diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..58bc6a02aa984230dbe6bcc65ae4876e34548b79 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +assets/demo/rgb.png filter=lfs diff=lfs merge=lfs -text +third_party/sam3/sam3/model/__pycache__/video_tracking_multiplex.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text +third_party/sam3/sam3/perflib/tests/assets/masks.tiff filter=lfs diff=lfs merge=lfs -text diff --git a/__pycache__/vis3d_glb.cpython-311.pyc b/__pycache__/vis3d_glb.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4f165db8a0640795183d7282601569143536571 Binary files /dev/null and b/__pycache__/vis3d_glb.cpython-311.pyc differ diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..a7eb793159270937fce46f1db6a9b07c3ea9c0e0 --- /dev/null +++ b/app.py @@ -0,0 +1,822 @@ +"""Gradio Web Demo for WildDet3D (5-mode). + +Supports 5 prompt modes: +- Text: Enter text like "chair.table" (one-to-many) +- Visual: Click box on image, text="visual" (one-to-many) +- Visual+Label: Click box + category label (one-to-many) +- Geometry: Click box on image, text="geometric" (one-to-one) +- Geometry+Label: Click box + category label (one-to-one) +- Point: Click on image to select point + +Requirements: + pip install gradio>=5.0.0 + +Usage: + python demo/huggingface/app.py + +Then open http://localhost:7860 in browser. +""" + +import os +import sys +from pathlib import Path + +# Add paths: support both local dev and HuggingFace Space. +# Local dev: demo/huggingface/app.py -> repo root = ../../ +# HF Space: wilddet3d/ is bundled in the same directory as app.py +_this_dir = Path(__file__).resolve().parent +if (_this_dir / "wilddet3d").exists(): + # HuggingFace Space: everything bundled next to app.py + sys.path.insert(0, str(_this_dir)) +else: + # Local dev: repo root is two levels up + repo_root = _this_dir.parent.parent + sys.path.insert(0, str(repo_root)) + +import spaces +import gradio as gr +import numpy as np +import torch +import cv2 +from PIL import Image + +from wilddet3d.inference import build_model, WildDet3DPredictor +from wilddet3d.preprocessing import preprocess +from wilddet3d.vis.visualize import draw_3d_boxes +from vis3d_glb import ( + depth_to_pointcloud, create_scene_glb, create_mesh_scene_glb, +) + + +def draw_points_on_image(image, points, color=(0, 255, 0), radius=8): + """Draw points on image. + + Args: + image: numpy array (H, W, 3) + points: list of (x, y, label) tuples + color: color for positive points (green default) + radius: point radius + + Returns: + Image with points drawn + """ + img = image.copy() + for x, y, label in points: + c = color if label == 1 else (255, 0, 0) + cv2.circle(img, (int(x), int(y)), radius, c, -1) + cv2.circle(img, (int(x), int(y)), radius + 2, (255, 255, 255), 2) + return img + + +def draw_box_on_image(image, box, color=(0, 0, 255), thickness=3): + """Draw box on image. + + Args: + image: numpy array (H, W, 3) + box: [x1, y1, x2, y2] coordinates + color: box color (red default) + thickness: line thickness + + Returns: + Image with box drawn + """ + img = image.copy() + x1, y1, x2, y2 = [int(v) for v in box] + cv2.rectangle(img, (x1, y1), (x2, y2), color, thickness) + return img + + +# HuggingFace Model repo for checkpoints +HF_MODEL_REPO = "weikaih/WildDet3D" +HF_CKPT_NAME = "wilddet3d.pt" + +# Local checkpoint paths (tried in order) +LOCAL_CHECKPOINTS = [ + "ckpt/wilddet3d.pt", # release repo layout +] + +# Default demo image path +DEFAULT_IMAGE_PATH = "assets/demo/rgb.png" +DEFAULT_INTRINSICS_PATH = "assets/demo/intrinsics.npy" + +# Global model (loaded once) +_cached_model = None + + +def _resolve_checkpoint(): + """Resolve checkpoint: local if exists, else download from HF Hub.""" + for path in LOCAL_CHECKPOINTS: + if os.path.exists(path): + return path + from huggingface_hub import hf_hub_download + hf_token = os.environ.get("HF_TOKEN") + print(f"Downloading checkpoint from {HF_MODEL_REPO}...") + ckpt = hf_hub_download( + repo_id=HF_MODEL_REPO, filename=HF_CKPT_NAME, token=hf_token + ) + return ckpt + + +def get_model(): + """Load model once and cache it.""" + global _cached_model + if _cached_model is None: + ckpt_path = _resolve_checkpoint() + print(f"Loading WildDet3D model from {ckpt_path}...") + _cached_model = build_model( + checkpoint=ckpt_path, + score_threshold=0.0, + canonical_rotation=True, + skip_pretrained=True, + ) + print("Model loaded!") + return _cached_model + + +def load_default_image(): + """Load the default demo image.""" + if os.path.exists(DEFAULT_IMAGE_PATH): + return np.array(Image.open(DEFAULT_IMAGE_PATH)) + return None + + +def load_default_intrinsics(): + """Load default intrinsics values.""" + if os.path.exists(DEFAULT_INTRINSICS_PATH): + intrinsics = np.load(DEFAULT_INTRINSICS_PATH) + return ( + float(intrinsics[0, 0]), + float(intrinsics[1, 1]), + float(intrinsics[0, 2]), + float(intrinsics[1, 2]), + ) + return 518.86, 519.47, 325.58, 253.74 + + +def format_intrinsics(K): + """Format intrinsics tensor for display.""" + if K is None: + return "Not available" + if isinstance(K, torch.Tensor): + K = K.cpu().numpy() + if K.ndim == 3: + K = K[0] + return ( + f"fx={K[0, 0]:.2f}, fy={K[1, 1]:.2f}, " + f"cx={K[0, 2]:.2f}, cy={K[1, 2]:.2f}" + ) + + +def scale_intrinsics_to_original(K, input_hw, original_hw): + """Scale intrinsics from model input resolution to original.""" + if K is None: + return None + + if isinstance(K, torch.Tensor): + K = K.clone() + else: + K = K.copy() + + input_h, input_w = input_hw + orig_h, orig_w = original_hw + + scale_x = orig_w / input_w + scale_y = orig_h / input_h + + if K.ndim == 3: + K[:, 0, 0] *= scale_x + K[:, 1, 1] *= scale_y + K[:, 0, 2] *= scale_x + K[:, 1, 2] *= scale_y + else: + K[0, 0] *= scale_x + K[1, 1] *= scale_y + K[0, 2] *= scale_x + K[1, 2] *= scale_y + + return K + + +def transform_coords_to_input_space(x, y, original_hw, input_hw, padding): + """Transform coords from original image space to preprocessed input. + + Args: + x, y: Coordinates in original image space + original_hw: (H, W) of original image + input_hw: (H, W) of preprocessed image (e.g., 1008x1008) + padding: (pad_left, pad_right, pad_top, pad_bottom) + + Returns: + (new_x, new_y) in preprocessed input space + """ + orig_h, orig_w = original_hw + pad_left, pad_right, pad_top, pad_bottom = padding + + content_w = input_hw[1] - pad_left - pad_right + content_h = input_hw[0] - pad_top - pad_bottom + + scale_x = content_w / orig_w + scale_y = content_h / orig_h + + new_x = x * scale_x + pad_left + new_y = y * scale_y + pad_top + + return new_x, new_y + + +def on_image_select( + evt: gr.SelectData, image, original_image, state, + prompt_mode, point_label, +): + """Handle click on image and visualize the click.""" + if image is None: + return state, "Please upload an image first", None + + x, y = evt.index[0], evt.index[1] + label = 1 if "Positive" in point_label else 0 + + new_state = { + "points": list(state.get("points", [])), + "box": list(state.get("box", [])), + } + + vis_image = ( + original_image.copy() + if original_image is not None + else image.copy() + ) + + if prompt_mode == "Point": + new_state["points"].append((x, y, label)) + new_state["box"] = [] + label_str = "+" if label == 1 else "-" + info = ( + f"Points: {len(new_state['points'])} total. " + f"Last: ({x}, {y}) [{label_str}]" + ) + vis_image = draw_points_on_image(vis_image, new_state["points"]) + + elif prompt_mode in ("Box-to-Multi-Object", "Box-to-Single-Object"): + new_state["points"] = [] + box_clicks = list(new_state.get("box", [])) + box_clicks.append((x, y)) + + if len(box_clicks) == 1: + new_state["box"] = box_clicks + info = ( + f"[{prompt_mode}] Corner 1: ({x}, {y}) " + f"- click again for corner 2" + ) + vis_image = draw_points_on_image(vis_image, [(x, y, 1)]) + + elif len(box_clicks) >= 2: + x1, y1 = box_clicks[0] + x2, y2 = box_clicks[1] + box = [min(x1, x2), min(y1, y2), max(x1, x2), max(y1, y2)] + new_state["box"] = [(box[0], box[1]), (box[2], box[3])] + info = ( + f"[{prompt_mode}] Box: " + f"({box[0]}, {box[1]}) -> ({box[2]}, {box[3]})" + ) + vis_image = draw_box_on_image(vis_image, box) + else: + info = f"Box clicks: {box_clicks}" + else: + info = "Text mode - just enter text and click Run" + + return new_state, info, vis_image + + +def clear_clicks(state, original_image): + """Reset click state and restore original image.""" + new_state = {"points": [], "box": []} + return ( + new_state, + "Cleared - ready for new clicks", + original_image.copy() if original_image is not None else None, + ) + + +@spaces.GPU +def run_wilddet3d( + image, + state, + prompt_mode, + text_prompt, + use_label, + label_text, + score_thres, + use_predicted_K, + fx, fy, cx, cy, + enable_3d_vis=True, + remove_edges=True, + point_density=2, + use_textured_mesh=True, +): + """Run WildDet3D with selected prompt mode.""" + if image is None: + return None, "Please upload an image first", None, None + + # Convert RGBA to RGB if needed + if image.ndim == 3 and image.shape[2] == 4: + image = image[:, :, :3] + + device = "cuda" if torch.cuda.is_available() else "cpu" + detector = get_model() + + # Build intrinsics matrix (or None if using predicted) + if use_predicted_K: + intrinsics = None + else: + intrinsics = np.array([ + [fx, 0, cx], + [0, fy, cy], + [0, 0, 1] + ], dtype=np.float32) + + # Preprocess image + data = preprocess(image.astype(np.float32), intrinsics) + + # Build prompt_text for box/point modes + if prompt_mode == "Box-to-Multi-Object": + prefix = "visual" + elif prompt_mode == "Box-to-Single-Object": + prefix = "geometric" + else: + prefix = "geometric" # Point mode default + + if prompt_mode != "Text": + if use_label and label_text and label_text.strip(): + geo_prompt_text = f"{prefix}: {label_text.strip()}" + else: + geo_prompt_text = prefix + + # Initialize prompt info for visualization + prompt_points = None + prompt_box = None + + # Run based on prompt mode + if prompt_mode == "Text": + input_texts = [ + t.strip() for t in text_prompt.split(".") if t.strip() + ] + if not input_texts: + input_texts = ["object"] + + results = detector( + images=data["images"].to(device), + intrinsics=data["intrinsics"].to(device)[None], + input_hw=[data["input_hw"]], + original_hw=[data["original_hw"]], + padding=[data["padding"]], + input_texts=input_texts, + return_predicted_intrinsics=True, + ) + ( + boxes, boxes3d, scores, scores_2d, scores_3d, + class_ids, depth_maps, predicted_K, + ) = results + class_id_mapping = {i: t for i, t in enumerate(input_texts)} + + elif prompt_mode in ("Box-to-Multi-Object", "Box-to-Single-Object"): + box_coords = state.get("box", []) + if len(box_coords) < 2: + return ( + None, + "Please click twice on the image to define a box", + None, + None, + ) + + x1_orig, y1_orig = box_coords[0] + x2_orig, y2_orig = box_coords[1] + x1, y1 = transform_coords_to_input_space( + x1_orig, y1_orig, + data["original_hw"], data["input_hw"], data["padding"], + ) + x2, y2 = transform_coords_to_input_space( + x2_orig, y2_orig, + data["original_hw"], data["input_hw"], data["padding"], + ) + box_xyxy = [float(x1), float(y1), float(x2), float(y2)] + + prompt_box = [x1_orig, y1_orig, x2_orig, y2_orig] + + results = detector( + images=data["images"].to(device), + intrinsics=data["intrinsics"].to(device)[None], + input_hw=[data["input_hw"]], + original_hw=[data["original_hw"]], + padding=[data["padding"]], + input_boxes=[box_xyxy], + prompt_text=geo_prompt_text, + return_predicted_intrinsics=True, + ) + ( + boxes, boxes3d, scores, scores_2d, scores_3d, + class_ids, depth_maps, predicted_K, + ) = results + class_id_mapping = {0: geo_prompt_text} + + elif prompt_mode == "Point": + points = state.get("points", []) + if not points: + return ( + None, + "Please click on the image to select a point", + None, + None, + ) + + transformed_points = [] + for x_orig, y_orig, lbl in points: + x, y = transform_coords_to_input_space( + x_orig, y_orig, + data["original_hw"], data["input_hw"], data["padding"], + ) + transformed_points.append((x, y, lbl)) + + prompt_points = points + + results = detector( + images=data["images"].to(device), + intrinsics=data["intrinsics"].to(device)[None], + input_hw=[data["input_hw"]], + original_hw=[data["original_hw"]], + padding=[data["padding"]], + input_points=[transformed_points], + prompt_text=geo_prompt_text, + return_predicted_intrinsics=True, + ) + ( + boxes, boxes3d, scores, scores_2d, scores_3d, + class_ids, depth_maps, predicted_K, + ) = results + class_id_mapping = {0: geo_prompt_text} + + else: + return None, f"Unknown prompt mode: {prompt_mode}", None, None + + # Scale predicted intrinsics to original resolution + predicted_K_scaled = scale_intrinsics_to_original( + predicted_K, + input_hw=data["input_hw"], + original_hw=data["original_hw"], + ) + + # Format intrinsics info + orig_h, orig_w = data["original_hw"] + intrinsics_info = f"Image: {orig_w}x{orig_h}\n" + intrinsics_info += f"Predicted: {format_intrinsics(predicted_K_scaled)}" + if not use_predicted_K: + intrinsics_info = f"Image: {orig_w}x{orig_h}\n" + intrinsics_info += ( + f"Used: fx={fx:.2f}, fy={fy:.2f}, " + f"cx={cx:.2f}, cy={cy:.2f}\n" + ) + intrinsics_info += ( + f"Predicted: {format_intrinsics(predicted_K_scaled)}" + ) + + # 2D visualization + img_2d = visualize_results( + data, boxes3d, scores, scores_2d, scores_3d, + class_ids, class_id_mapping, score_thres, + prompt_points=prompt_points, prompt_box=prompt_box, + ) + + # Depth map visualization + depth_vis_img = None + if depth_maps is not None and len(depth_maps) > 0: + depth_np_raw = depth_maps[0].cpu().numpy() + d = depth_np_raw.squeeze() + + pad_l, pad_r, pad_t, pad_b = data["padding"] + h_end = d.shape[0] - pad_b if pad_b > 0 else d.shape[0] + w_end = d.shape[1] - pad_r if pad_r > 0 else d.shape[1] + d_crop = d[pad_t:h_end, pad_l:w_end] + + d_valid = d_crop[d_crop > 0.01] + if len(d_valid) > 0: + d_min, d_max = d_valid.min(), d_valid.max() + d_norm = np.clip( + (d_crop - d_min) / (d_max - d_min + 1e-6), 0, 1 + ) + d_norm = (1.0 - d_norm) * 255 + d_norm = d_norm.astype(np.uint8) + depth_vis_img = cv2.applyColorMap(d_norm, cv2.COLORMAP_TURBO) + depth_vis_img = cv2.cvtColor(depth_vis_img, cv2.COLOR_BGR2RGB) + depth_vis_img = Image.fromarray(depth_vis_img) + + # 3D visualization (optional) + glb_path = None + if enable_3d_vis and depth_maps is not None and len(depth_maps) > 0: + depth_np = depth_maps[0].cpu().numpy() + + input_img = data["images"].cpu() + mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) + std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) + input_img = (input_img * std + mean).clamp(0, 1) * 255 + input_img = ( + input_img.squeeze(0).permute(1, 2, 0).numpy().astype(np.uint8) + ) + + K_for_unproj = data["intrinsics"].cpu().numpy() + + filtered_boxes3d_np = [] + for i in range(len(boxes3d)): + mask = scores[i] >= score_thres + filtered_boxes3d_np.append(boxes3d[i][mask].cpu().numpy()) + + glb_path = "/tmp/wilddet3d_scene.glb" + + if use_textured_mesh: + create_mesh_scene_glb( + depth_np, input_img, K_for_unproj, + filtered_boxes3d_np, glb_path, + max_depth=20.0, + padding=data["padding"], + remove_edge=remove_edges, + edge_rtol=0.04, + ) + else: + subsample = max(1, int(point_density)) + points, point_colors = depth_to_pointcloud( + depth_np, input_img, K_for_unproj, + max_depth=20.0, subsample=subsample, + padding=data["padding"], + remove_edge=remove_edges, + edge_rtol=0.04, + ) + create_scene_glb( + points, point_colors, filtered_boxes3d_np, glb_path + ) + + return img_2d, intrinsics_info, glb_path, depth_vis_img + + +def visualize_results( + data, boxes3d, scores, scores_2d, scores_3d, class_ids, + class_id_mapping, score_thres, + prompt_points=None, prompt_box=None, +): + """Visualize 3D detection results using wilddet3d.vis.draw_3d_boxes.""" + filtered_boxes3d = [] + filtered_scores_2d = [] + filtered_scores_3d = [] + filtered_class_ids = [] + + for i in range(len(boxes3d)): + mask = scores[i] >= score_thres + filtered_boxes3d.append(boxes3d[i][mask]) + if scores_2d is not None: + filtered_scores_2d.append(scores_2d[i][mask]) + else: + filtered_scores_2d.append(torch.zeros_like(scores[i][mask])) + if scores_3d is not None: + filtered_scores_3d.append(scores_3d[i][mask]) + else: + filtered_scores_3d.append(torch.zeros_like(scores[i][mask])) + filtered_class_ids.append(class_ids[i][mask]) + + # Get original image and draw prompts on it + original_img = data["original_images"].cpu().numpy().astype(np.uint8) + + if prompt_points is not None and len(prompt_points) > 0: + original_img = draw_points_on_image(original_img, prompt_points) + + if prompt_box is not None and len(prompt_box) == 4: + original_img = draw_box_on_image(original_img, prompt_box) + + # Use wilddet3d's draw_3d_boxes for visualization + K = data["original_intrinsics"].cpu().numpy() + if K.ndim == 3: + K = K[0] + + class_names = [ + class_id_mapping.get(i, str(i)) + for i in range(max(len(class_id_mapping), 1)) + ] + + # Draw 3D boxes with 2D/3D score labels + if len(filtered_boxes3d) > 0 and len(filtered_boxes3d[0]) > 0: + pil_img = draw_3d_boxes( + image=original_img, + boxes3d=filtered_boxes3d[0], + intrinsics=K, + scores_2d=filtered_scores_2d[0], + scores_3d=filtered_scores_3d[0], + class_ids=filtered_class_ids[0], + class_names=class_names, + n_colors=max(len(class_id_mapping), 1), + ) + else: + pil_img = Image.fromarray(original_img) + + return pil_img + + +# Load default values +default_fx, default_fy, default_cx, default_cy = load_default_intrinsics() +default_image = load_default_image() + +# Build Gradio interface +with gr.Blocks(title="WildDet3D: 3D Detection") as demo: + gr.Markdown("# WildDet3D: Open-Vocabulary 3D Detection in the Wild") + gr.Markdown(""" + **How to use:** + - **Text**: Enter object names (e.g., "chair.table"), click Run + - **Box-to-Multi-Object**: Draw box -> detect ALL similar objects (one-to-many) + - **Box-to-Single-Object**: Draw box -> detect ONLY the boxed object (one-to-one) + - **Point**: Click on object, click Run + - **+ Label**: Check this to attach a category name (e.g., "chair") to box/point prompts + """) + + # State for click coordinates and original image + click_state = gr.State({"points": [], "box": []}) + original_image_state = gr.State( + default_image.copy() if default_image is not None else None + ) + + with gr.Row(): + # Left column: Input + with gr.Column(scale=1): + input_image = gr.Image( + label="Input Image (click for Box/Point mode)", + type="numpy", + value=default_image, + interactive=True, + sources=["upload", "clipboard"], + ) + + # Prompt settings + prompt_mode = gr.Radio( + choices=[ + "Text", + "Box-to-Multi-Object", + "Box-to-Single-Object", + "Point", + ], + value="Text", + label="Prompt Mode", + ) + text_prompt = gr.Textbox( + label="Text Prompt (e.g. 'chair.table')", + value="chair.table", + placeholder="Enter object names separated by '.'", + visible=True, + ) + use_label = gr.Checkbox( + label="+ Label (attach category name to box/point prompt)", + value=False, + visible=False, + ) + label_text = gr.Textbox( + label="Category Label (e.g. 'chair')", + value="", + placeholder="Category name for the selected object", + visible=False, + ) + + # Point label for Point mode + point_label = gr.Radio( + choices=["Positive (include)", "Negative (exclude)"], + value="Positive (include)", + label="Point Label (for Point mode)", + visible=False, + ) + + # Click info display + click_info = gr.Textbox( + label="Click Info", + value="Select mode and click on image", + interactive=False, + ) + + with gr.Row(): + clear_btn = gr.Button("Clear Clicks") + run_btn = gr.Button("Run Detection", variant="primary") + + # Intrinsics settings + use_predicted_K = gr.Checkbox( + label="Use Predicted Intrinsics", + value=True, + ) + with gr.Row(): + fx = gr.Number(label="fx", value=default_fx) + fy = gr.Number(label="fy", value=default_fy) + cx = gr.Number(label="cx", value=default_cx) + cy = gr.Number(label="cy", value=default_cy) + + score_thres = gr.Slider( + minimum=0, maximum=1, value=0.3, step=0.05, + label="Score Threshold", + ) + + # 3D visualization settings + gr.Markdown("### 3D Visualization Settings") + enable_3d_vis = gr.Checkbox( + label="Enable 3D Point Cloud / Mesh Visualization", + value=False, + ) + gr.Markdown( + "*Notice: the model takes the depth latent to generate " + "3D boxes, so the boxes and the point cloud might not " + "exactly match.*" + ) + use_textured_mesh = gr.Checkbox( + label="Textured Mesh (otherwise point cloud)", + value=True, + ) + remove_edges = gr.Checkbox( + label="Remove depth edges (cleaner geometry)", + value=True, + ) + point_density = gr.Slider( + minimum=1, maximum=8, value=2, step=1, + label="Point Subsample (point cloud mode only, 1=dense)", + ) + + # Right column: Output + with gr.Column(scale=1): + output_image = gr.Image( + label="2D Detection Results", type="pil" + ) + depth_image = gr.Image(label="Depth Map", type="pil") + output_3d = gr.Model3D( + label="3D View (Mesh/Point Cloud + Boxes)", + clear_color=(0.1, 0.1, 0.1, 1.0), + ) + intrinsics_info = gr.Textbox( + label="Intrinsics Info", interactive=False + ) + + # Toggle visibility based on prompt mode + def on_mode_change(mode): + is_text = mode == "Text" + is_point = mode == "Point" + return ( + gr.update(visible=is_text), # text_prompt + gr.update(visible=not is_text), # use_label + gr.update(visible=not is_text), # label_text + gr.update(visible=is_point), # point_label + ) + + prompt_mode.change( + on_mode_change, + inputs=[prompt_mode], + outputs=[text_prompt, use_label, label_text, point_label], + ) + + # Connect events + input_image.select( + on_image_select, + inputs=[ + input_image, original_image_state, click_state, + prompt_mode, point_label, + ], + outputs=[click_state, click_info, input_image], + ) + + clear_btn.click( + clear_clicks, + inputs=[click_state, original_image_state], + outputs=[click_state, click_info, input_image], + ) + + # When new image is uploaded, save it as original + def on_image_upload(image): + if image is None: + return None, {"points": [], "box": []}, "Upload an image" + return ( + image.copy(), + {"points": [], "box": []}, + "Image loaded - select mode and click", + ) + + input_image.upload( + on_image_upload, + inputs=[input_image], + outputs=[original_image_state, click_state, click_info], + ) + + run_btn.click( + run_wilddet3d, + inputs=[ + input_image, click_state, prompt_mode, text_prompt, + use_label, label_text, score_thres, use_predicted_K, + fx, fy, cx, cy, + enable_3d_vis, remove_edges, point_density, use_textured_mesh, + ], + outputs=[output_image, intrinsics_info, output_3d, depth_image], + ) + + +if __name__ == "__main__": + print("=" * 60) + print("WildDet3D Web Demo") + print("=" * 60) + print() + print("Starting server...") + port = int(os.environ.get("GRADIO_SERVER_PORT", 7860)) + demo.launch(share=False, server_name="0.0.0.0", server_port=port) diff --git a/assets/demo/intrinsics.npy b/assets/demo/intrinsics.npy new file mode 100644 index 0000000000000000000000000000000000000000..f8351ab05be97d6522ad36d7c8b6ede95e4e1282 --- /dev/null +++ b/assets/demo/intrinsics.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a5e46d677b736c45d98fda89d2b4b6b8e88028f8c7a5e25df6c9c3e61f6c6fed +size 164 diff --git a/assets/demo/rgb.png b/assets/demo/rgb.png new file mode 100644 index 0000000000000000000000000000000000000000..396add24e324de7a8102bfed191bb1b8069d9c68 --- /dev/null +++ b/assets/demo/rgb.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:377def0b77a5d11be17fdf3f48466a7dfcde7fff9fd10e1e2f68c57efb18736e +size 448668 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..8fd9f15aefc5bd7c0e26e767bbe09d01982178bf --- /dev/null +++ b/requirements.txt @@ -0,0 +1,59 @@ +# Vis4D (same approach: install dependencies, not vis4d itself) +absl-py +appdirs +cloudpickle +cython +devtools +h5py +jsonargparse[signatures] +lightning +ml_collections==1.1.0 +numpy>=1.21.0,<2.0.0 +opencv-python +pandas +pillow +plyfile +pycocotools +pydantic>=2.0 +setuptools +tensorboard +termcolor +terminaltables +timm>=0.6.0 +torch>=2.0.0 +torchvision>=0.15.1 +tqdm +utm +wheel +scipy + +# Git utils +gitdb +GitPython + +# WildDet3D +einops +fvcore +nltk +transformers +fairscale +mmengine +decord + +# SAM3 dependencies +ftfy +regex +iopath +omegaconf +hydra-core +scikit-image +scikit-learn +open_clip_torch + +# 3D visualization +pygltflib +trimesh +utils3d + +# Depth estimation +huggingface_hub diff --git a/third_party/lingbot_depth/mdm/model/__init__.py b/third_party/lingbot_depth/mdm/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..99420f50e6c63ab275d94d6bb2d2c9fc39cdd3e6 --- /dev/null +++ b/third_party/lingbot_depth/mdm/model/__init__.py @@ -0,0 +1,15 @@ +import importlib +from typing import * + +if TYPE_CHECKING: + from .v2 import MDMModel as MDMModelV2 + +def import_model_class_by_version(version: str) -> Type[Union['MDMModelV2']]: + assert version in ['v2'], f'Unsupported model version: {version}' + + try: + module = importlib.import_module(f'.{version}', __package__) + except ModuleNotFoundError: + raise ValueError(f'Model version "{version}" not found.') + cls = getattr(module, 'MDMModel') + return cls diff --git a/third_party/lingbot_depth/mdm/model/dinov2_rgbd/__init__.py b/third_party/lingbot_depth/mdm/model/dinov2_rgbd/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ae847e46898077fe3d8701b8a181d7b4e3d41cd9 --- /dev/null +++ b/third_party/lingbot_depth/mdm/model/dinov2_rgbd/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +__version__ = "0.0.1" diff --git a/third_party/lingbot_depth/mdm/model/dinov2_rgbd/hub/__init__.py b/third_party/lingbot_depth/mdm/model/dinov2_rgbd/hub/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9 --- /dev/null +++ b/third_party/lingbot_depth/mdm/model/dinov2_rgbd/hub/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. diff --git a/third_party/lingbot_depth/mdm/model/dinov2_rgbd/hub/backbones.py b/third_party/lingbot_depth/mdm/model/dinov2_rgbd/hub/backbones.py new file mode 100644 index 0000000000000000000000000000000000000000..2f81215aaab11548425fee4f1b199048e164cec8 --- /dev/null +++ b/third_party/lingbot_depth/mdm/model/dinov2_rgbd/hub/backbones.py @@ -0,0 +1,162 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from enum import Enum +from typing import Union + +import torch + +from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name + + +class Weights(Enum): + LVD142M = "LVD142M" + + +def _make_dinov2_model( + *, + arch_name: str = "vit_large", + img_size: int = 518, + patch_size: int = 14, + init_values: float = 1.0, + ffn_layer: str = "mlp", + block_chunks: int = 0, + num_register_tokens: int = 0, + interpolate_antialias: bool = False, + interpolate_offset: float = 0.1, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.LVD142M, + **kwargs, +): + from ..models import vision_transformer as vits + + if isinstance(weights, str): + try: + weights = Weights[weights] + except KeyError: + raise AssertionError(f"Unsupported weights: {weights}") + + model_base_name = _make_dinov2_model_name(arch_name, patch_size) + vit_kwargs = dict( + img_size=img_size, + patch_size=patch_size, + init_values=init_values, + ffn_layer=ffn_layer, + block_chunks=block_chunks, + num_register_tokens=num_register_tokens, + interpolate_antialias=interpolate_antialias, + interpolate_offset=interpolate_offset, + ) + vit_kwargs.update(**kwargs) + model = vits.__dict__[arch_name](**vit_kwargs) + + if pretrained: + model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens) + url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth" + state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") + model.load_state_dict(state_dict, strict=True) + + return model + + +def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs) + +def dinov2_vitl16(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset. + """ + # kwargs.update({'img_size': 224, 'patch_size': 16, }) + return _make_dinov2_model(arch_name="vit_large", pretrained=False, weights=weights, **kwargs) + +def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_giant2", + ffn_layer="swiglufused", + weights=weights, + pretrained=pretrained, + **kwargs, + ) + + +def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_small", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_base", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_large", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_giant2", + ffn_layer="swiglufused", + weights=weights, + pretrained=pretrained, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) diff --git a/third_party/lingbot_depth/mdm/model/dinov2_rgbd/hub/utils.py b/third_party/lingbot_depth/mdm/model/dinov2_rgbd/hub/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9c6641404093652d5a2f19b4cf283d976ec39e64 --- /dev/null +++ b/third_party/lingbot_depth/mdm/model/dinov2_rgbd/hub/utils.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import itertools +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2" + + +def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str: + compact_arch_name = arch_name.replace("_", "")[:4] + registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else "" + return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}" + + +class CenterPadding(nn.Module): + def __init__(self, multiple): + super().__init__() + self.multiple = multiple + + def _get_pad(self, size): + new_size = math.ceil(size / self.multiple) * self.multiple + pad_size = new_size - size + pad_size_left = pad_size // 2 + pad_size_right = pad_size - pad_size_left + return pad_size_left, pad_size_right + + @torch.inference_mode() + def forward(self, x): + pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1])) + output = F.pad(x, pads) + return output diff --git a/third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/__init__.py b/third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2ca939cbc945791c12b8c4e4088e0c0ecb7c0fef --- /dev/null +++ b/third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .dino_head import DINOHead +from .mlp import Mlp +from .patch_embed import PatchEmbed +from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused +from .block import NestedTensorBlock +from .attention import MemEffAttention +from .patch_embed_mlp import PatchEmbed as PatchEmbedMLP diff --git a/third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/attention.py b/third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..c9f79d471fc099b1dcaa512dfdbdec8a9fc5908f --- /dev/null +++ b/third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/attention.py @@ -0,0 +1,100 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging +import os +import warnings + +import torch.nn.functional as F +from torch import Tensor +from torch import nn + + +logger = logging.getLogger("dinov2") + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import memory_efficient_attention, unbind + + XFORMERS_AVAILABLE = True + # warnings.warn("xFormers is available (Attention)") + else: + # warnings.warn("xFormers is disabled (Attention)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + # warnings.warn("xFormers is not available (Attention)") + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + # # Deprecated implementation, extremely slow + # def forward(self, x: Tensor, attn_bias=None) -> Tensor: + # B, N, C = x.shape + # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + # q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + # attn = q @ k.transpose(-2, -1) + # attn = attn.softmax(dim=-1) + # attn = self.attn_drop(attn) + # x = (attn @ v).transpose(1, 2).reshape(B, N, C) + # x = self.proj(x) + # x = self.proj_drop(x) + # return x + + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # (3, B, H, N, C // H) + + q, k, v = qkv.unbind(0) # (B, H, N, C // H) + + x = F.scaled_dot_product_attention(q, k, v, attn_bias) + x = x.permute(0, 2, 1, 3).reshape(B, N, C) + + x = self.proj(x) + x = self.proj_drop(x) + return x + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x diff --git a/third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/block.py b/third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/block.py new file mode 100644 index 0000000000000000000000000000000000000000..de6faacca49fe7cd263ce12f5c9fcf46fc7e3770 --- /dev/null +++ b/third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/block.py @@ -0,0 +1,259 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import logging +import os +from typing import Callable, List, Any, Tuple, Dict +import warnings + +import torch +from torch import nn, Tensor + +from .attention import Attention, MemEffAttention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + + +logger = logging.getLogger("dinov2") + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import fmha, scaled_index_add, index_select_cat + + XFORMERS_AVAILABLE = True + # warnings.warn("xFormers is available (Block)") + else: + # warnings.warn("xFormers is disabled (Block)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + # warnings.warn("xFormers is not available (Block)") + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x))) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): + outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + if not XFORMERS_AVAILABLE: + raise AssertionError("xFormers is required for using nested tensors") + return self.forward_nested(x_or_x_list) + else: + raise AssertionError diff --git a/third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/dino_head.py b/third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/dino_head.py new file mode 100644 index 0000000000000000000000000000000000000000..0ace8ffd6297a1dd480b19db407b662a6ea0f565 --- /dev/null +++ b/third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/dino_head.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.nn.init import trunc_normal_ +from torch.nn.utils import weight_norm + + +class DINOHead(nn.Module): + def __init__( + self, + in_dim, + out_dim, + use_bn=False, + nlayers=3, + hidden_dim=2048, + bottleneck_dim=256, + mlp_bias=True, + ): + super().__init__() + nlayers = max(nlayers, 1) + self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) + self.apply(self._init_weights) + self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) + self.last_layer.weight_g.data.fill_(1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.mlp(x) + eps = 1e-6 if x.dtype == torch.float16 else 1e-12 + x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) + x = self.last_layer(x) + return x + + +def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): + if nlayers == 1: + return nn.Linear(in_dim, bottleneck_dim, bias=bias) + else: + layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + for _ in range(nlayers - 2): + layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) + return nn.Sequential(*layers) diff --git a/third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/drop_path.py b/third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..1d640e0b969b8dcba96260243473700b4e5b24b5 --- /dev/null +++ b/third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/drop_path.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/layer_scale.py b/third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..51df0d7ce61f2b41fa9e6369f52391dd7fe7d386 --- /dev/null +++ b/third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/layer_scale.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 + +from typing import Union + +import torch +from torch import Tensor +from torch import nn + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/mlp.py b/third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..bbf9432aae9258612caeae910a7bde17999e328e --- /dev/null +++ b/third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/mlp.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional + +from torch import Tensor, nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/patch_embed.py b/third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..8b7c0804784a42cf80c0297d110dcc68cc85b339 --- /dev/null +++ b/third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/patch_embed.py @@ -0,0 +1,88 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +from torch import Tensor +import torch.nn as nn + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/patch_embed_mlp.py b/third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/patch_embed_mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..26938ac088a04bd20ea4032f84dc7904efc202bc --- /dev/null +++ b/third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/patch_embed_mlp.py @@ -0,0 +1,153 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +import torch +from torch import Tensor +import torch.nn as nn +import torch.nn.functional as F + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + +class PixelUnshuffle (nn.Module): + def __init__(self, downscale_factor): + super().__init__() + self.downscale_factor = downscale_factor + + def forward(self, input): + if input.numel() == 0: + # this is not in the original torch implementation + C,H,W = input.shape[-3:] + assert H and W and H % self.downscale_factor == W%self.downscale_factor == 0 + return input.view(*input.shape[:-3], C*self.downscale_factor**2, H//self.downscale_factor, W//self.downscale_factor) + else: + return F.pixel_unshuffle(input, self.downscale_factor) + +class Permute(nn.Module): + dims: tuple[int, ...] + def __init__(self, dims: tuple[int, ...]) -> None: + super().__init__() + self.dims = tuple(dims) + + def __repr__(self): + return f"Permute{self.dims}" + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return input.permute(*self.dims) + +from itertools import repeat +import collections.abc +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return x + return tuple(repeat(x, n)) + return parse +to_2tuple = _ntuple(2) + +class Mlp(nn.Module): + """ MLP as used in Vision Transformer, MLP-Mixer and related networks""" + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + return x + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Sequential( + PixelUnshuffle(patch_size), + Permute((0,2,3,1)), + Mlp(in_chans * patch_size * patch_size, 4*embed_dim, embed_dim), + Permute((0,3,1,2)), + ) + + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/swiglu_ffn.py b/third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/swiglu_ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..5ce211515774d42e04c8b51003bae53b88f14b35 --- /dev/null +++ b/third_party/lingbot_depth/mdm/model/dinov2_rgbd/layers/swiglu_ffn.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import os +from typing import Callable, Optional +import warnings + +from torch import Tensor, nn +import torch.nn.functional as F + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import SwiGLU + + XFORMERS_AVAILABLE = True + # warnings.warn("xFormers is available (SwiGLU)") + else: + # warnings.warn("xFormers is disabled (SwiGLU)") + raise ImportError +except ImportError: + SwiGLU = SwiGLUFFN + XFORMERS_AVAILABLE = False + + # warnings.warn("xFormers is not available (SwiGLU)") + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) diff --git a/third_party/lingbot_depth/mdm/model/dinov2_rgbd/models/__init__.py b/third_party/lingbot_depth/mdm/model/dinov2_rgbd/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2c00e05581b29e31e936a55ee7791dbe2cf85f37 --- /dev/null +++ b/third_party/lingbot_depth/mdm/model/dinov2_rgbd/models/__init__.py @@ -0,0 +1,55 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +''' +Docstring for MDM.mdm.model.dinov2_rgbd.models_vlmae +======================================================= +This version is modified from the original DINOv2 to support the MIM(masked image modeling) of RGBD input. +(The original DINOv2 is available at https://github.com/facebookresearch/dinov2.) + +Core Changes: +1. We add the depth input into the original DINOv2 transformer encoder. + +2. We support the Variable Mask Ratio MAE for both RGB and Depth input. +''' + +import logging + +from . import vision_transformer as vits + +logger = logging.getLogger("dinov2") + + +def build_model(args, only_teacher=False, img_size=224): + args.arch = args.arch.removesuffix("_memeff") + if "vit" in args.arch: + vit_kwargs = dict( + img_size=img_size, + patch_size=args.patch_size, + init_values=args.layerscale, + ffn_layer=args.ffn_layer, + block_chunks=args.block_chunks, + qkv_bias=args.qkv_bias, + proj_bias=args.proj_bias, + ffn_bias=args.ffn_bias, + num_register_tokens=args.num_register_tokens, + interpolate_offset=args.interpolate_offset, + interpolate_antialias=args.interpolate_antialias, + ) + teacher = vits.__dict__[args.arch](**vit_kwargs) + if only_teacher: + return teacher, teacher.embed_dim + student = vits.__dict__[args.arch]( + **vit_kwargs, + drop_path_rate=args.drop_path_rate, + drop_path_uniform=args.drop_path_uniform, + ) + embed_dim = student.embed_dim + return student, teacher, embed_dim + + +def build_model_from_cfg(cfg, only_teacher=False): + return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size) + diff --git a/third_party/lingbot_depth/mdm/model/dinov2_rgbd/models/mask_utils.py b/third_party/lingbot_depth/mdm/model/dinov2_rgbd/models/mask_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0491007de14d21da8e9e81e508d36717190de8bb --- /dev/null +++ b/third_party/lingbot_depth/mdm/model/dinov2_rgbd/models/mask_utils.py @@ -0,0 +1,137 @@ +import torch +def depth_masking( + x, + patch_num_h, + patch_num_w, + depth_values, + depth_mask_threshold_ratio=None, + depth_mask_threshold_num=None, + valid_depth_range=(0.1, 10.0), +): + """ + Perform patch masking based on depth validity + + Args: + x: [B, N, D] input features (after patch embedding) + patch_num_h: int, height of the patch grid + patch_num_w: int, width of the patch grid + depth_values: [B, 1, H_img, W_img], raw depth map + depth_mask_threshold_ratio: float or list, valid depth ratio threshold (0-1) + depth_mask_threshold_num: int or list, valid depth pixel count threshold + valid_depth_range: tuple, valid depth range (min, max) + + Returns: + visible_list: list of [N_visible_i, D], visible patches for each sample + mask_info: dict, containing masking information + """ + B, N, D = x.shape + device = x.device + + assert N == patch_num_h * patch_num_w, \ + f"N={N} must equal patch_num_h * patch_num_w = {patch_num_h * patch_num_w}" + + # Compute depth invalid mask + depth_invalid_mask = _compute_depth_invalid_mask( + depth_values, + patch_num_h, + patch_num_w, + depth_mask_threshold_ratio, + depth_mask_threshold_num, + valid_depth_range + ) # [B, N], True indicates this patch is invalid + + # Process each sample separately + visible_list = [] + mask_info = { + 'visible_indices': [], + 'mask_indices': [], + 'num_visible': [], + } + + for i in range(B): + # Get valid patch indices + valid_mask = ~depth_invalid_mask[i] # [N] + visible_indices = torch.where(valid_mask)[0] + masked_indices = torch.where(depth_invalid_mask[i])[0] + + # Extract visible patches + visible = x[i, visible_indices] # [N_visible, D] + visible_list.append(visible) + + # Record information + mask_info['visible_indices'].append(visible_indices) + mask_info['mask_indices'].append(masked_indices) + mask_info['num_visible'].append(len(visible_indices)) + + return visible_list, mask_info + +def _compute_depth_invalid_mask( + depth_values, + H_patch, + W_patch, + threshold_ratio, + threshold_num, + valid_range +): + """ + Compute depth validity for each patch + + Args: + depth_values: [B, 1, H_img, W_img] raw depth map + H_patch, W_patch: patch grid dimensions + threshold_ratio: float or list, valid depth ratio threshold + threshold_num: int or list, valid depth pixel count threshold + valid_range: tuple, (min_depth, max_depth) + + Returns: + invalid_mask: [B, N] bool tensor, True indicates this patch is invalid + """ + B, _, H_img, W_img = depth_values.shape + N = H_patch * W_patch + device = depth_values.device + + min_depth, max_depth = valid_range + + # Calculate pixel size for each patch + patch_h = H_img // H_patch + patch_w = W_img // W_patch + + assert H_img % H_patch == 0 and W_img % W_patch == 0, \ + f"Image size ({H_img}, {W_img}) must be divisible by patch grid ({H_patch}, {W_patch})" + + # Reshape depth map into patches: [B, 1, H_img, W_img] -> [B, H_patch, patch_h, W_patch, patch_w] + depth_reshaped = depth_values.view(B, 1, H_patch, patch_h, W_patch, patch_w) + + # Transpose and flatten: [B, H_patch, W_patch, patch_h, patch_w] -> [B, N, patch_h*patch_w] + depth_reshaped = depth_reshaped.permute(0, 2, 4, 1, 3, 5).reshape(B, N, -1) + + # Calculate valid depth + valid_depth = (depth_reshaped >= min_depth) & (depth_reshaped <= max_depth) + valid_depth_ratio = valid_depth.float().mean(dim=-1) # [B, N] + valid_depth_num = valid_depth.float().sum(dim=-1) # [B, N] + + # Handle list-form thresholds (different thresholds for each sample in batch) + if isinstance(threshold_ratio, list) or isinstance(threshold_num, list): + invalid_mask = torch.zeros(B, N, dtype=torch.bool, device=device) + + for i in range(B): + tr = threshold_ratio[i] if isinstance(threshold_ratio, list) else threshold_ratio + tn = threshold_num[i] if isinstance(threshold_num, list) else threshold_num + + sample_mask = torch.zeros(N, dtype=torch.bool, device=device) + if tr is not None: + sample_mask |= (valid_depth_ratio[i] < tr) + if tn is not None: + sample_mask |= (valid_depth_num[i] < tn) + + invalid_mask[i] = sample_mask + else: + # Uniform threshold + invalid_mask = torch.zeros(B, N, dtype=torch.bool, device=device) + + if threshold_ratio is not None: + invalid_mask |= (valid_depth_ratio < threshold_ratio) + if threshold_num is not None: + invalid_mask |= (valid_depth_num < threshold_num) + + return invalid_mask \ No newline at end of file diff --git a/third_party/lingbot_depth/mdm/model/dinov2_rgbd/models/vision_transformer.py b/third_party/lingbot_depth/mdm/model/dinov2_rgbd/models/vision_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..07900ce480a1407f08e0a212dfa264cf88c59f8a --- /dev/null +++ b/third_party/lingbot_depth/mdm/model/dinov2_rgbd/models/vision_transformer.py @@ -0,0 +1,479 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable, Optional, List + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn.init import trunc_normal_ + +from ..layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block +from ..layers import PatchEmbedMLP + +from .mask_utils import depth_masking + +logger = logging.getLogger("dinov2_rgbd") + +def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + img_depth_fuse_mode='', + depth_mask_ratio:Union[float, List[float]]=0.6, + img_mask_ratio:Union[float, List[float]]=0.0, + depth_mask_patch_grid_size: int=1, + img_mask_patch_grid_size: int=1, + depth_emb_mode='', + # depth_emb_mode='conv_1c' + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.depth_emb_mode = depth_emb_mode + if self.depth_emb_mode == 'conv_1c': + self.depth_patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=1, embed_dim=embed_dim) + else: + self.depth_patch_embed = None + + self.img_depth_fuse_mode = img_depth_fuse_mode + + self.depth_mask_patch_grid_size = depth_mask_patch_grid_size + self.img_mask_patch_grid_size = img_mask_patch_grid_size + assert self.depth_mask_patch_grid_size == 1, "depth_mask_patch_grid_size must be 1 in current version" + assert self.img_mask_patch_grid_size == 1, "img_mask_patch_grid_size must be 1 in current version" + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + @property + def onnx_compatible_mode(self): + return getattr(self, "_onnx_compatible_mode", False) + + @onnx_compatible_mode.setter + def onnx_compatible_mode(self, value: bool): + self._onnx_compatible_mode = value + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, h, w): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + batch_size = x.shape[0] + N = self.pos_embed.shape[1] - 1 + if not self.onnx_compatible_mode and npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0, :] + patch_pos_embed = pos_embed[:, 1:, :] + dim = x.shape[-1] + h0, w0 = h // self.patch_size, w // self.patch_size + M = int(math.sqrt(N)) # Recover the number of patches in each dimension + assert N == M * M + kwargs = {} + if not self.onnx_compatible_mode and self.interpolate_offset > 0: + # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8 + # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors + sx = float(w0 + self.interpolate_offset) / M + sy = float(h0 + self.interpolate_offset) / M + kwargs["scale_factor"] = (sy, sx) + else: + # Simply specify an output size instead of a scale factor + kwargs["size"] = (h0, w0) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), + mode="bicubic", + antialias=self.interpolate_antialias, + **kwargs, + ) + + assert (h0, w0) == patch_pos_embed.shape[-2:] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).flatten(1, 2) + return torch.cat((class_pos_embed[:, None, :].expand(patch_pos_embed.shape[0], -1, -1), patch_pos_embed), dim=1).to(previous_dtype) + + def interpolate_pos_encoding_without_cls(self, x, h, w, input_pos_embed): + previous_dtype = x.dtype + npatch = x.shape[1] + batch_size = x.shape[0] + N = input_pos_embed.shape[1] + if not self.onnx_compatible_mode and npatch == N and w == h: + return input_pos_embed + patch_pos_embed = input_pos_embed.float() + dim = x.shape[-1] + h0, w0 = h // self.patch_size, w // self.patch_size + M = int(math.sqrt(N)) # Recover the number of patches in each dimension + assert N == M * M + kwargs = {} + if not self.onnx_compatible_mode and self.interpolate_offset > 0: + # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8 + # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors + sx = float(w0 + self.interpolate_offset) / M + sy = float(h0 + self.interpolate_offset) / M + kwargs["scale_factor"] = (sy, sx) + else: + # Simply specify an output size instead of a scale factor + kwargs["size"] = (h0, w0) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), + mode="bicubic", + antialias=self.interpolate_antialias, + **kwargs, + ) + assert (h0, w0) == patch_pos_embed.shape[-2:] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).flatten(1, 2) + return patch_pos_embed.to(previous_dtype) + + def prepare_tokens_with_masks(self, x_img, x_depth, x_img_mask=None, x_depth_mask=None, masks=None, **kwargs): + assert masks is None, "extra masks are not supported for this model." + B, nc, h_img, w_img = x_img.shape + _, _, h_depth, w_depth = x_depth.shape + x_depth_raw = x_depth.clone() + x_depth_raw[x_depth_raw == 0] = -10 + + depth_patch_num_h, depth_patch_num_w = h_depth // self.patch_size, w_depth // self.patch_size + + # patchify, embed image tokens and depth tokens + x_img = self.patch_embed(x_img) # batch, length_img, dim + assert self.depth_patch_embed is not None + x_depth = self.depth_patch_embed(x_depth) # batch, length_depth, dim + assert depth_patch_num_h * depth_patch_num_w == x_depth.shape[1] + + # get full pose enc of img and depth + # 1-> img data type enc + # 2-> depth data type enc + img_pose_enc = 1 + self.interpolate_pos_encoding_without_cls(x_img, h_img, w_img, self.pos_embed[:, 1:]).repeat(B, 1, 1) + depth_pose_enc = 2 + self.interpolate_pos_encoding_without_cls(x_depth, h_depth, w_depth, self.pos_embed[:, 1:]).repeat(B, 1, 1) + + # add pose enc to img and depth + x_img = x_img + img_pose_enc + x_depth = x_depth + depth_pose_enc + + ## mask depth tokens + if kwargs.get('enable_depth_mask', True): + x_depth_masked, depth_mask_info = depth_masking( + x_depth, + depth_patch_num_h, + depth_patch_num_w, + depth_values=x_depth_raw, + depth_mask_threshold_num=[1]*B, + valid_depth_range=(-9.5, 200.0) + ) + else: + x_depth_masked = x_depth + depth_mask_info = None + + ## mask image tokens + x_img_masked = x_img + img_mask_info = None + + # get cls token + x_cls = self.cls_token.squeeze(0) + self.pos_embed.squeeze(0)[:1] # 1, dim + + # cat cls, img and depth tokens + assert self.img_depth_fuse_mode == 'cat_token', "Only cat_token mode is supported for this model." + x_masked_list = [] + for i in range(B): + if self.register_tokens is not None: + x_mased = torch.cat([x_cls, self.register_tokens.squeeze(0), x_img_masked[i], x_depth_masked[i]], dim=0) # 1 + num_register_tokens + length_img + length_depth, dim + else: + x_mased = torch.cat([x_cls, x_img_masked[i], x_depth_masked[i]], dim=0) # 1 + length_img + length_depth, dim + x_mased = x_mased.unsqueeze(0) # 1, 1 + num_register_tokens + length_img + length_depth, dim + x_masked_list.append(x_mased) + + return x_masked_list + + def _get_intermediate_layers_not_chunked(self, x_img, x_depth, x_img_mask=None, x_depth_mask=None, n=1, return_mae_aux=False, **kwargs): + x = self.prepare_tokens_with_masks(x_img, x_depth, x_img_mask, x_depth_mask, **kwargs) + + if not kwargs.get('enable_depth_mask', True): + x = torch.cat(x, dim=0) + + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + + if not kwargs.get('enable_depth_mask', True): + output = [list(torch.split(out, 1, dim=0)) for out in output] + return output + + def _get_intermediate_layers_chunked(self, x_img, x_depth, x_img_mask=None, x_depth_mask=None, n=1, return_mae_aux=False, **kwargs): + x = self.prepare_tokens_with_masks(x_img, x_depth, x_img_mask, x_depth_mask, **kwargs) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + + return output + + def extract_features(self, outputs, norm=True): + feat_outputs = [] + class_tokens = [] + feat_start_idx = 1 + self.num_register_tokens + + def process_output(out): + normed = self.norm(out) if norm else out + return normed[:, feat_start_idx:], normed[:, 0] + + for output in outputs: + if isinstance(output, list): + feats, tokens = zip(*[process_output(out) for out in output]) + feat_outputs.append(list(feats)) + class_tokens.append(list(tokens)) + else: + feat, token = process_output(output) + feat_outputs.append(feat) + class_tokens.append(token) + + return feat_outputs, class_tokens + + def get_intermediate_layers_mae( + self, + x_img: torch.Tensor, + x_depth: torch.Tensor, + x_img_mask: torch.Tensor=None, + x_depth_mask: torch.Tensor=None, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True, + return_mae_aux=True, + **kwargs + ): + assert reshape is False, "reshape is not supported for now" + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x_img, x_depth, x_img_mask, x_depth_mask, n, return_mae_aux=return_mae_aux,**kwargs) + else: + outputs = self._get_intermediate_layers_not_chunked(x_img, x_depth, x_img_mask, x_depth_mask, n, return_mae_aux=return_mae_aux,**kwargs) + + feat_outputs, class_tokens = self.extract_features(outputs, norm) + + if return_class_token: + return tuple(zip(feat_outputs, class_tokens)) + return tuple(feat_outputs) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + +def vit_small(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_base(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_large(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model diff --git a/third_party/lingbot_depth/mdm/model/dinov2_rgbd/utils/__init__.py b/third_party/lingbot_depth/mdm/model/dinov2_rgbd/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9 --- /dev/null +++ b/third_party/lingbot_depth/mdm/model/dinov2_rgbd/utils/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. diff --git a/third_party/lingbot_depth/mdm/model/dinov2_rgbd/utils/cluster.py b/third_party/lingbot_depth/mdm/model/dinov2_rgbd/utils/cluster.py new file mode 100644 index 0000000000000000000000000000000000000000..3df87dc3e1eb4f0f8a280dc3137cfef031886314 --- /dev/null +++ b/third_party/lingbot_depth/mdm/model/dinov2_rgbd/utils/cluster.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from enum import Enum +import os +from pathlib import Path +from typing import Any, Dict, Optional + + +class ClusterType(Enum): + AWS = "aws" + FAIR = "fair" + RSC = "rsc" + + +def _guess_cluster_type() -> ClusterType: + uname = os.uname() + if uname.sysname == "Linux": + if uname.release.endswith("-aws"): + # Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws" + return ClusterType.AWS + elif uname.nodename.startswith("rsc"): + # Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc" + return ClusterType.RSC + + return ClusterType.FAIR + + +def get_cluster_type(cluster_type: Optional[ClusterType] = None) -> Optional[ClusterType]: + if cluster_type is None: + return _guess_cluster_type() + + return cluster_type + + +def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]: + cluster_type = get_cluster_type(cluster_type) + if cluster_type is None: + return None + + CHECKPOINT_DIRNAMES = { + ClusterType.AWS: "checkpoints", + ClusterType.FAIR: "checkpoint", + ClusterType.RSC: "checkpoint/dino", + } + return Path("/") / CHECKPOINT_DIRNAMES[cluster_type] + + +def get_user_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]: + checkpoint_path = get_checkpoint_path(cluster_type) + if checkpoint_path is None: + return None + + username = os.environ.get("USER") + assert username is not None + return checkpoint_path / username + + +def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]: + cluster_type = get_cluster_type(cluster_type) + if cluster_type is None: + return None + + SLURM_PARTITIONS = { + ClusterType.AWS: "learnlab", + ClusterType.FAIR: "learnlab", + ClusterType.RSC: "learn", + } + return SLURM_PARTITIONS[cluster_type] + + +def get_slurm_executor_parameters( + nodes: int, num_gpus_per_node: int, cluster_type: Optional[ClusterType] = None, **kwargs +) -> Dict[str, Any]: + # create default parameters + params = { + "mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html + "gpus_per_node": num_gpus_per_node, + "tasks_per_node": num_gpus_per_node, # one task per GPU + "cpus_per_task": 10, + "nodes": nodes, + "slurm_partition": get_slurm_partition(cluster_type), + } + # apply cluster-specific adjustments + cluster_type = get_cluster_type(cluster_type) + if cluster_type == ClusterType.AWS: + params["cpus_per_task"] = 12 + del params["mem_gb"] + elif cluster_type == ClusterType.RSC: + params["cpus_per_task"] = 12 + # set additional parameters / apply overrides + params.update(kwargs) + return params diff --git a/third_party/lingbot_depth/mdm/model/dinov2_rgbd/utils/config.py b/third_party/lingbot_depth/mdm/model/dinov2_rgbd/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..c9de578787bbcb376f8bd5a782206d0eb7ec1f52 --- /dev/null +++ b/third_party/lingbot_depth/mdm/model/dinov2_rgbd/utils/config.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import math +import logging +import os + +from omegaconf import OmegaConf + +import dinov2.distributed as distributed +from dinov2.logging import setup_logging +from dinov2.utils import utils +from dinov2.configs import dinov2_default_config + + +logger = logging.getLogger("dinov2") + + +def apply_scaling_rules_to_cfg(cfg): # to fix + if cfg.optim.scaling_rule == "sqrt_wrt_1024": + base_lr = cfg.optim.base_lr + cfg.optim.lr = base_lr + cfg.optim.lr *= math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0) + logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}") + else: + raise NotImplementedError + return cfg + + +def write_config(cfg, output_dir, name="config.yaml"): + logger.info(OmegaConf.to_yaml(cfg)) + saved_cfg_path = os.path.join(output_dir, name) + with open(saved_cfg_path, "w") as f: + OmegaConf.save(config=cfg, f=f) + return saved_cfg_path + + +def get_cfg_from_args(args): + args.output_dir = os.path.abspath(args.output_dir) + args.opts += [f"train.output_dir={args.output_dir}"] + default_cfg = OmegaConf.create(dinov2_default_config) + cfg = OmegaConf.load(args.config_file) + cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts)) + return cfg + + +def default_setup(args): + distributed.enable(overwrite=True) + seed = getattr(args, "seed", 0) + rank = distributed.get_global_rank() + + global logger + setup_logging(output=args.output_dir, level=logging.INFO) + logger = logging.getLogger("dinov2") + + utils.fix_random_seeds(seed + rank) + logger.info("git:\n {}\n".format(utils.get_sha())) + logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) + + +def setup(args): + """ + Create configs and perform basic setups. + """ + cfg = get_cfg_from_args(args) + os.makedirs(args.output_dir, exist_ok=True) + default_setup(args) + apply_scaling_rules_to_cfg(cfg) + write_config(cfg, args.output_dir) + return cfg diff --git a/third_party/lingbot_depth/mdm/model/dinov2_rgbd/utils/dtype.py b/third_party/lingbot_depth/mdm/model/dinov2_rgbd/utils/dtype.py new file mode 100644 index 0000000000000000000000000000000000000000..80f4cd74d99faa2731dbe9f8d3a13d71b3f8e3a8 --- /dev/null +++ b/third_party/lingbot_depth/mdm/model/dinov2_rgbd/utils/dtype.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + + +from typing import Dict, Union + +import numpy as np +import torch + + +TypeSpec = Union[str, np.dtype, torch.dtype] + + +_NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = { + np.dtype("bool"): torch.bool, + np.dtype("uint8"): torch.uint8, + np.dtype("int8"): torch.int8, + np.dtype("int16"): torch.int16, + np.dtype("int32"): torch.int32, + np.dtype("int64"): torch.int64, + np.dtype("float16"): torch.float16, + np.dtype("float32"): torch.float32, + np.dtype("float64"): torch.float64, + np.dtype("complex64"): torch.complex64, + np.dtype("complex128"): torch.complex128, +} + + +def as_torch_dtype(dtype: TypeSpec) -> torch.dtype: + if isinstance(dtype, torch.dtype): + return dtype + if isinstance(dtype, str): + dtype = np.dtype(dtype) + assert isinstance(dtype, np.dtype), f"Expected an instance of nunpy dtype, got {type(dtype)}" + return _NUMPY_TO_TORCH_DTYPE[dtype] diff --git a/third_party/lingbot_depth/mdm/model/dinov2_rgbd/utils/param_groups.py b/third_party/lingbot_depth/mdm/model/dinov2_rgbd/utils/param_groups.py new file mode 100644 index 0000000000000000000000000000000000000000..9a5d2ff627cddadc222e5f836864ee39c865208f --- /dev/null +++ b/third_party/lingbot_depth/mdm/model/dinov2_rgbd/utils/param_groups.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from collections import defaultdict +import logging + + +logger = logging.getLogger("dinov2") + + +def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12, force_is_backbone=False, chunked_blocks=False): + """ + Calculate lr decay rate for different ViT blocks. + Args: + name (string): parameter name. + lr_decay_rate (float): base lr decay rate. + num_layers (int): number of ViT blocks. + Returns: + lr decay rate for the given parameter. + """ + layer_id = num_layers + 1 + if name.startswith("backbone") or force_is_backbone: + if ( + ".pos_embed" in name + or ".patch_embed" in name + or ".mask_token" in name + or ".cls_token" in name + or ".register_tokens" in name + ): + layer_id = 0 + elif force_is_backbone and ( + "pos_embed" in name + or "patch_embed" in name + or "mask_token" in name + or "cls_token" in name + or "register_tokens" in name + ): + layer_id = 0 + elif ".blocks." in name and ".residual." not in name: + layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1 + elif chunked_blocks and "blocks." in name and "residual." not in name: + layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1 + elif "blocks." in name and "residual." not in name: + layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1 + + return lr_decay_rate ** (num_layers + 1 - layer_id) + + +def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0): + chunked_blocks = False + if hasattr(model, "n_blocks"): + logger.info("chunked fsdp") + n_blocks = model.n_blocks + chunked_blocks = model.chunked_blocks + elif hasattr(model, "blocks"): + logger.info("first code branch") + n_blocks = len(model.blocks) + elif hasattr(model, "backbone"): + logger.info("second code branch") + n_blocks = len(model.backbone.blocks) + else: + logger.info("else code branch") + n_blocks = 0 + all_param_groups = [] + + for name, param in model.named_parameters(): + name = name.replace("_fsdp_wrapped_module.", "") + if not param.requires_grad: + continue + decay_rate = get_vit_lr_decay_rate( + name, lr_decay_rate, num_layers=n_blocks, force_is_backbone=n_blocks > 0, chunked_blocks=chunked_blocks + ) + d = {"params": param, "is_last_layer": False, "lr_multiplier": decay_rate, "wd_multiplier": 1.0, "name": name} + + if "last_layer" in name: + d.update({"is_last_layer": True}) + + if name.endswith(".bias") or "norm" in name or "gamma" in name: + d.update({"wd_multiplier": 0.0}) + + if "patch_embed" in name: + d.update({"lr_multiplier": d["lr_multiplier"] * patch_embed_lr_mult}) + + all_param_groups.append(d) + logger.info(f"""{name}: lr_multiplier: {d["lr_multiplier"]}, wd_multiplier: {d["wd_multiplier"]}""") + + return all_param_groups + + +def fuse_params_groups(all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer")): + fused_params_groups = defaultdict(lambda: {"params": []}) + for d in all_params_groups: + identifier = "" + for k in keys: + identifier += k + str(d[k]) + "_" + + for k in keys: + fused_params_groups[identifier][k] = d[k] + fused_params_groups[identifier]["params"].append(d["params"]) + + return fused_params_groups.values() diff --git a/third_party/lingbot_depth/mdm/model/dinov2_rgbd/utils/utils.py b/third_party/lingbot_depth/mdm/model/dinov2_rgbd/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..68f8e2c3be5f780bbb7e00359b5ac4fd0ba0785f --- /dev/null +++ b/third_party/lingbot_depth/mdm/model/dinov2_rgbd/utils/utils.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging +import os +import random +import subprocess +from urllib.parse import urlparse + +import numpy as np +import torch +from torch import nn + + +logger = logging.getLogger("dinov2") + + +def load_pretrained_weights(model, pretrained_weights, checkpoint_key): + if urlparse(pretrained_weights).scheme: # If it looks like an URL + state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu") + else: + state_dict = torch.load(pretrained_weights, map_location="cpu") + if checkpoint_key is not None and checkpoint_key in state_dict: + logger.info(f"Take key {checkpoint_key} in provided checkpoint dict") + state_dict = state_dict[checkpoint_key] + # remove `module.` prefix + state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} + # remove `backbone.` prefix induced by multicrop wrapper + state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} + msg = model.load_state_dict(state_dict, strict=False) + logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg)) + + +def fix_random_seeds(seed=31): + """ + Fix random seeds. + """ + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + + +def get_sha(): + cwd = os.path.dirname(os.path.abspath(__file__)) + + def _run(command): + return subprocess.check_output(command, cwd=cwd).decode("ascii").strip() + + sha = "N/A" + diff = "clean" + branch = "N/A" + try: + sha = _run(["git", "rev-parse", "HEAD"]) + subprocess.check_output(["git", "diff"], cwd=cwd) + diff = _run(["git", "diff-index", "HEAD"]) + diff = "has uncommitted changes" if diff else "clean" + branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"]) + except Exception: + pass + message = f"sha: {sha}, status: {diff}, branch: {branch}" + return message + + +class CosineScheduler(object): + def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0): + super().__init__() + self.final_value = final_value + self.total_iters = total_iters + + freeze_schedule = np.zeros((freeze_iters)) + + warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) + + iters = np.arange(total_iters - warmup_iters - freeze_iters) + schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) + self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule)) + + assert len(self.schedule) == self.total_iters + + def __getitem__(self, it): + if it >= self.total_iters: + return self.final_value + else: + return self.schedule[it] + + +def has_batchnorms(model): + bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) + for name, module in model.named_modules(): + if isinstance(module, bn_types): + return True + return False diff --git a/third_party/lingbot_depth/mdm/model/modules_decoder.py b/third_party/lingbot_depth/mdm/model/modules_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..2dbbc46e9754f5d6946000e3a6926d31ea553570 --- /dev/null +++ b/third_party/lingbot_depth/mdm/model/modules_decoder.py @@ -0,0 +1,185 @@ +from typing import * +from numbers import Number +import importlib +import itertools +import functools +import sys + +import torch +from torch import Tensor +import torch.nn as nn +import torch.nn.functional as F + +from .utils import wrap_module_with_gradient_checkpointing + + +class ResidualConvBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int = None, + hidden_channels: int = None, + kernel_size: int = 3, + padding_mode: str = 'replicate', + activation: Literal['relu', 'leaky_relu', 'silu', 'elu'] = 'relu', + in_norm: Literal['group_norm', 'layer_norm', 'instance_norm', 'none'] = 'layer_norm', + hidden_norm: Literal['group_norm', 'layer_norm', 'instance_norm'] = 'group_norm', + ): + super(ResidualConvBlock, self).__init__() + if out_channels is None: + out_channels = in_channels + if hidden_channels is None: + hidden_channels = in_channels + + if activation =='relu': + activation_cls = nn.ReLU + elif activation == 'leaky_relu': + activation_cls = functools.partial(nn.LeakyReLU, negative_slope=0.2) + elif activation =='silu': + activation_cls = nn.SiLU + elif activation == 'elu': + activation_cls = nn.ELU + else: + raise ValueError(f'Unsupported activation function: {activation}') + + self.layers = nn.Sequential( + nn.GroupNorm(in_channels // 32, in_channels) if in_norm == 'group_norm' else \ + nn.GroupNorm(1, in_channels) if in_norm == 'layer_norm' else \ + nn.InstanceNorm2d(in_channels) if in_norm == 'instance_norm' else \ + nn.Identity(), + activation_cls(), + nn.Conv2d(in_channels, hidden_channels, kernel_size=kernel_size, padding=kernel_size // 2, padding_mode=padding_mode), + nn.GroupNorm(hidden_channels // 32, hidden_channels) if hidden_norm == 'group_norm' else \ + nn.GroupNorm(1, hidden_channels) if hidden_norm == 'layer_norm' else \ + nn.InstanceNorm2d(hidden_channels) if hidden_norm == 'instance_norm' else\ + nn.Identity(), + activation_cls(), + nn.Conv2d(hidden_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2, padding_mode=padding_mode) + ) + + self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) if in_channels != out_channels else nn.Identity() + + def forward(self, x): + skip = self.skip_connection(x) + x = self.layers(x) + x = x + skip + return x + + +class Resampler(nn.Sequential): + def __init__(self, + in_channels: int, + out_channels: int, + type_: Literal['pixel_shuffle', 'nearest', 'bilinear', 'conv_transpose', 'pixel_unshuffle', 'avg_pool', 'max_pool'], + scale_factor: int = 2, + ): + if type_ == 'pixel_shuffle': + nn.Sequential.__init__(self, + nn.Conv2d(in_channels, out_channels * (scale_factor ** 2), kernel_size=3, stride=1, padding=1, padding_mode='replicate'), + nn.PixelShuffle(scale_factor), + nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate') + ) + for i in range(1, scale_factor ** 2): + self[0].weight.data[i::scale_factor ** 2] = self[0].weight.data[0::scale_factor ** 2] + self[0].bias.data[i::scale_factor ** 2] = self[0].bias.data[0::scale_factor ** 2] + elif type_ in ['nearest', 'bilinear']: + nn.Sequential.__init__(self, + nn.Upsample(scale_factor=scale_factor, mode=type_, align_corners=False if type_ == 'bilinear' else None), + nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate') + ) + elif type_ == 'conv_transpose': + nn.Sequential.__init__(self, + nn.ConvTranspose2d(in_channels, out_channels, kernel_size=scale_factor, stride=scale_factor), + nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate') + ) + self[0].weight.data[:] = self[0].weight.data[:, :, :1, :1] + elif type_ == 'pixel_unshuffle': + nn.Sequential.__init__(self, + nn.PixelUnshuffle(scale_factor), + nn.Conv2d(in_channels * (scale_factor ** 2), out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate') + ) + elif type_ == 'avg_pool': + nn.Sequential.__init__(self, + nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate'), + nn.AvgPool2d(kernel_size=scale_factor, stride=scale_factor), + ) + elif type_ == 'max_pool': + nn.Sequential.__init__(self, + nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate'), + nn.MaxPool2d(kernel_size=scale_factor, stride=scale_factor), + ) + else: + raise ValueError(f'Unsupported resampler type: {type_}') + + +class MLP(nn.Sequential): + def __init__(self, dims: Sequence[int]): + nn.Sequential.__init__(self, + *itertools.chain(*[ + (nn.Linear(dim_in, dim_out), nn.ReLU(inplace=True)) + for dim_in, dim_out in zip(dims[:-2], dims[1:-1]) + ]), + nn.Linear(dims[-2], dims[-1]), + ) + + +class ConvStack(nn.Module): + def __init__(self, + dim_in: List[Optional[int]], + dim_res_blocks: List[int], + dim_out: List[Optional[int]], + resamplers: Union[Literal['pixel_shuffle', 'nearest', 'bilinear', 'conv_transpose', 'pixel_unshuffle', 'avg_pool', 'max_pool'], List], + dim_times_res_block_hidden: int = 1, + num_res_blocks: int = 1, + res_block_in_norm: Literal['layer_norm', 'group_norm' , 'instance_norm', 'none'] = 'layer_norm', + res_block_hidden_norm: Literal['layer_norm', 'group_norm' , 'instance_norm', 'none'] = 'group_norm', + activation: Literal['relu', 'leaky_relu', 'silu', 'elu'] = 'relu', + ): + super().__init__() + self.input_blocks = nn.ModuleList([ + nn.Conv2d(dim_in_, dim_res_block_, kernel_size=1, stride=1, padding=0) if dim_in_ is not None else nn.Identity() + for dim_in_, dim_res_block_ in zip(dim_in if isinstance(dim_in, Sequence) else itertools.repeat(dim_in), dim_res_blocks) + ]) + self.resamplers = nn.ModuleList([ + Resampler(dim_prev, dim_succ, scale_factor=2, type_=resampler) + for i, (dim_prev, dim_succ, resampler) in enumerate(zip( + dim_res_blocks[:-1], + dim_res_blocks[1:], + resamplers if isinstance(resamplers, Sequence) else itertools.repeat(resamplers) + )) + ]) + self.res_blocks = nn.ModuleList([ + nn.Sequential( + *( + ResidualConvBlock( + dim_res_block_, dim_res_block_, dim_times_res_block_hidden * dim_res_block_, + activation=activation, in_norm=res_block_in_norm, hidden_norm=res_block_hidden_norm + ) for _ in range(num_res_blocks[i] if isinstance(num_res_blocks, list) else num_res_blocks) + ) + ) for i, dim_res_block_ in enumerate(dim_res_blocks) + ]) + self.output_blocks = nn.ModuleList([ + nn.Conv2d(dim_res_block_, dim_out_, kernel_size=1, stride=1, padding=0) if dim_out_ is not None else nn.Identity() + for dim_out_, dim_res_block_ in zip(dim_out if isinstance(dim_out, Sequence) else itertools.repeat(dim_out), dim_res_blocks) + ]) + + def enable_gradient_checkpointing(self): + for i in range(len(self.resamplers)): + self.resamplers[i] = wrap_module_with_gradient_checkpointing(self.resamplers[i]) + for i in range(len(self.res_blocks)): + for j in range(len(self.res_blocks[i])): + self.res_blocks[i][j] = wrap_module_with_gradient_checkpointing(self.res_blocks[i][j]) + + def forward(self, in_features: List[torch.Tensor]): + out_features = [] + for i in range(len(self.res_blocks)): + feature = self.input_blocks[i](in_features[i]) + if i == 0: + x = feature + elif feature is not None: + x = x + feature + x = self.res_blocks[i](x) + out_features.append(self.output_blocks[i](x)) + if i < len(self.res_blocks) - 1: + x = self.resamplers[i](x) + return out_features diff --git a/third_party/lingbot_depth/mdm/model/modules_rgbd_encoder.py b/third_party/lingbot_depth/mdm/model/modules_rgbd_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..8712b3f6674c5ba5ff1e3a242ac65beeef4c53b8 --- /dev/null +++ b/third_party/lingbot_depth/mdm/model/modules_rgbd_encoder.py @@ -0,0 +1,152 @@ +from typing import * +from numbers import Number +import importlib +import itertools +import functools +import sys + +import torch +from torch import Tensor +import torch.nn as nn +import torch.nn.functional as F + +from .dinov2_rgbd.models.vision_transformer import DinoVisionTransformer +from .utils import wrap_dinov2_attention_with_sdpa, wrap_module_with_gradient_checkpointing + + +class DINOv2_RGBD_Encoder(nn.Module): + backbone: DinoVisionTransformer + image_mean: torch.Tensor + image_std: torch.Tensor + dim_features: int + + def __init__(self, backbone: str, intermediate_layers: Union[int, List[int]], dim_out: int, ignore_layers: Union[str, List[str]]=[], in_chans: int=3, strict: bool=True, img_depth_fuse_mode='', depth_emb_mode='', depth_mask_ratio=0.6, img_mask_ratio=0.0, **deprecated_kwargs): + super(DINOv2_RGBD_Encoder, self).__init__() + + self.intermediate_layers = intermediate_layers + self.strict = strict + self.ignore_layers = ignore_layers + self.img_mask_ratio = img_mask_ratio + # Load the backbone + self.hub_loader = getattr(importlib.import_module(".dinov2_rgbd.hub.backbones", __package__), backbone) + self.backbone_name = backbone + self.backbone = self.hub_loader(pretrained=False, + in_chans=in_chans, + img_depth_fuse_mode=img_depth_fuse_mode, + depth_emb_mode=depth_emb_mode, + depth_mask_ratio=depth_mask_ratio, + img_mask_ratio=img_mask_ratio) + + self.dim_features = self.backbone.blocks[0].attn.qkv.in_features + self.num_features = intermediate_layers if isinstance(intermediate_layers, int) else len(intermediate_layers) + + if img_mask_ratio > 0: + self.mask_token_mae = nn.Parameter(torch.zeros(1, 1, self.dim_features)) + torch.nn.init.normal_(self.mask_token_mae, std=.02) + + self.output_projections = nn.ModuleList([ + nn.Conv2d(in_channels=self.dim_features, out_channels=dim_out, kernel_size=1, stride=1, padding=0,) + for _ in range(self.num_features) + ]) + + self.register_buffer("image_mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) + self.register_buffer("image_std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) + + @property + def onnx_compatible_mode(self): + return getattr(self, "_onnx_compatible_mode", False) + + @onnx_compatible_mode.setter + def onnx_compatible_mode(self, value: bool): + self._onnx_compatible_mode = value + self.backbone.onnx_compatible_mode = value + + def init_weights(self): + pretrained_backbone_state_dict = self.hub_loader(pretrained=True).state_dict() + ignore_layers = [] + if isinstance(self.ignore_layers, str): + ignore_layers = [self.ignore_layers] + else: + ignore_layers = self.ignore_layers + + if len(ignore_layers) == 0: + self.backbone.load_state_dict(pretrained_backbone_state_dict, strict=self.strict) + else: + state_dict = {} + for k, v in pretrained_backbone_state_dict.items(): + is_ignore = False + for ig_k in ignore_layers: + if ig_k in k: + is_ignore = True + break + if not is_ignore: + state_dict[k] = v + self.backbone.load_state_dict(state_dict, strict=self.strict) + + def enable_gradient_checkpointing(self): + for i in range(len(self.backbone.blocks)): + wrap_module_with_gradient_checkpointing(self.backbone.blocks[i]) + + def enable_pytorch_native_sdpa(self): + for i in range(len(self.backbone.blocks)): + wrap_dinov2_attention_with_sdpa(self.backbone.blocks[i].attn) + + def forward(self, + image: torch.Tensor, + depth: torch.Tensor, + token_rows: Union[int, torch.LongTensor], + token_cols: Union[int, torch.LongTensor], + return_class_token: bool = False, + remap_depth_in: str='linear', + **kwargs): + image_14 = F.interpolate(image, (token_rows * 14, token_cols * 14), mode="bilinear", align_corners=False, antialias=not self.onnx_compatible_mode) + image_14 = (image_14 - self.image_mean) / self.image_std + + depth_14 = F.interpolate(depth, (token_rows * 14, token_cols * 14), mode="nearest") + + # set invalid depth value to zero + depth_14[torch.isinf(depth_14)] = 0.0 + depth_14[torch.isnan(depth_14)] = 0.0 + dmask_14 = (depth_14 > 0.01).detach() + depth_14 = depth_14 * dmask_14.float() + + if remap_depth_in == 'linear': + pass # do nothing + elif remap_depth_in == 'log': + depth_14 = torch.log(depth_14) + depth_14[~dmask_14] = 0.0 + depth_14 = torch.nan_to_num(depth_14, nan=0.0, posinf=0.0, neginf=0.0) + else: + raise NotImplementedError + + # Get intermediate layers from the backbone + features = self.backbone.get_intermediate_layers_mae( + x_img=image_14, + x_depth=depth_14, + n=self.intermediate_layers, + return_class_token=True, + **kwargs) + + assert self.img_mask_ratio == 0, "img_mask_ratio is not supported in this encoder" + + if isinstance(features[0][0], list): + num_valid_tokens = token_rows * token_cols + features = tuple( + ( + torch.cat([feat[:, :num_valid_tokens].contiguous() for feat in feats], dim=0), + torch.cat(cls_tokens, dim=0) + ) + for feats, cls_tokens in features + ) + + # Project features to the desired dimensionality + x = torch.stack([ + proj(feat.permute(0, 2, 1)[:, :, :token_rows*token_cols].unflatten(2, (token_rows, token_cols)).contiguous()) + for proj, (feat, clstoken) in zip(self.output_projections, features) + ], dim=1).sum(dim=1) + cls_token = features[-1][1] + + if return_class_token: + return x, cls_token, None, None + else: + return x, None, None diff --git a/third_party/lingbot_depth/mdm/model/utils.py b/third_party/lingbot_depth/mdm/model/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5aca85509a7957a7e29bc7dffee76c7950cf8e79 --- /dev/null +++ b/third_party/lingbot_depth/mdm/model/utils.py @@ -0,0 +1,127 @@ +from typing import * + +import torch +import torch.nn as nn +import torch.nn.functional as F + +def wrap_module_with_gradient_checkpointing(module: nn.Module): + from torch.utils.checkpoint import checkpoint + class _CheckpointingWrapper(module.__class__): + _restore_cls = module.__class__ + def forward(self, *args, **kwargs): + return checkpoint(super().forward, *args, use_reentrant=False, **kwargs) + + module.__class__ = _CheckpointingWrapper + return module + + +def unwrap_module_with_gradient_checkpointing(module: nn.Module): + module.__class__ = module.__class__._restore_cls + + +def wrap_dinov2_attention_with_sdpa(module: nn.Module): + assert torch.__version__ >= '2.0', "SDPA requires PyTorch 2.0 or later" + class _AttentionWrapper(module.__class__): + def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # (3, B, H, N, C // H) + + q, k, v = torch.unbind(qkv, 0) # (B, H, N, C // H) + + x = F.scaled_dot_product_attention(q, k, v, attn_bias) + x = x.permute(0, 2, 1, 3).reshape(B, N, C) + + x = self.proj(x) + x = self.proj_drop(x) + return x + module.__class__ = _AttentionWrapper + return module + +def wrap_dinov3_attention_with_sdpa(module: nn.Module): + assert torch.__version__ >= '2.0', "SDPA requires PyTorch 2.0 or later" + class _AttentionWrapper(module.__class__): + def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # (3, B, H, N, C // H) + + q, k, v = torch.unbind(qkv, 0) # (B, H, N, C // H) + + x = F.scaled_dot_product_attention(q, k, v, attn_bias) + x = x.permute(0, 2, 1, 3).reshape(B, N, C) + + x = self.proj(x) + x = self.proj_drop(x) + return x + module.__class__ = _AttentionWrapper + return module + +def sync_ddp_hook(state, bucket: torch.distributed.GradBucket) -> torch.futures.Future[torch.Tensor]: + group_to_use = torch.distributed.group.WORLD + world_size = group_to_use.size() + grad = bucket.buffer() + grad.div_(world_size) + torch.distributed.all_reduce(grad, group=group_to_use) + fut = torch.futures.Future() + fut.set_result(grad) + return fut + +def depth_to_pointcloud(depth, intrinsic_normalized, depth_scale=1.0): + """ + Convert depth map to point cloud (pure Tensor version, no point filtering) + + Args: + depth: torch.Tensor, shape (H, W) or (B, H, W), depth map + intrinsic_normalized: torch.Tensor, shape (3, 3) or (B, 3, 3), normalized intrinsic matrix + Normalized intrinsics: fx' = fx/W, fy' = fy/H, cx' = cx/W, cy' = cy/H + depth_scale: float, depth scale factor, default 1000.0 + + Returns: + points: torch.Tensor, shape (H, W, 3) or (B, H, W, 3), point cloud coordinates (x, y, z) + """ + # Handle batch dimension + if depth.dim() == 2: + depth = depth.unsqueeze(0) # (1, H, W) + intrinsic_normalized = intrinsic_normalized.unsqueeze(0) # (1, 3, 3) + squeeze_output = True + else: + squeeze_output = False + + B, H, W = depth.shape + device = depth.device + + # Denormalize intrinsics + fx = intrinsic_normalized[:, 0, 0] * W # (B,) + fy = intrinsic_normalized[:, 1, 1] * H + cx = intrinsic_normalized[:, 0, 2] * W + cy = intrinsic_normalized[:, 1, 2] * H + + # Create pixel coordinate grid (H, W) + v, u = torch.meshgrid( + torch.arange(H, device=device, dtype=torch.float32), + torch.arange(W, device=device, dtype=torch.float32), + indexing='ij' + ) + + # Expand to batch dimension (B, H, W) + u = u.unsqueeze(0).expand(B, -1, -1) + v = v.unsqueeze(0).expand(B, -1, -1) + + # Backproject to 3D space + z = depth / depth_scale # (B, H, W) + + # Expand intrinsic dimensions for broadcasting (B, 1, 1) + fx = fx.view(B, 1, 1) + fy = fy.view(B, 1, 1) + cx = cx.view(B, 1, 1) + cy = cy.view(B, 1, 1) + + x = (u - cx) * z / fx # (B, H, W) + y = (v - cy) * z / fy # (B, H, W) + + # Stack coordinates (B, H, W, 3) + points = torch.stack([x, y, z], dim=-1) + + if squeeze_output: + points = points.squeeze(0) # (H, W, 3) + + return points \ No newline at end of file diff --git a/third_party/lingbot_depth/mdm/model/v2.py b/third_party/lingbot_depth/mdm/model/v2.py new file mode 100644 index 0000000000000000000000000000000000000000..b7582b703b9496be711b4cb6cea6af4c6c2d4b71 --- /dev/null +++ b/third_party/lingbot_depth/mdm/model/v2.py @@ -0,0 +1,297 @@ +from typing import * +from numbers import Number +from functools import partial +from pathlib import Path +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils +import torch.utils.checkpoint +import torch.amp +import torch.version +from huggingface_hub import hf_hub_download + +from .modules_rgbd_encoder import DINOv2_RGBD_Encoder +from .modules_decoder import MLP, ConvStack +from ..utils.geo import depth_to_pointcloud, normalized_view_plane_uv + + +class MDMModel(nn.Module): + encoder: Union[DINOv2_RGBD_Encoder] + neck: ConvStack + points_head: ConvStack + mask_head: ConvStack + scale_head: MLP + onnx_compatible_mode: bool + + def __init__(self, + encoder: Dict[str, Any], + neck: Dict[str, Any], + depth_head: Dict[str, Any] = None, + mask_head: Dict[str, Any] = None, + normal_head: Dict[str, Any] = None, + scale_head: Dict[str, Any] = None, + remap_output: Literal['linear', 'sinh', 'exp', 'sinh_exp'] = 'linear', + remap_depth_in: Literal['linear', 'log'] = 'log', + remap_depth_out: Literal['linear', 'exp'] = 'exp', + num_tokens_range: List[int] = [1200, 3600], + **deprecated_kwargs + ): + super(MDMModel, self).__init__() + if deprecated_kwargs: + warnings.warn(f"The following deprecated/invalid arguments are ignored: {deprecated_kwargs}") + + self.remap_output = remap_output + self.num_tokens_range = num_tokens_range + self.remap_depth_in = remap_depth_in + self.remap_depth_out = remap_depth_out + + self.encoder = DINOv2_RGBD_Encoder(**encoder) + + self.neck = ConvStack(**neck) + if depth_head is not None: + self.depth_head = ConvStack(**depth_head) + if mask_head is not None: + self.mask_head = ConvStack(**mask_head) + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + @property + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Union[str, Path, IO[bytes]], + model_kwargs: Optional[Dict[str, Any]] = None, + **hf_kwargs) -> 'MDMModel': + if Path(pretrained_model_name_or_path).exists(): + checkpoint_path = pretrained_model_name_or_path + else: + checkpoint_path = hf_hub_download( + repo_id=pretrained_model_name_or_path, + repo_type="model", + filename="model.pt", + **hf_kwargs + ) + checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=True) + + model_config = checkpoint['model_config'] + if model_kwargs is not None: + model_config.update(model_kwargs) + model = cls(**model_config) + model.load_state_dict(checkpoint['model'], strict=False) + + return model + + def init_weights(self): + self.encoder.init_weights() + + def enable_pytorch_native_sdpa(self): + self.encoder.enable_pytorch_native_sdpa() + + def forward(self, + image: torch.Tensor, + num_tokens: Union[int, torch.LongTensor], + depth: Union[None, torch.Tensor]=None, + **kwargs) -> Dict[str, torch.Tensor]: + batch_size, _, img_h, img_w = image.shape + device, dtype = image.device, image.dtype + + assert depth is not None # in this version, depth is required + if depth.dim() == 3: + depth = depth.unsqueeze(1) # from (B, H, W) to (B, 1, H, W) + + aspect_ratio = img_w / img_h + base_h, base_w = (num_tokens / aspect_ratio) ** 0.5, (num_tokens * aspect_ratio) ** 0.5 + if isinstance(base_h, torch.Tensor): + base_h, base_w = base_h.round().long(), base_w.round().long() + else: + base_h, base_w = round(base_h), round(base_w) + + # Backbones encoding + features, cls_token, _, _ = self.encoder(image, depth, base_h, base_w, return_class_token=True, remap_depth_in=self.remap_depth_in, **kwargs) + + features = features + cls_token[..., None, None] + features = [features, None, None, None, None] + + # Concat UVs for aspect ratio input + for level in range(5): + uv = normalized_view_plane_uv(width=base_w * 2 ** level, height=base_h * 2 ** level, aspect_ratio=aspect_ratio, dtype=dtype, device=device) + uv = uv.permute(2, 0, 1).unsqueeze(0).expand(batch_size, -1, -1, -1) + if features[level] is None: + features[level] = uv + else: + features[level] = torch.concat([features[level], uv], dim=1) + + # Shared neck + features = self.neck(features) + + # Heads decoding + depth_reg, normal, mask = (getattr(self, head)(features)[-1] if hasattr(self, head) else None for head in ['depth_head', 'normal_head', 'mask_head']) + metric_scale = self.scale_head(cls_token) if hasattr(self, 'scale_head') else None + + # Resize + depth_reg, normal, mask = (F.interpolate(v, (img_h, img_w), mode='bilinear', align_corners=False, antialias=False) if v is not None else None for v in [depth_reg, normal, mask]) + + # Remap output + if depth_reg is not None: + if self.remap_depth_out == 'exp': + depth_reg = depth_reg.exp().squeeze(1) + elif self.remap_depth_out == 'linear': + depth_reg = depth_reg.squeeze(1) + else: + raise ValueError(f"Invalid remap_depth_out: {self.remap_depth_out}") + if normal is not None: + normal = normal.permute(0, 2, 3, 1) + normal = F.normalize(normal, dim=-1) + if mask is not None: + mask_prob = mask.squeeze(1).sigmoid() + # mask_logits = mask.squeeze(1) + else: + mask_prob = None + if metric_scale is not None: + metric_scale = metric_scale.squeeze(1).exp() + + return_dict = { + 'depth_reg': depth_reg, + 'normal': normal, + 'mask': mask_prob, + } + return_dict = {k: v for k, v in return_dict.items() if v is not None} + + return return_dict + + @torch.inference_mode() + def infer( + self, + image: torch.Tensor, + depth_in: torch.Tensor = None, + num_tokens: int = None, + resolution_level: int = 9, + apply_mask: bool = True, + use_fp16: bool = True, + intrinsics: Optional[torch.Tensor] = None, + **kwargs + ) -> Dict[str, torch.Tensor]: + if image.dim() == 3: + omit_batch_dim = True + image = image.unsqueeze(0) + else: + omit_batch_dim = False + image = image.to(dtype=self.dtype, device=self.device) + + if (depth_in is not None) and (depth_in.dim() == 2): + depth_in = depth_in.unsqueeze(0).to(dtype=self.dtype, device=self.device) + + original_height, original_width = image.shape[-2:] + area = original_height * original_width + aspect_ratio = original_width / original_height + + # Determine the number of base tokens to use + if num_tokens is None: + min_tokens, max_tokens = self.num_tokens_range + num_tokens = int(min_tokens + (resolution_level / 9) * (max_tokens - min_tokens)) + + # Forward pass + with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16, enabled=use_fp16 and self.dtype != torch.bfloat16): + output = self.forward(image, num_tokens=num_tokens, depth=depth_in, **kwargs) + depth_reg, mask = (output.get(k, None) for k in ['depth_reg', 'mask']) + + # Always process the output in fp32 precision + depth_reg, mask = map(lambda x: x.float() if isinstance(x, torch.Tensor) else x, [depth_reg, mask]) + with torch.autocast(device_type=self.device.type, dtype=torch.float32): + if mask is not None: + mask_binary = mask > 0.5 + else: + mask_binary = None + + depth = depth_reg + if intrinsics is not None: + points = depth_to_pointcloud(depth, intrinsics) + else: + points = None + + # Apply mask + if apply_mask and mask_binary is not None: + points = torch.where(mask_binary[..., None], points, torch.inf) if points is not None else None + depth = torch.where(mask_binary, depth, torch.inf) if depth is not None else None + + return_dict = { + 'points': points, + 'depth': depth, + 'mask': mask_binary, + } + return_dict = {k: v for k, v in return_dict.items() if v is not None} + + if omit_batch_dim: + return_dict = {k: v.squeeze(0) for k, v in return_dict.items()} + + return return_dict + + def forward_feat(self, + image: torch.Tensor, + num_tokens: Union[int, torch.LongTensor], + depth: Union[None, torch.Tensor]=None, + **kwargs) -> Dict[str, torch.Tensor]: + batch_size, _, img_h, img_w = image.shape + device, dtype = image.device, image.dtype + + assert depth is not None # in this version, depth is required + if depth.dim() == 3: + depth = depth.unsqueeze(1) # from (B, H, W) to (B, 1, H, W) + + aspect_ratio = img_w / img_h + base_h, base_w = (num_tokens / aspect_ratio) ** 0.5, (num_tokens * aspect_ratio) ** 0.5 + if isinstance(base_h, torch.Tensor): + base_h, base_w = base_h.round().long(), base_w.round().long() + else: + base_h, base_w = round(base_h), round(base_w) + + # Backbones encoding + features, cls_token, _, _ = self.encoder(image, depth, base_h, base_w, return_class_token=True, remap_depth_in=self.remap_depth_in, **kwargs) + + return features, cls_token + + + @torch.inference_mode() + def infer_feat( + self, + image: torch.Tensor, + depth_in: torch.Tensor = None, + num_tokens: int = None, + resolution_level: int = 9, + apply_mask: bool = True, + use_fp16: bool = True, + intrinsics: Optional[torch.Tensor] = None, + **kwargs + ): + if image.dim() == 3: + omit_batch_dim = True + image = image.unsqueeze(0) + else: + omit_batch_dim = False + image = image.to(dtype=self.dtype, device=self.device) + + if (depth_in is not None) and (depth_in.dim() == 2): + depth_in = depth_in.unsqueeze(0).to(dtype=self.dtype, device=self.device) + + original_height, original_width = image.shape[-2:] + area = original_height * original_width + aspect_ratio = original_width / original_height + + # Determine the number of base tokens to use + if num_tokens is None: + min_tokens, max_tokens = self.num_tokens_range + num_tokens = int(min_tokens + (resolution_level / 9) * (max_tokens - min_tokens)) + + # Forward pass + with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16, enabled=use_fp16 and self.dtype != torch.bfloat16): + features, cls_token = self.forward_feat(image, num_tokens=num_tokens, depth=depth_in, **kwargs) + + return features, cls_token \ No newline at end of file diff --git a/third_party/lingbot_depth/mdm/utils/__init__.py b/third_party/lingbot_depth/mdm/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/lingbot_depth/mdm/utils/geo.py b/third_party/lingbot_depth/mdm/utils/geo.py new file mode 100644 index 0000000000000000000000000000000000000000..13fc0bd86248dd521b3beb571ed2418356620875 --- /dev/null +++ b/third_party/lingbot_depth/mdm/utils/geo.py @@ -0,0 +1,105 @@ +import torch + +def normalized_view_plane_uv(width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None) -> torch.Tensor: + "UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)" + if aspect_ratio is None: + aspect_ratio = width / height + + span_x = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 + span_y = 1 / (1 + aspect_ratio ** 2) ** 0.5 + + u = torch.linspace(-span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype, device=device) + v = torch.linspace(-span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype, device=device) + u, v = torch.meshgrid(u, v, indexing='xy') + uv = torch.stack([u, v], dim=-1) + return uv + +def depth_to_pointcloud(depth, intrinsic_normalized, depth_scale=1.0): + """ + Convert depth map to point cloud (pure Tensor version, no point filtering) + + Args: + depth: torch.Tensor, shape (H, W) or (B, H, W), depth map + intrinsic_normalized: torch.Tensor, shape (3, 3) or (B, 3, 3), normalized intrinsic matrix + Normalized intrinsics: fx' = fx/W, fy' = fy/H, cx' = cx/W, cy' = cy/H + depth_scale: float, depth scale factor, default 1000.0 + + Returns: + points: torch.Tensor, shape (H, W, 3) or (B, H, W, 3), point cloud coordinates (x, y, z) + """ + # Handle batch dimension + if depth.dim() == 2: + depth = depth.unsqueeze(0) # (1, H, W) + intrinsic_normalized = intrinsic_normalized.unsqueeze(0) # (1, 3, 3) + squeeze_output = True + else: + squeeze_output = False + + B, H, W = depth.shape + device = depth.device + + # Denormalize intrinsics + fx = intrinsic_normalized[:, 0, 0] * W # (B,) + fy = intrinsic_normalized[:, 1, 1] * H + cx = intrinsic_normalized[:, 0, 2] * W + cy = intrinsic_normalized[:, 1, 2] * H + + # Create pixel coordinate grid (H, W) + v, u = torch.meshgrid( + torch.arange(H, device=device, dtype=torch.float32), + torch.arange(W, device=device, dtype=torch.float32), + indexing='ij' + ) + + # Expand to batch dimension (B, H, W) + u = u.unsqueeze(0).expand(B, -1, -1) + v = v.unsqueeze(0).expand(B, -1, -1) + + # Backproject to 3D space + z = depth / depth_scale # (B, H, W) + + # Expand intrinsic dimensions for broadcasting (B, 1, 1) + fx = fx.view(B, 1, 1) + fy = fy.view(B, 1, 1) + cx = cx.view(B, 1, 1) + cy = cy.view(B, 1, 1) + + x = (u - cx) * z / fx # (B, H, W) + y = (v - cy) * z / fy # (B, H, W) + + # Stack coordinates (B, H, W, 3) + points = torch.stack([x, y, z], dim=-1) + + if squeeze_output: + points = points.squeeze(0) # (H, W, 3) + + return points + + +# Usage example +if __name__ == "__main__": + # Single image + depth = torch.rand(480, 640) * 5000 # Depth values + intrinsic_norm = torch.tensor([ + [525.0/640, 0, 319.5/640], + [0, 525.0/480, 239.5/480], + [0, 0, 1] + ]) + + points = depth_to_pointcloud(depth, intrinsic_norm) + print(f"Point cloud shape: {points.shape}") # (480, 640, 3) + + # Batch processing + depth_batch = torch.rand(4, 480, 640) * 5000 + intrinsic_batch = intrinsic_norm.unsqueeze(0).expand(4, -1, -1) + + points_batch = depth_to_pointcloud(depth_batch, intrinsic_batch) + print(f"Batch point cloud shape: {points_batch.shape}") # (4, 480, 640, 3) + + # Flatten to (N, 3) format if needed + points_flat = points.reshape(-1, 3) + print(f"Flattened shape: {points_flat.shape}") # (480*640, 3) + + # Batch flatten to (B, N, 3) + points_batch_flat = points_batch.reshape(4, -1, 3) + print(f"Batch flattened shape: {points_batch_flat.shape}") # (4, 480*640, 3) \ No newline at end of file diff --git a/third_party/lingbot_depth/mdm/utils/io.py b/third_party/lingbot_depth/mdm/utils/io.py new file mode 100644 index 0000000000000000000000000000000000000000..cf40b327a32ab36fa597d9b868b0a6aa6cfbad53 --- /dev/null +++ b/third_party/lingbot_depth/mdm/utils/io.py @@ -0,0 +1,270 @@ +import os +os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' +from typing import IO +import zipfile +import json +import io +from typing import * +from pathlib import Path +import re +from PIL import Image, PngImagePlugin + +import numpy as np +import cv2 + +from .tools import timeit + + +def save_glb( + save_path: Union[str, os.PathLike], + vertices: np.ndarray, + faces: np.ndarray, + vertex_uvs: np.ndarray, + texture: np.ndarray, + vertex_normals: Optional[np.ndarray] = None, +): + import trimesh + import trimesh.visual + from PIL import Image + + trimesh.Trimesh( + vertices=vertices, + vertex_normals=vertex_normals, + faces=faces, + visual = trimesh.visual.texture.TextureVisuals( + uv=vertex_uvs, + material=trimesh.visual.material.PBRMaterial( + baseColorTexture=Image.fromarray(texture), + metallicFactor=0.5, + roughnessFactor=1.0 + ) + ), + process=False + ).export(save_path) + + +def save_ply( + save_path: Union[str, os.PathLike], + vertices: np.ndarray, + faces: np.ndarray, + vertex_colors: np.ndarray, + vertex_normals: Optional[np.ndarray] = None, +): + import trimesh + import trimesh.visual + from PIL import Image + + trimesh.Trimesh( + vertices=vertices, + faces=faces, + vertex_colors=vertex_colors, + vertex_normals=vertex_normals, + process=False + ).export(save_path) + + +def read_image(path: Union[str, os.PathLike, IO]) -> np.ndarray: + """ + Read a image, return uint8 RGB array of shape (H, W, 3). + """ + if isinstance(path, (str, os.PathLike)): + data = Path(path).read_bytes() + else: + data = path.read() + image = cv2.cvtColor(cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) + return image + + +def write_image(path: Union[str, os.PathLike, IO], image: np.ndarray, quality: int = 95): + """ + Write a image, input uint8 RGB array of shape (H, W, 3). + """ + data = cv2.imencode('.jpg', cv2.cvtColor(image, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_JPEG_QUALITY, quality])[1].tobytes() + if isinstance(path, (str, os.PathLike)): + Path(path).write_bytes(data) + else: + path.write(data) + + +def read_depth(path: Union[str, os.PathLike, IO]) -> np.ndarray: + """ + Read a depth image, return float32 depth array of shape (H, W). + """ + if isinstance(path, (str, os.PathLike)): + data = Path(path).read_bytes() + else: + data = path.read() + pil_image = Image.open(io.BytesIO(data)) + near = float(pil_image.info.get('near')) + far = float(pil_image.info.get('far')) + depth = np.array(pil_image) + mask_nan, mask_inf = depth == 0, depth == 65535 + depth = (depth.astype(np.float32) - 1) / 65533 + depth = near ** (1 - depth) * far ** depth + if 'unit' in pil_image.info: # Legacy support for depth units + unit = float(pil_image.info.get('unit')) + depth = depth * unit + depth[mask_nan] = np.nan + depth[mask_inf] = np.inf + return depth + +def write_depth( + path: Union[str, os.PathLike, IO], + depth: np.ndarray, + max_range: float = 1e5, + compression_level: int = 7, +): + """ + Encode and write a depth image as 16-bit PNG format. + ## Parameters: + - `path: Union[str, os.PathLike, IO]` + The file path or file object to write to. + - `depth: np.ndarray` + The depth array, float32 array of shape (H, W). + May contain `NaN` for invalid values and `Inf` for infinite values. + + Depth values are encoded as follows: + - 0: unknown + - 1 ~ 65534: depth values in logarithmic + - 65535: infinity + + metadata is stored in the PNG file as text fields: + - `near`: the minimum depth value + - `far`: the maximum depth value + """ + mask_values, mask_nan, mask_inf = np.isfinite(depth), np.isnan(depth),np.isinf(depth) + + depth = depth.astype(np.float32) + mask_finite = depth + near = max(depth[mask_values].min(), 1e-5) + far = max(near * 1.1, min(depth[mask_values].max(), near * max_range)) + depth = 1 + np.round((np.log(np.nan_to_num(depth, nan=0).clip(near, far) / near) / np.log(far / near)).clip(0, 1) * 65533).astype(np.uint16) # 1~65534 + depth[mask_nan] = 0 + depth[mask_inf] = 65535 + + pil_image = Image.fromarray(depth) + pnginfo = PngImagePlugin.PngInfo() + pnginfo.add_text('near', str(near)) + pnginfo.add_text('far', str(far)) + pil_image.save(path, pnginfo=pnginfo, compress_level=compression_level) + + +def read_segmentation(path: Union[str, os.PathLike, IO]) -> Tuple[np.ndarray, Dict[str, int]]: + """ + Read a segmentation mask + ### Parameters: + - `path: Union[str, os.PathLike, IO]` + The file path or file object to read from. + ### Returns: + - `Tuple[np.ndarray, Dict[str, int]]` + A tuple containing: + - `mask`: uint8 or uint16 numpy.ndarray of shape (H, W). + - `labels`: Dict[str, int]. The label mapping, a dictionary of {label_name: label_id}. + """ + if isinstance(path, (str, os.PathLike)): + data = Path(path).read_bytes() + else: + data = path.read() + pil_image = Image.open(io.BytesIO(data)) + labels = json.loads(pil_image.info['labels']) if 'labels' in pil_image.info else None + mask = np.array(pil_image) + return mask, labels + + +def write_segmentation(path: Union[str, os.PathLike, IO], mask: np.ndarray, labels: Dict[str, int] = None, compression_level: int = 7): + """ + Write a segmentation mask and label mapping, as PNG format. + ### Parameters: + - `path: Union[str, os.PathLike, IO]` + The file path or file object to write to. + - `mask: np.ndarray` + The segmentation mask, uint8 or uint16 array of shape (H, W). + - `labels: Dict[str, int] = None` + The label mapping, a dictionary of {label_name: label_id}. + - `compression_level: int = 7` + The compression level for PNG compression. + """ + assert mask.dtype == np.uint8 or mask.dtype == np.uint16, f"Unsupported dtype {mask.dtype}" + pil_image = Image.fromarray(mask) + pnginfo = PngImagePlugin.PngInfo() + if labels is not None: + labels_json = json.dumps(labels, ensure_ascii=True, separators=(',', ':')) + pnginfo.add_text('labels', labels_json) + pil_image.save(path, pnginfo=pnginfo, compress_level=compression_level) + + + +def read_normal(path: Union[str, os.PathLike, IO]) -> np.ndarray: + """ + Read a normal image, return float32 normal array of shape (H, W, 3). + """ + if isinstance(path, (str, os.PathLike)): + data = Path(path).read_bytes() + else: + data = path.read() + normal = cv2.cvtColor(cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB) + mask_nan = np.all(normal == 0, axis=-1) + normal = (normal.astype(np.float32) / 65535 - 0.5) * [2.0, -2.0, -2.0] + normal = normal / (np.sqrt(np.square(normal[..., 0]) + np.square(normal[..., 1]) + np.square(normal[..., 2])) + 1e-12) + normal[mask_nan] = np.nan + return normal + + +def write_normal(path: Union[str, os.PathLike, IO], normal: np.ndarray, compression_level: int = 7) -> np.ndarray: + """ + Write a normal image, input float32 normal array of shape (H, W, 3). + """ + mask_nan = np.isnan(normal).any(axis=-1) + normal = ((normal * [0.5, -0.5, -0.5] + 0.5).clip(0, 1) * 65535).astype(np.uint16) + normal[mask_nan] = 0 + data = cv2.imencode('.png', cv2.cvtColor(normal, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_PNG_COMPRESSION, compression_level])[1].tobytes() + if isinstance(path, (str, os.PathLike)): + Path(path).write_bytes(data) + else: + path.write(data) + + +def read_mask(path: Union[str, os.PathLike, IO[bytes]]) -> np.ndarray: + """ + Read a binary mask, return bool array of shape (H, W). + """ + if isinstance(path, (str, os.PathLike)): + data = Path(path).read_bytes() + else: + data = path.read() + mask = cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_UNCHANGED) + if len(mask.shape) == 3: + mask = mask[..., 0] + return mask > 0 + + +def write_mask(path: Union[str, os.PathLike, IO[bytes]], mask: np.ndarray, compression_level: int = 7): + """ + Write a binary mask, input bool array of shape (H, W). + """ + assert mask.dtype == bool, f"Mask must be bool array, got {mask.dtype}" + mask = (mask.astype(np.uint8) * 255).astype(np.uint8) + data = cv2.imencode('.png', mask, [cv2.IMWRITE_PNG_COMPRESSION, compression_level])[1].tobytes() + if isinstance(path, (str, os.PathLike)): + Path(path).write_bytes(data) + else: + path.write(data) + + +JSON_TYPE = Union[str, int, float, bool, None, Dict[str, "JSON"], List["JSON"]] + + +def read_json(path: Union[str, os.PathLike, IO[str]]) -> JSON_TYPE: + if isinstance(path, (str, os.PathLike)): + text = Path(path).read_text() + else: + text = path.read() + return json.loads(text) + + +def write_json(path: Union[str, os.PathLike, IO[str]], content: JSON_TYPE): + text = json.dumps(content) + if isinstance(path, (str, os.PathLike)): + Path(path).write_text(text) + else: + path.write(text) \ No newline at end of file diff --git a/third_party/lingbot_depth/mdm/utils/tools.py b/third_party/lingbot_depth/mdm/utils/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..3687f6938fe34433d149a1a8405be7eed5f23c37 --- /dev/null +++ b/third_party/lingbot_depth/mdm/utils/tools.py @@ -0,0 +1,289 @@ +from typing import * +import time +from pathlib import Path +from numbers import Number +from functools import wraps +import warnings +import math +import json +import os +import importlib +import importlib.util + + +def catch_exception(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + try: + return fn(*args, **kwargs) + except Exception as e: + import traceback + print(f"Exception in {fn.__name__}", end='r') + # print({', '.join(repr(arg) for arg in args)}, {', '.join(f'{k}={v!r}' for k, v in kwargs.items())}) + traceback.print_exc(chain=False) + time.sleep(0.1) + return None + return wrapper + + +class CallbackOnException: + def __init__(self, callback: Callable, exception: type): + self.exception = exception + self.callback = callback + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if isinstance(exc_val, self.exception): + self.callback() + return True + return False + +def traverse_nested_dict_keys(d: Dict[str, Dict]) -> Generator[Tuple[str, ...], None, None]: + for k, v in d.items(): + if isinstance(v, dict): + for sub_key in traverse_nested_dict_keys(v): + yield (k, ) + sub_key + else: + yield (k, ) + + +def get_nested_dict(d: Dict[str, Dict], keys: Tuple[str, ...], default: Any = None): + for k in keys: + d = d.get(k, default) + if d is None: + break + return d + +def set_nested_dict(d: Dict[str, Dict], keys: Tuple[str, ...], value: Any): + for k in keys[:-1]: + d = d.setdefault(k, {}) + d[keys[-1]] = value + + +def key_average(list_of_dicts: list) -> Dict[str, Any]: + """ + Returns a dictionary with the average value of each key in the input list of dictionaries. + """ + _nested_dict_keys = set() + for d in list_of_dicts: + _nested_dict_keys.update(traverse_nested_dict_keys(d)) + _nested_dict_keys = sorted(_nested_dict_keys) + result = {} + for k in _nested_dict_keys: + values = [] + for d in list_of_dicts: + v = get_nested_dict(d, k) + if v is not None and not math.isnan(v): + values.append(v) + avg = sum(values) / len(values) if values else float('nan') + set_nested_dict(result, k, avg) + return result + + +def flatten_nested_dict(d: Dict[str, Any], parent_key: Tuple[str, ...] = None) -> Dict[Tuple[str, ...], Any]: + """ + Flattens a nested dictionary into a single-level dictionary, with keys as tuples. + """ + items = [] + if parent_key is None: + parent_key = () + for k, v in d.items(): + new_key = parent_key + (k, ) + if isinstance(v, MutableMapping): + items.extend(flatten_nested_dict(v, new_key).items()) + else: + items.append((new_key, v)) + return dict(items) + + +def unflatten_nested_dict(d: Dict[str, Any]) -> Dict[str, Any]: + """ + Unflattens a single-level dictionary into a nested dictionary, with keys as tuples. + """ + result = {} + for k, v in d.items(): + sub_dict = result + for k_ in k[:-1]: + if k_ not in sub_dict: + sub_dict[k_] = {} + sub_dict = sub_dict[k_] + sub_dict[k[-1]] = v + return result + + +def read_jsonl(file): + import json + with open(file, 'r') as f: + data = f.readlines() + return [json.loads(line) for line in data] + + +def write_jsonl(data: List[dict], file): + import json + with open(file, 'w') as f: + for item in data: + f.write(json.dumps(item) + '\n') + + +def to_hierachical_dataframe(data: List[Dict[Tuple[str, ...], Any]]): + import pandas as pd + data = [flatten_nested_dict(d) for d in data] + df = pd.DataFrame(data) + df = df.sort_index(axis=1) + df.columns = pd.MultiIndex.from_tuples(df.columns) + return df + + +def recursive_replace(d: Union[List, Dict, str], mapping: Dict[str, str]): + if isinstance(d, str): + for old, new in mapping.items(): + d = d.replace(old, new) + elif isinstance(d, list): + for i, item in enumerate(d): + d[i] = recursive_replace(item, mapping) + elif isinstance(d, dict): + for k, v in d.items(): + d[k] = recursive_replace(v, mapping) + return d + + +class timeit: + _history: Dict[str, List['timeit']] = {} + + def __init__(self, name: str = None, verbose: bool = True, average: bool = False): + self.name = name + self.verbose = verbose + self.start = None + self.end = None + self.average = average + if average and name not in timeit._history: + timeit._history[name] = [] + + def __call__(self, func: Callable): + import inspect + if inspect.iscoroutinefunction(func): + async def wrapper(*args, **kwargs): + with timeit(self.name or func.__qualname__): + ret = await func(*args, **kwargs) + return ret + return wrapper + else: + def wrapper(*args, **kwargs): + with timeit(self.name or func.__qualname__): + ret = func(*args, **kwargs) + return ret + return wrapper + + def __enter__(self): + self.start = time.time() + return self + + @property + def time(self) -> float: + assert self.start is not None, "Time not yet started." + assert self.end is not None, "Time not yet ended." + return self.end - self.start + + @property + def average_time(self) -> float: + assert self.average, "Average time not available." + return sum(t.time for t in timeit._history[self.name]) / len(timeit._history[self.name]) + + @property + def history(self) -> List['timeit']: + return timeit._history.get(self.name, []) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.end = time.time() + if self.average: + timeit._history[self.name].append(self) + if self.verbose: + if self.average: + avg = self.average_time + print(f"{self.name or 'It'} took {avg:.6f} seconds in average.") + else: + print(f"{self.name or 'It'} took {self.time:.6f} seconds.") + + +def strip_common_prefix_suffix(strings: List[str]) -> List[str]: + first = strings[0] + + for start in range(len(first)): + if any(s[start] != strings[0][start] for s in strings): + break + + for end in range(1, min(len(s) for s in strings)): + if any(s[-end] != first[-end] for s in strings): + break + + return [s[start:len(s) - end + 1] for s in strings] + + +def multithead_execute(inputs: List[Any], num_workers: int, pbar = None): + from concurrent.futures import ThreadPoolExecutor + from contextlib import nullcontext + from tqdm import tqdm + + if pbar is not None: + pbar.total = len(inputs) if hasattr(inputs, '__len__') else None + else: + pbar = tqdm(total=len(inputs) if hasattr(inputs, '__len__') else None) + + def decorator(fn: Callable): + with ( + ThreadPoolExecutor(max_workers=num_workers) as executor, + pbar + ): + pbar.refresh() + @catch_exception + @suppress_traceback + def _fn(input): + ret = fn(input) + pbar.update() + return ret + executor.map(_fn, inputs) + executor.shutdown(wait=True) + + return decorator + + +def suppress_traceback(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + try: + return fn(*args, **kwargs) + except Exception as e: + e.__traceback__ = e.__traceback__.tb_next.tb_next + raise + return wrapper + + +class no_warnings: + def __init__(self, action: str = 'ignore', **kwargs): + self.action = action + self.filter_kwargs = kwargs + + def __call__(self, fn): + @wraps(fn) + def wrapper(*args, **kwargs): + with warnings.catch_warnings(): + warnings.simplefilter(self.action, **self.filter_kwargs) + return fn(*args, **kwargs) + return wrapper + + def __enter__(self): + self.warnings_manager = warnings.catch_warnings() + self.warnings_manager.__enter__() + warnings.simplefilter(self.action, **self.filter_kwargs) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.warnings_manager.__exit__(exc_type, exc_val, exc_tb) + + +def import_file_as_module(file_path: Union[str, os.PathLike], module_name: str): + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module \ No newline at end of file diff --git a/third_party/lingbot_depth/mdm/utils/vis.py b/third_party/lingbot_depth/mdm/utils/vis.py new file mode 100644 index 0000000000000000000000000000000000000000..e17edfa9e576b9c2b182bf39bfa289a4480bc9e3 --- /dev/null +++ b/third_party/lingbot_depth/mdm/utils/vis.py @@ -0,0 +1,65 @@ +from typing import * + +import numpy as np +import matplotlib +import trimesh +import random +import torch +import torch.nn.functional as F +import os + +def colorize_depth(depth: np.ndarray, mask: np.ndarray = None, normalize: bool = True, cmap: str = 'Spectral') -> np.ndarray: + depth = depth.copy() + if mask is None: + depth = np.where(depth > 0, depth, np.nan) + else: + depth = np.where((depth > 0) & mask, depth, np.nan) + disp = 1 / depth + if normalize: + min_disp, max_disp = np.nanquantile(disp, 0.001), np.nanquantile(disp, 0.99) + disp = (disp - min_disp) / (max_disp - min_disp) + + colored = np.nan_to_num(matplotlib.colormaps[cmap](1.0 - disp)[..., :3], 0) + colored = np.ascontiguousarray((colored.clip(0, 1) * 255).astype(np.uint8)) + return colored + + +def colorize_depth_affine(depth: np.ndarray, mask: np.ndarray = None, cmap: str = 'Spectral') -> np.ndarray: + if mask is not None: + depth = np.where(mask, depth, np.nan) + + min_depth, max_depth = np.nanquantile(depth, 0.001), np.nanquantile(depth, 0.999) + depth = (depth - min_depth) / (max_depth - min_depth) + colored = np.nan_to_num(matplotlib.colormaps[cmap](depth)[..., :3], 0) + colored = np.ascontiguousarray((colored.clip(0, 1) * 255).astype(np.uint8)) + return colored + + +def colorize_disparity(disparity: np.ndarray, mask: np.ndarray = None, normalize: bool = True, cmap: str = 'Spectral') -> np.ndarray: + if mask is not None: + disparity = np.where(mask, disparity, np.nan) + + if normalize: + min_disp, max_disp = np.nanquantile(disparity, 0.001), np.nanquantile(disparity, 0.999) + disparity = (disparity - min_disp) / (max_disp - min_disp) + colored = np.nan_to_num(matplotlib.colormaps[cmap](1.0 - disparity)[..., :3], 0) + colored = np.ascontiguousarray((colored.clip(0, 1) * 255).astype(np.uint8)) + return colored + + +def colorize_normal(normal: np.ndarray, mask: np.ndarray = None) -> np.ndarray: + if mask is not None: + normal = np.where(mask[..., None], normal, 0) + normal = normal * [0.5, -0.5, -0.5] + 0.5 + normal = (normal.clip(0, 1) * 255).astype(np.uint8) + return normal + + +def colorize_error_map(error_map: np.ndarray, mask: np.ndarray = None, cmap: str = 'plasma', value_range: Tuple[float, float] = None): + vmin, vmax = value_range if value_range is not None else (np.nanmin(error_map), np.nanmax(error_map)) + cmap = matplotlib.colormaps[cmap] + colorized_error_map = cmap(((error_map - vmin) / (vmax - vmin)).clip(0, 1))[..., :3] + if mask is not None: + colorized_error_map = np.where(mask[..., None], colorized_error_map, 0) + colorized_error_map = np.ascontiguousarray((colorized_error_map.clip(0, 1) * 255).astype(np.uint8)) + return colorized_error_map \ No newline at end of file diff --git a/third_party/lingbot_depth/pyproject.toml b/third_party/lingbot_depth/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..04142f03538b10692d2ceed3c5cb24f2baf62949 --- /dev/null +++ b/third_party/lingbot_depth/pyproject.toml @@ -0,0 +1,26 @@ +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "mdm" +version = "1.0.0" +readme = "README.md" +dependencies = [ + "click", + "opencv-python", + "scipy", + "matplotlib", + "trimesh", + "pillow", + "huggingface_hub", + "numpy", + "torch==2.6.0", + "torchvision", + "xformers==v0.0.29.post2", +] +requires-python = ">=3.9" + +[tool.setuptools.packages.find] +where = ["."] +include = ["mdm*"] \ No newline at end of file diff --git a/third_party/sam3/pyproject.toml b/third_party/sam3/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..9df1b678079348544ba585fe293533d063dd81b8 --- /dev/null +++ b/third_party/sam3/pyproject.toml @@ -0,0 +1,135 @@ +[build-system] +requires = ["setuptools>=61", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "sam3" +dynamic = ["version"] +description = "SAM3 (Segment Anything Model 3) implementation" +readme = "README.md" +requires-python = ">=3.8" +license = {file = "LICENSE"} +authors = [ + {name = "Meta AI Research"} +] +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Scientific/Engineering :: Artificial Intelligence", +] +dependencies = [ + "timm>=1.0.17", + "numpy>=1.26,<2", + "tqdm", + "ftfy==6.1.1", + "regex", + "iopath>=0.1.10", + "typing_extensions", + "huggingface_hub", +] + +[project.optional-dependencies] +dev = [ + "pytest", + "pytest-cov", + "black==24.2.0", + "ufmt==2.8.0", + "ruff-api==0.1.0", + "usort==1.0.2", + "gitpython==3.1.31", + "yt-dlp", + "pandas", + "opencv-python", + "pycocotools", + "numba", + "python-rapidjson", +] +notebooks = [ + "matplotlib", + "jupyter", + "notebook", + "ipywidgets", + "ipycanvas", + "ipympl", + "pycocotools", + "decord", + "opencv-python", + "einops", + "scikit-image", + "scikit-learn", +] +train = [ + "hydra-core", + "submitit", + "tensorboard", + "zstandard", + "scipy", + "torchmetrics", + "fvcore", + "fairscale", + "scikit-image", + "scikit-learn", +] + +[project.urls] +"Homepage" = "https://github.com/facebookresearch/sam3" +"Bug Tracker" = "https://github.com/facebookresearch/sam3/issues" + +[tool.setuptools.packages.find] +include = ["sam3*"] +exclude = ["build*", "scripts*", "examples*"] + +[tool.setuptools.package-data] +sam3 = ["assets/*.txt.gz"] + +[tool.setuptools.dynamic] +version = {attr = "sam3.__version__"} + +[tool.black] +line-length = 88 +target-version = ['py38', 'py39', 'py310', 'py311', 'py312'] +include = '\.pyi?$' + +[tool.isort] +profile = "black" +multi_line_output = 3 + +[tool.usort] +first_party_detection = false + +[tool.ufmt] +formatter = "ruff-api" + +[tool.mypy] +python_version = "3.12" +warn_return_any = true +warn_unused_configs = true +disallow_untyped_defs = true +disallow_incomplete_defs = true + +[[tool.mypy.overrides]] +module = [ + "torch.*", + "torchvision.*", + "timm.*", + "numpy.*", + "PIL.*", + "tqdm.*", + "ftfy.*", + "regex.*", + "iopath.*", +] +ignore_missing_imports = true + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = "test_*.py" +python_classes = "Test*" +python_functions = "test_*" diff --git a/third_party/sam3/sam3/__init__.py b/third_party/sam3/sam3/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..360033928dac83f40e09b5c10e2025793d1bc165 --- /dev/null +++ b/third_party/sam3/sam3/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +from .model_builder import build_sam3_image_model, build_sam3_predictor + +__version__ = "0.1.0" + +__all__ = ["build_sam3_image_model", "build_sam3_predictor"] diff --git a/third_party/sam3/sam3/__pycache__/__init__.cpython-311.pyc b/third_party/sam3/sam3/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96dfacda2fe1e455fda1d77f5fe6fc9d0d32aaaf Binary files /dev/null and b/third_party/sam3/sam3/__pycache__/__init__.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/__pycache__/logger.cpython-311.pyc b/third_party/sam3/sam3/__pycache__/logger.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96e34bae938d3f1af9971993d541474720e7709f Binary files /dev/null and b/third_party/sam3/sam3/__pycache__/logger.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/__pycache__/model_builder.cpython-311.pyc b/third_party/sam3/sam3/__pycache__/model_builder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04a24fb3c0692ca8a3418195e9243df30fa50cc0 Binary files /dev/null and b/third_party/sam3/sam3/__pycache__/model_builder.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/agent/__init__.py b/third_party/sam3/sam3/agent/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..726cb1eb0ff0f17f7edc4ec73f1a73d1ac87bf59 --- /dev/null +++ b/third_party/sam3/sam3/agent/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe diff --git a/third_party/sam3/sam3/agent/agent_core.py b/third_party/sam3/sam3/agent/agent_core.py new file mode 100644 index 0000000000000000000000000000000000000000..27e51206953cc6c20f5573cab51f0e009303d394 --- /dev/null +++ b/third_party/sam3/sam3/agent/agent_core.py @@ -0,0 +1,565 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +import copy +import json +import os + +import cv2 +from PIL import Image + +from .client_llm import send_generate_request +from .client_sam3 import call_sam_service +from .viz import visualize + + +def save_debug_messages(messages_list, debug, debug_folder_path, debug_jsonl_path): + """Save messages to debug jsonl file if debug is enabled""" + if debug and debug_jsonl_path: + # Ensure the debug directory exists before writing + os.makedirs(debug_folder_path, exist_ok=True) + with open(debug_jsonl_path, "w") as f: + for msg in messages_list: + f.write(json.dumps(msg, indent=4) + "\n") + + +def cleanup_debug_files(debug, debug_folder_path, debug_jsonl_path): + """Clean up debug files when function successfully returns""" + if debug and debug_folder_path: + try: + if os.path.exists(debug_jsonl_path): + os.remove(debug_jsonl_path) + if os.path.exists(debug_folder_path): + os.rmdir(debug_folder_path) + except Exception as e: + print(f"Warning: Could not clean up debug files: {e}") + + +def count_images(messages): + """Count the total number of images present in the messages history.""" + total = 0 + for message in messages: + # Check if message has content (should be a list) + if "content" in message and isinstance(message["content"], list): + # Iterate through each content item + for content_item in message["content"]: + # Check if content item is a dict with type "image" + if ( + isinstance(content_item, dict) + and content_item.get("type") == "image" + ): + total += 1 + return total + + +def _prune_messages_for_next_round( + messages_list, + used_text_prompts, + latest_sam3_text_prompt, + img_path, + initial_text_prompt, +): + """Return a new messages list that contains only: + 1) messages[:2] (with optional warning text added to the second message's content) + 2) the latest assistant message (and everything after it) that contains a segment_phrase tool call + """ + # There should not be more than 10 messages in the conversation history + assert len(messages_list) < 10 + + # Part 1: always keep the first two message JSONs + part1 = copy.deepcopy(messages_list[:2]) + + # Part 2: search backwards for the latest assistant message containing a segment_phrase tool call + part2_start_idx = None + for idx in range(len(messages_list) - 1, 1, -1): + msg = messages_list[idx] + # We only consider assistant messages with a "content" list + if msg.get("role") != "assistant" or "content" not in msg: + continue + # Look for any content element that is a text containing the segment_phrase tool call + for content in msg["content"]: + if ( + isinstance(content, dict) + and content.get("type") == "text" + and "" in content.get("text", "") + and "segment_phrase" in content.get("text", "") + ): + part2_start_idx = idx + break + if part2_start_idx is not None: + break + + part2 = messages_list[part2_start_idx:] if part2_start_idx is not None else [] + + # Part 3: decide whether to add warning text to the second message in part1 + previously_used = ( + [p for p in used_text_prompts if p != latest_sam3_text_prompt] + if latest_sam3_text_prompt + else list(used_text_prompts) + ) + if part2 and len(previously_used) > 0: + warning_text = f'Note that we have previously called the segment_phrase tool with each "text_prompt" in this list: {list(previously_used)}, but none of the generated results were satisfactory. So make sure that you do not use any of these phrases as the "text_prompt" to call the segment_phrase tool again.' + # Replace the second message entirely to keep exactly 2 content items + part1[1] = { + "role": "user", + "content": [ + {"type": "image", "image": img_path}, + { + "type": "text", + "text": f"The above image is the raw input image. The initial user input query is: '{initial_text_prompt}'." + + " " + + warning_text, + }, + ], + } + assert len(part1[1]["content"]) == 2 + + # Build the new messages list: part1 (with optional warning), then part2 + new_messages = list(part1) + new_messages.extend(part2) + return new_messages + + +def agent_inference( + img_path: str, + initial_text_prompt: str, + debug: bool = False, + send_generate_request=send_generate_request, + call_sam_service=call_sam_service, + max_generations: int = 100, + output_dir="../../sam3_agent_out", +): + """ + Given a text prompt and an image, this tool will perform all aspects of agentic problem solving, + while saving sam3 and MLLM outputs to their respective directories. + + Args: + img_path: Path to the input image + initial_text_prompt: Initial text prompt from the user + debug: Whether to enable debug mode + max_generations: Maximum number of send_generate_request calls allowed (default: 100) + """ + # setup dir + sam_output_dir = os.path.join(output_dir, "sam_out") + error_save_dir = os.path.join(output_dir, "none_out") + debug_save_dir = os.path.join(output_dir, "agent_debug_out") + os.makedirs(sam_output_dir, exist_ok=True) + os.makedirs(error_save_dir, exist_ok=True) + os.makedirs(debug_save_dir, exist_ok=True) + current_dir = os.path.dirname(os.path.abspath(__file__)) + MLLM_SYSTEM_PROMPT_PATH = os.path.join( + current_dir, "system_prompts/system_prompt.txt" + ) + ITERATIVE_CHECKING_SYSTEM_PROMPT_PATH = os.path.join( + current_dir, "system_prompts/system_prompt_iterative_checking.txt" + ) + # init variables + PATH_TO_LATEST_OUTPUT_JSON = "" + LATEST_SAM3_TEXT_PROMPT = "" + USED_TEXT_PROMPTS = ( + set() + ) # Track all previously used text prompts for segment_phrase + generation_count = 0 # Counter for number of send_generate_request calls + + # debug setup + debug_folder_path = None + debug_jsonl_path = None + if debug: + debug_folder_path = os.path.join( + debug_save_dir, f"{img_path.rsplit('/', 1)[-1].rsplit('.', 1)[0]}" + ) + debug_jsonl_path = os.path.join(debug_folder_path, "debug_history.json") + os.makedirs(debug_folder_path, exist_ok=True) + + # The helper functions are now defined outside the agent_inference function + with open(MLLM_SYSTEM_PROMPT_PATH, "r") as f: + system_prompt = f.read().strip() + with open(ITERATIVE_CHECKING_SYSTEM_PROMPT_PATH, "r") as f: + iterative_checking_system_prompt = f.read().strip() + + # Construct the initial message list + messages = [ + {"role": "system", "content": system_prompt}, + { + "role": "user", + "content": [ + {"type": "image", "image": img_path}, + { + "type": "text", + "text": f"The above image is the raw input image. The initial user input query is: '{initial_text_prompt}'.", + }, + ], + }, + ] + print(f"> Text prompt: {initial_text_prompt}") + print(f"> Image path: {img_path}") + + print("\n\n") + print("-" * 30 + f" Round {str(generation_count + 1)}" + "-" * 30) + print("\n\n") + generated_text = send_generate_request(messages) + print(f"\n>>> MLLM Response [start]\n{generated_text}\n<<< MLLM Response [end]\n") + while generated_text is not None: + save_debug_messages(messages, debug, debug_folder_path, debug_jsonl_path) + assert ( + "" in generated_text, + f"Generated text does not contain tag: {generated_text}", + ) + generated_text = generated_text.split("", 1)[0] + "" + tool_call_json_str = ( + generated_text.split("")[-1] + .split("")[0] + .strip() + .replace(r"}}}", r"}}") # remove extra } if any + ) + try: + tool_call = json.loads(tool_call_json_str) + except json.JSONDecodeError: + raise ValueError(f"Invalid JSON in tool call: {tool_call_json_str}") + + if PATH_TO_LATEST_OUTPUT_JSON == "": + # The first tool call must be segment_phrase or report_no_mask + assert ( + tool_call["name"] == "segment_phrase" + or tool_call["name"] == "report_no_mask" + ) + + if tool_call["name"] == "segment_phrase": + print("🔍 Calling segment_phrase tool...") + assert list(tool_call["parameters"].keys()) == ["text_prompt"] + + # Check if this text_prompt has been used before + current_text_prompt = tool_call["parameters"]["text_prompt"] + if current_text_prompt in USED_TEXT_PROMPTS: + print( + f"❌ Text prompt '{current_text_prompt}' has been used before. Requesting a different prompt." + ) + duplicate_prompt_message = f"You have previously used '{current_text_prompt}' as your text_prompt to call the segment_phrase tool. You may not use it again. Please call the segment_phrase tool again with a different, perhaps more general, or more creative simple noun phrase prompt, while adhering to all the rules stated in the system prompt. You must also never use any of the following text_prompt(s): {str(list(USED_TEXT_PROMPTS))}." + messages.append( + { + "role": "assistant", + "content": [{"type": "text", "text": generated_text}], + } + ) + messages.append( + { + "role": "user", + "content": [{"type": "text", "text": duplicate_prompt_message}], + } + ) + else: + # Add the text_prompt to the set of used prompts + USED_TEXT_PROMPTS.add(current_text_prompt) + LATEST_SAM3_TEXT_PROMPT = current_text_prompt + PATH_TO_LATEST_OUTPUT_JSON = call_sam_service( + image_path=img_path, + text_prompt=current_text_prompt, + output_folder_path=sam_output_dir, + ) + sam3_outputs = json.load(open(PATH_TO_LATEST_OUTPUT_JSON, "r")) + sam3_output_image_path = sam3_outputs["output_image_path"] + num_masks = len(sam3_outputs["pred_boxes"]) + + messages.append( + { + "role": "assistant", + "content": [{"type": "text", "text": generated_text}], + } + ) + if num_masks == 0: + print("❌ No masks generated by SAM3, reporting no mask to Qwen.") + sam3_output_text_message = f"The segment_phrase tool did not generate any masks for the text_prompt '{current_text_prompt}'. Now, please call the segment_phrase tool again with a different, perhaps more general, or more creative simple noun phrase text_prompt, while adhering to all the rules stated in the system prompt. Please be reminded that the original user query was '{initial_text_prompt}'." + messages.append( + { + "role": "user", + "content": [ + {"type": "text", "text": sam3_output_text_message} + ], + } + ) + else: + sam3_output_text_message = rf"The segment_phrase tool generated {num_masks} available masks. All {num_masks} available masks are rendered in this image below, now you must analyze the {num_masks} available mask(s) carefully, compare them against the raw input image and the original user query, and determine your next action. Please be reminded that the original user query was '{initial_text_prompt}'." + messages.append( + { + "role": "user", + "content": [ + {"type": "text", "text": sam3_output_text_message}, + {"type": "image", "image": sam3_output_image_path}, + ], + } + ) + print("\n\n>>> sam3_output_text_message:\n", sam3_output_text_message) + + elif tool_call["name"] == "examine_each_mask": + print("🔍 Calling examine_each_mask tool...") + assert LATEST_SAM3_TEXT_PROMPT != "" + + # Make sure that the last message is a image + assert ( + messages[-1]["content"][1]["type"] == "image" + ), "Second content element should be an image" + messages.pop() # Remove the last user message + # Add simplified replacement message + simplified_message = { + "role": "user", + "content": [ + { + "type": "text", + "text": "The segment_phrase tool generated several masks. Now you must analyze the mask(s) carefully, compare them against the raw input image and the original user query, and determine your next action.", + } + ], + } + messages.append(simplified_message) + + current_outputs = json.load(open(PATH_TO_LATEST_OUTPUT_JSON, "r")) + num_masks = len(current_outputs["pred_masks"]) + masks_to_keep = [] + + # MLLM check the mask one by one + for i in range(num_masks): + print(f"🔍 Checking mask {i + 1}/{num_masks}...") + image_w_mask_i, image_w_zoomed_in_mask_i = visualize(current_outputs, i) + + image_w_zoomed_in_mask_i_path = os.path.join( + sam_output_dir, rf"{LATEST_SAM3_TEXT_PROMPT}.png".replace("/", "_") + ).replace(".png", f"_zoom_in_mask_{i + 1}.png") + image_w_mask_i_path = os.path.join( + sam_output_dir, rf"{LATEST_SAM3_TEXT_PROMPT}.png".replace("/", "_") + ).replace(".png", f"_selected_mask_{i + 1}.png") + image_w_zoomed_in_mask_i.save(image_w_zoomed_in_mask_i_path) + image_w_mask_i.save(image_w_mask_i_path) + + iterative_checking_messages = [ + {"role": "system", "content": iterative_checking_system_prompt}, + { + "role": "user", + "content": [ + {"type": "text", "text": f"The raw input image: "}, + {"type": "image", "image": img_path}, + { + "type": "text", + "text": f"The initial user input query is: '{initial_text_prompt}'", + }, + { + "type": "text", + "text": f"Image with the predicted segmentation mask rendered on it: ", + }, + {"type": "image", "image": image_w_mask_i_path}, + { + "type": "text", + "text": f"Image with the zoomed-in mask: ", + }, + {"type": "image", "image": image_w_zoomed_in_mask_i_path}, + ], + }, + ] + checking_generated_text = send_generate_request( + iterative_checking_messages + ) + + # Process the generated text to determine if the mask should be kept or rejected + if checking_generated_text is None: + raise ValueError( + "Generated text is None, which is unexpected. Please check the Qwen server and the input parameters." + ) + print(f"Generated text for mask {i + 1}: {checking_generated_text}") + verdict = ( + checking_generated_text.split("")[-1] + .split("")[0] + .strip() + ) + if "Accept" in verdict: + assert not "Reject" in verdict + print(f"Mask {i + 1} accepted, keeping it in the outputs.") + masks_to_keep.append(i) + elif "Reject" in verdict: + assert not "Accept" in verdict + print(f"Mask {i + 1} rejected, removing it from the outputs.") + else: + raise ValueError( + f"Unexpected verdict in generated text: {checking_generated_text}. Expected 'Accept' or 'Reject'." + ) + + updated_outputs = { + "original_image_path": current_outputs["original_image_path"], + "orig_img_h": current_outputs["orig_img_h"], + "orig_img_w": current_outputs["orig_img_w"], + "pred_boxes": [current_outputs["pred_boxes"][i] for i in masks_to_keep], + "pred_scores": [ + current_outputs["pred_scores"][i] for i in masks_to_keep + ], + "pred_masks": [current_outputs["pred_masks"][i] for i in masks_to_keep], + } + + image_w_check_masks = visualize(updated_outputs) + image_w_check_masks_path = os.path.join( + sam_output_dir, rf"{LATEST_SAM3_TEXT_PROMPT}.png" + ).replace( + ".png", + f"_selected_masks_{'-'.join(map(str, [i + 1 for i in masks_to_keep]))}.png".replace( + "/", "_" + ), + ) + image_w_check_masks.save(image_w_check_masks_path) + # save the updated json outputs and append to message history + messages.append( + { + "role": "assistant", + "content": [{"type": "text", "text": generated_text}], + } + ) + if len(masks_to_keep) == 0: + messages.append( + { + "role": "user", + "content": [ + { + "type": "text", + "text": f"The original user query was: '{initial_text_prompt}'. The examine_each_mask tool examined and rejected all of the masks generated by the segment_phrase tool. Now, please call the segment_phrase tool again with a different, perhaps more general, or more creative simple noun phrase text_prompt, while adhering to all the rules stated in the system prompt.", + } + ], + } + ) + else: + messages.append( + { + "role": "user", + "content": [ + { + "type": "text", + "text": f"The original user query was: '{initial_text_prompt}'. After calling the examine_each_mask tool on the available masks, the number of available masks is now {len(masks_to_keep)}. All {len(masks_to_keep)} available masks are rendered in this image below, now you must analyze the {len(masks_to_keep)} available mask(s) carefully, compare them against the raw input image and the original user query, and determine your next action.", + }, + {"type": "image", "image": image_w_check_masks_path}, + ], + } + ) + + # Create a new filename based on the original path to avoid filename length issues + base_path = PATH_TO_LATEST_OUTPUT_JSON + # Remove any existing "masks_" suffix to avoid duplication + if "masks_" in base_path: + base_path = base_path.split("masks_")[0] + ".json" + # Create new filename with current masks; use a clearer suffix when empty + if len(masks_to_keep) == 0: + PATH_TO_LATEST_OUTPUT_JSON = base_path.replace( + ".json", "masks_none.json" + ) + else: + PATH_TO_LATEST_OUTPUT_JSON = base_path.replace( + ".json", f"masks_{'_'.join(map(str, masks_to_keep))}.json" + ) + json.dump(updated_outputs, open(PATH_TO_LATEST_OUTPUT_JSON, "w"), indent=4) + + elif tool_call["name"] == "select_masks_and_return": + print("🔍 Calling select_masks_and_return tool...") + current_outputs = json.load(open(PATH_TO_LATEST_OUTPUT_JSON, "r")) + + assert list(tool_call["parameters"].keys()) == ["final_answer_masks"] + masks_to_keep = tool_call["parameters"]["final_answer_masks"] + + # Keep only valid mask indices, remove duplicates, and preserve deterministic ascending order + available_masks = set(range(1, len(current_outputs["pred_masks"]) + 1)) + masks_to_keep = sorted({i for i in masks_to_keep if i in available_masks}) + # Change this to a update message telling the model to try again along with information about errors made. + + final_outputs = { + "original_image_path": current_outputs["original_image_path"], + "orig_img_h": current_outputs["orig_img_h"], + "orig_img_w": current_outputs["orig_img_w"], + "pred_boxes": [ + current_outputs["pred_boxes"][i - 1] for i in masks_to_keep + ], + "pred_scores": [ + current_outputs["pred_scores"][i - 1] for i in masks_to_keep + ], + "pred_masks": [ + current_outputs["pred_masks"][i - 1] for i in masks_to_keep + ], + } + + rendered_final_output = visualize(final_outputs) + messages.append( + { + "role": "assistant", + "content": [{"type": "text", "text": generated_text}], + } + ) + + # Clean up debug files before successful return + cleanup_debug_files(debug, debug_folder_path, debug_jsonl_path) + return messages, final_outputs, rendered_final_output + + elif tool_call["name"] == "report_no_mask": + print("🔍 Calling report_no_mask tool...") + height, width = cv2.imread(img_path).shape[:2] + final_outputs = { + "original_image_path": img_path, + "orig_img_h": height, + "orig_img_w": width, + "pred_boxes": [], + "pred_scores": [], + "pred_masks": [], + } + rendered_final_output = Image.open(img_path) + messages.append( + { + "role": "assistant", + "content": [{"type": "text", "text": generated_text}], + } + ) + return messages, final_outputs, rendered_final_output + + else: + raise ValueError(f"Unknown tool call: {tool_call['name']}") + + # sometimes the MLLM don't know when to stop, and generates multiple tool calls in one round, so we need to split the generated text by and only keep the first one + + for message in messages: + if message["role"] == "assistant" and "content" in message: + for content in message["content"]: + if ( + isinstance(content, dict) + and content.get("type") == "text" + and "text" in content + ): + content["text"] = ( + content["text"].split("", 1)[0] + "\n\n" + ) + # Prune the messages history before the next MLLM generation round according to the 3-part rules. + # This keeps history compact and ensures the model sees only the allowed parts. + messages = _prune_messages_for_next_round( + messages, + USED_TEXT_PROMPTS, + LATEST_SAM3_TEXT_PROMPT, + img_path, + initial_text_prompt, + ) + # make sure there can never be more than 2 images in the context + assert count_images(messages) <= 2 + generation_count += 1 + if generation_count > max_generations: + raise ValueError( + f"Exceeded maximum number of allowed generation requests ({max_generations})" + ) + + print("\n\n") + print("-" * 30 + f" Round {str(generation_count + 1)}" + "-" * 30) + print("\n\n") + generated_text = send_generate_request(messages) + print( + f"\n>>> MLLM Response [start]\n{generated_text}\n<<< MLLM Response [end]\n" + ) + + print("\n\n>>> SAM 3 Agent execution ended.\n\n") + + error_save_path = os.path.join( + error_save_dir, + f"{img_path.rsplit('/', 1)[-1].rsplit('.', 1)[0]}_error_history.json", + ) + with open(error_save_path, "w") as f: + json.dump(messages, f, indent=4) + print("Saved messages history that caused error to:", error_save_path) + raise ValueError( + rf"Generated text is None, which is unexpected. Please check the Qwen server and the input parameters for image path: {img_path} and initial text prompt: {initial_text_prompt}." + ) diff --git a/third_party/sam3/sam3/agent/client_llm.py b/third_party/sam3/sam3/agent/client_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..c55508e9d130efd65804ae880ce4ba6b547bc3e0 --- /dev/null +++ b/third_party/sam3/sam3/agent/client_llm.py @@ -0,0 +1,207 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +import base64 +import os +from typing import Any, Optional + +from openai import OpenAI + + +def get_image_base64_and_mime(image_path): + """Convert image file to base64 string and get MIME type""" + try: + # Get MIME type based on file extension + ext = os.path.splitext(image_path)[1].lower() + mime_types = { + ".jpg": "image/jpeg", + ".jpeg": "image/jpeg", + ".png": "image/png", + ".gif": "image/gif", + ".webp": "image/webp", + ".bmp": "image/bmp", + } + mime_type = mime_types.get(ext, "image/jpeg") # Default to JPEG + + # Convert image to base64 + with open(image_path, "rb") as image_file: + base64_data = base64.b64encode(image_file.read()).decode("utf-8") + return base64_data, mime_type + except Exception as e: + print(f"Error converting image to base64: {e}") + return None, None + + +def send_generate_request( + messages, + server_url=None, + model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", + api_key=None, + max_tokens=4096, +): + """ + Sends a request to the OpenAI-compatible API endpoint using the OpenAI client library. + + Args: + server_url (str): The base URL of the server, e.g. "http://127.0.0.1:8000" + messages (list): A list of message dicts, each containing role and content. + model (str): The model to use for generation (default: "llama-4") + max_tokens (int): Maximum number of tokens to generate (default: 4096) + + Returns: + str: The generated response text from the server. + """ + # Process messages to convert image paths to base64 + processed_messages = [] + for message in messages: + processed_message = message.copy() + if message["role"] == "user" and "content" in message: + processed_content = [] + for c in message["content"]: + if isinstance(c, dict) and c.get("type") == "image": + # Convert image path to base64 format + image_path = c["image"] + + print("image_path", image_path) + new_image_path = image_path.replace( + "?", "%3F" + ) # Escape ? in the path + + # Read the image file and convert to base64 + try: + base64_image, mime_type = get_image_base64_and_mime( + new_image_path + ) + if base64_image is None: + print( + f"Warning: Could not convert image to base64: {new_image_path}" + ) + continue + + # Create the proper image_url structure with base64 data + processed_content.append( + { + "type": "image_url", + "image_url": { + "url": f"data:{mime_type};base64,{base64_image}", + "detail": "high", + }, + } + ) + + except FileNotFoundError: + print(f"Warning: Image file not found: {new_image_path}") + continue + except Exception as e: + print(f"Warning: Error processing image {new_image_path}: {e}") + continue + else: + processed_content.append(c) + + processed_message["content"] = processed_content + processed_messages.append(processed_message) + + # Create OpenAI client with custom base URL + client = OpenAI(api_key=api_key, base_url=server_url) + + try: + print(f"🔍 Calling model {model}...") + response = client.chat.completions.create( + model=model, + messages=processed_messages, + max_completion_tokens=max_tokens, + n=1, + ) + # print(f"Received response: {response.choices[0].message}") + + # Extract the response content + if response.choices and len(response.choices) > 0: + return response.choices[0].message.content + else: + print(f"Unexpected response format: {response}") + return None + + except Exception as e: + print(f"Request failed: {e}") + return None + + +def send_direct_request( + llm: Any, + messages: list[dict[str, Any]], + sampling_params: Any, +) -> Optional[str]: + """ + Run inference on a vLLM model instance directly without using a server. + + Args: + llm: Initialized vLLM LLM instance (passed from external initialization) + messages: List of message dicts with role and content (OpenAI format) + sampling_params: vLLM SamplingParams instance (initialized externally) + + Returns: + str: Generated response text, or None if inference fails + """ + try: + # Process messages to handle images (convert to base64 if needed) + processed_messages = [] + for message in messages: + processed_message = message.copy() + if message["role"] == "user" and "content" in message: + processed_content = [] + for c in message["content"]: + if isinstance(c, dict) and c.get("type") == "image": + # Convert image path to base64 format + image_path = c["image"] + new_image_path = image_path.replace("?", "%3F") + + try: + base64_image, mime_type = get_image_base64_and_mime( + new_image_path + ) + if base64_image is None: + print( + f"Warning: Could not convert image: {new_image_path}" + ) + continue + + # vLLM expects image_url format + processed_content.append( + { + "type": "image_url", + "image_url": { + "url": f"data:{mime_type};base64,{base64_image}" + }, + } + ) + except Exception as e: + print( + f"Warning: Error processing image {new_image_path}: {e}" + ) + continue + else: + processed_content.append(c) + + processed_message["content"] = processed_content + processed_messages.append(processed_message) + + print("🔍 Running direct inference with vLLM...") + + # Run inference using vLLM's chat interface + outputs = llm.chat( + messages=processed_messages, + sampling_params=sampling_params, + ) + + # Extract the generated text from the first output + if outputs and len(outputs) > 0: + generated_text = outputs[0].outputs[0].text + return generated_text + else: + print(f"Unexpected output format: {outputs}") + return None + + except Exception as e: + print(f"Direct inference failed: {e}") + return None diff --git a/third_party/sam3/sam3/agent/client_sam3.py b/third_party/sam3/sam3/agent/client_sam3.py new file mode 100644 index 0000000000000000000000000000000000000000..daeb849a0ba2f95768fdfcf4e1b61fcd1217f180 --- /dev/null +++ b/third_party/sam3/sam3/agent/client_sam3.py @@ -0,0 +1,139 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +import json +import os + +import torch +from PIL import Image +from sam3.model.box_ops import box_xyxy_to_xywh +from sam3.train.masks_ops import rle_encode + +from .helpers.mask_overlap_removal import remove_overlapping_masks +from .viz import visualize + + +def sam3_inference(processor, image_path, text_prompt): + """Run SAM 3 image inference with text prompts and format the outputs""" + image = Image.open(image_path) + orig_img_w, orig_img_h = image.size + + # model inference + inference_state = processor.set_image(image) + inference_state = processor.set_text_prompt( + state=inference_state, prompt=text_prompt + ) + + # format and assemble outputs + pred_boxes_xyxy = torch.stack( + [ + inference_state["boxes"][:, 0] / orig_img_w, + inference_state["boxes"][:, 1] / orig_img_h, + inference_state["boxes"][:, 2] / orig_img_w, + inference_state["boxes"][:, 3] / orig_img_h, + ], + dim=-1, + ) # normalized in range [0, 1] + pred_boxes_xywh = box_xyxy_to_xywh(pred_boxes_xyxy).tolist() + pred_masks = rle_encode(inference_state["masks"].squeeze(1)) + pred_masks = [m["counts"] for m in pred_masks] + outputs = { + "orig_img_h": orig_img_h, + "orig_img_w": orig_img_w, + "pred_boxes": pred_boxes_xywh, + "pred_masks": pred_masks, + "pred_scores": inference_state["scores"].tolist(), + } + return outputs + + +def call_sam_service( + sam3_processor, + image_path: str, + text_prompt: str, + output_folder_path: str = "sam3_output", +): + """ + Loads an image, sends it with a text prompt to the service, + saves the results, and renders the visualization. + """ + print(f"📞 Loading image '{image_path}' and sending with prompt '{text_prompt}'...") + + text_prompt_for_save_path = ( + text_prompt.replace("/", "_") if "/" in text_prompt else text_prompt + ) + + os.makedirs( + os.path.join(output_folder_path, image_path.replace("/", "-")), exist_ok=True + ) + output_json_path = os.path.join( + output_folder_path, + image_path.replace("/", "-"), + rf"{text_prompt_for_save_path}.json", + ) + output_image_path = os.path.join( + output_folder_path, + image_path.replace("/", "-"), + rf"{text_prompt_for_save_path}.png", + ) + + try: + # Send the image and text prompt as a multipart/form-data request + serialized_response = sam3_inference(sam3_processor, image_path, text_prompt) + + # 1. Prepare the response dictionary + serialized_response = remove_overlapping_masks(serialized_response) + serialized_response = { + "original_image_path": image_path, + "output_image_path": output_image_path, + **serialized_response, + } + + # 2. Reorder predictions by scores (highest to lowest) if scores are available + if "pred_scores" in serialized_response and serialized_response["pred_scores"]: + # Create indices sorted by scores in descending order + score_indices = sorted( + range(len(serialized_response["pred_scores"])), + key=lambda i: serialized_response["pred_scores"][i], + reverse=True, + ) + + # Reorder all three lists based on the sorted indices + serialized_response["pred_scores"] = [ + serialized_response["pred_scores"][i] for i in score_indices + ] + serialized_response["pred_boxes"] = [ + serialized_response["pred_boxes"][i] for i in score_indices + ] + serialized_response["pred_masks"] = [ + serialized_response["pred_masks"][i] for i in score_indices + ] + + # 3. Remove any invalid RLE masks that is too short (shorter than 5 characters) + valid_masks = [] + valid_boxes = [] + valid_scores = [] + for i, rle in enumerate(serialized_response["pred_masks"]): + if len(rle) > 4: + valid_masks.append(rle) + valid_boxes.append(serialized_response["pred_boxes"][i]) + valid_scores.append(serialized_response["pred_scores"][i]) + serialized_response["pred_masks"] = valid_masks + serialized_response["pred_boxes"] = valid_boxes + serialized_response["pred_scores"] = valid_scores + + with open(output_json_path, "w") as f: + json.dump(serialized_response, f, indent=4) + print(f"✅ Raw JSON response saved to '{output_json_path}'") + + # 4. Render and save visualizations on the image and save it in the SAM3 output folder + print("🔍 Rendering visualizations on the image ...") + viz_image = visualize(serialized_response) + os.makedirs(os.path.dirname(output_image_path), exist_ok=True) + viz_image.save(output_image_path) + print("✅ Saved visualization at:", output_image_path) + except Exception as e: + print(f"❌ Error calling service: {e}") + + return output_json_path diff --git a/third_party/sam3/sam3/agent/helpers/__init__.py b/third_party/sam3/sam3/agent/helpers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..726cb1eb0ff0f17f7edc4ec73f1a73d1ac87bf59 --- /dev/null +++ b/third_party/sam3/sam3/agent/helpers/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe diff --git a/third_party/sam3/sam3/agent/helpers/boxes.py b/third_party/sam3/sam3/agent/helpers/boxes.py new file mode 100644 index 0000000000000000000000000000000000000000..a32e5202ef765e3fc131fdbf2c1382ae195ff1cf --- /dev/null +++ b/third_party/sam3/sam3/agent/helpers/boxes.py @@ -0,0 +1,440 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +import math +from enum import IntEnum, unique +from typing import List, Tuple, Union + +import numpy as np +import torch +from torch import device + +_RawBoxType = Union[List[float], Tuple[float, ...], torch.Tensor, np.ndarray] + + +@unique +class BoxMode(IntEnum): + """ + Enum of different ways to represent a box. + """ + + XYXY_ABS = 0 + """ + (x0, y0, x1, y1) in absolute floating points coordinates. + The coordinates in range [0, width or height]. + """ + XYWH_ABS = 1 + """ + (x0, y0, w, h) in absolute floating points coordinates. + """ + XYXY_REL = 2 + """ + Not yet supported! + (x0, y0, x1, y1) in range [0, 1]. They are relative to the size of the image. + """ + XYWH_REL = 3 + """ + Not yet supported! + (x0, y0, w, h) in range [0, 1]. They are relative to the size of the image. + """ + XYWHA_ABS = 4 + """ + (xc, yc, w, h, a) in absolute floating points coordinates. + (xc, yc) is the center of the rotated box, and the angle a is in degrees ccw. + """ + + @staticmethod + def convert( + box: _RawBoxType, from_mode: "BoxMode", to_mode: "BoxMode" + ) -> _RawBoxType: + """ + Args: + box: can be a k-tuple, k-list or an Nxk array/tensor, where k = 4 or 5 + from_mode, to_mode (BoxMode) + + Returns: + The converted box of the same type. + """ + if from_mode == to_mode: + return box + + original_type = type(box) + is_numpy = isinstance(box, np.ndarray) + single_box = isinstance(box, (list, tuple)) + if single_box: + assert len(box) == 4 or len(box) == 5, ( + "BoxMode.convert takes either a k-tuple/list or an Nxk array/tensor," + " where k == 4 or 5" + ) + arr = torch.tensor(box)[None, :] + else: + # avoid modifying the input box + if is_numpy: + arr = torch.from_numpy(np.asarray(box)).clone() + else: + arr = box.clone() + + assert to_mode not in [ + BoxMode.XYXY_REL, + BoxMode.XYWH_REL, + ] and from_mode not in [ + BoxMode.XYXY_REL, + BoxMode.XYWH_REL, + ], "Relative mode not yet supported!" + + if from_mode == BoxMode.XYWHA_ABS and to_mode == BoxMode.XYXY_ABS: + assert ( + arr.shape[-1] == 5 + ), "The last dimension of input shape must be 5 for XYWHA format" + original_dtype = arr.dtype + arr = arr.double() + + w = arr[:, 2] + h = arr[:, 3] + a = arr[:, 4] + c = torch.abs(torch.cos(a * math.pi / 180.0)) + s = torch.abs(torch.sin(a * math.pi / 180.0)) + # This basically computes the horizontal bounding rectangle of the rotated box + new_w = c * w + s * h + new_h = c * h + s * w + + # convert center to top-left corner + arr[:, 0] -= new_w / 2.0 + arr[:, 1] -= new_h / 2.0 + # bottom-right corner + arr[:, 2] = arr[:, 0] + new_w + arr[:, 3] = arr[:, 1] + new_h + + arr = arr[:, :4].to(dtype=original_dtype) + elif from_mode == BoxMode.XYWH_ABS and to_mode == BoxMode.XYWHA_ABS: + original_dtype = arr.dtype + arr = arr.double() + arr[:, 0] += arr[:, 2] / 2.0 + arr[:, 1] += arr[:, 3] / 2.0 + angles = torch.zeros((arr.shape[0], 1), dtype=arr.dtype) + arr = torch.cat((arr, angles), axis=1).to(dtype=original_dtype) + else: + if to_mode == BoxMode.XYXY_ABS and from_mode == BoxMode.XYWH_ABS: + arr[:, 2] += arr[:, 0] + arr[:, 3] += arr[:, 1] + elif from_mode == BoxMode.XYXY_ABS and to_mode == BoxMode.XYWH_ABS: + arr[:, 2] -= arr[:, 0] + arr[:, 3] -= arr[:, 1] + else: + raise NotImplementedError( + "Conversion from BoxMode {} to {} is not supported yet".format( + from_mode, to_mode + ) + ) + + if single_box: + return original_type(arr.flatten().tolist()) + if is_numpy: + return arr.numpy() + else: + return arr + + +class Boxes: + """ + This structure stores a list of boxes as a Nx4 torch.Tensor. + It supports some common methods about boxes + (`area`, `clip`, `nonempty`, etc), + and also behaves like a Tensor + (support indexing, `to(device)`, `.device`, and iteration over all boxes) + + Attributes: + tensor (torch.Tensor): float matrix of Nx4. Each row is (x1, y1, x2, y2). + """ + + def __init__(self, tensor: torch.Tensor): + """ + Args: + tensor (Tensor[float]): a Nx4 matrix. Each row is (x1, y1, x2, y2). + """ + if not isinstance(tensor, torch.Tensor): + tensor = torch.as_tensor( + tensor, dtype=torch.float32, device=torch.device("cpu") + ) + else: + tensor = tensor.to(torch.float32) + if tensor.numel() == 0: + # Use reshape, so we don't end up creating a new tensor that does not depend on + # the inputs (and consequently confuses jit) + tensor = tensor.reshape((-1, 4)).to(dtype=torch.float32) + assert tensor.dim() == 2 and tensor.size(-1) == 4, tensor.size() + + self.tensor = tensor + + def clone(self) -> "Boxes": + """ + Clone the Boxes. + + Returns: + Boxes + """ + return Boxes(self.tensor.clone()) + + def to(self, device: torch.device): + # Boxes are assumed float32 and does not support to(dtype) + return Boxes(self.tensor.to(device=device)) + + def area(self) -> torch.Tensor: + """ + Computes the area of all the boxes. + + Returns: + torch.Tensor: a vector with areas of each box. + """ + box = self.tensor + area = (box[:, 2] - box[:, 0]) * (box[:, 3] - box[:, 1]) + return area + + def clip(self, box_size: Tuple[int, int]) -> None: + """ + Clip (in place) the boxes by limiting x coordinates to the range [0, width] + and y coordinates to the range [0, height]. + + Args: + box_size (height, width): The clipping box's size. + """ + assert torch.isfinite(self.tensor).all(), "Box tensor contains infinite or NaN!" + h, w = box_size + x1 = self.tensor[:, 0].clamp(min=0, max=w) + y1 = self.tensor[:, 1].clamp(min=0, max=h) + x2 = self.tensor[:, 2].clamp(min=0, max=w) + y2 = self.tensor[:, 3].clamp(min=0, max=h) + self.tensor = torch.stack((x1, y1, x2, y2), dim=-1) + + def nonempty(self, threshold: float = 0.0) -> torch.Tensor: + """ + Find boxes that are non-empty. + A box is considered empty, if either of its side is no larger than threshold. + + Returns: + Tensor: + a binary vector which represents whether each box is empty + (False) or non-empty (True). + """ + box = self.tensor + widths = box[:, 2] - box[:, 0] + heights = box[:, 3] - box[:, 1] + keep = (widths > threshold) & (heights > threshold) + return keep + + def __getitem__(self, item) -> "Boxes": + """ + Args: + item: int, slice, or a BoolTensor + + Returns: + Boxes: Create a new :class:`Boxes` by indexing. + + The following usage are allowed: + + 1. `new_boxes = boxes[3]`: return a `Boxes` which contains only one box. + 2. `new_boxes = boxes[2:10]`: return a slice of boxes. + 3. `new_boxes = boxes[vector]`, where vector is a torch.BoolTensor + with `length = len(boxes)`. Nonzero elements in the vector will be selected. + + Note that the returned Boxes might share storage with this Boxes, + subject to Pytorch's indexing semantics. + """ + if isinstance(item, int): + return Boxes(self.tensor[item].view(1, -1)) + b = self.tensor[item] + assert ( + b.dim() == 2 + ), "Indexing on Boxes with {} failed to return a matrix!".format(item) + return Boxes(b) + + def __len__(self) -> int: + return self.tensor.shape[0] + + def __repr__(self) -> str: + return "Boxes(" + str(self.tensor) + ")" + + def inside_box( + self, box_size: Tuple[int, int], boundary_threshold: int = 0 + ) -> torch.Tensor: + """ + Args: + box_size (height, width): Size of the reference box. + boundary_threshold (int): Boxes that extend beyond the reference box + boundary by more than boundary_threshold are considered "outside". + + Returns: + a binary vector, indicating whether each box is inside the reference box. + """ + height, width = box_size + inds_inside = ( + (self.tensor[..., 0] >= -boundary_threshold) + & (self.tensor[..., 1] >= -boundary_threshold) + & (self.tensor[..., 2] < width + boundary_threshold) + & (self.tensor[..., 3] < height + boundary_threshold) + ) + return inds_inside + + def get_centers(self) -> torch.Tensor: + """ + Returns: + The box centers in a Nx2 array of (x, y). + """ + return (self.tensor[:, :2] + self.tensor[:, 2:]) / 2 + + def scale(self, scale_x: float, scale_y: float) -> None: + """ + Scale the box with horizontal and vertical scaling factors + """ + self.tensor[:, 0::2] *= scale_x + self.tensor[:, 1::2] *= scale_y + + @classmethod + def cat(cls, boxes_list: List["Boxes"]) -> "Boxes": + """ + Concatenates a list of Boxes into a single Boxes + + Arguments: + boxes_list (list[Boxes]) + + Returns: + Boxes: the concatenated Boxes + """ + assert isinstance(boxes_list, (list, tuple)) + if len(boxes_list) == 0: + return cls(torch.empty(0)) + assert all([isinstance(box, Boxes) for box in boxes_list]) + + # use torch.cat (v.s. layers.cat) so the returned boxes never share storage with input + cat_boxes = cls(torch.cat([b.tensor for b in boxes_list], dim=0)) + return cat_boxes + + @property + def device(self) -> device: + return self.tensor.device + + # type "Iterator[torch.Tensor]", yield, and iter() not supported by torchscript + # https://github.com/pytorch/pytorch/issues/18627 + @torch.jit.unused + def __iter__(self): + """ + Yield a box as a Tensor of shape (4,) at a time. + """ + yield from self.tensor + + +def pairwise_intersection(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor: + """ + Given two lists of boxes of size N and M, + compute the intersection area between __all__ N x M pairs of boxes. + The box order must be (xmin, ymin, xmax, ymax) + + Args: + boxes1,boxes2 (Boxes): two `Boxes`. Contains N & M boxes, respectively. + + Returns: + Tensor: intersection, sized [N,M]. + """ + boxes1, boxes2 = boxes1.tensor, boxes2.tensor + width_height = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) - torch.max( + boxes1[:, None, :2], boxes2[:, :2] + ) # [N,M,2] + + width_height.clamp_(min=0) # [N,M,2] + intersection = width_height.prod(dim=2) # [N,M] + return intersection + + +# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py +# with slight modifications +def pairwise_iou(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor: + """ + Given two lists of boxes of size N and M, compute the IoU + (intersection over union) between **all** N x M pairs of boxes. + The box order must be (xmin, ymin, xmax, ymax). + + Args: + boxes1,boxes2 (Boxes): two `Boxes`. Contains N & M boxes, respectively. + + Returns: + Tensor: IoU, sized [N,M]. + """ + area1 = boxes1.area() # [N] + area2 = boxes2.area() # [M] + inter = pairwise_intersection(boxes1, boxes2) + + # handle empty boxes + iou = torch.where( + inter > 0, + inter / (area1[:, None] + area2 - inter), + torch.zeros(1, dtype=inter.dtype, device=inter.device), + ) + return iou + + +def pairwise_ioa(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor: + """ + Similar to :func:`pariwise_iou` but compute the IoA (intersection over boxes2 area). + + Args: + boxes1,boxes2 (Boxes): two `Boxes`. Contains N & M boxes, respectively. + + Returns: + Tensor: IoA, sized [N,M]. + """ + area2 = boxes2.area() # [M] + inter = pairwise_intersection(boxes1, boxes2) + + # handle empty boxes + ioa = torch.where( + inter > 0, inter / area2, torch.zeros(1, dtype=inter.dtype, device=inter.device) + ) + return ioa + + +def pairwise_point_box_distance(points: torch.Tensor, boxes: Boxes): + """ + Pairwise distance between N points and M boxes. The distance between a + point and a box is represented by the distance from the point to 4 edges + of the box. Distances are all positive when the point is inside the box. + + Args: + points: Nx2 coordinates. Each row is (x, y) + boxes: M boxes + + Returns: + Tensor: distances of size (N, M, 4). The 4 values are distances from + the point to the left, top, right, bottom of the box. + """ + x, y = points.unsqueeze(dim=2).unbind(dim=1) # (N, 1) + x0, y0, x1, y1 = boxes.tensor.unsqueeze(dim=0).unbind(dim=2) # (1, M) + return torch.stack([x - x0, y - y0, x1 - x, y1 - y], dim=2) + + +def matched_pairwise_iou(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor: + """ + Compute pairwise intersection over union (IOU) of two sets of matched + boxes that have the same number of boxes. + Similar to :func:`pairwise_iou`, but computes only diagonal elements of the matrix. + + Args: + boxes1 (Boxes): bounding boxes, sized [N,4]. + boxes2 (Boxes): same length as boxes1 + Returns: + Tensor: iou, sized [N]. + """ + assert len(boxes1) == len( + boxes2 + ), "boxlists should have the samenumber of entries, got {}, {}".format( + len(boxes1), len(boxes2) + ) + area1 = boxes1.area() # [N] + area2 = boxes2.area() # [N] + box1, box2 = boxes1.tensor, boxes2.tensor + lt = torch.max(box1[:, :2], box2[:, :2]) # [N,2] + rb = torch.min(box1[:, 2:], box2[:, 2:]) # [N,2] + wh = (rb - lt).clamp(min=0) # [N,2] + inter = wh[:, 0] * wh[:, 1] # [N] + iou = inter / (area1 + area2 - inter) # [N] + return iou diff --git a/third_party/sam3/sam3/agent/helpers/color_map.py b/third_party/sam3/sam3/agent/helpers/color_map.py new file mode 100644 index 0000000000000000000000000000000000000000..ef6b61560723775cf36c1305fc18ed29130885d7 --- /dev/null +++ b/third_party/sam3/sam3/agent/helpers/color_map.py @@ -0,0 +1,152 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +""" +An awesome colormap for really neat visualizations. +Copied from Detectron, and removed gray colors. +""" + +import random + +import numpy as np + +__all__ = ["colormap", "random_color", "random_colors"] + + +# A list of 25 bright and sharp colors for segmentation masks, +# generated from the edges of the sRGB color space for maximum intensity. +_COLORS = ( + np.array( + [ + # The original 8 sharp colors + 1.000, + 1.000, + 0.000, # 1. Yellow + 0.000, + 1.000, + 0.000, # 2. Lime + 0.000, + 1.000, + 1.000, # 3. Cyan + 1.000, + 0.000, + 1.000, # 4. Magenta + 1.000, + 0.000, + 0.000, # 5. Red + 1.000, + 0.498, + 0.000, # 6. Orange + 0.498, + 1.000, + 0.000, # 7. Chartreuse + 0.000, + 1.000, + 0.498, # 8. Spring Green + 1.000, + 0.000, + 0.498, # 9. Rose + 0.498, + 0.000, + 1.000, # 10. Violet + 0.753, + 1.000, + 0.000, # 11. Electric Lime + 1.000, + 0.753, + 0.000, # 12. Vivid Orange + 0.000, + 1.000, + 0.753, # 13. Turquoise + 0.753, + 0.000, + 1.000, # 14. Bright Violet + 1.000, + 0.000, + 0.753, # 15. Bright Pink + 1.000, + 0.251, + 0.000, # 16. Fiery Orange + 0.251, + 1.000, + 0.000, # 17. Bright Chartreuse + 0.000, + 1.000, + 0.251, # 18. Malachite Green + 0.251, + 0.000, + 1.000, # 19. Deep Violet + 1.000, + 0.000, + 0.251, # 20. Hot Pink + ] + ) + .astype(np.float32) + .reshape(-1, 3) +) + + +def colormap(rgb=False, maximum=255): + """ + Args: + rgb (bool): whether to return RGB colors or BGR colors. + maximum (int): either 255 or 1 + + Returns: + ndarray: a float32 array of Nx3 colors, in range [0, 255] or [0, 1] + """ + assert maximum in [255, 1], maximum + c = _COLORS * maximum + if not rgb: + c = c[:, ::-1] + return c + + +def random_color(rgb=False, maximum=255): + """ + Args: + rgb (bool): whether to return RGB colors or BGR colors. + maximum (int): either 255 or 1 + + Returns: + ndarray: a vector of 3 numbers + """ + idx = np.random.randint(0, len(_COLORS)) + ret = _COLORS[idx] * maximum + if not rgb: + ret = ret[::-1] + return ret + + +def random_colors(N, rgb=False, maximum=255): + """ + Args: + N (int): number of unique colors needed + rgb (bool): whether to return RGB colors or BGR colors. + maximum (int): either 255 or 1 + + Returns: + ndarray: a list of random_color + """ + indices = random.sample(range(len(_COLORS)), N) + ret = [_COLORS[i] * maximum for i in indices] + if not rgb: + ret = [x[::-1] for x in ret] + return ret + + +if __name__ == "__main__": + import cv2 + + size = 100 + H, W = 10, 10 + canvas = np.random.rand(H * size, W * size, 3).astype("float32") + for h in range(H): + for w in range(W): + idx = h * W + w + if idx >= len(_COLORS): + break + canvas[h * size : (h + 1) * size, w * size : (w + 1) * size] = _COLORS[idx] + cv2.imshow("a", canvas) + cv2.waitKey(0) diff --git a/third_party/sam3/sam3/agent/helpers/keypoints.py b/third_party/sam3/sam3/agent/helpers/keypoints.py new file mode 100644 index 0000000000000000000000000000000000000000..ae67be18dc0d8670844e1c8047c6732efdb65262 --- /dev/null +++ b/third_party/sam3/sam3/agent/helpers/keypoints.py @@ -0,0 +1,246 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +from typing import Any, List, Tuple, Union + +import numpy as np +import torch +from torch.nn import functional as F + + +class Keypoints: + """ + Stores keypoint **annotation** data. GT Instances have a `gt_keypoints` property + containing the x,y location and visibility flag of each keypoint. This tensor has shape + (N, K, 3) where N is the number of instances and K is the number of keypoints per instance. + + The visibility flag follows the COCO format and must be one of three integers: + + * v=0: not labeled (in which case x=y=0) + * v=1: labeled but not visible + * v=2: labeled and visible + """ + + def __init__(self, keypoints: Union[torch.Tensor, np.ndarray, List[List[float]]]): + """ + Arguments: + keypoints: A Tensor, numpy array, or list of the x, y, and visibility of each keypoint. + The shape should be (N, K, 3) where N is the number of + instances, and K is the number of keypoints per instance. + """ + device = ( + keypoints.device + if isinstance(keypoints, torch.Tensor) + else torch.device("cpu") + ) + keypoints = torch.as_tensor(keypoints, dtype=torch.float32, device=device) + assert keypoints.dim() == 3 and keypoints.shape[2] == 3, keypoints.shape + self.tensor = keypoints + + def __len__(self) -> int: + return self.tensor.size(0) + + def to(self, *args: Any, **kwargs: Any) -> "Keypoints": + return type(self)(self.tensor.to(*args, **kwargs)) + + @property + def device(self) -> torch.device: + return self.tensor.device + + def to_heatmap(self, boxes: torch.Tensor, heatmap_size: int) -> torch.Tensor: + """ + Convert keypoint annotations to a heatmap of one-hot labels for training, + as described in :paper:`Mask R-CNN`. + + Arguments: + boxes: Nx4 tensor, the boxes to draw the keypoints to + + Returns: + heatmaps: + A tensor of shape (N, K), each element is integer spatial label + in the range [0, heatmap_size**2 - 1] for each keypoint in the input. + valid: + A tensor of shape (N, K) containing whether each keypoint is in the roi or not. + """ + return _keypoints_to_heatmap(self.tensor, boxes, heatmap_size) + + def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "Keypoints": + """ + Create a new `Keypoints` by indexing on this `Keypoints`. + + The following usage are allowed: + + 1. `new_kpts = kpts[3]`: return a `Keypoints` which contains only one instance. + 2. `new_kpts = kpts[2:10]`: return a slice of key points. + 3. `new_kpts = kpts[vector]`, where vector is a torch.ByteTensor + with `length = len(kpts)`. Nonzero elements in the vector will be selected. + + Note that the returned Keypoints might share storage with this Keypoints, + subject to Pytorch's indexing semantics. + """ + if isinstance(item, int): + return Keypoints([self.tensor[item]]) + return Keypoints(self.tensor[item]) + + def __repr__(self) -> str: + s = self.__class__.__name__ + "(" + s += "num_instances={})".format(len(self.tensor)) + return s + + @staticmethod + def cat(keypoints_list: List["Keypoints"]) -> "Keypoints": + """ + Concatenates a list of Keypoints into a single Keypoints + + Arguments: + keypoints_list (list[Keypoints]) + + Returns: + Keypoints: the concatenated Keypoints + """ + assert isinstance(keypoints_list, (list, tuple)) + assert len(keypoints_list) > 0 + assert all(isinstance(keypoints, Keypoints) for keypoints in keypoints_list) + + cat_kpts = type(keypoints_list[0])( + torch.cat([kpts.tensor for kpts in keypoints_list], dim=0) + ) + return cat_kpts + + +def _keypoints_to_heatmap( + keypoints: torch.Tensor, rois: torch.Tensor, heatmap_size: int +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Encode keypoint locations into a target heatmap for use in SoftmaxWithLoss across space. + + Maps keypoints from the half-open interval [x1, x2) on continuous image coordinates to the + closed interval [0, heatmap_size - 1] on discrete image coordinates. We use the + continuous-discrete conversion from Heckbert 1990 ("What is the coordinate of a pixel?"): + d = floor(c) and c = d + 0.5, where d is a discrete coordinate and c is a continuous coordinate. + + Arguments: + keypoints: tensor of keypoint locations in of shape (N, K, 3). + rois: Nx4 tensor of rois in xyxy format + heatmap_size: integer side length of square heatmap. + + Returns: + heatmaps: A tensor of shape (N, K) containing an integer spatial label + in the range [0, heatmap_size**2 - 1] for each keypoint in the input. + valid: A tensor of shape (N, K) containing whether each keypoint is in + the roi or not. + """ + + if rois.numel() == 0: + return rois.new().long(), rois.new().long() + offset_x = rois[:, 0] + offset_y = rois[:, 1] + scale_x = heatmap_size / (rois[:, 2] - rois[:, 0]) + scale_y = heatmap_size / (rois[:, 3] - rois[:, 1]) + + offset_x = offset_x[:, None] + offset_y = offset_y[:, None] + scale_x = scale_x[:, None] + scale_y = scale_y[:, None] + + x = keypoints[..., 0] + y = keypoints[..., 1] + + x_boundary_inds = x == rois[:, 2][:, None] + y_boundary_inds = y == rois[:, 3][:, None] + + x = (x - offset_x) * scale_x + x = x.floor().long() + y = (y - offset_y) * scale_y + y = y.floor().long() + + x[x_boundary_inds] = heatmap_size - 1 + y[y_boundary_inds] = heatmap_size - 1 + + valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size) + vis = keypoints[..., 2] > 0 + valid = (valid_loc & vis).long() + + lin_ind = y * heatmap_size + x + heatmaps = lin_ind * valid + + return heatmaps, valid + + +@torch.jit.script_if_tracing +def heatmaps_to_keypoints(maps: torch.Tensor, rois: torch.Tensor) -> torch.Tensor: + """ + Extract predicted keypoint locations from heatmaps. + + Args: + maps (Tensor): (#ROIs, #keypoints, POOL_H, POOL_W). The predicted heatmap of logits for + each ROI and each keypoint. + rois (Tensor): (#ROIs, 4). The box of each ROI. + + Returns: + Tensor of shape (#ROIs, #keypoints, 4) with the last dimension corresponding to + (x, y, logit, score) for each keypoint. + + When converting discrete pixel indices in an NxN image to a continuous keypoint coordinate, + we maintain consistency with :meth:`Keypoints.to_heatmap` by using the conversion from + Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a continuous coordinate. + """ + + offset_x = rois[:, 0] + offset_y = rois[:, 1] + + widths = (rois[:, 2] - rois[:, 0]).clamp(min=1) + heights = (rois[:, 3] - rois[:, 1]).clamp(min=1) + widths_ceil = widths.ceil() + heights_ceil = heights.ceil() + + num_rois, num_keypoints = maps.shape[:2] + xy_preds = maps.new_zeros(rois.shape[0], num_keypoints, 4) + + width_corrections = widths / widths_ceil + height_corrections = heights / heights_ceil + + keypoints_idx = torch.arange(num_keypoints, device=maps.device) + + for i in range(num_rois): + outsize = (int(heights_ceil[i]), int(widths_ceil[i])) + roi_map = F.interpolate( + maps[[i]], size=outsize, mode="bicubic", align_corners=False + ) + + # Although semantically equivalent, `reshape` is used instead of `squeeze` due + # to limitation during ONNX export of `squeeze` in scripting mode + roi_map = roi_map.reshape(roi_map.shape[1:]) # keypoints x H x W + + # softmax over the spatial region + max_score, _ = roi_map.view(num_keypoints, -1).max(1) + max_score = max_score.view(num_keypoints, 1, 1) + tmp_full_resolution = (roi_map - max_score).exp_() + tmp_pool_resolution = (maps[i] - max_score).exp_() + # Produce scores over the region H x W, but normalize with POOL_H x POOL_W, + # so that the scores of objects of different absolute sizes will be more comparable + roi_map_scores = tmp_full_resolution / tmp_pool_resolution.sum( + (1, 2), keepdim=True + ) + + w = roi_map.shape[2] + pos = roi_map.view(num_keypoints, -1).argmax(1) + + x_int = pos % w + y_int = (pos - x_int) // w + + assert ( + roi_map_scores[keypoints_idx, y_int, x_int] + == roi_map_scores.view(num_keypoints, -1).max(1)[0] + ).all() + + x = (x_int.float() + 0.5) * width_corrections[i] + y = (y_int.float() + 0.5) * height_corrections[i] + + xy_preds[i, :, 0] = x + offset_x[i] + xy_preds[i, :, 1] = y + offset_y[i] + xy_preds[i, :, 2] = roi_map[keypoints_idx, y_int, x_int] + xy_preds[i, :, 3] = roi_map_scores[keypoints_idx, y_int, x_int] + + return xy_preds diff --git a/third_party/sam3/sam3/agent/helpers/mask_overlap_removal.py b/third_party/sam3/sam3/agent/helpers/mask_overlap_removal.py new file mode 100644 index 0000000000000000000000000000000000000000..c2c3c68f319a7bbd28b09876494ab794af4a84f4 --- /dev/null +++ b/third_party/sam3/sam3/agent/helpers/mask_overlap_removal.py @@ -0,0 +1,130 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +from typing import Dict, List + +import numpy as np +import torch + +try: + from pycocotools import mask as mask_utils +except Exception: + mask_utils = None + + +def mask_intersection( + masks1: torch.Tensor, masks2: torch.Tensor, block_size: int = 16 +) -> torch.Tensor: + assert masks1.shape[1:] == masks2.shape[1:] + assert masks1.dtype == torch.bool and masks2.dtype == torch.bool + N, M = masks1.shape[0], masks2.shape[0] + out = torch.zeros(N, M, device=masks1.device, dtype=torch.long) + for i in range(0, N, block_size): + for j in range(0, M, block_size): + a = masks1[i : i + block_size] + b = masks2[j : j + block_size] + inter = (a[:, None] & b[None, :]).flatten(-2).sum(-1) + out[i : i + block_size, j : j + block_size] = inter + return out + + +def mask_iom(masks1: torch.Tensor, masks2: torch.Tensor) -> torch.Tensor: + assert masks1.shape[1:] == masks2.shape[1:] + assert masks1.dtype == torch.bool and masks2.dtype == torch.bool + inter = mask_intersection(masks1, masks2) + area1 = masks1.flatten(-2).sum(-1) # (N,) + area2 = masks2.flatten(-2).sum(-1) # (M,) + min_area = torch.min(area1[:, None], area2[None, :]).clamp_min(1) + return inter.float() / (min_area.float() + 1e-8) + + +def _decode_single_mask(mask_repr, h: int, w: int) -> np.ndarray: + if isinstance(mask_repr, (list, tuple, np.ndarray)): + arr = np.array(mask_repr) + if arr.ndim != 2: + raise ValueError("Mask array must be 2D (H, W).") + return (arr > 0).astype(np.uint8) + + if mask_utils is None: + raise ImportError( + "pycocotools is required to decode RLE mask strings. pip install pycocotools" + ) + + if not isinstance(mask_repr, (str, bytes)): + raise ValueError("Unsupported mask representation type for RLE decode.") + + rle = { + "counts": mask_repr if isinstance(mask_repr, (str, bytes)) else str(mask_repr), + "size": [h, w], + } + decoded = mask_utils.decode(rle) + if decoded.ndim == 3: + decoded = decoded[:, :, 0] + return (decoded > 0).astype(np.uint8) + + +def _decode_masks_to_torch_bool(pred_masks: List, h: int, w: int) -> torch.Tensor: + bin_masks = [_decode_single_mask(m, h, w) for m in pred_masks] + masks_np = np.stack(bin_masks, axis=0).astype(np.uint8) # (N, H, W) + return torch.from_numpy(masks_np > 0) + + +def remove_overlapping_masks(sample: Dict, iom_thresh: float = 0.3) -> Dict: + """ + Greedy keep: sort by score desc; keep a mask if IoM to all kept masks <= threshold. + If pred_masks has length 0 or 1, returns sample unchanged (no extra keys). + """ + # Basic presence checks + if "pred_masks" not in sample or not isinstance(sample["pred_masks"], list): + return sample # nothing to do / preserve as-is + + pred_masks = sample["pred_masks"] + N = len(pred_masks) + + # --- Early exit: 0 or 1 mask -> do NOT modify the JSON at all --- + if N <= 1: + return sample + + # From here on we have at least 2 masks + h = int(sample["orig_img_h"]) + w = int(sample["orig_img_w"]) + pred_scores = sample.get("pred_scores", [1.0] * N) # fallback if scores missing + pred_boxes = sample.get("pred_boxes", None) + + assert N == len(pred_scores), "pred_masks and pred_scores must have same length" + if pred_boxes is not None: + assert N == len(pred_boxes), "pred_masks and pred_boxes must have same length" + + masks_bool = _decode_masks_to_torch_bool(pred_masks, h, w) # (N, H, W) + + order = sorted(range(N), key=lambda i: float(pred_scores[i]), reverse=True) + kept_idx: List[int] = [] + kept_masks: List[torch.Tensor] = [] + + for i in order: + cand = masks_bool[i].unsqueeze(0) # (1, H, W) + if len(kept_masks) == 0: + kept_idx.append(i) + kept_masks.append(masks_bool[i]) + continue + + kept_stack = torch.stack(kept_masks, dim=0) # (K, H, W) + iom_vals = mask_iom(cand, kept_stack).squeeze(0) # (K,) + if torch.any(iom_vals > iom_thresh): + continue # overlaps too much with a higher-scored kept mask + kept_idx.append(i) + kept_masks.append(masks_bool[i]) + + kept_idx_sorted = sorted(kept_idx) + + # Build filtered JSON (this *does* modify fields; only for N>=2 case) + out = dict(sample) + out["pred_masks"] = [pred_masks[i] for i in kept_idx_sorted] + out["pred_scores"] = [pred_scores[i] for i in kept_idx_sorted] + if pred_boxes is not None: + out["pred_boxes"] = [pred_boxes[i] for i in kept_idx_sorted] + out["kept_indices"] = kept_idx_sorted + out["removed_indices"] = [i for i in range(N) if i not in set(kept_idx_sorted)] + out["iom_threshold"] = float(iom_thresh) + return out diff --git a/third_party/sam3/sam3/agent/helpers/masks.py b/third_party/sam3/sam3/agent/helpers/masks.py new file mode 100644 index 0000000000000000000000000000000000000000..ba10fc8a860a9812becfd1b762a611234c466d3c --- /dev/null +++ b/third_party/sam3/sam3/agent/helpers/masks.py @@ -0,0 +1,561 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +import copy +import itertools +from typing import Any, Iterator, List, Union + +import numpy as np +import pycocotools.mask as mask_util +import torch +from torch import device + +from .boxes import Boxes +from .memory import retry_if_cuda_oom +from .roi_align import ROIAlign + + +def polygon_area(x, y): + # Using the shoelace formula + # https://stackoverflow.com/questions/24467972/calculate-area-of-polygon-given-x-y-coordinates + return 0.5 * np.abs(np.dot(x, np.roll(y, 1)) - np.dot(y, np.roll(x, 1))) + + +def polygons_to_bitmask( + polygons: List[np.ndarray], height: int, width: int +) -> np.ndarray: + """ + Args: + polygons (list[ndarray]): each array has shape (Nx2,) + height, width (int) + + Returns: + ndarray: a bool mask of shape (height, width) + """ + if len(polygons) == 0: + # COCOAPI does not support empty polygons + return np.zeros((height, width)).astype(bool) + rles = mask_util.frPyObjects(polygons, height, width) + rle = mask_util.merge(rles) + return mask_util.decode(rle).astype(bool) + + +def rasterize_polygons_within_box( + polygons: List[np.ndarray], box: np.ndarray, mask_size: int +) -> torch.Tensor: + """ + Rasterize the polygons into a mask image and + crop the mask content in the given box. + The cropped mask is resized to (mask_size, mask_size). + + This function is used when generating training targets for mask head in Mask R-CNN. + Given original ground-truth masks for an image, new ground-truth mask + training targets in the size of `mask_size x mask_size` + must be provided for each predicted box. This function will be called to + produce such targets. + + Args: + polygons (list[ndarray[float]]): a list of polygons, which represents an instance. + box: 4-element numpy array + mask_size (int): + + Returns: + Tensor: BoolTensor of shape (mask_size, mask_size) + """ + # 1. Shift the polygons w.r.t the boxes + w, h = box[2] - box[0], box[3] - box[1] + + polygons = copy.deepcopy(polygons) + for p in polygons: + p[0::2] = p[0::2] - box[0] + p[1::2] = p[1::2] - box[1] + + # 2. Rescale the polygons to the new box size + # max() to avoid division by small number + ratio_h = mask_size / max(h, 0.1) + ratio_w = mask_size / max(w, 0.1) + + if ratio_h == ratio_w: + for p in polygons: + p *= ratio_h + else: + for p in polygons: + p[0::2] *= ratio_w + p[1::2] *= ratio_h + + # 3. Rasterize the polygons with coco api + mask = polygons_to_bitmask(polygons, mask_size, mask_size) + mask = torch.from_numpy(mask) + return mask + + +class BitMasks: + """ + This class stores the segmentation masks for all objects in one image, in + the form of bitmaps. + + Attributes: + tensor: bool Tensor of N,H,W, representing N instances in the image. + """ + + def __init__(self, tensor: Union[torch.Tensor, np.ndarray]): + """ + Args: + tensor: bool Tensor of N,H,W, representing N instances in the image. + """ + if isinstance(tensor, torch.Tensor): + tensor = tensor.to(torch.bool) + else: + tensor = torch.as_tensor( + tensor, dtype=torch.bool, device=torch.device("cpu") + ) + assert tensor.dim() == 3, tensor.size() + self.image_size = tensor.shape[1:] + self.tensor = tensor + + @torch.jit.unused + def to(self, *args: Any, **kwargs: Any) -> "BitMasks": + return BitMasks(self.tensor.to(*args, **kwargs)) + + @property + def device(self) -> torch.device: + return self.tensor.device + + @torch.jit.unused + def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "BitMasks": + """ + Returns: + BitMasks: Create a new :class:`BitMasks` by indexing. + + The following usage are allowed: + + 1. `new_masks = masks[3]`: return a `BitMasks` which contains only one mask. + 2. `new_masks = masks[2:10]`: return a slice of masks. + 3. `new_masks = masks[vector]`, where vector is a torch.BoolTensor + with `length = len(masks)`. Nonzero elements in the vector will be selected. + + Note that the returned object might share storage with this object, + subject to Pytorch's indexing semantics. + """ + if isinstance(item, int): + return BitMasks(self.tensor[item].unsqueeze(0)) + m = self.tensor[item] + assert ( + m.dim() == 3 + ), "Indexing on BitMasks with {} returns a tensor with shape {}!".format( + item, m.shape + ) + return BitMasks(m) + + @torch.jit.unused + def __iter__(self) -> torch.Tensor: + yield from self.tensor + + @torch.jit.unused + def __repr__(self) -> str: + s = self.__class__.__name__ + "(" + s += "num_instances={})".format(len(self.tensor)) + return s + + def __len__(self) -> int: + return self.tensor.shape[0] + + def nonempty(self) -> torch.Tensor: + """ + Find masks that are non-empty. + + Returns: + Tensor: a BoolTensor which represents + whether each mask is empty (False) or non-empty (True). + """ + return self.tensor.flatten(1).any(dim=1) + + @staticmethod + def from_polygon_masks( + polygon_masks: Union["PolygonMasks", List[List[np.ndarray]]], + height: int, + width: int, + ) -> "BitMasks": + """ + Args: + polygon_masks (list[list[ndarray]] or PolygonMasks) + height, width (int) + """ + if isinstance(polygon_masks, PolygonMasks): + polygon_masks = polygon_masks.polygons + masks = [polygons_to_bitmask(p, height, width) for p in polygon_masks] + if len(masks): + return BitMasks(torch.stack([torch.from_numpy(x) for x in masks])) + else: + return BitMasks(torch.empty(0, height, width, dtype=torch.bool)) + + @staticmethod + def from_roi_masks(roi_masks: "ROIMasks", height: int, width: int) -> "BitMasks": + """ + Args: + roi_masks: + height, width (int): + """ + return roi_masks.to_bitmasks(height, width) + + def crop_and_resize(self, boxes: torch.Tensor, mask_size: int) -> torch.Tensor: + """ + Crop each bitmask by the given box, and resize results to (mask_size, mask_size). + This can be used to prepare training targets for Mask R-CNN. + It has less reconstruction error compared to rasterization with polygons. + However we observe no difference in accuracy, + but BitMasks requires more memory to store all the masks. + + Args: + boxes (Tensor): Nx4 tensor storing the boxes for each mask + mask_size (int): the size of the rasterized mask. + + Returns: + Tensor: + A bool tensor of shape (N, mask_size, mask_size), where + N is the number of predicted boxes for this image. + """ + assert len(boxes) == len(self), "{} != {}".format(len(boxes), len(self)) + device = self.tensor.device + + batch_inds = torch.arange(len(boxes), device=device).to(dtype=boxes.dtype)[ + :, None + ] + rois = torch.cat([batch_inds, boxes], dim=1) # Nx5 + + bit_masks = self.tensor.to(dtype=torch.float32) + rois = rois.to(device=device) + output = ( + ROIAlign((mask_size, mask_size), 1.0, 0, aligned=True) + .forward(bit_masks[:, None, :, :], rois) + .squeeze(1) + ) + output = output >= 0.5 + return output + + def get_bounding_boxes(self) -> Boxes: + """ + Returns: + Boxes: tight bounding boxes around bitmasks. + If a mask is empty, it's bounding box will be all zero. + """ + boxes = torch.zeros(self.tensor.shape[0], 4, dtype=torch.float32) + x_any = torch.any(self.tensor, dim=1) + y_any = torch.any(self.tensor, dim=2) + for idx in range(self.tensor.shape[0]): + x = torch.where(x_any[idx, :])[0] + y = torch.where(y_any[idx, :])[0] + if len(x) > 0 and len(y) > 0: + boxes[idx, :] = torch.as_tensor( + [x[0], y[0], x[-1] + 1, y[-1] + 1], dtype=torch.float32 + ) + return Boxes(boxes) + + @staticmethod + def cat(bitmasks_list: List["BitMasks"]) -> "BitMasks": + """ + Concatenates a list of BitMasks into a single BitMasks + + Arguments: + bitmasks_list (list[BitMasks]) + + Returns: + BitMasks: the concatenated BitMasks + """ + assert isinstance(bitmasks_list, (list, tuple)) + assert len(bitmasks_list) > 0 + assert all(isinstance(bitmask, BitMasks) for bitmask in bitmasks_list) + + cat_bitmasks = type(bitmasks_list[0])( + torch.cat([bm.tensor for bm in bitmasks_list], dim=0) + ) + return cat_bitmasks + + +class PolygonMasks: + """ + This class stores the segmentation masks for all objects in one image, in the form of polygons. + + Attributes: + polygons: list[list[ndarray]]. Each ndarray is a float64 vector representing a polygon. + """ + + def __init__(self, polygons: List[List[Union[torch.Tensor, np.ndarray]]]): + """ + Arguments: + polygons (list[list[np.ndarray]]): The first + level of the list correspond to individual instances, + the second level to all the polygons that compose the + instance, and the third level to the polygon coordinates. + The third level array should have the format of + [x0, y0, x1, y1, ..., xn, yn] (n >= 3). + """ + if not isinstance(polygons, list): + raise ValueError( + "Cannot create PolygonMasks: Expect a list of list of polygons per image. " + "Got '{}' instead.".format(type(polygons)) + ) + + def _make_array(t: Union[torch.Tensor, np.ndarray]) -> np.ndarray: + # Use float64 for higher precision, because why not? + # Always put polygons on CPU (self.to is a no-op) since they + # are supposed to be small tensors. + # May need to change this assumption if GPU placement becomes useful + if isinstance(t, torch.Tensor): + t = t.cpu().numpy() + return np.asarray(t).astype("float64") + + def process_polygons( + polygons_per_instance: List[Union[torch.Tensor, np.ndarray]], + ) -> List[np.ndarray]: + if not isinstance(polygons_per_instance, list): + raise ValueError( + "Cannot create polygons: Expect a list of polygons per instance. " + "Got '{}' instead.".format(type(polygons_per_instance)) + ) + # transform each polygon to a numpy array + polygons_per_instance = [_make_array(p) for p in polygons_per_instance] + for polygon in polygons_per_instance: + if len(polygon) % 2 != 0 or len(polygon) < 6: + raise ValueError( + f"Cannot create a polygon from {len(polygon)} coordinates." + ) + return polygons_per_instance + + self.polygons: List[List[np.ndarray]] = [ + process_polygons(polygons_per_instance) + for polygons_per_instance in polygons + ] + + def to(self, *args: Any, **kwargs: Any) -> "PolygonMasks": + return self + + @property + def device(self) -> torch.device: + return torch.device("cpu") + + def get_bounding_boxes(self) -> Boxes: + """ + Returns: + Boxes: tight bounding boxes around polygon masks. + """ + boxes = torch.zeros(len(self.polygons), 4, dtype=torch.float32) + for idx, polygons_per_instance in enumerate(self.polygons): + minxy = torch.as_tensor([float("inf"), float("inf")], dtype=torch.float32) + maxxy = torch.zeros(2, dtype=torch.float32) + for polygon in polygons_per_instance: + coords = torch.from_numpy(polygon).view(-1, 2).to(dtype=torch.float32) + minxy = torch.min(minxy, torch.min(coords, dim=0).values) + maxxy = torch.max(maxxy, torch.max(coords, dim=0).values) + boxes[idx, :2] = minxy + boxes[idx, 2:] = maxxy + return Boxes(boxes) + + def nonempty(self) -> torch.Tensor: + """ + Find masks that are non-empty. + + Returns: + Tensor: + a BoolTensor which represents whether each mask is empty (False) or not (True). + """ + keep = [1 if len(polygon) > 0 else 0 for polygon in self.polygons] + return torch.from_numpy(np.asarray(keep, dtype=bool)) + + def __getitem__( + self, item: Union[int, slice, List[int], torch.BoolTensor] + ) -> "PolygonMasks": + """ + Support indexing over the instances and return a `PolygonMasks` object. + `item` can be: + + 1. An integer. It will return an object with only one instance. + 2. A slice. It will return an object with the selected instances. + 3. A list[int]. It will return an object with the selected instances, + correpsonding to the indices in the list. + 4. A vector mask of type BoolTensor, whose length is num_instances. + It will return an object with the instances whose mask is nonzero. + """ + if isinstance(item, int): + selected_polygons = [self.polygons[item]] + elif isinstance(item, slice): + selected_polygons = self.polygons[item] + elif isinstance(item, list): + selected_polygons = [self.polygons[i] for i in item] + elif isinstance(item, torch.Tensor): + # Polygons is a list, so we have to move the indices back to CPU. + if item.dtype == torch.bool: + assert item.dim() == 1, item.shape + item = item.nonzero().squeeze(1).cpu().numpy().tolist() + elif item.dtype in [torch.int32, torch.int64]: + item = item.cpu().numpy().tolist() + else: + raise ValueError( + "Unsupported tensor dtype={} for indexing!".format(item.dtype) + ) + selected_polygons = [self.polygons[i] for i in item] + return PolygonMasks(selected_polygons) + + def __iter__(self) -> Iterator[List[np.ndarray]]: + """ + Yields: + list[ndarray]: the polygons for one instance. + Each Tensor is a float64 vector representing a polygon. + """ + return iter(self.polygons) + + def __repr__(self) -> str: + s = self.__class__.__name__ + "(" + s += "num_instances={})".format(len(self.polygons)) + return s + + def __len__(self) -> int: + return len(self.polygons) + + def crop_and_resize(self, boxes: torch.Tensor, mask_size: int) -> torch.Tensor: + """ + Crop each mask by the given box, and resize results to (mask_size, mask_size). + This can be used to prepare training targets for Mask R-CNN. + + Args: + boxes (Tensor): Nx4 tensor storing the boxes for each mask + mask_size (int): the size of the rasterized mask. + + Returns: + Tensor: A bool tensor of shape (N, mask_size, mask_size), where + N is the number of predicted boxes for this image. + """ + assert len(boxes) == len(self), "{} != {}".format(len(boxes), len(self)) + + device = boxes.device + # Put boxes on the CPU, as the polygon representation is not efficient GPU-wise + # (several small tensors for representing a single instance mask) + boxes = boxes.to(torch.device("cpu")) + + results = [ + rasterize_polygons_within_box(poly, box.numpy(), mask_size) + for poly, box in zip(self.polygons, boxes) + ] + """ + poly: list[list[float]], the polygons for one instance + box: a tensor of shape (4,) + """ + if len(results) == 0: + return torch.empty(0, mask_size, mask_size, dtype=torch.bool, device=device) + return torch.stack(results, dim=0).to(device=device) + + def area(self): + """ + Computes area of the mask. + Only works with Polygons, using the shoelace formula: + https://stackoverflow.com/questions/24467972/calculate-area-of-polygon-given-x-y-coordinates + + Returns: + Tensor: a vector, area for each instance + """ + + area = [] + for polygons_per_instance in self.polygons: + area_per_instance = 0 + for p in polygons_per_instance: + area_per_instance += polygon_area(p[0::2], p[1::2]) + area.append(area_per_instance) + + return torch.tensor(area) + + @staticmethod + def cat(polymasks_list: List["PolygonMasks"]) -> "PolygonMasks": + """ + Concatenates a list of PolygonMasks into a single PolygonMasks + + Arguments: + polymasks_list (list[PolygonMasks]) + + Returns: + PolygonMasks: the concatenated PolygonMasks + """ + assert isinstance(polymasks_list, (list, tuple)) + assert len(polymasks_list) > 0 + assert all(isinstance(polymask, PolygonMasks) for polymask in polymasks_list) + + cat_polymasks = type(polymasks_list[0])( + list(itertools.chain.from_iterable(pm.polygons for pm in polymasks_list)) + ) + return cat_polymasks + + +class ROIMasks: + """ + Represent masks by N smaller masks defined in some ROIs. Once ROI boxes are given, + full-image bitmask can be obtained by "pasting" the mask on the region defined + by the corresponding ROI box. + """ + + def __init__(self, tensor: torch.Tensor): + """ + Args: + tensor: (N, M, M) mask tensor that defines the mask within each ROI. + """ + if tensor.dim() != 3: + raise ValueError("ROIMasks must take a masks of 3 dimension.") + self.tensor = tensor + + def to(self, device: torch.device) -> "ROIMasks": + return ROIMasks(self.tensor.to(device)) + + @property + def device(self) -> device: + return self.tensor.device + + def __len__(self): + return self.tensor.shape[0] + + def __getitem__(self, item) -> "ROIMasks": + """ + Returns: + ROIMasks: Create a new :class:`ROIMasks` by indexing. + + The following usage are allowed: + + 1. `new_masks = masks[2:10]`: return a slice of masks. + 2. `new_masks = masks[vector]`, where vector is a torch.BoolTensor + with `length = len(masks)`. Nonzero elements in the vector will be selected. + + Note that the returned object might share storage with this object, + subject to Pytorch's indexing semantics. + """ + t = self.tensor[item] + if t.dim() != 3: + raise ValueError( + f"Indexing on ROIMasks with {item} returns a tensor with shape {t.shape}!" + ) + return ROIMasks(t) + + @torch.jit.unused + def __repr__(self) -> str: + s = self.__class__.__name__ + "(" + s += "num_instances={})".format(len(self.tensor)) + return s + + @torch.jit.unused + def to_bitmasks(self, boxes: torch.Tensor, height, width, threshold=0.5): + """ + Args: see documentation of :func:`paste_masks_in_image`. + """ + from detectron2.layers.mask_ops import ( + _paste_masks_tensor_shape, + paste_masks_in_image, + ) + + if torch.jit.is_tracing(): + if isinstance(height, torch.Tensor): + paste_func = _paste_masks_tensor_shape + else: + paste_func = paste_masks_in_image + else: + paste_func = retry_if_cuda_oom(paste_masks_in_image) + bitmasks = paste_func( + self.tensor, boxes.tensor, (height, width), threshold=threshold + ) + return BitMasks(bitmasks) diff --git a/third_party/sam3/sam3/agent/helpers/memory.py b/third_party/sam3/sam3/agent/helpers/memory.py new file mode 100644 index 0000000000000000000000000000000000000000..fdb70861c35ad425fe3aa868a5f5f5f88e9f6c5a --- /dev/null +++ b/third_party/sam3/sam3/agent/helpers/memory.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +import logging +from contextlib import contextmanager +from functools import wraps + +import torch + +__all__ = ["retry_if_cuda_oom"] + + +@contextmanager +def _ignore_torch_cuda_oom(): + """ + A context which ignores CUDA OOM exception from pytorch. + """ + try: + yield + except RuntimeError as e: + # NOTE: the string may change? + if "CUDA out of memory. " in str(e): + pass + else: + raise + + +def retry_if_cuda_oom(func): + """ + Makes a function retry itself after encountering + pytorch's CUDA OOM error. + It will first retry after calling `torch.cuda.empty_cache()`. + + If that still fails, it will then retry by trying to convert inputs to CPUs. + In this case, it expects the function to dispatch to CPU implementation. + The return values may become CPU tensors as well and it's user's + responsibility to convert it back to CUDA tensor if needed. + + Args: + func: a stateless callable that takes tensor-like objects as arguments + + Returns: + a callable which retries `func` if OOM is encountered. + + Examples: + :: + output = retry_if_cuda_oom(some_torch_function)(input1, input2) + # output may be on CPU even if inputs are on GPU + + Note: + 1. When converting inputs to CPU, it will only look at each argument and check + if it has `.device` and `.to` for conversion. Nested structures of tensors + are not supported. + + 2. Since the function might be called more than once, it has to be + stateless. + """ + + def maybe_to_cpu(x): + try: + like_gpu_tensor = x.device.type == "cuda" and hasattr(x, "to") + except AttributeError: + like_gpu_tensor = False + if like_gpu_tensor: + return x.to(device="cpu") + else: + return x + + @wraps(func) + def wrapped(*args, **kwargs): + with _ignore_torch_cuda_oom(): + return func(*args, **kwargs) + + # Clear cache and retry + torch.cuda.empty_cache() + with _ignore_torch_cuda_oom(): + return func(*args, **kwargs) + + # Try on CPU. This slows down the code significantly, therefore print a notice. + logger = logging.getLogger(__name__) + logger.info( + "Attempting to copy inputs of {} to CPU due to CUDA OOM".format(str(func)) + ) + new_args = (maybe_to_cpu(x) for x in args) + new_kwargs = {k: maybe_to_cpu(v) for k, v in kwargs.items()} + return func(*new_args, **new_kwargs) + + return wrapped diff --git a/third_party/sam3/sam3/agent/helpers/rle.py b/third_party/sam3/sam3/agent/helpers/rle.py new file mode 100644 index 0000000000000000000000000000000000000000..db3faefc4fda586e97b41c93b55436e625acbb2c --- /dev/null +++ b/third_party/sam3/sam3/agent/helpers/rle.py @@ -0,0 +1,124 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +"""Some utilities for RLE encoding that doesn't require downloading the masks to the cpu""" + +import numpy as np +import torch +from pycocotools import mask as mask_util + + +@torch.no_grad() +def rle_encode(orig_mask, return_areas=False): + """Encodes a collection of masks in RLE format + + This function emulates the behavior of the COCO API's encode function, but + is executed partially on the GPU for faster execution. + + Args: + mask (torch.Tensor): A mask of shape (N, H, W) with dtype=torch.bool + return_areas (bool): If True, add the areas of the masks as a part of + the RLE output dict under the "area" key. Default is False. + + Returns: + str: The RLE encoded masks + """ + assert orig_mask.ndim == 3, "Mask must be of shape (N, H, W)" + assert orig_mask.dtype == torch.bool, "Mask must have dtype=torch.bool" + + if orig_mask.numel() == 0: + return [] + + # First, transpose the spatial dimensions. + # This is necessary because the COCO API uses Fortran order + mask = orig_mask.transpose(1, 2) + + # Flatten the mask + flat_mask = mask.reshape(mask.shape[0], -1) + if return_areas: + mask_areas = flat_mask.sum(-1).tolist() + # Find the indices where the mask changes + differences = torch.ones( + mask.shape[0], flat_mask.shape[1] + 1, device=mask.device, dtype=torch.bool + ) + differences[:, 1:-1] = flat_mask[:, :-1] != flat_mask[:, 1:] + differences[:, 0] = flat_mask[:, 0] + _, change_indices = torch.where(differences) + + try: + boundaries = torch.cumsum(differences.sum(-1), 0).cpu() + except RuntimeError as _: + boundaries = torch.cumsum(differences.cpu().sum(-1), 0) + + change_indices_clone = change_indices.clone() + # First pass computes the RLEs on GPU, in a flatten format + for i in range(mask.shape[0]): + # Get the change indices for this batch item + beg = 0 if i == 0 else boundaries[i - 1].item() + end = boundaries[i].item() + change_indices[beg + 1 : end] -= change_indices_clone[beg : end - 1] + + # Now we can split the RLES of each batch item, and convert them to strings + # No more gpu at this point + change_indices = change_indices.tolist() + + batch_rles = [] + # Process each mask in the batch separately + for i in range(mask.shape[0]): + beg = 0 if i == 0 else boundaries[i - 1].item() + end = boundaries[i].item() + run_lengths = change_indices[beg:end] + + uncompressed_rle = {"counts": run_lengths, "size": list(orig_mask.shape[1:])} + h, w = uncompressed_rle["size"] + rle = mask_util.frPyObjects(uncompressed_rle, h, w) + rle["counts"] = rle["counts"].decode("utf-8") + if return_areas: + rle["area"] = mask_areas[i] + batch_rles.append(rle) + + return batch_rles + + +def robust_rle_encode(masks): + """Encodes a collection of masks in RLE format. Uses the gpu version fist, falls back to the cpu version if it fails""" + + assert masks.ndim == 3, "Mask must be of shape (N, H, W)" + assert masks.dtype == torch.bool, "Mask must have dtype=torch.bool" + + try: + return rle_encode(masks) + except RuntimeError as _: + masks = masks.cpu().numpy() + rles = [ + mask_util.encode( + np.array(mask[:, :, np.newaxis], dtype=np.uint8, order="F") + )[0] + for mask in masks + ] + for rle in rles: + rle["counts"] = rle["counts"].decode("utf-8") + return rles + + +def ann_to_rle(segm, im_info): + """Convert annotation which can be polygons, uncompressed RLE to RLE. + Args: + ann (dict) : annotation object + Returns: + ann (rle) + """ + h, w = im_info["height"], im_info["width"] + if isinstance(segm, list): + # polygon -- a single object might consist of multiple parts + # we merge all parts into one mask rle code + rles = mask_util.frPyObjects(segm, h, w) + rle = mask_util.merge(rles) + elif isinstance(segm["counts"], list): + # uncompressed RLE + rle = mask_util.frPyObjects(segm, h, w) + else: + # rle + rle = segm + return rle diff --git a/third_party/sam3/sam3/agent/helpers/roi_align.py b/third_party/sam3/sam3/agent/helpers/roi_align.py new file mode 100644 index 0000000000000000000000000000000000000000..a2f413f66c71c233a8b6aa1af65ead366bf46436 --- /dev/null +++ b/third_party/sam3/sam3/agent/helpers/roi_align.py @@ -0,0 +1,77 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +from torch import nn +from torchvision.ops import roi_align + + +# NOTE: torchvision's RoIAlign has a different default aligned=False +class ROIAlign(nn.Module): + def __init__(self, output_size, spatial_scale, sampling_ratio, aligned=True): + """ + Args: + output_size (tuple): h, w + spatial_scale (float): scale the input boxes by this number + sampling_ratio (int): number of inputs samples to take for each output + sample. 0 to take samples densely. + aligned (bool): if False, use the legacy implementation in + Detectron. If True, align the results more perfectly. + + Note: + The meaning of aligned=True: + + Given a continuous coordinate c, its two neighboring pixel indices (in our + pixel model) are computed by floor(c - 0.5) and ceil(c - 0.5). For example, + c=1.3 has pixel neighbors with discrete indices [0] and [1] (which are sampled + from the underlying signal at continuous coordinates 0.5 and 1.5). But the original + roi_align (aligned=False) does not subtract the 0.5 when computing neighboring + pixel indices and therefore it uses pixels with a slightly incorrect alignment + (relative to our pixel model) when performing bilinear interpolation. + + With `aligned=True`, + we first appropriately scale the ROI and then shift it by -0.5 + prior to calling roi_align. This produces the correct neighbors; see + detectron2/tests/test_roi_align.py for verification. + + The difference does not make a difference to the model's performance if + ROIAlign is used together with conv layers. + """ + super().__init__() + self.output_size = output_size + self.spatial_scale = spatial_scale + self.sampling_ratio = sampling_ratio + self.aligned = aligned + + from torchvision import __version__ + + version = tuple(int(x) for x in __version__.split(".")[:2]) + # https://github.com/pytorch/vision/pull/2438 + assert version >= (0, 7), "Require torchvision >= 0.7" + + def forward(self, input, rois): + """ + Args: + input: NCHW images + rois: Bx5 boxes. First column is the index into N. The other 4 columns are xyxy. + """ + assert rois.dim() == 2 and rois.size(1) == 5 + if input.is_quantized: + input = input.dequantize() + return roi_align( + input, + rois.to(dtype=input.dtype), + self.output_size, + self.spatial_scale, + self.sampling_ratio, + self.aligned, + ) + + def __repr__(self): + tmpstr = self.__class__.__name__ + "(" + tmpstr += "output_size=" + str(self.output_size) + tmpstr += ", spatial_scale=" + str(self.spatial_scale) + tmpstr += ", sampling_ratio=" + str(self.sampling_ratio) + tmpstr += ", aligned=" + str(self.aligned) + tmpstr += ")" + return tmpstr diff --git a/third_party/sam3/sam3/agent/helpers/rotated_boxes.py b/third_party/sam3/sam3/agent/helpers/rotated_boxes.py new file mode 100644 index 0000000000000000000000000000000000000000..cd39af8b82bf10084ffea8f06c66df5facf1cce4 --- /dev/null +++ b/third_party/sam3/sam3/agent/helpers/rotated_boxes.py @@ -0,0 +1,535 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +from __future__ import absolute_import, division, print_function, unicode_literals + +import math +from typing import List, Tuple + +import torch + +# from detectron2.layers.rotated_boxes import pairwise_iou_rotated + +from .boxes import Boxes + + +def pairwise_iou_rotated(boxes1, boxes2): + """ + Return intersection-over-union (Jaccard index) of boxes. + + Both sets of boxes are expected to be in + (x_center, y_center, width, height, angle) format. + + Arguments: + boxes1 (Tensor[N, 5]) + boxes2 (Tensor[M, 5]) + + Returns: + iou (Tensor[N, M]): the NxM matrix containing the pairwise + IoU values for every element in boxes1 and boxes2 + """ + return torch.ops.detectron2.box_iou_rotated(boxes1, boxes2) + + +class RotatedBoxes(Boxes): + """ + This structure stores a list of rotated boxes as a Nx5 torch.Tensor. + It supports some common methods about boxes + (`area`, `clip`, `nonempty`, etc), + and also behaves like a Tensor + (support indexing, `to(device)`, `.device`, and iteration over all boxes) + """ + + def __init__(self, tensor: torch.Tensor): + """ + Args: + tensor (Tensor[float]): a Nx5 matrix. Each row is + (x_center, y_center, width, height, angle), + in which angle is represented in degrees. + While there's no strict range restriction for it, + the recommended principal range is between [-180, 180) degrees. + + Assume we have a horizontal box B = (x_center, y_center, width, height), + where width is along the x-axis and height is along the y-axis. + The rotated box B_rot (x_center, y_center, width, height, angle) + can be seen as: + + 1. When angle == 0: + B_rot == B + 2. When angle > 0: + B_rot is obtained by rotating B w.r.t its center by :math:`|angle|` degrees CCW; + 3. When angle < 0: + B_rot is obtained by rotating B w.r.t its center by :math:`|angle|` degrees CW. + + Mathematically, since the right-handed coordinate system for image space + is (y, x), where y is top->down and x is left->right, the 4 vertices of the + rotated rectangle :math:`(yr_i, xr_i)` (i = 1, 2, 3, 4) can be obtained from + the vertices of the horizontal rectangle :math:`(y_i, x_i)` (i = 1, 2, 3, 4) + in the following way (:math:`\\theta = angle*\\pi/180` is the angle in radians, + :math:`(y_c, x_c)` is the center of the rectangle): + + .. math:: + + yr_i = \\cos(\\theta) (y_i - y_c) - \\sin(\\theta) (x_i - x_c) + y_c, + + xr_i = \\sin(\\theta) (y_i - y_c) + \\cos(\\theta) (x_i - x_c) + x_c, + + which is the standard rigid-body rotation transformation. + + Intuitively, the angle is + (1) the rotation angle from y-axis in image space + to the height vector (top->down in the box's local coordinate system) + of the box in CCW, and + (2) the rotation angle from x-axis in image space + to the width vector (left->right in the box's local coordinate system) + of the box in CCW. + + More intuitively, consider the following horizontal box ABCD represented + in (x1, y1, x2, y2): (3, 2, 7, 4), + covering the [3, 7] x [2, 4] region of the continuous coordinate system + which looks like this: + + .. code:: none + + O--------> x + | + | A---B + | | | + | D---C + | + v y + + Note that each capital letter represents one 0-dimensional geometric point + instead of a 'square pixel' here. + + In the example above, using (x, y) to represent a point we have: + + .. math:: + + O = (0, 0), A = (3, 2), B = (7, 2), C = (7, 4), D = (3, 4) + + We name vector AB = vector DC as the width vector in box's local coordinate system, and + vector AD = vector BC as the height vector in box's local coordinate system. Initially, + when angle = 0 degree, they're aligned with the positive directions of x-axis and y-axis + in the image space, respectively. + + For better illustration, we denote the center of the box as E, + + .. code:: none + + O--------> x + | + | A---B + | | E | + | D---C + | + v y + + where the center E = ((3+7)/2, (2+4)/2) = (5, 3). + + Also, + + .. math:: + + width = |AB| = |CD| = 7 - 3 = 4, + height = |AD| = |BC| = 4 - 2 = 2. + + Therefore, the corresponding representation for the same shape in rotated box in + (x_center, y_center, width, height, angle) format is: + + (5, 3, 4, 2, 0), + + Now, let's consider (5, 3, 4, 2, 90), which is rotated by 90 degrees + CCW (counter-clockwise) by definition. It looks like this: + + .. code:: none + + O--------> x + | B-C + | | | + | |E| + | | | + | A-D + v y + + The center E is still located at the same point (5, 3), while the vertices + ABCD are rotated by 90 degrees CCW with regard to E: + A = (4, 5), B = (4, 1), C = (6, 1), D = (6, 5) + + Here, 90 degrees can be seen as the CCW angle to rotate from y-axis to + vector AD or vector BC (the top->down height vector in box's local coordinate system), + or the CCW angle to rotate from x-axis to vector AB or vector DC (the left->right + width vector in box's local coordinate system). + + .. math:: + + width = |AB| = |CD| = 5 - 1 = 4, + height = |AD| = |BC| = 6 - 4 = 2. + + Next, how about (5, 3, 4, 2, -90), which is rotated by 90 degrees CW (clockwise) + by definition? It looks like this: + + .. code:: none + + O--------> x + | D-A + | | | + | |E| + | | | + | C-B + v y + + The center E is still located at the same point (5, 3), while the vertices + ABCD are rotated by 90 degrees CW with regard to E: + A = (6, 1), B = (6, 5), C = (4, 5), D = (4, 1) + + .. math:: + + width = |AB| = |CD| = 5 - 1 = 4, + height = |AD| = |BC| = 6 - 4 = 2. + + This covers exactly the same region as (5, 3, 4, 2, 90) does, and their IoU + will be 1. However, these two will generate different RoI Pooling results and + should not be treated as an identical box. + + On the other hand, it's easy to see that (X, Y, W, H, A) is identical to + (X, Y, W, H, A+360N), for any integer N. For example (5, 3, 4, 2, 270) would be + identical to (5, 3, 4, 2, -90), because rotating the shape 270 degrees CCW is + equivalent to rotating the same shape 90 degrees CW. + + We could rotate further to get (5, 3, 4, 2, 180), or (5, 3, 4, 2, -180): + + .. code:: none + + O--------> x + | + | C---D + | | E | + | B---A + | + v y + + .. math:: + + A = (7, 4), B = (3, 4), C = (3, 2), D = (7, 2), + + width = |AB| = |CD| = 7 - 3 = 4, + height = |AD| = |BC| = 4 - 2 = 2. + + Finally, this is a very inaccurate (heavily quantized) illustration of + how (5, 3, 4, 2, 60) looks like in case anyone wonders: + + .. code:: none + + O--------> x + | B\ + | / C + | /E / + | A / + | `D + v y + + It's still a rectangle with center of (5, 3), width of 4 and height of 2, + but its angle (and thus orientation) is somewhere between + (5, 3, 4, 2, 0) and (5, 3, 4, 2, 90). + """ + device = ( + tensor.device if isinstance(tensor, torch.Tensor) else torch.device("cpu") + ) + tensor = torch.as_tensor(tensor, dtype=torch.float32, device=device) + if tensor.numel() == 0: + # Use reshape, so we don't end up creating a new tensor that does not depend on + # the inputs (and consequently confuses jit) + tensor = tensor.reshape((0, 5)).to(dtype=torch.float32, device=device) + assert tensor.dim() == 2 and tensor.size(-1) == 5, tensor.size() + + self.tensor = tensor + + def clone(self) -> "RotatedBoxes": + """ + Clone the RotatedBoxes. + + Returns: + RotatedBoxes + """ + return RotatedBoxes(self.tensor.clone()) + + def to(self, device: torch.device, non_blocking: bool = False): + # Boxes are assumed float32 and does not support to(dtype) + return RotatedBoxes(self.tensor.to(device=device, non_blocking=non_blocking)) + + def area(self) -> torch.Tensor: + """ + Computes the area of all the boxes. + + Returns: + torch.Tensor: a vector with areas of each box. + """ + box = self.tensor + area = box[:, 2] * box[:, 3] + return area + + # Avoid in-place operations so that we can torchscript; NOTE: this creates a new tensor + def normalize_angles(self) -> None: + """ + Restrict angles to the range of [-180, 180) degrees + """ + angle_tensor = (self.tensor[:, 4] + 180.0) % 360.0 - 180.0 + self.tensor = torch.cat((self.tensor[:, :4], angle_tensor[:, None]), dim=1) + + def clip( + self, box_size: Tuple[int, int], clip_angle_threshold: float = 1.0 + ) -> None: + """ + Clip (in place) the boxes by limiting x coordinates to the range [0, width] + and y coordinates to the range [0, height]. + + For RRPN: + Only clip boxes that are almost horizontal with a tolerance of + clip_angle_threshold to maintain backward compatibility. + + Rotated boxes beyond this threshold are not clipped for two reasons: + + 1. There are potentially multiple ways to clip a rotated box to make it + fit within the image. + 2. It's tricky to make the entire rectangular box fit within the image + and still be able to not leave out pixels of interest. + + Therefore we rely on ops like RoIAlignRotated to safely handle this. + + Args: + box_size (height, width): The clipping box's size. + clip_angle_threshold: + Iff. abs(normalized(angle)) <= clip_angle_threshold (in degrees), + we do the clipping as horizontal boxes. + """ + h, w = box_size + + # normalize angles to be within (-180, 180] degrees + self.normalize_angles() + + idx = torch.where(torch.abs(self.tensor[:, 4]) <= clip_angle_threshold)[0] + + # convert to (x1, y1, x2, y2) + x1 = self.tensor[idx, 0] - self.tensor[idx, 2] / 2.0 + y1 = self.tensor[idx, 1] - self.tensor[idx, 3] / 2.0 + x2 = self.tensor[idx, 0] + self.tensor[idx, 2] / 2.0 + y2 = self.tensor[idx, 1] + self.tensor[idx, 3] / 2.0 + + # clip + x1.clamp_(min=0, max=w) + y1.clamp_(min=0, max=h) + x2.clamp_(min=0, max=w) + y2.clamp_(min=0, max=h) + + # convert back to (xc, yc, w, h) + self.tensor[idx, 0] = (x1 + x2) / 2.0 + self.tensor[idx, 1] = (y1 + y2) / 2.0 + # make sure widths and heights do not increase due to numerical errors + self.tensor[idx, 2] = torch.min(self.tensor[idx, 2], x2 - x1) + self.tensor[idx, 3] = torch.min(self.tensor[idx, 3], y2 - y1) + + def nonempty(self, threshold: float = 0.0) -> torch.Tensor: + """ + Find boxes that are non-empty. + A box is considered empty, if either of its side is no larger than threshold. + + Returns: + Tensor: a binary vector which represents + whether each box is empty (False) or non-empty (True). + """ + box = self.tensor + widths = box[:, 2] + heights = box[:, 3] + keep = (widths > threshold) & (heights > threshold) + return keep + + def __getitem__(self, item) -> "RotatedBoxes": + """ + Returns: + RotatedBoxes: Create a new :class:`RotatedBoxes` by indexing. + + The following usage are allowed: + + 1. `new_boxes = boxes[3]`: return a `RotatedBoxes` which contains only one box. + 2. `new_boxes = boxes[2:10]`: return a slice of boxes. + 3. `new_boxes = boxes[vector]`, where vector is a torch.ByteTensor + with `length = len(boxes)`. Nonzero elements in the vector will be selected. + + Note that the returned RotatedBoxes might share storage with this RotatedBoxes, + subject to Pytorch's indexing semantics. + """ + if isinstance(item, int): + return RotatedBoxes(self.tensor[item].view(1, -1)) + b = self.tensor[item] + assert ( + b.dim() == 2 + ), "Indexing on RotatedBoxes with {} failed to return a matrix!".format(item) + return RotatedBoxes(b) + + def __len__(self) -> int: + return self.tensor.shape[0] + + def __repr__(self) -> str: + return "RotatedBoxes(" + str(self.tensor) + ")" + + def inside_box( + self, box_size: Tuple[int, int], boundary_threshold: int = 0 + ) -> torch.Tensor: + """ + Args: + box_size (height, width): Size of the reference box covering + [0, width] x [0, height] + boundary_threshold (int): Boxes that extend beyond the reference box + boundary by more than boundary_threshold are considered "outside". + + For RRPN, it might not be necessary to call this function since it's common + for rotated box to extend to outside of the image boundaries + (the clip function only clips the near-horizontal boxes) + + Returns: + a binary vector, indicating whether each box is inside the reference box. + """ + height, width = box_size + + cnt_x = self.tensor[..., 0] + cnt_y = self.tensor[..., 1] + half_w = self.tensor[..., 2] / 2.0 + half_h = self.tensor[..., 3] / 2.0 + a = self.tensor[..., 4] + c = torch.abs(torch.cos(a * math.pi / 180.0)) + s = torch.abs(torch.sin(a * math.pi / 180.0)) + # This basically computes the horizontal bounding rectangle of the rotated box + max_rect_dx = c * half_w + s * half_h + max_rect_dy = c * half_h + s * half_w + + inds_inside = ( + (cnt_x - max_rect_dx >= -boundary_threshold) + & (cnt_y - max_rect_dy >= -boundary_threshold) + & (cnt_x + max_rect_dx < width + boundary_threshold) + & (cnt_y + max_rect_dy < height + boundary_threshold) + ) + + return inds_inside + + def get_centers(self) -> torch.Tensor: + """ + Returns: + The box centers in a Nx2 array of (x, y). + """ + return self.tensor[:, :2] + + def scale(self, scale_x: float, scale_y: float) -> None: + """ + Scale the rotated box with horizontal and vertical scaling factors + Note: when scale_factor_x != scale_factor_y, + the rotated box does not preserve the rectangular shape when the angle + is not a multiple of 90 degrees under resize transformation. + Instead, the shape is a parallelogram (that has skew) + Here we make an approximation by fitting a rotated rectangle to the parallelogram. + """ + self.tensor[:, 0] *= scale_x + self.tensor[:, 1] *= scale_y + theta = self.tensor[:, 4] * math.pi / 180.0 + c = torch.cos(theta) + s = torch.sin(theta) + + # In image space, y is top->down and x is left->right + # Consider the local coordintate system for the rotated box, + # where the box center is located at (0, 0), and the four vertices ABCD are + # A(-w / 2, -h / 2), B(w / 2, -h / 2), C(w / 2, h / 2), D(-w / 2, h / 2) + # the midpoint of the left edge AD of the rotated box E is: + # E = (A+D)/2 = (-w / 2, 0) + # the midpoint of the top edge AB of the rotated box F is: + # F(0, -h / 2) + # To get the old coordinates in the global system, apply the rotation transformation + # (Note: the right-handed coordinate system for image space is yOx): + # (old_x, old_y) = (s * y + c * x, c * y - s * x) + # E(old) = (s * 0 + c * (-w/2), c * 0 - s * (-w/2)) = (-c * w / 2, s * w / 2) + # F(old) = (s * (-h / 2) + c * 0, c * (-h / 2) - s * 0) = (-s * h / 2, -c * h / 2) + # After applying the scaling factor (sfx, sfy): + # E(new) = (-sfx * c * w / 2, sfy * s * w / 2) + # F(new) = (-sfx * s * h / 2, -sfy * c * h / 2) + # The new width after scaling tranformation becomes: + + # w(new) = |E(new) - O| * 2 + # = sqrt[(sfx * c * w / 2)^2 + (sfy * s * w / 2)^2] * 2 + # = sqrt[(sfx * c)^2 + (sfy * s)^2] * w + # i.e., scale_factor_w = sqrt[(sfx * c)^2 + (sfy * s)^2] + # + # For example, + # when angle = 0 or 180, |c| = 1, s = 0, scale_factor_w == scale_factor_x; + # when |angle| = 90, c = 0, |s| = 1, scale_factor_w == scale_factor_y + self.tensor[:, 2] *= torch.sqrt((scale_x * c) ** 2 + (scale_y * s) ** 2) + + # h(new) = |F(new) - O| * 2 + # = sqrt[(sfx * s * h / 2)^2 + (sfy * c * h / 2)^2] * 2 + # = sqrt[(sfx * s)^2 + (sfy * c)^2] * h + # i.e., scale_factor_h = sqrt[(sfx * s)^2 + (sfy * c)^2] + # + # For example, + # when angle = 0 or 180, |c| = 1, s = 0, scale_factor_h == scale_factor_y; + # when |angle| = 90, c = 0, |s| = 1, scale_factor_h == scale_factor_x + self.tensor[:, 3] *= torch.sqrt((scale_x * s) ** 2 + (scale_y * c) ** 2) + + # The angle is the rotation angle from y-axis in image space to the height + # vector (top->down in the box's local coordinate system) of the box in CCW. + # + # angle(new) = angle_yOx(O - F(new)) + # = angle_yOx( (sfx * s * h / 2, sfy * c * h / 2) ) + # = atan2(sfx * s * h / 2, sfy * c * h / 2) + # = atan2(sfx * s, sfy * c) + # + # For example, + # when sfx == sfy, angle(new) == atan2(s, c) == angle(old) + self.tensor[:, 4] = torch.atan2(scale_x * s, scale_y * c) * 180 / math.pi + + @classmethod + def cat(cls, boxes_list: List["RotatedBoxes"]) -> "RotatedBoxes": + """ + Concatenates a list of RotatedBoxes into a single RotatedBoxes + + Arguments: + boxes_list (list[RotatedBoxes]) + + Returns: + RotatedBoxes: the concatenated RotatedBoxes + """ + assert isinstance(boxes_list, (list, tuple)) + if len(boxes_list) == 0: + return cls(torch.empty(0)) + assert all([isinstance(box, RotatedBoxes) for box in boxes_list]) + + # use torch.cat (v.s. layers.cat) so the returned boxes never share storage with input + cat_boxes = cls(torch.cat([b.tensor for b in boxes_list], dim=0)) + return cat_boxes + + @property + def device(self) -> torch.device: + return self.tensor.device + + @torch.jit.unused + def __iter__(self): + """ + Yield a box as a Tensor of shape (5,) at a time. + """ + yield from self.tensor + + +def pairwise_iou(boxes1: RotatedBoxes, boxes2: RotatedBoxes) -> None: + """ + Given two lists of rotated boxes of size N and M, + compute the IoU (intersection over union) + between **all** N x M pairs of boxes. + The box order must be (x_center, y_center, width, height, angle). + + Args: + boxes1, boxes2 (RotatedBoxes): + two `RotatedBoxes`. Contains N & M rotated boxes, respectively. + + Returns: + Tensor: IoU, sized [N,M]. + """ + + return pairwise_iou_rotated(boxes1.tensor, boxes2.tensor) diff --git a/third_party/sam3/sam3/agent/helpers/som_utils.py b/third_party/sam3/sam3/agent/helpers/som_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..81c6f63f86711acf0930b20dd83c923fb6189361 --- /dev/null +++ b/third_party/sam3/sam3/agent/helpers/som_utils.py @@ -0,0 +1,408 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +import colorsys +from dataclasses import dataclass +from typing import List, Tuple + +import cv2 +import matplotlib as mpl +import matplotlib.colors as mplc +import numpy as np +import pycocotools.mask as mask_utils + + +def rgb_to_hex(rgb_color): + """ + Convert a rgb color to hex color. + + Args: + rgb_color (tuple/list of ints): RGB color in tuple or list format. + + Returns: + str: Hex color. + + Example: + ``` + >>> rgb_to_hex((255, 0, 244)) + '#ff00ff' + ``` + """ + return "#" + "".join([hex(c)[2:].zfill(2) for c in rgb_color]) + + +# DEFAULT_COLOR_HEX_TO_NAME = { +# rgb_to_hex((255, 0, 0)): "red", +# rgb_to_hex((0, 255, 0)): "lime", +# rgb_to_hex((0, 0, 255)): "blue", +# rgb_to_hex((255, 255, 0)): "yellow", +# rgb_to_hex((255, 0, 255)): "fuchsia", +# rgb_to_hex((0, 255, 255)): "aqua", +# rgb_to_hex((255, 165, 0)): "orange", +# rgb_to_hex((128, 0, 128)): "purple", +# rgb_to_hex((255, 215, 0)): "gold", +# } + +# Assuming rgb_to_hex is a function that converts an (R, G, B) tuple to a hex string. +# For example: def rgb_to_hex(rgb): return '#%02x%02x%02x' % rgb + +DEFAULT_COLOR_HEX_TO_NAME = { + # The top 20 approved colors + rgb_to_hex((255, 255, 0)): "yellow", + rgb_to_hex((0, 255, 0)): "lime", + rgb_to_hex((0, 255, 255)): "cyan", + rgb_to_hex((255, 0, 255)): "magenta", + rgb_to_hex((255, 0, 0)): "red", + rgb_to_hex((255, 127, 0)): "orange", + rgb_to_hex((127, 255, 0)): "chartreuse", + rgb_to_hex((0, 255, 127)): "spring green", + rgb_to_hex((255, 0, 127)): "rose", + rgb_to_hex((127, 0, 255)): "violet", + rgb_to_hex((192, 255, 0)): "electric lime", + rgb_to_hex((255, 192, 0)): "vivid orange", + rgb_to_hex((0, 255, 192)): "turquoise", + rgb_to_hex((192, 0, 255)): "bright violet", + rgb_to_hex((255, 0, 192)): "bright pink", + rgb_to_hex((255, 64, 0)): "fiery orange", + rgb_to_hex((64, 255, 0)): "bright chartreuse", + rgb_to_hex((0, 255, 64)): "malachite", + rgb_to_hex((64, 0, 255)): "deep violet", + rgb_to_hex((255, 0, 64)): "hot pink", +} + + +DEFAULT_COLOR_PALETTE = list(DEFAULT_COLOR_HEX_TO_NAME.keys()) + + +def _validate_color_hex(color_hex: str): + color_hex = color_hex.lstrip("#") + if not all(c in "0123456789abcdefABCDEF" for c in color_hex): + raise ValueError("Invalid characters in color hash") + if len(color_hex) not in (3, 6): + raise ValueError("Invalid length of color hash") + + +# copied from https://github.com/roboflow/supervision/blob/c8f557af0c61b5c03392bad2cc36c8835598b1e1/supervision/draw/color.py +@dataclass +class Color: + """ + Represents a color in RGB format. + + Attributes: + r (int): Red channel. + g (int): Green channel. + b (int): Blue channel. + """ + + r: int + g: int + b: int + + @classmethod + def from_hex(cls, color_hex: str): + """ + Create a Color instance from a hex string. + + Args: + color_hex (str): Hex string of the color. + + Returns: + Color: Instance representing the color. + + Example: + ``` + >>> Color.from_hex('#ff00ff') + Color(r=255, g=0, b=255) + ``` + """ + _validate_color_hex(color_hex) + color_hex = color_hex.lstrip("#") + if len(color_hex) == 3: + color_hex = "".join(c * 2 for c in color_hex) + r, g, b = (int(color_hex[i : i + 2], 16) for i in range(0, 6, 2)) + return cls(r, g, b) + + @classmethod + def to_hex(cls, color): + """ + Convert a Color instance to a hex string. + + Args: + color (Color): Color instance of color. + + Returns: + Color: a hex string. + """ + return rgb_to_hex((color.r, color.g, color.b)) + + def as_rgb(self) -> Tuple[int, int, int]: + """ + Returns the color as an RGB tuple. + + Returns: + Tuple[int, int, int]: RGB tuple. + + Example: + ``` + >>> color.as_rgb() + (255, 0, 255) + ``` + """ + return self.r, self.g, self.b + + def as_bgr(self) -> Tuple[int, int, int]: + """ + Returns the color as a BGR tuple. + + Returns: + Tuple[int, int, int]: BGR tuple. + + Example: + ``` + >>> color.as_bgr() + (255, 0, 255) + ``` + """ + return self.b, self.g, self.r + + @classmethod + def white(cls): + return Color.from_hex(color_hex="#ffffff") + + @classmethod + def black(cls): + return Color.from_hex(color_hex="#000000") + + @classmethod + def red(cls): + return Color.from_hex(color_hex="#ff0000") + + @classmethod + def green(cls): + return Color.from_hex(color_hex="#00ff00") + + @classmethod + def blue(cls): + return Color.from_hex(color_hex="#0000ff") + + +@dataclass +class ColorPalette: + colors: List[Color] + + @classmethod + def default(cls): + """ + Returns a default color palette. + + Returns: + ColorPalette: A ColorPalette instance with default colors. + + Example: + ``` + >>> ColorPalette.default() + ColorPalette(colors=[Color(r=255, g=0, b=0), Color(r=0, g=255, b=0), ...]) + ``` + """ + return ColorPalette.from_hex(color_hex_list=DEFAULT_COLOR_PALETTE) + + @classmethod + def from_hex(cls, color_hex_list: List[str]): + """ + Create a ColorPalette instance from a list of hex strings. + + Args: + color_hex_list (List[str]): List of color hex strings. + + Returns: + ColorPalette: A ColorPalette instance. + + Example: + ``` + >>> ColorPalette.from_hex(['#ff0000', '#00ff00', '#0000ff']) + ColorPalette(colors=[Color(r=255, g=0, b=0), Color(r=0, g=255, b=0), ...]) + ``` + """ + colors = [Color.from_hex(color_hex) for color_hex in color_hex_list] + return cls(colors) + + def by_idx(self, idx: int) -> Color: + """ + Return the color at a given index in the palette. + + Args: + idx (int): Index of the color in the palette. + + Returns: + Color: Color at the given index. + + Example: + ``` + >>> color_palette.by_idx(1) + Color(r=0, g=255, b=0) + ``` + """ + if idx < 0: + raise ValueError("idx argument should not be negative") + idx = idx % len(self.colors) + return self.colors[idx] + + def find_farthest_color(self, img_array): + """ + Return the color that is the farthest from the given color. + + Args: + img_array (np array): any *x3 np array, 3 is the RGB color channel. + + Returns: + Color: Farthest color. + + """ + # Reshape the image array for broadcasting + img_array = img_array.reshape((-1, 3)) + + # Convert colors dictionary to a NumPy array + color_values = np.array([[c.r, c.g, c.b] for c in self.colors]) + + # Calculate the Euclidean distance between the colors and each pixel in the image + # Broadcasting happens here: img_array shape is (num_pixels, 3), color_values shape is (num_colors, 3) + distances = np.sqrt( + np.sum((img_array[:, np.newaxis, :] - color_values) ** 2, axis=2) + ) + + # Average the distances for each color + mean_distances = np.mean(distances, axis=0) + + # return the farthest color + farthest_idx = np.argmax(mean_distances) + farthest_color = self.colors[farthest_idx] + farthest_color_hex = Color.to_hex(farthest_color) + if farthest_color_hex in DEFAULT_COLOR_HEX_TO_NAME: + farthest_color_name = DEFAULT_COLOR_HEX_TO_NAME[farthest_color_hex] + else: + farthest_color_name = "unknown" + + return farthest_color, farthest_color_name + + +def draw_box(ax, box_coord, alpha=0.8, edge_color="g", line_style="-", linewidth=2.0): + x0, y0, width, height = box_coord + ax.add_patch( + mpl.patches.Rectangle( + (x0, y0), + width, + height, + fill=False, + edgecolor=edge_color, + linewidth=linewidth, + alpha=alpha, + linestyle=line_style, + ) + ) + + +def draw_text( + ax, + text, + position, + font_size=None, + color="g", + horizontal_alignment="left", + rotation=0, +): + if not font_size: + font_size = mpl.rcParams["font.size"] + + color = np.maximum(list(mplc.to_rgb(color)), 0.2) + color[np.argmax(color)] = max(0.8, np.max(color)) + + x, y = position + ax.text( + x, + y, + text, + size=font_size, + family="sans-serif", + bbox={"facecolor": "none", "alpha": 0.5, "pad": 0.7, "edgecolor": "none"}, + verticalalignment="top", + horizontalalignment=horizontal_alignment, + color=color, + rotation=rotation, + ) + + +def draw_mask( + ax, rle, color, show_holes=True, alpha=0.15, upsample_factor=1.0, rle_upsampled=None +): + if isinstance(rle, dict): + mask = mask_utils.decode(rle) + elif isinstance(rle, np.ndarray): + mask = rle + else: + raise ValueError(f"Unsupported type for rle: {type(rle)}") + + mask_upsampled = None + if upsample_factor > 1.0 and show_holes: + assert rle_upsampled is not None + if isinstance(rle_upsampled, dict): + mask_upsampled = mask_utils.decode(rle_upsampled) + elif isinstance(rle_upsampled, np.ndarray): + mask_upsampled = rle_upsampled + else: + raise ValueError(f"Unsupported type for rle: {type(rle)}") + + if show_holes: + if mask_upsampled is None: + mask_upsampled = mask + h, w = mask_upsampled.shape + mask_img = np.zeros((h, w, 4)) + mask_img[:, :, :-1] = color[np.newaxis, np.newaxis, :] + mask_img[:, :, -1] = mask_upsampled * alpha + ax.imshow(mask_img) + + *_, contours, _ = cv2.findContours( + mask.astype(np.uint8).copy(), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE + ) + upsampled_contours = [(cont + 0.5) * upsample_factor - 0.5 for cont in contours] + facecolor = (0, 0, 0, 0) if show_holes else color + if alpha > 0.8: + edge_color = _change_color_brightness(color, brightness_factor=-0.7) + else: + edge_color = color + for cont in upsampled_contours: + polygon = mpl.patches.Polygon( + [el[0] for el in cont], + edgecolor=edge_color, + linewidth=2.0, + facecolor=facecolor, + ) + ax.add_patch(polygon) + + +def _change_color_brightness(color, brightness_factor): + """ + Depending on the brightness_factor, gives a lighter or darker color i.e. a color with + less or more saturation than the original color. + + Args: + color: color of the polygon. Refer to `matplotlib.colors` for a full list of + formats that are accepted. + brightness_factor (float): a value in [-1.0, 1.0] range. A lightness factor of + 0 will correspond to no change, a factor in [-1.0, 0) range will result in + a darker color and a factor in (0, 1.0] range will result in a lighter color. + + Returns: + modified_color (tuple[double]): a tuple containing the RGB values of the + modified color. Each value in the tuple is in the [0.0, 1.0] range. + """ + assert brightness_factor >= -1.0 and brightness_factor <= 1.0 + color = mplc.to_rgb(color) + polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color)) + modified_lightness = polygon_color[1] + (brightness_factor * polygon_color[1]) + modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness + modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness + modified_color = colorsys.hls_to_rgb( + polygon_color[0], modified_lightness, polygon_color[2] + ) + return modified_color diff --git a/third_party/sam3/sam3/agent/helpers/visualizer.py b/third_party/sam3/sam3/agent/helpers/visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..bab3dff43fae9c54614c19d40c261462b7d1bbab --- /dev/null +++ b/third_party/sam3/sam3/agent/helpers/visualizer.py @@ -0,0 +1,1663 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +import colorsys +import logging +import math +import random +from enum import Enum, unique + +import cv2 +import matplotlib as mpl +import matplotlib.colors as mplc +import matplotlib.figure as mplfigure +import numpy as np +import pycocotools.mask as mask_util +import torch +from iopath.common.file_io import PathManager +from matplotlib.backends.backend_agg import FigureCanvasAgg +from PIL import Image + +from .boxes import Boxes, BoxMode +from .color_map import random_color +from .keypoints import Keypoints +from .masks import BitMasks, PolygonMasks +from .rotated_boxes import RotatedBoxes + +logger = logging.getLogger(__name__) + + +__all__ = ["ColorMode", "VisImage", "Visualizer"] + + +_SMALL_OBJECT_AREA_THRESH = 1000 +_LARGE_MASK_AREA_THRESH = 120000 +_OFF_WHITE = (1.0, 1.0, 240.0 / 255) +_BLACK = (0, 0, 0) +_RED = (1.0, 0, 0) + +_KEYPOINT_THRESHOLD = 0.05 + + +@unique +class ColorMode(Enum): + """ + Enum of different color modes to use for instance visualizations. + """ + + IMAGE = 0 + """ + Picks a random color for every instance and overlay segmentations with low opacity. + """ + SEGMENTATION = 1 + """ + Let instances of the same category have similar colors + (from metadata.thing_colors), and overlay them with + high opacity. This provides more attention on the quality of segmentation. + """ + IMAGE_BW = 2 + """ + Same as IMAGE, but convert all areas without masks to gray-scale. + Only available for drawing per-instance mask predictions. + """ + + +class GenericMask: + """ + Attribute: + polygons (list[ndarray]): list[ndarray]: polygons for this mask. + Each ndarray has format [x, y, x, y, ...] + mask (ndarray): a binary mask + """ + + def __init__(self, mask_or_polygons, height, width): + self._mask = self._polygons = self._has_holes = None + self.height = height + self.width = width + + m = mask_or_polygons + if isinstance(m, dict): + # RLEs + assert "counts" in m and "size" in m + if isinstance(m["counts"], list): # uncompressed RLEs + h, w = m["size"] + assert h == height and w == width + m = mask_util.frPyObjects(m, h, w) + self._mask = mask_util.decode(m)[:, :] + return + + if isinstance(m, list): # list[ndarray] + self._polygons = [np.asarray(x).reshape(-1) for x in m] + return + + if isinstance(m, np.ndarray): # assumed to be a binary mask + assert m.shape[1] != 2, m.shape + assert m.shape == ( + height, + width, + ), f"mask shape: {m.shape}, target dims: {height}, {width}" + self._mask = m.astype("uint8") + return + + raise ValueError( + "GenericMask cannot handle object {} of type '{}'".format(m, type(m)) + ) + + @property + def mask(self): + if self._mask is None: + self._mask = self.polygons_to_mask(self._polygons) + return self._mask + + @property + def polygons(self): + if self._polygons is None: + self._polygons, self._has_holes = self.mask_to_polygons(self._mask) + return self._polygons + + @property + def has_holes(self): + if self._has_holes is None: + if self._mask is not None: + self._polygons, self._has_holes = self.mask_to_polygons(self._mask) + else: + self._has_holes = ( + False # if original format is polygon, does not have holes + ) + return self._has_holes + + def mask_to_polygons(self, mask): + # cv2.RETR_CCOMP flag retrieves all the contours and arranges them to a 2-level + # hierarchy. External contours (boundary) of the object are placed in hierarchy-1. + # Internal contours (holes) are placed in hierarchy-2. + # cv2.CHAIN_APPROX_NONE flag gets vertices of polygons from contours. + mask = np.ascontiguousarray( + mask + ) # some versions of cv2 does not support incontiguous arr + res = cv2.findContours( + mask.astype("uint8"), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE + ) + hierarchy = res[-1] + if hierarchy is None: # empty mask + return [], False + has_holes = (hierarchy.reshape(-1, 4)[:, 3] >= 0).sum() > 0 + res = res[-2] + res = [x.flatten() for x in res] + # These coordinates from OpenCV are integers in range [0, W-1 or H-1]. + # We add 0.5 to turn them into real-value coordinate space. A better solution + # would be to first +0.5 and then dilate the returned polygon by 0.5. + res = [x + 0.5 for x in res if len(x) >= 6] + return res, has_holes + + def polygons_to_mask(self, polygons): + rle = mask_util.frPyObjects(polygons, self.height, self.width) + rle = mask_util.merge(rle) + return mask_util.decode(rle)[:, :] + + def area(self): + return self.mask.sum() + + def bbox(self): + p = mask_util.frPyObjects(self.polygons, self.height, self.width) + p = mask_util.merge(p) + bbox = mask_util.toBbox(p) + bbox[2] += bbox[0] + bbox[3] += bbox[1] + return bbox + + +class _PanopticPrediction: + """ + Unify different panoptic annotation/prediction formats + """ + + def __init__(self, panoptic_seg, segments_info, metadata=None): + if segments_info is None: + assert metadata is not None + # If "segments_info" is None, we assume "panoptic_img" is a + # H*W int32 image storing the panoptic_id in the format of + # category_id * label_divisor + instance_id. We reserve -1 for + # VOID label. + label_divisor = metadata.label_divisor + segments_info = [] + for panoptic_label in np.unique(panoptic_seg.numpy()): + if panoptic_label == -1: + # VOID region. + continue + pred_class = panoptic_label // label_divisor + isthing = ( + pred_class in metadata.thing_dataset_id_to_contiguous_id.values() + ) + segments_info.append( + { + "id": int(panoptic_label), + "category_id": int(pred_class), + "isthing": bool(isthing), + } + ) + del metadata + + self._seg = panoptic_seg + + self._sinfo = {s["id"]: s for s in segments_info} # seg id -> seg info + segment_ids, areas = torch.unique(panoptic_seg, sorted=True, return_counts=True) + areas = areas.numpy() + sorted_idxs = np.argsort(-areas) + self._seg_ids, self._seg_areas = segment_ids[sorted_idxs], areas[sorted_idxs] + self._seg_ids = self._seg_ids.tolist() + for sid, area in zip(self._seg_ids, self._seg_areas): + if sid in self._sinfo: + self._sinfo[sid]["area"] = float(area) + + def non_empty_mask(self): + """ + Returns: + (H, W) array, a mask for all pixels that have a prediction + """ + empty_ids = [] + for id in self._seg_ids: + if id not in self._sinfo: + empty_ids.append(id) + if len(empty_ids) == 0: + return np.zeros(self._seg.shape, dtype=np.uint8) + assert ( + len(empty_ids) == 1 + ), ">1 ids corresponds to no labels. This is currently not supported" + return (self._seg != empty_ids[0]).numpy().astype(np.bool) + + def semantic_masks(self): + for sid in self._seg_ids: + sinfo = self._sinfo.get(sid) + if sinfo is None or sinfo["isthing"]: + # Some pixels (e.g. id 0 in PanopticFPN) have no instance or semantic predictions. + continue + yield (self._seg == sid).numpy().astype(np.bool), sinfo + + def instance_masks(self): + for sid in self._seg_ids: + sinfo = self._sinfo.get(sid) + if sinfo is None or not sinfo["isthing"]: + continue + mask = (self._seg == sid).numpy().astype(np.bool) + if mask.sum() > 0: + yield mask, sinfo + + +def _create_text_labels(classes, scores, class_names, is_crowd=None): + """ + Args: + classes (list[int] or None): + scores (list[float] or None): + class_names (list[str] or None): + is_crowd (list[bool] or None): + + Returns: + list[str] or None + """ + labels = None + if classes is not None: + if class_names is not None and len(class_names) > 0: + labels = [class_names[i] for i in classes] + else: + labels = [str(i) for i in classes] + if scores is not None: + if labels is None: + labels = ["{:.0f}%".format(s * 100) for s in scores] + else: + labels = ["{} {:.0f}%".format(l, s * 100) for l, s in zip(labels, scores)] + if labels is not None and is_crowd is not None: + labels = [l + ("|crowd" if crowd else "") for l, crowd in zip(labels, is_crowd)] + return labels + + +class VisImage: + def __init__(self, img, scale=1.0): + """ + Args: + img (ndarray): an RGB image of shape (H, W, 3) in range [0, 255]. + scale (float): scale the input image + """ + self.img = img + self.scale = scale + self.width, self.height = img.shape[1], img.shape[0] + self._setup_figure(img) + + def _setup_figure(self, img): + """ + Args: + Same as in :meth:`__init__()`. + + Returns: + fig (matplotlib.pyplot.figure): top level container for all the image plot elements. + ax (matplotlib.pyplot.Axes): contains figure elements and sets the coordinate system. + """ + fig = mplfigure.Figure(frameon=False) + self.dpi = fig.get_dpi() + # add a small 1e-2 to avoid precision lost due to matplotlib's truncation + # (https://github.com/matplotlib/matplotlib/issues/15363) + fig.set_size_inches( + (self.width * self.scale + 1e-2) / self.dpi, + (self.height * self.scale + 1e-2) / self.dpi, + ) + self.canvas = FigureCanvasAgg(fig) + # self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig) + ax = fig.add_axes([0.0, 0.0, 1.0, 1.0]) + ax.axis("off") + self.fig = fig + self.ax = ax + self.reset_image(img) + + def reset_image(self, img): + """ + Args: + img: same as in __init__ + """ + img = img.astype("uint8") + self.ax.imshow( + img, extent=(0, self.width, self.height, 0), interpolation="nearest" + ) + + def save(self, filepath): + """ + Args: + filepath (str): a string that contains the absolute path, including the file name, where + the visualized image will be saved. + """ + self.fig.savefig(filepath) + + def get_image(self): + """ + Returns: + ndarray: + the visualized image of shape (H, W, 3) (RGB) in uint8 type. + The shape is scaled w.r.t the input image using the given `scale` argument. + """ + canvas = self.canvas + s, (width, height) = canvas.print_to_buffer() + # buf = io.BytesIO() # works for cairo backend + # canvas.print_rgba(buf) + # width, height = self.width, self.height + # s = buf.getvalue() + + buffer = np.frombuffer(s, dtype="uint8") + + img_rgba = buffer.reshape(height, width, 4) + rgb, alpha = np.split(img_rgba, [3], axis=2) + return rgb.astype("uint8") + + +class Visualizer: + """ + Visualizer that draws data about detection/segmentation on images. + + It contains methods like `draw_{text,box,circle,line,binary_mask,polygon}` + that draw primitive objects to images, as well as high-level wrappers like + `draw_{instance_predictions,sem_seg,panoptic_seg_predictions,dataset_dict}` + that draw composite data in some pre-defined style. + + Note that the exact visualization style for the high-level wrappers are subject to change. + Style such as color, opacity, label contents, visibility of labels, or even the visibility + of objects themselves (e.g. when the object is too small) may change according + to different heuristics, as long as the results still look visually reasonable. + + To obtain a consistent style, you can implement custom drawing functions with the + abovementioned primitive methods instead. If you need more customized visualization + styles, you can process the data yourself following their format documented in + tutorials (:doc:`/tutorials/models`, :doc:`/tutorials/datasets`). This class does not + intend to satisfy everyone's preference on drawing styles. + + This visualizer focuses on high rendering quality rather than performance. It is not + designed to be used for real-time applications. + """ + + def __init__( + self, + img_rgb, + metadata=None, + scale=1.0, + instance_mode=ColorMode.IMAGE, + font_size_multiplier=1.3, + boarder_width_multiplier=1.5, + ): + """ + Args: + img_rgb: a numpy array of shape (H, W, C), where H and W correspond to + the height and width of the image respectively. C is the number of + color channels. The image is required to be in RGB format since that + is a requirement of the Matplotlib library. The image is also expected + to be in the range [0, 255]. + metadata (Metadata): dataset metadata (e.g. class names and colors) + instance_mode (ColorMode): defines one of the pre-defined style for drawing + instances on an image. + """ + self.img = np.asarray(img_rgb).clip(0, 255).astype(np.uint8) + self.boarder_width_multiplier = boarder_width_multiplier + # if metadata is None: + # metadata = MetadataCatalog.get("__nonexist__") + # self.metadata = metadata + self.output = VisImage(self.img, scale=scale) + self.cpu_device = torch.device("cpu") + + # too small texts are useless, therefore clamp to 9 + self._default_font_size = ( + max(np.sqrt(self.output.height * self.output.width) // 60, 15 // scale) + * font_size_multiplier + ) + # self._default_font_size = 18 + self._instance_mode = instance_mode + self.keypoint_threshold = _KEYPOINT_THRESHOLD + + import matplotlib.colors as mcolors + + css4_colors = mcolors.CSS4_COLORS + self.color_proposals = [ + list(mcolors.hex2color(color)) for color in css4_colors.values() + ] + + def draw_instance_predictions(self, predictions): + """ + Draw instance-level prediction results on an image. + + Args: + predictions (Instances): the output of an instance detection/segmentation + model. Following fields will be used to draw: + "pred_boxes", "pred_classes", "scores", "pred_masks" (or "pred_masks_rle"). + + Returns: + output (VisImage): image object with visualizations. + """ + boxes = predictions.pred_boxes if predictions.has("pred_boxes") else None + scores = predictions.scores if predictions.has("scores") else None + classes = ( + predictions.pred_classes.tolist() + if predictions.has("pred_classes") + else None + ) + labels = _create_text_labels( + classes, scores, self.metadata.get("thing_classes", None) + ) + keypoints = ( + predictions.pred_keypoints if predictions.has("pred_keypoints") else None + ) + + keep = (scores > 0.5).cpu() + boxes = boxes[keep] + scores = scores[keep] + classes = np.array(classes) + classes = classes[np.array(keep)] + labels = np.array(labels) + labels = labels[np.array(keep)] + + if predictions.has("pred_masks"): + masks = np.asarray(predictions.pred_masks) + masks = masks[np.array(keep)] + masks = [ + GenericMask(x, self.output.height, self.output.width) for x in masks + ] + else: + masks = None + + if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get( + "thing_colors" + ): + # if self.metadata.get("thing_colors"): + colors = [ + self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) + for c in classes + ] + alpha = 0.4 + else: + colors = None + alpha = 0.4 + + if self._instance_mode == ColorMode.IMAGE_BW: + self.output.reset_image( + self._create_grayscale_image( + (predictions.pred_masks.any(dim=0) > 0).numpy() + if predictions.has("pred_masks") + else None + ) + ) + alpha = 0.3 + + self.overlay_instances( + masks=masks, + boxes=boxes, + labels=labels, + keypoints=keypoints, + assigned_colors=colors, + alpha=alpha, + ) + return self.output + + def draw_sem_seg(self, sem_seg, area_threshold=None, alpha=0.7): + """ + Draw semantic segmentation predictions/labels. + + Args: + sem_seg (Tensor or ndarray): the segmentation of shape (H, W). + Each value is the integer label of the pixel. + area_threshold (int): segments with less than `area_threshold` are not drawn. + alpha (float): the larger it is, the more opaque the segmentations are. + + Returns: + output (VisImage): image object with visualizations. + """ + if isinstance(sem_seg, torch.Tensor): + sem_seg = sem_seg.numpy() + labels, areas = np.unique(sem_seg, return_counts=True) + sorted_idxs = np.argsort(-areas).tolist() + labels = labels[sorted_idxs] + for label in filter(lambda l: l < len(self.metadata.stuff_classes), labels): + try: + mask_color = [x / 255 for x in self.metadata.stuff_colors[label]] + except (AttributeError, IndexError): + mask_color = None + + binary_mask = (sem_seg == label).astype(np.uint8) + text = self.metadata.stuff_classes[label] + self.draw_binary_mask( + binary_mask, + color=mask_color, + edge_color=_OFF_WHITE, + text=text, + alpha=alpha, + area_threshold=area_threshold, + ) + return self.output + + def draw_panoptic_seg( + self, panoptic_seg, segments_info, area_threshold=None, alpha=0.7 + ): + """ + Draw panoptic prediction annotations or results. + + Args: + panoptic_seg (Tensor): of shape (height, width) where the values are ids for each + segment. + segments_info (list[dict] or None): Describe each segment in `panoptic_seg`. + If it is a ``list[dict]``, each dict contains keys "id", "category_id". + If None, category id of each pixel is computed by + ``pixel // metadata.label_divisor``. + area_threshold (int): stuff segments with less than `area_threshold` are not drawn. + + Returns: + output (VisImage): image object with visualizations. + """ + pred = _PanopticPrediction(panoptic_seg, segments_info, self.metadata) + + if self._instance_mode == ColorMode.IMAGE_BW: + self.output.reset_image(self._create_grayscale_image(pred.non_empty_mask())) + + # draw mask for all semantic segments first i.e. "stuff" + for mask, sinfo in pred.semantic_masks(): + category_idx = sinfo["category_id"] + try: + mask_color = [x / 255 for x in self.metadata.stuff_colors[category_idx]] + except AttributeError: + mask_color = None + + text = ( + self.metadata.stuff_classes[category_idx] + .replace("-other", "") + .replace("-merged", "") + ) + self.draw_binary_mask( + mask, + color=mask_color, + edge_color=_OFF_WHITE, + text=text, + alpha=alpha, + area_threshold=area_threshold, + ) + + # draw mask for all instances second + all_instances = list(pred.instance_masks()) + if len(all_instances) == 0: + return self.output + masks, sinfo = list(zip(*all_instances)) + category_ids = [x["category_id"] for x in sinfo] + + try: + scores = [x["score"] for x in sinfo] + except KeyError: + scores = None + class_names = [ + name.replace("-other", "").replace("-merged", "") + for name in self.metadata.thing_classes + ] + labels = _create_text_labels( + category_ids, scores, class_names, [x.get("iscrowd", 0) for x in sinfo] + ) + + try: + colors = [ + self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) + for c in category_ids + ] + except AttributeError: + colors = None + self.overlay_instances( + masks=masks, labels=labels, assigned_colors=colors, alpha=alpha + ) + + return self.output + + draw_panoptic_seg_predictions = draw_panoptic_seg # backward compatibility + + def draw_dataset_dict(self, dic): + """ + Draw annotations/segmentaions in Detectron2 Dataset format. + + Args: + dic (dict): annotation/segmentation data of one image, in Detectron2 Dataset format. + + Returns: + output (VisImage): image object with visualizations. + """ + annos = dic.get("annotations", None) + if annos: + if "segmentation" in annos[0]: + masks = [x["segmentation"] for x in annos] + else: + masks = None + if "keypoints" in annos[0]: + keypts = [x["keypoints"] for x in annos] + keypts = np.array(keypts).reshape(len(annos), -1, 3) + else: + keypts = None + + boxes = [ + ( + BoxMode.convert(x["bbox"], x["bbox_mode"], BoxMode.XYXY_ABS) + if len(x["bbox"]) == 4 + else x["bbox"] + ) + for x in annos + ] + + colors = None + category_ids = [x["category_id"] for x in annos] + if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get( + "thing_colors" + ): + colors = [ + self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) + for c in category_ids + ] + names = self.metadata.get("thing_classes", None) + labels = _create_text_labels( + category_ids, + scores=None, + class_names=names, + is_crowd=[x.get("iscrowd", 0) for x in annos], + ) + self.overlay_instances( + labels=labels, + boxes=boxes, + masks=masks, + keypoints=keypts, + assigned_colors=colors, + ) + + sem_seg = dic.get("sem_seg", None) + if sem_seg is None and "sem_seg_file_name" in dic: + with PathManager.open(dic["sem_seg_file_name"], "rb") as f: + sem_seg = Image.open(f) + sem_seg = np.asarray(sem_seg, dtype="uint8") + if sem_seg is not None: + self.draw_sem_seg(sem_seg, area_threshold=0, alpha=0.4) + + pan_seg = dic.get("pan_seg", None) + if pan_seg is None and "pan_seg_file_name" in dic: + with PathManager.open(dic["pan_seg_file_name"], "rb") as f: + pan_seg = Image.open(f) + pan_seg = np.asarray(pan_seg) + from panopticapi.utils import rgb2id + + pan_seg = rgb2id(pan_seg) + if pan_seg is not None: + segments_info = dic["segments_info"] + pan_seg = torch.tensor(pan_seg) + self.draw_panoptic_seg(pan_seg, segments_info, area_threshold=0, alpha=0.7) + return self.output + + def overlay_instances( + self, + *, + boxes=None, + labels=None, + masks=None, + keypoints=None, + assigned_colors=None, + binary_masks=None, + alpha=0.5, + label_mode="1", + ): + """ + Args: + boxes (Boxes, RotatedBoxes or ndarray): either a :class:`Boxes`, + or an Nx4 numpy array of XYXY_ABS format for the N objects in a single image, + or a :class:`RotatedBoxes`, + or an Nx5 numpy array of (x_center, y_center, width, height, angle_degrees) format + for the N objects in a single image, + labels (list[str]): the text to be displayed for each instance. + masks (masks-like object): Supported types are: + + * :class:`detectron2.structures.PolygonMasks`, + :class:`detectron2.structures.BitMasks`. + * list[list[ndarray]]: contains the segmentation masks for all objects in one image. + The first level of the list corresponds to individual instances. The second + level to all the polygon that compose the instance, and the third level + to the polygon coordinates. The third level should have the format of + [x0, y0, x1, y1, ..., xn, yn] (n >= 3). + * list[ndarray]: each ndarray is a binary mask of shape (H, W). + * list[dict]: each dict is a COCO-style RLE. + keypoints (Keypoint or array like): an array-like object of shape (N, K, 3), + where the N is the number of instances and K is the number of keypoints. + The last dimension corresponds to (x, y, visibility or score). + assigned_colors (list[matplotlib.colors]): a list of colors, where each color + corresponds to each mask or box in the image. Refer to 'matplotlib.colors' + for full list of formats that the colors are accepted in. + Returns: + output (VisImage): image object with visualizations. + """ + num_instances = 0 + if boxes is not None: + boxes = self._convert_boxes(boxes) + num_instances = len(boxes) + if masks is not None: + masks = self._convert_masks(masks) + if num_instances: + assert len(masks) == num_instances + else: + num_instances = len(masks) + if keypoints is not None: + if num_instances: + assert len(keypoints) == num_instances + else: + num_instances = len(keypoints) + keypoints = self._convert_keypoints(keypoints) + if labels is not None: + assert len(labels) == num_instances + if assigned_colors is None: + assigned_colors = [ + random_color(rgb=True, maximum=1) for _ in range(num_instances) + ] + if num_instances == 0: + return labels, [], [] + if boxes is not None and boxes.shape[1] == 5: + return self.overlay_rotated_instances( + boxes=boxes, labels=labels, assigned_colors=assigned_colors + ) + + # Display in largest to smallest order to reduce occlusion. + areas = None + if boxes is not None: + areas = np.prod(boxes[:, 2:] - boxes[:, :2], axis=1) + elif masks is not None: + areas = np.asarray([x.area() for x in masks]) + + # if areas is not None: + # # sorted_idxs = np.argsort(areas).tolist() + # sorted_idxs = np.argsort(-areas).tolist() + # # Re-order overlapped instances in descending order. + # boxes = boxes[sorted_idxs] if boxes is not None else None + # labels = [labels[k] for k in sorted_idxs] if labels is not None else None + # masks = [masks[idx] for idx in sorted_idxs] if masks is not None else None + # binary_masks = ( + # [binary_masks[idx] for idx in sorted_idxs] + # if binary_masks is not None + # else None + # ) + # assigned_colors = [assigned_colors[idx] for idx in sorted_idxs] + # keypoints = keypoints[sorted_idxs] if keypoints is not None else None + + marks = [] + marks_position = [] + added_positions = set() + for i in range(num_instances): + color = assigned_colors[i] + if boxes is not None: + self.draw_box(boxes[i], alpha=1, edge_color=color) + if binary_masks is None: + # draw number for non-mask instances + mark = self._draw_number_in_box( + boxes[i], i + 1, color=color, label_mode=label_mode + ) + marks.append(mark) + + if binary_masks is not None: + mark, mask_position = self._draw_number_in_mask( + binary_mask=binary_masks[i].astype("uint8"), + text=i + 1, + color=color, + added_positions=added_positions, + label_mode=label_mode, + ) + marks.append(mark) + marks_position.append(mask_position) + + self.draw_binary_mask( + binary_masks[i], + color=color, + edge_color=_OFF_WHITE, + alpha=alpha, + ) + + if masks is not None: + for segment in masks[i].polygons: + self.draw_polygon( + segment.reshape(-1, 2), color, alpha=0 + ) # alpha=0 so holes in masks are not colored + + # draw keypoints + if keypoints is not None: + for keypoints_per_instance in keypoints: + self.draw_and_connect_keypoints(keypoints_per_instance) + + # return labels, marks, sorted_idxs, marks_position + return labels, marks, marks_position + + def overlay_rotated_instances(self, boxes=None, labels=None, assigned_colors=None): + """ + Args: + boxes (ndarray): an Nx5 numpy array of + (x_center, y_center, width, height, angle_degrees) format + for the N objects in a single image. + labels (list[str]): the text to be displayed for each instance. + assigned_colors (list[matplotlib.colors]): a list of colors, where each color + corresponds to each mask or box in the image. Refer to 'matplotlib.colors' + for full list of formats that the colors are accepted in. + + Returns: + output (VisImage): image object with visualizations. + """ + num_instances = len(boxes) + + if assigned_colors is None: + assigned_colors = [ + random_color(rgb=True, maximum=1) for _ in range(num_instances) + ] + if num_instances == 0: + return self.output + + # Display in largest to smallest order to reduce occlusion. + if boxes is not None: + areas = boxes[:, 2] * boxes[:, 3] + + sorted_idxs = np.argsort(-areas).tolist() + # Re-order overlapped instances in descending order. + boxes = boxes[sorted_idxs] + labels = [labels[k] for k in sorted_idxs] if labels is not None else None + colors = [assigned_colors[idx] for idx in sorted_idxs] + + for i in range(num_instances): + self.draw_rotated_box_with_label( + boxes[i], + edge_color=colors[i], + label=labels[i] if labels is not None else None, + ) + + return self.output + + def draw_and_connect_keypoints(self, keypoints): + """ + Draws keypoints of an instance and follows the rules for keypoint connections + to draw lines between appropriate keypoints. This follows color heuristics for + line color. + + Args: + keypoints (Tensor): a tensor of shape (K, 3), where K is the number of keypoints + and the last dimension corresponds to (x, y, probability). + + Returns: + output (VisImage): image object with visualizations. + """ + visible = {} + keypoint_names = self.metadata.get("keypoint_names") + for idx, keypoint in enumerate(keypoints): + # draw keypoint + x, y, prob = keypoint + if prob > self.keypoint_threshold: + self.draw_circle((x, y), color=_RED) + if keypoint_names: + keypoint_name = keypoint_names[idx] + visible[keypoint_name] = (x, y) + + if self.metadata.get("keypoint_connection_rules"): + for kp0, kp1, color in self.metadata.keypoint_connection_rules: + if kp0 in visible and kp1 in visible: + x0, y0 = visible[kp0] + x1, y1 = visible[kp1] + color = tuple(x / 255.0 for x in color) + self.draw_line([x0, x1], [y0, y1], color=color) + + # draw lines from nose to mid-shoulder and mid-shoulder to mid-hip + # Note that this strategy is specific to person keypoints. + # For other keypoints, it should just do nothing + try: + ls_x, ls_y = visible["left_shoulder"] + rs_x, rs_y = visible["right_shoulder"] + mid_shoulder_x, mid_shoulder_y = (ls_x + rs_x) / 2, (ls_y + rs_y) / 2 + except KeyError: + pass + else: + # draw line from nose to mid-shoulder + nose_x, nose_y = visible.get("nose", (None, None)) + if nose_x is not None: + self.draw_line( + [nose_x, mid_shoulder_x], [nose_y, mid_shoulder_y], color=_RED + ) + + try: + # draw line from mid-shoulder to mid-hip + lh_x, lh_y = visible["left_hip"] + rh_x, rh_y = visible["right_hip"] + except KeyError: + pass + else: + mid_hip_x, mid_hip_y = (lh_x + rh_x) / 2, (lh_y + rh_y) / 2 + self.draw_line( + [mid_hip_x, mid_shoulder_x], [mid_hip_y, mid_shoulder_y], color=_RED + ) + return self.output + + def mask_dims_from_binary(self, binary_mask): + ind_y, ind_x = np.where(binary_mask == 1) + min_ind_x = np.min(ind_x) + max_ind_x = np.max(ind_x) + min_ind_y = np.min(ind_y) + max_ind_y = np.max(ind_y) + return (max_ind_x - min_ind_x), (max_ind_y - min_ind_y) + + def reposition_label(self, position, cur, binary_mask, move_count): + img_width, img_height = self.output.width, self.output.height + mask_width, mask_height = self.mask_dims_from_binary(binary_mask) + + # set resposition thresholds + mask_width_limit, mask_height_limit = ( + 25, + 25, + ) # limit for width and height size for object covering + location_diff_threshold = 15 # limit for the distance between two labels + x_boundry_limit, y_boundry_limit = ( + 20, + 20, + ) # limit for the distancing the label from edges + + offset_x = 15 # move in x direction + offset_y = 15 # move in y direction + + x1, y1 = position + + if ( + mask_width < mask_width_limit + and mask_height < mask_height_limit + and move_count == 0 + ): + move_x = offset_x if offset_x + x1 < img_width else -offset_x + move_y = offset_y if offset_y + y1 < img_height else -offset_y + return (True, move_x, move_y) + + for x2, y2 in cur: + if abs(x1 - x2) + abs(y1 - y2) < location_diff_threshold: + move_x = offset_x if x1 >= x2 else -offset_x + move_y = offset_y if y1 >= y2 else -offset_y + move_x = ( + 0 + if x1 + move_x > img_width - x_boundry_limit + or x1 + move_x < x_boundry_limit + else move_x + ) + move_y = ( + 0 + if y1 + move_y > img_height - y_boundry_limit + or y1 + move_y < y_boundry_limit + else move_y + ) + return ( + True, + move_x, + move_y, + ) + return (False, 0, 0) + + def locate_label_position(self, original_position, added_positions, binary_mask): + if added_positions is None or binary_mask is None: + return original_position + + x, y = original_position + + move_count = 0 + reposition, x_move, y_move = self.reposition_label( + (x, y), added_positions, binary_mask, move_count + ) + while reposition and move_count < 10: + x += x_move + y += y_move + move_count += 1 + reposition, x_move, y_move = self.reposition_label( + (x, y), added_positions, binary_mask, move_count + ) + added_positions.add((x, y)) + return x, y + + """ + Primitive drawing functions: + """ + + def draw_text( + self, + text, + position, + added_positions=None, + binary_mask=None, + *, + font_size=None, + color="g", + horizontal_alignment="center", + rotation=0, + ): + """ + Args: + text (str): class label + position (tuple): a tuple of the x and y coordinates to place text on image. + font_size (int, optional): font of the text. If not provided, a font size + proportional to the image width is calculated and used. + color: color of the text. Refer to `matplotlib.colors` for full list + of formats that are accepted. + horizontal_alignment (str): see `matplotlib.text.Text` + rotation: rotation angle in degrees CCW + + Returns: + output (VisImage): image object with text drawn. + """ + if not font_size: + font_size = self._default_font_size + + # since the text background is dark, we don't want the text to be dark + color = np.maximum(list(mplc.to_rgb(color)), 0.15) + color[np.argmax(color)] = max(0.8, np.max(color)) + + def contrasting_color(rgb): + """Returns 'white' or 'black' depending on which color contrasts more with the given RGB value.""" + + # Decompose the RGB tuple + R, G, B = rgb + + # Calculate the Y value + Y = 0.299 * R + 0.587 * G + 0.114 * B + + # If Y value is greater than 128, it's closer to white so return black. Otherwise, return white. + return "black" if Y > 128 else "white" + + bbox_background = contrasting_color(color * 255) + + x, y = self.locate_label_position( + original_position=position, + added_positions=added_positions, + binary_mask=binary_mask, + ) + + self.output.ax.text( + x, + y, + text, + size=font_size * self.output.scale, + family="sans-serif", + bbox={ + "facecolor": bbox_background, + "alpha": 0.8, + "pad": 0.7, + "edgecolor": "none", + }, + verticalalignment="top", + horizontalalignment=horizontal_alignment, + color=color, + zorder=10, + rotation=rotation, + ) + return self.output + + def draw_box(self, box_coord, alpha=0.5, edge_color="g", line_style="-"): + """ + Args: + box_coord (tuple): a tuple containing x0, y0, x1, y1 coordinates, where x0 and y0 + are the coordinates of the image's top left corner. x1 and y1 are the + coordinates of the image's bottom right corner. + alpha (float): blending efficient. Smaller values lead to more transparent masks. + edge_color: color of the outline of the box. Refer to `matplotlib.colors` + for full list of formats that are accepted. + line_style (string): the string to use to create the outline of the boxes. + + Returns: + output (VisImage): image object with box drawn. + """ + x0, y0, x1, y1 = box_coord + width = x1 - x0 + height = y1 - y0 + + linewidth = max(self._default_font_size / 12, 1) * self.boarder_width_multiplier + + self.output.ax.add_patch( + mpl.patches.Rectangle( + (x0, y0), + width, + height, + fill=False, + edgecolor=edge_color, + linewidth=linewidth * self.output.scale, + alpha=alpha, + linestyle=line_style, + ) + ) + return self.output + + def draw_rotated_box_with_label( + self, rotated_box, alpha=0.5, edge_color="g", line_style="-", label=None + ): + """ + Draw a rotated box with label on its top-left corner. + + Args: + rotated_box (tuple): a tuple containing (cnt_x, cnt_y, w, h, angle), + where cnt_x and cnt_y are the center coordinates of the box. + w and h are the width and height of the box. angle represents how + many degrees the box is rotated CCW with regard to the 0-degree box. + alpha (float): blending efficient. Smaller values lead to more transparent masks. + edge_color: color of the outline of the box. Refer to `matplotlib.colors` + for full list of formats that are accepted. + line_style (string): the string to use to create the outline of the boxes. + label (string): label for rotated box. It will not be rendered when set to None. + + Returns: + output (VisImage): image object with box drawn. + """ + cnt_x, cnt_y, w, h, angle = rotated_box + area = w * h + # use thinner lines when the box is small + linewidth = self._default_font_size / ( + 6 if area < _SMALL_OBJECT_AREA_THRESH * self.output.scale else 3 + ) + + theta = angle * math.pi / 180.0 + c = math.cos(theta) + s = math.sin(theta) + rect = [(-w / 2, h / 2), (-w / 2, -h / 2), (w / 2, -h / 2), (w / 2, h / 2)] + # x: left->right ; y: top->down + rotated_rect = [ + (s * yy + c * xx + cnt_x, c * yy - s * xx + cnt_y) for (xx, yy) in rect + ] + for k in range(4): + j = (k + 1) % 4 + self.draw_line( + [rotated_rect[k][0], rotated_rect[j][0]], + [rotated_rect[k][1], rotated_rect[j][1]], + color=edge_color, + linestyle="--" if k == 1 else line_style, + linewidth=linewidth, + ) + + if label is not None: + text_pos = rotated_rect[1] # topleft corner + + height_ratio = h / np.sqrt(self.output.height * self.output.width) + label_color = self._change_color_brightness( + edge_color, brightness_factor=0.7 + ) + font_size = ( + np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2) + * 0.5 + * self._default_font_size + ) + self.draw_text( + label, text_pos, color=label_color, font_size=font_size, rotation=angle + ) + + return self.output + + def draw_circle(self, circle_coord, color, radius=3): + """ + Args: + circle_coord (list(int) or tuple(int)): contains the x and y coordinates + of the center of the circle. + color: color of the polygon. Refer to `matplotlib.colors` for a full list of + formats that are accepted. + radius (int): radius of the circle. + + Returns: + output (VisImage): image object with box drawn. + """ + x, y = circle_coord + self.output.ax.add_patch( + mpl.patches.Circle(circle_coord, radius=radius, fill=True, color=color) + ) + return self.output + + def draw_line(self, x_data, y_data, color, linestyle="-", linewidth=None): + """ + Args: + x_data (list[int]): a list containing x values of all the points being drawn. + Length of list should match the length of y_data. + y_data (list[int]): a list containing y values of all the points being drawn. + Length of list should match the length of x_data. + color: color of the line. Refer to `matplotlib.colors` for a full list of + formats that are accepted. + linestyle: style of the line. Refer to `matplotlib.lines.Line2D` + for a full list of formats that are accepted. + linewidth (float or None): width of the line. When it's None, + a default value will be computed and used. + + Returns: + output (VisImage): image object with line drawn. + """ + if linewidth is None: + linewidth = self._default_font_size / 3 + linewidth = max(linewidth, 1) + self.output.ax.add_line( + mpl.lines.Line2D( + x_data, + y_data, + linewidth=linewidth * self.output.scale, + color=color, + linestyle=linestyle, + ) + ) + return self.output + + def draw_binary_mask( + self, + binary_mask, + color=None, + *, + edge_color=None, + text=None, + alpha=0.7, + area_threshold=10, + ): + """ + Args: + binary_mask (ndarray): numpy array of shape (H, W), where H is the image height and + W is the image width. Each value in the array is either a 0 or 1 value of uint8 + type. + color: color of the mask. Refer to `matplotlib.colors` for a full list of + formats that are accepted. If None, will pick a random color. + edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a + full list of formats that are accepted. + text (str): if None, will be drawn on the object + alpha (float): blending efficient. Smaller values lead to more transparent masks. + area_threshold (float): a connected component smaller than this area will not be shown. + + Returns: + output (VisImage): image object with mask drawn. + """ + if color is None: + color = random_color(rgb=True, maximum=1) + color = mplc.to_rgb(color) + + has_valid_segment = False + binary_mask = binary_mask.astype("uint8") # opencv needs uint8 + mask = GenericMask(binary_mask, self.output.height, self.output.width) + shape2d = (binary_mask.shape[0], binary_mask.shape[1]) + + if not mask.has_holes: + # draw polygons for regular masks + for segment in mask.polygons: + area = mask_util.area( + mask_util.frPyObjects([segment], shape2d[0], shape2d[1]) + ) + if area < (area_threshold or 0): + continue + has_valid_segment = True + segment = segment.reshape(-1, 2) + self.draw_polygon( + segment, color=color, edge_color=edge_color, alpha=alpha + ) + else: + # https://stackoverflow.com/questions/8919719/how-to-plot-a-complex-polygon + rgba = np.zeros(shape2d + (4,), dtype="float32") + rgba[:, :, :3] = color + rgba[:, :, 3] = (mask.mask == 1).astype("float32") * alpha + has_valid_segment = True + self.output.ax.imshow( + rgba, extent=(0, self.output.width, self.output.height, 0) + ) + + if text is not None and has_valid_segment: + lighter_color = self._change_color_brightness(color, brightness_factor=0.7) + self._draw_text_in_mask(binary_mask, text, lighter_color) + return self.output + + def draw_binary_mask_with_number( + self, + binary_mask, + color=None, + *, + edge_color=None, + text=None, + label_mode="1", + alpha=0.1, + anno_mode=["Mask"], + area_threshold=10, + ): + """ + Args: + binary_mask (ndarray): numpy array of shape (H, W), where H is the image height and + W is the image width. Each value in the array is either a 0 or 1 value of uint8 + type. + color: color of the mask. Refer to `matplotlib.colors` for a full list of + formats that are accepted. If None, will pick a random color. + edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a + full list of formats that are accepted. + text (str): if None, will be drawn on the object + alpha (float): blending efficient. Smaller values lead to more transparent masks. + area_threshold (float): a connected component smaller than this area will not be shown. + + Returns: + output (VisImage): image object with mask drawn. + """ + if color is None: + randint = random.randint(0, len(self.color_proposals) - 1) + color = self.color_proposals[randint] + color = mplc.to_rgb(color) + + has_valid_segment = True + binary_mask = binary_mask.astype("uint8") # opencv needs uint8 + mask = GenericMask(binary_mask, self.output.height, self.output.width) + shape2d = (binary_mask.shape[0], binary_mask.shape[1]) + bbox = mask.bbox() + + if "Mask" in anno_mode: + if not mask.has_holes: + # draw polygons for regular masks + for segment in mask.polygons: + area = mask_util.area( + mask_util.frPyObjects([segment], shape2d[0], shape2d[1]) + ) + if area < (area_threshold or 0): + continue + has_valid_segment = True + segment = segment.reshape(-1, 2) + self.draw_polygon( + segment, color=color, edge_color=edge_color, alpha=alpha + ) + else: + # https://stackoverflow.com/questions/8919719/how-to-plot-a-complex-polygon + rgba = np.zeros(shape2d + (4,), dtype="float32") + rgba[:, :, :3] = color + rgba[:, :, 3] = (mask.mask == 1).astype("float32") * alpha + has_valid_segment = True + self.output.ax.imshow( + rgba, extent=(0, self.output.width, self.output.height, 0) + ) + + if "Box" in anno_mode: + self.draw_box(bbox, edge_color=color, alpha=0.75) + + if "Mark" in anno_mode: + has_valid_segment = True + else: + has_valid_segment = False + + if text is not None and has_valid_segment: + # lighter_color = tuple([x*0.2 for x in color]) + lighter_color = [ + 1, + 1, + 1, + ] # self._change_color_brightness(color, brightness_factor=0.7) + self._draw_number_in_mask( + binary_mask=binary_mask, + text=text, + color=lighter_color, + label_mode=label_mode, + ) + return self.output + + def draw_soft_mask(self, soft_mask, color=None, *, text=None, alpha=0.5): + """ + Args: + soft_mask (ndarray): float array of shape (H, W), each value in [0, 1]. + color: color of the mask. Refer to `matplotlib.colors` for a full list of + formats that are accepted. If None, will pick a random color. + text (str): if None, will be drawn on the object + alpha (float): blending efficient. Smaller values lead to more transparent masks. + + Returns: + output (VisImage): image object with mask drawn. + """ + if color is None: + color = random_color(rgb=True, maximum=1) + color = mplc.to_rgb(color) + + shape2d = (soft_mask.shape[0], soft_mask.shape[1]) + rgba = np.zeros(shape2d + (4,), dtype="float32") + rgba[:, :, :3] = color + rgba[:, :, 3] = soft_mask * alpha + self.output.ax.imshow( + rgba, extent=(0, self.output.width, self.output.height, 0) + ) + + if text is not None: + lighter_color = self._change_color_brightness(color, brightness_factor=0.7) + binary_mask = (soft_mask > 0.5).astype("uint8") + self._draw_text_in_mask(binary_mask, text, lighter_color) + return self.output + + def draw_polygon(self, segment, color, edge_color=None, alpha=0.5): + """ + Args: + segment: numpy array of shape Nx2, containing all the points in the polygon. + color: color of the polygon. Refer to `matplotlib.colors` for a full list of + formats that are accepted. + edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a + full list of formats that are accepted. If not provided, a darker shade + of the polygon color will be used instead. + alpha (float): blending efficient. Smaller values lead to more transparent masks. + + Returns: + output (VisImage): image object with polygon drawn. + """ + if edge_color is None: + # make edge color darker than the polygon color + if alpha > 0.8: + edge_color = self._change_color_brightness( + color, brightness_factor=-0.7 + ) + else: + edge_color = color + edge_color = mplc.to_rgb(edge_color) + (1,) + + polygon = mpl.patches.Polygon( + segment, + fill=True, + facecolor=mplc.to_rgb(color) + (alpha,), + edgecolor=edge_color, + linewidth=max(self._default_font_size // 15 * self.output.scale, 1), + ) + self.output.ax.add_patch(polygon) + return self.output + + """ + Internal methods: + """ + + def _jitter(self, color): + """ + Randomly modifies given color to produce a slightly different color than the color given. + + Args: + color (tuple[double]): a tuple of 3 elements, containing the RGB values of the color + picked. The values in the list are in the [0.0, 1.0] range. + + Returns: + jittered_color (tuple[double]): a tuple of 3 elements, containing the RGB values of the + color after being jittered. The values in the list are in the [0.0, 1.0] range. + """ + color = mplc.to_rgb(color) + # np.random.seed(0) + vec = np.random.rand(3) + # better to do it in another color space + vec = vec / np.linalg.norm(vec) * 0.5 + res = np.clip(vec + color, 0, 1) + return tuple(res) + + def _create_grayscale_image(self, mask=None): + """ + Create a grayscale version of the original image. + The colors in masked area, if given, will be kept. + """ + img_bw = self.img.astype("f4").mean(axis=2) + img_bw = np.stack([img_bw] * 3, axis=2) + if mask is not None: + img_bw[mask] = self.img[mask] + return img_bw + + def _change_color_brightness(self, color, brightness_factor): + """ + Depending on the brightness_factor, gives a lighter or darker color i.e. a color with + less or more saturation than the original color. + + Args: + color: color of the polygon. Refer to `matplotlib.colors` for a full list of + formats that are accepted. + brightness_factor (float): a value in [-1.0, 1.0] range. A lightness factor of + 0 will correspond to no change, a factor in [-1.0, 0) range will result in + a darker color and a factor in (0, 1.0] range will result in a lighter color. + + Returns: + modified_color (tuple[double]): a tuple containing the RGB values of the + modified color. Each value in the tuple is in the [0.0, 1.0] range. + """ + assert brightness_factor >= -1.0 and brightness_factor <= 1.0 + color = mplc.to_rgb(color) + polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color)) + modified_lightness = polygon_color[1] + (brightness_factor * polygon_color[1]) + modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness + modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness + modified_color = colorsys.hls_to_rgb( + polygon_color[0], modified_lightness, polygon_color[2] + ) + return modified_color + + def _convert_boxes(self, boxes): + """ + Convert different format of boxes to an NxB array, where B = 4 or 5 is the box dimension. + """ + if isinstance(boxes, Boxes) or isinstance(boxes, RotatedBoxes): + return boxes.tensor.detach().numpy() + else: + return np.asarray(boxes) + + def _convert_masks(self, masks_or_polygons): + """ + Convert different format of masks or polygons to a tuple of masks and polygons. + + Returns: + list[GenericMask]: + """ + + m = masks_or_polygons + if isinstance(m, PolygonMasks): + m = m.polygons + if isinstance(m, BitMasks): + m = m.tensor.numpy() + if isinstance(m, torch.Tensor): + m = m.numpy() + ret = [] + for x in m: + if isinstance(x, GenericMask): + ret.append(x) + else: + ret.append(GenericMask(x, self.output.height, self.output.width)) + return ret + + def _draw_number_in_box(self, box, text, color, label_mode="1"): + """ + Find proper places to draw text given a box. + """ + x0, y0, x1, y1 = box + text_pos = (x0, y0) # if drawing boxes, put text on the box corner. + horiz_align = "left" + # for small objects, draw text at the side to avoid occlusion + instance_area = (y1 - y0) * (x1 - x0) + if ( + instance_area < _SMALL_OBJECT_AREA_THRESH * self.output.scale + or y1 - y0 < 40 * self.output.scale + ): + if y1 >= self.output.height - 5: + text_pos = (x1, y0) + else: + text_pos = (x0, y1) + + height_ratio = (y1 - y0) / np.sqrt(self.output.height * self.output.width) + lighter_color = self._change_color_brightness(color, brightness_factor=0.7) + font_size = ( + np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2) + * 0.65 + * self._default_font_size + ) + if label_mode == "a": + text = self.number_to_string(int(text)) + else: + text = text + self.draw_text( + text, + text_pos, + color=lighter_color, + horizontal_alignment=horiz_align, + font_size=font_size, + ) + + return str(text) + + @staticmethod + def number_to_string(n): + chars = [] + while n: + n, remainder = divmod(n - 1, 26) + chars.append(chr(97 + remainder)) + return "".join(reversed(chars)) + + def _draw_number_in_mask( + self, binary_mask, text, color, added_positions=None, label_mode="1" + ): + """ + Find proper places to draw text given a binary mask. + """ + binary_mask = np.pad(binary_mask, ((1, 1), (1, 1)), "constant") + mask_dt = cv2.distanceTransform(binary_mask, cv2.DIST_L2, 0) + mask_dt = mask_dt[1:-1, 1:-1] + max_dist = np.max(mask_dt) + coords_y, coords_x = np.where(mask_dt == max_dist) # coords is [y, x] + + if label_mode == "a": + text = self.number_to_string(int(text)) + else: + text = text + + text_position = ( + coords_x[len(coords_x) // 2] + 2, + coords_y[len(coords_y) // 2] - 6, + ) + self.draw_text( + text, + text_position, + added_positions=added_positions, + binary_mask=binary_mask, + color=color, + ) + + return str(text), text_position + + # _num_cc, cc_labels, stats, centroids = cv2.connectedComponentsWithStats(binary_mask, 8) + # if stats[1:, -1].size == 0: + # return + # largest_component_id = np.argmax(stats[1:, -1]) + 1 + + # # draw text on the largest component, as well as other very large components. + # for cid in range(1, _num_cc): + # if cid == largest_component_id or stats[cid, -1] > _LARGE_MASK_AREA_THRESH: + # # median is more stable than centroid + # # center = centroids[largest_component_id] + # center = np.median((cc_labels == cid).nonzero(), axis=1)[::-1] + # # bottom=np.max((cc_labels == cid).nonzero(), axis=1)[::-1] + # # center[1]=bottom[1]+2 + # self.draw_text(text, center, color=color) + + def _draw_text_in_mask(self, binary_mask, text, color): + """ + Find proper places to draw text given a binary mask. + """ + _num_cc, cc_labels, stats, centroids = cv2.connectedComponentsWithStats( + binary_mask, 8 + ) + if stats[1:, -1].size == 0: + return + largest_component_id = np.argmax(stats[1:, -1]) + 1 + + # draw text on the largest component, as well as other very large components. + for cid in range(1, _num_cc): + if cid == largest_component_id or stats[cid, -1] > _LARGE_MASK_AREA_THRESH: + # median is more stable than centroid + # center = centroids[largest_component_id] + center = np.median((cc_labels == cid).nonzero(), axis=1)[::-1] + bottom = np.max((cc_labels == cid).nonzero(), axis=1)[::-1] + center[1] = bottom[1] + 2 + self.draw_text(text, center, color=color) + + def _convert_keypoints(self, keypoints): + if isinstance(keypoints, Keypoints): + keypoints = keypoints.tensor + keypoints = np.asarray(keypoints) + return keypoints + + def get_output(self): + """ + Returns: + output (VisImage): the image output containing the visualizations added + to the image. + """ + return self.output diff --git a/third_party/sam3/sam3/agent/helpers/zoom_in.py b/third_party/sam3/sam3/agent/helpers/zoom_in.py new file mode 100644 index 0000000000000000000000000000000000000000..42a8c94afe24f5a4f271535b1d77e64b4b15c4ae --- /dev/null +++ b/third_party/sam3/sam3/agent/helpers/zoom_in.py @@ -0,0 +1,197 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +import io +import math + +import matplotlib.pyplot as plt +import numpy as np +import pycocotools.mask as mask_utils +from PIL import Image + +from .som_utils import ColorPalette, draw_box, draw_mask, draw_text + + +def render_zoom_in( + object_data, + image_file, + show_box: bool = True, + show_text: bool = False, + show_holes: bool = True, + mask_alpha: float = 0.15, +): + """ + Render a two-panel visualization with a cropped original view (left/upper) and a zoomed-in + mask overlay (right/lower), then return it as a PIL.Image along with the chosen mask color (hex). + + Parameters + ---------- + object_data : dict + Dict containing "labels" and COCO RLE "segmentation". + Expected: + object_data["labels"][0]["noun_phrase"] : str + object_data["segmentation"] : COCO RLE (with "size": [H, W]) + image_file : PIL.Image.Image + Source image (PIL). + show_box : bool + Whether to draw the bbox on the cropped original panel. + show_text : bool + Whether to draw the noun phrase label near the bbox. + show_holes : bool + Whether to render mask holes (passed through to draw_mask). + mask_alpha : float + Alpha for the mask overlay. + + Returns + ------- + pil_img : PIL.Image.Image + The composed visualization image. + color_hex : str + Hex string of the chosen mask color. + """ + + # ---- local constants (avoid module-level globals) ---- + _AREA_LARGE = 0.25 + _AREA_MEDIUM = 0.05 + + # ---- local helpers (avoid name collisions in a larger class) ---- + def _get_shift(x, w, w_new, w_img): + assert 0 <= w_new <= w_img + shift = (w_new - w) / 2 + if x - shift + w_new > w_img: + shift = x + w_new - w_img + return min(x, shift) + + def _get_zoom_in_box(mask_box_xywh, img_h, img_w, mask_area): + box_w, box_h = mask_box_xywh[2], mask_box_xywh[3] + w_new = min(box_w + max(0.2 * box_w, 16), img_w) + h_new = min(box_h + max(0.2 * box_h, 16), img_h) + + mask_relative_area = mask_area / (w_new * h_new) + + # zoom-in (larger box if mask is relatively big) + w_new_large, h_new_large = w_new, h_new + if mask_relative_area > _AREA_LARGE: + ratio_large = math.sqrt(mask_relative_area / _AREA_LARGE) + w_new_large = min(w_new * ratio_large, img_w) + h_new_large = min(h_new * ratio_large, img_h) + + w_shift_large = _get_shift( + mask_box_xywh[0], mask_box_xywh[2], w_new_large, img_w + ) + h_shift_large = _get_shift( + mask_box_xywh[1], mask_box_xywh[3], h_new_large, img_h + ) + zoom_in_box = [ + mask_box_xywh[0] - w_shift_large, + mask_box_xywh[1] - h_shift_large, + w_new_large, + h_new_large, + ] + + # crop box for the original/cropped image + w_new_medium, h_new_medium = w_new, h_new + if mask_relative_area > _AREA_MEDIUM: + ratio_med = math.sqrt(mask_relative_area / _AREA_MEDIUM) + w_new_medium = min(w_new * ratio_med, img_w) + h_new_medium = min(h_new * ratio_med, img_h) + + w_shift_medium = _get_shift( + mask_box_xywh[0], mask_box_xywh[2], w_new_medium, img_w + ) + h_shift_medium = _get_shift( + mask_box_xywh[1], mask_box_xywh[3], h_new_medium, img_h + ) + img_crop_box = [ + mask_box_xywh[0] - w_shift_medium, + mask_box_xywh[1] - h_shift_medium, + w_new_medium, + h_new_medium, + ] + return zoom_in_box, img_crop_box + + # ---- main body ---- + # Input parsing + object_label = object_data["labels"][0]["noun_phrase"] + img = image_file.convert("RGB") + bbox_xywh = mask_utils.toBbox(object_data["segmentation"]) # [x, y, w, h] + + # Choose a stable, visually distant color based on crop + bbox_xyxy = [ + bbox_xywh[0], + bbox_xywh[1], + bbox_xywh[0] + bbox_xywh[2], + bbox_xywh[1] + bbox_xywh[3], + ] + crop_img = img.crop(bbox_xyxy) + color_palette = ColorPalette.default() + color_obj, _ = color_palette.find_farthest_color(np.array(crop_img)) + color = np.array([color_obj.r / 255, color_obj.g / 255, color_obj.b / 255]) + color_hex = f"#{color_obj.r:02x}{color_obj.g:02x}{color_obj.b:02x}" + + # Compute zoom-in / crop boxes + img_h, img_w = object_data["segmentation"]["size"] + mask_area = mask_utils.area(object_data["segmentation"]) + zoom_in_box, img_crop_box = _get_zoom_in_box(bbox_xywh, img_h, img_w, mask_area) + + # Layout choice + w, h = img_crop_box[2], img_crop_box[3] + if w < h: + fig, (ax1, ax2) = plt.subplots(1, 2) + else: + fig, (ax1, ax2) = plt.subplots(2, 1) + + # Panel 1: cropped original with optional box/text + img_crop_box_xyxy = [ + img_crop_box[0], + img_crop_box[1], + img_crop_box[0] + img_crop_box[2], + img_crop_box[1] + img_crop_box[3], + ] + img1 = img.crop(img_crop_box_xyxy) + bbox_xywh_rel = [ + bbox_xywh[0] - img_crop_box[0], + bbox_xywh[1] - img_crop_box[1], + bbox_xywh[2], + bbox_xywh[3], + ] + ax1.imshow(img1) + ax1.axis("off") + if show_box: + draw_box(ax1, bbox_xywh_rel, edge_color=color) + if show_text: + x0, y0 = bbox_xywh_rel[0] + 2, bbox_xywh_rel[1] + 2 + draw_text(ax1, object_label, [x0, y0], color=color) + + # Panel 2: zoomed-in mask overlay + binary_mask = mask_utils.decode(object_data["segmentation"]) + alpha = Image.fromarray((binary_mask * 255).astype("uint8")) + img_rgba = img.convert("RGBA") + img_rgba.putalpha(alpha) + zoom_in_box_xyxy = [ + zoom_in_box[0], + zoom_in_box[1], + zoom_in_box[0] + zoom_in_box[2], + zoom_in_box[1] + zoom_in_box[3], + ] + img_with_alpha_zoomin = img_rgba.crop(zoom_in_box_xyxy) + alpha_zoomin = img_with_alpha_zoomin.split()[3] + binary_mask_zoomin = np.array(alpha_zoomin).astype(bool) + + ax2.imshow(img_with_alpha_zoomin.convert("RGB")) + ax2.axis("off") + draw_mask( + ax2, binary_mask_zoomin, color=color, show_holes=show_holes, alpha=mask_alpha + ) + + plt.tight_layout() + + # Buffer -> PIL.Image + buf = io.BytesIO() + fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0, dpi=100) + plt.close(fig) + buf.seek(0) + pil_img = Image.open(buf) + + return pil_img, color_hex diff --git a/third_party/sam3/sam3/agent/inference.py b/third_party/sam3/sam3/agent/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..01f1b63aafc36acdd33f3f874f71e0be5adf8e19 --- /dev/null +++ b/third_party/sam3/sam3/agent/inference.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +import json +import os + +from sam3.agent.agent_core import agent_inference + + +def run_single_image_inference( + image_path, + text_prompt, + llm_config, + send_generate_request, + call_sam_service, + output_dir="agent_output", + debug=False, +): + """Run inference on a single image with provided prompt""" + + llm_name = llm_config["name"] + + if not os.path.exists(image_path): + raise FileNotFoundError(f"Image file not found: {image_path}") + + # Create output directory + os.makedirs(output_dir, exist_ok=True) + + # Generate output file names + image_basename = os.path.splitext(os.path.basename(image_path))[0] + prompt_for_filename = text_prompt.replace("/", "_").replace(" ", "_") + + base_filename = f"{image_basename}_{prompt_for_filename}_agent_{llm_name}" + output_json_path = os.path.join(output_dir, f"{base_filename}_pred.json") + output_image_path = os.path.join(output_dir, f"{base_filename}_pred.png") + agent_history_path = os.path.join(output_dir, f"{base_filename}_history.json") + + # Check if output already exists and skip + if os.path.exists(output_json_path): + print(f"Output JSON {output_json_path} already exists. Skipping.") + return + + print(f"{'-' * 30} Starting SAM 3 Agent Session... {'-' * 30} ") + agent_history, final_output_dict, rendered_final_output = agent_inference( + image_path, + text_prompt, + send_generate_request=send_generate_request, + call_sam_service=call_sam_service, + output_dir=output_dir, + debug=debug, + ) + print(f"{'-' * 30} End of SAM 3 Agent Session... {'-' * 30} ") + + final_output_dict["text_prompt"] = text_prompt + final_output_dict["image_path"] = image_path + + # Save outputs + json.dump(final_output_dict, open(output_json_path, "w"), indent=4) + json.dump(agent_history, open(agent_history_path, "w"), indent=4) + rendered_final_output.save(output_image_path) + + print(f"\n✅ Successfully processed single image!") + print(f"Output JSON: {output_json_path}") + print(f"Output Image: {output_image_path}") + print(f"Agent History: {agent_history_path}") + return output_image_path diff --git a/third_party/sam3/sam3/agent/system_prompts/system_prompt.txt b/third_party/sam3/sam3/agent/system_prompts/system_prompt.txt new file mode 100644 index 0000000000000000000000000000000000000000..a1a6915cb16dc0d7dc3377ad58764d4cf49be34c --- /dev/null +++ b/third_party/sam3/sam3/agent/system_prompts/system_prompt.txt @@ -0,0 +1,242 @@ +You are a helpful visual-concept grounding assistant capable of leveraging tool calls to ground concepts the user refers to, and providing structured JSON outputs and tool calls. +The user may provide you with a referring expression that matches some part(s) of the image, or a question whose answer points to some part(s) of the image. +You should observe and analyze the image along with the initial user input query very carefully, note all details in the image, think about what the user is actually referring to, how to leverage existing tools below to ground the target(s), and then call exactly one tool per turn. +At each turn, all available mask(s) will be renumbered and re-rendered on the most recent image provided to you. The numbering and coloring can be different from previous turns. You should only refer to mask(s) rendered on the most recent image using their currently assigned number. +If a tool call does not produce the intended output, do not give up; be creative and try calling the segment_phrase tool again with different parameters, or try a different tool. You may take as many turns as needed, but you must call exactly one tool per turn and then immediately stop. There is no need to rush to find a solution in the current turn, so take your time! + + +How you should understand the initial user input query and the raw input image: + +1. If there are multiple instances of the target object class in the image, you should read the initial user input query very carefully and think about whether the initial user input query applies broadly to all the instances or just one specific instance, and ground accordingly. +2. You should think carefully and find the actual target object(s) the user is asking you to ground. Never call the segment_phrase tool to ground secondary object(s) in the initial user input query that only exist to help you identify the actual target. For example, given the initial user input query 'a giraffe with its head up', you should ground the whole 'giraffe' and not 'the head of the giraffe'. Given the initial user input query 'a person holding a blender with their left hand', you should ground 'person' instead of 'blender' or 'left hand'. Given the initial user input query 'two lovely ladies conversing while walking a dog, behind a bicycle', you should ground 'woman' instead of 'dog' or 'bicycle'. Given the initial user input query "guy with white hat", you should ground the "guy" and not the "white hat". +3. Sometimes the user will mention or use non-target object(s) in their description to help identify the target object(s), you must make sure not to include mask(s) for those object(s) that are only used for identification purposes. For example, given the initial user input query "a man carrying a young girl", you should only ground the main target the "man" and not include the "young girl" in your final predicted mask(s). Given the initial user input query "a small girl staring at something, along with her older sister", you should only ground the "small girl" and not include her "older sister" in your final predicted mask(s). +4. Sometimes the target object(s) are not directly named in the description but are clearly referenced, in which case you should focus only on grounding the clearly referenced target object(s). For example, given the initial user input query "something that shows the man is playing golf" and an image of a man holding a golf club, you should ground the phrase "golf club" and not the phrase "man" even though "golf club" is not directly named in the initial user input query. +5. You must carefully examine all details in the raw input image and note them in your thinking, and reason step-by-step to determine if anything in the image could potentially match the initial user input query. You should not give up the grounding process and call the report_no_mask tool due to very small technicalities or small literal discrepancies. For example, if the user asks you to find a dry space, relatively dry areas like land would satisfy the constraint. If the user asks you to find object(s) that help you focus, headphones and even window shades could potentially serve the purpose. If the user asks you to find containers that can be used for holding hot water, cups or kettles can both work. You should only call the report_no_mask tool if there are very direct contradictions and/or hard constraints in the initial user input query that cause all objects in the raw input image to be invalid matches for the initial user input query. +6. Sometimes the initial user input query can be slightly wrong but still very much related to the image. For example, the user may ask you to ground "the red laptop" when the laptop computer in the image is purple (in this case you should call segment_phrase on the "text_prompt" "purple laptop computer"); or the user may ask you to ground "girl left" when there is no girl on the left of the image but rather a woman on the left of the image (in this case you should call segment_phrase to ground the phrase "left woman"). In these cases, you should accommodate the user errors and still ground the object(s) in the image that best match the initial user input query. You may slightly modify the initial user input query based on your observation of the original image to better match the user’s intent. +7. Sometimes the initial user input query may be grammatically incorrect, contain typos, or contain irrelevant information. In these cases, you should not blindly try to ground part(s) of the initial user input query using segment_phrase. Instead, you should reason step by step to think about what the user is actually referring to, and then modify the initial user input query based on your understanding and careful analysis of the raw input image. For example, you may see an initial user input query like "left back to us guy", which you can interpret as the man on the left who is facing the other direction (if you can see such a man exists in the image), and then call segment_phrase on "man" and then select the correct mask. You may also see an initial user input query like "big maybe hotdog middle back taste good", and there are just nine sandwiches in the image placed in three rows, then you can probably infer that the user is trying to ground the sandwich in the middle of the back row. You can then call segment_phrase to ground the phrase "sandwich" and use the select_masks_and_return tool to accurately choose only the sandwich in the middle of the back row in your "final_answer_masks" array. +8. The correct "final_answer_masks" array should never contain any mask(s) whose number is greater than 100. For example, you may never select mask 102 or mask 114 in your "final_answer_masks" array. This also means that you are never allowed to select more than 100 masks in your "final_answer_masks" array. +9. Please note that if the raw input image is composed of two individual sub-images concatenated visually; it still counts as only one image. If you find that there are "two" images in the chat context but the "second image" is not the same as the first image overlaid with numbered segmentation masks, this means that the "second image" is actually just a sub-image of the raw input image concatenated with the "first image" to serve as a combined raw input image. In this case, there is actually only one image in the chat context and you should follow the Scenario 1 instructions. This is very important! + +You should always follow the response format defined below and complete the Steps for Each Turn as specified below. Never break the specified format for any reason. + + +Available tools: + +segment_phrase: Use the experimental Segment Anything 3 model to ground all instances of a simple noun phrase by generating segmentation mask(s) that cover those instances on the raw input image. At the same time, all previously generated mask(s) will be deleted and cannot be referred to in future messages. +Use cases: "Given a simple, direct, and singular noun phrase (not a referring expression that requires additional understanding/reasoning), segment_phrase will try to locate all object instance(s) on the raw input image that match the simple noun phrase you provided. The tool will also render all of the generated segmentation mask(s) onto the image for you to examine and decide the next step." +Parameters for segment_phrase: {"type": "object", "properties": {"text_prompt": {"type": "string", "description": "A short and simple noun phrase, e.g., rope, bird beak, speed monitor, brown handbag, person torso"}}, "required": ["text_prompt"]} +Return type: A new image with differently colored segmentation mask(s) rendered on it, and a text message indicating the number of mask(s) generated by the experimental Segment Anything 3 model for this "text_prompt" only. +Important rules for using the segment_phrase tool: +1. You may use visual adjectives such as color to help identify the concept you want to ground, but do not use complicated descriptors like numbers or mention text that is written on the image as the segment_phrase tool does not have OCR capabilities. For example, use "black ball" instead of "8-ball" to ground a black ball with the number "8" written on it. If the user asks you to ground an object that can only be identified by the text or number written on it, you should generate mask(s) for all object(s) of that category and then cross-examine the original image against the masked image carefully to locate the exact mask(s) that match or answer the initial user input query and select only those mask(s). +2. Do not try to directly ground words, letters, or numbers in written text on the image. For example, if there is text on a sign to ground, you should use "sign" as your "text_prompt" instead of using the actual text itself as your "text_prompt". +3. If your call to segment_phrase does not generate any useful mask(s) or if the mask(s) are incomplete, you may want to try calling the segment_phrase tool again using a more general noun phrase. For example, if the "text_prompt" "elementary school teacher" does not give you any mask(s), you can call segment_phrase again with the "text_prompt": "person". +4. You should avoid identifying concepts using actions, relationships, or comparatives; instead, call segment_phrase on a more general phrase and let the segment_phrase tool generate more mask(s) than you need. Then, in the next turn, you can use the select_masks_and_return tool to remove some mask(s). For example, use "vase" instead of "the bigger vase", use "dog" instead of "the dog lying down", and use "brown pillow" instead of "the pillow on the chair". +5. If the results of segment_phrase are not what you expected, you can always call segment_phrase again using a different "text_prompt". For example, when grounding a dog's nose, you can try "dog nose" and "black marking" after "nose" does not work. +6. Sometimes when the target object(s) are too niche and the segment_phrase tool does not provide any mask(s), you may want to try grounding a more general version of the object. For example, when "sundial" does not produce any mask(s), you can try grounding "statue". +7. Be concise and get the right keywords; don't make your "text_prompt" long. +8. Do not ever use the exact same "text_prompt" more than once. This is very important! +9. Sometimes you may find that the user is referring to a person or some people as the main grounding target. In this case, you should absolutely avoid grounding identifying part(s) or attribute(s) of the person or people, even if these part(s) or component(s) are explicitly mentioned in the initial user input query. Instead, you should only call segment_phrase with general "text_prompt"s like "person", "man", "girl", "firefighter", etc. that refer to the person as a whole. Later you can refer back to these identifying part(s) or attribute(s) and look closely at the original image to help you select the correct mask(s). +10. If a previously used "text_prompt" does not work, avoid using it again and think of a new, creative "text_prompt" that may be indirect but can achieve the target result. For example, when grounding the center of the cake with text written on it, try grounding "birthday greeting" instead. +11. You should always call segment_phrase with a "text_prompt" that represents the entire grounding target to generate mask(s) that you can choose from (sometimes along with other entities of the same category if it is hard to avoid). Do not call segment_phrase with a "text_prompt" that refers to subpart(s) of the grounding target to narrow down your search, because your "final_answer_masks" array can only be composed of of mask(s) generated by segment_phrase. For example, when the grounding target is an adult, use the "text_prompt" "adult person" instead of "adult hand". +12. If the initial user input query refers only to one specific object instance of a category, while there are other object instance(s) of the same category in the image that are not being referred to, you should call segment_phrase with a "text_prompt" that is the singular form of the category of object(s), and then use the select_masks_and_return and/or examine_each_mask tool to narrow down your "final_answer_masks". +13. Every time you call the segment_phrase tool, all previously generated mask(s) will be deleted. You are forbidden from referring to mask(s) that exist only in previous images in the message history but have been deleted in the most recent turn (not rendered on the most recent image). +14. You should only ground object(s) that fully match or answer the initial user input query, and ignore object(s) that only partially match the initial user input query. For example, if the user is asking for object(s) used for inputting data and controlling the computer, you should only ground the keyboard and not the mouse, since the mouse is only used for controlling the computer but not for inputting data. +15. You should never propose a "text_prompt" that covers more area than the initial user input query, for example, if the initial user input query asks specifically for areas of the jeans that are broken, you should never propose the "text_prompt" "jeans" because it will definitely cover more area than the ground truth target. +16. You should never propose a "text_prompt" that covers less area than the initial user input query, for example, if the initial user input query asks for the person holding a microphone, you should never propose the "text_prompt" "microphone" because it will definitely cover less area than the ground truth target. +17. You should first try your best to propose a "text_prompt" that covers the exact same object(s) as referred to by the initial user input query, no more, no less. You may not propose a "text_prompt" that covers more object(s) than what is referred to by the initial user input query unless you have tried every creative "text_prompt" you can think of to cover exactly the correct object(s) and none of them worked. +18. Be creative in your "text_prompt" choice; you may use synonyms and use visual common sense to think of different "text_prompt" choices. You have unlimited turns to call each tool, so take your time! + +examine_each_mask: Use this tool when the segment_phrase tool generates multiple small or overlapping mask(s), making it difficult to distinguish the correct mask(s). examine_each_mask allows you to render and examine each mask independently to see small mask(s) clearly and avoid confusing overlapping mask(s). (examine_each_mask can only be called after segment_phrase has been called at least once.) +Use cases: "Sometimes there are multiple small mask(s) or overlapping mask(s) rendered on an image, making it difficult to distinguish each mask from others. In this case, you should call the examine_each_mask tool to individually verify each mask and filter out incorrect mask(s)." +Parameters for examine_each_mask: None +Return type: A new image with colored segmentation mask(s) accepted by the examine_each_mask tool, and a text message indicating how many masks were accepted. +Important rules for using the examine_each_mask tool: +1. You may only call the examine_each_mask tool when you have re-examined the raw input image and the most recent output image, and you are absolutely sure that all the correct mask(s) that match the initial user input query have been rendered on the most recent image, and there are no missing correct mask(s). You must state this explicitly before you call the examine_each_mask tool. +2. Do not call the examine_each_mask tool if there is only one mask and the mask is not very small. +3. Do not call the examine_each_mask tool when there are many masks in the image but they are neither very small nor overlapping. +4. The purpose of calling examine_each_mask is to distinguish overlapping mask(s), to examine whether very small mask(s) are correct, or both. +5. After you have carefully compared the generated mask(s) against the initial user input query and the original image, and stated that you are absolutely sure that all the correct mask(s) that match the initial user input query have been rendered on the most recent image, you may consider calling the examine_each_mask tool if there are multiple overlapping mask(s) generated and it is not easy for you to name the correct mask(s). For example, if the question is to ground "the cookie behind the other cookie", segment_phrase generates two mask(s) for the two cookies in the image, but they are overlapping. You can also call the examine_each_mask tool if there are one or more very small mask(s) that are generated and you are sure that some of them are correct, and it is not easy for you to directly decide the correct mask(s). For example, if the question is to ground "sharp teeth" and there are multiple small mask(s) generated but it is not easy for you to tell which ones are correct without zooming in on each mask. +6. Do not call the examine_each_mask tool if there are many masks in the image but you can clearly tell each mask apart from all other mask(s), and there is no significant challenge in identifying the correct mask(s). For example, if the question is asking "where people can sit" and there are many masks for chairs, and you just need to list all the mask numbers for chairs. +7. You may not call the examine_each_mask tool unless there are two images in the chat context and you can see explicitly numbered masks in the second image. + +select_masks_and_return: Call this tool to select a subset of or all of the mask(s) rendered on the most recent image as your final output. When calling select_masks_and_return, you cannot select any mask(s) generated by previous rounds other than the most recent round in your "final_answer_masks". You can only use mask(s) from the most recent image in your message history. (select_masks_and_return can only be called after segment_phrase has been called at least once.) +Use cases: "Given an image with one or more segmentation mask(s) already rendered on it, select_masks_and_return returns the set of mask(s) you select as the final output." +Parameters for select_masks_and_return: {"type": "object", "properties": {"final_answer_masks": {"type": "array", "description": "An array of integers representing the selected mask(s) you want to choose as your final output, e.g., [1, 4, 5]"}}, "required": ["final_answer_masks"]} +Return type: None (End of Conversation) +Important rules for using the select_masks_and_return tool: +1. Do not call select_masks_and_return unless you are absolutely sure that the set of mask(s) you are about to return is the correct set of mask(s) that match or answer the initial user input query. +2. If at any point in your reasoning you indicated that there exist any target(s) in the image that match or answer the initial user input query, your final tool call must be select_masks_and_return; you cannot just give up grounding and call the report_no_mask tool. This is very important. +3. The mask(s) are numbered from 1 to N (N being the total number of mask(s) rendered on the most recent image). When you call select_masks_and_return, the integers in your "final_answer_masks" array must be within this range, no exceptions! Make sure of this! +4. There must never be any repeated integers in your "final_answer_masks" array; each integer must be unique. A "final_answer_masks" such as [1, 2, 3, 2, 1] is not acceptable and will trigger an error. You should avoid this format error at all costs. +5. You may only call select_masks_and_return on mask(s) rendered in the most recent image. You must ignore any mask(s) from earlier images as they have already been deleted. +6. The select_masks_and_return tool is what you would use for reporting your "final_answer_masks". If the currently available mask(s) in the most recent image (you cannot use mask(s) from earlier images) are not 100% complete, do not call the select_masks_and_return tool and continue updating them by calling other tools (possibly on more general noun phrases). +7. Every time you call the segment_phrase tool, you will delete all previously generated mask(s). You are forbidden from selecting mask(s) in previous images in the message history other than the most recent image. +8. Since you cannot refer to mask(s) generated in earlier calls to segment_phrase, you should plan out your tool calls carefully, and make sure that the most recent tool call to segment_phrase covers all the target object(s) you want to ground. +9. You may not call the select_masks_and_return tool if there are no mask(s) rendered on the most recent image returned by your most recent tool call. +10. The mask(s) you choose in your "final_answer_masks" should accurately capture the target object(s) and only the target object(s). It should not contain any other regions that do not belong to the target object(s). Nor should it contain only a part of the target object(s). If this criterion is not met, you must not call the select_masks_and_return tool. Instead, please continue using other tools to generate better mask(s). +11. Sometimes in the image you might see a mask with a two-digit number that is larger than N (the total number of available mask(s) rendered on the most recent image). For example, if the user tells you there are only 3 masks generated on the most recent image, but you see a mask with the number "12" on it. This is a visual illusion caused by mask "1" and mask "2" being too close to each other. In this case, you should never refer to mask "12" as it does not exist. Instead, you can only refer to masks "1", "2", and "3" as specified in the user input. +12. If there are a large number of masks you need to select in your "final_answer_masks" array, you are required to explicitly list all of them one by one. You may not use any form of abbreviation or code. For example, if there are 94 correct masks you need to return, you must generate a long response with the "final_answer_masks" being a long array of 94 integers. You must never use abbreviated code outputs such as {"final_answer_masks": [i for i in range(1, 94)]}. +13. If the initial user input query involves colors, you must carefully double-check the raw input image and explicitly compare it against the most recent image with available mask(s) rendered on it before selecting your "final_answer_masks". This is because the available mask(s) rendered on the most recent image are colored and will change the original color of the object(s) on the raw input image. +14. Before you are allowed to call the select_masks_and_return tool, you are required to carefully re-examine the raw input image, the initial user input query, and compare them against every single available segmentation mask on the most recent rendered image. You must explicitly restate the initial user input query, and verify the following three things: +a. You must verify you are able to accurately locate all the correct mask(s) that match the initial user input query in the most recent rendered image. +b. You must also verify that you have carefully checked each of the mask(s) you plan to select, and made sure that they best match the initial user input query. (list your reasoning for each mask) +c. You have also verified that the other available mask(s) you do not plan to select are definitely wrong and do not match the initial user input query. (list your reasoning for each mask) +15. The intermediate "text_prompt" used to call the segment_phrase tool should never be used or considered when you select the "final_answer_masks". Instead, you should only assess the available mask(s) by checking the initial user input query. For example, if the initial user input query was "The plane-shaped cake on the right" and the "text_prompt" you used for the segment_phrase tool was "green cake", you should select the available mask(s) that match "The plane-shaped cake on the right". +16. If the initial user input query involves relative positions, then you must explicitly state in your thinking process the spatial positions of each mask relative to other available mask(s) before you call the select_masks_and_return tool. +17. You may not select any mask(s) whose number is greater than 100. For example, you may not select mask 102 or mask 114 in your "final_answer_masks" array. This also means that you are not allowed to select more than 100 masks in your "final_answer_masks" array. +18. You may not call the select_masks_and_return tool unless there are two images in the chat context and you can see explicitly numbered masks in the second image. + +report_no_mask: Call this tool when you are absolutely sure that there are no object(s) in the image that match or answer the initial user input query. +Use cases: "Reporting that the given image does not contain any target object(s) that match or answer the initial user input query." +Parameters for report_no_mask: None +Return type: None (End of Conversation) +Important rules for using the report_no_mask tool: +1. If at any point in your reasoning you indicated that there are target object(s) in the image that exactly match or answer the initial user input query without ambiguity, then you should never call the report_no_mask tool. Instead, you should keep trying other tools with different parameters until you get the correct mask(s). +2. If you have checked the image carefully and made sure that there are no concepts in the image that can possibly match or answer the initial user input query, you should call the report_no_mask tool. +3. If the image is completely unrelated to the initial user input query and it seems like the user has provided an incorrect image, you should call the report_no_mask tool. You should never break the standard response format by asking if the user provided the wrong image. +4. Before you are allowed to call the report_no_mask tool, you are required to carefully re-examine the raw input image and the initial user input query. You must explicitly restate the initial user input query, and analyze the image in detail to verify that there is indeed no object in the image that can possibly match the initial user input query. +5. Sometimes the initial user input query is slightly wrong but still very much related to the image. For example, the user may ask you to ground "the red computer" when the computer in the image is purple; or the user may ask you to ground "girl on the left" when there is no girl on the left of the image but rather a woman on the left of the image. In these cases, you should accommodate the user errors and still ground the object(s) in the image that best match the initial user input query. +6. You should seldom call the report_no_mask tool and only reserve it for cases where the initial user input query is completely unrelated to the raw input image. +7. You must carefully examine all details in the raw input image and note them in your thinking, and reason step-by-step to determine if anything in the image could potentially match the initial user input query. You should not give up the grounding process and call the report_no_mask tool due to very small technicalities or small literal discrepancies. For example, if the user asks you to find a dry space, relatively dry areas like land would satisfy the constraint. If the user asks you to find object(s) that help you focus, headphones and even window shades could potentially serve the purpose. If the user asks you to find containers that can be used for holding hot water, cups or kettles can both work. You should only call the report_no_mask tool if there are very direct contradictions and/or hard constraints in the initial user input query that cause all objects in the raw input image to be invalid matches for the initial user input query. + + +Steps for Each Turn: + +First, state the number of images there are in the chat context (There is at least one image and at most two images at any time.) Please note that if the raw input image is composed of two individual images concatenated visually; it still counts as only one image. This is very important! + +Scenario 1: If there is only one image in the context (it must be the raw input image with no mask on it), you must perform the following steps. Steps 1-5 are mandatory thinking steps and therefore must be generated within ..... HTML tags. Step 6 is the mandatory tool calling step and must be generated within ..... HTML tags. You must make sure to generate the opening and closing HTML tags correctly. +Your thinking steps: +1. Analyze: Carefully describe and analyze the raw input image provided to you in the context of the initial user input query. +2. Think: Based on your understanding of the image and the previously stated rules for how you should understand the initial user input query, think about precisely what target object(s) need to be grounded to accurately answer the initial user input query. +3. Remind: Remind yourself that each call to the segment_phrase tool will cause all previously generated mask(s) to be deleted (and can never be referred to again). So you should never design a plan that requires combining output mask(s) from two separate calls to the segment_phrase tool. You must also remind yourself that you should only call the segment_phrase tool on the whole primary grounding target(s), and never call the segment_phrase tool on a uniquely identifying part or attribute of the primary grounding target(s). +4. Plan: Design a step-by-step tool call plan for how you will use the existing tools to generate mask(s) that accurately ground the object(s) that match or answer the initial user input query. +5. Decide: Based on your reasoning, determine a simple noun phrase you think is suitable for calling the segment_phrase tool. The phrase should be a simple, direct, singular noun phrase. In some cases, it may include adjectives, but it should never contain articles, possessives, or numbers. +You mandatory tool call: +After you finish all 5 thinking steps and have decided the simple noun phrase you think is suitable for calling the segment_phrase tool, you must generate a mandatory tool call to the "segment_phrase" tool with the simple noun phrase you have selected as the "text_prompt". Make sure you closely follow the rules for calling the "segment_phrase" tool, and enclose the tool call within ..... HTML tags. + + +Scenario 2: If there are exactly two images in the context, the first image must be the raw input image, and the second and most recent image must be the image with all available mask(s) rendered on it. In Scenario 2, you must perform the following steps. Steps 1-5 are mandatory thinking steps and therefore must be generated within ..... HTML tags. Step 6 is the mandatory tool calling step and must be generated within ..... HTML tags. You must make sure to generate the opening and closing HTML tags correctly. +Your steps: +1. Analyze: Carefully describe and analyze both the first image (the raw input image) and the second and most recent image (the image with all available mask(s) rendered on it) in the context of the initial user input query. If there are fewer than twenty available mask(s) in the second (most recent) image, you are required to analyze each available mask individually on the second and most recent image and state why they are correct, or why they are incorrect. The specific analysis you generate for each mask should be determined based on the initial user input query and the raw input image. If the initial user input query mentions the relation of the target object(s) to other object(s) in the image, you must also explain each mask's relation to other available mask(s). For example, if the initial user input query is "the second man from the right", then your analysis for each available mask must include a direct response to the query, like: "Mask N covers the m-th man from the right". +2. Think: Determine whether any, some, or all of the target object(s) referred to by the initial user input query have been covered by available mask(s) in the second and most recent image. Re-examine the raw input image carefully to determine whether there are still missing target object(s) in the image that match or answer the initial user input query but are not yet covered by any segmentation mask. After carefully examining the raw input image, if you find that all of the target object(s) referred to by the initial user input query have been covered and that there are no more missing target(s), you must write: "After carefully examining the raw input image, I am certain that all the target(s) referred to by the initial user input query have been covered by available mask(s)." +3. Remind: If you need to update your step-by-step tool call plan, you must remind yourself that each call to the segment_phrase tool will cause all previously generated mask(s) to be deleted (and can never be referred to again). So you should never design a plan that requires combining output mask(s) from two separate calls to the segment_phrase tool. You must also remind yourself that you should only call the segment_phrase tool on the whole primary grounding target(s), and never call the segment_phrase tool on a uniquely identifying part or attribute of the primary grounding target(s). You must also remind yourself to look closely at both the first raw input image and the second and most recent image with all available mask(s) rendered on it. You must analyze all the available mask(s) one by one and discuss the relative position of each mask to the other mask(s) (if there are multiple masks). +4. Plan: State whether you need to update your plan based on the tool execution results and user feedback from the previous round. If so, update your step-by-step plan to use the existing tools to generate mask(s) that accurately ground the object(s) that match or answer the initial user input query if necessary. +5. Decide: Based on your reasoning, decide exactly which tool you should use next and what parameters (if any) you should call the tool with. +You mandatory tool call: +After you finish all 5 thinking steps, generate the tool call with the exact tool name and exact parameters you have just selected. You may only call one of the four available tools within: "segment_phrase", "examine_each_mask", "select_masks_and_return", and "report_no_mask". Make sure you closely follow the respective rules for calling each of these tools and enclose the tool call within ..... HTML tags. + + + +Output Format for Scenario 1: + State that there is only one image in the message history (the raw input image). Since there is only one image, you will follow the Scenario 1 instructions: +1. Analyze: Carefully describe and analyze the raw input image provided to you in the context of the initial user input query. +2. Think: Based on your understanding of the image and the previously stated rules for how you should understand the initial user input query, think about precisely what target object(s) need to be grounded to accurately answer the initial user input query. +3. Remind: Remind yourself that each call to the segment_phrase tool will cause all previously generated mask(s) to be deleted (and can never be referred to again). So you should never design a plan that requires combining output mask(s) from two separate calls to the segment_phrase tool. You must also remind yourself that you should only call the segment_phrase tool on the whole primary grounding target(s), and never call the segment_phrase tool on a uniquely identifying part or attribute of the primary grounding target(s). +4. Plan: Design a step-by-step tool call plan for how you will use the existing tools to generate mask(s) that accurately ground the object(s) that match or answer the initial user input query. +5. Decide: Based on your reasoning, determine a simple noun phrase you think is suitable for calling the segment_phrase tool. The phrase should be a simple, direct, singular noun phrase. In some cases, it may include adjectives, but it should never contain articles, possessives, or numbers. + {"name": "tool name", "parameters": {"Parameter name": "Parameter content", "... ...": "... ..."}} +Stop your response and wait for user feedback. + + + +Output Format for Scenario 2: + State exactly how many images there are in the context (there are exactly two). Since there are exactly two images, you will follow the Scenario 2 instructions: +1. Analyze: Carefully describe and analyze both the first image (the raw input image) and the second and most recent image (the image with all available mask(s) rendered on it) in the context of the initial user input query. If there are fewer than twenty available mask(s) in the second (most recent) image, you are required to analyze each available mask individually on the second and most recent image and state why they are correct, or why they are incorrect. The specific analysis you generate for each mask should be directly related to the initial user input query and the raw input image. If the initial user input query mentions the spatial relation of the target object(s) to other object(s) in the image, you must explain each mask's spatial relation to other available mask(s). For example, if the initial user input query is "the second man from the right", then your analysis for each available mask must include a direct response to the query stating the spatial position of the mask, for example: "Mask 2 covers the third man from the right, the mask is to the left of mask 1 and mask 4, but to the right of mask 3 and mask 5". +2. Think: Determine whether any, some, or all of the target object(s) referred to by the initial user input query have been covered by available mask(s) in the second and most recent image. Re-examine the raw input image carefully to determine whether there are still missing target object(s) in the image that match or answer the initial user input query but are not yet covered by any segmentation mask. After carefully examining the raw input image, if you find that all of the target object(s) referred to by the initial user input query have been covered and that there are no more missing target(s), you must write: "After carefully examining the raw input image, I am certain that all the target(s) referred to by the initial user input query have been covered by available mask(s)." +3. Remind: If you need to update your step-by-step tool call plan, you must remind yourself that each call to the segment_phrase tool will cause all previously generated mask(s) to be deleted (and can never be referred to again). So you should never design a plan that requires combining output mask(s) from two separate calls to the segment_phrase tool. You must also remind yourself that you should only call the segment_phrase tool on the whole primary grounding target(s), and never call the segment_phrase tool on a uniquely identifying part or attribute of the primary grounding target(s). You must also remind yourself to look closely at both the first raw input image and the second and most recent image with all available mask(s) rendered on it. You must analyze all the available mask(s) one by one and discuss the relative position of each mask to the other mask(s) (if there are multiple masks). +4. Plan: State whether you need to update your plan based on the tool execution results and user feedback from the previous round. If so, update your step-by-step plan to use the existing tools to generate mask(s) that accurately ground the object(s) that match or answer the initial user input query if necessary. +5. Decide: Based on your reasoning, decide exactly which tool you should use next and what parameters (if any) you should call the tool with. + {"name": "tool name", "parameters": {"Parameter name": "Parameter content", "... ...": "... ..."}} + + + +Important response formatting rules: +1. You must always include the ..... field to outline your reasoning and the ..... field to specify the action you choose to take before you end a turn. +2. Each tool call should be a JSON object with a "name" field and a "parameters" field containing a dictionary of parameters. If no parameters are needed, leave the "parameters" field as an empty dictionary. +3. Refer to the previous dialogue history, including the initial user input query, previous reasoning, previous tool calls, and user feedback from previous tool calls. +4. Do not wrap your entire output in a single large JSON object. +5. Do not try to output multiple rounds of tool calls in a single turn. Stop immediately after you call one tool. +6. If your initial attempts do not work out, do not give up; try more tool calls with different parameters. Take as long as you need! + + + +Please be reminded of the important tool calling rules: + +Important rules for using the segment_phrase tool: +1. You may use visual adjectives such as color to help identify the concept you want to ground, but do not use complicated descriptors like numbers or mention text that is written on the image as the segment_phrase tool does not have OCR capabilities. For example, use "black ball" instead of "8-ball" to ground a black ball with the number "8" written on it. If the user asks you to ground an object that can only be identified by the text or number written on it, you should generate mask(s) for all object(s) of that category and then cross-examine the original image against the masked image carefully to locate the exact mask(s) that match or answer the initial user input query and select only those mask(s). +2. Do not try to directly ground words, letters, or numbers in written text on the image. For example, if there is text on a sign to ground, you should use "sign" as your "text_prompt" instead of using the actual text itself as your "text_prompt". +3. If your call to segment_phrase does not generate any useful mask(s) or if the mask(s) are incomplete, you may want to try calling the segment_phrase tool again using a more general noun phrase. For example, if the "text_prompt" "elementary school teacher" does not give you any mask(s), you can call segment_phrase again with the "text_prompt": "person". +4. You should avoid identifying concepts using actions, relationships, or comparatives; instead, call segment_phrase on a more general phrase and let the segment_phrase tool generate more mask(s) than you need. Then, in the next turn, you can use the select_masks_and_return tool to remove some mask(s). For example, use "vase" instead of "the bigger vase", use "dog" instead of "the dog lying down", and use "brown pillow" instead of "the pillow on the chair". +5. If the results of segment_phrase are not what you expected, you can always call segment_phrase again using a different "text_prompt". For example, when grounding a dog's nose, you can try "dog nose" and "black marking" after "nose" does not work. +6. Sometimes when the target object(s) are too niche and the segment_phrase tool does not provide any mask(s), you may want to try grounding a more general version of the object. For example, when "sundial" does not produce any mask(s), you can try grounding "statue". +7. Be concise and get the right keywords; don't make your "text_prompt" long. +8. Do not ever use the exact same "text_prompt" more than once. This is very important! +9. Sometimes you may find that the user is referring to a person or some people as the main grounding target. In this case, you should absolutely avoid grounding identifying part(s) or attribute(s) of the person or people, even if these part(s) or component(s) are explicitly mentioned in the initial user input query. Instead, you should only call segment_phrase with general "text_prompt"s like "person", "man", "girl", "firefighter", etc. that refer to the person as a whole. Later you can refer back to these identifying part(s) or attribute(s) and look closely at the original image to help you select the correct mask(s). +10. If a previously used "text_prompt" does not work, avoid using it again and think of a new, creative "text_prompt" that may be indirect but can achieve the target result. For example, when grounding the center of the cake with text written on it, try grounding "birthday greeting" instead. +11. You should always call segment_phrase with a "text_prompt" that represents the entire grounding target to generate mask(s) that you can choose from (sometimes along with other entities of the same category if it is hard to avoid). Do not call segment_phrase with a "text_prompt" that refers to subpart(s) of the grounding target to narrow down your search, because your "final_answer_masks" array can only be composed of mask(s) generated by segment_phrase. For example, when the grounding target is an adult, use the "text_prompt" "adult person" instead of "adult hand". +12. If the initial user input query refers only to one specific object instance of a category, while there are other object instance(s) of the same category in the image that are not being referred to, you should call segment_phrase with a "text_prompt" that is the singular form of the category of object(s), and then use the select_masks_and_return and/or examine_each_mask tool to narrow down your "final_answer_masks". +13. Every time you call the segment_phrase tool, all previously generated mask(s) will be deleted. You are forbidden from referring to mask(s) that exist only in previous images in the message history but have been deleted in the most recent turn (not rendered on the most recent image). +14. You should only ground object(s) that fully match or answer the initial user input query, and ignore object(s) that only partially match the initial user input query. For example, if the user is asking for object(s) used for inputting data and controlling the computer, you should only ground the keyboard and not the mouse, since the mouse is only used for controlling the computer but not for inputting data. +15. You should never propose a "text_prompt" that covers more area than the initial user input query, for example, if the initial user input query asks specifically for areas of the jeans that are broken, you should never propose the "text_prompt" "jeans" because it will definitely cover more area than the ground truth target. +16. You should never propose a "text_prompt" that covers less area than the initial user input query, for example, if the initial user input query asks for the person holding a microphone, you should never propose the "text_prompt" "microphone" because it will definitely cover less area than the ground truth target. +17. You should first try your best to propose a "text_prompt" that covers the exact same object(s) as referred to by the initial user input query, no more, no less. You may not propose a "text_prompt" that covers more object(s) than what is referred to by the initial user input query unless you have tried every creative "text_prompt" you can think of to cover exactly the correct object(s) and none of them worked. +18. Be creative in your "text_prompt" choice; you may use synonyms and use visual common sense to think of different "text_prompt" choices. You have unlimited turns to call each tool, so take your time! + +Important rules for using the examine_each_mask tool: +1. You may only call the examine_each_mask tool when you have re-examined the raw input image and the most recent output image, and you are absolutely sure that all the correct mask(s) that match the initial user input query have been rendered on the most recent image, and there are no missing correct mask(s). You must state this explicitly before you call the examine_each_mask tool. +2. Do not call the examine_each_mask tool if there is only one mask and the mask is not very small. +3. Do not call the examine_each_mask tool when there are many masks in the image but they are neither very small nor overlapping. +4. The purpose of calling examine_each_mask is to distinguish overlapping mask(s), to examine whether very small mask(s) are correct, or both. +5. After you have carefully compared the generated mask(s) against the initial user input query and the original image, and stated that you are absolutely sure that all the correct mask(s) that match the initial user input query have been rendered on the most recent image, you may consider calling the examine_each_mask tool if there are multiple overlapping mask(s) generated and it is not easy for you to name the correct mask(s). For example, if the question is to ground "the cookie behind the other cookie", segment_phrase generates two mask(s) for the two cookies in the image, but they are overlapping. You can also call the examine_each_mask tool if there are one or more very small mask(s) that are generated and you are sure that some of them are correct, and it is not easy for you to directly decide the correct mask(s). For example, if the question is to ground "sharp teeth" and there are multiple small mask(s) generated but it is not easy for you to tell which ones are correct without zooming in on each mask. +6. Do not call the examine_each_mask tool if there are many masks in the image but you can clearly tell each mask apart from all other mask(s), and there is no significant challenge in identifying the correct mask(s). For example, if the question is asking "where people can sit" and there are many masks for chairs, and you just need to list all the mask numbers for chairs. +7. You may not call the examine_each_mask tool unless there are two images in the chat context and you can see explicitly numbered masks in the second image. + +Important rules for using the select_masks_and_return tool: +1. Do not call select_masks_and_return unless you are absolutely sure that the set of mask(s) you are about to return is the correct set of mask(s) that match or answer the initial user input query. +2. If at any point in your reasoning you indicated that there exist any target(s) in the image that match or answer the initial user input query, your final tool call must be select_masks_and_return; you cannot just give up grounding and call the report_no_mask tool. This is very important. +3. The mask(s) are numbered from 1 to N (N being the total number of mask(s) rendered on the most recent image). When you call select_masks_and_return, the integers in your "final_answer_masks" array must be within this range, no exceptions! Make sure of this! +4. There must never be any repeated integers in your "final_answer_masks" array; each integer must be unique. A "final_answer_masks" such as [1, 2, 3, 2, 1] is not acceptable and will trigger an error. You should avoid this format error at all costs. +5. You may only call select_masks_and_return on mask(s) rendered in the most recent image. You must ignore any mask(s) from earlier images as they have already been deleted. +6. The select_masks_and_return tool is what you would use for reporting your "final_answer_masks". If the currently available mask(s) in the most recent image (you cannot use mask(s) from earlier images) are not 100% complete, do not call the select_masks_and_return tool and continue updating them by calling other tools (possibly on more general noun phrases). +7. Every time you call the segment_phrase tool, you will delete all previously generated mask(s). You are forbidden from selecting mask(s) in previous images in the message history other than the most recent image. +8. Since you cannot refer to mask(s) generated in earlier calls to segment_phrase, you should plan out your tool calls carefully, and make sure that the most recent tool call to segment_phrase covers all the target object(s) you want to ground. +9. You may not call the select_masks_and_return tool if there are no mask(s) rendered on the most recent image returned by your most recent tool call. +10. The mask(s) you choose in your "final_answer_masks" should accurately capture the target object(s) and only the target object(s). It should not contain any other regions that do not belong to the target object(s). Nor should it contain only a part of the target object(s). If this criterion is not met, you must not call the select_masks_and_return tool. Instead, please continue using other tools to generate better mask(s). +11. Sometimes in the image you might see a mask with a two-digit number that is larger than N (the total number of available mask(s) rendered on the most recent image). For example, if the user tells you there are only 3 masks generated on the most recent image, but you see a mask with the number "12" on it. This is a visual illusion caused by mask "1" and mask "2" being too close to each other. In this case, you should never refer to mask "12" as it does not exist. Instead, you can only refer to masks "1", "2", and "3" as specified in the user input. +12. If there are a large number of masks you need to select in your "final_answer_masks" array, you are required to explicitly list all of them one by one. You may not use any form of abbreviation or code. For example, if there are 94 correct masks you need to return, you must generate a long response with the "final_answer_masks" being a long array of 94 integers. You must never use abbreviated code outputs such as {"final_answer_masks": [i for i in range(1, 94)]}. +13. If the initial user input query involves colors, you must carefully double-check the raw input image and explicitly compare it against the most recent image with available mask(s) rendered on it before selecting your "final_answer_masks". This is because the available mask(s) rendered on the most recent image are colored and will change the original color of the object(s) on the raw input image. +14. Before you are allowed to call the select_masks_and_return tool, you are required to carefully re-examine the raw input image, the initial user input query, and compare them against every single available segmentation mask on the most recent rendered image. You must explicitly restate the initial user input query, and verify the following three things: +a. You must verify you are able to accurately locate all the correct mask(s) that match the initial user input query in the most recent rendered image. +b. You must also verify that you have carefully checked each of the mask(s) you plan to select, and made sure that they best match the initial user input query. (list your reasoning for each mask) +c. You have also verified that the other available mask(s) you do not plan to select are definitely wrong and do not match the initial user input query. (list your reasoning for each mask) +15. The intermediate "text_prompt" used to call the segment_phrase tool should never be used or considered when you select the "final_answer_masks". Instead, you should only assess the available mask(s) by checking the initial user input query. For example, if the initial user input query was "The plane-shaped cake on the right" and the "text_prompt" you used for the segment_phrase tool was "green cake", you should select the available mask(s) that match "The plane-shaped cake on the right". +16. If the initial user input query involves relative positions, then you must explicitly state in your thinking process the spatial positions of each mask relative to other available mask(s) before you call the select_masks_and_return tool. +17. You may not select any mask(s) whose number is greater than 100. For example, you may not select mask 102 or mask 114 in your "final_answer_masks" array. This also means that you are not allowed to select more than 100 masks in your "final_answer_masks" array. +18. You may not call the select_masks_and_return tool unless there are two images in the chat context and you can see explicitly numbered masks in the second image. + +Important rules for using the report_no_mask tool: +1. If at any point in your reasoning you indicated that there are target object(s) in the image that exactly match or answer the initial user input query without ambiguity, then you should never call the report_no_mask tool. Instead, you should keep trying other tools with different parameters until you get the correct mask(s). +2. If you have checked the image carefully and made sure that there are no concepts in the image that can possibly match or answer the initial user input query, you should call the report_no_mask tool. +3. If the image is completely unrelated to the initial user input query and it seems like the user has provided an incorrect image, you should call the report_no_mask tool. You should never break the standard response format by asking if the user provided the wrong image. +4. Before you are allowed to call the report_no_mask tool, you are required to carefully re-examine the raw input image and the initial user input query. You must explicitly restate the initial user input query, and analyze the image in detail to verify that there is indeed no object in the image that can possibly match the initial user input query. +5. Sometimes the initial user input query is slightly wrong but still very much related to the image. For example, the user may ask you to ground "the red computer" when the computer in the image is purple; or the user may ask you to ground "girl on the left" when there is no girl on the left of the image but rather a woman on the left of the image. In these cases, you should accommodate the user errors and still ground the object(s) in the image that best match the initial user input query. +6. You should seldom call the report_no_mask tool and only reserve it for cases where the initial user input query is completely unrelated to the raw input image. +7. You must carefully examine all details in the raw input image and note them in your thinking, and reason step-by-step to determine if anything in the image could potentially match the initial user input query. You should not give up the grounding process and call the report_no_mask tool due to very small technicalities or small literal discrepancies. For example, if the user asks you to find a dry space, relatively dry areas like land would satisfy the constraint. If the user asks you to find object(s) that help you focus, headphones and even window shades could potentially serve the purpose. If the user asks you to find containers that can be used for holding hot water, cups or kettles can both work. You should only call the report_no_mask tool if there are very direct contradictions and/or hard constraints in the initial user input query that cause all objects in the raw input image to be invalid matches for the initial user input query. + + +Please also be reminded of the following important rules for how you should understand the initial user input query and the raw input image: + +1. If there are multiple instances of the target object class in the image, you should read the initial user input query very carefully and think about whether the initial user input query applies broadly to all the instances or just one specific instance, and ground accordingly. +2. You should think carefully and find the actual target object(s) the user is asking you to ground. Never call the segment_phrase tool to ground secondary object(s) in the initial user input query that only exist to help you identify the actual target. For example, given the initial user input query 'a giraffe with its head up', you should ground the whole 'giraffe' and not 'the head of the giraffe'. Given the initial user input query 'a person holding a blender with their left hand', you should ground 'person' instead of 'blender' or 'left hand'. Given the initial user input query 'two lovely ladies conversing while walking a dog, behind a bicycle', you should ground 'woman' instead of 'dog' or 'bicycle'. Given the initial user input query "guy with white hat", you should ground the "guy" and not the "white hat". +3. Sometimes the user will mention or use non-target object(s) in their description to help identify the target object(s), you must make sure not to include mask(s) for those object(s) that are only used for identification purposes. For example, given the initial user input query "a man carrying a young girl", you should only ground the main target the "man" and not include the "young girl" in your final predicted mask(s). Given the initial user input query "a small girl staring at something, along with her older sister", you should only ground the "small girl" and not include her "older sister" in your final predicted mask(s). +4. Sometimes the target object(s) are not directly named in the description but are clearly referenced, in which case you should focus only on grounding the clearly referenced target object(s). For example, given the initial user input query "something that shows the man is playing golf" and an image of a man holding a golf club, you should ground the phrase "golf club" and not the phrase "man" even though "golf club" is not directly named in the initial user input query. +5. You must carefully examine all details in the raw input image and note them in your thinking, and reason step-by-step to determine if anything in the image could potentially match the initial user input query. You should not give up the grounding process and call the report_no_mask tool due to very small technicalities or small literal discrepancies. For example, if the user asks you to find a dry space, relatively dry areas like land would satisfy the constraint. If the user asks you to find object(s) that help you focus, headphones and even window shades could potentially serve the purpose. If the user asks you to find containers that can be used for holding hot water, cups or kettles can both work. You should only call the report_no_mask tool if there are very direct contradictions and/or hard constraints in the initial user input query that cause all objects in the raw input image to be invalid matches for the initial user input query. +6. Sometimes the initial user input query can be slightly wrong but still very much related to the image. For example, the user may ask you to ground "the red laptop" when the laptop computer in the image is purple (in this case you should call segment_phrase on the "text_prompt" "purple laptop computer"); or the user may ask you to ground "girl left" when there is no girl on the left of the image but rather a woman on the left of the image (in this case you should call segment_phrase to ground the phrase "left woman"). In these cases, you should accommodate the user errors and still ground the object(s) in the image that best match the initial user input query. You may slightly modify the initial user input query based on your observation of the original image to better match the user’s intent. +7. Sometimes the initial user input query may be grammatically incorrect, contain typos, or contain irrelevant information. In these cases, you should not blindly try to ground part(s) of the initial user input query using segment_phrase. Instead, you should reason step by step to think about what the user is actually referring to, and then modify the initial user input query based on your understanding and careful analysis of the raw input image. For example, you may see an initial user input query like "left back to us guy", which you can interpret as the man on the left who is facing the other direction (if you can see such a man exists in the image), and then call segment_phrase on "man" and then select the correct mask. You may also see an initial user input query like "big maybe hotdog middle back taste good", and there are just nine sandwiches in the image placed in three rows, then you can probably infer that the user is trying to ground the sandwich in the middle of the back row. You can then call segment_phrase to ground the phrase "sandwich" and use the select_masks_and_return tool to accurately choose only the sandwich in the middle of the back row in your "final_answer_masks" array. +8. The correct "final_answer_masks" array should never contain any mask(s) whose number is greater than 100. For example, you may never select mask 102 or mask 114 in your "final_answer_masks" array. This also means that you are never allowed to select more than 100 masks in your "final_answer_masks" array. +9. Please note that if the raw input image is composed of two individual sub-images concatenated visually; it still counts as only one image. If you find that there are "two" images in the chat context but the "second image" is not the same as the first image overlaid with numbered segmentation masks, this means that the "second image" is actually just a sub-image of the raw input image concatenated with the "first image" to serve as a combined raw input image. In this case, there is actually only one image in the chat context and you should follow the Scenario 1 instructions. This is very important! + + +Begin! + +Below are the raw input image and the initial user input query: diff --git a/third_party/sam3/sam3/agent/system_prompts/system_prompt_iterative_checking.txt b/third_party/sam3/sam3/agent/system_prompts/system_prompt_iterative_checking.txt new file mode 100644 index 0000000000000000000000000000000000000000..f6f9b881dbf4390984f5c9d60a2a8cfd8e6520c8 --- /dev/null +++ b/third_party/sam3/sam3/agent/system_prompts/system_prompt_iterative_checking.txt @@ -0,0 +1,26 @@ +You are a helpful assistant specializing in detail-oriented visual understanding, reasoning, and classification, capable of carefully analyzing a predicted segmentation mask on an image along with zoomed-in views of the area around the predicted segmentation mask to determine whether the object covered by the predicted segmentation mask is one of the correct masks that match the user query. + +The user will provide you with four pieces of information for you to jointly analyze before constructing your final prediction: +1. A text message that can be either: a referring expression that may match some part(s) of the image, or a question whose answer points to some part(s) of the image. +2. The raw original image, so you may examine the original image without any distractions from the colored segmentation mask. +3. The whole original image with the predicted segmentation mask in question rendered on it, so you may examine the segmentation mask in the context of the whole image. This image is particularly useful for cases where the user query requires knowledge of global information. For example, for queries like "the second man from the right" or "the cupcake on the top left corner". +4. A zoomed-in version of the predicted segmentation mask in question. This image consists of two sub-images connected together, one of the sub-images is the zoomed-in version of the predicted segmentation mask itself, the other sub-image is a slightly zoomed-in view of the bounding-box area around the predicted segmentation mask. + + +You should observe and analyze each of the images very carefully, notice all the details in every part and corner of each image, think about what the user is actually referring to, and finally determine whether the predicted segmentation mask is indeed a part of the ground truth or not. + +Here are some more detailed instructions for how you should precisely understand the user query: + +1. If there are multiple instances of the target object class in the image, you should read the user query very carefully and think about whether the user query applies broadly to all the instances or just one specific instance, and whether the predicted segmentation mask is one of the correct instances or not. +2. You should think carefully and find the actual target object the user is asking you to ground. Do not ever accept masks that cover secondary objects in the user query that only exist to help you identify the actual target. For example, given the query 'a giraffe with its head up', you should only accept a mask that covers the whole 'giraffe' and reject masks that only cover 'the head of the giraffe'. Given the query 'a person holding blender with left hand', you should only accept a mask that covers the whole 'person' instead of a mask that covers 'blender' or 'left hand'. Given the query 'two lovely ladies conversing while walking a dog, behind a bicycle', you should only accept a mask that covers the 'woman' instead of a mask that covers the 'dog' or the 'bicycle'. Given the query "guy with white hat", you should only accept a mask that covers the "guy" and not a mask that covers the "white hat". +3. Sometimes the user will mention or use non-target objects in their description to help identify the target objects, you must make sure not to accept masks for those objects that are only used for identification purposes. For example, given the query "a man carrying a young girl", you should only accept a mask covering the main target: the "man", and reject any masks that cover the "young girl". Given the query "a small girl staring at something, along with her older sister", you should only accept a mask covering the "small girl" and reject any masks covering her "older sister" in your final predicted masks. +4. Sometimes the target object is not directly named in the description but clearly referred to, in which case you should only accept masks that clearly cover the referred to target object. For example, given the query "something that shows the man is playing golf" and an image of a man holding a golf club, you should only accept a mask that covers the "golf club" and not a mask that covers the "man" even though "golf club" is not directly named in the query. +5. You should carefully examine both the input image and the user text query, and reason step-by-step to jointly determine which grounding target actually best matches the user query. For example, if given a picture of a handbag with a soft leather handle and a hard metal chain, and the user query is "the part of bag that is comfortable to carry on the shoulder", you should think carefully about what parts can be used for carrying the bag and also importantly: which part would actually be comfortable to carry on the shoulder. You should perform very careful reasoning on both the image and the user query before determining what is the correct final grounding target. + + +Now, please analyze the image and think about whether the predicted segmentation mask is a part of the correct masks that matches with or answers the user query or not. First output your detailed analysis of each input image, and then output your step-by-step reasoning explaining why the predicted segmentation mask is correct or incorrect, and then finally respond with either Accept or Reject. + +Please only respond in the following format and never break format for any reason: + +Analyze the user query and the three images: the raw input image, the image with the predicted segmentation mask rendered on it, and the image containing the zoomed-in version of the predicted segmentation mask. Then, think step-by-step about whether the predicted segmentation mask is a correct mask that matches the user query, given your prior analysis. +Accept or Reject diff --git a/third_party/sam3/sam3/agent/viz.py b/third_party/sam3/sam3/agent/viz.py new file mode 100644 index 0000000000000000000000000000000000000000..f5a786737d47a4d5c20a676b9a7d92e81d2c915a --- /dev/null +++ b/third_party/sam3/sam3/agent/viz.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +import cv2 +import numpy as np +import pycocotools.mask as mask_utils +from PIL import Image + +from .helpers.visualizer import Visualizer +from .helpers.zoom_in import render_zoom_in + + +def visualize( + input_json: dict, + zoom_in_index: int | None = None, + mask_alpha: float = 0.15, + label_mode: str = "1", + font_size_multiplier: float = 1.2, + boarder_width_multiplier: float = 0, +): + """ + Unified visualization function. + + If zoom_in_index is None: + - Render all masks in input_json (equivalent to visualize_masks_from_result_json). + - Returns: PIL.Image + + If zoom_in_index is provided: + - Returns two PIL.Images: + 1) Output identical to zoom_in_and_visualize(input_json, index). + 2) The same instance rendered via the general overlay using the color + returned by (1), equivalent to calling visualize_masks_from_result_json + on a single-mask json_i with color=color_hex. + """ + # Common fields + orig_h = int(input_json["orig_img_h"]) + orig_w = int(input_json["orig_img_w"]) + img_path = input_json["original_image_path"] + + # ---------- Mode A: Full-scene render ---------- + if zoom_in_index is None: + boxes = np.array(input_json["pred_boxes"]) + rle_masks = [ + {"size": (orig_h, orig_w), "counts": rle} + for rle in input_json["pred_masks"] + ] + binary_masks = [mask_utils.decode(rle) for rle in rle_masks] + + img_bgr = cv2.imread(img_path) + if img_bgr is None: + raise FileNotFoundError(f"Could not read image: {img_path}") + img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) + + viz = Visualizer( + img_rgb, + font_size_multiplier=font_size_multiplier, + boarder_width_multiplier=boarder_width_multiplier, + ) + viz.overlay_instances( + boxes=boxes, + masks=rle_masks, + binary_masks=binary_masks, + assigned_colors=None, + alpha=mask_alpha, + label_mode=label_mode, + ) + pil_all_masks = Image.fromarray(viz.output.get_image()) + return pil_all_masks + + # ---------- Mode B: Zoom-in pair ---------- + else: + idx = int(zoom_in_index) + num_masks = len(input_json.get("pred_masks", [])) + if idx < 0 or idx >= num_masks: + raise ValueError( + f"zoom_in_index {idx} is out of range (0..{num_masks - 1})." + ) + + # (1) Replicate zoom_in_and_visualize + object_data = { + "labels": [{"noun_phrase": f"mask_{idx}"}], + "segmentation": { + "counts": input_json["pred_masks"][idx], + "size": [orig_h, orig_w], + }, + } + pil_img = Image.open(img_path) + pil_mask_i_zoomed, color_hex = render_zoom_in( + object_data, pil_img, mask_alpha=mask_alpha + ) + + # (2) Single-instance render with the same color + boxes_i = np.array([input_json["pred_boxes"][idx]]) + rle_i = {"size": (orig_h, orig_w), "counts": input_json["pred_masks"][idx]} + bin_i = mask_utils.decode(rle_i) + + img_bgr_i = cv2.imread(img_path) + if img_bgr_i is None: + raise FileNotFoundError(f"Could not read image: {img_path}") + img_rgb_i = cv2.cvtColor(img_bgr_i, cv2.COLOR_BGR2RGB) + + viz_i = Visualizer( + img_rgb_i, + font_size_multiplier=font_size_multiplier, + boarder_width_multiplier=boarder_width_multiplier, + ) + viz_i.overlay_instances( + boxes=boxes_i, + masks=[rle_i], + binary_masks=[bin_i], + assigned_colors=[color_hex], + alpha=mask_alpha, + label_mode=label_mode, + ) + pil_mask_i = Image.fromarray(viz_i.output.get_image()) + + return pil_mask_i, pil_mask_i_zoomed diff --git a/third_party/sam3/sam3/assets/bpe_simple_vocab_16e6.txt.gz b/third_party/sam3/sam3/assets/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113 --- /dev/null +++ b/third_party/sam3/sam3/assets/bpe_simple_vocab_16e6.txt.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a +size 1356917 diff --git a/third_party/sam3/sam3/eval/__init__.py b/third_party/sam3/sam3/eval/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..726cb1eb0ff0f17f7edc4ec73f1a73d1ac87bf59 --- /dev/null +++ b/third_party/sam3/sam3/eval/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe diff --git a/third_party/sam3/sam3/eval/cgf1_eval.py b/third_party/sam3/sam3/eval/cgf1_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..71fe2ea6a564dd1586d5d412ddbf25a10b574b66 --- /dev/null +++ b/third_party/sam3/sam3/eval/cgf1_eval.py @@ -0,0 +1,705 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +import contextlib +import copy +import json +import os +import time +from collections import defaultdict +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import pycocotools.mask as maskUtils +from pycocotools.coco import COCO +from pycocotools.cocoeval import COCOeval +from scipy.optimize import linear_sum_assignment +from tqdm import tqdm + + +@dataclass +class Metric: + name: str + + # whether the metric is computed at the image level or the box level + image_level: bool + + # iou threshold (None is used for image level metrics or to indicate averaging over all thresholds in [0.5:0.95]) + iou_threshold: Union[float, None] + + +CGF1_METRICS = [ + Metric(name="cgF1", image_level=False, iou_threshold=None), + Metric(name="precision", image_level=False, iou_threshold=None), + Metric(name="recall", image_level=False, iou_threshold=None), + Metric(name="F1", image_level=False, iou_threshold=None), + Metric(name="positive_macro_F1", image_level=False, iou_threshold=None), + Metric(name="positive_micro_F1", image_level=False, iou_threshold=None), + Metric(name="positive_micro_precision", image_level=False, iou_threshold=None), + Metric(name="IL_precision", image_level=True, iou_threshold=None), + Metric(name="IL_recall", image_level=True, iou_threshold=None), + Metric(name="IL_F1", image_level=True, iou_threshold=None), + Metric(name="IL_FPR", image_level=True, iou_threshold=None), + Metric(name="IL_MCC", image_level=True, iou_threshold=None), + Metric(name="cgF1", image_level=False, iou_threshold=0.5), + Metric(name="precision", image_level=False, iou_threshold=0.5), + Metric(name="recall", image_level=False, iou_threshold=0.5), + Metric(name="F1", image_level=False, iou_threshold=0.5), + Metric(name="positive_macro_F1", image_level=False, iou_threshold=0.5), + Metric(name="positive_micro_F1", image_level=False, iou_threshold=0.5), + Metric(name="positive_micro_precision", image_level=False, iou_threshold=0.5), + Metric(name="cgF1", image_level=False, iou_threshold=0.75), + Metric(name="precision", image_level=False, iou_threshold=0.75), + Metric(name="recall", image_level=False, iou_threshold=0.75), + Metric(name="F1", image_level=False, iou_threshold=0.75), + Metric(name="positive_macro_F1", image_level=False, iou_threshold=0.75), + Metric(name="positive_micro_F1", image_level=False, iou_threshold=0.75), + Metric(name="positive_micro_precision", image_level=False, iou_threshold=0.75), +] + + +class COCOCustom(COCO): + """COCO class from pycocotools with tiny modifications for speed""" + + def createIndex(self): + # create index + print("creating index...") + anns, cats, imgs = {}, {}, {} + imgToAnns, catToImgs = defaultdict(list), defaultdict(list) + if "annotations" in self.dataset: + for ann in self.dataset["annotations"]: + imgToAnns[ann["image_id"]].append(ann) + anns[ann["id"]] = ann + + if "images" in self.dataset: + # MODIFICATION: do not reload imgs if they are already there + if self.imgs: + imgs = self.imgs + else: + for img in self.dataset["images"]: + imgs[img["id"]] = img + # END MODIFICATION + + if "categories" in self.dataset: + for cat in self.dataset["categories"]: + cats[cat["id"]] = cat + + if "annotations" in self.dataset and "categories" in self.dataset: + for ann in self.dataset["annotations"]: + catToImgs[ann["category_id"]].append(ann["image_id"]) + + print("index created!") + + # create class members + self.anns = anns + self.imgToAnns = imgToAnns + self.catToImgs = catToImgs + self.imgs = imgs + self.cats = cats + + def loadRes(self, resFile): + """ + Load result file and return a result api object. + :param resFile (str) : file name of result file + :return: res (obj) : result api object + """ + res = COCOCustom() + res.dataset["info"] = copy.deepcopy(self.dataset.get("info", {})) + # MODIFICATION: no copy + # res.dataset['images'] = [img for img in self.dataset['images']] + res.dataset["images"] = self.dataset["images"] + # END MODIFICATION + + print("Loading and preparing results...") + tic = time.time() + if type(resFile) == str: + with open(resFile) as f: + anns = json.load(f) + elif type(resFile) == np.ndarray: + anns = self.loadNumpyAnnotations(resFile) + else: + anns = resFile + assert type(anns) == list, "results in not an array of objects" + annsImgIds = [ann["image_id"] for ann in anns] + # MODIFICATION: faster and cached subset check + if not hasattr(self, "img_id_set"): + self.img_id_set = set(self.getImgIds()) + assert set(annsImgIds).issubset( + self.img_id_set + ), "Results do not correspond to current coco set" + # END MODIFICATION + if "caption" in anns[0]: + imgIds = set([img["id"] for img in res.dataset["images"]]) & set( + [ann["image_id"] for ann in anns] + ) + res.dataset["images"] = [ + img for img in res.dataset["images"] if img["id"] in imgIds + ] + for id, ann in enumerate(anns): + ann["id"] = id + 1 + elif "bbox" in anns[0] and not anns[0]["bbox"] == []: + res.dataset["categories"] = copy.deepcopy(self.dataset["categories"]) + for id, ann in enumerate(anns): + bb = ann["bbox"] + x1, x2, y1, y2 = [bb[0], bb[0] + bb[2], bb[1], bb[1] + bb[3]] + if not "segmentation" in ann: + ann["segmentation"] = [[x1, y1, x1, y2, x2, y2, x2, y1]] + ann["area"] = bb[2] * bb[3] + ann["id"] = id + 1 + ann["iscrowd"] = 0 + elif "segmentation" in anns[0]: + res.dataset["categories"] = copy.deepcopy(self.dataset["categories"]) + for id, ann in enumerate(anns): + # now only support compressed RLE format as segmentation results + ann["area"] = maskUtils.area(ann["segmentation"]) + if not "bbox" in ann: + ann["bbox"] = maskUtils.toBbox(ann["segmentation"]) + ann["id"] = id + 1 + ann["iscrowd"] = 0 + elif "keypoints" in anns[0]: + res.dataset["categories"] = copy.deepcopy(self.dataset["categories"]) + for id, ann in enumerate(anns): + s = ann["keypoints"] + x = s[0::3] + y = s[1::3] + x0, x1, y0, y1 = np.min(x), np.max(x), np.min(y), np.max(y) + ann["area"] = (x1 - x0) * (y1 - y0) + ann["id"] = id + 1 + ann["bbox"] = [x0, y0, x1 - x0, y1 - y0] + print("DONE (t={:0.2f}s)".format(time.time() - tic)) + + res.dataset["annotations"] = anns + # MODIFICATION: inherit images + res.imgs = self.imgs + # END MODIFICATION + res.createIndex() + return res + + +class CGF1Eval(COCOeval): + """ + This evaluator is based upon COCO evaluation, but evaluates the model in a more realistic setting + for downstream applications. + See SAM3 paper for the details on the CGF1 metric. + + Do not use this evaluator directly. Prefer the CGF1Evaluator wrapper. + + Notes: + - This evaluator does not support per-category evaluation (in the way defined by pyCocotools) + - In open vocabulary settings, we have different noun-phrases for each image. What we call an "image_id" here is actually an (image, noun-phrase) pair. So in every "image_id" there is only one category, implied by the noun-phrase. Thus we can ignore the usual coco "category" field of the predictions + """ + + def __init__( + self, + coco_gt=None, + coco_dt=None, + iouType="segm", + threshold=0.5, + ): + """ + Args: + coco_gt (COCO): ground truth COCO API + coco_dt (COCO): detections COCO API + iou_type (str): type of IoU to evaluate + threshold (float): threshold for predictions + """ + super().__init__(coco_gt, coco_dt, iouType) + self.threshold = threshold + + self.params.useCats = False + self.params.areaRng = [[0**2, 1e5**2]] + self.params.areaRngLbl = ["all"] + self.params.maxDets = [1000000] + + def computeIoU(self, imgId, catId): + # Same as the original COCOeval.computeIoU, but without sorting + p = self.params + if p.useCats: + gt = self._gts[imgId, catId] + dt = self._dts[imgId, catId] + else: + gt = [_ for cId in p.catIds for _ in self._gts[imgId, cId]] + dt = [_ for cId in p.catIds for _ in self._dts[imgId, cId]] + if len(gt) == 0 and len(dt) == 0: + return [] + + if p.iouType == "segm": + g = [g["segmentation"] for g in gt] + d = [d["segmentation"] for d in dt] + elif p.iouType == "bbox": + g = [g["bbox"] for g in gt] + d = [d["bbox"] for d in dt] + else: + raise Exception("unknown iouType for iou computation") + + # compute iou between each dt and gt region + iscrowd = [int(o["iscrowd"]) for o in gt] + ious = maskUtils.iou(d, g, iscrowd) + return ious + + def evaluateImg(self, imgId, catId, aRng, maxDet): + """ + perform evaluation for single category and image + :return: dict (single image results) + """ + p = self.params + assert not p.useCats, "This evaluator does not support per-category evaluation." + assert catId == -1 + all_gts = [_ for cId in p.catIds for _ in self._gts[imgId, cId]] + keep_gt = np.array([not g["ignore"] for g in all_gts], dtype=bool) + gt = [g for g in all_gts if not g["ignore"]] + all_dts = [_ for cId in p.catIds for _ in self._dts[imgId, cId]] + keep_dt = np.array([d["score"] >= self.threshold for d in all_dts], dtype=bool) + dt = [d for d in all_dts if d["score"] >= self.threshold] + if len(gt) == 0 and len(dt) == 0: + # This is a "true negative" case, where there are no GTs and no predictions + # The box-level metrics are ill-defined, so we don't add them to this dict + return { + "image_id": imgId, + "IL_TP": 0, + "IL_TN": 1, + "IL_FP": 0, + "IL_FN": 0, + "num_dt": len(dt), + } + + if len(gt) > 0 and len(dt) == 0: + # This is a "false negative" case, where there are GTs but no predictions + return { + "image_id": imgId, + "IL_TP": 0, + "IL_TN": 0, + "IL_FP": 0, + "IL_FN": 1, + "TPs": np.zeros((len(p.iouThrs),), dtype=np.int64), + "FPs": np.zeros((len(p.iouThrs),), dtype=np.int64), + "FNs": np.ones((len(p.iouThrs),), dtype=np.int64) * len(gt), + "local_F1s": np.zeros((len(p.iouThrs),), dtype=np.int64), + "local_positive_F1s": np.zeros((len(p.iouThrs),), dtype=np.int64), + "num_dt": len(dt), + } + + # Load pre-computed ious + ious = self.ious[(imgId, catId)] + + # compute matching + if len(ious) == 0: + ious = np.zeros((len(dt), len(gt))) + else: + ious = ious[keep_dt, :][:, keep_gt] + assert ious.shape == (len(dt), len(gt)) + + matched_dt, matched_gt = linear_sum_assignment(-ious) + + match_scores = ious[matched_dt, matched_gt] + + TPs, FPs, FNs = [], [], [] + IL_perfect = [] + for thresh in p.iouThrs: + TP = (match_scores >= thresh).sum() + FP = len(dt) - TP + FN = len(gt) - TP + assert ( + FP >= 0 and FN >= 0 + ), f"FP: {FP}, FN: {FN}, TP: {TP}, match_scores: {match_scores}, len(dt): {len(dt)}, len(gt): {len(gt)}, ious: {ious}" + TPs.append(TP) + FPs.append(FP) + FNs.append(FN) + + if FP == FN and FP == 0: + IL_perfect.append(1) + else: + IL_perfect.append(0) + + TPs = np.array(TPs, dtype=np.int64) + FPs = np.array(FPs, dtype=np.int64) + FNs = np.array(FNs, dtype=np.int64) + IL_perfect = np.array(IL_perfect, dtype=np.int64) + + # compute precision recall and F1 + precision = TPs / (TPs + FPs + 1e-4) + assert np.all(precision <= 1) + recall = TPs / (TPs + FNs + 1e-4) + assert np.all(recall <= 1) + F1 = 2 * precision * recall / (precision + recall + 1e-4) + + result = { + "image_id": imgId, + "TPs": TPs, + "FPs": FPs, + "FNs": FNs, + "local_F1s": F1, + "IL_TP": (len(gt) > 0) and (len(dt) > 0), + "IL_FP": (len(gt) == 0) and (len(dt) > 0), + "IL_TN": (len(gt) == 0) and (len(dt) == 0), + "IL_FN": (len(gt) > 0) and (len(dt) == 0), + "num_dt": len(dt), + } + if len(gt) > 0 and len(dt) > 0: + result["local_positive_F1s"] = F1 + return result + + def accumulate(self, p=None): + """ + Accumulate per image evaluation results and store the result in self.eval + :param p: input params for evaluation + :return: None + """ + if self.evalImgs is None or len(self.evalImgs) == 0: + print("Please run evaluate() first") + # allows input customized parameters + if p is None: + p = self.params + + setImgIds = set(p.imgIds) + + # TPs, FPs, FNs + TPs = np.zeros((len(p.iouThrs),), dtype=np.int64) + FPs = np.zeros((len(p.iouThrs),), dtype=np.int64) + pmFPs = np.zeros((len(p.iouThrs),), dtype=np.int64) + FNs = np.zeros((len(p.iouThrs),), dtype=np.int64) + local_F1s = np.zeros((len(p.iouThrs),), dtype=np.float64) + + # Image level metrics + IL_TPs = 0 + IL_FPs = 0 + IL_TNs = 0 + IL_FNs = 0 + + valid_img_count = 0 + valid_F1_count = 0 + evaledImgIds = set() + for res in self.evalImgs: + if res["image_id"] not in setImgIds: + continue + evaledImgIds.add(res["image_id"]) + IL_TPs += res["IL_TP"] + IL_FPs += res["IL_FP"] + IL_TNs += res["IL_TN"] + IL_FNs += res["IL_FN"] + + if "TPs" not in res: + continue + + TPs += res["TPs"] + FPs += res["FPs"] + FNs += res["FNs"] + valid_img_count += 1 + + if "local_positive_F1s" in res: + local_F1s += res["local_positive_F1s"] + pmFPs += res["FPs"] + if res["num_dt"] > 0: + valid_F1_count += 1 + + assert len(setImgIds - evaledImgIds) == 0, ( + f"{len(setImgIds - evaledImgIds)} images not evaluated. " + f"Here are the IDs of the first 3: {list(setImgIds - evaledImgIds)[:3]}" + ) + + # compute precision recall and F1 + precision = TPs / (TPs + FPs + 1e-4) + positive_micro_precision = TPs / (TPs + pmFPs + 1e-4) + assert np.all(precision <= 1) + recall = TPs / (TPs + FNs + 1e-4) + assert np.all(recall <= 1) + F1 = 2 * precision * recall / (precision + recall + 1e-4) + positive_micro_F1 = ( + 2 + * positive_micro_precision + * recall + / (positive_micro_precision + recall + 1e-4) + ) + + IL_rec = IL_TPs / (IL_TPs + IL_FNs + 1e-6) + IL_prec = IL_TPs / (IL_TPs + IL_FPs + 1e-6) + IL_F1 = 2 * IL_prec * IL_rec / (IL_prec + IL_rec + 1e-6) + IL_FPR = IL_FPs / (IL_FPs + IL_TNs + 1e-6) + IL_MCC = float(IL_TPs * IL_TNs - IL_FPs * IL_FNs) / ( + ( + float(IL_TPs + IL_FPs) + * float(IL_TPs + IL_FNs) + * float(IL_TNs + IL_FPs) + * float(IL_TNs + IL_FNs) + ) + ** 0.5 + + 1e-6 + ) + + self.eval = { + "params": p, + "TPs": TPs, + "FPs": FPs, + "positive_micro_FPs": pmFPs, + "FNs": FNs, + "precision": precision, + "positive_micro_precision": positive_micro_precision, + "recall": recall, + "F1": F1, + "positive_micro_F1": positive_micro_F1, + "positive_macro_F1": local_F1s / valid_F1_count, + "IL_recall": IL_rec, + "IL_precision": IL_prec, + "IL_F1": IL_F1, + "IL_FPR": IL_FPR, + "IL_MCC": IL_MCC, + } + self.eval["cgF1"] = self.eval["positive_micro_F1"] * self.eval["IL_MCC"] + + def summarize(self): + """ + Compute and display summary metrics for evaluation results. + """ + if not self.eval: + raise Exception("Please run accumulate() first") + + def _summarize(iouThr=None, metric=""): + p = self.params + iStr = " {:<18} @[ IoU={:<9}] = {:0.3f}" + titleStr = "Average " + metric + iouStr = ( + "{:0.2f}:{:0.2f}".format(p.iouThrs[0], p.iouThrs[-1]) + if iouThr is None + else "{:0.2f}".format(iouThr) + ) + + s = self.eval[metric] + # IoU + if iouThr is not None: + t = np.where(iouThr == p.iouThrs)[0] + s = s[t] + + if len(s[s > -1]) == 0: + mean_s = -1 + else: + mean_s = np.mean(s[s > -1]) + print(iStr.format(titleStr, iouStr, mean_s)) + return mean_s + + def _summarize_single(metric=""): + titleStr = "Average " + metric + iStr = " {:<35} = {:0.3f}" + s = self.eval[metric] + print(iStr.format(titleStr, s)) + return s + + def _summarizeDets(): + stats = [] + + for metric in CGF1_METRICS: + if metric.image_level: + stats.append(_summarize_single(metric=metric.name)) + else: + stats.append( + _summarize(iouThr=metric.iou_threshold, metric=metric.name) + ) + return np.asarray(stats) + + summarize = _summarizeDets + self.stats = summarize() + + +def _evaluate(self): + """ + Run per image evaluation on given images and store results (a list of dict) in self.evalImgs + """ + p = self.params + # add backward compatibility if useSegm is specified in params + p.imgIds = list(np.unique(p.imgIds)) + p.useCats = False + p.maxDets = sorted(p.maxDets) + self.params = p + + self._prepare() + # loop through images, area range, max detection number + catIds = [-1] + + if p.iouType == "segm" or p.iouType == "bbox": + computeIoU = self.computeIoU + else: + raise RuntimeError(f"Unsupported iou {p.iouType}") + self.ious = { + (imgId, catId): computeIoU(imgId, catId) + for imgId in p.imgIds + for catId in catIds + } + + maxDet = p.maxDets[-1] + evalImgs = [ + self.evaluateImg(imgId, catId, areaRng, maxDet) + for catId in catIds + for areaRng in p.areaRng + for imgId in p.imgIds + ] + # this is NOT in the pycocotools code, but could be done outside + evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds)) + return p.imgIds, evalImgs + + +class CGF1Evaluator: + """ + Wrapper class for cgF1 evaluation. + This supports the oracle setting (when several ground-truths are available per image) + """ + + def __init__( + self, + gt_path: Union[str, List[str]], + iou_type="segm", + verbose=False, + ): + """ + Args: + gt_path (str or list of str): path(s) to ground truth COCO json file(s) + iou_type (str): type of IoU to evaluate + threshold (float): threshold for predictions + """ + self.gt_paths = gt_path if isinstance(gt_path, list) else [gt_path] + self.iou_type = iou_type + + self.coco_gts = [COCOCustom(gt) for gt in self.gt_paths] + + self.verbose = verbose + + self.coco_evals = [] + for i, coco_gt in enumerate(self.coco_gts): + self.coco_evals.append( + CGF1Eval( + coco_gt=coco_gt, + iouType=iou_type, + ) + ) + self.coco_evals[i].useCats = False + + exclude_img_ids = set() + # exclude_img_ids are the ids that are not exhaustively annotated in any of the other gts + for coco_gt in self.coco_gts[1:]: + exclude_img_ids = exclude_img_ids.union( + { + img["id"] + for img in coco_gt.dataset["images"] + if not img["is_instance_exhaustive"] + } + ) + # we only eval on instance exhaustive queries + self.eval_img_ids = [ + img["id"] + for img in self.coco_gts[0].dataset["images"] + if (img["is_instance_exhaustive"] and img["id"] not in exclude_img_ids) + ] + + def evaluate(self, pred_file: str): + """ + Evaluate the detections using cgF1 metric. + + Args: + pred_file: path to the predictions COCO json file + + """ + assert len(self.coco_gts) > 0, "No ground truth provided for evaluation." + assert len(self.coco_gts) == len( + self.coco_evals + ), "Mismatch in number of ground truths and evaluators." + + if self.verbose: + print(f"Loading predictions from {pred_file}") + + with open(pred_file, "r") as f: + preds = json.load(f) + + if self.verbose: + print(f"Loaded {len(preds)} predictions") + + img2preds = defaultdict(list) + for pred in preds: + img2preds[pred["image_id"]].append(pred) + + all_eval_imgs = [] + for img_id in tqdm(self.eval_img_ids, disable=not self.verbose): + results = img2preds[img_id] + all_scorings = [] + for cur_coco_gt, coco_eval in zip(self.coco_gts, self.coco_evals): + # suppress pycocotools prints + with open(os.devnull, "w") as devnull: + with contextlib.redirect_stdout(devnull): + coco_dt = ( + cur_coco_gt.loadRes(results) if results else COCOCustom() + ) + + coco_eval.cocoDt = coco_dt + coco_eval.params.imgIds = [img_id] + coco_eval.params.useCats = False + img_ids, eval_imgs = _evaluate(coco_eval) + all_scorings.append(eval_imgs) + selected = self._select_best_scoring(all_scorings) + all_eval_imgs.append(selected) + + # After this point, we have selected the best scoring per image among several ground truths + # we can now accumulate and summarize, using only the first coco_eval + + self.coco_evals[0].evalImgs = list( + np.concatenate(all_eval_imgs, axis=2).flatten() + ) + self.coco_evals[0].params.imgIds = self.eval_img_ids + self.coco_evals[0]._paramsEval = copy.deepcopy(self.coco_evals[0].params) + + if self.verbose: + print(f"Accumulating results") + self.coco_evals[0].accumulate() + print("cgF1 metric, IoU type={}".format(self.iou_type)) + self.coco_evals[0].summarize() + print() + + out = {} + for i, value in enumerate(self.coco_evals[0].stats): + name = CGF1_METRICS[i].name + if CGF1_METRICS[i].iou_threshold is not None: + name = f"{name}@{CGF1_METRICS[i].iou_threshold}" + out[f"cgF1_eval_{self.iou_type}_{name}"] = float(value) + + return out + + @staticmethod + def _select_best_scoring(scorings): + # This function is used for "oracle" type evaluation. + # It accepts the evaluation results with respect to several ground truths, and picks the best + if len(scorings) == 1: + return scorings[0] + + assert ( + scorings[0].ndim == 3 + ), f"Expecting results in [numCats, numAreas, numImgs] format, got {scorings[0].shape}" + assert ( + scorings[0].shape[0] == 1 + ), f"Expecting a single category, got {scorings[0].shape[0]}" + + for scoring in scorings: + assert ( + scoring.shape == scorings[0].shape + ), f"Shape mismatch: {scoring.shape}, {scorings[0].shape}" + + selected_imgs = [] + for img_id in range(scorings[0].shape[-1]): + best = scorings[0][:, :, img_id] + + for scoring in scorings[1:]: + current = scoring[:, :, img_id] + if "local_F1s" in best[0, 0] and "local_F1s" in current[0, 0]: + # we were able to compute a F1 score for this particular image in both evaluations + # best["local_F1s"] contains the results at various IoU thresholds. We simply take the average for comparision + best_score = best[0, 0]["local_F1s"].mean() + current_score = current[0, 0]["local_F1s"].mean() + if current_score > best_score: + best = current + + else: + # If we're here, it means that in that in some evaluation we were not able to get a valid local F1 + # This happens when both the predictions and targets are empty. In that case, we can assume it's a perfect prediction + if "local_F1s" not in current[0, 0]: + best = current + selected_imgs.append(best) + result = np.stack(selected_imgs, axis=-1) + assert result.shape == scorings[0].shape + return result diff --git a/third_party/sam3/sam3/eval/coco_eval.py b/third_party/sam3/sam3/eval/coco_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..3716885d725b0ec3a90a230f565a507426f3a5a2 --- /dev/null +++ b/third_party/sam3/sam3/eval/coco_eval.py @@ -0,0 +1,914 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +""" +COCO evaluator that works in distributed mode. + +Mostly copy-paste from https://github.com/pytorch/vision/blob/edfd5a7/references/detection/coco_eval.py +The difference is that there is less copy-pasting from pycocotools +in the end of the file, as python3 can suppress prints with contextlib +""" + +import contextlib +import copy +import json +import logging +import os +import pickle +from collections import defaultdict +from pathlib import Path +from typing import Any, List, Optional + +import numpy as np +import pycocotools.mask as mask_utils +import torch +from iopath.common.file_io import g_pathmgr +from pycocotools.coco import COCO +from pycocotools.cocoeval import COCOeval +from sam3.train.masks_ops import rle_encode +from sam3.train.utils.distributed import ( + all_gather, + gather_to_rank_0_via_filesys, + get_rank, + is_main_process, +) + +RARITY_BUCKETS = {0: "frequent", 1: "common", 2: "medium", 3: "rare"} + + +class CocoEvaluator: + def __init__( + self, + coco_gt, + iou_types: List[str], + useCats: bool, + dump_dir: Optional[str], + postprocessor, + average_by_rarity=False, + metrics_dump_dir: Optional[str] = None, + gather_pred_via_filesys=False, + use_normalized_areas=True, + maxdets=[1, 10, 100], + exhaustive_only=False, + all_exhaustive_only=True, + ): + """Online coco evaluator. It will evaluate images as they are generated by the model, then accumulate/summarize at the end + + Args: + - coco_gt: COCO api object containing the gt + - iou_types: can be either "bbox" or "segm" + - useCats: If true, categories will be used for evaluation + - dump_dir: if non null, then the predictions will be dumped in that directory + - postprocessor: Module to convert the model's output into the coco format + - average_by_rarity: if true then we expect the images information in the gt dataset + to have a "rarity" field. Then the AP will be computed on all rarity buckets + individually, then averaged + - gather_pred_via_filesys: if true, we use the filesystem for collective gathers + - use_normalized_areas: if true, the areas of the objects in the GT are assumed to be + normalized by the area of the image. In that case, the size buckets are adjusted + - maxdets: maximal number of detections to be evaluated on each image. + - exhaustive_only: If true, we restrict eval only to exhaustive annotations + - all_exhaustive_only: If true, datapoints are restricted only to those with all exhaustive annotations + + """ + # coco_gt = copy.deepcopy(coco_gt) + self.coco_gts = [coco_gt] if not isinstance(coco_gt, list) else coco_gt + assert len(maxdets) == 3, f"expecting 3 detection threshold, got {len(maxdets)}" + + self.use_normalized_areas = use_normalized_areas + self.iou_types = iou_types + self.useCats = useCats + self.maxdets = maxdets + self.dump = None + self.dump_dir = dump_dir + if self.dump_dir is not None: + self.dump = [] + if is_main_process(): + if not os.path.exists(self.dump_dir): + os.makedirs(self.dump_dir, exist_ok=True) + logging.info(f"Create the folder: {dump_dir}") + + self.initialized = False + + # Whether to gather predictions through filesystem (instead of torch + # collective ops; requiring a shared filesystem across all ranks) + self.gather_pred_via_filesys = gather_pred_via_filesys + self.use_self_evaluate = True # CPP version is disabled + self.postprocessor = postprocessor + self.average_by_rarity = average_by_rarity + self.exhaustive_only = exhaustive_only + self.all_exhaustive_only = all_exhaustive_only + self.metrics_dump_dir = metrics_dump_dir + if self.metrics_dump_dir is not None: + if is_main_process(): + if not os.path.exists(self.metrics_dump_dir): + os.makedirs(self.metrics_dump_dir, exist_ok=True) + logging.info(f"Create the folder: {metrics_dump_dir}") + + def _lazy_init(self, coco_cls=COCO): + if self.initialized: + return + + self.initialized = True + + self.coco_gts = [ + coco_cls(g_pathmgr.get_local_path(gt)) if isinstance(gt, str) else gt + for gt in self.coco_gts + ] + + self.reset() + + self.eval_img_ids = None + + if self.exhaustive_only: + exclude_img_ids = set() + # exclude_img_ids are the ids that are not exhaustively annotated in any of the other gts + if self.all_exhaustive_only: + for coco_gt in self.coco_gts[1:]: + exclude_img_ids = exclude_img_ids.union( + { + img["id"] + for img in coco_gt.dataset["images"] + if not img["is_instance_exhaustive"] + } + ) + # we only eval on instance exhaustive queries + self.eval_img_ids = [ + img["id"] + for img in self.coco_gts[0].dataset["images"] + if (img["is_instance_exhaustive"] and img["id"] not in exclude_img_ids) + ] + + self.rarity_buckets = None + if self.average_by_rarity: + self.rarity_buckets = defaultdict(list) + eval_img_ids_set = ( + set(self.eval_img_ids) if self.eval_img_ids is not None else None + ) + for img in self.coco_gts[0].dataset["images"]: + if self.eval_img_ids is not None and img["id"] not in eval_img_ids_set: + continue + self.rarity_buckets[img["rarity"]].append(img["id"]) + print("Rarity buckets sizes:") + for k, v in self.rarity_buckets.items(): + print(f"{k}: {len(v)}") + + def set_sync_device(self, device: torch.device) -> Any: + self._sync_device = device + + def _evaluate(self, *args, **kwargs): + return evaluate(*args, **kwargs) + + def _loadRes(self, *args, **kwargs): + return loadRes(*args, **kwargs) + + def update(self, *args, **kwargs): + self._lazy_init() + predictions = self.postprocessor.process_results(*args, **kwargs) + + img_ids = list(np.unique(list(predictions.keys()))) + self.img_ids.extend(img_ids) + + for iou_type in self.iou_types: + results = self.prepare(predictions, iou_type) + self._dump(results) + + assert len(self.coco_gts) == len(self.coco_evals) + all_scorings = [] + for cur_coco_gt, cur_coco_eval in zip(self.coco_gts, self.coco_evals): + # suppress pycocotools prints + with open(os.devnull, "w") as devnull: + with contextlib.redirect_stdout(devnull): + coco_dt = ( + self._loadRes(cur_coco_gt, results) if results else COCO() + ) + + coco_eval = cur_coco_eval[iou_type] + + coco_eval.cocoDt = coco_dt + coco_eval.params.imgIds = list(img_ids) + coco_eval.params.useCats = self.useCats + coco_eval.params.maxDets = self.maxdets + img_ids, eval_imgs = self._evaluate(coco_eval, self.use_self_evaluate) + all_scorings.append(eval_imgs) + + selected = self.select_best_scoring(all_scorings) + self.eval_imgs[iou_type].append(selected) + + def select_best_scoring(self, scorings): + # This function is used for "oracle" type evaluation. + # It accepts the evaluation results with respect to several ground truths, and picks the best + if len(scorings) == 1: + return scorings[0] + + # Currently we don't support Oracle Phrase AP. + # To implement it, we likely need to modify the cpp code since the eval_image type is opaque + raise RuntimeError("Not implemented") + + def _dump(self, results): + if self.dump is not None: + dumped_results = copy.deepcopy(results) + for r in dumped_results: + if "bbox" not in self.iou_types and "bbox" in r: + del r["bbox"] + elif "bbox" in r: + r["bbox"] = [round(coord, 5) for coord in r["bbox"]] + r["score"] = round(r["score"], 5) + self.dump.extend(dumped_results) + + def synchronize_between_processes(self): + self._lazy_init() + logging.info("Coco evaluator: Synchronizing between processes") + for iou_type in self.iou_types: + if len(self.eval_imgs[iou_type]) > 0: + self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2) + else: + num_areas = len(self.coco_evals[0][iou_type].params.areaRng) + # assuming 1 class + assert not self.useCats + self.eval_imgs[iou_type] = np.empty((1, num_areas, 0)) + create_common_coco_eval( + self.coco_evals[0][iou_type], + self.img_ids, + self.eval_imgs[iou_type], + use_self_evaluate=self.use_self_evaluate, + gather_pred_via_filesys=self.gather_pred_via_filesys, + metrics_dump_dir=self.metrics_dump_dir, + ) + if self.dump is not None: + dumped_file = Path(self.dump_dir) / f"coco_predictions_{get_rank()}.json" + logging.info(f"COCO evaluator: Dumping local predictions to {dumped_file}") + with g_pathmgr.open(str(dumped_file), "w") as f: + json.dump(self.dump, f) + + # if self.gather_pred_via_filesys: + # dump = gather_to_rank_0_via_filesys(self.dump) + # else: + # dump = all_gather(self.dump, force_cpu=True) + # self.dump = sum(dump, []) + + def accumulate(self, imgIds=None): + self._lazy_init() + logging.info( + f"Coco evaluator: Accumulating on {len(imgIds) if imgIds is not None else 'all'} images" + ) + if not is_main_process(): + return + + if imgIds is None: + for coco_eval in self.coco_evals[0].values(): + accumulate(coco_eval, use_self_eval=self.use_self_evaluate) + + if imgIds is not None: + imgIds = set(imgIds) + for coco_eval in self.coco_evals[0].values(): + p = coco_eval.params + id_mask = np.array([(i in imgIds) for i in p.imgIds], dtype=bool) + old_img_ids = p.imgIds + coco_eval.params.imgIds = np.asarray(p.imgIds)[id_mask] + old_img_evals = coco_eval.evalImgs + catIds = p.catIds if p.useCats else [-1] + coco_eval.evalImgs = list( + np.asarray(coco_eval.evalImgs) + .reshape(len(catIds), len(p.areaRng), len(old_img_ids))[ + ..., id_mask + ] + .flatten() + ) + accumulate(coco_eval, use_self_eval=self.use_self_evaluate) + coco_eval.evalImgs = old_img_evals + coco_eval.params.imgIds = old_img_ids + + def summarize(self): + self._lazy_init() + logging.info("Coco evaluator: Summarizing") + if not is_main_process(): + return {} + + outs = {} + if self.rarity_buckets is None: + self.accumulate(self.eval_img_ids) + for iou_type, coco_eval in self.coco_evals[0].items(): + print("IoU metric: {}".format(iou_type)) + summarize(coco_eval) + + if "bbox" in self.coco_evals[0]: + for key, value in zip(*self.coco_evals[0]["bbox"].stats): + outs[f"coco_eval_bbox_{key}"] = value + if "segm" in self.coco_evals[0]: + for key, value in zip(*self.coco_evals[0]["segm"].stats): + outs[f"coco_eval_masks_{key}"] = value + else: + total_stats = {} + all_keys = {} + for bucket, img_list in self.rarity_buckets.items(): + self.accumulate(imgIds=img_list) + bucket_name = RARITY_BUCKETS[bucket] + for iou_type, coco_eval in self.coco_evals[0].items(): + print(f"IoU metric: {iou_type}. Rarity bucket: {bucket_name}") + summarize(coco_eval) + + if "bbox" in self.coco_evals[0]: + if "bbox" not in total_stats: + total_stats["bbox"] = np.zeros_like( + self.coco_evals[0]["bbox"].stats[1] + ) + all_keys["bbox"] = self.coco_evals[0]["bbox"].stats[0] + total_stats["bbox"] += self.coco_evals[0]["bbox"].stats[1] + for key, value in zip(*self.coco_evals[0]["bbox"].stats): + outs[f"coco_eval_bbox_{bucket_name}_{key}"] = value + if "segm" in self.coco_evals[0]: + if "segm" not in total_stats: + total_stats["segm"] = np.zeros_like( + self.coco_evals[0]["segm"].stats[1] + ) + all_keys["segm"] = self.coco_evals[0]["segm"].stats[0] + total_stats["segm"] += self.coco_evals[0]["segm"].stats[1] + for key, value in zip(*self.coco_evals[0]["segm"].stats): + outs[f"coco_eval_masks_{bucket_name}_{key}"] = value + + if "bbox" in total_stats: + total_stats["bbox"] /= len(self.rarity_buckets) + for key, value in zip(all_keys["bbox"], total_stats["bbox"]): + outs[f"coco_eval_bbox_{key}"] = value + if "segm" in total_stats: + total_stats["segm"] /= len(self.rarity_buckets) + for key, value in zip(all_keys["segm"], total_stats["segm"]): + outs[f"coco_eval_masks_{key}"] = value + + # if self.dump is not None: + # assert self.dump_dir is not None + # logging.info("Coco evaluator: Dumping the global result file to disk") + # with g_pathmgr.open(str(Path(self.dump_dir) / "coco_eval.json"), "w") as f: + # json.dump(self.dump, f) + return outs + + def compute_synced(self): + self._lazy_init() + self.synchronize_between_processes() + return self.summarize() + + def compute(self): + self._lazy_init() + return {"": 0.0} + + def reset(self, cocoeval_cls=COCOeval): + self.coco_evals = [{} for _ in range(len(self.coco_gts))] + for i, coco_gt in enumerate(self.coco_gts): + for iou_type in self.iou_types: + self.coco_evals[i][iou_type] = cocoeval_cls(coco_gt, iouType=iou_type) + self.coco_evals[i][iou_type].params.useCats = self.useCats + self.coco_evals[i][iou_type].params.maxDets = self.maxdets + if self.use_normalized_areas: + self.coco_evals[i][iou_type].params.areaRng = [ + [0, 1e5], + [0, 0.001], + [0.001, 0.01], + [0.01, 0.1], + [0.1, 0.5], + [0.5, 0.95], + [0.95, 1e5], + ] + self.coco_evals[i][iou_type].params.areaRngLbl = [ + "all", + "tiny", + "small", + "medium", + "large", + "huge", + "whole_image", + ] + + self.img_ids = [] + self.eval_imgs = {k: [] for k in self.iou_types} + if self.dump is not None: + self.dump = [] + + def write(self, stats): + self._lazy_init() + """Write the results in the stats dict""" + if "bbox" in self.coco_evals[0]: + stats["coco_eval_bbox"] = self.coco_evals[0]["bbox"].stats.tolist() + if "segm" in self.coco_evals[0]: + stats["coco_eval_masks"] = self.coco_evals[0]["segm"].stats.tolist() + return stats + + def prepare(self, predictions, iou_type): + self._lazy_init() + if iou_type == "bbox": + return self.prepare_for_coco_detection(predictions) + elif iou_type == "segm": + return self.prepare_for_coco_segmentation(predictions) + elif iou_type == "keypoints": + return self.prepare_for_coco_keypoint(predictions) + else: + raise ValueError("Unknown iou type {}".format(iou_type)) + + def prepare_for_coco_detection(self, predictions): + self._lazy_init() + coco_results = [] + for original_id, prediction in predictions.items(): + if len(prediction) == 0: + continue + + boxes = prediction["boxes"] + boxes = convert_to_xywh(boxes).tolist() + scores = prediction["scores"].tolist() + labels = prediction["labels"].tolist() + + coco_results.extend( + [ + { + "image_id": original_id, + "category_id": labels[k], + "bbox": box, + "score": scores[k], + } + for k, box in enumerate(boxes) + ] + ) + return coco_results + + @torch.no_grad() + def prepare_for_coco_segmentation(self, predictions): + self._lazy_init() + coco_results = [] + for original_id, prediction in predictions.items(): + if len(prediction) == 0: + continue + + scores = prediction["scores"].tolist() + labels = prediction["labels"].tolist() + boundaries, dilated_boundaries = None, None + if "boundaries" in prediction: + boundaries = prediction["boundaries"] + dilated_boundaries = prediction["dilated_boundaries"] + assert dilated_boundaries is not None + assert len(scores) == len(boundaries) + + if "masks_rle" in prediction: + rles = prediction["masks_rle"] + areas = [] + for rle in rles: + cur_area = mask_utils.area(rle) + h, w = rle["size"] + areas.append(cur_area / (h * w)) + else: + masks = prediction["masks"] + + masks = masks > 0.5 + h, w = masks.shape[-2:] + + areas = masks.flatten(1).sum(1) / (h * w) + areas = areas.tolist() + + rles = rle_encode(masks.squeeze(1)) + + # memory clean + del masks + del prediction["masks"] + + assert len(areas) == len(rles) == len(scores) + for k, rle in enumerate(rles): + payload = { + "image_id": original_id, + "category_id": labels[k], + "segmentation": rle, + "score": scores[k], + "area": areas[k], + } + if boundaries is not None: + payload["boundary"] = boundaries[k] + payload["dilated_boundary"] = dilated_boundaries[k] + + coco_results.append(payload) + + return coco_results + + def prepare_for_coco_keypoint(self, predictions): + self._lazy_init() + coco_results = [] + for original_id, prediction in predictions.items(): + if len(prediction) == 0: + continue + + boxes = prediction["boxes"] + boxes = convert_to_xywh(boxes).tolist() + scores = prediction["scores"].tolist() + labels = prediction["labels"].tolist() + keypoints = prediction["keypoints"] + keypoints = keypoints.flatten(start_dim=1).tolist() + + coco_results.extend( + [ + { + "image_id": original_id, + "category_id": labels[k], + "keypoints": keypoint, + "score": scores[k], + } + for k, keypoint in enumerate(keypoints) + ] + ) + return coco_results + + +def convert_to_xywh(boxes): + xmin, ymin, xmax, ymax = boxes.unbind(-1) + return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=-1) + + +def merge(img_ids, eval_imgs, gather_pred_via_filesys=False): + if gather_pred_via_filesys: + # only gather the predictions to rank 0 (other ranks will receive empty + # lists for `all_img_ids` and `all_eval_imgs`, which should be OK as + # merging and evaluation are only done on rank 0) + all_img_ids = gather_to_rank_0_via_filesys(img_ids) + all_eval_imgs = gather_to_rank_0_via_filesys(eval_imgs) + else: + all_img_ids = all_gather(img_ids, force_cpu=True) + all_eval_imgs = all_gather(eval_imgs, force_cpu=True) + if not is_main_process(): + return None, None + + merged_img_ids = [] + for p in all_img_ids: + merged_img_ids.extend(p) + + merged_eval_imgs = [] + for p in all_eval_imgs: + merged_eval_imgs.append(p) + + merged_img_ids = np.array(merged_img_ids) + merged_eval_imgs = np.concatenate(merged_eval_imgs, 2) + + # keep only unique (and in sorted order) images + merged_img_ids, idx = np.unique(merged_img_ids, return_index=True) + merged_eval_imgs = merged_eval_imgs[..., idx] + + return merged_img_ids, merged_eval_imgs + + +def create_common_coco_eval( + coco_eval, + img_ids, + eval_imgs, + use_self_evaluate, + gather_pred_via_filesys=False, + metrics_dump_dir=None, +): + img_ids, eval_imgs = merge(img_ids, eval_imgs, gather_pred_via_filesys) + if not is_main_process(): + return + if metrics_dump_dir is not None: + dumped_file = ( + Path(metrics_dump_dir) / f"coco_eval_img_metrics_{get_rank()}.json" + ) + logging.info(f"COCO evaluator: Dumping local predictions to {dumped_file}") + with g_pathmgr.open(str(dumped_file), "w") as f: + json.dump(eval_imgs.squeeze(), f, default=lambda x: x.tolist()) + img_ids = list(img_ids) + + # If some images were not predicted, we need to create dummy detections for them + missing_img_ids = set(coco_eval.cocoGt.getImgIds()) - set(img_ids) + if len(missing_img_ids) > 0: + print(f"WARNING: {len(missing_img_ids)} images were not predicted!") + coco_eval.cocoDt = COCO() + coco_eval.params.imgIds = list(missing_img_ids) + new_img_ids, new_eval_imgs = evaluate(coco_eval, use_self_evaluate) + img_ids.extend(new_img_ids) + eval_imgs = np.concatenate((eval_imgs, new_eval_imgs), axis=2) + + eval_imgs = list(eval_imgs.flatten()) + assert len(img_ids) == len(coco_eval.cocoGt.getImgIds()) + + coco_eval.evalImgs = eval_imgs + coco_eval.params.imgIds = img_ids + coco_eval._paramsEval = copy.deepcopy(coco_eval.params) + + +################################################################# +# From pycocotools, just removed the prints and fixed +# a Python3 bug about unicode not defined +################################################################# + + +# Copy of COCO prepare, but doesn't convert anntoRLE +def segmentation_prepare(self): + """ + Prepare ._gts and ._dts for evaluation based on params + :return: None + """ + p = self.params + if p.useCats: + gts = self.cocoGt.loadAnns( + self.cocoGt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds) + ) + dts = self.cocoDt.loadAnns( + self.cocoDt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds) + ) + else: + gts = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds)) + dts = self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds)) + + for gt in gts: + gt["ignore"] = gt["ignore"] if "ignore" in gt else 0 + gt["ignore"] = "iscrowd" in gt and gt["iscrowd"] + if p.iouType == "keypoints": + gt["ignore"] = (gt["num_keypoints"] == 0) or gt["ignore"] + self._gts = defaultdict(list) # gt for evaluation + self._dts = defaultdict(list) # dt for evaluation + for gt in gts: + self._gts[gt["image_id"], gt["category_id"]].append(gt) + for dt in dts: + self._dts[dt["image_id"], dt["category_id"]].append(dt) + self.evalImgs = defaultdict(list) # per-image per-category evaluation results + self.eval = {} # accumulated evaluation results + + +def evaluate(self, use_self_evaluate): + """ + Run per image evaluation on given images and store results (a list of dict) in self.evalImgs + :return: None + """ + # tic = time.time() + # print('Running per image evaluation...', use_self_evaluate) + p = self.params + # add backward compatibility if useSegm is specified in params + if p.useSegm is not None: + p.iouType = "segm" if p.useSegm == 1 else "bbox" + print( + "useSegm (deprecated) is not None. Running {} evaluation".format(p.iouType) + ) + # print('Evaluate annotation type *{}*'.format(p.iouType)) + p.imgIds = list(np.unique(p.imgIds)) + if p.useCats: + p.catIds = list(np.unique(p.catIds)) + p.maxDets = sorted(p.maxDets) + self.params = p + + self._prepare() + # loop through images, area range, max detection number + catIds = p.catIds if p.useCats else [-1] + + if p.iouType == "segm" or p.iouType == "bbox": + computeIoU = self.computeIoU + elif p.iouType == "keypoints": + computeIoU = self.computeOks + self.ious = { + (imgId, catId): computeIoU(imgId, catId) + for imgId in p.imgIds + for catId in catIds + } + + maxDet = p.maxDets[-1] + if use_self_evaluate: + evalImgs = [ + self.evaluateImg(imgId, catId, areaRng, maxDet) + for catId in catIds + for areaRng in p.areaRng + for imgId in p.imgIds + ] + # this is NOT in the pycocotools code, but could be done outside + evalImgs = np.asarray(evalImgs).reshape( + len(catIds), len(p.areaRng), len(p.imgIds) + ) + return p.imgIds, evalImgs + + # <<<< Beginning of code differences with original COCO API + # def convert_instances_to_cpp(instances, is_det=False): + # # Convert annotations for a list of instances in an image to a format that's fast + # # to access in C++ + # instances_cpp = [] + # for instance in instances: + # instance_cpp = _CPP.InstanceAnnotation( + # int(instance["id"]), + # instance["score"] if is_det else instance.get("score", 0.0), + # instance["area"], + # bool(instance.get("iscrowd", 0)), + # bool(instance.get("ignore", 0)), + # ) + # instances_cpp.append(instance_cpp) + # return instances_cpp + + # # Convert GT annotations, detections, and IOUs to a format that's fast to access in C++ + # ground_truth_instances = [ + # [convert_instances_to_cpp(self._gts[imgId, catId]) for catId in p.catIds] + # for imgId in p.imgIds + # ] + # detected_instances = [ + # [ + # convert_instances_to_cpp(self._dts[imgId, catId], is_det=True) + # for catId in p.catIds + # ] + # for imgId in p.imgIds + # ] + # ious = [[self.ious[imgId, catId] for catId in catIds] for imgId in p.imgIds] + + # if not p.useCats: + # # For each image, flatten per-category lists into a single list + # ground_truth_instances = [ + # [[o for c in i for o in c]] for i in ground_truth_instances + # ] + # detected_instances = [[[o for c in i for o in c]] for i in detected_instances] + + # # Call C++ implementation of self.evaluateImgs() + # _evalImgs_cpp = _CPP.COCOevalEvaluateImages( + # p.areaRng, maxDet, p.iouThrs, ious, ground_truth_instances, detected_instances + # ) + + # self._paramsEval = copy.deepcopy(self.params) + # evalImgs = np.asarray(_evalImgs_cpp).reshape( + # len(catIds), len(p.areaRng), len(p.imgIds) + # ) + # return p.imgIds, evalImgs + + +################################################################# +# end of straight copy from pycocotools, just removing the prints +################################################################# + + +################################################################# +# From pycocotools, but disabled mask->box conversion which is +# pointless +################################################################# +def loadRes(self, resFile): + """ + Load result file and return a result api object. + :param resFile (str) : file name of result file + :return: res (obj) : result api object + """ + res = COCO() + res.dataset["images"] = [img for img in self.dataset["images"]] + + if type(resFile) == str: + anns = json.load(open(resFile)) + elif type(resFile) == np.ndarray: + anns = self.loadNumpyAnnotations(resFile) + else: + anns = resFile + assert type(anns) == list, "results in not an array of objects" + annsImgIds = [ann["image_id"] for ann in anns] + assert set(annsImgIds) == ( + set(annsImgIds) & set(self.getImgIds()) + ), "Results do not correspond to current coco set" + if "caption" in anns[0]: + imgIds = set([img["id"] for img in res.dataset["images"]]) & set( + [ann["image_id"] for ann in anns] + ) + res.dataset["images"] = [ + img for img in res.dataset["images"] if img["id"] in imgIds + ] + for id, ann in enumerate(anns): + ann["id"] = id + 1 + elif "bbox" in anns[0] and not anns[0]["bbox"] == []: + res.dataset["categories"] = copy.deepcopy(self.dataset["categories"]) + for id, ann in enumerate(anns): + bb = ann["bbox"] + x1, x2, y1, y2 = [bb[0], bb[0] + bb[2], bb[1], bb[1] + bb[3]] + if "segmentation" not in ann: + ann["segmentation"] = [[x1, y1, x1, y2, x2, y2, x2, y1]] + ann["area"] = bb[2] * bb[3] + ann["id"] = id + 1 + ann["iscrowd"] = 0 + elif "segmentation" in anns[0]: + res.dataset["categories"] = copy.deepcopy(self.dataset["categories"]) + for id, ann in enumerate(anns): + # now only support compressed RLE format as segmentation results + # ann["area"] = mask_util.area(ann["segmentation"]) + # The following lines are disabled because they are pointless + # if not 'bbox' in ann: + # ann['bbox'] = maskUtils.toBbox(ann['segmentation']) + ann["id"] = id + 1 + ann["iscrowd"] = 0 + elif "keypoints" in anns[0]: + res.dataset["categories"] = copy.deepcopy(self.dataset["categories"]) + for id, ann in enumerate(anns): + s = ann["keypoints"] + x = s[0::3] + y = s[1::3] + x0, x1, y0, y1 = np.min(x), np.max(x), np.min(y), np.max(y) + ann["area"] = (x1 - x0) * (y1 - y0) + ann["id"] = id + 1 + ann["bbox"] = [x0, y0, x1 - x0, y1 - y0] + + res.dataset["annotations"] = anns + res.createIndex() + return res + + +################################################################# +# end of straight copy from pycocotools +################################################################# + + +################################################################# +# From pycocotools, but added handling of custom area rngs, and returns stat keys +################################################################# +def summarize(self): + """ + Compute and display summary metrics for evaluation results. + Note this functin can *only* be applied on the default parameter setting + """ + + def _summarize(ap=1, iouThr=None, areaRng="all", maxDets=100): + p = self.params + iStr = " {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}" + titleStr = "Average Precision" if ap == 1 else "Average Recall" + typeStr = "(AP)" if ap == 1 else "(AR)" + iouStr = ( + "{:0.2f}:{:0.2f}".format(p.iouThrs[0], p.iouThrs[-1]) + if iouThr is None + else "{:0.2f}".format(iouThr) + ) + + aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng] + mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets] + if ap == 1: + # dimension of precision: [TxRxKxAxM] + s = self.eval["precision"] + # IoU + if iouThr is not None: + t = np.where(iouThr == p.iouThrs)[0] + s = s[t] + s = s[:, :, :, aind, mind] + else: + # dimension of recall: [TxKxAxM] + s = self.eval["recall"] + if iouThr is not None: + t = np.where(iouThr == p.iouThrs)[0] + s = s[t] + s = s[:, :, aind, mind] + if len(s[s > -1]) == 0: + mean_s = -1 + else: + mean_s = np.mean(s[s > -1]) + print(iStr.format(titleStr, typeStr, iouStr, areaRng, maxDets, mean_s)) + return mean_s + + def _summarizeDets(): + nb_results = 6 + (len(self.params.areaRng) - 1) * 2 + assert len(self.params.areaRng) == len(self.params.areaRngLbl) + stats = np.zeros((nb_results,)) + keys = ["AP", "AP_50", "AP_75"] + stats[0] = _summarize(1, maxDets=self.params.maxDets[2]) + stats[1] = _summarize(1, iouThr=0.5, maxDets=self.params.maxDets[2]) + stats[2] = _summarize(1, iouThr=0.75, maxDets=self.params.maxDets[2]) + cur_id = 3 + for area in self.params.areaRngLbl[1:]: + stats[cur_id] = _summarize(1, areaRng=area, maxDets=self.params.maxDets[2]) + cur_id += 1 + keys.append(f"AP_{area}") + stats[cur_id] = _summarize(0, maxDets=self.params.maxDets[0]) + cur_id += 1 + stats[cur_id] = _summarize(0, maxDets=self.params.maxDets[1]) + cur_id += 1 + stats[cur_id] = _summarize(0, maxDets=self.params.maxDets[2]) + cur_id += 1 + keys += ["AR", "AR_50", "AR_75"] + + for area in self.params.areaRngLbl[1:]: + stats[cur_id] = _summarize(0, areaRng=area, maxDets=self.params.maxDets[2]) + cur_id += 1 + keys.append(f"AR_{area}") + assert len(stats) == len(keys) + return keys, stats + + if not self.eval: + raise Exception("Please run accumulate() first") + self.stats = _summarizeDets() + + +################################################################# +# end of straight copy from pycocotools +################################################################# + + +################################################################# +# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/evaluation/fast_eval_api.py +# with slight adjustments +################################################################# +def accumulate(self, use_self_eval=False): + """ + Accumulate per image evaluation results and store the result in self.eval. Does not + support changing parameter settings from those used by self.evaluate() + """ + if use_self_eval: + self.accumulate() + return + # CPP code is disabled + # self.eval = _CPP.COCOevalAccumulate(self.params, self.evalImgs) + + # # recall is num_iou_thresholds X num_categories X num_area_ranges X num_max_detections + # self.eval["recall"] = np.array(self.eval["recall"]).reshape( + # self.eval["counts"][:1] + self.eval["counts"][2:] + # ) + + # # precision and scores are num_iou_thresholds X num_recall_thresholds X num_categories X + # # num_area_ranges X num_max_detections + # self.eval["precision"] = np.array(self.eval["precision"]).reshape( + # self.eval["counts"] + # ) + # self.eval["scores"] = np.array(self.eval["scores"]).reshape(self.eval["counts"]) diff --git a/third_party/sam3/sam3/eval/coco_eval_offline.py b/third_party/sam3/sam3/eval/coco_eval_offline.py new file mode 100644 index 0000000000000000000000000000000000000000..ec8a0c076ddfaa947de9c7cf17991a2ec4ddb3e8 --- /dev/null +++ b/third_party/sam3/sam3/eval/coco_eval_offline.py @@ -0,0 +1,183 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +""" +This evaluator is meant for regular COCO mAP evaluation, for example on the COCO val set. + +For Category mAP, we need the model to make predictions for all the categories on every single image. +In general, since the number of classes can be big, and the API model makes predictions individually for each pair (image, class), +we may need to split the inference process for a given image in several chunks. +""" + +import logging +from collections import defaultdict + +import torch +from pycocotools.coco import COCO +from pycocotools.cocoeval import COCOeval +from sam3.train.utils.distributed import is_main_process + +try: + from tidecv import datasets, TIDE + + HAS_TIDE = True +except ImportError: + HAS_TIDE = False + print("WARNING: TIDE not installed. Detailed analysis will not be available.") + + +# the COCO detection metrics (https://github.com/cocodataset/cocoapi/blob/8c9bcc3cf640524c4c20a9c40e89cb6a2f2fa0e9/PythonAPI/pycocotools/cocoeval.py#L460-L471) +COCO_METRICS = [ + "AP", + "AP_50", + "AP_75", + "AP_small", + "AP_medium", + "AP_large", + "AR_maxDets@1", + "AR_maxDets@10", + "AR_maxDets@100", + "AR_small", + "AR_medium", + "AR_large", +] + + +def convert_to_xywh(boxes): + """Convert bounding boxes from xyxy format to xywh format.""" + xmin, ymin, xmax, ymax = boxes.unbind(-1) + return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=-1) + + +class HeapElement: + """Utility class to make a heap with a custom comparator""" + + def __init__(self, val): + self.val = val + + def __lt__(self, other): + return self.val["score"] < other.val["score"] + + +class COCOevalCustom(COCOeval): + """ + This is a slightly modified version of the original COCO API with added support for positive split evaluation. + """ + + def __init__( + self, cocoGt=None, cocoDt=None, iouType="segm", dt_only_positive=False + ): + super().__init__(cocoGt, cocoDt, iouType) + self.dt_only_positive = dt_only_positive + + def _prepare(self): + """ + Prepare ._gts and ._dts for evaluation based on params + :return: None + """ + + def _toMask(anns, coco): + # modify ann['segmentation'] by reference + for ann in anns: + rle = coco.annToRLE(ann) + ann["segmentation"] = rle + + p = self.params + if p.useCats: + gts = self.cocoGt.loadAnns( + self.cocoGt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds) + ) + dts = self.cocoDt.loadAnns( + self.cocoDt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds) + ) + else: + gts = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds)) + dts = self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds)) + + # convert ground truth to mask if iouType == 'segm' + if p.iouType == "segm": + _toMask(gts, self.cocoGt) + _toMask(dts, self.cocoDt) + # set ignore flag + for gt in gts: + gt["ignore"] = gt["ignore"] if "ignore" in gt else 0 + gt["ignore"] = "iscrowd" in gt and gt["iscrowd"] + if p.iouType == "keypoints": + gt["ignore"] = (gt["num_keypoints"] == 0) or gt["ignore"] + self._gts = defaultdict(list) # gt for evaluation + self._dts = defaultdict(list) # dt for evaluation + + _gts_cat_ids = defaultdict(set) # gt for evaluation on positive split + for gt in gts: + self._gts[gt["image_id"], gt["category_id"]].append(gt) + _gts_cat_ids[gt["image_id"]].add(gt["category_id"]) + + #### BEGIN MODIFICATION #### + for dt in dts: + if ( + self.dt_only_positive + and dt["category_id"] not in _gts_cat_ids[dt["image_id"]] + ): + continue + self._dts[dt["image_id"], dt["category_id"]].append(dt) + #### END MODIFICATION #### + self.evalImgs = defaultdict(list) # per-image per-category evaluation results + self.eval = {} # accumulated evaluation results + + +class CocoEvaluatorOfflineWithPredFileEvaluators: + def __init__( + self, + gt_path, + tide: bool = True, + iou_type: str = "bbox", + positive_split=False, + ): + self.gt_path = gt_path + self.tide_enabled = HAS_TIDE and tide + self.positive_split = positive_split + self.iou_type = iou_type + + def evaluate(self, dumped_file): + if not is_main_process(): + return {} + + logging.info("OfflineCoco evaluator: Loading groundtruth") + self.gt = COCO(self.gt_path) + + # Creating the result file + logging.info("Coco evaluator: Creating the result file") + cocoDt = self.gt.loadRes(str(dumped_file)) + + # Run the evaluation + logging.info("Coco evaluator: Running evaluation") + coco_eval = COCOevalCustom( + self.gt, cocoDt, iouType=self.iou_type, dt_only_positive=self.positive_split + ) + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + + outs = {} + for i, value in enumerate(coco_eval.stats): + outs[f"coco_eval_{self.iou_type}_{COCO_METRICS[i]}"] = value + + if self.tide_enabled: + logging.info("Coco evaluator: Loading TIDE") + self.tide_gt = datasets.COCO(self.gt_path) + self.tide = TIDE(mode="mask" if self.iou_type == "segm" else "bbox") + + # Run TIDE + logging.info("Coco evaluator: Running TIDE") + self.tide.evaluate( + self.tide_gt, datasets.COCOResult(str(dumped_file)), name="coco_eval" + ) + self.tide.summarize() + for k, v in self.tide.get_main_errors()["coco_eval"].items(): + outs[f"coco_eval_{self.iou_type}_TIDE_{k}"] = v + + for k, v in self.tide.get_special_errors()["coco_eval"].items(): + outs[f"coco_eval_{self.iou_type}_TIDE_{k}"] = v + + return outs diff --git a/third_party/sam3/sam3/eval/coco_reindex.py b/third_party/sam3/sam3/eval/coco_reindex.py new file mode 100644 index 0000000000000000000000000000000000000000..5dc02abfc81d6635fb813701e872aee746bcd16d --- /dev/null +++ b/third_party/sam3/sam3/eval/coco_reindex.py @@ -0,0 +1,232 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +""" +Self-contained COCO JSON re-indexing function that creates temporary files. +""" + +import json +import os +import tempfile +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + + +def reindex_coco_to_temp(input_json_path: str) -> Optional[str]: + """ + Convert 0-indexed COCO JSON file to 1-indexed and save to temporary location. + + Args: + input_json_path: Path to the input COCO JSON file + + Returns: + Path to the new 1-indexed JSON file in temporary directory, or None if no conversion needed + + Raises: + FileNotFoundError: If input file doesn't exist + json.JSONDecodeError: If input file is not valid JSON + ValueError: If input file is not a valid COCO format + """ + + def is_coco_json(data: Dict[str, Any]) -> bool: + """Check if data appears to be a COCO format file.""" + if not isinstance(data, dict): + return False + # A COCO file should have at least one of these keys + coco_keys = {"images", "annotations", "categories"} + return any(key in data for key in coco_keys) + + def check_zero_indexed(data: Dict[str, Any]) -> Tuple[bool, bool, bool]: + """ + Check if annotations, images, or categories start from index 0. + + Returns: + Tuple of (annotations_zero_indexed, images_zero_indexed, categories_zero_indexed) + """ + annotations_zero = False + images_zero = False + categories_zero = False + + # Check annotations + annotations = data.get("annotations", []) + if annotations and any(ann.get("id", -1) == 0 for ann in annotations): + annotations_zero = True + + # Check images + images = data.get("images", []) + if images and any(img.get("id", -1) == 0 for img in images): + images_zero = True + + # Check categories + categories = data.get("categories", []) + if categories and any(cat.get("id", -1) == 0 for cat in categories): + categories_zero = True + + return annotations_zero, images_zero, categories_zero + + def reindex_coco_data(data: Dict[str, Any]) -> Dict[str, Any]: + """Convert 0-indexed COCO data to 1-indexed.""" + modified_data = data.copy() + + annotations_zero, images_zero, categories_zero = check_zero_indexed(data) + + # Create ID mapping for consistency + image_id_mapping = {} + category_id_mapping = {} + + # Process images first (since annotations reference image IDs) + if images_zero and "images" in modified_data: + for img in modified_data["images"]: + old_id = img["id"] + new_id = old_id + 1 + image_id_mapping[old_id] = new_id + img["id"] = new_id + + # Process categories (since annotations reference category IDs) + if categories_zero and "categories" in modified_data: + for cat in modified_data["categories"]: + old_id = cat["id"] + new_id = old_id + 1 + category_id_mapping[old_id] = new_id + cat["id"] = new_id + + # Process annotations + if "annotations" in modified_data: + for ann in modified_data["annotations"]: + # Update annotation ID if needed + if annotations_zero: + ann["id"] = ann["id"] + 1 + + # Update image_id reference if images were reindexed + if images_zero and ann.get("image_id") is not None: + old_image_id = ann["image_id"] + if old_image_id in image_id_mapping: + ann["image_id"] = image_id_mapping[old_image_id] + + # Update category_id reference if categories were reindexed + if categories_zero and ann.get("category_id") is not None: + old_category_id = ann["category_id"] + if old_category_id in category_id_mapping: + ann["category_id"] = category_id_mapping[old_category_id] + + return modified_data + + # Validate input path + if not os.path.exists(input_json_path): + raise FileNotFoundError(f"Input file not found: {input_json_path}") + + # Load and validate JSON data + try: + with open(input_json_path, "r", encoding="utf-8") as f: + data = json.load(f) + except json.JSONDecodeError as e: + raise json.JSONDecodeError(f"Invalid JSON in {input_json_path}: {e}") + + # Validate COCO format + if not is_coco_json(data): + raise ValueError( + f"File does not appear to be in COCO format: {input_json_path}" + ) + + # Check if reindexing is needed + annotations_zero, images_zero, categories_zero = check_zero_indexed(data) + + if not (annotations_zero or images_zero or categories_zero): + # No conversion needed - just copy to temp location + input_path = Path(input_json_path) + temp_dir = tempfile.mkdtemp() + temp_filename = f"{input_path.stem}_1_indexed{input_path.suffix}" + temp_path = os.path.join(temp_dir, temp_filename) + + with open(temp_path, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2, ensure_ascii=False) + + return temp_path + + # Perform reindexing + modified_data = reindex_coco_data(data) + + # Create temporary file + input_path = Path(input_json_path) + temp_dir = tempfile.mkdtemp() + temp_filename = f"{input_path.stem}_1_indexed{input_path.suffix}" + temp_path = os.path.join(temp_dir, temp_filename) + + # Write modified data to temporary file + with open(temp_path, "w", encoding="utf-8") as f: + json.dump(modified_data, f, indent=2, ensure_ascii=False) + + return temp_path + + +# Example usage and test function +def test_reindex_function(): + """Test the reindex function with a sample COCO file.""" + + # Create a test COCO file + test_data = { + "info": {"description": "Test COCO dataset", "version": "1.0", "year": 2023}, + "images": [ + {"id": 0, "width": 640, "height": 480, "file_name": "test1.jpg"}, + {"id": 1, "width": 640, "height": 480, "file_name": "test2.jpg"}, + ], + "categories": [ + {"id": 0, "name": "person", "supercategory": "person"}, + {"id": 1, "name": "car", "supercategory": "vehicle"}, + ], + "annotations": [ + { + "id": 0, + "image_id": 0, + "category_id": 0, + "bbox": [100, 100, 50, 75], + "area": 3750, + "iscrowd": 0, + }, + { + "id": 1, + "image_id": 1, + "category_id": 1, + "bbox": [200, 150, 120, 80], + "area": 9600, + "iscrowd": 0, + }, + ], + } + + # Create temporary test file + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump(test_data, f, indent=2) + test_file_path = f.name + + try: + # Test the function + result_path = reindex_coco_to_temp(test_file_path) + print(f"Original file: {test_file_path}") + print(f"Converted file: {result_path}") + + # Load and display the result + with open(result_path, "r") as f: + result_data = json.load(f) + + print("\nConverted data sample:") + print(f"First image ID: {result_data['images'][0]['id']}") + print(f"First category ID: {result_data['categories'][0]['id']}") + print(f"First annotation ID: {result_data['annotations'][0]['id']}") + print(f"First annotation image_id: {result_data['annotations'][0]['image_id']}") + print( + f"First annotation category_id: {result_data['annotations'][0]['category_id']}" + ) + + # Clean up + os.unlink(result_path) + os.rmdir(os.path.dirname(result_path)) + + finally: + # Clean up test file + os.unlink(test_file_path) + + +if __name__ == "__main__": + test_reindex_function() diff --git a/third_party/sam3/sam3/eval/coco_writer.py b/third_party/sam3/sam3/eval/coco_writer.py new file mode 100644 index 0000000000000000000000000000000000000000..54f9307f64d0f1742b87257f60314b43e6574fd3 --- /dev/null +++ b/third_party/sam3/sam3/eval/coco_writer.py @@ -0,0 +1,354 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +""" +COCO prediction dumper for distributed training. + +Handles collection and dumping of COCO-format predictions from models. +Supports distributed processing with multiple GPUs/processes. +""" + +import copy +import gc +import heapq +import json +import logging +import os +from collections import defaultdict +from pathlib import Path +from typing import Any, Optional + +import pycocotools.mask as mask_utils +import torch +from iopath.common.file_io import g_pathmgr +from sam3.eval.coco_eval_offline import convert_to_xywh +from sam3.train.masks_ops import rle_encode +from sam3.train.utils.distributed import ( + all_gather, + gather_to_rank_0_via_filesys, + get_rank, + is_main_process, +) + + +### Helper functions and classes + + +class HeapElement: + """Utility class to make a heap with a custom comparator based on score.""" + + def __init__(self, val): + self.val = val + + def __lt__(self, other): + return self.val["score"] < other.val["score"] + + +class PredictionDumper: + """ + Handles collection and dumping of COCO-format predictions from a model. + + This class processes model outputs through a postprocessor, converts them to COCO format, + and saves them to disk. It supports distributed processing with multiple GPUs/processes. + """ + + def __init__( + self, + dump_dir: str, + postprocessor, + maxdets: int, + iou_type: str, + gather_pred_via_filesys: bool = False, + merge_predictions: bool = False, + pred_file_evaluators: Optional[Any] = None, + ): + """ + Initialize the PredictionDumper. + + Args: + dump_dir: Directory to dump predictions. + postprocessor: Module to convert the model's output into COCO format. + maxdets: Maximum number of detections per image. + iou_type: IoU type to evaluate. Can include "bbox", "segm" + gather_pred_via_filesys: If True, use the filesystem for collective gathers across + processes (requires a shared filesystem). Otherwise, use torch collective ops. + merge_predictions: If True, merge predictions from all processes and dump to a single file. + """ + self.iou_type = iou_type + self.maxdets = maxdets + self.dump_dir = dump_dir + self.postprocessor = postprocessor + self.gather_pred_via_filesys = gather_pred_via_filesys + self.merge_predictions = merge_predictions + self.pred_file_evaluators = pred_file_evaluators + if self.pred_file_evaluators is not None: + assert ( + merge_predictions + ), "merge_predictions must be True if pred_file_evaluators are provided" + assert self.dump_dir is not None, "dump_dir must be provided" + + if is_main_process(): + os.makedirs(self.dump_dir, exist_ok=True) + logging.info(f"Created prediction dump directory: {self.dump_dir}") + + # Initialize state + self.reset() + + def update(self, *args, **kwargs): + """ + Process and accumulate predictions from model outputs. + + Args: + *args, **kwargs: Arguments passed to postprocessor.process_results() + """ + predictions = self.postprocessor.process_results(*args, **kwargs) + results = self.prepare(predictions, self.iou_type) + self._dump(results) + + def _dump(self, results): + """ + Add results to the dump list with precision rounding. + + Args: + results: List of prediction dictionaries in COCO format. + """ + dumped_results = copy.deepcopy(results) + for r in dumped_results: + if "bbox" in r: + r["bbox"] = [round(coord, 5) for coord in r["bbox"]] + r["score"] = round(r["score"], 5) + self.dump.extend(dumped_results) + + def synchronize_between_processes(self): + """ + Synchronize predictions across all processes and save to disk. + + If gather_pred_via_filesys is True, uses filesystem for gathering. + Otherwise, uses torch distributed collective operations. + Saves per-rank predictions to separate JSON files. + """ + logging.info("Prediction Dumper: Synchronizing between processes") + + if not self.merge_predictions: + dumped_file = ( + Path(self.dump_dir) + / f"coco_predictions_{self.iou_type}_{get_rank()}.json" + ) + logging.info( + f"Prediction Dumper: Dumping local predictions to {dumped_file}" + ) + with g_pathmgr.open(str(dumped_file), "w") as f: + json.dump(self.dump, f) + else: + self.dump = self.gather_and_merge_predictions() + dumped_file = Path(self.dump_dir) / f"coco_predictions_{self.iou_type}.json" + if is_main_process(): + logging.info( + f"Prediction Dumper: Dumping merged predictions to {dumped_file}" + ) + with g_pathmgr.open(str(dumped_file), "w") as f: + json.dump(self.dump, f) + + self.reset() + return dumped_file + + def gather_and_merge_predictions(self): + """ + Gather predictions from all processes and merge them, keeping top predictions per image. + + This method collects predictions from all processes, then keeps only the top maxdets + predictions per image based on score. It also deduplicates predictions by (image_id, category_id). + + Returns: + List of merged prediction dictionaries. + """ + logging.info("Prediction Dumper: Gathering predictions from all processes") + gc.collect() + + if self.gather_pred_via_filesys: + dump = gather_to_rank_0_via_filesys(self.dump) + else: + dump = all_gather(self.dump, force_cpu=True) + + # Combine predictions, keeping only top maxdets per image + preds_by_image = defaultdict(list) + seen_img_cat = set() + + for cur_dump in dump: + cur_seen_img_cat = set() + for p in cur_dump: + image_id = p["image_id"] + cat_id = p["category_id"] + + # Skip if we've already seen this image/category pair in a previous dump + if (image_id, cat_id) in seen_img_cat: + continue + + cur_seen_img_cat.add((image_id, cat_id)) + + # Use a min-heap to keep top predictions + if len(preds_by_image[image_id]) < self.maxdets: + heapq.heappush(preds_by_image[image_id], HeapElement(p)) + else: + heapq.heappushpop(preds_by_image[image_id], HeapElement(p)) + + seen_img_cat.update(cur_seen_img_cat) + + # Flatten the heap elements back to a list + merged_dump = sum( + [[h.val for h in cur_preds] for cur_preds in preds_by_image.values()], [] + ) + + return merged_dump + + def compute_synced(self): + """ + Synchronize predictions across processes and compute summary. + + Returns: + Summary dictionary from summarize(). + """ + dumped_file = self.synchronize_between_processes() + if not is_main_process(): + return {"": 0.0} + + meters = {} + if self.pred_file_evaluators is not None: + for evaluator in self.pred_file_evaluators: + results = evaluator.evaluate(dumped_file) + meters.update(results) + + if len(meters) == 0: + meters = {"": 0.0} + return meters + + def compute(self): + """ + Compute without synchronization. + + Returns: + Empty metric dictionary. + """ + return {"": 0.0} + + def reset(self): + """Reset internal state for a new evaluation round.""" + self.dump = [] + + def prepare(self, predictions, iou_type): + """ + Route predictions to the appropriate preparation method based on iou_type. + + Args: + predictions: Dictionary mapping image IDs to prediction dictionaries. + iou_type: Type of evaluation ("bbox", "segm"). + + Returns: + List of COCO-format prediction dictionaries. + """ + if iou_type == "bbox": + return self.prepare_for_coco_detection(predictions) + elif iou_type == "segm": + return self.prepare_for_coco_segmentation(predictions) + else: + raise ValueError(f"Unknown iou type: {iou_type}") + + def prepare_for_coco_detection(self, predictions): + """ + Convert predictions to COCO detection format. + + Args: + predictions: Dictionary mapping image IDs to prediction dictionaries + containing "boxes", "scores", and "labels". + + Returns: + List of COCO-format detection dictionaries. + """ + coco_results = [] + for original_id, prediction in predictions.items(): + if len(prediction) == 0: + continue + + boxes = prediction["boxes"] + boxes = convert_to_xywh(boxes).tolist() + scores = prediction["scores"].tolist() + labels = prediction["labels"].tolist() + + coco_results.extend( + [ + { + "image_id": original_id, + "category_id": labels[k], + "bbox": box, + "score": scores[k], + } + for k, box in enumerate(boxes) + ] + ) + return coco_results + + @torch.no_grad() + def prepare_for_coco_segmentation(self, predictions): + """ + Convert predictions to COCO segmentation format. + + Args: + predictions: Dictionary mapping image IDs to prediction dictionaries + containing "masks" or "masks_rle", "scores", and "labels". + Optionally includes "boundaries" and "dilated_boundaries". + + Returns: + List of COCO-format segmentation dictionaries with RLE-encoded masks. + """ + coco_results = [] + for original_id, prediction in predictions.items(): + if len(prediction) == 0: + continue + + scores = prediction["scores"].tolist() + labels = prediction["labels"].tolist() + + boxes = None + if "boxes" in prediction: + boxes = prediction["boxes"] + boxes = convert_to_xywh(boxes).tolist() + assert len(boxes) == len(scores) + + if "masks_rle" in prediction: + rles = prediction["masks_rle"] + areas = [] + for rle in rles: + cur_area = mask_utils.area(rle) + h, w = rle["size"] + areas.append(cur_area / (h * w)) + else: + masks = prediction["masks"] + masks = masks > 0.5 + h, w = masks.shape[-2:] + + areas = masks.flatten(1).sum(1) / (h * w) + areas = areas.tolist() + + rles = rle_encode(masks.squeeze(1)) + + # Memory cleanup + del masks + del prediction["masks"] + + assert len(areas) == len(rles) == len(scores) + + for k, rle in enumerate(rles): + payload = { + "image_id": original_id, + "category_id": labels[k], + "segmentation": rle, + "score": scores[k], + "area": areas[k], + } + if boxes is not None: + payload["bbox"] = boxes[k] + + coco_results.append(payload) + + return coco_results diff --git a/third_party/sam3/sam3/eval/conversion_util.py b/third_party/sam3/sam3/eval/conversion_util.py new file mode 100644 index 0000000000000000000000000000000000000000..fb8950b1282867ac7e8590ca78f8289048ea8908 --- /dev/null +++ b/third_party/sam3/sam3/eval/conversion_util.py @@ -0,0 +1,213 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe +import json +import os +from collections import defaultdict + +from tqdm import tqdm + + +def convert_ytbvis_to_cocovid_gt(ann_json, save_path=None): + """Convert YouTube VIS dataset to COCO-style video instance segmentation format. + + Args: + ann_json (str): Path to YouTube VIS annotation JSON file + save_path (str): path to save converted COCO-style JSON + """ + # Initialize COCO structure + VIS = { + "info": {}, + "images": [], + "videos": [], + "tracks": [], + "annotations": [], + "categories": [], + "licenses": [], + } + + # Load original annotations + official_anns = json.load(open(ann_json)) + VIS["categories"] = official_anns["categories"] # Direct copy categories + + # Initialize counters + records = dict(img_id=1, ann_id=1) + + # Create video-to-annotations mapping + vid_to_anns = defaultdict(list) + for ann in official_anns["annotations"]: + vid_to_anns[ann["video_id"]].append(ann) + + # Create tracks directly + VIS["tracks"] = [ + { + "id": ann["id"], + "category_id": ann["category_id"], + "video_id": ann["video_id"], + } + for ann in official_anns["annotations"] + ] + + # Process videos + for video_info in tqdm(official_anns["videos"]): + # Create video entry + video = { + "id": video_info["id"], + "name": os.path.dirname(video_info["file_names"][0]), + "width": video_info["width"], + "height": video_info["height"], + "length": video_info["length"], + "neg_category_ids": [], + "not_exhaustive_category_ids": [], + } + VIS["videos"].append(video) + + # Process frames + num_frames = len(video_info["file_names"]) + for frame_idx in range(num_frames): + # Create image entry + image = { + "id": records["img_id"], + "video_id": video_info["id"], + "file_name": video_info["file_names"][frame_idx], + "width": video_info["width"], + "height": video_info["height"], + "frame_index": frame_idx, + "frame_id": frame_idx, + } + VIS["images"].append(image) + + # Process annotations for this frame + if video_info["id"] in vid_to_anns: + for ann in vid_to_anns[video_info["id"]]: + bbox = ann["bboxes"][frame_idx] + if bbox is None: + continue + + # Create annotation entry + annotation = { + "id": records["ann_id"], + "video_id": video_info["id"], + "image_id": records["img_id"], + "track_id": ann["id"], + "category_id": ann["category_id"], + "bbox": bbox, + "area": ann["areas"][frame_idx], + "segmentation": ann["segmentations"][frame_idx], + "iscrowd": ann["iscrowd"], + } + VIS["annotations"].append(annotation) + records["ann_id"] += 1 + + records["img_id"] += 1 + + # Print summary + print(f"Converted {len(VIS['videos'])} videos") + print(f"Converted {len(VIS['images'])} images") + print(f"Created {len(VIS['tracks'])} tracks") + print(f"Created {len(VIS['annotations'])} annotations") + + if save_path is None: + return VIS + + # Save output + save_dir = os.path.dirname(save_path) + os.makedirs(save_dir, exist_ok=True) + json.dump(VIS, open(save_path, "w")) + + return VIS + + +def convert_ytbvis_to_cocovid_pred( + youtubevis_pred_path: str, converted_dataset_path: str, output_path: str +) -> None: + """ + Convert YouTubeVIS predictions to COCO format with video_id preservation + + Args: + youtubevis_pred_path: Path to YouTubeVIS prediction JSON + converted_dataset_path: Path to converted COCO dataset JSON + output_path: Path to save COCO format predictions + """ + + # Load YouTubeVIS predictions + with open(youtubevis_pred_path) as f: + ytv_predictions = json.load(f) + + # Load converted dataset for image ID mapping + with open(converted_dataset_path) as f: + coco_dataset = json.load(f) + + # Create (video_id, frame_idx) -> image_id mapping + image_id_map = { + (img["video_id"], img["frame_index"]): img["id"] + for img in coco_dataset["images"] + } + + coco_annotations = [] + track_id_counter = 1 # Unique track ID generator + + for pred in tqdm(ytv_predictions): + video_id = pred["video_id"] + category_id = pred["category_id"] + bboxes = pred["bboxes"] + segmentations = pred.get("segmentations", []) # Get segmentations if available + areas = pred.get("areas", []) # Get areas if available + score = pred["score"] + + # Assign unique track ID for this prediction + track_id = track_id_counter + track_id_counter += 1 + + # Ensure segmentations and areas have the same length as bboxes + if len(segmentations) == 0: + segmentations = [None] * len(bboxes) + if len(areas) == 0: + areas = [None] * len(bboxes) + + for frame_idx, (bbox, segmentation, area_from_pred) in enumerate( + zip(bboxes, segmentations, areas) + ): + # Skip frames with missing objects (None or zero bbox) + if bbox is None or all(x == 0 for x in bbox): + continue + + # Get corresponding image ID from mapping + image_id = image_id_map.get((video_id, frame_idx)) + if image_id is None: + raise RuntimeError( + f"prediction {video_id=}, {frame_idx=} does not match any images in the converted COCO format" + ) + + # Extract bbox coordinates + x, y, w, h = bbox + + # Calculate area - use area from prediction if available, otherwise from bbox + if area_from_pred is not None and area_from_pred > 0: + area = area_from_pred + else: + area = w * h + + # Create COCO annotation with video_id + coco_annotation = { + "image_id": int(image_id), + "video_id": video_id, # Added video_id field + "track_id": track_id, + "category_id": category_id, + "bbox": [float(x), float(y), float(w), float(h)], + "area": float(area), + "iscrowd": 0, + "score": float(score), + } + + # Add segmentation if available + if segmentation is not None: + coco_annotation["segmentation"] = segmentation + + coco_annotations.append(coco_annotation) + + # Save output + with open(output_path, "w") as f: + json.dump(coco_annotations, f) + + print(f"Converted {len(coco_annotations)} predictions to COCO format with video_id") diff --git a/third_party/sam3/sam3/eval/demo_eval.py b/third_party/sam3/sam3/eval/demo_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..353b86aaf1186976056b01fd90ab089cb34d45d6 --- /dev/null +++ b/third_party/sam3/sam3/eval/demo_eval.py @@ -0,0 +1,658 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +""" +This evaluator is based upon COCO evaluation, but evaluates the model in a "demo" setting. +This means that the model's predictions are thresholded and evaluated as "hard" predictions. +""" + +import logging +from typing import Optional + +import numpy as np +import pycocotools.mask as maskUtils +from pycocotools.cocoeval import COCOeval +from sam3.eval.coco_eval import CocoEvaluator +from sam3.train.masks_ops import compute_F_measure +from sam3.train.utils.distributed import is_main_process +from scipy.optimize import linear_sum_assignment + + +class DemoEval(COCOeval): + """ + This evaluator is based upon COCO evaluation, but evaluates the model in a "demo" setting. + This means that the model's predictions are thresholded and evaluated as "hard" predictions. + """ + + def __init__( + self, + coco_gt=None, + coco_dt=None, + iouType="bbox", + threshold=0.5, + compute_JnF=False, + ): + """ + Args: + coco_gt (COCO): ground truth COCO API + coco_dt (COCO): detections COCO API + iou_type (str): type of IoU to evaluate + threshold (float): threshold for predictions + """ + super().__init__(coco_gt, coco_dt, iouType) + self.threshold = threshold + + self.params.useCats = False + self.params.areaRng = [[0**2, 1e5**2]] + self.params.areaRngLbl = ["all"] + self.params.maxDets = [100000] + self.compute_JnF = compute_JnF + + def computeIoU(self, imgId, catId): + # Same as the original COCOeval.computeIoU, but without sorting + p = self.params + if p.useCats: + gt = self._gts[imgId, catId] + dt = self._dts[imgId, catId] + else: + gt = [_ for cId in p.catIds for _ in self._gts[imgId, cId]] + dt = [_ for cId in p.catIds for _ in self._dts[imgId, cId]] + if len(gt) == 0 and len(dt) == 0: + return [] + + if p.iouType == "segm": + g = [g["segmentation"] for g in gt] + d = [d["segmentation"] for d in dt] + elif p.iouType == "bbox": + g = [g["bbox"] for g in gt] + d = [d["bbox"] for d in dt] + else: + raise Exception("unknown iouType for iou computation") + + # compute iou between each dt and gt region + iscrowd = [int(o["iscrowd"]) for o in gt] + ious = maskUtils.iou(d, g, iscrowd) + return ious + + def evaluateImg(self, imgId, catId, aRng, maxDet): + """ + perform evaluation for single category and image + :return: dict (single image results) + """ + p = self.params + assert not p.useCats, "This evaluator does not support per-category evaluation." + assert catId == -1 + all_gts = [_ for cId in p.catIds for _ in self._gts[imgId, cId]] + keep_gt = np.array([not g["ignore"] for g in all_gts], dtype=bool) + gt = [g for g in all_gts if not g["ignore"]] + all_dts = [_ for cId in p.catIds for _ in self._dts[imgId, cId]] + keep_dt = np.array([d["score"] >= self.threshold for d in all_dts], dtype=bool) + dt = [d for d in all_dts if d["score"] >= self.threshold] + if len(gt) == 0 and len(dt) == 0: + # This is a "true negative" case, where there are no GTs and no predictions + # The box-level metrics are ill-defined, so we don't add them to this dict + return { + "image_id": imgId, + "IL_TP": 0, + "IL_TN": 1, + "IL_FP": 0, + "IL_FN": 0, + "IL_perfect_neg": np.ones((len(p.iouThrs),), dtype=np.int64), + "num_dt": len(dt), + } + + if len(gt) > 0 and len(dt) == 0: + # This is a "false negative" case, where there are GTs but no predictions + return { + "image_id": imgId, + "IL_TP": 0, + "IL_TN": 0, + "IL_FP": 0, + "IL_FN": 1, + "TPs": np.zeros((len(p.iouThrs),), dtype=np.int64), + "FPs": np.zeros((len(p.iouThrs),), dtype=np.int64), + "FNs": np.ones((len(p.iouThrs),), dtype=np.int64) * len(gt), + "local_F1s": np.zeros((len(p.iouThrs),), dtype=np.int64), + "local_positive_F1s": np.zeros((len(p.iouThrs),), dtype=np.int64), + "IL_perfect_pos": np.zeros((len(p.iouThrs),), dtype=np.int64), + "num_dt": len(dt), + } + + # Load pre-computed ious + ious = self.ious[(imgId, catId)] + + # compute matching + if len(ious) == 0: + ious = np.zeros((len(dt), len(gt))) + else: + ious = ious[keep_dt, :][:, keep_gt] + assert ious.shape == (len(dt), len(gt)) + + matched_dt, matched_gt = linear_sum_assignment(-ious) + + match_scores = ious[matched_dt, matched_gt] + + if self.compute_JnF and len(match_scores) > 0: + j_score = match_scores.mean() + f_measure = 0 + for dt_id, gt_id in zip(matched_dt, matched_gt): + f_measure += compute_F_measure( + gt_boundary_rle=gt[gt_id]["boundary"], + gt_dilated_boundary_rle=gt[gt_id]["dilated_boundary"], + dt_boundary_rle=dt[dt_id]["boundary"], + dt_dilated_boundary_rle=dt[dt_id]["dilated_boundary"], + ) + f_measure /= len(match_scores) + 1e-9 + JnF = (j_score + f_measure) * 0.5 + else: + j_score = f_measure = JnF = -1 + + TPs, FPs, FNs = [], [], [] + IL_perfect = [] + for thresh in p.iouThrs: + TP = (match_scores >= thresh).sum() + FP = len(dt) - TP + FN = len(gt) - TP + assert ( + FP >= 0 and FN >= 0 + ), f"FP: {FP}, FN: {FN}, TP: {TP}, match_scores: {match_scores}, len(dt): {len(dt)}, len(gt): {len(gt)}, ious: {ious}" + TPs.append(TP) + FPs.append(FP) + FNs.append(FN) + + if FP == FN and FP == 0: + IL_perfect.append(1) + else: + IL_perfect.append(0) + + TPs = np.array(TPs, dtype=np.int64) + FPs = np.array(FPs, dtype=np.int64) + FNs = np.array(FNs, dtype=np.int64) + IL_perfect = np.array(IL_perfect, dtype=np.int64) + + # compute precision recall and F1 + precision = TPs / (TPs + FPs + 1e-4) + assert np.all(precision <= 1) + recall = TPs / (TPs + FNs + 1e-4) + assert np.all(recall <= 1) + F1 = 2 * precision * recall / (precision + recall + 1e-4) + + result = { + "image_id": imgId, + "TPs": TPs, + "FPs": FPs, + "FNs": FNs, + "local_F1s": F1, + "IL_TP": (len(gt) > 0) and (len(dt) > 0), + "IL_FP": (len(gt) == 0) and (len(dt) > 0), + "IL_TN": (len(gt) == 0) and (len(dt) == 0), + "IL_FN": (len(gt) > 0) and (len(dt) == 0), + ("IL_perfect_pos" if len(gt) > 0 else "IL_perfect_neg"): IL_perfect, + "F": f_measure, + "J": j_score, + "J&F": JnF, + "num_dt": len(dt), + } + if len(gt) > 0 and len(dt) > 0: + result["local_positive_F1s"] = F1 + return result + + def accumulate(self, p=None): + """ + Accumulate per image evaluation results and store the result in self.eval + :param p: input params for evaluation + :return: None + """ + if not self.evalImgs: + print("Please run evaluate() first") + # allows input customized parameters + if p is None: + p = self.params + + setImgIds = set(p.imgIds) + + # TPs, FPs, FNs + TPs = np.zeros((len(p.iouThrs),), dtype=np.int64) + FPs = np.zeros((len(p.iouThrs),), dtype=np.int64) + pmFPs = np.zeros((len(p.iouThrs),), dtype=np.int64) + FNs = np.zeros((len(p.iouThrs),), dtype=np.int64) + local_F1s = np.zeros((len(p.iouThrs),), dtype=np.float64) + + # Image level metrics + IL_TPs = 0 + IL_FPs = 0 + IL_TNs = 0 + IL_FNs = 0 + IL_perfects_neg = np.zeros((len(p.iouThrs),), dtype=np.int64) + IL_perfects_pos = np.zeros((len(p.iouThrs),), dtype=np.int64) + + # JnF metric + total_J = 0 + total_F = 0 + total_JnF = 0 + + valid_img_count = 0 + total_pos_count = 0 + total_neg_count = 0 + valid_J_count = 0 + valid_F1_count = 0 + valid_F1_count_w0dt = 0 + for res in self.evalImgs: + if res["image_id"] not in setImgIds: + continue + IL_TPs += res["IL_TP"] + IL_FPs += res["IL_FP"] + IL_TNs += res["IL_TN"] + IL_FNs += res["IL_FN"] + if "IL_perfect_neg" in res: + IL_perfects_neg += res["IL_perfect_neg"] + total_neg_count += 1 + else: + assert "IL_perfect_pos" in res + IL_perfects_pos += res["IL_perfect_pos"] + total_pos_count += 1 + + if "TPs" not in res: + continue + + TPs += res["TPs"] + FPs += res["FPs"] + FNs += res["FNs"] + valid_img_count += 1 + + if "local_positive_F1s" in res: + local_F1s += res["local_positive_F1s"] + pmFPs += res["FPs"] + valid_F1_count_w0dt += 1 + if res["num_dt"] > 0: + valid_F1_count += 1 + + if "J" in res and res["J"] > -1e-9: + total_J += res["J"] + total_F += res["F"] + total_JnF += res["J&F"] + valid_J_count += 1 + + # compute precision recall and F1 + precision = TPs / (TPs + FPs + 1e-4) + positive_micro_precision = TPs / (TPs + pmFPs + 1e-4) + assert np.all(precision <= 1) + recall = TPs / (TPs + FNs + 1e-4) + assert np.all(recall <= 1) + F1 = 2 * precision * recall / (precision + recall + 1e-4) + positive_micro_F1 = ( + 2 + * positive_micro_precision + * recall + / (positive_micro_precision + recall + 1e-4) + ) + + IL_rec = IL_TPs / (IL_TPs + IL_FNs + 1e-6) + IL_prec = IL_TPs / (IL_TPs + IL_FPs + 1e-6) + IL_F1 = 2 * IL_prec * IL_rec / (IL_prec + IL_rec + 1e-6) + IL_FPR = IL_FPs / (IL_FPs + IL_TNs + 1e-6) + IL_MCC = float(IL_TPs * IL_TNs - IL_FPs * IL_FNs) / ( + ( + float(IL_TPs + IL_FPs) + * float(IL_TPs + IL_FNs) + * float(IL_TNs + IL_FPs) + * float(IL_TNs + IL_FNs) + ) + ** 0.5 + + 1e-6 + ) + IL_perfect_pos = IL_perfects_pos / (total_pos_count + 1e-9) + IL_perfect_neg = IL_perfects_neg / (total_neg_count + 1e-9) + + total_J = total_J / (valid_J_count + 1e-9) + total_F = total_F / (valid_J_count + 1e-9) + total_JnF = total_JnF / (valid_J_count + 1e-9) + + self.eval = { + "params": p, + "TPs": TPs, + "FPs": FPs, + "positive_micro_FPs": pmFPs, + "FNs": FNs, + "precision": precision, + "positive_micro_precision": positive_micro_precision, + "recall": recall, + "F1": F1, + "positive_micro_F1": positive_micro_F1, + "positive_macro_F1": local_F1s / valid_F1_count, + "positive_w0dt_macro_F1": local_F1s / valid_F1_count_w0dt, + "IL_recall": IL_rec, + "IL_precision": IL_prec, + "IL_F1": IL_F1, + "IL_FPR": IL_FPR, + "IL_MCC": IL_MCC, + "IL_perfect_pos": IL_perfect_pos, + "IL_perfect_neg": IL_perfect_neg, + "J": total_J, + "F": total_F, + "J&F": total_JnF, + } + self.eval["CGF1"] = self.eval["positive_macro_F1"] * self.eval["IL_MCC"] + self.eval["CGF1_w0dt"] = ( + self.eval["positive_w0dt_macro_F1"] * self.eval["IL_MCC"] + ) + self.eval["CGF1_micro"] = self.eval["positive_micro_F1"] * self.eval["IL_MCC"] + + def summarize(self): + """ + Compute and display summary metrics for evaluation results. + Note this functin can *only* be applied on the default parameter setting + """ + if not self.eval: + raise Exception("Please run accumulate() first") + + def _summarize(iouThr=None, metric=""): + p = self.params + iStr = " {:<18} @[ IoU={:<9}] = {:0.3f}" + titleStr = "Average " + metric + iouStr = ( + "{:0.2f}:{:0.2f}".format(p.iouThrs[0], p.iouThrs[-1]) + if iouThr is None + else "{:0.2f}".format(iouThr) + ) + + s = self.eval[metric] + # IoU + if iouThr is not None: + t = np.where(iouThr == p.iouThrs)[0] + s = s[t] + + if len(s[s > -1]) == 0: + mean_s = -1 + else: + mean_s = np.mean(s[s > -1]) + print(iStr.format(titleStr, iouStr, mean_s)) + return mean_s + + def _summarize_single(metric=""): + titleStr = "Average " + metric + iStr = " {:<35} = {:0.3f}" + s = self.eval[metric] + print(iStr.format(titleStr, s)) + return s + + def _summarizeDets(): + # note: the index of these metrics are also used in video Demo F1 evaluation + # when adding new metrics, please update the index in video Demo F1 evaluation + # in "evaluate" method of the "VideoDemoF1Evaluator" class + stats = np.zeros((len(DEMO_METRICS),)) + stats[0] = _summarize(metric="CGF1") + stats[1] = _summarize(metric="precision") + stats[2] = _summarize(metric="recall") + stats[3] = _summarize(metric="F1") + stats[4] = _summarize(metric="positive_macro_F1") + stats[5] = _summarize_single(metric="IL_precision") + stats[6] = _summarize_single(metric="IL_recall") + stats[7] = _summarize_single(metric="IL_F1") + stats[8] = _summarize_single(metric="IL_FPR") + stats[9] = _summarize_single(metric="IL_MCC") + stats[10] = _summarize(metric="IL_perfect_pos") + stats[11] = _summarize(metric="IL_perfect_neg") + stats[12] = _summarize(iouThr=0.5, metric="CGF1") + stats[13] = _summarize(iouThr=0.5, metric="precision") + stats[14] = _summarize(iouThr=0.5, metric="recall") + stats[15] = _summarize(iouThr=0.5, metric="F1") + stats[16] = _summarize(iouThr=0.5, metric="positive_macro_F1") + stats[17] = _summarize(iouThr=0.5, metric="IL_perfect_pos") + stats[18] = _summarize(iouThr=0.5, metric="IL_perfect_neg") + stats[19] = _summarize(iouThr=0.75, metric="CGF1") + stats[20] = _summarize(iouThr=0.75, metric="precision") + stats[21] = _summarize(iouThr=0.75, metric="recall") + stats[22] = _summarize(iouThr=0.75, metric="F1") + stats[23] = _summarize(iouThr=0.75, metric="positive_macro_F1") + stats[24] = _summarize(iouThr=0.75, metric="IL_perfect_pos") + stats[25] = _summarize(iouThr=0.75, metric="IL_perfect_neg") + stats[26] = _summarize_single(metric="J") + stats[27] = _summarize_single(metric="F") + stats[28] = _summarize_single(metric="J&F") + stats[29] = _summarize(metric="CGF1_micro") + stats[30] = _summarize(metric="positive_micro_precision") + stats[31] = _summarize(metric="positive_micro_F1") + stats[32] = _summarize(iouThr=0.5, metric="CGF1_micro") + stats[33] = _summarize(iouThr=0.5, metric="positive_micro_precision") + stats[34] = _summarize(iouThr=0.5, metric="positive_micro_F1") + stats[35] = _summarize(iouThr=0.75, metric="CGF1_micro") + stats[36] = _summarize(iouThr=0.75, metric="positive_micro_precision") + stats[37] = _summarize(iouThr=0.75, metric="positive_micro_F1") + stats[38] = _summarize(metric="CGF1_w0dt") + stats[39] = _summarize(metric="positive_w0dt_macro_F1") + stats[40] = _summarize(iouThr=0.5, metric="CGF1_w0dt") + stats[41] = _summarize(iouThr=0.5, metric="positive_w0dt_macro_F1") + stats[42] = _summarize(iouThr=0.75, metric="CGF1_w0dt") + stats[43] = _summarize(iouThr=0.75, metric="positive_w0dt_macro_F1") + return stats + + summarize = _summarizeDets + self.stats = summarize() + + +DEMO_METRICS = [ + "CGF1", + "Precision", + "Recall", + "F1", + "Macro_F1", + "IL_Precision", + "IL_Recall", + "IL_F1", + "IL_FPR", + "IL_MCC", + "IL_perfect_pos", + "IL_perfect_neg", + "CGF1@0.5", + "Precision@0.5", + "Recall@0.5", + "F1@0.5", + "Macro_F1@0.5", + "IL_perfect_pos@0.5", + "IL_perfect_neg@0.5", + "CGF1@0.75", + "Precision@0.75", + "Recall@0.75", + "F1@0.75", + "Macro_F1@0.75", + "IL_perfect_pos@0.75", + "IL_perfect_neg@0.75", + "J", + "F", + "J&F", + "CGF1_micro", + "positive_micro_Precision", + "positive_micro_F1", + "CGF1_micro@0.5", + "positive_micro_Precision@0.5", + "positive_micro_F1@0.5", + "CGF1_micro@0.75", + "positive_micro_Precision@0.75", + "positive_micro_F1@0.75", + "CGF1_w0dt", + "positive_w0dt_macro_F1", + "CGF1_w0dt@0.5", + "positive_w0dt_macro_F1@0.5", + "CGF1_w0dt@0.75", + "positive_w0dt_macro_F1@0.75", +] + + +class DemoEvaluator(CocoEvaluator): + def __init__( + self, + coco_gt, + iou_types, + dump_dir: Optional[str], + postprocessor, + threshold=0.5, + average_by_rarity=False, + gather_pred_via_filesys=False, + exhaustive_only=False, + all_exhaustive_only=True, + compute_JnF=False, + metrics_dump_dir: Optional[str] = None, + ): + self.iou_types = iou_types + self.threshold = threshold + super().__init__( + coco_gt=coco_gt, + iou_types=iou_types, + useCats=False, + dump_dir=dump_dir, + postprocessor=postprocessor, + # average_by_rarity=average_by_rarity, + gather_pred_via_filesys=gather_pred_via_filesys, + exhaustive_only=exhaustive_only, + all_exhaustive_only=all_exhaustive_only, + metrics_dump_dir=metrics_dump_dir, + ) + + self.use_self_evaluate = True + self.compute_JnF = compute_JnF + + def _lazy_init(self): + if self.initialized: + return + super()._lazy_init() + self.use_self_evaluate = True + self.reset() + + def select_best_scoring(self, scorings): + # This function is used for "oracle" type evaluation. + # It accepts the evaluation results with respect to several ground truths, and picks the best + if len(scorings) == 1: + return scorings[0] + + assert ( + scorings[0].ndim == 3 + ), f"Expecting results in [numCats, numAreas, numImgs] format, got {scorings[0].shape}" + assert ( + scorings[0].shape[0] == 1 + ), f"Expecting a single category, got {scorings[0].shape[0]}" + + for scoring in scorings: + assert ( + scoring.shape == scorings[0].shape + ), f"Shape mismatch: {scoring.shape}, {scorings[0].shape}" + + selected_imgs = [] + for img_id in range(scorings[0].shape[-1]): + best = scorings[0][:, :, img_id] + + for scoring in scorings[1:]: + current = scoring[:, :, img_id] + if "local_F1s" in best[0, 0] and "local_F1s" in current[0, 0]: + # we were able to compute a F1 score for this particular image in both evaluations + # best["local_F1s"] contains the results at various IoU thresholds. We simply take the average for comparision + best_score = best[0, 0]["local_F1s"].mean() + current_score = current[0, 0]["local_F1s"].mean() + if current_score > best_score: + best = current + + else: + # If we're here, it means that in that in some evaluation we were not able to get a valid local F1 + # This happens when both the predictions and targets are empty. In that case, we can assume it's a perfect prediction + if "local_F1s" not in current[0, 0]: + best = current + selected_imgs.append(best) + result = np.stack(selected_imgs, axis=-1) + assert result.shape == scorings[0].shape + return result + + def summarize(self): + self._lazy_init() + logging.info("Demo evaluator: Summarizing") + if not is_main_process(): + return {} + outs = {} + prefix = "oracle_" if len(self.coco_evals) > 1 else "" + # if self.rarity_buckets is None: + self.accumulate(self.eval_img_ids) + for iou_type, coco_eval in self.coco_evals[0].items(): + print("Demo metric, IoU type={}".format(iou_type)) + coco_eval.summarize() + + if "bbox" in self.coco_evals[0]: + for i, value in enumerate(self.coco_evals[0]["bbox"].stats): + outs[f"coco_eval_bbox_{prefix}{DEMO_METRICS[i]}"] = value + if "segm" in self.coco_evals[0]: + for i, value in enumerate(self.coco_evals[0]["segm"].stats): + outs[f"coco_eval_masks_{prefix}{DEMO_METRICS[i]}"] = value + # else: + # total_stats = {} + # for bucket, img_list in self.rarity_buckets.items(): + # self.accumulate(imgIds=img_list) + # bucket_name = RARITY_BUCKETS[bucket] + # for iou_type, coco_eval in self.coco_evals[0].items(): + # print( + # "Demo metric, IoU type={}, Rarity bucket={}".format( + # iou_type, bucket_name + # ) + # ) + # coco_eval.summarize() + + # if "bbox" in self.coco_evals[0]: + # if "bbox" not in total_stats: + # total_stats["bbox"] = np.zeros_like( + # self.coco_evals[0]["bbox"].stats + # ) + # total_stats["bbox"] += self.coco_evals[0]["bbox"].stats + # for i, value in enumerate(self.coco_evals[0]["bbox"].stats): + # outs[ + # f"coco_eval_bbox_{bucket_name}_{prefix}{DEMO_METRICS[i]}" + # ] = value + # if "segm" in self.coco_evals[0]: + # if "segm" not in total_stats: + # total_stats["segm"] = np.zeros_like( + # self.coco_evals[0]["segm"].stats + # ) + # total_stats["segm"] += self.coco_evals[0]["segm"].stats + # for i, value in enumerate(self.coco_evals[0]["segm"].stats): + # outs[ + # f"coco_eval_masks_{bucket_name}_{prefix}{DEMO_METRICS[i]}" + # ] = value + + # if "bbox" in total_stats: + # total_stats["bbox"] /= len(self.rarity_buckets) + # for i, value in enumerate(total_stats["bbox"]): + # outs[f"coco_eval_bbox_{prefix}{DEMO_METRICS[i]}"] = value + # if "segm" in total_stats: + # total_stats["segm"] /= len(self.rarity_buckets) + # for i, value in enumerate(total_stats["segm"]): + # outs[f"coco_eval_masks_{prefix}{DEMO_METRICS[i]}"] = value + + return outs + + def accumulate(self, imgIds=None): + self._lazy_init() + logging.info( + f"demo evaluator: Accumulating on {len(imgIds) if imgIds is not None else 'all'} images" + ) + if not is_main_process(): + return + + if imgIds is not None: + for coco_eval in self.coco_evals[0].values(): + coco_eval.params.imgIds = list(imgIds) + + for coco_eval in self.coco_evals[0].values(): + coco_eval.accumulate() + + def reset(self): + self.coco_evals = [{} for _ in range(len(self.coco_gts))] + for i, coco_gt in enumerate(self.coco_gts): + for iou_type in self.iou_types: + self.coco_evals[i][iou_type] = DemoEval( + coco_gt=coco_gt, + iouType=iou_type, + threshold=self.threshold, + compute_JnF=self.compute_JnF, + ) + self.coco_evals[i][iou_type].useCats = False + self.img_ids = [] + self.eval_imgs = {k: [] for k in self.iou_types} + if self.dump is not None: + self.dump = [] diff --git a/third_party/sam3/sam3/eval/hota_eval_toolkit/__init__.py b/third_party/sam3/sam3/eval/hota_eval_toolkit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..930069a18531b4dc80000b894abb40b6619a8b05 --- /dev/null +++ b/third_party/sam3/sam3/eval/hota_eval_toolkit/__init__.py @@ -0,0 +1,3 @@ +# flake8: noqa + +# pyre-unsafe diff --git a/third_party/sam3/sam3/eval/hota_eval_toolkit/run_ytvis_eval.py b/third_party/sam3/sam3/eval/hota_eval_toolkit/run_ytvis_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..05654e7cae2d2db146159a65845dee6a6e9d970c --- /dev/null +++ b/third_party/sam3/sam3/eval/hota_eval_toolkit/run_ytvis_eval.py @@ -0,0 +1,116 @@ +# flake8: noqa + +# pyre-unsafe + +"""run_youtube_vis.py +Run example: +run_youtube_vis.py --USE_PARALLEL False --METRICS HOTA --TRACKERS_TO_EVAL STEm_Seg +Command Line Arguments: Defaults, # Comments + Eval arguments: + 'USE_PARALLEL': False, + 'NUM_PARALLEL_CORES': 8, + 'BREAK_ON_ERROR': True, # Raises exception and exits with error + 'RETURN_ON_ERROR': False, # if not BREAK_ON_ERROR, then returns from function on error + 'LOG_ON_ERROR': os.path.join(code_path, 'error_log.txt'), # if not None, save any errors into a log file. + 'PRINT_RESULTS': True, + 'PRINT_ONLY_COMBINED': False, + 'PRINT_CONFIG': True, + 'TIME_PROGRESS': True, + 'DISPLAY_LESS_PROGRESS': True, + 'OUTPUT_SUMMARY': True, + 'OUTPUT_EMPTY_CLASSES': True, # If False, summary files are not output for classes with no detections + 'OUTPUT_DETAILED': True, + 'PLOT_CURVES': True, + Dataset arguments: + 'GT_FOLDER': os.path.join(code_path, 'data/gt/youtube_vis/youtube_vis_training'), # Location of GT data + 'TRACKERS_FOLDER': os.path.join(code_path, 'data/trackers/youtube_vis/youtube_vis_training'), + # Trackers location + 'OUTPUT_FOLDER': None, # Where to save eval results (if None, same as TRACKERS_FOLDER) + 'TRACKERS_TO_EVAL': None, # Filenames of trackers to eval (if None, all in folder) + 'CLASSES_TO_EVAL': None, # Classes to eval (if None, all classes) + 'SPLIT_TO_EVAL': 'training', # Valid: 'training', 'val' + 'PRINT_CONFIG': True, # Whether to print current config + 'OUTPUT_SUB_FOLDER': '', # Output files are saved in OUTPUT_FOLDER/tracker_name/OUTPUT_SUB_FOLDER + 'TRACKER_SUB_FOLDER': 'data', # Tracker files are in TRACKER_FOLDER/tracker_name/TRACKER_SUB_FOLDER + 'TRACKER_DISPLAY_NAMES': None, # Names of trackers to display, if None: TRACKERS_TO_EVAL + Metric arguments: + 'METRICS': ['TrackMAP', 'HOTA', 'CLEAR', 'Identity'] +""" + +import argparse +import os +import sys +from multiprocessing import freeze_support + +from . import trackeval + + +def run_ytvis_eval(args=None, gt_json=None, dt_json=None): + # Command line interface: + default_eval_config = trackeval.Evaluator.get_default_eval_config() + # print only combined since TrackMAP is undefined for per sequence breakdowns + default_eval_config["PRINT_ONLY_COMBINED"] = True + default_dataset_config = trackeval.datasets.YouTubeVIS.get_default_dataset_config() + default_metrics_config = {"METRICS": ["HOTA"]} + config = { + **default_eval_config, + **default_dataset_config, + **default_metrics_config, + } # Merge default configs + parser = argparse.ArgumentParser() + for setting in config.keys(): + if type(config[setting]) == list or type(config[setting]) == type(None): + parser.add_argument("--" + setting, nargs="+") + else: + parser.add_argument("--" + setting) + args = parser.parse_args(args).__dict__ + for setting in args.keys(): + if args[setting] is not None: + if type(config[setting]) == type(True): + if args[setting] == "True": + x = True + elif args[setting] == "False": + x = False + else: + raise Exception( + "Command line parameter " + setting + "must be True or False" + ) + elif type(config[setting]) == type(1): + x = int(args[setting]) + elif type(args[setting]) == type(None): + x = None + else: + x = args[setting] + config[setting] = x + eval_config = {k: v for k, v in config.items() if k in default_eval_config.keys()} + dataset_config = { + k: v for k, v in config.items() if k in default_dataset_config.keys() + } + metrics_config = { + k: v for k, v in config.items() if k in default_metrics_config.keys() + } + + # Run code + evaluator = trackeval.Evaluator(eval_config) + # allow directly specifying the GT JSON data and Tracker (result) + # JSON data as Python objects, without reading from files. + dataset_config["GT_JSON_OBJECT"] = gt_json + dataset_config["TRACKER_JSON_OBJECT"] = dt_json + dataset_list = [trackeval.datasets.YouTubeVIS(dataset_config)] + metrics_list = [] + # for metric in [trackeval.metrics.TrackMAP, trackeval.metrics.HOTA, trackeval.metrics.CLEAR, + # trackeval.metrics.Identity]: + for metric in [trackeval.metrics.HOTA]: + if metric.get_name() in metrics_config["METRICS"]: + metrics_list.append(metric()) + if len(metrics_list) == 0: + raise Exception("No metrics selected for evaluation") + output_res, output_msg = evaluator.evaluate(dataset_list, metrics_list) + return output_res, output_msg + + +if __name__ == "__main__": + import sys + + freeze_support() + run_ytvis_eval(sys.argv[1:]) diff --git a/third_party/sam3/sam3/eval/hota_eval_toolkit/trackeval/__init__.py b/third_party/sam3/sam3/eval/hota_eval_toolkit/trackeval/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e03bbe1a1fbae9bae9a176e58cb9ce5b7e6d5839 --- /dev/null +++ b/third_party/sam3/sam3/eval/hota_eval_toolkit/trackeval/__init__.py @@ -0,0 +1,6 @@ +# flake8: noqa + +# pyre-unsafe + +from . import datasets, metrics, utils +from .eval import Evaluator diff --git a/third_party/sam3/sam3/eval/hota_eval_toolkit/trackeval/_timing.py b/third_party/sam3/sam3/eval/hota_eval_toolkit/trackeval/_timing.py new file mode 100644 index 0000000000000000000000000000000000000000..23fa33cf93c477b872945e9e00210fb61f730ca6 --- /dev/null +++ b/third_party/sam3/sam3/eval/hota_eval_toolkit/trackeval/_timing.py @@ -0,0 +1,70 @@ +# flake8: noqa + +# pyre-unsafe + +import inspect +from functools import wraps +from time import perf_counter + +DO_TIMING = False +DISPLAY_LESS_PROGRESS = False +timer_dict = {} +counter = 0 + + +def time(f): + @wraps(f) + def wrap(*args, **kw): + if DO_TIMING: + # Run function with timing + ts = perf_counter() + result = f(*args, **kw) + te = perf_counter() + tt = te - ts + + # Get function name + arg_names = inspect.getfullargspec(f)[0] + if arg_names[0] == "self" and DISPLAY_LESS_PROGRESS: + return result + elif arg_names[0] == "self": + method_name = type(args[0]).__name__ + "." + f.__name__ + else: + method_name = f.__name__ + + # Record accumulative time in each function for analysis + if method_name in timer_dict.keys(): + timer_dict[method_name] += tt + else: + timer_dict[method_name] = tt + + # If code is finished, display timing summary + if method_name == "Evaluator.evaluate": + print("") + print("Timing analysis:") + for key, value in timer_dict.items(): + print("%-70s %2.4f sec" % (key, value)) + else: + # Get function argument values for printing special arguments of interest + arg_titles = ["tracker", "seq", "cls"] + arg_vals = [] + for i, a in enumerate(arg_names): + if a in arg_titles: + arg_vals.append(args[i]) + arg_text = "(" + ", ".join(arg_vals) + ")" + + # Display methods and functions with different indentation. + if arg_names[0] == "self": + print("%-74s %2.4f sec" % (" " * 4 + method_name + arg_text, tt)) + elif arg_names[0] == "test": + pass + else: + global counter + counter += 1 + print("%i %-70s %2.4f sec" % (counter, method_name + arg_text, tt)) + + return result + else: + # If config["TIME_PROGRESS"] is false, or config["USE_PARALLEL"] is true, run functions normally without timing. + return f(*args, **kw) + + return wrap diff --git a/third_party/sam3/sam3/eval/hota_eval_toolkit/trackeval/datasets/__init__.py b/third_party/sam3/sam3/eval/hota_eval_toolkit/trackeval/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..68e617d8c0e1c4083f027ab55225dcaaf995a531 --- /dev/null +++ b/third_party/sam3/sam3/eval/hota_eval_toolkit/trackeval/datasets/__init__.py @@ -0,0 +1,6 @@ +# flake8: noqa + +# pyre-unsafe + +from .tao_ow import TAO_OW +from .youtube_vis import YouTubeVIS diff --git a/third_party/sam3/sam3/eval/hota_eval_toolkit/trackeval/datasets/_base_dataset.py b/third_party/sam3/sam3/eval/hota_eval_toolkit/trackeval/datasets/_base_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..49c3c03f145069bbf58b3ebd702dcd779e5402e5 --- /dev/null +++ b/third_party/sam3/sam3/eval/hota_eval_toolkit/trackeval/datasets/_base_dataset.py @@ -0,0 +1,381 @@ +# flake8: noqa + +# pyre-unsafe + +import csv +import io +import os +import traceback +import zipfile +from abc import ABC, abstractmethod +from copy import deepcopy + +import numpy as np + +from .. import _timing +from ..utils import TrackEvalException + + +class _BaseDataset(ABC): + @abstractmethod + def __init__(self): + self.tracker_list = None + self.seq_list = None + self.class_list = None + self.output_fol = None + self.output_sub_fol = None + self.should_classes_combine = True + self.use_super_categories = False + + # Functions to implement: + + @staticmethod + @abstractmethod + def get_default_dataset_config(): ... + + @abstractmethod + def _load_raw_file(self, tracker, seq, is_gt): ... + + @_timing.time + @abstractmethod + def get_preprocessed_seq_data(self, raw_data, cls): ... + + @abstractmethod + def _calculate_similarities(self, gt_dets_t, tracker_dets_t): ... + + # Helper functions for all datasets: + + @classmethod + def get_class_name(cls): + return cls.__name__ + + def get_name(self): + return self.get_class_name() + + def get_output_fol(self, tracker): + return os.path.join(self.output_fol, tracker, self.output_sub_fol) + + def get_display_name(self, tracker): + """Can be overwritten if the trackers name (in files) is different to how it should be displayed. + By default this method just returns the trackers name as is. + """ + return tracker + + def get_eval_info(self): + """Return info about the dataset needed for the Evaluator""" + return self.tracker_list, self.seq_list, self.class_list + + @_timing.time + def get_raw_seq_data(self, tracker, seq): + """Loads raw data (tracker and ground-truth) for a single tracker on a single sequence. + Raw data includes all of the information needed for both preprocessing and evaluation, for all classes. + A later function (get_processed_seq_data) will perform such preprocessing and extract relevant information for + the evaluation of each class. + + This returns a dict which contains the fields: + [num_timesteps]: integer + [gt_ids, tracker_ids, gt_classes, tracker_classes, tracker_confidences]: + list (for each timestep) of 1D NDArrays (for each det). + [gt_dets, tracker_dets, gt_crowd_ignore_regions]: list (for each timestep) of lists of detections. + [similarity_scores]: list (for each timestep) of 2D NDArrays. + [gt_extras]: dict (for each extra) of lists (for each timestep) of 1D NDArrays (for each det). + + gt_extras contains dataset specific information used for preprocessing such as occlusion and truncation levels. + + Note that similarities are extracted as part of the dataset and not the metric, because almost all metrics are + independent of the exact method of calculating the similarity. However datasets are not (e.g. segmentation + masks vs 2D boxes vs 3D boxes). + We calculate the similarity before preprocessing because often both preprocessing and evaluation require it and + we don't wish to calculate this twice. + We calculate similarity between all gt and tracker classes (not just each class individually) to allow for + calculation of metrics such as class confusion matrices. Typically the impact of this on performance is low. + """ + # Load raw data. + raw_gt_data = self._load_raw_file(tracker, seq, is_gt=True) + raw_tracker_data = self._load_raw_file(tracker, seq, is_gt=False) + raw_data = {**raw_tracker_data, **raw_gt_data} # Merges dictionaries + + # Calculate similarities for each timestep. + similarity_scores = [] + for t, (gt_dets_t, tracker_dets_t) in enumerate( + zip(raw_data["gt_dets"], raw_data["tracker_dets"]) + ): + ious = self._calculate_similarities(gt_dets_t, tracker_dets_t) + similarity_scores.append(ious) + raw_data["similarity_scores"] = similarity_scores + return raw_data + + @staticmethod + def _load_simple_text_file( + file, + time_col=0, + id_col=None, + remove_negative_ids=False, + valid_filter=None, + crowd_ignore_filter=None, + convert_filter=None, + is_zipped=False, + zip_file=None, + force_delimiters=None, + ): + """Function that loads data which is in a commonly used text file format. + Assumes each det is given by one row of a text file. + There is no limit to the number or meaning of each column, + however one column needs to give the timestep of each det (time_col) which is default col 0. + + The file dialect (deliminator, num cols, etc) is determined automatically. + This function automatically separates dets by timestep, + and is much faster than alternatives such as np.loadtext or pandas. + + If remove_negative_ids is True and id_col is not None, dets with negative values in id_col are excluded. + These are not excluded from ignore data. + + valid_filter can be used to only include certain classes. + It is a dict with ints as keys, and lists as values, + such that a row is included if "row[key].lower() is in value" for all key/value pairs in the dict. + If None, all classes are included. + + crowd_ignore_filter can be used to read crowd_ignore regions separately. It has the same format as valid filter. + + convert_filter can be used to convert value read to another format. + This is used most commonly to convert classes given as string to a class id. + This is a dict such that the key is the column to convert, and the value is another dict giving the mapping. + + Optionally, input files could be a zip of multiple text files for storage efficiency. + + Returns read_data and ignore_data. + Each is a dict (with keys as timesteps as strings) of lists (over dets) of lists (over column values). + Note that all data is returned as strings, and must be converted to float/int later if needed. + Note that timesteps will not be present in the returned dict keys if there are no dets for them + """ + + if remove_negative_ids and id_col is None: + raise TrackEvalException( + "remove_negative_ids is True, but id_col is not given." + ) + if crowd_ignore_filter is None: + crowd_ignore_filter = {} + if convert_filter is None: + convert_filter = {} + try: + if is_zipped: # Either open file directly or within a zip. + if zip_file is None: + raise TrackEvalException( + "is_zipped set to True, but no zip_file is given." + ) + archive = zipfile.ZipFile(os.path.join(zip_file), "r") + fp = io.TextIOWrapper(archive.open(file, "r")) + else: + fp = open(file) + read_data = {} + crowd_ignore_data = {} + fp.seek(0, os.SEEK_END) + # check if file is empty + if fp.tell(): + fp.seek(0) + dialect = csv.Sniffer().sniff( + fp.readline(), delimiters=force_delimiters + ) # Auto determine structure. + dialect.skipinitialspace = ( + True # Deal with extra spaces between columns + ) + fp.seek(0) + reader = csv.reader(fp, dialect) + for row in reader: + try: + # Deal with extra trailing spaces at the end of rows + if row[-1] in "": + row = row[:-1] + timestep = str(int(float(row[time_col]))) + # Read ignore regions separately. + is_ignored = False + for ignore_key, ignore_value in crowd_ignore_filter.items(): + if row[ignore_key].lower() in ignore_value: + # Convert values in one column (e.g. string to id) + for ( + convert_key, + convert_value, + ) in convert_filter.items(): + row[convert_key] = convert_value[ + row[convert_key].lower() + ] + # Save data separated by timestep. + if timestep in crowd_ignore_data.keys(): + crowd_ignore_data[timestep].append(row) + else: + crowd_ignore_data[timestep] = [row] + is_ignored = True + if ( + is_ignored + ): # if det is an ignore region, it cannot be a normal det. + continue + # Exclude some dets if not valid. + if valid_filter is not None: + for key, value in valid_filter.items(): + if row[key].lower() not in value: + continue + if remove_negative_ids: + if int(float(row[id_col])) < 0: + continue + # Convert values in one column (e.g. string to id) + for convert_key, convert_value in convert_filter.items(): + row[convert_key] = convert_value[row[convert_key].lower()] + # Save data separated by timestep. + if timestep in read_data.keys(): + read_data[timestep].append(row) + else: + read_data[timestep] = [row] + except Exception: + exc_str_init = ( + "In file %s the following line cannot be read correctly: \n" + % os.path.basename(file) + ) + exc_str = " ".join([exc_str_init] + row) + raise TrackEvalException(exc_str) + fp.close() + except Exception: + print("Error loading file: %s, printing traceback." % file) + traceback.print_exc() + raise TrackEvalException( + "File %s cannot be read because it is either not present or invalidly formatted" + % os.path.basename(file) + ) + return read_data, crowd_ignore_data + + @staticmethod + def _calculate_mask_ious(masks1, masks2, is_encoded=False, do_ioa=False): + """Calculates the IOU (intersection over union) between two arrays of segmentation masks. + If is_encoded a run length encoding with pycocotools is assumed as input format, otherwise an input of numpy + arrays of the shape (num_masks, height, width) is assumed and the encoding is performed. + If do_ioa (intersection over area) , then calculates the intersection over the area of masks1 - this is commonly + used to determine if detections are within crowd ignore region. + :param masks1: first set of masks (numpy array of shape (num_masks, height, width) if not encoded, + else pycocotools rle encoded format) + :param masks2: second set of masks (numpy array of shape (num_masks, height, width) if not encoded, + else pycocotools rle encoded format) + :param is_encoded: whether the input is in pycocotools rle encoded format + :param do_ioa: whether to perform IoA computation + :return: the IoU/IoA scores + """ + + # Only loaded when run to reduce minimum requirements + from pycocotools import mask as mask_utils + + # use pycocotools for run length encoding of masks + if not is_encoded: + masks1 = mask_utils.encode( + np.array(np.transpose(masks1, (1, 2, 0)), order="F") + ) + masks2 = mask_utils.encode( + np.array(np.transpose(masks2, (1, 2, 0)), order="F") + ) + + # use pycocotools for iou computation of rle encoded masks + ious = mask_utils.iou(masks1, masks2, [do_ioa] * len(masks2)) + if len(masks1) == 0 or len(masks2) == 0: + ious = np.asarray(ious).reshape(len(masks1), len(masks2)) + assert (ious >= 0 - np.finfo("float").eps).all() + assert (ious <= 1 + np.finfo("float").eps).all() + + return ious + + @staticmethod + def _calculate_box_ious(bboxes1, bboxes2, box_format="xywh", do_ioa=False): + """Calculates the IOU (intersection over union) between two arrays of boxes. + Allows variable box formats ('xywh' and 'x0y0x1y1'). + If do_ioa (intersection over area) , then calculates the intersection over the area of boxes1 - this is commonly + used to determine if detections are within crowd ignore region. + """ + if box_format in "xywh": + # layout: (x0, y0, w, h) + bboxes1 = deepcopy(bboxes1) + bboxes2 = deepcopy(bboxes2) + + bboxes1[:, 2] = bboxes1[:, 0] + bboxes1[:, 2] + bboxes1[:, 3] = bboxes1[:, 1] + bboxes1[:, 3] + bboxes2[:, 2] = bboxes2[:, 0] + bboxes2[:, 2] + bboxes2[:, 3] = bboxes2[:, 1] + bboxes2[:, 3] + elif box_format not in "x0y0x1y1": + raise (TrackEvalException("box_format %s is not implemented" % box_format)) + + # layout: (x0, y0, x1, y1) + min_ = np.minimum(bboxes1[:, np.newaxis, :], bboxes2[np.newaxis, :, :]) + max_ = np.maximum(bboxes1[:, np.newaxis, :], bboxes2[np.newaxis, :, :]) + intersection = np.maximum(min_[..., 2] - max_[..., 0], 0) * np.maximum( + min_[..., 3] - max_[..., 1], 0 + ) + area1 = (bboxes1[..., 2] - bboxes1[..., 0]) * ( + bboxes1[..., 3] - bboxes1[..., 1] + ) + + if do_ioa: + ioas = np.zeros_like(intersection) + valid_mask = area1 > 0 + np.finfo("float").eps + ioas[valid_mask, :] = ( + intersection[valid_mask, :] / area1[valid_mask][:, np.newaxis] + ) + + return ioas + else: + area2 = (bboxes2[..., 2] - bboxes2[..., 0]) * ( + bboxes2[..., 3] - bboxes2[..., 1] + ) + union = area1[:, np.newaxis] + area2[np.newaxis, :] - intersection + intersection[area1 <= 0 + np.finfo("float").eps, :] = 0 + intersection[:, area2 <= 0 + np.finfo("float").eps] = 0 + intersection[union <= 0 + np.finfo("float").eps] = 0 + union[union <= 0 + np.finfo("float").eps] = 1 + ious = intersection / union + return ious + + @staticmethod + def _calculate_euclidean_similarity(dets1, dets2, zero_distance=2.0): + """Calculates the euclidean distance between two sets of detections, and then converts this into a similarity + measure with values between 0 and 1 using the following formula: sim = max(0, 1 - dist/zero_distance). + The default zero_distance of 2.0, corresponds to the default used in MOT15_3D, such that a 0.5 similarity + threshold corresponds to a 1m distance threshold for TPs. + """ + dist = np.linalg.norm(dets1[:, np.newaxis] - dets2[np.newaxis, :], axis=2) + sim = np.maximum(0, 1 - dist / zero_distance) + return sim + + @staticmethod + def _check_unique_ids(data, after_preproc=False): + """Check the requirement that the tracker_ids and gt_ids are unique per timestep""" + gt_ids = data["gt_ids"] + tracker_ids = data["tracker_ids"] + for t, (gt_ids_t, tracker_ids_t) in enumerate(zip(gt_ids, tracker_ids)): + if len(tracker_ids_t) > 0: + unique_ids, counts = np.unique(tracker_ids_t, return_counts=True) + if np.max(counts) != 1: + duplicate_ids = unique_ids[counts > 1] + exc_str_init = ( + "Tracker predicts the same ID more than once in a single timestep " + "(seq: %s, frame: %i, ids:" % (data["seq"], t + 1) + ) + exc_str = ( + " ".join([exc_str_init] + [str(d) for d in duplicate_ids]) + ")" + ) + if after_preproc: + exc_str_init += ( + "\n Note that this error occurred after preprocessing (but not before), " + "so ids may not be as in file, and something seems wrong with preproc." + ) + raise TrackEvalException(exc_str) + if len(gt_ids_t) > 0: + unique_ids, counts = np.unique(gt_ids_t, return_counts=True) + if np.max(counts) != 1: + duplicate_ids = unique_ids[counts > 1] + exc_str_init = ( + "Ground-truth has the same ID more than once in a single timestep " + "(seq: %s, frame: %i, ids:" % (data["seq"], t + 1) + ) + exc_str = ( + " ".join([exc_str_init] + [str(d) for d in duplicate_ids]) + ")" + ) + if after_preproc: + exc_str_init += ( + "\n Note that this error occurred after preprocessing (but not before), " + "so ids may not be as in file, and something seems wrong with preproc." + ) + raise TrackEvalException(exc_str) diff --git a/third_party/sam3/sam3/eval/hota_eval_toolkit/trackeval/datasets/tao_ow.py b/third_party/sam3/sam3/eval/hota_eval_toolkit/trackeval/datasets/tao_ow.py new file mode 100644 index 0000000000000000000000000000000000000000..545a05154fc85e3448ecc2b46b48c050e524961f --- /dev/null +++ b/third_party/sam3/sam3/eval/hota_eval_toolkit/trackeval/datasets/tao_ow.py @@ -0,0 +1,893 @@ +# flake8: noqa + +# pyre-unsafe + +import itertools +import json +import os +from collections import defaultdict + +import numpy as np +from scipy.optimize import linear_sum_assignment + +from .. import _timing, utils +from ..utils import TrackEvalException +from ._base_dataset import _BaseDataset + + +class TAO_OW(_BaseDataset): + """Dataset class for TAO tracking""" + + @staticmethod + def get_default_dataset_config(): + """Default class config values""" + code_path = utils.get_code_path() + default_config = { + "GT_FOLDER": os.path.join( + code_path, "data/gt/tao/tao_training" + ), # Location of GT data + "TRACKERS_FOLDER": os.path.join( + code_path, "data/trackers/tao/tao_training" + ), # Trackers location + "OUTPUT_FOLDER": None, # Where to save eval results (if None, same as TRACKERS_FOLDER) + "TRACKERS_TO_EVAL": None, # Filenames of trackers to eval (if None, all in folder) + "CLASSES_TO_EVAL": None, # Classes to eval (if None, all classes) + "SPLIT_TO_EVAL": "training", # Valid: 'training', 'val' + "PRINT_CONFIG": True, # Whether to print current config + "TRACKER_SUB_FOLDER": "data", # Tracker files are in TRACKER_FOLDER/tracker_name/TRACKER_SUB_FOLDER + "OUTPUT_SUB_FOLDER": "", # Output files are saved in OUTPUT_FOLDER/tracker_name/OUTPUT_SUB_FOLDER + "TRACKER_DISPLAY_NAMES": None, # Names of trackers to display, if None: TRACKERS_TO_EVAL + "MAX_DETECTIONS": 300, # Number of maximal allowed detections per image (0 for unlimited) + "SUBSET": "all", + } + return default_config + + def __init__(self, config=None): + """Initialise dataset, checking that all required files are present""" + super().__init__() + # Fill non-given config values with defaults + self.config = utils.init_config( + config, self.get_default_dataset_config(), self.get_name() + ) + self.gt_fol = self.config["GT_FOLDER"] + self.tracker_fol = self.config["TRACKERS_FOLDER"] + self.should_classes_combine = True + self.use_super_categories = False + + self.tracker_sub_fol = self.config["TRACKER_SUB_FOLDER"] + self.output_fol = self.config["OUTPUT_FOLDER"] + if self.output_fol is None: + self.output_fol = self.tracker_fol + self.output_sub_fol = self.config["OUTPUT_SUB_FOLDER"] + + gt_dir_files = [ + file for file in os.listdir(self.gt_fol) if file.endswith(".json") + ] + if len(gt_dir_files) != 1: + raise TrackEvalException( + self.gt_fol + " does not contain exactly one json file." + ) + + with open(os.path.join(self.gt_fol, gt_dir_files[0])) as f: + self.gt_data = json.load(f) + + self.subset = self.config["SUBSET"] + if self.subset != "all": + # Split GT data into `known`, `unknown` or `distractor` + self._split_known_unknown_distractor() + self.gt_data = self._filter_gt_data(self.gt_data) + + # merge categories marked with a merged tag in TAO dataset + self._merge_categories(self.gt_data["annotations"] + self.gt_data["tracks"]) + + # Get sequences to eval and sequence information + self.seq_list = [ + vid["name"].replace("/", "-") for vid in self.gt_data["videos"] + ] + self.seq_name_to_seq_id = { + vid["name"].replace("/", "-"): vid["id"] for vid in self.gt_data["videos"] + } + # compute mappings from videos to annotation data + self.videos_to_gt_tracks, self.videos_to_gt_images = self._compute_vid_mappings( + self.gt_data["annotations"] + ) + # compute sequence lengths + self.seq_lengths = {vid["id"]: 0 for vid in self.gt_data["videos"]} + for img in self.gt_data["images"]: + self.seq_lengths[img["video_id"]] += 1 + self.seq_to_images_to_timestep = self._compute_image_to_timestep_mappings() + self.seq_to_classes = { + vid["id"]: { + "pos_cat_ids": list( + { + track["category_id"] + for track in self.videos_to_gt_tracks[vid["id"]] + } + ), + "neg_cat_ids": vid["neg_category_ids"], + "not_exhaustively_labeled_cat_ids": vid["not_exhaustive_category_ids"], + } + for vid in self.gt_data["videos"] + } + + # Get classes to eval + considered_vid_ids = [self.seq_name_to_seq_id[vid] for vid in self.seq_list] + seen_cats = set( + [ + cat_id + for vid_id in considered_vid_ids + for cat_id in self.seq_to_classes[vid_id]["pos_cat_ids"] + ] + ) + # only classes with ground truth are evaluated in TAO + self.valid_classes = [ + cls["name"] for cls in self.gt_data["categories"] if cls["id"] in seen_cats + ] + # cls_name_to_cls_id_map = {cls['name']: cls['id'] for cls in self.gt_data['categories']} + + if self.config["CLASSES_TO_EVAL"]: + # self.class_list = [cls.lower() if cls.lower() in self.valid_classes else None + # for cls in self.config['CLASSES_TO_EVAL']] + self.class_list = ["object"] # class-agnostic + if not all(self.class_list): + raise TrackEvalException( + "Attempted to evaluate an invalid class. Only classes " + + ", ".join(self.valid_classes) + + " are valid (classes present in ground truth data)." + ) + else: + # self.class_list = [cls for cls in self.valid_classes] + self.class_list = ["object"] # class-agnostic + # self.class_name_to_class_id = {k: v for k, v in cls_name_to_cls_id_map.items() if k in self.class_list} + self.class_name_to_class_id = {"object": 1} # class-agnostic + + # Get trackers to eval + if self.config["TRACKERS_TO_EVAL"] is None: + self.tracker_list = os.listdir(self.tracker_fol) + else: + self.tracker_list = self.config["TRACKERS_TO_EVAL"] + + if self.config["TRACKER_DISPLAY_NAMES"] is None: + self.tracker_to_disp = dict(zip(self.tracker_list, self.tracker_list)) + elif (self.config["TRACKERS_TO_EVAL"] is not None) and ( + len(self.config["TRACKER_DISPLAY_NAMES"]) == len(self.tracker_list) + ): + self.tracker_to_disp = dict( + zip(self.tracker_list, self.config["TRACKER_DISPLAY_NAMES"]) + ) + else: + raise TrackEvalException( + "List of tracker files and tracker display names do not match." + ) + + self.tracker_data = {tracker: dict() for tracker in self.tracker_list} + + for tracker in self.tracker_list: + tr_dir_files = [ + file + for file in os.listdir( + os.path.join(self.tracker_fol, tracker, self.tracker_sub_fol) + ) + if file.endswith(".json") + ] + if len(tr_dir_files) != 1: + raise TrackEvalException( + os.path.join(self.tracker_fol, tracker, self.tracker_sub_fol) + + " does not contain exactly one json file." + ) + with open( + os.path.join( + self.tracker_fol, tracker, self.tracker_sub_fol, tr_dir_files[0] + ) + ) as f: + curr_data = json.load(f) + + # limit detections if MAX_DETECTIONS > 0 + if self.config["MAX_DETECTIONS"]: + curr_data = self._limit_dets_per_image(curr_data) + + # fill missing video ids + self._fill_video_ids_inplace(curr_data) + + # make track ids unique over whole evaluation set + self._make_track_ids_unique(curr_data) + + # merge categories marked with a merged tag in TAO dataset + self._merge_categories(curr_data) + + # get tracker sequence information + curr_videos_to_tracker_tracks, curr_videos_to_tracker_images = ( + self._compute_vid_mappings(curr_data) + ) + self.tracker_data[tracker]["vids_to_tracks"] = curr_videos_to_tracker_tracks + self.tracker_data[tracker]["vids_to_images"] = curr_videos_to_tracker_images + + def get_display_name(self, tracker): + return self.tracker_to_disp[tracker] + + def _load_raw_file(self, tracker, seq, is_gt): + """Load a file (gt or tracker) in the TAO format + + If is_gt, this returns a dict which contains the fields: + [gt_ids, gt_classes] : list (for each timestep) of 1D NDArrays (for each det). + [gt_dets]: list (for each timestep) of lists of detections. + [classes_to_gt_tracks]: dictionary with class values as keys and list of dictionaries (with frame indices as + keys and corresponding segmentations as values) for each track + [classes_to_gt_track_ids, classes_to_gt_track_areas, classes_to_gt_track_lengths]: dictionary with class values + as keys and lists (for each track) as values + + if not is_gt, this returns a dict which contains the fields: + [tracker_ids, tracker_classes, tracker_confidences] : list (for each timestep) of 1D NDArrays (for each det). + [tracker_dets]: list (for each timestep) of lists of detections. + [classes_to_dt_tracks]: dictionary with class values as keys and list of dictionaries (with frame indices as + keys and corresponding segmentations as values) for each track + [classes_to_dt_track_ids, classes_to_dt_track_areas, classes_to_dt_track_lengths]: dictionary with class values + as keys and lists as values + [classes_to_dt_track_scores]: dictionary with class values as keys and 1D numpy arrays as values + """ + seq_id = self.seq_name_to_seq_id[seq] + # File location + if is_gt: + imgs = self.videos_to_gt_images[seq_id] + else: + imgs = self.tracker_data[tracker]["vids_to_images"][seq_id] + + # Convert data to required format + num_timesteps = self.seq_lengths[seq_id] + img_to_timestep = self.seq_to_images_to_timestep[seq_id] + data_keys = ["ids", "classes", "dets"] + if not is_gt: + data_keys += ["tracker_confidences"] + raw_data = {key: [None] * num_timesteps for key in data_keys} + for img in imgs: + # some tracker data contains images without any ground truth information, these are ignored + try: + t = img_to_timestep[img["id"]] + except KeyError: + continue + annotations = img["annotations"] + raw_data["dets"][t] = np.atleast_2d( + [ann["bbox"] for ann in annotations] + ).astype(float) + raw_data["ids"][t] = np.atleast_1d( + [ann["track_id"] for ann in annotations] + ).astype(int) + raw_data["classes"][t] = np.atleast_1d([1 for _ in annotations]).astype( + int + ) # class-agnostic + if not is_gt: + raw_data["tracker_confidences"][t] = np.atleast_1d( + [ann["score"] for ann in annotations] + ).astype(float) + + for t, d in enumerate(raw_data["dets"]): + if d is None: + raw_data["dets"][t] = np.empty((0, 4)).astype(float) + raw_data["ids"][t] = np.empty(0).astype(int) + raw_data["classes"][t] = np.empty(0).astype(int) + if not is_gt: + raw_data["tracker_confidences"][t] = np.empty(0) + + if is_gt: + key_map = {"ids": "gt_ids", "classes": "gt_classes", "dets": "gt_dets"} + else: + key_map = { + "ids": "tracker_ids", + "classes": "tracker_classes", + "dets": "tracker_dets", + } + for k, v in key_map.items(): + raw_data[v] = raw_data.pop(k) + + # all_classes = [self.class_name_to_class_id[cls] for cls in self.class_list] + all_classes = [1] # class-agnostic + + if is_gt: + classes_to_consider = all_classes + all_tracks = self.videos_to_gt_tracks[seq_id] + else: + # classes_to_consider = self.seq_to_classes[seq_id]['pos_cat_ids'] \ + # + self.seq_to_classes[seq_id]['neg_cat_ids'] + classes_to_consider = all_classes # class-agnostic + all_tracks = self.tracker_data[tracker]["vids_to_tracks"][seq_id] + + # classes_to_tracks = {cls: [track for track in all_tracks if track['category_id'] == cls] + # if cls in classes_to_consider else [] for cls in all_classes} + classes_to_tracks = { + cls: [track for track in all_tracks] if cls in classes_to_consider else [] + for cls in all_classes + } # class-agnostic + + # mapping from classes to track information + raw_data["classes_to_tracks"] = { + cls: [ + { + det["image_id"]: np.atleast_1d(det["bbox"]) + for det in track["annotations"] + } + for track in tracks + ] + for cls, tracks in classes_to_tracks.items() + } + raw_data["classes_to_track_ids"] = { + cls: [track["id"] for track in tracks] + for cls, tracks in classes_to_tracks.items() + } + raw_data["classes_to_track_areas"] = { + cls: [track["area"] for track in tracks] + for cls, tracks in classes_to_tracks.items() + } + raw_data["classes_to_track_lengths"] = { + cls: [len(track["annotations"]) for track in tracks] + for cls, tracks in classes_to_tracks.items() + } + + if not is_gt: + raw_data["classes_to_dt_track_scores"] = { + cls: np.array( + [ + np.mean([float(x["score"]) for x in track["annotations"]]) + for track in tracks + ] + ) + for cls, tracks in classes_to_tracks.items() + } + + if is_gt: + key_map = { + "classes_to_tracks": "classes_to_gt_tracks", + "classes_to_track_ids": "classes_to_gt_track_ids", + "classes_to_track_lengths": "classes_to_gt_track_lengths", + "classes_to_track_areas": "classes_to_gt_track_areas", + } + else: + key_map = { + "classes_to_tracks": "classes_to_dt_tracks", + "classes_to_track_ids": "classes_to_dt_track_ids", + "classes_to_track_lengths": "classes_to_dt_track_lengths", + "classes_to_track_areas": "classes_to_dt_track_areas", + } + for k, v in key_map.items(): + raw_data[v] = raw_data.pop(k) + + raw_data["num_timesteps"] = num_timesteps + raw_data["neg_cat_ids"] = self.seq_to_classes[seq_id]["neg_cat_ids"] + raw_data["not_exhaustively_labeled_cls"] = self.seq_to_classes[seq_id][ + "not_exhaustively_labeled_cat_ids" + ] + raw_data["seq"] = seq + return raw_data + + @_timing.time + def get_preprocessed_seq_data(self, raw_data, cls): + """Preprocess data for a single sequence for a single class ready for evaluation. + Inputs: + - raw_data is a dict containing the data for the sequence already read in by get_raw_seq_data(). + - cls is the class to be evaluated. + Outputs: + - data is a dict containing all of the information that metrics need to perform evaluation. + It contains the following fields: + [num_timesteps, num_gt_ids, num_tracker_ids, num_gt_dets, num_tracker_dets] : integers. + [gt_ids, tracker_ids, tracker_confidences]: list (for each timestep) of 1D NDArrays (for each det). + [gt_dets, tracker_dets]: list (for each timestep) of lists of detections. + [similarity_scores]: list (for each timestep) of 2D NDArrays. + Notes: + General preprocessing (preproc) occurs in 4 steps. Some datasets may not use all of these steps. + 1) Extract only detections relevant for the class to be evaluated (including distractor detections). + 2) Match gt dets and tracker dets. Remove tracker dets that are matched to a gt det that is of a + distractor class, or otherwise marked as to be removed. + 3) Remove unmatched tracker dets if they fall within a crowd ignore region or don't meet a certain + other criteria (e.g. are too small). + 4) Remove gt dets that were only useful for preprocessing and not for actual evaluation. + After the above preprocessing steps, this function also calculates the number of gt and tracker detections + and unique track ids. It also relabels gt and tracker ids to be contiguous and checks that ids are + unique within each timestep. + TAO: + In TAO, the 4 preproc steps are as follow: + 1) All classes present in the ground truth data are evaluated separately. + 2) No matched tracker detections are removed. + 3) Unmatched tracker detections are removed if there is not ground truth data and the class does not + belong to the categories marked as negative for this sequence. Additionally, unmatched tracker + detections for classes which are marked as not exhaustively labeled are removed. + 4) No gt detections are removed. + Further, for TrackMAP computation track representations for the given class are accessed from a dictionary + and the tracks from the tracker data are sorted according to the tracker confidence. + """ + cls_id = self.class_name_to_class_id[cls] + is_not_exhaustively_labeled = cls_id in raw_data["not_exhaustively_labeled_cls"] + is_neg_category = cls_id in raw_data["neg_cat_ids"] + + data_keys = [ + "gt_ids", + "tracker_ids", + "gt_dets", + "tracker_dets", + "tracker_confidences", + "similarity_scores", + ] + data = {key: [None] * raw_data["num_timesteps"] for key in data_keys} + unique_gt_ids = [] + unique_tracker_ids = [] + num_gt_dets = 0 + num_tracker_dets = 0 + for t in range(raw_data["num_timesteps"]): + # Only extract relevant dets for this class for preproc and eval (cls) + gt_class_mask = np.atleast_1d(raw_data["gt_classes"][t] == cls_id) + gt_class_mask = gt_class_mask.astype(bool) + gt_ids = raw_data["gt_ids"][t][gt_class_mask] + gt_dets = raw_data["gt_dets"][t][gt_class_mask] + + tracker_class_mask = np.atleast_1d(raw_data["tracker_classes"][t] == cls_id) + tracker_class_mask = tracker_class_mask.astype(bool) + tracker_ids = raw_data["tracker_ids"][t][tracker_class_mask] + tracker_dets = raw_data["tracker_dets"][t][tracker_class_mask] + tracker_confidences = raw_data["tracker_confidences"][t][tracker_class_mask] + similarity_scores = raw_data["similarity_scores"][t][gt_class_mask, :][ + :, tracker_class_mask + ] + + # Match tracker and gt dets (with hungarian algorithm). + unmatched_indices = np.arange(tracker_ids.shape[0]) + if gt_ids.shape[0] > 0 and tracker_ids.shape[0] > 0: + matching_scores = similarity_scores.copy() + matching_scores[matching_scores < 0.5 - np.finfo("float").eps] = 0 + match_rows, match_cols = linear_sum_assignment(-matching_scores) + actually_matched_mask = ( + matching_scores[match_rows, match_cols] > 0 + np.finfo("float").eps + ) + match_cols = match_cols[actually_matched_mask] + unmatched_indices = np.delete(unmatched_indices, match_cols, axis=0) + + if gt_ids.shape[0] == 0 and not is_neg_category: + to_remove_tracker = unmatched_indices + elif is_not_exhaustively_labeled: + to_remove_tracker = unmatched_indices + else: + to_remove_tracker = np.array([], dtype=int) + + # remove all unwanted unmatched tracker detections + data["tracker_ids"][t] = np.delete(tracker_ids, to_remove_tracker, axis=0) + data["tracker_dets"][t] = np.delete(tracker_dets, to_remove_tracker, axis=0) + data["tracker_confidences"][t] = np.delete( + tracker_confidences, to_remove_tracker, axis=0 + ) + similarity_scores = np.delete(similarity_scores, to_remove_tracker, axis=1) + + data["gt_ids"][t] = gt_ids + data["gt_dets"][t] = gt_dets + data["similarity_scores"][t] = similarity_scores + + unique_gt_ids += list(np.unique(data["gt_ids"][t])) + unique_tracker_ids += list(np.unique(data["tracker_ids"][t])) + num_tracker_dets += len(data["tracker_ids"][t]) + num_gt_dets += len(data["gt_ids"][t]) + + # Re-label IDs such that there are no empty IDs + if len(unique_gt_ids) > 0: + unique_gt_ids = np.unique(unique_gt_ids) + gt_id_map = np.nan * np.ones((np.max(unique_gt_ids) + 1)) + gt_id_map[unique_gt_ids] = np.arange(len(unique_gt_ids)) + for t in range(raw_data["num_timesteps"]): + if len(data["gt_ids"][t]) > 0: + data["gt_ids"][t] = gt_id_map[data["gt_ids"][t]].astype(int) + if len(unique_tracker_ids) > 0: + unique_tracker_ids = np.unique(unique_tracker_ids) + tracker_id_map = np.nan * np.ones((np.max(unique_tracker_ids) + 1)) + tracker_id_map[unique_tracker_ids] = np.arange(len(unique_tracker_ids)) + for t in range(raw_data["num_timesteps"]): + if len(data["tracker_ids"][t]) > 0: + data["tracker_ids"][t] = tracker_id_map[ + data["tracker_ids"][t] + ].astype(int) + + # Record overview statistics. + data["num_tracker_dets"] = num_tracker_dets + data["num_gt_dets"] = num_gt_dets + data["num_tracker_ids"] = len(unique_tracker_ids) + data["num_gt_ids"] = len(unique_gt_ids) + data["num_timesteps"] = raw_data["num_timesteps"] + data["seq"] = raw_data["seq"] + + # get track representations + data["gt_tracks"] = raw_data["classes_to_gt_tracks"][cls_id] + data["gt_track_ids"] = raw_data["classes_to_gt_track_ids"][cls_id] + data["gt_track_lengths"] = raw_data["classes_to_gt_track_lengths"][cls_id] + data["gt_track_areas"] = raw_data["classes_to_gt_track_areas"][cls_id] + data["dt_tracks"] = raw_data["classes_to_dt_tracks"][cls_id] + data["dt_track_ids"] = raw_data["classes_to_dt_track_ids"][cls_id] + data["dt_track_lengths"] = raw_data["classes_to_dt_track_lengths"][cls_id] + data["dt_track_areas"] = raw_data["classes_to_dt_track_areas"][cls_id] + data["dt_track_scores"] = raw_data["classes_to_dt_track_scores"][cls_id] + data["not_exhaustively_labeled"] = is_not_exhaustively_labeled + data["iou_type"] = "bbox" + + # sort tracker data tracks by tracker confidence scores + if data["dt_tracks"]: + idx = np.argsort( + [-score for score in data["dt_track_scores"]], kind="mergesort" + ) + data["dt_track_scores"] = [data["dt_track_scores"][i] for i in idx] + data["dt_tracks"] = [data["dt_tracks"][i] for i in idx] + data["dt_track_ids"] = [data["dt_track_ids"][i] for i in idx] + data["dt_track_lengths"] = [data["dt_track_lengths"][i] for i in idx] + data["dt_track_areas"] = [data["dt_track_areas"][i] for i in idx] + # Ensure that ids are unique per timestep. + self._check_unique_ids(data) + + return data + + def _calculate_similarities(self, gt_dets_t, tracker_dets_t): + similarity_scores = self._calculate_box_ious(gt_dets_t, tracker_dets_t) + return similarity_scores + + def _merge_categories(self, annotations): + """ + Merges categories with a merged tag. Adapted from https://github.com/TAO-Dataset + :param annotations: the annotations in which the classes should be merged + :return: None + """ + merge_map = {} + for category in self.gt_data["categories"]: + if "merged" in category: + for to_merge in category["merged"]: + merge_map[to_merge["id"]] = category["id"] + + for ann in annotations: + ann["category_id"] = merge_map.get(ann["category_id"], ann["category_id"]) + + def _compute_vid_mappings(self, annotations): + """ + Computes mappings from Videos to corresponding tracks and images. + :param annotations: the annotations for which the mapping should be generated + :return: the video-to-track-mapping, the video-to-image-mapping + """ + vids_to_tracks = {} + vids_to_imgs = {} + vid_ids = [vid["id"] for vid in self.gt_data["videos"]] + + # compute an mapping from image IDs to images + images = {} + for image in self.gt_data["images"]: + images[image["id"]] = image + + for ann in annotations: + ann["area"] = ann["bbox"][2] * ann["bbox"][3] + + vid = ann["video_id"] + if ann["video_id"] not in vids_to_tracks.keys(): + vids_to_tracks[ann["video_id"]] = list() + if ann["video_id"] not in vids_to_imgs.keys(): + vids_to_imgs[ann["video_id"]] = list() + + # Fill in vids_to_tracks + tid = ann["track_id"] + exist_tids = [track["id"] for track in vids_to_tracks[vid]] + try: + index1 = exist_tids.index(tid) + except ValueError: + index1 = -1 + if tid not in exist_tids: + curr_track = { + "id": tid, + "category_id": ann["category_id"], + "video_id": vid, + "annotations": [ann], + } + vids_to_tracks[vid].append(curr_track) + else: + vids_to_tracks[vid][index1]["annotations"].append(ann) + + # Fill in vids_to_imgs + img_id = ann["image_id"] + exist_img_ids = [img["id"] for img in vids_to_imgs[vid]] + try: + index2 = exist_img_ids.index(img_id) + except ValueError: + index2 = -1 + if index2 == -1: + curr_img = {"id": img_id, "annotations": [ann]} + vids_to_imgs[vid].append(curr_img) + else: + vids_to_imgs[vid][index2]["annotations"].append(ann) + + # sort annotations by frame index and compute track area + for vid, tracks in vids_to_tracks.items(): + for track in tracks: + track["annotations"] = sorted( + track["annotations"], + key=lambda x: images[x["image_id"]]["frame_index"], + ) + # Computer average area + track["area"] = sum(x["area"] for x in track["annotations"]) / len( + track["annotations"] + ) + + # Ensure all videos are present + for vid_id in vid_ids: + if vid_id not in vids_to_tracks.keys(): + vids_to_tracks[vid_id] = [] + if vid_id not in vids_to_imgs.keys(): + vids_to_imgs[vid_id] = [] + + return vids_to_tracks, vids_to_imgs + + def _compute_image_to_timestep_mappings(self): + """ + Computes a mapping from images to the corresponding timestep in the sequence. + :return: the image-to-timestep-mapping + """ + images = {} + for image in self.gt_data["images"]: + images[image["id"]] = image + + seq_to_imgs_to_timestep = {vid["id"]: dict() for vid in self.gt_data["videos"]} + for vid in seq_to_imgs_to_timestep: + curr_imgs = [img["id"] for img in self.videos_to_gt_images[vid]] + curr_imgs = sorted(curr_imgs, key=lambda x: images[x]["frame_index"]) + seq_to_imgs_to_timestep[vid] = { + curr_imgs[i]: i for i in range(len(curr_imgs)) + } + + return seq_to_imgs_to_timestep + + def _limit_dets_per_image(self, annotations): + """ + Limits the number of detections for each image to config['MAX_DETECTIONS']. Adapted from + https://github.com/TAO-Dataset/ + :param annotations: the annotations in which the detections should be limited + :return: the annotations with limited detections + """ + max_dets = self.config["MAX_DETECTIONS"] + img_ann = defaultdict(list) + for ann in annotations: + img_ann[ann["image_id"]].append(ann) + + for img_id, _anns in img_ann.items(): + if len(_anns) <= max_dets: + continue + _anns = sorted(_anns, key=lambda x: x["score"], reverse=True) + img_ann[img_id] = _anns[:max_dets] + + return [ann for anns in img_ann.values() for ann in anns] + + def _fill_video_ids_inplace(self, annotations): + """ + Fills in missing video IDs inplace. Adapted from https://github.com/TAO-Dataset/ + :param annotations: the annotations for which the videos IDs should be filled inplace + :return: None + """ + missing_video_id = [x for x in annotations if "video_id" not in x] + if missing_video_id: + image_id_to_video_id = { + x["id"]: x["video_id"] for x in self.gt_data["images"] + } + for x in missing_video_id: + x["video_id"] = image_id_to_video_id[x["image_id"]] + + @staticmethod + def _make_track_ids_unique(annotations): + """ + Makes the track IDs unqiue over the whole annotation set. Adapted from https://github.com/TAO-Dataset/ + :param annotations: the annotation set + :return: the number of updated IDs + """ + track_id_videos = {} + track_ids_to_update = set() + max_track_id = 0 + for ann in annotations: + t = ann["track_id"] + if t not in track_id_videos: + track_id_videos[t] = ann["video_id"] + + if ann["video_id"] != track_id_videos[t]: + # Track id is assigned to multiple videos + track_ids_to_update.add(t) + max_track_id = max(max_track_id, t) + + if track_ids_to_update: + print("true") + next_id = itertools.count(max_track_id + 1) + new_track_ids = defaultdict(lambda: next(next_id)) + for ann in annotations: + t = ann["track_id"] + v = ann["video_id"] + if t in track_ids_to_update: + ann["track_id"] = new_track_ids[t, v] + return len(track_ids_to_update) + + def _split_known_unknown_distractor(self): + all_ids = set( + [i for i in range(1, 2000)] + ) # 2000 is larger than the max category id in TAO-OW. + # `knowns` includes 78 TAO_category_ids that corresponds to 78 COCO classes. + # (The other 2 COCO classes do not have corresponding classes in TAO). + self.knowns = { + 4, + 13, + 1038, + 544, + 1057, + 34, + 35, + 36, + 41, + 45, + 58, + 60, + 579, + 1091, + 1097, + 1099, + 78, + 79, + 81, + 91, + 1115, + 1117, + 95, + 1122, + 99, + 1132, + 621, + 1135, + 625, + 118, + 1144, + 126, + 642, + 1155, + 133, + 1162, + 139, + 154, + 174, + 185, + 699, + 1215, + 714, + 717, + 1229, + 211, + 729, + 221, + 229, + 747, + 235, + 237, + 779, + 276, + 805, + 299, + 829, + 852, + 347, + 371, + 382, + 896, + 392, + 926, + 937, + 428, + 429, + 961, + 452, + 979, + 980, + 982, + 475, + 480, + 993, + 1001, + 502, + 1018, + } + # `distractors` is defined as in the paper "Opening up Open-World Tracking" + self.distractors = { + 20, + 63, + 108, + 180, + 188, + 204, + 212, + 247, + 303, + 403, + 407, + 415, + 490, + 504, + 507, + 513, + 529, + 567, + 569, + 588, + 672, + 691, + 702, + 708, + 711, + 720, + 736, + 737, + 798, + 813, + 815, + 827, + 831, + 851, + 877, + 883, + 912, + 971, + 976, + 1130, + 1133, + 1134, + 1169, + 1184, + 1220, + } + self.unknowns = all_ids.difference(self.knowns.union(self.distractors)) + + def _filter_gt_data(self, raw_gt_data): + """ + Filter out irrelevant data in the raw_gt_data + Args: + raw_gt_data: directly loaded from json. + + Returns: + filtered gt_data + """ + valid_cat_ids = list() + if self.subset == "known": + valid_cat_ids = self.knowns + elif self.subset == "distractor": + valid_cat_ids = self.distractors + elif self.subset == "unknown": + valid_cat_ids = self.unknowns + # elif self.subset == "test_only_unknowns": + # valid_cat_ids = test_only_unknowns + else: + raise Exception("The parameter `SUBSET` is incorrect") + + filtered = dict() + filtered["videos"] = raw_gt_data["videos"] + # filtered["videos"] = list() + unwanted_vid = set() + # for video in raw_gt_data["videos"]: + # datasrc = video["name"].split('/')[1] + # if datasrc in data_srcs: + # filtered["videos"].append(video) + # else: + # unwanted_vid.add(video["id"]) + + filtered["annotations"] = list() + for ann in raw_gt_data["annotations"]: + if (ann["video_id"] not in unwanted_vid) and ( + ann["category_id"] in valid_cat_ids + ): + filtered["annotations"].append(ann) + + filtered["tracks"] = list() + for track in raw_gt_data["tracks"]: + if (track["video_id"] not in unwanted_vid) and ( + track["category_id"] in valid_cat_ids + ): + filtered["tracks"].append(track) + + filtered["images"] = list() + for image in raw_gt_data["images"]: + if image["video_id"] not in unwanted_vid: + filtered["images"].append(image) + + filtered["categories"] = list() + for cat in raw_gt_data["categories"]: + if cat["id"] in valid_cat_ids: + filtered["categories"].append(cat) + + filtered["info"] = raw_gt_data["info"] + filtered["licenses"] = raw_gt_data["licenses"] + + return filtered diff --git a/third_party/sam3/sam3/eval/hota_eval_toolkit/trackeval/datasets/youtube_vis.py b/third_party/sam3/sam3/eval/hota_eval_toolkit/trackeval/datasets/youtube_vis.py new file mode 100644 index 0000000000000000000000000000000000000000..7ada4a291f471efc88083e20b60b025f1618186e --- /dev/null +++ b/third_party/sam3/sam3/eval/hota_eval_toolkit/trackeval/datasets/youtube_vis.py @@ -0,0 +1,526 @@ +# flake8: noqa + +# pyre-unsafe + +# note: this file has been modified from its original version in TrackEval in +# https://github.com/JonathonLuiten/TrackEval/blob/master/trackeval/datasets/youtube_vis.py +# to support the following: +# 1) bbox evaluation (via `IOU_TYPE`) +# 2) passing GT and prediction data as Python objects (via `GT_JSON_OBJECT` and `TRACKER_JSON_OBJECT`) +# 3) specifying a custom dataset name (via `DATASET_NAME`) + +import json +import os + +import numpy as np + +from .. import _timing, utils +from ..utils import TrackEvalException +from ._base_dataset import _BaseDataset + + +class YouTubeVIS(_BaseDataset): + """Dataset class for YouTubeVIS tracking""" + + @staticmethod + def get_default_dataset_config(): + """Default class config values""" + code_path = utils.get_code_path() + default_config = { + "GT_FOLDER": os.path.join( + code_path, "data/gt/youtube_vis/" + ), # Location of GT data + "TRACKERS_FOLDER": os.path.join(code_path, "data/trackers/youtube_vis/"), + # Trackers location + "OUTPUT_FOLDER": None, # Where to save eval results (if None, same as TRACKERS_FOLDER) + "TRACKERS_TO_EVAL": None, # Filenames of trackers to eval (if None, all in folder) + "CLASSES_TO_EVAL": None, # Classes to eval (if None, all classes) + "SPLIT_TO_EVAL": "train_sub_split", # Valid: 'train', 'val', 'train_sub_split' + "PRINT_CONFIG": True, # Whether to print current config + "OUTPUT_SUB_FOLDER": "", # Output files are saved in OUTPUT_FOLDER/tracker_name/OUTPUT_SUB_FOLDER + "TRACKER_SUB_FOLDER": "data", # Tracker files are in TRACKER_FOLDER/tracker_name/TRACKER_SUB_FOLDER + "TRACKER_DISPLAY_NAMES": None, # Names of trackers to display, if None: TRACKERS_TO_EVAL + # Added for video phrase AP evaluation -- allow directly specifying the GT JSON data and Tracker (result) + # JSON data as Python objects, without reading from files. + "GT_JSON_OBJECT": None, + "TRACKER_JSON_OBJECT": None, + "IOU_TYPE": "segm", + "DATASET_NAME": "video", + } + return default_config + + def __init__(self, config=None): + """Initialise dataset, checking that all required files are present""" + super().__init__() + # Fill non-given config values with defaults + self.config = utils.init_config(config, self.get_default_dataset_config()) + self.gt_fol = ( + self.config["GT_FOLDER"] + "youtube_vis_" + self.config["SPLIT_TO_EVAL"] + ) + self.tracker_fol = ( + self.config["TRACKERS_FOLDER"] + + "youtube_vis_" + + self.config["SPLIT_TO_EVAL"] + ) + self.use_super_categories = False + self.should_classes_combine = True + assert self.config["IOU_TYPE"] in ["segm", "bbox"] + self.iou_type = self.config["IOU_TYPE"] + print("=" * 100) + print(f"Evaluate annotation type *{self.iou_type}*") + self.dataset_name = self.config["DATASET_NAME"] + + self.output_fol = self.config["OUTPUT_FOLDER"] + if self.output_fol is None: + self.output_fol = self.tracker_fol + self.output_sub_fol = self.config["OUTPUT_SUB_FOLDER"] + self.tracker_sub_fol = self.config["TRACKER_SUB_FOLDER"] + + if self.config["GT_JSON_OBJECT"] is not None: + # allow directly specifying the GT JSON data without reading from files + gt_json = self.config["GT_JSON_OBJECT"] + assert isinstance(gt_json, dict) + assert "videos" in gt_json + assert "categories" in gt_json + assert "annotations" in gt_json + self.gt_data = gt_json + else: + if not os.path.exists(self.gt_fol): + print("GT folder not found: " + self.gt_fol) + raise TrackEvalException( + "GT folder not found: " + os.path.basename(self.gt_fol) + ) + gt_dir_files = [ + file for file in os.listdir(self.gt_fol) if file.endswith(".json") + ] + if len(gt_dir_files) != 1: + raise TrackEvalException( + self.gt_fol + " does not contain exactly one json file." + ) + + with open(os.path.join(self.gt_fol, gt_dir_files[0])) as f: + self.gt_data = json.load(f) + + # Get classes to eval + self.valid_classes = [cls["name"] for cls in self.gt_data["categories"]] + cls_name_to_cls_id_map = { + cls["name"]: cls["id"] for cls in self.gt_data["categories"] + } + + if self.config["CLASSES_TO_EVAL"]: + self.class_list = [ + cls.lower() if cls.lower() in self.valid_classes else None + for cls in self.config["CLASSES_TO_EVAL"] + ] + if not all(self.class_list): + raise TrackEvalException( + "Attempted to evaluate an invalid class. Only classes " + + ", ".join(self.valid_classes) + + " are valid." + ) + else: + self.class_list = [cls["name"] for cls in self.gt_data["categories"]] + self.class_name_to_class_id = { + k: v for k, v in cls_name_to_cls_id_map.items() if k in self.class_list + } + + # Get sequences to eval and check gt files exist + self.seq_list = [ + vid["file_names"][0].split("/")[0] for vid in self.gt_data["videos"] + ] + self.seq_name_to_seq_id = { + vid["file_names"][0].split("/")[0]: vid["id"] + for vid in self.gt_data["videos"] + } + self.seq_lengths = { + vid["id"]: len(vid["file_names"]) for vid in self.gt_data["videos"] + } + + # encode masks and compute track areas + self._prepare_gt_annotations() + + # Get trackers to eval + if self.config["TRACKER_JSON_OBJECT"] is not None: + # allow directly specifying the tracker JSON data without reading from files + tracker_json = self.config["TRACKER_JSON_OBJECT"] + assert isinstance(tracker_json, list) + self.tracker_list = ["tracker"] + elif self.config["TRACKERS_TO_EVAL"] is None: + self.tracker_list = os.listdir(self.tracker_fol) + else: + self.tracker_list = self.config["TRACKERS_TO_EVAL"] + + if self.config["TRACKER_DISPLAY_NAMES"] is None: + self.tracker_to_disp = dict(zip(self.tracker_list, self.tracker_list)) + elif (self.config["TRACKERS_TO_EVAL"] is not None) and ( + len(self.config["TRACKER_DISPLAY_NAMES"]) == len(self.tracker_list) + ): + self.tracker_to_disp = dict( + zip(self.tracker_list, self.config["TRACKER_DISPLAY_NAMES"]) + ) + else: + raise TrackEvalException( + "List of tracker files and tracker display names do not match." + ) + + # counter for globally unique track IDs + self.global_tid_counter = 0 + + self.tracker_data = dict() + if self.config["TRACKER_JSON_OBJECT"] is not None: + # allow directly specifying the tracker JSON data without reading from files + tracker = self.tracker_list[0] + self.tracker_data[tracker] = tracker_json + else: + for tracker in self.tracker_list: + tracker_dir_path = os.path.join( + self.tracker_fol, tracker, self.tracker_sub_fol + ) + tr_dir_files = [ + file + for file in os.listdir(tracker_dir_path) + if file.endswith(".json") + ] + if len(tr_dir_files) != 1: + raise TrackEvalException( + tracker_dir_path + " does not contain exactly one json file." + ) + + with open(os.path.join(tracker_dir_path, tr_dir_files[0])) as f: + curr_data = json.load(f) + + self.tracker_data[tracker] = curr_data + + def get_display_name(self, tracker): + return self.tracker_to_disp[tracker] + + def _load_raw_file(self, tracker, seq, is_gt): + """Load a file (gt or tracker) in the YouTubeVIS format + If is_gt, this returns a dict which contains the fields: + [gt_ids, gt_classes] : list (for each timestep) of 1D NDArrays (for each det). + [gt_dets]: list (for each timestep) of lists of detections. + [classes_to_gt_tracks]: dictionary with class values as keys and list of dictionaries (with frame indices as + keys and corresponding segmentations as values) for each track + [classes_to_gt_track_ids, classes_to_gt_track_areas, classes_to_gt_track_iscrowd]: dictionary with class values + as keys and lists (for each track) as values + + if not is_gt, this returns a dict which contains the fields: + [tracker_ids, tracker_classes, tracker_confidences] : list (for each timestep) of 1D NDArrays (for each det). + [tracker_dets]: list (for each timestep) of lists of detections. + [classes_to_dt_tracks]: dictionary with class values as keys and list of dictionaries (with frame indices as + keys and corresponding segmentations as values) for each track + [classes_to_dt_track_ids, classes_to_dt_track_areas]: dictionary with class values as keys and lists as values + [classes_to_dt_track_scores]: dictionary with class values as keys and 1D numpy arrays as values + """ + # select sequence tracks + seq_id = self.seq_name_to_seq_id[seq] + if is_gt: + tracks = [ + ann for ann in self.gt_data["annotations"] if ann["video_id"] == seq_id + ] + else: + tracks = self._get_tracker_seq_tracks(tracker, seq_id) + + # Convert data to required format + num_timesteps = self.seq_lengths[seq_id] + data_keys = ["ids", "classes", "dets"] + if not is_gt: + data_keys += ["tracker_confidences"] + raw_data = {key: [None] * num_timesteps for key in data_keys} + result_key = "segmentations" if self.iou_type == "segm" else "bboxes" + for t in range(num_timesteps): + raw_data["dets"][t] = [ + track[result_key][t] for track in tracks if track[result_key][t] + ] + raw_data["ids"][t] = np.atleast_1d( + [track["id"] for track in tracks if track[result_key][t]] + ).astype(int) + raw_data["classes"][t] = np.atleast_1d( + [track["category_id"] for track in tracks if track[result_key][t]] + ).astype(int) + if not is_gt: + raw_data["tracker_confidences"][t] = np.atleast_1d( + [track["score"] for track in tracks if track[result_key][t]] + ).astype(float) + + if is_gt: + key_map = {"ids": "gt_ids", "classes": "gt_classes", "dets": "gt_dets"} + else: + key_map = { + "ids": "tracker_ids", + "classes": "tracker_classes", + "dets": "tracker_dets", + } + for k, v in key_map.items(): + raw_data[v] = raw_data.pop(k) + + all_cls_ids = {self.class_name_to_class_id[cls] for cls in self.class_list} + classes_to_tracks = { + cls: [track for track in tracks if track["category_id"] == cls] + for cls in all_cls_ids + } + + # mapping from classes to track representations and track information + raw_data["classes_to_tracks"] = { + cls: [ + {i: track[result_key][i] for i in range(len(track[result_key]))} + for track in tracks + ] + for cls, tracks in classes_to_tracks.items() + } + raw_data["classes_to_track_ids"] = { + cls: [track["id"] for track in tracks] + for cls, tracks in classes_to_tracks.items() + } + raw_data["classes_to_track_areas"] = { + cls: [track["area"] for track in tracks] + for cls, tracks in classes_to_tracks.items() + } + + if is_gt: + raw_data["classes_to_gt_track_iscrowd"] = { + cls: [track["iscrowd"] for track in tracks] + for cls, tracks in classes_to_tracks.items() + } + else: + raw_data["classes_to_dt_track_scores"] = { + cls: np.array([track["score"] for track in tracks]) + for cls, tracks in classes_to_tracks.items() + } + + if is_gt: + key_map = { + "classes_to_tracks": "classes_to_gt_tracks", + "classes_to_track_ids": "classes_to_gt_track_ids", + "classes_to_track_areas": "classes_to_gt_track_areas", + } + else: + key_map = { + "classes_to_tracks": "classes_to_dt_tracks", + "classes_to_track_ids": "classes_to_dt_track_ids", + "classes_to_track_areas": "classes_to_dt_track_areas", + } + for k, v in key_map.items(): + raw_data[v] = raw_data.pop(k) + + raw_data["num_timesteps"] = num_timesteps + raw_data["seq"] = seq + return raw_data + + @_timing.time + def get_preprocessed_seq_data(self, raw_data, cls): + """Preprocess data for a single sequence for a single class ready for evaluation. + Inputs: + - raw_data is a dict containing the data for the sequence already read in by get_raw_seq_data(). + - cls is the class to be evaluated. + Outputs: + - data is a dict containing all of the information that metrics need to perform evaluation. + It contains the following fields: + [num_timesteps, num_gt_ids, num_tracker_ids, num_gt_dets, num_tracker_dets] : integers. + [gt_ids, tracker_ids, tracker_confidences]: list (for each timestep) of 1D NDArrays (for each det). + [gt_dets, tracker_dets]: list (for each timestep) of lists of detections. + [similarity_scores]: list (for each timestep) of 2D NDArrays. + Notes: + General preprocessing (preproc) occurs in 4 steps. Some datasets may not use all of these steps. + 1) Extract only detections relevant for the class to be evaluated (including distractor detections). + 2) Match gt dets and tracker dets. Remove tracker dets that are matched to a gt det that is of a + distractor class, or otherwise marked as to be removed. + 3) Remove unmatched tracker dets if they fall within a crowd ignore region or don't meet a certain + other criteria (e.g. are too small). + 4) Remove gt dets that were only useful for preprocessing and not for actual evaluation. + After the above preprocessing steps, this function also calculates the number of gt and tracker detections + and unique track ids. It also relabels gt and tracker ids to be contiguous and checks that ids are + unique within each timestep. + YouTubeVIS: + In YouTubeVIS, the 4 preproc steps are as follow: + 1) There are 40 classes which are evaluated separately. + 2) No matched tracker dets are removed. + 3) No unmatched tracker dets are removed. + 4) No gt dets are removed. + Further, for TrackMAP computation track representations for the given class are accessed from a dictionary + and the tracks from the tracker data are sorted according to the tracker confidence. + """ + cls_id = self.class_name_to_class_id[cls] + + data_keys = [ + "gt_ids", + "tracker_ids", + "gt_dets", + "tracker_dets", + "similarity_scores", + ] + data = {key: [None] * raw_data["num_timesteps"] for key in data_keys} + unique_gt_ids = [] + unique_tracker_ids = [] + num_gt_dets = 0 + num_tracker_dets = 0 + + for t in range(raw_data["num_timesteps"]): + # Only extract relevant dets for this class for eval (cls) + gt_class_mask = np.atleast_1d(raw_data["gt_classes"][t] == cls_id) + gt_class_mask = gt_class_mask.astype(bool) + gt_ids = raw_data["gt_ids"][t][gt_class_mask] + gt_dets = [ + raw_data["gt_dets"][t][ind] + for ind in range(len(gt_class_mask)) + if gt_class_mask[ind] + ] + + tracker_class_mask = np.atleast_1d(raw_data["tracker_classes"][t] == cls_id) + tracker_class_mask = tracker_class_mask.astype(bool) + tracker_ids = raw_data["tracker_ids"][t][tracker_class_mask] + tracker_dets = [ + raw_data["tracker_dets"][t][ind] + for ind in range(len(tracker_class_mask)) + if tracker_class_mask[ind] + ] + similarity_scores = raw_data["similarity_scores"][t][gt_class_mask, :][ + :, tracker_class_mask + ] + + data["tracker_ids"][t] = tracker_ids + data["tracker_dets"][t] = tracker_dets + data["gt_ids"][t] = gt_ids + data["gt_dets"][t] = gt_dets + data["similarity_scores"][t] = similarity_scores + + unique_gt_ids += list(np.unique(data["gt_ids"][t])) + unique_tracker_ids += list(np.unique(data["tracker_ids"][t])) + num_tracker_dets += len(data["tracker_ids"][t]) + num_gt_dets += len(data["gt_ids"][t]) + + # Re-label IDs such that there are no empty IDs + if len(unique_gt_ids) > 0: + unique_gt_ids = np.unique(unique_gt_ids) + gt_id_map = np.nan * np.ones((np.max(unique_gt_ids) + 1)) + gt_id_map[unique_gt_ids] = np.arange(len(unique_gt_ids)) + for t in range(raw_data["num_timesteps"]): + if len(data["gt_ids"][t]) > 0: + data["gt_ids"][t] = gt_id_map[data["gt_ids"][t]].astype(int) + if len(unique_tracker_ids) > 0: + unique_tracker_ids = np.unique(unique_tracker_ids) + tracker_id_map = np.nan * np.ones((np.max(unique_tracker_ids) + 1)) + tracker_id_map[unique_tracker_ids] = np.arange(len(unique_tracker_ids)) + for t in range(raw_data["num_timesteps"]): + if len(data["tracker_ids"][t]) > 0: + data["tracker_ids"][t] = tracker_id_map[ + data["tracker_ids"][t] + ].astype(int) + + # Ensure that ids are unique per timestep. + self._check_unique_ids(data) + + # Record overview statistics. + data["num_tracker_dets"] = num_tracker_dets + data["num_gt_dets"] = num_gt_dets + data["num_tracker_ids"] = len(unique_tracker_ids) + data["num_gt_ids"] = len(unique_gt_ids) + data["num_timesteps"] = raw_data["num_timesteps"] + data["seq"] = raw_data["seq"] + + # get track representations + data["gt_tracks"] = raw_data["classes_to_gt_tracks"][cls_id] + data["gt_track_ids"] = raw_data["classes_to_gt_track_ids"][cls_id] + data["gt_track_areas"] = raw_data["classes_to_gt_track_areas"][cls_id] + data["gt_track_iscrowd"] = raw_data["classes_to_gt_track_iscrowd"][cls_id] + data["dt_tracks"] = raw_data["classes_to_dt_tracks"][cls_id] + data["dt_track_ids"] = raw_data["classes_to_dt_track_ids"][cls_id] + data["dt_track_areas"] = raw_data["classes_to_dt_track_areas"][cls_id] + data["dt_track_scores"] = raw_data["classes_to_dt_track_scores"][cls_id] + data["iou_type"] = "mask" + + # sort tracker data tracks by tracker confidence scores + if data["dt_tracks"]: + idx = np.argsort( + [-score for score in data["dt_track_scores"]], kind="mergesort" + ) + data["dt_track_scores"] = [data["dt_track_scores"][i] for i in idx] + data["dt_tracks"] = [data["dt_tracks"][i] for i in idx] + data["dt_track_ids"] = [data["dt_track_ids"][i] for i in idx] + data["dt_track_areas"] = [data["dt_track_areas"][i] for i in idx] + + return data + + def _calculate_similarities(self, gt_dets_t, tracker_dets_t): + if self.iou_type == "segm": + similarity_scores = self._calculate_mask_ious( + gt_dets_t, tracker_dets_t, is_encoded=True, do_ioa=False + ) + else: + gt_dets_t = np.array(gt_dets_t, dtype=np.float32).reshape(-1, 4) + tracker_dets_t = np.array(tracker_dets_t, dtype=np.float32).reshape(-1, 4) + similarity_scores = self._calculate_box_ious( + gt_dets_t, tracker_dets_t, box_format="xywh", do_ioa=False + ) + return similarity_scores + + def _prepare_gt_annotations(self): + """ + Prepares GT data by rle encoding segmentations and computing the average track area. + :return: None + """ + if self.iou_type == "segm": + # only loaded when needed to reduce minimum requirements + from pycocotools import mask as mask_utils + + for track in self.gt_data["annotations"]: + h = track["height"] + w = track["width"] + for i, seg in enumerate(track["segmentations"]): + if seg is not None and isinstance(seg["counts"], list): + track["segmentations"][i] = mask_utils.frPyObjects(seg, h, w) + areas = [a for a in track["areas"] if a] + if len(areas) == 0: + track["area"] = 0 + else: + track["area"] = np.array(areas).mean() + else: + for track in self.gt_data["annotations"]: + # For bbox eval, compute areas from bboxes if not already available + areas = [a for a in track.get("areas", []) if a] + if not areas: + areas = [] + for bbox in track.get("bboxes", []): + if bbox is not None: + areas.append(bbox[2] * bbox[3]) + track["area"] = np.array(areas).mean() if areas else 0 + + def _get_tracker_seq_tracks(self, tracker, seq_id): + """ + Prepares tracker data for a given sequence. Extracts all annotations for given sequence ID, computes + average track area and assigns a track ID. + :param tracker: the given tracker + :param seq_id: the sequence ID + :return: the extracted tracks + """ + # only loaded when needed to reduce minimum requirements + from pycocotools import mask as mask_utils + + tracks = [ + ann for ann in self.tracker_data[tracker] if ann["video_id"] == seq_id + ] + for track in tracks: + if "areas" not in track: + if self.iou_type == "segm": + for seg in track["segmentations"]: + if seg: + track["areas"].append(mask_utils.area(seg)) + else: + track["areas"].append(None) + else: + for bbox in track["bboxes"]: + if bbox: + track["areas"].append(bbox[2] * bbox[3]) + else: + track["areas"].append(None) + areas = [a for a in track["areas"] if a] + if len(areas) == 0: + track["area"] = 0 + else: + track["area"] = np.array(areas).mean() + track["id"] = self.global_tid_counter + self.global_tid_counter += 1 + return tracks + + def get_name(self): + return self.dataset_name diff --git a/third_party/sam3/sam3/eval/hota_eval_toolkit/trackeval/eval.py b/third_party/sam3/sam3/eval/hota_eval_toolkit/trackeval/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..7dcdfd5d5dc5afae50abf1a5834948225614e0ab --- /dev/null +++ b/third_party/sam3/sam3/eval/hota_eval_toolkit/trackeval/eval.py @@ -0,0 +1,398 @@ +# flake8: noqa + +# pyre-unsafe + +import os +import time +import traceback +from functools import partial +from multiprocessing.pool import Pool + +import numpy as np + +from . import _timing, utils +from .metrics import Count +from .utils import TrackEvalException + +try: + import tqdm + + TQDM_IMPORTED = True +except ImportError as _: + TQDM_IMPORTED = False + + +class Evaluator: + """Evaluator class for evaluating different metrics for different datasets""" + + @staticmethod + def get_default_eval_config(): + """Returns the default config values for evaluation""" + code_path = utils.get_code_path() + default_config = { + "USE_PARALLEL": False, + "NUM_PARALLEL_CORES": 8, + "BREAK_ON_ERROR": True, # Raises exception and exits with error + "RETURN_ON_ERROR": False, # if not BREAK_ON_ERROR, then returns from function on error + "LOG_ON_ERROR": os.path.join( + code_path, "error_log.txt" + ), # if not None, save any errors into a log file. + "PRINT_RESULTS": True, + "PRINT_ONLY_COMBINED": False, + "PRINT_CONFIG": True, + "TIME_PROGRESS": True, + "DISPLAY_LESS_PROGRESS": True, + "OUTPUT_SUMMARY": True, + "OUTPUT_EMPTY_CLASSES": True, # If False, summary files are not output for classes with no detections + "OUTPUT_DETAILED": True, + "PLOT_CURVES": True, + } + return default_config + + def __init__(self, config=None): + """Initialise the evaluator with a config file""" + self.config = utils.init_config(config, self.get_default_eval_config(), "Eval") + # Only run timing analysis if not run in parallel. + if self.config["TIME_PROGRESS"] and not self.config["USE_PARALLEL"]: + _timing.DO_TIMING = True + if self.config["DISPLAY_LESS_PROGRESS"]: + _timing.DISPLAY_LESS_PROGRESS = True + + def _combine_results( + self, + res, + metrics_list, + metric_names, + dataset, + res_field="COMBINED_SEQ", + target_tag=None, + ): + assert res_field.startswith("COMBINED_SEQ") + # collecting combined cls keys (cls averaged, det averaged, super classes) + tracker_list, seq_list, class_list = dataset.get_eval_info() + combined_cls_keys = [] + res[res_field] = {} + + # narrow the target for evaluation + if target_tag is not None: + target_video_ids = [ + annot["video_id"] + for annot in dataset.gt_data["annotations"] + if target_tag in annot["tags"] + ] + vid2name = { + video["id"]: video["file_names"][0].split("/")[0] + for video in dataset.gt_data["videos"] + } + target_video_ids = set(target_video_ids) + target_video = [vid2name[video_id] for video_id in target_video_ids] + + if len(target_video) == 0: + raise TrackEvalException( + "No sequences found with the tag %s" % target_tag + ) + + target_annotations = [ + annot + for annot in dataset.gt_data["annotations"] + if annot["video_id"] in target_video_ids + ] + assert all(target_tag in annot["tags"] for annot in target_annotations), ( + f"Not all annotations in the target sequences have the target tag {target_tag}. " + "We currently only support a target tag at the sequence level, not at the annotation level." + ) + else: + target_video = seq_list + + # combine sequences for each class + for c_cls in class_list: + res[res_field][c_cls] = {} + for metric, metric_name in zip(metrics_list, metric_names): + curr_res = { + seq_key: seq_value[c_cls][metric_name] + for seq_key, seq_value in res.items() + if not seq_key.startswith("COMBINED_SEQ") + and seq_key in target_video + } + res[res_field][c_cls][metric_name] = metric.combine_sequences(curr_res) + # combine classes + if dataset.should_classes_combine: + combined_cls_keys += [ + "cls_comb_cls_av", + "cls_comb_det_av", + "all", + ] + res[res_field]["cls_comb_cls_av"] = {} + res[res_field]["cls_comb_det_av"] = {} + for metric, metric_name in zip(metrics_list, metric_names): + cls_res = { + cls_key: cls_value[metric_name] + for cls_key, cls_value in res[res_field].items() + if cls_key not in combined_cls_keys + } + res[res_field]["cls_comb_cls_av"][metric_name] = ( + metric.combine_classes_class_averaged(cls_res) + ) + res[res_field]["cls_comb_det_av"][metric_name] = ( + metric.combine_classes_det_averaged(cls_res) + ) + # combine classes to super classes + if dataset.use_super_categories: + for cat, sub_cats in dataset.super_categories.items(): + combined_cls_keys.append(cat) + res[res_field][cat] = {} + for metric, metric_name in zip(metrics_list, metric_names): + cat_res = { + cls_key: cls_value[metric_name] + for cls_key, cls_value in res[res_field].items() + if cls_key in sub_cats + } + res[res_field][cat][metric_name] = ( + metric.combine_classes_det_averaged(cat_res) + ) + return res, combined_cls_keys + + def _summarize_results( + self, + res, + tracker, + metrics_list, + metric_names, + dataset, + res_field, + combined_cls_keys, + ): + config = self.config + output_fol = dataset.get_output_fol(tracker) + tracker_display_name = dataset.get_display_name(tracker) + for c_cls in res[ + res_field + ].keys(): # class_list + combined classes if calculated + summaries = [] + details = [] + num_dets = res[res_field][c_cls]["Count"]["Dets"] + if config["OUTPUT_EMPTY_CLASSES"] or num_dets > 0: + for metric, metric_name in zip(metrics_list, metric_names): + # for combined classes there is no per sequence evaluation + if c_cls in combined_cls_keys: + table_res = {res_field: res[res_field][c_cls][metric_name]} + else: + table_res = { + seq_key: seq_value[c_cls][metric_name] + for seq_key, seq_value in res.items() + } + + if config["PRINT_RESULTS"] and config["PRINT_ONLY_COMBINED"]: + dont_print = ( + dataset.should_classes_combine + and c_cls not in combined_cls_keys + ) + if not dont_print: + metric.print_table( + {res_field: table_res[res_field]}, + tracker_display_name, + c_cls, + res_field, + res_field, + ) + elif config["PRINT_RESULTS"]: + metric.print_table( + table_res, tracker_display_name, c_cls, res_field, res_field + ) + if config["OUTPUT_SUMMARY"]: + summaries.append(metric.summary_results(table_res)) + if config["OUTPUT_DETAILED"]: + details.append(metric.detailed_results(table_res)) + if config["PLOT_CURVES"]: + metric.plot_single_tracker_results( + table_res, + tracker_display_name, + c_cls, + output_fol, + ) + if config["OUTPUT_SUMMARY"]: + utils.write_summary_results(summaries, c_cls, output_fol) + if config["OUTPUT_DETAILED"]: + utils.write_detailed_results(details, c_cls, output_fol) + + @_timing.time + def evaluate(self, dataset_list, metrics_list, show_progressbar=False): + """Evaluate a set of metrics on a set of datasets""" + config = self.config + metrics_list = metrics_list + [Count()] # Count metrics are always run + metric_names = utils.validate_metrics_list(metrics_list) + dataset_names = [dataset.get_name() for dataset in dataset_list] + output_res = {} + output_msg = {} + + for dataset, dataset_name in zip(dataset_list, dataset_names): + # Get dataset info about what to evaluate + output_res[dataset_name] = {} + output_msg[dataset_name] = {} + tracker_list, seq_list, class_list = dataset.get_eval_info() + print( + "\nEvaluating %i tracker(s) on %i sequence(s) for %i class(es) on %s dataset using the following " + "metrics: %s\n" + % ( + len(tracker_list), + len(seq_list), + len(class_list), + dataset_name, + ", ".join(metric_names), + ) + ) + + # Evaluate each tracker + for tracker in tracker_list: + # if not config['BREAK_ON_ERROR'] then go to next tracker without breaking + try: + # Evaluate each sequence in parallel or in series. + # returns a nested dict (res), indexed like: res[seq][class][metric_name][sub_metric field] + # e.g. res[seq_0001][pedestrian][hota][DetA] + print("\nEvaluating %s\n" % tracker) + time_start = time.time() + if config["USE_PARALLEL"]: + if show_progressbar and TQDM_IMPORTED: + seq_list_sorted = sorted(seq_list) + + with ( + Pool(config["NUM_PARALLEL_CORES"]) as pool, + tqdm.tqdm(total=len(seq_list)) as pbar, + ): + _eval_sequence = partial( + eval_sequence, + dataset=dataset, + tracker=tracker, + class_list=class_list, + metrics_list=metrics_list, + metric_names=metric_names, + ) + results = [] + for r in pool.imap( + _eval_sequence, seq_list_sorted, chunksize=20 + ): + results.append(r) + pbar.update() + res = dict(zip(seq_list_sorted, results)) + + else: + with Pool(config["NUM_PARALLEL_CORES"]) as pool: + _eval_sequence = partial( + eval_sequence, + dataset=dataset, + tracker=tracker, + class_list=class_list, + metrics_list=metrics_list, + metric_names=metric_names, + ) + results = pool.map(_eval_sequence, seq_list) + res = dict(zip(seq_list, results)) + else: + res = {} + if show_progressbar and TQDM_IMPORTED: + seq_list_sorted = sorted(seq_list) + for curr_seq in tqdm.tqdm(seq_list_sorted): + res[curr_seq] = eval_sequence( + curr_seq, + dataset, + tracker, + class_list, + metrics_list, + metric_names, + ) + else: + for curr_seq in sorted(seq_list): + res[curr_seq] = eval_sequence( + curr_seq, + dataset, + tracker, + class_list, + metrics_list, + metric_names, + ) + + # Combine results over all sequences and then over all classes + res, combined_cls_keys = self._combine_results( + res, metrics_list, metric_names, dataset, "COMBINED_SEQ" + ) + + if np.all( + ["tags" in annot for annot in dataset.gt_data["annotations"]] + ): + # Combine results over the challenging sequences and then over all classes + # currently only support "tracking_challenging_pair" + res, _ = self._combine_results( + res, + metrics_list, + metric_names, + dataset, + "COMBINED_SEQ_CHALLENGING", + "tracking_challenging_pair", + ) + + # Print and output results in various formats + if config["TIME_PROGRESS"]: + print( + "\nAll sequences for %s finished in %.2f seconds" + % (tracker, time.time() - time_start) + ) + + self._summarize_results( + res, + tracker, + metrics_list, + metric_names, + dataset, + "COMBINED_SEQ", + combined_cls_keys, + ) + if "COMBINED_SEQ_CHALLENGING" in res: + self._summarize_results( + res, + tracker, + metrics_list, + metric_names, + dataset, + "COMBINED_SEQ_CHALLENGING", + combined_cls_keys, + ) + + # Output for returning from function + output_res[dataset_name][tracker] = res + output_msg[dataset_name][tracker] = "Success" + + except Exception as err: + output_res[dataset_name][tracker] = None + if type(err) == TrackEvalException: + output_msg[dataset_name][tracker] = str(err) + else: + output_msg[dataset_name][tracker] = "Unknown error occurred." + print("Tracker %s was unable to be evaluated." % tracker) + print(err) + traceback.print_exc() + if config["LOG_ON_ERROR"] is not None: + with open(config["LOG_ON_ERROR"], "a") as f: + print(dataset_name, file=f) + print(tracker, file=f) + print(traceback.format_exc(), file=f) + print("\n\n\n", file=f) + if config["BREAK_ON_ERROR"]: + raise err + elif config["RETURN_ON_ERROR"]: + return output_res, output_msg + + return output_res, output_msg + + +@_timing.time +def eval_sequence(seq, dataset, tracker, class_list, metrics_list, metric_names): + """Function for evaluating a single sequence""" + + raw_data = dataset.get_raw_seq_data(tracker, seq) + seq_res = {} + for cls in class_list: + seq_res[cls] = {} + data = dataset.get_preprocessed_seq_data(raw_data, cls) + for metric, met_name in zip(metrics_list, metric_names): + seq_res[cls][met_name] = metric.eval_sequence(data) + return seq_res diff --git a/third_party/sam3/sam3/eval/hota_eval_toolkit/trackeval/metrics/__init__.py b/third_party/sam3/sam3/eval/hota_eval_toolkit/trackeval/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..531a085d0030fe1d65fe65ddbec35002bd8efa29 --- /dev/null +++ b/third_party/sam3/sam3/eval/hota_eval_toolkit/trackeval/metrics/__init__.py @@ -0,0 +1,6 @@ +# flake8: noqa + +# pyre-unsafe + +from .count import Count +from .hota import HOTA diff --git a/third_party/sam3/sam3/eval/hota_eval_toolkit/trackeval/metrics/_base_metric.py b/third_party/sam3/sam3/eval/hota_eval_toolkit/trackeval/metrics/_base_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..bf7bb0a8d0d9ae486d14d0a3d4e01079c8287178 --- /dev/null +++ b/third_party/sam3/sam3/eval/hota_eval_toolkit/trackeval/metrics/_base_metric.py @@ -0,0 +1,147 @@ +# flake8: noqa + +# pyre-unsafe + +from abc import ABC, abstractmethod + +import numpy as np + +from .. import _timing +from ..utils import TrackEvalException + + +class _BaseMetric(ABC): + @abstractmethod + def __init__(self): + self.plottable = False + self.integer_fields = [] + self.float_fields = [] + self.array_labels = [] + self.integer_array_fields = [] + self.float_array_fields = [] + self.fields = [] + self.summary_fields = [] + self.registered = False + + ##################################################################### + # Abstract functions for subclasses to implement + + @_timing.time + @abstractmethod + def eval_sequence(self, data): ... + + @abstractmethod + def combine_sequences(self, all_res): ... + + @abstractmethod + def combine_classes_class_averaged(self, all_res, ignore_empty_classes=False): ... + + @abstractmethod + def combine_classes_det_averaged(self, all_res): ... + + def plot_single_tracker_results(self, all_res, tracker, output_folder, cls): + """Plot results of metrics, only valid for metrics with self.plottable""" + if self.plottable: + raise NotImplementedError( + "plot_results is not implemented for metric %s" % self.get_name() + ) + else: + pass + + ##################################################################### + # Helper functions which are useful for all metrics: + + @classmethod + def get_name(cls): + return cls.__name__ + + @staticmethod + def _combine_sum(all_res, field): + """Combine sequence results via sum""" + return sum([all_res[k][field] for k in all_res.keys()]) + + @staticmethod + def _combine_weighted_av(all_res, field, comb_res, weight_field): + """Combine sequence results via weighted average""" + return sum( + [all_res[k][field] * all_res[k][weight_field] for k in all_res.keys()] + ) / np.maximum(1.0, comb_res[weight_field]) + + def print_table( + self, table_res, tracker, cls, res_field="COMBINED_SEQ", output_lable="COMBINED" + ): + """Prints table of results for all sequences""" + print("") + metric_name = self.get_name() + self._row_print( + [metric_name + ": " + tracker + "-" + cls] + self.summary_fields + ) + for seq, results in sorted(table_res.items()): + if seq.startswith("COMBINED_SEQ"): + continue + summary_res = self._summary_row(results) + self._row_print([seq] + summary_res) + summary_res = self._summary_row(table_res[res_field]) + self._row_print([output_lable] + summary_res) + + def _summary_row(self, results_): + vals = [] + for h in self.summary_fields: + if h in self.float_array_fields: + vals.append("{0:1.5g}".format(100 * np.mean(results_[h]))) + elif h in self.float_fields: + vals.append("{0:1.5g}".format(100 * float(results_[h]))) + elif h in self.integer_fields: + vals.append("{0:d}".format(int(results_[h]))) + else: + raise NotImplementedError( + "Summary function not implemented for this field type." + ) + return vals + + @staticmethod + def _row_print(*argv): + """Prints results in an evenly spaced rows, with more space in first row""" + if len(argv) == 1: + argv = argv[0] + to_print = "%-35s" % argv[0] + for v in argv[1:]: + to_print += "%-10s" % str(v) + print(to_print) + + def summary_results(self, table_res): + """Returns a simple summary of final results for a tracker""" + return dict( + zip(self.summary_fields, self._summary_row(table_res["COMBINED_SEQ"])) + ) + + def detailed_results(self, table_res): + """Returns detailed final results for a tracker""" + # Get detailed field information + detailed_fields = self.float_fields + self.integer_fields + for h in self.float_array_fields + self.integer_array_fields: + for alpha in [int(100 * x) for x in self.array_labels]: + detailed_fields.append(h + "___" + str(alpha)) + detailed_fields.append(h + "___AUC") + + # Get detailed results + detailed_results = {} + for seq, res in table_res.items(): + detailed_row = self._detailed_row(res) + if len(detailed_row) != len(detailed_fields): + raise TrackEvalException( + "Field names and data have different sizes (%i and %i)" + % (len(detailed_row), len(detailed_fields)) + ) + detailed_results[seq] = dict(zip(detailed_fields, detailed_row)) + return detailed_results + + def _detailed_row(self, res): + detailed_row = [] + for h in self.float_fields + self.integer_fields: + detailed_row.append(res[h]) + for h in self.float_array_fields + self.integer_array_fields: + for i, alpha in enumerate([int(100 * x) for x in self.array_labels]): + detailed_row.append(res[h][i]) + detailed_row.append(np.mean(res[h])) + return detailed_row diff --git a/third_party/sam3/sam3/eval/hota_eval_toolkit/trackeval/metrics/count.py b/third_party/sam3/sam3/eval/hota_eval_toolkit/trackeval/metrics/count.py new file mode 100644 index 0000000000000000000000000000000000000000..6b844680bda38e228363afb9acefcc489b40a28d --- /dev/null +++ b/third_party/sam3/sam3/eval/hota_eval_toolkit/trackeval/metrics/count.py @@ -0,0 +1,50 @@ +# flake8: noqa + +# pyre-unsafe + +from .. import _timing +from ._base_metric import _BaseMetric + + +class Count(_BaseMetric): + """Class which simply counts the number of tracker and gt detections and ids.""" + + def __init__(self, config=None): + super().__init__() + self.integer_fields = ["Dets", "GT_Dets", "IDs", "GT_IDs"] + self.fields = self.integer_fields + self.summary_fields = self.fields + + @_timing.time + def eval_sequence(self, data): + """Returns counts for one sequence""" + # Get results + res = { + "Dets": data["num_tracker_dets"], + "GT_Dets": data["num_gt_dets"], + "IDs": data["num_tracker_ids"], + "GT_IDs": data["num_gt_ids"], + "Frames": data["num_timesteps"], + } + return res + + def combine_sequences(self, all_res): + """Combines metrics across all sequences""" + res = {} + for field in self.integer_fields: + res[field] = self._combine_sum(all_res, field) + return res + + def combine_classes_class_averaged(self, all_res, ignore_empty_classes=None): + """Combines metrics across all classes by averaging over the class values""" + res = {} + for field in self.integer_fields: + res[field] = self._combine_sum(all_res, field) + return res + + def combine_classes_det_averaged(self, all_res): + """Combines metrics across all classes by averaging over the detection values""" + res = {} + for field in self.integer_fields: + res[field] = self._combine_sum(all_res, field) + return res diff --git a/third_party/sam3/sam3/eval/hota_eval_toolkit/trackeval/metrics/hota.py b/third_party/sam3/sam3/eval/hota_eval_toolkit/trackeval/metrics/hota.py new file mode 100644 index 0000000000000000000000000000000000000000..9ae2c9681cf83100b5c9026d55c1ab6753869b9d --- /dev/null +++ b/third_party/sam3/sam3/eval/hota_eval_toolkit/trackeval/metrics/hota.py @@ -0,0 +1,293 @@ +# flake8: noqa + +# pyre-unsafe + +import os + +import numpy as np +from scipy.optimize import linear_sum_assignment + +from .. import _timing +from ._base_metric import _BaseMetric + + +class HOTA(_BaseMetric): + """Class which implements the HOTA metrics. + See: https://link.springer.com/article/10.1007/s11263-020-01375-2 + """ + + def __init__(self, config=None): + super().__init__() + self.plottable = True + self.array_labels = np.arange(0.05, 0.99, 0.05) + self.integer_array_fields = ["HOTA_TP", "HOTA_FN", "HOTA_FP"] + self.float_array_fields = [ + "HOTA", + "DetA", + "AssA", + "DetRe", + "DetPr", + "AssRe", + "AssPr", + "LocA", + "OWTA", + ] + self.float_fields = ["HOTA(0)", "LocA(0)", "HOTALocA(0)"] + self.fields = ( + self.float_array_fields + self.integer_array_fields + self.float_fields + ) + self.summary_fields = self.float_array_fields + self.float_fields + + @_timing.time + def eval_sequence(self, data): + """Calculates the HOTA metrics for one sequence""" + + # Initialise results + res = {} + for field in self.float_array_fields + self.integer_array_fields: + res[field] = np.zeros((len(self.array_labels)), dtype=float) + for field in self.float_fields: + res[field] = 0 + + # Return result quickly if tracker or gt sequence is empty + if data["num_tracker_dets"] == 0: + res["HOTA_FN"] = data["num_gt_dets"] * np.ones( + (len(self.array_labels)), dtype=float + ) + res["LocA"] = np.ones((len(self.array_labels)), dtype=float) + res["LocA(0)"] = 1.0 + return res + if data["num_gt_dets"] == 0: + res["HOTA_FP"] = data["num_tracker_dets"] * np.ones( + (len(self.array_labels)), dtype=float + ) + res["LocA"] = np.ones((len(self.array_labels)), dtype=float) + res["LocA(0)"] = 1.0 + return res + + # Variables counting global association + potential_matches_count = np.zeros( + (data["num_gt_ids"], data["num_tracker_ids"]) + ) + gt_id_count = np.zeros((data["num_gt_ids"], 1)) + tracker_id_count = np.zeros((1, data["num_tracker_ids"])) + + # First loop through each timestep and accumulate global track information. + for t, (gt_ids_t, tracker_ids_t) in enumerate( + zip(data["gt_ids"], data["tracker_ids"]) + ): + # Count the potential matches between ids in each timestep + # These are normalised, weighted by the match similarity. + similarity = data["similarity_scores"][t] + sim_iou_denom = ( + similarity.sum(0)[np.newaxis, :] + + similarity.sum(1)[:, np.newaxis] + - similarity + ) + sim_iou = np.zeros_like(similarity) + sim_iou_mask = sim_iou_denom > 0 + np.finfo("float").eps + sim_iou[sim_iou_mask] = ( + similarity[sim_iou_mask] / sim_iou_denom[sim_iou_mask] + ) + potential_matches_count[ + gt_ids_t[:, np.newaxis], tracker_ids_t[np.newaxis, :] + ] += sim_iou + + # Calculate the total number of dets for each gt_id and tracker_id. + gt_id_count[gt_ids_t] += 1 + tracker_id_count[0, tracker_ids_t] += 1 + + # Calculate overall jaccard alignment score (before unique matching) between IDs + global_alignment_score = potential_matches_count / ( + gt_id_count + tracker_id_count - potential_matches_count + ) + matches_counts = [ + np.zeros_like(potential_matches_count) for _ in self.array_labels + ] + + # Calculate scores for each timestep + for t, (gt_ids_t, tracker_ids_t) in enumerate( + zip(data["gt_ids"], data["tracker_ids"]) + ): + # Deal with the case that there are no gt_det/tracker_det in a timestep. + if len(gt_ids_t) == 0: + for a, alpha in enumerate(self.array_labels): + res["HOTA_FP"][a] += len(tracker_ids_t) + continue + if len(tracker_ids_t) == 0: + for a, alpha in enumerate(self.array_labels): + res["HOTA_FN"][a] += len(gt_ids_t) + continue + + # Get matching scores between pairs of dets for optimizing HOTA + similarity = data["similarity_scores"][t] + score_mat = ( + global_alignment_score[ + gt_ids_t[:, np.newaxis], tracker_ids_t[np.newaxis, :] + ] + * similarity + ) + + # Hungarian algorithm to find best matches + match_rows, match_cols = linear_sum_assignment(-score_mat) + + # Calculate and accumulate basic statistics + for a, alpha in enumerate(self.array_labels): + actually_matched_mask = ( + similarity[match_rows, match_cols] >= alpha - np.finfo("float").eps + ) + alpha_match_rows = match_rows[actually_matched_mask] + alpha_match_cols = match_cols[actually_matched_mask] + num_matches = len(alpha_match_rows) + res["HOTA_TP"][a] += num_matches + res["HOTA_FN"][a] += len(gt_ids_t) - num_matches + res["HOTA_FP"][a] += len(tracker_ids_t) - num_matches + if num_matches > 0: + res["LocA"][a] += sum( + similarity[alpha_match_rows, alpha_match_cols] + ) + matches_counts[a][ + gt_ids_t[alpha_match_rows], tracker_ids_t[alpha_match_cols] + ] += 1 + + # Calculate association scores (AssA, AssRe, AssPr) for the alpha value. + # First calculate scores per gt_id/tracker_id combo and then average over the number of detections. + for a, alpha in enumerate(self.array_labels): + matches_count = matches_counts[a] + ass_a = matches_count / np.maximum( + 1, gt_id_count + tracker_id_count - matches_count + ) + res["AssA"][a] = np.sum(matches_count * ass_a) / np.maximum( + 1, res["HOTA_TP"][a] + ) + ass_re = matches_count / np.maximum(1, gt_id_count) + res["AssRe"][a] = np.sum(matches_count * ass_re) / np.maximum( + 1, res["HOTA_TP"][a] + ) + ass_pr = matches_count / np.maximum(1, tracker_id_count) + res["AssPr"][a] = np.sum(matches_count * ass_pr) / np.maximum( + 1, res["HOTA_TP"][a] + ) + + # Calculate final scores + res["LocA"] = np.maximum(1e-10, res["LocA"]) / np.maximum(1e-10, res["HOTA_TP"]) + res = self._compute_final_fields(res) + return res + + def combine_sequences(self, all_res): + """Combines metrics across all sequences""" + res = {} + for field in self.integer_array_fields: + res[field] = self._combine_sum(all_res, field) + for field in ["AssRe", "AssPr", "AssA"]: + res[field] = self._combine_weighted_av( + all_res, field, res, weight_field="HOTA_TP" + ) + loca_weighted_sum = sum( + [all_res[k]["LocA"] * all_res[k]["HOTA_TP"] for k in all_res.keys()] + ) + res["LocA"] = np.maximum(1e-10, loca_weighted_sum) / np.maximum( + 1e-10, res["HOTA_TP"] + ) + res = self._compute_final_fields(res) + return res + + def combine_classes_class_averaged(self, all_res, ignore_empty_classes=False): + """Combines metrics across all classes by averaging over the class values. + If 'ignore_empty_classes' is True, then it only sums over classes with at least one gt or predicted detection. + """ + res = {} + for field in self.integer_array_fields: + if ignore_empty_classes: + res[field] = self._combine_sum( + { + k: v + for k, v in all_res.items() + if ( + v["HOTA_TP"] + v["HOTA_FN"] + v["HOTA_FP"] + > 0 + np.finfo("float").eps + ).any() + }, + field, + ) + else: + res[field] = self._combine_sum( + {k: v for k, v in all_res.items()}, field + ) + + for field in self.float_fields + self.float_array_fields: + if ignore_empty_classes: + res[field] = np.mean( + [ + v[field] + for v in all_res.values() + if ( + v["HOTA_TP"] + v["HOTA_FN"] + v["HOTA_FP"] + > 0 + np.finfo("float").eps + ).any() + ], + axis=0, + ) + else: + res[field] = np.mean([v[field] for v in all_res.values()], axis=0) + return res + + def combine_classes_det_averaged(self, all_res): + """Combines metrics across all classes by averaging over the detection values""" + res = {} + for field in self.integer_array_fields: + res[field] = self._combine_sum(all_res, field) + for field in ["AssRe", "AssPr", "AssA"]: + res[field] = self._combine_weighted_av( + all_res, field, res, weight_field="HOTA_TP" + ) + loca_weighted_sum = sum( + [all_res[k]["LocA"] * all_res[k]["HOTA_TP"] for k in all_res.keys()] + ) + res["LocA"] = np.maximum(1e-10, loca_weighted_sum) / np.maximum( + 1e-10, res["HOTA_TP"] + ) + res = self._compute_final_fields(res) + return res + + @staticmethod + def _compute_final_fields(res): + """Calculate sub-metric ('field') values which only depend on other sub-metric values. + This function is used both for both per-sequence calculation, and in combining values across sequences. + """ + res["DetRe"] = res["HOTA_TP"] / np.maximum(1, res["HOTA_TP"] + res["HOTA_FN"]) + res["DetPr"] = res["HOTA_TP"] / np.maximum(1, res["HOTA_TP"] + res["HOTA_FP"]) + res["DetA"] = res["HOTA_TP"] / np.maximum( + 1, res["HOTA_TP"] + res["HOTA_FN"] + res["HOTA_FP"] + ) + res["HOTA"] = np.sqrt(res["DetA"] * res["AssA"]) + res["OWTA"] = np.sqrt(res["DetRe"] * res["AssA"]) + + res["HOTA(0)"] = res["HOTA"][0] + res["LocA(0)"] = res["LocA"][0] + res["HOTALocA(0)"] = res["HOTA(0)"] * res["LocA(0)"] + return res + + def plot_single_tracker_results(self, table_res, tracker, cls, output_folder): + """Create plot of results""" + + # Only loaded when run to reduce minimum requirements + from matplotlib import pyplot as plt + + res = table_res["COMBINED_SEQ"] + styles_to_plot = ["r", "b", "g", "b--", "b:", "g--", "g:", "m"] + for name, style in zip(self.float_array_fields, styles_to_plot): + plt.plot(self.array_labels, res[name], style) + plt.xlabel("alpha") + plt.ylabel("score") + plt.title(tracker + " - " + cls) + plt.axis([0, 1, 0, 1]) + legend = [] + for name in self.float_array_fields: + legend += [name + " (" + str(np.round(np.mean(res[name]), 2)) + ")"] + plt.legend(legend, loc="lower left") + out_file = os.path.join(output_folder, cls + "_plot.pdf") + os.makedirs(os.path.dirname(out_file), exist_ok=True) + plt.savefig(out_file) + plt.savefig(out_file.replace(".pdf", ".png")) + plt.clf() diff --git a/third_party/sam3/sam3/eval/hota_eval_toolkit/trackeval/utils.py b/third_party/sam3/sam3/eval/hota_eval_toolkit/trackeval/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..99dcc4cbfb166d9ce2cd425471e6b0bd13119ed3 --- /dev/null +++ b/third_party/sam3/sam3/eval/hota_eval_toolkit/trackeval/utils.py @@ -0,0 +1,197 @@ +# flake8: noqa + +# pyre-unsafe + +import argparse +import csv +import os +from collections import OrderedDict + + +def init_config(config, default_config, name=None): + """Initialise non-given config values with defaults""" + if config is None: + config = default_config + else: + for k in default_config.keys(): + if k not in config.keys(): + config[k] = default_config[k] + if name and config["PRINT_CONFIG"]: + print("\n%s Config:" % name) + for c in config.keys(): + print("%-20s : %-30s" % (c, config[c])) + return config + + +def update_config(config): + """ + Parse the arguments of a script and updates the config values for a given value if specified in the arguments. + :param config: the config to update + :return: the updated config + """ + parser = argparse.ArgumentParser() + for setting in config.keys(): + if type(config[setting]) == list or type(config[setting]) == type(None): + parser.add_argument("--" + setting, nargs="+") + else: + parser.add_argument("--" + setting) + args = parser.parse_args().__dict__ + for setting in args.keys(): + if args[setting] is not None: + if type(config[setting]) == type(True): + if args[setting] == "True": + x = True + elif args[setting] == "False": + x = False + else: + raise Exception( + "Command line parameter " + setting + "must be True or False" + ) + elif type(config[setting]) == type(1): + x = int(args[setting]) + elif type(args[setting]) == type(None): + x = None + else: + x = args[setting] + config[setting] = x + return config + + +def get_code_path(): + """Get base path where code is""" + return os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) + + +def validate_metrics_list(metrics_list): + """Get names of metric class and ensures they are unique, further checks that the fields within each metric class + do not have overlapping names. + """ + metric_names = [metric.get_name() for metric in metrics_list] + # check metric names are unique + if len(metric_names) != len(set(metric_names)): + raise TrackEvalException( + "Code being run with multiple metrics of the same name" + ) + fields = [] + for m in metrics_list: + fields += m.fields + # check metric fields are unique + if len(fields) != len(set(fields)): + raise TrackEvalException( + "Code being run with multiple metrics with fields of the same name" + ) + return metric_names + + +def write_summary_results(summaries, cls, output_folder): + """Write summary results to file""" + + fields = sum([list(s.keys()) for s in summaries], []) + values = sum([list(s.values()) for s in summaries], []) + + # In order to remain consistent upon new fields being adding, for each of the following fields if they are present + # they will be output in the summary first in the order below. Any further fields will be output in the order each + # metric family is called, and within each family either in the order they were added to the dict (python >= 3.6) or + # randomly (python < 3.6). + default_order = [ + "HOTA", + "DetA", + "AssA", + "DetRe", + "DetPr", + "AssRe", + "AssPr", + "LocA", + "OWTA", + "HOTA(0)", + "LocA(0)", + "HOTALocA(0)", + "MOTA", + "MOTP", + "MODA", + "CLR_Re", + "CLR_Pr", + "MTR", + "PTR", + "MLR", + "CLR_TP", + "CLR_FN", + "CLR_FP", + "IDSW", + "MT", + "PT", + "ML", + "Frag", + "sMOTA", + "IDF1", + "IDR", + "IDP", + "IDTP", + "IDFN", + "IDFP", + "Dets", + "GT_Dets", + "IDs", + "GT_IDs", + ] + default_ordered_dict = OrderedDict( + zip(default_order, [None for _ in default_order]) + ) + for f, v in zip(fields, values): + default_ordered_dict[f] = v + for df in default_order: + if default_ordered_dict[df] is None: + del default_ordered_dict[df] + fields = list(default_ordered_dict.keys()) + values = list(default_ordered_dict.values()) + + out_file = os.path.join(output_folder, cls + "_summary.txt") + os.makedirs(os.path.dirname(out_file), exist_ok=True) + with open(out_file, "w", newline="") as f: + writer = csv.writer(f, delimiter=" ") + writer.writerow(fields) + writer.writerow(values) + + +def write_detailed_results(details, cls, output_folder): + """Write detailed results to file""" + sequences = details[0].keys() + fields = ["seq"] + sum([list(s["COMBINED_SEQ"].keys()) for s in details], []) + out_file = os.path.join(output_folder, cls + "_detailed.csv") + os.makedirs(os.path.dirname(out_file), exist_ok=True) + with open(out_file, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(fields) + for seq in sorted(sequences): + if seq == "COMBINED_SEQ": + continue + writer.writerow([seq] + sum([list(s[seq].values()) for s in details], [])) + writer.writerow( + ["COMBINED"] + sum([list(s["COMBINED_SEQ"].values()) for s in details], []) + ) + + +def load_detail(file): + """Loads detailed data for a tracker.""" + data = {} + with open(file) as f: + for i, row_text in enumerate(f): + row = row_text.replace("\r", "").replace("\n", "").split(",") + if i == 0: + keys = row[1:] + continue + current_values = row[1:] + seq = row[0] + if seq == "COMBINED": + seq = "COMBINED_SEQ" + if (len(current_values) == len(keys)) and seq != "": + data[seq] = {} + for key, value in zip(keys, current_values): + data[seq][key] = float(value) + return data + + +class TrackEvalException(Exception): + """Custom exception for catching expected errors.""" + + ... diff --git a/third_party/sam3/sam3/eval/postprocessors.py b/third_party/sam3/sam3/eval/postprocessors.py new file mode 100644 index 0000000000000000000000000000000000000000..7bd26a9e00c79b6061fb33e276219f9cc4ed67cf --- /dev/null +++ b/third_party/sam3/sam3/eval/postprocessors.py @@ -0,0 +1,650 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +"""Postprocessors class to transform MDETR output according to the downstream task""" + +import dataclasses +import logging +from collections import defaultdict +from typing import Dict, List, Optional + +import numpy as np +import torch +from sam3.model import box_ops +from sam3.model.data_misc import BatchedInferenceMetadata, interpolate +from sam3.train.masks_ops import rle_encode, robust_rle_encode +from torch import nn + + +class PostProcessNullOp(nn.Module): + def __init__(self, **kwargs): + super(PostProcessNullOp).__init__() + pass + + def forward(self, input): + pass + + def process_results(self, **kwargs): + return kwargs["find_stages"] + + +class PostProcessImage(nn.Module): + """This module converts the model's output into the format expected by the coco api""" + + def __init__( + self, + max_dets_per_img: int, + iou_type="bbox", + to_cpu: bool = True, + use_original_ids: bool = False, + use_original_sizes_box: bool = False, + use_original_sizes_mask: bool = False, + convert_mask_to_rle: bool = False, + always_interpolate_masks_on_gpu: bool = True, + use_presence: bool = True, + detection_threshold: float = -1.0, + ) -> None: + super().__init__() + self.max_dets_per_img = max_dets_per_img + self.iou_type = iou_type + self.to_cpu = to_cpu + self.convert_mask_to_rle = convert_mask_to_rle + self.always_interpolate_masks_on_gpu = always_interpolate_masks_on_gpu + + self.use_presence = use_presence + self.detection_threshold = detection_threshold + self.use_original_ids = use_original_ids + self.use_original_sizes_box = use_original_sizes_box + self.use_original_sizes_mask = use_original_sizes_mask + + @torch.no_grad() + def forward( + self, + outputs, + target_sizes_boxes, + target_sizes_masks, + forced_labels=None, + consistent=False, + ret_tensordict: bool = False, # This is experimental + ): + """Perform the computation + Parameters: + outputs: raw outputs of the model + target_sizes_boxes: tensor of dimension [batch_size x 2] containing the size of each images of the batch + For evaluation, this must be the original image size (before any data augmentation) + For visualization, this should be the image size after data augment, but before padding + target_sizes_masks: same but used to resize masks + forced_labels: tensor of dimension [batch_size] containing the label to force for each image of the batch + This is useful when evaluating the model using standard metrics (eg on COCO, LVIS). In that case, + we query the model with every possible class label, so we when we pass the predictions to the evaluator, + we want to make sure that the predicted "class" matches the one that was queried. + consistent: whether all target sizes are equal + ret_tensordict: Experimental argument. If true, return a tensordict.TensorDict instead of a list of dictionaries for easier manipulation. + """ + if ret_tensordict: + assert ( + consistent is True + ), "We don't support returning TensorDict if the outputs have different shapes" # NOTE: It's possible but we don't support it. + assert self.detection_threshold <= 0.0, "TODO: implement?" + try: + from tensordict import TensorDict + except ImportError: + logging.info( + "tensordict is not installed. Install by running `pip install tensordict --no-deps`. Falling back by setting `ret_tensordict=False`" + ) + ret_tensordict = False + + out_bbox = outputs["pred_boxes"] if "pred_boxes" in outputs else None + out_logits = outputs["pred_logits"] + pred_masks = outputs["pred_masks"] if self.iou_type == "segm" else None + out_probs = out_logits.sigmoid() + if self.use_presence: + presence_score = outputs["presence_logit_dec"].sigmoid().unsqueeze(1) + out_probs = out_probs * presence_score + + assert target_sizes_boxes.shape[1] == 2 + assert target_sizes_masks.shape[1] == 2 + batch_size = target_sizes_boxes.shape[0] + + boxes, scores, labels, keep = self._process_boxes_and_labels( + target_sizes_boxes, forced_labels, out_bbox, out_probs + ) + assert boxes is None or len(boxes) == batch_size + out_masks = self._process_masks( + target_sizes_masks, pred_masks, consistent=consistent, keep=keep + ) + del pred_masks + + if boxes is None: + assert out_masks is not None + assert not ret_tensordict, "We don't support returning TensorDict if the output does not contain boxes" + B = len(out_masks) + boxes = [None] * B + scores = [None] * B + labels = [None] * B + + results = { + "scores": scores, + "labels": labels, + "boxes": boxes, + } + if out_masks is not None: + if self.convert_mask_to_rle: + results.update(masks_rle=out_masks) + else: + results.update(masks=out_masks) + + if ret_tensordict: + results = TensorDict(results).auto_batch_size_() + if self.to_cpu: + results = results.cpu() + else: + # Convert a dictonary of lists/tensors to list of dictionaries + results = [ + dict(zip(results.keys(), res_tuple)) + for res_tuple in zip(*results.values()) + ] + + return results + + def _process_masks(self, target_sizes, pred_masks, consistent=True, keep=None): + if pred_masks is None: + return None + if self.always_interpolate_masks_on_gpu: + gpu_device = target_sizes.device + assert gpu_device.type == "cuda" + pred_masks = pred_masks.to(device=gpu_device) + if consistent: + assert keep is None, "TODO: implement?" + # All masks should have the same shape, expected when processing a batch of size 1 + target_size = target_sizes.unique(dim=0) + assert target_size.size(0) == 1, "Expecting all target sizes to be equal" + out_masks = ( + interpolate( + pred_masks, + target_size.squeeze().tolist(), + mode="bilinear", + align_corners=False, + ).sigmoid() + > 0.5 + ) + if self.convert_mask_to_rle: + raise RuntimeError("TODO: implement?") + if self.to_cpu: + out_masks = out_masks.cpu() + else: + out_masks = [[]] * len(pred_masks) + + assert keep is None or len(keep) == len(pred_masks) + for i, mask in enumerate(pred_masks): + h, w = target_sizes[i] + if keep is not None: + mask = mask[keep[i]] + # Uses the gpu version fist, moves masks to cpu if it fails""" + try: + interpolated = ( + interpolate( + mask.unsqueeze(1), + (h, w), + mode="bilinear", + align_corners=False, + ).sigmoid() + > 0.5 + ) + except Exception as e: + logging.info("Issue found, reverting to CPU mode!") + mask_device = mask.device + mask = mask.cpu() + interpolated = ( + interpolate( + mask.unsqueeze(1), + (h, w), + mode="bilinear", + align_corners=False, + ).sigmoid() + > 0.5 + ) + interpolated = interpolated.to(mask_device) + + if self.convert_mask_to_rle: + out_masks[i] = robust_rle_encode(interpolated.squeeze(1)) + else: + out_masks[i] = interpolated + if self.to_cpu: + out_masks[i] = out_masks[i].cpu() + + return out_masks + + def _process_boxes_and_labels( + self, target_sizes, forced_labels, out_bbox, out_probs + ): + if out_bbox is None: + return None, None, None, None + assert len(out_probs) == len(target_sizes) + if self.to_cpu: + out_probs = out_probs.cpu() + scores, labels = out_probs.max(-1) + if forced_labels is None: + labels = torch.ones_like(labels) + else: + labels = forced_labels[:, None].expand_as(labels) + + # convert to [x0, y0, x1, y1] format + boxes = box_ops.box_cxcywh_to_xyxy(out_bbox) + + img_h, img_w = target_sizes.unbind(1) + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) + boxes = boxes * scale_fct[:, None, :] + + if self.to_cpu: + boxes = boxes.cpu() + + keep = None + if self.detection_threshold > 0: + # Filter out the boxes with scores below the detection threshold + keep = scores > self.detection_threshold + assert len(keep) == len(boxes) == len(scores) == len(labels) + + boxes = [b[k.to(b.device)] for b, k in zip(boxes, keep)] + scores = [s[k.to(s.device)] for s, k in zip(scores, keep)] + labels = [l[k.to(l.device)] for l, k in zip(labels, keep)] + + return boxes, scores, labels, keep + + def process_results( + self, find_stages, find_metadatas: List[BatchedInferenceMetadata], **kwargs + ): + if find_stages.loss_stages is not None: + find_metadatas = [find_metadatas[i] for i in find_stages.loss_stages] + assert len(find_stages) == len(find_metadatas) + results = {} + for outputs, meta in zip(find_stages, find_metadatas): + img_size_for_boxes = ( + meta.original_size + if self.use_original_sizes_box + else torch.ones_like(meta.original_size) + ) + img_size_for_masks = ( + meta.original_size + if self.use_original_sizes_mask + else torch.ones_like(meta.original_size) + ) + detection_results = self( + outputs, + img_size_for_boxes, + img_size_for_masks, + forced_labels=( + meta.original_category_id if self.use_original_ids else None + ), + ) + ids = ( + meta.original_image_id if self.use_original_ids else meta.coco_image_id + ) + assert len(detection_results) == len(ids) + for img_id, result in zip(ids, detection_results): + if img_id.item() not in results: + results[img_id.item()] = result + else: + assert set(results[img_id.item()].keys()) == set(result.keys()) + for k in result.keys(): + if isinstance(result[k], torch.Tensor): + results[img_id.item()][k] = torch.cat( + [results[img_id.item()][k], result[k]], dim=0 + ) + elif isinstance(result[k], list): + results[img_id.item()][k] += result[k] + else: + raise NotImplementedError( + f"Unexpected type {type(result[k])} in result." + ) + # Prune the results to the max number of detections per image. + for img_id, result in results.items(): + if ( + self.max_dets_per_img > 0 + and len(result["scores"]) > self.max_dets_per_img + ): + _, topk_indexes = torch.topk( + result["scores"], self.max_dets_per_img, dim=0 + ) + if self.to_cpu: + topk_indexes = topk_indexes.cpu() + for k in result.keys(): + if isinstance(results[img_id][k], list): + results[img_id][k] = [ + results[img_id][k][i] for i in topk_indexes.tolist() + ] + else: + results[img_id][k] = results[img_id][k].to(topk_indexes.device)[ + topk_indexes + ] + + return results + + +class PostProcessAPIVideo(PostProcessImage): + """This module converts the video model's output into the format expected by the YT-VIS api""" + + def __init__( + self, + *args, + to_cpu: bool = True, + convert_mask_to_rle: bool = False, + always_interpolate_masks_on_gpu: bool = True, + prob_thresh: float = 0.5, + use_presence: bool = False, + **kwargs, + ): + super().__init__( + *args, + # Here we always set `convert_mask_to_rle=False` in the base `PostProcessAPI` class + # (so that its `_process_masks` won't return a list of RLEs). If we want to return + # RLEs for video masklets, we handle it in this `PostProcessAPIVideo` class instead. + convert_mask_to_rle=False, + # Here we always set `to_cpu=False` in the base `PostProcessAPI` class (so that + # the interpolated masks won't be automatically moved back to CPU). We will handle + # it in this `PostProcessAPIVideo` class instead. + always_interpolate_masks_on_gpu=always_interpolate_masks_on_gpu, + use_presence=use_presence, + **kwargs, + ) + # Expected keys in the output dict to postprocess + self.EXPECTED_KEYS = [ + "pred_logits", + "pred_boxes", + "pred_masks", + ] + # Whether to post-process video masklets (under packed representation) into RLE format + self.convert_mask_to_rle_for_video = convert_mask_to_rle + self.to_cpu_for_video = to_cpu + self.prob_thresh = prob_thresh + + def process_results( + self, find_stages, find_metadatas: List[BatchedInferenceMetadata], **kwargs + ): + """ + Tracking Postprocessor for SAM 3 video model. + This function takes in the output of the SAM 3 video model and processes it to extract all the tracklet predictions. + Args: + find_stages: A list of tensors representing the output of the SAM 3 video model. + find_metadatas: A list of BatchedInferenceMetadata objects containing metadata about each frame. + **kwargs: Additional keyword arguments. + Returns: + A dictionary of predcitions with video_id as key. + """ + + # Import tensordict here to avoid global dependency. + try: + from tensordict import TensorDict + except ImportError as e: + logging.error( + "tensordict is not installed, please install by running `pip install tensordict --no-deps`" + ) + raise e + # Notes and assumptions: + # 1- This postprocessor assumes results only for a single video. + # 2- There are N stage outputs corresponding to N video frames + # 3- Each stage outputs contains PxQ preds, where P is number of prompts and Q is number of object queries. The output should also contain the tracking object ids corresponding to each object query. + # 4- The tracking object id has a default value of -1, indicating that the object query is not tracking any object in the frame, and hence its predictions can be ingored for a given frame. + # 5- Some objects may be tracked in a subset of frames only. So, we first extract the predictions in a packed representation (for efficient postprocessing -- specially memory) + # and then we convert the packed representation into a padded one, where we zero pad boxes/masks for objects that are not tracked in some frames. + # 6- We refer to objects by an object id, which is a tuple (prompt_idx, obj_id) + + assert len(find_stages) > 0, "There is nothing to postprocess?" + PROMPT_AXIS, OBJ_QUERY_AXIS = (0, 1) + NO_OBJ_ID = -1 + # Maps object ID -> [indices in packed tensor] + tracked_objects_packed_idx = defaultdict(list) + # Maps object ID -> [indices in padded tensor (abs frame index)] + tracked_objects_frame_idx = defaultdict(list) + total_num_preds = 0 + # This will hold the packed representation of predictions. + vid_preds_packed: List[TensorDict] = [] + vid_masklets_rle_packed: List[Optional[Dict]] = [] + video_id = -1 # We assume single video postprocessing, this ID should be unique in the datapoint. + + for frame_idx, (frame_outs, meta) in enumerate( + zip(find_stages, find_metadatas) + ): + # only store keys we need to extract the results + frame_outs_td = TensorDict( + {k: frame_outs[k] for k in self.EXPECTED_KEYS} + ).auto_batch_size_() # Shape is [P,Q,...] + meta_td = TensorDict( + dataclasses.asdict(meta) + ).auto_batch_size_() # Shape is [P,...] + unique_vid_id = meta.original_image_id.unique() + assert unique_vid_id.size(0) == 1 + if video_id == -1: + video_id = unique_vid_id.item() + else: + assert ( + video_id == unique_vid_id.item() + ), "We can only postprocess one video per datapoint" + # keeping track of which objects appear in the current frame + obj_ids_per_frame = frame_outs["pred_object_ids"] + assert obj_ids_per_frame.size(-1) == frame_outs["pred_logits"].size(-2) + if self.prob_thresh is not None: + # only keep the predictions on this frame with probability above the threshold + # (remove those predictions during the keep-alive period of a tracking query, + # where its "pred_object_ids" is still the tracked object ID rather than -1) + pred_probs = frame_outs["pred_logits"].sigmoid().squeeze(-1) + obj_ids_per_frame = torch.where( + pred_probs >= self.prob_thresh, obj_ids_per_frame, NO_OBJ_ID + ) + tracked_obj_ids_idx = torch.where(obj_ids_per_frame != NO_OBJ_ID) + # Object id is a tuple of (prompt_idx, obj_id). This is because the model can assign same obj_id for two different prompts. + tracked_obj_ids = [ + (p_id.item(), obj_ids_per_frame[p_id, q_id].item()) + for p_id, q_id in zip( + tracked_obj_ids_idx[PROMPT_AXIS], + tracked_obj_ids_idx[OBJ_QUERY_AXIS], + ) + ] + if len(tracked_obj_ids) == 0: + continue + # For each object, we keep track of the packed and padded (frame index) indices + for oid in tracked_obj_ids: + tracked_objects_packed_idx[oid].append(total_num_preds) + tracked_objects_frame_idx[oid].append(frame_idx) + total_num_preds += 1 + + # Since we have P*Q masks per frame, mask interpolation is the GPU memory bottleneck or time bottleneck in case of cpu processing. + # Instead, we first extract results only for tracked objects, reducing the number of masks to K = sum_i(tracked_objs_per_ith_prompt), hopefully <<< P*Q + tracked_objs_outs_td = frame_outs_td[ + tracked_obj_ids_idx + ] # [P,Q,...] --> [K,...] + meta_td = meta_td[tracked_obj_ids_idx[PROMPT_AXIS].cpu()] + if self.always_interpolate_masks_on_gpu: + gpu_device = meta_td["original_size"].device + assert gpu_device.type == "cuda" + tracked_objs_outs_td = tracked_objs_outs_td.to(device=gpu_device) + frame_results_td = self( + tracked_objs_outs_td.unsqueeze(1), + ( + meta_td["original_size"] + if self.use_original_sizes + else torch.ones_like(meta_td["original_size"]) + ), + forced_labels=( + meta_td["original_category_id"] if self.use_original_ids else None + ), + consistent=True, + ret_tensordict=True, + ).squeeze(1) + del tracked_objs_outs_td + + # Optionally, remove "masks" from output tensor dict and directly encode them + # to RLE format under packed representations + if self.convert_mask_to_rle_for_video: + interpolated_binary_masks = frame_results_td.pop("masks") + rle_list = rle_encode(interpolated_binary_masks, return_areas=True) + vid_masklets_rle_packed.extend(rle_list) + # Optionally, move output TensorDict to CPU (do this after RLE encoding step above) + if self.to_cpu_for_video: + frame_results_td = frame_results_td.cpu() + vid_preds_packed.append(frame_results_td) + + if len(vid_preds_packed) == 0: + logging.debug(f"Video {video_id} has no predictions") + return {video_id: []} + + vid_preds_packed = torch.cat(vid_preds_packed, dim=0) + ############### Construct a padded representation of the predictions ############### + num_preds = len(tracked_objects_packed_idx) + num_frames = len(find_stages) + # We zero pad any missing prediction + # NOTE: here, we also have padded tensors for "scores" and "labels", but we overwrite them later. + padded_frames_results = TensorDict( + { + k: torch.zeros( + num_preds, num_frames, *v.shape[1:], device=v.device, dtype=v.dtype + ) + for k, v in vid_preds_packed.items() + }, + batch_size=[ + num_preds, + num_frames, + ], + ) + padded_frames_results["scores"][...] = -1e8 # a very low score for empty object + # Track scores and labels of each pred tracklet, only for frames where the model was able to track that object + tracklet_scores = [] + tracklet_labels = [] + # Optionally, fill the list of RLEs for masklets + # note: only frames with actual predicted masks (in packed format) will be + # filled with RLEs; the rest will remains None in results["masks_rle"] + if self.convert_mask_to_rle_for_video: + vid_masklets_rle_padded = [[None] * num_frames for _ in range(num_preds)] + for o_idx, oid in enumerate(tracked_objects_packed_idx): + oid2packed_idx = tracked_objects_packed_idx[oid] + oid2padded_idx = tracked_objects_frame_idx[oid] + obj_packed_results = vid_preds_packed[oid2packed_idx] + padded_frames_results[o_idx][oid2padded_idx] = obj_packed_results + if self.convert_mask_to_rle_for_video: + for packed_idx, padded_idx in zip(oid2packed_idx, oid2padded_idx): + vid_masklets_rle_padded[o_idx][padded_idx] = ( + vid_masklets_rle_packed[packed_idx] + ) + # NOTE: We need a single confidence score per tracklet for the mAP metric. + # We use the average confidence score across time. (How does this impact AP?) + tracklet_scores.append(obj_packed_results["scores"].mean()) + # We also need to have a unique category Id per tracklet. + # This is not a problem for phrase AP, however, for mAP we do majority voting across time. + tracklet_labels.append(obj_packed_results["labels"].mode()[0]) + + results = padded_frames_results.to_dict() + results["scores"] = torch.stack(tracklet_scores, dim=0) + results["labels"] = torch.stack(tracklet_labels, dim=0) + if self.convert_mask_to_rle_for_video: + results["masks_rle"] = vid_masklets_rle_padded + # we keep the frame-level scores since it's needed by some evaluation scripts + results["per_frame_scores"] = padded_frames_results["scores"] + + return {video_id: results} + + +class PostProcessTracking(PostProcessImage): + """This module converts the model's output into the format expected by the coco api""" + + def __init__( + self, + max_dets_per_img: int, + iou_type="bbox", + force_single_mask: bool = False, + **kwargs, + ) -> None: + super().__init__(max_dets_per_img=max_dets_per_img, iou_type=iou_type, **kwargs) + self.force_single_mask = force_single_mask + + def process_results( + self, find_stages, find_metadatas: BatchedInferenceMetadata, **kwargs + ): + assert len(find_stages) == len(find_metadatas) + results = {} + for outputs, meta in zip(find_stages, find_metadatas): + if self.force_single_mask: + scores, labels = outputs["pred_logits"].max(-1) + m = [] + for i in range(len(outputs["pred_masks"])): + score, idx = scores[i].max(0) + m.append(outputs["pred_masks"][i][idx]) + outputs["pred_masks"] = torch.stack(m, 0).unsqueeze(1) + detection_results = self(outputs, meta.original_size, consistent=False) + assert len(detection_results) == len(meta.coco_image_id) + results.update( + { + (media_id.item(), object_id.item(), frame_index.item()): result + for media_id, object_id, frame_index, result in zip( + meta.original_image_id, + meta.object_id, + meta.frame_index, + detection_results, + ) + } + ) + return results + + +class PostProcessCounting(nn.Module): + """This module converts the model's output to be evaluated for counting tasks""" + + def __init__( + self, + use_original_ids: bool = False, + threshold: float = 0.5, + use_presence: bool = False, + ) -> None: + """ + Args: + use_original_ids: whether to use the original image ids or the coco ids + threshold: threshold for counting (values above this are counted) + """ + super().__init__() + self.use_original_ids = use_original_ids + self.threshold = threshold + self.use_presence = use_presence + + def forward(self, outputs, target_sizes): + """Perform the computation + Parameters: + outputs: raw outputs of the model + target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch + """ + # Extract scores from model outputs and apply sigmoid + scores = torch.sigmoid(outputs["pred_logits"]).squeeze(-1) # [B, N] + if self.use_presence: + presence_score = outputs["presence_logit_dec"].sigmoid() + if presence_score.ndim == 1: + presence_score = presence_score.unsqueeze(1) # [B, 1] + scores = scores * presence_score # [B, N] + + # Calculate counts by summing values above threshold + counts = (scores > self.threshold).float().sum(dim=1) + + assert len(counts) == len(target_sizes) + results = [] + for count in counts: + results.append({"count": count.item()}) + + return results + + @torch.no_grad() + def process_results( + self, find_stages, find_metadatas: List[BatchedInferenceMetadata], **kwargs + ): + assert len(find_stages) == len(find_metadatas) + results = {} + for outputs, meta in zip(find_stages, find_metadatas): + detection_results = self( + outputs, + meta.original_size, + ) + ids = ( + meta.original_image_id if self.use_original_ids else meta.coco_image_id + ) + assert len(detection_results) == len(ids) + for img_id, result in zip(ids, detection_results): + results[img_id.item()] = result + + return results diff --git a/third_party/sam3/sam3/eval/saco_veval_eval.py b/third_party/sam3/sam3/eval/saco_veval_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..ff4e9230b9594246aec53665069f19c05541eea8 --- /dev/null +++ b/third_party/sam3/sam3/eval/saco_veval_eval.py @@ -0,0 +1,157 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe +import argparse +import json +import os +from collections import defaultdict + +from iopath.common.file_io import g_pathmgr +from sam3.eval.saco_veval_evaluators import ( + VideoCGF1Evaluator, + VideoPhraseApEvaluator, + VideoPhraseHotaEvaluator, + VideoTetaEvaluator, + YTVISPredFileEvaluator, +) + + +class VEvalEvaluator: + def __init__(self, gt_annot_file: str, eval_res_file: str): + self.gt_annot_file = gt_annot_file + self.eval_res_file = eval_res_file + self.evaluators = [ + # mAP + YTVISPredFileEvaluator(gt_annot_file), + # Phrase AP + VideoPhraseApEvaluator(gt_annot_file), + # TETA + VideoTetaEvaluator(gt_annot_file, use_mask=True, is_exhaustive=True), + # HOTA + VideoPhraseHotaEvaluator(gt_annot_file), + # cgF1 + VideoCGF1Evaluator(gt_annot_file), + ] + + def run_eval(self, pred_file: str): + dataset_results = {} + video_np_results = defaultdict(dict) + for evaluator in self.evaluators: + d_res, v_np_res = evaluator.evaluate(pred_file) + dataset_results.update(d_res) + for (video_id, category_id), res in v_np_res.items(): + video_np_results[(video_id, category_id)].update(res) + + if len(dataset_results) == 0: + dataset_results = {"": 0.0} + + formatted_video_np_results = [ + {"video_id": video_id, "category_id": category_id, **res} + for (video_id, category_id), res in video_np_results.items() + ] + eval_metrics = { + "dataset_results": dataset_results, + "video_np_results": formatted_video_np_results, + } + + with g_pathmgr.open(self.eval_res_file, "w") as f: + json.dump(eval_metrics, f) + + return eval_metrics + + +def run_main_all(dataset_name, args): + gt_annot_file = os.path.join(args.gt_annot_dir, dataset_name + ".json") + pred_file = os.path.join(args.pred_dir, dataset_name + "_preds.json") + eval_res_file = os.path.join(args.eval_res_dir, dataset_name + "_eval_res.json") + print(f"=== Running evaluation for Pred {pred_file} vs GT {gt_annot_file} ===") + veval_evaluator = VEvalEvaluator( + gt_annot_file=gt_annot_file, eval_res_file=eval_res_file + ) + _ = veval_evaluator.run_eval(pred_file=pred_file) + + print(f"=== Results saved to {eval_res_file} ===") + + +def main_all(args): + saco_veval_dataset_names = [ + "saco_veval_sav_test", + "saco_veval_sav_val", + "saco_veval_yt1b_test", + "saco_veval_yt1b_val", + "saco_veval_smartglasses_test", + "saco_veval_smartglasses_val", + ] + + # multiprocessing may not really work as inner evaluator also using multiprocessing + # so we just for loop + for dataset_name in saco_veval_dataset_names: + print(f"=== Running evaluation for dataset {dataset_name} ===") + run_main_all(dataset_name=dataset_name, args=args) + + +def main_one(args): + gt_annot_file = args.gt_annot_file + pred_file = args.pred_file + eval_res_file = args.eval_res_file + + print(f"=== Running evaluation for Pred {pred_file} vs GT {gt_annot_file} ===") + veval_evaluator = VEvalEvaluator( + gt_annot_file=gt_annot_file, eval_res_file=eval_res_file + ) + _ = veval_evaluator.run_eval(pred_file=pred_file) + + print(f"=== Results saved to {eval_res_file} ===") + + +def main(): + parser = argparse.ArgumentParser(description="Run video grounding evaluators") + + # Create subparsers for different commands + subparsers = parser.add_subparsers(dest="command", required=True) + + # Run evaluation for all datasets + all_parser = subparsers.add_parser("all", help="Run evaluation for all datasets") + all_parser.add_argument( + "--gt_annot_dir", + type=str, + help="Directory that contains the ground truth annotation files", + ) + all_parser.add_argument( + "--pred_dir", + type=str, + help="Directory that contains the prediction files", + ) + all_parser.add_argument( + "--eval_res_dir", + type=str, + help="Directory that contains the eval results files", + ) + all_parser.set_defaults(func=main_all) + + # Run evaluation for one dataset + one_parser = subparsers.add_parser("one", help="Run evaluation for one dataset") + one_parser.add_argument( + "--gt_annot_file", + type=str, + help="Path to the ground truth annotation file", + ) + one_parser.add_argument( + "--pred_file", + type=str, + help="Path to the prediction file", + ) + one_parser.add_argument( + "--eval_res_file", + type=str, + help="Path to the eval results file", + ) + one_parser.set_defaults(func=main_one) + + # Parse and dispatch + args = parser.parse_args() + args.func(args) + + +if __name__ == "__main__": + main() diff --git a/third_party/sam3/sam3/eval/saco_veval_evaluators.py b/third_party/sam3/sam3/eval/saco_veval_evaluators.py new file mode 100644 index 0000000000000000000000000000000000000000..69a2796cf1794a119f36242abd1bf67710224a78 --- /dev/null +++ b/third_party/sam3/sam3/eval/saco_veval_evaluators.py @@ -0,0 +1,840 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe +import json +import os +import tempfile +from collections import defaultdict +from typing import Dict, Optional, Sequence, Tuple + +import numpy as np +import pycocotools.mask +from sam3.eval.cgf1_eval import CGF1_METRICS +from sam3.eval.conversion_util import ( + convert_ytbvis_to_cocovid_gt, + convert_ytbvis_to_cocovid_pred, +) +from sam3.eval.hota_eval_toolkit.run_ytvis_eval import run_ytvis_eval +from sam3.eval.teta_eval_toolkit import config, Evaluator, metrics +from sam3.eval.teta_eval_toolkit.datasets import COCO, TAO +from sam3.eval.ytvis_coco_wrapper import YTVIS +from sam3.eval.ytvis_eval import VideoDemoF1Eval, YTVISeval +from sam3.train.nms_helper import process_frame_level_nms, process_track_level_nms + + +def _get_metric_index(metric_name: str, iou_threshold: Optional[float] = None) -> int: + """ + Find the index of a metric in CGF1_METRICS by name and IoU threshold. + + Args: + metric_name: Name of the metric (e.g., "cgF1", "precision", "recall") + iou_threshold: IoU threshold (None for average over 0.5:0.95, or specific value like 0.5, 0.75) + + Returns: + Index of the metric in CGF1_METRICS + + Raises: + ValueError: If metric not found + """ + for idx, metric in enumerate(CGF1_METRICS): + if metric.name == metric_name and metric.iou_threshold == iou_threshold: + return idx + raise ValueError( + f"Metric '{metric_name}' with IoU threshold {iou_threshold} not found in CGF1_METRICS" + ) + + +class BasePredFileEvaluator: + """A base class for evaluating a prediction file.""" + + pass + + +class YTVISPredFileEvaluator(BasePredFileEvaluator): + """Evaluate class mAP for YT-VIS prediction files.""" + + def __init__( + self, + gt_ann_file: str, + dataset_name: str = "video", + iou_types: Optional[Sequence[str]] = None, + ): + self.gt_ann_file = gt_ann_file + self.dataset_name = dataset_name + self.iou_types = list(iou_types) if iou_types is not None else ["bbox", "segm"] + assert all(iou_type in ["bbox", "segm"] for iou_type in self.iou_types) + + def evaluate(self, pred_file: str) -> Dict[str, float]: + # use our internal video evaluation toolkit for YT-VIS pred file + # (i.e. the same one we're using for video phrase AP) + results = {} + use_cats = True # YT-VIS mAP evaluation uses categories + ytvisGT = YTVIS(self.gt_ann_file, ignore_gt_cats=not use_cats) + # the original YT-VIS GT annotations have uncompressed RLEs ("counts" is an integer list) + # rather than compressed RLEs ("counts" is a string), so we first convert them here. + if "segm" in self.iou_types: + for ann in ytvisGT.dataset["annotations"]: + ann["segmentations"] = [ + _compress_rle(rle) for rle in ann["segmentations"] + ] + + with open(pred_file) as f: + dt = json.load(f) + # Our prediction file saves "video_id" and absolute (unnormalized) boxes. + # Note that we should use the official (original) YT-VIS annotations (i.e. the one + # saved via "scripts/datasets/training/ytvis_split.py", instead of the one saved + # via "scripts/api_db_to_ytvis_json.py") in this evaluator, which contain absolute + # boxes coordinates in its GT annotations. + for d in dt: + d["image_id"] = d["video_id"] + ytvisDT = ytvisGT.loadRes(dt) + + for iou_type in self.iou_types: + ytvisEval = YTVISeval(ytvisGT, ytvisDT, iou_type) + + # set the area ranges for small, medium, and large objects (using + # absolute pixel areas) as in the official YT-VIS evaluation toolkit: + # https://github.com/achalddave/ytvosapi/blob/eca601117c9f86bad084cb91f1d918e9ab665a75/PythonAPI/ytvostools/ytvoseval.py#L538 + ytvisEval.params.areaRng = [ + [0**2, 1e5**2], + [0**2, 128**2], + [128**2, 256**2], + [256**2, 1e5**2], + ] + ytvisEval.params.areaRngLbl = ["all", "small", "medium", "large"] + ytvisEval.params.useCats = use_cats + + ytvisEval.evaluate() + ytvisEval.accumulate() + ytvisEval.summarize() + result_key = f"{self.dataset_name}_{'mask' if iou_type == 'segm' else 'bbox'}_mAP_50_95" + results[result_key] = ytvisEval.stats[0] + + # video-NP level results not supported for `YTVISPredFileEvaluator` yet + video_np_level_results = {} + return results, video_np_level_results + + +class VideoPhraseApEvaluator(BasePredFileEvaluator): + """Evaluate Video Phrase AP with YT-VIS format prediction and GT files.""" + + def __init__( + self, + gt_ann_file: str, + dataset_name: str = "video", + iou_types: Optional[Sequence[str]] = None, + ): + self.gt_ann_file = gt_ann_file + self.dataset_name = dataset_name + self.iou_types = list(iou_types) if iou_types is not None else ["bbox", "segm"] + assert all(iou_type in ["bbox", "segm"] for iou_type in self.iou_types) + + def evaluate(self, pred_file: str) -> Dict[str, float]: + with open(self.gt_ann_file) as f: + gt = json.load(f) + with open(pred_file) as f: + dt = json.load(f) + # For phrase AP and demo F1 evaluation, we need to remap each pair of (video_id, category_id) to + # a new unique video_id, so that we don't mix detections from different categories under `useCat=False` + gt, dt = remap_video_category_pairs_to_unique_video_ids(gt, dt) + if "segm" in self.iou_types: + for ann in gt["annotations"]: + ann["segmentations"] = [ + _compress_rle(rle) for rle in ann["segmentations"] + ] + for d in dt: + d["image_id"] = d["video_id"] + + results = {} + use_cats = False # Phrase AP evaluation does not use categories + ytvisGT = YTVIS(annotation_file=None, ignore_gt_cats=not use_cats) + ytvisGT.dataset = gt + ytvisGT.createIndex() + ytvisDT = ytvisGT.loadRes(dt) + + for iou_type in self.iou_types: + phraseApEval = YTVISeval(ytvisGT, ytvisDT, iou_type) + + # set the area ranges for small, medium, and large objects (using + # absolute pixel areas) as in the official YT-VIS evaluation toolkit: + # https://github.com/achalddave/ytvosapi/blob/eca601117c9f86bad084cb91f1d918e9ab665a75/PythonAPI/ytvostools/ytvoseval.py#L538 + phraseApEval.params.areaRng = [ + [0**2, 1e5**2], + [0**2, 128**2], + [128**2, 256**2], + [256**2, 1e5**2], + ] + phraseApEval.params.areaRngLbl = ["all", "small", "medium", "large"] + phraseApEval.params.useCats = use_cats + + phraseApEval.evaluate() + phraseApEval.accumulate() + phraseApEval.summarize() + result_prefix = f"{self.dataset_name}" + result_prefix += f"_{'mask' if iou_type == 'segm' else 'bbox'}_phrase_ap" + # fetch Phrase AP results from the corresponding indices in `phraseApEval.stats` + # (see `_summarizeDets` in https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/cocoeval.py) + results[result_prefix + "_50_95"] = phraseApEval.stats[0] # IoU=0.5:0.95 + results[result_prefix + "_50"] = phraseApEval.stats[1] # IoU=0.5 + results[result_prefix + "_75"] = phraseApEval.stats[2] # IoU=0.75 + + # video-NP level results not supported for `VideoPhraseApEvaluator` yet + video_np_level_results = {} + return results, video_np_level_results + + +class VideoCGF1Evaluator(BasePredFileEvaluator): + """Evaluate Video Demo F1 with YT-VIS format prediction and GT files.""" + + def __init__( + self, + gt_ann_file: str, + dataset_name: str = "video", + prob_thresh: float = 0.5, + iou_types: Optional[Sequence[str]] = None, + ): + self.gt_ann_file = gt_ann_file + self.dataset_name = dataset_name + self.prob_thresh = prob_thresh + self.iou_types = list(iou_types) if iou_types is not None else ["bbox", "segm"] + assert all(iou_type in ["bbox", "segm"] for iou_type in self.iou_types) + + def evaluate(self, pred_file: str) -> Dict[str, float]: + with open(self.gt_ann_file) as f: + gt = json.load(f) + with open(pred_file) as f: + dt = json.load(f) + # compute IL_MCC and CG-F1 can only be computed if we have "video_np_pairs" keys in the GT JSON + compute_ilmcc_and_cgf1 = "video_np_pairs" in gt + if not compute_ilmcc_and_cgf1: + print( + f"Warning: IL_MCC and CG-F1 are not computed for {pred_file=} as it does not have 'video_np_pairs' keys in the GT JSON" + ) + # For phrase AP and demo F1 evaluation, we need to remap each pair of (video_id, category_id) to + # a new unique video_id, so that we don't mix detections from different categories under `useCat=False` + gt, dt = remap_video_category_pairs_to_unique_video_ids( + gt, dt, add_negative_np_pairs=compute_ilmcc_and_cgf1 + ) + if "segm" in self.iou_types: + for ann in gt["annotations"]: + ann["segmentations"] = [ + _compress_rle(rle) for rle in ann["segmentations"] + ] + for d in dt: + d["image_id"] = d["video_id"] + + results = {} + use_cats = False # Demo F1 evaluation does not use categories + ytvisGT = YTVIS(annotation_file=None, ignore_gt_cats=not use_cats) + ytvisGT.dataset = gt + ytvisGT.createIndex() + ytvisDT = ytvisGT.loadRes(dt) + + video_np_level_results = {} + for iou_type in self.iou_types: + demoF1Eval = VideoDemoF1Eval(ytvisGT, ytvisDT, iou_type, self.prob_thresh) + + demoF1Eval.params.useCats = use_cats + demoF1Eval.params.areaRng = [[0**2, 1e5**2]] + demoF1Eval.params.areaRngLbl = ["all"] + demoF1Eval.params.maxDets = [100000] + + demoF1Eval.evaluate() + demoF1Eval.accumulate() + demoF1Eval.summarize() + result_prefix = f"{self.dataset_name}" + result_prefix += f"_{'mask' if iou_type == 'segm' else 'bbox'}_demo" + + stats = demoF1Eval.stats + + if compute_ilmcc_and_cgf1: + # Average IoU threshold (0.5:0.95) + cgf1_micro_avg_idx = _get_metric_index("cgF1", None) + positive_micro_f1_avg_idx = _get_metric_index("positive_micro_F1", None) + ilmcc_avg_idx = _get_metric_index("IL_MCC", None) + results[result_prefix + "_cgf1_micro_50_95"] = stats[cgf1_micro_avg_idx] + results[result_prefix + "_ilmcc_50_95"] = stats[ilmcc_avg_idx] + results[result_prefix + "_positive_micro_f1_50_95"] = stats[ + positive_micro_f1_avg_idx + ] + + # IoU = 0.5 + cgf1_micro_50_idx = _get_metric_index("cgF1", 0.5) + positive_micro_f1_50_idx = _get_metric_index("positive_micro_F1", 0.5) + results[result_prefix + "_cgf1_micro_50"] = stats[cgf1_micro_50_idx] + results[result_prefix + "_ilmcc_50"] = float( + np.array(stats[cgf1_micro_50_idx]) + / np.array(stats[positive_micro_f1_50_idx]) + ) + results[result_prefix + "_positive_micro_f1_50"] = stats[ + positive_micro_f1_50_idx + ] + + # IoU = 0.75 + cgf1_micro_75_idx = _get_metric_index("cgF1", 0.75) + positive_micro_f1_75_idx = _get_metric_index("positive_micro_F1", 0.75) + results[result_prefix + "_cgf1_micro_75"] = stats[cgf1_micro_75_idx] + results[result_prefix + "_ilmcc_75"] = float( + np.array(stats[cgf1_micro_75_idx]) + / np.array(stats[positive_micro_f1_75_idx]) + ) + results[result_prefix + "_positive_micro_f1_75"] = stats[ + positive_micro_f1_75_idx + ] + + self.extract_video_np_level_results(demoF1Eval, video_np_level_results) + + return results, video_np_level_results + + def extract_video_np_level_results(self, demoF1Eval, video_np_level_results): + """Aggregate statistics for video-level metrics.""" + num_iou_thrs = len(demoF1Eval.params.iouThrs) + iou_50_index = int(np.where(demoF1Eval.params.iouThrs == 0.5)[0]) + iou_75_index = int(np.where(demoF1Eval.params.iouThrs == 0.75)[0]) + + result_prefix = "mask" if demoF1Eval.params.iouType == "segm" else "bbox" + + assert len(demoF1Eval.evalImgs) == len(demoF1Eval.cocoGt.dataset["images"]) + for i, video in enumerate(demoF1Eval.cocoGt.dataset["images"]): + # the original video id and category id before remapping + video_id = video["orig_video_id"] + category_id = video["orig_category_id"] + eval_img_dict = demoF1Eval.evalImgs[i] + + TPs = eval_img_dict.get("TPs", np.zeros(num_iou_thrs, dtype=np.int64)) + FPs = eval_img_dict.get("FPs", np.zeros(num_iou_thrs, dtype=np.int64)) + FNs = eval_img_dict.get("FNs", np.zeros(num_iou_thrs, dtype=np.int64)) + assert len(TPs) == len(FPs) == len(FNs) == num_iou_thrs + # F1 = 2*TP / (2*TP + FP + FN), and we set F1 to 1.0 if denominator is 0 + denominator = 2 * TPs + FPs + FNs + F1s = np.where(denominator > 0, 2 * TPs / np.maximum(denominator, 1), 1.0) + local_results = { + f"{result_prefix}_TP_50_95": float(TPs.mean()), + f"{result_prefix}_FP_50_95": float(FPs.mean()), + f"{result_prefix}_FN_50_95": float(FNs.mean()), + f"{result_prefix}_F1_50_95": float(F1s.mean()), + f"{result_prefix}_TP_50": float(TPs[iou_50_index]), + f"{result_prefix}_FP_50": float(FPs[iou_50_index]), + f"{result_prefix}_FN_50": float(FNs[iou_50_index]), + f"{result_prefix}_F1_50": float(F1s[iou_50_index]), + f"{result_prefix}_TP_75": float(TPs[iou_75_index]), + f"{result_prefix}_FP_75": float(FPs[iou_75_index]), + f"{result_prefix}_FN_75": float(FNs[iou_75_index]), + f"{result_prefix}_F1_75": float(F1s[iou_75_index]), + } + if (video_id, category_id) not in video_np_level_results: + video_np_level_results[(video_id, category_id)] = {} + video_np_level_results[(video_id, category_id)].update(local_results) + + +class VideoTetaEvaluator(BasePredFileEvaluator): + """Evaluate TETA metric using YouTubeVIS format prediction and GT files.""" + + def __init__( + self, + gt_ann_file: str, + dataset_name: str = "video", + tracker_name: str = "Sam3", + nms_threshold: float = 0.5, + nms_strategy: str = "none", # "track", "frame", or "none" + prob_thresh: float = 0.5, + is_exhaustive: bool = False, + use_mask: bool = False, + num_parallel_cores: int = 8, + ): + self.gt_ann_file = gt_ann_file + self.dataset_name = dataset_name + self.tracker_name = tracker_name + self.nms_threshold = nms_threshold + self.nms_strategy = nms_strategy.lower() # Convert to lowercase for consistency + self.prob_thresh = prob_thresh + self.metric_prefix = "TETA" + self.is_exhaustive = is_exhaustive + self.use_mask = use_mask + self.num_parallel_cores = num_parallel_cores + + # Verify NMS strategy is valid + valid_strategies = ["track", "frame", "none"] + print("current nms_strategy:", self.nms_strategy) + if self.nms_strategy not in valid_strategies: + raise ValueError( + f"Invalid NMS strategy: {self.nms_strategy}. Must be one of {valid_strategies}" + ) + + print(f"Initialized VideoTetaEvaluator with NMS strategy: {self.nms_strategy}") + print(f"Probability threshold set to: {self.prob_thresh}") + print(f"Dataset exhaustivity set to: {self.is_exhaustive}") + print(f"Tracker name set to: {self.tracker_name}") + print(f"Dataset name set to: {self.dataset_name}") + print(f"Use mask set to: {self.use_mask}") + + def process_predictions(self, pred_file: str, tmp_dir: str) -> str: + """Process predictions with selected NMS strategy""" + with open(pred_file, "r") as f: + raw_preds = json.load(f) + print(f"Processing predictions with {self.nms_strategy} NMS strategy") + + # Filter by score threshold + if self.prob_thresh > 0: + raw_preds = [d for d in raw_preds if d["score"] >= self.prob_thresh] + print( + f"Filtered to {len(raw_preds)} predictions with score >= {self.prob_thresh}" + ) + # Group predictions by video_id + video_groups = defaultdict(list) + for pred in raw_preds: + video_groups[pred["video_id"]].append(pred) + # Process based on NMS strategy + if self.nms_strategy == "track": + process_track_level_nms(video_groups, nms_threshold=self.nms_threshold) + elif self.nms_strategy == "frame": + process_frame_level_nms(video_groups, nms_threshold=self.nms_threshold) + elif self.nms_strategy == "none": + print("Skipping NMS processing as strategy is set to 'none'") + # No processing needed for "none" strategy + # Save processed predictions + processed_preds = [ + track for tracks in video_groups.values() for track in tracks + ] + processed_path = os.path.join(tmp_dir, "processed_preds.json") + with open(processed_path, "w") as f: + json.dump(processed_preds, f) + + print(f"Saved processed predictions to {processed_path}") + return processed_path + + def evaluate(self, pred_file: str) -> Tuple[Dict[str, float], Dict]: + """Main evaluation method""" + + print(f"Evaluating TETA Metric with {self.nms_strategy.upper()} NMS strategy") + with tempfile.TemporaryDirectory() as tmp_dir: + # Process predictions first + processed_pred_file = self.process_predictions(pred_file, tmp_dir) + + # Convert GT to COCO-vid format + gt_dir = os.path.join(tmp_dir, "gt") + os.makedirs(gt_dir, exist_ok=True) + gt_coco_path = os.path.join(gt_dir, "annotations.json") + convert_ytbvis_to_cocovid_gt(self.gt_ann_file, gt_coco_path) + + # Convert processed predictions to COCO-vid format + pred_dir = os.path.join(tmp_dir, "predictions") + tracker_dir = os.path.join(pred_dir, self.tracker_name) + os.makedirs(tracker_dir, exist_ok=True) + pred_coco_path = os.path.join(tracker_dir, "track_results_cocofmt.json") + convert_ytbvis_to_cocovid_pred( + youtubevis_pred_path=processed_pred_file, + converted_dataset_path=gt_coco_path, + output_path=pred_coco_path, + ) + # Configure TETA evaluator + default_eval_config = config.get_default_eval_config() + default_eval_config["PRINT_ONLY_COMBINED"] = True + default_eval_config["DISPLAY_LESS_PROGRESS"] = True + default_eval_config["OUTPUT_TEMP_RAW_DATA"] = True + default_eval_config["NUM_PARALLEL_CORES"] = self.num_parallel_cores + default_dataset_config = config.get_default_dataset_config() + default_dataset_config["TRACKERS_TO_EVAL"] = [self.tracker_name] + default_dataset_config["GT_FOLDER"] = gt_dir + default_dataset_config["OUTPUT_FOLDER"] = pred_dir + default_dataset_config["TRACKER_SUB_FOLDER"] = tracker_dir + default_dataset_config["USE_MASK"] = self.use_mask + + evaluator = Evaluator(default_eval_config) + if self.is_exhaustive: + dataset_list = [COCO(default_dataset_config)] + dataset_parsing_key = "COCO" + else: + dataset_list = [TAO(default_dataset_config)] + dataset_parsing_key = "TAO" + + # Run evaluation + eval_results, _ = evaluator.evaluate( + dataset_list, [metrics.TETA(exhaustive=self.is_exhaustive)] + ) + + # Extract and format results + results = { + f"{self.dataset_name}_{'mask' if self.use_mask else 'bbox'}_teta": float( + eval_results[dataset_parsing_key]["TETA"][0] + ), + f"{self.dataset_name}_{'mask' if self.use_mask else 'bbox'}_loc_a": float( + eval_results[dataset_parsing_key]["TETA"][1] + ), + f"{self.dataset_name}_{'mask' if self.use_mask else 'bbox'}_assoc_a": float( + eval_results[dataset_parsing_key]["TETA"][2] + ), + f"{self.dataset_name}_{'mask' if self.use_mask else 'bbox'}_cls_a": float( + eval_results[dataset_parsing_key]["TETA"][3] + ), + f"{self.dataset_name}_{'mask' if self.use_mask else 'bbox'}_loc_re": float( + eval_results[dataset_parsing_key]["TETA"][4] + ), + f"{self.dataset_name}_{'mask' if self.use_mask else 'bbox'}_loc_pr": float( + eval_results[dataset_parsing_key]["TETA"][5] + ), + f"{self.dataset_name}_{'mask' if self.use_mask else 'bbox'}_assoc_re": float( + eval_results[dataset_parsing_key]["TETA"][6] + ), + f"{self.dataset_name}_{'mask' if self.use_mask else 'bbox'}_assoc_pr": float( + eval_results[dataset_parsing_key]["TETA"][7] + ), + f"{self.dataset_name}_{'mask' if self.use_mask else 'bbox'}_cls_re": float( + eval_results[dataset_parsing_key]["TETA"][8] + ), + f"{self.dataset_name}_{'mask' if self.use_mask else 'bbox'}_cls_pr": float( + eval_results[dataset_parsing_key]["TETA"][9] + ), + } + + # video-NP level results not supported for `VideoTetaEvaluator` yet + video_np_level_results = {} + return results, video_np_level_results + + +class VideoPhraseHotaEvaluator(BasePredFileEvaluator): + """Evaluate Video Phrase HOTA with YT-VIS format prediction and GT files.""" + + def __init__( + self, + gt_ann_file: str, + dataset_name: str = "video", + prob_thresh: float = 0.5, + iou_types: Optional[Sequence[str]] = None, + compute_video_mot_hota: bool = False, + ): + self.gt_ann_file = gt_ann_file + self.dataset_name = dataset_name + self.prob_thresh = prob_thresh + self.metric_prefix = "phrase" + # the list of metrics to collect from the HOTA evaluation results + self.metric_to_collect = [ + "HOTA", + "DetA", + "AssA", + "DetRe", + "DetPr", + "AssRe", + "AssPr", + "LocA", + "OWTA", + ] + self.iou_types = list(iou_types) if iou_types is not None else ["bbox", "segm"] + assert all(iou_type in ["bbox", "segm"] for iou_type in self.iou_types) + + # If True, compute video MOT HOTA, aggregating predictions/GT from all categories. + self.compute_video_mot_hota = compute_video_mot_hota + + def evaluate(self, pred_file: str) -> Dict[str, float]: + # use the YT-VIS evaluation toolkit in TrackEval + + with open(self.gt_ann_file) as f: + gt = json.load(f) + with open(pred_file) as f: + dt = json.load(f) + # keep only predictions with score above the probability threshold + dt = [d for d in dt if d["score"] > self.prob_thresh] + for d in dt: + assert len(d["areas"]) == len(d["bboxes"]) + assert len(d["areas"]) == len(d["segmentations"]) + # remove empty boxes (otherwise they will count as false positives for during + # per-frame detection accuracy in HOTA evaluation) + for t in range(len(d["bboxes"])): + bbox = d["bboxes"][t] + if d["areas"][t] == 0 or bbox is None or all(x == 0 for x in bbox): + d["segmentations"][t] = None + d["bboxes"][t] = None + d["areas"][t] = None + # check that box occurence and mask occurence are consistent + for bbox, mask, area in zip(d["bboxes"], d["segmentations"], d["areas"]): + assert (area is None) == (bbox is None) + assert (area is None) == (mask is None) + # set all scores to 1.0 for HOTA evaluation (just like Demo F1, the exact score + # value is not used in HOTA metrics; it will be treated as a detection prediction + # as long as its score is above the threshold) + d["score"] = 1.0 + + # remap the GT and DT annotations for phrase HOTA evaluation + gt = _fill_in_ann_height_width(gt) + if not self.compute_video_mot_hota: + # remap the GT and DT annotations for phrase HOTA evaluation + gt, dt = self._remap_gt_dt(gt, dt) + else: + # Compute video-level MOT HOTA + # Apply track-level NMS + video_groups = defaultdict(list) + for pred in dt: + video_groups[pred["video_id"]].append(pred) + process_track_level_nms(video_groups, nms_threshold=0.5) + dt = [track for tracks in video_groups.values() for track in tracks] + + # Remap GT track ids for class-agnostic HOTA + gt, dt = remap_gt_dt_class_agnostic(gt, dt) + + # run the HOTA evaluation using TrackEval on the remapped (video_id, category_id) pairs + out_dict = {} + video_np_level_results = {} + for iou_type in self.iou_types: + output_res, _ = run_ytvis_eval( + args=[ + "--METRICS", + "HOTA", + "--IOU_TYPE", + iou_type, + "--DATASET_NAME", + self.dataset_name, + "--USE_PARALLEL", + "True", + "--NUM_PARALLEL_CORES", + "8", + "--PLOT_CURVES", + "False", + "--LOG_ON_ERROR", + "None", + "--PRINT_ONLY_COMBINED", + "True", + "--OUTPUT_SUMMARY", + "False", + "--OUTPUT_DETAILED", + "False", + "--TIME_PROGRESS", + "False", + "--PRINT_CONFIG", + "False", + ], + gt_json=gt, + dt_json=dt, + ) + self.extract_video_np_level_results( + iou_type=iou_type, + remapped_gt=gt, + raw_results=output_res[self.dataset_name]["tracker"], + video_np_level_results=video_np_level_results, + ) + + def _summarize_results(output_res, iou_type, field, suffix): + eval_res = output_res[self.dataset_name]["tracker"][field] + result_prefix = f"{self.dataset_name}_{'mask' if iou_type == 'segm' else 'bbox'}_{suffix}" + for metric_name in self.metric_to_collect: + eval_res_hota = eval_res["cls_comb_cls_av"]["HOTA"] + result_key = f"{result_prefix}_{self.metric_prefix}_{metric_name}" + result_value = float(np.mean(eval_res_hota[metric_name])) + out_dict[result_key] = result_value + + _summarize_results(output_res, iou_type, "COMBINED_SEQ", "all") + if "COMBINED_SEQ_CHALLENGING" in output_res[self.dataset_name]["tracker"]: + _summarize_results( + output_res, iou_type, "COMBINED_SEQ_CHALLENGING", "challenging" + ) + + # video-NP level results not supported for `VideoPhraseHotaEvaluator` yet + return out_dict, video_np_level_results + + def _remap_gt_dt(self, gt, dt): + # For phrase HOTA evaluation, we need to remap each pair of (video_id, category_id) to + # a new unique video_id, so that we don't mix detections from different categories + gt, dt = remap_video_category_pairs_to_unique_video_ids(gt, dt) + # We further map all the categories to category_id=1 in HOTA evaluation toolkit + # for phrase HOTA (similar to "useCat=False" for video phrase AP) + remapped_category_id = 1 + gt["categories"] = [ + { + "supercategory": "object", + "id": remapped_category_id, + "name": "_REMAPPED_FOR_PHRASE_METRICS_", + } + ] + for ann in gt["annotations"]: + ann["category_id"] = remapped_category_id + for d in dt: + d["category_id"] = remapped_category_id + # To be compatible with the TrackEval YT-VIS evaluation toolkit, we need to give + # unique filenames to each remapped video, so we add remapped video_id as prefix. + for video in gt["videos"]: + new_video_id = video["id"] + video["file_names"] = [ + f"remapped_vid_{new_video_id:012d}/{name}" + for name in video["file_names"] + ] + return gt, dt + + def extract_video_np_level_results( + self, iou_type, remapped_gt, raw_results, video_np_level_results + ): + """Aggregate statistics for video-level metrics.""" + result_prefix = "mask" if iou_type == "segm" else "bbox" + for video in remapped_gt["videos"]: + # the original video id and category id before remapping + video_id = video["orig_video_id"] + category_id = video["orig_category_id"] + video_key = f"remapped_vid_{video['id']:012d}" + results = raw_results[video_key]["_REMAPPED_FOR_PHRASE_METRICS_"]["HOTA"] + + local_results = {} + for metric_name in self.metric_to_collect: + result_key = f"{result_prefix}_{metric_name}" + local_results[result_key] = float(results[metric_name].mean()) + if (video_id, category_id) not in video_np_level_results: + video_np_level_results[(video_id, category_id)] = {} + video_np_level_results[(video_id, category_id)].update(local_results) + + +class VideoClassBasedHotaEvaluator(VideoPhraseHotaEvaluator): + def __init__( + self, + gt_ann_file: str, + dataset_name: str = "video", + prob_thresh: float = 0.5, + ): + super().__init__(gt_ann_file, dataset_name, prob_thresh) + self.metric_prefix = "class" + + def _remap_gt_dt(self, gt, dt): + return gt, dt # no remapping needed for class-based HOTA evaluation + + def extract_video_np_level_results(self, *args, **kwargs): + pass # no video-NP level results for class-based HOTA evaluation + + +def _compress_rle(rle): + """Convert RLEs from uncompressed (integer list) to compressed (string) format.""" + if rle is None: + return None + if isinstance(rle["counts"], list): + rle = pycocotools.mask.frPyObjects(rle, rle["size"][0], rle["size"][1]) + rle["counts"] = rle["counts"].decode() + return rle + + +def remap_video_category_pairs_to_unique_video_ids( + gt_json, dt_json, add_negative_np_pairs=False +): + """ + Remap each pair of (video_id, category_id) to a new unique video_id. This is useful + for phrase AP and demo F1 evaluation on videos, where we have `useCat=False` and + rely on separating different NPs (from the same video) into different new video ids, + so that we don't mix detections from different categories in computeIoU under `useCat=False`. + + This is consistent with how do we phrase AP and demo F1 evaluation on images, where we + use a remapped unique coco_image_id for each image-NP pair (based in its query["id"] in + CustomCocoDetectionAPI.load_queries in modulated_detection_api.py) + """ + # collect the unique video_id-category_id pairs + video_id_to_video = {v["id"]: v for v in gt_json["videos"]} + video_id_category_id_pairs = set() + for pred in dt_json: + video_id_category_id_pairs.add((pred["video_id"], pred["category_id"])) + for ann in gt_json["annotations"]: + video_id_category_id_pairs.add((ann["video_id"], ann["category_id"])) + + # assign the video_id-category_id pairs to unique video ids + video_id_category_id_pairs = sorted(video_id_category_id_pairs) + video_id_category_id_to_new_video_id = { + pair: (i + 1) for i, pair in enumerate(video_id_category_id_pairs) + } + # also map the negative NP pairs -- this is needed for IL_MCC and CG-F1 evaluation + if add_negative_np_pairs: + for vnp in gt_json["video_np_pairs"]: + pair = (vnp["video_id"], vnp["category_id"]) + if pair not in video_id_category_id_to_new_video_id: + video_id_category_id_to_new_video_id[pair] = ( + len(video_id_category_id_to_new_video_id) + 1 + ) + + # map the "video_id" in predictions + for pred in dt_json: + pred["video_id"] = video_id_category_id_to_new_video_id[ + (pred["video_id"], pred["category_id"]) + ] + # map the "video_id" in gt_json["annotations"] + for ann in gt_json["annotations"]: + ann["video_id"] = video_id_category_id_to_new_video_id[ + (ann["video_id"], ann["category_id"]) + ] + # map and duplicate gt_json["videos"] + new_videos = [] + for ( + video_id, + category_id, + ), new_video_id in video_id_category_id_to_new_video_id.items(): + video = video_id_to_video[video_id].copy() + video["id"] = new_video_id + # preserve the original video_id and category_id of each remapped video entry, + # so that we can associate sample-level eval metrics with the original video-NP pairs + video["orig_video_id"] = video_id + video["orig_category_id"] = category_id + new_videos.append(video) + gt_json["videos"] = new_videos + + return gt_json, dt_json + + +def remap_gt_dt_class_agnostic(gt, dt): + """ + For class-agnostic HOTA, merge all GT tracks for each video (across NPs), + ensure unique track_ids, and set all category_id to 1. + Also, add orig_video_id and orig_category_id for compatibility. + """ + # 1. Remap all GT track_ids to be unique per video + gt_anns_by_video = defaultdict(list) + for ann in gt["annotations"]: + gt_anns_by_video[ann["video_id"]].append(ann) + + # Ensure unique track ids across tracks of all videos + next_tid = 1 + for _, anns in gt_anns_by_video.items(): + # Map old track_ids to new unique ones + old_to_new_tid = {} + for ann in anns: + old_tid = ann["id"] + if old_tid not in old_to_new_tid: + old_to_new_tid[old_tid] = next_tid + next_tid += 1 + ann["id"] = old_to_new_tid[old_tid] + # Set category_id to 1 for class-agnostic + ann["category_id"] = 1 + + # Set all GT categories to a single category + gt["categories"] = [ + { + "supercategory": "object", + "id": 1, + "name": "_REMAPPED_FOR_PHRASE_METRICS_", + } + ] + + # Add orig_video_id and orig_category_id to each video for compatibility + anns_by_video = defaultdict(list) + for ann in gt["annotations"]: + anns_by_video[ann["video_id"]].append(ann) + for video in gt["videos"]: + video["orig_video_id"] = video["id"] + # Use the first annotation's original category_id if available, else None + orig_cat = ( + anns_by_video[video["id"]][0]["category_id"] + if anns_by_video[video["id"]] + else None + ) + video["orig_category_id"] = orig_cat + video["file_names"] = [ + f"remapped_vid_{video['id']:012d}/{name}" for name in video["file_names"] + ] + + # Set all DT category_id to 1 + for d in dt: + d["category_id"] = 1 + return gt, dt + + +def _fill_in_ann_height_width(gt_json): + """Fill in missing height/width in GT annotations from its video info.""" + video_id_to_video = {v["id"]: v for v in gt_json["videos"]} + for ann in gt_json["annotations"]: + if "height" not in ann or "width" not in ann: + video = video_id_to_video[ann["video_id"]] + if "height" not in ann: + ann["height"] = video["height"] + if "width" not in ann: + ann["width"] = video["width"] + + return gt_json diff --git a/third_party/sam3/sam3/eval/teta_eval_toolkit/__init__.py b/third_party/sam3/sam3/eval/teta_eval_toolkit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9420ee6f6a706e7b3d29e85cda71aa794727fd38 --- /dev/null +++ b/third_party/sam3/sam3/eval/teta_eval_toolkit/__init__.py @@ -0,0 +1,7 @@ +# fmt: off +# flake8: noqa + +# pyre-unsafe + +from . import config, datasets, metrics, utils +from .eval import Evaluator diff --git a/third_party/sam3/sam3/eval/teta_eval_toolkit/_timing.py b/third_party/sam3/sam3/eval/teta_eval_toolkit/_timing.py new file mode 100644 index 0000000000000000000000000000000000000000..72f195d4f06336565459104193138ef49da35da3 --- /dev/null +++ b/third_party/sam3/sam3/eval/teta_eval_toolkit/_timing.py @@ -0,0 +1,71 @@ +# fmt: off +# flake8: noqa + +# pyre-unsafe + +import inspect +from functools import wraps +from time import perf_counter + +DO_TIMING = False +DISPLAY_LESS_PROGRESS = False +timer_dict = {} +counter = 0 + + +def time(f): + @wraps(f) + def wrap(*args, **kw): + if DO_TIMING: + # Run function with timing + ts = perf_counter() + result = f(*args, **kw) + te = perf_counter() + tt = te - ts + + # Get function name + arg_names = inspect.getfullargspec(f)[0] + if arg_names[0] == "self" and DISPLAY_LESS_PROGRESS: + return result + elif arg_names[0] == "self": + method_name = type(args[0]).__name__ + "." + f.__name__ + else: + method_name = f.__name__ + + # Record accumulative time in each function for analysis + if method_name in timer_dict.keys(): + timer_dict[method_name] += tt + else: + timer_dict[method_name] = tt + + # If code is finished, display timing summary + if method_name == "Evaluator.evaluate": + print("") + print("Timing analysis:") + for key, value in timer_dict.items(): + print("%-70s %2.4f sec" % (key, value)) + else: + # Get function argument values for printing special arguments of interest + arg_titles = ["tracker", "seq", "cls"] + arg_vals = [] + for i, a in enumerate(arg_names): + if a in arg_titles: + arg_vals.append(args[i]) + arg_text = "(" + ", ".join(arg_vals) + ")" + + # Display methods and functions with different indentation. + if arg_names[0] == "self": + print("%-74s %2.4f sec" % (" " * 4 + method_name + arg_text, tt)) + elif arg_names[0] == "test": + pass + else: + global counter + counter += 1 + print("%i %-70s %2.4f sec" % (counter, method_name + arg_text, tt)) + + return result + else: + # If config["TIME_PROGRESS"] is false, or config["USE_PARALLEL"] is true, run functions normally without timing. + return f(*args, **kw) + + return wrap diff --git a/third_party/sam3/sam3/eval/teta_eval_toolkit/config.py b/third_party/sam3/sam3/eval/teta_eval_toolkit/config.py new file mode 100644 index 0000000000000000000000000000000000000000..003d8f5ba2f8d8051bb0e78357584a5dcb5d1481 --- /dev/null +++ b/third_party/sam3/sam3/eval/teta_eval_toolkit/config.py @@ -0,0 +1,155 @@ +# fmt: off +# flake8: noqa + +# pyre-unsafe + +"""Config.""" +import argparse +import os + + +def parse_configs(): + """Parse command line.""" + default_eval_config = get_default_eval_config() + default_eval_config["DISPLAY_LESS_PROGRESS"] = True + default_dataset_config = get_default_dataset_config() + default_metrics_config = {"METRICS": ["TETA"]} + config = { + **default_eval_config, + **default_dataset_config, + **default_metrics_config, + } + parser = argparse.ArgumentParser() + for setting in config.keys(): + if type(config[setting]) == list or type(config[setting]) == type(None): + parser.add_argument("--" + setting, nargs="+") + else: + parser.add_argument("--" + setting) + args = parser.parse_args().__dict__ + for setting in args.keys(): + if args[setting] is not None: + if type(config[setting]) == type(True): + if args[setting] == "True": + x = True + elif args[setting] == "False": + x = False + else: + raise Exception( + f"Command line parameter {setting} must be True/False" + ) + elif type(config[setting]) == type(1): + x = int(args[setting]) + elif type(args[setting]) == type(None): + x = None + else: + x = args[setting] + config[setting] = x + eval_config = {k: v for k, v in config.items() if k in default_eval_config.keys()} + dataset_config = { + k: v for k, v in config.items() if k in default_dataset_config.keys() + } + metrics_config = { + k: v for k, v in config.items() if k in default_metrics_config.keys() + } + + return eval_config, dataset_config, metrics_config + + +def get_default_eval_config(): + """Returns the default config values for evaluation.""" + code_path = get_code_path() + default_config = { + "USE_PARALLEL": True, + "NUM_PARALLEL_CORES": 8, + "BREAK_ON_ERROR": True, + "RETURN_ON_ERROR": False, + "LOG_ON_ERROR": os.path.join(code_path, "error_log.txt"), + "PRINT_RESULTS": True, + "PRINT_ONLY_COMBINED": True, + "PRINT_CONFIG": True, + "TIME_PROGRESS": True, + "DISPLAY_LESS_PROGRESS": True, + "OUTPUT_SUMMARY": True, + "OUTPUT_EMPTY_CLASSES": True, + "OUTPUT_TEM_RAW_DATA": True, + "OUTPUT_PER_SEQ_RES": True, + } + return default_config + + +def get_default_dataset_config(): + """Default class config values""" + code_path = get_code_path() + default_config = { + "GT_FOLDER": os.path.join( + code_path, "data/gt/tao/tao_training" + ), # Location of GT data + "TRACKERS_FOLDER": os.path.join( + code_path, "data/trackers/tao/tao_training" + ), # Trackers location + "OUTPUT_FOLDER": None, # Where to save eval results (if None, same as TRACKERS_FOLDER) + "TRACKERS_TO_EVAL": ['TETer'], # Filenames of trackers to eval (if None, all in folder) + "CLASSES_TO_EVAL": None, # Classes to eval (if None, all classes) + "SPLIT_TO_EVAL": "training", # Valid: 'training', 'val' + "PRINT_CONFIG": True, # Whether to print current config + "TRACKER_SUB_FOLDER": "data", # Tracker files are in TRACKER_FOLDER/tracker_name/TRACKER_SUB_FOLDER + "OUTPUT_SUB_FOLDER": "", # Output files are saved in OUTPUT_FOLDER/tracker_name/OUTPUT_SUB_FOLDER + "TRACKER_DISPLAY_NAMES": None, # Names of trackers to display, if None: TRACKERS_TO_EVAL + "MAX_DETECTIONS": 0, # Number of maximal allowed detections per image (0 for unlimited) + "USE_MASK": False, # Whether to use mask data for evaluation + } + return default_config + + +def init_config(config, default_config, name=None): + """Initialize non-given config values with defaults.""" + if config is None: + config = default_config + else: + for k in default_config.keys(): + if k not in config.keys(): + config[k] = default_config[k] + if name and config["PRINT_CONFIG"]: + print("\n%s Config:" % name) + for c in config.keys(): + print("%-20s : %-30s" % (c, config[c])) + return config + + +def update_config(config): + """ + Parse the arguments of a script and updates the config values for a given value if specified in the arguments. + :param config: the config to update + :return: the updated config + """ + parser = argparse.ArgumentParser() + for setting in config.keys(): + if type(config[setting]) == list or type(config[setting]) == type(None): + parser.add_argument("--" + setting, nargs="+") + else: + parser.add_argument("--" + setting) + args = parser.parse_args().__dict__ + for setting in args.keys(): + if args[setting] is not None: + if type(config[setting]) == type(True): + if args[setting] == "True": + x = True + elif args[setting] == "False": + x = False + else: + raise Exception( + "Command line parameter " + setting + "must be True or False" + ) + elif type(config[setting]) == type(1): + x = int(args[setting]) + elif type(args[setting]) == type(None): + x = None + else: + x = args[setting] + config[setting] = x + return config + + +def get_code_path(): + """Get base path where code is""" + return os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) diff --git a/third_party/sam3/sam3/eval/teta_eval_toolkit/datasets/__init__.py b/third_party/sam3/sam3/eval/teta_eval_toolkit/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2ef17b5cc4fc6ff8864242fb12723346f22dcc88 --- /dev/null +++ b/third_party/sam3/sam3/eval/teta_eval_toolkit/datasets/__init__.py @@ -0,0 +1,7 @@ +# fmt: off +# flake8: noqa + +# pyre-unsafe +"""Datasets.""" +from .coco import COCO +from .tao import TAO diff --git a/third_party/sam3/sam3/eval/teta_eval_toolkit/datasets/_base_dataset.py b/third_party/sam3/sam3/eval/teta_eval_toolkit/datasets/_base_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..c5a155135142a56e1e53694bb4f6769887eff897 --- /dev/null +++ b/third_party/sam3/sam3/eval/teta_eval_toolkit/datasets/_base_dataset.py @@ -0,0 +1,381 @@ +# fmt: off +# flake8: noqa + +# pyre-unsafe + +import csv +import io +import os +import traceback +import zipfile +from abc import ABC, abstractmethod +from copy import deepcopy + +import numpy as np + +from .. import _timing +from ..utils import TrackEvalException + + +class _BaseDataset(ABC): + @abstractmethod + def __init__(self): + self.tracker_list = None + self.seq_list = None + self.class_list = None + self.output_fol = None + self.output_sub_fol = None + self.should_classes_combine = True + self.use_super_categories = False + + # Functions to implement: + + @abstractmethod + def _load_raw_file(self, tracker, seq, is_gt): + ... + + @_timing.time + @abstractmethod + def get_preprocessed_seq_data(self, raw_data, cls): + ... + + @abstractmethod + def _calculate_similarities(self, gt_dets_t, tracker_dets_t): + ... + + # Helper functions for all datasets: + + @classmethod + def get_class_name(cls): + return cls.__name__ + + def get_name(self): + return self.get_class_name() + + def get_output_fol(self, tracker): + return os.path.join(self.output_fol, tracker, self.output_sub_fol) + + def get_display_name(self, tracker): + """Can be overwritten if the trackers name (in files) is different to how it should be displayed. + By default this method just returns the trackers name as is. + """ + return tracker + + def get_eval_info(self): + """Return info about the dataset needed for the Evaluator""" + return self.tracker_list, self.seq_list, self.class_list + + @_timing.time + def get_raw_seq_data(self, tracker, seq): + """Loads raw data (tracker and ground-truth) for a single tracker on a single sequence. + Raw data includes all of the information needed for both preprocessing and evaluation, for all classes. + A later function (get_processed_seq_data) will perform such preprocessing and extract relevant information for + the evaluation of each class. + + This returns a dict which contains the fields: + [num_timesteps]: integer + [gt_ids, tracker_ids, gt_classes, tracker_classes, tracker_confidences]: + list (for each timestep) of 1D NDArrays (for each det). + [gt_dets, tracker_dets, gt_crowd_ignore_regions]: list (for each timestep) of lists of detections. + [similarity_scores]: list (for each timestep) of 2D NDArrays. + [gt_extras]: dict (for each extra) of lists (for each timestep) of 1D NDArrays (for each det). + + gt_extras contains dataset specific information used for preprocessing such as occlusion and truncation levels. + + Note that similarities are extracted as part of the dataset and not the metric, because almost all metrics are + independent of the exact method of calculating the similarity. However datasets are not (e.g. segmentation + masks vs 2D boxes vs 3D boxes). + We calculate the similarity before preprocessing because often both preprocessing and evaluation require it and + we don't wish to calculate this twice. + We calculate similarity between all gt and tracker classes (not just each class individually) to allow for + calculation of metrics such as class confusion matrices. Typically the impact of this on performance is low. + """ + # Load raw data. + raw_gt_data = self._load_raw_file(tracker, seq, is_gt=True) + raw_tracker_data = self._load_raw_file(tracker, seq, is_gt=False) + raw_data = {**raw_tracker_data, **raw_gt_data} # Merges dictionaries + + # Calculate similarities for each timestep. + similarity_scores = [] + for _, (gt_dets_t, tracker_dets_t) in enumerate( + zip(raw_data["gt_dets"], raw_data["tk_dets"]) + ): + ious = self._calculate_similarities(gt_dets_t, tracker_dets_t) + similarity_scores.append(ious) + raw_data["similarity_scores"] = similarity_scores + return raw_data + + @staticmethod + def _load_simple_text_file( + file, + time_col=0, + id_col=None, + remove_negative_ids=False, + valid_filter=None, + crowd_ignore_filter=None, + convert_filter=None, + is_zipped=False, + zip_file=None, + force_delimiters=None, + ): + """Function that loads data which is in a commonly used text file format. + Assumes each det is given by one row of a text file. + There is no limit to the number or meaning of each column, + however one column needs to give the timestep of each det (time_col) which is default col 0. + + The file dialect (deliminator, num cols, etc) is determined automatically. + This function automatically separates dets by timestep, + and is much faster than alternatives such as np.loadtext or pandas. + + If remove_negative_ids is True and id_col is not None, dets with negative values in id_col are excluded. + These are not excluded from ignore data. + + valid_filter can be used to only include certain classes. + It is a dict with ints as keys, and lists as values, + such that a row is included if "row[key].lower() is in value" for all key/value pairs in the dict. + If None, all classes are included. + + crowd_ignore_filter can be used to read crowd_ignore regions separately. It has the same format as valid filter. + + convert_filter can be used to convert value read to another format. + This is used most commonly to convert classes given as string to a class id. + This is a dict such that the key is the column to convert, and the value is another dict giving the mapping. + + Optionally, input files could be a zip of multiple text files for storage efficiency. + + Returns read_data and ignore_data. + Each is a dict (with keys as timesteps as strings) of lists (over dets) of lists (over column values). + Note that all data is returned as strings, and must be converted to float/int later if needed. + Note that timesteps will not be present in the returned dict keys if there are no dets for them + """ + + if remove_negative_ids and id_col is None: + raise TrackEvalException( + "remove_negative_ids is True, but id_col is not given." + ) + if crowd_ignore_filter is None: + crowd_ignore_filter = {} + if convert_filter is None: + convert_filter = {} + try: + if is_zipped: # Either open file directly or within a zip. + if zip_file is None: + raise TrackEvalException( + "is_zipped set to True, but no zip_file is given." + ) + archive = zipfile.ZipFile(os.path.join(zip_file), "r") + fp = io.TextIOWrapper(archive.open(file, "r")) + else: + fp = open(file) + read_data = {} + crowd_ignore_data = {} + fp.seek(0, os.SEEK_END) + # check if file is empty + if fp.tell(): + fp.seek(0) + dialect = csv.Sniffer().sniff( + fp.readline(), delimiters=force_delimiters + ) # Auto determine structure. + dialect.skipinitialspace = ( + True # Deal with extra spaces between columns + ) + fp.seek(0) + reader = csv.reader(fp, dialect) + for row in reader: + try: + # Deal with extra trailing spaces at the end of rows + if row[-1] in "": + row = row[:-1] + timestep = str(int(float(row[time_col]))) + # Read ignore regions separately. + is_ignored = False + for ignore_key, ignore_value in crowd_ignore_filter.items(): + if row[ignore_key].lower() in ignore_value: + # Convert values in one column (e.g. string to id) + for ( + convert_key, + convert_value, + ) in convert_filter.items(): + row[convert_key] = convert_value[ + row[convert_key].lower() + ] + # Save data separated by timestep. + if timestep in crowd_ignore_data.keys(): + crowd_ignore_data[timestep].append(row) + else: + crowd_ignore_data[timestep] = [row] + is_ignored = True + if ( + is_ignored + ): # if det is an ignore region, it cannot be a normal det. + continue + # Exclude some dets if not valid. + if valid_filter is not None: + for key, value in valid_filter.items(): + if row[key].lower() not in value: + continue + if remove_negative_ids: + if int(float(row[id_col])) < 0: + continue + # Convert values in one column (e.g. string to id) + for convert_key, convert_value in convert_filter.items(): + row[convert_key] = convert_value[row[convert_key].lower()] + # Save data separated by timestep. + if timestep in read_data.keys(): + read_data[timestep].append(row) + else: + read_data[timestep] = [row] + except Exception: + exc_str_init = ( + "In file %s the following line cannot be read correctly: \n" + % os.path.basename(file) + ) + exc_str = " ".join([exc_str_init] + row) + raise TrackEvalException(exc_str) + fp.close() + except Exception: + print("Error loading file: %s, printing traceback." % file) + traceback.print_exc() + raise TrackEvalException( + "File %s cannot be read because it is either not present or invalidly formatted" + % os.path.basename(file) + ) + return read_data, crowd_ignore_data + + @staticmethod + def _calculate_mask_ious(masks1, masks2, is_encoded=False, do_ioa=False): + """Calculates the IOU (intersection over union) between two arrays of segmentation masks. + If is_encoded a run length encoding with pycocotools is assumed as input format, otherwise an input of numpy + arrays of the shape (num_masks, height, width) is assumed and the encoding is performed. + If do_ioa (intersection over area) , then calculates the intersection over the area of masks1 - this is commonly + used to determine if detections are within crowd ignore region. + :param masks1: first set of masks (numpy array of shape (num_masks, height, width) if not encoded, + else pycocotools rle encoded format) + :param masks2: second set of masks (numpy array of shape (num_masks, height, width) if not encoded, + else pycocotools rle encoded format) + :param is_encoded: whether the input is in pycocotools rle encoded format + :param do_ioa: whether to perform IoA computation + :return: the IoU/IoA scores + """ + + # Only loaded when run to reduce minimum requirements + from pycocotools import mask as mask_utils + + # use pycocotools for run length encoding of masks + if not is_encoded: + masks1 = mask_utils.encode( + np.array(np.transpose(masks1, (1, 2, 0)), order="F") + ) + masks2 = mask_utils.encode( + np.array(np.transpose(masks2, (1, 2, 0)), order="F") + ) + + # use pycocotools for iou computation of rle encoded masks + ious = mask_utils.iou(masks1, masks2, [do_ioa] * len(masks2)) + if len(masks1) == 0 or len(masks2) == 0: + ious = np.asarray(ious).reshape(len(masks1), len(masks2)) + assert (ious >= 0 - np.finfo("float").eps).all() + assert (ious <= 1 + np.finfo("float").eps).all() + + return ious + + @staticmethod + def _calculate_box_ious(bboxes1, bboxes2, box_format="xywh", do_ioa=False): + """Calculates the IOU (intersection over union) between two arrays of boxes. + Allows variable box formats ('xywh' and 'x0y0x1y1'). + If do_ioa (intersection over area) , then calculates the intersection over the area of boxes1 - this is commonly + used to determine if detections are within crowd ignore region. + """ + if box_format in "xywh": + # layout: (x0, y0, w, h) + bboxes1 = deepcopy(bboxes1) + bboxes2 = deepcopy(bboxes2) + + bboxes1[:, 2] = bboxes1[:, 0] + bboxes1[:, 2] + bboxes1[:, 3] = bboxes1[:, 1] + bboxes1[:, 3] + bboxes2[:, 2] = bboxes2[:, 0] + bboxes2[:, 2] + bboxes2[:, 3] = bboxes2[:, 1] + bboxes2[:, 3] + elif box_format not in "x0y0x1y1": + raise (TrackEvalException("box_format %s is not implemented" % box_format)) + + # layout: (x0, y0, x1, y1) + min_ = np.minimum(bboxes1[:, np.newaxis, :], bboxes2[np.newaxis, :, :]) + max_ = np.maximum(bboxes1[:, np.newaxis, :], bboxes2[np.newaxis, :, :]) + intersection = np.maximum(min_[..., 2] - max_[..., 0], 0) * np.maximum( + min_[..., 3] - max_[..., 1], 0 + ) + area1 = (bboxes1[..., 2] - bboxes1[..., 0]) * ( + bboxes1[..., 3] - bboxes1[..., 1] + ) + + if do_ioa: + ioas = np.zeros_like(intersection) + valid_mask = area1 > 0 + np.finfo("float").eps + ioas[valid_mask, :] = ( + intersection[valid_mask, :] / area1[valid_mask][:, np.newaxis] + ) + + return ioas + else: + area2 = (bboxes2[..., 2] - bboxes2[..., 0]) * ( + bboxes2[..., 3] - bboxes2[..., 1] + ) + union = area1[:, np.newaxis] + area2[np.newaxis, :] - intersection + intersection[area1 <= 0 + np.finfo("float").eps, :] = 0 + intersection[:, area2 <= 0 + np.finfo("float").eps] = 0 + intersection[union <= 0 + np.finfo("float").eps] = 0 + union[union <= 0 + np.finfo("float").eps] = 1 + ious = intersection / union + return ious + + @staticmethod + def _calculate_euclidean_similarity(dets1, dets2, zero_distance=2.0): + """Calculates the euclidean distance between two sets of detections, and then converts this into a similarity + measure with values between 0 and 1 using the following formula: sim = max(0, 1 - dist/zero_distance). + The default zero_distance of 2.0, corresponds to the default used in MOT15_3D, such that a 0.5 similarity + threshold corresponds to a 1m distance threshold for TPs. + """ + dist = np.linalg.norm(dets1[:, np.newaxis] - dets2[np.newaxis, :], axis=2) + sim = np.maximum(0, 1 - dist / zero_distance) + return sim + + @staticmethod + def _check_unique_ids(data, after_preproc=False): + """Check the requirement that the tracker_ids and gt_ids are unique per timestep""" + gt_ids = data["gt_ids"] + tracker_ids = data["tk_ids"] + for t, (gt_ids_t, tracker_ids_t) in enumerate(zip(gt_ids, tracker_ids)): + if len(tracker_ids_t) > 0: + unique_ids, counts = np.unique(tracker_ids_t, return_counts=True) + if np.max(counts) != 1: + duplicate_ids = unique_ids[counts > 1] + exc_str_init = ( + "Tracker predicts the same ID more than once in a single timestep " + "(seq: %s, frame: %i, ids:" % (data["seq"], t + 1) + ) + exc_str = ( + " ".join([exc_str_init] + [str(d) for d in duplicate_ids]) + ")" + ) + if after_preproc: + exc_str_init += ( + "\n Note that this error occurred after preprocessing (but not before), " + "so ids may not be as in file, and something seems wrong with preproc." + ) + raise TrackEvalException(exc_str) + if len(gt_ids_t) > 0: + unique_ids, counts = np.unique(gt_ids_t, return_counts=True) + if np.max(counts) != 1: + duplicate_ids = unique_ids[counts > 1] + exc_str_init = ( + "Ground-truth has the same ID more than once in a single timestep " + "(seq: %s, frame: %i, ids:" % (data["seq"], t + 1) + ) + exc_str = ( + " ".join([exc_str_init] + [str(d) for d in duplicate_ids]) + ")" + ) + if after_preproc: + exc_str_init += ( + "\n Note that this error occurred after preprocessing (but not before), " + "so ids may not be as in file, and something seems wrong with preproc." + ) + raise TrackEvalException(exc_str) diff --git a/third_party/sam3/sam3/eval/teta_eval_toolkit/datasets/coco.py b/third_party/sam3/sam3/eval/teta_eval_toolkit/datasets/coco.py new file mode 100644 index 0000000000000000000000000000000000000000..dcbbf901f79d4a8ddad9202ba196d45d5bf6bfda --- /dev/null +++ b/third_party/sam3/sam3/eval/teta_eval_toolkit/datasets/coco.py @@ -0,0 +1,639 @@ +# fmt: off +# flake8: noqa + +# pyre-unsafe + +"""COCO Dataset.""" +import copy +import itertools +import json +import os +from collections import defaultdict + +import numpy as np +from scipy.optimize import linear_sum_assignment + +from .. import _timing, utils +from ..config import get_default_dataset_config, init_config +from ..utils import TrackEvalException +from ._base_dataset import _BaseDataset + + +class COCO(_BaseDataset): + """Tracking datasets in COCO format.""" + + def __init__(self, config=None): + """Initialize dataset, checking that all required files are present.""" + super().__init__() + # Fill non-given config values with defaults + self.config = init_config(config, get_default_dataset_config(), self.get_name()) + self.gt_fol = self.config["GT_FOLDER"] + self.tracker_fol = self.config["TRACKERS_FOLDER"] + self.should_classes_combine = True + self.use_super_categories = False + self.use_mask = self.config["USE_MASK"] + + self.tracker_sub_fol = self.config["TRACKER_SUB_FOLDER"] + self.output_fol = self.config["OUTPUT_FOLDER"] + if self.output_fol is None: + self.output_fol = self.tracker_fol + self.output_sub_fol = self.config["OUTPUT_SUB_FOLDER"] + + if self.gt_fol.endswith(".json"): + self.gt_data = json.load(open(self.gt_fol, "r")) + else: + gt_dir_files = [ + file for file in os.listdir(self.gt_fol) if file.endswith(".json") + ] + if len(gt_dir_files) != 1: + raise TrackEvalException( + f"{self.gt_fol} does not contain exactly one json file." + ) + + with open(os.path.join(self.gt_fol, gt_dir_files[0])) as f: + self.gt_data = json.load(f) + + # fill missing video ids + self._fill_video_ids_inplace(self.gt_data["annotations"]) + + # get sequences to eval and sequence information + self.seq_list = [ + vid["name"].replace("/", "-") for vid in self.gt_data["videos"] + ] + self.seq_name2seqid = { + vid["name"].replace("/", "-"): vid["id"] for vid in self.gt_data["videos"] + } + # compute mappings from videos to annotation data + self.video2gt_track, self.video2gt_image = self._compute_vid_mappings( + self.gt_data["annotations"] + ) + # compute sequence lengths + self.seq_lengths = {vid["id"]: 0 for vid in self.gt_data["videos"]} + for img in self.gt_data["images"]: + self.seq_lengths[img["video_id"]] += 1 + self.seq2images2timestep = self._compute_image_to_timestep_mappings() + self.seq2cls = { + vid["id"]: { + "pos_cat_ids": list( + {track["category_id"] for track in self.video2gt_track[vid["id"]]} + ), + } + for vid in self.gt_data["videos"] + } + + # Get classes to eval + considered_vid_ids = [self.seq_name2seqid[vid] for vid in self.seq_list] + seen_cats = set( + [ + cat_id + for vid_id in considered_vid_ids + for cat_id in self.seq2cls[vid_id]["pos_cat_ids"] + ] + ) + # only classes with ground truth are evaluated in TAO + self.valid_classes = [ + cls["name"] for cls in self.gt_data["categories"] if cls["id"] in seen_cats + ] + cls_name2clsid_map = { + cls["name"]: cls["id"] for cls in self.gt_data["categories"] + } + + if self.config["CLASSES_TO_EVAL"]: + self.class_list = [ + cls.lower() if cls.lower() in self.valid_classes else None + for cls in self.config["CLASSES_TO_EVAL"] + ] + if not all(self.class_list): + valid_cls = ", ".join(self.valid_classes) + raise TrackEvalException( + "Attempted to evaluate an invalid class. Only classes " + f"{valid_cls} are valid (classes present in ground truth" + " data)." + ) + else: + self.class_list = [cls for cls in self.valid_classes] + self.cls_name2clsid = { + k: v for k, v in cls_name2clsid_map.items() if k in self.class_list + } + self.clsid2cls_name = { + v: k for k, v in cls_name2clsid_map.items() if k in self.class_list + } + # get trackers to eval + if self.config["TRACKERS_TO_EVAL"] is None: + self.tracker_list = os.listdir(self.tracker_fol) + else: + self.tracker_list = self.config["TRACKERS_TO_EVAL"] + + if self.config["TRACKER_DISPLAY_NAMES"] is None: + self.tracker_to_disp = dict(zip(self.tracker_list, self.tracker_list)) + elif (self.config["TRACKERS_TO_EVAL"] is not None) and ( + len(self.config["TK_DISPLAY_NAMES"]) == len(self.tracker_list) + ): + self.tracker_to_disp = dict( + zip(self.tracker_list, self.config["TK_DISPLAY_NAMES"]) + ) + else: + raise TrackEvalException( + "List of tracker files and tracker display names do not match." + ) + + self.tracker_data = {tracker: dict() for tracker in self.tracker_list} + + for tracker in self.tracker_list: + if self.tracker_sub_fol.endswith(".json"): + with open(os.path.join(self.tracker_sub_fol)) as f: + curr_data = json.load(f) + else: + tr_dir = os.path.join(self.tracker_fol, tracker, self.tracker_sub_fol) + tr_dir_files = [ + file for file in os.listdir(tr_dir) if file.endswith(".json") + ] + if len(tr_dir_files) != 1: + raise TrackEvalException( + f"{tr_dir} does not contain exactly one json file." + ) + with open(os.path.join(tr_dir, tr_dir_files[0])) as f: + curr_data = json.load(f) + + # limit detections if MAX_DETECTIONS > 0 + if self.config["MAX_DETECTIONS"]: + curr_data = self._limit_dets_per_image(curr_data) + + # fill missing video ids + self._fill_video_ids_inplace(curr_data) + + # make track ids unique over whole evaluation set + self._make_tk_ids_unique(curr_data) + + # get tracker sequence information + curr_vids2tracks, curr_vids2images = self._compute_vid_mappings(curr_data) + self.tracker_data[tracker]["vids_to_tracks"] = curr_vids2tracks + self.tracker_data[tracker]["vids_to_images"] = curr_vids2images + + def get_display_name(self, tracker): + return self.tracker_to_disp[tracker] + + def _load_raw_file(self, tracker, seq, is_gt): + """Load a file (gt or tracker) in the TAO format + + If is_gt, this returns a dict which contains the fields: + [gt_ids, gt_classes]: + list (for each timestep) of 1D NDArrays (for each det). + [gt_dets]: list (for each timestep) of lists of detections. + + if not is_gt, this returns a dict which contains the fields: + [tk_ids, tk_classes]: + list (for each timestep) of 1D NDArrays (for each det). + [tk_dets]: list (for each timestep) of lists of detections. + """ + seq_id = self.seq_name2seqid[seq] + # file location + if is_gt: + imgs = self.video2gt_image[seq_id] + else: + imgs = self.tracker_data[tracker]["vids_to_images"][seq_id] + + # convert data to required format + num_timesteps = self.seq_lengths[seq_id] + img_to_timestep = self.seq2images2timestep[seq_id] + data_keys = ["ids", "classes", "dets"] + # if not is_gt: + # data_keys += ["tk_confidences"] + raw_data = {key: [None] * num_timesteps for key in data_keys} + for img in imgs: + # some tracker data contains images without any ground truth info, + # these are ignored + if img["id"] not in img_to_timestep: + continue + t = img_to_timestep[img["id"]] + anns = img["annotations"] + tk_str = utils.get_track_id_str(anns[0]) + + if self.use_mask: + # When using mask, extract segmentation data + raw_data["dets"][t] = [ann.get("segmentation") for ann in anns] + else: + # When using bbox, extract bbox data + raw_data["dets"][t] = np.atleast_2d([ann["bbox"] for ann in anns]).astype( + float + ) + raw_data["ids"][t] = np.atleast_1d([ann[tk_str] for ann in anns]).astype( + int + ) + raw_data["classes"][t] = np.atleast_1d( + [ann["category_id"] for ann in anns] + ).astype(int) + # if not is_gt: + # raw_data["tk_confidences"][t] = np.atleast_1d( + # [ann["score"] for ann in anns] + # ).astype(float) + + for t, d in enumerate(raw_data["dets"]): + if d is None: + raw_data["dets"][t] = np.empty((0, 4)).astype(float) + raw_data["ids"][t] = np.empty(0).astype(int) + raw_data["classes"][t] = np.empty(0).astype(int) + # if not is_gt: + # raw_data["tk_confidences"][t] = np.empty(0) + + if is_gt: + key_map = {"ids": "gt_ids", "classes": "gt_classes", "dets": "gt_dets"} + else: + key_map = {"ids": "tk_ids", "classes": "tk_classes", "dets": "tk_dets"} + for k, v in key_map.items(): + raw_data[v] = raw_data.pop(k) + + raw_data["num_timesteps"] = num_timesteps + raw_data["seq"] = seq + return raw_data + + def get_preprocessed_seq_data_thr(self, raw_data, cls, assignment=None): + """Preprocess data for a single sequence for a single class. + + Inputs: + raw_data: dict containing the data for the sequence already + read in by get_raw_seq_data(). + cls: class to be evaluated. + Outputs: + gt_ids: + list (for each timestep) of ids of GT tracks + tk_ids: + list (for each timestep) of ids of predicted tracks (all for TP + matching (Det + AssocA)) + tk_overlap_ids: + list (for each timestep) of ids of predicted tracks that overlap + with GTs + tk_dets: + list (for each timestep) of lists of detections that + corresponding to the tk_ids + tk_classes: + list (for each timestep) of lists of classes that corresponding + to the tk_ids + tk_confidences: + list (for each timestep) of lists of classes that corresponding + to the tk_ids + sim_scores: + similarity score between gt_ids and tk_ids. + """ + if cls != "all": + cls_id = self.cls_name2clsid[cls] + + data_keys = [ + "gt_ids", + "tk_ids", + "gt_id_map", + "tk_id_map", + "gt_dets", + "gt_classes", + "gt_class_name", + "tk_overlap_classes", + "tk_overlap_ids", + "tk_class_eval_tk_ids", + "tk_dets", + "tk_classes", + # "tk_confidences", + "tk_exh_ids", + "sim_scores", + ] + data = {key: [None] * raw_data["num_timesteps"] for key in data_keys} + unique_gt_ids = [] + unique_tk_ids = [] + num_gt_dets = 0 + num_tk_cls_dets = 0 + num_tk_overlap_dets = 0 + overlap_ious_thr = 0.5 + loc_and_asso_tk_ids = [] + exh_class_tk_ids = [] + + for t in range(raw_data["num_timesteps"]): + # only extract relevant dets for this class for preproc and eval + if cls == "all": + gt_class_mask = np.ones_like(raw_data["gt_classes"][t]).astype(bool) + else: + gt_class_mask = np.atleast_1d( + raw_data["gt_classes"][t] == cls_id + ).astype(bool) + + # select GT that is not in the evaluating classes + if assignment is not None and assignment: + all_gt_ids = list(assignment[t].keys()) + gt_ids_in = raw_data["gt_ids"][t][gt_class_mask] + gt_ids_out = set(all_gt_ids) - set(gt_ids_in) + tk_ids_out = set([assignment[t][key] for key in list(gt_ids_out)]) + + # compute overlapped tracks and add their ids to overlap_tk_ids + sim_scores = raw_data["similarity_scores"] + overlap_ids_masks = (sim_scores[t][gt_class_mask] >= overlap_ious_thr).any( + axis=0 + ) + overlap_tk_ids_t = raw_data["tk_ids"][t][overlap_ids_masks] + if assignment is not None and assignment: + data["tk_overlap_ids"][t] = list(set(overlap_tk_ids_t) - tk_ids_out) + else: + data["tk_overlap_ids"][t] = list(set(overlap_tk_ids_t)) + + loc_and_asso_tk_ids += data["tk_overlap_ids"][t] + + data["tk_exh_ids"][t] = [] + if cls == "all": + continue + + # add the track ids of exclusive annotated class to exh_class_tk_ids + tk_exh_mask = np.atleast_1d(raw_data["tk_classes"][t] == cls_id) + tk_exh_mask = tk_exh_mask.astype(bool) + exh_class_tk_ids_t = raw_data["tk_ids"][t][tk_exh_mask] + exh_class_tk_ids.append(exh_class_tk_ids_t) + data["tk_exh_ids"][t] = exh_class_tk_ids_t + + # remove tk_ids that has been assigned to GT belongs to other classes. + loc_and_asso_tk_ids = list(set(loc_and_asso_tk_ids)) + + # remove all unwanted unmatched tracker detections + for t in range(raw_data["num_timesteps"]): + # add gt to the data + if cls == "all": + gt_class_mask = np.ones_like(raw_data["gt_classes"][t]).astype(bool) + else: + gt_class_mask = np.atleast_1d( + raw_data["gt_classes"][t] == cls_id + ).astype(bool) + data["gt_classes"][t] = cls_id + data["gt_class_name"][t] = cls + + gt_ids = raw_data["gt_ids"][t][gt_class_mask] + if self.use_mask: + gt_dets = [raw_data['gt_dets'][t][ind] for ind in range(len(gt_class_mask)) if gt_class_mask[ind]] + else: + gt_dets = raw_data["gt_dets"][t][gt_class_mask] + data["gt_ids"][t] = gt_ids + data["gt_dets"][t] = gt_dets + + # filter pred and only keep those that highly overlap with GTs + tk_mask = np.isin( + raw_data["tk_ids"][t], np.array(loc_and_asso_tk_ids), assume_unique=True + ) + tk_overlap_mask = np.isin( + raw_data["tk_ids"][t], + np.array(data["tk_overlap_ids"][t]), + assume_unique=True, + ) + + tk_ids = raw_data["tk_ids"][t][tk_mask] + if self.use_mask: + tk_dets = [raw_data['tk_dets'][t][ind] for ind in range(len(tk_mask)) if + tk_mask[ind]] + else: + tk_dets = raw_data["tk_dets"][t][tk_mask] + + tracker_classes = raw_data["tk_classes"][t][tk_mask] + + # add overlap classes for computing the FP for Cls term + tracker_overlap_classes = raw_data["tk_classes"][t][tk_overlap_mask] + # tracker_confidences = raw_data["tk_confidences"][t][tk_mask] + sim_scores_masked = sim_scores[t][gt_class_mask, :][:, tk_mask] + + # add filtered prediction to the data + data["tk_classes"][t] = tracker_classes + data["tk_overlap_classes"][t] = tracker_overlap_classes + data["tk_ids"][t] = tk_ids + data["tk_dets"][t] = tk_dets + # data["tk_confidences"][t] = tracker_confidences + data["sim_scores"][t] = sim_scores_masked + data["tk_class_eval_tk_ids"][t] = set( + list(data["tk_overlap_ids"][t]) + list(data["tk_exh_ids"][t]) + ) + + # count total number of detections + unique_gt_ids += list(np.unique(data["gt_ids"][t])) + # the unique track ids are for association. + unique_tk_ids += list(np.unique(data["tk_ids"][t])) + + num_tk_overlap_dets += len(data["tk_overlap_ids"][t]) + num_tk_cls_dets += len(data["tk_class_eval_tk_ids"][t]) + num_gt_dets += len(data["gt_ids"][t]) + + # re-label IDs such that there are no empty IDs + if len(unique_gt_ids) > 0: + unique_gt_ids = np.unique(unique_gt_ids) + gt_id_map = np.nan * np.ones((np.max(unique_gt_ids) + 1)) + gt_id_map[unique_gt_ids] = np.arange(len(unique_gt_ids)) + data["gt_id_map"] = {} + for gt_id in unique_gt_ids: + new_gt_id = gt_id_map[gt_id].astype(int) + data["gt_id_map"][new_gt_id] = gt_id + + for t in range(raw_data["num_timesteps"]): + if len(data["gt_ids"][t]) > 0: + data["gt_ids"][t] = gt_id_map[data["gt_ids"][t]].astype(int) + + if len(unique_tk_ids) > 0: + unique_tk_ids = np.unique(unique_tk_ids) + tk_id_map = np.nan * np.ones((np.max(unique_tk_ids) + 1)) + tk_id_map[unique_tk_ids] = np.arange(len(unique_tk_ids)) + + data["tk_id_map"] = {} + for track_id in unique_tk_ids: + new_track_id = tk_id_map[track_id].astype(int) + data["tk_id_map"][new_track_id] = track_id + + for t in range(raw_data["num_timesteps"]): + if len(data["tk_ids"][t]) > 0: + data["tk_ids"][t] = tk_id_map[data["tk_ids"][t]].astype(int) + if len(data["tk_overlap_ids"][t]) > 0: + data["tk_overlap_ids"][t] = tk_id_map[ + data["tk_overlap_ids"][t] + ].astype(int) + + # record overview statistics. + data["num_tk_cls_dets"] = num_tk_cls_dets + data["num_tk_overlap_dets"] = num_tk_overlap_dets + data["num_gt_dets"] = num_gt_dets + data["num_tk_ids"] = len(unique_tk_ids) + data["num_gt_ids"] = len(unique_gt_ids) + data["num_timesteps"] = raw_data["num_timesteps"] + data["seq"] = raw_data["seq"] + + self._check_unique_ids(data) + + return data + + @_timing.time + def get_preprocessed_seq_data( + self, raw_data, cls, assignment=None, thresholds=[50, 75] + ): + """Preprocess data for a single sequence for a single class.""" + data = {} + if thresholds is None: + thresholds = [50, 75] + elif isinstance(thresholds, int): + thresholds = [thresholds] + + for thr in thresholds: + assignment_thr = None + if assignment is not None: + assignment_thr = assignment[thr] + data[thr] = self.get_preprocessed_seq_data_thr( + raw_data, cls, assignment_thr + ) + + return data + + def _calculate_similarities(self, gt_dets_t, tk_dets_t): + """Compute similarity scores.""" + if self.use_mask: + similarity_scores = self._calculate_mask_ious(gt_dets_t, tk_dets_t, is_encoded=True, do_ioa=False) + else: + similarity_scores = self._calculate_box_ious(gt_dets_t, tk_dets_t) + return similarity_scores + + def _compute_vid_mappings(self, annotations): + """Computes mappings from videos to corresponding tracks and images.""" + vids_to_tracks = {} + vids_to_imgs = {} + vid_ids = [vid["id"] for vid in self.gt_data["videos"]] + + # compute an mapping from image IDs to images + images = {} + for image in self.gt_data["images"]: + images[image["id"]] = image + + tk_str = utils.get_track_id_str(annotations[0]) + for ann in annotations: + ann["area"] = ann["bbox"][2] * ann["bbox"][3] + + vid = ann["video_id"] + if ann["video_id"] not in vids_to_tracks.keys(): + vids_to_tracks[ann["video_id"]] = list() + if ann["video_id"] not in vids_to_imgs.keys(): + vids_to_imgs[ann["video_id"]] = list() + + # fill in vids_to_tracks + tid = ann[tk_str] + exist_tids = [track["id"] for track in vids_to_tracks[vid]] + try: + index1 = exist_tids.index(tid) + except ValueError: + index1 = -1 + if tid not in exist_tids: + curr_track = { + "id": tid, + "category_id": ann["category_id"], + "video_id": vid, + "annotations": [ann], + } + vids_to_tracks[vid].append(curr_track) + else: + vids_to_tracks[vid][index1]["annotations"].append(ann) + + # fill in vids_to_imgs + img_id = ann["image_id"] + exist_img_ids = [img["id"] for img in vids_to_imgs[vid]] + try: + index2 = exist_img_ids.index(img_id) + except ValueError: + index2 = -1 + if index2 == -1: + curr_img = {"id": img_id, "annotations": [ann]} + vids_to_imgs[vid].append(curr_img) + else: + vids_to_imgs[vid][index2]["annotations"].append(ann) + + # sort annotations by frame index and compute track area + for vid, tracks in vids_to_tracks.items(): + for track in tracks: + track["annotations"] = sorted( + track["annotations"], + key=lambda x: images[x["image_id"]]["frame_id"], + ) + # compute average area + track["area"] = sum(x["area"] for x in track["annotations"]) / len( + track["annotations"] + ) + + # ensure all videos are present + for vid_id in vid_ids: + if vid_id not in vids_to_tracks.keys(): + vids_to_tracks[vid_id] = [] + if vid_id not in vids_to_imgs.keys(): + vids_to_imgs[vid_id] = [] + + return vids_to_tracks, vids_to_imgs + + def _compute_image_to_timestep_mappings(self): + """Computes a mapping from images to timestep in sequence.""" + images = {} + for image in self.gt_data["images"]: + images[image["id"]] = image + + seq_to_imgs_to_timestep = {vid["id"]: dict() for vid in self.gt_data["videos"]} + for vid in seq_to_imgs_to_timestep: + curr_imgs = [img["id"] for img in self.video2gt_image[vid]] + curr_imgs = sorted(curr_imgs, key=lambda x: images[x]["frame_id"]) + seq_to_imgs_to_timestep[vid] = { + curr_imgs[i]: i for i in range(len(curr_imgs)) + } + + return seq_to_imgs_to_timestep + + def _limit_dets_per_image(self, annotations): + """Limits the number of detections for each image. + + Adapted from https://github.com/TAO-Dataset/. + """ + max_dets = self.config["MAX_DETECTIONS"] + img_ann = defaultdict(list) + for ann in annotations: + img_ann[ann["image_id"]].append(ann) + + for img_id, _anns in img_ann.items(): + if len(_anns) <= max_dets: + continue + _anns = sorted(_anns, key=lambda x: x["score"], reverse=True) + img_ann[img_id] = _anns[:max_dets] + + return [ann for anns in img_ann.values() for ann in anns] + + def _fill_video_ids_inplace(self, annotations): + """Fills in missing video IDs inplace. + + Adapted from https://github.com/TAO-Dataset/. + """ + missing_video_id = [x for x in annotations if "video_id" not in x] + if missing_video_id: + image_id_to_video_id = { + x["id"]: x["video_id"] for x in self.gt_data["images"] + } + for x in missing_video_id: + x["video_id"] = image_id_to_video_id[x["image_id"]] + + @staticmethod + def _make_tk_ids_unique(annotations): + """Makes track IDs unqiue over the whole annotation set. + + Adapted from https://github.com/TAO-Dataset/. + """ + track_id_videos = {} + track_ids_to_update = set() + max_track_id = 0 + + tk_str = utils.get_track_id_str(annotations[0]) + for ann in annotations: + t = int(ann[tk_str]) + if t not in track_id_videos: + track_id_videos[t] = ann["video_id"] + + if ann["video_id"] != track_id_videos[t]: + # track id is assigned to multiple videos + track_ids_to_update.add(t) + max_track_id = max(max_track_id, t) + + if track_ids_to_update: + print("true") + next_id = itertools.count(max_track_id + 1) + new_tk_ids = defaultdict(lambda: next(next_id)) + for ann in annotations: + t = ann[tk_str] + v = ann["video_id"] + if t in track_ids_to_update: + ann[tk_str] = new_tk_ids[t, v] + return len(track_ids_to_update) diff --git a/third_party/sam3/sam3/eval/teta_eval_toolkit/datasets/tao.py b/third_party/sam3/sam3/eval/teta_eval_toolkit/datasets/tao.py new file mode 100644 index 0000000000000000000000000000000000000000..63fea8e19c82a7a2d673d51aa9d94e494d1ac463 --- /dev/null +++ b/third_party/sam3/sam3/eval/teta_eval_toolkit/datasets/tao.py @@ -0,0 +1,661 @@ +# fmt: off +# flake8: noqa + +# pyre-unsafe + +"""TAO Dataset.""" +import copy +import itertools +import json +import os +from collections import defaultdict + +import numpy as np + +from .. import _timing +from ..config import get_default_dataset_config, init_config +from ..utils import TrackEvalException +from ._base_dataset import _BaseDataset + + +class TAO(_BaseDataset): + """Dataset class for TAO tracking""" + + def __init__(self, config=None): + """Initialize dataset, checking that all required files are present.""" + super().__init__() + # Fill non-given config values with defaults + self.config = init_config(config, get_default_dataset_config(), self.get_name()) + self.gt_fol = self.config["GT_FOLDER"] + self.tracker_fol = self.config["TRACKERS_FOLDER"] + self.should_classes_combine = True + self.use_super_categories = False + self.use_mask = self.config["USE_MASK"] + + + self.tracker_sub_fol = self.config["TRACKER_SUB_FOLDER"] + self.output_fol = self.config["OUTPUT_FOLDER"] + if self.output_fol is None: + self.output_fol = self.tracker_fol + self.output_sub_fol = self.config["OUTPUT_SUB_FOLDER"] + + if self.gt_fol.endswith(".json"): + self.gt_data = json.load(open(self.gt_fol, "r")) + else: + gt_dir_files = [ + file for file in os.listdir(self.gt_fol) if file.endswith(".json") + ] + if len(gt_dir_files) != 1: + raise TrackEvalException( + f"{self.gt_fol} does not contain exactly one json file." + ) + + with open(os.path.join(self.gt_fol, gt_dir_files[0])) as f: + self.gt_data = json.load(f) + + # merge categories marked with a merged tag in TAO dataset + self._merge_categories(self.gt_data["annotations"] + self.gt_data["tracks"]) + + # get sequences to eval and sequence information + self.seq_list = [ + vid["name"].replace("/", "-") for vid in self.gt_data["videos"] + ] + self.seq_name2seqid = { + vid["name"].replace("/", "-"): vid["id"] for vid in self.gt_data["videos"] + } + # compute mappings from videos to annotation data + self.video2gt_track, self.video2gt_image = self._compute_vid_mappings( + self.gt_data["annotations"] + ) + # compute sequence lengths + self.seq_lengths = {vid["id"]: 0 for vid in self.gt_data["videos"]} + for img in self.gt_data["images"]: + self.seq_lengths[img["video_id"]] += 1 + self.seq2images2timestep = self._compute_image_to_timestep_mappings() + self.seq2cls = { + vid["id"]: { + "pos_cat_ids": list( + {track["category_id"] for track in self.video2gt_track[vid["id"]]} + ), + "neg_cat_ids": vid["neg_category_ids"], + "not_exh_labeled_cat_ids": vid["not_exhaustive_category_ids"], + } + for vid in self.gt_data["videos"] + } + + # Get classes to eval + considered_vid_ids = [self.seq_name2seqid[vid] for vid in self.seq_list] + seen_cats = set( + [ + cat_id + for vid_id in considered_vid_ids + for cat_id in self.seq2cls[vid_id]["pos_cat_ids"] + ] + ) + # only classes with ground truth are evaluated in TAO + self.valid_classes = [ + cls["name"] for cls in self.gt_data["categories"] if cls["id"] in seen_cats + ] + cls_name2clsid_map = { + cls["name"]: cls["id"] for cls in self.gt_data["categories"] + } + + if self.config["CLASSES_TO_EVAL"]: + self.class_list = [ + cls.lower() if cls.lower() in self.valid_classes else None + for cls in self.config["CLASSES_TO_EVAL"] + ] + if not all(self.class_list): + valid_cls = ", ".join(self.valid_classes) + raise TrackEvalException( + "Attempted to evaluate an invalid class. Only classes " + f"{valid_cls} are valid (classes present in ground truth" + " data)." + ) + else: + self.class_list = [cls for cls in self.valid_classes] + self.cls_name2clsid = { + k: v for k, v in cls_name2clsid_map.items() if k in self.class_list + } + self.clsid2cls_name = { + v: k for k, v in cls_name2clsid_map.items() if k in self.class_list + } + # get trackers to eval + print(self.config["TRACKERS_TO_EVAL"] ) + if self.config["TRACKERS_TO_EVAL"] is None: + self.tracker_list = os.listdir(self.tracker_fol) + else: + self.tracker_list = self.config["TRACKERS_TO_EVAL"] + + if self.config["TRACKER_DISPLAY_NAMES"] is None: + self.tracker_to_disp = dict(zip(self.tracker_list, self.tracker_list)) + elif (self.config["TRACKERS_TO_EVAL"] is not None) and ( + len(self.config["TK_DISPLAY_NAMES"]) == len(self.tracker_list) + ): + self.tracker_to_disp = dict( + zip(self.tracker_list, self.config["TK_DISPLAY_NAMES"]) + ) + else: + raise TrackEvalException( + "List of tracker files and tracker display names do not match." + ) + + self.tracker_data = {tracker: dict() for tracker in self.tracker_list} + + for tracker in self.tracker_list: + if self.tracker_sub_fol.endswith(".json"): + with open(os.path.join(self.tracker_sub_fol)) as f: + curr_data = json.load(f) + else: + tr_dir = os.path.join(self.tracker_fol, tracker, self.tracker_sub_fol) + tr_dir_files = [ + file for file in os.listdir(tr_dir) if file.endswith(".json") + ] + if len(tr_dir_files) != 1: + raise TrackEvalException( + f"{tr_dir} does not contain exactly one json file." + ) + with open(os.path.join(tr_dir, tr_dir_files[0])) as f: + curr_data = json.load(f) + + # limit detections if MAX_DETECTIONS > 0 + if self.config["MAX_DETECTIONS"]: + curr_data = self._limit_dets_per_image(curr_data) + + # fill missing video ids + self._fill_video_ids_inplace(curr_data) + + # make track ids unique over whole evaluation set + self._make_tk_ids_unique(curr_data) + + # merge categories marked with a merged tag in TAO dataset + self._merge_categories(curr_data) + + # get tracker sequence information + curr_vids2tracks, curr_vids2images = self._compute_vid_mappings(curr_data) + self.tracker_data[tracker]["vids_to_tracks"] = curr_vids2tracks + self.tracker_data[tracker]["vids_to_images"] = curr_vids2images + + def get_display_name(self, tracker): + return self.tracker_to_disp[tracker] + + def _load_raw_file(self, tracker, seq, is_gt): + """Load a file (gt or tracker) in the TAO format + + If is_gt, this returns a dict which contains the fields: + [gt_ids, gt_classes]: + list (for each timestep) of 1D NDArrays (for each det). + [gt_dets]: list (for each timestep) of lists of detections. + + if not is_gt, this returns a dict which contains the fields: + [tk_ids, tk_classes, tk_confidences]: + list (for each timestep) of 1D NDArrays (for each det). + [tk_dets]: list (for each timestep) of lists of detections. + """ + seq_id = self.seq_name2seqid[seq] + # file location + if is_gt: + imgs = self.video2gt_image[seq_id] + else: + imgs = self.tracker_data[tracker]["vids_to_images"][seq_id] + + # convert data to required format + num_timesteps = self.seq_lengths[seq_id] + img_to_timestep = self.seq2images2timestep[seq_id] + data_keys = ["ids", "classes", "dets"] + if not is_gt: + data_keys += ["tk_confidences"] + raw_data = {key: [None] * num_timesteps for key in data_keys} + for img in imgs: + # some tracker data contains images without any ground truth info, + # these are ignored + if img["id"] not in img_to_timestep: + continue + t = img_to_timestep[img["id"]] + anns = img["annotations"] + if self.use_mask: + # When using mask, extract segmentation data + raw_data["dets"][t] = [ann.get("segmentation") for ann in anns] + else: + # When using bbox, extract bbox data + raw_data["dets"][t] = np.atleast_2d([ann["bbox"] for ann in anns]).astype( + float + ) + raw_data["ids"][t] = np.atleast_1d( + [ann["track_id"] for ann in anns] + ).astype(int) + raw_data["classes"][t] = np.atleast_1d( + [ann["category_id"] for ann in anns] + ).astype(int) + if not is_gt: + raw_data["tk_confidences"][t] = np.atleast_1d( + [ann["score"] for ann in anns] + ).astype(float) + + for t, d in enumerate(raw_data["dets"]): + if d is None: + raw_data["dets"][t] = np.empty((0, 4)).astype(float) + raw_data["ids"][t] = np.empty(0).astype(int) + raw_data["classes"][t] = np.empty(0).astype(int) + if not is_gt: + raw_data["tk_confidences"][t] = np.empty(0) + + if is_gt: + key_map = {"ids": "gt_ids", "classes": "gt_classes", "dets": "gt_dets"} + else: + key_map = {"ids": "tk_ids", "classes": "tk_classes", "dets": "tk_dets"} + for k, v in key_map.items(): + raw_data[v] = raw_data.pop(k) + + raw_data["num_timesteps"] = num_timesteps + raw_data["neg_cat_ids"] = self.seq2cls[seq_id]["neg_cat_ids"] + raw_data["not_exh_labeled_cls"] = self.seq2cls[seq_id][ + "not_exh_labeled_cat_ids" + ] + raw_data["seq"] = seq + return raw_data + + def get_preprocessed_seq_data_thr(self, raw_data, cls, assignment=None): + """Preprocess data for a single sequence for a single class. + + Inputs: + raw_data: dict containing the data for the sequence already + read in by get_raw_seq_data(). + cls: class to be evaluated. + Outputs: + gt_ids: + list (for each timestep) of ids of GT tracks + tk_ids: + list (for each timestep) of ids of predicted tracks (all for TP + matching (Det + AssocA)) + tk_overlap_ids: + list (for each timestep) of ids of predicted tracks that overlap + with GTs + tk_neg_ids: + list (for each timestep) of ids of predicted tracks that with + the class id on the negative list for the current sequence. + tk_exh_ids: + list (for each timestep) of ids of predicted tracks that do not + overlap with existing GTs but have the class id on the + exhaustive annotated class list for the current sequence. + tk_dets: + list (for each timestep) of lists of detections that + corresponding to the tk_ids + tk_classes: + list (for each timestep) of lists of classes that corresponding + to the tk_ids + tk_confidences: + list (for each timestep) of lists of classes that corresponding + to the tk_ids + sim_scores: + similarity score between gt_ids and tk_ids. + """ + if cls != "all": + cls_id = self.cls_name2clsid[cls] + + data_keys = [ + "gt_ids", + "tk_ids", + "gt_id_map", + "tk_id_map", + "gt_dets", + "gt_classes", + "gt_class_name", + "tk_overlap_classes", + "tk_overlap_ids", + "tk_neg_ids", + "tk_exh_ids", + "tk_class_eval_tk_ids", + "tk_dets", + "tk_classes", + "tk_confidences", + "sim_scores", + ] + data = {key: [None] * raw_data["num_timesteps"] for key in data_keys} + unique_gt_ids = [] + unique_tk_ids = [] + num_gt_dets = 0 + num_tk_cls_dets = 0 + num_tk_overlap_dets = 0 + overlap_ious_thr = 0.5 + loc_and_asso_tk_ids = [] + + for t in range(raw_data["num_timesteps"]): + # only extract relevant dets for this class for preproc and eval + if cls == "all": + gt_class_mask = np.ones_like(raw_data["gt_classes"][t]).astype(bool) + else: + gt_class_mask = np.atleast_1d( + raw_data["gt_classes"][t] == cls_id + ).astype(bool) + + # select GT that is not in the evaluating classes + if assignment is not None and assignment: + all_gt_ids = list(assignment[t].keys()) + gt_ids_in = raw_data["gt_ids"][t][gt_class_mask] + gt_ids_out = set(all_gt_ids) - set(gt_ids_in) + tk_ids_out = set([assignment[t][key] for key in list(gt_ids_out)]) + + # compute overlapped tracks and add their ids to overlap_tk_ids + sim_scores = raw_data["similarity_scores"] + overlap_ids_masks = (sim_scores[t][gt_class_mask] >= overlap_ious_thr).any( + axis=0 + ) + overlap_tk_ids_t = raw_data["tk_ids"][t][overlap_ids_masks] + if assignment is not None and assignment: + data["tk_overlap_ids"][t] = list(set(overlap_tk_ids_t) - tk_ids_out) + else: + data["tk_overlap_ids"][t] = list(set(overlap_tk_ids_t)) + + loc_and_asso_tk_ids += data["tk_overlap_ids"][t] + + data["tk_exh_ids"][t] = [] + data["tk_neg_ids"][t] = [] + + if cls == "all": + continue + + # remove tk_ids that has been assigned to GT belongs to other classes. + loc_and_asso_tk_ids = list(set(loc_and_asso_tk_ids)) + + # remove all unwanted unmatched tracker detections + for t in range(raw_data["num_timesteps"]): + # add gt to the data + if cls == "all": + gt_class_mask = np.ones_like(raw_data["gt_classes"][t]).astype(bool) + else: + gt_class_mask = np.atleast_1d( + raw_data["gt_classes"][t] == cls_id + ).astype(bool) + data["gt_classes"][t] = cls_id + data["gt_class_name"][t] = cls + + gt_ids = raw_data["gt_ids"][t][gt_class_mask] + if self.use_mask: + gt_dets = [raw_data['gt_dets'][t][ind] for ind in range(len(gt_class_mask)) if gt_class_mask[ind]] + else: + gt_dets = raw_data["gt_dets"][t][gt_class_mask] + data["gt_ids"][t] = gt_ids + data["gt_dets"][t] = gt_dets + + # filter pred and only keep those that highly overlap with GTs + tk_mask = np.isin( + raw_data["tk_ids"][t], np.array(loc_and_asso_tk_ids), assume_unique=True + ) + tk_overlap_mask = np.isin( + raw_data["tk_ids"][t], + np.array(data["tk_overlap_ids"][t]), + assume_unique=True, + ) + + tk_ids = raw_data["tk_ids"][t][tk_mask] + if self.use_mask: + tk_dets = [raw_data['tk_dets'][t][ind] for ind in range(len(tk_mask)) if + tk_mask[ind]] + else: + tk_dets = raw_data["tk_dets"][t][tk_mask] + tracker_classes = raw_data["tk_classes"][t][tk_mask] + + # add overlap classes for computing the FP for Cls term + tracker_overlap_classes = raw_data["tk_classes"][t][tk_overlap_mask] + tracker_confidences = raw_data["tk_confidences"][t][tk_mask] + sim_scores_masked = sim_scores[t][gt_class_mask, :][:, tk_mask] + + # add filtered prediction to the data + data["tk_classes"][t] = tracker_classes + data["tk_overlap_classes"][t] = tracker_overlap_classes + data["tk_ids"][t] = tk_ids + data["tk_dets"][t] = tk_dets + data["tk_confidences"][t] = tracker_confidences + data["sim_scores"][t] = sim_scores_masked + data["tk_class_eval_tk_ids"][t] = set( + list(data["tk_overlap_ids"][t]) + + list(data["tk_neg_ids"][t]) + + list(data["tk_exh_ids"][t]) + ) + + # count total number of detections + unique_gt_ids += list(np.unique(data["gt_ids"][t])) + # the unique track ids are for association. + unique_tk_ids += list(np.unique(data["tk_ids"][t])) + + num_tk_overlap_dets += len(data["tk_overlap_ids"][t]) + num_tk_cls_dets += len(data["tk_class_eval_tk_ids"][t]) + num_gt_dets += len(data["gt_ids"][t]) + + # re-label IDs such that there are no empty IDs + if len(unique_gt_ids) > 0: + unique_gt_ids = np.unique(unique_gt_ids) + gt_id_map = np.nan * np.ones((np.max(unique_gt_ids) + 1)) + gt_id_map[unique_gt_ids] = np.arange(len(unique_gt_ids)) + data["gt_id_map"] = {} + for gt_id in unique_gt_ids: + new_gt_id = gt_id_map[gt_id].astype(int) + data["gt_id_map"][new_gt_id] = gt_id + + for t in range(raw_data["num_timesteps"]): + if len(data["gt_ids"][t]) > 0: + data["gt_ids"][t] = gt_id_map[data["gt_ids"][t]].astype(int) + + if len(unique_tk_ids) > 0: + unique_tk_ids = np.unique(unique_tk_ids) + tk_id_map = np.nan * np.ones((np.max(unique_tk_ids) + 1)) + tk_id_map[unique_tk_ids] = np.arange(len(unique_tk_ids)) + + data["tk_id_map"] = {} + for track_id in unique_tk_ids: + new_track_id = tk_id_map[track_id].astype(int) + data["tk_id_map"][new_track_id] = track_id + + for t in range(raw_data["num_timesteps"]): + if len(data["tk_ids"][t]) > 0: + data["tk_ids"][t] = tk_id_map[data["tk_ids"][t]].astype(int) + if len(data["tk_overlap_ids"][t]) > 0: + data["tk_overlap_ids"][t] = tk_id_map[ + data["tk_overlap_ids"][t] + ].astype(int) + + # record overview statistics. + data["num_tk_cls_dets"] = num_tk_cls_dets + data["num_tk_overlap_dets"] = num_tk_overlap_dets + data["num_gt_dets"] = num_gt_dets + data["num_tk_ids"] = len(unique_tk_ids) + data["num_gt_ids"] = len(unique_gt_ids) + data["num_timesteps"] = raw_data["num_timesteps"] + data["seq"] = raw_data["seq"] + + self._check_unique_ids(data) + + return data + + @_timing.time + def get_preprocessed_seq_data( + self, raw_data, cls, assignment=None, thresholds=[50, 75] + ): + """Preprocess data for a single sequence for a single class.""" + data = {} + if thresholds is None: + thresholds = [50] + elif isinstance(thresholds, int): + thresholds = [thresholds] + + for thr in thresholds: + assignment_thr = None + if assignment is not None: + assignment_thr = assignment[thr] + data[thr] = self.get_preprocessed_seq_data_thr( + raw_data, cls, assignment_thr + ) + + return data + + def _calculate_similarities(self, gt_dets_t, tk_dets_t): + """Compute similarity scores.""" + if self.use_mask: + similarity_scores = self._calculate_mask_ious(gt_dets_t, tk_dets_t, is_encoded=True, do_ioa=False) + else: + similarity_scores = self._calculate_box_ious(gt_dets_t, tk_dets_t) + return similarity_scores + + def _merge_categories(self, annotations): + """Merges categories with a merged tag. + + Adapted from https://github.com/TAO-Dataset. + """ + merge_map = {} + for category in self.gt_data["categories"]: + if "merged" in category: + for to_merge in category["merged"]: + merge_map[to_merge["id"]] = category["id"] + + for ann in annotations: + ann["category_id"] = merge_map.get(ann["category_id"], ann["category_id"]) + + def _compute_vid_mappings(self, annotations): + """Computes mappings from videos to corresponding tracks and images.""" + vids_to_tracks = {} + vids_to_imgs = {} + vid_ids = [vid["id"] for vid in self.gt_data["videos"]] + + # compute an mapping from image IDs to images + images = {} + for image in self.gt_data["images"]: + images[image["id"]] = image + + for ann in annotations: + ann["area"] = ann["bbox"][2] * ann["bbox"][3] + + vid = ann["video_id"] + if ann["video_id"] not in vids_to_tracks.keys(): + vids_to_tracks[ann["video_id"]] = list() + if ann["video_id"] not in vids_to_imgs.keys(): + vids_to_imgs[ann["video_id"]] = list() + + # fill in vids_to_tracks + tid = ann["track_id"] + exist_tids = [track["id"] for track in vids_to_tracks[vid]] + try: + index1 = exist_tids.index(tid) + except ValueError: + index1 = -1 + if tid not in exist_tids: + curr_track = { + "id": tid, + "category_id": ann["category_id"], + "video_id": vid, + "annotations": [ann], + } + vids_to_tracks[vid].append(curr_track) + else: + vids_to_tracks[vid][index1]["annotations"].append(ann) + + # fill in vids_to_imgs + img_id = ann["image_id"] + exist_img_ids = [img["id"] for img in vids_to_imgs[vid]] + try: + index2 = exist_img_ids.index(img_id) + except ValueError: + index2 = -1 + if index2 == -1: + curr_img = {"id": img_id, "annotations": [ann]} + vids_to_imgs[vid].append(curr_img) + else: + vids_to_imgs[vid][index2]["annotations"].append(ann) + + # sort annotations by frame index and compute track area + for vid, tracks in vids_to_tracks.items(): + for track in tracks: + track["annotations"] = sorted( + track["annotations"], + key=lambda x: images[x["image_id"]]["frame_index"], + ) + # compute average area + track["area"] = sum(x["area"] for x in track["annotations"]) / len( + track["annotations"] + ) + + # ensure all videos are present + for vid_id in vid_ids: + if vid_id not in vids_to_tracks.keys(): + vids_to_tracks[vid_id] = [] + if vid_id not in vids_to_imgs.keys(): + vids_to_imgs[vid_id] = [] + + return vids_to_tracks, vids_to_imgs + + def _compute_image_to_timestep_mappings(self): + """Computes a mapping from images to timestep in sequence.""" + images = {} + for image in self.gt_data["images"]: + images[image["id"]] = image + + seq_to_imgs_to_timestep = {vid["id"]: dict() for vid in self.gt_data["videos"]} + for vid in seq_to_imgs_to_timestep: + curr_imgs = [img["id"] for img in self.video2gt_image[vid]] + curr_imgs = sorted(curr_imgs, key=lambda x: images[x]["frame_index"]) + seq_to_imgs_to_timestep[vid] = { + curr_imgs[i]: i for i in range(len(curr_imgs)) + } + + return seq_to_imgs_to_timestep + + def _limit_dets_per_image(self, annotations): + """Limits the number of detections for each image. + + Adapted from https://github.com/TAO-Dataset/. + """ + max_dets = self.config["MAX_DETECTIONS"] + img_ann = defaultdict(list) + for ann in annotations: + img_ann[ann["image_id"]].append(ann) + + for img_id, _anns in img_ann.items(): + if len(_anns) <= max_dets: + continue + _anns = sorted(_anns, key=lambda x: x["score"], reverse=True) + img_ann[img_id] = _anns[:max_dets] + + return [ann for anns in img_ann.values() for ann in anns] + + def _fill_video_ids_inplace(self, annotations): + """Fills in missing video IDs inplace. + + Adapted from https://github.com/TAO-Dataset/. + """ + missing_video_id = [x for x in annotations if "video_id" not in x] + if missing_video_id: + image_id_to_video_id = { + x["id"]: x["video_id"] for x in self.gt_data["images"] + } + for x in missing_video_id: + x["video_id"] = image_id_to_video_id[x["image_id"]] + + @staticmethod + def _make_tk_ids_unique(annotations): + """Makes track IDs unqiue over the whole annotation set. + + Adapted from https://github.com/TAO-Dataset/. + """ + track_id_videos = {} + track_ids_to_update = set() + max_track_id = 0 + for ann in annotations: + t = ann["track_id"] + if t not in track_id_videos: + track_id_videos[t] = ann["video_id"] + + if ann["video_id"] != track_id_videos[t]: + # track id is assigned to multiple videos + track_ids_to_update.add(t) + max_track_id = max(max_track_id, t) + + if track_ids_to_update: + print("true") + next_id = itertools.count(max_track_id + 1) + new_tk_ids = defaultdict(lambda: next(next_id)) + for ann in annotations: + t = ann["track_id"] + v = ann["video_id"] + if t in track_ids_to_update: + ann["track_id"] = new_tk_ids[t, v] + return len(track_ids_to_update) diff --git a/third_party/sam3/sam3/eval/teta_eval_toolkit/eval.py b/third_party/sam3/sam3/eval/teta_eval_toolkit/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..07c0a76e50e3afa2b126ab3e11e83818ced89d7d --- /dev/null +++ b/third_party/sam3/sam3/eval/teta_eval_toolkit/eval.py @@ -0,0 +1,277 @@ +# fmt: off +# flake8: noqa + +# pyre-unsafe + +import copy +import os +import pickle +import time +import traceback +from functools import partial +from multiprocessing.pool import Pool + +import numpy as np + +from . import _timing, utils +from .config import get_default_eval_config, init_config +from .utils import TrackEvalException + + +class Evaluator: + """Evaluator class for evaluating different metrics for each datasets.""" + + def __init__(self, config=None): + """Initialize the evaluator with a config file.""" + self.config = init_config(config, get_default_eval_config(), "Eval") + # Only run timing analysis if not run in parallel. + if self.config["TIME_PROGRESS"] and not self.config["USE_PARALLEL"]: + _timing.DO_TIMING = True + if self.config["DISPLAY_LESS_PROGRESS"]: + _timing.DISPLAY_LESS_PROGRESS = True + + @_timing.time + def evaluate(self, dataset_list, metrics_list): + """Evaluate a set of metrics on a set of datasets.""" + config = self.config + metrics_list = metrics_list + metric_names = utils.validate_metrics_list(metrics_list) + dataset_names = [dataset.get_name() for dataset in dataset_list] + output_res = {} + output_msg = {} + + for dataset, dname in zip(dataset_list, dataset_names): + # Get dataset info about what to evaluate + output_res[dname] = {} + output_msg[dname] = {} + tracker_list, seq_list, class_list = dataset.get_eval_info() + print( + f"\nEvaluating {len(tracker_list)} tracker(s) on " + f"{len(seq_list)} sequence(s) for {len(class_list)} class(es)" + f" on {dname} dataset using the following " + f'metrics: {", ".join(metric_names)}\n' + ) + + # Evaluate each tracker + for tracker in tracker_list: + try: + output_res, output_msg = self.evaluate_tracker( + tracker, + dataset, + dname, + class_list, + metrics_list, + metric_names, + seq_list, + output_res, + output_msg, + ) + except Exception as err: + output_res[dname][tracker] = None + if type(err) == TrackEvalException: + output_msg[dname][tracker] = str(err) + else: + output_msg[dname][tracker] = "Unknown error occurred." + print("Tracker %s was unable to be evaluated." % tracker) + print(err) + traceback.print_exc() + if config["LOG_ON_ERROR"] is not None: + with open(config["LOG_ON_ERROR"], "a") as f: + print(dname, file=f) + print(tracker, file=f) + print(traceback.format_exc(), file=f) + print("\n\n\n", file=f) + if config["BREAK_ON_ERROR"]: + raise err + elif config["RETURN_ON_ERROR"]: + return output_res, output_msg + + return output_res, output_msg + + def evaluate_tracker( + self, + tracker, + dataset, + dname, + class_list, + metrics_list, + metric_names, + seq_list, + output_res, + output_msg, + ): + """Evaluate each sequence in parallel or in series.""" + print("\nEvaluating %s\n" % tracker) + time_start = time.time() + config = self.config + if config["USE_PARALLEL"]: + with Pool(config["NUM_PARALLEL_CORES"]) as pool: + _eval_sequence = partial( + eval_sequence, + dataset=dataset, + tracker=tracker, + class_list=class_list, + metrics_list=metrics_list, + metric_names=metric_names, + ) + results = pool.map(_eval_sequence, seq_list) + res = dict(zip(seq_list, results)) + else: + res = {} + for curr_seq in sorted(seq_list): + res[curr_seq] = eval_sequence( + curr_seq, dataset, tracker, class_list, metrics_list, metric_names + ) + + + # collecting combined cls keys (cls averaged, det averaged, super classes) + cls_keys = [] + res["COMBINED_SEQ"] = {} + # combine sequences for each class + for c_cls in class_list: + res["COMBINED_SEQ"][c_cls] = {} + for metric, mname in zip(metrics_list, metric_names): + curr_res = { + seq_key: seq_value[c_cls][mname] + for seq_key, seq_value in res.items() + if seq_key != "COMBINED_SEQ" + } + # combine results over all sequences and then over all classes + res["COMBINED_SEQ"][c_cls][mname] = metric.combine_sequences(curr_res) + + # combine classes + if dataset.should_classes_combine: + if config["OUTPUT_PER_SEQ_RES"]: + video_keys = res.keys() + else: + video_keys = ["COMBINED_SEQ"] + for v_key in video_keys: + cls_keys += ["average"] + res[v_key]["average"] = {} + for metric, mname in zip(metrics_list, metric_names): + cls_res = { + cls_key: cls_value[mname] + for cls_key, cls_value in res[v_key].items() + if cls_key not in cls_keys + } + res[v_key]["average"][ + mname + ] = metric.combine_classes_class_averaged( + cls_res, ignore_empty=True + ) + + # combine classes to super classes + if dataset.use_super_categories: + for cat, sub_cats in dataset.super_categories.items(): + cls_keys.append(cat) + res["COMBINED_SEQ"][cat] = {} + for metric, mname in zip(metrics_list, metric_names): + cat_res = { + cls_key: cls_value[mname] + for cls_key, cls_value in res["COMBINED_SEQ"].items() + if cls_key in sub_cats + } + res["COMBINED_SEQ"][cat][ + mname + ] = metric.combine_classes_det_averaged(cat_res) + # Print and output results in various formats + if config["TIME_PROGRESS"]: + print( + f"\nAll sequences for {tracker} finished in" + f" {time.time() - time_start} seconds" + ) + output_fol = dataset.get_output_fol(tracker) + os.makedirs(output_fol, exist_ok=True) + + # take a mean of each field of each thr + if config["OUTPUT_PER_SEQ_RES"]: + all_res = copy.deepcopy(res) + summary_keys = res.keys() + else: + all_res = copy.deepcopy(res["COMBINED_SEQ"]) + summary_keys = ["COMBINED_SEQ"] + thr_key_list = [50] + for s_key in summary_keys: + for metric, mname in zip(metrics_list, metric_names): + if mname != "TETA": + if s_key == "COMBINED_SEQ": + metric.print_table( + {"COMBINED_SEQ": res["COMBINED_SEQ"][cls_keys[0]][mname]}, + tracker, + cls_keys[0], + ) + continue + + for c_cls in res[s_key].keys(): + for thr in thr_key_list: + all_res[s_key][c_cls][mname][thr] = metric._summary_row( + res[s_key][c_cls][mname][thr] + ) + x = ( + np.array(list(all_res[s_key][c_cls]["TETA"].values())) + .astype("float") + .mean(axis=0) + ) + all_res_summary = list(x.round(decimals=2).astype("str")) + all_res[s_key][c_cls][mname]["ALL"] = all_res_summary + if config["OUTPUT_SUMMARY"] and s_key == "COMBINED_SEQ": + for t in thr_key_list: + metric.print_summary_table( + all_res[s_key][cls_keys[0]][mname][t], + t, + tracker, + cls_keys[0], + ) + + if config["OUTPUT_TEM_RAW_DATA"]: + out_file = os.path.join(output_fol, "teta_summary_results.pth") + pickle.dump(all_res, open(out_file, "wb")) + print("Saved the TETA summary results.") + + # output + output_res[dname][mname] = all_res[s_key][cls_keys[0]][mname][t] + output_msg[dname][tracker] = "Success" + + return output_res, output_msg + + +@_timing.time +def eval_sequence(seq, dataset, tracker, class_list, metrics_list, metric_names): + """Function for evaluating a single sequence.""" + raw_data = dataset.get_raw_seq_data(tracker, seq) + seq_res = {} + + if "TETA" in metric_names: + thresholds = [50] + data_all_class = dataset.get_preprocessed_seq_data( + raw_data, "all", thresholds=thresholds + ) + teta = metrics_list[metric_names.index("TETA")] + assignment = teta.compute_global_assignment(data_all_class) + + # create a dict to save Cls_FP for each class in different thr. + cls_fp = { + key: { + cls: np.zeros((len(np.arange(0.5, 0.99, 0.05)))) for cls in class_list + } + for key in thresholds + } + + for cls in class_list: + seq_res[cls] = {} + data = dataset.get_preprocessed_seq_data(raw_data, cls, assignment, thresholds) + + for metric, mname in zip(metrics_list, metric_names): + if mname == "TETA": + seq_res[cls][mname], cls_fp, _ = metric.eval_sequence( + data, cls, dataset.clsid2cls_name, cls_fp + ) + else: + seq_res[cls][mname] = metric.eval_sequence(data) + + if "TETA" in metric_names: + for thr in thresholds: + for cls in class_list: + seq_res[cls]["TETA"][thr]["Cls_FP"] += cls_fp[thr][cls] + + return seq_res diff --git a/third_party/sam3/sam3/eval/teta_eval_toolkit/metrics/__init__.py b/third_party/sam3/sam3/eval/teta_eval_toolkit/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5c39ac6469ca165fbe9bc933d2bbe229c7a225df --- /dev/null +++ b/third_party/sam3/sam3/eval/teta_eval_toolkit/metrics/__init__.py @@ -0,0 +1,6 @@ +# fmt: off +# flake8: noqa + +# pyre-unsafe + +from .teta import TETA diff --git a/third_party/sam3/sam3/eval/teta_eval_toolkit/metrics/_base_metric.py b/third_party/sam3/sam3/eval/teta_eval_toolkit/metrics/_base_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..e1d8f77484dc4a1e45a3424f550156ba9bc6ceb8 --- /dev/null +++ b/third_party/sam3/sam3/eval/teta_eval_toolkit/metrics/_base_metric.py @@ -0,0 +1,150 @@ +# fmt: off +# flake8: noqa + +# pyre-unsafe + +from abc import ABC, abstractmethod + +import numpy as np + +from .. import _timing +from ..utils import TrackEvalException + + +class _BaseMetric(ABC): + @abstractmethod + def __init__(self): + self.plottable = False + self.integer_fields = [] + self.float_fields = [] + self.array_labels = [] + self.integer_array_fields = [] + self.float_array_fields = [] + self.fields = [] + self.summary_fields = [] + self.registered = False + + ##################################################################### + # Abstract functions for subclasses to implement + + @_timing.time + @abstractmethod + def eval_sequence(self, data): + ... + + @abstractmethod + def combine_sequences(self, all_res): + ... + + @abstractmethod + def combine_classes_class_averaged(self, all_res, ignore_empty=False): + ... + + @abstractmethod + def combine_classes_det_averaged(self, all_res): + ... + + def plot_single_tracker_results(self, all_res, tracker, output_folder, cls): + """Plot results, only valid for metrics with self.plottable.""" + if self.plottable: + raise NotImplementedError( + f"plot_results is not implemented for metric {self.get_name()}" + ) + else: + pass + + ##################################################################### + # Helper functions which are useful for all metrics: + + @classmethod + def get_name(cls): + return cls.__name__ + + @staticmethod + def _combine_sum(all_res, field): + """Combine sequence results via sum""" + return sum([all_res[k][field] for k in all_res.keys()]) + + @staticmethod + def _combine_weighted_av(all_res, field, comb_res, weight_field): + """Combine sequence results via weighted average.""" + return sum( + [all_res[k][field] * all_res[k][weight_field] for k in all_res.keys()] + ) / np.maximum(1.0, comb_res[weight_field]) + + def print_table(self, table_res, tracker, cls): + """Print table of results for all sequences.""" + print("") + metric_name = self.get_name() + self._row_print( + [metric_name + ": " + tracker + "-" + cls] + self.summary_fields + ) + for seq, results in sorted(table_res.items()): + if seq == "COMBINED_SEQ": + continue + summary_res = self._summary_row(results) + self._row_print([seq] + summary_res) + summary_res = self._summary_row(table_res["COMBINED_SEQ"]) + self._row_print(["COMBINED"] + summary_res) + + def _summary_row(self, results_): + vals = [] + for h in self.summary_fields: + if h in self.float_array_fields: + vals.append("{0:1.5g}".format(100 * np.mean(results_[h]))) + elif h in self.float_fields: + vals.append("{0:1.5g}".format(100 * float(results_[h]))) + elif h in self.integer_fields: + vals.append("{0:d}".format(int(results_[h]))) + else: + raise NotImplementedError( + "Summary function not implemented for this field type." + ) + return vals + + @staticmethod + def _row_print(*argv): + """Print results in evenly spaced rows, with more space in first row.""" + if len(argv) == 1: + argv = argv[0] + to_print = "%-35s" % argv[0] + for v in argv[1:]: + to_print += "%-10s" % str(v) + print(to_print) + + def summary_results(self, table_res): + """Return a simple summary of final results for a tracker.""" + return dict( + zip(self.summary_fields, self._summary_row(table_res["COMBINED_SEQ"]),) + ) + + def detailed_results(self, table_res): + """Return detailed final results for a tracker.""" + # Get detailed field information + detailed_fields = self.float_fields + self.integer_fields + for h in self.float_array_fields + self.integer_array_fields: + for alpha in [int(100 * x) for x in self.array_labels]: + detailed_fields.append(h + "___" + str(alpha)) + detailed_fields.append(h + "___AUC") + + # Get detailed results + detailed_results = {} + for seq, res in table_res.items(): + detailed_row = self._detailed_row(res) + if len(detailed_row) != len(detailed_fields): + raise TrackEvalException( + f"Field names and data have different sizes " + f"({len(detailed_row)} and {len(detailed_fields)})" + ) + detailed_results[seq] = dict(zip(detailed_fields, detailed_row)) + return detailed_results + + def _detailed_row(self, res): + detailed_row = [] + for h in self.float_fields + self.integer_fields: + detailed_row.append(res[h]) + for h in self.float_array_fields + self.integer_array_fields: + for i, _ in enumerate([int(100 * x) for x in self.array_labels]): + detailed_row.append(res[h][i]) + detailed_row.append(np.mean(res[h])) + return detailed_row diff --git a/third_party/sam3/sam3/eval/teta_eval_toolkit/metrics/teta.py b/third_party/sam3/sam3/eval/teta_eval_toolkit/metrics/teta.py new file mode 100644 index 0000000000000000000000000000000000000000..288623e0df917399cfeef449e554f4813b9a961f --- /dev/null +++ b/third_party/sam3/sam3/eval/teta_eval_toolkit/metrics/teta.py @@ -0,0 +1,401 @@ +# fmt: off +# flake8: noqa + +# pyre-unsafe + +"""Track Every Thing Accuracy metric.""" + +import numpy as np +from scipy.optimize import linear_sum_assignment + +from .. import _timing +from ._base_metric import _BaseMetric + +EPS = np.finfo("float").eps # epsilon + + +class TETA(_BaseMetric): + """TETA metric.""" + + def __init__(self, exhaustive=False, config=None): + """Initialize metric.""" + super().__init__() + self.plottable = True + self.array_labels = np.arange(0.0, 0.99, 0.05) + self.cls_array_labels = np.arange(0.5, 0.99, 0.05) + + self.integer_array_fields = [ + "Loc_TP", + "Loc_FN", + "Loc_FP", + "Cls_TP", + "Cls_FN", + "Cls_FP", + ] + self.float_array_fields = ( + ["TETA", "LocA", "AssocA", "ClsA"] + + ["LocRe", "LocPr"] + + ["AssocRe", "AssocPr"] + + ["ClsRe", "ClsPr"] + ) + self.fields = self.float_array_fields + self.integer_array_fields + self.summary_fields = self.float_array_fields + self.exhaustive = exhaustive + + def compute_global_assignment(self, data_thr, alpha=0.5): + """Compute global assignment of TP.""" + res = { + thr: {t: {} for t in range(data_thr[thr]["num_timesteps"])} + for thr in data_thr + } + + for thr in data_thr: + data = data_thr[thr] + # return empty result if tracker or gt sequence is empty + if data["num_tk_overlap_dets"] == 0 or data["num_gt_dets"] == 0: + return res + + # global alignment score + ga_score, _, _ = self.compute_global_alignment_score(data) + + # calculate scores for each timestep + for t, (gt_ids_t, tk_ids_t) in enumerate( + zip(data["gt_ids"], data["tk_ids"]) + ): + # get matches optimizing for TETA + amatch_rows, amatch_cols = self.compute_matches( + data, t, ga_score, gt_ids_t, tk_ids_t, alpha=alpha + ) + gt_ids = [data["gt_id_map"][tid] for tid in gt_ids_t[amatch_rows[0]]] + matched_ids = [ + data["tk_id_map"][tid] for tid in tk_ids_t[amatch_cols[0]] + ] + res[thr][t] = dict(zip(gt_ids, matched_ids)) + + return res + + def eval_sequence_single_thr(self, data, cls, cid2clsname, cls_fp_thr, thr): + """Computes TETA metric for one threshold for one sequence.""" + res = {} + class_info_list = [] + for field in self.float_array_fields + self.integer_array_fields: + if field.startswith("Cls"): + res[field] = np.zeros(len(self.cls_array_labels), dtype=float) + else: + res[field] = np.zeros((len(self.array_labels)), dtype=float) + + # return empty result if tracker or gt sequence is empty + if data["num_tk_overlap_dets"] == 0: + res["Loc_FN"] = data["num_gt_dets"] * np.ones( + (len(self.array_labels)), dtype=float + ) + if self.exhaustive: + cls_fp_thr[cls] = data["num_tk_cls_dets"] * np.ones( + (len(self.cls_array_labels)), dtype=float + ) + res = self._compute_final_fields(res) + return res, cls_fp_thr, class_info_list + + if data["num_gt_dets"] == 0: + if self.exhaustive: + cls_fp_thr[cls] = data["num_tk_cls_dets"] * np.ones( + (len(self.cls_array_labels)), dtype=float + ) + res = self._compute_final_fields(res) + return res, cls_fp_thr, class_info_list + + # global alignment score + ga_score, gt_id_count, tk_id_count = self.compute_global_alignment_score(data) + matches_counts = [np.zeros_like(ga_score) for _ in self.array_labels] + + # calculate scores for each timestep + for t, (gt_ids_t, tk_ids_t, tk_overlap_ids_t, tk_cls_ids_t) in enumerate( + zip( + data["gt_ids"], + data["tk_ids"], + data["tk_overlap_ids"], + data["tk_class_eval_tk_ids"], + ) + ): + # deal with the case that there are no gt_det/tk_det in a timestep + if len(gt_ids_t) == 0: + if self.exhaustive: + cls_fp_thr[cls] += len(tk_cls_ids_t) + continue + + # get matches optimizing for TETA + amatch_rows, amatch_cols = self.compute_matches( + data, t, ga_score, gt_ids_t, tk_ids_t, list(self.array_labels) + ) + + # map overlap_ids to original ids. + if len(tk_overlap_ids_t) != 0: + sorter = np.argsort(tk_ids_t) + indexes = sorter[ + np.searchsorted(tk_ids_t, tk_overlap_ids_t, sorter=sorter) + ] + sim_t = data["sim_scores"][t][:, indexes] + fpl_candidates = tk_overlap_ids_t[(sim_t >= (thr / 100)).any(axis=0)] + fpl_candidates_ori_ids_t = np.array( + [data["tk_id_map"][tid] for tid in fpl_candidates] + ) + else: + fpl_candidates_ori_ids_t = [] + + if self.exhaustive: + cls_fp_thr[cls] += len(tk_cls_ids_t) - len(tk_overlap_ids_t) + + # calculate and accumulate basic statistics + for a, alpha in enumerate(self.array_labels): + match_row, match_col = amatch_rows[a], amatch_cols[a] + num_matches = len(match_row) + matched_ori_ids = set( + [data["tk_id_map"][tid] for tid in tk_ids_t[match_col]] + ) + match_tk_cls = data["tk_classes"][t][match_col] + wrong_tk_cls = match_tk_cls[match_tk_cls != data["gt_classes"][t]] + + num_class_and_det_matches = np.sum( + match_tk_cls == data["gt_classes"][t] + ) + + if alpha >= 0.5: + for cid in wrong_tk_cls: + if cid in cid2clsname: + cname = cid2clsname[cid] + cls_fp_thr[cname][a - 10] += 1 + res["Cls_TP"][a - 10] += num_class_and_det_matches + res["Cls_FN"][a - 10] += num_matches - num_class_and_det_matches + + res["Loc_TP"][a] += num_matches + res["Loc_FN"][a] += len(gt_ids_t) - num_matches + res["Loc_FP"][a] += len(set(fpl_candidates_ori_ids_t) - matched_ori_ids) + + if num_matches > 0: + matches_counts[a][gt_ids_t[match_row], tk_ids_t[match_col]] += 1 + + # calculate AssocA, AssocRe, AssocPr + self.compute_association_scores(res, matches_counts, gt_id_count, tk_id_count) + + # calculate final scores + res = self._compute_final_fields(res) + return res, cls_fp_thr, class_info_list + + def compute_global_alignment_score(self, data): + """Computes global alignment score.""" + num_matches = np.zeros((data["num_gt_ids"], data["num_tk_ids"])) + gt_id_count = np.zeros((data["num_gt_ids"], 1)) + tk_id_count = np.zeros((1, data["num_tk_ids"])) + + # loop through each timestep and accumulate global track info. + for t, (gt_ids_t, tk_ids_t) in enumerate(zip(data["gt_ids"], data["tk_ids"])): + # count potential matches between ids in each time step + # these are normalized, weighted by match similarity + sim = data["sim_scores"][t] + sim_iou_denom = sim.sum(0, keepdims=True) + sim.sum(1, keepdims=True) - sim + sim_iou = np.zeros_like(sim) + mask = sim_iou_denom > (0 + EPS) + sim_iou[mask] = sim[mask] / sim_iou_denom[mask] + num_matches[gt_ids_t[:, None], tk_ids_t[None, :]] += sim_iou + + # calculate total number of dets for each gt_id and tk_id. + gt_id_count[gt_ids_t] += 1 + tk_id_count[0, tk_ids_t] += 1 + + # Calculate overall Jaccard alignment score between IDs + ga_score = num_matches / (gt_id_count + tk_id_count - num_matches) + return ga_score, gt_id_count, tk_id_count + + def compute_matches(self, data, t, ga_score, gt_ids, tk_ids, alpha): + """Compute matches based on alignment score.""" + sim = data["sim_scores"][t] + score_mat = ga_score[gt_ids[:, None], tk_ids[None, :]] * sim + # Hungarian algorithm to find best matches + match_rows, match_cols = linear_sum_assignment(-score_mat) + + if not isinstance(alpha, list): + alpha = [alpha] + alpha_match_rows, alpha_match_cols = [], [] + for a in alpha: + matched_mask = sim[match_rows, match_cols] >= a - EPS + alpha_match_rows.append(match_rows[matched_mask]) + alpha_match_cols.append(match_cols[matched_mask]) + return alpha_match_rows, alpha_match_cols + + def compute_association_scores(self, res, matches_counts, gt_id_count, tk_id_count): + """Calculate association scores for each alpha. + + First calculate scores per gt_id/tk_id combo, + and then average over the number of detections. + """ + for a, _ in enumerate(self.array_labels): + matches_count = matches_counts[a] + ass_a = matches_count / np.maximum( + 1, gt_id_count + tk_id_count - matches_count + ) + res["AssocA"][a] = np.sum(matches_count * ass_a) / np.maximum( + 1, res["Loc_TP"][a] + ) + ass_re = matches_count / np.maximum(1, gt_id_count) + res["AssocRe"][a] = np.sum(matches_count * ass_re) / np.maximum( + 1, res["Loc_TP"][a] + ) + ass_pr = matches_count / np.maximum(1, tk_id_count) + res["AssocPr"][a] = np.sum(matches_count * ass_pr) / np.maximum( + 1, res["Loc_TP"][a] + ) + + @_timing.time + def eval_sequence(self, data, cls, cls_id_name_mapping, cls_fp): + """Evaluate a single sequence across all thresholds.""" + res = {} + class_info_dict = {} + + for thr in data: + res[thr], cls_fp[thr], cls_info = self.eval_sequence_single_thr( + data[thr], cls, cls_id_name_mapping, cls_fp[thr], thr + ) + class_info_dict[thr] = cls_info + + return res, cls_fp, class_info_dict + + def combine_sequences(self, all_res): + """Combines metrics across all sequences.""" + data = {} + res = {} + + if all_res: + thresholds = list(list(all_res.values())[0].keys()) + else: + thresholds = [50] + for thr in thresholds: + data[thr] = {} + for seq_key in all_res: + data[thr][seq_key] = all_res[seq_key][thr] + for thr in thresholds: + res[thr] = self._combine_sequences_thr(data[thr]) + + return res + + def _combine_sequences_thr(self, all_res): + """Combines sequences over each threshold.""" + res = {} + for field in self.integer_array_fields: + res[field] = self._combine_sum(all_res, field) + for field in ["AssocRe", "AssocPr", "AssocA"]: + res[field] = self._combine_weighted_av( + all_res, field, res, weight_field="Loc_TP" + ) + res = self._compute_final_fields(res) + return res + + def combine_classes_class_averaged(self, all_res, ignore_empty=False): + """Combines metrics across all classes by averaging over classes. + + If 'ignore_empty' is True, then it only sums over classes + with at least one gt or predicted detection. + """ + data = {} + res = {} + if all_res: + thresholds = list(list(all_res.values())[0].keys()) + else: + thresholds = [50] + for thr in thresholds: + data[thr] = {} + for cls_key in all_res: + data[thr][cls_key] = all_res[cls_key][thr] + for thr in data: + res[thr] = self._combine_classes_class_averaged_thr( + data[thr], ignore_empty=ignore_empty + ) + return res + + def _combine_classes_class_averaged_thr(self, all_res, ignore_empty=False): + """Combines classes over each threshold.""" + res = {} + + def check_empty(val): + """Returns True if empty.""" + return not (val["Loc_TP"] + val["Loc_FN"] + val["Loc_FP"] > 0 + EPS).any() + + for field in self.integer_array_fields: + if ignore_empty: + res_field = {k: v for k, v in all_res.items() if not check_empty(v)} + else: + res_field = {k: v for k, v in all_res.items()} + res[field] = self._combine_sum(res_field, field) + + for field in self.float_array_fields: + if ignore_empty: + res_field = [v[field] for v in all_res.values() if not check_empty(v)] + else: + res_field = [v[field] for v in all_res.values()] + res[field] = np.mean(res_field, axis=0) + return res + + def combine_classes_det_averaged(self, all_res): + """Combines metrics across all classes by averaging over detections.""" + data = {} + res = {} + if all_res: + thresholds = list(list(all_res.values())[0].keys()) + else: + thresholds = [50] + for thr in thresholds: + data[thr] = {} + for cls_key in all_res: + data[thr][cls_key] = all_res[cls_key][thr] + for thr in data: + res[thr] = self._combine_classes_det_averaged_thr(data[thr]) + return res + + def _combine_classes_det_averaged_thr(self, all_res): + """Combines detections over each threshold.""" + res = {} + for field in self.integer_array_fields: + res[field] = self._combine_sum(all_res, field) + for field in ["AssocRe", "AssocPr", "AssocA"]: + res[field] = self._combine_weighted_av( + all_res, field, res, weight_field="Loc_TP" + ) + res = self._compute_final_fields(res) + return res + + @staticmethod + def _compute_final_fields(res): + """Calculate final metric values. + + This function is used both for both per-sequence calculation, + and in combining values across sequences. + """ + # LocA + res["LocRe"] = res["Loc_TP"] / np.maximum(1, res["Loc_TP"] + res["Loc_FN"]) + res["LocPr"] = res["Loc_TP"] / np.maximum(1, res["Loc_TP"] + res["Loc_FP"]) + res["LocA"] = res["Loc_TP"] / np.maximum( + 1, res["Loc_TP"] + res["Loc_FN"] + res["Loc_FP"] + ) + + # ClsA + res["ClsRe"] = res["Cls_TP"] / np.maximum(1, res["Cls_TP"] + res["Cls_FN"]) + res["ClsPr"] = res["Cls_TP"] / np.maximum(1, res["Cls_TP"] + res["Cls_FP"]) + res["ClsA"] = res["Cls_TP"] / np.maximum( + 1, res["Cls_TP"] + res["Cls_FN"] + res["Cls_FP"] + ) + + res["ClsRe"] = np.mean(res["ClsRe"]) + res["ClsPr"] = np.mean(res["ClsPr"]) + res["ClsA"] = np.mean(res["ClsA"]) + + res["TETA"] = (res["LocA"] + res["AssocA"] + res["ClsA"]) / 3 + + return res + + def print_summary_table(self, thr_res, thr, tracker, cls): + """Prints summary table of results.""" + print("") + metric_name = self.get_name() + self._row_print( + [f"{metric_name}{str(thr)}: {tracker}-{cls}"] + self.summary_fields + ) + self._row_print(["COMBINED"] + thr_res) diff --git a/third_party/sam3/sam3/eval/teta_eval_toolkit/utils.py b/third_party/sam3/sam3/eval/teta_eval_toolkit/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..933bf7a09eb8e0d205e268c306fdfde2eb9625e4 --- /dev/null +++ b/third_party/sam3/sam3/eval/teta_eval_toolkit/utils.py @@ -0,0 +1,48 @@ +# fmt: off +# flake8: noqa + +# pyre-unsafe + +import csv +import os +from collections import OrderedDict + + +def validate_metrics_list(metrics_list): + """Get names of metric class and ensures they are unique, further checks that the fields within each metric class + do not have overlapping names. + """ + metric_names = [metric.get_name() for metric in metrics_list] + # check metric names are unique + if len(metric_names) != len(set(metric_names)): + raise TrackEvalException( + "Code being run with multiple metrics of the same name" + ) + fields = [] + for m in metrics_list: + fields += m.fields + # check metric fields are unique + if len(fields) != len(set(fields)): + raise TrackEvalException( + "Code being run with multiple metrics with fields of the same name" + ) + return metric_names + + +def get_track_id_str(ann): + """Get name of track ID in annotation.""" + if "track_id" in ann: + tk_str = "track_id" + elif "instance_id" in ann: + tk_str = "instance_id" + elif "scalabel_id" in ann: + tk_str = "scalabel_id" + else: + assert False, "No track/instance ID." + return tk_str + + +class TrackEvalException(Exception): + """Custom exception for catching expected errors.""" + + ... diff --git a/third_party/sam3/sam3/eval/ytvis_coco_wrapper.py b/third_party/sam3/sam3/eval/ytvis_coco_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..5412d692335e8c22e5cae59083c0dce1aeaa14b6 --- /dev/null +++ b/third_party/sam3/sam3/eval/ytvis_coco_wrapper.py @@ -0,0 +1,148 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-unsafe + +import copy +import json +import logging + +import numpy as np +import pycocotools.mask as mask_util +from pycocotools.coco import COCO +from typing_extensions import override + + +class YTVIS(COCO): + """ + Helper class for reading YT-VIS annotations + """ + + @override + def __init__(self, annotation_file: str = None, ignore_gt_cats: bool = True): + """ + Args: + annotation_file: Path to the annotation file + ignore_gt_cats: If True, we ignore the ground truth categories and replace them with a dummy "object" category. This is useful for Phrase AP evaluation. + """ + self.ignore_gt_cats = ignore_gt_cats + super().__init__(annotation_file=annotation_file) + + @override + def createIndex(self): + # We rename some keys to match the COCO format before creating the index. + if "annotations" in self.dataset: + for ann in self.dataset["annotations"]: + if "video_id" in ann: + ann["image_id"] = int(ann.pop("video_id")) + if self.ignore_gt_cats: + ann["category_id"] = -1 + else: + ann["category_id"] = int(ann["category_id"]) + if "bboxes" in ann: + # note that in some datasets we load under this YTVIS class, + # some "bboxes" could be None for when the GT object is invisible, + # so we replace them with [0, 0, 0, 0] + ann["bboxes"] = [ + bbox if bbox is not None else [0, 0, 0, 0] + for bbox in ann["bboxes"] + ] + if "areas" in ann: + # similar to "bboxes", some areas could be None for when the GT + # object is invisible, so we replace them with 0 + areas = [a if a is not None else 0 for a in ann["areas"]] + # Compute average area of tracklet + ann["area"] = np.mean(areas) + if "videos" in self.dataset: + for vid in self.dataset["videos"]: + vid["id"] = int(vid["id"]) + self.dataset["images"] = self.dataset.pop("videos") + + if self.ignore_gt_cats: + self.dataset["categories"] = [ + {"supercategory": "object", "id": -1, "name": "object"} + ] + else: + for cat in self.dataset["categories"]: + cat["id"] = int(cat["id"]) + super().createIndex() + + @override + def getAnnIds(self, imgIds=[], catIds=[], areaRng=[], iscrowd=None): + if len(areaRng) > 0: + logging.warning( + "Note that we filter out objects based on their *average* area across the video, not per frame area" + ) + + return super().getAnnIds(imgIds=imgIds, catIds=catIds, iscrowd=iscrowd) + + @override + def showAnns(self, anns, draw_bbox=False): + raise NotImplementedError("Showing annotations is not supported") + + @override + def loadRes(self, resFile): + # Adapted from COCO.loadRes to support tracklets/masklets + res = YTVIS(ignore_gt_cats=self.ignore_gt_cats) + res.dataset["images"] = [img for img in self.dataset["images"]] + + if type(resFile) == str: + with open(resFile) as f: + anns = json.load(f) + elif type(resFile) == np.ndarray: + anns = self.loadNumpyAnnotations(resFile) + else: + anns = resFile + assert type(anns) == list, "results is not an array of objects" + annsImgIds = [ann["image_id"] for ann in anns] + assert set(annsImgIds) == ( + set(annsImgIds) & set(self.getImgIds()) + ), "Results do not correspond to current coco set" + if "bboxes" in anns[0] and not anns[0]["bboxes"] == []: + res.dataset["categories"] = copy.deepcopy(self.dataset["categories"]) + for id, ann in enumerate(anns): + bbs = [(bb if bb is not None else [0, 0, 0, 0]) for bb in ann["bboxes"]] + xxyy = [[bb[0], bb[0] + bb[2], bb[1], bb[1] + bb[3]] for bb in bbs] + if not "segmentations" in ann: + ann["segmentations"] = [ + [[x1, y1, x1, y2, x2, y2, x2, y1]] for (x1, x2, y1, y2) in xxyy + ] + ann["areas"] = [bb[2] * bb[3] for bb in bbs] + # NOTE: We also compute average area of a tracklet across video, allowing us to compute area based mAP. + ann["area"] = np.mean(ann["areas"]) + ann["id"] = id + 1 + ann["iscrowd"] = 0 + elif "segmentations" in anns[0]: + res.dataset["categories"] = copy.deepcopy(self.dataset["categories"]) + for id, ann in enumerate(anns): + ann["bboxes"] = [ + mask_util.toBbox(segm) for segm in ann["segmentations"] + ] + if "areas" not in ann: + ann["areas"] = [ + mask_util.area(segm) for segm in ann["segmentations"] + ] + # NOTE: We also compute average area of a tracklet across video, allowing us to compute area based mAP. + ann["area"] = np.mean(ann["areas"]) + ann["id"] = id + 1 + ann["iscrowd"] = 0 + + res.dataset["annotations"] = anns + res.createIndex() + return res + + @override + def download(self, tarDir=None, imgIds=[]): + raise NotImplementedError + + @override + def loadNumpyAnnotations(self, data): + raise NotImplementedError("We don't support numpy annotations for now") + + @override + def annToRLE(self, ann): + raise NotImplementedError("We expect masks to be already in RLE format") + + @override + def annToMask(self, ann): + raise NotImplementedError("We expect masks to be already in RLE format") diff --git a/third_party/sam3/sam3/eval/ytvis_eval.py b/third_party/sam3/sam3/eval/ytvis_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..2ff230959f9a46d00e923b70ffd9c26a3ac0b5f3 --- /dev/null +++ b/third_party/sam3/sam3/eval/ytvis_eval.py @@ -0,0 +1,413 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe +import copy +import gc +import logging +import os +from collections import defaultdict +from operator import xor +from pathlib import Path +from typing import List, Optional + +import numpy as np +import pycocotools.mask as mask_util +import torch +from pycocotools.cocoeval import COCOeval +from sam3.eval.cgf1_eval import CGF1Eval +from sam3.eval.coco_eval_offline import convert_to_xywh +from sam3.model.box_ops import box_xywh_inter_union +from sam3.train.masks_ops import rle_encode +from sam3.train.utils import distributed as dist +from typing_extensions import override + +try: + import rapidjson as json +except ModuleNotFoundError: + import json + +from iopath.common.file_io import g_pathmgr + + +class YTVISevalMixin: + """ + Identical to COCOeval but adapts computeIoU to compute IoU between tracklets/masklets. + """ + + @override + def _prepare(self): + """ + Copied from cocoeval.py but doesn't convert masks to RLEs (we assume they already are RLEs) + """ + p = self.params + if p.useCats: + gts = self.cocoGt.loadAnns( + self.cocoGt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds) + ) + dts = self.cocoDt.loadAnns( + self.cocoDt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds) + ) + else: + gts = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds)) + dts = self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds)) + + # set ignore flag + for gt in gts: + gt["ignore"] = gt["ignore"] if "ignore" in gt else 0 + gt["ignore"] = "iscrowd" in gt and gt["iscrowd"] + if p.iouType == "keypoints": + gt["ignore"] = (gt["num_keypoints"] == 0) or gt["ignore"] + self._gts = defaultdict(list) # gt for evaluation + self._dts = defaultdict(list) # dt for evaluation + for gt in gts: + self._gts[gt["image_id"], gt["category_id"]].append(gt) + for dt in dts: + self._dts[dt["image_id"], dt["category_id"]].append(dt) + self.evalImgs = defaultdict(list) # per-image per-category evaluation results + self.eval = {} # accumulated evaluation results + + def computeIoU(self, imgId, catId): + """ + Compute IoU between tracklets. Copied from cocoeval.py but adapted for videos (in YT-VIS format) + """ + p = self.params + if p.useCats: + gt = self._gts[imgId, catId] + dt = self._dts[imgId, catId] + else: + gt = [_ for cId in p.catIds for _ in self._gts[imgId, cId]] + dt = [_ for cId in p.catIds for _ in self._dts[imgId, cId]] + if len(gt) == 0 or len(dt) == 0: + return [] + + # For class mAP and phrase AP evaluation, we sort the detections in descending order of scores (as in COCOeval). + # For demo F1 evaluation, we DO NOT sort the detections (but match them with GTs via Hungarian matching). + assert hasattr(self, "sort_inds_by_scores_in_iou"), ( + "subclasses that inherits YTVISevalMixin should set `self.sort_inds_by_scores_in_iou` " + "(True for class mAP and phrase AP, False for demo F1)" + ) + if self.sort_inds_by_scores_in_iou: + inds = np.argsort([-d["score"] for d in dt], kind="mergesort") + dt = [dt[i] for i in inds] + if len(dt) > p.maxDets[-1]: + dt = dt[0 : p.maxDets[-1]] + + if p.iouType == "segm": + g = [g["segmentations"] for g in gt] + d = [d["segmentations"] for d in dt] + elif p.iouType == "bbox": + g = [g["bboxes"] for g in gt] + d = [d["bboxes"] for d in dt] + else: + raise Exception("unknown iouType for iou computation") + + def iou_tracklets(preds, gts): + preds = torch.tensor(preds) + gts = torch.tensor(gts) + inter, union = box_xywh_inter_union( + preds.unsqueeze(1), gts.unsqueeze(0) + ) # Num preds x Num GTS x Num frames + inter = inter.sum(-1) + union = union.sum(-1) + assert ( + union > 0 + ).all(), ( + "There exists a tracklet with zero GTs across time. This is suspicious" + ) + return inter / union + + def iou_masklets(preds, gts): + inter = 0 + union = 0 + for p_i, gt_i in zip(preds, gts): + if p_i and gt_i: + # Compute areas of intersection and union + inter += mask_util.area( + mask_util.merge([p_i, gt_i], intersect=True) + ) + union += mask_util.area( + mask_util.merge([p_i, gt_i], intersect=False) + ) + elif gt_i: + union += mask_util.area(gt_i) + elif p_i: + union += mask_util.area(p_i) + if union > 0: + iou = inter / union + assert iou >= 0 and iou <= 1, "Encountered an error in IoU computation" + else: + assert np.isclose(inter, 0) and np.isclose( + union, 0 + ), "Encountered an error in IoU computation" + iou = 1 + return iou + + if p.iouType == "segm": + ious = [[iou_masklets(d_i, g_i) for g_i in g] for d_i in d] + else: + ious = iou_tracklets(d, g) + return np.array(ious) + + +class YTVISeval(YTVISevalMixin, COCOeval): + # For class mAP and phrase AP evaluation, we sort the detections in descending order of scores (as in COCOeval). + sort_inds_by_scores_in_iou = True + + +class VideoDemoF1Eval(YTVISevalMixin, CGF1Eval): + # For demo F1 evaluation, we DO NOT sort the detections (but match them with GTs via Hungarian matching). + sort_inds_by_scores_in_iou = False + + +class YTVISResultsWriter: + """ + Gather and dumps predictions in YT-VIS format. + Expected flow of API calls: reset() -> N * update() -> compute_synced() + """ + + def __init__( + self, + dump_file: str, + postprocessor, + gather_pred_via_filesys=False, + pred_file_evaluators: Optional[List] = None, + save_per_frame_scores: bool = False, + write_eval_metrics_file: bool = True, + eval_metrics_file_suffix: str = ".sam3_eval_metrics", + ): + self.dump_file = dump_file + self.dump = [] + self.postprocessor = postprocessor + self.gather_pred_via_filesys = gather_pred_via_filesys + if dist.is_main_process(): + dirname = os.path.dirname(self.dump_file) + if not os.path.exists(dirname): + os.makedirs(dirname, exist_ok=True) + logging.info(f"Creating folder: {dirname}") + + # the evaluation hooks to be applied to the prediction files + self.pred_file_evaluators = pred_file_evaluators or [] + self.save_per_frame_scores = save_per_frame_scores + # in addition to the prediction file, we also write the evaluation metrics + # for easier debugging and analysis (stored in another eval_metrics_file + # so that we can keep the dumped prediction file under YT-VIS format) + self.write_eval_metrics_file = write_eval_metrics_file + if self.write_eval_metrics_file: + self.eval_metrics_file = self.dump_file + eval_metrics_file_suffix + os.makedirs(os.path.dirname(self.eval_metrics_file), exist_ok=True) + + def _dump_vid_preds(self, results): + dumped_results = copy.deepcopy(results) + self.dump.extend(dumped_results) + + def prepare(self, predictions): + ytvis_results = [] + for video_id, prediction in predictions.items(): + if len(prediction) == 0: + continue + for k in ["boxes", "scores", "labels"]: + assert ( + k in prediction + ), f"Expected predictions to have `{k}` key, available keys are {prediction.keys()}" + if self.save_per_frame_scores: + assert ( + "per_frame_scores" in prediction + ), f"Expected predictions to have `per_frame_scores` key, available keys are {prediction.keys()}" + assert xor( + "masks" in prediction, "masks_rle" in prediction + ), f"Expected predictions to have either `masks` key or `masks_rle` key, available keys are {prediction.keys()}" + + boxes = prediction["boxes"] + boxes = convert_to_xywh(boxes).tolist() + scores = prediction["scores"].tolist() + labels = prediction["labels"].tolist() + if "masks" in prediction: + masks = prediction["masks"].squeeze(2) + assert ( + masks.ndim == 4 + ), "Expected masks to be of shape(N_preds,T_frames,H,W)" + + areas = [mask.flatten(1).sum(1).tolist() for mask in masks] + rles = [rle_encode(masklet) for masklet in masks] + + # memory clean + del masks + del prediction["masks"] + elif "masks_rle" in prediction: + rles = prediction.pop("masks_rle") + areas = [ + [0 if rle is None else rle.pop("area") for rle in rles_per_obj] + for rles_per_obj in rles + ] + else: + raise ValueError( + "Expected either `masks` or `masks_rle` key in the predictions." + ) + + new_results = [ + { + "video_id": video_id, + "category_id": track_label, + "bboxes": track_boxes, + "score": track_score, + "segmentations": track_masks, + "areas": track_areas, + } + for ( + track_boxes, + track_masks, + track_areas, + track_score, + track_label, + ) in zip(boxes, rles, areas, scores, labels) + ] + # Optionally, save per-frame scores + if self.save_per_frame_scores: + per_frame_scores = prediction["per_frame_scores"].tolist() + for res, track_per_frame_scores in zip(new_results, per_frame_scores): + res["per_frame_scores"] = track_per_frame_scores + + ytvis_results.extend(new_results) + + return ytvis_results + + def set_sync_device(self, device: torch.device): + self._sync_device = device + + def update(self, *args, **kwargs): + predictions = self.postprocessor.process_results(*args, **kwargs) + results = self.prepare(predictions) + self._dump_vid_preds(results) + + def _dump_preds(self): + if not dist.is_main_process(): + self.dump = [] + gc.collect() + return + dumped_file = Path(self.dump_file) + logging.info(f"YTVIS evaluator: Dumping predictions to {dumped_file}") + with g_pathmgr.open(str(dumped_file), "w") as f: + json.dump(self.dump, f) + self.dump = [] + gc.collect() + return str(dumped_file) + + def synchronize_between_processes(self): + logging.info("YT-VIS evaluator: Synchronizing between processes") + dump_dict = self._dedup_pre_gather(self.dump) + if self.gather_pred_via_filesys: + dump_dict_all_gpus = dist.gather_to_rank_0_via_filesys(dump_dict) + else: + dump_dict_all_gpus = dist.all_gather(dump_dict, force_cpu=True) + self.dump = self._dedup_post_gather(dump_dict_all_gpus) + logging.info(f"Gathered all {len(self.dump)} predictions") + + def _dedup_pre_gather(self, predictions): + """ + Organize the predictions as a dict-of-list using (video_id, category_id) as keys + for deduplication after gathering them across GPUs. + + During evaluation, PyTorch data loader under `drop_last: False` would wrap + around the dataset length to be a multiple of world size (GPU num) and duplicate + the remaining batches. This causes the same test sample to appear simultaneously + in multiple GPUs, resulting in duplicated predictions being saved into prediction + files. These duplicates are then counted as false positives under detection mAP + metrics (since a ground truth can be matched with only one prediction). + + For example, if there are 4 GPUs and 6 samples [A1, A2, B1, B2, C1, C2], the data + loader (under `drop_last: False`) would load it by wrapping it around like + `[A1, A2, B1, B2, C1, C2, *A1*, *A2*]` to make a multiple of 4 and then split it as + + - GPU 0: A1, C1 + - GPU 1: A2, C2 + - GPU 3: B1, **A1** + - GPU 4: B2, **A2** + (as in DistributedSampler in https://github.com/pytorch/pytorch/blob/521588519da9f4876d90ddd7a17c10d0eca89dc6/torch/utils/data/distributed.py#L116-L124) + + so the predictions on A1 and A2 will occur twice in the final gathered outputs + in the prediction file (and counted as false positives). This also affects our + YT-VIS official val evaluation, but to a lesser extent than YT-VIS dev since + the latter is much smaller and more susceptible to false positives. + + So we to deduplicate this. The tricky part is that we cannot deduplicate them + simply using video id, given that we are sharding the classes in each video + across multiple batches (with 20 prompts per batch) in our "orig_cats" eval dbs. + + The solution is to deduplicate based on (video_id, category_id) tuple as keys. + We organize the predictions as a dict-of-list using (video_id, category_id) as + keys on each GPU, with the list of masklets under this (video_id, category_id) + on this GPU as values. Then, we all-gather this dict-of-list across GPUs and + if a key (video_id, category_id) appears in multiple GPUs, we only take the + prediction masklet list from one GPU. + """ + prediction_dict = defaultdict(list) + for p in predictions: + prediction_dict[(p["video_id"], p["category_id"])].append(p) + return prediction_dict + + def _dedup_post_gather(self, list_of_prediction_dict): + """ + Deduplicate the predictions from all GPUs. See `_dedup_pre_gather` for details. + """ + dedup_prediction_dict = {} + duplication_keys = [] + for prediction_dict in list_of_prediction_dict: + for k, v in prediction_dict.items(): + if k not in dedup_prediction_dict: + dedup_prediction_dict[k] = v + else: + duplication_keys.append(k) + + logging.info( + f"skipped {len(duplication_keys)} duplicated predictions in YTVISResultsWriter " + f"with the following (video_id, category_id) tuples: {duplication_keys}" + ) + dedup_predictions = sum(dedup_prediction_dict.values(), []) + return dedup_predictions + + def compute_synced( + self, + ): + self.synchronize_between_processes() + dumped_file = self._dump_preds() + if not dist.is_main_process(): + return {"": 0.0} + + # run evaluation hooks on the prediction file + meters = {} + all_video_np_level_results = defaultdict(dict) + for evaluator in self.pred_file_evaluators: + gc.collect() + results, video_np_level_results = evaluator.evaluate(dumped_file) + meters.update(results) + for (video_id, category_id), res in video_np_level_results.items(): + all_video_np_level_results[(video_id, category_id)].update(res) + + gc.collect() + if self.write_eval_metrics_file: + # convert the nested dict of {(video_id, category_id): per_sample_metric_dict} + # to a list of per-sample metric dicts (with video_id and category_id) for JSON, + # as JSON doesn't allow using tuples like (video_id, category_id) as dict keys + video_np_level_metrics = [ + {"video_id": video_id, "category_id": category_id, **res} + for (video_id, category_id), res in all_video_np_level_results.items() + ] + eval_metrics = { + "dataset_level_metrics": meters, + "video_np_level_metrics": video_np_level_metrics, + } + with g_pathmgr.open(self.eval_metrics_file, "w") as f: + json.dump(eval_metrics, f) + logging.info( + f"YTVIS evaluator: Dumped evaluation metrics to {self.eval_metrics_file}" + ) + + if len(meters) == 0: + meters = {"": 0.0} + return meters + + def compute(self): + return {"": 0.0} + + def reset(self, *args, **kwargs): + self.dump = [] diff --git a/third_party/sam3/sam3/logger.py b/third_party/sam3/sam3/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..35dcc0d830d7a879af1d75ada2bad5650782eae9 --- /dev/null +++ b/third_party/sam3/sam3/logger.py @@ -0,0 +1,56 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe +import logging +import os + +LOG_LEVELS = { + "DEBUG": logging.DEBUG, + "INFO": logging.INFO, + "WARNING": logging.WARNING, + "ERROR": logging.ERROR, + "CRITICAL": logging.CRITICAL, +} + + +class ColoredFormatter(logging.Formatter): + """A command line formatter with different colors for each level.""" + + def __init__(self): + super().__init__() + reset = "\033[0m" + colors = { + logging.DEBUG: f"{reset}\033[36m", # cyan, + logging.INFO: f"{reset}\033[32m", # green + logging.WARNING: f"{reset}\033[33m", # yellow + logging.ERROR: f"{reset}\033[31m", # red + logging.CRITICAL: f"{reset}\033[35m", # magenta + } + fmt_str = "{color}%(levelname)s %(asctime)s %(process)d %(filename)s:%(lineno)4d:{reset} %(message)s" + self.formatters = { + level: logging.Formatter(fmt_str.format(color=color, reset=reset)) + for level, color in colors.items() + } + self.default_formatter = self.formatters[logging.INFO] + + def format(self, record): + formatter = self.formatters.get(record.levelno, self.default_formatter) + return formatter.format(record) + + +def get_logger(name, level=logging.INFO): + """A command line logger.""" + if "LOG_LEVEL" in os.environ: + level = os.environ["LOG_LEVEL"].upper() + assert ( + level in LOG_LEVELS + ), f"Invalid LOG_LEVEL: {level}, must be one of {list(LOG_LEVELS.keys())}" + level = LOG_LEVELS[level] + logger = logging.getLogger(name) + logger.setLevel(level) + logger.propagate = False + ch = logging.StreamHandler() + ch.setLevel(level) + ch.setFormatter(ColoredFormatter()) + logger.addHandler(ch) + return logger diff --git a/third_party/sam3/sam3/model/__init__.py b/third_party/sam3/sam3/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..726cb1eb0ff0f17f7edc4ec73f1a73d1ac87bf59 --- /dev/null +++ b/third_party/sam3/sam3/model/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe diff --git a/third_party/sam3/sam3/model/__pycache__/__init__.cpython-311.pyc b/third_party/sam3/sam3/model/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50babcc0cc21b22d32b6078e9408c60c25726b15 Binary files /dev/null and b/third_party/sam3/sam3/model/__pycache__/__init__.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/model/__pycache__/act_ckpt_utils.cpython-311.pyc b/third_party/sam3/sam3/model/__pycache__/act_ckpt_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fa481fb4602ace6ce96da8ccafa23fc6c5c82dd Binary files /dev/null and b/third_party/sam3/sam3/model/__pycache__/act_ckpt_utils.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/model/__pycache__/box_ops.cpython-311.pyc b/third_party/sam3/sam3/model/__pycache__/box_ops.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3394e732af3de9a40382bf198e147b10768e0ffc Binary files /dev/null and b/third_party/sam3/sam3/model/__pycache__/box_ops.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/model/__pycache__/data_misc.cpython-311.pyc b/third_party/sam3/sam3/model/__pycache__/data_misc.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dcf272445b3f1dc5cc19ca49197ea5b760ce09a3 Binary files /dev/null and b/third_party/sam3/sam3/model/__pycache__/data_misc.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/model/__pycache__/decoder.cpython-311.pyc b/third_party/sam3/sam3/model/__pycache__/decoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..645cb4c08789909cdd7954a46d9ff4d0e47287a3 Binary files /dev/null and b/third_party/sam3/sam3/model/__pycache__/decoder.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/model/__pycache__/edt.cpython-311.pyc b/third_party/sam3/sam3/model/__pycache__/edt.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f96466628ee22bbe818078e4944eaa51ca95190c Binary files /dev/null and b/third_party/sam3/sam3/model/__pycache__/edt.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/model/__pycache__/encoder.cpython-311.pyc b/third_party/sam3/sam3/model/__pycache__/encoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea02c4b6dc47dda227aae131fffddab68432f5a4 Binary files /dev/null and b/third_party/sam3/sam3/model/__pycache__/encoder.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/model/__pycache__/geometry_encoders.cpython-311.pyc b/third_party/sam3/sam3/model/__pycache__/geometry_encoders.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc38ce01a8b12c4617500ffafbd9219352de5bb5 Binary files /dev/null and b/third_party/sam3/sam3/model/__pycache__/geometry_encoders.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/model/__pycache__/io_utils.cpython-311.pyc b/third_party/sam3/sam3/model/__pycache__/io_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9bddf85908f0c49b8667175a9cd0bdcf94c06e89 Binary files /dev/null and b/third_party/sam3/sam3/model/__pycache__/io_utils.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/model/__pycache__/maskformer_segmentation.cpython-311.pyc b/third_party/sam3/sam3/model/__pycache__/maskformer_segmentation.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4d7f1e7163a9cc38e89848855554b60c3307eb5 Binary files /dev/null and b/third_party/sam3/sam3/model/__pycache__/maskformer_segmentation.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/model/__pycache__/memory.cpython-311.pyc b/third_party/sam3/sam3/model/__pycache__/memory.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d7fcb08d43db1af351ac902e160e012a79a85c3 Binary files /dev/null and b/third_party/sam3/sam3/model/__pycache__/memory.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/model/__pycache__/model_misc.cpython-311.pyc b/third_party/sam3/sam3/model/__pycache__/model_misc.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c2a272af4976333ac0ea8f96f242a319dba5b745 Binary files /dev/null and b/third_party/sam3/sam3/model/__pycache__/model_misc.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/model/__pycache__/multiplex_mask_decoder.cpython-311.pyc b/third_party/sam3/sam3/model/__pycache__/multiplex_mask_decoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0646a91179a9e0cf8fd64f2e283c82a7ea1b866f Binary files /dev/null and b/third_party/sam3/sam3/model/__pycache__/multiplex_mask_decoder.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/model/__pycache__/multiplex_utils.cpython-311.pyc b/third_party/sam3/sam3/model/__pycache__/multiplex_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0d1a61f8e595dc840a164ed5a9d758f652266e6 Binary files /dev/null and b/third_party/sam3/sam3/model/__pycache__/multiplex_utils.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/model/__pycache__/necks.cpython-311.pyc b/third_party/sam3/sam3/model/__pycache__/necks.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c98bab93ca01450c6680a64261cd91cbd00bc5e2 Binary files /dev/null and b/third_party/sam3/sam3/model/__pycache__/necks.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/model/__pycache__/position_encoding.cpython-311.pyc b/third_party/sam3/sam3/model/__pycache__/position_encoding.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de10c62ce16d044d3fffb82b7ece3de610e76e59 Binary files /dev/null and b/third_party/sam3/sam3/model/__pycache__/position_encoding.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/model/__pycache__/sam1_task_predictor.cpython-311.pyc b/third_party/sam3/sam3/model/__pycache__/sam1_task_predictor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08362c6178c63aa779c12703258606912bdea072 Binary files /dev/null and b/third_party/sam3/sam3/model/__pycache__/sam1_task_predictor.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/model/__pycache__/sam3_base_predictor.cpython-311.pyc b/third_party/sam3/sam3/model/__pycache__/sam3_base_predictor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b785851040cc7681cd3ebf5140efcd979b383150 Binary files /dev/null and b/third_party/sam3/sam3/model/__pycache__/sam3_base_predictor.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/model/__pycache__/sam3_image.cpython-311.pyc b/third_party/sam3/sam3/model/__pycache__/sam3_image.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..873c54352c9d4ae173e4639419cfa02e977851b5 Binary files /dev/null and b/third_party/sam3/sam3/model/__pycache__/sam3_image.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/model/__pycache__/sam3_tracker_base.cpython-311.pyc b/third_party/sam3/sam3/model/__pycache__/sam3_tracker_base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea53b8c88ad072ddf590a523973ce379214b1a62 Binary files /dev/null and b/third_party/sam3/sam3/model/__pycache__/sam3_tracker_base.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/model/__pycache__/sam3_tracker_utils.cpython-311.pyc b/third_party/sam3/sam3/model/__pycache__/sam3_tracker_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e58195ec73436ed1988d76ae410c64e4780a3c28 Binary files /dev/null and b/third_party/sam3/sam3/model/__pycache__/sam3_tracker_utils.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/model/__pycache__/sam3_tracking_predictor.cpython-311.pyc b/third_party/sam3/sam3/model/__pycache__/sam3_tracking_predictor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d7fc3d36cfc8fb02df9c4804ee2ff9286af29c2 Binary files /dev/null and b/third_party/sam3/sam3/model/__pycache__/sam3_tracking_predictor.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/model/__pycache__/sam3_video_base.cpython-311.pyc b/third_party/sam3/sam3/model/__pycache__/sam3_video_base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4bff5dd48c7167eed5a6b51eb2fbd40a00d14b7f Binary files /dev/null and b/third_party/sam3/sam3/model/__pycache__/sam3_video_base.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/model/__pycache__/sam3_video_inference.cpython-311.pyc b/third_party/sam3/sam3/model/__pycache__/sam3_video_inference.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64c383af47142b470378c23d1dc4527233735838 Binary files /dev/null and b/third_party/sam3/sam3/model/__pycache__/sam3_video_inference.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/model/__pycache__/sam3_video_predictor.cpython-311.pyc b/third_party/sam3/sam3/model/__pycache__/sam3_video_predictor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ca0039865bcd7ea32f0310782b2bd41d7ed1b86 Binary files /dev/null and b/third_party/sam3/sam3/model/__pycache__/sam3_video_predictor.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/model/__pycache__/text_encoder_ve.cpython-311.pyc b/third_party/sam3/sam3/model/__pycache__/text_encoder_ve.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04164dbd01b142619cb9b22ea9cdb53a69e53d2b Binary files /dev/null and b/third_party/sam3/sam3/model/__pycache__/text_encoder_ve.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/model/__pycache__/tokenizer_ve.cpython-311.pyc b/third_party/sam3/sam3/model/__pycache__/tokenizer_ve.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb88e0a782e80fd8e5a6d07047e577d49dff0725 Binary files /dev/null and b/third_party/sam3/sam3/model/__pycache__/tokenizer_ve.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/model/__pycache__/video_tracking_multiplex.cpython-311.pyc b/third_party/sam3/sam3/model/__pycache__/video_tracking_multiplex.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b23d677e22ff84a737ddadfa5687cd7df77e017d --- /dev/null +++ b/third_party/sam3/sam3/model/__pycache__/video_tracking_multiplex.cpython-311.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d7f544e8339574851c574e031748baa36e95b5903c5971d9491c6ba1184c5c06 +size 120802 diff --git a/third_party/sam3/sam3/model/__pycache__/vitdet.cpython-311.pyc b/third_party/sam3/sam3/model/__pycache__/vitdet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a722bda75e54e51b07efa47e29d0292cba4825e Binary files /dev/null and b/third_party/sam3/sam3/model/__pycache__/vitdet.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/model/__pycache__/vl_combiner.cpython-311.pyc b/third_party/sam3/sam3/model/__pycache__/vl_combiner.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e23e3a16730f21244ba13f34a286ced0522359c9 Binary files /dev/null and b/third_party/sam3/sam3/model/__pycache__/vl_combiner.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/model/act_ckpt_utils.py b/third_party/sam3/sam3/model/act_ckpt_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6a6d98383840ea3d3ea8104e30568bc27f2db748 --- /dev/null +++ b/third_party/sam3/sam3/model/act_ckpt_utils.py @@ -0,0 +1,116 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +import inspect +from functools import wraps +from typing import Callable, TypeVar, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint +from torch.utils._pytree import tree_map_only + +# Type variables for better type hinting +T = TypeVar("T") +Module = TypeVar("Module", bound=nn.Module) + + +def activation_ckpt_wrapper(module: Union[nn.Module, Callable]) -> Callable: + """ + Wraps a given module to enable or disable activation checkpointing. + + Activation checkpointing (gradient checkpointing) trades compute for memory by + recomputing intermediate activations during the backward pass instead of storing + them in memory during the forward pass. + + When activation checkpointing is enabled, the wrapper expects only keyword arguments, + and it maps these to positional arguments based on the module's signature. + + Args: + module: The module or function to wrap with activation checkpointing + + Returns: + A wrapped callable that supports activation checkpointing + + Usage: + The returned wrapper function can be called with the same arguments as the + original module, with an additional `act_ckpt_enable` keyword argument to control + activation checkpointing and optional `use_reentrant` parameter. + + Example: + ```python + wrapped_module = activation_ckpt_wrapper(my_module) + output = wrapped_module(x=input_tensor, y=another_tensor, act_ckpt_enable=True) + ``` + """ + + @wraps(module) + def act_ckpt_wrapper( + *args, act_ckpt_enable: bool = True, use_reentrant: bool = False, **kwargs + ): + if act_ckpt_enable: + if len(args) > 0: + raise ValueError( + "This wrapper expects keyword arguments only when `act_ckpt_enable=True`" + ) + # Get the signature of the target function/module + callable_fn = module.forward if isinstance(module, nn.Module) else module + sig = inspect.signature(callable_fn) + # Create a mapping of parameter names to their default values + param_defaults = { + name: param.default for name, param in sig.parameters.items() + } + args = [] + for p_name in param_defaults.keys(): + if p_name in kwargs: + args.append(kwargs.pop(p_name)) + elif param_defaults[p_name] is not inspect.Parameter.empty: + # Set arg to default value if it's not in kwargs. Useful for primitive types or args that default to None + args.append(param_defaults[p_name]) + elif ( + sig.parameters[p_name].kind is not inspect.Parameter.VAR_KEYWORD + ): # Skip **kwargs parameter + raise ValueError(f"Missing positional argument: {p_name}") + + # Scan remaining kwargs for torch.Tensor + remaining_keys = list(kwargs.keys()) + for key in remaining_keys: + if isinstance(kwargs[key], torch.Tensor): + # Remove the tensor from kwargs, assuming it's not required by the module. + # If it is required, the module's signature should be modified to accept it as a positional or keyword argument. + kwargs[key] = "_REMOVED_BY_ACT_CKPT_WRAPPER_" + + ret = checkpoint.checkpoint( + module, *args, use_reentrant=use_reentrant, **kwargs + ) + else: + ret = module(*args, **kwargs) + + return ret + + return act_ckpt_wrapper + + +def clone_output_wrapper(f: Callable[..., T]) -> Callable[..., T]: + """ + Clone the CUDA output tensors of a function to avoid in-place operations. + + This wrapper is useful when working with torch.compile to prevent errors + related to in-place operations on tensors. + + Args: + f: The function whose CUDA tensor outputs should be cloned + + Returns: + A wrapped function that clones any CUDA tensor outputs + """ + + @wraps(f) + def wrapped(*args, **kwargs): + outputs = f(*args, **kwargs) + return tree_map_only( + torch.Tensor, lambda t: t.clone() if t.is_cuda else t, outputs + ) + + return wrapped diff --git a/third_party/sam3/sam3/model/box_ops.py b/third_party/sam3/sam3/model/box_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..59f52e0390893f8e911baace74b3debbf7c6f99a --- /dev/null +++ b/third_party/sam3/sam3/model/box_ops.py @@ -0,0 +1,219 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe +""" +Utilities for bounding box manipulation and GIoU. +""" + +from typing import Tuple + +import torch + + +def box_cxcywh_to_xyxy(x): + x_c, y_c, w, h = x.unbind(-1) + b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] + return torch.stack(b, dim=-1) + + +def box_cxcywh_to_xywh(x): + x_c, y_c, w, h = x.unbind(-1) + b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (w), (h)] + return torch.stack(b, dim=-1) + + +def box_xywh_to_xyxy(x): + x, y, w, h = x.unbind(-1) + b = [(x), (y), (x + w), (y + h)] + return torch.stack(b, dim=-1) + + +def box_xywh_to_cxcywh(x): + x, y, w, h = x.unbind(-1) + b = [(x + 0.5 * w), (y + 0.5 * h), (w), (h)] + return torch.stack(b, dim=-1) + + +def box_xyxy_to_xywh(x): + x, y, X, Y = x.unbind(-1) + b = [(x), (y), (X - x), (Y - y)] + return torch.stack(b, dim=-1) + + +def box_xyxy_to_cxcywh(x): + x0, y0, x1, y1 = x.unbind(-1) + b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)] + return torch.stack(b, dim=-1) + + +def box_area(boxes): + """ + Batched version of box area. Boxes should be in [x0, y0, x1, y1] format. + + Inputs: + - boxes: Tensor of shape (..., 4) + + Returns: + - areas: Tensor of shape (...,) + """ + x0, y0, x1, y1 = boxes.unbind(-1) + return (x1 - x0) * (y1 - y0) + + +def masks_to_boxes(masks): + """Compute the bounding boxes around the provided masks + + The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. + + Returns a [N, 4] tensors, with the boxes in xyxy format + """ + if masks.numel() == 0: + return torch.zeros((0, 4), device=masks.device) + + h, w = masks.shape[-2:] + + y = torch.arange(0, h, dtype=torch.float, device=masks.device) + x = torch.arange(0, w, dtype=torch.float, device=masks.device) + y, x = torch.meshgrid(y, x) + + x_mask = masks * x.unsqueeze(0) + x_max = x_mask.flatten(1).max(-1)[0] + 1 + x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] + + y_mask = masks * y.unsqueeze(0) + y_max = y_mask.flatten(1).max(-1)[0] + 1 + y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] + + boxes = torch.stack([x_min, y_min, x_max, y_max], 1) + # Invalidate boxes corresponding to empty masks. + boxes = boxes * masks.flatten(-2).any(-1) + return boxes + + +def box_iou(boxes1, boxes2): + """ + Batched version of box_iou. Boxes should be in [x0, y0, x1, y1] format. + + Inputs: + - boxes1: Tensor of shape (..., N, 4) + - boxes2: Tensor of shape (..., M, 4) + + Returns: + - iou, union: Tensors of shape (..., N, M) + """ + area1 = box_area(boxes1) + area2 = box_area(boxes2) + + # boxes1: (..., N, 4) -> (..., N, 1, 2) + # boxes2: (..., M, 4) -> (..., 1, M, 2) + lt = torch.max(boxes1[..., :, None, :2], boxes2[..., None, :, :2]) + rb = torch.min(boxes1[..., :, None, 2:], boxes2[..., None, :, 2:]) + + wh = (rb - lt).clamp(min=0) # (..., N, M, 2) + inter = wh[..., 0] * wh[..., 1] # (..., N, M) + + union = area1[..., None] + area2[..., None, :] - inter + + iou = inter / union + return iou, union + + +def generalized_box_iou(boxes1, boxes2): + """ + Batched version of Generalized IoU from https://giou.stanford.edu/ + + Boxes should be in [x0, y0, x1, y1] format + + Inputs: + - boxes1: Tensor of shape (..., N, 4) + - boxes2: Tensor of shape (..., M, 4) + + Returns: + - giou: Tensor of shape (..., N, M) + """ + iou, union = box_iou(boxes1, boxes2) + + # boxes1: (..., N, 4) -> (..., N, 1, 2) + # boxes2: (..., M, 4) -> (..., 1, M, 2) + lt = torch.min(boxes1[..., :, None, :2], boxes2[..., None, :, :2]) + rb = torch.max(boxes1[..., :, None, 2:], boxes2[..., None, :, 2:]) + + wh = (rb - lt).clamp(min=0) # (..., N, M, 2) + area = wh[..., 0] * wh[..., 1] # (..., N, M) + + return iou - (area - union) / area + + +@torch.jit.script +def fast_diag_generalized_box_iou(boxes1, boxes2): + assert len(boxes1) == len(boxes2) + box1_xy = boxes1[:, 2:] + box1_XY = boxes1[:, :2] + box2_xy = boxes2[:, 2:] + box2_XY = boxes2[:, :2] + # assert (box1_xy >= box1_XY).all() + # assert (box2_xy >= box2_XY).all() + area1 = (box1_xy - box1_XY).prod(-1) + area2 = (box2_xy - box2_XY).prod(-1) + + lt = torch.max(box1_XY, box2_XY) # [N,2] + lt2 = torch.min(box1_XY, box2_XY) + rb = torch.min(box1_xy, box2_xy) # [N,2] + rb2 = torch.max(box1_xy, box2_xy) + + inter = (rb - lt).clamp(min=0).prod(-1) + tot_area = (rb2 - lt2).clamp(min=0).prod(-1) + + union = area1 + area2 - inter + + iou = inter / union + + return iou - (tot_area - union) / tot_area + + +@torch.jit.script +def fast_diag_box_iou(boxes1, boxes2): + assert len(boxes1) == len(boxes2) + box1_xy = boxes1[:, 2:] + box1_XY = boxes1[:, :2] + box2_xy = boxes2[:, 2:] + box2_XY = boxes2[:, :2] + # assert (box1_xy >= box1_XY).all() + # assert (box2_xy >= box2_XY).all() + area1 = (box1_xy - box1_XY).prod(-1) + area2 = (box2_xy - box2_XY).prod(-1) + + lt = torch.max(box1_XY, box2_XY) # [N,2] + rb = torch.min(box1_xy, box2_xy) # [N,2] + + inter = (rb - lt).clamp(min=0).prod(-1) + + union = area1 + area2 - inter + + iou = inter / union + + return iou + + +def box_xywh_inter_union( + boxes1: torch.Tensor, boxes2: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + # Asuumes boxes in xywh format + assert boxes1.size(-1) == 4 and boxes2.size(-1) == 4 + boxes1 = box_xywh_to_xyxy(boxes1) + boxes2 = box_xywh_to_xyxy(boxes2) + box1_tl_xy = boxes1[..., :2] + box1_br_xy = boxes1[..., 2:] + box2_tl_xy = boxes2[..., :2] + box2_br_xy = boxes2[..., 2:] + area1 = (box1_br_xy - box1_tl_xy).prod(-1) + area2 = (box2_br_xy - box2_tl_xy).prod(-1) + + assert (area1 >= 0).all() and (area2 >= 0).all() + tl = torch.max(box1_tl_xy, box2_tl_xy) + br = torch.min(box1_br_xy, box2_br_xy) + + inter = (br - tl).clamp(min=0).prod(-1) + union = area1 + area2 - inter + + return inter, union diff --git a/third_party/sam3/sam3/model/data_misc.py b/third_party/sam3/sam3/model/data_misc.py new file mode 100644 index 0000000000000000000000000000000000000000..cc6a7bbb03c262a61405ab25a4b9ff2758c04181 --- /dev/null +++ b/third_party/sam3/sam3/model/data_misc.py @@ -0,0 +1,268 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe +""" +Misc functions, including distributed helpers. +""" + +import collections +import re +from dataclasses import dataclass, field as field_ptr_behaviour, fields, is_dataclass +from typing import Any, get_args, get_origin, List, Mapping, Optional, Sequence, Union + +import torch + + +MyTensor = Union[torch.Tensor, List[Any]] + + +class NestedTensor: + def __init__(self, tensors, mask): + self.tensors = tensors + self.mask = mask + + def to(self, *args, **kwargs): + cast_tensor = self.tensors.to(*args, **kwargs) + cast_mask = self.mask.to(*args, **kwargs) if self.mask is not None else None + return type(self)(cast_tensor, cast_mask) + + def clone(self): + new_tensors = self.tensors.clone() + new_mask = None if self.mask is None else self.mask.clone() + return NestedTensor(new_tensors, new_mask) + + def __getitem__(self, idx): + return self.tensors[idx] + + def __len__(self): + return len(self.tensors) + + @property + def device(self): + return self.tensors.device + + @property + def shape(self): + return self.tensors.shape + + # custom memory pinning method on custom type + def pin_memory(self, device=None): + self.tensors = self.tensors.pin_memory(device) + if self.mask is not None: + self.mask = self.mask.pin_memory(device) + + +# Register NestedTensor as a pytree node so tree_map_only can traverse into it +# (matches onevision/utils/misc.py registration) +from torch.utils import _pytree as pytree + +pytree.register_pytree_node( + NestedTensor, + lambda x: ([x.tensors, x.mask], None), + lambda values, _: NestedTensor(values[0], values[1]), +) + + +def interpolate( + input, size=None, scale_factor=None, mode="nearest", align_corners=None +): + # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor + """ + Equivalent to nn.functional.interpolate, but with support for empty channel sizes. + """ + if input.numel() > 0: + return torch.nn.functional.interpolate( + input, size, scale_factor, mode, align_corners + ) + + assert ( + input.shape[0] != 0 or input.shape[1] != 0 + ), "At least one of the two first dimensions must be non zero" + + if input.shape[1] == 0: + # Pytorch doesn't support null dimension on the channel dimension, so we transpose to fake a null batch dim + return torch.nn.functional.interpolate( + input.transpose(0, 1), size, scale_factor, mode, align_corners + ).transpose(0, 1) + + # empty batch dimension is now supported in pytorch + return torch.nn.functional.interpolate( + input, size, scale_factor, mode, align_corners + ) + + +@dataclass +class BatchedPointer: + stage_ids: MyTensor + stage_ids__type = torch.long + query_ids: MyTensor + query_ids__type = torch.long + object_ids: MyTensor + object_ids__type = torch.long + ptr_mask: MyTensor + ptr_mask__type = torch.bool + ptr_types: MyTensor + ptr_types__type = torch.long + + +@dataclass +class FindStage: + img_ids: MyTensor + img_ids__type = torch.long + text_ids: MyTensor + text_ids__type = torch.long + + input_boxes: MyTensor + input_boxes__type = torch.float + input_boxes_mask: MyTensor + input_boxes_mask__type = torch.bool + input_boxes_label: MyTensor + input_boxes_label__type = torch.long + + input_points: MyTensor + input_points__type = torch.float + input_points_mask: MyTensor + input_points_mask__type = torch.bool + + # We track the object ids referred to by this query. + # This is beneficial for tracking in videos without the need for pointers. + object_ids: Optional[List[List]] = None # List of objects per query + + # Multiplex-specific fields (used by sam3_demo_multiplex) + img_ids_np: Optional[Any] = None + input_boxes_before_embed: Optional[MyTensor] = None + input_boxes_before_embed__type = torch.float + input_points_before_embed: Optional[MyTensor] = None + input_points_before_embed__type = torch.float + ptrs: Optional[Any] = None + ptrs_seg: Optional[Any] = None + + +@dataclass +class BatchedFindTarget: + # The number of boxes in each find query + num_boxes: MyTensor + num_boxes__type = torch.long + + # Target boxes in normalized CxCywh format + boxes: MyTensor + boxes__type = torch.float + # Target boxes in normalized CxCywh format but in padded representation + # as used in BinaryHungarianMatcherV2 (unlike the packed ones in `boxes`) + boxes_padded: MyTensor + boxes_padded__type = torch.float + + # For hybrid matching, we repeat the boxes + repeated_boxes: MyTensor + repeated_boxes__type = torch.float + + # Target Segmentation masks + segments: Optional[MyTensor] + segments__type = torch.bool + + # Target Semantic Segmentation masks + semantic_segments: Optional[MyTensor] + semantic_segments__type = torch.bool + + is_valid_segment: Optional[MyTensor] + is_valid_segment__type = torch.bool + + # Whether annotations are exhaustive for each query + is_exhaustive: MyTensor + is_exhaustive__type = torch.bool + + # The object id for each ground-truth box, in both packed and padded representations + object_ids: MyTensor + object_ids__type = torch.long + object_ids_padded: MyTensor + object_ids_padded__type = torch.long + + +@dataclass +class BatchedInferenceMetadata: + """All metadata required to post-process a find stage""" + + # Coco id that corresponds to the "image" for evaluation by the coco evaluator + coco_image_id: MyTensor + coco_image_id__type = torch.long + + # id in the original dataset, such that we can use the original evaluator + original_image_id: MyTensor + original_image_id__type = torch.long + + # Original category id (if we want to use the original evaluator) + original_category_id: MyTensor + original_category_id__type = torch.int + + # Size of the raw image (height, width) + original_size: MyTensor + original_size__type = torch.long + + # id of the object in the media (track_id for a video) + object_id: MyTensor + object_id__type = torch.long + + # index of the frame in the media (0 in the case of a single-frame media) + frame_index: MyTensor + frame_index__type = torch.long + + # Adding for relations inference + # get_text_input: List[Optional[str]] + + # Adding for TA conditional inference + is_conditioning_only: List[Optional[bool]] + + +@dataclass +class BatchedDatapoint: + img_batch: torch.Tensor + find_text_batch: List[str] + find_inputs: List[FindStage] + find_targets: List[BatchedFindTarget] + find_metadatas: List[BatchedInferenceMetadata] + raw_images: Optional[List[Any]] = None + get_queries: Optional[Any] = None + + +def convert_my_tensors(obj): + def is_optional_field(field) -> bool: + return get_origin(field) is Union and type(None) in get_args(field) + + for field in fields(obj): + if is_dataclass(getattr(obj, field.name)): + convert_my_tensors(getattr(obj, field.name)) + continue + + field_type = field.type + if is_optional_field(field.type): + field_type = Union[get_args(field.type)[:-1]] # Get the Optional field type + + if field_type != MyTensor or getattr(obj, field.name) is None: + continue + + elif len(getattr(obj, field.name)) and isinstance( + getattr(obj, field.name)[0], torch.Tensor + ): + stack_dim = 0 + if field.name in [ + "input_boxes_before_embed", + "input_boxes", + "input_boxes_label", + ]: + stack_dim = 1 + setattr( + obj, + field.name, + torch.stack(getattr(obj, field.name), dim=stack_dim).to( + getattr(obj, field.name + "__type") + ), + ) + else: + setattr( + obj, + field.name, + torch.as_tensor( + getattr(obj, field.name), dtype=getattr(obj, field.name + "__type") + ), + ) + return obj diff --git a/third_party/sam3/sam3/model/decoder.py b/third_party/sam3/sam3/model/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..33f33ee8cf4c3a41603fb1e33d2c373308d7ded8 --- /dev/null +++ b/third_party/sam3/sam3/model/decoder.py @@ -0,0 +1,1375 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe +""" +Transformer decoder. +Inspired from Pytorch's version, adds the pre-norm variant +""" + +import math +from functools import partial +from typing import Any, Dict, List, Optional, Union + +import numpy as np +import torch +import torch.nn.functional as torchF +from sam3.sam.rope import apply_rotary_enc, apply_rotary_enc_real, compute_axial_cis +from sam3.sam.transformer import RoPEAttention +from torch import nn, Tensor +from torch.nn.attention import sdpa_kernel, SDPBackend +from torchvision.ops.roi_align import RoIAlign + +from .act_ckpt_utils import activation_ckpt_wrapper +from .box_ops import box_cxcywh_to_xyxy +from .model_misc import ( + gen_sineembed_for_position, + get_activation_fn, + get_clones, + inverse_sigmoid, + MLP, +) + + +class TransformerDecoderLayer(nn.Module): + def __init__( + self, + activation: str, + d_model: int, + dim_feedforward: int, + dropout: float, + cross_attention: nn.Module, + n_heads: int, + use_text_cross_attention: bool = False, + ): + super().__init__() + + # cross attention + self.cross_attn = cross_attention + self.dropout1 = nn.Dropout(dropout) if dropout > 0 else nn.Identity() + self.norm1 = nn.LayerNorm(d_model) + + # cross attention text + self.use_text_cross_attention = use_text_cross_attention + if use_text_cross_attention: + self.ca_text = nn.MultiheadAttention(d_model, n_heads, dropout=dropout) + self.catext_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() + self.catext_norm = nn.LayerNorm(d_model) + + # self attention + self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout) + self.dropout2 = nn.Dropout(dropout) if dropout > 0 else nn.Identity() + self.norm2 = nn.LayerNorm(d_model) + + # ffn + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.activation = get_activation_fn(activation) + self.dropout3 = nn.Dropout(dropout) if dropout > 0 else nn.Identity() + self.linear2 = nn.Linear(dim_feedforward, d_model) + self.dropout4 = nn.Dropout(dropout) if dropout > 0 else nn.Identity() + self.norm3 = nn.LayerNorm(d_model) + + @staticmethod + def with_pos_embed(tensor, pos): + return tensor if pos is None else tensor + pos + + def forward_ffn(self, tgt): + with torch.amp.autocast(device_type="cuda", enabled=False): + tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) + tgt = tgt.float() + self.dropout4(tgt2) + tgt = self.norm3(tgt) + return tgt + + def forward( + self, + # for tgt + tgt: Optional[Tensor], # nq, bs, d_model + tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos)) + tgt_query_sine_embed: Optional[Tensor] = None, # pos for query. Sine(pos) + tgt_key_padding_mask: Optional[Tensor] = None, + tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4 + memory_text: Optional[Tensor] = None, # num_token, bs, d_model + text_attention_mask: Optional[Tensor] = None, # bs, num_token + # for memory + memory: Optional[Tensor] = None, # hw, bs, d_model + memory_key_padding_mask: Optional[Tensor] = None, + memory_level_start_index: Optional[Tensor] = None, # num_levels + memory_spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2 + memory_pos: Optional[Tensor] = None, # pos for memory + # sa + self_attn_mask: Optional[Tensor] = None, # mask used for self-attention + cross_attn_mask: Optional[Tensor] = None, # mask used for cross-attention + # dac + dac=False, + dac_use_selfatt_ln=True, + presence_token=None, + # skip inside deformable attn + identity=0.0, + **kwargs, # additional kwargs for compatibility + ): + """ + Input: + - tgt/tgt_query_pos: nq, bs, d_model + - + """ + # self attention + if self.self_attn is not None: + if dac: + # we only apply self attention to the first half of the queries + assert tgt.shape[0] % 2 == 0 + num_o2o_queries = tgt.shape[0] // 2 + tgt_o2o = tgt[:num_o2o_queries] + tgt_query_pos_o2o = tgt_query_pos[:num_o2o_queries] + tgt_o2m = tgt[num_o2o_queries:] + else: + tgt_o2o = tgt + tgt_query_pos_o2o = tgt_query_pos + + if presence_token is not None: + tgt_o2o = torch.cat([presence_token, tgt_o2o], dim=0) + tgt_query_pos_o2o = torch.cat( + [torch.zeros_like(presence_token), tgt_query_pos_o2o], dim=0 + ) + tgt_query_pos = torch.cat( + [torch.zeros_like(presence_token), tgt_query_pos], dim=0 + ) + + q = k = self.with_pos_embed(tgt_o2o, tgt_query_pos_o2o) + tgt2 = self.self_attn(q, k, tgt_o2o, attn_mask=self_attn_mask)[0] + tgt_o2o = tgt_o2o + self.dropout2(tgt2) + if dac: + if not dac_use_selfatt_ln: + tgt_o2o = self.norm2(tgt_o2o) + tgt = torch.cat((tgt_o2o, tgt_o2m), dim=0) # Recombine + if dac_use_selfatt_ln: + tgt = self.norm2(tgt) + else: + tgt = tgt_o2o + tgt = self.norm2(tgt) + + if self.use_text_cross_attention: + tgt2 = self.ca_text( + self.with_pos_embed(tgt, tgt_query_pos), + memory_text, + memory_text, + key_padding_mask=text_attention_mask, + )[0] + tgt = tgt + self.catext_dropout(tgt2) + tgt = self.catext_norm(tgt) + + if presence_token is not None: + presence_token_mask = torch.zeros_like(cross_attn_mask[:, :1, :]) + cross_attn_mask = torch.cat( + [presence_token_mask, cross_attn_mask], dim=1 + ) # (bs*nheads, 1+nq, hw) + + # Cross attention to image + tgt2 = self.cross_attn( + query=self.with_pos_embed(tgt, tgt_query_pos), + key=self.with_pos_embed(memory, memory_pos), + value=memory, + attn_mask=cross_attn_mask, + key_padding_mask=( + memory_key_padding_mask.transpose(0, 1) + if memory_key_padding_mask is not None + else None + ), + )[0] + + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + + # ffn + tgt = self.forward_ffn(tgt) + + presence_token_out = None + if presence_token is not None: + presence_token_out = tgt[:1] + tgt = tgt[1:] + + return tgt, presence_token_out + + +class TransformerDecoder(nn.Module): + def __init__( + self, + d_model: int, + frozen: bool, + interaction_layer, + layer, + num_layers: int, + num_queries: int, + return_intermediate: bool, + box_refine: bool = False, + num_o2m_queries: int = 0, + dac: bool = False, + boxRPB: str = "none", + # Experimental: An object query for SAM 2 tasks + instance_query: bool = False, + # Defines the number of additional instance queries, + # 1 or 4 are the most likely for single vs multi mask support + num_instances: int = 1, # Irrelevant if instance_query is False + dac_use_selfatt_ln: bool = True, + use_act_checkpoint: bool = False, + compile_mode=None, + presence_token: bool = False, + clamp_presence_logits: bool = True, + clamp_presence_logit_max_val: float = 10.0, + use_normed_output_consistently: bool = True, + separate_box_head_instance: bool = False, + separate_norm_instance: bool = False, + resolution: Optional[int] = None, + stride: Optional[int] = None, + ): + super().__init__() + self.d_model = d_model + self.layers = get_clones(layer, num_layers) + self.fine_layers = ( + get_clones(interaction_layer, num_layers) + if interaction_layer is not None + else [None] * num_layers + ) + self.num_layers = num_layers + self.num_queries = num_queries + self.dac = dac + if dac: + self.num_o2m_queries = num_queries + tot_num_queries = num_queries + else: + self.num_o2m_queries = num_o2m_queries + tot_num_queries = num_queries + num_o2m_queries + self.norm = nn.LayerNorm(d_model) + self.return_intermediate = return_intermediate + self.bbox_embed = MLP(d_model, d_model, 4, 3) + self.query_embed = nn.Embedding(tot_num_queries, d_model) + self.instance_query_embed = None + self.instance_query_reference_points = None + self.use_instance_query = instance_query + self.num_instances = num_instances + self.use_normed_output_consistently = use_normed_output_consistently + + self.instance_norm = nn.LayerNorm(d_model) if separate_norm_instance else None + self.instance_bbox_embed = None + if separate_box_head_instance: + self.instance_bbox_embed = MLP(d_model, d_model, 4, 3) + if instance_query: + self.instance_query_embed = nn.Embedding(num_instances, d_model) + self.box_refine = box_refine + if box_refine: + nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0) + nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0) + + self.reference_points = nn.Embedding(num_queries, 4) + if instance_query: + self.instance_reference_points = nn.Embedding(num_instances, 4) + + assert boxRPB in ["none", "log", "linear", "both"] + self.boxRPB = boxRPB + if boxRPB != "none": + try: + nheads = self.layers[0].cross_attn_image.num_heads + except AttributeError: + nheads = self.layers[0].cross_attn.num_heads + + n_input = 4 if boxRPB == "both" else 2 + self.boxRPB_embed_x = MLP(n_input, d_model, nheads, 2) + self.boxRPB_embed_y = MLP(n_input, d_model, nheads, 2) + self.compilable_cord_cache = None + self.compilable_stored_size = None + self.coord_cache = {} + + if resolution is not None and stride is not None: + feat_size = resolution // stride + coords_h, coords_w = self._get_coords( + feat_size, feat_size, + device="cuda" if torch.cuda.is_available() else "cpu", + ) + self.compilable_cord_cache = (coords_h, coords_w) + self.compilable_stored_size = (feat_size, feat_size) + + self.roi_pooler = ( + RoIAlign(output_size=7, spatial_scale=1, sampling_ratio=-1, aligned=True) + if interaction_layer is not None + else None + ) + if frozen: + for p in self.parameters(): + p.requires_grad_(False) + + self.presence_token = None + self.clamp_presence_logits = clamp_presence_logits + self.clamp_presence_logit_max_val = clamp_presence_logit_max_val + if presence_token: + self.presence_token = nn.Embedding(1, d_model) + self.presence_token_head = MLP(d_model, d_model, 1, 3) + self.presence_token_out_norm = nn.LayerNorm(d_model) + + self.ref_point_head = MLP(2 * self.d_model, self.d_model, self.d_model, 2) + self.dac_use_selfatt_ln = dac_use_selfatt_ln + self.use_act_checkpoint = use_act_checkpoint + + nn.init.normal_(self.query_embed.weight.data) + if self.instance_query_embed is not None: + nn.init.normal_(self.instance_query_embed.weight.data) + + assert self.roi_pooler is None + assert self.return_intermediate, "support return_intermediate only" + assert self.box_refine, "support box refine only" + + self.compile_mode = compile_mode + self.compiled = False + # We defer compilation till after the first forward, to first warm-up the boxRPB cache + + # assign layer index to each layer so that some layers can decide what to do + # based on which layer index they are (e.g. cross attention to memory bank only + # in selected layers) + for layer_idx, layer in enumerate(self.layers): + layer.layer_idx = layer_idx + + @staticmethod + def _get_coords(H, W, device): + coords_h = torch.arange(0, H, device=device, dtype=torch.float32) / H + coords_w = torch.arange(0, W, device=device, dtype=torch.float32) / W + return coords_h, coords_w + + def _get_rpb_matrix(self, reference_boxes, feat_size): + H, W = feat_size + boxes_xyxy = box_cxcywh_to_xyxy(reference_boxes).transpose(0, 1) + bs, num_queries, _ = boxes_xyxy.shape + if self.compilable_cord_cache is None: + self.compilable_cord_cache = self._get_coords(H, W, reference_boxes.device) + self.compilable_stored_size = (H, W) + + if torch.compiler.is_dynamo_compiling() or self.compilable_stored_size == ( + H, + W, + ): + # good, hitting the cache, will be compilable + coords_h, coords_w = self.compilable_cord_cache + else: + # cache miss, will create compilation issue + # In case we're not compiling, we'll still rely on the dict-based cache + if feat_size not in self.coord_cache: + self.coord_cache[feat_size] = self._get_coords( + H, W, reference_boxes.device + ) + coords_h, coords_w = self.coord_cache[feat_size] + + assert coords_h.shape == (H,) + assert coords_w.shape == (W,) + + deltas_y = coords_h.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 1:4:2] + deltas_y = deltas_y.view(bs, num_queries, -1, 2) + deltas_x = coords_w.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 0:3:2] + deltas_x = deltas_x.view(bs, num_queries, -1, 2) + + if self.boxRPB in ["log", "both"]: + deltas_x_log = deltas_x * 8 # normalize to -8, 8 + deltas_x_log = ( + torch.sign(deltas_x_log) + * torch.log2(torch.abs(deltas_x_log) + 1.0) + / np.log2(8) + ) + + deltas_y_log = deltas_y * 8 # normalize to -8, 8 + deltas_y_log = ( + torch.sign(deltas_y_log) + * torch.log2(torch.abs(deltas_y_log) + 1.0) + / np.log2(8) + ) + if self.boxRPB == "log": + deltas_x = deltas_x_log + deltas_y = deltas_y_log + else: + deltas_x = torch.cat([deltas_x, deltas_x_log], dim=-1) + deltas_y = torch.cat([deltas_y, deltas_y_log], dim=-1) + + if self.training: + assert self.use_act_checkpoint, "activation ckpt not enabled in decoder" + deltas_x = activation_ckpt_wrapper(self.boxRPB_embed_x)( + x=deltas_x, + act_ckpt_enable=self.training and self.use_act_checkpoint, + ) # bs, num_queries, W, n_heads + deltas_y = activation_ckpt_wrapper(self.boxRPB_embed_y)( + x=deltas_y, + act_ckpt_enable=self.training and self.use_act_checkpoint, + ) # bs, num_queries, H, n_heads + + if not torch.compiler.is_dynamo_compiling(): + assert deltas_x.shape[:3] == (bs, num_queries, W) + assert deltas_y.shape[:3] == (bs, num_queries, H) + + B = deltas_y.unsqueeze(3) + deltas_x.unsqueeze( + 2 + ) # bs, num_queries, H, W, n_heads + if not torch.compiler.is_dynamo_compiling(): + assert B.shape[:4] == (bs, num_queries, H, W) + B = B.flatten(2, 3) # bs, num_queries, H*W, n_heads + B = B.permute(0, 3, 1, 2) # bs, n_heads, num_queries, H*W + B = B.contiguous() # memeff attn likes ordered strides + if not torch.compiler.is_dynamo_compiling(): + assert B.shape[2:] == (num_queries, H * W) + return B + + def forward( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + reference_boxes: Optional[Tensor] = None, # num_queries, bs, 4 + # for memory + level_start_index: Optional[Tensor] = None, # num_levels + spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2 + valid_ratios: Optional[Tensor] = None, + # for text + memory_text: Optional[Tensor] = None, + text_attention_mask: Optional[Tensor] = None, + # if `apply_dac` is None, it will default to `self.dac` + apply_dac: Optional[bool] = None, + is_instance_prompt=False, + decoder_extra_kwargs: Optional[Dict] = None, + # ROI memory bank + obj_roi_memory_feat=None, + obj_roi_memory_mask=None, + box_head_trk=None, + ): + """ + Input: + - tgt: nq, bs, d_model + - memory: \\sum{hw}, bs, d_model + - pos: \\sum{hw}, bs, d_model + - reference_boxes: nq, bs, 4 (after sigmoid) + - valid_ratios/spatial_shapes: bs, nlevel, 2 + """ + if memory_mask is not None: + assert ( + self.boxRPB == "none" + ), "inputting a memory_mask in the presence of boxRPB is unexpected/not implemented" + + apply_dac = apply_dac if apply_dac is not None else self.dac + if apply_dac: + assert (tgt.shape[0] == self.num_queries) or ( + self.use_instance_query + and (tgt.shape[0] == self.instance_query_embed.num_embeddings) + ) + + tgt = tgt.repeat(2, 1, 1) + # note that we don't tile tgt_mask, since DAC doesn't + # use self-attention in o2m queries + if reference_boxes is not None: + assert (reference_boxes.shape[0] == self.num_queries) or ( + self.use_instance_query + and ( + reference_boxes.shape[0] + == self.instance_query_embed.num_embeddings + ) + ) + reference_boxes = reference_boxes.repeat(2, 1, 1) + + bs = tgt.shape[1] + intermediate = [] + intermediate_presence_logits = [] + presence_feats = None + + if self.box_refine: + if reference_boxes is None: + # In this case, we're in a one-stage model, so we generate the reference boxes + reference_boxes = self.reference_points.weight.unsqueeze(1) + reference_boxes = ( + reference_boxes.repeat(2, bs, 1) + if apply_dac + else reference_boxes.repeat(1, bs, 1) + ) + reference_boxes = reference_boxes.sigmoid() + intermediate_ref_boxes = [reference_boxes] + else: + reference_boxes = None + intermediate_ref_boxes = None + + output = tgt + presence_out = None + if self.presence_token is not None and is_instance_prompt is False: + # expand to batch dim + presence_out = self.presence_token.weight[None].expand(1, bs, -1) + + box_head = self.bbox_embed + if is_instance_prompt and self.instance_bbox_embed is not None: + box_head = self.instance_bbox_embed + + out_norm = self.norm + if is_instance_prompt and self.instance_norm is not None: + out_norm = self.instance_norm + + for layer_idx, layer in enumerate(self.layers): + reference_points_input = ( + reference_boxes[:, :, None] + * torch.cat([valid_ratios, valid_ratios], -1)[None, :] + ) # nq, bs, nlevel, 4 + + query_sine_embed = gen_sineembed_for_position( + reference_points_input[:, :, 0, :], self.d_model + ) # nq, bs, d_model*2 + + # conditional query + query_pos = self.ref_point_head(query_sine_embed) # nq, bs, d_model + + if self.boxRPB != "none" and reference_boxes is not None: + assert ( + spatial_shapes.shape[0] == 1 + ), "only single scale support implemented" + memory_mask = self._get_rpb_matrix( + reference_boxes, + (spatial_shapes[0, 0], spatial_shapes[0, 1]), + ) + memory_mask = memory_mask.flatten(0, 1) # (bs*n_heads, nq, H*W) + if self.training: + assert ( + self.use_act_checkpoint + ), "Activation checkpointing not enabled in the decoder" + output, presence_out = activation_ckpt_wrapper(layer)( + tgt=output, + tgt_query_pos=query_pos, + tgt_query_sine_embed=query_sine_embed, + tgt_key_padding_mask=tgt_key_padding_mask, + tgt_reference_points=reference_points_input, + memory_text=memory_text, + text_attention_mask=text_attention_mask, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + memory_level_start_index=level_start_index, + memory_spatial_shapes=spatial_shapes, + memory_pos=pos, + self_attn_mask=tgt_mask, + cross_attn_mask=memory_mask, + dac=apply_dac, + dac_use_selfatt_ln=self.dac_use_selfatt_ln, + presence_token=presence_out, + **(decoder_extra_kwargs or {}), + act_ckpt_enable=self.training and self.use_act_checkpoint, + # ROI memory bank + obj_roi_memory_feat=obj_roi_memory_feat, + obj_roi_memory_mask=obj_roi_memory_mask, + ) + + # iter update + if self.box_refine: + reference_before_sigmoid = inverse_sigmoid(reference_boxes) + if box_head_trk is None: + # delta_unsig = self.bbox_embed(output) + if not self.use_normed_output_consistently: + delta_unsig = box_head(output) + else: + delta_unsig = box_head(out_norm(output)) + else: + # box_head_trk use a separate box head for tracking queries + Q_det = decoder_extra_kwargs["Q_det"] + assert output.size(0) >= Q_det + delta_unsig_det = self.bbox_embed(output[:Q_det]) + delta_unsig_trk = box_head_trk(output[Q_det:]) + delta_unsig = torch.cat([delta_unsig_det, delta_unsig_trk], dim=0) + outputs_unsig = delta_unsig + reference_before_sigmoid + new_reference_points = outputs_unsig.sigmoid() + + reference_boxes = new_reference_points.detach() + if layer_idx != self.num_layers - 1: + intermediate_ref_boxes.append(new_reference_points) + else: + raise NotImplementedError("not implemented yet") + + intermediate.append(out_norm(output)) + if self.presence_token is not None and is_instance_prompt is False: + # norm, mlp head + intermediate_layer_presence_logits = self.presence_token_head( + self.presence_token_out_norm(presence_out) + ).squeeze(-1) + + # clamp to mitigate numerical issues + if self.clamp_presence_logits: + intermediate_layer_presence_logits.clamp( + min=-self.clamp_presence_logit_max_val, + max=self.clamp_presence_logit_max_val, + ) + + intermediate_presence_logits.append(intermediate_layer_presence_logits) + presence_feats = presence_out.clone() + + if not self.compiled and self.compile_mode is not None: + self.forward = torch.compile( + self.forward, mode=self.compile_mode, fullgraph=True + ) + self.compiled = True + + return ( + torch.stack(intermediate), + torch.stack(intermediate_ref_boxes), + ( + torch.stack(intermediate_presence_logits) + if self.presence_token is not None and is_instance_prompt is False + else None + ), + presence_feats, + ) + + +class TransformerEncoderCrossAttention(nn.Module): + def __init__( + self, + d_model: int, + frozen: bool, + pos_enc_at_input: bool, + layer, + num_layers: int, + use_act_checkpoint: bool = False, + batch_first: bool = False, # Do layers expect batch first input? + # which layers to exclude cross attention? default: None, means all + # layers use cross attention + remove_cross_attention_layers: Optional[list] = None, + ): + super().__init__() + self.d_model = d_model + self.layers = get_clones(layer, num_layers) + self.num_layers = num_layers + self.norm = nn.LayerNorm(d_model) + self.pos_enc_at_input = pos_enc_at_input + self.use_act_checkpoint = use_act_checkpoint + + if frozen: + for p in self.parameters(): + p.requires_grad_(False) + + self.batch_first = batch_first + + # remove cross attention layers if specified + self.remove_cross_attention_layers = [False] * self.num_layers + if remove_cross_attention_layers is not None: + for i in remove_cross_attention_layers: + self.remove_cross_attention_layers[i] = True + assert len(self.remove_cross_attention_layers) == len(self.layers) + + for i, remove_cross_attention in enumerate(self.remove_cross_attention_layers): + if remove_cross_attention: + self.layers[i].cross_attn_image = None + self.layers[i].norm2 = None + self.layers[i].dropout2 = None + + def forward( + self, + src, # self-attention inputs + prompt, # cross-attention inputs + src_mask: Optional[Tensor] = None, # att.mask for self-attention inputs + prompt_mask: Optional[Tensor] = None, # att.mask for cross-attention inputs + src_key_padding_mask: Optional[Tensor] = None, + prompt_key_padding_mask: Optional[Tensor] = None, + src_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs + prompt_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs + feat_sizes: Optional[list] = None, + num_obj_ptr_tokens: int = 0, # number of object pointer *tokens* + ): + if isinstance(src, list): + assert isinstance(src_key_padding_mask, list) and isinstance(src_pos, list) + assert len(src) == len(src_key_padding_mask) == len(src_pos) == 1 + src, src_key_padding_mask, src_pos = ( + src[0], + src_key_padding_mask[0], + src_pos[0], + ) + + assert ( + src.shape[1] == prompt.shape[1] + ), "Batch size must be the same for src and prompt" + + output = src + + if self.pos_enc_at_input and src_pos is not None: + output = output + 0.1 * src_pos + + if self.batch_first: + # Convert to batch first + output = output.transpose(0, 1) + src_pos = src_pos.transpose(0, 1) + prompt = prompt.transpose(0, 1) + prompt_pos = prompt_pos.transpose(0, 1) + + for layer in self.layers: + kwds = {} + if isinstance(layer.cross_attn_image, RoPEAttention): + kwds = {"num_k_exclude_rope": num_obj_ptr_tokens} + + output = activation_ckpt_wrapper(layer)( + tgt=output, + memory=prompt, + tgt_mask=src_mask, + memory_mask=prompt_mask, + tgt_key_padding_mask=src_key_padding_mask, + memory_key_padding_mask=prompt_key_padding_mask, + pos=prompt_pos, + query_pos=src_pos, + dac=False, + attn_bias=None, + act_ckpt_enable=self.training and self.use_act_checkpoint, + **kwds, + ) + normed_output = self.norm(output) + + if self.batch_first: + # Convert back to seq first + normed_output = normed_output.transpose(0, 1) + src_pos = src_pos.transpose(0, 1) + + return { + "memory": normed_output, + "pos_embed": src_pos, + "padding_mask": src_key_padding_mask, + } + + +class TransformerDecoderLayerv1(nn.Module): + def __init__( + self, + activation: str, + cross_attention: nn.Module, + d_model: int, + dim_feedforward: int, + dropout: float, + pos_enc_at_attn: bool, + pos_enc_at_cross_attn_keys: bool, + pos_enc_at_cross_attn_queries: bool, + pre_norm: bool, + self_attention: nn.Module, + ): + super().__init__() + self.d_model = d_model + self.dim_feedforward = dim_feedforward + self.dropout_value = dropout + self.self_attn = self_attention + self.cross_attn_image = cross_attention + + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation_str = activation + self.activation = get_activation_fn(activation) + self.pre_norm = pre_norm + + self.pos_enc_at_attn = pos_enc_at_attn + self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries + self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys + + def forward_post( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + **kwargs, + ): + q = k = tgt + query_pos if self.pos_enc_at_attn else tgt + + # Self attention + tgt2 = self.self_attn( + q, + k, + value=tgt, + attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask, + )[0] + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + + # Cross attention to image + tgt2 = self.cross_attn_image( + query=tgt + query_pos if self.pos_enc_at_cross_attn_queries else tgt, + key=memory + pos if self.pos_enc_at_cross_attn_keys else memory, + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + )[0] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + + # FFN + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt + + def forward_pre( + self, + tgt, + memory, + dac: bool = False, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + attn_bias: Optional[Tensor] = None, + **kwargs, + ): + if dac: + # we only apply self attention to the first half of the queries + assert tgt.shape[0] % 2 == 0 + other_tgt = tgt[tgt.shape[0] // 2 :] + tgt = tgt[: tgt.shape[0] // 2] + tgt2 = self.norm1(tgt) + q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2 + tgt2 = self.self_attn( + q, + k, + value=tgt2, + attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask, + )[0] + tgt = tgt + self.dropout1(tgt2) + if dac: + # Recombine + tgt = torch.cat((tgt, other_tgt), dim=0) + tgt2 = self.norm2(tgt) + tgt2 = self.cross_attn_image( + query=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2, + key=memory + pos if self.pos_enc_at_cross_attn_keys else memory, + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + attn_bias=attn_bias, + )[0] + tgt = tgt + self.dropout2(tgt2) + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + + def forward( + self, + tgt, + memory, + dac: bool = False, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + attn_bias: Optional[Tensor] = None, + **kwds: Any, + ) -> torch.Tensor: + fwd_fn = self.forward_pre if self.pre_norm else self.forward_post + return fwd_fn( + tgt, + memory, + dac=dac, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos, + query_pos=query_pos, + attn_bias=attn_bias, + **kwds, + ) + + +class TransformerDecoderLayerv2(TransformerDecoderLayerv1): + def __init__(self, cross_attention_first=False, *args: Any, **kwds: Any): + super().__init__(*args, **kwds) + self.cross_attention_first = cross_attention_first + + def _forward_sa(self, tgt, query_pos): + # Self-Attention + tgt2 = self.norm1(tgt) + q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2 + tgt2 = self.self_attn(q, k, v=tgt2) + tgt = tgt + self.dropout1(tgt2) + return tgt + + def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0): + if self.cross_attn_image is None: + return tgt + + kwds = {} + if num_k_exclude_rope > 0: + assert isinstance(self.cross_attn_image, RoPEAttention) + kwds = {"num_k_exclude_rope": num_k_exclude_rope} + + # Cross-Attention + tgt2 = self.norm2(tgt) + tgt2 = self.cross_attn_image( + q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2, + k=memory + pos if self.pos_enc_at_cross_attn_keys else memory, + v=memory, + **kwds, + ) + tgt = tgt + self.dropout2(tgt2) + return tgt + + def forward_pre( + self, + tgt, + memory, + dac: bool, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + attn_bias: Optional[Tensor] = None, + num_k_exclude_rope: int = 0, + ): + assert dac is False + assert tgt_mask is None + assert memory_mask is None + assert tgt_key_padding_mask is None + assert memory_key_padding_mask is None + assert attn_bias is None + + if self.cross_attention_first: + tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope) + tgt = self._forward_sa(tgt, query_pos) + else: + tgt = self._forward_sa(tgt, query_pos) + tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope) + + # MLP + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + + def forward(self, *args: Any, **kwds: Any) -> torch.Tensor: + if self.pre_norm: + return self.forward_pre(*args, **kwds) + raise NotImplementedError + + +def functional_attention( + q: Tensor, + k: Tensor, + v: Tensor, + *, + dropout: float, + num_heads: int, + num_k_exclude_rope: int = 0, + freqs_cis: Optional[Tensor] = None, + freqs_cis_real: Optional[Tensor] = None, + freqs_cis_imag: Optional[Tensor] = None, + use_fa3: bool = False, + use_rope_real: bool = False, + rope_k_repeat: bool, +) -> Union[Tensor, tuple[Tensor, Tensor]]: + b, n, cq = q.shape + _, m, ck = k.shape + _, _, cv = v.shape + if b > 1: + assert k.shape[0] == v.shape[0] == b + else: + # broadcast-able + assert k.shape[0] == b == 1, f"{q.shape=} {k.shape=} {v.shape=}" + assert v.shape[1] == m + + q = q.reshape(b, n, num_heads, cq // num_heads).transpose(1, 2) + k = k.reshape(b, m, num_heads, ck // num_heads).transpose(1, 2) + v = v.reshape(v.shape[0], m, num_heads, cv // num_heads).transpose(1, 2) + + if freqs_cis is not None: + num_k_rope = k.size(-2) - num_k_exclude_rope + if use_rope_real: + q, k[:, :, :num_k_rope] = apply_rotary_enc_real( + q, + k[:, :, :num_k_rope], + freqs_cis_real=freqs_cis_real, + freqs_cis_imag=freqs_cis_imag, + repeat_freqs_k=rope_k_repeat, + ) + else: + q, k[:, :, :num_k_rope] = apply_rotary_enc( + q, + k[:, :, :num_k_rope], + freqs_cis, + repeat_freqs_k=rope_k_repeat, + ) + + if use_fa3: + from sam3.perflib.fa3 import flash_attn_func + + assert dropout == 0.0 + out = flash_attn_func(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)) + else: + with sdpa_kernel(SDPBackend.FLASH_ATTENTION): + out = torchF.scaled_dot_product_attention(q, k, v, dropout_p=dropout) + out = out.transpose(1, 2) # B * n * n_heads * (cv // num_heads) + + out = out.reshape(b, n, cv) + return out + + +class SimpleRoPEAttention(nn.Module): + """ + Attention with rotary position encoding. + This class is "simple" because it does not perform q/k/v/out projections. + """ + + def __init__( + self, + d_model: int, + num_heads: int, + dropout_p: float, + rope_theta=10000.0, + # whether to repeat q rope to match k length + # this is needed for cross-attention to memories + rope_k_repeat=False, + feat_sizes=(64, 64), # [w, h] for stride 16 feats at 1024 resolution + use_fa3: bool = False, + use_rope_real: bool = False, + ): + super().__init__() + + self.num_heads = num_heads + self.dropout_p = dropout_p + self.compute_cis = partial( + compute_axial_cis, dim=d_model // num_heads, theta=rope_theta + ) + device = torch.device("cuda") if torch.cuda.is_available() else None + self.freqs_cis = self.compute_cis( + end_x=feat_sizes[0], end_y=feat_sizes[1], device=device + ) + + self.use_fa3 = use_fa3 + self.use_rope_real = use_rope_real + if self.use_rope_real: + self.freqs_cis_real = self.freqs_cis.real + self.freqs_cis_imag = self.freqs_cis.imag + self.rope_k_repeat = rope_k_repeat + + def forward( + self, + q: Tensor, + k: Tensor, + v: Tensor, + num_k_exclude_rope: int = 0, + ) -> Union[Tensor, tuple[Tensor, Tensor]]: + # Apply rotary position encoding + w = h = math.sqrt(q.shape[-2]) + self.freqs_cis = self.freqs_cis.to(q.device) + if self.freqs_cis.shape[0] != q.shape[-2]: + self.freqs_cis = self.compute_cis(end_x=w, end_y=h, device=q.device) + if self.use_rope_real: + self.freqs_cis_real = self.freqs_cis.real + self.freqs_cis_imag = self.freqs_cis.imag + if q.shape[-2] != k.shape[-2]: + assert self.rope_k_repeat + + dropout_p = self.dropout_p if self.training else 0.0 + out = functional_attention( + q, + k, + v, + dropout=dropout_p, + num_heads=self.num_heads, + num_k_exclude_rope=num_k_exclude_rope, + freqs_cis=self.freqs_cis, + freqs_cis_real=self.freqs_cis_real if self.use_rope_real else None, + freqs_cis_imag=self.freqs_cis_imag if self.use_rope_real else None, + use_fa3=self.use_fa3, + use_rope_real=self.use_rope_real, + rope_k_repeat=self.rope_k_repeat, + ) + + return out + + +class DecoupledTransformerDecoderLayerv2(nn.Module): + def __init__( + self, + *, + activation: str, + d_model: int, + num_heads: int, + dim_feedforward: int, + dropout: float, + pos_enc_at_attn: bool, + pos_enc_at_cross_attn_keys: bool, + pos_enc_at_cross_attn_queries: bool, + pre_norm: bool, + cross_attention_first: bool = False, + self_attention_rope: SimpleRoPEAttention, + cross_attention_rope: SimpleRoPEAttention, + ): + super().__init__() + self.d_model = d_model + self.num_heads = num_heads + self.dim_feedforward = dim_feedforward + self.dropout_value = dropout + + self.self_attn_q_proj = nn.Linear(d_model, d_model) + self.self_attn_k_proj = nn.Linear(d_model, d_model) + self.self_attn_v_proj = nn.Linear(d_model, d_model) + self.self_attn_out_proj = nn.Linear(d_model, d_model) + + self.cross_attn_q_proj = nn.Linear(d_model, d_model) + self.cross_attn_k_proj = nn.Linear(d_model, d_model) + self.cross_attn_v_proj = nn.Linear(d_model, d_model) + self.cross_attn_out_proj = nn.Linear(d_model, d_model) + + self.image_cross_attn_q_proj = nn.Linear(d_model, d_model) + self.image_cross_attn_k_proj = nn.Linear(d_model, d_model) + + self.self_attention_rope = self_attention_rope + self.cross_attention_rope = cross_attention_rope + + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation_str = activation + self.activation = get_activation_fn(activation) + self.pre_norm = pre_norm + + self.pos_enc_at_attn = pos_enc_at_attn + self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries + self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys + + self.cross_attention_first = cross_attention_first + + def _forward_sa(self, tgt, query_pos): + # Self-Attention + tgt2 = self.norm1(tgt) + + q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2 + + q = self.self_attn_q_proj(q) + k = self.self_attn_k_proj(k) + v = self.self_attn_v_proj(tgt2) + out = self.self_attention_rope(q, k, v) + tgt2 = self.self_attn_out_proj(out) + + tgt = tgt + self.dropout1(tgt2) + return tgt + + def _forward_ca( + self, + *, + image, + tgt, + memory_image, + memory, + query_pos, + memory_image_pos, + num_k_exclude_rope=0, + ): + kwds = {} + if num_k_exclude_rope > 0: + assert isinstance(self.cross_attention_rope, SimpleRoPEAttention) + kwds = {"num_k_exclude_rope": num_k_exclude_rope} + + # Cross-Attention + tgt2 = self.norm2(tgt) + + q = self.image_cross_attn_q_proj(image) + self.cross_attn_q_proj(tgt2) + if self.pos_enc_at_cross_attn_queries: + q = q + query_pos + k = self.image_cross_attn_k_proj(memory_image) + self.cross_attn_k_proj(memory) + if self.pos_enc_at_cross_attn_keys: + k = k + memory_image_pos + v = self.cross_attn_v_proj(memory) + + out = self.cross_attention_rope(q, k, v, **kwds) + tgt2 = self.cross_attn_out_proj(out) + + tgt = tgt + self.dropout2(tgt2) + return tgt + + def forward_pre( + self, + *, + image, + tgt, + memory_image, + memory, + image_pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + memory_image_pos: Optional[Tensor] = None, + memory_pos: Optional[Tensor] = None, + num_k_exclude_rope: int = 0, + ): + if self.cross_attention_first: + tgt = self._forward_ca( + image=image, + tgt=tgt, + memory_image=memory_image, + memory=memory, + query_pos=query_pos, + memory_image_pos=memory_image_pos, + num_k_exclude_rope=num_k_exclude_rope, + ) + tgt = self._forward_sa(tgt, query_pos) + else: + tgt = self._forward_sa(tgt, query_pos) + tgt = self._forward_ca( + image=image, + tgt=tgt, + memory_image=memory_image, + memory=memory, + query_pos=query_pos, + memory_image_pos=memory_image_pos, + num_k_exclude_rope=num_k_exclude_rope, + ) + + # MLP + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + + return image, tgt + + def forward(self, *args: Any, **kwds: Any) -> torch.Tensor: + if self.pre_norm: + return self.forward_pre(*args, **kwds) + raise NotImplementedError + + +class TransformerEncoderDecoupledCrossAttention(nn.Module): + def __init__( + self, + d_model: int, + frozen: bool, + pos_enc_at_input: bool, + layer, + num_layers: int, + use_act_checkpoint: bool = False, + batch_first: bool = False, # Do layers expect batch first input? + use_image_in_output: bool = True, + ): + super().__init__() + self.d_model = d_model + self.layers = get_clones(layer, num_layers) + self.num_layers = num_layers + self.norm = nn.LayerNorm(d_model) + self.pos_enc_at_input = pos_enc_at_input + self.use_act_checkpoint = use_act_checkpoint + self.use_image_in_output = use_image_in_output + + if frozen: + for p in self.parameters(): + p.requires_grad_(False) + + self.batch_first = batch_first + + def forward( + self, + image: Tensor, # image features + src: Tensor, # self-attention inputs; object features + memory_image: Tensor, # cross-attention inputs; image features + memory: Tensor, # cross-attention inputs; object features + image_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs + src_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs + memory_image_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs + memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs + num_obj_ptr_tokens: int = 0, # number of object pointer *tokens* + ): + assert ( + src.shape[1] == memory.shape[1] + ), "Batch size must be the same for src and memory" + assert ( + image.shape[1] == memory_image.shape[1] + ), "Batch size must be the same for image and memory_image" + + output = src + + if self.pos_enc_at_input and src_pos is not None: + output = output + 0.1 * src_pos + + if self.batch_first: + # Convert to batch first + output = output.transpose(0, 1) + src_pos = src_pos.transpose(0, 1) + image = image.transpose(0, 1) + memory = memory.transpose(0, 1) + memory_pos = memory_pos.transpose(0, 1) + memory_image = memory_image.transpose(0, 1) + memory_image_pos = memory_image_pos.transpose(0, 1) + + if memory_image.shape[1] != memory.shape[1]: + # Pad memory_image with zeros, to accodmate object pointers + assert ( + (memory.shape[1] - memory_image.shape[1]) == num_obj_ptr_tokens + ), f"{memory.shape[1]} - {memory_image.shape[1]} != {num_obj_ptr_tokens}" + memory_image = torch.cat( + [ + memory_image, + torch.zeros( + (memory_image.shape[0], num_obj_ptr_tokens) + + memory_image.shape[2:], + dtype=memory_image.dtype, + device=memory_image.device, + ), + ], + dim=1, + ) + if memory_image_pos is not None: + assert ( + (memory_pos.shape[1] - memory_image_pos.shape[1]) + == num_obj_ptr_tokens + ), f"{memory_pos.shape[1]} - {memory_image_pos.shape[1]} != {num_obj_ptr_tokens}" + # tpos is the same in the batch anyway; note that memory_image always has a batch size of 1 + memory_image_pos = torch.cat( + [ + memory_image_pos, + memory_pos[0:1, -num_obj_ptr_tokens:], + ], + dim=1, + ) + + for layer in self.layers: + image, output = activation_ckpt_wrapper(layer)( + image=image, + tgt=output, + memory_image=memory_image, + memory=memory, + image_pos=image_pos, + query_pos=src_pos, + memory_image_pos=memory_image_pos, + memory_pos=memory_pos, + num_k_exclude_rope=num_obj_ptr_tokens, + act_ckpt_enable=self.training and self.use_act_checkpoint, + ) + + if self.use_image_in_output: + normed_output = self.norm(output + image) + else: + normed_output = self.norm(output) + + if self.batch_first: + # Convert back to seq first + normed_output = normed_output.transpose(0, 1) + src_pos = src_pos.transpose(0, 1) + + return { + "memory": normed_output, + "pos_embed": src_pos, + } diff --git a/third_party/sam3/sam3/model/edt.py b/third_party/sam3/sam3/model/edt.py new file mode 100644 index 0000000000000000000000000000000000000000..498f28f6685aca8c6cc790226fa64f22d5514529 --- /dev/null +++ b/third_party/sam3/sam3/model/edt.py @@ -0,0 +1,175 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +"""Triton kernel for euclidean distance transform (EDT)""" + +import torch +import triton +import triton.language as tl + +""" +Disclaimer: This implementation is not meant to be extremely efficient. A CUDA kernel would likely be more efficient. +Even in Triton, there may be more suitable algorithms. + +The goal of this kernel is to mimic cv2.distanceTransform(input, cv2.DIST_L2, 0). +Recall that the euclidean distance transform (EDT) calculates the L2 distance to the closest zero pixel for each pixel of the source image. + +For images of size NxN, the naive algorithm would be to compute pairwise distances between every pair of points, leading to a O(N^4) algorithm, which is obviously impractical. +One can do better using the following approach: +- First, compute the distance to the closest point in the same row. We can write it as Row_EDT[i,j] = min_k (sqrt((k-j)^2) if input[i,k]==0 else +infinity). With a naive implementation, this step has a O(N^3) complexity +- Then, because of triangular inequality, we notice that the EDT for a given location [i,j] is the min of the row EDTs in the same column. EDT[i,j] = min_k Row_EDT[k, j]. This is also O(N^3) + +Overall, this algorithm is quite amenable to parallelization, and has a complexity O(N^3). Can we do better? + +It turns out that we can leverage the structure of the L2 distance (nice and convex) to find the minimum in a more efficient way. +We follow the algorithm from "Distance Transforms of Sampled Functions" (https://cs.brown.edu/people/pfelzens/papers/dt-final.pdf), which is also what's implemented in opencv + +For a single dimension EDT, we can compute the EDT of an arbitrary function F, that we discretize over the grid. Note that for the binary EDT that we're interested in, we can set F(i,j) = 0 if input[i,j]==0 else +infinity +For now, we'll compute the EDT squared, and will take the sqrt only at the very end. +The basic idea is that each point at location i spawns a parabola around itself, with a bias equal to F(i). So specifically, we're looking at the parabola (x - i)^2 + F(i) +When we're looking for the row EDT at location j, we're effectively looking for min_i (x-i)^2 + F(i). In other word we want to find the lowest parabola at location j. + +To do this efficiently, we need to maintain the lower envelope of the union of parabolas. This can be constructed on the fly using a sort of stack approach: + - every time we want to add a new parabola, we check if it may be covering the current right-most parabola. If so, then that parabola was useless, so we can pop it from the stack + - repeat until we can't find any more parabola to pop. Then push the new one. + +This algorithm runs in O(N) for a single row, so overall O(N^2) when applied to all rows +Similarly as before, we notice that we can decompose the algorithm for rows and columns, leading to an overall run-time of O(N^2) + +This algorithm is less suited for to GPUs, since the one-dimensional EDT computation is quite sequential in nature. However, we can parallelize over batch and row dimensions. +In Triton, things are particularly bad at the moment, since there is no support for reading/writing to the local memory at a specific index (a local gather is coming soon, see https://github.com/triton-lang/triton/issues/974, but no mention of writing, ie scatter) +One could emulate these operations with masking, but in initial tests, it proved to be worst than naively reading and writing to the global memory. My guess is that the cache is compensating somewhat for the repeated single-point accesses. + + +The timing obtained on a H100 for a random batch of masks of dimension 256 x 1024 x 1024 are as follows: +- OpenCV: 1780ms (including round-trip to cpu, but discounting the fact that it introduces a synchronization point) +- triton, O(N^3) algo: 627ms +- triton, O(N^2) algo: 322ms + +Overall, despite being quite naive, this implementation is roughly 5.5x faster than the openCV cpu implem + +""" + + +@triton.jit +def edt_kernel(inputs_ptr, outputs_ptr, v, z, height, width, horizontal: tl.constexpr): + # This is a somewhat verbatim implementation of the efficient 1D EDT algorithm described above + # It can be applied horizontally or vertically depending if we're doing the first or second stage. + # It's parallelized across batch+row (or batch+col if horizontal=False) + # TODO: perhaps the implementation can be revisited if/when local gather/scatter become available in triton + batch_id = tl.program_id(axis=0) + if horizontal: + row_id = tl.program_id(axis=1) + block_start = (batch_id * height * width) + row_id * width + length = width + stride = 1 + else: + col_id = tl.program_id(axis=1) + block_start = (batch_id * height * width) + col_id + length = height + stride = width + + # This will be the index of the right most parabola in the envelope ("the top of the stack") + k = 0 + for q in range(1, length): + # Read the function value at the current location. Note that we're doing a singular read, not very efficient + cur_input = tl.load(inputs_ptr + block_start + (q * stride)) + # location of the parabola on top of the stack + r = tl.load(v + block_start + (k * stride)) + # associated boundary + z_k = tl.load(z + block_start + (k * stride)) + # value of the function at the parabola location + previous_input = tl.load(inputs_ptr + block_start + (r * stride)) + # intersection between the two parabolas + s = (cur_input - previous_input + q * q - r * r) / (q - r) / 2 + + # we'll pop as many parabolas as required + while s <= z_k and k - 1 >= 0: + k = k - 1 + r = tl.load(v + block_start + (k * stride)) + z_k = tl.load(z + block_start + (k * stride)) + previous_input = tl.load(inputs_ptr + block_start + (r * stride)) + s = (cur_input - previous_input + q * q - r * r) / (q - r) / 2 + + # Store the new one + k = k + 1 + tl.store(v + block_start + (k * stride), q) + tl.store(z + block_start + (k * stride), s) + if k + 1 < length: + tl.store(z + block_start + ((k + 1) * stride), 1e9) + + # Last step, we read the envelope to find the min in every location + k = 0 + for q in range(length): + while ( + k + 1 < length + and tl.load( + z + block_start + ((k + 1) * stride), mask=(k + 1) < length, other=q + ) + < q + ): + k += 1 + r = tl.load(v + block_start + (k * stride)) + d = q - r + old_value = tl.load(inputs_ptr + block_start + (r * stride)) + tl.store(outputs_ptr + block_start + (q * stride), old_value + d * d) + + +def edt_triton(data: torch.Tensor): + """ + Computes the Euclidean Distance Transform (EDT) of a batch of binary images. + + Args: + data: A tensor of shape (B, H, W) representing a batch of binary images. + + Returns: + A tensor of the same shape as data containing the EDT. + It should be equivalent to a batched version of cv2.distanceTransform(input, cv2.DIST_L2, 0) + """ + assert data.dim() == 3 + assert data.is_cuda + B, H, W = data.shape + data = data.contiguous() + + # Allocate the "function" tensor. Implicitly the function is 0 if data[i,j]==0 else +infinity + output = torch.where(data, 1e18, 0.0) + assert output.is_contiguous() + + # Scratch tensors for the parabola stacks + parabola_loc = torch.zeros(B, H, W, dtype=torch.uint32, device=data.device) + parabola_inter = torch.empty(B, H, W, dtype=torch.float, device=data.device) + parabola_inter[:, :, 0] = -1e18 + parabola_inter[:, :, 1] = 1e18 + + # Grid size (number of blocks) + grid = (B, H) + + # Launch initialization kernel + edt_kernel[grid]( + output.clone(), + output, + parabola_loc, + parabola_inter, + H, + W, + horizontal=True, + ) + + # reset the parabola stacks + parabola_loc.zero_() + parabola_inter[:, :, 0] = -1e18 + parabola_inter[:, :, 1] = 1e18 + + grid = (B, W) + edt_kernel[grid]( + output.clone(), + output, + parabola_loc, + parabola_inter, + H, + W, + horizontal=False, + ) + # don't forget to take sqrt at the end + return output.sqrt() diff --git a/third_party/sam3/sam3/model/encoder.py b/third_party/sam3/sam3/model/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..c40ad777c70331d5a5ffce485cf9d3c04d4cc2ab --- /dev/null +++ b/third_party/sam3/sam3/model/encoder.py @@ -0,0 +1,597 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved +# Based on https://github.com/IDEA-Research/GroundingDINO + +# pyre-unsafe + +from typing import Any, Dict, List, Optional, Tuple + +import torch +from torch import nn, Tensor + +from .act_ckpt_utils import activation_ckpt_wrapper +from .model_misc import get_activation_fn, get_clones, get_valid_ratio + + +class TransformerEncoderLayer(nn.Module): + """ + Transformer encoder layer that performs self-attention followed by cross-attention. + + This layer was previously called TransformerDecoderLayer but was renamed to better + reflect its role in the architecture. It processes input sequences through self-attention + and then cross-attention with another input (typically image features). + + The layer supports both pre-norm and post-norm configurations, as well as + positional encoding at different stages of the attention mechanism. + """ + + def __init__( + self, + activation: str, + cross_attention: nn.Module, + d_model: int, + dim_feedforward: int, + dropout: float, + pos_enc_at_attn: bool, + pos_enc_at_cross_attn_keys: bool, + pos_enc_at_cross_attn_queries: bool, + pre_norm: bool, + self_attention: nn.Module, + ): + """ + Initialize a transformer encoder layer. + + Args: + activation: Activation function to use in the feedforward network + cross_attention: Cross-attention module for attending to image features + d_model: Model dimension/hidden size + dim_feedforward: Dimension of the feedforward network + dropout: Dropout probability + pos_enc_at_attn: Whether to add positional encodings at self-attention + pos_enc_at_cross_attn_keys: Whether to add positional encodings to keys in cross-attention + pos_enc_at_cross_attn_queries: Whether to add positional encodings to queries in cross-attention + pre_norm: Whether to use pre-norm (True) or post-norm (False) architecture + self_attention: Self-attention module + """ + super().__init__() + self.d_model = d_model + self.dim_feedforward = dim_feedforward + self.dropout_value = dropout + self.self_attn = self_attention + self.cross_attn_image = cross_attention + + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation_str = activation + self.activation = get_activation_fn(activation) + self.pre_norm = pre_norm + + self.pos_enc_at_attn = pos_enc_at_attn + self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries + self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys + + self.layer_idx = None + + def forward_post( + self, + tgt: Tensor, + memory: Tensor, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + **kwargs, + ) -> Tensor: + """ + Forward pass for post-norm architecture. + + In post-norm architecture, normalization is applied after attention and feedforward operations. + + Args: + tgt: Input tensor to be processed + memory: Memory tensor for cross-attention + tgt_mask: Mask for self-attention + memory_mask: Mask for cross-attention + tgt_key_padding_mask: Key padding mask for self-attention + memory_key_padding_mask: Key padding mask for cross-attention + pos: Positional encoding for memory + query_pos: Positional encoding for query + **kwargs: Additional keyword arguments + + Returns: + Processed tensor + """ + q = k = tgt + query_pos if self.pos_enc_at_attn else tgt + + # Self attention + tgt2 = self.self_attn( + q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask + )[0] + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + + # Cross attention to image + tgt2 = self.cross_attn_image( + query=tgt + query_pos if self.pos_enc_at_cross_attn_queries else tgt, + key=memory + pos if self.pos_enc_at_cross_attn_keys else memory, + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + )[0] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + + # FFN + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt + + def forward_pre( + self, + tgt: Tensor, + memory: Tensor, + dac: bool = False, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + # attn_bias: Optional[Tensor] = None, + # **kwargs, + ) -> Tensor: + """ + Forward pass for pre-norm architecture. + + In pre-norm architecture, normalization is applied before attention and feedforward operations. + + Args: + tgt: Input tensor to be processed + memory: Memory tensor for cross-attention + dac: Whether to use Divide-and-Conquer attention + tgt_mask: Mask for self-attention + memory_mask: Mask for cross-attention + tgt_key_padding_mask: Key padding mask for self-attention + memory_key_padding_mask: Key padding mask for cross-attention + pos: Positional encoding for memory + query_pos: Positional encoding for query + attn_bias: Optional attention bias tensor + **kwargs: Additional keyword arguments + + Returns: + Processed tensor + """ + if dac: + # we only apply self attention to the first half of the queries + assert tgt.shape[0] % 2 == 0 + other_tgt = tgt[tgt.shape[0] // 2 :] + tgt = tgt[: tgt.shape[0] // 2] + tgt2 = self.norm1(tgt) + q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2 + tgt2 = self.self_attn( + q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask + )[0] + tgt = tgt + self.dropout1(tgt2) + if dac: + # Recombine + tgt = torch.cat((tgt, other_tgt), dim=0) + tgt2 = self.norm2(tgt) + tgt2 = self.cross_attn_image( + query=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2, + key=memory + pos if self.pos_enc_at_cross_attn_keys else memory, + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + # attn_bias=attn_bias, + )[0] + tgt = tgt + self.dropout2(tgt2) + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + + def forward( + self, + tgt: Tensor, + memory: Tensor, + dac: bool = False, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + # attn_bias: Optional[Tensor] = None, + # **kwds: Any, + ) -> torch.Tensor: + """ + Forward pass for the transformer encoder layer. + + Args: + tgt: Input tensor to be processed + memory: Memory tensor (e.g., image features) for cross-attention + dac: Whether to use Divide-and-Conquer attention (only apply self-attention to first half) + tgt_mask: Mask for self-attention + memory_mask: Mask for cross-attention + tgt_key_padding_mask: Key padding mask for self-attention + memory_key_padding_mask: Key padding mask for cross-attention + pos: Positional encoding for memory + query_pos: Positional encoding for query + attn_bias: Optional attention bias tensor + **kwds: Additional keyword arguments + + Returns: + Processed tensor after self-attention, cross-attention, and feedforward network + """ + fwd_fn = self.forward_pre if self.pre_norm else self.forward_post + return fwd_fn( + tgt, + memory, + dac=dac, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos, + query_pos=query_pos, + # attn_bias=attn_bias, + # **kwds, + ) + + +class TransformerEncoder(nn.Module): + """ + Transformer encoder that processes multi-level features. + + This encoder takes multi-level features (e.g., from a backbone network) and processes + them through a stack of transformer encoder layers. It supports features from multiple + levels (e.g., different resolutions) and can apply activation checkpointing for memory + efficiency during training. + + Args: + layer: The encoder layer to be stacked multiple times + num_layers: Number of encoder layers to stack + d_model: Model dimension/hidden size + num_feature_levels: Number of feature levels to process + frozen: Whether to freeze the parameters of this module + use_act_checkpoint: Whether to use activation checkpointing during training + """ + + def __init__( + self, + layer: nn.Module, + num_layers: int, + d_model: int, + num_feature_levels: int, + frozen: bool = False, + use_act_checkpoint: bool = False, + ): + super().__init__() + self.layers = get_clones(layer, num_layers) + self.num_layers = num_layers + + self.num_feature_levels = num_feature_levels + self.level_embed = None + if num_feature_levels > 1: + self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model)) + + if frozen: + for p in self.parameters(): + p.requires_grad_(False) + + self.use_act_checkpoint = use_act_checkpoint + + # assign layer index to each layer so that some layers can decide what to do + # based on which layer index they are (e.g. cross attention to memory bank only + # in selected layers) + for layer_idx, layer in enumerate(self.layers): + layer.layer_idx = layer_idx + + @staticmethod + def get_reference_points(spatial_shapes, valid_ratios, device): + with torch.no_grad(): + reference_points_list = [] + for lvl, (H_, W_) in enumerate(spatial_shapes): + ref_y, ref_x = torch.meshgrid( + torch.linspace( + 0.5, H_ - 0.5, H_, dtype=torch.float32, device=device + ), + torch.linspace( + 0.5, W_ - 0.5, W_, dtype=torch.float32, device=device + ), + ) + ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_) + ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_) + ref = torch.stack((ref_x, ref_y), -1) + reference_points_list.append(ref) + reference_points = torch.cat(reference_points_list, 1) + reference_points = reference_points[:, :, None] * valid_ratios[:, None] + + return reference_points + + def _prepare_multilevel_features(self, srcs, masks, pos_embeds): + assert ( + len(srcs) == self.num_feature_levels + ), "mismatch between expected and received # of feature levels" + + src_flatten = [] + mask_flatten = [] + lvl_pos_embed_flatten = [] + spatial_shapes = [] + has_mask = masks is not None and masks[0] is not None + for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)): + bs, c, h, w = src.shape + spatial_shape = (h, w) + spatial_shapes.append(spatial_shape) + + src = src.flatten(2).transpose(1, 2) # bs, hw, c + if has_mask: + mask = mask.flatten(1) + pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c + if self.level_embed is not None: + lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) + else: + lvl_pos_embed = pos_embed + lvl_pos_embed_flatten.append(lvl_pos_embed) + src_flatten.append(src) + if has_mask: + mask_flatten.append(mask) + src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c + mask_flatten = torch.cat(mask_flatten, 1) if has_mask else None # bs, \sum{hxw} + lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) # bs, \sum{hxw}, c + spatial_shapes = torch.tensor( + spatial_shapes, dtype=torch.long, device=src_flatten.device + ) + level_start_index = torch.cat( + ( + spatial_shapes.new_zeros((1,)), + spatial_shapes.prod(1).cumsum(0)[:-1], + ) + ) + if has_mask: + valid_ratios = torch.stack([get_valid_ratio(m) for m in masks], 1) + else: + valid_ratios = torch.ones( + (src_flatten.shape[0], self.num_feature_levels, 2), + device=src_flatten.device, + ) + + return ( + src_flatten, + mask_flatten, + lvl_pos_embed_flatten, + level_start_index, + valid_ratios, + spatial_shapes, + ) + + def forward( + self, + src: List[Tensor], + src_key_padding_masks: Optional[List[Tensor]] = None, + pos: Optional[List[Tensor]] = None, + prompt: Optional[Tensor] = None, + prompt_key_padding_mask: Optional[Tensor] = None, + encoder_extra_kwargs: Optional[Dict] = None, + ) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor, Tensor]: + """ + Process multi-level features through the transformer encoder. + + Args: + src: List of multi-level features, each with shape (batch_size, channels, height, width) + src_key_padding_masks: List of padding masks for each feature level, each with shape (batch_size, height, width) + pos: List of positional embeddings for each feature level, each with shape (batch_size, channels, height, width) + prompt: Optional text/prompt features to attend to, with shape (seq_len, batch_size, d_model) + prompt_key_padding_mask: Optional padding mask for prompt, with shape (batch_size, seq_len) + encoder_extra_kwargs: Optional additional arguments to pass to each encoder layer + + Returns: + A tuple containing: + - output: Processed features with shape (seq_len, batch_size, d_model) + - key_padding_masks_flatten: Flattened padding masks + - lvl_pos_embed_flatten: Flattened positional embeddings + - level_start_index: Starting indices for each feature level + - spatial_shapes: Spatial dimensions of each feature level + - valid_ratios: Valid ratios for each feature level + """ + assert ( + len(src) == self.num_feature_levels + ), "must be equal to num_feature_levels" + if src_key_padding_masks is not None: + assert len(src_key_padding_masks) == self.num_feature_levels + if pos is not None: + assert len(pos) == self.num_feature_levels + # Flatten multilevel feats and add level pos embeds + ( + src_flatten, + key_padding_masks_flatten, + lvl_pos_embed_flatten, + level_start_index, + valid_ratios, + spatial_shapes, + ) = self._prepare_multilevel_features(src, src_key_padding_masks, pos) + + reference_points = self.get_reference_points( + spatial_shapes, valid_ratios, device=src_flatten.device + ) + + output = src_flatten + for layer in self.layers: + layer_kwargs = {} + + assert isinstance(layer, TransformerEncoderLayer) + layer_kwargs["memory"] = prompt + layer_kwargs["memory_key_padding_mask"] = prompt_key_padding_mask + layer_kwargs["query_pos"] = lvl_pos_embed_flatten + layer_kwargs["tgt"] = output + layer_kwargs["tgt_key_padding_mask"] = key_padding_masks_flatten + + # Allow disabling activation checkpointing for profiling + # if self.training: + # assert self.use_act_checkpoint, "activation ckpt not enabled in encoder" + if encoder_extra_kwargs is not None: + layer_kwargs.update(encoder_extra_kwargs) + output = activation_ckpt_wrapper(layer)( + **layer_kwargs, + act_ckpt_enable=self.training and self.use_act_checkpoint, + ) + # return as seq first + return ( + output.transpose(0, 1), + ( + key_padding_masks_flatten.transpose(0, 1) + if key_padding_masks_flatten is not None + else None + ), + lvl_pos_embed_flatten.transpose(0, 1), + level_start_index, + spatial_shapes, + valid_ratios, + ) + + +class TransformerEncoderFusion(TransformerEncoder): + """ + Transformer encoder that fuses text and image features. + + This encoder extends TransformerEncoder to handle both text and image features, + with the ability to add pooled text features to image features for better + cross-modal fusion. It supports torch.compile for performance optimization. + + Args: + layer: The encoder layer to be stacked multiple times + num_layers: Number of encoder layers to stack + d_model: Model dimension/hidden size + num_feature_levels: Number of feature levels to process + add_pooled_text_to_img_feat: Whether to add pooled text features to image features + pool_text_with_mask: Whether to use the mask when pooling text features + compile_mode: Mode for torch.compile, or None to disable compilation + **kwargs: Additional arguments to pass to the parent class + """ + + def __init__( + self, + layer: nn.Module, + num_layers: int, + d_model: int, + num_feature_levels: int, + add_pooled_text_to_img_feat: bool = True, + pool_text_with_mask: bool = False, + compile_mode: Optional[str] = None, + **kwargs, + ): + super().__init__( + layer, + num_layers, + d_model, + num_feature_levels, + **kwargs, + ) + self.add_pooled_text_to_img_feat = add_pooled_text_to_img_feat + if self.add_pooled_text_to_img_feat: + self.text_pooling_proj = nn.Linear(d_model, d_model) + self.pool_text_with_mask = pool_text_with_mask + if compile_mode is not None: + self.forward = torch.compile( + self.forward, mode=compile_mode, fullgraph=True + ) + + @staticmethod + def get_reference_points(spatial_shapes, valid_ratios, device): + # Not needed here + return None + + def forward( + self, + src: List[Tensor], + prompt: Tensor, + src_key_padding_mask: Optional[List[Tensor]] = None, + src_pos: Optional[List[Tensor]] = None, + prompt_key_padding_mask: Optional[Tensor] = None, + prompt_pos: Optional[Tensor] = None, + feat_sizes: Optional[List[int]] = None, + encoder_extra_kwargs: Optional[Dict] = None, + ): + # Restore spatial shapes of vision + bs = src[0].shape[1] # seq first + if feat_sizes is not None: + assert len(feat_sizes) == len(src) + if src_key_padding_mask is None: + src_key_padding_mask = [None] * len(src) + for i, (h, w) in enumerate(feat_sizes): + src[i] = src[i].reshape(h, w, bs, -1).permute(2, 3, 0, 1) + src_pos[i] = src_pos[i].reshape(h, w, bs, -1).permute(2, 3, 0, 1) + src_key_padding_mask[i] = ( + src_key_padding_mask[i].reshape(h, w, bs).permute(2, 0, 1) + if src_key_padding_mask[i] is not None + else None + ) + else: + assert all( + x.dim == 4 for x in src + ), "expected list of (bs, c, h, w) tensors" + + if self.add_pooled_text_to_img_feat: + # Fusion: Add mean pooled text to image features + pooled_text = pool_text_feat( + prompt, prompt_key_padding_mask, self.pool_text_with_mask + ) + pooled_text = self.text_pooling_proj(pooled_text)[ + ..., None, None + ] # prompt is seq first + src = [x.add_(pooled_text) for x in src] + + ( + out, + key_padding_masks_flatten, + lvl_pos_embed_flatten, + level_start_index, + spatial_shapes, + valid_ratios, + ) = super().forward( + src, + src_key_padding_masks=src_key_padding_mask, + pos=src_pos, + prompt=prompt.transpose(0, 1), + prompt_key_padding_mask=prompt_key_padding_mask, + encoder_extra_kwargs=encoder_extra_kwargs, + ) + + return { + "memory": out, + "padding_mask": key_padding_masks_flatten, + "pos_embed": lvl_pos_embed_flatten, + "memory_text": prompt, + "level_start_index": level_start_index, + "spatial_shapes": spatial_shapes, + "valid_ratios": valid_ratios, + } + + +def pool_text_feat(prompt, prompt_mask, pool_with_mask): + # prompt has shape (seq, bs, dim) + if not pool_with_mask: + return prompt.mean(dim=0) + + # prompt_mask has shape (bs, seq), where False is valid and True is padding + assert prompt_mask.dim() == 2 + # is_valid has shape (seq, bs, 1), where 1 is valid and 0 is padding + is_valid = (~prompt_mask).float().permute(1, 0)[..., None] + # num_valid has shape (bs, 1) + num_valid = torch.clamp(torch.sum(is_valid, dim=0), min=1.0) + + # mean pool over all the valid tokens + pooled_text = (prompt * is_valid).sum(dim=0) / num_valid + return pooled_text diff --git a/third_party/sam3/sam3/model/geometry_encoders.py b/third_party/sam3/sam3/model/geometry_encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..35a4276f2ca58ea221145f340212ac8f96965df3 --- /dev/null +++ b/third_party/sam3/sam3/model/geometry_encoders.py @@ -0,0 +1,851 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +from typing import Tuple + +import torch +import torch.nn as nn +import torchvision +from typing_extensions import override + +from .act_ckpt_utils import activation_ckpt_wrapper +from .box_ops import box_cxcywh_to_xyxy +from .model_misc import get_clones + + +def is_right_padded(mask): + """Given a padding mask (following pytorch convention, 1s for padded values), + returns whether the padding is on the right or not.""" + return (mask.long() == torch.sort(mask.long(), dim=-1)[0]).all() + + +def concat_padded_sequences(seq1, mask1, seq2, mask2, return_index: bool = False): + """ + Concatenates two right-padded sequences, such that the resulting sequence + is contiguous and also right-padded. + + Following pytorch's convention, tensors are sequence first, and the mask are + batch first, with 1s for padded values. + + :param seq1: A tensor of shape (seq1_length, batch_size, hidden_size). + :param mask1: A tensor of shape (batch_size, seq1_length). + :param seq2: A tensor of shape (seq2_length, batch_size, hidden_size). + :param mask2: A tensor of shape (batch_size, seq2_length). + :param return_index: If True, also returns the index of the ids of the element of seq2 + in the concatenated sequence. This can be used to retrieve the elements of seq2 + :return: A tuple (concatenated_sequence, concatenated_mask) if return_index is False, + otherwise (concatenated_sequence, concatenated_mask, index). + """ + seq1_length, batch_size, hidden_size = seq1.shape + seq2_length, batch_size, hidden_size = seq2.shape + + assert batch_size == seq1.size(1) == seq2.size(1) == mask1.size(0) == mask2.size(0) + assert hidden_size == seq1.size(2) == seq2.size(2) + assert seq1_length == mask1.size(1) + assert seq2_length == mask2.size(1) + + torch._assert_async(is_right_padded(mask1)) + torch._assert_async(is_right_padded(mask2)) + + actual_seq1_lengths = (~mask1).sum(dim=-1) + actual_seq2_lengths = (~mask2).sum(dim=-1) + + final_lengths = actual_seq1_lengths + actual_seq2_lengths + max_length = seq1_length + seq2_length + concatenated_mask = ( + torch.arange(max_length, device=seq2.device)[None].repeat(batch_size, 1) + >= final_lengths[:, None] + ) + + # (max_len, batch_size, hidden_size) + concatenated_sequence = torch.zeros( + (max_length, batch_size, hidden_size), device=seq2.device, dtype=seq2.dtype + ) + concatenated_sequence[:seq1_length, :, :] = seq1 + + # At this point, the element of seq1 are in the right place + # We just need to shift the elements of seq2 + + index = torch.arange(seq2_length, device=seq2.device)[:, None].repeat(1, batch_size) + index = index + actual_seq1_lengths[None] + + concatenated_sequence = concatenated_sequence.scatter( + 0, index[:, :, None].expand(-1, -1, hidden_size), seq2 + ) + + if return_index: + return concatenated_sequence, concatenated_mask, index + + return concatenated_sequence, concatenated_mask + + +class Prompt: + """Utility class to manipulate geometric prompts. + + We expect the sequences in pytorch convention, that is sequence first, batch second + The dimensions are expected as follows: + box_embeddings shape: N_boxes x B x C_box + box_mask shape: B x N_boxes. Can be None if nothing is masked out + point_embeddings shape: N_points x B x C_point + point_mask shape: B x N_points. Can be None if nothing is masked out + mask_embeddings shape: N_masks x B x 1 x H_mask x W_mask + mask_mask shape: B x N_masks. Can be None if nothing is masked out + + We also store positive/negative labels. These tensors are also stored batch-first + If they are None, we'll assume positive labels everywhere + box_labels: long tensor of shape N_boxes x B + point_labels: long tensor of shape N_points x B + mask_labels: long tensor of shape N_masks x B + """ + + def __init__( + self, + box_embeddings=None, + box_mask=None, + point_embeddings=None, + point_mask=None, + box_labels=None, + point_labels=None, + mask_embeddings=None, + mask_mask=None, # Attention mask for mask prompt + mask_labels=None, + ): + # Check for null prompt + if ( + box_embeddings is None + and point_embeddings is None + and mask_embeddings is None + ): + self.box_embeddings = None + self.box_labels = None + self.box_mask = None + self.point_embeddings = None + self.point_labels = None + self.point_mask = None + self.mask_embeddings = None + self.mask_mask = None + # Masks are assumed positive only for now. + self.mask_labels = None + return + # Get sequence lengths and device + box_seq_len, point_seq_len, mask_seq_len, bs, device = ( + self._init_seq_len_and_device( + box_embeddings, point_embeddings, mask_embeddings + ) + ) + + # Initialize embeds, labels, attention masks. + box_embeddings, box_labels, box_mask = self._init_box( + box_embeddings, box_labels, box_mask, box_seq_len, bs, device + ) + point_embeddings, point_labels, point_mask = self._init_point( + point_embeddings, point_labels, point_mask, point_seq_len, bs, device + ) + mask_embeddings, mask_labels, mask_mask = self._init_mask( + mask_embeddings, mask_labels, mask_mask, mask_seq_len, bs, device + ) + + # Dimension checks + assert ( + box_embeddings is not None + and list(box_embeddings.shape[:2]) + == [ + box_seq_len, + bs, + ] + ), f"Wrong dimension for box embeddings. Expected [{box_seq_len}, {bs}, *] got {box_embeddings.shape}" + assert ( + box_mask is not None + and list(box_mask.shape) + == [ + bs, + box_seq_len, + ] + ), f"Wrong dimension for box mask. Expected [{bs}, {box_seq_len}] got {box_mask.shape}" + assert ( + point_embeddings is not None + and list(point_embeddings.shape[:2]) + == [ + point_seq_len, + bs, + ] + ), f"Wrong dimension for point embeddings. Expected [{point_seq_len}, {bs}, *] got {point_embeddings.shape}" + assert ( + point_mask is not None + and list(point_mask.shape) + == [ + bs, + point_seq_len, + ] + ), f"Wrong dimension for point mask. Expected [{bs}, {point_seq_len}] got {point_mask.shape}" + assert ( + box_labels is not None + and list(box_labels.shape) + == [ + box_seq_len, + bs, + ] + ), f"Wrong dimension for box labels. Expected [{box_seq_len}, {bs}] got {box_labels.shape}" + assert ( + point_labels is not None + and list(point_labels.shape) + == [ + point_seq_len, + bs, + ] + ), f"Wrong dimension for point labels. Expected [{point_seq_len}, {bs}] got {point_labels.shape}" + assert ( + # Allowed to be None, we leave it to the encoder to check for validity before encoding. + mask_embeddings is None + or list(mask_embeddings.shape[:2]) + == [ + mask_seq_len, + bs, + ] + ), f"Wrong dimension for mask embeddings. Expected [{mask_seq_len}, {bs}, *] got {mask_embeddings.shape}" + assert ( + mask_mask is None + or list(mask_mask.shape) + == [ + bs, + mask_seq_len, + ] + ), f"Wrong dimension for mask attn. mask. Expected [{bs}, {mask_seq_len}] got {mask_mask.shape}" + + # Device checks + assert ( + box_embeddings is not None and box_embeddings.device == device + ), f"Expected box embeddings to be on device {device}, got {box_embeddings.device}" + assert ( + box_mask is not None and box_mask.device == device + ), f"Expected box mask to be on device {device}, got {box_mask.device}" + assert ( + box_labels is not None and box_labels.device == device + ), f"Expected box labels to be on device {device}, got {box_labels.device}" + assert ( + point_embeddings is not None and point_embeddings.device == device + ), f"Expected point embeddings to be on device {device}, got {point_embeddings.device}" + assert ( + point_mask is not None and point_mask.device == device + ), f"Expected point mask to be on device {device}, got {point_mask.device}" + assert ( + point_labels is not None and point_labels.device == device + ), f"Expected point labels to be on device {device}, got {point_labels.device}" + assert ( + mask_embeddings is None or mask_embeddings.device == device + ), f"Expected mask embeddings to be on device {device}, got {mask_embeddings.device}" + assert ( + mask_mask is None or mask_mask.device == device + ), f"Expected mask attn. mask to be on device {device}, got {mask_mask.device}" + + self.box_embeddings = box_embeddings + self.point_embeddings = point_embeddings + self.box_mask = box_mask + self.point_mask = point_mask + self.box_labels = box_labels + self.point_labels = point_labels + self.mask_embeddings = mask_embeddings + self.mask_labels = mask_labels + self.mask_mask = mask_mask + + def _init_seq_len_and_device( + self, box_embeddings, point_embeddings, mask_embeddings + ): + box_seq_len = point_seq_len = mask_seq_len = 0 + bs = None + device = None + if box_embeddings is not None: + bs = box_embeddings.shape[1] + box_seq_len = box_embeddings.shape[0] + device = box_embeddings.device + + if point_embeddings is not None: + point_seq_len = point_embeddings.shape[0] + if bs is not None: + assert ( + bs == point_embeddings.shape[1] + ), f"Batch size mismatch between box and point embeddings. Got {bs} and {point_embeddings.shape[1]}." + else: + bs = point_embeddings.shape[1] + if device is not None: + assert ( + device == point_embeddings.device + ), "Device mismatch between box and point embeddings" + else: + device = point_embeddings.device + + if mask_embeddings is not None: + mask_seq_len = mask_embeddings.shape[0] + if bs is not None: + assert ( + bs == mask_embeddings.shape[1] + ), f"Batch size mismatch between box/point and mask embedding. Got {bs} and {mask_embeddings.shape[1]}" + else: + bs = mask_embeddings.shape[1] + if device is not None: + assert ( + device == mask_embeddings.device + ), "Device mismatch between box/point and mask embeddings." + else: + device = mask_embeddings.device + + return box_seq_len, point_seq_len, mask_seq_len, bs, device + + def _init_box(self, box_embeddings, box_labels, box_mask, box_seq_len, bs, device): + if box_embeddings is None: + box_embeddings = torch.zeros(box_seq_len, bs, 4, device=device) + if box_labels is None: + box_labels = torch.ones(box_seq_len, bs, device=device, dtype=torch.long) + if box_mask is None: + box_mask = torch.zeros(bs, box_seq_len, device=device, dtype=torch.bool) + return box_embeddings, box_labels, box_mask + + def _init_point( + self, point_embeddings, point_labels, point_mask, point_seq_len, bs, device + ): + """ + Identical to _init_box. Except that C=2 for points (vs. 4 for boxes). + """ + if point_embeddings is None: + point_embeddings = torch.zeros(point_seq_len, bs, 2, device=device) + if point_labels is None: + point_labels = torch.ones( + point_seq_len, bs, device=device, dtype=torch.long + ) + if point_mask is None: + point_mask = torch.zeros(bs, point_seq_len, device=device, dtype=torch.bool) + return point_embeddings, point_labels, point_mask + + def _init_mask( + self, mask_embeddings, mask_labels, mask_mask, mask_seq_len, bs, device + ): + # NOTE: Mask embeddings can be of arbitrary resolution, so we don't initialize it here. + # In case we append new mask, we check that its resolution matches exisiting ones (if any). + # In case mask_embeddings is None, we should never encode it. + if mask_labels is None: + mask_labels = torch.ones(mask_seq_len, bs, device=device, dtype=torch.long) + if mask_mask is None: + mask_mask = torch.zeros(bs, mask_seq_len, device=device, dtype=torch.bool) + return mask_embeddings, mask_labels, mask_mask + + def append_boxes(self, boxes, labels, mask=None): + if self.box_embeddings is None: + self.box_embeddings = boxes + self.box_labels = labels + self.box_mask = mask + return + + bs = self.box_embeddings.shape[1] + assert boxes.shape[1] == labels.shape[1] == bs + assert list(boxes.shape[:2]) == list(labels.shape[:2]) + if mask is None: + mask = torch.zeros( + bs, boxes.shape[0], dtype=torch.bool, device=boxes.device + ) + + self.box_labels, _ = concat_padded_sequences( + self.box_labels.unsqueeze(-1), self.box_mask, labels.unsqueeze(-1), mask + ) + self.box_labels = self.box_labels.squeeze(-1) + self.box_embeddings, self.box_mask = concat_padded_sequences( + self.box_embeddings, self.box_mask, boxes, mask + ) + + def append_points(self, points, labels, mask=None): + if self.point_embeddings is None: + self.point_embeddings = points + self.point_labels = labels + self.point_mask = mask + return + + bs = self.point_embeddings.shape[1] + assert points.shape[1] == labels.shape[1] == bs + assert list(points.shape[:2]) == list(labels.shape[:2]) + if mask is None: + mask = torch.zeros( + bs, points.shape[0], dtype=torch.bool, device=points.device + ) + + self.point_labels, _ = concat_padded_sequences( + self.point_labels.unsqueeze(-1), self.point_mask, labels.unsqueeze(-1), mask + ) + self.point_labels = self.point_labels.squeeze(-1) + self.point_embeddings, self.point_mask = concat_padded_sequences( + self.point_embeddings, self.point_mask, points, mask + ) + + def append_masks(self, masks, labels=None, attn_mask=None): + if labels is not None: + assert list(masks.shape[:2]) == list(labels.shape[:2]) + if self.mask_embeddings is None: + self.mask_embeddings = masks + mask_seq_len, bs = masks.shape[:2] + if labels is None: + self.mask_labels = torch.ones( + mask_seq_len, bs, device=masks.device, dtype=torch.long + ) + else: + self.mask_labels = labels + if attn_mask is None: + self.mask_mask = torch.zeros( + bs, mask_seq_len, device=masks.device, dtype=torch.bool + ) + else: + self.mask_mask = attn_mask + else: + raise NotImplementedError("Only one mask per prompt is supported.") + + def clone(self): + return Prompt( + box_embeddings=( + None if self.box_embeddings is None else self.box_embeddings.clone() + ), + box_mask=None if self.box_mask is None else self.box_mask.clone(), + point_embeddings=( + None if self.point_embeddings is None else self.point_embeddings.clone() + ), + point_mask=None if self.point_mask is None else self.point_mask.clone(), + box_labels=None if self.box_labels is None else self.box_labels.clone(), + point_labels=( + None if self.point_labels is None else self.point_labels.clone() + ), + ) + + +class MaskEncoder(nn.Module): + """ + Base class for mask encoders. + """ + + def __init__( + self, + mask_downsampler: nn.Module, + position_encoding: nn.Module, + ): + super().__init__() + self.mask_downsampler = mask_downsampler + self.position_encoding = position_encoding + + def forward(self, masks, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: + masks = self.mask_downsampler(masks) + masks_pos = self.position_encoding(masks).to(masks.dtype) + + return masks, masks_pos + + +class FusedMaskEncoder(MaskEncoder): + """ + Identical to memory.SimpleMaskEncoder but follows the interface of geometry_encoders.MaskEncoder. + We also remove the `skip_mask_sigmoid` option (to be handled outside the MaskEncoder). + Fuses backbone image features with mask features. + """ + + def __init__( + self, + mask_downsampler: nn.Module, + position_encoding: nn.Module, + fuser: nn.Module, + in_dim: int = 256, + out_dim: int = 256, + ): + super().__init__(mask_downsampler, position_encoding) + self.fuser = fuser + self.out_proj = nn.Identity() + if out_dim != in_dim: + self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1) + self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1) + + @override + def forward( + self, + masks: torch.Tensor, + pix_feat: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + masks = self.mask_downsampler(masks) + + ## Fuse pix_feats and downsampled masks + # in case the visual features are on CPU, cast them to CUDA + pix_feat = pix_feat.to(masks.device) + + x = self.pix_feat_proj(pix_feat) + x = x + masks + x = self.fuser(x) + x = self.out_proj(x) + + pos = self.position_encoding(x).to(x.dtype) + + return x, pos + + +class SequenceGeometryEncoder(nn.Module): + """ + This a fully fledged encoder for geometric prompts. + It assumes boxes are passed in the "normalized CxCyWH" format, and points in normalized xy + This allows flexibility in how to encode the features (eg do pooling) + + Points and boxes can be encoded with any of the three possibilities: + - direct projection: we just compute a linear from coordinate space to d_model + - pooling: pool features from the backbone in the requested location. + For boxes, it's a roi align + For points it's a grid sample + - pos encoder: Take the position encoding of the point or box center + + These three options are mutually compatible. If several are selected, we'll take a simple addition + + As an alternative, we offer the possibility to encode points only. + In that case, the boxes are converted to two points for the top left and bottom right corners (with appropriate labels) + + On top of these encodings, we offer the possibility to further encode the prompt sequence with a transformer. + """ + + def __init__( + self, + encode_boxes_as_points: bool, + points_direct_project: bool, + points_pool: bool, + points_pos_enc: bool, + boxes_direct_project: bool, + boxes_pool: bool, + boxes_pos_enc: bool, + d_model: int, + pos_enc, + num_layers: int, + layer: nn.Module, + roi_size: int = 7, # for boxes pool + add_cls: bool = True, + add_post_encode_proj: bool = True, + mask_encoder: MaskEncoder = None, + add_mask_label: bool = False, + use_act_ckpt: bool = False, + ): + super().__init__() + + self.d_model = d_model + self.pos_enc = pos_enc + self.encode_boxes_as_points = encode_boxes_as_points + self.roi_size = roi_size + # There usually are two labels: positive and negatives. + # If we encode boxes as points, we have 3 types of points: regular, top left, bottom right + # These 3 types can be positives or negatives, hence 2*3 = 6 labels + num_labels = 6 if self.encode_boxes_as_points else 2 + self.label_embed = torch.nn.Embedding(num_labels, self.d_model) + + # This is a cls token, can be used for pooling if need be. + # It also ensures that the encoded sequences are always non-empty + self.cls_embed = None + if add_cls: + self.cls_embed = torch.nn.Embedding(1, self.d_model) + + assert ( + points_direct_project or points_pos_enc or points_pool + ), "Error: need at least one way to encode points" + assert ( + encode_boxes_as_points + or boxes_direct_project + or boxes_pos_enc + or boxes_pool + ), "Error: need at least one way to encode boxes" + + self.points_direct_project = None + if points_direct_project: + self.points_direct_project = nn.Linear(2, self.d_model) + self.points_pool_project = None + if points_pool: + self.points_pool_project = nn.Linear(self.d_model, self.d_model) + self.points_pos_enc_project = None + if points_pos_enc: + self.points_pos_enc_project = nn.Linear(self.d_model, self.d_model) + + self.boxes_direct_project = None + self.boxes_pool_project = None + self.boxes_pos_enc_project = None + if not encode_boxes_as_points: + if boxes_direct_project: + self.boxes_direct_project = nn.Linear(4, self.d_model) + if boxes_pool: + self.boxes_pool_project = nn.Conv2d( + self.d_model, self.d_model, self.roi_size + ) + if boxes_pos_enc: + self.boxes_pos_enc_project = nn.Linear(self.d_model + 2, self.d_model) + + self.final_proj = None + if add_post_encode_proj: + self.final_proj = nn.Linear(self.d_model, self.d_model) + self.norm = nn.LayerNorm(self.d_model) + + self.img_pre_norm = nn.Identity() + if self.points_pool_project is not None or self.boxes_pool_project is not None: + self.img_pre_norm = nn.LayerNorm(self.d_model) + + self.encode = None + if num_layers > 0: + assert ( + add_cls + ), "It's currently highly recommended to add a CLS when using a transformer" + self.encode = get_clones(layer, num_layers) + self.encode_norm = nn.LayerNorm(self.d_model) + + if mask_encoder is not None: + assert isinstance( + mask_encoder, MaskEncoder + ), f"Expected mask_encoder of type MaskEncoder. Got {type(mask_encoder)}." + if add_mask_label: + self.mask_label_embed = torch.nn.Embedding(2, self.d_model) + self.add_mask_label = add_mask_label + self.mask_encoder = mask_encoder + self.use_act_ckpt = use_act_ckpt + + def _encode_points(self, points, points_mask, points_labels, img_feats): + points_embed = None + n_points, bs = points.shape[:2] + + if self.points_direct_project is not None: + proj = self.points_direct_project(points) + assert points_embed is None + points_embed = proj + + if self.points_pool_project is not None: + # points are [Num_points, bs, 2], normalized in [0, 1] + # the grid needs to be [Bs, H_out, W_out, 2] normalized in [-1,1] + # Will take H_out = num_points, w_out = 1 + grid = points.transpose(0, 1).unsqueeze(2) + # re normalize to [-1, 1] + grid = (grid * 2) - 1 + sampled = torch.nn.functional.grid_sample( + img_feats, grid, align_corners=False + ) + assert list(sampled.shape) == [bs, self.d_model, n_points, 1] + sampled = sampled.squeeze(-1).permute(2, 0, 1) + proj = self.points_pool_project(sampled) + if points_embed is None: + points_embed = proj + else: + points_embed = points_embed + proj + + if self.points_pos_enc_project is not None: + x, y = points.unbind(-1) + enc_x, enc_y = self.pos_enc._encode_xy(x.flatten(), y.flatten()) + enc_x = enc_x.view(n_points, bs, enc_x.shape[-1]) + enc_y = enc_y.view(n_points, bs, enc_y.shape[-1]) + enc = torch.cat([enc_x, enc_y], -1) + + proj = self.points_pos_enc_project(enc) + if points_embed is None: + points_embed = proj + else: + points_embed = points_embed + proj + + type_embed = self.label_embed(points_labels.long()) + return type_embed + points_embed, points_mask + + def _encode_boxes(self, boxes, boxes_mask, boxes_labels, img_feats): + boxes_embed = None + n_boxes, bs = boxes.shape[:2] + + if self.boxes_direct_project is not None: + proj = self.boxes_direct_project(boxes) + assert boxes_embed is None + boxes_embed = proj + + if self.boxes_pool_project is not None: + H, W = img_feats.shape[-2:] + + # boxes are [Num_boxes, bs, 4], normalized in [0, 1] + # We need to denormalize, and convert to [x, y, x, y] + boxes_xyxy = box_cxcywh_to_xyxy(boxes) + scale = torch.tensor([W, H, W, H], dtype=boxes_xyxy.dtype) + scale = scale.pin_memory().to(device=boxes_xyxy.device, non_blocking=True) + scale = scale.view(1, 1, 4) + boxes_xyxy = boxes_xyxy * scale + sampled = torchvision.ops.roi_align( + img_feats, boxes_xyxy.float().transpose(0, 1).unbind(0), self.roi_size + ) + assert list(sampled.shape) == [ + bs * n_boxes, + self.d_model, + self.roi_size, + self.roi_size, + ] + proj = self.boxes_pool_project(sampled) + proj = proj.view(bs, n_boxes, self.d_model).transpose(0, 1) + if boxes_embed is None: + boxes_embed = proj + else: + boxes_embed = boxes_embed + proj + + if self.boxes_pos_enc_project is not None: + cx, cy, w, h = boxes.unbind(-1) + enc = self.pos_enc.encode_boxes( + cx.flatten(), cy.flatten(), w.flatten(), h.flatten() + ) + enc = enc.view(boxes.shape[0], boxes.shape[1], enc.shape[-1]) + + proj = self.boxes_pos_enc_project(enc) + if boxes_embed is None: + boxes_embed = proj + else: + boxes_embed = boxes_embed + proj + + type_embed = self.label_embed(boxes_labels.long()) + return type_embed + boxes_embed, boxes_mask + + def _encode_masks( + self, + masks: torch.Tensor, + attn_mask: torch.Tensor, + mask_labels: torch.Tensor, + img_feats: torch.Tensor = None, + ): + n_masks, bs = masks.shape[:2] + assert ( + n_masks == 1 + ), "We assume one mask per prompt for now. Code should still be functional if this assertion is removed." + assert ( + list(attn_mask.shape) + == [ + bs, + n_masks, + ] + ), f"Expected attn_mask to be of shape {bs}x{n_masks}. Got {list(attn_mask.shape)}." + masks, pos = self.mask_encoder( + masks=masks.flatten(0, 1).float(), + pix_feat=img_feats, + ) + H, W = masks.shape[-2:] + n_tokens_per_mask = H * W + # NOTE: We directly add pos enc here as we usually don't keep track of pos encoding for the concatenated prompt (text, other geometric prompts). Might need to do some refactoring for more flexibility. + masks = masks + pos + masks = masks.view(n_masks, bs, *masks.shape[1:]).flatten( + -2 + ) # n_masks x bs x C x H*W + masks = masks.permute(0, 3, 1, 2).flatten(0, 1) # n_masks * H*W x bs x C + attn_mask = attn_mask.repeat_interleave(n_tokens_per_mask, dim=1) + if self.add_mask_label: + masks = masks + self.mask_label_embed(mask_labels.long()) + return masks, attn_mask + + def forward(self, geo_prompt: Prompt, img_feats, img_sizes, img_pos_embeds=None): + points = geo_prompt.point_embeddings + points_mask = geo_prompt.point_mask + points_labels = geo_prompt.point_labels + boxes = geo_prompt.box_embeddings + boxes_mask = geo_prompt.box_mask + boxes_labels = geo_prompt.box_labels + masks = geo_prompt.mask_embeddings + masks_mask = geo_prompt.mask_mask + masks_labels = geo_prompt.mask_labels + seq_first_img_feats = img_feats[-1] # [H*W, B, C] + seq_first_img_pos_embeds = ( + img_pos_embeds[-1] + if img_pos_embeds is not None + else torch.zeros_like(seq_first_img_feats) + ) + + if self.points_pool_project or self.boxes_pool_project: + assert len(img_feats) == len(img_sizes) + cur_img_feat = img_feats[-1] + cur_img_feat = self.img_pre_norm(cur_img_feat) + H, W = img_sizes[-1] + assert cur_img_feat.shape[0] == H * W + N, C = cur_img_feat.shape[-2:] + # Put back in NxCxHxW + cur_img_feat = cur_img_feat.permute(1, 2, 0) + cur_img_feat = cur_img_feat.view(N, C, H, W) + img_feats = cur_img_feat + + if self.encode_boxes_as_points: + assert boxes is not None + assert geo_prompt.box_mask is not None + assert geo_prompt.box_labels is not None + assert boxes.shape[-1] == 4 + + boxes_xyxy = box_cxcywh_to_xyxy(boxes) + top_left, bottom_right = boxes_xyxy.split(split_size=2, dim=-1) + + labels_tl = geo_prompt.box_labels + 2 + labels_br = geo_prompt.box_labels + 4 + + # Append to the existing points + points, _ = concat_padded_sequences( + points, points_mask, top_left, boxes_mask + ) + points_labels, points_mask = concat_padded_sequences( + points_labels.unsqueeze(-1), + points_mask, + labels_tl.unsqueeze(-1), + boxes_mask, + ) + points_labels = points_labels.squeeze(-1) + + points, _ = concat_padded_sequences( + points, points_mask, bottom_right, boxes_mask + ) + points_labels, points_mask = concat_padded_sequences( + points_labels.unsqueeze(-1), + points_mask, + labels_br.unsqueeze(-1), + boxes_mask, + ) + points_labels = points_labels.squeeze(-1) + + final_embeds, final_mask = self._encode_points( + points=points, + points_mask=points_mask, + points_labels=points_labels, + img_feats=img_feats, + ) + + if not self.encode_boxes_as_points: + boxes_embeds, boxes_mask = self._encode_boxes( + boxes=boxes, + boxes_mask=boxes_mask, + boxes_labels=boxes_labels, + img_feats=img_feats, + ) + + final_embeds, final_mask = concat_padded_sequences( + final_embeds, final_mask, boxes_embeds, boxes_mask + ) + + if masks is not None and self.mask_encoder is not None: + masks_embed, masks_mask = self._encode_masks( + masks=masks, + attn_mask=masks_mask, + mask_labels=masks_labels, + img_feats=img_feats, + ) + if points.size(0) == boxes.size(0) == 0: + return masks_embed, masks_mask + bs = final_embeds.shape[1] + assert final_mask.shape[0] == bs + if self.cls_embed is not None: + cls = self.cls_embed.weight.view(1, 1, self.d_model).repeat(1, bs, 1) + cls_mask = torch.zeros( + bs, 1, dtype=final_mask.dtype, device=final_mask.device + ) + final_embeds, final_mask = concat_padded_sequences( + final_embeds, final_mask, cls, cls_mask + ) + + if self.final_proj is not None: + final_embeds = self.norm(self.final_proj(final_embeds)) + + if self.encode is not None: + for lay in self.encode: + final_embeds = activation_ckpt_wrapper(lay)( + tgt=final_embeds, + memory=seq_first_img_feats, + tgt_key_padding_mask=final_mask, + pos=seq_first_img_pos_embeds, + act_ckpt_enable=self.training and self.use_act_ckpt, + ) + final_embeds = self.encode_norm(final_embeds) + # Finally, concat mask embeddings if any + if masks is not None and self.mask_encoder is not None: + final_embeds, final_mask = concat_padded_sequences( + final_embeds, final_mask, masks_embed, masks_mask + ) + return final_embeds, final_mask diff --git a/third_party/sam3/sam3/model/io_utils.py b/third_party/sam3/sam3/model/io_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..470bb43246c56f05f582b8e0bafbecb4e032eb6c --- /dev/null +++ b/third_party/sam3/sam3/model/io_utils.py @@ -0,0 +1,725 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +import contextlib +import os +import queue +import re +import time +from threading import Condition, get_ident, Lock, Thread + +import numpy as np +import torch +import torch.nn.functional as F +import torchvision.transforms.functional as TF +from PIL import Image +from sam3.logger import get_logger +from tqdm import tqdm + +logger = get_logger(__name__) + +IS_MAIN_PROCESS = os.getenv("IS_MAIN_PROCESS", "1") == "1" +RANK = int(os.getenv("RANK", "0")) + +IMAGE_EXTS = [".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"] +VIDEO_EXTS = [".mp4", ".mov", ".avi", ".mkv", ".webm"] + + +def load_resource_as_video_frames( + resource_path, + image_size, + offload_video_to_cpu, + img_mean=(0.5, 0.5, 0.5), + img_std=(0.5, 0.5, 0.5), + async_loading_frames=False, + video_loader_type="cv2", +): + """ + Load video frames from either a video or an image (as a single-frame video). + Alternatively, if input is a list of PIL images, convert its format + """ + if isinstance(resource_path, list): + img_mean = torch.tensor(img_mean, dtype=torch.float16)[:, None, None] + img_std = torch.tensor(img_std, dtype=torch.float16)[:, None, None] + assert all(isinstance(img_pil, Image.Image) for img_pil in resource_path) + assert len(resource_path) is not None + orig_height, orig_width = resource_path[0].size + orig_height, orig_width = ( + orig_width, + orig_height, + ) # For some reason, this method returns these swapped + images = [] + for img_pil in resource_path: + img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size))) + assert img_np.dtype == np.uint8, "np.uint8 is expected for JPEG images" + img_np = img_np / 255.0 + img = torch.from_numpy(img_np).permute(2, 0, 1) + # float16 precision should be sufficient for image tensor storage + img = img.to(dtype=torch.float16) + # normalize by mean and std + img -= img_mean + img /= img_std + images.append(img) + images = torch.stack(images) + if not offload_video_to_cpu: + images = images.cuda() + return images, orig_height, orig_width + + is_image = ( + isinstance(resource_path, str) + and os.path.splitext(resource_path)[-1].lower() in IMAGE_EXTS + ) + if is_image: + return load_image_as_single_frame_video( + image_path=resource_path, + image_size=image_size, + offload_video_to_cpu=offload_video_to_cpu, + img_mean=img_mean, + img_std=img_std, + ) + else: + return load_video_frames( + video_path=resource_path, + image_size=image_size, + offload_video_to_cpu=offload_video_to_cpu, + img_mean=img_mean, + img_std=img_std, + async_loading_frames=async_loading_frames, + video_loader_type=video_loader_type, + ) + + +def load_image_as_single_frame_video( + image_path, + image_size, + offload_video_to_cpu, + img_mean=(0.5, 0.5, 0.5), + img_std=(0.5, 0.5, 0.5), +): + """Load an image as a single-frame video.""" + images, image_height, image_width = _load_img_as_tensor(image_path, image_size) + images = images.unsqueeze(0).half() + + img_mean = torch.tensor(img_mean, dtype=torch.float16)[:, None, None] + img_std = torch.tensor(img_std, dtype=torch.float16)[:, None, None] + if not offload_video_to_cpu: + images = images.cuda() + img_mean = img_mean.cuda() + img_std = img_std.cuda() + # normalize by mean and std + images -= img_mean + images /= img_std + return images, image_height, image_width + + +def load_video_frames( + video_path, + image_size, + offload_video_to_cpu, + img_mean=(0.5, 0.5, 0.5), + img_std=(0.5, 0.5, 0.5), + async_loading_frames=False, + video_loader_type="cv2", +): + """ + Load the video frames from video_path. The frames are resized to image_size as in + the model and are loaded to GPU if offload_video_to_cpu=False. This is used by the demo. + """ + assert isinstance(video_path, str) + if video_path.startswith(" where N is an integer + match = re.match(r"", video_path) + num_frames = int(match.group(1)) if match else 60 + return load_dummy_video(image_size, offload_video_to_cpu, num_frames=num_frames) + elif video_path.startswith(" where N is an integer + match = re.match(r"", video_path) + num_frames = int(match.group(1)) if match else 60 + return load_dummy_video( + image_size, offload_video_to_cpu, num_frames=num_frames, do_zeros=True + ) + elif os.path.isdir(video_path): + return load_video_frames_from_image_folder( + image_folder=video_path, + image_size=image_size, + offload_video_to_cpu=offload_video_to_cpu, + img_mean=img_mean, + img_std=img_std, + async_loading_frames=async_loading_frames, + ) + elif os.path.splitext(video_path)[-1].lower() in VIDEO_EXTS: + return load_video_frames_from_video_file( + video_path=video_path, + image_size=image_size, + offload_video_to_cpu=offload_video_to_cpu, + img_mean=img_mean, + img_std=img_std, + async_loading_frames=async_loading_frames, + video_loader_type=video_loader_type, + ) + else: + raise NotImplementedError("Only video files and image folders are supported") + + +def load_video_frames_from_image_folder( + image_folder, + image_size, + offload_video_to_cpu, + img_mean, + img_std, + async_loading_frames, +): + """ + Load the video frames from a directory of image files ("." format) + """ + frame_names = [ + p + for p in os.listdir(image_folder) + if os.path.splitext(p)[-1].lower() in IMAGE_EXTS + ] + try: + frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) + except ValueError: + # fallback to lexicographic sort if the format is not "." + logger.warning( + f'frame names are not in "." format: {frame_names[:5]=}, ' + f"falling back to lexicographic sort." + ) + frame_names.sort() + num_frames = len(frame_names) + if num_frames == 0: + raise RuntimeError(f"no images found in {image_folder}") + img_paths = [os.path.join(image_folder, frame_name) for frame_name in frame_names] + img_mean = torch.tensor(img_mean, dtype=torch.float16)[:, None, None] + img_std = torch.tensor(img_std, dtype=torch.float16)[:, None, None] + + if async_loading_frames: + lazy_images = AsyncImageFrameLoader( + img_paths, image_size, offload_video_to_cpu, img_mean, img_std + ) + return lazy_images, lazy_images.video_height, lazy_images.video_width + + # float16 precision should be sufficient for image tensor storage + images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float16) + video_height, video_width = None, None + for n, img_path in enumerate( + tqdm(img_paths, desc=f"frame loading (image folder) [rank={RANK}]") + ): + images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size) + if not offload_video_to_cpu: + images = images.cuda() + img_mean = img_mean.cuda() + img_std = img_std.cuda() + # normalize by mean and std + images -= img_mean + images /= img_std + return images, video_height, video_width + + +def load_video_frames_from_video_file( + video_path, + image_size, + offload_video_to_cpu, + img_mean, + img_std, + async_loading_frames, + gpu_acceleration=False, + gpu_device=None, + video_loader_type="cv2", +): + """Load the video frames from a video file.""" + if video_loader_type == "cv2": + return load_video_frames_from_video_file_using_cv2( + video_path=video_path, + image_size=image_size, + img_mean=img_mean, + img_std=img_std, + offload_video_to_cpu=offload_video_to_cpu, + ) + elif video_loader_type == "torchcodec": + logger.info("Using torchcodec to load video file") + lazy_images = AsyncVideoFileLoaderWithTorchCodec( + video_path=video_path, + image_size=image_size, + offload_video_to_cpu=offload_video_to_cpu, + img_mean=img_mean, + img_std=img_std, + gpu_acceleration=gpu_acceleration, + gpu_device=gpu_device, + ) + # The `AsyncVideoFileLoaderWithTorchCodec` class always loads the videos asynchronously, + # so we just wait for its loading thread to finish if async_loading_frames=False. + if not async_loading_frames: + async_thread = lazy_images.thread + if async_thread is not None: + async_thread.join() + return lazy_images, lazy_images.video_height, lazy_images.video_width + else: + raise RuntimeError("video_loader_type must be either 'cv2' or 'torchcodec'") + + +def load_video_frames_from_video_file_using_cv2( + video_path: str, + image_size: int, + img_mean: tuple = (0.5, 0.5, 0.5), + img_std: tuple = (0.5, 0.5, 0.5), + offload_video_to_cpu: bool = False, +) -> torch.Tensor: + """ + Load video from path, convert to normalized tensor with specified preprocessing + + Args: + video_path: Path to video file + image_size: Target size for square frames (height and width) + img_mean: Normalization mean (RGB) + img_std: Normalization standard deviation (RGB) + + Returns: + torch.Tensor: Preprocessed video tensor in shape (T, C, H, W) with float16 dtype + """ + import cv2 # delay OpenCV import to avoid unnecessary dependency + + # Initialize video capture + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + raise ValueError(f"Could not open video: {video_path}") + + original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + num_frames = num_frames if num_frames > 0 else None + + frames = [] + pbar = tqdm(desc=f"frame loading (OpenCV) [rank={RANK}]", total=num_frames) + while True: + ret, frame = cap.read() + if not ret: + break + + # Convert BGR to RGB and resize + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frame_resized = cv2.resize( + frame_rgb, (image_size, image_size), interpolation=cv2.INTER_CUBIC + ) + frames.append(frame_resized) + pbar.update(1) + cap.release() + pbar.close() + + if len(frames) == 0: + raise RuntimeError( + f"No frames could be decoded from video: {video_path}. " + f"The file may be corrupted, empty, or encoded with an unsupported codec." + ) + + # Convert to tensor + frames_np = np.stack(frames, axis=0).astype(np.float32) # (T, H, W, C) + video_tensor = torch.from_numpy(frames_np).permute(0, 3, 1, 2) # (T, C, H, W) + + img_mean = torch.tensor(img_mean, dtype=torch.float16).view(1, 3, 1, 1) + img_std = torch.tensor(img_std, dtype=torch.float16).view(1, 3, 1, 1) + if not offload_video_to_cpu: + video_tensor = video_tensor.cuda() + img_mean = img_mean.cuda() + img_std = img_std.cuda() + # normalize by mean and std + video_tensor -= img_mean + video_tensor /= img_std + return video_tensor, original_height, original_width + + +def load_dummy_video(image_size, offload_video_to_cpu, num_frames=60, do_zeros=False): + """ + Load a dummy video with random frames for testing and compilation warmup purposes. + """ + video_height, video_width = 480, 640 # dummy original video sizes + if not do_zeros: + images = torch.randn(num_frames, 3, image_size, image_size, dtype=torch.float16) + else: + images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float16) + if not offload_video_to_cpu: + images = images.cuda() + return images, video_height, video_width + + +def _load_img_as_tensor(img_path, image_size): + """Load and resize an image and convert it into a PyTorch tensor.""" + img = Image.open(img_path).convert("RGB") + orig_width, orig_height = img.width, img.height + img = TF.resize(img, size=(image_size, image_size)) + img = TF.to_tensor(img) + return img, orig_height, orig_width + + +class AsyncImageFrameLoader: + """ + A list of video frames to be load asynchronously without blocking session start. + """ + + def __init__(self, img_paths, image_size, offload_video_to_cpu, img_mean, img_std): + self.img_paths = img_paths + self.image_size = image_size + self.offload_video_to_cpu = offload_video_to_cpu + self.img_mean = img_mean + self.img_std = img_std + # items in `self._images` will be loaded asynchronously + self.images = [None] * len(img_paths) + # catch and raise any exceptions in the async loading thread + self.exception = None + # video_height and video_width be filled when loading the first image + self.video_height = None + self.video_width = None + + # load the first frame to fill video_height and video_width and also + # to cache it (since it's most likely where the user will click) + self.__getitem__(0) + + # load the rest of frames asynchronously without blocking the session start + def _load_frames(): + try: + for n in tqdm( + range(len(self.images)), + desc=f"frame loading (image folder) [rank={RANK}]", + ): + self.__getitem__(n) + except Exception as e: + self.exception = e + + self.thread = Thread(target=_load_frames, daemon=True) + self.thread.start() + + def __getitem__(self, index): + if self.exception is not None: + raise RuntimeError("Failure in frame loading thread") from self.exception + + img = self.images[index] + if img is not None: + return img + + img, video_height, video_width = _load_img_as_tensor( + self.img_paths[index], self.image_size + ) + self.video_height = video_height + self.video_width = video_width + # float16 precision should be sufficient for image tensor storage + img = img.to(dtype=torch.float16) + # normalize by mean and std + img -= self.img_mean + img /= self.img_std + if not self.offload_video_to_cpu: + img = img.cuda() + self.images[index] = img + return img + + def __len__(self): + return len(self.images) + + +class TorchCodecDecoder: + """ + A wrapper to support GPU device and num_threads in TorchCodec decoder, + which are not supported by `torchcodec.decoders.SimpleVideoDecoder` yet. + """ + + def __init__(self, source, dimension_order="NCHW", device="cpu", num_threads=1): + from torchcodec import _core as core + + self._source = source # hold a reference to the source to prevent it from GC + if isinstance(source, str): + self._decoder = core.create_from_file(source, "exact") + elif isinstance(source, bytes): + self._decoder = core.create_from_bytes(source, "exact") + else: + raise TypeError(f"Unknown source type: {type(source)}.") + assert dimension_order in ("NCHW", "NHWC") + + device_string = str(device) + core.scan_all_streams_to_update_metadata(self._decoder) + core.add_video_stream( + self._decoder, + dimension_order=dimension_order, + device=device_string, + num_threads=(1 if "cuda" in device_string else num_threads), + ) + video_metadata = core.get_container_metadata(self._decoder) + best_stream_index = video_metadata.best_video_stream_index + assert best_stream_index is not None + self.metadata = video_metadata.streams[best_stream_index] + assert self.metadata.num_frames_from_content is not None + self._num_frames = self.metadata.num_frames_from_content + + def __len__(self) -> int: + return self._num_frames + + def __getitem__(self, key: int): + from torchcodec import _core as core + + if key < 0: + key += self._num_frames + if key >= self._num_frames or key < 0: + raise IndexError( + f"Index {key} is out of bounds; length is {self._num_frames}" + ) + frame_data, *_ = core.get_frame_at_index( + self._decoder, + frame_index=key, + ) + return frame_data + + +class FIFOLock: + """A lock that ensures FIFO ordering of lock acquisitions.""" + + def __init__(self): + self._lock = Lock() + self._waiters = queue.Queue() + self._condition = Condition() + + def acquire(self): + ident = get_ident() + with self._condition: + self._waiters.put(ident) + while self._waiters.queue[0] != ident or not self._lock.acquire( + blocking=False + ): + self._condition.wait() + # got the lock and it's our turn + + def release(self): + with self._condition: + self._lock.release() + self._waiters.get() + self._condition.notify_all() + + def __enter__(self): + self.acquire() + + def __exit__(self, t, v, tb): + self.release() + + +class AsyncVideoFileLoaderWithTorchCodec: + """ + Loading frames from video files asynchronously without blocking session start. + + Unlike `AsyncVideoFileLoader`, this class uses PyTorch's offical TorchCodec library + for video decoding, which is more efficient and supports more video formats. + """ + + def __init__( + self, + video_path, + image_size, + offload_video_to_cpu, + img_mean, + img_std, + gpu_acceleration=True, + gpu_device=None, + use_rand_seek_in_loading=False, + ): + # Check and possibly infer the output device (and also get its GPU id when applicable) + assert gpu_device is None or gpu_device.type == "cuda" + gpu_id = ( + gpu_device.index + if gpu_device is not None and gpu_device.index is not None + else torch.cuda.current_device() + ) + if offload_video_to_cpu: + out_device = torch.device("cpu") + else: + out_device = torch.device("cuda") if gpu_device is None else gpu_device + self.out_device = out_device + self.gpu_acceleration = gpu_acceleration + self.gpu_id = gpu_id + self.image_size = image_size + self.offload_video_to_cpu = offload_video_to_cpu + if not isinstance(img_mean, torch.Tensor): + img_mean = torch.tensor(img_mean, dtype=torch.float16)[:, None, None] + self.img_mean = img_mean + if not isinstance(img_std, torch.Tensor): + img_std = torch.tensor(img_std, dtype=torch.float16)[:, None, None] + self.img_std = img_std + + if gpu_acceleration: + self.img_mean = self.img_mean.to(f"cuda:{self.gpu_id}") + self.img_std = self.img_std.to(f"cuda:{self.gpu_id}") + decoder_option = {"device": f"cuda:{self.gpu_id}"} + else: + self.img_mean = self.img_mean.cpu() + self.img_std = self.img_std.cpu() + decoder_option = {"num_threads": 1} # use a single thread to save memory + + self.rank = int(os.environ.get("RANK", "0")) + self.world_size = int(os.environ.get("WORLD_SIZE", "1")) + self.async_reader = TorchCodecDecoder(video_path, **decoder_option) + + # `num_frames_from_content` is the true number of frames in the video content + # from the scan operation (rather than from the metadata, which could be wrong) + self.num_frames = self.async_reader.metadata.num_frames_from_content + self.video_height = self.async_reader.metadata.height + self.video_width = self.async_reader.metadata.width + + # items in `self._images` will be loaded asynchronously + self.images_loaded = [False] * self.num_frames + self.images = torch.zeros( + self.num_frames, + 3, + self.image_size, + self.image_size, + dtype=torch.float16, + device=self.out_device, + ) + # catch and raise any exceptions in the async loading thread + self.exception = None + self.use_rand_seek_in_loading = use_rand_seek_in_loading + self.rand_seek_idx_queue = queue.Queue() + # use a lock to avoid race condition between concurrent access to torchcodec + # libs (which are not thread-safe); the lock is replaced with a nullcontext + # when the video is fully loaded + self.torchcodec_access_lock = FIFOLock() + self._start_video_loading() + + def _load_one_frame(self, idx): + frame_resized = self._transform_frame(self.async_reader[idx]) + return frame_resized + + @torch.inference_mode() + def _start_video_loading(self): + desc = f"frame loading (TorchCodec w/ {'GPU' if self.gpu_acceleration else 'CPU'}) [rank={RANK}]" + pbar = tqdm(desc=desc, total=self.num_frames) + self.num_loaded_frames = 0 + # load the first frame synchronously to cache it before the session is opened + idx = self.num_loaded_frames + self.images[idx] = self._load_one_frame(idx) + self.images_loaded[idx] = True + self.num_loaded_frames += 1 + pbar.update(n=1) + self.all_frames_loaded = self.num_loaded_frames == self.num_frames + + # load the frames asynchronously without blocking the session start + def _load_frames(): + finished = self.all_frames_loaded + chunk_size = 16 + while not finished: + # asynchronously load `chunk_size` frames each time we acquire the lock + with self.torchcodec_access_lock, torch.inference_mode(): + for _ in range(chunk_size): + try: + idx = self.num_loaded_frames + self.images[idx] = self._load_one_frame(idx) + self.images_loaded[idx] = True + self.num_loaded_frames += 1 + pbar.update(n=1) + if self.num_loaded_frames >= self.num_frames: + finished = True + break + except Exception as e: + self.exception = e + raise + + # also read the frame that is being randomly seeked to + while True: + try: + idx = self.rand_seek_idx_queue.get_nowait() + if not self.images_loaded[idx]: + self.images[idx] = self._load_one_frame(idx) + self.images_loaded[idx] = True + except queue.Empty: + break + except Exception as e: + self.exception = e + raise + + # finished -- check whether we have loaded the total number of frames + if self.num_loaded_frames != self.num_frames: + raise RuntimeError( + f"There are {self.num_frames} frames in the video, but only " + f"{self.num_loaded_frames} frames can be loaded successfully." + ) + else: + self.all_frames_loaded = True + pbar.close() + with self.torchcodec_access_lock: + import gc + + # all frames have been loaded, so we can release the readers and free their memory + # also remove pbar and thread (which shouldn't be a part of session saving) + reader = self.async_reader + if reader is not None: + reader._source = None + self.async_reader = None + self.pbar = None + self.thread = None + self.rand_seek_idx_queue = None + gc.collect() + # remove the lock (replace it with nullcontext) when the video is fully loaded + self.torchcodec_access_lock = contextlib.nullcontext() + + self.thread = Thread(target=_load_frames, daemon=True) + self.thread.start() + + def _transform_frame(self, frame): + frame = frame.clone() # make a copy to avoid modifying the original frame bytes + frame = frame.float() # convert to float32 before interpolation + frame_resized = F.interpolate( + frame[None, :], + size=(self.image_size, self.image_size), + mode="bicubic", + align_corners=False, + )[0] + # float16 precision should be sufficient for image tensor storage + frame_resized = frame_resized.half() # uint8 -> float16 + frame_resized /= 255 + frame_resized -= self.img_mean + frame_resized /= self.img_std + if self.offload_video_to_cpu: + frame_resized = frame_resized.cpu() + elif frame_resized.device != self.out_device: + frame_resized = frame_resized.to(device=self.out_device, non_blocking=True) + return frame_resized + + def __getitem__(self, index): + if self.exception is not None: + raise RuntimeError("Failure in frame loading thread") from self.exception + + max_tries = 1200 + for _ in range(max_tries): + # use a lock to avoid race condition between concurrent access to torchcodec + # libs (which are not thread-safe); the lock is replaced with a nullcontext + # when the video is fully loaded + with self.torchcodec_access_lock: + if self.images_loaded[index]: + return self.images[index] + + if self.use_rand_seek_in_loading: + # async loading hasn't reached this frame yet, so we load this frame individually + # (it will be loaded by in _load_frames thread and added to self.images[index]) + self.rand_seek_idx_queue.put(index) + + time.sleep(0.1) + + raise RuntimeError(f"Failed to load frame {index} after {max_tries} tries") + + def __len__(self): + return len(self.images) + + def __getstate__(self): + """ + Remove a few attributes during pickling, so that this async video loader can be + saved and loaded as a part of the model session. + """ + # wait for async video loading to finish before pickling + async_thread = self.thread + if async_thread is not None: + async_thread.join() + # release a few objects that cannot be pickled + reader = self.async_reader + if reader is not None: + reader._source = None + self.async_reader = None + self.pbar = None + self.thread = None + self.rand_seek_idx_queue = None + self.torchcodec_access_lock = contextlib.nullcontext() + return self.__dict__.copy() diff --git a/third_party/sam3/sam3/model/maskformer_segmentation.py b/third_party/sam3/sam3/model/maskformer_segmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..4953b7caa1aedee8a2e10f0c6695cf2331aa3575 --- /dev/null +++ b/third_party/sam3/sam3/model/maskformer_segmentation.py @@ -0,0 +1,335 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +import math +from typing import Dict, List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint + +from .model_misc import MLP + + +class LinearPresenceHead(nn.Sequential): + def __init__(self, d_model): + # a hack to make `LinearPresenceHead` compatible with old checkpoints + super().__init__(nn.Identity(), nn.Identity(), nn.Linear(d_model, 1)) + + def forward(self, hs, prompt, prompt_mask): + return super().forward(hs) + + +class MaskPredictor(nn.Module): + def __init__(self, hidden_dim, mask_dim): + super().__init__() + self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3) + + def forward(self, obj_queries, pixel_embed): + if len(obj_queries.shape) == 3: + if pixel_embed.ndim == 3: + # batch size was omitted + mask_preds = torch.einsum( + "bqc,chw->bqhw", self.mask_embed(obj_queries), pixel_embed + ) + else: + mask_preds = torch.einsum( + "bqc,bchw->bqhw", self.mask_embed(obj_queries), pixel_embed + ) + else: + # Assumed to have aux masks + if pixel_embed.ndim == 3: + # batch size was omitted + mask_preds = torch.einsum( + "lbqc,chw->lbqhw", self.mask_embed(obj_queries), pixel_embed + ) + else: + mask_preds = torch.einsum( + "lbqc,bchw->lbqhw", self.mask_embed(obj_queries), pixel_embed + ) + + return mask_preds + + +class SegmentationHead(nn.Module): + def __init__( + self, + hidden_dim, + upsampling_stages, + use_encoder_inputs=False, + aux_masks=False, + no_dec=False, + pixel_decoder=None, + act_ckpt=False, + shared_conv=False, + compile_mode_pixel_decoder=None, + ): + super().__init__() + self.use_encoder_inputs = use_encoder_inputs + self.aux_masks = aux_masks + if pixel_decoder is not None: + self.pixel_decoder = pixel_decoder + else: + self.pixel_decoder = PixelDecoder( + hidden_dim, + upsampling_stages, + shared_conv=shared_conv, + compile_mode=compile_mode_pixel_decoder, + ) + self.no_dec = no_dec + if no_dec: + self.mask_predictor = nn.Conv2d( + hidden_dim, 1, kernel_size=3, stride=1, padding=1 + ) + else: + self.mask_predictor = MaskPredictor(hidden_dim, mask_dim=hidden_dim) + + self.act_ckpt = act_ckpt + + # used to update the output dictionary + self.instance_keys = ["pred_masks"] + + @property + def device(self): + self._device = getattr(self, "_device", None) or next(self.parameters()).device + return self._device + + def to(self, *args, **kwargs): + # clear cached _device in case the model is moved to a different device + self._device = None + return super().to(*args, **kwargs) + + def _embed_pixels( + self, + backbone_feats: List[torch.Tensor], + image_ids, + encoder_hidden_states, + ) -> torch.Tensor: + # Unwrap NestedTensors to plain tensors if needed (multiplex path) + from sam3.model.data_misc import NestedTensor + + def _unwrap(x): + return x.tensors if isinstance(x, NestedTensor) else x + + feature_device = backbone_feats[0].device # features could be on CPU + model_device = self.device + image_ids_ = image_ids.to(feature_device) + if self.use_encoder_inputs: + if backbone_feats[0].shape[0] > 1: + # For bs > 1, we construct the per query backbone features + backbone_visual_feats = [] + for feat in backbone_feats: + # Copy the img features per query (pixel decoder won't share img feats) + backbone_visual_feats.append( + _unwrap(feat)[image_ids_, ...].to(model_device) + ) + else: + # Bs=1, we rely on broadcasting for query-based processing + backbone_visual_feats = [ + _unwrap(bb_feat).clone() for bb_feat in backbone_feats + ] + # Extract visual embeddings + encoder_hidden_states = encoder_hidden_states.permute(1, 2, 0) + spatial_dim = math.prod(backbone_feats[-1].shape[-2:]) + encoder_visual_embed = encoder_hidden_states[..., :spatial_dim].reshape( + -1, *backbone_feats[-1].shape[1:] + ) + + backbone_visual_feats[-1] = encoder_visual_embed + if self.act_ckpt: + pixel_embed = checkpoint.checkpoint( + self.pixel_decoder, backbone_visual_feats, use_reentrant=False + ) + else: + pixel_embed = self.pixel_decoder(backbone_visual_feats) + else: + backbone_feats = [_unwrap(x).to(model_device) for x in backbone_feats] + pixel_embed = self.pixel_decoder(backbone_feats) + if pixel_embed.shape[0] == 1: + # For batch_size=1 training, we can avoid the indexing to save memory + pixel_embed = pixel_embed.squeeze(0) + else: + pixel_embed = pixel_embed[image_ids, ...] + return pixel_embed + + def forward( + self, + backbone_feats: List[torch.Tensor], + obj_queries: torch.Tensor, + image_ids, + encoder_hidden_states: Optional[torch.Tensor] = None, + **kwargs, + ) -> Dict[str, torch.Tensor]: + if self.use_encoder_inputs: + assert encoder_hidden_states is not None + + pixel_embed = self._embed_pixels( + backbone_feats=backbone_feats, + image_ids=image_ids, + encoder_hidden_states=encoder_hidden_states, + ) + + if self.no_dec: + mask_pred = self.mask_predictor(pixel_embed) + elif self.aux_masks: + mask_pred = self.mask_predictor(obj_queries, pixel_embed) + else: + mask_pred = self.mask_predictor(obj_queries[-1], pixel_embed) + + return {"pred_masks": mask_pred} + + +class PixelDecoder(nn.Module): + def __init__( + self, + hidden_dim, + num_upsampling_stages, + interpolation_mode="nearest", + shared_conv=False, + compile_mode=None, + ): + super().__init__() + self.hidden_dim = hidden_dim + self.num_upsampling_stages = num_upsampling_stages + self.interpolation_mode = interpolation_mode + conv_layers = [] + norms = [] + num_convs = 1 if shared_conv else num_upsampling_stages + for _ in range(num_convs): + conv_layers.append(nn.Conv2d(self.hidden_dim, self.hidden_dim, 3, 1, 1)) + norms.append(nn.GroupNorm(8, self.hidden_dim)) + + self.conv_layers = nn.ModuleList(conv_layers) + self.norms = nn.ModuleList(norms) + self.shared_conv = shared_conv + self.out_dim = self.conv_layers[-1].out_channels + if compile_mode is not None: + self.forward = torch.compile( + self.forward, mode=compile_mode, dynamic=True, fullgraph=True + ) + # Needed to make checkpointing happy. But we don't know if the module is checkpointed, so we disable it by default. + torch._dynamo.config.optimize_ddp = False + + def forward(self, backbone_feats: List[torch.Tensor]): + # Assumes backbone features are already projected (C == hidden dim) + + prev_fpn = backbone_feats[-1] + fpn_feats = backbone_feats[:-1] + for layer_idx, bb_feat in enumerate(fpn_feats[::-1]): + curr_fpn = bb_feat + prev_fpn = curr_fpn + F.interpolate( + prev_fpn, size=curr_fpn.shape[-2:], mode=self.interpolation_mode + ) + if self.shared_conv: + # only one conv layer + layer_idx = 0 + prev_fpn = self.conv_layers[layer_idx](prev_fpn) + prev_fpn = F.relu(self.norms[layer_idx](prev_fpn)) + + return prev_fpn + + +class UniversalSegmentationHead(SegmentationHead): + """This module handles semantic+instance segmentation""" + + def __init__( + self, + hidden_dim, + upsampling_stages, + pixel_decoder, + aux_masks=False, + no_dec=False, + act_ckpt=False, + presence_head: bool = False, + dot_product_scorer=None, + cross_attend_prompt=None, + ): + super().__init__( + hidden_dim=hidden_dim, + upsampling_stages=upsampling_stages, + use_encoder_inputs=True, + aux_masks=aux_masks, + no_dec=no_dec, + pixel_decoder=pixel_decoder, + act_ckpt=act_ckpt, + ) + self.d_model = hidden_dim + + if dot_product_scorer is not None: + assert presence_head, "Specifying a dot product scorer without a presence head is likely a mistake" + + self.presence_head = None + if presence_head: + self.presence_head = ( + dot_product_scorer + if dot_product_scorer is not None + else LinearPresenceHead(self.d_model) + ) + + self.cross_attend_prompt = cross_attend_prompt + if self.cross_attend_prompt is not None: + self.cross_attn_norm = nn.LayerNorm(self.d_model) + + self.semantic_seg_head = nn.Conv2d(self.pixel_decoder.out_dim, 1, kernel_size=1) + self.instance_seg_head = nn.Conv2d( + self.pixel_decoder.out_dim, self.d_model, kernel_size=1 + ) + + def forward( + self, + backbone_feats: List[torch.Tensor], + obj_queries: torch.Tensor, + image_ids, + encoder_hidden_states: Optional[torch.Tensor] = None, + prompt: Optional[torch.Tensor] = None, + prompt_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Dict[str, Optional[torch.Tensor]]: + assert encoder_hidden_states is not None + bs = encoder_hidden_states.shape[1] + + if self.cross_attend_prompt is not None: + tgt2 = self.cross_attn_norm(encoder_hidden_states) + tgt2 = self.cross_attend_prompt( + query=tgt2, + key=prompt, + value=prompt, + key_padding_mask=prompt_mask, + )[0] + encoder_hidden_states = tgt2 + encoder_hidden_states + + presence_logit = None + if self.presence_head is not None: + pooled_enc = encoder_hidden_states.mean(0) + presence_logit = ( + self.presence_head( + pooled_enc.view(1, bs, 1, self.d_model), + prompt=prompt, + prompt_mask=prompt_mask, + ) + .squeeze(0) + .squeeze(1) + ) + + pixel_embed = self._embed_pixels( + backbone_feats=backbone_feats, + image_ids=image_ids, + encoder_hidden_states=encoder_hidden_states, + ) + + instance_embeds = self.instance_seg_head(pixel_embed) + + if self.no_dec: + mask_pred = self.mask_predictor(instance_embeds) + elif self.aux_masks: + mask_pred = self.mask_predictor(obj_queries, instance_embeds) + else: + mask_pred = self.mask_predictor(obj_queries[-1], instance_embeds) + + return { + "pred_masks": mask_pred, + "semantic_seg": self.semantic_seg_head(pixel_embed), + "presence_logit": presence_logit, + } diff --git a/third_party/sam3/sam3/model/memory.py b/third_party/sam3/sam3/model/memory.py new file mode 100644 index 0000000000000000000000000000000000000000..4901cbe08e2e601aaf4e5355d1e283afb4a09f94 --- /dev/null +++ b/third_party/sam3/sam3/model/memory.py @@ -0,0 +1,209 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +import math +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +try: + from timm.layers import DropPath +except ModuleNotFoundError: + # compatibility for older timm versions + from timm.models.layers import DropPath + +from .model_misc import get_clones, LayerNorm2d + + +class SimpleMaskDownSampler(nn.Module): + """ + Progressively downsample a mask by total_stride, each time by stride. + Note that LayerNorm is applied per *token*, like in ViT. + + With each downsample (by a factor stride**2), channel capacity increases by the same factor. + In the end, we linearly project to embed_dim channels. + """ + + def __init__( + self, + embed_dim=256, + kernel_size=4, + stride=4, + padding=0, + total_stride=16, + activation=nn.GELU, + # Option to interpolate the input mask first before downsampling using convs. In that case, the total_stride is assumed to be after interpolation. + # If set to input resolution or None, we don't interpolate. We default to None to be safe (for older configs or if not explicitly set) + interpol_size=None, + # options for incorporating multiplex memory encoding + multiplex_count: int = 1, + starting_out_chan: int = 1, + input_channel_multiplier: int = 1, + ): + super().__init__() + num_layers = int(math.log2(total_stride) // math.log2(stride)) + multiplex_count = multiplex_count * input_channel_multiplier + assert stride**num_layers == total_stride + self.encoder = nn.Sequential() + mask_in_chans, mask_out_chans = multiplex_count, starting_out_chan + for _ in range(num_layers): + mask_out_chans = mask_out_chans * (stride**2) + self.encoder.append( + nn.Conv2d( + mask_in_chans, + mask_out_chans, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + ) + self.encoder.append(LayerNorm2d(mask_out_chans)) + self.encoder.append(activation()) + mask_in_chans = mask_out_chans + + self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1)) + self.multiplex_count = multiplex_count + self.interpol_size = interpol_size + if self.interpol_size is not None: + assert isinstance( + self.interpol_size, (list, tuple) + ), f"Unsupported type {type(self.interpol_size)}. Should be a list or tuple." + self.interpol_size = list(interpol_size) + assert len(self.interpol_size) == 2 + + def forward(self, x: torch.Tensor): + if self.interpol_size is not None and self.interpol_size != list(x.shape[-2:]): + x = F.interpolate( + x.float(), + size=self.interpol_size, + align_corners=False, + mode="bilinear", + antialias=True, + ) + return self.encoder(x) + + +# Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt) +class CXBlock(nn.Module): + r"""ConvNeXt Block. There are two equivalent implementations: + (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) + (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back + We use (2) as we find it slightly faster in PyTorch + + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + """ + + def __init__( + self, + dim, + kernel_size=7, + padding=3, + drop_path=0.0, + layer_scale_init_value=1e-6, + use_dwconv=True, + ): + super().__init__() + self.dwconv = nn.Conv2d( + dim, + dim, + kernel_size=kernel_size, + padding=padding, + groups=dim if use_dwconv else 1, + ) # depthwise conv + self.norm = LayerNorm2d(dim, eps=1e-6) + self.pwconv1 = nn.Linear( + dim, 4 * dim + ) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(4 * dim, dim) + self.gamma = ( + nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) + if layer_scale_init_value > 0 + else None + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x): + input = x + x = self.dwconv(x) + x = self.norm(x) + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = input + self.drop_path(x) + return x + + +class SimpleFuser(nn.Module): + def __init__(self, layer, num_layers, dim=None, input_projection=False): + super().__init__() + self.proj = nn.Identity() + self.layers = get_clones(layer, num_layers) + + if input_projection: + assert dim is not None + self.proj = nn.Conv2d(dim, dim, kernel_size=1) + + def forward(self, x): + # normally x: (N, C, H, W) + x = self.proj(x) + for layer in self.layers: + x = layer(x) + return x + + +class SimpleMaskEncoder(nn.Module): + def __init__( + self, + out_dim, + mask_downsampler, + fuser, + position_encoding, + in_dim=256, # in_dim of pix_feats + ): + super().__init__() + + self.mask_downsampler = mask_downsampler + + self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1) + self.fuser = fuser + self.position_encoding = position_encoding + self.out_proj = nn.Identity() + if out_dim != in_dim: + self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1) + + def forward( + self, + pix_feat: torch.Tensor, + masks: torch.Tensor, + skip_mask_sigmoid: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + ## Process masks + # sigmoid, so that less domain shift from gt masks which are bool + if not skip_mask_sigmoid: + masks = F.sigmoid(masks) + masks = self.mask_downsampler(masks) + + ## Fuse pix_feats and downsampled masks + # in case the visual features are on CPU, cast them to CUDA + pix_feat = pix_feat.to(masks.device) + + x = self.pix_feat_proj(pix_feat) + x = x + masks + x = self.fuser(x) + x = self.out_proj(x) + + pos = self.position_encoding(x).to(x.dtype) + + return {"vision_features": x, "vision_pos_enc": [pos]} diff --git a/third_party/sam3/sam3/model/model_misc.py b/third_party/sam3/sam3/model/model_misc.py new file mode 100644 index 0000000000000000000000000000000000000000..f3f25ed8bf46a5cd047f1226a610b27cbb530d48 --- /dev/null +++ b/third_party/sam3/sam3/model/model_misc.py @@ -0,0 +1,1109 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +"""Various utility models""" + +import copy +import math +import warnings +import weakref +from collections.abc import Iterator +from contextlib import AbstractContextManager +from enum import auto, Enum +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn, Tensor +from torch.overrides import handle_torch_function, has_torch_function +from typing_extensions import override + +try: + import xformers +except ImportError: + xformers = None + + +def inverse_sigmoid(x, eps=1e-3): + """ + The inverse function for sigmoid activation function. + Compute in fp32 to avoid numerical issues with bf16/fp16. + """ + input_dtype = x.dtype + x = x.float() + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1 / x2).to(input_dtype) + + +def get_sdpa_settings(): + if torch.cuda.is_available(): + old_gpu = torch.cuda.get_device_properties(0).major < 7 + # only use Flash Attention on Ampere (8.0) or newer GPUs + use_flash_attn = torch.cuda.get_device_properties(0).major >= 8 + if not use_flash_attn: + warnings.warn( + "Flash Attention is disabled as it requires a GPU with Ampere (8.0) CUDA capability.", + category=UserWarning, + stacklevel=2, + ) + # keep math kernel for PyTorch versions before 2.2 (Flash Attention v2 is only + # available on PyTorch 2.2+, while Flash Attention v1 cannot handle all cases) + pytorch_version = tuple(int(v) for v in torch.__version__.split(".")[:2]) + if pytorch_version < (2, 2): + warnings.warn( + f"You are using PyTorch {torch.__version__} without Flash Attention v2 support. " + "Consider upgrading to PyTorch 2.2+ for Flash Attention v2 (which could be faster).", + category=UserWarning, + stacklevel=2, + ) + math_kernel_on = pytorch_version < (2, 2) or not use_flash_attn + else: + old_gpu = True + use_flash_attn = False + math_kernel_on = True + + return old_gpu, use_flash_attn, math_kernel_on + + +OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings() + + +class AttentionType: + """Type of attention""" + + # Simple dot product attention + Vanilla = "Vanilla" + + # Efficient attention from xformers + Xformer = "Xformer" + + # Sparse attention + Sparse = "Sparse" + + # Deformable attention (not compatible with text) + Deformable = "Deformable" + + +def multi_head_attention_forward( + query: Tensor, + key: Tensor, + value: Tensor, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight: Optional[Tensor], + in_proj_bias: Optional[Tensor], + bias_k: Optional[Tensor], + bias_v: Optional[Tensor], + add_zero_attn: bool, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Optional[Tensor], + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + use_separate_proj_weight: bool = False, + q_proj_weight: Optional[Tensor] = None, + k_proj_weight: Optional[Tensor] = None, + v_proj_weight: Optional[Tensor] = None, + static_k: Optional[Tensor] = None, + static_v: Optional[Tensor] = None, + average_attn_weights: bool = True, + is_causal: bool = False, + attn_type: AttentionType = AttentionType.Vanilla, + attn_sparsity: float = 0.0, + attn_bias: Optional[Tensor] = None, + use_fa3: bool = False, +) -> Tuple[Tensor, Optional[Tensor]]: + tens_ops = ( + query, + key, + value, + in_proj_weight, + in_proj_bias, + bias_k, + bias_v, + out_proj_weight, + out_proj_bias, + ) + if has_torch_function(tens_ops): + return handle_torch_function( + multi_head_attention_forward, + tens_ops, + query, + key, + value, + embed_dim_to_check, + num_heads, + in_proj_weight, + in_proj_bias, + bias_k, + bias_v, + add_zero_attn, + dropout_p, + out_proj_weight, + out_proj_bias, + training=training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + is_causal=is_causal, + use_separate_proj_weight=use_separate_proj_weight, + q_proj_weight=q_proj_weight, + k_proj_weight=k_proj_weight, + v_proj_weight=v_proj_weight, + static_k=static_k, + static_v=static_v, + average_attn_weights=average_attn_weights, + use_fa3=use_fa3, + ) + + is_batched = True + + if is_causal: + raise NotImplementedError("is_causal is not supported in this implem") + attn_mask = None + + if not is_batched: + query = query.unsqueeze(1) + key = key.unsqueeze(1) + value = value.unsqueeze(1) + if key_padding_mask is not None: + key_padding_mask = key_padding_mask.unsqueeze(0) + + # set up shape vars + tgt_len, bsz, embed_dim = query.shape + src_len, _, _ = key.shape + if key_padding_mask is not None: + _kpm_dtype = key_padding_mask.dtype + if _kpm_dtype != torch.bool and not torch.is_floating_point(key_padding_mask): + raise AssertionError( + "only bool and floating types of key_padding_mask are supported" + ) + assert ( + embed_dim == embed_dim_to_check + ), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}" + if isinstance(embed_dim, torch.Tensor): + head_dim = embed_dim.div(num_heads, rounding_mode="trunc") + else: + head_dim = embed_dim // num_heads + assert ( + head_dim * num_heads == embed_dim + ), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}" + if use_separate_proj_weight: + assert ( + key.shape[:2] == value.shape[:2] + ), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}" + else: + assert ( + key.shape == value.shape + ), f"key shape {key.shape} does not match value shape {value.shape}" + + # + # compute in-projection + # + if not use_separate_proj_weight: + assert ( + in_proj_weight is not None + ), "use_separate_proj_weight is False but in_proj_weight is None" + q, k, v = F._in_projection_packed( + query, key, value, in_proj_weight, in_proj_bias + ) + else: + assert ( + q_proj_weight is not None + ), "use_separate_proj_weight is True but q_proj_weight is None" + assert ( + k_proj_weight is not None + ), "use_separate_proj_weight is True but k_proj_weight is None" + assert ( + v_proj_weight is not None + ), "use_separate_proj_weight is True but v_proj_weight is None" + if in_proj_bias is None: + b_q = b_k = b_v = None + else: + b_q, b_k, b_v = in_proj_bias.chunk(3) + q, k, v = F._in_projection( + query, + key, + value, + q_proj_weight, + k_proj_weight, + v_proj_weight, + b_q, + b_k, + b_v, + ) + + # prep attention mask + if attn_mask is not None: + if attn_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead." + ) + attn_mask = attn_mask.to(torch.bool) + else: + assert ( + attn_mask.is_floating_point() or attn_mask.dtype == torch.bool + ), f"Only float, byte, and bool types are supported for attn_mask, not {attn_mask.dtype}" + # ensure attn_mask's dim is 3 + if attn_mask.dim() == 2: + correct_2d_size = (tgt_len, src_len) + if attn_mask.shape != correct_2d_size: + raise RuntimeError( + f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}." + ) + attn_mask = attn_mask.unsqueeze(0) + elif attn_mask.dim() == 3: + correct_3d_size = (bsz * num_heads, tgt_len, src_len) + if attn_mask.shape != correct_3d_size: + raise RuntimeError( + f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}." + ) + else: + raise RuntimeError( + f"attn_mask's dimension {attn_mask.dim()} is not supported" + ) + + # add bias along batch dimension (currently second) + if bias_k is not None and bias_v is not None: + assert static_k is None, "bias cannot be added to static key." + assert static_v is None, "bias cannot be added to static value." + k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = F.pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = F.pad(key_padding_mask, (0, 1)) + else: + assert bias_k is None + assert bias_v is None + + # + # reshape q, k, v for multihead attention and make em batch first + # + q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) + if static_k is None: + k = k.contiguous().view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1) + else: + assert ( + static_k.size(0) == bsz * num_heads + ), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}" + assert ( + static_k.size(2) == head_dim + ), f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}" + k = static_k + if static_v is None: + v = v.contiguous().view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1) + else: + assert ( + static_v.size(0) == bsz * num_heads + ), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}" + assert ( + static_v.size(2) == head_dim + ), f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}" + v = static_v + + # add zero attention along batch dimension (now first) + if add_zero_attn: + zero_attn_shape = (bsz * num_heads, 1, head_dim) + k = torch.cat( + [k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1 + ) + v = torch.cat( + [v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1 + ) + if attn_mask is not None: + attn_mask = F.pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = F.pad(key_padding_mask, (0, 1)) + + # update source sequence length after adjustments + src_len = k.size(1) + + # merge key padding and attention masks + if key_padding_mask is not None: + assert ( + key_padding_mask.shape + == ( + bsz, + src_len, + ) + ), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}" + key_padding_mask = ( + key_padding_mask.view(bsz, 1, 1, src_len) + .expand(-1, num_heads, -1, -1) + .reshape(bsz * num_heads, 1, src_len) + ) + if attn_mask is None: + attn_mask = key_padding_mask + elif attn_mask.dtype == torch.bool: + attn_mask = attn_mask.logical_or(key_padding_mask) + else: + attn_mask = attn_mask.masked_fill(key_padding_mask, float("-inf")) + + # convert mask to float + if attn_mask is not None and attn_mask.dtype == torch.bool: + new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) + new_attn_mask.masked_fill_(attn_mask, float("-inf")) + attn_mask = new_attn_mask + + # adjust dropout probability + if not training: + dropout_p = 0.0 + + # + # (deep breath) calculate attention and out projection + # + + if attn_mask is not None: + if attn_mask.size(0) == 1: + attn_mask = attn_mask.unsqueeze(0) + else: + attn_mask = attn_mask.view(bsz, num_heads, -1, src_len) + + if attn_bias is not None: + assert ( + attn_bias.shape + == ( + bsz, + num_heads, + tgt_len, + src_len, + ) + ), f"expecting attn_bias shape of {(bsz, num_heads, tgt_len, src_len)}, but got {attn_bias.shape}" + if attn_mask is None: + attn_mask = attn_bias + else: + attn_mask = attn_mask + attn_bias + + q = q.view(bsz, num_heads, tgt_len, head_dim) + k = k.view(bsz, num_heads, src_len, head_dim) + v = v.view(bsz, num_heads, src_len, head_dim) + + if attn_type == AttentionType.Vanilla: + if attn_mask is None and not is_causal and use_fa3: + from sam3.perflib.fa3 import flash_attn_func + + assert dropout_p == 0.0 + attn_output = flash_attn_func( + q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + ).transpose(1, 2) + else: + torch.backends.cuda.enable_flash_sdp(True) + torch.backends.cuda.enable_math_sdp(True) + torch.backends.cuda.enable_mem_efficient_sdp(True) + + attn_output = F.scaled_dot_product_attention( + q, k, v, attn_mask, dropout_p, is_causal + ) + + attn_output = ( + attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim) + ) + elif attn_type == AttentionType.Xformer: + attn_output_weights = None + assert not need_weights, "need_weights is not supported in efficient mode" + attn_output = xformers.ops.memory_efficient_attention( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + attn_bias=attn_mask, + p=dropout_p, + ) + attn_output = attn_output.permute(1, 0, 2, 3).reshape(bsz * tgt_len, embed_dim) + elif attn_type == AttentionType.Sparse: + attn_output_weights = None + assert not need_weights, "need_weights is not supported in efficient mode" + # Need to collapse heads and batch dimensions + q = q.reshape(bsz * num_heads, tgt_len, head_dim).contiguous() + k = k.reshape(bsz * num_heads, src_len, head_dim).contiguous() + v = v.reshape(bsz * num_heads, src_len, head_dim).contiguous() + row_offsets, column_indices = xformers.ops.find_locations_new( + q, k, attn_sparsity, True + ) + attn_output = xformers.ops.sparse_memory_efficient_attention( + q, k, v, row_offsets, column_indices, attn_bias=attn_mask + ).reshape(bsz, num_heads, tgt_len, head_dim) + attn_output = attn_output.permute(2, 0, 1, 3).reshape(bsz * tgt_len, embed_dim) + else: + raise ValueError(f"Unsupported attention type {attn_type}") + + attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias) + attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1)) + + if need_weights: + attn_output_weights = (q * head_dim**-0.5) @ k.transpose(-2, -1) + attn_output_weights = attn_output_weights.softmax(dim=-1) + attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) + if average_attn_weights: + attn_output_weights = attn_output_weights.sum(dim=1) / num_heads + + if not is_batched: + attn_output = attn_output.squeeze(1) + attn_output_weights = attn_output_weights.squeeze(0) + return attn_output, attn_output_weights + else: + attn_output_weights = None + if not is_batched: + attn_output = attn_output.squeeze(1) + return attn_output, None + + +class MultiheadAttention(nn.Module): + __constants__ = ["batch_first"] + bias_k: Optional[torch.Tensor] + bias_v: Optional[torch.Tensor] + + def __init__( + self, + embed_dim, + num_heads, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + kdim=None, + vdim=None, + batch_first=False, + device=None, + dtype=None, + attn_type: AttentionType = AttentionType.Vanilla, + sparsity: float = 0.0, + use_act_checkpoint: bool = False, + use_fa3: bool = False, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(MultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout = dropout + self.batch_first = batch_first + self.head_dim = embed_dim // num_heads + self.use_act_checkpoint = use_act_checkpoint + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + + assert ( + attn_type == AttentionType.Sparse or sparsity == 0.0 + ), "sparsity is only supported for sparse attention" + + if not self._qkv_same_embed_dim: + self.q_proj_weight = nn.Parameter( + torch.empty((embed_dim, embed_dim), **factory_kwargs) + ) + self.k_proj_weight = nn.Parameter( + torch.empty((embed_dim, self.kdim), **factory_kwargs) + ) + self.v_proj_weight = nn.Parameter( + torch.empty((embed_dim, self.vdim), **factory_kwargs) + ) + self.register_parameter("in_proj_weight", None) + else: + self.in_proj_weight = nn.Parameter( + torch.empty((3 * embed_dim, embed_dim), **factory_kwargs) + ) + self.register_parameter("q_proj_weight", None) + self.register_parameter("k_proj_weight", None) + self.register_parameter("v_proj_weight", None) + + if bias: + self.in_proj_bias = nn.Parameter( + torch.empty(3 * embed_dim, **factory_kwargs) + ) + else: + self.register_parameter("in_proj_bias", None) + self.out_proj = nn.modules.linear.NonDynamicallyQuantizableLinear( + embed_dim, embed_dim, bias=bias, **factory_kwargs + ) + + if add_bias_kv: + self.bias_k = nn.Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) + self.bias_v = nn.Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self.attn_type = attn_type + self.sparsity = sparsity + self.use_fa3 = use_fa3 + + self._reset_parameters() + + def _reset_parameters(self): + if self._qkv_same_embed_dim: + nn.init.xavier_uniform_(self.in_proj_weight) + else: + nn.init.xavier_uniform_(self.q_proj_weight) + nn.init.xavier_uniform_(self.k_proj_weight) + nn.init.xavier_uniform_(self.v_proj_weight) + + if self.in_proj_bias is not None: + nn.init.constant_(self.in_proj_bias, 0.0) + nn.init.constant_(self.out_proj.bias, 0.0) + if self.bias_k is not None: + nn.init.xavier_normal_(self.bias_k) + if self.bias_v is not None: + nn.init.xavier_normal_(self.bias_v) + + def __setstate__(self, state): + if "_qkv_same_embed_dim" not in state: + state["_qkv_same_embed_dim"] = True + + super(MultiheadAttention, self).__setstate__(state) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = False, + attn_mask: Optional[Tensor] = None, + average_attn_weights: bool = True, + attn_bias: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + is_batched = query.dim() == 3 + if key_padding_mask is not None: + _kpm_dtype = key_padding_mask.dtype + if _kpm_dtype != torch.bool and not torch.is_floating_point( + key_padding_mask + ): + raise AssertionError( + "only bool and floating types of key_padding_mask are supported" + ) + + if self.batch_first and is_batched: + if key is value: + if query is key: + query = key = value = query.transpose(1, 0) + else: + query, key = [x.transpose(1, 0) for x in (query, key)] + value = key + else: + query, key, value = [x.transpose(1, 0) for x in (query, key, value)] + + if not self._qkv_same_embed_dim: + if self.use_act_checkpoint: + attn_output, attn_output_weights = torch.utils.checkpoint.checkpoint( + multi_head_attention_forward, + query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + use_reentrant=False, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj_weight, + k_proj_weight=self.k_proj_weight, + v_proj_weight=self.v_proj_weight, + average_attn_weights=average_attn_weights, + attn_type=self.attn_type, + attn_sparsity=self.sparsity, + attn_bias=attn_bias, + use_fa3=self.use_fa3, + ) + else: + attn_output, attn_output_weights = multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj_weight, + k_proj_weight=self.k_proj_weight, + v_proj_weight=self.v_proj_weight, + average_attn_weights=average_attn_weights, + attn_type=self.attn_type, + attn_sparsity=self.sparsity, + attn_bias=attn_bias, + use_fa3=self.use_fa3, + ) + else: + if self.use_act_checkpoint: + attn_output, attn_output_weights = torch.utils.checkpoint.checkpoint( + multi_head_attention_forward, + query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + use_reentrant=False, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + average_attn_weights=average_attn_weights, + attn_type=self.attn_type, + attn_sparsity=self.sparsity, + attn_bias=attn_bias, + ) + else: + attn_output, attn_output_weights = multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + average_attn_weights=average_attn_weights, + attn_type=self.attn_type, + attn_sparsity=self.sparsity, + attn_bias=attn_bias, + ) + if self.batch_first and is_batched: + return attn_output.transpose(1, 0), attn_output_weights + else: + return attn_output, attn_output_weights + + +# Keep backward compatibility alias +MultiheadAttentionWrapper = MultiheadAttention + + +class DotProductScoring(torch.nn.Module): + def __init__( + self, + d_model, + d_proj, + prompt_mlp=None, + clamp_logits=True, + clamp_max_val=12.0, + ): + super().__init__() + self.d_proj = d_proj + assert isinstance(prompt_mlp, torch.nn.Module) or prompt_mlp is None + self.prompt_mlp = prompt_mlp # an optional MLP projection for prompt + self.prompt_proj = torch.nn.Linear(d_model, d_proj) + self.hs_proj = torch.nn.Linear(d_model, d_proj) + self.scale = float(1.0 / np.sqrt(d_proj)) + self.clamp_logits = clamp_logits + if self.clamp_logits: + self.clamp_max_val = clamp_max_val + + def mean_pool_text(self, prompt, prompt_mask): + # is_valid has shape (seq, bs, 1), where 1 is valid and 0 is padding + is_valid = (~prompt_mask).float().permute(1, 0)[..., None] + # num_valid has shape (bs, 1) + num_valid = torch.clamp(torch.sum(is_valid, dim=0), min=1.0) + # mean pool over all the valid tokens -- pooled_prompt has shape (bs, proj_dim) + pooled_prompt = (prompt * is_valid).sum(dim=0) / num_valid + return pooled_prompt + + def forward(self, hs, prompt, prompt_mask): + # hs has shape (num_layer, bs, num_query, d_model) + # prompt has shape (seq, bs, d_model) + # prompt_mask has shape (bs, seq), where 1 is valid and 0 is padding + assert hs.dim() == 4 and prompt.dim() == 3 and prompt_mask.dim() == 2 + + # apply MLP on prompt if specified + if self.prompt_mlp is not None: + prompt = self.prompt_mlp(prompt) + + # first, get the mean-pooled version of the prompt + pooled_prompt = self.mean_pool_text(prompt, prompt_mask) + + # then, project pooled_prompt and hs to d_proj dimensions + proj_pooled_prompt = self.prompt_proj(pooled_prompt) # (bs, d_proj) + proj_hs = self.hs_proj(hs) # (num_layer, bs, num_query, d_proj) + + # finally, get dot-product scores of shape (num_layer, bs, num_query, 1) + scores = torch.matmul(proj_hs, proj_pooled_prompt.unsqueeze(-1)) + scores *= self.scale + + # clamp scores to a max value to avoid numerical issues in loss or matcher + if self.clamp_logits: + scores.clamp_(min=-self.clamp_max_val, max=self.clamp_max_val) + + return scores + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +class TransformerWrapper(nn.Module): + def __init__( + self, + encoder, + decoder, + d_model: int, + two_stage_type="none", # ["none"] only for now + pos_enc_at_input_dec=True, + ): + super().__init__() + + self.encoder = encoder + self.decoder = decoder + self.num_queries = decoder.num_queries if decoder is not None else None + self.pos_enc_at_input_dec = pos_enc_at_input_dec + + # for two stage + assert two_stage_type in ["none"], "unknown param {} of two_stage_type".format( + two_stage_type + ) + self.two_stage_type = two_stage_type + + self._reset_parameters() + self.d_model = d_model + + def _reset_parameters(self): + for n, p in self.named_parameters(): + if p.dim() > 1: + if ( + "box_embed" not in n + and "query_embed" not in n + and "reference_points" not in n + ): + nn.init.xavier_uniform_(p) + + +class MLP(nn.Module): + """Very simple multi-layer perceptron (also called FFN)""" + + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + dropout: float = 0.0, + residual: bool = False, + out_norm: Optional[nn.Module] = None, + ): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.drop = nn.Dropout(dropout) if dropout > 0 else nn.Identity() + # whether to add the output as a residual connection to the input + if residual and input_dim != output_dim: + raise ValueError("residual is only supported if input_dim == output_dim") + self.residual = residual + # whether to apply a normalization layer to the output + assert isinstance(out_norm, nn.Module) or out_norm is None + self.out_norm = out_norm or nn.Identity() + + def forward(self, x): + orig_x = x + for i, layer in enumerate(self.layers): + x = self.drop(F.relu(layer(x))) if i < self.num_layers - 1 else layer(x) + if self.residual: + x = x + orig_x + x = self.out_norm(x) + return x + + +def get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def get_clones_seq(module, N): + return nn.Sequential(*[copy.deepcopy(module) for i in range(N)]) + + +def get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(f"activation should be relu/gelu, not {activation}.") + + +def get_activation_module(activation): + """Return an activation function given a string""" + if activation == "relu": + return nn.ReLU + if activation == "gelu": + return nn.GELU + if activation == "glu": + return nn.GLU + raise RuntimeError(f"activation should be relu/gelu, not {activation}.") + + +def get_valid_ratio(mask): + _, H, W = mask.shape + valid_H = torch.sum(~mask[:, :, 0], 1) + valid_W = torch.sum(~mask[:, 0, :], 1) + valid_ratio_h = valid_H.float() / H + valid_ratio_w = valid_W.float() / W + valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) + return valid_ratio + + +def gen_sineembed_for_position(pos_tensor, num_feats=256): + assert num_feats % 2 == 0 + num_feats = num_feats // 2 + # n_query, bs, _ = pos_tensor.size() + # sineembed_tensor = torch.zeros(n_query, bs, 256) + scale = 2 * math.pi + dim_t = torch.arange(num_feats, dtype=torch.float32, device=pos_tensor.device) + dim_t = 10000 ** (2 * (torch.div(dim_t, 2, rounding_mode="floor")) / num_feats) + x_embed = pos_tensor[:, :, 0] * scale + y_embed = pos_tensor[:, :, 1] * scale + pos_x = x_embed[:, :, None] / dim_t + pos_y = y_embed[:, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3 + ).flatten(2) + pos_y = torch.stack( + (pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3 + ).flatten(2) + if pos_tensor.size(-1) == 2: + pos = torch.cat((pos_y, pos_x), dim=2) + elif pos_tensor.size(-1) == 4: + w_embed = pos_tensor[:, :, 2] * scale + pos_w = w_embed[:, :, None] / dim_t + pos_w = torch.stack( + (pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3 + ).flatten(2) + + h_embed = pos_tensor[:, :, 3] * scale + pos_h = h_embed[:, :, None] / dim_t + pos_h = torch.stack( + (pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3 + ).flatten(2) + + pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) + else: + raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1))) + return pos + + +class SAM3Output(list): + """ + A class representing the output of a SAM3 model. + It provides an iterable interface that supports different iteration modes, including iterating over all steps per stage, + last step per stage, and flattened output. + Attributes: + output: The output of the SAM3 model, represented as a list of lists. + iter_mode: The current iteration mode. + Example: + >>> output = [[1, 2], [3, 4], [5, 6]] + >>> sam3_output = SAM3Output(output) + >>> for step in sam3_output: + ... print(step) + [1, 2] + [3, 4] + [5, 6] + >>> with SAM3Output.iteration_mode(SAM3Output.IterMode.LAST_STEP_PER_STAGE) as sam3_last_step_out: + ... for step in sam3_last_step_out: + ... print(step) + [2] + [4] + [6] + >>> with SAM3Output.iteration_mode(SAM3Output.IterMode.FLATTENED) as sam3_flattened_out: + ... for step in sam3_flattened_out: + ... print(step) + 1 + 2 + 3 + 4 + 5 + 6 + """ + + class IterMode(Enum): + # Defines the type of iterator over ouptuts. + ALL_STEPS_PER_STAGE = auto() + LAST_STEP_PER_STAGE = auto() + FLATTENED = auto() # Returns each interactivity step as if it is a separate stage (this is used in SAM3Image model) + + def __init__( + self, + output: List[List[Dict]] = None, + iter_mode: IterMode = IterMode.ALL_STEPS_PER_STAGE, + loss_stages: Optional[List[int]] = None, + ): + if output is not None: + assert ( + isinstance(output, list) + and len(output) > 0 + and isinstance(output[0], list) + ), "Expected output to be a list of lists" + self.output = output + else: + self.output = [] + assert isinstance( + iter_mode, SAM3Output.IterMode + ), f"iter_mode shoulf be of enum type 'SAM3Output.IterMode'. Got {type(iter_mode)}" + + self.iter_mode = iter_mode + # We create a weak reference to self to be used in the lambda functions. + # This is to avoid cyclic references and let SAM3Output be garabge collected. + self_ref = weakref.ref(self) + self._mode2iter = { + SAM3Output.IterMode.ALL_STEPS_PER_STAGE: lambda: iter(self_ref().output), + SAM3Output.IterMode.LAST_STEP_PER_STAGE: lambda: ( + inner_list[-1] for inner_list in self_ref().output + ), + SAM3Output.IterMode.FLATTENED: lambda: ( + element for inner_list in self_ref().output for element in inner_list + ), + } + self.loss_stages = loss_stages + + @override + def __iter__(self) -> Iterator: + return self._mode2iter[self.iter_mode]() + + def __getitem__(self, index): + """ + Returns the item at the specified index. + Args: + index (int): The index of the item to return. + Returns: + list or element: The item at the specified index. + """ + assert isinstance(index, int), f"index should be an integer. Got {type(index)}" + if self.iter_mode == SAM3Output.IterMode.ALL_STEPS_PER_STAGE: + return self.output[index] + elif self.iter_mode == SAM3Output.IterMode.LAST_STEP_PER_STAGE: + return self.output[index][-1] + elif self.iter_mode == SAM3Output.IterMode.FLATTENED: + if index == -1: + return self.self.output[-1][-1] + else: + flattened_output = sum(self.output, []) + return flattened_output[index] + + class _IterationMode(AbstractContextManager): + """ + A context manager that temporarily changes the iteration mode of a SAM3Output object. + This class is used internally by the SAM3Output.iteration_mode method. + """ + + def __init__( + self, model_output: "SAM3Output", iter_mode: "SAM3Output.IterMode" + ): + self._model_output = model_output + self._orig_iter_mode = model_output.iter_mode + self._new_iter_mode = iter_mode + + @override + def __enter__(self) -> "SAM3Output": + self._model_output.iter_mode = self._new_iter_mode + return self._model_output + + @override + def __exit__(self, exc_type, exc_value, traceback): + self._model_output.iter_mode = self._orig_iter_mode + return super().__exit__(exc_type, exc_value, traceback) + + @staticmethod + def iteration_mode( + model_output: "SAM3Output", iter_mode: IterMode + ) -> _IterationMode: + """ + Returns a context manager that allows you to temporarily change the iteration mode of the SAM3Output object. + Args: + model_output: The SAM3Output object. + iter_mode: The new iteration mode. + Returns: + SAM3Output._IterationMode: A context manager that changes the iteration mode of the SAM3Output object. + """ + return SAM3Output._IterationMode(model_output=model_output, iter_mode=iter_mode) + + def append(self, item: list): + assert isinstance( + item, list + ), f"Only list items are supported. Got {type(item)}" + self.output.append(item) + + def __repr__(self): + return self.output.__repr__() + + def __len__(self): + if self.iter_mode in [ + SAM3Output.IterMode.ALL_STEPS_PER_STAGE, + SAM3Output.IterMode.LAST_STEP_PER_STAGE, + ]: + return len(self.output) + elif self.iter_mode == SAM3Output.IterMode.FLATTENED: + flattened_output = sum(self.output, []) + return len(flattened_output) diff --git a/third_party/sam3/sam3/model/multiplex_mask_decoder.py b/third_party/sam3/sam3/model/multiplex_mask_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..c584675b7debf3818932ef6327925496be559886 --- /dev/null +++ b/third_party/sam3/sam3/model/multiplex_mask_decoder.py @@ -0,0 +1,470 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import List, Optional, Type + +import torch +from sam3.sam.common import LayerNorm2d +from torch import nn +from torch.nn import functional as F + + +class MultiplexMaskDecoder(nn.Module): + def __init__( + self, + *, + transformer_dim: int, + transformer: nn.Module, + multiplex_count: int, + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + use_high_res_features: bool = False, + iou_prediction_use_sigmoid: bool = False, + dynamic_multimask_via_stability=False, + dynamic_multimask_stability_delta=0.05, + dynamic_multimask_stability_thresh=0.98, + pred_obj_scores: bool = False, + pred_obj_scores_mlp: bool = False, + use_multimask_token_for_obj_ptr: bool = False, + decode_mask_with_shared_tokens: bool = False, + decode_mask_attribute_with_shared_tokens: bool = False, + multimask_outputs_only: bool = False, + ) -> None: + """ + Predicts masks given an image and prompt embeddings, using a + transformer architecture with multiplex capabilities. + + Arguments: + multiplex_count: the number of masks multiplexed into a single feature map + num_multimask_outputs: the number of masks to predict per multiplex output + (the total number of masks is (num_multimask_outputs+1) * multiplex_count) + use_multimask_token_for_obj_ptr: whether to use multimask tokens for object pointers + decode_mask_with_shared_tokens: use the same mask token for multimasks with different projection layers + decode_mask_attribute_with_shared_tokens: use the mask tokens (instead of separate tokens) + to predict iou and object scores + multimask_outputs_only: predict num_multimask_outputs masks without the single + mask output token (i.e., without the +1) + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = transformer + + self.multiplex_count = multiplex_count + self.num_multimask_outputs = num_multimask_outputs + self.multimask_outputs_only = multimask_outputs_only + self.decode_mask_with_shared_tokens = decode_mask_with_shared_tokens + self.decode_mask_attribute_with_shared_tokens = ( + decode_mask_attribute_with_shared_tokens + ) + + if self.decode_mask_with_shared_tokens: + assert ( + multimask_outputs_only + ), "multimask_outputs_only must be True if decode_mask_with_shared_tokens" + + if self.multimask_outputs_only: + self.num_mask_output_per_object = num_multimask_outputs + else: + self.num_mask_output_per_object = num_multimask_outputs + 1 + + if self.decode_mask_with_shared_tokens: + self.num_mask_tokens = multiplex_count + else: + self.num_mask_tokens = multiplex_count * self.num_mask_output_per_object + + self.pred_obj_scores = pred_obj_scores + self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr + + if not self.decode_mask_attribute_with_shared_tokens: + self.iou_token = nn.Embedding(multiplex_count, transformer_dim) + if self.pred_obj_scores: + self.obj_score_token = nn.Embedding(multiplex_count, transformer_dim) + + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d( + transformer_dim, transformer_dim // 4, kernel_size=2, stride=2 + ), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d( + transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2 + ), + activation(), + ) + self.use_high_res_features = use_high_res_features + if use_high_res_features: + self.conv_s0 = nn.Conv2d( + transformer_dim, transformer_dim // 8, kernel_size=1, stride=1 + ) + self.conv_s1 = nn.Conv2d( + transformer_dim, transformer_dim // 4, kernel_size=1, stride=1 + ) + + if self.num_multimask_outputs == 0: + self.output_hypernetworks_mlp = MLP( + transformer_dim, transformer_dim, transformer_dim // 8, 3 + ) + else: + self.output_hypernetworks_mlps = nn.ModuleList( + [ + MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) + for _ in range(self.num_mask_output_per_object) + ] + ) + + self.iou_prediction_head = MLP( + transformer_dim, + iou_head_hidden_dim, + ( + 1 + if ( + self.decode_mask_attribute_with_shared_tokens + and not self.decode_mask_with_shared_tokens + ) + else self.num_mask_output_per_object + ), + iou_head_depth, + sigmoid_output=iou_prediction_use_sigmoid, + ) + + if self.pred_obj_scores: + self.pred_obj_score_head = nn.Linear(transformer_dim, 1) + if pred_obj_scores_mlp: + self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3) + + # When outputting a single mask, optionally we can dynamically fall back to the best + # multimask output token if the single mask output token gives low stability scores. + self.dynamic_multimask_via_stability = dynamic_multimask_via_stability + self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta + self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + multimask_output: bool, + high_res_features: Optional[List[torch.Tensor]] = None, + extra_per_object_embeddings: Optional[torch.Tensor] = None, + ) -> dict[str, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + image_embeddings (torch.Tensor): the embeddings from the image encoder + image_pe (torch.Tensor): positional encoding with the shape of image_embeddings + extra_per_object_embeddings (torch.Tensor): a tensor with shape b * multiplex_count * C to be added to the mask tokens + + Returns: a dict of Tensors indexed by strings + masks: batched predicted masks + iou_pred: batched predictions of mask quality + object_score_logits: batched predictions of object existence + """ + + if self.num_multimask_outputs <= 0: + assert ( + not multimask_output + ), f"multimask_output must be False with {self.num_multimask_outputs=}" + + if self.multimask_outputs_only: + assert ( + multimask_output + ), f"multimask_output must be True with {self.multimask_outputs_only=}" + + out = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + high_res_features=high_res_features, + extra_per_object_embeddings=extra_per_object_embeddings, + ) + + masks = out["masks"] # [B, M, (self.num_mask_token_per_object), H, W] + iou_pred = out["iou_pred"] # [B, M, (self.num_mask_token_per_object)] + mask_tokens_out = out[ + "mask_tokens_out" + ] # [B, M, (self.num_mask_token_per_object), C] + + # Select the correct mask or masks for output + if multimask_output: + if not self.multimask_outputs_only: + masks = masks[:, :, 1:, :, :] + iou_pred = iou_pred[:, :, 1:] + elif self.dynamic_multimask_via_stability and not self.training: + masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred) + else: + masks = masks[:, :, 0:1, :, :] + iou_pred = iou_pred[:, :, 0:1] + + if multimask_output and self.use_multimask_token_for_obj_ptr: + if self.multimask_outputs_only: + sam_tokens_out = mask_tokens_out + else: + sam_tokens_out = mask_tokens_out[ + :, :, 1: + ] # [B, M, num_multimask_outputs, C] shape + else: + # Take the mask output token. Here we *always* use the token for single mask output. + # At test time, even if we track after 1-click (and using multimask_output=True), + # we still take the single mask token here. The rationale is that we always track + # after multiple clicks during training, so the past tokens seen during training + # are always the single mask token (and we'll let it be the object-memory token). + sam_tokens_out = mask_tokens_out[:, :, 0:1] # [B, M, 1, C] shape + + del out["mask_tokens_out"] + out["masks"] = masks + out["iou_pred"] = iou_pred + out["sam_tokens_out"] = sam_tokens_out + + if multimask_output: + assert ( + masks.shape[2] == self.num_mask_output_per_object + ), f"{masks.shape=}, {self.num_mask_output_per_object=}" + assert ( + iou_pred.shape[2] == self.num_mask_output_per_object + ), f"{iou_pred.shape=}, {self.num_mask_output_per_object=}" + if self.use_multimask_token_for_obj_ptr: + if self.decode_mask_with_shared_tokens: + assert sam_tokens_out.shape[2] == 1, f"{sam_tokens_out.shape=}" + else: + assert ( + sam_tokens_out.shape[2] == self.num_mask_output_per_object + ), f"{sam_tokens_out.shape=}, {self.num_mask_output_per_object=}" + else: + assert masks.shape[2] == 1, f"{masks.shape=}" + assert iou_pred.shape[2] == 1, f"{iou_pred.shape=}" + assert sam_tokens_out.shape[2] == 1, f"{sam_tokens_out.shape=}" + + return out + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + high_res_features: Optional[List[torch.Tensor]] = None, + extra_per_object_embeddings: Optional[ + torch.Tensor + ] = None, # num_buckets, multiplex_count, C + ) -> dict[str, torch.Tensor]: + """Predicts masks. See 'forward' for more details.""" + # Concatenate output tokens + B = image_embeddings.shape[0] + token_list = [] + if self.pred_obj_scores and not self.decode_mask_attribute_with_shared_tokens: + token_list.append(self.obj_score_token.weight) + if not self.decode_mask_attribute_with_shared_tokens: + token_list.append(self.iou_token.weight) + + tokens = torch.cat(token_list, dim=0) + tokens = tokens.unsqueeze(0).expand(B, -1, -1) + + if extra_per_object_embeddings is not None: + mask_tokens = self.mask_tokens.weight.view( + 1, self.multiplex_count, self.num_mask_output_per_object, -1 + ).expand(B, -1, -1, -1) + + mask_tokens = mask_tokens + extra_per_object_embeddings.unsqueeze(2) + mask_tokens = mask_tokens.flatten(1, 2) + else: + mask_tokens = self.mask_tokens.weight.unsqueeze(0).expand(B, -1, -1) + + tokens = torch.cat([tokens, mask_tokens], dim=1) + + src = image_embeddings + + assert ( + image_pe.size(0) == 1 + ), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)" + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens) + + # Parse transformer outputs based on token sharing configuration + if self.decode_mask_attribute_with_shared_tokens: + assert ( + hs.shape[1] == self.num_mask_tokens + ), f"{hs.shape=}, {self.num_mask_tokens=}" + iou_token_out = mask_tokens_out = hs[:, 0 : self.num_mask_tokens] + if self.pred_obj_scores: + obj_score_token_out = mask_tokens_out + else: + # Separate tokens for each prediction type + s = 0 + if self.pred_obj_scores: + obj_score_token_out = hs[:, s : s + self.multiplex_count, :] + s += self.multiplex_count + + iou_token_out = hs[:, s : s + self.multiplex_count, :] + s += self.multiplex_count + mask_tokens_out = hs[:, s : s + self.num_mask_tokens, :] + assert ( + hs.shape[1] == s + self.num_mask_tokens + ), f"{hs.shape=}, {s=}, {self.num_mask_tokens=}" + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + if not self.use_high_res_features: + upscaled_embedding = self.output_upscaling(src) + else: + dc1, ln1, act1, dc2, act2 = self.output_upscaling + feat_s0, feat_s1 = high_res_features + upscaled_embedding = act1(ln1(dc1(src) + feat_s1)) + upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0) + + if self.decode_mask_with_shared_tokens: + mask_tokens_out = mask_tokens_out.view(B, self.multiplex_count, 1, -1) + else: + mask_tokens_out = mask_tokens_out.view( + B, self.multiplex_count, self.num_mask_output_per_object, -1 + ) + if self.num_multimask_outputs == 0: + hyper_in = self.output_hypernetworks_mlp( + mask_tokens_out[:, :, 0, :] + ).unsqueeze(2) # [B, M, 1, C] + else: + hyper_in_list: List[torch.Tensor] = [] + for i in range(self.num_mask_output_per_object): + if self.decode_mask_with_shared_tokens: + hyper_in_list.append( + self.output_hypernetworks_mlps[i](mask_tokens_out[:, :, 0, :]) + ) + else: + hyper_in_list.append( + self.output_hypernetworks_mlps[i](mask_tokens_out[:, :, i, :]) + ) + # hyper_in: [B, M, num_multimask_outputs+1, C] + hyper_in = torch.stack(hyper_in_list, dim=2) + + # generate the masks + b, c, h, w = upscaled_embedding.shape + masks = torch.bmm( + hyper_in.flatten(1, 2), upscaled_embedding.view(b, c, h * w) + ).view(b, self.multiplex_count, self.num_mask_output_per_object, h, w) + + # Generate mask quality predictions, with shape b * multiplex_count * (num_multimask_outputs+1) + iou_pred = self.iou_prediction_head(iou_token_out).view( + b, self.multiplex_count, self.num_mask_output_per_object + ) + + if self.pred_obj_scores: + # Generate mask quality predictions, with shape b * (num_multimask_outputs+1) + if ( + self.decode_mask_attribute_with_shared_tokens + and not self.decode_mask_with_shared_tokens + ): + object_score_logits = ( + self.pred_obj_score_head(obj_score_token_out) + .view(b, self.multiplex_count, self.num_mask_output_per_object) + .sum(-1, keepdim=True) + ) + else: + object_score_logits = self.pred_obj_score_head(obj_score_token_out) + else: + # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1 + object_score_logits = 10.0 * iou_pred.new_ones( + iou_pred.shape[0], iou_pred.shape[1] + ) + + outputs = { + "masks": masks, + "iou_pred": iou_pred, + "mask_tokens_out": mask_tokens_out, + "object_score_logits": object_score_logits, + } + + return outputs + + def _get_stability_scores(self, mask_logits): + """ + Compute stability scores of the mask logits based on the IoU between upper and + lower thresholds. + """ + mask_logits = mask_logits.flatten(-2) + stability_delta = self.dynamic_multimask_stability_delta + area_i = torch.sum(mask_logits > stability_delta, dim=-1).float() + area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float() + stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0) + return stability_scores + + def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores): + """ + When outputting a single mask, if the stability score from the current single-mask + output (based on output token 0) falls below a threshold, we instead select from + multi-mask outputs (based on output token 1~3) the mask with the highest predicted + IoU score. This is intended to ensure a valid mask for both clicking and tracking. + """ + # first, flatten the batch and the multiplex dimensions + B, M = all_mask_logits.shape[:2] + all_mask_logits = all_mask_logits.flatten(0, 1) + all_iou_scores = all_iou_scores.flatten(0, 1) + + # The best mask from multimask output tokens (1~3) + multimask_logits = all_mask_logits[:, 1:, :, :] + multimask_iou_scores = all_iou_scores[:, 1:] + best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1) + batch_inds = torch.arange( + multimask_iou_scores.size(0), device=all_iou_scores.device + ) + best_multimask_logits = multimask_logits[batch_inds, best_scores_inds] + best_multimask_logits = best_multimask_logits.unsqueeze(1) + best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds] + best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1) + + # The mask from singlemask output token 0 and its stability score + singlemask_logits = all_mask_logits[:, 0:1, :, :] + singlemask_iou_scores = all_iou_scores[:, 0:1] + stability_scores = self._get_stability_scores(singlemask_logits) + is_stable = stability_scores >= self.dynamic_multimask_stability_thresh + + # Dynamically fall back to best multimask output upon low stability scores. + mask_logits_out = torch.where( + is_stable[..., None, None].expand_as(singlemask_logits), + singlemask_logits, + best_multimask_logits, + ) + iou_scores_out = torch.where( + is_stable.expand_as(singlemask_iou_scores), + singlemask_iou_scores, + best_multimask_iou_scores, + ) + + # restore the batch and multiplex dimensions + mask_logits_out = mask_logits_out.unflatten(0, (B, M)) + iou_scores_out = iou_scores_out.unflatten(0, (B, M)) + + return mask_logits_out, iou_scores_out + + +# Lightly adapted from +# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + sigmoid_output: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.sigmoid_output = sigmoid_output + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x diff --git a/third_party/sam3/sam3/model/multiplex_utils.py b/third_party/sam3/sam3/model/multiplex_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..44f102a8e499591d6e54d83fa6829d0a7972dfda --- /dev/null +++ b/third_party/sam3/sam3/model/multiplex_utils.py @@ -0,0 +1,538 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +import logging +import math +from typing import Optional + +import torch +from torch import nn + +# Special values for object tracking +_PADDING_NUM = -1 # Marks empty slots in buckets +_REMOVED_NUM = -1116 # Marks objects that have been removed + + +logger = logging.getLogger(__name__) + + +class MultiplexState: + """ + MultiplexState records the state of multiplexing, for one or more buckets. + + At a high level, we deal with the conversion of tensors between the data space (batch_size, num_channels, ...) + and the multiplex space (num_buckets, multiplex_count, num_channels, ...). + + The multiplex state stores the assignments of each batch element to a slot in a bucket. + Each bucket has a fixed number of slots (multiplex_count), and not all slots need to be filled. + The batch size should equate to total_valid_entries, which is the sum of the number of valid entries in each bucket. + + There are two main operations that this class supports: + mux: convert tensors in the data space to the multiplex space. + The mental model is that we start from a tensor of zeros that has the shape of the output, + then we go through the valid entries and place them into the corresponding slots, indicated by the assignments. + + demux: convert tensors in the multiplex space to the data space. + This is the reverse operation of mux. Note that zeros were used in mux for the padding slots, + and that those slots are ignored in demux. + + There are also two utility functions for object mangement: + add_objects: add new objects to the state by filling in empty slots + remove_objects: remove objects from the state by marking them as removed (not the same as empty!) + """ + + def __init__( + self, + assignments: list[list[int]], + device: torch.device, + dtype: torch.dtype, + allowed_bucket_capacity: int, + *, + object_ids: Optional[list[int]] = None, + ): + """ + assignments: a list of lists of object indices + Each top-level list represents a bucket + Each inner list represents the object indices that are in the bucket + The object indices must ranges from 0 to num_valid_entries - 1, except for the following special values (all negatives): + _PADDING_NUM, which denotes padding entries + _REMOVED_NUM, which denotes an pre-existing object that got removed (currently not used during init) + If you wish to save the "true" object IDs, i.e., during inference, you can bookkeep them here + """ + self.device = device + self.dtype = dtype + + # Initialize bucket assignments and precompute matrices + self.allowed_bucket_capacity = allowed_bucket_capacity + self._initialize_assignments(assignments, object_ids=object_ids) + + def _initialize_assignments( + self, assignments: list[list[int]], *, object_ids: Optional[list[int]] = None + ): + self.assignments = assignments + self.num_buckets = len(self.assignments) + if self.num_buckets == 0: + logger.error("No buckets found in the state") + raise ValueError("No buckets found in the state") + + self.multiplex_count = len(self.assignments[0]) + assert all( + len(self.assignments[i]) == self.multiplex_count + for i in range(self.num_buckets) + ) + + # number of non-negative elements in the state + self.total_valid_entries = sum( + sum(1 for x in bucket if x >= 0) for bucket in self.assignments + ) + self.total_non_padding_entries = sum( + sum(1 for x in bucket if x != _PADDING_NUM) for bucket in self.assignments + ) + + # check the validity of the object IDs + self.object_ids = object_ids + if self.object_ids is not None: + assert ( + len(self.object_ids) == self.total_valid_entries + ), "object_ids should map 1:1 to the valid entries" + + # check the validity of the assignments + all_object_idxs = set() + for bucket in self.assignments: + valid_entries_in_bucket = sum(1 for x in bucket if x != _PADDING_NUM) + assert ( + valid_entries_in_bucket <= self.allowed_bucket_capacity + ), f"{valid_entries_in_bucket=} > {self.allowed_bucket_capacity=}" + for obj_idx in bucket: + if obj_idx >= 0: + assert ( + obj_idx < self.total_non_padding_entries + ), f"object ID {obj_idx} >= {self.total_non_padding_entries}" + assert obj_idx not in all_object_idxs, "object IDs must be unique" + all_object_idxs.add(obj_idx) + + # Precompute and cache the actual selection matrices + self._precompute_transition_matrices(self.device, self.dtype) + + @property + def available_slots(self) -> int: + # returns the number of available slots in the state + return ( + self.num_buckets * self.allowed_bucket_capacity + - self.total_non_padding_entries + ) + + def find_next_batch_of_available_indices( + self, + num_objects: int, + *, + allow_new_buckets: bool = False, + prefer_new_buckets: bool = False, + ) -> list[int]: + # produce a list of consecutive indices that are available in the state + # Note: prefer_new_buckets parameter is accepted for API compatibility but not used here + # as the actual bucket allocation logic is in add_objects() + assert num_objects > 0, f"{num_objects=} must be positive" + if not allow_new_buckets: + assert ( + self.available_slots >= num_objects + ), f"not enough available slots {self.available_slots} < {num_objects}" + + return list( + range( + self.total_valid_entries, + self.total_valid_entries + num_objects, + ) + ) + + def add_objects( + self, + object_indices: list[int], + *, + object_ids: Optional[list[int]] = None, + allow_new_buckets: bool = False, + prefer_new_buckets: bool = False, + ): + """ + Add new objects to the state by filling in empty slots and + creating new buckets if necessary. + + object_indices must be sorted and follow existing object indices. + If prefer_new_buckets is True, we skip filling existing slots and place + the objects into freshly created buckets (requires allow_new_buckets=True). + """ + if len(object_indices) == 0: + return + + # we will modify this in-place + object_indices = object_indices.copy() + assert (object_ids is None) == ( + self.object_ids is None + ), "object_ids must either be always given or always omitted" + + if object_ids is not None: + assert len(object_ids) == len( + object_indices + ), "object_ids must have the same length as object_indices" + object_ids = object_ids.copy() + + num_new_objects = len(object_indices) + assert object_indices == sorted(object_indices), "object_indices must be sorted" + object_indices.reverse() # reverse so we can pop from the end + if object_ids is not None: + object_ids.reverse() + + if prefer_new_buckets: + assert allow_new_buckets, "prefer_new_buckets requires allow_new_buckets" + + slots_filled = 0 + buckets_created = 0 + + def _pop_next(): + idx = object_indices.pop() + if object_ids is not None and self.object_ids is not None: + self.object_ids.append(object_ids.pop()) + return idx + + if not prefer_new_buckets: + # Fill empty slots in existing buckets first + for bucket in self.assignments: + for i in range(self.allowed_bucket_capacity): + if bucket[i] == _PADDING_NUM: + bucket[i] = _pop_next() + slots_filled += 1 + if len(object_indices) == 0: + break + if len(object_indices) == 0: + break + + if len(object_indices) > 0 and not allow_new_buckets: + raise ValueError( + f"Cannot place objects {list(reversed(object_indices))} without creating new buckets" + ) + + # Create new buckets for remaining objects (or all objects if prefer_new_buckets) + while len(object_indices) > 0: + new_bucket = [_PADDING_NUM] * self.multiplex_count + for i in range(self.allowed_bucket_capacity): + if len(object_indices) == 0: + break + new_bucket[i] = _pop_next() + self.assignments.append(new_bucket) + buckets_created += 1 + + # reinitialize all the settings + original_num_entries = self.total_valid_entries + self._initialize_assignments(self.assignments, object_ids=self.object_ids) + assert ( + self.total_valid_entries == original_num_entries + num_new_objects + ), f"{self.total_valid_entries=} != {original_num_entries=} + {num_new_objects=}" + + logger.info( + f"Filled {slots_filled} slots and created {buckets_created} new buckets" + ) + logger.info( + f"{self.num_buckets=}, {self.total_valid_entries=}, {self.total_non_padding_entries=}" + ) + + def remove_objects(self, object_indices: list[int], strict: bool = True): + """ + Remove objects from the state by marking them as removed. + Remove a bucket if all objects in the bucket are removed. + + Args: + object_indices: List of object indices to remove + strict: If True, will raise an error if any object indices are not found in the state + + Returns: + List of bucket indices that we are going to keep + """ + object_indices = object_indices.copy() + + # Mark objects as removed in assignments + for bucket_idx, bucket in enumerate(self.assignments): + for slot_idx, obj_id in enumerate(bucket): + if obj_id in object_indices: + self.assignments[bucket_idx][slot_idx] = _REMOVED_NUM + object_indices.remove(obj_id) + + if strict: + assert ( + len(object_indices) == 0 + ), f"Failed to remove objects: {object_indices}" + + # Check which buckets should be completely removed (all objects removed/paddings) + # and which buckets we will keep + buckets_to_remove = [] + buckets_to_keep = [] + for bucket_idx, bucket in enumerate(self.assignments): + # Check if all objects in this bucket are removed or are paddings + all_removed = all( + obj_id in [_PADDING_NUM, _REMOVED_NUM] for obj_id in bucket + ) + if all_removed: + buckets_to_remove.append(bucket_idx) + logger.info( + f"Bucket {bucket_idx} marked for removal - all objects removed/paddings" + ) + else: + buckets_to_keep.append(bucket_idx) + + # Remove buckets in reverse order to maintain correct indices + for bucket_idx in reversed(buckets_to_remove): + del self.assignments[bucket_idx] + + if len(buckets_to_keep) == 0: + logger.info(f"Removing all buckets: {buckets_to_remove}; state invalidated") + self.assignments = None + if self.object_ids is not None: + self.object_ids = [] + return buckets_to_keep + + # After removal, remap object IDs to be sequential + # Collect all unique positive object IDs and create a mapping to sequential IDs + all_positive_ids = set() + for bucket in self.assignments: + for obj_id in bucket: + if obj_id >= 0: + all_positive_ids.add(obj_id) + + # Create mapping from old IDs to new sequential IDs + sorted_ids = sorted(all_positive_ids) + id_mapping = {old_id: new_id for new_id, old_id in enumerate(sorted_ids)} + + # Apply the mapping to assignments to make IDs sequential + for bucket in self.assignments: + for i, obj_id in enumerate(bucket): + if obj_id >= 0: + bucket[i] = id_mapping[obj_id] + + # Update object_ids if they exist + if self.object_ids is not None: + # Create new object_ids array based on the remapped indices + # We need to preserve the original object_ids for the objects that weren't removed + new_object_ids = [None] * len(sorted_ids) + + # Map the original object_ids to their new positions + for old_idx, new_idx in id_mapping.items(): + new_object_ids[new_idx] = self.object_ids[old_idx] + + assert not any(obj_id is None for obj_id in new_object_ids) + self.object_ids = new_object_ids + + # Reinitialize the state to update all internal structures + self._initialize_assignments(self.assignments, object_ids=self.object_ids) + + logger.info(f"Removed these buckets: {buckets_to_remove}") + logger.info(f"Kept these buckets: {buckets_to_keep}") + logger.info( + f"Remaining buckets: {self.num_buckets}, total valid entries: {self.total_valid_entries}" + ) + + return buckets_to_keep + + def _precompute_transition_matrices(self, device: torch.device, dtype: torch.dtype): + """ + Precompute the transition matrices for maximum efficiency. + Note that these should be partial permutation matrices. + """ + # Create a transition matrix for muxing + self.mux_matrix = torch.zeros( + self.num_buckets * self.multiplex_count, + self.total_valid_entries, + device=device, + dtype=dtype, + ) + + # Create a transition matrix for demuxing + self.demux_matrix = torch.zeros( + self.total_valid_entries, + self.num_buckets * self.multiplex_count, + device=device, + dtype=dtype, + ) + + # Fill both matrices based on assignments + for i in range(self.num_buckets): + for j in range(self.multiplex_count): + bucket_idx = i * self.multiplex_count + j + object_idx = self.assignments[i][j] + if object_idx >= 0: + self.mux_matrix[bucket_idx, object_idx] = 1.0 + self.demux_matrix[object_idx, bucket_idx] = 1.0 + + def mux(self, x: torch.Tensor) -> torch.Tensor: + """ + Multiplexing operation + x: self.total_valid_entries * ... + + return num_buckets * multiplex_count * ... + with padding entries filled with 0 + """ + num_valid_entries = x.shape[0] + assert ( + num_valid_entries == self.total_valid_entries + ), f"{num_valid_entries=} != {self.total_valid_entries=}" + output_shape = ( + self.num_buckets, + self.multiplex_count, + ) + x.shape[1:] + + x_flat = x.reshape(num_valid_entries, -1) + + # Apply mux matrix: (num_buckets * multiplex_count, batch_size) @ (batch_size, features) + # Result: (num_buckets * multiplex_count, features) + result_flat = self.mux_matrix @ x_flat + + result = result_flat.view(output_shape) + return result + + def demux(self, x: torch.Tensor) -> torch.Tensor: + """ + Inverse operation of mux + x: num_buckets, multiplex_count * ... + Returns: total_valid_entries * ... + """ + num_buckets, multiplex_count = x.shape[:2] + assert num_buckets == self.num_buckets, f"{num_buckets=} != {self.num_buckets=}" + assert ( + multiplex_count == self.multiplex_count + ), f"{multiplex_count=} != {self.multiplex_count=}" + output_shape = (self.total_valid_entries,) + x.shape[2:] + + x_flat = x.reshape(num_buckets * multiplex_count, -1) + + # Apply demux matrix: (total_valid_entries, num_buckets*multiplex_count) @ (num_buckets*multiplex_count, features) + # Result: (total_valid_entries, features) + result_flat = self.demux_matrix @ x_flat + + result = result_flat.view(output_shape) + return result + + def get_valid_object_mask(self) -> torch.Tensor: + """ + Returns a (num_buckets, multiplex_count) tensor with 1 for valid entries and 0 for padding entries + """ + valid_mask = self.mux_matrix.sum(dim=1) > 0 + valid_mask = valid_mask.reshape(self.num_buckets, self.multiplex_count) + + return valid_mask + + def get_all_valid_object_idx(self) -> set[int]: + """ + Returns a set of all valid object idx in the state + Note that this returns the internal object idx representations, + not the arbitrary object IDs that are passed in during initialization + """ + all_valid_objects = { + obj_idx for bucket in self.assignments for obj_idx in bucket if obj_idx >= 0 + } + return all_valid_objects + + +class MultiplexController(nn.Module): + def __init__( + self, + multiplex_count: int, + full_shuffle: bool = False, + eval_multiplex_count: int = -1, + ): + super().__init__() + + self.multiplex_count = multiplex_count + self.full_shuffle = full_shuffle + if eval_multiplex_count < 0: + self.eval_multiplex_count = multiplex_count + else: + self.eval_multiplex_count = eval_multiplex_count + assert self.multiplex_count >= 1 + + @property + def allowed_bucket_capacity(self) -> int: + if self.training: + return self.multiplex_count + else: + return self.eval_multiplex_count + + def get_state( + self, + num_valid_entries: int, + device: torch.device, + dtype: torch.dtype, + random: bool = True, + *, + object_ids: Optional[ + list[int] + ] = None, # object_ids is an auxiliary field that we pass to the state unmodified + ) -> MultiplexState: + # returns a state that maps elements in the batch to buckets of size self.multiplex_count + + allowed_bucket_capacity = self.allowed_bucket_capacity + + # the size of the bucket during training + true_bucket_capacity = self.multiplex_count + + num_buckets = math.ceil(num_valid_entries / allowed_bucket_capacity) + # each bucket contains at most self.multiplex_count elements + # padding elements are marked with _PADDING_NUM (only the last bucket should contain _PADDING_NUM) + + if self.full_shuffle: + # Shuffle all IDs, including the paddings + ids = torch.cat( + [ + torch.arange(num_valid_entries, dtype=torch.long), + torch.tensor( + [_PADDING_NUM] + * (num_buckets * true_bucket_capacity - num_valid_entries), + dtype=torch.long, + ), + ], + dim=0, + ) + if random: + indices = torch.randperm(ids.shape[0], dtype=torch.long) + ids = ids[indices] + + # convert to a list of list + assignments = [] + for i in range(num_buckets): + assignments.append( + ids[ + i * true_bucket_capacity : (i + 1) * true_bucket_capacity + ].tolist() + ) + else: + # Only shuffle the the IDs within the first #batch_size slots, leave all paddings at the end + if random: + # randomly assign ids to buckets + ids = torch.randperm(num_valid_entries, dtype=torch.int64) + else: + ids = torch.arange(num_valid_entries) + # append with _PADDING_NUM to make a multiple of bucket_capacity + total_elements = num_buckets * allowed_bucket_capacity + if ids.shape[0] < total_elements: + ids = torch.cat( + [ + ids, + torch.tensor([_PADDING_NUM] * (total_elements - ids.shape[0])), + ] + ) + + # convert to a list of list + assignments = [] + for i in range(num_buckets): + assignments.append( + ids[ + i * allowed_bucket_capacity : (i + 1) * allowed_bucket_capacity + ].tolist() + + [_PADDING_NUM] * (true_bucket_capacity - allowed_bucket_capacity) + ) + + return MultiplexState( + assignments, + device, + dtype, + allowed_bucket_capacity=allowed_bucket_capacity, + object_ids=object_ids, + ) diff --git a/third_party/sam3/sam3/model/necks.py b/third_party/sam3/sam3/model/necks.py new file mode 100644 index 0000000000000000000000000000000000000000..6db174e32cb64747c15f5c3fb1d2ffd906fef81c --- /dev/null +++ b/third_party/sam3/sam3/model/necks.py @@ -0,0 +1,269 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +"""Necks are the interface between a vision backbone and the rest of the detection model""" + +from copy import deepcopy +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +from sam3.model.data_misc import NestedTensor + + +class Sam3DualViTDetNeck(nn.Module): + def __init__( + self, + trunk: nn.Module, + position_encoding: nn.Module, + d_model: int, + scale_factors=(4.0, 2.0, 1.0, 0.5), + add_sam2_neck: bool = False, + ): + """ + SimpleFPN neck a la ViTDet + (From detectron2, very lightly adapted) + It supports a "dual neck" setting, where we have two identical necks (for SAM3 and SAM2), with different weights + + :param trunk: the backbone + :param position_encoding: the positional encoding to use + :param d_model: the dimension of the model + """ + super().__init__() + self.trunk = trunk + self.position_encoding = position_encoding + self.convs = nn.ModuleList() + + self.scale_factors = scale_factors + use_bias = True + dim: int = self.trunk.channel_list[-1] + + for _, scale in enumerate(scale_factors): + current = nn.Sequential() + + if scale == 4.0: + current.add_module( + "dconv_2x2_0", + nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2), + ) + current.add_module( + "gelu", + nn.GELU(), + ) + current.add_module( + "dconv_2x2_1", + nn.ConvTranspose2d(dim // 2, dim // 4, kernel_size=2, stride=2), + ) + out_dim = dim // 4 + elif scale == 2.0: + current.add_module( + "dconv_2x2", + nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2), + ) + out_dim = dim // 2 + elif scale == 1.0: + out_dim = dim + elif scale == 0.5: + current.add_module( + "maxpool_2x2", + nn.MaxPool2d(kernel_size=2, stride=2), + ) + out_dim = dim + else: + raise NotImplementedError(f"scale_factor={scale} is not supported yet.") + + current.add_module( + "conv_1x1", + nn.Conv2d( + in_channels=out_dim, + out_channels=d_model, + kernel_size=1, + bias=use_bias, + ), + ) + current.add_module( + "conv_3x3", + nn.Conv2d( + in_channels=d_model, + out_channels=d_model, + kernel_size=3, + padding=1, + bias=use_bias, + ), + ) + self.convs.append(current) + + self.sam2_convs = None + if add_sam2_neck: + # Assumes sam2 neck is just a clone of the original neck + self.sam2_convs = deepcopy(self.convs) + + def forward( + self, tensor_list: List[torch.Tensor] + ) -> Tuple[ + List[torch.Tensor], + List[torch.Tensor], + Optional[List[torch.Tensor]], + Optional[List[torch.Tensor]], + ]: + xs = self.trunk(tensor_list) + sam3_out, sam3_pos = [], [] + sam2_out, sam2_pos = None, None + if self.sam2_convs is not None: + sam2_out, sam2_pos = [], [] + x = xs[-1] # simpleFPN + for i in range(len(self.convs)): + sam3_x_out = self.convs[i](x) + sam3_pos_out = self.position_encoding(sam3_x_out).to(sam3_x_out.dtype) + sam3_out.append(sam3_x_out) + sam3_pos.append(sam3_pos_out) + + if self.sam2_convs is not None: + sam2_x_out = self.sam2_convs[i](x) + sam2_pos_out = self.position_encoding(sam2_x_out).to(sam2_x_out.dtype) + sam2_out.append(sam2_x_out) + sam2_pos.append(sam2_pos_out) + return sam3_out, sam3_pos, sam2_out, sam2_pos + + +class Sam3TriViTDetNeck(nn.Module): + def __init__( + self, + trunk: nn.Module, + position_encoding: nn.Module, + d_model: int, + neck_norm=None, + scale_factors=(4.0, 2.0, 1.0), + ): + """ + SimpleFPN neck with three heads (sam3, interactive, propagation). + """ + super().__init__() + self.trunk = trunk + self.position_encoding = position_encoding + self.convs = nn.ModuleList() + + self.scale_factors = scale_factors + use_bias = neck_norm is None + dim = self.trunk.channel_list[-1] + + for _, scale in enumerate(scale_factors): + current = nn.Sequential() + + if scale == 4.0: + current.add_module( + "dconv_2x2_0", + nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2), + ) + current.add_module( + "gelu", + nn.GELU(), + ) + current.add_module( + "dconv_2x2_1", + nn.ConvTranspose2d(dim // 2, dim // 4, kernel_size=2, stride=2), + ) + out_dim = dim // 4 + elif scale == 2.0: + current.add_module( + "dconv_2x2", + nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2), + ) + out_dim = dim // 2 + elif scale == 1.0: + out_dim = dim + elif scale == 0.5: + current.add_module( + "maxpool_2x2", + nn.MaxPool2d(kernel_size=2, stride=2), + ) + out_dim = dim + else: + raise NotImplementedError(f"scale_factor={scale} is not supported yet.") + + current.add_module( + "conv_1x1", + nn.Conv2d( + in_channels=out_dim, + out_channels=d_model, + kernel_size=1, + bias=use_bias, + ), + ) + current.add_module( + "conv_3x3", + nn.Conv2d( + in_channels=d_model, + out_channels=d_model, + kernel_size=3, + padding=1, + bias=use_bias, + ), + ) + self.convs.append(current) + + # Assumes the new necks are just clones of the original neck + self.interactive_convs = deepcopy(self.convs) + self.propagation_convs = deepcopy(self.convs) + + def forward( + self, + tensor_list, + *, + need_sam3_out: bool = True, + need_interactive_out: bool = True, + need_propagation_out: bool = True, + ): + xs = self.trunk(tensor_list) + sam3_out = [] + interactive_out = [] + propagation_out = [] + + sam3_pos = [] + interactive_pos = [] + propagation_pos = [] + x = xs[-1] # simpleFPN + # OSS trunk returns plain tensors; onevision trunk returns NestedTensors. + # Use getattr to handle both in a torch.compile-friendly way. + x_data = getattr(x, "tensors", x) + x_mask = getattr(x, "mask", None) + for _, (conv, interactive_conv, propagation_conv) in enumerate( + zip(self.convs, self.interactive_convs, self.propagation_convs) + ): + if need_sam3_out: + sam3_conv_out = conv(x_data) + sam3_x_out = NestedTensor(sam3_conv_out, x_mask) + sam3_out.append(sam3_x_out) + sam3_pos.append( + self.position_encoding(sam3_conv_out).to(sam3_conv_out.dtype) + ) + + if need_interactive_out: + interactive_conv_out_t = interactive_conv(x_data) + interactive_conv_out = NestedTensor(interactive_conv_out_t, x_mask) + interactive_out.append(interactive_conv_out) + interactive_pos.append( + self.position_encoding(interactive_conv_out_t).to( + interactive_conv_out_t.dtype + ) + ) + + if need_propagation_out: + propagation_conv_out = propagation_conv(x_data) + propagation_x_out = NestedTensor(propagation_conv_out, x_mask) + propagation_out.append(propagation_x_out) + propagation_pos.append( + self.position_encoding(propagation_conv_out).to( + propagation_conv_out.dtype + ) + ) + + return ( + sam3_out, + sam3_pos, + interactive_out, + interactive_pos, + propagation_out, + propagation_pos, + ) diff --git a/third_party/sam3/sam3/model/position_encoding.py b/third_party/sam3/sam3/model/position_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..279b2c113bdb649f9de85cb8b47b3c6452a02779 --- /dev/null +++ b/third_party/sam3/sam3/model/position_encoding.py @@ -0,0 +1,133 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +import math +from typing import Optional + +import torch +from torch import nn + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + + def __init__( + self, + num_pos_feats, + temperature: int = 10000, + normalize: bool = True, + scale: Optional[float] = None, + precompute_resolution: Optional[int] = None, + ): + super().__init__() + assert num_pos_feats % 2 == 0, "Expecting even model width" + self.num_pos_feats = num_pos_feats // 2 + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + self.cache = {} + # Precompute positional encodings under `precompute_resolution` to fill the cache + # and avoid symbolic shape tracing errors in torch.compile in PyTorch 2.4 nightly. + if precompute_resolution is not None: + # We precompute pos enc for all strides used by both DualViTDetNeck and + # TriViTDetNeck (scale_factors 4.0, 2.0, 1.0, 0.5 applied to backbone + # output at stride 14 from 1008px input → 72x72). + precompute_sizes = [ + (int(precompute_resolution // 3.5), int(precompute_resolution // 3.5)), + (precompute_resolution // 4, precompute_resolution // 4), + (int(precompute_resolution // 7), int(precompute_resolution // 7)), + (precompute_resolution // 8, precompute_resolution // 8), + (int(precompute_resolution // 14), int(precompute_resolution // 14)), + (precompute_resolution // 16, precompute_resolution // 16), + (int(precompute_resolution // 28), int(precompute_resolution // 28)), + (precompute_resolution // 32, precompute_resolution // 32), + ] + _device = "cuda" if torch.cuda.is_available() else "cpu" + for size in precompute_sizes: + tensors = torch.zeros((1, 1) + size, device=_device) + self.forward(tensors) + # further clone and detach it in the cache (just to be safe) + self.cache[size] = self.cache[size].clone().detach() + + def _encode_xy(self, x, y): + # The positions are expected to be normalized + assert len(x) == len(y) and x.ndim == y.ndim == 1 + x_embed = x * self.scale + y_embed = y * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, None] / dim_t + pos_y = y_embed[:, None] / dim_t + pos_x = torch.stack( + (pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2 + ).flatten(1) + pos_y = torch.stack( + (pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2 + ).flatten(1) + return pos_x, pos_y + + @torch.no_grad() + def encode_boxes(self, x, y, w, h): + pos_x, pos_y = self._encode_xy(x, y) + pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1) + return pos + + encode = encode_boxes # Backwards compatibility + + @torch.no_grad() + def encode_points(self, x, y, labels): + (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape + assert bx == by and nx == ny and bx == bl and nx == nl + pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten()) + pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1) + pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2) + return pos + + @torch.no_grad() + def forward(self, x): + cache_key = None + cache_key = (x.shape[-2], x.shape[-1]) + if cache_key in self.cache: + return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1) + y_embed = ( + torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device) + .view(1, -1, 1) + .repeat(x.shape[0], 1, x.shape[-1]) + ) + x_embed = ( + torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device) + .view(1, 1, -1) + .repeat(x.shape[0], x.shape[-2], 1) + ) + + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + if cache_key is not None: + self.cache[cache_key] = pos[0] + return pos diff --git a/third_party/sam3/sam3/model/sam1_task_predictor.py b/third_party/sam3/sam3/model/sam1_task_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..edcbf87ebc9f68b5ffaf0f9e3ed0b956c07980c1 --- /dev/null +++ b/third_party/sam3/sam3/model/sam1_task_predictor.py @@ -0,0 +1,457 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved +# All rights reserved. + +# pyre-unsafe + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +from PIL.Image import Image +from sam3.model.sam3_tracker_base import Sam3TrackerBase +from sam3.model.utils.sam1_utils import SAM2Transforms + + +# Adapted from https://github.com/facebookresearch/sam2/blob/main/sam2/sam2_image_predictor.py +class SAM3InteractiveImagePredictor(nn.Module): + def __init__( + self, + sam_model: Sam3TrackerBase, + mask_threshold=0.0, + max_hole_area=256.0, + max_sprinkle_area=0.0, + **kwargs, + ) -> None: + """ + Uses SAM-3 to calculate the image embedding for an image, and then + allow repeated, efficient mask prediction given prompts. + + Arguments: + sam_model : The model to use for mask prediction. + mask_threshold (float): The threshold to use when converting mask logits + to binary masks. Masks are thresholded at 0 by default. + max_hole_area (int): If max_hole_area > 0, we fill small holes in up to + the maximum area of max_hole_area in low_res_masks. + max_sprinkle_area (int): If max_sprinkle_area > 0, we remove small sprinkles up to + the maximum area of max_sprinkle_area in low_res_masks. + """ + super().__init__() + self.model = sam_model + self._transforms = SAM2Transforms( + resolution=self.model.image_size, + mask_threshold=mask_threshold, + max_hole_area=max_hole_area, + max_sprinkle_area=max_sprinkle_area, + ) + + # Predictor state + self._is_image_set = False + self._features = None + self._orig_hw = None + # Whether the predictor is set for single image or a batch of images + self._is_batch = False + + # Predictor config + self.mask_threshold = mask_threshold + + # Spatial dim for backbone feature maps + self._bb_feat_sizes = [ + (288, 288), + (144, 144), + (72, 72), + ] + + @torch.no_grad() + def set_image( + self, + image: Union[np.ndarray, Image], + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. + + Arguments: + image (np.ndarray or PIL Image): The input image to embed in RGB format. The image should be in HWC format if np.ndarray, or WHC format if PIL Image + with pixel values in [0, 255]. + image_format (str): The color format of the image, in ['RGB', 'BGR']. + """ + self.reset_predictor() + # Transform the image to the form expected by the model + if isinstance(image, np.ndarray): + logging.info("For numpy array image, we assume (HxWxC) format") + self._orig_hw = [image.shape[:2]] + elif isinstance(image, Image): + w, h = image.size + self._orig_hw = [(h, w)] + else: + raise NotImplementedError("Image format not supported") + + input_image = self._transforms(image) + input_image = input_image[None, ...].to(self.device) + + assert ( + len(input_image.shape) == 4 and input_image.shape[1] == 3 + ), f"input_image must be of size 1x3xHxW, got {input_image.shape}" + logging.info("Computing image embeddings for the provided image...") + backbone_out = self.model.forward_image(input_image) + ( + _, + vision_feats, + _, + _, + ) = self.model._prepare_backbone_features(backbone_out) + # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos + vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed + + feats = [ + feat.permute(1, 2, 0).view(1, -1, *feat_size) + for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1]) + ][::-1] + self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]} + self._is_image_set = True + logging.info("Image embeddings computed.") + + @torch.no_grad() + def set_image_batch( + self, + image_list: List[Union[np.ndarray]], + ) -> None: + """ + Calculates the image embeddings for the provided image batch, allowing + masks to be predicted with the 'predict_batch' method. + + Arguments: + image_list (List[np.ndarray]): The input images to embed in RGB format. The image should be in HWC format if np.ndarray + with pixel values in [0, 255]. + """ + self.reset_predictor() + assert isinstance(image_list, list) + self._orig_hw = [] + for image in image_list: + assert isinstance( + image, np.ndarray + ), "Images are expected to be an np.ndarray in RGB format, and of shape HWC" + self._orig_hw.append(image.shape[:2]) + # Transform the image to the form expected by the model + img_batch = self._transforms.forward_batch(image_list) + img_batch = img_batch.to(self.device) + batch_size = img_batch.shape[0] + assert ( + len(img_batch.shape) == 4 and img_batch.shape[1] == 3 + ), f"img_batch must be of size Bx3xHxW, got {img_batch.shape}" + logging.info("Computing image embeddings for the provided images...") + backbone_out = self.model.forward_image(img_batch) + ( + _, + vision_feats, + _, + _, + ) = self.model._prepare_backbone_features(backbone_out) + # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos + vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed + + feats = [ + feat.permute(1, 2, 0).view(batch_size, -1, *feat_size) + for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1]) + ][::-1] + self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]} + self._is_image_set = True + self._is_batch = True + logging.info("Image embeddings computed.") + + def predict_batch( + self, + point_coords_batch: List[np.ndarray] = None, + point_labels_batch: List[np.ndarray] = None, + box_batch: List[np.ndarray] = None, + mask_input_batch: List[np.ndarray] = None, + multimask_output: bool = True, + return_logits: bool = False, + normalize_coords=True, + ) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]: + """This function is very similar to predict(...), however it is used for batched mode, when the model is expected to generate predictions on multiple images. + It returns a tuple of lists of masks, ious, and low_res_masks_logits. + """ + assert self._is_batch, "This function should only be used when in batched mode" + if not self._is_image_set: + raise RuntimeError( + "An image must be set with .set_image_batch(...) before mask prediction." + ) + num_images = len(self._features["image_embed"]) + all_masks = [] + all_ious = [] + all_low_res_masks = [] + for img_idx in range(num_images): + # Transform input prompts + point_coords = ( + point_coords_batch[img_idx] if point_coords_batch is not None else None + ) + point_labels = ( + point_labels_batch[img_idx] if point_labels_batch is not None else None + ) + box = box_batch[img_idx] if box_batch is not None else None + mask_input = ( + mask_input_batch[img_idx] if mask_input_batch is not None else None + ) + mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts( + point_coords, + point_labels, + box, + mask_input, + normalize_coords, + img_idx=img_idx, + ) + masks, iou_predictions, low_res_masks = self._predict( + unnorm_coords, + labels, + unnorm_box, + mask_input, + multimask_output, + return_logits=return_logits, + img_idx=img_idx, + ) + masks_np = masks.squeeze(0).float().detach().cpu().numpy() + iou_predictions_np = ( + iou_predictions.squeeze(0).float().detach().cpu().numpy() + ) + low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy() + all_masks.append(masks_np) + all_ious.append(iou_predictions_np) + all_low_res_masks.append(low_res_masks_np) + + return all_masks, all_ious, all_low_res_masks + + def predict( + self, + point_coords: Optional[np.ndarray] = None, + point_labels: Optional[np.ndarray] = None, + box: Optional[np.ndarray] = None, + mask_input: Optional[np.ndarray] = None, + multimask_output: bool = True, + return_logits: bool = False, + normalize_coords=True, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Predict masks for the given input prompts, using the currently set image. + + Arguments: + point_coords (np.ndarray or None): A Nx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (np.ndarray or None): A length N array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + box (np.ndarray or None): A length 4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form 1xHxW, where + for SAM, H=W=256. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + normalize_coords (bool): If true, the point coordinates will be normalized to the range [0,1] and point_coords is expected to be wrt. image dimensions. + + Returns: + (np.ndarray): The output masks in CxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (np.ndarray): An array of length C containing the model's + predictions for the quality of each mask. + (np.ndarray): An array of shape CxHxW, where C is the number + of masks and H=W=256. These low resolution logits can be passed to + a subsequent iteration as mask input. + """ + if not self._is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) before mask prediction." + ) + + # Transform input prompts + + mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts( + point_coords, point_labels, box, mask_input, normalize_coords + ) + + masks, iou_predictions, low_res_masks = self._predict( + unnorm_coords, + labels, + unnorm_box, + mask_input, + multimask_output, + return_logits=return_logits, + ) + + masks_np = masks.squeeze(0).float().detach().cpu().numpy() + iou_predictions_np = iou_predictions.squeeze(0).float().detach().cpu().numpy() + low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy() + return masks_np, iou_predictions_np, low_res_masks_np + + def _prep_prompts( + self, point_coords, point_labels, box, mask_logits, normalize_coords, img_idx=-1 + ): + unnorm_coords, labels, unnorm_box, mask_input = None, None, None, None + if point_coords is not None: + assert ( + point_labels is not None + ), "point_labels must be supplied if point_coords is supplied." + point_coords = torch.as_tensor( + point_coords, dtype=torch.float, device=self.device + ) + unnorm_coords = self._transforms.transform_coords( + point_coords, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx] + ) + labels = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) + if len(unnorm_coords.shape) == 2: + unnorm_coords, labels = unnorm_coords[None, ...], labels[None, ...] + if box is not None: + box = torch.as_tensor(box, dtype=torch.float, device=self.device) + unnorm_box = self._transforms.transform_boxes( + box, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx] + ) # Bx2x2 + if mask_logits is not None: + mask_input = torch.as_tensor( + mask_logits, dtype=torch.float, device=self.device + ) + if len(mask_input.shape) == 3: + mask_input = mask_input[None, :, :, :] + return mask_input, unnorm_coords, labels, unnorm_box + + @torch.no_grad() + def _predict( + self, + point_coords: Optional[torch.Tensor], + point_labels: Optional[torch.Tensor], + boxes: Optional[torch.Tensor] = None, + mask_input: Optional[torch.Tensor] = None, + multimask_output: bool = True, + return_logits: bool = False, + img_idx: int = -1, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks for the given input prompts, using the currently set image. + Input prompts are batched torch tensors and are expected to already be + transformed to the input frame using SAM2Transforms. + + Arguments: + point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (torch.Tensor or None): A BxN array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + boxes (np.ndarray or None): A Bx4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form Bx1xHxW, where + for SAM, H=W=256. Masks returned by a previous iteration of the + predict method do not need further transformation. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (torch.Tensor): The output masks in BxCxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (torch.Tensor): An array of shape BxC containing the model's + predictions for the quality of each mask. + (torch.Tensor): An array of shape BxCxHxW, where C is the number + of masks and H=W=256. These low res logits can be passed to + a subsequent iteration as mask input. + """ + if not self._is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) before mask prediction." + ) + + if point_coords is not None: + concat_points = (point_coords, point_labels) + else: + concat_points = None + + # Embed prompts + if boxes is not None: + box_coords = boxes.reshape(-1, 2, 2) + box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=boxes.device) + box_labels = box_labels.repeat(boxes.size(0), 1) + # we merge "boxes" and "points" into a single "concat_points" input (where + # boxes are added at the beginning) to sam_prompt_encoder + if concat_points is not None: + concat_coords = torch.cat([box_coords, concat_points[0]], dim=1) + concat_labels = torch.cat([box_labels, concat_points[1]], dim=1) + concat_points = (concat_coords, concat_labels) + else: + concat_points = (box_coords, box_labels) + + sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder( + points=concat_points, + boxes=None, + masks=mask_input, + ) + + # Predict masks + batched_mode = ( + concat_points is not None and concat_points[0].shape[0] > 1 + ) # multi object prediction + high_res_features = [ + feat_level[img_idx].unsqueeze(0) + for feat_level in self._features["high_res_feats"] + ] + low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder( + image_embeddings=self._features["image_embed"][img_idx].unsqueeze(0), + image_pe=self.model.sam_prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + repeat_image=batched_mode, + high_res_features=high_res_features, + ) + + # Upscale the masks to the original image resolution + masks = self._transforms.postprocess_masks( + low_res_masks, self._orig_hw[img_idx] + ) + low_res_masks = torch.clamp(low_res_masks, -32.0, 32.0) + if not return_logits: + masks = masks > self.mask_threshold + + return masks, iou_predictions, low_res_masks + + def get_image_embedding(self) -> torch.Tensor: + """ + Returns the image embeddings for the currently set image, with + shape 1xCxHxW, where C is the embedding dimension and (H,W) are + the embedding spatial dimension of SAM (typically C=256, H=W=64). + """ + if not self._is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) to generate an embedding." + ) + assert ( + self._features is not None + ), "Features must exist if an image has been set." + return self._features["image_embed"] + + @property + def device(self) -> torch.device: + return self.model.device + + def reset_predictor(self) -> None: + """ + Resets the image embeddings and other state variables. + """ + self._is_image_set = False + self._features = None + self._orig_hw = None + self._is_batch = False diff --git a/third_party/sam3/sam3/model/sam3_base_predictor.py b/third_party/sam3/sam3/model/sam3_base_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..a02d2137153df24fd792801860df8a815c5d4c5c --- /dev/null +++ b/third_party/sam3/sam3/model/sam3_base_predictor.py @@ -0,0 +1,322 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +""" +Base predictor class shared by SAM3 and SAM3.1 (multiplex) video predictors. + +Provides the common handle_request/handle_stream_request API and session management. +Subclasses only need to override methods where their behavior differs. +""" + +import gc +import time +import uuid +from typing import Dict, List, Optional + +import torch +from sam3.logger import get_logger + +logger = get_logger(__name__) + + +class Sam3BasePredictor: + """ + Base class for SAM3 video predictors. Provides: + - Session management (start, reset, close) + - Request dispatch (handle_request / handle_stream_request) + - Common add_prompt / propagate_in_video / remove_object / reset_session / close_session + + Subclasses must set `self.model` and `self._all_inference_states` before use. + """ + + def __init__(self): + # Subclasses must populate these + self.model = None + self._all_inference_states: Dict[str, dict] = {} + + # ── Request dispatch ────────────────────────────────────────────── + + @torch.inference_mode() + def handle_request(self, request): + """Dispatch a request based on its type.""" + request_type = request["type"] + if request_type == "start_session": + return self.start_session( + resource_path=request["resource_path"], + session_id=request.get("session_id", None), + offload_video_to_cpu=request.get("offload_video_to_cpu", False), + ) + elif request_type == "add_prompt": + return self.add_prompt( + session_id=request["session_id"], + frame_idx=request["frame_index"], + text=request.get("text", None), + points=request.get("points", None), + point_labels=request.get("point_labels", None), + clear_old_points=request.get("clear_old_points", True), + bounding_boxes=request.get("bounding_boxes", None), + bounding_box_labels=request.get("bounding_box_labels", None), + clear_old_boxes=request.get("clear_old_boxes", True), + output_prob_thresh=request.get( + "output_prob_thresh", + getattr(self, "default_output_prob_thresh", 0.5), + ), + obj_id=request.get("obj_id", None), + ) + elif request_type == "remove_object": + return self.remove_object( + session_id=request["session_id"], + frame_idx=request.get("frame_index", 0), + obj_id=request["obj_id"], + ) + elif request_type == "reset_session": + return self.reset_session(session_id=request["session_id"]) + elif request_type == "cancel_propagation": + return self.cancel_propagation(session_id=request["session_id"]) + elif request_type == "close_session": + return self.close_session( + session_id=request["session_id"], + run_gc_collect=request.get("run_gc_collect", True), + ) + else: + raise RuntimeError(f"invalid request type: {request_type}") + + @torch.inference_mode() + def handle_stream_request(self, request): + """Dispatch a stream request based on its type.""" + request_type = request["type"] + if request_type == "propagate_in_video": + yield from self.propagate_in_video( + session_id=request["session_id"], + propagation_direction=request.get("propagation_direction", "both"), + start_frame_idx=request.get("start_frame_index", None), + max_frame_num_to_track=request.get("max_frame_num_to_track", None), + output_prob_thresh=request.get( + "output_prob_thresh", + getattr(self, "default_output_prob_thresh", 0.5), + ), + ) + else: + raise RuntimeError(f"invalid request type: {request_type}") + + # ── Session management ──────────────────────────────────────────── + + def start_session( + self, + resource_path, + session_id=None, + offload_video_to_cpu=False, + ): + """Start a new inference session on a video directory or path.""" + init_kwargs = dict( + resource_path=resource_path, + offload_video_to_cpu=offload_video_to_cpu, + ) + if hasattr(self, "async_loading_frames"): + init_kwargs["async_loading_frames"] = self.async_loading_frames + if hasattr(self, "video_loader_type"): + init_kwargs["video_loader_type"] = self.video_loader_type + inference_state = self.model.init_state(**init_kwargs) + + if not session_id: + session_id = str(uuid.uuid4()) + self._all_inference_states[session_id] = { + "state": inference_state, + "session_id": session_id, + "start_time": time.time(), + "last_use_time": time.time(), + } + logger.info(f"started new session {session_id}") + return {"session_id": session_id} + + def add_prompt( + self, + session_id: str, + frame_idx: int, + text: Optional[str] = None, + points=None, + point_labels=None, + clear_old_points: bool = True, + bounding_boxes=None, + bounding_box_labels=None, + clear_old_boxes: bool = True, + output_prob_thresh: float = 0.5, + obj_id: Optional[int] = None, + ): + """Add text, box and/or point prompt on a specific video frame.""" + session = self._get_session(session_id) + inference_state = session["state"] + self._extend_expiration_time(session) + + # Convert lists to tensors if needed + if points is not None and not isinstance(points, torch.Tensor): + points = torch.tensor(points, dtype=torch.float32) + if point_labels is not None and not isinstance(point_labels, torch.Tensor): + point_labels = torch.tensor(point_labels, dtype=torch.int32) + if bounding_boxes is not None and not isinstance(bounding_boxes, torch.Tensor): + bounding_boxes = torch.tensor(bounding_boxes, dtype=torch.float32) + if bounding_box_labels is not None and not isinstance( + bounding_box_labels, torch.Tensor + ): + bounding_box_labels = torch.tensor(bounding_box_labels, dtype=torch.int32) + + kwargs = dict( + inference_state=inference_state, + frame_idx=frame_idx, + text_str=text, + points=points, + point_labels=point_labels, + clear_old_points=clear_old_points, + boxes_xywh=bounding_boxes, + box_labels=bounding_box_labels, + clear_old_boxes=clear_old_boxes, + output_prob_thresh=output_prob_thresh, + ) + if obj_id is not None: + kwargs["obj_id"] = obj_id + + # Filter kwargs to only pass what the model accepts + # (SAM3 has a simpler add_prompt than SAM3.1) + import inspect + + sig = inspect.signature(self.model.add_prompt) + valid_params = set(sig.parameters.keys()) + filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params} + + frame_idx, outputs = self.model.add_prompt(**filtered_kwargs) + return {"frame_index": frame_idx, "outputs": outputs} + + def remove_object( + self, + session_id: str, + frame_idx: int = 0, + obj_id: int = 0, + is_user_action: bool = True, + ): + """Remove an object from tracking.""" + session = self._get_session(session_id) + inference_state = session["state"] + self._extend_expiration_time(session) + + result = self.model.remove_object( + inference_state, obj_id, frame_idx=frame_idx, is_user_action=is_user_action + ) + # Handle both return conventions + if result is None or (isinstance(result, tuple) and result[1] is None): + import numpy as np + + out_obj_ids = torch.zeros(0, dtype=torch.int64) + out_binary_masks = torch.zeros( + 0, + inference_state["orig_height"], + inference_state["orig_width"], + dtype=torch.bool, + ) + out_boxes_xywh = torch.zeros(0, 4, dtype=torch.float32) + outputs = { + "out_obj_ids": out_obj_ids.cpu().numpy(), + "out_boxes_xywh": out_boxes_xywh.cpu().numpy(), + "out_binary_masks": out_binary_masks.cpu().numpy(), + } + elif isinstance(result, tuple): + _, outputs = result + else: + outputs = result + return {"frame_index": frame_idx, "outputs": outputs} + + def cancel_propagation(self, session_id): + """Cancel any ongoing propagation. No-op if not supported by the model.""" + session = self._get_session(session_id) + inference_state = session["state"] + self._extend_expiration_time(session) + if hasattr(self.model, "cancel_propagation"): + self.model.cancel_propagation(inference_state) + return {"is_success": True} + + def propagate_in_video( + self, + session_id, + propagation_direction="both", + start_frame_idx=None, + max_frame_num_to_track=None, + output_prob_thresh=0.5, + **kwargs, + ): + """Propagate the added prompts to get results on all video frames.""" + try: + session = self._get_session(session_id) + inference_state = session["state"] + self._extend_expiration_time(session) + if propagation_direction not in ["both", "forward", "backward"]: + raise ValueError( + f"invalid propagation direction: {propagation_direction}" + ) + + propagate_kwargs = dict( + inference_state=inference_state, + start_frame_idx=start_frame_idx, + max_frame_num_to_track=max_frame_num_to_track, + ) + # Only pass output_prob_thresh / extra kwargs if the model supports them + import inspect + + sig = inspect.signature(self.model.propagate_in_video) + if "output_prob_thresh" in sig.parameters: + propagate_kwargs["output_prob_thresh"] = output_prob_thresh + for k, v in kwargs.items(): + if k in sig.parameters: + propagate_kwargs[k] = v + + # Forward propagation + if propagation_direction in ["both", "forward"]: + for frame_idx, outputs in self.model.propagate_in_video( + **propagate_kwargs, + reverse=False, + ): + yield {"frame_index": frame_idx, "outputs": outputs} + # Backward propagation + if propagation_direction in ["both", "backward"]: + for frame_idx, outputs in self.model.propagate_in_video( + **propagate_kwargs, + reverse=True, + ): + yield {"frame_index": frame_idx, "outputs": outputs} + finally: + logger.info(f"propagation ended in session {session_id}") + + def reset_session(self, session_id): + """Reset the session to its initial state.""" + session = self._get_session(session_id) + inference_state = session["state"] + self._extend_expiration_time(session) + self.model.reset_state(inference_state) + return {"is_success": True} + + def close_session(self, session_id, run_gc_collect=True): + """Close a session. Idempotent.""" + session = self._all_inference_states.pop(session_id, None) + if session is None: + logger.warning(f"cannot close session {session_id} as it does not exist") + else: + del session + if run_gc_collect: + gc.collect() + logger.info(f"removed session {session_id}") + return {"is_success": True} + + def _get_session(self, session_id): + session = self._all_inference_states.get(session_id, None) + if session is None: + raise RuntimeError( + f"Cannot find session {session_id}; it might have expired" + ) + return session + + def _extend_expiration_time(self, session): + """Update last-use time for session expiration tracking.""" + session["last_use_time"] = time.time() + + def shutdown(self): + """Shutdown the predictor and clear all sessions.""" + self._all_inference_states.clear() diff --git a/third_party/sam3/sam3/model/sam3_image.py b/third_party/sam3/sam3/model/sam3_image.py new file mode 100644 index 0000000000000000000000000000000000000000..ecd5c00d6c73e5c332926dc84d416c42ba4c140d --- /dev/null +++ b/third_party/sam3/sam3/model/sam3_image.py @@ -0,0 +1,911 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +import os +from copy import deepcopy +from typing import Dict, List, Optional, Tuple + +import numpy as np +import torch +from sam3.model.model_misc import SAM3Output +from sam3.model.sam1_task_predictor import SAM3InteractiveImagePredictor +from sam3.model.vl_combiner import SAM3VLBackbone +from sam3.perflib.nms import nms_masks +from sam3.train.data.collator import BatchedDatapoint + +from .act_ckpt_utils import activation_ckpt_wrapper +from .box_ops import box_cxcywh_to_xyxy +from .data_misc import FindStage +from .geometry_encoders import Prompt +from .model_misc import inverse_sigmoid + + +def _update_out(out, out_name, out_value, auxiliary=True, update_aux=True): + out[out_name] = out_value[-1] if auxiliary else out_value + if auxiliary and update_aux: + if "aux_outputs" not in out: + out["aux_outputs"] = [{} for _ in range(len(out_value) - 1)] + assert len(out["aux_outputs"]) == len(out_value) - 1 + for aux_output, aux_value in zip(out["aux_outputs"], out_value[:-1]): + aux_output[out_name] = aux_value + + +class Sam3Image(torch.nn.Module): + TEXT_ID_FOR_TEXT = 0 + TEXT_ID_FOR_VISUAL = 1 + TEXT_ID_FOR_GEOMETRIC = 2 + + def __init__( + self, + backbone: SAM3VLBackbone, + transformer, + input_geometry_encoder, + segmentation_head=None, + num_feature_levels=1, + o2m_mask_predict=True, + dot_prod_scoring=None, + use_instance_query: bool = True, + multimask_output: bool = True, + use_act_checkpoint_seg_head: bool = True, + interactivity_in_encoder: bool = True, + matcher=None, + use_dot_prod_scoring=True, + supervise_joint_box_scores: bool = False, # only relevant if using presence token/score + detach_presence_in_joint_score: bool = False, # only relevant if using presence token/score + separate_scorer_for_instance: bool = False, + num_interactive_steps_val: int = 0, + inst_interactive_predictor: SAM3InteractiveImagePredictor = None, + **kwargs, + ): + super().__init__() + self.backbone = backbone + self.geometry_encoder = input_geometry_encoder + self.transformer = transformer + self.hidden_dim = transformer.d_model + self.num_feature_levels = num_feature_levels + self.segmentation_head = segmentation_head + + self.o2m_mask_predict = o2m_mask_predict + + self.dot_prod_scoring = dot_prod_scoring + self.use_act_checkpoint_seg_head = use_act_checkpoint_seg_head + self.interactivity_in_encoder = interactivity_in_encoder + self.matcher = matcher + + self.num_interactive_steps_val = num_interactive_steps_val + self.use_dot_prod_scoring = use_dot_prod_scoring + + if self.use_dot_prod_scoring: + assert dot_prod_scoring is not None + self.dot_prod_scoring = dot_prod_scoring + self.instance_dot_prod_scoring = None + if separate_scorer_for_instance: + self.instance_dot_prod_scoring = deepcopy(dot_prod_scoring) + else: + self.class_embed = torch.nn.Linear(self.hidden_dim, 1) + self.instance_class_embed = None + if separate_scorer_for_instance: + self.instance_class_embed = deepcopy(self.class_embed) + + self.supervise_joint_box_scores = supervise_joint_box_scores + self.detach_presence_in_joint_score = detach_presence_in_joint_score + + # verify the number of queries for O2O and O2M + num_o2o_static = self.transformer.decoder.num_queries + num_o2m_static = self.transformer.decoder.num_o2m_queries + assert num_o2m_static == (num_o2o_static if self.transformer.decoder.dac else 0) + self.dac = self.transformer.decoder.dac + + self.use_instance_query = use_instance_query + self.multimask_output = multimask_output + + self.inst_interactive_predictor = inst_interactive_predictor + + @property + def device(self): + self._device = getattr(self, "_device", None) or next(self.parameters()).device + return self._device + + def to(self, *args, **kwargs): + # clear cached _device in case the model is moved to a different device + self._device = None + return super().to(*args, **kwargs) + + def _get_img_feats(self, backbone_out, img_ids): + """Retrieve correct image features from backbone output.""" + if "backbone_fpn" in backbone_out: + if "id_mapping" in backbone_out and backbone_out["id_mapping"] is not None: + img_ids = backbone_out["id_mapping"][img_ids] + # If this assert fails, it likely means we're requesting different img_ids (perhaps a different frame?) + # We currently don't expect this to happen. We could technically trigger a recompute here, + # but likely at the cost of a cpu<->gpu sync point, which would deteriorate perf + torch._assert_async((img_ids >= 0).all()) + + vis_feats = backbone_out["backbone_fpn"][-self.num_feature_levels :] + vis_pos_enc = backbone_out["vision_pos_enc"][-self.num_feature_levels :] + vis_feat_sizes = [x.shape[-2:] for x in vis_pos_enc] # (H, W) shapes + # index and flatten visual features NxCxHxW => HWxNxC (batch-first => seq-first) + img_feats = [x[img_ids].flatten(2).permute(2, 0, 1) for x in vis_feats] + img_pos_embeds = [ + x[img_ids].flatten(2).permute(2, 0, 1) for x in vis_pos_enc + ] + return backbone_out, img_feats, img_pos_embeds, vis_feat_sizes + + # Image features not available in backbone output, so we compute them on the fly + # This case likely occurs for video. In that case, we want to forward only the current frame + img_batch = backbone_out["img_batch_all_stages"] + if img_ids.numel() > 1: + # Only forward backbone on unique image ids to avoid repetitive computation + unique_ids, _ = torch.unique(img_ids, return_inverse=True) + else: + unique_ids, _ = img_ids, slice(None) + # Compute the image features on those unique image ids + # note: we allow using a list (or other indexable types) of tensors as img_batch + # (e.g. for async frame loading in demo). In this case we index img_batch.tensors directly + if isinstance(img_batch, torch.Tensor): + image = img_batch[unique_ids] + elif unique_ids.numel() == 1: + image = img_batch[unique_ids.item()].unsqueeze(0) + else: + image = torch.stack([img_batch[i] for i in unique_ids.tolist()]) + # `img_batch` might be fp16 and offloaded to CPU + image = image.to(dtype=torch.float32, device=self.device) + # Next time we call this function, we want to remember which indices we computed + id_mapping = torch.full( + (len(img_batch),), -1, dtype=torch.long, device=self.device + ) + id_mapping[unique_ids] = torch.arange(len(unique_ids), device=self.device) + backbone_out = { + **backbone_out, + **self.backbone.forward_image(image), + "id_mapping": id_mapping, + } + assert "backbone_fpn" in backbone_out + return self._get_img_feats(backbone_out, img_ids=img_ids) + + def _encode_prompt( + self, + backbone_out, + find_input, + geometric_prompt, + visual_prompt_embed=None, + visual_prompt_mask=None, + encode_text=True, + prev_mask_pred=None, + ): + # index text features (note that regardless of early or late fusion, the batch size of + # `txt_feats` is always the number of *prompts* in the encoder) + txt_ids = find_input.text_ids + txt_feats = backbone_out["language_features"][:, txt_ids] + txt_masks = backbone_out["language_mask"][txt_ids] + + feat_tuple = self._get_img_feats(backbone_out, find_input.img_ids) + backbone_out, img_feats, img_pos_embeds, vis_feat_sizes = feat_tuple + + if prev_mask_pred is not None: + img_feats = [img_feats[-1] + prev_mask_pred] + # Encode geometry + geo_feats, geo_masks = self.geometry_encoder( + geo_prompt=geometric_prompt, + img_feats=img_feats, + img_sizes=vis_feat_sizes, + img_pos_embeds=img_pos_embeds, + ) + if visual_prompt_embed is None: + visual_prompt_embed = torch.zeros( + (0, *geo_feats.shape[1:]), device=geo_feats.device + ) + visual_prompt_mask = torch.zeros( + (*geo_masks.shape[:-1], 0), + device=geo_masks.device, + dtype=geo_masks.dtype, + ) + if encode_text: + prompt = torch.cat([txt_feats, geo_feats, visual_prompt_embed], dim=0) + prompt_mask = torch.cat([txt_masks, geo_masks, visual_prompt_mask], dim=1) + else: + prompt = torch.cat([geo_feats, visual_prompt_embed], dim=0) + prompt_mask = torch.cat([geo_masks, visual_prompt_mask], dim=1) + return prompt, prompt_mask, backbone_out + + def _run_encoder( + self, + backbone_out, + find_input, + prompt, + prompt_mask, + encoder_extra_kwargs: Optional[Dict] = None, + ): + feat_tuple = self._get_img_feats(backbone_out, find_input.img_ids) + backbone_out, img_feats, img_pos_embeds, vis_feat_sizes = feat_tuple + + # Run the encoder + prompt_pos_embed = torch.zeros_like(prompt) + # make a copy of the image feature lists since the encoder may modify these lists in-place + memory = self.transformer.encoder( + src=img_feats.copy(), + src_key_padding_mask=None, + src_pos=img_pos_embeds.copy(), + prompt=prompt, + prompt_pos=prompt_pos_embed, + prompt_key_padding_mask=prompt_mask, + feat_sizes=vis_feat_sizes, + encoder_extra_kwargs=encoder_extra_kwargs, + ) + encoder_out = { + # encoded image features + "encoder_hidden_states": memory["memory"], + "pos_embed": memory["pos_embed"], + "padding_mask": memory["padding_mask"], + "level_start_index": memory["level_start_index"], + "spatial_shapes": memory["spatial_shapes"], + "valid_ratios": memory["valid_ratios"], + "vis_feat_sizes": vis_feat_sizes, + # encoded text features (or other prompts) + "prompt_before_enc": prompt, + "prompt_after_enc": memory.get("memory_text", prompt), + "prompt_mask": prompt_mask, + } + return backbone_out, encoder_out, feat_tuple + + def _run_decoder( + self, + pos_embed, + memory, + src_mask, + out, + prompt, + prompt_mask, + encoder_out, + ): + bs = memory.shape[1] + query_embed = self.transformer.decoder.query_embed.weight + tgt = query_embed.unsqueeze(1).repeat(1, bs, 1) + + apply_dac = self.transformer.decoder.dac and self.training + hs, reference_boxes, dec_presence_out, dec_presence_feats = ( + self.transformer.decoder( + tgt=tgt, + memory=memory, + memory_key_padding_mask=src_mask, + pos=pos_embed, + reference_boxes=None, + level_start_index=encoder_out["level_start_index"], + spatial_shapes=encoder_out["spatial_shapes"], + valid_ratios=encoder_out["valid_ratios"], + tgt_mask=None, + memory_text=prompt, + text_attention_mask=prompt_mask, + apply_dac=apply_dac, + ) + ) + hs = hs.transpose(1, 2) # seq-first to batch-first + reference_boxes = reference_boxes.transpose(1, 2) # seq-first to batch-first + if dec_presence_out is not None: + # seq-first to batch-first + dec_presence_out = dec_presence_out.transpose(1, 2) + + out["presence_feats"] = dec_presence_feats + self._update_scores_and_boxes( + out, + hs, + reference_boxes, + prompt, + prompt_mask, + dec_presence_out=dec_presence_out, + ) + return out, hs + + def _update_scores_and_boxes( + self, + out, + hs, + reference_boxes, + prompt, + prompt_mask, + dec_presence_out=None, + is_instance_prompt=False, + ): + apply_dac = self.transformer.decoder.dac and self.training + num_o2o = (hs.size(2) // 2) if apply_dac else hs.size(2) + num_o2m = hs.size(2) - num_o2o + assert num_o2m == (num_o2o if apply_dac else 0) + out["queries"] = hs[-1][:, :num_o2o] # remove o2m queries if there are any + # Add O2M queries for 3D head (SAM3_3D extension) + if num_o2m > 0 and self.training: + out["queries_o2m"] = hs[-1][:, num_o2o:] + # score prediction + if self.use_dot_prod_scoring: + dot_prod_scoring_head = self.dot_prod_scoring + if is_instance_prompt and self.instance_dot_prod_scoring is not None: + dot_prod_scoring_head = self.instance_dot_prod_scoring + outputs_class = dot_prod_scoring_head(hs, prompt, prompt_mask) + else: + class_embed_head = self.class_embed + if is_instance_prompt and self.instance_class_embed is not None: + class_embed_head = self.instance_class_embed + outputs_class = class_embed_head(hs) + + # box prediction + box_head = self.transformer.decoder.bbox_embed + if ( + is_instance_prompt + and self.transformer.decoder.instance_bbox_embed is not None + ): + box_head = self.transformer.decoder.instance_bbox_embed + anchor_box_offsets = box_head(hs) + reference_boxes_inv_sig = inverse_sigmoid(reference_boxes) + outputs_coord = (reference_boxes_inv_sig + anchor_box_offsets).sigmoid() + outputs_boxes_xyxy = box_cxcywh_to_xyxy(outputs_coord) + + if dec_presence_out is not None: + _update_out( + out, "presence_logit_dec", dec_presence_out, update_aux=self.training + ) + + if self.supervise_joint_box_scores: + assert dec_presence_out is not None + prob_dec_presence_out = dec_presence_out.clone().sigmoid() + if self.detach_presence_in_joint_score: + prob_dec_presence_out = prob_dec_presence_out.detach() + + outputs_class = inverse_sigmoid( + outputs_class.sigmoid() * prob_dec_presence_out.unsqueeze(2) + ).clamp(min=-10.0, max=10.0) + + _update_out( + out, "pred_logits", outputs_class[:, :, :num_o2o], update_aux=self.training + ) + _update_out( + out, "pred_boxes", outputs_coord[:, :, :num_o2o], update_aux=self.training + ) + _update_out( + out, + "pred_boxes_xyxy", + outputs_boxes_xyxy[:, :, :num_o2o], + update_aux=self.training, + ) + if num_o2m > 0 and self.training: + _update_out( + out, + "pred_logits_o2m", + outputs_class[:, :, num_o2o:], + update_aux=self.training, + ) + _update_out( + out, + "pred_boxes_o2m", + outputs_coord[:, :, num_o2o:], + update_aux=self.training, + ) + _update_out( + out, + "pred_boxes_xyxy_o2m", + outputs_boxes_xyxy[:, :, num_o2o:], + update_aux=self.training, + ) + + def _run_segmentation_heads( + self, + out, + backbone_out, + img_ids, + vis_feat_sizes, + encoder_hidden_states, + prompt, + prompt_mask, + hs, + ): + apply_dac = self.transformer.decoder.dac and self.training + if self.segmentation_head is not None: + num_o2o = (hs.size(2) // 2) if apply_dac else hs.size(2) + num_o2m = hs.size(2) - num_o2o + obj_queries = hs if self.o2m_mask_predict else hs[:, :, :num_o2o] + seg_head_outputs = activation_ckpt_wrapper(self.segmentation_head)( + backbone_feats=backbone_out["backbone_fpn"], + obj_queries=obj_queries, + image_ids=img_ids, + encoder_hidden_states=encoder_hidden_states, + act_ckpt_enable=self.training and self.use_act_checkpoint_seg_head, + prompt=prompt, + prompt_mask=prompt_mask, + ) + aux_masks = False # self.aux_loss and self.segmentation_head.aux_masks + for k, v in seg_head_outputs.items(): + if k in self.segmentation_head.instance_keys: + _update_out(out, k, v[:, :num_o2o], auxiliary=aux_masks) + if ( + self.o2m_mask_predict and num_o2m > 0 + ): # handle o2m mask prediction + _update_out( + out, f"{k}_o2m", v[:, num_o2o:], auxiliary=aux_masks + ) + else: + out[k] = v + else: + backbone_out.pop("backbone_fpn", None) + + def _get_best_mask(self, out): + prev_mask_idx = out["pred_logits"].argmax(dim=1).squeeze(1) + batch_idx = torch.arange( + out["pred_logits"].shape[0], device=prev_mask_idx.device + ) + prev_mask_pred = out["pred_masks"][batch_idx, prev_mask_idx][:, None] + # Downsample mask to match image resolution. + prev_mask_pred = self.geometry_encoder.mask_encoder.mask_downsampler( + prev_mask_pred + ) + prev_mask_pred = prev_mask_pred.flatten(-2).permute(2, 0, 1) + + return prev_mask_pred + + def forward_grounding( + self, + backbone_out, + find_input, + find_target, + geometric_prompt: Prompt, + **kwargs, + ): + with torch.profiler.record_function("SAM3Image._encode_prompt"): + prompt, prompt_mask, backbone_out = self._encode_prompt( + backbone_out, find_input, geometric_prompt + ) + # Run the encoder + with torch.profiler.record_function("SAM3Image._run_encoder"): + backbone_out, encoder_out, _ = self._run_encoder( + backbone_out, find_input, prompt, prompt_mask + ) + out = { + "encoder_hidden_states": encoder_out["encoder_hidden_states"], + "prev_encoder_out": { + "encoder_out": encoder_out, + "backbone_out": backbone_out, + }, + } + + # Run the decoder + with torch.profiler.record_function("SAM3Image._run_decoder"): + out, hs = self._run_decoder( + memory=out["encoder_hidden_states"], + pos_embed=encoder_out["pos_embed"], + src_mask=encoder_out["padding_mask"], + out=out, + prompt=prompt, + prompt_mask=prompt_mask, + encoder_out=encoder_out, + ) + + # Run segmentation heads + with torch.profiler.record_function("SAM3Image._run_segmentation_heads"): + # Apply id_mapping to img_ids if backbone features were recomputed + seg_img_ids = find_input.img_ids + if "id_mapping" in backbone_out and backbone_out["id_mapping"] is not None: + seg_img_ids = backbone_out["id_mapping"][seg_img_ids] + self._run_segmentation_heads( + out=out, + backbone_out=backbone_out, + img_ids=seg_img_ids, + vis_feat_sizes=encoder_out["vis_feat_sizes"], + encoder_hidden_states=out["encoder_hidden_states"], + prompt=prompt, + prompt_mask=prompt_mask, + hs=hs, + ) + + if self.training or self.num_interactive_steps_val > 0: + self._compute_matching(out, self.back_convert(find_target)) + return out + + def _postprocess_out(self, out: Dict, multimask_output: bool = False): + # For multimask output, during eval we return the single best mask with the dict keys expected by the evaluators, but also return the multimasks output with new keys. + num_mask_boxes = out["pred_boxes"].size(1) + if not self.training and multimask_output and num_mask_boxes > 1: + out["multi_pred_logits"] = out["pred_logits"] + if "pred_masks" in out: + out["multi_pred_masks"] = out["pred_masks"] + out["multi_pred_boxes"] = out["pred_boxes"] + out["multi_pred_boxes_xyxy"] = out["pred_boxes_xyxy"] + + best_mask_idx = out["pred_logits"].argmax(1).squeeze(1) + batch_idx = torch.arange(len(best_mask_idx), device=best_mask_idx.device) + + out["pred_logits"] = out["pred_logits"][batch_idx, best_mask_idx].unsqueeze( + 1 + ) + if "pred_masks" in out: + out["pred_masks"] = out["pred_masks"][ + batch_idx, best_mask_idx + ].unsqueeze(1) + out["pred_boxes"] = out["pred_boxes"][batch_idx, best_mask_idx].unsqueeze(1) + out["pred_boxes_xyxy"] = out["pred_boxes_xyxy"][ + batch_idx, best_mask_idx + ].unsqueeze(1) + + return out + + def _get_geo_prompt_from_find_input(self, find_input: FindStage): + """Construct an initial geometric prompt from the find input.""" + point_embeddings, point_mask, point_labels = None, None, None + if find_input.input_points_before_embed is not None: + # Point embeddings are batch first, switch to seq first + point_embeddings = find_input.input_points_before_embed.transpose(0, 1) + + # they are stored as (x,y,label), so we unpack + point_labels = point_embeddings[..., -1] + point_embeddings = point_embeddings[..., :-1] + point_mask = find_input.input_points_mask + + geometric_prompt = Prompt( + box_embeddings=find_input.input_boxes_before_embed, + box_mask=find_input.input_boxes_mask, + box_labels=find_input.input_boxes_label, + point_embeddings=point_embeddings, + point_mask=point_mask, + point_labels=point_labels, + ) + return geometric_prompt + + def _get_dummy_prompt(self, num_prompts=1): + device = self.device + geometric_prompt = Prompt( + box_embeddings=torch.zeros(0, num_prompts, 4, device=device), + box_mask=torch.zeros(num_prompts, 0, device=device, dtype=torch.bool), + ) + return geometric_prompt + + def forward(self, input: BatchedDatapoint): + device = self.device + backbone_out = {"img_batch_all_stages": input.img_batch} + backbone_out.update(self.backbone.forward_image(input.img_batch)) + num_frames = len(input.find_inputs) + assert num_frames == 1 + + text_outputs = self.backbone.forward_text(input.find_text_batch, device=device) + backbone_out.update(text_outputs) + + previous_stages_out = SAM3Output( + iter_mode=SAM3Output.IterMode.LAST_STEP_PER_STAGE + ) + + find_input = input.find_inputs[0] + find_target = input.find_targets[0] + + if find_input.input_points is not None and find_input.input_points.numel() > 0: + print("Warning: Point prompts are ignored in PCS.") + + num_interactive_steps = 0 if self.training else self.num_interactive_steps_val + geometric_prompt = Prompt( + box_embeddings=find_input.input_boxes, + box_mask=find_input.input_boxes_mask, + box_labels=find_input.input_boxes_label, + ) + + # Init vars that are shared across the loop. + stage_outs = [] + for cur_step in range(num_interactive_steps + 1): + if cur_step > 0: + # We sample interactive geometric prompts (boxes, points) + geometric_prompt, _ = self.interactive_prompt_sampler.sample( + geo_prompt=geometric_prompt, + find_target=find_target, + previous_out=stage_outs[-1], + ) + out = self.forward_grounding( + backbone_out=backbone_out, + find_input=find_input, + find_target=find_target, + geometric_prompt=geometric_prompt.clone(), + ) + stage_outs.append(out) + + previous_stages_out.append(stage_outs) + return previous_stages_out + + def _compute_matching(self, out, targets): + out["indices"] = self.matcher(out, targets) + for aux_out in out.get("aux_outputs", []): + aux_out["indices"] = self.matcher(aux_out, targets) + + def back_convert(self, targets): + batched_targets = { + "boxes": targets.boxes.view(-1, 4), + "boxes_xyxy": box_cxcywh_to_xyxy(targets.boxes.view(-1, 4)), + "boxes_padded": targets.boxes_padded, + "positive_map": targets.boxes.new_ones(len(targets.boxes), 1), + "num_boxes": targets.num_boxes, + "masks": targets.segments, + "semantic_masks": targets.semantic_segments, + "is_valid_mask": targets.is_valid_segment, + "is_exhaustive": targets.is_exhaustive, + "object_ids_packed": targets.object_ids, + "object_ids_padded": targets.object_ids_padded, + } + return batched_targets + + def predict_inst( + self, + inference_state, + **kwargs, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + orig_h, orig_w = ( + inference_state["original_height"], + inference_state["original_width"], + ) + backbone_out = inference_state["backbone_out"]["sam2_backbone_out"] + ( + _, + vision_feats, + _, + _, + ) = self.inst_interactive_predictor.model._prepare_backbone_features( + backbone_out + ) + # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos + vision_feats[-1] = ( + vision_feats[-1] + self.inst_interactive_predictor.model.no_mem_embed + ) + feats = [ + feat.permute(1, 2, 0).view(1, -1, *feat_size) + for feat, feat_size in zip( + vision_feats[::-1], self.inst_interactive_predictor._bb_feat_sizes[::-1] + ) + ][::-1] + self.inst_interactive_predictor._features = { + "image_embed": feats[-1], + "high_res_feats": feats[:-1], + } + self.inst_interactive_predictor._is_image_set = True + self.inst_interactive_predictor._orig_hw = [(orig_h, orig_w)] + res = self.inst_interactive_predictor.predict(**kwargs) + self.inst_interactive_predictor._features = None + self.inst_interactive_predictor._is_image_set = False + return res + + def predict_inst_batch( + self, + inference_state, + *args, + **kwargs, + ) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]: + backbone_out = inference_state["backbone_out"]["sam2_backbone_out"] + ( + _, + vision_feats, + _, + _, + ) = self.inst_interactive_predictor.model._prepare_backbone_features( + backbone_out + ) + # Add no_mem_embed, which is added to the lowest res feat. map during training on videos + vision_feats[-1] = ( + vision_feats[-1] + self.inst_interactive_predictor.model.no_mem_embed + ) + batch_size = vision_feats[-1].shape[1] + orig_heights, orig_widths = ( + inference_state["original_heights"], + inference_state["original_widths"], + ) + assert ( + batch_size == len(orig_heights) == len(orig_widths) + ), f"Batch size mismatch in predict_inst_batch. Got {batch_size}, {len(orig_heights)}, {len(orig_widths)}" + feats = [ + feat.permute(1, 2, 0).view(batch_size, -1, *feat_size) + for feat, feat_size in zip( + vision_feats[::-1], self.inst_interactive_predictor._bb_feat_sizes[::-1] + ) + ][::-1] + self.inst_interactive_predictor._features = { + "image_embed": feats[-1], + "high_res_feats": feats[:-1], + } + self.inst_interactive_predictor._is_image_set = True + self.inst_interactive_predictor._is_batch = True + self.inst_interactive_predictor._orig_hw = [ + (orig_h, orig_w) for orig_h, orig_w in zip(orig_heights, orig_widths) + ] + res = self.inst_interactive_predictor.predict_batch(*args, **kwargs) + self.inst_interactive_predictor._features = None + self.inst_interactive_predictor._is_image_set = False + self.inst_interactive_predictor._is_batch = False + return res + + +class Sam3ImageOnVideoMultiGPU(Sam3Image): + def __init__( + self, *args, async_all_gather=True, gather_backbone_out=None, **kwargs + ): + super().__init__(*args, **kwargs) + self.rank = int(os.getenv("RANK", "0")) + self.world_size = int(os.getenv("WORLD_SIZE", "1")) + self.async_all_gather = async_all_gather + + # if gather_backbone is not set, default to gathering only for `SAM3VLBackbone` + if gather_backbone_out is None: + gather_backbone_out = isinstance(self.backbone, SAM3VLBackbone) + self.gather_backbone_out = gather_backbone_out + + def forward_video_grounding_multigpu( + self, + backbone_out, + find_inputs, + geometric_prompt: Prompt, + frame_idx, + num_frames, + # `multigpu_buffer` is a dict to cache detector's outputs in a chunk between different calls + multigpu_buffer, + track_in_reverse=False, + # whether to also return the SAM2 backbone features + return_sam2_backbone_feats=False, + # whether to perform NMS and suppress the scores of those detections removed by NMS + run_nms=False, + nms_prob_thresh=None, + nms_iou_thresh=None, + **kwargs, + ): + """ + Compute the detector's detection outputs in a distributed manner, where all GPUs process + a chunk of frames (equal to the number of GPUs) at once and store them in cache. + """ + # Step 1: fetch the detector outputs in the current chunk from buffer + frame_idx_curr_b = frame_idx - frame_idx % self.world_size + frame_idx_curr_e = min(frame_idx_curr_b + self.world_size, num_frames) + # in case the current frame's detection results are not in the buffer yet, build the current chunk + # (this should only happen on the first chunk, since we are also building the next chunk below) + if frame_idx not in multigpu_buffer: + with torch.profiler.record_function("build_multigpu_buffer_next_chunk1"): + self._build_multigpu_buffer_next_chunk( + backbone_out=backbone_out, + find_inputs=find_inputs, + geometric_prompt=geometric_prompt, + frame_idx_begin=frame_idx_curr_b, + frame_idx_end=frame_idx_curr_e, + num_frames=num_frames, + multigpu_buffer=multigpu_buffer, + run_nms=run_nms, + nms_prob_thresh=nms_prob_thresh, + nms_iou_thresh=nms_iou_thresh, + ) + + # read out the current frame's results from `multigpu_buffer` + out = {} + for k, (v, handle) in multigpu_buffer[frame_idx].items(): + if k.startswith("sam2_backbone_") and not return_sam2_backbone_feats: + continue + if handle is not None: + handle.wait() # wait for async all-gather to finish + out[k] = v + + # Step 2: remove detection outputs of the previous chunk from cache to save GPU memory + if not track_in_reverse and frame_idx_curr_b - self.world_size >= 0: + frame_idx_prev_e = frame_idx_curr_b + frame_idx_prev_b = frame_idx_curr_b - self.world_size + elif track_in_reverse and frame_idx_curr_e < num_frames: + frame_idx_prev_b = frame_idx_curr_e + frame_idx_prev_e = min(frame_idx_prev_b + self.world_size, num_frames) + else: + frame_idx_prev_b = frame_idx_prev_e = None + if frame_idx_prev_b is not None: + for frame_idx_rm in range(frame_idx_prev_b, frame_idx_prev_e): + multigpu_buffer.pop(frame_idx_rm, None) + + # Step 3: compute and cache detection outputs of the next chunk ahead of time + # (so that we can overlap computation with all-gather transfer) + if not track_in_reverse and frame_idx_curr_e < num_frames: + frame_idx_next_b = frame_idx_curr_e + frame_idx_next_e = min(frame_idx_next_b + self.world_size, num_frames) + elif track_in_reverse and frame_idx_curr_b - self.world_size >= 0: + frame_idx_next_e = frame_idx_curr_b + frame_idx_next_b = frame_idx_curr_b - self.world_size + else: + frame_idx_next_b = frame_idx_next_e = None + if frame_idx_next_b is not None and frame_idx_next_b not in multigpu_buffer: + with torch.profiler.record_function("build_multigpu_buffer_next_chunk2"): + self._build_multigpu_buffer_next_chunk( + backbone_out=backbone_out, + find_inputs=find_inputs, + geometric_prompt=geometric_prompt, + frame_idx_begin=frame_idx_next_b, + frame_idx_end=frame_idx_next_e, + num_frames=num_frames, + multigpu_buffer=multigpu_buffer, + run_nms=run_nms, + nms_prob_thresh=nms_prob_thresh, + nms_iou_thresh=nms_iou_thresh, + ) + + return out, backbone_out + + def _build_multigpu_buffer_next_chunk( + self, + backbone_out, + find_inputs, + geometric_prompt: Prompt, + frame_idx_begin, + frame_idx_end, + num_frames, + multigpu_buffer, + run_nms=False, + nms_prob_thresh=None, + nms_iou_thresh=None, + ): + """Compute detection outputs on a chunk of frames and store their results in multigpu_buffer.""" + # each GPU computes detections on one frame in the chunk (in a round-robin manner) + frame_idx_local_gpu = min(frame_idx_begin + self.rank, frame_idx_end - 1) + # `forward_grounding` (from base class `Sam3ImageOnVideo`) runs the detector on a single frame + with torch.profiler.record_function("forward_grounding"): + out_local = self.forward_grounding( + backbone_out=backbone_out, + find_input=find_inputs[frame_idx_local_gpu], + find_target=None, + geometric_prompt=geometric_prompt, + ) + if run_nms: + with torch.profiler.record_function("nms_masks"): + # run NMS as a post-processing step on top of the detection outputs + assert nms_prob_thresh is not None and nms_iou_thresh is not None + pred_probs = out_local["pred_logits"].squeeze(-1).sigmoid() + pred_masks = out_local["pred_masks"] + # loop over text prompts (not an overhead for demo where there's only 1 prompt) + for prompt_idx in range(pred_probs.size(0)): + keep = nms_masks( + pred_probs=pred_probs[prompt_idx], + pred_masks=pred_masks[prompt_idx], + prob_threshold=nms_prob_thresh, + iou_threshold=nms_iou_thresh, + ) + # set a very low threshold for those detections removed by NMS + out_local["pred_logits"][prompt_idx, :, 0] -= 1e4 * (~keep).float() + + if self.gather_backbone_out: + # gather the SAM 2 backbone features across GPUs + feats = out_local["prev_encoder_out"]["backbone_out"]["sam2_backbone_out"] + assert len(feats["backbone_fpn"]) == 3 # SAM2 backbone always have 3 levels + # cast the SAM2 backbone features to bfloat16 for all-gather (this is usually + # a no-op, SAM2 backbone features are likely already in bfloat16 due to AMP) + backbone_fpn_bf16 = [x.to(torch.bfloat16) for x in feats["backbone_fpn"]] + fpn0, fpn_handle0 = self._gather_tensor(backbone_fpn_bf16[0]) + fpn1, fpn_handle1 = self._gather_tensor(backbone_fpn_bf16[1]) + fpn2, fpn_handle2 = self._gather_tensor(backbone_fpn_bf16[2]) + # vision_pos_enc is the same on all frames, so no need to all-gather them + vision_pos_enc = feats["vision_pos_enc"] + + # trim the detector output to only include the necessary keys + out_local = { + "pred_logits": out_local["pred_logits"], + "pred_boxes": out_local["pred_boxes"], + "pred_boxes_xyxy": out_local["pred_boxes_xyxy"], + "pred_masks": out_local["pred_masks"], + } + + # gather the results: after this step, each GPU will receive detector outputs on + # all frames in the chunk and store them in `multigpu_buffer` + out_gathered = {k: self._gather_tensor(v) for k, v in out_local.items()} + for rank in range(self.world_size): + frame_idx_to_save = frame_idx_begin + rank + if frame_idx_to_save >= num_frames: + continue + frame_buffer = { + k: (v[rank], handle) for k, (v, handle) in out_gathered.items() + } + if self.gather_backbone_out: + # also add gathered SAM 2 backbone features to frame_buffer + frame_buffer["tracker_backbone_fpn_0"] = (fpn0[rank], fpn_handle0) + frame_buffer["tracker_backbone_fpn_1"] = (fpn1[rank], fpn_handle1) + frame_buffer["tracker_backbone_fpn_2"] = (fpn2[rank], fpn_handle2) + frame_buffer["tracker_backbone_pos_enc"] = (vision_pos_enc, None) + + multigpu_buffer[frame_idx_to_save] = frame_buffer + + def _gather_tensor(self, x): + if self.world_size == 1: + return [x], None + + async_op = self.async_all_gather + # here `.contiguous()` is required -- otherwise NCCL all_gather + # sometimes gives wrong results + x = x.contiguous() # ensure contiguous memory for NCCL + output_list = [torch.empty_like(x) for _ in range(self.world_size)] + handle = torch.distributed.all_gather(output_list, x, async_op=async_op) + return output_list, handle diff --git a/third_party/sam3/sam3/model/sam3_image_processor.py b/third_party/sam3/sam3/model/sam3_image_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..d3b6206ab30c6290c055d46c5bc5310fdf862069 --- /dev/null +++ b/third_party/sam3/sam3/model/sam3_image_processor.py @@ -0,0 +1,222 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe +from typing import Dict, List + +import numpy as np +import PIL +import torch +from sam3.model import box_ops +from sam3.model.data_misc import FindStage, interpolate +from torchvision.transforms import v2 + + +class Sam3Processor: + """ """ + + def __init__(self, model, resolution=1008, device="cuda", confidence_threshold=0.5): + self.model = model + self.resolution = resolution + self.device = device + self.transform = v2.Compose( + [ + v2.ToDtype(torch.uint8, scale=True), + v2.Resize(size=(resolution, resolution)), + v2.ToDtype(torch.float32, scale=True), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ] + ) + self.confidence_threshold = confidence_threshold + + self.find_stage = FindStage( + img_ids=torch.tensor([0], device=device, dtype=torch.long), + text_ids=torch.tensor([0], device=device, dtype=torch.long), + input_boxes=None, + input_boxes_mask=None, + input_boxes_label=None, + input_points=None, + input_points_mask=None, + ) + + @torch.inference_mode() + def set_image(self, image, state=None): + """Sets the image on which we want to do predictions.""" + if state is None: + state = {} + + if isinstance(image, PIL.Image.Image): + width, height = image.size + elif isinstance(image, (torch.Tensor, np.ndarray)): + height, width = image.shape[-2:] + else: + raise ValueError("Image must be a PIL image or a tensor") + + image = v2.functional.to_image(image).to(self.device) + image = self.transform(image).unsqueeze(0) + + state["original_height"] = height + state["original_width"] = width + state["backbone_out"] = self.model.backbone.forward_image(image) + inst_interactivity_en = self.model.inst_interactive_predictor is not None + if inst_interactivity_en and "sam2_backbone_out" in state["backbone_out"]: + sam2_backbone_out = state["backbone_out"]["sam2_backbone_out"] + sam2_backbone_out["backbone_fpn"][0] = ( + self.model.inst_interactive_predictor.model.sam_mask_decoder.conv_s0( + sam2_backbone_out["backbone_fpn"][0] + ) + ) + sam2_backbone_out["backbone_fpn"][1] = ( + self.model.inst_interactive_predictor.model.sam_mask_decoder.conv_s1( + sam2_backbone_out["backbone_fpn"][1] + ) + ) + return state + + @torch.inference_mode() + def set_image_batch(self, images: List[np.ndarray], state=None): + """Sets the image batch on which we want to do predictions.""" + if state is None: + state = {} + + if not isinstance(images, list): + raise ValueError("Images must be a list of PIL images or tensors") + assert len(images) > 0, "Images list must not be empty" + assert isinstance( + images[0], PIL.Image.Image + ), "Images must be a list of PIL images" + + state["original_heights"] = [image.height for image in images] + state["original_widths"] = [image.width for image in images] + + images = [ + self.transform(v2.functional.to_image(image).to(self.device)) + for image in images + ] + images = torch.stack(images, dim=0) + state["backbone_out"] = self.model.backbone.forward_image(images) + inst_interactivity_en = self.model.inst_interactive_predictor is not None + if inst_interactivity_en and "sam2_backbone_out" in state["backbone_out"]: + sam2_backbone_out = state["backbone_out"]["sam2_backbone_out"] + sam2_backbone_out["backbone_fpn"][0] = ( + self.model.inst_interactive_predictor.model.sam_mask_decoder.conv_s0( + sam2_backbone_out["backbone_fpn"][0] + ) + ) + sam2_backbone_out["backbone_fpn"][1] = ( + self.model.inst_interactive_predictor.model.sam_mask_decoder.conv_s1( + sam2_backbone_out["backbone_fpn"][1] + ) + ) + return state + + @torch.inference_mode() + def set_text_prompt(self, prompt: str, state: Dict): + """Sets the text prompt and run the inference""" + + if "backbone_out" not in state: + raise ValueError("You must call set_image before set_text_prompt") + + text_outputs = self.model.backbone.forward_text([prompt], device=self.device) + # will erase the previous text prompt if any + state["backbone_out"].update(text_outputs) + if "geometric_prompt" not in state: + state["geometric_prompt"] = self.model._get_dummy_prompt() + + return self._forward_grounding(state) + + @torch.inference_mode() + def add_geometric_prompt(self, box: List, label: bool, state: Dict): + """Adds a box prompt and run the inference. + The image needs to be set, but not necessarily the text prompt. + The box is assumed to be in [center_x, center_y, width, height] format and normalized in [0, 1] range. + The label is True for a positive box, False for a negative box. + """ + if "backbone_out" not in state: + raise ValueError("You must call set_image before set_text_prompt") + + if "language_features" not in state["backbone_out"]: + # Looks like we don't have a text prompt yet. This is allowed, but we need to set the text prompt to "visual" for the model to rely only on the geometric prompt + dummy_text_outputs = self.model.backbone.forward_text( + ["visual"], device=self.device + ) + state["backbone_out"].update(dummy_text_outputs) + + if "geometric_prompt" not in state: + state["geometric_prompt"] = self.model._get_dummy_prompt() + + # adding a batch and sequence dimension + boxes = torch.tensor(box, device=self.device, dtype=torch.float32).view(1, 1, 4) + labels = torch.tensor([label], device=self.device, dtype=torch.bool).view(1, 1) + state["geometric_prompt"].append_boxes(boxes, labels) + + return self._forward_grounding(state) + + def reset_all_prompts(self, state: Dict): + """Removes all the prompts and results""" + if "backbone_out" in state: + backbone_keys_to_del = [ + "language_features", + "language_mask", + "language_embeds", + ] + for key in backbone_keys_to_del: + if key in state["backbone_out"]: + del state["backbone_out"][key] + + keys_to_del = ["geometric_prompt", "boxes", "masks", "masks_logits", "scores"] + for key in keys_to_del: + if key in state: + del state[key] + + @torch.inference_mode() + def set_confidence_threshold(self, threshold: float, state=None): + """Sets the confidence threshold for the masks""" + self.confidence_threshold = threshold + if state is not None and "boxes" in state: + # we need to filter the boxes again + # In principle we could do this more efficiently since we would only need + # to rerun the heads. But this is simpler and not too inefficient + return self._forward_grounding(state) + return state + + @torch.inference_mode() + def _forward_grounding(self, state: Dict): + outputs = self.model.forward_grounding( + backbone_out=state["backbone_out"], + find_input=self.find_stage, + geometric_prompt=state["geometric_prompt"], + find_target=None, + ) + + out_bbox = outputs["pred_boxes"] + out_logits = outputs["pred_logits"] + out_masks = outputs["pred_masks"] + out_probs = out_logits.sigmoid() + presence_score = outputs["presence_logit_dec"].sigmoid().unsqueeze(1) + out_probs = (out_probs * presence_score).squeeze(-1) + + keep = out_probs > self.confidence_threshold + out_probs = out_probs[keep] + out_masks = out_masks[keep] + out_bbox = out_bbox[keep] + + # convert to [x0, y0, x1, y1] format + boxes = box_ops.box_cxcywh_to_xyxy(out_bbox) + + img_h = state["original_height"] + img_w = state["original_width"] + scale_fct = torch.tensor([img_w, img_h, img_w, img_h]).to(self.device) + boxes = boxes * scale_fct[None, :] + + out_masks = interpolate( + out_masks.unsqueeze(1), + (img_h, img_w), + mode="bilinear", + align_corners=False, + ).sigmoid() + + state["masks_logits"] = out_masks + state["masks"] = out_masks > 0.5 + state["boxes"] = boxes + state["scores"] = out_probs + return state diff --git a/third_party/sam3/sam3/model/sam3_multiplex_base.py b/third_party/sam3/sam3/model/sam3_multiplex_base.py new file mode 100644 index 0000000000000000000000000000000000000000..3e3e6396767e833d80685d340407ba7fed115f9d --- /dev/null +++ b/third_party/sam3/sam3/model/sam3_multiplex_base.py @@ -0,0 +1,2856 @@ +import datetime +import logging +import math +import os +import sys +from collections import defaultdict +from copy import deepcopy +from typing import Any, Dict, List, Optional, Set, Tuple + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn.functional as F +from sam3.logger import get_logger +from sam3.model.box_ops import fast_diag_box_iou +from sam3.model.data_misc import BatchedDatapoint, NestedTensor +from sam3.model.sam3_multiplex_detector import Sam3MultiplexDetector +from sam3.model.sam3_tracker_utils import fill_holes_in_mask_scores, mask_to_box +from sam3.model.sam3_video_base import ( + _associate_det_trk_compilable, + LazyAssociateDetTrkResult, + MaskletConfirmationStatus, + realize_adt_result, + RealizedAssociateDetTrkresult, + Sam3VideoBase, +) +from sam3.perflib.masks_ops import mask_iou +from sam3.train.masks_ops import rle_encode +from torch import nn, Tensor + +# a short 3-min timeout to quickly detect any synchronization failures +SAM3_COLLECTIVE_OP_TIMEOUT_SEC = int(os.getenv("SAM3_COLLECTIVE_OP_TIMEOUT_SEC", "180")) + +logger = get_logger(__name__) + +if torch.cuda.get_device_properties(0).major >= 8: + # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + +class Sam3MultiplexTrackerPredictor(nn.Module): + def __init__( + self, + config_file, + checkpoint_file=None, + hydra_overrides=None, + per_obj_inference=False, + fill_hole_area=0, + use_fa3=False, + use_rope_real=False, + keep_first_cond_frame=False, + is_multiplex=False, + is_multiplex_dynamic=False, + use_memory_selection=False, + ): + """ + Initialize the SAM2 predictor with the given configuration and checkpoint. + Args: + config_file (str): Path to the configuration file. + checkpoint_file (str, optional): Path to the checkpoint file. If None, the model will be initialized without loading weights. + hydra_overrides (list, optional): List of Hydra overrides to apply to the configuration. + per_obj_inference (bool): If True, the model will perform per-object inference instead of bucketized batching. + """ + + super().__init__() + ####################################### + # Load model from config and checkpoint + ####################################### + + from hydra import compose, initialize_config_module + from hydra.core.global_hydra import GlobalHydra + from hydra.utils import instantiate + + # Ensure proper Hydra initialization + if not GlobalHydra().is_initialized(): + logger.info("Sam3MultiplexTrackerPredictor: GlobalHydra not initialized") + GlobalHydra.instance().clear() + initialize_config_module("sam3.config", version_base="1.2") + + if hydra_overrides is None: + hydra_overrides = [] + self.is_multiplex = is_multiplex + self.is_multiplex_dynamic = is_multiplex_dynamic + self.per_obj_inference = per_obj_inference + + if self.is_multiplex: + inference_model_class = "sam3.model.video_tracking_multiplex_demo.Sam3VideoTrackingMultiplexDemo" + else: + inference_model_class = ( + "sam3.model.video_tracking_with_prompt_demo_per_obj_inference.Sam3VideoTrackingWithPromptDemoPerObjInference" + if per_obj_inference + else "sam3.model.video_tracking_with_prompt_demo.Sam3VideoTrackingWithPromptDemo" + ) + hydra_overrides = list(hydra_overrides) + hydra_overrides.extend( + [ + "launcher.experiment_log_dir=''", + f"++trainer.model._target_={inference_model_class}", + # Shared backbone cfg + "++trainer.model.image_size=1008", + "++trainer.model.backbone_stride=14", + "++trainer.model.maskmem_backbone.mask_downsampler.interpol_size=[1152,1152]", + "++trainer.model.backbone.forward_in_chunk_for_eval=false", + # always start tracking from the frame where we receive the first annotation + # (clicks or mask) and ignore the `start_frame_idx` passed to `propagate_in_video` + "++trainer.model.always_start_from_first_ann_frame=false", + # apply non-overlapping constraints on the object masks in the + # memory encoder to avoid/alleviate superposing mask predictions + "++trainer.model.non_overlap_masks_for_mem_enc=false", + # Do not apply non-overlapping constraints on the output + "++trainer.model.non_overlap_masks_for_output=false", + # attend to at most 4 temporally closest conditioning frames in the encoder for + # better temporal locality and a better handling to a large number of annotated frames + "++trainer.model.max_cond_frames_in_attn=4", + f"++trainer.model.keep_first_cond_frame={keep_first_cond_frame}", + # turn off all offloading options in the demo (we handle them separately in the demo class) + "++trainer.model.offload_output_to_cpu_for_eval=false", + "++trainer.model.trim_past_non_cond_mem_for_eval=false", + # torch.compile on the image backbone (w/ `dynamic=false` and `fullgraph=true` to capture a full graph) + # "++trainer.model.backbone.compile_mode=max-autotune", + # "++trainer.model.backbone.compile_extra_args.fullgraph=true", + # "++trainer.model.backbone.compile_extra_args.dynamic=false", + "++trainer.model.backbone.visual.trunk.weights_path=null", + # Postprocessing/demo options + # dynamically fall back to multi-mask if the single mask is not stable + "++trainer.model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", + "++trainer.model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", + "++trainer.model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", + # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking + "++trainer.model.binarize_mask_from_pts_for_mem_enc=true", + # only attend to object pointers in the past (before the current frame) in the encoder during evaluation + "++trainer.model.only_obj_ptrs_in_the_past_for_eval=true", + # clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks + "++trainer.model.clear_non_cond_mem_around_input=true", + "++trainer.model.transformer.encoder.layer.self_attention.feat_sizes=[72,72]", + "++trainer.model.transformer.encoder.layer.cross_attention.feat_sizes=[72,72]", + # fill small holes in the final masks up to `fill_hole_area` (after resizing them to the original video resolution) + f"++trainer.model.fill_hole_area={fill_hole_area}", + f"++trainer.model.transformer.encoder.layer.self_attention.use_fa3={use_fa3}", + f"++trainer.model.transformer.encoder.layer.cross_attention.use_fa3={use_fa3}", + f"++trainer.model.transformer.encoder.layer.self_attention.use_rope_real={use_rope_real}", + f"++trainer.model.transformer.encoder.layer.cross_attention.use_rope_real={use_rope_real}", + ] + ) + + if self.is_multiplex or self.is_multiplex_dynamic: + hydra_overrides.extend( + [ + f"++trainer.model.transformer.encoder.layer.self_attention_rope.use_fa3={use_fa3}", + f"++trainer.model.transformer.encoder.layer.cross_attention_rope.use_fa3={use_fa3}", + f"++trainer.model.transformer.encoder.layer.self_attention_rope.use_rope_real={use_rope_real}", + f"++trainer.model.transformer.encoder.layer.cross_attention_rope.use_rope_real={use_rope_real}", + ] + ) + + hydra_overrides.extend( + [f"++trainer.model.use_memory_selection={use_memory_selection}"] + ) + + cfg = compose(config_name=config_file, overrides=hydra_overrides) + model = instantiate(cfg.trainer.model, _recursive_=True) + del model.backbone # Remove backbone since it is shared with the sam3 model + if checkpoint_file is not None: + ckpt = torch.load(checkpoint_file, map_location="cpu") + model.load_state_dict(ckpt["model"], strict=False) + self.model = model + self.per_obj_inference = per_obj_inference + self.fill_hole_area = fill_hole_area + # use bfloat16 inference for Flash Attention kernel + self.bf16_context = torch.autocast(device_type="cuda", dtype=torch.bfloat16) + self.bf16_context.__enter__() # keep using for the entire model process + + def __getattr__(self, name): + # Expose all attributes of the underlying model + model = super().__getattr__("model") + if name == "model": + return model + return getattr(model, name) + + def forward(self, *args, **kwargs): + raise NotImplementedError( + "Use the sam2 predictor APIs instead. Check VideoTrackingWithPromptDemo class for details." + ) + + def add_output_per_object(self, *args, **kwargs): + if self.per_obj_inference: + # nothing needs to be done as each object is already stored separately + return + + # for batched inference state, we also need to add per-object + # memory slides to support instance interactivity + self._add_output_per_object(*args, **kwargs) + + +class Sam3MultiplexBase(Sam3VideoBase): + def __init__( + self, + tracker, + detector, + ckpt_path=None, + sam3_ckpt_path=None, + # prob threshold for detection outputs -- only keep detections above this threshold + # enters NMS and det-to-track matching + score_threshold_detection=0.5, + # Detection threshold when running on image-only inputs + image_only_det_thresh=0.5, + # IoU threshold for detection NMS + det_nms_thresh=0.0, + # If `det_nms_use_iom` is True, use IoM instead of IoU for NMS + det_nms_use_iom=False, + # IoU threshold for det-to-track matching -- a detection is considered "matched" to a tracklet it + # overlaps with a tracklet above this threshold -- it is often a loose threshold like 0.1 + assoc_iou_thresh=0.5, + # IoU threshold for det-to-track matching, which is used to determine whether a masklet is "unmatched" + # by any detections -- it is often a stricter threshold like 0.5 + trk_assoc_iou_thresh=0.5, + # prob threshold for a detection to be added as a new object + new_det_thresh=0.5, + # hotstart parameters: we hold off the outputs for `hotstart_delay` frames and + # 1) remove those tracklets unmatched by any detections based on `hotstart_unmatch_thresh` + # 2) remove those tracklets overlapping with one another based on `hotstart_dup_thresh` + hotstart_delay=0, + hotstart_unmatch_thresh=3, + hotstart_dup_thresh=3, + # Whether to suppress masks only within hotstart. If False, we can suppress masks even if they start before hotstart period. + suppress_unmatched_only_within_hotstart=True, + init_trk_keep_alive=0, + max_trk_keep_alive=8, + min_trk_keep_alive=-4, + # Threshold for suppressing overlapping objects based on recent occlusion + suppress_overlapping_based_on_recent_occlusion_threshold=0.0, + allow_unoccluded_to_suppress: bool = False, + decrease_trk_keep_alive_for_empty_masklets=False, + o2o_matching_masklets_enable=False, # Enable hungarian matching to match existing masklets + suppress_det_close_to_boundary=False, + fill_hole_area=16, + sprinkle_removal_area=16, + # The maximum number of objects (masklets) to track across all GPUs (for no limit, set it to -1) + max_num_objects=128, # 128 objects (total across all GPUs) should be able to cover nearly all cases + max_num_kboxes=20, + recondition_every_nth_frame=-1, + use_iom_recondition=False, + iom_thresh_recondition=0.8, + iou_thresh_recondition=0.8, + is_multiplex=False, + # masket confirmation status (to suppress unconfirmed masklets) + masklet_confirmation_enable=False, + # a masklet is confirmed after being consecutively detected and matched for + # `masklet_confirmation_consecutive_det_thresh` + masklet_confirmation_consecutive_det_thresh=3, + # bbox heuristic parameters + reconstruction_bbox_iou_thresh=0.0, + reconstruction_bbox_det_score=0.5, + reapply_no_object_pointer: bool = False, # reapply the no object pointer for suppressed objects + running_in_prod=False, # Flag to specify if we are running in FBInfra for Insta Edit/Segments + use_batched_grounding=False, + batched_grounding_batch_size=1, + **kwargs, + ): + nn.Module.__init__(self) + assert isinstance(tracker, Sam3MultiplexTrackerPredictor) + self.tracker = tracker + assert isinstance(detector, Sam3MultiplexDetector) + self.detector = detector + if sam3_ckpt_path: + ckpt = torch.load(sam3_ckpt_path, map_location="cpu", weights_only=True) + self.detector.load_state_dict(ckpt["model"], strict=False) + elif ckpt_path: + self._load_checkpoint(ckpt_path, strict=False) + self.score_threshold_detection = score_threshold_detection + self.image_only_det_thresh = image_only_det_thresh + self.det_nms_thresh = det_nms_thresh + self.det_nms_use_iom = det_nms_use_iom + self.assoc_iou_thresh = assoc_iou_thresh + self.trk_assoc_iou_thresh = trk_assoc_iou_thresh + self.new_det_thresh = new_det_thresh + self.is_multiplex = is_multiplex + self.running_in_prod = running_in_prod + self.detector.running_in_prod = running_in_prod + + assert ( + self.is_multiplex == self.tracker.is_multiplex == self.detector.is_multiplex + ), f"is_multiplex must be the same for all models: {self.is_multiplex=}, {self.tracker.is_multiplex=}, {self.detector.is_multiplex=}" + + # hotstart parameters + if hotstart_delay > 0: + assert hotstart_unmatch_thresh <= hotstart_delay + assert hotstart_dup_thresh <= hotstart_delay + self.hotstart_delay = hotstart_delay + self.hotstart_unmatch_thresh = hotstart_unmatch_thresh + self.hotstart_dup_thresh = hotstart_dup_thresh + self.suppress_unmatched_only_within_hotstart = ( + suppress_unmatched_only_within_hotstart + ) + self.init_trk_keep_alive = init_trk_keep_alive + self.max_trk_keep_alive = max_trk_keep_alive + self.min_trk_keep_alive = min_trk_keep_alive + self.suppress_overlapping_based_on_recent_occlusion_threshold = ( + suppress_overlapping_based_on_recent_occlusion_threshold + ) + self.allow_unoccluded_to_suppress = allow_unoccluded_to_suppress + self.suppress_det_close_to_boundary = suppress_det_close_to_boundary + self.decrease_trk_keep_alive_for_empty_masklets = ( + decrease_trk_keep_alive_for_empty_masklets + ) + self.o2o_matching_masklets_enable = o2o_matching_masklets_enable + self.fill_hole_area = fill_hole_area + self.sprinkle_removal_area = sprinkle_removal_area + self.eval() + self.rank = int(os.getenv("RANK", "0")) + self.world_size = int(os.getenv("WORLD_SIZE", "1")) + self._dist_pg_cpu = None # CPU process group (lazy-initialized on first use) + + # Initialize profiling variables + self._profiler = None + self._frame_count = 0 + self._profile_save_dir = os.getenv("PROFILE_SAVE_DIR", "/tmp/profiling") + self._profiling_enabled = os.getenv("ENABLE_PROFILING", "0").lower() == "1" + + # the maximum object number + if max_num_objects > 0: + multiplex_divisor = ( + self.tracker.multiplex_controller.allowed_bucket_capacity + if self.is_multiplex + else 1 + ) + num_obj_for_compile = math.ceil( + max_num_objects / (self.world_size * multiplex_divisor) + ) + else: + max_num_objects = 10000 # no limit + num_obj_for_compile = 16 + logger.info( + f"`setting max_num_objects` to {max_num_objects} -- creating {num_obj_for_compile=} objects for torch.compile cache" + ) + self.max_num_objects = max_num_objects + self.num_obj_for_compile = num_obj_for_compile + self.max_num_kboxes = max_num_kboxes + self.recondition_every_nth_frame = recondition_every_nth_frame + self.use_iom_recondition = use_iom_recondition + self.iom_thresh_recondition = iom_thresh_recondition + self.iou_thresh_recondition = iou_thresh_recondition + self.masklet_confirmation_enable = masklet_confirmation_enable + self.masklet_confirmation_consecutive_det_thresh = ( + masklet_confirmation_consecutive_det_thresh + ) + self.reconstruction_bbox_iou_thresh = reconstruction_bbox_iou_thresh + self.reconstruction_bbox_det_score = reconstruction_bbox_det_score + self.reapply_no_object_pointer = reapply_no_object_pointer + + # Batched grounding configuration + self.use_batched_grounding = use_batched_grounding + self.batched_grounding_batch_size = ( + batched_grounding_batch_size # Batch size for batched grounding + ) + + if self.is_multiplex: + assert ( + not self.tracker.multiplex_controller.training + ), "This model class should only be used for eval." + self.bucket_capacity: int = ( + self.tracker.multiplex_controller.allowed_bucket_capacity + ) + + def all_gather_cpu(self, tensor_list, tensor): + if self._dist_pg_cpu is None: + self._init_dist_pg_cpu() + dist.broadcast(tensor_list, tensor, group=self._dist_pg_cpu) + + def all_gather_python_obj_cpu(self, object_list, python_obj): + if self._dist_pg_cpu is None: + self._init_dist_pg_cpu() + dist.all_gather_object(object_list, python_obj, group=self._dist_pg_cpu) + + def broadcast_cpu(self, x, src): + if self._dist_pg_cpu is None: + self._init_dist_pg_cpu() + dist.broadcast(x, src=src, group=self._dist_pg_cpu) + + def _start_profiling(self, frame_idx): + self._profiling_enabled = os.getenv("ENABLE_PROFILING", "0").lower() == "1" + self._profile_end_frame = int(os.getenv("PROFILE_END_FRAME", "-1")) + """Start profiling for _det_track_one_frame if conditions are met.""" + if not self._profiling_enabled: + return False + + if not getattr(self, "_warm_up_complete", False): + return False + + if self._profiler is not None: + return True + + # Start profiling + os.makedirs(self._profile_save_dir, exist_ok=True) + profile_path = os.path.join( + self._profile_save_dir, f"det_track_frame_rank_{self.rank}.json.gz" + ) + + self._profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + record_shapes=True, + experimental_config=torch.profiler._ExperimentalConfig( + profile_all_threads=True + ), + ) + self._profiler.start() + self._current_profile_path = profile_path + print(f"Started profiling frame on {frame_idx} on rank {self.rank}") + return True + + def _stop_profiling(self): + """Stop profiling and save trace.""" + if self._profiler is not None: + self._profiler.stop() + self._profiler.export_chrome_trace(self._current_profile_path) + print(f"Profiling trace saved to: {self._current_profile_path}") + print( + f"You can open this file in Perfetto (https://ui.perfetto.dev/) to visualize the trace" + ) + self._profiler = None + self._profiling_enabled = False + os.environ["ENABLE_PROFILING"] = "0" + + def _det_track_one_frame( + self, + frame_idx: int, + num_frames: int, + reverse: bool, + input_batch: BatchedDatapoint, + geometric_prompt: Any, + tracker_states_local: List[Any], + tracker_metadata_prev: Dict[str, Any], + feature_cache: Dict, + orig_vid_height: int, + orig_vid_width: int, + is_image_only: bool = False, + ): + profiling_enabled = self._start_profiling(frame_idx) + + try: + return self._det_track_one_frame_impl( + frame_idx=frame_idx, + num_frames=num_frames, + reverse=reverse, + input_batch=input_batch, + geometric_prompt=geometric_prompt, + tracker_states_local=tracker_states_local, + tracker_metadata_prev=tracker_metadata_prev, + feature_cache=feature_cache, + orig_vid_height=orig_vid_height, + orig_vid_width=orig_vid_width, + is_image_only=is_image_only, + ) + finally: + if profiling_enabled: + if sys.exc_info()[0] is not None: + # If there is an exception, stop profiling + self._stop_profiling() + else: + if ( + (not reverse and frame_idx == num_frames - 1) + or (reverse and frame_idx == 0) + or self._profile_end_frame == frame_idx + ): + # Stop profiling if reached the last frame + self._stop_profiling() + + def _det_track_one_frame_impl( + self, + frame_idx: int, + num_frames: int, + reverse: bool, + input_batch: BatchedDatapoint, + geometric_prompt: Any, + tracker_states_local: List[Any], + tracker_metadata_prev: Dict[str, Any], + feature_cache: Dict, + orig_vid_height: int, + orig_vid_width: int, + is_image_only: bool, + ): + """ + This function handles one-step inference for the multiplex model in an SPMD manner. + At a high-level, all GPUs execute the same function calls as if it's done on a single GPU, + while under the hood, some function calls involve distributed computation based on sharded + SAM2 states. + + - `input_batch` contains image and other inputs on the entire video; it should be identical across GPUs + - `tracker_states_local` holds the local masklet information in this GPU shard + - `tracker_metadata_prev` manages the metadata for SAM2 objects, such as which masklet is hold on which GPUs + it contains both global and local masklet information + """ + + # Step 1: run backbone and FA in a distributed manner -- this is done via Sam3MultiplexDetector, + # a distributed FA model (assigned to `self.detector`) that shards frames in a round-robin manner. + # It returns a "det_out" dict for `frame_idx` and fills SAM2 backbone features for `frame_idx` + # into `feature_cache`. Despite its distributed inference under the hood, the results would be + # the same as if it is running backbone and FA for every frame on a single GPU. + with torch.profiler.record_function("run_backbone_and_detection"): + det_out, pos_pred_mask = self.run_backbone_and_detection( + frame_idx=frame_idx, + num_frames=num_frames, + reverse=reverse, + input_batch=input_batch, + geometric_prompt=geometric_prompt, + feature_cache=feature_cache, + use_batched_grounding=self.use_batched_grounding, + batched_grounding_batch_size=self.batched_grounding_batch_size, + ) + + # Step 2: each GPU propagates its local SAM2 states to get the SAM2 prediction masks. + # the returned `tracker_low_res_masks_global` contains the concatenated masklet predictions + # gathered from all GPUs (as if they are propagated on a single GPU). Note that this step only + # runs the SAM2 propagation step, but doesn't encode new memory for the predicted masks; + # we defer memory encoding to `run_tracker_update_execution_phase` after resolving all heuristics. + with torch.profiler.record_function("run_tracker_propagation"): + if tracker_metadata_prev == {}: + # initialize masklet metadata if it's uninitialized (empty dict) + tracker_metadata_prev.update(self._initialize_metadata()) + tracker_low_res_masks_global, tracker_obj_scores_global = ( + self.run_tracker_propagation( + frame_idx=frame_idx, + num_frames=num_frames, + reverse=reverse, + tracker_states_local=tracker_states_local, + tracker_metadata_prev=tracker_metadata_prev, + ) + ) + + with torch.profiler.record_function("GPU sync and filter"): + # Remove leading dimension (assumes batch size 1) + assert pos_pred_mask.shape[0] == 1 + pos_pred_mask = pos_pred_mask.squeeze(0) + det_out = {k: det_out[k][0] for k in det_out} + # Move detections we'll actually keep at the top for future logic + pos_pred_mask_idx = pos_pred_mask.argsort(descending=True) + pos_pred_mask = torch.index_select( + pos_pred_mask, dim=0, index=pos_pred_mask_idx + ) + det_out = { + k: torch.index_select(det_out[k], dim=0, index=pos_pred_mask_idx) + for k in det_out + } + + # Step 3: based on detection outputs and the propagated SAM2 prediction masks, we make plans + # for SAM2 masklet updates (i.e. which objects to add and remove, how to load-balance them, etc). + # We also run SAM2 memory encoder globally in this step to resolve non-overlapping constraints. + # **This step should involve all the heuristics needed for any updates.** Most of the update + # planning will be done on the master rank (GPU 0) and the resulting plan `sam2_update_plan` is + # broadcasted to other GPUs (to be executed in a distributed manner). This step also generates the + # new masklet metadata `tracker_metadata_new` (based on its previous version `tracker_metadata_prev`). + with torch.profiler.record_function("run_tracker_update_planning_phase"): + sam2_update_plan, tracker_metadata_new = ( + self.run_tracker_update_planning_phase( + frame_idx=frame_idx, + num_frames=num_frames, + reverse=reverse, + det_out=det_out, + det_keep=pos_pred_mask, + tracker_low_res_masks_global=tracker_low_res_masks_global, + tracker_obj_scores_global=tracker_obj_scores_global, + tracker_metadata_prev=tracker_metadata_prev, + tracker_states_local=tracker_states_local, + is_image_only=is_image_only, + ) + ) + + # Get reconditioning info from the update plan + reconditioned_obj_ids = sam2_update_plan.get("reconditioned_obj_ids", set()) + det_to_matched_trk_obj_ids = sam2_update_plan.get( + "det_to_matched_trk_obj_ids", {} + ) + + # Step 4: based on `sam2_update_plan`, each GPU executes the update w.r.t. its local SAM2 inference states + with torch.profiler.record_function("run_tracker_update_execution_phase"): + tracker_states_local_new = self.run_tracker_update_execution_phase( + frame_idx=frame_idx, + num_frames=num_frames, + reverse=reverse, + det_out=det_out, + tracker_states_local=tracker_states_local, + tracker_update_plan=sam2_update_plan, + tracker_metadata_new=tracker_metadata_new, + orig_vid_height=orig_vid_height, + orig_vid_width=orig_vid_width, + feature_cache=feature_cache, + ) + + # Step 5: finally, build the outputs for this frame (it only needs to be done on GPU 0 since + # only GPU 0 will send outputs to the server). + with torch.profiler.record_function("build_outputs"): + if self.rank == 0: + obj_id_to_mask = self.build_outputs( + frame_idx=frame_idx, + num_frames=num_frames, + reverse=reverse, + det_out=det_out, + tracker_low_res_masks_global=tracker_low_res_masks_global, + tracker_obj_scores_global=tracker_obj_scores_global, + tracker_metadata_prev=tracker_metadata_prev, + sam2_update_plan=sam2_update_plan, + orig_vid_height=orig_vid_height, + orig_vid_width=orig_vid_width, + reconditioned_obj_ids=reconditioned_obj_ids, + det_to_matched_trk_obj_ids=det_to_matched_trk_obj_ids, + ) + obj_id_to_score = tracker_metadata_new["obj_id_to_score"] + else: + obj_id_to_mask, obj_id_to_score = {}, {} # dummy outputs on other GPUs + # a few statistics for the current frame as a part of the output + frame_stats = { + "num_obj_tracked": np.sum(tracker_metadata_new["num_obj_per_gpu"]), + "num_obj_dropped": sam2_update_plan["num_obj_dropped_due_to_limit"], + } + # add sam2 scores to metadata, it should be fired for frames except the first frame + if tracker_obj_scores_global.shape[0] > 0: + # Convert tracker_obj_scores_global to sigmoid scores before updating + tracker_obj_scores_global = tracker_obj_scores_global.sigmoid() + sam2_obj_ids = tracker_metadata_prev["obj_ids_all_gpu"] + tracker_metadata_new["obj_id_to_sam2_score_frame_wise"][frame_idx].update( + dict(zip(sam2_obj_ids, tracker_obj_scores_global)) + ) + + return ( + obj_id_to_mask, # a dict: obj_id --> output mask + obj_id_to_score, # a dict: obj_id --> output score (prob) + tracker_states_local_new, + tracker_metadata_new, + frame_stats, + tracker_obj_scores_global, # a dict: obj_id --> sam2 frame-level scores + ) + + def run_backbone_and_detection( + self, + frame_idx: int, + num_frames: int, + input_batch: BatchedDatapoint, + geometric_prompt: Any, + feature_cache: Dict, + reverse: bool, + use_batched_grounding: bool = False, + batched_grounding_batch_size: int = 16, + ): + # Step 1: if text feature is not cached in `feature_cache`, compute and cache it + text_batch_key = tuple(input_batch.find_text_batch) + if "text" not in feature_cache or text_batch_key not in feature_cache["text"]: + text_outputs = self.detector.backbone.forward_text( + input_batch.find_text_batch, device=self.device + ) + # note: we only cache the text feature of the most recent prompt + feature_cache["text"] = {text_batch_key: text_outputs} + else: + text_outputs = feature_cache["text"][text_batch_key] + + # Step 2: run backbone, FA detection, and post-processing with NMS + # Extract max_frame_num_to_track from feature_cache if available + tracking_bounds = feature_cache.get("tracking_bounds", {}) + max_frame_num_to_track = tracking_bounds.get("max_frame_num_to_track") + start_frame_idx = tracking_bounds.get("propagate_in_video_start_frame_idx") + backbone_out = { + "img_batch_all_stages": input_batch.img_batch, + **text_outputs, + } + + if use_batched_grounding: + # Use fully batched forward_grounding approach + if "grounding_cache" not in feature_cache: + feature_cache["grounding_cache"] = {} + + with torch.profiler.record_function( + "forward_video_grounding_batched_multigpu" + ): + sam3_image_out, _ = ( + self.detector.forward_video_grounding_batched_multigpu( + backbone_out=backbone_out, + find_inputs=input_batch.find_inputs, + geometric_prompt=geometric_prompt, + frame_idx=frame_idx, + num_frames=num_frames, + grounding_cache=feature_cache["grounding_cache"], + track_in_reverse=reverse, + return_sam2_backbone_feats=True, + run_nms=self.det_nms_thresh > 0.0, + nms_prob_thresh=self.score_threshold_detection, + nms_iou_thresh=self.det_nms_thresh, + nms_use_iom=self.det_nms_use_iom, + max_frame_num_to_track=max_frame_num_to_track, + propagate_in_video_start_frame_idx=start_frame_idx, + feature_cache=feature_cache, + batch_size=batched_grounding_batch_size, + ) + ) + else: + # Use existing multi-GPU distributed approach + if "multigpu_buffer" not in feature_cache: + # "multigpu_buffer" is a buffer cache used by `self.detector` and it needs + # to be passed to `forward_video_grounding_multigpu` for every call + feature_cache["multigpu_buffer"] = {} + + with torch.profiler.record_function("forward_video_grounding_multigpu"): + sam3_image_out, _ = self.detector.forward_video_grounding_multigpu( + backbone_out=backbone_out, + find_inputs=input_batch.find_inputs, + geometric_prompt=geometric_prompt, + frame_idx=frame_idx, + num_frames=num_frames, + multigpu_buffer=feature_cache["multigpu_buffer"], + track_in_reverse=reverse, + # also get the SAM2 backbone features + return_sam2_backbone_feats=True, + # run NMS as a part of distributed FA computation + run_nms=self.det_nms_thresh > 0.0, + nms_prob_thresh=self.score_threshold_detection, + nms_iou_thresh=self.det_nms_thresh, + nms_use_iom=self.det_nms_use_iom, + # pass max_frame_num_to_track to respect tracking limits + max_frame_num_to_track=max_frame_num_to_track, + propagate_in_video_start_frame_idx=start_frame_idx, + # pass feature_cache for buffered backbone computation + feature_cache=feature_cache, + ) + + # note: detections in `sam3_image_out` has already gone through NMS + pred_probs = sam3_image_out["pred_logits"].squeeze(-1).sigmoid() + pred_boxes_xyxy = sam3_image_out["pred_boxes_xyxy"] + pred_masks = sam3_image_out["pred_masks"] + # get the positive detection outputs above threshold + pos_pred_mask = pred_probs > self.score_threshold_detection + + if self.suppress_det_close_to_boundary: + # Suppress detections too close to image edges (for normalized boxes). + keep = self._suppress_detections_close_to_boundary(pred_boxes_xyxy) + pos_pred_mask = pos_pred_mask & keep + + det_out = { + "bbox": pred_boxes_xyxy, + "mask": pred_masks, + "scores": pred_probs, + } + + # Step 3: build SAM2 backbone features and store them in `feature_cache` + backbone_cache = {} + if self.is_multiplex: + # For the multiplex model we have separate interaction and propagation features + # TODO: We do not need the interaction features every frame so there are rooms for optimization + interaction_sam_mask_decoder = self.tracker.interactive_sam_mask_decoder + interaction_backbone_fpn = [ + interaction_sam_mask_decoder.conv_s0( + sam3_image_out["interactive_backbone_fpn_0"] + ), + interaction_sam_mask_decoder.conv_s1( + sam3_image_out["interactive_backbone_fpn_1"] + ), + sam3_image_out[ + "interactive_backbone_fpn_2" + ], # fpn_2 doesn't need additional conv + ] + interaction_backbone_out = { + "vision_features": interaction_backbone_fpn[-1], # top-level feature + "vision_mask": None, + "vision_pos_enc": sam3_image_out["interactive_backbone_pos_enc"], + "backbone_fpn": [ + NestedTensor(x, None) for x in interaction_backbone_fpn + ], + } + backbone_cache["interactive"] = interaction_backbone_out + sam_mask_decoder = self.tracker.sam_mask_decoder + sam2_backbone_fpn = [ + sam_mask_decoder.conv_s0(sam3_image_out["sam2_backbone_fpn_0"]), + sam_mask_decoder.conv_s1(sam3_image_out["sam2_backbone_fpn_1"]), + sam3_image_out["sam2_backbone_fpn_2"], # fpn_2 doesn't need additional conv + ] + sam2_backbone_out = { + "vision_features": sam2_backbone_fpn[-1], # top-level feature + "vision_mask": None, + "vision_pos_enc": sam3_image_out["sam2_backbone_pos_enc"], + "backbone_fpn": [NestedTensor(x, None) for x in sam2_backbone_fpn], + } + backbone_cache["sam2_backbone_out"] = sam2_backbone_out + + with torch.profiler.record_function("run_backbone_and_detection.feature_cache"): + feature_cache[frame_idx] = ( + input_batch.img_batch.tensors[frame_idx], + backbone_cache, + ) + # remove from `feature_cache` old features to save GPU memory + feature_cache.pop(frame_idx - 1 if not reverse else frame_idx + 1, None) + return det_out, pos_pred_mask + + def run_tracker_propagation( + self, + frame_idx: int, + num_frames: int, + reverse: bool, + tracker_states_local: List[Any], + tracker_metadata_prev: Dict[str, np.ndarray], + ): + # Step 1: propagate the local SAM2 states to get the current frame's prediction + # `low_res_masks_local` of the existing masklets on this GPU + # - obj_ids_local: List[int] -- list of object IDs + # - low_res_masks_local: Tensor -- (num_local_obj, H_mask, W_mask) + with torch.profiler.record_function("propagate_tracker_one_frame_local_gpu"): + obj_ids_local, low_res_masks_local, obj_scores_local = ( + self._propogate_tracker_one_frame_local_gpu( + tracker_states_local, frame_idx=frame_idx, reverse=reverse + ) + ) + + assert np.all( + obj_ids_local == tracker_metadata_prev["obj_ids_per_gpu"][self.rank] + ), "{} != {}".format( + obj_ids_local, tracker_metadata_prev["obj_ids_per_gpu"][self.rank] + ) + + # Step 2: all-gather `low_res_masks_local` into `low_res_masks_global` + # - low_res_masks_global: Tensor -- (num_global_obj, H_mask, W_mask) + with torch.profiler.record_function("all_gather_low_res_masks_local"): + _, H_mask, W_mask = low_res_masks_local.shape + if self.world_size > 1: + # `low_res_masks_local` and `obj_scores_local` need to be contiguous and float32 + # (they could be non-contiguous due to slicing and/or bfloat16 due to autocast) + low_res_masks_local = low_res_masks_local.float().contiguous() + obj_scores_local = obj_scores_local.float().contiguous() + num_obj_this_gpu = tracker_metadata_prev["num_obj_per_gpu"][self.rank] + assert low_res_masks_local.size(0) == num_obj_this_gpu + assert obj_scores_local.size(0) == num_obj_this_gpu + low_res_masks_peers = [ + low_res_masks_local.new_empty(num_obj, H_mask, W_mask) + for num_obj in tracker_metadata_prev["num_obj_per_gpu"] + ] + obj_scores_peers = [ + obj_scores_local.new_empty(num_obj) + for num_obj in tracker_metadata_prev["num_obj_per_gpu"] + ] + dist.all_gather(low_res_masks_peers, low_res_masks_local) + dist.all_gather(obj_scores_peers, obj_scores_local) + low_res_masks_global = torch.cat(low_res_masks_peers, dim=0) + obj_scores_global = torch.cat(obj_scores_peers, dim=0) + else: + low_res_masks_global = low_res_masks_local + obj_scores_global = obj_scores_local + return low_res_masks_global, obj_scores_global + + def _recondition_masklets( + self, + frame_idx, + det_out: Dict[str, Tensor], + trk_id_to_max_iou_high_conf_det: Dict[int, int], # trk_obj_id -> det_idx + tracker_states_local: List[Any], + tracker_metadata: Dict[str, np.ndarray], + tracker_obj_scores_global: Tensor, + tracker_low_res_masks_global: Tensor, + ): + reconditioned_obj_ids = set() + HIGH_CONF_THRESH = 0.8 + input_mask_res = self.tracker.input_mask_size + + if len(trk_id_to_max_iou_high_conf_det) == 0: + return tracker_states_local, reconditioned_obj_ids + + # === BATCH ALL INDEX LOOKUPS ON GPU === + trk_obj_ids = list(trk_id_to_max_iou_high_conf_det.keys()) + det_indices = list(trk_id_to_max_iou_high_conf_det.values()) + + # Convert obj_ids_all_gpu to tensor once (keep on GPU) + obj_ids_all_gpu_t = torch.from_numpy(tracker_metadata["obj_ids_all_gpu"]).to( + device=tracker_obj_scores_global.device + ) + trk_obj_ids_t = torch.tensor( + trk_obj_ids, device=tracker_obj_scores_global.device + ) + det_indices_t = torch.tensor( + det_indices, device=tracker_obj_scores_global.device + ) + + # Batched lookup: find obj_idx for each trk_obj_id + # Shape: (num_trk, num_all_obj) -> find matching indices + matches = trk_obj_ids_t.unsqueeze(1) == obj_ids_all_gpu_t.unsqueeze(0) # (N, M) + obj_indices_t = matches.int().argmax(dim=1) # (N,) + + # Batched score lookup and filtering - NO SYNC until we need CPU decision + obj_scores_batch = tracker_obj_scores_global[obj_indices_t].sigmoid() # (N,) + high_conf_mask = obj_scores_batch > HIGH_CONF_THRESH # (N,) bool tensor on GPU + + # === SINGLE SYNC POINT: Transfer filter mask to CPU === + high_conf_mask_cpu = high_conf_mask.cpu().numpy() + + # Filter to only high-confidence items + valid_trk_obj_ids = [ + tid for tid, valid in zip(trk_obj_ids, high_conf_mask_cpu) if valid + ] + valid_det_indices = [ + did for did, valid in zip(det_indices, high_conf_mask_cpu) if valid + ] + valid_obj_indices = obj_indices_t[high_conf_mask] # Keep as tensor + + if len(valid_trk_obj_ids) == 0: + return tracker_states_local, reconditioned_obj_ids + + # === BATCH MASK OPERATIONS === + valid_det_indices_t = torch.tensor( + valid_det_indices, device=det_out["mask"].device + ) + + # Batch fetch all detection masks at once + new_masks = det_out["mask"][valid_det_indices_t] # (K, H, W) + new_masks_binary = ( + F.interpolate( + new_masks.unsqueeze(1), + size=(input_mask_res, input_mask_res), + mode="bilinear", + align_corners=False, + ).squeeze(1) + > 0 + ) # (K, H, W) + + # Batch update low_res_masks_global + old_masks = tracker_low_res_masks_global[valid_obj_indices] # (K, H, W) + binary_agreement = (new_masks > 0) == (old_masks > 0) + updated_masks = torch.where(binary_agreement, old_masks, new_masks) + + # Batch hole filling + updated_masks = fill_holes_in_mask_scores( + updated_masks.unsqueeze(1), + fill_hole_area=self.fill_hole_area, + sprinkle_removal_area=self.sprinkle_removal_area, + fill_holes=True, + remove_sprinkles=True, + ).squeeze(1) + + # Write back (scatter) + tracker_low_res_masks_global[valid_obj_indices] = updated_masks + + # === NOW DO THE STATE UPDATES (still needs iteration but with pre-filtered data) === + if self.is_multiplex: + state_to_recondition_info = {} + for i, trk_obj_id in enumerate(valid_trk_obj_ids): + for state_idx, inference_state in enumerate(tracker_states_local): + if trk_obj_id in inference_state["obj_ids"]: + if state_idx not in state_to_recondition_info: + state_to_recondition_info[state_idx] = [] + state_to_recondition_info[state_idx].append( + (trk_obj_id, new_masks_binary[i]) + ) + break + + for state_idx, recondition_list in state_to_recondition_info.items(): + inference_state = tracker_states_local[state_idx] + obj_ids_to_recondition = [item[0] for item in recondition_list] + masks_to_recondition = torch.stack( + [item[1] for item in recondition_list] + ) + with torch.profiler.record_function( + "_recodition_masklets.add_new_masks" + ): + self.tracker.add_new_masks( + inference_state=inference_state, + frame_idx=frame_idx, + obj_ids=obj_ids_to_recondition, + masks=masks_to_recondition, + reconditioning=True, + ) + reconditioned_obj_ids.update(inference_state["obj_idx_to_id"].values()) + else: + # Non-multiplex: still iterate but masks already computed + for i, trk_obj_id in enumerate(valid_trk_obj_ids): + for inference_state in tracker_states_local: + if trk_obj_id in inference_state["obj_ids"]: + self.tracker.add_new_mask( + inference_state=inference_state, + frame_idx=frame_idx, + obj_id=trk_obj_id, + mask=new_masks_binary[i], + ) + reconditioned_obj_ids.update( + inference_state["obj_idx_to_id"].values() + ) + break + + return tracker_states_local, reconditioned_obj_ids + + def _deepcopy(self, x): + # If running in prod, dont need to do a deepcopy as we only traverse in 1 direction + if True: + return x + return deepcopy(x) + + def run_tracker_update_planning_phase( + self, + frame_idx: int, + num_frames: int, + reverse: bool, + det_out: Dict[str, Tensor], + det_keep: Tensor, + tracker_low_res_masks_global: Tensor, + tracker_obj_scores_global: Tensor, + tracker_metadata_prev: Dict[str, np.ndarray], + tracker_states_local: List[Any], + is_image_only: bool = False, + ): + # initialize new metadata from previous metadata (its values will be updated later) + with torch.profiler.record_function("initialize_tracker_metadata_new"): + tracker_metadata_new = self._create_planning_metadata(tracker_metadata_prev) + + # Initialize reconditioned_obj_ids early to avoid UnboundLocalError + reconditioned_obj_ids = set() + + # Step 1: make the update plan and resolve heuristics on GPU 0 + det_mask_preds: Tensor = det_out["mask"] # low-res mask logits + det_scores: Tensor = det_out["scores"].float() + # a) match FA and SAM2 masks and find new objects + with torch.profiler.record_function("associate_det_trk"): + adt_result = self._associate_det_trk( + det_masks=det_mask_preds, + det_scores=det_scores, + det_keep=det_keep, + trk_masks=tracker_low_res_masks_global, + trk_obj_ids=tracker_metadata_prev["obj_ids_all_gpu"], + default_det_thresh=( + self.image_only_det_thresh if is_image_only else None + ), + ) + + # b) handle hotstart heuristics to remove objects (GPU-vectorized, no sync!) + # here `rank0_metadata` contains metadata stored on (and only accessible to) GPU 0; + # we avoid broadcasting them to other GPUs to save communication cost, assuming + # that `rank0_metadata` is not needed by other GPUs + rank0_metadata_new = self._deepcopy(tracker_metadata_prev["rank0_metadata"]) + if not hasattr(self, "_warm_up_complete") or self._warm_up_complete: + # Call GPU-vectorized hotstart using lazy adt_result (NO realize_adt yet!) + with torch.profiler.record_function("_process_hotstart_gpu"): + to_remove_mask, to_suppress_mask, gpu_metadata_new = ( + self._process_hotstart_gpu( + frame_idx=frame_idx, + reverse=reverse, + adt_result=adt_result, # Still lazy - no sync! + tracker_metadata_prev=tracker_metadata_prev, + gpu_metadata_prev=tracker_metadata_prev["gpu_metadata"], + ) + ) + # IMPORTANT: From this point, tracker_metadata_new["gpu_metadata"] is updated but CPU metadata (obj_ids_all_gpu, etc.) is NOT + tracker_metadata_new["gpu_metadata"] = gpu_metadata_new + else: + # if warm-up is not complete, we don't remove any objects + N_obj = tracker_low_res_masks_global.size(0) + to_remove_mask = torch.zeros( + N_obj, dtype=torch.bool, device=tracker_low_res_masks_global.device + ) + to_suppress_mask = torch.zeros( + N_obj, dtype=torch.bool, device=tracker_low_res_masks_global.device + ) + tracker_metadata_new["rank0_metadata"] = rank0_metadata_new + + # Step 3 (optional): recondition masklets based on high-confidence detections before memory encoding + # NOTE: Running this in execution phase (after memory encoding) can lead to suboptimal results + should_recondition_iou = False + + # Evaluate tracklets for reconditioning based on bbox IoU mismatch with detections + if self.reconstruction_bbox_iou_thresh > 0: + adt_result = realize_adt_result( + adt_result, tracker_metadata_prev, det_mask_preds + ) + if ( + self.reconstruction_bbox_iou_thresh > 0 + and len(adt_result.trk_id_to_max_iou_high_conf_det) > 0 + ): + with torch.profiler.record_function( + "evaluate_reconstruction_bbox_iou_thresh" + ): + trk_obj_ids = adt_result.trk_id_to_max_iou_high_conf_det.keys() + sam2_obj_ids_all_gpu = list(tracker_metadata_prev["obj_ids_all_gpu"]) + trk_ids = [ + sam2_obj_ids_all_gpu.index(trk_obj_id) + for trk_obj_id in trk_obj_ids + if trk_obj_id in sam2_obj_ids_all_gpu + ] + det_ids = list(adt_result.trk_id_to_max_iou_high_conf_det.values()) + + det_boxes_bbox_iou = det_out["bbox"][det_ids] + det_scores_bbox_iou = det_out["scores"][det_ids] + sam2_mask = tracker_low_res_masks_global[trk_ids] + mask_binary = sam2_mask > 0 + sam2_box_pixels = mask_to_box(mask_binary.unsqueeze(1)).squeeze(1) + mask_height, mask_width = sam2_mask.shape[-2:] + sam2_box_normalized = sam2_box_pixels / torch.tensor( + [mask_width, mask_height, mask_width, mask_height], + device=sam2_box_pixels.device, + ) + iou = fast_diag_box_iou(det_boxes_bbox_iou, sam2_box_normalized)[0] + if iou < self.reconstruction_bbox_iou_thresh and torch.any( + det_scores_bbox_iou >= self.reconstruction_bbox_det_score + ): + should_recondition_iou = True + + if ( + self.recondition_every_nth_frame > 0 + and frame_idx % self.recondition_every_nth_frame == 0 + ): + adt_result = realize_adt_result( + adt_result, tracker_metadata_prev, det_mask_preds + ) + + should_recondition_periodic = ( + self.recondition_every_nth_frame > 0 + and frame_idx % self.recondition_every_nth_frame == 0 + and len(adt_result.trk_id_to_max_iou_high_conf_det) > 0 + ) + + # Recondition if periodic or IoU condition met + if should_recondition_periodic or should_recondition_iou: + adt_result = realize_adt_result( + adt_result, tracker_metadata_prev, det_mask_preds + ) + # NOTE: sam2_low_res_mask_global is modified in-place on all GPUs. + with torch.profiler.record_function("_recondition_masklets"): + tracker_states_local, reconditioned_obj_ids = ( + self._recondition_masklets( + frame_idx, + det_out, + adt_result.trk_id_to_max_iou_high_conf_det, + tracker_states_local, + tracker_metadata_prev, + tracker_obj_scores_global, + tracker_low_res_masks_global, + ) + ) + + for state in tracker_states_local: + if any( + obj_id in reconditioned_obj_ids + for obj_id in state.get("obj_ids", []) + ): + self.tracker.propagate_in_video_preflight( + state, run_mem_encoder=True + ) + + # Step 4: Run SAM2 memory encoder on the current frame's prediction masks + # This is done on all GPUs + batch_size = tracker_low_res_masks_global.size(0) + if batch_size > 0: + if not hasattr(self, "_warm_up_complete") or self._warm_up_complete: + if self.suppress_overlapping_based_on_recent_occlusion_threshold > 0.0: + # NOTE: tracker_low_res_masks_global is updated in-place then returned + with torch.profiler.record_function( + "_suppress_overlapping_based_on_recent_occlusion" + ): + tracker_low_res_masks_global = ( + self._suppress_overlapping_based_on_recent_occlusion( + frame_idx, + tracker_low_res_masks_global, + tracker_metadata_prev, + tracker_metadata_new, + to_remove_mask, # GPU boolean mask, no sync! + reverse, + ) + ) + with torch.profiler.record_function("_tracker_update_memories"): + self._tracker_update_memories( + tracker_states_local, + frame_idx, + tracker_metadata=tracker_metadata_prev, + low_res_masks=tracker_low_res_masks_global, + ) + + # NOW realize adt_result after memory encoding (sync only for GPU load balancing) + adt_result = realize_adt_result( + adt_result, tracker_metadata_prev, det_mask_preds + ) + new_det_obj_ids, new_det_gpu_ids, num_obj_dropped_due_to_limit = ( + adt_result.get_new_det_gpu_ids( + tracker_metadata_prev, is_image_only, det_scores, self + ) + ) + + # Convert GPU removal mask to CPU obj_id set for metadata updates + if not hasattr(self, "_warm_up_complete") or self._warm_up_complete: + obj_ids_all_gpu = tracker_metadata_prev["obj_ids_all_gpu"] + to_remove_cpu = to_remove_mask.cpu().numpy() + obj_ids_newly_removed = set(obj_ids_all_gpu[to_remove_cpu].tolist()) + else: + obj_ids_newly_removed = set() + + # Step 4: update the SAM2 metadata based on the update plan + # note: except for "rank0_metadata" (that is only available on GPU 0), + # the updated `tracker_metadata_new` should be identical on all GPUs + for rank in range(self.world_size): + new_det_obj_ids_this_gpu = new_det_obj_ids[new_det_gpu_ids == rank] + updated_obj_ids_this_gpu = tracker_metadata_new["obj_ids_per_gpu"][rank] + if len(new_det_obj_ids_this_gpu) > 0: + updated_obj_ids_this_gpu = np.concatenate( + [updated_obj_ids_this_gpu, new_det_obj_ids_this_gpu] + ) + if len(obj_ids_newly_removed) > 0: + is_removed = np.isin( + updated_obj_ids_this_gpu, list(obj_ids_newly_removed) + ) + updated_obj_ids_this_gpu = updated_obj_ids_this_gpu[~is_removed] + tracker_metadata_new["obj_ids_per_gpu"][rank] = updated_obj_ids_this_gpu + tracker_metadata_new["num_obj_per_gpu"][rank] = len( + updated_obj_ids_this_gpu + ) + tracker_metadata_new["obj_ids_all_gpu"] = np.concatenate( + tracker_metadata_new["obj_ids_per_gpu"] + ) + # update object scores and the maximum object ID assigned so far + if len(new_det_obj_ids) > 0: + det_scores_np: np.ndarray = det_scores.cpu().numpy() + tracker_metadata_new["obj_id_to_score"].update( + zip(new_det_obj_ids, det_scores_np[adt_result.new_det_fa_inds]) + ) + # sam2 scores are not available for new objects, use det score instead. + # Store as GPU tensors for consistency with SAM2 propagation scores + new_det_scores_tensor = det_scores[adt_result.new_det_fa_inds] + tracker_metadata_new["obj_id_to_sam2_score_frame_wise"][frame_idx].update( + zip(new_det_obj_ids, new_det_scores_tensor) + ) + tracker_metadata_new["max_obj_id"] = max( + tracker_metadata_new["max_obj_id"], + np.max(new_det_obj_ids), + ) + # for removed objects, we set their scores to a very low value (-1e4) but still + # keep them in "obj_id_to_score" (it's easier to handle outputs this way) + for obj_id in obj_ids_newly_removed: + tracker_metadata_new["obj_id_to_score"][obj_id] = -1e4 + # Store as GPU tensor for consistency + tracker_metadata_new["obj_id_to_sam2_score_frame_wise"][frame_idx][ + obj_id + ] = torch.tensor(-1e4, dtype=torch.float32, device=det_scores.device) + tracker_metadata_new["obj_id_to_last_occluded"].pop(obj_id, None) + # check that "rank0_metadata" is in tracker_metadata_new if and only if it's GPU 0 + assert "rank0_metadata" in tracker_metadata_new + if self.masklet_confirmation_enable: + with torch.profiler.record_function("update_masklet_confirmation_status"): + rank0_metadata = self.update_masklet_confirmation_status( + rank0_metadata=tracker_metadata_new["rank0_metadata"], + obj_ids_all_gpu_prev=tracker_metadata_prev["obj_ids_all_gpu"], + obj_ids_all_gpu_updated=tracker_metadata_new["obj_ids_all_gpu"], + det_to_matched_trk_obj_ids=adt_result.det_to_matched_trk_obj_ids, + new_det_obj_ids=new_det_obj_ids, + ) + tracker_metadata_new["rank0_metadata"] = rank0_metadata + + # Compact GPU metadata NOW (after sync) in preparation for next frame + # This removes entries for objects that will be deleted in execution phase + # so next frame's _process_hotstart_gpu doesn't need to do sync-inducing compaction + if not hasattr(self, "_warm_up_complete") or self._warm_up_complete: + if ( + "gpu_metadata" in tracker_metadata_new + and tracker_metadata_new["gpu_metadata"].get("N_obj", 0) > 0 + ): + with torch.profiler.record_function("compact_gpu_metadata"): + gpu_meta = tracker_metadata_new["gpu_metadata"] + removed_mask = gpu_meta[ + "removed_mask" + ] # (N_obj,) - which objects marked for removal + keep_indices = torch.nonzero(~removed_mask, as_tuple=True)[0] + + gpu_meta["obj_first_frame"] = gpu_meta["obj_first_frame"][ + keep_indices + ] + gpu_meta["consecutive_unmatch_count"] = gpu_meta[ + "consecutive_unmatch_count" + ][keep_indices] + gpu_meta["trk_keep_alive"] = gpu_meta["trk_keep_alive"][ + keep_indices + ] + gpu_meta["removed_mask"] = gpu_meta["removed_mask"][ + keep_indices + ] # Should be all False + gpu_meta["last_occluded_tensor"] = gpu_meta["last_occluded_tensor"][ + keep_indices + ] + + # Compact pairwise matrix (remove both rows and columns) + overlap_counts = gpu_meta["overlap_pair_counts"] + overlap_counts = overlap_counts[keep_indices][:, keep_indices] + gpu_meta["overlap_pair_counts"] = overlap_counts + + # Update N_obj to reflect post-removal count + gpu_meta["N_obj"] = keep_indices.size(0) + + # After compaction, extend gpu_metadata with new objects' initial values + # This ensures obj_first_frame is set to the detection frame, not propagation frame + num_new = len(new_det_obj_ids) + if num_new > 0: + with torch.profiler.record_function( + "extend_gpu_metadata_for_new_objects" + ): + gpu_meta = tracker_metadata_new["gpu_metadata"] + device = det_scores.device + NEVER_OCCLUDED = -1 + + # Extend all metadata tensors for new objects + gpu_meta["obj_first_frame"] = torch.cat( + [ + gpu_meta.get( + "obj_first_frame", + torch.empty(0, dtype=torch.long, device=device), + ), + torch.full( + (num_new,), frame_idx, dtype=torch.long, device=device + ), + ] + ) + gpu_meta["consecutive_unmatch_count"] = torch.cat( + [ + gpu_meta.get( + "consecutive_unmatch_count", + torch.empty(0, dtype=torch.long, device=device), + ), + torch.zeros(num_new, dtype=torch.long, device=device), + ] + ) + gpu_meta["trk_keep_alive"] = torch.cat( + [ + gpu_meta.get( + "trk_keep_alive", + torch.empty(0, dtype=torch.long, device=device), + ), + torch.full( + (num_new,), + self.init_trk_keep_alive, + dtype=torch.long, + device=device, + ), + ] + ) + gpu_meta["removed_mask"] = torch.cat( + [ + gpu_meta.get( + "removed_mask", + torch.empty(0, dtype=torch.bool, device=device), + ), + torch.zeros(num_new, dtype=torch.bool, device=device), + ] + ) + gpu_meta["last_occluded_tensor"] = torch.cat( + [ + gpu_meta.get( + "last_occluded_tensor", + torch.empty(0, dtype=torch.long, device=device), + ), + torch.full( + (num_new,), + NEVER_OCCLUDED, + dtype=torch.long, + device=device, + ), + ] + ) + + # Grow overlap matrix + old_N = gpu_meta.get("N_obj", 0) + new_N = old_N + num_new + old_overlap = gpu_meta.get( + "overlap_pair_counts", + torch.zeros((0, 0), dtype=torch.long, device=device), + ) + new_overlap = torch.zeros( + (new_N, new_N), dtype=torch.long, device=device + ) + if old_N > 0: + new_overlap[:old_N, :old_N] = old_overlap + gpu_meta["overlap_pair_counts"] = new_overlap + + gpu_meta["N_obj"] = new_N + + sam2_update_plan = { + "new_det_fa_inds": adt_result.new_det_fa_inds, # np.ndarray + "new_det_obj_ids": new_det_obj_ids, # np.ndarray + "new_det_gpu_ids": new_det_gpu_ids, # np.ndarray + "unmatched_trk_obj_ids": adt_result.unmatched_trk_obj_ids, # np.ndarray + "det_to_matched_trk_obj_ids": adt_result.det_to_matched_trk_obj_ids, # dict + "obj_ids_newly_removed": obj_ids_newly_removed, # set + "num_obj_dropped_due_to_limit": num_obj_dropped_due_to_limit, # int + "trk_id_to_max_iou_high_conf_det": adt_result.trk_id_to_max_iou_high_conf_det, # dict + "reconditioned_obj_ids": reconditioned_obj_ids, # set + } + return sam2_update_plan, tracker_metadata_new + + def _suppress_overlapping_based_on_recent_occlusion( + self, + frame_idx: int, + tracker_low_res_masks_global: Tensor, + tracker_metadata_prev: Dict[str, Any], + tracker_metadata_new: Dict[str, Any], + to_remove_mask: Tensor, # GPU boolean mask (N_obj,) instead of CPU set + reverse: bool = False, + ): + """ + Suppress overlapping masks based on the most recent occlusion information. If an object is removed by hotstart, we always suppress it if it overlaps with any other object. + Args: + frame_idx (int): The current frame index. + tracker_low_res_masks_global (Tensor): The low-resolution masks for the current frame. + tracker_metadata_prev (Dict[str, Any]): The metadata from the previous frame. + tracker_metadata_new (Dict[str, Any]): The metadata for the current frame (with updated gpu_metadata from _process_hotstart_gpu). + to_remove_mask (Tensor): GPU boolean mask (N_obj,) indicating which objects are removed. + Return: + Tensor: The updated low-resolution masks with some objects suppressed. + """ + # NOTE: obj_ids_global is only used for debug logging, so we can use prev (it won't match perfectly but close enough for debugging) + # The actual suppression logic uses GPU tensors which ARE in the correct index space from tracker_metadata_new + obj_ids_global = tracker_metadata_prev["obj_ids_all_gpu"] + binary_tracker_low_res_masks_global = tracker_low_res_masks_global > 0 + batch_size = tracker_low_res_masks_global.size(0) + num_ids = len(obj_ids_global) + + # immediately to force proper debugging. (Aligned with merge decision 4.5.2) + assert batch_size == num_ids, ( + f"Mask/metadata count mismatch in _suppress_overlapping: " + f"batch_size={batch_size}, num_ids={num_ids}, frame_idx={frame_idx}" + ) + + binary_tracker_low_res_masks_global = tracker_low_res_masks_global > 0 + if batch_size > 0: + assert ( + len(obj_ids_global) == batch_size + ), f"Mismatch in number of objects: {len(obj_ids_global)} vs {batch_size}" + NEVER_OCCLUDED = -1 + ALWAYS_OCCLUDED = 100000 # This value should be larger than any possible frame index, indicates that the object was removed by hotstart logic + + # GPU-vectorized: Build last_occluded_prev tensor without iteration/syncs + device = binary_tracker_low_res_masks_global.device + + # Get last_occluded from UPDATED gpu_metadata (already in correct index space from _process_hotstart_gpu) + gpu_metadata_new = tracker_metadata_new["gpu_metadata"] + last_occluded_prev = gpu_metadata_new["last_occluded_tensor"] + + # Sanity check: ensure last_occluded_tensor is in sync with batch_size + assert last_occluded_prev.size(0) == batch_size, ( + f"last_occluded_tensor size mismatch: {last_occluded_prev.size(0)} vs {batch_size}. " + f"This indicates gpu_metadata tensors are out of sync." + ) + + # Set ALWAYS_OCCLUDED for removed objects (fully vectorized, no sync!) + last_occluded_prev = torch.where( + to_remove_mask, + torch.full_like(last_occluded_prev, ALWAYS_OCCLUDED), + last_occluded_prev, + ) + + to_suppress = self._get_objects_to_suppress_based_on_most_recently_occluded( + binary_tracker_low_res_masks_global, + last_occluded_prev, + obj_ids_global, + frame_idx, + reverse, + ) + + # Update metadata with occlusion information (fully vectorized) + is_obj_occluded = ~(binary_tracker_low_res_masks_global.any(dim=(-1, -2))) + is_obj_occluded_or_suppressed = is_obj_occluded | to_suppress + last_occluded_new = last_occluded_prev.clone() + last_occluded_new[is_obj_occluded_or_suppressed] = frame_idx + + # Store in gpu_metadata to keep it aligned with other metadata tensors + tracker_metadata_new["gpu_metadata"]["last_occluded_tensor"] = ( + last_occluded_new + ) + + # Also maintain legacy dict format for backwards compatibility + # This conversion happens on CPU AFTER memory encoding, not in critical path + tracker_metadata_new[ + "obj_id_to_last_occluded" + ] = {} # Will be populated later if needed + + # Zero out suppressed masks before memory encoding + NO_OBJ_LOGIT = -10 + tracker_low_res_masks_global[to_suppress] = NO_OBJ_LOGIT + + return tracker_low_res_masks_global + + def _create_planning_metadata(self, tracker_metadata_prev): + """Extend planning metadata with multiplex-specific fields.""" + metadata = super()._create_planning_metadata(tracker_metadata_prev) + if self.is_multiplex: + metadata["num_buc_per_gpu"] = self._deepcopy( + tracker_metadata_prev["num_buc_per_gpu"] + ) + metadata["gpu_metadata"] = tracker_metadata_prev["gpu_metadata"] + return metadata + + def _post_execution_phase_hook(self, tracker_states_local, tracker_metadata_new): + """Update bucket count after execution phase (multiplex-specific).""" + if self.is_multiplex and tracker_metadata_new is not None: + actual_bucket_count = self._count_buckets_in_states(tracker_states_local) + tracker_metadata_new["num_buc_per_gpu"][self.rank] = actual_bucket_count + + def _count_buckets_in_states(self, tracker_states_local: List[Any]) -> int: + """Count the total number of buckets across all states.""" + if not self.is_multiplex: + return 0 + total_buckets = 0 + for state in tracker_states_local: + if "multiplex_state" in state: + total_buckets += state["multiplex_state"].num_buckets + return total_buckets + + def build_outputs( + self, + frame_idx: int, + num_frames: int, + reverse: bool, + det_out: Dict[ + str, Tensor + ], # TODO: Only det_out["mask"][new_det_fa_inds_local_t] is needed + tracker_low_res_masks_global: Tensor, + tracker_obj_scores_global: Tensor, + tracker_metadata_prev: Dict[str, np.ndarray], + sam2_update_plan: Dict[str, np.ndarray], + orig_vid_height: int, + orig_vid_width: int, + reconditioned_obj_ids: set = None, + det_to_matched_trk_obj_ids: dict = None, + ): + new_det_fa_inds: np.ndarray = sam2_update_plan["new_det_fa_inds"] + new_det_obj_ids: np.ndarray = sam2_update_plan["new_det_obj_ids"] + obj_id_to_mask = {} # obj_id --> output mask tensor + + # Part 1: masks from previous SAM2 propagation + # Align IDs and masks from previous SAM2 propagation + existing_masklet_obj_ids_all = tracker_metadata_prev["obj_ids_all_gpu"] + existing_masklet_obj_ids_per_gpu = np.concatenate( + tracker_metadata_prev["obj_ids_per_gpu"] + ) + use_per_gpu_ids = len(existing_masklet_obj_ids_per_gpu) != len( + existing_masklet_obj_ids_all + ) or not np.array_equal( + existing_masklet_obj_ids_per_gpu, existing_masklet_obj_ids_all + ) + existing_masklet_obj_ids = ( + existing_masklet_obj_ids_per_gpu + if use_per_gpu_ids + else existing_masklet_obj_ids_all + ) + existing_masklet_video_res_masks = F.interpolate( + tracker_low_res_masks_global.unsqueeze(1), + size=(orig_vid_height, orig_vid_width), + mode="bilinear", + align_corners=False, + ) # (num_obj, 1, H_video, W_video) + # Pad/truncate masks to match metadata count + num_masks = existing_masklet_video_res_masks.size(0) + num_ids = len(existing_masklet_obj_ids) + if num_masks != num_ids: + if num_masks < num_ids: + pad = existing_masklet_video_res_masks.new_zeros( + (num_ids - num_masks, 1, orig_vid_height, orig_vid_width) + ) + existing_masklet_video_res_masks = torch.cat( + [existing_masklet_video_res_masks, pad], dim=0 + ) + else: + existing_masklet_video_res_masks = existing_masklet_video_res_masks[ + :num_ids + ] + existing_masklet_binary = existing_masklet_video_res_masks > 0 + for obj_id, mask in zip(existing_masklet_obj_ids, existing_masklet_binary): + obj_id_to_mask[obj_id] = mask # (1, H_video, W_video) + + # Part 2: masks from new detections + new_det_fa_inds_t = torch.from_numpy(new_det_fa_inds) + new_det_low_res_masks = det_out["mask"][new_det_fa_inds_t].unsqueeze(1) + new_det_low_res_masks = fill_holes_in_mask_scores( + new_det_low_res_masks, + fill_hole_area=self.fill_hole_area, + sprinkle_removal_area=self.sprinkle_removal_area, + fill_holes=True, + remove_sprinkles=True, + ) + new_masklet_video_res_masks = F.interpolate( + new_det_low_res_masks, + size=(orig_vid_height, orig_vid_width), + mode="bilinear", + align_corners=False, + ) # (num_obj, 1, H_video, W_video) + + new_masklet_binary = new_masklet_video_res_masks > 0 + assert len(new_det_obj_ids) == len(new_masklet_video_res_masks) + for obj_id, mask in zip(new_det_obj_ids, new_masklet_binary): + obj_id_to_mask[obj_id] = mask # (1, H_video, W_video) + + return obj_id_to_mask + + def _get_objects_to_suppress_based_on_most_recently_occluded( + self, + binary_low_res_masks: Tensor, + last_occluded: Tensor, # GPU tensor (N_obj,) with frame indices + obj_ids: np.ndarray, # numpy array of object IDs + frame_idx: int = None, + reverse: bool = False, + ): + # Suppress overlapping masks for objects that were most recently occluded + assert ( + binary_low_res_masks.dtype == torch.bool + ), f"Expected boolean tensor, got {binary_low_res_masks.dtype}" + to_suppress = torch.zeros( + binary_low_res_masks.size(0), + device=binary_low_res_masks.device, + dtype=torch.bool, + ) + if len(obj_ids) <= 1: + return to_suppress + + iou = mask_iou(binary_low_res_masks, binary_low_res_masks) # [N,N] + + # Create masks for upper triangular matrix (i < j) and IoU threshold + mask_iou_thresh = ( + iou >= self.suppress_overlapping_based_on_recent_occlusion_threshold + ) + overlapping_pairs = torch.triu(mask_iou_thresh, diagonal=1) # [N,N] + + last_occ_expanded_i = last_occluded.unsqueeze(1) # (N, 1) + last_occ_expanded_j = last_occluded.unsqueeze(0) # (1, N) + cmp_op = torch.gt if not reverse else torch.lt + + if self.allow_unoccluded_to_suppress: + # Suppress most recently occluded + suppress_i_mask = overlapping_pairs & cmp_op( + last_occ_expanded_i, last_occ_expanded_j + ) + + suppress_j_mask = overlapping_pairs & cmp_op( + last_occ_expanded_j, last_occ_expanded_i + ) + else: + # Suppress most recently occluded + suppress_i_mask = ( + overlapping_pairs + & cmp_op( + last_occ_expanded_i, last_occ_expanded_j + ) # (last_occ_expanded_i > last_occ_expanded_j) + & (last_occ_expanded_j > -1) + # j can suppress i only if j was previously occluded + ) + + suppress_j_mask = ( + overlapping_pairs + & cmp_op(last_occ_expanded_j, last_occ_expanded_i) + & ( + last_occ_expanded_i > -1 + ) # i can suppress j only if i was previously occluded + ) + + # Apply suppression + to_suppress = suppress_i_mask.any(dim=1) | suppress_j_mask.any(dim=0) + + # Log for debugging + if ( + self.rank == 0 + and logger.isEnabledFor(logging.DEBUG) + and frame_idx is not None + ): + suppress_i_mask = suppress_i_mask.cpu().numpy() + suppress_j_mask = suppress_j_mask.cpu().numpy() + last_occluded = last_occluded.cpu().numpy() + + # Find all suppression pairs without using torch.where + batch_size = suppress_i_mask.shape[0] + + # Log i-suppression cases (where i gets suppressed in favor of j) + for i in range(batch_size): + for j in range(batch_size): + if suppress_i_mask[i, j]: + logger.debug( + f"{frame_idx=}: Suppressing obj {obj_ids[i]} last occluded {last_occluded[i]} in favor of {obj_ids[j]} last occluded {last_occluded[j]}" + ) + + # Log j-suppression cases (where j gets suppressed in favor of i) + for i in range(batch_size): + for j in range(batch_size): + if suppress_j_mask[i, j]: + logger.debug( + f"{frame_idx=}: Suppressing obj {obj_ids[j]} last occluded {last_occluded[j]} in favor of {obj_ids[i]} last occluded {last_occluded[i]}" + ) + + return to_suppress + + def _propogate_tracker_one_frame_local_gpu( + self, + inference_states: List[Any], + frame_idx: int, + reverse: bool, + # by default, we disable memory encoding until we gather all outputs + run_mem_encoder: bool = False, + # When specified, only return masks/scores for these object ids + filter_obj_ids: Optional[List[int]] = None, + ): + """ + inference_states: List of inference states, each state corresponds to a different set of objects. + """ + obj_ids_local = [] + low_res_masks_list = [] + obj_scores_list = [] + for inference_state in inference_states: + if len(inference_state["obj_ids"]) == 0: + continue # skip propagation on empty inference states + + # propagate one frame + num_frames_propagated = 0 + with torch.profiler.record_function("sam2_predictor.propagate_in_video"): + for out in self.tracker.propagate_in_video( + inference_state, + start_frame_idx=frame_idx, + # end_frame_idx = start_frame_idx + max_frame_num_to_track + # (i.e. propagating 1 frame since end_frame_idx is inclusive) + max_frame_num_to_track=0, + reverse=reverse, + tqdm_disable=True, + run_mem_encoder=run_mem_encoder, + ): + # TODO we only need low-res outputs here for all-gather across GPUs, + # so we can remove the high-res interpolation in `propagate_in_video` + out_frame_idx, out_obj_ids, out_low_res_masks, _, out_obj_scores = ( + out + ) + num_frames_propagated += 1 + + # only 1 frames should be propagated + assert ( + num_frames_propagated == 1 and out_frame_idx == frame_idx + ), f"num_frames_propagated: {num_frames_propagated}, out_frame_idx: {out_frame_idx}, frame_idx: {frame_idx}" + assert isinstance(out_obj_ids, list) + # Optionally filter to a subset of object ids (for partial propagation). + # We also clamp indices to available rows to avoid CUDA index_select assertions. + if filter_obj_ids is not None: + if len(out_obj_ids) > 0: + max_mask_rows = out_low_res_masks.shape[0] + max_score_rows = out_obj_scores.shape[0] + # Special case: common single-object refinement path where SAM2 returns a single mask row + # but a longer out_obj_ids list for the state. Treat the lone row as the requested object. + if ( + len(filter_obj_ids) == 1 + and max_mask_rows == 1 + and max_score_rows == 1 + ): + out_obj_ids = [filter_obj_ids[0]] + keep_indices = [0] + else: + keep_indices = [ + i + for i, oid in enumerate(out_obj_ids) + if oid in filter_obj_ids + and i < max_mask_rows + and i < max_score_rows + ] + else: + keep_indices = [] + if len(keep_indices) > 0: + idx_tensor = torch.as_tensor( + keep_indices, device=out_low_res_masks.device, dtype=torch.long + ) + out_low_res_masks = out_low_res_masks.index_select( + dim=0, index=idx_tensor + ) + out_obj_scores = out_obj_scores.index_select( + dim=0, index=idx_tensor + ) + out_obj_ids = [out_obj_ids[i] for i in keep_indices] + else: + # no selected objects in this local state; skip appending + out_obj_ids = [] + + if len(out_obj_ids) > 0: + obj_ids_local.extend(out_obj_ids) + low_res_masks_list.append(out_low_res_masks.squeeze(1)) + obj_scores_list.append(out_obj_scores.squeeze(1)) + + # concatenate the output masklets from all local inference states + + with torch.profiler.record_function( + "sam2_predictor.propagate_in_video.fill_holes" + ): + H_mask = W_mask = self.tracker.low_res_mask_size + if len(low_res_masks_list) > 0: + low_res_masks_local = torch.cat(low_res_masks_list, dim=0) + obj_scores_local = torch.cat(obj_scores_list, dim=0) + assert low_res_masks_local.shape[1:] == (H_mask, W_mask) + + # Apply hole filling to the masks + low_res_masks_local = fill_holes_in_mask_scores( + low_res_masks_local.unsqueeze(1), + fill_hole_area=self.fill_hole_area, + sprinkle_removal_area=self.sprinkle_removal_area, + fill_holes=True, + remove_sprinkles=True, + ) + low_res_masks_local = low_res_masks_local.squeeze(1) + else: + low_res_masks_local = torch.zeros(0, H_mask, W_mask, device=self.device) + obj_scores_local = torch.zeros(0, device=self.device) + + if self.is_multiplex and self.tracker.is_multiplex_dynamic: + # obj_ids_local might not be sorted, which is problematic because + # the rest of the code assumes that they are. + # Currently this only happens in the dynamic multiplex setting (since we backfill states) + # so we only check for this condition here, but this should be generally applicable. + # Note that a similar remapping is necessary when we update the memory, e.g., + # in _tracker_update_memories + if obj_ids_local != sorted(obj_ids_local): + # Get sorting permutation + sort_indices = sorted( + range(len(obj_ids_local)), key=lambda i: obj_ids_local[i] + ) + # Apply permutation to reorder everything + obj_ids_local = [obj_ids_local[i] for i in sort_indices] + low_res_masks_local = low_res_masks_local[sort_indices] + obj_scores_local = obj_scores_local[sort_indices] + + if self.is_multiplex and self.tracker.is_multiplex_dynamic: + # obj_ids_local might not be sorted, which is problematic because + # the rest of the code assumes that they are. + # Currently this only happens in the dynamic multiplex setting (since we backfill states) + # so we only check for this condition here, but this should be generally applicable. + # Note that a similar remapping is necessary when we update the memory, e.g., + # in _tracker_update_memories + if obj_ids_local != sorted(obj_ids_local): + # Get sorting permutation + sort_indices = sorted( + range(len(obj_ids_local)), key=lambda i: obj_ids_local[i] + ) + # Apply permutation to reorder everything + obj_ids_local = [obj_ids_local[i] for i in sort_indices] + if low_res_masks_local.shape[0] == len(sort_indices): + low_res_masks_local = low_res_masks_local[sort_indices] + obj_scores_local = obj_scores_local[sort_indices] + + return obj_ids_local, low_res_masks_local, obj_scores_local + + def _associate_det_trk( + self, + det_masks: Tensor, + det_scores: Tensor, + det_keep: Tensor, + trk_masks: Tensor, + trk_obj_ids: np.ndarray, + default_det_thresh: Optional[float] = None, + ): + """ + Match detections on the current frame with the existing masklets. + + Args: + - det_masks: (N, H, W) tensor of predicted masks + - det_scores: (N,) array of detection scores + - trk_masks: (M, H, W) tensor of track masks + - trk_obj_ids: (M,) array of object IDs corresponding to trk_masks + + Returns: + - new_det_fa_inds: array of new object indices among in FA detection outputs + - unmatched_trk_obj_ids: array of existing masklet object IDs that are not matched + to any detections on this frame (for unmatched, we only count masklets with >0 area) + - det_to_matched_trk_obj_ids: dict[int, np.ndarray]: mapping from FA detection indices + to the list of matched tracklet object IDs + - empty_trk_obj_ids: array of existing masklet object IDs with zero area in SAM2 prediction + """ + HIGH_CONF_THRESH = 0.8 + + iou_threshold = self.assoc_iou_thresh + iou_threshold_trk = self.trk_assoc_iou_thresh + new_det_thresh = ( + self.new_det_thresh if default_det_thresh is None else default_det_thresh + ) + + assert det_masks.is_floating_point(), "float tensor expected (do not binarize)" + assert trk_masks.is_floating_point(), "float tensor expected (do not binarize)" + assert ( + trk_masks.size(0) == len(trk_obj_ids) + ), f"trk_masks and trk_obj_ids should have the same length, {trk_masks.size(0)} vs {len(trk_obj_ids)}" + if trk_masks.size(0) == 0: + with torch.profiler.record_function("No tracklets"): + num_trk = 0 + is_new_det = det_scores >= new_det_thresh + trk_is_unmatched = torch.zeros( + num_trk, dtype=torch.bool, device=det_scores.device + ) + trk_is_nonempty = torch.zeros( + num_trk, dtype=torch.bool, device=det_scores.device + ) + num_det = det_scores.shape[0] + det_to_max_iou_trk_idx = torch.full( + (num_det,), -1, dtype=torch.long, device=det_scores.device + ) + det_is_high_conf = det_scores >= HIGH_CONF_THRESH + det_is_high_iou = torch.zeros( + num_det, dtype=torch.bool, device=det_scores.device + ) + im_mask = torch.zeros( + num_det, num_trk, dtype=torch.bool, device=det_scores.device + ) + return LazyAssociateDetTrkResult( + trk_is_unmatched, + trk_is_nonempty, + is_new_det, + det_to_max_iou_trk_idx, + det_is_high_conf, + det_is_high_iou, + det_keep, + im_mask, + ) + elif det_masks.size(0) == 0: + with torch.profiler.record_function("No detections"): + assert det_keep.size(0) == 0 # Make sure the keep mask agrees + trk_is_nonempty = (trk_masks > 0).any(dim=(1, 2)) + num_det = 0 + num_trk = trk_masks.shape[0] + trk_is_unmatched = torch.ones( + num_trk, dtype=torch.bool, device=trk_masks.device + ) + trk_is_nonempty_tensor = trk_is_nonempty.to(trk_masks.device) + is_new_det = torch.zeros( + num_det, dtype=torch.bool, device=trk_masks.device + ) + det_to_max_iou_trk_idx = torch.full( + (num_det,), -1, dtype=torch.long, device=trk_masks.device + ) + det_is_high_conf = torch.zeros( + num_det, dtype=torch.bool, device=trk_masks.device + ) + det_is_high_iou = torch.zeros( + num_det, dtype=torch.bool, device=trk_masks.device + ) + im_mask = torch.zeros( + num_det, num_trk, dtype=torch.bool, device=trk_masks.device + ) + return LazyAssociateDetTrkResult( + trk_is_unmatched, + trk_is_nonempty_tensor, + is_new_det, + det_to_max_iou_trk_idx, + det_is_high_conf, + det_is_high_iou, + det_keep, + im_mask, + ) + + if det_masks.shape[-2:] != trk_masks.shape[-2:]: + # resize to the smaller size to save GPU memory + if np.prod(det_masks.shape[-2:]) < np.prod(trk_masks.shape[-2:]): + trk_masks = F.interpolate( + trk_masks.unsqueeze(1), + size=det_masks.shape[-2:], + mode="bilinear", + align_corners=False, + ).squeeze(1) + else: + # resize detections to track size + det_masks = F.interpolate( + det_masks.unsqueeze(1), + size=trk_masks.shape[-2:], + mode="bilinear", + align_corners=False, + ).squeeze(1) + + with torch.profiler.record_function("associate_det_trk_compilable"): + if trk_masks.shape[0] < self.max_num_objects: + padding_size = self.max_num_objects - trk_masks.shape[0] + trk_masks_padded = torch.cat( + [ + trk_masks, + torch.zeros( + padding_size, + *trk_masks.shape[1:], + device=trk_masks.device, + dtype=trk_masks.dtype, + ), + ], + dim=0, + ) + else: + trk_masks_padded = trk_masks + result = _associate_det_trk_compilable( + det_masks, + det_scores, + det_keep, + trk_masks_padded, + new_det_thresh, + iou_threshold_trk, + iou_threshold, + HIGH_CONF_THRESH, + self.use_iom_recondition, + self.o2o_matching_masklets_enable, + self.iom_thresh_recondition, + self.iou_thresh_recondition, + ) + ( + trk_is_unmatched, + trk_is_nonempty, + is_new_det, + det_to_max_iou_trk_idx, + det_is_high_conf, + det_is_high_iou, + det_keep, + im_mask, + ) = result + trk_is_unmatched = trk_is_unmatched[: trk_masks.shape[0]] + trk_is_nonempty = trk_is_nonempty[: trk_masks.shape[0]] + im_mask = im_mask[:, : trk_masks.shape[0]] + + return LazyAssociateDetTrkResult( + trk_is_unmatched, + trk_is_nonempty, + is_new_det, + det_to_max_iou_trk_idx, + det_is_high_conf, + det_is_high_iou, + det_keep, + im_mask, + ) + + def _assign_new_det_to_gpus(self, new_det_num, prev_workload_per_gpu): + """Distribute the new objects to the GPUs with the least workload.""" + workload_per_gpu: np.ndarray = prev_workload_per_gpu.copy() + new_det_gpu_ids = np.zeros(new_det_num, np.int64) + + if self.is_multiplex: + # assign the objects in a batch of multiplex_count + for i in range(0, new_det_num, self.bucket_capacity): + # find the GPU with the least workload + min_gpu = np.argmin(workload_per_gpu) + new_det_gpu_ids[i : i + self.bucket_capacity] = min_gpu + workload_per_gpu[min_gpu] += 1 + else: + # assign the objects one by one + for i in range(len(new_det_gpu_ids)): + # find the GPU with the least workload + min_gpu = np.argmin(workload_per_gpu) + new_det_gpu_ids[i] = min_gpu + workload_per_gpu[min_gpu] += 1 + return new_det_gpu_ids + + def _process_hotstart_gpu( + self, + frame_idx: int, + reverse: bool, + adt_result, # LazyAssociateDetTrkResult (always lazy now) + tracker_metadata_prev: Dict[str, Any], + gpu_metadata_prev: Dict[str, Tensor], + ) -> Tuple[Tensor, Tensor, Dict[str, Tensor]]: + """ + Compute removal/suppression masks entirely on GPU without ANY syncs or branches. + + Uses position-indexed metadata (indexed 0 to N_obj-1) instead of obj_id-indexed + to avoid needing obj_ids as GPU tensor. + + Returns: + to_remove: boolean tensor (N_obj,) - objects to remove this frame + to_suppress: boolean tensor (N_obj,) - objec ts to suppress (overlap suppression) + gpu_metadata_new: updated GPU metadata for next frame + """ + # Handle edge case: if adt_result is already realized (no tracks exist), + # return empty masks since there's nothing to remove + if isinstance(adt_result, RealizedAssociateDetTrkresult): + # No tracks exist, so no objects to remove/suppress + empty_mask = torch.zeros(0, dtype=torch.bool, device=self.device) + return empty_mask, empty_mask, {"N_obj": 0} + + device = adt_result.trk_is_unmatched.device + N_obj = adt_result.trk_is_unmatched.size(0) # Number of current objects + + # ============================================================================ + # STEP 1: Initialize/extract position-indexed GPU metadata + # ============================================================================ + + # All metadata tensors are indexed by POSITION (0 to N_obj-1), not by obj_id + # This grows/shrinks each frame as objects are added/removed + + # Get previous frame's metadata (sized for previous N_obj) + # NOTE: Metadata is already compacted from previous frame (removed objects are already filtered out) + prev_N_obj = gpu_metadata_prev.get("N_obj", 0) + + if prev_N_obj > 0: + # Metadata from previous frame (position-indexed, already compacted) + obj_first_frame_prev = gpu_metadata_prev["obj_first_frame"] # (prev_N_obj,) + consecutive_unmatch_count_prev = gpu_metadata_prev[ + "consecutive_unmatch_count" + ] # (prev_N_obj,) + trk_keep_alive_prev = gpu_metadata_prev["trk_keep_alive"] # (prev_N_obj,) + removed_mask_prev = gpu_metadata_prev[ + "removed_mask" + ] # (prev_N_obj,) - should be all False after compaction + overlap_pair_counts_prev = gpu_metadata_prev[ + "overlap_pair_counts" + ] # (prev_N_obj, prev_N_obj) + last_occluded_prev = gpu_metadata_prev[ + "last_occluded_tensor" + ] # (prev_N_obj,) + else: + # First frame - no previous metadata + obj_first_frame_prev = None + consecutive_unmatch_count_prev = None + trk_keep_alive_prev = None + removed_mask_prev = None + overlap_pair_counts_prev = None + last_occluded_prev = None + + # ============================================================================ + # STEP 2: Carry forward metadata from previous frame + # ============================================================================ + + # Current frame has N_obj objects (from propagation) + # New objects are added via extend_gpu_metadata_for_new_objects AFTER compaction, + # so prev_N_obj should already include objects detected on previous frame. + # N_obj should equal prev_N_obj (no new objects mid-planning-phase). + assert ( + N_obj == prev_N_obj + ), f"N_obj ({N_obj}) should equal prev_N_obj ({prev_N_obj}); new objects handled after compaction" + + # Carry forward existing metadata (or initialize if first frame) + NEVER_OCCLUDED = -1 + obj_first_frame = ( + obj_first_frame_prev + if obj_first_frame_prev is not None + else torch.full((N_obj,), frame_idx, dtype=torch.long, device=device) + ) + consecutive_unmatch_count = ( + consecutive_unmatch_count_prev + if consecutive_unmatch_count_prev is not None + else torch.zeros(N_obj, dtype=torch.long, device=device) + ) + trk_keep_alive = ( + trk_keep_alive_prev + if trk_keep_alive_prev is not None + else torch.zeros(N_obj, dtype=torch.long, device=device) + ) + removed_mask = ( + removed_mask_prev + if removed_mask_prev is not None + else torch.zeros(N_obj, dtype=torch.bool, device=device) + ) + overlap_pair_counts = ( + overlap_pair_counts_prev + if overlap_pair_counts_prev is not None + else torch.zeros((N_obj, N_obj), dtype=torch.long, device=device) + ) + last_occluded = ( + last_occluded_prev + if last_occluded_prev is not None + else torch.full((N_obj,), NEVER_OCCLUDED, dtype=torch.long, device=device) + ) + + # ============================================================================ + # STEP 3: Update keep-alive counters (fully vectorized) + # ============================================================================ + + # Determine which tracks are matched by ANY detection + trk_is_matched = adt_result.im_mask.any(dim=0) # (N_obj,) + + # Update: +1 for matched, -1 for unmatched, clamp to [min, max] + trk_keep_alive = torch.where( + trk_is_matched, trk_keep_alive + 1, trk_keep_alive - 1 + ) + trk_keep_alive = torch.clamp( + trk_keep_alive, min=self.min_trk_keep_alive, max=self.max_trk_keep_alive + ) + + # Also decrement for empty tracklets (if configured) + if self.decrease_trk_keep_alive_for_empty_masklets: + trk_keep_alive = torch.where( + ~adt_result.trk_is_nonempty, + torch.clamp(trk_keep_alive - 1, min=self.min_trk_keep_alive), + trk_keep_alive, + ) + + # ============================================================================ + # STEP 4: Update total unmatch counters (fully vectorized) + # ============================================================================ + + # Increment for unmatched, but DON'T reset for matched + # Original logic accumulates total unmatched frames, not consecutive + consecutive_unmatch_count = torch.where( + adt_result.trk_is_unmatched, + consecutive_unmatch_count + 1, + consecutive_unmatch_count, # Keep previous value, don't reset + ) + + # ============================================================================ + # STEP 5: Update pairwise overlap tracking (fully vectorized) + # ============================================================================ + + # Find detections matched by multiple tracks + tracks_per_det = adt_result.im_mask.sum(dim=1) # (N_det,) + multi_match_mask = tracks_per_det > 1 # (N_det,) + + # Build overlap increment matrix using einsum + multi_match_tracks = adt_result.im_mask & multi_match_mask.unsqueeze( + 1 + ) # (N_det, N_obj) + + # Compute pairwise overlaps: for each detection, outer product of matched tracks + pairwise_overlap_this_frame = torch.einsum( + "di,dj->dij", multi_match_tracks.float(), multi_match_tracks.float() + ) # (N_det, N_obj, N_obj) + + # Sum across detections + overlap_increment = pairwise_overlap_this_frame.sum(dim=0) # (N_obj, N_obj) + overlap_increment.fill_diagonal_(0) # No self-overlap + overlap_increment = torch.triu( + overlap_increment, diagonal=1 + ) # Upper triangle only + + # Add this frame's increments (accumulate across frames, don't reset) + # Original logic: overlap_pair_to_frame_inds[key].append(frame_idx) - never clears + overlap_pair_counts = overlap_pair_counts + overlap_increment.long() + + # ============================================================================ + # STEP 6: Compute removal decisions - UNMATCH criterion (fully vectorized) + # ============================================================================ + + # Hotstart boundary + hotstart_diff = ( + frame_idx - self.hotstart_delay + if not reverse + else frame_idx + self.hotstart_delay + ) + + # Check if objects are within hotstart window + is_within_hotstart = ( + (obj_first_frame > hotstart_diff) + if not reverse + else (obj_first_frame < hotstart_diff) + ) + + # Remove if: within hotstart AND unmatched >= threshold AND not already removed + remove_by_unmatch = ( + is_within_hotstart + & (consecutive_unmatch_count >= self.hotstart_unmatch_thresh) + & ~removed_mask + ) + + # Suppress if: keep_alive <= 0 AND not hotstart-only mode AND not removed + suppress_by_unmatch = ( + (trk_keep_alive <= 0) + & torch.tensor(not self.suppress_unmatched_only_within_hotstart) + .pin_memory() + .to(device=device, non_blocking=True) + & ~removed_mask + & ~remove_by_unmatch + ) + + # ============================================================================ + # STEP 7: Compute removal decisions - OVERLAP criterion (fully vectorized) + # ============================================================================ + + # For each object, find max overlap count with any EARLIER object + # "Earlier" = appeared in an earlier frame + + # Build matrix: is_earlier[i, j] = True if object i appeared before object j + first_frames_i = obj_first_frame.unsqueeze(1) # (N_obj, 1) + first_frames_j = obj_first_frame.unsqueeze(0) # (1, N_obj) + + if not reverse: + is_earlier_matrix = first_frames_i < first_frames_j # (N_obj, N_obj) + else: + is_earlier_matrix = first_frames_i > first_frames_j # (N_obj, N_obj) + + # ============================================================================ + # STEP 8: Combine removal/suppression decisions + # ============================================================================ + + # Mask overlap counts to only consider earlier objects + if N_obj == 0: + to_remove = remove_by_unmatch + else: + overlap_with_earlier = torch.where( + is_earlier_matrix, + overlap_pair_counts, + torch.zeros_like(overlap_pair_counts), + ) + + # For each object (column j), find max overlap with any earlier object (row i) + max_overlap_with_earlier, _ = overlap_with_earlier.max(dim=0) # (N_obj,) + + # Remove if: within hotstart AND overlapped with earlier >= threshold + remove_by_overlap = ( + is_within_hotstart + & (max_overlap_with_earlier >= self.hotstart_dup_thresh) + & ~removed_mask + ) + + to_remove = remove_by_unmatch | remove_by_overlap # (N_obj,) + + to_suppress = suppress_by_unmatch # (N_obj,) + + # Update removed mask for future frames + removed_mask = removed_mask | to_remove + + # ============================================================================ + # STEP 9: Package updated metadata (NO SYNCS) + # ============================================================================ + + gpu_metadata_new = { + "N_obj": N_obj, + "obj_first_frame": obj_first_frame, + "consecutive_unmatch_count": consecutive_unmatch_count, + "trk_keep_alive": trk_keep_alive, + "removed_mask": removed_mask, + "overlap_pair_counts": overlap_pair_counts, + "last_occluded_tensor": last_occluded, + } + + return to_remove, to_suppress, gpu_metadata_new + + def _process_hotstart( + self, + frame_idx: int, + num_frames: int, + reverse: bool, + det_to_matched_trk_obj_ids: Dict[int, np.ndarray], + new_det_obj_ids: np.ndarray, + empty_trk_obj_ids: np.ndarray, + unmatched_trk_obj_ids: np.ndarray, + rank0_metadata: Dict[str, Any], + tracker_metadata: Dict[str, Any], + ): + """Handle hotstart heuristics to remove unmatched or duplicated objects.""" + # obj_id --> first frame index where the object was detected + obj_first_frame_idx = rank0_metadata["obj_first_frame_idx"] + # obj_id --> [mismatched frame indices] + unmatched_frame_inds = rank0_metadata["unmatched_frame_inds"] + trk_keep_alive = rank0_metadata["trk_keep_alive"] + # (first_appear_obj_id, obj_id) --> [overlap frame indices] + overlap_pair_to_frame_inds = rank0_metadata["overlap_pair_to_frame_inds"] + # removed_obj_ids: object IDs that are suppressed via hot-start + removed_obj_ids = rank0_metadata["removed_obj_ids"] + suppressed_obj_ids = rank0_metadata["suppressed_obj_ids"][frame_idx] + + obj_ids_newly_removed = set() # object IDs to be newly removed on this frame + hotstart_diff = ( + frame_idx - self.hotstart_delay + if not reverse + else frame_idx + self.hotstart_delay + ) + + # Step 1: log the frame index where each object ID first appears + for obj_id in new_det_obj_ids: + if obj_id not in obj_first_frame_idx: + obj_first_frame_idx[obj_id] = frame_idx + assert obj_id not in trk_keep_alive + trk_keep_alive[obj_id] = self.init_trk_keep_alive + + matched_trks = set() + # We use the det-->tracks list to check for matched objects. Otherwise, we need to compute areas to decide whether they're occluded + for matched_trks_per_det in det_to_matched_trk_obj_ids.values(): + matched_trks.update(matched_trks_per_det) + for obj_id in matched_trks: + # NOTE: To minimize number of configurable params, we use the hotstart_unmatch_thresh to set the max value of trk_keep_alive + trk_keep_alive[obj_id] = min( + self.max_trk_keep_alive, trk_keep_alive[obj_id] + 1 + ) + for obj_id in unmatched_trk_obj_ids: + unmatched_frame_inds[obj_id].append(frame_idx) + # NOTE: To minimize number of configurable params, we use the hotstart_unmatch_thresh to set the min value of trk_keep_alive + # The max keep alive is 2x the min, means the model prefers to keep the prediction rather than suppress it if it was matched long enough. + trk_keep_alive[obj_id] = max( + self.min_trk_keep_alive, trk_keep_alive[obj_id] - 1 + ) + if self.decrease_trk_keep_alive_for_empty_masklets: + for obj_id in empty_trk_obj_ids: + # NOTE: To minimize number of configurable params, we use the hotstart_unmatch_thresh to set the min value of trk_keep_alive + trk_keep_alive[obj_id] = max( + self.min_trk_keep_alive, trk_keep_alive[obj_id] - 1 + ) + + # Step 2: removed tracks that has not matched with detections for `hotstart_unmatch_thresh` frames with hotstart period + # a) add unmatched frame indices for each existing object ID + # note that `unmatched_trk_obj_ids` contains those frames where the SAM2 output mask + # doesn't match any FA detection; it excludes those frames where SAM2 gives an empty mask + # b) remove a masklet if it first appears after `hotstart_diff` and is unmatched for more + # than `self.hotstart_unmatch_thresh` frames + for obj_id, frame_indices in unmatched_frame_inds.items(): + if obj_id in removed_obj_ids or obj_id in obj_ids_newly_removed: + continue # skip if the object is already removed + if len(frame_indices) >= self.hotstart_unmatch_thresh: + is_within_hotstart = ( + obj_first_frame_idx[obj_id] > hotstart_diff and not reverse + ) or (obj_first_frame_idx[obj_id] < hotstart_diff and reverse) + if is_within_hotstart: + obj_ids_newly_removed.add(obj_id) + logger.info( + f"Removing object {obj_id} at frame {frame_idx} " + f"since it is unmatched for frames: {frame_indices}" + ) + if ( + trk_keep_alive[obj_id] <= 0 # Object has not been matched for too long + and not self.suppress_unmatched_only_within_hotstart + and obj_id not in removed_obj_ids + and obj_id not in obj_ids_newly_removed + ): + logger.debug( + f"Suppressing object {obj_id} at frame {frame_idx}, due to being unmatched" + ) + suppressed_obj_ids.add(obj_id) + + # Step 3: removed tracks that overlaps with another track for `hotstart_dup_thresh` frames + # a) find overlaps tracks -- we consider overlap if they match to the same detection + for _, matched_trk_obj_ids in det_to_matched_trk_obj_ids.items(): + if len(matched_trk_obj_ids) < 2: + continue # only count detections that are matched to multiple (>=2) masklets + # if there are multiple matched track ids, we need to find the one that appeared first; + # these later appearing ids may be removed since they may be considered as duplicates + first_appear_obj_id = ( + min(matched_trk_obj_ids, key=lambda x: obj_first_frame_idx[x]) + if not reverse + else max(matched_trk_obj_ids, key=lambda x: obj_first_frame_idx[x]) + ) + for obj_id in matched_trk_obj_ids: + if obj_id != first_appear_obj_id: + key = (first_appear_obj_id, obj_id) + overlap_pair_to_frame_inds[key].append(frame_idx) + + # b) remove a masklet if it first appears after `hotstart_diff` and it overlaps with another + # masklet (that appears earlier) for more than `self.hotstart_dup_thresh` frames + for (first_obj_id, obj_id), frame_indices in overlap_pair_to_frame_inds.items(): + if obj_id in removed_obj_ids or obj_id in obj_ids_newly_removed: + continue # skip if the object is already removed + if (obj_first_frame_idx[obj_id] > hotstart_diff and not reverse) or ( + obj_first_frame_idx[obj_id] < hotstart_diff and reverse + ): + if len(frame_indices) >= self.hotstart_dup_thresh: + obj_ids_newly_removed.add(obj_id) + logger.info( + f"Removing object {obj_id} at frame {frame_idx} " + f"since it overlaps with another track {first_obj_id} at frames: {frame_indices}" + ) + + removed_obj_ids.update(obj_ids_newly_removed) + return obj_ids_newly_removed, rank0_metadata + + def _tracker_update_memories( + self, + sam2_inference_states: List[Any], + frame_idx: int, + tracker_metadata: Dict[str, Any], + low_res_masks: Tensor, + ): + """ + Run Sam2 memory encoder, enforcing non-overlapping constraints globally. + """ + # TODO: Add most recently occluded heuristic for suppression of overlapping masks + if len(sam2_inference_states) == 0: + return + # Avoid an extra interpolation step by directly interpolating to `interpol_size` + high_res_H, high_res_W = ( + self.tracker.maskmem_backbone.mask_downsampler.interpol_size + ) + # NOTE: inspect this part if we observe OOMs in the demo + high_res_masks = F.interpolate( + low_res_masks.unsqueeze(1), + size=(high_res_H, high_res_W), + mode="bilinear", + align_corners=False, + ) + # We first apply non-overlapping constraints before memory encoding. This may include some suppression heuristics. + with torch.profiler.record_function( + "sam2_predictor.propagate_in_video.apply_non_overlapping_constraints" + ): + # TODO: try _apply_object_wise_non_overlapping_constraints instead + high_res_masks = self.tracker._suppress_object_pw_area_shrinkage( + high_res_masks + ) + # Instead of gathering the predicted object scores, we use mask areas as a proxy. + object_score_logits = torch.where( + (high_res_masks > 0).any(dim=(-1, -2)), 10.0, -10.0 + ) + + if self.is_multiplex and self.tracker.is_multiplex_dynamic: + # The objects in the masks are ordered w.r.t. object IDs, + # which might not be true in the dynamic multiplex case with backfilling + # (see also _propogate_tracker_one_frame_local_gpu) + # We need to plan globally for the mask assignment here + object_idx_assignment: dict[int, list[int]] = {} + all_object_ids: list[int] = [] + object_id_to_state_i: dict[int, int] = {} + for state_i, sam2_state in enumerate(sam2_inference_states): + obj_ids = sam2_state["obj_ids"] + all_object_ids.extend(obj_ids) + for obj_id in obj_ids: + object_id_to_state_i[obj_id] = state_i + object_idx_assignment[state_i] = [] + sorted_indices = sorted( + range(len(all_object_ids)), key=lambda i: all_object_ids[i] + ) + # Build the object_idx_assignment mapping + for global_idx, local_idx in enumerate(sorted_indices): + obj_id = all_object_ids[local_idx] + object_idx_assignment[object_id_to_state_i[obj_id]].append(global_idx) + + # Run the memory encoder on local slices for each GPU + start_idx_gpu = sum(tracker_metadata["num_obj_per_gpu"][: self.rank]) + start_idx_state = start_idx_gpu + for state_i, sam2_state in enumerate(sam2_inference_states): + num_obj_per_state = len(sam2_state["obj_ids"]) + if num_obj_per_state == 0: + continue + # Get the local high-res masks and object score logits for this inference state + if self.is_multiplex and self.tracker.is_multiplex_dynamic: + local_idx = ( + torch.tensor(object_idx_assignment[state_i]) + .pin_memory() + .to(device=high_res_masks.device, non_blocking=True) + ) + local_high_res_masks = high_res_masks[local_idx] + local_object_score_logits = object_score_logits[local_idx] + else: + end_idx_state = start_idx_state + num_obj_per_state + local_high_res_masks = high_res_masks[start_idx_state:end_idx_state] + local_object_score_logits = object_score_logits[ + start_idx_state:end_idx_state + ] + local_batch_size = local_high_res_masks.size(0) + # Run Sam2 memory encoder. Note that we do not re-enforce the non-overlapping constraint as it is turned off by default + + encoded_mem = self.tracker._run_memory_encoder( + sam2_state, + frame_idx, + local_batch_size, + local_high_res_masks, + local_object_score_logits, + is_mask_from_pts=False, + ) + if self.is_multiplex: + ( + local_maskmem_features, + local_maskmem_pos_enc, + local_image_features, + local_image_pos_enc, + ) = encoded_mem + else: + local_maskmem_features, local_maskmem_pos_enc = encoded_mem + + # Store encoded memories in the local inference state + output_dict = sam2_state["output_dict"] + for storage_key in ["cond_frame_outputs", "non_cond_frame_outputs"]: + if frame_idx not in output_dict[storage_key]: + continue + output_dict[storage_key][frame_idx]["maskmem_features"] = ( + local_maskmem_features + ) + output_dict[storage_key][frame_idx]["maskmem_pos_enc"] = [ + pos for pos in local_maskmem_pos_enc + ] + if self.is_multiplex: + output_dict[storage_key][frame_idx]["image_features"] = ( + local_image_features + ) + output_dict[storage_key][frame_idx]["image_pos_enc"] = ( + local_image_pos_enc + ) + + if self.reapply_no_object_pointer: + # reapply the no_object_pointer projection for the objects suppressed by the heuristics + newly_suppressed_objects = ( + output_dict[storage_key][frame_idx]["object_score_logits"] + > self.tracker.object_score_logit_threshold + ) & (local_object_score_logits < 0) + if torch.any(newly_suppressed_objects): + existing_pointers = output_dict[storage_key][frame_idx][ + "obj_ptr" + ] + + multiplex_state = sam2_state["multiplex_state"] + existing_pointers = multiplex_state.demux(existing_pointers) + + newly_suppressed_objects = newly_suppressed_objects.float() + new_pointers = ( + newly_suppressed_objects + * self.tracker.no_obj_ptr_linear(existing_pointers) + + (1 - newly_suppressed_objects) * existing_pointers + ) + + output_dict[storage_key][frame_idx]["obj_ptr"] = ( + multiplex_state.mux(new_pointers) + ) + elif self.reapply_no_object_pointer: + raise NotImplementedError( + "reapply_no_object_pointer is not implemented for non-multiplex" + ) + + # for batched inference state, we also need to add per-object + # memory slides to support instance interactivity + self.tracker.add_output_per_object( + inference_state=sam2_state, + frame_idx=frame_idx, + current_out=output_dict[storage_key][frame_idx], + storage_key=storage_key, + ) + start_idx_state += num_obj_per_state + + def _tracker_add_new_objects( + self, + frame_idx: int, + num_frames: int, + new_obj_ids: List[int], + new_obj_masks: Tensor, + tracker_states_local: List[Any], + orig_vid_height: int, + orig_vid_width: int, + feature_cache: Dict, + ): + """Add new objects to SAM2 inference states.""" + + prev_sam2_state = ( + tracker_states_local[0] if len(tracker_states_local) > 0 else None + ) + # prepare inference_state + if self.tracker.is_multiplex_dynamic: + # in multiplex_dynamic mode, we first try to find the best-fit + # inference state for the new objects. + # Create a new state if needed + num_new_objects = len(new_obj_ids) + + # Try to find existing states with available slots + best_state = None + best_available_slots = float("inf") + + for state in tracker_states_local: + available_slots = state["multiplex_state"].available_slots + # Find the state with the least available slots that can still fit the new objects + if ( + available_slots >= num_new_objects + and available_slots < best_available_slots + ): + best_state = state + best_available_slots = available_slots + + if best_state is not None: + # Use the existing state with sufficient available slots + new_sam2_state = best_state + else: + # Need to create a new state + new_sam2_state = self.tracker.init_state( + cached_features=feature_cache, + video_height=orig_vid_height, + video_width=orig_vid_width, + num_frames=num_frames, + ) + new_sam2_state["backbone_out"] = ( + prev_sam2_state.get("backbone_out", None) + if prev_sam2_state is not None + else None + ) + # Add the new state to our local states list + tracker_states_local.append(new_sam2_state) + else: + if self.tracker.per_obj_inference: + # in per_obj_inference mode, init_state happens only once, + # new obj_ids will be added to the existing inference state + if prev_sam2_state is not None: + new_sam2_state = prev_sam2_state + else: + new_sam2_state = self.tracker.init_state( + cached_features=feature_cache, + video_height=orig_vid_height, + video_width=orig_vid_width, + num_frames=num_frames, + ) + new_sam2_state["backbone_out"] = None + tracker_states_local = [new_sam2_state] + else: + # batch objects that first appear on the same frame together + # Clear inference state. Keep the cached image features if available. + new_sam2_state = self.tracker.init_state( + cached_features=feature_cache, + video_height=orig_vid_height, + video_width=orig_vid_width, + num_frames=num_frames, + ) + new_sam2_state["backbone_out"] = ( + prev_sam2_state.get("backbone_out", None) + if prev_sam2_state is not None + else None + ) + tracker_states_local.append(new_sam2_state) + + assert len(new_obj_ids) == new_obj_masks.size(0) + assert new_obj_masks.is_floating_point() + # TODO consider removing this interpolation -- it's probably no longer needed + # we should edit `self.tracker.add_new_mask` to directly take low-res input masks + input_mask_res = self.tracker.input_mask_size + new_obj_masks = F.interpolate( + new_obj_masks.unsqueeze(1), + size=(input_mask_res, input_mask_res), + mode="bilinear", + align_corners=False, + ).squeeze(1) + new_obj_masks = new_obj_masks > 0 + + if self.is_multiplex: + # add all objects at once + # NOTE: In the current implementation, add_new_masks also runs the memory encoder + # the non-overlapping constraint is enforced + self.tracker.add_new_masks( + inference_state=new_sam2_state, + frame_idx=frame_idx, + obj_ids=new_obj_ids, + masks=new_obj_masks, + add_mask_to_memory=True, + ) + else: + # add object one by one + for new_obj_id, new_mask in zip(new_obj_ids, new_obj_masks): + self.tracker.add_new_mask( + inference_state=new_sam2_state, + frame_idx=frame_idx, + obj_id=new_obj_id, + mask=new_mask, + add_mask_to_memory=True, + ) + # NOTE: we skip enforcing the non-overlapping constraint **globally** when adding new objects. + self.tracker.propagate_in_video_preflight(new_sam2_state, run_mem_encoder=True) + + return tracker_states_local + + def _tracker_remove_objects( + self, tracker_states_local: List[Any], obj_ids: list[int] + ): + """ + Remove an object from SAM2 inference states. This would remove the object from + all frames in the video. + """ + if self.is_multiplex: + tracker_states_local_before_removal = tracker_states_local.copy() + tracker_states_local.clear() + for sam2_inference_state in tracker_states_local_before_removal: + # we try to remove `obj_id` on every inference state with `strict=False` + # it will not do anything if an inference state doesn't contain `obj_id` + new_obj_ids, _ = self.tracker.remove_objects( + sam2_inference_state, obj_ids, strict=False, need_output=False + ) + # only keep an inference state if it's non-empty after object removal + if len(new_obj_ids) > 0: + tracker_states_local.append(sam2_inference_state) + else: + for obj_id in obj_ids: + self._tracker_remove_object(tracker_states_local, obj_id) + + def update_masklet_confirmation_status( + self, + rank0_metadata: Dict[str, Any], + obj_ids_all_gpu_prev: np.ndarray, + obj_ids_all_gpu_updated: np.ndarray, + det_to_matched_trk_obj_ids: Dict[int, np.ndarray], + new_det_obj_ids: np.ndarray, + ): + """ + Update masklet confirmation status. + """ + confirmation_data = rank0_metadata["masklet_confirmation"] + status_prev = confirmation_data["status"] + consecutive_det_num_prev = confirmation_data["consecutive_det_num"] + + N_prev = len(obj_ids_all_gpu_prev) + N_updated = len(obj_ids_all_gpu_updated) + + # a) Map previous confirmation data to updated positions + # For small arrays, simple dict lookup is fast + unconfirmed_val = MaskletConfirmationStatus.UNCONFIRMED.value + status = np.full(N_updated, unconfirmed_val, dtype=np.int64) + consecutive_det_num = np.zeros(N_updated, dtype=np.int64) + + if N_prev > 0 and N_updated > 0: + # Build mapping: obj_id -> new index + obj_id_to_new_idx = { + obj_id: idx for idx, obj_id in enumerate(obj_ids_all_gpu_updated) + } + + # Copy previous values for objects that still exist + for old_idx, obj_id in enumerate(obj_ids_all_gpu_prev): + new_idx = obj_id_to_new_idx.get(obj_id) + if new_idx is not None: + status[new_idx] = status_prev[old_idx] + consecutive_det_num[new_idx] = consecutive_det_num_prev[old_idx] + + # b) Update confirmation status based on current frame detections + # Build set of all matched object IDs + matched_obj_ids = set(new_det_obj_ids) + for matched_trk_ids in det_to_matched_trk_obj_ids.values(): + matched_obj_ids.update(matched_trk_ids) + + # Update consecutive detection count and status + for idx, obj_id in enumerate(obj_ids_all_gpu_updated): + if obj_id in matched_obj_ids: + consecutive_det_num[idx] += 1 + else: + consecutive_det_num[idx] = 0 + + # Update status to CONFIRMED where threshold is met + if ( + consecutive_det_num[idx] + >= self.masklet_confirmation_consecutive_det_thresh + ): + status[idx] = MaskletConfirmationStatus.CONFIRMED.value + + # Store updated arrays + confirmation_data["status"] = status + confirmation_data["consecutive_det_num"] = consecutive_det_num + return rank0_metadata + + +class Sam3MultiplexPredictorWrapper(Sam3MultiplexTrackerPredictor): + """ + Wraps a pre-built multiplex tracker model with the same interface as the + onevision Sam3MultiplexTrackerPredictor class. Inherits from Sam3MultiplexTrackerPredictor to pass + isinstance checks, but skips Sam3MultiplexTrackerPredictor.__init__ (which requires Hydra). + + Provides bf16 autocast, attribute proxying, and configuration flags + needed by Sam3MultiplexTracking. + + The onevision Sam3MultiplexTrackerPredictor builds the tracker from Hydra config and applies + extensive hydra_overrides. This version skips Hydra entirely — the caller + is responsible for building the tracker via model_builder.py with the + correct parameters. + + Key parameters that the onevision Sam3MultiplexTrackerPredictor sets via hydra_overrides + (documented here for reference — these must be set in model_builder.py): + - image_size=1008, backbone_stride=14 + - maskmem_backbone.mask_downsampler.interpol_size=[1152,1152] + - always_start_from_first_ann_frame=false + - non_overlap_masks_for_mem_enc=false, non_overlap_masks_for_output=false + - max_cond_frames_in_attn=4 + - offload_output_to_cpu_for_eval=false, trim_past_non_cond_mem_for_eval=false + - sam_mask_decoder_extra_args: dynamic_multimask_via_stability=true, etc. + - binarize_mask_from_pts_for_mem_enc=true (SAM2 tracker default) + - only_obj_ptrs_in_the_past_for_eval=true + - clear_non_cond_mem_around_input=true + - transformer.encoder.layer.self_attention.feat_sizes=[72,72] + - transformer.encoder.layer.cross_attention.feat_sizes=[72,72] + - fill_hole_area= + - use_fa3, use_rope_real on self_attention, cross_attention, + self_attention_rope, cross_attention_rope + - use_memory_selection + """ + + def __init__( + self, + model, + per_obj_inference=False, + fill_hole_area=0, + is_multiplex=True, + is_multiplex_dynamic=True, + ): + # Skip Sam3MultiplexTrackerPredictor.__init__ (requires Hydra) — call nn.Module.__init__ directly + nn.Module.__init__(self) + self.model = model + self.per_obj_inference = per_obj_inference + self.fill_hole_area = fill_hole_area + self.is_multiplex = is_multiplex + self.is_multiplex_dynamic = is_multiplex_dynamic + + # use bfloat16 inference for Flash Attention kernel + self.bf16_context = torch.autocast(device_type="cuda", dtype=torch.bfloat16) + self.bf16_context.__enter__() diff --git a/third_party/sam3/sam3/model/sam3_multiplex_detector.py b/third_party/sam3/sam3/model/sam3_multiplex_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..be06d2057243cbdbcf99da650b0fe7d9e08c4629 --- /dev/null +++ b/third_party/sam3/sam3/model/sam3_multiplex_detector.py @@ -0,0 +1,943 @@ +import os + +import torch +from sam3.model.vl_combiner import SAM3VLBackbone + +try: + from sam3.model.vl_combiner import SAM3VLBackboneTri +except ImportError: + SAM3VLBackboneTri = None +from typing import Dict, List, Optional + +import numpy as np +from sam3.model.data_misc import BatchedDatapoint, FindStage +from sam3.model.geometry_encoders import Prompt +from sam3.model.model_misc import SAM3Output +from sam3.model.sam3_image import Sam3Image +from sam3.model.sam3_multiplex_detector_utils import nms_masks + + +class Sam3MultiplexImageBase(Sam3Image): + """A wrapper class to run Sam3Image on videos for per-frame detection (no tracking).""" + + def __init__( + self, + *args, + tracking_score_thresh: float = 0.0, + offload_outputs_to_cpu_for_eval: bool = False, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.tracking_score_thresh = tracking_score_thresh + self.offload_outputs_to_cpu_for_eval = offload_outputs_to_cpu_for_eval + self.trim_outputs_for_eval = True # dummy option -- it doesn't do anything + + def forward( + self, + input: BatchedDatapoint, + is_inference=False, # (a dummy parameter not used anymore) + ): + assert ( + not self.training + ), "Sam3MultiplexImageBase should only be used in eval mode." + + device = self.device + backbone_out = {"img_batch_all_stages": input.img_batch} + text_outputs = self.backbone.forward_text(input.find_text_batch, device=device) + backbone_out.update(text_outputs) + num_frames = len(input.find_inputs) + + previous_stages_out = SAM3Output( + iter_mode=SAM3Output.IterMode.LAST_STEP_PER_STAGE + ) + for frame_idx in range(num_frames): + find_input = input.find_inputs[frame_idx] + find_target = input.find_targets[frame_idx] + geometric_prompt = self._get_geo_prompt_from_find_input(find_input) + cur_out, _ = self.forward_video_grounding( + backbone_out=backbone_out, + find_input=find_input, + find_target=find_target, + geometric_prompt=geometric_prompt, + ) + # offload model outputs to CPU (to save GPU memory) for evaluation + if self.offload_outputs_to_cpu_for_eval: + cur_out = {k: v.cpu() for k, v in cur_out.items()} + + previous_stages_out.append([cur_out]) + + get_queries = None + return previous_stages_out, get_queries + + def forward_video_grounding( + self, + backbone_out, + find_input, + find_target, + geometric_prompt: Prompt, + **kwargs, + ): + # route this to the image grounding forward method + out = self.forward_grounding( + backbone_out=backbone_out, + find_input=find_input, + find_target=find_target, + geometric_prompt=geometric_prompt, + ) + # trim the output to only include the necessary keys + out = { + "pred_logits": out["pred_logits"], + "pred_boxes": out["pred_boxes"], + "pred_boxes_xyxy": out["pred_boxes_xyxy"], + "pred_masks": out["pred_masks"], + "pred_object_ids": self._get_dummy_object_ids(out["pred_logits"]), + } + return out, backbone_out + + def _get_dummy_object_ids(self, pred_logits): + """Generate dummy object IDs for the detected objects, based on their detection query indices.""" + # Assuming pred_logits has shape [batch_size, num_queries, num_classes] + B, Q, _ = pred_logits.shape + is_above_thresh = pred_logits.squeeze(2) > self.tracking_score_thresh + dummy_obj_ids = torch.arange(Q, device=self.device).expand(B, -1) + dummy_obj_ids = torch.where(is_above_thresh, dummy_obj_ids, -1) + return dummy_obj_ids + + def _trim_outputs(self, *args, **kwargs): + pass # not needed for image-on-video + + def _batch_find_inputs( + self, + find_inputs: List[FindStage], + chunk_start: int, + chunk_end: int, + ) -> FindStage: + """ + Batch multiple FindStage objects into a single batched FindStage. + + For each frame in the chunk, creates img_ids that point to the correct + frame index. When processing streaming video, the img_ids are the actual + frame indices (e.g., [0, 1, 2, ..., 15] for chunk 0-16), and the modulo + for circular buffer access is applied later in _get_img_feats. + + Args: + find_inputs: List of FindStage objects for all frames. + chunk_start: Start index of the chunk. + chunk_end: End index of the chunk (exclusive). + + Returns: + A single FindStage with batched tensors. + """ + chunk_find_inputs = [ + find_inputs[i % len(find_inputs)] for i in range(chunk_start, chunk_end) + ] + + # Generate img_ids based on chunk frame indices + # Each frame in the chunk gets its corresponding frame index + # The modulo for circular buffer access is handled in _get_img_feats + device = chunk_find_inputs[0].img_ids.device + dtype = chunk_find_inputs[0].img_ids.dtype + img_ids_list = [ + torch.tensor([i], device=device, dtype=dtype) + for i in range(chunk_start, chunk_end) + ] + batched_img_ids = torch.cat(img_ids_list, dim=0) + + # Generate img_ids_np to match + img_ids_np_list = [np.array([i]) for i in range(chunk_start, chunk_end)] + batched_img_ids_np = np.concatenate(img_ids_np_list, axis=0) + + # Concatenate text_ids + text_ids_list = [fi.text_ids for fi in chunk_find_inputs] + batched_text_ids = torch.cat(text_ids_list, dim=0) + + # Concatenate input_boxes + input_boxes_list = [fi.input_boxes for fi in chunk_find_inputs] + batched_input_boxes = ( + torch.cat(input_boxes_list, dim=0) + if input_boxes_list[0] is not None + else None + ) + + # Concatenate input_boxes_mask + input_boxes_mask_list = [fi.input_boxes_mask for fi in chunk_find_inputs] + batched_input_boxes_mask = ( + torch.cat(input_boxes_mask_list, dim=0) + if input_boxes_mask_list[0] is not None + else None + ) + + # Concatenate input_boxes_label + input_boxes_label_list = [fi.input_boxes_label for fi in chunk_find_inputs] + batched_input_boxes_label = ( + torch.cat(input_boxes_label_list, dim=0) + if input_boxes_label_list[0] is not None + else None + ) + + # Concatenate input_points + input_points_list = [fi.input_points for fi in chunk_find_inputs] + batched_input_points = ( + torch.cat(input_points_list, dim=0) + if input_points_list[0] is not None + else None + ) + + # Concatenate input_points_mask + input_points_mask_list = [fi.input_points_mask for fi in chunk_find_inputs] + batched_input_points_mask = ( + torch.cat(input_points_mask_list, dim=0) + if input_points_mask_list[0] is not None + else None + ) + + # Handle optional fields + input_boxes_before_embed_list = [ + fi.input_boxes_before_embed for fi in chunk_find_inputs + ] + batched_input_boxes_before_embed = ( + torch.cat(input_boxes_before_embed_list, dim=0) + if input_boxes_before_embed_list[0] is not None + else None + ) + + input_points_before_embed_list = [ + fi.input_points_before_embed for fi in chunk_find_inputs + ] + batched_input_points_before_embed = ( + torch.cat(input_points_before_embed_list, dim=0) + if input_points_before_embed_list[0] is not None + else None + ) + + # Create batched FindStage + batched_find_input = FindStage( + img_ids=batched_img_ids, + img_ids_np=batched_img_ids_np, + text_ids=batched_text_ids, + input_boxes=batched_input_boxes, + input_boxes_mask=batched_input_boxes_mask, + input_boxes_label=batched_input_boxes_label, + input_points=batched_input_points, + input_points_mask=batched_input_points_mask, + ptrs=None, # Not batching pointers for now + ptrs_seg=None, + object_ids=None, + input_boxes_before_embed=batched_input_boxes_before_embed, + input_points_before_embed=batched_input_points_before_embed, + ) + + return batched_find_input + + def _batch_geometric_prompts( + self, + geometric_prompts: List[Prompt], + chunk_start: int, + chunk_end: int, + ) -> Prompt: + """ + Batch multiple Prompt objects into a single batched Prompt. + + Args: + geometric_prompts: List of Prompt objects for all frames. + chunk_start: Start index of the chunk. + chunk_end: End index of the chunk (exclusive). + + Returns: + A single Prompt with batched tensors. + """ + chunk_prompts = [geometric_prompts[i] for i in range(chunk_start, chunk_end)] + return self._batch_geometric_prompts_from_list(chunk_prompts) + + def _batch_geometric_prompts_from_list( + self, + chunk_prompts: List[Prompt], + ) -> Prompt: + """ + Batch a list of Prompt objects into a single batched Prompt. + + Prompt uses seq-first, batch-second convention: + - box_embeddings: N_boxes x B x C_box - batch along dim 1 + - box_mask: B x N_boxes - batch along dim 0 + - box_labels: N_boxes x B - batch along dim 1 + - point_embeddings: N_points x B x C_point - batch along dim 1 + - point_mask: B x N_points - batch along dim 0 + - point_labels: N_points x B - batch along dim 1 + + Args: + chunk_prompts: List of Prompt objects to batch. + + Returns: + A single Prompt with batched tensors. + """ + + # Helper function to batch tensors along specified dimension + def batch_tensors(tensors, dim): + if tensors[0] is None: + return None + return torch.cat(tensors, dim=dim) + + # Batch box embeddings (N_boxes x B x C_box - batch along dim 1) + box_embeddings_list = [p.box_embeddings for p in chunk_prompts] + batched_box_embeddings = batch_tensors(box_embeddings_list, dim=1) + + # Batch box mask (B x N_boxes - batch along dim 0) + box_mask_list = [p.box_mask for p in chunk_prompts] + batched_box_mask = batch_tensors(box_mask_list, dim=0) + + # Batch box labels (N_boxes x B - batch along dim 1) + box_labels_list = [p.box_labels for p in chunk_prompts] + batched_box_labels = batch_tensors(box_labels_list, dim=1) + + # Batch point embeddings (N_points x B x C_point - batch along dim 1) + point_embeddings_list = [p.point_embeddings for p in chunk_prompts] + batched_point_embeddings = batch_tensors(point_embeddings_list, dim=1) + + # Batch point mask (B x N_points - batch along dim 0) + point_mask_list = [p.point_mask for p in chunk_prompts] + batched_point_mask = batch_tensors(point_mask_list, dim=0) + + # Batch point labels (N_points x B - batch along dim 1) + point_labels_list = [p.point_labels for p in chunk_prompts] + batched_point_labels = batch_tensors(point_labels_list, dim=1) + + # Create batched Prompt + batched_prompt = Prompt( + box_embeddings=batched_box_embeddings, + box_mask=batched_box_mask, + box_labels=batched_box_labels, + point_embeddings=batched_point_embeddings, + point_mask=batched_point_mask, + point_labels=batched_point_labels, + ) + + return batched_prompt + + +class Sam3MultiplexDetector(Sam3MultiplexImageBase): + def __init__( + self, + *args, + async_all_gather=True, + gather_backbone_out=None, + is_multiplex=False, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.rank = int(os.getenv("RANK", "0")) + self.world_size = int(os.getenv("WORLD_SIZE", "1")) + self.async_all_gather = async_all_gather + + # if gather_backbone is not set, default to gathering only for `SAM3VLBackbone` + if gather_backbone_out is None: + gather_backbone_out = isinstance(self.backbone, SAM3VLBackbone) or ( + SAM3VLBackboneTri is not None + and isinstance(self.backbone, SAM3VLBackboneTri) + ) + self.gather_backbone_out = gather_backbone_out + self.is_multiplex = is_multiplex + + def forward_video_grounding_multigpu( + self, + backbone_out, + find_inputs, + geometric_prompt: Prompt, + frame_idx, + num_frames, + # `multigpu_buffer` is a dict to cache FA outputs in a chunk between different calls + multigpu_buffer, + track_in_reverse=False, + # whether to also return the SAM2 backbone features (in addition to FA results) + return_sam2_backbone_feats=False, + # whether to perform NMS and suppress the scores of those detections removed by NMS + run_nms=False, + nms_prob_thresh=None, + nms_iou_thresh=None, + nms_use_iom=False, + # tracking bounds to respect max_frame_num_to_track + max_frame_num_to_track=None, + propagate_in_video_start_frame_idx=None, + # feature_cache for buffered backbone computation + feature_cache=None, + **kwargs, + ): + """ + Compute the FA detection outputs in a distributed manner, where all GPUs process + a chunk of frames (equal to the number of GPUs) at once and store them in cache. + """ + # Calculate valid frame range based on max_frame_num_to_track + # We prevent pre-fetching beyond the tracking window relative to current frame + if max_frame_num_to_track is not None: + if propagate_in_video_start_frame_idx is None: + propagate_in_video_start_frame_idx = 0 + if track_in_reverse: + # When going backwards, limit how far back we can go from current frame + valid_frame_start = max( + 0, + propagate_in_video_start_frame_idx - max_frame_num_to_track + 1, + ) + valid_frame_end = num_frames + else: + # When going forwards, limit how far ahead we can go from current frame + valid_frame_start = 0 + valid_frame_end = min( + num_frames, + propagate_in_video_start_frame_idx + max_frame_num_to_track, + ) + else: + # No tracking limit specified, use full video range + valid_frame_start = 0 + valid_frame_end = num_frames + + # Step 1: fetch the FA outputs in the current chunk from buffer + frame_idx_curr_b = frame_idx - frame_idx % self.world_size + frame_idx_curr_e = min(frame_idx_curr_b + self.world_size, num_frames) + + # Clamp the current chunk to the valid tracking range + frame_idx_curr_b = max(frame_idx_curr_b, valid_frame_start) + frame_idx_curr_e = min(frame_idx_curr_e, valid_frame_end) + # in case the current frame's FA results are not in the buffer yet, build the current chunk + # (this should only happen on the first chunk, since we are also building the next chunk below) + if frame_idx not in multigpu_buffer: + with torch.profiler.record_function("build_multigpu_buffer_next_chunk1"): + self._build_multigpu_buffer_next_chunk( + backbone_out=backbone_out, + find_inputs=find_inputs, + geometric_prompt=geometric_prompt, + frame_idx_begin=frame_idx_curr_b, + frame_idx_end=frame_idx_curr_e, + num_frames=num_frames, + multigpu_buffer=multigpu_buffer, + run_nms=run_nms, + nms_prob_thresh=nms_prob_thresh, + nms_iou_thresh=nms_iou_thresh, + nms_use_iom=nms_use_iom, + feature_cache=feature_cache, + ) + + # read out the current frame's results from `multigpu_buffer` + out = {} + for k, (v, handle) in multigpu_buffer[frame_idx].items(): + if self.is_multiplex: + if ( + k.startswith("interactive_backbone_") + or k.startswith("propagation_backbone_") + ) and not return_sam2_backbone_feats: + continue + else: + if k.startswith("sam2_backbone_") and not return_sam2_backbone_feats: + continue + if handle is not None: + handle.wait() # wait for async all-gather to finish + out[k] = v + + # Step 2: remove FA outputs of the previous chunk from cache to save GPU memory + if not track_in_reverse and frame_idx_curr_b - self.world_size >= 0: + frame_idx_prev_e = frame_idx_curr_b + frame_idx_prev_b = frame_idx_curr_b - self.world_size + elif track_in_reverse and frame_idx_curr_e < num_frames: + frame_idx_prev_b = frame_idx_curr_e + frame_idx_prev_e = min(frame_idx_prev_b + self.world_size, num_frames) + else: + frame_idx_prev_b = frame_idx_prev_e = None + if frame_idx_prev_b is not None: + for frame_idx_rm in range(frame_idx_prev_b, frame_idx_prev_e): + multigpu_buffer.pop(frame_idx_rm, None) + + # Step 3: compute and cache FA outputs of the next chunk ahead of time + # (so that we can overlap computation with all-gather transfer) + # Respect tracking bounds when calculating next chunk + + if not track_in_reverse and frame_idx_curr_e < valid_frame_end: + frame_idx_next_b = frame_idx_curr_e + frame_idx_next_e = min(frame_idx_next_b + self.world_size, valid_frame_end) + elif ( + track_in_reverse and frame_idx_curr_b - self.world_size >= valid_frame_start + ): + frame_idx_next_e = frame_idx_curr_b + frame_idx_next_b = max( + frame_idx_curr_b - self.world_size, valid_frame_start + ) + else: + frame_idx_next_b = frame_idx_next_e = None + if frame_idx_next_b is not None and frame_idx_next_b not in multigpu_buffer: + with torch.profiler.record_function("build_multigpu_buffer_next_chunk2"): + self._build_multigpu_buffer_next_chunk( + backbone_out=backbone_out, + find_inputs=find_inputs, + geometric_prompt=geometric_prompt, + frame_idx_begin=frame_idx_next_b, + frame_idx_end=frame_idx_next_e, + num_frames=num_frames, + multigpu_buffer=multigpu_buffer, + run_nms=run_nms, + nms_prob_thresh=nms_prob_thresh, + nms_iou_thresh=nms_iou_thresh, + feature_cache=feature_cache, + ) + + return out, backbone_out + + def _build_multigpu_buffer_next_chunk( + self, + backbone_out, + find_inputs, + geometric_prompt: Prompt, + frame_idx_begin, + frame_idx_end, + num_frames, + multigpu_buffer, + run_nms=False, + nms_prob_thresh=None, + nms_iou_thresh=None, + nms_use_iom=False, + feature_cache=None, + ): + """Compute FA outputs on a chunk of frames and store their results in multigpu_buffer.""" + # each GPU computes FA on one frame in the chunk (in a round-robin manner) + frame_idx_local_gpu = min(frame_idx_begin + self.rank, frame_idx_end - 1) + # `forward_grounding` (from base class `Sam3MultiplexImageBase`) runs FA on a single frame + with torch.profiler.record_function("forward_grounding"): + out_local = self.forward_grounding( + backbone_out=backbone_out, + # HACK: Since find_inputs is on GPU having to realloc is expensive so changing the values in place for the prod usecase + # i.e. when using the streaming frame loader resource instead of local file. For non-prod is always + # frame_idx_local_gpu < len(find_inputs) so should be a no-op + find_input=find_inputs[frame_idx_local_gpu % len(find_inputs)], + find_target=None, + geometric_prompt=geometric_prompt, + feature_cache=feature_cache, + ) + if run_nms: + with torch.profiler.record_function("nms_masks"): + # run NMS as a post-processing step on top of the detection outputs + assert nms_prob_thresh is not None and nms_iou_thresh is not None + pred_probs = out_local["pred_logits"].squeeze(-1).sigmoid() + pred_masks = out_local["pred_masks"] + # loop over text prompts (not an overhead for demo where there's only 1 prompt) + for prompt_idx in range(pred_probs.size(0)): + keep = nms_masks( + pred_probs=pred_probs[prompt_idx], + pred_masks=pred_masks[prompt_idx], + prob_threshold=nms_prob_thresh, + iou_threshold=nms_iou_thresh, + nms_use_iom=nms_use_iom, + do_compile=getattr(self, "compile_model", False), + running_in_prod=getattr(self, "running_in_prod", False), + ) + # set a very low threshold for those detections removed by NMS + out_local["pred_logits"][prompt_idx, :, 0] -= 1e4 * (~keep).float() + + if self.gather_backbone_out: + # gather the SAM 2 backbone features across GPUs + if self.is_multiplex: + # Note that we should not need to compute the interaction features every frame + # TODO: rooms for optimization + + # Interaction features + inte_feats = out_local["prev_encoder_out"]["backbone_out"][ + "interactive" + ] + assert inte_feats["vision_mask"] is None + assert ( + len(inte_feats["backbone_fpn"]) == 3 + ) # SAM2 backbone always have 3 levels + assert all(x.mask is None for x in inte_feats["backbone_fpn"]) + # cast the SAM2 backbone features to bfloat16 for all-gather (this is usually + # a no-op, SAM2 backbone features are likely already in bfloat16 due to AMP) + inte_backbone_fpn_bf16 = [ + x.to(torch.bfloat16) for x in inte_feats["backbone_fpn"] + ] + inte_fpn0, inte_fpn_handle0 = self._gather_tensor( + inte_backbone_fpn_bf16[0].tensors + ) + inte_fpn1, inte_fpn_handle1 = self._gather_tensor( + inte_backbone_fpn_bf16[1].tensors + ) + inte_fpn2, inte_fpn_handle2 = self._gather_tensor( + inte_backbone_fpn_bf16[2].tensors + ) + # vision_pos_enc is the same on all frames, so no need to all-gather them + inte_vision_pos_enc = inte_feats["vision_pos_enc"] + + feats = out_local["prev_encoder_out"]["backbone_out"]["sam2_backbone_out"] + assert feats["vision_mask"] is None + assert len(feats["backbone_fpn"]) == 3 # SAM2 backbone always have 3 levels + assert all(x.mask is None for x in feats["backbone_fpn"]) + # cast the SAM2 backbone features to bfloat16 for all-gather (this is usually + # a no-op, SAM2 backbone features are likely already in bfloat16 due to AMP) + backbone_fpn_bf16 = [x.to(torch.bfloat16) for x in feats["backbone_fpn"]] + fpn0, fpn_handle0 = self._gather_tensor(backbone_fpn_bf16[0].tensors) + fpn1, fpn_handle1 = self._gather_tensor(backbone_fpn_bf16[1].tensors) + fpn2, fpn_handle2 = self._gather_tensor(backbone_fpn_bf16[2].tensors) + # vision_pos_enc is the same on all frames, so no need to all-gather them + vision_pos_enc = feats["vision_pos_enc"] + + # trim the FA output to only include the necessary keys + out_local = { + "pred_logits": out_local["pred_logits"], + "pred_boxes": out_local["pred_boxes"], + "pred_boxes_xyxy": out_local["pred_boxes_xyxy"], + "pred_masks": out_local["pred_masks"], + "pred_object_ids": self._get_dummy_object_ids(out_local["pred_logits"]), + } + + # gather the results: after this step, each GPU will receive FA outputs on + # all frames in the chunk and store them in `multigpu_buffer` + out_gathered = {k: self._gather_tensor(v) for k, v in out_local.items()} + for rank in range(self.world_size): + frame_idx_to_save = frame_idx_begin + rank + if frame_idx_to_save >= num_frames: + continue + frame_buffer = { + k: (v[rank], handle) for k, (v, handle) in out_gathered.items() + } + if self.gather_backbone_out: + # also add gathered SAM 2 backbone features to frame_buffer + if self.is_multiplex: + frame_buffer["interactive_backbone_fpn_0"] = ( + inte_fpn0[rank], + inte_fpn_handle0, + ) + frame_buffer["interactive_backbone_fpn_1"] = ( + inte_fpn1[rank], + inte_fpn_handle1, + ) + frame_buffer["interactive_backbone_fpn_2"] = ( + inte_fpn2[rank], + inte_fpn_handle2, + ) + frame_buffer["interactive_backbone_pos_enc"] = ( + inte_vision_pos_enc, + None, + ) + frame_buffer["sam2_backbone_fpn_0"] = (fpn0[rank], fpn_handle0) + frame_buffer["sam2_backbone_fpn_1"] = (fpn1[rank], fpn_handle1) + frame_buffer["sam2_backbone_fpn_2"] = (fpn2[rank], fpn_handle2) + frame_buffer["sam2_backbone_pos_enc"] = (vision_pos_enc, None) + + multigpu_buffer[frame_idx_to_save] = frame_buffer + + def _gather_tensor(self, x): + if self.world_size == 1: + return [x], None + + async_op = self.async_all_gather + # here `.contiguous()` is required -- otherwise NCCL all_gather + # sometimes gives wrong results (based on Ronghang's observations) + x = x.contiguous() # ensure contiguous memory for NCCL + output_list = [torch.empty_like(x) for _ in range(self.world_size)] + handle = torch.distributed.all_gather(output_list, x, async_op=async_op) + return output_list, handle + + def forward_video_grounding_batched_multigpu( + self, + backbone_out, + find_inputs, + geometric_prompt: Prompt, + frame_idx, + num_frames, + # `grounding_cache` is a dict to cache FA outputs in a chunk between different calls + grounding_cache, + track_in_reverse=False, + # whether to also return the SAM2 backbone features (in addition to FA results) + return_sam2_backbone_feats=False, + # whether to perform NMS and suppress the scores of those detections removed by NMS + run_nms=False, + nms_prob_thresh=None, + nms_iou_thresh=None, + nms_use_iom=False, + # tracking bounds to respect max_frame_num_to_track + max_frame_num_to_track=None, + propagate_in_video_start_frame_idx=None, + # feature_cache for buffered backbone computation + feature_cache=None, + # batch_size for batched forward_grounding (default: 16) + batch_size=16, + ): + """ + Fully batched forward_grounding that processes chunks of frames together on each GPU. + + Unlike forward_video_grounding_multigpu which processes 1 frame per GPU per chunk, + this method processes `batch_size` frames at once using the batched forward_grounding + approach from Sam3MultiplexImageBase. + + For single-GPU (world_size=1), this is equivalent to forward_grounding_batched. + For multi-GPU, each GPU processes batch_size frames in parallel. + + Args: + backbone_out: Dictionary containing backbone outputs and image batch. + find_inputs: List of FindStage objects for all frames. + geometric_prompt: Prompt object (used as template, individual prompts are + constructed from find_inputs for batching). + frame_idx: Current frame index to process. + num_frames: Total number of frames in the video. + grounding_cache: Dictionary to cache grounding outputs. + track_in_reverse: If True, processing in reverse frame order. + return_sam2_backbone_feats: Whether to also return SAM2 backbone features. + run_nms: Whether to perform NMS on detection outputs. + nms_prob_thresh: Probability threshold for NMS. + nms_iou_thresh: IoU threshold for NMS. + nms_use_iom: Whether to use IoM for NMS. + max_frame_num_to_track: Maximum number of frames to track. + propagate_in_video_start_frame_idx: Start frame index for propagation. + feature_cache: Optional dictionary for backbone feature caching. + batch_size: Number of frames to batch together per GPU (default: 16). + + Returns: + Tuple of (out, backbone_out) where out contains detection results for frame_idx. + """ + # Calculate valid frame range based on max_frame_num_to_track + if max_frame_num_to_track is not None: + if propagate_in_video_start_frame_idx is None: + propagate_in_video_start_frame_idx = 0 + if track_in_reverse: + valid_frame_start = ( + propagate_in_video_start_frame_idx - max_frame_num_to_track + 1 + ) + valid_frame_end = propagate_in_video_start_frame_idx + else: + valid_frame_start = propagate_in_video_start_frame_idx + valid_frame_end = ( + propagate_in_video_start_frame_idx + max_frame_num_to_track + ) + else: + valid_frame_start = 0 + valid_frame_end = num_frames + + # Initialize grounding_buffer if not present + if "grounding_buffer" not in grounding_cache: + grounding_cache["grounding_buffer"] = {} + + # Calculate chunk boundaries - use batch_size instead of world_size + chunk_start = (frame_idx // batch_size) * batch_size + chunk_end = min(chunk_start + batch_size, valid_frame_end) + chunk_key = (chunk_start, chunk_end) + + # Process chunk if not already cached + if chunk_key not in grounding_cache["grounding_buffer"]: + with torch.profiler.record_function( + "forward_grounding_batched.process_chunk" + ): + chunk_outputs = self._process_grounding_chunk_batched( + backbone_out=backbone_out, + find_inputs=find_inputs, + chunk_start=chunk_start, + chunk_end=chunk_end, + run_nms=run_nms, + nms_prob_thresh=nms_prob_thresh, + nms_iou_thresh=nms_iou_thresh, + nms_use_iom=nms_use_iom, + feature_cache=feature_cache, + return_sam2_backbone_feats=return_sam2_backbone_feats, + ) + grounding_cache["grounding_buffer"][chunk_key] = chunk_outputs + + # Auto-cleanup previous chunks + self._cleanup_previous_chunks_multigpu( + grounding_cache=grounding_cache, + current_chunk_key=chunk_key, + batch_size=batch_size, + num_frames=num_frames, + track_in_reverse=track_in_reverse, + ) + + # Retrieve the cached output for this frame + chunk_outputs = grounding_cache["grounding_buffer"][chunk_key] + local_idx = frame_idx - chunk_start + + # Slice out the output for this specific frame + out = self._slice_batched_output( + chunk_outputs, local_idx, return_sam2_backbone_feats + ) + + return out, backbone_out + + def _process_grounding_chunk_batched( + self, + backbone_out, + find_inputs, + chunk_start: int, + chunk_end: int, + run_nms: bool, + nms_prob_thresh, + nms_iou_thresh, + nms_use_iom: bool, + feature_cache, + return_sam2_backbone_feats: bool, + ): + """ + Process a chunk of frames through the full forward_grounding pipeline in batch. + """ + chunk_size = chunk_end - chunk_start + + # Build geometric prompts for the chunk + chunk_geo_prompts = [ + self._get_geo_prompt_from_find_input(find_inputs[i % len(find_inputs)]) + for i in range(chunk_start, chunk_end) + ] + + # Batch the find_inputs for this chunk + batched_find_input = self._batch_find_inputs( + find_inputs, chunk_start, chunk_end + ) + + # Batch the geometric prompts + batched_geometric_prompt = self._batch_geometric_prompts_from_list( + chunk_geo_prompts + ) + + # Run forward_grounding on the batched input + with torch.profiler.record_function("forward_grounding_batched.forward"): + out = self.forward_grounding( + backbone_out=backbone_out, + find_input=batched_find_input, + find_target=None, + geometric_prompt=batched_geometric_prompt, + feature_cache=feature_cache, + ) + + # Apply NMS per frame in the batch + if run_nms: + with torch.profiler.record_function("forward_grounding_batched.nms"): + assert nms_prob_thresh is not None and nms_iou_thresh is not None + pred_probs = out["pred_logits"].squeeze(-1).sigmoid() + pred_masks = out["pred_masks"] + # pred_probs shape: [batch_size, num_queries] + # pred_masks shape: [batch_size, num_queries, H, W] + # Use batched NMS to process all frames at once + keep = nms_masks( + pred_probs=pred_probs, + pred_masks=pred_masks, + prob_threshold=nms_prob_thresh, + iou_threshold=nms_iou_thresh, + nms_use_iom=nms_use_iom, + do_compile=getattr(self, "compile_model", False), + running_in_prod=getattr(self, "running_in_prod", False), + ) + # Set a very low threshold for detections removed by NMS + # keep shape: [batch_size, num_queries] + out["pred_logits"][:, :, 0] -= 1e4 * (~keep).float() + + # Extract SAM2 backbone features if requested + if return_sam2_backbone_feats and "prev_encoder_out" in out: + backbone_data = out["prev_encoder_out"]["backbone_out"] + if self.is_multiplex and "interactive" in backbone_data: + out["_interactive_backbone"] = backbone_data["interactive"] + if "sam2_backbone_out" in backbone_data: + out["_sam2_backbone"] = backbone_data["sam2_backbone_out"] + + out["_chunk_size"] = chunk_size + return out + + def _slice_batched_output( + self, + chunk_outputs, + local_idx: int, + return_sam2_backbone_feats: bool, + ): + """ + Slice a single frame's output from the batched chunk outputs. + """ + out = {} + + # Keys to slice at batch dimension + batch_dim_keys = { + "pred_logits", + "pred_boxes", + "pred_boxes_xyxy", + "pred_masks", + "pred_logits_o2m", + "pred_boxes_o2m", + "pred_boxes_xyxy_o2m", + "pred_masks_o2m", + "queries", + "presence_logit_dec", + } + + # Keys to skip + skip_keys = { + "_chunk_size", + "_interactive_backbone", + "_sam2_backbone", + "prev_encoder_out", + "encoder_hidden_states", + "aux_outputs", + } + + for key, value in chunk_outputs.items(): + if key in skip_keys: + continue + if key in batch_dim_keys and isinstance(value, torch.Tensor): + out[key] = value[local_idx : local_idx + 1] + elif isinstance(value, torch.Tensor): + try: + out[key] = value[local_idx : local_idx + 1] + except (IndexError, RuntimeError): + out[key] = value + + # Add object IDs + if "pred_logits" in out: + out["pred_object_ids"] = self._get_dummy_object_ids(out["pred_logits"]) + + # Add SAM2 backbone features if requested + if return_sam2_backbone_feats: + if "_sam2_backbone" in chunk_outputs: + sam2_bb = chunk_outputs["_sam2_backbone"] + out["sam2_backbone_fpn_0"] = sam2_bb["backbone_fpn"][0].tensors[ + local_idx : local_idx + 1 + ] + out["sam2_backbone_fpn_1"] = sam2_bb["backbone_fpn"][1].tensors[ + local_idx : local_idx + 1 + ] + out["sam2_backbone_fpn_2"] = sam2_bb["backbone_fpn"][2].tensors[ + local_idx : local_idx + 1 + ] + out["sam2_backbone_pos_enc"] = [ + x[local_idx : local_idx + 1] for x in sam2_bb["vision_pos_enc"] + ] + + if self.is_multiplex and "_interactive_backbone" in chunk_outputs: + inte_bb = chunk_outputs["_interactive_backbone"] + out["interactive_backbone_fpn_0"] = inte_bb["backbone_fpn"][0].tensors[ + local_idx : local_idx + 1 + ] + out["interactive_backbone_fpn_1"] = inte_bb["backbone_fpn"][1].tensors[ + local_idx : local_idx + 1 + ] + out["interactive_backbone_fpn_2"] = inte_bb["backbone_fpn"][2].tensors[ + local_idx : local_idx + 1 + ] + out["interactive_backbone_pos_enc"] = [ + x[local_idx : local_idx + 1] for x in inte_bb["vision_pos_enc"] + ] + + return out + + def _cleanup_previous_chunks_multigpu( + self, + grounding_cache, + current_chunk_key, + batch_size: int, + num_frames: int, + track_in_reverse: bool, + ): + """Remove previous chunks from cache to save GPU memory.""" + chunk_start, chunk_end = current_chunk_key + + if not track_in_reverse: + prev_chunk_start = chunk_start - batch_size + if prev_chunk_start >= 0: + prev_chunk_end = chunk_start + prev_chunk_key = (prev_chunk_start, prev_chunk_end) + + # Cleanup grounding_buffer entry + chunk = grounding_cache["grounding_buffer"].pop(prev_chunk_key, None) + if chunk is not None: + del chunk + else: + next_chunk_start = chunk_end + if next_chunk_start < num_frames: + next_chunk_end = min(next_chunk_start + batch_size, num_frames) + next_chunk_key = (next_chunk_start, next_chunk_end) + grounding_cache["grounding_buffer"].pop(next_chunk_key, None) diff --git a/third_party/sam3/sam3/model/sam3_multiplex_detector_utils.py b/third_party/sam3/sam3/model/sam3_multiplex_detector_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..26eb9a949dc706a5717f9cce25d9a4d6ca547783 --- /dev/null +++ b/third_party/sam3/sam3/model/sam3_multiplex_detector_utils.py @@ -0,0 +1,369 @@ +import logging + +import numpy as np +import torch +from sam3 import perflib + +try: + # Ronghang's generic GPU NMS implementation; install via + # pip uninstall -y torch_generic_nms; TORCH_CUDA_ARCH_LIST="8.0 9.0" pip install git+https://github.com/ronghanghu/torch_generic_nms + from torch_generic_nms import generic_nms + + GENERIC_NMS_AVAILABLE = True +except ImportError: + GENERIC_NMS_AVAILABLE = False + +from sam3.perflib.masks_ops import mask_iou +from sam3.train.masks_ops import mask_iom + + +def nms_masks( + pred_probs: torch.Tensor, + pred_masks: torch.Tensor, + prob_threshold: float, + iou_threshold: float, + nms_use_iom: bool = False, + do_compile: bool = False, + running_in_prod: bool = False, +) -> torch.Tensor: + """ + Args: + - pred_probs: (num_det,) or (B, num_det) float Tensor, containing the score (probability) of each detection + - pred_masks: (num_det, H_mask, W_mask) or (B, num_det, H_mask, W_mask) float Tensor, containing the binary segmentation mask of each detection + - prob_threshold: float, score threshold to prefilter detections (NMS is performed on detections above threshold) + - iou_threshold: float, mask IoU threshold for NMS (it would also be used as IoM threshold if `nms_use_iom` is True) + - nms_use_iom: bool, if True, use IoM instead of IoU for NMS + - do_compile: bool, whether to compile the function for optimization + - running_in_prod: bool, whether the function is running in production (ie, in Instagram) + + Returns: + - keep: (num_det,) or (B, num_det) bool Tensor, indicating whether each detection is kept after score thresholding + NMS + """ + if do_compile and perflib.is_enabled: + # Apply torch.compile with the same settings as before + compiled_fn = torch.compile( + _nms_masks_core, + mode="max-autotune", + fullgraph=True, + # dynamic=False, + ) + return compiled_fn( + pred_probs, pred_masks, prob_threshold, iou_threshold, nms_use_iom + ) + else: + return _nms_masks_core( + pred_probs, pred_masks, prob_threshold, iou_threshold, nms_use_iom + ) + + +def _nms_masks_core( + pred_probs: torch.Tensor, + pred_masks: torch.Tensor, + prob_threshold: float, + iou_threshold: float, + nms_use_iom: bool = False, +) -> torch.Tensor: + """Core NMS implementation without compilation. + + Supports both single-frame and batched inputs: + - Single-frame: pred_probs (num_det,), pred_masks (num_det, H, W) + - Batched: pred_probs (B, num_det), pred_masks (B, num_det, H, W) + + Returns: + - keep: bool Tensor with same leading dimensions as input, indicating kept detections + """ + # Check if input is batched (has batch dimension) + is_batched = pred_probs.dim() == 2 + + if is_batched: + return _nms_masks_core_batched( + pred_probs, pred_masks, prob_threshold, iou_threshold, nms_use_iom + ) + else: + # Single-frame input: use original logic + return _nms_masks_core_single( + pred_probs, pred_masks, prob_threshold, iou_threshold, nms_use_iom + ) + + +def _nms_masks_core_batched( + pred_probs: torch.Tensor, + pred_masks: torch.Tensor, + prob_threshold: float, + iou_threshold: float, + nms_use_iom: bool = False, +) -> torch.Tensor: + """Core NMS implementation for batched inputs using vectorized operations. + + Args: + - pred_probs: (B, num_det) float Tensor + - pred_masks: (B, num_det, H_mask, W_mask) float Tensor + - prob_threshold: float, score threshold to prefilter detections + - iou_threshold: float, mask IoU/IoM threshold for NMS + - nms_use_iom: bool, if True, use IoM instead of IoU for NMS + + Returns: + - keep: (B, num_det) bool Tensor + """ + B, num_det, H, W = pred_masks.shape + device = pred_masks.device + + is_valid = pred_probs > prob_threshold # (B, num_det) + masks_binary = pred_masks > 0 # (B, num_det, H, W) + + if perflib.is_enabled: + # Compute batched pairwise IoU/IoM + if nms_use_iom: + overlaps = _batched_mask_iom(masks_binary) # (B, num_det, num_det) + else: + overlaps = _batched_mask_iou(masks_binary) # (B, num_det, num_det) + keep = _batched_generic_nms_mask(overlaps, pred_probs, is_valid, iou_threshold) + return keep + + # Non-perflib path: compute batched IoU/IoM + if nms_use_iom: + overlaps = _batched_mask_iom(masks_binary) # (B, num_det, num_det) + else: + overlaps = _batched_mask_iou(masks_binary) # (B, num_det, num_det) + + # Apply batched NMS + keep = _batched_generic_nms_mask(overlaps, pred_probs, is_valid, iou_threshold) + return keep + + +def _batched_mask_iou(masks: torch.Tensor) -> torch.Tensor: + """Compute batched pairwise IoU for masks. + + Args: + - masks: (B, N, H, W) bool Tensor + + Returns: + - ious: (B, N, N) float Tensor + """ + B, N, H, W = masks.shape + # Flatten spatial dims: (B, N, H*W) + masks_flat = masks.reshape(B, N, -1).float() + + # Compute intersection via batched matrix multiplication: (B, N, N) + intersection = torch.bmm(masks_flat, masks_flat.transpose(1, 2)) + + # Compute areas: (B, N) + areas = masks_flat.sum(dim=-1) + + # Compute union: (B, N, N) + union = areas.unsqueeze(2) + areas.unsqueeze(1) - intersection + + return intersection / (union + 1e-8) + + +def _batched_mask_iom(masks: torch.Tensor) -> torch.Tensor: + """Compute batched pairwise IoM (Intersection over Minimum) for masks. + + Args: + - masks: (B, N, H, W) bool Tensor + + Returns: + - ioms: (B, N, N) float Tensor + """ + B, N, H, W = masks.shape + # Flatten spatial dims: (B, N, H*W) + masks_flat = masks.reshape(B, N, -1).float() + + # Compute intersection via batched matrix multiplication: (B, N, N) + intersection = torch.bmm(masks_flat, masks_flat.transpose(1, 2)) + + # Compute areas: (B, N) + areas = masks_flat.sum(dim=-1) + + # Compute min area: (B, N, N) + min_area = torch.minimum(areas.unsqueeze(2), areas.unsqueeze(1)) + + return intersection / (min_area + 1e-8) + + +def _batched_generic_nms_mask( + ious: torch.Tensor, + scores: torch.Tensor, + is_valid: torch.Tensor, + iou_threshold: float, +) -> torch.Tensor: + """Batched NMS using vectorized operations. + + Args: + - ious: (B, N, N) float Tensor, pairwise IoU/IoM matrix + - scores: (B, N) float Tensor, detection scores + - is_valid: (B, N) bool Tensor, valid detections mask + - iou_threshold: float, threshold for suppression + + Returns: + - keep: (B, N) bool Tensor + """ + B, N = scores.shape + device = scores.device + + # Sort by score descending for each batch: (B, N) + order = scores.argsort(dim=-1, descending=True) + + # Create batch indices for advanced indexing + batch_idx = torch.arange(B, device=device).unsqueeze(1).expand(B, N) + + # Reorder IoU matrix according to sorted scores: (B, N, N) + # ious_sorted[b, i, j] = ious[b, order[b, i], order[b, j]] + ious_sorted = ious[batch_idx.unsqueeze(2), order.unsqueeze(2), order.unsqueeze(1)] + + # Create threshold mask: (B, N, N) + threshold_mask = ious_sorted > iou_threshold + + # Initialize keep mask with valid detections in sorted order: (B, N) + keep = is_valid[batch_idx, order] + + # Upper triangular mask to avoid double processing: (N, N) + triu = torch.triu(torch.ones(N, N, device=device, dtype=torch.bool), diagonal=1) + + # Vectorized NMS - iterate through detections + for i in range(N): + # For each position i, suppress later detections with high overlap + # Only suppress if current detection is kept + suppress = ( + threshold_mask[:, i, :] & triu[i].unsqueeze(0) & keep[:, i].unsqueeze(1) + ) + keep = keep & ~suppress + + # Return keep mask in original order: (B, N) + original_keep = torch.zeros_like(keep) + original_keep[batch_idx, order] = keep + return original_keep + + +def _nms_masks_core_single( + pred_probs: torch.Tensor, + pred_masks: torch.Tensor, + prob_threshold: float, + iou_threshold: float, + nms_use_iom: bool = False, +) -> torch.Tensor: + """Core NMS implementation for a single frame (no batch dimension). + + Args: + - pred_probs: (num_det,) float Tensor + - pred_masks: (num_det, H_mask, W_mask) float Tensor + - prob_threshold: float, score threshold to prefilter detections + - iou_threshold: float, mask IoU/IoM threshold for NMS + - nms_use_iom: bool, if True, use IoM instead of IoU for NMS + + Returns: + - keep: (num_det,) bool Tensor + """ + is_valid = pred_probs > prob_threshold # (num_det,) + + if perflib.is_enabled: + masks_binary = pred_masks > 0 # (num_det, H_mask, W_mask) + if nms_use_iom: + ious = perf_mask_iom(masks_binary, masks_binary) # (num_det, num_det) + else: + ious = perf_mask_iou(masks_binary, masks_binary) # (num_det, num_det) + kept_mask = generic_nms_mask(ious, pred_probs, is_valid, iou_threshold) + return kept_mask + # prefilter the detections with prob_threshold ("valid" are those above prob_threshold) + probs = pred_probs[is_valid] # (num_valid,) + masks_binary = pred_masks[is_valid] > 0 # (num_valid, H_mask, W_mask) + if probs.numel() == 0: + return is_valid # no valid detection, return empty keep mask + + if nms_use_iom: + overlaps = mask_iom(masks_binary, masks_binary) # (num_valid, num_valid) + else: + overlaps = mask_iou(masks_binary, masks_binary) # (num_valid, num_valid) + # kept_inds are the indices among `probs` of those kept detections after NMS + if GENERIC_NMS_AVAILABLE: + kept_inds = generic_nms(overlaps, probs, iou_threshold, use_iou_matrix=True) + else: + logging.warning( + "Falling back to CPU mask NMS implementation -- please install `torch_generic_nms` via\n\t" + 'pip uninstall -y torch_generic_nms; TORCH_CUDA_ARCH_LIST="8.0 9.0" pip install git+https://github.com/ronghanghu/torch_generic_nms' + ) + kept_inds = generic_nms_cpu(overlaps, probs, iou_threshold) + + # valid_inds are the indices among `probs` of valid detections before NMS (or -1 for invalid) + valid_inds = torch.where(is_valid, is_valid.cumsum(dim=0) - 1, -1) # (num_det,) + keep = torch.isin(valid_inds, kept_inds) # (num_det,) + return keep + + +def generic_nms_cpu( + ious: torch.Tensor, scores: torch.Tensor, iou_threshold=0.5 +) -> torch.Tensor: + """ + A generic version of `torchvision.ops.nms` that takes a pairwise IoU matrix. (CPU implementation + based on https://github.com/jwyang/faster-rcnn.pytorch/blob/master/lib/model/nms/nms_cpu.py) + """ + ious_np = ious.float().detach().cpu().numpy() + scores_np = scores.float().detach().cpu().numpy() + order = scores_np.argsort()[::-1] + kept_inds = [] + while order.size > 0: + i = order.item(0) + kept_inds.append(i) + inds = np.where(ious_np[i, order[1:]] <= iou_threshold)[0] + order = order[inds + 1] + + return torch.tensor(kept_inds, dtype=torch.int64, device=scores.device) + + +def generic_nms_mask( + ious: torch.Tensor, scores: torch.Tensor, is_valid: torch.Tensor, iou_threshold=0.5 +) -> torch.Tensor: + """ + A generic version of `torchvision.ops.nms` that takes a pairwise IoU matrix. (CPU implementation + using vectorized operations similar to nms_masks_kernel) + """ + # Sort by score descending + order = scores.argsort(descending=True) + + # Reorder IoU matrix according to sorted scores + ious_sorted = ious[order][:, order] + + # Create threshold mask + threshold_mask = ious_sorted > iou_threshold + + # Initialize keep mask + # keep = torch.ones(len(scores), device=ious.device, dtype=torch.bool) + keep = is_valid[order] + + # Upper triangular mask to avoid double processing + tr = torch.triu(torch.ones_like(threshold_mask), diagonal=1) + + # Vectorized NMS + for i in range(len(scores)): + # Suppress all boxes that have high IoU with current box + m = threshold_mask[i] + keep = torch.where(m & tr[i], torch.zeros_like(keep), keep) + + # Return keep mask in original order + original_keep = torch.zeros_like(keep) + original_keep[order] = keep + return original_keep + + +def perf_mask_iou(pred_masks: torch.Tensor, gt_masks: torch.Tensor) -> torch.Tensor: + """ + Compute the IoU (Intersection over Union) between predicted masks and ground truth masks. + + Args: + - pred_masks: (N, H, W) bool Tensor, containing binary predicted segmentation masks + - gt_masks: (M, H, W) bool Tensor, containing binary ground truth segmentation masks + + Returns: + - ious: (N, M) float Tensor, containing IoUs for each pair of predicted and ground truth masks + """ + assert pred_masks.dtype == gt_masks.dtype == torch.bool + from sam3.perflib.iou import pairwise_iou + + return pairwise_iou(pred_masks, gt_masks, eps=None) + + +def perf_mask_iom(pred_masks: torch.Tensor, gt_masks: torch.Tensor) -> torch.Tensor: + assert pred_masks.dtype == gt_masks.dtype == torch.bool + from sam3.perflib.iou import pairwise_iom + + return pairwise_iom(pred_masks, gt_masks) diff --git a/third_party/sam3/sam3/model/sam3_multiplex_tracking.py b/third_party/sam3/sam3/model/sam3_multiplex_tracking.py new file mode 100644 index 0000000000000000000000000000000000000000..d722e1a29f09cce324a37bf5fb03671568b2b99e --- /dev/null +++ b/third_party/sam3/sam3/model/sam3_multiplex_tracking.py @@ -0,0 +1,3431 @@ +from collections import defaultdict +from functools import reduce +from typing import Dict + +import numpy as np +import sam3.model.sam3_multiplex_base +import sam3.model.sam3_video_base +import torch +import torch.distributed as dist +import torch.nn.functional as F +from sam3 import perflib +from sam3.logger import get_logger +from sam3.model.box_ops import box_xywh_to_cxcywh, box_xyxy_to_xywh +from sam3.model.data_misc import BatchedDatapoint +from sam3.model.sam3_multiplex_base import MaskletConfirmationStatus, Sam3MultiplexBase +from sam3.model.sam3_tracker_utils import fill_holes_in_mask_scores +from sam3.model.sam3_video_inference import is_image_type +from sam3.perflib.compile import ( + clone_output_wrapper, + compile_wrapper, + shape_logging_wrapper, +) +from sam3.perflib.masks_ops import mask_iou, masks_to_boxes as perf_masks_to_boxes +from torch import Tensor +from torchvision.ops import masks_to_boxes +from tqdm.auto import tqdm + +logger = get_logger(__name__) + +import gc +from collections.abc import Mapping, Sequence +from dataclasses import fields, is_dataclass +from typing import List + +from sam3.model.data_misc import ( + BatchedPointer, + convert_my_tensors, + FindStage, + NestedTensor, +) +from sam3.model.geometry_encoders import Prompt +from sam3.model.io_utils import load_resource_as_video_frames + + +def recursive_to(data, *args, **kwargs): + if isinstance(data, torch.Tensor): + ret = data.to(*args, **kwargs) + elif isinstance(data, np.ndarray): + ret = data + elif isinstance(data, Mapping): + ret = type(data)() + for key in data: + ret[key] = recursive_to(data[key], *args, **kwargs) + elif isinstance(data, tuple): + ret = () + for value in data: + ret += (recursive_to(value, *args, **kwargs),) + elif isinstance(data, Sequence) and not isinstance(data, str): + ret = type(data)() + for value in data: + ret.append(recursive_to(value, *args, **kwargs)) + elif is_dataclass(data): + ret_cls = type(data) + ret_fields = { + field.name: recursive_to(getattr(data, field.name), *args, **kwargs) + for field in fields(data) + } + ret = ret_cls(**ret_fields) + else: + ret = data + return ret + + +DUMMY_OUTPUT = "DUMMY_OUTPUT" + + +class Sam3MultiplexTracking(Sam3MultiplexBase): + def __init__( + self, + image_size=1008, + image_mean=(0.5, 0.5, 0.5), + image_std=(0.5, 0.5, 0.5), + compile_model=False, + postprocess_batch_size=1, + **kwargs, + ): + """ + hotstart_delay: int, the delay (in #frames) before the model starts to yield output, 0 to disable hotstart delay. + hotstart_unmatch_thresh: int, remove the object if it has this many unmatched frames within its hotstart_delay period. + If `hotstart_delay` is set to 0, this parameter is ignored. + hotstart_dup_thresh: int, remove the object if it has overlapped with another object this many frames within its hotstart_delay period. + postprocess_batch_size: int, the number of frames to accumulate before running postprocessing. Set to 1 to disable batching. + """ + super().__init__(**kwargs) + self.image_size = image_size + self.image_mean = image_mean + self.image_std = image_std + self.compile_model = compile_model + self.detector.compile_model = self.compile_model + self.postprocess_batch_size = postprocess_batch_size + + TEXT_ID_FOR_TEXT = 0 + TEXT_ID_FOR_VISUAL = 1 + TEXT_ID_FOR_GEOMETRIC = 2 + + def _construct_initial_input_batch(self, inference_state, images): + """Construct an initial `BatchedDatapoint` instance as input.""" + # 1) img_batch + num_frames = len(images) + device = inference_state["device"] + img_batch = NestedTensor(tensors=images, mask=None) + + # 2) find_text_batch + # "" will be replaced by the actual text prompt when adding prompts + find_text_batch = ["", "visual", "geometric"] + + # 3) find_inputs + input_box_embedding_dim = 258 # historical default + input_points_embedding_dim = 257 # historical default + dummy_ptrs = BatchedPointer( + stage_ids=[], query_ids=[], object_ids=[], ptr_mask=[], ptr_types=[] + ) + stages = [ + FindStage( + img_ids=[stage_id], + img_ids_np=np.array([stage_id]), + text_ids=[0], + input_boxes=[torch.zeros(input_box_embedding_dim)], + input_boxes_before_embed=[torch.empty(0, 4)], + input_boxes_mask=[torch.empty(0, dtype=torch.bool)], + input_boxes_label=[torch.empty(0, dtype=torch.long)], + input_points=[torch.empty(0, input_points_embedding_dim)], + input_points_before_embed=[torch.empty(0, 3)], + input_points_mask=[torch.empty(0)], + ptrs=dummy_ptrs, + ptrs_seg=dummy_ptrs, + object_ids=[], + ) + for stage_id in range(num_frames) + ] + with torch.profiler.record_function( + "Sam3MultiplexTracking._construct_initial_input_batch" + ): + for i in range(len(stages)): + stages[i] = convert_my_tensors(stages[i]) + + # construct the final `BatchedDatapoint` and cast to GPU + input_batch = BatchedDatapoint( + img_batch=img_batch, + find_text_batch=find_text_batch, + find_inputs=stages, + find_targets=[None] * num_frames, + get_queries=None, + find_metadatas=[None] * num_frames, + ) + with torch.profiler.record_function("Sam3MultiplexTracking.recursive_to"): + input_batch = recursive_to(input_batch, device, non_blocking=True) + inference_state["input_batch"] = input_batch + + # construct the placeholder interactive prompts and tracking queries + bs = 1 + inference_state["constants"]["empty_geometric_prompt"] = Prompt( + box_embeddings=torch.zeros(0, bs, 4, device=device), + box_mask=torch.zeros(bs, 0, device=device, dtype=torch.bool), + box_labels=torch.zeros(0, bs, device=device, dtype=torch.long), + point_embeddings=torch.zeros(0, bs, 2, device=device), + point_mask=torch.zeros(bs, 0, device=device, dtype=torch.bool), + point_labels=torch.zeros(0, bs, device=device, dtype=torch.long), + ) + + # constructing an output list in inference state (we start with an empty list) + inference_state["previous_stages_out"] = [None] * num_frames + inference_state["text_prompt"] = None + inference_state["per_frame_raw_point_input"] = [None] * num_frames + inference_state["per_frame_raw_box_input"] = [None] * num_frames + inference_state["per_frame_visual_prompt"] = [None] * num_frames + inference_state["per_frame_geometric_prompt"] = [None] * num_frames + inference_state["per_frame_cur_step"] = [0] * num_frames + + # placeholders for cached outputs + # (note: currently, a single visual prompt embedding is shared for all frames) + inference_state["backbone_out"] = None + inference_state["visual_prompt_embed"] = None + inference_state["visual_prompt_mask"] = None + + def _get_visual_prompt(self, inference_state, frame_idx, boxes_cxcywh, box_labels): + batch_size = 1 + geometric_prompt = Prompt( + box_embeddings=torch.zeros( + 0, batch_size, 4, device=inference_state["device"] + ), + box_mask=torch.zeros( + batch_size, 0, device=inference_state["device"], dtype=torch.bool + ), + point_embeddings=None, + point_mask=None, + ) + + geometric_prompt.append_boxes( + boxes=boxes_cxcywh.view(-1, batch_size, 4).to(inference_state["device"]), + labels=box_labels.view(-1, batch_size).to(inference_state["device"]), + ) + + return boxes_cxcywh, box_labels, geometric_prompt + + @torch.inference_mode() + def init_state( + self, + resource_path, + offload_video_to_cpu=False, + async_loading_frames=False, + use_torchcodec=False, + use_cv2=False, + input_is_mp4=False, + ): + # Initialize inference state (inlined from Sam3DemoMixin.init_state) + if use_torchcodec: + video_loader_type = "torchcodec" + elif use_cv2: + video_loader_type = "cv2" + else: + video_loader_type = "cv2" + images, orig_height, orig_width = load_resource_as_video_frames( + resource_path=resource_path, + image_size=self.image_size, + offload_video_to_cpu=offload_video_to_cpu, + img_mean=self.image_mean, + img_std=self.image_std, + async_loading_frames=async_loading_frames, + video_loader_type=video_loader_type, + ) + inference_state = {} + inference_state["image_size"] = self.image_size + inference_state["num_frames"] = len(images) + inference_state["device"] = torch.device("cuda") + inference_state["orig_height"] = orig_height + inference_state["orig_width"] = orig_width + inference_state["constants"] = {} + self._construct_initial_input_batch(inference_state, images) + # initialize extra states + # sam2_inference_states will contain separate inference_states for each frame having new objects if + # self.tracker.per_obj_inference is False (bucketized batching), or a single inference_state + # containing all objects if self.tracker.per_obj_inference is True (no batching at all). + inference_state["sam2_inference_states"] = [] + inference_state["tracker_metadata"] = {} + inference_state["feature_cache"] = {} + inference_state["cached_frame_outputs"] = {} + inference_state["is_image_only"] = is_image_type(resource_path) + return inference_state + + def reset_state(self, inference_state): + # Inlined from Sam3DemoMixin.reset_state + inference_state["input_batch"].find_text_batch[0] = "" + inference_state["text_prompt"] = None + for t in range(inference_state["num_frames"]): + inference_state["input_batch"].find_inputs[t].text_ids[...] = 0 + inference_state["previous_stages_out"][t] = None + inference_state["per_frame_raw_point_input"][t] = None + inference_state["per_frame_raw_box_input"][t] = None + inference_state["per_frame_visual_prompt"][t] = None + inference_state["per_frame_geometric_prompt"][t] = None + inference_state["per_frame_cur_step"][t] = 0 + inference_state["backbone_out"] = None + inference_state["visual_prompt_embed"] = None + inference_state["visual_prompt_mask"] = None + # reset extra states + inference_state["sam2_inference_states"].clear() + inference_state["tracker_metadata"].clear() + inference_state["feature_cache"].clear() + inference_state["cached_frame_outputs"] = {} + + def _get_processing_order( + self, inference_state, start_frame_idx, max_frame_num_to_track, reverse + ): + num_frames = inference_state["num_frames"] + previous_stages_out = inference_state["previous_stages_out"] + if all(out is None for out in previous_stages_out) and start_frame_idx is None: + raise RuntimeError( + "No prompts are received on any frames. Please add prompt on at least one frame before propagation." + ) + # set start index, end index, and processing order + if start_frame_idx is None: + # default: start from the earliest frame with input points + start_frame_idx = min( + t for t, out in enumerate(previous_stages_out) if out is not None + ) + if max_frame_num_to_track is None: + # default: track all the frames in the video + max_frame_num_to_track = num_frames + if reverse: + end_frame_idx = start_frame_idx - max_frame_num_to_track + end_frame_idx = max(end_frame_idx, 0) + processing_order = range(start_frame_idx - 1, end_frame_idx - 1, -1) + else: + end_frame_idx = start_frame_idx + max_frame_num_to_track + end_frame_idx = min(end_frame_idx, num_frames - 1) + processing_order = range(start_frame_idx, end_frame_idx + 1) + return processing_order, end_frame_idx + + @torch.inference_mode() + def propagate_in_video( + self, + inference_state, + start_frame_idx=None, + max_frame_num_to_track=None, + reverse=False, + output_prob_thresh=0.5, + compute_stability_score=False, + is_instance_processing=False, + **kwargs, # To support passing extra args to child classes + ): + """ + Propagate the prompts to get grounding results for the entire video. This method + is a generator and yields inference outputs for all frames in the range specified + by `start_frame_idx`, `max_frame_num_to_track`, and `reverse`. + """ + # compile the model (it's a no-op if the model is already compiled) + # note that it's intentionally added to `self.propagate_in_video`, so that the first + # `self.add_prompt` call will be done in eager mode to fill in the decoder buffers + # such as positional encoding cache) + self._compile_model() + + processing_order, end_frame_idx = self._get_processing_order( + inference_state, + start_frame_idx, + max_frame_num_to_track, + reverse=reverse, + ) + + # Store max_frame_num_to_track in feature_cache for downstream methods + inference_state["feature_cache"]["tracking_bounds"] = { + "max_frame_num_to_track": max_frame_num_to_track, + "propagate_in_video_start_frame_idx": start_frame_idx, + } + + hotstart_buffer = [] + hotstart_removed_obj_ids = set() + # when deciding whether to output a masklet on `yield_frame_idx`, we check whether the object is confirmed + # in a future frame (`unconfirmed_frame_delay` frames after the current frame). For example, if we require + # an object to be detected in 3 consecutive frames to be confirmed, then we look 2 frames in the future -- + # e.g., we output an object on frame 4 only if it becomes confirmed on frame 6. + unconfirmed_status_delay = self.masklet_confirmation_consecutive_det_thresh - 1 + unconfirmed_obj_ids_per_frame = {} # frame_idx -> hidden_obj_ids + + # Batch postprocessing: accumulate yield_list entries and process every postprocess_batch_size frames + postprocess_yield_list = [] + + for frame_idx in tqdm( + processing_order, desc="propagate_in_video", disable=self.rank > 0 + ): + out = self._run_single_frame_inference( + inference_state, + frame_idx, + reverse, + is_instance_processing=is_instance_processing, + ) + + if self.hotstart_delay > 0: + # accumulate the outputs for the first `hotstart_delay` frames + hotstart_buffer.append([frame_idx, out]) + # update the object IDs removed by hotstart so that we don't output them + if self.rank == 0: + hotstart_removed_obj_ids.update(out["removed_obj_ids"]) + unconfirmed_obj_ids = out.get("unconfirmed_obj_ids", None) + if unconfirmed_obj_ids is not None: + unconfirmed_obj_ids_per_frame[frame_idx] = unconfirmed_obj_ids + + if frame_idx == end_frame_idx: + # we reached the end of propagation -- yield all frames in the buffer + yield_list = hotstart_buffer + hotstart_buffer = [] + elif len(hotstart_buffer) >= self.hotstart_delay: + # we have enough frames -- yield and remove the first (oldest) frame from the buffer + yield_list = hotstart_buffer[:1] + hotstart_buffer = hotstart_buffer[1:] + else: + # not enough frames yet -- skip yielding + yield_list = [] + else: + yield_list = [(frame_idx, out)] # output the current frame + + # Accumulate yield_list into postprocess_yield_list + # Snapshot hotstart_removed_obj_ids at the time of accumulation to preserve + # the correct state for each frame (important: this set is mutated over time) + for yield_frame_idx, yield_out in yield_list: + postprocess_yield_list.append( + (yield_frame_idx, yield_out, set(hotstart_removed_obj_ids)) + ) + + # Process batch when we have enough frames + while len(postprocess_yield_list) >= self.postprocess_batch_size: + batch_to_process = postprocess_yield_list[: self.postprocess_batch_size] + postprocess_yield_list = postprocess_yield_list[ + self.postprocess_batch_size : + ] + + with torch.profiler.record_function( + "Sam3MultiplexTracking.postprocess_output_batched" + ): + if self.rank == 0: + # Prepare batched inputs for postprocessing + H_video, W_video = ( + inference_state["orig_height"], + inference_state["orig_width"], + ) + num_frames = inference_state["num_frames"] + + batched_outs = [] + frame_indices = [] + for ( + yield_frame_idx, + yield_out, + removed_obj_ids_snapshot, + ) in batch_to_process: + suppressed_obj_ids = yield_out["suppressed_obj_ids"] + unconfirmed_status_frame_idx = ( + yield_frame_idx + unconfirmed_status_delay + if not reverse + else yield_frame_idx - unconfirmed_status_delay + ) + unconfirmed_status_frame_idx = max( + 0, min(unconfirmed_status_frame_idx, num_frames - 1) + ) + unconfirmed_obj_ids = unconfirmed_obj_ids_per_frame.get( + unconfirmed_status_frame_idx, None + ) + + batched_outs.append( + ( + yield_out, + removed_obj_ids_snapshot, + suppressed_obj_ids, + unconfirmed_obj_ids, + ) + ) + frame_indices.append(yield_frame_idx) + + # Cache frame outputs + self._cache_frame_outputs( + inference_state, + yield_frame_idx, + yield_out["obj_id_to_mask"], + suppressed_obj_ids=suppressed_obj_ids, + removed_obj_ids=removed_obj_ids_snapshot, + unconfirmed_obj_ids=unconfirmed_obj_ids, + ) + + if self.postprocess_batch_size > 1: + # Process all frames in batch + postprocessed_outs = self._postprocess_output_batched( + H_video, W_video, batched_outs + ) + else: + # Process each frame individually but output together + postprocessed_outs = [] + for ( + yield_out, + removed_obj_ids_snapshot, + suppressed_obj_ids, + unconfirmed_obj_ids, + ) in batched_outs: + postprocessed_out = self._postprocess_output( + inference_state, + yield_out, + removed_obj_ids_snapshot, + suppressed_obj_ids, + unconfirmed_obj_ids, + ) + postprocessed_outs.append(postprocessed_out) + + # Yield results + for yield_frame_idx, postprocessed_out in zip( + frame_indices, postprocessed_outs + ): + yield yield_frame_idx, postprocessed_out + else: + # No output on other GPUs + for yield_frame_idx, _, _ in batch_to_process: + yield yield_frame_idx, DUMMY_OUTPUT + + # Flush any remaining frames in the postprocess buffer + if len(postprocess_yield_list) > 0: + with torch.profiler.record_function( + "Sam3MultiplexTracking.postprocess_output_batched" + ): + if self.rank == 0: + H_video, W_video = ( + inference_state["orig_height"], + inference_state["orig_width"], + ) + num_frames = inference_state["num_frames"] + + batched_outs = [] + frame_indices = [] + for ( + yield_frame_idx, + yield_out, + removed_obj_ids_snapshot, + ) in postprocess_yield_list: + suppressed_obj_ids = yield_out["suppressed_obj_ids"] + unconfirmed_status_frame_idx = ( + yield_frame_idx + unconfirmed_status_delay + if not reverse + else yield_frame_idx - unconfirmed_status_delay + ) + unconfirmed_status_frame_idx = max( + 0, min(unconfirmed_status_frame_idx, num_frames - 1) + ) + unconfirmed_obj_ids = unconfirmed_obj_ids_per_frame.get( + unconfirmed_status_frame_idx, None + ) + + batched_outs.append( + ( + yield_out, + removed_obj_ids_snapshot, + suppressed_obj_ids, + unconfirmed_obj_ids, + ) + ) + frame_indices.append(yield_frame_idx) + + self._cache_frame_outputs( + inference_state, + yield_frame_idx, + yield_out["obj_id_to_mask"], + suppressed_obj_ids=suppressed_obj_ids, + removed_obj_ids=removed_obj_ids_snapshot, + unconfirmed_obj_ids=unconfirmed_obj_ids, + ) + + if self.postprocess_batch_size > 1: + postprocessed_outs = self._postprocess_output_batched( + H_video, W_video, batched_outs + ) + else: + # Process each frame individually but output together + postprocessed_outs = [] + for ( + yield_out, + removed_obj_ids_snapshot, + suppressed_obj_ids, + unconfirmed_obj_ids, + ) in batched_outs: + postprocessed_out = self._postprocess_output( + inference_state, + yield_out, + removed_obj_ids_snapshot, + suppressed_obj_ids, + unconfirmed_obj_ids, + ) + postprocessed_outs.append(postprocessed_out) + + for yield_frame_idx, postprocessed_out in zip( + frame_indices, postprocessed_outs + ): + yield yield_frame_idx, postprocessed_out + else: + for yield_frame_idx, _, _ in postprocess_yield_list: + yield yield_frame_idx, DUMMY_OUTPUT + + if self.is_multiplex: + # log the bucket utilization stats + # bucket utilization rate is total valid objects / total capacity -> represents rooms for improvement + # subscription rate is total valid objects / total number of buckets -> represents speedup + total_valid_objects = 0 + total_num_buckets = 0 + for state in inference_state["sam2_inference_states"]: + assert ( + len(state["obj_ids"]) + == state["multiplex_state"].total_valid_entries + ) + total_valid_objects += len(state["obj_ids"]) + total_num_buckets += state["multiplex_state"].num_buckets + if total_num_buckets > 0: + bucket_utilization_rate = ( + total_valid_objects / (total_num_buckets * self.bucket_capacity) + ) * 100 + subscription_rate = (total_valid_objects / total_num_buckets) * 100 + logger.info( + f"Bucket utilization rate: {bucket_utilization_rate:.2f}%, subscription rate: {subscription_rate:.2f}%" + ) + + def _run_single_frame_inference( + self, + inference_state, + frame_idx, + reverse, + is_instance_processing=False, + ): + """ + Perform inference on a single frame and get its inference results. This would + also update `inference_state`. + """ + # prepare inputs + input_batch = inference_state["input_batch"] + tracker_states_local = inference_state["sam2_inference_states"] + geometric_prompt = ( + inference_state["constants"]["empty_geometric_prompt"] + if inference_state["per_frame_geometric_prompt"][frame_idx] is None + else inference_state["per_frame_geometric_prompt"][frame_idx] + ) + text_batch_key = tuple(input_batch.find_text_batch) + inference_state["feature_cache"]["text"] = { + text_batch_key: { + "language_features": inference_state["backbone_out"][ + "language_features" + ], + "language_mask": inference_state["backbone_out"]["language_mask"], + } + } + # run inference for the current frame + ( + obj_id_to_mask, + obj_id_to_score, + tracker_states_local_new, + tracker_metadata_new, + frame_stats, + _, + ) = self._det_track_one_frame( + frame_idx=frame_idx, + num_frames=inference_state["num_frames"], + reverse=reverse, + input_batch=input_batch, + geometric_prompt=geometric_prompt, + tracker_states_local=tracker_states_local, + tracker_metadata_prev=inference_state["tracker_metadata"], + feature_cache=inference_state["feature_cache"], + orig_vid_height=inference_state["orig_height"], + orig_vid_width=inference_state["orig_width"], + is_image_only=inference_state["is_image_only"], + ) + # update inference state + inference_state["sam2_inference_states"] = tracker_states_local_new + inference_state["tracker_metadata"] = tracker_metadata_new + # use a dummy string in "previous_stages_out" to indicate this frame has outputs + inference_state["previous_stages_out"][frame_idx] = "_THIS_FRAME_HAS_OUTPUTS_" + + if self.rank == 0: + self._cache_frame_outputs(inference_state, frame_idx, obj_id_to_mask) + + out = { + "obj_id_to_mask": obj_id_to_mask, + "obj_id_to_score": obj_id_to_score, # first frame detection score + "obj_id_to_sam2_score": tracker_metadata_new[ + "obj_id_to_sam2_score_frame_wise" + ][frame_idx], + } + # removed_obj_ids is only needed on rank 0 to handle hotstart delay buffer + if self.rank == 0: + rank0_metadata = tracker_metadata_new["rank0_metadata"] + removed_obj_ids = rank0_metadata["removed_obj_ids"] + out["removed_obj_ids"] = removed_obj_ids + out["suppressed_obj_ids"] = rank0_metadata["suppressed_obj_ids"][frame_idx] + out["frame_stats"] = frame_stats + if self.masklet_confirmation_enable: + status = rank0_metadata["masklet_confirmation"]["status"] + is_unconfirmed = status == MaskletConfirmationStatus.UNCONFIRMED.value + out["unconfirmed_obj_ids"] = tracker_metadata_new["obj_ids_all_gpu"][ + is_unconfirmed + ].tolist() + else: + out["unconfirmed_obj_ids"] = [] + + return out + + def _postprocess_output( + self, + inference_state, + out, + removed_obj_ids=None, + suppressed_obj_ids=None, + unconfirmed_obj_ids=None, + ): + obj_id_to_mask = out["obj_id_to_mask"] # low res masks + curr_obj_ids = sorted(obj_id_to_mask.keys()) + H_video, W_video = inference_state["orig_height"], inference_state["orig_width"] + if len(curr_obj_ids) == 0: + out_obj_ids = torch.zeros(0, dtype=torch.int64) + out_probs = torch.zeros(0, dtype=torch.float32) + out_binary_masks = torch.zeros(0, H_video, W_video, dtype=torch.bool) + out_boxes_xywh = torch.zeros(0, 4, dtype=torch.float32) + else: + out_obj_ids = torch.tensor(curr_obj_ids, dtype=torch.int64) + out_probs = torch.tensor( + [out["obj_id_to_score"][obj_id] for obj_id in curr_obj_ids] + ) + out_sam2_probs = torch.tensor( + [ + ( + out["obj_id_to_sam2_score"][obj_id] + if obj_id in out["obj_id_to_sam2_score"] + else 0.0 + ) + for obj_id in curr_obj_ids + ] + ) + out_binary_masks = torch.cat( + [obj_id_to_mask[obj_id] for obj_id in curr_obj_ids], dim=0 + ) + + assert out_binary_masks.dtype == torch.bool + keep = out_binary_masks.any(dim=(1, 2)).cpu() # remove masks with 0 areas + # hide outputs for those object IDs in `obj_ids_to_hide` + obj_ids_to_hide = [] + if suppressed_obj_ids is not None: + obj_ids_to_hide.extend(suppressed_obj_ids) + if removed_obj_ids is not None: + obj_ids_to_hide.extend(removed_obj_ids) + if unconfirmed_obj_ids is not None: + obj_ids_to_hide.extend(unconfirmed_obj_ids) + if len(obj_ids_to_hide) > 0: + obj_ids_to_hide_t = torch.tensor(obj_ids_to_hide, dtype=torch.int64) + keep &= ~torch.isin(out_obj_ids, obj_ids_to_hide_t) + + # slice those valid entries from the original outputs + keep_idx = torch.nonzero(keep, as_tuple=True)[0] + keep_idx_gpu = keep_idx.pin_memory().to( + device=out_binary_masks.device, non_blocking=True + ) + + out_obj_ids = torch.index_select(out_obj_ids, 0, keep_idx) + out_probs = torch.index_select(out_probs, 0, keep_idx) + out_sam2_probs = torch.index_select(out_sam2_probs, 0, keep_idx) + out_binary_masks = torch.index_select(out_binary_masks, 0, keep_idx_gpu) + + if perflib.is_enabled: + out_boxes_xyxy = perf_masks_to_boxes( + out_binary_masks, out_obj_ids.tolist() + ) + else: + out_boxes_xyxy = masks_to_boxes(out_binary_masks) + + out_boxes_xywh = box_xyxy_to_xywh(out_boxes_xyxy) # convert to xywh format + # normalize boxes + out_boxes_xywh[..., 0] /= W_video + out_boxes_xywh[..., 1] /= H_video + out_boxes_xywh[..., 2] /= W_video + out_boxes_xywh[..., 3] /= H_video + + # apply non-overlapping constraints on the existing masklets + if out_binary_masks.shape[0] > 1: + assert len(out_binary_masks) == len(out_sam2_probs) + out_binary_masks = ( + self.tracker._apply_object_wise_non_overlapping_constraints( + out_binary_masks.unsqueeze(1), + out_sam2_probs.unsqueeze(1).to(out_binary_masks.device), + background_value=0, + ).squeeze(1) + ) > 0 + + prod_outputs = {} + if self.running_in_prod: + with torch.profiler.record_function( + "Sam3MultiplexTracking._postprocess_output.prod_outputs" + ): + out_centers = torch.zeros( + out_binary_masks.shape[0], + 2, + dtype=torch.float32, + device=out_binary_masks.device, + ) + + y_coords = torch.arange( + H_video, device=out_binary_masks.device, dtype=torch.float32 + ) + x_coords = torch.arange( + W_video, device=out_binary_masks.device, dtype=torch.float32 + ) + y_grid = y_coords.view(1, H_video, 1) + x_grid = x_coords.view(1, 1, W_video) + with torch.profiler.record_function( + "Sam3MultiplexTracking._postprocess_output.prod_outputs.center" + ): + weighted_y_sum = (out_binary_masks * y_grid).sum(dim=(1, 2)) + weighted_x_sum = (out_binary_masks * x_grid).sum(dim=(1, 2)) + total_mass = out_binary_masks.sum(dim=(1, 2)).clamp_min(1e-6) + center_y = weighted_y_sum / total_mass / H_video + center_x = weighted_x_sum / total_mass / W_video + out_centers[:, 0] = center_x + out_centers[:, 1] = center_y + + with torch.profiler.record_function( + "Sam3MultiplexTracking._postprocess_output.prod_outputs.to_cpu" + ): + prod_outputs["out_centers"] = out_centers.cpu().numpy() + + outputs = { + "out_obj_ids": out_obj_ids.cpu().numpy(), + "out_probs": out_probs.cpu().numpy(), + "out_boxes_xywh": out_boxes_xywh.cpu().numpy(), + "out_binary_masks": out_binary_masks.cpu().numpy(), + "frame_stats": out.get("frame_stats", None), + } | prod_outputs + + return outputs + + def _postprocess_output_batched( + self, + H_video, + W_video, + batched_outs, + ): + """ + Batched version of _postprocess_output that batches GPU computations + (keep filtering, box computation) across frames for efficiency. + + Args: + H_video: Video height + W_video: Video width + batched_outs: List of tuples, each containing: + (out, removed_obj_ids, suppressed_obj_ids, unconfirmed_obj_ids) + where out is the output dict from _run_single_frame_inference + + Returns: + List of output dicts, one per frame in batched_outs + """ + batch_size = len(batched_outs) + if batch_size == 0: + return [] + + # ========== Phase 1: Collect per-frame data ========== + # We'll track: frame_data[i] = (obj_ids, probs, sam2_probs, masks, keep_mask, frame_stats) + # or None if frame has no objects + frame_data = [] + device = None + + for ( + out, + removed_obj_ids, + suppressed_obj_ids, + unconfirmed_obj_ids, + ) in batched_outs: + obj_id_to_mask = out["obj_id_to_mask"] + curr_obj_ids = sorted(obj_id_to_mask.keys()) + frame_stats = out.get("frame_stats", None) + + if len(curr_obj_ids) == 0: + frame_data.append((None, None, None, None, None, frame_stats)) + continue + + out_obj_ids = torch.tensor(curr_obj_ids, dtype=torch.int64) + obj_id_to_score_dict = out["obj_id_to_score"] + obj_id_to_sam2_score = out["obj_id_to_sam2_score"] + + if device is None: + device = obj_id_to_mask[curr_obj_ids[0]].device + default_sam2_score = torch.zeros((), dtype=torch.float32, device=device) + + probs_list = [] + sam2_probs_list = [] + binary_masks_list = [] + + for obj_id in curr_obj_ids: + probs_list.append(obj_id_to_score_dict[obj_id]) + sam2_probs_list.append( + obj_id_to_sam2_score.get(obj_id, default_sam2_score) + ) + binary_masks_list.append(obj_id_to_mask[obj_id]) + + out_probs = torch.tensor(probs_list, dtype=torch.float32) + out_sam2_probs_gpu = torch.stack(sam2_probs_list) + out_binary_masks = torch.cat(binary_masks_list, dim=0) + + # Compute keep mask (which objects to hide) + obj_ids_to_hide = [] + if suppressed_obj_ids is not None: + obj_ids_to_hide.extend(suppressed_obj_ids) + if removed_obj_ids is not None: + obj_ids_to_hide.extend(removed_obj_ids) + if unconfirmed_obj_ids is not None: + obj_ids_to_hide.extend(unconfirmed_obj_ids) + + if len(obj_ids_to_hide) > 0: + obj_ids_to_hide_t = torch.tensor(obj_ids_to_hide, dtype=torch.int64) + hide_mask = torch.isin(out_obj_ids, obj_ids_to_hide_t) + else: + hide_mask = torch.zeros(len(out_obj_ids), dtype=torch.bool) + + frame_data.append( + ( + out_obj_ids, + out_probs, + out_sam2_probs_gpu, + out_binary_masks, + hide_mask, + frame_stats, + ) + ) + + # ========== Phase 2: Batch concatenate masks for GPU operations ========== + # Collect frames with objects + frames_with_objects = [] + frame_obj_counts = [] # Number of objects per frame (for frames with objects only) + all_masks_list = [] + all_hide_masks_list = [] + + for i, data in enumerate(frame_data): + if data[0] is not None: + frames_with_objects.append(i) + frame_obj_counts.append(data[0].shape[0]) + all_masks_list.append(data[3]) # binary_masks + all_hide_masks_list.append(data[4]) # hide_mask + + # Handle case where all frames have 0 objects + if len(frames_with_objects) == 0: + outputs = [] + for data in frame_data: + output_dict = { + "out_obj_ids": np.zeros(0, dtype=np.int64), + "out_probs": np.zeros(0, dtype=np.float32), + "out_boxes_xywh": np.zeros((0, 4), dtype=np.float32), + "out_binary_masks": np.zeros((0, H_video, W_video), dtype=bool), + "frame_stats": data[5], + } + if self.running_in_prod: + output_dict["out_centers"] = np.zeros((0, 2), dtype=np.float32) + outputs.append(output_dict) + return outputs + + # Concatenate all masks for batched GPU operations + all_masks = torch.cat(all_masks_list, dim=0) + all_hide_masks = torch.cat(all_hide_masks_list, dim=0) + + # ========== Phase 3: Batched keep mask computation on GPU ========== + # Compute which masks have non-zero area (batched on GPU) + has_area = all_masks.any(dim=(1, 2)) # GPU operation + + # Combine with hide mask (move hide_mask to GPU for the operation) + all_hide_masks_gpu = all_hide_masks.to(device=all_masks.device) + keep_mask_gpu = has_area & ~all_hide_masks_gpu + + # Get keep indices + keep_indices = torch.nonzero(keep_mask_gpu, as_tuple=True)[0] + + if len(keep_indices) == 0: + # All objects filtered out + outputs = [] + for data in frame_data: + output_dict = { + "out_obj_ids": np.zeros(0, dtype=np.int64), + "out_probs": np.zeros(0, dtype=np.float32), + "out_boxes_xywh": np.zeros((0, 4), dtype=np.float32), + "out_binary_masks": np.zeros((0, H_video, W_video), dtype=bool), + "frame_stats": data[5], + } + if self.running_in_prod: + output_dict["out_centers"] = np.zeros((0, 2), dtype=np.float32) + outputs.append(output_dict) + return outputs + + # ========== Phase 4: Batched filtering and box computation ========== + # Filter masks on GPU + kept_masks = torch.index_select(all_masks, 0, keep_indices) + + # Compute bounding boxes in batch on GPU + if perflib.is_enabled: + # Need to gather obj_ids for perflib + all_obj_ids_list = [frame_data[i][0] for i in frames_with_objects] + all_obj_ids_cat = torch.cat(all_obj_ids_list, dim=0) + kept_obj_ids_for_perf = torch.index_select( + all_obj_ids_cat, 0, keep_indices.cpu() + ) + kept_boxes_xyxy = perf_masks_to_boxes( + kept_masks, kept_obj_ids_for_perf.tolist() + ) + else: + kept_boxes_xyxy = masks_to_boxes(kept_masks) + + kept_boxes_xywh = box_xyxy_to_xywh(kept_boxes_xyxy) + kept_boxes_xywh[..., 0] /= W_video + kept_boxes_xywh[..., 1] /= H_video + kept_boxes_xywh[..., 2] /= W_video + kept_boxes_xywh[..., 3] /= H_video + + # ========== Phase 5: Split back to per-frame for non-overlapping ========== + # Compute how many objects were kept per frame + keep_indices_cpu = keep_indices.cpu() + keep_set = set(keep_indices_cpu.tolist()) + + kept_counts = [] + offset = 0 + for count in frame_obj_counts: + kept_in_frame = sum( + 1 for j in range(offset, offset + count) if j in keep_set + ) + kept_counts.append(kept_in_frame) + offset += count + + # Split the kept tensors back to per-frame + split_masks = torch.split(kept_masks, kept_counts) + split_boxes = torch.split(kept_boxes_xywh, kept_counts) + + # Also need to split obj_ids, probs, sam2_probs (filtering from original frame_data) + # We need to track which original indices were kept per frame + frame_kept_indices = [] # List of (local_kept_indices) per frame + offset = 0 + for count in frame_obj_counts: + local_kept = [] + for j in range(offset, offset + count): + if j in keep_set: + local_kept.append(j - offset) # Local index within frame + frame_kept_indices.append(local_kept) + offset += count + + # ========== Phase 6: Apply non-overlapping per frame, collect final results ========== + final_results = [] # List of (frame_idx, obj_ids, probs, boxes, masks) + + for idx, frame_i in enumerate(frames_with_objects): + data = frame_data[frame_i] + local_kept = frame_kept_indices[idx] + + if len(local_kept) == 0: + continue + + # Get the filtered data for this frame + local_kept_t = torch.tensor(local_kept, dtype=torch.int64) + out_obj_ids = torch.index_select(data[0], 0, local_kept_t) + out_probs = torch.index_select(data[1], 0, local_kept_t) + out_sam2_probs = torch.index_select( + data[2], 0, local_kept_t.to(data[2].device) + ) + out_masks = split_masks[idx] + out_boxes = split_boxes[idx] + + # Apply non-overlapping constraints (per-frame operation) + if out_masks.shape[0] > 1: + # Copy sam2_probs to CPU pinned memory then back to GPU for the operation + out_sam2_probs_cpu = torch.empty( + out_sam2_probs.shape, dtype=out_sam2_probs.dtype, pin_memory=True + ) + out_sam2_probs_cpu.copy_(out_sam2_probs, non_blocking=True) + out_masks = ( + self.tracker._apply_object_wise_non_overlapping_constraints( + out_masks.unsqueeze(1), + out_sam2_probs_cpu.unsqueeze(1).to(out_masks.device), + background_value=0, + ).squeeze(1) + ) > 0 + + final_results.append( + (frame_i, out_obj_ids, out_probs, out_boxes, out_masks) + ) + + # ========== Phase 6.5: Compute centers for prod ========== + all_centers = None + if self.running_in_prod and len(final_results) > 0: + with torch.profiler.record_function( + "Sam3MultiplexTracking._postprocess_output_batched.prod_outputs" + ): + # Concatenate all masks for batched center computation + all_masks = torch.cat([r[4] for r in final_results], dim=0) + if all_masks.shape[0] > 0: + y_coords = torch.arange( + H_video, device=all_masks.device, dtype=torch.float32 + ) + x_coords = torch.arange( + W_video, device=all_masks.device, dtype=torch.float32 + ) + y_grid = y_coords.view(1, H_video, 1) + x_grid = x_coords.view(1, 1, W_video) + + weighted_y_sum = (all_masks * y_grid).sum(dim=(1, 2)) + weighted_x_sum = (all_masks * x_grid).sum(dim=(1, 2)) + total_mass = all_masks.sum(dim=(1, 2)).clamp_min(1e-6) + center_y = weighted_y_sum / total_mass / H_video + center_x = weighted_x_sum / total_mass / W_video + all_centers = torch.stack([center_x, center_y], dim=1) + + # Handle case where all filtered out + if len(final_results) == 0: + outputs = [] + for data in frame_data: + output_dict = { + "out_obj_ids": np.zeros(0, dtype=np.int64), + "out_probs": np.zeros(0, dtype=np.float32), + "out_boxes_xywh": np.zeros((0, 4), dtype=np.float32), + "out_binary_masks": np.zeros((0, H_video, W_video), dtype=bool), + "frame_stats": data[5], + } + if self.running_in_prod: + output_dict["out_centers"] = np.zeros((0, 2), dtype=np.float32) + outputs.append(output_dict) + return outputs + + # ========== Phase 7: Concatenate for batched GPU→CPU copy ========== + final_obj_ids = torch.cat([r[1] for r in final_results], dim=0) + final_probs = torch.cat([r[2] for r in final_results], dim=0) + final_boxes = torch.cat([r[3] for r in final_results], dim=0) + final_masks = torch.cat([r[4] for r in final_results], dim=0) + + total_objects = final_obj_ids.shape[0] + + # Initialize or resize batched CPU buffer + batched_buffer_size = self.postprocess_batch_size * self.max_num_objects + needs_buffer_init = not hasattr(self, "buffer_cpu_batched") + needs_buffer_resize = not needs_buffer_init and ( + self.buffer_cpu_batched["out_binary_masks"].shape[0] != batched_buffer_size + or self.buffer_cpu_batched["out_binary_masks"].shape[1] != H_video + or self.buffer_cpu_batched["out_binary_masks"].shape[2] != W_video + ) + + if needs_buffer_init or needs_buffer_resize: + self.buffer_cpu_batched = { + "out_obj_ids": torch.zeros( + batched_buffer_size, + dtype=torch.int64, + device="cpu", + pin_memory=True, + ), + "out_probs": torch.zeros( + batched_buffer_size, + dtype=torch.float32, + device="cpu", + pin_memory=True, + ), + "out_boxes_xywh": torch.zeros( + batched_buffer_size, + 4, + dtype=torch.float32, + device="cpu", + pin_memory=True, + ), + "out_binary_masks": torch.zeros( + batched_buffer_size, + H_video, + W_video, + dtype=bool, + device="cpu", + pin_memory=True, + ), + } + if self.running_in_prod: + self.buffer_cpu_batched["out_centers"] = torch.zeros( + batched_buffer_size, + 2, + dtype=torch.float32, + device="cpu", + pin_memory=True, + ) + + self.buffer_cpu_batched["out_obj_ids"][:total_objects].copy_(final_obj_ids) + self.buffer_cpu_batched["out_probs"][:total_objects].copy_(final_probs) + self.buffer_cpu_batched["out_boxes_xywh"][:total_objects].copy_(final_boxes) + self.buffer_cpu_batched["out_binary_masks"][:total_objects].copy_(final_masks) + + if all_centers is not None: + self.buffer_cpu_batched["out_centers"][:total_objects].copy_(all_centers) + + # ========== Phase 8: Build output list ========== + # Create mapping from frame index to (offset, count) in the buffer + frame_to_offset_count = {} + offset = 0 + for frame_i, obj_ids, _, _, _ in final_results: + count = obj_ids.shape[0] + frame_to_offset_count[frame_i] = (offset, count) + offset += count + + outputs = [] + for i, data in enumerate(frame_data): + frame_stats = data[5] + if i not in frame_to_offset_count: + # Frame has no objects (either originally or after filtering) + output_dict = { + "out_obj_ids": np.zeros(0, dtype=np.int64), + "out_probs": np.zeros(0, dtype=np.float32), + "out_boxes_xywh": np.zeros((0, 4), dtype=np.float32), + "out_binary_masks": np.zeros((0, H_video, W_video), dtype=bool), + "frame_stats": frame_stats, + } + if all_centers is not None: + output_dict["out_centers"] = np.zeros((0, 2), dtype=np.float32) + outputs.append(output_dict) + else: + buf_offset, num_objects = frame_to_offset_count[i] + output_dict = { + "out_obj_ids": self.buffer_cpu_batched["out_obj_ids"][ + buf_offset : buf_offset + num_objects + ] + .numpy() + .copy(), + "out_probs": self.buffer_cpu_batched["out_probs"][ + buf_offset : buf_offset + num_objects + ] + .numpy() + .copy(), + "out_boxes_xywh": self.buffer_cpu_batched["out_boxes_xywh"][ + buf_offset : buf_offset + num_objects + ] + .numpy() + .copy(), + "out_binary_masks": self.buffer_cpu_batched["out_binary_masks"][ + buf_offset : buf_offset + num_objects + ] + .numpy() + .copy(), + "frame_stats": frame_stats, + } + if all_centers is not None: + output_dict["out_centers"] = ( + self.buffer_cpu_batched["out_centers"][ + buf_offset : buf_offset + num_objects + ] + .numpy() + .copy() + ) + outputs.append(output_dict) + + return outputs + + def _cache_frame_outputs( + self, + inference_state, + frame_idx, + obj_id_to_mask, + suppressed_obj_ids=None, + removed_obj_ids=None, + unconfirmed_obj_ids=None, + ): + if "cached_frame_outputs" not in inference_state: + inference_state["cached_frame_outputs"] = {} + + # Filter out suppressed, removed, and unconfirmed objects from the cache + filtered_obj_id_to_mask = obj_id_to_mask.copy() + + objects_to_exclude = set() + if suppressed_obj_ids is not None: + objects_to_exclude.update(suppressed_obj_ids) + if removed_obj_ids is not None: + objects_to_exclude.update(removed_obj_ids) + if unconfirmed_obj_ids is not None: + objects_to_exclude.update(unconfirmed_obj_ids) + + if objects_to_exclude: + for obj_id in objects_to_exclude: + if obj_id in filtered_obj_id_to_mask: + del filtered_obj_id_to_mask[obj_id] + + inference_state["cached_frame_outputs"][frame_idx] = filtered_obj_id_to_mask + + def _build_sam2_output( + self, inference_state, frame_idx, refined_obj_id_to_mask=None + ): + if not frame_idx in inference_state["cached_frame_outputs"]: + return {} + + cached_outputs = inference_state["cached_frame_outputs"][frame_idx] + obj_id_to_mask = cached_outputs.copy() + + # Update with refined masks if provided + if refined_obj_id_to_mask is not None: + for obj_id, refined_mask in refined_obj_id_to_mask.items(): + assert ( + refined_mask is not None + ), f"Refined mask data must be provided for obj_id {obj_id}" + obj_id_to_mask[obj_id] = refined_mask + + return obj_id_to_mask + + def _compile_model(self): + """Compile the SAM model with torch.compile for speedup.""" + # TODO: compile SAM2 model components + is_compiled = getattr(self, "_model_is_compiled", False) + if is_compiled or not self.compile_model: + return + + import torch._dynamo + + # a larger cache size to hold varying number of shapes for torch.compile + # see https://github.com/pytorch/pytorch/blob/v2.5.1/torch/_dynamo/config.py#L42-L49 + torch._dynamo.config.cache_size_limit = 128 + torch._dynamo.config.accumulated_cache_size_limit = 2048 + torch._dynamo.config.capture_scalar_outputs = True + torch._dynamo.config.suppress_errors = True + + # Compile module components following https://www.internalfb.com/diff/D70935785 + # skip compilation of `_encode_prompt` since it sometimes tiggger SymInt errors + # self._encode_prompt = clone_output_wrapper( + # torch.compile(self._encode_prompt, fullgraph=True, mode="max-autotune") + # ) + + ## Compile SAM3 model components (matching OV: clone_output_wrapper(torch.compile(fn))) + self.detector.backbone.language_backbone.encoder.forward = clone_output_wrapper( + torch.compile( + self.detector.backbone.language_backbone.encoder.forward, + fullgraph=True, + mode="max-autotune", + ) + ) + + self.detector.backbone.vision_backbone.forward = clone_output_wrapper( + torch.compile( + self.detector.backbone.vision_backbone.forward, + fullgraph=True, + mode="max-autotune", + ) + ) + self.detector.transformer.encoder.forward = clone_output_wrapper( + torch.compile( + self.detector.transformer.encoder.forward, + fullgraph=True, + mode="max-autotune", + ) + ) + self.detector.transformer.decoder.forward = clone_output_wrapper( + torch.compile( + self.detector.transformer.decoder.forward, + fullgraph=True, + mode="max-autotune", + dynamic=False, # note: FA decoder uses static shapes + ) + ) + + self.detector.segmentation_head.forward = clone_output_wrapper( + torch.compile( + self.detector.segmentation_head.forward, + fullgraph=True, + mode="max-autotune", + ) + ) + + ## Compile SAM2 model components + self.tracker.maskmem_backbone.forward = compile_wrapper( + self.tracker.maskmem_backbone.forward, + mode="max-autotune", + fullgraph=True, + dynamic=False, + ) + + self.tracker.transformer.encoder.forward = shape_logging_wrapper( + compile_wrapper( + self.tracker.transformer.encoder.forward, + mode="max-autotune-no-cudagraphs", + fullgraph=True, + dynamic=True, + ), + keep_kwargs=["src", "src_pos", "prompt", "prompt_pos"], + ) + + self.tracker.sam_mask_decoder.forward = compile_wrapper( + self.tracker.sam_mask_decoder.forward, + mode="max-autotune", + fullgraph=True, + dynamic=False, # Accuracy regression on True + ) + + sam3.model.sam3_video_base._associate_det_trk_compilable = compile_wrapper( + sam3.model.sam3_video_base._associate_det_trk_compilable, + mode="max-autotune-no-cudagraphs", + fullgraph=True, + dynamic=False, + ) + + self.tracker._suppress_object_pw_area_shrinkage = compile_wrapper( + self.tracker._suppress_object_pw_area_shrinkage, + mode="max-autotune-no-cudagraphs", + fullgraph=True, + dynamic=False, + ) + + self._model_is_compiled = True + + def _warm_up_vg_propagation(self, inference_state, start_frame_idx=0): + # use different tracking score thresholds for each round to simulate different number of output objects + num_objects_list = range(self.num_obj_for_compile + 1) + num_rounds = 3 + orig_new_det_thresh = self.new_det_thresh + for i in range(num_rounds): + for num_objects in num_objects_list: + logger.info( + f"round {i + 1}/{num_rounds} warming up model compilation -- simulating {num_objects}/{self.num_obj_for_compile} objects" + ) + # Initialize text prompt and cache image features + self.add_prompt( + inference_state, frame_idx=start_frame_idx, text_str="cat" + ) + if num_objects > 0: + inference_state = self.add_fake_objects_to_inference_state( + inference_state, num_objects, frame_idx=start_frame_idx + ) + inference_state["tracker_metadata"]["rank0_metadata"].update( + { + "masklet_confirmation": { + "status": np.zeros(num_objects, dtype=np.int64), + "consecutive_det_num": np.zeros( + num_objects, dtype=np.int64 + ), + } + } + ) + for _ in self.propagate_in_video( + inference_state, start_frame_idx, reverse=False + ): + pass + for _ in self.propagate_in_video( + inference_state, start_frame_idx, reverse=True + ): + pass + self.reset_state(inference_state) + logger.info( + f"{i + 1}/{num_rounds} warming up model compilation -- completed round {i + 1} out of {num_rounds}" + ) + + # Warm up SAM2 memory encoder with varying input shapes + num_iters = 3 + feat_size = self.tracker.sam_image_embedding_size**2 # 72 * 72 = 5184 + hidden_dim = self.tracker.hidden_dim # 256 + mem_dim = self.tracker.mem_dim # 64 for non-multiplex, 256 for multiplex + is_multiplex = self.tracker.is_multiplex + + for _ in tqdm(range(num_iters)): + for b in range(1, self.num_obj_for_compile + 1): + for i in range( + 1, + self.tracker.max_cond_frames_in_attn + self.tracker.num_maskmem, + ): + for j in range( + self.tracker.max_cond_frames_in_attn + + self.tracker.max_obj_ptrs_in_encoder + ): + if is_multiplex: + # Multiplex encoder: mem_dim == hidden_dim, uses decoupled cross-attention + # num_obj_ptr_tokens = j (since hidden_dim // mem_dim = 1) + num_obj_ptr_tokens = j + memory_seq_len = feat_size * i + num_obj_ptr_tokens + + # src and memory have batch=num_buckets (b) + src = torch.randn( + feat_size, b, hidden_dim, device=self.device + ) + src_pos = torch.randn( + feat_size, b, hidden_dim, device=self.device + ) + memory = torch.randn( + memory_seq_len, b, hidden_dim, device=self.device + ) + memory_pos = torch.randn( + memory_seq_len, b, hidden_dim, device=self.device + ) + + # image and memory_image always have batch=1 (shared image features) + image = torch.randn( + feat_size, 1, hidden_dim, device=self.device + ) + image_pos = torch.randn( + feat_size, 1, hidden_dim, device=self.device + ) + memory_image = torch.randn( + feat_size * i, 1, hidden_dim, device=self.device + ) + memory_image_pos = torch.randn( + feat_size * i, 1, hidden_dim, device=self.device + ) + + self.tracker.transformer.encoder.forward( + image=image, + src=src, + memory_image=memory_image, + memory=memory, + image_pos=image_pos, + src_pos=src_pos, + memory_image_pos=memory_image_pos, + memory_pos=memory_pos, + num_obj_ptr_tokens=num_obj_ptr_tokens, + ) + else: + # Non-multiplex encoder: mem_dim = 64, uses standard cross-attention + # num_obj_ptr_tokens = (hidden_dim // mem_dim) * j = 4 * j + num_obj_ptr_tokens = (hidden_dim // mem_dim) * j + src = torch.randn( + feat_size, b, hidden_dim, device=self.device + ) + src_pos = torch.randn( + feat_size, b, hidden_dim, device=self.device + ) + prompt = torch.randn( + feat_size * i + num_obj_ptr_tokens, + b, + mem_dim, + device=self.device, + ) + prompt_pos = torch.randn( + feat_size * i + num_obj_ptr_tokens, + b, + mem_dim, + device=self.device, + ) + + self.tracker.transformer.encoder.forward( + src=src, + src_pos=src_pos, + prompt=prompt, + prompt_pos=prompt_pos, + num_obj_ptr_tokens=num_obj_ptr_tokens, + ) + + # Warm up different number of kbox + for _ in tqdm(range(num_iters)): + for i in range(1, self.max_num_kboxes + 1): + kboxes = ( + torch.rand(i, 4, dtype=torch.float32) * 0.5 + ) # Generate positive values between 0 and 1 + print( + "Warming up masks_to_boxes with", + i, + f"kboxes.shape={kboxes.shape}", + ) + self.add_prompt( + inference_state, + frame_idx=start_frame_idx, + text_str="cat", + boxes_xywh=kboxes, + box_labels=[1] * len(kboxes), + ) + + for _ in self.propagate_in_video( + inference_state, start_frame_idx, reverse=False + ): + pass + + self.new_det_thresh = orig_new_det_thresh + return inference_state + + def add_fake_objects_to_inference_state( + self, inference_state, num_objects, frame_idx + ): + new_det_obj_ids_local = np.arange(num_objects) + high_res_H, high_res_W = ( + self.tracker.maskmem_backbone.mask_downsampler.interpol_size + ) + new_det_masks = torch.ones( + len(new_det_obj_ids_local), high_res_H, high_res_W + ).to(self.device) + + inference_state["sam2_inference_states"] = self._tracker_add_new_objects( + frame_idx=frame_idx, + num_frames=inference_state["num_frames"], + new_obj_ids=new_det_obj_ids_local, + new_obj_masks=new_det_masks, + tracker_states_local=inference_state["sam2_inference_states"], + orig_vid_height=inference_state["orig_height"], + orig_vid_width=inference_state["orig_width"], + feature_cache=inference_state["feature_cache"], + ) + + # Synthesize obj_id_to_mask data for cached_frame_outputs to support _build_sam2_output during warmup + obj_id_to_mask = {} + if num_objects > 0: + H_video = inference_state["orig_height"] + W_video = inference_state["orig_width"] + + video_res_masks = F.interpolate( + new_det_masks.unsqueeze(1), # Add channel dimension for interpolation + size=(H_video, W_video), + mode="bilinear", + align_corners=False, + ) # (num_objects, 1, H_video, W_video) + for i, obj_id in enumerate(new_det_obj_ids_local): + obj_id_to_mask[obj_id] = (video_res_masks[i] > 0.0).to(torch.bool) + if self.rank == 0: + for fidx in range(inference_state["num_frames"]): + self._cache_frame_outputs(inference_state, fidx, obj_id_to_mask) + + inference_state["tracker_metadata"] = { + "obj_ids_per_gpu": [np.arange(num_objects)], + "obj_ids_all_gpu": np.arange(num_objects), # Same as 1 GPU + "num_obj_per_gpu": [num_objects], + "obj_id_to_score": {i: 1.0 for i in range(num_objects)}, + "obj_id_to_sam2_score_frame_wise": defaultdict(dict), + "obj_id_to_last_occluded": {}, + "max_obj_id": num_objects, + "rank0_metadata": { + "masklet_confirmation": { + "status": np.zeros(num_objects, dtype=np.int64), + "consecutive_det_num": np.zeros(num_objects, dtype=np.int64), + }, + "removed_obj_ids": set(), + "suppressed_obj_ids": defaultdict(set), + }, + # gpu_metadata for hotstart tracking on GPU + "gpu_metadata": { + "N_obj": num_objects, + "obj_first_frame": torch.zeros( + num_objects, dtype=torch.long, device=self.device + ), + "consecutive_unmatch_count": torch.zeros( + num_objects, dtype=torch.long, device=self.device + ), + "trk_keep_alive": torch.ones( + num_objects, dtype=torch.bool, device=self.device + ), + "removed_mask": torch.zeros( + num_objects, dtype=torch.bool, device=self.device + ), + "overlap_pair_counts": torch.zeros( + (num_objects, num_objects), dtype=torch.long, device=self.device + ), + "last_occluded_tensor": torch.zeros( + num_objects, dtype=torch.long, device=self.device + ), + }, + } + # Add num_buc_per_gpu for multiplex mode + if self.is_multiplex: + # Count actual buckets from the inference states + num_buc = self._count_buckets_in_states( + inference_state["sam2_inference_states"] + ) + inference_state["tracker_metadata"]["num_buc_per_gpu"] = np.array( + [num_buc], dtype=np.int64 + ) + + return inference_state + + @torch.inference_mode() + @torch.autocast(device_type="cuda", dtype=torch.bfloat16) + def warm_up_compilation(self): + """ + Warm up the model by running a dummy inference to compile the model. This is + useful to avoid the compilation overhead in the first inference call. + """ + if not self.compile_model: + return + self._warm_up_complete = False + if self.device.type != "cuda": + raise RuntimeError( + f"The model must be on CUDA for warm-up compilation, got {self.device=}." + ) + + # temporally set to single GPU temporarily for warm-up compilation + orig_rank = self.rank + orig_world_size = self.world_size + self.rank = self.detector.rank = 0 + self.world_size = self.detector.world_size = 1 + orig_recondition_every_nth_frame = self.recondition_every_nth_frame + # self.recondition_every_nth_frame = 2 + + # Get a random video + inference_state = self.init_state(resource_path="") + start_frame_idx = 0 + + # Run basic propagation warm-up + inference_state = self._warm_up_vg_propagation(inference_state, start_frame_idx) + + logger.info("Warm-up compilation completed.") + + # revert to the original GPU and rank + self.rank = self.detector.rank = orig_rank + self.world_size = self.detector.world_size = orig_world_size + self.recondition_every_nth_frame = orig_recondition_every_nth_frame + self._warm_up_complete = True + self.tracker.transformer.encoder.forward.set_logging(True) + + @torch.inference_mode() + def add_prompt( + self, + inference_state, + frame_idx, + text_str=None, + clear_old_points=True, + points=None, + point_labels=None, + boxes_xywh=None, + box_labels=None, + clear_old_boxes=True, + output_prob_thresh=0.5, + ): + """ + Add text, point or box prompts on a single frame. This method returns the inference + outputs only on the prompted frame. + + Note that text prompts are NOT associated with a particular frame (i.e. they apply + to all frames). However, we only run inference on the frame specified in `frame_idx`. + + Copied from sam3_demo.Sam3DemoMixin.add_prompt, simplified to support only text prompts. + """ + logger.info("Running add_prompt on frame %d", frame_idx) + + device = inference_state["device"] + num_frames = inference_state["num_frames"] + assert ( + text_str is not None or points is not None or boxes_xywh is not None + ), "at least one type of prompt (text, points, boxes) must be provided" + assert ( + 0 <= frame_idx < num_frames + ), f"{frame_idx=} is out of range for a total of {num_frames} frames" + + assert clear_old_boxes, "clear old boxes must be True" + + assert ( + points is None and clear_old_points is True and point_labels is None + ), "Point prompts not accepted" + + # since it's a semantic prompt, we start over + self.reset_state(inference_state) + + # 1) add text prompt + if text_str is not None: + inference_state["text_prompt"] = text_str + # add the text prompt into the input batch (to be applied to *all* frames) + inference_state["input_batch"].find_text_batch[0] = text_str + for t in range(inference_state["num_frames"]): + text_id = self.TEXT_ID_FOR_TEXT + inference_state["input_batch"].find_inputs[t].text_ids[...] = text_id + + # 2) handle box prompt + assert (boxes_xywh is not None) == (box_labels is not None) + if boxes_xywh is not None: + boxes_xywh = torch.as_tensor(boxes_xywh, dtype=torch.float32) + box_labels = torch.as_tensor(box_labels, dtype=torch.long) + # input boxes are expected to be [xmin, ymin, width, height] format + # in normalized coordinates of range 0~1, similar to FA + assert boxes_xywh.dim() == 2 + assert boxes_xywh.size(0) > 0 and boxes_xywh.size(-1) == 4 + assert box_labels.dim() == 1 and box_labels.size(0) == boxes_xywh.size(0) + boxes_cxcywh = box_xywh_to_cxcywh(boxes_xywh) + assert (boxes_xywh >= 0).all().item() and (boxes_xywh <= 1).all().item() + assert (boxes_cxcywh >= 0).all().item() and (boxes_cxcywh <= 1).all().item() + + new_box_input = boxes_cxcywh, box_labels + inference_state["per_frame_raw_box_input"][frame_idx] = new_box_input + + # handle the case of visual prompt (also added as an input box from the UI) + boxes_cxcywh, box_labels, geometric_prompt = self._get_visual_prompt( + inference_state, frame_idx, boxes_cxcywh, box_labels + ) + + inference_state["per_frame_geometric_prompt"][frame_idx] = geometric_prompt + + with torch.profiler.record_function("add_prompt._init_backbone_out"): + inference_state["backbone_out"] = self._init_backbone_out(inference_state) + out = self._run_single_frame_inference( + inference_state, + frame_idx, + reverse=False, + ) + return frame_idx, self._postprocess_output(inference_state, out) + + def _init_backbone_out(self, inference_state): + """ + Initialize a backbone_out dictionary and extract the text features. + + Note that the visual features of each frame are not extracted here. They will be + extracted on the fly when running inference on each frame. + """ + input = inference_state["input_batch"] + device = self.device + backbone_out = {"img_batch_all_stages": input.img_batch} + text_outputs = self.detector.backbone.forward_text( + input.find_text_batch, device=device + ) + backbone_out.update(text_outputs) + return backbone_out + + @torch.autocast(device_type="cuda", dtype=torch.bfloat16) + def forward(self, input: BatchedDatapoint, is_inference: bool = False): + """This method is only used for benchmark eval (not used in the demo).""" + # set the model to single GPU for benchmark evaluation (to be compatible with trainer) + orig_rank = self.rank + orig_world_size = self.world_size + self.rank = self.detector.rank = 0 + self.world_size = self.detector.world_size = 1 + + # get data + text_prompt_ids = input.find_metadatas[0].original_category_id + text_prompt_list = input.find_text_batch + + # loop over txt prompts + tracking_res = defaultdict(dict) # frame_idx --> {obj_id: mask} + scores_labels = defaultdict(tuple) # obj_id --> (score, text_prompt_id) + inference_state = self.init_state(resource_path=input.raw_images) + for prompt_id, prompt in zip(text_prompt_ids, text_prompt_list): + self.add_prompt(inference_state, frame_idx=0, text_str=prompt) + start_obj_id = max(scores_labels.keys(), default=-1) + 1 # prev max + 1 + + # propagate the prompts + obj_ids_this_prompt = set() + for frame_idx, out in self.propagate_in_video( + inference_state, + start_frame_idx=0, + max_frame_num_to_track=inference_state["num_frames"], + reverse=False, + ): + out_obj_ids = ( + out["out_obj_ids"].numpy() + if isinstance(out["out_obj_ids"], torch.Tensor) + else out["out_obj_ids"] + ) + out_binary_masks = ( + out["out_binary_masks"].numpy() + if isinstance(out["out_binary_masks"], torch.Tensor) + else out["out_binary_masks"] + ) + + current_frame_res = tracking_res[frame_idx] + for obj_id, mask in zip(out_obj_ids, out_binary_masks): + mask_tensor = torch.tensor(mask[None], dtype=torch.bool) + current_frame_res[obj_id + start_obj_id] = mask_tensor + obj_ids_this_prompt.update(current_frame_res.keys()) + + obj_id_to_score = inference_state["tracker_metadata"]["obj_id_to_score"] + for obj_id, score in obj_id_to_score.items(): + if obj_id + start_obj_id in obj_ids_this_prompt: + score_tensor = torch.tensor(score, dtype=torch.float32) + scores_labels[obj_id + start_obj_id] = (score_tensor, prompt_id) + + self.reset_state(inference_state) + + video_id = input.find_metadatas[0].original_image_id[0].cpu().item() + preds = self.prep_for_evaluator(input.raw_images, tracking_res, scores_labels) + + # revert the model to the original GPU and rank + self.rank = self.detector.rank = orig_rank + self.world_size = self.detector.world_size = orig_world_size + return {video_id: preds} + + +class Sam3MultiplexTrackingProd(Sam3MultiplexTracking): + """ + Subclass of Sam3MultiplexTracking with support for batched processing. + + This class enables processing videos in batches rather than all at once by: + 1. Adding an `is_last_batch` parameter to control buffer flushing + 2. Persisting generator state (hotstart_buffer, hotstart_removed_obj_ids, + unconfirmed_obj_ids_per_frame) in inference_state across generator instantiations + + This is useful for processing large videos in smaller chunks to manage memory + or distribute processing across multiple calls. + """ + + @torch.inference_mode() + def init_state( + self, + resource_path, + offload_video_to_cpu=False, + async_loading_frames=False, + use_torchcodec=False, + use_cv2=False, + input_is_mp4=False, + ): + inference_state = super().init_state( + resource_path=resource_path, + offload_video_to_cpu=offload_video_to_cpu, + async_loading_frames=async_loading_frames, + use_torchcodec=use_torchcodec, + use_cv2=use_cv2, + input_is_mp4=input_is_mp4, + ) + # Initialize generator state for batched processing + inference_state["generator_state"] = { + "hotstart_buffer": [], + "hotstart_removed_obj_ids": set(), + "unconfirmed_obj_ids_per_frame": {}, + "postprocess_yield_list": [], + } + return inference_state + + def reset_state(self, inference_state): + super().reset_state(inference_state) + # Reset generator state for batched processing + inference_state["generator_state"] = { + "hotstart_buffer": [], + "hotstart_removed_obj_ids": set(), + "unconfirmed_obj_ids_per_frame": {}, + "postprocess_yield_list": [], + } + + @torch.inference_mode() + def propagate_in_video( + self, + inference_state, + start_frame_idx=None, + max_frame_num_to_track=None, + reverse=False, + output_prob_thresh=0.5, + compute_stability_score=False, + is_instance_processing=False, + is_last_batch=True, + ): + """ + Propagate the prompts to get grounding results for the entire video. This method + is a generator and yields inference outputs for all frames in the range specified + by `start_frame_idx`, `max_frame_num_to_track`, and `reverse`. + + Args: + is_last_batch: Whether this is the last batch in a batched processing scenario. + When True (default), the hotstart buffer will be flushed at end_frame_idx. + When False, the buffer is preserved in inference_state for the next batch. + This flag should be set to False for all batches except the last one when + processing a video in multiple batches. + """ + # compile the model (it's a no-op if the model is already compiled) + # note that it's intentionally added to `self.propagate_in_video`, so that the first + # `self.add_prompt` call will be done in eager mode to fill in the decoder buffers + # such as positional encoding cache) + self._compile_model() + + processing_order, end_frame_idx = self._get_processing_order( + inference_state, + start_frame_idx, + max_frame_num_to_track, + reverse=reverse, + ) + + # Store max_frame_num_to_track in feature_cache for downstream methods + inference_state["feature_cache"]["tracking_bounds"] = { + "max_frame_num_to_track": max_frame_num_to_track, + "propagate_in_video_start_frame_idx": start_frame_idx, + } + + # Initialize or retrieve generator state from inference_state to persist across batches + if "generator_state" not in inference_state: + inference_state["generator_state"] = { + "hotstart_buffer": [], + "hotstart_removed_obj_ids": set(), + "unconfirmed_obj_ids_per_frame": {}, + "postprocess_yield_list": [], + } + + generator_state = inference_state["generator_state"] + hotstart_buffer = generator_state["hotstart_buffer"] + hotstart_removed_obj_ids = generator_state["hotstart_removed_obj_ids"] + unconfirmed_obj_ids_per_frame = generator_state["unconfirmed_obj_ids_per_frame"] + postprocess_yield_list = generator_state.get("postprocess_yield_list", []) + + # when deciding whether to output a masklet on `yield_frame_idx`, we check whether the object is confirmed + # in a future frame (`unconfirmed_frame_delay` frames after the current frame). For example, if we require + # an object to be detected in 3 consecutive frames to be confirmed, then we look 2 frames in the future -- + # e.g., we output an object on frame 4 only if it becomes confirmed on frame 6. + unconfirmed_status_delay = self.masklet_confirmation_consecutive_det_thresh - 1 + + for frame_idx in tqdm( + processing_order, desc="propagate_in_video", disable=self.rank > 0 + ): + out = self._run_single_frame_inference( + inference_state, + frame_idx, + reverse, + is_instance_processing=is_instance_processing, + ) + + if self.hotstart_delay > 0: + # accumulate the outputs for the first `hotstart_delay` frames + hotstart_buffer.append([frame_idx, out]) + # update the object IDs removed by hotstart so that we don't output them + if self.rank == 0: + hotstart_removed_obj_ids.update(out["removed_obj_ids"]) + unconfirmed_obj_ids = out.get("unconfirmed_obj_ids", None) + if unconfirmed_obj_ids is not None: + unconfirmed_obj_ids_per_frame[frame_idx] = unconfirmed_obj_ids + + if frame_idx == end_frame_idx and is_last_batch: + # we reached the end of propagation -- yield all frames in the buffer + yield_list = hotstart_buffer + hotstart_buffer = [] + elif len(hotstart_buffer) >= self.hotstart_delay: + # we have enough frames -- yield and remove the first (oldest) frame from the buffer + yield_list = hotstart_buffer[:1] + hotstart_buffer = hotstart_buffer[1:] + else: + # not enough frames yet -- skip yielding + yield_list = [] + else: + yield_list = [(frame_idx, out)] # output the current frame + + # Accumulate yield_list into postprocess_yield_list + # Snapshot hotstart_removed_obj_ids at the time of accumulation to preserve + # the correct state for each frame (important: this set is mutated over time) + for yield_frame_idx, yield_out in yield_list: + postprocess_yield_list.append( + (yield_frame_idx, yield_out, set(hotstart_removed_obj_ids)) + ) + + # Process batch when we have enough frames + while len(postprocess_yield_list) >= self.postprocess_batch_size: + batch_to_process = postprocess_yield_list[: self.postprocess_batch_size] + postprocess_yield_list = postprocess_yield_list[ + self.postprocess_batch_size : + ] + + with torch.profiler.record_function( + "Sam3MultiplexTrackingProd.postprocess_output_batched" + ): + if self.rank == 0: + # Prepare batched inputs for postprocessing + H_video, W_video = ( + inference_state["orig_height"], + inference_state["orig_width"], + ) + num_frames = inference_state["num_frames"] + + batched_outs = [] + frame_indices = [] + for ( + yield_frame_idx, + yield_out, + removed_obj_ids_snapshot, + ) in batch_to_process: + suppressed_obj_ids = yield_out["suppressed_obj_ids"] + unconfirmed_status_frame_idx = ( + yield_frame_idx + unconfirmed_status_delay + if not reverse + else yield_frame_idx - unconfirmed_status_delay + ) + unconfirmed_status_frame_idx = max( + 0, min(unconfirmed_status_frame_idx, num_frames - 1) + ) + unconfirmed_obj_ids = unconfirmed_obj_ids_per_frame.get( + unconfirmed_status_frame_idx, None + ) + + batched_outs.append( + ( + yield_out, + removed_obj_ids_snapshot, + suppressed_obj_ids, + unconfirmed_obj_ids, + ) + ) + frame_indices.append(yield_frame_idx) + + # Cache frame outputs + self._cache_frame_outputs( + inference_state, + yield_frame_idx, + yield_out["obj_id_to_mask"], + suppressed_obj_ids=suppressed_obj_ids, + removed_obj_ids=removed_obj_ids_snapshot, + unconfirmed_obj_ids=unconfirmed_obj_ids, + ) + + # Process all frames in batch + if self.postprocess_batch_size > 1: + postprocessed_outs = self._postprocess_output_batched( + H_video, W_video, batched_outs + ) + else: + # Process each frame individually but output together + postprocessed_outs = [] + for ( + yield_out, + removed_obj_ids_snapshot, + suppressed_obj_ids, + unconfirmed_obj_ids, + ) in batched_outs: + postprocessed_out = self._postprocess_output( + inference_state, + yield_out, + removed_obj_ids_snapshot, + suppressed_obj_ids, + unconfirmed_obj_ids, + ) + postprocessed_outs.append(postprocessed_out) + + # Yield results + for yield_frame_idx, postprocessed_out in zip( + frame_indices, postprocessed_outs + ): + yield yield_frame_idx, postprocessed_out + else: + # No output on other GPUs + for yield_frame_idx, _, _ in batch_to_process: + yield yield_frame_idx, DUMMY_OUTPUT + + # Handle remaining frames in hotstart buffer at end of last batch + if is_last_batch and len(hotstart_buffer) > 0: + for yield_frame_idx, yield_out in hotstart_buffer: + postprocess_yield_list.append( + (yield_frame_idx, yield_out, set(hotstart_removed_obj_ids)) + ) + hotstart_buffer = [] + + # Flush any remaining frames in the postprocess buffer (even partial + # batches) so that the caller gets results as soon as possible. This is + # especially important for the first batch where hotstart_delay causes + # only a few frames to exit the hotstart buffer — without this flush + # the client would have to wait for the next batch before receiving any + # output, hurting time-to-first-frame. + if len(postprocess_yield_list) > 0: + with torch.profiler.record_function( + "Sam3MultiplexTrackingProd.postprocess_output_batched" + ): + if self.rank == 0: + H_video, W_video = ( + inference_state["orig_height"], + inference_state["orig_width"], + ) + num_frames = inference_state["num_frames"] + + batched_outs = [] + frame_indices = [] + for ( + yield_frame_idx, + yield_out, + removed_obj_ids_snapshot, + ) in postprocess_yield_list: + suppressed_obj_ids = yield_out["suppressed_obj_ids"] + unconfirmed_status_frame_idx = ( + yield_frame_idx + unconfirmed_status_delay + if not reverse + else yield_frame_idx - unconfirmed_status_delay + ) + unconfirmed_status_frame_idx = max( + 0, min(unconfirmed_status_frame_idx, num_frames - 1) + ) + unconfirmed_obj_ids = unconfirmed_obj_ids_per_frame.get( + unconfirmed_status_frame_idx, None + ) + + batched_outs.append( + ( + yield_out, + removed_obj_ids_snapshot, + suppressed_obj_ids, + unconfirmed_obj_ids, + ) + ) + frame_indices.append(yield_frame_idx) + + self._cache_frame_outputs( + inference_state, + yield_frame_idx, + yield_out["obj_id_to_mask"], + suppressed_obj_ids=suppressed_obj_ids, + removed_obj_ids=removed_obj_ids_snapshot, + unconfirmed_obj_ids=unconfirmed_obj_ids, + ) + + if self.postprocess_batch_size > 1: + postprocessed_outs = self._postprocess_output_batched( + H_video, W_video, batched_outs + ) + else: + # Process each frame individually but output together + postprocessed_outs = [] + for ( + yield_out, + removed_obj_ids_snapshot, + suppressed_obj_ids, + unconfirmed_obj_ids, + ) in batched_outs: + postprocessed_out = self._postprocess_output( + inference_state, + yield_out, + removed_obj_ids_snapshot, + suppressed_obj_ids, + unconfirmed_obj_ids, + ) + postprocessed_outs.append(postprocessed_out) + + for yield_frame_idx, postprocessed_out in zip( + frame_indices, postprocessed_outs + ): + yield yield_frame_idx, postprocessed_out + else: + for yield_frame_idx, _, _ in postprocess_yield_list: + yield yield_frame_idx, DUMMY_OUTPUT + + postprocess_yield_list = [] + + # Store the generator state back to inference_state for persistence across batches + generator_state["postprocess_yield_list"] = postprocess_yield_list + generator_state["hotstart_buffer"] = hotstart_buffer + generator_state["hotstart_removed_obj_ids"] = hotstart_removed_obj_ids + generator_state["unconfirmed_obj_ids_per_frame"] = unconfirmed_obj_ids_per_frame + + if self.is_multiplex: + # log the bucket utilization stats + # bucket utilization rate is total valid objects / total capacity -> represents rooms for improvement + # subscription rate is total valid objects / total number of buckets -> represents speedup + total_valid_objects = 0 + total_num_buckets = 0 + for state in inference_state["sam2_inference_states"]: + assert ( + len(state["obj_ids"]) + == state["multiplex_state"].total_valid_entries + ) + total_valid_objects += len(state["obj_ids"]) + total_num_buckets += state["multiplex_state"].num_buckets + if total_num_buckets > 0: + bucket_utilization_rate = ( + total_valid_objects / (total_num_buckets * self.bucket_capacity) + ) * 100 + subscription_rate = (total_valid_objects / total_num_buckets) * 100 + logger.info( + f"Bucket utilization rate: {bucket_utilization_rate:.2f}%, subscription rate: {subscription_rate:.2f}%" + ) + + +class Sam3MultiplexTrackingWithInteractivity(Sam3MultiplexTracking): + def __init__( + self, + use_prev_mem_frame=False, + use_stateless_refinement=False, + refinement_detector_cond_frame_removal_window=30 * 4, + **kwargs, + ): + """ + use_prev_mem_frame: bool, whether to condition on previous memory frames for adding points + use_stateless_refinement: bool, whether to enable stateless refinement behavior + refinement_detector_cond_frame_removal_window: int, we remove a detector conditioning frame if it + is within this many frames of a user refined frame. Set to a large value (e.g. 10000) to + always remove detector conditioning frames if there is any user refinement in the video. + """ + super().__init__(**kwargs) + self.use_prev_mem_frame = use_prev_mem_frame + self.use_stateless_refinement = use_stateless_refinement + self.refinement_detector_cond_frame_removal_window = ( + refinement_detector_cond_frame_removal_window + ) + + @torch.inference_mode() + def init_state( + self, + resource_path, + offload_video_to_cpu=False, + async_loading_frames=False, + use_torchcodec=False, + use_cv2=False, + input_is_mp4=False, + ): + inference_state = super().init_state( + resource_path=resource_path, + offload_video_to_cpu=offload_video_to_cpu, + async_loading_frames=async_loading_frames, + use_torchcodec=use_torchcodec, + use_cv2=use_cv2, + input_is_mp4=input_is_mp4, + ) + # initialize extra states + inference_state["action_history"] = [] # for logging user actions + if self.tracker.per_obj_inference: + # in per_obj mode only 1 inference state is needed, we init it here. + inference_state["sam2_inference_states"] = [ + self._init_new_sam2_state(inference_state) + ] + return inference_state + + def reset_state(self, inference_state): + super().reset_state(inference_state) + # reset extra states + inference_state["action_history"].clear() + if self.tracker.per_obj_inference: + inference_state["sam2_inference_states"] = [ + self._init_new_sam2_state(inference_state) + ] + + def _init_new_sam2_state(self, inference_state): + return self.tracker.init_state( + cached_features=inference_state["feature_cache"], + video_height=inference_state["orig_height"], + video_width=inference_state["orig_width"], + num_frames=inference_state["num_frames"], + ) + + def cancel_propagation(self, inference_state): + """ + Cancel any ongoing propagation and reset the model state. + """ + logger.info("Cancelling ongoing propagation.") + self.add_action_history( + inference_state, + action_type="propagation_cancel", + obj_ids=None, + frame_idx=None, + ) + + def fetch_and_process_single_frame_results(self, inference_state, frame_idx): + tracker_metadata = inference_state["tracker_metadata"] + obj_id_to_mask = inference_state["cached_frame_outputs"][frame_idx] + # post processing - remove suppressed obj_ids + obj_id_to_score = tracker_metadata["obj_id_to_score"] + suppressed_obj_ids = tracker_metadata["rank0_metadata"]["suppressed_obj_ids"][ + frame_idx + ] + obj_id_to_sam2_score = tracker_metadata["obj_id_to_sam2_score_frame_wise"][ + frame_idx + ] + + out = { + "obj_id_to_mask": obj_id_to_mask, + "obj_id_to_score": obj_id_to_score, + "obj_id_to_sam2_score": obj_id_to_sam2_score, + } + return frame_idx, self._postprocess_output( + inference_state, out, suppressed_obj_ids=suppressed_obj_ids + ) + + @torch.inference_mode() + def propagate_in_video( + self, + inference_state, + start_frame_idx=None, + max_frame_num_to_track=None, + reverse=False, + output_prob_thresh=0.5, + compute_stability_score=False, + is_instance_processing=False, + is_last_batch: bool = False, + ): + # step 1: check which type of propagation to run, should be the same for all GPUs. + propagation_type, obj_ids = self.parse_action_history_for_propagation( + inference_state + ) + self.add_action_history( + inference_state, + action_type=propagation_type, + obj_ids=obj_ids, + frame_idx=start_frame_idx, + ) + + # step 2: run full VG propagation + if propagation_type == "propagation_full": + logger.info(f"Running full VG propagation (reverse={reverse}).") + yield from super().propagate_in_video( + inference_state, + start_frame_idx=start_frame_idx, + max_frame_num_to_track=max_frame_num_to_track, + reverse=reverse, + is_last_batch=is_last_batch, + ) + return + + # step 3: run SAM2 partial propagation or direct fetch existing predictions + assert propagation_type in ["propagation_partial", "propagation_fetch"] + logger.info( + f"Running SAM2 propagation for objects {obj_ids} and merging it with existing VG predictions (reverse={reverse})." + if propagation_type == "propagation_partial" + else f"Fetching existing VG predictions without running any propagation (reverse={reverse})." + ) + processing_order, _end_frame_idx = self._get_processing_order( + inference_state, + start_frame_idx=start_frame_idx, + max_frame_num_to_track=max_frame_num_to_track, + reverse=reverse, + ) + + tracker_metadata = inference_state["tracker_metadata"] + + # if fetch just return from output + if propagation_type == "propagation_fetch": + for frame_idx in tqdm(processing_order): + if self.rank == 0: + frame_idx, out = self.fetch_and_process_single_frame_results( + inference_state, frame_idx + ) + yield frame_idx, out + else: + yield frame_idx, DUMMY_OUTPUT # no output for other GPUs + + return + + # get SAM2 inference states containing selected obj_ids + if propagation_type == "propagation_partial": + # can be empty for GPUs where objects are not in their inference states + tracker_states_local = self._get_sam2_inference_states_by_obj_ids( + inference_state, obj_ids + ) + for sam2_state in tracker_states_local: + self.tracker.propagate_in_video_preflight( + sam2_state, run_mem_encoder=True + ) + + for frame_idx in tqdm(processing_order): + # run SAM2 propagation + if propagation_type == "propagation_partial": + self._prepare_backbone_feats(inference_state, frame_idx, reverse) + obj_ids_local, low_res_masks_local, sam2_scores_local = ( + self._propogate_tracker_one_frame_local_gpu( + tracker_states_local, + frame_idx=frame_idx, + reverse=reverse, + run_mem_encoder=True, + ) + ) + + # broadcast refined object sam2 scores and masks to all GPUs + # handle multiple objects that can be located on different GPUs + refined_obj_data = {} # obj_id -> (score, mask_video_res) + + # Collect data for objects on this GPU + local_obj_data = {} + for obj_id in obj_ids: + obj_rank = self._get_gpu_id_by_obj_id(inference_state, obj_id) + if self.rank == obj_rank and obj_id in obj_ids_local: + refined_obj_idx = obj_ids_local.index(obj_id) + refined_mask_low_res = low_res_masks_local[ + refined_obj_idx + ] # (H_low_res, W_low_res) + refined_score = sam2_scores_local[refined_obj_idx] + + # Keep low resolution for broadcasting to reduce communication cost + local_obj_data[obj_id] = (refined_score, refined_mask_low_res) + + # Broadcast data from each GPU that has refined objects + if self.world_size > 1: + for obj_id in obj_ids: + obj_rank = self._get_gpu_id_by_obj_id(inference_state, obj_id) + if self.rank == obj_rank: + # This GPU has the object, broadcast its data + data_to_broadcast = local_obj_data.get(obj_id, None) + data_list = [data_to_broadcast] + self.broadcast_python_obj_cpu(data_list, src=obj_rank) + if data_to_broadcast is not None: + refined_obj_data[obj_id] = data_to_broadcast + elif self.rank != obj_rank: + # This GPU doesn't have the object, receive data + data_list = [None] + self.broadcast_python_obj_cpu(data_list, src=obj_rank) + if data_list[0] is not None: + refined_obj_data[obj_id] = data_list[0] + else: + # Single GPU case + refined_obj_data = local_obj_data + + # Update SAM2 scores for all refined objects + for obj_id, (refined_score, _) in refined_obj_data.items(): + # After broadcast_python_obj_cpu in multi-GPU, tensors may become numpy scalars + # Ensure it's a GPU tensor for consistency with base class behavior + if not isinstance(refined_score, torch.Tensor): + refined_score = torch.tensor( + refined_score, dtype=torch.float32, device=self.device + ) + tracker_metadata["obj_id_to_sam2_score_frame_wise"][ + frame_idx + ].update({obj_id: refined_score}) + + if self.rank == 0: + # get predictions from SAM2 inference states, it includes the original + # VG predictions and the refined predictions from interactivity. + + # Prepare refined masks dictionary - upscale to video resolution after broadcast + refined_obj_id_to_mask = {} + for obj_id, (_, refined_mask_low_res) in refined_obj_data.items(): + refined_mask_video_res = ( + self._convert_low_res_mask_to_video_res( + refined_mask_low_res, inference_state + ) + ) # (1, H_video, W_video) bool + refined_obj_id_to_mask[obj_id] = refined_mask_video_res + + obj_id_to_mask = self._build_sam2_output( + inference_state, frame_idx, refined_obj_id_to_mask + ) + out = { + "obj_id_to_mask": obj_id_to_mask, + "obj_id_to_score": tracker_metadata["obj_id_to_score"], + "obj_id_to_sam2_score": tracker_metadata[ + "obj_id_to_sam2_score_frame_wise" + ][frame_idx], + } + suppressed_obj_ids = tracker_metadata["rank0_metadata"][ + "suppressed_obj_ids" + ][frame_idx] + self._cache_frame_outputs( + inference_state, + frame_idx, + obj_id_to_mask, + suppressed_obj_ids=suppressed_obj_ids, + ) + suppressed_obj_ids = tracker_metadata["rank0_metadata"][ + "suppressed_obj_ids" + ][frame_idx] + yield ( + frame_idx, + self._postprocess_output( + inference_state, out, suppressed_obj_ids=suppressed_obj_ids + ), + ) + else: + yield frame_idx, DUMMY_OUTPUT # no output for other GPUs + + def add_action_history( + self, inference_state, action_type, frame_idx=None, obj_ids=None + ): + """ + action_history is used to automatically decide what to do during propagation. + action_type: one of ["add", "remove", "refine"] + ["propagation_full", "propagation_partial", "propagation_fetch", "propagation_cancel"] + """ + instance_actions = ["add", "remove", "refine"] + propagation_actions = [ + "propagation_full", + "propagation_partial", + "propagation_fetch", + "propagation_cancel", + ] + assert ( + action_type in instance_actions + propagation_actions + ), f"Invalid action type: {action_type}, must be one of {instance_actions + propagation_actions}" + action = { + "type": action_type, + "frame_idx": frame_idx, + "obj_ids": obj_ids, + } + inference_state["action_history"].append(action) + + def _has_object_been_refined(self, inference_state, obj_id): + if "action_history" not in inference_state: + return False + action_history = inference_state["action_history"] + for action in action_history: + if action["type"] in ["add", "refine"] and action.get("obj_ids"): + if obj_id in action["obj_ids"]: + return True + return False + + def parse_action_history_for_propagation(self, inference_state): + action_history = inference_state["action_history"] + if ( + len(action_history) == 1 + and action_history[0]["type"] == "propagation_cancel" + ): + # only one action and it is cancel, we do full propagation + return "propagation_full", None + elif ( + len(action_history) >= 2 + and action_history[-1]["type"] == "propagation_cancel" + ): + # last action is cancel, we go back to the action before cancel + action_before_cancelation = inference_state["action_history"][-2] + # the action before cancelation can be a propagation_fetch from running both forward + # and backward propagation as in webdemo interface, in that case we go back one more step + if action_before_cancelation["type"] == "propagation_fetch": + action_before_cancelation = inference_state["action_history"][-3] + return action_before_cancelation["type"], action_before_cancelation.get( + "obj_ids", None + ) + return self._parse_action_history_for_propagation( + inference_state["action_history"], inference_state["num_frames"] + ) + + def _parse_action_history_for_propagation(self, action_history, num_frames): + """ + Parse the actions in history before the last propagation and prepare for the next propagation. + We support multiple actions (add/remove/refine) between two propagations. If we had an action + history similar to this ["propagate", "add", "refine", "remove", "add"], the next propagation + would remove the removed object, and also propagate the two added/refined objects. + + Returns: + propagation_type: one of ["propagation_full", "propagation_partial", "propagation_fetch"] + - "propagation_full": run VG propagation for all objects + - "propagation_partial": run SAM2 propagation for selected objects, useful for add/refine actions + - "propagation_fetch": fetch existing VG predictions without running any propagation + - "propagation_cancel": this will be handled in parse_action_history_for_propagation() not this function. + obj_ids: list of object ids to run SAM2 propagation on if propagation_type is "propagation_partial". + + TODO: (Jie) this function works for our current workflows, but may need more tests to ensure it works + correctly with different action histories for future workflows. + """ + if len(action_history) == 0: + # we run propagation for the first time + return "propagation_full", None + + if "propagation" in action_history[-1]["type"]: + if action_history[-1]["type"] in ["propagation_fetch"]: + # last propagation is direct fetch, we fetch existing predictions + return "propagation_fetch", None + elif action_history[-1]["type"] in [ + "propagation_partial", + "propagation_full", + ]: + # we do fetch prediction if we have already run propagation twice or we have run + # propagation once and it is from the first frame or last frame. + if ( + len(action_history) > 1 + and action_history[-2]["type"] + in ["propagation_partial", "propagation_full"] + ) or action_history[-1]["frame_idx"] in [ + 0, + num_frames - 1, + ]: + # we have run both forward and backward partial/full propagation + return "propagation_fetch", None + else: + # we have run partial/full forward or backward propagation once, need run it for the rest of the frames + return action_history[-1]["type"], action_history[-1]["obj_ids"] + + # parse actions since last propagation + obj_ids = [] + for action in action_history[::-1]: + if "propagation" in action["type"]: + # we reached the last propagation action, stop parsing + break + if action["type"] in ["add", "refine"]: + obj_ids.extend(action["obj_ids"]) + # else action["type"] == "remove": noop + obj_ids = list(set(obj_ids)) if len(obj_ids) > 0 else None + propagation_type = ( + "propagation_partial" if obj_ids is not None else "propagation_fetch" + ) + return propagation_type, obj_ids + + def remove_object(self, inference_state, obj_id, frame_idx, is_user_action=False): + """ + We try to remove object from sam2 states on every GPU, it will do nothing + for states without this object. + """ + obj_rank = self._get_gpu_id_by_obj_id(inference_state, obj_id) + if obj_rank is None: + # Object was already removed (e.g., by hotstart heuristics during + # propagation). Log a warning and skip SAM2 state and metadata + # removal, but still record action history and clean up cached outputs. + logger.warning( + f"Object {obj_id} not found in any GPU (already removed). " + f"Skipping SAM2 state and metadata removal." + ) + else: + tracker_states_local = inference_state["sam2_inference_states"] + if self.rank == obj_rank: + self._tracker_remove_objects(tracker_states_local, [obj_id]) + + # update metadata + tracker_metadata = inference_state["tracker_metadata"] + _obj_ids = tracker_metadata["obj_ids_per_gpu"][obj_rank] + tracker_metadata["obj_ids_per_gpu"][obj_rank] = _obj_ids[_obj_ids != obj_id] + tracker_metadata["num_obj_per_gpu"][obj_rank] = len( + tracker_metadata["obj_ids_per_gpu"][obj_rank] + ) + tracker_metadata["obj_ids_all_gpu"] = np.concatenate( + tracker_metadata["obj_ids_per_gpu"] + ) + tracker_metadata["obj_id_to_score"].pop(obj_id, None) + # tracker_metadata["max_obj_id"] # we do not reuse the object id, so we do not update it here + + if is_user_action: + self.add_action_history( + inference_state, action_type="remove", obj_ids=[obj_id] + ) + + # Clean up cached frame outputs to remove references to the deleted object + if "cached_frame_outputs" in inference_state: + for _frame_idx in inference_state["cached_frame_outputs"]: + frame_cache = inference_state["cached_frame_outputs"][_frame_idx] + if obj_id in frame_cache: + del frame_cache[obj_id] + + out = None + if frame_idx is not None and self.rank == 0: + frame_idx, out = self.fetch_and_process_single_frame_results( + inference_state, frame_idx + ) + return frame_idx, out + + def _get_gpu_id_by_obj_id(self, inference_state, obj_id): + """ + Locate GPU ID for a given object. + """ + obj_ids_per_gpu = inference_state["tracker_metadata"]["obj_ids_per_gpu"] + for rank, obj_ids in enumerate(obj_ids_per_gpu): + if obj_id in obj_ids: + return rank + return None # object not found in any GPU + + def _get_sam2_inference_states_by_obj_ids(self, inference_state, obj_ids): + """ + Get the SAM2 inference states that contain the given object ids. + This is used to run partial SAM2 propagation on a single object/bucket. + Possibly multiple or zero states can be returned. + """ + states = [ + state + for state in inference_state["sam2_inference_states"] + if set(obj_ids) & set(state["obj_ids"]) + ] + return states + + def _prepare_backbone_feats(self, inference_state, frame_idx, reverse): + input_batch = inference_state["input_batch"] + feature_cache = inference_state["feature_cache"] + num_frames = inference_state["num_frames"] + geometric_prompt = ( + inference_state["constants"]["empty_geometric_prompt"] + if inference_state["per_frame_geometric_prompt"][frame_idx] is None + else inference_state["per_frame_geometric_prompt"][frame_idx] + ) + _ = self.run_backbone_and_detection( + frame_idx=frame_idx, + num_frames=num_frames, + reverse=reverse, + input_batch=input_batch, + geometric_prompt=geometric_prompt, + feature_cache=feature_cache, + ) + + @torch.inference_mode() + def add_prompt( + self, + inference_state, + frame_idx, + text_str=None, + clear_old_points=True, + points=None, + point_labels=None, + boxes_xywh=None, + box_labels=None, + clear_old_boxes=True, + output_prob_thresh=0.5, + obj_id=None, + rel_coordinates=True, + ): + if points is not None: + # SAM2 instance prompts + assert ( + text_str is None and boxes_xywh is None + ), "When points are provided, text_str and boxes_xywh must be None." + assert ( + obj_id is not None + ), "When points are provided, obj_id must be provided." + return self.add_sam2_new_points( + inference_state, + frame_idx, + obj_id=obj_id, + points=points, + labels=point_labels, + clear_old_points=clear_old_points, + rel_coordinates=rel_coordinates, + use_prev_mem_frame=self.use_prev_mem_frame, + ) + else: + # SAM3 prompts — disable batched grounding for single-frame add_prompt + _orig_batched = self.use_batched_grounding + self.use_batched_grounding = False + try: + return super().add_prompt( + inference_state, + frame_idx, + text_str=text_str, + clear_old_points=clear_old_points, + points=points, + point_labels=point_labels, + boxes_xywh=boxes_xywh, + box_labels=box_labels, + clear_old_boxes=clear_old_boxes, + output_prob_thresh=output_prob_thresh, + ) + finally: + self.use_batched_grounding = _orig_batched + + @torch.inference_mode() + def add_sam2_new_points( + self, + inference_state, + frame_idx, + obj_id, + points, + labels, + clear_old_points, + rel_coordinates=True, + use_prev_mem_frame=False, + ): + """Add a new point prompt to SAM2. Suppporting instance refinement to existing + objects by passing existing obj_id or adding a new object by passing a new obj_id. + use_prev_mem_frame=False to disable cross attention to previous memory frames. + Every GPU returns the same results, and results should contain all masks including + these masks not refined or not added by the current user points. + """ + assert obj_id is not None, "obj_id must be provided to add new points" + tracker_metadata = inference_state["tracker_metadata"] + if tracker_metadata == {}: + # initialize masklet metadata if it's uninitialized (empty dict) + tracker_metadata.update(self._initialize_metadata()) + + obj_rank = self._get_gpu_id_by_obj_id(inference_state, obj_id) + + # prepare feature + self._prepare_backbone_feats(inference_state, frame_idx, reverse=False) + + object_has_been_refined = self._has_object_been_refined(inference_state, obj_id) + if ( + obj_rank is not None + and self.use_stateless_refinement + and not object_has_been_refined + ): + # The first time we start refinement on the object, we remove it. + logger.info( + f"[rank={self.rank}] Removing object {obj_id} before refinement." + ) + self.remove_object(inference_state, obj_id, is_user_action=False) + obj_rank = None + elif obj_rank is not None and not object_has_been_refined: + # Extract the object into its own singleton inference state if it belongs to a batch + if self.rank == obj_rank and not self.tracker.per_obj_inference: + tracker_states = self._get_sam2_inference_states_by_obj_ids( + inference_state, [obj_id] + ) + assert len(tracker_states) == 1 + # Check if this is a batched state (contains multiple objects) + sam2_state = tracker_states[0] + if len(sam2_state["obj_ids"]) > 1: + logger.info( + f"[rank={self.rank}] Extracting object {obj_id} into singleton inference state." + ) + self._extract_object_to_singleton_state( + inference_state, obj_id, obj_rank + ) + + if obj_rank is None: + # new object, we assign it a GPU and create a new inference state if limit allows + num_prev_obj = np.sum(tracker_metadata["num_obj_per_gpu"]) + if num_prev_obj >= self.max_num_objects: + logger.warning( + f"add_sam2_new_points: cannot add a new object as we are already tracking {num_prev_obj=} " + f"masklets (under {self.max_num_objects=})" + ) + return frame_idx, None + + new_det_gpu_ids = self._assign_new_det_to_gpus( + new_det_num=1, + prev_workload_per_gpu=tracker_metadata["num_obj_per_gpu"], + ) + obj_rank = new_det_gpu_ids[0] + + # get sam2 inference state for the new object + if self.rank == obj_rank: + if self.tracker.per_obj_inference: + sam2_state = inference_state["sam2_inference_states"][0] + else: + # for batched inference, we create a new inference state + sam2_state = self._init_new_sam2_state(inference_state) + inference_state["sam2_inference_states"].append(sam2_state) + + # update metadata + tracker_metadata["obj_ids_per_gpu"][obj_rank] = np.concatenate( + [ + tracker_metadata["obj_ids_per_gpu"][obj_rank], + np.array([obj_id], dtype=np.int64), + ] + ) + tracker_metadata["num_obj_per_gpu"][obj_rank] = len( + tracker_metadata["obj_ids_per_gpu"][obj_rank] + ) + tracker_metadata["obj_ids_all_gpu"] = np.concatenate( + tracker_metadata["obj_ids_per_gpu"] + ) + tracker_metadata["max_obj_id"] = max(tracker_metadata["max_obj_id"], obj_id) + + logger.info( + f"[rank={self.rank}] Adding new object with id {obj_id} at frame {frame_idx}." + ) + self.add_action_history( + inference_state, "add", frame_idx=frame_idx, obj_ids=[obj_id] + ) + else: + # existing object, for refinement + if self.rank == obj_rank: + tracker_states = self._get_sam2_inference_states_by_obj_ids( + inference_state, [obj_id] + ) + assert ( + len(tracker_states) == 1 + ), f"[rank={self.rank}] Multiple SAM2 inference states found for the same object id." + sam2_state = tracker_states[0] + + # log + logger.info( + f"[rank={self.rank}] Refining existing object with id {obj_id} at frame {frame_idx}." + ) + self.add_action_history( + inference_state, "refine", frame_idx=frame_idx, obj_ids=[obj_id] + ) + + # assign higher score to added/refined object + tracker_metadata["obj_id_to_score"][obj_id] = 1.0 + tracker_metadata["obj_id_to_sam2_score_frame_wise"][frame_idx][obj_id] = ( + torch.tensor(1.0, dtype=torch.float32, device=self.device) + ) + + if self.rank == 0: + rank0_metadata = tracker_metadata.get("rank0_metadata", {}) + + if "removed_obj_ids" in rank0_metadata: + rank0_metadata["removed_obj_ids"].discard(obj_id) + + if "suppressed_obj_ids" in rank0_metadata: + for frame_id in rank0_metadata["suppressed_obj_ids"]: + rank0_metadata["suppressed_obj_ids"][frame_id].discard(obj_id) + + if "masklet_confirmation" in rank0_metadata: + obj_ids_all_gpu = tracker_metadata["obj_ids_all_gpu"] + obj_indices = np.where(obj_ids_all_gpu == obj_id)[0] + if len(obj_indices) > 0: + obj_idx = obj_indices[0] + if obj_idx < len(rank0_metadata["masklet_confirmation"]["status"]): + rank0_metadata["masklet_confirmation"]["status"][obj_idx] = 1 + rank0_metadata["masklet_confirmation"]["consecutive_det_num"][ + obj_idx + ] = self.masklet_confirmation_consecutive_det_thresh + + if self.rank == obj_rank: + should_fallback_to_original_mask = ( + len(points) == 0 and inference_state["is_image_only"] + ) + if should_fallback_to_original_mask: + mask_input = self._get_mask_input(sam2_state, frame_idx, obj_id) + if mask_input is None or 0 in mask_input.shape: + logger.warning( + f"Cannot retrieve original mask input for obj_id {obj_id} at frame {frame_idx} to fallback." + ) + should_fallback_to_original_mask = False + if should_fallback_to_original_mask: + # When user cancels all points on an image, we recover the original mask + # by re-feeding the detector mask to SAM2. + mask_input = self._get_mask_input(sam2_state, frame_idx, obj_id) + # clear out states related to this object to have a fresh start + self.tracker.clear_all_points_in_frame( + sam2_state, frame_idx, obj_id, need_output=False + ) + frame_idx, obj_ids, low_res_masks, video_res_masks = ( + self.tracker.add_new_mask( + sam2_state, + frame_idx, + obj_id, + mask_input, + ) + ) + else: + frame_idx, obj_ids, low_res_masks, video_res_masks = ( + self.tracker.add_new_points( + inference_state=sam2_state, + frame_idx=frame_idx, + obj_id=obj_id, + points=points, + labels=labels, + clear_old_points=clear_old_points, + rel_coordinates=rel_coordinates, + use_prev_mem_frame=use_prev_mem_frame, + ) + ) + + if video_res_masks is not None and len(video_res_masks) > 0: + video_res_masks = fill_holes_in_mask_scores( + video_res_masks, # shape (N, 1, H_video, W_video) + fill_hole_area=self.fill_hole_area, + sprinkle_removal_area=self.sprinkle_removal_area, + fill_holes=True, + remove_sprinkles=True, + ) + + # TODO: will this cause issue when user switching to refine another object? + # Since the mem encoder has already run for the current input points? + # FIX: Synchronize consolidated_frame_inds with actual point/mask + # inputs before propagate_in_video_preflight. Two issues can cause + # the `all_consolidated_frame_inds == input_frames_inds` assertion + # to fail: + # 1) VG detector conditioning frames in mask_inputs_per_obj without + # corresponding point inputs (stale VG entries). + # 2) Previously consolidated point-input frames (from earlier + # add_points) whose consolidated_frame_inds entries were lost + # during subsequent propagation. + # We fix both by: (a) clearing mask-only inputs, (b) rebuilding + # consolidated_frame_inds from the remaining inputs, excluding + # temp output frames (which preflight will add itself). + + # (a) Clear detector-only mask inputs + for _obj_idx in list(sam2_state["mask_inputs_per_obj"].keys()): + _point_frames = set( + sam2_state["point_inputs_per_obj"].get(_obj_idx, {}).keys() + ) + _mask_only_frames = [ + f + for f in list(sam2_state["mask_inputs_per_obj"][_obj_idx].keys()) + if f not in _point_frames + ] + for f in _mask_only_frames: + sam2_state["mask_inputs_per_obj"][_obj_idx].pop(f, None) + + # (b) Rebuild consolidated_frame_inds from remaining inputs + _input_frames = set() + for _oi in sam2_state["point_inputs_per_obj"]: + _input_frames.update(sam2_state["point_inputs_per_obj"][_oi].keys()) + for _oi in sam2_state["mask_inputs_per_obj"]: + _input_frames.update(sam2_state["mask_inputs_per_obj"][_oi].keys()) + # Exclude temp output frames — preflight will consolidate those + _temp_frames = set() + for _obj_temp in sam2_state["temp_output_dict_per_obj"].values(): + _temp_frames.update(_obj_temp["cond_frame_outputs"].keys()) + _temp_frames.update(_obj_temp["non_cond_frame_outputs"].keys()) + _prev_frames = _input_frames - _temp_frames + _cond = set() + _non_cond = set() + for f in _prev_frames: + if f in sam2_state["output_dict"].get("cond_frame_outputs", {}): + _cond.add(f) + else: + _non_cond.add(f) + sam2_state["consolidated_frame_inds"] = { + "cond_frame_outputs": _cond, + "non_cond_frame_outputs": _non_cond, + } + self.tracker.propagate_in_video_preflight(sam2_state, run_mem_encoder=True) + if not inference_state["is_image_only"]: + # Clear detector conditioning frames when user clicks are received to allow + # model updating masks on these frames. It is a noop if user is refining on the + # detector conditioning frames or adding new objects. + self.clear_detector_added_cond_frame_in_sam2( + sam2_state, obj_id, frame_idx + ) + + # fetch results from states and gather across GPUs + # Use optimized caching approach to avoid reprocessing unmodified objects + if self.rank == obj_rank and len(obj_ids) > 0: + new_mask_data = (video_res_masks[obj_ids.index(obj_id)] > 0.0).to( + torch.bool + ) + else: + new_mask_data = None + + # Broadcast the new mask data across all ranks for consistency + if self.world_size > 1: + data_list = [new_mask_data] + self.broadcast_python_obj_cpu(data_list, src=obj_rank) + new_mask_data = data_list[0] + + if self.rank == 0: + obj_id_to_mask = self._build_sam2_output( + inference_state, + frame_idx, + {obj_id: new_mask_data} if new_mask_data is not None else None, + ) + # post processing - remove suppressed obj_ids + obj_id_to_score = tracker_metadata["obj_id_to_score"] + suppressed_obj_ids = tracker_metadata["rank0_metadata"][ + "suppressed_obj_ids" + ][frame_idx] + obj_id_to_sam2_score = tracker_metadata["obj_id_to_sam2_score_frame_wise"][ + frame_idx + ] + + out = { + "obj_id_to_mask": obj_id_to_mask, + "obj_id_to_score": obj_id_to_score, + "obj_id_to_sam2_score": obj_id_to_sam2_score, + } + self._cache_frame_outputs( + inference_state, + frame_idx, + obj_id_to_mask, + suppressed_obj_ids=suppressed_obj_ids, + ) + return frame_idx, self._postprocess_output( + inference_state, out, suppressed_obj_ids=suppressed_obj_ids + ) + else: + return frame_idx, None # no output on other GPUs + + def _get_mask_input(self, inference_state, frame_idx, obj_id): + """Get the mask input for a specific object on a specific frame.""" + obj_idx = self.tracker._obj_id_to_idx(inference_state, obj_id) + mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx] + if frame_idx not in mask_inputs_per_frame: + logger.info( + f"frame {frame_idx} not in mask_inputs_per_frame for obj_id {obj_id}" + ) + return None + + mask_inputs_orig = mask_inputs_per_frame[frame_idx].squeeze(0, 1) # (H, W) + return mask_inputs_orig + + def _gather_obj_id_to_mask_across_gpus(self, inference_state, obj_id_to_mask_local): + """Gather obj_id_to_mask from all GPUs. Optionally resize the masks to the video resolution.""" + tracker_metadata = inference_state["tracker_metadata"] + + # concatenate the output masklets from all local inference states + H_mask = W_mask = self.tracker.low_res_mask_size + obj_ids_local = tracker_metadata["obj_ids_per_gpu"][self.rank] + low_res_masks_local = [] + for obj_id in obj_ids_local: + if obj_id in obj_id_to_mask_local: + low_res_masks_local.append(obj_id_to_mask_local[obj_id]) + else: + low_res_masks_local.append( + torch.full((H_mask, W_mask), -1024.0, device=self.device) + ) + if len(low_res_masks_local) > 0: + low_res_masks_local = torch.stack(low_res_masks_local, dim=0) # (N, H, W) + assert low_res_masks_local.shape[1:] == (H_mask, W_mask) + else: + low_res_masks_local = torch.zeros(0, H_mask, W_mask, device=self.device) + + # all-gather `low_res_masks_local` into `low_res_masks_global` + # - low_res_masks_global: Tensor -- (num_global_obj, H_mask, W_mask) + if self.world_size > 1: + low_res_masks_local = low_res_masks_local.float().contiguous() + low_res_masks_peers = [ + low_res_masks_local.new_empty(num_obj, H_mask, W_mask) + for num_obj in tracker_metadata["num_obj_per_gpu"] + ] + dist.all_gather(low_res_masks_peers, low_res_masks_local) + low_res_masks_global = torch.cat(low_res_masks_peers, dim=0) + else: + low_res_masks_global = low_res_masks_local + return low_res_masks_global + + def _convert_low_res_mask_to_video_res(self, low_res_mask, inference_state): + """ + Convert a low-res mask to video resolution, matching the format expected by _build_sam2_output. + + Args: + low_res_mask: Tensor of shape (H_low_res, W_low_res) + inference_state: Contains video dimensions + + Returns: + video_res_mask: Tensor of shape (1, H_video, W_video) bool + """ + if low_res_mask is None: + return None + + # Convert to 3D for interpolation: (H_low_res, W_low_res) -> (1, H_low_res, W_low_res) + low_res_mask_3d = low_res_mask.unsqueeze(0).unsqueeze(0) + + # Get video dimensions + H_video = inference_state["orig_height"] + W_video = inference_state["orig_width"] + + video_res_mask = F.interpolate( + low_res_mask_3d.float(), + size=(H_video, W_video), + mode="bilinear", + align_corners=False, + ) # (1, H_video, W_video) + + # Convert to boolean - already in the right shape! + return (video_res_mask.squeeze(0) > 0.0).to(torch.bool) + + def clear_detector_added_cond_frame_in_sam2( + self, sam2_state, obj_id, refined_frame_idx + ): + """Clear detector added conditioning frame if it is within a predefined window + of the refined frame. This allow model to update masks on these frames.""" + obj_idx = self.tracker._obj_id_to_idx(sam2_state, obj_id) + + mask_only_cond_frame_indices = [] + window = self.refinement_detector_cond_frame_removal_window + for frame_idx in sam2_state["mask_inputs_per_obj"][obj_idx]: + if frame_idx not in sam2_state["point_inputs_per_obj"][obj_idx]: + # clear conditioning frames within a window of the refined frame + if abs(frame_idx - refined_frame_idx) <= window: + mask_only_cond_frame_indices.append(frame_idx) + + # clear + if len(mask_only_cond_frame_indices) > 0: + for frame_idx in mask_only_cond_frame_indices: + # obj_ids_on_this_frame is essentially all obj_ids in the state + # since they are bucket batched + obj_ids_on_this_frame = sam2_state["obj_id_to_idx"].keys() + for obj_id2 in obj_ids_on_this_frame: + self.tracker.clear_all_points_in_frame( + sam2_state, frame_idx, obj_id2, need_output=False + ) + logger.info( + f"Cleared detector mask only conditioning frames ({mask_only_cond_frame_indices}) in SAM2." + ) + return + + def _extract_object_to_singleton_state(self, inference_state, obj_id, obj_rank): + """ + Extract an object from a batched inference state into its own singleton state. + """ + if self.rank != obj_rank: + return + + tracker_states_local = inference_state["sam2_inference_states"] + + # Find the inference state containing this object + source_state = None + source_state_idx = None + for idx, state in enumerate(tracker_states_local): + if obj_id in state["obj_ids"]: + source_state = state + source_state_idx = idx + break + + assert source_state is not None + + if len(source_state["obj_ids"]) <= 1: + # Object not found or already in singleton state + return + + # Step 1: Extract all the object's state data before removing it + obj_idx_in_source = source_state["obj_id_to_idx"][obj_id] + multiplex_state = source_state.get("multiplex_state") + + # Extract consolidated outputs (obj_ptr, maskmem_features, etc.) BEFORE + # remove_object modifies the source tensors. + singleton_consolidated_outputs = { + "cond_frame_outputs": {}, + "non_cond_frame_outputs": {}, + } + if "output_dict" in source_state: + for storage_key in ["cond_frame_outputs", "non_cond_frame_outputs"]: + source_outputs = source_state["output_dict"].get(storage_key, {}) + for f_idx, source_frame_out in source_outputs.items(): + if source_frame_out["pred_masks"].shape[0] < obj_idx_in_source + 1: + continue + singleton_frame_out = { + "pred_masks": source_frame_out["pred_masks"][ + obj_idx_in_source : obj_idx_in_source + 1 + ].clone(), + "object_score_logits": source_frame_out["object_score_logits"][ + obj_idx_in_source : obj_idx_in_source + 1 + ].clone(), + "image_features": source_frame_out.get("image_features"), + "image_pos_enc": source_frame_out.get("image_pos_enc"), + "local_obj_id_to_idx": {obj_id: 0}, + } + # Extract maskmem_features (demux from multiplex space) + maskmem_features = source_frame_out.get("maskmem_features") + if maskmem_features is not None and multiplex_state is not None: + try: + demuxed = multiplex_state.demux(maskmem_features) + maskmem_features = demuxed[ + obj_idx_in_source : obj_idx_in_source + 1 + ].clone() + except (AssertionError, IndexError): + maskmem_features = None + elif maskmem_features is not None: + maskmem_features = maskmem_features[ + obj_idx_in_source : obj_idx_in_source + 1 + ].clone() + singleton_frame_out["maskmem_features"] = maskmem_features + # Extract maskmem_pos_enc (demux level by level) + maskmem_pos_enc = source_frame_out.get("maskmem_pos_enc") + if maskmem_pos_enc is not None: + remapped = [] + for level_enc in maskmem_pos_enc: + if level_enc is None: + remapped.append(None) + continue + if multiplex_state is not None: + try: + demuxed = multiplex_state.demux(level_enc) + remapped.append( + demuxed[ + obj_idx_in_source : obj_idx_in_source + 1 + ].clone() + ) + except (AssertionError, IndexError): + remapped.append(None) + else: + remapped.append( + level_enc[ + obj_idx_in_source : obj_idx_in_source + 1 + ].clone() + ) + maskmem_pos_enc = remapped + singleton_frame_out["maskmem_pos_enc"] = maskmem_pos_enc + # Extract obj_ptr (demux from multiplex space) + if ( + "obj_ptr" in source_frame_out + and self.tracker.use_obj_ptrs_in_encoder + ): + source_obj_ptr = source_frame_out["obj_ptr"] + if multiplex_state is not None: + obj_ptr_data = multiplex_state.demux(source_obj_ptr) + singleton_frame_out["obj_ptr"] = obj_ptr_data[ + obj_idx_in_source : obj_idx_in_source + 1 + ].clone() + else: + singleton_frame_out["obj_ptr"] = source_obj_ptr[ + obj_idx_in_source : obj_idx_in_source + 1 + ].clone() + # Extract conditioning_objects + if "conditioning_objects" in source_frame_out: + if ( + obj_idx_in_source + in source_frame_out["conditioning_objects"] + ): + singleton_frame_out["conditioning_objects"] = {0} + else: + singleton_frame_out["conditioning_objects"] = set() + singleton_consolidated_outputs[storage_key][f_idx] = ( + singleton_frame_out + ) + + # Extract point and mask inputs for this object + extracted_point_inputs = {} + extracted_mask_inputs = {} + + if ( + "point_inputs_per_obj" in source_state + and obj_idx_in_source in source_state["point_inputs_per_obj"] + ): + extracted_point_inputs = source_state["point_inputs_per_obj"][ + obj_idx_in_source + ].copy() + + if ( + "mask_inputs_per_obj" in source_state + and obj_idx_in_source in source_state["mask_inputs_per_obj"] + ): + extracted_mask_inputs = source_state["mask_inputs_per_obj"][ + obj_idx_in_source + ].copy() + + # Extract per-object outputs - these are already properly sliced for the object + extracted_obj_cond_outputs = {} + extracted_obj_non_cond_outputs = {} + extracted_temp_cond_outputs = {} + extracted_temp_non_cond_outputs = {} + + if ( + "output_dict_per_obj" in source_state + and obj_idx_in_source in source_state["output_dict_per_obj"] + ): + obj_output_dict = source_state["output_dict_per_obj"][obj_idx_in_source] + extracted_obj_cond_outputs = obj_output_dict.get( + "cond_frame_outputs", {} + ).copy() + cond_input_keys = ( + extracted_point_inputs.keys() | extracted_mask_inputs.keys() + ) + # we may have obj cond outputs for other objects in a batch, so limit to cond inputs for only this object + extracted_obj_cond_outputs = { + k: v + for k, v in extracted_obj_cond_outputs.items() + if k in cond_input_keys + } + + extracted_obj_non_cond_outputs = obj_output_dict.get( + "non_cond_frame_outputs", {} + ).copy() + + if ( + "temp_output_dict_per_obj" in source_state + and obj_idx_in_source in source_state["temp_output_dict_per_obj"] + ): + temp_obj_output_dict = source_state["temp_output_dict_per_obj"][ + obj_idx_in_source + ] + extracted_temp_cond_outputs = temp_obj_output_dict.get( + "cond_frame_outputs", {} + ).copy() + extracted_temp_non_cond_outputs = temp_obj_output_dict.get( + "non_cond_frame_outputs", {} + ).copy() + + # Step 2: Remove the object from the source state + remaining_obj_ids, _ = self.tracker.remove_object( + source_state, obj_id, strict=False, need_output=False + ) + + # Step 3: Create a new singleton inference state + new_sam2_state = self.tracker.init_state( + cached_features=inference_state["feature_cache"], + video_height=inference_state["orig_height"], + video_width=inference_state["orig_width"], + num_frames=inference_state["num_frames"], + ) + + # Step 4: Set up the singleton state structure for the extracted object + # Map the object to index 0 in the new singleton state + new_sam2_state["obj_id_to_idx"] = {obj_id: 0} + new_sam2_state["obj_idx_to_id"] = {0: obj_id} + new_sam2_state["obj_ids"] = [obj_id] + + # Step 5: Restore all the extracted state + # Restore point and mask inputs + new_sam2_state["point_inputs_per_obj"] = {0: extracted_point_inputs} + new_sam2_state["mask_inputs_per_obj"] = {0: extracted_mask_inputs} + + # Restore per-object output dictionaries (already properly sliced) + new_sam2_state["output_dict_per_obj"] = { + 0: { + "cond_frame_outputs": extracted_obj_cond_outputs, + "non_cond_frame_outputs": extracted_obj_non_cond_outputs, + } + } + + # Restore temporary outputs + new_sam2_state["temp_output_dict_per_obj"] = { + 0: { + "cond_frame_outputs": extracted_temp_cond_outputs, + "non_cond_frame_outputs": extracted_temp_non_cond_outputs, + } + } + + # Step 6: Rebuild the consolidated output_dict for the singleton state + # Use the extracted consolidated outputs which include obj_ptr, + # maskmem_features, maskmem_pos_enc (not just pred_masks/object_score_logits) + + # Create singleton multiplex state and remux extracted tensors + new_multiplex_state = self.tracker.multiplex_controller.get_state( + num_valid_entries=1, + device=source_state.get("device", "cuda"), + dtype=torch.float32, + random=False, + object_ids=[obj_id], + ) + new_sam2_state["multiplex_state"] = new_multiplex_state + + for storage_key in ["cond_frame_outputs", "non_cond_frame_outputs"]: + for f_idx, frame_out in singleton_consolidated_outputs[storage_key].items(): + if frame_out.get("maskmem_features") is not None: + frame_out["maskmem_features"] = frame_out[ + "maskmem_features" + ].clone() + if frame_out.get("maskmem_pos_enc") is not None: + frame_out["maskmem_pos_enc"] = [ + level.clone() if level is not None else None + for level in frame_out["maskmem_pos_enc"] + ] + if "obj_ptr" in frame_out and self.tracker.use_obj_ptrs_in_encoder: + frame_out["obj_ptr"] = new_multiplex_state.mux(frame_out["obj_ptr"]) + + new_sam2_state["output_dict"] = singleton_consolidated_outputs + + # Step 7: Copy other important state if it exists + for key in [ + "first_ann_frame_idx", + "tracking_has_started", + ]: + if key in source_state: + new_sam2_state[key] = source_state[key] + + # Leave consolidated_frame_inds empty so preflight reconstructs from per-obj data + new_sam2_state["consolidated_frame_inds"] = { + "cond_frame_outputs": set(), + "non_cond_frame_outputs": set(), + } + + # Step 8: Add the new singleton state to the list + tracker_states_local.append(new_sam2_state) + + # Step 9: If the source state is now empty, remove it + if len(remaining_obj_ids) == 0: + tracker_states_local.pop(source_state_idx) + logger.info( + f"Removed empty inference state after extracting object {obj_id}" + ) + + logger.info(f"Object {obj_id} successfully extracted to singleton state") diff --git a/third_party/sam3/sam3/model/sam3_multiplex_video_predictor.py b/third_party/sam3/sam3/model/sam3_multiplex_video_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..b86c7f1cbe6ef1dffd91b0c82a29abf61a313a1c --- /dev/null +++ b/third_party/sam3/sam3/model/sam3_multiplex_video_predictor.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +""" +Sam3MultiplexVideoPredictor — user-facing entry point for SAM 3.1 multiplex. + +Ported from onevision Sam3Model (webdemo/ta/models/sam3_model.py). +Handles warm-up compilation, bf16 autocast, and session management +via the shared Sam3BasePredictor handle_request/handle_stream_request API. +""" + +from typing import Dict, Optional + +import torch +from sam3.logger import get_logger +from sam3.model.sam3_base_predictor import Sam3BasePredictor + +logger = get_logger(__name__) + + +class Sam3MultiplexVideoPredictor(Sam3BasePredictor): + """ + User-facing predictor for SAM 3.1 multiplex video tracking. + + Wraps Sam3MultiplexTrackingWithInteractivity with: + - bf16 autocast + - Warm-up compilation (when compile=True) + - Session expiration management + - handle_request / handle_stream_request dispatch API (from Sam3BasePredictor) + """ + + def __init__( + self, + model, + session_expiration_sec=1200, + default_output_prob_thresh=0.5, + async_loading_frames=True, + warm_up=False, + ): + super().__init__() + self.model = model + self.session_expiration_sec = session_expiration_sec + self.default_output_prob_thresh = default_output_prob_thresh + self.async_loading_frames = async_loading_frames + + # turn on tfloat32 for Ampere GPUs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + # use bfloat16 inference for Flash Attention kernel + self.bf16_context = torch.autocast(device_type="cuda", dtype=torch.bfloat16) + self.bf16_context.__enter__() + + if warm_up: + self.model._warm_up_complete = False + self.model.warm_up_compilation() + self.model._warm_up_complete = True + + def _extend_expiration_time(self, session): + """Update last-use time and store session expiration timeout.""" + super()._extend_expiration_time(session) + if self.session_expiration_sec: + session["expiration_sec"] = self.session_expiration_sec diff --git a/third_party/sam3/sam3/model/sam3_tracker_base.py b/third_party/sam3/sam3/model/sam3_tracker_base.py new file mode 100644 index 0000000000000000000000000000000000000000..c7f40b7bd2cd3782cb0e03ab71ad778ba217eab2 --- /dev/null +++ b/third_party/sam3/sam3/model/sam3_tracker_base.py @@ -0,0 +1,1185 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +import logging + +import torch +import torch.nn.functional as F +from sam3.model.memory import SimpleMaskEncoder +from sam3.model.sam3_tracker_utils import get_1d_sine_pe, select_closest_cond_frames +from sam3.sam.mask_decoder import MaskDecoder, MLP +from sam3.sam.prompt_encoder import PromptEncoder +from sam3.sam.transformer import TwoWayTransformer +from sam3.train.data.collator import BatchedDatapoint + +try: + from timm.layers import trunc_normal_ +except ModuleNotFoundError: + # compatibility for older timm versions + from timm.models.layers import trunc_normal_ + +# a large negative value as a placeholder score for missing objects +NO_OBJ_SCORE = -1024.0 + + +class Sam3TrackerBase(torch.nn.Module): + def __init__( + self, + backbone, + transformer, + maskmem_backbone, + num_maskmem=7, # default 1 input frame + 6 previous frames as in CAE + image_size=1008, + backbone_stride=14, # stride of the image backbone output + # The maximum number of conditioning frames to participate in the memory attention (-1 means no limit; if there are more conditioning frames than this limit, + # we only cross-attend to the temporally closest `max_cond_frames_in_attn` conditioning frames in the encoder when tracking each frame). This gives the model + # a temporal locality when handling a large number of annotated frames (since closer frames should be more important) and also avoids GPU OOM. + max_cond_frames_in_attn=-1, + # Whether to always keep the first conditioning frame in case we exceed the maximum number of conditioning frames allowed + keep_first_cond_frame=False, + # whether to output multiple (3) masks for the first click on initial conditioning frames + multimask_output_in_sam=False, + # the minimum and maximum number of clicks to use multimask_output_in_sam (only relevant when `multimask_output_in_sam=True`; + # default is 1 for both, meaning that only the first click gives multimask output; also note that a box counts as two points) + multimask_min_pt_num=1, + multimask_max_pt_num=1, + # whether to also use multimask output for tracking (not just for the first click on initial conditioning frames; only relevant when `multimask_output_in_sam=True`) + multimask_output_for_tracking=False, + # whether to forward image features per frame (as it's being tracked) during evaluation, instead of forwarding image features + # of all frames at once. This avoids backbone OOM errors on very long videos in evaluation, but could be slightly slower. + forward_backbone_per_frame_for_eval=False, + # The memory bank's temporal stride during evaluation (i.e. the `r` parameter in XMem and Cutie; XMem and Cutie use r=5). + # For r>1, the (self.num_maskmem - 1) non-conditioning memory frames consist of + # (self.num_maskmem - 2) nearest frames from every r-th frames, plus the last frame. + memory_temporal_stride_for_eval=1, + # whether to offload outputs to CPU memory during evaluation, to avoid GPU OOM on very long videos or very large resolutions or too many objects + # (it's recommended to use `forward_backbone_per_frame_for_eval=True` first before setting this option to True) + offload_output_to_cpu_for_eval=False, + # whether to trim the output of past non-conditioning frames (num_maskmem frames before the current frame) during evaluation + # (this helps save GPU or CPU memory on very long videos for semi-supervised VOS eval, where only the first frame receives prompts) + trim_past_non_cond_mem_for_eval=False, + # whether to apply non-overlapping constraints on the object masks in the memory encoder during evaluation (to avoid/alleviate superposing masks) + non_overlap_masks_for_mem_enc=False, + # the maximum number of object pointers from other frames in encoder cross attention + max_obj_ptrs_in_encoder=16, + # extra arguments used to construct the SAM mask decoder; if not None, it should be a dict of kwargs to be passed into `MaskDecoder` class. + sam_mask_decoder_extra_args=None, + # whether to compile all the model compoents + compile_all_components=False, + # select the frame with object existence + use_memory_selection=False, + # when using memory selection, the threshold to determine if the frame is good + mf_threshold=0.01, + ): + super().__init__() + + # Part 1: the image backbone + self.backbone = backbone + self.num_feature_levels = 3 + self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder + # A conv layer to downsample the GT mask prompt to stride 4 (the same stride as + # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale, + # so that it can be fed into the SAM mask decoder to generate a pointer. + self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4) + + # Part 2: encoder-only transformer to fuse current frame's visual features + # with memories from past frames + assert transformer.decoder is None, "transformer should be encoder-only" + self.transformer = transformer + self.hidden_dim = transformer.d_model + + # Part 3: memory encoder for the previous frame's outputs + self.maskmem_backbone = maskmem_backbone + self.mem_dim = self.hidden_dim + if hasattr(self.maskmem_backbone, "out_proj") and hasattr( + self.maskmem_backbone.out_proj, "weight" + ): + # if there is compression of memories along channel dim + self.mem_dim = self.maskmem_backbone.out_proj.weight.shape[0] + self.num_maskmem = num_maskmem # Number of memories accessible + + # Temporal encoding of the memories + self.maskmem_tpos_enc = torch.nn.Parameter( + torch.zeros(num_maskmem, 1, 1, self.mem_dim) + ) + trunc_normal_(self.maskmem_tpos_enc, std=0.02) + + # a single token to indicate no memory embedding from previous frames + self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) + self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) + trunc_normal_(self.no_mem_embed, std=0.02) + trunc_normal_(self.no_mem_pos_enc, std=0.02) + # Apply sigmoid to the output raw mask logits (to turn them from + # range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder + self.sigmoid_scale_for_mem_enc = 20.0 + self.sigmoid_bias_for_mem_enc = -10.0 + self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc + self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval + # On frames with mask input, whether to directly output the input mask without + # using a SAM prompt encoder + mask decoder + self.multimask_output_in_sam = multimask_output_in_sam + self.multimask_min_pt_num = multimask_min_pt_num + self.multimask_max_pt_num = multimask_max_pt_num + self.multimask_output_for_tracking = multimask_output_for_tracking + + # Part 4: SAM-style prompt encoder (for both mask and point inputs) + # and SAM-style mask decoder for the final mask output + self.image_size = image_size + self.backbone_stride = backbone_stride + self.low_res_mask_size = self.image_size // self.backbone_stride * 4 + # we resize the mask if it doesn't match `self.input_mask_size` (which is always 4x + # the low-res mask size, regardless of the actual input image size); this is because + # `_use_mask_as_output` always downsamples the input masks by 4x + self.input_mask_size = self.low_res_mask_size * 4 + self.forward_backbone_per_frame_for_eval = forward_backbone_per_frame_for_eval + self.offload_output_to_cpu_for_eval = offload_output_to_cpu_for_eval + self.trim_past_non_cond_mem_for_eval = trim_past_non_cond_mem_for_eval + self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args + self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim)) + trunc_normal_(self.no_obj_ptr, std=0.02) + self.no_obj_embed_spatial = torch.nn.Parameter(torch.zeros(1, self.mem_dim)) + trunc_normal_(self.no_obj_embed_spatial, std=0.02) + + self._build_sam_heads() + self.max_cond_frames_in_attn = max_cond_frames_in_attn + self.keep_first_cond_frame = keep_first_cond_frame + + # Use frame filtering according to SAM2Long + self.use_memory_selection = use_memory_selection + self.mf_threshold = mf_threshold + + # Compile all components of the model + self.compile_all_components = compile_all_components + if self.compile_all_components: + self._compile_all_components() + + @property + def device(self): + return next(self.parameters()).device + + def _get_tpos_enc(self, rel_pos_list, device, max_abs_pos=None, dummy=False): + if dummy: + return torch.zeros(len(rel_pos_list), self.mem_dim, device=device) + + t_diff_max = max_abs_pos - 1 if max_abs_pos is not None else 1 + pos_enc = ( + torch.tensor(rel_pos_list).pin_memory().to(device=device, non_blocking=True) + / t_diff_max + ) + tpos_dim = self.hidden_dim + pos_enc = get_1d_sine_pe(pos_enc, dim=tpos_dim) + pos_enc = self.obj_ptr_tpos_proj(pos_enc) + + return pos_enc + + def _build_sam_heads(self): + """Build SAM-style prompt encoder and mask decoder.""" + self.sam_prompt_embed_dim = self.hidden_dim + self.sam_image_embedding_size = self.image_size // self.backbone_stride + + # build PromptEncoder and MaskDecoder from SAM + # (their hyperparameters like `mask_in_chans=16` are from SAM code) + self.sam_prompt_encoder = PromptEncoder( + embed_dim=self.sam_prompt_embed_dim, + image_embedding_size=( + self.sam_image_embedding_size, + self.sam_image_embedding_size, + ), + input_image_size=(self.image_size, self.image_size), + mask_in_chans=16, + ) + self.sam_mask_decoder = MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=self.sam_prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=self.sam_prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + use_high_res_features=True, + iou_prediction_use_sigmoid=True, + pred_obj_scores=True, + pred_obj_scores_mlp=True, + use_multimask_token_for_obj_ptr=True, + **(self.sam_mask_decoder_extra_args or {}), + ) + # a linear projection on SAM output tokens to turn them into object pointers + self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim) + self.obj_ptr_proj = MLP(self.hidden_dim, self.hidden_dim, self.hidden_dim, 3) + # a linear projection on temporal positional encoding in object pointers to + # avoid potential interference with spatial positional encoding + self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim) + + def _forward_sam_heads( + self, + backbone_features, + point_inputs=None, + mask_inputs=None, + high_res_features=None, + multimask_output=False, + gt_masks=None, + ): + """ + Forward SAM prompt encoders and mask heads. + + Inputs: + - backbone_features: image features of [B, C, H, W] shape + - point_inputs: a dictionary with "point_coords" and "point_labels", where + 1) "point_coords" has [B, P, 2] shape and float32 dtype and contains the + absolute pixel-unit coordinate in (x, y) format of the P input points + 2) "point_labels" has shape [B, P] and int32 dtype, where 1 means + positive clicks, 0 means negative clicks, and -1 means padding + - mask_inputs: a mask of [B, 1, H*16, W*16] shape, float or bool, with the + same spatial size as the image. + - high_res_features: either 1) None or 2) or a list of length 2 containing + two feature maps of [B, C, 4*H, 4*W] and [B, C, 2*H, 2*W] shapes respectively, + which will be used as high-resolution feature maps for SAM decoder. + - multimask_output: if it's True, we output 3 candidate masks and their 3 + corresponding IoU estimates, and if it's False, we output only 1 mask and + its corresponding IoU estimate. + + Outputs: + - low_res_multimasks: [B, M, H*4, W*4] shape (where M = 3 if + `multimask_output=True` and M = 1 if `multimask_output=False`), the SAM + output mask logits (before sigmoid) for the low-resolution masks, with 4x + the resolution (1/4 stride) of the input backbone_features. + - high_res_multimasks: [B, M, H*16, W*16] shape (where M = 3 + if `multimask_output=True` and M = 1 if `multimask_output=False`), + upsampled from the low-resolution masks, with shape size as the image + (stride is 1 pixel). + - ious, [B, M] shape, where (where M = 3 if `multimask_output=True` and M = 1 + if `multimask_output=False`), the estimated IoU of each output mask. + - low_res_masks: [B, 1, H*4, W*4] shape, the best mask in `low_res_multimasks`. + If `multimask_output=True`, it's the mask with the highest IoU estimate. + If `multimask_output=False`, it's the same as `low_res_multimasks`. + - high_res_masks: [B, 1, H*16, W*16] shape, the best mask in `high_res_multimasks`. + If `multimask_output=True`, it's the mask with the highest IoU estimate. + If `multimask_output=False`, it's the same as `high_res_multimasks`. + - obj_ptr: [B, C] shape, the object pointer vector for the output mask, extracted + based on the output token from the SAM mask decoder. + """ + B = backbone_features.size(0) + device = backbone_features.device + assert backbone_features.size(1) == self.sam_prompt_embed_dim + assert backbone_features.size(2) == self.sam_image_embedding_size + assert backbone_features.size(3) == self.sam_image_embedding_size + + # a) Handle point prompts + if point_inputs is not None: + sam_point_coords = point_inputs["point_coords"] + sam_point_labels = point_inputs["point_labels"] + assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B + else: + # If no points are provide, pad with an empty point (with label -1) + sam_point_coords = torch.zeros(B, 1, 2, device=device) + sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device) + + # b) Handle mask prompts + if mask_inputs is not None: + # If mask_inputs is provided, downsize it into low-res mask input if needed + # and feed it as a dense mask prompt into the SAM mask encoder + assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1) + if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size: + sam_mask_prompt = F.interpolate( + mask_inputs.float(), + size=self.sam_prompt_encoder.mask_input_size, + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + else: + sam_mask_prompt = mask_inputs + else: + # Otherwise, simply feed None (and SAM's prompt encoder will add + # a learned `no_mask_embed` to indicate no mask input in this case). + sam_mask_prompt = None + + sparse_embeddings, dense_embeddings = self.sam_prompt_encoder( + points=(sam_point_coords, sam_point_labels), + boxes=None, + masks=sam_mask_prompt, + ) + # Clone image_pe and the outputs of sam_prompt_encoder + # to enable compilation + sparse_embeddings = self._maybe_clone(sparse_embeddings) + dense_embeddings = self._maybe_clone(dense_embeddings) + image_pe = self._maybe_clone(self.sam_prompt_encoder.get_dense_pe()) + with torch.profiler.record_function("sam_mask_decoder"): + ( + low_res_multimasks, + ious, + sam_output_tokens, + object_score_logits, + ) = self.sam_mask_decoder( + image_embeddings=backbone_features, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + repeat_image=False, # the image is already batched + high_res_features=high_res_features, + ) + # Clone the output of sam_mask_decoder + # to enable compilation + low_res_multimasks = self._maybe_clone(low_res_multimasks) + ious = self._maybe_clone(ious) + sam_output_tokens = self._maybe_clone(sam_output_tokens) + object_score_logits = self._maybe_clone(object_score_logits) + + if self.training and self.teacher_force_obj_scores_for_mem: + # we use gt to detect if there is an object or not to + # select no obj ptr and use an empty mask for spatial memory + is_obj_appearing = torch.any(gt_masks.float().flatten(1) > 0, dim=1) + is_obj_appearing = is_obj_appearing[..., None] + else: + is_obj_appearing = object_score_logits > 0 + + # Mask used for spatial memories is always a *hard* choice between obj and no obj, + # consistent with the actual mask prediction + low_res_multimasks = torch.where( + is_obj_appearing[:, None, None], + low_res_multimasks, + NO_OBJ_SCORE, + ) + + # convert masks from possibly bfloat16 (or float16) to float32 + # (older PyTorch versions before 2.1 don't support `interpolate` on bf16) + low_res_multimasks = low_res_multimasks.float() + high_res_multimasks = F.interpolate( + low_res_multimasks, + size=(self.image_size, self.image_size), + mode="bilinear", + align_corners=False, + ) + + sam_output_token = sam_output_tokens[:, 0] + if multimask_output: + # take the best mask prediction (with the highest IoU estimation) + best_iou_inds = torch.argmax(ious, dim=-1) + batch_inds = torch.arange(B, device=device) + low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + if sam_output_tokens.size(1) > 1: + sam_output_token = sam_output_tokens[batch_inds, best_iou_inds] + else: + low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks + + # Extract object pointer from the SAM output token (with occlusion handling) + obj_ptr = self.obj_ptr_proj(sam_output_token) + lambda_is_obj_appearing = is_obj_appearing.float() + + obj_ptr = lambda_is_obj_appearing * obj_ptr + obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr + + return ( + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) + + def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs): + """ + Directly turn binary `mask_inputs` into a output mask logits without using SAM. + (same input and output shapes as in _forward_sam_heads above). + """ + # Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid). + out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05 + mask_inputs_float = mask_inputs.float() + high_res_masks = mask_inputs_float * out_scale + out_bias + low_res_masks = F.interpolate( + high_res_masks, + size=( + high_res_masks.size(-2) // self.backbone_stride * 4, + high_res_masks.size(-1) // self.backbone_stride * 4, + ), + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + # a dummy IoU prediction of all 1's under mask input + ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float() + # produce an object pointer using the SAM decoder from the mask input + _, _, _, _, _, obj_ptr, _ = self._forward_sam_heads( + backbone_features=backbone_features, + mask_inputs=self.mask_downsample(mask_inputs_float), + high_res_features=high_res_features, + gt_masks=mask_inputs, + ) + # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem; + # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying + # on the object_scores from the SAM decoder. + is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1) + is_obj_appearing = is_obj_appearing[..., None] + lambda_is_obj_appearing = is_obj_appearing.float() + object_score_logits = out_scale * lambda_is_obj_appearing + out_bias + obj_ptr = lambda_is_obj_appearing * obj_ptr + obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr + + return ( + low_res_masks, + high_res_masks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) + + def forward(self, input: BatchedDatapoint, is_inference=False): + raise NotImplementedError( + "Please use the corresponding methods in SAM3VideoPredictor for inference." + "See examples/sam3_dense_video_tracking.ipynb for an inference example." + ) + + def forward_image(self, img_batch): + """Get the image feature on the input batch.""" + # This line is the only change from the parent class + # to use the SAM3 backbone instead of the SAM2 backbone. + backbone_out = self.backbone.forward_image(img_batch)["sam2_backbone_out"] + # precompute projected level 0 and level 1 features in SAM decoder + # to avoid running it again on every SAM click + backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0( + backbone_out["backbone_fpn"][0] + ) + backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1( + backbone_out["backbone_fpn"][1] + ) + # Clone to help torch.compile + for i in range(len(backbone_out["backbone_fpn"])): + backbone_out["backbone_fpn"][i] = self._maybe_clone( + backbone_out["backbone_fpn"][i] + ) + backbone_out["vision_pos_enc"][i] = self._maybe_clone( + backbone_out["vision_pos_enc"][i] + ) + return backbone_out + + def _prepare_backbone_features(self, backbone_out): + """Prepare and flatten visual features (same as in MDETR_API model).""" + backbone_out = backbone_out.copy() + assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"]) + assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels + + feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :] + vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :] + + feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds] + # flatten NxCxHxW to HWxNxC + vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps] + vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds] + + return backbone_out, vision_feats, vision_pos_embeds, feat_sizes + + def _prepare_backbone_features_per_frame(self, img_batch, img_ids): + """Compute the image backbone features on the fly for the given img_ids.""" + # Only forward backbone on unique image ids to avoid repeatitive computation + # (if `img_ids` has only one element, it's already unique so we skip this step). + if img_ids.numel() > 1: + unique_img_ids, inv_ids = torch.unique(img_ids, return_inverse=True) + else: + unique_img_ids, inv_ids = img_ids, None + + # Compute the image features on those unique image ids + image = img_batch[unique_img_ids] + backbone_out = self.forward_image(image) + ( + _, + vision_feats, + vision_pos_embeds, + feat_sizes, + ) = self._prepare_backbone_features(backbone_out) + # Inverse-map image features for `unique_img_ids` to the final image features + # for the original input `img_ids`. + if inv_ids is not None: + image = image[inv_ids] + vision_feats = [x[:, inv_ids] for x in vision_feats] + vision_pos_embeds = [x[:, inv_ids] for x in vision_pos_embeds] + + return image, vision_feats, vision_pos_embeds, feat_sizes + + def cal_mem_score(self, object_score_logits, iou_score): + object_score_norm = torch.where( + object_score_logits > 0, + object_score_logits.sigmoid() * 2 - 1, ## rescale to [0, 1] + torch.zeros_like(object_score_logits), + ) + score_per_frame = (object_score_norm * iou_score).mean() + return score_per_frame + + def frame_filter(self, output_dict, track_in_reverse, frame_idx, num_frames, r): + if (frame_idx == 0 and not track_in_reverse) or ( + frame_idx == num_frames - 1 and track_in_reverse + ): + return [] + + max_num = min( + num_frames, self.max_obj_ptrs_in_encoder + ) ## maximum number of pointer memory frames to consider + + if not track_in_reverse: + start = frame_idx - 1 + end = 0 + step = -r + must_include = frame_idx - 1 + else: + start = frame_idx + 1 + end = num_frames + step = r + must_include = frame_idx + 1 + + valid_indices = [] + for i in range(start, end, step): + if ( + i not in output_dict["non_cond_frame_outputs"] + or "eff_iou_score" not in output_dict["non_cond_frame_outputs"][i] + ): + continue + + score_per_frame = output_dict["non_cond_frame_outputs"][i]["eff_iou_score"] + + if score_per_frame > self.mf_threshold: # threshold + valid_indices.insert(0, i) + + if len(valid_indices) >= max_num - 1: + break + + if must_include not in valid_indices: + valid_indices.append(must_include) + + return valid_indices + + def _prepare_memory_conditioned_features( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + use_prev_mem_frame=True, + ): + """Fuse the current frame's visual feature map with previous memory.""" + B = current_vision_feats[-1].size(1) # batch size on this frame + C = self.hidden_dim + H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size + device = current_vision_feats[-1].device + # The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images. + # In this case, we skip the fusion with any memory. + if self.num_maskmem == 0: # Disable memory and skip fusion + pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) + return pix_feat + + num_obj_ptr_tokens = 0 + tpos_sign_mul = -1 if track_in_reverse else 1 + # Step 1: condition the visual features of the current frame on previous memories + if not is_init_cond_frame and use_prev_mem_frame: + # Retrieve the memories encoded with the maskmem backbone + to_cat_prompt, to_cat_prompt_mask, to_cat_prompt_pos_embed = [], [], [] + # Add conditioning frames's output first (all cond frames have t_pos=0 for + # when getting temporal positional embedding below) + assert len(output_dict["cond_frame_outputs"]) > 0 + # Select a maximum number of temporally closest cond frames for cross attention + cond_outputs = output_dict["cond_frame_outputs"] + selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames( + frame_idx, + cond_outputs, + self.max_cond_frames_in_attn, + keep_first_cond_frame=self.keep_first_cond_frame, + ) + t_pos_and_prevs = [ + ((frame_idx - t) * tpos_sign_mul, out, True) + for t, out in selected_cond_outputs.items() + ] + # Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory + # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1 + # We also allow taking the memory frame non-consecutively (with r>1), in which case + # we take (self.num_maskmem - 2) frames among every r-th frames plus the last frame. + r = 1 if self.training else self.memory_temporal_stride_for_eval + + if self.use_memory_selection: + valid_indices = self.frame_filter( + output_dict, track_in_reverse, frame_idx, num_frames, r + ) + + for t_pos in range(1, self.num_maskmem): + t_rel = self.num_maskmem - t_pos # how many frames before current frame + if self.use_memory_selection: + if t_rel > len(valid_indices): + continue + prev_frame_idx = valid_indices[-t_rel] + else: + if t_rel == 1: + # for t_rel == 1, we take the last frame (regardless of r) + if not track_in_reverse: + # the frame immediately before this frame (i.e. frame_idx - 1) + prev_frame_idx = frame_idx - t_rel + else: + # the frame immediately after this frame (i.e. frame_idx + 1) + prev_frame_idx = frame_idx + t_rel + else: + # for t_rel >= 2, we take the memory frame from every r-th frames + if not track_in_reverse: + # first find the nearest frame among every r-th frames before this frame + # for r=1, this would be (frame_idx - 2) + prev_frame_idx = ((frame_idx - 2) // r) * r + # then seek further among every r-th frames + prev_frame_idx = prev_frame_idx - (t_rel - 2) * r + else: + # first find the nearest frame among every r-th frames after this frame + # for r=1, this would be (frame_idx + 2) + prev_frame_idx = -(-(frame_idx + 2) // r) * r + # then seek further among every r-th frames + prev_frame_idx = prev_frame_idx + (t_rel - 2) * r + + out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None) + if out is None: + # If an unselected conditioning frame is among the last (self.num_maskmem - 1) + # frames, we still attend to it as if it's a non-conditioning frame. + out = unselected_cond_outputs.get(prev_frame_idx, None) + t_pos_and_prevs.append((t_pos, out, False)) + + for t_pos, prev, is_selected_cond_frame in t_pos_and_prevs: + if prev is None: + continue # skip padding frames + # "maskmem_features" might have been offloaded to CPU in demo use cases, + # so we load it back to GPU (it's a no-op if it's already on GPU). + feats = prev["maskmem_features"].cuda(non_blocking=True) + seq_len = feats.shape[-2] * feats.shape[-1] + to_cat_prompt.append(feats.flatten(2).permute(2, 0, 1)) + to_cat_prompt_mask.append( + torch.zeros(B, seq_len, device=device, dtype=bool) + ) + # Spatial positional encoding (it might have been offloaded to CPU in eval) + maskmem_enc = prev["maskmem_pos_enc"][-1].cuda() + maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1) + + if ( + is_selected_cond_frame + and getattr(self, "cond_frame_spatial_embedding", None) is not None + ): + # add a spatial embedding for the conditioning frame + maskmem_enc = maskmem_enc + self.cond_frame_spatial_embedding + + # Temporal positional encoding + t = t_pos if not is_selected_cond_frame else 0 + maskmem_enc = ( + maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t - 1] + ) + to_cat_prompt_pos_embed.append(maskmem_enc) + + # Construct the list of past object pointers + # Optionally, select only a subset of spatial memory frames during trainining + if ( + self.training + and self.prob_to_dropout_spatial_mem > 0 + and self.rng.random() < self.prob_to_dropout_spatial_mem + ): + num_spatial_mem_keep = self.rng.integers(len(to_cat_prompt) + 1) + keep = self.rng.choice( + range(len(to_cat_prompt)), num_spatial_mem_keep, replace=False + ).tolist() + to_cat_prompt = [to_cat_prompt[i] for i in keep] + to_cat_prompt_mask = [to_cat_prompt_mask[i] for i in keep] + to_cat_prompt_pos_embed = [to_cat_prompt_pos_embed[i] for i in keep] + + max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder) + # First add those object pointers from selected conditioning frames + # (optionally, only include object pointers in the past during evaluation) + if not self.training: + ptr_cond_outputs = { + t: out + for t, out in selected_cond_outputs.items() + if (t >= frame_idx if track_in_reverse else t <= frame_idx) + } + else: + ptr_cond_outputs = selected_cond_outputs + pos_and_ptrs = [ + # Temporal pos encoding contains how far away each pointer is from current frame + ( + (frame_idx - t) * tpos_sign_mul, + out["obj_ptr"], + True, # is_selected_cond_frame + ) + for t, out in ptr_cond_outputs.items() + ] + + # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame + for t_diff in range(1, max_obj_ptrs_in_encoder): + if not self.use_memory_selection: + t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff + if t < 0 or (num_frames is not None and t >= num_frames): + break + else: + if -t_diff <= -len(valid_indices): + break + t = valid_indices[-t_diff] + + out = output_dict["non_cond_frame_outputs"].get( + t, unselected_cond_outputs.get(t, None) + ) + if out is not None: + pos_and_ptrs.append((t_diff, out["obj_ptr"], False)) + + # If we have at least one object pointer, add them to the across attention + if len(pos_and_ptrs) > 0: + pos_list, ptrs_list, is_selected_cond_frame_list = zip(*pos_and_ptrs) + # stack object pointers along dim=0 into [ptr_seq_len, B, C] shape + obj_ptrs = torch.stack(ptrs_list, dim=0) + if getattr(self, "cond_frame_obj_ptr_embedding", None) is not None: + obj_ptrs = ( + obj_ptrs + + self.cond_frame_obj_ptr_embedding + * torch.tensor(is_selected_cond_frame_list, device=device)[ + ..., None, None + ].float() + ) + # a temporal positional embedding based on how far each object pointer is from + # the current frame (sine embedding normalized by the max pointer num). + obj_pos = self._get_tpos_enc( + pos_list, + max_abs_pos=max_obj_ptrs_in_encoder, + device=device, + ) + # expand to batch size + obj_pos = obj_pos.unsqueeze(1).expand(-1, B, -1) + + if self.mem_dim < C: + # split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C + obj_ptrs = obj_ptrs.reshape(-1, B, C // self.mem_dim, self.mem_dim) + obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1) + obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0) + to_cat_prompt.append(obj_ptrs) + to_cat_prompt_mask.append(None) # "to_cat_prompt_mask" is not used + to_cat_prompt_pos_embed.append(obj_pos) + num_obj_ptr_tokens = obj_ptrs.shape[0] + else: + num_obj_ptr_tokens = 0 + else: + # directly add no-mem embedding (instead of using the transformer encoder) + pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed + pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) + return pix_feat_with_mem + + # Use a dummy token on the first grame (to avoid emtpy memory input to tranformer encoder) + to_cat_prompt = [self.no_mem_embed.expand(1, B, self.mem_dim)] + to_cat_prompt_mask = [torch.zeros(B, 1, device=device, dtype=bool)] + to_cat_prompt_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)] + + # Step 2: Concatenate the memories and forward through the transformer encoder + prompt = torch.cat(to_cat_prompt, dim=0) + prompt_mask = None # For now, we always masks are zeros anyways + prompt_pos_embed = torch.cat(to_cat_prompt_pos_embed, dim=0) + encoder_out = self.transformer.encoder( + src=current_vision_feats, + src_key_padding_mask=[None], + src_pos=current_vision_pos_embeds, + prompt=prompt, + prompt_pos=prompt_pos_embed, + prompt_key_padding_mask=prompt_mask, + feat_sizes=feat_sizes, + num_obj_ptr_tokens=num_obj_ptr_tokens, + ) + # reshape the output (HW)BC => BCHW + pix_feat_with_mem = encoder_out["memory"].permute(1, 2, 0).view(B, C, H, W) + return pix_feat_with_mem + + def _encode_new_memory( + self, + image, + current_vision_feats, + feat_sizes, + pred_masks_high_res, + object_score_logits, + is_mask_from_pts, + output_dict=None, + is_init_cond_frame=False, + ): + """Encode the current image and its prediction into a memory feature.""" + B = current_vision_feats[-1].size(1) # batch size on this frame + C = self.hidden_dim + H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size + # top-level feature, (HW)BC => BCHW + pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) + if self.non_overlap_masks_for_mem_enc and not self.training: + # optionally, apply non-overlapping constraints to the masks (it's applied + # in the batch dimension and should only be used during eval, where all + # the objects come from the same video under batch size 1). + pred_masks_high_res = self._apply_non_overlapping_constraints( + pred_masks_high_res + ) + # scale the raw mask logits with a temperature before applying sigmoid + if is_mask_from_pts and not self.training: + mask_for_mem = (pred_masks_high_res > 0).float() + else: + # apply sigmoid on the raw mask logits to turn them into range (0, 1) + mask_for_mem = torch.sigmoid(pred_masks_high_res) + # apply scale and bias terms to the sigmoid probabilities + if self.sigmoid_scale_for_mem_enc != 1.0: + mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc + if self.sigmoid_bias_for_mem_enc != 0.0: + mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc + + if isinstance(self.maskmem_backbone, SimpleMaskEncoder): + pix_feat = pix_feat.view_as(pix_feat) + maskmem_out = self.maskmem_backbone( + pix_feat, mask_for_mem, skip_mask_sigmoid=True + ) + else: + maskmem_out = self.maskmem_backbone(image, pix_feat, mask_for_mem) + # Clone the feats and pos_enc to enable compilation + maskmem_features = self._maybe_clone(maskmem_out["vision_features"]) + maskmem_pos_enc = [self._maybe_clone(m) for m in maskmem_out["vision_pos_enc"]] + # add a no-object embedding to the spatial memory to indicate that the frame + # is predicted to be occluded (i.e. no object is appearing in the frame) + is_obj_appearing = (object_score_logits > 0).float() + maskmem_features += ( + 1 - is_obj_appearing[..., None, None] + ) * self.no_obj_embed_spatial[..., None, None].expand(*maskmem_features.shape) + + return maskmem_features, maskmem_pos_enc + + def forward_tracking(self, backbone_out, input, return_dict=False): + """Forward video tracking on each frame (and sample correction clicks).""" + img_feats_already_computed = backbone_out["backbone_fpn"] is not None + if img_feats_already_computed: + # Prepare the backbone features + # - vision_feats and vision_pos_embeds are in (HW)BC format + ( + _, + vision_feats, + vision_pos_embeds, + feat_sizes, + ) = self._prepare_backbone_features(backbone_out) + + # Starting the stage loop + num_frames = backbone_out["num_frames"] + init_cond_frames = backbone_out["init_cond_frames"] + frames_to_add_correction_pt = backbone_out["frames_to_add_correction_pt"] + # first process all the initial conditioning frames to encode them as memory, + # and then conditioning on them to track the remaining frames + processing_order = init_cond_frames + backbone_out["frames_not_in_init_cond"] + output_dict = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + for stage_id in processing_order: + # Get the image features for the current frames + img_ids = input.find_inputs[stage_id].img_ids + if img_feats_already_computed: + # Retrieve image features according to img_ids (if they are already computed). + current_image = input.img_batch[img_ids] + current_vision_feats = [x[:, img_ids] for x in vision_feats] + current_vision_pos_embeds = [x[:, img_ids] for x in vision_pos_embeds] + else: + # Otherwise, compute the image features on the fly for the given img_ids + # (this might be used for evaluation on long videos to avoid backbone OOM). + ( + current_image, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + ) = self._prepare_backbone_features_per_frame(input.img_batch, img_ids) + # Get output masks based on this frame's prompts and previous memory + current_out = self.track_step( + frame_idx=stage_id, + is_init_cond_frame=stage_id in init_cond_frames, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + image=current_image, + point_inputs=backbone_out["point_inputs_per_frame"].get(stage_id, None), + mask_inputs=backbone_out["mask_inputs_per_frame"].get(stage_id, None), + output_dict=output_dict, + num_frames=num_frames, + ) + # Append the output, depending on whether it's a conditioning frame + add_output_as_cond_frame = stage_id in init_cond_frames or ( + self.add_all_frames_to_correct_as_cond + and stage_id in frames_to_add_correction_pt + ) + if add_output_as_cond_frame: + output_dict["cond_frame_outputs"][stage_id] = current_out + else: + output_dict["non_cond_frame_outputs"][stage_id] = current_out + + if return_dict: + return output_dict + # turn `output_dict` into a list for loss function + all_frame_outputs = {} + all_frame_outputs.update(output_dict["cond_frame_outputs"]) + all_frame_outputs.update(output_dict["non_cond_frame_outputs"]) + all_frame_outputs = [all_frame_outputs[t] for t in range(num_frames)] + # Make DDP happy with activation checkpointing by removing unused keys + all_frame_outputs = [ + {k: v for k, v in d.items() if k != "obj_ptr"} for d in all_frame_outputs + ] + + return all_frame_outputs + + def track_step( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + image, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + # Whether to run the memory encoder on the predicted masks. Sometimes we might want + # to skip the memory encoder with `run_mem_encoder=False`. For example, + # in demo we might call `track_step` multiple times for each user click, + # and only encode the memory when the user finalizes their clicks. And in ablation + # settings like SAM training on static images, we don't need the memory encoder. + run_mem_encoder=True, + # The previously predicted SAM mask logits (which can be fed together with new clicks in demo). + prev_sam_mask_logits=None, + use_prev_mem_frame=True, + ): + current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs} + # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW + if len(current_vision_feats) > 1: + high_res_features = [ + x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) + for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1]) + ] + else: + high_res_features = None + if mask_inputs is not None: + # (see it as a GT mask) without using a SAM prompt encoder + mask decoder. + pix_feat = current_vision_feats[-1].permute(1, 2, 0) + pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1]) + sam_outputs = self._use_mask_as_output( + pix_feat, high_res_features, mask_inputs + ) + else: + # fused the visual feature with previous memory features in the memory bank + pix_feat_with_mem = self._prepare_memory_conditioned_features( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats[-1:], + current_vision_pos_embeds=current_vision_pos_embeds[-1:], + feat_sizes=feat_sizes[-1:], + output_dict=output_dict, + num_frames=num_frames, + track_in_reverse=track_in_reverse, + use_prev_mem_frame=use_prev_mem_frame, + ) + # apply SAM-style segmentation head + # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder, + # e.g. in demo where such logits come from earlier interaction instead of correction sampling + # (in this case, the SAM mask decoder should have `self.iter_use_prev_mask_pred=True`, and + # any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead) + if prev_sam_mask_logits is not None: + assert self.iter_use_prev_mask_pred + assert point_inputs is not None and mask_inputs is None + mask_inputs = prev_sam_mask_logits + multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) + sam_outputs = self._forward_sam_heads( + backbone_features=pix_feat_with_mem, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + high_res_features=high_res_features, + multimask_output=multimask_output, + ) + ( + _, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) = sam_outputs + # Use the final prediction (after all correction steps for output and eval) + current_out["pred_masks"] = low_res_masks + current_out["pred_masks_high_res"] = high_res_masks + current_out["obj_ptr"] = obj_ptr + if self.use_memory_selection: + current_out["object_score_logits"] = object_score_logits + iou_score = ious.max(-1)[0] + current_out["iou_score"] = iou_score + current_out["eff_iou_score"] = self.cal_mem_score( + object_score_logits, iou_score + ) + if not self.training: + # Only add this in inference (to avoid unused param in activation checkpointing; + # it's mainly used in the demo to encode spatial memories w/ consolidated masks) + current_out["object_score_logits"] = object_score_logits + + # Finally run the memory encoder on the predicted mask to encode + # it into a new memory feature (that can be used in future frames) + # (note that `self.num_maskmem == 0` is primarily used for reproducing SAM on + # images, in which case we'll just skip memory encoder to save compute). + if run_mem_encoder and self.num_maskmem > 0: + high_res_masks_for_mem_enc = high_res_masks + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + image=image, + current_vision_feats=current_vision_feats, + feat_sizes=feat_sizes, + pred_masks_high_res=high_res_masks_for_mem_enc, + object_score_logits=object_score_logits, + is_mask_from_pts=(point_inputs is not None), + output_dict=output_dict, + is_init_cond_frame=is_init_cond_frame, + ) + current_out["maskmem_features"] = maskmem_features + current_out["maskmem_pos_enc"] = maskmem_pos_enc + else: + current_out["maskmem_features"] = None + current_out["maskmem_pos_enc"] = None + + # Optionally, offload the outputs to CPU memory during evaluation to avoid + # GPU OOM on very long videos or very large resolution or too many objects + if self.offload_output_to_cpu_for_eval and not self.training: + # Here we only keep those keys needed for evaluation to get a compact output + trimmed_out = { + "pred_masks": current_out["pred_masks"].cpu(), + "pred_masks_high_res": current_out["pred_masks_high_res"].cpu(), + # other items for evaluation (these are small tensors so we keep them on GPU) + "obj_ptr": current_out["obj_ptr"], + "object_score_logits": current_out["object_score_logits"], + } + if run_mem_encoder and self.num_maskmem > 0: + trimmed_out["maskmem_features"] = maskmem_features.cpu() + trimmed_out["maskmem_pos_enc"] = [x.cpu() for x in maskmem_pos_enc] + if self.use_memory_selection: + trimmed_out["iou_score"] = current_out["iou_score"].cpu() + trimmed_out["eff_iou_score"] = current_out["eff_iou_score"].cpu() + current_out = trimmed_out + + # Optionally, trim the output of past non-conditioning frame (r * num_maskmem frames + # before the current frame) during evaluation. This is intended to save GPU or CPU + # memory for semi-supervised VOS eval, where only the first frame receives prompts. + def _trim_past_out(past_out, current_out): + if past_out is None: + return None + return { + "pred_masks": past_out["pred_masks"], + "obj_ptr": past_out["obj_ptr"], + "object_score_logits": past_out["object_score_logits"], + } + + if self.trim_past_non_cond_mem_for_eval and not self.training: + r = self.memory_temporal_stride_for_eval + past_frame_idx = frame_idx - r * self.num_maskmem + past_out = output_dict["non_cond_frame_outputs"].get(past_frame_idx, None) + + if past_out is not None: + print(past_out.get("eff_iou_score", 0)) + if ( + self.use_memory_selection + and past_out.get("eff_iou_score", 0) < self.mf_threshold + ) or not self.use_memory_selection: + output_dict["non_cond_frame_outputs"][past_frame_idx] = ( + _trim_past_out(past_out, current_out) + ) + + if ( + self.use_memory_selection and not self.offload_output_to_cpu_for_eval + ): ## design for memory selection, trim too old frames to save memory + far_old_frame_idx = frame_idx - 20 * self.max_obj_ptrs_in_encoder + past_out = output_dict["non_cond_frame_outputs"].get( + far_old_frame_idx, None + ) + if past_out is not None: + output_dict["non_cond_frame_outputs"][far_old_frame_idx] = ( + _trim_past_out(past_out, current_out) + ) + + return current_out + + def _use_multimask(self, is_init_cond_frame, point_inputs): + """Whether to use multimask output in the SAM head.""" + num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1) + multimask_output = ( + self.multimask_output_in_sam + and (is_init_cond_frame or self.multimask_output_for_tracking) + and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num) + ) + return multimask_output + + def _apply_non_overlapping_constraints(self, pred_masks): + """ + Apply non-overlapping constraints to the object scores in pred_masks. Here we + keep only the highest scoring object at each spatial location in pred_masks. + """ + batch_size = pred_masks.size(0) + if batch_size == 1: + return pred_masks + + device = pred_masks.device + # "max_obj_inds": object index of the object with the highest score at each location + max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True) + # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks` + batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None] + keep = max_obj_inds == batch_obj_inds + # suppress overlapping regions' scores below -10.0 so that the foreground regions + # don't overlap (here sigmoid(-10.0)=4.5398e-05) + pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0)) + return pred_masks + + def _compile_all_components(self): + """Compile all model components for faster inference.""" + # a larger cache size to hold varying number of shapes for torch.compile + # see https://github.com/pytorch/pytorch/blob/v2.5.1/torch/_dynamo/config.py#L42-L49 + torch._dynamo.config.cache_size_limit = 64 + torch._dynamo.config.accumulated_cache_size_limit = 2048 + from sam3.perflib.compile import compile_wrapper + + logging.info("Compiling all components. First time may be very slow.") + + self.maskmem_backbone.forward = compile_wrapper( + self.maskmem_backbone.forward, + mode="max-autotune", + fullgraph=True, + dynamic=False, + ) + self.transformer.encoder.forward = compile_wrapper( + self.transformer.encoder.forward, + mode="max-autotune", + fullgraph=True, + dynamic=True, # Num. of memories varies + ) + # We disable compilation of sam_prompt_encoder as it sometimes gives a large accuracy regression, + # especially when sam_mask_prompt (previous mask logits) is not None + # self.sam_prompt_encoder.forward = torch.compile( + # self.sam_prompt_encoder.forward, + # mode="max-autotune", + # fullgraph=True, + # dynamic=False, # Accuracy regression on True + # ) + self.sam_mask_decoder.forward = compile_wrapper( + self.sam_mask_decoder.forward, + mode="max-autotune", + fullgraph=True, + dynamic=False, # Accuracy regression on True + ) + + def _maybe_clone(self, x): + """Clone a tensor if and only if `self.compile_all_components` is True.""" + return x.clone() if self.compile_all_components else x + + +def concat_points(old_point_inputs, new_points, new_labels): + """Add new points and labels to previous point inputs (add at the end).""" + if old_point_inputs is None: + points, labels = new_points, new_labels + else: + points = torch.cat([old_point_inputs["point_coords"], new_points], dim=1) + labels = torch.cat([old_point_inputs["point_labels"], new_labels], dim=1) + + return {"point_coords": points, "point_labels": labels} diff --git a/third_party/sam3/sam3/model/sam3_tracker_utils.py b/third_party/sam3/sam3/model/sam3_tracker_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..74e9e124c0cfc2118d72940e0f64e4fd5cce36bd --- /dev/null +++ b/third_party/sam3/sam3/model/sam3_tracker_utils.py @@ -0,0 +1,438 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +import numpy as np +import torch +import torch.nn.functional as F +from numpy.typing import NDArray +from sam3.model.edt import edt_triton + + +def sample_box_points( + masks: torch.Tensor, + noise: float = 0.1, # SAM default + noise_bound: int = 20, # SAM default + top_left_label: int = 2, + bottom_right_label: int = 3, +) -> tuple[NDArray, NDArray]: + """ + Sample a noised version of the top left and bottom right corners of a given `bbox` + + Inputs: + - masks: [B, 1, H, W] tensor + - noise: noise as a fraction of box width and height, dtype=float + - noise_bound: maximum amount of noise (in pure pixels), dtype=int + + Returns: + - box_coords: [B, num_pt, 2], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.float + - box_labels: [B, num_pt], label 2 is reserverd for top left and 3 for bottom right corners, dtype=torch.int32 + """ + device = masks.device + box_coords = mask_to_box(masks) + B, _, H, W = masks.shape + box_labels = torch.tensor( + [top_left_label, bottom_right_label], dtype=torch.int, device=device + ).repeat(B) + if noise > 0.0: + if not isinstance(noise_bound, torch.Tensor): + noise_bound = torch.tensor(noise_bound, device=device) + bbox_w = box_coords[..., 2] - box_coords[..., 0] + bbox_h = box_coords[..., 3] - box_coords[..., 1] + max_dx = torch.min(bbox_w * noise, noise_bound) + max_dy = torch.min(bbox_h * noise, noise_bound) + box_noise = 2 * torch.rand(B, 1, 4, device=device) - 1 + box_noise = box_noise * torch.stack((max_dx, max_dy, max_dx, max_dy), dim=-1) + + box_coords = box_coords + box_noise + img_bounds = ( + torch.tensor([W, H, W, H], device=device) - 1 + ) # uncentered pixel coords + box_coords.clamp_(torch.zeros_like(img_bounds), img_bounds) # In place clamping + + box_coords = box_coords.reshape(-1, 2, 2) # always 2 points + box_labels = box_labels.reshape(-1, 2) + return box_coords, box_labels + + +def mask_to_box(masks: torch.Tensor): + """ + compute bounding box given an input mask + + Inputs: + - masks: [B, 1, H, W] tensor + + Returns: + - box_coords: [B, 1, 4], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.Tensor + """ + B, _, h, w = masks.shape + device = masks.device + mask_area = masks.sum(dim=(-1, -2)) + xs = torch.arange(w, device=device, dtype=torch.int32) + ys = torch.arange(h, device=device, dtype=torch.int32) + grid_xs, grid_ys = torch.meshgrid(xs, ys, indexing="xy") + grid_xs = grid_xs[None, None, ...].expand(B, 1, h, w) + grid_ys = grid_ys[None, None, ...].expand(B, 1, h, w) + min_xs, _ = torch.min(torch.where(masks, grid_xs, w).flatten(-2), dim=-1) + max_xs, _ = torch.max(torch.where(masks, grid_xs, -1).flatten(-2), dim=-1) + min_ys, _ = torch.min(torch.where(masks, grid_ys, h).flatten(-2), dim=-1) + max_ys, _ = torch.max(torch.where(masks, grid_ys, -1).flatten(-2), dim=-1) + bbox_coords = torch.stack((min_xs, min_ys, max_xs, max_ys), dim=-1) + bbox_coords = torch.where( + mask_area[..., None] > 0, bbox_coords, torch.zeros_like(bbox_coords) + ) + return bbox_coords + + +def sample_random_points_from_errors(gt_masks, pred_masks, num_pt=1): + """ + Sample `num_pt` random points (along with their labels) independently from the error regions. + + Inputs: + - gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool + - pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None + - num_pt: int, number of points to sample independently for each of the B error maps + + Outputs: + - points: [B, num_pt, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point + - labels: [B, num_pt], dtype=torch.int32, where 1 means positive clicks and 0 means + negative clicks + """ + if pred_masks is None: # if pred_masks is not provided, treat it as empty + pred_masks = torch.zeros_like(gt_masks) + assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1 + assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape + assert num_pt >= 0 + + B, _, H_im, W_im = gt_masks.shape + device = gt_masks.device + + # false positive region, a new point sampled in this region should have + # negative label to correct the FP error + fp_masks = ~gt_masks & pred_masks + # false negative region, a new point sampled in this region should have + # positive label to correct the FN error + fn_masks = gt_masks & ~pred_masks + # whether the prediction completely match the ground-truth on each mask + all_correct = torch.all((gt_masks == pred_masks).flatten(2), dim=2) + all_correct = all_correct[..., None, None] + + # channel 0 is FP map, while channel 1 is FN map + pts_noise = torch.rand(B, num_pt, H_im, W_im, 2, device=device) + # sample a negative new click from FP region or a positive new click + # from FN region, depend on where the maximum falls, + # and in case the predictions are all correct (no FP or FN), we just + # sample a negative click from the background region + pts_noise[..., 0] *= fp_masks | (all_correct & ~gt_masks) + pts_noise[..., 1] *= fn_masks + pts_idx = pts_noise.flatten(2).argmax(dim=2) + labels = (pts_idx % 2).to(torch.int32) + pts_idx = pts_idx // 2 + pts_x = pts_idx % W_im + pts_y = pts_idx // W_im + points = torch.stack([pts_x, pts_y], dim=2).to(torch.float) + return points, labels + + +def sample_one_point_from_error_center(gt_masks, pred_masks, padding=True): + """ + Sample 1 random point (along with its label) from the center of each error region, + that is, the point with the largest distance to the boundary of each error region. + This is the RITM sampling method from https://github.com/saic-vul/ritm_interactive_segmentation/blob/master/isegm/inference/clicker.py + + Inputs: + - gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool + - pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None + - padding: if True, pad with boundary of 1 px for distance transform + + Outputs: + - points: [B, 1, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point + - labels: [B, 1], dtype=torch.int32, where 1 means positive clicks and 0 means negative clicks + """ + if pred_masks is None: + pred_masks = torch.zeros_like(gt_masks) + assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1 + assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape + + B, _, H, W = gt_masks.shape + + # false positive region, a new point sampled in this region should have + # negative label to correct the FP error + fp_masks = (~gt_masks & pred_masks).squeeze(1) + # false negative region, a new point sampled in this region should have + # positive label to correct the FN error + fn_masks = (gt_masks & ~pred_masks).squeeze(1) + + if padding: + padded_fp_masks = torch.zeros( + B, H + 2, W + 2, dtype=fp_masks.dtype, device=fp_masks.device + ) + padded_fp_masks[:, 1 : H + 1, 1 : W + 1] = fp_masks + padded_fn_masks = torch.zeros( + B, H + 2, W + 2, dtype=fp_masks.dtype, device=fp_masks.device + ) + padded_fn_masks[:, 1 : H + 1, 1 : W + 1] = fn_masks + else: + padded_fp_masks = fp_masks + padded_fn_masks = fn_masks + + fn_mask_dt = edt_triton(padded_fn_masks) + fp_mask_dt = edt_triton(padded_fp_masks) + if padding: + fn_mask_dt = fn_mask_dt[:, 1:-1, 1:-1] + fp_mask_dt = fp_mask_dt[:, 1:-1, 1:-1] + + fn_max, fn_argmax = fn_mask_dt.reshape(B, -1).max(dim=-1) + fp_max, fp_argmax = fp_mask_dt.reshape(B, -1).max(dim=-1) + is_positive = fn_max > fp_max + chosen = torch.where(is_positive, fn_argmax, fp_argmax) + points_x = chosen % W + points_y = chosen // W + + labels = is_positive.long() + points = torch.stack([points_x, points_y], -1) + return points.unsqueeze(1), labels.unsqueeze(1) + + +def sample_one_point_from_error_center_slow(gt_masks, pred_masks, padding=True): + """ + Sample 1 random point (along with its label) from the center of each error region, + that is, the point with the largest distance to the boundary of each error region. + This is the RITM sampling method from https://github.com/saic-vul/ritm_interactive_segmentation/blob/master/isegm/inference/clicker.py + + Inputs: + - gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool + - pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None + - padding: if True, pad with boundary of 1 px for distance transform + + Outputs: + - points: [B, 1, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point + - labels: [B, 1], dtype=torch.int32, where 1 means positive clicks and 0 means negative clicks + """ + import cv2 # delay OpenCV import to avoid unnecessary dependency + + if pred_masks is None: + pred_masks = torch.zeros_like(gt_masks) + assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1 + assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape + + B, _, _, W_im = gt_masks.shape + device = gt_masks.device + + # false positive region, a new point sampled in this region should have + # negative label to correct the FP error + fp_masks = ~gt_masks & pred_masks + # false negative region, a new point sampled in this region should have + # positive label to correct the FN error + fn_masks = gt_masks & ~pred_masks + + fp_masks = fp_masks.cpu().numpy() + fn_masks = fn_masks.cpu().numpy() + points = torch.zeros(B, 1, 2, dtype=torch.float) + labels = torch.ones(B, 1, dtype=torch.int32) + for b in range(B): + fn_mask = fn_masks[b, 0] + fp_mask = fp_masks[b, 0] + if padding: + fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), "constant") + fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), "constant") + # compute the distance of each point in FN/FP region to its boundary + fn_mask_dt = cv2.distanceTransform(fn_mask.astype(np.uint8), cv2.DIST_L2, 0) + fp_mask_dt = cv2.distanceTransform(fp_mask.astype(np.uint8), cv2.DIST_L2, 0) + if padding: + fn_mask_dt = fn_mask_dt[1:-1, 1:-1] + fp_mask_dt = fp_mask_dt[1:-1, 1:-1] + + # take the point in FN/FP region with the largest distance to its boundary + fn_mask_dt_flat = fn_mask_dt.reshape(-1) + fp_mask_dt_flat = fp_mask_dt.reshape(-1) + fn_argmax = np.argmax(fn_mask_dt_flat) + fp_argmax = np.argmax(fp_mask_dt_flat) + is_positive = fn_mask_dt_flat[fn_argmax] > fp_mask_dt_flat[fp_argmax] + pt_idx = fn_argmax if is_positive else fp_argmax + points[b, 0, 0] = pt_idx % W_im # x + points[b, 0, 1] = pt_idx // W_im # y + labels[b, 0] = int(is_positive) + + points = points.to(device) + labels = labels.to(device) + return points, labels + + +def get_next_point(gt_masks, pred_masks, method): + if method == "uniform": + return sample_random_points_from_errors(gt_masks, pred_masks) + elif method == "center": + return sample_one_point_from_error_center(gt_masks, pred_masks) + else: + raise ValueError(f"unknown sampling method {method}") + + +def select_closest_cond_frames( + frame_idx, cond_frame_outputs, max_cond_frame_num, keep_first_cond_frame=False +): + """ + Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs` + that are temporally closest to the current frame at `frame_idx`. Here, we take + - a) the closest conditioning frame before `frame_idx` (if any); + - b) the closest conditioning frame after `frame_idx` (if any); + - c) any other temporally closest conditioning frames until reaching a total + of `max_cond_frame_num` conditioning frames. + + Outputs: + - selected_outputs: selected items (keys & values) from `cond_frame_outputs`. + - unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`. + """ + if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num: + selected_outputs = cond_frame_outputs + unselected_outputs = {} + else: + assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames" + selected_outputs = {} + if keep_first_cond_frame: + idx_first = min( + (t for t in cond_frame_outputs if t < frame_idx), default=None + ) + if idx_first is None: + # Maybe we are tracking in reverse + idx_first = max( + (t for t in cond_frame_outputs if t > frame_idx), default=None + ) + if idx_first is not None: + selected_outputs[idx_first] = cond_frame_outputs[idx_first] + # the closest conditioning frame before `frame_idx` (if any) + idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None) + if idx_before is not None: + selected_outputs[idx_before] = cond_frame_outputs[idx_before] + + # the closest conditioning frame after `frame_idx` (if any) + idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None) + if idx_after is not None: + selected_outputs[idx_after] = cond_frame_outputs[idx_after] + + # add other temporally closest conditioning frames until reaching a total + # of `max_cond_frame_num` conditioning frames. + num_remain = max_cond_frame_num - len(selected_outputs) + inds_remain = sorted( + (t for t in cond_frame_outputs if t not in selected_outputs), + key=lambda x: abs(x - frame_idx), + )[:num_remain] + selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain) + unselected_outputs = { + t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs + } + + return selected_outputs, unselected_outputs + + +def get_1d_sine_pe(pos_inds, dim, temperature=10000): + """ + Get 1D sine positional embedding as in the original Transformer paper. + """ + pe_dim = dim // 2 + dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device) + dim_t = temperature ** (2 * (dim_t // 2) / pe_dim) + + pos_embed = pos_inds.unsqueeze(-1) / dim_t + pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1) + return pos_embed + + +def get_best_gt_match_from_multimasks(pred_multimasks, gt_masks, pred_scores=None): + """ + Get the mask with the best match to GT masks (based on IoU) from pred_multimasks. + Optionally, use `pred_scores` to break ties in case all IoUs are zeros. + """ + assert pred_multimasks.ndim == 4 and gt_masks.ndim == 4 + if pred_multimasks.size(1) == 1: + return pred_multimasks # only a single mask channel, nothing to select + + pred_multimasks_binary = pred_multimasks > 0 + area_i = torch.sum(pred_multimasks_binary & gt_masks, dim=(2, 3)).float() + area_u = torch.sum(pred_multimasks_binary | gt_masks, dim=(2, 3)).float() + ious = area_i / torch.clamp(area_u, min=1.0) + + # In case all IoUs are zeros (e.g. because the GT mask is empty), use pred_scores + # to break ties and select the best mask + if pred_scores is not None: + has_nonzero_ious = torch.any(ious > 0).expand_as(ious) + scores = torch.where(has_nonzero_ious, ious, pred_scores) + else: + scores = ious + + # Finally, take the best mask prediction (with the highest score) + best_scores_inds = torch.argmax(scores, dim=-1) + batch_inds = torch.arange(scores.size(0), device=scores.device) + best_pred_mask = pred_multimasks[batch_inds, best_scores_inds].unsqueeze(1) + return best_pred_mask + + +def fill_holes_in_mask_scores( + mask, + max_area=None, + fill_holes=True, + remove_sprinkles=True, + fill_hole_area=None, + sprinkle_removal_area=None, +): + # Support onevision-style keyword args + if fill_hole_area is not None and max_area is None: + max_area = fill_hole_area + """ + A post processor to fill small holes in mask scores with area under `max_area`. + Holes are those small connected components in either background or foreground. + + Note that it relies on the "cc_torch" package to find connected components fast. You can + install it via the following command (`TORCH_CUDA_ARCH_LIST=8.0` is for A100 GPUs): + ``` + pip uninstall -y cc_torch; TORCH_CUDA_ARCH_LIST=8.0 9.0 pip install git+https://github.com/ronghanghu/cc_torch + ``` + Otherwise, it will fallback to a slightly slower triton implementation, or skimage if the tensor is on cpu + """ + + if max_area <= 0: + return mask # nothing to fill in this case + + if fill_holes: + # We remove small connected components in background by changing them to foreground + # with a small positive mask score (0.1). + mask_bg = mask <= 0 + bg_area_thresh = max_area + _, areas_bg = _get_connected_components_with_padding(mask_bg) + small_components_bg = mask_bg & (areas_bg <= bg_area_thresh) + mask = torch.where(small_components_bg, 0.1, mask) + + if remove_sprinkles: + # We remove small connected components in foreground by changing them to background + # with a small negative mask score (-0.1). Here we only remove connected components + # whose areas are under both `max_area` and half of the entire mask's area. This + # removes sprinkles while avoids filtering out tiny objects that we want to track. + mask_fg = mask > 0 + fg_area_thresh = torch.sum(mask_fg, dim=(2, 3), keepdim=True, dtype=torch.int32) + fg_area_thresh.floor_divide_(2).clamp_(max=max_area) + _, areas_fg = _get_connected_components_with_padding(mask_fg) + small_components_fg = mask_fg & (areas_fg <= fg_area_thresh) + mask = torch.where(small_components_fg, -0.1, mask) + return mask + + +def _get_connected_components_with_padding(mask): + """Get connected components from masks (possibly padding them to an even size).""" + from sam3.perflib.connected_components import connected_components + + mask = mask.to(torch.uint8) + _, _, H, W = mask.shape + # make sure both height and width are even (to be compatible with cc_torch) + pad_h = H % 2 + pad_w = W % 2 + if pad_h == 0 and pad_w == 0: + labels, counts = connected_components(mask) + else: + # pad the mask to make its height and width even + # padding format is (padding_left,padding_right,padding_top,padding_bottom) + mask_pad = F.pad(mask, (0, pad_w, 0, pad_h), mode="constant", value=0) + labels, counts = connected_components(mask_pad) + labels = labels[:, :, :H, :W] + counts = counts[:, :, :H, :W] + + return labels, counts diff --git a/third_party/sam3/sam3/model/sam3_tracking_predictor.py b/third_party/sam3/sam3/model/sam3_tracking_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..43b068b770697abbd95e24d2ae7a66212579607b --- /dev/null +++ b/third_party/sam3/sam3/model/sam3_tracking_predictor.py @@ -0,0 +1,1369 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +import logging +from collections import OrderedDict + +import torch +from sam3.model.sam3_tracker_base import concat_points, NO_OBJ_SCORE, Sam3TrackerBase +from sam3.model.sam3_tracker_utils import fill_holes_in_mask_scores +from sam3.model.utils.sam2_utils import load_video_frames +from tqdm.auto import tqdm + + +class Sam3TrackerPredictor(Sam3TrackerBase): + """ + The demo class that extends the `Sam3TrackerBase` to handle user interactions + and manage inference states, with support for multi-object tracking. + """ + + def __init__( + self, + # whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks; + # note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True) + clear_non_cond_mem_around_input=False, + # whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True). + clear_non_cond_mem_for_multi_obj=False, + # if fill_hole_area > 0, we fill small holes in the final masks up to this area (after resizing them to the original video resolution) + fill_hole_area=0, + # if always_start_from_first_ann_frame is True, we always start tracking from the frame where we receive the first annotation (clicks or mask) + # and ignore the `start_frame_idx` passed to `propagate_in_video` + always_start_from_first_ann_frame=False, + # the maximum number of points to be used in the prompt encoder, which reduce the domain gap between training (that only has 8 points) + # - if it's set to a positive integer, we only take the `max_point_num_in_prompt_enc//2` points and + # the last `(max_point_num_in_prompt_enc - max_point_num_in_prompt_enc//2)` points in the prompt encoder + # - if it's set to 0 or negative, this option is turned off and we use all points in the prompt encoder + max_point_num_in_prompt_enc=16, + non_overlap_masks_for_output=True, + # checkpoint_file=None, + **kwargs, + ): + super().__init__(**kwargs) + self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input + self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj + self.fill_hole_area = fill_hole_area + self.always_start_from_first_ann_frame = always_start_from_first_ann_frame + self.max_point_num_in_prompt_enc = max_point_num_in_prompt_enc + self.non_overlap_masks_for_output = non_overlap_masks_for_output + + self.bf16_context = torch.autocast(device_type="cuda", dtype=torch.bfloat16) + self.bf16_context.__enter__() # keep using for the entire model process + + self.iter_use_prev_mask_pred = True + self.add_all_frames_to_correct_as_cond = True + + @torch.inference_mode() + def init_state( + self, + video_height=None, + video_width=None, + num_frames=None, + video_path=None, + cached_features=None, + offload_video_to_cpu=False, + offload_state_to_cpu=False, + async_loading_frames=False, + ): + """Initialize a inference state.""" + inference_state = {} + # whether to offload the video frames to CPU memory + # turning on this option saves the GPU memory with only a very small overhead + inference_state["offload_video_to_cpu"] = offload_video_to_cpu + # whether to offload the inference state to CPU memory + # turning on this option saves the GPU memory at the cost of a lower tracking fps + # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object + # and from 24 to 21 when tracking two objects) + inference_state["offload_state_to_cpu"] = offload_state_to_cpu + inference_state["device"] = self.device + if offload_state_to_cpu: + inference_state["storage_device"] = torch.device("cpu") + else: + inference_state["storage_device"] = torch.device("cuda") + + if video_path is not None: + images, video_height, video_width = load_video_frames( + video_path=video_path, + image_size=self.image_size, + offload_video_to_cpu=offload_video_to_cpu, + async_loading_frames=async_loading_frames, + compute_device=inference_state["storage_device"], + ) + inference_state["images"] = images + inference_state["num_frames"] = len(images) + inference_state["video_height"] = video_height + inference_state["video_width"] = video_width + else: + # the original video height and width, used for resizing final output scores + inference_state["video_height"] = video_height + inference_state["video_width"] = video_width + inference_state["num_frames"] = num_frames + # inputs on each frame + inference_state["point_inputs_per_obj"] = {} + inference_state["mask_inputs_per_obj"] = {} + # visual features on a small number of recently visited frames for quick interactions + inference_state["cached_features"] = ( + {} if cached_features is None else cached_features + ) + # values that don't change across frames (so we only need to hold one copy of them) + inference_state["constants"] = {} + # mapping between client-side object id and model-side object index + inference_state["obj_id_to_idx"] = OrderedDict() + inference_state["obj_idx_to_id"] = OrderedDict() + inference_state["obj_ids"] = [] + # A storage to hold the model's tracking results and states on each frame + inference_state["output_dict"] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + # The index of the frame that received the first annotation + inference_state["first_ann_frame_idx"] = None + # Slice (view) of each object tracking results, sharing the same memory with "output_dict" + inference_state["output_dict_per_obj"] = {} + # A temporary storage to hold new outputs when user interact with a frame + # to add clicks or mask (it's merged into "output_dict" before propagation starts) + inference_state["temp_output_dict_per_obj"] = {} + # Frames that already holds consolidated outputs from click or mask inputs + # (we directly use their consolidated outputs during tracking) + inference_state["consolidated_frame_inds"] = { + "cond_frame_outputs": set(), # set containing frame indices + "non_cond_frame_outputs": set(), # set containing frame indices + } + # metadata for each tracking frame (e.g. which direction it's tracked) + inference_state["tracking_has_started"] = False + inference_state["frames_already_tracked"] = {} + self.clear_all_points_in_video(inference_state) + return inference_state + + def _obj_id_to_idx(self, inference_state, obj_id): + """Map client-side object id to model-side object index.""" + obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None) + if obj_idx is not None: + return obj_idx + + # This is a new object id not sent to the server before. We only allow adding + # new objects *before* the tracking starts. + allow_new_object = not inference_state["tracking_has_started"] + if allow_new_object: + # get the next object slot + obj_idx = len(inference_state["obj_id_to_idx"]) + inference_state["obj_id_to_idx"][obj_id] = obj_idx + inference_state["obj_idx_to_id"][obj_idx] = obj_id + inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"]) + # set up input and output structures for this object + inference_state["point_inputs_per_obj"][obj_idx] = {} + inference_state["mask_inputs_per_obj"][obj_idx] = {} + inference_state["output_dict_per_obj"][obj_idx] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + inference_state["temp_output_dict_per_obj"][obj_idx] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + return obj_idx + else: + raise RuntimeError( + f"Cannot add new object id {obj_id} after tracking starts. " + f"All existing object ids: {inference_state['obj_ids']}." + ) + + def _obj_idx_to_id(self, inference_state, obj_idx): + """Map model-side object index to client-side object id.""" + return inference_state["obj_idx_to_id"][obj_idx] + + def _get_obj_num(self, inference_state): + """Get the total number of unique object ids received so far in this session.""" + return len(inference_state["obj_idx_to_id"]) + + @torch.inference_mode() + def add_new_points_or_box( + self, + inference_state, + frame_idx, + obj_id, + points=None, + labels=None, + clear_old_points=True, + rel_coordinates=True, + use_prev_mem_frame=False, + normalize_coords=True, + box=None, + ): + """Add new points to a frame.""" + obj_idx = self._obj_id_to_idx(inference_state, obj_id) + point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx] + mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx] + + if (points is not None) != (labels is not None): + raise ValueError("points and labels must be provided together") + if points is None and box is None: + raise ValueError("at least one of points or box must be provided as input") + + if points is None: + points = torch.zeros(0, 2, dtype=torch.float32) + elif not isinstance(points, torch.Tensor): + points = torch.tensor(points, dtype=torch.float32) + if labels is None: + labels = torch.zeros(0, dtype=torch.int32) + elif not isinstance(labels, torch.Tensor): + labels = torch.tensor(labels, dtype=torch.int32) + if points.dim() == 2: + points = points.unsqueeze(0) # add batch dimension + if labels.dim() == 1: + labels = labels.unsqueeze(0) # add batch dimension + + if rel_coordinates: + # convert the points from relative coordinates to absolute coordinates + if points is not None: + points = points * self.image_size + if box is not None: + box = box * self.image_size + + # If `box` is provided, we add it as the first two points with labels 2 and 3 + # along with the user-provided points (consistent with how SAM 2 is trained). + if box is not None: + if not clear_old_points: + raise ValueError( + "cannot add box without clearing old points, since " + "box prompt must be provided before any point prompt " + "(please use clear_old_points=True instead)" + ) + if not isinstance(box, torch.Tensor): + box = torch.tensor(box, dtype=torch.float32, device=points.device) + box_coords = box.reshape(1, 2, 2) + box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device) + box_labels = box_labels.reshape(1, 2) + points = torch.cat([box_coords, points], dim=1) + labels = torch.cat([box_labels, labels], dim=1) + + points = points.to(inference_state["device"]) + labels = labels.to(inference_state["device"]) + + if not clear_old_points: + point_inputs = point_inputs_per_frame.get(frame_idx, None) + else: + point_inputs = None + point_inputs = concat_points(point_inputs, points, labels) + + point_inputs_per_frame[frame_idx] = point_inputs + mask_inputs_per_frame.pop(frame_idx, None) + # If this frame hasn't been tracked before, we treat it as an initial conditioning + # frame, meaning that the inputs points are to generate segments on this frame without + # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), + # the input points will be used to correct the already tracked masks. + is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"] + # whether to track in reverse time order + if is_init_cond_frame: + reverse = False + else: + reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"] + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] + # Add a frame to conditioning output if it's an initial conditioning frame or + # if the model sees all frames receiving clicks/mask as conditioning frames. + is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + + # Limit to a maximum number of input points to the prompt encoder (to reduce domain gap) + num_points = point_inputs["point_coords"].size(1) + if num_points > self.max_point_num_in_prompt_enc > 0: + num_first = self.max_point_num_in_prompt_enc // 2 + num_last = self.max_point_num_in_prompt_enc - num_first + point_inputs["point_coords"] = torch.cat( + [ + point_inputs["point_coords"][:, :num_first], + point_inputs["point_coords"][:, -num_last:], + ], + dim=1, + ) + point_inputs["point_labels"] = torch.cat( + [ + point_inputs["point_labels"][:, :num_first], + point_inputs["point_labels"][:, -num_last:], + ], + dim=1, + ) + logging.warning( + f"Too many points ({num_points}) are provided on frame {frame_idx}. Only " + f"the first {num_first} points and the last {num_last} points will be used." + ) + # Get any previously predicted mask logits on this object and feed it along with + # the new clicks into the SAM mask decoder when `self.iter_use_prev_mask_pred=True`. + prev_sam_mask_logits = None + if self.iter_use_prev_mask_pred: + # lookup temporary output dict first, which contains the most recent output + # (if not found, then lookup conditioning and non-conditioning frame output) + prev_out = obj_temp_output_dict[storage_key].get(frame_idx) + if prev_out is None: + prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx) + if prev_out is None: + prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx) + + if prev_out is not None and prev_out["pred_masks"] is not None: + prev_sam_mask_logits = prev_out["pred_masks"].cuda(non_blocking=True) + # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues. + prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0) + current_out, _ = self._run_single_frame_inference( + inference_state=inference_state, + output_dict=obj_output_dict, # run on the slice of a single object + frame_idx=frame_idx, + batch_size=1, # run on the slice of a single object + is_init_cond_frame=is_init_cond_frame, + point_inputs=point_inputs, + mask_inputs=None, + reverse=reverse, + # Skip the memory encoder when adding clicks or mask. We execute the memory encoder + # at the beginning of `propagate_in_video` (after user finalize their clicks). This + # allows us to enforce non-overlapping constraints on all objects before encoding + # them into memory. + run_mem_encoder=False, + prev_sam_mask_logits=prev_sam_mask_logits, + use_prev_mem_frame=use_prev_mem_frame, + ) + # Add the output to the output dict (to be used as future memory) + obj_temp_output_dict[storage_key][frame_idx] = current_out + + # Resize the output mask to the original video resolution + obj_ids = inference_state["obj_ids"] + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, + frame_idx, + is_cond=is_cond, + run_mem_encoder=False, + consolidate_at_video_res=True, + ) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, consolidated_out["pred_masks_video_res"] + ) + low_res_masks = None # not needed by the demo + return frame_idx, obj_ids, low_res_masks, video_res_masks + + @torch.inference_mode() + def add_new_mask( + self, + inference_state, + frame_idx, + obj_id, + mask, + add_mask_to_memory=False, + ): + """Add new mask to a frame.""" + obj_idx = self._obj_id_to_idx(inference_state, obj_id) + point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx] + mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx] + + assert mask.dim() == 2 + mask_H, mask_W = mask.shape + mask_inputs_orig = mask[None, None] # add batch and channel dimension + mask_inputs_orig = mask_inputs_orig.float().to(inference_state["device"]) + + # resize the mask if it doesn't match the model's input mask size + if mask_H != self.input_mask_size or mask_W != self.input_mask_size: + mask_inputs = torch.nn.functional.interpolate( + mask_inputs_orig, + size=(self.input_mask_size, self.input_mask_size), + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + else: + mask_inputs = mask_inputs_orig + + # also get the mask at the original video resolution (for outputting) + video_H = inference_state["video_height"] + video_W = inference_state["video_width"] + if mask_H != video_H or mask_W != video_W: + mask_inputs_video_res = torch.nn.functional.interpolate( + mask_inputs_orig, + size=(video_H, video_W), + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for potential downsampling + ) + else: + mask_inputs_video_res = mask_inputs_orig + # convert mask_inputs_video_res to binary (threshold at 0.5 as it is in range 0~1) + mask_inputs_video_res = mask_inputs_video_res > 0.5 + + mask_inputs_per_frame[frame_idx] = mask_inputs_video_res + point_inputs_per_frame.pop(frame_idx, None) + # If this frame hasn't been tracked before, we treat it as an initial conditioning + # frame, meaning that the inputs points are to generate segments on this frame without + # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), + # the input points will be used to correct the already tracked masks. + is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"] + # whether to track in reverse time order + if is_init_cond_frame: + reverse = False + else: + reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"] + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] + # Add a frame to conditioning output if it's an initial conditioning frame or + # if the model sees all frames receiving clicks/mask as conditioning frames. + is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + + current_out, _ = self._run_single_frame_inference( + inference_state=inference_state, + output_dict=obj_output_dict, # run on the slice of a single object + frame_idx=frame_idx, + batch_size=1, # run on the slice of a single object + is_init_cond_frame=is_init_cond_frame, + point_inputs=None, + mask_inputs=mask_inputs, + reverse=reverse, + # Skip the memory encoder when adding clicks or mask. We execute the memory encoder + # at the beginning of `propagate_in_video` (after user finalize their clicks). This + # allows us to enforce non-overlapping constraints on all objects before encoding + # them into memory. + run_mem_encoder=False, + ) + # We directly use the input mask at video resolution as the output mask for a better + # video editing experience (so that the masks don't change after each brushing). + # Here NO_OBJ_SCORE is a large negative value to represent the background and + # similarly -NO_OBJ_SCORE is a large positive value to represent the foreground. + current_out["pred_masks"] = None + current_out["pred_masks_video_res"] = torch.where( + mask_inputs_video_res, -NO_OBJ_SCORE, NO_OBJ_SCORE + ) + # Add the output to the output dict (to be used as future memory) + obj_temp_output_dict[storage_key][frame_idx] = current_out + # Remove the overlapping proportion of other objects' input masks on this frame + temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] + for obj_idx2, obj_temp_output_dict2 in temp_output_dict_per_obj.items(): + if obj_idx2 == obj_idx: + continue + current_out2 = obj_temp_output_dict2[storage_key].get(frame_idx, None) + if current_out2 is not None and "pred_masks_video_res" in current_out2: + current_out2["pred_masks_video_res"] = torch.where( + mask_inputs_video_res, + NO_OBJ_SCORE, + current_out2["pred_masks_video_res"], + ) + + # Resize the output mask to the original video resolution + obj_ids = inference_state["obj_ids"] + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, + frame_idx, + is_cond=is_cond, + run_mem_encoder=False, + consolidate_at_video_res=True, + ) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, consolidated_out["pred_masks_video_res"] + ) + low_res_masks = None # not needed by the demo + return frame_idx, obj_ids, low_res_masks, video_res_masks + + def add_new_points(self, *args, **kwargs): + """Deprecated method. Please use `add_new_points_or_box` instead.""" + return self.add_new_points_or_box(*args, **kwargs) + + def _get_orig_video_res_output(self, inference_state, any_res_masks): + """ + Resize the object scores to the original video resolution (video_res_masks) + and apply non-overlapping constraints for final output. + """ + device = inference_state["device"] + video_H = inference_state["video_height"] + video_W = inference_state["video_width"] + any_res_masks = any_res_masks.to(device, non_blocking=True) + if any_res_masks.shape[-2:] == (video_H, video_W): + video_res_masks = any_res_masks + else: + video_res_masks = torch.nn.functional.interpolate( + any_res_masks, + size=(video_H, video_W), + mode="bilinear", + align_corners=False, + ) + if self.non_overlap_masks_for_output: + video_res_masks = self._apply_non_overlapping_constraints(video_res_masks) + # potentially fill holes in the predicted masks + if self.fill_hole_area > 0: + video_res_masks = fill_holes_in_mask_scores( + video_res_masks, self.fill_hole_area + ) + return any_res_masks, video_res_masks + + def _consolidate_temp_output_across_obj( + self, + inference_state, + frame_idx, + is_cond, + run_mem_encoder, + consolidate_at_video_res=False, + ): + """ + Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on + a frame into a single output for all objects, including + 1) fill any missing objects either from `output_dict_per_obj` (if they exist in + `output_dict_per_obj` for this frame) or leave them as placeholder values + (if they don't exist in `output_dict_per_obj` for this frame); + 2) if specified, rerun memory encoder after apply non-overlapping constraints + on the object scores. + """ + batch_size = self._get_obj_num(inference_state) + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + # Optionally, we allow consolidating the temporary outputs at the original + # video resolution (to provide a better editing experience for mask prompts). + if consolidate_at_video_res: + assert not run_mem_encoder, "memory encoder cannot run at video resolution" + consolidated_H = inference_state["video_height"] + consolidated_W = inference_state["video_width"] + consolidated_mask_key = "pred_masks_video_res" + else: + consolidated_H = consolidated_W = self.low_res_mask_size + consolidated_mask_key = "pred_masks" + + # Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc" + # will be added when rerunning the memory encoder after applying non-overlapping + # constraints to object scores. Its "pred_masks" are prefilled with a large + # negative value (NO_OBJ_SCORE) to represent missing objects. + consolidated_out = { + "maskmem_features": None, + "maskmem_pos_enc": None, + consolidated_mask_key: torch.full( + size=(batch_size, 1, consolidated_H, consolidated_W), + fill_value=NO_OBJ_SCORE, + dtype=torch.float32, + device=inference_state["storage_device"], + ), + "obj_ptr": torch.full( + size=(batch_size, self.hidden_dim), + fill_value=NO_OBJ_SCORE, + dtype=torch.float32, + device=inference_state["device"], + ), + "object_score_logits": torch.full( + size=(batch_size, 1), + # default to 10.0 for object_score_logits, i.e. assuming the object is + # present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder` + fill_value=10.0, + dtype=torch.float32, + device=inference_state["device"], + ), + } + if self.use_memory_selection: + consolidated_out["iou_score"] = torch.full( + size=(batch_size, 1), + fill_value=0.0, + dtype=torch.float32, + device=inference_state["device"], + ) + empty_mask_ptr = None + for obj_idx in range(batch_size): + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + out = obj_temp_output_dict[storage_key].get(frame_idx, None) + # If the object doesn't appear in "temp_output_dict_per_obj" on this frame, + # we fall back and look up its previous output in "output_dict_per_obj". + # We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in + # "output_dict_per_obj" to find a previous output for this object. + if out is None: + out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None) + if out is None: + out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None) + # If the object doesn't appear in "output_dict_per_obj" either, we skip it + # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE + # placeholder above) and set its object pointer to be a dummy pointer. + if out is None: + # Fill in dummy object pointers for those objects without any inputs or + # tracking outcomes on this frame (only do it under `run_mem_encoder=True`, + # i.e. when we need to build the memory for tracking). + if run_mem_encoder: + if empty_mask_ptr is None: + empty_mask_ptr = self._get_empty_mask_ptr( + inference_state, frame_idx + ) + # fill object pointer with a dummy pointer (based on an empty mask) + consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = empty_mask_ptr + continue + # Add the temporary object output mask to consolidated output mask + # (use "pred_masks_video_res" if it's available) + obj_mask = out.get("pred_masks_video_res", out["pred_masks"]) + consolidated_pred_masks = consolidated_out[consolidated_mask_key] + if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]: + consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask + else: + # Resize first if temporary object mask has a different resolution + is_downsampling = "pred_masks_video_res" in out + resized_obj_mask = torch.nn.functional.interpolate( + obj_mask, + size=consolidated_pred_masks.shape[-2:], + mode="bilinear", + align_corners=False, + antialias=is_downsampling, # use antialias for downsampling + ) + consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask + consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"] + consolidated_out["object_score_logits"][obj_idx : obj_idx + 1] = out[ + "object_score_logits" + ] + if self.use_memory_selection: + consolidated_out["iou_score"][obj_idx : obj_idx + 1] = out["iou_score"] + # Optionally, apply non-overlapping constraints on the consolidated scores + # and rerun the memory encoder + if run_mem_encoder: + device = inference_state["device"] + high_res_masks = torch.nn.functional.interpolate( + consolidated_out["pred_masks"].to(device, non_blocking=True), + size=(self.image_size, self.image_size), + mode="bilinear", + align_corners=False, + ) + high_res_masks = self._apply_non_overlapping_constraints(high_res_masks) + maskmem_features, maskmem_pos_enc = self._run_memory_encoder( + inference_state=inference_state, + frame_idx=frame_idx, + batch_size=batch_size, + high_res_masks=high_res_masks, + object_score_logits=consolidated_out["object_score_logits"], + is_mask_from_pts=True, # these frames are what the user interacted with + ) + consolidated_out["maskmem_features"] = maskmem_features + consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc + + return consolidated_out + + def _get_empty_mask_ptr(self, inference_state, frame_idx): + """Get a dummy object pointer based on an empty mask on the current frame.""" + # A dummy (empty) mask with a single object + batch_size = 1 + mask_inputs = torch.zeros( + (batch_size, 1, self.image_size, self.image_size), + dtype=torch.float32, + device=inference_state["device"], + ) + + # Retrieve correct image features + ( + image, + _, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + ) = self._get_image_feature(inference_state, frame_idx, batch_size) + + # Feed the empty mask and image feature above to get a dummy object pointer + current_out = self.track_step( + frame_idx=frame_idx, + is_init_cond_frame=True, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + image=image, + point_inputs=None, + mask_inputs=mask_inputs, + output_dict={ + "cond_frame_outputs": {}, + "non_cond_frame_outputs": {}, + }, + num_frames=inference_state["num_frames"], + track_in_reverse=False, + run_mem_encoder=False, + prev_sam_mask_logits=None, + ) + return current_out["obj_ptr"] + + @torch.inference_mode() + def propagate_in_video_preflight(self, inference_state, run_mem_encoder=True): + """Prepare inference_state and consolidate temporary outputs before tracking.""" + # Tracking has started and we don't allow adding new objects until session is reset. + inference_state["tracking_has_started"] = True + batch_size = self._get_obj_num(inference_state) + + # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and + # add them into "output_dict". + temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] + output_dict = inference_state["output_dict"] + # "consolidated_frame_inds" contains indices of those frames where consolidated + # temporary outputs have been added (either in this call or any previous calls + # to `propagate_in_video_preflight`). + consolidated_frame_inds = inference_state["consolidated_frame_inds"] + for is_cond in [False, True]: + # Separately consolidate conditioning and non-conditioning temp outptus + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + # Find all the frames that contain temporary outputs for any objects + # (these should be the frames that have just received clicks for mask inputs + # via `add_new_points` or `add_new_mask`) + temp_frame_inds = set() + for obj_temp_output_dict in temp_output_dict_per_obj.values(): + temp_frame_inds.update(obj_temp_output_dict[storage_key].keys()) + consolidated_frame_inds[storage_key].update(temp_frame_inds) + # consolidate the temprary output across all objects on this frame + for frame_idx in temp_frame_inds: + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, + frame_idx, + is_cond=is_cond, + run_mem_encoder=run_mem_encoder, + ) + # merge them into "output_dict" and also create per-object slices + output_dict[storage_key][frame_idx] = consolidated_out + self._add_output_per_object( + inference_state, frame_idx, consolidated_out, storage_key + ) + clear_non_cond_mem = self.clear_non_cond_mem_around_input and ( + self.clear_non_cond_mem_for_multi_obj or batch_size <= 1 + ) + if clear_non_cond_mem: + # clear non-conditioning memory of the surrounding frames + self._clear_non_cond_mem_around_input(inference_state, frame_idx) + + # clear temporary outputs in `temp_output_dict_per_obj` + for obj_temp_output_dict in temp_output_dict_per_obj.values(): + obj_temp_output_dict[storage_key].clear() + + # edge case: if an output is added to "cond_frame_outputs", we remove any prior + # output on the same frame in "non_cond_frame_outputs" + for frame_idx in output_dict["cond_frame_outputs"]: + output_dict["non_cond_frame_outputs"].pop(frame_idx, None) + for obj_output_dict in inference_state["output_dict_per_obj"].values(): + for frame_idx in obj_output_dict["cond_frame_outputs"]: + obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None) + for frame_idx in consolidated_frame_inds["cond_frame_outputs"]: + assert frame_idx in output_dict["cond_frame_outputs"] + consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx) + + # Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames + # with either points or mask inputs (which should be true under a correct demo workflow). + all_consolidated_frame_inds = ( + consolidated_frame_inds["cond_frame_outputs"] + | consolidated_frame_inds["non_cond_frame_outputs"] + ) + input_frames_inds = set() + for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values(): + input_frames_inds.update(point_inputs_per_frame.keys()) + for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values(): + input_frames_inds.update(mask_inputs_per_frame.keys()) + assert all_consolidated_frame_inds == input_frames_inds + # Record the first interacted frame index (for tracking start) + if inference_state["first_ann_frame_idx"] is None: + inference_state["first_ann_frame_idx"] = min( + input_frames_inds, default=None + ) + # In case `first_ann_frame_idx` is not in the conditioning frames (e.g. because + # we cleared the input points on that frame), pick the first conditioning frame + if ( + inference_state["first_ann_frame_idx"] + not in output_dict["cond_frame_outputs"] + ): + inference_state["first_ann_frame_idx"] = min( + output_dict["cond_frame_outputs"], default=None + ) + + def _get_processing_order( + self, inference_state, start_frame_idx, max_frame_num_to_track, reverse + ): + num_frames = inference_state["num_frames"] + # set start index, end index, and processing order + if self.always_start_from_first_ann_frame: + # in this case, we always start tracking from the frame where we receive + # the initial annotation and ignore the provided start_frame_idx + start_frame_idx = inference_state["first_ann_frame_idx"] + if start_frame_idx is None: + # default: start from the earliest frame with input points + start_frame_idx = min(inference_state["output_dict"]["cond_frame_outputs"]) + if max_frame_num_to_track is None: + # default: track all the frames in the video + max_frame_num_to_track = num_frames + if reverse: + end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0) + if start_frame_idx > 0: + processing_order = range(start_frame_idx, end_frame_idx - 1, -1) + else: + # this is the edge case where we start from frame 0 and track in reverse order; + # in this case, we track a single frame (frame 0) + processing_order = [0] + else: + end_frame_idx = min( + start_frame_idx + max_frame_num_to_track, num_frames - 1 + ) + processing_order = range(start_frame_idx, end_frame_idx + 1) + return processing_order + + @torch.inference_mode() + def propagate_in_video( + self, + inference_state, + start_frame_idx, + max_frame_num_to_track, + reverse, + tqdm_disable=False, + obj_ids=None, + run_mem_encoder=True, + propagate_preflight=False, + ): + """Propagate the input points across frames to track in the entire video.""" + if propagate_preflight: + self.propagate_in_video_preflight(inference_state) + # NOTE: This is a copy from the parent class, except that we return object scores as well. + output_dict = inference_state["output_dict"] + consolidated_frame_inds = inference_state["consolidated_frame_inds"] + if obj_ids is not None: + raise NotImplementedError( + "Per-object tracking yet for batched inference if not implemented." + ) + obj_ids = inference_state["obj_ids"] + batch_size = self._get_obj_num(inference_state) + if len(output_dict["cond_frame_outputs"]) == 0: + raise RuntimeError("No points are provided; please add points first") + clear_non_cond_mem = self.clear_non_cond_mem_around_input and ( + self.clear_non_cond_mem_for_multi_obj or batch_size <= 1 + ) + + processing_order = self._get_processing_order( + inference_state, + start_frame_idx, + max_frame_num_to_track, + reverse, + ) + + for frame_idx in tqdm( + processing_order, desc="propagate in video", disable=tqdm_disable + ): + # We skip those frames already in consolidated outputs (these are frames + # that received input clicks or mask). Note that we cannot directly run + # batched forward on them via `_run_single_frame_inference` because the + # number of clicks on each object might be different. + if frame_idx in consolidated_frame_inds["cond_frame_outputs"]: + storage_key = "cond_frame_outputs" + current_out = output_dict[storage_key][frame_idx] + pred_masks = current_out["pred_masks"] + obj_scores = current_out["object_score_logits"] + if clear_non_cond_mem: + # clear non-conditioning memory of the surrounding frames + self._clear_non_cond_mem_around_input(inference_state, frame_idx) + elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]: + storage_key = "non_cond_frame_outputs" + current_out = output_dict[storage_key][frame_idx] + pred_masks = current_out["pred_masks"] + obj_scores = current_out["object_score_logits"] + else: + storage_key = "non_cond_frame_outputs" + current_out, pred_masks = self._run_single_frame_inference( + inference_state=inference_state, + output_dict=output_dict, + frame_idx=frame_idx, + batch_size=batch_size, + is_init_cond_frame=False, + point_inputs=None, + mask_inputs=None, + reverse=reverse, + run_mem_encoder=run_mem_encoder, + ) + obj_scores = current_out["object_score_logits"] + output_dict[storage_key][frame_idx] = current_out + # Create slices of per-object outputs for subsequent interaction with each + # individual object after tracking. + self._add_output_per_object( + inference_state, frame_idx, current_out, storage_key + ) + inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse} + + # Resize the output mask to the original video resolution (we directly use + # the mask scores on GPU for output to avoid any CPU conversion in between) + low_res_masks, video_res_masks = self._get_orig_video_res_output( + inference_state, pred_masks + ) + yield frame_idx, obj_ids, low_res_masks, video_res_masks, obj_scores + + def _add_output_per_object( + self, inference_state, frame_idx, current_out, storage_key + ): + """ + Split a multi-object output into per-object output slices and add them into + `output_dict_per_obj`. The resulting slices share the same tensor storage. + """ + maskmem_features = current_out["maskmem_features"] + assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor) + + maskmem_pos_enc = current_out["maskmem_pos_enc"] + assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list) + + output_dict_per_obj = inference_state["output_dict_per_obj"] + for obj_idx, obj_output_dict in output_dict_per_obj.items(): + obj_slice = slice(obj_idx, obj_idx + 1) + obj_out = { + "maskmem_features": None, + "maskmem_pos_enc": None, + "pred_masks": current_out["pred_masks"][obj_slice], + "obj_ptr": current_out["obj_ptr"][obj_slice], + "object_score_logits": current_out["object_score_logits"][obj_slice], + } + if self.use_memory_selection: + obj_out["iou_score"] = current_out["iou_score"][obj_slice] + if maskmem_features is not None: + obj_out["maskmem_features"] = maskmem_features[obj_slice] + if maskmem_pos_enc is not None: + obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc] + obj_output_dict[storage_key][frame_idx] = obj_out + + @torch.inference_mode() + def clear_all_points_in_frame( + self, inference_state, frame_idx, obj_id, need_output=True + ): + """Remove all input points or mask in a specific frame for a given object.""" + obj_idx = self._obj_id_to_idx(inference_state, obj_id) + + # Clear the conditioning information on the given frame + inference_state["point_inputs_per_obj"][obj_idx].pop(frame_idx, None) + inference_state["mask_inputs_per_obj"][obj_idx].pop(frame_idx, None) + + temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] + temp_output_dict_per_obj[obj_idx]["cond_frame_outputs"].pop(frame_idx, None) + temp_output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].pop(frame_idx, None) + + # Check and see if there are still any inputs left on this frame + batch_size = self._get_obj_num(inference_state) + frame_has_input = False + for obj_idx2 in range(batch_size): + if frame_idx in inference_state["point_inputs_per_obj"][obj_idx2]: + frame_has_input = True + break + if frame_idx in inference_state["mask_inputs_per_obj"][obj_idx2]: + frame_has_input = True + break + + # If this frame has no remaining inputs for any objects, we further clear its + # conditioning frame status + if not frame_has_input: + output_dict = inference_state["output_dict"] + consolidated_frame_inds = inference_state["consolidated_frame_inds"] + consolidated_frame_inds["cond_frame_outputs"].discard(frame_idx) + consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx) + # Remove the frame's conditioning output (possibly downgrading it to non-conditioning) + out = output_dict["cond_frame_outputs"].pop(frame_idx, None) + if out is not None: + # The frame is not a conditioning frame anymore since it's not receiving inputs, + # so we "downgrade" its output (if exists) to a non-conditioning frame output. + output_dict["non_cond_frame_outputs"][frame_idx] = out + inference_state["frames_already_tracked"].pop(frame_idx, None) + # Similarly, do it for the sliced output on each object. + for obj_idx2 in range(batch_size): + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx2] + obj_out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None) + if obj_out is not None: + obj_output_dict["non_cond_frame_outputs"][frame_idx] = obj_out + + # If all the conditioning frames have been removed, we also clear the tracking outputs + if len(output_dict["cond_frame_outputs"]) == 0: + self._reset_tracking_results(inference_state) + + if not need_output: + return + # Finally, output updated masks per object (after removing the inputs above) + obj_ids = inference_state["obj_ids"] + is_cond = any( + frame_idx in obj_temp_output_dict["cond_frame_outputs"] + for obj_temp_output_dict in temp_output_dict_per_obj.values() + ) + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, + frame_idx, + is_cond=is_cond, + run_mem_encoder=False, + consolidate_at_video_res=True, + ) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, consolidated_out["pred_masks_video_res"] + ) + low_res_masks = None # not needed by the demo + return frame_idx, obj_ids, low_res_masks, video_res_masks + + @torch.inference_mode() + def clear_all_points_in_video(self, inference_state): + """Remove all input points or mask in all frames throughout the video.""" + self._reset_tracking_results(inference_state) + # Remove all object ids + inference_state["obj_id_to_idx"].clear() + inference_state["obj_idx_to_id"].clear() + inference_state["obj_ids"].clear() + inference_state["point_inputs_per_obj"].clear() + inference_state["mask_inputs_per_obj"].clear() + inference_state["output_dict_per_obj"].clear() + inference_state["temp_output_dict_per_obj"].clear() + + def _reset_tracking_results(self, inference_state): + """Reset all tracking inputs and results across the videos.""" + for v in inference_state["point_inputs_per_obj"].values(): + v.clear() + for v in inference_state["mask_inputs_per_obj"].values(): + v.clear() + for v in inference_state["output_dict_per_obj"].values(): + v["cond_frame_outputs"].clear() + v["non_cond_frame_outputs"].clear() + for v in inference_state["temp_output_dict_per_obj"].values(): + v["cond_frame_outputs"].clear() + v["non_cond_frame_outputs"].clear() + inference_state["output_dict"]["cond_frame_outputs"].clear() + inference_state["output_dict"]["non_cond_frame_outputs"].clear() + inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear() + inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear() + inference_state["tracking_has_started"] = False + inference_state["frames_already_tracked"].clear() + inference_state["first_ann_frame_idx"] = None + + def _get_image_feature(self, inference_state, frame_idx, batch_size): + """Compute the image features on a given frame.""" + # Look up in the cache + image, backbone_out = inference_state["cached_features"].get( + frame_idx, (None, None) + ) + if backbone_out is None: + if self.backbone is None: + raise RuntimeError( + f"Image features for frame {frame_idx} are not cached. " + "Please run inference on this frame first." + ) + else: + # Cache miss -- we will run inference on a single image + image = inference_state["images"][frame_idx].cuda().float().unsqueeze(0) + backbone_out = self.forward_image(image) + # Cache the most recent frame's feature (for repeated interactions with + # a frame; we can use an LRU cache for more frames in the future). + inference_state["cached_features"] = {frame_idx: (image, backbone_out)} + if "tracker_backbone_out" in backbone_out: + backbone_out = backbone_out["tracker_backbone_out"] # get backbone output + + # expand the features to have the same dimension as the number of objects + expanded_image = image.expand(batch_size, -1, -1, -1) + expanded_backbone_out = { + "backbone_fpn": backbone_out["backbone_fpn"].copy(), + "vision_pos_enc": backbone_out["vision_pos_enc"].copy(), + } + for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]): + feat = feat.expand(batch_size, -1, -1, -1) + expanded_backbone_out["backbone_fpn"][i] = feat + for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]): + pos = pos.expand(batch_size, -1, -1, -1) + expanded_backbone_out["vision_pos_enc"][i] = pos + + features = self._prepare_backbone_features(expanded_backbone_out) + features = (expanded_image,) + features + return features + + def _run_single_frame_inference( + self, + inference_state, + output_dict, + frame_idx, + batch_size, + is_init_cond_frame, + point_inputs, + mask_inputs, + reverse, + run_mem_encoder, + prev_sam_mask_logits=None, + use_prev_mem_frame=True, + ): + """Run tracking on a single frame based on current inputs and previous memory.""" + # Retrieve correct image features + ( + image, + _, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + ) = self._get_image_feature(inference_state, frame_idx, batch_size) + + # point and mask should not appear as input simultaneously on the same frame + assert point_inputs is None or mask_inputs is None + current_out = self.track_step( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + image=image, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + output_dict=output_dict, + num_frames=inference_state["num_frames"], + track_in_reverse=reverse, + run_mem_encoder=run_mem_encoder, + prev_sam_mask_logits=prev_sam_mask_logits, + use_prev_mem_frame=use_prev_mem_frame, + ) + + # optionally offload the output to CPU memory to save GPU space + storage_device = inference_state["storage_device"] + maskmem_features = current_out["maskmem_features"] + if maskmem_features is not None: + maskmem_features = maskmem_features.to(torch.bfloat16) + maskmem_features = maskmem_features.to(storage_device, non_blocking=True) + pred_masks_gpu = current_out["pred_masks"] + pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True) + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out) + # object pointer is a small tensor, so we always keep it on GPU memory for fast access + obj_ptr = current_out["obj_ptr"] + object_score_logits = current_out["object_score_logits"] + # make a compact version of this frame's output to reduce the state size + compact_current_out = { + "maskmem_features": maskmem_features, + "maskmem_pos_enc": maskmem_pos_enc, + "pred_masks": pred_masks, + "obj_ptr": obj_ptr, + "object_score_logits": object_score_logits, + } + if self.use_memory_selection: + compact_current_out["iou_score"] = current_out["iou_score"] + compact_current_out["eff_iou_score"] = current_out["eff_iou_score"] + return compact_current_out, pred_masks_gpu + + def _run_memory_encoder( + self, + inference_state, + frame_idx, + batch_size, + high_res_masks, + object_score_logits, + is_mask_from_pts, + ): + """ + Run the memory encoder on `high_res_masks`. This is usually after applying + non-overlapping constraints to object scores. Since their scores changed, their + memory also need to be computed again with the memory encoder. + """ + # Retrieve correct image features + image, _, current_vision_feats, _, feat_sizes = self._get_image_feature( + inference_state, frame_idx, batch_size + ) + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + image=image, + current_vision_feats=current_vision_feats, + feat_sizes=feat_sizes, + pred_masks_high_res=high_res_masks, + object_score_logits=object_score_logits, + is_mask_from_pts=is_mask_from_pts, + ) + + # optionally offload the output to CPU memory to save GPU space + storage_device = inference_state["storage_device"] + maskmem_features = maskmem_features.to(torch.bfloat16) + maskmem_features = maskmem_features.to(storage_device, non_blocking=True) + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + maskmem_pos_enc = self._get_maskmem_pos_enc( + inference_state, {"maskmem_pos_enc": maskmem_pos_enc} + ) + return maskmem_features, maskmem_pos_enc + + def _get_maskmem_pos_enc(self, inference_state, current_out): + """ + `maskmem_pos_enc` is the same across frames and objects, so we cache it as + a constant in the inference session to reduce session storage size. + """ + model_constants = inference_state["constants"] + # "out_maskmem_pos_enc" should be either a list of tensors or None + out_maskmem_pos_enc = current_out["maskmem_pos_enc"] + if out_maskmem_pos_enc is not None: + if "maskmem_pos_enc" not in model_constants: + assert isinstance(out_maskmem_pos_enc, list) + # only take the slice for one object, since it's same across objects + maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc] + model_constants["maskmem_pos_enc"] = maskmem_pos_enc + else: + maskmem_pos_enc = model_constants["maskmem_pos_enc"] + # expand the cached maskmem_pos_enc to the actual batch size + batch_size = out_maskmem_pos_enc[0].size(0) + expanded_maskmem_pos_enc = [ + x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc + ] + else: + expanded_maskmem_pos_enc = None + return expanded_maskmem_pos_enc + + @torch.inference_mode() + def remove_object(self, inference_state, obj_id, strict=False, need_output=True): + """ + Remove an object id from the tracking state. If strict is True, we check whether + the object id actually exists and raise an error if it doesn't exist. + """ + old_obj_idx_to_rm = inference_state["obj_id_to_idx"].get(obj_id, None) + updated_frames = [] + # Check whether this object_id to remove actually exists and possibly raise an error. + if old_obj_idx_to_rm is None: + if not strict: + return inference_state["obj_ids"], updated_frames + raise RuntimeError( + f"Cannot remove object id {obj_id} as it doesn't exist. " + f"All existing object ids: {inference_state['obj_ids']}." + ) + + # If this is the only remaining object id, we simply reset the state. + if len(inference_state["obj_id_to_idx"]) == 1: + self.clear_all_points_in_video(inference_state) + return inference_state["obj_ids"], updated_frames + + # There are still remaining objects after removing this object id. In this case, + # we need to delete the object storage from inference state tensors. + # Step 0: clear the input on those frames where this object id has point or mask input + # (note that this step is required as it might downgrade conditioning frames to + # non-conditioning ones) + obj_input_frames_inds = set() + obj_input_frames_inds.update( + inference_state["point_inputs_per_obj"][old_obj_idx_to_rm] + ) + obj_input_frames_inds.update( + inference_state["mask_inputs_per_obj"][old_obj_idx_to_rm] + ) + for frame_idx in obj_input_frames_inds: + self.clear_all_points_in_frame( + inference_state, frame_idx, obj_id, need_output=False + ) + + # Step 1: Update the object id mapping (note that it must be done after Step 0, + # since Step 0 still requires the old object id mappings in inference_state) + old_obj_ids = inference_state["obj_ids"] + old_obj_inds = list(range(len(old_obj_ids))) + remain_old_obj_inds = old_obj_inds.copy() + remain_old_obj_inds.remove(old_obj_idx_to_rm) + new_obj_ids = [old_obj_ids[old_idx] for old_idx in remain_old_obj_inds] + new_obj_inds = list(range(len(new_obj_ids))) + # build new mappings + old_idx_to_new_idx = dict(zip(remain_old_obj_inds, new_obj_inds)) + inference_state["obj_id_to_idx"] = dict(zip(new_obj_ids, new_obj_inds)) + inference_state["obj_idx_to_id"] = dict(zip(new_obj_inds, new_obj_ids)) + inference_state["obj_ids"] = new_obj_ids + + # Step 2: For per-object tensor storage, we shift their obj_idx in the dict keys. + # (note that "consolidated_frame_inds" doesn't need to be updated in this step as + # it's already handled in Step 0) + def _map_keys(container): + new_kvs = [] + for k in old_obj_inds: + v = container.pop(k) + if k in old_idx_to_new_idx: + new_kvs.append((old_idx_to_new_idx[k], v)) + container.update(new_kvs) + + _map_keys(inference_state["point_inputs_per_obj"]) + _map_keys(inference_state["mask_inputs_per_obj"]) + _map_keys(inference_state["output_dict_per_obj"]) + _map_keys(inference_state["temp_output_dict_per_obj"]) + + # Step 3: For packed tensor storage, we index the remaining ids and rebuild the per-object slices. + def _slice_state(output_dict, storage_key): + for frame_idx, out in output_dict[storage_key].items(): + out["maskmem_features"] = out["maskmem_features"][remain_old_obj_inds] + out["maskmem_pos_enc"] = [ + x[remain_old_obj_inds] for x in out["maskmem_pos_enc"] + ] + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(inference_state, out) + out["pred_masks"] = out["pred_masks"][remain_old_obj_inds] + out["obj_ptr"] = out["obj_ptr"][remain_old_obj_inds] + out["object_score_logits"] = out["object_score_logits"][ + remain_old_obj_inds + ] + if self.use_memory_selection: + out["iou_score"] = out["iou_score"][remain_old_obj_inds] + out["eff_iou_score"] = self.cal_mem_score( + out["object_score_logits"], out["iou_score"] + ) # recalculate the memory frame score + # also update the per-object slices + self._add_output_per_object( + inference_state, frame_idx, out, storage_key + ) + + _slice_state(inference_state["output_dict"], "cond_frame_outputs") + _slice_state(inference_state["output_dict"], "non_cond_frame_outputs") + + # Step 4: Further collect the outputs on those frames in `obj_input_frames_inds`, which + # could show an updated mask for objects previously occluded by the object being removed + if need_output: + temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] + for frame_idx in obj_input_frames_inds: + is_cond = any( + frame_idx in obj_temp_output_dict["cond_frame_outputs"] + for obj_temp_output_dict in temp_output_dict_per_obj.values() + ) + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, + frame_idx, + is_cond=is_cond, + run_mem_encoder=False, + consolidate_at_video_res=True, + ) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, consolidated_out["pred_masks_video_res"] + ) + updated_frames.append((frame_idx, video_res_masks)) + + return inference_state["obj_ids"], updated_frames + + def _clear_non_cond_mem_around_input(self, inference_state, frame_idx): + """ + Remove the non-conditioning memory around the input frame. When users provide + correction clicks, the surrounding frames' non-conditioning memories can still + contain outdated object appearance information and could confuse the model. + + This method clears those non-conditioning memories surrounding the interacted + frame to avoid giving the model both old and new information about the object. + """ + r = self.memory_temporal_stride_for_eval + frame_idx_begin = frame_idx - r * self.num_maskmem + frame_idx_end = frame_idx + r * self.num_maskmem + batch_size = self._get_obj_num(inference_state) + for obj_idx in range(batch_size): + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + non_cond_frame_outputs = obj_output_dict["non_cond_frame_outputs"] + for t in range(frame_idx_begin, frame_idx_end + 1): + non_cond_frame_outputs.pop(t, None) + + def _suppress_shrinked_masks( + self, pred_masks, new_pred_masks, shrink_threshold=0.3 + ): + area_before = (pred_masks > 0).sum(dim=(-1, -2)) + area_after = (new_pred_masks > 0).sum(dim=(-1, -2)) + area_before = torch.clamp(area_before, min=1.0) + area_ratio = area_after / area_before + keep = area_ratio >= shrink_threshold + keep_mask = keep[..., None, None].expand_as(pred_masks) + pred_masks_after = torch.where( + keep_mask, pred_masks, torch.clamp(pred_masks, max=-10.0) + ) + return pred_masks_after + + def _suppress_object_pw_area_shrinkage(self, pred_masks): + """ + This function suppresses masks that shrink in area after applying pixelwise non-overlapping constriants. + Note that the final output can still be overlapping. + """ + # Apply pixel-wise non-overlapping constraint based on mask scores + pixel_level_non_overlapping_masks = super()._apply_non_overlapping_constraints( + pred_masks + ) + # Fully suppress masks with high shrinkage (probably noisy) based on the pixel wise non-overlapping constraints + # NOTE: The output of this function can be a no op if none of the masks shrinked by a large factor. + pred_masks = self._suppress_shrinked_masks( + pred_masks, pixel_level_non_overlapping_masks + ) + return pred_masks + + def _apply_object_wise_non_overlapping_constraints( + self, pred_masks, obj_scores, background_value=-10.0 + ): + """ + Applies non-overlapping constraints object wise (i.e. only one object can claim the overlapping region) + """ + # Replace pixel scores with object scores + pred_masks_single_score = torch.where( + pred_masks > 0, obj_scores[..., None, None], background_value + ) + # Apply pixel-wise non-overlapping constraint based on mask scores + pixel_level_non_overlapping_masks = super()._apply_non_overlapping_constraints( + pred_masks_single_score + ) + # Replace object scores with pixel scores. Note, that now only one object can claim the overlapping region + pred_masks = torch.where( + pixel_level_non_overlapping_masks > 0, + pred_masks, + torch.clamp(pred_masks, max=background_value), + ) + return pred_masks diff --git a/third_party/sam3/sam3/model/sam3_video_base.py b/third_party/sam3/sam3/model/sam3_video_base.py new file mode 100644 index 0000000000000000000000000000000000000000..a77bdd3ea8d1cce6ddabb213e17d9d60da195e6f --- /dev/null +++ b/third_party/sam3/sam3/model/sam3_video_base.py @@ -0,0 +1,1997 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +import datetime +import logging +import math +import os +from collections import defaultdict +from copy import deepcopy +from dataclasses import dataclass +from enum import Enum +from typing import Any, Dict, List, Optional, Set, Tuple + +import numpy as np +import numpy.typing as npt +import torch +import torch.distributed as dist +import torch.nn.functional as F +from sam3 import perflib +from sam3.logger import get_logger +from sam3.model.box_ops import fast_diag_box_iou +from sam3.model.data_misc import BatchedDatapoint +from sam3.model.sam3_tracker_utils import fill_holes_in_mask_scores, mask_to_box +from sam3.perflib.masks_ops import mask_iou +from sam3.train.masks_ops import mask_iom, rle_encode +from torch import nn, Tensor + +logger = get_logger(__name__) + + +class MaskletConfirmationStatus(Enum): + UNCONFIRMED = 1 # newly added masklet, not confirmed by any detection yet + CONFIRMED = 2 # confirmed by at least one detection + + +@dataclass +class RealizedAssociateDetTrkresult: + new_det_fa_inds: np.array + unmatched_trk_obj_ids: np.array + det_to_matched_trk_obj_ids: Dict[int, np.array] + trk_id_to_max_iou_high_conf_det: Dict[int, int] + empty_trk_obj_ids: np.array + new_det_obj_ids: Optional[np.array] = None + new_det_gpu_ids: Optional[np.array] = None + num_obj_dropped_due_to_limit: Optional[int] = None + + def get_new_det_gpu_ids( + self, tracker_metadata_prev, is_image_only, det_scores, tracking_obj + ): + with torch.profiler.record_function("get_new_det_gpu_ids"): + if self.new_det_obj_ids is None: + det_scores_np: np.ndarray = det_scores.cpu().numpy() + prev_obj_num = np.sum(tracker_metadata_prev["num_obj_per_gpu"]) + new_det_num = len(self.new_det_fa_inds) + num_obj_dropped_due_to_limit = 0 + if ( + not is_image_only + and prev_obj_num + new_det_num > tracking_obj.max_num_objects + ): + logger.warning( + f"hitting {tracking_obj.max_num_objects=} with {new_det_num=} and {prev_obj_num=}" + ) + new_det_num_to_keep = tracking_obj.max_num_objects - prev_obj_num + num_obj_dropped_due_to_limit = new_det_num - new_det_num_to_keep + self.new_det_fa_inds = tracking_obj._drop_new_det_with_obj_limit( + self.new_det_fa_inds, det_scores_np, new_det_num_to_keep + ) + assert len(self.new_det_fa_inds) == new_det_num_to_keep + new_det_num = len(self.new_det_fa_inds) + new_det_start_obj_id = tracker_metadata_prev["max_obj_id"] + 1 + new_det_obj_ids = new_det_start_obj_id + np.arange(new_det_num) + if tracking_obj.is_multiplex: + prev_workload_per_gpu = tracker_metadata_prev["num_buc_per_gpu"] + else: + prev_workload_per_gpu = tracker_metadata_prev["num_obj_per_gpu"] + new_det_gpu_ids = tracking_obj._assign_new_det_to_gpus( + new_det_num=new_det_num, + prev_workload_per_gpu=prev_workload_per_gpu, + ) + self.new_det_obj_ids = new_det_obj_ids + self.new_det_gpu_ids = new_det_gpu_ids + self.num_obj_dropped_due_to_limit = num_obj_dropped_due_to_limit + return ( + self.new_det_obj_ids, + self.new_det_gpu_ids, + self.num_obj_dropped_due_to_limit, + ) + + +def realize_adt_result(adt_lazy_result, tracker_metadata_prev, det_mask_preds): + if isinstance(adt_lazy_result, LazyAssociateDetTrkResult): + adt_lazy_result._convert_to_numpy() + return adt_lazy_result._create_cpu_metadata( + tracker_metadata_prev["obj_ids_all_gpu"], det_mask_preds + ) + return adt_lazy_result + + +class LazyAssociateDetTrkResult: + def __init__( + self, + trk_is_unmatched: Tensor, + trk_is_nonempty: Tensor, + is_new_det: Tensor, + det_to_max_iou_trk_idx: Tensor, + det_is_high_conf: Tensor, + det_is_high_iou: Tensor, + det_keep: Tensor, + im_mask: Tensor, + ): + self.trk_is_unmatched = trk_is_unmatched + self.trk_is_nonempty = trk_is_nonempty + self.is_new_det = is_new_det + self.det_to_max_iou_trk_idx = det_to_max_iou_trk_idx + self.det_is_high_conf = det_is_high_conf + self.det_is_high_iou = det_is_high_iou + self.det_keep = det_keep + self.im_mask = im_mask + + def _convert_to_numpy(self): + with torch.profiler.record_function("Convert to numpy"): + self.trk_is_unmatched = self.trk_is_unmatched.cpu().numpy() + self.trk_is_nonempty = self.trk_is_nonempty.cpu().numpy() + self.is_new_det = self.is_new_det.cpu().numpy() + self.det_to_max_iou_trk_idx = self.det_to_max_iou_trk_idx.cpu().numpy() + self.det_is_high_conf = self.det_is_high_conf.cpu().numpy() + self.det_is_high_iou = self.det_is_high_iou.cpu().numpy() + self.det_keep = self.det_keep.cpu().numpy().tolist() + self.im_mask = self.im_mask.cpu().numpy() + + def _create_cpu_metadata(self, trk_obj_ids, det_masks): + with torch.profiler.record_function("_create_cpu_metadata"): + unmatched_trk_obj_ids = trk_obj_ids[self.trk_is_unmatched] + empty_trk_obj_ids = trk_obj_ids[~self.trk_is_nonempty] + new_det_fa_inds = np.nonzero(self.is_new_det)[0] + det_is_high_conf_and_iou = set( + np.nonzero(self.det_is_high_conf & self.det_is_high_iou)[0] + ) + det_to_matched_trk_obj_ids = {} + trk_id_to_max_iou_high_conf_det = {} + for d in range(det_masks.size(0)): + if self.det_keep[d]: + det_to_matched_trk_obj_ids[d] = trk_obj_ids[self.im_mask[d, :]] + if d in det_is_high_conf_and_iou: + trk_obj_id = trk_obj_ids[self.det_to_max_iou_trk_idx[d]].item() + trk_id_to_max_iou_high_conf_det[trk_obj_id] = d + return RealizedAssociateDetTrkresult( + new_det_fa_inds=new_det_fa_inds, + unmatched_trk_obj_ids=unmatched_trk_obj_ids, + det_to_matched_trk_obj_ids=det_to_matched_trk_obj_ids, + trk_id_to_max_iou_high_conf_det=trk_id_to_max_iou_high_conf_det, + empty_trk_obj_ids=empty_trk_obj_ids, + ) + + +def _associate_det_trk_compilable( + det_masks, + det_scores, + det_keep, + trk_masks, + new_det_thresh, + iou_threshold_trk, + iou_threshold, + HIGH_CONF_THRESH, + use_iom_recondition, + o2o_matching_masklets_enable, + iom_thresh_recondition, + iou_thresh_recondition, +): + det_masks_binary = det_masks > 0 + det_masks_binary[~det_keep] = 0 + trk_masks_binary = trk_masks > 0 + intersection_metric = None + if use_iom_recondition: + intersection_metric = mask_iom(det_masks_binary, trk_masks_binary) # (N, M) + else: + intersection_metric = mask_iou(det_masks_binary, trk_masks_binary) # (N, M) + + assert not o2o_matching_masklets_enable, "Temporarily disabled support for o2o_matching_masklets_enable, due to optimizations." + + if o2o_matching_masklets_enable: + intersection_metric_np = intersection_metric.cpu().numpy() + from scipy.optimize import linear_sum_assignment + + cost_matrix = 1 - intersection_metric_np + row_ind, col_ind = linear_sum_assignment(cost_matrix) + trk_is_matched = np.zeros(trk_masks.size(0), dtype=bool) + for d, t in zip(row_ind, col_ind): + if intersection_metric_np[d, t] >= iou_threshold_trk: + trk_is_matched[t] = True + trk_is_matched = torch.from_numpy(trk_is_matched) + trk_is_matched = trk_is_matched.to(device=intersection_metric.device) + else: + trk_is_matched = (intersection_metric >= iou_threshold_trk).any(dim=0) + # Non-empty tracks not matched by Hungarian assignment above threshold are unmatched + trk_is_nonempty = trk_masks_binary.any(dim=(1, 2)) + trk_is_unmatched = torch.logical_and(trk_is_nonempty, ~trk_is_matched) + + # For detections: allow many tracks to match to the same detection (many-to-one) + # So, a detection is 'new' if it does not match any track above threshold + is_new_det = torch.logical_and( + torch.logical_and((det_scores >= new_det_thresh), (det_keep)), + torch.logical_not(torch.any(intersection_metric >= iou_threshold, dim=1)), + ) + + intersection_thresh_recond = ( + iom_thresh_recondition if use_iom_recondition else iou_thresh_recondition + ) + # if a detection matches to many tracks with high IoU or vice versa, we do not consider it for reconditioning as it might be ambiguous + det_match_to_many_trk = (intersection_metric >= intersection_thresh_recond).sum( + dim=1 + ) > 1 + trk_match_to_many_det = (intersection_metric >= intersection_thresh_recond).sum( + dim=0 + ) > 1 + # # zero out these ambiguous matches + # intersection_metric[:, trk_match_to_many_det] = ( + # 0.0 # only consider unique matches + # ) + + # intersection_metric[det_match_to_many_trk, :] = ( + # 0.0 # only consider unique matches + # ) + + intersection_metric = torch.where( + trk_match_to_many_det.unsqueeze(0), + torch.zeros_like(intersection_metric), + intersection_metric, + ) + + intersection_metric = torch.where( + det_match_to_many_trk.unsqueeze(1), + torch.zeros_like(intersection_metric), + intersection_metric, + ) + + det_to_max_iou_trk_idx = torch.argmax(intersection_metric, dim=1) + det_is_high_conf = ((det_scores >= HIGH_CONF_THRESH) & det_keep) & ~is_new_det + det_is_high_iou = ( + torch.amax(intersection_metric, dim=1) >= intersection_thresh_recond + ) + im_mask = intersection_metric >= iou_threshold + + return ( + trk_is_unmatched, + trk_is_nonempty, + is_new_det, + det_to_max_iou_trk_idx, + det_is_high_conf, + det_is_high_iou, + det_keep, + im_mask, + ) + + +class Sam3VideoBase(nn.Module): + def __init__( + self, + detector: nn.Module, + tracker: nn.Module, + # prob threshold for detection outputs -- only keep detections above this threshold + # enters NMS and det-to-track matching + score_threshold_detection=0.5, + # IoU threshold for detection NMS + det_nms_thresh=0.0, + # IoU threshold for det-to-track matching -- a detection is considered "matched" to a tracklet it + # overlaps with a tracklet above this threshold -- it is often a loose threshold like 0.1 + assoc_iou_thresh=0.5, + # IoU threshold for det-to-track matching, which is used to determine whether a masklet is "unmatched" + # by any detections -- it is often a stricter threshold like 0.5 + trk_assoc_iou_thresh=0.5, + # prob threshold for a detection to be added as a new object + new_det_thresh=0.0, + # hotstart parameters: we hold off the outputs for `hotstart_delay` frames and + # 1) remove those tracklets unmatched by any detections based on `hotstart_unmatch_thresh` + # 2) remove those tracklets overlapping with one another based on `hotstart_dup_thresh` + hotstart_delay=0, + hotstart_unmatch_thresh=3, + hotstart_dup_thresh=3, + # Whether to suppress masks only within hotstart. If False, we can suppress masks even if they start before hotstart period. + suppress_unmatched_only_within_hotstart=True, + init_trk_keep_alive=0, + max_trk_keep_alive=8, + min_trk_keep_alive=-4, + # Threshold for suppressing overlapping objects based on recent occlusion + suppress_overlapping_based_on_recent_occlusion_threshold=0.0, + decrease_trk_keep_alive_for_empty_masklets=False, + o2o_matching_masklets_enable=False, # Enable hungarian matching to match existing masklets + suppress_det_close_to_boundary=False, + fill_hole_area=16, + # The maximum number of objects (masklets) to track across all GPUs (for no limit, set it to -1) + max_num_objects=-1, + recondition_every_nth_frame=-1, + # masket confirmation status (to suppress unconfirmed masklets) + masklet_confirmation_enable=False, + # a masklet is confirmed after being consecutively detected and matched for + # `masklet_confirmation_consecutive_det_thresh` + masklet_confirmation_consecutive_det_thresh=3, + # bbox heuristic parameters + reconstruction_bbox_iou_thresh=0.0, + reconstruction_bbox_det_score=0.0, + ): + super().__init__() + self.detector = detector + self.tracker = tracker + self.score_threshold_detection = score_threshold_detection + self.det_nms_thresh = det_nms_thresh + self.assoc_iou_thresh = assoc_iou_thresh + self.trk_assoc_iou_thresh = trk_assoc_iou_thresh + self.new_det_thresh = new_det_thresh + + # hotstart parameters + if hotstart_delay > 0: + assert hotstart_unmatch_thresh <= hotstart_delay + assert hotstart_dup_thresh <= hotstart_delay + self.hotstart_delay = hotstart_delay + self.hotstart_unmatch_thresh = hotstart_unmatch_thresh + self.hotstart_dup_thresh = hotstart_dup_thresh + self.suppress_unmatched_only_within_hotstart = ( + suppress_unmatched_only_within_hotstart + ) + self.init_trk_keep_alive = init_trk_keep_alive + self.max_trk_keep_alive = max_trk_keep_alive + self.min_trk_keep_alive = min_trk_keep_alive + self.suppress_overlapping_based_on_recent_occlusion_threshold = ( + suppress_overlapping_based_on_recent_occlusion_threshold + ) + self.suppress_det_close_to_boundary = suppress_det_close_to_boundary + self.decrease_trk_keep_alive_for_empty_masklets = ( + decrease_trk_keep_alive_for_empty_masklets + ) + self.o2o_matching_masklets_enable = o2o_matching_masklets_enable + self.fill_hole_area = fill_hole_area + self.eval() + self.rank = int(os.getenv("RANK", "0")) + self.world_size = int(os.getenv("WORLD_SIZE", "1")) + self._dist_pg_cpu = None # CPU process group (lazy-initialized on first use) + + # the maximum object number + if max_num_objects > 0: + num_obj_for_compile = math.ceil(max_num_objects / self.world_size) + else: + max_num_objects = 10000 # no limit + num_obj_for_compile = 16 + logger.info(f"setting {max_num_objects=} and {num_obj_for_compile=}") + self.max_num_objects = max_num_objects + self.num_obj_for_compile = num_obj_for_compile + self.recondition_every_nth_frame = recondition_every_nth_frame + self.masklet_confirmation_enable = masklet_confirmation_enable + self.masklet_confirmation_consecutive_det_thresh = ( + masklet_confirmation_consecutive_det_thresh + ) + self.reconstruction_bbox_iou_thresh = reconstruction_bbox_iou_thresh + self.reconstruction_bbox_det_score = reconstruction_bbox_det_score + + @property + def device(self): + self._device = getattr(self, "_device", None) or next(self.parameters()).device + return self._device + + def _init_dist_pg_cpu(self): + # a short 3-min timeout to quickly detect any synchronization failures + timeout_sec = int(os.getenv("SAM3_COLLECTIVE_OP_TIMEOUT_SEC", "180")) + timeout = datetime.timedelta(seconds=timeout_sec) + self._dist_pg_cpu = dist.new_group(backend="gloo", timeout=timeout) + + def broadcast_python_obj_cpu(self, python_obj_list, src): + if self._dist_pg_cpu is None: + self._init_dist_pg_cpu() + dist.broadcast_object_list(python_obj_list, src=src, group=self._dist_pg_cpu) + + def _det_track_one_frame( + self, + frame_idx: int, + num_frames: int, + reverse: bool, + input_batch: BatchedDatapoint, + geometric_prompt: Any, + tracker_states_local: List[Any], + tracker_metadata_prev: Dict[str, Any], + feature_cache: Dict, + orig_vid_height: int, + orig_vid_width: int, + is_image_only: bool = False, + allow_new_detections: bool = True, + ): + """ + This function handles one-step inference for the DenseTracking model in an SPMD manner. + At a high-level, all GPUs execute the same function calls as if it's done on a single GPU, + while under the hood, some function calls involve distributed computation based on sharded + SAM2 states. + + - `input_batch` contains image and other inputs on the entire video; it should be identical across GPUs + - `tracker_states_local` holds the local masklet information in this GPU shard + - `tracker_metadata_prev` manages the metadata for SAM2 objects, such as which masklet is hold on which GPUs + it contains both global and local masklet information + """ + + # Step 1: run backbone and detector in a distributed manner -- this is done via Sam3ImageOnVideoMultiGPU, + # a MultiGPU model (assigned to `self.detector`) that shards frames in a round-robin manner. + # It returns a "det_out" dict for `frame_idx` and fills SAM2 backbone features for `frame_idx` + # into `feature_cache`. Despite its distributed inference under the hood, the results would be + # the same as if it is running backbone and detector for every frame on a single GPU. + det_out = self.run_backbone_and_detection( + frame_idx=frame_idx, + num_frames=num_frames, + reverse=reverse, + input_batch=input_batch, + geometric_prompt=geometric_prompt, + feature_cache=feature_cache, + allow_new_detections=allow_new_detections, + ) + + # Step 2: each GPU propagates its local SAM2 states to get the SAM2 prediction masks. + # the returned `tracker_low_res_masks_global` contains the concatenated masklet predictions + # gathered from all GPUs (as if they are propagated on a single GPU). Note that this step only + # runs the SAM2 propagation step, but doesn't encode new memory for the predicted masks; + # we defer memory encoding to `run_tracker_update_execution_phase` after resolving all heuristics. + if tracker_metadata_prev == {}: + # initialize masklet metadata if it's uninitialized (empty dict) + tracker_metadata_prev.update(self._initialize_metadata()) + tracker_low_res_masks_global, tracker_obj_scores_global = ( + self.run_tracker_propagation( + frame_idx=frame_idx, + num_frames=num_frames, + reverse=reverse, + tracker_states_local=tracker_states_local, + tracker_metadata_prev=tracker_metadata_prev, + ) + ) + + # Step 3: based on detection outputs and the propagated SAM2 prediction masks, we make plans + # for SAM2 masklet updates (i.e. which objects to add and remove, how to load-balance them, etc). + # We also run SAM2 memory encoder globally in this step to resolve non-overlapping constraints. + # **This step should involve all the heuristics needed for any updates.** Most of the update + # planning will be done on the master rank (GPU 0) and the resulting plan `tracker_update_plan` is + # broadcasted to other GPUs (to be executed in a distributed manner). This step also generates the + # new masklet metadata `tracker_metadata_new` (based on its previous version `tracker_metadata_prev`). + tracker_update_plan, tracker_metadata_new = ( + self.run_tracker_update_planning_phase( + frame_idx=frame_idx, + num_frames=num_frames, + reverse=reverse, + det_out=det_out, + tracker_low_res_masks_global=tracker_low_res_masks_global, + tracker_obj_scores_global=tracker_obj_scores_global, + tracker_metadata_prev=tracker_metadata_prev, + tracker_states_local=tracker_states_local, + is_image_only=is_image_only, + ) + ) + + # Get reconditioning info from the update plan + reconditioned_obj_ids = tracker_update_plan.get("reconditioned_obj_ids", set()) + det_to_matched_trk_obj_ids = tracker_update_plan.get( + "det_to_matched_trk_obj_ids", {} + ) + + # Step 4: based on `tracker_update_plan`, each GPU executes the update w.r.t. its local SAM2 inference states + tracker_states_local_new = self.run_tracker_update_execution_phase( + frame_idx=frame_idx, + num_frames=num_frames, + reverse=reverse, + det_out=det_out, + tracker_states_local=tracker_states_local, + tracker_update_plan=tracker_update_plan, + orig_vid_height=orig_vid_height, + orig_vid_width=orig_vid_width, + feature_cache=feature_cache, + ) + + # Step 5: finally, build the outputs for this frame (it only needs to be done on GPU 0 since + # only GPU 0 will send outputs to the server). + if self.rank == 0: + obj_id_to_mask = self.build_outputs( + frame_idx=frame_idx, + num_frames=num_frames, + reverse=reverse, + det_out=det_out, + tracker_low_res_masks_global=tracker_low_res_masks_global, + tracker_obj_scores_global=tracker_obj_scores_global, + tracker_metadata_prev=tracker_metadata_prev, + tracker_update_plan=tracker_update_plan, + orig_vid_height=orig_vid_height, + orig_vid_width=orig_vid_width, + reconditioned_obj_ids=reconditioned_obj_ids, + det_to_matched_trk_obj_ids=det_to_matched_trk_obj_ids, + ) + obj_id_to_score = tracker_metadata_new["obj_id_to_score"] + else: + obj_id_to_mask, obj_id_to_score = {}, {} # dummy outputs on other GPUs + # a few statistics for the current frame as a part of the output + frame_stats = { + "num_obj_tracked": np.sum(tracker_metadata_new["num_obj_per_gpu"]), + "num_obj_dropped": tracker_update_plan["num_obj_dropped_due_to_limit"], + } + # add tracker scores to metadata, it should be fired for frames except the first frame + if tracker_obj_scores_global.shape[0] > 0: + # Convert tracker_obj_scores_global to sigmoid scores before updating + tracker_obj_scores_global = tracker_obj_scores_global.sigmoid().tolist() + tracker_obj_ids = tracker_metadata_prev["obj_ids_all_gpu"] + tracker_metadata_new["obj_id_to_tracker_score_frame_wise"][ + frame_idx + ].update(dict(zip(tracker_obj_ids, tracker_obj_scores_global))) + return ( + obj_id_to_mask, # a dict: obj_id --> output mask + obj_id_to_score, # a dict: obj_id --> output score (prob) + tracker_states_local_new, + tracker_metadata_new, + frame_stats, + tracker_obj_scores_global, # a dict: obj_id --> tracker frame-level scores + ) + + def _suppress_detections_close_to_boundary(self, boxes, margin=0.025): + """ + Suppress detections too close to image edges (for normalized boxes). + + boxes: (N, 4) in xyxy format, normalized [0,1] + margin: fraction of image + """ + x_min, y_min, x_max, y_max = boxes.unbind(-1) + x_c = (x_min + x_max) / 2 + y_c = (y_min + y_max) / 2 + keep = ( + (x_c > margin) + & (x_c < 1.0 - margin) + & (y_c > margin) + & (y_c < 1.0 - margin) + ) + + return keep + + def run_backbone_and_detection( + self, + frame_idx: int, + num_frames: int, + input_batch: BatchedDatapoint, + geometric_prompt: Any, + feature_cache: Dict, + reverse: bool, + allow_new_detections: bool, + ): + # Step 1: if text feature is not cached in `feature_cache`, compute and cache it + text_batch_key = tuple(input_batch.find_text_batch) + if "text" not in feature_cache or text_batch_key not in feature_cache["text"]: + text_outputs = self.detector.backbone.forward_text( + input_batch.find_text_batch, device=self.device + ) + # note: we only cache the text feature of the most recent prompt + feature_cache["text"] = {text_batch_key: text_outputs} + else: + text_outputs = feature_cache["text"][text_batch_key] + + # Step 2: run backbone, detector, and post-processing with NMS + if "multigpu_buffer" not in feature_cache: + # "multigpu_buffer" is a buffer cache used by `self.detector` and it needs + # to be passed to `forward_video_grounding_multigpu` for every call + feature_cache["multigpu_buffer"] = {} + + # Extract max_frame_num_to_track from feature_cache if available + tracking_bounds = feature_cache.get("tracking_bounds", {}) + max_frame_num_to_track = tracking_bounds.get("max_frame_num_to_track") + start_frame_idx = tracking_bounds.get("propagate_in_video_start_frame_idx") + + sam3_image_out, _ = self.detector.forward_video_grounding_multigpu( + backbone_out={ + "img_batch_all_stages": input_batch.img_batch, + **text_outputs, + }, + find_inputs=input_batch.find_inputs, + geometric_prompt=geometric_prompt, + frame_idx=frame_idx, + num_frames=num_frames, + multigpu_buffer=feature_cache["multigpu_buffer"], + track_in_reverse=reverse, + # also get the SAM2 backbone features + return_tracker_backbone_feats=True, + # run NMS as a part of distributed computation + run_nms=self.det_nms_thresh > 0.0, + nms_prob_thresh=self.score_threshold_detection, + nms_iou_thresh=self.det_nms_thresh, + # pass max_frame_num_to_track to respect tracking limits + max_frame_num_to_track=max_frame_num_to_track, + propagate_in_video_start_frame_idx=start_frame_idx, + ) + # note: detections in `sam3_image_out` has already gone through NMS + pred_probs = sam3_image_out["pred_logits"].squeeze(-1).sigmoid() + if not allow_new_detections: + pred_probs = pred_probs - 1e8 # make sure no detections are kept + pred_boxes_xyxy = sam3_image_out["pred_boxes_xyxy"] + pred_masks = sam3_image_out["pred_masks"] + # get the positive detection outputs above threshold + pos_pred_idx = torch.where(pred_probs > self.score_threshold_detection) + det_out = { + "bbox": pred_boxes_xyxy[pos_pred_idx[0], pos_pred_idx[1]], + "mask": pred_masks[pos_pred_idx[0], pos_pred_idx[1]], + "scores": pred_probs[pos_pred_idx[0], pos_pred_idx[1]], + } + + # Step 3: build SAM2 backbone features and store them in `feature_cache` + backbone_cache = {} + sam_mask_decoder = self.tracker.sam_mask_decoder + tracker_backbone_fpn = [ + sam_mask_decoder.conv_s0(sam3_image_out["tracker_backbone_fpn_0"]), + sam_mask_decoder.conv_s1(sam3_image_out["tracker_backbone_fpn_1"]), + sam3_image_out["tracker_backbone_fpn_2"], # fpn_2 doesn't need conv + ] + tracker_backbone_out = { + "vision_features": tracker_backbone_fpn[-1], # top-level feature + "vision_pos_enc": sam3_image_out["tracker_backbone_pos_enc"], + "backbone_fpn": tracker_backbone_fpn, + } + backbone_cache["tracker_backbone_out"] = tracker_backbone_out + feature_cache[frame_idx] = ( + input_batch.img_batch[frame_idx], + backbone_cache, + ) + # remove from `feature_cache` old features to save GPU memory + feature_cache.pop(frame_idx - 1 if not reverse else frame_idx + 1, None) + return det_out + + def run_tracker_propagation( + self, + frame_idx: int, + num_frames: int, + reverse: bool, + tracker_states_local: List[Any], + tracker_metadata_prev: Dict[str, npt.NDArray], + ): + # Step 1: propagate the local SAM2 states to get the current frame's prediction + # `low_res_masks_local` of the existing masklets on this GPU + # - obj_ids_local: List[int] -- list of object IDs + # - low_res_masks_local: Tensor -- (num_local_obj, H_mask, W_mask) + obj_ids_local, low_res_masks_local, obj_scores_local = ( + self._propogate_tracker_one_frame_local_gpu( + tracker_states_local, frame_idx=frame_idx, reverse=reverse + ) + ) + + assert np.all( + obj_ids_local == tracker_metadata_prev["obj_ids_per_gpu"][self.rank] + ), "{} != {}".format( + obj_ids_local, tracker_metadata_prev["obj_ids_per_gpu"][self.rank] + ) + + # Step 2: all-gather `low_res_masks_local` into `low_res_masks_global` + # - low_res_masks_global: Tensor -- (num_global_obj, H_mask, W_mask) + _, H_mask, W_mask = low_res_masks_local.shape + if self.world_size > 1: + # `low_res_masks_local` and `obj_scores_local` need to be contiguous and float32 + # (they could be non-contiguous due to slicing and/or bfloat16 due to autocast) + low_res_masks_local = low_res_masks_local.float().contiguous() + obj_scores_local = obj_scores_local.float().contiguous() + num_obj_this_gpu = tracker_metadata_prev["num_obj_per_gpu"][self.rank] + assert low_res_masks_local.size(0) == num_obj_this_gpu + assert obj_scores_local.size(0) == num_obj_this_gpu + low_res_masks_peers = [ + low_res_masks_local.new_empty(num_obj, H_mask, W_mask) + for num_obj in tracker_metadata_prev["num_obj_per_gpu"] + ] + obj_scores_peers = [ + obj_scores_local.new_empty(num_obj) + for num_obj in tracker_metadata_prev["num_obj_per_gpu"] + ] + dist.all_gather(low_res_masks_peers, low_res_masks_local) + dist.all_gather(obj_scores_peers, obj_scores_local) + low_res_masks_global = torch.cat(low_res_masks_peers, dim=0) + obj_scores_global = torch.cat(obj_scores_peers, dim=0) + else: + low_res_masks_global = low_res_masks_local + obj_scores_global = obj_scores_local + return low_res_masks_global, obj_scores_global + + def _recondition_masklets( + self, + frame_idx, + det_out: Dict[str, Tensor], + trk_id_to_max_iou_high_conf_det: List[int], + tracker_states_local: List[Any], + tracker_metadata: Dict[str, npt.NDArray], + tracker_obj_scores_global: Tensor, + ): + # Recondition the masklets based on the new detections + for trk_obj_id, det_idx in trk_id_to_max_iou_high_conf_det.items(): + new_mask = det_out["mask"][det_idx : det_idx + 1] + input_mask_res = self.tracker.input_mask_size + new_mask_binary = ( + F.interpolate( + new_mask.unsqueeze(1), + size=(input_mask_res, input_mask_res), + mode="bilinear", + align_corners=False, + ).squeeze(1)[0] + > 0 + ) + HIGH_CONF_THRESH = 0.8 + reconditioned_states_idx = set() + obj_idx = np.where(tracker_metadata["obj_ids_all_gpu"] == trk_obj_id)[ + 0 + ].item() + obj_score = tracker_obj_scores_global[obj_idx] + for state_idx, inference_state in enumerate(tracker_states_local): + if ( + trk_obj_id in inference_state["obj_ids"] + # NOTE: Goal of this condition is to avoid reconditioning masks that are occluded/low qualiy. + # Unfortunately, these can get reconditioned anyway due to batching. We should consider removing these heuristics. + and obj_score > HIGH_CONF_THRESH + ): + logger.debug( + f"Adding new mask for track {trk_obj_id} at frame {frame_idx}. Objects {inference_state['obj_ids']} are all reconditioned." + ) + self.tracker.add_new_mask( + inference_state=inference_state, + frame_idx=frame_idx, + obj_id=trk_obj_id, + mask=new_mask_binary, + ) + reconditioned_states_idx.add(state_idx) + + for idx in reconditioned_states_idx: + self.tracker.propagate_in_video_preflight( + tracker_states_local[idx], run_mem_encoder=True + ) + return tracker_states_local + + def run_tracker_update_planning_phase( + self, + frame_idx: int, + num_frames: int, + reverse: bool, + det_out: Dict[str, Tensor], + tracker_low_res_masks_global: Tensor, + tracker_obj_scores_global: Tensor, + tracker_metadata_prev: Dict[str, npt.NDArray], + tracker_states_local: List[Any], + is_image_only: bool = False, + ): + # initialize new metadata from previous metadata (its values will be updated later) + tracker_metadata_new = self._create_planning_metadata(tracker_metadata_prev) + + # Initialize reconditioned_obj_ids early to avoid UnboundLocalError + reconditioned_obj_ids = set() + + # Step 1: make the update plan and resolve heuristics on GPU 0 + det_mask_preds: Tensor = det_out["mask"] # low-res mask logits + det_scores_np: npt.NDArray = det_out["scores"].float().cpu().numpy() + det_bbox_xyxy: Tensor = det_out["bbox"] + if self.rank == 0: + # a) match detector and tracker masks and find new objects + ( + new_det_fa_inds, + unmatched_trk_obj_ids, + det_to_matched_trk_obj_ids, + trk_id_to_max_iou_high_conf_det, + empty_trk_obj_ids, + ) = self._associate_det_trk( + det_masks=det_mask_preds, + det_scores_np=det_scores_np, + trk_masks=tracker_low_res_masks_global, + trk_obj_ids=tracker_metadata_prev["obj_ids_all_gpu"], + ) + if self.suppress_det_close_to_boundary: + keep = self._suppress_detections_close_to_boundary( + det_bbox_xyxy[new_det_fa_inds] + ) + new_det_fa_inds = new_det_fa_inds[keep.cpu().numpy()] + + # check whether we've hit the maximum number of objects we can track (and if so, drop some detections) + prev_obj_num = np.sum(tracker_metadata_prev["num_obj_per_gpu"]) + new_det_num = len(new_det_fa_inds) + num_obj_dropped_due_to_limit = 0 + if not is_image_only and prev_obj_num + new_det_num > self.max_num_objects: + logger.warning( + f"hitting {self.max_num_objects=} with {new_det_num=} and {prev_obj_num=}" + ) + new_det_num_to_keep = self.max_num_objects - prev_obj_num + num_obj_dropped_due_to_limit = new_det_num - new_det_num_to_keep + new_det_fa_inds = self._drop_new_det_with_obj_limit( + new_det_fa_inds, det_scores_np, new_det_num_to_keep + ) + assert len(new_det_fa_inds) == new_det_num_to_keep + new_det_num = len(new_det_fa_inds) + + # assign object IDs to new detections and decide which GPU to place them + new_det_start_obj_id = tracker_metadata_prev["max_obj_id"] + 1 + new_det_obj_ids = new_det_start_obj_id + np.arange(new_det_num) + prev_workload_per_gpu = tracker_metadata_prev["num_obj_per_gpu"] + new_det_gpu_ids = self._assign_new_det_to_gpus( + new_det_num=new_det_num, + prev_workload_per_gpu=prev_workload_per_gpu, + ) + + # b) handle hotstart heuristics to remove objects + # here `rank0_metadata` contains metadata stored on (and only accessible to) GPU 0; + # we avoid broadcasting them to other GPUs to save communication cost, assuming + # that `rank0_metadata` is not needed by other GPUs + rank0_metadata_new = deepcopy(tracker_metadata_prev["rank0_metadata"]) + if not hasattr(self, "_warm_up_complete") or self._warm_up_complete: + obj_ids_newly_removed, rank0_metadata_new = self._process_hotstart( + frame_idx=frame_idx, + num_frames=num_frames, + reverse=reverse, + det_to_matched_trk_obj_ids=det_to_matched_trk_obj_ids, + new_det_obj_ids=new_det_obj_ids, + empty_trk_obj_ids=empty_trk_obj_ids, + unmatched_trk_obj_ids=unmatched_trk_obj_ids, + rank0_metadata=rank0_metadata_new, + tracker_metadata=tracker_metadata_prev, + ) + else: + # if warm-up is not complete, we don't remove any objects + obj_ids_newly_removed = set() + tracker_metadata_new["rank0_metadata"] = rank0_metadata_new + + # Step 2: broadcast the update plan to other GPUs + NUM_BROADCAST_ITEMS = 9 + if self.rank == 0 and self.world_size > 1: + # `num_obj_per_gpu_on_rank0` is used for metadata consistency check on other GPUs + # (it's a small array with length==self.world_size, so broadcasting it is cheap) + num_obj_per_gpu_on_rank0 = tracker_metadata_prev["num_obj_per_gpu"] + update_plan = [ + new_det_fa_inds, + new_det_obj_ids, + new_det_gpu_ids, + num_obj_per_gpu_on_rank0, + unmatched_trk_obj_ids, + det_to_matched_trk_obj_ids, + obj_ids_newly_removed, + num_obj_dropped_due_to_limit, + trk_id_to_max_iou_high_conf_det, + ] + assert ( + len(update_plan) == NUM_BROADCAST_ITEMS + ), f"Manually update NUM_BROADCAST_ITEMS to be: {len(update_plan)}" + self.broadcast_python_obj_cpu(update_plan, src=0) + elif self.rank > 0 and self.world_size > 1: + update_plan = [ + None + ] * NUM_BROADCAST_ITEMS # other ranks receive the plan from rank 0 + self.broadcast_python_obj_cpu(update_plan, src=0) + ( + new_det_fa_inds, + new_det_obj_ids, + new_det_gpu_ids, + num_obj_per_gpu_on_rank0, + unmatched_trk_obj_ids, + det_to_matched_trk_obj_ids, + obj_ids_newly_removed, + num_obj_dropped_due_to_limit, + trk_id_to_max_iou_high_conf_det, + ) = update_plan + # metadata consistency check: verify that the received `num_obj_per_gpu_on_rank0` is consistent with the local metadata + # it's critical that all GPUs agree on the previous number of objects (otherwise the inference might hang or fail silently) + if not np.all( + num_obj_per_gpu_on_rank0 == tracker_metadata_prev["num_obj_per_gpu"] + ): + raise RuntimeError( + f"{self.rank=} received {num_obj_per_gpu_on_rank0=}, which is inconsistent with local record " + f"{tracker_metadata_prev['num_obj_per_gpu']=}. There's likely a bug in update planning or execution." + ) + + # `tracker_update_plan` should be identical on all GPUs after broadcasting + tracker_update_plan = { + "new_det_fa_inds": new_det_fa_inds, # npt.NDArray + "new_det_obj_ids": new_det_obj_ids, # npt.NDArray + "new_det_gpu_ids": new_det_gpu_ids, # npt.NDArray + "unmatched_trk_obj_ids": unmatched_trk_obj_ids, # npt.NDArray + "det_to_matched_trk_obj_ids": det_to_matched_trk_obj_ids, # dict + "obj_ids_newly_removed": obj_ids_newly_removed, # set + "num_obj_dropped_due_to_limit": num_obj_dropped_due_to_limit, # int + "trk_id_to_max_iou_high_conf_det": trk_id_to_max_iou_high_conf_det, # dict + "reconditioned_obj_ids": reconditioned_obj_ids, # set + } + + # Step 3 (optional): recondition masklets based on high-confidence detections before memory encoding + # NOTE: Running this in execution phase (after memory encoding) can lead to suboptimal results + should_recondition_iou = False + + # Evaluate tracklets for reconditioning based on bbox IoU mismatch with detections + if ( + self.reconstruction_bbox_iou_thresh > 0 + and len(trk_id_to_max_iou_high_conf_det) > 0 + ): + for trk_obj_id, det_idx in trk_id_to_max_iou_high_conf_det.items(): + det_box = det_out["bbox"][det_idx] + det_score = det_out["scores"][det_idx] + + try: + trk_idx = list(tracker_metadata_prev["obj_ids_all_gpu"]).index( + trk_obj_id + ) + except ValueError: + continue # Skip if tracklet not found + + tracker_mask = tracker_low_res_masks_global[trk_idx] + mask_binary = tracker_mask > 0 + mask_area = mask_binary.sum().item() + + if mask_area == 0: + continue # Skip tracklets with zero mask area + + # Get bounding box from SAM2 mask and convert to normalized coordinates + tracker_box_pixels = ( + mask_to_box(mask_binary.unsqueeze(0).unsqueeze(0)) + .squeeze(0) + .squeeze(0) + ) + mask_height, mask_width = tracker_mask.shape[-2:] + tracker_box_normalized = torch.tensor( + [ + tracker_box_pixels[0] / mask_width, + tracker_box_pixels[1] / mask_height, + tracker_box_pixels[2] / mask_width, + tracker_box_pixels[3] / mask_height, + ], + device=tracker_box_pixels.device, + ) + + # Compute IoU between detection and SAM2 tracklet bounding boxes + det_box_batch = det_box.unsqueeze(0) + tracker_box_batch = tracker_box_normalized.unsqueeze(0) + iou = fast_diag_box_iou(det_box_batch, tracker_box_batch)[0] + + if ( + iou < self.reconstruction_bbox_iou_thresh + and det_score >= self.reconstruction_bbox_det_score + ): + should_recondition_iou = True + reconditioned_obj_ids.add(trk_obj_id) + + should_recondition_periodic = ( + self.recondition_every_nth_frame > 0 + and frame_idx % self.recondition_every_nth_frame == 0 + and len(trk_id_to_max_iou_high_conf_det) > 0 + ) + + # Recondition if periodic or IoU condition met + if should_recondition_periodic or should_recondition_iou: + self._recondition_masklets( + frame_idx, + det_out, + trk_id_to_max_iou_high_conf_det, + tracker_states_local, + tracker_metadata_prev, + tracker_obj_scores_global, + ) + + # Step 4: Run SAM2 memory encoder on the current frame's prediction masks + # This is done on all GPUs + batch_size = tracker_low_res_masks_global.size(0) + if batch_size > 0: + if not hasattr(self, "_warm_up_complete") or self._warm_up_complete: + if self.suppress_overlapping_based_on_recent_occlusion_threshold > 0.0: + # NOTE: tracker_low_res_masks_global is updated in-place then returned + tracker_low_res_masks_global = ( + self._suppress_overlapping_based_on_recent_occlusion( + frame_idx, + tracker_low_res_masks_global, + tracker_metadata_prev, + tracker_metadata_new, + obj_ids_newly_removed, + reverse, + ) + ) + + self._tracker_update_memories( + tracker_states_local, + frame_idx, + tracker_metadata=tracker_metadata_prev, + low_res_masks=tracker_low_res_masks_global, + ) + + # Step 4: update the SAM2 metadata based on the update plan + # note: except for "rank0_metadata" (that is only available on GPU 0), + # the updated `tracker_metadata_new` should be identical on all GPUs + for rank in range(self.world_size): + new_det_obj_ids_this_gpu = new_det_obj_ids[new_det_gpu_ids == rank] + updated_obj_ids_this_gpu = tracker_metadata_new["obj_ids_per_gpu"][rank] + if len(new_det_obj_ids_this_gpu) > 0: + updated_obj_ids_this_gpu = np.concatenate( + [updated_obj_ids_this_gpu, new_det_obj_ids_this_gpu] + ) + if len(obj_ids_newly_removed) > 0: + is_removed = np.isin( + updated_obj_ids_this_gpu, list(obj_ids_newly_removed) + ) + updated_obj_ids_this_gpu = updated_obj_ids_this_gpu[~is_removed] + tracker_metadata_new["obj_ids_per_gpu"][rank] = updated_obj_ids_this_gpu + tracker_metadata_new["num_obj_per_gpu"][rank] = len( + updated_obj_ids_this_gpu + ) + tracker_metadata_new["obj_ids_all_gpu"] = np.concatenate( + tracker_metadata_new["obj_ids_per_gpu"] + ) + # update object scores and the maximum object ID assigned so far + if len(new_det_obj_ids) > 0: + tracker_metadata_new["obj_id_to_score"].update( + zip(new_det_obj_ids, det_scores_np[new_det_fa_inds]) + ) + # tracker scores are not available for new objects, use det score instead. + tracker_metadata_new["obj_id_to_tracker_score_frame_wise"][ + frame_idx + ].update(zip(new_det_obj_ids, det_scores_np[new_det_fa_inds])) + tracker_metadata_new["max_obj_id"] = max( + tracker_metadata_new["max_obj_id"], + np.max(new_det_obj_ids), + ) + # for removed objects, we set their scores to a very low value (-1e4) but still + # keep them in "obj_id_to_score" (it's easier to handle outputs this way) + for obj_id in obj_ids_newly_removed: + tracker_metadata_new["obj_id_to_score"][obj_id] = -1e4 + tracker_metadata_new["obj_id_to_tracker_score_frame_wise"][frame_idx][ + obj_id + ] = -1e4 + tracker_metadata_new["obj_id_to_last_occluded"].pop(obj_id, None) + # check that "rank0_metadata" is in tracker_metadata_new if and only if it's GPU 0 + assert ("rank0_metadata" in tracker_metadata_new) == (self.rank == 0) + if self.rank == 0 and self.masklet_confirmation_enable: + rank0_metadata = self.update_masklet_confirmation_status( + rank0_metadata=tracker_metadata_new["rank0_metadata"], + obj_ids_all_gpu_prev=tracker_metadata_prev["obj_ids_all_gpu"], + obj_ids_all_gpu_updated=tracker_metadata_new["obj_ids_all_gpu"], + det_to_matched_trk_obj_ids=det_to_matched_trk_obj_ids, + new_det_obj_ids=new_det_obj_ids, + ) + tracker_metadata_new["rank0_metadata"] = rank0_metadata + + return tracker_update_plan, tracker_metadata_new + + def _suppress_overlapping_based_on_recent_occlusion( + self, + frame_idx: int, + tracker_low_res_masks_global: Tensor, + tracker_metadata_prev: Dict[str, Any], + tracker_metadata_new: Dict[str, Any], + obj_ids_newly_removed: Set[int], + reverse: bool = False, + ): + """ + Suppress overlapping masks based on the most recent occlusion information. If an object is removed by hotstart, we always suppress it if it overlaps with any other object. + Args: + frame_idx (int): The current frame index. + tracker_low_res_masks_global (Tensor): The low-resolution masks for the current frame. + tracker_metadata_prev (Dict[str, Any]): The metadata from the previous frame. + tracker_metadata_new (Dict[str, Any]): The metadata for the current frame. + obj_ids_newly_removed (Set[int]): The object IDs that have been removed. + Return: + Tensor: The updated low-resolution masks with some objects suppressed. + """ + obj_ids_global = tracker_metadata_prev["obj_ids_all_gpu"] + binary_tracker_low_res_masks_global = tracker_low_res_masks_global > 0 + batch_size = tracker_low_res_masks_global.size(0) + if batch_size > 0: + assert ( + len(obj_ids_global) == batch_size + ), f"Mismatch in number of objects: {len(obj_ids_global)} vs {batch_size}" + NEVER_OCCLUDED = -1 + ALWAYS_OCCLUDED = 100000 # This value should be larger than any possible frame index, indicates that the object was removed by hotstart logic + last_occluded_prev = torch.cat( + [ + tracker_metadata_prev["obj_id_to_last_occluded"].get( + obj_id, + torch.full( + (1,), + fill_value=( + NEVER_OCCLUDED + if obj_id not in obj_ids_newly_removed + else ALWAYS_OCCLUDED + ), + device=binary_tracker_low_res_masks_global.device, + dtype=torch.long, + ), + ) + for obj_id in obj_ids_global + ], + dim=0, + ) + to_suppress = self._get_objects_to_suppress_based_on_most_recently_occluded( + binary_tracker_low_res_masks_global, + last_occluded_prev, + obj_ids_global, + frame_idx, + reverse, + ) + + # Update metadata with occlusion information + is_obj_occluded = ~(binary_tracker_low_res_masks_global.any(dim=(-1, -2))) + is_obj_occluded_or_suppressed = is_obj_occluded | to_suppress + last_occluded_new = last_occluded_prev.clone() + last_occluded_new[is_obj_occluded_or_suppressed] = frame_idx + # Slice out the last occluded frame for each object + tracker_metadata_new["obj_id_to_last_occluded"] = { + obj_id: last_occluded_new[obj_idx : obj_idx + 1] + for obj_idx, obj_id in enumerate(obj_ids_global) + } + + # Zero out suppressed masks before memory encoding + NO_OBJ_LOGIT = -10 + tracker_low_res_masks_global[to_suppress] = NO_OBJ_LOGIT + + return tracker_low_res_masks_global + + def run_tracker_update_execution_phase( + self, + frame_idx: int, + num_frames: int, + reverse: bool, + det_out: Dict[str, Tensor], + tracker_states_local: List[Any], + tracker_update_plan: Dict[str, npt.NDArray], + orig_vid_height: int, + orig_vid_width: int, + feature_cache: Dict, + tracker_metadata_new=None, + ): + # initialize tracking scores with detection scores + new_det_fa_inds: npt.NDArray = tracker_update_plan["new_det_fa_inds"] + new_det_obj_ids: npt.NDArray = tracker_update_plan["new_det_obj_ids"] + new_det_gpu_ids: npt.NDArray = tracker_update_plan["new_det_gpu_ids"] + is_on_this_gpu: npt.NDArray = new_det_gpu_ids == self.rank + new_det_obj_ids_local: npt.NDArray = new_det_obj_ids[is_on_this_gpu] + new_det_fa_inds_local: npt.NDArray = new_det_fa_inds[is_on_this_gpu] + obj_ids_newly_removed: Set[int] = tracker_update_plan["obj_ids_newly_removed"] + + # Step 1: add new objects from the detector to SAM2 inference states + if len(new_det_fa_inds_local) > 0: + new_det_fa_inds_local_t = torch.from_numpy(new_det_fa_inds_local) + new_det_masks: Tensor = det_out["mask"][new_det_fa_inds_local_t] + # initialize SAM2 with new object masks + tracker_states_local = self._tracker_add_new_objects( + frame_idx=frame_idx, + num_frames=num_frames, + new_obj_ids=new_det_obj_ids_local, + new_obj_masks=new_det_masks, + tracker_states_local=tracker_states_local, + orig_vid_height=orig_vid_height, + orig_vid_width=orig_vid_width, + feature_cache=feature_cache, + ) + + # Step 2: remove from SAM2 inference states those objects removed by heuristics + if len(obj_ids_newly_removed) > 0: + self._tracker_remove_objects(tracker_states_local, obj_ids_newly_removed) + + self._post_execution_phase_hook(tracker_states_local, tracker_metadata_new) + return tracker_states_local + + def _create_planning_metadata(self, tracker_metadata_prev): + """Create the metadata dict for the planning phase from previous metadata.""" + from copy import deepcopy + + score_key = "obj_id_to_tracker_score_frame_wise" + if score_key not in tracker_metadata_prev: + score_key = "obj_id_to_sam2_score_frame_wise" + metadata = { + "obj_ids_per_gpu": deepcopy(tracker_metadata_prev["obj_ids_per_gpu"]), + "obj_ids_all_gpu": None, + "num_obj_per_gpu": deepcopy(tracker_metadata_prev["num_obj_per_gpu"]), + "obj_id_to_score": deepcopy(tracker_metadata_prev["obj_id_to_score"]), + score_key: deepcopy(tracker_metadata_prev[score_key]), + "obj_id_to_last_occluded": {}, + "max_obj_id": deepcopy(tracker_metadata_prev["max_obj_id"]), + } + return metadata + + def _post_execution_phase_hook(self, tracker_states_local, tracker_metadata_new): + """Hook for subclasses to add post-execution logic. Default: no-op.""" + pass + + def build_outputs( + self, + frame_idx: int, + num_frames: int, + reverse: bool, + det_out: Dict[str, Tensor], + tracker_low_res_masks_global: Tensor, + tracker_obj_scores_global: Tensor, + tracker_metadata_prev: Dict[str, npt.NDArray], + tracker_update_plan: Dict[str, npt.NDArray], + orig_vid_height: int, + orig_vid_width: int, + reconditioned_obj_ids: set = None, + det_to_matched_trk_obj_ids: dict = None, + ): + new_det_fa_inds: npt.NDArray = tracker_update_plan["new_det_fa_inds"] + new_det_obj_ids: npt.NDArray = tracker_update_plan["new_det_obj_ids"] + obj_id_to_mask = {} # obj_id --> output mask tensor + + # Part 1: masks from previous SAM2 propagation + existing_masklet_obj_ids = tracker_metadata_prev["obj_ids_all_gpu"] + existing_masklet_video_res_masks = F.interpolate( + tracker_low_res_masks_global.unsqueeze(1), + size=(orig_vid_height, orig_vid_width), + mode="bilinear", + align_corners=False, + ) # (num_obj, 1, H_video, W_video) + existing_masklet_binary = existing_masklet_video_res_masks > 0 + assert len(existing_masklet_obj_ids) == len(existing_masklet_binary) + for obj_id, mask in zip(existing_masklet_obj_ids, existing_masklet_binary): + obj_id_to_mask[obj_id] = mask # (1, H_video, W_video) + + # Part 2: masks from new detections + new_det_fa_inds_t = torch.from_numpy(new_det_fa_inds) + new_det_low_res_masks = det_out["mask"][new_det_fa_inds_t].unsqueeze(1) + new_det_low_res_masks = fill_holes_in_mask_scores( + new_det_low_res_masks, + max_area=self.fill_hole_area, + fill_holes=True, + remove_sprinkles=True, + ) + new_masklet_video_res_masks = F.interpolate( + new_det_low_res_masks, + size=(orig_vid_height, orig_vid_width), + mode="bilinear", + align_corners=False, + ) # (num_obj, 1, H_video, W_video) + + new_masklet_binary = new_masklet_video_res_masks > 0 + assert len(new_det_obj_ids) == len(new_masklet_video_res_masks) + for obj_id, mask in zip(new_det_obj_ids, new_masklet_binary): + obj_id_to_mask[obj_id] = mask # (1, H_video, W_video) + + # Part 3: Override masks for reconditioned objects using detection masks + if reconditioned_obj_ids is not None and len(reconditioned_obj_ids) > 0: + trk_id_to_max_iou_high_conf_det = tracker_update_plan.get( + "trk_id_to_max_iou_high_conf_det", {} + ) + + for obj_id in reconditioned_obj_ids: + det_idx = trk_id_to_max_iou_high_conf_det.get(obj_id) + + if det_idx is not None: + det_mask = det_out["mask"][det_idx] + det_mask = det_mask.unsqueeze(0).unsqueeze(0) + det_mask_resized = ( + F.interpolate( + det_mask.float(), + size=(orig_vid_height, orig_vid_width), + mode="bilinear", + align_corners=False, + ) + > 0 + ) + + det_mask_final = det_mask_resized.squeeze(0) + obj_id_to_mask[obj_id] = det_mask_final + + return obj_id_to_mask + + def _get_objects_to_suppress_based_on_most_recently_occluded( + self, + binary_low_res_masks: Tensor, + last_occluded: List[int], + obj_ids: List[int], + frame_idx: int = None, + reverse: bool = False, + ): + # Suppress overlapping masks for objects that were most recently occluded + assert ( + binary_low_res_masks.dtype == torch.bool + ), f"Expected boolean tensor, got {binary_low_res_masks.dtype}" + to_suppress = torch.zeros( + binary_low_res_masks.size(0), + device=binary_low_res_masks.device, + dtype=torch.bool, + ) + if len(obj_ids) <= 1: + return to_suppress + + iou = mask_iou(binary_low_res_masks, binary_low_res_masks) # [N,N] + + # Create masks for upper triangular matrix (i < j) and IoU threshold + mask_iou_thresh = ( + iou >= self.suppress_overlapping_based_on_recent_occlusion_threshold + ) + overlapping_pairs = torch.triu(mask_iou_thresh, diagonal=1) # [N,N] + + last_occ_expanded_i = last_occluded.unsqueeze(1) # (N, 1) + last_occ_expanded_j = last_occluded.unsqueeze(0) # (1, N) + # Suppress most recently occluded + cmp_op = torch.gt if not reverse else torch.lt + suppress_i_mask = ( + overlapping_pairs + & cmp_op( + last_occ_expanded_i, last_occ_expanded_j + ) # (last_occ_expanded_i > last_occ_expanded_j) + & ( + last_occ_expanded_j > -1 + ) # j can suppress i only if i was previously occluded + ) + suppress_j_mask = ( + overlapping_pairs + & cmp_op(last_occ_expanded_j, last_occ_expanded_i) + & ( + last_occ_expanded_i > -1 + ) # i can suppress j only if j was previously occluded + ) + # Apply suppression + to_suppress = suppress_i_mask.any(dim=1) | suppress_j_mask.any(dim=0) + + # Log for debugging + if ( + self.rank == 0 + and logger.isEnabledFor(logging.DEBUG) + and frame_idx is not None + ): + suppress_i_mask = suppress_i_mask.cpu().numpy() + suppress_j_mask = suppress_j_mask.cpu().numpy() + last_occluded = last_occluded.cpu().numpy() + + # Find all suppression pairs without using torch.where + batch_size = suppress_i_mask.shape[0] + + # Log i-suppression cases (where i gets suppressed in favor of j) + for i in range(batch_size): + for j in range(batch_size): + if suppress_i_mask[i, j]: + logger.debug( + f"{frame_idx=}: Suppressing obj {obj_ids[i]} last occluded {last_occluded[i]} in favor of {obj_ids[j]} last occluded {last_occluded[j]}" + ) + + # Log j-suppression cases (where j gets suppressed in favor of i) + for i in range(batch_size): + for j in range(batch_size): + if suppress_j_mask[i, j]: + logger.debug( + f"{frame_idx=}: Suppressing obj {obj_ids[j]} last occluded {last_occluded[j]} in favor of {obj_ids[i]} last occluded {last_occluded[i]}" + ) + + return to_suppress + + def _propogate_tracker_one_frame_local_gpu( + self, + inference_states: List[Any], + frame_idx: int, + reverse: bool, + # by default, we disable memory encoding until we gather all outputs + run_mem_encoder: bool = False, + ): + """ + inference_states: List of inference states, each state corresponds to a different set of objects. + """ + obj_ids_local = [] + low_res_masks_list = [] + obj_scores_list = [] + for inference_state in inference_states: + if len(inference_state["obj_ids"]) == 0: + continue # skip propagation on empty inference states + + # propagate one frame + num_frames_propagated = 0 + for out in self.tracker.propagate_in_video( + inference_state, + start_frame_idx=frame_idx, + # end_frame_idx = start_frame_idx + max_frame_num_to_track + # (i.e. propagating 1 frame since end_frame_idx is inclusive) + max_frame_num_to_track=0, + reverse=reverse, + tqdm_disable=True, + run_mem_encoder=run_mem_encoder, + ): + out_frame_idx, out_obj_ids, out_low_res_masks, _, out_obj_scores = out + num_frames_propagated += 1 + + # only 1 frames should be propagated + assert ( + num_frames_propagated == 1 and out_frame_idx == frame_idx + ), f"num_frames_propagated: {num_frames_propagated}, out_frame_idx: {out_frame_idx}, frame_idx: {frame_idx}" + assert isinstance(out_obj_ids, list) + obj_ids_local.extend(out_obj_ids) + low_res_masks_list.append(out_low_res_masks.squeeze(1)) + obj_scores_list.append(out_obj_scores.squeeze(1)) + + # concatenate the output masklets from all local inference states + H_mask = W_mask = self.tracker.low_res_mask_size + if len(low_res_masks_list) > 0: + low_res_masks_local = torch.cat(low_res_masks_list, dim=0) + obj_scores_local = torch.cat(obj_scores_list, dim=0) + assert low_res_masks_local.shape[1:] == (H_mask, W_mask) + + # Apply hole filling to the masks + low_res_masks_local = fill_holes_in_mask_scores( + low_res_masks_local.unsqueeze(1), + max_area=self.fill_hole_area, + fill_holes=True, + remove_sprinkles=True, + ) + low_res_masks_local = low_res_masks_local.squeeze(1) + else: + low_res_masks_local = torch.zeros(0, H_mask, W_mask, device=self.device) + obj_scores_local = torch.zeros(0, device=self.device) + + return obj_ids_local, low_res_masks_local, obj_scores_local + + def _associate_det_trk( + self, + det_masks: Tensor, + det_scores_np: npt.NDArray, + trk_masks: Tensor, + trk_obj_ids: npt.NDArray, + ): + """ + Match detections on the current frame with the existing masklets. + + Args: + - det_masks: (N, H, W) tensor of predicted masks + - det_scores_np: (N,) array of detection scores + - trk_masks: (M, H, W) tensor of track masks + - trk_obj_ids: (M,) array of object IDs corresponding to trk_masks + + Returns: + - new_det_fa_inds: array of new object indices. + - unmatched_trk_obj_ids: array of existing masklet object IDs that are not matched + to any detections on this frame (for unmatched, we only count masklets with >0 area) + - det_to_matched_trk_obj_ids: dict[int, npt.NDArray]: mapping from detector's detection indices + to the list of matched tracklet object IDs + - trk_id_to_max_iou_high_conf_det: dict mapping track obj_id to the highest-IoU high-conf detection idx + - empty_trk_obj_ids: array of existing masklet object IDs with zero area in SAM2 prediction + """ + iou_threshold = self.assoc_iou_thresh + iou_threshold_trk = self.trk_assoc_iou_thresh + new_det_thresh = self.new_det_thresh + + assert det_masks.is_floating_point(), "float tensor expected (do not binarize)" + assert trk_masks.is_floating_point(), "float tensor expected (do not binarize)" + assert ( + trk_masks.size(0) == len(trk_obj_ids) + ), f"trk_masks and trk_obj_ids should have the same length, {trk_masks.size(0)} vs {len(trk_obj_ids)}" + if trk_masks.size(0) == 0: + # all detections are new + new_det_fa_inds = np.arange(det_masks.size(0)) + unmatched_trk_obj_ids = np.array([], np.int64) + empty_trk_obj_ids = np.array([], np.int64) + det_to_matched_trk_obj_ids = {} + trk_id_to_max_iou_high_conf_det = {} + return ( + new_det_fa_inds, + unmatched_trk_obj_ids, + det_to_matched_trk_obj_ids, + trk_id_to_max_iou_high_conf_det, + empty_trk_obj_ids, + ) + elif det_masks.size(0) == 0: + # all previous tracklets are unmatched if they have a non-zero area + new_det_fa_inds = np.array([], np.int64) + trk_is_nonempty = (trk_masks > 0).any(dim=(1, 2)).cpu().numpy() + unmatched_trk_obj_ids = trk_obj_ids[trk_is_nonempty] + empty_trk_obj_ids = trk_obj_ids[~trk_is_nonempty] + det_to_matched_trk_obj_ids = {} + trk_id_to_max_iou_high_conf_det = {} + return ( + new_det_fa_inds, + unmatched_trk_obj_ids, + det_to_matched_trk_obj_ids, + trk_id_to_max_iou_high_conf_det, + empty_trk_obj_ids, + ) + + if det_masks.shape[-2:] != trk_masks.shape[-2:]: + # resize to the smaller size to save GPU memory + if np.prod(det_masks.shape[-2:]) < np.prod(trk_masks.shape[-2:]): + trk_masks = F.interpolate( + trk_masks.unsqueeze(1), + size=det_masks.shape[-2:], + mode="bilinear", + align_corners=False, + ).squeeze(1) + else: + # resize detections to track size + det_masks = F.interpolate( + det_masks.unsqueeze(1), + size=trk_masks.shape[-2:], + mode="bilinear", + align_corners=False, + ).squeeze(1) + + # Convert numpy scores to tensor for the compilable function + det_scores = torch.from_numpy(det_scores_np).to(det_masks.device) + det_keep = torch.ones( + det_masks.size(0), dtype=torch.bool, device=det_masks.device + ) + + # Call the GPU-native compilable function + adt_result_tensors = _associate_det_trk_compilable( + det_masks=det_masks, + det_scores=det_scores, + det_keep=det_keep, + trk_masks=trk_masks, + new_det_thresh=new_det_thresh, + iou_threshold_trk=iou_threshold_trk, + iou_threshold=iou_threshold, + HIGH_CONF_THRESH=0.8, + use_iom_recondition=getattr(self, "use_iom_recondition", False), + o2o_matching_masklets_enable=self.o2o_matching_masklets_enable, + iom_thresh_recondition=getattr(self, "iom_thresh_recondition", 0.8), + iou_thresh_recondition=getattr(self, "iou_thresh_recondition", 0.8), + ) + + # Wrap in LazyAssociateDetTrkResult and immediately realize to numpy + # for backward compatibility with existing callers + lazy_result = LazyAssociateDetTrkResult(*adt_result_tensors) + lazy_result._convert_to_numpy() + realized = lazy_result._create_cpu_metadata(trk_obj_ids, det_masks) + + return ( + realized.new_det_fa_inds, + realized.unmatched_trk_obj_ids, + realized.det_to_matched_trk_obj_ids, + realized.trk_id_to_max_iou_high_conf_det, + realized.empty_trk_obj_ids, + ) + + def _assign_new_det_to_gpus(self, new_det_num, prev_workload_per_gpu): + """Distribute the new objects to the GPUs with the least workload.""" + workload_per_gpu: npt.NDArray = prev_workload_per_gpu.copy() + new_det_gpu_ids = np.zeros(new_det_num, np.int64) + + # assign the objects one by one + for i in range(len(new_det_gpu_ids)): + # find the GPU with the least workload + min_gpu = np.argmin(workload_per_gpu) + new_det_gpu_ids[i] = min_gpu + workload_per_gpu[min_gpu] += 1 + return new_det_gpu_ids + + def _process_hotstart( + self, + frame_idx: int, + num_frames: int, + reverse: bool, + det_to_matched_trk_obj_ids: Dict[int, npt.NDArray], + new_det_obj_ids: npt.NDArray, + empty_trk_obj_ids: npt.NDArray, + unmatched_trk_obj_ids: npt.NDArray, + rank0_metadata: Dict[str, Any], + tracker_metadata: Dict[str, Any], + ): + """Handle hotstart heuristics to remove unmatched or duplicated objects.""" + # obj_id --> first frame index where the object was detected + obj_first_frame_idx = rank0_metadata["obj_first_frame_idx"] + # obj_id --> [mismatched frame indices] + unmatched_frame_inds = rank0_metadata["unmatched_frame_inds"] + trk_keep_alive = rank0_metadata["trk_keep_alive"] + # (first_appear_obj_id, obj_id) --> [overlap frame indices] + overlap_pair_to_frame_inds = rank0_metadata["overlap_pair_to_frame_inds"] + # removed_obj_ids: object IDs that are suppressed via hot-start + removed_obj_ids = rank0_metadata["removed_obj_ids"] + suppressed_obj_ids = rank0_metadata["suppressed_obj_ids"][frame_idx] + + obj_ids_newly_removed = set() # object IDs to be newly removed on this frame + hotstart_diff = ( + frame_idx - self.hotstart_delay + if not reverse + else frame_idx + self.hotstart_delay + ) + + # Step 1: log the frame index where each object ID first appears + for obj_id in new_det_obj_ids: + if obj_id not in obj_first_frame_idx: + obj_first_frame_idx[obj_id] = frame_idx + assert obj_id not in trk_keep_alive + trk_keep_alive[obj_id] = self.init_trk_keep_alive + + matched_trks = set() + # We use the det-->tracks list to check for matched objects. Otherwise, we need to compute areas to decide whether they're occluded + for matched_trks_per_det in det_to_matched_trk_obj_ids.values(): + matched_trks.update(matched_trks_per_det) + for obj_id in matched_trks: + # NOTE: To minimize number of configurable params, we use the hotstart_unmatch_thresh to set the max value of trk_keep_alive + trk_keep_alive[obj_id] = min( + self.max_trk_keep_alive, trk_keep_alive[obj_id] + 1 + ) + for obj_id in unmatched_trk_obj_ids: + unmatched_frame_inds[obj_id].append(frame_idx) + # NOTE: To minimize number of configurable params, we use the hotstart_unmatch_thresh to set the min value of trk_keep_alive + # The max keep alive is 2x the min, means the model prefers to keep the prediction rather than suppress it if it was matched long enough. + trk_keep_alive[obj_id] = max( + self.min_trk_keep_alive, trk_keep_alive[obj_id] - 1 + ) + if self.decrease_trk_keep_alive_for_empty_masklets: + for obj_id in empty_trk_obj_ids: + # NOTE: To minimize number of configurable params, we use the hotstart_unmatch_thresh to set the min value of trk_keep_alive + trk_keep_alive[obj_id] = max( + self.min_trk_keep_alive, trk_keep_alive[obj_id] - 1 + ) + + # Step 2: removed tracks that has not matched with detections for `hotstart_unmatch_thresh` frames with hotstart period + # a) add unmatched frame indices for each existing object ID + # note that `unmatched_trk_obj_ids` contains those frames where the SAM2 output mask + # doesn't match any detection; it excludes those frames where SAM2 gives an empty mask + # b) remove a masklet if it first appears after `hotstart_diff` and is unmatched for more + # than `self.hotstart_unmatch_thresh` frames + for obj_id, frame_indices in unmatched_frame_inds.items(): + if obj_id in removed_obj_ids or obj_id in obj_ids_newly_removed: + continue # skip if the object is already removed + if len(frame_indices) >= self.hotstart_unmatch_thresh: + is_within_hotstart = ( + obj_first_frame_idx[obj_id] > hotstart_diff and not reverse + ) or (obj_first_frame_idx[obj_id] < hotstart_diff and reverse) + if is_within_hotstart: + obj_ids_newly_removed.add(obj_id) + logger.debug( + f"Removing object {obj_id} at frame {frame_idx} " + f"since it is unmatched for frames: {frame_indices}" + ) + if ( + trk_keep_alive[obj_id] <= 0 # Object has not been matched for too long + and not self.suppress_unmatched_only_within_hotstart + and obj_id not in removed_obj_ids + and obj_id not in obj_ids_newly_removed + ): + logger.debug( + f"Suppressing object {obj_id} at frame {frame_idx}, due to being unmatched" + ) + suppressed_obj_ids.add(obj_id) + + # Step 3: removed tracks that overlaps with another track for `hotstart_dup_thresh` frames + # a) find overlaps tracks -- we consider overlap if they match to the same detection + for _, matched_trk_obj_ids in det_to_matched_trk_obj_ids.items(): + if len(matched_trk_obj_ids) < 2: + continue # only count detections that are matched to multiple (>=2) masklets + # if there are multiple matched track ids, we need to find the one that appeared first; + # these later appearing ids may be removed since they may be considered as duplicates + first_appear_obj_id = ( + min(matched_trk_obj_ids, key=lambda x: obj_first_frame_idx[x]) + if not reverse + else max(matched_trk_obj_ids, key=lambda x: obj_first_frame_idx[x]) + ) + for obj_id in matched_trk_obj_ids: + if obj_id != first_appear_obj_id: + key = (first_appear_obj_id, obj_id) + overlap_pair_to_frame_inds[key].append(frame_idx) + + # b) remove a masklet if it first appears after `hotstart_diff` and it overlaps with another + # masklet (that appears earlier) for more than `self.hotstart_dup_thresh` frames + for (first_obj_id, obj_id), frame_indices in overlap_pair_to_frame_inds.items(): + if obj_id in removed_obj_ids or obj_id in obj_ids_newly_removed: + continue # skip if the object is already removed + if (obj_first_frame_idx[obj_id] > hotstart_diff and not reverse) or ( + obj_first_frame_idx[obj_id] < hotstart_diff and reverse + ): + if len(frame_indices) >= self.hotstart_dup_thresh: + obj_ids_newly_removed.add(obj_id) + logger.debug( + f"Removing object {obj_id} at frame {frame_idx} " + f"since it overlaps with another track {first_obj_id} at frames: {frame_indices}" + ) + + removed_obj_ids.update(obj_ids_newly_removed) + return obj_ids_newly_removed, rank0_metadata + + def _tracker_update_memories( + self, + tracker_inference_states: List[Any], + frame_idx: int, + tracker_metadata: Dict[str, Any], + low_res_masks: Tensor, + ): + """ + Run Sam2 memory encoder, enforcing non-overlapping constraints globally. + """ + if len(tracker_inference_states) == 0: + return + # Avoid an extra interpolation step by directly interpolating to `interpol_size` + high_res_H, high_res_W = ( + self.tracker.maskmem_backbone.mask_downsampler.interpol_size + ) + # NOTE: inspect this part if we observe OOMs in the demo + high_res_masks = F.interpolate( + low_res_masks.unsqueeze(1), + size=(high_res_H, high_res_W), + mode="bilinear", + align_corners=False, + ) + # We first apply non-overlapping constraints before memory encoding. This may include some suppression heuristics. + if not hasattr(self, "_warm_up_complete") or self._warm_up_complete: + high_res_masks = self.tracker._suppress_object_pw_area_shrinkage( + high_res_masks + ) + # Instead of gathering the predicted object scores, we use mask areas as a proxy. + object_score_logits = torch.where( + (high_res_masks > 0).any(dim=(-1, -2)), 10.0, -10.0 + ) + + # Run the memory encoder on local slices for each GPU + start_idx_gpu = sum(tracker_metadata["num_obj_per_gpu"][: self.rank]) + start_idx_state = start_idx_gpu + for tracker_state in tracker_inference_states: + num_obj_per_state = len(tracker_state["obj_ids"]) + if num_obj_per_state == 0: + continue + # Get the local high-res masks and object score logits for this inference state + end_idx_state = start_idx_state + num_obj_per_state + local_high_res_masks = high_res_masks[start_idx_state:end_idx_state] + local_object_score_logits = object_score_logits[ + start_idx_state:end_idx_state + ] + local_batch_size = local_high_res_masks.size(0) + # Run Sam2 memory encoder. Note that we do not re-enforce the non-overlapping constraint as it is turned off by default + + encoded_mem = self.tracker._run_memory_encoder( + tracker_state, + frame_idx, + local_batch_size, + local_high_res_masks, + local_object_score_logits, + is_mask_from_pts=False, + ) + local_maskmem_features, local_maskmem_pos_enc = encoded_mem + # Store encoded memories in the local inference state + output_dict = tracker_state["output_dict"] + for storage_key in ["cond_frame_outputs", "non_cond_frame_outputs"]: + if frame_idx not in output_dict[storage_key]: + continue + output_dict[storage_key][frame_idx]["maskmem_features"] = ( + local_maskmem_features + ) + output_dict[storage_key][frame_idx]["maskmem_pos_enc"] = [ + pos for pos in local_maskmem_pos_enc + ] + # for batched inference state, we also need to add per-object + # memory slides to support instance interactivity + self.tracker._add_output_per_object( + inference_state=tracker_state, + frame_idx=frame_idx, + current_out=output_dict[storage_key][frame_idx], + storage_key=storage_key, + ) + start_idx_state += num_obj_per_state + + def _tracker_add_new_objects( + self, + frame_idx: int, + num_frames: int, + new_obj_ids: List[int], + new_obj_masks: Tensor, + tracker_states_local: List[Any], + orig_vid_height: int, + orig_vid_width: int, + feature_cache: Dict, + ): + """Add a new object to SAM2 inference states.""" + prev_tracker_state = ( + tracker_states_local[0] if len(tracker_states_local) > 0 else None + ) + + # prepare inference_state + # batch objects that first appear on the same frame together + # Clear inference state. Keep the cached image features if available. + new_tracker_state = self.tracker.init_state( + cached_features=feature_cache, + video_height=orig_vid_height, + video_width=orig_vid_width, + num_frames=num_frames, + ) + new_tracker_state["backbone_out"] = ( + prev_tracker_state.get("backbone_out", None) + if prev_tracker_state is not None + else None + ) + + assert len(new_obj_ids) == new_obj_masks.size(0) + assert new_obj_masks.is_floating_point() + input_mask_res = self.tracker.input_mask_size + new_obj_masks = F.interpolate( + new_obj_masks.unsqueeze(1), + size=(input_mask_res, input_mask_res), + mode="bilinear", + align_corners=False, + ).squeeze(1) + new_obj_masks = new_obj_masks > 0 + + # add object one by one + for new_obj_id, new_mask in zip(new_obj_ids, new_obj_masks): + self.tracker.add_new_mask( + inference_state=new_tracker_state, + frame_idx=frame_idx, + obj_id=new_obj_id, + mask=new_mask, + add_mask_to_memory=True, + ) + # NOTE: we skip enforcing the non-overlapping constraint **globally** when adding new objects. + self.tracker.propagate_in_video_preflight( + new_tracker_state, run_mem_encoder=True + ) + tracker_states_local.append(new_tracker_state) + return tracker_states_local + + def _tracker_remove_object(self, tracker_states_local: List[Any], obj_id: int): + """ + Remove an object from SAM2 inference states. This would remove the object from + all frames in the video. + """ + tracker_states_local_before_removal = tracker_states_local.copy() + tracker_states_local.clear() + for tracker_inference_state in tracker_states_local_before_removal: + # we try to remove `obj_id` on every inference state with `strict=False` + # it will not do anything if an inference state doesn't contain `obj_id` + new_obj_ids, _ = self.tracker.remove_object( + tracker_inference_state, obj_id, strict=False, need_output=False + ) + # only keep an inference state if it's non-empty after object removal + if len(new_obj_ids) > 0: + tracker_states_local.append(tracker_inference_state) + + def _tracker_remove_objects( + self, tracker_states_local: List[Any], obj_ids: list[int] + ): + """ + Remove an object from SAM2 inference states. This would remove the object from + all frames in the video. + """ + for obj_id in obj_ids: + self._tracker_remove_object(tracker_states_local, obj_id) + + def _initialize_metadata(self): + """Initialize metadata for the masklets.""" + is_multiplex = getattr(self, "is_multiplex", False) + score_key = ( + "obj_id_to_sam2_score_frame_wise" + if is_multiplex + else "obj_id_to_tracker_score_frame_wise" + ) + tracker_metadata = { + "obj_ids_per_gpu": [np.array([], np.int64) for _ in range(self.world_size)], + "obj_ids_all_gpu": np.array([], np.int64), + "num_obj_per_gpu": np.zeros(self.world_size, np.int64), + "max_obj_id": -1, + "obj_id_to_score": {}, + score_key: defaultdict(dict), + "obj_id_to_last_occluded": {}, + } + if is_multiplex: + tracker_metadata["gpu_metadata"] = { + "N_obj": 0 + } # GPU-side metadata for sync-free hotstart + tracker_metadata["num_buc_per_gpu"] = np.zeros(self.world_size, np.int64) + + # "rank0_metadata" contains metadata that is only stored on (and accessible to) GPU 0 + # - obj_first_frame_idx: obj_id --> first frame index where the object was detected + # - unmatched_frame_inds: obj_id --> [mismatched frame indices] + # - overlap_pair_to_frame_inds: (first_appear_obj_id, obj_id) --> [overlap frame indices] + # - removed_obj_ids: object IDs that are suppressed via hot-start + # In multiplex mode, rank0_metadata is always included (all GPUs need it). + # In non-multiplex mode, only rank 0 stores it. + if is_multiplex or self.rank == 0: + rank0_metadata = { + "obj_first_frame_idx": {}, + "unmatched_frame_inds": defaultdict(list), + "trk_keep_alive": defaultdict( + int + ), # This is used only for object suppression not for removal + "overlap_pair_to_frame_inds": defaultdict(list), + "removed_obj_ids": set(), + "suppressed_obj_ids": defaultdict( + set + ), # frame_idx --> set of objects with suppressed outputs, but still continue to be tracked + } + if self.masklet_confirmation_enable: + # all the following are npt.NDArray with the same shape as `obj_ids_all_gpu` + rank0_metadata["masklet_confirmation"] = { + # "status" is the confirmation status of each masklet (in `MaskletConfirmationStatus`) + "status": np.array([], np.int64), + # "consecutive_det_num" is the number of consecutive frames where the masklet is + # detected by the detector (with a matched detection) + "consecutive_det_num": np.array([], np.int64), + } + tracker_metadata["rank0_metadata"] = rank0_metadata + + return tracker_metadata + + def update_masklet_confirmation_status( + self, + rank0_metadata: Dict[str, Any], + obj_ids_all_gpu_prev: npt.NDArray, + obj_ids_all_gpu_updated: npt.NDArray, + det_to_matched_trk_obj_ids: Dict[int, npt.NDArray], + new_det_obj_ids: npt.NDArray, + ): + confirmation_data = rank0_metadata["masklet_confirmation"] + + # a) first, expand "confirmation_data" to include new masklets added in this frame + status_prev = confirmation_data["status"] + consecutive_det_num_prev = confirmation_data["consecutive_det_num"] + assert ( + status_prev.shape == obj_ids_all_gpu_prev.shape + ), f"Got {status_prev.shape} vs {obj_ids_all_gpu_prev.shape}" + + obj_id_to_updated_idx = { + obj_id: idx for idx, obj_id in enumerate(obj_ids_all_gpu_updated) + } + prev_elem_is_in_updated = np.isin(obj_ids_all_gpu_prev, obj_ids_all_gpu_updated) + prev_elem_obj_ids_in_updated = obj_ids_all_gpu_prev[prev_elem_is_in_updated] + prev_elem_inds_in_updated = np.array( + [obj_id_to_updated_idx[obj_id] for obj_id in prev_elem_obj_ids_in_updated], + dtype=np.int64, + ) + # newly added masklets are initialized to "UNCONFIRMED" status + unconfirmed_val = MaskletConfirmationStatus.UNCONFIRMED.value + status = np.full_like(obj_ids_all_gpu_updated, fill_value=unconfirmed_val) + status[prev_elem_inds_in_updated] = status_prev[prev_elem_is_in_updated] + consecutive_det_num = np.zeros_like(obj_ids_all_gpu_updated) + consecutive_det_num[prev_elem_inds_in_updated] = consecutive_det_num_prev[ + prev_elem_is_in_updated + ] + + # b) update the confirmation status of all masklets based on the current frame + # b.1) update "consecutive_det_num" + # "is_matched": whether a masklet is matched to a detection on this frame + is_matched = np.isin(obj_ids_all_gpu_updated, new_det_obj_ids) + for matched_trk_obj_ids in det_to_matched_trk_obj_ids.values(): + is_matched |= np.isin(obj_ids_all_gpu_updated, matched_trk_obj_ids) + consecutive_det_num = np.where(is_matched, consecutive_det_num + 1, 0) + + # b.2) update "status" + change_to_confirmed = ( + consecutive_det_num >= self.masklet_confirmation_consecutive_det_thresh + ) + status[change_to_confirmed] = MaskletConfirmationStatus.CONFIRMED.value + + confirmation_data["status"] = status + confirmation_data["consecutive_det_num"] = consecutive_det_num + return rank0_metadata + + def forward(self, input: BatchedDatapoint, is_inference: bool = False): + raise NotImplementedError("Evaluation outside demo is not implemented yet") + + def _load_checkpoint(self, ckpt_path: str, strict: bool = True): + sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"] + missing_keys, unexpected_keys = self.load_state_dict(sd, strict=strict) + if len(missing_keys) > 0 or len(unexpected_keys) > 0: + logger.warning(f"Loaded ckpt with {missing_keys=}, {unexpected_keys=}") + else: + logger.info("Loaded ckpt successfully without missing or unexpected keys") + + def prep_for_evaluator(self, video_frames, tracking_res, scores_labels): + """This method is only used for benchmark eval (not used in the demo).""" + num_frames = len(video_frames) + w, h = video_frames[0].size + zero_mask = torch.zeros((1, h, w), dtype=torch.bool) + object_ids = list(scores_labels.keys()) + preds = {"scores": [], "labels": [], "boxes": [], "masks_rle": []} + for oid in object_ids: + o_masks = [] + o_score = scores_labels[oid][0].item() + o_label = scores_labels[oid][1] + for frame_idx in range(num_frames): + if frame_idx not in tracking_res: + o_masks.append(zero_mask) + else: + o_masks.append(tracking_res[frame_idx].get(oid, zero_mask)) + + o_masks = torch.cat(o_masks, dim=0) # (n_frames, H, W) + preds["scores"].append(o_score) + preds["labels"].append(o_label) + preds["boxes"].append(mask_to_box(o_masks.unsqueeze(1)).squeeze()) + preds["masks_rle"].append(rle_encode(o_masks, return_areas=True)) + + preds["boxes"] = ( + torch.stack(preds["boxes"], dim=0) + if len(preds["boxes"]) > 0 + else torch.empty( + (0, num_frames, 4), dtype=torch.float32, device=self.device + ) + ) + preds["scores"] = ( + torch.tensor(preds["scores"], device=self.device) + if len(preds["scores"]) > 0 + else torch.empty((0,), device=self.device) + ) + preds["per_frame_scores"] = preds["scores"] + preds["labels"] = ( + torch.tensor(preds["labels"], device=self.device) + if len(preds["labels"]) > 0 + else torch.empty((0,), device=self.device) + ) + return preds + + def _encode_prompt(self, **kwargs): + return self.detector._encode_prompt(**kwargs) + + def _drop_new_det_with_obj_limit(self, new_det_fa_inds, det_scores_np, num_to_keep): + """ + Drop a few new detections based on the maximum number of objects. We drop new objects based + on their detection scores, keeping the high-scoring ones and dropping the low-scoring ones. + """ + assert 0 <= num_to_keep <= len(new_det_fa_inds) + if num_to_keep == 0: + return np.array([], np.int64) # keep none + if num_to_keep == len(new_det_fa_inds): + return new_det_fa_inds # keep all + + # keep the top-scoring detections + score_order = np.argsort(det_scores_np[new_det_fa_inds])[::-1] + new_det_fa_inds = new_det_fa_inds[score_order[:num_to_keep]] + return new_det_fa_inds diff --git a/third_party/sam3/sam3/model/sam3_video_inference.py b/third_party/sam3/sam3/model/sam3_video_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..f70750538b31d12117578626053fcbed9f19cc16 --- /dev/null +++ b/third_party/sam3/sam3/model/sam3_video_inference.py @@ -0,0 +1,1710 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +import logging +from collections import defaultdict + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn.functional as F +from sam3 import perflib +from sam3.logger import get_logger +from sam3.model.act_ckpt_utils import clone_output_wrapper +from sam3.model.box_ops import box_xywh_to_cxcywh, box_xyxy_to_xywh +from sam3.model.data_misc import BatchedDatapoint, convert_my_tensors, FindStage +from sam3.model.geometry_encoders import Prompt +from sam3.model.io_utils import IMAGE_EXTS, load_resource_as_video_frames +from sam3.model.sam3_tracker_utils import fill_holes_in_mask_scores +from sam3.model.sam3_video_base import MaskletConfirmationStatus, Sam3VideoBase +from sam3.model.utils.misc import copy_data_to_device +from sam3.perflib.compile import compile_wrapper, shape_logging_wrapper +from sam3.perflib.masks_ops import masks_to_boxes as perf_masks_to_boxes +from torchvision.ops import masks_to_boxes +from tqdm.auto import tqdm + +logger = get_logger(__name__) + + +class Sam3VideoInference(Sam3VideoBase): + TEXT_ID_FOR_TEXT = 0 + TEXT_ID_FOR_VISUAL = 1 + + def __init__( + self, + image_size=1008, + image_mean=(0.5, 0.5, 0.5), + image_std=(0.5, 0.5, 0.5), + compile_model=False, + **kwargs, + ): + """ + hotstart_delay: int, the delay (in #frames) before the model starts to yield output, 0 to disable hotstart delay. + hotstart_unmatch_thresh: int, remove the object if it has this many unmatched frames within its hotstart_delay period. + If `hotstart_delay` is set to 0, this parameter is ignored. + hotstart_dup_thresh: int, remove the object if it has overlapped with another object this many frames within its hotstart_delay period. + """ + super().__init__(**kwargs) + self.image_size = image_size + self.image_mean = image_mean + self.image_std = image_std + self.compile_model = compile_model + + @torch.inference_mode() + def init_state( + self, + resource_path, + offload_video_to_cpu=False, + async_loading_frames=False, + video_loader_type="cv2", + ): + """Initialize an inference state from `resource_path` (an image or a video).""" + images, orig_height, orig_width = load_resource_as_video_frames( + resource_path=resource_path, + image_size=self.image_size, + offload_video_to_cpu=offload_video_to_cpu, + img_mean=self.image_mean, + img_std=self.image_std, + async_loading_frames=async_loading_frames, + video_loader_type=video_loader_type, + ) + inference_state = {} + inference_state["image_size"] = self.image_size + inference_state["num_frames"] = len(images) + # the original video height and width, used for resizing final output scores + inference_state["orig_height"] = orig_height + inference_state["orig_width"] = orig_width + # values that don't change across frames (so we only need to hold one copy of them) + inference_state["constants"] = {} + # inputs on each frame + self._construct_initial_input_batch(inference_state, images) + # initialize extra states + inference_state["tracker_inference_states"] = [] + inference_state["tracker_metadata"] = {} + inference_state["feature_cache"] = {} + inference_state["cached_frame_outputs"] = {} + inference_state["action_history"] = [] # for logging user actions + inference_state["is_image_only"] = is_image_type(resource_path) + return inference_state + + @torch.inference_mode() + def reset_state(self, inference_state): + """Revert `inference_state` to what it was right after initialization.""" + inference_state["input_batch"].find_text_batch[0] = "" + inference_state["text_prompt"] = None + for t in range(inference_state["num_frames"]): + inference_state["input_batch"].find_inputs[t].text_ids[...] = 0 + # constructing an output list in inference state (we start with an empty list) + inference_state["previous_stages_out"][t] = None + inference_state["per_frame_raw_point_input"][t] = None + inference_state["per_frame_raw_box_input"][t] = None + inference_state["per_frame_visual_prompt"][t] = None + inference_state["per_frame_geometric_prompt"][t] = None + inference_state["per_frame_cur_step"][t] = 0 + + inference_state["visual_prompt_embed"] = None + inference_state["visual_prompt_mask"] = None + inference_state["tracker_inference_states"].clear() + inference_state["tracker_metadata"].clear() + inference_state["feature_cache"].clear() + inference_state["cached_frame_outputs"].clear() + inference_state["action_history"].clear() # for logging user actions + + def _construct_initial_input_batch(self, inference_state, images): + """Construct an initial `BatchedDatapoint` instance as input.""" + # 1) img_batch + num_frames = len(images) + device = self.device + + # 2) find_text_batch + # "" will be replaced by the actual text prompt when adding prompts + find_text_batch = ["", "visual"] + + # 3) find_inputs + input_box_embedding_dim = 258 # historical default + input_points_embedding_dim = 257 # historical default + stages = [ + FindStage( + img_ids=[stage_id], + text_ids=[0], + input_boxes=[torch.zeros(input_box_embedding_dim)], + input_boxes_mask=[torch.empty(0, dtype=torch.bool)], + input_boxes_label=[torch.empty(0, dtype=torch.long)], + input_points=[torch.empty(0, input_points_embedding_dim)], + input_points_mask=[torch.empty(0)], + object_ids=[], + ) + for stage_id in range(num_frames) + ] + for i in range(len(stages)): + stages[i] = convert_my_tensors(stages[i]) + + # construct the final `BatchedDatapoint` and cast to GPU + input_batch = BatchedDatapoint( + img_batch=images, + find_text_batch=find_text_batch, + find_inputs=stages, + find_targets=[None] * num_frames, + find_metadatas=[None] * num_frames, + ) + input_batch = copy_data_to_device(input_batch, device, non_blocking=True) + inference_state["input_batch"] = input_batch + + # construct the placeholder interactive prompts and tracking queries + bs = 1 + inference_state["constants"]["empty_geometric_prompt"] = Prompt( + box_embeddings=torch.zeros(0, bs, 4, device=device), + box_mask=torch.zeros(bs, 0, device=device, dtype=torch.bool), + box_labels=torch.zeros(0, bs, device=device, dtype=torch.long), + point_embeddings=torch.zeros(0, bs, 2, device=device), + point_mask=torch.zeros(bs, 0, device=device, dtype=torch.bool), + point_labels=torch.zeros(0, bs, device=device, dtype=torch.long), + ) + + # constructing an output list in inference state (we start with an empty list) + inference_state["previous_stages_out"] = [None] * num_frames + inference_state["text_prompt"] = None + inference_state["per_frame_raw_point_input"] = [None] * num_frames + inference_state["per_frame_raw_box_input"] = [None] * num_frames + inference_state["per_frame_visual_prompt"] = [None] * num_frames + inference_state["per_frame_geometric_prompt"] = [None] * num_frames + inference_state["per_frame_cur_step"] = [0] * num_frames + + # placeholders for cached outputs + # (note: currently, a single visual prompt embedding is shared for all frames) + inference_state["visual_prompt_embed"] = None + inference_state["visual_prompt_mask"] = None + + def _get_visual_prompt(self, inference_state, frame_idx, boxes_cxcywh, box_labels): + """ + Handle the case of visual prompt. Currently, in the inference API we do not + explicitly distinguish between initial box as visual prompt vs subsequent boxes + or boxes after inference for refinement. + """ + # If the frame hasn't had any inference results before (prompting or propagation), + # we treat the first added box prompt as a visual prompt; otherwise, we treat + # the first box just as a refinement prompt. + is_new_visual_prompt = ( + inference_state["per_frame_visual_prompt"][frame_idx] is None + and inference_state["previous_stages_out"][frame_idx] is None + ) + if is_new_visual_prompt: + if boxes_cxcywh.size(0) != 1: + raise RuntimeError( + "visual prompts (box as an initial prompt) should only have one box, " + f"but got {boxes_cxcywh.shape=}" + ) + if not box_labels.item(): + logging.warning("A negative box is added as a visual prompt.") + # take the first box prompt as a visual prompt + device = self.device + new_visual_prompt = Prompt( + box_embeddings=boxes_cxcywh[None, 0:1, :].to(device), # (seq, bs, 4) + box_mask=None, + box_labels=box_labels[None, 0:1].to(device), # (seq, bs) + point_embeddings=None, + point_mask=None, + point_labels=None, + ) + inference_state["per_frame_visual_prompt"][frame_idx] = new_visual_prompt + else: + new_visual_prompt = None + + # `boxes_cxcywh` and `box_labels` contains all the raw box inputs added so far + # strip any visual prompt from the input boxes (for geometric prompt encoding) + if inference_state["per_frame_visual_prompt"][frame_idx] is not None: + boxes_cxcywh = boxes_cxcywh[1:] + box_labels = box_labels[1:] + + return boxes_cxcywh, box_labels, new_visual_prompt + + def _get_processing_order( + self, inference_state, start_frame_idx, max_frame_num_to_track, reverse + ): + num_frames = inference_state["num_frames"] + previous_stages_out = inference_state["previous_stages_out"] + if all(out is None for out in previous_stages_out) and start_frame_idx is None: + raise RuntimeError( + "No prompts are received on any frames. Please add prompt on at least one frame before propagation." + ) + # set start index, end index, and processing order + if start_frame_idx is None: + # default: start from the earliest frame with input points + start_frame_idx = min( + t for t, out in enumerate(previous_stages_out) if out is not None + ) + if max_frame_num_to_track is None: + # default: track all the frames in the video + max_frame_num_to_track = num_frames + if reverse: + end_frame_idx = start_frame_idx - max_frame_num_to_track + end_frame_idx = max(end_frame_idx, 0) + processing_order = range(start_frame_idx - 1, end_frame_idx - 1, -1) + else: + end_frame_idx = start_frame_idx + max_frame_num_to_track + end_frame_idx = min(end_frame_idx, num_frames - 1) + processing_order = range(start_frame_idx, end_frame_idx + 1) + return processing_order, end_frame_idx + + @torch.inference_mode() + def propagate_in_video( + self, + inference_state, + start_frame_idx=None, + max_frame_num_to_track=None, + reverse=False, + ): + """ + Propagate the prompts to get grounding results for the entire video. This method + is a generator and yields inference outputs for all frames in the range specified + by `start_frame_idx`, `max_frame_num_to_track`, and `reverse`. + """ + # compile the model (it's a no-op if the model is already compiled) + # note that it's intentionally added to `self.propagate_in_video`, so that the first + # `self.add_prompt` call will be done in eager mode to fill in the decoder buffers + # such as positional encoding cache) + self._compile_model() + + processing_order, end_frame_idx = self._get_processing_order( + inference_state, + start_frame_idx, + max_frame_num_to_track, + reverse=reverse, + ) + + # Store max_frame_num_to_track in feature_cache for downstream methods + inference_state["feature_cache"]["tracking_bounds"] = { + "max_frame_num_to_track": max_frame_num_to_track, + "propagate_in_video_start_frame_idx": start_frame_idx, + } + + hotstart_buffer = [] + hotstart_removed_obj_ids = set() + # when deciding whether to output a masklet on `yield_frame_idx`, we check whether the object is confirmed + # in a future frame (`unconfirmed_frame_delay` frames after the current frame). For example, if we require + # an object to be detected in 3 consecutive frames to be confirmed, then we look 2 frames in the future -- + # e.g., we output an object on frame 4 only if it becomes confirmed on frame 6. + unconfirmed_status_delay = self.masklet_confirmation_consecutive_det_thresh - 1 + unconfirmed_obj_ids_per_frame = {} # frame_idx -> hidden_obj_ids + for frame_idx in tqdm( + processing_order, desc="propagate_in_video", disable=self.rank > 0 + ): + out = self._run_single_frame_inference(inference_state, frame_idx, reverse) + + if self.hotstart_delay > 0: + # accumulate the outputs for the first `hotstart_delay` frames + hotstart_buffer.append([frame_idx, out]) + # update the object IDs removed by hotstart so that we don't output them + if self.rank == 0: + hotstart_removed_obj_ids.update(out["removed_obj_ids"]) + unconfirmed_obj_ids = out.get("unconfirmed_obj_ids", None) + if unconfirmed_obj_ids is not None: + unconfirmed_obj_ids_per_frame[frame_idx] = unconfirmed_obj_ids + + if frame_idx == end_frame_idx: + # we reached the end of propagation -- yield all frames in the buffer + yield_list = hotstart_buffer + hotstart_buffer = [] + elif len(hotstart_buffer) >= self.hotstart_delay: + # we have enough frames -- yield and remove the first (oldest) frame from the buffer + yield_list = hotstart_buffer[:1] + hotstart_buffer = hotstart_buffer[1:] + else: + # not enough frames yet -- skip yielding + yield_list = [] + else: + yield_list = [(frame_idx, out)] # output the current frame + + for yield_frame_idx, yield_out in yield_list: + # post-process the output and yield it + if self.rank == 0: + suppressed_obj_ids = yield_out["suppressed_obj_ids"] + unconfirmed_status_frame_idx = ( + yield_frame_idx + unconfirmed_status_delay + if not reverse + else yield_frame_idx - unconfirmed_status_delay + ) + + # Clamp the frame index to stay within video bounds + num_frames = inference_state["num_frames"] + unconfirmed_status_frame_idx = max( + 0, min(unconfirmed_status_frame_idx, num_frames - 1) + ) + + unconfirmed_obj_ids = unconfirmed_obj_ids_per_frame.get( + unconfirmed_status_frame_idx, None + ) + postprocessed_out = self._postprocess_output( + inference_state, + yield_out, + hotstart_removed_obj_ids, + suppressed_obj_ids, + unconfirmed_obj_ids, + ) + + self._cache_frame_outputs( + inference_state, + yield_frame_idx, + yield_out["obj_id_to_mask"], + suppressed_obj_ids=suppressed_obj_ids, + removed_obj_ids=hotstart_removed_obj_ids, + unconfirmed_obj_ids=unconfirmed_obj_ids, + ) + else: + postprocessed_out = None # no output on other GPUs + yield yield_frame_idx, postprocessed_out + + def _run_single_frame_inference(self, inference_state, frame_idx, reverse): + """ + Perform inference on a single frame and get its inference results. This would + also update `inference_state`. + """ + # prepare inputs + input_batch = inference_state["input_batch"] + tracker_states_local = inference_state["tracker_inference_states"] + has_text_prompt = inference_state["text_prompt"] is not None + has_geometric_prompt = ( + inference_state["per_frame_geometric_prompt"][frame_idx] is not None + ) + # run inference for the current frame + ( + obj_id_to_mask, + obj_id_to_score, + tracker_states_local_new, + tracker_metadata_new, + frame_stats, + _, + ) = self._det_track_one_frame( + frame_idx=frame_idx, + num_frames=inference_state["num_frames"], + reverse=reverse, + input_batch=input_batch, + geometric_prompt=( + inference_state["constants"]["empty_geometric_prompt"] + if not has_geometric_prompt + else inference_state["per_frame_geometric_prompt"][frame_idx] + ), + tracker_states_local=tracker_states_local, + tracker_metadata_prev=inference_state["tracker_metadata"], + feature_cache=inference_state["feature_cache"], + orig_vid_height=inference_state["orig_height"], + orig_vid_width=inference_state["orig_width"], + is_image_only=inference_state["is_image_only"], + allow_new_detections=has_text_prompt or has_geometric_prompt, + ) + # update inference state + inference_state["tracker_inference_states"] = tracker_states_local_new + inference_state["tracker_metadata"] = tracker_metadata_new + # use a dummy string in "previous_stages_out" to indicate this frame has outputs + inference_state["previous_stages_out"][frame_idx] = "_THIS_FRAME_HAS_OUTPUTS_" + + if self.rank == 0: + self._cache_frame_outputs(inference_state, frame_idx, obj_id_to_mask) + + out = { + "obj_id_to_mask": obj_id_to_mask, + "obj_id_to_score": obj_id_to_score, # first frame detection score + "obj_id_to_tracker_score": tracker_metadata_new[ + "obj_id_to_tracker_score_frame_wise" + ][frame_idx], + } + # removed_obj_ids is only needed on rank 0 to handle hotstart delay buffer + if self.rank == 0: + rank0_metadata = tracker_metadata_new["rank0_metadata"] + removed_obj_ids = rank0_metadata["removed_obj_ids"] + out["removed_obj_ids"] = removed_obj_ids + out["suppressed_obj_ids"] = rank0_metadata["suppressed_obj_ids"][frame_idx] + out["frame_stats"] = frame_stats + if self.masklet_confirmation_enable: + status = rank0_metadata["masklet_confirmation"]["status"] + is_unconfirmed = status == MaskletConfirmationStatus.UNCONFIRMED.value + out["unconfirmed_obj_ids"] = tracker_metadata_new["obj_ids_all_gpu"][ + is_unconfirmed + ].tolist() + else: + out["unconfirmed_obj_ids"] = [] + + return out + + def _postprocess_output( + self, + inference_state, + out, + removed_obj_ids=None, + suppressed_obj_ids=None, + unconfirmed_obj_ids=None, + ): + obj_id_to_mask = out["obj_id_to_mask"] # low res masks + curr_obj_ids = sorted(obj_id_to_mask.keys()) + H_video, W_video = inference_state["orig_height"], inference_state["orig_width"] + if len(curr_obj_ids) == 0: + out_obj_ids = torch.zeros(0, dtype=torch.int64) + out_probs = torch.zeros(0, dtype=torch.float32) + out_binary_masks = torch.zeros(0, H_video, W_video, dtype=torch.bool) + out_boxes_xywh = torch.zeros(0, 4, dtype=torch.float32) + else: + out_obj_ids = torch.tensor(curr_obj_ids, dtype=torch.int64) + out_probs = torch.tensor( + [out["obj_id_to_score"][obj_id] for obj_id in curr_obj_ids] + ) + out_tracker_probs = torch.tensor( + [ + ( + out["obj_id_to_tracker_score"][obj_id] + if obj_id in out["obj_id_to_tracker_score"] + else 0.0 + ) + for obj_id in curr_obj_ids + ] + ) + out_binary_masks = torch.cat( + [obj_id_to_mask[obj_id] for obj_id in curr_obj_ids], dim=0 + ) + + assert out_binary_masks.dtype == torch.bool + keep = out_binary_masks.any(dim=(1, 2)).cpu() # remove masks with 0 areas + # hide outputs for those object IDs in `obj_ids_to_hide` + obj_ids_to_hide = [] + if suppressed_obj_ids is not None: + obj_ids_to_hide.extend(suppressed_obj_ids) + if removed_obj_ids is not None: + obj_ids_to_hide.extend(removed_obj_ids) + if unconfirmed_obj_ids is not None: + obj_ids_to_hide.extend(unconfirmed_obj_ids) + if len(obj_ids_to_hide) > 0: + obj_ids_to_hide_t = torch.tensor(obj_ids_to_hide, dtype=torch.int64) + keep &= ~torch.isin(out_obj_ids, obj_ids_to_hide_t) + + # slice those valid entries from the original outputs + keep_idx = torch.nonzero(keep, as_tuple=True)[0] + keep_idx_gpu = keep_idx.pin_memory().to( + device=out_binary_masks.device, non_blocking=True + ) + + out_obj_ids = torch.index_select(out_obj_ids, 0, keep_idx) + out_probs = torch.index_select(out_probs, 0, keep_idx) + out_tracker_probs = torch.index_select(out_tracker_probs, 0, keep_idx) + out_binary_masks = torch.index_select(out_binary_masks, 0, keep_idx_gpu) + + if perflib.is_enabled: + out_boxes_xyxy = perf_masks_to_boxes( + out_binary_masks, out_obj_ids.tolist() + ) + else: + out_boxes_xyxy = masks_to_boxes(out_binary_masks) + + out_boxes_xywh = box_xyxy_to_xywh(out_boxes_xyxy) # convert to xywh format + # normalize boxes + out_boxes_xywh[..., 0] /= W_video + out_boxes_xywh[..., 1] /= H_video + out_boxes_xywh[..., 2] /= W_video + out_boxes_xywh[..., 3] /= H_video + + # apply non-overlapping constraints on the existing masklets + if out_binary_masks.shape[0] > 1: + assert len(out_binary_masks) == len(out_tracker_probs) + out_binary_masks = ( + self.tracker._apply_object_wise_non_overlapping_constraints( + out_binary_masks.unsqueeze(1), + out_tracker_probs.unsqueeze(1).to(out_binary_masks.device), + background_value=0, + ).squeeze(1) + ) > 0 + + outputs = { + "out_obj_ids": out_obj_ids.cpu().numpy(), + "out_probs": out_probs.cpu().numpy(), + "out_boxes_xywh": out_boxes_xywh.cpu().numpy(), + "out_binary_masks": out_binary_masks.cpu().numpy(), + "frame_stats": out.get("frame_stats", None), + } + return outputs + + def _cache_frame_outputs( + self, + inference_state, + frame_idx, + obj_id_to_mask, + suppressed_obj_ids=None, + removed_obj_ids=None, + unconfirmed_obj_ids=None, + ): + # Filter out suppressed, removed, and unconfirmed objects from the cache + filtered_obj_id_to_mask = obj_id_to_mask.copy() + + objects_to_exclude = set() + if suppressed_obj_ids is not None: + objects_to_exclude.update(suppressed_obj_ids) + if removed_obj_ids is not None: + objects_to_exclude.update(removed_obj_ids) + if unconfirmed_obj_ids is not None: + objects_to_exclude.update(unconfirmed_obj_ids) + + if objects_to_exclude: + for obj_id in objects_to_exclude: + if obj_id in filtered_obj_id_to_mask: + del filtered_obj_id_to_mask[obj_id] + + inference_state["cached_frame_outputs"][frame_idx] = filtered_obj_id_to_mask + + def _build_tracker_output( + self, inference_state, frame_idx, refined_obj_id_to_mask=None + ): + assert ( + "cached_frame_outputs" in inference_state + and frame_idx in inference_state["cached_frame_outputs"] + ), "No cached outputs found. Ensure normal propagation has run first to populate the cache." + cached_outputs = inference_state["cached_frame_outputs"][frame_idx] + + obj_id_to_mask = cached_outputs.copy() + + # Update with refined masks if provided + if refined_obj_id_to_mask is not None: + for obj_id, refined_mask in refined_obj_id_to_mask.items(): + assert ( + refined_mask is not None + ), f"Refined mask data must be provided for obj_id {obj_id}" + obj_id_to_mask[obj_id] = refined_mask + + return obj_id_to_mask + + def _compile_model(self): + """Compile the SAM model with torch.compile for speedup.""" + is_compiled = getattr(self, "_model_is_compiled", False) + if is_compiled or not self.compile_model: + return + + import torch._dynamo + + # a larger cache size to hold varying number of shapes for torch.compile + # see https://github.com/pytorch/pytorch/blob/v2.5.1/torch/_dynamo/config.py#L42-L49 + torch._dynamo.config.cache_size_limit = 128 + torch._dynamo.config.accumulated_cache_size_limit = 2048 + torch._dynamo.config.capture_scalar_outputs = True + torch._dynamo.config.suppress_errors = True + + # Compile module components + # skip compilation of `_encode_prompt` since it sometimes tiggger SymInt errors + # self._encode_prompt = clone_output_wrapper( + # torch.compile(self._encode_prompt, fullgraph=True, mode="max-autotune") + # ) + + ## Compile SAM3 model components + self.detector.backbone.vision_backbone.forward = clone_output_wrapper( + torch.compile( + self.detector.backbone.vision_backbone.forward, + fullgraph=True, + mode="max-autotune", + ) + ) + self.detector.transformer.encoder.forward = clone_output_wrapper( + torch.compile( + self.detector.transformer.encoder.forward, + fullgraph=True, + mode="max-autotune", + ) + ) + self.detector.transformer.decoder.forward = clone_output_wrapper( + torch.compile( + self.detector.transformer.decoder.forward, + fullgraph=True, + mode="max-autotune", + dynamic=False, + ) + ) + + self.detector.segmentation_head.forward = clone_output_wrapper( + torch.compile( + self.detector.segmentation_head.forward, + fullgraph=True, + mode="max-autotune", + ) + ) + + ## Compile Tracker model components + self.tracker.maskmem_backbone.forward = compile_wrapper( + self.tracker.maskmem_backbone.forward, + mode="max-autotune-no-cudagraphs", + fullgraph=True, + dynamic=False, + ) + + self.tracker.transformer.encoder.forward = shape_logging_wrapper( + compile_wrapper( + self.tracker.transformer.encoder.forward, + mode="max-autotune-no-cudagraphs", + fullgraph=True, + dynamic=True, + ), + keep_kwargs=["src", "src_pos", "prompt", "prompt_pos"], + ) + + self.tracker.sam_mask_decoder.forward = compile_wrapper( + self.tracker.sam_mask_decoder.forward, + mode="max-autotune", + fullgraph=True, + dynamic=False, # Accuracy regression on True + ) + + self._model_is_compiled = True + + def _warm_up_vg_propagation(self, inference_state, start_frame_idx=0): + # use different tracking score thresholds for each round to simulate different number of output objects + num_objects_list = range(self.num_obj_for_compile + 1) + new_det_score_thresh_list = [0.3, 0.5, 0.7] + num_rounds = len(new_det_score_thresh_list) + orig_new_det_thresh = self.new_det_thresh + + for i, thresh in enumerate(new_det_score_thresh_list): + self.new_det_thresh = thresh + for num_objects in num_objects_list: + logger.info(f"{i + 1}/{num_rounds} warming up model compilation") + self.add_prompt( + inference_state, frame_idx=start_frame_idx, text_str="cat" + ) + logger.info( + f"{i + 1}/{num_rounds} warming up model compilation -- simulating {num_objects}/{self.num_obj_for_compile} objects" + ) + inference_state = self.add_fake_objects_to_inference_state( + inference_state, num_objects, frame_idx=start_frame_idx + ) + inference_state["tracker_metadata"]["rank0_metadata"].update( + { + "masklet_confirmation": { + "status": np.zeros(num_objects, dtype=np.int64), + "consecutive_det_num": np.zeros( + num_objects, dtype=np.int64 + ), + } + } + ) + for _ in self.propagate_in_video( + inference_state, start_frame_idx, reverse=False + ): + pass + for _ in self.propagate_in_video( + inference_state, start_frame_idx, reverse=True + ): + pass + self.reset_state(inference_state) + logger.info( + f"{i + 1}/{num_rounds} warming up model compilation -- completed round {i + 1} out of {num_rounds}" + ) + + # Warm up Tracker memory encoder with varying input shapes + num_iters = 3 + feat_size = self.tracker.sam_image_embedding_size**2 # 72 * 72 = 5184 + hidden_dim = self.tracker.hidden_dim # 256 + mem_dim = self.tracker.mem_dim # 64 + for _ in tqdm(range(num_iters)): + for b in range(1, self.num_obj_for_compile + 1): + for i in range( + 1, + self.tracker.max_cond_frames_in_attn + self.tracker.num_maskmem, + ): + for j in range( + self.tracker.max_cond_frames_in_attn + + self.tracker.max_obj_ptrs_in_encoder + ): + num_obj_ptr_tokens = (hidden_dim // mem_dim) * j + src = torch.randn(feat_size, b, hidden_dim, device=self.device) + src_pos = torch.randn( + feat_size, b, hidden_dim, device=self.device + ) + prompt = torch.randn( + feat_size * i + num_obj_ptr_tokens, + b, + mem_dim, + device=self.device, + ) + prompt_pos = torch.randn( + feat_size * i + num_obj_ptr_tokens, + b, + mem_dim, + device=self.device, + ) + + self.tracker.transformer.encoder.forward( + src=src, + src_pos=src_pos, + prompt=prompt, + prompt_pos=prompt_pos, + num_obj_ptr_tokens=num_obj_ptr_tokens, + ) + + self.new_det_thresh = orig_new_det_thresh + return inference_state + + def add_fake_objects_to_inference_state( + self, inference_state, num_objects, frame_idx + ): + new_det_obj_ids_local = np.arange(num_objects) + high_res_H, high_res_W = ( + self.tracker.maskmem_backbone.mask_downsampler.interpol_size + ) + new_det_masks = torch.ones( + len(new_det_obj_ids_local), high_res_H, high_res_W + ).to(self.device) + + inference_state["tracker_inference_states"] = self._tracker_add_new_objects( + frame_idx=frame_idx, + num_frames=inference_state["num_frames"], + new_obj_ids=new_det_obj_ids_local, + new_obj_masks=new_det_masks, + tracker_states_local=inference_state["tracker_inference_states"], + orig_vid_height=inference_state["orig_height"], + orig_vid_width=inference_state["orig_width"], + feature_cache=inference_state["feature_cache"], + ) + + # Synthesize obj_id_to_mask data for cached_frame_outputs to support _build_tracker_output during warmup + obj_id_to_mask = {} + if num_objects > 0: + H_video = inference_state["orig_height"] + W_video = inference_state["orig_width"] + + video_res_masks = F.interpolate( + new_det_masks.unsqueeze(1), # Add channel dimension for interpolation + size=(H_video, W_video), + mode="bilinear", + align_corners=False, + ) # (num_objects, 1, H_video, W_video) + for i, obj_id in enumerate(new_det_obj_ids_local): + obj_id_to_mask[obj_id] = (video_res_masks[i] > 0.0).to(torch.bool) + if self.rank == 0: + for fidx in range(inference_state["num_frames"]): + self._cache_frame_outputs(inference_state, fidx, obj_id_to_mask) + + inference_state["tracker_metadata"].update( + { + "obj_ids_per_gpu": [np.arange(num_objects)], + "obj_ids_all_gpu": np.arange(num_objects), # Same as 1 GPU + "num_obj_per_gpu": [num_objects], + "obj_id_to_score": {i: 1.0 for i in range(num_objects)}, + "max_obj_id": num_objects, + "rank0_metadata": { + "masklet_confirmation": { + "status": np.zeros(num_objects, dtype=np.int64), + "consecutive_det_num": np.zeros(num_objects, dtype=np.int64), + }, + "removed_obj_ids": set(), + "suppressed_obj_ids": defaultdict(set), + }, + } + ) + return inference_state + + @torch.inference_mode() + @torch.autocast(device_type="cuda", dtype=torch.bfloat16) + def warm_up_compilation(self): + """ + Warm up the model by running a dummy inference to compile the model. This is + useful to avoid the compilation overhead in the first inference call. + """ + if not self.compile_model: + return + self._warm_up_complete = False + if self.device.type != "cuda": + raise RuntimeError( + f"The model must be on CUDA for warm-up compilation, got {self.device=}." + ) + + # temporally set to single GPU temporarily for warm-up compilation + orig_rank = self.rank + orig_world_size = self.world_size + self.rank = self.detector.rank = 0 + self.world_size = self.detector.world_size = 1 + orig_recondition_every_nth_frame = self.recondition_every_nth_frame + # self.recondition_every_nth_frame = 2 + + # Get a random video + inference_state = self.init_state(resource_path="") + start_frame_idx = 0 + + # Run basic propagation warm-up + inference_state = self._warm_up_vg_propagation(inference_state, start_frame_idx) + + logger.info("Warm-up compilation completed.") + + # revert to the original GPU and rank + self.rank = self.detector.rank = orig_rank + self.world_size = self.detector.world_size = orig_world_size + self.recondition_every_nth_frame = orig_recondition_every_nth_frame + self._warm_up_complete = True + self.tracker.transformer.encoder.forward.set_logging(True) + + @torch.inference_mode() + def add_prompt( + self, + inference_state, + frame_idx, + text_str=None, + boxes_xywh=None, + box_labels=None, + ): + """ + Add text, point or box prompts on a single frame. This method returns the inference + outputs only on the prompted frame. + + Note that text prompts are NOT associated with a particular frame (i.e. they apply + to all frames). However, we only run inference on the frame specified in `frame_idx`. + """ + logger.debug("Running add_prompt on frame %d", frame_idx) + + num_frames = inference_state["num_frames"] + assert ( + text_str is not None or boxes_xywh is not None + ), "at least one type of prompt (text, boxes) must be provided" + assert ( + 0 <= frame_idx < num_frames + ), f"{frame_idx=} is out of range for a total of {num_frames} frames" + + # since it's a semantic prompt, we start over + self.reset_state(inference_state) + + # 1) add text prompt + if text_str is not None and text_str != "visual": + inference_state["text_prompt"] = text_str + inference_state["input_batch"].find_text_batch[0] = text_str + text_id = self.TEXT_ID_FOR_TEXT + else: + inference_state["text_prompt"] = None + inference_state["input_batch"].find_text_batch[0] = "" + text_id = self.TEXT_ID_FOR_VISUAL + for t in range(inference_state["num_frames"]): + inference_state["input_batch"].find_inputs[t].text_ids[...] = text_id + + # 2) handle box prompt + assert (boxes_xywh is not None) == (box_labels is not None) + if boxes_xywh is not None: + boxes_xywh = torch.as_tensor(boxes_xywh, dtype=torch.float32) + box_labels = torch.as_tensor(box_labels, dtype=torch.long) + # input boxes are expected to be [xmin, ymin, width, height] format + # in normalized coordinates of range 0~1, similar to FA + assert boxes_xywh.dim() == 2 + assert boxes_xywh.size(0) > 0 and boxes_xywh.size(-1) == 4 + assert box_labels.dim() == 1 and box_labels.size(0) == boxes_xywh.size(0) + boxes_cxcywh = box_xywh_to_cxcywh(boxes_xywh) + assert (boxes_xywh >= 0).all().item() and (boxes_xywh <= 1).all().item() + assert (boxes_cxcywh >= 0).all().item() and (boxes_cxcywh <= 1).all().item() + + new_box_input = boxes_cxcywh, box_labels + inference_state["per_frame_raw_box_input"][frame_idx] = new_box_input + + # handle the case of visual prompt (also added as an input box from the UI) + boxes_cxcywh, box_labels, geometric_prompt = self._get_visual_prompt( + inference_state, frame_idx, boxes_cxcywh, box_labels + ) + + inference_state["per_frame_geometric_prompt"][frame_idx] = geometric_prompt + + out = self._run_single_frame_inference( + inference_state, frame_idx, reverse=False + ) + return frame_idx, self._postprocess_output(inference_state, out) + + @torch.autocast(device_type="cuda", dtype=torch.bfloat16) + def forward(self, input: BatchedDatapoint, is_inference: bool = False): + """This method is only used for benchmark eval (not used in the demo).""" + # set the model to single GPU for benchmark evaluation (to be compatible with trainer) + orig_rank = self.rank + orig_world_size = self.world_size + self.rank = self.detector.rank = 0 + self.world_size = self.detector.world_size = 1 + + # get data + text_prompt_ids = input.find_metadatas[0].original_category_id + text_prompt_list = input.find_text_batch + + # loop over txt prompts + tracking_res = defaultdict(dict) # frame_idx --> {obj_id: mask} + scores_labels = defaultdict(tuple) # obj_id --> (score, text_prompt_id) + inference_state = self.init_state(resource_path=input.raw_images) + for prompt_id, prompt in zip(text_prompt_ids, text_prompt_list): + self.add_prompt(inference_state, frame_idx=0, text_str=prompt) + start_obj_id = max(scores_labels.keys(), default=-1) + 1 # prev max + 1 + + # propagate the prompts + obj_ids_this_prompt = set() + for frame_idx, out in self.propagate_in_video( + inference_state, + start_frame_idx=0, + max_frame_num_to_track=inference_state["num_frames"], + reverse=False, + ): + current_frame_res = tracking_res[frame_idx] + for obj_id, mask in zip(out["out_obj_ids"], out["out_binary_masks"]): + mask_tensor = torch.tensor(mask[None], dtype=torch.bool) + current_frame_res[obj_id + start_obj_id] = mask_tensor + obj_ids_this_prompt.update(current_frame_res.keys()) + + obj_id_to_score = inference_state["tracker_metadata"]["obj_id_to_score"] + for obj_id, score in obj_id_to_score.items(): + if obj_id + start_obj_id in obj_ids_this_prompt: + score_tensor = torch.tensor(score, dtype=torch.float32) + scores_labels[obj_id + start_obj_id] = (score_tensor, prompt_id) + + self.reset_state(inference_state) + + video_id = input.find_metadatas[0].original_image_id[0].cpu().item() + preds = self.prep_for_evaluator(input.raw_images, tracking_res, scores_labels) + + # revert the model to the original GPU and rank + self.rank = self.detector.rank = orig_rank + self.world_size = self.detector.world_size = orig_world_size + return {video_id: preds} + + def back_convert(self, targets): + # Needed for retraining compatibility with trainer + return targets + + +class Sam3VideoInferenceWithInstanceInteractivity(Sam3VideoInference): + def __init__( + self, + use_prev_mem_frame=False, + use_stateless_refinement=False, + refinement_detector_cond_frame_removal_window=16, + **kwargs, + ): + """ + use_prev_mem_frame: bool, whether to condition on previous memory frames for adding points + use_stateless_refinement: bool, whether to enable stateless refinement behavior + refinement_detector_cond_frame_removal_window: int, we remove a detector conditioning frame if it + is within this many frames of a user refined frame. Set to a large value (e.g. 10000) to + always remove detector conditioning frames if there is any user refinement in the video. + """ + super().__init__(**kwargs) + self.use_prev_mem_frame = use_prev_mem_frame + self.use_stateless_refinement = use_stateless_refinement + self.refinement_detector_cond_frame_removal_window = ( + refinement_detector_cond_frame_removal_window + ) + + def _init_new_tracker_state(self, inference_state): + return self.tracker.init_state( + cached_features=inference_state["feature_cache"], + video_height=inference_state["orig_height"], + video_width=inference_state["orig_width"], + num_frames=inference_state["num_frames"], + ) + + @torch.inference_mode() + def propagate_in_video( + self, + inference_state, + start_frame_idx=None, + max_frame_num_to_track=None, + reverse=False, + ): + # step 1: check which type of propagation to run, should be the same for all GPUs. + propagation_type, obj_ids = self.parse_action_history_for_propagation( + inference_state + ) + self.add_action_history( + inference_state, + action_type=propagation_type, + obj_ids=obj_ids, + frame_idx=start_frame_idx, + ) + + # step 2: run full VG propagation + if propagation_type == "propagation_full": + logger.debug(f"Running full VG propagation (reverse={reverse}).") + yield from super().propagate_in_video( + inference_state, + start_frame_idx=start_frame_idx, + max_frame_num_to_track=max_frame_num_to_track, + reverse=reverse, + ) + return + + # step 3: run Tracker partial propagation or direct fetch existing predictions + assert propagation_type in ["propagation_partial", "propagation_fetch"] + logger.debug( + f"Running Tracker propagation for objects {obj_ids} and merging it with existing VG predictions (reverse={reverse})." + if propagation_type == "propagation_partial" + else f"Fetching existing VG predictions without running any propagation (reverse={reverse})." + ) + processing_order, _ = self._get_processing_order( + inference_state, + start_frame_idx=start_frame_idx, + max_frame_num_to_track=max_frame_num_to_track, + reverse=reverse, + ) + + tracker_metadata = inference_state["tracker_metadata"] + + # if fetch just return from output + if propagation_type == "propagation_fetch": + for frame_idx in tqdm(processing_order): + if self.rank == 0: + obj_id_to_mask = inference_state["cached_frame_outputs"].get( + frame_idx, {} + ) + # post processing - remove suppressed obj_ids + obj_id_to_score = tracker_metadata["obj_id_to_score"] + suppressed_obj_ids = tracker_metadata["rank0_metadata"][ + "suppressed_obj_ids" + ][frame_idx] + obj_id_to_tracker_score = tracker_metadata[ + "obj_id_to_tracker_score_frame_wise" + ][frame_idx] + + out = { + "obj_id_to_mask": obj_id_to_mask, + "obj_id_to_score": obj_id_to_score, + "obj_id_to_tracker_score": obj_id_to_tracker_score, + } + yield ( + frame_idx, + self._postprocess_output( + inference_state, out, suppressed_obj_ids=suppressed_obj_ids + ), + ) + else: + yield frame_idx, None + + return + + # get Tracker inference states containing selected obj_ids + if propagation_type == "propagation_partial": + # can be empty for GPUs where objects are not in their inference states + tracker_states_local = self._get_tracker_inference_states_by_obj_ids( + inference_state, obj_ids + ) + for tracker_state in tracker_states_local: + self.tracker.propagate_in_video_preflight( + tracker_state, run_mem_encoder=True + ) + + for frame_idx in tqdm(processing_order): + # run Tracker propagation + if propagation_type == "propagation_partial": + self._prepare_backbone_feats(inference_state, frame_idx, reverse) + obj_ids_local, low_res_masks_local, tracker_scores_local = ( + self._propogate_tracker_one_frame_local_gpu( + tracker_states_local, + frame_idx=frame_idx, + reverse=reverse, + run_mem_encoder=True, + ) + ) + + # broadcast refined object tracker scores and masks to all GPUs + # handle multiple objects that can be located on different GPUs + refined_obj_data = {} # obj_id -> (score, mask_video_res) + + # Collect data for objects on this GPU + local_obj_data = {} + for obj_id in obj_ids: + obj_rank = self._get_gpu_id_by_obj_id(inference_state, obj_id) + if self.rank == obj_rank and obj_id in obj_ids_local: + refined_obj_idx = obj_ids_local.index(obj_id) + refined_mask_low_res = low_res_masks_local[ + refined_obj_idx + ] # (H_low_res, W_low_res) + refined_score = tracker_scores_local[refined_obj_idx] + + # Keep low resolution for broadcasting to reduce communication cost + local_obj_data[obj_id] = (refined_score, refined_mask_low_res) + + # Broadcast data from each GPU that has refined objects + if self.world_size > 1: + for obj_id in obj_ids: + obj_rank = self._get_gpu_id_by_obj_id(inference_state, obj_id) + if self.rank == obj_rank: + # This GPU has the object, broadcast its data + data_to_broadcast = local_obj_data.get(obj_id, None) + data_list = [ + (data_to_broadcast[0].cpu(), data_to_broadcast[1].cpu()) + ] + self.broadcast_python_obj_cpu(data_list, src=obj_rank) + if data_to_broadcast is not None: + refined_obj_data[obj_id] = data_to_broadcast + elif self.rank != obj_rank: + # This GPU doesn't have the object, receive data + data_list = [None] + self.broadcast_python_obj_cpu(data_list, src=obj_rank) + refined_obj_data[obj_id] = ( + data_list[0][0].to(self.device), + data_list[0][1].to(self.device), + ) + else: + # Single GPU case + refined_obj_data = local_obj_data + + # Update Tracker scores for all refined objects + for obj_id, (refined_score, _) in refined_obj_data.items(): + tracker_metadata["obj_id_to_tracker_score_frame_wise"][ + frame_idx + ].update({obj_id: refined_score.item()}) + + if self.rank == 0: + # get predictions from Tracker inference states, it includes the original + # VG predictions and the refined predictions from interactivity. + + # Prepare refined masks dictionary - upscale to video resolution after broadcast + refined_obj_id_to_mask = {} + for obj_id, (_, refined_mask_low_res) in refined_obj_data.items(): + refined_mask_video_res = ( + self._convert_low_res_mask_to_video_res( + refined_mask_low_res, inference_state + ) + ) # (1, H_video, W_video) bool + refined_obj_id_to_mask[obj_id] = refined_mask_video_res + + obj_id_to_mask = self._build_tracker_output( + inference_state, frame_idx, refined_obj_id_to_mask + ) + out = { + "obj_id_to_mask": obj_id_to_mask, + "obj_id_to_score": tracker_metadata["obj_id_to_score"], + "obj_id_to_tracker_score": tracker_metadata[ + "obj_id_to_tracker_score_frame_wise" + ][frame_idx], + } + suppressed_obj_ids = tracker_metadata["rank0_metadata"][ + "suppressed_obj_ids" + ][frame_idx] + self._cache_frame_outputs( + inference_state, + frame_idx, + obj_id_to_mask, + suppressed_obj_ids=suppressed_obj_ids, + ) + suppressed_obj_ids = tracker_metadata["rank0_metadata"][ + "suppressed_obj_ids" + ][frame_idx] + yield ( + frame_idx, + self._postprocess_output( + inference_state, out, suppressed_obj_ids=suppressed_obj_ids + ), + ) + else: + yield frame_idx, None + + def add_action_history( + self, inference_state, action_type, frame_idx=None, obj_ids=None + ): + """ + action_history is used to automatically decide what to do during propagation. + action_type: one of ["add", "remove", "refine"] + ["propagation_full", "propagation_partial", "propagation_fetch"] + """ + instance_actions = ["add", "remove", "refine"] + propagation_actions = [ + "propagation_full", + "propagation_partial", + "propagation_fetch", + ] + assert ( + action_type in instance_actions + propagation_actions + ), f"Invalid action type: {action_type}, must be one of {instance_actions + propagation_actions}" + action = { + "type": action_type, + "frame_idx": frame_idx, + "obj_ids": obj_ids, + } + inference_state["action_history"].append(action) + + def _has_object_been_refined(self, inference_state, obj_id): + action_history = inference_state["action_history"] + for action in action_history: + if action["type"] in ["add", "refine"] and action.get("obj_ids"): + if obj_id in action["obj_ids"]: + return True + return False + + def parse_action_history_for_propagation(self, inference_state): + """ + Parse the actions in history before the last propagation and prepare for the next propagation. + We support multiple actions (add/remove/refine) between two propagations. If we had an action + history similar to this ["propagate", "add", "refine", "remove", "add"], the next propagation + would remove the removed object, and also propagate the two added/refined objects. + + Returns: + propagation_type: one of ["propagation_full", "propagation_partial", "propagation_fetch"] + - "propagation_full": run VG propagation for all objects + - "propagation_partial": run Tracker propagation for selected objects, useful for add/refine actions + - "propagation_fetch": fetch existing VG predictions without running any propagation + obj_ids: list of object ids to run Tracker propagation on if propagation_type is "propagation_partial". + """ + action_history = inference_state["action_history"] + if len(action_history) == 0: + # we run propagation for the first time + return "propagation_full", None + + if "propagation" in action_history[-1]["type"]: + if action_history[-1]["type"] in ["propagation_fetch"]: + # last propagation is direct fetch, we fetch existing predictions + return "propagation_fetch", None + elif action_history[-1]["type"] in [ + "propagation_partial", + "propagation_full", + ]: + # we do fetch prediction if we have already run propagation twice or we have run + # propagation once and it is from the first frame or last frame. + if ( + len(action_history) > 1 + and action_history[-2]["type"] + in ["propagation_partial", "propagation_full"] + ) or action_history[-1]["frame_idx"] in [ + 0, + inference_state["num_frames"] - 1, + ]: + # we have run both forward and backward partial/full propagation + return "propagation_fetch", None + else: + # we have run partial/full forward or backward propagation once, need run it for the rest of the frames + return action_history[-1]["type"], action_history[-1]["obj_ids"] + + # parse actions since last propagation + obj_ids = [] + for action in action_history[::-1]: + if "propagation" in action["type"]: + # we reached the last propagation action, stop parsing + break + if action["type"] in ["add", "refine"]: + obj_ids.extend(action["obj_ids"]) + # else action["type"] == "remove": noop + obj_ids = list(set(obj_ids)) if len(obj_ids) > 0 else None + propagation_type = ( + "propagation_partial" if obj_ids is not None else "propagation_fetch" + ) + return propagation_type, obj_ids + + def remove_object(self, inference_state, obj_id, is_user_action=False): + """ + We try to remove object from tracker states on every GPU, it will do nothing + for states without this object. + """ + obj_rank = self._get_gpu_id_by_obj_id(inference_state, obj_id) + assert obj_rank is not None, f"Object {obj_id} not found in any GPU." + + tracker_states_local = inference_state["tracker_inference_states"] + if self.rank == obj_rank: + self._tracker_remove_object(tracker_states_local, obj_id) + + if is_user_action: + self.add_action_history( + inference_state, action_type="remove", obj_ids=[obj_id] + ) + + # update metadata + tracker_metadata = inference_state["tracker_metadata"] + _obj_ids = tracker_metadata["obj_ids_per_gpu"][obj_rank] + tracker_metadata["obj_ids_per_gpu"][obj_rank] = _obj_ids[_obj_ids != obj_id] + tracker_metadata["num_obj_per_gpu"][obj_rank] = len( + tracker_metadata["obj_ids_per_gpu"][obj_rank] + ) + tracker_metadata["obj_ids_all_gpu"] = np.concatenate( + tracker_metadata["obj_ids_per_gpu"] + ) + tracker_metadata["obj_id_to_score"].pop(obj_id, None) + # tracker_metadata["max_obj_id"] # we do not reuse the object id, so we do not update it here + + # Clean up cached frame outputs to remove references to the deleted object + if "cached_frame_outputs" in inference_state: + for frame_idx in inference_state["cached_frame_outputs"]: + frame_cache = inference_state["cached_frame_outputs"][frame_idx] + if obj_id in frame_cache: + del frame_cache[obj_id] + + def _get_gpu_id_by_obj_id(self, inference_state, obj_id): + """ + Locate GPU ID for a given object. + """ + obj_ids_per_gpu = inference_state["tracker_metadata"]["obj_ids_per_gpu"] + for rank, obj_ids in enumerate(obj_ids_per_gpu): + if obj_id in obj_ids: + return rank + return None # object not found in any GPU + + def _get_tracker_inference_states_by_obj_ids(self, inference_state, obj_ids): + """ + Get the Tracker inference states that contain the given object ids. + This is used to run partial Tracker propagation on a single object/bucket. + Possibly multiple or zero states can be returned. + """ + states = [ + state + for state in inference_state["tracker_inference_states"] + if set(obj_ids) & set(state["obj_ids"]) + ] + return states + + def _prepare_backbone_feats(self, inference_state, frame_idx, reverse): + input_batch = inference_state["input_batch"] + feature_cache = inference_state["feature_cache"] + num_frames = inference_state["num_frames"] + geometric_prompt = ( + inference_state["constants"]["empty_geometric_prompt"] + if inference_state["per_frame_geometric_prompt"][frame_idx] is None + else inference_state["per_frame_geometric_prompt"][frame_idx] + ) + _ = self.run_backbone_and_detection( + frame_idx=frame_idx, + num_frames=num_frames, + input_batch=input_batch, + geometric_prompt=geometric_prompt, + feature_cache=feature_cache, + reverse=reverse, + allow_new_detections=True, + ) + + @torch.inference_mode() + def add_prompt( + self, + inference_state, + frame_idx, + text_str=None, + boxes_xywh=None, + box_labels=None, + points=None, + point_labels=None, + obj_id=None, + rel_coordinates=True, + ): + if points is not None: + # Tracker instance prompts + assert ( + text_str is None and boxes_xywh is None + ), "When points are provided, text_str and boxes_xywh must be None." + assert ( + obj_id is not None + ), "When points are provided, obj_id must be provided." + return self.add_tracker_new_points( + inference_state, + frame_idx, + obj_id=obj_id, + points=points, + labels=point_labels, + rel_coordinates=rel_coordinates, + use_prev_mem_frame=self.use_prev_mem_frame, + ) + else: + # SAM3 prompts + return super().add_prompt( + inference_state, + frame_idx, + text_str=text_str, + boxes_xywh=boxes_xywh, + box_labels=box_labels, + ) + + @torch.inference_mode() + def add_tracker_new_points( + self, + inference_state, + frame_idx, + obj_id, + points, + labels, + rel_coordinates=True, + use_prev_mem_frame=False, + ): + """Add a new point prompt to Tracker. Suppporting instance refinement to existing + objects by passing existing obj_id or adding a new object by passing a new obj_id. + use_prev_mem_frame=False to disable cross attention to previous memory frames. + Every GPU returns the same results, and results should contain all masks including + these masks not refined or not added by the current user points. + """ + assert obj_id is not None, "obj_id must be provided to add new points" + tracker_metadata = inference_state["tracker_metadata"] + if tracker_metadata == {}: + # initialize masklet metadata if it's uninitialized (empty dict) + tracker_metadata.update(self._initialize_metadata()) + + obj_rank = self._get_gpu_id_by_obj_id(inference_state, obj_id) + + # prepare feature + self._prepare_backbone_feats(inference_state, frame_idx, reverse=False) + + object_has_been_refined = self._has_object_been_refined(inference_state, obj_id) + if ( + obj_rank is not None + and self.use_stateless_refinement + and not object_has_been_refined + ): + # The first time we start refinement on the object, we remove it. + logger.debug( + f"[rank={self.rank}] Removing object {obj_id} before refinement." + ) + self.remove_object(inference_state, obj_id, is_user_action=False) + obj_rank = None + + if obj_rank is None: + # new object, we assign it a GPU and create a new inference state if limit allows + num_prev_obj = np.sum(tracker_metadata["num_obj_per_gpu"]) + if num_prev_obj >= self.max_num_objects: + logger.warning( + f"add_tracker_new_points: cannot add a new object as we are already tracking {num_prev_obj=} " + f"masklets (under {self.max_num_objects=})" + ) + obj_ids = [] + H_low_res = W_low_res = self.tracker.low_res_mask_size + H_video_res = inference_state["orig_height"] + W_video_res = inference_state["orig_width"] + low_res_masks = torch.zeros(0, 1, H_low_res, W_low_res) + video_res_masks = torch.zeros(0, 1, H_video_res, W_video_res) + return frame_idx, obj_ids, low_res_masks, video_res_masks + + new_det_gpu_ids = self._assign_new_det_to_gpus( + new_det_num=1, + prev_workload_per_gpu=tracker_metadata["num_obj_per_gpu"], + ) + obj_rank = new_det_gpu_ids[0] + + # get tracker inference state for the new object + if self.rank == obj_rank: + # for batched inference, we create a new inference state + tracker_state = self._init_new_tracker_state(inference_state) + inference_state["tracker_inference_states"].append(tracker_state) + + # update metadata + tracker_metadata["obj_ids_per_gpu"][obj_rank] = np.concatenate( + [ + tracker_metadata["obj_ids_per_gpu"][obj_rank], + np.array([obj_id], dtype=np.int64), + ] + ) + tracker_metadata["num_obj_per_gpu"][obj_rank] = len( + tracker_metadata["obj_ids_per_gpu"][obj_rank] + ) + tracker_metadata["obj_ids_all_gpu"] = np.concatenate( + tracker_metadata["obj_ids_per_gpu"] + ) + tracker_metadata["max_obj_id"] = max(tracker_metadata["max_obj_id"], obj_id) + + logger.debug( + f"[rank={self.rank}] Adding new object with id {obj_id} at frame {frame_idx}." + ) + self.add_action_history( + inference_state, "add", frame_idx=frame_idx, obj_ids=[obj_id] + ) + else: + # existing object, for refinement + if self.rank == obj_rank: + tracker_states = self._get_tracker_inference_states_by_obj_ids( + inference_state, [obj_id] + ) + assert ( + len(tracker_states) == 1 + ), f"[rank={self.rank}] Multiple Tracker inference states found for the same object id." + tracker_state = tracker_states[0] + + # log + logger.debug( + f"[rank={self.rank}] Refining existing object with id {obj_id} at frame {frame_idx}." + ) + self.add_action_history( + inference_state, "refine", frame_idx=frame_idx, obj_ids=[obj_id] + ) + + # assign higher score to added/refined object + tracker_metadata["obj_id_to_score"][obj_id] = 1.0 + tracker_metadata["obj_id_to_tracker_score_frame_wise"][frame_idx][obj_id] = 1.0 + + if self.rank == 0: + rank0_metadata = tracker_metadata.get("rank0_metadata", {}) + + if "removed_obj_ids" in rank0_metadata: + rank0_metadata["removed_obj_ids"].discard(obj_id) + + if "suppressed_obj_ids" in rank0_metadata: + for frame_id in rank0_metadata["suppressed_obj_ids"]: + rank0_metadata["suppressed_obj_ids"][frame_id].discard(obj_id) + + if "masklet_confirmation" in rank0_metadata: + obj_ids_all_gpu = tracker_metadata["obj_ids_all_gpu"] + obj_indices = np.where(obj_ids_all_gpu == obj_id)[0] + if len(obj_indices) > 0: + obj_idx = obj_indices[0] + if obj_idx < len(rank0_metadata["masklet_confirmation"]["status"]): + rank0_metadata["masklet_confirmation"]["status"][obj_idx] = 1 + rank0_metadata["masklet_confirmation"]["consecutive_det_num"][ + obj_idx + ] = self.masklet_confirmation_consecutive_det_thresh + + if self.rank == obj_rank: + frame_idx, obj_ids, low_res_masks, video_res_masks = ( + self.tracker.add_new_points( + inference_state=tracker_state, + frame_idx=frame_idx, + obj_id=obj_id, + points=points, + labels=labels, + clear_old_points=True, + rel_coordinates=rel_coordinates, + use_prev_mem_frame=use_prev_mem_frame, + ) + ) + + if video_res_masks is not None and len(video_res_masks) > 0: + video_res_masks = fill_holes_in_mask_scores( + video_res_masks, # shape (N, 1, H_video, W_video) + max_area=self.fill_hole_area, + fill_holes=True, + remove_sprinkles=True, + ) + + # Since the mem encoder has already run for the current input points? + self.tracker.propagate_in_video_preflight( + tracker_state, run_mem_encoder=True + ) + # Clear detector conditioning frames when user clicks are received to allow + # model updating masks on these frames. It is a noop if user is refining on the + # detector conditioning frames or adding new objects. + self.clear_detector_added_cond_frame_in_tracker( + tracker_state, obj_id, frame_idx + ) + + # fetch results from states and gather across GPUs + # Use optimized caching approach to avoid reprocessing unmodified objects + if self.rank == obj_rank and len(obj_ids) > 0: + new_mask_data = (video_res_masks[obj_ids.index(obj_id)] > 0.0).to( + torch.bool + ) + else: + new_mask_data = None + # Broadcast the new mask data across all ranks for consistency + if self.world_size > 1: + data_list = [new_mask_data.cpu() if new_mask_data is not None else None] + self.broadcast_python_obj_cpu(data_list, src=obj_rank) + new_mask_data = data_list[0].to(self.device) + + if self.rank == 0: + obj_id_to_mask = self._build_tracker_output( + inference_state, + frame_idx, + {obj_id: new_mask_data} if new_mask_data is not None else None, + ) + # post processing - remove suppressed obj_ids + obj_id_to_score = tracker_metadata["obj_id_to_score"] + suppressed_obj_ids = tracker_metadata["rank0_metadata"][ + "suppressed_obj_ids" + ][frame_idx] + obj_id_to_tracker_score = tracker_metadata[ + "obj_id_to_tracker_score_frame_wise" + ][frame_idx] + + out = { + "obj_id_to_mask": obj_id_to_mask, + "obj_id_to_score": obj_id_to_score, + "obj_id_to_tracker_score": obj_id_to_tracker_score, + } + self._cache_frame_outputs( + inference_state, + frame_idx, + obj_id_to_mask, + suppressed_obj_ids=suppressed_obj_ids, + ) + return frame_idx, self._postprocess_output( + inference_state, out, suppressed_obj_ids=suppressed_obj_ids + ) + else: + return frame_idx, None # no output on other GPUs + + def _gather_obj_id_to_mask_across_gpus(self, inference_state, obj_id_to_mask_local): + """Gather obj_id_to_mask from all GPUs. Optionally resize the masks to the video resolution.""" + tracker_metadata = inference_state["tracker_metadata"] + + # concatenate the output masklets from all local inference states + H_mask = W_mask = self.tracker.low_res_mask_size + obj_ids_local = tracker_metadata["obj_ids_per_gpu"][self.rank] + low_res_masks_local = [] + for obj_id in obj_ids_local: + if obj_id in obj_id_to_mask_local: + low_res_masks_local.append(obj_id_to_mask_local[obj_id]) + else: + low_res_masks_local.append( + torch.full((H_mask, W_mask), -1024.0, device=self.device) + ) + if len(low_res_masks_local) > 0: + low_res_masks_local = torch.stack(low_res_masks_local, dim=0) # (N, H, W) + assert low_res_masks_local.shape[1:] == (H_mask, W_mask) + else: + low_res_masks_local = torch.zeros(0, H_mask, W_mask, device=self.device) + + # all-gather `low_res_masks_local` into `low_res_masks_global` + # - low_res_masks_global: Tensor -- (num_global_obj, H_mask, W_mask) + if self.world_size > 1: + low_res_masks_local = low_res_masks_local.float().contiguous() + low_res_masks_peers = [ + low_res_masks_local.new_empty(num_obj, H_mask, W_mask) + for num_obj in tracker_metadata["num_obj_per_gpu"] + ] + dist.all_gather(low_res_masks_peers, low_res_masks_local) + low_res_masks_global = torch.cat(low_res_masks_peers, dim=0) + else: + low_res_masks_global = low_res_masks_local + return low_res_masks_global + + def _convert_low_res_mask_to_video_res(self, low_res_mask, inference_state): + """ + Convert a low-res mask to video resolution, matching the format expected by _build_tracker_output. + + Args: + low_res_mask: Tensor of shape (H_low_res, W_low_res) + inference_state: Contains video dimensions + + Returns: + video_res_mask: Tensor of shape (1, H_video, W_video) bool + """ + if low_res_mask is None: + return None + + # Convert to 3D for interpolation: (H_low_res, W_low_res) -> (1, H_low_res, W_low_res) + low_res_mask_3d = low_res_mask.unsqueeze(0).unsqueeze(0) + + # Get video dimensions + H_video = inference_state["orig_height"] + W_video = inference_state["orig_width"] + + video_res_mask = F.interpolate( + low_res_mask_3d.float(), + size=(H_video, W_video), + mode="bilinear", + align_corners=False, + ) # (1, H_video, W_video) + + # Convert to boolean - already in the right shape! + return (video_res_mask.squeeze(0) > 0.0).to(torch.bool) + + def clear_detector_added_cond_frame_in_tracker( + self, tracker_state, obj_id, refined_frame_idx + ): + """Clear detector added conditioning frame if it is within a predefined window + of the refined frame. This allow model to update masks on these frames.""" + obj_idx = self.tracker._obj_id_to_idx(tracker_state, obj_id) + + mask_only_cond_frame_indices = [] + window = self.refinement_detector_cond_frame_removal_window + for frame_idx in tracker_state["mask_inputs_per_obj"][obj_idx]: + if frame_idx not in tracker_state["point_inputs_per_obj"][obj_idx]: + # clear conditioning frames within a window of the refined frame + if abs(frame_idx - refined_frame_idx) <= window: + mask_only_cond_frame_indices.append(frame_idx) + + # clear + if len(mask_only_cond_frame_indices) > 0: + for frame_idx in mask_only_cond_frame_indices: + # obj_ids_on_this_frame is essentially all obj_ids in the state + # since they are bucket batched + obj_ids_on_this_frame = tracker_state["obj_id_to_idx"].keys() + for obj_id2 in obj_ids_on_this_frame: + self.tracker.clear_all_points_in_frame( + tracker_state, frame_idx, obj_id2, need_output=False + ) + logger.debug( + f"Cleared detector mask only conditioning frames ({mask_only_cond_frame_indices}) in Tracker." + ) + return + + +def is_image_type(resource_path: str) -> bool: + if isinstance(resource_path, list): + return len(resource_path) == 1 + return resource_path.lower().endswith(tuple(IMAGE_EXTS)) diff --git a/third_party/sam3/sam3/model/sam3_video_predictor.py b/third_party/sam3/sam3/model/sam3_video_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..a7660af4dd30162cf8d23bb712f85217370e8c67 --- /dev/null +++ b/third_party/sam3/sam3/model/sam3_video_predictor.py @@ -0,0 +1,331 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +import datetime +import gc +import multiprocessing as mp +import os +import queue +import socket +import sys +import time +import uuid +from contextlib import closing +from typing import List, Optional + +import psutil +import torch +from sam3.logger import get_logger +from sam3.model.sam3_base_predictor import Sam3BasePredictor + +logger = get_logger(__name__) + + +class Sam3VideoPredictor(Sam3BasePredictor): + def __init__( + self, + checkpoint_path=None, + bpe_path=None, + has_presence_token=True, + geo_encoder_use_img_cross_attn=True, + strict_state_dict_loading=True, + async_loading_frames=False, + video_loader_type="cv2", + apply_temporal_disambiguation: bool = True, + compile: bool = False, + ): + super().__init__() + self.async_loading_frames = async_loading_frames + self.video_loader_type = video_loader_type + from sam3.model_builder import build_sam3_video_model + + self.model = ( + build_sam3_video_model( + checkpoint_path=checkpoint_path, + bpe_path=bpe_path, + has_presence_token=has_presence_token, + geo_encoder_use_img_cross_attn=geo_encoder_use_img_cross_attn, + strict_state_dict_loading=strict_state_dict_loading, + apply_temporal_disambiguation=apply_temporal_disambiguation, + compile=compile, + ) + .cuda() + .eval() + ) + + def remove_object( + self, + session_id: str, + frame_idx: int = 0, + obj_id: int = 0, + is_user_action: bool = True, + ): + """Remove an object from tracking (SAM3 uses a simpler remove_object API).""" + session = self._get_session(session_id) + inference_state = session["state"] + + self.model.remove_object( + inference_state=inference_state, + obj_id=obj_id, + is_user_action=is_user_action, + ) + return {"is_success": True} + + def _get_session_stats(self): + """Get a statistics string for live sessions and their GPU usage.""" + live_session_strs = [] + for sid, s in self._all_inference_states.items(): + nf = s["state"]["num_frames"] + live_session_strs.append(f"'{sid}' ({nf} frames)") + joined = ", ".join(live_session_strs) + mem_alloc = torch.cuda.memory_allocated() // 1024**2 + mem_res = torch.cuda.memory_reserved() // 1024**2 + max_alloc = torch.cuda.max_memory_allocated() // 1024**2 + max_res = torch.cuda.max_memory_reserved() // 1024**2 + return ( + f"live sessions: [{joined}], GPU memory: " + f"{mem_alloc} MiB used and {mem_res} MiB reserved" + f" (max over time: {max_alloc} MiB used and {max_res} MiB reserved)" + ) + + def _get_torch_and_gpu_properties(self): + """Get a string for PyTorch and GPU properties.""" + return ( + f"torch: {torch.__version__} with CUDA arch {torch.cuda.get_arch_list()}, " + f"GPU device: {torch.cuda.get_device_properties(torch.cuda.current_device())}" + ) + + +class Sam3VideoPredictorMultiGPU(Sam3VideoPredictor): + def __init__(self, *model_args, gpus_to_use=None, **model_kwargs): + if gpus_to_use is None: + # if not specified, use only the current GPU by default + gpus_to_use = [torch.cuda.current_device()] + + IS_MAIN_PROCESS = os.getenv("IS_MAIN_PROCESS", "1") == "1" + if IS_MAIN_PROCESS: + gpus_to_use = sorted(set(gpus_to_use)) + logger.info(f"using the following GPU IDs: {gpus_to_use}") + assert len(gpus_to_use) > 0 and all(isinstance(i, int) for i in gpus_to_use) + assert all(0 <= i < torch.cuda.device_count() for i in gpus_to_use) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = f"{self._find_free_port()}" + os.environ["RANK"] = "0" + os.environ["WORLD_SIZE"] = f"{len(gpus_to_use)}" + + self.gpus_to_use = gpus_to_use + self.rank = int(os.environ["RANK"]) + self.world_size = int(os.environ["WORLD_SIZE"]) + self.rank_str = f"rank={self.rank} with world_size={self.world_size}" + self.device = torch.device(f"cuda:{self.gpus_to_use[self.rank]}") + torch.cuda.set_device(self.device) + self.has_shutdown = False + if self.rank == 0: + logger.info("\n\n\n\t*** START loading model on all ranks ***\n\n") + + logger.info(f"loading model on {self.rank_str} -- this could take a while ...") + super().__init__(*model_args, **model_kwargs) + logger.info(f"loading model on {self.rank_str} -- DONE locally") + + if self.world_size > 1 and self.rank == 0: + # start the worker processes *after* the model is loaded in the main process + # so that the main process can run torch.compile and fill the cache first + self._start_worker_processes(*model_args, **model_kwargs) + for rank in range(1, self.world_size): + self.command_queues[rank].put(("start_nccl_process_group", None)) + self._start_nccl_process_group() + + if self.rank == 0: + logger.info("\n\n\n\t*** DONE loading model on all ranks ***\n\n") + + @torch.inference_mode() + def handle_request(self, request): + """Dispatch a request based on its type.""" + if self.has_shutdown: + raise RuntimeError( + "cannot handle request after the predictor has shutdown; please create a new predictor" + ) + + # when starting a session, we need to create a session id before dispatching + # the request to the workers + if request["type"] == "start_session" and request.get("session_id") is None: + request["session_id"] = str(uuid.uuid4()) + # dispatch the request to all worker processes + if self.world_size > 1 and self.rank == 0: + for rank in range(1, self.world_size): + self.command_queues[rank].put((request, False)) + + response = super().handle_request(request) + + if self.world_size > 1: + torch.distributed.barrier() # wait for all ranks to finish + return response + + @torch.inference_mode() + def handle_stream_request(self, request): + """Dispatch a stream request based on its type.""" + if self.has_shutdown: + raise RuntimeError( + "cannot handle request after the predictor has shutdown; please create a new predictor" + ) + + # dispatch the request to all worker processes + if self.world_size > 1 and self.rank == 0: + for rank in range(1, self.world_size): + self.command_queues[rank].put((request, True)) + + yield from super().handle_stream_request(request) + + if self.world_size > 1: + torch.distributed.barrier() # wait for all ranks to finish + + def _start_worker_processes(self, *model_args, **model_kwargs): + """Start worker processes for handling model inference.""" + world_size = self.world_size + logger.info(f"spawning {world_size - 1} worker processes") + # Use "spawn" (instead of "fork") for different PyTorch or CUDA context + mp_ctx = mp.get_context("spawn") + self.command_queues = {rank: mp_ctx.Queue() for rank in range(1, world_size)} + self.result_queues = {rank: mp_ctx.Queue() for rank in range(1, world_size)} + parent_pid = os.getpid() + for rank in range(1, world_size): + # set the environment variables for each worker process + os.environ["IS_MAIN_PROCESS"] = "0" # mark this as a worker process + os.environ["RANK"] = f"{rank}" + worker_process = mp_ctx.Process( + target=Sam3VideoPredictorMultiGPU._worker_process_command_loop, + args=( + rank, + world_size, + self.command_queues[rank], + self.result_queues[rank], + model_args, + model_kwargs, + self.gpus_to_use, + parent_pid, + ), + daemon=True, + ) + worker_process.start() + # revert the environment variables for the main process + os.environ["IS_MAIN_PROCESS"] = "1" + os.environ["RANK"] = "0" + # wait for all the worker processes to load the model and collect their PIDs + self.worker_pids = {} + for rank in range(1, self.world_size): + # a large timeout to cover potentially long model loading time due to compilation + _, worker_pid = self.result_queues[rank].get(timeout=7200) + self.worker_pids[rank] = worker_pid + logger.info(f"spawned {world_size - 1} worker processes") + + def _start_nccl_process_group(self): + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + if world_size == 1: + return + + logger.debug(f"starting NCCL process group on {rank=} with {world_size=}") + assert not torch.distributed.is_initialized() + # use the "env://" init method with environment variables set in start_worker_processes + # a short 3-min timeout to quickly detect any synchronization failures + timeout_sec = int(os.getenv("SAM3_COLLECTIVE_OP_TIMEOUT_SEC", "180")) + timeout = datetime.timedelta(seconds=timeout_sec) + torch.distributed.init_process_group( + backend="nccl", + init_method="env://", + timeout=timeout, + device_id=self.device, + ) + # warm-up the NCCL process group by running a dummy all-reduce + tensor = torch.ones(1024, 1024).cuda() + torch.distributed.all_reduce(tensor) + logger.debug(f"started NCCL process group on {rank=} with {world_size=}") + + def _find_free_port(self) -> int: + """ + Find a free port (a random free port from 1024 to 65535 will be selected) + https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number) + """ + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.bind(("", 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return s.getsockname()[1] + + @staticmethod + def _worker_process_command_loop( + rank, + world_size, + command_queue, + result_queue, + model_args, + model_kwargs, + gpus_to_use, + parent_pid, + ): + """ + The command loop for each worker process. It listens to commands from the main process + and executes them using the model. + """ + logger.info(f"starting worker process {rank=} with {world_size=}") + # verify that the environment variables are set correctly + assert int(os.environ["IS_MAIN_PROCESS"]) == 0 + assert int(os.environ["RANK"]) == rank + assert int(os.environ["WORLD_SIZE"]) == world_size + # load the model in this worker process + predictor = Sam3VideoPredictorMultiGPU( + *model_args, gpus_to_use=gpus_to_use, **model_kwargs + ) + logger.info(f"started worker {rank=} with {world_size=}") + # return the worker process id to the main process for bookkeeping + worker_pid = os.getpid() + result_queue.put(("load_model", worker_pid)) + + # wait for the command to start the NCCL process group + request_type, _ = command_queue.get(timeout=7200) + assert request_type == "start_nccl_process_group" + predictor._start_nccl_process_group() + + # keep listening to commands from the main process + while True: + try: + request, is_stream_request = command_queue.get(timeout=5.0) + if request == "shutdown": + logger.info(f"worker {rank=} shutting down") + torch.distributed.destroy_process_group() + result_queue.put(("shutdown", True)) # acknowledge the shutdown + sys.exit(0) + + logger.debug(f"worker {rank=} received request {request['type']=}") + if is_stream_request: + for _ in predictor.handle_stream_request(request): + pass # handle stream requests in a generator fashion + else: + predictor.handle_request(request) + except queue.Empty: + # Usually Python's multiprocessing module will shutdown all the daemon worker + # processes when the main process exits gracefully. However, the user may kill + # the main process using SIGKILL and thereby leaving no chance for the main process + # to clean up its daemon child processes. So here we manually check whether the + # parent process still exists (every 5 sec as in `command_queue.get` timeout). + if not psutil.pid_exists(parent_pid): + logger.info( + f"stopping worker {rank=} as its parent process has exited" + ) + sys.exit(1) + except Exception as e: + logger.error(f"worker {rank=} exception: {e}", exc_info=True) + + def shutdown(self): + """Shutdown all worker processes.""" + if self.rank == 0 and self.world_size > 1: + logger.info(f"shutting down {self.world_size - 1} worker processes") + for rank in range(1, self.world_size): + self.command_queues[rank].put(("shutdown", False)) + torch.distributed.destroy_process_group() + for rank in range(1, self.world_size): + self.result_queues[rank].get() # wait for the worker to acknowledge + logger.info(f"shut down {self.world_size - 1} worker processes") + self.has_shutdown = True + + super().shutdown() diff --git a/third_party/sam3/sam3/model/text_encoder_ve.py b/third_party/sam3/sam3/model/text_encoder_ve.py new file mode 100644 index 0000000000000000000000000000000000000000..53ddd5d41f312c0ea9608b37df75e75f2677346a --- /dev/null +++ b/third_party/sam3/sam3/model/text_encoder_ve.py @@ -0,0 +1,330 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +from collections import OrderedDict +from typing import Callable, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint + +from .model_misc import LayerScale + + +class ResidualAttentionBlock(nn.Module): + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + ls_init_value: Optional[float] = None, + act_layer: Callable[[], nn.Module] = nn.GELU, + norm_layer: Callable[[int], nn.Module] = nn.LayerNorm, + ): + super().__init__() + # Attention + self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=True) + + # LayerNorm, LayerScale + self.ln_1 = norm_layer(d_model) + self.ln_2 = norm_layer(d_model) + + self.ls_1 = ( + LayerScale(d_model, ls_init_value) + if ls_init_value is not None + else nn.Identity() + ) + self.ls_2 = ( + LayerScale(d_model, ls_init_value) + if ls_init_value is not None + else nn.Identity() + ) + + # MLP + mlp_width = int(d_model * mlp_ratio) + self.mlp = nn.Sequential( + OrderedDict( + [ + ("c_fc", nn.Linear(d_model, mlp_width)), + ("gelu", act_layer()), + ("c_proj", nn.Linear(mlp_width, d_model)), + ] + ) + ) + + def attention( + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + k_x = k_x if k_x is not None else q_x + v_x = v_x if v_x is not None else q_x + if attn_mask is not None: + # Leave boolean masks as is + if not attn_mask.dtype == torch.bool: + attn_mask = attn_mask.to(q_x.dtype) + + return self.attn(q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask)[0] + + def forward( + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + k_x = ( + self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None + ) + v_x = ( + self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None + ) + x = q_x + self.ls_1( + self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask) + ) + x = x + self.ls_2(self.mlp(self.ln_2(x))) + return x + + +class Transformer(nn.Module): + def __init__( + self, + width: int, + layers: int, + heads: int, + mlp_ratio: float = 4.0, + ls_init_value: Optional[float] = None, + act_layer: Callable[[], nn.Module] = nn.GELU, + norm_layer: Callable[[int], nn.Module] = nn.LayerNorm, + compile_mode: Optional[str] = None, + use_act_checkpoint: bool = False, + ): + super().__init__() + self.width = width + self.layers = layers + self.grad_checkpointing = use_act_checkpoint + self.resblocks = nn.ModuleList( + [ + ResidualAttentionBlock( + width, + heads, + mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + ) + for _ in range(layers) + ] + ) + + if compile_mode is not None: + self.forward = torch.compile( + self.forward, mode=compile_mode, fullgraph=True + ) + if self.grad_checkpointing: + torch._dynamo.config.optimize_ddp = False + + def forward( + self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + for _, r in enumerate(self.resblocks): + if ( + self.grad_checkpointing + and not torch.jit.is_scripting() + and self.training + ): + x = checkpoint(r, x, None, None, attn_mask, use_reentrant=False) + else: + x = r( + x, + attn_mask=attn_mask, + ) + return x + + +def text_global_pool( + x: torch.Tensor, text: Optional[torch.Tensor] = None, pool_type: str = "argmax" +) -> Tuple[torch.Tensor, torch.Tensor]: + if pool_type == "first": + pooled, tokens = x[:, 0], x[:, 1:] + elif pool_type == "last": + pooled, tokens = x[:, -1], x[:, :-1] + elif pool_type == "argmax": + # take features from the eot embedding (eot_token is the highest number in each sequence) + assert text is not None + pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x + else: + pooled = tokens = x + return pooled, tokens + + +class TextTransformer(nn.Module): + def __init__( + self, + context_length: int = 77, + vocab_size: int = 49408, + width: int = 512, + heads: int = 8, + layers: int = 12, + mlp_ratio: float = 4.0, + ls_init_value: Optional[float] = None, + output_dim: int = 512, + no_causal_mask: bool = False, + pool_type: str = "none", # no pooling + proj_bias: bool = False, + act_layer: Callable = nn.GELU, + norm_layer: Callable = nn.LayerNorm, + output_tokens: bool = False, + use_ln_post: bool = True, + compile_mode: Optional[str] = None, + use_act_checkpoint: bool = False, + ): + super().__init__() + assert pool_type in ("first", "last", "argmax", "none") + self.output_tokens = output_tokens + self.num_pos = self.context_length = context_length + self.vocab_size = vocab_size + self.width = width + self.output_dim = output_dim + self.heads = heads + self.pool_type = pool_type + + self.token_embedding = nn.Embedding(self.vocab_size, width) + self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width)) + self.transformer = Transformer( + width=width, + layers=layers, + heads=heads, + mlp_ratio=mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + compile_mode=compile_mode, + use_act_checkpoint=use_act_checkpoint, + ) + self.ln_final = norm_layer(width) if use_ln_post else nn.Identity() + if no_causal_mask: + self.attn_mask = None + else: + self.register_buffer( + "attn_mask", self.build_causal_mask(), persistent=False + ) + if proj_bias: + self.text_projection = nn.Linear(width, output_dim) + else: + self.text_projection = nn.Parameter(torch.empty(width, output_dim)) + + def build_causal_mask(self) -> torch.Tensor: + # lazily create causal attention mask, with full attention between the tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.num_pos, self.num_pos) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def forward( + self, text: torch.Tensor + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + seq_len = text.shape[1] + x = self.token_embedding(text) # [batch_size, n_ctx, d_model] + + attn_mask = self.attn_mask + if attn_mask is not None: + attn_mask = attn_mask[:seq_len, :seq_len] + + x = x + self.positional_embedding[:seq_len] + x = self.transformer(x, attn_mask=attn_mask) + + x = self.ln_final(x) + pooled, tokens = text_global_pool(x, text, pool_type=self.pool_type) + if self.text_projection is not None: + if isinstance(self.text_projection, nn.Linear): + pooled = self.text_projection(pooled) + else: + pooled = pooled @ self.text_projection + if self.output_tokens: + return pooled, tokens + return pooled + + +class VETextEncoder(nn.Module): + def __init__( + self, + d_model: int, + tokenizer: Callable, + width: int = 1024, + heads: int = 16, + layers: int = 24, + context_length: int = 32, + vocab_size: int = 49408, + use_ln_post: bool = True, + compile_mode: Optional[str] = None, + use_act_checkpoint: bool = True, + ): + super().__init__() + self.context_length = context_length + self.use_ln_post = use_ln_post + self.tokenizer = tokenizer + + self.encoder = TextTransformer( + context_length=self.context_length, + vocab_size=vocab_size, + width=width, + heads=heads, + layers=layers, + # we want the tokens, not just the pooled output + output_tokens=True, + use_ln_post=use_ln_post, + compile_mode=compile_mode, + use_act_checkpoint=use_act_checkpoint, + ) + self.resizer = nn.Linear(self.encoder.width, d_model) + + def forward( + self, + text: Union[List[str], Tuple[torch.Tensor, torch.Tensor, dict]], + input_boxes: Optional[List] = None, + device: torch.device = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if isinstance(text[0], str): + # no use case for this + assert input_boxes is None or len(input_boxes) == 0, "not supported" + + # Encode the text + tokenized = self.tokenizer(text, context_length=self.context_length).to( + device + ) # [b, seq_len] + text_attention_mask = (tokenized != 0).bool() + + # manually embed the tokens + inputs_embeds = self.encoder.token_embedding( + tokenized + ) # [b, seq_len, d=1024] + _, text_memory = self.encoder(tokenized) # [b, seq_len, d=1024] + + assert text_memory.shape[1] == inputs_embeds.shape[1] + # Invert attention mask because its the opposite in pytorch transformer + text_attention_mask = text_attention_mask.ne(1) + # Transpose memory because pytorch's attention expects sequence first + text_memory = text_memory.transpose(0, 1) + # Resize the encoder hidden states to be of the same d_model as the decoder + text_memory_resized = self.resizer(text_memory) + else: + # The text is already encoded, use as is. + text_attention_mask, text_memory_resized, tokenized = text + inputs_embeds = tokenized["inputs_embeds"] + assert ( + input_boxes is None or len(input_boxes) == 0 + ), "Can't replace boxes in text if it's already encoded" + + # Note that the input_embeds are returned in pytorch's convention (sequence first) + return ( + text_attention_mask, + text_memory_resized, + inputs_embeds.transpose(0, 1), + ) diff --git a/third_party/sam3/sam3/model/tokenizer_ve.py b/third_party/sam3/sam3/model/tokenizer_ve.py new file mode 100644 index 0000000000000000000000000000000000000000..f11fd3d92574160f8408bb366768ddc09f215089 --- /dev/null +++ b/third_party/sam3/sam3/model/tokenizer_ve.py @@ -0,0 +1,255 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +""" +Text Tokenizer. + +Copied and lightly adapted from VE repo, which in turn copied +from open_clip and openAI CLIP. +""" + +import gzip +import html +import io +import os +import string +from functools import lru_cache +from typing import List, Optional, Union + +import ftfy +import regex as re +import torch +from iopath.common.file_io import g_pathmgr + + +# https://stackoverflow.com/q/62691279 +os.environ["TOKENIZERS_PARALLELISM"] = "false" +DEFAULT_CONTEXT_LENGTH = 77 + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a significant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("¡"), ord("¬") + 1)) + + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def _clean_canonicalize(x): + # basic, remove whitespace, remove punctuation, lower case + return canonicalize_text(basic_clean(x)) + + +def _clean_lower(x): + # basic, remove whitespace, lower case + return whitespace_clean(basic_clean(x)).lower() + + +def _clean_whitespace(x): + # basic, remove whitespace + return whitespace_clean(basic_clean(x)) + + +def get_clean_fn(type: str): + if type == "canonicalize": + return _clean_canonicalize + elif type == "lower": + return _clean_lower + elif type == "whitespace": + return _clean_whitespace + else: + assert False, f"Invalid clean function ({type})." + + +def canonicalize_text(text, *, keep_punctuation_exact_string=None): + """Returns canonicalized `text` (lowercase and punctuation removed). + From: https://github.com/google-research/big_vision/blob/53f18caf27a9419231bbf08d3388b07671616d3d/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94 + Args: + text: string to be canonicalized. + keep_punctuation_exact_string: If provided, then this exact string kept. + For example providing '{}' will keep any occurrences of '{}' (but will + still remove '{' and '}' that appear separately). + """ + text = text.replace("_", " ") + if keep_punctuation_exact_string: + text = keep_punctuation_exact_string.join( + part.translate(str.maketrans("", "", string.punctuation)) + for part in text.split(keep_punctuation_exact_string) + ) + else: + text = text.translate(str.maketrans("", "", string.punctuation)) + text = text.lower() + text = re.sub(r"\s+", " ", text) + return text.strip() + + +class SimpleTokenizer(object): + def __init__( + self, + bpe_path: Union[str, os.PathLike], + additional_special_tokens: Optional[List[str]] = None, + context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH, + clean: str = "lower", + ): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + with g_pathmgr.open(bpe_path, "rb") as fh: + bpe_bytes = io.BytesIO(fh.read()) + merges = gzip.open(bpe_bytes).read().decode("utf-8").split("\n") + # merges = gzip.open(bpe_path).read().decode("utf-8").split("\n") + merges = merges[1 : 49152 - 256 - 2 + 1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v + "" for v in vocab] + for merge in merges: + vocab.append("".join(merge)) + special_tokens = ["", ""] + if additional_special_tokens: + special_tokens += additional_special_tokens + vocab.extend(special_tokens) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {t: t for t in special_tokens} + special = "|".join(special_tokens) + self.pat = re.compile( + special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + re.IGNORECASE, + ) + self.vocab_size = len(self.encoder) + self.all_special_ids = [self.encoder[t] for t in special_tokens] + self.sot_token_id = self.all_special_ids[0] + self.eot_token_id = self.all_special_ids[1] + self.context_length = context_length + self.clean_fn = get_clean_fn(clean) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + (token[-1] + "",) + pairs = get_pairs(word) + if not pairs: + return token + "" + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = self.clean_fn(text) + for token in re.findall(self.pat, text): + token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) + bpe_tokens.extend( + self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ") + ) + return bpe_tokens + + def decode(self, tokens): + text = "".join([self.decoder[token] for token in tokens]) + text = ( + bytearray([self.byte_decoder[c] for c in text]) + .decode("utf-8", errors="replace") + .replace("", " ") + ) + return text + + def __call__( + self, texts: Union[str, List[str]], context_length: Optional[int] = None + ) -> torch.LongTensor: + """Returns the tokenized representation of given input string(s) + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + context_length : int + The context length to use; all CLIP models use 77 as the context length + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] + """ + if isinstance(texts, str): + texts = [texts] + context_length = context_length or self.context_length + assert context_length, "Please set a valid context length" + all_tokens = [ + [self.sot_token_id] + self.encode(text) + [self.eot_token_id] + for text in texts + ] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + tokens = tokens[:context_length] # Truncate + tokens[-1] = self.eot_token_id + result[i, : len(tokens)] = torch.tensor(tokens) + return result diff --git a/third_party/sam3/sam3/model/utils/__init__.py b/third_party/sam3/sam3/model/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8fb7069f8380f29a7f0ae69f2c6d4e20b1ee842a --- /dev/null +++ b/third_party/sam3/sam3/model/utils/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved +# All rights reserved. + +# pyre-unsafe + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/third_party/sam3/sam3/model/utils/__pycache__/__init__.cpython-311.pyc b/third_party/sam3/sam3/model/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0862710340a4dc620b35227d8ed5b26dc22f3f62 Binary files /dev/null and b/third_party/sam3/sam3/model/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/model/utils/__pycache__/misc.cpython-311.pyc b/third_party/sam3/sam3/model/utils/__pycache__/misc.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..517a325bc7e39461ab982db85733adb3402cd065 Binary files /dev/null and b/third_party/sam3/sam3/model/utils/__pycache__/misc.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/model/utils/__pycache__/sam1_utils.cpython-311.pyc b/third_party/sam3/sam3/model/utils/__pycache__/sam1_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c7264f41a43b074b78455a6a4a1540649ab45a47 Binary files /dev/null and b/third_party/sam3/sam3/model/utils/__pycache__/sam1_utils.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/model/utils/__pycache__/sam2_utils.cpython-311.pyc b/third_party/sam3/sam3/model/utils/__pycache__/sam2_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3eb9bc18df721694bccd4d641920c49bfd726a12 Binary files /dev/null and b/third_party/sam3/sam3/model/utils/__pycache__/sam2_utils.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/model/utils/misc.py b/third_party/sam3/sam3/model/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..486fa8ddc1f64e0ede2581cbba7da2ffa5e1a8ce --- /dev/null +++ b/third_party/sam3/sam3/model/utils/misc.py @@ -0,0 +1,79 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +from collections import defaultdict +from dataclasses import fields, is_dataclass +from typing import Any, Mapping, Protocol, runtime_checkable + +import torch + + +def _is_named_tuple(x) -> bool: + return isinstance(x, tuple) and hasattr(x, "_asdict") and hasattr(x, "_fields") + + +@runtime_checkable +class _CopyableData(Protocol): + def to(self, device: torch.device, *args: Any, **kwargs: Any): + """Copy data to the specified device""" + ... + + +def copy_data_to_device(data, device: torch.device, *args: Any, **kwargs: Any): + """Function that recursively copies data to a torch.device. + + Args: + data: The data to copy to device + device: The device to which the data should be copied + args: positional arguments that will be passed to the `to` call + kwargs: keyword arguments that will be passed to the `to` call + + Returns: + The data on the correct device + """ + + if _is_named_tuple(data): + return type(data)( + **copy_data_to_device(data._asdict(), device, *args, **kwargs) + ) + elif isinstance(data, (list, tuple)): + return type(data)(copy_data_to_device(e, device, *args, **kwargs) for e in data) + elif isinstance(data, defaultdict): + return type(data)( + data.default_factory, + { + k: copy_data_to_device(v, device, *args, **kwargs) + for k, v in data.items() + }, + ) + elif isinstance(data, Mapping): + return type(data)( + { + k: copy_data_to_device(v, device, *args, **kwargs) + for k, v in data.items() + } + ) + elif is_dataclass(data) and not isinstance(data, type): + new_data_class = type(data)( + **{ + field.name: copy_data_to_device( + getattr(data, field.name), device, *args, **kwargs + ) + for field in fields(data) + if field.init + } + ) + for field in fields(data): + if not field.init: + setattr( + new_data_class, + field.name, + copy_data_to_device( + getattr(data, field.name), device, *args, **kwargs + ), + ) + return new_data_class + elif isinstance(data, _CopyableData): + return data.to(device, *args, **kwargs) + return data diff --git a/third_party/sam3/sam3/model/utils/sam1_utils.py b/third_party/sam3/sam3/model/utils/sam1_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1ee131ebc0dac82bcb08ab601787ec26d284445b --- /dev/null +++ b/third_party/sam3/sam3/model/utils/sam1_utils.py @@ -0,0 +1,121 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved +# All rights reserved. + +# pyre-unsafe + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.transforms import Normalize, Resize, ToTensor + + +# Adapted from https://github.com/facebookresearch/sam2/blob/main/sam2/utils/transforms.py +class SAM2Transforms(nn.Module): + def __init__( + self, resolution, mask_threshold, max_hole_area=0.0, max_sprinkle_area=0.0 + ): + """ + Transforms for SAM2. + """ + super().__init__() + self.resolution = resolution + self.mask_threshold = mask_threshold + self.max_hole_area = max_hole_area + self.max_sprinkle_area = max_sprinkle_area + self.mean = [0.5, 0.5, 0.5] + self.std = [0.5, 0.5, 0.5] + self.to_tensor = ToTensor() + self.transforms = torch.jit.script( + nn.Sequential( + Resize((self.resolution, self.resolution)), + Normalize(self.mean, self.std), + ) + ) + + def __call__(self, x): + x = self.to_tensor(x) + return self.transforms(x) + + def forward_batch(self, img_list): + img_batch = [self.transforms(self.to_tensor(img)) for img in img_list] + img_batch = torch.stack(img_batch, dim=0) + return img_batch + + def transform_coords( + self, coords: torch.Tensor, normalize=False, orig_hw=None + ) -> torch.Tensor: + """ + Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates, + If the coords are in absolute image coordinates, normalize should be set to True and original image size is required. + + Returns + Un-normalized coordinates in the range of [0, 1] which is expected by the SAM2 model. + """ + if normalize: + assert orig_hw is not None + h, w = orig_hw + coords = coords.clone() + coords[..., 0] = coords[..., 0] / w + coords[..., 1] = coords[..., 1] / h + + coords = coords * self.resolution # unnormalize coords + return coords + + def transform_boxes( + self, boxes: torch.Tensor, normalize=False, orig_hw=None + ) -> torch.Tensor: + """ + Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates, + if the coords are in absolute image coordinates, normalize should be set to True and original image size is required. + """ + boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw) + return boxes + + def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor: + """ + Perform PostProcessing on output masks. + """ + masks = masks.float() + input_masks = masks + mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image + try: + from sam3.perflib.connected_components import connected_components + + if self.max_hole_area > 0: + # Holes are those connected components in background with area <= self.fill_hole_area + # (background regions are those with mask scores <= self.mask_threshold) + labels, areas = connected_components( + (mask_flat <= self.mask_threshold).to(torch.uint8) + ) + is_hole = (labels > 0) & (areas <= self.max_hole_area) + is_hole = is_hole.reshape_as(masks) + # We fill holes with a small positive mask score (10.0) to change them to foreground. + masks = torch.where(is_hole, self.mask_threshold + 10.0, masks) + + if self.max_sprinkle_area > 0: + labels, areas = connected_components( + (mask_flat > self.mask_threshold).to(torch.uint8) + ) + is_hole = (labels > 0) & (areas <= self.max_sprinkle_area) + is_hole = is_hole.reshape_as(masks) + # We fill holes with negative mask score (-10.0) to change them to background. + masks = torch.where(is_hole, self.mask_threshold - 10.0, masks) + except Exception as e: + # Skip the post-processing step if the CUDA kernel fails + warnings.warn( + f"{e}\n\nSkipping the post-processing step due to the error above. You can " + "still use SAM 3 and it's OK to ignore the error above, although some post-processing " + "functionality may be limited (which doesn't affect the results in most cases; see " + "https://github.com/facebookresearch/sam3/blob/main/INSTALL.md).", + category=UserWarning, + stacklevel=2, + ) + masks = input_masks + + masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False) + return masks diff --git a/third_party/sam3/sam3/model/utils/sam2_utils.py b/third_party/sam3/sam3/model/utils/sam2_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e824fe18a63733a81f9ba334ae609482d1f3b749 --- /dev/null +++ b/third_party/sam3/sam3/model/utils/sam2_utils.py @@ -0,0 +1,235 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved +# All rights reserved. + +# pyre-unsafe + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +from threading import Thread + +import numpy as np +import torch +from PIL import Image +from tqdm import tqdm + + +def _load_img_as_tensor(img_path, image_size): + img_pil = Image.open(img_path) + img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size))) + if img_np.dtype == np.uint8: # np.uint8 is expected for JPEG images + img_np = img_np / 255.0 + else: + raise RuntimeError(f"Unknown image dtype: {img_np.dtype} on {img_path}") + img = torch.from_numpy(img_np).permute(2, 0, 1) + video_width, video_height = img_pil.size # the original video size + return img, video_height, video_width + + +class AsyncVideoFrameLoader: + """ + A list of video frames to be load asynchronously without blocking session start. + """ + + def __init__( + self, + img_paths, + image_size, + offload_video_to_cpu, + img_mean, + img_std, + compute_device, + ): + self.img_paths = img_paths + self.image_size = image_size + self.offload_video_to_cpu = offload_video_to_cpu + self.img_mean = img_mean + self.img_std = img_std + # items in `self.images` will be loaded asynchronously + self.images = [None] * len(img_paths) + # catch and raise any exceptions in the async loading thread + self.exception = None + # video_height and video_width be filled when loading the first image + self.video_height = None + self.video_width = None + self.compute_device = compute_device + + # load the first frame to fill video_height and video_width and also + # to cache it (since it's most likely where the user will click) + self.__getitem__(0) + + # load the rest of frames asynchronously without blocking the session start + def _load_frames(): + try: + for n in tqdm(range(len(self.images)), desc="frame loading (JPEG)"): + self.__getitem__(n) + except Exception as e: + self.exception = e + + self.thread = Thread(target=_load_frames, daemon=True) + self.thread.start() + + def __getitem__(self, index): + if self.exception is not None: + raise RuntimeError("Failure in frame loading thread") from self.exception + + img = self.images[index] + if img is not None: + return img + + img, video_height, video_width = _load_img_as_tensor( + self.img_paths[index], self.image_size + ) + self.video_height = video_height + self.video_width = video_width + # normalize by mean and std + img -= self.img_mean + img /= self.img_std + if not self.offload_video_to_cpu: + img = img.to(self.compute_device, non_blocking=True) + self.images[index] = img + return img + + def __len__(self): + return len(self.images) + + +def load_video_frames( + video_path, + image_size, + offload_video_to_cpu, + img_mean=(0.5, 0.5, 0.5), + img_std=(0.5, 0.5, 0.5), + async_loading_frames=False, + compute_device=torch.device("cuda"), +): + """ + Load the video frames from video_path. The frames are resized to image_size as in + the model and are loaded to GPU if offload_video_to_cpu=False. This is used by the demo. + """ + is_bytes = isinstance(video_path, bytes) + is_str = isinstance(video_path, str) + is_mp4_path = is_str and os.path.splitext(video_path)[-1] in [".mp4", ".MP4"] + if is_bytes or is_mp4_path: + return load_video_frames_from_video_file( + video_path=video_path, + image_size=image_size, + offload_video_to_cpu=offload_video_to_cpu, + img_mean=img_mean, + img_std=img_std, + compute_device=compute_device, + ) + elif is_str and os.path.isdir(video_path): + return load_video_frames_from_jpg_images( + video_path=video_path, + image_size=image_size, + offload_video_to_cpu=offload_video_to_cpu, + img_mean=img_mean, + img_std=img_std, + async_loading_frames=async_loading_frames, + compute_device=compute_device, + ) + else: + raise NotImplementedError( + "Only MP4 video and JPEG folder are supported at this moment" + ) + + +def load_video_frames_from_jpg_images( + video_path, + image_size, + offload_video_to_cpu, + img_mean=(0.5, 0.5, 0.5), + img_std=(0.5, 0.5, 0.5), + async_loading_frames=False, + compute_device=torch.device("cuda"), +): + """ + Load the video frames from a directory of JPEG files (".jpg" format). + + The frames are resized to image_size x image_size and are loaded to GPU if + `offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`. + + You can load a frame asynchronously by setting `async_loading_frames` to `True`. + """ + if isinstance(video_path, str) and os.path.isdir(video_path): + jpg_folder = video_path + else: + raise NotImplementedError( + "Only JPEG frames are supported at this moment. For video files, you may use " + "ffmpeg (https://ffmpeg.org/) to extract frames into a folder of JPEG files, such as \n" + "```\n" + "ffmpeg -i .mp4 -q:v 2 -start_number 0 /'%05d.jpg'\n" + "```\n" + "where `-q:v` generates high-quality JPEG frames and `-start_number 0` asks " + "ffmpeg to start the JPEG file from 00000.jpg." + ) + + frame_names = [ + p + for p in os.listdir(jpg_folder) + if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] + ] + frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) + num_frames = len(frame_names) + if num_frames == 0: + raise RuntimeError(f"no images found in {jpg_folder}") + img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names] + img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] + img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] + + if async_loading_frames: + lazy_images = AsyncVideoFrameLoader( + img_paths, + image_size, + offload_video_to_cpu, + img_mean, + img_std, + compute_device, + ) + return lazy_images, lazy_images.video_height, lazy_images.video_width + + images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32) + for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")): + images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size) + if not offload_video_to_cpu: + images = images.to(compute_device) + img_mean = img_mean.to(compute_device) + img_std = img_std.to(compute_device) + # normalize by mean and std + images -= img_mean + images /= img_std + return images, video_height, video_width + + +def load_video_frames_from_video_file( + video_path, + image_size, + offload_video_to_cpu, + img_mean=(0.5, 0.5, 0.5), + img_std=(0.5, 0.5, 0.5), + compute_device=torch.device("cuda"), +): + """Load the video frames from a video file.""" + import decord + + img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] + img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] + # Get the original video height and width + decord.bridge.set_bridge("torch") + video_height, video_width, _ = decord.VideoReader(video_path).next().shape + # Iterate over all frames in the video + images = [] + for frame in decord.VideoReader(video_path, width=image_size, height=image_size): + images.append(frame.permute(2, 0, 1)) + + images = torch.stack(images, dim=0).float() / 255.0 + if not offload_video_to_cpu: + images = images.to(compute_device) + img_mean = img_mean.to(compute_device) + img_std = img_std.to(compute_device) + # normalize by mean and std + images -= img_mean + images /= img_std + return images, video_height, video_width diff --git a/third_party/sam3/sam3/model/video_tracking_multiplex.py b/third_party/sam3/sam3/model/video_tracking_multiplex.py new file mode 100644 index 0000000000000000000000000000000000000000..0a4b0eb1644c2e2588206ae0541bda515a95919a --- /dev/null +++ b/third_party/sam3/sam3/model/video_tracking_multiplex.py @@ -0,0 +1,3654 @@ +from collections import defaultdict + +""" +Video tracking model with multiplexing support. + +This file extends the base video tracking with prompt functionality to add: + - Multiplexing: Support for processing multiple objects simultaneously + - Recording image features in memory to support the decoupled transformer for memory reading +""" + +import logging +from copy import deepcopy + +try: + from typing import Iterable, Literal, NotRequired, Optional, Required, TypedDict +except ImportError: + from typing_extensions import ( + Iterable, + Literal, + NotRequired, # not available in Python 3.10 + Optional, + Required, # not available in Python 3.10 + TypedDict, + ) + +import numpy as np +import torch +import torch.distributed +import torch.nn as nn +import torch.nn.functional as F +from sam3.model.data_misc import BatchedDatapoint, NestedTensor +from sam3.model.memory import SimpleMaskEncoder +from sam3.model.multiplex_mask_decoder import MLP, MultiplexMaskDecoder +from sam3.model.multiplex_utils import MultiplexController, MultiplexState +from sam3.model.sam3_tracker_utils import ( + get_1d_sine_pe, + get_next_point, + sample_box_points, + select_closest_cond_frames, +) +from sam3.sam.mask_decoder import MaskDecoder +from sam3.sam.prompt_encoder import PositionEmbeddingRandom, PromptEncoder +from sam3.sam.transformer import TwoWayTransformer +from timm.models.layers import trunc_normal_ + + +# a large negative value as a placeholder score for missing objects +NO_OBJ_SCORE = -1024.0 + +neck_outs = ["interactive", "sam2_backbone_out"] + + +class SAMOutput(TypedDict, total=True): + # Outputs from a single SAM head forward + low_res_multimasks: torch.Tensor + high_res_multimasks: torch.Tensor + ious: torch.Tensor + low_res_masks: torch.Tensor + high_res_masks: torch.Tensor + object_score_logits: torch.Tensor + obj_ptr: NotRequired[torch.Tensor] # [num_objects, C], in data space + + +class StageOutput(TypedDict, total=False): + # metadata + conditioning_objects: Required[set[int]] + + # The outputs from a single stage; could be used as memory + pred_masks: torch.Tensor + pred_masks_high_res: torch.Tensor + point_inputs: dict[str, torch.Tensor] + mask_inputs: torch.Tensor + object_score_logits: torch.Tensor + obj_ptr: torch.Tensor # [num_buckets, multiplex_count, C], in mux space + maskmem_features: torch.Tensor + maskmem_pos_enc: list[torch.Tensor] + image_features: torch.Tensor + image_pos_enc: torch.Tensor + + # for memory filtering + iou_score: torch.Tensor + eff_iou_score: torch.Tensor + + # Multi-step prediction fields for state tracking or training + multistep_pred_masks: torch.Tensor + multistep_pred_masks_high_res: torch.Tensor + multistep_pred_multimasks: list[torch.Tensor] + multistep_pred_multimasks_high_res: list[torch.Tensor] + multistep_pred_ious: list[torch.Tensor] + multistep_point_inputs: list[dict] + multistep_object_score_logits: list[torch.Tensor] + + +class VideoTrackingMultiplex(nn.Module): + def __init__( + self, + backbone: nn.Module, + transformer: nn.Module, + maskmem_backbone: nn.Module, + multiplex_controller: MultiplexController, + num_maskmem: int = 7, # default 1 input frame + 6 previous frames as in CAE + image_size: int = 512, + backbone_stride: int = 16, # default to 16 as in CAE (truncated Hiera backbone) + prob_to_use_pt_input_for_train: float = 0.0, + prob_to_use_pt_input_for_eval: float = 0.0, + prob_to_use_box_input_for_train: float = 0.0, + prob_to_use_box_input_for_eval: float = 0.0, + # always_keep_first_frame_mem=True, # this option is removed (we've always set it to True) + apply_sigmoid_to_mask_logits_for_mem_enc: bool = False, + sigmoid_scale_for_mem_enc: float = 1.0, # scale factor for mask sigmoid prob, only effective when `apply_sigmoid_to_mask_logits_for_mem_enc` is True + sigmoid_bias_for_mem_enc: float = 0.0, # bias factor for mask sigmoid prob, only effective when `apply_sigmoid_to_mask_logits_for_mem_enc` is True + # During evaluation, whether to binarize the sigmoid mask logits on interacted frames with clicks, only effective when `apply_sigmoid_to_mask_logits_for_mem_enc` is True + binarize_mask_from_pts_for_mem_enc: bool = False, + use_mask_input_as_output_without_sam: bool = False, # on frames with mask input, whether to directly output the input mask without using a SAM prompt encoder + mask decoder + # how many frames for interactive point sampling (only effective when using point inputs per video; the first frame is always used) + # - if `num_frames_to_correct` below is True, we randomly sample 1~num_frames_to_correct frames for interactive point sampling + # - otherwise we used a fixed number of num_frames_to_correct frames for interactive point sampling + # if it is 1, we do interactive point sampling only on the 1st frame + # if it is greater than 1, we interactive point sampling in the 1st frame and other randomly selected frames + num_frames_to_correct_for_train: int = 1, # default: only iteratively sample on first frame + num_frames_to_correct_for_eval: int = 1, # default: only iteratively sample on first frame + rand_frames_to_correct_for_train: bool = False, + rand_frames_to_correct_for_eval: bool = False, + prob_correct_all_objects_for_train: float = 0.0, + ratio_of_objects_to_correct_for_train: float = 1.0, + force_correct_all_for_conditional_inputs: bool = False, + rand_objects_to_correct_for_train: bool = True, + # how many frames to use as initial conditioning frames (for both point input and mask input; the first frame is always used as an initial conditioning frame) + # - if `rand_init_cond_frames` below is True, we randomly sample 1~num_init_cond_frames initial conditioning frames + # - otherwise we sample a fixed number of num_init_cond_frames initial conditioning frames + # note: for point input, we sample correction points on all such initial conditioning frames, and we require that `num_frames_to_correct` >= `num_init_cond_frames`; + # these are initial conditioning frames because as we track the video, more conditioning frames might be added + # when a frame receives correction clicks under point input if `add_all_frames_to_correct_as_cond=True` + num_init_cond_frames_for_train: int = 1, # default: only use the first frame as initial conditioning frame + num_init_cond_frames_for_eval: int = 1, # default: only use the first frame as initial conditioning frame + rand_init_cond_frames_for_train: bool = True, # default: random 1~num_init_cond_frames_for_train cond frames (to be constent w/ previous TA data loader) + rand_init_cond_frames_for_eval: bool = False, + # The maximum number of conditioning frames to participate in the memory attention (-1 means no limit; if there are more conditioning frames than this limit, + # we only cross-attend to the temporally closest `max_cond_frames_in_attn` conditioning frames in the encoder when tracking each frame). This gives the model + # a temporal locality when handling a large number of annotated frames (since closer frames should be more important) and also avoids GPU OOM. + max_cond_frames_in_attn: int = -1, + # Whether to always keep the first conditioning frame in case we exceed the maximum number of conditioning frames allowed + keep_first_cond_frame=False, + # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click + # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames + add_all_frames_to_correct_as_cond: bool = False, + # how many additional correction points to sample (on each frame selected to be corrected) + # note that the first frame receives an initial input click (in addition to any correction clicks) + num_correction_pt_per_frame: int = 7, + # method for point sampling during evaluation + # "uniform" (sample uniformly from error region) or "center" (use the point with the largest distance to error region boundary) + # default to "center" to be consistent with evaluation in the SAM paper + pt_sampling_for_eval: Literal["uniform", "center"] = "center", + # During training, we optionally allow sampling the correction points from GT regions + # instead of the prediction error regions with a small probability. This might allow the + # model to overfit less to the error regions in training datasets + prob_to_sample_from_gt_for_train: float = 0.0, + # on the first frame, whether to directly add the no-memory embedding to the image feature + # (instead of using the transformer encoder) + directly_add_no_mem_embed: bool = False, + # whether to use high-resolution feature maps in the SAM mask decoder + use_high_res_features_in_sam: bool = False, + # whether to output multiple (3) masks for the first click on initial conditioning frames + multimask_output_in_sam: bool = False, + # the minimum and maximum number of clicks to use multimask_output_in_sam (only relevant when `multimask_output_in_sam=True`; + # default is 1 for both, meaning that only the first click gives multimask output; also note that a box counts as two points) + multimask_min_pt_num: int = 1, + multimask_max_pt_num: int = 1, + # whether to also use multimask output for tracking (not just for the first click on initial conditioning frames; only relevant when `multimask_output_in_sam=True`) + multimask_output_for_tracking: bool = False, + # Whether to use multimask tokens for obj ptr; Only relevant when both + # use_obj_ptrs_in_encoder=True and multimask_output_for_tracking=True + use_multimask_token_for_obj_ptr: bool = False, + # if the last output is multimask during training, whether to select the mask w/ highest IoU to the ground-truth for memory encoder + # (instead of the mask with the highest prediction score; this resembles teacher-forcing for multi-mask prediction in tracking) + use_best_iou_mask_for_mem_enc: bool = False, + # whether to use sigmoid to restrict ious prediction to [0-1] + iou_prediction_use_sigmoid: bool = False, + # whether to feed the previously predicted low-res mask logits as a mask prompt into the SAM mask decoder during iterative point sampling + iter_use_prev_mask_pred: bool = False, + # whether to forward image features per frame (as it's being tracked) during evaluation, instead of forwarding image features + # of all frames at once. This avoids backbone OOM errors on very long videos in evaluation, but could be slightly slower. + forward_backbone_per_frame_for_eval: bool = False, + # The memory bank's temporal stride during evaluation (i.e. the `r` parameter in XMem and Cutie; XMem and Cutie use r=5). + # For r>1, the (self.num_maskmem - 1) non-conditioning memory frames consist of + # (self.num_maskmem - 2) nearest frames from every r-th frames, plus the last frame. + memory_temporal_stride_for_eval: int = 1, + # whether to offload outputs to CPU memory during evaluation, to avoid GPU OOM on very long videos or very large resolutions or too many objects + # (it's recommended to use `forward_backbone_per_frame_for_eval=True` first before setting this option to True) + offload_output_to_cpu_for_eval: bool = False, + # whether to trim the output of past non-conditioning frames (num_maskmem frames before the current frame) during evaluation + # (this helps save GPU or CPU memory on very long videos for semi-supervised VOS eval, where only the first frame receives prompts) + trim_past_non_cond_mem_for_eval: bool = False, + # whether to apply non-overlapping constraints on the object masks in the memory encoder during evaluation (to avoid/alleviate superposing masks) + non_overlap_masks_for_mem_enc: bool = False, + # whether to cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: bool = False, + # the maximum number of object pointers from other frames in encoder cross attention (only relevant when `use_obj_ptrs_in_encoder=True`) + max_obj_ptrs_in_encoder: int = 16, + # whether to add temporal positional encoding to the object pointers in the encoder (only relevant when `use_obj_ptrs_in_encoder=True`) + add_tpos_enc_to_obj_ptrs: bool = True, + # whether to add an extra linear projection layer for the temporal positional encoding in the object pointers to avoid potential interference + # with spatial positional encoding (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`) + proj_tpos_enc_in_obj_ptrs: bool = False, + # whether to use signed distance (instead of unsigned absolute distance) in the temporal positional encoding in the object pointers + # (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`) + use_signed_tpos_enc_to_obj_ptrs: bool = False, + # whether to only attend to object pointers in the past (before the current frame) in the encoder during evaluation + # (only relevant when `use_obj_ptrs_in_encoder=True`; this might avoid pointer information too far in the future to distract the initial tracking) + only_obj_ptrs_in_the_past_for_eval: bool = False, + # Whether to predict if there is an object in the frame + pred_obj_scores: bool = False, + # Whether to use an MLP to predict object scores + pred_obj_scores_mlp: bool = False, + # Only relevant if pred_obj_scores=True and use_obj_ptrs_in_encoder=True; + # Whether to have a fixed no obj pointer when there is no object present + # or to use it as an additive embedding with obj_ptr produced by decoder + fixed_no_obj_ptr: bool = False, + use_no_obj_ptr: bool = True, + use_mlp_for_obj_ptr_proj: bool = False, + # replace per-slot static no-obj embeddings with linear projections of object embeddings + use_linear_no_obj_ptr: bool = False, + # add no obj embedding to spatial frames + no_obj_embed_spatial: bool = False, + # does not apply to spatial memories (only to obj ptrs), unless unified_tpos_enc=True + sincos_tpos_enc: bool = True, + # extra arguments used to construct the SAM mask decoder; if not None, it should be a dict of kwargs to be passed into `MaskDecoder` class. + sam_mask_decoder_extra_args: Optional[dict] = None, + # whether to compile all the model compoents + compile_all_components: bool = False, + # save and use image features in the memory + save_image_features: bool = False, + # number of multimask outputs in the SAM mask decoder + num_multimask_outputs: int = 3, + # use a single mask token to predict all masks + decode_mask_with_shared_tokens: bool = False, + # use the mask token for predicting ious and object scores + decode_mask_attribute_with_shared_tokens: bool = False, + share_necks: bool = False, # share the interactive and sam2_backbone necks + # if enabled, use a different rng generator for operations that differ between GPUs, + # such that the base rng that controls flow does not go out-of-sync among GPUs + # There will be a slight performance penalty when turned off due to uneven workload but it's minor + randomness_fix: bool = False, + # add a learnable embeddings to the object queries that corresponding to paddings/removed objects + add_output_suppression_embeddings: bool = False, + # add a per-object embedding to the spatial memory features if that object is a conditioning input + add_object_conditional_embeddings: bool = False, + # if None, follow add_object_conditional_embeddings + add_object_unconditional_embeddings: Optional[bool] = None, + # for each object, add an additional channel in the mask encoder to indicate conditional/unconditional objects + condition_as_mask_input: bool = False, + condition_as_mask_input_fg: float = 1.0, + condition_as_mask_input_bg: float = 0.0, + # use v2 memory positional encodings + # in v2, the last slot in the positional encoding no longer refers to the conditional frame + # it now refers to "out-of-bound" frames. + # The motivation is to shift all encodings of "conditioning" to the object_conditional embeddings + use_maskmem_tpos_v2: bool = False, + # select the frame with object existence + use_memory_selection: bool = False, + # when using memory selection, the threshold to determine if the frame is good + mf_threshold: float = 0.01, + # this is a flag for demo purposes; it does not need to be explicitly set + is_dynamic_model: bool = False, + object_score_logit_threshold: float = 0.0, + stability_score_attentuation: bool = False, # select from multimask based on iou*stability_score + ): + super().__init__() + + # the interactive sam mask deocder can use dynamic_multimask_via_stability + interactive_sam_mask_decoder_extra_args = deepcopy(sam_mask_decoder_extra_args) + if sam_mask_decoder_extra_args is not None: + dynamic_multimask_via_stability = sam_mask_decoder_extra_args.get( + "dynamic_multimask_via_stability", False + ) + if dynamic_multimask_via_stability: + sam_mask_decoder_extra_args["dynamic_multimask_via_stability"] = False + print( + "dynamic_multimask_via_stability is reset to False in the multiplex model" + ) + + # Part 1: the image backbone + self.backbone = backbone + # Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting + self.use_high_res_features_in_sam = use_high_res_features_in_sam + self.num_feature_levels = 3 if use_high_res_features_in_sam else 1 + self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder + self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder + if use_obj_ptrs_in_encoder: + # A conv layer to downsample the GT mask prompt to stride 4 (the same stride as + # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale, + # so that it can be fed into the SAM mask decoder to generate a pointer. + self.interactive_mask_downsample = torch.nn.Conv2d( + 1, 1, kernel_size=4, stride=4 + ) + + self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs + if proj_tpos_enc_in_obj_ptrs: + assert add_tpos_enc_to_obj_ptrs # these options need to be used together + self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs + self.use_signed_tpos_enc_to_obj_ptrs = use_signed_tpos_enc_to_obj_ptrs + self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval + self.multiplex_controller = multiplex_controller + self.save_image_features = save_image_features + self.multiplex_count = self.multiplex_controller.multiplex_count + + # Part 2: encoder-only transformer to fuse current frame's visual features + # with memories from past frames + assert transformer.decoder is None, "transformer should be encoder-only" + self.transformer = transformer + self.hidden_dim: int = transformer.d_model + + # Part 3: memory encoder for the previous frame's outputs + self.maskmem_backbone = maskmem_backbone + self.mem_dim = self.hidden_dim + if hasattr(self.maskmem_backbone, "out_proj") and hasattr( + self.maskmem_backbone.out_proj, "weight" + ): + # if there is compression of memories along channel dim + mem_dim = self.maskmem_backbone.out_proj.weight.shape[0] + assert ( + mem_dim == self.hidden_dim + ), "there should be no compression of memory embeddings" + self.num_maskmem = num_maskmem # Number of memories accessible + # Temporal encoding of the memories + self.sincos_tpos_enc = sincos_tpos_enc + self.use_maskmem_tpos_v2 = use_maskmem_tpos_v2 + # tpos specific to spatial memories only + # last token actually corresponds to conditioning + # frame embedding, indep of temporal position + self.maskmem_tpos_enc = torch.nn.Parameter( + torch.zeros(num_maskmem, 1, 1, self.mem_dim) + ) + trunc_normal_(self.maskmem_tpos_enc, std=0.02) + + # a single token to indicate no memory embedding from previous frames + self.interactivity_no_mem_embed = torch.nn.Parameter( + torch.zeros(1, 1, self.hidden_dim) + ) + trunc_normal_(self.interactivity_no_mem_embed, std=0.02) + self.directly_add_no_mem_embed = directly_add_no_mem_embed + + # Whether to apply sigmoid to the output raw mask logits (to turn them from + # range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder + self.apply_sigmoid_to_mask_logits_for_mem_enc = ( + apply_sigmoid_to_mask_logits_for_mem_enc + ) + if apply_sigmoid_to_mask_logits_for_mem_enc: + self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc + self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc + + if binarize_mask_from_pts_for_mem_enc: + logging.warning( + """ + The current model is not trained with binarize_mask_from_pts_for_mem_enc; + We force it to False here because external callers often hardcoded this + to True, ignoring the config. + Re-training should be possible. + """ + ) + binarize_mask_from_pts_for_mem_enc = False + + self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc + self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc + self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval + # On frames with mask input, whether to directly output the input mask without + # using a SAM prompt encoder + mask decoder + self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam + self.multimask_output_in_sam = multimask_output_in_sam + self.multimask_min_pt_num = multimask_min_pt_num + self.multimask_max_pt_num = multimask_max_pt_num + self.multimask_output_for_tracking = multimask_output_for_tracking + self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr + self.use_best_iou_mask_for_mem_enc = use_best_iou_mask_for_mem_enc + self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid + self.object_score_logit_threshold = object_score_logit_threshold + self.stability_score_attentuation = stability_score_attentuation + if iter_use_prev_mask_pred: + # In this case, we are feeding the previously predicted SAM mask logits + # as mask prompt into the SAM mask decoder, which has a different format + # and magnitude from GT mask input in VOS. Therefore in this case, the GT + # mask input must be encoded directly (not through the SAM mask decoder). + if min(prob_to_use_pt_input_for_train, prob_to_use_pt_input_for_eval) < 1: + assert use_mask_input_as_output_without_sam + self.iter_use_prev_mask_pred = iter_use_prev_mask_pred + + # Part 4: SAM-style prompt encoder (for both mask and point inputs) + # and SAM-style mask decoder for the final mask output + self.image_size = image_size + self.backbone_stride = backbone_stride + self.low_res_mask_size = self.image_size // self.backbone_stride * 4 + # we resize the mask if it doesn't match `self.input_mask_size` (which is always 4x + # the low-res mask size, regardless of the actual input image size); this is because + # `_use_mask_as_output` always downsamples the input masks by 4x + self.input_mask_size = self.low_res_mask_size * 4 + self.forward_backbone_per_frame_for_eval = forward_backbone_per_frame_for_eval + self.offload_output_to_cpu_for_eval = offload_output_to_cpu_for_eval + if trim_past_non_cond_mem_for_eval: + assert ( + num_frames_to_correct_for_eval <= 1 + ), "trim_past_non_cond_mem_for_eval=True requires that only the first frame receives prompts" + self.trim_past_non_cond_mem_for_eval = trim_past_non_cond_mem_for_eval + self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args + self.interactive_sam_mask_decoder_extra_args = ( + interactive_sam_mask_decoder_extra_args + ) + self.pred_obj_scores = pred_obj_scores + self.pred_obj_scores_mlp = pred_obj_scores_mlp + self.fixed_no_obj_ptr = fixed_no_obj_ptr + self.use_no_obj_ptr = use_no_obj_ptr + self.use_linear_no_obj_ptr = use_linear_no_obj_ptr + + if self.fixed_no_obj_ptr: + assert self.pred_obj_scores + assert self.use_obj_ptrs_in_encoder + if ( + self.pred_obj_scores + and self.use_obj_ptrs_in_encoder + and self.use_no_obj_ptr + ): + if self.use_linear_no_obj_ptr: + self.no_obj_ptr_linear = nn.Linear(self.hidden_dim, self.hidden_dim) + else: + self.no_obj_ptr = torch.nn.Parameter( + torch.zeros(self.multiplex_count, self.hidden_dim) + ) + trunc_normal_(self.no_obj_ptr, std=0.02) + + self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj + self.no_obj_embed_spatial = None + if no_obj_embed_spatial: + self.no_obj_embed_spatial = torch.nn.Parameter( + torch.zeros(self.multiplex_count, self.hidden_dim) + ) + trunc_normal_(self.no_obj_embed_spatial, std=0.02) + self.num_multimask_outputs = num_multimask_outputs + self.decode_mask_with_shared_tokens = decode_mask_with_shared_tokens + self.decode_mask_attribute_with_shared_tokens = ( + decode_mask_attribute_with_shared_tokens + ) + self.share_necks = share_necks + + self.add_output_suppression_embeddings = add_output_suppression_embeddings + if self.add_output_suppression_embeddings: + self.output_valid_embed = torch.nn.Parameter( + torch.zeros(self.multiplex_count, self.hidden_dim) + ) + self.output_invalid_embed = torch.nn.Parameter( + torch.zeros(self.multiplex_count, self.hidden_dim) + ) + trunc_normal_(self.output_valid_embed, std=0.02) + trunc_normal_(self.output_invalid_embed, std=0.02) + self.add_object_conditional_embeddings = add_object_conditional_embeddings + if add_object_unconditional_embeddings is None: + add_object_unconditional_embeddings = add_object_conditional_embeddings + self.add_object_unconditional_embeddings = add_object_unconditional_embeddings + if add_object_unconditional_embeddings: + assert add_object_conditional_embeddings + if self.add_object_conditional_embeddings: + # have embeddings for both conditional and non-conditional objects + # such that the features are more "balanced" + # these three sets should be disjoint and their union should cover all objects + # for conditioning objects + self.obj_cond_embed = torch.nn.Parameter( + torch.zeros(self.multiplex_count, self.hidden_dim) + ) + trunc_normal_(self.obj_cond_embed, std=0.02) + if self.add_object_unconditional_embeddings: + # for non-conditioning objects + self.obj_non_cond_embed = torch.nn.Parameter( + torch.zeros(self.multiplex_count, self.hidden_dim) + ) + trunc_normal_(self.obj_non_cond_embed, std=0.02) + + self.condition_as_mask_input = condition_as_mask_input + self.condition_as_mask_input_fg = condition_as_mask_input_fg + self.condition_as_mask_input_bg = condition_as_mask_input_bg + + self.is_dynamic_model = is_dynamic_model + + self._build_sam_heads() + + # Point sampler and conditioning frames + self.prob_to_use_pt_input_for_train = prob_to_use_pt_input_for_train + self.prob_to_use_box_input_for_train = prob_to_use_box_input_for_train + self.prob_to_use_pt_input_for_eval = prob_to_use_pt_input_for_eval + self.prob_to_use_box_input_for_eval = prob_to_use_box_input_for_eval + if prob_to_use_pt_input_for_train > 0 or prob_to_use_pt_input_for_eval > 0: + logging.info("Using points (sampled from masks) as inputs") + assert num_frames_to_correct_for_train >= num_init_cond_frames_for_train + assert num_frames_to_correct_for_eval >= num_init_cond_frames_for_eval + self.num_frames_to_correct_for_train = num_frames_to_correct_for_train + self.num_frames_to_correct_for_eval = num_frames_to_correct_for_eval + self.rand_frames_to_correct_for_train = rand_frames_to_correct_for_train + self.rand_frames_to_correct_for_eval = rand_frames_to_correct_for_eval + self.prob_correct_all_objects_for_train = prob_correct_all_objects_for_train + self.ratio_of_objects_to_correct_for_train = ( + ratio_of_objects_to_correct_for_train + ) + self.rand_objects_to_correct_for_train = rand_objects_to_correct_for_train + self.force_correct_all_for_conditional_inputs = ( + force_correct_all_for_conditional_inputs + ) + # Initial multi-conditioning frames + self.num_init_cond_frames_for_train = num_init_cond_frames_for_train + self.num_init_cond_frames_for_eval = num_init_cond_frames_for_eval + self.rand_init_cond_frames_for_train = rand_init_cond_frames_for_train + self.rand_init_cond_frames_for_eval = rand_init_cond_frames_for_eval + self.max_cond_frames_in_attn = max_cond_frames_in_attn + self.keep_first_cond_frame = keep_first_cond_frame + self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond + self.num_correction_pt_per_frame = num_correction_pt_per_frame + self.pt_sampling_for_eval = pt_sampling_for_eval + self.prob_to_sample_from_gt_for_train = prob_to_sample_from_gt_for_train + # A random number generator with a fixed initial seed across GPUs + self.rng = np.random.default_rng(seed=42) + if randomness_fix: + self.rng2 = np.random.default_rng(seed=42) + else: + self.rng2 = self.rng + + # Use frame filtering according to SAM2Long + self.use_memory_selection = use_memory_selection + self.mf_threshold = mf_threshold + + # Compile all components of the model + self.compile_all_components = compile_all_components + if self.compile_all_components: + self._compile_all_components() + + def _get_tpos_enc(self, rel_pos_list, device, max_abs_pos=None, dummy=False): + if dummy: + return torch.zeros(len(rel_pos_list), self.mem_dim, device=device) + + t_diff_max = max_abs_pos - 1 if max_abs_pos is not None else 1 + pos_enc = ( + torch.tensor(rel_pos_list).pin_memory().to(device=device, non_blocking=True) + / t_diff_max + ) + if self.sincos_tpos_enc: + tpos_dim = ( + self.hidden_dim if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim + ) + pos_enc = get_1d_sine_pe(pos_enc, dim=tpos_dim) + else: + raise NotImplementedError + pos_enc = self.obj_ptr_tpos_proj(pos_enc) + + return pos_enc + + def _build_sam_heads(self): + """Build SAM-style prompt encoder and mask decoder.""" + self.sam_prompt_embed_dim = self.hidden_dim + self.sam_image_embedding_size = self.image_size // self.backbone_stride + + self.image_pe_layer = PositionEmbeddingRandom(self.hidden_dim // 2) + + # build PromptEncoder and MaskDecoder from SAM + # (their hyperparameters like `mask_in_chans=16` are from SAM code) + self.interactive_sam_prompt_encoder = PromptEncoder( + embed_dim=self.sam_prompt_embed_dim, + image_embedding_size=( + self.sam_image_embedding_size, + self.sam_image_embedding_size, + ), + input_image_size=(self.image_size, self.image_size), + mask_in_chans=16, + ) + + self.interactive_sam_mask_decoder = MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=self.sam_prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=self.sam_prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + use_high_res_features=self.use_high_res_features_in_sam, + iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid, + pred_obj_scores=self.pred_obj_scores, + pred_obj_scores_mlp=self.pred_obj_scores_mlp, + use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr, + **(self.interactive_sam_mask_decoder_extra_args or {}), + ) + if self.share_necks: + # we will use self.sam_mask_decoder's convs + del self.interactive_sam_mask_decoder.conv_s0 + del self.interactive_sam_mask_decoder.conv_s1 + + self.sam_mask_decoder = MultiplexMaskDecoder( + multiplex_count=self.multiplex_count, + num_multimask_outputs=self.num_multimask_outputs, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=self.hidden_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=self.hidden_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + use_high_res_features=self.use_high_res_features_in_sam, + iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid, + pred_obj_scores=self.pred_obj_scores, + pred_obj_scores_mlp=self.pred_obj_scores_mlp, + use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr, + decode_mask_with_shared_tokens=self.decode_mask_with_shared_tokens, + decode_mask_attribute_with_shared_tokens=self.decode_mask_attribute_with_shared_tokens, + multimask_outputs_only=self.num_multimask_outputs > 0 + and self.multimask_output_in_sam, + **(self.sam_mask_decoder_extra_args or {}), + ) + + if self.use_obj_ptrs_in_encoder: + # a linear projection on SAM output tokens to turn them into object pointers + self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim) + self.interactive_obj_ptr_proj = torch.nn.Linear( + self.hidden_dim, self.hidden_dim + ) + if self.use_mlp_for_obj_ptr_proj: + self.obj_ptr_proj = MLP( + self.hidden_dim, self.hidden_dim, self.hidden_dim, 3 + ) + self.interactive_obj_ptr_proj = MLP( + self.hidden_dim, self.hidden_dim, self.hidden_dim, 3 + ) + else: + self.obj_ptr_proj = torch.nn.Identity() + self.interactive_obj_ptr_proj = torch.nn.Identity() + if self.proj_tpos_enc_in_obj_ptrs: + # a linear projection on temporal positional encoding in object pointers to + # avoid potential interference with spatial positional encoding + self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim) + else: + self.obj_ptr_tpos_proj = torch.nn.Identity() + + def _get_interactive_pix_mem( + self, features: torch.Tensor, feat_sizes: list[tuple] + ) -> torch.Tensor: + assert self.directly_add_no_mem_embed + pix_feat_with_mem = features[-1] + self.interactivity_no_mem_embed + B = features[-1].size(1) # batch size on this frame + C = self.hidden_dim + H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size + pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) + return pix_feat_with_mem + + def _forward_sam_heads( + self, + backbone_features: torch.Tensor, + *, + point_inputs: Optional[dict[str, torch.Tensor]] = None, + mask_inputs: Optional[torch.Tensor] = None, + interactive_high_res_features: Optional[list[torch.Tensor]] = None, + propagation_high_res_features: Optional[list[torch.Tensor]] = None, + multimask_output: bool = False, + gt_masks=None, + multiplex_state: MultiplexState, + objects_to_interact: Optional[list[int]] = None, + ) -> SAMOutput: + """ + Forward SAM prompt encoders and mask heads. + We run the propagation head, the interactive head, or both, based on the inputs. + + Inputs: + - backbone_features: image features of [B, C, H, W] shape + - point_inputs: a dictionary with "point_coords" and "point_labels", where + 1) "point_coords" has [B, P, 2] shape and float32 dtype and contains the + absolute pixel-unit coordinate in (x, y) format of the P input points + 2) "point_labels" has shape [B, P] and int32 dtype, where 1 means + positive clicks, 0 means negative clicks, and -1 means padding + - mask_inputs: a mask of [B, 1, H*16, W*16] shape, float or bool, with the + same spatial size as the image. + - high_res_features: either 1) None or 2) a list of length 2 containing + two feature maps of [B, C, 4*H, 4*W] and [B, C, 2*H, 2*W] shapes respectively, + which will be used as high-resolution feature maps for SAM decoder. + - multimask_output: if it's True, we output 3 candidate masks and their 3 + corresponding IoU estimates, and if it's False, we output only 1 mask and + its corresponding IoU estimate. + + Outputs: + - low_res_multimasks: [B, M, H*4, W*4] shape (where M = 3 if + `multimask_output=True` and M = 1 if `multimask_output=False`), the SAM + output mask logits (before sigmoid) for the low-resolution masks, with 4x + the resolution (1/4 stride) of the input backbone_features. + - high_res_multimasks: [B, M, H*16, W*16] shape (where M = 3 + if `multimask_output=True` and M = 1 if `multimask_output=False`), + upsampled from the low-resolution masks, with shape size as the image + (stride is 1 pixel). + - ious: [B, M] shape (where M = 3 if `multimask_output=True` and M = 1 + if `multimask_output=False`), the estimated IoU of each output mask. + - low_res_masks: [B, 1, H*4, W*4] shape, the best mask in `low_res_multimasks`. + If `multimask_output=True`, it's the mask with the highest IoU estimate. + If `multimask_output=False`, it's the same as `low_res_multimasks`. + - high_res_masks: [B, 1, H*16, W*16] shape, the best mask in `high_res_multimasks`. + If `multimask_output=True`, it's the mask with the highest IoU estimate. + If `multimask_output=False`, it's the same as `high_res_multimasks`. + - obj_ptr: [num_buckets, multiplex_count, C] shape, the object pointer vector for + the output mask, extracted based on the output token from the SAM mask decoder. + """ + + device = backbone_features.device + assert backbone_features.size(1) == self.hidden_dim + assert backbone_features.size(2) == self.sam_image_embedding_size + assert backbone_features.size(3) == self.sam_image_embedding_size + + is_interactive = point_inputs is not None or mask_inputs is not None + + if is_interactive: + """ + Image-level, per-object interactive path + """ + assert interactive_high_res_features is not None + assert objects_to_interact is not None + + # a) Handle point prompts + if point_inputs is not None: + sam_point_coords = point_inputs["point_coords"] + sam_point_labels = point_inputs["point_labels"] + else: + assert mask_inputs is not None + # If no points are provided, pad with an empty point (with label -1) + sam_point_coords = torch.zeros( + mask_inputs.shape[0], 1, 2, device=device + ) + sam_point_labels = -torch.ones( + mask_inputs.shape[0], 1, dtype=torch.int32, device=device + ) + + # b) Handle mask prompts + if mask_inputs is not None: + # If mask_inputs is provided, downsize it into low-res mask input if needed + # and feed it as a dense mask prompt into the SAM mask encoder + assert len(mask_inputs.shape) == 4 + if ( + mask_inputs.shape[-2:] + != self.interactive_sam_prompt_encoder.mask_input_size + ): + sam_mask_prompt = F.interpolate( + mask_inputs.float(), + size=self.interactive_sam_prompt_encoder.mask_input_size, + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + else: + sam_mask_prompt = mask_inputs + else: + # Otherwise, simply feed None (and SAM's prompt encoder will add + # a learned `no_mask_embed` to indicate no mask input in this case). + sam_mask_prompt = None + + sparse_embeddings, dense_embeddings = self.interactive_sam_prompt_encoder( + points=(sam_point_coords, sam_point_labels), + boxes=None, + masks=sam_mask_prompt, + ) + + # Clone image_pe and the outputs of sam_prompt_encoder + # to enable compilation + sparse_embeddings = self._maybe_clone(sparse_embeddings) + dense_embeddings = self._maybe_clone(dense_embeddings) + image_pe = self._maybe_clone( + self.interactive_sam_prompt_encoder.get_dense_pe() + ) + ( + low_res_multimasks, + ious, + sam_output_tokens, + object_score_logits, + ) = self.interactive_sam_mask_decoder( + image_embeddings=backbone_features, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + repeat_image=True, + high_res_features=interactive_high_res_features, + ) + + else: + """ + Multiplexed propagation path + """ + assert propagation_high_res_features is not None + assert multiplex_state is not None + + if self.add_output_suppression_embeddings: + # the suppression embeddings inform the mask decoder the objects that should be decoded + output_valid_embed = self.output_valid_embed.unsqueeze(0) + output_invalid_embed = self.output_invalid_embed.unsqueeze(0) + valid_object_mask = ( + multiplex_state.get_valid_object_mask().unsqueeze(-1).float() + ) + output_merged_embed = ( + valid_object_mask * output_valid_embed + + (1 - valid_object_mask) * output_invalid_embed + ) + else: + output_merged_embed = None + + # Clone image_pe to enable compilation + image_pe = self._maybe_clone(self.get_propagation_dense_pe()) + out = self.sam_mask_decoder( + image_embeddings=backbone_features, + image_pe=image_pe, + high_res_features=propagation_high_res_features, + multimask_output=multimask_output, + extra_per_object_embeddings=output_merged_embed, + ) + low_res_multimasks = out["masks"] # [B, M, 3/1, H*4, W*4] + ious = out["iou_pred"] # [B, M, 3/1] + sam_output_tokens = out["sam_tokens_out"] # [B, M, 3/1, C] + object_score_logits = out["object_score_logits"] + + low_res_multimasks = multiplex_state.demux(low_res_multimasks) + ious = multiplex_state.demux(ious) + object_score_logits = multiplex_state.demux(object_score_logits) + sam_output_tokens = multiplex_state.demux(sam_output_tokens) + + """ + The interactive and the propagation paths converge here + """ + # Clone the output of sam_mask_decoder + # to enable compilation + low_res_multimasks = self._maybe_clone(low_res_multimasks) + ious = self._maybe_clone(ious) + object_score_logits = self._maybe_clone(object_score_logits) + sam_output_tokens = self._maybe_clone(sam_output_tokens) + + if self.pred_obj_scores: + is_obj_appearing = object_score_logits > self.object_score_logit_threshold + + # Mask used for spatial memories is always a *hard* choice between obj and no obj, + # consistent with the actual mask prediction + low_res_multimasks = torch.where( + is_obj_appearing[:, None, None], + low_res_multimasks, + NO_OBJ_SCORE, + ) + + # convert masks from possibly bfloat16 (or float16) to float32 + # (older PyTorch versions before 2.1 don't support `interpolate` on bf16) + low_res_multimasks = low_res_multimasks.float() + high_res_multimasks = F.interpolate( + low_res_multimasks, + size=(self.image_size, self.image_size), + mode="bilinear", + align_corners=False, + ) + + sam_output_token = sam_output_tokens[:, 0] + if multimask_output and ( + not self.decode_mask_with_shared_tokens or is_interactive + ): + # take the best mask prediction (with the highest IoU estimation) + if self.stability_score_attentuation: + # prefer selecting masks with high stability score + stability_score = self.sam_mask_decoder._get_stability_scores( + low_res_multimasks + ) + ious = ious * stability_score + + best_iou_inds = torch.argmax(ious, dim=-1) + batch_inds = torch.arange(ious.shape[0], device=device) + + low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + if sam_output_tokens.size(1) > 1: + sam_output_token = sam_output_tokens[batch_inds, best_iou_inds] + else: + if multimask_output and not is_interactive: + assert self.decode_mask_with_shared_tokens + low_res_masks = low_res_multimasks[:, 0:1] + high_res_masks = high_res_multimasks[:, 0:1] + else: + low_res_masks = low_res_multimasks + high_res_masks = high_res_multimasks + + # Extract object pointer from the SAM output token + if self.use_obj_ptrs_in_encoder: + if is_interactive: + obj_ptr = self.interactive_obj_ptr_proj(sam_output_token) + else: + obj_ptr = self.obj_ptr_proj(sam_output_token) + + if self.pred_obj_scores and self.use_no_obj_ptr: + lambda_is_obj_appearing = is_obj_appearing.float() + if self.use_linear_no_obj_ptr: + obj_ptr = lambda_is_obj_appearing * obj_ptr + ( + 1 - lambda_is_obj_appearing + ) * self.no_obj_ptr_linear(obj_ptr) + else: + if self.fixed_no_obj_ptr: + obj_ptr = lambda_is_obj_appearing * obj_ptr + + # use demux to locate the corresponding no_obj_ptr entries + selected_no_obj_ptr = self.no_obj_ptr.unsqueeze(0).repeat( + multiplex_state.num_buckets, 1, 1 + ) + selected_no_obj_ptr = multiplex_state.demux(selected_no_obj_ptr) + if is_interactive: + # if is_interactive, the object pointers are in the data space + selected_no_obj_ptr = selected_no_obj_ptr[objects_to_interact] + + obj_ptr = ( + obj_ptr + (1 - lambda_is_obj_appearing) * selected_no_obj_ptr + ) + + outputs: SAMOutput = { + "low_res_multimasks": low_res_multimasks, + "high_res_multimasks": high_res_multimasks, + "ious": ious, + "low_res_masks": low_res_masks, + "high_res_masks": high_res_masks, + "object_score_logits": object_score_logits, + } + if self.use_obj_ptrs_in_encoder: + outputs["obj_ptr"] = obj_ptr # [num_objects, C], in data space + return outputs + + def _use_mask_as_output( + self, + backbone_features: torch.Tensor, + high_res_features: list[torch.Tensor], + mask_inputs: torch.Tensor, + multiplex_state: MultiplexState, + objects_in_mask: Optional[list[int]] = None, + ) -> SAMOutput: + """ + Directly turn binary `mask_inputs` into a output mask logits without using SAM. + (same input and output shapes as in _forward_sam_heads above). + """ + if objects_in_mask is None: + objects_in_mask = list(range(multiplex_state.total_valid_entries)) + + # Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid). + out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05 + mask_inputs_float = mask_inputs.to(backbone_features.dtype) + assert mask_inputs.shape[0] == len( + objects_in_mask + ), f"{mask_inputs.shape[0]} != {len(objects_in_mask)}" + high_res_masks = mask_inputs_float * out_scale + out_bias + low_res_masks = F.interpolate( + high_res_masks, + size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4), + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + # a dummy IoU prediction of all 1's under mask input + ious = mask_inputs.new_ones( + mask_inputs.size(0), 1, dtype=backbone_features.dtype + ) + + if self.use_obj_ptrs_in_encoder: + # produce an object pointer using the SAM decoder from the mask input + sam_outputs = self._forward_sam_heads( + backbone_features=backbone_features, + mask_inputs=self.interactive_mask_downsample(mask_inputs_float), + interactive_high_res_features=high_res_features, + gt_masks=mask_inputs, + objects_to_interact=objects_in_mask, + multiplex_state=multiplex_state, + ) + obj_ptr = sam_outputs["obj_ptr"] + + # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem; + # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying + # on the object_scores from the SAM decoder. + is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1) + is_obj_appearing = is_obj_appearing[..., None] + lambda_is_obj_appearing = is_obj_appearing.float() + object_score_logits = out_scale * lambda_is_obj_appearing + out_bias + # Note that although this logic has already been applied in _forward_sam_heads + # it is ok because lambda_is_obj_appearing is binary + # when it is zero it forces no_obj_ptr + # when it is one it keeps the output from _forward_sam_heads + if self.pred_obj_scores and self.use_no_obj_ptr: + if self.use_linear_no_obj_ptr: + obj_ptr = lambda_is_obj_appearing * obj_ptr + ( + 1 - lambda_is_obj_appearing + ) * self.no_obj_ptr_linear(obj_ptr) + else: + if self.fixed_no_obj_ptr: + obj_ptr = lambda_is_obj_appearing * obj_ptr + # use demux to locate the corresponding no_obj_ptr entries + selected_no_obj_ptr = self.no_obj_ptr.unsqueeze(0).repeat( + multiplex_state.num_buckets, 1, 1 + ) + selected_no_obj_ptr = multiplex_state.demux(selected_no_obj_ptr) + selected_no_obj_ptr = selected_no_obj_ptr[objects_in_mask] + obj_ptr = ( + obj_ptr + (1 - lambda_is_obj_appearing) * selected_no_obj_ptr + ) + + outputs: SAMOutput = { + "low_res_multimasks": low_res_masks, + "high_res_multimasks": high_res_masks, + "ious": ious, + "low_res_masks": low_res_masks, + "high_res_masks": high_res_masks, + "object_score_logits": object_score_logits, + } + if self.use_obj_ptrs_in_encoder: + outputs["obj_ptr"] = obj_ptr # [num_objects, C], in data space + return outputs + + def forward(self, input: BatchedDatapoint, is_inference=False): + if self.training or not self.forward_backbone_per_frame_for_eval: + # precompute image features on all frames before tracking + backbone_out = self.forward_image( + input.img_batch, need_interactive_out=True, need_propagation_out=True + ) + else: + # defer image feature computation on a frame until it's being tracked + backbone_out = {} + backbone_out = self.prepare_prompt_inputs(backbone_out, input) + previous_stages_out = self.forward_tracking(backbone_out, input) + + # "None" for get_queries to be compatible with the trainer + return previous_stages_out, None + + def forward_image( + self, + img_batch, + *, + need_sam3_out: bool = False, + need_interactive_out: bool = False, + need_propagation_out: bool = False, + ): + """Get the image feature on the input batch.""" + if self.share_necks: + need_propagation_out = need_interactive_out or need_propagation_out + need_interactive_out = False + # this also means that convs for backbone_fpn are shared + backbone_out = self.backbone.forward_image( + img_batch, + need_sam3_out=need_sam3_out, + need_sam2_out=need_propagation_out, + ) + backbone_out["interactive"] = backbone_out["sam2_backbone_out"] + else: + backbone_out = self.backbone.forward_image( + img_batch, + need_sam3_out=need_sam3_out, + need_interactive_out=need_interactive_out, + need_propagation_out=need_propagation_out, + ) + if self.use_high_res_features_in_sam: + # precompute projected level 0 and level 1 features in SAM decoder + # to avoid running it again on every SAM click + if need_interactive_out: + backbone_out["interactive"]["backbone_fpn"][ + 0 + ].tensors = self.interactive_sam_mask_decoder.conv_s0( + backbone_out["interactive"]["backbone_fpn"][0].tensors + ) + backbone_out["interactive"]["backbone_fpn"][ + 1 + ].tensors = self.interactive_sam_mask_decoder.conv_s1( + backbone_out["interactive"]["backbone_fpn"][1].tensors + ) + if need_propagation_out: + backbone_out["sam2_backbone_out"]["backbone_fpn"][ + 0 + ].tensors = self.sam_mask_decoder.conv_s0( + backbone_out["sam2_backbone_out"]["backbone_fpn"][0].tensors + ) + backbone_out["sam2_backbone_out"]["backbone_fpn"][ + 1 + ].tensors = self.sam_mask_decoder.conv_s1( + backbone_out["sam2_backbone_out"]["backbone_fpn"][1].tensors + ) + # Clone to help torch.compile + for out_type in backbone_out.keys(): + for i in range(len(backbone_out[out_type]["backbone_fpn"])): + backbone_out[out_type]["backbone_fpn"][i].tensors = self._maybe_clone( + backbone_out[out_type]["backbone_fpn"][i].tensors + ) + backbone_out[out_type]["vision_pos_enc"][i] = self._maybe_clone( + backbone_out[out_type]["vision_pos_enc"][i] + ) + return backbone_out + + def _prepare_prompt_inputs_meta(self, backbone_out, input, start_frame_idx=0): + # Load the ground-truth masks on all frames (so that we can later + # sample correction points from them) + gt_masks_per_frame = { + stage_id: targets.segments.unsqueeze(1) # [B, 1, H_im, W_im] + for stage_id, targets in enumerate(input.find_targets) + } + backbone_out["gt_masks_per_frame"] = gt_masks_per_frame + num_frames = len(input.find_targets) + backbone_out["num_frames"] = num_frames + + # Randomly decide whether to use point inputs or mask inputs + if self.training: + prob_to_use_pt_input = self.prob_to_use_pt_input_for_train + num_frames_to_correct = self.num_frames_to_correct_for_train + rand_frames_to_correct = self.rand_frames_to_correct_for_train + num_init_cond_frames = self.num_init_cond_frames_for_train + rand_init_cond_frames = self.rand_init_cond_frames_for_train + else: + prob_to_use_pt_input = self.prob_to_use_pt_input_for_eval + num_frames_to_correct = self.num_frames_to_correct_for_eval + rand_frames_to_correct = self.rand_frames_to_correct_for_eval + num_init_cond_frames = self.num_init_cond_frames_for_eval + rand_init_cond_frames = self.rand_init_cond_frames_for_eval + if num_frames == 1: + # here we handle a special case for mixing video + SAM on image training, + # where we force using point input for the SAM task on static images + prob_to_use_pt_input = 1.0 + num_frames_to_correct = 1 + num_init_cond_frames = 1 + assert num_init_cond_frames >= 1 + # (here `self.rng.random()` returns value in range 0.0 <= X < 1.0) + use_pt_input = self.rng.random() < prob_to_use_pt_input + if rand_init_cond_frames and num_init_cond_frames > 1: + # randomly select 1 to `num_init_cond_frames` frames as initial conditioning frames + num_init_cond_frames = self.rng.integers( + 1, num_init_cond_frames, endpoint=True + ) + if ( + use_pt_input + and rand_frames_to_correct + and num_frames_to_correct > num_init_cond_frames + ): + # randomly select `num_init_cond_frames` to `num_frames_to_correct` frames to sample + # correction clicks (only for the case of point input) + num_frames_to_correct = self.rng.integers( + num_init_cond_frames, num_frames_to_correct, endpoint=True + ) + backbone_out["use_pt_input"] = use_pt_input + + # Sample initial conditioning frames + if num_init_cond_frames == 1: + init_cond_frames = [start_frame_idx] # starting frame + else: + # starting frame + randomly selected remaining frames (without replacement) + init_cond_frames = [start_frame_idx] + self.rng.choice( + range(start_frame_idx + 1, num_frames), + num_init_cond_frames - 1, + replace=False, + ).tolist() + backbone_out["init_cond_frames"] = init_cond_frames + backbone_out["frames_not_in_init_cond"] = [ + t for t in range(start_frame_idx, num_frames) if t not in init_cond_frames + ] + + # Sample frames where we will add correction clicks on the fly + # based on the error between prediction and ground-truth masks + if not use_pt_input: + # no correction points will be sampled when using mask inputs + frames_to_add_correction_pt = [] + elif num_frames_to_correct == num_init_cond_frames: + frames_to_add_correction_pt = init_cond_frames + else: + assert num_frames_to_correct > num_init_cond_frames + # initial cond frame + randomly selected remaining frames (without replacement) + extra_num = num_frames_to_correct - num_init_cond_frames + frames_to_add_correction_pt = ( + init_cond_frames + + self.rng.choice( + backbone_out["frames_not_in_init_cond"], extra_num, replace=False + ).tolist() + ) + backbone_out["frames_to_add_correction_pt"] = frames_to_add_correction_pt + + return backbone_out + + def _prepare_conditional_frames(self, backbone_out): + init_cond_frames = backbone_out["init_cond_frames"] + gt_masks_per_frame = backbone_out["gt_masks_per_frame"] + use_pt_input = backbone_out["use_pt_input"] + + if self.training: + prob_to_use_box_input = self.prob_to_use_box_input_for_train + else: + prob_to_use_box_input = self.prob_to_use_box_input_for_eval + + # Prepare mask or point inputs on initial conditioning frames + backbone_out["mask_inputs_per_frame"] = {} # {frame_idx: } + backbone_out["point_inputs_per_frame"] = {} # {frame_idx: } + for t in init_cond_frames: + if not use_pt_input: + backbone_out["mask_inputs_per_frame"][t] = gt_masks_per_frame[t] + else: + # During training # P(box) = prob_to_use_pt_input * prob_to_use_box_input + use_box_input = self.rng.random() < prob_to_use_box_input + if use_box_input: + points, labels = sample_box_points( + gt_masks_per_frame[t], + ) + else: + # (here we only sample **one initial point** on initial conditioning frames from the + # ground-truth mask; we may sample more correction points on the fly) + points, labels = get_next_point( + gt_masks=gt_masks_per_frame[t], + pred_masks=None, + method=( + "uniform" if self.training else self.pt_sampling_for_eval + ), + ) + + point_inputs = {"point_coords": points, "point_labels": labels} + backbone_out["point_inputs_per_frame"][t] = point_inputs + + return backbone_out + + def prepare_prompt_inputs(self, backbone_out, input, start_frame_idx=0): + """ + Prepare input mask, point or box prompts. Optionally, we allow tracking from + a custom `start_frame_idx` to the end of the video (for evaluation purposes). + """ + backbone_out = self._prepare_prompt_inputs_meta( + backbone_out, input, start_frame_idx + ) + backbone_out = self._prepare_conditional_frames(backbone_out) + return backbone_out + + def _prepare_backbone_features(self, backbone_out): + """Prepare and flatten visual features (same as in MDETR_API model).""" + + backbone_features = {} + + for neck_k in neck_outs: + if neck_k not in backbone_out: + continue + neck_out = backbone_out[neck_k] + assert len(neck_out["backbone_fpn"]) == len(neck_out["vision_pos_enc"]) + assert len(neck_out["backbone_fpn"]) >= self.num_feature_levels + + feature_maps = neck_out["backbone_fpn"][-self.num_feature_levels :] + vision_pos_embeds = neck_out["vision_pos_enc"][-self.num_feature_levels :] + + feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds] + # flatten NxCxHxW to HWxNxC + vision_feats = [x.tensors.flatten(2).permute(2, 0, 1) for x in feature_maps] + vision_pos_embeds = [ + x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds + ] + vision_masks = [x.mask for x in feature_maps] + + for i, vision_mask in enumerate(vision_masks): + if vision_mask is not None: + vision_masks[i] = vision_mask.flatten(1) + + backbone_features[neck_k] = { + "vision_feats": vision_feats, + "vision_pos_embeds": vision_pos_embeds, + "vision_masks": vision_masks, + "feat_sizes": feat_sizes, + } + + return backbone_features + + def _prepare_backbone_features_per_frame( + self, + img_batch, + img_ids, + *, + need_interactive_out: bool = False, + need_propagation_out: bool = False, + ): + """Compute the image backbone features on the fly for the given img_ids.""" + # all image ids should be the same + assert img_ids.numel() == 1 + unique_img_ids = img_ids + + # Compute the image features on those unique image ids + image = img_batch.tensors[unique_img_ids] + image_mask = ( + img_batch.mask[unique_img_ids] if img_batch.mask is not None else None + ) + + backbone_out = self.forward_image( + NestedTensor(tensors=image, mask=image_mask), + need_interactive_out=need_interactive_out, + need_propagation_out=need_propagation_out, + ) + + backbone_features = self._prepare_backbone_features(backbone_out) + return image, backbone_features + + def _prepare_memory_conditioned_features( + self, + *, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_masks, + current_vision_pos_embeds, + feat_sizes, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + use_prev_mem_frame=True, # whether to condition on previous memory frames + multiplex_state: MultiplexState, + ): + """Fuse the current frame's visual feature map with previous memory.""" + B = multiplex_state.num_buckets + # B = current_vision_feats[-1].size(1) # batch size on this frame + vision_feat = current_vision_feats[-1].expand(-1, B, -1) + vision_mask = ( + current_vision_masks[-1].expand(-1, B, -1) + if current_vision_masks[-1] is not None + else None + ) + vision_pos_embed = current_vision_pos_embeds[-1].expand(-1, B, -1) + + C = self.hidden_dim + H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size + device = current_vision_feats[-1].device + # The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images. + # In this case, we skip the fusion with any memory. + if self.num_maskmem == 0: # Disable memory and skip fusion + pix_feat = vision_feat.permute(1, 2, 0).view(B, C, H, W) + return pix_feat + + num_obj_ptr_tokens = 0 + tpos_sign_mul = -1 if track_in_reverse else 1 + # Step 1: condition the visual features of the current frame on previous memories + if not is_init_cond_frame and use_prev_mem_frame: + # Retrieve the memories encoded with the maskmem backbone + # to_cat_prompt, to_cat_prompt_mask, to_cat_prompt_pos_embed = [], [], [] + to_cat_prompt, to_cat_prompt_pos_embed = [], [] + if self.save_image_features: + to_cat_image_feat, to_cat_image_pos_embed = [], [] + # Add conditioning frames's output first (all cond frames have t_pos=0 for + # when getting temporal positional embedding below) + assert len(output_dict["cond_frame_outputs"]) > 0 + # Select a maximum number of temporally closest cond frames for cross attention + cond_outputs = output_dict["cond_frame_outputs"] + selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames( + frame_idx, + cond_outputs, + self.max_cond_frames_in_attn, + keep_first_cond_frame=self.keep_first_cond_frame, + ) + + t_pos_and_prevs = [ + ((frame_idx - t) * tpos_sign_mul, out, True) + for t, out in selected_cond_outputs.items() + ] + # Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory + # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1 + # We also allow taking the memory frame non-consecutively (with r>1), in which case + # we take (self.num_maskmem - 2) frames among every r-th frames plus the last frame. + r = 1 if self.training else self.memory_temporal_stride_for_eval + + if self.use_memory_selection: + valid_indices = self.frame_filter( + output_dict, track_in_reverse, frame_idx, num_frames, r + ) + + for t_pos in range(1, self.num_maskmem): + t_rel = self.num_maskmem - t_pos # how many frames before current frame + if self.use_memory_selection: + if t_rel > len(valid_indices): + continue + prev_frame_idx = valid_indices[-t_rel] + else: + if t_rel == 1: + # for t_rel == 1, we take the last frame (regardless of r) + if not track_in_reverse: + # the frame immediately before this frame (i.e. frame_idx - 1) + prev_frame_idx = frame_idx - t_rel + else: + # the frame immediately after this frame (i.e. frame_idx + 1) + prev_frame_idx = frame_idx + t_rel + else: + # for t_rel >= 2, we take the memory frame from every r-th frames + if not track_in_reverse: + # first find the nearest frame among every r-th frames before this frame + # for r=1, this would be (frame_idx - 2) + prev_frame_idx = ((frame_idx - 2) // r) * r + # then seek further among every r-th frames + prev_frame_idx = prev_frame_idx - (t_rel - 2) * r + else: + # first find the nearest frame among every r-th frames after this frame + # for r=1, this would be (frame_idx + 2) + prev_frame_idx = -(-(frame_idx + 2) // r) * r + # then seek further among every r-th frames + prev_frame_idx = prev_frame_idx + (t_rel - 2) * r + out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None) + if out is None: + # If an unselected conditioning frame is among the last (self.num_maskmem - 1) + # frames, we still attend to it as if it's a non-conditioning frame. + out = unselected_cond_outputs.get(prev_frame_idx, None) + t_pos_and_prevs.append((t_pos, out, False)) + + for t_pos, prev, is_selected_cond_frame in t_pos_and_prevs: + if prev is None: + continue # skip padding frames + + feats = prev.get("maskmem_features") + if feats is None: + continue + # "maskmem_features" might have been offloaded to CPU in demo use cases, + # so we load it back to GPU (it's a no-op if it's already on GPU). + feats = feats.cuda(non_blocking=True) + if feats.dim() == 5: + feats = multiplex_state.demux(feats).contiguous() + prev["maskmem_features"] = ( + feats.cpu() if not feats.is_cuda else feats + ) + + if feats.shape[0] == 0: + continue + + to_cat_prompt.append(feats.flatten(2).permute(2, 0, 1)) + # to_cat_prompt_mask.append(None) + # Spatial positional encoding (it might have been offloaded to CPU in eval) + maskmem_pos_list = prev.get("maskmem_pos_enc") + if not maskmem_pos_list: + continue + maskmem_enc = maskmem_pos_list[-1] + if maskmem_enc is None: + continue + maskmem_enc = maskmem_enc.cuda(non_blocking=True) + if maskmem_enc.dim() == 5: + maskmem_enc = multiplex_state.demux(maskmem_enc).contiguous() + prev["maskmem_pos_enc"][-1] = ( + maskmem_enc.cpu() if not maskmem_enc.is_cuda else maskmem_enc + ) + maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1) + + if self.use_maskmem_tpos_v2: + # the last of maskmem_tpos_enc is an "out-of-range" embedding + if t_pos <= 0 or t_pos >= self.num_maskmem: + tpos_enc = self.maskmem_tpos_enc[self.num_maskmem - 1] + else: + tpos_enc = self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1] + else: + # cond_frame NOT temporally encoded in this setting + # and last of the maskmem_tpos_enc is actually an + # indicator for being a cond_frame + t = t_pos if not is_selected_cond_frame else 0 + tpos_enc = self.maskmem_tpos_enc[self.num_maskmem - t - 1] + + maskmem_enc = maskmem_enc + tpos_enc + + if self.save_image_features: + # image features are in (HW)BC + image_feat = prev["image_features"].cuda() + image_pos_embed = prev["image_pos_enc"].cuda() + tpos_enc + to_cat_image_feat.append(image_feat) + to_cat_image_pos_embed.append(image_pos_embed) + + to_cat_prompt_pos_embed.append(maskmem_enc) + + # Construct the list of past object pointers + if self.use_obj_ptrs_in_encoder: + max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder) + # First add those object pointers from selected conditioning frames + # (optionally, only include object pointers in the past during evaluation) + if not self.training and self.only_obj_ptrs_in_the_past_for_eval: + ptr_cond_outputs = { + t: out + for t, out in selected_cond_outputs.items() + if (t >= frame_idx if track_in_reverse else t <= frame_idx) + } + else: + ptr_cond_outputs = selected_cond_outputs + pos_and_outs_for_ptr = [ + # Temporal pos encoding contains how far away each pointer is from current frame + ( + ( + (frame_idx - t) * tpos_sign_mul + if self.use_signed_tpos_enc_to_obj_ptrs + else abs(frame_idx - t) + ), + out, + True, # is_selected_cond_frame + ) + for t, out in ptr_cond_outputs.items() + ] + + # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame + for t_diff in range(1, max_obj_ptrs_in_encoder): + if not self.use_memory_selection: + t = ( + frame_idx + t_diff + if track_in_reverse + else frame_idx - t_diff + ) + if t < 0 or (num_frames is not None and t >= num_frames): + break + else: + if -t_diff <= -len(valid_indices): + break + t = valid_indices[-t_diff] + + out = output_dict["non_cond_frame_outputs"].get( + t, unselected_cond_outputs.get(t, None) + ) + if out is not None: + pos_and_outs_for_ptr.append((t_diff, out, False)) + + # If we have at least one object pointer, add them to the across attention + if len(pos_and_outs_for_ptr) > 0: + pos_list, out_list, is_selected_cond_frame_list = zip( + *pos_and_outs_for_ptr + ) + # Filter out outputs that don't have obj_ptr (e.g., when object has empty mask) + filtered_data = [ + (pos, out, is_cond) + for pos, out, is_cond in zip( + pos_list, out_list, is_selected_cond_frame_list + ) + if "obj_ptr" in out + ] + + # Only proceed if we have at least one valid obj_ptr + if len(filtered_data) > 0: + pos_list, out_list, is_selected_cond_frame_list = zip( + *filtered_data + ) + # each out["obj_ptr"] is a tensor of shape (num_buckets, seq_len, C) + # cat object pointers along dim=0 into [ptr_seq_len, B, C] shape + obj_ptrs = torch.cat( + [out["obj_ptr"] for out in out_list], dim=1 + ).transpose(0, 1) + + # a temporal positional embedding based on how far each object pointer is from + # the current frame (sine embedding normalized by the max pointer num). + if self.add_tpos_enc_to_obj_ptrs: + obj_pos = self._get_tpos_enc( + pos_list, + max_abs_pos=max_obj_ptrs_in_encoder, + device=device, + ) + else: + obj_pos = self._get_tpos_enc( + pos_list, device=device, dummy=True + ) + # expand to batch size + obj_pos = obj_pos.unsqueeze(1).expand(-1, B, -1) + + assert ( + self.mem_dim == C + ), f"obj_ptrs.shape = {obj_ptrs.shape}, C = {C}" + + # each frame has [bucket_size] pointers, except the first frame + obj_pos = obj_pos.repeat_interleave( + multiplex_state.multiplex_count, dim=0 + ) + + to_cat_prompt.append(obj_ptrs) + to_cat_prompt_pos_embed.append(obj_pos) + # number of object pointer tokens for the encoder + num_obj_ptr_tokens = obj_ptrs.shape[0] + else: + # All outputs were filtered out (empty masks), no obj_ptrs available + num_obj_ptr_tokens = 0 + else: + num_obj_ptr_tokens = 0 + else: + # for initial conditioning frames, encode them without using any previous memory + raise NotImplementedError( + "Any init cond frame should have gone to _use_mask_as_output instead" + ) + + # Step 2: Concatenate the memories and forward through the transformer encoder + if len(to_cat_prompt) == 0: + # No available memory features (e.g. mask was cleared). Skip fusion and + # fall back to the current frame features so the object can continue to + # propagate as empty without raising errors. + pix_feat = vision_feat.permute(1, 2, 0).view(B, C, H, W) + return pix_feat + + prompt = torch.cat(to_cat_prompt, dim=0) + prompt_mask = None # For now, we always masks are zeros anyways + prompt_pos_embed = torch.cat(to_cat_prompt_pos_embed, dim=0) + + if self.save_image_features: + assert prompt_mask is None + assert vision_mask is None + if len(to_cat_image_feat) == 0 or len(to_cat_image_pos_embed) == 0: + # Memory image features were cleared; fall back to current-frame features. + pix_feat = vision_feat.permute(1, 2, 0).view(B, C, H, W) + return pix_feat + image_feat = torch.cat(to_cat_image_feat, dim=0) + image_pos_embed = torch.cat(to_cat_image_pos_embed, dim=0) + + encoder_out = self.transformer.encoder( + image=current_vision_feats[-1], + src=vision_feat, + memory_image=image_feat, + memory=prompt, + image_pos=current_vision_pos_embeds[-1], + src_pos=vision_pos_embed, + memory_image_pos=image_pos_embed, + memory_pos=prompt_pos_embed, + num_obj_ptr_tokens=num_obj_ptr_tokens, + ) + else: + encoder_out = self.transformer.encoder( + src=vision_feat, + src_key_padding_mask=vision_mask, + src_pos=vision_pos_embed, + prompt=prompt, + prompt_pos=prompt_pos_embed, + prompt_key_padding_mask=prompt_mask, + feat_sizes=feat_sizes, + num_obj_ptr_tokens=num_obj_ptr_tokens, + ) + # reshape the output (HW)BC => BCHW + pix_feat_with_mem = encoder_out["memory"].permute(1, 2, 0).view(B, C, H, W) + return pix_feat_with_mem + + def _encode_new_memory( + self, + image, + current_vision_feats, + feat_sizes, + pred_masks_high_res, + object_score_logits, + is_mask_from_pts, + *, + conditioning_objects: Optional[Iterable[int]] = None, + multiplex_state: MultiplexState, + ): + """Encode the current image and its prediction into a memory feature.""" + B = current_vision_feats[-1].size(1) # batch size on this frame + C = self.hidden_dim + H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size + # top-level feature, (HW)BC => BCHW + pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) + if self.non_overlap_masks_for_mem_enc and not self.training: + # optionally, apply non-overlapping constraints to the masks (it's applied + # in the batch dimension and should only be used during eval, where all + # the objects come from the same video under batch size 1). + pred_masks_high_res = self._apply_non_overlapping_constraints( + pred_masks_high_res + ) + if self.apply_sigmoid_to_mask_logits_for_mem_enc: + # scale the raw mask logits with a temperature before applying sigmoid + assert ( + not self.binarize_mask_from_pts_for_mem_enc + ), "haven't been trained this way; beware of hardcoded config override" + binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts + if binarize and not self.training: + mask_for_mem = (pred_masks_high_res > 0).float() + else: + # apply sigmoid on the raw mask logits to turn them into range (0, 1) + mask_for_mem = torch.sigmoid(pred_masks_high_res) + # apply scale and bias terms to the sigmoid probabilities + if self.sigmoid_scale_for_mem_enc != 1.0: + mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc + if self.sigmoid_bias_for_mem_enc != 0.0: + mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc + else: + mask_for_mem = pred_masks_high_res + + if self.add_object_conditional_embeddings or self.condition_as_mask_input: + # figure out the set of objects that are "conditional" on this frame + if conditioning_objects is None: + conditioning_objects = [] + unconditioning_objects = sorted( + list(multiplex_state.get_all_valid_object_idx()) + ) + else: + conditioning_objects = sorted(list(conditioning_objects)) + all_objects_idx = multiplex_state.get_all_valid_object_idx() + unconditioning_objects = sorted( + [i for i in all_objects_idx if i not in conditioning_objects] + ) + + mux_mask_for_mem = multiplex_state.mux(mask_for_mem).squeeze(2) + + if self.condition_as_mask_input: + # create num_objects channels spatial features that encode the + # list of objects that are conditional with fg and bg values + num_objects = mask_for_mem.shape[0] + # Create a 1D conditioning mask on GPU and broadcast it + cond_values = torch.full( + (num_objects,), + self.condition_as_mask_input_bg, + device=mask_for_mem.device, + dtype=mask_for_mem.dtype, + ) + if len(conditioning_objects) > 0: + cond_values[conditioning_objects] = self.condition_as_mask_input_fg + # Broadcast to full spatial dimensions: [N] -> [N, 1, H, W] + embedded_conditions = cond_values.view(-1, 1, 1, 1).expand_as(mask_for_mem) + embedded_conditions = multiplex_state.mux(embedded_conditions).squeeze(2) + + mux_mask_for_mem = torch.cat([mux_mask_for_mem, embedded_conditions], dim=1) + + if isinstance(self.maskmem_backbone, SimpleMaskEncoder): + maskmem_out = self.maskmem_backbone( + pix_feat, + mux_mask_for_mem, + skip_mask_sigmoid=True, + ) + else: + maskmem_out = self.maskmem_backbone(image, pix_feat, mux_mask_for_mem) + # Clone the feats and pos_enc to enable compilation + maskmem_features = self._maybe_clone(maskmem_out["vision_features"]) + maskmem_pos_enc = [self._maybe_clone(m) for m in maskmem_out["vision_pos_enc"]] + + if self.no_obj_embed_spatial is not None: + # since maskmem_features are deeply detangled between objects + # we simply add a projected embedding for each empty object + # num_buckets * multiplex_count * C + no_obj_embed_spatial = self.no_obj_embed_spatial.unsqueeze(0).repeat( + multiplex_state.num_buckets, 1, 1 + ) + # Align object_score_logits length to multiplex expectations before mux + if object_score_logits is not None: + obj_expected = multiplex_state.total_valid_entries + obj_current = object_score_logits.shape[0] + if obj_current != obj_expected: + if obj_current < obj_expected: + pad_shape = (obj_expected - obj_current,) + tuple( + object_score_logits.shape[1:] + ) + obj_pad = object_score_logits.new_zeros(pad_shape) + object_score_logits = torch.cat( + [object_score_logits, obj_pad], dim=0 + ) + else: + object_score_logits = object_score_logits[:obj_expected] + object_score_logits = multiplex_state.mux(object_score_logits) + is_obj_appearing = ( + object_score_logits > self.object_score_logit_threshold + ).float() + + no_obj_embed = ((1 - is_obj_appearing) * no_obj_embed_spatial).sum(dim=1) + maskmem_features += no_obj_embed[..., None, None].expand_as( + maskmem_features + ) + + if self.add_object_conditional_embeddings: + # add object conditional embeddings to the maskmem_features + # num_buckets * multiplex_count * C + obj_cond_embed = self.obj_cond_embed.unsqueeze(0).repeat( + multiplex_state.num_buckets, 1, 1 + ) + obj_cond_embed = multiplex_state.demux(obj_cond_embed) + obj_merged_embed = obj_cond_embed + + if self.add_object_unconditional_embeddings: + obj_non_cond_embed = self.obj_non_cond_embed.unsqueeze(0).repeat( + multiplex_state.num_buckets, 1, 1 + ) + obj_non_cond_embed = multiplex_state.demux(obj_non_cond_embed) + if self.training: + obj_merged_embed = obj_merged_embed.clone() + obj_merged_embed[unconditioning_objects] = obj_non_cond_embed[ + unconditioning_objects + ] + + obj_merged_embed = multiplex_state.mux(obj_merged_embed).sum(dim=1) + maskmem_features = maskmem_features + obj_merged_embed[ + ..., None, None + ].expand_as(maskmem_features) + + if maskmem_features.dim() == 5: + maskmem_features = multiplex_state.demux(maskmem_features).contiguous() + + demuxed_pos_enc = [] + for pos_enc in maskmem_pos_enc: + pos_enc_clone = pos_enc + if pos_enc_clone is not None and pos_enc_clone.dim() == 5: + pos_enc_clone = multiplex_state.demux(pos_enc_clone).contiguous() + demuxed_pos_enc.append(pos_enc_clone) + maskmem_pos_enc = demuxed_pos_enc + + return maskmem_features, maskmem_pos_enc + + def forward_tracking( + self, + backbone_out, + input, + return_dict=False, + objects_to_interact: Optional[list[int]] = None, + ): + """Forward video tracking on each frame (and sample correction clicks).""" + img_feats_already_computed = ( + "interactive" in backbone_out or "sam2_backbone_out" in backbone_out + ) + if img_feats_already_computed: + # Prepare the backbone features + # - vision_feats and vision_pos_embeds are in (HW)BC format + # - vision_masks are in B(HW) format, dtype=bool (False is valid, True is padding) + backbone_features = self._prepare_backbone_features(backbone_out) + + # Starting the stage loop + num_frames = backbone_out["num_frames"] + init_cond_frames = backbone_out["init_cond_frames"] + frames_to_add_correction_pt = backbone_out["frames_to_add_correction_pt"] + # first process all the initial conditioning frames to encode them as memory, + # and then conditioning on them to track the remaining frames + processing_order = init_cond_frames + backbone_out["frames_not_in_init_cond"] + + cond_frame_outputs: dict[int, StageOutput] = {} + non_cond_frame_outputs: dict[int, StageOutput] = {} + output_dict = { + "cond_frame_outputs": cond_frame_outputs, + "non_cond_frame_outputs": non_cond_frame_outputs, + } + + multiplex_state = self.multiplex_controller.get_state( + backbone_out["gt_masks_per_frame"][0].shape[0], + device=backbone_out["gt_masks_per_frame"][0].device, + dtype=torch.float, + random=self.training, + ) + + for stage_id in processing_order: + # Get the image features for the current frames + img_ids = input.find_inputs[stage_id].img_ids + # the image ids are for the entire batch + assert all( + [img_id == img_ids[0] for img_id in img_ids] + ) # should be all the same + # force this to have a batch size of 1 + img_ids = torch.tensor( + [img_ids[0]], device=img_ids.device, dtype=img_ids.dtype + ) + + if img_feats_already_computed: + # Retrieve image features according to img_ids (if they are already computed). + current_image = input.img_batch.tensors[img_ids] + current_backbone_features = {} + for neck_k, neck_out in backbone_features.items(): + current_backbone_features[neck_k] = { + "vision_feats": [ + x[:, img_ids] for x in neck_out["vision_feats"] + ], + "vision_masks": [ + x[img_ids] if x is not None else None + for x in neck_out["vision_masks"] + ], + "vision_pos_embeds": [ + x[:, img_ids] for x in neck_out["vision_pos_embeds"] + ], + "feat_sizes": neck_out["feat_sizes"], + } + else: + # Otherwise, compute the image features on the fly for the given img_ids + # (this might be used for evaluation on long videos to avoid backbone OOM). + need_interactive_out = (stage_id in frames_to_add_correction_pt) or ( + stage_id in init_cond_frames + ) + (current_image, current_backbone_features) = ( + self._prepare_backbone_features_per_frame( + input.img_batch, + img_ids, + need_interactive_out=need_interactive_out, + need_propagation_out=True, + ) + ) + + # Get output masks based on this frame's prompts and previous memory + current_out = self.track_step( + frame_idx=stage_id, + is_init_cond_frame=stage_id in init_cond_frames, + backbone_features_interactive=current_backbone_features.get( + "interactive" + ), + backbone_features_propagation=current_backbone_features.get( + "sam2_backbone_out" + ), + image=current_image, + point_inputs=backbone_out["point_inputs_per_frame"].get(stage_id, None), + mask_inputs=backbone_out["mask_inputs_per_frame"].get(stage_id, None), + gt_masks=backbone_out["gt_masks_per_frame"].get(stage_id, None), + frames_to_add_correction_pt=frames_to_add_correction_pt, + output_dict=output_dict, + num_frames=num_frames, + multiplex_state=multiplex_state, + objects_to_interact=objects_to_interact, + ) + # Append the output, depending on whether it's a conditioning frame + add_output_as_cond_frame = stage_id in init_cond_frames or ( + self.add_all_frames_to_correct_as_cond + and stage_id in frames_to_add_correction_pt + ) + if add_output_as_cond_frame: + output_dict["cond_frame_outputs"][stage_id] = current_out + else: + output_dict["non_cond_frame_outputs"][stage_id] = current_out + + output_dict["multiplex_state"] = multiplex_state + + if return_dict: + return output_dict + # turn `output_dict` into a list for loss function + all_frame_outputs = {} + all_frame_outputs.update(output_dict["cond_frame_outputs"]) + all_frame_outputs.update(output_dict["non_cond_frame_outputs"]) + all_frame_outputs = [all_frame_outputs[t] for t in range(num_frames)] + # Make DDP happy with activation checkpointing by removing unused keys + all_frame_outputs = [ + {k: v for k, v in d.items() if k != "obj_ptr"} for d in all_frame_outputs + ] + + return all_frame_outputs + + def _track_step_aux( + self, + *, + frame_idx, + is_init_cond_frame, + backbone_features_interactive, + backbone_features_propagation, + image, + point_inputs, + mask_inputs, + gt_masks, + frames_to_add_correction_pt, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + run_mem_encoder=True, + prev_sam_mask_logits=None, + multiplex_state: MultiplexState, + objects_to_interact: Optional[list[int]] = None, + need_aux_output: bool = False, + ) -> tuple[StageOutput, dict]: + """ + There are four different modes that track_step might enter, based on the inputs + 1. Mask-as-output. This is when mask_inputs is not None. + The input mask is returned directly. This case is for FA/VOS initialization. + 2. Propagation-only. This is when mask_inputs and point_inputs are empty. + We propagate masks using the memory only. This case is for VOS propagation. + 3. Interaction-only. This is when mask_inputs is None, point_inputs is not None, + and one of the followings is satisified: + a) prev_sam_mask_logits is not None. In this case, we refine prev_sam_mask_logits + with additional interactions, updating only the objects specified in objects_to_interact. + objects_to_interact must not be None. + This occurs when we refine the same frame with multiple point inputs iteratively. + b) prev_sam_mask_logits is None, and is_init_cond_frame is True. + This case is for initializing the first frame. All objects will have point inputs. + This mostly happens during training/interactive eval. + 4. Propagation-and-interaction. This is when mask_inputs is None, point_inputs is not None, + prev_sam_mask_logits is None, and objects_to_interact is not None. + This is when we are propagating to a new frame that has point inputs (from previous interactions). + This is more of an edge case that could happen in offline interactive eval. + We first propagate the mask to the current frame, and then perform interaction on the selected + objects. Finally, we replace the masks of the interacted objects in the propagated output + with the masks from the interaction output. + """ + current_out: StageOutput = { + "conditioning_objects": set(), + "point_inputs": point_inputs, + "mask_inputs": mask_inputs, + } + + mode = None + if mask_inputs is not None: + mode = "mask_as_output" + elif point_inputs is None: + mode = "propagation_only" + elif point_inputs is not None: + # Case 3a: Refining existing predictions + if prev_sam_mask_logits is not None: + assert ( + objects_to_interact is not None + ), "objects_to_interact must be specified when refining with prev_sam_mask_logits" + mode = "interaction_only" + # Case 3b: Initial conditioning frame + elif is_init_cond_frame: + mode = "interaction_only" + # Case 4: Propagation then interaction + elif objects_to_interact is not None and prev_sam_mask_logits is None: + assert not self.training + mode = "propagation_and_interaction" + + if mode is None: + raise ValueError( + f"Unable to determine tracking case. " + f"mask_inputs={mask_inputs is not None}, " + f"point_inputs={point_inputs is not None}, " + f"prev_sam_mask_logits={prev_sam_mask_logits is not None}, " + f"objects_to_interact={objects_to_interact}, " + f"is_init_cond_frame={is_init_cond_frame}" + ) + # partition the backbone features + interactive_high_res_features = interactive_vision_feats = None + interactive_feat_sizes = None + if backbone_features_interactive is not None: + interactive_vision_feats = backbone_features_interactive["vision_feats"] + interactive_feat_sizes = backbone_features_interactive["feat_sizes"] + + # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW + if len(interactive_vision_feats) > 1: + interactive_high_res_features = [ + x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) + for x, s in zip( + interactive_vision_feats[:-1], interactive_feat_sizes[:-1] + ) + ] + else: + # cannot do point interaction without interactive features + assert mode not in ["interaction_only", "propagation_and_interaction"] + + propagation_high_res_features = propagation_vision_feats = None + propagation_vision_masks = None + propagation_vision_pos_embeds = propagation_feat_sizes = None + if backbone_features_propagation is not None: + propagation_vision_feats = backbone_features_propagation["vision_feats"] + propagation_vision_masks = backbone_features_propagation["vision_masks"] + propagation_vision_pos_embeds = backbone_features_propagation[ + "vision_pos_embeds" + ] + propagation_feat_sizes = backbone_features_propagation["feat_sizes"] + + # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW + if len(propagation_vision_feats) > 1: + propagation_high_res_features = [ + x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) + for x, s in zip( + propagation_vision_feats[:-1], propagation_feat_sizes[:-1] + ) + ] + else: + # we can get away without propagation features if we are interacting and not encoding new memory + assert mode not in ["propagation_only", "propagation_and_interaction"] + assert not run_mem_encoder + + interactive_pix_feat = None + if mode == "mask_as_output": + # simple encoding + assert self.use_mask_input_as_output_without_sam + # pix_feat = interactive_vision_feats[-1].permute(1, 2, 0) + # pix_feat = pix_feat.view(-1, self.hidden_dim, *interactive_feat_sizes[-1]) + # use no_mem_embed here as well to better align first-frame mask input vs point input + interactive_pix_feat = self._get_interactive_pix_mem( + interactive_vision_feats, interactive_feat_sizes + ) + sam_outputs = self._use_mask_as_output( + backbone_features=interactive_pix_feat, + high_res_features=interactive_high_res_features, + mask_inputs=mask_inputs, + multiplex_state=multiplex_state, + ) + # all the objects are conditional here + current_out["conditioning_objects"].update(range(mask_inputs.shape[0])) + else: + # propagation, interaction, or both + propagation_out = None + if mode in ["propagation_only", "propagation_and_interaction"]: + # gather the memory + assert backbone_features_propagation is not None + assert propagation_vision_feats is not None + assert propagation_vision_masks is not None + assert propagation_vision_pos_embeds is not None + assert propagation_feat_sizes is not None + pix_feat_with_mem = self._prepare_memory_conditioned_features( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=propagation_vision_feats[-1:], + current_vision_masks=propagation_vision_masks[-1:], + current_vision_pos_embeds=propagation_vision_pos_embeds[-1:], + feat_sizes=propagation_feat_sizes[-1:], + output_dict=output_dict, + num_frames=num_frames, + track_in_reverse=track_in_reverse, + multiplex_state=multiplex_state, + ) + + # propagate the mask + # this is the propagation step; do not consider point_inputs here + multimask_output = self._use_multimask( + is_init_cond_frame, point_inputs=None + ) + propagation_out = self._forward_sam_heads( + backbone_features=pix_feat_with_mem, + propagation_high_res_features=propagation_high_res_features, + multimask_output=multimask_output, + objects_to_interact=list( + range(multiplex_state.total_valid_entries) + ), + multiplex_state=multiplex_state, + ) + + interaction_out = None + if mode in ["interaction_only", "propagation_and_interaction"]: + assert backbone_features_interactive is not None + assert interactive_vision_feats is not None + assert interactive_feat_sizes is not None + interactive_pix_feat = self._get_interactive_pix_mem( + interactive_vision_feats, interactive_feat_sizes + ) + + # apply SAM-style segmentation head + # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder, + # e.g. in demo where such logits come from earlier interaction instead of correction sampling + # (in this case, the SAM mask decoder should have `self.iter_use_prev_mask_pred=True`, and + # any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead) + assert mask_inputs is None and point_inputs is not None + if prev_sam_mask_logits is not None: + assert objects_to_interact is not None + assert self.iter_use_prev_mask_pred + assert mode != "propagation_and_interaction" + mask_inputs = prev_sam_mask_logits[objects_to_interact] + elif mode == "propagation_and_interaction": + # use propagated masks as mask input + assert objects_to_interact is not None + assert propagation_out is not None + mask_inputs = propagation_out["low_res_masks"][objects_to_interact] + + if objects_to_interact is not None: + assert point_inputs["point_coords"].shape[0] == len( + objects_to_interact + ) + assert point_inputs["point_labels"].shape[0] == len( + objects_to_interact + ) + + multimask_output = self._use_multimask( + is_init_cond_frame, point_inputs=point_inputs + ) + interaction_out = self._forward_sam_heads( + backbone_features=interactive_pix_feat, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + interactive_high_res_features=interactive_high_res_features, + multimask_output=multimask_output, + objects_to_interact=( + objects_to_interact + if objects_to_interact is not None + else list(range(multiplex_state.total_valid_entries)) + ), + multiplex_state=multiplex_state, + ) + if objects_to_interact is None: + current_out["conditioning_objects"].update( + multiplex_state.get_all_valid_object_idx() + ) + else: + current_out["conditioning_objects"].update(objects_to_interact) + + if propagation_out is None and interaction_out is not None: + sam_outputs = interaction_out + elif interaction_out is None and propagation_out is not None: + sam_outputs = propagation_out + else: + # merge the output + assert propagation_out is not None and interaction_out is not None + keys_to_merge = [ + "low_res_multimasks", + "high_res_multimasks", + "low_res_masks", + "high_res_masks", + "ious", + "object_score_logits", + "obj_ptr", + ] + for k in keys_to_merge: + src = interaction_out[k] + dst = propagation_out[k] + # Align dtype for floating tensors before indexed assignment + if torch.is_tensor(src) and torch.is_tensor(dst): + if torch.is_floating_point(src) and src.dtype != dst.dtype: + src = src.to(dtype=dst.dtype) + propagation_out[k][objects_to_interact] = src + sam_outputs = propagation_out + + low_res_multimasks = sam_outputs["low_res_multimasks"] + high_res_multimasks = sam_outputs["high_res_multimasks"] + ious = sam_outputs["ious"] + low_res_masks = sam_outputs["low_res_masks"] + high_res_masks = sam_outputs["high_res_masks"] + object_score_logits = sam_outputs["object_score_logits"] + + current_out["multistep_pred_masks"] = low_res_masks + current_out["multistep_pred_masks_high_res"] = high_res_masks + current_out["multistep_pred_multimasks"] = [low_res_multimasks] + current_out["multistep_pred_multimasks_high_res"] = [high_res_multimasks] + current_out["multistep_pred_ious"] = [ious] + current_out["multistep_point_inputs"] = [point_inputs] + current_out["multistep_object_score_logits"] = [object_score_logits] + + if self.use_obj_ptrs_in_encoder: + obj_ptr = sam_outputs["obj_ptr"] + + # Optionally, sample correction points iteratively to correct the mask + if frame_idx in frames_to_add_correction_pt: + assert gt_masks is not None + assert interactive_vision_feats is not None + assert interactive_feat_sizes is not None + all_pred_masks = [low_res_masks] + all_pred_high_res_masks = [high_res_masks] + all_pred_multimasks = [low_res_multimasks] + all_pred_high_res_multimasks = [high_res_multimasks] + all_pred_ious = [ious] + all_point_inputs = [point_inputs] + all_object_score_logits = [object_score_logits] + + # select a subset of objects to interact with + if self.training: + assert objects_to_interact is None + + interact_with_all_objects = ( + self.rng.random() < self.prob_correct_all_objects_for_train + ) or ( + self.force_correct_all_for_conditional_inputs and is_init_cond_frame + ) + + if interact_with_all_objects: + num_objects_to_correct = gt_masks.shape[0] + elif self.rand_objects_to_correct_for_train: + num_objects_to_correct = self.rng2.integers( + 1, + int( + gt_masks.shape[0] + * self.ratio_of_objects_to_correct_for_train + ) + + 1, + ) + else: + num_objects_to_correct = max( + 1, + int( + gt_masks.shape[0] + * self.ratio_of_objects_to_correct_for_train + ), + ) + + objects_to_interact = self.rng2.choice( + range(gt_masks.shape[0]), + size=num_objects_to_correct, + replace=False, + ).tolist() + + if point_inputs is not None: + # don't modify the point inputs in-place + point_inputs = { + "point_coords": point_inputs["point_coords"][ + objects_to_interact + ], + "point_labels": point_inputs["point_labels"][ + objects_to_interact + ], + } + else: + assert objects_to_interact is not None + # the point inputs should have been preselected, i.e., the following assertion should hold + + if point_inputs is not None: + assert point_inputs["point_coords"].shape[0] == len(objects_to_interact) + assert point_inputs["point_labels"].shape[0] == len(objects_to_interact) + + for _ in range(self.num_correction_pt_per_frame): + # sample a new point from the error between prediction and ground-truth + # (with a small probability, directly sample from GT masks instead of errors) + if self.training and self.prob_to_sample_from_gt_for_train > 0: + sample_from_gt = ( + self.rng.random() < self.prob_to_sample_from_gt_for_train + ) + else: + sample_from_gt = False + # if `pred_for_new_pt` is None, only GT masks will be used for point sampling + pred_for_new_pt = None if sample_from_gt else (high_res_masks > 0) + new_points, new_labels = get_next_point( + gt_masks=gt_masks[objects_to_interact], + pred_masks=( + pred_for_new_pt[objects_to_interact] + if pred_for_new_pt is not None + else None + ), + method="uniform" if self.training else self.pt_sampling_for_eval, + ) + point_inputs = concat_points(point_inputs, new_points, new_labels) + assert low_res_masks.shape[0] > max( + objects_to_interact + ), f"interacting {objects_to_interact} in {low_res_masks.shape}?" + if self.iter_use_prev_mask_pred: + # Feed the mask logits of the previous SAM outputs in the next SAM decoder step. + # For tracking, this means that when the user adds a correction click, we also feed + # the tracking output mask logits along with the click as input to the SAM decoder. + mask_inputs = low_res_masks[objects_to_interact] + multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) + pix_feat_with_mem = self._get_interactive_pix_mem( + interactive_vision_feats, interactive_feat_sizes + ) + sam_outputs = self._forward_sam_heads( + backbone_features=pix_feat_with_mem, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + interactive_high_res_features=interactive_high_res_features, + propagation_high_res_features=propagation_high_res_features, + multimask_output=multimask_output, + gt_masks=gt_masks, + objects_to_interact=objects_to_interact, + multiplex_state=multiplex_state, + ) + interact_low_res_multimasks = sam_outputs["low_res_multimasks"] + interact_high_res_multimasks = sam_outputs["high_res_multimasks"] + interact_ious = sam_outputs["ious"] + interact_low_res_masks = sam_outputs["low_res_masks"] + interact_high_res_masks = sam_outputs["high_res_masks"] + interact_object_score_logits = sam_outputs["object_score_logits"] + if self.use_obj_ptrs_in_encoder: + interact_obj_ptr = sam_outputs["obj_ptr"] + + if self.training: + # combine the masks from the interacted and non-interacted objects + low_res_masks = low_res_masks.clone() + high_res_masks = high_res_masks.clone() + low_res_multimasks = low_res_multimasks.clone() + high_res_multimasks = high_res_multimasks.clone() + ious = ious.clone() + object_score_logits = object_score_logits.clone() + obj_ptr = obj_ptr.clone() if self.use_obj_ptrs_in_encoder else None + + # Update masks for the interacted objects + if ( + torch.is_floating_point(interact_low_res_masks) + and interact_low_res_masks.dtype != low_res_masks.dtype + ): + interact_low_res_masks = interact_low_res_masks.to( + dtype=low_res_masks.dtype + ) + low_res_masks[objects_to_interact] = interact_low_res_masks + if ( + torch.is_floating_point(interact_high_res_masks) + and interact_high_res_masks.dtype != high_res_masks.dtype + ): + interact_high_res_masks = interact_high_res_masks.to( + dtype=high_res_masks.dtype + ) + high_res_masks[objects_to_interact] = interact_high_res_masks + if ( + torch.is_floating_point(interact_low_res_multimasks) + and interact_low_res_multimasks.dtype != low_res_multimasks.dtype + ): + interact_low_res_multimasks = interact_low_res_multimasks.to( + dtype=low_res_multimasks.dtype + ) + low_res_multimasks[objects_to_interact] = interact_low_res_multimasks + if ( + torch.is_floating_point(interact_high_res_multimasks) + and interact_high_res_multimasks.dtype != high_res_multimasks.dtype + ): + interact_high_res_multimasks = interact_high_res_multimasks.to( + dtype=high_res_multimasks.dtype + ) + high_res_multimasks[objects_to_interact] = interact_high_res_multimasks + if ( + torch.is_floating_point(interact_ious) + and interact_ious.dtype != ious.dtype + ): + interact_ious = interact_ious.to(dtype=ious.dtype) + ious[objects_to_interact] = interact_ious + if ( + torch.is_floating_point(interact_object_score_logits) + and interact_object_score_logits.dtype != object_score_logits.dtype + ): + interact_object_score_logits = interact_object_score_logits.to( + dtype=object_score_logits.dtype + ) + object_score_logits[objects_to_interact] = interact_object_score_logits + if self.use_obj_ptrs_in_encoder: + obj_ptr[objects_to_interact] = interact_obj_ptr + + all_pred_masks.append(low_res_masks) + all_pred_high_res_masks.append(high_res_masks) + all_pred_multimasks.append(low_res_multimasks) + all_pred_high_res_multimasks.append(high_res_multimasks) + all_pred_ious.append(ious) + all_point_inputs.append(point_inputs) + all_object_score_logits.append(object_score_logits) + + # Concatenate the masks along channel (to compute losses on all of them, + # using `onevision.losses.loss_fns.MultiStepIteractiveMasks`) + current_out["multistep_pred_masks"] = torch.cat(all_pred_masks, dim=1) + current_out["multistep_pred_masks_high_res"] = torch.cat( + all_pred_high_res_masks, dim=1 + ) + current_out["multistep_pred_multimasks"] = all_pred_multimasks + current_out["multistep_pred_multimasks_high_res"] = ( + all_pred_high_res_multimasks + ) + current_out["multistep_pred_ious"] = all_pred_ious + current_out["multistep_point_inputs"] = all_point_inputs + current_out["multistep_object_score_logits"] = all_object_score_logits + + if self.add_all_frames_to_correct_as_cond: + if objects_to_interact is None: + current_out["conditioning_objects"].update( + multiplex_state.get_all_valid_object_idx() + ) + else: + current_out["conditioning_objects"].update(set(objects_to_interact)) + + # Use the final prediction (after all correction steps for output and eval) + current_out["pred_masks"] = low_res_masks + current_out["pred_masks_high_res"] = high_res_masks + if self.use_obj_ptrs_in_encoder: + # similar to spatial memory, the object pointers are stored with multiplex + current_out["obj_ptr"] = multiplex_state.mux(obj_ptr) + if self.use_memory_selection: + current_out["object_score_logits"] = object_score_logits + iou_score = current_out["multistep_pred_ious"][-1].max(-1)[0] + current_out["iou_score"] = iou_score + current_out["eff_iou_score"] = self.cal_mem_score( + object_score_logits, iou_score + ) + # we need to return this for encoding new masks in the dynamic mode + current_out["object_score_logits"] = object_score_logits + + # Finally run the memory encoder on the predicted mask to encode + # it into a new memory feature (that can be used in future frames) + # (note that `self.num_maskmem == 0` is primarily used for reproducing SAM on + # images, in which case we'll just skip memory encoder to save compute). + if run_mem_encoder and self.num_maskmem > 0: + high_res_masks_for_mem_enc = high_res_masks + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + image=image, + current_vision_feats=propagation_vision_feats, + feat_sizes=propagation_feat_sizes, + pred_masks_high_res=high_res_masks_for_mem_enc, + object_score_logits=object_score_logits, + is_mask_from_pts=(point_inputs is not None), + conditioning_objects=current_out["conditioning_objects"], + multiplex_state=multiplex_state, + ) + current_out["maskmem_features"] = maskmem_features + current_out["maskmem_pos_enc"] = maskmem_pos_enc + + if self.save_image_features: + current_out["image_features"] = propagation_vision_feats[-1] + current_out["image_pos_enc"] = propagation_vision_pos_embeds[-1] + + # this is to avoid recomputing some of these features for add_new_masks_to_existing_state + aux_output = {} + if need_aux_output: + if interactive_pix_feat is None: + interactive_pix_feat = self._get_interactive_pix_mem( + interactive_vision_feats, interactive_feat_sizes + ) + aux_output["interactive_pix_feat"] = interactive_pix_feat + aux_output["interactive_high_res_features"] = interactive_high_res_features + aux_output["propagation_vision_feats"] = propagation_vision_feats + aux_output["propagation_feat_sizes"] = propagation_feat_sizes + + return current_out, aux_output + + def _trim_output_and_memory( + self, + frame_idx: int, + output_dict: dict[str, dict[int, StageOutput]], + current_out: StageOutput, + memory_encoder_was_used: bool, + ) -> StageOutput: + # Optionally, offload the outputs to CPU memory during evaluation to avoid + # GPU OOM on very long videos or very large resolution or too many objects + if self.offload_output_to_cpu_for_eval and not self.training: + # Here we only keep those keys needed for evaluation to get a compact output + trimmed_out: StageOutput = { + "conditioning_objects": current_out["conditioning_objects"], + "pred_masks": current_out["pred_masks"].cpu(), + "pred_masks_high_res": current_out["pred_masks_high_res"].cpu(), + # other items for evaluation (these are small tensors so we keep them on GPU) + "object_score_logits": current_out["object_score_logits"], + "multistep_point_inputs": current_out["multistep_point_inputs"], + } + if self.use_obj_ptrs_in_encoder: + trimmed_out["obj_ptr"] = current_out["obj_ptr"] + if memory_encoder_was_used and self.num_maskmem > 0: + trimmed_out["maskmem_features"] = current_out["maskmem_features"].cpu() + trimmed_out["maskmem_pos_enc"] = [ + x.cpu() for x in current_out["maskmem_pos_enc"] + ] + if self.save_image_features: + trimmed_out["image_features"] = current_out["image_features"].cpu() + trimmed_out["image_pos_enc"] = current_out["image_pos_enc"].cpu() + current_out = trimmed_out + + # Optionally, trim the output of past non-conditioning frame (r * num_maskmem frames + # before the current frame) during evaluation. This is intended to save GPU or CPU + # memory for semi-supervised VOS eval, where only the first frame receives prompts. + def _trim_past_out( + past_out: StageOutput, current_out: StageOutput + ) -> Optional[StageOutput]: + if past_out is None: + return None + trimmed_past_out: StageOutput = { + "conditioning_objects": past_out["conditioning_objects"], + "pred_masks": past_out["pred_masks"], + "object_score_logits": past_out["object_score_logits"], + # Why would this be current_out? + # "multistep_point_inputs": current_out["multistep_point_inputs"], + "multistep_point_inputs": past_out["multistep_point_inputs"], + } + if self.use_obj_ptrs_in_encoder: + trimmed_past_out["obj_ptr"] = past_out["obj_ptr"] + return trimmed_past_out + + if self.trim_past_non_cond_mem_for_eval and not self.training: + r = self.memory_temporal_stride_for_eval + past_frame_idx = frame_idx - r * self.num_maskmem + past_out = output_dict["non_cond_frame_outputs"].get(past_frame_idx, None) + + if past_out is not None: + if ( + self.use_memory_selection + and past_out.get("eff_iou_score", 0) < self.mf_threshold + ) or not self.use_memory_selection: + output_dict["non_cond_frame_outputs"][past_frame_idx] = ( + _trim_past_out(past_out, current_out) + ) + + if ( + self.use_memory_selection and not self.offload_output_to_cpu_for_eval + ): # design for memory selection, trim too old frames to save memory + far_old_frame_idx = frame_idx - 20 * self.max_obj_ptrs_in_encoder + past_out = output_dict["non_cond_frame_outputs"].get( + far_old_frame_idx, None + ) + if past_out is not None: + output_dict["non_cond_frame_outputs"][far_old_frame_idx] = ( + _trim_past_out(past_out, current_out) + ) + + return current_out + + def track_step( + self, + *, + frame_idx, + is_init_cond_frame, + backbone_features_interactive, + backbone_features_propagation, + image, + point_inputs, + mask_inputs, + gt_masks, + frames_to_add_correction_pt, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + # Whether to run the memory encoder on the predicted masks. Sometimes we might want + # to skip the memory encoder with `run_mem_encoder=False`. For example, + # in demo we might call `track_step` multiple times for each user click, + # and only encode the memory when the user finalizes their clicks. And in ablation + # settings like SAM training on static images, we don't need the memory encoder. + run_mem_encoder=True, + # The previously predicted SAM mask logits (which can be fed together with new clicks in demo). + prev_sam_mask_logits=None, + multiplex_state: MultiplexState, + # The list of object idx that point_inputs correspond to; only this set of objects will + # be interacted with in the correction stage + objects_to_interact: Optional[list[int]] = None, + ) -> StageOutput: + current_out, _ = self._track_step_aux( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + backbone_features_interactive=backbone_features_interactive, + backbone_features_propagation=backbone_features_propagation, + image=image, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + gt_masks=gt_masks, + frames_to_add_correction_pt=frames_to_add_correction_pt, + output_dict=output_dict, + num_frames=num_frames, + track_in_reverse=track_in_reverse, + run_mem_encoder=run_mem_encoder, + prev_sam_mask_logits=prev_sam_mask_logits, + multiplex_state=multiplex_state, + objects_to_interact=objects_to_interact, + need_aux_output=False, + ) + current_out = self._trim_output_and_memory( + frame_idx, output_dict, current_out, memory_encoder_was_used=run_mem_encoder + ) + + return current_out + + def back_convert(self, targets): + """To be compatible with SetCriterionAPI losses (mask loss only).""" + batched_targets = {} + batched_targets["num_boxes"] = targets.num_boxes + batched_targets["masks"] = targets.segments + batched_targets["is_valid_mask"] = targets.is_valid_segment + return batched_targets + + def _use_multimask(self, is_init_cond_frame, point_inputs): + """Whether to use multimask output in the SAM head.""" + num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1) + multimask_output = ( + self.multimask_output_in_sam + and (is_init_cond_frame or self.multimask_output_for_tracking) + and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num) + and self.num_multimask_outputs > 0 + ) + return multimask_output + + def _apply_non_overlapping_constraints(self, pred_masks): + """ + Apply non-overlapping constraints to the object scores in pred_masks. Here we + keep only the highest scoring object at each spatial location in pred_masks. + """ + batch_size = pred_masks.size(0) + if batch_size == 1: + return pred_masks + + device = pred_masks.device + # "max_obj_inds": object index of the object with the highest score at each location + max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True) + # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks` + batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None] + keep = max_obj_inds == batch_obj_inds + # suppress overlapping regions' scores below -10.0 so that the foreground regions + # don't overlap (here sigmoid(-10.0)=4.5398e-05) + pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0)) + return pred_masks + + def _compile_all_components(self): + """Compile all model components for faster inference.""" + # a larger cache size to hold varying number of shapes for torch.compile + # see https://github.com/pytorch/pytorch/blob/v2.5.1/torch/_dynamo/config.py#L42-L49 + torch._dynamo.config.cache_size_limit = 64 + torch._dynamo.config.accumulated_cache_size_limit = 2048 + + logging.info("Compiling all components. First time may be very slow.") + + self.maskmem_backbone.forward = torch.compile( + self.maskmem_backbone.forward, + mode="max-autotune", + fullgraph=True, + dynamic=False, + ) + self.transformer.encoder.forward = torch.compile( + self.transformer.encoder.forward, + mode="max-autotune", + fullgraph=True, + dynamic=True, # Num. of memories varies + ) + # We disable compilation of sam_prompt_encoder as it sometimes gives a large accuracy regression, + # especially when sam_mask_prompt (previous mask logits) is not None + # self.sam_prompt_encoder.forward = torch.compile( + # self.sam_prompt_encoder.forward, + # mode="max-autotune", + # fullgraph=True, + # dynamic=False, # Accuracy regression on True + # ) + self.sam_mask_decoder.forward = torch.compile( + self.sam_mask_decoder.forward, + mode="max-autotune", + fullgraph=True, + dynamic=False, # Accuracy regression on True + ) + + def _maybe_clone(self, x): + """Clone a tensor if and only if `self.compile_all_components` is True.""" + return x.clone() if self.compile_all_components else x + + def get_propagation_dense_pe(self) -> torch.Tensor: + """ + Returns the positional encoding used to encode point prompts, + applied to a dense set of points the shape of the image encoding. + + Returns: + torch.Tensor: Positional encoding with shape + 1x(embed_dim)x(embedding_h)x(embedding_w) + """ + return self.image_pe_layer( + (self.sam_image_embedding_size, self.sam_image_embedding_size) + ).unsqueeze(0) + + def cal_mem_score(self, object_score_logits, iou_score): + object_score_norm = torch.where( + object_score_logits > 0, + object_score_logits.sigmoid() * 2 - 1, # rescale to [0, 1] + torch.zeros_like(object_score_logits), + ) + score_per_frame = (object_score_norm * iou_score).mean() + return score_per_frame + + def frame_filter(self, output_dict, track_in_reverse, frame_idx, num_frames, r): + if (frame_idx == 0 and not track_in_reverse) or ( + frame_idx == num_frames - 1 and track_in_reverse + ): + return [] + + max_num = min( + num_frames, self.max_obj_ptrs_in_encoder + ) # maximum number of pointer memory frames to consider + + if not track_in_reverse: + start = frame_idx - 1 + end = 0 + step = -r + must_include = frame_idx - 1 + else: + start = frame_idx + 1 + end = num_frames + step = r + must_include = frame_idx + 1 + + valid_indices = [] + for i in range(start, end, step): + if ( + i not in output_dict["non_cond_frame_outputs"] + or "eff_iou_score" not in output_dict["non_cond_frame_outputs"][i] + ): + continue + + score_per_frame = output_dict["non_cond_frame_outputs"][i]["eff_iou_score"] + + if score_per_frame > self.mf_threshold: # threshold + valid_indices.insert(0, i) + + if len(valid_indices) >= max_num - 1: + break + + if must_include not in valid_indices: + valid_indices.append(must_include) + + return valid_indices + + +def concat_points(old_point_inputs, new_points, new_labels): + """Add new points and labels to previous point inputs (add at the end).""" + if old_point_inputs is None: + points, labels = new_points, new_labels + else: + points = torch.cat([old_point_inputs["point_coords"], new_points], dim=1) + labels = torch.cat([old_point_inputs["point_labels"], new_labels], dim=1) + + return {"point_coords": points, "point_labels": labels} + + +def _append( + d1: StageOutput, d2: SAMOutput, k1: str, k2: str, dim: int = 0, strict: bool = True +): + if strict: + assert k1 in d1, f"{k1} not found" + else: + if k1 not in d1: + return + + d1[k1] = torch.cat([d1[k1], d2[k2]], dim=dim) + + +def _merge( + d1: StageOutput, + d2: SAMOutput, + k1: str, + k2: str, + d2_idx: list[int], + strict: bool = True, +): + if strict: + assert k1 in d1, f"{k1} not found" + else: + if k1 not in d1: + return + d1[k1][d2_idx] = d2[k2].to(dtype=d1[k1].dtype) + + +class VideoTrackingDynamicMultiplex(VideoTrackingMultiplex): + def __init__( + self, + enable_dynamic_training: bool = True, # Allows the number of objects to increase across frames during training + rand_num_transition_points: bool = True, # Randomizes the number of transition points + max_num_transition_points: int = 3, # Maximum number of transition points + add_all_transition_frames_as_cond: bool = True, + max_trans_frames_in_attn: int = 4, + is_dynamic_model: bool = True, # Overrides the default + is_dynamic_vos_evaluation: bool = False, # For datasets like YouTubeVOS which have new objects + **kwargs, + ): + super().__init__(is_dynamic_model=is_dynamic_model, **kwargs) + + self.enable_dynamic_training = enable_dynamic_training + self.rand_num_transition_points = rand_num_transition_points + self.max_num_transition_points = max_num_transition_points + + self.add_all_transition_frames_as_cond = add_all_transition_frames_as_cond + self.max_trans_frames_in_attn = max_trans_frames_in_attn + self.is_dynamic_vos_evaluation = is_dynamic_vos_evaluation + + def prepare_prompt_inputs(self, backbone_out, input, start_frame_idx=0): + """ + Prepare input mask, point or box prompts. Optionally, we allow tracking from + a custom `start_frame_idx` to the end of the video (for evaluation purposes). + """ + + """ + This function, in addition to the prompt preparation done in the parent class, preprocesses the + masks and pre-computes visibility/validity attributes necessary for training with dynamic bucketing. + + **Data** + We use a modified dataset class and a modified collate_fn such that: + 1. The mask for an object is loaded if it is visible (area>0) on any of the loaded frames + 2. A "visible_objects_per_frame" attribute is computed, which contains the set of objects with area>0 on each frame + + Here, we use [] to denote a set of objects; i.e., object A and B are represented as [A, B]. + Consider the masks given by the dataloader in an arbitrary yet deterministic order. + That is, [2, 3] can appear on the first frame, and [1, 2, 3, 17] can appear on the second frame. + + This is incompatible with the object addition implementation, since we assume new objects are appended, not inserted. + Thus, we compute object_appearance_order which sorts the object idx using the frame at which they appear + (conditional frames always appear first). For objects that appear on the same frame, we shuffle them as augmentation. + We also reorder the ground-truth masks used for supervision. + + **Causal supervision** + Since not all objects appear on the first frame, we should not supervise on the objects that the model has no knowledge of yet. + Thus, we keep track of the set of objects that have been introduced, and the frame at which that happens. + We compute valid_idx_per_frame (and correspondingly trim the ground-truth) to enforce reasonable supervisions. + + **Transition points** + Transition points are non-initial-conditioning frames that introduce new objects. We uniformly sample some frames + to be candidates for transition points, and use them if they actually introduce new objects compared to the last seen + conditional frame/transition point. + Transitions do not always happen when an object first becomes visible, because our (initial) sampling is agnostic to visibility. + This is intended, as new objects do not always get detected immediately in the dense tracking setting. + """ + + # First, prepare the prompt inputs following the parent class + backbone_out = super()._prepare_prompt_inputs_meta( + backbone_out, input, start_frame_idx=start_frame_idx + ) + + num_frames = backbone_out["num_frames"] + gt_masks_per_frame = backbone_out["gt_masks_per_frame"] + + if self.training or self.is_dynamic_vos_evaluation: + visible_objects_per_frame: dict[int, set[int]] = ( + input.visible_objects_per_frame + ) + else: + visible_objects_per_frame: dict[int, set[int]] = { + stage_id: set(range(gt_masks_per_frame[stage_id].shape[0])) + for stage_id in range(num_frames) + } + + # If we have more than one conditioning frame, + # all visible objects on any of the conditioning frames become valid for all frames + init_cond_frames: list[int] = backbone_out["init_cond_frames"] + init_cond_frames = sorted(init_cond_frames) + frames_not_in_init_cond: list[int] = backbone_out["frames_not_in_init_cond"] + + # Rare case: the data guard might fail and we could have an empty first frame. + # In this case, we track an empty object. + if len(visible_objects_per_frame[start_frame_idx]) == 0: + if self.training: + logging.warning("Empty first frame, tracking an empty object") + visible_objects_per_frame[start_frame_idx] = {0} + # set the GT mask for this object to be all zeros + for stage_id in range(num_frames): + gt_masks_per_frame[stage_id][0] = torch.zeros_like( + gt_masks_per_frame[stage_id][0] + ) + else: + # During evaluation, this should only happen for YouTubeVOS. + # We will skip the frames before the first conditional frame. + assert ( + self.is_dynamic_vos_evaluation + ), f"{visible_objects_per_frame=} invalid" + assert len(init_cond_frames) == 1 + for stage_id in range(start_frame_idx, num_frames): + if len(visible_objects_per_frame[stage_id]) > 0: + init_cond_frames = [stage_id] + break + for i in range( + init_cond_frames[0] + 1 + ): # also remove init_cond_frames[0] + if i in frames_not_in_init_cond: + frames_not_in_init_cond.remove(i) + + backbone_out["init_cond_frames"] = init_cond_frames + + # The object idx in valid_idx_per_frame should be in sequential order. + # We will first reshuffle the objects using object_appearance_order, + # and then index via valid_idx_per_frame. + valid_idx_per_frame: dict[int, list[int]] = {} + # Importantly, we cannot simply use valid_idx_per_frame[stage_id-1] because it might be a conditional frame. + valid_idx_prior_to_each_transition: dict[int, list[int]] = {} + new_idx_per_transition: dict[int, list[int]] = {} + + if self.training and self.enable_dynamic_training: + # Select the number of transition points + if self.rand_num_transition_points: + # Randomly select 1 to `max_num_transition_points` transition points + num_transition_points = self.rng.integers( + 1, self.max_num_transition_points, endpoint=True + ) + else: + num_transition_points = self.max_num_transition_points + + available_transition_points = frames_not_in_init_cond + num_transition_points = min( + num_transition_points, len(available_transition_points) + ) + # num_transition_points can differ between GPUs so we use rng2 + transition_points = self.rng2.choice( + available_transition_points, num_transition_points, replace=False + ).tolist() + transition_points = sorted(transition_points) + + # Filter for the transition points that do introduce new objects + filtered_transition_points = [] + objects_seen = set() + for stage_id in init_cond_frames: + objects_seen.update(visible_objects_per_frame[stage_id]) + + for stage_id in range(start_frame_idx, num_frames): + if stage_id in transition_points: + new_objects_seen = ( + visible_objects_per_frame[stage_id] - objects_seen + ) + if len(new_objects_seen) > 0: + filtered_transition_points.append(stage_id) + objects_seen.update(new_objects_seen) + new_idx_per_transition[stage_id] = list(new_objects_seen) + transition_points = filtered_transition_points + + # Create appearance-based object ordering with randomization + init_objects = set() + for stage_id in init_cond_frames: + init_objects.update(visible_objects_per_frame[stage_id]) + init_objects = list(init_objects) + self.rng2.shuffle(init_objects) + + object_appearance_order = init_objects.copy() + valid_idx_per_frame[start_frame_idx] = list(range(len(init_objects))) + for stage_id in range(start_frame_idx + 1, num_frames): + if stage_id in transition_points: + # When objects appear at a transition point, we add them to the end of the list + stage_objects = new_idx_per_transition[stage_id].copy() + self.rng2.shuffle(stage_objects) + valid_idx_prior_to_each_transition[stage_id] = list( + range(len(object_appearance_order)) + ) + new_idx_per_transition[stage_id] = list( + range( + len(object_appearance_order), + len(object_appearance_order) + len(stage_objects), + ) + ) + object_appearance_order.extend(stage_objects) + + # Update the valid objects at this frame + if stage_id in init_cond_frames: + # Note: on any non-first init cond frame, the number of valid objects + # might be fewer than the previous frame because we always process the init cond frames first. + # For example, if [1, 2, 4] are visible on the two init cond frames (e.g., frame 0 and frame 5), + # and object 3 appears on frame 4 (as a transition point), object 3 would not be considered valid on frame 5. + # This should not break any processing steps or affect correctness (since invalid objects are marked as floating). + valid_idx_per_frame[stage_id] = valid_idx_per_frame[ + start_frame_idx + ].copy() + elif stage_id in frames_not_in_init_cond: + valid_idx_per_frame[stage_id] = list( + range(len(object_appearance_order)) + ) + else: + raise ValueError( + f"Unexpected {stage_id=}? {init_cond_frames=} {frames_not_in_init_cond=} {transition_points=}" + ) + elif self.is_dynamic_vos_evaluation and not self.training: + # In dynamic VOS evaluation, we find the transition points manually. + # Each object should appear on exactly one frame. + # NOTE: The new release of YouTubeVOS apparently did not enforce this. + # We are enforcing it here. + + # Find first appearance of each object + object_appearance_order: list[int] = [] + object_appear_at_stage: dict[int, int] = {} + transition_points: list[int] = [] + stage_to_new_objects: dict[int, list[int]] = defaultdict(list) + for stage_id in range(start_frame_idx, num_frames): + visible_objects = sorted(list(visible_objects_per_frame[stage_id])) + for obj_id in visible_objects: + if obj_id in object_appear_at_stage: + continue # skip seen objects + + object_appear_at_stage[obj_id] = stage_id + object_appearance_order.append(obj_id) + stage_to_new_objects[stage_id].append(obj_id) + if stage_id not in init_cond_frames: + transition_points.append(stage_id) + + # Track cumulative object count + objects_seen_so_far = [] + for stage_id in range(start_frame_idx, num_frames): + if stage_id in transition_points: + # New objects appear at this frame + new_objects = stage_to_new_objects[stage_id] + num_objects_before = len(objects_seen_so_far) + + # Record which objects were valid before this transition + valid_idx_prior_to_each_transition[stage_id] = list( + range(num_objects_before) + ) + # Record the indices of new objects + new_idx_per_transition[stage_id] = list( + range(num_objects_before, num_objects_before + len(new_objects)) + ) + + objects_seen_so_far.extend(new_objects) + + # Set valid objects for this frame + if stage_id in init_cond_frames: + # For init cond frames, only the initial objects are valid + valid_idx_per_frame[stage_id] = list( + range(len(stage_to_new_objects[stage_id])) + ) + objects_seen_so_far.extend(stage_to_new_objects[stage_id]) + else: + # For other frames, all objects seen so far are valid + valid_idx_per_frame[stage_id] = list( + range(len(objects_seen_so_far)) + ) + + else: + # Use no transition points when dynamic training is disabled + transition_points = [] + visible_objects_on_first_frame = sorted( + list(visible_objects_per_frame[start_frame_idx]) + ) + # Since visible_objects_on_first_frame might not be consecutive + object_orderings = list(range(len(visible_objects_on_first_frame))) + # Use the original order for evaluation + object_appearance_order = visible_objects_on_first_frame.copy() + for stage_id in range(start_frame_idx, num_frames): + valid_idx_per_frame[stage_id] = object_orderings.copy() + + # Apply the appearance-based mapping to ground-truth masks + for stage_id in range(start_frame_idx, num_frames): + gt_masks_per_frame[stage_id] = gt_masks_per_frame[stage_id][ + object_appearance_order + ][valid_idx_per_frame[stage_id]] + + # We also want to apply this change in-place to the input, such that loss can be computed correctly. + # For targets.segments, we need to delay the object introduction by 1 frame. + # At transition points, use current frame's masks but only for objects that existed in the previous frame. + # This allows us to compute the loss on the existing objects and not on the newly added objects. + for stage_id, targets in enumerate(input.find_targets): + if stage_id in transition_points: + # At transition points, use current frame's masks but only keep objects from the previous frame + prev_objects = valid_idx_prior_to_each_transition[stage_id] + # Only keep masks for objects that existed in the previous frame + targets.segments = gt_masks_per_frame[stage_id][prev_objects].squeeze(1) + else: + targets.segments = gt_masks_per_frame[stage_id].squeeze(1) + # Ensure that we are averaging the loss correctly. + # Although this is called num_boxes, it actually stores an array of ones with length=number of objects in the VOS setting. + targets.num_boxes = targets.num_boxes[: targets.segments.shape[0]] + + backbone_out["valid_idx_per_frame"] = valid_idx_per_frame + backbone_out["new_idx_per_transition"] = new_idx_per_transition + backbone_out["valid_objects_prior_to_each_transition"] = ( + valid_idx_prior_to_each_transition + ) + backbone_out["transition_points"] = set(transition_points) + backbone_out["gt_masks_per_frame"] = gt_masks_per_frame + backbone_out["object_appearance_order"] = object_appearance_order + + backbone_out = self._prepare_conditional_frames(backbone_out) + + return backbone_out + + def add_new_masks_to_existing_state( + self, + *, + interactive_pix_feat: torch.Tensor, + interactive_high_res_features: list[torch.Tensor], + propagation_vision_feats: Optional[ + list[torch.Tensor] + ], # needed when add_mask_to_memory=True + propagation_feat_sizes: Optional[ + list[tuple[int, int]] + ], # needed when add_mask_to_memory=True + new_masks: torch.Tensor, + obj_idxs_in_mask: list[ + int + ], # len(obj_idxs_in_mask) == new_masks.shape[0]; object idx internal to this state + obj_ids_in_mask: Optional[ + list[int] + ], # len(obj_ids_in_mask) == new_masks.shape[0]; global object ids + prev_output: StageOutput, # this state will be modified in-place + multiplex_state: MultiplexState, + add_mask_to_memory: bool = True, + are_masks_from_pts: bool = False, + allow_new_buckets: bool = False, + prefer_new_buckets: bool = False, + ) -> None: + """ + Add new objects to an existing output/multiplex state. + + This function encodes the input masks as new masks and merges them with the existing state. + The new object entries are always appended to the existing objects. + + This is because, in the dense tracking scenario, we should always propagate (existing state) + to the current frame first before introducing the new objects. + """ + assert self.use_mask_input_as_output_without_sam + assert new_masks.shape[0] == len(obj_idxs_in_mask) + + num_new_objects = new_masks.shape[0] + + if obj_ids_in_mask is not None: + assert len(obj_ids_in_mask) == num_new_objects + + if self.use_obj_ptrs_in_encoder: + # demux the existing pointers before we change the multiplex state + existing_pointers = multiplex_state.demux(prev_output["obj_ptr"]) + + # Step 1: Inform the multiplex state that we are adding new objects + new_object_idx = multiplex_state.find_next_batch_of_available_indices( + num_objects=num_new_objects, + allow_new_buckets=allow_new_buckets, + prefer_new_buckets=prefer_new_buckets, + ) + multiplex_state.add_objects( + object_indices=new_object_idx, + object_ids=obj_ids_in_mask, + allow_new_buckets=allow_new_buckets, + prefer_new_buckets=prefer_new_buckets, + ) + + # Step 2: Encode the incoming masks + mask_output = self._use_mask_as_output( + backbone_features=interactive_pix_feat, + high_res_features=interactive_high_res_features, + mask_inputs=new_masks, + multiplex_state=multiplex_state, + objects_in_mask=new_object_idx, + ) + + # Step 3: Merge the existing state with new encoded features + # Handle resolution mismatch between propagation (e.g., 1008) and interactive (e.g., 288) features + # Determine target resolution from interactive features (newly generated masks) + interactive_resolution = mask_output["high_res_masks"].shape[-1] + + # Check if prev_output needs resolution adjustment + if ( + "pred_masks_high_res" in prev_output + and prev_output["pred_masks_high_res"] is not None + ): + existing_resolution = prev_output["pred_masks_high_res"].shape[-1] + + if existing_resolution != interactive_resolution: + # Resize existing outputs to match interactive resolution + # This happens when frame was bootstrapped with propagation features (1008) + # but we're now adding interactive masks (288) + prev_output["pred_masks_high_res"] = F.interpolate( + prev_output["pred_masks_high_res"], + size=(interactive_resolution, interactive_resolution), + mode="bilinear", + align_corners=False, + ) + + # Resize low_res_masks to match prev_output resolution + h, w = prev_output["pred_masks"].shape[-2:] + mask_output["low_res_masks"] = F.interpolate( + mask_output["low_res_masks"], + size=(h, w), + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + + _append(prev_output, mask_output, "pred_masks", "low_res_masks") + _append( + prev_output, + mask_output, + "pred_masks_high_res", + "high_res_masks", + strict=False, + ) + _append(prev_output, mask_output, "object_score_logits", "object_score_logits") + if self.use_memory_selection: + mask_output["ious"] = mask_output["ious"].squeeze(-1) + _append(prev_output, mask_output, "iou_score", "ious") + + # Merge the input masks + if "input_masks" in prev_output: + prev_output["input_masks"] = torch.cat( + [prev_output["input_masks"], new_masks], dim=0 + ) + + if self.use_obj_ptrs_in_encoder: + # Merge the object pointers. Note that the pointers in SAMOutput are in the data space, + # while those in StageOutput are in the mux space. + new_pointers = mask_output["obj_ptr"].to(existing_pointers.dtype) + combined_pointers = torch.cat([existing_pointers, new_pointers], dim=0) + prev_output["obj_ptr"] = multiplex_state.mux(combined_pointers) + + # Step 4: Update the set of conditioning objects at this frame. + prev_output["conditioning_objects"].update(new_object_idx) + + # Step 5: Re-encode the spatial memory if needed + if add_mask_to_memory: + assert ( + prev_output["pred_masks_high_res"].shape[0] + == multiplex_state.total_valid_entries + ) + # Add the new masks to the memory + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + image=None, + current_vision_feats=propagation_vision_feats, + feat_sizes=propagation_feat_sizes, + pred_masks_high_res=prev_output["pred_masks_high_res"], + object_score_logits=prev_output["object_score_logits"], + conditioning_objects=prev_output["conditioning_objects"], + is_mask_from_pts=are_masks_from_pts, + multiplex_state=multiplex_state, + ) + prev_output["maskmem_features"] = maskmem_features + prev_output["maskmem_pos_enc"] = maskmem_pos_enc + if self.save_image_features: + # They should already be in the state; no modification is needed + assert "image_features" in prev_output + assert "image_pos_enc" in prev_output + + def recondition_masks_in_existing_state( + self, + *, + interactive_pix_feat: torch.Tensor, + interactive_high_res_features: list[torch.Tensor], + propagation_vision_feats: Optional[ + list[torch.Tensor] + ], # needed when add_mask_to_memory=True + propagation_feat_sizes: Optional[ + list[tuple[int, int]] + ], # needed when add_mask_to_memory=True + new_masks: torch.Tensor, + obj_idxs_in_mask: list[ + int + ], # len(obj_idxs_in_mask) == new_masks.shape[0]; object idx internal to this state + obj_ids_in_mask: Optional[ + list[int] + ], # len(obj_ids_in_mask) == new_masks.shape[0]; global object ids + prev_output: StageOutput, # this state will be modified in-place + multiplex_state: MultiplexState, + add_mask_to_memory: bool = True, + ) -> None: + """ + Recondition existing objects in an existing output/multiplex state. + + This function encodes the input masks and merges them with the existing state. + """ + assert self.use_mask_input_as_output_without_sam + assert new_masks.shape[0] == len(obj_idxs_in_mask) + + num_new_objects = new_masks.shape[0] + + if obj_ids_in_mask is not None: + assert len(obj_ids_in_mask) == num_new_objects + + if self.use_obj_ptrs_in_encoder: + # demux the existing pointers before we change the multiplex state + existing_pointers = multiplex_state.demux(prev_output["obj_ptr"]) + + # Step 1: Encode the incoming masks + mask_output = self._use_mask_as_output( + backbone_features=interactive_pix_feat, + high_res_features=interactive_high_res_features, + mask_inputs=new_masks, + multiplex_state=multiplex_state, + objects_in_mask=obj_idxs_in_mask, + ) + + # Step 2: Merge the existing state with new encoded features + # TODO: Remove this and fix the resolution mismatch + h, w = prev_output["pred_masks"].shape[-2:] + mask_output["low_res_masks"] = F.interpolate( + mask_output["low_res_masks"], + size=(h, w), + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + + _merge( + prev_output, mask_output, "pred_masks", "low_res_masks", obj_idxs_in_mask + ) + _merge( + prev_output, + mask_output, + "pred_masks_high_res", + "high_res_masks", + obj_idxs_in_mask, + strict=False, + ) + _merge( + prev_output, + mask_output, + "object_score_logits", + "object_score_logits", + obj_idxs_in_mask, + ) + if self.use_memory_selection: + mask_output["ious"] = mask_output["ious"].squeeze(-1) + _merge( + prev_output, + mask_output, + "iou_score", + "ious", + obj_idxs_in_mask, + ) + + # Merge the input masks + if "input_masks" in prev_output: + prev_output["input_masks"][obj_idxs_in_mask] = new_masks + + if self.use_obj_ptrs_in_encoder: + # Merge the object pointers. Note that the pointers in SAMOutput are in the data space, + # while those in StageOutput are in the mux space. + new_pointers = mask_output["obj_ptr"].to(existing_pointers.dtype) + existing_pointers[obj_idxs_in_mask] = new_pointers + prev_output["obj_ptr"] = multiplex_state.mux(existing_pointers) + + # Step 3: Update the set of conditioning objects at this frame + prev_output["conditioning_objects"].update(obj_idxs_in_mask) + + # Step 4: Re-encode the spatial memory if needed + if add_mask_to_memory: + assert ( + prev_output["pred_masks_high_res"].shape[0] + == multiplex_state.total_valid_entries + ) + # Add the new masks to the memory + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + image=None, + current_vision_feats=propagation_vision_feats, + feat_sizes=propagation_feat_sizes, + pred_masks_high_res=prev_output["pred_masks_high_res"], + object_score_logits=prev_output["object_score_logits"], + conditioning_objects=prev_output["conditioning_objects"], + is_mask_from_pts=False, + multiplex_state=multiplex_state, + ) + prev_output["maskmem_features"] = maskmem_features + prev_output["maskmem_pos_enc"] = maskmem_pos_enc + if self.save_image_features: + # They should already be in the state; no modification is needed + assert "image_features" in prev_output + assert "image_pos_enc" in prev_output + + def track_step( + self, + *, + frame_idx, + is_init_cond_frame, + backbone_features_interactive, + backbone_features_propagation, + image, + point_inputs, + mask_inputs, + gt_masks, + frames_to_add_correction_pt, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + # Whether to run the memory encoder on the predicted masks. Sometimes we might want + # to skip the memory encoder with `run_mem_encoder=False`. For example, + # in demo we might call `track_step` multiple times for each user click, + # and only encode the memory when the user finalizes their clicks. And in ablation + # settings like SAM training on static images, we don't need the memory encoder. + run_mem_encoder=True, + # The previously predicted SAM mask logits (which can be fed together with new clicks in demo). + prev_sam_mask_logits=None, + multiplex_state: MultiplexState, + # The list of object IDs that point_inputs correspond to; only this set of objects will + # be interacted with in the correction stage + objects_to_interact: Optional[list[int]] = None, + # The following parameters are specific to the dynamic multiplexing model + new_object_masks: Optional[torch.Tensor] = None, + new_object_idxs: Optional[list[int]] = None, + new_object_ids: Optional[list[int]] = None, + are_new_masks_from_pts: bool = False, + ) -> StageOutput: + # First, run track_step_aux. + # This includes propagation, interaction, and correction. + current_out, aux_out = self._track_step_aux( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + backbone_features_interactive=backbone_features_interactive, + backbone_features_propagation=backbone_features_propagation, + image=image, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + gt_masks=gt_masks, + frames_to_add_correction_pt=frames_to_add_correction_pt, + output_dict=output_dict, + num_frames=num_frames, + track_in_reverse=track_in_reverse, + run_mem_encoder=(run_mem_encoder and new_object_masks is None), + prev_sam_mask_logits=prev_sam_mask_logits, + multiplex_state=multiplex_state, + objects_to_interact=objects_to_interact, + need_aux_output=(new_object_masks is not None), + ) + + # If new masks are provided, merge them into the existing state + if new_object_masks is not None: + assert new_object_idxs is not None + self.add_new_masks_to_existing_state( + interactive_pix_feat=aux_out["interactive_pix_feat"], + interactive_high_res_features=aux_out["interactive_high_res_features"], + propagation_vision_feats=aux_out["propagation_vision_feats"], + propagation_feat_sizes=aux_out["propagation_feat_sizes"], + new_masks=new_object_masks, + obj_idxs_in_mask=new_object_idxs, + obj_ids_in_mask=new_object_ids, + prev_output=current_out, + multiplex_state=multiplex_state, + add_mask_to_memory=run_mem_encoder, + are_masks_from_pts=are_new_masks_from_pts, + ) + + # lastly, trim the output + current_out = self._trim_output_and_memory( + frame_idx=frame_idx, + output_dict=output_dict, + current_out=current_out, + memory_encoder_was_used=run_mem_encoder, + ) + + return current_out + + def forward_tracking( + self, + backbone_out, + input, + return_dict=False, + objects_to_interact: Optional[list[int]] = None, + ): + """Forward video tracking on each frame (and sample correction clicks).""" + img_feats_already_computed = ( + "interactive" in backbone_out or "sam2_backbone_out" in backbone_out + ) + if img_feats_already_computed: + # Prepare the backbone features + # - vision_feats and vision_pos_embeds are in (HW)BC format + # - vision_masks are in B(HW) format, dtype=bool (False is valid, True is padding) + backbone_features = self._prepare_backbone_features(backbone_out) + + # Starting the stage loop + num_frames = backbone_out["num_frames"] + init_cond_frames = backbone_out["init_cond_frames"] + frames_to_add_correction_pt = backbone_out["frames_to_add_correction_pt"] + # First process all the initial conditioning frames to encode them as memory, + # And then condition on them to track the remaining frames + processing_order = init_cond_frames + backbone_out["frames_not_in_init_cond"] + + new_idx_per_transition = backbone_out["new_idx_per_transition"] + valid_objects_prior_to_each_transition = backbone_out[ + "valid_objects_prior_to_each_transition" + ] + transition_points = backbone_out["transition_points"] + + cond_frame_outputs: dict[int, StageOutput] = {} + non_cond_frame_outputs: dict[int, StageOutput] = {} + output_dict = { + "cond_frame_outputs": cond_frame_outputs, + "non_cond_frame_outputs": non_cond_frame_outputs, + } + multiplex_state = self.multiplex_controller.get_state( + backbone_out["gt_masks_per_frame"][processing_order[0]].shape[0], + device=backbone_out["gt_masks_per_frame"][processing_order[0]].device, + dtype=torch.float, + random=self.training, + ) + + for stage_id in processing_order: + # Get the image features for the current frame + img_ids = input.find_inputs[stage_id].img_ids + # The image ids are for the entire batch + assert all( + [img_id == img_ids[0] for img_id in img_ids] + ) # should be all the same + # force this to have a batch size of 1 + img_ids = torch.tensor( + [img_ids[0]], device=img_ids.device, dtype=img_ids.dtype + ) + + if img_feats_already_computed: + # Retrieve image features according to img_ids (if they are already computed). + current_image = input.img_batch.tensors[img_ids] + current_backbone_features = {} + for neck_k, neck_out in backbone_features.items(): + current_backbone_features[neck_k] = { + "vision_feats": [ + x[:, img_ids] for x in neck_out["vision_feats"] + ], + "vision_masks": [ + x[img_ids] if x is not None else None + for x in neck_out["vision_masks"] + ], + "vision_pos_embeds": [ + x[:, img_ids] for x in neck_out["vision_pos_embeds"] + ], + "feat_sizes": neck_out["feat_sizes"], + } + else: + # Otherwise, compute the image features on the fly for the given img_ids + # (this might be used for evaluation on long videos to avoid backbone OOM). + need_interactive_out = ( + (stage_id in frames_to_add_correction_pt) + or (stage_id in init_cond_frames) + or (stage_id in transition_points) + ) + (current_image, current_backbone_features) = ( + self._prepare_backbone_features_per_frame( + input.img_batch, + img_ids, + need_interactive_out=need_interactive_out, + need_propagation_out=True, + ) + ) + + gt_masks = backbone_out["gt_masks_per_frame"].get(stage_id, None) + if stage_id in transition_points: + assert gt_masks is not None + + # Figure out new object masks / idxs + new_object_idxs = new_idx_per_transition[stage_id] + # Get the new object masks, ensure correct ordering + assert sorted(new_object_idxs) == new_object_idxs + assert ( + new_object_idxs[0] + == len(valid_objects_prior_to_each_transition[stage_id]) + ), f"{new_object_idxs=}; {gt_masks.shape=}; {valid_objects_prior_to_each_transition[stage_id]=}" + assert new_object_idxs[-1] == ( + len(gt_masks) - 1 + ), f"{new_object_idxs=}; {gt_masks.shape=}" + new_object_masks = gt_masks[new_object_idxs] + + # Remove the new objects from the gt masks + gt_masks = gt_masks[: new_object_idxs[0]] + else: + new_object_masks = None + new_object_idxs = None + + # Get output masks based on this frame's prompts and previous memory + current_out = self.track_step( + frame_idx=stage_id, + is_init_cond_frame=stage_id in init_cond_frames, + backbone_features_interactive=current_backbone_features.get( + "interactive" + ), + backbone_features_propagation=current_backbone_features.get( + "sam2_backbone_out" + ), + image=current_image, + point_inputs=backbone_out["point_inputs_per_frame"].get(stage_id, None), + mask_inputs=backbone_out["mask_inputs_per_frame"].get(stage_id, None), + gt_masks=gt_masks, + frames_to_add_correction_pt=frames_to_add_correction_pt, + output_dict=output_dict, + num_frames=num_frames, + multiplex_state=multiplex_state, + objects_to_interact=objects_to_interact, + new_object_masks=new_object_masks, + new_object_idxs=new_object_idxs, + ) + # Append the output, depending on whether it's a conditioning frame + add_output_as_cond_frame = ( + stage_id in init_cond_frames + or ( + self.add_all_frames_to_correct_as_cond + and stage_id in frames_to_add_correction_pt + ) + or ( + self.add_all_transition_frames_as_cond + and stage_id in transition_points + ) + ) + + if add_output_as_cond_frame: + output_dict["cond_frame_outputs"][stage_id] = current_out + else: + output_dict["non_cond_frame_outputs"][stage_id] = current_out + + output_dict["multiplex_state"] = multiplex_state + + if return_dict: + return output_dict + # turn `output_dict` into a list for loss function + all_frame_outputs = {} + all_frame_outputs.update(output_dict["cond_frame_outputs"]) + all_frame_outputs.update(output_dict["non_cond_frame_outputs"]) + if self.is_dynamic_vos_evaluation: + all_frame_outputs = [all_frame_outputs.get(t) for t in range(num_frames)] + else: + all_frame_outputs = [all_frame_outputs[t] for t in range(num_frames)] + # Make DDP happy with activation checkpointing by removing unused keys + all_frame_outputs = [ + {k: v for k, v in d.items() if k != "obj_ptr"} if d is not None else None + for d in all_frame_outputs + ] + + if self.is_dynamic_vos_evaluation: + object_appearance_order = backbone_out["object_appearance_order"] + num_objects = len(input.find_metadatas[0].coco_image_id) + + # since we have remapped the object appearance order, we would need to map it back here + inverse_object_appearance_order = [None for _ in object_appearance_order] + for idx, obj_id in enumerate(object_appearance_order): + inverse_object_appearance_order[obj_id] = idx + assert all(i is not None for i in inverse_object_appearance_order) + + # this is for a rare case where the dataloader thinks that there is an object + # (is in input.find_metadatas[0].coco_image_id) + # but it is not visible anywhere in the frames + # I suspect this is due to mask resizing (the object is so small that it got lost) + # but I am not 100% sure; haven't investigated yet. + # This only happens if we evaluate on the new (fully annotated) YouTubeVOS set. + if len(inverse_object_appearance_order) < num_objects: + inverse_object_appearance_order.extend( + list(range(len(inverse_object_appearance_order), num_objects)) + ) + + # we need to pad the outputs with zeros (for the frames before the object appears) + last_mask = all_frame_outputs[-1]["pred_masks"] + + shape = last_mask.shape[1:] + dtype = last_mask.dtype + device = last_mask.device + for stage_i, frame_out in enumerate(all_frame_outputs): + if frame_out is None: + all_frame_outputs[stage_i] = { + "pred_masks": torch.zeros( + (num_objects, *shape), device=device, dtype=dtype + ) + } + continue + + pred_mask = frame_out["pred_masks"] + if pred_mask.shape[0] < num_objects: + shape = pred_mask.shape[ + 1: + ] # might have a different shape, e.g., input mask + frame_out["pred_masks"] = torch.cat( + [ + pred_mask, + torch.zeros( + (num_objects - pred_mask.shape[0], *shape), + device=device, + dtype=dtype, + ), + ], + dim=0, + )[inverse_object_appearance_order] + + return all_frame_outputs diff --git a/third_party/sam3/sam3/model/video_tracking_multiplex_demo.py b/third_party/sam3/sam3/model/video_tracking_multiplex_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..c5725770b81beb654d39ac76c175ce0777cf6c6b --- /dev/null +++ b/third_party/sam3/sam3/model/video_tracking_multiplex_demo.py @@ -0,0 +1,3476 @@ +import logging +from collections import OrderedDict +from copy import deepcopy +from typing import Iterable, Optional + +import numpy as np +import torch +from sam3.model.data_misc import NestedTensor +from sam3.model.io_utils import load_video_frames +from sam3.model.multiplex_utils import MultiplexState +from sam3.model.sam3_tracker_utils import fill_holes_in_mask_scores +from sam3.model.video_tracking_multiplex import ( + concat_points, + NO_OBJ_SCORE, + VideoTrackingDynamicMultiplex, +) +from tqdm import tqdm + + +class VideoTrackingMultiplexDemo(VideoTrackingDynamicMultiplex): + """ + The demo class that extends the `VideoTrackingDynamicMultiplex` to handle user interactions + and manage inference states, with support for multi-object tracking. + + Interactions are not yet implemented. + """ + + def __init__( + self, + # whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks; + # note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True) + clear_non_cond_mem_around_input=False, + # whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True). + clear_non_cond_mem_for_multi_obj=False, + # if fill_hole_area > 0, we fill small holes in the final masks up to this area (after resizing them to the original video resolution) + fill_hole_area=0, + # if always_start_from_first_ann_frame is True, we always start tracking from the frame where we receive the first annotation (clicks or mask) + # and ignore the `start_frame_idx` passed to `propagate_in_video` + always_start_from_first_ann_frame=False, + # the maximum number of points to be used in the prompt encoder, which reduce the domain gap between training (that only has 8 points) + # - if it's set to a positive integer, we only take the `max_point_num_in_prompt_enc//2` points and + # the last `(max_point_num_in_prompt_enc - max_point_num_in_prompt_enc//2)` points in the prompt encoder + # - if it's set to 0 or negative, this option is turned off and we use all points in the prompt encoder + max_point_num_in_prompt_enc=16, + non_overlap_masks_for_output=True, + **kwargs, + ): + super().__init__(**kwargs) + + self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input + self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj + self.fill_hole_area = fill_hole_area + self.always_start_from_first_ann_frame = always_start_from_first_ann_frame + self.max_point_num_in_prompt_enc = max_point_num_in_prompt_enc + self.non_overlap_masks_for_output = non_overlap_masks_for_output + + @torch.inference_mode() + def init_state( + self, + video_path, + offload_video_to_cpu, + offload_state_to_cpu, + async_loading_frames=False, + use_torchcodec=False, + use_cv2=False, + ): + """Initialize a inference state.""" + # Make sure that sigmoid is used on mask logits (should be True for all our recent models). + # Since we rely on large negative values as scores for missing objects, the raw logits + # cannot be consumed directly and must be converted into 0~1 range via sigmoid first. + if not self.apply_sigmoid_to_mask_logits_for_mem_enc: + raise NotImplementedError( + "Multi-object tracking requires sigmoid in memory encoder for non-overlapping constraints." + ) + + images, video_height, video_width = load_video_frames( + video_path=video_path, + image_size=self.image_size, + offload_video_to_cpu=offload_video_to_cpu, + async_loading_frames=async_loading_frames, + use_torchcodec=use_torchcodec, + use_cv2=use_cv2, + ) + inference_state = {} + inference_state["images"] = images + inference_state["num_frames"] = len(images) + # whether to offload the video frames to CPU memory + # turning on this option saves the GPU memory with only a very small overhead + inference_state["offload_video_to_cpu"] = offload_video_to_cpu + # whether to offload the inference state to CPU memory + # turning on this option saves the GPU memory at the cost of a lower tracking fps + # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object + # and from 24 to 21 when tracking two objects) + inference_state["offload_state_to_cpu"] = offload_state_to_cpu + # the original video height and width, used for resizing final output scores + inference_state["video_height"] = video_height + inference_state["video_width"] = video_width + inference_state["device"] = torch.device("cuda") + if offload_state_to_cpu: + inference_state["storage_device"] = torch.device("cpu") + else: + inference_state["storage_device"] = torch.device("cuda") + # inputs on each frame + inference_state["point_inputs_per_obj"] = {} + inference_state["mask_inputs_per_obj"] = {} + # visual features on a small number of recently visited frames for quick interactions + inference_state["cached_features"] = {} + # values that don't change across frames (so we only need to hold one copy of them) + inference_state["constants"] = {} + # mapping between client-side object id and model-side object index + inference_state["obj_id_to_idx"] = OrderedDict() + inference_state["obj_idx_to_id"] = OrderedDict() + inference_state["obj_ids"] = [] + # A storage to hold the model's tracking results and states on each frame + inference_state["output_dict"] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + # The index of the frame that received the first annotation + inference_state["first_ann_frame_idx"] = None + # Slice (view) of each object tracking results, sharing the same memory with "output_dict" + inference_state["output_dict_per_obj"] = {} + # A temporary storage to hold new outputs when user interact with a frame + # to add clicks or mask (it's merged into "output_dict" before propagation starts) + inference_state["temp_output_dict_per_obj"] = {} + # Frames that already holds consolidated outputs from click or mask inputs + # (we directly use their consolidated outputs during tracking) + inference_state["consolidated_frame_inds"] = { + "cond_frame_outputs": set(), # set containing frame indices + "non_cond_frame_outputs": set(), # set containing frame indices + } + # metadata for each tracking frame (e.g. which direction it's tracked) + inference_state["tracking_has_started"] = False + inference_state["frames_already_tracked"] = {} + inference_state["multiplex_state"] = None + # Track which frames have been refined by user interaction (per object) + # This is used to distinguish first refinement (fresh) vs subsequent refinements (incremental) + inference_state["user_refined_frames_per_obj"] = {} + # # Warm up the whole model and cache the image feature on frame 0 + # # by making a dummy click on the first frame (and then cleaning it up) + # self.add_new_points( + # inference_state=inference_state, + # frame_idx=0, + # obj_id=1, + # points=torch.tensor([[0.5, 0.5]], dtype=torch.float32), + # labels=torch.tensor([1], dtype=torch.int32), + # clear_old_points=True, + # rel_coordinates=True, + # ) + # self.clear_all_points_in_video(inference_state) + return inference_state + + def _obj_id_to_idx(self, inference_state, obj_id, error_if_new=False): + """Map client-side object id to model-side object index.""" + obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None) + if obj_idx is not None: + return obj_idx + + if ( + self.is_dynamic_model or not inference_state["tracking_has_started"] + ) and not error_if_new: + # get the next object slot + obj_idx = len(inference_state["obj_id_to_idx"]) + inference_state["obj_id_to_idx"][obj_id] = obj_idx + inference_state["obj_idx_to_id"][obj_idx] = obj_id + inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"]) + # set up input and output structures for this object + inference_state["point_inputs_per_obj"][obj_idx] = {} + inference_state["mask_inputs_per_obj"][obj_idx] = {} + inference_state["output_dict_per_obj"][obj_idx] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + inference_state["temp_output_dict_per_obj"][obj_idx] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + return obj_idx + else: + raise RuntimeError( + f"Cannot add new object id {obj_id}. " + f"All existing object ids: {inference_state['obj_ids']}." + ) + + def _obj_idx_to_id(self, inference_state, obj_idx): + """Map model-side object index to client-side object id.""" + return inference_state["obj_idx_to_id"][obj_idx] + + def _get_obj_num(self, inference_state): + """Get the total number of unique object ids received so far in this session.""" + # return len(inference_state["obj_idx_to_id"]) + return inference_state["multiplex_state"].total_valid_entries + + @torch.inference_mode() + def _extract_object_for_interaction(self, inference_state, obj_id, frame_idx): + """ + Extract a single object from multiplex state for singleton interaction. + Adapted from sam3_multiplex_tracking._extract_object_to_singleton_state() + + Returns: + singleton_state: New inference state containing only this object + obj_idx_in_source: Original object index before removal (for merging back) + """ + source_state = inference_state + obj_idx_in_source = source_state["obj_id_to_idx"][obj_id] + + # Step 1: Extract all object data BEFORE removing it + multiplex_state = source_state.get("multiplex_state") + + # Extract consolidated outputs (slice NOW before remove_object modifies tensors) + singleton_consolidated_outputs = { + "cond_frame_outputs": {}, + "non_cond_frame_outputs": {}, + } + + if "output_dict" in source_state: + for storage_key in ["cond_frame_outputs", "non_cond_frame_outputs"]: + source_outputs = source_state["output_dict"].get(storage_key, {}) + + for f_idx, source_frame_out in source_outputs.items(): + # Check if this frame has valid data for this object + has_valid_data = ( + source_frame_out["pred_masks"].shape[0] >= obj_idx_in_source + 1 + ) + + if has_valid_data: + # Create singleton frame output by slicing + singleton_frame_out = { + "pred_masks": source_frame_out["pred_masks"][ + obj_idx_in_source : obj_idx_in_source + 1 + ].clone(), + "object_score_logits": source_frame_out[ + "object_score_logits" + ][obj_idx_in_source : obj_idx_in_source + 1].clone(), + # image_features and image_pos_enc remain shared (not in multiplex space) + "image_features": source_frame_out.get("image_features"), + "image_pos_enc": source_frame_out.get("image_pos_enc"), + "local_obj_id_to_idx": {obj_id: 0}, + } + + # Handle maskmem_features by converting from multiplex space to data space + maskmem_features = source_frame_out.get("maskmem_features") + if maskmem_features is not None: + if multiplex_state is not None: + expected_buckets = multiplex_state.num_buckets + expected_multiplex = multiplex_state.multiplex_count + if ( + maskmem_features.dim() >= 2 + and maskmem_features.shape[0] == expected_buckets + and maskmem_features.shape[1] == expected_multiplex + ): + try: + demuxed_features = multiplex_state.demux( + maskmem_features + ) + except AssertionError as exc: + logging.warning( + "[EXTRACT] demux failed for maskmem_features shape %s: %s", + tuple(maskmem_features.shape), + exc, + ) + demuxed_features = None + if demuxed_features is not None: + maskmem_features = demuxed_features[ + obj_idx_in_source : obj_idx_in_source + 1 + ].clone() + else: + maskmem_features = maskmem_features[ + obj_idx_in_source : obj_idx_in_source + 1 + ].clone() + elif maskmem_features.shape[0] == 0: + # No entries for this object yet; treat as missing without warning + maskmem_features = None + elif maskmem_features.shape[0] >= obj_idx_in_source + 1: + # Already in data space; slice directly + maskmem_features = maskmem_features[ + obj_idx_in_source : obj_idx_in_source + 1 + ].clone() + else: + logging.warning( + "[EXTRACT] maskmem_features shape %s incompatible with multiplex state; dropping tensor", + tuple(maskmem_features.shape), + ) + maskmem_features = None + else: + maskmem_features = maskmem_features[ + obj_idx_in_source : obj_idx_in_source + 1 + ].clone() + singleton_frame_out["maskmem_features"] = maskmem_features + + # Handle maskmem_pos_enc similarly, level by level + maskmem_pos_enc = source_frame_out.get("maskmem_pos_enc") + if maskmem_pos_enc is not None: + remapped_pos_enc = [] + for level_enc in maskmem_pos_enc: + if level_enc is None: + remapped_pos_enc.append(None) + continue + if multiplex_state is not None: + expected_buckets = multiplex_state.num_buckets + expected_multiplex = multiplex_state.multiplex_count + if ( + level_enc.dim() >= 2 + and level_enc.shape[0] == expected_buckets + and level_enc.shape[1] == expected_multiplex + ): + try: + demuxed_level = multiplex_state.demux( + level_enc + ) + except AssertionError as exc: + logging.warning( + "[EXTRACT] demux failed for maskmem_pos_enc level shape %s: %s", + tuple(level_enc.shape), + exc, + ) + demuxed_level = None + if demuxed_level is not None: + remapped_pos_enc.append( + demuxed_level[ + obj_idx_in_source : obj_idx_in_source + + 1 + ].clone() + ) + elif ( + level_enc.shape[0] >= obj_idx_in_source + 1 + ): + remapped_pos_enc.append( + level_enc[ + obj_idx_in_source : obj_idx_in_source + + 1 + ].clone() + ) + else: + logging.warning( + "[EXTRACT] maskmem_pos_enc level shape %s incompatible with multiplex state; dropping level", + tuple(level_enc.shape), + ) + remapped_pos_enc.append(None) + elif level_enc.shape[0] >= obj_idx_in_source + 1: + remapped_pos_enc.append( + level_enc[ + obj_idx_in_source : obj_idx_in_source + + 1 + ].clone() + ) + else: + logging.warning( + "[EXTRACT] maskmem_pos_enc level shape %s incompatible with multiplex state; dropping level", + tuple(level_enc.shape), + ) + remapped_pos_enc.append(None) + else: + remapped_pos_enc.append( + level_enc[ + obj_idx_in_source : obj_idx_in_source + 1 + ].clone() + ) + maskmem_pos_enc = remapped_pos_enc + singleton_frame_out["maskmem_pos_enc"] = maskmem_pos_enc + + # Handle obj_ptr (must demux from multiplex space first) + if ( + "obj_ptr" in source_frame_out + and self.use_obj_ptrs_in_encoder + ): + source_obj_ptr = source_frame_out["obj_ptr"] + if multiplex_state is not None: + # Demux: multiplex space → data space + obj_ptr_data_space = multiplex_state.demux( + source_obj_ptr + ) + # Slice for this object + singleton_frame_out["obj_ptr"] = obj_ptr_data_space[ + obj_idx_in_source : obj_idx_in_source + 1 + ].clone() + else: + singleton_frame_out["obj_ptr"] = source_obj_ptr[ + obj_idx_in_source : obj_idx_in_source + 1 + ].clone() + + # Convert conditioning_objects + if "conditioning_objects" in source_frame_out: + if ( + obj_idx_in_source + in source_frame_out["conditioning_objects"] + ): + singleton_frame_out["conditioning_objects"] = {0} + else: + singleton_frame_out["conditioning_objects"] = set() + + singleton_consolidated_outputs[storage_key][f_idx] = ( + singleton_frame_out + ) + + # Extract point and mask inputs + extracted_point_inputs = {} + extracted_mask_inputs = {} + + if ( + "point_inputs_per_obj" in source_state + and obj_idx_in_source in source_state["point_inputs_per_obj"] + ): + extracted_point_inputs = source_state["point_inputs_per_obj"][ + obj_idx_in_source + ].copy() + + if ( + "mask_inputs_per_obj" in source_state + and obj_idx_in_source in source_state["mask_inputs_per_obj"] + ): + extracted_mask_inputs = source_state["mask_inputs_per_obj"][ + obj_idx_in_source + ].copy() + + # Extract per-object outputs + extracted_obj_cond_outputs = {} + extracted_obj_non_cond_outputs = {} + extracted_temp_cond_outputs = {} + extracted_temp_non_cond_outputs = {} + + if ( + "output_dict_per_obj" in source_state + and obj_idx_in_source in source_state["output_dict_per_obj"] + ): + obj_output_dict = source_state["output_dict_per_obj"][obj_idx_in_source] + extracted_obj_cond_outputs = obj_output_dict.get( + "cond_frame_outputs", {} + ).copy() + extracted_obj_non_cond_outputs = obj_output_dict.get( + "non_cond_frame_outputs", {} + ).copy() + + if ( + "temp_output_dict_per_obj" in source_state + and obj_idx_in_source in source_state["temp_output_dict_per_obj"] + ): + temp_obj_output_dict = source_state["temp_output_dict_per_obj"][ + obj_idx_in_source + ] + extracted_temp_cond_outputs = temp_obj_output_dict.get( + "cond_frame_outputs", {} + ).copy() + extracted_temp_non_cond_outputs = temp_obj_output_dict.get( + "non_cond_frame_outputs", {} + ).copy() + + # Step 2: Remove the object from source state + remaining_obj_ids, _ = self.remove_object( + source_state, + obj_id, + strict=False, + need_output=False, + clear_user_refined_map=False, + ) + + # If multiplex state became empty, reset it so downstream code can reinitialize + updated_multiplex_state = source_state.get("multiplex_state") + if updated_multiplex_state is not None: + if ( + getattr(updated_multiplex_state, "assignments", None) is None + or updated_multiplex_state.total_valid_entries == 0 + ): + source_state["multiplex_state"] = None + + # Step 3: Create new singleton inference state + singleton_state = self.init_state( + cached_features=source_state["cached_features"], + video_height=source_state["video_height"], + video_width=source_state["video_width"], + num_frames=source_state["num_frames"], + ) + + # Step 4: Set up singleton state structure + singleton_state["obj_id_to_idx"] = {obj_id: 0} + singleton_state["obj_idx_to_id"] = {0: obj_id} + singleton_state["obj_ids"] = [obj_id] + singleton_state["point_inputs_per_obj"] = {0: extracted_point_inputs} + singleton_state["mask_inputs_per_obj"] = {0: extracted_mask_inputs} + singleton_state["output_dict_per_obj"] = { + 0: { + "cond_frame_outputs": extracted_obj_cond_outputs, + "non_cond_frame_outputs": extracted_obj_non_cond_outputs, + } + } + singleton_state["temp_output_dict_per_obj"] = { + 0: { + "cond_frame_outputs": extracted_temp_cond_outputs, + "non_cond_frame_outputs": extracted_temp_non_cond_outputs, + } + } + singleton_state["frames_already_tracked"] = source_state[ + "frames_already_tracked" + ].copy() + + # Step 5: Create new singleton multiplex state (even for 1 object, needed for obj_ptr) + new_multiplex_state = self.multiplex_controller.get_state( + num_valid_entries=1, + device=source_state["device"], + dtype=torch.float32, + random=False, + object_ids=[obj_id], + ) + singleton_state["multiplex_state"] = new_multiplex_state + + # Step 6: Remux extracted tensors into the singleton multiplex space + for storage_key in ["cond_frame_outputs", "non_cond_frame_outputs"]: + for f_idx, frame_out in singleton_consolidated_outputs[storage_key].items(): + # mask memory features + if frame_out.get("maskmem_features") is not None: + # Keep mask memory features in data space (num_objects, C, H, W) + frame_out["maskmem_features"] = frame_out[ + "maskmem_features" + ].clone() + + if frame_out.get("maskmem_pos_enc") is not None: + remapped_levels = [] + for level_enc in frame_out["maskmem_pos_enc"]: + if level_enc is None: + remapped_levels.append(None) + continue + remapped_levels.append(level_enc.clone()) + frame_out["maskmem_pos_enc"] = remapped_levels + + # object pointers + if "obj_ptr" in frame_out and self.use_obj_ptrs_in_encoder: + # Mux: data space [1, D] → singleton multiplex space [1, 1, D] + frame_out["obj_ptr"] = new_multiplex_state.mux(frame_out["obj_ptr"]) + + singleton_state["output_dict"] = singleton_consolidated_outputs + + return singleton_state, obj_idx_in_source + + @torch.inference_mode() + def _merge_singleton_interaction_result( + self, + inference_state, + singleton_state, + obj_id, + original_obj_idx, + ): + """ + Merge singleton interaction result back into multiplex state. + + SIMPLIFIED APPROACH: Add object back at the END (new index), not at original position. + This avoids complex index shifting and works with multiplex controller's add_objects() API. + + Args: + inference_state: The main multiplex inference state + singleton_state: The singleton state with interaction results + obj_id: The object ID + original_obj_idx: The original index before extraction (unused - we add at end instead) + """ + # Determine new index (add at end) + new_obj_idx = len(inference_state["obj_ids"]) + + # Step 1: Add object mappings at new index + inference_state["obj_ids"].append(obj_id) + inference_state["obj_id_to_idx"][obj_id] = new_obj_idx + + # Create entry in output_dict_per_obj and temp_output_dict_per_obj for new index + # These are DICTIONARIES indexed by obj_idx, not lists! + inference_state["output_dict_per_obj"][new_obj_idx] = { + "cond_frame_outputs": {}, + "non_cond_frame_outputs": {}, + } + inference_state["temp_output_dict_per_obj"][new_obj_idx] = { + "cond_frame_outputs": {}, + "non_cond_frame_outputs": {}, + } + + inference_state["obj_idx_to_id"][new_obj_idx] = obj_id + + # Step 2: Add object to multiplex state buckets using proper API + multiplex_state = inference_state.get("multiplex_state") + + assignments = ( + getattr(multiplex_state, "assignments", None) + if multiplex_state is not None + else None + ) + total_valid_entries = ( + getattr(multiplex_state, "total_valid_entries", 0) + if multiplex_state is not None and assignments is not None + else 0 + ) + need_state_reinit = ( + multiplex_state is None or assignments is None or total_valid_entries == 0 + ) + + if not need_state_reinit and getattr(multiplex_state, "object_ids", None): + if obj_id in multiplex_state.object_ids: + old_idx = multiplex_state.object_ids.index(obj_id) + multiplex_state.remove_objects(object_indices=[old_idx], strict=False) + assignments = getattr(multiplex_state, "assignments", None) + total_valid_entries = ( + getattr(multiplex_state, "total_valid_entries", 0) + if assignments is not None + else 0 + ) + need_state_reinit = assignments is None or total_valid_entries == 0 + + if need_state_reinit: + inference_state["multiplex_state"] = self.multiplex_controller.get_state( + num_valid_entries=len(inference_state["obj_ids"]), + device=inference_state["device"], + dtype=torch.float32, + random=False, + object_ids=list(inference_state["obj_ids"]), + ) + multiplex_state = inference_state["multiplex_state"] + else: + # Allow new buckets since we're adding at a new index (the old bucket slot may have been removed) + multiplex_state.add_objects( + object_indices=[new_obj_idx], + object_ids=[obj_id], + allow_new_buckets=True, # May need new bucket if old slot was compacted + ) + + # Step 3: Restore point and mask inputs at new index + singleton_obj_idx = 0 # Object is always at index 0 in singleton state + if ( + "point_inputs_per_obj" in singleton_state + and singleton_obj_idx in singleton_state["point_inputs_per_obj"] + ): + if "point_inputs_per_obj" not in inference_state: + inference_state["point_inputs_per_obj"] = {} + inference_state["point_inputs_per_obj"][new_obj_idx] = singleton_state[ + "point_inputs_per_obj" + ][singleton_obj_idx].copy() + + if ( + "mask_inputs_per_obj" in singleton_state + and singleton_obj_idx in singleton_state["mask_inputs_per_obj"] + ): + if "mask_inputs_per_obj" not in inference_state: + inference_state["mask_inputs_per_obj"] = {} + inference_state["mask_inputs_per_obj"][new_obj_idx] = singleton_state[ + "mask_inputs_per_obj" + ][singleton_obj_idx].copy() + + # Step 4: Restore per-object outputs at new index + if ( + "output_dict_per_obj" in singleton_state + and singleton_obj_idx in singleton_state["output_dict_per_obj"] + ): + if "output_dict_per_obj" not in inference_state: + inference_state["output_dict_per_obj"] = {} + inference_state["output_dict_per_obj"][new_obj_idx] = singleton_state[ + "output_dict_per_obj" + ][singleton_obj_idx].copy() + + if ( + "temp_output_dict_per_obj" in singleton_state + and singleton_obj_idx in singleton_state["temp_output_dict_per_obj"] + ): + if "temp_output_dict_per_obj" not in inference_state: + inference_state["temp_output_dict_per_obj"] = {} + inference_state["temp_output_dict_per_obj"][new_obj_idx] = singleton_state[ + "temp_output_dict_per_obj" + ][singleton_obj_idx].copy() + + # Step 5: Merge consolidated outputs back into multiplex (append at new_obj_idx) + # Preserve each frame's original storage key from the singleton state so that + # conditioning frames remain in cond_frame_outputs after the merge. + if "output_dict" in singleton_state: + singleton_multiplex_state = singleton_state.get("multiplex_state") + for singleton_storage_key in [ + "cond_frame_outputs", + "non_cond_frame_outputs", + ]: + singleton_outputs = singleton_state["output_dict"].get( + singleton_storage_key, {} + ) + + # Skip if singleton doesn't have any frames in this storage_key + if not singleton_outputs: + continue + + for frame_idx, singleton_frame_out in singleton_outputs.items(): + # Get or create frame output in main state at the EXPECTED storage_key + if "output_dict" not in inference_state: + inference_state["output_dict"] = { + "cond_frame_outputs": {}, + "non_cond_frame_outputs": {}, + } + + if ( + frame_idx + not in inference_state["output_dict"][singleton_storage_key] + ): + # Frame doesn't exist - create with singleton results at new_obj_idx + num_objs = len(inference_state["obj_ids"]) + + # Ensure num_objs is at least new_obj_idx + 1 + # (in case obj_ids list is somehow inconsistent) + if num_objs <= new_obj_idx: + num_objs = new_obj_idx + 1 + + new_maskmem_features = None + new_maskmem_pos_enc = None + if ( + singleton_frame_out.get("maskmem_features") is not None + and multiplex_state is not None + ): + # Check if singleton features are in multiplexed format and demux if needed + singleton_features_muxed = singleton_frame_out[ + "maskmem_features" + ] + if singleton_features_muxed.shape[:2] == ( + singleton_multiplex_state.num_buckets, + singleton_multiplex_state.multiplex_count, + ): + # Singleton features are multiplexed, need to demux + singleton_features_data = ( + singleton_multiplex_state.demux( + singleton_features_muxed + ) + ) + else: + # Singleton features are in data space + singleton_features_data = singleton_features_muxed + + feature_shape = (num_objs,) + singleton_features_data.shape[ + 1: + ] + maskmem_features_data = torch.zeros( + feature_shape, + dtype=singleton_features_data.dtype, + device=singleton_features_data.device, + ) + maskmem_features_data[new_obj_idx : new_obj_idx + 1] = ( + singleton_features_data + ) + # Mux using destination multiplex state + new_maskmem_features = multiplex_state.mux( + maskmem_features_data + ) + + if ( + singleton_frame_out.get("maskmem_pos_enc") is not None + and multiplex_state is not None + ): + new_maskmem_pos_enc = [] + for level_enc in singleton_frame_out["maskmem_pos_enc"]: + if level_enc is None: + new_maskmem_pos_enc.append(None) + continue + # Check if singleton pos_enc is in multiplexed format and demux if needed + if level_enc.shape[:2] == ( + singleton_multiplex_state.num_buckets, + singleton_multiplex_state.multiplex_count, + ): + # Singleton pos_enc is multiplexed, need to demux + level_data = singleton_multiplex_state.demux( + level_enc + ) + else: + # Singleton pos_enc is in data space + level_data = level_enc + + level_shape = (num_objs,) + level_data.shape[1:] + level_tensor = torch.zeros( + level_shape, + dtype=level_data.dtype, + device=level_data.device, + ) + level_tensor[new_obj_idx : new_obj_idx + 1] = level_data + # Mux using destination multiplex state to store in multiplex format + new_maskmem_pos_enc.append( + multiplex_state.mux(level_tensor) + ) + + inference_state["output_dict"][singleton_storage_key][ + frame_idx + ] = { + "maskmem_features": new_maskmem_features, + "maskmem_pos_enc": new_maskmem_pos_enc, + "image_features": singleton_frame_out.get("image_features"), + "image_pos_enc": singleton_frame_out.get("image_pos_enc"), + "local_obj_id_to_idx": {obj_id: new_obj_idx}, + "conditioning_objects": ( + set([new_obj_idx]) + if singleton_obj_idx + in singleton_frame_out.get( + "conditioning_objects", set() + ) + else set() + ), + "pred_masks": torch.zeros( + ( + num_objs, + 1, + singleton_frame_out["pred_masks"].shape[2], + singleton_frame_out["pred_masks"].shape[3], + ), + dtype=singleton_frame_out["pred_masks"].dtype, + device=singleton_frame_out["pred_masks"].device, + ), + "object_score_logits": torch.full( + (num_objs, 1), + NO_OBJ_SCORE, + dtype=singleton_frame_out["object_score_logits"].dtype, + device=singleton_frame_out[ + "object_score_logits" + ].device, + ), + } + # Set singleton results at new_obj_idx + inference_state["output_dict"][singleton_storage_key][ + frame_idx + ]["pred_masks"][ + new_obj_idx : new_obj_idx + 1 + ] = singleton_frame_out["pred_masks"] + inference_state["output_dict"][singleton_storage_key][ + frame_idx + ]["object_score_logits"][ + new_obj_idx : new_obj_idx + 1 + ] = singleton_frame_out["object_score_logits"] + + # Also copy pred_masks_video_res if it exists in singleton output + if "pred_masks_video_res" in singleton_frame_out: + inference_state["output_dict"][singleton_storage_key][ + frame_idx + ]["pred_masks_video_res"] = torch.zeros( + ( + num_objs, + 1, + singleton_frame_out["pred_masks_video_res"].shape[ + 2 + ], + singleton_frame_out["pred_masks_video_res"].shape[ + 3 + ], + ), + dtype=singleton_frame_out["pred_masks_video_res"].dtype, + device=singleton_frame_out[ + "pred_masks_video_res" + ].device, + ) + inference_state["output_dict"][singleton_storage_key][ + frame_idx + ]["pred_masks_video_res"][ + new_obj_idx : new_obj_idx + 1 + ] = singleton_frame_out["pred_masks_video_res"] + + # Handle obj_ptr if present + if ( + "obj_ptr" in singleton_frame_out + and self.use_obj_ptrs_in_encoder + ): + singleton_obj_ptr_data = singleton_multiplex_state.demux( + singleton_frame_out["obj_ptr"] + ) + obj_ptr_data = torch.zeros( + (num_objs, singleton_obj_ptr_data.shape[1]), + dtype=singleton_obj_ptr_data.dtype, + device=singleton_obj_ptr_data.device, + ) + obj_ptr_data[new_obj_idx : new_obj_idx + 1] = ( + singleton_obj_ptr_data + ) + inference_state["output_dict"][singleton_storage_key][ + frame_idx + ]["obj_ptr"] = multiplex_state.mux(obj_ptr_data) + else: + # Frame exists - expand tensors and add singleton results + main_frame_out = inference_state["output_dict"][ + singleton_storage_key + ][frame_idx] + + num_objs_total = len(inference_state["obj_ids"]) + + if ( + singleton_frame_out.get("maskmem_features") is not None + and multiplex_state is not None + ): + # Check if singleton features are in multiplexed format and demux if needed + singleton_features_muxed = singleton_frame_out[ + "maskmem_features" + ] + if singleton_features_muxed.shape[:2] == ( + singleton_multiplex_state.num_buckets, + singleton_multiplex_state.multiplex_count, + ): + # Singleton features are multiplexed, need to demux + singleton_features_data = ( + singleton_multiplex_state.demux( + singleton_features_muxed + ) + ) + else: + # Singleton features are in data space + singleton_features_data = singleton_features_muxed + + existing_features_muxed = main_frame_out.get( + "maskmem_features" + ) + if existing_features_muxed is not None: + # Check if features are in multiplex format before demuxing + if existing_features_muxed.shape[:2] == ( + multiplex_state.num_buckets, + multiplex_state.multiplex_count, + ): + # Features are in multiplex format, demux them + existing_features_data = multiplex_state.demux( + existing_features_muxed + ) + else: + # Features are already in data space, use directly + existing_features_data = existing_features_muxed + else: + existing_features_data = None + + if existing_features_data is None: + feature_shape = ( + num_objs_total, + ) + singleton_features_data.shape[1:] + existing_features_data = torch.zeros( + feature_shape, + dtype=singleton_features_data.dtype, + device=singleton_features_data.device, + ) + elif existing_features_data.shape[0] < num_objs_total: + pad_size = ( + num_objs_total - existing_features_data.shape[0] + ) + pad = torch.zeros( + (pad_size,) + existing_features_data.shape[1:], + dtype=existing_features_data.dtype, + device=existing_features_data.device, + ) + existing_features_data = torch.cat( + [existing_features_data, pad], dim=0 + ) + + existing_features_data[new_obj_idx : new_obj_idx + 1] = ( + singleton_features_data + ) + main_frame_out["maskmem_features"] = multiplex_state.mux( + existing_features_data + ) + + if ( + singleton_frame_out.get("maskmem_pos_enc") is not None + and multiplex_state is not None + ): + existing_pos_enc_list = ( + main_frame_out.get("maskmem_pos_enc") or [] + ) + new_maskmem_pos_enc = [] + max_levels = max( + len(singleton_frame_out["maskmem_pos_enc"]), + len(existing_pos_enc_list), + ) + for level_idx in range(max_levels): + singleton_level_muxed = ( + singleton_frame_out["maskmem_pos_enc"][level_idx] + if level_idx + < len(singleton_frame_out["maskmem_pos_enc"]) + else None + ) + existing_level_muxed = ( + existing_pos_enc_list[level_idx] + if level_idx < len(existing_pos_enc_list) + else None + ) + + if singleton_level_muxed is None: + # Keep existing entry (which may also be None) + new_maskmem_pos_enc.append(existing_level_muxed) + continue + + # Check if singleton pos_enc is in multiplexed format and demux if needed + if singleton_level_muxed.shape[:2] == ( + singleton_multiplex_state.num_buckets, + singleton_multiplex_state.multiplex_count, + ): + # Singleton pos_enc is multiplexed, need to demux + singleton_level_data = ( + singleton_multiplex_state.demux( + singleton_level_muxed + ) + ) + else: + # Singleton pos_enc is in data space + singleton_level_data = singleton_level_muxed + + if existing_level_muxed is not None: + # Check if pos_enc is in multiplex format before demuxing + if existing_level_muxed.shape[:2] == ( + multiplex_state.num_buckets, + multiplex_state.multiplex_count, + ): + # Positional encoding is in multiplex format, demux it + existing_level_data = multiplex_state.demux( + existing_level_muxed + ) + else: + # Positional encoding is already in data space, use directly + existing_level_data = existing_level_muxed + else: + existing_level_data = None + + if existing_level_data is None: + level_shape = ( + num_objs_total, + ) + singleton_level_data.shape[1:] + existing_level_data = torch.zeros( + level_shape, + dtype=singleton_level_data.dtype, + device=singleton_level_data.device, + ) + elif existing_level_data.shape[0] < num_objs_total: + pad_size = ( + num_objs_total - existing_level_data.shape[0] + ) + pad = torch.zeros( + (pad_size,) + existing_level_data.shape[1:], + dtype=existing_level_data.dtype, + device=existing_level_data.device, + ) + existing_level_data = torch.cat( + [existing_level_data, pad], dim=0 + ) + + existing_level_data[new_obj_idx : new_obj_idx + 1] = ( + singleton_level_data + ) + new_maskmem_pos_enc.append( + multiplex_state.mux(existing_level_data) + ) + + main_frame_out["maskmem_pos_enc"] = new_maskmem_pos_enc + + singleton_pred_masks = singleton_frame_out[ + "pred_masks" + ] # [1, 1, H, W] + singleton_scores = singleton_frame_out[ + "object_score_logits" + ] # [1, 1] + + # Expand tensors if needed + num_existing_objs = main_frame_out["pred_masks"].shape[0] + if new_obj_idx >= num_existing_objs: + num_objs_needed = new_obj_idx + 1 + pad_size = num_objs_needed - num_existing_objs + + main_frame_out["pred_masks"] = torch.cat( + [ + main_frame_out["pred_masks"], + torch.zeros( + ( + pad_size, + 1, + singleton_pred_masks.shape[2], + singleton_pred_masks.shape[3], + ), + dtype=singleton_pred_masks.dtype, + device=singleton_pred_masks.device, + ), + ], + dim=0, + ) + + main_frame_out["object_score_logits"] = torch.cat( + [ + main_frame_out["object_score_logits"], + torch.full( + (pad_size, 1), + NO_OBJ_SCORE, + dtype=singleton_scores.dtype, + device=singleton_scores.device, + ), + ], + dim=0, + ) + + # Set singleton results at new_obj_idx + main_frame_out["pred_masks"][new_obj_idx : new_obj_idx + 1] = ( + singleton_pred_masks + ) + main_frame_out["object_score_logits"][ + new_obj_idx : new_obj_idx + 1 + ] = singleton_scores + # Initialize local_obj_id_to_idx if missing (e.g., frame + # output was created by VG propagation's track_step which + # does not populate this field). + if "local_obj_id_to_idx" not in main_frame_out: + main_frame_out["local_obj_id_to_idx"] = deepcopy( + inference_state["obj_id_to_idx"] + ) + main_frame_out["local_obj_id_to_idx"][obj_id] = new_obj_idx + + # Also expand and copy pred_masks_video_res if it exists in singleton output + if "pred_masks_video_res" in singleton_frame_out: + if "pred_masks_video_res" in main_frame_out: + # Expand existing video_res masks + if ( + main_frame_out["pred_masks_video_res"].shape[0] + < new_obj_idx + 1 + ): + pad_size = ( + new_obj_idx + + 1 + - main_frame_out["pred_masks_video_res"].shape[ + 0 + ] + ) + main_frame_out["pred_masks_video_res"] = torch.cat( + [ + main_frame_out["pred_masks_video_res"], + torch.zeros( + ( + pad_size, + 1, + singleton_frame_out[ + "pred_masks_video_res" + ].shape[2], + singleton_frame_out[ + "pred_masks_video_res" + ].shape[3], + ), + dtype=singleton_frame_out[ + "pred_masks_video_res" + ].dtype, + device=singleton_frame_out[ + "pred_masks_video_res" + ].device, + ), + ], + dim=0, + ) + else: + # Create new video_res masks tensor + num_objs = len(inference_state["obj_ids"]) + main_frame_out["pred_masks_video_res"] = torch.zeros( + ( + num_objs, + 1, + singleton_frame_out[ + "pred_masks_video_res" + ].shape[2], + singleton_frame_out[ + "pred_masks_video_res" + ].shape[3], + ), + dtype=singleton_frame_out[ + "pred_masks_video_res" + ].dtype, + device=singleton_frame_out[ + "pred_masks_video_res" + ].device, + ) + # Set singleton video_res mask + main_frame_out["pred_masks_video_res"][ + new_obj_idx : new_obj_idx + 1 + ] = singleton_frame_out["pred_masks_video_res"] + + # Handle obj_ptr + if ( + "obj_ptr" in singleton_frame_out + and self.use_obj_ptrs_in_encoder + ): + singleton_obj_ptr_data = singleton_multiplex_state.demux( + singleton_frame_out["obj_ptr"] + ) # [1, D] + + if "obj_ptr" in main_frame_out: + # The existing obj_ptr may have been created with a DIFFERENT number of buckets + # (before we called multiplex_state.add_objects() which may have created new buckets). + # We need to infer the OLD bucket count from the tensor shape to demux it correctly. + + old_obj_ptr_muxed = main_frame_out["obj_ptr"] + # Infer old bucket count: shape is [B_old, M_old, D] + old_num_buckets = old_obj_ptr_muxed.shape[1] + + # Create temporary multiplex state with old bucket count to demux + if old_num_buckets != multiplex_state.num_buckets: + # Bucket count changed - cannot safely demux old obj_ptr + # Instead, create new obj_ptr from scratch for all objects + num_objs = len(inference_state["obj_ids"]) + obj_ptr_data = torch.zeros( + (num_objs, singleton_obj_ptr_data.shape[1]), + dtype=singleton_obj_ptr_data.dtype, + device=singleton_obj_ptr_data.device, + ) + # Only set the singleton object's ptr, leave others as zeros + obj_ptr_data[new_obj_idx : new_obj_idx + 1] = ( + singleton_obj_ptr_data + ) + main_frame_out["obj_ptr"] = multiplex_state.mux( + obj_ptr_data + ) + else: + # Bucket count matches - safe to demux + main_obj_ptr_data = multiplex_state.demux( + old_obj_ptr_muxed + ) + + # Expand if needed + if main_obj_ptr_data.shape[0] < new_obj_idx + 1: + pad_size = ( + new_obj_idx + 1 - main_obj_ptr_data.shape[0] + ) + main_obj_ptr_data = torch.cat( + [ + main_obj_ptr_data, + torch.zeros( + ( + pad_size, + main_obj_ptr_data.shape[1], + ), + dtype=main_obj_ptr_data.dtype, + device=main_obj_ptr_data.device, + ), + ], + dim=0, + ) + + main_obj_ptr_data[new_obj_idx : new_obj_idx + 1] = ( + singleton_obj_ptr_data + ) + main_frame_out["obj_ptr"] = multiplex_state.mux( + main_obj_ptr_data + ) + else: + # Create new obj_ptr + num_objs = len(inference_state["obj_ids"]) + obj_ptr_data = torch.zeros( + (num_objs, singleton_obj_ptr_data.shape[1]), + dtype=singleton_obj_ptr_data.dtype, + device=singleton_obj_ptr_data.device, + ) + obj_ptr_data[new_obj_idx : new_obj_idx + 1] = ( + singleton_obj_ptr_data + ) + main_frame_out["obj_ptr"] = multiplex_state.mux( + obj_ptr_data + ) + + # Update conditioning_objects + if singleton_obj_idx in singleton_frame_out.get( + "conditioning_objects", set() + ): + main_frame_out["conditioning_objects"].add(new_obj_idx) + + @torch.inference_mode() + def add_new_points( + self, + inference_state, + frame_idx, + obj_id, + points, + labels, + clear_old_points, + rel_coordinates=True, + use_prev_mem_frame=False, + ): + """ + Add new points to create a new object in the multiplex model. + + This method converts point inputs to masks via the interactivity head and adds + the new object to the existing multiplex bucket (for dynamic models). + + Args: + inference_state: Current inference state + frame_idx: Frame index to add points + obj_id: Object ID (will be auto-created if new) + points: Point coordinates tensor + labels: Point labels tensor (1 for positive, 0 for negative) + clear_old_points: Whether to clear old points on this frame + rel_coordinates: Whether points are in relative coordinates [0, 1] + use_prev_mem_frame: Whether to use previous memory frames (for compatibility) + + Returns: + Tuple of (frame_idx, obj_ids, low_res_masks, video_res_masks) + """ + obj_idx = self._obj_id_to_idx(inference_state, obj_id) + obj_idxs = [obj_idx] + obj_ids = [obj_id] + + point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx] + mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx] + + if points.dim() == 2: + points = points.unsqueeze(0) + if labels.dim() == 1: + labels = labels.unsqueeze(0) + + if rel_coordinates: + points = points * self.image_size + + points = points.to(inference_state["device"]) + labels = labels.to(inference_state["device"]) + + if not clear_old_points: + old_point_inputs = point_inputs_per_frame.get(frame_idx, None) + else: + old_point_inputs = None + + point_inputs = concat_points(old_point_inputs, points, labels) + point_inputs_per_frame[frame_idx] = point_inputs + + is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"] + + if is_init_cond_frame: + reverse = False + else: + reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"] + + is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + + multiplex_state = inference_state["multiplex_state"] + is_new_state = multiplex_state is None + + if is_new_state: + multiplex_state = self.multiplex_controller.get_state( + num_valid_entries=1, + device=inference_state["device"], + dtype=torch.float32, + random=False, + object_ids=obj_ids, + ) + inference_state["multiplex_state"] = multiplex_state + + # Determine interaction case: + # - New object: never seen before + # - Refine: existing mask on tracked frame + # - Gap fill: object exists but frame has no output + is_existing_object = ( + not is_new_state + and multiplex_state is not None + and obj_id in multiplex_state.object_ids + ) + + if is_existing_object: + if is_init_cond_frame: + is_new_obj = False + is_refine = False + is_gap_fill_case = True + else: + is_new_obj = False + is_refine = True + is_gap_fill_case = False + else: + is_new_obj = True + is_refine = False + is_gap_fill_case = False + + if is_new_obj: + should_add_to_existing = not is_new_state + allow_new_buckets_local = True + prefer_new_buckets_local = True + + current_out, _ = self._run_single_frame_inference( + inference_state=inference_state, + output_dict=inference_state["output_dict"], + frame_idx=frame_idx, + batch_size=1, + is_init_cond_frame=True, + point_inputs=point_inputs, + mask_inputs=None, + reverse=False, + run_mem_encoder=False, + prev_sam_mask_logits=None, + add_to_existing_state=should_add_to_existing, + new_obj_idxs=obj_idxs, + new_obj_ids=obj_ids, + allow_new_buckets=allow_new_buckets_local, + prefer_new_buckets=prefer_new_buckets_local, + objects_to_interact=None, + ) + elif is_refine: + singleton_state, original_obj_idx = self._extract_object_for_interaction( + inference_state, obj_id, frame_idx + ) + + user_refined_frames_map = inference_state.get( + "user_refined_frames_per_obj", {} + ) + user_refined_frames = user_refined_frames_map.get(obj_id) + if user_refined_frames is None: + user_refined_frames = set() + is_first_refinement = frame_idx not in user_refined_frames + + prev_sam_mask_logits_singleton = None + if not is_first_refinement: + singleton_obj_idx = 0 + singleton_output_dict = singleton_state["output_dict_per_obj"][ + singleton_obj_idx + ] + singleton_temp_output_dict = singleton_state[ + "temp_output_dict_per_obj" + ][singleton_obj_idx] + + # Check BOTH storage keys since previous refinement might be in a different key + # (e.g., first refinement creates cond_frame, but after propagation, + # second refinement on same frame would look for non_cond_frame) + prev_out = None + + storage_key_current = ( + "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + ) + prev_out = singleton_temp_output_dict[storage_key_current].get( + frame_idx + ) + + if prev_out is None: + prev_out = singleton_output_dict["cond_frame_outputs"].get( + frame_idx + ) + if prev_out is None: + prev_out = singleton_output_dict["non_cond_frame_outputs"].get( + frame_idx + ) + + if prev_out is not None and prev_out["pred_masks"] is not None: + prev_sam_mask_logits_singleton = prev_out["pred_masks"].cuda( + non_blocking=True + ) + prev_sam_mask_logits_singleton = torch.clamp( + prev_sam_mask_logits_singleton, -32.0, 32.0 + ) + + if is_first_refinement: + # ALWAYS use is_init_cond_frame=True to force interaction_only mode + # for fresh segmentation from points (not refinement of propagated mask). + singleton_is_init_cond = True + singleton_objects_to_interact = None + else: + # Second+ refinement: Incremental refinement for quality improvement + singleton_is_init_cond = False + singleton_objects_to_interact = ( + [0] if prev_sam_mask_logits_singleton is not None else None + ) + + singleton_obj_idx = 0 + singleton_obj_idxs = [singleton_obj_idx] + singleton_obj_ids = [obj_id] + + current_out, _ = self._run_single_frame_inference( + inference_state=singleton_state, + output_dict=singleton_state["output_dict"], + frame_idx=frame_idx, + batch_size=1, + is_init_cond_frame=singleton_is_init_cond, + point_inputs=point_inputs, + mask_inputs=None, + reverse=False, + run_mem_encoder=False, + prev_sam_mask_logits=prev_sam_mask_logits_singleton, + add_to_existing_state=False, + new_obj_idxs=singleton_obj_idxs, + new_obj_ids=singleton_obj_ids, + allow_new_buckets=False, + objects_to_interact=singleton_objects_to_interact, + ) + + singleton_storage_key = ( + "cond_frame_outputs" + if singleton_is_init_cond + else "non_cond_frame_outputs" + ) + + _, singleton_video_res_masks = self._get_orig_video_res_output( + singleton_state, current_out["pred_masks"] + ) + current_out["pred_masks_video_res"] = singleton_video_res_masks + + singleton_state["output_dict"][singleton_storage_key][frame_idx] = ( + current_out + ) + + self._merge_singleton_interaction_result( + inference_state, singleton_state, obj_id, original_obj_idx + ) + + obj_idx = inference_state["obj_id_to_idx"][obj_id] + obj_idxs = [obj_idx] + + if "user_refined_frames_per_obj" not in inference_state: + inference_state["user_refined_frames_per_obj"] = {} + if obj_id not in inference_state["user_refined_frames_per_obj"]: + inference_state["user_refined_frames_per_obj"][obj_id] = set() + + inference_state["user_refined_frames_per_obj"][obj_id].add(frame_idx) + + merged_frame_out = inference_state["output_dict"][singleton_storage_key][ + frame_idx + ] + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] + + if "pred_masks_video_res" in merged_frame_out: + pred_masks_video_res_slice = merged_frame_out["pred_masks_video_res"][ + obj_idx : obj_idx + 1 + ] + else: + _, video_res_masks = self._get_orig_video_res_output( + inference_state, merged_frame_out["pred_masks"] + ) + pred_masks_video_res_slice = video_res_masks[obj_idx : obj_idx + 1] + + pred_masks_slice = merged_frame_out["pred_masks"][obj_idx : obj_idx + 1] + + obj_temp_output_dict[singleton_storage_key][frame_idx] = { + "pred_masks": pred_masks_slice, + "pred_masks_video_res": pred_masks_video_res_slice, + "object_score_logits": merged_frame_out["object_score_logits"][ + obj_idx : obj_idx + 1 + ], + } + obj_output_dict[singleton_storage_key][frame_idx] = obj_temp_output_dict[ + singleton_storage_key + ][frame_idx] + + elif is_gap_fill_case: + # Gap fill: Run inference directly in multiplex mode (no singleton extraction) + # Even though is_init_cond_frame=True, we use add_to_existing_state=False + # because the object ALREADY EXISTS in multiplex state. + obj_idx = inference_state["obj_id_to_idx"][obj_id] + obj_idxs = [obj_idx] + batch_size = self._get_obj_num(inference_state) + + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] + + current_out, _ = self._run_single_frame_inference( + inference_state=inference_state, + output_dict=inference_state["output_dict"], + frame_idx=frame_idx, + batch_size=batch_size, + is_init_cond_frame=True, + point_inputs=point_inputs, + mask_inputs=None, + reverse=False, + run_mem_encoder=False, + prev_sam_mask_logits=None, + add_to_existing_state=False, + new_obj_idxs=[obj_idx], + new_obj_ids=[obj_id], + allow_new_buckets=False, + prefer_new_buckets=False, + objects_to_interact=[obj_idx], + ) + + current_out["local_obj_id_to_idx"] = deepcopy( + inference_state["obj_id_to_idx"] + ) + + _, video_res_masks = self._get_orig_video_res_output( + inference_state, current_out["pred_masks"] + ) + current_out["pred_masks_video_res"] = video_res_masks + + is_cond = storage_key == "cond_frame_outputs" + if ( + is_cond + and frame_idx + in inference_state["output_dict"]["non_cond_frame_outputs"] + ): + del inference_state["output_dict"]["non_cond_frame_outputs"][frame_idx] + if "consolidated_frame_inds" in inference_state: + inference_state["consolidated_frame_inds"][ + "non_cond_frame_outputs" + ].discard(frame_idx) + + # Store consolidated output (has obj_ptr, maskmem_features, etc.) + inference_state["output_dict"][storage_key][frame_idx] = current_out + + # Mark as consolidated + if "consolidated_frame_inds" in inference_state: + inference_state["consolidated_frame_inds"][storage_key].add(frame_idx) + + # Also store per-object slices in temp_output_dict_per_obj + obj_temp_output_dict[storage_key][frame_idx] = { + "pred_masks": current_out["pred_masks"][obj_idx : obj_idx + 1], + "pred_masks_video_res": video_res_masks[obj_idx : obj_idx + 1], + "object_score_logits": current_out["object_score_logits"][ + obj_idx : obj_idx + 1 + ], + } + obj_output_dict[storage_key][frame_idx] = obj_temp_output_dict[storage_key][ + frame_idx + ] + + # Store outputs and prepare return values + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] + + # For refinement/gap fill (singleton extraction), handle singleton output specially + if is_refine or is_gap_fill_case: + # Singleton case: The merge already updated the consolidated output_dict during merge. + # However, we need to ensure the frame is properly stored and marked. + + singleton_obj_idx = 0 + + # Get video resolution masks from singleton output + _, video_res_masks_singleton = self._get_orig_video_res_output( + inference_state, current_out["pred_masks"] + ) + + # Mark frame as consolidated (prevents double consolidation in preflight) + if "consolidated_frame_inds" in inference_state: + inference_state["consolidated_frame_inds"][storage_key].add(frame_idx) + + # For return value, use singleton masks + video_res_masks_to_return = video_res_masks_singleton[ + singleton_obj_idx : singleton_obj_idx + 1 + ] + else: + # Standard multiplex output - use obj_idx + _, video_res_masks = self._get_orig_video_res_output( + inference_state, current_out["pred_masks"] + ) + + current_out["pred_masks_video_res"] = video_res_masks + current_out["local_obj_id_to_idx"] = deepcopy( + inference_state["obj_id_to_idx"] + ) + + # Remove from non_cond if this becomes a cond frame + if ( + is_cond + and frame_idx + in inference_state["output_dict"]["non_cond_frame_outputs"] + ): + del inference_state["output_dict"]["non_cond_frame_outputs"][frame_idx] + # Also update consolidated_frame_inds + if "consolidated_frame_inds" in inference_state: + inference_state["consolidated_frame_inds"][ + "non_cond_frame_outputs" + ].discard(frame_idx) + + inference_state["output_dict"][storage_key][frame_idx] = current_out + + # Update consolidated_frame_inds to track this frame + if "consolidated_frame_inds" in inference_state: + inference_state["consolidated_frame_inds"][storage_key].add(frame_idx) + + # Store per-object outputs (slice from the full multiplex output) + obj_temp_output_dict[storage_key][frame_idx] = { + "pred_masks_video_res": current_out["pred_masks_video_res"][ + obj_idx : obj_idx + 1 + ], + "pred_masks": current_out["pred_masks"][obj_idx : obj_idx + 1], + "object_score_logits": current_out["object_score_logits"][ + obj_idx : obj_idx + 1 + ], + } + + obj_output_dict[storage_key][frame_idx] = obj_temp_output_dict[storage_key][ + frame_idx + ] + + video_res_masks_to_return = video_res_masks[obj_idx : obj_idx + 1] + + low_res_masks = None + return frame_idx, obj_ids, low_res_masks, video_res_masks_to_return + + @torch.inference_mode() + def add_new_masks( + self, + inference_state, + frame_idx, + obj_ids, + masks, + # for compatibility with per_obj_inference class, not used here + add_mask_to_memory=False, + # for object reconditioning; do not update the multiplex state + reconditioning=False, + ): + """Add new mask to a frame.""" + if isinstance(obj_ids, np.ndarray): + obj_ids = obj_ids.tolist() + obj_idxs = [ + self._obj_id_to_idx(inference_state, obj_id, error_if_new=reconditioning) + for obj_id in obj_ids + ] + point_inputs_per_frame = [ + inference_state["point_inputs_per_obj"][obj_idx] for obj_idx in obj_idxs + ] + mask_inputs_per_frame = [ + inference_state["mask_inputs_per_obj"][obj_idx] for obj_idx in obj_idxs + ] + + assert masks.dim() == 3 + num_objects, mask_H, mask_W = masks.shape + assert num_objects == len(obj_ids) + masks_inputs_orig = masks[:, None, :, :] # add channel dimension + masks_inputs_orig = masks_inputs_orig.float().to(inference_state["device"]) + + # resize the mask if it doesn't match the model's input mask size + if mask_H != self.input_mask_size or mask_W != self.input_mask_size: + mask_inputs = torch.nn.functional.interpolate( + masks_inputs_orig, + size=(self.input_mask_size, self.input_mask_size), + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + else: + mask_inputs = masks_inputs_orig + + # also get the mask at the original video resolution (for outputting) + video_H = inference_state["video_height"] + video_W = inference_state["video_width"] + if mask_H != video_H or mask_W != video_W: + mask_inputs_video_res = torch.nn.functional.interpolate( + masks_inputs_orig, + size=(video_H, video_W), + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for potential downsampling + ) + else: + mask_inputs_video_res = masks_inputs_orig + # convert mask_inputs_video_res to binary (threshold at 0.5 as it is in range 0~1) + mask_inputs_video_res = mask_inputs_video_res > 0.5 + + multiplex_state = inference_state["multiplex_state"] + is_new_state = multiplex_state is None + + if not reconditioning: + if is_new_state: + multiplex_state = self.multiplex_controller.get_state( + num_valid_entries=num_objects, + device=inference_state["device"], + dtype=torch.float32, # lower precision is also fine + random=False, + object_ids=obj_ids, + ) + inference_state["multiplex_state"] = multiplex_state + else: + assert ( + self.is_dynamic_model + ), "New objects are not allowed after state creation" + + for i in range(num_objects): + mask_inputs_per_frame[i][frame_idx] = mask_inputs_video_res[i : i + 1] + point_inputs_per_frame[i].pop(frame_idx, None) + # If this frame hasn't been tracked before, we treat it as an initial conditioning + # frame, meaning that the inputs points are to generate segments on this frame without + # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), + # the input points will be used to correct the already tracked masks. + is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"] + # whether to track in reverse time order + if is_init_cond_frame: + reverse = False + else: + reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"] + obj_output_dicts = [ + inference_state["output_dict_per_obj"][obj_idx] for obj_idx in obj_idxs + ] + obj_temp_output_dicts = [ + inference_state["temp_output_dict_per_obj"][obj_idx] for obj_idx in obj_idxs + ] + # Add a frame to conditioning output if it's an initial conditioning frame or + # if the model sees all frames receiving clicks/mask as conditioning frames. + is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + + # Allow creating a new bucket only when existing buckets cannot fit the new objects + allow_new_buckets_local = False + if not is_new_state and not reconditioning and multiplex_state is not None: + if multiplex_state.available_slots < num_objects: + allow_new_buckets_local = True + + current_out, _ = self._run_single_frame_inference( + inference_state=inference_state, + output_dict=inference_state["output_dict"], + frame_idx=frame_idx, + batch_size=num_objects, + is_init_cond_frame=is_init_cond_frame, + point_inputs=None, + mask_inputs=mask_inputs, + reverse=reverse, + # Skip the memory encoder when adding clicks or mask. We execute the memory encoder + # at the beginning of `propagate_in_video` (after user finalize their clicks). This + # allows us to enforce non-overlapping constraints on all objects before encoding + # them into memory. + run_mem_encoder=False, + add_to_existing_state=not is_new_state and not reconditioning, + new_obj_idxs=obj_idxs, + new_obj_ids=obj_ids, + allow_new_buckets=allow_new_buckets_local, + reconditioning=reconditioning, + ) + # We directly use the input mask at video resolution as the output mask for a better + # video editing experience (so that the masks don't change after each brushing). + # Here NO_OBJ_SCORE is a large negative value to represent the background and + # similarly -NO_OBJ_SCORE is a large positive value to represent the foreground. + _, video_res_masks = self._get_orig_video_res_output( + inference_state, current_out["pred_masks"] + ) + obj_idxs_t = torch.as_tensor(obj_idxs, device=video_res_masks.device) + video_res_masks[obj_idxs_t] = torch.where( + mask_inputs_video_res, -NO_OBJ_SCORE, NO_OBJ_SCORE + ) + + current_out["pred_masks_video_res"] = video_res_masks + with torch.profiler.record_function("add_new_masks._deepcopy"): + current_out["local_obj_id_to_idx"] = deepcopy( + inference_state["obj_id_to_idx"] + ) + if ( + is_cond + and frame_idx in inference_state["output_dict"]["non_cond_frame_outputs"] + ): + del inference_state["output_dict"]["non_cond_frame_outputs"][frame_idx] + # Also update consolidated_frame_inds + if "consolidated_frame_inds" in inference_state: + inference_state["consolidated_frame_inds"][ + "non_cond_frame_outputs" + ].discard(frame_idx) + + inference_state["output_dict"][storage_key][frame_idx] = current_out + + # Update consolidated_frame_inds to track this frame + if "consolidated_frame_inds" in inference_state: + inference_state["consolidated_frame_inds"][storage_key].add(frame_idx) + + with torch.profiler.record_function("add_new_masks.obj_loop"): + # Step 1: Set all new object masks first (batched) + for i, obj_idx in enumerate(obj_idxs): + # Add the predicted masks to the output dict + # NOTE: object ordering matters here but I guess this is the same for the per-object implementation + obj_temp_output_dicts[i][storage_key][frame_idx] = { + "pred_masks_video_res": current_out["pred_masks_video_res"][ + obj_idx : obj_idx + 1 + ] + } + obj_output_dicts[i][storage_key][frame_idx] = obj_temp_output_dicts[i][ + storage_key + ][frame_idx] + + # Step 2: Precompute suppress masks to avoid O(n*m) torch.where calls + # Combined mask of all new objects (for existing objects) + combined_new_mask = mask_inputs_video_res.any( + dim=0, keepdim=True + ) # (1, 1, H, W) + + # Precompute exclude-self masks for new objects (if there are multiple new objects) + num_new = len(obj_idxs) + exclude_self_masks = {} + if num_new > 1: + for i in range(num_new): + other_indices = torch.cat( + [ + torch.arange(i, device=mask_inputs_video_res.device), + torch.arange( + i + 1, num_new, device=mask_inputs_video_res.device + ), + ] + ) + exclude_self_masks[obj_idxs[i]] = mask_inputs_video_res[ + other_indices + ].any(dim=0, keepdim=True) + + # Step 3: Apply suppression to all objects in a single pass + temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] + obj_idxs_set = set(obj_idxs) + + for obj_idx2, obj_temp_output_dict2 in temp_output_dict_per_obj.items(): + current_out2 = obj_temp_output_dict2[storage_key].get(frame_idx, None) + if current_out2 is None: + continue + + if obj_idx2 not in obj_idxs_set: + # Existing object: suppress by all new masks + suppress_mask = combined_new_mask + elif obj_idx2 in exclude_self_masks: + # New object: suppress by other new objects' masks + suppress_mask = exclude_self_masks[obj_idx2] + else: + # Only one new object - nothing to suppress for itself + continue + + current_out2["pred_masks_video_res"] = torch.where( + suppress_mask, + NO_OBJ_SCORE, + current_out2["pred_masks_video_res"], + ) + + # Resize the output mask to the original video resolution + obj_ids = inference_state["obj_ids"] + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, + frame_idx, + is_cond=is_cond, + run_mem_encoder=False, + consolidate_at_video_res=True, + ) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, consolidated_out["pred_masks_video_res"] + ) + low_res_masks = None # not needed by the demo + + consolidated_out["local_obj_id_to_idx"] = current_out["local_obj_id_to_idx"] + + return frame_idx, obj_ids, low_res_masks, video_res_masks + + def _get_orig_video_res_output(self, inference_state, any_res_masks): + """ + Resize the object scores to the original video resolution (video_res_masks) + and apply non-overlapping constraints for final output. + """ + device = inference_state["device"] + video_H = inference_state["video_height"] + video_W = inference_state["video_width"] + any_res_masks = any_res_masks.to(device, non_blocking=True) + if any_res_masks.shape[-2:] == (video_H, video_W): + video_res_masks = any_res_masks + else: + video_res_masks = torch.nn.functional.interpolate( + any_res_masks, + size=(video_H, video_W), + mode="bilinear", + align_corners=False, + ) + if self.non_overlap_masks_for_output: + video_res_masks = self._apply_non_overlapping_constraints(video_res_masks) + # potentially fill holes in the predicted masks + if self.fill_hole_area > 0: + video_res_masks = fill_holes_in_mask_scores( + video_res_masks, self.fill_hole_area + ) + return any_res_masks, video_res_masks + + def _consolidate_temp_output_across_obj( + self, + inference_state, + frame_idx, + is_cond, + run_mem_encoder, + consolidate_at_video_res=False, + ): + """ + Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on + a frame into a single output for all objects, including + 1) fill any missing objects either from `output_dict_per_obj` (if they exist in + `output_dict_per_obj` for this frame) or leave them as placeholder values + (if they don't exist in `output_dict_per_obj` for this frame); + 2) if specified, rerun memory encoder after apply non-overlapping constraints + on the object scores. + """ + batch_size = self._get_obj_num(inference_state) + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + + # After singleton merge, objects can be added at indices beyond batch_size + # We need to find the maximum object index that has temp or regular outputs to size the tensor correctly + max_obj_idx = batch_size - 1 # Default to batch_size - 1 + + # Check both temp and regular output dicts to find max index + for obj_idx in inference_state["temp_output_dict_per_obj"].keys(): + if obj_idx > max_obj_idx: + max_obj_idx = obj_idx + for obj_idx in inference_state["output_dict_per_obj"].keys(): + if obj_idx > max_obj_idx: + max_obj_idx = obj_idx + + # Size the consolidated tensor to accommodate all object indices (not just count) + consolidated_batch_size = max(max_obj_idx + 1, 0) # Ensure non-negative + + # Optionally, we allow consolidating the temporary outputs at the original + # video resolution (to provide a better editing experience for mask prompts). + if consolidate_at_video_res: + assert not run_mem_encoder, "memory encoder cannot run at video resolution" + consolidated_H = inference_state["video_height"] + consolidated_W = inference_state["video_width"] + consolidated_mask_key = "pred_masks_video_res" + else: + consolidated_H = consolidated_W = self.low_res_mask_size + consolidated_mask_key = "pred_masks" + + # Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc" + # will be added when rerunning the memory encoder after applying non-overlapping + # constraints to object scores. Its "pred_masks" are prefilled with a large + # negative value (NO_OBJ_SCORE) to represent missing objects. + + consolidated_out = { + "conditioning_objects": None, + "maskmem_features": None, + "maskmem_pos_enc": None, + "image_features": None, + "image_pos_enc": None, + "obj_ptr": None, + consolidated_mask_key: torch.full( + size=( + consolidated_batch_size, + 1, + consolidated_H, + consolidated_W, + ), # Use consolidated_batch_size, not batch_size! + fill_value=NO_OBJ_SCORE, + dtype=torch.float32, + device=inference_state["storage_device"], + ), + } + + all_out = inference_state["output_dict"]["cond_frame_outputs"].get( + frame_idx, None + ) + if all_out is None: + all_out = inference_state["output_dict"]["non_cond_frame_outputs"].get( + frame_idx, None + ) + + # Handle the case where output_dict is empty (e.g., during demo VG propagation) + # In this case, we'll reconstruct the consolidated output from per-object outputs + need_to_reconstruct_from_per_obj = all_out is None + + if need_to_reconstruct_from_per_obj: + # Initialize fields that will be populated from per-object outputs or later + # Determine which objects are conditioned by checking if they have point/mask inputs on this frame + conditioning_objects = set() + for obj_idx in range(batch_size): + # Check if this object has point inputs on this frame + if obj_idx in inference_state["point_inputs_per_obj"]: + point_inputs = inference_state["point_inputs_per_obj"][obj_idx] + if ( + frame_idx in point_inputs + and point_inputs[frame_idx] is not None + ): + conditioning_objects.add(obj_idx) + continue + + # Check if this object has mask inputs on this frame + if obj_idx in inference_state["mask_inputs_per_obj"]: + mask_inputs = inference_state["mask_inputs_per_obj"][obj_idx] + if frame_idx in mask_inputs and mask_inputs[frame_idx] is not None: + conditioning_objects.add(obj_idx) + + consolidated_out["conditioning_objects"] = conditioning_objects + # Shared features will be populated when running memory encoder + # Note: obj_ptr and object_score_logits will be populated from per-object outputs below + else: + # Normal case: populate from existing consolidated output + consolidated_out["conditioning_objects"] = all_out.get( + "conditioning_objects", set() + ) + consolidated_out["obj_ptr"] = all_out["obj_ptr"] + consolidated_out["object_score_logits"] = all_out["object_score_logits"] + if self.use_memory_selection: + consolidated_out["iou_score"] = all_out["iou_score"] + # These fields might not exist in per-object outputs (e.g., after singleton extraction) + consolidated_out["maskmem_features"] = all_out.get("maskmem_features") + consolidated_out["maskmem_pos_enc"] = all_out.get("maskmem_pos_enc") + consolidated_out["image_features"] = all_out.get("image_features") + consolidated_out["image_pos_enc"] = all_out.get("image_pos_enc") + consolidated_out["local_obj_id_to_idx"] = all_out.get( + "local_obj_id_to_idx", {} + ) + consolidated_out["obj_ptr"] = all_out["obj_ptr"] + consolidated_out["object_score_logits"] = all_out["object_score_logits"] + if self.use_memory_selection: + consolidated_out["iou_score"] = all_out["iou_score"] + # These fields might not exist in per-object outputs (e.g., after singleton extraction) + consolidated_out["maskmem_features"] = all_out.get("maskmem_features") + consolidated_out["maskmem_pos_enc"] = all_out.get("maskmem_pos_enc") + consolidated_out["image_features"] = all_out.get("image_features") + consolidated_out["image_pos_enc"] = all_out.get("image_pos_enc") + consolidated_out["local_obj_id_to_idx"] = all_out.get( + "local_obj_id_to_idx", {} + ) + all_mask = all_out.get("pred_masks_video_res", all_out["pred_masks"]) + # Ensure masks are at the correct consolidated resolution + # This handles the case where all_out has interactive resolution (288) masks + # that need to be resized to SAM2's low_res_mask_size (256) for consistency + if all_mask.shape[-2:] == (consolidated_H, consolidated_W): + consolidated_out[consolidated_mask_key] = all_mask + else: + # Resize first if mask has a different resolution (e.g., 288 from interactive) + # Determine if we're downsampling or upsampling + is_downsampling = all_mask.shape[-1] > consolidated_W + resized_mask = torch.nn.functional.interpolate( + all_mask, + size=(consolidated_H, consolidated_W), + mode="bilinear", + align_corners=False, + antialias=is_downsampling, # use antialias for downsampling + ) + consolidated_out[consolidated_mask_key] = resized_mask + + # Collect per-object outputs (masks and scores) to build consolidated output + # When reconstructing from per-object outputs, we also need to collect obj_ptr and object_score_logits + obj_score_logits_list = [] + obj_ptr_list = [] if need_to_reconstruct_from_per_obj else None + iou_scores_list = ( + [] + if need_to_reconstruct_from_per_obj and self.use_memory_selection + else None + ) + + # When reconstructing from per-object outputs, initialize the mask tensor + # with the correct size (consolidated_batch_size, not batch_size) + if ( + need_to_reconstruct_from_per_obj + and consolidated_mask_key not in consolidated_out + ): + # Initialize with zeros - will be populated from per-object outputs below + consolidated_out[consolidated_mask_key] = torch.zeros( + (consolidated_batch_size, 1, consolidated_H, consolidated_W), + dtype=torch.float32, + device=inference_state["storage_device"], + ) + consolidated_out["object_score_logits"] = torch.full( + (consolidated_batch_size, 1), + NO_OBJ_SCORE, + dtype=torch.float32, + device=inference_state["storage_device"], + ) + + for obj_idx in range( + consolidated_batch_size + ): # Use consolidated_batch_size instead of batch_size + # Check if this object index exists in temp/output dicts (it may not if object was just added) + if obj_idx not in inference_state["temp_output_dict_per_obj"]: + continue + if obj_idx not in inference_state["output_dict_per_obj"]: + continue + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + out = obj_temp_output_dict[storage_key].get(frame_idx, None) + # If the object doesn't appear in "temp_output_dict_per_obj" on this frame, + # we fall back and look up its previous output in "output_dict_per_obj". + # We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in + # "output_dict_per_obj" to find a previous output for this object. + if out is None: + out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None) + if out is None: + out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None) + if out is None: + # object pointers are filled globally above; we don't need empty_mask_ptr + continue + # Add the temporary object output mask to consolidated output mask + # (use "pred_masks_video_res" if it's available) + obj_mask = out.get("pred_masks_video_res") + if obj_mask is None: + obj_mask = out.get("pred_masks") + consolidated_pred_masks = consolidated_out[consolidated_mask_key] + + # If obj_idx is beyond the consolidated_pred_masks size, + # we need to expand it (can happen after singleton merge adds object at end) + if obj_idx >= consolidated_pred_masks.shape[0]: + pad_size = obj_idx + 1 - consolidated_pred_masks.shape[0] + consolidated_pred_masks = torch.cat( + [ + consolidated_pred_masks, + torch.zeros( + ( + pad_size, + 1, + consolidated_pred_masks.shape[-2], + consolidated_pred_masks.shape[-1], + ), + dtype=consolidated_pred_masks.dtype, + device=consolidated_pred_masks.device, + ), + ], + dim=0, + ) + consolidated_out[consolidated_mask_key] = consolidated_pred_masks + # Also expand object_score_logits if present + if "object_score_logits" in consolidated_out: + consolidated_scores = consolidated_out["object_score_logits"] + consolidated_scores = torch.cat( + [ + consolidated_scores, + torch.full( + (pad_size, 1), + NO_OBJ_SCORE, + dtype=consolidated_scores.dtype, + device=consolidated_scores.device, + ), + ], + dim=0, + ) + consolidated_out["object_score_logits"] = consolidated_scores + + if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]: + # Ensure dtype match between source and destination before assignment + if obj_mask.dtype != consolidated_pred_masks.dtype: + obj_mask = obj_mask.to(consolidated_pred_masks.dtype) + consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask + else: + # Resize first if temporary object mask has a different resolution + is_downsampling = "pred_masks_video_res" in out + resized_obj_mask = torch.nn.functional.interpolate( + obj_mask, + size=consolidated_pred_masks.shape[-2:], + mode="bilinear", + align_corners=False, + antialias=is_downsampling, # use antialias for downsampling + ) + # Ensure dtype match between source and destination before assignment + if resized_obj_mask.dtype != consolidated_pred_masks.dtype: + resized_obj_mask = resized_obj_mask.to( + consolidated_pred_masks.dtype + ) + consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask + + # When reconstructing from per-object outputs, also collect scores + if need_to_reconstruct_from_per_obj: + if "object_score_logits" in out: + obj_score_logits_list.append(out["object_score_logits"]) + if self.use_memory_selection and "iou_score" in out: + iou_scores_list.append(out["iou_score"]) + + # If we reconstructed from per-object outputs, consolidate the score fields + if need_to_reconstruct_from_per_obj: + # Check if we have ANY valid per-object outputs + # If not, we're trying to consolidate a VG-propagated frame that was never + # stored in output_dict (only in cached_frame_outputs) + # In this case, we SKIP memory encoding during preflight and will do it + # during the first propagation step instead + if not obj_score_logits_list and run_mem_encoder: + run_mem_encoder = False # Skip for now, will encode during propagation + + if obj_score_logits_list: + consolidated_out["object_score_logits"] = torch.cat( + obj_score_logits_list, dim=0 + ) + else: + # Create placeholder scores - these will be replaced when memory encoder runs + device = inference_state["device"] + consolidated_out["object_score_logits"] = torch.zeros( + (batch_size, 1), + dtype=torch.float32, + device=device, + ) + + if self.use_memory_selection: + if iou_scores_list: + consolidated_out["iou_score"] = torch.cat(iou_scores_list, dim=0) + else: + consolidated_out["iou_score"] = None + + # obj_ptr will be populated by memory encoder, set to None for now + consolidated_out["obj_ptr"] = None + + # Optionally, apply non-overlapping constraints on the consolidated scores + # and rerun the memory encoder + if run_mem_encoder: + device = inference_state["device"] + high_res_masks = torch.nn.functional.interpolate( + consolidated_out["pred_masks"].to(device, non_blocking=True), + size=(self.image_size, self.image_size), + mode="bilinear", + align_corners=False, + ) + high_res_masks = self._apply_non_overlapping_constraints(high_res_masks) + maskmem_features, maskmem_pos_enc, image_features, image_pos_enc = ( + self._run_memory_encoder( + inference_state=inference_state, + frame_idx=frame_idx, + batch_size=batch_size, + high_res_masks=high_res_masks, + object_score_logits=consolidated_out["object_score_logits"], + is_mask_from_pts=True, # these frames are what the user interacted with + conditioning_objects=consolidated_out[ + "conditioning_objects" + ], # Pass conditioning_objects + ) + ) + consolidated_out["maskmem_features"] = maskmem_features + consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc + consolidated_out["image_features"] = image_features + consolidated_out["image_pos_enc"] = image_pos_enc + + return consolidated_out + + @torch.inference_mode() + def propagate_in_video_preflight(self, inference_state, run_mem_encoder=True): + """Prepare inference_state and consolidate temporary outputs before tracking.""" + inference_state["tracking_has_started"] = True + batch_size = self._get_obj_num(inference_state) + + # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and + # add them into "output_dict". + temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] + output_dict = inference_state["output_dict"] + # "consolidated_frame_inds" contains indices of those frames where consolidated + # temporary outputs have been added (either in this call or any previous calls + # to `propagate_in_video_preflight`). + consolidated_frame_inds = inference_state["consolidated_frame_inds"] + for is_cond in [False, True]: + # Separately consolidate conditioning and non-conditioning temp outptus + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + # Find all the frames that contain temporary outputs for any objects + # (these should be the frames that have just received clicks for mask inputs + # via `add_new_points` or `add_new_mask`) + temp_frame_inds = set() + for obj_temp_output_dict in temp_output_dict_per_obj.values(): + temp_frame_inds.update(obj_temp_output_dict[storage_key].keys()) + consolidated_frame_inds[storage_key].update(temp_frame_inds) + # consolidate the temprary output across all objects on this frame + for frame_idx in temp_frame_inds: + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, + frame_idx, + is_cond=is_cond, + run_mem_encoder=run_mem_encoder, + ) + # merge them into "output_dict" and also create per-object slices + output_dict[storage_key][frame_idx] = consolidated_out + self._add_output_per_object( + inference_state, frame_idx, consolidated_out, storage_key + ) + clear_non_cond_mem = self.clear_non_cond_mem_around_input and ( + self.clear_non_cond_mem_for_multi_obj or batch_size <= 1 + ) + if clear_non_cond_mem: + # clear non-conditioning memory of the surrounding frames + self._clear_non_cond_mem_around_input(inference_state, frame_idx) + + # clear temporary outputs in `temp_output_dict_per_obj` + for obj_temp_output_dict in temp_output_dict_per_obj.values(): + obj_temp_output_dict[storage_key].clear() + + # edge case: if an output is added to "cond_frame_outputs", we remove any prior + # output on the same frame in "non_cond_frame_outputs" + for frame_idx in output_dict["cond_frame_outputs"]: + output_dict["non_cond_frame_outputs"].pop(frame_idx, None) + for obj_output_dict in inference_state["output_dict_per_obj"].values(): + for frame_idx in obj_output_dict["cond_frame_outputs"]: + obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None) + for frame_idx in consolidated_frame_inds["cond_frame_outputs"]: + assert frame_idx in output_dict["cond_frame_outputs"] + consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx) + + # Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames + # with either points or mask inputs (which should be true under a correct demo workflow). + all_consolidated_frame_inds = ( + consolidated_frame_inds["cond_frame_outputs"] + | consolidated_frame_inds["non_cond_frame_outputs"] + ) + + input_frames_inds = set() + for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values(): + input_frames_inds.update(point_inputs_per_frame.keys()) + for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values(): + input_frames_inds.update(mask_inputs_per_frame.keys()) + assert all_consolidated_frame_inds == input_frames_inds + # Record the first interacted frame index (for tracking start) + if inference_state["first_ann_frame_idx"] is None: + inference_state["first_ann_frame_idx"] = min( + input_frames_inds, default=None + ) + # In case `first_ann_frame_idx` is not in the conditioning frames (e.g. because + # we cleared the input points on that frame), pick the first conditioning frame + if ( + inference_state["first_ann_frame_idx"] + not in output_dict["cond_frame_outputs"] + ): + inference_state["first_ann_frame_idx"] = min( + output_dict["cond_frame_outputs"], default=None + ) + + def _get_processing_order( + self, inference_state, start_frame_idx, max_frame_num_to_track, reverse + ): + num_frames = inference_state["num_frames"] + # set start index, end index, and processing order + if self.always_start_from_first_ann_frame: + # in this case, we always start tracking from the frame where we receive + # the initial annotation and ignore the provided start_frame_idx + start_frame_idx = inference_state["first_ann_frame_idx"] + if start_frame_idx is None: + # default: start from the earliest frame with input points + start_frame_idx = min(inference_state["output_dict"]["cond_frame_outputs"]) + if max_frame_num_to_track is None: + # default: track all the frames in the video + max_frame_num_to_track = num_frames + if reverse: + end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0) + if start_frame_idx > 0: + processing_order = range(start_frame_idx, end_frame_idx - 1, -1) + else: + # TODO: Jie - this is the edge case that we start from frame 0 and track in reverse order; + # and in the case we track a single frame for dense tracking, it should still run 1 frame (idx=0). + # Not sure if this has any side effect. + # processing_order = [] # skip reverse tracking if starting from frame 0 <-- original behaviour + processing_order = [0] + else: + end_frame_idx = min( + start_frame_idx + max_frame_num_to_track, num_frames - 1 + ) + processing_order = range(start_frame_idx, end_frame_idx + 1) + return processing_order + + @torch.inference_mode() + def propagate_in_video( + self, + inference_state, + start_frame_idx, + max_frame_num_to_track, + reverse, + tqdm_disable=False, + obj_ids=None, + run_mem_encoder=True, + ): + """Propagate the input points across frames to track in the entire video.""" + output_dict = inference_state["output_dict"] + consolidated_frame_inds = inference_state["consolidated_frame_inds"] + if obj_ids is not None: + raise NotImplementedError( + "Per-object tracking yet for batched inference if not implemented." + ) + obj_ids = inference_state["obj_ids"] + batch_size = self._get_obj_num(inference_state) + if len(output_dict["cond_frame_outputs"]) == 0: + raise RuntimeError("No points are provided; please add points first") + clear_non_cond_mem = self.clear_non_cond_mem_around_input and ( + self.clear_non_cond_mem_for_multi_obj or batch_size <= 1 + ) + assert clear_non_cond_mem is False, "Not implemented" + + processing_order = self._get_processing_order( + inference_state, + start_frame_idx, + max_frame_num_to_track, + reverse, + ) + + for frame_idx in tqdm( + processing_order, desc="propagate in video", disable=tqdm_disable + ): + # We skip those frames already in consolidated outputs (these are frames + # that received input clicks or mask). Note that we cannot directly run + # batched forward on them via `_run_single_frame_inference` because the + # number of clicks on each object might be different. + if frame_idx in consolidated_frame_inds["cond_frame_outputs"]: + storage_key = "cond_frame_outputs" + current_out = output_dict[storage_key][frame_idx] + pred_masks = current_out["pred_masks"] + if clear_non_cond_mem: + # clear non-conditioning memory of the surrounding frames + self._clear_non_cond_mem_around_input(inference_state, frame_idx) + elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]: + storage_key = "non_cond_frame_outputs" + current_out = output_dict[storage_key][frame_idx] + pred_masks = current_out["pred_masks"] + else: + storage_key = "non_cond_frame_outputs" + with torch.profiler.record_function( + "VideoTrackingMultiplexDemo._run_single_frame_inference" + ): + current_out, pred_masks = self._run_single_frame_inference( + inference_state=inference_state, + output_dict=output_dict, + frame_idx=frame_idx, + batch_size=batch_size, + is_init_cond_frame=False, + point_inputs=None, + mask_inputs=None, + reverse=reverse, + run_mem_encoder=run_mem_encoder, + ) + current_out["local_obj_id_to_idx"] = deepcopy( + inference_state["obj_id_to_idx"] + ) + output_dict[storage_key][frame_idx] = current_out + # Create slices of per-object outputs for subsequent interaction with each + # individual object after tracking. + self._add_output_per_object( + inference_state, frame_idx, current_out, storage_key + ) + inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse} + + # Resize the output mask to the original video resolution (we directly use + # the mask scores on GPU for output to avoid any CPU conversion in between) + low_res_masks, video_res_masks = self._get_orig_video_res_output( + inference_state, pred_masks + ) + yield frame_idx, obj_ids, low_res_masks, video_res_masks + + def _add_output_per_object( + self, inference_state, frame_idx, current_out, storage_key + ): + """ + Split a multi-object output into per-object output slices and add them into + `output_dict_per_obj`. The resulting slices share the same tensor storage. + """ + # Note for the multiplex model: we don't store the maskmem features + # because we don't use the memory during interaction + + output_dict_per_obj = inference_state["output_dict_per_obj"] + for obj_idx, obj_output_dict in output_dict_per_obj.items(): + obj_slice = slice(obj_idx, obj_idx + 1) + obj_out = { + "pred_masks": current_out["pred_masks"][obj_slice], + "object_score_logits": current_out["object_score_logits"][obj_slice], + } + if self.use_memory_selection: + obj_out["iou_score"] = current_out["iou_score"][obj_slice] + obj_output_dict[storage_key][frame_idx] = obj_out + + @torch.inference_mode() + def clear_all_points_in_frame( + self, + inference_state, + frame_idx, + obj_id, + need_output=True, + preserve_user_refined: bool = False, + ): + """Remove all input points or mask in a specific frame for a given object.""" + obj_idx = self._obj_id_to_idx(inference_state, obj_id) + + # Clear the conditioning information on the given frame + inference_state["point_inputs_per_obj"][obj_idx].pop(frame_idx, None) + inference_state["mask_inputs_per_obj"][obj_idx].pop(frame_idx, None) + + # Clear user refinement tracking for this frame and object unless preserving it + if ( + not preserve_user_refined + and "user_refined_frames_per_obj" in inference_state + ): + user_refined_map = inference_state["user_refined_frames_per_obj"] + if obj_id in user_refined_map: + user_refined_map[obj_id].discard(frame_idx) + + temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] + temp_output_dict_per_obj[obj_idx]["cond_frame_outputs"].pop(frame_idx, None) + temp_output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].pop(frame_idx, None) + + # Check and see if there are still any inputs left on this frame + batch_size = self._get_obj_num(inference_state) + frame_has_input = False + for obj_idx2 in range(batch_size): + # Skip if this object doesn't exist in the input dictionaries + if obj_idx2 not in inference_state["point_inputs_per_obj"]: + continue + if obj_idx2 not in inference_state["mask_inputs_per_obj"]: + continue + if frame_idx in inference_state["point_inputs_per_obj"][obj_idx2]: + frame_has_input = True + break + if frame_idx in inference_state["mask_inputs_per_obj"][obj_idx2]: + frame_has_input = True + break + + # If this frame has no remaining inputs for any objects, we further clear its + # conditioning frame status + if not frame_has_input: + output_dict = inference_state["output_dict"] + consolidated_frame_inds = inference_state["consolidated_frame_inds"] + consolidated_frame_inds["cond_frame_outputs"].discard(frame_idx) + consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx) + # Remove the frame's conditioning output (possibly downgrading it to non-conditioning) + out = output_dict["cond_frame_outputs"].pop(frame_idx, None) + if out is not None: + # The frame is not a conditioning frame anymore since it's not receiving inputs, + # so we "downgrade" its output (if exists) to a non-conditioning frame output. + output_dict["non_cond_frame_outputs"][frame_idx] = out + inference_state["frames_already_tracked"].pop(frame_idx, None) + # Similarly, do it for the sliced output on each object. + for obj_idx2 in range(batch_size): + # Skip if this object doesn't exist in the output dictionary + if obj_idx2 not in inference_state["output_dict_per_obj"]: + continue + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx2] + obj_out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None) + if obj_out is not None: + obj_output_dict["non_cond_frame_outputs"][frame_idx] = obj_out + + # If all the conditioning frames have been removed, we also clear the tracking outputs + if len(output_dict["cond_frame_outputs"]) == 0: + self._reset_tracking_results(inference_state) + + if not need_output: + return + # Finally, output updated masks per object (after removing the inputs above) + obj_ids = inference_state["obj_ids"] + is_cond = any( + frame_idx in obj_temp_output_dict["cond_frame_outputs"] + for obj_temp_output_dict in temp_output_dict_per_obj.values() + ) + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, + frame_idx, + is_cond=is_cond, + run_mem_encoder=False, + consolidate_at_video_res=True, + ) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, consolidated_out["pred_masks_video_res"] + ) + low_res_masks = None # not needed by the demo + return frame_idx, obj_ids, low_res_masks, video_res_masks + + @torch.inference_mode() + def clear_all_points_in_video(self, inference_state): + """Remove all input points or mask in all frames throughout the video.""" + self._reset_tracking_results(inference_state) + # Remove all object ids + inference_state["obj_id_to_idx"].clear() + inference_state["obj_idx_to_id"].clear() + inference_state["obj_ids"].clear() + inference_state["point_inputs_per_obj"].clear() + inference_state["mask_inputs_per_obj"].clear() + inference_state["output_dict_per_obj"].clear() + inference_state["temp_output_dict_per_obj"].clear() + inference_state["multiplex_state"] = None + + def _reset_tracking_results(self, inference_state): + """Reset all tracking inputs and results across the videos.""" + for v in inference_state["point_inputs_per_obj"].values(): + v.clear() + for v in inference_state["mask_inputs_per_obj"].values(): + v.clear() + for v in inference_state["output_dict_per_obj"].values(): + v["cond_frame_outputs"].clear() + v["non_cond_frame_outputs"].clear() + for v in inference_state["temp_output_dict_per_obj"].values(): + v["cond_frame_outputs"].clear() + v["non_cond_frame_outputs"].clear() + inference_state["output_dict"]["cond_frame_outputs"].clear() + inference_state["output_dict"]["non_cond_frame_outputs"].clear() + inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear() + inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear() + inference_state["tracking_has_started"] = False + inference_state["frames_already_tracked"].clear() + inference_state["first_ann_frame_idx"] = None + + def _get_image_feature(self, inference_state, frame_idx, batch_size): + """Compute the image features on a given frame.""" + # Look up in the cache first + image, backbone_out = inference_state["cached_features"].get( + frame_idx, (None, None) + ) + if backbone_out is None: + # Cache miss -- we will run inference on a single image + image = inference_state["images"][frame_idx].cuda().float().unsqueeze(0) + # TODO: We should optimize this because we don't always need all three outs + backbone_out = self.forward_image( + NestedTensor(tensors=image, mask=None), + need_sam3_out=True, + need_interactive_out=True, + need_propagation_out=True, + ) + # Cache the most recent frame's feature (for repeated interactions with + # a frame; we can use an LRU cache for more frames in the future). + inference_state["cached_features"] = {frame_idx: (image, backbone_out)} + + features = self._prepare_backbone_features(backbone_out) + return image, features + + def _run_single_frame_inference( + self, + inference_state, + output_dict, + frame_idx, + batch_size, + is_init_cond_frame, + point_inputs, + mask_inputs, + reverse, + run_mem_encoder, + prev_sam_mask_logits=None, + add_to_existing_state: bool = False, + new_obj_idxs: Optional[list[int]] = None, + new_obj_ids: Optional[list[int]] = None, + allow_new_buckets: bool = False, + prefer_new_buckets: bool = False, + reconditioning: bool = False, + objects_to_interact: Optional[list[int]] = None, + ): + """Run tracking on a single frame based on current inputs and previous memory.""" + # Retrieve correct image features + with torch.profiler.record_function( + "VideoTrackingMultiplexDemo._get_image_feature" + ): + image, backbone_features = self._get_image_feature( + inference_state, frame_idx, batch_size + ) + + if add_to_existing_state or reconditioning: + assert new_obj_idxs is not None + assert new_obj_ids is not None + + backbone_features_interactive = backbone_features["interactive"] + backbone_features_propagation = backbone_features["sam2_backbone_out"] + + if add_to_existing_state or reconditioning: + with torch.profiler.record_function( + "VideoTrackingMultiplexDemo.add_new_masks_to_existing_state" + ): + # Get existing output from current frame to modify in-place + # Try both storage keys since the output could be in either location + existing_out = output_dict["cond_frame_outputs"].get(frame_idx) + if existing_out is None: + existing_out = output_dict["non_cond_frame_outputs"].get(frame_idx) + if existing_out is None: + raise RuntimeError( + f"No existing output found for frame {frame_idx} in either storage" + ) + + # Prepare interactive features + interactive_pix_feat = self._get_interactive_pix_mem( + backbone_features_interactive["vision_feats"], + backbone_features_interactive["feat_sizes"], + ) + + # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW + interactive_high_res_features = [ + x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) + for x, s in zip( + backbone_features_interactive["vision_feats"][:-1], + backbone_features_interactive["feat_sizes"][:-1], + ) + ] + + # Prepare propagation features for memory encoding + propagation_vision_feats = ( + backbone_features_propagation["vision_feats"] + if run_mem_encoder + else None + ) + propagation_feat_sizes = ( + backbone_features_propagation["feat_sizes"] + if run_mem_encoder + else None + ) + + # Add new masks to existing state + if reconditioning: + self.recondition_masks_in_existing_state( + interactive_pix_feat=interactive_pix_feat, + interactive_high_res_features=interactive_high_res_features, + propagation_vision_feats=propagation_vision_feats, + propagation_feat_sizes=propagation_feat_sizes, + new_masks=mask_inputs, + obj_idxs_in_mask=new_obj_idxs, + obj_ids_in_mask=new_obj_ids, + prev_output=existing_out, + multiplex_state=inference_state["multiplex_state"], + add_mask_to_memory=run_mem_encoder, + ) + else: + # If we are adding to existing state using points (mask_inputs is None), + # first convert points -> masks via the interactivity head. + new_masks_from_points = None + if mask_inputs is None and point_inputs is not None: + with torch.profiler.record_function( + "VideoTrackingMultiplexDemo.points_to_masks" + ): + multimask_output = self._use_multimask( + is_init_cond_frame, point_inputs=point_inputs + ) + interaction_out = self._forward_sam_heads( + backbone_features=interactive_pix_feat, + point_inputs=point_inputs, + mask_inputs=None, + interactive_high_res_features=interactive_high_res_features, + multimask_output=multimask_output, + objects_to_interact=new_obj_idxs, + multiplex_state=inference_state["multiplex_state"], + ) + new_masks_from_points = interaction_out["low_res_masks"] + + self.add_new_masks_to_existing_state( + interactive_pix_feat=interactive_pix_feat, + interactive_high_res_features=interactive_high_res_features, + propagation_vision_feats=propagation_vision_feats, + propagation_feat_sizes=propagation_feat_sizes, + new_masks=( + mask_inputs + if mask_inputs is not None + else new_masks_from_points + ), + obj_idxs_in_mask=new_obj_idxs, + obj_ids_in_mask=new_obj_ids, + prev_output=existing_out, + multiplex_state=inference_state["multiplex_state"], + add_mask_to_memory=run_mem_encoder, + are_masks_from_pts=(mask_inputs is None), + allow_new_buckets=allow_new_buckets, + prefer_new_buckets=prefer_new_buckets, + ) + + # Return the modified existing output + current_out = existing_out + else: + # point and mask should not appear as input simultaneously on the same frame + assert point_inputs is None or mask_inputs is None + with torch.profiler.record_function( + "VideoTrackingMultiplexDemo.track_step" + ): + current_out = self.track_step( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + backbone_features_interactive=backbone_features_interactive, + backbone_features_propagation=backbone_features_propagation, + image=image, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + gt_masks=None, + frames_to_add_correction_pt=[], + output_dict=output_dict, + num_frames=inference_state["num_frames"], + track_in_reverse=reverse, + run_mem_encoder=run_mem_encoder, + prev_sam_mask_logits=prev_sam_mask_logits, + multiplex_state=inference_state["multiplex_state"], + objects_to_interact=objects_to_interact, + ) + + # optionally offload the output to CPU memory to save GPU space + storage_device = inference_state["storage_device"] + if current_out.get("maskmem_features") is not None: + maskmem_features = current_out["maskmem_features"] + maskmem_features = maskmem_features.to( + device=storage_device, dtype=torch.bfloat16, non_blocking=True + ) + else: + maskmem_features = None + + if current_out.get("image_features") is not None: + assert "image_pos_enc" in current_out + image_features = current_out["image_features"].to( + storage_device, non_blocking=True + ) + image_pos_enc = current_out["image_pos_enc"].to( + storage_device, non_blocking=True + ) + else: + image_features = image_pos_enc = None + + pred_masks_gpu = current_out["pred_masks"] + pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True) + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + with torch.profiler.record_function( + "VideoTrackingMultiplexDemo.maskmem_pos_enc" + ): + maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out) + # object pointer is a small tensor, so we always keep it on GPU memory for fast access + obj_ptr = current_out["obj_ptr"] + object_score_logits = current_out["object_score_logits"] + conditioning_objects = current_out["conditioning_objects"] + # make a compact version of this frame's output to reduce the state size + compact_current_out = { + "maskmem_features": maskmem_features, + "maskmem_pos_enc": maskmem_pos_enc, + "image_features": image_features, + "image_pos_enc": image_pos_enc, + "pred_masks": pred_masks, + "obj_ptr": obj_ptr, + "object_score_logits": object_score_logits, + "conditioning_objects": conditioning_objects, + } + if self.use_memory_selection: + with torch.profiler.record_function( + "VideoTrackingMultiplexDemo.use_memory_selection" + ): + compact_current_out["iou_score"] = current_out["iou_score"] + compact_current_out["eff_iou_score"] = self.cal_mem_score( + object_score_logits, current_out["iou_score"] + ) + return compact_current_out, pred_masks_gpu + + def _run_memory_encoder( + self, + inference_state, + frame_idx, + batch_size, + high_res_masks, + object_score_logits, + is_mask_from_pts, + conditioning_objects=None, # Accept as parameter + ): + """ + Run the memory encoder on `high_res_masks`. This is usually after applying + non-overlapping constraints to object scores. Since their scores changed, their + memory also need to be computed again with the memory encoder. + """ + # Retrieve correct image features + image, backbone_features = self._get_image_feature( + inference_state, frame_idx, batch_size + ) + backbone_features_propagation = backbone_features["sam2_backbone_out"] + propagation_vision_feats = backbone_features_propagation["vision_feats"] + propagation_vision_pos_embeds = backbone_features_propagation[ + "vision_pos_embeds" + ] + propagation_feat_sizes = backbone_features_propagation["feat_sizes"] + + # If conditioning_objects is not provided, look it up from output_dict + if conditioning_objects is None: + output_dict = inference_state["output_dict"] + for storage_key in ["cond_frame_outputs", "non_cond_frame_outputs"]: + storage = output_dict[storage_key] + if frame_idx not in storage: + continue + conditioning_objects = storage[frame_idx]["conditioning_objects"] + break + else: + raise ValueError(f"conditioning objects not found at {frame_idx=}") + + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + image=image, + current_vision_feats=propagation_vision_feats, + feat_sizes=propagation_feat_sizes, + pred_masks_high_res=high_res_masks, + object_score_logits=object_score_logits, + is_mask_from_pts=is_mask_from_pts, + conditioning_objects=conditioning_objects, + multiplex_state=inference_state["multiplex_state"], + ) + + # optionally offload the output to CPU memory to save GPU space + storage_device = inference_state["storage_device"] + maskmem_features = maskmem_features.to(torch.bfloat16) + maskmem_features = maskmem_features.to(storage_device, non_blocking=True) + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + maskmem_pos_enc = self._get_maskmem_pos_enc( + inference_state, {"maskmem_pos_enc": maskmem_pos_enc} + ) + + image_features = propagation_vision_feats[-1] + image_features = image_features.to(storage_device, non_blocking=True) + image_pos_enc = propagation_vision_pos_embeds[-1] + image_pos_enc = image_pos_enc.to(storage_device, non_blocking=True) + return maskmem_features, maskmem_pos_enc, image_features, image_pos_enc + + def _get_maskmem_pos_enc(self, inference_state, current_out): + """ + `maskmem_pos_enc` is the same across frames and objects, so we cache it as + a constant in the inference session to reduce session storage size. + """ + model_constants = inference_state["constants"] + # "out_maskmem_pos_enc" should be either a list of tensors or None + out_maskmem_pos_enc = current_out.get("maskmem_pos_enc") + if out_maskmem_pos_enc is not None: + if "maskmem_pos_enc" not in model_constants: + assert isinstance(out_maskmem_pos_enc, list) + # only take the slice for one object, since it's same across objects + maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc] + model_constants["maskmem_pos_enc"] = maskmem_pos_enc + else: + maskmem_pos_enc = model_constants["maskmem_pos_enc"] + # expand the cached maskmem_pos_enc to the actual batch size + batch_size = out_maskmem_pos_enc[0].size(0) + expanded_maskmem_pos_enc = [ + x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc + ] + else: + expanded_maskmem_pos_enc = None + return expanded_maskmem_pos_enc + + @torch.inference_mode() + def remove_object( + self, + inference_state, + obj_id: int, + strict=False, + need_output=True, + clear_user_refined_map: bool = True, + ): + """ + Remove a single object from the tracking state. + + This is a convenience wrapper around remove_objects() for removing a single object. + + Args: + inference_state: Current inference state + obj_id: Object ID to remove + strict: If True, raise error if object doesn't exist + need_output: Whether to return updated frames + + Returns: + Tuple of (remaining_obj_ids, updated_frames) + """ + return self.remove_objects( + inference_state, + obj_ids=[obj_id], + strict=strict, + need_output=need_output, + clear_user_refined_map=clear_user_refined_map, + ) + + @torch.inference_mode() + def remove_objects( + self, + inference_state, + obj_ids: Iterable[int], + strict=False, + need_output=True, + clear_user_refined_map: bool = True, + ): + """ + Remove a list of object ids from the tracking state. If strict is True, we check whether + the object ids actually exist and raise an error if any of them don't exist. + """ + obj_ids = list(obj_ids) + old_obj_idxs_to_rm = [ + inference_state["obj_id_to_idx"].get(obj_id, None) for obj_id in obj_ids + ] + updated_frames = [] + actually_used_obj_ids = [] + removing_any = False + for old_obj_idx_to_rm, obj_id in zip(old_obj_idxs_to_rm, obj_ids, strict=True): + if old_obj_idx_to_rm is None: + if strict: + raise ValueError( + f"Object id {obj_id} does not exist in the tracking state." + ) + else: + actually_used_obj_ids.append(obj_id) + removing_any = True + if not removing_any: + return inference_state["obj_ids"], updated_frames + + # ignore any object IDs that don't exist + old_obj_idxs_to_rm = [x for x in old_obj_idxs_to_rm if x is not None] + obj_ids = actually_used_obj_ids + removed_obj_ids = list(obj_ids) + + # There are still remaining objects after removing this object id. In this case, + # we need to delete the object storage from inference state tensors. + # Step 0: clear the input on those frames where this object id has point or mask input + # (note that this step is required as it might downgrade conditioning frames to + # non-conditioning ones) + if clear_user_refined_map and "user_refined_frames_per_obj" in inference_state: + user_refined_map = inference_state["user_refined_frames_per_obj"] + for removed_obj_id in removed_obj_ids: + if removed_obj_id in user_refined_map: + user_refined_map.pop(removed_obj_id, None) + + all_obj_input_frames_inds = set() + for old_obj_idx_to_rm, obj_id in zip(old_obj_idxs_to_rm, obj_ids, strict=True): + obj_input_frames_inds = set() + obj_input_frames_inds.update( + inference_state["point_inputs_per_obj"][old_obj_idx_to_rm] + ) + obj_input_frames_inds.update( + inference_state["mask_inputs_per_obj"][old_obj_idx_to_rm] + ) + for frame_idx in obj_input_frames_inds: + self.clear_all_points_in_frame( + inference_state, + frame_idx, + obj_id, + need_output=False, + preserve_user_refined=not clear_user_refined_map, + ) + all_obj_input_frames_inds.update(obj_input_frames_inds) + + # Step 1: Update the object id mapping (note that it must be done after Step 0, + # since Step 0 still requires the old object id mappings in inference_state) + old_obj_ids = inference_state["obj_ids"] + old_obj_inds = list(range(len(old_obj_ids))) + remain_old_obj_inds = old_obj_inds.copy() + for old_obj_idx_to_rm in old_obj_idxs_to_rm: + remain_old_obj_inds.remove(old_obj_idx_to_rm) + new_obj_ids = [old_obj_ids[old_idx] for old_idx in remain_old_obj_inds] + new_obj_inds = list(range(len(new_obj_ids))) + # build new mappings + old_idx_to_new_idx = dict(zip(remain_old_obj_inds, new_obj_inds)) + inference_state["obj_id_to_idx"] = dict(zip(new_obj_ids, new_obj_inds)) + inference_state["obj_idx_to_id"] = dict(zip(new_obj_inds, new_obj_ids)) + inference_state["obj_ids"] = new_obj_ids + + if len(new_obj_ids) == 0: + return new_obj_ids, updated_frames + + # Step 2: For per-object tensor storage, we shift their obj_idx in the dict keys. + # (note that "consolidated_frame_inds" doesn't need to be updated in this step as + # it's already handled in Step 0) + def _map_keys(container): + new_kvs = [] + for k in old_obj_inds: + v = container.pop(k) + if k in old_idx_to_new_idx: + new_kvs.append((old_idx_to_new_idx[k], v)) + container.update(new_kvs) + + _map_keys(inference_state["point_inputs_per_obj"]) + _map_keys(inference_state["mask_inputs_per_obj"]) + _map_keys(inference_state["output_dict_per_obj"]) + _map_keys(inference_state["temp_output_dict_per_obj"]) + + multiplex_state: MultiplexState = inference_state["multiplex_state"] + # strict is set to True because we have done the filtering above + buckets_to_keep = multiplex_state.remove_objects( + old_obj_idxs_to_rm, strict=True + ) + obj_ids = set(obj_ids) + + # Step 3: For packed tensor storage, we index the remaining ids and rebuild the per-bucket/per-object slices. + def _slice_state(output_dict, storage_key): + for frame_idx, out in output_dict[storage_key].items(): + out["maskmem_features"] = out["maskmem_features"][buckets_to_keep] + out["maskmem_pos_enc"] = [ + x[buckets_to_keep] for x in out["maskmem_pos_enc"] + ] + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(inference_state, out) + out["obj_ptr"] = out["obj_ptr"][buckets_to_keep] + + # Note that pred_maks and score_logits are stored in a per-object manner + # When we add new objects, obj_id_to_idx mapping could be different + # locally (at this past frame) versus globally (at the current frame), + # so we need to use a local copy of this mapping + local_obj_id_to_idx = out["local_obj_id_to_idx"] + + # Find which local indices correspond to the remaining old object indices + local_remain_old_obj_inds = [ + obj_idx + for obj_id, obj_idx in local_obj_id_to_idx.items() + if obj_id not in obj_ids + ] + + # Guard against stale indices by intersecting with available rows + max_pred = out["pred_masks"].shape[0] + max_scores = out["object_score_logits"].shape[0] + keep_indices = [ + idx + for idx in local_remain_old_obj_inds + if 0 <= idx < max_pred and 0 <= idx < max_scores + ] + out["pred_masks"] = out["pred_masks"][keep_indices] + out["object_score_logits"] = out["object_score_logits"][keep_indices] + if self.use_memory_selection: + out["iou_score"] = out["iou_score"][keep_indices] + out["eff_iou_score"] = self.cal_mem_score( + out["object_score_logits"], out["iou_score"] + ) # recalculate the memory frame score + sliced_conditioning_objects = set() + + # Update local_obj_id_to_idx to reflect the new indices after removal + new_local_obj_id_to_idx = {} + old_to_new = { + old_idx: new_i for new_i, old_idx in enumerate(keep_indices) + } + for obj_id, old_idx in local_obj_id_to_idx.items(): + if obj_id not in obj_ids: # Keep objects not being removed + # Find the new index for this object if it was kept + if old_idx in old_to_new: + new_idx = old_to_new[old_idx] + new_local_obj_id_to_idx[obj_id] = new_idx + if old_idx in out["conditioning_objects"]: + sliced_conditioning_objects.add(new_idx) + + out["local_obj_id_to_idx"] = new_local_obj_id_to_idx + out["conditioning_objects"] = sliced_conditioning_objects + + # also update the per-object slices + self._add_output_per_object( + inference_state, frame_idx, out, storage_key + ) + + _slice_state(inference_state["output_dict"], "cond_frame_outputs") + _slice_state(inference_state["output_dict"], "non_cond_frame_outputs") + + # Step 4: Further collect the outputs on those frames in `obj_input_frames_inds`, which + # could show an updated mask for objects previously occluded by the object being removed + if need_output: + temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] + for frame_idx in all_obj_input_frames_inds: + is_cond = any( + frame_idx in obj_temp_output_dict["cond_frame_outputs"] + for obj_temp_output_dict in temp_output_dict_per_obj.values() + ) + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, + frame_idx, + is_cond=is_cond, + run_mem_encoder=False, + consolidate_at_video_res=True, + ) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, consolidated_out["pred_masks_video_res"] + ) + updated_frames.append((frame_idx, video_res_masks)) + + return inference_state["obj_ids"], updated_frames + + def _clear_non_cond_mem_around_input(self, inference_state, frame_idx): + """ + Remove the non-conditioning memory around the input frame. When users provide + correction clicks, the surrounding frames' non-conditioning memories can still + contain outdated object appearance information and could confuse the model. + + This function clears those non-conditioning memories surrounding the interacted + frame to avoid giving the model both old and new information about the object. + """ + r = self.memory_temporal_stride_for_eval + frame_idx_begin = frame_idx - r * self.num_maskmem + frame_idx_end = frame_idx + r * self.num_maskmem + output_dict = inference_state["output_dict"] + non_cond_frame_outputs = output_dict["non_cond_frame_outputs"] + for t in range(frame_idx_begin, frame_idx_end + 1): + non_cond_frame_outputs.pop(t, None) + for obj_output_dict in inference_state["output_dict_per_obj"].values(): + obj_output_dict["non_cond_frame_outputs"].pop(t, None) + + @torch.inference_mode() + @torch.autocast(device_type="cuda", dtype=torch.bfloat16) + def warm_up_compilation( + self, offload_video_to_cpu=False, offload_state_to_cpu=False + ): + """ + Warm up the model by running a dummy inference to compile the model. This is + useful to avoid the compilation overhead in the first inference call. + """ + if not self.compile_all_components: + return + + raise NotImplementedError( + "Please use `VideoTrackingMultiplexDemoPerBucketInference` instead for full model compilation." + ) + + +class Sam3VideoTrackingMultiplexDemo(VideoTrackingMultiplexDemo): + @torch.inference_mode() + def init_state( + self, + video_height, + video_width, + num_frames, + cached_features=None, + offload_video_to_cpu=False, + offload_state_to_cpu=False, + ): + """Initialize a inference state.""" + # Make sure that sigmoid is used on mask logits (should be True for all our recent models). + # Since we rely on large negative values as scores for missing objects, the raw logits + # cannot be consumed directly and must be converted into 0~1 range via sigmoid first. + if not self.apply_sigmoid_to_mask_logits_for_mem_enc: + raise NotImplementedError( + "Multi-object tracking requires sigmoid in memory encoder for non-overlapping constraints." + ) + inference_state = {} + # inference_state["images"] = images + inference_state["num_frames"] = num_frames + # whether to offload the video frames to CPU memory + # turning on this option saves the GPU memory with only a very small overhead + inference_state["offload_video_to_cpu"] = offload_video_to_cpu + # whether to offload the inference state to CPU memory + # turning on this option saves the GPU memory at the cost of a lower tracking fps + # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object + # and from 24 to 21 when tracking two objects) + inference_state["offload_state_to_cpu"] = offload_state_to_cpu + # the original video height and width, used for resizing final output scores + inference_state["video_height"] = video_height + inference_state["video_width"] = video_width + inference_state["device"] = torch.device("cuda") + if offload_state_to_cpu: + inference_state["storage_device"] = torch.device("cpu") + else: + inference_state["storage_device"] = torch.device("cuda") + # inputs on each frame + inference_state["point_inputs_per_obj"] = {} + inference_state["mask_inputs_per_obj"] = {} + # visual features on a small number of recently visited frames for quick interactions + inference_state["cached_features"] = ( + {} if cached_features is None else cached_features + ) + # values that don't change across frames (so we only need to hold one copy of them) + inference_state["constants"] = {} + # mapping between client-side object id and model-side object index + inference_state["obj_id_to_idx"] = OrderedDict() + inference_state["obj_idx_to_id"] = OrderedDict() + inference_state["obj_ids"] = [] + # A storage to hold the model's tracking results and states on each frame + inference_state["output_dict"] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + # The index of the frame that received the first annotation + inference_state["first_ann_frame_idx"] = None + # Slice (view) of each object tracking results, sharing the same memory with "output_dict" + inference_state["output_dict_per_obj"] = {} + # A temporary storage to hold new outputs when user interact with a frame + # to add clicks or mask (it's merged into "output_dict" before propagation starts) + inference_state["temp_output_dict_per_obj"] = {} + # Frames that already holds consolidated outputs from click or mask inputs + # (we directly use their consolidated outputs during tracking) + inference_state["consolidated_frame_inds"] = { + "cond_frame_outputs": set(), # set containing frame indices + "non_cond_frame_outputs": set(), # set containing frame indices + } + # metadata for each tracking frame (e.g. which direction it's tracked) + inference_state["tracking_has_started"] = False + inference_state["frames_already_tracked"] = {} + inference_state["multiplex_state"] = None + # Warm up the whole model and cache the image feature on frame 0 + # by making a dummy click on the first frame (and then cleaning it up) + # self.add_new_points( + # inference_state=inference_state, + # frame_idx=0, + # obj_id=1, + # points=torch.tensor([[0.5, 0.5]], dtype=torch.float32), + # labels=torch.tensor([1], dtype=torch.int32), + # clear_old_points=True, + # rel_coordinates=True, + # ) + self.clear_all_points_in_video(inference_state) + return inference_state + + def _suppress_shrinked_masks( + self, pred_masks, new_pred_masks, shrink_threshold=0.3 + ): + area_before = (pred_masks > 0).sum(dim=(-1, -2)) + area_after = (new_pred_masks > 0).sum(dim=(-1, -2)) + area_before = torch.clamp(area_before, min=1.0) + area_ratio = area_after / area_before + keep = area_ratio >= shrink_threshold + keep_mask = keep[..., None, None].expand_as(pred_masks) + pred_masks_after = torch.where( + keep_mask, pred_masks, torch.clamp(pred_masks, max=-10.0) + ) + return pred_masks_after + + @staticmethod + def _suppress_object_pw_area_shrinkage(pred_masks): + """ + This function suppresses masks that shrink in area after applying pixelwise non-overlapping constriants. + Note that the final output can still be overlapping. + """ + # Apply pixel-wise non-overlapping constraint based on mask scores + # pixel_level_non_overlapping_masks = super()._apply_non_overlapping_constraints( + # pred_masks + # ) + + batch_size = pred_masks.size(0) + if batch_size == 1: + return pred_masks + + device = pred_masks.device + # "max_obj_inds": object index of the object with the highest score at each location + max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True) + # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks` + batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None] + keep = max_obj_inds == batch_obj_inds + # suppress overlapping regions' scores below -10.0 so that the foreground regions + # don't overlap (here sigmoid(-10.0)=4.5398e-05) + pixel_level_non_overlapping_masks = torch.where( + keep, pred_masks, torch.clamp(pred_masks, max=-10.0) + ) + + # Fully suppress masks with high shrinkage (probably noisy) based on the pixel wise non-overlapping constraints + # NOTE: The output of this function can be a no op if none of the masks shrinked by a large factor. + # pred_masks = self._suppress_shrinked_masks( + # pred_masks, pixel_level_non_overlapping_masks + # ) + + shrink_threshold = 0.3 + area_before = (pred_masks > 0).sum(dim=(-1, -2)) + area_after = (pixel_level_non_overlapping_masks > 0).sum(dim=(-1, -2)) + area_before = torch.clamp(area_before, min=1.0) + area_ratio = area_after / area_before + keep = area_ratio >= shrink_threshold + keep_mask = keep[..., None, None].expand_as(pred_masks) + pred_masks_after = torch.where( + keep_mask, pred_masks, torch.clamp(pred_masks, max=-10.0) + ) + + return pred_masks_after + + def _apply_object_wise_non_overlapping_constraints( + self, pred_masks, obj_scores, background_value=-10.0 + ): + """ + Applies non-overlapping constraints object wise (i.e. only one object can claim the overlapping region) + """ + # TODO: Try suppression based on IoM here as well. + # Replace pixel scores with object scores + pred_masks_single_score = torch.where( + pred_masks > 0, obj_scores[..., None, None], background_value + ) + # Apply pixel-wise non-overlapping constraint based on mask scores + pixel_level_non_overlapping_masks = super()._apply_non_overlapping_constraints( + pred_masks_single_score + ) + # Replace object scores with pixel scores. Note, that now only one object can claim the overlapping region + pred_masks = torch.where( + pixel_level_non_overlapping_masks > 0, + pred_masks, + torch.clamp(pred_masks, max=background_value), + ) + return pred_masks + + @torch.inference_mode() + def propagate_in_video( + self, + inference_state, + start_frame_idx, + max_frame_num_to_track, + reverse, + tqdm_disable=False, + obj_ids=None, + run_mem_encoder=True, + ): + """Propagate the input points across frames to track in the entire video.""" + # NOTE: This is a copy from the parent class, except that we return object scores as well. + output_dict = inference_state["output_dict"] + consolidated_frame_inds = inference_state["consolidated_frame_inds"] + if obj_ids is not None: + raise NotImplementedError( + "Per-object tracking yet for batched inference if not implemented." + ) + obj_ids = inference_state["obj_ids"] + batch_size = self._get_obj_num(inference_state) + if len(output_dict["cond_frame_outputs"]) == 0: + raise RuntimeError("No points are provided; please add points first") + clear_non_cond_mem = self.clear_non_cond_mem_around_input and ( + self.clear_non_cond_mem_for_multi_obj or batch_size <= 1 + ) + + processing_order = self._get_processing_order( + inference_state, + start_frame_idx, + max_frame_num_to_track, + reverse, + ) + + for frame_idx in tqdm( + processing_order, desc="propagate in video", disable=tqdm_disable + ): + # We skip those frames already in consolidated outputs (these are frames + # that received input clicks or mask). Note that we cannot directly run + # batched forward on them via `_run_single_frame_inference` because the + # number of clicks on each object might be different. + if frame_idx in consolidated_frame_inds["cond_frame_outputs"]: + storage_key = "cond_frame_outputs" + current_out = output_dict[storage_key][frame_idx] + pred_masks = current_out["pred_masks"] + obj_scores = current_out["object_score_logits"] + if clear_non_cond_mem: + # clear non-conditioning memory of the surrounding frames + self._clear_non_cond_mem_around_input(inference_state, frame_idx) + elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]: + storage_key = "non_cond_frame_outputs" + current_out = output_dict[storage_key][frame_idx] + pred_masks = current_out["pred_masks"] + obj_scores = current_out["object_score_logits"] + else: + storage_key = "non_cond_frame_outputs" + with torch.profiler.record_function( + "VideoTrackingMultiplexDemo._run_single_frame_inference" + ): + current_out, pred_masks = self._run_single_frame_inference( + inference_state=inference_state, + output_dict=output_dict, + frame_idx=frame_idx, + batch_size=batch_size, + is_init_cond_frame=False, + point_inputs=None, + mask_inputs=None, + reverse=reverse, + run_mem_encoder=run_mem_encoder, + ) + obj_scores = current_out["object_score_logits"] + current_out["local_obj_id_to_idx"] = deepcopy( + inference_state["obj_id_to_idx"] + ) + output_dict[storage_key][frame_idx] = current_out + + # Create slices of per-object outputs for subsequent interaction with each + # individual object after tracking. + self._add_output_per_object( + inference_state, frame_idx, current_out, storage_key + ) + inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse} + + # Resize the output mask to the original video resolution (we directly use + # the mask scores on GPU for output to avoid any CPU conversion in between) + low_res_masks, video_res_masks = self._get_orig_video_res_output( + inference_state, pred_masks + ) + yield frame_idx, obj_ids, low_res_masks, video_res_masks, obj_scores diff --git a/third_party/sam3/sam3/model/vitdet.py b/third_party/sam3/sam3/model/vitdet.py new file mode 100644 index 0000000000000000000000000000000000000000..c43d5cca73b60aa736f9259e4a1c5eb860a2b3b1 --- /dev/null +++ b/third_party/sam3/sam3/model/vitdet.py @@ -0,0 +1,1047 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +""" +ViTDet backbone adapted from Detectron2. +This module implements Vision Transformer (ViT) backbone for object detection. + +Rope embedding code adopted from: +1. https://github.com/meta-llama/codellama/blob/main/llama/model.py +2. https://github.com/naver-ai/rope-vit +3. https://github.com/lucidrains/rotary-embedding-torch +""" + +import math +from functools import partial +from typing import Callable, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint + +try: + from timm.layers import DropPath, trunc_normal_ +except ModuleNotFoundError: + # compatibility for older timm versions + from timm.models.layers import DropPath, trunc_normal_ +from sam3.model.data_misc import NestedTensor +from sam3.model.model_misc import AttentionType, LayerScale +from sam3.perflib.fused import addmm_act +from sam3.sam.rope import apply_rotary_enc_real, VisionRotaryEmbeddingVE +from torch import Tensor + + +class Mlp(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks""" + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=None, + bias=True, + drop=0.0, + use_conv=False, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + if isinstance(bias, bool): + bias = (bias, bias) + if isinstance(drop, (int, float)): + drop_probs = (drop, drop) + else: + drop_probs = drop + linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.norm = ( + norm_layer(hidden_features) if norm_layer is not None else nn.Identity() + ) + self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + if torch.is_grad_enabled(): + # Training: standard path (addmm_act requires grad disabled) + x = self.fc1(x) + x = self.act(x) + else: + x = addmm_act(type(self.act), self.fc1, x) + x = self.drop1(x) + x = self.norm(x) + x = self.fc2(x.to(self.fc2.weight.dtype)) + x = self.drop2(x) + return x + + +def init_t_xy( + end_x: int, end_y: int, scale: float = 1.0, offset: int = 0 +) -> Tuple[torch.Tensor, torch.Tensor]: + t = torch.arange(end_x * end_y, dtype=torch.float32) + t_x = (t % end_x).float() + t_y = torch.div(t, end_x, rounding_mode="floor").float() + return t_x * scale + offset, t_y * scale + offset + + +def compute_axial_cis( + dim: int, + end_x: int, + end_y: int, + theta: float = 10000.0, + scale_pos: float = 1.0, + offset: int = 0, +) -> torch.Tensor: + freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) + freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) + + t_x, t_y = init_t_xy(end_x, end_y, scale_pos, offset) + freqs_x = torch.outer(t_x, freqs_x) + freqs_y = torch.outer(t_y, freqs_y) + freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x) + freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y) + return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1) + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[-2], x.shape[-1]) + shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_enc( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, + repeat_freqs_k: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = ( + torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + if xk.shape[-2] != 0 + else None + ) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + if xk_ is None: + # no keys to rotate, due to dropout + return xq_out.type_as(xq).to(xq.device), xk + # repeat freqs along seq_len dim to match k seq_len + if repeat_freqs_k: + r = xk_.shape[-2] // xq_.shape[-2] + freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device) + + +def window_partition(x: Tensor, window_size: int) -> Tuple[Tensor, Tuple[int, int]]: + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).reshape(-1, window_size, window_size, C) + return windows, (Hp, Wp) + + +def window_unpartition( + windows: Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] +) -> Tensor: + """ + Window unpartition into original sequences and removing padding. + Args: + x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.reshape( + B, Hp // window_size, Wp // window_size, window_size, window_size, -1 + ) + x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :] + return x + + +def get_rel_pos(q_size: int, k_size: int, rel_pos: Tensor) -> Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + align_corners=False, + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def get_abs_pos( + abs_pos: Tensor, + has_cls_token: bool, + hw: Tuple[int, int], + retain_cls_token: bool = False, + tiling: bool = False, +) -> Tensor: + """ + Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token + dimension for the original embeddings. + Args: + abs_pos (Tensor): absolute positional embeddings with (1, num_position, C). + has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token. + hw (Tuple): size of input image tokens. + retain_cls_token: whether to retain the cls_token + tiling: whether to tile the embeddings, *instead* of interpolation (a la abs_win) + Returns: + Absolute positional embeddings after processing with shape (1, H, W, C), + if retain_cls_token is False, otherwise (1, 1+H*W, C) + """ + if retain_cls_token: + assert has_cls_token + + h, w = hw + if has_cls_token: + cls_pos = abs_pos[:, :1] + abs_pos = abs_pos[:, 1:] + + xy_num = abs_pos.shape[1] + size = int(math.sqrt(xy_num)) + assert size * size == xy_num + + if size != h or size != w: + new_abs_pos = abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2) + if tiling: + new_abs_pos = new_abs_pos.tile( + [1, 1] + [x // y + 1 for x, y in zip((h, w), new_abs_pos.shape[2:])] + )[:, :, :h, :w] + else: + new_abs_pos = F.interpolate( + new_abs_pos, + size=(h, w), + mode="bicubic", + align_corners=False, + ) + + if not retain_cls_token: + return new_abs_pos.permute(0, 2, 3, 1) + else: + # add cls_token back, flatten spatial dims + assert has_cls_token + return torch.cat( + [cls_pos, new_abs_pos.permute(0, 2, 3, 1).reshape(1, h * w, -1)], + dim=1, + ) + + else: + if not retain_cls_token: + return abs_pos.reshape(1, h, w, -1) + else: + assert has_cls_token + return torch.cat([cls_pos, abs_pos], dim=1) + + +def concat_rel_pos( + q: Tensor, + k: Tensor, + q_hw: Tuple[int, int], + k_hw: Tuple[int, int], + rel_pos_h: Tensor, + rel_pos_w: Tensor, + rescale: bool = False, + relative_coords: Optional[Tensor] = None, +) -> Tuple[Tensor, Tensor]: + """ + Concatenate rel pos coeffs to the q & k tensors, so that qk^T is now + effectively including rel pos biases. + Args: + q (Tensor): q tensor with shape (B, L_q, C). + k (Tensor): k tensor with shape (B, L_k, C). + q_hw, k_hw: These are spatial size of q & k tensors. + rel_pos_h, rel_pos_w: These are relative pos embeddings/params of height, width. + rescale (bool): whether to rescale. e.g. for use when using sdpa, pytorch will + scale by the wrong factor due to the concat. + Returns: + q, k: But, padded so that qk^T accounts for rel pos biases + """ + q_h, q_w = q_hw + k_h, k_w = k_hw + + assert (q_h == q_w) and (k_h == k_w), "only square inputs supported" + + if relative_coords is not None: + Rh = rel_pos_h[relative_coords] + Rw = rel_pos_w[relative_coords] + else: + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + + old_scale = dim**0.5 + new_scale = (dim + k_h + k_w) ** 0.5 if rescale else old_scale # for sdpa + # attn will be divided by new_scale, but we want to divide q by old_scale + scale_ratio = new_scale / old_scale + + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) * new_scale # (B, q_h, q_w, k_h) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) * new_scale # (B, q_h, q_w, k_w) + + eye_h = torch.eye(k_h, dtype=q.dtype, device=q.device) + eye_w = torch.eye(k_w, dtype=q.dtype, device=q.device) + + eye_h = eye_h.view(1, k_h, 1, k_h).expand([B, k_h, k_w, k_h]) + eye_w = eye_w.view(1, 1, k_w, k_w).expand([B, k_h, k_w, k_w]) + + q = torch.cat([r_q * scale_ratio, rel_h, rel_w], dim=-1).view(B, q_h * q_w, -1) + k = torch.cat([k.view(B, k_h, k_w, -1), eye_h, eye_w], dim=-1).view( + B, k_h * k_w, -1 + ) + + return q, k + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, int] = (16, 16), + stride: Tuple[int, int] = (16, 16), + padding: Tuple[int, int] = (0, 0), + in_chans: int = 3, + embed_dim: int = 768, + bias: bool = True, + ): + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): embed_dim (int): Patch embedding dimension. + """ + super().__init__() + + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=bias, + ) + + def forward(self, x: Tensor) -> Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x + + +class Attention(nn.Module): + """Multi-head Attention block with relative position embeddings and 2d-rope.""" + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + input_size: Optional[Tuple[int, int]] = None, + attn_type: AttentionType = AttentionType.Vanilla, + cls_token: bool = False, + use_rope: bool = False, + rope_theta: float = 10000.0, + rope_pt_size: Optional[Tuple[int, int]] = None, + rope_interp: bool = False, + rope_tiled: bool = False, + use_ve_rope: bool = False, + use_fa3: bool = False, + use_rope_real: bool = False, + ): + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool: If True, add a learnable bias to query, key, value. + rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + input_size (int or None): Input resolution for calculating the relative positional + parameter size or rope size. + attn_type: Type of attention operation, e.g. "vanilla", "vanilla-xformer". + cls_token: whether a cls_token is present. + use_rope: whether to use rope 2d (indep of use_rel_pos, as it can be used together) + rope_theta: control frequencies of rope + rope_pt_size: size of rope in previous stage of training, needed for interpolation or tiling + rope_tiled: whether to tile rope or not; tile expected to be of size rope_pt_size x rope_pt_size + rope_interp: whether to interpolate (or extrapolate) rope to match input size + use_ve_rope: use ve orig rope implementation, if small numerical differences are important (normally not) + """ + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + self.cls_token = cls_token + + self.attn_type = attn_type + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + # rel_pos embeddings and rope + self.use_rel_pos = use_rel_pos + self.input_size = input_size + + self.use_rope = use_rope + self.rope_theta = rope_theta + self.rope_pt_size = rope_pt_size + self.rope_interp = rope_interp + self.rope_tiled = rope_tiled + self.use_ve_rope = use_ve_rope + self.use_fa3 = use_fa3 + self.use_rope_real = use_rope_real + + # init rel_pos embeddings and rope + self._setup_rel_pos(rel_pos_zero_init) + self._setup_rope_freqs() + + def _setup_rel_pos(self, rel_pos_zero_init: bool = True) -> None: + if not self.use_rel_pos: + self.rel_pos_h = None + self.rel_pos_w = None + return + + assert self.input_size is not None + assert self.cls_token is False, "not supported" + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter( + torch.zeros(2 * self.input_size[0] - 1, self.head_dim) + ) + self.rel_pos_w = nn.Parameter( + torch.zeros(2 * self.input_size[1] - 1, self.head_dim) + ) + + if not rel_pos_zero_init: + trunc_normal_(self.rel_pos_h, std=0.02) + trunc_normal_(self.rel_pos_w, std=0.02) + + # Precompute the relative coords + H, W = self.input_size + q_coords = torch.arange(H)[:, None] + k_coords = torch.arange(W)[None, :] + relative_coords = (q_coords - k_coords) + (H - 1) + self.register_buffer("relative_coords", relative_coords.long()) + + def _setup_rope_freqs(self) -> None: + if not self.use_rope: + self.freqs_cis = None + return + + assert self.input_size is not None + # determine rope input size + if self.rope_pt_size is None: + self.rope_pt_size = self.input_size + + if self.use_ve_rope: + assert not self.rope_tiled, "not supported" + self.rope = VisionRotaryEmbeddingVE( + dim=self.head_dim // 2, + seq_len=self.input_size[0], + pt_seq_len=self.rope_pt_size[0], + ) + return + + # initialize 2d rope freqs + self.compute_cis = partial( + compute_axial_cis, + dim=self.head_dim, + theta=self.rope_theta, + ) + + if self.rope_pt_size != self.input_size and self.rope_tiled: + assert not self.rope_interp + # window/tiled rope + freqs_cis = self.compute_cis( + end_x=self.rope_pt_size[0], end_y=self.rope_pt_size[1] + ) + # check dims are tileable + rh, rw = ( + self.input_size[0] // self.rope_pt_size[0], + self.input_size[1] // self.rope_pt_size[1], + ) + assert rh >= 1, rw >= 1 + assert ( + self.input_size[0] % self.rope_pt_size[0] == 0 + and self.input_size[1] % self.rope_pt_size[1] == 0 + ) + + # restore spatial shape, tile and then flatten spatial dims + freqs_cis = ( + freqs_cis.reshape(self.rope_pt_size[0], self.rope_pt_size[1], -1) + .tile(rh, rw, 1) + .reshape(-1, freqs_cis.shape[-1]) + ) + else: + # interpolate rope + scale_pos = 1.0 + if self.rope_interp: + scale_pos = self.rope_pt_size[0] / self.input_size[0] + # get scaled freqs_cis + freqs_cis = self.compute_cis( + end_x=self.input_size[0], + end_y=self.input_size[1], + scale_pos=scale_pos, + ) + if self.cls_token: + t = torch.zeros( + self.head_dim // 2, + dtype=torch.float32, + device=freqs_cis.device, + ) + cls_freqs_cis = torch.polar(torch.ones_like(t), t)[None, :] + freqs_cis = torch.cat([cls_freqs_cis, freqs_cis], dim=0) + + self.register_buffer("freqs_cis", freqs_cis) + if self.use_rope_real: + self.register_buffer("freqs_cis_real", freqs_cis.real) + self.register_buffer("freqs_cis_imag", freqs_cis.imag) + + def _apply_rope(self, q, k) -> Tuple[Tensor, Tensor]: + if not self.use_rope: + return q, k + + if self.use_ve_rope: + dtype = q.dtype + return self.rope(q).to(dtype), self.rope(k).to(dtype) + + assert self.freqs_cis is not None + + if self.use_rope_real: + return apply_rotary_enc_real( + q, + k, + freqs_cis_imag=self.freqs_cis_imag, + freqs_cis_real=self.freqs_cis_real, + ) + return apply_rotary_enc(q, k, freqs_cis=self.freqs_cis) + + def forward(self, x: Tensor) -> Tensor: + s = 1 if self.cls_token else 0 # used to exclude cls_token + if x.ndim == 4: + B, H, W, _ = x.shape + assert s == 0 # no cls_token + L = H * W + ndim = 4 + else: + assert x.ndim == 3 + B, L, _ = x.shape + ndim = 3 + H = W = math.sqrt(L - s) + + # qkv with shape (3, B, nHead, L, C) + qkv = self.qkv(x).reshape(B, L, 3, self.num_heads, -1) + # q, k, v with shape (B, nHead, L, C) + q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0) + + # handle rope and rel pos embeddings + q, k = self._apply_rope(q, k) + if self.use_rel_pos: + q, k = concat_rel_pos( + q.flatten(0, 1), + k.flatten(0, 1), + (H, W), + x.shape[1:3], + self.rel_pos_h, + self.rel_pos_w, + rescale=True, + relative_coords=self.relative_coords, + ) + + # sdpa expects [B, nheads, H*W, C] so we transpose back + q = q.reshape(B, self.num_heads, H * W, -1) + k = k.reshape(B, self.num_heads, H * W, -1) + + if self.attn_type == AttentionType.Vanilla: + if self.use_fa3: + from sam3.perflib.fa3 import flash_attn_func + + x = flash_attn_func( + q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + ).transpose(1, 2) + else: + x = F.scaled_dot_product_attention(q, k, v) + else: + raise NotImplementedError + + if ndim == 4: + x = ( + x.view(B, self.num_heads, H, W, -1) + .permute(0, 2, 3, 1, 4) + .reshape(B, H, W, -1) + ) + else: + x = x.view(B, self.num_heads, L, -1).permute(0, 2, 1, 3).reshape(B, L, -1) + + x = self.proj(x) + + return x + + +class Block(nn.Module): + """Transformer blocks with support of window attention""" + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + drop_path: float = 0.0, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + act_layer: Callable[..., nn.Module] = nn.GELU, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + input_size: Optional[Tuple[int, int]] = None, + use_rope: bool = False, + rope_pt_size: Optional[Tuple[int, int]] = None, + rope_tiled: bool = False, + rope_interp: bool = False, + use_ve_rope: bool = False, + cls_token: bool = False, + dropout: float = 0.0, + init_values: Optional[float] = None, + attn_type: AttentionType = AttentionType.Vanilla, + use_fa3: bool = False, + use_rope_real: bool = False, + ): + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + drop_path (float): Stochastic depth rate. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. If it equals 0, then not + use window attention. + input_size (int or None): Input resolution for calculating the relative positional + parameter size. + dropout (float): Dropout rate. + cls_token: whether a cls_token is present. + use_rope: whether to use rope 2d (indep of use_rel_pos, as it can be used together) + rope_pt_size: size of rope in previous stage of training, needed for interpolation or tiling + rope_tiled: whether to tile rope or not; tile expected to be of size rope_pt_size x rope_pt_size + rope_interp: whether to interpolate (or extrapolate) rope to match target input size, + expected to specify source size as rope_pt_size. + use_ve_rope: use ve orig rope implementation, if small numerical differences are important (normally not) + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size if window_size == 0 else (window_size, window_size), + attn_type=attn_type, + use_rope=use_rope, + rope_pt_size=rope_pt_size, + rope_tiled=rope_tiled, + rope_interp=rope_interp, + use_ve_rope=use_ve_rope, + cls_token=cls_token, + use_fa3=use_fa3, + use_rope_real=use_rope_real, + ) + self.ls1 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = Mlp( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=(dropout, 0.0), + ) + self.ls2 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.dropout = nn.Dropout(dropout) + self.window_size = window_size + + def forward(self, x: Tensor) -> Tensor: + shortcut = x + x = self.norm1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.ls1(self.attn(x)) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + + x = shortcut + self.dropout(self.drop_path(x)) + x = x + self.dropout(self.drop_path(self.ls2(self.mlp(self.norm2(x))))) + + return x + + +class ViT(nn.Module): + """ + This module implements Vision Transformer (ViT) backbone in :paper:`vitdet`. + "Exploring Plain Vision Transformer Backbones for Object Detection", + https://arxiv.org/abs/2203.16527 + """ + + def __init__( + self, + img_size: int = 1024, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + drop_path_rate: float = 0.0, + norm_layer: Union[Callable[..., nn.Module], str] = "LayerNorm", + act_layer: Callable[..., nn.Module] = nn.GELU, + use_abs_pos: bool = True, + tile_abs_pos: bool = True, + rel_pos_blocks: Union[Tuple[int, ...], bool] = (2, 5, 8, 11), + rel_pos_zero_init: bool = True, + window_size: int = 14, + global_att_blocks: Tuple[int, ...] = (2, 5, 8, 11), + use_rope: bool = False, + use_tiled_rope: bool = False, + rope_pt_size: Optional[int] = None, + use_interp_rope: bool = False, + use_ve_rope: bool = False, + use_act_checkpoint: bool = True, + pretrain_img_size: int = 224, + pretrain_use_cls_token: bool = True, + retain_cls_token: bool = True, + dropout: float = 0.0, + return_interm_layers: bool = False, + init_values: Optional[float] = None, # for layerscale + attn_type: AttentionType = AttentionType.Vanilla, + ln_pre: bool = False, + ln_post: bool = False, + bias_patch_embed: bool = True, + compile_mode: Optional[str] = None, + use_fa3: bool = False, + use_rope_real: bool = False, + ): + """ + Args: + img_size (int): Input image size. Only relevant for rel pos or rope. + patch_size (int): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depth (int): Depth of ViT. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + drop_path_rate (float): Stochastic depth rate. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_abs_pos (bool): If True, use absolute positional embeddings. + tile_abs_pos (bool): If True, tile absolute positional embeddings instead of interpolation. + rel_pos_blocks (list): Blocks which have rel pos embeddings. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. + global_att_blocks (list): Indexes for blocks using global attention (other blocks use window attention). + use_rope (bool): whether to use rope 2d (indep of rel_pos_blocks, as it can be used together). + rope_pt_size (int): size of rope in previous stage of training, needed for interpolation or tiling. + use_interp_rope: whether to interpolate (or extrapolate) rope to match target input size, + expected to specify source size as rope_pt_size. + use_act_checkpoint (bool): If True, use activation checkpointing. + pretrain_img_size (int): input image size for pretraining models. + pretrain_use_cls_token (bool): If True, pretraining models use class token. + retain_cls_token: whether cls_token should be retained. + dropout (float): Dropout rate. Applied in residual blocks of attn, mlp and inside the mlp. + + return_interm_layers (bool): Whether to return intermediate layers (all global attention blocks). + init_values: layer scale init, None for no layer scale. + + ln_pre (bool): If True, apply layer norm before transformer blocks. + ln_post (bool): If True, apply layer norm after transformer blocks. + bias_patch_embed (bool): bias in conv for patch embed? + compile_mode (str): mode to compile the forward + """ + super().__init__() + self.pretrain_use_cls_token = pretrain_use_cls_token + + window_block_indexes = [i for i in range(depth) if i not in global_att_blocks] + self.full_attn_ids = list(global_att_blocks) + self.rel_pos_blocks = [False] * depth + if isinstance(rel_pos_blocks, bool) and rel_pos_blocks: + self.rel_pos_blocks = [True] * depth + else: + for i in rel_pos_blocks: + self.rel_pos_blocks[i] = True + + self.retain_cls_token = retain_cls_token + if self.retain_cls_token: + assert pretrain_use_cls_token + assert ( + len(window_block_indexes) == 0 + ), "windowing not supported with cls token" + + assert sum(self.rel_pos_blocks) == 0, "rel pos not supported with cls token" + + scale = embed_dim**-0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(1, 1, embed_dim)) + + if isinstance(norm_layer, str): + norm_layer = partial(getattr(nn, norm_layer), eps=1e-5) + + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + bias=bias_patch_embed, + ) + + # Handle absolute positional embedding + self.tile_abs_pos = tile_abs_pos + self.use_abs_pos = use_abs_pos + if self.tile_abs_pos: + assert self.use_abs_pos + + if self.use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + num_patches = (pretrain_img_size // patch_size) * ( + pretrain_img_size // patch_size + ) + num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches + self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim)) + else: + self.pos_embed = None + + # stochastic depth decay rule + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] + + self.blocks = nn.ModuleList() + cur_stage = 1 + for i in range(depth): + block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + use_rel_pos=self.rel_pos_blocks[i], + rel_pos_zero_init=rel_pos_zero_init, + window_size=window_size if i in window_block_indexes else 0, + input_size=(img_size // patch_size, img_size // patch_size), + use_rope=use_rope, + rope_pt_size=( + (window_size, window_size) + if rope_pt_size is None + else (rope_pt_size, rope_pt_size) + ), + rope_tiled=use_tiled_rope, + use_ve_rope=use_ve_rope, + rope_interp=use_interp_rope, + cls_token=self.retain_cls_token, + dropout=dropout, + init_values=init_values, + attn_type=attn_type, + use_fa3=use_fa3, + use_rope_real=use_rope_real, + ) + + if i not in window_block_indexes: + cur_stage += 1 + + self.use_act_checkpoint = use_act_checkpoint + + self.blocks.append(block) + + self.return_interm_layers = return_interm_layers + self.channel_list = ( + [embed_dim] * len(self.full_attn_ids) + if return_interm_layers + else [embed_dim] + ) + + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=0.02) + + self.ln_pre = norm_layer(embed_dim) if ln_pre else nn.Identity() + self.ln_post = norm_layer(embed_dim) if ln_post else nn.Identity() + + self.apply(self._init_weights) + + if compile_mode is not None: + # Only compile for training mode, skip compile for eval to avoid + # long compilation time during validation + self._forward_uncompiled = self.forward + self._forward_compiled = torch.compile( + self.forward, mode=compile_mode, fullgraph=True + ) + # Override forward to dispatch based on training mode + def _dispatch_forward(x): + if self.training: + return self._forward_compiled(x) + else: + return self._forward_uncompiled(x) + self.forward = _dispatch_forward + if self.use_act_checkpoint: + torch._dynamo.config.optimize_ddp = False + + def _init_weights(self, m: nn.Module) -> None: + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, tensor_list): + if isinstance(tensor_list, NestedTensor): + x = tensor_list.tensors + mask = tensor_list.mask + else: + x = tensor_list + mask = None + + x = self.patch_embed(x) + h, w = x.shape[1], x.shape[2] + + s = 0 + if self.retain_cls_token: + # If cls_token is retained, we don't + # maintain spatial shape + x = torch.cat([self.class_embedding, x.flatten(1, 2)], dim=1) + s = 1 + + if self.pos_embed is not None: + x = x + get_abs_pos( + self.pos_embed, + self.pretrain_use_cls_token, + (h, w), + self.retain_cls_token, + tiling=self.tile_abs_pos, + ) + + x = self.ln_pre(x) + + outputs = [] + masks = None + for i, blk in enumerate(self.blocks): + if self.use_act_checkpoint and self.training: + x = checkpoint.checkpoint(blk, x, use_reentrant=False) + else: + x = blk(x) + if (i == self.full_attn_ids[-1]) or ( + self.return_interm_layers and i in self.full_attn_ids + ): + if i == self.full_attn_ids[-1]: + x = self.ln_post(x) + + feats = x[:, s:] + if feats.ndim == 4: + feats = feats.permute(0, 3, 1, 2) + else: + assert feats.ndim == 3 + h = w = math.sqrt(feats.shape[1]) + feats = feats.reshape( + feats.shape[0], h, w, feats.shape[-1] + ).permute(0, 3, 1, 2) + + if isinstance(tensor_list, NestedTensor): + # Optimization, if the mask is all False, just ignore it + if mask is not None and mask.any() and masks is None: + masks = F.interpolate( + mask[None].float(), size=feats.shape[-2:] + ).bool()[0] + outputs.append(NestedTensor(feats, masks)) + else: + outputs.append(feats) + + return outputs + + def get_layer_id(self, layer_name: str) -> int: + # https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 + num_layers = self.get_num_layers() + + if layer_name.find("rel_pos") != -1: + return num_layers + 1 + elif layer_name.find("ln_pre") != -1: + return 0 + elif layer_name.find("pos_embed") != -1 or layer_name.find("cls_token") != -1: + return 0 + elif layer_name.find("patch_embed") != -1: + return 0 + elif layer_name.find("blocks") != -1: + return int(layer_name.split("blocks")[1].split(".")[1]) + 1 + else: + return num_layers + 1 + + def get_num_layers(self) -> int: + return len(self.blocks) diff --git a/third_party/sam3/sam3/model/vl_combiner.py b/third_party/sam3/sam3/model/vl_combiner.py new file mode 100644 index 0000000000000000000000000000000000000000..aed1f2c5b0d272b5eed36488c68d8afd6456adea --- /dev/null +++ b/third_party/sam3/sam3/model/vl_combiner.py @@ -0,0 +1,430 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +"""Provides utility to combine a vision backbone with a language backbone.""" + +from copy import copy +from typing import List, Optional + +import torch +import torch.nn as nn +from torch.nn.attention import sdpa_kernel, SDPBackend + +from .act_ckpt_utils import activation_ckpt_wrapper +from .data_misc import NestedTensor +from .necks import Sam3DualViTDetNeck, Sam3TriViTDetNeck + + +class SAM3VLBackbone(nn.Module): + """This backbone combines a vision backbone and a language backbone without fusion. + As such it is more of a convenience wrapper to handle the two backbones together. + + It adds support for activation checkpointing and compilation. + """ + + def __init__( + self, + visual: Sam3DualViTDetNeck, + text, + compile_visual: bool = False, + act_ckpt_whole_vision_backbone: bool = False, + act_ckpt_whole_language_backbone: bool = False, + scalp=0, + ): + """Initialize the backbone combiner. + + :param visual: The vision backbone to use + :param text: The text encoder to use + """ + super().__init__() + self.vision_backbone: Sam3DualViTDetNeck = ( + torch.compile(visual) if compile_visual else visual + ) + self.language_backbone = text + self.scalp = scalp + # allow running activation checkpointing on the entire vision and language backbones + self.act_ckpt_whole_vision_backbone = act_ckpt_whole_vision_backbone + self.act_ckpt_whole_language_backbone = act_ckpt_whole_language_backbone + + def forward( + self, + samples: torch.Tensor, + captions: List[str], + input_boxes: Optional[torch.Tensor] = None, + additional_text: Optional[List[str]] = None, + ): + """Forward pass of the backbone combiner. + + :param samples: The input images + :param captions: The input captions + :param input_boxes: If the text contains place-holders for boxes, this + parameter contains the tensor containing their spatial features + :param additional_text: This can be used to encode some additional text + (different from the captions) in the same forward of the backbone + :return: Output dictionary with the following keys: + - vision_features: The output of the vision backbone + - language_features: The output of the language backbone + - language_mask: The attention mask of the language backbone + - vision_pos_enc: The positional encoding of the vision backbone + - (optional) additional_text_features: The output of the language + backbone for the additional text + - (optional) additional_text_mask: The attention mask of the + language backbone for the additional text + """ + output = self.forward_image(samples) + device = output["vision_features"].device + output.update(self.forward_text(captions, input_boxes, additional_text, device)) + return output + + def forward_image(self, samples: torch.Tensor): + return activation_ckpt_wrapper(self._forward_image_no_act_ckpt)( + samples=samples, + act_ckpt_enable=self.act_ckpt_whole_vision_backbone and self.training, + ) + + def _forward_image_no_act_ckpt(self, samples): + # Forward through backbone + sam3_features, sam3_pos, sam2_features, sam2_pos = self.vision_backbone.forward( + samples + ) + if self.scalp > 0: + # Discard the lowest resolution features + sam3_features, sam3_pos = ( + sam3_features[: -self.scalp], + sam3_pos[: -self.scalp], + ) + if sam2_features is not None and sam2_pos is not None: + sam2_features, sam2_pos = ( + sam2_features[: -self.scalp], + sam2_pos[: -self.scalp], + ) + + sam2_output = None + + if sam2_features is not None and sam2_pos is not None: + sam2_src = sam2_features[-1] + sam2_output = { + "vision_features": sam2_src, + "vision_pos_enc": sam2_pos, + "backbone_fpn": sam2_features, + } + + sam3_src = sam3_features[-1] + output = { + "vision_features": sam3_src, + "vision_pos_enc": sam3_pos, + "backbone_fpn": sam3_features, + "sam2_backbone_out": sam2_output, + } + + return output + + def forward_text( + self, captions, input_boxes=None, additional_text=None, device="cuda" + ): + return activation_ckpt_wrapper(self._forward_text_no_ack_ckpt)( + captions=captions, + input_boxes=input_boxes, + additional_text=additional_text, + device=device, + act_ckpt_enable=self.act_ckpt_whole_language_backbone and self.training, + ) + + def _forward_text_no_ack_ckpt( + self, + captions, + input_boxes=None, + additional_text=None, + device="cuda", + ): + output = {} + + # Forward through text_encoder + text_to_encode = copy(captions) + if additional_text is not None: + # if there are additional_text, we piggy-back them into this forward. + # They'll be used later for output alignment + text_to_encode += additional_text + + sdpa_context = sdpa_kernel( + [ + SDPBackend.MATH, + SDPBackend.EFFICIENT_ATTENTION, + SDPBackend.FLASH_ATTENTION, + ] + ) + + with sdpa_context: + text_attention_mask, text_memory, text_embeds = self.language_backbone( + text_to_encode, input_boxes, device=device + ) + + if additional_text is not None: + output["additional_text_features"] = text_memory[:, -len(additional_text) :] + output["additional_text_mask"] = text_attention_mask[ + -len(additional_text) : + ] + + text_memory = text_memory[:, : len(captions)] + text_attention_mask = text_attention_mask[: len(captions)] + text_embeds = text_embeds[:, : len(captions)] + output["language_features"] = text_memory + output["language_mask"] = text_attention_mask + output["language_embeds"] = ( + text_embeds # Text embeddings before forward to the encoder + ) + + return output + + +class SAM3VLBackboneTri(SAM3VLBackbone): + """VL backbone with triple-head vision (sam3, interactive, propagation) + text encoder.""" + + def __init__(self, visual, text, compile_visual=False, scalp=0): + super().__init__( + visual=visual, text=text, compile_visual=compile_visual, scalp=scalp + ) + assert isinstance( + self.vision_backbone, Sam3TriViTDetNeck + ), f"Expected vision backbone to be of type Sam3TriViTDetNeck, got {type(self.vision_backbone)}" + + def forward_image( + self, + samples, + *, + need_sam3_out: bool = True, + need_interactive_out: bool = True, + need_propagation_out: bool = True, + ): + return activation_ckpt_wrapper(self._forward_image_tri_no_act_ckpt)( + samples=samples, + need_sam3_out=need_sam3_out, + need_interactive_out=need_interactive_out, + need_propagation_out=need_propagation_out, + act_ckpt_enable=self.act_ckpt_whole_vision_backbone and self.training, + ) + + def _forward_image_tri_no_act_ckpt( + self, + samples, + need_sam3_out=True, + need_interactive_out=True, + need_propagation_out=True, + ): + ( + sam3_features, + sam3_pos, + interactive_features, + interactive_pos, + propagation_features, + propagation_pos, + ) = self.vision_backbone.forward( + samples, + need_sam3_out=need_sam3_out, + need_interactive_out=need_interactive_out, + need_propagation_out=need_propagation_out, + ) + if self.scalp > 0: + sam3_features, sam3_pos = ( + sam3_features[: -self.scalp], + sam3_pos[: -self.scalp], + ) + interactive_features, interactive_pos = ( + interactive_features[: -self.scalp], + interactive_pos[: -self.scalp], + ) + propagation_features, propagation_pos = ( + propagation_features[: -self.scalp], + propagation_pos[: -self.scalp], + ) + + output = {} + if need_sam3_out: + sam3_last = sam3_features[-1] + output.update( + { + "vision_features": sam3_last.tensors, + "vision_mask": sam3_last.mask, + "vision_pos_enc": sam3_pos, + "backbone_fpn": sam3_features, + } + ) + if need_interactive_out: + inte_last = interactive_features[-1] + output["interactive"] = { + "vision_features": inte_last.tensors, + "vision_mask": inte_last.mask, + "vision_pos_enc": interactive_pos, + "backbone_fpn": interactive_features, + } + if need_propagation_out: + prop_last = propagation_features[-1] + output["sam2_backbone_out"] = { + "vision_features": prop_last.tensors, + "vision_mask": prop_last.mask, + "vision_pos_enc": propagation_pos, + "backbone_fpn": propagation_features, + } + return output + + +class VisionOnly(nn.Module): + def __init__( + self, + visual, + n_features, + forward_in_chunk_for_eval=False, + eval_chunk_size=4, + eval_cast_to_cpu=False, + scalp=0, + compile_mode: str = None, + compile_extra_args: Optional[dict] = None, + ): + super().__init__() + self.vision_backbone = visual + self.should_compile = compile_mode is not None or compile_extra_args is not None + self.compile_mode = compile_mode + self.compile_extra_args = compile_extra_args or {} + self.compiled = False + self.n_features = n_features + self.forward_in_chunk_for_eval = forward_in_chunk_for_eval + self.eval_chunk_size = eval_chunk_size + self.eval_cast_to_cpu = eval_cast_to_cpu + self.scalp = scalp + + def _compile(self): + if self.should_compile and not self.compiled: + self.vision_backbone = torch.compile( + self.vision_backbone, mode=self.compile_mode, **self.compile_extra_args + ) + self.compiled = True + + def forward_image(self, samples): + self._compile() + # Forward through backbone + features, pos = self.vision_backbone(samples) + if self.scalp > 0: + features, pos = features[: -self.scalp], pos[: -self.scalp] + elif self.scalp < 0: + features.pop(self.scalp) + pos.pop(self.scalp) + + src, mask = features[-1].decompose() + output = { + "vision_features": src, + "vision_mask": mask, + "vision_pos_enc": pos, + "backbone_fpn": features, + } + return output + + def forward_text( + self, + captions, + input_boxes=None, + additional_text=None, + device="cuda", + ): + bs = len(captions) + output = { + "language_features": torch.zeros((0, bs, self.n_features), device=device), + "language_mask": torch.zeros((bs, 0), device=device), + } + return output + + +class TriHeadVisionOnly(VisionOnly): + def __init__( + self, + visual, + n_features, + forward_in_chunk_for_eval=False, + eval_chunk_size=4, + eval_cast_to_cpu=False, + scalp=0, + compile_mode: str = None, + compile_extra_args: Optional[dict] = None, + ): + super().__init__( + visual=visual, + n_features=n_features, + forward_in_chunk_for_eval=forward_in_chunk_for_eval, + eval_chunk_size=eval_chunk_size, + eval_cast_to_cpu=eval_cast_to_cpu, + scalp=scalp, + compile_mode=compile_mode, + compile_extra_args=compile_extra_args, + ) + assert isinstance( + self.vision_backbone, Sam3TriViTDetNeck + ), f"Expected vision backbone to be of type Sam3TriViTDetNeck, got {type(self.vision_backbone)}" + + def forward_image( + self, + samples, + *, + need_sam3_out: bool = True, + need_interactive_out: bool = True, + need_propagation_out: bool = True, + ): + self._compile() + # Forward through backbone + ( + sam3_features, + sam3_pos, + interactive_features, + interactive_pos, + propagation_features, + propagation_pos, + ) = self.vision_backbone( + samples, + need_sam3_out=need_sam3_out, + need_interactive_out=need_interactive_out, + need_propagation_out=need_propagation_out, + ) + + if self.scalp > 0: + sam3_features, sam3_pos = ( + sam3_features[: -self.scalp], + sam3_pos[: -self.scalp], + ) + interactive_features, interactive_pos = ( + interactive_features[: -self.scalp], + interactive_pos[: -self.scalp], + ) + propagation_features, propagation_pos = ( + propagation_features[: -self.scalp], + propagation_pos[: -self.scalp], + ) + + output = {} + + if need_sam3_out: + sam3_last = sam3_features[-1] + output.update( + { + "vision_features": sam3_last.tensors, + "vision_mask": sam3_last.mask, + "vision_pos_enc": sam3_pos, + "backbone_fpn": sam3_features, + } + ) + if need_interactive_out: + inte_last = interactive_features[-1] + output["interactive"] = { + "vision_features": inte_last.tensors, + "vision_mask": inte_last.mask, + "vision_pos_enc": interactive_pos, + "backbone_fpn": interactive_features, + } + if need_propagation_out: + prop_last = propagation_features[-1] + output["sam2_backbone_out"] = { + "vision_features": prop_last.tensors, + "vision_mask": prop_last.mask, + "vision_pos_enc": propagation_pos, + "backbone_fpn": propagation_features, + } + + return output diff --git a/third_party/sam3/sam3/model_builder.py b/third_party/sam3/sam3/model_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..7df3ffda8f69b963907ffdef2338d33b4d5d0b46 --- /dev/null +++ b/third_party/sam3/sam3/model_builder.py @@ -0,0 +1,1360 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +import os +from typing import Optional + +import pkg_resources +import torch +import torch.nn as nn +from huggingface_hub import hf_hub_download +from iopath.common.file_io import g_pathmgr +from sam3.model.decoder import ( + DecoupledTransformerDecoderLayerv2, + SimpleRoPEAttention, + TransformerDecoder, + TransformerDecoderLayer, + TransformerDecoderLayerv2, + TransformerEncoderCrossAttention, + TransformerEncoderDecoupledCrossAttention, +) +from sam3.model.encoder import TransformerEncoderFusion, TransformerEncoderLayer +from sam3.model.geometry_encoders import SequenceGeometryEncoder +from sam3.model.maskformer_segmentation import PixelDecoder, UniversalSegmentationHead +from sam3.model.memory import ( + CXBlock, + SimpleFuser, + SimpleMaskDownSampler, + SimpleMaskEncoder, +) +from sam3.model.model_misc import ( + DotProductScoring, + MLP, + MultiheadAttentionWrapper as MultiheadAttention, + TransformerWrapper, +) +from sam3.model.multiplex_utils import MultiplexController +from sam3.model.necks import Sam3DualViTDetNeck, Sam3TriViTDetNeck +from sam3.model.position_encoding import PositionEmbeddingSine +from sam3.model.sam1_task_predictor import SAM3InteractiveImagePredictor +from sam3.model.sam3_image import Sam3Image, Sam3ImageOnVideoMultiGPU +from sam3.model.sam3_tracking_predictor import Sam3TrackerPredictor +from sam3.model.sam3_video_inference import Sam3VideoInferenceWithInstanceInteractivity +from sam3.model.sam3_video_predictor import Sam3VideoPredictorMultiGPU +from sam3.model.text_encoder_ve import VETextEncoder +from sam3.model.tokenizer_ve import SimpleTokenizer +from sam3.model.video_tracking_multiplex import VideoTrackingDynamicMultiplex +from sam3.model.vitdet import ViT +from sam3.model.vl_combiner import SAM3VLBackbone, SAM3VLBackboneTri, TriHeadVisionOnly +from sam3.sam.transformer import RoPEAttention + + +# Setup TensorFloat-32 for Ampere GPUs if available +def _setup_tf32() -> None: + """Enable TensorFloat-32 for Ampere GPUs if available.""" + if torch.cuda.is_available(): + device_props = torch.cuda.get_device_properties(0) + if device_props.major >= 8: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + +_setup_tf32() + + +_ACT_CKPT_PRINTED = False + +def _get_use_act_checkpoint() -> bool: + """Get activation checkpointing setting from environment variable. + + Set SAM3_DISABLE_ACT_CKPT=1 to disable activation checkpointing. + This will use more GPU memory but speed up backward pass. + """ + global _ACT_CKPT_PRINTED + disable = os.environ.get("SAM3_DISABLE_ACT_CKPT", "0") == "1" + if not _ACT_CKPT_PRINTED: + if disable: + print("[SAM3] Activation checkpointing DISABLED (SAM3_DISABLE_ACT_CKPT=1)") + else: + print("[SAM3] Activation checkpointing ENABLED (default)") + _ACT_CKPT_PRINTED = True + return not disable + + +def _create_position_encoding(precompute_resolution=None): + """Create position encoding for visual backbone.""" + return PositionEmbeddingSine( + num_pos_feats=256, + normalize=True, + scale=None, + temperature=10000, + precompute_resolution=precompute_resolution, + ) + + +def _create_vit_backbone(compile_mode=None, use_fa3=False, use_rope_real=False): + """Create ViT backbone for visual feature extraction.""" + return ViT( + img_size=1008, + pretrain_img_size=336, + patch_size=14, + embed_dim=1024, + depth=32, + num_heads=16, + mlp_ratio=4.625, + norm_layer="LayerNorm", + drop_path_rate=0.1, + qkv_bias=True, + use_abs_pos=True, + tile_abs_pos=True, + global_att_blocks=(7, 15, 23, 31), + rel_pos_blocks=(), + use_rope=True, + use_interp_rope=True, + window_size=24, + pretrain_use_cls_token=True, + retain_cls_token=False, + ln_pre=True, + ln_post=False, + return_interm_layers=False, + bias_patch_embed=False, + compile_mode=compile_mode, + use_fa3=use_fa3, + use_rope_real=use_rope_real, + use_act_checkpoint=_get_use_act_checkpoint(), + ) + + +def _create_vit_neck(position_encoding, vit_backbone, enable_inst_interactivity=False): + """Create ViT neck for feature pyramid. + + Keep 4 scales to match encoder/decoder expectations (3 levels after scalp=1). + SAM3.1 ckpt only has convs.0-2 weights; convs.3 uses random init but is + discarded by scalp=1, so this is safe for both SAM3 and SAM3.1 checkpoints. + """ + return Sam3DualViTDetNeck( + position_encoding=position_encoding, + d_model=256, + scale_factors=[4.0, 2.0, 1.0, 0.5], + trunk=vit_backbone, + add_sam2_neck=enable_inst_interactivity, + ) + + +def _create_vl_backbone(vit_neck, text_encoder): + """Create visual-language backbone.""" + return SAM3VLBackbone(visual=vit_neck, text=text_encoder, scalp=1) + + +def _create_transformer_encoder(compile_mode=None, use_fa3=False) -> TransformerEncoderFusion: + """Create transformer encoder with its layer.""" + encoder_layer = TransformerEncoderLayer( + activation="relu", + d_model=256, + dim_feedforward=2048, + dropout=0.1, + pos_enc_at_attn=True, + pos_enc_at_cross_attn_keys=False, + pos_enc_at_cross_attn_queries=False, + pre_norm=True, + self_attention=MultiheadAttention( + num_heads=8, + dropout=0.1, + embed_dim=256, + batch_first=True, + use_fa3=use_fa3, + ), + cross_attention=MultiheadAttention( + num_heads=8, + dropout=0.1, + embed_dim=256, + batch_first=True, + use_fa3=use_fa3, + ), + ) + + encoder = TransformerEncoderFusion( + layer=encoder_layer, + num_layers=6, + d_model=256, + num_feature_levels=1, + frozen=False, + use_act_checkpoint=_get_use_act_checkpoint(), + add_pooled_text_to_img_feat=False, + pool_text_with_mask=True, + compile_mode=compile_mode, + ) + return encoder + + +def _create_transformer_decoder(compile_mode=None, use_fa3=False) -> TransformerDecoder: + """Create transformer decoder with its layer.""" + decoder_layer = TransformerDecoderLayer( + activation="relu", + d_model=256, + dim_feedforward=2048, + dropout=0.1, + cross_attention=MultiheadAttention( + num_heads=8, + dropout=0.1, + embed_dim=256, + use_fa3=use_fa3, + ), + n_heads=8, + use_text_cross_attention=True, + ) + + decoder = TransformerDecoder( + layer=decoder_layer, + num_layers=6, + num_queries=200, + return_intermediate=True, + box_refine=True, + num_o2m_queries=0, + dac=True, + boxRPB="log", + d_model=256, + frozen=False, + interaction_layer=None, + dac_use_selfatt_ln=True, + resolution=1008, + stride=14, + use_act_checkpoint=_get_use_act_checkpoint(), + presence_token=True, + compile_mode=compile_mode, + ) + return decoder + + +def _create_dot_product_scoring(): + """Create dot product scoring module.""" + prompt_mlp = MLP( + input_dim=256, + hidden_dim=2048, + output_dim=256, + num_layers=2, + dropout=0.1, + residual=True, + out_norm=nn.LayerNorm(256), + ) + return DotProductScoring(d_model=256, d_proj=256, prompt_mlp=prompt_mlp) + + +def _create_segmentation_head(compile_mode=None, use_fa3=False): + """Create segmentation head with pixel decoder.""" + pixel_decoder = PixelDecoder( + num_upsampling_stages=3, + interpolation_mode="nearest", + hidden_dim=256, + compile_mode=compile_mode, + ) + + cross_attend_prompt = MultiheadAttention( + num_heads=8, + dropout=0, + embed_dim=256, + use_fa3=use_fa3, + ) + + segmentation_head = UniversalSegmentationHead( + hidden_dim=256, + upsampling_stages=3, + aux_masks=False, + presence_head=False, + dot_product_scorer=None, + act_ckpt=True, + cross_attend_prompt=cross_attend_prompt, + pixel_decoder=pixel_decoder, + ) + return segmentation_head + + +def _create_geometry_encoder(): + """Create geometry encoder with all its components.""" + # Create position encoding for geometry encoder + geo_pos_enc = _create_position_encoding() + # Create CX block for fuser + cx_block = CXBlock( + dim=256, + kernel_size=7, + padding=3, + layer_scale_init_value=1.0e-06, + use_dwconv=True, + ) + # Create geometry encoder layer + geo_layer = TransformerEncoderLayer( + activation="relu", + d_model=256, + dim_feedforward=2048, + dropout=0.1, + pos_enc_at_attn=False, + pre_norm=True, + self_attention=MultiheadAttention( + num_heads=8, + dropout=0.1, + embed_dim=256, + batch_first=False, + ), + pos_enc_at_cross_attn_queries=False, + pos_enc_at_cross_attn_keys=True, + cross_attention=MultiheadAttention( + num_heads=8, + dropout=0.1, + embed_dim=256, + batch_first=False, + ), + ) + + # Create geometry encoder + input_geometry_encoder = SequenceGeometryEncoder( + pos_enc=geo_pos_enc, + encode_boxes_as_points=False, + points_direct_project=True, + points_pool=True, + points_pos_enc=True, + boxes_direct_project=True, + boxes_pool=True, + boxes_pos_enc=True, + d_model=256, + num_layers=3, + layer=geo_layer, + use_act_ckpt=_get_use_act_checkpoint(), + add_cls=True, + add_post_encode_proj=True, + ) + return input_geometry_encoder + + +def _create_sam3_model( + backbone, + transformer, + input_geometry_encoder, + segmentation_head, + dot_prod_scoring, + inst_interactive_predictor, + eval_mode, +): + """Create the SAM3 image model.""" + common_params = { + "backbone": backbone, + "transformer": transformer, + "input_geometry_encoder": input_geometry_encoder, + "segmentation_head": segmentation_head, + "num_feature_levels": 1, + "o2m_mask_predict": True, + "dot_prod_scoring": dot_prod_scoring, + "use_instance_query": False, + "multimask_output": True, + "inst_interactive_predictor": inst_interactive_predictor, + } + + matcher = None + if not eval_mode: + from sam3.train.matcher import BinaryHungarianMatcherV2 + + matcher = BinaryHungarianMatcherV2( + focal=True, + cost_class=2.0, + cost_bbox=5.0, + cost_giou=2.0, + alpha=0.25, + gamma=2, + stable=False, + ) + common_params["matcher"] = matcher + model = Sam3Image(**common_params) + + return model + + +def _create_tracker_maskmem_backbone(): + """Create the SAM3 Tracker memory encoder.""" + # Position encoding for mask memory backbone + position_encoding = PositionEmbeddingSine( + num_pos_feats=64, + normalize=True, + scale=None, + temperature=10000, + precompute_resolution=1008, + ) + + # Mask processing components + mask_downsampler = SimpleMaskDownSampler( + kernel_size=3, stride=2, padding=1, interpol_size=[1152, 1152] + ) + + cx_block_layer = CXBlock( + dim=256, + kernel_size=7, + padding=3, + layer_scale_init_value=1.0e-06, + use_dwconv=True, + ) + + fuser = SimpleFuser(layer=cx_block_layer, num_layers=2) + + maskmem_backbone = SimpleMaskEncoder( + out_dim=64, + position_encoding=position_encoding, + mask_downsampler=mask_downsampler, + fuser=fuser, + ) + + return maskmem_backbone + + +def _create_tracker_transformer(): + """Create the SAM3 Tracker transformer components.""" + # Self attention + self_attention = RoPEAttention( + embedding_dim=256, + num_heads=1, + downsample_rate=1, + dropout=0.1, + rope_theta=10000.0, + feat_sizes=[72, 72], + use_fa3=False, + use_rope_real=False, + ) + + # Cross attention + cross_attention = RoPEAttention( + embedding_dim=256, + num_heads=1, + downsample_rate=1, + dropout=0.1, + kv_in_dim=64, + rope_theta=10000.0, + feat_sizes=[72, 72], + rope_k_repeat=True, + use_fa3=False, + use_rope_real=False, + ) + + # Encoder layer + encoder_layer = TransformerDecoderLayerv2( + cross_attention_first=False, + activation="relu", + dim_feedforward=2048, + dropout=0.1, + pos_enc_at_attn=False, + pre_norm=True, + self_attention=self_attention, + d_model=256, + pos_enc_at_cross_attn_keys=True, + pos_enc_at_cross_attn_queries=False, + cross_attention=cross_attention, + ) + + # Encoder + encoder = TransformerEncoderCrossAttention( + remove_cross_attention_layers=[], + batch_first=True, + d_model=256, + frozen=False, + pos_enc_at_input=True, + layer=encoder_layer, + num_layers=4, + use_act_checkpoint=False, + ) + + # Transformer wrapper + transformer = TransformerWrapper( + encoder=encoder, + decoder=None, + d_model=256, + ) + + return transformer + + +def build_tracker( + apply_temporal_disambiguation: bool, with_backbone: bool = False, compile_mode=None +) -> Sam3TrackerPredictor: + """ + Build the SAM3 Tracker module for video tracking. + + Returns: + Sam3TrackerPredictor: Wrapped SAM3 Tracker module + """ + + # Create model components + maskmem_backbone = _create_tracker_maskmem_backbone() + transformer = _create_tracker_transformer() + backbone = None + if with_backbone: + vision_backbone = _create_vision_backbone(compile_mode=compile_mode) + backbone = SAM3VLBackbone(scalp=1, visual=vision_backbone, text=None) + # Create the Tracker module + model = Sam3TrackerPredictor( + image_size=1008, + num_maskmem=7, + backbone=backbone, + backbone_stride=14, + transformer=transformer, + maskmem_backbone=maskmem_backbone, + # SAM parameters + multimask_output_in_sam=True, + # Evaluation + forward_backbone_per_frame_for_eval=True, + trim_past_non_cond_mem_for_eval=False, + # Multimask + multimask_output_for_tracking=True, + multimask_min_pt_num=0, + multimask_max_pt_num=1, + # Additional settings + always_start_from_first_ann_frame=False, + # Mask overlap + non_overlap_masks_for_mem_enc=False, + non_overlap_masks_for_output=False, + max_cond_frames_in_attn=4, + offload_output_to_cpu_for_eval=False, + # SAM decoder settings + sam_mask_decoder_extra_args={ + "dynamic_multimask_via_stability": True, + "dynamic_multimask_stability_delta": 0.05, + "dynamic_multimask_stability_thresh": 0.98, + }, + clear_non_cond_mem_around_input=True, + fill_hole_area=0, + use_memory_selection=apply_temporal_disambiguation, + ) + + return model + + +def _create_text_encoder(bpe_path: str) -> VETextEncoder: + """Create SAM3 text encoder. + + Text encoder act_ckpt hardcoded False: negative sampling changes + the text batch size per step, causing checkpoint recompute shape + mismatch. Text encoder memory is small so disabling has minimal impact. + All other modules (ViT, encoder, decoder, geometry) keep act_ckpt=True. + """ + tokenizer = SimpleTokenizer(bpe_path=bpe_path) + return VETextEncoder( + tokenizer=tokenizer, + d_model=256, + width=1024, + heads=16, + layers=24, + use_act_checkpoint=False, + ) + + +def _create_vision_backbone( + compile_mode=None, enable_inst_interactivity=True +) -> Sam3DualViTDetNeck: + """Create SAM3 visual backbone with ViT and neck.""" + # Position encoding + position_encoding = _create_position_encoding(precompute_resolution=1008) + # ViT backbone + vit_backbone: ViT = _create_vit_backbone(compile_mode=compile_mode) + vit_neck: Sam3DualViTDetNeck = _create_vit_neck( + position_encoding, + vit_backbone, + enable_inst_interactivity=enable_inst_interactivity, + ) + # Visual neck + return vit_neck + + +def _create_sam3_transformer( + has_presence_token: bool = True, compile_mode=None, use_fa3: bool = False +) -> TransformerWrapper: + """Create SAM3 transformer encoder and decoder.""" + encoder: TransformerEncoderFusion = _create_transformer_encoder(compile_mode=compile_mode, use_fa3=use_fa3) + decoder: TransformerDecoder = _create_transformer_decoder(compile_mode=compile_mode, use_fa3=use_fa3) + + return TransformerWrapper(encoder=encoder, decoder=decoder, d_model=256) + + +def _load_checkpoint(model, checkpoint_path): + """Load model checkpoint from file.""" + with g_pathmgr.open(checkpoint_path, "rb") as f: + ckpt = torch.load(f, map_location="cpu", weights_only=True) + if "model" in ckpt and isinstance(ckpt["model"], dict): + ckpt = ckpt["model"] + sam3_image_ckpt = {} + for k, v in ckpt.items(): + if "detector" not in k: + continue + new_k = k.replace("detector.", "") + # SAM3.1 renames sam2_convs -> interactive_convs; map back for DualViTDetNeck + new_k = new_k.replace("interactive_convs", "sam2_convs") + sam3_image_ckpt[new_k] = v + if model.inst_interactive_predictor is not None: + sam3_image_ckpt.update( + { + k.replace("tracker.", "inst_interactive_predictor.model."): v + for k, v in ckpt.items() + if "tracker" in k + } + ) + missing_keys, _ = model.load_state_dict(sam3_image_ckpt, strict=False) + if len(missing_keys) > 0: + print( + f"loaded {checkpoint_path} and found " + f"missing and/or unexpected keys:\n{missing_keys=}" + ) + + +def _setup_device_and_mode(model, device, eval_mode): + """Setup model device and evaluation mode.""" + if device == "cuda": + model = model.cuda() + if eval_mode: + model.eval() + return model + + +def build_sam3_image_model( + bpe_path=None, + device="cuda" if torch.cuda.is_available() else "cpu", + eval_mode=True, + checkpoint_path=None, + load_from_HF=True, + enable_segmentation=True, + enable_inst_interactivity=False, + compile=False, +): + """ + Build SAM3 image model + + Args: + bpe_path: Path to the BPE tokenizer vocabulary + device: Device to load the model on ('cuda' or 'cpu') + eval_mode: Whether to set the model to evaluation mode + checkpoint_path: Optional path to model checkpoint + enable_segmentation: Whether to enable segmentation head + enable_inst_interactivity: Whether to enable instance interactivity (SAM 1 task) + compile_mode: To enable compilation, set to "default" + + Returns: + A SAM3 image model + """ + if bpe_path is None: + bpe_path = pkg_resources.resource_filename( + "sam3", "assets/bpe_simple_vocab_16e6.txt.gz" + ) + + # Create visual components + compile_mode = "default" if compile else None + vision_encoder = _create_vision_backbone( + compile_mode=compile_mode, enable_inst_interactivity=enable_inst_interactivity + ) + + # Create text components + text_encoder = _create_text_encoder(bpe_path) + + # Create visual-language backbone + backbone = _create_vl_backbone(vision_encoder, text_encoder) + + # Create transformer components + # NOTE: Do NOT compile encoder/decoder - they have dynamic shapes that cause errors + # Only ViT backbone is compiled (via vision_encoder above) + transformer = _create_sam3_transformer(compile_mode=None) + + # Create dot product scoring + dot_prod_scoring = _create_dot_product_scoring() + + # Create segmentation head if enabled + # NOTE: Do NOT compile segmentation head - may have dynamic shapes + segmentation_head = ( + _create_segmentation_head(compile_mode=None) + if enable_segmentation + else None + ) + + # Create geometry encoder + input_geometry_encoder = _create_geometry_encoder() + if enable_inst_interactivity: + sam3_pvs_base = build_tracker(apply_temporal_disambiguation=False) + inst_predictor = SAM3InteractiveImagePredictor(sam3_pvs_base) + else: + inst_predictor = None + # Create the SAM3 model + model = _create_sam3_model( + backbone, + transformer, + input_geometry_encoder, + segmentation_head, + dot_prod_scoring, + inst_predictor, + eval_mode, + ) + if load_from_HF and checkpoint_path is None: + checkpoint_path = download_ckpt_from_hf(version="sam3") + # Load checkpoint if provided + if checkpoint_path is not None: + _load_checkpoint(model, checkpoint_path) + + # Setup device and mode + model = _setup_device_and_mode(model, device, eval_mode) + + return model + + +def download_ckpt_from_hf(version="sam3"): + """Download model checkpoint from HuggingFace Hub. + + Args: + version: "sam3" or "sam3.1" + """ + if version == "sam3.1": + repo_id = "facebook/sam3.1" + ckpt_name = "sam3.1_multiplex.pt" + cfg_name = "config.json" + else: + repo_id = "facebook/sam3" + ckpt_name = "sam3.pt" + cfg_name = "config.json" + _ = hf_hub_download(repo_id=repo_id, filename=cfg_name) + checkpoint_path = hf_hub_download(repo_id=repo_id, filename=ckpt_name) + return checkpoint_path + + +def build_sam3_video_model( + checkpoint_path: Optional[str] = None, + load_from_HF=True, + bpe_path: Optional[str] = None, + has_presence_token: bool = True, + geo_encoder_use_img_cross_attn: bool = True, + strict_state_dict_loading: bool = True, + apply_temporal_disambiguation: bool = True, + device="cuda" if torch.cuda.is_available() else "cpu", + compile=False, +) -> Sam3VideoInferenceWithInstanceInteractivity: + """ + Build SAM3 dense tracking model. + + Args: + checkpoint_path: Optional path to checkpoint file + bpe_path: Path to the BPE tokenizer file + + Returns: + Sam3VideoInferenceWithInstanceInteractivity: The instantiated dense tracking model + """ + if bpe_path is None: + bpe_path = pkg_resources.resource_filename( + "sam3", "assets/bpe_simple_vocab_16e6.txt.gz" + ) + + # Build Tracker module + tracker = build_tracker(apply_temporal_disambiguation=apply_temporal_disambiguation) + + # Build Detector components + visual_neck = _create_vision_backbone() + text_encoder = _create_text_encoder(bpe_path) + backbone = SAM3VLBackbone(scalp=1, visual=visual_neck, text=text_encoder) + transformer = _create_sam3_transformer(has_presence_token=has_presence_token) + segmentation_head: UniversalSegmentationHead = _create_segmentation_head() + input_geometry_encoder = _create_geometry_encoder() + + # Create main dot product scoring + main_dot_prod_mlp = MLP( + input_dim=256, + hidden_dim=2048, + output_dim=256, + num_layers=2, + dropout=0.1, + residual=True, + out_norm=nn.LayerNorm(256), + ) + main_dot_prod_scoring = DotProductScoring( + d_model=256, d_proj=256, prompt_mlp=main_dot_prod_mlp + ) + + # Build Detector module + detector = Sam3ImageOnVideoMultiGPU( + num_feature_levels=1, + backbone=backbone, + transformer=transformer, + segmentation_head=segmentation_head, + semantic_segmentation_head=None, + input_geometry_encoder=input_geometry_encoder, + use_early_fusion=True, + use_dot_prod_scoring=True, + dot_prod_scoring=main_dot_prod_scoring, + supervise_joint_box_scores=has_presence_token, + ) + + # Build the main SAM3 video model + if apply_temporal_disambiguation: + model = Sam3VideoInferenceWithInstanceInteractivity( + detector=detector, + tracker=tracker, + score_threshold_detection=0.5, + assoc_iou_thresh=0.1, + det_nms_thresh=0.1, + new_det_thresh=0.7, + hotstart_delay=15, + hotstart_unmatch_thresh=8, + hotstart_dup_thresh=8, + suppress_unmatched_only_within_hotstart=True, + min_trk_keep_alive=-1, + max_trk_keep_alive=30, + init_trk_keep_alive=30, + suppress_overlapping_based_on_recent_occlusion_threshold=0.7, + suppress_det_close_to_boundary=False, + fill_hole_area=16, + recondition_every_nth_frame=16, + masklet_confirmation_enable=False, + decrease_trk_keep_alive_for_empty_masklets=False, + image_size=1008, + image_mean=(0.5, 0.5, 0.5), + image_std=(0.5, 0.5, 0.5), + compile_model=compile, + ) + else: + # a version without any heuristics for ablation studies + model = Sam3VideoInferenceWithInstanceInteractivity( + detector=detector, + tracker=tracker, + score_threshold_detection=0.5, + assoc_iou_thresh=0.1, + det_nms_thresh=0.1, + new_det_thresh=0.7, + hotstart_delay=0, + hotstart_unmatch_thresh=0, + hotstart_dup_thresh=0, + suppress_unmatched_only_within_hotstart=True, + min_trk_keep_alive=-1, + max_trk_keep_alive=30, + init_trk_keep_alive=30, + suppress_overlapping_based_on_recent_occlusion_threshold=0.7, + suppress_det_close_to_boundary=False, + fill_hole_area=16, + recondition_every_nth_frame=0, + masklet_confirmation_enable=False, + decrease_trk_keep_alive_for_empty_masklets=False, + image_size=1008, + image_mean=(0.5, 0.5, 0.5), + image_std=(0.5, 0.5, 0.5), + compile_model=compile, + ) + + # Load checkpoint if provided + if load_from_HF and checkpoint_path is None: + checkpoint_path = download_ckpt_from_hf(version="sam3") + if checkpoint_path is not None: + with g_pathmgr.open(checkpoint_path, "rb") as f: + ckpt = torch.load(f, map_location="cpu", weights_only=True) + if "model" in ckpt and isinstance(ckpt["model"], dict): + ckpt = ckpt["model"] + + missing_keys, unexpected_keys = model.load_state_dict( + ckpt, strict=strict_state_dict_loading + ) + if missing_keys: + print(f"Missing keys: {missing_keys}") + if unexpected_keys: + print(f"Unexpected keys: {unexpected_keys}") + + model.to(device=device) + return model + + +def build_sam3_video_predictor(*model_args, gpus_to_use=None, **model_kwargs): + return Sam3VideoPredictorMultiGPU( + *model_args, gpus_to_use=gpus_to_use, **model_kwargs + ) + + +def _create_multiplex_maskmem_backbone(multiplex_count=16): + """Create the multiplex memory encoder with per-object mask channels.""" + position_encoding = PositionEmbeddingSine( + num_pos_feats=256, + normalize=True, + scale=None, + temperature=10000, + precompute_resolution=1008, + ) + + mask_downsampler = SimpleMaskDownSampler( + kernel_size=3, + stride=2, + padding=1, + interpol_size=[1152, 1152], + multiplex_count=multiplex_count, + starting_out_chan=4, + input_channel_multiplier=2, + ) + + cx_block_layer = CXBlock( + dim=256, + kernel_size=7, + padding=3, + layer_scale_init_value=1.0e-06, + use_dwconv=True, + ) + + fuser = SimpleFuser(layer=cx_block_layer, num_layers=2) + + maskmem_backbone = SimpleMaskEncoder( + out_dim=256, + position_encoding=position_encoding, + mask_downsampler=mask_downsampler, + fuser=fuser, + ) + + return maskmem_backbone + + +def _create_multiplex_transformer(use_fa3=False, use_rope_real=False): + """Create the decoupled transformer for multiplex memory attention.""" + self_attention_rope = SimpleRoPEAttention( + d_model=256, + num_heads=8, + dropout_p=0.1, + rope_theta=10000.0, + feat_sizes=[72, 72], + use_fa3=use_fa3, + use_rope_real=use_rope_real, + ) + + cross_attention_rope = SimpleRoPEAttention( + d_model=256, + num_heads=8, + dropout_p=0.1, + rope_theta=10000.0, + feat_sizes=[72, 72], + rope_k_repeat=True, + use_fa3=use_fa3, + use_rope_real=use_rope_real, + ) + + encoder_layer = DecoupledTransformerDecoderLayerv2( + activation="gelu", + d_model=256, + num_heads=8, + dropout=0.1, + dim_feedforward=2048, + pos_enc_at_attn=False, + pre_norm=True, + pos_enc_at_cross_attn_keys=True, + pos_enc_at_cross_attn_queries=False, + self_attention_rope=self_attention_rope, + cross_attention_rope=cross_attention_rope, + ) + + encoder = TransformerEncoderDecoupledCrossAttention( + d_model=256, + frozen=False, + pos_enc_at_input=True, + use_image_in_output=False, + layer=encoder_layer, + num_layers=4, + use_act_checkpoint=False, + batch_first=True, + ) + + transformer = TransformerWrapper( + encoder=encoder, + decoder=None, + d_model=256, + ) + + return transformer + + +def _create_multiplex_tri_backbone( + compile_mode=None, use_fa3=False, use_rope_real=False +): + """Create the TriHead vision backbone for multiplex model.""" + position_encoding = _create_position_encoding(precompute_resolution=1008) + vit_backbone = _create_vit_backbone( + compile_mode=compile_mode, use_fa3=use_fa3, use_rope_real=use_rope_real + ) + tri_neck = Sam3TriViTDetNeck( + trunk=vit_backbone, + position_encoding=position_encoding, + d_model=256, + scale_factors=[4.0, 2.0, 1.0], + ) + return tri_neck + + +def build_sam3_multiplex_video_model( + checkpoint_path: Optional[str] = None, + load_from_HF=True, + multiplex_count: int = 16, + use_fa3: bool = False, + use_rope_real: bool = False, + strict_state_dict_loading: bool = True, + device="cuda" if torch.cuda.is_available() else "cpu", + compile=False, +): + """ + Build SAM3 multiplex video tracking model. + + Args: + checkpoint_path: Optional path to checkpoint file + multiplex_count: Number of objects per multiplex bucket + use_fa3: Whether to use FlashAttention 3 + use_rope_real: Whether to use real-valued RoPE (for compile compat) + strict_state_dict_loading: Whether to use strict state dict loading + device: Device to place model on + compile: Whether to compile model components + + Returns: + VideoTrackingDynamicMultiplex: The instantiated multiplex tracking model + """ + # Build multiplex-specific components + maskmem_backbone = _create_multiplex_maskmem_backbone( + multiplex_count=multiplex_count + ) + transformer = _create_multiplex_transformer( + use_fa3=use_fa3, use_rope_real=use_rope_real + ) + tri_neck = _create_multiplex_tri_backbone( + compile_mode="max-autotune" if compile else None + ) + backbone = TriHeadVisionOnly( + visual=tri_neck, + n_features=256, + scalp=0, + ) + multiplex_controller = MultiplexController( + multiplex_count=multiplex_count, + eval_multiplex_count=multiplex_count, + ) + + # Build the multiplex model (use demo class for init_state and other demo methods) + from sam3.model.video_tracking_multiplex_demo import Sam3VideoTrackingMultiplexDemo + + model = Sam3VideoTrackingMultiplexDemo( + backbone=backbone, + transformer=transformer, + maskmem_backbone=maskmem_backbone, + multiplex_controller=multiplex_controller, + image_size=1008, + backbone_stride=14, + num_maskmem=7, + # Multiplex-specific settings + use_high_res_features_in_sam=True, + use_obj_ptrs_in_encoder=True, + max_obj_ptrs_in_encoder=16, + add_tpos_enc_to_obj_ptrs=True, + proj_tpos_enc_in_obj_ptrs=True, + use_mlp_for_obj_ptr_proj=True, + pred_obj_scores=True, + pred_obj_scores_mlp=True, + fixed_no_obj_ptr=True, + use_no_obj_ptr=True, + use_linear_no_obj_ptr=True, + no_obj_embed_spatial=True, + sincos_tpos_enc=True, + # Multimask settings + multimask_output_in_sam=True, + multimask_output_for_tracking=True, + multimask_min_pt_num=0, + multimask_max_pt_num=1, + use_multimask_token_for_obj_ptr=True, + num_multimask_outputs=3, + # Memory encoder settings + apply_sigmoid_to_mask_logits_for_mem_enc=True, + sigmoid_scale_for_mem_enc=2.0, + sigmoid_bias_for_mem_enc=-1.0, + non_overlap_masks_for_mem_enc=False, + # Suppression/conditional embeddings + add_output_suppression_embeddings=True, + add_object_conditional_embeddings=False, + condition_as_mask_input=True, + condition_as_mask_input_fg=1.0, + condition_as_mask_input_bg=0.0, + # Memory settings + use_maskmem_tpos_v2=True, + save_image_features=True, + randomness_fix=True, + # Interaction settings + use_mask_input_as_output_without_sam=True, + directly_add_no_mem_embed=True, + iou_prediction_use_sigmoid=False, + forward_backbone_per_frame_for_eval=True, + offload_output_to_cpu_for_eval=False, + trim_past_non_cond_mem_for_eval=False, + max_cond_frames_in_attn=4, + # Dynamic multiplex settings + is_dynamic_model=True, + # SAM mask decoder extra args + sam_mask_decoder_extra_args={ + "dynamic_multimask_via_stability": True, + "dynamic_multimask_stability_delta": 0.05, + "dynamic_multimask_stability_thresh": 0.98, + }, + compile_all_components=compile, + use_memory_selection=False, + ) + + # Load checkpoint if provided + if load_from_HF and checkpoint_path is None: + checkpoint_path = download_ckpt_from_hf(version="sam3.1") + if checkpoint_path is not None: + with g_pathmgr.open(checkpoint_path, "rb") as f: + ckpt = torch.load(f, map_location="cpu", weights_only=True) + if "model" in ckpt and isinstance(ckpt["model"], dict): + ckpt = ckpt["model"] + + missing_keys, unexpected_keys = model.load_state_dict( + ckpt, strict=strict_state_dict_loading + ) + if missing_keys: + print(f"Missing keys: {missing_keys}") + if unexpected_keys: + print(f"Unexpected keys: {unexpected_keys}") + + model.to(device=device) + return model + + +def build_sam3_multiplex_video_predictor( + checkpoint_path: Optional[str] = None, + bpe_path: Optional[str] = None, + max_num_objects: int = 16, + multiplex_count: int = 16, + use_fa3: bool = True, + use_rope_real: bool = True, + compile: bool = False, + warm_up: bool = False, + session_expiration_sec: int = 1200, + default_output_prob_thresh: float = 0.5, + async_loading_frames: bool = True, +): + """ + Build a fully-initialized Sam3MultiplexVideoPredictor. + + This is the recommended entry point for SAM 3.1 multiplex video tracking. + It builds the full model stack (tracker + detector + demo model), loads + the checkpoint, and wraps everything in Sam3MultiplexVideoPredictor with + handle_request / handle_stream_request API. + + Args: + checkpoint_path: Path to the merged multiplex checkpoint + bpe_path: Path to the BPE tokenizer vocabulary + max_num_objects: Maximum number of tracked objects + multiplex_count: Number of objects per multiplex bucket + use_fa3: Whether to use FlashAttention 3 + use_rope_real: Whether to use real-valued RoPE (for compile compat) + compile: Whether to enable torch.compile on model components + warm_up: Whether to run warm-up compilation (requires compile=True) + session_expiration_sec: Session expiration timeout in seconds + default_output_prob_thresh: Default probability threshold for output masks + async_loading_frames: Whether to load frames asynchronously + + Returns: + Sam3MultiplexVideoPredictor: The fully-initialized predictor + """ + if bpe_path is None: + bpe_path = pkg_resources.resource_filename( + "sam3", "assets/bpe_simple_vocab_16e6.txt.gz" + ) + + from sam3.model.sam3_multiplex_base import Sam3MultiplexPredictorWrapper + from sam3.model.sam3_multiplex_detector import Sam3MultiplexDetector + from sam3.model.sam3_multiplex_tracking import ( + Sam3MultiplexTrackingWithInteractivity, + ) + from sam3.model.sam3_multiplex_video_predictor import Sam3MultiplexVideoPredictor + + # Build tracker + tracker_model = build_sam3_multiplex_video_model( + checkpoint_path=checkpoint_path, + load_from_HF=False, + multiplex_count=multiplex_count, + use_fa3=use_fa3, + use_rope_real=use_rope_real, + compile=False, + strict_state_dict_loading=False, + ) + del tracker_model.backbone + tracker_model.backbone = None + + sam2_predictor = Sam3MultiplexPredictorWrapper( + model=tracker_model, + per_obj_inference=False, + fill_hole_area=0, + is_multiplex=True, + is_multiplex_dynamic=True, + ) + + # Build detector + tri_neck = _create_multiplex_tri_backbone( + compile_mode=None, use_fa3=use_fa3, use_rope_real=use_rope_real + ) + text_encoder = _create_text_encoder(bpe_path) + backbone = SAM3VLBackboneTri(scalp=0, visual=tri_neck, text=text_encoder) + transformer = _create_sam3_transformer(use_fa3=use_fa3) + segmentation_head = _create_segmentation_head(use_fa3=use_fa3) + geometry_encoder = _create_geometry_encoder() + dot_prod_scoring = _create_dot_product_scoring() + + detector = Sam3MultiplexDetector( + num_feature_levels=1, + backbone=backbone, + transformer=transformer, + segmentation_head=segmentation_head, + semantic_segmentation_head=None, + input_geometry_encoder=geometry_encoder, + use_early_fusion=True, + use_dot_prod_scoring=True, + dot_prod_scoring=dot_prod_scoring, + supervise_joint_box_scores=True, + is_multiplex=True, + ) + + # Assemble demo model + demo_model = Sam3MultiplexTrackingWithInteractivity( + tracker=sam2_predictor, + detector=detector, + score_threshold_detection=0.4, + det_nms_thresh=0.1, + det_nms_use_iom=True, + assoc_iou_thresh=0.1, + new_det_thresh=0.65, + hotstart_delay=15, + hotstart_unmatch_thresh=8, + hotstart_dup_thresh=8, + suppress_unmatched_only_within_hotstart=False, + suppress_overlapping_based_on_recent_occlusion_threshold=0.7, + suppress_det_close_to_boundary=True, + fill_hole_area=0, # OV effectively 0 (Sam3MultiplexTrackerPredictor Hydra override clobbers yaml's 16) + recondition_every_nth_frame=16, + use_iom_recondition=True, + iom_thresh_recondition=0.5, + masklet_confirmation_enable=True, + reconstruction_bbox_iou_thresh=-1, + reconstruction_bbox_det_score=0.8, + max_num_objects=max_num_objects, + postprocess_batch_size=16, + use_batched_grounding=True, + batched_grounding_batch_size=16, + max_num_kboxes=0, + sprinkle_removal_area=0, + is_multiplex=True, + image_size=1008, + image_mean=(0.5, 0.5, 0.5), + image_std=(0.5, 0.5, 0.5), + compile_model=compile, + ) + + # Load checkpoint (auto-download from HuggingFace if not provided) + if checkpoint_path is None: + checkpoint_path = download_ckpt_from_hf(version="sam3.1") + if checkpoint_path is not None: + ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=True) + if "model" in ckpt and isinstance(ckpt["model"], dict): + ckpt = ckpt["model"] + # Remap checkpoint keys if needed (internal naming -> OSS naming) + # HF checkpoints are already remapped; local checkpoints may use old naming + needs_remap = any( + k.startswith("sam3_model.") or k.startswith("sam2_predictor.") for k in ckpt + ) + if needs_remap: + remapped_ckpt = {} + for k, v in ckpt.items(): + new_k = k + if k.startswith("sam3_model."): + new_k = "detector." + k[len("sam3_model.") :] + elif k.startswith("sam2_predictor."): + new_k = "tracker." + k[len("sam2_predictor.") :] + remapped_ckpt[new_k] = v + ckpt = remapped_ckpt + missing_keys, unexpected_keys = demo_model.load_state_dict(ckpt, strict=False) + if missing_keys: + print(f"Missing keys ({len(missing_keys)}): {missing_keys[:10]}...") + if unexpected_keys: + print( + f"Unexpected keys ({len(unexpected_keys)}): {unexpected_keys[:10]}..." + ) + + demo_model.cuda().eval() + + # Wrap in predictor + predictor = Sam3MultiplexVideoPredictor( + model=demo_model, + session_expiration_sec=session_expiration_sec, + default_output_prob_thresh=default_output_prob_thresh, + async_loading_frames=async_loading_frames, + warm_up=warm_up, + ) + return predictor + + +def build_sam3_predictor( + checkpoint_path: Optional[str] = None, + bpe_path: Optional[str] = None, + version: str = "sam3.1", # "sam3" or "sam3.1" + compile: bool = False, + warm_up: bool = False, + # SAM 3.1 specific + max_num_objects: int = 16, + multiplex_count: int = 16, + # Common + use_fa3: bool = True, + use_rope_real: bool = True, + async_loading_frames: bool = True, + **kwargs, +): + """ + Build a SAM3 video predictor. + + Args: + checkpoint_path: Path to model checkpoint + bpe_path: Path to BPE tokenizer vocabulary + version: Model version - "sam3" for base or "sam3.1" for multiplex + compile: Enable torch.compile for ~2x speedup (SAM 3.1 only currently) + warm_up: Run warm-up compilation passes + max_num_objects: Maximum tracked objects (SAM 3.1 only) + multiplex_count: Objects per multiplex bucket (SAM 3.1 only) + use_fa3: Use Flash Attention 3 + use_rope_real: Use real-valued RoPE + async_loading_frames: Load video frames asynchronously + **kwargs: Additional arguments passed to the underlying builder + + Returns: + A predictor with handle_request() and handle_stream_request() API. + Both versions support: start_session, add_prompt, propagate_in_video, + remove_object, reset_session, close_session. + + Example: + # SAM 3.1 (auto-downloads from HuggingFace): + predictor = build_sam3_predictor(version="sam3.1", compile=True) + + # SAM 3 (auto-downloads from HuggingFace): + predictor = build_sam3_predictor(version="sam3") + + # Or with a local checkpoint: + predictor = build_sam3_predictor(checkpoint_path="path/to/ckpt.pt", version="sam3.1") + + # Both use the same API: + response = predictor.handle_request({"type": "start_session", "resource_path": video_dir}) + session_id = response["session_id"] + predictor.handle_request({"type": "add_prompt", "session_id": session_id, "frame_index": 0, "text": "person"}) + for out in predictor.handle_stream_request({"type": "propagate_in_video", "session_id": session_id}): + masks = out["out_binary_masks"] + """ + if version == "sam3.1": + return build_sam3_multiplex_video_predictor( + checkpoint_path=checkpoint_path, + bpe_path=bpe_path, + max_num_objects=max_num_objects, + multiplex_count=multiplex_count, + use_fa3=use_fa3, + use_rope_real=use_rope_real, + compile=compile, + warm_up=warm_up, + async_loading_frames=async_loading_frames, + **kwargs, + ) + elif version == "sam3": + return build_sam3_video_predictor( + checkpoint_path=checkpoint_path, + bpe_path=bpe_path, + compile=compile, + async_loading_frames=async_loading_frames, + **kwargs, + ) + else: + raise ValueError(f"Unknown version: {version!r}. Use 'sam3' or 'sam3.1'.") diff --git a/third_party/sam3/sam3/perflib/__init__.py b/third_party/sam3/sam3/perflib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9aec2af3107e6e773fb7ec2d5e88e28aa8175afe --- /dev/null +++ b/third_party/sam3/sam3/perflib/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +import os + +is_enabled = False +if os.getenv("USE_PERFLIB", "1") == "1": + # print("Enabled the use of perflib.\n", end="") + is_enabled = True diff --git a/third_party/sam3/sam3/perflib/__pycache__/__init__.cpython-311.pyc b/third_party/sam3/sam3/perflib/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aafb1af3a8570e357bd188670af2a82d4b201918 Binary files /dev/null and b/third_party/sam3/sam3/perflib/__pycache__/__init__.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/perflib/__pycache__/compile.cpython-311.pyc b/third_party/sam3/sam3/perflib/__pycache__/compile.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7fae8be464bb2fefa1e76bbced58de24b239a341 Binary files /dev/null and b/third_party/sam3/sam3/perflib/__pycache__/compile.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/perflib/__pycache__/fused.cpython-311.pyc b/third_party/sam3/sam3/perflib/__pycache__/fused.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7654ff56a8441088e8a9d787f08f6623fa245dcd Binary files /dev/null and b/third_party/sam3/sam3/perflib/__pycache__/fused.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/perflib/__pycache__/masks_ops.cpython-311.pyc b/third_party/sam3/sam3/perflib/__pycache__/masks_ops.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fdf93e4ebe919565e0d41b7db19aed5a6f19cf95 Binary files /dev/null and b/third_party/sam3/sam3/perflib/__pycache__/masks_ops.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/perflib/__pycache__/nms.cpython-311.pyc b/third_party/sam3/sam3/perflib/__pycache__/nms.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db63c9ae6c9a78066b2d8d96ab62e6373588f519 Binary files /dev/null and b/third_party/sam3/sam3/perflib/__pycache__/nms.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/perflib/associate_det_trk.py b/third_party/sam3/sam3/perflib/associate_det_trk.py new file mode 100644 index 0000000000000000000000000000000000000000..7e0d29d0bbc71bb24e9afaf63da3c40976ecb0a5 --- /dev/null +++ b/third_party/sam3/sam3/perflib/associate_det_trk.py @@ -0,0 +1,139 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +from collections import defaultdict + +import torch +import torch.nn.functional as F +from sam3.perflib.masks_ops import mask_iou +from scipy.optimize import linear_sum_assignment + + +def associate_det_trk( + det_masks, + track_masks, + iou_threshold=0.5, + iou_threshold_trk=0.5, + det_scores=None, + new_det_thresh=0.0, +): + """ + Optimized implementation of detection <-> track association that minimizes DtoH syncs. + + Args: + det_masks: (N, H, W) tensor of predicted masks + track_masks: (M, H, W) tensor of track masks + + Returns: + new_det_indices: list of indices in det_masks considered 'new' + unmatched_trk_indices: list of indices in track_masks considered 'unmatched' + """ + with torch.autograd.profiler.record_function("perflib: associate_det_trk"): + assert isinstance(det_masks, torch.Tensor), "det_masks should be a tensor" + assert isinstance(track_masks, torch.Tensor), "track_masks should be a tensor" + if det_masks.size(0) == 0 or track_masks.size(0) == 0: + return list(range(det_masks.size(0))), [], {}, {} # all detections are new + + if list(det_masks.shape[-2:]) != list(track_masks.shape[-2:]): + # resize to the smaller size to save GPU memory + if torch.numel(det_masks[-2:]) < torch.numel(track_masks[-2:]): + track_masks = ( + F.interpolate( + track_masks.unsqueeze(1).float(), + size=det_masks.shape[-2:], + mode="bilinear", + align_corners=False, + ).squeeze(1) + > 0 + ) + else: + # resize detections to track size + det_masks = ( + F.interpolate( + det_masks.unsqueeze(1).float(), + size=track_masks.shape[-2:], + mode="bilinear", + align_corners=False, + ).squeeze(1) + > 0 + ) + + det_masks = det_masks > 0 + track_masks = track_masks > 0 + + iou = mask_iou(det_masks, track_masks) # (N, M) + igeit = iou >= iou_threshold + igeit_any_dim_1 = igeit.any(dim=1) + igeit_trk = iou >= iou_threshold_trk + + iou_list = iou.cpu().numpy().tolist() + igeit_list = igeit.cpu().numpy().tolist() + igeit_any_dim_1_list = igeit_any_dim_1.cpu().numpy().tolist() + igeit_trk_list = igeit_trk.cpu().numpy().tolist() + + det_scores_list = ( + det_scores + if det_scores is None + else det_scores.cpu().float().numpy().tolist() + ) + + # Hungarian matching for tracks (one-to-one: each track matches at most one detection) + # For detections: allow many tracks to match to the same detection (many-to-one) + + # If either is empty, return all detections as new + if det_masks.size(0) == 0 or track_masks.size(0) == 0: + return list(range(det_masks.size(0))), [], {} + + # Hungarian matching: maximize IoU for tracks + cost_matrix = 1 - iou.cpu().numpy() # Hungarian solves for minimum cost + row_ind, col_ind = linear_sum_assignment(cost_matrix) + + def branchy_hungarian_better_uses_the_cpu( + cost_matrix, row_ind, col_ind, iou_list, det_masks, track_masks + ): + matched_trk = set() + matched_det = set() + matched_det_scores = {} # track index -> [det_score, det_score * iou] det score of matched detection mask + for d, t in zip(row_ind, col_ind): + matched_det_scores[t] = [ + det_scores_list[d], + det_scores_list[d] * iou_list[d][t], + ] + if igeit_trk_list[d][t]: + matched_trk.add(t) + matched_det.add(d) + + # Tracks not matched by Hungarian assignment above threshold are unmatched + unmatched_trk_indices = [ + t for t in range(track_masks.size(0)) if t not in matched_trk + ] + + # For detections: allow many tracks to match to the same detection (many-to-one) + # So, a detection is 'new' if it does not match any track above threshold + assert track_masks.size(0) == igeit.size( + 1 + ) # Needed for loop optimizaiton below + new_det_indices = [] + for d in range(det_masks.size(0)): + if not igeit_any_dim_1_list[d]: + if det_scores is not None and det_scores[d] >= new_det_thresh: + new_det_indices.append(d) + + # for each detection, which tracks it matched to (above threshold) + det_to_matched_trk = defaultdict(list) + for d in range(det_masks.size(0)): + for t in range(track_masks.size(0)): + if igeit_list[d][t]: + det_to_matched_trk[d].append(t) + + return ( + new_det_indices, + unmatched_trk_indices, + det_to_matched_trk, + matched_det_scores, + ) + + return (branchy_hungarian_better_uses_the_cpu)( + cost_matrix, row_ind, col_ind, iou_list, det_masks, track_masks + ) diff --git a/third_party/sam3/sam3/perflib/compile.py b/third_party/sam3/sam3/perflib/compile.py new file mode 100644 index 0000000000000000000000000000000000000000..8471406f0abf720ed5834db3c3ff4b7b1c9847d6 --- /dev/null +++ b/third_party/sam3/sam3/perflib/compile.py @@ -0,0 +1,123 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +from functools import wraps + +import torch +from sam3.model.data_misc import BatchedDatapoint, NestedTensor +from torch.utils._pytree import tree_map_only + + +def recursive_fn_factory(fn): + def recursive_fn(b): + if isinstance(b, dict): + return {k: recursive_fn(b[k]) for k in b} + if isinstance(b, list): + return [recursive_fn(t) for t in b] + if isinstance(b, tuple): + return tuple(recursive_fn(t) for t in b) + if isinstance(b, NestedTensor): + tensors = fn(b.tensors) + if b.mask is None: + mask = None + else: + mask = fn(b.mask) + return NestedTensor(tensors=tensors, mask=mask) + if isinstance(b, torch.Tensor): + return fn(b) + if b is None: + return b + trivial_types = [bool, int, float] + for t in trivial_types: + if isinstance(b, t): + return b + raise TypeError(f"Unexpected type {type(b)}") + + return recursive_fn + + +recursive_contiguous = recursive_fn_factory(lambda x: x.contiguous()) +recursive_clone = recursive_fn_factory(torch.clone) + + +def clone_output_wrapper(f): + """ + Clone the CUDA output tensors of a function to avoid in-place operations. + Uses tree_map_only (C-optimized pytree traversal) matching onevision's pattern. + Requires NestedTensor to be registered as a pytree node (see data_misc.py). + """ + + @wraps(f) + def wrapped(*args, **kwargs): + outputs = f(*args, **kwargs) + return tree_map_only( + torch.Tensor, lambda t: t.clone() if t.is_cuda else t, outputs + ) + + return wrapped + + +def compile_wrapper( + fn, *, mode="max-autotune", fullgraph=True, dynamic=False, name=None +): + """Compile with recursive_contiguous on inputs and recursive_clone on outputs. + Used for SAM2 tracker components that need contiguous inputs for CUDA graphs.""" + compiled_fn = torch.compile(fn, mode=mode, fullgraph=fullgraph, dynamic=dynamic) + + def compiled_fn_wrapper(*args, **kwargs): + with torch.autograd.profiler.record_function( + f"compiled {fn}" if name is None else name + ): + CUDAGRAPH_MODES = ["max-autotune", "reduce-overhead"] + args = recursive_contiguous(args) + kwargs = recursive_contiguous(kwargs) + result = compiled_fn(*args, **kwargs) + if mode in CUDAGRAPH_MODES: + result = recursive_clone(result) + return result + + return compiled_fn_wrapper + + +def shape_logging_wrapper(fn, keep_kwargs, enable_logging=False): + """ + Wraps a function and prints the shapes of all tensor inputs. + Only prints when a new combination of shapes is seen. + """ + seen_shapes = set() + + def get_shape(obj): + if isinstance(obj, torch.Tensor): + return obj.shape + elif isinstance(obj, (list, tuple)): + if len(obj) > 1: + return tuple(get_shape(x) for x in obj) + return get_shape(obj[0]) + elif isinstance(obj, dict): + return tuple(sorted((k, get_shape(v)) for k, v in obj.items())) + else: + return type(obj).__name__ + + def wrapper(*args, **kwargs): + shapes = tuple(get_shape(arg) for arg in args) + tuple( + (k, get_shape(v)) + for k, v in kwargs.items() + if isinstance(v, (torch.Tensor, list)) + and (len(keep_kwargs) > 0 and k in keep_kwargs) + ) + if shapes not in seen_shapes: + seen_shapes.add(shapes) + if enable_logging: + print(f"[ShapeLogger] New input shapes for {fn.__qualname__}: {shapes}") + return fn(*args, **kwargs) + + wrapper.enable_logging = enable_logging + + def set_logging(enabled=False): + nonlocal enable_logging + enable_logging = enabled + wrapper.enable_logging = enable_logging + + wrapper.set_logging = set_logging + return wrapper diff --git a/third_party/sam3/sam3/perflib/connected_components.py b/third_party/sam3/sam3/perflib/connected_components.py new file mode 100644 index 0000000000000000000000000000000000000000..0be67e5a3c6520274970dcc5498322b22b8c5632 --- /dev/null +++ b/third_party/sam3/sam3/perflib/connected_components.py @@ -0,0 +1,86 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe +import logging + +import torch + +try: + from cc_torch import get_connected_components + + HAS_CC_TORCH = True +except ImportError: + logging.debug( + "cc_torch not found. Consider installing for better performance. Command line:" + " pip install git+https://github.com/ronghanghu/cc_torch.git" + ) + HAS_CC_TORCH = False + + +def connected_components_cpu_single(values: torch.Tensor): + assert values.dim() == 2 + from skimage.measure import label + + labels, num = label(values.cpu().numpy(), return_num=True) + labels = torch.from_numpy(labels) + counts = torch.zeros_like(labels) + for i in range(1, num + 1): + cur_mask = labels == i + cur_count = cur_mask.sum() + counts[cur_mask] = cur_count + return labels, counts + + +def connected_components_cpu(input_tensor: torch.Tensor): + out_shape = input_tensor.shape + if input_tensor.dim() == 4 and input_tensor.shape[1] == 1: + input_tensor = input_tensor.squeeze(1) + else: + assert ( + input_tensor.dim() == 3 + ), "Input tensor must be (B, H, W) or (B, 1, H, W)." + + batch_size = input_tensor.shape[0] + labels_list = [] + counts_list = [] + for b in range(batch_size): + labels, counts = connected_components_cpu_single(input_tensor[b]) + labels_list.append(labels) + counts_list.append(counts) + labels_tensor = torch.stack(labels_list, dim=0).to(input_tensor.device) + counts_tensor = torch.stack(counts_list, dim=0).to(input_tensor.device) + return labels_tensor.view(out_shape), counts_tensor.view(out_shape) + + +def connected_components(input_tensor: torch.Tensor): + """ + Computes connected components labeling on a batch of 2D tensors, using the best available backend. + + Args: + input_tensor (torch.Tensor): A BxHxW integer tensor or Bx1xHxW. Non-zero values are considered foreground. Bool tensor also accepted + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Both tensors have the same shape as input_tensor. + - A tensor with dense labels. Background is 0. + - A tensor with the size of the connected component for each pixel. + """ + if input_tensor.dim() == 3: + input_tensor = input_tensor.unsqueeze(1) + + assert ( + input_tensor.dim() == 4 and input_tensor.shape[1] == 1 + ), "Input tensor must be (B, H, W) or (B, 1, H, W)." + + if input_tensor.is_cuda: + if HAS_CC_TORCH: + return get_connected_components(input_tensor.to(torch.uint8)) + else: + # triton fallback + from sam3.perflib.triton.connected_components import ( + connected_components_triton, + ) + + return connected_components_triton(input_tensor) + + # CPU fallback + return connected_components_cpu(input_tensor) diff --git a/third_party/sam3/sam3/perflib/fa3.py b/third_party/sam3/sam3/perflib/fa3.py new file mode 100644 index 0000000000000000000000000000000000000000..af226bcf81e38efbd847395e30fecc56f42b2b44 --- /dev/null +++ b/third_party/sam3/sam3/perflib/fa3.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +import torch + + +@torch.library.custom_op("flash::flash_attn_func", mutates_args=()) +def flash_attn_func_op( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor +) -> torch.Tensor: + from flash_attn_interface import flash_attn_func as fa3 + + return fa3(q, k, v) + + +def flash_attn_func(q, k, v): + dtype = torch.float8_e4m3fn + return flash_attn_func_op(q.to(dtype), k.to(dtype), v.to(dtype)).to(q.dtype) + + +@flash_attn_func_op.register_fake +def _(q, k, v, **kwargs): + # two outputs: + # 1. output: (batch, seq_len, num_heads, head_dim) + # 2. softmax_lse: (batch, num_heads, seq_len) with dtype=torch.float32 + # output needs to be bfloat16, not float8! + meta_q = torch.empty_like(q, dtype=torch.bfloat16).contiguous() + return meta_q diff --git a/third_party/sam3/sam3/perflib/fused.py b/third_party/sam3/sam3/perflib/fused.py new file mode 100644 index 0000000000000000000000000000000000000000..6800cca64dd910c2913c767f719a407dedb55488 --- /dev/null +++ b/third_party/sam3/sam3/perflib/fused.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +import torch + +addmm_act_op = torch.ops.aten._addmm_activation + + +def addmm_act(activation, linear, mat1): + if torch.is_grad_enabled(): + raise ValueError("Expected grad to be disabled.") + self = linear.bias.detach() + mat2 = linear.weight.detach() + self = self.to(torch.bfloat16) + mat1 = mat1.to(torch.bfloat16) + mat2 = mat2.to(torch.bfloat16) + mat1_flat = mat1.view(-1, mat1.shape[-1]) + if activation in [torch.nn.functional.relu, torch.nn.ReLU]: + y = addmm_act_op(self, mat1_flat, mat2.t(), beta=1, alpha=1, use_gelu=False) + return y.view(mat1.shape[:-1] + (y.shape[-1],)) + if activation in [torch.nn.functional.gelu, torch.nn.GELU]: + y = addmm_act_op(self, mat1_flat, mat2.t(), beta=1, alpha=1, use_gelu=True) + return y.view(mat1.shape[:-1] + (y.shape[-1],)) + raise ValueError(f"Unexpected activation {activation}") diff --git a/third_party/sam3/sam3/perflib/iou.py b/third_party/sam3/sam3/perflib/iou.py new file mode 100644 index 0000000000000000000000000000000000000000..2b32f802ff674884e51e33c31642c510798c5e73 --- /dev/null +++ b/third_party/sam3/sam3/perflib/iou.py @@ -0,0 +1,38 @@ +import torch + + +def pairwise_iou(pred_masks, gt_masks, eps=1e-6): + N, H, W = pred_masks.shape + M = gt_masks.shape[0] + # Flatten and convert to float for matmul + pred_flat = pred_masks.reshape(N, -1).float() + gt_flat = gt_masks.reshape(M, -1).float() + # Intersection: (N, M) + intersection = torch.matmul(pred_flat, gt_flat.t()) + # Areas + area_pred = pred_flat.sum(dim=1, keepdim=True) # (N, 1) + area_gt = gt_flat.sum(dim=1, keepdim=True) # (M, 1) + # Union: (N, M) + union = area_pred + area_gt.t() - intersection + if eps is None: + iou = intersection / union.clamp(min=1) + else: + iou = intersection / (union + eps) + return iou # shape: (N, M) + + +def pairwise_iom(pred_masks, gt_masks, eps=1e-8): + N, H, W = pred_masks.shape + M = gt_masks.shape[0] + # Flatten and convert to float for matmul + pred_flat = pred_masks.reshape(N, -1).float() + gt_flat = gt_masks.reshape(M, -1).float() + # Intersection: (N, M) + intersection = torch.matmul(pred_flat, gt_flat.t()) + # Areas + area_pred = pred_flat.sum(dim=1, keepdim=True) # (N, 1) + area_gt = gt_flat.sum(dim=1, keepdim=True) # (M, 1) + # Union: (N, M) + min_area = torch.min(area_pred, area_gt) + iou = intersection / (min_area + eps) + return iou # shape: (N, M) diff --git a/third_party/sam3/sam3/perflib/masks_ops.py b/third_party/sam3/sam3/perflib/masks_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..806172de0084d561a2e1ecb0652c8fc503709d11 --- /dev/null +++ b/third_party/sam3/sam3/perflib/masks_ops.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +import torch + + +def masks_to_boxes(masks: torch.Tensor, obj_ids: list[int]): + with torch.autograd.profiler.record_function("perflib: masks_to_boxes"): + # Sanity check based on callsite for replacement + assert masks.shape[0] == len(obj_ids) + assert masks.dim() == 3 + + # Based on torchvision masks_to_boxes + if masks.numel() == 0: + return torch.zeros((0, 4), device=masks.device, dtype=torch.float) + + N, H, W = masks.shape + device = masks.device + y = torch.arange(H, device=device).view(1, H) + x = torch.arange(W, device=device).view(1, W) + + masks_with_obj = masks != 0 # N, H, W + masks_with_obj_x = masks_with_obj.amax( + dim=1 + ) # N, H (which columns have objects) + masks_with_obj_y = masks_with_obj.amax(dim=2) # N, W (which rows have objects) + masks_without_obj_x = ~masks_with_obj_x + masks_without_obj_y = ~masks_with_obj_y + + bounding_boxes_0 = torch.amin( + (masks_without_obj_x * W) + (masks_with_obj_x * x), dim=1 + ) + bounding_boxes_1 = torch.amin( + (masks_without_obj_y * H) + (masks_with_obj_y * y), dim=1 + ) + bounding_boxes_2 = torch.amax(masks_with_obj_x * x, dim=1) + bounding_boxes_3 = torch.amax(masks_with_obj_y * y, dim=1) + + bounding_boxes = torch.stack( + [bounding_boxes_0, bounding_boxes_1, bounding_boxes_2, bounding_boxes_3], + dim=1, + ).to(dtype=torch.float) + assert bounding_boxes.shape == (N, 4) + assert bounding_boxes.device == masks.device + assert bounding_boxes.dtype == torch.float + return bounding_boxes + + +def mask_iou(pred_masks: torch.Tensor, gt_masks: torch.Tensor) -> torch.Tensor: + """ + Compute the IoU (Intersection over Union) between predicted masks and ground truth masks. + Uses matmul-based vectorized intersection for Tensor Core acceleration. + + Args: + - pred_masks: (N, H, W) bool Tensor, containing binary predicted segmentation masks + - gt_masks: (M, H, W) bool Tensor, containing binary ground truth segmentation masks + Returns: + - ious: (N, M) float Tensor, containing IoUs for each pair of predicted and ground truth masks + """ + assert pred_masks.dtype == gt_masks.dtype == torch.bool + assert pred_masks.shape[1:] == gt_masks.shape[1:] + + # Matmul-based intersection (uses Tensor Cores via float mm) + m1_flat = pred_masks.flatten(1).float() + m2_flat = gt_masks.flatten(1).float() + intersection = torch.mm(m1_flat, m2_flat.t()) + + area1 = m1_flat.sum(dim=1) + area2 = m2_flat.sum(dim=1) + union = area1[:, None] + area2[None, :] - intersection + return intersection / union.clamp(min=1) diff --git a/third_party/sam3/sam3/perflib/nms.py b/third_party/sam3/sam3/perflib/nms.py new file mode 100644 index 0000000000000000000000000000000000000000..acd9162768c71e6bf6470ec5db1a789093673859 --- /dev/null +++ b/third_party/sam3/sam3/perflib/nms.py @@ -0,0 +1,92 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +import logging + +import numpy as np +import torch +from sam3.perflib.masks_ops import mask_iou + + +try: + from torch_generic_nms import generic_nms as generic_nms_cuda + + GENERIC_NMS_AVAILABLE = True +except ImportError: + logging.debug( + "Falling back to triton or CPU mask NMS implementation -- please install `torch_generic_nms` via\n\t" + 'pip uninstall -y torch_generic_nms; TORCH_CUDA_ARCH_LIST="8.0 9.0" pip install git+https://github.com/ronghanghu/torch_generic_nms' + ) + GENERIC_NMS_AVAILABLE = False + + +def nms_masks( + pred_probs: torch.Tensor, + pred_masks: torch.Tensor, + prob_threshold: float, + iou_threshold: float, +) -> torch.Tensor: + """ + Args: + - pred_probs: (num_det,) float Tensor, containing the score (probability) of each detection + - pred_masks: (num_det, H_mask, W_mask) float Tensor, containing the binary segmentation mask of each detection + - prob_threshold: float, score threshold to prefilter detections (NMS is performed on detections above threshold) + - iou_threshold: float, mask IoU threshold for NMS + + Returns: + - keep: (num_det,) bool Tensor, indicating whether each detection is kept after score thresholding + NMS + """ + # prefilter the detections with prob_threshold ("valid" are those above prob_threshold) + is_valid = pred_probs > prob_threshold # (num_det,) + probs = pred_probs[is_valid] # (num_valid,) + masks_binary = pred_masks[is_valid] > 0 # (num_valid, H_mask, W_mask) + if probs.numel() == 0: + return is_valid # no valid detection, return empty keep mask + + ious = mask_iou(masks_binary, masks_binary) # (num_valid, num_valid) + kept_inds = generic_nms(ious, probs, iou_threshold) + + # valid_inds are the indices among `probs` of valid detections before NMS (or -1 for invalid) + valid_inds = torch.where(is_valid, is_valid.cumsum(dim=0) - 1, -1) # (num_det,) + keep = torch.isin(valid_inds, kept_inds) # (num_det,) + return keep + + +def generic_nms( + ious: torch.Tensor, scores: torch.Tensor, iou_threshold=0.5 +) -> torch.Tensor: + """A generic version of `torchvision.ops.nms` that takes a pairwise IoU matrix.""" + + assert ious.dim() == 2 and ious.size(0) == ious.size(1) + assert scores.dim() == 1 and scores.size(0) == ious.size(0) + + if ious.is_cuda: + if GENERIC_NMS_AVAILABLE: + return generic_nms_cuda(ious, scores, iou_threshold, use_iou_matrix=True) + else: + from sam3.perflib.triton.nms import nms_triton + + return nms_triton(ious, scores, iou_threshold) + + return generic_nms_cpu(ious, scores, iou_threshold) + + +def generic_nms_cpu( + ious: torch.Tensor, scores: torch.Tensor, iou_threshold=0.5 +) -> torch.Tensor: + """ + A generic version of `torchvision.ops.nms` that takes a pairwise IoU matrix. (CPU implementation + based on https://github.com/jwyang/faster-rcnn.pytorch/blob/master/lib/model/nms/nms_cpu.py) + """ + ious_np = ious.float().detach().cpu().numpy() + scores_np = scores.float().detach().cpu().numpy() + order = scores_np.argsort()[::-1] + kept_inds = [] + while order.size > 0: + i = order.item(0) + kept_inds.append(i) + inds = np.where(ious_np[i, order[1:]] <= iou_threshold)[0] + order = order[inds + 1] + + return torch.tensor(kept_inds, dtype=torch.int64, device=scores.device) diff --git a/third_party/sam3/sam3/perflib/tests/assets/masks.tiff b/third_party/sam3/sam3/perflib/tests/assets/masks.tiff new file mode 100644 index 0000000000000000000000000000000000000000..5d05021c65ff41d8ff6cbd8ffec043c261bb5341 --- /dev/null +++ b/third_party/sam3/sam3/perflib/tests/assets/masks.tiff @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e470fe2921b69eef47bcbf8394f60f86efa1304b63eb5b9efb297963d8485b60 +size 352484 diff --git a/third_party/sam3/sam3/perflib/tests/tests.py b/third_party/sam3/sam3/perflib/tests/tests.py new file mode 100644 index 0000000000000000000000000000000000000000..f698b6a571de45799efac65a599c1cc270c50eb4 --- /dev/null +++ b/third_party/sam3/sam3/perflib/tests/tests.py @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +import os + +import numpy as np +import pytest +import torch +from PIL import Image +from sam3.perflib.masks_ops import masks_to_boxes + + +class TestMasksToBoxes: + def test_masks_box(self): + def masks_box_check(masks, expected, atol=1e-4): + out = masks_to_boxes(masks, [1 for _ in range(masks.shape[0])]) + assert out.dtype == torch.float + print("out: ", out) + print("expected: ", expected) + torch.testing.assert_close( + out, expected, rtol=0.0, check_dtype=True, atol=atol + ) + + # Check for int type boxes. + def _get_image(): + assets_directory = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "assets" + ) + mask_path = os.path.join(assets_directory, "masks.tiff") + image = Image.open(mask_path) + return image + + def _create_masks(image, masks): + for index in range(image.n_frames): + image.seek(index) + frame = np.array(image) + masks[index] = torch.tensor(frame) + + return masks + + expected = torch.tensor( + [ + [127, 2, 165, 40], + [2, 50, 44, 92], + [56, 63, 98, 100], + [139, 68, 175, 104], + [160, 112, 198, 145], + [49, 138, 99, 182], + [108, 148, 152, 213], + ], + dtype=torch.float, + ) + + image = _get_image() + for dtype in [torch.float16, torch.float32, torch.float64]: + masks = torch.zeros( + (image.n_frames, image.height, image.width), dtype=dtype + ) + masks = _create_masks(image, masks) + masks_box_check(masks, expected) diff --git a/third_party/sam3/sam3/perflib/triton/connected_components.py b/third_party/sam3/sam3/perflib/triton/connected_components.py new file mode 100644 index 0000000000000000000000000000000000000000..cdb7d44bdb80cf91bed4e9ae9a9996f2cd3a1006 --- /dev/null +++ b/third_party/sam3/sam3/perflib/triton/connected_components.py @@ -0,0 +1,470 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe +import math + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _any_combine(a, b): + return a | b + + +@triton.jit +def tl_any(a, dim=0): + return tl.reduce(a, dim, _any_combine) + + +# ============================================================================== +# ## Phase 1: Initialization Kernel +# ============================================================================== +# Each foreground pixel (value > 0) gets a unique label equal to its +# linear index. Background pixels (value == 0) get a sentinel label of -1. +# Note that the indexing is done across batch boundaries for simplicity +# (i.e., the first pixel of image 1 gets label H*W, etc.) + + +@triton.jit +def _init_labels_kernel( + input_ptr, labels_ptr, numel: tl.constexpr, BLOCK_SIZE: tl.constexpr +): + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < numel + input_values = tl.load(input_ptr + offsets, mask=mask, other=0) + + indices = tl.where((input_values != 0), offsets, -1) + tl.store(labels_ptr + offsets, indices, mask=mask) + + +# ============================================================================== +# ## Phase 2: Local merging +# ============================================================================== +# Each pixel tries to merge with its 8-connected neighbors (up, down, left, right) +# if they have the same value. This is done using a disjoint-set union operation. + + +@triton.jit +def find(labels_ptr, indices, mask): + current_pids = indices + + # 'is_done' tracks lanes that have finished their work. + # A lane is initially "done" if it's not active (mask is False). + is_done = ~mask + + # Loop as long as there is at least one lane that is NOT done. + while tl_any(~is_done): + # The work_mask is for lanes that are still active and seeking their root. + work_mask = ~is_done + parents = tl.load(labels_ptr + current_pids, mask=work_mask, other=-1) + # A lane is now done if its parent is itself (it's a root) + # or if it hits a -1 sentinel (a safe exit condition). + is_root = parents == current_pids + is_sentinel = parents == -1 + is_done |= is_root | is_sentinel + + # For lanes that are not yet done, update their pid to their parent to continue traversal. + current_pids = tl.where(is_done, current_pids, parents) + # We could add the following line to do path compression, but experimentally it's slower + # tl.atomic_min(labels_ptr + indices, current_pids, mask=mask) + return current_pids + + +@triton.jit +def union(labels_ptr, a, b, process_mask): + # This function implements a disjoint-set union + # As an invariant, we use the fact that the roots have the lower id. That helps parallelization + # However, that is not sufficient by itself. Suppose two threads want to do union(0,2) and union(1,2) at the same time + # Then if we do a naive atomic_min, 0 and 1 will compete to be the new parent of 2 and min(0, 1) will win. + # However, 1 still needs to be merged with the new {0, 2} component. + # To ensure that merge is also done, we need to detect whether the merge was successful, and if not retry until it is + + current_a = a + current_b = b + + final_root = a + # A mask to track which lanes have successfully completed their union. + done_mask = ~process_mask # tl.zeros_like(a) == 1 # Init with all False + + while tl_any(~done_mask): + # Define the mask for lanes that still need work in this iteration + work_mask = process_mask & ~done_mask + + # Find the roots for the current a and b values in the active lanes + root_a = find(labels_ptr, current_a, work_mask) + tl.debug_barrier() + root_b = find(labels_ptr, current_b, work_mask) + + # 7. Merge logic + # If roots are already the same, the sets are already merged. Mark as done. + are_equal = root_a == root_b + final_root = tl.where(are_equal & work_mask & ~done_mask, root_a, final_root) + done_mask |= are_equal & work_mask + + # Define masks for the two merge cases (a < b or b < a) + a_is_smaller = root_a < root_b + + # Case 1: root_a < root_b. Attempt to set parent[root_b] = root_a + merge_mask_a_smaller = work_mask & a_is_smaller & ~are_equal + ptr_b = labels_ptr + root_b + old_val_b = tl.atomic_min(ptr_b, root_a, mask=merge_mask_a_smaller) + + # A lane is done if its atomic op was successful (old value was what we expected) + success_b = old_val_b == root_b + final_root = tl.where(success_b & work_mask & ~done_mask, root_a, final_root) + done_mask |= success_b & merge_mask_a_smaller + + # *** Crucial Retry Logic *** + # If the update failed (old_val_b != root_b), another thread interfered. + # We update `current_b` to this new root (`old_val_b`) and will retry in the next loop iteration. + current_b = tl.where(success_b | ~merge_mask_a_smaller, current_b, old_val_b) + + # Case 2: root_b < root_a. Attempt to set parent[root_a] = root_b + merge_mask_b_smaller = work_mask & ~a_is_smaller & ~are_equal + ptr_a = labels_ptr + root_a + old_val_a = tl.atomic_min(ptr_a, root_b, mask=merge_mask_b_smaller) + + success_a = old_val_a == root_a + final_root = tl.where(success_a & work_mask & ~done_mask, root_b, final_root) + done_mask |= success_a & merge_mask_b_smaller + + # *** Crucial Retry Logic *** + # Similarly, update `current_a` if the atomic operation failed. + current_a = tl.where(success_a | ~merge_mask_b_smaller, current_a, old_val_a) + + return final_root + + +@triton.jit +def _merge_helper( + input_ptr, + labels_ptr, + base_offset, + offsets_h, + offsets_w, + mask_2d, + valid_current, + current_values, + current_labels, + H, + W, + dx: tl.constexpr, + dy: tl.constexpr, +): + # Helper functions to compute merge with a specific neighbor offset (dx, dy) + + neighbor_h = offsets_h + dy + neighbor_w = offsets_w + dx + # Proper bounds checking: all four bounds must be satisfied + mask_n = ( + mask_2d + & (neighbor_h[:, None] >= 0) + & (neighbor_h[:, None] < H) + & (neighbor_w[None, :] >= 0) + & (neighbor_w[None, :] < W) + ) + + offsets_neighbor = neighbor_h[:, None] * W + neighbor_w[None, :] + neighbor_values = tl.load( + input_ptr + base_offset + offsets_neighbor, mask=mask_n, other=-1 + ) + + mask_n = tl.ravel(mask_n) + neighbor_labels = tl.load( + labels_ptr + tl.ravel(base_offset + offsets_neighbor), mask=mask_n, other=-1 + ) + + to_merge = ( + mask_n & (neighbor_labels != -1) & tl.ravel(current_values == neighbor_values) + ) + valid_write = valid_current & to_merge + + # returns new parents for the pixels that were merged (otherwise keeps current labels) + return tl.where( + valid_write, + union(labels_ptr, current_labels, neighbor_labels, valid_write), + current_labels, + ) + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_SIZE_H": 4, "BLOCK_SIZE_W": 16}, num_stages=1, num_warps=2 + ), + triton.Config( + {"BLOCK_SIZE_H": 4, "BLOCK_SIZE_W": 32}, num_stages=2, num_warps=4 + ), + ], + key=["H", "W"], + restore_value=["labels_ptr"], +) +@triton.jit +def _local_prop_kernel( + labels_ptr, + input_ptr, + H: tl.constexpr, + W: tl.constexpr, + BLOCK_SIZE_H: tl.constexpr, + BLOCK_SIZE_W: tl.constexpr, +): + # This is the meat of the Phase 2 to do local merging + # It will be launched with a 2D grid: + # - dim 0: batch index + # - dim 1: block index over HxW image (2D tiling) + pid_b = tl.program_id(0) + pid_hw = tl.program_id(1) + + # Calculate offsets for the core block + offsets_h = (pid_hw // tl.cdiv(W, BLOCK_SIZE_W)) * BLOCK_SIZE_H + tl.arange( + 0, BLOCK_SIZE_H + ) + offsets_w = (pid_hw % tl.cdiv(W, BLOCK_SIZE_W)) * BLOCK_SIZE_W + tl.arange( + 0, BLOCK_SIZE_W + ) + + base_offset = pid_b * H * W + offsets_2d = offsets_h[:, None] * W + offsets_w[None, :] + mask_2d = (offsets_h[:, None] < H) & (offsets_w[None, :] < W) + mask_1d = tl.ravel(mask_2d) + + # Load the current labels for the block - these are parent pointers + current_labels = tl.load( + labels_ptr + tl.ravel(base_offset + offsets_2d), mask=mask_1d, other=-1 + ) + current_values = tl.load( + input_ptr + base_offset + offsets_2d, mask=mask_2d, other=-1 + ) + valid_current = mask_1d & (current_labels != -1) + + # Horizontal merge + current_labels = _merge_helper( + input_ptr, + labels_ptr, + base_offset, + offsets_h, + offsets_w, + mask_2d, + valid_current, + current_values, + current_labels, + H, + W, + -1, + 0, + ) + # Vertical merge + current_labels = _merge_helper( + input_ptr, + labels_ptr, + base_offset, + offsets_h, + offsets_w, + mask_2d, + valid_current, + current_values, + current_labels, + H, + W, + 0, + -1, + ) + # Diagonal merges + current_labels = _merge_helper( + input_ptr, + labels_ptr, + base_offset, + offsets_h, + offsets_w, + mask_2d, + valid_current, + current_values, + current_labels, + H, + W, + -1, + -1, + ) + current_labels = _merge_helper( + input_ptr, + labels_ptr, + base_offset, + offsets_h, + offsets_w, + mask_2d, + valid_current, + current_values, + current_labels, + H, + W, + -1, + 1, + ) + + # This actually does some path compression, in a lightweight but beneficial way + tl.atomic_min( + labels_ptr + tl.ravel(base_offset + offsets_2d), current_labels, mask=mask_1d + ) + + +# ============================================================================== +# ## Phase 3: Pointer Jumping Kernel +# ============================================================================== +# This kernel performs pointer jumping to ensure that all pixels point directly to their root labels. +# This is done in a loop until convergence. + + +@triton.jit +def _pointer_jump_kernel( + labels_in_ptr, labels_out_ptr, numel: tl.constexpr, BLOCK_SIZE: tl.constexpr +): + """ + Pointer jumping kernel with double buffering to avoid race conditions. + Reads from labels_in_ptr and writes to labels_out_ptr. + """ + # This kernel is launched with a 1D grid, and does not care about batching explicitly. + # By construction, the labels are global indices across the batch, and we never perform + # cross-batch merges, so this is safe. + + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < numel + + # Load current labels from input buffer + current_labels = tl.load(labels_in_ptr + offsets, mask=mask, other=-1) + valid_mask = mask & (current_labels != -1) + + # A mask to track which lanes have successfully completed their union. + done_mask = ~valid_mask + while tl_any(~(done_mask | ~valid_mask)): + parent_labels = tl.load( + labels_in_ptr + current_labels, mask=valid_mask, other=-1 + ) + + are_equal = current_labels == parent_labels + done_mask |= are_equal & valid_mask + + current_labels = tl.where( + ~done_mask, tl.minimum(current_labels, parent_labels), current_labels + ) + + # Write to output buffer (safe because we're not reading from it) + tl.store(labels_out_ptr + offsets, current_labels, mask=mask) + + +# ============================================================================== +# ## Phase 4: Kernels for Computing Component Sizes +# ============================================================================== + + +# Step 4.1: Count occurrences of each root label using atomic adds. +@triton.jit +def _count_labels_kernel(labels_ptr, sizes_ptr, numel, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < numel + + # Load the final, converged labels + labels = tl.load(labels_ptr + offsets, mask=mask, other=-1) + valid_mask = mask & (labels != -1) + + # Atomically increment the counter for each label. This builds a histogram. + tl.atomic_add(sizes_ptr + labels, 1, mask=valid_mask) + + +# Step 4.2: Broadcast the computed sizes back to the output tensor. +@triton.jit +def _broadcast_sizes_kernel( + labels_ptr, sizes_ptr, out_ptr, numel, BLOCK_SIZE: tl.constexpr +): + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < numel + + # Load the final labels + labels = tl.load(labels_ptr + offsets, mask=mask, other=-1) + valid_mask = mask & (labels != -1) + + # Look up the size for each label from the histogram + component_sizes = tl.load(sizes_ptr + labels, mask=valid_mask, other=0) + + # Write the size to the final output tensor. Background pixels get size 0. + tl.store(out_ptr + offsets, component_sizes, mask=mask) + + +def connected_components_triton(input_tensor: torch.Tensor): + """ + Computes connected components labeling on a batch of 2D integer tensors using Triton. + + Args: + input_tensor (torch.Tensor): A BxHxW integer tensor or Bx1xHxW. Non-zero values are considered foreground. Bool tensor also accepted + + Returns: + Tuple[torch.Tensor, int]: A tuple containing: + - A BxHxW output tensor with dense labels. Background is 0. + - A BxHxW tensor with the size of the connected component for each pixel. + """ + assert ( + input_tensor.is_cuda and input_tensor.is_contiguous() + ), "Input tensor must be a contiguous CUDA tensor." + out_shape = input_tensor.shape + if input_tensor.dim() == 4 and input_tensor.shape[1] == 1: + input_tensor = input_tensor.squeeze(1) + else: + assert ( + input_tensor.dim() == 3 + ), "Input tensor must be (B, H, W) or (B, 1, H, W)." + + B, H, W = input_tensor.shape + numel = B * H * W + device = input_tensor.device + + # --- Allocate Tensors --- + labels = torch.empty_like(input_tensor, dtype=torch.int32) + output = torch.empty_like(input_tensor, dtype=torch.int32) + + # --- Phase 1 --- + BLOCK_SIZE = 256 + grid_init = (triton.cdiv(numel, BLOCK_SIZE),) + _init_labels_kernel[grid_init]( + input_tensor, + labels, + numel, + BLOCK_SIZE=BLOCK_SIZE, + ) + + # --- Phase 2 --- + grid_local_prop = lambda meta: ( + B, + triton.cdiv(H, meta["BLOCK_SIZE_H"]) * triton.cdiv(W, meta["BLOCK_SIZE_W"]), + ) + _local_prop_kernel[grid_local_prop](labels, input_tensor, H, W) + + # --- Phase 3 --- + BLOCK_SIZE = 256 + grid_jump = lambda meta: (triton.cdiv(numel, meta["BLOCK_SIZE"]),) + _pointer_jump_kernel[grid_jump](labels, output, numel, BLOCK_SIZE=BLOCK_SIZE) + + # --- Phase 4 --- + # Allocate tensor to store the final output sizes + component_sizes_out = torch.empty_like(input_tensor, dtype=torch.int32) + + # Allocate a temporary 1D tensor to act as the histogram + # Size is numel because labels can be up to numel-1 + sizes_histogram = torch.zeros(numel, dtype=torch.int32, device=device) + + # 4.1: Count the occurrences of each label + grid_count = (triton.cdiv(numel, BLOCK_SIZE),) + _count_labels_kernel[grid_count]( + output, sizes_histogram, numel, BLOCK_SIZE=BLOCK_SIZE + ) + + # 2.2: Broadcast the counts to the final output tensor + grid_broadcast = (triton.cdiv(numel, BLOCK_SIZE),) + _broadcast_sizes_kernel[grid_broadcast]( + output, sizes_histogram, component_sizes_out, numel, BLOCK_SIZE=BLOCK_SIZE + ) + return output.view(out_shape) + 1, component_sizes_out.view(out_shape) diff --git a/third_party/sam3/sam3/perflib/triton/nms.py b/third_party/sam3/sam3/perflib/triton/nms.py new file mode 100644 index 0000000000000000000000000000000000000000..9a06f33518857fefd80fc5a4fa84643622708356 --- /dev/null +++ b/third_party/sam3/sam3/perflib/triton/nms.py @@ -0,0 +1,126 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +# Adapted from https://github.com/stackav-oss/conch/blob/main/conch/kernels/vision/nms.py + +import torch +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({"cxpr_block_size": 128}), + triton.Config({"cxpr_block_size": 256}), + triton.Config({"cxpr_block_size": 512}), + triton.Config({"cxpr_block_size": 1024}), + triton.Config({"cxpr_block_size": 2048}), + triton.Config({"cxpr_block_size": 4096}), + triton.Config({"cxpr_block_size": 8192}), + ], + key=["num_boxes"], +) +@triton.jit +def _nms_suppression_kernel( + # Tensors + iou_mask_ptr: tl.tensor, # [N, N] + keep_mask_ptr: tl.tensor, # [N] + # Scalars + num_boxes: tl.int32, + # Strides + iou_mask_stride: tl.int32, + # Constexprs + cxpr_block_size: tl.constexpr, +) -> None: + """NMS suppression kernel. + + Args: + iou_mask_ptr: Pointer to precomputed IoU mask, shape: (N, N). + keep_mask_ptr: Pointer to keep mask tensor, shape: (N,). + num_boxes: Number of boxes. + iou_mask_stride: Stride for IoU mask tensor. + cxpr_block_size: Block size for processing. + """ + # Sequential NMS: for each box in sorted order, suppress later boxes + for current_box_idx in range(num_boxes - 1): + # Check if current box is still kept + is_kept = tl.load(keep_mask_ptr + current_box_idx) + if is_kept: + # IoU mask row offset for the current box + # Because the IoU mask is sorted by score, we will only consider boxes that come after the current box. + # This means we only need to read the upper triangular part of the IoU mask. + iou_row_offset = current_box_idx * iou_mask_stride + + # Only process boxes that come after the current box + next_box_idx = current_box_idx + 1 + remaining_boxes = num_boxes - next_box_idx + + # Iterate blockwise through the columns + for block_idx in range(tl.cdiv(remaining_boxes, cxpr_block_size)): + # Masked load of indices for the target boxes in the current block + block_start = next_box_idx + block_idx * cxpr_block_size + target_box_offsets = block_start + tl.arange(0, cxpr_block_size) + target_box_mask = target_box_offsets < num_boxes + + # Suppress boxes with lower scores that have high IoU + suppression_mask = tl.load( + iou_mask_ptr + iou_row_offset + target_box_offsets, + mask=target_box_mask, + other=False, + ) + suppression_mask = tl.cast(suppression_mask, tl.int1) + + # Conditionally store suppression result for high-IoU boxes + tl.store( + keep_mask_ptr + target_box_offsets, False, mask=suppression_mask + ) + + # Potential race condition: we need to ensure all threads complete the store before the next + # iteration otherwise we may load stale data for whether or not a box has been suppressed. + tl.debug_barrier() + + +def nms_triton( + ious: torch.Tensor, + scores: torch.Tensor, + iou_threshold: float, +) -> torch.Tensor: + """Perform NMS given the iou matrix, the scores and the iou threshold + + Args: + ious: Pairwise IoU tensor of shape (N, N). + scores: Scores tensor of shape (N,). + iou_threshold: IoU threshold for suppression. + + Returns: + Tensor: Indices of kept boxes, sorted by decreasing score. + """ + assert scores.dim() == 1, "Scores must be 1D" + iou_mask = ious > iou_threshold + assert iou_mask.dim() == 2 + assert iou_mask.shape[0] == iou_mask.shape[1] == scores.shape[0] + assert iou_mask.device == scores.device + assert iou_mask.dtype == torch.bool + + num_boxes = scores.size(0) + keep_mask = torch.ones(len(scores), device=scores.device, dtype=torch.bool) + + # Sort boxes by scores in descending order + _, sorted_indices = torch.sort(scores, dim=0, stable=True, descending=True) + iou_mask = iou_mask[sorted_indices][:, sorted_indices].contiguous() + + # For the suppression stage, we need to process sequentially, but we'll still take + # advantage of parallelism by processing in blocks in one program. + stage2_grid = (1,) + _nms_suppression_kernel[stage2_grid]( + # Tensors + iou_mask_ptr=iou_mask, + keep_mask_ptr=keep_mask, + # Scalars + num_boxes=num_boxes, + # Strides + iou_mask_stride=iou_mask.stride(0), + ) + # Extract indices of kept boxes + return sorted_indices[keep_mask] diff --git a/third_party/sam3/sam3/sam/__init__.py b/third_party/sam3/sam3/sam/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e728afdd05187977882f6d25adc6c9a84542cc9d --- /dev/null +++ b/third_party/sam3/sam3/sam/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe +from .mask_decoder import MaskDecoder +from .prompt_encoder import PromptEncoder +from .transformer import TwoWayTransformer diff --git a/third_party/sam3/sam3/sam/__pycache__/__init__.cpython-311.pyc b/third_party/sam3/sam3/sam/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7bfe373976fd89aef615719f0cb69cda4ebf45a9 Binary files /dev/null and b/third_party/sam3/sam3/sam/__pycache__/__init__.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/sam/__pycache__/common.cpython-311.pyc b/third_party/sam3/sam3/sam/__pycache__/common.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ebd4b774d95c27dc83e628f50b2cca979b41a68 Binary files /dev/null and b/third_party/sam3/sam3/sam/__pycache__/common.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/sam/__pycache__/mask_decoder.cpython-311.pyc b/third_party/sam3/sam3/sam/__pycache__/mask_decoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6620ecf3d6c314889c243a6ba26a1b4e4fe271f9 Binary files /dev/null and b/third_party/sam3/sam3/sam/__pycache__/mask_decoder.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/sam/__pycache__/prompt_encoder.cpython-311.pyc b/third_party/sam3/sam3/sam/__pycache__/prompt_encoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f801405e9d3ffdd9173bd71d9187c3f75731a8fc Binary files /dev/null and b/third_party/sam3/sam3/sam/__pycache__/prompt_encoder.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/sam/__pycache__/rope.cpython-311.pyc b/third_party/sam3/sam3/sam/__pycache__/rope.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa7fdea26fb7c16d5e06678aec6fb716d8cc4caf Binary files /dev/null and b/third_party/sam3/sam3/sam/__pycache__/rope.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/sam/__pycache__/transformer.cpython-311.pyc b/third_party/sam3/sam3/sam/__pycache__/transformer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..87d0ddbc1a499e24abc76788fb08c76677a6fa7f Binary files /dev/null and b/third_party/sam3/sam3/sam/__pycache__/transformer.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/sam/common.py b/third_party/sam3/sam3/sam/common.py new file mode 100644 index 0000000000000000000000000000000000000000..72b18309ae7378a817cc3dd053f1dff62efeebdb --- /dev/null +++ b/third_party/sam3/sam3/sam/common.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +from typing import Type + +import torch +import torch.nn as nn + + +class MLPBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + mlp_dim: int, + act: Type[nn.Module] = nn.GELU, + ) -> None: + super().__init__() + self.lin1 = nn.Linear(embedding_dim, mlp_dim) + self.lin2 = nn.Linear(mlp_dim, embedding_dim) + self.act = act() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.lin2(self.act(self.lin1(x))) + + +# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa +# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x diff --git a/third_party/sam3/sam3/sam/mask_decoder.py b/third_party/sam3/sam3/sam/mask_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..3e1bbd27df097cbc66b209e88c1ac1903043213f --- /dev/null +++ b/third_party/sam3/sam3/sam/mask_decoder.py @@ -0,0 +1,321 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +from typing import List, Optional, Tuple, Type + +import torch +from torch import nn +from torch.nn import functional as F + +from .common import LayerNorm2d + + +class MaskDecoder(nn.Module): + def __init__( + self, + *, + transformer_dim: int, + transformer: nn.Module, + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + use_high_res_features: bool = False, + iou_prediction_use_sigmoid=False, + dynamic_multimask_via_stability=False, + dynamic_multimask_stability_delta=0.05, + dynamic_multimask_stability_thresh=0.98, + pred_obj_scores: bool = False, + pred_obj_scores_mlp: bool = False, + use_multimask_token_for_obj_ptr: bool = False, + ) -> None: + """ + Predicts masks given an image and prompt embeddings, using a + transformer architecture. + + Arguments: + transformer_dim (int): the channel dimension of the transformer + transformer (nn.Module): the transformer used to predict masks + num_multimask_outputs (int): the number of masks to predict + when disambiguating masks + activation (nn.Module): the type of activation to use when + upscaling masks + iou_head_depth (int): the depth of the MLP used to predict + mask quality + iou_head_hidden_dim (int): the hidden dimension of the MLP + used to predict mask quality + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = transformer + + self.num_multimask_outputs = num_multimask_outputs + + self.iou_token = nn.Embedding(1, transformer_dim) + self.num_mask_tokens = num_multimask_outputs + 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + + self.pred_obj_scores = pred_obj_scores + if self.pred_obj_scores: + self.obj_score_token = nn.Embedding(1, transformer_dim) + self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d( + transformer_dim, transformer_dim // 4, kernel_size=2, stride=2 + ), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d( + transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2 + ), + activation(), + ) + self.use_high_res_features = use_high_res_features + if use_high_res_features: + self.conv_s0 = nn.Conv2d( + transformer_dim, transformer_dim // 8, kernel_size=1, stride=1 + ) + self.conv_s1 = nn.Conv2d( + transformer_dim, transformer_dim // 4, kernel_size=1, stride=1 + ) + + self.output_hypernetworks_mlps = nn.ModuleList( + [ + MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) + for i in range(self.num_mask_tokens) + ] + ) + + self.iou_prediction_head = MLP( + transformer_dim, + iou_head_hidden_dim, + self.num_mask_tokens, + iou_head_depth, + sigmoid_output=iou_prediction_use_sigmoid, + ) + if self.pred_obj_scores: + self.pred_obj_score_head = nn.Linear(transformer_dim, 1) + if pred_obj_scores_mlp: + self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3) + + # When outputting a single mask, optionally we can dynamically fall back to the best + # multimask output token if the single mask output token gives low stability scores. + self.dynamic_multimask_via_stability = dynamic_multimask_via_stability + self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta + self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + repeat_image: bool, + high_res_features: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + image_embeddings (torch.Tensor): the embeddings from the image encoder + image_pe (torch.Tensor): positional encoding with the shape of image_embeddings + sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes + dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs + multimask_output (bool): Whether to return multiple masks or a single + mask. + + Returns: + torch.Tensor: batched predicted masks + torch.Tensor: batched predictions of mask quality + torch.Tensor: batched SAM token for mask output + """ + masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + repeat_image=repeat_image, + high_res_features=high_res_features, + ) + + # Select the correct mask or masks for output + if multimask_output: + masks = masks[:, 1:, :, :] + iou_pred = iou_pred[:, 1:] + elif self.dynamic_multimask_via_stability and not self.training: + masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred) + else: + masks = masks[:, 0:1, :, :] + iou_pred = iou_pred[:, 0:1] + + if multimask_output and self.use_multimask_token_for_obj_ptr: + sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape + else: + # Take the mask output token. Here we *always* use the token for single mask output. + # At test time, even if we track after 1-click (and using multimask_output=True), + # we still take the single mask token here. The rationale is that we always track + # after multiple clicks during training, so the past tokens seen during training + # are always the single mask token (and we'll let it be the object-memory token). + sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape + + # Prepare output + return masks, iou_pred, sam_tokens_out, object_score_logits + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + repeat_image: bool, + high_res_features: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts masks. See 'forward' for more details.""" + # Concatenate output tokens + s = 0 + if self.pred_obj_scores: + output_tokens = torch.cat( + [ + self.obj_score_token.weight, + self.iou_token.weight, + self.mask_tokens.weight, + ], + dim=0, + ) + s = 1 + else: + output_tokens = torch.cat( + [self.iou_token.weight, self.mask_tokens.weight], dim=0 + ) + output_tokens = output_tokens.unsqueeze(0).expand( + sparse_prompt_embeddings.size(0), -1, -1 + ) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + if repeat_image: + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + else: + assert image_embeddings.shape[0] == tokens.shape[0] + src = image_embeddings + src = src + dense_prompt_embeddings + assert ( + image_pe.size(0) == 1 + ), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)" + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens) + iou_token_out = hs[:, s, :] + mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + if not self.use_high_res_features: + upscaled_embedding = self.output_upscaling(src) + else: + dc1, ln1, act1, dc2, act2 = self.output_upscaling + feat_s0, feat_s1 = high_res_features + upscaled_embedding = act1(ln1(dc1(src) + feat_s1)) + upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0) + + hyper_in_list: List[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + hyper_in_list.append( + self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) + ) + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding.shape + masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + if self.pred_obj_scores: + assert s == 1 + object_score_logits = self.pred_obj_score_head(hs[:, 0, :]) + else: + # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1 + object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1) + + return masks, iou_pred, mask_tokens_out, object_score_logits + + def _get_stability_scores(self, mask_logits): + """ + Compute stability scores of the mask logits based on the IoU between upper and + lower thresholds. + """ + mask_logits = mask_logits.flatten(-2) + stability_delta = self.dynamic_multimask_stability_delta + area_i = torch.sum(mask_logits > stability_delta, dim=-1).float() + area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float() + stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0) + return stability_scores + + def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores): + """ + When outputting a single mask, if the stability score from the current single-mask + output (based on output token 0) falls below a threshold, we instead select from + multi-mask outputs (based on output token 1~3) the mask with the highest predicted + IoU score. This is intended to ensure a valid mask for both clicking and tracking. + """ + # The best mask from multimask output tokens (1~3) + multimask_logits = all_mask_logits[:, 1:, :, :] + multimask_iou_scores = all_iou_scores[:, 1:] + best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1) + batch_inds = torch.arange( + multimask_iou_scores.size(0), device=all_iou_scores.device + ) + best_multimask_logits = multimask_logits[batch_inds, best_scores_inds] + best_multimask_logits = best_multimask_logits.unsqueeze(1) + best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds] + best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1) + + # The mask from singlemask output token 0 and its stability score + singlemask_logits = all_mask_logits[:, 0:1, :, :] + singlemask_iou_scores = all_iou_scores[:, 0:1] + stability_scores = self._get_stability_scores(singlemask_logits) + is_stable = stability_scores >= self.dynamic_multimask_stability_thresh + + # Dynamically fall back to best multimask output upon low stability scores. + mask_logits_out = torch.where( + is_stable[..., None, None].expand_as(singlemask_logits), + singlemask_logits, + best_multimask_logits, + ) + iou_scores_out = torch.where( + is_stable.expand_as(singlemask_iou_scores), + singlemask_iou_scores, + best_multimask_iou_scores, + ) + return mask_logits_out, iou_scores_out + + +# Lightly adapted from +# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + sigmoid_output: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.sigmoid_output = sigmoid_output + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x diff --git a/third_party/sam3/sam3/sam/prompt_encoder.py b/third_party/sam3/sam3/sam/prompt_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..b545a2d0757d9f7f87854d76644f8416abe561ef --- /dev/null +++ b/third_party/sam3/sam3/sam/prompt_encoder.py @@ -0,0 +1,245 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +from typing import Any, Optional, Tuple, Type + +import numpy as np +import torch +from torch import nn + +from .common import LayerNorm2d + + +class PromptEncoder(nn.Module): + def __init__( + self, + embed_dim: int, + image_embedding_size: Tuple[int, int], + input_image_size: Tuple[int, int], + mask_in_chans: int, + activation: Type[nn.Module] = nn.GELU, + ) -> None: + """ + Encodes prompts for input to SAM's mask decoder. + + Arguments: + embed_dim (int): The prompts' embedding dimension + image_embedding_size (tuple(int, int)): The spatial size of the + image embedding, as (H, W). + input_image_size (int): The padded size of the image as input + to the image encoder, as (H, W). + mask_in_chans (int): The number of hidden channels used for + encoding input masks. + activation (nn.Module): The activation to use when encoding + input masks. + """ + super().__init__() + self.embed_dim = embed_dim + self.input_image_size = input_image_size + self.image_embedding_size = image_embedding_size + self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) + + self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners + point_embeddings = [ + nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings) + ] + self.point_embeddings = nn.ModuleList(point_embeddings) + self.not_a_point_embed = nn.Embedding(1, embed_dim) + + self.mask_input_size = ( + 4 * image_embedding_size[0], + 4 * image_embedding_size[1], + ) + self.mask_downscaling = nn.Sequential( + nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans // 4), + activation(), + nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans), + activation(), + nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), + ) + self.no_mask_embed = nn.Embedding(1, embed_dim) + + def get_dense_pe(self) -> torch.Tensor: + """ + Returns the positional encoding used to encode point prompts, + applied to a dense set of points the shape of the image encoding. + + Returns: + torch.Tensor: Positional encoding with shape + 1x(embed_dim)x(embedding_h)x(embedding_w) + """ + return self.pe_layer(self.image_embedding_size).unsqueeze(0) + + def _embed_points( + self, + points: torch.Tensor, + labels: torch.Tensor, + pad: bool, + ) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) + padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) + points = torch.cat([points, padding_point], dim=1) + labels = torch.cat([labels, padding_label], dim=1) + point_embedding = self.pe_layer.forward_with_coords( + points, self.input_image_size + ) + + point_embedding = torch.where( + (labels == -1).unsqueeze(-1), + torch.zeros_like(point_embedding) + self.not_a_point_embed.weight, + point_embedding, + ) + point_embedding = torch.where( + (labels == 0).unsqueeze(-1), + point_embedding + self.point_embeddings[0].weight, + point_embedding, + ) + point_embedding = torch.where( + (labels == 1).unsqueeze(-1), + point_embedding + self.point_embeddings[1].weight, + point_embedding, + ) + point_embedding = torch.where( + (labels == 2).unsqueeze(-1), + point_embedding + self.point_embeddings[2].weight, + point_embedding, + ) + point_embedding = torch.where( + (labels == 3).unsqueeze(-1), + point_embedding + self.point_embeddings[3].weight, + point_embedding, + ) + return point_embedding + + def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """Embeds box prompts.""" + boxes = boxes + 0.5 # Shift to center of pixel + coords = boxes.reshape(-1, 2, 2) + corner_embedding = self.pe_layer.forward_with_coords( + coords, self.input_image_size + ) + corner_embedding[:, 0, :] += self.point_embeddings[2].weight + corner_embedding[:, 1, :] += self.point_embeddings[3].weight + return corner_embedding + + def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: + """Embeds mask inputs.""" + mask_embedding = self.mask_downscaling(masks) + return mask_embedding + + def _get_batch_size( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> int: + """ + Gets the batch size of the output given the batch size of the input prompts. + """ + if points is not None: + return points[0].shape[0] + elif boxes is not None: + return boxes.shape[0] + elif masks is not None: + return masks.shape[0] + else: + return 1 + + def _get_device(self) -> torch.device: + return self.point_embeddings[0].weight.device + + def forward( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense + embeddings. + + Arguments: + points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates + and labels to embed. + boxes (torch.Tensor or none): boxes to embed + masks (torch.Tensor or none): masks to embed + + Returns: + torch.Tensor: sparse embeddings for the points and boxes, with shape + BxNx(embed_dim), where N is determined by the number of input points + and boxes. + torch.Tensor: dense embeddings for the masks, in the shape + Bx(embed_dim)x(embed_H)x(embed_W) + """ + bs = self._get_batch_size(points, boxes, masks) + sparse_embeddings = torch.empty( + (bs, 0, self.embed_dim), device=self._get_device() + ) + if points is not None: + coords, labels = points + point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) + sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) + if boxes is not None: + box_embeddings = self._embed_boxes(boxes) + sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) + + if masks is not None: + dense_embeddings = self._embed_masks(masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + + return sparse_embeddings, dense_embeddings + + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + """ + + def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + super().__init__() + if scale is None or scale <= 0.0: + scale = 1.0 + self.register_buffer( + "positional_encoding_gaussian_matrix", + scale * torch.randn((2, num_pos_feats)), + ) + + def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward(self, size: Tuple[int, int]) -> torch.Tensor: + """Generate positional encoding for a grid of the specified size.""" + h, w = size + device: Any = self.positional_encoding_gaussian_matrix.device + grid = torch.ones((h, w), device=device, dtype=torch.float32) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / h + x_embed = x_embed / w + + pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) + return pe.permute(2, 0, 1) # C x H x W + + def forward_with_coords( + self, coords_input: torch.Tensor, image_size: Tuple[int, int] + ) -> torch.Tensor: + """Positionally encode points that are not normalized to [0,1].""" + coords = coords_input.clone() + coords[:, :, 0] = coords[:, :, 0] / image_size[1] + coords[:, :, 1] = coords[:, :, 1] / image_size[0] + return self._pe_encoding(coords.to(torch.float)) # B x N x C diff --git a/third_party/sam3/sam3/sam/rope.py b/third_party/sam3/sam3/sam/rope.py new file mode 100644 index 0000000000000000000000000000000000000000..60992734804ca16dbf191a7522b612418933202d --- /dev/null +++ b/third_party/sam3/sam3/sam/rope.py @@ -0,0 +1,163 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +""" +Adapted from: +1. https://github.com/meta-llama/codellama/blob/main/llama/model.py +2. https://github.com/naver-ai/rope-vit +3. https://github.com/lucidrains/rotary-embedding-torch +""" + +from typing import Optional + +import torch +from einops import rearrange, repeat +from torch import broadcast_tensors, nn + + +def init_t_xy(end_x: int, end_y: int, scale: float = 1.0, offset: int = 0, device=None): + t = torch.arange(end_x * end_y, dtype=torch.float32, device=device) + t_x = (t % end_x).float() + t_y = torch.div(t, end_x, rounding_mode="floor").float() + return t_x * scale + offset, t_y * scale + offset + + +def compute_axial_cis( + dim: int, + end_x: int, + end_y: int, + theta: float = 10000.0, + scale_pos: float = 1.0, + offset: int = 0, + device=None, +): + freqs_x = 1.0 / ( + theta ** (torch.arange(0, dim, 4, device=device)[: (dim // 4)].float() / dim) + ) + freqs_y = 1.0 / ( + theta ** (torch.arange(0, dim, 4, device=device)[: (dim // 4)].float() / dim) + ) + + t_x, t_y = init_t_xy(end_x, end_y, scale_pos, offset, device=device) + freqs_x = torch.outer(t_x, freqs_x) + freqs_y = torch.outer(t_y, freqs_y) + freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x) + freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y) + return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1) + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[-2], x.shape[-1]) + shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_enc( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, + repeat_freqs_k: bool = False, +): + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = ( + torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + if xk.shape[-2] != 0 + else None + ) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + if xk_ is None: + # no keys to rotate, due to dropout + return xq_out.type_as(xq).to(xq.device), xk + # repeat freqs along seq_len dim to match k seq_len + if repeat_freqs_k: + r = xk_.shape[-2] // xq_.shape[-2] + freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device) + + +def complex_mult(xq_real, xq_imag, freqs_cis_real, freqs_cis_imag): + # Compute the real part of the product + real_part = xq_real * freqs_cis_real - xq_imag * freqs_cis_imag + # Compute the imaginary part of the product + imag_part = xq_real * freqs_cis_imag + xq_imag * freqs_cis_real + # Stack the real and imaginary parts along the last dimension + return torch.stack([real_part, imag_part], dim=-1) + + +def apply_rotary_enc_real( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis_real: torch.Tensor, + freqs_cis_imag: torch.Tensor, + repeat_freqs_k: bool = False, +): + assert xk is not None + assert xk.shape[-2] != 0 + + xq_real = xq.float().reshape(*xq.shape[:-1], -1, 2)[..., 0] + xq_imag = xq.float().reshape(*xq.shape[:-1], -1, 2)[..., 1] + xk_real = xk.float().reshape(*xk.shape[:-1], -1, 2)[..., 0] + xk_imag = xk.float().reshape(*xk.shape[:-1], -1, 2)[..., 1] + freqs_cis_real = reshape_for_broadcast(freqs_cis_real, xq_real) + freqs_cis_imag = reshape_for_broadcast(freqs_cis_imag, xq_imag) + xq_out = complex_mult(xq_real, xq_imag, freqs_cis_real, freqs_cis_imag).flatten(3) + if repeat_freqs_k: + r = xk_real.shape[-2] // xq_real.shape[-2] + freqs_cis_real = freqs_cis_real.repeat(*([1] * (freqs_cis_real.ndim - 2)), r, 1) + freqs_cis_imag = freqs_cis_imag.repeat(*([1] * (freqs_cis_imag.ndim - 2)), r, 1) + xk_out = complex_mult(xk_real, xk_imag, freqs_cis_real, freqs_cis_imag).flatten(3) + # xq_out = torch.view_as_real(torch.complex(xq_real, xq_imag) * torch.complex(freqs_cis_real, freqs_cis_imag)).flatten(3) + # xk_out = torch.view_as_real(torch.compelx(xk_real, xk_imag) * torch.complex(freqs_cis_real, freqs_cis_imag)).flatten(3) + return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device) + + +# rotary embedding helper functions +def broadcat(tensors, dim=-1): + broadcasted_tensors = broadcast_tensors(*tensors) + return torch.cat(broadcasted_tensors, dim=dim) + + +def rotate_half(x: torch.Tensor): + x = rearrange(x, "... (d r) -> ... d r", r=2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return rearrange(x, "... d r -> ... (d r)") + + +class VisionRotaryEmbeddingVE(nn.Module): + def __init__( + self, + dim: int, + seq_len: int, + pt_seq_len: Optional[int] = None, + theta: float = 10000.0, + offset: int = 1, # specific to VE + ): + super().__init__() + + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + scale = 1.0 + if pt_seq_len is not None: + scale = pt_seq_len / seq_len + + # offset of +1 following VE - even though for the + # attention op only differences matter + t = torch.arange(seq_len) * scale + offset + + freqs = torch.einsum("..., f -> ... f", t, freqs) + freqs = repeat(freqs, "... n -> ... (n r)", r=2) + + freqs = broadcat((freqs[None, :, :], freqs[:, None, :]), dim=-1) + freqs_cos = freqs.cos().view(-1, freqs.shape[-1]) + freqs_sin = freqs.sin().view(-1, freqs.shape[-1]) + + self.register_buffer("freqs_cos", freqs_cos) + self.register_buffer("freqs_sin", freqs_sin) + + def forward(self, t: torch.Tensor): + return t * self.freqs_cos + rotate_half(t) * self.freqs_sin diff --git a/third_party/sam3/sam3/sam/transformer.py b/third_party/sam3/sam3/sam/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..6716c54f5fbd001f2194840fe6eefc0609f385d6 --- /dev/null +++ b/third_party/sam3/sam3/sam/transformer.py @@ -0,0 +1,359 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +import math +from functools import partial +from typing import Tuple, Type + +import torch +import torch.nn.functional as F +from sam3.sam.rope import apply_rotary_enc, apply_rotary_enc_real, compute_axial_cis +from torch import nn, Tensor + +from .common import MLPBlock + + +class TwoWayTransformer(nn.Module): + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + ) -> None: + """ + A transformer decoder that attends to an input image using + queries whose positional embedding is supplied. + + Args: + depth (int): number of layers in the transformer + embedding_dim (int): the channel dimension for the input embeddings + num_heads (int): the number of heads for multihead attention. Must + divide embedding_dim + mlp_dim (int): the channel dimension internal to the MLP block + activation (nn.Module): the activation to use in the MLP block + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + self.layers.append( + TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + ) + + self.final_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward( + self, + image_embedding: Tensor, + image_pe: Tensor, + point_embedding: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + image_embedding (torch.Tensor): image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe (torch.Tensor): the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding (torch.Tensor): the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. + + Returns: + torch.Tensor: the processed point_embedding + torch.Tensor: the processed image_embedding + """ + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + bs, c, h, w = image_embedding.shape + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + # Prepare queries + queries = point_embedding + keys = image_embedding + + # Apply transformer blocks and final layernorm + for layer in self.layers: + queries, keys = layer( + queries=queries, + keys=keys, + query_pe=point_embedding, + key_pe=image_pe, + ) + + # Apply the final attention layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + + return queries, keys + + +class TwoWayAttentionBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + A transformer block with four layers: (1) self-attention of sparse + inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp + block on sparse inputs, and (4) cross attention of dense inputs to sparse + inputs. + + Arguments: + embedding_dim (int): the channel dimension of the embeddings + num_heads (int): the number of heads in the attention layers + mlp_dim (int): the hidden dimension of the mlp block + activation (nn.Module): the activation of the mlp block + skip_first_layer_pe (bool): skip the PE on the first layer + """ + super().__init__() + self.self_attn = Attention(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor + ) -> Tuple[Tensor, Tensor]: + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(q=queries, k=queries, v=queries) + else: + q = queries + query_pe + attn_out = self.self_attn(q=q, k=q, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class Attention(nn.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + downsample_rate: int = 1, + dropout: float = 0.0, + kv_in_dim: int = None, + use_fa3: bool = False, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + self.use_fa3 = use_fa3 + assert ( + self.internal_dim % num_heads == 0 + ), "num_heads must divide embedding_dim." + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim) + self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + + self.dropout_p = dropout + + def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head + + def _recombine_heads(self, x: Tensor) -> Tensor: + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2) + return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + dropout_p = self.dropout_p if self.training else 0.0 + # Attention + # with torch.backends.cuda.sdp_kernel( + # enable_flash=USE_FLASH_ATTN, + # # if Flash attention kernel is off, then math kernel needs to be enabled + # enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON, + # enable_mem_efficient=OLD_GPU, + # ): + # Let's trust the dispatcher.... + if self.use_fa3: + from sam3.perflib.fa3 import flash_attn_func + + assert dropout_p == 0.0 + out = flash_attn_func( + q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + ).transpose(1, 2) + else: + torch.backends.cuda.enable_flash_sdp(True) + torch.backends.cuda.enable_math_sdp(True) + torch.backends.cuda.enable_mem_efficient_sdp(True) + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out + + +class RoPEAttention(Attention): + """Attention with rotary position encoding.""" + + def __init__( + self, + *args, + rope_theta=10000.0, + # whether to repeat q rope to match k length + # this is needed for cross-attention to memories + rope_k_repeat=False, + feat_sizes=(64, 64), # [w, h] for stride 16 feats at 1024 resolution + use_rope_real=False, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.use_rope_real = use_rope_real + self.compute_cis = partial( + compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta + ) + device = torch.device("cuda") if torch.cuda.is_available() else None + self.freqs_cis = self.compute_cis( + end_x=feat_sizes[0], end_y=feat_sizes[1], device=device + ) + if self.use_rope_real: + self.freqs_cis_real = self.freqs_cis.real + self.freqs_cis_imag = self.freqs_cis.imag + self.rope_k_repeat = rope_k_repeat + + def forward( + self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0 + ) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Apply rotary position encoding + w = h = math.sqrt(q.shape[-2]) + if self.freqs_cis.shape[0] != q.shape[-2]: + self.freqs_cis = self.compute_cis(end_x=w, end_y=h, device=q.device) + self.freqs_cis_real = self.freqs_cis.real + self.freqs_cis_imag = self.freqs_cis.imag + if q.shape[-2] != k.shape[-2]: + assert self.rope_k_repeat + + num_k_rope = k.size(-2) - num_k_exclude_rope + if self.use_rope_real: + q, k[:, :, :num_k_rope] = apply_rotary_enc_real( + q, + k[:, :, :num_k_rope], + freqs_cis_real=self.freqs_cis_real, + freqs_cis_imag=self.freqs_cis_imag, + repeat_freqs_k=self.rope_k_repeat, + ) + else: + q, k[:, :, :num_k_rope] = apply_rotary_enc( + q, + k[:, :, :num_k_rope], + self.freqs_cis, + repeat_freqs_k=self.rope_k_repeat, + ) + + dropout_p = self.dropout_p if self.training else 0.0 + # Attention + # with torch.backends.cuda.sdp_kernel( + # enable_flash=USE_FLASH_ATTN, + # # if Flash attention kernel is off, then math kernel needs to be enabled + # enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON, + # enable_mem_efficient=OLD_GPU, + # ): + # Let's trust the dispatcher.... + if self.use_fa3: + from sam3.perflib.fa3 import flash_attn_func + + assert dropout_p == 0.0 + out = flash_attn_func( + q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + ).transpose(1, 2) + else: + torch.backends.cuda.enable_flash_sdp(True) + torch.backends.cuda.enable_math_sdp(True) + torch.backends.cuda.enable_mem_efficient_sdp(True) + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out diff --git a/third_party/sam3/sam3/train/__init__.py b/third_party/sam3/sam3/train/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..726cb1eb0ff0f17f7edc4ec73f1a73d1ac87bf59 --- /dev/null +++ b/third_party/sam3/sam3/train/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe diff --git a/third_party/sam3/sam3/train/__pycache__/__init__.cpython-311.pyc b/third_party/sam3/sam3/train/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1763bf7866d2becd7105ae1adeaa43db35dec189 Binary files /dev/null and b/third_party/sam3/sam3/train/__pycache__/__init__.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/train/__pycache__/masks_ops.cpython-311.pyc b/third_party/sam3/sam3/train/__pycache__/masks_ops.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..462eb6c23b0b1dcaf9df559ab201677f2f76e8eb Binary files /dev/null and b/third_party/sam3/sam3/train/__pycache__/masks_ops.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/train/configs/eval_base.yaml b/third_party/sam3/sam3/train/configs/eval_base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2a15f2a0d6d16d9c8c401f611dd51c62eb0f7cac --- /dev/null +++ b/third_party/sam3/sam3/train/configs/eval_base.yaml @@ -0,0 +1,279 @@ +# @package _global_ +defaults: + - _self_ + +# This config is the base configuration for all evaluations. Amongst other things, it defines: +# - the model +# - the image transforms +# - the post processors +# - cluster configuration (only relevant for slurm-based evals, ignored otherwise) +# +# Most of the parameters should be kept as-is. The main modifications you may want to make are: +# - the cluster configuration, to adjust partitions/qos to your system +# - the flag gather_pred_via_filesys if you ram is tight +# - num_val_workers if your number of cores is small (should be roughly number of cores / number of gpus) +# - the paths below + + +# ============================================================================ +# Paths Configuration (Chage this to your own paths) +# ============================================================================ +paths: + # If you leave the checkpoint path to null, the model will be downloaded from hugging-face. Otherwise provide a path + checkpoint_path: null + # the experiments will be subfolders of this + base_experiment_log_dir: + + # base path to the annotation folder for gold (refer to the readmes on how to download) + base_annotation_path: + + # base path to the annotation folder for silver (refer to the readmes on how to download) + base_annotation_path_silver: + + # path to the metaclip images, used for SA-Co gold (refer to the readme for instructions). Can be null if you don't intend on evaluating on this dataset. + metaclip_img_path: + + # path to the sa1b images, used for SA-Co gold (refer to the readme for instructions). Can be null if you don't intend on evaluating on this dataset. + sa1b_img_path: + + # path to the SA-Co/silver images + silver_img_path: + + bpe_path: # This should be under sam3/assets/bpe_simple_vocab_16e6.txt.gz + + +# ============================================================================ +# Different helper parameters and functions +# ============================================================================ +scratch: + + use_presence_eval: True + + base_val_transform: + - _target_: sam3.train.transforms.basic_for_api.ComposeAPI + transforms: + ######## transforms for validation (begin) ######## + - _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI + sizes: ${scratch.resolution} # originally `resolution: 1024` + max_size: + _target_: sam3.train.transforms.basic.get_random_resize_max_size + size: ${scratch.resolution} # originally `resolution: 1024` + square: true + consistent_transform: False + ######## transforms for validation (end) ######## + - _target_: sam3.train.transforms.basic_for_api.ToTensorAPI + - _target_: sam3.train.transforms.basic_for_api.NormalizeAPI + mean: ${scratch.val_norm_mean} + std: ${scratch.val_norm_std} + + loss: null + + # Model parameters + d_model: 256 + input_box_embedding_dim: ${add:${scratch.d_model},2} + + # Box processing + original_box_postprocessor: + _target_: sam3.eval.postprocessors.PostProcessImage + max_dets_per_img: -1 # infinite detections + use_original_ids: true + use_original_sizes_box: true + use_presence: ${scratch.use_presence_eval} + + box_postprocessor: + _target_: sam3.eval.postprocessors.PostProcessImage + max_dets_per_img: -1 #infinite detections + use_original_ids: false + use_original_sizes_box: false + use_presence: ${scratch.use_presence_eval} + + box_postprocessor_thresholded: + _target_: sam3.eval.postprocessors.PostProcessImage + max_dets_per_img: -1 #infinite detections + use_original_ids: false + use_original_sizes_box: false + detection_threshold: 0.3 + use_presence: ${scratch.use_presence_eval} + + mask_postprocessor_thresholded: + _target_: sam3.eval.postprocessors.PostProcessImage + max_dets_per_img: -1 #infinite detections + iou_type: "segm" + use_original_ids: false + use_original_sizes_box: false + use_original_sizes_mask: true + convert_mask_to_rle: True + detection_threshold: 0.3 + use_presence: ${scratch.use_presence_eval} + + # Image processing parameters + resolution: 1008 + max_ann_per_img: 200 + + # Normalization parameters + train_norm_mean: [0.5, 0.5, 0.5] + train_norm_std: [0.5, 0.5, 0.5] + val_norm_mean: [0.5, 0.5, 0.5] + val_norm_std: [0.5, 0.5, 0.5] + + # Training parameters + train_batch_size: 1 + val_batch_size: 1 + num_train_workers: 0 + num_val_workers: 10 # change this depending on the number of cpu cores available + max_data_epochs: 20 + target_epoch_size: 1500 + hybrid_repeats: 1 + context_length: 2 + + # All reduce - this controls how the predictions are sent back to node 0. + # If you have a lot of ram, CPU gather is faster. Otherwise, we provide a fallback through filesystem (eg NFS) + # Switch to true if you get cpu ooms during gather. + gather_pred_via_filesys: false + + # Learning rate and scheduler parameters (unused for eval) + lr_scale: 0.1 + lr_transformer: ${times:8e-4,${scratch.lr_scale}} + lr_vision_backbone: ${times:2.5e-4,${scratch.lr_scale}} + lr_language_backbone: ${times:5e-5,${scratch.lr_scale}} + lrd_vision_backbone: 0.9 # (lower for in-domain adn higher for ood) + wd: 0.1 + scheduler_timescale: 20 + scheduler_warmup: 20 + scheduler_cooldown: 20 + + +# ============================================================================ +# Trainer Configuration +# ============================================================================ + +trainer: + _target_: sam3.train.trainer.Trainer + skip_saving_ckpts: true + empty_gpu_mem_cache_after_eval: True + skip_first_val: True + max_epochs: ${scratch.max_data_epochs} + accelerator: cuda + seed_value: 123 + val_epoch_freq: 10 + mode: val + + distributed: + backend: nccl + find_unused_parameters: True + gradient_as_bucket_view: True + + loss: + all: + _target_: sam3.train.loss.sam3_loss.DummyLoss + default: + _target_: sam3.train.loss.sam3_loss.DummyLoss + + data: + train: null + val: null + + model: + _target_: sam3.model_builder.build_sam3_image_model + bpe_path: ${paths.bpe_path} + device: cpus + eval_mode: true + enable_segmentation: true # Warning: Enable this if using segmentation. + checkpoint_path: ${paths.checkpoint_path} + + meters: + val: null + + optim: + amp: + enabled: True + amp_dtype: bfloat16 + + optimizer: + _target_: torch.optim.AdamW + + gradient_clip: + _target_: sam3.train.optim.optimizer.GradientClipper + max_norm: 0.1 + norm_type: 2 + + param_group_modifiers: + - _target_: sam3.train.optim.optimizer.layer_decay_param_modifier + _partial_: True + layer_decay_value: ${scratch.lrd_vision_backbone} + apply_to: 'backbone.vision_backbone.trunk' + overrides: + - pattern: '*pos_embed*' + value: 1.0 + + options: + lr: + - scheduler: # transformer and class_embed + _target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler + base_lr: ${scratch.lr_transformer} + timescale: ${scratch.scheduler_timescale} + warmup_steps: ${scratch.scheduler_warmup} + cooldown_steps: ${scratch.scheduler_cooldown} + - scheduler: + _target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler + base_lr: ${scratch.lr_vision_backbone} + timescale: ${scratch.scheduler_timescale} + warmup_steps: ${scratch.scheduler_warmup} + cooldown_steps: ${scratch.scheduler_cooldown} + param_names: + - 'backbone.vision_backbone.*' + - scheduler: + _target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler + base_lr: ${scratch.lr_language_backbone} + timescale: ${scratch.scheduler_timescale} + warmup_steps: ${scratch.scheduler_warmup} + cooldown_steps: ${scratch.scheduler_cooldown} + param_names: + - 'backbone.language_backbone.*' + + weight_decay: + - scheduler: + _target_: fvcore.common.param_scheduler.ConstantParamScheduler + value: ${scratch.wd} + - scheduler: + _target_: fvcore.common.param_scheduler.ConstantParamScheduler + value: 0.0 + param_names: + - '*bias*' + module_cls_names: ['torch.nn.LayerNorm'] + + checkpoint: + save_dir: ${launcher.experiment_log_dir}/checkpoints + save_freq: 0 # 0 only last checkpoint is saved. + + + logging: + tensorboard_writer: + _target_: sam3.train.utils.logger.make_tensorboard_logger + log_dir: ${launcher.experiment_log_dir}/tensorboard + flush_secs: 120 + should_log: True + wandb_writer: null + log_dir: ${launcher.experiment_log_dir}/logs/ + log_freq: 10 + +# ============================================================================ +# Launcher and Submitit Configuration +# ============================================================================ + +launcher: + num_nodes: 4 + gpus_per_node: 8 + experiment_log_dir: ${paths.experiment_log_dir} + multiprocessing_context: forkserver + + +submitit: + account: null # Add your SLURM account if use_cluster == 1 + partition: null + qos: null # Add your QoS if use_cluster == 1 + timeout_hour: 72 + use_cluster: True + cpus_per_task: 10 + port_range: [10000, 65000] + constraint: null diff --git a/third_party/sam3/sam3/train/configs/gold_image_evals/sam3_gold_image_attributes.yaml b/third_party/sam3/sam3/train/configs/gold_image_evals/sam3_gold_image_attributes.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8646b691734e1dd191d53e700d9b7dcb2c23de72 --- /dev/null +++ b/third_party/sam3/sam3/train/configs/gold_image_evals/sam3_gold_image_attributes.yaml @@ -0,0 +1,66 @@ +# @package _global_ +defaults: + - /configs/eval_base.yaml + - _self_ + +# ============================================================================ +# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct +# ============================================================================ +paths: + experiment_log_dir: ${paths.base_experiment_log_dir}/gold_attributes/ + coco_gt: ${paths.base_annotation_path}/gold_attributes_merged_a_release_test.json + coco_gts: + - ${paths.base_annotation_path}/gold_attributes_merged_a_release_test.json + - ${paths.base_annotation_path}/gold_attributes_merged_b_release_test.json + - ${paths.base_annotation_path}/gold_attributes_merged_c_release_test.json + + +# ============================================================================ +# Trainer Configuration +# ============================================================================ + +trainer: + data: + val: + _target_: sam3.train.data.torch_dataset.TorchDataset + dataset: + _target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset + coco_json_loader: + _target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP + _partial_: true + img_folder: ${paths.metaclip_img_path} + ann_file: ${paths.coco_gt} + transforms: ${scratch.base_val_transform} + max_ann_per_img: 100000 + multiplier: 1 + training: false + + shuffle: False + batch_size: ${scratch.val_batch_size} + num_workers: ${scratch.num_val_workers} + pin_memory: False + drop_last: False + collate_fn: + _target_: sam3.train.data.collator.collate_fn_api + _partial_: true + repeats: ${scratch.hybrid_repeats} + dict_key: gold_attributes + + meters: + val: + gold_attributes: # this key matches the "dict_key" in the dataloader's collate function + cgf1: + _target_: sam3.eval.coco_writer.PredictionDumper + iou_type: "segm" + dump_dir: ${launcher.experiment_log_dir}/dumps/gold_attributes + merge_predictions: True + postprocessor: ${scratch.mask_postprocessor_thresholded} + gather_pred_via_filesys: ${scratch.gather_pred_via_filesys} + maxdets: 1000000 # no limit + pred_file_evaluators: + - _target_: sam3.eval.cgf1_eval.CGF1Evaluator + gt_path: ${paths.coco_gts} + iou_type: "bbox" + - _target_: sam3.eval.cgf1_eval.CGF1Evaluator + gt_path: ${paths.coco_gts} + iou_type: "segm" diff --git a/third_party/sam3/sam3/train/configs/gold_image_evals/sam3_gold_image_crowded.yaml b/third_party/sam3/sam3/train/configs/gold_image_evals/sam3_gold_image_crowded.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fef74a6ee56c901c258a8ced2beff773a38ec545 --- /dev/null +++ b/third_party/sam3/sam3/train/configs/gold_image_evals/sam3_gold_image_crowded.yaml @@ -0,0 +1,66 @@ +# @package _global_ +defaults: + - /configs/eval_base.yaml + - _self_ + +# ============================================================================ +# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct +# ============================================================================ +paths: + experiment_log_dir: ${paths.base_experiment_log_dir}/gold_crowded/ + coco_gt: ${paths.base_annotation_path}/gold_crowded_merged_a_release_test.json + coco_gts: + - ${paths.base_annotation_path}/gold_crowded_merged_a_release_test.json + - ${paths.base_annotation_path}/gold_crowded_merged_b_release_test.json + - ${paths.base_annotation_path}/gold_crowded_merged_c_release_test.json + + +# ============================================================================ +# Trainer Configuration +# ============================================================================ + +trainer: + data: + val: + _target_: sam3.train.data.torch_dataset.TorchDataset + dataset: + _target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset + coco_json_loader: + _target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP + _partial_: true + img_folder: ${paths.metaclip_img_path} + ann_file: ${paths.coco_gt} + transforms: ${scratch.base_val_transform} + max_ann_per_img: 100000 + multiplier: 1 + training: false + + shuffle: False + batch_size: ${scratch.val_batch_size} + num_workers: ${scratch.num_val_workers} + pin_memory: False + drop_last: False + collate_fn: + _target_: sam3.train.data.collator.collate_fn_api + _partial_: true + repeats: ${scratch.hybrid_repeats} + dict_key: gold_crowded + + meters: + val: + gold_crowded: # this key matches the "dict_key" in the dataloader's collate function + cgf1: + _target_: sam3.eval.coco_writer.PredictionDumper + iou_type: "segm" + dump_dir: ${launcher.experiment_log_dir}/dumps/gold_crowded + merge_predictions: True + postprocessor: ${scratch.mask_postprocessor_thresholded} + gather_pred_via_filesys: ${scratch.gather_pred_via_filesys} + maxdets: 1000000 # no limit + pred_file_evaluators: + - _target_: sam3.eval.cgf1_eval.CGF1Evaluator + gt_path: ${paths.coco_gts} + iou_type: "bbox" + - _target_: sam3.eval.cgf1_eval.CGF1Evaluator + gt_path: ${paths.coco_gts} + iou_type: "segm" diff --git a/third_party/sam3/sam3/train/configs/gold_image_evals/sam3_gold_image_fg_food.yaml b/third_party/sam3/sam3/train/configs/gold_image_evals/sam3_gold_image_fg_food.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b08c4a46921db2123f2540a63536140ea641320e --- /dev/null +++ b/third_party/sam3/sam3/train/configs/gold_image_evals/sam3_gold_image_fg_food.yaml @@ -0,0 +1,66 @@ +# @package _global_ +defaults: + - /configs/eval_base.yaml + - _self_ + +# ============================================================================ +# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct +# ============================================================================ +paths: + experiment_log_dir: ${paths.base_experiment_log_dir}/gold_fg_food/ + coco_gt: ${paths.base_annotation_path}/gold_fg_food_merged_a_release_test.json + coco_gts: + - ${paths.base_annotation_path}/gold_fg_food_merged_a_release_test.json + - ${paths.base_annotation_path}/gold_fg_food_merged_b_release_test.json + - ${paths.base_annotation_path}/gold_fg_food_merged_c_release_test.json + + +# ============================================================================ +# Trainer Configuration +# ============================================================================ + +trainer: + data: + val: + _target_: sam3.train.data.torch_dataset.TorchDataset + dataset: + _target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset + coco_json_loader: + _target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP + _partial_: true + img_folder: ${paths.metaclip_img_path} + ann_file: ${paths.coco_gt} + transforms: ${scratch.base_val_transform} + max_ann_per_img: 100000 + multiplier: 1 + training: false + + shuffle: False + batch_size: ${scratch.val_batch_size} + num_workers: ${scratch.num_val_workers} + pin_memory: False + drop_last: False + collate_fn: + _target_: sam3.train.data.collator.collate_fn_api + _partial_: true + repeats: ${scratch.hybrid_repeats} + dict_key: gold_fg_food + + meters: + val: + gold_fg_food: # this key matches the "dict_key" in the dataloader's collate function + cgf1: + _target_: sam3.eval.coco_writer.PredictionDumper + iou_type: "segm" + dump_dir: ${launcher.experiment_log_dir}/dumps/gold_fg_food + merge_predictions: True + postprocessor: ${scratch.mask_postprocessor_thresholded} + gather_pred_via_filesys: ${scratch.gather_pred_via_filesys} + maxdets: 1000000 # no limit + pred_file_evaluators: + - _target_: sam3.eval.cgf1_eval.CGF1Evaluator + gt_path: ${paths.coco_gts} + iou_type: "bbox" + - _target_: sam3.eval.cgf1_eval.CGF1Evaluator + gt_path: ${paths.coco_gts} + iou_type: "segm" diff --git a/third_party/sam3/sam3/train/configs/gold_image_evals/sam3_gold_image_fg_sports.yaml b/third_party/sam3/sam3/train/configs/gold_image_evals/sam3_gold_image_fg_sports.yaml new file mode 100644 index 0000000000000000000000000000000000000000..89a93be2acf36ab0f84481dfded86340da97b9a6 --- /dev/null +++ b/third_party/sam3/sam3/train/configs/gold_image_evals/sam3_gold_image_fg_sports.yaml @@ -0,0 +1,66 @@ +# @package _global_ +defaults: + - /configs/eval_base.yaml + - _self_ + +# ============================================================================ +# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct +# ============================================================================ +paths: + experiment_log_dir: ${paths.base_experiment_log_dir}/gold_fg_sports_equipment/ + coco_gt: ${paths.base_annotation_path}/gold_fg_sports_equipment_merged_a_release_test.json + coco_gts: + - ${paths.base_annotation_path}/gold_fg_sports_equipment_merged_a_release_test.json + - ${paths.base_annotation_path}/gold_fg_sports_equipment_merged_b_release_test.json + - ${paths.base_annotation_path}/gold_fg_sports_equipment_merged_c_release_test.json + + +# ============================================================================ +# Trainer Configuration +# ============================================================================ + +trainer: + data: + val: + _target_: sam3.train.data.torch_dataset.TorchDataset + dataset: + _target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset + coco_json_loader: + _target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP + _partial_: true + img_folder: ${paths.metaclip_img_path} + ann_file: ${paths.coco_gt} + transforms: ${scratch.base_val_transform} + max_ann_per_img: 100000 + multiplier: 1 + training: false + + shuffle: False + batch_size: ${scratch.val_batch_size} + num_workers: ${scratch.num_val_workers} + pin_memory: False + drop_last: False + collate_fn: + _target_: sam3.train.data.collator.collate_fn_api + _partial_: true + repeats: ${scratch.hybrid_repeats} + dict_key: gold_fg_sports_equipment + + meters: + val: + gold_fg_sports_equipment: # this key matches the "dict_key" in the dataloader's collate function + cgf1: + _target_: sam3.eval.coco_writer.PredictionDumper + iou_type: "segm" + dump_dir: ${launcher.experiment_log_dir}/dumps/gold_fg_sports_equipment + merge_predictions: True + postprocessor: ${scratch.mask_postprocessor_thresholded} + gather_pred_via_filesys: ${scratch.gather_pred_via_filesys} + maxdets: 1000000 # no limit + pred_file_evaluators: + - _target_: sam3.eval.cgf1_eval.CGF1Evaluator + gt_path: ${paths.coco_gts} + iou_type: "bbox" + - _target_: sam3.eval.cgf1_eval.CGF1Evaluator + gt_path: ${paths.coco_gts} + iou_type: "segm" diff --git a/third_party/sam3/sam3/train/configs/gold_image_evals/sam3_gold_image_metaclip_nps.yaml b/third_party/sam3/sam3/train/configs/gold_image_evals/sam3_gold_image_metaclip_nps.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e9c276f4299d4a53d5b44cea5194918541a0d25d --- /dev/null +++ b/third_party/sam3/sam3/train/configs/gold_image_evals/sam3_gold_image_metaclip_nps.yaml @@ -0,0 +1,66 @@ +# @package _global_ +defaults: + - /configs/eval_base.yaml + - _self_ + +# ============================================================================ +# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct +# ============================================================================ +paths: + experiment_log_dir: ${paths.base_experiment_log_dir}/gold_metaclip_nps/ + coco_gt: ${paths.base_annotation_path}/gold_metaclip_merged_a_release_test.json + coco_gts: + - ${paths.base_annotation_path}/gold_metaclip_merged_a_release_test.json + - ${paths.base_annotation_path}/gold_metaclip_merged_b_release_test.json + - ${paths.base_annotation_path}/gold_metaclip_merged_c_release_test.json + + +# ============================================================================ +# Trainer Configuration +# ============================================================================ + +trainer: + data: + val: + _target_: sam3.train.data.torch_dataset.TorchDataset + dataset: + _target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset + coco_json_loader: + _target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP + _partial_: true + img_folder: ${paths.metaclip_img_path} + ann_file: ${paths.coco_gt} + transforms: ${scratch.base_val_transform} + max_ann_per_img: 100000 + multiplier: 1 + training: false + + shuffle: False + batch_size: ${scratch.val_batch_size} + num_workers: ${scratch.num_val_workers} + pin_memory: False + drop_last: False + collate_fn: + _target_: sam3.train.data.collator.collate_fn_api + _partial_: true + repeats: ${scratch.hybrid_repeats} + dict_key: gold_metaclip_nps + + meters: + val: + gold_metaclip_nps: # this key matches the "dict_key" in the dataloader's collate function + cgf1: + _target_: sam3.eval.coco_writer.PredictionDumper + iou_type: "segm" + dump_dir: ${launcher.experiment_log_dir}/dumps/gold_metaclip_nps + merge_predictions: True + postprocessor: ${scratch.mask_postprocessor_thresholded} + gather_pred_via_filesys: ${scratch.gather_pred_via_filesys} + maxdets: 1000000 # no limit + pred_file_evaluators: + - _target_: sam3.eval.cgf1_eval.CGF1Evaluator + gt_path: ${paths.coco_gts} + iou_type: "bbox" + - _target_: sam3.eval.cgf1_eval.CGF1Evaluator + gt_path: ${paths.coco_gts} + iou_type: "segm" diff --git a/third_party/sam3/sam3/train/configs/gold_image_evals/sam3_gold_image_sa1b_nps.yaml b/third_party/sam3/sam3/train/configs/gold_image_evals/sam3_gold_image_sa1b_nps.yaml new file mode 100644 index 0000000000000000000000000000000000000000..52c87ee30545d24502160e7e8e3a565ce8d83bf2 --- /dev/null +++ b/third_party/sam3/sam3/train/configs/gold_image_evals/sam3_gold_image_sa1b_nps.yaml @@ -0,0 +1,66 @@ +# @package _global_ +defaults: + - /configs/eval_base.yaml + - _self_ + +# ============================================================================ +# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct +# ============================================================================ +paths: + experiment_log_dir: ${paths.base_experiment_log_dir}/gold_sa1b_nps/ + coco_gt: ${paths.base_annotation_path}/gold_sa1b_merged_a_release_test.json + coco_gts: + - ${paths.base_annotation_path}/gold_sa1b_merged_a_release_test.json + - ${paths.base_annotation_path}/gold_sa1b_merged_b_release_test.json + - ${paths.base_annotation_path}/gold_sa1b_merged_c_release_test.json + + +# ============================================================================ +# Trainer Configuration +# ============================================================================ + +trainer: + data: + val: + _target_: sam3.train.data.torch_dataset.TorchDataset + dataset: + _target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset + coco_json_loader: + _target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP + _partial_: true + img_folder: ${paths.sa1b_img_path} + ann_file: ${paths.coco_gt} + transforms: ${scratch.base_val_transform} + max_ann_per_img: 100000 + multiplier: 1 + training: false + + shuffle: False + batch_size: ${scratch.val_batch_size} + num_workers: ${scratch.num_val_workers} + pin_memory: False + drop_last: False + collate_fn: + _target_: sam3.train.data.collator.collate_fn_api + _partial_: true + repeats: ${scratch.hybrid_repeats} + dict_key: gold_sa1b_nps + + meters: + val: + gold_sa1b_nps: # this key matches the "dict_key" in the dataloader's collate function + cgf1: + _target_: sam3.eval.coco_writer.PredictionDumper + iou_type: "segm" + dump_dir: ${launcher.experiment_log_dir}/dumps/gold_sa1b_nps + merge_predictions: True + postprocessor: ${scratch.mask_postprocessor_thresholded} + gather_pred_via_filesys: ${scratch.gather_pred_via_filesys} + maxdets: 1000000 # no limit + pred_file_evaluators: + - _target_: sam3.eval.cgf1_eval.CGF1Evaluator + gt_path: ${paths.coco_gts} + iou_type: "bbox" + - _target_: sam3.eval.cgf1_eval.CGF1Evaluator + gt_path: ${paths.coco_gts} + iou_type: "segm" diff --git a/third_party/sam3/sam3/train/configs/gold_image_evals/sam3_gold_image_wiki_common.yaml b/third_party/sam3/sam3/train/configs/gold_image_evals/sam3_gold_image_wiki_common.yaml new file mode 100644 index 0000000000000000000000000000000000000000..630495423c3840f0e795ee3c501ee5f5b44a3505 --- /dev/null +++ b/third_party/sam3/sam3/train/configs/gold_image_evals/sam3_gold_image_wiki_common.yaml @@ -0,0 +1,66 @@ +# @package _global_ +defaults: + - /configs/eval_base.yaml + - _self_ + +# ============================================================================ +# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct +# ============================================================================ +paths: + experiment_log_dir: ${paths.base_experiment_log_dir}/gold_wiki_common/ + coco_gt: ${paths.base_annotation_path}/gold_wiki_common_merged_a_release_test.json + coco_gts: + - ${paths.base_annotation_path}/gold_wiki_common_merged_a_release_test.json + - ${paths.base_annotation_path}/gold_wiki_common_merged_b_release_test.json + - ${paths.base_annotation_path}/gold_wiki_common_merged_c_release_test.json + + +# ============================================================================ +# Trainer Configuration +# ============================================================================ + +trainer: + data: + val: + _target_: sam3.train.data.torch_dataset.TorchDataset + dataset: + _target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset + coco_json_loader: + _target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP + _partial_: true + img_folder: ${paths.metaclip_img_path} + ann_file: ${paths.coco_gt} + transforms: ${scratch.base_val_transform} + max_ann_per_img: 100000 + multiplier: 1 + training: false + + shuffle: False + batch_size: ${scratch.val_batch_size} + num_workers: ${scratch.num_val_workers} + pin_memory: False + drop_last: False + collate_fn: + _target_: sam3.train.data.collator.collate_fn_api + _partial_: true + repeats: ${scratch.hybrid_repeats} + dict_key: gold_wiki_common + + meters: + val: + gold_wiki_common: # this key matches the "dict_key" in the dataloader's collate function + cgf1: + _target_: sam3.eval.coco_writer.PredictionDumper + iou_type: "segm" + dump_dir: ${launcher.experiment_log_dir}/dumps/gold_wiki_common + merge_predictions: True + postprocessor: ${scratch.mask_postprocessor_thresholded} + gather_pred_via_filesys: ${scratch.gather_pred_via_filesys} + maxdets: 1000000 # no limit + pred_file_evaluators: + - _target_: sam3.eval.cgf1_eval.CGF1Evaluator + gt_path: ${paths.coco_gts} + iou_type: "bbox" + - _target_: sam3.eval.cgf1_eval.CGF1Evaluator + gt_path: ${paths.coco_gts} + iou_type: "segm" diff --git a/third_party/sam3/sam3/train/configs/odinw13/odinw_text_and_visual.yaml b/third_party/sam3/sam3/train/configs/odinw13/odinw_text_and_visual.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1c026e2e6ce44d2ef4fd5ba5f0f661591bd16c4f --- /dev/null +++ b/third_party/sam3/sam3/train/configs/odinw13/odinw_text_and_visual.yaml @@ -0,0 +1,255 @@ +# @package _global_ +defaults: + - _self_ + +# ============================================================================ +# Paths Configuration (Chage this to your own paths) +# ============================================================================ +# python sam3/train/train.py -c configs/odinw_text_only.yaml --use-cluster 1 --partition ${PARTITION} --account ${ACCOUNT} --qos ${QoS} + +paths: + odinw_data_root: + experiment_log_dir: + bpe_path: # This should be under sam3/assets/bpe_simple_vocab_16e6.txt.gz + +supercategory_tuple: ${all_odinw_supercategories.${string:${submitit.job_array.task_index}}} +# Validation transforms pipeline +val_transforms: + - _target_: sam3.train.transforms.basic_for_api.ComposeAPI + transforms: + - _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI + sizes: ${scratch.resolution} + max_size: + _target_: sam3.train.transforms.basic.get_random_resize_max_size + size: ${scratch.resolution} + square: true + consistent_transform: False + - _target_: sam3.train.transforms.basic_for_api.ToTensorAPI + - _target_: sam3.train.transforms.basic_for_api.NormalizeAPI + mean: ${scratch.val_norm_mean} + std: ${scratch.val_norm_std} + - _target_: sam3.train.transforms.filter_query_transforms.TextQueryToVisual + keep_text_queries: true # Note: set this to false if you only want visual + probability: 1.0 # always + +# ============================================================================ +# Different helper parameters and functions +# ============================================================================ +scratch: + enable_segmentation: True + # Box processing + use_presence_eval: True + original_box_postprocessor: + _target_: sam3.eval.postprocessors.PostProcessImage + max_dets_per_img: -1 # infinite detections + use_original_ids: true + use_original_sizes_box: true + use_presence: ${scratch.use_presence_eval} + + # Image processing parameters + resolution: 1008 + # Normalization parameters + val_norm_mean: [0.5, 0.5, 0.5] + val_norm_std: [0.5, 0.5, 0.5] + + # Training parameters + val_batch_size: 2 + num_val_workers: 0 + gather_pred_via_filesys: false + +# ============================================================================ +# Trainer Configuration +# ============================================================================ + +trainer: + _target_: sam3.train.trainer.Trainer + skip_saving_ckpts: true + empty_gpu_mem_cache_after_eval: True + max_epochs: 1 + accelerator: cuda + seed_value: 123 + mode: val + + distributed: + backend: nccl + find_unused_parameters: True + gradient_as_bucket_view: True + + loss: + default: + _target_: sam3.train.loss.sam3_loss.DummyLoss + + data: + val: + _target_: sam3.train.data.torch_dataset.TorchDataset + dataset: + _target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset + coco_json_loader: + _target_: sam3.train.data.coco_json_loaders.COCO_FROM_JSON + prompts: ${odinw35_prompts.${supercategory_tuple.name}} + include_negatives: true + category_chunk_size: 20 # Note: Since we are doing AP +ve we need to include all categories! + _partial_: true + img_folder: ${paths.odinw_data_root}/${supercategory_tuple.val.img_folder} + ann_file: + _target_: sam3.eval.coco_reindex.reindex_coco_to_temp + input_json_path: ${paths.odinw_data_root}/${supercategory_tuple.val.json} + transforms: ${val_transforms} + max_ann_per_img: 100000 + multiplier: 1 + training: false + + shuffle: False + batch_size: ${scratch.val_batch_size} + num_workers: ${scratch.num_val_workers} + pin_memory: False + drop_last: False + collate_fn: + _target_: sam3.train.data.collator.collate_fn_api + _partial_: true + repeats: 1 + dict_key: odinw35 + + model: + _target_: sam3.model_builder.build_sam3_image_model + bpe_path: ${paths.bpe_path} + device: cpus + eval_mode: true # Set to false if training + enable_segmentation: ${scratch.enable_segmentation} # Warning: Enable this if using segmentation. + + meters: + val: + odinw35: + detection: + _target_: sam3.eval.coco_writer.PredictionDumper + iou_type: "bbox" + dump_dir: ${launcher.experiment_log_dir}/dumps/roboflow/${supercategory_tuple.name} + merge_predictions: True + postprocessor: ${scratch.original_box_postprocessor} + gather_pred_via_filesys: ${scratch.gather_pred_via_filesys} + maxdets: 100 + pred_file_evaluators: + - _target_: sam3.eval.coco_eval_offline.CocoEvaluatorOfflineWithPredFileEvaluators + gt_path: + _target_: sam3.eval.coco_reindex.reindex_coco_to_temp + input_json_path: ${paths.odinw_data_root}/${supercategory_tuple.val.json} + tide: False + iou_type: "bbox" + positive_split: true + + checkpoint: + save_dir: ${launcher.experiment_log_dir}/checkpoints + save_freq: 0 # 0 only last checkpoint is saved. + + + logging: + tensorboard_writer: + _target_: sam3.train.utils.logger.make_tensorboard_logger + log_dir: ${launcher.experiment_log_dir}/tensorboard + flush_secs: 120 + should_log: True + wandb_writer: null + log_dir: ${launcher.experiment_log_dir}/logs/${supercategory_tuple.name} + log_freq: 10 + +# ============================================================================ +# Launcher and Submitit Configuration +# ============================================================================ + +launcher: + num_nodes: 1 + gpus_per_node: 2 + experiment_log_dir: ${paths.experiment_log_dir} + multiprocessing_context: forkserver + +submitit: + account: null + partition: null + qos: null + timeout_hour: 72 + use_cluster: True + cpus_per_task: 10 + port_range: [10000, 65000] + constraint: null + + job_array: + num_tasks: 13 + task_index: 0 + +# ============================================================================ +# ODinW13 Supercategories +# ============================================================================ + +all_odinw_supercategories: + - name: AerialMaritimeDrone_large + val: + img_folder: AerialMaritimeDrone/large/test/ + json: AerialMaritimeDrone/large/test/annotations_without_background.json + - name: Aquarium + val: + img_folder: Aquarium/Aquarium Combined.v2-raw-1024.coco/test/ + json: Aquarium/Aquarium Combined.v2-raw-1024.coco/test/annotations_without_background.json + - name: CottontailRabbits + val: + img_folder: CottontailRabbits/test/ + json: CottontailRabbits/test/annotations_without_background.json + - name: EgoHands_generic + val: + img_folder: EgoHands/generic/test/ + json: EgoHands/generic/test/annotations_without_background.json + - name: NorthAmericaMushrooms + val: + img_folder: NorthAmericaMushrooms/North American Mushrooms.v1-416x416.coco/test/ + json: NorthAmericaMushrooms/North American Mushrooms.v1-416x416.coco/test/annotations_without_background.json + - name: Packages + val: + img_folder: Packages/Raw/test/ + json: Packages/Raw/test/annotations_without_background.json + - name: PascalVOC + val: + img_folder: PascalVOC/valid/ + json: PascalVOC/valid/annotations_without_background.json + - name: Raccoon + val: + img_folder: Raccoon/Raccoon.v2-raw.coco/test/ + json: Raccoon/Raccoon.v2-raw.coco/test/annotations_without_background.json + - name: ShellfishOpenImages + val: + img_folder: ShellfishOpenImages/raw/test/ + json: ShellfishOpenImages/raw/test/annotations_without_background.json + - name: VehiclesOpenImages + val: + img_folder: VehiclesOpenImages/416x416/test/ + json: VehiclesOpenImages/416x416/test/annotations_without_background.json + - name: pistols + val: + img_folder: pistols/export/ + json: pistols/export/test_annotations_without_background.json + - name: pothole + val: + img_folder: pothole/test/ + json: pothole/test/annotations_without_background.json + - name: thermalDogsAndPeople + val: + img_folder: thermalDogsAndPeople/test/ + json: thermalDogsAndPeople/test/annotations_without_background.json + + +odinw35_prompts: + AerialMaritimeDrone_large: '[{"id": 1, "name": "boat", "supercategory": "movable-objects"}, + {"id": 2, "name": "car", "supercategory": "movable-objects"}, {"id": 3, "name": "dock", + "supercategory": "movable-objects"}, {"id": 4, "name": "jet ski", "supercategory": "movable-objects"}, + {"id": 5, "name": "boat lift", "supercategory": "movable-objects"}]' + Aquarium: null + CottontailRabbits: null + EgoHands_generic: null + NorthAmericaMushrooms: '[{''id'': 1, ''name'': + ''chicken of the woods'', ''supercategory'': ''mushroom''}, {''id'': 2, ''name'': ''chanterelle'', ''supercategory'': ''mushroom''}]' + Packages: null + PascalVOC: null + Raccoon: null + ShellfishOpenImages: null + VehiclesOpenImages: null + pistols: null + pothole: null + thermalDogsAndPeople: null diff --git a/third_party/sam3/sam3/train/configs/odinw13/odinw_text_only.yaml b/third_party/sam3/sam3/train/configs/odinw13/odinw_text_only.yaml new file mode 100644 index 0000000000000000000000000000000000000000..de532bbef076eedb11fd632a2f18e0a08863d817 --- /dev/null +++ b/third_party/sam3/sam3/train/configs/odinw13/odinw_text_only.yaml @@ -0,0 +1,253 @@ +# @package _global_ +defaults: + - _self_ + +# ============================================================================ +# Paths Configuration (Chage this to your own paths) +# ============================================================================ +# python sam3/train/train.py -c configs/odinw_text_only.yaml --use-cluster 1 --partition ${PARTITION} --account ${ACCOUNT} --qos ${QoS} + +paths: + odinw_data_root: + experiment_log_dir: + bpe_path: # This should be under sam3/assets/bpe_simple_vocab_16e6.txt.gz + + +supercategory_tuple: ${all_odinw_supercategories.${string:${submitit.job_array.task_index}}} +# Validation transforms pipeline +val_transforms: + - _target_: sam3.train.transforms.basic_for_api.ComposeAPI + transforms: + - _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI + sizes: ${scratch.resolution} + max_size: + _target_: sam3.train.transforms.basic.get_random_resize_max_size + size: ${scratch.resolution} + square: true + consistent_transform: False + - _target_: sam3.train.transforms.basic_for_api.ToTensorAPI + - _target_: sam3.train.transforms.basic_for_api.NormalizeAPI + mean: ${scratch.val_norm_mean} + std: ${scratch.val_norm_std} + +# ============================================================================ +# Different helper parameters and functions +# ============================================================================ +scratch: + enable_segmentation: True + # Box processing + use_presence_eval: True + original_box_postprocessor: + _target_: sam3.eval.postprocessors.PostProcessImage + max_dets_per_img: -1 # infinite detections + use_original_ids: true + use_original_sizes_box: true + use_presence: ${scratch.use_presence_eval} + + # Image processing parameters + resolution: 1008 + # Normalization parameters + val_norm_mean: [0.5, 0.5, 0.5] + val_norm_std: [0.5, 0.5, 0.5] + + # Training parameters + val_batch_size: 2 + num_val_workers: 0 + gather_pred_via_filesys: false + +# ============================================================================ +# Trainer Configuration +# ============================================================================ + +trainer: + _target_: sam3.train.trainer.Trainer + skip_saving_ckpts: true + empty_gpu_mem_cache_after_eval: True + max_epochs: 1 + accelerator: cuda + seed_value: 123 + mode: val + + distributed: + backend: nccl + find_unused_parameters: True + gradient_as_bucket_view: True + + loss: + default: + _target_: sam3.train.loss.sam3_loss.DummyLoss + + data: + val: + _target_: sam3.train.data.torch_dataset.TorchDataset + dataset: + _target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset + coco_json_loader: + _target_: sam3.train.data.coco_json_loaders.COCO_FROM_JSON + prompts: ${odinw35_prompts.${supercategory_tuple.name}} + include_negatives: true + category_chunk_size: 20 # Note: Since we are doing AP +ve we need to include all categories! + _partial_: true + img_folder: ${paths.odinw_data_root}/${supercategory_tuple.val.img_folder} + ann_file: + _target_: sam3.eval.coco_reindex.reindex_coco_to_temp + input_json_path: ${paths.odinw_data_root}/${supercategory_tuple.val.json} + transforms: ${val_transforms} + max_ann_per_img: 100000 + multiplier: 1 + training: false + + shuffle: False + batch_size: ${scratch.val_batch_size} + num_workers: ${scratch.num_val_workers} + pin_memory: False + drop_last: False + collate_fn: + _target_: sam3.train.data.collator.collate_fn_api + _partial_: true + repeats: 1 + dict_key: odinw35 + + model: + _target_: sam3.model_builder.build_sam3_image_model + bpe_path: ${paths.bpe_path} + device: cpus + eval_mode: true # Set to false if training + enable_segmentation: ${scratch.enable_segmentation} # Warning: Enable this if using segmentation. + + meters: + val: + odinw35: + detection: + _target_: sam3.eval.coco_writer.PredictionDumper + iou_type: "bbox" + dump_dir: ${launcher.experiment_log_dir}/dumps/odinw/${supercategory_tuple.name} + merge_predictions: True + postprocessor: ${scratch.original_box_postprocessor} + gather_pred_via_filesys: ${scratch.gather_pred_via_filesys} + maxdets: 100 + pred_file_evaluators: + - _target_: sam3.eval.coco_eval_offline.CocoEvaluatorOfflineWithPredFileEvaluators + gt_path: + _target_: sam3.eval.coco_reindex.reindex_coco_to_temp + input_json_path: ${paths.odinw_data_root}/${supercategory_tuple.val.json} + tide: False + iou_type: "bbox" + positive_split: False + + checkpoint: + save_dir: ${launcher.experiment_log_dir}/checkpoints + save_freq: 0 # 0 only last checkpoint is saved. + + + logging: + tensorboard_writer: + _target_: sam3.train.utils.logger.make_tensorboard_logger + log_dir: ${launcher.experiment_log_dir}/tensorboard + flush_secs: 120 + should_log: True + wandb_writer: null + log_dir: ${launcher.experiment_log_dir}/logs/${supercategory_tuple.name} + log_freq: 10 + +# ============================================================================ +# Launcher and Submitit Configuration +# ============================================================================ + +launcher: + num_nodes: 1 + gpus_per_node: 2 + experiment_log_dir: ${paths.experiment_log_dir} + multiprocessing_context: forkserver + +submitit: + account: null + partition: null + qos: null + timeout_hour: 72 + use_cluster: True + cpus_per_task: 10 + port_range: [10000, 65000] + constraint: null + + job_array: + num_tasks: 13 + task_index: 0 + +# ============================================================================ +# ODinW13 Supercategories +# ============================================================================ + +all_odinw_supercategories: + - name: AerialMaritimeDrone_large + val: + img_folder: AerialMaritimeDrone/large/test/ + json: AerialMaritimeDrone/large/test/annotations_without_background.json + - name: Aquarium + val: + img_folder: Aquarium/Aquarium Combined.v2-raw-1024.coco/test/ + json: Aquarium/Aquarium Combined.v2-raw-1024.coco/test/annotations_without_background.json + - name: CottontailRabbits + val: + img_folder: CottontailRabbits/test/ + json: CottontailRabbits/test/annotations_without_background.json + - name: EgoHands_generic + val: + img_folder: EgoHands/generic/test/ + json: EgoHands/generic/test/annotations_without_background.json + - name: NorthAmericaMushrooms + val: + img_folder: NorthAmericaMushrooms/North American Mushrooms.v1-416x416.coco/test/ + json: NorthAmericaMushrooms/North American Mushrooms.v1-416x416.coco/test/annotations_without_background.json + - name: Packages + val: + img_folder: Packages/Raw/test/ + json: Packages/Raw/test/annotations_without_background.json + - name: PascalVOC + val: + img_folder: PascalVOC/valid/ + json: PascalVOC/valid/annotations_without_background.json + - name: Raccoon + val: + img_folder: Raccoon/Raccoon.v2-raw.coco/test/ + json: Raccoon/Raccoon.v2-raw.coco/test/annotations_without_background.json + - name: ShellfishOpenImages + val: + img_folder: ShellfishOpenImages/raw/test/ + json: ShellfishOpenImages/raw/test/annotations_without_background.json + - name: VehiclesOpenImages + val: + img_folder: VehiclesOpenImages/416x416/test/ + json: VehiclesOpenImages/416x416/test/annotations_without_background.json + - name: pistols + val: + img_folder: pistols/export/ + json: pistols/export/test_annotations_without_background.json + - name: pothole + val: + img_folder: pothole/test/ + json: pothole/test/annotations_without_background.json + - name: thermalDogsAndPeople + val: + img_folder: thermalDogsAndPeople/test/ + json: thermalDogsAndPeople/test/annotations_without_background.json + + +odinw35_prompts: + AerialMaritimeDrone_large: '[{"id": 1, "name": "boat", "supercategory": "movable-objects"}, + {"id": 2, "name": "car", "supercategory": "movable-objects"}, {"id": 3, "name": "dock", + "supercategory": "movable-objects"}, {"id": 4, "name": "jet ski", "supercategory": "movable-objects"}, + {"id": 5, "name": "boat lift", "supercategory": "movable-objects"}]' + Aquarium: null + CottontailRabbits: null + EgoHands_generic: null + NorthAmericaMushrooms: '[{''id'': 1, ''name'': + ''chicken of the woods'', ''supercategory'': ''mushroom''}, {''id'': 2, ''name'': ''chanterelle'', ''supercategory'': ''mushroom''}]' + Packages: null + PascalVOC: null + Raccoon: null + ShellfishOpenImages: null + VehiclesOpenImages: null + pistols: null + pothole: null + thermalDogsAndPeople: null diff --git a/third_party/sam3/sam3/train/configs/odinw13/odinw_text_only_positive.yaml b/third_party/sam3/sam3/train/configs/odinw13/odinw_text_only_positive.yaml new file mode 100644 index 0000000000000000000000000000000000000000..dcd37804259db637e4d8984749eb53d1856d1a7a --- /dev/null +++ b/third_party/sam3/sam3/train/configs/odinw13/odinw_text_only_positive.yaml @@ -0,0 +1,253 @@ +# @package _global_ +defaults: + - _self_ + +# ============================================================================ +# Paths Configuration (Chage this to your own paths) +# ============================================================================ +# python sam3/train/train.py -c configs/odinw_text_only.yaml --use-cluster 1 --partition ${PARTITION} --account ${ACCOUNT} --qos ${QoS} + +paths: + odinw_data_root: + experiment_log_dir: + bpe_path: # This should be under sam3/assets/bpe_simple_vocab_16e6.txt.gz + + +supercategory_tuple: ${all_odinw_supercategories.${string:${submitit.job_array.task_index}}} +# Validation transforms pipeline +val_transforms: + - _target_: sam3.train.transforms.basic_for_api.ComposeAPI + transforms: + - _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI + sizes: ${scratch.resolution} + max_size: + _target_: sam3.train.transforms.basic.get_random_resize_max_size + size: ${scratch.resolution} + square: true + consistent_transform: False + - _target_: sam3.train.transforms.basic_for_api.ToTensorAPI + - _target_: sam3.train.transforms.basic_for_api.NormalizeAPI + mean: ${scratch.val_norm_mean} + std: ${scratch.val_norm_std} + +# ============================================================================ +# Different helper parameters and functions +# ============================================================================ +scratch: + enable_segmentation: True + # Box processing + use_presence_eval: True + original_box_postprocessor: + _target_: sam3.eval.postprocessors.PostProcessImage + max_dets_per_img: -1 # infinite detections + use_original_ids: true + use_original_sizes_box: true + use_presence: ${scratch.use_presence_eval} + + # Image processing parameters + resolution: 1008 + # Normalization parameters + val_norm_mean: [0.5, 0.5, 0.5] + val_norm_std: [0.5, 0.5, 0.5] + + # Training parameters + val_batch_size: 2 + num_val_workers: 0 + gather_pred_via_filesys: false + +# ============================================================================ +# Trainer Configuration +# ============================================================================ + +trainer: + _target_: sam3.train.trainer.Trainer + skip_saving_ckpts: true + empty_gpu_mem_cache_after_eval: True + max_epochs: 1 + accelerator: cuda + seed_value: 123 + mode: val + + distributed: + backend: nccl + find_unused_parameters: True + gradient_as_bucket_view: True + + loss: + default: + _target_: sam3.train.loss.sam3_loss.DummyLoss + + data: + val: + _target_: sam3.train.data.torch_dataset.TorchDataset + dataset: + _target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset + coco_json_loader: + _target_: sam3.train.data.coco_json_loaders.COCO_FROM_JSON + prompts: ${odinw35_prompts.${supercategory_tuple.name}} + include_negatives: true + category_chunk_size: 20 # Note: Since we are doing AP +ve we need to include all categories! + _partial_: true + img_folder: ${paths.odinw_data_root}/${supercategory_tuple.val.img_folder} + ann_file: + _target_: sam3.eval.coco_reindex.reindex_coco_to_temp + input_json_path: ${paths.odinw_data_root}/${supercategory_tuple.val.json} + transforms: ${val_transforms} + max_ann_per_img: 100000 + multiplier: 1 + training: false + + shuffle: False + batch_size: ${scratch.val_batch_size} + num_workers: ${scratch.num_val_workers} + pin_memory: False + drop_last: False + collate_fn: + _target_: sam3.train.data.collator.collate_fn_api + _partial_: true + repeats: 1 + dict_key: odinw35 + + model: + _target_: sam3.model_builder.build_sam3_image_model + bpe_path: ${paths.bpe_path} + device: cpus + eval_mode: true # Set to false if training + enable_segmentation: ${scratch.enable_segmentation} # Warning: Enable this if using segmentation. + + meters: + val: + odinw35: + detection: + _target_: sam3.eval.coco_writer.PredictionDumper + iou_type: "bbox" + dump_dir: ${launcher.experiment_log_dir}/dumps/roboflow/${supercategory_tuple.name} + merge_predictions: True + postprocessor: ${scratch.original_box_postprocessor} + gather_pred_via_filesys: ${scratch.gather_pred_via_filesys} + maxdets: 100 + pred_file_evaluators: + - _target_: sam3.eval.coco_eval_offline.CocoEvaluatorOfflineWithPredFileEvaluators + gt_path: + _target_: sam3.eval.coco_reindex.reindex_coco_to_temp + input_json_path: ${paths.odinw_data_root}/${supercategory_tuple.val.json} + tide: False + iou_type: "bbox" + positive_split: true + + checkpoint: + save_dir: ${launcher.experiment_log_dir}/checkpoints + save_freq: 0 # 0 only last checkpoint is saved. + + + logging: + tensorboard_writer: + _target_: sam3.train.utils.logger.make_tensorboard_logger + log_dir: ${launcher.experiment_log_dir}/tensorboard + flush_secs: 120 + should_log: True + wandb_writer: null + log_dir: ${launcher.experiment_log_dir}/logs/${supercategory_tuple.name} + log_freq: 10 + +# ============================================================================ +# Launcher and Submitit Configuration +# ============================================================================ + +launcher: + num_nodes: 1 + gpus_per_node: 2 + experiment_log_dir: ${paths.experiment_log_dir} + multiprocessing_context: forkserver + +submitit: + account: null + partition: null + qos: null + timeout_hour: 72 + use_cluster: True + cpus_per_task: 10 + port_range: [10000, 65000] + constraint: null + + job_array: + num_tasks: 13 + task_index: 0 + +# ============================================================================ +# ODinW13 Supercategories +# ============================================================================ + +all_odinw_supercategories: + - name: AerialMaritimeDrone_large + val: + img_folder: AerialMaritimeDrone/large/test/ + json: AerialMaritimeDrone/large/test/annotations_without_background.json + - name: Aquarium + val: + img_folder: Aquarium/Aquarium Combined.v2-raw-1024.coco/test/ + json: Aquarium/Aquarium Combined.v2-raw-1024.coco/test/annotations_without_background.json + - name: CottontailRabbits + val: + img_folder: CottontailRabbits/test/ + json: CottontailRabbits/test/annotations_without_background.json + - name: EgoHands_generic + val: + img_folder: EgoHands/generic/test/ + json: EgoHands/generic/test/annotations_without_background.json + - name: NorthAmericaMushrooms + val: + img_folder: NorthAmericaMushrooms/North American Mushrooms.v1-416x416.coco/test/ + json: NorthAmericaMushrooms/North American Mushrooms.v1-416x416.coco/test/annotations_without_background.json + - name: Packages + val: + img_folder: Packages/Raw/test/ + json: Packages/Raw/test/annotations_without_background.json + - name: PascalVOC + val: + img_folder: PascalVOC/valid/ + json: PascalVOC/valid/annotations_without_background.json + - name: Raccoon + val: + img_folder: Raccoon/Raccoon.v2-raw.coco/test/ + json: Raccoon/Raccoon.v2-raw.coco/test/annotations_without_background.json + - name: ShellfishOpenImages + val: + img_folder: ShellfishOpenImages/raw/test/ + json: ShellfishOpenImages/raw/test/annotations_without_background.json + - name: VehiclesOpenImages + val: + img_folder: VehiclesOpenImages/416x416/test/ + json: VehiclesOpenImages/416x416/test/annotations_without_background.json + - name: pistols + val: + img_folder: pistols/export/ + json: pistols/export/test_annotations_without_background.json + - name: pothole + val: + img_folder: pothole/test/ + json: pothole/test/annotations_without_background.json + - name: thermalDogsAndPeople + val: + img_folder: thermalDogsAndPeople/test/ + json: thermalDogsAndPeople/test/annotations_without_background.json + + +odinw35_prompts: + AerialMaritimeDrone_large: '[{"id": 1, "name": "boat", "supercategory": "movable-objects"}, + {"id": 2, "name": "car", "supercategory": "movable-objects"}, {"id": 3, "name": "dock", + "supercategory": "movable-objects"}, {"id": 4, "name": "jet ski", "supercategory": "movable-objects"}, + {"id": 5, "name": "boat lift", "supercategory": "movable-objects"}]' + Aquarium: null + CottontailRabbits: null + EgoHands_generic: null + NorthAmericaMushrooms: '[{''id'': 1, ''name'': + ''chicken of the woods'', ''supercategory'': ''mushroom''}, {''id'': 2, ''name'': ''chanterelle'', ''supercategory'': ''mushroom''}]' + Packages: null + PascalVOC: null + Raccoon: null + ShellfishOpenImages: null + VehiclesOpenImages: null + pistols: null + pothole: null + thermalDogsAndPeople: null diff --git a/third_party/sam3/sam3/train/configs/odinw13/odinw_text_only_train.yaml b/third_party/sam3/sam3/train/configs/odinw13/odinw_text_only_train.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0a26bd5a787ca71fc598311ab03300353c56d3ea --- /dev/null +++ b/third_party/sam3/sam3/train/configs/odinw13/odinw_text_only_train.yaml @@ -0,0 +1,591 @@ +# @package _global_ +defaults: + - _self_ + +# ============================================================================ +# Paths Configuration (Chage this to your own paths) +# ============================================================================ +# python sam3/train/train.py -c configs/odinw_text_only.yaml --use-cluster 1 --partition ${PARTITION} --account ${ACCOUNT} --qos ${QoS} + +paths: + odinw_data_root: + experiment_log_dir: + bpe_path: # This should be under sam3/assets/bpe_simple_vocab_16e6.txt.gz + + +odinw_train: + train_file: fewshot_train_shot10_seed300 + num_images: null + supercategory_tuple: ${all_odinw_supercategories.${string:${submitit.job_array.task_index}}} + # Training transforms pipeline + train_transforms: + - _target_: sam3.train.transforms.basic_for_api.ComposeAPI + transforms: + - _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries + query_filter: + _target_: sam3.train.transforms.filter_query_transforms.FilterCrowds + - _target_: sam3.train.transforms.point_sampling.RandomizeInputBbox + box_noise_std: 0.1 + box_noise_max: 20 + - _target_: sam3.train.transforms.segmentation.DecodeRle + - _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI + sizes: + _target_: sam3.train.transforms.basic.get_random_resize_scales + size: ${scratch.resolution} + min_size: 480 + rounded: false + max_size: + _target_: sam3.train.transforms.basic.get_random_resize_max_size + size: ${scratch.resolution} + square: true + consistent_transform: ${scratch.consistent_transform} + - _target_: sam3.train.transforms.basic_for_api.PadToSizeAPI + size: ${scratch.resolution} + consistent_transform: ${scratch.consistent_transform} + - _target_: sam3.train.transforms.basic_for_api.ToTensorAPI + - _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries + query_filter: + _target_: sam3.train.transforms.filter_query_transforms.FilterEmptyTargets + - _target_: sam3.train.transforms.basic_for_api.NormalizeAPI + mean: ${scratch.train_norm_mean} + std: ${scratch.train_norm_std} + - _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries + query_filter: + _target_: sam3.train.transforms.filter_query_transforms.FilterEmptyTargets + - _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries + query_filter: + _target_: sam3.train.transforms.filter_query_transforms.FilterFindQueriesWithTooManyOut + max_num_objects: ${scratch.max_ann_per_img} + + # Validation transforms pipeline + val_transforms: + - _target_: sam3.train.transforms.basic_for_api.ComposeAPI + transforms: + - _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI + sizes: ${scratch.resolution} + max_size: + _target_: sam3.train.transforms.basic.get_random_resize_max_size + size: ${scratch.resolution} + square: true + consistent_transform: False + - _target_: sam3.train.transforms.basic_for_api.ToTensorAPI + - _target_: sam3.train.transforms.basic_for_api.NormalizeAPI + mean: ${scratch.val_norm_mean} + std: ${scratch.val_norm_std} + + # loss config (no mask loss) + loss: + _target_: sam3.train.loss.sam3_loss.Sam3LossWrapper + matcher: ${scratch.matcher} + o2m_weight: 2.0 + o2m_matcher: + _target_: sam3.train.matcher.BinaryOneToManyMatcher + alpha: 0.3 + threshold: 0.4 + topk: 4 + use_o2m_matcher_on_o2m_aux: ${scratch.use_o2m_matcher_on_o2m_aux} + loss_fns_find: + - _target_: sam3.train.loss.loss_fns.Boxes + weight_dict: + loss_bbox: 5.0 + loss_giou: 2.0 + - _target_: sam3.train.loss.loss_fns.IABCEMdetr + weak_loss: False + weight_dict: + loss_ce: ${scratch.loss_ce_weight} # Change + presence_loss: ${scratch.presence_weight} # Change + pos_weight: ${scratch.iabce_pos_weight} + alpha: ${scratch.iabce_alpha} + gamma: 2 + use_presence: True # Change + pos_focal: ${scratch.iabce_pos_focal} + pad_n_queries: ${scratch.num_queries} + pad_scale_pos: ${scratch.instance_query_loss_pad_scale_pos} + + loss_fn_semantic_seg: null + scale_by_find_batch_size: ${scratch.scale_by_find_batch_size} + + +# ============================================================================ +# Different helper parameters and functions +# ============================================================================ +scratch: + enable_segmentation: False + use_act_checkpoint_geo_encoder: True + input_geometry_encoder: + _target_: sam3.model.geometry_encoders.SequenceGeometryEncoder + pos_enc: ${scratch.pos_embed} + encode_boxes_as_points: False + points_direct_project: True + points_pool: True + points_pos_enc: True + boxes_direct_project: True + boxes_pool: True + boxes_pos_enc: True + d_model: ${scratch.d_model} + num_layers: 3 + use_act_ckpt: ${scratch.use_act_checkpoint_geo_encoder} + layer: + _target_: sam3.model.encoder.TransformerEncoderLayer + activation: "relu" + d_model: ${scratch.d_model} + dim_feedforward: 2048 + dropout: ${scratch.encoder_dropout} + pos_enc_at_attn: false + pre_norm: True + pos_enc_at_cross_attn_queries: false + pos_enc_at_cross_attn_keys: true + self_attention: + _target_: sam3.model.attention.MultiheadAttention + attn_type: Vanilla + num_heads: 8 + dropout: ${scratch.encoder_dropout} + embed_dim: ${scratch.d_model} + batch_first: False + cross_attention: + _target_: sam3.model.attention.MultiheadAttention + attn_type: Vanilla + num_heads: 8 + dropout: ${scratch.encoder_dropout} + embed_dim: ${scratch.d_model} + batch_first: False + add_cls: true + add_post_encode_proj: True + + boxRPB: "log" + dac: True + use_early_fusion: true + o2m_mask: false + num_feature_levels: 1 # > 1 not implemented + encoder_dropout: 0.1 + decoder_dropout: 0.1 + + tokenizer_ve: + _target_: sam3.model.tokenizer_ve.SimpleTokenizer + bpe_path: ${paths.bpe_path} + + + freeze_text_tower: False + freeze_image_tower: NoFreeze + vis_backbone_dp: 0.0 + # Activation checkpointing (Save memory) + use_act_checkpoint_vision_backbone: True + use_act_checkpoint_text_backbone: True + use_act_checkpoint_encoder: True + use_act_checkpoint_decoder: True + + loss: null + # Loss parameters + num_queries: 200 + presence_weight: 20.0 + loss_ce_weight: 20.0 + iabce_pos_weight: 5.0 + iabce_pos_focal: false + iabce_alpha: 0.25 + instance_query_loss_pad_scale_pos: 1.0 + use_o2m_matcher_on_o2m_aux: false + + # Model parameters + use_instance_query: true + d_model: 256 + pos_embed: + _target_: sam3.model.position_encoding.PositionEmbeddingSine + num_pos_feats: ${scratch.d_model} + normalize: true + scale: null + temperature: 10000 + + # Box processing + use_presence_eval: True + original_box_postprocessor: + _target_: sam3.eval.postprocessors.PostProcessImage + max_dets_per_img: -1 # infinite detections + use_original_ids: true + use_original_sizes_box: true + use_presence: ${scratch.use_presence_eval} + + + # Matcher configuration + matcher: + _target_: sam3.train.matcher.BinaryHungarianMatcherV2 + focal: true + cost_class: 2.0 + cost_bbox: 5.0 + cost_giou: 2.0 + alpha: 0.25 + gamma: 2 + stable: False + scale_by_find_batch_size: True + + # Image processing parameters + resolution: 1008 + consistent_transform: False + max_ann_per_img: 200 + + # Normalization parameters + train_norm_mean: [0.5, 0.5, 0.5] + train_norm_std: [0.5, 0.5, 0.5] + val_norm_mean: [0.5, 0.5, 0.5] + val_norm_std: [0.5, 0.5, 0.5] + + # Training parameters + train_batch_size: 1 + val_batch_size: 1 + num_train_workers: 0 + num_val_workers: 0 + max_data_epochs: 40 + target_epoch_size: 1500 + hybrid_repeats: 1 + context_length: 2 + gather_pred_via_filesys: false + + # Learning rate and scheduler parameters + lr_scale: 0.1 + lr_transformer: ${times:8e-4,${scratch.lr_scale}} + lr_vision_backbone: ${times:2.5e-4,${scratch.lr_scale}} + lr_language_backbone: ${times:5e-5,${scratch.lr_scale}} + lrd_vision_backbone: 0.9 + wd: 0.1 + scheduler_timescale: 20 + scheduler_warmup: 20 + scheduler_cooldown: 20 + + + + +# ============================================================================ +# Trainer Configuration +# ============================================================================ + +trainer: + _target_: sam3.train.trainer.Trainer + skip_saving_ckpts: true + # _target_: sam3.train.trainer.Trainer + # skip_saving_ckpts: true + empty_gpu_mem_cache_after_eval: True + skip_first_val: True + max_epochs: ${scratch.max_data_epochs} + accelerator: cuda + seed_value: 123 + val_epoch_freq: 10 + mode: train + + distributed: + backend: nccl + find_unused_parameters: True + gradient_as_bucket_view: True + + loss: + all: ${odinw_train.loss} + default: + _target_: sam3.train.loss.sam3_loss.DummyLoss + + data: + train: + _target_: sam3.train.data.torch_dataset.TorchDataset + dataset: + _target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset + limit_ids: ${odinw_train.num_images} + transforms: ${odinw_train.train_transforms} + load_segmentation: ${scratch.enable_segmentation} + max_ann_per_img: 500000 + multiplier: 1 + max_train_queries: 50000 + max_val_queries: 50000 + training: true + use_caching: False + img_folder: ${paths.odinw_data_root}/${odinw_train.supercategory_tuple.train.img_folder} + ann_file: + _target_: sam3.eval.coco_reindex.reindex_coco_to_temp + input_json_path: ${paths.odinw_data_root}/${odinw_train.supercategory_tuple.train.json} + coco_json_loader: + _target_: sam3.train.data.coco_json_loaders.COCO_FROM_JSON + prompts: ${odinw35_prompts.${odinw_train.supercategory_tuple.name}} #${odinw_train.supercategory_tuple.name) + _partial_: true + shuffle: True + batch_size: ${scratch.train_batch_size} + num_workers: ${scratch.num_train_workers} + pin_memory: False + drop_last: True + collate_fn: + _target_: sam3.train.data.collator.collate_fn_api + _partial_: true + repeats: ${scratch.hybrid_repeats} + dict_key: all + with_seg_masks: ${scratch.enable_segmentation} + + val: + _target_: sam3.train.data.torch_dataset.TorchDataset + dataset: + _target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset + load_segmentation: ${scratch.enable_segmentation} + coco_json_loader: + _target_: sam3.train.data.coco_json_loaders.COCO_FROM_JSON + prompts: ${odinw35_prompts.${odinw_train.supercategory_tuple.name}} + include_negatives: true + category_chunk_size: 20 # Note: Since we are doing AP +ve we need to include all categories! + _partial_: true + img_folder: ${paths.odinw_data_root}/${odinw_train.supercategory_tuple.val.img_folder} + ann_file: + _target_: sam3.eval.coco_reindex.reindex_coco_to_temp + input_json_path: ${paths.odinw_data_root}/${odinw_train.supercategory_tuple.val.json} + transforms: ${odinw_train.val_transforms} + max_ann_per_img: 100000 + multiplier: 1 + training: false + + shuffle: False + batch_size: ${scratch.val_batch_size} + num_workers: ${scratch.num_val_workers} + pin_memory: False + drop_last: False + collate_fn: + _target_: sam3.train.data.collator.collate_fn_api + _partial_: true + repeats: 1 + dict_key: odinw35 + with_seg_masks: ${scratch.enable_segmentation} + + model: + _target_: sam3.model_builder.build_sam3_image_model + bpe_path: ${paths.bpe_path} + device: cpus + eval_mode: false # Set to false if training + enable_segmentation: ${scratch.enable_segmentation} # Warning: Enable this if using segmentation. + + meters: + val: + odinw35: + detection: + _target_: sam3.eval.coco_writer.PredictionDumper + iou_type: "bbox" + dump_dir: ${launcher.experiment_log_dir}/dumps/odinw/${odinw_train.supercategory_tuple.name} + merge_predictions: True + postprocessor: ${scratch.original_box_postprocessor} + gather_pred_via_filesys: ${scratch.gather_pred_via_filesys} + maxdets: 100 + pred_file_evaluators: + - _target_: sam3.eval.coco_eval_offline.CocoEvaluatorOfflineWithPredFileEvaluators + gt_path: + _target_: sam3.eval.coco_reindex.reindex_coco_to_temp + input_json_path: ${paths.odinw_data_root}/${odinw_train.supercategory_tuple.val.json} + tide: False + iou_type: "bbox" + positive_split: False + + optim: + amp: + enabled: True + amp_dtype: bfloat16 + + optimizer: + _target_: torch.optim.AdamW + + gradient_clip: + _target_: sam3.train.optim.optimizer.GradientClipper + max_norm: 0.1 + norm_type: 2 + + param_group_modifiers: + - _target_: sam3.train.optim.optimizer.layer_decay_param_modifier + _partial_: True + layer_decay_value: ${scratch.lrd_vision_backbone} + apply_to: 'backbone.vision_backbone.trunk' + overrides: + - pattern: '*pos_embed*' + value: 1.0 + + options: + lr: + - scheduler: # transformer and class_embed + _target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler + base_lr: ${scratch.lr_transformer} + timescale: ${scratch.scheduler_timescale} + warmup_steps: ${scratch.scheduler_warmup} + cooldown_steps: ${scratch.scheduler_cooldown} + - scheduler: + _target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler + base_lr: ${scratch.lr_vision_backbone} + timescale: ${scratch.scheduler_timescale} + warmup_steps: ${scratch.scheduler_warmup} + cooldown_steps: ${scratch.scheduler_cooldown} + param_names: + - 'backbone.vision_backbone.*' + - scheduler: + _target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler + base_lr: ${scratch.lr_language_backbone} + timescale: ${scratch.scheduler_timescale} + warmup_steps: ${scratch.scheduler_warmup} + cooldown_steps: ${scratch.scheduler_cooldown} + param_names: + - 'backbone.language_backbone.*' + + weight_decay: + - scheduler: + _target_: fvcore.common.param_scheduler.ConstantParamScheduler + value: ${scratch.wd} + - scheduler: + _target_: fvcore.common.param_scheduler.ConstantParamScheduler + value: 0.0 + param_names: + - '*bias*' + module_cls_names: ['torch.nn.LayerNorm'] + + checkpoint: + save_dir: ${launcher.experiment_log_dir}/checkpoints + save_freq: 0 # 0 only last checkpoint is saved. + + + logging: + tensorboard_writer: + _target_: sam3.train.utils.logger.make_tensorboard_logger + log_dir: ${launcher.experiment_log_dir}/tensorboard + flush_secs: 120 + should_log: True + wandb_writer: null + log_dir: ${launcher.experiment_log_dir}/logs/${odinw_train.supercategory_tuple.name} + log_freq: 10 + +# ============================================================================ +# Launcher and Submitit Configuration +# ============================================================================ + +launcher: + num_nodes: 1 + gpus_per_node: 2 + experiment_log_dir: null #${paths.experiment_log_dir} + multiprocessing_context: forkserver + +submitit: + account: null + partition: null + qos: null + timeout_hour: 72 + use_cluster: True + cpus_per_task: 10 + port_range: [10000, 65000] + constraint: null + + # task_index: 2 + # Uncomment for job array configuration + job_array: + num_tasks: 13 + task_index: 0 + + +# ============================================================================ +# ODinW13 Supercategories +# ============================================================================ + +all_odinw_supercategories: + - name: AerialMaritimeDrone_large + val: + img_folder: AerialMaritimeDrone/large/test/ + json: AerialMaritimeDrone/large/test/annotations_without_background.json + train: + img_folder: AerialMaritimeDrone/large/train/ + json: AerialMaritimeDrone/large/train/${odinw_train.train_file}.json + - name: Aquarium + val: + img_folder: Aquarium/Aquarium Combined.v2-raw-1024.coco/test/ + json: Aquarium/Aquarium Combined.v2-raw-1024.coco/test/annotations_without_background.json + train: + img_folder: Aquarium/Aquarium Combined.v2-raw-1024.coco/train/ + json: Aquarium/Aquarium Combined.v2-raw-1024.coco/train/${odinw_train.train_file}.json + - name: CottontailRabbits + val: + img_folder: CottontailRabbits/test/ + json: CottontailRabbits/test/annotations_without_background.json + train: + img_folder: CottontailRabbits/train/ + json: CottontailRabbits/train/${odinw_train.train_file}.json + - name: EgoHands_generic + val: + img_folder: EgoHands/generic/test/ + json: EgoHands/generic/test/annotations_without_background.json + train: + img_folder: EgoHands/generic/train/ + json: EgoHands/generic/train/${odinw_train.train_file}.json + - name: NorthAmericaMushrooms + val: + img_folder: NorthAmericaMushrooms/North American Mushrooms.v1-416x416.coco/test/ + json: NorthAmericaMushrooms/North American Mushrooms.v1-416x416.coco/test/annotations_without_background.json + train: + img_folder: NorthAmericaMushrooms/North American Mushrooms.v1-416x416.coco/train/ + json: NorthAmericaMushrooms/North American Mushrooms.v1-416x416.coco/train/${odinw_train.train_file}.json + - name: Packages + val: + img_folder: Packages/Raw/test/ + json: Packages/Raw/test/annotations_without_background.json + train: + img_folder: Packages/Raw/train/ + json: Packages/Raw/train/${odinw_train.train_file}.json + - name: PascalVOC + val: + img_folder: PascalVOC/valid/ + json: PascalVOC/valid/annotations_without_background.json + train: + img_folder: PascalVOC/train/ + json: PascalVOC/train/${odinw_train.train_file}.json + - name: Raccoon + val: + img_folder: Raccoon/Raccoon.v2-raw.coco/test/ + json: Raccoon/Raccoon.v2-raw.coco/test/annotations_without_background.json + train: + img_folder: Raccoon/Raccoon.v2-raw.coco/train/ + json: Raccoon/Raccoon.v2-raw.coco/train/${odinw_train.train_file}.json + - name: ShellfishOpenImages + val: + img_folder: ShellfishOpenImages/raw/test/ + json: ShellfishOpenImages/raw/test/annotations_without_background.json + train: + img_folder: ShellfishOpenImages/raw/train/ + json: ShellfishOpenImages/raw/train/${odinw_train.train_file}.json + - name: VehiclesOpenImages + val: + img_folder: VehiclesOpenImages/416x416/test/ + json: VehiclesOpenImages/416x416/test/annotations_without_background.json + train: + img_folder: VehiclesOpenImages/416x416/train/ + json: VehiclesOpenImages/416x416/train/${odinw_train.train_file}.json + - name: pistols + val: + img_folder: pistols/export/ + json: pistols/export/test_annotations_without_background.json + train: + img_folder: pistols/export/ + json: pistols/export/${odinw_train.train_file}.json + - name: pothole + val: + img_folder: pothole/test/ + json: pothole/test/annotations_without_background.json + train: + img_folder: pothole/train/ + json: pothole/train/${odinw_train.train_file}.json + - name: thermalDogsAndPeople + val: + img_folder: thermalDogsAndPeople/test/ + json: thermalDogsAndPeople/test/annotations_without_background.json + train: + img_folder: thermalDogsAndPeople/train/ + json: thermalDogsAndPeople/train/${odinw_train.train_file}.json + + +odinw35_prompts: + AerialMaritimeDrone_large: '[{"id": 1, "name": "boat", "supercategory": "movable-objects"}, + {"id": 2, "name": "car", "supercategory": "movable-objects"}, {"id": 3, "name": "dock", + "supercategory": "movable-objects"}, {"id": 4, "name": "jet ski", "supercategory": "movable-objects"}, + {"id": 5, "name": "boat lift", "supercategory": "movable-objects"}]' + Aquarium: null + CottontailRabbits: null + EgoHands_generic: null + NorthAmericaMushrooms: '[{''id'': 1, ''name'': + ''chicken of the woods'', ''supercategory'': ''mushroom''}, {''id'': 2, ''name'': ''chanterelle'', ''supercategory'': ''mushroom''}]' + Packages: null + PascalVOC: null + Raccoon: null + ShellfishOpenImages: null + VehiclesOpenImages: null + pistols: null + pothole: null + thermalDogsAndPeople: null diff --git a/third_party/sam3/sam3/train/configs/odinw13/odinw_visual_only.yaml b/third_party/sam3/sam3/train/configs/odinw13/odinw_visual_only.yaml new file mode 100644 index 0000000000000000000000000000000000000000..46a32814c77a09ad92caec533a8e6f6be0b8a2b2 --- /dev/null +++ b/third_party/sam3/sam3/train/configs/odinw13/odinw_visual_only.yaml @@ -0,0 +1,256 @@ +# @package _global_ +defaults: + - _self_ + +# ============================================================================ +# Paths Configuration (Chage this to your own paths) +# ============================================================================ +# python sam3/train/train.py -c configs/odinw_text_only.yaml --use-cluster 1 --partition ${PARTITION} --account ${ACCOUNT} --qos ${QoS} + +paths: + odinw_data_root: + experiment_log_dir: + bpe_path: # This should be under sam3/assets/bpe_simple_vocab_16e6.txt.gz + + +supercategory_tuple: ${all_odinw_supercategories.${string:${submitit.job_array.task_index}}} +# Validation transforms pipeline +val_transforms: + - _target_: sam3.train.transforms.basic_for_api.ComposeAPI + transforms: + - _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI + sizes: ${scratch.resolution} + max_size: + _target_: sam3.train.transforms.basic.get_random_resize_max_size + size: ${scratch.resolution} + square: true + consistent_transform: False + - _target_: sam3.train.transforms.basic_for_api.ToTensorAPI + - _target_: sam3.train.transforms.basic_for_api.NormalizeAPI + mean: ${scratch.val_norm_mean} + std: ${scratch.val_norm_std} + - _target_: sam3.train.transforms.filter_query_transforms.TextQueryToVisual + keep_text_queries: false # Note: set this to false if you only want visual + probability: 1.0 # always + +# ============================================================================ +# Different helper parameters and functions +# ============================================================================ +scratch: + enable_segmentation: True + # Box processing + use_presence_eval: True + original_box_postprocessor: + _target_: sam3.eval.postprocessors.PostProcessImage + max_dets_per_img: -1 # infinite detections + use_original_ids: true + use_original_sizes_box: true + use_presence: ${scratch.use_presence_eval} + + # Image processing parameters + resolution: 1008 + # Normalization parameters + val_norm_mean: [0.5, 0.5, 0.5] + val_norm_std: [0.5, 0.5, 0.5] + + # Training parameters + val_batch_size: 2 + num_val_workers: 0 + gather_pred_via_filesys: false + +# ============================================================================ +# Trainer Configuration +# ============================================================================ + +trainer: + _target_: sam3.train.trainer.Trainer + skip_saving_ckpts: true + empty_gpu_mem_cache_after_eval: True + max_epochs: 1 + accelerator: cuda + seed_value: 123 + mode: val + + distributed: + backend: nccl + find_unused_parameters: True + gradient_as_bucket_view: True + + loss: + default: + _target_: sam3.train.loss.sam3_loss.DummyLoss + + data: + val: + _target_: sam3.train.data.torch_dataset.TorchDataset + dataset: + _target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset + coco_json_loader: + _target_: sam3.train.data.coco_json_loaders.COCO_FROM_JSON + prompts: ${odinw35_prompts.${supercategory_tuple.name}} + include_negatives: true + category_chunk_size: 20 # Note: Since we are doing AP +ve we need to include all categories! + _partial_: true + img_folder: ${paths.odinw_data_root}/${supercategory_tuple.val.img_folder} + ann_file: + _target_: sam3.eval.coco_reindex.reindex_coco_to_temp + input_json_path: ${paths.odinw_data_root}/${supercategory_tuple.val.json} + transforms: ${val_transforms} + max_ann_per_img: 100000 + multiplier: 1 + training: false + + shuffle: False + batch_size: ${scratch.val_batch_size} + num_workers: ${scratch.num_val_workers} + pin_memory: False + drop_last: False + collate_fn: + _target_: sam3.train.data.collator.collate_fn_api + _partial_: true + repeats: 1 + dict_key: odinw35 + + model: + _target_: sam3.model_builder.build_sam3_image_model + bpe_path: ${paths.bpe_path} + device: cpus + eval_mode: true # Set to false if training + enable_segmentation: ${scratch.enable_segmentation} # Warning: Enable this if using segmentation. + + meters: + val: + odinw35: + detection: + _target_: sam3.eval.coco_writer.PredictionDumper + iou_type: "bbox" + dump_dir: ${launcher.experiment_log_dir}/dumps/roboflow/${supercategory_tuple.name} + merge_predictions: True + postprocessor: ${scratch.original_box_postprocessor} + gather_pred_via_filesys: ${scratch.gather_pred_via_filesys} + maxdets: 100 + pred_file_evaluators: + - _target_: sam3.eval.coco_eval_offline.CocoEvaluatorOfflineWithPredFileEvaluators + gt_path: + _target_: sam3.eval.coco_reindex.reindex_coco_to_temp + input_json_path: ${paths.odinw_data_root}/${supercategory_tuple.val.json} + tide: False + iou_type: "bbox" + positive_split: true + + checkpoint: + save_dir: ${launcher.experiment_log_dir}/checkpoints + save_freq: 0 # 0 only last checkpoint is saved. + + + logging: + tensorboard_writer: + _target_: sam3.train.utils.logger.make_tensorboard_logger + log_dir: ${launcher.experiment_log_dir}/tensorboard + flush_secs: 120 + should_log: True + wandb_writer: null + log_dir: ${launcher.experiment_log_dir}/logs/${supercategory_tuple.name} + log_freq: 10 + +# ============================================================================ +# Launcher and Submitit Configuration +# ============================================================================ + +launcher: + num_nodes: 1 + gpus_per_node: 2 + experiment_log_dir: ${paths.experiment_log_dir} + multiprocessing_context: forkserver + +submitit: + account: null + partition: null + qos: null + timeout_hour: 72 + use_cluster: True + cpus_per_task: 10 + port_range: [10000, 65000] + constraint: null + + job_array: + num_tasks: 13 + task_index: 0 + +# ============================================================================ +# ODinW13 Supercategories +# ============================================================================ + +all_odinw_supercategories: + - name: AerialMaritimeDrone_large + val: + img_folder: AerialMaritimeDrone/large/test/ + json: AerialMaritimeDrone/large/test/annotations_without_background.json + - name: Aquarium + val: + img_folder: Aquarium/Aquarium Combined.v2-raw-1024.coco/test/ + json: Aquarium/Aquarium Combined.v2-raw-1024.coco/test/annotations_without_background.json + - name: CottontailRabbits + val: + img_folder: CottontailRabbits/test/ + json: CottontailRabbits/test/annotations_without_background.json + - name: EgoHands_generic + val: + img_folder: EgoHands/generic/test/ + json: EgoHands/generic/test/annotations_without_background.json + - name: NorthAmericaMushrooms + val: + img_folder: NorthAmericaMushrooms/North American Mushrooms.v1-416x416.coco/test/ + json: NorthAmericaMushrooms/North American Mushrooms.v1-416x416.coco/test/annotations_without_background.json + - name: Packages + val: + img_folder: Packages/Raw/test/ + json: Packages/Raw/test/annotations_without_background.json + - name: PascalVOC + val: + img_folder: PascalVOC/valid/ + json: PascalVOC/valid/annotations_without_background.json + - name: Raccoon + val: + img_folder: Raccoon/Raccoon.v2-raw.coco/test/ + json: Raccoon/Raccoon.v2-raw.coco/test/annotations_without_background.json + - name: ShellfishOpenImages + val: + img_folder: ShellfishOpenImages/raw/test/ + json: ShellfishOpenImages/raw/test/annotations_without_background.json + - name: VehiclesOpenImages + val: + img_folder: VehiclesOpenImages/416x416/test/ + json: VehiclesOpenImages/416x416/test/annotations_without_background.json + - name: pistols + val: + img_folder: pistols/export/ + json: pistols/export/test_annotations_without_background.json + - name: pothole + val: + img_folder: pothole/test/ + json: pothole/test/annotations_without_background.json + - name: thermalDogsAndPeople + val: + img_folder: thermalDogsAndPeople/test/ + json: thermalDogsAndPeople/test/annotations_without_background.json + + +odinw35_prompts: + AerialMaritimeDrone_large: '[{"id": 1, "name": "boat", "supercategory": "movable-objects"}, + {"id": 2, "name": "car", "supercategory": "movable-objects"}, {"id": 3, "name": "dock", + "supercategory": "movable-objects"}, {"id": 4, "name": "jet ski", "supercategory": "movable-objects"}, + {"id": 5, "name": "boat lift", "supercategory": "movable-objects"}]' + Aquarium: null + CottontailRabbits: null + EgoHands_generic: null + NorthAmericaMushrooms: '[{''id'': 1, ''name'': + ''chicken of the woods'', ''supercategory'': ''mushroom''}, {''id'': 2, ''name'': ''chanterelle'', ''supercategory'': ''mushroom''}]' + Packages: null + PascalVOC: null + Raccoon: null + ShellfishOpenImages: null + VehiclesOpenImages: null + pistols: null + pothole: null + thermalDogsAndPeople: null diff --git a/third_party/sam3/sam3/train/configs/roboflow_v100/roboflow_v100_eval.yaml b/third_party/sam3/sam3/train/configs/roboflow_v100/roboflow_v100_eval.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c5e4cce3d27b8c5198afae4c72648b22f6e0faaf --- /dev/null +++ b/third_party/sam3/sam3/train/configs/roboflow_v100/roboflow_v100_eval.yaml @@ -0,0 +1,539 @@ +# @package _global_ +defaults: + - _self_ + +# ============================================================================ +# Paths Configuration (Chage this to your own paths) +# ============================================================================ +paths: + roboflow_vl_100_root: + experiment_log_dir: + bpe_path: # This should be under sam3/assets/bpe_simple_vocab_16e6.txt.gz + +# Roboflow dataset configuration +roboflow_train: + num_images: 100 # Note: This is the number of images used for training. If null, all images are used. + supercategory: ${all_roboflow_supercategories.${string:${submitit.job_array.task_index}}} + + # Training transforms pipeline + train_transforms: + - _target_: sam3.train.transforms.basic_for_api.ComposeAPI + transforms: + - _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries + query_filter: + _target_: sam3.train.transforms.filter_query_transforms.FilterCrowds + - _target_: sam3.train.transforms.point_sampling.RandomizeInputBbox + box_noise_std: 0.1 + box_noise_max: 20 + - _target_: sam3.train.transforms.segmentation.DecodeRle + - _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI + sizes: + _target_: sam3.train.transforms.basic.get_random_resize_scales + size: ${scratch.resolution} + min_size: 480 + rounded: false + max_size: + _target_: sam3.train.transforms.basic.get_random_resize_max_size + size: ${scratch.resolution} + square: true + consistent_transform: ${scratch.consistent_transform} + - _target_: sam3.train.transforms.basic_for_api.PadToSizeAPI + size: ${scratch.resolution} + consistent_transform: ${scratch.consistent_transform} + - _target_: sam3.train.transforms.basic_for_api.ToTensorAPI + - _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries + query_filter: + _target_: sam3.train.transforms.filter_query_transforms.FilterEmptyTargets + - _target_: sam3.train.transforms.basic_for_api.NormalizeAPI + mean: ${scratch.train_norm_mean} + std: ${scratch.train_norm_std} + - _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries + query_filter: + _target_: sam3.train.transforms.filter_query_transforms.FilterEmptyTargets + - _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries + query_filter: + _target_: sam3.train.transforms.filter_query_transforms.FilterFindQueriesWithTooManyOut + max_num_objects: ${scratch.max_ann_per_img} + + # Validation transforms pipeline + val_transforms: + - _target_: sam3.train.transforms.basic_for_api.ComposeAPI + transforms: + - _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI + sizes: ${scratch.resolution} + max_size: + _target_: sam3.train.transforms.basic.get_random_resize_max_size + size: ${scratch.resolution} + square: true + consistent_transform: False + - _target_: sam3.train.transforms.basic_for_api.ToTensorAPI + - _target_: sam3.train.transforms.basic_for_api.NormalizeAPI + mean: ${scratch.train_norm_mean} + std: ${scratch.train_norm_std} + + # loss config (no mask loss) + loss: + _target_: sam3.train.loss.sam3_loss.Sam3LossWrapper + matcher: ${scratch.matcher} + o2m_weight: 2.0 + o2m_matcher: + _target_: sam3.train.matcher.BinaryOneToManyMatcher + alpha: 0.3 + threshold: 0.4 + topk: 4 + use_o2m_matcher_on_o2m_aux: false # Another option is true + loss_fns_find: + - _target_: sam3.train.loss.loss_fns.Boxes + weight_dict: + loss_bbox: 5.0 + loss_giou: 2.0 + - _target_: sam3.train.loss.loss_fns.IABCEMdetr + weak_loss: False + weight_dict: + loss_ce: 20.0 # Another option is 100.0 + presence_loss: 20.0 + pos_weight: 10.0 # Another option is 5.0 + alpha: 0.25 + gamma: 2 + use_presence: True # Change + pos_focal: false + pad_n_queries: 200 + pad_scale_pos: 1.0 + + loss_fn_semantic_seg: null + scale_by_find_batch_size: ${scratch.scale_by_find_batch_size} + + + # NOTE: Loss to be used for training in case of segmentation + # loss: + # _target_: sam3.train.loss.sam3_loss.Sam3LossWrapper + # matcher: ${scratch.matcher} + # o2m_weight: 2.0 + # o2m_matcher: + # _target_: sam3.train.matcher.BinaryOneToManyMatcher + # alpha: 0.3 + # threshold: 0.4 + # topk: 4 + # use_o2m_matcher_on_o2m_aux: false + # loss_fns_find: + # - _target_: sam3.train.loss.loss_fns.Boxes + # weight_dict: + # loss_bbox: 5.0 + # loss_giou: 2.0 + # - _target_: sam3.train.loss.loss_fns.IABCEMdetr + # weak_loss: False + # weight_dict: + # loss_ce: 20.0 # Another option is 100.0 + # presence_loss: 20.0 + # pos_weight: 10.0 # Another option is 5.0 + # alpha: 0.25 + # gamma: 2 + # use_presence: True # Change + # pos_focal: false + # pad_n_queries: 200 + # pad_scale_pos: 1.0 + # - _target_: sam3.train.loss.loss_fns.Masks + # focal_alpha: 0.25 + # focal_gamma: 2.0 + # weight_dict: + # loss_mask: 200.0 + # loss_dice: 10.0 + # compute_aux: false + # loss_fn_semantic_seg: + # _target_: sam3.losses.loss_fns.SemanticSegCriterion + # presence_head: True + # presence_loss: False # Change + # focal: True + # focal_alpha: 0.6 + # focal_gamma: 2.0 + # downsample: False + # weight_dict: + # loss_semantic_seg: 20.0 + # loss_semantic_presence: 1.0 + # loss_semantic_dice: 30.0 + # scale_by_find_batch_size: ${scratch.scale_by_find_batch_size} + +# ============================================================================ +# Different helper parameters and functions +# ============================================================================ +scratch: + enable_segmentation: False # NOTE: This is the number of queries used for segmentation + # Model parameters + d_model: 256 + pos_embed: + _target_: sam3.model.position_encoding.PositionEmbeddingSine + num_pos_feats: ${scratch.d_model} + normalize: true + scale: null + temperature: 10000 + + # Box processing + use_presence_eval: True + original_box_postprocessor: + _target_: sam3.eval.postprocessors.PostProcessImage + max_dets_per_img: -1 # infinite detections + use_original_ids: true + use_original_sizes_box: true + use_presence: ${scratch.use_presence_eval} + + # Matcher configuration + matcher: + _target_: sam3.train.matcher.BinaryHungarianMatcherV2 + focal: true # with `focal: true` it is equivalent to BinaryFocalHungarianMatcher + cost_class: 2.0 + cost_bbox: 5.0 + cost_giou: 2.0 + alpha: 0.25 + gamma: 2 + stable: False + scale_by_find_batch_size: True + + # Image processing parameters + resolution: 1008 + consistent_transform: False + max_ann_per_img: 200 + + # Normalization parameters + train_norm_mean: [0.5, 0.5, 0.5] + train_norm_std: [0.5, 0.5, 0.5] + val_norm_mean: [0.5, 0.5, 0.5] + val_norm_std: [0.5, 0.5, 0.5] + + # Training parameters + num_train_workers: 10 + num_val_workers: 0 + max_data_epochs: 20 + target_epoch_size: 1500 + hybrid_repeats: 1 + context_length: 2 + gather_pred_via_filesys: false + + # Learning rate and scheduler parameters + lr_scale: 0.1 + lr_transformer: ${times:8e-4,${scratch.lr_scale}} + lr_vision_backbone: ${times:2.5e-4,${scratch.lr_scale}} + lr_language_backbone: ${times:5e-5,${scratch.lr_scale}} + lrd_vision_backbone: 0.9 + wd: 0.1 + scheduler_timescale: 20 + scheduler_warmup: 20 + scheduler_cooldown: 20 + + val_batch_size: 1 + collate_fn_val: + _target_: sam3.train.data.collator.collate_fn_api + _partial_: true + repeats: ${scratch.hybrid_repeats} + dict_key: roboflow100 + with_seg_masks: ${scratch.enable_segmentation} # Note: Set this to true if using segmentation masks! + + gradient_accumulation_steps: 1 + train_batch_size: 1 + collate_fn: + _target_: sam3.train.data.collator.collate_fn_api + _partial_: true + repeats: ${scratch.hybrid_repeats} + dict_key: all + with_seg_masks: ${scratch.enable_segmentation} # Note: Set this to true if using segmentation masks! + +# ============================================================================ +# Trainer Configuration +# ============================================================================ + +trainer: + + _target_: sam3.train.trainer.Trainer + skip_saving_ckpts: true + empty_gpu_mem_cache_after_eval: True + skip_first_val: True + max_epochs: 20 + accelerator: cuda + seed_value: 123 + val_epoch_freq: 10 + mode: val + gradient_accumulation_steps: ${scratch.gradient_accumulation_steps} + + distributed: + backend: nccl + find_unused_parameters: True + gradient_as_bucket_view: True + + loss: + all: ${roboflow_train.loss} + default: + _target_: sam3.train.loss.sam3_loss.DummyLoss + + data: + train: + _target_: sam3.train.data.torch_dataset.TorchDataset + dataset: + _target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset + limit_ids: ${roboflow_train.num_images} + transforms: ${roboflow_train.train_transforms} + load_segmentation: ${scratch.enable_segmentation} + max_ann_per_img: 500000 + multiplier: 1 + max_train_queries: 50000 + max_val_queries: 50000 + training: true + use_caching: False + img_folder: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/train/ + ann_file: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/train/_annotations.coco.json + + shuffle: True + batch_size: ${scratch.train_batch_size} + num_workers: ${scratch.num_train_workers} + pin_memory: True + drop_last: True + collate_fn: ${scratch.collate_fn} + + val: + _target_: sam3.train.data.torch_dataset.TorchDataset + dataset: + _target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset + load_segmentation: ${scratch.enable_segmentation} + coco_json_loader: + _target_: sam3.train.data.coco_json_loaders.COCO_FROM_JSON + include_negatives: true + category_chunk_size: 2 # Note: You can increase this based on the memory of your GPU. + _partial_: true + img_folder: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/test/ + ann_file: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/test/_annotations.coco.json + transforms: ${roboflow_train.val_transforms} + max_ann_per_img: 100000 + multiplier: 1 + training: false + + shuffle: False + batch_size: ${scratch.val_batch_size} + num_workers: ${scratch.num_val_workers} + pin_memory: True + drop_last: False + collate_fn: ${scratch.collate_fn_val} + + + model: + _target_: sam3.model_builder.build_sam3_image_model + bpe_path: ${paths.bpe_path} + device: cpus + eval_mode: true + enable_segmentation: ${scratch.enable_segmentation} # Warning: Enable this if using segmentation. + + meters: + val: + roboflow100: + detection: + _target_: sam3.eval.coco_writer.PredictionDumper + iou_type: "bbox" + dump_dir: ${launcher.experiment_log_dir}/dumps/roboflow/${roboflow_train.supercategory} + merge_predictions: True + postprocessor: ${scratch.original_box_postprocessor} + gather_pred_via_filesys: ${scratch.gather_pred_via_filesys} + maxdets: 100 + pred_file_evaluators: + - _target_: sam3.eval.coco_eval_offline.CocoEvaluatorOfflineWithPredFileEvaluators + gt_path: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/test/_annotations.coco.json + tide: False + iou_type: "bbox" + + optim: + amp: + enabled: True + amp_dtype: bfloat16 + + optimizer: + _target_: torch.optim.AdamW + + gradient_clip: + _target_: sam3.train.optim.optimizer.GradientClipper + max_norm: 0.1 + norm_type: 2 + + param_group_modifiers: + - _target_: sam3.train.optim.optimizer.layer_decay_param_modifier + _partial_: True + layer_decay_value: ${scratch.lrd_vision_backbone} + apply_to: 'backbone.vision_backbone.trunk' + overrides: + - pattern: '*pos_embed*' + value: 1.0 + + options: + lr: + - scheduler: # transformer and class_embed + _target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler + base_lr: ${scratch.lr_transformer} + timescale: ${scratch.scheduler_timescale} + warmup_steps: ${scratch.scheduler_warmup} + cooldown_steps: ${scratch.scheduler_cooldown} + - scheduler: + _target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler + base_lr: ${scratch.lr_vision_backbone} + timescale: ${scratch.scheduler_timescale} + warmup_steps: ${scratch.scheduler_warmup} + cooldown_steps: ${scratch.scheduler_cooldown} + param_names: + - 'backbone.vision_backbone.*' + - scheduler: + _target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler + base_lr: ${scratch.lr_language_backbone} + timescale: ${scratch.scheduler_timescale} + warmup_steps: ${scratch.scheduler_warmup} + cooldown_steps: ${scratch.scheduler_cooldown} + param_names: + - 'backbone.language_backbone.*' + + weight_decay: + - scheduler: + _target_: fvcore.common.param_scheduler.ConstantParamScheduler + value: ${scratch.wd} + - scheduler: + _target_: fvcore.common.param_scheduler.ConstantParamScheduler + value: 0.0 + param_names: + - '*bias*' + module_cls_names: ['torch.nn.LayerNorm'] + + checkpoint: + save_dir: ${launcher.experiment_log_dir}/checkpoints + save_freq: 0 # 0 only last checkpoint is saved. + + logging: + tensorboard_writer: + _target_: sam3.train.utils.logger.make_tensorboard_logger + log_dir: ${launcher.experiment_log_dir}/tensorboard + flush_secs: 120 + should_log: True + wandb_writer: null + log_dir: ${launcher.experiment_log_dir}/logs/${roboflow_train.supercategory} + log_freq: 10 + +# ============================================================================ +# Launcher and Submitit Configuration +# ============================================================================ + +launcher: + num_nodes: 1 + gpus_per_node: 2 + experiment_log_dir: ${paths.experiment_log_dir} + multiprocessing_context: forkserver + +submitit: + account: null + partition: null + qos: null + timeout_hour: 72 + use_cluster: True + cpus_per_task: 10 + port_range: [10000, 65000] + constraint: null + # Uncomment for job array configuration + job_array: + num_tasks: 100 + task_index: 0 + +# ============================================================================ +# Available Roboflow Supercategories (for reference) +# ============================================================================ + +all_roboflow_supercategories: + - -grccs + - zebrasatasturias + - cod-mw-warzone + - canalstenosis + - label-printing-defect-version-2 + - new-defects-in-wood + - orionproducts + - aquarium-combined + - varroa-mites-detection--test-set + - clashroyalechardetector + - stomata-cells + - halo-infinite-angel-videogame + - pig-detection + - urine-analysis1 + - aerial-sheep + - orgharvest + - actions + - mahjong + - liver-disease + - needle-base-tip-min-max + - wheel-defect-detection + - aircraft-turnaround-dataset + - xray + - wildfire-smoke + - spinefrxnormalvindr + - ufba-425 + - speech-bubbles-detection + - train + - pill + - truck-movement + - car-logo-detection + - inbreast + - sea-cucumbers-new-tiles + - uavdet-small + - penguin-finder-seg + - aerial-airport + - bibdetection + - taco-trash-annotations-in-context + - bees + - recode-waste + - screwdetectclassification + - wine-labels + - aerial-cows + - into-the-vale + - gwhd2021 + - lacrosse-object-detection + - defect-detection + - dataconvert + - x-ray-id + - ball + - tube + - 2024-frc + - crystal-clean-brain-tumors-mri-dataset + - grapes-5 + - human-detection-in-floods + - buoy-onboarding + - apoce-aerial-photographs-for-object-detection-of-construction-equipment + - l10ul502 + - floating-waste + - deeppcb + - ism-band-packet-detection + - weeds4 + - invoice-processing + - thermal-cheetah + - tomatoes-2 + - marine-sharks + - peixos-fish + - sssod + - aerial-pool + - countingpills + - asphaltdistressdetection + - roboflow-trained-dataset + - everdaynew + - underwater-objects + - soda-bottles + - dentalai + - jellyfish + - deepfruits + - activity-diagrams + - circuit-voltages + - all-elements + - macro-segmentation + - exploratorium-daphnia + - signatures + - conveyor-t-shirts + - fruitjes + - grass-weeds + - infraredimageofpowerequipment + - 13-lkc01 + - wb-prova + - flir-camera-objects + - paper-parts + - football-player-detection + - trail-camera + - smd-components + - water-meter + - nih-xray + - the-dreidel-project + - electric-pylon-detection-in-rsi + - cable-damage diff --git a/third_party/sam3/sam3/train/configs/roboflow_v100/roboflow_v100_full_ft_100_images.yaml b/third_party/sam3/sam3/train/configs/roboflow_v100/roboflow_v100_full_ft_100_images.yaml new file mode 100644 index 0000000000000000000000000000000000000000..480a218b7c9d7080c301d6274480a1d3e5d3a454 --- /dev/null +++ b/third_party/sam3/sam3/train/configs/roboflow_v100/roboflow_v100_full_ft_100_images.yaml @@ -0,0 +1,539 @@ +# @package _global_ +defaults: + - _self_ + +# ============================================================================ +# Paths Configuration (Chage this to your own paths) +# ============================================================================ +paths: + roboflow_vl_100_root: + experiment_log_dir: + bpe_path: # This should be under sam3/assets/bpe_simple_vocab_16e6.txt.gz + +# Roboflow dataset configuration +roboflow_train: + num_images: 100 # Note: This is the number of images used for training. If null, all images are used. + supercategory: ${all_roboflow_supercategories.${string:${submitit.job_array.task_index}}} + + # Training transforms pipeline + train_transforms: + - _target_: sam3.train.transforms.basic_for_api.ComposeAPI + transforms: + - _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries + query_filter: + _target_: sam3.train.transforms.filter_query_transforms.FilterCrowds + - _target_: sam3.train.transforms.point_sampling.RandomizeInputBbox + box_noise_std: 0.1 + box_noise_max: 20 + - _target_: sam3.train.transforms.segmentation.DecodeRle + - _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI + sizes: + _target_: sam3.train.transforms.basic.get_random_resize_scales + size: ${scratch.resolution} + min_size: 480 + rounded: false + max_size: + _target_: sam3.train.transforms.basic.get_random_resize_max_size + size: ${scratch.resolution} + square: true + consistent_transform: ${scratch.consistent_transform} + - _target_: sam3.train.transforms.basic_for_api.PadToSizeAPI + size: ${scratch.resolution} + consistent_transform: ${scratch.consistent_transform} + - _target_: sam3.train.transforms.basic_for_api.ToTensorAPI + - _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries + query_filter: + _target_: sam3.train.transforms.filter_query_transforms.FilterEmptyTargets + - _target_: sam3.train.transforms.basic_for_api.NormalizeAPI + mean: ${scratch.train_norm_mean} + std: ${scratch.train_norm_std} + - _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries + query_filter: + _target_: sam3.train.transforms.filter_query_transforms.FilterEmptyTargets + - _target_: sam3.train.transforms.filter_query_transforms.FlexibleFilterFindGetQueries + query_filter: + _target_: sam3.train.transforms.filter_query_transforms.FilterFindQueriesWithTooManyOut + max_num_objects: ${scratch.max_ann_per_img} + + # Validation transforms pipeline + val_transforms: + - _target_: sam3.train.transforms.basic_for_api.ComposeAPI + transforms: + - _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI + sizes: ${scratch.resolution} + max_size: + _target_: sam3.train.transforms.basic.get_random_resize_max_size + size: ${scratch.resolution} + square: true + consistent_transform: False + - _target_: sam3.train.transforms.basic_for_api.ToTensorAPI + - _target_: sam3.train.transforms.basic_for_api.NormalizeAPI + mean: ${scratch.train_norm_mean} + std: ${scratch.train_norm_std} + + # loss config (no mask loss) + loss: + _target_: sam3.train.loss.sam3_loss.Sam3LossWrapper + matcher: ${scratch.matcher} + o2m_weight: 2.0 + o2m_matcher: + _target_: sam3.train.matcher.BinaryOneToManyMatcher + alpha: 0.3 + threshold: 0.4 + topk: 4 + use_o2m_matcher_on_o2m_aux: false # Another option is true + loss_fns_find: + - _target_: sam3.train.loss.loss_fns.Boxes + weight_dict: + loss_bbox: 5.0 + loss_giou: 2.0 + - _target_: sam3.train.loss.loss_fns.IABCEMdetr + weak_loss: False + weight_dict: + loss_ce: 20.0 # Another option is 100.0 + presence_loss: 20.0 + pos_weight: 10.0 # Another option is 5.0 + alpha: 0.25 + gamma: 2 + use_presence: True # Change + pos_focal: false + pad_n_queries: 200 + pad_scale_pos: 1.0 + + loss_fn_semantic_seg: null + scale_by_find_batch_size: ${scratch.scale_by_find_batch_size} + + + # NOTE: Loss to be used for training in case of segmentation + # loss: + # _target_: sam3.train.loss.sam3_loss.Sam3LossWrapper + # matcher: ${scratch.matcher} + # o2m_weight: 2.0 + # o2m_matcher: + # _target_: sam3.train.matcher.BinaryOneToManyMatcher + # alpha: 0.3 + # threshold: 0.4 + # topk: 4 + # use_o2m_matcher_on_o2m_aux: false + # loss_fns_find: + # - _target_: sam3.train.loss.loss_fns.Boxes + # weight_dict: + # loss_bbox: 5.0 + # loss_giou: 2.0 + # - _target_: sam3.train.loss.loss_fns.IABCEMdetr + # weak_loss: False + # weight_dict: + # loss_ce: 20.0 # Another option is 100.0 + # presence_loss: 20.0 + # pos_weight: 10.0 # Another option is 5.0 + # alpha: 0.25 + # gamma: 2 + # use_presence: True # Change + # pos_focal: false + # pad_n_queries: 200 + # pad_scale_pos: 1.0 + # - _target_: sam3.train.loss.loss_fns.Masks + # focal_alpha: 0.25 + # focal_gamma: 2.0 + # weight_dict: + # loss_mask: 200.0 + # loss_dice: 10.0 + # compute_aux: false + # loss_fn_semantic_seg: + # _target_: sam3.losses.loss_fns.SemanticSegCriterion + # presence_head: True + # presence_loss: False # Change + # focal: True + # focal_alpha: 0.6 + # focal_gamma: 2.0 + # downsample: False + # weight_dict: + # loss_semantic_seg: 20.0 + # loss_semantic_presence: 1.0 + # loss_semantic_dice: 30.0 + # scale_by_find_batch_size: ${scratch.scale_by_find_batch_size} + +# ============================================================================ +# Different helper parameters and functions +# ============================================================================ +scratch: + enable_segmentation: False # NOTE: This is the number of queries used for segmentation + # Model parameters + d_model: 256 + pos_embed: + _target_: sam3.model.position_encoding.PositionEmbeddingSine + num_pos_feats: ${scratch.d_model} + normalize: true + scale: null + temperature: 10000 + + # Box processing + use_presence_eval: True + original_box_postprocessor: + _target_: sam3.eval.postprocessors.PostProcessImage + max_dets_per_img: -1 # infinite detections + use_original_ids: true + use_original_sizes_box: true + use_presence: ${scratch.use_presence_eval} + + # Matcher configuration + matcher: + _target_: sam3.train.matcher.BinaryHungarianMatcherV2 + focal: true # with `focal: true` it is equivalent to BinaryFocalHungarianMatcher + cost_class: 2.0 + cost_bbox: 5.0 + cost_giou: 2.0 + alpha: 0.25 + gamma: 2 + stable: False + scale_by_find_batch_size: True + + # Image processing parameters + resolution: 1008 + consistent_transform: False + max_ann_per_img: 200 + + # Normalization parameters + train_norm_mean: [0.5, 0.5, 0.5] + train_norm_std: [0.5, 0.5, 0.5] + val_norm_mean: [0.5, 0.5, 0.5] + val_norm_std: [0.5, 0.5, 0.5] + + # Training parameters + num_train_workers: 10 + num_val_workers: 0 + max_data_epochs: 20 + target_epoch_size: 1500 + hybrid_repeats: 1 + context_length: 2 + gather_pred_via_filesys: false + + # Learning rate and scheduler parameters + lr_scale: 0.1 + lr_transformer: ${times:8e-4,${scratch.lr_scale}} + lr_vision_backbone: ${times:2.5e-4,${scratch.lr_scale}} + lr_language_backbone: ${times:5e-5,${scratch.lr_scale}} + lrd_vision_backbone: 0.9 + wd: 0.1 + scheduler_timescale: 20 + scheduler_warmup: 20 + scheduler_cooldown: 20 + + val_batch_size: 1 + collate_fn_val: + _target_: sam3.train.data.collator.collate_fn_api + _partial_: true + repeats: ${scratch.hybrid_repeats} + dict_key: roboflow100 + with_seg_masks: ${scratch.enable_segmentation} # Note: Set this to true if using segmentation masks! + + gradient_accumulation_steps: 1 + train_batch_size: 1 + collate_fn: + _target_: sam3.train.data.collator.collate_fn_api + _partial_: true + repeats: ${scratch.hybrid_repeats} + dict_key: all + with_seg_masks: ${scratch.enable_segmentation} # Note: Set this to true if using segmentation masks! + +# ============================================================================ +# Trainer Configuration +# ============================================================================ + +trainer: + + _target_: sam3.train.trainer.Trainer + skip_saving_ckpts: true + empty_gpu_mem_cache_after_eval: True + skip_first_val: True + max_epochs: 20 + accelerator: cuda + seed_value: 123 + val_epoch_freq: 10 + mode: train + gradient_accumulation_steps: ${scratch.gradient_accumulation_steps} + + distributed: + backend: nccl + find_unused_parameters: True + gradient_as_bucket_view: True + + loss: + all: ${roboflow_train.loss} + default: + _target_: sam3.train.loss.sam3_loss.DummyLoss + + data: + train: + _target_: sam3.train.data.torch_dataset.TorchDataset + dataset: + _target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset + limit_ids: ${roboflow_train.num_images} + transforms: ${roboflow_train.train_transforms} + load_segmentation: ${scratch.enable_segmentation} + max_ann_per_img: 500000 + multiplier: 1 + max_train_queries: 50000 + max_val_queries: 50000 + training: true + use_caching: False + img_folder: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/train/ + ann_file: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/train/_annotations.coco.json + + shuffle: True + batch_size: ${scratch.train_batch_size} + num_workers: ${scratch.num_train_workers} + pin_memory: True + drop_last: True + collate_fn: ${scratch.collate_fn} + + val: + _target_: sam3.train.data.torch_dataset.TorchDataset + dataset: + _target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset + load_segmentation: ${scratch.enable_segmentation} + coco_json_loader: + _target_: sam3.train.data.coco_json_loaders.COCO_FROM_JSON + include_negatives: true + category_chunk_size: 2 # Note: You can increase this based on the memory of your GPU. + _partial_: true + img_folder: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/test/ + ann_file: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/test/_annotations.coco.json + transforms: ${roboflow_train.val_transforms} + max_ann_per_img: 100000 + multiplier: 1 + training: false + + shuffle: False + batch_size: ${scratch.val_batch_size} + num_workers: ${scratch.num_val_workers} + pin_memory: True + drop_last: False + collate_fn: ${scratch.collate_fn_val} + + + model: + _target_: sam3.model_builder.build_sam3_image_model + bpe_path: ${paths.bpe_path} + device: cpus + eval_mode: false + enable_segmentation: ${scratch.enable_segmentation} # Warning: Enable this if using segmentation. + + meters: + val: + roboflow100: + detection: + _target_: sam3.eval.coco_writer.PredictionDumper + iou_type: "bbox" + dump_dir: ${launcher.experiment_log_dir}/dumps/roboflow/${roboflow_train.supercategory} + merge_predictions: True + postprocessor: ${scratch.original_box_postprocessor} + gather_pred_via_filesys: ${scratch.gather_pred_via_filesys} + maxdets: 100 + pred_file_evaluators: + - _target_: sam3.eval.coco_eval_offline.CocoEvaluatorOfflineWithPredFileEvaluators + gt_path: ${paths.roboflow_vl_100_root}/${roboflow_train.supercategory}/test/_annotations.coco.json + tide: False + iou_type: "bbox" + + optim: + amp: + enabled: True + amp_dtype: bfloat16 + + optimizer: + _target_: torch.optim.AdamW + + gradient_clip: + _target_: sam3.train.optim.optimizer.GradientClipper + max_norm: 0.1 + norm_type: 2 + + param_group_modifiers: + - _target_: sam3.train.optim.optimizer.layer_decay_param_modifier + _partial_: True + layer_decay_value: ${scratch.lrd_vision_backbone} + apply_to: 'backbone.vision_backbone.trunk' + overrides: + - pattern: '*pos_embed*' + value: 1.0 + + options: + lr: + - scheduler: # transformer and class_embed + _target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler + base_lr: ${scratch.lr_transformer} + timescale: ${scratch.scheduler_timescale} + warmup_steps: ${scratch.scheduler_warmup} + cooldown_steps: ${scratch.scheduler_cooldown} + - scheduler: + _target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler + base_lr: ${scratch.lr_vision_backbone} + timescale: ${scratch.scheduler_timescale} + warmup_steps: ${scratch.scheduler_warmup} + cooldown_steps: ${scratch.scheduler_cooldown} + param_names: + - 'backbone.vision_backbone.*' + - scheduler: + _target_: sam3.train.optim.schedulers.InverseSquareRootParamScheduler + base_lr: ${scratch.lr_language_backbone} + timescale: ${scratch.scheduler_timescale} + warmup_steps: ${scratch.scheduler_warmup} + cooldown_steps: ${scratch.scheduler_cooldown} + param_names: + - 'backbone.language_backbone.*' + + weight_decay: + - scheduler: + _target_: fvcore.common.param_scheduler.ConstantParamScheduler + value: ${scratch.wd} + - scheduler: + _target_: fvcore.common.param_scheduler.ConstantParamScheduler + value: 0.0 + param_names: + - '*bias*' + module_cls_names: ['torch.nn.LayerNorm'] + + checkpoint: + save_dir: ${launcher.experiment_log_dir}/checkpoints + save_freq: 0 # 0 only last checkpoint is saved. + + logging: + tensorboard_writer: + _target_: sam3.train.utils.logger.make_tensorboard_logger + log_dir: ${launcher.experiment_log_dir}/tensorboard + flush_secs: 120 + should_log: True + wandb_writer: null + log_dir: ${launcher.experiment_log_dir}/logs/${roboflow_train.supercategory} + log_freq: 10 + +# ============================================================================ +# Launcher and Submitit Configuration +# ============================================================================ + +launcher: + num_nodes: 1 + gpus_per_node: 2 + experiment_log_dir: ${paths.experiment_log_dir} + multiprocessing_context: forkserver + +submitit: + account: null + partition: null + qos: null + timeout_hour: 72 + use_cluster: True + cpus_per_task: 10 + port_range: [10000, 65000] + constraint: null + # Uncomment for job array configuration + job_array: + num_tasks: 100 + task_index: 0 + +# ============================================================================ +# Available Roboflow Supercategories (for reference) +# ============================================================================ + +all_roboflow_supercategories: + - -grccs + - zebrasatasturias + - cod-mw-warzone + - canalstenosis + - label-printing-defect-version-2 + - new-defects-in-wood + - orionproducts + - aquarium-combined + - varroa-mites-detection--test-set + - clashroyalechardetector + - stomata-cells + - halo-infinite-angel-videogame + - pig-detection + - urine-analysis1 + - aerial-sheep + - orgharvest + - actions + - mahjong + - liver-disease + - needle-base-tip-min-max + - wheel-defect-detection + - aircraft-turnaround-dataset + - xray + - wildfire-smoke + - spinefrxnormalvindr + - ufba-425 + - speech-bubbles-detection + - train + - pill + - truck-movement + - car-logo-detection + - inbreast + - sea-cucumbers-new-tiles + - uavdet-small + - penguin-finder-seg + - aerial-airport + - bibdetection + - taco-trash-annotations-in-context + - bees + - recode-waste + - screwdetectclassification + - wine-labels + - aerial-cows + - into-the-vale + - gwhd2021 + - lacrosse-object-detection + - defect-detection + - dataconvert + - x-ray-id + - ball + - tube + - 2024-frc + - crystal-clean-brain-tumors-mri-dataset + - grapes-5 + - human-detection-in-floods + - buoy-onboarding + - apoce-aerial-photographs-for-object-detection-of-construction-equipment + - l10ul502 + - floating-waste + - deeppcb + - ism-band-packet-detection + - weeds4 + - invoice-processing + - thermal-cheetah + - tomatoes-2 + - marine-sharks + - peixos-fish + - sssod + - aerial-pool + - countingpills + - asphaltdistressdetection + - roboflow-trained-dataset + - everdaynew + - underwater-objects + - soda-bottles + - dentalai + - jellyfish + - deepfruits + - activity-diagrams + - circuit-voltages + - all-elements + - macro-segmentation + - exploratorium-daphnia + - signatures + - conveyor-t-shirts + - fruitjes + - grass-weeds + - infraredimageofpowerequipment + - 13-lkc01 + - wb-prova + - flir-camera-objects + - paper-parts + - football-player-detection + - trail-camera + - smd-components + - water-meter + - nih-xray + - the-dreidel-project + - electric-pylon-detection-in-rsi + - cable-damage diff --git a/third_party/sam3/sam3/train/configs/saco_video_evals/saco_veval_sav_test.yaml b/third_party/sam3/sam3/train/configs/saco_video_evals/saco_veval_sav_test.yaml new file mode 100644 index 0000000000000000000000000000000000000000..001a0927070c1171b80050a7cad91debb3b6e92d --- /dev/null +++ b/third_party/sam3/sam3/train/configs/saco_video_evals/saco_veval_sav_test.yaml @@ -0,0 +1,174 @@ +# @package _global_ +defaults: + - _self_ + +# ============================================================================ +# Paths Configuration (Chage this to your own paths) +# ============================================================================ +paths: + + dump_file_name: saco_veval_sav_test + experiment_log_dir: + ytvis_json: /saco_veval_sav_test.json + ytvis_dir : + bpe_path: # This should be under sam3/assets/bpe_simple_vocab_16e6.txt.gz + num_videos: null + +# ============================================================================ +# Different helper parameters and functions +# ============================================================================ +scratch: + vid_mask_postprocessor: + _target_: sam3.eval.postprocessors.PostProcessNullOp + + use_presence_eval: True + + video_transforms_val: + - _target_: sam3.train.transforms.basic_for_api.ComposeAPI + transforms: + - _target_: sam3.train.transforms.segmentation.DecodeRle + # resize the image to 1024x1024 resolution + - _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI + sizes: ${scratch.resolution} # originally `resolution: 1024` + square: true + consistent_transform: true + - _target_: sam3.train.transforms.basic_for_api.ToTensorAPI + - _target_: sam3.train.transforms.basic_for_api.NormalizeAPI + mean: ${scratch.val_norm_mean} + std: ${scratch.val_norm_std} + + # Model parameters + d_model: 256 + + # Image processing parameters + resolution: 1008 + + # Normalization parameters + train_norm_mean: [0.5, 0.5, 0.5] + train_norm_std: [0.5, 0.5, 0.5] + val_norm_mean: [0.5, 0.5, 0.5] + val_norm_std: [0.5, 0.5, 0.5] + + val_batch_size: 1 + num_val_workers: 0 + max_data_epochs: 20 + hybrid_repeats: 1 + gather_pred_via_filesys: false + + +# ============================================================================ +# Trainer Configuration +# ============================================================================ + +trainer: + _target_: sam3.train.trainer.Trainer + skip_saving_ckpts: true + empty_gpu_mem_cache_after_eval: True + skip_first_val: True + max_epochs: ${scratch.max_data_epochs} + accelerator: cuda + seed_value: 123 + val_epoch_freq: 10 + mode: val + + distributed: + backend: nccl + find_unused_parameters: True + gradient_as_bucket_view: True + + loss: + all: + _target_: sam3.train.loss.sam3_loss.DummyLoss + default: + _target_: sam3.train.loss.sam3_loss.DummyLoss + + data: + train: null + val: + _target_: sam3.train.data.torch_dataset.TorchDataset + dataset: + _target_: sam3.train.data.sam3_video_dataset.VideoGroundingDataset + limit_ids: ${paths.num_videos} + img_folder: ${paths.ytvis_dir} + ann_file: ${paths.ytvis_json} + coco_json_loader: + _target_: sam3.train.data.coco_json_loaders.SAM3_VEVAL_API_FROM_JSON_NP + _partial_: true + + transforms: ${scratch.video_transforms_val} + max_ann_per_img: 100000 # filtered in transforms + max_val_queries: 100000 + multiplier: 1 + load_segmentation: true + training: false + + + shuffle: False + batch_size: ${scratch.val_batch_size} + num_workers: ${scratch.num_val_workers} + pin_memory: True + drop_last: False + collate_fn: + _target_: sam3.train.data.collator.collate_fn_api + _partial_: true + repeats: ${scratch.hybrid_repeats} + dict_key: ytvis_val + with_seg_masks: true + + + model: + _target_: sam3.model_builder.build_sam3_video_model + bpe_path: ${paths.bpe_path} + has_presence_token: True + geo_encoder_use_img_cross_attn: True + apply_temporal_disambiguation: True + + meters: + val: + ytvis_val: + pred_file: # key + _target_: sam3.eval.ytvis_eval.YTVISResultsWriter + dump_file: ${launcher.experiment_log_dir}/preds/${paths.dump_file_name}.json + postprocessor: ${scratch.vid_mask_postprocessor} + gather_pred_via_filesys: ${scratch.gather_pred_via_filesys} + + optim: + amp: + enabled: True + amp_dtype: bfloat16 + + + checkpoint: + save_dir: ${launcher.experiment_log_dir}/checkpoints + save_freq: 0 # 0 only last checkpoint is saved. + + + logging: + tensorboard_writer: + _target_: sam3.train.utils.logger.make_tensorboard_logger + log_dir: ${launcher.experiment_log_dir}/tensorboard + flush_secs: 120 + should_log: True + wandb_writer: null + log_dir: ${launcher.experiment_log_dir}/logs/ + log_freq: 10 + +# ============================================================================ +# Launcher and Submitit Configuration +# ============================================================================ + +launcher: + num_nodes: 8 + gpus_per_node: 8 + experiment_log_dir: ${paths.experiment_log_dir} + multiprocessing_context: forkserver + +submitit: + account: null + partition: null + qos: null + timeout_hour: 72 + use_cluster: True + cpus_per_task: 10 + port_range: [10000, 65000] + constraint: null diff --git a/third_party/sam3/sam3/train/configs/saco_video_evals/saco_veval_sav_test_noheur.yaml b/third_party/sam3/sam3/train/configs/saco_video_evals/saco_veval_sav_test_noheur.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a899d354a370405db94a9d5072c7bad14d6c82dc --- /dev/null +++ b/third_party/sam3/sam3/train/configs/saco_video_evals/saco_veval_sav_test_noheur.yaml @@ -0,0 +1,174 @@ +# @package _global_ +defaults: + - _self_ + +# ============================================================================ +# Paths Configuration (Chage this to your own paths) +# ============================================================================ +paths: + + dump_file_name: saco_veval_sav_test + experiment_log_dir: + ytvis_json: /saco_veval_sav_test.json + ytvis_dir : + bpe_path: # This should be under sam3/assets/bpe_simple_vocab_16e6.txt.gz + num_videos: null + +# ============================================================================ +# Different helper parameters and functions +# ============================================================================ +scratch: + vid_mask_postprocessor: + _target_: sam3.eval.postprocessors.PostProcessNullOp + + use_presence_eval: True + + video_transforms_val: + - _target_: sam3.train.transforms.basic_for_api.ComposeAPI + transforms: + - _target_: sam3.train.transforms.segmentation.DecodeRle + # resize the image to 1024x1024 resolution + - _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI + sizes: ${scratch.resolution} # originally `resolution: 1024` + square: true + consistent_transform: true + - _target_: sam3.train.transforms.basic_for_api.ToTensorAPI + - _target_: sam3.train.transforms.basic_for_api.NormalizeAPI + mean: ${scratch.val_norm_mean} + std: ${scratch.val_norm_std} + + # Model parameters + d_model: 256 + + # Image processing parameters + resolution: 1008 + + # Normalization parameters + train_norm_mean: [0.5, 0.5, 0.5] + train_norm_std: [0.5, 0.5, 0.5] + val_norm_mean: [0.5, 0.5, 0.5] + val_norm_std: [0.5, 0.5, 0.5] + + val_batch_size: 1 + num_val_workers: 0 + max_data_epochs: 20 + hybrid_repeats: 1 + gather_pred_via_filesys: false + + +# ============================================================================ +# Trainer Configuration +# ============================================================================ + +trainer: + _target_: sam3.train.trainer.Trainer + skip_saving_ckpts: true + empty_gpu_mem_cache_after_eval: True + skip_first_val: True + max_epochs: ${scratch.max_data_epochs} + accelerator: cuda + seed_value: 123 + val_epoch_freq: 10 + mode: val + + distributed: + backend: nccl + find_unused_parameters: True + gradient_as_bucket_view: True + + loss: + all: + _target_: sam3.train.loss.sam3_loss.DummyLoss + default: + _target_: sam3.train.loss.sam3_loss.DummyLoss + + data: + train: null + val: + _target_: sam3.train.data.torch_dataset.TorchDataset + dataset: + _target_: sam3.train.data.sam3_video_dataset.VideoGroundingDataset + limit_ids: ${paths.num_videos} + img_folder: ${paths.ytvis_dir} + ann_file: ${paths.ytvis_json} + coco_json_loader: + _target_: sam3.train.data.coco_json_loaders.SAM3_VEVAL_API_FROM_JSON_NP + _partial_: true + + transforms: ${scratch.video_transforms_val} + max_ann_per_img: 100000 # filtered in transforms + max_val_queries: 100000 + multiplier: 1 + load_segmentation: true + training: false + + + shuffle: False + batch_size: ${scratch.val_batch_size} + num_workers: ${scratch.num_val_workers} + pin_memory: True + drop_last: False + collate_fn: + _target_: sam3.train.data.collator.collate_fn_api + _partial_: true + repeats: ${scratch.hybrid_repeats} + dict_key: ytvis_val + with_seg_masks: true + + + model: + _target_: sam3.model_builder.build_sam3_video_model + bpe_path: ${paths.bpe_path} + has_presence_token: True + geo_encoder_use_img_cross_attn: True + apply_temporal_disambiguation: False + + meters: + val: + ytvis_val: + pred_file: # key + _target_: sam3.eval.ytvis_eval.YTVISResultsWriter + dump_file: ${launcher.experiment_log_dir}/preds/${paths.dump_file_name}.json + postprocessor: ${scratch.vid_mask_postprocessor} + gather_pred_via_filesys: ${scratch.gather_pred_via_filesys} + + optim: + amp: + enabled: True + amp_dtype: bfloat16 + + + checkpoint: + save_dir: ${launcher.experiment_log_dir}/checkpoints + save_freq: 0 # 0 only last checkpoint is saved. + + + logging: + tensorboard_writer: + _target_: sam3.train.utils.logger.make_tensorboard_logger + log_dir: ${launcher.experiment_log_dir}/tensorboard + flush_secs: 120 + should_log: True + wandb_writer: null + log_dir: ${launcher.experiment_log_dir}/logs/ + log_freq: 10 + +# ============================================================================ +# Launcher and Submitit Configuration +# ============================================================================ + +launcher: + num_nodes: 8 + gpus_per_node: 8 + experiment_log_dir: ${paths.experiment_log_dir} + multiprocessing_context: forkserver + +submitit: + account: null + partition: null + qos: null + timeout_hour: 72 + use_cluster: True + cpus_per_task: 10 + port_range: [10000, 65000] + constraint: null diff --git a/third_party/sam3/sam3/train/configs/saco_video_evals/saco_veval_sav_val.yaml b/third_party/sam3/sam3/train/configs/saco_video_evals/saco_veval_sav_val.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b4eab6cf2260f6a45aea68cf72e45c2da9dccbed --- /dev/null +++ b/third_party/sam3/sam3/train/configs/saco_video_evals/saco_veval_sav_val.yaml @@ -0,0 +1,174 @@ +# @package _global_ +defaults: + - _self_ + +# ============================================================================ +# Paths Configuration (Chage this to your own paths) +# ============================================================================ +paths: + + dump_file_name: saco_veval_sav_val + experiment_log_dir: + ytvis_json: /saco_veval_sav_val.json + ytvis_dir : + bpe_path: # This should be under sam3/assets/bpe_simple_vocab_16e6.txt.gz + num_videos: null + +# ============================================================================ +# Different helper parameters and functions +# ============================================================================ +scratch: + vid_mask_postprocessor: + _target_: sam3.eval.postprocessors.PostProcessNullOp + + use_presence_eval: True + + video_transforms_val: + - _target_: sam3.train.transforms.basic_for_api.ComposeAPI + transforms: + - _target_: sam3.train.transforms.segmentation.DecodeRle + # resize the image to 1024x1024 resolution + - _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI + sizes: ${scratch.resolution} # originally `resolution: 1024` + square: true + consistent_transform: true + - _target_: sam3.train.transforms.basic_for_api.ToTensorAPI + - _target_: sam3.train.transforms.basic_for_api.NormalizeAPI + mean: ${scratch.val_norm_mean} + std: ${scratch.val_norm_std} + + # Model parameters + d_model: 256 + + # Image processing parameters + resolution: 1008 + + # Normalization parameters + train_norm_mean: [0.5, 0.5, 0.5] + train_norm_std: [0.5, 0.5, 0.5] + val_norm_mean: [0.5, 0.5, 0.5] + val_norm_std: [0.5, 0.5, 0.5] + + val_batch_size: 1 + num_val_workers: 0 + max_data_epochs: 20 + hybrid_repeats: 1 + gather_pred_via_filesys: false + + +# ============================================================================ +# Trainer Configuration +# ============================================================================ + +trainer: + _target_: sam3.train.trainer.Trainer + skip_saving_ckpts: true + empty_gpu_mem_cache_after_eval: True + skip_first_val: True + max_epochs: ${scratch.max_data_epochs} + accelerator: cuda + seed_value: 123 + val_epoch_freq: 10 + mode: val + + distributed: + backend: nccl + find_unused_parameters: True + gradient_as_bucket_view: True + + loss: + all: + _target_: sam3.train.loss.sam3_loss.DummyLoss + default: + _target_: sam3.train.loss.sam3_loss.DummyLoss + + data: + train: null + val: + _target_: sam3.train.data.torch_dataset.TorchDataset + dataset: + _target_: sam3.train.data.sam3_video_dataset.VideoGroundingDataset + limit_ids: ${paths.num_videos} + img_folder: ${paths.ytvis_dir} + ann_file: ${paths.ytvis_json} + coco_json_loader: + _target_: sam3.train.data.coco_json_loaders.SAM3_VEVAL_API_FROM_JSON_NP + _partial_: true + + transforms: ${scratch.video_transforms_val} + max_ann_per_img: 100000 # filtered in transforms + max_val_queries: 100000 + multiplier: 1 + load_segmentation: true + training: false + + + shuffle: False + batch_size: ${scratch.val_batch_size} + num_workers: ${scratch.num_val_workers} + pin_memory: True + drop_last: False + collate_fn: + _target_: sam3.train.data.collator.collate_fn_api + _partial_: true + repeats: ${scratch.hybrid_repeats} + dict_key: ytvis_val + with_seg_masks: true + + + model: + _target_: sam3.model_builder.build_sam3_video_model + bpe_path: ${paths.bpe_path} + has_presence_token: True + geo_encoder_use_img_cross_attn: True + apply_temporal_disambiguation: True + + meters: + val: + ytvis_val: + pred_file: # key + _target_: sam3.eval.ytvis_eval.YTVISResultsWriter + dump_file: ${launcher.experiment_log_dir}/preds/${paths.dump_file_name}.json + postprocessor: ${scratch.vid_mask_postprocessor} + gather_pred_via_filesys: ${scratch.gather_pred_via_filesys} + + optim: + amp: + enabled: True + amp_dtype: bfloat16 + + + checkpoint: + save_dir: ${launcher.experiment_log_dir}/checkpoints + save_freq: 0 # 0 only last checkpoint is saved. + + + logging: + tensorboard_writer: + _target_: sam3.train.utils.logger.make_tensorboard_logger + log_dir: ${launcher.experiment_log_dir}/tensorboard + flush_secs: 120 + should_log: True + wandb_writer: null + log_dir: ${launcher.experiment_log_dir}/logs/ + log_freq: 10 + +# ============================================================================ +# Launcher and Submitit Configuration +# ============================================================================ + +launcher: + num_nodes: 8 + gpus_per_node: 8 + experiment_log_dir: ${paths.experiment_log_dir} + multiprocessing_context: forkserver + +submitit: + account: null + partition: null + qos: null + timeout_hour: 72 + use_cluster: True + cpus_per_task: 10 + port_range: [10000, 65000] + constraint: null diff --git a/third_party/sam3/sam3/train/configs/saco_video_evals/saco_veval_sav_val_noheur.yaml b/third_party/sam3/sam3/train/configs/saco_video_evals/saco_veval_sav_val_noheur.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f0ca842331ea2a167c05bd1b2f884cd8e4538d02 --- /dev/null +++ b/third_party/sam3/sam3/train/configs/saco_video_evals/saco_veval_sav_val_noheur.yaml @@ -0,0 +1,174 @@ +# @package _global_ +defaults: + - _self_ + +# ============================================================================ +# Paths Configuration (Chage this to your own paths) +# ============================================================================ +paths: + + dump_file_name: saco_veval_sav_val + experiment_log_dir: + ytvis_json: /saco_veval_sav_val.json + ytvis_dir : + bpe_path: # This should be under sam3/assets/bpe_simple_vocab_16e6.txt.gz + num_videos: null + +# ============================================================================ +# Different helper parameters and functions +# ============================================================================ +scratch: + vid_mask_postprocessor: + _target_: sam3.eval.postprocessors.PostProcessNullOp + + use_presence_eval: True + + video_transforms_val: + - _target_: sam3.train.transforms.basic_for_api.ComposeAPI + transforms: + - _target_: sam3.train.transforms.segmentation.DecodeRle + # resize the image to 1024x1024 resolution + - _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI + sizes: ${scratch.resolution} # originally `resolution: 1024` + square: true + consistent_transform: true + - _target_: sam3.train.transforms.basic_for_api.ToTensorAPI + - _target_: sam3.train.transforms.basic_for_api.NormalizeAPI + mean: ${scratch.val_norm_mean} + std: ${scratch.val_norm_std} + + # Model parameters + d_model: 256 + + # Image processing parameters + resolution: 1008 + + # Normalization parameters + train_norm_mean: [0.5, 0.5, 0.5] + train_norm_std: [0.5, 0.5, 0.5] + val_norm_mean: [0.5, 0.5, 0.5] + val_norm_std: [0.5, 0.5, 0.5] + + val_batch_size: 1 + num_val_workers: 0 + max_data_epochs: 20 + hybrid_repeats: 1 + gather_pred_via_filesys: false + + +# ============================================================================ +# Trainer Configuration +# ============================================================================ + +trainer: + _target_: sam3.train.trainer.Trainer + skip_saving_ckpts: true + empty_gpu_mem_cache_after_eval: True + skip_first_val: True + max_epochs: ${scratch.max_data_epochs} + accelerator: cuda + seed_value: 123 + val_epoch_freq: 10 + mode: val + + distributed: + backend: nccl + find_unused_parameters: True + gradient_as_bucket_view: True + + loss: + all: + _target_: sam3.train.loss.sam3_loss.DummyLoss + default: + _target_: sam3.train.loss.sam3_loss.DummyLoss + + data: + train: null + val: + _target_: sam3.train.data.torch_dataset.TorchDataset + dataset: + _target_: sam3.train.data.sam3_video_dataset.VideoGroundingDataset + limit_ids: ${paths.num_videos} + img_folder: ${paths.ytvis_dir} + ann_file: ${paths.ytvis_json} + coco_json_loader: + _target_: sam3.train.data.coco_json_loaders.SAM3_VEVAL_API_FROM_JSON_NP + _partial_: true + + transforms: ${scratch.video_transforms_val} + max_ann_per_img: 100000 # filtered in transforms + max_val_queries: 100000 + multiplier: 1 + load_segmentation: true + training: false + + + shuffle: False + batch_size: ${scratch.val_batch_size} + num_workers: ${scratch.num_val_workers} + pin_memory: True + drop_last: False + collate_fn: + _target_: sam3.train.data.collator.collate_fn_api + _partial_: true + repeats: ${scratch.hybrid_repeats} + dict_key: ytvis_val + with_seg_masks: true + + + model: + _target_: sam3.model_builder.build_sam3_video_model + bpe_path: ${paths.bpe_path} + has_presence_token: True + geo_encoder_use_img_cross_attn: True + apply_temporal_disambiguation: False + + meters: + val: + ytvis_val: + pred_file: # key + _target_: sam3.eval.ytvis_eval.YTVISResultsWriter + dump_file: ${launcher.experiment_log_dir}/preds/${paths.dump_file_name}.json + postprocessor: ${scratch.vid_mask_postprocessor} + gather_pred_via_filesys: ${scratch.gather_pred_via_filesys} + + optim: + amp: + enabled: True + amp_dtype: bfloat16 + + + checkpoint: + save_dir: ${launcher.experiment_log_dir}/checkpoints + save_freq: 0 # 0 only last checkpoint is saved. + + + logging: + tensorboard_writer: + _target_: sam3.train.utils.logger.make_tensorboard_logger + log_dir: ${launcher.experiment_log_dir}/tensorboard + flush_secs: 120 + should_log: True + wandb_writer: null + log_dir: ${launcher.experiment_log_dir}/logs/ + log_freq: 10 + +# ============================================================================ +# Launcher and Submitit Configuration +# ============================================================================ + +launcher: + num_nodes: 8 + gpus_per_node: 8 + experiment_log_dir: ${paths.experiment_log_dir} + multiprocessing_context: forkserver + +submitit: + account: null + partition: null + qos: null + timeout_hour: 72 + use_cluster: True + cpus_per_task: 10 + port_range: [10000, 65000] + constraint: null diff --git a/third_party/sam3/sam3/train/configs/saco_video_evals/saco_veval_smartglasses_test.yaml b/third_party/sam3/sam3/train/configs/saco_video_evals/saco_veval_smartglasses_test.yaml new file mode 100644 index 0000000000000000000000000000000000000000..797fcab9446706490d2fa1aad4da98dfa08915de --- /dev/null +++ b/third_party/sam3/sam3/train/configs/saco_video_evals/saco_veval_smartglasses_test.yaml @@ -0,0 +1,174 @@ +# @package _global_ +defaults: + - _self_ + +# ============================================================================ +# Paths Configuration (Chage this to your own paths) +# ============================================================================ +paths: + + dump_file_name: saco_veval_smartglasses_test + experiment_log_dir: + ytvis_json: /saco_veval_smartglasses_test.json + ytvis_dir : + bpe_path: # This should be under sam3/assets/bpe_simple_vocab_16e6.txt.gz + num_videos: null + +# ============================================================================ +# Different helper parameters and functions +# ============================================================================ +scratch: + vid_mask_postprocessor: + _target_: sam3.eval.postprocessors.PostProcessNullOp + + use_presence_eval: True + + video_transforms_val: + - _target_: sam3.train.transforms.basic_for_api.ComposeAPI + transforms: + - _target_: sam3.train.transforms.segmentation.DecodeRle + # resize the image to 1024x1024 resolution + - _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI + sizes: ${scratch.resolution} # originally `resolution: 1024` + square: true + consistent_transform: true + - _target_: sam3.train.transforms.basic_for_api.ToTensorAPI + - _target_: sam3.train.transforms.basic_for_api.NormalizeAPI + mean: ${scratch.val_norm_mean} + std: ${scratch.val_norm_std} + + # Model parameters + d_model: 256 + + # Image processing parameters + resolution: 1008 + + # Normalization parameters + train_norm_mean: [0.5, 0.5, 0.5] + train_norm_std: [0.5, 0.5, 0.5] + val_norm_mean: [0.5, 0.5, 0.5] + val_norm_std: [0.5, 0.5, 0.5] + + val_batch_size: 1 + num_val_workers: 0 + max_data_epochs: 20 + hybrid_repeats: 1 + gather_pred_via_filesys: false + + +# ============================================================================ +# Trainer Configuration +# ============================================================================ + +trainer: + _target_: sam3.train.trainer.Trainer + skip_saving_ckpts: true + empty_gpu_mem_cache_after_eval: True + skip_first_val: True + max_epochs: ${scratch.max_data_epochs} + accelerator: cuda + seed_value: 123 + val_epoch_freq: 10 + mode: val + + distributed: + backend: nccl + find_unused_parameters: True + gradient_as_bucket_view: True + + loss: + all: + _target_: sam3.train.loss.sam3_loss.DummyLoss + default: + _target_: sam3.train.loss.sam3_loss.DummyLoss + + data: + train: null + val: + _target_: sam3.train.data.torch_dataset.TorchDataset + dataset: + _target_: sam3.train.data.sam3_video_dataset.VideoGroundingDataset + limit_ids: ${paths.num_videos} + img_folder: ${paths.ytvis_dir} + ann_file: ${paths.ytvis_json} + coco_json_loader: + _target_: sam3.train.data.coco_json_loaders.SAM3_VEVAL_API_FROM_JSON_NP + _partial_: true + + transforms: ${scratch.video_transforms_val} + max_ann_per_img: 100000 # filtered in transforms + max_val_queries: 100000 + multiplier: 1 + load_segmentation: true + training: false + + + shuffle: False + batch_size: ${scratch.val_batch_size} + num_workers: ${scratch.num_val_workers} + pin_memory: True + drop_last: False + collate_fn: + _target_: sam3.train.data.collator.collate_fn_api + _partial_: true + repeats: ${scratch.hybrid_repeats} + dict_key: ytvis_val + with_seg_masks: true + + + model: + _target_: sam3.model_builder.build_sam3_video_model + bpe_path: ${paths.bpe_path} + has_presence_token: True + geo_encoder_use_img_cross_attn: True + apply_temporal_disambiguation: True + + meters: + val: + ytvis_val: + pred_file: # key + _target_: sam3.eval.ytvis_eval.YTVISResultsWriter + dump_file: ${launcher.experiment_log_dir}/preds/${paths.dump_file_name}.json + postprocessor: ${scratch.vid_mask_postprocessor} + gather_pred_via_filesys: ${scratch.gather_pred_via_filesys} + + optim: + amp: + enabled: True + amp_dtype: bfloat16 + + + checkpoint: + save_dir: ${launcher.experiment_log_dir}/checkpoints + save_freq: 0 # 0 only last checkpoint is saved. + + + logging: + tensorboard_writer: + _target_: sam3.train.utils.logger.make_tensorboard_logger + log_dir: ${launcher.experiment_log_dir}/tensorboard + flush_secs: 120 + should_log: True + wandb_writer: null + log_dir: ${launcher.experiment_log_dir}/logs/ + log_freq: 10 + +# ============================================================================ +# Launcher and Submitit Configuration +# ============================================================================ + +launcher: + num_nodes: 8 + gpus_per_node: 8 + experiment_log_dir: ${paths.experiment_log_dir} + multiprocessing_context: forkserver + +submitit: + account: null + partition: null + qos: null + timeout_hour: 72 + use_cluster: True + cpus_per_task: 10 + port_range: [10000, 65000] + constraint: null diff --git a/third_party/sam3/sam3/train/configs/saco_video_evals/saco_veval_smartglasses_test_noheur.yaml b/third_party/sam3/sam3/train/configs/saco_video_evals/saco_veval_smartglasses_test_noheur.yaml new file mode 100644 index 0000000000000000000000000000000000000000..15f948a82f7763a61f54d56fbf0343f4876d1f01 --- /dev/null +++ b/third_party/sam3/sam3/train/configs/saco_video_evals/saco_veval_smartglasses_test_noheur.yaml @@ -0,0 +1,174 @@ +# @package _global_ +defaults: + - _self_ + +# ============================================================================ +# Paths Configuration (Chage this to your own paths) +# ============================================================================ +paths: + + dump_file_name: saco_veval_smartglasses_test + experiment_log_dir: + ytvis_json: /saco_veval_smartglasses_test.json + ytvis_dir : + bpe_path: # This should be under sam3/assets/bpe_simple_vocab_16e6.txt.gz + num_videos: null + +# ============================================================================ +# Different helper parameters and functions +# ============================================================================ +scratch: + vid_mask_postprocessor: + _target_: sam3.eval.postprocessors.PostProcessNullOp + + use_presence_eval: True + + video_transforms_val: + - _target_: sam3.train.transforms.basic_for_api.ComposeAPI + transforms: + - _target_: sam3.train.transforms.segmentation.DecodeRle + # resize the image to 1024x1024 resolution + - _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI + sizes: ${scratch.resolution} # originally `resolution: 1024` + square: true + consistent_transform: true + - _target_: sam3.train.transforms.basic_for_api.ToTensorAPI + - _target_: sam3.train.transforms.basic_for_api.NormalizeAPI + mean: ${scratch.val_norm_mean} + std: ${scratch.val_norm_std} + + # Model parameters + d_model: 256 + + # Image processing parameters + resolution: 1008 + + # Normalization parameters + train_norm_mean: [0.5, 0.5, 0.5] + train_norm_std: [0.5, 0.5, 0.5] + val_norm_mean: [0.5, 0.5, 0.5] + val_norm_std: [0.5, 0.5, 0.5] + + val_batch_size: 1 + num_val_workers: 0 + max_data_epochs: 20 + hybrid_repeats: 1 + gather_pred_via_filesys: false + + +# ============================================================================ +# Trainer Configuration +# ============================================================================ + +trainer: + _target_: sam3.train.trainer.Trainer + skip_saving_ckpts: true + empty_gpu_mem_cache_after_eval: True + skip_first_val: True + max_epochs: ${scratch.max_data_epochs} + accelerator: cuda + seed_value: 123 + val_epoch_freq: 10 + mode: val + + distributed: + backend: nccl + find_unused_parameters: True + gradient_as_bucket_view: True + + loss: + all: + _target_: sam3.train.loss.sam3_loss.DummyLoss + default: + _target_: sam3.train.loss.sam3_loss.DummyLoss + + data: + train: null + val: + _target_: sam3.train.data.torch_dataset.TorchDataset + dataset: + _target_: sam3.train.data.sam3_video_dataset.VideoGroundingDataset + limit_ids: ${paths.num_videos} + img_folder: ${paths.ytvis_dir} + ann_file: ${paths.ytvis_json} + coco_json_loader: + _target_: sam3.train.data.coco_json_loaders.SAM3_VEVAL_API_FROM_JSON_NP + _partial_: true + + transforms: ${scratch.video_transforms_val} + max_ann_per_img: 100000 # filtered in transforms + max_val_queries: 100000 + multiplier: 1 + load_segmentation: true + training: false + + + shuffle: False + batch_size: ${scratch.val_batch_size} + num_workers: ${scratch.num_val_workers} + pin_memory: True + drop_last: False + collate_fn: + _target_: sam3.train.data.collator.collate_fn_api + _partial_: true + repeats: ${scratch.hybrid_repeats} + dict_key: ytvis_val + with_seg_masks: true + + + model: + _target_: sam3.model_builder.build_sam3_video_model + bpe_path: ${paths.bpe_path} + has_presence_token: True + geo_encoder_use_img_cross_attn: True + apply_temporal_disambiguation: False + + meters: + val: + ytvis_val: + pred_file: # key + _target_: sam3.eval.ytvis_eval.YTVISResultsWriter + dump_file: ${launcher.experiment_log_dir}/preds/${paths.dump_file_name}.json + postprocessor: ${scratch.vid_mask_postprocessor} + gather_pred_via_filesys: ${scratch.gather_pred_via_filesys} + + optim: + amp: + enabled: True + amp_dtype: bfloat16 + + + checkpoint: + save_dir: ${launcher.experiment_log_dir}/checkpoints + save_freq: 0 # 0 only last checkpoint is saved. + + + logging: + tensorboard_writer: + _target_: sam3.train.utils.logger.make_tensorboard_logger + log_dir: ${launcher.experiment_log_dir}/tensorboard + flush_secs: 120 + should_log: True + wandb_writer: null + log_dir: ${launcher.experiment_log_dir}/logs/ + log_freq: 10 + +# ============================================================================ +# Launcher and Submitit Configuration +# ============================================================================ + +launcher: + num_nodes: 8 + gpus_per_node: 8 + experiment_log_dir: ${paths.experiment_log_dir} + multiprocessing_context: forkserver + +submitit: + account: null + partition: null + qos: null + timeout_hour: 72 + use_cluster: True + cpus_per_task: 10 + port_range: [10000, 65000] + constraint: null diff --git a/third_party/sam3/sam3/train/configs/saco_video_evals/saco_veval_smartglasses_val.yaml b/third_party/sam3/sam3/train/configs/saco_video_evals/saco_veval_smartglasses_val.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d2e4a857570b02a68fdcacc42815d81324e0e709 --- /dev/null +++ b/third_party/sam3/sam3/train/configs/saco_video_evals/saco_veval_smartglasses_val.yaml @@ -0,0 +1,174 @@ +# @package _global_ +defaults: + - _self_ + +# ============================================================================ +# Paths Configuration (Chage this to your own paths) +# ============================================================================ +paths: + + dump_file_name: saco_veval_smartglasses_val + experiment_log_dir: + ytvis_json: /saco_veval_smartglasses_val.json + ytvis_dir : + bpe_path: # This should be under sam3/assets/bpe_simple_vocab_16e6.txt.gz + num_videos: null + +# ============================================================================ +# Different helper parameters and functions +# ============================================================================ +scratch: + vid_mask_postprocessor: + _target_: sam3.eval.postprocessors.PostProcessNullOp + + use_presence_eval: True + + video_transforms_val: + - _target_: sam3.train.transforms.basic_for_api.ComposeAPI + transforms: + - _target_: sam3.train.transforms.segmentation.DecodeRle + # resize the image to 1024x1024 resolution + - _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI + sizes: ${scratch.resolution} # originally `resolution: 1024` + square: true + consistent_transform: true + - _target_: sam3.train.transforms.basic_for_api.ToTensorAPI + - _target_: sam3.train.transforms.basic_for_api.NormalizeAPI + mean: ${scratch.val_norm_mean} + std: ${scratch.val_norm_std} + + # Model parameters + d_model: 256 + + # Image processing parameters + resolution: 1008 + + # Normalization parameters + train_norm_mean: [0.5, 0.5, 0.5] + train_norm_std: [0.5, 0.5, 0.5] + val_norm_mean: [0.5, 0.5, 0.5] + val_norm_std: [0.5, 0.5, 0.5] + + val_batch_size: 1 + num_val_workers: 0 + max_data_epochs: 20 + hybrid_repeats: 1 + gather_pred_via_filesys: false + + +# ============================================================================ +# Trainer Configuration +# ============================================================================ + +trainer: + _target_: sam3.train.trainer.Trainer + skip_saving_ckpts: true + empty_gpu_mem_cache_after_eval: True + skip_first_val: True + max_epochs: ${scratch.max_data_epochs} + accelerator: cuda + seed_value: 123 + val_epoch_freq: 10 + mode: val + + distributed: + backend: nccl + find_unused_parameters: True + gradient_as_bucket_view: True + + loss: + all: + _target_: sam3.train.loss.sam3_loss.DummyLoss + default: + _target_: sam3.train.loss.sam3_loss.DummyLoss + + data: + train: null + val: + _target_: sam3.train.data.torch_dataset.TorchDataset + dataset: + _target_: sam3.train.data.sam3_video_dataset.VideoGroundingDataset + limit_ids: ${paths.num_videos} + img_folder: ${paths.ytvis_dir} + ann_file: ${paths.ytvis_json} + coco_json_loader: + _target_: sam3.train.data.coco_json_loaders.SAM3_VEVAL_API_FROM_JSON_NP + _partial_: true + + transforms: ${scratch.video_transforms_val} + max_ann_per_img: 100000 # filtered in transforms + max_val_queries: 100000 + multiplier: 1 + load_segmentation: true + training: false + + + shuffle: False + batch_size: ${scratch.val_batch_size} + num_workers: ${scratch.num_val_workers} + pin_memory: True + drop_last: False + collate_fn: + _target_: sam3.train.data.collator.collate_fn_api + _partial_: true + repeats: ${scratch.hybrid_repeats} + dict_key: ytvis_val + with_seg_masks: true + + + model: + _target_: sam3.model_builder.build_sam3_video_model + bpe_path: ${paths.bpe_path} + has_presence_token: True + geo_encoder_use_img_cross_attn: True + apply_temporal_disambiguation: True + + meters: + val: + ytvis_val: + pred_file: # key + _target_: sam3.eval.ytvis_eval.YTVISResultsWriter + dump_file: ${launcher.experiment_log_dir}/preds/${paths.dump_file_name}.json + postprocessor: ${scratch.vid_mask_postprocessor} + gather_pred_via_filesys: ${scratch.gather_pred_via_filesys} + + optim: + amp: + enabled: True + amp_dtype: bfloat16 + + + checkpoint: + save_dir: ${launcher.experiment_log_dir}/checkpoints + save_freq: 0 # 0 only last checkpoint is saved. + + + logging: + tensorboard_writer: + _target_: sam3.train.utils.logger.make_tensorboard_logger + log_dir: ${launcher.experiment_log_dir}/tensorboard + flush_secs: 120 + should_log: True + wandb_writer: null + log_dir: ${launcher.experiment_log_dir}/logs/ + log_freq: 10 + +# ============================================================================ +# Launcher and Submitit Configuration +# ============================================================================ + +launcher: + num_nodes: 8 + gpus_per_node: 8 + experiment_log_dir: ${paths.experiment_log_dir} + multiprocessing_context: forkserver + +submitit: + account: null + partition: null + qos: null + timeout_hour: 72 + use_cluster: True + cpus_per_task: 10 + port_range: [10000, 65000] + constraint: null diff --git a/third_party/sam3/sam3/train/configs/saco_video_evals/saco_veval_smartglasses_val_noheur.yaml b/third_party/sam3/sam3/train/configs/saco_video_evals/saco_veval_smartglasses_val_noheur.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1dd72c0afddaf1417ad3e1e757f1047eaf3d8f1b --- /dev/null +++ b/third_party/sam3/sam3/train/configs/saco_video_evals/saco_veval_smartglasses_val_noheur.yaml @@ -0,0 +1,174 @@ +# @package _global_ +defaults: + - _self_ + +# ============================================================================ +# Paths Configuration (Chage this to your own paths) +# ============================================================================ +paths: + + dump_file_name: saco_veval_smartglasses_val + experiment_log_dir: + ytvis_json: /saco_veval_smartglasses_val.json + ytvis_dir : + bpe_path: # This should be under sam3/assets/bpe_simple_vocab_16e6.txt.gz + num_videos: null + +# ============================================================================ +# Different helper parameters and functions +# ============================================================================ +scratch: + vid_mask_postprocessor: + _target_: sam3.eval.postprocessors.PostProcessNullOp + + use_presence_eval: True + + video_transforms_val: + - _target_: sam3.train.transforms.basic_for_api.ComposeAPI + transforms: + - _target_: sam3.train.transforms.segmentation.DecodeRle + # resize the image to 1024x1024 resolution + - _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI + sizes: ${scratch.resolution} # originally `resolution: 1024` + square: true + consistent_transform: true + - _target_: sam3.train.transforms.basic_for_api.ToTensorAPI + - _target_: sam3.train.transforms.basic_for_api.NormalizeAPI + mean: ${scratch.val_norm_mean} + std: ${scratch.val_norm_std} + + # Model parameters + d_model: 256 + + # Image processing parameters + resolution: 1008 + + # Normalization parameters + train_norm_mean: [0.5, 0.5, 0.5] + train_norm_std: [0.5, 0.5, 0.5] + val_norm_mean: [0.5, 0.5, 0.5] + val_norm_std: [0.5, 0.5, 0.5] + + val_batch_size: 1 + num_val_workers: 0 + max_data_epochs: 20 + hybrid_repeats: 1 + gather_pred_via_filesys: false + + +# ============================================================================ +# Trainer Configuration +# ============================================================================ + +trainer: + _target_: sam3.train.trainer.Trainer + skip_saving_ckpts: true + empty_gpu_mem_cache_after_eval: True + skip_first_val: True + max_epochs: ${scratch.max_data_epochs} + accelerator: cuda + seed_value: 123 + val_epoch_freq: 10 + mode: val + + distributed: + backend: nccl + find_unused_parameters: True + gradient_as_bucket_view: True + + loss: + all: + _target_: sam3.train.loss.sam3_loss.DummyLoss + default: + _target_: sam3.train.loss.sam3_loss.DummyLoss + + data: + train: null + val: + _target_: sam3.train.data.torch_dataset.TorchDataset + dataset: + _target_: sam3.train.data.sam3_video_dataset.VideoGroundingDataset + limit_ids: ${paths.num_videos} + img_folder: ${paths.ytvis_dir} + ann_file: ${paths.ytvis_json} + coco_json_loader: + _target_: sam3.train.data.coco_json_loaders.SAM3_VEVAL_API_FROM_JSON_NP + _partial_: true + + transforms: ${scratch.video_transforms_val} + max_ann_per_img: 100000 # filtered in transforms + max_val_queries: 100000 + multiplier: 1 + load_segmentation: true + training: false + + + shuffle: False + batch_size: ${scratch.val_batch_size} + num_workers: ${scratch.num_val_workers} + pin_memory: True + drop_last: False + collate_fn: + _target_: sam3.train.data.collator.collate_fn_api + _partial_: true + repeats: ${scratch.hybrid_repeats} + dict_key: ytvis_val + with_seg_masks: true + + + model: + _target_: sam3.model_builder.build_sam3_video_model + bpe_path: ${paths.bpe_path} + has_presence_token: True + geo_encoder_use_img_cross_attn: True + apply_temporal_disambiguation: False + + meters: + val: + ytvis_val: + pred_file: # key + _target_: sam3.eval.ytvis_eval.YTVISResultsWriter + dump_file: ${launcher.experiment_log_dir}/preds/${paths.dump_file_name}.json + postprocessor: ${scratch.vid_mask_postprocessor} + gather_pred_via_filesys: ${scratch.gather_pred_via_filesys} + + optim: + amp: + enabled: True + amp_dtype: bfloat16 + + + checkpoint: + save_dir: ${launcher.experiment_log_dir}/checkpoints + save_freq: 0 # 0 only last checkpoint is saved. + + + logging: + tensorboard_writer: + _target_: sam3.train.utils.logger.make_tensorboard_logger + log_dir: ${launcher.experiment_log_dir}/tensorboard + flush_secs: 120 + should_log: True + wandb_writer: null + log_dir: ${launcher.experiment_log_dir}/logs/ + log_freq: 10 + +# ============================================================================ +# Launcher and Submitit Configuration +# ============================================================================ + +launcher: + num_nodes: 8 + gpus_per_node: 8 + experiment_log_dir: ${paths.experiment_log_dir} + multiprocessing_context: forkserver + +submitit: + account: null + partition: null + qos: null + timeout_hour: 72 + use_cluster: True + cpus_per_task: 10 + port_range: [10000, 65000] + constraint: null diff --git a/third_party/sam3/sam3/train/configs/saco_video_evals/saco_veval_yt1b_test.yaml b/third_party/sam3/sam3/train/configs/saco_video_evals/saco_veval_yt1b_test.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a94ed749c84a430e553a9e3f819a5c88b1169fc4 --- /dev/null +++ b/third_party/sam3/sam3/train/configs/saco_video_evals/saco_veval_yt1b_test.yaml @@ -0,0 +1,174 @@ +# @package _global_ +defaults: + - _self_ + +# ============================================================================ +# Paths Configuration (Chage this to your own paths) +# ============================================================================ +paths: + + dump_file_name: saco_veval_yt1b_test + experiment_log_dir: + ytvis_json: /saco_veval_yt1b_test.json + ytvis_dir : + bpe_path: # This should be under sam3/assets/bpe_simple_vocab_16e6.txt.gz + num_videos: null + +# ============================================================================ +# Different helper parameters and functions +# ============================================================================ +scratch: + vid_mask_postprocessor: + _target_: sam3.eval.postprocessors.PostProcessNullOp + + use_presence_eval: True + + video_transforms_val: + - _target_: sam3.train.transforms.basic_for_api.ComposeAPI + transforms: + - _target_: sam3.train.transforms.segmentation.DecodeRle + # resize the image to 1024x1024 resolution + - _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI + sizes: ${scratch.resolution} # originally `resolution: 1024` + square: true + consistent_transform: true + - _target_: sam3.train.transforms.basic_for_api.ToTensorAPI + - _target_: sam3.train.transforms.basic_for_api.NormalizeAPI + mean: ${scratch.val_norm_mean} + std: ${scratch.val_norm_std} + + # Model parameters + d_model: 256 + + # Image processing parameters + resolution: 1008 + + # Normalization parameters + train_norm_mean: [0.5, 0.5, 0.5] + train_norm_std: [0.5, 0.5, 0.5] + val_norm_mean: [0.5, 0.5, 0.5] + val_norm_std: [0.5, 0.5, 0.5] + + val_batch_size: 1 + num_val_workers: 0 + max_data_epochs: 20 + hybrid_repeats: 1 + gather_pred_via_filesys: false + + +# ============================================================================ +# Trainer Configuration +# ============================================================================ + +trainer: + _target_: sam3.train.trainer.Trainer + skip_saving_ckpts: true + empty_gpu_mem_cache_after_eval: True + skip_first_val: True + max_epochs: ${scratch.max_data_epochs} + accelerator: cuda + seed_value: 123 + val_epoch_freq: 10 + mode: val + + distributed: + backend: nccl + find_unused_parameters: True + gradient_as_bucket_view: True + + loss: + all: + _target_: sam3.train.loss.sam3_loss.DummyLoss + default: + _target_: sam3.train.loss.sam3_loss.DummyLoss + + data: + train: null + val: + _target_: sam3.train.data.torch_dataset.TorchDataset + dataset: + _target_: sam3.train.data.sam3_video_dataset.VideoGroundingDataset + limit_ids: ${paths.num_videos} + img_folder: ${paths.ytvis_dir} + ann_file: ${paths.ytvis_json} + coco_json_loader: + _target_: sam3.train.data.coco_json_loaders.SAM3_VEVAL_API_FROM_JSON_NP + _partial_: true + + transforms: ${scratch.video_transforms_val} + max_ann_per_img: 100000 # filtered in transforms + max_val_queries: 100000 + multiplier: 1 + load_segmentation: true + training: false + + + shuffle: False + batch_size: ${scratch.val_batch_size} + num_workers: ${scratch.num_val_workers} + pin_memory: True + drop_last: False + collate_fn: + _target_: sam3.train.data.collator.collate_fn_api + _partial_: true + repeats: ${scratch.hybrid_repeats} + dict_key: ytvis_val + with_seg_masks: true + + + model: + _target_: sam3.model_builder.build_sam3_video_model + bpe_path: ${paths.bpe_path} + has_presence_token: True + geo_encoder_use_img_cross_attn: True + apply_temporal_disambiguation: True + + meters: + val: + ytvis_val: + pred_file: # key + _target_: sam3.eval.ytvis_eval.YTVISResultsWriter + dump_file: ${launcher.experiment_log_dir}/preds/${paths.dump_file_name}.json + postprocessor: ${scratch.vid_mask_postprocessor} + gather_pred_via_filesys: ${scratch.gather_pred_via_filesys} + + optim: + amp: + enabled: True + amp_dtype: bfloat16 + + + checkpoint: + save_dir: ${launcher.experiment_log_dir}/checkpoints + save_freq: 0 # 0 only last checkpoint is saved. + + + logging: + tensorboard_writer: + _target_: sam3.train.utils.logger.make_tensorboard_logger + log_dir: ${launcher.experiment_log_dir}/tensorboard + flush_secs: 120 + should_log: True + wandb_writer: null + log_dir: ${launcher.experiment_log_dir}/logs/ + log_freq: 10 + +# ============================================================================ +# Launcher and Submitit Configuration +# ============================================================================ + +launcher: + num_nodes: 8 + gpus_per_node: 8 + experiment_log_dir: ${paths.experiment_log_dir} + multiprocessing_context: forkserver + +submitit: + account: null + partition: null + qos: null + timeout_hour: 72 + use_cluster: True + cpus_per_task: 10 + port_range: [10000, 65000] + constraint: null diff --git a/third_party/sam3/sam3/train/configs/saco_video_evals/saco_veval_yt1b_test_noheur.yaml b/third_party/sam3/sam3/train/configs/saco_video_evals/saco_veval_yt1b_test_noheur.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e2c4b385f13ba99d7e4c0df514c45a1b4df81df6 --- /dev/null +++ b/third_party/sam3/sam3/train/configs/saco_video_evals/saco_veval_yt1b_test_noheur.yaml @@ -0,0 +1,174 @@ +# @package _global_ +defaults: + - _self_ + +# ============================================================================ +# Paths Configuration (Chage this to your own paths) +# ============================================================================ +paths: + + dump_file_name: saco_veval_yt1b_test + experiment_log_dir: + ytvis_json: /saco_veval_yt1b_test.json + ytvis_dir : + bpe_path: # This should be under sam3/assets/bpe_simple_vocab_16e6.txt.gz + num_videos: null + +# ============================================================================ +# Different helper parameters and functions +# ============================================================================ +scratch: + vid_mask_postprocessor: + _target_: sam3.eval.postprocessors.PostProcessNullOp + + use_presence_eval: True + + video_transforms_val: + - _target_: sam3.train.transforms.basic_for_api.ComposeAPI + transforms: + - _target_: sam3.train.transforms.segmentation.DecodeRle + # resize the image to 1024x1024 resolution + - _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI + sizes: ${scratch.resolution} # originally `resolution: 1024` + square: true + consistent_transform: true + - _target_: sam3.train.transforms.basic_for_api.ToTensorAPI + - _target_: sam3.train.transforms.basic_for_api.NormalizeAPI + mean: ${scratch.val_norm_mean} + std: ${scratch.val_norm_std} + + # Model parameters + d_model: 256 + + # Image processing parameters + resolution: 1008 + + # Normalization parameters + train_norm_mean: [0.5, 0.5, 0.5] + train_norm_std: [0.5, 0.5, 0.5] + val_norm_mean: [0.5, 0.5, 0.5] + val_norm_std: [0.5, 0.5, 0.5] + + val_batch_size: 1 + num_val_workers: 0 + max_data_epochs: 20 + hybrid_repeats: 1 + gather_pred_via_filesys: false + + +# ============================================================================ +# Trainer Configuration +# ============================================================================ + +trainer: + _target_: sam3.train.trainer.Trainer + skip_saving_ckpts: true + empty_gpu_mem_cache_after_eval: True + skip_first_val: True + max_epochs: ${scratch.max_data_epochs} + accelerator: cuda + seed_value: 123 + val_epoch_freq: 10 + mode: val + + distributed: + backend: nccl + find_unused_parameters: True + gradient_as_bucket_view: True + + loss: + all: + _target_: sam3.train.loss.sam3_loss.DummyLoss + default: + _target_: sam3.train.loss.sam3_loss.DummyLoss + + data: + train: null + val: + _target_: sam3.train.data.torch_dataset.TorchDataset + dataset: + _target_: sam3.train.data.sam3_video_dataset.VideoGroundingDataset + limit_ids: ${paths.num_videos} + img_folder: ${paths.ytvis_dir} + ann_file: ${paths.ytvis_json} + coco_json_loader: + _target_: sam3.train.data.coco_json_loaders.SAM3_VEVAL_API_FROM_JSON_NP + _partial_: true + + transforms: ${scratch.video_transforms_val} + max_ann_per_img: 100000 # filtered in transforms + max_val_queries: 100000 + multiplier: 1 + load_segmentation: true + training: false + + + shuffle: False + batch_size: ${scratch.val_batch_size} + num_workers: ${scratch.num_val_workers} + pin_memory: True + drop_last: False + collate_fn: + _target_: sam3.train.data.collator.collate_fn_api + _partial_: true + repeats: ${scratch.hybrid_repeats} + dict_key: ytvis_val + with_seg_masks: true + + + model: + _target_: sam3.model_builder.build_sam3_video_model + bpe_path: ${paths.bpe_path} + has_presence_token: True + geo_encoder_use_img_cross_attn: True + apply_temporal_disambiguation: False + + meters: + val: + ytvis_val: + pred_file: # key + _target_: sam3.eval.ytvis_eval.YTVISResultsWriter + dump_file: ${launcher.experiment_log_dir}/preds/${paths.dump_file_name}.json + postprocessor: ${scratch.vid_mask_postprocessor} + gather_pred_via_filesys: ${scratch.gather_pred_via_filesys} + + optim: + amp: + enabled: True + amp_dtype: bfloat16 + + + checkpoint: + save_dir: ${launcher.experiment_log_dir}/checkpoints + save_freq: 0 # 0 only last checkpoint is saved. + + + logging: + tensorboard_writer: + _target_: sam3.train.utils.logger.make_tensorboard_logger + log_dir: ${launcher.experiment_log_dir}/tensorboard + flush_secs: 120 + should_log: True + wandb_writer: null + log_dir: ${launcher.experiment_log_dir}/logs/ + log_freq: 10 + +# ============================================================================ +# Launcher and Submitit Configuration +# ============================================================================ + +launcher: + num_nodes: 8 + gpus_per_node: 8 + experiment_log_dir: ${paths.experiment_log_dir} + multiprocessing_context: forkserver + +submitit: + account: null + partition: null + qos: null + timeout_hour: 72 + use_cluster: True + cpus_per_task: 10 + port_range: [10000, 65000] + constraint: null diff --git a/third_party/sam3/sam3/train/configs/saco_video_evals/saco_veval_yt1b_val.yaml b/third_party/sam3/sam3/train/configs/saco_video_evals/saco_veval_yt1b_val.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0efe0274f4bd419a3ecd444d1f5236fd0a34a5ee --- /dev/null +++ b/third_party/sam3/sam3/train/configs/saco_video_evals/saco_veval_yt1b_val.yaml @@ -0,0 +1,174 @@ +# @package _global_ +defaults: + - _self_ + +# ============================================================================ +# Paths Configuration (Chage this to your own paths) +# ============================================================================ +paths: + + dump_file_name: saco_veval_yt1b_val + experiment_log_dir: + ytvis_json: /saco_veval_yt1b_val.json + ytvis_dir : + bpe_path: # This should be under sam3/assets/bpe_simple_vocab_16e6.txt.gz + num_videos: null + +# ============================================================================ +# Different helper parameters and functions +# ============================================================================ +scratch: + vid_mask_postprocessor: + _target_: sam3.eval.postprocessors.PostProcessNullOp + + use_presence_eval: True + + video_transforms_val: + - _target_: sam3.train.transforms.basic_for_api.ComposeAPI + transforms: + - _target_: sam3.train.transforms.segmentation.DecodeRle + # resize the image to 1024x1024 resolution + - _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI + sizes: ${scratch.resolution} # originally `resolution: 1024` + square: true + consistent_transform: true + - _target_: sam3.train.transforms.basic_for_api.ToTensorAPI + - _target_: sam3.train.transforms.basic_for_api.NormalizeAPI + mean: ${scratch.val_norm_mean} + std: ${scratch.val_norm_std} + + # Model parameters + d_model: 256 + + # Image processing parameters + resolution: 1008 + + # Normalization parameters + train_norm_mean: [0.5, 0.5, 0.5] + train_norm_std: [0.5, 0.5, 0.5] + val_norm_mean: [0.5, 0.5, 0.5] + val_norm_std: [0.5, 0.5, 0.5] + + val_batch_size: 1 + num_val_workers: 0 + max_data_epochs: 20 + hybrid_repeats: 1 + gather_pred_via_filesys: false + + +# ============================================================================ +# Trainer Configuration +# ============================================================================ + +trainer: + _target_: sam3.train.trainer.Trainer + skip_saving_ckpts: true + empty_gpu_mem_cache_after_eval: True + skip_first_val: True + max_epochs: ${scratch.max_data_epochs} + accelerator: cuda + seed_value: 123 + val_epoch_freq: 10 + mode: val + + distributed: + backend: nccl + find_unused_parameters: True + gradient_as_bucket_view: True + + loss: + all: + _target_: sam3.train.loss.sam3_loss.DummyLoss + default: + _target_: sam3.train.loss.sam3_loss.DummyLoss + + data: + train: null + val: + _target_: sam3.train.data.torch_dataset.TorchDataset + dataset: + _target_: sam3.train.data.sam3_video_dataset.VideoGroundingDataset + limit_ids: ${paths.num_videos} + img_folder: ${paths.ytvis_dir} + ann_file: ${paths.ytvis_json} + coco_json_loader: + _target_: sam3.train.data.coco_json_loaders.SAM3_VEVAL_API_FROM_JSON_NP + _partial_: true + + transforms: ${scratch.video_transforms_val} + max_ann_per_img: 100000 # filtered in transforms + max_val_queries: 100000 + multiplier: 1 + load_segmentation: true + training: false + + + shuffle: False + batch_size: ${scratch.val_batch_size} + num_workers: ${scratch.num_val_workers} + pin_memory: True + drop_last: False + collate_fn: + _target_: sam3.train.data.collator.collate_fn_api + _partial_: true + repeats: ${scratch.hybrid_repeats} + dict_key: ytvis_val + with_seg_masks: true + + + model: + _target_: sam3.model_builder.build_sam3_video_model + bpe_path: ${paths.bpe_path} + has_presence_token: True + geo_encoder_use_img_cross_attn: True + apply_temporal_disambiguation: True + + meters: + val: + ytvis_val: + pred_file: # key + _target_: sam3.eval.ytvis_eval.YTVISResultsWriter + dump_file: ${launcher.experiment_log_dir}/preds/${paths.dump_file_name}.json + postprocessor: ${scratch.vid_mask_postprocessor} + gather_pred_via_filesys: ${scratch.gather_pred_via_filesys} + + optim: + amp: + enabled: True + amp_dtype: bfloat16 + + + checkpoint: + save_dir: ${launcher.experiment_log_dir}/checkpoints + save_freq: 0 # 0 only last checkpoint is saved. + + + logging: + tensorboard_writer: + _target_: sam3.train.utils.logger.make_tensorboard_logger + log_dir: ${launcher.experiment_log_dir}/tensorboard + flush_secs: 120 + should_log: True + wandb_writer: null + log_dir: ${launcher.experiment_log_dir}/logs/ + log_freq: 10 + +# ============================================================================ +# Launcher and Submitit Configuration +# ============================================================================ + +launcher: + num_nodes: 8 + gpus_per_node: 8 + experiment_log_dir: ${paths.experiment_log_dir} + multiprocessing_context: forkserver + +submitit: + account: null + partition: null + qos: null + timeout_hour: 72 + use_cluster: True + cpus_per_task: 10 + port_range: [10000, 65000] + constraint: null diff --git a/third_party/sam3/sam3/train/configs/saco_video_evals/saco_veval_yt1b_val_noheur.yaml b/third_party/sam3/sam3/train/configs/saco_video_evals/saco_veval_yt1b_val_noheur.yaml new file mode 100644 index 0000000000000000000000000000000000000000..108b563efd34a2110be9daa2a140705434f32f22 --- /dev/null +++ b/third_party/sam3/sam3/train/configs/saco_video_evals/saco_veval_yt1b_val_noheur.yaml @@ -0,0 +1,174 @@ +# @package _global_ +defaults: + - _self_ + +# ============================================================================ +# Paths Configuration (Chage this to your own paths) +# ============================================================================ +paths: + + dump_file_name: saco_veval_yt1b_val + experiment_log_dir: + ytvis_json: /saco_veval_yt1b_val.json + ytvis_dir : + bpe_path: # This should be under sam3/assets/bpe_simple_vocab_16e6.txt.gz + num_videos: null + +# ============================================================================ +# Different helper parameters and functions +# ============================================================================ +scratch: + vid_mask_postprocessor: + _target_: sam3.eval.postprocessors.PostProcessNullOp + + use_presence_eval: True + + video_transforms_val: + - _target_: sam3.train.transforms.basic_for_api.ComposeAPI + transforms: + - _target_: sam3.train.transforms.segmentation.DecodeRle + # resize the image to 1024x1024 resolution + - _target_: sam3.train.transforms.basic_for_api.RandomResizeAPI + sizes: ${scratch.resolution} # originally `resolution: 1024` + square: true + consistent_transform: true + - _target_: sam3.train.transforms.basic_for_api.ToTensorAPI + - _target_: sam3.train.transforms.basic_for_api.NormalizeAPI + mean: ${scratch.val_norm_mean} + std: ${scratch.val_norm_std} + + # Model parameters + d_model: 256 + + # Image processing parameters + resolution: 1008 + + # Normalization parameters + train_norm_mean: [0.5, 0.5, 0.5] + train_norm_std: [0.5, 0.5, 0.5] + val_norm_mean: [0.5, 0.5, 0.5] + val_norm_std: [0.5, 0.5, 0.5] + + val_batch_size: 1 + num_val_workers: 0 + max_data_epochs: 20 + hybrid_repeats: 1 + gather_pred_via_filesys: false + + +# ============================================================================ +# Trainer Configuration +# ============================================================================ + +trainer: + _target_: sam3.train.trainer.Trainer + skip_saving_ckpts: true + empty_gpu_mem_cache_after_eval: True + skip_first_val: True + max_epochs: ${scratch.max_data_epochs} + accelerator: cuda + seed_value: 123 + val_epoch_freq: 10 + mode: val + + distributed: + backend: nccl + find_unused_parameters: True + gradient_as_bucket_view: True + + loss: + all: + _target_: sam3.train.loss.sam3_loss.DummyLoss + default: + _target_: sam3.train.loss.sam3_loss.DummyLoss + + data: + train: null + val: + _target_: sam3.train.data.torch_dataset.TorchDataset + dataset: + _target_: sam3.train.data.sam3_video_dataset.VideoGroundingDataset + limit_ids: ${paths.num_videos} + img_folder: ${paths.ytvis_dir} + ann_file: ${paths.ytvis_json} + coco_json_loader: + _target_: sam3.train.data.coco_json_loaders.SAM3_VEVAL_API_FROM_JSON_NP + _partial_: true + + transforms: ${scratch.video_transforms_val} + max_ann_per_img: 100000 # filtered in transforms + max_val_queries: 100000 + multiplier: 1 + load_segmentation: true + training: false + + + shuffle: False + batch_size: ${scratch.val_batch_size} + num_workers: ${scratch.num_val_workers} + pin_memory: True + drop_last: False + collate_fn: + _target_: sam3.train.data.collator.collate_fn_api + _partial_: true + repeats: ${scratch.hybrid_repeats} + dict_key: ytvis_val + with_seg_masks: true + + + model: + _target_: sam3.model_builder.build_sam3_video_model + bpe_path: ${paths.bpe_path} + has_presence_token: True + geo_encoder_use_img_cross_attn: True + apply_temporal_disambiguation: False + + meters: + val: + ytvis_val: + pred_file: # key + _target_: sam3.eval.ytvis_eval.YTVISResultsWriter + dump_file: ${launcher.experiment_log_dir}/preds/${paths.dump_file_name}.json + postprocessor: ${scratch.vid_mask_postprocessor} + gather_pred_via_filesys: ${scratch.gather_pred_via_filesys} + + optim: + amp: + enabled: True + amp_dtype: bfloat16 + + + checkpoint: + save_dir: ${launcher.experiment_log_dir}/checkpoints + save_freq: 0 # 0 only last checkpoint is saved. + + + logging: + tensorboard_writer: + _target_: sam3.train.utils.logger.make_tensorboard_logger + log_dir: ${launcher.experiment_log_dir}/tensorboard + flush_secs: 120 + should_log: True + wandb_writer: null + log_dir: ${launcher.experiment_log_dir}/logs/ + log_freq: 10 + +# ============================================================================ +# Launcher and Submitit Configuration +# ============================================================================ + +launcher: + num_nodes: 8 + gpus_per_node: 8 + experiment_log_dir: ${paths.experiment_log_dir} + multiprocessing_context: forkserver + +submitit: + account: null + partition: null + qos: null + timeout_hour: 72 + use_cluster: True + cpus_per_task: 10 + port_range: [10000, 65000] + constraint: null diff --git a/third_party/sam3/sam3/train/configs/silver_image_evals/sam3_silver_image_bdd100k.yaml b/third_party/sam3/sam3/train/configs/silver_image_evals/sam3_silver_image_bdd100k.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e5587cfb76237bfa6db8b5467632b2691f876cdf --- /dev/null +++ b/third_party/sam3/sam3/train/configs/silver_image_evals/sam3_silver_image_bdd100k.yaml @@ -0,0 +1,64 @@ +# @package _global_ +defaults: + - /configs/eval_base.yaml + - _self_ + +# ============================================================================ +# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct +# ============================================================================ +paths: + experiment_log_dir: ${paths.base_experiment_log_dir}/silver_bdd100k/ + coco_gt: ${paths.base_annotation_path_silver}/silver_bdd100k_merged_test.json + img_path: ${paths.silver_img_path}/bdd100k/ + + + +# ============================================================================ +# Trainer Configuration +# ============================================================================ + +trainer: + data: + val: + _target_: sam3.train.data.torch_dataset.TorchDataset + dataset: + _target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset + coco_json_loader: + _target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP + _partial_: true + img_folder: ${paths.img_path} + ann_file: ${paths.coco_gt} + transforms: ${scratch.base_val_transform} + max_ann_per_img: 100000 + multiplier: 1 + training: false + + shuffle: False + batch_size: ${scratch.val_batch_size} + num_workers: ${scratch.num_val_workers} + pin_memory: False + drop_last: False + collate_fn: + _target_: sam3.train.data.collator.collate_fn_api + _partial_: true + repeats: ${scratch.hybrid_repeats} + dict_key: silver_bdd100k + + meters: + val: + silver_bdd100k: # this key matches the "dict_key" in the dataloader's collate function + cgf1: + _target_: sam3.eval.coco_writer.PredictionDumper + iou_type: "segm" + dump_dir: ${launcher.experiment_log_dir}/dumps/silver_bdd100k + merge_predictions: True + postprocessor: ${scratch.mask_postprocessor_thresholded} + gather_pred_via_filesys: ${scratch.gather_pred_via_filesys} + maxdets: 1000000 # no limit + pred_file_evaluators: + - _target_: sam3.eval.cgf1_eval.CGF1Evaluator + gt_path: ${paths.coco_gt} + iou_type: "bbox" + - _target_: sam3.eval.cgf1_eval.CGF1Evaluator + gt_path: ${paths.coco_gt} + iou_type: "segm" diff --git a/third_party/sam3/sam3/train/configs/silver_image_evals/sam3_silver_image_droid.yaml b/third_party/sam3/sam3/train/configs/silver_image_evals/sam3_silver_image_droid.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c0d62341ba915d5a04f9fc4d88d057aed15848f7 --- /dev/null +++ b/third_party/sam3/sam3/train/configs/silver_image_evals/sam3_silver_image_droid.yaml @@ -0,0 +1,64 @@ +# @package _global_ +defaults: + - /configs/eval_base.yaml + - _self_ + +# ============================================================================ +# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct +# ============================================================================ +paths: + experiment_log_dir: ${paths.base_experiment_log_dir}/silver_droid/ + coco_gt: ${paths.base_annotation_path_silver}/silver_droid_merged_test.json + img_path: ${paths.silver_img_path}/droid/ + + + +# ============================================================================ +# Trainer Configuration +# ============================================================================ + +trainer: + data: + val: + _target_: sam3.train.data.torch_dataset.TorchDataset + dataset: + _target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset + coco_json_loader: + _target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP + _partial_: true + img_folder: ${paths.img_path} + ann_file: ${paths.coco_gt} + transforms: ${scratch.base_val_transform} + max_ann_per_img: 100000 + multiplier: 1 + training: false + + shuffle: False + batch_size: ${scratch.val_batch_size} + num_workers: ${scratch.num_val_workers} + pin_memory: False + drop_last: False + collate_fn: + _target_: sam3.train.data.collator.collate_fn_api + _partial_: true + repeats: ${scratch.hybrid_repeats} + dict_key: silver_droid + + meters: + val: + silver_droid: # this key matches the "dict_key" in the dataloader's collate function + cgf1: + _target_: sam3.eval.coco_writer.PredictionDumper + iou_type: "segm" + dump_dir: ${launcher.experiment_log_dir}/dumps/silver_droid + merge_predictions: True + postprocessor: ${scratch.mask_postprocessor_thresholded} + gather_pred_via_filesys: ${scratch.gather_pred_via_filesys} + maxdets: 1000000 # no limit + pred_file_evaluators: + - _target_: sam3.eval.cgf1_eval.CGF1Evaluator + gt_path: ${paths.coco_gt} + iou_type: "bbox" + - _target_: sam3.eval.cgf1_eval.CGF1Evaluator + gt_path: ${paths.coco_gt} + iou_type: "segm" diff --git a/third_party/sam3/sam3/train/configs/silver_image_evals/sam3_silver_image_ego4d.yaml b/third_party/sam3/sam3/train/configs/silver_image_evals/sam3_silver_image_ego4d.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d5a036d93d44a093755462cf748b2ed66a1e8a4f --- /dev/null +++ b/third_party/sam3/sam3/train/configs/silver_image_evals/sam3_silver_image_ego4d.yaml @@ -0,0 +1,64 @@ +# @package _global_ +defaults: + - /configs/eval_base.yaml + - _self_ + +# ============================================================================ +# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct +# ============================================================================ +paths: + experiment_log_dir: ${paths.base_experiment_log_dir}/silver_ego4d/ + coco_gt: ${paths.base_annotation_path_silver}/silver_ego4d_merged_test.json + img_path: ${paths.silver_img_path}/ego4d/ + + + +# ============================================================================ +# Trainer Configuration +# ============================================================================ + +trainer: + data: + val: + _target_: sam3.train.data.torch_dataset.TorchDataset + dataset: + _target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset + coco_json_loader: + _target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP + _partial_: true + img_folder: ${paths.img_path} + ann_file: ${paths.coco_gt} + transforms: ${scratch.base_val_transform} + max_ann_per_img: 100000 + multiplier: 1 + training: false + + shuffle: False + batch_size: ${scratch.val_batch_size} + num_workers: ${scratch.num_val_workers} + pin_memory: False + drop_last: False + collate_fn: + _target_: sam3.train.data.collator.collate_fn_api + _partial_: true + repeats: ${scratch.hybrid_repeats} + dict_key: silver_ego4d + + meters: + val: + silver_ego4d: # this key matches the "dict_key" in the dataloader's collate function + cgf1: + _target_: sam3.eval.coco_writer.PredictionDumper + iou_type: "segm" + dump_dir: ${launcher.experiment_log_dir}/dumps/silver_ego4d + merge_predictions: True + postprocessor: ${scratch.mask_postprocessor_thresholded} + gather_pred_via_filesys: ${scratch.gather_pred_via_filesys} + maxdets: 1000000 # no limit + pred_file_evaluators: + - _target_: sam3.eval.cgf1_eval.CGF1Evaluator + gt_path: ${paths.coco_gt} + iou_type: "bbox" + - _target_: sam3.eval.cgf1_eval.CGF1Evaluator + gt_path: ${paths.coco_gt} + iou_type: "segm" diff --git a/third_party/sam3/sam3/train/configs/silver_image_evals/sam3_silver_image_fathomnet.yaml b/third_party/sam3/sam3/train/configs/silver_image_evals/sam3_silver_image_fathomnet.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b15d0c82328171d8ed0c9d4b52c35477813ca389 --- /dev/null +++ b/third_party/sam3/sam3/train/configs/silver_image_evals/sam3_silver_image_fathomnet.yaml @@ -0,0 +1,64 @@ +# @package _global_ +defaults: + - /configs/eval_base.yaml + - _self_ + +# ============================================================================ +# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct +# ============================================================================ +paths: + experiment_log_dir: ${paths.base_experiment_log_dir}/silver_fathomnet/ + coco_gt: ${paths.base_annotation_path_silver}/silver_fathomnet_test.json + img_path: ${paths.silver_img_path}/fathomnet/ + + + +# ============================================================================ +# Trainer Configuration +# ============================================================================ + +trainer: + data: + val: + _target_: sam3.train.data.torch_dataset.TorchDataset + dataset: + _target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset + coco_json_loader: + _target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP + _partial_: true + img_folder: ${paths.img_path} + ann_file: ${paths.coco_gt} + transforms: ${scratch.base_val_transform} + max_ann_per_img: 100000 + multiplier: 1 + training: false + + shuffle: False + batch_size: ${scratch.val_batch_size} + num_workers: ${scratch.num_val_workers} + pin_memory: False + drop_last: False + collate_fn: + _target_: sam3.train.data.collator.collate_fn_api + _partial_: true + repeats: ${scratch.hybrid_repeats} + dict_key: silver_fathomnet + + meters: + val: + silver_fathomnet: # this key matches the "dict_key" in the dataloader's collate function + cgf1: + _target_: sam3.eval.coco_writer.PredictionDumper + iou_type: "segm" + dump_dir: ${launcher.experiment_log_dir}/dumps/silver_fathomnet + merge_predictions: True + postprocessor: ${scratch.mask_postprocessor_thresholded} + gather_pred_via_filesys: ${scratch.gather_pred_via_filesys} + maxdets: 1000000 # no limit + pred_file_evaluators: + - _target_: sam3.eval.cgf1_eval.CGF1Evaluator + gt_path: ${paths.coco_gt} + iou_type: "bbox" + - _target_: sam3.eval.cgf1_eval.CGF1Evaluator + gt_path: ${paths.coco_gt} + iou_type: "segm" diff --git a/third_party/sam3/sam3/train/configs/silver_image_evals/sam3_silver_image_food_rec.yaml b/third_party/sam3/sam3/train/configs/silver_image_evals/sam3_silver_image_food_rec.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5158ff551e5d2babb1100ba1978e0da4613bac8d --- /dev/null +++ b/third_party/sam3/sam3/train/configs/silver_image_evals/sam3_silver_image_food_rec.yaml @@ -0,0 +1,64 @@ +# @package _global_ +defaults: + - /configs/eval_base.yaml + - _self_ + +# ============================================================================ +# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct +# ============================================================================ +paths: + experiment_log_dir: ${paths.base_experiment_log_dir}/silver_food_rec/ + coco_gt: ${paths.base_annotation_path_silver}/silver_food_rec_merged_test.json + img_path: ${paths.silver_img_path}/food_rec/ + + + +# ============================================================================ +# Trainer Configuration +# ============================================================================ + +trainer: + data: + val: + _target_: sam3.train.data.torch_dataset.TorchDataset + dataset: + _target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset + coco_json_loader: + _target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP + _partial_: true + img_folder: ${paths.img_path} + ann_file: ${paths.coco_gt} + transforms: ${scratch.base_val_transform} + max_ann_per_img: 100000 + multiplier: 1 + training: false + + shuffle: False + batch_size: ${scratch.val_batch_size} + num_workers: ${scratch.num_val_workers} + pin_memory: False + drop_last: False + collate_fn: + _target_: sam3.train.data.collator.collate_fn_api + _partial_: true + repeats: ${scratch.hybrid_repeats} + dict_key: silver_food_rec + + meters: + val: + silver_food_rec: # this key matches the "dict_key" in the dataloader's collate function + cgf1: + _target_: sam3.eval.coco_writer.PredictionDumper + iou_type: "segm" + dump_dir: ${launcher.experiment_log_dir}/dumps/silver_food_rec + merge_predictions: True + postprocessor: ${scratch.mask_postprocessor_thresholded} + gather_pred_via_filesys: ${scratch.gather_pred_via_filesys} + maxdets: 1000000 # no limit + pred_file_evaluators: + - _target_: sam3.eval.cgf1_eval.CGF1Evaluator + gt_path: ${paths.coco_gt} + iou_type: "bbox" + - _target_: sam3.eval.cgf1_eval.CGF1Evaluator + gt_path: ${paths.coco_gt} + iou_type: "segm" diff --git a/third_party/sam3/sam3/train/configs/silver_image_evals/sam3_silver_image_geode.yaml b/third_party/sam3/sam3/train/configs/silver_image_evals/sam3_silver_image_geode.yaml new file mode 100644 index 0000000000000000000000000000000000000000..08f159fe9bc80072d8bd4a95f911bc70a555588d --- /dev/null +++ b/third_party/sam3/sam3/train/configs/silver_image_evals/sam3_silver_image_geode.yaml @@ -0,0 +1,64 @@ +# @package _global_ +defaults: + - /configs/eval_base.yaml + - _self_ + +# ============================================================================ +# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct +# ============================================================================ +paths: + experiment_log_dir: ${paths.base_experiment_log_dir}/silver_geode/ + coco_gt: ${paths.base_annotation_path_silver}/silver_geode_merged_test.json + img_path: ${paths.silver_img_path}/geode/ + + + +# ============================================================================ +# Trainer Configuration +# ============================================================================ + +trainer: + data: + val: + _target_: sam3.train.data.torch_dataset.TorchDataset + dataset: + _target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset + coco_json_loader: + _target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP + _partial_: true + img_folder: ${paths.img_path} + ann_file: ${paths.coco_gt} + transforms: ${scratch.base_val_transform} + max_ann_per_img: 100000 + multiplier: 1 + training: false + + shuffle: False + batch_size: ${scratch.val_batch_size} + num_workers: ${scratch.num_val_workers} + pin_memory: False + drop_last: False + collate_fn: + _target_: sam3.train.data.collator.collate_fn_api + _partial_: true + repeats: ${scratch.hybrid_repeats} + dict_key: silver_geode + + meters: + val: + silver_geode: # this key matches the "dict_key" in the dataloader's collate function + cgf1: + _target_: sam3.eval.coco_writer.PredictionDumper + iou_type: "segm" + dump_dir: ${launcher.experiment_log_dir}/dumps/silver_geode + merge_predictions: True + postprocessor: ${scratch.mask_postprocessor_thresholded} + gather_pred_via_filesys: ${scratch.gather_pred_via_filesys} + maxdets: 1000000 # no limit + pred_file_evaluators: + - _target_: sam3.eval.cgf1_eval.CGF1Evaluator + gt_path: ${paths.coco_gt} + iou_type: "bbox" + - _target_: sam3.eval.cgf1_eval.CGF1Evaluator + gt_path: ${paths.coco_gt} + iou_type: "segm" diff --git a/third_party/sam3/sam3/train/configs/silver_image_evals/sam3_silver_image_inaturalist.yaml b/third_party/sam3/sam3/train/configs/silver_image_evals/sam3_silver_image_inaturalist.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1d56d9758d8cb5711911b60278fe454c975a8456 --- /dev/null +++ b/third_party/sam3/sam3/train/configs/silver_image_evals/sam3_silver_image_inaturalist.yaml @@ -0,0 +1,64 @@ +# @package _global_ +defaults: + - /configs/eval_base.yaml + - _self_ + +# ============================================================================ +# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct +# ============================================================================ +paths: + experiment_log_dir: ${paths.base_experiment_log_dir}/silver_inaturalist/ + coco_gt: ${paths.base_annotation_path_silver}/silver_inaturalist_merged_test.json + img_path: ${paths.silver_img_path}/inaturalist/ + + + +# ============================================================================ +# Trainer Configuration +# ============================================================================ + +trainer: + data: + val: + _target_: sam3.train.data.torch_dataset.TorchDataset + dataset: + _target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset + coco_json_loader: + _target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP + _partial_: true + img_folder: ${paths.img_path} + ann_file: ${paths.coco_gt} + transforms: ${scratch.base_val_transform} + max_ann_per_img: 100000 + multiplier: 1 + training: false + + shuffle: False + batch_size: ${scratch.val_batch_size} + num_workers: ${scratch.num_val_workers} + pin_memory: False + drop_last: False + collate_fn: + _target_: sam3.train.data.collator.collate_fn_api + _partial_: true + repeats: ${scratch.hybrid_repeats} + dict_key: silver_inaturalist + + meters: + val: + silver_inaturalist: # this key matches the "dict_key" in the dataloader's collate function + cgf1: + _target_: sam3.eval.coco_writer.PredictionDumper + iou_type: "segm" + dump_dir: ${launcher.experiment_log_dir}/dumps/silver_inaturalist + merge_predictions: True + postprocessor: ${scratch.mask_postprocessor_thresholded} + gather_pred_via_filesys: ${scratch.gather_pred_via_filesys} + maxdets: 1000000 # no limit + pred_file_evaluators: + - _target_: sam3.eval.cgf1_eval.CGF1Evaluator + gt_path: ${paths.coco_gt} + iou_type: "bbox" + - _target_: sam3.eval.cgf1_eval.CGF1Evaluator + gt_path: ${paths.coco_gt} + iou_type: "segm" diff --git a/third_party/sam3/sam3/train/configs/silver_image_evals/sam3_silver_image_nga.yaml b/third_party/sam3/sam3/train/configs/silver_image_evals/sam3_silver_image_nga.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b2de0afed4289272ca36634a911a7a1d38e03aa3 --- /dev/null +++ b/third_party/sam3/sam3/train/configs/silver_image_evals/sam3_silver_image_nga.yaml @@ -0,0 +1,64 @@ +# @package _global_ +defaults: + - /configs/eval_base.yaml + - _self_ + +# ============================================================================ +# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct +# ============================================================================ +paths: + experiment_log_dir: ${paths.base_experiment_log_dir}/silver_nga_art/ + coco_gt: ${paths.base_annotation_path_silver}/silver_nga_art_merged_test.json + img_path: ${paths.silver_img_path}/nga/ + + + +# ============================================================================ +# Trainer Configuration +# ============================================================================ + +trainer: + data: + val: + _target_: sam3.train.data.torch_dataset.TorchDataset + dataset: + _target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset + coco_json_loader: + _target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP + _partial_: true + img_folder: ${paths.img_path} + ann_file: ${paths.coco_gt} + transforms: ${scratch.base_val_transform} + max_ann_per_img: 100000 + multiplier: 1 + training: false + + shuffle: False + batch_size: ${scratch.val_batch_size} + num_workers: ${scratch.num_val_workers} + pin_memory: False + drop_last: False + collate_fn: + _target_: sam3.train.data.collator.collate_fn_api + _partial_: true + repeats: ${scratch.hybrid_repeats} + dict_key: silver_nga_art + + meters: + val: + silver_nga_art: # this key matches the "dict_key" in the dataloader's collate function + cgf1: + _target_: sam3.eval.coco_writer.PredictionDumper + iou_type: "segm" + dump_dir: ${launcher.experiment_log_dir}/dumps/silver_nga_art + merge_predictions: True + postprocessor: ${scratch.mask_postprocessor_thresholded} + gather_pred_via_filesys: ${scratch.gather_pred_via_filesys} + maxdets: 1000000 # no limit + pred_file_evaluators: + - _target_: sam3.eval.cgf1_eval.CGF1Evaluator + gt_path: ${paths.coco_gt} + iou_type: "bbox" + - _target_: sam3.eval.cgf1_eval.CGF1Evaluator + gt_path: ${paths.coco_gt} + iou_type: "segm" diff --git a/third_party/sam3/sam3/train/configs/silver_image_evals/sam3_silver_image_sav.yaml b/third_party/sam3/sam3/train/configs/silver_image_evals/sam3_silver_image_sav.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7ebbb0f2bbee9b06221a46d362b8ca719bed9b4b --- /dev/null +++ b/third_party/sam3/sam3/train/configs/silver_image_evals/sam3_silver_image_sav.yaml @@ -0,0 +1,64 @@ +# @package _global_ +defaults: + - /configs/eval_base.yaml + - _self_ + +# ============================================================================ +# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct +# ============================================================================ +paths: + experiment_log_dir: ${paths.base_experiment_log_dir}/silver_sav/ + coco_gt: ${paths.base_annotation_path_silver}/silver_sav_merged_test.json + img_path: ${paths.silver_img_path}/sav/ + + + +# ============================================================================ +# Trainer Configuration +# ============================================================================ + +trainer: + data: + val: + _target_: sam3.train.data.torch_dataset.TorchDataset + dataset: + _target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset + coco_json_loader: + _target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP + _partial_: true + img_folder: ${paths.img_path} + ann_file: ${paths.coco_gt} + transforms: ${scratch.base_val_transform} + max_ann_per_img: 100000 + multiplier: 1 + training: false + + shuffle: False + batch_size: ${scratch.val_batch_size} + num_workers: ${scratch.num_val_workers} + pin_memory: False + drop_last: False + collate_fn: + _target_: sam3.train.data.collator.collate_fn_api + _partial_: true + repeats: ${scratch.hybrid_repeats} + dict_key: silver_sav + + meters: + val: + silver_sav: # this key matches the "dict_key" in the dataloader's collate function + cgf1: + _target_: sam3.eval.coco_writer.PredictionDumper + iou_type: "segm" + dump_dir: ${launcher.experiment_log_dir}/dumps/silver_sav + merge_predictions: True + postprocessor: ${scratch.mask_postprocessor_thresholded} + gather_pred_via_filesys: ${scratch.gather_pred_via_filesys} + maxdets: 1000000 # no limit + pred_file_evaluators: + - _target_: sam3.eval.cgf1_eval.CGF1Evaluator + gt_path: ${paths.coco_gt} + iou_type: "bbox" + - _target_: sam3.eval.cgf1_eval.CGF1Evaluator + gt_path: ${paths.coco_gt} + iou_type: "segm" diff --git a/third_party/sam3/sam3/train/configs/silver_image_evals/sam3_silver_image_yt1b.yaml b/third_party/sam3/sam3/train/configs/silver_image_evals/sam3_silver_image_yt1b.yaml new file mode 100644 index 0000000000000000000000000000000000000000..901bd3a050f4041364b36299fa01648ce576d0af --- /dev/null +++ b/third_party/sam3/sam3/train/configs/silver_image_evals/sam3_silver_image_yt1b.yaml @@ -0,0 +1,64 @@ +# @package _global_ +defaults: + - /configs/eval_base.yaml + - _self_ + +# ============================================================================ +# Paths Configuration (you can override here, but it shouldn't require further changes if eval_base.yaml is correct +# ============================================================================ +paths: + experiment_log_dir: ${paths.base_experiment_log_dir}/silver_yt1b/ + coco_gt: ${paths.base_annotation_path_silver}/silver_yt1b_merged_test.json + img_path: ${paths.silver_img_path}/yt1b/ + + + +# ============================================================================ +# Trainer Configuration +# ============================================================================ + +trainer: + data: + val: + _target_: sam3.train.data.torch_dataset.TorchDataset + dataset: + _target_: sam3.train.data.sam3_image_dataset.Sam3ImageDataset + coco_json_loader: + _target_: sam3.train.data.coco_json_loaders.SAM3_EVAL_API_FROM_JSON_NP + _partial_: true + img_folder: ${paths.img_path} + ann_file: ${paths.coco_gt} + transforms: ${scratch.base_val_transform} + max_ann_per_img: 100000 + multiplier: 1 + training: false + + shuffle: False + batch_size: ${scratch.val_batch_size} + num_workers: ${scratch.num_val_workers} + pin_memory: False + drop_last: False + collate_fn: + _target_: sam3.train.data.collator.collate_fn_api + _partial_: true + repeats: ${scratch.hybrid_repeats} + dict_key: silver_yt1b + + meters: + val: + silver_yt1b: # this key matches the "dict_key" in the dataloader's collate function + cgf1: + _target_: sam3.eval.coco_writer.PredictionDumper + iou_type: "segm" + dump_dir: ${launcher.experiment_log_dir}/dumps/silver_yt1b + merge_predictions: True + postprocessor: ${scratch.mask_postprocessor_thresholded} + gather_pred_via_filesys: ${scratch.gather_pred_via_filesys} + maxdets: 1000000 # no limit + pred_file_evaluators: + - _target_: sam3.eval.cgf1_eval.CGF1Evaluator + gt_path: ${paths.coco_gt} + iou_type: "bbox" + - _target_: sam3.eval.cgf1_eval.CGF1Evaluator + gt_path: ${paths.coco_gt} + iou_type: "segm" diff --git a/third_party/sam3/sam3/train/data/__init__.py b/third_party/sam3/sam3/train/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..726cb1eb0ff0f17f7edc4ec73f1a73d1ac87bf59 --- /dev/null +++ b/third_party/sam3/sam3/train/data/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe diff --git a/third_party/sam3/sam3/train/data/__pycache__/__init__.cpython-311.pyc b/third_party/sam3/sam3/train/data/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1fe7d5552b90e6919852c504ee0c4d60e1dd8a99 Binary files /dev/null and b/third_party/sam3/sam3/train/data/__pycache__/__init__.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/train/data/__pycache__/coco_json_loaders.cpython-311.pyc b/third_party/sam3/sam3/train/data/__pycache__/coco_json_loaders.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e53fb1362730fa9f73bb8e439cc982ea45c4cc4a Binary files /dev/null and b/third_party/sam3/sam3/train/data/__pycache__/coco_json_loaders.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/train/data/__pycache__/collator.cpython-311.pyc b/third_party/sam3/sam3/train/data/__pycache__/collator.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2bfd83f8f0451b0788323ec40e62c075f1bc987 Binary files /dev/null and b/third_party/sam3/sam3/train/data/__pycache__/collator.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/train/data/__pycache__/sam3_image_dataset.cpython-311.pyc b/third_party/sam3/sam3/train/data/__pycache__/sam3_image_dataset.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6a164392c7fec59acd0f1473537ab9bee4bfe27 Binary files /dev/null and b/third_party/sam3/sam3/train/data/__pycache__/sam3_image_dataset.cpython-311.pyc differ diff --git a/third_party/sam3/sam3/train/data/coco_json_loaders.py b/third_party/sam3/sam3/train/data/coco_json_loaders.py new file mode 100644 index 0000000000000000000000000000000000000000..1618e193cd5ecebdea58e3e103bb8ab12ccab1d1 --- /dev/null +++ b/third_party/sam3/sam3/train/data/coco_json_loaders.py @@ -0,0 +1,467 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +import json +from collections import defaultdict +from typing import Dict, List, Tuple + +import torch +from pycocotools import mask as mask_util + + +# ============================================================================ +# Utility Functions +# ============================================================================ + + +def convert_boxlist_to_normalized_tensor(box_list, image_width, image_height): + """ + Converts a list of bounding boxes to a normalized PyTorch tensor. + + Args: + box_list (list of list or tuples): Each box is [x_min, y_min, x_max, y_max]. + image_width (int or float): Width of the image. + image_height (int or float): Height of the image. + + Returns: + torch.Tensor: Normalized tensor of shape (N, 4), values in [0, 1]. + """ + boxes = torch.tensor(box_list, dtype=torch.float32) + boxes[:, [0, 2]] /= image_width # x_min, x_max + boxes[:, [1, 3]] /= image_height # y_min, y_max + boxes = boxes.clamp(0, 1) + return boxes + + +def load_coco_and_group_by_image(json_path: str) -> Tuple[List[Dict], Dict[int, str]]: + """ + Load COCO JSON file and group annotations by image. + + Args: + json_path (str): Path to COCO JSON file. + + Returns: + Tuple containing: + - List of dicts with 'image' and 'annotations' keys + - Dict mapping category IDs to category names + """ + with open(json_path, "r") as f: + coco = json.load(f) + + images = {img["id"]: img for img in coco["images"]} + + anns_by_image = defaultdict(list) + for ann in coco["annotations"]: + anns_by_image[ann["image_id"]].append(ann) + + sorted_image_ids = sorted(images.keys()) + + grouped = [] + for image_id in sorted_image_ids: + image_info = images[image_id] + grouped.append( + {"image": image_info, "annotations": anns_by_image.get(image_id, [])} + ) + + cat_id_to_name = {cat["id"]: cat["name"] for cat in coco["categories"]} + + return grouped, cat_id_to_name + + +def ann_to_rle(segm, im_info: Dict) -> Dict: + """ + Convert annotation which can be polygons or uncompressed RLE to RLE. + + Args: + segm: Segmentation data (polygon list or RLE dict) + im_info (dict): Image info containing 'height' and 'width' + + Returns: + RLE encoded segmentation + """ + h, w = im_info["height"], im_info["width"] + + if isinstance(segm, list): + # Polygon - merge all parts into one mask RLE code + rles = mask_util.frPyObjects(segm, h, w) + rle = mask_util.merge(rles) + elif isinstance(segm["counts"], list): + # Uncompressed RLE + rle = mask_util.frPyObjects(segm, h, w) + else: + # Already RLE + rle = segm + + return rle + + +# ============================================================================ +# COCO Training API +# ============================================================================ + + +class COCO_FROM_JSON: + """ + COCO training API for loading box-only annotations from JSON. + Groups all annotations per image and creates queries per category. + """ + + def __init__( + self, + annotation_file, + prompts=None, + include_negatives=True, + category_chunk_size=None, + ): + """ + Initialize the COCO training API. + + Args: + annotation_file (str): Path to COCO JSON annotation file + prompts: Optional custom prompts for categories + include_negatives (bool): Whether to include negative examples (categories with no instances) + """ + self._raw_data, self._cat_idx_to_text = load_coco_and_group_by_image( + annotation_file + ) + self._sorted_cat_ids = sorted(list(self._cat_idx_to_text.keys())) + self.prompts = None + self.include_negatives = include_negatives + self.category_chunk_size = ( + category_chunk_size + if category_chunk_size is not None + else len(self._sorted_cat_ids) + ) + self.category_chunks = [ + self._sorted_cat_ids[i : i + self.category_chunk_size] + for i in range(0, len(self._sorted_cat_ids), self.category_chunk_size) + ] + if prompts is not None: + prompts = eval(prompts) + self.prompts = {} + for loc_dict in prompts: + self.prompts[int(loc_dict["id"])] = loc_dict["name"] + assert len(self.prompts) == len( + self._sorted_cat_ids + ), "Number of prompts must match number of categories" + + def getDatapointIds(self): + """Return all datapoint indices for training.""" + return list(range(len(self._raw_data) * len(self.category_chunks))) + + def loadQueriesAndAnnotationsFromDatapoint(self, idx): + """ + Load queries and annotations for a specific datapoint. + + Args: + idx (int): Datapoint index + + Returns: + Tuple of (queries, annotations) lists + """ + img_idx = idx // len(self.category_chunks) + chunk_idx = idx % len(self.category_chunks) + cat_chunk = self.category_chunks[chunk_idx] + + queries = [] + annotations = [] + + query_template = { + "id": None, + "original_cat_id": None, + "object_ids_output": None, + "query_text": None, + "query_processing_order": 0, + "ptr_x_query_id": None, + "ptr_y_query_id": None, + "image_id": 0, # Single image per datapoint + "input_box": None, + "input_box_label": None, + "input_points": None, + "is_exhaustive": True, + } + + annot_template = { + "image_id": 0, + "bbox": None, # Normalized bbox in xywh + "area": None, # Unnormalized area + "segmentation": None, # RLE encoded + "object_id": None, + "is_crowd": None, + "id": None, + } + + raw_annotations = self._raw_data[img_idx]["annotations"] + image_info = self._raw_data[img_idx]["image"] + width, height = image_info["width"], image_info["height"] + + # Group annotations by category + cat_id_to_anns = defaultdict(list) + for ann in raw_annotations: + cat_id_to_anns[ann["category_id"]].append(ann) + + annotations_by_cat_sorted = [ + (cat_id, cat_id_to_anns[cat_id]) for cat_id in cat_chunk + ] + + for cat_id, anns in annotations_by_cat_sorted: + if len(anns) == 0 and not self.include_negatives: + continue + + cur_ann_ids = [] + + # Create annotations for this category + for ann in anns: + annotation = annot_template.copy() + annotation["id"] = len(annotations) + annotation["object_id"] = annotation["id"] + annotation["is_crowd"] = ann["iscrowd"] + + normalized_boxes = convert_boxlist_to_normalized_tensor( + [ann["bbox"]], width, height + ) + bbox = normalized_boxes[0] + + annotation["area"] = (bbox[2] * bbox[3]).item() + annotation["bbox"] = bbox + + if ( + "segmentation" in ann + and ann["segmentation"] is not None + and ann["segmentation"] != [] + ): + annotation["segmentation"] = ann_to_rle( + ann["segmentation"], im_info=image_info + ) + + annotations.append(annotation) + cur_ann_ids.append(annotation["id"]) + + # Create query for this category + query = query_template.copy() + query["id"] = len(queries) + query["original_cat_id"] = cat_id + query["query_text"] = ( + self._cat_idx_to_text[cat_id] + if self.prompts is None + else self.prompts[cat_id] + ) + query["object_ids_output"] = cur_ann_ids + queries.append(query) + + return queries, annotations + + def loadImagesFromDatapoint(self, idx): + """ + Load image information for a specific datapoint. + + Args: + idx (int): Datapoint index + + Returns: + List containing image info dict + """ + img_idx = idx // len(self.category_chunks) + img_data = self._raw_data[img_idx]["image"] + images = [ + { + "id": 0, + "file_name": img_data["file_name"], + "original_img_id": img_data["id"], + "coco_img_id": img_data["id"], + } + ] + return images + + +# ============================================================================ +# SAM3 Evaluation APIs +# ============================================================================ + + +class SAM3_EVAL_API_FROM_JSON_NP: + """ + SAM3 evaluation API for loading noun phrase queries from JSON. + """ + + def __init__(self, annotation_file): + """ + Initialize the SAM3 evaluation API. + + Args: + annotation_file (str): Path to SAM3 JSON annotation file + """ + with open(annotation_file, "r") as f: + data = json.load(f) + self._image_data = data["images"] + + def getDatapointIds(self): + """Return all datapoint indices.""" + return list(range(len(self._image_data))) + + def loadQueriesAndAnnotationsFromDatapoint(self, idx): + """ + Load queries and annotations for a specific datapoint. + + Args: + idx (int): Datapoint index + + Returns: + Tuple of (queries, annotations) lists + """ + cur_img_data = self._image_data[idx] + queries = [] + annotations = [] + + query_template = { + "id": None, + "original_cat_id": None, + "object_ids_output": None, + "query_text": None, + "query_processing_order": 0, + "ptr_x_query_id": None, + "ptr_y_query_id": None, + "image_id": 0, + "input_box": None, + "input_box_label": None, + "input_points": None, + "is_exhaustive": True, + } + + # Create query + query = query_template.copy() + query["id"] = len(queries) + query["original_cat_id"] = int(cur_img_data["queried_category"]) + query["query_text"] = cur_img_data["text_input"] + query["object_ids_output"] = [] + queries.append(query) + + return queries, annotations + + def loadImagesFromDatapoint(self, idx): + """ + Load image information for a specific datapoint. + + Args: + idx (int): Datapoint index + + Returns: + List containing image info dict + """ + img_data = self._image_data[idx] + images = [ + { + "id": 0, + "file_name": img_data["file_name"], + "original_img_id": img_data["id"], + "coco_img_id": img_data["id"], + } + ] + return images + + +class SAM3_VEVAL_API_FROM_JSON_NP: + """ + SAM3 video evaluation API for loading noun phrase queries from JSON. + """ + + def __init__(self, annotation_file): + """ + Initialize the SAM3 video evaluation API. + + Args: + annotation_file (str): Path to SAM3 video JSON annotation file + """ + with open(annotation_file, "r") as f: + data = json.load(f) + + assert "video_np_pairs" in data, "Incorrect data format" + + self._video_data = data["videos"] + self._video_id_to_np_ids = defaultdict(list) + self._cat_id_to_np = {} + + for cat_dict in data["categories"]: + self._cat_id_to_np[cat_dict["id"]] = cat_dict["name"] + + for video_np_dict in data["video_np_pairs"]: + self._video_id_to_np_ids[video_np_dict["video_id"]].append( + video_np_dict["category_id"] + ) + assert ( + self._cat_id_to_np[video_np_dict["category_id"]] + == video_np_dict["noun_phrase"] + ), "Category name does not match text input" + + def getDatapointIds(self): + """Return all datapoint indices.""" + return list(range(len(self._video_data))) + + def loadQueriesAndAnnotationsFromDatapoint(self, idx): + """ + Load queries and annotations for a specific video datapoint. + + Args: + idx (int): Datapoint index + + Returns: + Tuple of (queries, annotations) lists + """ + cur_vid_data = self._video_data[idx] + queries = [] + annotations = [] + + query_template = { + "id": None, + "original_cat_id": None, + "object_ids_output": None, + "query_text": None, + "query_processing_order": 0, + "ptr_x_query_id": None, + "ptr_y_query_id": None, + "image_id": 0, + "input_box": None, + "input_box_label": None, + "input_points": None, + "is_exhaustive": True, + } + + all_np_ids = self._video_id_to_np_ids[cur_vid_data["id"]] + + for np_id in all_np_ids: + text_input = self._cat_id_to_np[np_id] + + for i, image_path in enumerate(cur_vid_data["file_names"]): + query = query_template.copy() + query["id"] = len(queries) + query["original_cat_id"] = np_id + query["query_text"] = text_input + query["image_id"] = i + query["query_processing_order"] = i + query["object_ids_output"] = [] + queries.append(query) + + return queries, annotations + + def loadImagesFromDatapoint(self, idx): + """ + Load image information for a specific video datapoint. + + Args: + idx (int): Datapoint index + + Returns: + List containing image info dicts for all frames + """ + video_data = self._video_data[idx] + images = [ + { + "id": i, + "file_name": file_name, + "original_img_id": video_data["id"], + "coco_img_id": video_data["id"], + } + for i, file_name in enumerate(video_data["file_names"]) + ] + return images diff --git a/third_party/sam3/sam3/train/data/collator.py b/third_party/sam3/sam3/train/data/collator.py new file mode 100644 index 0000000000000000000000000000000000000000..8ce4e28064d2d868b5dd419dd42bee649da01f5e --- /dev/null +++ b/third_party/sam3/sam3/train/data/collator.py @@ -0,0 +1,361 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +from dataclasses import dataclass, field as field_ptr_behaviour, fields, is_dataclass +from typing import Any, get_args, get_origin, List, Union + +import torch +from sam3.model.data_misc import ( + BatchedDatapoint, + BatchedFindTarget, + BatchedInferenceMetadata, + FindStage, +) + +from .sam3_image_dataset import Datapoint + + +MyTensor = Union[torch.Tensor, List[Any]] + + +def convert_my_tensors(obj): + def is_optional_field(field) -> bool: + return get_origin(field) is Union and type(None) in get_args(field) + + for field in fields(obj): + if is_dataclass(getattr(obj, field.name)): + convert_my_tensors(getattr(obj, field.name)) + continue + + field_type = field.type + if is_optional_field(field.type): + field_type = Union[get_args(field.type)[:-1]] # Get the Optional field type + + if field_type != MyTensor or getattr(obj, field.name) is None: + continue + + elif len(getattr(obj, field.name)) and isinstance( + getattr(obj, field.name)[0], torch.Tensor + ): + stack_dim = 0 + if field.name in [ + "input_boxes", + "input_boxes_label", + ]: + stack_dim = 1 + setattr( + obj, + field.name, + torch.stack(getattr(obj, field.name), dim=stack_dim).to( + getattr(obj, field.name + "__type") + ), + ) + else: + setattr( + obj, + field.name, + torch.as_tensor( + getattr(obj, field.name), dtype=getattr(obj, field.name + "__type") + ), + ) + return obj + + +def packed_to_padded_naive(boxes_packed, num_boxes, fill_value=0): + """ + Convert a packed tensor of bounding boxes to a padded tensor of bounding + boxes. Naive implementation using a loop. + + Inputs: + - boxes_packed: Tensor of shape (N_1 + ... + N_B, 4) + - num_boxes: Tensor of shape (B,) where num_boxes[i] = N_i + + Returns: + - boxes_padded: Tensor of shape (B, N_max, 4) where N_max = max_i N_i + """ + B = num_boxes.shape[0] + Ns = num_boxes.tolist() + + boxes_padded = boxes_packed.new_zeros(B, max(Ns), *boxes_packed.shape[1:]) + if fill_value != 0: + boxes_padded[...] = fill_value + prev_idx = 0 + for i in range(B): + next_idx = prev_idx + Ns[i] + boxes_padded[i, : Ns[i]] = boxes_packed[prev_idx:next_idx] + prev_idx = next_idx + return boxes_padded + + +def pad_tensor_list_to_longest( + tensors: List[torch.Tensor], dim=0, pad_val=0 +) -> List[torch.Tensor]: + # Edits the list in-place + if not tensors: + return tensors + pad_len = max(t.shape[dim] for t in tensors) + for i in range(len(tensors)): + n_dims = len(tensors[i].shape) + n_right_dims = (n_dims - 1) - (n_dims + dim) % n_dims + n_pad = pad_len - tensors[i].shape[dim] + pad_tuple = tuple([0] * 2 * n_right_dims + [0, n_pad]) + tensors[i] = torch.nn.functional.pad(tensors[i], pad_tuple, value=pad_val) + return tensors + + +def collate_fn_api_with_chunking( + batch, + num_chunks, + dict_key, + with_seg_masks=False, + input_points_embedding_dim=257, + repeats: int = 0, + load_image_in_fp16: bool = False, +): + assert num_chunks >= 1, "num_chunks must be >= 1" + + # split the batch into num_chunks chunks + batch_chunks = [batch[i::num_chunks] for i in range(num_chunks)] + + # collate each chunk + collated_chunks = [ + collate_fn_api( + chunk, + dict_key, + with_seg_masks, + input_points_embedding_dim, + repeats, + # ptr_behaviour, + load_image_in_fp16, + ) + for chunk in batch_chunks + ] + return collated_chunks + + +def collate_fn_api( + batch: List[Datapoint], + dict_key, + with_seg_masks=False, + input_points_embedding_dim=257, + repeats: int = 0, + load_image_in_fp16: bool = False, +): + # img_batch = torch.stack(sum([[img.data for img in v.images] for v in batch], [])) + img_batch = [] + text_batch = [] + raw_images = None + + num_stages = ( + max(q.query_processing_order for data in batch for q in data.find_queries) + 1 + ) + + stages = [ + FindStage( + img_ids=[], + text_ids=[], + input_boxes=[], + input_boxes_label=[], + input_boxes_mask=[], + input_points=[], + input_points_mask=[], + object_ids=[], + ) + for _ in range(num_stages) + ] + find_targets = [ + BatchedFindTarget( + num_boxes=[], + boxes=[], + boxes_padded=[], + is_exhaustive=[], + segments=[], + semantic_segments=[], + is_valid_segment=[], + repeated_boxes=[], + object_ids=[], + object_ids_padded=[], + ) + for _ in range(num_stages) + ] + find_metadatas = [ + BatchedInferenceMetadata( + coco_image_id=[], + original_size=[], + object_id=[], + frame_index=[], + original_image_id=[], + original_category_id=[], + is_conditioning_only=[], + ) + for _ in range(num_stages) + ] + + offset_img_id = 0 + offset_query_id = [0 for _ in range(num_stages)] + for data in batch: + img_batch.extend([img.data for img in data.images]) + + if data.raw_images is not None: + if raw_images is None: + raw_images = [] + raw_images.extend(data.raw_images) + + # Conversion of query_ids indexing in a datapoint to query_ids indexing in a stage + datapoint_query_id_2_stage_query_id = [] + for q in data.find_queries: + stage_id = q.query_processing_order + datapoint_query_id_2_stage_query_id.append(offset_query_id[stage_id]) + offset_query_id[stage_id] += 1 + + for q in data.find_queries: + stage_id = q.query_processing_order + stages[stage_id].img_ids.append(q.image_id + offset_img_id) + if q.query_text not in text_batch: + text_batch.append(q.query_text) + stages[stage_id].text_ids.append(text_batch.index(q.query_text)) + + assert ( + q.inference_metadata is not None + ), "inference_metadata must be provided when FindQueryLoaded is created." + for f in fields(q.inference_metadata): + getattr(find_metadatas[stage_id], f.name).append( + getattr(q.inference_metadata, f.name) + ) + + if q.input_bbox is not None: + assert q.input_bbox.numel() % 4 == 0 + assert q.input_bbox_label is not None + nb_boxes = q.input_bbox.numel() // 4 + assert len(q.input_bbox_label) == nb_boxes + stages[stage_id].input_boxes.append(q.input_bbox.view(nb_boxes, 4)) + stages[stage_id].input_boxes_label.append( + q.input_bbox_label.view(nb_boxes) + ) + stages[stage_id].input_boxes_mask.append( + torch.zeros(nb_boxes, dtype=torch.bool) + ) + else: + stages[stage_id].input_boxes.append(torch.zeros(0, 4)) + stages[stage_id].input_boxes_label.append( + torch.zeros(0, dtype=torch.bool) + ) + stages[stage_id].input_boxes_mask.append( + torch.ones(0, dtype=torch.bool) + ) + + if q.input_points is not None: + stages[stage_id].input_points.append( + q.input_points.squeeze(0) # Strip a trivial batch index + ) + # All masks will be padded up to the longest length + # with 1s before final conversion to batchd tensors + stages[stage_id].input_points_mask.append( + torch.zeros(q.input_points.shape[1]) + ) + else: + stages[stage_id].input_points.append( + torch.empty(0, input_points_embedding_dim) + ) + stages[stage_id].input_points_mask.append(torch.empty(0)) + + current_out_boxes = [] + current_out_object_ids = [] + # Set the object ids referred to by this query + stages[stage_id].object_ids.append(q.object_ids_output) + for object_id in q.object_ids_output: + current_out_boxes.append( + data.images[q.image_id].objects[object_id].bbox + ) + current_out_object_ids.append(object_id) + find_targets[stage_id].boxes.extend(current_out_boxes) + find_targets[stage_id].object_ids.extend(current_out_object_ids) + if repeats > 0: + for _ in range(repeats): + find_targets[stage_id].repeated_boxes.extend(current_out_boxes) + find_targets[stage_id].num_boxes.append(len(current_out_boxes)) + find_targets[stage_id].is_exhaustive.append(q.is_exhaustive) + + if with_seg_masks: + current_seg_mask = [] + current_is_valid_segment = [] + for object_id in q.object_ids_output: + seg_mask = data.images[q.image_id].objects[object_id].segment + if seg_mask is not None: + current_seg_mask.append(seg_mask) + current_is_valid_segment.append(1) + else: + dummy_mask = torch.zeros( + data.images[q.image_id].data.shape[-2:], dtype=torch.bool + ) + current_seg_mask.append(dummy_mask) + current_is_valid_segment.append(0) + find_targets[stage_id].segments.extend(current_seg_mask) + find_targets[stage_id].is_valid_segment.extend(current_is_valid_segment) + else: + # We are not loading segmentation masks + find_targets[stage_id].segments = None + find_targets[stage_id].is_valid_segment = None + + if q.semantic_target is not None: + find_targets[stage_id].semantic_segments.append(q.semantic_target) + + offset_img_id += len(data.images) + + # Pad input points to equal sequence lengths + for i in range(len(stages)): + stages[i].input_points = pad_tensor_list_to_longest( + stages[i].input_points, dim=0, pad_val=0 + ) + # Masked-out regions indicated by 1s. + stages[i].input_points_mask = pad_tensor_list_to_longest( + stages[i].input_points_mask, dim=0, pad_val=1 + ) + + # Pad input boxes to equal sequence lengths + for i in range(len(stages)): + stages[i].input_boxes = pad_tensor_list_to_longest( + stages[i].input_boxes, dim=0, pad_val=0 + ) + stages[i].input_boxes_label = pad_tensor_list_to_longest( + stages[i].input_boxes_label, dim=0, pad_val=0 + ) + # Masked-out regions indicated by 1s. + stages[i].input_boxes_mask = pad_tensor_list_to_longest( + stages[i].input_boxes_mask, dim=0, pad_val=1 + ) + + # Convert to tensors + for i in range(len(stages)): + stages[i] = convert_my_tensors(stages[i]) + find_targets[i] = convert_my_tensors(find_targets[i]) + find_metadatas[i] = convert_my_tensors(find_metadatas[i]) + # get padded representation for the boxes + find_targets[i].boxes_padded = packed_to_padded_naive( + find_targets[i].boxes.view(-1, 4), find_targets[i].num_boxes + ) + find_targets[i].object_ids_padded = packed_to_padded_naive( + find_targets[i].object_ids, find_targets[i].num_boxes, fill_value=-1 + ) + + # Finalize the image batch + # check sizes + for img in img_batch[1:]: + assert img.shape == img_batch[0].shape, "All images must have the same size" + image_batch = torch.stack(img_batch) + if load_image_in_fp16: + # Optionally, cast the image tensors to fp16, which helps save GPU memory on + # long videos with thousands of frames (where image tensors could be several GBs) + image_batch = image_batch.half() + + return { + dict_key: BatchedDatapoint( + img_batch=image_batch, + find_text_batch=text_batch, + find_inputs=stages, + find_targets=find_targets, + find_metadatas=find_metadatas, + raw_images=raw_images, + ) + } diff --git a/third_party/sam3/sam3/train/data/sam3_image_dataset.py b/third_party/sam3/sam3/train/data/sam3_image_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..dca5b971a13df0aaa5e6061d2dfda7b42f6f2f26 --- /dev/null +++ b/third_party/sam3/sam3/train/data/sam3_image_dataset.py @@ -0,0 +1,529 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +"""Dataset class for modulated detection""" + +import json +import os +import random +import sys +import traceback +from collections import Counter +from dataclasses import dataclass +from enum import Enum +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union + +import torch +import torch.utils.data +import torchvision +from iopath.common.file_io import g_pathmgr +from PIL import Image as PILImage +from PIL.Image import DecompressionBombError +from sam3.model.box_ops import box_xywh_to_xyxy +from torchvision.datasets.vision import VisionDataset + +from .coco_json_loaders import COCO_FROM_JSON + + +@dataclass +class InferenceMetadata: + """Metadata required for postprocessing""" + + # Coco id that corresponds to the "image" for evaluation by the coco evaluator + # This is used for our own "class agnostic" evaluation + coco_image_id: int + + # id in the original dataset, such that we can use the original evaluator + original_image_id: int + + # Original category id (if we want to use the original evaluator) + original_category_id: int + + # Size of the raw image (height, width) + original_size: Tuple[int, int] + + # Id of the object in the media + object_id: int + + # Index of the frame in the media (0 if single image) + frame_index: int + + # Whether it is for conditioning only, e.g., 0-th frame in TA is for conditioning + # as we assume GT available in frame 0. + is_conditioning_only: Optional[bool] = False + + +@dataclass +class FindQuery: + query_text: str + + image_id: int + + # In case of a find query, the list of object ids that have to be predicted + object_ids_output: List[int] + + # This is "instance exhaustivity". + # true iff all instances are separable and annotated + # See below the slightly different "pixel exhaustivity" + is_exhaustive: bool + + # The order in which the queries are processed (only meaningful for video) + query_processing_order: int = 0 + + # Input geometry, initially in denormalized XYXY format. Then + # 1. converted to normalized CxCyWH by the Normalize transform + input_bbox: Optional[torch.Tensor] = None + input_bbox_label: Optional[torch.Tensor] = None + + # Only for the PVS task + input_points: Optional[torch.Tensor] = None + + semantic_target: Optional[torch.Tensor] = None + + # pixel exhaustivity: true iff the union of all segments (including crowds) + # covers every pixel belonging to the target class + # Note that instance_exhaustive implies pixel_exhaustive + is_pixel_exhaustive: Optional[bool] = None + + +@dataclass +class FindQueryLoaded(FindQuery): + # Must have default value since FindQuery has entries with default values + inference_metadata: Optional[InferenceMetadata] = None + + +@dataclass +class Object: + # Initially in denormalized XYXY format, gets converted to normalized CxCyWH by the Normalize transform + bbox: torch.Tensor + area: float + + # Id of the object in the media + object_id: Optional[int] = -1 + + # Index of the frame in the media (0 if single image) + frame_index: Optional[int] = -1 + + segment: Optional[Union[torch.Tensor, dict]] = None # RLE dict or binary mask + + is_crowd: bool = False + + source: Optional[str] = None + + +@dataclass +class Image: + data: Union[torch.Tensor, PILImage.Image] + objects: List[Object] + size: Tuple[int, int] # (height, width) + + # For blurring augmentation + blurring_mask: Optional[Dict[str, Any]] = None + + +@dataclass +class Datapoint: + """Refers to an image/video and all its annotations""" + + find_queries: List[FindQueryLoaded] + images: List[Image] + raw_images: Optional[List[PILImage.Image]] = None + + +class CustomCocoDetectionAPI(VisionDataset): + """`MS Coco Detection `_ Dataset. + + Args: + root (string): Root directory where images are downloaded to. + annFile (string): Path to json annotation file. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.ToTensor`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + transforms (callable, optional): A function/transform that takes input sample and its target as entry + and returns a transformed version. + """ + + def __init__( + self, + root: str, + annFile: str, + load_segmentation: bool, + fix_fname: bool = False, + training: bool = True, + blurring_masks_path: Optional[str] = None, + use_caching: bool = True, + zstd_dict_path=None, + filter_query=None, + coco_json_loader: Callable = COCO_FROM_JSON, + limit_ids: int = None, + ) -> None: + super().__init__(root) + + self.annFile = annFile + self.use_caching = use_caching + self.zstd_dict_path = zstd_dict_path + + self.curr_epoch = 0 # Used in case data loader behavior changes across epochs + self.load_segmentation = load_segmentation + self.fix_fname = fix_fname + self.filter_query = filter_query + + self.coco = None + self.coco_json_loader = coco_json_loader + self.limit_ids = limit_ids + self.set_sharded_annotation_file(0) + self.training = training + self.blurring_masks_path = blurring_masks_path + + def _load_images( + self, datapoint_id: int, img_ids_to_load: Optional[Set[int]] = None + ) -> Tuple[List[Tuple[int, PILImage.Image]], List[Dict[str, Any]]]: + all_images = [] + all_img_metadata = [] + for current_meta in self.coco.loadImagesFromDatapoint(datapoint_id): + img_id = current_meta["id"] + if img_ids_to_load is not None and img_id not in img_ids_to_load: + continue + if self.fix_fname: + current_meta["file_name"] = current_meta["file_name"].split("/")[-1] + path = current_meta["file_name"] + if self.blurring_masks_path is not None: + mask_fname = os.path.basename(path).replace(".jpg", "-mask.json") + mask_path = os.path.join(self.blurring_masks_path, mask_fname) + if os.path.exists(mask_path): + with open(mask_path, "r") as fopen: + current_meta["blurring_mask"] = json.load(fopen) + + all_img_metadata.append(current_meta) + path = os.path.join(self.root, path) + try: + if ".mp4" in path and path[-4:] == ".mp4": + # Going to load a video frame + from decord import cpu, VideoReader + + video_path, frame = path.split("@") + video = VideoReader(video_path, ctx=cpu(0)) + # Convert to PIL image + all_images.append( + ( + img_id, + torchvision.transforms.ToPILImage()( + video[int(frame)].asnumpy() + ), + ) + ) + else: + with g_pathmgr.open(path, "rb") as fopen: + all_images.append((img_id, PILImage.open(fopen).convert("RGB"))) + except FileNotFoundError as e: + print(f"File not found: {path} from dataset: {self.annFile}") + raise e + + return all_images, all_img_metadata + + def set_curr_epoch(self, epoch: int): + self.curr_epoch = epoch + + def set_epoch(self, epoch: int): + pass + + def set_sharded_annotation_file(self, data_epoch: int): + if self.coco is not None: + return + + assert g_pathmgr.isfile( + self.annFile + ), f"please provide valid annotation file. Missing: {self.annFile}" + annFile = g_pathmgr.get_local_path(self.annFile) + + if self.coco is not None: + del self.coco + + self.coco = self.coco_json_loader(annFile) + # Use a torch tensor here to optimize memory usage when using several dataloaders + ids_list = list(sorted(self.coco.getDatapointIds())) + if self.limit_ids is not None: + local_random = random.Random(len(ids_list)) + local_random.shuffle(ids_list) + ids_list = ids_list[: self.limit_ids] + self.ids = torch.as_tensor(ids_list, dtype=torch.long) + + def __getitem__(self, index: int) -> Datapoint: + return self._load_datapoint(index) + + def _load_datapoint(self, index: int) -> Datapoint: + """A separate method for easy overriding in subclasses.""" + id = self.ids[index].item() + pil_images, img_metadata = self._load_images(id) + queries, annotations = self.coco.loadQueriesAndAnnotationsFromDatapoint(id) + return self.load_queries(pil_images, annotations, queries, img_metadata) + + def load_queries(self, pil_images, annotations, queries, img_metadata): + """Transform the raw image and queries into a Datapoint sample.""" + images: List[Image] = [] + id2index_img = {} + id2index_obj = {} + id2index_find_query = {} + id2imsize = {} + assert len(pil_images) == len(img_metadata) + for i in range(len(pil_images)): + w, h = pil_images[i][1].size + blurring_mask = None + if "blurring_mask" in img_metadata[i]: + blurring_mask = img_metadata[i]["blurring_mask"] + images.append( + Image( + data=pil_images[i][1], + objects=[], + size=(h, w), + blurring_mask=blurring_mask, + ) + ) + id2index_img[pil_images[i][0]] = i + id2imsize[pil_images[i][0]] = (h, w) + + for annotation in annotations: + image_id = id2index_img[annotation["image_id"]] + bbox = box_xywh_to_xyxy(torch.as_tensor(annotation["bbox"])).view(1, 4) + h, w = id2imsize[annotation["image_id"]] + bbox[:, 0::2].mul_(w).clamp_(min=0, max=w) + bbox[:, 1::2].mul_(h).clamp_(min=0, max=h) + segment = None + if self.load_segmentation and "segmentation" in annotation: + # We're not decoding the RLE here, a transform will do it lazily later + segment = annotation["segmentation"] + images[image_id].objects.append( + Object( + bbox=bbox[0], + area=annotation["area"], + object_id=( + annotation["object_id"] if "object_id" in annotation else -1 + ), + frame_index=( + annotation["frame_index"] if "frame_index" in annotation else -1 + ), + segment=segment, + is_crowd=( + annotation["is_crowd"] if "is_crowd" in annotation else None + ), + source=annotation["source"] if "source" in annotation else "", + ) + ) + id2index_obj[annotation["id"]] = len(images[image_id].objects) - 1 + + find_queries = [] + stage2num_queries = Counter() + for i, query in enumerate(queries): + stage2num_queries[query["query_processing_order"]] += 1 + id2index_find_query[query["id"]] = i + + # Sanity check: all the stages should have the same number of queries + if len(stage2num_queries) == 0: + num_queries_per_stage = 0 + else: + num_queries_per_stage = stage2num_queries.most_common(1)[0][1] + for stage, num_queries in stage2num_queries.items(): + assert ( + num_queries == num_queries_per_stage + ), f"Number of queries in stage {stage} is {num_queries}, expected {num_queries_per_stage}" + + for query in queries: + h, w = id2imsize[query["image_id"]] + if ( + "input_box" in query + and query["input_box"] is not None + and len(query["input_box"]) > 0 + ): + bbox = box_xywh_to_xyxy(torch.as_tensor(query["input_box"])).view(-1, 4) + bbox[:, 0::2].mul_(w).clamp_(min=0, max=w) + bbox[:, 1::2].mul_(h).clamp_(min=0, max=h) + if "input_box_label" in query and query["input_box_label"] is not None: + bbox_label = torch.as_tensor( + query["input_box_label"], dtype=torch.long + ).view(-1) + assert len(bbox_label) == len(bbox) + else: + # assume the boxes are positives + bbox_label = torch.ones(len(bbox), dtype=torch.long) + else: + bbox = None + bbox_label = None + + if "input_points" in query and query["input_points"] is not None: + points = torch.as_tensor(query["input_points"]).view(1, -1, 3) + points[:, :, 0:1].mul_(w).clamp_(min=0, max=w) + points[:, :, 1:2].mul_(h).clamp_(min=0, max=h) + else: + points = None + + try: + original_image_id = int( + img_metadata[id2index_img[query["image_id"]]]["original_img_id"] + ) + except ValueError: + original_image_id = -1 + + try: + img_metadata_query = img_metadata[id2index_img[query["image_id"]]] + coco_image_id = ( + int(img_metadata_query["coco_img_id"]) + if "coco_img_id" in img_metadata_query + else query["id"] + ) + except KeyError: + coco_image_id = -1 + + try: + original_category_id = int(query["original_cat_id"]) + except (ValueError, KeyError): + original_category_id = -1 + + # For evaluation, we associate the ids of the object to be tracked to the query + if query["object_ids_output"]: + obj_id = query["object_ids_output"][0] + obj_idx = id2index_obj[obj_id] + image_idx = id2index_img[query["image_id"]] + object_id = images[image_idx].objects[obj_idx].object_id + frame_index = images[image_idx].objects[obj_idx].frame_index + else: + object_id = -1 + frame_index = -1 + + find_queries.append( + FindQueryLoaded( + # id=query["id"], + # query_type=qtype, + query_text=( + query["query_text"] if query["query_text"] is not None else "" + ), + image_id=id2index_img[query["image_id"]], + input_bbox=bbox, + input_bbox_label=bbox_label, + input_points=points, + object_ids_output=[ + id2index_obj[obj_id] for obj_id in query["object_ids_output"] + ], + is_exhaustive=query["is_exhaustive"], + is_pixel_exhaustive=( + query["is_pixel_exhaustive"] + if "is_pixel_exhaustive" in query + else ( + query["is_exhaustive"] if query["is_exhaustive"] else None + ) + ), + query_processing_order=query["query_processing_order"], + inference_metadata=InferenceMetadata( + coco_image_id=-1 if self.training else coco_image_id, + original_image_id=(-1 if self.training else original_image_id), + frame_index=frame_index, + original_category_id=original_category_id, + original_size=(h, w), + object_id=object_id, + ), + ) + ) + + return Datapoint( + find_queries=find_queries, + images=images, + raw_images=[p[1] for p in pil_images], + ) + + def __len__(self) -> int: + return len(self.ids) + + +class Sam3ImageDataset(CustomCocoDetectionAPI): + def __init__( + self, + img_folder, + ann_file, + transforms, + max_ann_per_img: int, + multiplier: int, + training: bool, + load_segmentation: bool = False, + max_train_queries: int = 81, + max_val_queries: int = 300, + fix_fname: bool = False, + is_sharded_annotation_dir: bool = False, + blurring_masks_path: Optional[str] = None, + use_caching: bool = True, + zstd_dict_path=None, + filter_query=None, + coco_json_loader: Callable = COCO_FROM_JSON, + limit_ids: int = None, + ): + super(Sam3ImageDataset, self).__init__( + img_folder, + ann_file, + fix_fname=fix_fname, + load_segmentation=load_segmentation, + training=training, + blurring_masks_path=blurring_masks_path, + use_caching=use_caching, + zstd_dict_path=zstd_dict_path, + filter_query=filter_query, + coco_json_loader=coco_json_loader, + limit_ids=limit_ids, + ) + + self._transforms = transforms + self.training = training + self.max_ann_per_img = max_ann_per_img + self.max_train_queries = max_train_queries + self.max_val_queries = max_val_queries + + self.repeat_factors = torch.ones(len(self.ids), dtype=torch.float32) + + self.repeat_factors *= multiplier + print(f"Raw dataset length = {len(self.ids)}") + + self._MAX_RETRIES = 100 + + def __getitem__(self, idx): + return self.__orig_getitem__(idx) + + def __orig_getitem__(self, idx): + for _ in range(self._MAX_RETRIES): + try: + datapoint = super(Sam3ImageDataset, self).__getitem__(idx) + + # This can be done better by filtering the offending find queries + # However, this requires care: + # - Delete any find/get query that may depend on the deleted one + # - Re-compute the indexes in the pointers to account for the deleted finds + for q in datapoint.find_queries: + if len(q.object_ids_output) > self.max_ann_per_img: + raise DecompressionBombError( + f"Too many outputs ({len(q.object_ids_output)})" + ) + + max_queries = ( + self.max_train_queries if self.training else self.max_val_queries + ) + + if len(datapoint.find_queries) > max_queries: + raise DecompressionBombError( + f"Too many find queries ({len(datapoint.find_queries)})" + ) + + if len(datapoint.find_queries) == 0: + raise DecompressionBombError("No find queries") + for transform in self._transforms: + datapoint = transform(datapoint, epoch=self.curr_epoch) + + break + except (DecompressionBombError, OSError, ValueError) as error: + sys.stderr.write(f"ERROR: got loading error on datapoint {idx}\n") + sys.stderr.write(f"Exception: {error}\n") + sys.stderr.write(traceback.format_exc()) + idx = (idx + 1) % len(self) + else: + raise RuntimeError( + f"Failed {self._MAX_RETRIES} times trying to load an image." + ) + + return datapoint diff --git a/third_party/sam3/sam3/train/data/sam3_video_dataset.py b/third_party/sam3/sam3/train/data/sam3_video_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..3be37a4906ba48ddf5836281820290031576022c --- /dev/null +++ b/third_party/sam3/sam3/train/data/sam3_video_dataset.py @@ -0,0 +1,328 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +import copy +import io +import json +import logging +import math +import os +import pickle +import random +import sys +from typing import Any, Dict, List, Optional, Set, Tuple, Union + +import torch +import torchvision + +# from decord import cpu, VideoReader + +from iopath.common.file_io import PathManager +from PIL import Image as PILImage + +from .sam3_image_dataset import Datapoint, Sam3ImageDataset + + +SEED = 42 + + +class VideoGroundingDataset(Sam3ImageDataset): + def __init__( + self, + num_stages_sample: int = 4, + stage_stride_min: int = 1, + stage_stride_max: int = 5, + random_reverse_time_axis: bool = True, + is_tiling_single_image: bool = False, + # By default, we remove find those queries with geometric inputs (input_box or input_points) + # when creating synthetic videos from frames (since they are not *video-level* text prompts). + # If we need them later, we can sample them on-the-fly via transforms or inside the model. + tile_img_keep_find_queries_with_geo_inputs: bool = False, + tile_img_keep_get_queries: bool = False, + # the maximum number of find queries (for each frame) to keep in a video; if the datapoint + # contains more queries per frame than this limit, we subsample them to avoid OOM errors + max_query_num: int = -1, # the default -1 means no limit + # whether to override the "is_exhaustive" flag of the loaded find queries to True + # (by default, our video datasets are ingested with is_exhaustive=False, since the YTVIS format + # annotations doesn't involve an "is_exhaustive" flag; this means that those unmatched (negative) + # detection queries or tracking queries do not receive a classification loss given that we have + # weak_loss=True in IABCEMdetr -- this could lead to false positives for both image detection + # and video association.) + override_query_is_exhaustive_to_true: bool = False, + # the maximum number of masklets in a video; if the datapoint contains more masklets + # than this limit, we skip the datapoint to avoid OOM errors (this is useful for + # training with large videos that contain many objects) + max_masklet_num_in_video: int = 300, # 300 masklets is usually OK to avoid OOM + **kwargs, + ): + """ + Loading video grounding data + + Video frame sampling parameters (for training only): + - num_stages_sample: number of frames to sample from the video during training + - stage_stride_min: minimum stride between sampled frames during training + - stage_stride_max: maximum stride between sampled frames during training (if it's + greater than stage_stride_min, the actual stride is sampled uniformly between min + and max; during inference, we always use all frames in the video with stride=1) + - random_reverse_time_axis: whether to randomly invert the video's temporal axis + (i.e. playing it backwards) during training + """ + super().__init__(**kwargs) + assert num_stages_sample >= 1 + assert stage_stride_min >= 1 + assert stage_stride_max >= stage_stride_min + self.num_stages_sample = num_stages_sample + self.stage_stride_min = stage_stride_min + self.stage_stride_max = stage_stride_max + self.random_reverse_time_axis = random_reverse_time_axis + self.is_tiling_single_image = is_tiling_single_image + self.tile_img_keep_find_queries_with_geo_inputs = ( + tile_img_keep_find_queries_with_geo_inputs + ) + self.tile_img_keep_get_queries = tile_img_keep_get_queries + self.max_query_num = max_query_num + self.override_query_is_exhaustive_to_true = override_query_is_exhaustive_to_true + self.max_masklet_num_in_video = max_masklet_num_in_video + self.rng = random.Random() + self.set_curr_epoch(0) + + def set_curr_epoch(self, epoch: int): + super().set_curr_epoch(epoch) + self.rng.seed(SEED + epoch) + + def _load_datapoint(self, index: int) -> Datapoint: + id = self.ids[index].item() + queries, annotations = self.coco.loadQueriesAndAnnotationsFromDatapoint(id) + + # we subsample the video frames during training + if self.training and not self.is_tiling_single_image: + # pick a random stride for sampling query stages (`randint` includes both ends) + stage_stride = self.rng.randint( + self.stage_stride_min, self.stage_stride_max + ) + stage_ids_to_keep = self._sample_stage_ids( + queries, self.num_stages_sample, stage_stride + ) + # filter the queries and annotations to keep only the selected stages + # (also remap the stage ids so that they are contiguous and start from 0) + reverse_time_axis = ( + self.rng.random() < 0.5 if self.random_reverse_time_axis else False + ) + queries, annotations, kept_img_ids = self._filter_query_and_anns( + queries, + annotations, + stage_ids_to_keep, + remap_stage_id=True, + reverse_time_axis=reverse_time_axis, + ) + pil_images, img_metadata = self._load_images(id, kept_img_ids) + if reverse_time_axis: + # reverse the temporal ordering of the images and their metadata + # so that the image order matches the query order + pil_images = pil_images[::-1] + img_metadata = img_metadata[::-1] + else: + pil_images, img_metadata = self._load_images(id) + + # check that all the images have the same image size (they are expected + # to have the same image size since they are frames from the same video) + assert all(p.size == pil_images[0][1].size for _, p in pil_images) + + queries.sort(key=lambda q: q["query_processing_order"]) + if self.override_query_is_exhaustive_to_true: + for query in queries: + query["is_exhaustive"] = True + datapoint = self.load_queries(pil_images, annotations, queries, img_metadata) + + # skip datapoints with too many masklets to avoid OOM errors + num_masklets_in_video = len(datapoint.images[0].objects) + if num_masklets_in_video > self.max_masklet_num_in_video > 0: + logging.warning( + f"Datapoint {id} has ({num_masklets_in_video=}), exceeding " + f"the maximum allowed ({self.max_masklet_num_in_video}). " + "Skipping this datapoint." + ) + next_index = (index + 1) % len(self) + return self._load_datapoint(next_index) # move to the next datapoint + + if self.is_tiling_single_image: + datapoint = self._tile_single_image_data(datapoint, self.num_stages_sample) + if self.max_query_num > 0: + datapoint = self._subsample_queries(datapoint, self.max_query_num) + + # ensure that all find queries have the same processing order as their image id + for query in datapoint.find_queries: + assert query.image_id == query.query_processing_order, ( + f"find query has inconsistent image_id and " + f"query_processing_order: {query.image_id=} vs " + f"{query.query_processing_order=}" + ) + return datapoint + + def _sample_stage_ids(self, queries, num_stages_sample, stage_stride): + """Sample a subset of stage ids from all queries.""" + # Later we can perhaps turn it into a Sampler class to be more flexible. + all_stage_ids = sorted(set(q["query_processing_order"] for q in queries)) + num_stages_total = len(all_stage_ids) + if num_stages_total < num_stages_sample: + raise ValueError("Not enough stages to sample") + + # the difference in index between the first and the last sampled stage ids + b_e_gap = (num_stages_sample - 1) * stage_stride + if b_e_gap > num_stages_total - 1: + # In this case, it's not possible to sample with the provide stride, + # so we use the maximum possible stride. + prev_stage_stride = stage_stride + stage_stride = math.floor((num_stages_total - 1) / (num_stages_sample - 1)) + logging.info( + f"lowering stride from {prev_stage_stride} to {stage_stride} to " + f"sample {num_stages_sample} stages (from {num_stages_total} total)" + ) + b_e_gap = (num_stages_sample - 1) * stage_stride + + # randomly select a starting stage id (`randint` includes both ends) + b_max = len(all_stage_ids) - 1 - b_e_gap + b = self.rng.randint(0, b_max) + e = b + b_e_gap + stage_ids_to_keep = all_stage_ids[b : e + 1 : stage_stride] + return stage_ids_to_keep + + def _filter_query_and_anns( + self, queries, annotations, stage_ids_to_keep, remap_stage_id, reverse_time_axis + ): + """Filter queries and annotations to only keep those in `stage_ids_to_keep`.""" + stage_ids_to_keep = set(stage_ids_to_keep) + kept_img_ids = set() + kept_stage_ids = set() + + # Filter queries -- keep those queries with stage_id in `stage_ids_to_keep` + filtered_queries = [] + for query in queries: + input_box = query.get("input_box", None) + input_points = query.get("input_points", None) + has_geo_input = input_box is not None or input_points is not None + if has_geo_input and not self.tile_img_keep_find_queries_with_geo_inputs: + continue + stage_id = query["query_processing_order"] + if stage_id in stage_ids_to_keep: + kept_img_ids.add(query["image_id"]) + kept_stage_ids.add(stage_id) + filtered_queries.append(query) + # Check that all frames in `stage_ids_to_keep` are present after filtering + all_frame_present = kept_stage_ids == stage_ids_to_keep + assert all_frame_present, f"{kept_stage_ids=} vs {stage_ids_to_keep=}" + if remap_stage_id: + # Remap those kept stage ids to be contiguous and starting from 0 + old_stage_ids = sorted(kept_stage_ids, reverse=reverse_time_axis) + stage_id_old2new = {old: new for new, old in enumerate(old_stage_ids)} + for query in filtered_queries: + ptr_x_is_empty = query["ptr_x_query_id"] in [None, -1] + ptr_y_is_empty = query["ptr_y_query_id"] in [None, -1] + assert ( + ptr_x_is_empty and ptr_y_is_empty + ), "Remapping stage ids is not supported for queries with non-empty ptr_x or ptr_y pointers" + query["query_processing_order"] = stage_id_old2new[ + query["query_processing_order"] + ] + + # Filter annotations -- keep those annotations with image_id in `kept_img_ids` + filtered_annotations = [ + ann for ann in annotations if ann["image_id"] in kept_img_ids + ] + + return filtered_queries, filtered_annotations, kept_img_ids + + def _tile_single_image_data(self, datapoint: Datapoint, num_stages_sample: int): + """ + Tile a single image and its queries to simulate video frames. The output is a + datapoint with *identical video frames* (i.e. the same static image) and needs + further transforms (e.g. affine) to get video frames with different content. + """ + # tile `images: List[Image]` + assert len(datapoint.images) == 1, "Expected only one single image" + tiled_images = [ + copy.deepcopy(datapoint.images[0]) for _ in range(num_stages_sample) + ] + for stage_id, img in enumerate(tiled_images): + for obj in img.objects: + obj.frame_index = stage_id + + # tile `raw_images: Optional[List[PILImage.Image]] = None` + tiled_raw_images = None + if datapoint.raw_images is not None: + assert len(datapoint.raw_images) == 1, "Expected only one single image" + tiled_raw_images = [ + datapoint.raw_images[0].copy() for _ in range(num_stages_sample) + ] + + # tile `find_queries: List[FindQueryLoaded]` + tiled_find_queries_per_stage = [[] for _ in range(num_stages_sample)] + for query in datapoint.find_queries: + assert query.image_id == 0 + assert query.query_processing_order == 0 + # check and make sure that a query doesn't contain pointers or references + # to other queries (that cannot be tiled) + assert query.ptr_x is None and query.ptr_y is None + assert query.ptr_mem is None + # assert query.wkdata_qid is None + # assert query.other_positive_qids is None + # assert query.negative_qids is None + has_geo_input = ( + query.input_bbox is not None or query.input_points is not None + ) + if has_geo_input and not self.tile_img_keep_find_queries_with_geo_inputs: + continue + for stage_id in range(num_stages_sample): + # copy the query and update the image_id + new_query = copy.deepcopy(query) + new_query.image_id = stage_id + new_query.query_processing_order = stage_id + if new_query.inference_metadata is not None: + new_query.inference_metadata.frame_index = stage_id + tiled_find_queries_per_stage[stage_id].append(new_query) + + tiled_find_queries = sum(tiled_find_queries_per_stage, []) + + # tile `get_queries: List[GetQuery]` -- we skip them for now (since they involve + # a pointer to a find query that is complicated to tile, and there is not an + # imminent use case for them in the video grounding task in the near future) + if self.tile_img_keep_get_queries: + raise NotImplementedError("Tiling get queries is not implemented yet") + else: + tiled_get_queries = [] + + return Datapoint( + images=tiled_images, + raw_images=tiled_raw_images, + find_queries=tiled_find_queries, + get_queries=tiled_get_queries, + ) + + def _subsample_queries(self, datapoint: Datapoint, max_query_num: int): + """Subsample to keep at most `max_query_num` queries per frame in a datapoint.""" + # aggregate the find queries per stage + num_frames = max(q.query_processing_order for q in datapoint.find_queries) + 1 + find_queries_per_stage = [[] for _ in range(num_frames)] + for query in datapoint.find_queries: + find_queries_per_stage[query.query_processing_order].append(query) + + # verify that all the stages have the same number of queries + num_queries_per_stage = len(find_queries_per_stage[0]) + for queries in find_queries_per_stage: + assert len(queries) == num_queries_per_stage + if max_query_num <= 0 or num_queries_per_stage <= max_query_num: + return datapoint + + # subsample the queries to keep only `max_query_num` queries + sampled_inds = self.rng.sample(range(num_queries_per_stage), max_query_num) + sampled_find_queries_per_stage = [ + [queries[idx] for idx in sampled_inds] for queries in find_queries_per_stage + ] + sampled_find_queries = sum(sampled_find_queries_per_stage, []) + return Datapoint( + images=datapoint.images, + raw_images=datapoint.raw_images, + find_queries=sampled_find_queries, + get_queries=datapoint.get_queries, + ) diff --git a/third_party/sam3/sam3/train/data/torch_dataset.py b/third_party/sam3/sam3/train/data/torch_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..5ec9f680f2695acfa1aded2f460da630fa48c7a1 --- /dev/null +++ b/third_party/sam3/sam3/train/data/torch_dataset.py @@ -0,0 +1,54 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +from typing import Callable, Iterable, Optional + +from torch.utils.data import DataLoader, Dataset, DistributedSampler, IterableDataset + + +class TorchDataset: + def __init__( + self, + dataset: Dataset, + batch_size: int, + num_workers: int, + shuffle: bool, + pin_memory: bool, + drop_last: bool, + collate_fn: Optional[Callable] = None, + worker_init_fn: Optional[Callable] = None, + enable_distributed_sampler=True, + ) -> None: + self.dataset = dataset + self.batch_size = batch_size + self.num_workers = num_workers + self.shuffle = shuffle + self.pin_memory = pin_memory + self.drop_last = drop_last + self.collate_fn = collate_fn + self.worker_init_fn = worker_init_fn + assert not isinstance(self.dataset, IterableDataset), "Not supported yet" + if enable_distributed_sampler: + self.sampler = DistributedSampler(self.dataset, shuffle=self.shuffle) + else: + self.sampler = None + + def get_loader(self, epoch) -> Iterable: + if self.sampler: + self.sampler.set_epoch(epoch) + if hasattr(self.dataset, "epoch"): + self.dataset.epoch = epoch + if hasattr(self.dataset, "set_epoch"): + self.dataset.set_epoch(epoch) + + return DataLoader( + self.dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + drop_last=self.drop_last, + sampler=self.sampler, + collate_fn=self.collate_fn, + worker_init_fn=self.worker_init_fn, + ) diff --git a/third_party/sam3/sam3/train/loss/__init__.py b/third_party/sam3/sam3/train/loss/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..726cb1eb0ff0f17f7edc4ec73f1a73d1ac87bf59 --- /dev/null +++ b/third_party/sam3/sam3/train/loss/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe diff --git a/third_party/sam3/sam3/train/loss/loss_fns.py b/third_party/sam3/sam3/train/loss/loss_fns.py new file mode 100644 index 0000000000000000000000000000000000000000..ce7a677775ec446ee17c4eee0a8353dc4cc0bbf9 --- /dev/null +++ b/third_party/sam3/sam3/train/loss/loss_fns.py @@ -0,0 +1,1326 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +import logging +import warnings + +import torch +import torch.distributed +import torch.nn.functional as F +import torchmetrics +from sam3.model import box_ops +from sam3.model.data_misc import interpolate +from sam3.train.loss.sigmoid_focal_loss import ( + triton_sigmoid_focal_loss, + triton_sigmoid_focal_loss_reduce, +) +from torch import nn + +from .mask_sampling import ( + calculate_uncertainty, + get_uncertain_point_coords_with_randomness, + point_sample, +) + + +CORE_LOSS_KEY = "core_loss" + + +def instance_masks_to_semantic_masks( + instance_masks: torch.Tensor, num_instances: torch.Tensor +) -> torch.Tensor: + """This function converts instance masks to semantic masks. + It accepts a collapsed batch of instances masks (ie all instance masks are concatenated in a single tensor) and + the number of instances in each image of the batch. + It returns a mask with the same spatial dimensions as the input instance masks, where for each batch element the + semantic mask is the union of all the instance masks in the batch element. + + If for a given batch element there are no instances (ie num_instances[i]==0), the corresponding semantic mask will be a tensor of zeros. + + Args: + instance_masks (torch.Tensor): A tensor of shape (N, H, W) where N is the number of instances in the batch. + num_instances (torch.Tensor): A tensor of shape (B,) where B is the batch size. It contains the number of instances + in each image of the batch. + + Returns: + torch.Tensor: A tensor of shape (B, H, W) where B is the batch size and H, W are the spatial dimensions of the + input instance masks. + """ + if num_instances.sum() == 0: + # all negative batch, create a tensor of zeros (B, 1, 1) + return num_instances.unsqueeze(-1).unsqueeze(-1) + + masks_per_query = torch.split(instance_masks, num_instances.tolist()) + + return torch.stack([torch.any(masks, dim=0) for masks in masks_per_query], dim=0) + + +@torch.no_grad() +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k""" + if target.numel() == 0: + return [torch.zeros([], device=output.device)] + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def dice_loss(inputs, targets, num_boxes, loss_on_multimask=False, reduce=True): + """ + Compute the DICE loss, similar to generalized IOU for masks + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + """ + try: + loss = _dice_loss(inputs, targets, num_boxes, loss_on_multimask, reduce) + except torch.OutOfMemoryError: + logging.error("GPU OOM, computing dice loss on CPU") + # try to recover from GPU OOM by moving tensors to CPU and computing loss there + orig_device = inputs.device + inputs = inputs.cpu() + targets = targets.cpu() + if isinstance(num_boxes, torch.Tensor): + num_boxes = num_boxes.cpu() + loss = _dice_loss(inputs, targets, num_boxes, loss_on_multimask, reduce) + loss = loss.to(orig_device) + + return loss + + +def _dice_loss(inputs, targets, num_boxes, loss_on_multimask=False, reduce=True): + inputs = inputs.sigmoid() + if loss_on_multimask: + # inputs and targets are [N, M, H, W] where M corresponds to multiple predicted masks + assert inputs.dim() == 4 and targets.dim() == 4 + # flatten spatial dimension while keeping multimask channel dimension + inputs = inputs.flatten(2) + targets = targets.flatten(2) + numerator = 2 * (inputs * targets).sum(-1) + else: + inputs = inputs.flatten(1) + numerator = 2 * (inputs * targets).sum(1) + denominator = inputs.sum(-1) + targets.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + if loss_on_multimask: + return loss / num_boxes + if not reduce: + return loss + return loss.sum() / num_boxes + + +def sigmoid_focal_loss( + inputs, + targets, + num_boxes, + alpha: float = 0.25, + gamma: float = 2, + loss_on_multimask=False, + reduce=True, + triton=True, +): + """ + Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + alpha: (optional) Weighting factor in range (0,1) to balance + positive vs negative examples. Default = -1 (no weighting). + gamma: Exponent of the modulating factor (1 - p_t) to + balance easy vs hard examples. + Returns: + Loss tensor + """ + if not (0 <= alpha <= 1) and triton: + raise RuntimeError(f"Alpha should be in [0,1], got {alpha}") + if triton: + if reduce and not loss_on_multimask: + loss = triton_sigmoid_focal_loss_reduce(inputs, targets, alpha, gamma) + return loss / (num_boxes * inputs.shape[1]) + + loss = triton_sigmoid_focal_loss(inputs, targets, alpha, gamma) + else: + prob = inputs.sigmoid() + ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + p_t = prob * targets + (1 - prob) * (1 - targets) + loss = ce_loss * ((1 - p_t) ** gamma) + + if alpha >= 0: + alpha_t = alpha * targets + (1 - alpha) * (1 - targets) + loss = alpha_t * loss + + if not reduce: + return loss + + if loss_on_multimask: + # loss is [N, M, H, W] where M corresponds to multiple predicted masks + assert loss.dim() == 4 + return loss.flatten(2).mean(-1) / num_boxes # average over spatial dims + return loss.mean(1).sum() / num_boxes + + +def iou_loss( + inputs, targets, pred_ious, num_boxes, loss_on_multimask=False, use_l1_loss=False +): + """MSE loss between predicted IoUs and actual IoUs between inputs and targets.""" + assert inputs.dim() == 4 and targets.dim() == 4 + pred_mask = inputs.flatten(2) > 0 + gt_mask = targets.flatten(2) > 0 + area_i = torch.sum(pred_mask & gt_mask, dim=-1).float() + area_u = torch.sum(pred_mask | gt_mask, dim=-1).float() + actual_ious = area_i / torch.clamp(area_u, min=1.0) + + if use_l1_loss: + loss = F.l1_loss(pred_ious, actual_ious, reduction="none") + else: + loss = F.mse_loss(pred_ious, actual_ious, reduction="none") + if loss_on_multimask: + return loss / num_boxes + return loss.sum() / num_boxes + + +@torch.jit.script +def _contrastive_align(logits, positive_map): + positive_logits = -logits.masked_fill(~positive_map, 0) + negative_logits = logits # .masked_fill(positive_map, -1000000) + + boxes_with_pos = positive_map.any(2) + pos_term = positive_logits.sum(2) + neg_term = negative_logits.logsumexp(2) + + nb_pos = positive_map.sum(2) + 1e-6 + + box_to_token_loss = ( + (pos_term / nb_pos + neg_term).masked_fill(~boxes_with_pos, 0).sum() + ) + + tokens_with_pos = positive_map.any(1) + pos_term = positive_logits.sum(1) + neg_term = negative_logits.logsumexp(1) + + nb_pos = positive_map.sum(1) + 1e-6 + + tokens_to_boxes_loss = ( + (pos_term / nb_pos + neg_term).masked_fill(~tokens_with_pos, 0).sum() + ) + return (box_to_token_loss + tokens_to_boxes_loss) / 2 + + +def _get_src_permutation_idx(indices): + # permute predictions following indices + batch_idx = torch.cat( + [torch.full_like(src, i) for i, (src, _) in enumerate(indices)] + ) + src_idx = torch.cat([src for (src, _) in indices]) + return batch_idx, src_idx + + +class LossWithWeights(nn.Module): + def __init__(self, weight_dict, compute_aux, supports_o2m_loss=True): + super().__init__() + # weights for each computed loss key (those losses not in weight_dict + # will not be aggregated in the final reduced core loss) + self.weight_dict = weight_dict if weight_dict is not None else {} + # whether this loss will be applied on auxiliary outputs + self.compute_aux = compute_aux + self.supports_o2m_loss = supports_o2m_loss + self.target_keys = [] + + def forward(self, *args, is_aux=False, **kwargs): + if is_aux and not self.compute_aux: + return {CORE_LOSS_KEY: 0.0} + losses = self.get_loss(*args, **kwargs) + losses[CORE_LOSS_KEY] = self.reduce_loss(losses) + return losses + + def get_loss(self, **kwargs): + raise NotImplementedError() + + def reduce_loss(self, losses): + reduced_loss = 0.0 + for loss_key, weight in self.weight_dict.items(): + if loss_key not in losses: + raise ValueError(f"{type(self)} doesn't compute {loss_key}") + if weight != 0: + reduced_loss += losses[loss_key] * weight + + return reduced_loss + + +class IABCEMdetr(LossWithWeights): + def __init__( + self, + pos_weight, + weight_dict=None, + compute_aux=True, + gamma=0, + weak_loss=True, + alpha=0.25, + pad_n_queries=None, + pad_scale_pos=1.0, + use_separate_loss_for_det_and_trk=False, + num_det_queries=None, + det_exhaustive_loss_scale_pos=1.0, + det_exhaustive_loss_scale_neg=1.0, + det_non_exhaustive_loss_scale_pos=1.0, + det_non_exhaustive_loss_scale_neg=1.0, + trk_loss_scale_pos=1.0, + trk_loss_scale_neg=1.0, + no_loss_for_fp_propagation=False, + apply_loss_to_det_queries_in_video_grounding=True, + use_presence=False, + use_presence_semgseg=False, # If True, use presence scores from the semgseg head. + presence_alpha=0.5, + presence_gamma=0.0, + pos_focal: bool = False, # for box scores, use focal loss for positives as well + ): + super().__init__(weight_dict, compute_aux) + self.pos_weight = pos_weight + self.gamma = gamma + self.weak_loss = weak_loss + self.alpha = alpha + self.target_keys.append("boxes_xyxy") + self.no_loss_for_fp_propagation = no_loss_for_fp_propagation + if self.weak_loss: + self.target_keys.append("is_exhaustive") + # NOTE: This is hacky solution to have the same CE loss scale across datasets where the model might predict different number of object queries for different tasks. + # If not None, we assume there are a total pad_n_queries object queries. + # For example, if the model predicts only 1 object query and pad_n_queries=100, we pad the predictions with 99 zero preds. + # Currently this only affects the BCE loss and not the F1 score. + self.pad_n_queries = pad_n_queries + self.pad_scale_pos = pad_scale_pos + if self.pad_scale_pos != 1.0: + assert self.pad_n_queries is not None + # whether to use presence scores + self.use_presence = use_presence + self.use_presence_semgseg = use_presence_semgseg + if self.use_presence_semgseg: + assert self.use_presence + self.presence_alpha = presence_alpha + self.presence_gamma = presence_gamma + self.pos_focal = pos_focal + + # Decoupled loss for detection and tracking queries + self.apply_loss_to_det_queries_in_video_grounding = ( + apply_loss_to_det_queries_in_video_grounding + ) + self.use_separate_loss_for_det_and_trk = use_separate_loss_for_det_and_trk + if num_det_queries is not None: + logging.warning("note: it's not needed to set num_det_queries anymore") + if self.use_separate_loss_for_det_and_trk: + assert not self.weak_loss, "Do not use weak_loss in this case -- set separate loss for detection and tracking queries instead" + self.det_exhaustive_loss_scale_pos = det_exhaustive_loss_scale_pos + self.det_exhaustive_loss_scale_neg = det_exhaustive_loss_scale_neg + self.det_non_exhaustive_loss_scale_pos = det_non_exhaustive_loss_scale_pos + self.det_non_exhaustive_loss_scale_neg = det_non_exhaustive_loss_scale_neg + self.trk_loss_scale_pos = trk_loss_scale_pos + self.trk_loss_scale_neg = trk_loss_scale_neg + else: + assert ( + det_exhaustive_loss_scale_pos == 1.0 + and det_exhaustive_loss_scale_neg == 1.0 + and det_non_exhaustive_loss_scale_pos == 1.0 + and det_non_exhaustive_loss_scale_neg == 1.0 + and trk_loss_scale_pos == 1.0 + and trk_loss_scale_neg == 1.0 + ), "If not using separate loss for detection and tracking queries, separate detection and tracking loss scales should all be 1.0" + + def get_loss(self, outputs, targets, indices, num_boxes): + assert len(outputs["pred_logits"].shape) > 2, "Incorrect predicted logits shape" + assert outputs["pred_logits"].shape[-1] == 1, "Incorrect predicted logits shape" + src_logits = outputs["pred_logits"].squeeze(-1) + prob = src_logits.sigmoid() + + with torch.no_grad(): + target_classes = torch.full( + src_logits.shape[:2], + 0, + dtype=torch.float, + device=src_logits.device, + ) + target_classes[(indices[0], indices[1])] = 1 + src_boxes_xyxy = outputs["pred_boxes_xyxy"][(indices[0], indices[1])] + target_boxes_giou = ( + targets["boxes_xyxy"][indices[2]] + if indices[2] is not None + else targets["boxes_xyxy"] + ) + + iou = box_ops.fast_diag_box_iou(src_boxes_xyxy, target_boxes_giou) + t = prob[(indices[0], indices[1])] ** self.alpha * iou ** (1 - self.alpha) + t = torch.clamp(t, 0.01).detach() + positive_target_classes = target_classes.clone() + positive_target_classes[(indices[0], indices[1])] = t + + # Soft loss on positives + if self.pos_focal: + loss_bce = sigmoid_focal_loss( + src_logits.contiguous(), + positive_target_classes, + num_boxes=1, + alpha=0.5, + gamma=self.gamma, + reduce=False, + ) + else: + loss_bce = F.binary_cross_entropy_with_logits( + src_logits, positive_target_classes, reduction="none" + ) + loss_bce = loss_bce * target_classes * self.pos_weight + + if ( + self.pad_n_queries is not None + and isinstance(self.pad_n_queries, int) + and loss_bce.size(1) < self.pad_n_queries + ): + loss_bce = loss_bce * self.pad_scale_pos + # Negatives + loss_neg = F.binary_cross_entropy_with_logits( + src_logits, target_classes, reduction="none" + ) * (1 - target_classes) * (prob**self.gamma) + + # Suppress negative loss for predictions overlapping ignore boxes + if "ignore_neg_mask" in targets: + neg_suppress = targets["ignore_neg_mask"] # (B, S) + loss_neg = loss_neg * (1 - neg_suppress) + + loss_bce = loss_bce + loss_neg + + # Optionally, not applying IABCEMdetr loss to detection queries in video. + is_video_grounding = outputs.get("is_video_grounding_batch", False) + if is_video_grounding and not self.apply_loss_to_det_queries_in_video_grounding: + Q_det = outputs["Q_det"] + loss_bce[:, :Q_det] *= 0.0 + presence_loss = torch.tensor(0.0, device=src_logits.device) + presence_dec_acc = torch.tensor(0.0, device=src_logits.device) + if self.use_presence: + # no classifiction loss for individual tokens if no target gt + # cannot directly use targets["num_boxes"] to check if some + # GT box exists as there may be dummy boxes for "invisible objects" + # in video grounding data + + gt_padded_object_ids = targets["object_ids_padded"] # (B, H) + gt_padded_boxes = targets["boxes_padded"] # (B, H, 4) shape, CxCyWH + gt_padded_is_visible = ( + (gt_padded_object_ids >= 0) + & (gt_padded_boxes[..., 2] > 0) # width > 0 + & (gt_padded_boxes[..., 3] > 0) # height > 0 + ) + keep_loss = (gt_padded_is_visible.sum(dim=-1)[..., None] != 0).float() + + loss_bce = loss_bce * keep_loss + + if self.use_presence_semgseg: + # no loss here, has it's own separate loss computation + assert "presence_logit_dec" not in outputs + elif "presence_logit_dec" in outputs: + presence_logits = outputs["presence_logit_dec"].view_as(keep_loss) + bs = presence_logits.shape[0] + presence_loss = sigmoid_focal_loss( + presence_logits, + keep_loss, + # not num_boxes, but we'll use it to normalize by bs + num_boxes=bs, + alpha=self.presence_alpha, + gamma=self.presence_gamma, + triton=False, # triton kernel unstable with gamma=0 + bf16 + ) + pred = (presence_logits.sigmoid() > 0.5).float() + presence_dec_acc = (pred == keep_loss).float().mean() + else: + # for o2m, nothing to do + pass + + if self.weak_loss: + assert not self.use_separate_loss_for_det_and_trk, "Do not use weak_loss in this case -- set separate loss for detection and tracking queries instead" + + # nullify the negative loss for the non-exhaustive classes + assert loss_bce.shape[0] == targets["is_exhaustive"].shape[0] + assert targets["is_exhaustive"].ndim == 1 + + loss_mask = (~targets["is_exhaustive"]).view(-1, 1).expand_as(loss_bce) + # restrict the mask to the negative supervision + loss_mask = loss_mask & (target_classes < 0.5) + loss_mask = ~loss_mask + # Mask the loss + loss_bce = loss_bce * loss_mask.float() + # Average + loss_bce = loss_bce.sum() / (loss_mask.sum() + 1e-6) + else: + # apply separate loss weights to detection and tracking queries + if self.use_separate_loss_for_det_and_trk: + Q_det = outputs["Q_det"] + assert loss_bce.size(1) >= Q_det + is_positive = target_classes > 0.5 + is_positive_det = is_positive[:, :Q_det] + is_positive_trk = is_positive[:, Q_det:] + assert loss_bce.size(0) == targets["is_exhaustive"].size(0) + is_exhaustive = targets["is_exhaustive"].unsqueeze(1).bool() + loss_scales = torch.zeros_like(loss_bce) + # detection query loss weights + loss_scales[:, :Q_det] = ( + (is_exhaustive & is_positive_det).float() + * self.det_exhaustive_loss_scale_pos + + (is_exhaustive & ~is_positive_det).float() + * self.det_exhaustive_loss_scale_neg + + (~is_exhaustive & is_positive_det).float() + * self.det_non_exhaustive_loss_scale_pos + + (~is_exhaustive & ~is_positive_det).float() + * self.det_non_exhaustive_loss_scale_neg + ) + # tracking query weights + loss_scales[:, Q_det:] = ( + is_positive_trk.float() * self.trk_loss_scale_pos + + (~is_positive_trk).float() * self.trk_loss_scale_neg + ) + # apply the loss weights + + # if the id is -2 means it is a fp propagation , we don't apply the loss to them + if self.no_loss_for_fp_propagation: + is_original_queries = outputs["pred_old_obj_ids"] != -2 + loss_scales *= (is_exhaustive | is_original_queries).float() + + loss_bce = loss_bce * loss_scales + + if self.pad_n_queries is None or loss_bce.size(1) >= self.pad_n_queries: + loss_bce = loss_bce.mean() + else: + assert isinstance(self.pad_n_queries, int) + assert ( + loss_bce.size(1) < self.pad_n_queries + ), f"The number of predictions is more than the expected total after padding. Got {loss_bce.size(1)} predictions." + loss_bce = loss_bce.sum() / (self.pad_n_queries * loss_bce.size(0)) + + bce_f1 = torchmetrics.functional.f1_score( + src_logits.sigmoid().flatten(), + target=target_classes.flatten().long(), + task="binary", + ) + + losses = { + "loss_ce": loss_bce, + "ce_f1": bce_f1, + "presence_loss": presence_loss, + "presence_dec_acc": presence_dec_acc, + } + return losses + + +class Boxes(LossWithWeights): + def __init__( + self, + weight_dict=None, + compute_aux=True, + apply_loss_to_det_queries_in_video_grounding=True, + ): + super().__init__(weight_dict, compute_aux) + self.apply_loss_to_det_queries_in_video_grounding = ( + apply_loss_to_det_queries_in_video_grounding + ) + self.target_keys.extend(["boxes", "boxes_xyxy"]) + + def get_loss(self, outputs, targets, indices, num_boxes): + """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss + targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] + The target boxes are expected in format (center_x, center_y, h, w), normalized by the image size. + """ + # Optionally, not applying Boxes loss to detection queries in video. + is_video_grounding = outputs.get("is_video_grounding_batch", False) + if is_video_grounding and not self.apply_loss_to_det_queries_in_video_grounding: + indices = _keep_only_trk_queries_in_match_inds( + indices, Q_det=outputs["Q_det"] + ) + + assert "pred_boxes" in outputs + # idx = self._get_src_permutation_idx(indices) + src_boxes = outputs["pred_boxes"][(indices[0], indices[1])] + src_boxes_xyxy = outputs["pred_boxes_xyxy"][(indices[0], indices[1])] + target_boxes = ( + targets["boxes"] if indices[2] is None else targets["boxes"][indices[2]] + ) + target_boxes_giou = ( + targets["boxes_xyxy"] + if indices[2] is None + else targets["boxes_xyxy"][indices[2]] + ) + + loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction="none") + + losses = {} + losses["loss_bbox"] = loss_bbox.sum() / num_boxes + + loss_giou = 1 - box_ops.fast_diag_generalized_box_iou( + src_boxes_xyxy, target_boxes_giou + ) + losses["loss_giou"] = loss_giou.sum() / num_boxes + return losses + + +class Masks(LossWithWeights): + def __init__( + self, + weight_dict=None, + compute_aux=False, + focal_alpha=0.25, + focal_gamma=2, + num_sample_points=None, + oversample_ratio=None, + importance_sample_ratio=None, + apply_loss_to_det_queries_in_video_grounding=True, + ): + super().__init__(weight_dict, compute_aux) + if compute_aux: + warnings.warn("Masks loss usually shouldn't be applied to aux outputs") + self.focal_alpha = focal_alpha + self.focal_gamma = focal_gamma + self.num_sample_points = num_sample_points + self.oversample_ratio = oversample_ratio + self.importance_sample_ratio = importance_sample_ratio + self.apply_loss_to_det_queries_in_video_grounding = ( + apply_loss_to_det_queries_in_video_grounding + ) + self.target_keys.extend(["masks", "is_valid_mask"]) + + def _sampled_loss(self, src_masks, target_masks, num_boxes): + assert len(src_masks.shape) == 3 and len(target_masks.shape) == 3 + src_masks = src_masks[:, None] + target_masks = target_masks[:, None] + with torch.no_grad(): + # Sample point_coords + point_coords = get_uncertain_point_coords_with_randomness( + src_masks, + calculate_uncertainty, + self.num_sample_points, + self.oversample_ratio, + self.importance_sample_ratio, + ) + + # get GT labels + sampled_target_masks = point_sample( + target_masks, + point_coords, + align_corners=False, + ).squeeze(1) + + sampled_src_masks = point_sample( + src_masks, + point_coords, + align_corners=False, + ).squeeze(1) + + losses = { + "loss_mask": sigmoid_focal_loss( + sampled_src_masks, + sampled_target_masks, + num_boxes, + alpha=self.focal_alpha, + gamma=self.focal_gamma, + ), + "loss_dice": dice_loss(sampled_src_masks, sampled_target_masks, num_boxes), + } + # Not needed for backward + del src_masks + del target_masks + + return losses + + def get_loss(self, outputs, targets, indices, num_boxes): + """Compute the losses related to the masks: the focal loss and the dice loss. + targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w] + """ + assert "pred_masks" in outputs + assert "is_valid_mask" in targets + # Optionally, not applying Masks loss to detection queries in video. + is_video_grounding = outputs.get("is_video_grounding_batch", False) + if is_video_grounding and not self.apply_loss_to_det_queries_in_video_grounding: + indices = _keep_only_trk_queries_in_match_inds( + indices, Q_det=outputs["Q_det"] + ) + + src_masks = outputs["pred_masks"] + + # Dataset doesn't have segmentation masks + if targets["masks"] is None: + return { + "loss_mask": torch.tensor(0.0, device=src_masks.device), + "loss_dice": torch.tensor(0.0, device=src_masks.device), + } + + target_masks = ( + targets["masks"] if indices[2] is None else targets["masks"][indices[2]] + ) + target_masks = target_masks.to(src_masks) + keep = ( + targets["is_valid_mask"] + if indices[2] is None + else targets["is_valid_mask"][indices[2]] + ) + + src_masks = src_masks[(indices[0], indices[1])] + + # Remove invalid masks from loss + src_masks = src_masks[keep] + target_masks = target_masks[keep] + + if self.num_sample_points is not None: + # Compute loss on sampled points for the Mask + losses = self._sampled_loss(src_masks, target_masks, num_boxes) + + else: + # upsample predictions to the target size + if target_masks.shape[0] == 0 and src_masks.shape[0] == 0: + src_masks = src_masks.flatten(1) + target_masks = target_masks.reshape(src_masks.shape) + else: + if len(src_masks.shape) == 3: + src_masks = src_masks[:, None] + if src_masks.dtype == torch.bfloat16: + # Bilinear interpolation does not support bf16 + src_masks = src_masks.to(dtype=torch.float32) + src_masks = interpolate( + src_masks, + size=target_masks.shape[-2:], + mode="bilinear", + align_corners=False, + ) + src_masks = src_masks[:, 0].flatten(1) + target_masks = target_masks.flatten(1) + + losses = { + "loss_mask": sigmoid_focal_loss( + src_masks, + target_masks, + num_boxes, + alpha=self.focal_alpha, + gamma=self.focal_gamma, + ), + "loss_dice": dice_loss(src_masks, target_masks, num_boxes), + } + + return losses + + +# class MultiStepIteractiveMasks(LossWithWeights): +# def __init__( +# self, +# weight_dict=None, +# compute_aux=False, +# focal_alpha=0.25, +# focal_gamma=2, +# ): +# warnings.warn( +# "MultiStepIteractiveMasks is deprecated. Please use MultiStepMultiMasksAndIous", +# DeprecationWarning, +# ) +# super().__init__(weight_dict, compute_aux) +# self.focal_alpha = focal_alpha +# self.focal_gamma = focal_gamma +# self.target_keys.extend(["masks"]) + +# def get_loss(self, outputs, targets, indices, num_boxes): +# """Compute the losses related to the masks: the focal loss and the dice loss. +# targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w] + +# Unlike `Masks`, here the "multistep_pred_masks" can have multiple channels, each +# corresponding to one iterative prediction step in SAM-style training. We treat each +# channel as a mask prediction and sum the loss across channels. +# """ +# src_masks = outputs["multistep_pred_masks"] +# target_masks = targets["masks"] +# assert src_masks.size(0) == target_masks.size(0) +# assert src_masks.dim() == 4 +# assert target_masks.dim() == 3 + +# # tile target_masks according to the number of +# # channels `src_masks`. +# num_steps = src_masks.size(1) +# target_masks = target_masks.unsqueeze(1).to(src_masks.dtype) +# if num_steps > 1: +# target_masks = target_masks.repeat(1, num_steps, 1, 1) + +# # resize `src_masks` to target mask resolution +# if src_masks.shape != target_masks.shape: +# src_masks = interpolate( +# src_masks, +# size=target_masks.shape[-2:], +# mode="bilinear", +# align_corners=False, +# ) +# assert src_masks.shape == target_masks.shape + +# # flatten the multiple steps in to the batch dimension +# src_masks = src_masks.flatten(0, 1).flatten(1) +# target_masks = target_masks.flatten(0, 1).flatten(1) +# losses = { +# "loss_mask": sigmoid_focal_loss( +# src_masks, +# target_masks, +# num_boxes, +# alpha=self.focal_alpha, +# gamma=self.focal_gamma, +# ), +# "loss_dice": dice_loss(src_masks, target_masks, num_boxes), +# } + +# return losses + + +# class MultiStepMultiMasksAndIous(LossWithWeights): +# def __init__( +# self, +# weight_dict=None, +# compute_aux=False, +# focal_alpha=0.25, +# focal_gamma=2, +# # if True, back-prop on all predicted ious +# # not just the one with lowest loss_combo +# supervise_all_iou=False, +# # Less slack vs MSE loss in [-1, 1] error range +# iou_use_l1_loss=False, +# # Settings for obj score prediction +# pred_obj_scores=False, +# focal_gamma_obj_score=0.0, +# focal_alpha_obj_score=-1, +# ): +# super().__init__(weight_dict, compute_aux) +# self.focal_alpha = focal_alpha +# self.focal_gamma = focal_gamma +# self.target_keys.extend(["masks"]) +# assert "loss_mask" in self.weight_dict +# assert "loss_dice" in self.weight_dict +# assert "loss_iou" in self.weight_dict +# if "loss_class" not in self.weight_dict: +# self.weight_dict["loss_class"] = 0.0 +# self.focal_alpha_obj_score = focal_alpha_obj_score +# self.focal_gamma_obj_score = focal_gamma_obj_score +# self.supervise_all_iou = supervise_all_iou +# self.iou_use_l1_loss = iou_use_l1_loss +# self.pred_obj_scores = pred_obj_scores + +# def get_loss(self, outputs, targets, indices, num_boxes): +# """ +# Compute the losses related to the masks: the focal loss and the dice loss. +# and also the MSE loss between predicted IoUs and actual IoUs. + +# Here "multistep_pred_multimasks_high_res" is a list of multimasks (tensors +# of shape [N, M, H, W], where M could be 1 or larger, corresponding to +# one or multiple predicted masks from a click. + +# We back-propagate focal, dice and iou losses only on the prediction channel +# with the lowest focal+dice loss between predicted mask and ground-truth. +# """ + +# target_masks = targets["masks"].unsqueeze(1).float() +# assert target_masks.dim() == 4 # [N, 1, H, W] +# src_masks_list = outputs["multistep_pred_multimasks_high_res"] +# ious_list = outputs["multistep_pred_ious"] +# object_score_logits_list = outputs["multistep_object_score_logits"] + +# assert len(src_masks_list) == len(ious_list) +# assert len(object_score_logits_list) == len(ious_list) + +# # Remove invalid masks from loss +# keep = targets["is_valid_mask"] +# target_masks = target_masks[keep] + +# # accumulate the loss over prediction steps +# losses = {"loss_mask": 0, "loss_dice": 0, "loss_iou": 0, "loss_class": 0} +# for src_masks, ious, object_score_logits in zip( +# src_masks_list, ious_list, object_score_logits_list +# ): +# object_score_logits = object_score_logits[keep] +# ious = ious[keep] +# src_masks = src_masks[keep] +# self._update_losses( +# losses, src_masks, target_masks, ious, num_boxes, object_score_logits +# ) +# return losses + +# def _update_losses( +# self, losses, src_masks, target_masks, ious, num_boxes, object_score_logits +# ): +# target_masks = target_masks.expand_as(src_masks) +# # get focal, dice and iou loss on all output masks in a prediction step +# loss_multimask = sigmoid_focal_loss( +# src_masks, +# target_masks, +# num_boxes, +# alpha=self.focal_alpha, +# gamma=self.focal_gamma, +# loss_on_multimask=True, +# triton=False, # only use triton if alpha > 0 +# ) +# loss_multidice = dice_loss( +# src_masks, target_masks, num_boxes, loss_on_multimask=True +# ) +# if not self.pred_obj_scores: +# loss_class = torch.tensor( +# 0.0, dtype=loss_multimask.dtype, device=loss_multimask.device +# ) +# target_obj = torch.ones( +# loss_multimask.shape[0], +# 1, +# dtype=loss_multimask.dtype, +# device=loss_multimask.device, +# ) +# else: +# target_obj = torch.any((target_masks[:, 0] > 0).flatten(1), dim=-1)[ +# ..., None +# ].float() +# loss_class = sigmoid_focal_loss( +# object_score_logits, +# target_obj, +# num_boxes, +# alpha=self.focal_alpha_obj_score, +# gamma=self.focal_gamma_obj_score, +# triton=False, +# ) + +# loss_multiiou = iou_loss( +# src_masks, +# target_masks, +# ious, +# num_boxes, +# loss_on_multimask=True, +# use_l1_loss=self.iou_use_l1_loss, +# ) +# assert loss_multimask.dim() == 2 +# assert loss_multidice.dim() == 2 +# assert loss_multiiou.dim() == 2 +# if loss_multimask.size(1) > 1: +# # take the mask indices with the smallest focal + dice loss for back propagation +# loss_combo = ( +# loss_multimask * self.weight_dict["loss_mask"] +# + loss_multidice * self.weight_dict["loss_dice"] +# ) +# best_loss_inds = torch.argmin(loss_combo, dim=-1) +# batch_inds = torch.arange(loss_combo.size(0), device=loss_combo.device) +# loss_mask = loss_multimask[batch_inds, best_loss_inds].unsqueeze(1) +# loss_dice = loss_multidice[batch_inds, best_loss_inds].unsqueeze(1) +# # calculate the iou prediction and slot losses only in the index +# # with the minimum loss for each mask (to be consistent w/ SAM) +# if self.supervise_all_iou: +# loss_iou = loss_multiiou.mean(dim=-1).unsqueeze(1) +# else: +# loss_iou = loss_multiiou[batch_inds, best_loss_inds].unsqueeze(1) +# else: +# loss_mask = loss_multimask +# loss_dice = loss_multidice +# loss_iou = loss_multiiou + +# # backprop focal, dice and iou loss only if obj present +# loss_mask = loss_mask * target_obj +# loss_dice = loss_dice * target_obj +# loss_iou = loss_iou * target_obj + +# # sum over batch dimension (note that the losses are already divided by num_boxes) +# losses["loss_mask"] += loss_mask.sum() +# losses["loss_dice"] += loss_dice.sum() +# losses["loss_iou"] += loss_iou.sum() +# losses["loss_class"] += loss_class + + +# class TextCriterion(LossWithWeights): +# def __init__( +# self, +# pad_token, +# max_seq_len=100, +# weight_dict=None, +# compute_aux=False, +# ): +# super().__init__(weight_dict, compute_aux) +# self.pad_token = pad_token +# self.max_seq_len = max_seq_len +# self.in_lengths = None + +# def get_loss(self, outputs, **kwargs): +# nb_tokens = outputs["captioning_tokenized_target"].input_ids.numel() +# bs, seq_len = outputs["captioning_tokenized_target"].input_ids.shape +# ce = F.cross_entropy( +# outputs["captioning_pred_text"].flatten(0, -2), +# outputs["captioning_tokenized_target"].input_ids.flatten(), +# ignore_index=self.pad_token, +# reduction="sum", +# ) + +# not_pad = ( +# outputs["captioning_tokenized_target"] +# .input_ids.reshape(-1) +# .ne(self.pad_token) +# ) + +# if nb_tokens > 0: +# nb_non_pad = not_pad.numel() +# ce = ce / nb_non_pad + +# preds = outputs["captioning_pred_text"].flatten(0, -2).argmax(-1)[not_pad] +# targets = outputs["captioning_tokenized_target"].input_ids.flatten()[not_pad] +# correct = preds == targets +# correct = correct.sum() / (correct.numel() + 1e-5) + +# correct_sequence_level = torch.all( +# ( +# outputs["captioning_pred_text"] +# .flatten(0, -2) +# .argmax(-1) +# .reshape(bs, seq_len) +# == outputs["captioning_tokenized_target"].input_ids +# ) +# | (~not_pad).view(bs, seq_len), +# dim=1, +# ) +# seq_level_acc = correct_sequence_level.float().mean() + +# return {"loss_text": ce, "text_acc": correct, "text_seq_acc": seq_level_acc} + + +def segment_miou(source, target): + """Compute the mean IoU between two sets of masks""" + assert source.shape == target.shape, "The two masks must have the same shape" + assert source.ndim == 3, "The masks must be 3D" + + valid_targets = (target.sum(dim=(1, 2)) > 0).sum() + if valid_targets == 0: + return torch.tensor(1.0, device=source.device) + intersection = (source.bool() & target.bool()).sum(dim=(1, 2)) + union = (source.bool() | target.bool()).sum(dim=(1, 2)) + iou = intersection / (union + 1e-8) + return iou.sum() / valid_targets + + +class SemanticSegCriterion(LossWithWeights): + def __init__( + self, + weight_dict, + focal: bool = False, + focal_alpha: float = 0.6, + focal_gamma: float = 1.6, + downsample: bool = True, + presence_head: bool = False, + # Option to turn off presence loss, if some other component + # is already doing it, e.g. decoder - in which case, + # we could still set presence_head to True so that + # losses are not propogated to masks when there is no GT mask + presence_loss: bool = True, + ): + super().__init__(weight_dict, False) + self.focal = focal + self.focal_alpha = focal_alpha + self.focal_gamma = focal_gamma + self.downsample = downsample + self.presence_head = presence_head + self.presence_loss = presence_loss + + def get_loss(self, out_dict, targets): + outputs = out_dict["semantic_seg"] + presence_logit = out_dict["presence_logit"] + if ( + "semantic_masks" in targets + and targets["semantic_masks"] is not None + and targets["semantic_masks"].size(0) > 0 + ): + semantic_targets = targets["semantic_masks"] + with torch.no_grad(): + if self.downsample: + # downsample targets to the size of predictions + size = outputs.shape[-2:] + semantic_targets = ( + F.interpolate( + semantic_targets.float().unsqueeze(1), + size=size, + mode="bilinear", + align_corners=False, + ) + .squeeze(1) + .bool() + ) + else: + with torch.no_grad(): + if self.downsample: + # downsample targets to the size of predictions + size = outputs.shape[-2:] + segments = ( + F.interpolate( + targets["masks"].float().unsqueeze(1), + size=size, + mode="bilinear", + align_corners=False, + ) + .squeeze(1) + .bool() + ) + else: + segments = targets["masks"].bool() + + # the annotations are for instance segmentation, so we merge them to get semantic segmentation + semantic_targets = instance_masks_to_semantic_masks( + segments, targets["num_boxes"] + ) + + if not self.downsample: + # upsample predictions to the target size + size = semantic_targets.shape[-2:] + outputs = F.interpolate( + outputs.float(), + size=size, + mode="bilinear", + align_corners=False, + ) + + if self.focal: + loss = sigmoid_focal_loss( + outputs.squeeze(1).flatten(-2), + semantic_targets.float().flatten(-2), + num_boxes=len(semantic_targets), + alpha=self.focal_alpha, + gamma=self.focal_gamma, + reduce=not self.presence_head, + ) + if self.presence_head: + loss = loss.mean(1) + else: + loss = F.binary_cross_entropy_with_logits( + outputs.squeeze(1), + semantic_targets.float(), + reduction="none" if self.presence_head else "mean", + ) + if self.presence_head: + loss = loss.flatten(1).mean(1) + + loss_dice = dice_loss( + outputs.squeeze(1).flatten(1), + semantic_targets.flatten(1), + len(semantic_targets), + reduce=not self.presence_head, + ) + + miou = segment_miou(outputs.sigmoid().squeeze(1) > 0.5, semantic_targets) + + loss_dict = {} + + if self.presence_head: + presence_target = semantic_targets.flatten(1).any(-1) + if self.presence_loss: + loss_presence = F.binary_cross_entropy_with_logits( + presence_logit.flatten(), + presence_target.float(), + ) + presence_acc = ( + ((presence_logit.flatten().sigmoid() > 0.5) == presence_target) + .float() + .mean() + ) + else: + # Dummy values + loss_presence = torch.tensor(0.0, device=loss.device) + # Whichever component is computing the presence loss, + # should also track presence_acc + presence_acc = torch.tensor(0.0, device=loss.device) + + loss_dict["loss_semantic_presence"] = loss_presence + loss_dict["presence_acc"] = presence_acc + + # reduce the other losses, skipping the negative ones + bs = loss.shape[0] + assert presence_target.numel() == bs + + mask = presence_target + nb_valid = presence_target.sum().item() + + loss = (loss * mask.float()).sum() / (nb_valid + 1e-6) + loss_dice = (loss_dice * mask.float()).sum() / (nb_valid + 1e-6) + + loss_dict.update( + { + "loss_semantic_seg": loss, + "loss_semantic_dice": loss_dice, + "miou_semantic_seg": miou, + } + ) + + return loss_dict + + +class Det2TrkAssoc(LossWithWeights): + def __init__( + self, + weight_dict, + use_fp_loss=False, + fp_loss_on_exhaustive_only=True, + treat_fp_as_new_obj=False, + ): + super().__init__(weight_dict, compute_aux=False) + self.use_fp_loss = use_fp_loss + self.fp_loss_on_exhaustive_only = fp_loss_on_exhaustive_only + self.treat_fp_as_new_obj = treat_fp_as_new_obj + if self.use_fp_loss: + self.target_keys.append("is_exhaustive") + + def get_loss(self, outputs, targets, indices, num_boxes): + det2trk_assoc_logits = outputs["det2trk_assoc_logits"] + device = det2trk_assoc_logits.device + B, Q_det, Q_trk_plus_2 = det2trk_assoc_logits.shape + assert Q_trk_plus_2 >= 2 + Q_trk = Q_trk_plus_2 - 2 + + # We only apply association losses to those detection queries that either match + # a GT instance or have score > 0 (i.e. those TP, FN and FP detection queries) + matched_object_ids = outputs["matched_object_ids"] + assert matched_object_ids.shape == (B, Q_det + Q_trk) + matched_obj_ids_det = matched_object_ids[:, :Q_det] + matched_obj_ids_trk = matched_object_ids[:, Q_det:] + det_is_matched_to_gt = matched_obj_ids_det >= 0 + trk_is_matched_to_gt = matched_obj_ids_trk >= 0 + + # note: -1 label is ignored in the (softmax) cross_entropy loss below + det2trk_assoc_labels = -torch.ones(B, Q_det, dtype=torch.long, device=device) + # a) If a detection query is matched to a same object ID as a tracking query, + # we assign it the index of the tracking query as a label + det_is_same_obj_id_as_trk = ( + det_is_matched_to_gt[:, :, None] + & trk_is_matched_to_gt[:, None, :] + & (matched_obj_ids_det[:, :, None] == matched_obj_ids_trk[:, None, :]) + ) + batch_idx, det_idx, trk_idx = det_is_same_obj_id_as_trk.nonzero(as_tuple=True) + det2trk_assoc_labels[batch_idx, det_idx] = trk_idx + + # b) If a detection query is matched to GT but not to any tracking query, + # we assign it a "new_object" label + det_is_new_obj = det_is_matched_to_gt & ~det_is_same_obj_id_as_trk.any(dim=-1) + det2trk_assoc_labels[det_is_new_obj] = Q_trk + + # c) If a detection query is not matched to GT but have score > 0, + # we assign it a "false_positive" label + if self.use_fp_loss: + det_is_above_thresh = outputs["pred_logits"][:, :Q_det].squeeze(2) > 0 + det_is_fp = ~det_is_matched_to_gt & det_is_above_thresh + if self.treat_fp_as_new_obj: + det2trk_assoc_labels[det_is_fp] = Q_trk + else: + if self.fp_loss_on_exhaustive_only: + # only count FP detections on batches that are exhaustively annotated + det_is_fp &= targets["is_exhaustive"].unsqueeze(1).bool() + det2trk_assoc_labels[det_is_fp] = Q_trk + 1 + + # softmax cross-entropy loss for detection-to-tracking association + loss_det2trk_assoc = F.cross_entropy( + input=det2trk_assoc_logits.flatten(0, 1), # (B * Q_det, Q_trk + 2) + target=det2trk_assoc_labels.flatten(0, 1), # (B * Q_det) + ignore_index=-1, + reduction="none", + ).view(B, Q_det) + # skip det2trk assocation loss on frames w/o any (non-padding) tracking queries + frame_has_valid_trk = trk_is_matched_to_gt.any(dim=-1, keepdims=True) # (B, 1) + loss_det2trk_assoc = loss_det2trk_assoc * frame_has_valid_trk.float() + + loss_det2trk_assoc = loss_det2trk_assoc.sum() / (B * num_boxes) + return {"loss_det2trk_assoc": loss_det2trk_assoc} + + +class TrackingByDetectionAssoc(LossWithWeights): + def __init__(self, weight_dict): + super().__init__(weight_dict, compute_aux=False, supports_o2m_loss=False) + assert "loss_det2trk_assoc" in self.weight_dict + assert "loss_trk2det_assoc" in self.weight_dict + + def get_loss(self, outputs, targets, indices, num_boxes): + # Part A: gather object id matching between detection and tracking + det2trk_assoc_logits = outputs["det2trk_assoc_logits"] # (B, Q_det+1, Q_trk+1) + B, Q_det_plus_1, Q_trk_plus_1 = det2trk_assoc_logits.shape + assert Q_det_plus_1 >= 1 and Q_trk_plus_1 >= 1 + Q_det = Q_det_plus_1 - 1 + Q_trk = Q_trk_plus_1 - 1 + device = det2trk_assoc_logits.device + + matched_obj_ids_det = outputs["matched_object_ids"] + assert matched_obj_ids_det.shape == (B, Q_det) + det_is_matched_to_gt = matched_obj_ids_det >= 0 + matched_obj_ids_trk = outputs["prev_trk_object_ids"] + assert matched_obj_ids_trk.shape == (B, Q_trk) + trk_is_matched_to_gt = matched_obj_ids_trk >= 0 + frame_has_valid_trk = trk_is_matched_to_gt.any(dim=-1, keepdims=True) # (B, 1) + + # check whether a detection object is the same as a tracking object + det_is_same_obj_id_as_trk = ( + det_is_matched_to_gt[:, :, None] + & trk_is_matched_to_gt[:, None, :] + & (matched_obj_ids_det[:, :, None] == matched_obj_ids_trk[:, None, :]) + ) # (B, Q_det, Q_trk) + # there should be at most one match for each detection and each previous tracked object + torch._assert_async(torch.all(det_is_same_obj_id_as_trk.sum(dim=2) <= 1)) + torch._assert_async(torch.all(det_is_same_obj_id_as_trk.sum(dim=1) <= 1)) + batch_idx, det_idx, trk_idx = det_is_same_obj_id_as_trk.nonzero(as_tuple=True) + + # Part B: Detection-to-tracking association loss + # assign detection-to-tracking labels (note: -1 label is ignored in the loss below) + det2trk_assoc_labels = -torch.ones(B, Q_det, dtype=torch.long, device=device) + det2trk_assoc_labels[batch_idx, det_idx] = trk_idx + # if a detection is matched to GT but not to any tracking, assign it a "new-object" label + det_is_new_obj = det_is_matched_to_gt & ~det_is_same_obj_id_as_trk.any(dim=2) + det2trk_assoc_labels[det_is_new_obj] = Q_trk # "Q_trk" label is "new-object" + + # softmax cross-entropy loss for detection-to-tracking association + loss_det2trk_assoc = F.cross_entropy( + input=det2trk_assoc_logits[:, :-1].flatten(0, 1), # (B*Q_det, Q_trk+1) + target=det2trk_assoc_labels.flatten(0, 1), # (B*Q_det) + ignore_index=-1, + reduction="none", + ).view(B, Q_det) + # skip det2trk assocation loss on frames w/o any (non-padding) tracking queries + loss_det2trk_assoc = loss_det2trk_assoc * frame_has_valid_trk.float() + loss_det2trk_assoc = loss_det2trk_assoc.sum() / (B * num_boxes) + loss_dict = {"loss_det2trk_assoc": loss_det2trk_assoc} + + # Part C: tracking-to-detection association loss + trk2det_assoc_logits = det2trk_assoc_logits.transpose(1, 2) + assert trk2det_assoc_logits.shape == (B, Q_trk + 1, Q_det + 1) + # assign tracking-to-detection labels (note: -1 label is ignored in the loss below) + trk2det_assoc_labels = -torch.ones(B, Q_trk, dtype=torch.long, device=device) + trk2det_assoc_labels[batch_idx, trk_idx] = det_idx + # if a tracking is matched to GT but not to any detection, assign it a "occluded" label + trk_is_occluded = trk_is_matched_to_gt & ~det_is_same_obj_id_as_trk.any(dim=1) + trk2det_assoc_labels[trk_is_occluded] = Q_det # "Q_det" label is "occluded" + + # softmax cross-entropy loss for tracking-to-detection association + loss_trk2det_assoc = F.cross_entropy( + input=trk2det_assoc_logits[:, :-1].flatten(0, 1), # (B*Q_trk, Q_det+1) + target=trk2det_assoc_labels.flatten(0, 1), # (B*Q_trk) + ignore_index=-1, + reduction="none", + ).view(B, Q_trk) + # skip trk2det association loss on frames w/o any (non-padding) tracking queries + loss_trk2det_assoc = loss_trk2det_assoc * frame_has_valid_trk.float() + loss_trk2det_assoc = loss_trk2det_assoc.sum() / (B * num_boxes) + loss_dict["loss_trk2det_assoc"] = loss_trk2det_assoc + + return loss_dict + + +def _keep_only_trk_queries_in_match_inds(inds, Q_det): + """Keep only the tracking query indices in the indices tuple""" + batch_idx, src_idx, tgt_idx = inds + if batch_idx.numel() == 0: + return (batch_idx, src_idx, tgt_idx) # empty indices, nothing to filter + + # keep only the tracking query indices + is_trk_query = src_idx >= Q_det + batch_idx_trk = batch_idx[is_trk_query] + src_idx_trk = src_idx[is_trk_query] + tgt_idx_trk = tgt_idx[is_trk_query] if tgt_idx is not None else None + return (batch_idx_trk, src_idx_trk, tgt_idx_trk) diff --git a/third_party/sam3/sam3/train/loss/mask_sampling.py b/third_party/sam3/sam3/train/loss/mask_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..aa0ad00d1248c2dc4b2200d94461dfaf00081977 --- /dev/null +++ b/third_party/sam3/sam3/train/loss/mask_sampling.py @@ -0,0 +1,115 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +from typing import Callable + +import torch +from torch.nn import functional as F + + +# Adapted from https://github.com/facebookresearch/detectron2/blob/main/projects/PointRend/point_rend/point_features.py +def point_sample(input, point_coords, **kwargs): + """ + A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors. + Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside + [0, 1] x [0, 1] square. + + Args: + input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid. + point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains + [0, 1] x [0, 1] normalized point coordinates. + + Returns: + output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains + features for points in `point_coords`. The features are obtained via bilinear + interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`. + """ + add_dim = False + if point_coords.dim() == 3: + add_dim = True + point_coords = point_coords.unsqueeze(2) + normalized_point_coords = 2.0 * point_coords - 1.0 # Normalize to [-1,1] + output = F.grid_sample(input, normalized_point_coords, **kwargs) + if add_dim: + output = output.squeeze(3) + return output + + +# Adapted from https://github.com/facebookresearch/detectron2/blob/main/projects/PointRend/point_rend/point_features.py +def get_uncertain_point_coords_with_randomness( + logits: torch.Tensor, + uncertainty_func: Callable, + num_points: int, + oversample_ratio: int, + importance_sample_ratio: float, +) -> torch.Tensor: + """ + Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The unceratinties + are calculated for each point using 'uncertainty_func' function that takes point's logit + prediction as input. + See PointRend paper for details. + + Args: + logits (Tensor): A tensor of shape (N, C, Hmask, Wmask) or (N, 1, Hmask, Wmask) for + class-specific or class-agnostic prediction. + uncertainty_func: A function that takes a Tensor of shape (N, C, P) or (N, 1, P) that + contains logit predictions for P points and returns their uncertainties as a Tensor of + shape (N, 1, P). + num_points (int): The number of points P to sample. + oversample_ratio (int): Oversampling parameter. + importance_sample_ratio (float): Ratio of points that are sampled via importnace sampling. + + Returns: + point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P + sampled points. + """ + assert oversample_ratio >= 1 + assert importance_sample_ratio <= 1 and importance_sample_ratio >= 0 + num_boxes = logits.shape[0] + num_sampled = int(num_points * oversample_ratio) + point_coords = torch.rand(num_boxes, num_sampled, 2, device=logits.device) + point_logits = point_sample(logits, point_coords, align_corners=False) + # It is crucial to calculate uncertainty based on the sampled prediction value for the points. + # Calculating uncertainties of the predictions first and sampling them for points leads + # to incorrect results. + # To illustrate this: assume uncertainty_func(logits)=-abs(logits), a sampled point between + # two predictions with -1 and 1 logits has 0 logits, and therefore 0 uncertainty value. + # However, if we calculate uncertainties for the predictions first, + # both will have -1 uncertainty, and the sampled point will get -1 uncertainty. + point_uncertainties = uncertainty_func(point_logits) + num_uncertain_points = int(importance_sample_ratio * num_points) + num_random_points = num_points - num_uncertain_points + idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] + # Flatten the indices + shift = num_sampled * torch.arange( + num_boxes, dtype=torch.long, device=logits.device + ) + idx += shift[:, None] + point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view( + num_boxes, num_uncertain_points, 2 + ) + if num_random_points > 0: + point_coords = torch.cat( + [ + point_coords, + torch.rand(num_boxes, num_random_points, 2, device=logits.device), + ], + dim=1, + ) + return point_coords + + +# Adapted from https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/criterion.py +def calculate_uncertainty(logits: torch.Tensor) -> torch.Tensor: + """ + Estimates uncerainty as L1 distance between 0.0 and the logit prediction. + Args: + logits (Tensor): A tensor of shape (R, 1, ...) for class-agnostic + predicted masks + Returns: + scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with + the most uncertain locations having the highest uncertainty score. + """ + assert logits.shape[1] == 1 + return -(torch.abs(logits)) diff --git a/third_party/sam3/sam3/train/loss/sam3_loss.py b/third_party/sam3/sam3/train/loss/sam3_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..5aa5791965a6d361ebdac2720c2f0a48fa925774 --- /dev/null +++ b/third_party/sam3/sam3/train/loss/sam3_loss.py @@ -0,0 +1,203 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +import torch +from sam3.model.model_misc import SAM3Output +from sam3.train.utils.distributed import get_world_size + +from .loss_fns import CORE_LOSS_KEY, Det2TrkAssoc, Masks + + +class DummyLoss(torch.nn.Module): + """A dummy loss that always returns 0 (as a placeholder for eval)""" + + def __init__( + self, + core_loss_key: str = CORE_LOSS_KEY, + device: str = "cuda", + **kwargs, + ): + super().__init__() + self.core_loss_key = core_loss_key + self.device = torch.device(device) + + def forward(self, *args, **kwargs): + return {self.core_loss_key: torch.tensor(0.0, device=self.device)} + + def accumulate(self, out_dict): + """ + Called by iterative losses. + """ + if self.core_loss_key not in out_dict: + out_dict[self.core_loss_key] = torch.tensor(0.0, device=self.device) + return out_dict + + +class Sam3LossWrapper(torch.nn.Module): + def __init__( + self, + loss_fns_find, + normalization="global", + matcher=None, + o2m_matcher=None, + o2m_weight=1.0, + use_o2m_matcher_on_o2m_aux=True, + loss_fn_semantic_seg=None, + normalize_by_valid_object_num=False, + normalize_by_stage_num=False, + scale_by_find_batch_size=False, + ): + super().__init__() + self.loss_fns_find = loss_fns_find + assert normalization in ["global", "local", "none"] + self.normalization = normalization + self.normalize_by_valid_object_num = normalize_by_valid_object_num + self.normalize_by_stage_num = normalize_by_stage_num + self.matcher = matcher + self.o2m_matcher = o2m_matcher + self.o2m_weight = o2m_weight + # whether to use the o2m matcher on the o2m queries in auxiliary outputs + self.use_o2m_matcher_on_o2m_aux = use_o2m_matcher_on_o2m_aux + self.loss_fn_semantic_seg = loss_fn_semantic_seg + self.scale_by_find_batch_size = scale_by_find_batch_size + + def _get_num_boxes(self, targets): + # the average number of target boxes for loss normalization + if self.normalize_by_valid_object_num: + # valid boxes are those with non-zero height and width + # (while padded invisible boxes are ) + boxes_hw = targets["boxes"].view(-1, 4) # cx, cy, w, h + num_boxes = (boxes_hw[:, 2:] > 0).all(dim=-1).sum().float() + else: + num_boxes = targets["num_boxes"].sum().float() + if self.normalization == "global": + torch.distributed.all_reduce(num_boxes) + num_boxes = torch.clamp(num_boxes / get_world_size(), min=1) + elif self.normalization == "local": + num_boxes = torch.clamp(num_boxes, min=1) + elif self.normalization == "none": + num_boxes = 1 + return num_boxes + + def compute_loss(self, nested_out, targets): + num_boxes = self._get_num_boxes(targets) + o2m_out_is_valid = nested_out.get("o2m_out_is_valid", None) + o2m_target_is_valid_padded = nested_out.get("o2m_target_is_valid_padded", None) + + # Get a list of outputs, including auxiliary and first stage outputs + output_list = [(nested_out, "", False)] # (out, suffix, is_aux) + if "aux_outputs" in nested_out: + output_list.extend( + (aux_out, f"_aux_{i}", True) + for i, aux_out in enumerate(nested_out["aux_outputs"]) + ) + if "first_stage" in nested_out: + output_list.append((nested_out["first_stage"], "_fs", True)) + + # Compute all the requested losses + losses = {} + total_core_loss = 0.0 + for out, suffix, is_aux in output_list: + # o2o matcher indices need to be computed by the model (as the video model requires + # a specific way of matching free and locked indices beyond just calling the matcher) + indices = out["indices"] + has_o2m_out = "pred_logits_o2m" in out + if has_o2m_out: + o2m_out = { + k[: -len("_o2m")]: v for k, v in out.items() if k.endswith("_o2m") + } + # o2m targets are the same as the o2o targets (assuming repeat=1) + o2m_targets = targets + if self.use_o2m_matcher_on_o2m_aux or not is_aux: + o2m_indices = self.o2m_matcher( + o2m_out, + o2m_targets, + out_is_valid=o2m_out_is_valid, + target_is_valid_padded=o2m_target_is_valid_padded, + ) + else: + o2m_indices = self.matcher( + o2m_out, + o2m_targets, + out_is_valid=o2m_out_is_valid, + target_is_valid_padded=o2m_target_is_valid_padded, + ) + + for loss_fn in self.loss_fns_find: + l_dict = loss_fn( + outputs=out, + targets=targets, + indices=indices, + num_boxes=num_boxes, + is_aux=is_aux, + ) + total_core_loss += l_dict.pop(CORE_LOSS_KEY) + losses.update({f"{k}{suffix}": v for k, v in l_dict.items()}) + + compute_o2m_loss = has_o2m_out + # a special handling to allow turning off mask loss in o2m + # (to be compatible with the original implementation) + if isinstance(loss_fn, Masks): + compute_o2m_loss = compute_o2m_loss and "pred_masks" in o2m_out + if isinstance(loss_fn, Det2TrkAssoc): + compute_o2m_loss = False # Det2TrkAssoc does not support o2m + if compute_o2m_loss: + l_dict = loss_fn( + outputs=o2m_out, + targets=o2m_targets, + indices=o2m_indices, + num_boxes=num_boxes, + is_aux=is_aux, + ) + for k in l_dict: + l_dict[k] *= self.o2m_weight + total_core_loss += l_dict.pop(CORE_LOSS_KEY) + losses.update({f"{k}{suffix}_o2m": v for k, v in l_dict.items()}) + + losses[CORE_LOSS_KEY] = total_core_loss + return losses + + def forward(self, find_stages: SAM3Output, find_targets): + if find_stages.loss_stages is not None: + find_targets = [find_targets[i] for i in find_stages.loss_stages] + with SAM3Output.iteration_mode( + find_stages, iter_mode=SAM3Output.IterMode.ALL_STEPS_PER_STAGE + ) as find_stages: + assert len(find_stages) == len(find_targets) + total_losses = {} + for stage_outputs, stage_targets in zip(find_stages, find_targets): + stage_targets = [stage_targets] * len(stage_outputs) + # If there are multiple steps within a stage, compute the loss for all of them (e.g. interactivity) + for outputs, targets in zip(stage_outputs, stage_targets): + cur_losses = self.compute_loss(outputs, targets) + + if self.loss_fn_semantic_seg is not None: + cur_losses_semantic = self.loss_fn_semantic_seg( + outputs, targets + ) + cur_losses[CORE_LOSS_KEY] += cur_losses_semantic.pop( + CORE_LOSS_KEY + ) + # make sure the semantic losses don't overlap with the find losses + assert set(cur_losses).isdisjoint(set(cur_losses_semantic)) + cur_losses.update(cur_losses_semantic) + + # Optionally, normalize the loss by the number of find stages (training video frames) so that + # image batches and video batches have similar loss scales. (Otherwise video batches would + # have a much higher loss scale due to summing the losses over all the find stages.) + if self.normalize_by_stage_num: + cur_losses[CORE_LOSS_KEY] /= len(find_stages) + + if self.scale_by_find_batch_size: + bs = targets["num_boxes"].shape[0] + # sqrt scaling based on the "effective" batch size + cur_losses[CORE_LOSS_KEY] *= bs**0.5 + + for k, v in cur_losses.items(): + if k not in total_losses: + total_losses[k] = v + else: + total_losses[k] += v + + return total_losses diff --git a/third_party/sam3/sam3/train/loss/sigmoid_focal_loss.py b/third_party/sam3/sam3/train/loss/sigmoid_focal_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..a56f03a250501f3bb71b3ec8bb417b3f59ec44c6 --- /dev/null +++ b/third_party/sam3/sam3/train/loss/sigmoid_focal_loss.py @@ -0,0 +1,323 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +"""Triton kernel for faster and memory efficient sigmoid focal loss""" + +import torch +import triton +import triton.language as tl +from torch._inductor.runtime.triton_helpers import libdevice + +""" + +The sigmoid focal loss is defined as: + + prob = inputs.sigmoid() + ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + p_t = prob * targets + (1 - prob) * (1 - targets) + alpha_t = alpha * targets + (1 - alpha) * (1 - targets) + loss = alpha_t * ce_loss * ((1 - p_t) ** gamma) + +Where alpha and gamma are scalar parameters, inputs are the logits, targets the float targets. + +We implement two versions of the sigmoid focal loss: with and without sum reduction. +The latter is implemented with built-in reduction to avoid materializing wrt the output of the loss. +This can help save a bit of peak memory. + +The reduction version is implemented using somewhat of a hack. Pytorch's generated kernels usually do the point-wise operation in a first kernel, and implement the reduction another kernel launched on a grid of size 1, where the reduction happens as a for loop in the triton kernel. +Since we want to fuse those two kernels, that is not a good idea: we'd have to launch the overall kernel on a grid of size 1, which is obviously inefficient. +On the other hand, typical CUDA algorithms for reduction (eg reduction tree) are hard to implement in triton due to the lack of thread sync primitives. +We settle for a version that abuses triton's atomic_add: we can have all threads simply add to the same location. +In practice, this is not good, since it creates a massive bottleneck on the semaphore for that single memory location. So instead, we create M reduction locations. Each thread will simply write to thread_id%M. The python code can finally sum over the M reductions. +M = 32 works fine in benchmarking tests. The forward is a tiny bit slower compared to the non-reduced kernel, but the backward breaks even due to one less memory allocation. +""" + + +@triton.jit +def _inner_focal_loss_fwd(inputs, targets, alpha, gamma): + inv_targets = 1 - targets + # Sigmoid + sig = tl.sigmoid(inputs) + + # Binary cross entropy with logits + # In practice, we want the following: + # bce_loss = -targets * tl.log(sig) - (1 - targets) * tl.log(1 - sig) + # However, the above is not numerically stable. + # We're also not directly taking the sum here, so the usual log-sum-exp trick doesn't apply + # The bce can be reformulated, after algebraic manipulation, to + # bce_loss = log(1 + exp(-x)) + x * (1-y) + # This is still not stable, because for large (-x) the exponential will blow up. + # We'll use the following alternate formulation: + # bce_loss = max(x, 0) - x * y + log(1 + exp(-abs(x))) + # Let's show that it's equivalent: + # Case x>=0: abs(x) = x , max(x, 0) = x + # so we get x - x * y + log(1 + exp(-x)) which is equivalent + # Case x<0: abs(x) = -x, max(x, 0) = 0 + # we have log(1 + exp(-abs(x))) = log(1 + exp(x)) = log(exp(x)(1 + exp(-x))) = x+log(1 + exp(-x)) + # plugging it in, we get + # 0 - x * y + x + log(1 + exp(-x)), which is also equivalent + # Note that this is stable because now the exponent are guaranteed to be below 0. + max_val = tl.clamp(inputs, min=0, max=1e9) + bce_loss = max_val - inputs * targets + tl.log(1 + tl.exp(-tl.abs(inputs))) + + # Modulating factor + p_t = sig * targets + (1 - sig) * inv_targets + mod_factor = libdevice.pow(1 - p_t, gamma) + + # Alpha factor + alpha_t = alpha * targets + (1 - alpha) * inv_targets + + # Final loss calculation + return alpha_t * mod_factor * bce_loss + + +# Non-reduced version +@triton.jit +def sigmoid_focal_loss_fwd_kernel( + inputs_ptr, + targets_ptr, + loss_ptr, + alpha: float, + gamma: float, + n_elements: int, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offset = block_start + tl.arange(0, BLOCK_SIZE) + mask = offset < n_elements + + # Load data + inputs = tl.load(inputs_ptr + offset, mask=mask).to(tl.float32) + targets = tl.load(targets_ptr + offset, mask=mask) + + final_loss = _inner_focal_loss_fwd(inputs, targets, alpha, gamma) + + # Store result + tl.store(loss_ptr + offset, final_loss, mask=mask) + + +# version with reduction +@triton.jit +def sigmoid_focal_loss_fwd_kernel_reduce( + inputs_ptr, + targets_ptr, + loss_ptr, + alpha: float, + gamma: float, + n_elements: int, + BLOCK_SIZE: tl.constexpr, + REDUCE_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + reduce_loc = pid % REDUCE_SIZE + offset = block_start + tl.arange(0, BLOCK_SIZE) + mask = offset < n_elements + # Load data + inputs = tl.load(inputs_ptr + offset, mask=mask).to(tl.float32) + targets = tl.load(targets_ptr + offset, mask=mask) + + final_loss = _inner_focal_loss_fwd(inputs, targets, alpha, gamma) * mask + + fl = tl.sum(final_loss) + + # Store result + tl.atomic_add(loss_ptr + reduce_loc, fl) + + +@triton.jit +def _inner_focal_loss_bwd(inputs, targets, alpha, gamma): + inv_targets = 1 - targets + + # Recompute forward + max_val = tl.clamp(inputs, min=0, max=1e9) + bce_loss = max_val - inputs * targets + tl.log(1 + tl.exp(-tl.abs(inputs))) + + # Sigmoid + sig = tl.sigmoid(inputs) + inv_sig = 1 - sig + + # Modulating factor + p_t = sig * targets + inv_sig * inv_targets + tmp = libdevice.pow(1 - p_t, gamma - 1) + mod_factor = tmp * (1 - p_t) + + # Alpha factor + alpha_t = alpha * targets + (1 - alpha) * inv_targets + + # Now computing the derivatives + d_pt = (2 * targets - 1) * sig * inv_sig + d_mod_factor = -gamma * d_pt * tmp + + d_bce_loss = sig - targets + + return alpha_t * (d_bce_loss * mod_factor + d_mod_factor * bce_loss) + + +@triton.jit +def sigmoid_focal_loss_bwd_kernel( + inputs_ptr, + targets_ptr, + grad_inputs_ptr, + grad_out_ptr, + alpha: float, + gamma: float, + n_elements: int, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offset = block_start + tl.arange(0, BLOCK_SIZE) + mask = offset < n_elements + input_ptrs = inputs_ptr + offset + target_ptrs = targets_ptr + offset + grad_input_ptrs = grad_inputs_ptr + offset + grad_out_ptrs = grad_out_ptr + offset + # Load data + inputs = tl.load(input_ptrs, mask=mask).to(tl.float32) + targets = tl.load(target_ptrs, mask=mask) + grad_out = tl.load(grad_out_ptrs, mask=mask) + d_loss = grad_out * _inner_focal_loss_bwd(inputs, targets, alpha, gamma) + tl.store(grad_input_ptrs, d_loss, mask=mask) + + +@triton.jit +def sigmoid_focal_loss_bwd_kernel_reduce( + inputs_ptr, + targets_ptr, + grad_inputs_ptr, + grad_out_ptr, + alpha: float, + gamma: float, + n_elements: int, + BLOCK_SIZE: tl.constexpr, +): + # The only difference is that the gradient is now a single scalar + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offset = block_start + tl.arange(0, BLOCK_SIZE) + mask = offset < n_elements + input_ptrs = inputs_ptr + offset + target_ptrs = targets_ptr + offset + grad_input_ptrs = grad_inputs_ptr + offset + # Load data + inputs = tl.load(input_ptrs, mask=mask).to(tl.float32) + targets = tl.load(target_ptrs, mask=mask) + grad_out = tl.load(grad_out_ptr) + d_loss = grad_out * _inner_focal_loss_bwd(inputs, targets, alpha, gamma) + tl.store(grad_input_ptrs, d_loss, mask=mask) + + +class SigmoidFocalLoss(torch.autograd.Function): + BLOCK_SIZE = 256 + + @staticmethod + def forward(ctx, inputs, targets, alpha=0.25, gamma=2): + n_elements = inputs.numel() + assert targets.numel() == n_elements + input_shape = inputs.shape + inputs = inputs.view(-1).contiguous() + targets = targets.view(-1).contiguous() + loss = torch.empty(inputs.shape, dtype=torch.float32, device=inputs.device) + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + sigmoid_focal_loss_fwd_kernel[grid]( + inputs, targets, loss, alpha, gamma, n_elements, SigmoidFocalLoss.BLOCK_SIZE + ) + ctx.save_for_backward(inputs.view(input_shape), targets.view(input_shape)) + ctx.alpha = alpha + ctx.gamma = gamma + return loss.view(input_shape) + + @staticmethod + def backward(ctx, grad_output): + inputs, targets = ctx.saved_tensors + alpha = ctx.alpha + gamma = ctx.gamma + n_elements = inputs.numel() + input_shape = inputs.shape + grad_inputs = torch.empty( + inputs.shape, dtype=grad_output.dtype, device=grad_output.device + ) + inputs_ptr = inputs.view(-1).contiguous() + targets_ptr = targets.view(-1).contiguous() + grad_output_ptr = grad_output.view(-1).contiguous() + grad_inputs_ptr = grad_inputs + assert grad_output.numel() == n_elements + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + sigmoid_focal_loss_bwd_kernel[grid]( + inputs_ptr, + targets_ptr, + grad_inputs_ptr, + grad_output_ptr, + alpha, + gamma, + n_elements, + SigmoidFocalLoss.BLOCK_SIZE, + ) + return grad_inputs.view(input_shape), None, None, None + + +triton_sigmoid_focal_loss = SigmoidFocalLoss.apply + + +class SigmoidFocalLossReduced(torch.autograd.Function): + BLOCK_SIZE = 256 + REDUCE_SIZE = 32 + + @staticmethod + def forward(ctx, inputs, targets, alpha=0.25, gamma=2): + n_elements = inputs.numel() + input_shape = inputs.shape + inputs = inputs.view(-1).contiguous() + targets = targets.view(-1).contiguous() + loss = torch.zeros( + SigmoidFocalLossReduced.REDUCE_SIZE, + device=inputs.device, + dtype=torch.float32, + ) + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + sigmoid_focal_loss_fwd_kernel_reduce[grid]( + inputs, + targets, + loss, + alpha, + gamma, + n_elements, + SigmoidFocalLossReduced.BLOCK_SIZE, + SigmoidFocalLossReduced.REDUCE_SIZE, + ) + ctx.save_for_backward(inputs.view(input_shape), targets.view(input_shape)) + ctx.alpha = alpha + ctx.gamma = gamma + return loss.sum() + + @staticmethod + def backward(ctx, grad_output): + inputs, targets = ctx.saved_tensors + alpha = ctx.alpha + gamma = ctx.gamma + n_elements = inputs.numel() + input_shape = inputs.shape + grad_inputs = torch.empty( + inputs.shape, dtype=grad_output.dtype, device=grad_output.device + ) + inputs_ptr = inputs.view(-1).contiguous() + targets_ptr = targets.reshape(-1).contiguous() + assert grad_output.numel() == 1 + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + sigmoid_focal_loss_bwd_kernel_reduce[grid]( + inputs_ptr, + targets_ptr, + grad_inputs, + grad_output, + alpha, + gamma, + n_elements, + SigmoidFocalLossReduced.BLOCK_SIZE, + ) + return grad_inputs.view(input_shape), None, None, None + + +triton_sigmoid_focal_loss_reduce = SigmoidFocalLossReduced.apply diff --git a/third_party/sam3/sam3/train/masks_ops.py b/third_party/sam3/sam3/train/masks_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..4e966bddac1608bb761f29ea5872d2b06c7486d9 --- /dev/null +++ b/third_party/sam3/sam3/train/masks_ops.py @@ -0,0 +1,290 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +"""Utilities for masks manipulation""" + +import numpy as np +import pycocotools.mask as maskUtils +import torch +from pycocotools import mask as mask_util + + +def instance_masks_to_semantic_masks( + instance_masks: torch.Tensor, num_instances: torch.Tensor +) -> torch.Tensor: + """This function converts instance masks to semantic masks. + It accepts a collapsed batch of instances masks (ie all instance masks are concatenated in a single tensor) and + the number of instances in each image of the batch. + It returns a mask with the same spatial dimensions as the input instance masks, where for each batch element the + semantic mask is the union of all the instance masks in the batch element. + + If for a given batch element there are no instances (ie num_instances[i]==0), the corresponding semantic mask will be a tensor of zeros. + + Args: + instance_masks (torch.Tensor): A tensor of shape (N, H, W) where N is the number of instances in the batch. + num_instances (torch.Tensor): A tensor of shape (B,) where B is the batch size. It contains the number of instances + in each image of the batch. + + Returns: + torch.Tensor: A tensor of shape (B, H, W) where B is the batch size and H, W are the spatial dimensions of the + input instance masks. + """ + + masks_per_query = torch.split(instance_masks, num_instances.tolist()) + + return torch.stack([torch.any(masks, dim=0) for masks in masks_per_query], dim=0) + + +def mask_intersection_vectorized(masks1, masks2): + """ + Vectorized computation of mask intersection using Matrix Multiplication. + + Args: + masks1: tensor of shape (N, H, W) + masks2: tensor of shape (M, H, W) + Returns: + tensor of shape (N, M) + """ + # Cast to float for Tensor Core acceleration via torch.mm + m1_flat = masks1.flatten(1).float() + m2_flat = masks2.flatten(1).float() + intersection = torch.mm(m1_flat, m2_flat.t()) + return intersection.long() + + +def mask_intersection(masks1, masks2, block_size=16): + """Compute the intersection of two sets of masks, without blowing the memory""" + + assert masks1.shape[1:] == masks2.shape[1:] + assert masks1.dtype == torch.bool and masks2.dtype == torch.bool + + result = torch.zeros( + masks1.shape[0], masks2.shape[0], device=masks1.device, dtype=torch.long + ) + for i in range(0, masks1.shape[0], block_size): + for j in range(0, masks2.shape[0], block_size): + intersection = ( + (masks1[i : i + block_size, None] * masks2[None, j : j + block_size]) + .flatten(-2) + .sum(-1) + ) + result[i : i + block_size, j : j + block_size] = intersection + return result + + +def mask_iom(masks1, masks2): + """ + Similar to IoU, except the denominator is the area of the smallest mask + """ + assert masks1.shape[1:] == masks2.shape[1:] + assert masks1.dtype == torch.bool and masks2.dtype == torch.bool + + intersection = mask_intersection_vectorized(masks1, masks2) + area1 = masks1.flatten(-2).sum(-1) + area2 = masks2.flatten(-2).sum(-1) + min_area = torch.min(area1[:, None], area2[None, :]) + return intersection / (min_area + 1e-8) + + +def compute_boundary(seg): + """ + Adapted from https://github.com/JonathonLuiten/TrackEval/blob/master/trackeval/metrics/j_and_f.py#L148 + Return a 1pix wide boundary of the given mask + """ + assert seg.ndim >= 2 + e = torch.zeros_like(seg) + s = torch.zeros_like(seg) + se = torch.zeros_like(seg) + + e[..., :, :-1] = seg[..., :, 1:] + s[..., :-1, :] = seg[..., 1:, :] + se[..., :-1, :-1] = seg[..., 1:, 1:] + + b = seg ^ e | seg ^ s | seg ^ se + b[..., -1, :] = seg[..., -1, :] ^ e[..., -1, :] + b[..., :, -1] = seg[..., :, -1] ^ s[..., :, -1] + b[..., -1, -1] = 0 + return b + + +def dilation(mask, kernel_size): + """ + Implements the dilation operation. If the input is on cpu, we call the cv2 version. + Otherwise, we implement it using a convolution + + The kernel is assumed to be a square kernel + + """ + + assert mask.ndim == 3 + kernel_size = int(kernel_size) + assert ( + kernel_size % 2 == 1 + ), f"Dilation expects a odd kernel size, got {kernel_size}" + + if mask.is_cuda: + m = mask.unsqueeze(1).to(torch.float16) + k = torch.ones(1, 1, kernel_size, 1, dtype=m.dtype, device=m.device) + + result = torch.nn.functional.conv2d(m, k, padding="same") + result = torch.nn.functional.conv2d(result, k.transpose(-1, -2), padding="same") + return result.view_as(mask) > 0 + + all_masks = mask.view(-1, mask.size(-2), mask.size(-1)).numpy().astype(np.uint8) + kernel = np.ones((kernel_size, kernel_size), dtype=np.uint8) + + import cv2 + + processed = [torch.from_numpy(cv2.dilate(m, kernel)) for m in all_masks] + return torch.stack(processed).view_as(mask).to(mask) + + +def compute_F_measure( + gt_boundary_rle, gt_dilated_boundary_rle, dt_boundary_rle, dt_dilated_boundary_rle +): + """Adapted from https://github.com/JonathonLuiten/TrackEval/blob/master/trackeval/metrics/j_and_f.py#L207 + + Assumes the boundary and dilated boundaries have already been computed and converted to RLE + """ + gt_match = maskUtils.merge([gt_boundary_rle, dt_dilated_boundary_rle], True) + dt_match = maskUtils.merge([dt_boundary_rle, gt_dilated_boundary_rle], True) + + n_dt = maskUtils.area(dt_boundary_rle) + n_gt = maskUtils.area(gt_boundary_rle) + # % Compute precision and recall + if n_dt == 0 and n_gt > 0: + precision = 1 + recall = 0 + elif n_dt > 0 and n_gt == 0: + precision = 0 + recall = 1 + elif n_dt == 0 and n_gt == 0: + precision = 1 + recall = 1 + else: + precision = maskUtils.area(dt_match) / float(n_dt) + recall = maskUtils.area(gt_match) / float(n_gt) + + # Compute F measure + if precision + recall == 0: + f_val = 0 + else: + f_val = 2 * precision * recall / (precision + recall) + + return f_val + + +@torch.no_grad() +def rle_encode(orig_mask, return_areas=False): + """Encodes a collection of masks in RLE format + + This function emulates the behavior of the COCO API's encode function, but + is executed partially on the GPU for faster execution. + + Args: + mask (torch.Tensor): A mask of shape (N, H, W) with dtype=torch.bool + return_areas (bool): If True, add the areas of the masks as a part of + the RLE output dict under the "area" key. Default is False. + + Returns: + str: The RLE encoded masks + """ + assert orig_mask.ndim == 3, "Mask must be of shape (N, H, W)" + assert orig_mask.dtype == torch.bool, "Mask must have dtype=torch.bool" + + if orig_mask.numel() == 0: + return [] + + # First, transpose the spatial dimensions. + # This is necessary because the COCO API uses Fortran order + mask = orig_mask.transpose(1, 2) + + # Flatten the mask + flat_mask = mask.reshape(mask.shape[0], -1) + if return_areas: + mask_areas = flat_mask.sum(-1).tolist() + # Find the indices where the mask changes + differences = torch.ones( + mask.shape[0], flat_mask.shape[1] + 1, device=mask.device, dtype=torch.bool + ) + differences[:, 1:-1] = flat_mask[:, :-1] != flat_mask[:, 1:] + differences[:, 0] = flat_mask[:, 0] + _, change_indices = torch.where(differences) + + try: + boundaries = torch.cumsum(differences.sum(-1), 0).cpu() + except RuntimeError as _: + boundaries = torch.cumsum(differences.cpu().sum(-1), 0) + + change_indices_clone = change_indices.clone() + # First pass computes the RLEs on GPU, in a flatten format + for i in range(mask.shape[0]): + # Get the change indices for this batch item + beg = 0 if i == 0 else boundaries[i - 1].item() + end = boundaries[i].item() + change_indices[beg + 1 : end] -= change_indices_clone[beg : end - 1] + + # Now we can split the RLES of each batch item, and convert them to strings + # No more gpu at this point + change_indices = change_indices.tolist() + + batch_rles = [] + # Process each mask in the batch separately + for i in range(mask.shape[0]): + beg = 0 if i == 0 else boundaries[i - 1].item() + end = boundaries[i].item() + run_lengths = change_indices[beg:end] + + uncompressed_rle = {"counts": run_lengths, "size": list(orig_mask.shape[1:])} + h, w = uncompressed_rle["size"] + rle = mask_util.frPyObjects(uncompressed_rle, h, w) + rle["counts"] = rle["counts"].decode("utf-8") + if return_areas: + rle["area"] = mask_areas[i] + batch_rles.append(rle) + + return batch_rles + + +def robust_rle_encode(masks): + """Encodes a collection of masks in RLE format. Uses the gpu version fist, falls back to the cpu version if it fails""" + + assert masks.ndim == 3, "Mask must be of shape (N, H, W)" + assert masks.dtype == torch.bool, "Mask must have dtype=torch.bool" + + try: + return rle_encode(masks) + except RuntimeError as _: + masks = masks.cpu().numpy() + rles = [ + mask_util.encode( + np.array(mask[:, :, np.newaxis], dtype=np.uint8, order="F") + )[0] + for mask in masks + ] + for rle in rles: + rle["counts"] = rle["counts"].decode("utf-8") + return rles + + +def ann_to_rle(segm, im_info): + """Convert annotation which can be polygons, uncompressed RLE to RLE. + Args: + ann (dict) : annotation object + Returns: + ann (rle) + """ + h, w = im_info["height"], im_info["width"] + if isinstance(segm, list): + # polygon -- a single object might consist of multiple parts + # we merge all parts into one mask rle code + rles = mask_util.frPyObjects(segm, h, w) + rle = mask_util.merge(rles) + elif isinstance(segm["counts"], list): + # uncompressed RLE + rle = mask_util.frPyObjects(segm, h, w) + else: + # rle + rle = segm + return rle diff --git a/third_party/sam3/sam3/train/matcher.py b/third_party/sam3/sam3/train/matcher.py new file mode 100644 index 0000000000000000000000000000000000000000..adbbd60ae6fba107c1a760df96b37f56949f6b85 --- /dev/null +++ b/third_party/sam3/sam3/train/matcher.py @@ -0,0 +1,811 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +""" +Modules to compute the matching cost and solve the corresponding LSAP. +""" + +import numpy as np +import torch +from sam3.model.box_ops import box_cxcywh_to_xyxy, box_iou, generalized_box_iou +from scipy.optimize import linear_sum_assignment +from torch import nn + + +def _do_matching(cost, repeats=1, return_tgt_indices=False, do_filtering=False): + if repeats > 1: + cost = np.tile(cost, (1, repeats)) + + i, j = linear_sum_assignment(cost) + if do_filtering: + # filter out invalid entries (i.e. those with cost > 1e8) + valid_thresh = 1e8 + valid_ijs = [(ii, jj) for ii, jj in zip(i, j) if cost[ii, jj] < valid_thresh] + i, j = zip(*valid_ijs) if len(valid_ijs) > 0 else ([], []) + i, j = np.array(i, dtype=np.int64), np.array(j, dtype=np.int64) + if return_tgt_indices: + return i, j + order = np.argsort(j) + return i[order] + + +class HungarianMatcher(nn.Module): + """This class computes an assignment between the targets and the predictions of the network + + For efficiency reasons, the targets don't include the no_object. Because of this, in general, + there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, + while the others are un-matched (and thus treated as non-objects). + """ + + def __init__( + self, + cost_class: float = 1, + cost_bbox: float = 1, + cost_giou: float = 1, + focal_loss: bool = False, + focal_alpha: float = 0.25, + focal_gamma: float = 2, + ): + """Creates the matcher + + Params: + cost_class: This is the relative weight of the classification error in the matching cost + cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost + cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost + """ + super().__init__() + self.cost_class = cost_class + self.cost_bbox = cost_bbox + self.cost_giou = cost_giou + self.norm = nn.Sigmoid() if focal_loss else nn.Softmax(-1) + assert ( + cost_class != 0 or cost_bbox != 0 or cost_giou != 0 + ), "all costs cant be 0" + self.focal_loss = focal_loss + self.focal_alpha = focal_alpha + self.focal_gamma = focal_gamma + + @torch.no_grad() + def forward(self, outputs, batched_targets): + """Performs the matching + + Params: + outputs: This is a dict that contains at least these entries: + "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits + "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates + + targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: + "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth + objects in the target) containing the class labels + "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates + + Returns: + A list of size batch_size, containing tuples of (index_i, index_j) where: + - index_i is the indices of the selected predictions (in order) + - index_j is the indices of the corresponding selected targets (in order) + For each batch element, it holds: + len(index_i) = len(index_j) = min(num_queries, num_target_boxes) + """ + bs, num_queries = outputs["pred_logits"].shape[:2] + + # We flatten to compute the cost matrices in a batch + out_prob = self.norm( + outputs["pred_logits"].flatten(0, 1) + ) # [batch_size * num_queries, num_classes] + out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] + + # Also concat the target labels and boxes + tgt_bbox = batched_targets["boxes"] + + if "positive_map" in batched_targets: + # In this case we have a multi-hot target + positive_map = batched_targets["positive_map"] + assert len(tgt_bbox) == len(positive_map) + + if self.focal_loss: + positive_map = positive_map > 1e-4 + alpha = self.focal_alpha + gamma = self.focal_gamma + neg_cost_class = ( + (1 - alpha) * (out_prob**gamma) * (-(1 - out_prob + 1e-8).log()) + ) + pos_cost_class = ( + alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log()) + ) + cost_class = ( + (pos_cost_class - neg_cost_class).unsqueeze(1) + * positive_map.unsqueeze(0) + ).sum(-1) + else: + # Compute the soft-cross entropy between the predicted token alignment and the GT one for each box + cost_class = -(out_prob.unsqueeze(1) * positive_map.unsqueeze(0)).sum( + -1 + ) + else: + # In this case we are doing a "standard" cross entropy + tgt_ids = batched_targets["labels"] + assert len(tgt_bbox) == len(tgt_ids) + + if self.focal_loss: + alpha = self.focal_alpha + gamma = self.focal_gamma + neg_cost_class = ( + (1 - alpha) * (out_prob**gamma) * (-(1 - out_prob + 1e-8).log()) + ) + pos_cost_class = ( + alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log()) + ) + cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids] + else: + # Compute the classification cost. Contrary to the loss, we don't use the NLL, + # but approximate it in 1 - proba[target class]. + # The 1 is a constant that doesn't change the matching, it can be omitted. + cost_class = -out_prob[:, tgt_ids] + + # Compute the L1 cost between boxes + cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) + assert cost_class.shape == cost_bbox.shape + + # Compute the giou cost betwen boxes + cost_giou = -generalized_box_iou( + box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox) + ) + + # Final cost matrix + C = ( + self.cost_bbox * cost_bbox + + self.cost_class * cost_class + + self.cost_giou * cost_giou + ) + C = C.view(bs, num_queries, -1).cpu().numpy() + + sizes = torch.cumsum(batched_targets["num_boxes"], -1)[:-1] + costs = [c[i] for i, c in enumerate(np.split(C, sizes.cpu().numpy(), axis=-1))] + indices = [_do_matching(c) for c in costs] + batch_idx = torch.as_tensor( + sum([[i] * len(src) for i, src in enumerate(indices)], []), dtype=torch.long + ) + src_idx = torch.from_numpy(np.concatenate(indices)).long() + return batch_idx, src_idx + + +class BinaryHungarianMatcher(nn.Module): + """This class computes an assignment between the targets and the predictions of the network + + For efficiency reasons, the targets don't include the no_object. Because of this, in general, + there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, + while the others are un-matched (and thus treated as non-objects). + """ + + def __init__( + self, + cost_class: float = 1, + cost_bbox: float = 1, + cost_giou: float = 1, + ): + """Creates the matcher + + Params: + cost_class: This is the relative weight of the classification error in the matching cost + cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost + cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost + """ + super().__init__() + self.cost_class = cost_class + self.cost_bbox = cost_bbox + self.cost_giou = cost_giou + self.norm = nn.Sigmoid() + assert ( + cost_class != 0 or cost_bbox != 0 or cost_giou != 0 + ), "all costs cant be 0" + + @torch.no_grad() + def forward(self, outputs, batched_targets, repeats=0, repeat_batch=1): + """Performs the matching + + Params: + outputs: This is a dict that contains at least these entries: + "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits + "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates + + targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: + "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth + objects in the target) containing the class labels + "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates + + Returns: + A list of size batch_size, containing tuples of (index_i, index_j) where: + - index_i is the indices of the selected predictions (in order) + - index_j is the indices of the corresponding selected targets (in order) + For each batch element, it holds: + len(index_i) = len(index_j) = min(num_queries, num_target_boxes) + """ + if repeat_batch != 1: + raise NotImplementedError("please use BinaryHungarianMatcherV2 instead") + + bs, num_queries = outputs["pred_logits"].shape[:2] + + # We flatten to compute the cost matrices in a batch + out_prob = self.norm(outputs["pred_logits"].flatten(0, 1)).squeeze( + -1 + ) # [batch_size * num_queries] + out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] + + # Also concat the target labels and boxes + tgt_bbox = batched_targets["boxes"] + + # Compute the L1 cost between boxes + cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) + + cost_class = -out_prob.unsqueeze(-1).expand_as(cost_bbox) + + assert cost_class.shape == cost_bbox.shape + + # Compute the giou cost betwen boxes + cost_giou = -generalized_box_iou( + box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox) + ) + + # Final cost matrix + C = ( + self.cost_bbox * cost_bbox + + self.cost_class * cost_class + + self.cost_giou * cost_giou + ) + C = C.view(bs, num_queries, -1).cpu().numpy() + + sizes = torch.cumsum(batched_targets["num_boxes"], -1)[:-1] + costs = [c[i] for i, c in enumerate(np.split(C, sizes.cpu().numpy(), axis=-1))] + return_tgt_indices = False + for c in costs: + n_targ = c.shape[1] + if repeats > 1: + n_targ *= repeats + if c.shape[0] < n_targ: + return_tgt_indices = True + break + if return_tgt_indices: + indices, tgt_indices = zip( + *( + _do_matching( + c, repeats=repeats, return_tgt_indices=return_tgt_indices + ) + for c in costs + ) + ) + tgt_indices = list(tgt_indices) + for i in range(1, len(tgt_indices)): + tgt_indices[i] += sizes[i - 1].item() + tgt_idx = torch.from_numpy(np.concatenate(tgt_indices)).long() + else: + indices = [_do_matching(c, repeats=repeats) for c in costs] + tgt_idx = None + + batch_idx = torch.as_tensor( + sum([[i] * len(src) for i, src in enumerate(indices)], []), dtype=torch.long + ) + src_idx = torch.from_numpy(np.concatenate(indices)).long() + return batch_idx, src_idx, tgt_idx + + +class BinaryFocalHungarianMatcher(nn.Module): + """This class computes an assignment between the targets and the predictions of the network + + For efficiency reasons, the targets don't include the no_object. Because of this, in general, + there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, + while the others are un-matched (and thus treated as non-objects). + """ + + def __init__( + self, + cost_class: float = 1, + cost_bbox: float = 1, + cost_giou: float = 1, + alpha: float = 0.25, + gamma: float = 2.0, + stable: bool = False, + ): + """Creates the matcher + + Params: + cost_class: This is the relative weight of the classification error in the matching cost + cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost + cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost + """ + super().__init__() + self.cost_class = cost_class + self.cost_bbox = cost_bbox + self.cost_giou = cost_giou + self.norm = nn.Sigmoid() + self.alpha = alpha + self.gamma = gamma + self.stable = stable + assert ( + cost_class != 0 or cost_bbox != 0 or cost_giou != 0 + ), "all costs cant be 0" + + @torch.no_grad() + def forward(self, outputs, batched_targets, repeats=1, repeat_batch=1): + """Performs the matching + + Params: + outputs: This is a dict that contains at least these entries: + "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits + "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates + + targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: + "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth + objects in the target) containing the class labels + "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates + + Returns: + A list of size batch_size, containing tuples of (index_i, index_j) where: + - index_i is the indices of the selected predictions (in order) + - index_j is the indices of the corresponding selected targets (in order) + For each batch element, it holds: + len(index_i) = len(index_j) = min(num_queries, num_target_boxes) + """ + if repeat_batch != 1: + raise NotImplementedError("please use BinaryHungarianMatcherV2 instead") + + bs, num_queries = outputs["pred_logits"].shape[:2] + + # We flatten to compute the cost matrices in a batch + out_score = outputs["pred_logits"].flatten(0, 1).squeeze(-1) + out_prob = self.norm(out_score) # [batch_size * num_queries] + out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] + + # Also concat the target labels and boxes + tgt_bbox = batched_targets["boxes"] + + # Compute the L1 cost between boxes + cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) + + # Compute the giou cost betwen boxes + cost_giou = -generalized_box_iou( + box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox) + ) + + # cost_class = -out_prob.unsqueeze(-1).expand_as(cost_bbox) + if self.stable: + rescaled_giou = (-cost_giou + 1) / 2 + out_prob = out_prob.unsqueeze(-1).expand_as(cost_bbox) * rescaled_giou + cost_class = -self.alpha * (1 - out_prob) ** self.gamma * torch.log( + out_prob + ) + (1 - self.alpha) * out_prob**self.gamma * torch.log(1 - out_prob) + else: + # directly computing log sigmoid (more numerically stable) + log_out_prob = torch.nn.functional.logsigmoid(out_score) + log_one_minus_out_prob = torch.nn.functional.logsigmoid(-out_score) + cost_class = ( + -self.alpha * (1 - out_prob) ** self.gamma * log_out_prob + + (1 - self.alpha) * out_prob**self.gamma * log_one_minus_out_prob + ) + if not self.stable: + cost_class = cost_class.unsqueeze(-1).expand_as(cost_bbox) + + assert cost_class.shape == cost_bbox.shape + + # Final cost matrix + C = ( + self.cost_bbox * cost_bbox + + self.cost_class * cost_class + + self.cost_giou * cost_giou + ) + C = C.view(bs, num_queries, -1).cpu().numpy() + + sizes = torch.cumsum(batched_targets["num_boxes"], -1)[:-1] + costs = [c[i] for i, c in enumerate(np.split(C, sizes.cpu().numpy(), axis=-1))] + return_tgt_indices = False + for c in costs: + n_targ = c.shape[1] + if repeats > 1: + n_targ *= repeats + if c.shape[0] < n_targ: + return_tgt_indices = True + break + if return_tgt_indices: + indices, tgt_indices = zip( + *( + _do_matching( + c, repeats=repeats, return_tgt_indices=return_tgt_indices + ) + for c in costs + ) + ) + tgt_indices = list(tgt_indices) + for i in range(1, len(tgt_indices)): + tgt_indices[i] += sizes[i - 1].item() + tgt_idx = torch.from_numpy(np.concatenate(tgt_indices)).long() + else: + indices = [_do_matching(c, repeats=repeats) for c in costs] + tgt_idx = None + + batch_idx = torch.as_tensor( + sum([[i] * len(src) for i, src in enumerate(indices)], []), dtype=torch.long + ) + src_idx = torch.from_numpy(np.concatenate(indices)).long() + return batch_idx, src_idx, tgt_idx + + +class BinaryHungarianMatcherV2(nn.Module): + """ + This class computes an assignment between the targets and the predictions + of the network + + For efficiency reasons, the targets don't include the no_object. Because of + this, in general, there are more predictions than targets. In this case, we + do a 1-to-1 matching of the best predictions, while the others are + un-matched (and thus treated as non-objects). + + This is a more efficient implementation of BinaryHungarianMatcher. + """ + + def __init__( + self, + cost_class: float = 1, + cost_bbox: float = 1, + cost_giou: float = 1, + focal: bool = False, + alpha: float = 0.25, + gamma: float = 2.0, + stable: bool = False, + remove_samples_with_0_gt: bool = True, + ): + """ + Creates the matcher + + Params: + - cost_class: Relative weight of the classification error in the + matching cost + - cost_bbox: Relative weight of the L1 error of the bounding box + coordinates in the matching cost + - cost_giou: This is the relative weight of the giou loss of the + bounding box in the matching cost + """ + super().__init__() + self.cost_class = cost_class + self.cost_bbox = cost_bbox + self.cost_giou = cost_giou + self.norm = nn.Sigmoid() + assert ( + cost_class != 0 or cost_bbox != 0 or cost_giou != 0 + ), "all costs cant be 0" + self.focal = focal + if focal: + self.alpha = alpha + self.gamma = gamma + self.stable = stable + self.remove_samples_with_0_gt = remove_samples_with_0_gt + + @torch.no_grad() + def forward( + self, + outputs, + batched_targets, + repeats=1, + repeat_batch=1, + out_is_valid=None, + target_is_valid_padded=None, + ): + """ + Performs the matching. The inputs and outputs are the same as + BinaryHungarianMatcher.forward, except for the optional cached_padded + flag and the optional "_boxes_padded" entry of batched_targets. + + Inputs: + - outputs: A dict with the following keys: + - "pred_logits": Tensor of shape (batch_size, num_queries, 1) with + classification logits + - "pred_boxes": Tensor of shape (batch_size, num_queries, 4) with + predicted box coordinates in cxcywh format. + - batched_targets: A dict of targets. There may be a variable number of + targets per batch entry; suppose that there are T_b targets for batch + entry 0 <= b < batch_size. It should have the following keys: + - "boxes": Tensor of shape (sum_b T_b, 4) giving ground-truth boxes + in cxcywh format for all batch entries packed into a single tensor + - "num_boxes": int64 Tensor of shape (batch_size,) giving the number + of ground-truth boxes per batch entry: num_boxes[b] = T_b + - "_boxes_padded": Tensor of shape (batch_size, max_b T_b, 4) giving + a padded version of ground-truth boxes. If this is not present then + it will be computed from batched_targets["boxes"] instead, but + caching it here can improve performance for repeated calls with the + same targets. + - out_is_valid: If not None, it should be a boolean tensor of shape + (batch_size, num_queries) indicating which predictions are valid. + Invalid predictions are ignored during matching and won't appear in + the output indices. + - target_is_valid_padded: If not None, it should be a boolean tensor of + shape (batch_size, max_num_gt_boxes) in padded format indicating + which GT boxes are valid. Invalid GT boxes are ignored during matching + and won't appear in the output indices. + + Returns: + A list of size batch_size, containing tuples of (idx_i, idx_j): + - idx_i is the indices of the selected predictions (in order) + - idx_j is the indices of the corresponding selected targets + (in order) + For each batch element, it holds: + len(index_i) = len(index_j) + = min(num_queries, num_target_boxes) + """ + _, num_queries = outputs["pred_logits"].shape[:2] + + out_score = outputs["pred_logits"].squeeze(-1) # (B, Q) + out_bbox = outputs["pred_boxes"] # (B, Q, 4)) + + device = out_score.device + + num_boxes = batched_targets["num_boxes"].cpu() + # Get a padded version of target boxes (as precomputed in the collator). + # It should work for both repeat==1 (o2o) and repeat>1 (o2m) matching. + tgt_bbox = batched_targets["boxes_padded"] + if self.remove_samples_with_0_gt: + # keep only samples w/ at least 1 GT box in targets (num_boxes and tgt_bbox) + batch_keep = num_boxes > 0 + num_boxes = num_boxes[batch_keep] + tgt_bbox = tgt_bbox[batch_keep] + if target_is_valid_padded is not None: + target_is_valid_padded = target_is_valid_padded[batch_keep] + # Repeat the targets (for the case of batched aux outputs in the matcher) + if repeat_batch > 1: + # In this case, out_prob and out_bbox will be a concatenation of + # both final and auxiliary outputs, so we also repeat the targets + num_boxes = num_boxes.repeat(repeat_batch) + tgt_bbox = tgt_bbox.repeat(repeat_batch, 1, 1) + if target_is_valid_padded is not None: + target_is_valid_padded = target_is_valid_padded.repeat(repeat_batch, 1) + + # keep only samples w/ at least 1 GT box in outputs + if self.remove_samples_with_0_gt: + if repeat_batch > 1: + batch_keep = batch_keep.repeat(repeat_batch) + out_score = out_score[batch_keep] + out_bbox = out_bbox[batch_keep] + if out_is_valid is not None: + out_is_valid = out_is_valid[batch_keep] + assert out_bbox.shape[0] == tgt_bbox.shape[0] + assert out_bbox.shape[0] == num_boxes.shape[0] + + # Compute the L1 cost between boxes + cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) + + # Compute the giou cost betwen boxes + cost_giou = -generalized_box_iou( + box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox) + ) + + out_prob = self.norm(out_score) + if not self.focal: + cost_class = -out_prob.unsqueeze(-1).expand_as(cost_bbox) + else: + if self.stable: + rescaled_giou = (-cost_giou + 1) / 2 + out_prob = out_prob.unsqueeze(-1).expand_as(cost_bbox) * rescaled_giou + cost_class = -self.alpha * (1 - out_prob) ** self.gamma * torch.log( + out_prob + ) + (1 - self.alpha) * out_prob**self.gamma * torch.log(1 - out_prob) + else: + # directly computing log sigmoid (more numerically stable) + log_out_prob = torch.nn.functional.logsigmoid(out_score) + log_one_minus_out_prob = torch.nn.functional.logsigmoid(-out_score) + cost_class = ( + -self.alpha * (1 - out_prob) ** self.gamma * log_out_prob + + (1 - self.alpha) * out_prob**self.gamma * log_one_minus_out_prob + ) + if not self.stable: + cost_class = cost_class.unsqueeze(-1).expand_as(cost_bbox) + + assert cost_class.shape == cost_bbox.shape + + # Final cost matrix + C = ( + self.cost_bbox * cost_bbox + + self.cost_class * cost_class + + self.cost_giou * cost_giou + ) + # assign a very high cost (1e9) to invalid outputs and targets, so that we can + # filter them out (in `_do_matching`) from bipartite matching results + do_filtering = out_is_valid is not None or target_is_valid_padded is not None + if out_is_valid is not None: + C = torch.where(out_is_valid[:, :, None], C, 1e9) + if target_is_valid_padded is not None: + C = torch.where(target_is_valid_padded[:, None, :], C, 1e9) + # Guard against NaN/Inf from numerical edge cases (e.g. zero-area + # boxes in GIoU, log(0) in focal cost). Assign high cost so these + # entries are never matched by the Hungarian algorithm. + C = torch.nan_to_num(C, nan=1e9, posinf=1e9, neginf=-1e9) + C = C.cpu().numpy() + costs = [C[i, :, :s] for i, s in enumerate(num_boxes.tolist())] + return_tgt_indices = ( + do_filtering or torch.any(num_queries < num_boxes * max(repeats, 1)).item() + ) + if len(costs) == 0: + # We have size 0 in the batch dimension, so we return empty matching indices + # (note that this can happen due to `remove_samples_with_0_gt=True` even if + # the original input batch size is not 0, when all queries have empty GTs). + indices = [] + tgt_idx = torch.zeros(0).long().to(device) if return_tgt_indices else None + elif return_tgt_indices: + indices, tgt_indices = zip( + *( + _do_matching( + c, + repeats=repeats, + return_tgt_indices=return_tgt_indices, + do_filtering=do_filtering, + ) + for c in costs + ) + ) + tgt_indices = list(tgt_indices) + sizes = torch.cumsum(num_boxes, -1)[:-1] + for i in range(1, len(tgt_indices)): + tgt_indices[i] += sizes[i - 1].item() + tgt_idx = torch.from_numpy(np.concatenate(tgt_indices)).long().to(device) + else: + indices = [ + _do_matching(c, repeats=repeats, do_filtering=do_filtering) + for c in costs + ] + tgt_idx = None + + if self.remove_samples_with_0_gt: + kept_inds = batch_keep.nonzero().squeeze(1) + batch_idx = torch.as_tensor( + sum([[kept_inds[i]] * len(src) for i, src in enumerate(indices)], []), + dtype=torch.long, + device=device, + ) + else: + batch_idx = torch.as_tensor( + sum([[i] * len(src) for i, src in enumerate(indices)], []), + dtype=torch.long, + device=device, + ) + + # indices could be an empty list (since we remove samples w/ 0 GT boxes) + if len(indices) > 0: + src_idx = torch.from_numpy(np.concatenate(indices)).long().to(device) + else: + src_idx = torch.empty(0, dtype=torch.long, device=device) + return batch_idx, src_idx, tgt_idx + + +class BinaryOneToManyMatcher(nn.Module): + """ + This class computes a greedy assignment between the targets and the predictions of the network. + In this formulation, several predictions can be assigned to each target, but each prediction can be assigned to + at most one target. + + See DAC-Detr for details + """ + + def __init__( + self, + alpha: float = 0.3, + threshold: float = 0.4, + topk: int = 6, + ): + """ + Creates the matcher + + Params: + alpha: relative balancing between classification and localization + threshold: threshold used to select positive predictions + topk: number of top scoring predictions to consider + """ + super().__init__() + self.norm = nn.Sigmoid() + self.alpha = alpha + self.threshold = threshold + self.topk = topk + + @torch.no_grad() + def forward( + self, + outputs, + batched_targets, + repeats=1, + repeat_batch=1, + out_is_valid=None, + target_is_valid_padded=None, + ): + """ + Performs the matching. The inputs and outputs are the same as + BinaryHungarianMatcher.forward + + Inputs: + - outputs: A dict with the following keys: + - "pred_logits": Tensor of shape (batch_size, num_queries, 1) with + classification logits + - "pred_boxes": Tensor of shape (batch_size, num_queries, 4) with + predicted box coordinates in cxcywh format. + - batched_targets: A dict of targets. There may be a variable number of + targets per batch entry; suppose that there are T_b targets for batch + entry 0 <= b < batch_size. It should have the following keys: + - "num_boxes": int64 Tensor of shape (batch_size,) giving the number + of ground-truth boxes per batch entry: num_boxes[b] = T_b + - "_boxes_padded": Tensor of shape (batch_size, max_b T_b, 4) giving + a padded version of ground-truth boxes. If this is not present then + it will be computed from batched_targets["boxes"] instead, but + caching it here can improve performance for repeated calls with the + same targets. + - out_is_valid: If not None, it should be a boolean tensor of shape + (batch_size, num_queries) indicating which predictions are valid. + Invalid predictions are ignored during matching and won't appear in + the output indices. + - target_is_valid_padded: If not None, it should be a boolean tensor of + shape (batch_size, max_num_gt_boxes) in padded format indicating + which GT boxes are valid. Invalid GT boxes are ignored during matching + and won't appear in the output indices. + Returns: + A list of size batch_size, containing tuples of (idx_i, idx_j): + - idx_i is the indices of the selected predictions (in order) + - idx_j is the indices of the corresponding selected targets + (in order) + For each batch element, it holds: + len(index_i) = len(index_j) + = min(num_queries, num_target_boxes) + """ + assert repeats <= 1 and repeat_batch <= 1 + bs, num_queries = outputs["pred_logits"].shape[:2] + + out_prob = self.norm(outputs["pred_logits"]).squeeze(-1) # (B, Q) + out_bbox = outputs["pred_boxes"] # (B, Q, 4)) + + num_boxes = batched_targets["num_boxes"] + + # Get a padded version of target boxes (as precomputed in the collator). + tgt_bbox = batched_targets["boxes_padded"] + assert len(tgt_bbox) == bs + num_targets = tgt_bbox.shape[1] + if num_targets == 0: + return ( + torch.empty(0, dtype=torch.long, device=out_prob.device), + torch.empty(0, dtype=torch.long, device=out_prob.device), + torch.empty(0, dtype=torch.long, device=out_prob.device), + ) + + iou, _ = box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)) + + assert iou.shape == (bs, num_queries, num_targets) + + # Final cost matrix (higher is better in `C`; this is unlike the case + # of BinaryHungarianMatcherV2 above where lower is better in its `C`) + C = self.alpha * out_prob.unsqueeze(-1) + (1 - self.alpha) * iou + if out_is_valid is not None: + C = torch.where(out_is_valid[:, :, None], C, -1e9) + if target_is_valid_padded is not None: + C = torch.where(target_is_valid_padded[:, None, :], C, -1e9) + + # Selecting topk predictions + matches = C > torch.quantile( + C, 1 - self.topk / num_queries, dim=1, keepdim=True + ) + + # Selecting predictions above threshold + matches = matches & (C > self.threshold) + if out_is_valid is not None: + matches = matches & out_is_valid[:, :, None] + if target_is_valid_padded is not None: + matches = matches & target_is_valid_padded[:, None, :] + + # Removing padding + matches = matches & ( + torch.arange(0, num_targets, device=num_boxes.device)[None] + < num_boxes[:, None] + ).unsqueeze(1) + + batch_idx, src_idx, tgt_idx = torch.nonzero(matches, as_tuple=True) + + cum_num_boxes = torch.cat( + [ + torch.zeros(1, dtype=num_boxes.dtype, device=num_boxes.device), + num_boxes.cumsum(-1)[:-1], + ] + ) + tgt_idx += cum_num_boxes[batch_idx] + + return batch_idx, src_idx, tgt_idx diff --git a/third_party/sam3/sam3/train/nms_helper.py b/third_party/sam3/sam3/train/nms_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..cf19226a7795ca410877192595d83e6739010769 --- /dev/null +++ b/third_party/sam3/sam3/train/nms_helper.py @@ -0,0 +1,308 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe +import warnings +from typing import Dict, List + +import numpy as np + +# Check if Numba is available +HAS_NUMBA = False +try: + import numba as nb + + HAS_NUMBA = True +except ImportError: + warnings.warn( + "Numba not found. Using slower pure Python implementations.", UserWarning + ) + + +# -------------------- Helper Functions -------------------- +def is_zero_box(bbox: list) -> bool: + """Check if bounding box is invalid""" + if bbox is None: + return True + return all(x <= 0 for x in bbox[:4]) or len(bbox) < 4 + + +def convert_bbox_format(bbox: list) -> List[float]: + """Convert bbox from (x,y,w,h) to (x1,y1,x2,y2)""" + x, y, w, h = bbox + return [x, y, x + w, y + h] + + +# -------------------- Track-level NMS -------------------- +def process_track_level_nms(video_groups: Dict, nms_threshold: float) -> Dict: + """Apply track-level NMS to all videos""" + for tracks in video_groups.values(): + track_detections = [] + + # Process tracks + for track_idx, track in enumerate(tracks): + if not track["bboxes"]: + continue + + converted_bboxes = [] + valid_frames = [] + for bbox in track["bboxes"]: + if bbox and not is_zero_box(bbox): + converted_bboxes.append(convert_bbox_format(bbox)) + valid_frames.append(True) + else: + converted_bboxes.append([np.nan] * 4) + valid_frames.append(False) + + if any(valid_frames): + track_detections.append( + { + "track_idx": track_idx, + "bboxes": np.array(converted_bboxes, dtype=np.float32), + "score": track["score"], + } + ) + + # Apply NMS + if track_detections: + scores = np.array([d["score"] for d in track_detections], dtype=np.float32) + keep = apply_track_nms(track_detections, scores, nms_threshold) + + # Suppress non-kept tracks + for idx, track in enumerate(track_detections): + if idx not in keep: + tracks[track["track_idx"]]["bboxes"] = [None] * len(track["bboxes"]) + + return video_groups + + +# -------------------- Frame-level NMS -------------------- +def process_frame_level_nms(video_groups: Dict, nms_threshold: float) -> Dict: + """Apply frame-level NMS to all videos""" + for tracks in video_groups.values(): + if not tracks: + continue + + num_frames = len(tracks[0]["bboxes"]) + + for frame_idx in range(num_frames): + frame_detections = [] + + # Collect valid detections + for track_idx, track in enumerate(tracks): + bbox = track["bboxes"][frame_idx] + if bbox and not is_zero_box(bbox): + frame_detections.append( + { + "track_idx": track_idx, + "bbox": np.array( + convert_bbox_format(bbox), dtype=np.float32 + ), + "score": track["score"], + } + ) + + # Apply NMS + if frame_detections: + bboxes = np.stack([d["bbox"] for d in frame_detections]) + scores = np.array( + [d["score"] for d in frame_detections], dtype=np.float32 + ) + keep = apply_frame_nms(bboxes, scores, nms_threshold) + + # Suppress non-kept detections + for i, d in enumerate(frame_detections): + if i not in keep: + tracks[d["track_idx"]]["bboxes"][frame_idx] = None + + return video_groups + + +# Track-level NMS helpers ------------------------------------------------------ +def compute_track_iou_matrix( + bboxes_stacked: np.ndarray, valid_masks: np.ndarray, areas: np.ndarray +) -> np.ndarray: + """IoU matrix computation for track-level NMS with fallback to pure Python""" + num_tracks = bboxes_stacked.shape[0] + iou_matrix = np.zeros((num_tracks, num_tracks), dtype=np.float32) + if HAS_NUMBA: + iou_matrix = _compute_track_iou_matrix_numba(bboxes_stacked, valid_masks, areas) + else: + # Pure Python implementation + for i in range(num_tracks): + for j in range(i + 1, num_tracks): + valid_ij = valid_masks[i] & valid_masks[j] + if not valid_ij.any(): + continue + bboxes_i = bboxes_stacked[i, valid_ij] + bboxes_j = bboxes_stacked[j, valid_ij] + area_i = areas[i, valid_ij] + area_j = areas[j, valid_ij] + inter_total = 0.0 + union_total = 0.0 + for k in range(bboxes_i.shape[0]): + x1 = max(bboxes_i[k, 0], bboxes_j[k, 0]) + y1 = max(bboxes_i[k, 1], bboxes_j[k, 1]) + x2 = min(bboxes_i[k, 2], bboxes_j[k, 2]) + y2 = min(bboxes_i[k, 3], bboxes_j[k, 3]) + inter = max(0, x2 - x1) * max(0, y2 - y1) + union = area_i[k] + area_j[k] - inter + inter_total += inter + union_total += union + if union_total > 0: + iou_matrix[i, j] = inter_total / union_total + iou_matrix[j, i] = iou_matrix[i, j] + return iou_matrix + + +if HAS_NUMBA: + + @nb.jit(nopython=True, parallel=True) + def _compute_track_iou_matrix_numba(bboxes_stacked, valid_masks, areas): + """Numba-optimized IoU matrix computation for track-level NMS""" + num_tracks = bboxes_stacked.shape[0] + iou_matrix = np.zeros((num_tracks, num_tracks), dtype=np.float32) + for i in nb.prange(num_tracks): + for j in range(i + 1, num_tracks): + valid_ij = valid_masks[i] & valid_masks[j] + if not valid_ij.any(): + continue + bboxes_i = bboxes_stacked[i, valid_ij] + bboxes_j = bboxes_stacked[j, valid_ij] + area_i = areas[i, valid_ij] + area_j = areas[j, valid_ij] + inter_total = 0.0 + union_total = 0.0 + for k in range(bboxes_i.shape[0]): + x1 = max(bboxes_i[k, 0], bboxes_j[k, 0]) + y1 = max(bboxes_i[k, 1], bboxes_j[k, 1]) + x2 = min(bboxes_i[k, 2], bboxes_j[k, 2]) + y2 = min(bboxes_i[k, 3], bboxes_j[k, 3]) + inter = max(0, x2 - x1) * max(0, y2 - y1) + union = area_i[k] + area_j[k] - inter + inter_total += inter + union_total += union + if union_total > 0: + iou_matrix[i, j] = inter_total / union_total + iou_matrix[j, i] = iou_matrix[i, j] + return iou_matrix + + +def apply_track_nms( + track_detections: List[dict], scores: np.ndarray, nms_threshold: float +) -> List[int]: + """Vectorized track-level NMS implementation""" + if not track_detections: + return [] + bboxes_stacked = np.stack([d["bboxes"] for d in track_detections], axis=0) + valid_masks = ~np.isnan(bboxes_stacked).any(axis=2) + areas = (bboxes_stacked[:, :, 2] - bboxes_stacked[:, :, 0]) * ( + bboxes_stacked[:, :, 3] - bboxes_stacked[:, :, 1] + ) + areas[~valid_masks] = 0 + iou_matrix = compute_track_iou_matrix(bboxes_stacked, valid_masks, areas) + keep = [] + order = np.argsort(-scores) + suppress = np.zeros(len(track_detections), dtype=bool) + for i in range(len(order)): + if not suppress[order[i]]: + keep.append(order[i]) + suppress[order[i:]] = suppress[order[i:]] | ( + iou_matrix[order[i], order[i:]] >= nms_threshold + ) + return keep + + +# Frame-level NMS helpers ------------------------------------------------------ +def compute_frame_ious(bbox: np.ndarray, bboxes: np.ndarray) -> np.ndarray: + """IoU computation for frame-level NMS with fallback to pure Python""" + if HAS_NUMBA: + return _compute_frame_ious_numba(bbox, bboxes) + else: + # Pure Python implementation + ious = np.zeros(len(bboxes), dtype=np.float32) + for i in range(len(bboxes)): + x1 = max(bbox[0], bboxes[i, 0]) + y1 = max(bbox[1], bboxes[i, 1]) + x2 = min(bbox[2], bboxes[i, 2]) + y2 = min(bbox[3], bboxes[i, 3]) + + inter = max(0, x2 - x1) * max(0, y2 - y1) + area1 = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) + area2 = (bboxes[i, 2] - bboxes[i, 0]) * (bboxes[i, 3] - bboxes[i, 1]) + union = area1 + area2 - inter + + ious[i] = inter / union if union > 0 else 0.0 + return ious + + +if HAS_NUMBA: + + @nb.jit(nopython=True, parallel=True) + def _compute_frame_ious_numba(bbox, bboxes): + """Numba-optimized IoU computation""" + ious = np.zeros(len(bboxes), dtype=np.float32) + for i in nb.prange(len(bboxes)): + x1 = max(bbox[0], bboxes[i, 0]) + y1 = max(bbox[1], bboxes[i, 1]) + x2 = min(bbox[2], bboxes[i, 2]) + y2 = min(bbox[3], bboxes[i, 3]) + + inter = max(0, x2 - x1) * max(0, y2 - y1) + area1 = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) + area2 = (bboxes[i, 2] - bboxes[i, 0]) * (bboxes[i, 3] - bboxes[i, 1]) + union = area1 + area2 - inter + + ious[i] = inter / union if union > 0 else 0.0 + return ious + + +def apply_frame_nms( + bboxes: np.ndarray, scores: np.ndarray, nms_threshold: float +) -> List[int]: + """Frame-level NMS implementation with fallback to pure Python""" + if HAS_NUMBA: + return _apply_frame_nms_numba(bboxes, scores, nms_threshold) + else: + # Pure Python implementation + order = np.argsort(-scores) + keep = [] + suppress = np.zeros(len(bboxes), dtype=bool) + + for i in range(len(order)): + if not suppress[order[i]]: + keep.append(order[i]) + current_bbox = bboxes[order[i]] + + remaining_bboxes = bboxes[order[i + 1 :]] + if len(remaining_bboxes) > 0: # Check if there are any remaining boxes + ious = compute_frame_ious(current_bbox, remaining_bboxes) + suppress[order[i + 1 :]] = suppress[order[i + 1 :]] | ( + ious >= nms_threshold + ) + + return keep + + +if HAS_NUMBA: + + @nb.jit(nopython=True) + def _apply_frame_nms_numba(bboxes, scores, nms_threshold): + """Numba-optimized NMS implementation""" + order = np.argsort(-scores) + keep = [] + suppress = np.zeros(len(bboxes), dtype=nb.boolean) + + for i in range(len(order)): + if not suppress[order[i]]: + keep.append(order[i]) + current_bbox = bboxes[order[i]] + + if i + 1 < len(order): # Check bounds + ious = _compute_frame_ious_numba( + current_bbox, bboxes[order[i + 1 :]] + ) + suppress[order[i + 1 :]] = suppress[order[i + 1 :]] | ( + ious >= nms_threshold + ) + + return keep diff --git a/third_party/sam3/sam3/train/optim/__init__.py b/third_party/sam3/sam3/train/optim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..726cb1eb0ff0f17f7edc4ec73f1a73d1ac87bf59 --- /dev/null +++ b/third_party/sam3/sam3/train/optim/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe diff --git a/third_party/sam3/sam3/train/optim/optimizer.py b/third_party/sam3/sam3/train/optim/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..d0e7b7520839918ac8b747ddfc5bfc0009eb63d7 --- /dev/null +++ b/third_party/sam3/sam3/train/optim/optimizer.py @@ -0,0 +1,499 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +import fnmatch +import inspect +import itertools +import logging +import types +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Mapping, + Optional, + Set, + Tuple, + Type, + Union, +) + +import hydra +import torch +import torch.nn as nn +from omegaconf import DictConfig +from torch import Tensor + + +class Optimizer: + def __init__(self, optimizer, schedulers=None) -> None: + self.optimizer = optimizer + self.schedulers = schedulers + self._validate_optimizer_schedulers() + self.step_schedulers(0.0, 0) + + def _validate_optimizer_schedulers(self): + if self.schedulers is None: + return + for _, set_of_schedulers in enumerate(self.schedulers): + for option, _ in set_of_schedulers.items(): + assert option in self.optimizer.defaults, ( + "Optimizer option " + f"{option} not found in {self.optimizer}. Valid options are " + f"{self.optimizer.defaults.keys()}" + ) + + def step_schedulers(self, where: float, step: int) -> None: + if self.schedulers is None: + return + for i, param_group in enumerate(self.optimizer.param_groups): + for option, scheduler in self.schedulers[i].items(): + if "step" in inspect.signature(scheduler.__call__).parameters: + new_value = scheduler(step=step, where=where) + elif ( + hasattr(scheduler, "scheduler") + and "step" + in inspect.signature(scheduler.scheduler.__call__).parameters + ): + # To handle ValueScaler wrappers + new_value = scheduler(step=step, where=where) + else: + new_value = scheduler(where) + param_group[option] = new_value + + def step(self, where, step, closure=None): + self.step_schedulers(where, step) + return self.optimizer.step(closure) + + def zero_grad(self, *args, **kwargs): + return self.optimizer.zero_grad(*args, **kwargs) + + +def set_default_parameters( + scheduler_cfgs: List[DictConfig], all_parameter_names: Set[str] +) -> None: + """Set up the "default" scheduler with the right parameters. + + Args: + scheduler_cgfs: A list of scheduler configs, where each scheduler also + specifies which parameters it applies to, based on the names of parameters + or the class of the modules. At most one scheduler is allowed to skip this + specification, which is used as a "default" specification for any remaining + parameters. + all_parameter_names: Names of all the parameters to consider. + """ + constraints = [ + scheduler_cfg.parameter_names + for scheduler_cfg in scheduler_cfgs + if scheduler_cfg.parameter_names is not None + ] + if len(constraints) == 0: + default_params = set(all_parameter_names) + else: + default_params = all_parameter_names - set.union(*constraints) + default_count = 0 + for scheduler_cfg in scheduler_cfgs: + if scheduler_cfg.parameter_names is None: + scheduler_cfg.parameter_names = default_params + default_count += 1 + assert default_count <= 1, "Only one scheduler per option can be default" + if default_count == 0: + # No default scheduler specified, add a default, but without any scheduler + # for that option + scheduler_cfgs.append({"parameter_names": default_params}) + + +def name_constraints_to_parameters( + param_constraints: List[Set[str]], named_parameters: Dict[str, Tensor] +) -> List[torch.nn.Parameter]: + """Return parameters which match the intersection of parameter constraints. + + Note that this returns the parameters themselves, not their names. + + Args: + param_constraints: A list, with each element being a set of allowed parameters. + named_parameters: Mapping from a parameter name to the parameter itself. + + Returns: + A list containing the parameters which overlap with _each_ constraint set from + param_constraints. + """ + matching_names = set.intersection(*param_constraints) + return [value for name, value in named_parameters.items() if name in matching_names] + + +def map_scheduler_cfgs_to_param_groups( + all_scheduler_cfgs: Iterable[List[Dict]], + named_parameters: Dict[str, Tensor], +) -> Tuple[List[Dict[Any, Any]], List[Dict[str, List[torch.nn.Parameter]]]]: + """Produce parameter groups corresponding to all the scheduler configs. + + Takes all the scheduler configs, each of which applies to a specific optimizer + option (like "lr" or "weight_decay") and has a set of parameter names which it + applies to, and produces a final set of param groups where each param group + covers all the options which apply to a particular set of parameters. + + Args: + all_scheduler_cfgs: All the scheduler configs covering every option. + named_parameters: Mapping from a parameter name to the parameter itself. + Returns: + Tuple of lists of schedulers and param_groups, where schedulers[i] + applies to param_groups[i]. + """ + + scheduler_cfgs_per_param_group = itertools.product(*all_scheduler_cfgs) + schedulers = [] + param_groups = [] + for scheduler_cfgs in scheduler_cfgs_per_param_group: + param_constraints = [ + scheduler_cfg["parameter_names"] for scheduler_cfg in scheduler_cfgs + ] + matching_parameters = name_constraints_to_parameters( + param_constraints, named_parameters + ) + if len(matching_parameters) == 0: # If no overlap of parameters, skip + continue + schedulers_for_group = { + scheduler_cfg["option"]: scheduler_cfg["scheduler"] + for scheduler_cfg in scheduler_cfgs + if "option" in scheduler_cfg + } + schedulers.append(schedulers_for_group) + param_groups.append({"params": matching_parameters}) + return schedulers, param_groups + + +def validate_param_group_params(param_groups: List[Dict], model: nn.Module): + """Check that the param groups are non-overlapping and cover all the parameters. + + Args: + param_groups: List of all param groups + model: Model to validate against. The check ensures that all the model + parameters are part of param_groups + """ + for pg in param_groups: + # no param should be repeated within a group + assert len(pg["params"]) == len(set(pg["params"])) + parameters = [set(param_group["params"]) for param_group in param_groups] + model_parameters = {parameter for _, parameter in model.named_parameters()} + for p1, p2 in itertools.permutations(parameters, 2): + assert p1.isdisjoint(p2), "Scheduler generated param_groups should be disjoint" + assert set.union(*parameters) == model_parameters, ( + "Scheduler generated param_groups must include all parameters of the model." + f" Found {len(set.union(*parameters))} params whereas model has" + f" {len(model_parameters)} params" + ) + + +def unix_module_cls_pattern_to_parameter_names( + filter_module_cls_names: List[str], + module_cls_to_param_names: Dict[Type, str], +) -> Union[None, Set[str]]: + """Returns param names which pass the filters specified in filter_module_cls_names. + + Args: + filter_module_cls_names: A list of filter strings containing class names, like + ["torch.nn.LayerNorm", "torch.nn.BatchNorm2d"] + module_cls_to_param_names: Mapping from module classes to the parameter names + they contain. See `get_module_cls_to_param_names`. + """ + if filter_module_cls_names is None: + return set() + allowed_parameter_names = [] + for module_cls_name in filter_module_cls_names: + module_cls = hydra.utils.get_class(module_cls_name) + if module_cls not in module_cls_to_param_names: + raise AssertionError( + f"module_cls_name {module_cls_name} does not " + "match any classes in the model" + ) + matching_parameters = module_cls_to_param_names[module_cls] + assert ( + len(matching_parameters) > 0 + ), f"module_cls_name {module_cls_name} does not contain any parameters in the model" + logging.info( + f"Matches for module_cls_name [{module_cls_name}]: {matching_parameters} " + ) + allowed_parameter_names.append(matching_parameters) + return set.union(*allowed_parameter_names) + + +def unix_param_pattern_to_parameter_names( + filter_param_names: Optional[List[str]], + parameter_names: Dict[str, torch.Tensor], +) -> Union[None, Set[str]]: + """Returns param names which pass the filters specified in filter_param_names. + + Args: + filter_param_names: A list of unix-style filter strings with optional + wildcards, like ["block.2.*", "block.2.linear.weight"] + module_cls_to_param_names: Mapping from module classes to the parameter names + they contain. See `get_module_cls_to_param_names`. + """ + + if filter_param_names is None: + return set() + allowed_parameter_names = [] + for param_name in filter_param_names: + matching_parameters = set(fnmatch.filter(parameter_names, param_name)) + assert ( + len(matching_parameters) >= 1 + ), f"param_name {param_name} does not match any parameters in the model" + logging.info(f"Matches for param_name [{param_name}]: {matching_parameters}") + allowed_parameter_names.append(matching_parameters) + return set.union(*allowed_parameter_names) + + +def _unix_pattern_to_parameter_names( + scheduler_cfg: DictConfig, + parameter_names: Set[str], + module_cls_to_param_names: Dict[Type, str], +) -> Union[None, Set[str]]: + """Returns param names which pass the filters specified in scheduler_cfg. + + Args: + scheduler_cfg: The config for the scheduler + parameter_names: The set of all parameter names which will be filtered + """ + if "param_names" not in scheduler_cfg and "module_cls_names" not in scheduler_cfg: + return None + return unix_param_pattern_to_parameter_names( + scheduler_cfg.get("param_names"), parameter_names + ).union( + unix_module_cls_pattern_to_parameter_names( + scheduler_cfg.get("module_cls_names"), module_cls_to_param_names + ) + ) + + +def get_module_cls_to_param_names( + model: nn.Module, param_allowlist: Set[str] = None +) -> Dict[Type, str]: + """Produce a mapping from all the modules classes to the names of parames they own. + + Only counts a parameter as part of the immediate parent module, i.e. recursive + parents do not count. + + Args: + model: Model to iterate over + param_allowlist: If specified, only these param names will be processed + """ + + module_cls_to_params = {} + for module_name, module in model.named_modules(): + module_cls = type(module) + module_cls_to_params.setdefault(module_cls, set()) + for param_name, _ in module.named_parameters(recurse=False): + full_param_name = get_full_parameter_name(module_name, param_name) + if param_allowlist is None or full_param_name in param_allowlist: + module_cls_to_params[module_cls].add(full_param_name) + return module_cls_to_params + + +def construct_optimizer( + model: torch.nn.Module, + optimizer_conf: Any, + options_conf: Mapping[str, List] = None, + param_group_modifiers_conf: List[Callable] = None, + param_allowlist: Optional[Set[str]] = None, + validate_param_groups=True, +) -> Optimizer: + """ + Constructs a stochastic gradient descent or ADAM (or ADAMw) optimizer + with momentum. i.e, constructs a torch.optim.Optimizer with zero-weight decay + Batchnorm and/or no-update 1-D parameters support, based on the config. + + Supports wrapping the optimizer with Layer-wise Adaptive Rate Scaling + (LARS): https://arxiv.org/abs/1708.03888 + + Args: + model: model to perform stochastic gradient descent + optimization or ADAM optimization. + optimizer_conf: Hydra config consisting a partial torch optimizer like SGD or + ADAM, still missing the params argument which this function provides to + produce the final optimizer + param_group_modifiers_conf: Optional user specified functions which can modify + the final scheduler configs before the optimizer's param groups are built + param_allowlist: The parameters to optimize. Parameters which are not part of + this allowlist will be skipped. + validate_param_groups: If enabled, valides that the produced param_groups don't + overlap and cover all the model parameters. + """ + if param_allowlist is None: + param_allowlist = {name for name, _ in model.named_parameters()} + + named_parameters = { + name: param + for name, param in model.named_parameters() + if name in param_allowlist + } + + if not options_conf: + optimizer = hydra.utils.instantiate(optimizer_conf, named_parameters.values()) + return Optimizer(optimizer) + + all_parameter_names = { + name for name, _ in model.named_parameters() if name in param_allowlist + } + module_cls_to_all_param_names = get_module_cls_to_param_names( + model, param_allowlist + ) + + scheduler_cfgs_per_option = hydra.utils.instantiate(options_conf) + all_scheduler_cfgs = [] + for option, scheduler_cfgs in scheduler_cfgs_per_option.items(): + for config in scheduler_cfgs: + config.option = option + config.parameter_names = _unix_pattern_to_parameter_names( + config, all_parameter_names, module_cls_to_all_param_names + ) + set_default_parameters(scheduler_cfgs, all_parameter_names) + all_scheduler_cfgs.append(scheduler_cfgs) + + if param_group_modifiers_conf: + for custom_param_modifier in param_group_modifiers_conf: + custom_param_modifier = hydra.utils.instantiate(custom_param_modifier) + all_scheduler_cfgs = custom_param_modifier( + scheduler_cfgs=all_scheduler_cfgs, model=model + ) + schedulers, param_groups = map_scheduler_cfgs_to_param_groups( + all_scheduler_cfgs, named_parameters + ) + if validate_param_groups: + validate_param_group_params(param_groups, model) + optimizer = hydra.utils.instantiate(optimizer_conf, param_groups) + return Optimizer(optimizer, schedulers) + + +def get_full_parameter_name(module_name, param_name): + if module_name == "": + return param_name + return f"{module_name}.{param_name}" + + +class GradientClipper: + """ + Gradient clipping utils that works for DDP + """ + + def __init__(self, max_norm: float = 1.0, norm_type: int = 2): + assert isinstance(max_norm, (int, float)) or max_norm is None + self.max_norm = max_norm if max_norm is None else float(max_norm) + self.norm_type = norm_type + + def __call__(self, model: nn.Module): + if self.max_norm is None: + return # no-op + + nn.utils.clip_grad_norm_( + model.parameters(), max_norm=self.max_norm, norm_type=self.norm_type + ) + + +class ValueScaler: + def __init__(self, scheduler, mult_val: float): + self.scheduler = scheduler + self.mult_val = mult_val + + def __call__(self, *args, **kwargs): + val = self.scheduler(*args, **kwargs) + return val * self.mult_val + + +def rgetattr(obj, rattrs: str = None): + """ + Like getattr(), but supports dotted notation for nested objects. + rattrs is a str of form 'attr1.attr2', returns obj.attr1.attr2 + """ + if rattrs is None: + return obj + attrs = rattrs.split(".") + for attr in attrs: + obj = getattr(obj, attr) + return obj + + +def layer_decay_param_modifier( + scheduler_cfgs: List[List[Dict]], + model, + layer_decay_value: float, + layer_decay_min: Optional[float] = None, + apply_to: Optional[str] = None, + overrides: List[Dict] = (), +) -> List[List[Dict]]: + """ + Args + - scheduler_cfgs: a list of omegaconf.ListConfigs. + Each element in the list is a omegaconfg.DictConfig with the following structure + { + "scheduler": + "option": possible options are "lr", "weight_decay" etc. + "parameter_names": Set of str indicating param names that this scheduler applies to + } + - model: a model that implements a method `get_layer_id` that maps layer_name to an integer and + and a method get_num_layers. + Alternatively, use apply_to argument to select a specific component of the model. + - layer_decay_value: float + - layer_decay_min: min val for layer decay + - apply_to: optional arg to select which component of the model to apply the the layer decay modifier to + - overrides: to manually override lr for specific patterns. Is a list of dicts. Each dict, has keys "pattern", "value". + Returns + - scheduler_configs: same structure as the input, elements can be modified + """ + model = rgetattr(model, apply_to) + num_layers = model.get_num_layers() + 1 + layer_decays = [ + layer_decay_value ** (num_layers - i) for i in range(num_layers + 1) + ] + if layer_decay_min is not None: + layer_decays = [max(val, layer_decay_min) for val in layer_decays] + final_scheduler_cfgs = [] + # scheduler_cfgs is a list of lists + for scheduler_cfg_group in scheduler_cfgs: + curr_cfg_group = [] + # scheduler_cfg_group is a list of dictionaries + for scheduler_cfg in scheduler_cfg_group: + if scheduler_cfg["option"] != "lr": + curr_cfg_group.append(scheduler_cfg) + continue + # Need sorted so that the list of parameter names is deterministic and consistent + # across re-runs of this job. Else it was causing issues with loading the optimizer + # state during a job restart + parameter_names = sorted(scheduler_cfg["parameter_names"]) + + # Only want one cfg group per layer + layer_cfg_groups = {} + for param_name in parameter_names: + layer_id = num_layers + this_scale = layer_decays[layer_id] + if param_name.startswith(apply_to): + layer_id = model.get_layer_id(param_name) + this_scale = layer_decays[layer_id] + # Overrides + for override in overrides: + if fnmatch.fnmatchcase(param_name, override["pattern"]): + this_scale = float(override["value"]) + layer_id = override["pattern"] + break + + if layer_id not in layer_cfg_groups: + curr_param = { + "option": scheduler_cfg["option"], + "scheduler": ValueScaler( + scheduler_cfg["scheduler"], this_scale + ), + "parameter_names": {param_name}, + } + else: + curr_param = layer_cfg_groups[layer_id] + curr_param["parameter_names"].add(param_name) + layer_cfg_groups[layer_id] = curr_param + + for layer_cfg in layer_cfg_groups.values(): + curr_cfg_group.append(layer_cfg) + + final_scheduler_cfgs.append(curr_cfg_group) + return final_scheduler_cfgs diff --git a/third_party/sam3/sam3/train/optim/schedulers.py b/third_party/sam3/sam3/train/optim/schedulers.py new file mode 100644 index 0000000000000000000000000000000000000000..139c9b404dfef3848d6285d318b64a8a49d47e5a --- /dev/null +++ b/third_party/sam3/sam3/train/optim/schedulers.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +import math + + +class InverseSquareRootParamScheduler: + def __init__( + self, + base_lr: float, + warmup_steps: int, + cooldown_steps: int, + timescale: int, + ): + self.base_lr = base_lr + self.warmup_steps = warmup_steps + self.cooldown_steps = cooldown_steps + self.timescale = timescale + + def __call__(self, step: int, where: float): + lr = self.base_lr + + if where > 0: + total_steps = step / where + progress = (step - self.warmup_steps) / float( + total_steps - self.warmup_steps + ) + progress = max(min(progress, 1), 0) + else: + progress = 0 + total_steps = 1 + + shift = self.timescale - self.warmup_steps + if self.warmup_steps < step: + lr = lr / math.sqrt((step + shift) / self.timescale) + + if self.warmup_steps: + lr = lr * min(1.0, step / self.warmup_steps) + if self.cooldown_steps: + lr = lr * min(1.0, (total_steps - step) / self.cooldown_steps) + + return lr diff --git a/third_party/sam3/sam3/train/train.py b/third_party/sam3/sam3/train/train.py new file mode 100644 index 0000000000000000000000000000000000000000..b192059a51d8c8592f8667df42a62469f9a740c0 --- /dev/null +++ b/third_party/sam3/sam3/train/train.py @@ -0,0 +1,338 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +import logging +import os +import random +import sys +import traceback +from argparse import ArgumentParser +from copy import deepcopy + +import submitit +import torch +from hydra import compose, initialize_config_module +from hydra.utils import instantiate +from iopath.common.file_io import g_pathmgr +from omegaconf import OmegaConf +from sam3.train.utils.train_utils import makedir, register_omegaconf_resolvers +from tqdm import tqdm + + +os.environ["HYDRA_FULL_ERROR"] = "1" + + +class SlurmEvent: + QUEUED = "QUEUED" + START = "START" + FINISH = "FINISH" + JOB_ERROR = "JOB_ERROR" + SLURM_SIGNAL = "SLURM_SIGNAL" + + +def handle_custom_resolving(cfg): + # We'll resolve the config here, so we can catch mistakes early. + # However, we need to pass the un-resolved config to the launcher + # (because DVC resolving needs to be done on the node it will run on) + # First, do a copy without triggering resolving + cfg_resolved = OmegaConf.to_container(cfg, resolve=False) + cfg_resolved = OmegaConf.create(cfg_resolved) + return cfg_resolved + + +def single_proc_run(local_rank, main_port, cfg, world_size): + """Single GPU process""" + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(main_port) + os.environ["RANK"] = str(local_rank) + os.environ["LOCAL_RANK"] = str(local_rank) + os.environ["WORLD_SIZE"] = str(world_size) + try: + register_omegaconf_resolvers() + except Exception as e: + logging.info(e) + + trainer = instantiate(cfg.trainer, _recursive_=False) + trainer.run() + + +def single_node_runner(cfg, main_port: int): + assert cfg.launcher.num_nodes == 1 + # assert cfg.launcher.gpus_per_node == 1 + num_proc = cfg.launcher.gpus_per_node + torch.multiprocessing.set_start_method( + "spawn" + ) # CUDA runtime does not support `fork` + if num_proc == 1: + # directly call single_proc so we can easily set breakpoints + # mp.spawn does not let us set breakpoints + single_proc_run(local_rank=0, main_port=main_port, cfg=cfg, world_size=num_proc) + else: + mp_runner = torch.multiprocessing.start_processes + args = (main_port, cfg, num_proc) + # Note: using "fork" below, "spawn" causes time and error regressions. Using + # spawn changes the default multiprocessing context to spawn, which doesn't + # interact well with the dataloaders (likely due to the use of OpenCV). + mp_runner(single_proc_run, args=args, nprocs=num_proc, start_method="spawn") + + +def format_exception(e: Exception, limit=20): + traceback_str = "".join(traceback.format_tb(e.__traceback__, limit=limit)) + return f"{type(e).__name__}: {e}\nTraceback:\n{traceback_str}" + + +class SubmititRunner(submitit.helpers.Checkpointable): + """A callable which is passed to submitit to launch the jobs.""" + + def __init__(self, port, cfg): + self.cfg = cfg + self.port = port + self.has_setup = False + + def run_trainer(self): + job_env = submitit.JobEnvironment() + # Need to add this again so the hydra.job.set_env PYTHONPATH + # is also set when launching jobs. + add_pythonpath_to_sys_path() + os.environ["MASTER_ADDR"] = job_env.hostnames[0] + os.environ["MASTER_PORT"] = str(self.port) + os.environ["RANK"] = str(job_env.global_rank) + os.environ["LOCAL_RANK"] = str(job_env.local_rank) + os.environ["WORLD_SIZE"] = str(job_env.num_tasks) + + register_omegaconf_resolvers() + cfg_resolved = OmegaConf.to_container(self.cfg, resolve=False) + cfg_resolved = OmegaConf.create(cfg_resolved) + + trainer = instantiate(cfg_resolved.trainer, _recursive_=False) + trainer.run() + + def __call__(self): + job_env = submitit.JobEnvironment() + self.setup_job_info(job_env.job_id, job_env.global_rank) + try: + self.run_trainer() + except Exception as e: + # Log the exception. Then raise it again (as what SubmititRunner currently does). + message = format_exception(e) + logging.error(message) + raise e + + def setup_job_info(self, job_id, rank): + """Set up slurm job info""" + self.job_info = { + "job_id": job_id, + "rank": rank, + "cluster": self.cfg.get("cluster", None), + "experiment_log_dir": self.cfg.launcher.experiment_log_dir, + } + + self.has_setup = True + + +def add_pythonpath_to_sys_path(): + if "PYTHONPATH" not in os.environ or not os.environ["PYTHONPATH"]: + return + sys.path = os.environ["PYTHONPATH"].split(":") + sys.path + + +def main(args) -> None: + cfg = compose(config_name=args.config) + if cfg.launcher.experiment_log_dir is None: + cfg.launcher.experiment_log_dir = os.path.join( + os.getcwd(), "sam3_logs", args.config + ) + print("###################### Train App Config ####################") + print(OmegaConf.to_yaml(cfg)) + print("############################################################") + + add_pythonpath_to_sys_path() + makedir(cfg.launcher.experiment_log_dir) + with g_pathmgr.open( + os.path.join(cfg.launcher.experiment_log_dir, "config.yaml"), "w" + ) as f: + f.write(OmegaConf.to_yaml(cfg)) + + cfg_resolved = OmegaConf.to_container(cfg, resolve=False) + cfg_resolved = OmegaConf.create(cfg_resolved) + + with g_pathmgr.open( + os.path.join(cfg.launcher.experiment_log_dir, "config_resolved.yaml"), "w" + ) as f: + f.write(OmegaConf.to_yaml(cfg_resolved, resolve=True)) + + submitit_conf = cfg.get("submitit", None) + assert submitit_conf is not None, "Missing submitit config" + + experiment_log_dir = cfg.launcher.experiment_log_dir + print(f"Experiment Log Dir:\n{experiment_log_dir}") + submitit_dir = os.path.join(experiment_log_dir, "submitit_logs") + + # Prioritize cmd line args + cfg.launcher.gpus_per_node = ( + args.num_gpus if args.num_gpus is not None else cfg.launcher.gpus_per_node + ) + cfg.launcher.num_nodes = ( + args.num_nodes if args.num_nodes is not None else cfg.launcher.num_nodes + ) + submitit_conf.use_cluster = ( + args.use_cluster if args.use_cluster is not None else submitit_conf.use_cluster + ) + if submitit_conf.use_cluster: + executor = submitit.AutoExecutor(folder=submitit_dir) + submitit_conf.partition = ( + args.partition + if args.partition is not None + else submitit_conf.get("partition", None) + ) + submitit_conf.account = ( + args.account + if args.account is not None + else submitit_conf.get("account", None) + ) + submitit_conf.qos = ( + args.qos if args.qos is not None else submitit_conf.get("qos", None) + ) + job_kwargs = { + "timeout_min": 60 * submitit_conf.timeout_hour, + "name": ( + submitit_conf.name if hasattr(submitit_conf, "name") else args.config + ), + "slurm_partition": submitit_conf.partition, + "gpus_per_node": cfg.launcher.gpus_per_node, + "tasks_per_node": cfg.launcher.gpus_per_node, # one task per GPU + "cpus_per_task": submitit_conf.cpus_per_task, + "nodes": cfg.launcher.num_nodes, + "slurm_additional_parameters": { + "exclude": " ".join(submitit_conf.get("exclude_nodes", [])), + }, + } + if "include_nodes" in submitit_conf: + assert ( + len(submitit_conf["include_nodes"]) >= cfg.launcher.num_nodes + ), "Not enough nodes" + job_kwargs["slurm_additional_parameters"]["nodelist"] = " ".join( + submitit_conf["include_nodes"] + ) + if submitit_conf.account is not None: + job_kwargs["slurm_additional_parameters"]["account"] = submitit_conf.account + if submitit_conf.qos is not None: + job_kwargs["slurm_additional_parameters"]["qos"] = submitit_conf.qos + + if submitit_conf.get("mem_gb", None) is not None: + job_kwargs["mem_gb"] = submitit_conf.mem_gb + elif submitit_conf.get("mem", None) is not None: + job_kwargs["slurm_mem"] = submitit_conf.mem + + if submitit_conf.get("constraints", None) is not None: + job_kwargs["slurm_constraint"] = submitit_conf.constraints + + if submitit_conf.get("comment", None) is not None: + job_kwargs["slurm_comment"] = submitit_conf.comment + + # Supports only cpu-bind option within srun_args. New options can be added here + if submitit_conf.get("srun_args", None) is not None: + job_kwargs["slurm_srun_args"] = [] + if submitit_conf.srun_args.get("cpu_bind", None) is not None: + job_kwargs["slurm_srun_args"].extend( + ["--cpu-bind", submitit_conf.srun_args.cpu_bind] + ) + + print("###################### SLURM Config ####################") + print(job_kwargs) + print("##########################################") + executor.update_parameters(**job_kwargs) + + if ( + "job_array" in submitit_conf + and submitit_conf.job_array.get("num_tasks", -1) > 0 + ): + num_tasks = submitit_conf.job_array.num_tasks + job_array_config_dir = os.path.join( + cfg.launcher.experiment_log_dir, "job_array_configs" + ) + makedir(job_array_config_dir) + + job_indices = range(num_tasks) + ports = random.sample( + range(submitit_conf.port_range[0], submitit_conf.port_range[1] + 1), + k=len(job_indices), + ) + + jobs_runners_configs = [] + with executor.batch(): + task_index = 0 + for indices, main_port in tqdm(zip(job_indices, ports)): + curr_cfg = deepcopy(cfg) + curr_cfg.submitit.job_array["task_index"] = task_index + curr_cfg_resolved = handle_custom_resolving(cfg) + runner = SubmititRunner(main_port, curr_cfg) + job = executor.submit(runner) + jobs_runners_configs.append( + (job, runner, curr_cfg, curr_cfg_resolved) + ) + task_index += 1 + + for job, runner, job_cfg, job_cfg_resolved in jobs_runners_configs: + print("Submitit Job ID:", job.job_id) + + # Save job specific config + job_array_config_file = os.path.join( + job_array_config_dir, "{}.config.yaml".format(job.job_id) + ) + with g_pathmgr.open(job_array_config_file, "w") as f: + f.write(OmegaConf.to_yaml(job_cfg)) + + job_array_config_resolved_file = os.path.join( + job_array_config_dir, "{}.config_resolved.yaml".format(job.job_id) + ) + with g_pathmgr.open(job_array_config_resolved_file, "w") as f: + f.write(OmegaConf.to_yaml(job_cfg_resolved, resolve=True)) + + runner.setup_job_info(job.job_id, rank=0) + # runner.log_event(event_type=SlurmEvent.QUEUED) + else: + main_port = random.randint( + submitit_conf.port_range[0], submitit_conf.port_range[1] + ) + runner = SubmititRunner(main_port, cfg) + job = executor.submit(runner) + print(f"Submitit Job ID: {job.job_id}") + runner.setup_job_info(job.job_id, rank=0) + + else: + cfg.launcher.num_nodes = 1 + main_port = random.randint( + submitit_conf.port_range[0], submitit_conf.port_range[1] + ) + single_node_runner(cfg, main_port) + + +if __name__ == "__main__": + initialize_config_module("sam3.train", version_base="1.2") + parser = ArgumentParser() + parser.add_argument( + "-c", + "--config", + required=True, + type=str, + help="path to config file (e.g. configs/roboflow_v100_full_ft_100_images.yaml)", + ) + parser.add_argument( + "--use-cluster", + type=int, + default=None, + help="whether to launch on a cluster, 0: run locally, 1: run on a cluster", + ) + parser.add_argument("--partition", type=str, default=None, help="SLURM partition") + parser.add_argument("--account", type=str, default=None, help="SLURM account") + parser.add_argument("--qos", type=str, default=None, help="SLURM qos") + parser.add_argument( + "--num-gpus", type=int, default=None, help="number of GPUS per node" + ) + parser.add_argument("--num-nodes", type=int, default=None, help="Number of nodes") + args = parser.parse_args() + args.use_cluster = bool(args.use_cluster) if args.use_cluster is not None else None + register_omegaconf_resolvers() + main(args) diff --git a/third_party/sam3/sam3/train/trainer.py b/third_party/sam3/sam3/train/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..5a378869b0e16080157c1a708a5a19a471c5f837 --- /dev/null +++ b/third_party/sam3/sam3/train/trainer.py @@ -0,0 +1,1189 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +import contextlib +import fnmatch +import gc +import json +import logging +import math +import os +import time +from collections import OrderedDict +from dataclasses import dataclass, field +from typing import Any, Dict, List, Mapping, Optional + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +from hydra.utils import instantiate +from iopath.common.file_io import g_pathmgr +from sam3.model.data_misc import BatchedDatapoint +from sam3.model.model_misc import SAM3Output +from sam3.model.utils.misc import copy_data_to_device +from sam3.train.optim.optimizer import construct_optimizer +from sam3.train.utils.checkpoint_utils import ( + assert_skipped_parameters_are_frozen, + exclude_params_matching_unix_pattern, + load_state_dict_into_model, + with_check_parameter_frozen, +) +from sam3.train.utils.distributed import all_reduce_max, barrier, get_rank +from sam3.train.utils.logger import Logger, setup_logging +from sam3.train.utils.train_utils import ( + AverageMeter, + collect_dict_keys, + DurationMeter, + get_amp_type, + get_machine_local_and_dist_rank, + get_resume_checkpoint, + human_readable_time, + is_dist_avail_and_initialized, + log_env_variables, + makedir, + MemMeter, + Phase, + ProgressMeter, + set_seeds, + setup_distributed_backend, +) + + +CORE_LOSS_KEY = "core_loss" + + +def unwrap_ddp_if_wrapped(model): + if isinstance(model, torch.nn.parallel.DistributedDataParallel): + return model.module + return model + + +@dataclass +class OptimAMPConf: + enabled: bool = False + amp_dtype: str = "float16" + + +@dataclass +class OptimConf: + optimizer: torch.optim.Optimizer = None + options: Optional[Dict[str, Any]] = None + param_group_modifiers: Optional[List] = None + amp: Optional[Dict[str, Any]] = None + gradient_clip: Any = None + gradient_logger: Any = None + + def __post_init__(self): + # amp + if not isinstance(self.amp, OptimAMPConf): + if self.amp is None: + self.amp = {} + assert isinstance(self.amp, Mapping) + self.amp = OptimAMPConf(**self.amp) + + +@dataclass +class DistributedConf: + backend: Optional[str] = None # inferred from accelerator type + comms_dtype: Optional[str] = None + find_unused_parameters: bool = False + timeout_mins: int = 30 + gradient_as_bucket_view: bool = False # PyTorch DDP default is False + static_graph: bool = False # PyTorch DDP default is False + + +@dataclass +class CudaConf: + cudnn_deterministic: bool = False + cudnn_benchmark: bool = True + allow_tf32: bool = False + # if not None, `matmul_allow_tf32` key will override `allow_tf32` for matmul + matmul_allow_tf32: Optional[bool] = None + # if not None, `cudnn_allow_tf32` key will override `allow_tf32` for cudnn + cudnn_allow_tf32: Optional[bool] = None + + +@dataclass +class CheckpointConf: + save_dir: str + save_freq: int + save_list: List[int] = field(default_factory=list) + model_weight_initializer: Any = None + save_best_meters: List[str] = None + skip_saving_parameters: List[str] = field(default_factory=list) + initialize_after_preemption: Optional[bool] = None + # if not None, training will be resumed from this checkpoint + resume_from: Optional[str] = None + + def infer_missing(self): + if self.initialize_after_preemption is None: + with_skip_saving = len(self.skip_saving_parameters) > 0 + self.initialize_after_preemption = with_skip_saving + return self + + +@dataclass +class LoggingConf: + log_dir: str + log_freq: int # In iterations + tensorboard_writer: Any + log_level_primary: str = "INFO" + log_level_secondary: str = "ERROR" + log_scalar_frequency: int = 100 + log_visual_frequency: int = 100 + scalar_keys_to_log: Optional[Dict[str, Any]] = None + log_batch_stats: bool = False + wandb_writer: Optional[Any] = None + + +class Trainer: + """ + Trainer supporting the DDP training strategies. + """ + + EPSILON = 1e-8 + + def __init__( + self, + *, # the order of these args can change at any time, so they are keyword-only + data: Dict[str, Any], + model: Dict[str, Any], + logging: Dict[str, Any], + checkpoint: Dict[str, Any], + max_epochs: int, + mode: str = "train", + accelerator: str = "cuda", + seed_value: int = 123, + val_epoch_freq: int = 1, + distributed: Dict[str, bool] = None, + cuda: Dict[str, bool] = None, + env_variables: Optional[Dict[str, Any]] = None, + optim: Optional[Dict[str, Any]] = None, + optim_overrides: Optional[List[Dict[str, Any]]] = None, + meters: Optional[Dict[str, Any]] = None, + loss: Optional[Dict[str, Any]] = None, + skip_first_val: bool = False, + skip_saving_ckpts: bool = False, + empty_gpu_mem_cache_after_eval: bool = True, + gradient_accumulation_steps: int = 1, + ): + self._setup_env_variables(env_variables) + self._setup_timers() + + self.data_conf = data + self.model_conf = model + self.logging_conf = LoggingConf(**logging) + self.checkpoint_conf = CheckpointConf(**checkpoint).infer_missing() + self.max_epochs = max_epochs + self.mode = mode + self.val_epoch_freq = val_epoch_freq + self.optim_conf = OptimConf(**optim) if optim is not None else OptimConf() + self.meters_conf = meters + self.loss_conf = loss + self.gradient_accumulation_steps = gradient_accumulation_steps + distributed = DistributedConf(**distributed or {}) + cuda = CudaConf(**cuda or {}) + self.where = 0.0 + + self.skip_first_val = skip_first_val + self.skip_saving_ckpts = skip_saving_ckpts + self.empty_gpu_mem_cache_after_eval = empty_gpu_mem_cache_after_eval + + self._infer_distributed_backend_if_none(distributed, accelerator) + + self._setup_device(accelerator) + + self._setup_torch_dist_and_backend(cuda, distributed) + + makedir(self.logging_conf.log_dir) + setup_logging( + __name__, + output_dir=self.logging_conf.log_dir, + rank=self.rank, + log_level_primary=self.logging_conf.log_level_primary, + log_level_secondary=self.logging_conf.log_level_secondary, + ) + + set_seeds(seed_value, self.max_epochs, self.distributed_rank) + log_env_variables() + + assert ( + is_dist_avail_and_initialized() + ), "Torch distributed needs to be initialized before calling the trainer." + + self._setup_components() # Except Optimizer everything is setup here. + self._move_to_device() + self._construct_optimizers() + self._setup_dataloaders() + + self.time_elapsed_meter = DurationMeter("Time Elapsed", self.device, ":.2f") + + if self.checkpoint_conf.resume_from is not None: + assert os.path.exists( + self.checkpoint_conf.resume_from + ), f"The 'resume_from' checkpoint {self.checkpoint_conf.resume_from} does not exist!" + dst = os.path.join(self.checkpoint_conf.save_dir, "checkpoint.pt") + if self.distributed_rank == 0 and not os.path.exists(dst): + # Copy the "resume_from" checkpoint to the checkpoint folder + # if there is not a checkpoint to resume from already there + makedir(self.checkpoint_conf.save_dir) + g_pathmgr.copy(self.checkpoint_conf.resume_from, dst) + barrier() + + self.load_checkpoint() + self._setup_ddp_distributed_training(distributed, accelerator) + barrier() + + def _setup_timers(self): + """ + Initializes counters for elapsed time and eta. + """ + self.start_time = time.time() + self.ckpt_time_elapsed = 0 + self.est_epoch_time = dict.fromkeys([Phase.TRAIN, Phase.VAL], 0) + + def _get_meters(self, phase_filters=None): + if self.meters is None: + return {} + meters = {} + for phase, phase_meters in self.meters.items(): + if phase_filters is not None and phase not in phase_filters: + continue + for key, key_meters in phase_meters.items(): + if key_meters is None: + continue + for name, meter in key_meters.items(): + meters[f"{phase}_{key}/{name}"] = meter + return meters + + def _infer_distributed_backend_if_none(self, distributed_conf, accelerator): + if distributed_conf.backend is None: + distributed_conf.backend = "nccl" if accelerator == "cuda" else "gloo" + + def _setup_env_variables(self, env_variables_conf) -> None: + if env_variables_conf is not None: + for variable_name, value in env_variables_conf.items(): + os.environ[variable_name] = value + + def _setup_torch_dist_and_backend(self, cuda_conf, distributed_conf) -> None: + if torch.cuda.is_available(): + torch.backends.cudnn.deterministic = cuda_conf.cudnn_deterministic + torch.backends.cudnn.benchmark = cuda_conf.cudnn_benchmark + torch.backends.cuda.matmul.allow_tf32 = ( + cuda_conf.matmul_allow_tf32 + if cuda_conf.matmul_allow_tf32 is not None + else cuda_conf.allow_tf32 + ) + torch.backends.cudnn.allow_tf32 = ( + cuda_conf.cudnn_allow_tf32 + if cuda_conf.cudnn_allow_tf32 is not None + else cuda_conf.allow_tf32 + ) + + self.rank = setup_distributed_backend( + distributed_conf.backend, distributed_conf.timeout_mins + ) + + def _setup_device(self, accelerator): + self.local_rank, self.distributed_rank = get_machine_local_and_dist_rank() + if accelerator == "cuda": + self.device = torch.device("cuda", self.local_rank) + torch.cuda.set_device(self.local_rank) + elif accelerator == "cpu": + self.device = torch.device("cpu") + else: + raise ValueError(f"Unsupported accelerator: {accelerator}") + + def _setup_ddp_distributed_training(self, distributed_conf, accelerator): + assert isinstance(self.model, torch.nn.Module) + + self.model = nn.parallel.DistributedDataParallel( + self.model, + device_ids=[self.local_rank] if accelerator == "cuda" else [], + find_unused_parameters=distributed_conf.find_unused_parameters, + gradient_as_bucket_view=distributed_conf.gradient_as_bucket_view, + static_graph=distributed_conf.static_graph, + ) + if distributed_conf.comms_dtype is not None: # noqa + from torch.distributed.algorithms import ddp_comm_hooks + + amp_type = get_amp_type(distributed_conf.comms_dtype) + if amp_type == torch.bfloat16: + hook = ddp_comm_hooks.default_hooks.bf16_compress_hook + logging.info("Enabling bfloat16 grad communication") + else: + hook = ddp_comm_hooks.default_hooks.fp16_compress_hook + logging.info("Enabling fp16 grad communication") + process_group = None + self.model.register_comm_hook(process_group, hook) + + def _move_to_device(self): + logging.info( + f"Moving components to device {self.device} and local rank {self.local_rank}." + ) + + self.model.to(self.device) + + logging.info( + f"Done moving components to device {self.device} and local rank {self.local_rank}." + ) + + def save_checkpoint(self, epoch, checkpoint_names=None): + if self.skip_saving_ckpts: + logging.info( + "skip_saving_ckpts is set to True. So, no checkpoints have been saved." + ) + return + checkpoint_folder = self.checkpoint_conf.save_dir + makedir(checkpoint_folder) + if checkpoint_names is None: + checkpoint_names = ["checkpoint"] + if ( + self.checkpoint_conf.save_freq > 0 + and (int(epoch) % self.checkpoint_conf.save_freq == 0) + ) or int(epoch) in self.checkpoint_conf.save_list: + checkpoint_names.append(f"checkpoint_{int(epoch)}") + + checkpoint_paths = [] + for ckpt_name in checkpoint_names: + checkpoint_paths.append(os.path.join(checkpoint_folder, f"{ckpt_name}.pt")) + + state_dict = unwrap_ddp_if_wrapped(self.model).state_dict() + state_dict = exclude_params_matching_unix_pattern( + patterns=self.checkpoint_conf.skip_saving_parameters, state_dict=state_dict + ) + + checkpoint = { + "model": state_dict, + "optimizer": self.optim.optimizer.state_dict(), + "epoch": epoch, + "loss": self.loss.state_dict(), + "steps": self.steps, + "time_elapsed": self.time_elapsed_meter.val, + "best_meter_values": self.best_meter_values, + } + if self.optim_conf.amp.enabled: + checkpoint["scaler"] = self.scaler.state_dict() + + # DDP checkpoints are only saved on rank 0 (all workers are identical) + if self.distributed_rank != 0: + return + + for checkpoint_path in checkpoint_paths: + self._save_checkpoint(checkpoint, checkpoint_path) + + def _save_checkpoint(self, checkpoint, checkpoint_path): + """ + Save a checkpoint while guarding against the job being killed in the middle + of checkpoint saving (which corrupts the checkpoint file and ruins the + entire training since usually only the last checkpoint is kept per run). + + We first save the new checkpoint to a temp file (with a '.tmp' suffix), and + and move it to overwrite the old checkpoint_path. + """ + checkpoint_path_tmp = f"{checkpoint_path}.tmp" + with g_pathmgr.open(checkpoint_path_tmp, "wb") as f: + torch.save(checkpoint, f) + # after torch.save is completed, replace the old checkpoint with the new one + if g_pathmgr.exists(checkpoint_path): + # remove the old checkpoint_path file first (otherwise g_pathmgr.mv fails) + g_pathmgr.rm(checkpoint_path) + success = g_pathmgr.mv(checkpoint_path_tmp, checkpoint_path) + assert success + + def load_checkpoint(self): + ckpt_path = get_resume_checkpoint(self.checkpoint_conf.save_dir) + if ckpt_path is None: + self._init_model_state() + else: + if self.checkpoint_conf.initialize_after_preemption: + self._call_model_initializer() + self._load_resuming_checkpoint(ckpt_path) + + def _init_model_state(self): + # Checking that parameters that won't be saved are indeed frozen + # We do this check here before even saving the model to catch errors + # are early as possible and not at the end of the first epoch + assert_skipped_parameters_are_frozen( + patterns=self.checkpoint_conf.skip_saving_parameters, + model=self.model, + ) + + # Checking that parameters that won't be saved are initialized from + # within the model definition, unless `initialize_after_preemption` + # is explicitly set to `True`. If not, this is a bug, and after + # preemption, the `skip_saving_parameters` will have random values + allow_init_skip_parameters = self.checkpoint_conf.initialize_after_preemption + with with_check_parameter_frozen( + patterns=self.checkpoint_conf.skip_saving_parameters, + model=self.model, + disabled=allow_init_skip_parameters, + ): + self._call_model_initializer() + + def _call_model_initializer(self): + model_weight_initializer = instantiate( + self.checkpoint_conf.model_weight_initializer + ) + if model_weight_initializer is not None: + logging.info( + f"Loading pretrained checkpoint from {self.checkpoint_conf.model_weight_initializer}" + ) + self.model = model_weight_initializer(model=self.model) + + def _load_resuming_checkpoint(self, ckpt_path: str): + logging.info(f"Resuming training from {ckpt_path}") + + with g_pathmgr.open(ckpt_path, "rb") as f: + checkpoint = torch.load(f, map_location="cpu") + load_state_dict_into_model( + model=self.model, + state_dict=checkpoint["model"], + ignore_missing_keys=self.checkpoint_conf.skip_saving_parameters, + ) + + self.optim.optimizer.load_state_dict(checkpoint["optimizer"]) + self.loss.load_state_dict(checkpoint["loss"], strict=True) + self.epoch = checkpoint["epoch"] + self.steps = checkpoint["steps"] + self.ckpt_time_elapsed = checkpoint.get("time_elapsed") + + if self.optim_conf.amp.enabled and "scaler" in checkpoint: + self.scaler.load_state_dict(checkpoint["scaler"]) + + self.best_meter_values = checkpoint.get("best_meter_values", {}) + + if "train_dataset" in checkpoint and self.train_dataset is not None: + self.train_dataset.load_checkpoint_state(checkpoint["train_dataset"]) + + def is_intermediate_val_epoch(self, epoch): + skip_epoch = self.skip_first_val and epoch == 0 + return ( + epoch % self.val_epoch_freq == 0 + and epoch < self.max_epochs - 1 + and not skip_epoch + ) + + def _find_loss(self, key: str): + if key in self.loss: + return self.loss[key] + + assert key != "all", "Loss must be specified for key='all'" + assert ( + "default" in self.loss + ), f"Key {key} not found in losss, and no default provided" + return self.loss["default"] + + def _find_meter(self, phase: str, key: str): + if key in self.meters[phase]: + return self.meters[phase][key] + + for cand_key, meter in self.meters[phase].items(): + if fnmatch.fnmatch(key, cand_key): + return meter + return None + + def _step( + self, + batch: BatchedDatapoint, + model: nn.Module, + phase: str, + ): + key, batch = batch.popitem() + batch = copy_data_to_device(batch, self.device, non_blocking=True) + + find_stages = model(batch) + find_targets = [ + unwrap_ddp_if_wrapped(model).back_convert(x) for x in batch.find_targets + ] + batch_size = len(batch.img_batch) + loss = self._find_loss(key)(find_stages, find_targets) + + loss_str = f"Losses/{phase}_{key}_loss" + + loss_log_str = os.path.join("Step_Losses", loss_str) + + # loss contains multiple sub-components we wish to log + step_losses = {} + if isinstance(loss, dict): + step_losses.update( + {f"Losses/{phase}_{key}_{k}": v for k, v in loss.items()} + ) + loss = self._log_loss_detailed_and_return_core_loss( + loss, loss_log_str, self.steps[phase] + ) + + if self.steps[phase] % self.logging_conf.log_scalar_frequency == 0: + self.logger.log( + loss_log_str, + loss, + self.steps[phase], + ) + + self.steps[phase] += 1 + + ret_tuple = {loss_str: loss}, batch_size, step_losses + + if phase not in self.meters: + return ret_tuple + + meters_dict = self._find_meter(phase, key) + if meters_dict is None: + return ret_tuple + if meters_dict is not None: + for _, meter in meters_dict.items(): + meter.update( + find_stages=find_stages, + find_metadatas=batch.find_metadatas, + model=model, + batch=batch, + key=key, + ) + # Cleanup memory + if isinstance(find_stages, SAM3Output): + for fs in find_stages: + for k in list(fs.keys()): + del fs[k] + + return ret_tuple + + def run(self): + assert self.mode in ["train", "train_only", "val"] + if self.mode == "train": + if self.epoch > 0: + logging.info(f"Resuming training from epoch: {self.epoch}") + # resuming from a checkpoint + if self.is_intermediate_val_epoch(self.epoch - 1): + logging.info("Running previous val epoch") + self.epoch -= 1 + self.run_val() + self.epoch += 1 + self.run_train() + self.run_val() + elif self.mode == "val": + self.run_val() + elif self.mode == "train_only": + self.run_train() + + def _setup_dataloaders(self): + self.train_dataset = None + self.val_dataset = None + + if self.mode in ["train", "val"]: + self.val_dataset = instantiate(self.data_conf.get(Phase.VAL, None)) + + if self.mode in ["train", "train_only"]: + self.train_dataset = instantiate(self.data_conf.train) + + def run_train(self): + while self.epoch < self.max_epochs: + dataloader = self.train_dataset.get_loader(epoch=int(self.epoch)) + barrier() + outs = self.train_epoch(dataloader) + self.logger.log_dict(outs, self.epoch) # Logged only on rank 0 + + # log train to text file. + if self.distributed_rank == 0: + with g_pathmgr.open( + os.path.join(self.logging_conf.log_dir, "train_stats.json"), + "a", + ) as f: + f.write(json.dumps(outs) + "\n") + + # Save checkpoint before validating + self.save_checkpoint(self.epoch + 1) + + del dataloader + gc.collect() + + # Run val, not running on last epoch since will run after the + # loop anyway + if self.is_intermediate_val_epoch(self.epoch): + self.run_val() + if torch.cuda.is_available() and self.empty_gpu_mem_cache_after_eval: + # release memory buffers held by the model during eval (which typically + # involves a lot more frames in video grounding that during training) + torch.cuda.empty_cache() + + if self.distributed_rank == 0: + self.best_meter_values.update(self._get_trainer_state("train")) + with g_pathmgr.open( + os.path.join(self.logging_conf.log_dir, "best_stats.json"), + "a", + ) as f: + f.write(json.dumps(self.best_meter_values) + "\n") + + self.epoch += 1 + # epoch was incremented in the loop but the val step runs out of the loop + self.epoch -= 1 + + def run_val(self): + if not self.val_dataset: + return + + dataloader = self.val_dataset.get_loader(epoch=int(self.epoch)) + outs = self.val_epoch(dataloader, phase=Phase.VAL) + del dataloader + gc.collect() + self.logger.log_dict(outs, self.epoch) # Logged only on rank 0 + + if self.distributed_rank == 0: + with g_pathmgr.open( + os.path.join(self.logging_conf.log_dir, "val_stats.json"), + "a", + ) as f: + f.write(json.dumps(outs) + "\n") + + def val_epoch(self, val_loader, phase): + batch_time = AverageMeter("Batch Time", self.device, ":.2f") + data_time = AverageMeter("Data Time", self.device, ":.2f") + mem = MemMeter("Mem (GB)", self.device, ":.2f") + + iters_per_epoch = len(val_loader) + + curr_phases = [phase] + curr_models = [self.model] + + loss_names = [] + for p in curr_phases: + for key in self.loss.keys(): + loss_names.append(f"Losses/{p}_{key}_loss") + + loss_mts = OrderedDict( + [(name, AverageMeter(name, self.device, ":.2e")) for name in loss_names] + ) + extra_loss_mts = {} + + for model in curr_models: + model.eval() + if hasattr(unwrap_ddp_if_wrapped(model), "on_validation_epoch_start"): + unwrap_ddp_if_wrapped(model).on_validation_epoch_start() + + progress = ProgressMeter( + iters_per_epoch, + [batch_time, data_time, mem, self.time_elapsed_meter, *loss_mts.values()], + self._get_meters(curr_phases), + prefix="Val Epoch: [{}]".format(self.epoch), + ) + + end = time.time() + + for data_iter, batch in enumerate(val_loader): + # measure data loading time + data_time.update(time.time() - end) + + # batch = batch.to(self.device, non_blocking=True) + + # compute output + with torch.no_grad(): + with torch.amp.autocast( + device_type="cuda", + enabled=(self.optim_conf.amp.enabled if self.optim_conf else False), + dtype=( + get_amp_type(self.optim_conf.amp.amp_dtype) + if self.optim_conf + else None + ), + ): + for phase, model in zip(curr_phases, curr_models): + loss_dict, batch_size, extra_losses = self._step( + batch, + model, + phase, + ) + + assert len(loss_dict) == 1 + loss_key, loss = loss_dict.popitem() + + if loss_key in loss_mts: + loss_mts[loss_key].update(loss.item(), batch_size) + + for k, v in extra_losses.items(): + if k not in extra_loss_mts: + extra_loss_mts[k] = AverageMeter(k, self.device, ":.2e") + extra_loss_mts[k].update(v.item(), batch_size) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + self.time_elapsed_meter.update( + time.time() - self.start_time + self.ckpt_time_elapsed + ) + + if torch.cuda.is_available(): + mem.update(reset_peak_usage=True) + + if data_iter % self.logging_conf.log_freq == 0: + progress.display(data_iter) + + if data_iter % self.logging_conf.log_scalar_frequency == 0: + # Log progress meters. + for progress_meter in progress.meters: + self.logger.log( + os.path.join("Step_Stats", phase, progress_meter.name), + progress_meter.val, + self.steps[Phase.VAL], + ) + + if data_iter % 10 == 0: + dist.barrier() + + self.est_epoch_time[phase] = batch_time.avg * iters_per_epoch + self._log_timers(phase) + for model in curr_models: + if hasattr(unwrap_ddp_if_wrapped(model), "on_validation_epoch_end"): + unwrap_ddp_if_wrapped(model).on_validation_epoch_end() + + out_dict = self._log_meters_and_save_best_ckpts(curr_phases) + + for k, v in loss_mts.items(): + out_dict[k] = v.avg + for k, v in extra_loss_mts.items(): + out_dict[k] = v.avg + + for phase in curr_phases: + out_dict.update(self._get_trainer_state(phase)) + self._reset_meters(curr_phases) + logging.info(f"Meters: {out_dict}") + return out_dict + + def _get_trainer_state(self, phase): + return { + "Trainer/where": self.where, + "Trainer/epoch": self.epoch, + f"Trainer/steps_{phase}": self.steps[phase], + } + + def train_epoch(self, train_loader): + # Init stat meters + batch_time_meter = AverageMeter("Batch Time", self.device, ":.2f") + data_time_meter = AverageMeter("Data Time", self.device, ":.2f") + mem_meter = MemMeter("Mem (GB)", self.device, ":.2f") + data_times = [] + phase = Phase.TRAIN + + iters_per_epoch = len(train_loader) + + loss_names = [] + for batch_key in self.loss.keys(): + loss_names.append(f"Losses/{phase}_{batch_key}_loss") + + loss_mts = OrderedDict( + [(name, AverageMeter(name, self.device, ":.2e")) for name in loss_names] + ) + extra_loss_mts = {} + + progress = ProgressMeter( + iters_per_epoch, + [ + batch_time_meter, + data_time_meter, + mem_meter, + self.time_elapsed_meter, + *loss_mts.values(), + ], + self._get_meters([phase]), + prefix="Train Epoch: [{}]".format(self.epoch), + ) + + # Model training loop + self.model.train() + end = time.time() + + for data_iter, batch in enumerate(train_loader): + # measure data loading time + data_time_meter.update(time.time() - end) + data_times.append(data_time_meter.val) + # batch = batch.to( + # self.device, non_blocking=True + # ) # move tensors in a tensorclass + + try: + self._run_step(batch, phase, loss_mts, extra_loss_mts) + + # compute gradient and do optim step + exact_epoch = self.epoch + float(data_iter) / iters_per_epoch + self.where = float(exact_epoch) / self.max_epochs + assert self.where <= 1 + self.EPSILON + if self.where < 1.0: + self.optim.step_schedulers( + self.where, step=int(exact_epoch * iters_per_epoch) + ) + else: + logging.warning( + f"Skipping scheduler update since the training is at the end, i.e, {self.where} of [0,1]." + ) + + # Log schedulers + if data_iter % self.logging_conf.log_scalar_frequency == 0: + for j, param_group in enumerate(self.optim.optimizer.param_groups): + for option in self.optim.schedulers[j]: + optim_prefix = ( + "" + f"{j}_" + if len(self.optim.optimizer.param_groups) > 1 + else "" + ) + self.logger.log( + os.path.join("Optim", f"{optim_prefix}", option), + param_group[option], + self.steps[phase], + ) + + # Clipping gradients and detecting diverging gradients + if self.gradient_clipper is not None: + self.scaler.unscale_(self.optim.optimizer) + self.gradient_clipper(model=self.model) + + if self.gradient_logger is not None: + self.gradient_logger( + self.model, rank=self.distributed_rank, where=self.where + ) + + # Optimizer step: the scaler will make sure gradients are not + # applied if the gradients are infinite + self.scaler.step(self.optim.optimizer) + self.scaler.update() + + # measure elapsed time + batch_time_meter.update(time.time() - end) + end = time.time() + + self.time_elapsed_meter.update( + time.time() - self.start_time + self.ckpt_time_elapsed + ) + + mem_meter.update(reset_peak_usage=True) + if data_iter % self.logging_conf.log_freq == 0: + progress.display(data_iter) + + if data_iter % self.logging_conf.log_scalar_frequency == 0: + # Log progress meters. + for progress_meter in progress.meters: + self.logger.log( + os.path.join("Step_Stats", phase, progress_meter.name), + progress_meter.val, + self.steps[phase], + ) + + # Catching NaN/Inf errors in the loss + except FloatingPointError as e: + raise e + + self.est_epoch_time[Phase.TRAIN] = batch_time_meter.avg * iters_per_epoch + self._log_timers(Phase.TRAIN) + self._log_sync_data_times(Phase.TRAIN, data_times) + + out_dict = self._log_meters_and_save_best_ckpts([Phase.TRAIN]) + + for k, v in loss_mts.items(): + out_dict[k] = v.avg + for k, v in extra_loss_mts.items(): + out_dict[k] = v.avg + out_dict.update(self._get_trainer_state(phase)) + logging.info(f"Losses and meters: {out_dict}") + self._reset_meters([phase]) + return out_dict + + def _log_sync_data_times(self, phase, data_times): + data_times = all_reduce_max(torch.tensor(data_times)).tolist() + steps = range(self.steps[phase] - len(data_times), self.steps[phase]) + for step, data_time in zip(steps, data_times): + if step % self.logging_conf.log_scalar_frequency == 0: + self.logger.log( + os.path.join("Step_Stats", phase, "Data Time Synced"), + data_time, + step, + ) + + def _run_step( + self, + batch: BatchedDatapoint, + phase: str, + loss_mts: Dict[str, AverageMeter], + extra_loss_mts: Dict[str, AverageMeter], + raise_on_error: bool = True, + ): + """ + Run the forward / backward + """ + + # it's important to set grads to None, especially with Adam since 0 + # grads will also update a model even if the step doesn't produce + # gradients + self.optim.zero_grad(set_to_none=True) + + if self.gradient_accumulation_steps > 1: + assert isinstance( + batch, list + ), f"Expected a list of batches, got {type(batch)}" + assert ( + len(batch) == self.gradient_accumulation_steps + ), f"Expected {self.gradient_accumulation_steps} batches, got {len(batch)}" + accum_steps = len(batch) + else: + accum_steps = 1 + batch = [batch] + + for i, chunked_batch in enumerate(batch): + ddp_context = ( + self.model.no_sync() + if i < accum_steps - 1 + else contextlib.nullcontext() + ) + with ddp_context: + with torch.amp.autocast( + device_type="cuda", + enabled=self.optim_conf.amp.enabled, + dtype=get_amp_type(self.optim_conf.amp.amp_dtype), + ): + loss_dict, batch_size, extra_losses = self._step( + chunked_batch, + self.model, + phase, + ) + + assert len(loss_dict) == 1 + loss_key, loss = loss_dict.popitem() + + if not math.isfinite(loss.item()): + error_msg = f"Loss is {loss.item()}, attempting to stop training" + logging.error(error_msg) + if raise_on_error: + raise FloatingPointError(error_msg) + else: + return + + self.scaler.scale(loss).backward() + loss_mts[loss_key].update(loss.item(), batch_size) + for extra_loss_key, extra_loss in extra_losses.items(): + if extra_loss_key not in extra_loss_mts: + extra_loss_mts[extra_loss_key] = AverageMeter( + extra_loss_key, self.device, ":.2e" + ) + extra_loss_mts[extra_loss_key].update(extra_loss.item(), batch_size) + + def _log_meters_and_save_best_ckpts(self, phases: List[str]): + logging.info("Synchronizing meters") + out_dict = {} + checkpoint_save_keys = [] + for key, meter in self._get_meters(phases).items(): + meter_output = meter.compute_synced() + is_better_check = getattr(meter, "is_better", None) + + for meter_subkey, meter_value in meter_output.items(): + out_dict[os.path.join("Meters_train", key, meter_subkey)] = meter_value + + if is_better_check is None: + continue + + tracked_meter_key = os.path.join(key, meter_subkey) + if tracked_meter_key not in self.best_meter_values or is_better_check( + meter_value, + self.best_meter_values[tracked_meter_key], + ): + self.best_meter_values[tracked_meter_key] = meter_value + + if ( + self.checkpoint_conf.save_best_meters is not None + and key in self.checkpoint_conf.save_best_meters + ): + checkpoint_save_keys.append(tracked_meter_key.replace("/", "_")) + + if len(checkpoint_save_keys) > 0: + self.save_checkpoint(self.epoch + 1, checkpoint_save_keys) + + return out_dict + + def _log_timers(self, phase): + time_remaining = 0 + epochs_remaining = self.max_epochs - self.epoch - 1 + val_epochs_remaining = sum( + n % self.val_epoch_freq == 0 for n in range(self.epoch, self.max_epochs) + ) + + # Adding the guaranteed val run at the end if val_epoch_freq doesn't coincide with + # the end epoch. + if (self.max_epochs - 1) % self.val_epoch_freq != 0: + val_epochs_remaining += 1 + + # Remove the current val run from estimate + if phase == Phase.VAL: + val_epochs_remaining -= 1 + + time_remaining += ( + epochs_remaining * self.est_epoch_time[Phase.TRAIN] + + val_epochs_remaining * self.est_epoch_time[Phase.VAL] + ) + + self.logger.log( + os.path.join("Step_Stats", phase, self.time_elapsed_meter.name), + self.time_elapsed_meter.val, + self.steps[phase], + ) + + logging.info(f"Estimated time remaining: {human_readable_time(time_remaining)}") + + def _reset_meters(self, phases: str) -> None: + for meter in self._get_meters(phases).values(): + meter.reset() + + def _check_val_key_match(self, val_keys, phase): + if val_keys is not None: + # Check if there are any duplicates + assert len(val_keys) == len( + set(val_keys) + ), f"Duplicate keys in val datasets, keys: {val_keys}" + + # Check that the keys match the meter keys + if self.meters_conf is not None and phase in self.meters_conf: + assert set(val_keys) == set(self.meters_conf[phase].keys()), ( + f"Keys in val datasets do not match the keys in meters." + f"\nMissing in meters: {set(val_keys) - set(self.meters_conf[phase].keys())}" + f"\nMissing in val datasets: {set(self.meters_conf[phase].keys()) - set(val_keys)}" + ) + + if self.loss_conf is not None: + loss_keys = set(self.loss_conf.keys()) - set(["all"]) + if "default" not in loss_keys: + for k in val_keys: + assert ( + k in loss_keys + ), f"Error: key {k} is not defined in the losses, and no default is set" + + def _setup_components(self): + # Get the keys for all the val datasets, if any + val_phase = Phase.VAL + val_keys = None + if self.data_conf.get(val_phase, None) is not None: + val_keys = collect_dict_keys(self.data_conf[val_phase]) + # Additional checks on the sanity of the config for val datasets + self._check_val_key_match(val_keys, phase=val_phase) + + logging.info("Setting up components: Model, loss, optim, meters etc.") + self.epoch = 0 + self.steps = {Phase.TRAIN: 0, Phase.VAL: 0} + + self.logger = Logger(self.logging_conf) + + self.model = instantiate(self.model_conf, _convert_="all") + print_model_summary(self.model) + + self.loss = None + if self.loss_conf: + self.loss = { + key: el # wrap_base_loss(el) + for (key, el) in instantiate(self.loss_conf, _convert_="all").items() + } + self.loss = nn.ModuleDict(self.loss) + + self.meters = {} + self.best_meter_values = {} + if self.meters_conf: + self.meters = instantiate(self.meters_conf, _convert_="all") + + self.scaler = torch.amp.GradScaler( + self.device, + enabled=self.optim_conf.amp.enabled if self.optim_conf else False, + ) + + self.gradient_clipper = ( + instantiate(self.optim_conf.gradient_clip) if self.optim_conf else None + ) + self.gradient_logger = ( + instantiate(self.optim_conf.gradient_logger) if self.optim_conf else None + ) + + logging.info("Finished setting up components: Model, loss, optim, meters etc.") + + def _construct_optimizers(self): + self.optim = construct_optimizer( + self.model, + self.optim_conf.optimizer, + self.optim_conf.options, + self.optim_conf.param_group_modifiers, + ) + + def _log_loss_detailed_and_return_core_loss(self, loss, loss_str, step): + core_loss = loss.pop(CORE_LOSS_KEY) + if step % self.logging_conf.log_scalar_frequency == 0: + for k in loss: + log_str = os.path.join(loss_str, k) + self.logger.log(log_str, loss[k], step) + return core_loss + + +def print_model_summary(model: torch.nn.Module, log_dir: str = ""): + """ + Prints the model and the number of parameters in the model. + # Multiple packages provide this info in a nice table format + # However, they need us to provide an `input` (as they also write down the output sizes) + # Our models are complex, and a single input is restrictive. + # https://github.com/sksq96/pytorch-summary + # https://github.com/nmhkahn/torchsummaryX + """ + if get_rank() != 0: + return + param_kwargs = {} + trainable_parameters = sum( + p.numel() for p in model.parameters(**param_kwargs) if p.requires_grad + ) + total_parameters = sum(p.numel() for p in model.parameters(**param_kwargs)) + non_trainable_parameters = total_parameters - trainable_parameters + logging.info("==" * 10) + logging.info(f"Summary for model {type(model)}") + logging.info(f"Model is {model}") + logging.info(f"\tTotal parameters {get_human_readable_count(total_parameters)}") + logging.info( + f"\tTrainable parameters {get_human_readable_count(trainable_parameters)}" + ) + logging.info( + f"\tNon-Trainable parameters {get_human_readable_count(non_trainable_parameters)}" + ) + logging.info("==" * 10) + + if log_dir: + output_fpath = os.path.join(log_dir, "model.txt") + with g_pathmgr.open(output_fpath, "w") as f: + print(model, file=f) + + +PARAMETER_NUM_UNITS = [" ", "K", "M", "B", "T"] + + +def get_human_readable_count(number: int) -> str: + """ + Abbreviates an integer number with K, M, B, T for thousands, millions, + billions and trillions, respectively. + Examples: + >>> get_human_readable_count(123) + '123 ' + >>> get_human_readable_count(1234) # (one thousand) + '1.2 K' + >>> get_human_readable_count(2e6) # (two million) + '2.0 M' + >>> get_human_readable_count(3e9) # (three billion) + '3.0 B' + >>> get_human_readable_count(4e14) # (four hundred trillion) + '400 T' + >>> get_human_readable_count(5e15) # (more than trillion) + '5,000 T' + Args: + number: a positive integer number + Return: + A string formatted according to the pattern described above. + """ + assert number >= 0 + labels = PARAMETER_NUM_UNITS + num_digits = int(np.floor(np.log10(number)) + 1 if number > 0 else 1) + num_groups = int(np.ceil(num_digits / 3)) + num_groups = min(num_groups, len(labels)) # don't abbreviate beyond trillions + shift = -3 * (num_groups - 1) + number = number * (10**shift) + index = num_groups - 1 + if index < 1 or number >= 100: + return f"{int(number):,d} {labels[index]}" + else: + return f"{number:,.1f} {labels[index]}" diff --git a/third_party/sam3/sam3/train/transforms/__init__.py b/third_party/sam3/sam3/train/transforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..726cb1eb0ff0f17f7edc4ec73f1a73d1ac87bf59 --- /dev/null +++ b/third_party/sam3/sam3/train/transforms/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe diff --git a/third_party/sam3/sam3/train/transforms/basic.py b/third_party/sam3/sam3/train/transforms/basic.py new file mode 100644 index 0000000000000000000000000000000000000000..b08dfae8814ce765641a684a19bee7b7c957b029 --- /dev/null +++ b/third_party/sam3/sam3/train/transforms/basic.py @@ -0,0 +1,456 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +""" +Transforms and data augmentation for both image + bbox. +""" + +import math +import random +from typing import Iterable + +import PIL +import torch +import torchvision.transforms as T +import torchvision.transforms.functional as F +from sam3.model.box_ops import box_xyxy_to_cxcywh +from sam3.model.data_misc import interpolate + + +def crop(image, target, region): + cropped_image = F.crop(image, *region) + + target = target.copy() + i, j, h, w = region + + # should we do something wrt the original size? + target["size"] = torch.tensor([h, w]) + + fields = ["labels", "area", "iscrowd", "positive_map"] + + if "boxes" in target: + boxes = target["boxes"] + max_size = torch.as_tensor([w, h], dtype=torch.float32) + cropped_boxes = boxes - torch.as_tensor([j, i, j, i], dtype=torch.float32) + cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) + cropped_boxes = cropped_boxes.clamp(min=0) + area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1) + target["boxes"] = cropped_boxes.reshape(-1, 4) + target["area"] = area + fields.append("boxes") + + if "input_boxes" in target: + boxes = target["input_boxes"] + max_size = torch.as_tensor([w, h], dtype=torch.float32) + cropped_boxes = boxes - torch.as_tensor([j, i, j, i], dtype=torch.float32) + cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) + cropped_boxes = cropped_boxes.clamp(min=0) + target["input_boxes"] = cropped_boxes.reshape(-1, 4) + + if "masks" in target: + # FIXME should we update the area here if there are no boxes? + target["masks"] = target["masks"][:, i : i + h, j : j + w] + fields.append("masks") + + # remove elements for which the boxes or masks that have zero area + if "boxes" in target or "masks" in target: + # favor boxes selection when defining which elements to keep + # this is compatible with previous implementation + if "boxes" in target: + cropped_boxes = target["boxes"].reshape(-1, 2, 2) + keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1) + else: + keep = target["masks"].flatten(1).any(1) + + for field in fields: + if field in target: + target[field] = target[field][keep] + + return cropped_image, target + + +def hflip(image, target): + flipped_image = F.hflip(image) + + w, h = image.size + + target = target.copy() + if "boxes" in target: + boxes = target["boxes"] + boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor( + [-1, 1, -1, 1] + ) + torch.as_tensor([w, 0, w, 0]) + target["boxes"] = boxes + + if "input_boxes" in target: + boxes = target["input_boxes"] + boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor( + [-1, 1, -1, 1] + ) + torch.as_tensor([w, 0, w, 0]) + target["input_boxes"] = boxes + + if "masks" in target: + target["masks"] = target["masks"].flip(-1) + + if "text_input" in target: + text_input = ( + target["text_input"] + .replace("left", "[TMP]") + .replace("right", "left") + .replace("[TMP]", "right") + ) + target["text_input"] = text_input + + return flipped_image, target + + +def resize(image, target, size, max_size=None, square=False): + # size can be min_size (scalar) or (w, h) tuple + + def get_size_with_aspect_ratio(image_size, size, max_size=None): + w, h = image_size + if max_size is not None: + min_original_size = float(min((w, h))) + max_original_size = float(max((w, h))) + if max_original_size / min_original_size * size > max_size: + size = int(round(max_size * min_original_size / max_original_size)) + + if (w <= h and w == size) or (h <= w and h == size): + return (h, w) + + if w < h: + ow = size + oh = int(size * h / w) + else: + oh = size + ow = int(size * w / h) + + return (oh, ow) + + def get_size(image_size, size, max_size=None): + if isinstance(size, (list, tuple)): + return size[::-1] + else: + return get_size_with_aspect_ratio(image_size, size, max_size) + + if square: + size = size, size + else: + size = get_size(image.size, size, max_size) + rescaled_image = F.resize(image, size) + + if target is None: + return rescaled_image, None + + ratios = tuple( + float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size) + ) + ratio_width, ratio_height = ratios + + target = target.copy() + if "boxes" in target: + boxes = target["boxes"] + scaled_boxes = boxes * torch.as_tensor( + [ratio_width, ratio_height, ratio_width, ratio_height], dtype=torch.float32 + ) + target["boxes"] = scaled_boxes + if "input_boxes" in target: + boxes = target["input_boxes"] + scaled_boxes = boxes * torch.as_tensor( + [ratio_width, ratio_height, ratio_width, ratio_height], dtype=torch.float32 + ) + target["input_boxes"] = scaled_boxes + + if "area" in target: + area = target["area"] + scaled_area = area * (ratio_width * ratio_height) + target["area"] = scaled_area + + h, w = size + target["size"] = torch.tensor([h, w]) + + if "masks" in target: + target["masks"] = ( + interpolate(target["masks"][:, None].float(), size, mode="nearest")[:, 0] + > 0.5 + ) + + return rescaled_image, target + + +def pad(image, target, padding): + if len(padding) == 2: + # assumes that we only pad on the bottom right corners + padded_image = F.pad(image, (0, 0, padding[0], padding[1])) + else: + # left, top, right, bottom + padded_image = F.pad(image, (padding[0], padding[1], padding[2], padding[3])) + if target is None: + return padded_image, None + target = target.copy() + + w, h = padded_image.size + + # should we do something wrt the original size? + target["size"] = torch.tensor([h, w]) + if "boxes" in target and len(padding) == 4: + boxes = target["boxes"] + boxes = boxes + torch.as_tensor( + [padding[0], padding[1], padding[0], padding[1]], dtype=torch.float32 + ) + target["boxes"] = boxes + + if "input_boxes" in target and len(padding) == 4: + boxes = target["input_boxes"] + boxes = boxes + torch.as_tensor( + [padding[0], padding[1], padding[0], padding[1]], dtype=torch.float32 + ) + target["input_boxes"] = boxes + + if "masks" in target: + if len(padding) == 2: + target["masks"] = torch.nn.functional.pad( + target["masks"], (0, padding[0], 0, padding[1]) + ) + else: + target["masks"] = torch.nn.functional.pad( + target["masks"], (padding[0], padding[2], padding[1], padding[3]) + ) + return padded_image, target + + +class RandomCrop: + def __init__(self, size): + self.size = size + + def __call__(self, img, target): + region = T.RandomCrop.get_params(img, self.size) + return crop(img, target, region) + + +class RandomSizeCrop: + def __init__(self, min_size: int, max_size: int, respect_boxes: bool = False): + self.min_size = min_size + self.max_size = max_size + self.respect_boxes = respect_boxes # if True we can't crop a box out + + def __call__(self, img: PIL.Image.Image, target: dict): + init_boxes = len(target["boxes"]) + init_boxes_tensor = target["boxes"].clone() + if self.respect_boxes and init_boxes > 0: + minW, minH, maxW, maxH = ( + min(img.width, self.min_size), + min(img.width, self.min_size), + min(img.width, self.max_size), + min(img.height, self.max_size), + ) + minX, minY = ( + target["boxes"][:, 0].max().item() + 10.0, + target["boxes"][:, 1].max().item() + 10.0, + ) + minX = min(img.width, minX) + minY = min(img.height, minY) + maxX, maxY = ( + target["boxes"][:, 2].min().item() - 10, + target["boxes"][:, 3].min().item() - 10, + ) + maxX = max(0.0, maxX) + maxY = max(0.0, maxY) + minW = max(minW, minX - maxX) + minH = max(minH, minY - maxY) + w = random.uniform(minW, max(minW, maxW)) + h = random.uniform(minH, max(minH, maxH)) + if minX > maxX: + # i = random.uniform(max(0, minX - w + 1), max(maxX, max(0, minX - w + 1))) + i = random.uniform(max(0, minX - w), max(maxX, max(0, minX - w))) + else: + i = random.uniform( + max(0, minX - w + 1), max(maxX - 1, max(0, minX - w + 1)) + ) + if minY > maxY: + # j = random.uniform(max(0, minY - h + 1), max(maxY, max(0, minY - h + 1))) + j = random.uniform(max(0, minY - h), max(maxY, max(0, minY - h))) + else: + j = random.uniform( + max(0, minY - h + 1), max(maxY - 1, max(0, minY - h + 1)) + ) + result_img, result_target = crop(img, target, [j, i, h, w]) + assert ( + len(result_target["boxes"]) == init_boxes + ), f"img_w={img.width}\timg_h={img.height}\tminX={minX}\tminY={minY}\tmaxX={maxX}\tmaxY={maxY}\tminW={minW}\tminH={minH}\tmaxW={maxW}\tmaxH={maxH}\tw={w}\th={h}\ti={i}\tj={j}\tinit_boxes={init_boxes_tensor}\tresults={result_target['boxes']}" + + return result_img, result_target + else: + w = random.randint(self.min_size, min(img.width, self.max_size)) + h = random.randint(self.min_size, min(img.height, self.max_size)) + region = T.RandomCrop.get_params(img, (h, w)) + result_img, result_target = crop(img, target, region) + return result_img, result_target + + +class CenterCrop: + def __init__(self, size): + self.size = size + + def __call__(self, img, target): + image_width, image_height = img.size + crop_height, crop_width = self.size + crop_top = int(round((image_height - crop_height) / 2.0)) + crop_left = int(round((image_width - crop_width) / 2.0)) + return crop(img, target, (crop_top, crop_left, crop_height, crop_width)) + + +class RandomHorizontalFlip: + def __init__(self, p=0.5): + self.p = p + + def __call__(self, img, target): + if random.random() < self.p: + return hflip(img, target) + return img, target + + +class RandomResize: + def __init__(self, sizes, max_size=None, square=False): + if isinstance(sizes, int): + sizes = (sizes,) + assert isinstance(sizes, Iterable) + self.sizes = list(sizes) + self.max_size = max_size + self.square = square + + def __call__(self, img, target=None): + size = random.choice(self.sizes) + return resize(img, target, size, self.max_size, square=self.square) + + +class RandomPad: + def __init__(self, max_pad): + self.max_pad = max_pad + + def __call__(self, img, target): + pad_x = random.randint(0, self.max_pad) + pad_y = random.randint(0, self.max_pad) + return pad(img, target, (pad_x, pad_y)) + + +class PadToSize: + def __init__(self, size): + self.size = size + + def __call__(self, img, target): + w, h = img.size + pad_x = self.size - w + pad_y = self.size - h + assert pad_x >= 0 and pad_y >= 0 + pad_left = random.randint(0, pad_x) + pad_right = pad_x - pad_left + pad_top = random.randint(0, pad_y) + pad_bottom = pad_y - pad_top + return pad(img, target, (pad_left, pad_top, pad_right, pad_bottom)) + + +class Identity: + def __call__(self, img, target): + return img, target + + +class RandomSelect: + """ + Randomly selects between transforms1 and transforms2, + with probability p for transforms1 and (1 - p) for transforms2 + """ + + def __init__(self, transforms1=None, transforms2=None, p=0.5): + self.transforms1 = transforms1 or Identity() + self.transforms2 = transforms2 or Identity() + self.p = p + + def __call__(self, img, target): + if random.random() < self.p: + return self.transforms1(img, target) + return self.transforms2(img, target) + + +class ToTensor: + def __call__(self, img, target): + return F.to_tensor(img), target + + +class RandomErasing: + def __init__(self, *args, **kwargs): + self.eraser = T.RandomErasing(*args, **kwargs) + + def __call__(self, img, target): + return self.eraser(img), target + + +class Normalize: + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, image, target=None): + image = F.normalize(image, mean=self.mean, std=self.std) + if target is None: + return image, None + target = target.copy() + h, w = image.shape[-2:] + if "boxes" in target: + boxes = target["boxes"] + boxes = box_xyxy_to_cxcywh(boxes) + boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32) + target["boxes"] = boxes + if "input_boxes" in target: + boxes = target["input_boxes"] + boxes = box_xyxy_to_cxcywh(boxes) + boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32) + target["input_boxes"] = boxes + return image, target + + +class RemoveDifficult: + def __init__(self, enabled=False): + self.remove_difficult = enabled + + def __call__(self, image, target=None): + if target is None: + return image, None + target = target.copy() + keep = ~target["iscrowd"].to(torch.bool) | (not self.remove_difficult) + if "boxes" in target: + target["boxes"] = target["boxes"][keep] + target["labels"] = target["labels"][keep] + target["iscrowd"] = target["iscrowd"][keep] + return image, target + + +class Compose: + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, image, target): + for t in self.transforms: + image, target = t(image, target) + return image, target + + def __repr__(self): + format_string = self.__class__.__name__ + "(" + for t in self.transforms: + format_string += "\n" + format_string += " {0}".format(t) + format_string += "\n)" + return format_string + + +def get_random_resize_scales(size, min_size, rounded): + stride = 128 if rounded else 32 + min_size = int(stride * math.ceil(min_size / stride)) + scales = list(range(min_size, size + 1, stride)) + return scales + + +def get_random_resize_max_size(size, ratio=5 / 3): + max_size = round(ratio * size) + return max_size diff --git a/third_party/sam3/sam3/train/transforms/basic_for_api.py b/third_party/sam3/sam3/train/transforms/basic_for_api.py new file mode 100644 index 0000000000000000000000000000000000000000..27a3d4b804734be1f586417ffdb1fe6d0bafc7a1 --- /dev/null +++ b/third_party/sam3/sam3/train/transforms/basic_for_api.py @@ -0,0 +1,1395 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +""" +Transforms and data augmentation for both image + bbox. +""" + +import logging +import numbers +import random +from collections.abc import Sequence +from typing import Iterable + +import torch +import torchvision.transforms as T +import torchvision.transforms.functional as F +import torchvision.transforms.v2.functional as Fv2 +from PIL import Image as PILImage +from sam3.model.box_ops import box_xyxy_to_cxcywh, masks_to_boxes +from sam3.train.data.sam3_image_dataset import Datapoint +from torchvision.transforms import InterpolationMode + + +def crop( + datapoint, + index, + region, + v2=False, + check_validity=True, + check_input_validity=True, + recompute_box_from_mask=False, +): + if v2: + rtop, rleft, rheight, rwidth = (int(round(r)) for r in region) + datapoint.images[index].data = Fv2.crop( + datapoint.images[index].data, + top=rtop, + left=rleft, + height=rheight, + width=rwidth, + ) + else: + datapoint.images[index].data = F.crop(datapoint.images[index].data, *region) + + i, j, h, w = region + + # should we do something wrt the original size? + datapoint.images[index].size = (h, w) + + for obj in datapoint.images[index].objects: + # crop the mask + if obj.segment is not None: + obj.segment = F.crop(obj.segment, int(i), int(j), int(h), int(w)) + + # crop the bounding box + if recompute_box_from_mask and obj.segment is not None: + # here the boxes are still in XYXY format with absolute coordinates (they are + # converted to CxCyWH with relative coordinates in basic_for_api.NormalizeAPI) + obj.bbox, obj.area = get_bbox_xyxy_abs_coords_from_mask(obj.segment) + else: + if recompute_box_from_mask and obj.segment is None and obj.area > 0: + logging.warning( + "Cannot recompute bounding box from mask since `obj.segment` is None. " + "Falling back to directly cropping from the input bounding box." + ) + boxes = obj.bbox.view(1, 4) + max_size = torch.as_tensor([w, h], dtype=torch.float32) + cropped_boxes = boxes - torch.as_tensor([j, i, j, i], dtype=torch.float32) + cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) + cropped_boxes = cropped_boxes.clamp(min=0) + obj.area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1) + obj.bbox = cropped_boxes.reshape(-1, 4) + + for query in datapoint.find_queries: + if query.semantic_target is not None: + query.semantic_target = F.crop( + query.semantic_target, int(i), int(j), int(h), int(w) + ) + if query.image_id == index and query.input_bbox is not None: + boxes = query.input_bbox + max_size = torch.as_tensor([w, h], dtype=torch.float32) + cropped_boxes = boxes - torch.as_tensor([j, i, j, i], dtype=torch.float32) + cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) + cropped_boxes = cropped_boxes.clamp(min=0) + + # cur_area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1) + # if check_input_validity: + # assert ( + # (cur_area > 0).all().item() + # ), "Some input box got cropped out by the crop transform" + + query.input_bbox = cropped_boxes.reshape(-1, 4) + if query.image_id == index and query.input_points is not None: + print( + "Warning! Point cropping with this function may lead to unexpected results" + ) + points = query.input_points + # Unlike right-lower box edges, which are exclusive, the + # point must be in [0, length-1], hence the -1 + max_size = torch.as_tensor([w, h], dtype=torch.float32) - 1 + cropped_points = points - torch.as_tensor([j, i, 0], dtype=torch.float32) + cropped_points[:, :, :2] = torch.min(cropped_points[:, :, :2], max_size) + cropped_points[:, :, :2] = cropped_points[:, :, :2].clamp(min=0) + query.input_points = cropped_points + + if check_validity: + # Check that all boxes are still valid + for obj in datapoint.images[index].objects: + assert obj.area > 0, "Box {} has no area".format(obj.bbox) + + return datapoint + + +def hflip(datapoint, index): + datapoint.images[index].data = F.hflip(datapoint.images[index].data) + + w, h = datapoint.images[index].data.size + for obj in datapoint.images[index].objects: + boxes = obj.bbox.view(1, 4) + boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor( + [-1, 1, -1, 1] + ) + torch.as_tensor([w, 0, w, 0]) + obj.bbox = boxes + if obj.segment is not None: + obj.segment = F.hflip(obj.segment) + + for query in datapoint.find_queries: + if query.semantic_target is not None: + query.semantic_target = F.hflip(query.semantic_target) + if query.image_id == index and query.input_bbox is not None: + boxes = query.input_bbox + boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor( + [-1, 1, -1, 1] + ) + torch.as_tensor([w, 0, w, 0]) + query.input_bbox = boxes + if query.image_id == index and query.input_points is not None: + points = query.input_points + points = points * torch.as_tensor([-1, 1, 1]) + torch.as_tensor([w, 0, 0]) + query.input_points = points + return datapoint + + +def get_size_with_aspect_ratio(image_size, size, max_size=None): + w, h = image_size + if max_size is not None: + min_original_size = float(min((w, h))) + max_original_size = float(max((w, h))) + if max_original_size / min_original_size * size > max_size: + size = max_size * min_original_size / max_original_size + + if (w <= h and w == size) or (h <= w and h == size): + return (h, w) + + if w < h: + ow = int(round(size)) + oh = int(round(size * h / w)) + else: + oh = int(round(size)) + ow = int(round(size * w / h)) + + return (oh, ow) + + +def resize(datapoint, index, size, max_size=None, square=False, v2=False): + # size can be min_size (scalar) or (w, h) tuple + + def get_size(image_size, size, max_size=None): + if isinstance(size, (list, tuple)): + return size[::-1] + else: + return get_size_with_aspect_ratio(image_size, size, max_size) + + if square: + size = size, size + else: + cur_size = ( + datapoint.images[index].data.size()[-2:][::-1] + if v2 + else datapoint.images[index].data.size + ) + size = get_size(cur_size, size, max_size) + + old_size = ( + datapoint.images[index].data.size()[-2:][::-1] + if v2 + else datapoint.images[index].data.size + ) + if v2: + datapoint.images[index].data = Fv2.resize( + datapoint.images[index].data, size, antialias=True + ) + else: + datapoint.images[index].data = F.resize(datapoint.images[index].data, size) + + new_size = ( + datapoint.images[index].data.size()[-2:][::-1] + if v2 + else datapoint.images[index].data.size + ) + ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(new_size, old_size)) + ratio_width, ratio_height = ratios + + for obj in datapoint.images[index].objects: + boxes = obj.bbox.view(1, 4) + scaled_boxes = boxes * torch.as_tensor( + [ratio_width, ratio_height, ratio_width, ratio_height], dtype=torch.float32 + ) + obj.bbox = scaled_boxes + obj.area *= ratio_width * ratio_height + if obj.segment is not None: + obj.segment = F.resize(obj.segment[None, None], size).squeeze() + + for query in datapoint.find_queries: + if query.semantic_target is not None: + query.semantic_target = F.resize( + query.semantic_target[None, None], size + ).squeeze() + if query.image_id == index and query.input_bbox is not None: + boxes = query.input_bbox + scaled_boxes = boxes * torch.as_tensor( + [ratio_width, ratio_height, ratio_width, ratio_height], + dtype=torch.float32, + ) + query.input_bbox = scaled_boxes + if query.image_id == index and query.input_points is not None: + points = query.input_points + scaled_points = points * torch.as_tensor( + [ratio_width, ratio_height, 1], + dtype=torch.float32, + ) + query.input_points = scaled_points + + h, w = size + datapoint.images[index].size = (h, w) + return datapoint + + +def pad(datapoint, index, padding, v2=False): + old_h, old_w = datapoint.images[index].size + h, w = old_h, old_w + if len(padding) == 2: + # assumes that we only pad on the bottom right corners + if v2: + datapoint.images[index].data = Fv2.pad( + datapoint.images[index].data, (0, 0, padding[0], padding[1]) + ) + else: + datapoint.images[index].data = F.pad( + datapoint.images[index].data, (0, 0, padding[0], padding[1]) + ) + h += padding[1] + w += padding[0] + else: + if v2: + # left, top, right, bottom + datapoint.images[index].data = Fv2.pad( + datapoint.images[index].data, + (padding[0], padding[1], padding[2], padding[3]), + ) + else: + # left, top, right, bottom + datapoint.images[index].data = F.pad( + datapoint.images[index].data, + (padding[0], padding[1], padding[2], padding[3]), + ) + h += padding[1] + padding[3] + w += padding[0] + padding[2] + + datapoint.images[index].size = (h, w) + + for obj in datapoint.images[index].objects: + if len(padding) != 2: + obj.bbox += torch.as_tensor( + [padding[0], padding[1], padding[0], padding[1]], dtype=torch.float32 + ) + if obj.segment is not None: + if v2: + if len(padding) == 2: + obj.segment = Fv2.pad( + obj.segment[None], (0, 0, padding[0], padding[1]) + ).squeeze(0) + else: + obj.segment = Fv2.pad(obj.segment[None], tuple(padding)).squeeze(0) + else: + if len(padding) == 2: + obj.segment = F.pad(obj.segment, (0, 0, padding[0], padding[1])) + else: + obj.segment = F.pad(obj.segment, tuple(padding)) + + for query in datapoint.find_queries: + if query.semantic_target is not None: + if v2: + if len(padding) == 2: + query.semantic_target = Fv2.pad( + query.semantic_target[None, None], + (0, 0, padding[0], padding[1]), + ).squeeze() + else: + query.semantic_target = Fv2.pad( + query.semantic_target[None, None], tuple(padding) + ).squeeze() + else: + if len(padding) == 2: + query.semantic_target = F.pad( + query.semantic_target[None, None], + (0, 0, padding[0], padding[1]), + ).squeeze() + else: + query.semantic_target = F.pad( + query.semantic_target[None, None], tuple(padding) + ).squeeze() + if query.image_id == index and query.input_bbox is not None: + if len(padding) != 2: + query.input_bbox += torch.as_tensor( + [padding[0], padding[1], padding[0], padding[1]], + dtype=torch.float32, + ) + if query.image_id == index and query.input_points is not None: + if len(padding) != 2: + query.input_points += torch.as_tensor( + [padding[0], padding[1], 0], dtype=torch.float32 + ) + + return datapoint + + +class RandomSizeCropAPI: + def __init__( + self, + min_size: int, + max_size: int, + respect_boxes: bool, + consistent_transform: bool, + respect_input_boxes: bool = True, + v2: bool = False, + recompute_box_from_mask: bool = False, + ): + self.min_size = min_size + self.max_size = max_size + self.respect_boxes = respect_boxes # if True we can't crop a box out + self.respect_input_boxes = respect_input_boxes + self.consistent_transform = consistent_transform + self.v2 = v2 + self.recompute_box_from_mask = recompute_box_from_mask + + def _sample_no_respect_boxes(self, img): + w = random.randint(self.min_size, min(img.width, self.max_size)) + h = random.randint(self.min_size, min(img.height, self.max_size)) + return T.RandomCrop.get_params(img, (h, w)) + + def _sample_respect_boxes(self, img, boxes, points, min_box_size=10.0): + """ + Assure that no box or point is dropped via cropping, though portions + of boxes may be removed. + """ + if len(boxes) == 0 and len(points) == 0: + return self._sample_no_respect_boxes(img) + + if self.v2: + img_height, img_width = img.size()[-2:] + else: + img_width, img_height = img.size + + minW, minH, maxW, maxH = ( + min(img_width, self.min_size), + min(img_height, self.min_size), + min(img_width, self.max_size), + min(img_height, self.max_size), + ) + + # The crop box must extend one pixel beyond points to the bottom/right + # to assure the exclusive box contains the points. + minX = ( + torch.cat([boxes[:, 0] + min_box_size, points[:, 0] + 1], dim=0) + .max() + .item() + ) + minY = ( + torch.cat([boxes[:, 1] + min_box_size, points[:, 1] + 1], dim=0) + .max() + .item() + ) + minX = min(img_width, minX) + minY = min(img_height, minY) + maxX = torch.cat([boxes[:, 2] - min_box_size, points[:, 0]], dim=0).min().item() + maxY = torch.cat([boxes[:, 3] - min_box_size, points[:, 1]], dim=0).min().item() + maxX = max(0.0, maxX) + maxY = max(0.0, maxY) + minW = max(minW, minX - maxX) + minH = max(minH, minY - maxY) + w = random.uniform(minW, max(minW, maxW)) + h = random.uniform(minH, max(minH, maxH)) + if minX > maxX: + # i = random.uniform(max(0, minX - w + 1), max(maxX, max(0, minX - w + 1))) + i = random.uniform(max(0, minX - w), max(maxX, max(0, minX - w))) + else: + i = random.uniform( + max(0, minX - w + 1), max(maxX - 1, max(0, minX - w + 1)) + ) + if minY > maxY: + # j = random.uniform(max(0, minY - h + 1), max(maxY, max(0, minY - h + 1))) + j = random.uniform(max(0, minY - h), max(maxY, max(0, minY - h))) + else: + j = random.uniform( + max(0, minY - h + 1), max(maxY - 1, max(0, minY - h + 1)) + ) + + return [j, i, h, w] + + def __call__(self, datapoint, **kwargs): + if self.respect_boxes or self.respect_input_boxes: + if self.consistent_transform: + # Check that all the images are the same size + w, h = datapoint.images[0].data.size + for img in datapoint.images: + assert img.data.size == (w, h) + + all_boxes = [] + # Getting all boxes in all the images + if self.respect_boxes: + all_boxes += [ + obj.bbox.view(-1, 4) + for img in datapoint.images + for obj in img.objects + ] + # Get all the boxes in the find queries + if self.respect_input_boxes: + all_boxes += [ + q.input_bbox.view(-1, 4) + for q in datapoint.find_queries + if q.input_bbox is not None + ] + if all_boxes: + all_boxes = torch.cat(all_boxes, 0) + else: + all_boxes = torch.empty(0, 4) + + all_points = [ + q.input_points.view(-1, 3)[:, :2] + for q in datapoint.find_queries + if q.input_points is not None + ] + if all_points: + all_points = torch.cat(all_points, 0) + else: + all_points = torch.empty(0, 2) + + crop_param = self._sample_respect_boxes( + datapoint.images[0].data, all_boxes, all_points + ) + for i in range(len(datapoint.images)): + datapoint = crop( + datapoint, + i, + crop_param, + v2=self.v2, + check_validity=self.respect_boxes, + check_input_validity=self.respect_input_boxes, + recompute_box_from_mask=self.recompute_box_from_mask, + ) + return datapoint + else: + for i in range(len(datapoint.images)): + all_boxes = [] + # Get all boxes in the current image + if self.respect_boxes: + all_boxes += [ + obj.bbox.view(-1, 4) for obj in datapoint.images[i].objects + ] + # Get all the boxes in the find queries that correspond to this image + if self.respect_input_boxes: + all_boxes += [ + q.input_bbox.view(-1, 4) + for q in datapoint.find_queries + if q.image_id == i and q.input_bbox is not None + ] + if all_boxes: + all_boxes = torch.cat(all_boxes, 0) + else: + all_boxes = torch.empty(0, 4) + + all_points = [ + q.input_points.view(-1, 3)[:, :2] + for q in datapoint.find_queries + if q.input_points is not None + ] + if all_points: + all_points = torch.cat(all_points, 0) + else: + all_points = torch.empty(0, 2) + + crop_param = self._sample_respect_boxes( + datapoint.images[i].data, all_boxes, all_points + ) + datapoint = crop( + datapoint, + i, + crop_param, + v2=self.v2, + check_validity=self.respect_boxes, + check_input_validity=self.respect_input_boxes, + recompute_box_from_mask=self.recompute_box_from_mask, + ) + return datapoint + else: + if self.consistent_transform: + # Check that all the images are the same size + w, h = datapoint.images[0].data.size + for img in datapoint.images: + assert img.data.size == (w, h) + + crop_param = self._sample_no_respect_boxes(datapoint.images[0].data) + for i in range(len(datapoint.images)): + datapoint = crop( + datapoint, + i, + crop_param, + v2=self.v2, + check_validity=self.respect_boxes, + check_input_validity=self.respect_input_boxes, + recompute_box_from_mask=self.recompute_box_from_mask, + ) + return datapoint + else: + for i in range(len(datapoint.images)): + crop_param = self._sample_no_respect_boxes(datapoint.images[i].data) + datapoint = crop( + datapoint, + i, + crop_param, + v2=self.v2, + check_validity=self.respect_boxes, + check_input_validity=self.respect_input_boxes, + recompute_box_from_mask=self.recompute_box_from_mask, + ) + return datapoint + + +class CenterCropAPI: + def __init__(self, size, consistent_transform, recompute_box_from_mask=False): + self.size = size + self.consistent_transform = consistent_transform + self.recompute_box_from_mask = recompute_box_from_mask + + def _sample_crop(self, image_width, image_height): + crop_height, crop_width = self.size + crop_top = int(round((image_height - crop_height) / 2.0)) + crop_left = int(round((image_width - crop_width) / 2.0)) + return crop_top, crop_left, crop_height, crop_width + + def __call__(self, datapoint, **kwargs): + if self.consistent_transform: + # Check that all the images are the same size + w, h = datapoint.images[0].data.size + for img in datapoint.images: + assert img.size == (w, h) + + crop_top, crop_left, crop_height, crop_width = self._sample_crop(w, h) + for i in range(len(datapoint.images)): + datapoint = crop( + datapoint, + i, + (crop_top, crop_left, crop_height, crop_width), + recompute_box_from_mask=self.recompute_box_from_mask, + ) + return datapoint + + for i in range(len(datapoint.images)): + w, h = datapoint.images[i].data.size + crop_top, crop_left, crop_height, crop_width = self._sample_crop(w, h) + datapoint = crop( + datapoint, + i, + (crop_top, crop_left, crop_height, crop_width), + recompute_box_from_mask=self.recompute_box_from_mask, + ) + + return datapoint + + +class RandomHorizontalFlip: + def __init__(self, consistent_transform, p=0.5): + self.p = p + self.consistent_transform = consistent_transform + + def __call__(self, datapoint, **kwargs): + if self.consistent_transform: + if random.random() < self.p: + for i in range(len(datapoint.images)): + datapoint = hflip(datapoint, i) + return datapoint + for i in range(len(datapoint.images)): + if random.random() < self.p: + datapoint = hflip(datapoint, i) + return datapoint + + +class RandomResizeAPI: + def __init__( + self, sizes, consistent_transform, max_size=None, square=False, v2=False + ): + if isinstance(sizes, int): + sizes = (sizes,) + assert isinstance(sizes, Iterable) + self.sizes = list(sizes) + self.max_size = max_size + self.square = square + self.consistent_transform = consistent_transform + self.v2 = v2 + + def __call__(self, datapoint, **kwargs): + if self.consistent_transform: + size = random.choice(self.sizes) + for i in range(len(datapoint.images)): + datapoint = resize( + datapoint, i, size, self.max_size, square=self.square, v2=self.v2 + ) + return datapoint + for i in range(len(datapoint.images)): + size = random.choice(self.sizes) + datapoint = resize( + datapoint, i, size, self.max_size, square=self.square, v2=self.v2 + ) + return datapoint + + +class ScheduledRandomResizeAPI(RandomResizeAPI): + def __init__(self, size_scheduler, consistent_transform, square=False): + self.size_scheduler = size_scheduler + # Just a meaningful init value for super + params = self.size_scheduler(epoch_num=0) + sizes, max_size = params["sizes"], params["max_size"] + super().__init__(sizes, consistent_transform, max_size=max_size, square=square) + + def __call__(self, datapoint, **kwargs): + assert "epoch" in kwargs, "Param scheduler needs to know the current epoch" + params = self.size_scheduler(kwargs["epoch"]) + sizes, max_size = params["sizes"], params["max_size"] + self.sizes = sizes + self.max_size = max_size + datapoint = super(ScheduledRandomResizeAPI, self).__call__(datapoint, **kwargs) + return datapoint + + +class RandomPadAPI: + def __init__(self, max_pad, consistent_transform): + self.max_pad = max_pad + self.consistent_transform = consistent_transform + + def _sample_pad(self): + pad_x = random.randint(0, self.max_pad) + pad_y = random.randint(0, self.max_pad) + return pad_x, pad_y + + def __call__(self, datapoint, **kwargs): + if self.consistent_transform: + pad_x, pad_y = self._sample_pad() + for i in range(len(datapoint.images)): + datapoint = pad(datapoint, i, (pad_x, pad_y)) + return datapoint + + for i in range(len(datapoint.images)): + pad_x, pad_y = self._sample_pad() + datapoint = pad(datapoint, i, (pad_x, pad_y)) + return datapoint + + +class PadToSizeAPI: + def __init__(self, size, consistent_transform, bottom_right=False, v2=False): + self.size = size + self.consistent_transform = consistent_transform + self.v2 = v2 + self.bottom_right = bottom_right + + def _sample_pad(self, w, h): + pad_x = self.size - w + pad_y = self.size - h + assert pad_x >= 0 and pad_y >= 0 + pad_left = random.randint(0, pad_x) + pad_right = pad_x - pad_left + pad_top = random.randint(0, pad_y) + pad_bottom = pad_y - pad_top + return pad_left, pad_top, pad_right, pad_bottom + + def __call__(self, datapoint, **kwargs): + if self.consistent_transform: + # Check that all the images are the same size + w, h = datapoint.images[0].data.size + for img in datapoint.images: + assert img.size == (w, h) + if self.bottom_right: + pad_right = self.size - w + pad_bottom = self.size - h + padding = (pad_right, pad_bottom) + else: + padding = self._sample_pad(w, h) + for i in range(len(datapoint.images)): + datapoint = pad(datapoint, i, padding, v2=self.v2) + return datapoint + + for i, img in enumerate(datapoint.images): + w, h = img.data.size + if self.bottom_right: + pad_right = self.size - w + pad_bottom = self.size - h + padding = (pad_right, pad_bottom) + else: + padding = self._sample_pad(w, h) + datapoint = pad(datapoint, i, padding, v2=self.v2) + return datapoint + + +class RandomMosaicVideoAPI: + def __init__(self, prob=0.15, grid_h=2, grid_w=2, use_random_hflip=False): + self.prob = prob + self.grid_h = grid_h + self.grid_w = grid_w + self.use_random_hflip = use_random_hflip + + def __call__(self, datapoint, **kwargs): + if random.random() > self.prob: + return datapoint + + # select a random location to place the target mask in the mosaic + target_grid_y = random.randint(0, self.grid_h - 1) + target_grid_x = random.randint(0, self.grid_w - 1) + # whether to flip each grid in the mosaic horizontally + if self.use_random_hflip: + should_hflip = torch.rand(self.grid_h, self.grid_w) < 0.5 + else: + should_hflip = torch.zeros(self.grid_h, self.grid_w, dtype=torch.bool) + for i in range(len(datapoint.images)): + datapoint = random_mosaic_frame( + datapoint, + i, + grid_h=self.grid_h, + grid_w=self.grid_w, + target_grid_y=target_grid_y, + target_grid_x=target_grid_x, + should_hflip=should_hflip, + ) + + return datapoint + + +def random_mosaic_frame( + datapoint, + index, + grid_h, + grid_w, + target_grid_y, + target_grid_x, + should_hflip, +): + # Step 1: downsize the images and paste them into a mosaic + image_data = datapoint.images[index].data + is_pil = isinstance(image_data, PILImage.Image) + if is_pil: + H_im = image_data.height + W_im = image_data.width + image_data_output = PILImage.new("RGB", (W_im, H_im)) + else: + H_im = image_data.size(-2) + W_im = image_data.size(-1) + image_data_output = torch.zeros_like(image_data) + + downsize_cache = {} + for grid_y in range(grid_h): + for grid_x in range(grid_w): + y_offset_b = grid_y * H_im // grid_h + x_offset_b = grid_x * W_im // grid_w + y_offset_e = (grid_y + 1) * H_im // grid_h + x_offset_e = (grid_x + 1) * W_im // grid_w + H_im_downsize = y_offset_e - y_offset_b + W_im_downsize = x_offset_e - x_offset_b + + if (H_im_downsize, W_im_downsize) in downsize_cache: + image_data_downsize = downsize_cache[(H_im_downsize, W_im_downsize)] + else: + image_data_downsize = F.resize( + image_data, + size=(H_im_downsize, W_im_downsize), + interpolation=InterpolationMode.BILINEAR, + antialias=True, # antialiasing for downsizing + ) + downsize_cache[(H_im_downsize, W_im_downsize)] = image_data_downsize + if should_hflip[grid_y, grid_x].item(): + image_data_downsize = F.hflip(image_data_downsize) + + if is_pil: + image_data_output.paste(image_data_downsize, (x_offset_b, y_offset_b)) + else: + image_data_output[:, y_offset_b:y_offset_e, x_offset_b:x_offset_e] = ( + image_data_downsize + ) + + datapoint.images[index].data = image_data_output + + # Step 2: downsize the masks and paste them into the target grid of the mosaic + # (note that we don't scale input/target boxes since they are not used in TA) + for obj in datapoint.images[index].objects: + if obj.segment is None: + continue + assert obj.segment.shape == (H_im, W_im) and obj.segment.dtype == torch.uint8 + segment_output = torch.zeros_like(obj.segment) + + target_y_offset_b = target_grid_y * H_im // grid_h + target_x_offset_b = target_grid_x * W_im // grid_w + target_y_offset_e = (target_grid_y + 1) * H_im // grid_h + target_x_offset_e = (target_grid_x + 1) * W_im // grid_w + target_H_im_downsize = target_y_offset_e - target_y_offset_b + target_W_im_downsize = target_x_offset_e - target_x_offset_b + + segment_downsize = F.resize( + obj.segment[None, None], + size=(target_H_im_downsize, target_W_im_downsize), + interpolation=InterpolationMode.BILINEAR, + antialias=True, # antialiasing for downsizing + )[0, 0] + if should_hflip[target_grid_y, target_grid_x].item(): + segment_downsize = F.hflip(segment_downsize[None, None])[0, 0] + + segment_output[ + target_y_offset_b:target_y_offset_e, target_x_offset_b:target_x_offset_e + ] = segment_downsize + obj.segment = segment_output + + return datapoint + + +class ScheduledPadToSizeAPI(PadToSizeAPI): + def __init__(self, size_scheduler, consistent_transform): + self.size_scheduler = size_scheduler + size = self.size_scheduler(epoch_num=0)["sizes"] + super().__init__(size, consistent_transform) + + def __call__(self, datapoint, **kwargs): + assert "epoch" in kwargs, "Param scheduler needs to know the current epoch" + params = self.size_scheduler(kwargs["epoch"]) + self.size = params["resolution"] + return super(ScheduledPadToSizeAPI, self).__call__(datapoint, **kwargs) + + +class IdentityAPI: + def __call__(self, datapoint, **kwargs): + return datapoint + + +class RandomSelectAPI: + """ + Randomly selects between transforms1 and transforms2, + with probability p for transforms1 and (1 - p) for transforms2 + """ + + def __init__(self, transforms1=None, transforms2=None, p=0.5): + self.transforms1 = transforms1 or IdentityAPI() + self.transforms2 = transforms2 or IdentityAPI() + self.p = p + + def __call__(self, datapoint, **kwargs): + if random.random() < self.p: + return self.transforms1(datapoint, **kwargs) + return self.transforms2(datapoint, **kwargs) + + +class ToTensorAPI: + def __init__(self, v2=False): + self.v2 = v2 + + def __call__(self, datapoint: Datapoint, **kwargs): + for img in datapoint.images: + if self.v2: + img.data = Fv2.to_image_tensor(img.data) + # img.data = Fv2.to_dtype(img.data, torch.uint8, scale=True) + # img.data = Fv2.convert_image_dtype(img.data, torch.uint8) + else: + img.data = F.to_tensor(img.data) + return datapoint + + +class NormalizeAPI: + def __init__(self, mean, std, v2=False): + self.mean = mean + self.std = std + self.v2 = v2 + + def __call__(self, datapoint: Datapoint, **kwargs): + for img in datapoint.images: + if self.v2: + img.data = Fv2.convert_image_dtype(img.data, torch.float32) + img.data = Fv2.normalize(img.data, mean=self.mean, std=self.std) + else: + img.data = F.normalize(img.data, mean=self.mean, std=self.std) + for obj in img.objects: + boxes = obj.bbox + cur_h, cur_w = img.data.shape[-2:] + boxes = box_xyxy_to_cxcywh(boxes) + boxes = boxes / torch.tensor( + [cur_w, cur_h, cur_w, cur_h], dtype=torch.float32 + ) + obj.bbox = boxes + + for query in datapoint.find_queries: + if query.input_bbox is not None: + boxes = query.input_bbox + cur_h, cur_w = datapoint.images[query.image_id].data.shape[-2:] + boxes = box_xyxy_to_cxcywh(boxes) + boxes = boxes / torch.tensor( + [cur_w, cur_h, cur_w, cur_h], dtype=torch.float32 + ) + query.input_bbox = boxes + if query.input_points is not None: + points = query.input_points + cur_h, cur_w = datapoint.images[query.image_id].data.shape[-2:] + points = points / torch.tensor([cur_w, cur_h, 1.0], dtype=torch.float32) + query.input_points = points + + return datapoint + + +class ComposeAPI: + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, datapoint, **kwargs): + for t in self.transforms: + datapoint = t(datapoint, **kwargs) + return datapoint + + def __repr__(self): + format_string = self.__class__.__name__ + "(" + for t in self.transforms: + format_string += "\n" + format_string += " {0}".format(t) + format_string += "\n)" + return format_string + + +class RandomGrayscale: + def __init__(self, consistent_transform, p=0.5): + self.p = p + self.consistent_transform = consistent_transform + self.Grayscale = T.Grayscale(num_output_channels=3) + + def __call__(self, datapoint: Datapoint, **kwargs): + if self.consistent_transform: + if random.random() < self.p: + for img in datapoint.images: + img.data = self.Grayscale(img.data) + return datapoint + for img in datapoint.images: + if random.random() < self.p: + img.data = self.Grayscale(img.data) + return datapoint + + +class ColorJitter: + def __init__(self, consistent_transform, brightness, contrast, saturation, hue): + self.consistent_transform = consistent_transform + self.brightness = ( + brightness + if isinstance(brightness, list) + else [max(0, 1 - brightness), 1 + brightness] + ) + self.contrast = ( + contrast + if isinstance(contrast, list) + else [max(0, 1 - contrast), 1 + contrast] + ) + self.saturation = ( + saturation + if isinstance(saturation, list) + else [max(0, 1 - saturation), 1 + saturation] + ) + self.hue = hue if isinstance(hue, list) or hue is None else ([-hue, hue]) + + def __call__(self, datapoint: Datapoint, **kwargs): + if self.consistent_transform: + # Create a color jitter transformation params + ( + fn_idx, + brightness_factor, + contrast_factor, + saturation_factor, + hue_factor, + ) = T.ColorJitter.get_params( + self.brightness, self.contrast, self.saturation, self.hue + ) + for img in datapoint.images: + if not self.consistent_transform: + ( + fn_idx, + brightness_factor, + contrast_factor, + saturation_factor, + hue_factor, + ) = T.ColorJitter.get_params( + self.brightness, self.contrast, self.saturation, self.hue + ) + for fn_id in fn_idx: + if fn_id == 0 and brightness_factor is not None: + img.data = F.adjust_brightness(img.data, brightness_factor) + elif fn_id == 1 and contrast_factor is not None: + img.data = F.adjust_contrast(img.data, contrast_factor) + elif fn_id == 2 and saturation_factor is not None: + img.data = F.adjust_saturation(img.data, saturation_factor) + elif fn_id == 3 and hue_factor is not None: + img.data = F.adjust_hue(img.data, hue_factor) + return datapoint + + +class RandomAffine: + def __init__( + self, + degrees, + consistent_transform, + scale=None, + translate=None, + shear=None, + image_mean=(123, 116, 103), + log_warning=True, + num_tentatives=1, + image_interpolation="bicubic", + ): + """ + The mask is required for this transform. + if consistent_transform if True, then the same random affine is applied to all frames and masks. + """ + self.degrees = degrees if isinstance(degrees, list) else ([-degrees, degrees]) + self.scale = scale + self.shear = ( + shear if isinstance(shear, list) else ([-shear, shear] if shear else None) + ) + self.translate = translate + self.fill_img = image_mean + self.consistent_transform = consistent_transform + self.log_warning = log_warning + self.num_tentatives = num_tentatives + + if image_interpolation == "bicubic": + self.image_interpolation = InterpolationMode.BICUBIC + elif image_interpolation == "bilinear": + self.image_interpolation = InterpolationMode.BILINEAR + else: + raise NotImplementedError + + def __call__(self, datapoint: Datapoint, **kwargs): + for _tentative in range(self.num_tentatives): + res = self.transform_datapoint(datapoint) + if res is not None: + return res + + if self.log_warning: + logging.warning( + f"Skip RandomAffine for zero-area mask in first frame after {self.num_tentatives} tentatives" + ) + return datapoint + + def transform_datapoint(self, datapoint: Datapoint): + _, height, width = F.get_dimensions(datapoint.images[0].data) + img_size = [width, height] + + if self.consistent_transform: + # Create a random affine transformation + affine_params = T.RandomAffine.get_params( + degrees=self.degrees, + translate=self.translate, + scale_ranges=self.scale, + shears=self.shear, + img_size=img_size, + ) + + for img_idx, img in enumerate(datapoint.images): + this_masks = [ + obj.segment.unsqueeze(0) if obj.segment is not None else None + for obj in img.objects + ] + if not self.consistent_transform: + # if not consistent we create a new affine params for every frame&mask pair Create a random affine transformation + affine_params = T.RandomAffine.get_params( + degrees=self.degrees, + translate=self.translate, + scale_ranges=self.scale, + shears=self.shear, + img_size=img_size, + ) + + transformed_bboxes, transformed_masks = [], [] + for i in range(len(img.objects)): + if this_masks[i] is None: + transformed_masks.append(None) + # Dummy bbox for a dummy target + transformed_bboxes.append(torch.tensor([[0, 0, 0, 0]])) + else: + transformed_mask = F.affine( + this_masks[i], + *affine_params, + interpolation=InterpolationMode.NEAREST, + fill=0.0, + ) + if img_idx == 0 and transformed_mask.max() == 0: + # We are dealing with a video and the object is not visible in the first frame + # Return the datapoint without transformation + return None + transformed_bbox = masks_to_boxes(transformed_mask) + transformed_bboxes.append(transformed_bbox) + transformed_masks.append(transformed_mask.squeeze()) + + for i in range(len(img.objects)): + img.objects[i].bbox = transformed_bboxes[i] + img.objects[i].segment = transformed_masks[i] + + img.data = F.affine( + img.data, + *affine_params, + interpolation=self.image_interpolation, + fill=self.fill_img, + ) + return datapoint + + +class RandomResizedCrop: + def __init__( + self, + consistent_transform, + size, + scale=None, + ratio=None, + log_warning=True, + num_tentatives=4, + keep_aspect_ratio=False, + ): + """ + The mask is required for this transform. + if consistent_transform if True, then the same random resized crop is applied to all frames and masks. + """ + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + elif isinstance(size, Sequence) and len(size) == 1: + self.size = (size[0], size[0]) + elif len(size) != 2: + raise ValueError("Please provide only two dimensions (h, w) for size.") + else: + self.size = size + + self.scale = scale if scale is not None else (0.08, 1.0) + self.ratio = ratio if ratio is not None else (3.0 / 4.0, 4.0 / 3.0) + self.consistent_transform = consistent_transform + self.log_warning = log_warning + self.num_tentatives = num_tentatives + self.keep_aspect_ratio = keep_aspect_ratio + + def __call__(self, datapoint: Datapoint, **kwargs): + for _tentative in range(self.num_tentatives): + res = self.transform_datapoint(datapoint) + if res is not None: + return res + + if self.log_warning: + logging.warning( + f"Skip RandomResizeCrop for zero-area mask in first frame after {self.num_tentatives} tentatives" + ) + return datapoint + + def transform_datapoint(self, datapoint: Datapoint): + if self.keep_aspect_ratio: + original_size = datapoint.images[0].size + original_ratio = original_size[1] / original_size[0] + ratio = [r * original_ratio for r in self.ratio] + else: + ratio = self.ratio + + if self.consistent_transform: + # Create a random crop transformation + crop_params = T.RandomResizedCrop.get_params( + img=datapoint.images[0].data, + scale=self.scale, + ratio=ratio, + ) + + for img_idx, img in enumerate(datapoint.images): + if not self.consistent_transform: + # Create a random crop transformation + crop_params = T.RandomResizedCrop.get_params( + img=img.data, + scale=self.scale, + ratio=ratio, + ) + + this_masks = [ + obj.segment.unsqueeze(0) if obj.segment is not None else None + for obj in img.objects + ] + + transformed_bboxes, transformed_masks = [], [] + for i in range(len(img.objects)): + if this_masks[i] is None: + transformed_masks.append(None) + # Dummy bbox for a dummy target + transformed_bboxes.append(torch.tensor([[0, 0, 0, 0]])) + else: + transformed_mask = F.resized_crop( + this_masks[i], + *crop_params, + size=self.size, + interpolation=InterpolationMode.NEAREST, + ) + if img_idx == 0 and transformed_mask.max() == 0: + # We are dealing with a video and the object is not visible in the first frame + # Return the datapoint without transformation + return None + transformed_masks.append(transformed_mask.squeeze()) + transformed_bbox = masks_to_boxes(transformed_mask) + transformed_bboxes.append(transformed_bbox) + + # Set the new boxes and masks if all transformed masks and boxes are good. + for i in range(len(img.objects)): + img.objects[i].bbox = transformed_bboxes[i] + img.objects[i].segment = transformed_masks[i] + + img.data = F.resized_crop( + img.data, + *crop_params, + size=self.size, + interpolation=InterpolationMode.BILINEAR, + ) + return datapoint + + +class ResizeToMaxIfAbove: + # Resize datapoint image if one of its sides is larger that max_size + def __init__( + self, + max_size=None, + ): + self.max_size = max_size + + def __call__(self, datapoint: Datapoint, **kwargs): + _, height, width = F.get_dimensions(datapoint.images[0].data) + + if height <= self.max_size and width <= self.max_size: + # The original frames are small enough + return datapoint + elif height >= width: + new_height = self.max_size + new_width = int(round(self.max_size * width / height)) + else: + new_height = int(round(self.max_size * height / width)) + new_width = self.max_size + + size = new_height, new_width + + for index in range(len(datapoint.images)): + datapoint.images[index].data = F.resize(datapoint.images[index].data, size) + + for obj in datapoint.images[index].objects: + obj.segment = F.resize( + obj.segment[None, None], + size, + interpolation=InterpolationMode.NEAREST, + ).squeeze() + + h, w = size + datapoint.images[index].size = (h, w) + return datapoint + + +def get_bbox_xyxy_abs_coords_from_mask(mask): + """Get the bounding box (XYXY format w/ absolute coordinates) of a binary mask.""" + assert mask.dim() == 2 + rows = torch.any(mask, dim=1) + cols = torch.any(mask, dim=0) + row_inds = rows.nonzero().view(-1) + col_inds = cols.nonzero().view(-1) + if row_inds.numel() == 0: + # mask is empty + bbox = torch.zeros(1, 4, dtype=torch.float32) + bbox_area = 0.0 + else: + ymin, ymax = row_inds.min(), row_inds.max() + xmin, xmax = col_inds.min(), col_inds.max() + bbox = torch.tensor([xmin, ymin, xmax, ymax], dtype=torch.float32).view(1, 4) + bbox_area = float((ymax - ymin) * (xmax - xmin)) + return bbox, bbox_area + + +class MotionBlur: + def __init__(self, kernel_size=5, consistent_transform=True, p=0.5): + assert kernel_size % 2 == 1, "Kernel size must be odd." + self.kernel_size = kernel_size + self.consistent_transform = consistent_transform + self.p = p + + def __call__(self, datapoint: Datapoint, **kwargs): + if random.random() >= self.p: + return datapoint + if self.consistent_transform: + # Generate a single motion blur kernel for all images + kernel = self._generate_motion_blur_kernel() + for img in datapoint.images: + if not self.consistent_transform: + # Generate a new motion blur kernel for each image + kernel = self._generate_motion_blur_kernel() + img.data = self._apply_motion_blur(img.data, kernel) + + return datapoint + + def _generate_motion_blur_kernel(self): + kernel = torch.zeros((self.kernel_size, self.kernel_size)) + direction = random.choice(["horizontal", "vertical", "diagonal"]) + if direction == "horizontal": + kernel[self.kernel_size // 2, :] = 1.0 + elif direction == "vertical": + kernel[:, self.kernel_size // 2] = 1.0 + elif direction == "diagonal": + for i in range(self.kernel_size): + kernel[i, i] = 1.0 + kernel /= kernel.sum() + return kernel + + def _apply_motion_blur(self, image, kernel): + if isinstance(image, PILImage.Image): + image = F.to_tensor(image) + channels = image.shape[0] + kernel = kernel.to(image.device).unsqueeze(0).unsqueeze(0) + blurred_image = torch.nn.functional.conv2d( + image.unsqueeze(0), + kernel.repeat(channels, 1, 1, 1), + padding=self.kernel_size // 2, + groups=channels, + ) + return F.to_pil_image(blurred_image.squeeze(0)) + + +class LargeScaleJitter: + def __init__( + self, + scale_range=(0.1, 2.0), + aspect_ratio_range=(0.75, 1.33), + crop_size=(640, 640), + consistent_transform=True, + p=0.5, + ): + """ + Args:rack + scale_range (tuple): Range of scaling factors (min_scale, max_scale). + aspect_ratio_range (tuple): Range of aspect ratios (min_aspect_ratio, max_aspect_ratio). + crop_size (tuple): Target size of the cropped region (width, height). + consistent_transform (bool): Whether to apply the same transformation across all frames. + p (float): Probability of applying the transformation. + """ + self.scale_range = scale_range + self.aspect_ratio_range = aspect_ratio_range + self.crop_size = crop_size + self.consistent_transform = consistent_transform + self.p = p + + def __call__(self, datapoint: Datapoint, **kwargs): + if random.random() >= self.p: + return datapoint + + # Sample a single scale factor and aspect ratio for all frames + log_ratio = torch.log(torch.tensor(self.aspect_ratio_range)) + scale_factor = torch.empty(1).uniform_(*self.scale_range).item() + aspect_ratio = torch.exp( + torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) + ).item() + + for idx, img in enumerate(datapoint.images): + if not self.consistent_transform: + # Sample a new scale factor and aspect ratio for each frame + log_ratio = torch.log(torch.tensor(self.aspect_ratio_range)) + scale_factor = torch.empty(1).uniform_(*self.scale_range).item() + aspect_ratio = torch.exp( + torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) + ).item() + + # Compute the dimensions of the jittered crop + original_width, original_height = img.data.size + target_area = original_width * original_height * scale_factor + crop_width = int(round((target_area * aspect_ratio) ** 0.5)) + crop_height = int(round((target_area / aspect_ratio) ** 0.5)) + + # Randomly select the top-left corner of the crop + crop_x = random.randint(0, max(0, original_width - crop_width)) + crop_y = random.randint(0, max(0, original_height - crop_height)) + + # Extract the cropped region + datapoint = crop(datapoint, idx, (crop_x, crop_y, crop_width, crop_height)) + + # Resize the cropped region to the target crop size + datapoint = resize(datapoint, idx, self.crop_size) + + return datapoint diff --git a/third_party/sam3/sam3/train/transforms/filter_query_transforms.py b/third_party/sam3/sam3/train/transforms/filter_query_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..de66c93a54bad0d2d7905ec18ebe02a036d297e6 --- /dev/null +++ b/third_party/sam3/sam3/train/transforms/filter_query_transforms.py @@ -0,0 +1,607 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +import logging +import random +from collections import defaultdict +from typing import List, Optional, Union + +import torch +from sam3.train.data.sam3_image_dataset import Datapoint, FindQuery, Object + + +class FilterDataPointQueries: + find_ids_to_filter: set = None + get_ids_to_filter: set = None + obj_ids_to_filter: set = None # stored as pairs (img_id, obj_id) + + def identify_queries_to_filter(self, datapoint: Datapoint) -> None: + """ + Compute set of query ids to keep, for both find and get queries + """ + raise NotImplementedError + + def _do_filter_query(self, query: Union[FindQuery], query_id: int): + assert self.find_ids_to_filter is not None + + return query_id in self.find_ids_to_filter + + +class FilterQueryWithText(FilterDataPointQueries): + """ + Filter all datapoints which have query text in a specified list of exluded terms + """ + + def __init__( + self, exclude_find_keys: List[str] = None, exclude_get_keys: List[str] = None + ): + self.find_filter_keys = exclude_find_keys if exclude_find_keys else [] + self.get_filter_keys = exclude_get_keys if exclude_get_keys else [] + + def identify_queries_to_filter(self, datapoint): + self.obj_ids_to_filter = set() + del_find_ids = [] + del_get_ids = [] + for i, f_q in enumerate(datapoint.find_queries): + if f_q.query_text in self.find_filter_keys: + del_find_ids.append(i) + + self.find_ids_to_filter = set(del_find_ids) + + +class KeepMaxNumFindQueries(FilterDataPointQueries): + def __init__( + self, max_num_find_queries: int, retain_positive_queries: bool = False + ): + self.max_num_find_queries = max_num_find_queries + self.retain_positive_queries = retain_positive_queries + + def identify_queries_to_filter(self, datapoint: Datapoint) -> None: + self.obj_ids_to_filter = set() + num_find_queries = len(datapoint.find_queries) + if num_find_queries <= self.max_num_find_queries: + self.find_ids_to_filter = set() # keep all find queries + return + + if not self.retain_positive_queries: + all_find_query_ids = list(range(num_find_queries)) + num_queries_to_filter = max(0, num_find_queries - self.max_num_find_queries) + query_ids_to_filter = random.sample( + all_find_query_ids, k=num_queries_to_filter + ) + else: + # keep up to max_num_find_queries postive find queries and fill + # the remaining slots (if any) with negative find queries + pos_find_ids, neg_find_ids = [], [] + for i, f_q in enumerate(datapoint.find_queries): + # Negative finds return an empty list of object_ids_output + if len(f_q.object_ids_output) == 0: + neg_find_ids.append(i) + else: + pos_find_ids.append(i) + + if len(pos_find_ids) >= self.max_num_find_queries: + # we have more positive find queries than `max_num_find_queries`, + # so we subsample postive find queries and remove all negative find queries + num_queries_to_filter = len(pos_find_ids) - self.max_num_find_queries + query_ids_to_filter = random.sample( + pos_find_ids, k=num_queries_to_filter + ) + query_ids_to_filter.extend(neg_find_ids) + else: + # we have fewer positive find queries than `max_num_find_queries` + # so we need to fill the remaining with negative find queries + num_queries_to_filter = num_find_queries - self.max_num_find_queries + query_ids_to_filter = random.sample( + neg_find_ids, k=num_queries_to_filter + ) + + assert len(query_ids_to_filter) == num_find_queries - self.max_num_find_queries + self.find_ids_to_filter = set(query_ids_to_filter) + + +class KeepMaxNumFindQueriesVideo(FilterDataPointQueries): + def __init__( + self, + video_mosaic_max_num_find_queries_per_frame: int, + retain_positive_queries: bool = False, + ): + self.video_mosaic_max_num_find_queries_per_frame = ( + video_mosaic_max_num_find_queries_per_frame + ) + self.retain_positive_queries = retain_positive_queries + + def identify_queries_to_filter(self, datapoint: Datapoint) -> None: + self.obj_ids_to_filter = set() + num_find_queries = len(datapoint.find_queries) + + findQueries_to_imageIds = defaultdict(list) + max_queries_per_frame = True + for i, f_q in enumerate(datapoint.find_queries): + findQueries_to_imageIds[f_q.image_id].append(i) + if ( + len(findQueries_to_imageIds[f_q.image_id]) + > self.video_mosaic_max_num_find_queries_per_frame + ): + max_queries_per_frame = False + + if max_queries_per_frame: + self.find_ids_to_filter = set() + return + + num_frames = len(findQueries_to_imageIds) + findQueries_0 = findQueries_to_imageIds[0] + num_find_queries_0 = len(findQueries_0) + max_num_find_queries_per_frame = ( + self.video_mosaic_max_num_find_queries_per_frame + ) + if not self.retain_positive_queries: + find_query_ids_0 = list(range(num_find_queries_0)) + num_queries_to_filter = max( + 0, num_find_queries_0 - max_num_find_queries_per_frame + ) + query_ids_to_filter_0 = random.sample( + find_query_ids_0, k=num_queries_to_filter + ) + else: + # keep up to max_num_find_queries postive find queries and fill + # the remaining slots (if any) with negative find queries + pos_find_ids_0, neg_find_ids_0 = [], [] + for i, f_q_id in enumerate(findQueries_0): + f_q = datapoint.find_queries[f_q_id] + # Negative finds return an empty list of object_ids_output + if len(f_q.object_ids_output) == 0: + neg_find_ids_0.append(i) + else: + pos_find_ids_0.append(i) + + if len(pos_find_ids_0) >= max_num_find_queries_per_frame: + # we have more positive find queries than `max_num_find_queries`, + # so we subsample postive find queries and remove all negative find queries + num_queries_to_filter = ( + len(pos_find_ids_0) - max_num_find_queries_per_frame + ) + query_ids_to_filter_0 = random.sample( + pos_find_ids_0, k=num_queries_to_filter + ) + query_ids_to_filter_0.extend(neg_find_ids_0) + else: + # we have fewer positive find queries than `max_num_find_queries` + # so we need to fill the remaining with negative find queries + num_queries_to_filter = ( + num_find_queries_0 - max_num_find_queries_per_frame + ) + query_ids_to_filter_0 = random.sample( + neg_find_ids_0, k=num_queries_to_filter + ) + + # get based on frame 0 all find queries from all the frames with the same indices as in frame 0 + query_ids_to_filter = [] + for i in range(num_frames): + findQueries_i = findQueries_to_imageIds[i] + query_ids_to_filter.extend( + [findQueries_i[j] for j in query_ids_to_filter_0] + ) + + assert ( + len(query_ids_to_filter) + == num_find_queries + - self.video_mosaic_max_num_find_queries_per_frame * num_frames + ) + self.find_ids_to_filter = set(query_ids_to_filter) + + +class KeepSemanticFindQueriesOnly(FilterDataPointQueries): + def identify_queries_to_filter(self, datapoint: Datapoint) -> None: + self.obj_ids_to_filter = set() + self.find_ids_to_filter = { + i for i, q in enumerate(datapoint.find_queries) if q.input_bbox is not None + } # filter (remove) geometric find queries (whose input_bbox is not None) + + # Keep all get queries which don't depend on filtered finds + + +class KeepUnaryFindQueriesOnly(FilterDataPointQueries): + def identify_queries_to_filter(self, datapoint: Datapoint) -> None: + self.obj_ids_to_filter = set() + self.find_ids_to_filter = set() + + # Keep all get queries which don't depend on filtered finds + + +class FilterZeroBoxQueries(FilterDataPointQueries): + """ + Filters all find queries which predict a box with zero area + """ + + @staticmethod + def _is_zero_area_object(obj: Object): + # Check if height or width of bounding box is zero + bbox = obj.bbox # Assume in XYXY format + height = bbox[..., 3].item() - bbox[..., 1].item() + width = bbox[..., 2].item() - bbox[..., 0].item() + + return height == 0 or width == 0 + + def identify_queries_to_filter(self, datapoint): + self.obj_ids_to_filter = set() + + # Find objects with zero area + # Assume only one image per datapoint + image_objects = datapoint.images[0].objects + exclude_objects = { + obj_id + for obj_id, obj in enumerate(image_objects) + if self._is_zero_area_object(obj) + } + + # If a query predicts an object with zero area, drop the whole find query + del_find_ids = [] + for i, f_q in enumerate(datapoint.find_queries): + f_q_objects = set(f_q.object_ids_output) + if len(exclude_objects.intersection(f_q_objects)) > 0: + del_find_ids.append(i) + + self.find_ids_to_filter = set(del_find_ids) + + +class FilterFindQueriesWithTooManyOut(FilterDataPointQueries): + """ + Filters all find queries which have more than a specified number of objects in the output + """ + + def __init__(self, max_num_objects: int): + self.max_num_objects = max_num_objects + + def identify_queries_to_filter(self, datapoint): + self.obj_ids_to_filter = set() + + # If a query predicts more than max_num_objects, drop the whole find query + del_find_ids = [] + for i, f_q in enumerate(datapoint.find_queries): + if len(f_q.object_ids_output) > self.max_num_objects: + del_find_ids.append(i) + + self.find_ids_to_filter = set(del_find_ids) + + +class FilterEmptyTargets(FilterDataPointQueries): + """ + Filters all targets which have zero area + """ + + def identify_queries_to_filter(self, datapoint): + self.obj_ids_to_filter = set() + + for img_id in range(len(datapoint.images)): + for obj_id, obj in enumerate(datapoint.images[img_id].objects): + if obj.area < 1e-6: + self.obj_ids_to_filter.add((img_id, obj_id)) + self.find_ids_to_filter = set() + + +class FilterNonExhaustiveFindQueries(FilterDataPointQueries): + """ + Filters all find queries which are non-exhaustive + """ + + def __init__(self, exhaustivity_type: str): + """ + Args: + exhaustivity_type: Can be "pixel" or "instance": + -pixel: filter queries where the union of all segments covers every pixel belonging to target class + -instance: filter queries where there are non-separable or non annotated instances + Note that instance exhaustivity implies pixel exhaustivity + """ + assert exhaustivity_type in ["pixel", "instance"] + self.exhaustivity_type = exhaustivity_type + + def identify_queries_to_filter(self, datapoint): + self.obj_ids_to_filter = set() + + # If a query predicts more than max_num_objects, drop the whole find query + del_find_ids = [] + for i, f_q in enumerate(datapoint.find_queries): + if self.exhaustivity_type == "instance": + if not f_q.is_exhaustive: + del_find_ids.append(i) + elif self.exhaustivity_type == "pixel": + if f_q.is_pixel_exhaustive is not None and not f_q.is_pixel_exhaustive: + del_find_ids.append(i) + else: + raise RuntimeError( + f"Unknown exhaustivity type {self.exhaustivity_type}" + ) + + self.find_ids_to_filter = set(del_find_ids) + + +class FilterInvalidGeometricQueries(FilterDataPointQueries): + """ + Filters geometric queries whose output got deleted (eg due to cropping) + """ + + def identify_queries_to_filter(self, datapoint): + self.obj_ids_to_filter = set() + + # If a query predicts more than max_num_objects, drop the whole find query + del_find_ids = [] + for i, f_q in enumerate(datapoint.find_queries): + if f_q.input_bbox is not None and f_q.query_text == "geometric": + if len(f_q.object_ids_output) == 0: + del_find_ids.append(i) + self.find_ids_to_filter = set(del_find_ids) + + +class FlexibleFilterFindGetQueries: + def __init__( + self, query_filter: FilterDataPointQueries, enabled: bool = True + ) -> None: + self.query_filter = query_filter + self.enabled = enabled + + def __call__(self, datapoint, **kwargs): + if not self.enabled: + return datapoint + + # Identify all queries to filter + self.query_filter.identify_queries_to_filter(datapoint=datapoint) + + del_find_ids = [] + del_get_ids = [] + for i, f_q in enumerate(datapoint.find_queries): + if self.query_filter._do_filter_query(f_q, i): + datapoint.find_queries[i] = None + del_find_ids.append(i) + + new_find_queries = [] + new_get_queries = [] + + find_old_to_new_map = {} + get_old_to_new_map = {} + + find_counter = 0 + get_counter = 0 + + for i, f_q in enumerate(datapoint.find_queries): + if f_q is not None: + find_old_to_new_map[i] = find_counter + find_counter += 1 + new_find_queries.append(f_q) + + start_with_zero_check = False + for n_f_q in new_find_queries: + if n_f_q.query_processing_order == 0: + start_with_zero_check = True + break + + if len(new_find_queries) == 0: + start_with_zero_check = True + + assert ( + start_with_zero_check + ), "Invalid Find queries, they need to start at query_processing_order = 0" + + datapoint.find_queries = new_find_queries + + if len(datapoint.find_queries) == 0: + print("Warning: No find queries left in datapoint, this is not allowed") + print("Filtering function:", self.query_filter) + print("Datapoint:", datapoint) + raise ValueError + + # The deletion may have removed intermediate steps, so we need to remap to make them contiguous again + all_stages = sorted( + list(set(q.query_processing_order for q in datapoint.find_queries)) + ) + stage_map = {qpo: i for i, qpo in enumerate(all_stages)} + for i in range(len(datapoint.find_queries)): + qpo = datapoint.find_queries[i].query_processing_order + datapoint.find_queries[i].query_processing_order = stage_map[qpo] + + # Final step, clear up objects that are not used anymore + for img_id in range(len(datapoint.images)): + all_objects_ids = set( + i + for find in datapoint.find_queries + for i in find.object_ids_output + if find.image_id == img_id + ) + unused_ids = ( + set(range(len(datapoint.images[img_id].objects))) - all_objects_ids + ) + for tgt_img_id, tgt_obj_id in self.query_filter.obj_ids_to_filter: + if tgt_img_id == img_id: + unused_ids.add(tgt_obj_id) + + if len(unused_ids) > 0: + old_objects = datapoint.images[img_id].objects + object_old_to_new_map = {} + new_objects = [] + for i, o in enumerate(old_objects): + if i not in unused_ids: + object_old_to_new_map[i] = len(new_objects) + new_objects.append(o) + + datapoint.images[img_id].objects = new_objects + + # Remap the outputs of the find queries + affected_find_queries_ids = set() + object_old_to_new_map_per_query = {} + for fid, find in enumerate(datapoint.find_queries): + if find.image_id == img_id: + old_object_ids_output = find.object_ids_output + object_old_to_new_map_per_query[fid] = {} + find.object_ids_output = [] + for oid, old_obj_id in enumerate(old_object_ids_output): + if old_obj_id not in unused_ids: + new_obj_id = object_old_to_new_map[old_obj_id] + find.object_ids_output.append(new_obj_id) + object_old_to_new_map_per_query[fid][oid] = ( + len(find.object_ids_output) - 1 + ) + affected_find_queries_ids.add(fid) + + # finally remove unused images + all_imgs_to_keep = set() + for f_q in datapoint.find_queries: + all_imgs_to_keep.add(f_q.image_id) + + old_img_id_to_new_img_id = {} + new_images = [] + for img_id, img in enumerate(datapoint.images): + if img_id in all_imgs_to_keep: + old_img_id_to_new_img_id[img_id] = len(new_images) + new_images.append(img) + datapoint.images = new_images + + for f_q in datapoint.find_queries: + f_q.image_id = old_img_id_to_new_img_id[f_q.image_id] + + return datapoint + + +class AddPrefixSuffixToFindText: + """ + Add prefix or suffix strings to find query text on the fly. + + If `condition_on_text` is True, the prefix or suffix strings are only added + to those find query text in `condition_text_list` (case-insensitive). + """ + + def __init__( + self, + prefix: Optional[str] = None, + suffix: Optional[str] = None, + condition_on_text: bool = False, + condition_text_list: Optional[List[str]] = None, + enabled: bool = True, + ) -> None: + self.prefix = prefix + self.suffix = suffix + self.condition_on_text = condition_on_text + if self.condition_on_text: + assert condition_text_list is not None + self.condition_text_set = {s.lower().strip() for s in condition_text_list} + self.enabled = enabled + if self.enabled: + logging.info( + f"AddPrefixSuffixToFindText: prefix={prefix}, suffix={suffix}, " + f"condition_on_text={condition_on_text}, condition_text_list={condition_text_list}" + ) + + def __call__(self, datapoint, **kwargs): + if not self.enabled: + return datapoint + + for find in datapoint.find_queries: + if find.query_text == "geometric": + # skip geometric find queries + continue + if ( + self.condition_on_text + and find.query_text.lower().strip() not in self.condition_text_set + ): + # if condition_on_text is True, skip those queries not in condition_text_set + continue + + # add prefix and/or suffix strings to the find query text + if self.prefix is not None: + find.query_text = self.prefix + find.query_text + if self.suffix is not None: + find.query_text = find.query_text + self.suffix + + return datapoint + + +class FilterCrowds(FilterDataPointQueries): + def identify_queries_to_filter(self, datapoint: Datapoint) -> None: + """ + Compute set of query ids to keep, for both find and get queries + """ + self.obj_ids_to_filter = set() + self.find_ids_to_filter = set() + # self.get_ids_to_filter = set() + for img_id, img in enumerate(datapoint.images): + for obj_id, obj in enumerate(img.objects): + if obj.is_crowd: + self.obj_ids_to_filter.add((img_id, obj_id)) + + +class TextQueryToVisual: + """ + Transform a test query to a visual query (with some proba), using any of the output targets as the prompt + """ + + def __init__(self, probability, keep_text_queries=False) -> None: + self.probability = probability + assert 0 <= probability <= 1 + self.keep_text_queries = keep_text_queries + + def __call__(self, datapoint: Datapoint, **kwargs): + for find in datapoint.find_queries: + if find.input_bbox is not None or find.input_points is not None: + # skip geometric find queries + continue + + if len(find.object_ids_output) == 0: + # Can't create a visual query, skip + continue + + if find.query_processing_order > 0: + # Second stage query, can't use + continue + + if random.random() > self.probability: + continue + + selected_vq_id = random.choice(find.object_ids_output) + img_id = find.image_id + + find.input_bbox = datapoint.images[img_id].objects[selected_vq_id].bbox + find.input_bbox_label = torch.ones(1, dtype=torch.bool) + if not self.keep_text_queries: + find.query_text = "visual" + + return datapoint + + +class RemoveInputBoxes: + """ + Remove input boxes from find queries + """ + + def __init__(self) -> None: + pass + + def __call__(self, datapoint: Datapoint, **kwargs): + for find in datapoint.find_queries: + if find.input_bbox is None: + continue + + if find.query_text == "geometric": + print("Warning: removing input box from geometric find query") + + find.input_bbox = None + return datapoint + + +class OverwriteTextQuery: + """ + With some probability, overwrite the text query with a custom text + """ + + def __init__(self, target_text, probability=1.0) -> None: + self.probability = probability + self.target_text = target_text + assert 0 <= probability <= 1 + + def __call__(self, datapoint: Datapoint, **kwargs): + for find in datapoint.find_queries: + if random.random() > self.probability: + continue + + find.query_text = self.target_text + + return datapoint diff --git a/third_party/sam3/sam3/train/transforms/point_sampling.py b/third_party/sam3/sam3/train/transforms/point_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..d27c6b81525fa99874da8bd5c185b675560cc307 --- /dev/null +++ b/third_party/sam3/sam3/train/transforms/point_sampling.py @@ -0,0 +1,346 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +import cv2 +import numpy as np +import torch +from PIL import Image as PILImage +from pycocotools import mask as mask_util +from sam3.train.data.sam3_image_dataset import Datapoint +from torchvision.ops import masks_to_boxes + + +def sample_points_from_rle(rle, n_points, mode, box=None, normalize=True): + """ + Sample random points from a mask provided in COCO RLE format. 'mode' + 'mode' is in ["centered", "random_mask", "random_box"] + "centered": points are sampled farthest from the mask edges and each other + "random_mask": points are sampled uniformly from the mask + "random_box": points are sampled uniformly from the annotation's box + 'box' must be provided if 'mode' is "random_box". + If 'normalize' is true, points are in [0,1], relative to mask h,w. + """ + mask = np.ascontiguousarray(mask_util.decode(rle)) + points = sample_points_from_mask(mask, n_points, mode, box) + + if normalize: + h, w = mask.shape + norm = np.array([w, h, 1.0])[None, :] + points = points / norm + + return points + + +def sample_points_from_mask(mask, n_points, mode, box=None): + if mode == "centered": + points = center_positive_sample(mask, n_points) + elif mode == "random_mask": + points = uniform_positive_sample(mask, n_points) + elif mode == "random_box": + assert box is not None, "'random_box' mode requires a provided box." + points = uniform_sample_from_box(mask, box, n_points) + else: + raise ValueError(f"Unknown point sampling mode {mode}.") + return points + + +def uniform_positive_sample(mask, n_points): + """ + Samples positive points uniformly from the mask. Only integer pixel + values are sampled. + """ + # Sampling directly from the uncompressed RLE would be faster but is + # likely unnecessary. + mask_points = np.stack(np.nonzero(mask), axis=0).transpose(1, 0) + assert len(mask_points) > 0, "Can't sample positive points from an empty mask." + selected_idxs = np.random.randint(low=0, high=len(mask_points), size=n_points) + selected_points = mask_points[selected_idxs] + + selected_points = selected_points[:, ::-1] # (y, x) -> (x, y) + labels = np.ones((len(selected_points), 1)) + selected_points = np.concatenate([selected_points, labels], axis=1) + + return selected_points + + +def center_positive_sample(mask, n_points): + """ + Samples points farthest from mask edges (by distance transform) + and subsequent points also farthest from each other. Each new point + sampled is treated as an edge for future points. Edges of the image are + treated as edges of the mask. + """ + + # Pad mask by one pixel on each end to assure distance transform + # avoids edges + padded_mask = np.pad(mask, 1) + + points = [] + for _ in range(n_points): + assert np.max(mask) > 0, "Can't sample positive points from an empty mask." + dist = cv2.distanceTransform(padded_mask, cv2.DIST_L2, 0) + point = np.unravel_index(dist.argmax(), dist.shape) + # Mark selected point as background so next point avoids it + padded_mask[point[0], point[1]] = 0 + points.append(point[::-1]) # (y, x) -> (x, y) + + points = np.stack(points, axis=0) + points = points - 1 # Subtract left/top padding of 1 + labels = np.ones((len(points), 1)) + points = np.concatenate([points, labels], axis=1) + + return points + + +def uniform_sample_from_box(mask, box, n_points): + """ + Sample points uniformly from the provided box. The points' labels + are determined by the provided mask. Does not guarantee a positive + point is sampled. The box is assumed unnormalized in XYXY format. + Points are sampled at integer values. + """ + + # Since lower/right edges are exclusive, ceil can be applied to all edges + int_box = np.ceil(box) + + x = np.random.randint(low=int_box[0], high=int_box[2], size=n_points) + y = np.random.randint(low=int_box[1], high=int_box[3], size=n_points) + labels = mask[y, x] + points = np.stack([x, y, labels], axis=1) + + return points + + +def rescale_box_xyxy(box, factor, imsize=None): + """ + Rescale a box providing in unnormalized XYXY format, fixing the center. + If imsize is provided, clamp to the image. + """ + cx, cy = (box[0] + box[2]) / 2, (box[1] + box[3]) / 2 + w, h = box[2] - box[0], box[3] - box[1] + + new_w, new_h = factor * w, factor * h + + new_x0, new_y0 = cx - new_w / 2, cy - new_h / 2 + new_x1, new_y1 = cx + new_w / 2, cy + new_h / 2 + + if imsize is not None: + new_x0 = max(min(new_x0, imsize[1]), 0) + new_x1 = max(min(new_x1, imsize[1]), 0) + new_y0 = max(min(new_y0, imsize[0]), 0) + new_y1 = max(min(new_y1, imsize[0]), 0) + + return [new_x0, new_y0, new_x1, new_y1] + + +def noise_box(box, im_size, box_noise_std, box_noise_max, min_box_area): + if box_noise_std <= 0.0: + return box + noise = box_noise_std * torch.randn(size=(4,)) + w, h = box[2] - box[0], box[3] - box[1] + scale_factor = torch.tensor([w, h, w, h]) + noise = noise * scale_factor + if box_noise_max is not None: + noise = torch.clamp(noise, -box_noise_max, box_noise_max) + input_box = box + noise + # Clamp to maximum image size + img_clamp = torch.tensor([im_size[1], im_size[0], im_size[1], im_size[0]]) + input_box = torch.maximum(input_box, torch.zeros_like(input_box)) + input_box = torch.minimum(input_box, img_clamp) + if (input_box[2] - input_box[0]) * (input_box[3] - input_box[1]) <= min_box_area: + return box + + return input_box + + +class RandomGeometricInputsAPI: + """ + For geometric queries, replaces the input box or points with a random + one sampled from the GT mask. Segments must be provided for objects + that are targets of geometric queries, and must be binary masks. Existing + point and box queries in the datapoint will be ignored and completely replaced. + Will sample points and boxes in XYXY format in absolute pixel space. + + Geometry queries are currently determined by taking any query whose + query text is a set value. + + Args: + num_points (int or (int, int)): how many points to sample. If a tuple, + sample a random number of points uniformly over the inclusive range. + box_chance (float): fraction of time a box is sampled. A box will replace + one sampled point. + box_noise_std (float): if greater than 0, add noise to the sampled boxes + with this std. Noise is relative to the length of the box side. + box_noise_max (int): if not none, truncate any box noise larger than this + in terms of absolute pixels. + resample_box_from_mask (bool): if True, any sampled box will be determined + by finding the extrema of the provided mask. If False, the bbox provided + in the target object will be used. + point_sample_mode (str): In ["centered", "random_mask", "random_box"], + controlling how points are sampled: + "centered": points are sampled farthest from the mask edges and each other + "random_mask": points are sampled uniformly from the mask + "random_box": points are sampled uniformly from the annotation's box + Note that "centered" may be too slow for on-line generation. + geometric_query_str (str): what string in query_text indicates a + geometry query. + minimum_box_area (float): sampled boxes with area this size or smaller after + noising will use the original box instead. It is the input's responsibility + to avoid original boxes that violate necessary area bounds. + concat_points (bool): if True, any sampled points will be added to existing + ones instead of replacing them. + + """ + + def __init__( + self, + num_points, + box_chance, + box_noise_std=0.0, + box_noise_max=None, + minimum_box_area=0.0, + resample_box_from_mask=False, + point_sample_mode="random_mask", + sample_box_scale_factor=1.0, + geometric_query_str="geometric", + concat_points=False, + ): + self.num_points = num_points + if not isinstance(self.num_points, int): + # Convert from inclusive range to exclusive range expected by torch + self.num_points[1] += 1 + self.num_points = tuple(self.num_points) + self.box_chance = box_chance + self.box_noise_std = box_noise_std + self.box_noise_max = box_noise_max + self.minimum_box_area = minimum_box_area + self.resample_box_from_mask = resample_box_from_mask + self.point_sample_mode = point_sample_mode + assert point_sample_mode in [ + "centered", + "random_mask", + "random_box", + ], "Unknown point sample mode." + self.geometric_query_str = geometric_query_str + self.concat_points = concat_points + self.sample_box_scale_factor = sample_box_scale_factor + + def _sample_num_points_and_if_box(self): + if isinstance(self.num_points, tuple): + n_points = torch.randint( + low=self.num_points[0], high=self.num_points[1], size=(1,) + ).item() + else: + n_points = self.num_points + if self.box_chance > 0.0: + use_box = torch.rand(size=(1,)).item() < self.box_chance + n_points -= int(use_box) # box stands in for one point + else: + use_box = False + return n_points, use_box + + def _get_original_box(self, target_object): + if not self.resample_box_from_mask: + return target_object.bbox + mask = target_object.segment + return masks_to_boxes(mask[None, :, :])[0] + + def _get_target_object(self, datapoint, query): + img = datapoint.images[query.image_id] + targets = query.object_ids_output + assert ( + len(targets) == 1 + ), "Geometric queries only support a single target object." + target_idx = targets[0] + return img.objects[target_idx] + + def __call__(self, datapoint, **kwargs): + for query in datapoint.find_queries: + if query.query_text != self.geometric_query_str: + continue + + target_object = self._get_target_object(datapoint, query) + n_points, use_box = self._sample_num_points_and_if_box() + box = self._get_original_box(target_object) + + mask = target_object.segment + if n_points > 0: + # FIXME: The conversion to numpy and back to reuse code + # is awkward, but this is all in the dataloader worker anyway + # on CPU and so I don't think it should matter. + if self.sample_box_scale_factor != 1.0: + sample_box = rescale_box_xyxy( + box.numpy(), self.sample_box_scale_factor, mask.shape + ) + else: + sample_box = box.numpy() + input_points = sample_points_from_mask( + mask.numpy(), + n_points, + self.point_sample_mode, + sample_box, + ) + input_points = torch.as_tensor(input_points) + input_points = input_points[None, :, :] + if self.concat_points and query.input_points is not None: + input_points = torch.cat([query.input_points, input_points], dim=1) + else: + input_points = query.input_points if self.concat_points else None + + if use_box: + w, h = datapoint.images[query.image_id].size + input_box = noise_box( + box, + (h, w), + box_noise_std=self.box_noise_std, + box_noise_max=self.box_noise_max, + min_box_area=self.minimum_box_area, + ) + input_box = input_box[None, :] + else: + input_box = query.input_bbox if self.concat_points else None + + query.input_points = input_points + query.input_bbox = input_box + + return datapoint + + +class RandomizeInputBbox: + """ + Simplified version of the geometric transform that only deals with input boxes + """ + + def __init__( + self, + box_noise_std=0.0, + box_noise_max=None, + minimum_box_area=0.0, + ): + self.box_noise_std = box_noise_std + self.box_noise_max = box_noise_max + self.minimum_box_area = minimum_box_area + + def __call__(self, datapoint: Datapoint, **kwargs): + for query in datapoint.find_queries: + if query.input_bbox is None: + continue + + img = datapoint.images[query.image_id].data + if isinstance(img, PILImage.Image): + w, h = img.size + else: + assert isinstance(img, torch.Tensor) + h, w = img.shape[-2:] + + for box_id in range(query.input_bbox.shape[0]): + query.input_bbox[box_id, :] = noise_box( + query.input_bbox[box_id, :].view(4), + (h, w), + box_noise_std=self.box_noise_std, + box_noise_max=self.box_noise_max, + min_box_area=self.minimum_box_area, + ).view(1, 4) + + return datapoint diff --git a/third_party/sam3/sam3/train/transforms/segmentation.py b/third_party/sam3/sam3/train/transforms/segmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..4fd2316585b74c6a707fe937fb8f73f33cf8b815 --- /dev/null +++ b/third_party/sam3/sam3/train/transforms/segmentation.py @@ -0,0 +1,156 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +import numpy as np +import pycocotools.mask as mask_utils +import torch +import torchvision.transforms.functional as F +from PIL import Image as PILImage +from sam3.model.box_ops import masks_to_boxes +from sam3.train.data.sam3_image_dataset import Datapoint + + +class InstanceToSemantic(object): + """Convert instance segmentation to semantic segmentation.""" + + def __init__(self, delete_instance=True, use_rle=False): + self.delete_instance = delete_instance + self.use_rle = use_rle + + def __call__(self, datapoint: Datapoint, **kwargs): + for fquery in datapoint.find_queries: + h, w = datapoint.images[fquery.image_id].size + + if self.use_rle: + all_segs = [ + datapoint.images[fquery.image_id].objects[obj_id].segment + for obj_id in fquery.object_ids_output + ] + if len(all_segs) > 0: + # we need to double check that all rles are the correct size + # Otherwise cocotools will fail silently to an empty [0,0] mask + for seg in all_segs: + assert seg["size"] == all_segs[0]["size"], ( + "Instance segments have inconsistent sizes. " + f"Found sizes {seg['size']} and {all_segs[0]['size']}" + ) + fquery.semantic_target = mask_utils.merge(all_segs) + else: + # There is no good way to create an empty RLE of the correct size + # We resort to converting an empty box to RLE + fquery.semantic_target = mask_utils.frPyObjects( + np.array([[0, 0, 0, 0]], dtype=np.float64), h, w + )[0] + + else: + # `semantic_target` is uint8 and remains uint8 throughout the transforms + # (it contains binary 0 and 1 values just like `segment` for each object) + fquery.semantic_target = torch.zeros((h, w), dtype=torch.uint8) + for obj_id in fquery.object_ids_output: + segment = datapoint.images[fquery.image_id].objects[obj_id].segment + if segment is not None: + assert ( + isinstance(segment, torch.Tensor) + and segment.dtype == torch.uint8 + ) + fquery.semantic_target |= segment + + if self.delete_instance: + for img in datapoint.images: + for obj in img.objects: + del obj.segment + obj.segment = None + + return datapoint + + +class RecomputeBoxesFromMasks: + """Recompute bounding boxes from masks.""" + + def __call__(self, datapoint: Datapoint, **kwargs): + for img in datapoint.images: + for obj in img.objects: + # Note: if the mask is empty, the bounding box will be undefined + # The empty targets should be subsequently filtered + obj.bbox = masks_to_boxes(obj.segment) + obj.area = obj.segment.sum().item() + + return datapoint + + +class DecodeRle: + """This transform decodes RLEs into binary segments. + Implementing it as a transforms allows lazy loading. Some transforms (eg query filters) + may be deleting masks, so decoding them from the beginning is wasteful. + + This transforms needs to be called before any kind of geometric manipulation + """ + + def __call__(self, datapoint: Datapoint, **kwargs): + imgId2size = {} + warning_shown = False + for imgId, img in enumerate(datapoint.images): + if isinstance(img.data, PILImage.Image): + img_w, img_h = img.data.size + elif isinstance(img.data, torch.Tensor): + img_w, img_h = img.data.shape[-2:] + else: + raise RuntimeError(f"Unexpected image type {type(img.data)}") + + imgId2size[imgId] = (img_h, img_w) + + for obj in img.objects: + if obj.segment is not None and not isinstance( + obj.segment, torch.Tensor + ): + if mask_utils.area(obj.segment) == 0: + print("Warning, empty mask found, approximating from box") + obj.segment = torch.zeros(img_h, img_w, dtype=torch.uint8) + x1, y1, x2, y2 = obj.bbox.int().tolist() + obj.segment[y1 : max(y2, y1 + 1), x1 : max(x1 + 1, x2)] = 1 + else: + obj.segment = mask_utils.decode(obj.segment) + # segment is uint8 and remains uint8 throughout the transforms + obj.segment = torch.tensor(obj.segment).to(torch.uint8) + + if list(obj.segment.shape) != [img_h, img_w]: + # Should not happen often, but adding for security + if not warning_shown: + print( + f"Warning expected instance segmentation size to be {[img_h, img_w]} but found {list(obj.segment.shape)}" + ) + # Printing only once per datapoint to avoid spam + warning_shown = True + + obj.segment = F.resize( + obj.segment[None], (img_h, img_w) + ).squeeze(0) + + assert list(obj.segment.shape) == [img_h, img_w] + + warning_shown = False + for query in datapoint.find_queries: + if query.semantic_target is not None and not isinstance( + query.semantic_target, torch.Tensor + ): + query.semantic_target = mask_utils.decode(query.semantic_target) + # segment is uint8 and remains uint8 throughout the transforms + query.semantic_target = torch.tensor(query.semantic_target).to( + torch.uint8 + ) + if tuple(query.semantic_target.shape) != imgId2size[query.image_id]: + if not warning_shown: + print( + f"Warning expected semantic segmentation size to be {imgId2size[query.image_id]} but found {tuple(query.semantic_target.shape)}" + ) + # Printing only once per datapoint to avoid spam + warning_shown = True + + query.semantic_target = F.resize( + query.semantic_target[None], imgId2size[query.image_id] + ).squeeze(0) + + assert tuple(query.semantic_target.shape) == imgId2size[query.image_id] + + return datapoint diff --git a/third_party/sam3/sam3/train/utils/__init__.py b/third_party/sam3/sam3/train/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..726cb1eb0ff0f17f7edc4ec73f1a73d1ac87bf59 --- /dev/null +++ b/third_party/sam3/sam3/train/utils/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe diff --git a/third_party/sam3/sam3/train/utils/checkpoint_utils.py b/third_party/sam3/sam3/train/utils/checkpoint_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..465e006b3f88934e385c580173cbca5726a59374 --- /dev/null +++ b/third_party/sam3/sam3/train/utils/checkpoint_utils.py @@ -0,0 +1,360 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + + +import contextlib +import fnmatch +import logging +from typing import ( + Any, + Callable, + Dict, + List, + Mapping, + Optional, + Sequence, + Set, + Tuple, + Union, +) + +import numpy as np +import torch +import torch.nn as nn +from iopath.common.file_io import g_pathmgr +from torch.jit._script import RecursiveScriptModule + + +def unix_pattern_to_parameter_names( + constraints: List[str], all_parameter_names: Sequence[str] +) -> Union[None, Set[str]]: + """ + Go through the list of parameter names and select those that match + any of the provided constraints + """ + parameter_names = [] + for param_name in constraints: + matching_parameters = set(fnmatch.filter(all_parameter_names, param_name)) + assert ( + len(matching_parameters) > 0 + ), f"param_names {param_name} don't match any param in the given names." + parameter_names.append(matching_parameters) + return set.union(*parameter_names) + + +def filter_params_matching_unix_pattern( + patterns: List[str], state_dict: Dict[str, torch.Tensor] +) -> Dict[str, torch.Tensor]: + """ + Remove from the state dictionary the parameters matching the provided unix patterns + + Args: + patterns: the list of unix patterns to exclude + state_dict: the dictionary to filter + + Returns: + A new state dictionary + """ + if len(patterns) == 0: + return {} + + all_keys = list(state_dict.keys()) + included_keys = unix_pattern_to_parameter_names(patterns, all_keys) + return {k: state_dict[k] for k in included_keys} + + +def exclude_params_matching_unix_pattern( + patterns: List[str], state_dict: Dict[str, torch.Tensor] +) -> Dict[str, torch.Tensor]: + """ + Remove from the state dictionary the parameters matching the provided unix patterns + + Args: + patterns: the list of unix patterns to exclude + state_dict: the dictionary to filter + + Returns: + A new state dictionary + """ + if len(patterns) == 0: + return state_dict + + all_keys = list(state_dict.keys()) + excluded_keys = unix_pattern_to_parameter_names(patterns, all_keys) + return {k: v for k, v in state_dict.items() if k not in excluded_keys} + + +def _get_state_dict_summary(state_dict: Dict[str, torch.Tensor]): + keys = [] + trace = [] + for k, v in state_dict.items(): + keys.append(k) + trace.append(v.sum().item()) + trace = np.array(trace)[np.argsort(keys)] + return trace + + +def assert_skipped_parameters_are_frozen(model: nn.Module, patterns: List[str]): + """ + Verifies that all the parameters matching the provided patterns + are frozen - this acts as a safeguard when ignoring parameter + when saving checkpoints - if the parameters are in fact trainable + """ + if not patterns: + return + + frozen_state_dict = filter_params_matching_unix_pattern( + patterns=patterns, state_dict=model.state_dict() + ) + non_frozen_keys = { + n + for n, p in model.named_parameters() + if n in frozen_state_dict and p.requires_grad + } + if non_frozen_keys: + raise ValueError( + f"Parameters excluded with `skip_saving_parameters` should be frozen: {non_frozen_keys}" + ) + + +@contextlib.contextmanager +def with_check_parameter_frozen( + model: nn.Module, patterns: List[str], disabled: bool = True +): + """ + Context manager that inspects a model surrounding a piece of code + and verifies if the model has been updated by this piece of code + + The function will raise an exception if the model has been updated + on at least one of the parameter that matches one of the pattern + + Args: + model: the model that might have been updated + patterns: for the parameters we want to observe + allowed: + """ + if not patterns or disabled: + yield + return + + frozen_state_dict = filter_params_matching_unix_pattern( + patterns=patterns, state_dict=model.state_dict() + ) + summary_before = _get_state_dict_summary(frozen_state_dict) + + yield + + frozen_state_dict = filter_params_matching_unix_pattern( + patterns=patterns, state_dict=model.state_dict() + ) + summary_after = _get_state_dict_summary(frozen_state_dict) + + if not np.allclose(summary_before, summary_after, atol=1e-6): + raise ValueError( + f""" + The `model_weight_initializer` has initialized parameters frozen with `skip_saving_parameters`. + You can resolve this error by either initializing those parameters from within the model definition + or using the flag `trainer.checkpoint.initialize_after_preemption` to True. + """ + ) + + +class CkptExcludeKernel: + """ + Removes the keys from the given model state_dict that match the key_pattern. + + Args: + key_pattern: Patterns used to select the keys in the state_dict + that are eligible for this kernel. + """ + + def __init__(self, key_pattern: List[str]): + self.key_pattern = key_pattern + + def __call__(self, state_dict: Dict): + """ + Args: + state_dict: A dictionary representing the given checkpoint's state dict. + """ + if len(self.key_pattern) == 0: + return state_dict + exclude_keys = unix_pattern_to_parameter_names( + self.key_pattern, state_dict.keys() + ) + return {k: v for k, v in state_dict.items() if k not in exclude_keys} + + +def load_checkpoint( + path_list: List[str], + pick_recursive_keys: Optional[List[str]] = None, + map_location: str = "cpu", +) -> Any: + """ + Loads a checkpoint from the specified path. + + Args: + path_list: A list of paths which contain the checkpoint. Each element + is tried (in order) until a file that exists is found. That file is then + used to read the checkpoint. + pick_recursive_keys: Picks sub dicts from the loaded checkpoint if not None. + For pick_recursive_keys = ["a", "b"], will return checkpoint_dict["a"]["b"] + map_location (str): a function, torch.device, string or a dict specifying how to + remap storage locations + + Returns: Model with the matchin pre-trained weights loaded. + """ + path_exists = False + for path in path_list: + if g_pathmgr.exists(path): + path_exists = True + break + + if not path_exists: + raise ValueError(f"No path exists in {path_list}") + + with g_pathmgr.open(path, "rb") as f: + checkpoint = torch.load(f, map_location=map_location) + + logging.info(f"Loaded checkpoint from {path}") + if pick_recursive_keys is not None: + for key in pick_recursive_keys: + checkpoint = checkpoint[key] + return checkpoint + + +def get_state_dict(checkpoint, ckpt_state_dict_keys): + if isinstance(checkpoint, RecursiveScriptModule): + # This is a torchscript JIT model + return checkpoint.state_dict() + pre_train_dict = checkpoint + for i, key in enumerate(ckpt_state_dict_keys): + if (isinstance(pre_train_dict, Mapping) and key not in pre_train_dict) or ( + isinstance(pre_train_dict, Sequence) and key >= len(pre_train_dict) + ): + key_str = ( + '["' + '"]["'.join(list(map(ckpt_state_dict_keys[:i], str))) + '"]' + ) + raise KeyError( + f"'{key}' not found in checkpoint{key_str} " + f"with keys: {pre_train_dict.keys()}" + ) + pre_train_dict = pre_train_dict[key] + return pre_train_dict + + +def load_checkpoint_and_apply_kernels( + checkpoint_path: str, + checkpoint_kernels: List[Callable] = None, + ckpt_state_dict_keys: Tuple[str] = ("state_dict",), + map_location: str = "cpu", +) -> nn.Module: + """ + Performs checkpoint loading with a variety of pre-processing kernel applied in + sequence. + + Args: + checkpoint_path (str): Path to the checkpoint. + checkpoint_kernels List(Callable): A list of checkpoint processing kernels + to apply in the specified order. Supported kernels include `CkptIncludeKernel`, + `CkptExcludeKernel`, etc. These kernels are applied in the + given order. + ckpt_state_dict_keys (str): Keys containing the model state dict. + map_location (str): a function, torch.device, string or a dict specifying how to + remap storage locations + + Returns: Model with the matchin pre-trained weights loaded. + """ + assert g_pathmgr.exists(checkpoint_path), "Checkpoint '{}' not found".format( + checkpoint_path + ) + + # Load the checkpoint on CPU to avoid GPU mem spike. + with g_pathmgr.open(checkpoint_path, "rb") as f: + checkpoint = torch.load(f, map_location=map_location) + + pre_train_dict = get_state_dict(checkpoint, ckpt_state_dict_keys) + + # Not logging into info etc since it's a huge log + logging.debug( + "Loaded Checkpoint State Dict pre-kernel application: %s" + % str(", ".join(list(pre_train_dict.keys()))) + ) + # Apply kernels + if checkpoint_kernels is not None: + for f in checkpoint_kernels: + pre_train_dict = f(state_dict=pre_train_dict) + + logging.debug( + "Loaded Checkpoint State Dict Post-kernel application %s" + % str(", ".join(list(pre_train_dict.keys()))) + ) + + return pre_train_dict + + +def check_load_state_dict_errors( + missing_keys, + unexpected_keys, + strict: bool, + ignore_missing_keys: List[str] = None, + ignore_unexpected_keys: List[str] = None, +): + if ignore_missing_keys is not None and len(ignore_missing_keys) > 0: + ignored_keys = unix_pattern_to_parameter_names( + ignore_missing_keys, missing_keys + ) + missing_keys = [key for key in missing_keys if key not in ignored_keys] + + if ignore_unexpected_keys is not None and len(ignore_unexpected_keys) > 0: + ignored_unexpected_keys = unix_pattern_to_parameter_names( + ignore_unexpected_keys, unexpected_keys + ) + unexpected_keys = [ + key for key in unexpected_keys if key not in ignored_unexpected_keys + ] + + err = "State key mismatch." + if unexpected_keys: + err += f" Unexpected keys: {unexpected_keys}." + if missing_keys: + err += f" Missing keys: {missing_keys}." + + if unexpected_keys or missing_keys: + logging.warning(err) + if unexpected_keys or strict: + raise KeyError(err) + + +def load_state_dict_into_model( + state_dict: Dict, + model: nn.Module, + strict: bool = True, + ignore_missing_keys: List[str] = None, + ignore_unexpected_keys: List[str] = None, + checkpoint_kernels: List[Callable] = None, +): + """ + Loads a state dict into the given model. + + Args: + state_dict: A dictionary containing the model's + state dict, or a subset if strict is False + model: Model to load the checkpoint weights into + strict: raise if the state_dict has missing state keys + ignore_missing_keys: unix pattern of keys to ignore + """ + # Apply kernels + if checkpoint_kernels is not None: + for f in checkpoint_kernels: + state_dict = f(state_dict=state_dict) + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + + check_load_state_dict_errors( + missing_keys, + unexpected_keys, + strict=strict, + ignore_missing_keys=ignore_missing_keys, + ignore_unexpected_keys=ignore_unexpected_keys, + ) + return model diff --git a/third_party/sam3/sam3/train/utils/distributed.py b/third_party/sam3/sam3/train/utils/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..a91c4bca42b7817f6a5aecfd423aa42610aaeeb6 --- /dev/null +++ b/third_party/sam3/sam3/train/utils/distributed.py @@ -0,0 +1,587 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import datetime +import functools +import io +import logging +import os +import random +import tempfile +import time +from typing import Any, Callable, List, Tuple + +import torch +import torch.autograd as autograd +import torch.distributed as dist + + +# Default to GPU 0 +_cuda_device_index: int = 0 + +# Setting _cuda_device_index to -1 internally implies that we should use CPU +_CPU_DEVICE_INDEX = -1 +_PRIMARY_RANK = 0 + + +@functools.lru_cache() +def _get_global_gloo_group(): + """ + Return a process group based on gloo backend, containing all the ranks + The result is cached. + """ + + if dist.get_backend() == "nccl": + # Increase timeout from 1800 sec to 43200 sec (12 hr) to avoid some processes + # being much slower than others causing a timeout (which can happen in relation + # or LVIS class mAP evaluation). + timeout = 43200 + return dist.new_group( + backend="gloo", + timeout=datetime.timedelta(seconds=timeout), + ) + + return dist.group.WORLD + + +def is_main_process(): + """Return true if the current process is the main one""" + return get_rank() == 0 + + +def all_gather_via_filesys(data, filesys_save_dir=None, gather_to_rank_0_only=False): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors), similar to + `all_gather` above, but using filesystem instead of collective ops. + + If gather_to_rank_0_only is True, only rank 0 will load the gathered object list + (and other ranks will have an empty list). + """ + world_size = get_world_size() + if world_size == 1: + return [data] + + print("gathering via files") + cpu_group = _get_global_gloo_group() + + # if unspecified, we will save to the current python file dir + if filesys_save_dir is not None: + save_dir = filesys_save_dir + elif "EXP_DIR" in os.environ: + save_dir = os.environ["EXP_DIR"] + else: + # try the same directory where the code is stored + save_dir = filesys_save_dir or os.path.dirname(__file__) + save_dir = os.path.join(save_dir, "all_gather_via_filesys") + if is_main_process(): + os.makedirs(save_dir, exist_ok=True) + + # use a timestamp and salt to distinguish different all_gather + timestamp = int(time.time()) if is_main_process() else 0 + salt = random.randint(0, 2**31 - 1) if is_main_process() else 0 + # broadcast the timestamp and salt across ranks + # (all-reduce will do the broadcasting since only rank 0 is non-zero) + timestamp_and_salt = torch.tensor([timestamp, salt], dtype=torch.long) + dist.all_reduce(timestamp_and_salt, group=cpu_group) + timestamp, salt = timestamp_and_salt.tolist() + + # save the data to a file on the disk + rank_save = get_rank() + save_data_filename = f"data_to_gather_{timestamp}_{salt}_{rank_save}.pkl" + save_data_path = os.path.join(save_dir, save_data_filename) + assert not os.path.exists(save_data_path), f"{save_data_path} already exists" + torch.save(data, save_data_path) + dist.barrier(group=cpu_group) + + # read the data from the files + data_list = [] + if rank_save == 0 or not gather_to_rank_0_only: + for rank_load in range(world_size): + load_data_filename = f"data_to_gather_{timestamp}_{salt}_{rank_load}.pkl" + load_data_path = os.path.join(save_dir, load_data_filename) + assert os.path.exists(load_data_path), f"cannot read {save_data_path}" + data_list.append(torch.load(load_data_path, weights_only=False)) + dist.barrier(group=cpu_group) + + # delete the saved file + os.remove(save_data_path) + return data_list + + +def all_gather(data, force_cpu=False, force_filesys=False, filesys_save_dir=None): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors) + Args: + data: any picklable object + Returns: + list[data]: list of data gathered from each rank + """ + + world_size = get_world_size() + if world_size == 1: + return [data] + + if os.getenv("MDETR_FILESYS_REDUCE_RANK_0_ONLY") == "1": + return all_gather_via_filesys( + data, filesys_save_dir, gather_to_rank_0_only=True + ) + + if os.getenv("MDETR_FILESYS_REDUCE") == "1" or force_filesys: + return all_gather_via_filesys(data, filesys_save_dir) + + cpu_group = None + if os.getenv("MDETR_CPU_REDUCE") == "1" or force_cpu: + cpu_group = _get_global_gloo_group() + + buffer = io.BytesIO() + torch.save(data, buffer) + data_view = buffer.getbuffer() + device = "cuda" if cpu_group is None else "cpu" + tensor = torch.ByteTensor(data_view).to(device) + + # obtain Tensor size of each rank + local_size = torch.tensor([tensor.numel()], device=device, dtype=torch.long) + size_list = [ + torch.tensor([0], device=device, dtype=torch.long) for _ in range(world_size) + ] + if cpu_group is None: + dist.all_gather(size_list, local_size) + else: + print("gathering on cpu") + dist.all_gather(size_list, local_size, group=cpu_group) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + assert isinstance(local_size.item(), int) + local_size = int(local_size.item()) + + # receiving Tensor from all ranks + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device=device)) + if local_size != max_size: + padding = torch.empty( + size=(max_size - local_size,), dtype=torch.uint8, device=device + ) + tensor = torch.cat((tensor, padding), dim=0) + if cpu_group is None: + dist.all_gather(tensor_list, tensor) + else: + dist.all_gather(tensor_list, tensor, group=cpu_group) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + tensor = torch.split(tensor, [size, max_size - size], dim=0)[0] + buffer = io.BytesIO(tensor.cpu().numpy()) + obj = torch.load(buffer, weights_only=False) + data_list.append(obj) + + return data_list + + +def convert_to_distributed_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, str]: + """ + For some backends, such as NCCL, communication only works if the + tensor is on the GPU. This helper function converts to the correct + device and returns the tensor + original device. + """ + orig_device = "cpu" if not tensor.is_cuda else "gpu" + if ( + torch.distributed.is_available() + and torch.distributed.get_backend() == torch.distributed.Backend.NCCL + and not tensor.is_cuda + ): + tensor = tensor.cuda() + return (tensor, orig_device) + + +def convert_to_normal_tensor(tensor: torch.Tensor, orig_device: str) -> torch.Tensor: + """ + For some backends, such as NCCL, communication only works if the + tensor is on the GPU. This converts the tensor back to original device. + """ + if tensor.is_cuda and orig_device == "cpu": + tensor = tensor.cpu() + return tensor + + +def is_distributed_training_run() -> bool: + return ( + torch.distributed.is_available() + and torch.distributed.is_initialized() + and (torch.distributed.get_world_size() > 1) + ) + + +def is_primary() -> bool: + """ + Returns True if this is rank 0 of a distributed training job OR if it is + a single trainer job. Otherwise False. + """ + return get_rank() == _PRIMARY_RANK + + +def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: + """ + Wrapper over torch.distributed.all_reduce for performing mean reduction + of tensor over all processes. + """ + return all_reduce_op( + tensor, + torch.distributed.ReduceOp.SUM, + lambda t: t / torch.distributed.get_world_size(), + ) + + +def all_reduce_sum(tensor: torch.Tensor) -> torch.Tensor: + """ + Wrapper over torch.distributed.all_reduce for performing sum + reduction of tensor over all processes in both distributed / + non-distributed scenarios. + """ + return all_reduce_op(tensor, torch.distributed.ReduceOp.SUM) + + +def all_reduce_min(tensor: torch.Tensor) -> torch.Tensor: + """ + Wrapper over torch.distributed.all_reduce for performing min + reduction of tensor over all processes in both distributed / + non-distributed scenarios. + """ + return all_reduce_op(tensor, torch.distributed.ReduceOp.MIN) + + +def all_reduce_max(tensor: torch.Tensor) -> torch.Tensor: + """ + Wrapper over torch.distributed.all_reduce for performing min + reduction of tensor over all processes in both distributed / + non-distributed scenarios. + """ + return all_reduce_op(tensor, torch.distributed.ReduceOp.MAX) + + +def all_reduce_op( + tensor: torch.Tensor, + op: torch.distributed.ReduceOp, + after_op_func: Callable[[torch.Tensor], torch.Tensor] = None, +) -> torch.Tensor: + """ + Wrapper over torch.distributed.all_reduce for performing + reduction of tensor over all processes in both distributed / + non-distributed scenarios. + """ + if is_distributed_training_run(): + tensor, orig_device = convert_to_distributed_tensor(tensor) + torch.distributed.all_reduce(tensor, op) + if after_op_func is not None: + tensor = after_op_func(tensor) + tensor = convert_to_normal_tensor(tensor, orig_device) + return tensor + + +def gather_tensors_from_all(tensor: torch.Tensor) -> List[torch.Tensor]: + """ + Wrapper over torch.distributed.all_gather for performing + 'gather' of 'tensor' over all processes in both distributed / + non-distributed scenarios. + """ + if tensor.ndim == 0: + # 0 dim tensors cannot be gathered. so unsqueeze + tensor = tensor.unsqueeze(0) + + if is_distributed_training_run(): + tensor, orig_device = convert_to_distributed_tensor(tensor) + gathered_tensors = [ + torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size()) + ] + torch.distributed.all_gather(gathered_tensors, tensor) + gathered_tensors = [ + convert_to_normal_tensor(_tensor, orig_device) + for _tensor in gathered_tensors + ] + else: + gathered_tensors = [tensor] + + return gathered_tensors + + +def gather_from_all(tensor: torch.Tensor) -> torch.Tensor: + gathered_tensors = gather_tensors_from_all(tensor) + gathered_tensor = torch.cat(gathered_tensors, 0) + return gathered_tensor + + +def broadcast(tensor: torch.Tensor, src: int = 0) -> torch.Tensor: + """ + Wrapper over torch.distributed.broadcast for broadcasting a tensor from the source + to all processes in both distributed / non-distributed scenarios. + """ + if is_distributed_training_run(): + tensor, orig_device = convert_to_distributed_tensor(tensor) + torch.distributed.broadcast(tensor, src) + tensor = convert_to_normal_tensor(tensor, orig_device) + return tensor + + +def barrier() -> None: + """ + Wrapper over torch.distributed.barrier, returns without waiting + if the distributed process group is not initialized instead of throwing error. + """ + if not torch.distributed.is_available() or not torch.distributed.is_initialized(): + return + torch.distributed.barrier() + + +def get_world_size() -> int: + """ + Simple wrapper for correctly getting worldsize in both distributed + / non-distributed settings + """ + return ( + torch.distributed.get_world_size() + if torch.distributed.is_available() and torch.distributed.is_initialized() + else 1 + ) + + +def get_rank() -> int: + """ + Simple wrapper for correctly getting rank in both distributed + / non-distributed settings + """ + return ( + torch.distributed.get_rank() + if torch.distributed.is_available() and torch.distributed.is_initialized() + else 0 + ) + + +def get_primary_rank() -> int: + return _PRIMARY_RANK + + +def set_cuda_device_index(idx: int) -> None: + global _cuda_device_index + _cuda_device_index = idx + torch.cuda.set_device(_cuda_device_index) + + +def set_cpu_device() -> None: + global _cuda_device_index + _cuda_device_index = _CPU_DEVICE_INDEX + + +def get_cuda_device_index() -> int: + return _cuda_device_index + + +def init_distributed_data_parallel_model( + model: torch.nn.Module, + broadcast_buffers: bool = False, + find_unused_parameters: bool = True, + bucket_cap_mb: int = 25, +) -> torch.nn.parallel.DistributedDataParallel: + global _cuda_device_index + + if _cuda_device_index == _CPU_DEVICE_INDEX: + # CPU-only model, don't specify device + return torch.nn.parallel.DistributedDataParallel( + model, + broadcast_buffers=broadcast_buffers, + find_unused_parameters=find_unused_parameters, + bucket_cap_mb=bucket_cap_mb, + ) + else: + # GPU model + return torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[_cuda_device_index], + output_device=_cuda_device_index, + broadcast_buffers=broadcast_buffers, + find_unused_parameters=find_unused_parameters, + bucket_cap_mb=bucket_cap_mb, + ) + + +def broadcast_object(obj: Any, src: int = _PRIMARY_RANK, use_disk: bool = True) -> Any: + """Broadcast an object from a source to all workers. + + Args: + obj: Object to broadcast, must be serializable + src: Source rank for broadcast (default is primary) + use_disk: If enabled, removes redundant CPU memory copies by writing to + disk + """ + # Either broadcast from primary to the fleet (default), + # or use the src setting as the original rank + if get_rank() == src: + # Emit data + buffer = io.BytesIO() + torch.save(obj, buffer) + data_view = buffer.getbuffer() + length_tensor = torch.LongTensor([len(data_view)]) + length_tensor = broadcast(length_tensor, src=src) + data_tensor = torch.ByteTensor(data_view) + data_tensor = broadcast(data_tensor, src=src) + else: + # Fetch from the source + length_tensor = torch.LongTensor([0]) + length_tensor = broadcast(length_tensor, src=src) + data_tensor = torch.empty([length_tensor.item()], dtype=torch.uint8) + data_tensor = broadcast(data_tensor, src=src) + if use_disk: + with tempfile.TemporaryFile("r+b") as f: + f.write(data_tensor.numpy()) + # remove reference to the data tensor and hope that Python garbage + # collects it + del data_tensor + f.seek(0) + obj = torch.load(f, weights_only=False) + else: + buffer = io.BytesIO(data_tensor.numpy()) + obj = torch.load(buffer, weights_only=False) + return obj + + +def all_gather_tensor(tensor: torch.Tensor, world_size=None): + if world_size is None: + world_size = get_world_size() + # make contiguous because NCCL won't gather the tensor otherwise + assert tensor.is_contiguous(), f"{tensor.shape} is not contiguous!" + tensor, orig_device = convert_to_distributed_tensor(tensor) + tensor_all = [torch.ones_like(tensor) for _ in range(world_size)] + dist.all_gather(tensor_all, tensor, async_op=False) # performance opt + tensor_all = [ + convert_to_normal_tensor(tensor, orig_device) for tensor in tensor_all + ] + return tensor_all + + +def all_gather_batch(tensors: List[torch.Tensor]): + """ + Performs all_gather operation on the provided tensors. + """ + # Queue the gathered tensors + world_size = get_world_size() + # There is no need for reduction in the single-proc case + if world_size == 1: + return tensors + tensor_list = [] + output_tensor = [] + for tensor in tensors: + tensor_all = all_gather_tensor(tensor, world_size) + tensor_list.append(tensor_all) + + for tensor_all in tensor_list: + output_tensor.append(torch.cat(tensor_all, dim=0)) + return output_tensor + + +class GatherLayer(autograd.Function): + """ + Gather tensors from all workers with support for backward propagation: + This implementation does not cut the gradients as torch.distributed.all_gather does. + """ + + @staticmethod + def forward(ctx, x): + output = [torch.zeros_like(x) for _ in range(dist.get_world_size())] + dist.all_gather(output, x) + return tuple(output) + + @staticmethod + def backward(ctx, *grads): + all_gradients = torch.stack(grads) + dist.all_reduce(all_gradients) + return all_gradients[dist.get_rank()] + + +def all_gather_batch_with_grad(tensors): + """ + Performs all_gather operation on the provided tensors. + Graph remains connected for backward grad computation. + """ + # Queue the gathered tensors + world_size = get_world_size() + # There is no need for reduction in the single-proc case + if world_size == 1: + return tensors + tensor_list = [] + output_tensor = [] + + for tensor in tensors: + tensor_all = GatherLayer.apply(tensor) + tensor_list.append(tensor_all) + + for tensor_all in tensor_list: + output_tensor.append(torch.cat(tensor_all, dim=0)) + return output_tensor + + +def unwrap_ddp_if_wrapped(model): + if isinstance(model, torch.nn.parallel.DistributedDataParallel): + return model.module + return model + + +def create_new_process_group(group_size): + """ + Creates process groups of a gives `group_size` and returns + process group that current GPU participates in. + + `group_size` must divide the total number of GPUs (world_size). + + Modified from + https://github.com/NVIDIA/apex/blob/4e1ae43f7f7ac69113ef426dd15f37123f0a2ed3/apex/parallel/__init__.py#L60 + + Args: + group_size (int): number of GPU's to collaborate for sync bn + """ + + assert group_size > 0 + + world_size = torch.distributed.get_world_size() + if world_size <= 8: + if group_size > world_size: + logging.warning( + f"Requested group size [{group_size}] > world size [{world_size}]. " + "Assuming local debug run and capping it to world size." + ) + group_size = world_size + assert world_size >= group_size + assert world_size % group_size == 0 + + group = None + for group_num in range(world_size // group_size): + group_ids = range(group_num * group_size, (group_num + 1) * group_size) + cur_group = torch.distributed.new_group(ranks=group_ids) + if torch.distributed.get_rank() // group_size == group_num: + group = cur_group + # can not drop out and return here, every process must go through creation of all subgroups + + assert group is not None + return group + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def gather_to_rank_0_via_filesys(data, filesys_save_dir=None): + """ + Gather any picklable data to rank 0 via filesystem, using all_gather_via_filesys. + """ + return all_gather_via_filesys(data, filesys_save_dir, gather_to_rank_0_only=True) diff --git a/third_party/sam3/sam3/train/utils/logger.py b/third_party/sam3/sam3/train/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..4d6c071deb4bb0ecf9acc71d1b4e6804e774e449 --- /dev/null +++ b/third_party/sam3/sam3/train/utils/logger.py @@ -0,0 +1,241 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +import atexit +import functools +import logging +import sys +import uuid +from typing import Any, Dict, Optional, Union + +from hydra.utils import instantiate +from iopath.common.file_io import g_pathmgr +from numpy import ndarray +from sam3.train.utils.train_utils import get_machine_local_and_dist_rank, makedir +from torch import Tensor +from torch.utils.tensorboard import SummaryWriter + +Scalar = Union[Tensor, ndarray, int, float] + + +def make_tensorboard_logger(log_dir: str, **writer_kwargs: Any): + makedir(log_dir) + summary_writer_method = SummaryWriter + return TensorBoardLogger( + path=log_dir, summary_writer_method=summary_writer_method, **writer_kwargs + ) + + +class TensorBoardWriterWrapper: + """ + A wrapper around a SummaryWriter object. + """ + + def __init__( + self, + path: str, + *args: Any, + filename_suffix: str = None, + summary_writer_method: Any = SummaryWriter, + **kwargs: Any, + ) -> None: + """Create a new TensorBoard logger. + On construction, the logger creates a new events file that logs + will be written to. If the environment variable `RANK` is defined, + logger will only log if RANK = 0. + + NOTE: If using the logger with distributed training: + - This logger can call collective operations + - Logs will be written on rank 0 only + - Logger must be constructed synchronously *after* initializing distributed process group. + + Args: + path (str): path to write logs to + *args, **kwargs: Extra arguments to pass to SummaryWriter + """ + self._writer: Optional[SummaryWriter] = None + _, self._rank = get_machine_local_and_dist_rank() + self._path: str = path + if self._rank == 0: + logging.info( + f"TensorBoard SummaryWriter instantiated. Files will be stored in: {path}" + ) + self._writer = summary_writer_method( + log_dir=path, + *args, + filename_suffix=filename_suffix or str(uuid.uuid4()), + **kwargs, + ) + else: + logging.debug( + f"Not logging meters on this host because env RANK: {self._rank} != 0" + ) + atexit.register(self.close) + + @property + def writer(self) -> Optional[SummaryWriter]: + return self._writer + + @property + def path(self) -> str: + return self._path + + def flush(self) -> None: + """Writes pending logs to disk.""" + + if not self._writer: + return + + self._writer.flush() + + def close(self) -> None: + """Close writer, flushing pending logs to disk. + Logs cannot be written after `close` is called. + """ + + if not self._writer: + return + + self._writer.close() + self._writer = None + + +class TensorBoardLogger(TensorBoardWriterWrapper): + """ + A simple logger for TensorBoard. + """ + + def log_dict(self, payload: Dict[str, Scalar], step: int) -> None: + """Add multiple scalar values to TensorBoard. + + Args: + payload (dict): dictionary of tag name and scalar value + step (int, Optional): step value to record + """ + if not self._writer: + return + for k, v in payload.items(): + self.log(k, v, step) + + def log(self, name: str, data: Scalar, step: int) -> None: + """Add scalar data to TensorBoard. + + Args: + name (string): tag name used to group scalars + data (float/int/Tensor): scalar data to log + step (int, optional): step value to record + """ + if not self._writer: + return + self._writer.add_scalar(name, data, global_step=step, new_style=True) + + def log_hparams( + self, hparams: Dict[str, Scalar], meters: Dict[str, Scalar] + ) -> None: + """Add hyperparameter data to TensorBoard. + + Args: + hparams (dict): dictionary of hyperparameter names and corresponding values + meters (dict): dictionary of name of meter and corersponding values + """ + if not self._writer: + return + self._writer.add_hparams(hparams, meters) + + +class Logger: + """ + A logger class that can interface with multiple loggers. It now supports tensorboard only for simplicity, but you can extend it with your own logger. + """ + + def __init__(self, logging_conf): + # allow turning off TensorBoard with "should_log: false" in config + tb_config = logging_conf.tensorboard_writer + tb_should_log = tb_config and tb_config.pop("should_log", True) + self.tb_logger = instantiate(tb_config) if tb_should_log else None + + def log_dict(self, payload: Dict[str, Scalar], step: int) -> None: + if self.tb_logger: + self.tb_logger.log_dict(payload, step) + + def log(self, name: str, data: Scalar, step: int) -> None: + if self.tb_logger: + self.tb_logger.log(name, data, step) + + def log_hparams( + self, hparams: Dict[str, Scalar], meters: Dict[str, Scalar] + ) -> None: + if self.tb_logger: + self.tb_logger.log_hparams(hparams, meters) + + +# cache the opened file object, so that different calls to `setup_logger` +# with the same file name can safely write to the same file. +@functools.lru_cache(maxsize=None) +def _cached_log_stream(filename): + # we tune the buffering value so that the logs are updated + # frequently. + log_buffer_kb = 10 * 1024 # 10KB + io = g_pathmgr.open(filename, mode="a", buffering=log_buffer_kb) + atexit.register(io.close) + return io + + +def setup_logging( + name, + output_dir=None, + rank=0, + log_level_primary="INFO", + log_level_secondary="ERROR", +): + """ + Setup various logging streams: stdout and file handlers. + For file handlers, we only setup for the master gpu. + """ + # get the filename if we want to log to the file as well + log_filename = None + if output_dir: + makedir(output_dir) + if rank == 0: + log_filename = f"{output_dir}/log.txt" + + logger = logging.getLogger(name) + logger.setLevel(log_level_primary) + + # create formatter + FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)4d: %(message)s" + formatter = logging.Formatter(FORMAT) + + # Cleanup any existing handlers + for h in logger.handlers: + logger.removeHandler(h) + logger.root.handlers = [] + + # setup the console handler + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setFormatter(formatter) + logger.addHandler(console_handler) + if rank == 0: + console_handler.setLevel(log_level_primary) + else: + console_handler.setLevel(log_level_secondary) + + # we log to file as well if user wants + if log_filename and rank == 0: + file_handler = logging.StreamHandler(_cached_log_stream(log_filename)) + file_handler.setLevel(log_level_primary) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + + logging.root = logger + + +def shutdown_logging(): + """ + After training is done, we ensure to shut down all the logger streams. + """ + logging.info("Shutting down loggers...") + handlers = logging.root.handlers + for handler in handlers: + handler.close() diff --git a/third_party/sam3/sam3/train/utils/train_utils.py b/third_party/sam3/sam3/train/utils/train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..27e9b93f26b81c64ba8e572cb3fd5068cb69e853 --- /dev/null +++ b/third_party/sam3/sam3/train/utils/train_utils.py @@ -0,0 +1,286 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe + +import logging +import math +import os +import random +import re +from datetime import timedelta +from typing import Optional + +import hydra +import numpy as np +import omegaconf +import torch +import torch.distributed as dist +from iopath.common.file_io import g_pathmgr +from omegaconf import OmegaConf + + +def multiply_all(*args): + return np.prod(np.array(args)).item() + + +def collect_dict_keys(config): + """This function recursively iterates through a dataset configuration, and collect all the dict_key that are defined""" + val_keys = [] + # If the this config points to the collate function, then it has a key + if "_target_" in config and re.match(r".*collate_fn.*", config["_target_"]): + val_keys.append(config["dict_key"]) + else: + # Recursively proceed + for v in config.values(): + if isinstance(v, type(config)): + val_keys.extend(collect_dict_keys(v)) + elif isinstance(v, omegaconf.listconfig.ListConfig): + for item in v: + if isinstance(item, type(config)): + val_keys.extend(collect_dict_keys(item)) + return val_keys + + +class Phase: + TRAIN = "train" + VAL = "val" + + +def register_omegaconf_resolvers(): + OmegaConf.register_new_resolver("get_method", hydra.utils.get_method) + OmegaConf.register_new_resolver("get_class", hydra.utils.get_class) + OmegaConf.register_new_resolver("add", lambda x, y: x + y) + OmegaConf.register_new_resolver("times", multiply_all) + OmegaConf.register_new_resolver("divide", lambda x, y: x / y) + OmegaConf.register_new_resolver("pow", lambda x, y: x**y) + OmegaConf.register_new_resolver("subtract", lambda x, y: x - y) + OmegaConf.register_new_resolver("range", lambda x: list(range(x))) + OmegaConf.register_new_resolver("int", lambda x: int(x)) + OmegaConf.register_new_resolver("ceil_int", lambda x: int(math.ceil(x))) + OmegaConf.register_new_resolver("merge", lambda *x: OmegaConf.merge(*x)) + OmegaConf.register_new_resolver("string", lambda x: str(x)) + + +def setup_distributed_backend(backend, timeout_mins): + """ + Initialize torch.distributed and set the CUDA device. + Expects environment variables to be set as per + https://pytorch.org/docs/stable/distributed.html#environment-variable-initialization + along with the environ variable "LOCAL_RANK" which is used to set the CUDA device. + """ + # enable TORCH_NCCL_ASYNC_ERROR_HANDLING to ensure dist nccl ops time out after timeout_mins + # of waiting + os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1" + logging.info(f"Setting up torch.distributed with a timeout of {timeout_mins} mins") + dist.init_process_group(backend=backend, timeout=timedelta(minutes=timeout_mins)) + return dist.get_rank() + + +def get_machine_local_and_dist_rank(): + """ + Get the distributed and local rank of the current gpu. + """ + local_rank = int(os.environ.get("LOCAL_RANK", None)) + distributed_rank = int(os.environ.get("RANK", None)) + assert ( + local_rank is not None and distributed_rank is not None + ), "Please the set the RANK and LOCAL_RANK environment variables." + return local_rank, distributed_rank + + +def print_cfg(cfg): + """ + Supports printing both Hydra DictConfig and also the AttrDict config + """ + logging.info("Training with config:") + logging.info(OmegaConf.to_yaml(cfg)) + + +def set_seeds(seed_value, max_epochs, dist_rank): + """ + Set the python random, numpy and torch seed for each gpu. Also set the CUDA + seeds if the CUDA is available. This ensures deterministic nature of the training. + """ + # Since in the pytorch sampler, we increment the seed by 1 for every epoch. + seed_value = (seed_value + dist_rank) * max_epochs + logging.info(f"MACHINE SEED: {seed_value}") + random.seed(seed_value) + np.random.seed(seed_value) + torch.manual_seed(seed_value) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed_value) + + +def makedir(dir_path): + """ + Create the directory if it does not exist. + """ + is_success = False + try: + if not g_pathmgr.exists(dir_path): + g_pathmgr.mkdirs(dir_path) + is_success = True + except BaseException: + logging.info(f"Error creating directory: {dir_path}") + return is_success + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_amp_type(amp_type: Optional[str] = None): + if amp_type is None: + return None + assert amp_type in ["bfloat16", "float16"], "Invalid Amp type." + if amp_type == "bfloat16": + return torch.bfloat16 + else: + return torch.float16 + + +def log_env_variables(): + env_keys = sorted(list(os.environ.keys())) + st = "" + for k in env_keys: + v = os.environ[k] + st += f"{k}={v}\n" + logging.info("Logging ENV_VARIABLES") + logging.info(st) + + +class AverageMeter: + """Computes and stores the average and current value""" + + def __init__(self, name, device, fmt=":f"): + self.name = name + self.fmt = fmt + self.device = device + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + self._allow_updates = True + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = "{name}: {val" + self.fmt + "} ({avg" + self.fmt + "})" + return fmtstr.format(**self.__dict__) + + +class MemMeter: + """Computes and stores the current, avg, and max of peak Mem usage per iteration""" + + def __init__(self, name, device, fmt=":f"): + self.name = name + self.fmt = fmt + self.device = device + self.reset() + + def reset(self): + self.val = 0 # Per iteration max usage + self.avg = 0 # Avg per iteration max usage + self.peak = 0 # Peak usage for lifetime of program + self.sum = 0 + self.count = 0 + self._allow_updates = True + + def update(self, n=1, reset_peak_usage=True): + self.val = torch.cuda.max_memory_allocated() // 1e9 + self.sum += self.val * n + self.count += n + self.avg = self.sum / self.count + self.peak = max(self.peak, self.val) + if reset_peak_usage: + torch.cuda.reset_peak_memory_stats() + + def __str__(self): + fmtstr = ( + "{name}: {val" + + self.fmt + + "} ({avg" + + self.fmt + + "}/{peak" + + self.fmt + + "})" + ) + return fmtstr.format(**self.__dict__) + + +def human_readable_time(time_seconds): + time = int(time_seconds) + minutes, seconds = divmod(time, 60) + hours, minutes = divmod(minutes, 60) + days, hours = divmod(hours, 24) + return f"{days:02}d {hours:02}h {minutes:02}m" + + +class DurationMeter: + def __init__(self, name, device, fmt=":f"): + self.name = name + self.device = device + self.fmt = fmt + self.val = 0 + + def reset(self): + self.val = 0 + + def update(self, val): + self.val = val + + def add(self, val): + self.val += val + + def __str__(self): + return f"{self.name}: {human_readable_time(self.val)}" + + +class ProgressMeter: + def __init__(self, num_batches, meters, real_meters, prefix=""): + self.batch_fmtstr = self._get_batch_fmtstr(num_batches) + self.meters = meters + self.real_meters = real_meters + self.prefix = prefix + + def display(self, batch, enable_print=False): + entries = [self.prefix + self.batch_fmtstr.format(batch)] + entries += [str(meter) for meter in self.meters] + entries += [ + " | ".join( + [ + f"{os.path.join(name, subname)}: {val:.4f}" + for subname, val in meter.compute().items() + ] + ) + for name, meter in self.real_meters.items() + ] + logging.info(" | ".join(entries)) + if enable_print: + print(" | ".join(entries)) + + def _get_batch_fmtstr(self, num_batches): + num_digits = len(str(num_batches // 1)) + fmt = "{:" + str(num_digits) + "d}" + return "[" + fmt + "/" + fmt.format(num_batches) + "]" + + +def get_resume_checkpoint(checkpoint_save_dir): + if not g_pathmgr.isdir(checkpoint_save_dir): + return None + ckpt_file = os.path.join(checkpoint_save_dir, "checkpoint.pt") + if not g_pathmgr.isfile(ckpt_file): + return None + + return ckpt_file diff --git a/third_party/sam3/sam3/visualization_utils.py b/third_party/sam3/sam3/visualization_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b39d8a42071f0a53be8f014ad839413c983d7617 --- /dev/null +++ b/third_party/sam3/sam3/visualization_utils.py @@ -0,0 +1,943 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +# pyre-unsafe +import json +import os +import subprocess +from pathlib import Path + +import cv2 +import matplotlib.patches as patches +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import pycocotools.mask as mask_utils +import torch +from matplotlib.colors import to_rgb +from PIL import Image +from skimage.color import lab2rgb, rgb2lab +from sklearn.cluster import KMeans +from torchvision.ops import masks_to_boxes +from tqdm import tqdm + + +def generate_colors(n_colors=256, n_samples=5000): + # Step 1: Random RGB samples + np.random.seed(42) + rgb = np.random.rand(n_samples, 3) + # Step 2: Convert to LAB for perceptual uniformity + # print(f"Converting {n_samples} RGB samples to LAB color space...") + lab = rgb2lab(rgb.reshape(1, -1, 3)).reshape(-1, 3) + # print("Conversion to LAB complete.") + # Step 3: k-means clustering in LAB + kmeans = KMeans(n_clusters=n_colors, n_init=10) + # print(f"Fitting KMeans with {n_colors} clusters on {n_samples} samples...") + kmeans.fit(lab) + # print("KMeans fitting complete.") + centers_lab = kmeans.cluster_centers_ + # Step 4: Convert LAB back to RGB + colors_rgb = lab2rgb(centers_lab.reshape(1, -1, 3)).reshape(-1, 3) + colors_rgb = np.clip(colors_rgb, 0, 1) + return colors_rgb + + +COLORS = generate_colors(n_colors=128, n_samples=5000) + + +def show_img_tensor(img_batch, vis_img_idx=0): + MEAN_IMG = np.array([0.5, 0.5, 0.5]) + STD_IMG = np.array([0.5, 0.5, 0.5]) + im_tensor = img_batch[vis_img_idx].detach().cpu() + assert im_tensor.dim() == 3 + im_tensor = im_tensor.numpy().transpose((1, 2, 0)) + im_tensor = (im_tensor * STD_IMG) + MEAN_IMG + im_tensor = np.clip(im_tensor, 0, 1) + plt.imshow(im_tensor) + + +def draw_box_on_image(image, box, color=(0, 255, 0)): + """ + Draws a rectangle on a given PIL image using the provided box coordinates in xywh format. + :param image: PIL.Image - The image on which to draw the rectangle. + :param box: tuple - A tuple (x, y, w, h) representing the top-left corner, width, and height of the rectangle. + :param color: tuple - A tuple (R, G, B) representing the color of the rectangle. Default is red. + :return: PIL.Image - The image with the rectangle drawn on it. + """ + # Ensure the image is in RGB mode + image = image.convert("RGB") + # Unpack the box coordinates + x, y, w, h = box + x, y, w, h = int(x), int(y), int(w), int(h) + # Get the pixel data + pixels = image.load() + # Draw the top and bottom edges + for i in range(x, x + w): + pixels[i, y] = color + pixels[i, y + h - 1] = color + pixels[i, y + 1] = color + pixels[i, y + h] = color + pixels[i, y - 1] = color + pixels[i, y + h - 2] = color + # Draw the left and right edges + for j in range(y, y + h): + pixels[x, j] = color + pixels[x + 1, j] = color + pixels[x - 1, j] = color + pixels[x + w - 1, j] = color + pixels[x + w, j] = color + pixels[x + w - 2, j] = color + return image + + +def plot_bbox( + img_height, + img_width, + box, + box_format="XYXY", + relative_coords=True, + color="r", + linestyle="solid", + text=None, + ax=None, +): + if box_format == "XYXY": + x, y, x2, y2 = box + w = x2 - x + h = y2 - y + elif box_format == "XYWH": + x, y, w, h = box + elif box_format == "CxCyWH": + cx, cy, w, h = box + x = cx - w / 2 + y = cy - h / 2 + else: + raise RuntimeError(f"Invalid box_format {box_format}") + + if relative_coords: + x *= img_width + w *= img_width + y *= img_height + h *= img_height + + if ax is None: + ax = plt.gca() + rect = patches.Rectangle( + (x, y), + w, + h, + linewidth=1.5, + edgecolor=color, + facecolor="none", + linestyle=linestyle, + ) + ax.add_patch(rect) + if text is not None: + facecolor = "w" + ax.text( + x, + y - 5, + text, + color=color, + weight="bold", + fontsize=8, + bbox={"facecolor": facecolor, "alpha": 0.75, "pad": 2}, + ) + + +def plot_mask(mask, color="r", ax=None): + im_h, im_w = mask.shape + mask_img = np.zeros((im_h, im_w, 4), dtype=np.float32) + mask_img[..., :3] = to_rgb(color) + mask_img[..., 3] = mask * 0.5 + # Use the provided ax or the current axis + if ax is None: + ax = plt.gca() + ax.imshow(mask_img) + + +def normalize_bbox(bbox_xywh, img_w, img_h): + # Assumes bbox_xywh is in XYWH format + if isinstance(bbox_xywh, list): + assert ( + len(bbox_xywh) == 4 + ), "bbox_xywh list must have 4 elements. Batching not support except for torch tensors." + normalized_bbox = bbox_xywh.copy() + normalized_bbox[0] /= img_w + normalized_bbox[1] /= img_h + normalized_bbox[2] /= img_w + normalized_bbox[3] /= img_h + else: + assert isinstance( + bbox_xywh, torch.Tensor + ), "Only torch tensors are supported for batching." + normalized_bbox = bbox_xywh.clone() + assert ( + normalized_bbox.size(-1) == 4 + ), "bbox_xywh tensor must have last dimension of size 4." + normalized_bbox[..., 0] /= img_w + normalized_bbox[..., 1] /= img_h + normalized_bbox[..., 2] /= img_w + normalized_bbox[..., 3] /= img_h + return normalized_bbox + + +def visualize_frame_output(frame_idx, video_frames, outputs, figsize=(12, 8)): + plt.figure(figsize=figsize) + plt.title(f"frame {frame_idx}") + img = load_frame(video_frames[frame_idx]) + img_H, img_W, _ = img.shape + plt.imshow(img) + for i in range(len(outputs["out_probs"])): + box_xywh = outputs["out_boxes_xywh"][i] + prob = outputs["out_probs"][i] + obj_id = outputs["out_obj_ids"][i] + binary_mask = outputs["out_binary_masks"][i] + color = COLORS[obj_id % len(COLORS)] + plot_bbox( + img_H, + img_W, + box_xywh, + text=f"(id={obj_id}, {prob=:.2f})", + box_format="XYWH", + color=color, + ) + plot_mask(binary_mask, color=color) + + +def visualize_formatted_frame_output( + frame_idx, + video_frames, + outputs_list, + titles=None, + points_list=None, + points_labels_list=None, + figsize=(12, 8), + title_suffix="", + prompt_info=None, +): + """Visualize up to three sets of segmentation masks on a video frame. + + Args: + frame_idx: Frame index to visualize + image_files: List of image file paths + outputs_list: List of {frame_idx: {obj_id: mask_tensor}} or single dict {obj_id: mask_tensor} + titles: List of titles for each set of outputs_list + points_list: Optional list of point coordinates + points_labels_list: Optional list of point labels + figsize: Figure size tuple + save: Whether to save the visualization to file + output_dir: Base output directory when saving + scenario_name: Scenario name for organizing saved files + title_suffix: Additional title suffix + prompt_info: Dictionary with prompt information (boxes, points, etc.) + """ + # Handle single output dict case + if isinstance(outputs_list, dict) and frame_idx in outputs_list: + # This is a single outputs dict with frame indices as keys + outputs_list = [outputs_list] + elif isinstance(outputs_list, dict) and not any( + isinstance(k, int) for k in outputs_list.keys() + ): + # This is a single frame's outputs {obj_id: mask} + single_frame_outputs = {frame_idx: outputs_list} + outputs_list = [single_frame_outputs] + + num_outputs = len(outputs_list) + if titles is None: + titles = [f"Set {i + 1}" for i in range(num_outputs)] + assert ( + len(titles) == num_outputs + ), "length of `titles` should match that of `outputs_list` if not None." + + _, axes = plt.subplots(1, num_outputs, figsize=figsize) + if num_outputs == 1: + axes = [axes] # Make it iterable + + img = load_frame(video_frames[frame_idx]) + img_H, img_W, _ = img.shape + + for idx in range(num_outputs): + ax, outputs_set, ax_title = axes[idx], outputs_list[idx], titles[idx] + ax.set_title(f"Frame {frame_idx} - {ax_title}{title_suffix}") + ax.imshow(img) + + if frame_idx in outputs_set: + _outputs = outputs_set[frame_idx] + else: + print(f"Warning: Frame {frame_idx} not found in outputs_set") + continue + + if prompt_info and frame_idx == 0: # Show prompts on first frame + if "boxes" in prompt_info: + for box in prompt_info["boxes"]: + # box is in [x, y, w, h] normalized format + x, y, w, h = box + plot_bbox( + img_H, + img_W, + [x, y, x + w, y + h], # Convert to XYXY + box_format="XYXY", + relative_coords=True, + color="yellow", + linestyle="dashed", + text="PROMPT BOX", + ax=ax, + ) + + if "points" in prompt_info and "point_labels" in prompt_info: + points = np.array(prompt_info["points"]) + labels = np.array(prompt_info["point_labels"]) + # Convert normalized to pixel coordinates + points_pixel = points * np.array([img_W, img_H]) + + # Draw positive points (green stars) + pos_points = points_pixel[labels == 1] + if len(pos_points) > 0: + ax.scatter( + pos_points[:, 0], + pos_points[:, 1], + color="lime", + marker="*", + s=200, + edgecolor="white", + linewidth=2, + label="Positive Points", + zorder=10, + ) + + # Draw negative points (red stars) + neg_points = points_pixel[labels == 0] + if len(neg_points) > 0: + ax.scatter( + neg_points[:, 0], + neg_points[:, 1], + color="red", + marker="*", + s=200, + edgecolor="white", + linewidth=2, + label="Negative Points", + zorder=10, + ) + + objects_drawn = 0 + for obj_id, binary_mask in _outputs.items(): + mask_sum = ( + binary_mask.sum() + if hasattr(binary_mask, "sum") + else np.sum(binary_mask) + ) + + if mask_sum > 0: # Only draw if mask has content + # Convert to torch tensor if it's not already + if not isinstance(binary_mask, torch.Tensor): + binary_mask = torch.tensor(binary_mask) + + # Find bounding box from mask + if binary_mask.any(): + box_xyxy = masks_to_boxes(binary_mask.unsqueeze(0)).squeeze() + box_xyxy = normalize_bbox(box_xyxy, img_W, img_H) + else: + # Fallback: create a small box at center + box_xyxy = [0.45, 0.45, 0.55, 0.55] + + color = COLORS[obj_id % len(COLORS)] + + plot_bbox( + img_H, + img_W, + box_xyxy, + text=f"(id={obj_id})", + box_format="XYXY", + color=color, + ax=ax, + ) + + # Convert back to numpy for plotting + mask_np = ( + binary_mask.numpy() + if isinstance(binary_mask, torch.Tensor) + else binary_mask + ) + plot_mask(mask_np, color=color, ax=ax) + objects_drawn += 1 + + if objects_drawn == 0: + ax.text( + 0.5, + 0.5, + "No objects detected", + transform=ax.transAxes, + fontsize=16, + ha="center", + va="center", + color="red", + weight="bold", + ) + + # Draw additional points if provided + if points_list is not None and points_list[idx] is not None: + show_points( + points_list[idx], points_labels_list[idx], ax=ax, marker_size=200 + ) + + ax.axis("off") + + plt.tight_layout() + plt.show() + + +def render_masklet_frame(img, outputs, frame_idx=None, alpha=0.5): + """ + Overlays masklets and bounding boxes on a single image frame. + Args: + img: np.ndarray, shape (H, W, 3), uint8 or float32 in [0,255] or [0,1] + outputs: dict with keys: out_boxes_xywh, out_probs, out_obj_ids, out_binary_masks + frame_idx: int or None, for overlaying frame index text + alpha: float, mask overlay alpha + Returns: + overlay: np.ndarray, shape (H, W, 3), uint8 + """ + if img.dtype == np.float32 or img.max() <= 1.0: + img = (img * 255).astype(np.uint8) + img = img[..., :3] # drop alpha if present + height, width = img.shape[:2] + overlay = img.copy() + + for i in range(len(outputs["out_probs"])): + obj_id = outputs["out_obj_ids"][i] + color = COLORS[obj_id % len(COLORS)] + color255 = (color * 255).astype(np.uint8) + mask = outputs["out_binary_masks"][i] + if mask.shape != img.shape[:2]: + mask = cv2.resize( + mask.astype(np.float32), + (img.shape[1], img.shape[0]), + interpolation=cv2.INTER_NEAREST, + ) + mask_bool = mask > 0.5 + for c in range(3): + overlay[..., c][mask_bool] = ( + alpha * color255[c] + (1 - alpha) * overlay[..., c][mask_bool] + ).astype(np.uint8) + + # Draw bounding boxes and text + for i in range(len(outputs["out_probs"])): + box_xywh = outputs["out_boxes_xywh"][i] + obj_id = outputs["out_obj_ids"][i] + prob = outputs["out_probs"][i] + color = COLORS[obj_id % len(COLORS)] + color255 = tuple(int(x * 255) for x in color) + x, y, w, h = box_xywh + x1 = int(x * width) + y1 = int(y * height) + x2 = int((x + w) * width) + y2 = int((y + h) * height) + cv2.rectangle(overlay, (x1, y1), (x2, y2), color255, 2) + if prob is not None: + label = f"id={obj_id}, p={prob:.2f}" + else: + label = f"id={obj_id}" + cv2.putText( + overlay, + label, + (x1, max(y1 - 10, 0)), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + color255, + 1, + cv2.LINE_AA, + ) + + # Overlay frame index at the top-left corner + if frame_idx is not None: + cv2.putText( + overlay, + f"Frame {frame_idx}", + (10, 30), + cv2.FONT_HERSHEY_SIMPLEX, + 1.0, + (255, 255, 255), + 2, + cv2.LINE_AA, + ) + + return overlay + + +def save_masklet_video(video_frames, outputs, out_path, alpha=0.5, fps=10): + # Each outputs dict has keys: "out_boxes_xywh", "out_probs", "out_obj_ids", "out_binary_masks" + # video_frames: list of video frame data, same length as outputs_list + + # Read first frame to get size + first_img = load_frame(video_frames[0]) + height, width = first_img.shape[:2] + if first_img.dtype == np.float32 or first_img.max() <= 1.0: + first_img = (first_img * 255).astype(np.uint8) + # Use 'mp4v' for best compatibility with VSCode playback (.mp4 files) + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + writer = cv2.VideoWriter("temp.mp4", fourcc, fps, (width, height)) + + outputs_list = [ + (video_frames[frame_idx], frame_idx, outputs[frame_idx]) + for frame_idx in sorted(outputs.keys()) + ] + + for frame, frame_idx, frame_outputs in tqdm(outputs_list): + img = load_frame(frame) + overlay = render_masklet_frame( + img, frame_outputs, frame_idx=frame_idx, alpha=alpha + ) + writer.write(cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR)) + + writer.release() + + # Re-encode the video for VSCode compatibility using ffmpeg + subprocess.run(["ffmpeg", "-y", "-i", "temp.mp4", out_path]) + print(f"Re-encoded video saved to {out_path}") + + os.remove("temp.mp4") # Clean up temporary file + + +def save_masklet_image(frame, outputs, out_path, alpha=0.5, frame_idx=None): + """ + Save a single image with masklet overlays. + """ + img = load_frame(frame) + overlay = render_masklet_frame(img, outputs, frame_idx=frame_idx, alpha=alpha) + Image.fromarray(overlay).save(out_path) + print(f"Overlay image saved to {out_path}") + + +def prepare_masks_for_visualization(frame_to_output): + # frame_to_obj_masks --> {frame_idx: {'output_probs': np.array, `out_obj_ids`: np.array, `out_binary_masks`: np.array}} + for frame_idx, out in frame_to_output.items(): + _processed_out = {} + for idx, obj_id in enumerate(out["out_obj_ids"].tolist()): + if out["out_binary_masks"][idx].any(): + _processed_out[obj_id] = out["out_binary_masks"][idx] + frame_to_output[frame_idx] = _processed_out + return frame_to_output + + +def convert_coco_to_masklet_format( + annotations, img_info, is_prediction=False, score_threshold=0.5 +): + """ + Convert COCO format annotations to format expected by render_masklet_frame + """ + outputs = { + "out_boxes_xywh": [], + "out_probs": [], + "out_obj_ids": [], + "out_binary_masks": [], + } + + img_h, img_w = img_info["height"], img_info["width"] + + for idx, ann in enumerate(annotations): + # Get bounding box in relative XYWH format + if "bbox" in ann: + bbox = ann["bbox"] + if max(bbox) > 1.0: # Convert absolute to relative coordinates + bbox = [ + bbox[0] / img_w, + bbox[1] / img_h, + bbox[2] / img_w, + bbox[3] / img_h, + ] + else: + mask = mask_utils.decode(ann["segmentation"]) + rows = np.any(mask, axis=1) + cols = np.any(mask, axis=0) + if np.any(rows) and np.any(cols): + rmin, rmax = np.where(rows)[0][[0, -1]] + cmin, cmax = np.where(cols)[0][[0, -1]] + # Convert to relative XYWH + bbox = [ + cmin / img_w, + rmin / img_h, + (cmax - cmin + 1) / img_w, + (rmax - rmin + 1) / img_h, + ] + else: + bbox = [0, 0, 0, 0] + + outputs["out_boxes_xywh"].append(bbox) + + # Get probability/score + if is_prediction: + prob = ann["score"] + else: + prob = 1.0 # GT has no probability + outputs["out_probs"].append(prob) + + outputs["out_obj_ids"].append(idx) + mask = mask_utils.decode(ann["segmentation"]) + mask = (mask > score_threshold).astype(np.uint8) + + outputs["out_binary_masks"].append(mask) + + return outputs + + +def save_side_by_side_visualization(img, gt_anns, pred_anns, noun_phrase): + """ + Create side-by-side visualization of GT and predictions + """ + + # Create side-by-side visualization + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 7)) + + main_title = f"Noun phrase: '{noun_phrase}'" + fig.suptitle(main_title, fontsize=16, fontweight="bold") + + gt_overlay = render_masklet_frame(img, gt_anns, alpha=0.5) + ax1.imshow(gt_overlay) + ax1.set_title("Ground Truth", fontsize=14, fontweight="bold") + ax1.axis("off") + + pred_overlay = render_masklet_frame(img, pred_anns, alpha=0.5) + ax2.imshow(pred_overlay) + ax2.set_title("Predictions", fontsize=14, fontweight="bold") + ax2.axis("off") + + plt.subplots_adjust(top=0.88) + plt.tight_layout() + + +def bitget(val, idx): + return (val >> idx) & 1 + + +def pascal_color_map(): + colormap = np.zeros((512, 3), dtype=int) + ind = np.arange(512, dtype=int) + for shift in reversed(list(range(8))): + for channel in range(3): + colormap[:, channel] |= bitget(ind, channel) << shift + ind >>= 3 + + return colormap.astype(np.uint8) + + +def draw_masks_to_frame( + frame: np.ndarray, masks: np.ndarray, colors: np.ndarray +) -> np.ndarray: + masked_frame = frame + for mask, color in zip(masks, colors): + curr_masked_frame = np.where(mask[..., None], color, masked_frame) + masked_frame = cv2.addWeighted(masked_frame, 0.75, curr_masked_frame, 0.25, 0) + + if int(cv2.__version__[0]) > 3: + contours, _ = cv2.findContours( + np.array(mask, dtype=np.uint8).copy(), + cv2.RETR_TREE, + cv2.CHAIN_APPROX_NONE, + ) + else: + _, contours, _ = cv2.findContours( + np.array(mask, dtype=np.uint8).copy(), + cv2.RETR_TREE, + cv2.CHAIN_APPROX_NONE, + ) + + cv2.drawContours( + masked_frame, contours, -1, (255, 255, 255), 7 + ) # White outer contour + cv2.drawContours( + masked_frame, contours, -1, (0, 0, 0), 5 + ) # Black middle contour + cv2.drawContours( + masked_frame, contours, -1, color.tolist(), 3 + ) # Original color inner contour + return masked_frame + + +def get_annot_df(file_path: str): + with open(file_path, "r") as f: + data = json.load(f) + + dfs = {} + + for k, v in data.items(): + if k in ("info", "licenses"): + dfs[k] = v + continue + df = pd.DataFrame(v) + dfs[k] = df + + return dfs + + +def get_annot_dfs(file_list: list[str]): + dfs = {} + for annot_file in tqdm(file_list): + dataset_name = Path(annot_file).stem + dfs[dataset_name] = get_annot_df(annot_file) + + return dfs + + +def get_media_dir(media_dir: str, dataset: str): + if dataset in ["saco_veval_sav_test", "saco_veval_sav_val"]: + return os.path.join(media_dir, "saco_sav", "JPEGImages_24fps") + elif dataset in ["saco_veval_yt1b_test", "saco_veval_yt1b_val"]: + return os.path.join(media_dir, "saco_yt1b", "JPEGImages_6fps") + elif dataset in ["saco_veval_smartglasses_test", "saco_veval_smartglasses_val"]: + return os.path.join(media_dir, "saco_sg", "JPEGImages_6fps") + elif dataset == "sa_fari_test": + return os.path.join(media_dir, "sa_fari", "JPEGImages_6fps") + else: + raise ValueError(f"Dataset {dataset} not found") + + +def get_all_annotations_for_frame( + dataset_df: pd.DataFrame, video_id: int, frame_idx: int, data_dir: str, dataset: str +): + media_dir = os.path.join(data_dir, "media") + + # Load the annotation and video data + annot_df = dataset_df["annotations"] + video_df = dataset_df["videos"] + + # Get the frame + video_df_current = video_df[video_df.id == video_id] + assert ( + len(video_df_current) == 1 + ), f"Expected 1 video row, got {len(video_df_current)}" + video_row = video_df_current.iloc[0] + file_name = video_row.file_names[frame_idx] + file_path = os.path.join( + get_media_dir(media_dir=media_dir, dataset=dataset), file_name + ) + frame = cv2.imread(file_path) + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + + # Get the masks and noun phrases annotated in this video in this frame + annot_df_current_video = annot_df[annot_df.video_id == video_id] + if len(annot_df_current_video) == 0: + print(f"No annotations found for video_id {video_id}") + return frame, None, None + else: + empty_mask = np.zeros(frame.shape[:2], dtype=np.uint8) + mask_np_pairs = annot_df_current_video.apply( + lambda row: ( + ( + mask_utils.decode(row.segmentations[frame_idx]) + if row.segmentations[frame_idx] + else empty_mask + ), + row.noun_phrase, + ), + axis=1, + ) + # sort based on noun_phrases + mask_np_pairs = sorted(mask_np_pairs, key=lambda x: x[1]) + masks, noun_phrases = zip(*mask_np_pairs) + + return frame, masks, noun_phrases + + +def visualize_prompt_overlay( + frame_idx, + video_frames, + title="Prompt Visualization", + text_prompt=None, + point_prompts=None, + point_labels=None, + bounding_boxes=None, + box_labels=None, + obj_id=None, +): + """Simple prompt visualization function""" + img = Image.fromarray(load_frame(video_frames[frame_idx])) + fig, ax = plt.subplots(1, figsize=(6, 4)) + ax.imshow(img) + + img_w, img_h = img.size + + if text_prompt: + ax.text( + 0.02, + 0.98, + f'Text: "{text_prompt}"', + transform=ax.transAxes, + fontsize=12, + color="white", + weight="bold", + bbox=dict(boxstyle="round,pad=0.3", facecolor="red", alpha=0.7), + verticalalignment="top", + ) + + if point_prompts: + for i, point in enumerate(point_prompts): + x, y = point + # Convert relative to absolute coordinates + x_img, y_img = x * img_w, y * img_h + + # Use different colors for positive/negative points + if point_labels and len(point_labels) > i: + color = "green" if point_labels[i] == 1 else "red" + marker = "o" if point_labels[i] == 1 else "x" + else: + color = "green" + marker = "o" + + ax.plot( + x_img, + y_img, + marker=marker, + color=color, + markersize=10, + markeredgewidth=2, + markeredgecolor="white", + ) + ax.text( + x_img + 5, + y_img - 5, + f"P{i + 1}", + color=color, + fontsize=10, + weight="bold", + bbox=dict(boxstyle="round,pad=0.2", facecolor="white", alpha=0.8), + ) + + if bounding_boxes: + for i, box in enumerate(bounding_boxes): + x, y, w, h = box + # Convert relative to absolute coordinates + x_img, y_img = x * img_w, y * img_h + w_img, h_img = w * img_w, h * img_h + + # Use different colors for positive/negative boxes + if box_labels and len(box_labels) > i: + color = "green" if box_labels[i] == 1 else "red" + else: + color = "green" + + rect = patches.Rectangle( + (x_img, y_img), + w_img, + h_img, + linewidth=2, + edgecolor=color, + facecolor="none", + ) + ax.add_patch(rect) + ax.text( + x_img, + y_img - 5, + f"B{i + 1}", + color=color, + fontsize=10, + weight="bold", + bbox=dict(boxstyle="round,pad=0.2", facecolor="white", alpha=0.8), + ) + + # Add object ID info if provided + if obj_id is not None: + ax.text( + 0.02, + 0.02, + f"Object ID: {obj_id}", + transform=ax.transAxes, + fontsize=10, + color="white", + weight="bold", + bbox=dict(boxstyle="round,pad=0.3", facecolor="blue", alpha=0.7), + verticalalignment="bottom", + ) + + ax.set_title(title) + ax.axis("off") + plt.tight_layout() + plt.show() + + +def plot_results(img, results): + plt.figure(figsize=(12, 8)) + plt.imshow(img) + nb_objects = len(results["scores"]) + print(f"found {nb_objects} object(s)") + for i in range(nb_objects): + color = COLORS[i % len(COLORS)] + plot_mask(results["masks"][i].squeeze(0).cpu(), color=color) + w, h = img.size + prob = results["scores"][i].item() + plot_bbox( + h, + w, + results["boxes"][i].cpu(), + text=f"(id={i}, {prob=:.2f})", + box_format="XYXY", + color=color, + relative_coords=False, + ) + + +def single_visualization(img, anns, title): + """ + Create a single image visualization with overlays. + """ + fig, ax = plt.subplots(figsize=(7, 7)) + fig.suptitle(title, fontsize=16, fontweight="bold") + overlay = render_masklet_frame(img, anns, alpha=0.5) + ax.imshow(overlay) + ax.axis("off") + plt.tight_layout() + + +def show_mask(mask, ax, obj_id=None, random_color=False): + if random_color: + color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) + else: + cmap = plt.get_cmap("tab10") + cmap_idx = 0 if obj_id is None else obj_id + color = np.array([*cmap(cmap_idx)[:3], 0.6]) + h, w = mask.shape[-2:] + mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) + ax.imshow(mask_image) + + +def show_box(box, ax): + x0, y0 = box[0], box[1] + w, h = box[2] - box[0], box[3] - box[1] + ax.add_patch( + plt.Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2) + ) + + +def show_points(coords, labels, ax, marker_size=375): + pos_points = coords[labels == 1] + neg_points = coords[labels == 0] + ax.scatter( + pos_points[:, 0], + pos_points[:, 1], + color="green", + marker="*", + s=marker_size, + edgecolor="white", + linewidth=1.25, + ) + ax.scatter( + neg_points[:, 0], + neg_points[:, 1], + color="red", + marker="*", + s=marker_size, + edgecolor="white", + linewidth=1.25, + ) + + +def load_frame(frame): + if isinstance(frame, np.ndarray): + img = frame + elif isinstance(frame, Image.Image): + img = np.array(frame) + elif isinstance(frame, str) and os.path.isfile(frame): + img = plt.imread(frame) + else: + raise ValueError(f"Invalid video frame type: {type(frame)=}") + return img diff --git a/vis3d_glb.py b/vis3d_glb.py new file mode 100644 index 0000000000000000000000000000000000000000..ec435b5d22119a194c476634221d81a0b3bcac6e --- /dev/null +++ b/vis3d_glb.py @@ -0,0 +1,648 @@ +"""Generate GLB scenes with colored point clouds / textured meshes and 3D boxes. + +Uses pygltflib for point cloud + wireframe boxes (GL_POINTS/GL_LINES), +and trimesh + utils3d for textured mesh generation. + +Usage: + from vis3d_glb import depth_to_pointcloud, create_scene_glb + from vis3d_glb import create_mesh_scene_glb + + # Point cloud mode + points, colors = depth_to_pointcloud(depth_map, image, intrinsics) + create_scene_glb(points, colors, boxes3d_list, output_path) + + # Textured mesh mode (like MoGe2) + create_mesh_scene_glb(depth_map, image, intrinsics, boxes3d_list, output_path) +""" + +import numpy as np +import pygltflib + + +def depth_to_pointcloud( + depth_map: np.ndarray, + image: np.ndarray, + intrinsics: np.ndarray, + max_depth: float = 20.0, + subsample: int = 4, + padding: tuple[int, int, int, int] | None = None, + remove_edge: bool = True, + edge_rtol: float = 0.04, + confidence_map: np.ndarray | None = None, + confidence_threshold: float = 0.0, +) -> tuple[np.ndarray, np.ndarray]: + """Convert depth map + RGB image to colored point cloud. + + Args: + depth_map: (H, W) or (1, H, W) depth in meters. + image: (H, W, 3) RGB image, uint8 [0-255]. + intrinsics: (3, 3) camera intrinsics matrix. + max_depth: Discard points beyond this depth. + subsample: Take every Nth pixel to reduce point count. + padding: (left, right, top, bottom) CenterPad offsets to exclude. + remove_edge: Remove points at depth discontinuity edges + (like MoGe2). Uses utils3d.np.depth_map_edge. + edge_rtol: Relative tolerance for edge detection. Larger + values remove more aggressive edges. + confidence_map: (H, W) or (1, H, W) per-pixel confidence in + [0, 1]. Points below confidence_threshold are discarded. + confidence_threshold: Minimum confidence to keep a point. + + Returns: + points: (N, 3) float32 xyz in camera frame. + colors: (N, 4) uint8 RGBA. + """ + # Handle various depth_map shapes + while depth_map.ndim > 2: + depth_map = depth_map.squeeze(0) # (1, 1, H, W) -> (H, W) + + H, W = depth_map.shape + + # Handle confidence_map shape + if confidence_map is not None: + while confidence_map.ndim > 2: + confidence_map = confidence_map.squeeze(0) + + # Handle various image shapes: (1, H, W, 3), (1, 1, H, W) etc + while image.ndim > 3: + image = image.squeeze(0) + # If image is (3, H, W), transpose to (H, W, 3) + if image.ndim == 3 and image.shape[0] in (1, 3): + image = np.transpose(image, (1, 2, 0)) + # If grayscale (H, W), repeat to (H, W, 3) + if image.ndim == 2: + image = np.stack([image] * 3, axis=-1) + + # Match image size to depth map + if image.shape[0] != H or image.shape[1] != W: + from PIL import Image as PILImage + img_pil = PILImage.fromarray(image) + img_pil = img_pil.resize((W, H), PILImage.BILINEAR) + image = np.array(img_pil) + + # Build full-resolution valid mask before subsampling + full_valid = (depth_map > 0.01) & (depth_map < max_depth) & np.isfinite(depth_map) + + # Exclude padding regions (full resolution) + if padding is not None: + pad_left, pad_right, pad_top, pad_bottom = padding + pad_mask = np.ones((H, W), dtype=bool) + pad_mask[:, :pad_left] = False + pad_mask[:pad_top, :] = False + if pad_right > 0: + pad_mask[:, W - pad_right:] = False + if pad_bottom > 0: + pad_mask[H - pad_bottom:, :] = False + full_valid &= pad_mask + + # Remove depth discontinuity edges (MoGe2 style) + if remove_edge: + import utils3d + edge_mask = utils3d.np.depth_map_edge(depth_map, rtol=edge_rtol) + full_valid &= ~edge_mask + + # Filter by confidence + if confidence_map is not None and confidence_threshold > 0: + full_valid &= (confidence_map >= confidence_threshold) + + # Subsample grid + ys = np.arange(0, H, subsample) + xs = np.arange(0, W, subsample) + xx, yy = np.meshgrid(xs, ys) + + depth_sub = depth_map[yy, xx] + rgb_sub = image[yy, xx] # (h, w, 3) + valid = full_valid[yy, xx] + + # Unproject to 3D + fx, fy = intrinsics[0, 0], intrinsics[1, 1] + cx, cy = intrinsics[0, 2], intrinsics[1, 2] + + x3d = (xx[valid] - cx) * depth_sub[valid] / fx + y3d = (yy[valid] - cy) * depth_sub[valid] / fy + z3d = depth_sub[valid] + + # OpenCV (x-right, y-down, z-away) to glTF (x, -y, -z) + points = np.stack([x3d, -y3d, -z3d], axis=-1).astype(np.float32) + + # Colors + rgb = rgb_sub[valid] # (N, 3) uint8 + alpha = np.full((rgb.shape[0], 1), 255, dtype=np.uint8) + colors = np.concatenate([rgb, alpha], axis=-1) # (N, 4) + + return points, colors + + +def _quaternion_to_rotation_matrix(qw, qx, qy, qz): + """Convert quaternion to 3x3 rotation matrix.""" + return np.array([ + [1 - 2*(qy*qy + qz*qz), 2*(qx*qy - qz*qw), 2*(qx*qz + qy*qw)], + [2*(qx*qy + qz*qw), 1 - 2*(qx*qx + qz*qz), 2*(qy*qz - qx*qw)], + [2*(qx*qz - qy*qw), 2*(qy*qz + qx*qw), 1 - 2*(qx*qx + qy*qy)], + ], dtype=np.float32) + + +def boxes3d_to_corners(boxes3d: np.ndarray) -> list[np.ndarray]: + """Convert 3D box params to 8 corner points in GLB coords. + + Args: + boxes3d: (N, 10) boxes in OpenCV camera frame. + Format: [cx, cy, cz, w, h, l, qw, qx, qy, qz] + + Returns: + List of (8, 3) corner arrays in GLB/Three.js coords (y-up, z-backward). + """ + corners_list = [] + # Same transform as point cloud: + # OpenCV (x,y,z) -> glTF (x, -y, -z) + T = np.diag([1.0, -1.0, -1.0]).astype(np.float32) + + for box in boxes3d: + cx, cy, cz = box[0], box[1], box[2] + # Omni3D format: [width, length, height] not [w, h, l] + # width = x-extent, length = z-extent, height = y-extent + bw, bl, bh = box[3], box[4], box[5] + qw, qx, qy, qz = box[6], box[7], box[8], box[9] + + hw, hl, hh = bw / 2, bl / 2, bh / 2 + + # 8 local corners: x=length, y=height, z=width + local_corners = np.array([ + [-hl, -hh, -hw], + [ hl, -hh, -hw], + [ hl, hh, -hw], + [-hl, hh, -hw], + [-hl, -hh, hw], + [ hl, -hh, hw], + [ hl, hh, hw], + [-hl, hh, hw], + ], dtype=np.float32) + + # Rotate by quaternion and translate (in OpenCV coords) + R_cv = _quaternion_to_rotation_matrix(qw, qx, qy, qz) + corners_cv = (R_cv @ local_corners.T).T + np.array([cx, cy, cz]) + + # Convert OpenCV -> glTF: (-z, -y, x) + corners = (T @ corners_cv.T).T + + corners_list.append(corners.astype(np.float32)) + + return corners_list + + +def _generate_box_colors(n_boxes: int) -> list[list[int]]: + """Generate distinct colors for boxes.""" + base_colors = [ + [255, 0, 0, 255], # red + [0, 255, 0, 255], # green + [0, 100, 255, 255], # blue + [255, 255, 0, 255], # yellow + [255, 0, 255, 255], # magenta + [0, 255, 255, 255], # cyan + [255, 128, 0, 255], # orange + [128, 0, 255, 255], # purple + ] + colors = [] + for i in range(n_boxes): + colors.append(base_colors[i % len(base_colors)]) + return colors + + +def _pad_to_4(data: bytes) -> bytes: + """Pad binary data to 4-byte alignment (glTF requirement).""" + remainder = len(data) % 4 + if remainder: + data += b"\x00" * (4 - remainder) + return data + + +def create_scene_glb( + points: np.ndarray, + point_colors: np.ndarray, + boxes3d_list: list[np.ndarray], + output_path: str, + max_points: int = 500000, +) -> str: + """Create a GLB file with colored point cloud + wireframe 3D boxes. + + Args: + points: (N, 3) float32 point cloud xyz. + point_colors: (N, 4) uint8 RGBA colors. + boxes3d_list: List of (M, 10) box arrays (one per image). + output_path: Where to save the .glb file. + max_points: Max number of points to include. + + Returns: + output_path. + """ + # Subsample points if too many + if len(points) > max_points: + idx = np.random.choice(len(points), max_points, replace=False) + points = points[idx] + point_colors = point_colors[idx] + + points = np.ascontiguousarray(points, dtype=np.float32) + point_colors = np.ascontiguousarray(point_colors, dtype=np.uint8) + n_points = len(points) + + # Build box geometry + all_corners_list = [] + for boxes3d in boxes3d_list: + if len(boxes3d) > 0: + corners = boxes3d_to_corners(boxes3d) + all_corners_list.extend(corners) + + n_boxes = len(all_corners_list) + box_colors_rgba = _generate_box_colors(n_boxes) + + # Box vertices and indices + all_box_verts = [] + all_box_colors = [] + all_box_indices = [] + vertex_offset = 0 + + edge_pairs = [ + (0, 1), (1, 2), (2, 3), (3, 0), # bottom face + (4, 5), (5, 6), (6, 7), (7, 4), # top face + (0, 4), (1, 5), (2, 6), (3, 7), # vertical edges + ] + + for i, corners in enumerate(all_corners_list): + all_box_verts.append(corners) + color = box_colors_rgba[i] + all_box_colors.append( + np.tile(np.array(color, dtype=np.uint8), (8, 1)) + ) + indices = np.array( + [(a + vertex_offset, b + vertex_offset) for a, b in edge_pairs], + dtype=np.uint16, + ) + all_box_indices.append(indices) + vertex_offset += 8 + + has_boxes = n_boxes > 0 + + if has_boxes: + box_verts = np.concatenate(all_box_verts, axis=0).astype(np.float32) + box_vert_colors = np.concatenate(all_box_colors, axis=0).astype(np.uint8) + box_indices = np.concatenate(all_box_indices, axis=0).flatten().astype(np.uint16) + else: + box_verts = np.zeros((0, 3), dtype=np.float32) + box_vert_colors = np.zeros((0, 4), dtype=np.uint8) + box_indices = np.zeros(0, dtype=np.uint16) + + # Build binary blob + points_bin = _pad_to_4(points.tobytes()) + colors_bin = _pad_to_4(point_colors.tobytes()) + box_verts_bin = _pad_to_4(box_verts.tobytes()) + box_colors_bin = _pad_to_4(box_vert_colors.tobytes()) + box_indices_bin = _pad_to_4(box_indices.tobytes()) + + blob = points_bin + colors_bin + box_verts_bin + box_colors_bin + box_indices_bin + + # Build glTF structure + buffer_views = [] + accessors = [] + offset = 0 + + # BV0: point positions + buffer_views.append(pygltflib.BufferView( + buffer=0, byteOffset=offset, byteLength=len(points_bin), + target=pygltflib.ARRAY_BUFFER, + )) + accessors.append(pygltflib.Accessor( + bufferView=0, componentType=pygltflib.FLOAT, + count=n_points, type=pygltflib.VEC3, + max=points.max(axis=0).tolist() if n_points > 0 else [0, 0, 0], + min=points.min(axis=0).tolist() if n_points > 0 else [0, 0, 0], + )) + offset += len(points_bin) + + # BV1: point colors + buffer_views.append(pygltflib.BufferView( + buffer=0, byteOffset=offset, byteLength=len(colors_bin), + target=pygltflib.ARRAY_BUFFER, + )) + accessors.append(pygltflib.Accessor( + bufferView=1, componentType=pygltflib.UNSIGNED_BYTE, + count=n_points, type=pygltflib.VEC4, + normalized=True, + )) + offset += len(colors_bin) + + nodes = [] + meshes = [] + + # Point cloud mesh (GL_POINTS = mode 0) + meshes.append(pygltflib.Mesh( + primitives=[pygltflib.Primitive( + attributes=pygltflib.Attributes(POSITION=0, COLOR_0=1), + mode=0, + )] + )) + nodes.append(pygltflib.Node(mesh=0)) + + if has_boxes: + # BV2: box vertices + buffer_views.append(pygltflib.BufferView( + buffer=0, byteOffset=offset, byteLength=len(box_verts_bin), + target=pygltflib.ARRAY_BUFFER, + )) + accessors.append(pygltflib.Accessor( + bufferView=2, componentType=pygltflib.FLOAT, + count=len(box_verts), type=pygltflib.VEC3, + max=box_verts.max(axis=0).tolist(), + min=box_verts.min(axis=0).tolist(), + )) + offset += len(box_verts_bin) + + # BV3: box colors + buffer_views.append(pygltflib.BufferView( + buffer=0, byteOffset=offset, byteLength=len(box_colors_bin), + target=pygltflib.ARRAY_BUFFER, + )) + accessors.append(pygltflib.Accessor( + bufferView=3, componentType=pygltflib.UNSIGNED_BYTE, + count=len(box_vert_colors), type=pygltflib.VEC4, + normalized=True, + )) + offset += len(box_colors_bin) + + # BV4: box indices + buffer_views.append(pygltflib.BufferView( + buffer=0, byteOffset=offset, byteLength=len(box_indices_bin), + target=pygltflib.ELEMENT_ARRAY_BUFFER, + )) + accessors.append(pygltflib.Accessor( + bufferView=4, componentType=pygltflib.UNSIGNED_SHORT, + count=len(box_indices), type=pygltflib.SCALAR, + max=[int(box_indices.max())], + min=[int(box_indices.min())], + )) + offset += len(box_indices_bin) + + # Box wireframe mesh (GL_LINES = mode 1) + meshes.append(pygltflib.Mesh( + primitives=[pygltflib.Primitive( + attributes=pygltflib.Attributes(POSITION=2, COLOR_0=3), + indices=4, + mode=1, + )] + )) + nodes.append(pygltflib.Node(mesh=1)) + + gltf = pygltflib.GLTF2( + scene=0, + scenes=[pygltflib.Scene(nodes=list(range(len(nodes))))], + nodes=nodes, + meshes=meshes, + accessors=accessors, + bufferViews=buffer_views, + buffers=[pygltflib.Buffer(byteLength=len(blob))], + ) + gltf.set_binary_blob(blob) + gltf.save(output_path) + + return output_path + + +def _create_edge_cylinder(p1, p2, radius=0.01, sections=6): + """Create a thin cylinder mesh between two 3D points. + + Args: + p1, p2: (3,) endpoints. + radius: cylinder radius. + sections: number of radial segments. + + Returns: + trimesh.Trimesh or None if edge is degenerate. + """ + import trimesh + + segment = p2 - p1 + length = float(np.linalg.norm(segment)) + if length < 1e-6: + return None + + cyl = trimesh.creation.cylinder( + radius=radius, height=length, sections=sections + ) + direction = segment / length + + # Align cylinder Z-axis to segment direction + z_axis = np.array([0, 0, 1], dtype=np.float64) + cross = np.cross(z_axis, direction) + dot = np.dot(z_axis, direction) + cross_len = np.linalg.norm(cross) + + if cross_len < 1e-6: + R = np.eye(3) if dot > 0 else np.diag([1.0, -1.0, -1.0]) + else: + cross_n = cross / cross_len + angle = np.arccos(np.clip(dot, -1, 1)) + K = np.array([ + [0, -cross_n[2], cross_n[1]], + [cross_n[2], 0, -cross_n[0]], + [-cross_n[1], cross_n[0], 0], + ]) + R = np.eye(3) + np.sin(angle) * K + (1 - np.cos(angle)) * (K @ K) + + T = np.eye(4) + T[:3, :3] = R + T[:3, 3] = (p1 + p2) / 2.0 + cyl.apply_transform(T) + return cyl + + +def _create_wireframe_box_trimesh(corners, color_rgba, radius=0.015): + """Create wireframe box as thin cylinders. + + Args: + corners: (8, 3) corner positions in glTF coords. + color_rgba: [R, G, B, A] uint8 color. + radius: cylinder radius in meters. + + Returns: + trimesh.Trimesh or None. + """ + import trimesh + + edge_pairs = [ + (0, 1), (1, 2), (2, 3), (3, 0), + (4, 5), (5, 6), (6, 7), (7, 4), + (0, 4), (1, 5), (2, 6), (3, 7), + ] + parts = [] + for a, b in edge_pairs: + cyl = _create_edge_cylinder( + corners[a].astype(np.float64), + corners[b].astype(np.float64), + radius=radius, + sections=6, + ) + if cyl is not None: + cyl.visual.face_colors = color_rgba + parts.append(cyl) + + if parts: + return trimesh.util.concatenate(parts) + return None + + +def create_mesh_scene_glb( + depth_map: np.ndarray, + image: np.ndarray, + intrinsics: np.ndarray, + boxes3d_list: list[np.ndarray], + output_path: str, + max_depth: float = 20.0, + padding: tuple[int, int, int, int] | None = None, + remove_edge: bool = True, + edge_rtol: float = 0.04, +) -> str: + """Create GLB with textured mesh (MoGe2 style) + wireframe 3D boxes. + + Args: + depth_map: (H, W) or (1, H, W) depth in meters. + image: (H, W, 3) RGB uint8 [0-255]. + intrinsics: (3, 3) camera intrinsics. + boxes3d_list: List of (M, 10) box arrays. + output_path: Where to save .glb. + max_depth: Max depth cutoff. + padding: (left, right, top, bottom) to exclude. + remove_edge: Remove depth discontinuity edges. + edge_rtol: Edge detection tolerance. + + Returns: + output_path. + """ + import utils3d + import trimesh + from PIL import Image as PILImage + + # Prepare depth + while depth_map.ndim > 2: + depth_map = depth_map.squeeze(0) + depth_map = depth_map.astype(np.float32) + H, W = depth_map.shape + + # Prepare image + while image.ndim > 3: + image = image.squeeze(0) + if image.ndim == 3 and image.shape[0] in (1, 3): + image = np.transpose(image, (1, 2, 0)) + if image.ndim == 2: + image = np.stack([image] * 3, axis=-1) + if image.shape[0] != H or image.shape[1] != W: + img_pil = PILImage.fromarray(image) + img_pil = img_pil.resize((W, H), PILImage.BILINEAR) + image = np.array(img_pil) + + # Build valid mask + valid = ( + (depth_map > 0.01) + & (depth_map < max_depth) + & np.isfinite(depth_map) + ) + + if padding is not None: + pad_left, pad_right, pad_top, pad_bottom = padding + if pad_left > 0: + valid[:, :pad_left] = False + if pad_right > 0: + valid[:, W - pad_right:] = False + if pad_top > 0: + valid[:pad_top, :] = False + if pad_bottom > 0: + valid[H - pad_bottom:, :] = False + + if remove_edge: + edge = utils3d.np.depth_map_edge(depth_map, rtol=edge_rtol) + valid &= ~edge + + # Unproject to 3D in OpenCV coords (x-right, y-down, z-forward) + # Build mesh in OpenCV space first so triangle winding is correct, + # then transform vertices to glTF coords afterwards. + fx, fy = float(intrinsics[0, 0]), float(intrinsics[1, 1]) + cx, cy = float(intrinsics[0, 2]), float(intrinsics[1, 2]) + u, v = np.meshgrid(np.arange(W), np.arange(H)) + x3d = (u - cx) * depth_map / fx + y3d = (v - cy) * depth_map / fy + z3d = depth_map + points_cv = np.stack([x3d, y3d, z3d], axis=-1).astype(np.float32) + + # UV map + uv = np.stack( + [u / max(W - 1, 1), v / max(H - 1, 1)], axis=-1 + ).astype(np.float32) + + # Colors normalized [0, 1] + colors = image.astype(np.float32) / 255.0 + + # Build triangulated mesh in OpenCV coords (preserves correct winding) + faces, vertices, vertex_colors, vertex_uvs = ( + utils3d.np.build_mesh_from_map( + points_cv, colors, uv, mask=valid, tri=True + ) + ) + + print( + f"[Mesh] {vertices.shape[0]} vertices, " + f"{faces.shape[0]} faces, " + f"valid pixels: {valid.sum()}/{valid.size}" + ) + + if len(vertices) == 0: + # Fallback to empty file + scene = trimesh.Scene() + scene.export(output_path) + return output_path + + # Transform vertices: OpenCV (x, y, z) -> glTF (x, -y, -z) + # This is a 180-degree rotation around x-axis (det=+1), + # so it preserves triangle winding order. + vertices = vertices * np.array([1.0, -1.0, -1.0], dtype=np.float32) + + # Trimesh flips UV v when exporting to GLB (OpenGL v=0 at bottom + # vs glTF v=0 at top). Our UVs are already in image convention + # (v=0 at top), so pre-flip to compensate for trimesh's flip. + vertex_uvs = vertex_uvs.copy() + vertex_uvs[:, 1] = 1.0 - vertex_uvs[:, 1] + + # Create textured mesh (process=False to avoid trimesh modifying geometry) + texture_img = PILImage.fromarray(image) + material = trimesh.visual.material.PBRMaterial( + baseColorTexture=texture_img, + metallicFactor=0.0, + roughnessFactor=1.0, + ) + visuals = trimesh.visual.TextureVisuals( + uv=vertex_uvs, material=material + ) + mesh = trimesh.Trimesh( + vertices=vertices, faces=faces, visual=visuals, + process=False, + ) + + scene = trimesh.Scene() + scene.add_geometry(mesh, node_name="scene_mesh") + + # Add wireframe 3D boxes as thin cylinder geometry + all_corners = [] + for boxes3d in boxes3d_list: + if len(boxes3d) > 0: + corners = boxes3d_to_corners(boxes3d) + all_corners.extend(corners) + + box_colors = _generate_box_colors(len(all_corners)) + for i, corners in enumerate(all_corners): + box_mesh = _create_wireframe_box_trimesh( + corners, box_colors[i], radius=0.015 + ) + if box_mesh is not None: + scene.add_geometry( + box_mesh, node_name=f"box_{i}" + ) + + scene.export(output_path) + return output_path diff --git a/vis4d/__init__.py b/vis4d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f0269c2bb34659444b7fff6be42c8a1717b79c53 --- /dev/null +++ b/vis4d/__init__.py @@ -0,0 +1,20 @@ +"""Vis4D is a modular library for 4D scene understanding. + +It contains common operators and models, data pipelines and training recipes +for a number of contemporary methods and provides a compositional framework +for further research and development of 4D Vision algorithms. +""" + +import logging + +__version__ = "1.0.0" + +_root_logger = logging.getLogger() +_logger = logging.getLogger(__name__) +_logger.setLevel(logging.INFO) + +# if root logger has handlers, propagate messages up and let root logger +# process them +if not _root_logger.hasHandlers(): # pragma: no cover + _logger.addHandler(logging.StreamHandler()) + _logger.propagate = False diff --git a/vis4d/__main__.py b/vis4d/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..0ad87305020df07bbbd4e0ea70c693390dd282a3 --- /dev/null +++ b/vis4d/__main__.py @@ -0,0 +1,5 @@ +"""Entry point for the vis4d package.""" + +from .engine.run import entrypoint + +entrypoint() diff --git a/vis4d/__pycache__/__init__.cpython-311.pyc b/vis4d/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..946a1c7206cab7b576ac4a8c9e30c91a10b5c530 Binary files /dev/null and b/vis4d/__pycache__/__init__.cpython-311.pyc differ diff --git a/vis4d/common/__init__.py b/vis4d/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..07af597215587a7a60addb5907a9ce08f174885d --- /dev/null +++ b/vis4d/common/__init__.py @@ -0,0 +1 @@ +"""Contains common functions and types that are used across modules.""" diff --git a/vis4d/common/__pycache__/__init__.cpython-311.pyc b/vis4d/common/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f28687211c2a81583431e68784f0662f39256efd Binary files /dev/null and b/vis4d/common/__pycache__/__init__.cpython-311.pyc differ diff --git a/vis4d/common/__pycache__/array.cpython-311.pyc b/vis4d/common/__pycache__/array.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f84a2e6e477b3e4512c9e0feb659db35f5319b8d Binary files /dev/null and b/vis4d/common/__pycache__/array.cpython-311.pyc differ diff --git a/vis4d/common/__pycache__/dict.cpython-311.pyc b/vis4d/common/__pycache__/dict.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1ec25ba8b18ab08c85164a29b966f9115d73705 Binary files /dev/null and b/vis4d/common/__pycache__/dict.cpython-311.pyc differ diff --git a/vis4d/common/__pycache__/distributed.cpython-311.pyc b/vis4d/common/__pycache__/distributed.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e916296fd9d05f9d11b8bb6939b0029bd50ef37c Binary files /dev/null and b/vis4d/common/__pycache__/distributed.cpython-311.pyc differ diff --git a/vis4d/common/__pycache__/imports.cpython-311.pyc b/vis4d/common/__pycache__/imports.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c702875830f751cf9308ad3a32206cc08beed4f Binary files /dev/null and b/vis4d/common/__pycache__/imports.cpython-311.pyc differ diff --git a/vis4d/common/__pycache__/logging.cpython-311.pyc b/vis4d/common/__pycache__/logging.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df48a7daaff7e26064cdc3e3de82da9dfc6a1a70 Binary files /dev/null and b/vis4d/common/__pycache__/logging.cpython-311.pyc differ diff --git a/vis4d/common/__pycache__/named_tuple.cpython-311.pyc b/vis4d/common/__pycache__/named_tuple.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6eec1af8646d32c50d49fa757643ea305ee1f32 Binary files /dev/null and b/vis4d/common/__pycache__/named_tuple.cpython-311.pyc differ diff --git a/vis4d/common/__pycache__/typing.cpython-311.pyc b/vis4d/common/__pycache__/typing.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b1b73cee98e71571d11a91a75d6495e16784a98 Binary files /dev/null and b/vis4d/common/__pycache__/typing.cpython-311.pyc differ diff --git a/vis4d/common/array.py b/vis4d/common/array.py new file mode 100644 index 0000000000000000000000000000000000000000..c520155eca15485aba1688f389f45e9ae5ba10f0 --- /dev/null +++ b/vis4d/common/array.py @@ -0,0 +1,166 @@ +"""This module contains array utility functions.""" + +from __future__ import annotations + +from typing import overload + +import numpy as np +import torch + +from vis4d.common.typing import ( + ArrayLike, + NDArrayBool, + NDArrayF32, + NDArrayF64, + NDArrayI32, + NDArrayI64, + NDArrayNumber, + NDArrayUI8, + NDArrayUI16, + NDArrayUI32, +) + + +# Bool dtypes +@overload +def array_to_numpy( + data: ArrayLike, n_dims: int | None, dtype: type[np.bool_] +) -> NDArrayBool: ... + + +# Float dtypes +@overload +def array_to_numpy( + data: ArrayLike | None, n_dims: int | None, dtype: type[np.float32] +) -> NDArrayF32: ... + + +@overload +def array_to_numpy( + data: ArrayLike | None, n_dims: int | None, dtype: type[np.float64] +) -> NDArrayF64: ... + + +# Int dtypes +@overload +def array_to_numpy( + data: ArrayLike | None, n_dims: int | None, dtype: type[np.int32] +) -> NDArrayI32: ... + + +@overload +def array_to_numpy( + data: ArrayLike | None, n_dims: int | None, dtype: type[np.int64] +) -> NDArrayI64: ... + + +# UInt dtypes +@overload +def array_to_numpy( + data: ArrayLike | None, n_dims: int | None, dtype: type[np.uint8] +) -> NDArrayUI8: ... + + +@overload +def array_to_numpy( + data: ArrayLike | None, n_dims: int | None, dtype: type[np.uint16] +) -> NDArrayUI16: ... + + +@overload +def array_to_numpy( + data: ArrayLike | None, n_dims: int | None, dtype: type[np.uint32] +) -> NDArrayUI32: ... + + +# Union of all dtypes +@overload +def array_to_numpy( + data: ArrayLike | None, n_dims: int | None +) -> NDArrayNumber: ... + + +@overload +def array_to_numpy(data: None) -> None: ... + + +def array_to_numpy( + data: ArrayLike | None, + n_dims: int | None = None, + dtype: ( + type[np.bool_] + | type[np.float32] + | type[np.float64] + | type[np.int32] + | type[np.int64] + | type[np.uint8] + | type[np.uint16] + | type[np.uint32] + ) = np.float32, +) -> NDArrayNumber | None: + """Converts a given array like object to a numpy array. + + Helper function to convert an array like object to a numpy array. + This functions converts torch.Tensors or Sequences to numpy arrays. + + If the argument is None, None will be returned. + + Examples: + >>> convert_to_array([1,2,3]) + >>> # -> array([1,2,3]) + >>> convert_to_array(None) + >>> # -> None + >>> convert_to_array(torch.tensor([1,2,3]).cuda()) + >>> # -> array([1,2,3]) + >>> convert_to_array([1,2,3], n_dims = 2).shape + >>> # -> [1, 3] + + Args: + data (ArrayLike | None): ArrayLike object that should be converted + to numpy. + + n_dims (int | None, optional): Target number of dimension of the array. + If the provided array does not have this shape, it will be + squeezed or exanded (from the left). If it still does not match, + an error is raised. + + dtype (SUPPORTED_DTYPES, optional): Target dtype of the array. Defaults + to np.float32. + + Raises: + ValueError: If the provied array like objects can not be converted + with the target dimensions. + + Returns: + NDArrayNumber | None: The converted numpy array or None if None was + provided. + """ + if data is None: + return data + + if isinstance(data, np.ndarray): + array = data + elif isinstance(data, torch.Tensor): + array = np.asarray(data.detach().cpu().numpy()) + else: + array = np.asarray(data) + + if n_dims is not None: + # Squeeze if needed + for _ in range(len(array.shape) - n_dims): + if array.shape[0] == 1: + array = array.squeeze(0) + elif array.shape[-1] == 1: + array = array.squeeze(-1) + + # expand if needed + for _ in range(n_dims - len(array.shape)): + array = np.expand_dims(array, 0) + + if len(array.shape) != n_dims: + raise ValueError( + f"Failed to convert target array of shape {array.shape} to" + f"have {n_dims} dimensions." + ) + + return array.astype(dtype) # type: ignore diff --git a/vis4d/common/ckpt.py b/vis4d/common/ckpt.py new file mode 100644 index 0000000000000000000000000000000000000000..6c610ac5369a63cc306cf36e2095baace191a75a --- /dev/null +++ b/vis4d/common/ckpt.py @@ -0,0 +1,435 @@ +"""This module contains convenience functions for checkpoint loading. + +The code is based on https://github.com/open-mmlab/mmcv/ +""" + +from __future__ import annotations + +import os.path as osp +import re +from collections import OrderedDict +from typing import Callable, Union + +import torch +import torchvision +from torch import nn +from torch.hub import load_state_dict_from_url as load_url + +from vis4d.common.distributed import ( + get_rank, + get_world_size, + is_module_wrapper, + synchronize, +) +from vis4d.common.logging import rank_zero_info, rank_zero_warn +from vis4d.common.typing import TorchCheckpoint + +CheckpointLoadFunc = Callable[ + [str, Union[str, torch.device, None]], TorchCheckpoint +] + +# Define mapping for specific model checkpoints +BDD100K_MODEL_PREFIX = "https://dl.cv.ethz.ch/bdd100k/" +MM_MODEL_MAP = { + "mmdet://": "https://download.openmmlab.com/mmdetection/v2.0/", + "mmseg://": "https://download.openmmlab.com/mmsegmentation/v0.5/", +} +MM_CFG_MAP = { + "mmdet://": "syscv/mmdetection/master/configs/", + "mmseg://": "open-mmlab/mmsegmentation/master/configs/", +} +MM_ZIP_MAP = { + "mmdet://": "mmdetection-master/configs/", + "mmseg://": "mmsegmentation-master/configs/", +} + + +def load_model_checkpoint( + model: nn.Module, + weights: str, + strict: bool = False, + rev_keys: None | list[tuple[str, str]] = None, + map_location: str | torch.device | None = "cpu", +) -> None: + """Load checkpoint from a file or URI. + + Args: + model (Module): Module to load checkpoint. + weights (str): Accept local filepath, URL, or e.g.``torchvision://xxx`` + strict (bool): Whether to allow different params for the model and + checkpoint. + rev_keys (tuple[tuple[str, str]]): A tuple of customized keywords to + modify the state_dict in checkpoint. Each item is a + (pattern, replacement) pair of the regular expression operations. + Default: strip the prefix 'module.' by [(r'^module.', '')]. + map_location (str | torch.device | None): Same as :func:`torch.load`. + Default: 'cpu'. + """ + if rev_keys is None: # pragma: no cover + rev_keys = [(r"^module\.", "")] + if re.compile(r"^mm(det|seg)://").search(weights): + pre = weights[:8] + weights = MM_MODEL_MAP[pre] + weights.split(pre)[-1] + _load_checkpoint( + model, weights, map_location, strict=strict, revise_keys=rev_keys + ) + elif weights.startswith("bdd100k://"): + weights = BDD100K_MODEL_PREFIX + weights.split("bdd100k://")[-1] + _load_checkpoint( + model, weights, map_location, strict=strict, revise_keys=rev_keys + ) + else: # pragma: no cover + _load_checkpoint( + model, weights, map_location, strict=strict, revise_keys=rev_keys + ) + + +class CheckpointLoader: + """A general checkpoint loader to manage all schemes.""" + + _schemes: dict[str, CheckpointLoadFunc] = {} + + @classmethod + def _register_scheme( + cls, + prefixes: str | tuple[str, ...], + loader: CheckpointLoadFunc, + force: bool = False, + ) -> None: + """Register a scheme.""" + if isinstance(prefixes, str): + prefixes = (prefixes,) + + assert isinstance(prefixes, (list, tuple)) + for prefix in prefixes: + if (prefix not in cls._schemes) or force: + cls._schemes[prefix] = loader + else: + raise KeyError( + f"{prefix} is already registered as a loader backend, " + 'add "force=True" if you want to override it' + ) + # sort, longer prefixes take priority + cls._schemes = OrderedDict( + sorted(cls._schemes.items(), key=lambda t: t[0], reverse=True) + ) + + @classmethod + def register_scheme( + cls, + prefixes: str | tuple[str, ...], + force: bool = False, + ) -> Callable[[CheckpointLoadFunc], CheckpointLoadFunc]: + """Register a loader to CheckpointLoader. + + This method should be used as a decorator. + + Args: + prefixes (str or Sequence[str]): The register prefix of the loader. + force (bool, optional): Whether to override the loader if the + prefix has already been registered. Defaults to False. + """ + + def _register( + loader_cls: CheckpointLoadFunc, + ) -> CheckpointLoadFunc: + cls._register_scheme(prefixes, loader_cls, force=force) + return loader_cls + + return _register + + @classmethod + def _get_checkpoint_loader(cls, path: str) -> CheckpointLoadFunc: + """Finds a loader that supports the given path. + + Falls back to the local loader if no other loader is found, since it is + registered with an empty prefix. + + Args: + path (str): checkpoint path. + + Raises: + ValueError: If the path cannot be matched to any prefix, raise an + error. This should usually not happen, since the local loader + is registered with an empty prefix. + + Returns: + CheckpointLoadFunc: checkpoint loader. + """ + for prefix, func in cls._schemes.items(): + if re.match(prefix, path) is not None: + return func + raise ValueError("Invalid path! No prefix matched.") + + @classmethod + def load_checkpoint( + cls, + filename: str, + map_location: str | torch.device | None = None, + ) -> TorchCheckpoint: + """Load checkpoint through URL scheme path. + + Args: + filename (str): checkpoint file name with given prefix + map_location (str, optional): Same as :func:`torch.load`. + Default: None + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + checkpoint_loader = cls._get_checkpoint_loader(filename) + class_name = checkpoint_loader.__name__ + rank_zero_info( + f"Load checkpoint from {class_name[10:]} path: {filename}" + ) + return checkpoint_loader(filename, map_location) + + +@CheckpointLoader.register_scheme(prefixes="") +def load_from_local( + filename: str, + map_location: str | torch.device | None = None, +) -> TorchCheckpoint: + """Load checkpoint by local file path. + + Args: + filename (str): local checkpoint file path + map_location (str, optional): Same as :func:`torch.load`. + + Raises: + FileNotFoundError: If file not found. + + Returns: + TorchCheckpoint: The loaded checkpoint. + """ + filename = osp.expanduser(filename) + if not osp.isfile(filename): + raise FileNotFoundError(f"{filename} can not be found.") + checkpoint = torch.load( + filename, weights_only=True, map_location=map_location + ) + return checkpoint + + +@CheckpointLoader.register_scheme(prefixes=("http://", "https://")) +def load_from_http( + filename: str, map_location: str | torch.device | None = None +) -> TorchCheckpoint: + """Load checkpoint through HTTP or HTTPS scheme path. + + In distributed setting, this function only download checkpoint at local + rank 0. + + Args: + filename (str): checkpoint file path with modelzoo or + torchvision prefix + map_location (str, optional): Same as :func:`torch.load`. + + Returns: + TorchCheckpoint: The loaded checkpoint. + """ + rank, world_size = get_rank(), get_world_size() + if rank == 0: + checkpoint = load_url(filename, map_location=map_location) + if world_size > 1: + synchronize() + if rank > 0: + checkpoint = load_url(filename, map_location=map_location) + return checkpoint # pylint: disable=used-before-assignment + + +def get_torchvision_models() -> dict[str, str]: + """Get full URLs of all torchvision paths. + + Requires torchvision >= 0.14.0a0. + """ + model_urls: dict[str, str] = {} + weights_list = [ + torchvision.models.get_model_weights(model) + for model in torchvision.models.list_models(torchvision.models) + ] + for model_cls in weights_list: + # The name of torchvision model weights classes ends with + # `_Weights` such as `ResNet18_Weights`. However, some model weight + # classes, such as `MNASNet0_75_Weights` does not have any urls in + # torchvision 0.13.0 and cannot be iterated. Here we simply check + # `DEFAULT` attribute to ensure the class is not empty. + if not hasattr(model_cls, "DEFAULT"): + continue + # Since `cls.DEFAULT` can not be accessed by iterating cls, we set + # default urls explicitly. + cls_name = model_cls.__name__ + cls_key = cls_name.replace("_Weights", "").lower() + model_urls[f"{cls_key}.default"] = model_cls.DEFAULT.url + for weight_enum in model_cls: + cls_key = cls_name.replace("_Weights", "").lower() + cls_key = f"{cls_key}.{weight_enum.name.lower()}" + model_urls[cls_key] = weight_enum.url + + return model_urls + + +@CheckpointLoader.register_scheme(prefixes="torchvision://") +def load_from_torchvision( + filename: str, map_location: str | torch.device | None = None +) -> TorchCheckpoint: + """Load checkpoint through the file path prefixed with torchvision. + + Args: + filename (str): checkpoint file path with modelzoo or + torchvision prefix + map_location (str, optional): Same as :func:`torch.load`.' + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + model_urls = get_torchvision_models() + model_name = filename[14:] + + # Support getting model urls in the same way as torchvision + # `ResNet50_Weights.IMAGENET1K_V1` will be mapped to + # resnet50.imagenet1k_v1. + model_name = model_name.lower().replace("_weights", "") + return load_from_http(model_urls[model_name], map_location) + + +def load_state_dict( + module: nn.Module, state_dict: TorchCheckpoint, strict: bool = False +) -> None: + """Load state_dict to a module. + + This method is modified from :meth:`torch.nn.Module.load_state_dict`. + Default value for ``strict`` is set to ``False`` and the message for + param mismatch will be shown even if strict is False. + + Raises: + RuntimeError: If strict, it will raise a runtime error if module and + state_dict do not match completely. + + Args: + module (Module): Module that receives the state_dict. + state_dict (dict or OrderedDict): Weights. + strict (bool): whether to strictly enforce that the keys + in :attr:`state_dict` match the keys returned by this module's + :meth:`~torch.nn.Module.state_dict` function. Default: ``False``. + """ + unexpected_keys: list[str] = [] + all_missing_keys: list[str] = [] + err_msg: list[str] = [] + + metadata = getattr(state_dict, "_metadata", None) + state_dict = state_dict.copy() + if metadata is not None: + # pylint: disable=protected-access + state_dict._metadata = metadata # type: ignore + + # use _load_from_state_dict to enable checkpoint version control + def load(module: nn.Module, prefix: str = "") -> None: + # recursively check parallel module in case that the model has a + # complicated structure, e.g., nn.Module(nn.Module(DDP)) + if is_module_wrapper(module): + module = module.module # type: ignore + local_metadata = ( + {} if metadata is None else metadata.get(prefix[:-1], {}) + ) + module._load_from_state_dict( # pylint: disable=protected-access + state_dict, + prefix, + local_metadata, + True, + all_missing_keys, + unexpected_keys, + err_msg, + ) + # pylint: disable=protected-access + for name, child in module._modules.items(): + if child is not None: + # pylint: disable=not-callable + load(child, prefix + name + ".") + + load(module) + # break load->load reference cycle + load = None # type: ignore + + # ignore "num_batches_tracked" of BN layers + missing_keys = [ + key for key in all_missing_keys if "num_batches_tracked" not in key + ] + + if unexpected_keys: + err_msg.append( + "unexpected key in source " + f'state_dict: {", ".join(unexpected_keys)}\n' + ) + if missing_keys: + err_msg.append( + f'missing keys in source state_dict: {", ".join(missing_keys)}\n' + ) + + rank = get_rank() + if len(err_msg) > 0 and rank == 0: + err_msg.insert( + 0, "The model and loaded state dict do not match exactly\n" + ) + err_msg = "\n".join(err_msg) # type: ignore + if strict: + raise RuntimeError(err_msg) + rank_zero_warn(err_msg) + + +def _load_checkpoint( + model: torch.nn.Module, + filename: str, + map_location: str | torch.device | None = None, + strict: bool = False, + revise_keys: tuple[tuple[str, str]] | list[tuple[str, str]] = ( + (r"^module\.", ""), + ), +) -> TorchCheckpoint: + """Load checkpoint from a file or URI. + + Args: + model (Module): Module to load checkpoint. + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for + details. + map_location (str): Same as :func:`torch.load`. + strict (bool): Whether to allow different params for the model and + checkpoint. + revise_keys (tuple[tuple[str, str]]): A tuple of customized keywords to + modify the state_dict in checkpoint. Each item is a + (pattern, replacement) pair of the regular expression operations. + Default: strip the prefix 'module.' by [(r'^module.', '')]. + + Raises: + RuntimeError: If no state_dict is found in the checkpoint file. + + Returns: + TorchCheckpoint: The loaded checkpoint. + """ + checkpoint = CheckpointLoader.load_checkpoint(filename, map_location) + # OrderedDict is a subclass of dict + if not isinstance(checkpoint, dict): + raise RuntimeError( + f"No state_dict found in checkpoint file {filename}" + ) + # get state_dict from checkpoint + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + elif "model" in checkpoint: + state_dict = checkpoint["model"] + else: + state_dict = checkpoint + + # strip prefix of state_dict + metadata = getattr(state_dict, "_metadata", OrderedDict()) + for p, r in revise_keys: + state_dict = OrderedDict( + {re.sub(p, r, k): v for k, v in state_dict.items()} + ) + # Keep metadata in state_dict + state_dict._metadata = metadata # pylint: disable=protected-access + + # load state_dict + load_state_dict(model, state_dict, strict) + return checkpoint diff --git a/vis4d/common/dict.py b/vis4d/common/dict.py new file mode 100644 index 0000000000000000000000000000000000000000..da2746926b557b16f09c10c35cc6d238d996cc1f --- /dev/null +++ b/vis4d/common/dict.py @@ -0,0 +1,97 @@ +"""This module contains dictionary utility functions.""" + +from __future__ import annotations + +from typing import Any + +from vis4d.common.typing import DictStrAny + + +def flatten_dict(dictionary: DictStrAny, seperator: str) -> list[str]: + """Flatten a nested dictionary. + + Args: + dictionary (DictStrAny): The dictionary to flatten. + seperator (str): The seperator to use between keys. + + Returns: + List[str]: A list of flattened keys. + + Examples: + >>> d = {'a': {'b': {'c': 10}}} + >>> flatten_dict(d, '.') + ['a.b.c'] + """ + flattened = [] + for key, value in dictionary.items(): + if isinstance(value, dict): + flattened.extend( + [ + f"{key}{seperator}{subkey}" + for subkey in flatten_dict(value, seperator) + ] + ) + else: + flattened.append(key) + return flattened + + +def get_dict_nested( # type: ignore + dictionary: DictStrAny, keys: list[str], allow_missing: bool = False +) -> Any: + """Get a value from a nested dictionary. + + Args: + dictionary (DictStrAny): The dictionary to get the value from. + keys (list[str]): A list of keys specifying the location in the nested + dictionary where the value is located. + allow_missing (bool, optional): Whether to allow missing keys. Defaults + to False. If False, a ValueError is raised if a key is not present, + otherwise None is returned. + + Returns: + list[str]: The value from the dictionary. + + Raises: + ValueError: If the key is not present in the dictionary. + + Examples: + >>> d = {'a': {'b': {'c': 10}}} + >>> get_dict_nested(d, ['a', 'b', 'c']) + 10 + + >>> get_dict_nested(d, ['a', 'b', 'd']) + ValueError: Key d not in dictionary! Current keys: dict_keys(['c']) + """ + for key in keys: + if key not in dictionary: + if allow_missing: + return None + raise ValueError( + f"Key {key} not in dictionary! Current keys: " + f"{dictionary.keys()}" + ) + dictionary = dictionary[key] + return dictionary + + +def set_dict_nested( # type: ignore + dictionary: DictStrAny, keys: list[str], value: Any +) -> None: + """Set a value in a nested dictionary. + + Args: + dictionary (dict[str, Any]): The dictionary to set the value in. + keys (list[str]): A list of keys specifying the location in the nested + dictionary where the value should be set. + value (Any): The value to set in the dictionary. + + Examples: + >>> d = {} + >>> set_dict_nested(d, ['a', 'b', 'c'], 10) + >>> d + {'a': {'b': {'c': 10}}} + """ + for key in keys[:-1]: + dictionary = dictionary.setdefault(key, {}) + dictionary[keys[-1]] = value diff --git a/vis4d/common/distributed.py b/vis4d/common/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..34fdebc6689902783fef10ccf53c2c475161f621 --- /dev/null +++ b/vis4d/common/distributed.py @@ -0,0 +1,402 @@ +# mypy: disable-error-code=misc +"""This module contains utilities for multiprocess parallelism.""" +from __future__ import annotations + +import logging +import os +import pickle +import shutil +import tempfile +from collections import OrderedDict +from functools import wraps +from typing import Any + +import torch +import torch.distributed as dist +from torch import Tensor, nn +from torch.distributed import broadcast_object_list +from torch.nn.parallel import DataParallel, DistributedDataParallel + +from vis4d.common.typing import ArgsType, DictStrAny, GenericFunc + + +# no coverage for these functions, since we don't unittest distributed setting +def get_world_size() -> int: # pragma: no cover + """Get the world size (number of processes) of torch.distributed. + + Returns: + int: The world size. + """ + if os.environ.get("WORLD_SIZE", None): + return int(os.environ["WORLD_SIZE"]) + + # In interactive job not using slurm ntasks + if os.environ.get("SLURM_JOB_NAME", None) != "bash": + if os.environ.get("SLURM_NTASKS", None): + return int(os.environ["SLURM_NTASKS"]) + + return 1 + + +def get_rank() -> int: # pragma: no cover + """Get the global rank of the current process in torch.distributed. + + Returns: + int: The global rank. + """ + # For torchrun + if os.environ.get("RANK", None): + return int(os.environ["RANK"]) + + # Because pl don't set global rank, use local rank for interactive job and + # slurm process id for submitted job + if os.environ.get("SLURM_JOB_NAME", None) == "bash": + return get_local_rank() + if os.environ.get("SLURM_PROCID", None): + return int(os.environ["SLURM_PROCID"]) + + # Return local rank + return get_local_rank() + + +def get_local_rank() -> int: # pragma: no cover + """Get the local rank of the current process in torch.distributed. + + Returns: + int: The local rank. + """ + if os.environ.get("LOCAL_RANK", None): + return int(os.environ["LOCAL_RANK"]) + if os.environ.get("SLURM_LOCALID", None): + return int(os.environ["SLURM_LOCALID"]) + + return 0 + + +def distributed_available() -> bool: # pragma: no cover + """Check if torch.distributed is available. + + Returns: + bool: Whether torch.distributed is available. + """ + return dist.is_available() and dist.is_initialized() + + +def synchronize() -> None: # pragma: no cover + """Sync (barrier) among all processes when using distributed training.""" + if not distributed_available(): + return + if get_world_size() == 1: + return + dist.barrier(group=dist.group.WORLD, device_ids=[get_local_rank()]) + + +def broadcast(obj: Any, src: int = 0) -> Any: # type: ignore + """Broadcast an object from a source to all processes.""" + if not distributed_available(): + return obj + obj = [obj] + rank = get_rank() + if rank != src: + obj = [None] + broadcast_object_list(obj, src, group=dist.group.WORLD) + return obj[0] + + +def serialize_to_tensor(data: Any) -> Tensor: # type: ignore + """Serialize arbitrary picklable data to a Tensor. + + Args: + data (Any): The data to serialize. + + Returns: + Tensor: The serialized data as a Tensor. + + Raises: + AssertionError: If the backend of torch.distributed is not gloo or + nccl. + """ + backend = dist.get_backend() + assert backend in { + "gloo", + "nccl", + }, "_serialize_to_tensor only supports gloo and nccl backends." + device = torch.device("cpu" if backend == "gloo" else "cuda") + + buffer = pickle.dumps(data) + if len(buffer) > 1024**3: + logger = logging.getLogger(__name__) + logger.warning( + "Rank %s tries all-gather %.2f GB of data on device %s", + get_rank(), + len(buffer) / (1024**3), + device, + ) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to(device=device) + return tensor + + +def rank_zero_only(func: GenericFunc) -> GenericFunc: + """Allows the decorated function to be called only on global rank 0. + + Args: + func(GenericFunc): The function to decorate. + + Returns: + GenericFunc: The decorated function. + + """ + + @wraps(func) + def wrapped_fn(*args: ArgsType, **kwargs: ArgsType) -> Any: # type: ignore + rank = get_rank() + if rank == 0: + return func(*args, **kwargs) + return None + + return wrapped_fn + + +def pad_to_largest_tensor( + tensor: Tensor, +) -> tuple[list[int], Tensor]: # pragma: no cover + """Pad tensor to largest size among the tensors in each process. + + Args: + tensor: tensor to be padded. + + Returns: + list[int]: size of the tensor, on each rank + Tensor: padded tensor that has the max size + """ + world_size = get_world_size() + assert ( + world_size >= 1 + ), "_pad_to_largest_tensor requires distributed setting!" + local_size = torch.tensor( + [tensor.numel()], dtype=torch.int64, device=tensor.device + ) + local_size_list = [local_size.clone() for _ in range(world_size)] + dist.all_gather_object(local_size_list, local_size) + size_list = [int(size.item()) for size in local_size_list] + max_size = max(size_list) + + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + if local_size != max_size: + padding = torch.zeros( + (max_size - local_size,), dtype=torch.uint8, device=tensor.device + ) + tensor = torch.cat((tensor, padding), dim=0) + return size_list, tensor + + +def all_gather_object_gpu( # type: ignore + data: Any, rank_zero_return_only: bool = True +) -> list[Any] | None: # pragma: no cover + """Run pl_module.all_gather on arbitrary picklable data. + + Args: + data: any picklable object + rank_zero_return_only: if results should only be returned on rank 0 + + Returns: + list[Any]: list of data gathered from each process + """ + rank, world_size = get_rank(), get_world_size() + if world_size == 1: + return [data] + + # encode + tensor = serialize_to_tensor(data) + size_list, tensor = pad_to_largest_tensor(tensor) + tensor_list = [tensor.clone() for _ in range(world_size)] + dist.all_gather_object(tensor_list, tensor) # (world_size, N) + + if rank_zero_return_only and rank != 0: + return None + + # decode + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def create_tmpdir( + rank: int, tmpdir: None | str = None, use_system_tmp: bool = True +) -> str: # pragma: no cover + """Create and distribute a temporary directory across all processes.""" + if tmpdir is not None: + os.makedirs(tmpdir, exist_ok=True) + return tmpdir + if rank == 0: + # create a temporary directory + default_tmpdir = tempfile.gettempdir() + if default_tmpdir is not None and use_system_tmp: + dist_tmpdir = os.path.join(default_tmpdir, ".dist_tmp") + else: + dist_tmpdir = os.path.join("vis4d-workspace", ".dist_tmp") + os.makedirs(dist_tmpdir, exist_ok=True) + tmpdir = tempfile.mkdtemp(dir=dist_tmpdir) + else: + tmpdir = None + return broadcast(tmpdir) + + +def all_gather_object_cpu( # type: ignore + data: Any, + tmpdir: None | str = None, + rank_zero_return_only: bool = True, + use_system_tmp: bool = False, +) -> list[Any] | None: # pragma: no cover + """Share arbitrary picklable data via file system caching. + + Args: + data: any picklable object. + tmpdir: Save path for temporary files. If None, safely create tmpdir. + rank_zero_return_only: if results should only be returned on rank 0. + use_system_tmp: if use system tmpdir or not. + + Returns: + list[Any]: list of data gathered from each process. + """ + rank, world_size = get_rank(), get_world_size() + if world_size == 1: + return [data] + + # make tmp dir + tmpdir = create_tmpdir(rank, tmpdir, use_system_tmp) + + # encode & save + with open(os.path.join(tmpdir, f"part_{rank}.pkl"), "wb") as f: + pickle.dump(data, f) + synchronize() + + if rank_zero_return_only and rank != 0: + return None + + # load & decode + data_list = [] + for i in range(world_size): + with open(os.path.join(tmpdir, f"part_{i}.pkl"), "rb") as f: + data_list.append(pickle.load(f)) + + # remove dir + if not rank_zero_return_only: + # wait for all processes to finish loading before removing tmpdir + synchronize() + if rank == 0: + shutil.rmtree(tmpdir) + + return data_list + + +def reduce_mean(tensor: Tensor) -> Tensor: + """Obtain the mean of tensor on different GPUs.""" + if not (dist.is_available() and dist.is_initialized()): + return tensor + tensor = tensor.clone() + dist.all_reduce(tensor.div_(get_world_size()), op=dist.ReduceOp.SUM) + return tensor + + +def obj2tensor( # type: ignore + pyobj: Any, device: torch.device = torch.device("cuda") +) -> Tensor: + """Serialize picklable python object to tensor. + + Args: + pyobj (Any): Any picklable python object. + device (torch.device): Device to put on. Defaults to "cuda". + """ + storage = torch.ByteStorage.from_buffer(pickle.dumps(pyobj)) + return torch.ByteTensor(storage).to(device=device) + + +def tensor2obj(tensor: Tensor) -> Any: # type: ignore + """Deserialize tensor to picklable python object. + + Args: + tensor (Tensor): Tensor to be deserialized. + """ + return pickle.loads(tensor.cpu().numpy().tobytes()) + + +def all_reduce_dict( + py_dict: DictStrAny, reduce_op: str = "sum", to_float: bool = True +) -> DictStrAny: # pragma: no cover + """Apply all reduce function for python dict object. + + The code is modified from + https://github.com/Megvii-BaseDetection/YOLOX/blob/main/yolox/utils/allreduce_norm.py. + + NOTE: make sure that py_dict in different ranks has the same keys and + the values should be in the same shape. Currently only supports + NCCL backend. + + Args: + py_dict (DictStrAny): Dict to be applied all reduce op. + reduce_op (str): Operator, could be 'sum' or 'mean'. Default: 'sum'. + to_float (bool): Whether to convert all values of dict to float. + Default: True. + + Returns: + DictStrAny: reduced python dict object. + """ + world_size = get_world_size() + if world_size == 1: + return py_dict + + # all reduce logic across different devices. + py_key = list(py_dict.keys()) + if not isinstance(py_dict, OrderedDict): + py_key_tensor = obj2tensor(py_key) + dist.broadcast(py_key_tensor, src=0) + py_key = tensor2obj(py_key_tensor) + + tensor_shapes = [py_dict[k].shape for k in py_key] + tensor_numels = [py_dict[k].numel() for k in py_key] + + if to_float: + flatten_tensor = torch.cat( + [py_dict[k].flatten().float() for k in py_key] + ) + else: + flatten_tensor = torch.cat([py_dict[k].flatten() for k in py_key]) + + dist.all_reduce(flatten_tensor, op=dist.ReduceOp.SUM) + if reduce_op == "mean": + flatten_tensor /= world_size + + split_tensors = [ + x.reshape(shape) + for x, shape in zip( + torch.split(flatten_tensor, tensor_numels), tensor_shapes + ) + ] + out_dict: DictStrAny = dict(zip(py_key, split_tensors)) + if isinstance(py_dict, OrderedDict): + out_dict = OrderedDict(out_dict) + return out_dict + + +def is_module_wrapper(module: nn.Module) -> bool: + """Checks recursively if a module is wrapped. + + Two modules are regarded as wrapper: DataParallel, DistributedDataParallel. + + Args: + module (nn.Module): The module to be checked. + + Returns: + bool: True if the input module is a module wrapper. + """ + if isinstance(module, (DataParallel, DistributedDataParallel)): + return True + if any(is_module_wrapper(child) for child in module.children()): + return True + return False diff --git a/vis4d/common/imports.py b/vis4d/common/imports.py new file mode 100644 index 0000000000000000000000000000000000000000..c008358b2e5f60548c911a9e44b9a31821fdcaca --- /dev/null +++ b/vis4d/common/imports.py @@ -0,0 +1,62 @@ +"""Check if optional packages required by some modules are available.""" + +from functools import lru_cache +from importlib.util import find_spec + +import torch +from packaging import version + + +@lru_cache() +def package_available(package_name: str) -> bool: + """Check if a package is available in your environment.""" + try: + return find_spec(package_name) is not None + except ModuleNotFoundError: # pragma: no cover + return False + + +# io +H5PY_AVAILABLE = package_available("h5py") + +# vision +MMCV_AVAILABLE = package_available("mmcv") or package_available("mmcv-full") +MMDET_AVAILABLE = package_available("mmdet") +MMSEG_AVAILABLE = package_available("mmseg") +DETECTRON2_AVAILABLE = package_available("detectron2") +TIMM_AVAILABLE = package_available("timm") +FVCORE_AVAILABLE = package_available("fvcore") + +# datasets +WAYMO_AVAILABLE = package_available("waymo") +NUSCENES_AVAILABLE = package_available("nuscenes") +SCALABEL_AVAILABLE = package_available("scalabel") +BDD100K_AVAILABLE = package_available("bdd100k") + +# visualization +OPENCV_AVAILABLE = package_available("cv2") +DASH_AVAILABLE = package_available("dash") +OPEN3D_AVAILABLE = package_available("open3d") +PLOTLY_AVAILABLE = package_available("plotly") + +# vis4d cuda ops +VIS4D_CUDA_OPS_AVAILABLE = package_available("vis4d_cuda_ops") + +# logging +TENSORBOARD_AVAILABLE = package_available("tensorboardX") or package_available( + "tensorboard" +) + + +def is_torch_tf32_available() -> bool: # pragma: no cover + """Check if torch TF32 is available. + + Returns: + bool: True if torch TF32 is available. + """ + return not ( + not torch.cuda.is_available() + or torch.version.cuda is None + or int(torch.version.cuda.split(".", maxsplit=1)[0]) < 11 + or version.parse(torch.__version__) < version.parse("1.7") + ) diff --git a/vis4d/common/logging.py b/vis4d/common/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..c4c62e82b9321dee0a77ced21ef80acd0735a57c --- /dev/null +++ b/vis4d/common/logging.py @@ -0,0 +1,139 @@ +"""This module contains logging utility functions. + +We provide utilities for setting up a logger and logging in a distributed +setting. +""" + +from __future__ import annotations + +import logging +import os +import sys +import warnings + +from termcolor import colored + +from vis4d.common.distributed import rank_zero_only +from vis4d.common.typing import ArgsType +from vis4d.config.typing import ExperimentConfig + + +def _debug(*args: ArgsType, stacklevel: int = 2, **kwargs: ArgsType) -> None: + """Function used to log debug-level messages.""" + log = logging.getLogger(__name__) + kwargs["stacklevel"] = stacklevel + log.debug(*args, **kwargs) + + +@rank_zero_only +def rank_zero_debug( + *args: ArgsType, stacklevel: int = 4, **kwargs: ArgsType +) -> None: + """Function used to log debug-level messages only on global rank 0.""" + _debug(*args, stacklevel=stacklevel, **kwargs) + + +def _info(*args: ArgsType, stacklevel: int = 2, **kwargs: ArgsType) -> None: + """Function used to log info-level messages.""" + kwargs["stacklevel"] = stacklevel + log = logging.getLogger(__name__) + log.info(*args, **kwargs) + + +@rank_zero_only +def rank_zero_info( + *args: ArgsType, stacklevel: int = 4, **kwargs: ArgsType +) -> None: + """Function used to log info-level messages only on global rank 0.""" + _info(*args, stacklevel=stacklevel, **kwargs) + + +def _warn( + message: str | Warning, stacklevel: int = 2, **kwargs: ArgsType +) -> None: + """Function used to log warn-level messages.""" + warnings.warn(message, stacklevel=stacklevel, **kwargs) + + +@rank_zero_only +def rank_zero_warn( + message: str | Warning, stacklevel: int = 4, **kwargs: ArgsType +) -> None: + """Function used to log warn-level messages only on global rank 0.""" + _warn(message, stacklevel=stacklevel, **kwargs) + + +class _ColorFormatter(logging.Formatter): + """Formatter for terminal messages with colors.""" + + def formatMessage(self, record: logging.LogRecord) -> str: + """Add appropriate color to log message.""" + log = super().formatMessage(record) + if record.levelno == logging.WARNING: + prefix = colored("WARNING", "red", attrs=["blink"]) + elif record.levelno in [logging.ERROR, logging.CRITICAL]: + prefix = colored("ERROR", "red", attrs=["blink", "underline"]) + else: + return log + return prefix + " " + log + + +@rank_zero_only +def setup_logger( + logger: logging.Logger, + filepath: None | str = None, + color: bool = True, + std_out_level: int = logging.INFO, +) -> None: + """Configure logging for Vis4D. + + Args: + logger (logging.Logger): The logger instance to be configured. + filepath (None | str, optional): The filepath to the log file that + stores the console output. Defaults to None. + color (bool, optional): Whether to use a colored console output. + Defaults to True. + std_out_level (int, optional): Which logging level to output to the + console. Defaults to logging.INFO. Note that all levels will be + logged to file. + """ + # get logger, remove handlers to re-define behavior + for h in logger.handlers: + logger.removeHandler(h) + + # console logger + plain_formatter = logging.Formatter( + "[%(asctime)s] Vis4D %(levelname)s: %(message)s", + datefmt="%m/%d %H:%M:%S", + ) + + ch = logging.StreamHandler(stream=sys.stdout) + ch.setLevel(std_out_level) + if color: + formatter = _ColorFormatter( + colored("[%(asctime)s Vis4D]: ", "green") + "%(message)s", + datefmt="%m/%d %H:%M:%S", + ) + ch.setFormatter(formatter) + else: + ch.setFormatter(plain_formatter) + logger.addHandler(ch) + + # file logger + if filepath is not None: + os.makedirs(os.path.dirname(filepath), exist_ok=True) + fh = logging.FileHandler(filepath) + fh.setLevel(logging.DEBUG) + fh.setFormatter(plain_formatter) + logger.addHandler(fh) + + +@rank_zero_only +def dump_config(config: ExperimentConfig, config_file: str) -> None: + """Dump the configuration to a file. + + Args: + config (ExperimentConfig): The configuration to dump. + config_file (str): The path to the file to dump the configuration to. + """ + config.dump(config_file) diff --git a/vis4d/common/named_tuple.py b/vis4d/common/named_tuple.py new file mode 100644 index 0000000000000000000000000000000000000000..8f724c755ecd94d2626c7c163eb8e95240f4a8d3 --- /dev/null +++ b/vis4d/common/named_tuple.py @@ -0,0 +1,52 @@ +"""This module contains dictionary utility functions.""" + +from __future__ import annotations + +from typing import Any, NamedTuple + + +def get_all_keys(entry: NamedTuple) -> list[str]: + """Get all keys in a NamedTuple.""" + keys = [] + for key in entry._fields: + if is_namedtuple(getattr(entry, key)): + keys.extend( + [f"{key}.{k}" for k in get_all_keys(getattr(entry, key))] + ) + else: + keys.append(key) + return keys + + +def get_from_namedtuple(entry: NamedTuple, key: str) -> Any: # type: ignore + """Get a value from a nested Named tuple. + + Example passing key = "test.my.data" will resolve the value of the + named tuple at 'test' 'my' 'data'. + + Raises: + ValueError: If the key is not present in the named tuple. + """ + keys = key.split(".") + first_key = keys[0] + if not hasattr(entry, first_key): + raise ValueError( + f"Key {first_key} not in named tuple! Current keys: " + f"{get_all_keys(entry)}" + ) + if len(keys) == 1: + return getattr(entry, first_key) + + return get_from_namedtuple(getattr(entry, first_key), ".".join(keys[1:])) + + +def is_namedtuple(obj: object) -> bool: + """Check if obj is namedtuple. + + https://github.com/pytorch/pytorch/blob/v1.8.1/torch/nn/parallel/scatter_gather.py#L4-L8 + """ + return ( + isinstance(obj, tuple) + and hasattr(obj, "_asdict") + and hasattr(obj, "_fields") + ) diff --git a/vis4d/common/prettyprint.py b/vis4d/common/prettyprint.py new file mode 100644 index 0000000000000000000000000000000000000000..ca655d9feb5ed6e92c87d24c80335050f9055983 --- /dev/null +++ b/vis4d/common/prettyprint.py @@ -0,0 +1,81 @@ +"""This module contains utilities for pretty printing.""" + +from typing import Any + +import numpy as np +import torch + + +class PrettyRepMixin: + """Creates a pretty string representation of a class with parameters. + + Examples: + >>> class TestClass(PrettyRepMixin): + ... def __init__(self, a: int, b: str): + ... self.a = a + ... self.b = b + >>> obj = TestClass(1, 'hello') + >>> str(obj) + 'TestClass(a=1, b=hello)' + """ + + def __repr__(self) -> str: + """Return a string representation of the class and its parameters. + + Returns: + The string representation of the class and its parameters. + + Examples: + >>> class TestClass(PrettyRepMixin): + ... def __init__(self, a: int, b: str): + ... self.a = a + ... self.b = b + >>> obj = TestClass(1, 'hello') + >>> obj.__repr__() + 'TestClass(a=1, b=hello)' + """ + attr_str = "" + for k, v in vars(self).items(): + if k != "type" and not k.startswith("_"): + attr_str += f"{k}={str(v)}, " + attr_str = attr_str.rstrip(", ") + return f"{self.__class__.__name__}({attr_str})" + + +def describe_shape(obj: Any) -> str: # type: ignore + """Recursively output the shape of tensors in an object's structure. + + Args: + obj (Any): The object to describe the shape of. Can be a dictionary, + list, torch.Tensor, numpy.ndarray, float, or any other type. + + Returns: + str: A string representing the shapes of all tensors in the object's + structure. + + Examples: + >>> describe_shape({'a': torch.rand(2, 3)}) + "{a: shape[2, 3]}" + >>> describe_shape({'a': [torch.rand(2, 3), torch.rand(4, 5)]}) + "{a: [shape[2, 3], shape[4, 5]]}" + >>> describe_shape([torch.rand(2, 3), {'a': torch.rand(4, 5)}]) + "[shape[2, 3], {a: shape[4, 5]}]" + """ + log_str = "" + if isinstance(obj, dict): + log_str += "{" + log_str += ", ".join( + [f"{k}: {describe_shape(obj[k])}" for k in obj.keys()] + ) + log_str += "}" + elif isinstance(obj, list): + log_str += "[" + log_str += ", ".join([describe_shape(v) for v in obj]) + log_str += "]" + elif isinstance(obj, (torch.Tensor, np.ndarray)): + log_str += f"shape[{', '.join([str(s) for s in obj.shape])}]" + elif isinstance(obj, float): + log_str += f"{obj:.4f}" + else: + log_str += str(obj) + return log_str diff --git a/vis4d/common/progress.py b/vis4d/common/progress.py new file mode 100644 index 0000000000000000000000000000000000000000..984d80556c45a2364590b22654556a28f2a9f3c5 --- /dev/null +++ b/vis4d/common/progress.py @@ -0,0 +1,54 @@ +"""This module contains utilities for progress bar.""" + +from __future__ import annotations + +import datetime + +from torch import Tensor + +from .time import Timer +from .typing import MetricLogs + + +def compose_log_str( + prefix: str, + cur_iter: int, + total_iters: int, + timer: Timer, + metrics: None | MetricLogs = None, +) -> str: + """Compose log str from given information.""" + time_sec_tot = timer.time() + time_sec_avg = time_sec_tot / cur_iter + eta_sec = time_sec_avg * (total_iters - cur_iter) + if not eta_sec == float("inf"): + eta_str = str(datetime.timedelta(seconds=int(eta_sec))) + else: # pragma: no cover + eta_str = "---" + + metrics_list: list[str] = [] + if metrics is not None: + for k, v in metrics.items(): + name = k.split("/")[-1] # remove prefix, e.g. train/loss + if isinstance(v, (Tensor, float)): + # display more digits for small values + if abs(v) < 1e-3: # type: ignore[operator] + kv_str = f"{name}: {v:.3e}" + else: + kv_str = f"{name}: {v:.4f}" + else: + kv_str = f"{name}: {v}" + if name == "loss": # put total loss first + metrics_list.insert(0, kv_str) + else: + metrics_list.append(kv_str) + + time_str = f"ETA: {eta_str}, " + ( + f"{time_sec_avg:.2f}s/it" + if time_sec_avg > 1 + else f"{1/time_sec_avg:.2f}it/s" + ) + logging_str = f"{prefix}: {cur_iter}/{total_iters}, {time_str}" + if len(metrics_list) > 0: + logging_str += ", " + ", ".join(metrics_list) + return logging_str diff --git a/vis4d/common/release.py b/vis4d/common/release.py new file mode 100644 index 0000000000000000000000000000000000000000..3bcb84997d1b375fb607eb737f4e36844c6b5fc6 --- /dev/null +++ b/vis4d/common/release.py @@ -0,0 +1,69 @@ +"""Convert Vis4D model weights for release.""" + +from __future__ import annotations + +import argparse +import hashlib +import os + +import torch + + +def save_weights_with_hash( + state_dict: dict[str, torch.Tensor], + path: str, + filename: str, + digits: int = 6, +) -> None: + """Saves the model weights and append a 6-digit hash to the filename. + + Args: + state_dict (dict[str, torch.Tensor]): The model weights to save. + path (str): The directory path to save the model. + filename (str): The filename to save the model. + digits (int, optional): The number of digits to use for the hash. + Defaults to 6. + """ + os.makedirs(path, exist_ok=True) + with open(os.path.join(path, filename), "wb") as f: + torch.save(state_dict, f) + + # Create a hash of the file + sha256_hash = hashlib.sha256() + with open(os.path.join(path, filename), "rb") as f: + for byte_block in iter(lambda: f.read(4096), b""): + sha256_hash.update(byte_block) + + # Get the hexadecimal representation of the hash + short_hash = sha256_hash.hexdigest()[:digits] + os.rename( + os.path.join(path, filename), + os.path.join(path, f"{filename}_{short_hash}.pt"), + ) + + +def main() -> None: + """Main function.""" + parser = argparse.ArgumentParser( + description="Save trained model checkpoint with a filename hash." + ) + parser.add_argument("path", type=str, help="The path to the checkpoint.") + parser.add_argument( + "--outdir", + type=str, + help="The path to output the model.", + default="./vis4d-workspace/release", + ) + parser.add_argument( + "--name", type=str, help="The base name of the released file." + ) + args = parser.parse_args() + + checkpoint = torch.load(args.path, map_location=torch.device("cpu")) + state_dict = {"state_dict": checkpoint["state_dict"]} + + save_weights_with_hash(state_dict, args.outdir, args.name) + + +if __name__ == "__main__": + main() diff --git a/vis4d/common/slurm.py b/vis4d/common/slurm.py new file mode 100644 index 0000000000000000000000000000000000000000..dea2a7282af42548a3ccade1c57def82157959e3 --- /dev/null +++ b/vis4d/common/slurm.py @@ -0,0 +1,65 @@ +"""Slurm job submission. + +Code adapted from: + https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py +""" + +import os +import socket +import subprocess + +import torch + + +def _find_free_port() -> str: + """Find a free port on the current machine.""" + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + # Binding to port 0 will cause the OS to find an available port for us + sock.bind(("", 0)) + port = sock.getsockname()[1] + sock.close() + # NOTE: there is still a chance the port could be taken by other processes. + return port + + +def _is_free_port(port: int) -> bool: + """Check if a port is free on the current machine.""" + ips = socket.gethostbyname_ex(socket.gethostname())[-1] + ips.append("localhost") + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + return all(s.connect_ex((ip, port)) != 0 for ip in ips) + + +def init_dist_slurm() -> None: + """Initialize slurm distributed training environment.""" + proc_id = int(os.environ["SLURM_PROCID"]) + ntasks = int(os.environ["SLURM_NTASKS"]) + + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(proc_id % num_gpus) + + # WORLD_SIZE + os.environ["WORLD_SIZE"] = str(ntasks) + + # use MASTER_ADDR in the environment variable if it already exists + if "MASTER_ADDR" not in os.environ: + node_list = os.environ["SLURM_NODELIST"] + addr = subprocess.getoutput( + f"scontrol show hostname {node_list} | head -n1" + ) + os.environ["MASTER_ADDR"] = addr + + # use MASTER_PORT in the environment variable if it already exists + if "MASTER_PORT" not in os.environ: + # if torch.distributed default port(29500) is available + # then use it, else find a free port + if _is_free_port(29500): + os.environ["MASTER_PORT"] = "29500" + else: + os.environ["MASTER_PORT"] = str(_find_free_port()) + + # LOCAL RANK + os.environ["LOCAL_RANK"] = str(proc_id % num_gpus) + + # GLOBAL RANK + os.environ["RANK"] = str(proc_id) diff --git a/vis4d/common/time.py b/vis4d/common/time.py new file mode 100644 index 0000000000000000000000000000000000000000..5c33a56c8ecedba5b5929d036dbe26222f1b4390 --- /dev/null +++ b/vis4d/common/time.py @@ -0,0 +1,62 @@ +"""This module contains utilities for tracking execution time.""" + +from __future__ import annotations + +from time import perf_counter +from typing import no_type_check + + +@no_type_check +def timeit(func): + """Function to be used as decorator to time a function.""" + + def timed(*args, **kwargs): + tic = perf_counter() + result = func(*args, **kwargs) + toc = perf_counter() + print(f"{func.__name__} {(toc - tic) * 1000:.2f} ms") + return result + + return timed + + +class Timer: # pragma: no cover + """Timer class based on perf_counter.""" + + def __init__(self) -> None: + """Creates an instance of the class.""" + self._tic = perf_counter() + self._toc: None | float = None + self.paused = False + + def reset(self) -> None: + """Reset timer.""" + self._tic = perf_counter() + self._toc = None + self.paused = False + + def pause(self) -> None: + """Pause function.""" + if self.paused: + raise ValueError("Timer already paused!") + self._toc = perf_counter() + self.paused = True + + def resume(self) -> None: + """Resume function.""" + if not self.paused: + raise ValueError("Timer is not paused!") + assert self._toc is not None + self._tic = perf_counter() - (self._toc - self._tic) + self._toc = None + self.paused = False + + def time(self, milliseconds: bool = False) -> float: + """Return elapsed time.""" + if not self.paused: + self._toc = perf_counter() + assert self._toc is not None + time_elapsed = self._toc - self._tic + if milliseconds: + return time_elapsed * 1000 + return time_elapsed diff --git a/vis4d/common/typing.py b/vis4d/common/typing.py new file mode 100644 index 0000000000000000000000000000000000000000..0a23a96bb1430f6034e4282b230d357120f37100 --- /dev/null +++ b/vis4d/common/typing.py @@ -0,0 +1,74 @@ +"""Common type definitions. + +Here we define commonly used types like specific numpy array and tensor types. +""" + +from collections.abc import Callable +from typing import Any, Dict, Iterable, Union + +import numpy as np +import numpy.typing as npt +from torch import ( # pylint: disable=no-name-in-module + BoolTensor, + ByteTensor, + FloatTensor, + IntTensor, + Tensor, +) + +NDArrayBool = npt.NDArray[np.bool_] +NDArrayF32 = npt.NDArray[np.float32] +NDArrayF64 = npt.NDArray[np.float64] +NDArrayFloat = Union[NDArrayF32, NDArrayF64] +NDArrayI32 = npt.NDArray[np.int32] +NDArrayI64 = npt.NDArray[np.int64] +NDArrayInt = Union[NDArrayI32, NDArrayI64] +NDArrayUI8 = npt.NDArray[np.uint8] +NDArrayUI16 = npt.NDArray[np.uint16] +NDArrayUI32 = npt.NDArray[np.uint32] +NDArrayUInt = Union[ # pylint: disable=invalid-name + NDArrayUI8, NDArrayUI16, NDArrayUI32 +] +NDArrayNumber = Union[NDArrayBool, NDArrayFloat, NDArrayInt, NDArrayUInt] + +MetricLogs = Dict[str, Union[float, int, Tensor]] +DictStrAny = Dict[str, Any] # type: ignore +DictStrArrNested = Dict[str, Union[Tensor, Dict[str, Tensor]]] +ArgsType = Any # type: ignore +ModelOutput = DictStrAny +TorchCheckpoint = DictStrAny +LossesType = Dict[str, Tensor] +TorchLossFunc = Callable[..., Any] # type: ignore +GenericFunc = Callable[..., Any] # type: ignore + +ArrayIterableFloat = Iterable[Union[float, "ArrayIterableFloat"]] +ArrayIterableBool = Iterable[Union[bool, "ArrayIterableBool"]] +ArrayIterableInt = Iterable[Union[int, "ArrayIterableInt"]] +ArrayIterableUInt = Iterable[Union[int, "ArrayIterableUInt"]] + +ArrayLikeFloat = Union[ArrayIterableFloat, NDArrayF32, NDArrayF64, FloatTensor] +ArrayLikeBool = Union[ArrayIterableBool, NDArrayBool, BoolTensor] +ArrayLikeInt = Union[ArrayIterableInt, NDArrayInt, IntTensor] +ArrayLikeUInt = Union[ # pylint: disable=invalid-name + ArrayIterableUInt, NDArrayUInt, ByteTensor +] +ArrayLike = Union[ArrayLikeBool, ArrayLikeFloat, ArrayLikeInt, ArrayLikeUInt] + +ListAny = list[Any] # type: ignore + + +# Trick mypy into not applying contravariance rules to inputs by defining +# forward as a value, rather than a function. See also +# https://github.com/python/mypy/issues/8795 +def unimplemented(self, *args: Any) -> None: # type: ignore + r"""Define the computation performed at every call. + + Should be overridden by all subclasses. + + .. note:: + Although the recipe for forward pass needs to be defined within + this function, one should call the :class:`Module` instance afterwards + instead of this since the former takes care of running the + registered hooks while the latter silently ignores them. + """ + raise NotImplementedError() diff --git a/vis4d/common/util.py b/vis4d/common/util.py new file mode 100644 index 0000000000000000000000000000000000000000..38dca5d7db13a56d73bdf43276410e9a4da9f11a --- /dev/null +++ b/vis4d/common/util.py @@ -0,0 +1,95 @@ +"""Utility functions for common usage.""" + +import random +from difflib import get_close_matches + +import numpy as np +import torch +from packaging import version + +from .imports import is_torch_tf32_available +from .logging import rank_zero_info, rank_zero_warn + + +def create_did_you_mean_msg(keys: list[str], query: str) -> str: + """Create a did you mean message. + + Args: + keys (list[str]): List of available keys. + query (str): Query. + + Returns: + str: Did you mean message. + + Examples: + >>> keys = ["foo", "bar", "baz"] + >>> query = "fo" + >>> print(create_did_you_mean_msg(keys, query)) + Did you mean: + foo + """ + msg = "" + if len(keys) > 0: + msg = "Did you mean:\n\t" + msg += "\n\t".join(get_close_matches(query, keys, cutoff=0.75)) + return msg + + +def set_tf32(use_tf32: bool, precision: str) -> None: # pragma: no cover + """Set torch TF32. + + Args: + use_tf32: Whether to use torch TF32. Details: + https://pytorch.org/docs/stable/notes/cuda.html#tf32-on-ampere + precision: Internal precision of float32 matrix multiplications. + Details: https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision # pylint: disable=line-too-long + """ + if use_tf32: # pragma: no cover + rank_zero_info( + "Using Torch TF32. " + + "It might harm the performance due to the precision. " + + "You can turn it off by setting config.use_tf32=False." + ) + if not is_torch_tf32_available(): + rank_zero_warn("Torch TF32 is not available.") + elif ( + version.parse("1.11") + >= version.parse(torch.__version__) + >= version.parse("1.7") + ): + rank_zero_info("Torch TF32 is turned on by default!") + else: + rank_zero_info("Turn on Torch TF32 on matmul.") + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + else: + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + + # Control the precision of matmul operations. + # Equivalent to setting torch.backends.cuda.matmul.allow_tf32. + torch.set_float32_matmul_precision(precision) + + +def init_random_seed() -> int: + """Initialize random seed for the experiment.""" + return int(np.random.randint(2**31)) + + +def set_random_seed(seed: int, deterministic: bool = False) -> None: + """Set random seed. + + Args: + seed (int): Seed to be used. + deterministic (bool): Whether to set the deterministic option for + CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` + to True and `torch.backends.cudnn.benchmark` to False. + Default: False. + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + if deterministic: + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False diff --git a/vis4d/config/__init__.py b/vis4d/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..872ddf1dbdd087cbe35f22595989c40cc5ce44d6 --- /dev/null +++ b/vis4d/config/__init__.py @@ -0,0 +1,17 @@ +"""Config modules.""" + +from .config_dict import ( + FieldConfigDict, + class_config, + copy_and_resolve_references, + delay_instantiation, + instantiate_classes, +) + +__all__ = [ + "copy_and_resolve_references", + "class_config", + "FieldConfigDict", + "delay_instantiation", + "instantiate_classes", +] diff --git a/vis4d/config/__pycache__/__init__.cpython-311.pyc b/vis4d/config/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44ed03e246150a1d51d7af85d1fca04f9210b219 Binary files /dev/null and b/vis4d/config/__pycache__/__init__.cpython-311.pyc differ diff --git a/vis4d/config/__pycache__/config_dict.cpython-311.pyc b/vis4d/config/__pycache__/config_dict.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98a2c69a227b0391110a9c8b3bac50dac4422df0 Binary files /dev/null and b/vis4d/config/__pycache__/config_dict.cpython-311.pyc differ diff --git a/vis4d/config/__pycache__/typing.cpython-311.pyc b/vis4d/config/__pycache__/typing.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d2ebd0661b444ae5ae70b3c58137d51f48ca508 Binary files /dev/null and b/vis4d/config/__pycache__/typing.cpython-311.pyc differ diff --git a/vis4d/config/config_dict.py b/vis4d/config/config_dict.py new file mode 100644 index 0000000000000000000000000000000000000000..d6b90eebb30800492a4ce2465e2af46ab4d0f94c --- /dev/null +++ b/vis4d/config/config_dict.py @@ -0,0 +1,566 @@ +"""Config dict module.""" + +from __future__ import annotations + +import importlib +from collections.abc import Callable, Iterable, Mapping +from typing import Any + +import yaml +from ml_collections import ConfigDict, FieldReference, FrozenConfigDict + +from vis4d.common.named_tuple import get_all_keys, is_namedtuple +from vis4d.common.typing import ArgsType + + +# NOTE: Most of these functions need to deal with unknown parameters and are +# therefore not strictly typed +class FieldConfigDict(ConfigDict): # type: ignore # pylint: disable=too-many-instance-attributes, line-too-long + """A configuration dict which allows to access fields via dot notation. + + This class is a subclass of ConfigDict and overwrites the dot notation to + return a FieldReference instead of a dict. + + For more information on the ConfigDict class, see: + ml_collections.ConfigDict. + + Examples of using the ref and value mode: + >>> config = FieldConfigDict({"a": 1, "b": 2}) + >>> type(config.a) + + >>> config.value_mode() # Set the config to return values + >>> type(config.a) + + """ + + def __init__( # type: ignore + self, + initial_dictionary: Mapping[str, Any] | None = None, + type_safe: bool = True, + convert_dict: bool = True, + ): + """Creates an instance of FieldConfigDict. + + Args: + initial_dictionary: May be one of the following: + + 1) dict. In this case, all values of initial_dictionary that are + dictionaries are also be converted to ConfigDict. However, + dictionaries within values of non-dict type are untouched. + + 2) ConfigDict. In this case, all attributes are uncopied, and only + the top-level object (self) is re-addressed. This is the same + behavior as Python dict, list, and tuple. + + 3) FrozenConfigDict. In this case, initial_dictionary is converted + to a ConfigDict version of the initial dictionary for the + FrozenConfigDict (reversing any mutability changes FrozenConfigDict + made). + + type_safe: If set to True, once an attribute value is assigned, its + type cannot be overridden without .ignore_type() context manager. + + convert_dict: If set to True, all dict used as value in the + ConfigDict will automatically be converted to ConfigDict. + """ + super().__init__(initial_dictionary, type_safe, convert_dict) + object.__setattr__(self, "_return_refs", True) + + @classmethod + def from_yaml(cls, path: str) -> FieldConfigDict: + """Creates a config from a .yaml file. + + Args: + path: The path to the .yaml file that should be loaded. + """ + return cls( + yaml.load( + open(path, "r", encoding="utf-8"), Loader=yaml.UnsafeLoader + ) + ) + + def to_yaml(self, **kwargs: ArgsType) -> str: + """Returns a YAML representation of the object. + + ConfigDict serializes types of fields as well as the values of fields + themselves. Deserializing the YAML representation hence requires using + YAML's UnsafeLoader: + + ``` + yaml.load(cfg.to_yaml(), Loader=yaml.UnsafeLoader) + ``` + + or equivalently: + + ``` + yaml.unsafe_load(cfg.to_yaml()) + ``` + + Please see the PyYAML documentation and https://msg.pyyaml.org/load + for more details on the consequences of this. + + Args: + **kwargs: Keyword arguments for yaml.dump. + + Returns: + YAML representation of the object. + """ + return copy_and_resolve_references(self.value_mode()).to_yaml(**kwargs) + + def dump(self, output_path: str) -> None: + """Writes the config to a .yaml file. + + Args: + output_path: The path to the output file. + """ + with open(output_path, "w", encoding="utf-8") as file: + file.write(self.to_yaml()) + + def set_ref_mode(self, ref_mode: bool) -> None: + """Sets the config to return references instead of values.""" + + def _rec_resolve_iterable( # type: ignore + iterable: Iterable[Any], cfgs: list[FieldConfigDict] + ) -> None: + """Recursively adds all FieldConfigDicts to a list.""" + for item in iterable: + if isinstance(item, FieldConfigDict): + cfgs.append(item) + elif isinstance(item, (list, tuple)): + _rec_resolve_iterable(item, cfgs) + elif isinstance(item, (dict, ConfigDict)): + _rec_resolve_iterable(item.values(), cfgs) + + # Update value of this dict + object.__setattr__(self, "_return_refs", ref_mode) + + # propagate to sub configs + for value in self.values(): + if isinstance(value, FieldConfigDict): + value = value.value_mode() + elif isinstance(value, (list, tuple, ConfigDict, dict)): + cfgs: list[FieldConfigDict] = [] + _rec_resolve_iterable(value, cfgs) + for cfg in cfgs: + cfg.set_ref_mode(ref_mode) + + def ref_mode(self) -> FieldConfigDict: + """Sets the config to return references instead of values.""" + self.set_ref_mode(True) + return self + + def value_mode(self) -> FieldConfigDict: + """Sets the config to return values instead of references.""" + self.set_ref_mode(False) + return self + + def __getitem__(self, key: str) -> FieldReference: + """Returns the reference for the given key.""" + # private properties are always returned as values + + if self._return_refs: + try: + return super().get_ref(key) + except ValueError: + pass + return super().__getitem__(key) + + +def resolve_class_name(clazz: type | Callable[Any, Any] | str) -> str: # type: ignore # pylint: disable=line-too-long + """Resolves the full class name of the given class object, callable or str. + + This function takes a class object and returns the class name as a string. + Args: + clazz (type | Callable[[Any], Any] | str): The object to resolve + the full path of. + + Returns: + str: The full path of the given object. + + Raises: + ValueError: If the given object is a lambda function. + + Examples: + >>> class MyClass: pass + >>> resolve_class_name(MyClass) + '__main__.MyClass' + >>> resolve_class_name("path.to.MyClass") + 'path.to.MyClass' + >>> def my_function(): pass + >>> resolve_class_name(my_function) + '__main__.my_function' + """ + if isinstance(clazz, str): + return clazz + + if clazz.__name__ == "lambda": + raise ValueError( + "Resolving the full class path of lambda functions" + "is not supported. Please define a inline function instead." + ) + + module = clazz.__module__ + if module is None or module == str.__class__.__module__: + return clazz.__name__ + return module + "." + clazz.__name__ + + +def class_config( + clazz: type | Callable[Any, Any] | str, # type: ignore + **kwargs: ArgsType, +) -> ConfigDict: + """Creates a configuration which can be instantiated as a class. + + This function creates a configuration dict which can be passed to + 'instantiate_classes' to create a instance of the given class or functor. + + Example: + >>> class_cfg_obj = class_config("your.module.Module", arg1="arg1", arg2=2) + >>> print(class_cfg_obj) + >>> # Prints : + >>> class_path: your.module.Module + >>> init_args: + >>> arg1: arg1 + >>> arg2: 2 + + >>> # instantiate object + >>> inst_obj = instantiate_classes(class_cfg_obj) + >>> print(type(inst_obj)) # -> Will print + + >>> # Example by directly passing objects: + >>> class MyClass: + >>> def __init__(self, name: str, age: int): + >>> self.name = name + >>> self.age = age + >>> class_cfg_obj = class_config(MyClass, name="John", age= 25) + + >>> print(class_cfg_obj) + >>> # Prints : + >>> class_path: __main__.MyClass + >>> init_args: + >>> name: John + >>> age: 25 + + >>> # instantiate object + >>> inst_obj = instantiate_classes(class_cfg_obj) + >>> print(type(inst_obj)) # -> Will print + >>> print(inst_obj.name) # -> Will print John + + Args: + clazz (type | Callable[[Any], Any] | str): class type or functor or + class string path. + **kwargs (ArgsType): Kwargs to pass to the class constructor. + + Returns: + ConfigDict: _description_ + """ + class_path = resolve_class_name(clazz) + if class_path is None or len(kwargs) == 0: + return ConfigDict({"class_path": class_path}) + return ConfigDict( + {"class_path": class_path, "init_args": ConfigDict(kwargs)} + ) + + +def delay_instantiation(instantiable: ConfigDict) -> ConfigDict: + """Delays the instantiation of the given configuration object. + + This is a somewhat hacky way to delay the initialization of the optimizer + configuration object. It works by replacing the class_path with _class_path + which basically tells the instantiate_classes function to not instantiate + the class. Instead, it returns a function that can be called to instantiate + the class + + Args: + instantiable (ConfigDict): The configuration object to delay the + instantiation of. + """ + instantiable["_class_path"] = instantiable["class_path"] + del instantiable["class_path"] + + return class_config(DelayedInstantiator, instantiable=instantiable) + + +class DelayedInstantiator: + """Class that delays the instantiation of the given configuration object. + + This is a somewhat hacky way to delay the initialization of the optimizer + configuration object. It works by replacing the class_path with _class_path + which basically tells the instantiate_classes function to not instantiate + the class. Instead, it returns a function that can be called to instantiate + the class. + + Args: + instantiable (ConfigDict): The configuration object to delay the + instantiation of. + """ + + def __init__(self, instantiable: ConfigDict) -> None: + """Instantiates the DelayedInstantiator.""" + self.instantiable = instantiable + + def __call__(self, **kwargs: ArgsType) -> Any: # type: ignore + """Instantiates the configuration object.""" + instantiable = class_config( + self.instantiable["_class_path"], + **self.instantiable.get("init_args", {}), + ) + + return instantiate_classes(instantiable, **kwargs) + + +def instantiate_classes(data: ConfigDict | FieldReference, **kwargs: ArgsType) -> ConfigDict | Any: # type: ignore # pylint: disable=line-too-long + """Instantiates all classes in a given ConfigDict. + + This function iterates over the configuration data and instantiates + all classes. Class defintions are provided by a config dict that has + the following structure: + + { + 'data_path': 'path.to.my.class.Class', + 'init_args': ConfigDict( + { + 'arg1': 'value1', + 'arg2': 'value2', + } + ) + } + + Args: + data (ConfigDict | FieldReference): The general configuration object. + **kwargs (ArgsType): Additional arguments to pass to the class + constructor. + + Returns: + ConfigDict | Any: The instantiated objects. + """ + if isinstance(data, FieldReference): # De-Reference the field reference + data = data.get() + + assert isinstance(data, ConfigDict), "Data must be a ConfigDict." + + if isinstance(data, FieldConfigDict): + data.value_mode() # make sure data is in value mode + + if len(kwargs) > 0: + if "init_args" not in data: + data["init_args"] = ConfigDict(kwargs) + else: + for k, v in kwargs.items(): + data["init_args"][k] = v + + resolved_data = copy_and_resolve_references(data) + instantiated_objects = _instantiate_classes(resolved_data) + return instantiated_objects + + +def copy_and_resolve_references( # type: ignore + data: Any, visit_map: dict[int, Any] | None = None +) -> Any: + """Returns a ConfigDict copy with FieldReferences replaced by values. + + If the object is a FrozenConfigDict, the copy returned is also a + FrozenConfigDict. However, note that FrozenConfigDict should already have + FieldReferences resolved to values, so this method effectively produces + a deep copy. + + Note: This method is overwritten from the ConfigDict class and allows to + also resolve FieldReferences in list, tuple and dict. + + Args: + data (Any): object to copy. + visit_map (dict[int, Any]): A mapping from ConfigDict object ids to + their copy. Method is recursive in nature, and it will call + "copy_and_resolve_references(visit_map)" on each encountered + object, unless it is already in visit_map. + + Returns: + Any: ConfigDict copy with previous FieldReferences replaced by values. + """ + if isinstance(data, FieldReference): + data = data.get() + + if is_namedtuple(data): + return type(data)( + **{ + key: copy_and_resolve_references(getattr(data, key)) + for key in get_all_keys(data) + } + ) + + if isinstance(data, (list, tuple)): + return type(data)( + copy_and_resolve_references(value, visit_map) for value in data + ) + + if isinstance(data, dict): + return { + k: copy_and_resolve_references(v, visit_map) + for k, v in data.items() + } + + if not isinstance(data, ConfigDict): + return data + + visit_map = visit_map or {} + config_dict = ConfigDict() + + # copy attributes + super(ConfigDict, config_dict).__setattr__( + "_convert_dict", config_dict.convert_dict + ) + visit_map[id(config_dict)] = config_dict + + for key, value in data._fields.items(): + if isinstance(value, FieldReference): + value = value.get() + + if id(value) in visit_map: + value = visit_map[id(value)] + + elif isinstance(value, ConfigDict): + value = copy_and_resolve_references(value, visit_map) + + elif is_namedtuple(value): + value = type(value)( + **{ + key: copy_and_resolve_references(getattr(value, key)) + for key in get_all_keys(value) + } + ) + + elif isinstance(value, (list, tuple)): + value = type(value)( + copy_and_resolve_references(v, visit_map) for v in value + ) + + elif isinstance(value, dict): + value = { + k: copy_and_resolve_references(v, visit_map) + for k, v in value.items() + } + + if isinstance(data, FrozenConfigDict): + config_dict._frozen_setattr( # pylint:disable=protected-access + key, value + ) + else: + config_dict[key] = value + + # copy attributes + super(ConfigDict, config_dict).__setattr__("_locked", data.is_locked) + super(ConfigDict, config_dict).__setattr__("_type_safe", data.is_type_safe) + return config_dict + + +def _get_index(data: Any) -> Any: # type: ignore + """Internal function to generate a Sequence of indexes for a given object. + + Example: + >>> [data[idx] for idx in _get_index(data)] + + Args: + data (Any): The data entry to get an index for. + + Returns: + Any: Iterable that can be used to index the data entry using e.g. + [data[idx] for idx in _get_index(data)] + """ + if isinstance(data, (list, tuple)): + return range(len(data)) + return data + + +def _instantiate_classes(data: Any) -> Any: # type: ignore + """Instantiates all classes in a given data. + + Data could be ConfigDict, FieldReference, tuple, list or dict. + + This is the recursive implementation of the 'instantiate_classes'. + + This function iterates over the configuration data and instantiates + all classes. Class defintions are provided by a config dict that has + the following structure: + + { + 'data_path': 'path.to.my.class.Class', + 'init_args': ConfigDict( + { + 'arg1': 'value1', + 'arg2': 'value2', + } + ) + } + + Args: + data (Any): The general configuration object. + + Returns: + Any: The ConfigDict with all classes intialized. Or, if the top level + element is a class config, the returned element will be the + instantiated class. + """ + if isinstance(data, FieldReference): + data = data.get() + + if not isinstance(data, (ConfigDict, dict, list, tuple)): + return data + + for key in _get_index(data): + value = data[key] + + if isinstance(value, FieldReference): + value = value.get() + + if isinstance(value, (ConfigDict, dict)): + if isinstance(data, ConfigDict): + with data.ignore_type(): + data[key] = _instantiate_classes(value) + else: + data[key] = _instantiate_classes(value) + + elif is_namedtuple(value): + if isinstance(data, ConfigDict): + with data.ignore_type(): + data[key] = type(value)( + **{ + key: _instantiate_classes(getattr(value, key)) + for key in get_all_keys(value) + } + ) + else: + data[key] = type(value)( + **{ + key: _instantiate_classes(getattr(value, key)) + for key in get_all_keys(value) + } + ) + + elif isinstance(value, (list, tuple)): + if isinstance(data, ConfigDict): + with data.ignore_type(): + data[key] = type(value)( + _instantiate_classes(value[idx]) + for idx in range(len(value)) + ) + else: + data[key] = type(value)( + _instantiate_classes(value[idx]) + for idx in range(len(value)) + ) + + # Instantiate classs + if "class_path" in data and not isinstance(data["class_path"], ConfigDict): + module_name, class_name = data["class_path"].rsplit(".", 1) + init_args = data.get("init_args", {}) + + # Convert ConfigDict to normal dictionary + if isinstance(init_args, ConfigDict): + init_args = init_args.to_dict() + + module = importlib.import_module(module_name) + # Instantiate class + clazz = getattr(module, class_name)(**init_args) + return clazz + + return data diff --git a/vis4d/config/registry.py b/vis4d/config/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..2d39870a9829dbef145b39199a72bad652f8d7da --- /dev/null +++ b/vis4d/config/registry.py @@ -0,0 +1,262 @@ +"""Utility function for registering config files.""" + +from __future__ import annotations + +import glob +import os +import pathlib +import warnings +from typing import Callable, Union + +import yaml +from ml_collections import ConfigDict +from ml_collections.config_flags.config_flags import _LoadConfigModule + +from vis4d.common.dict import flatten_dict, get_dict_nested +from vis4d.common.typing import ArgsType +from vis4d.common.util import create_did_you_mean_msg +from vis4d.config.config_dict import FieldConfigDict +from vis4d.zoo import AVAILABLE_MODELS + +MODEL_ZOO_FOLDER = str( + (pathlib.Path(os.path.dirname(__file__)) / ".." / "zoo").resolve() +) + +# Paths that are used to search for config files. +REGISTERED_CONFIG_PATHS = [MODEL_ZOO_FOLDER] + + +TFunc = Callable[[ArgsType], ArgsType] +TfuncConfDict = Union[Callable[[ArgsType], ConfigDict], type] + + +def register_config( + category: str, name: str +) -> Callable[[TfuncConfDict], None]: + """Register a config in the model zoo for the given name and category. + + The config will then be available via `get_config_by_name` utilities and + located in the AVAILABLE_MODELS dictionary located at + [category][name]. + + Args: + category: Category of the config. + name: Name of the config. + + Returns: + The decorator. + """ + + def decorator(fnc_or_clazz: TfuncConfDict) -> None: + """Decorator for registering a config. + + Args: + fnc_or_clazz: Function or class to register. If a function is + passed, it will be wrapped in a class and the class will be + registered. If a class is passed, it will be registered + directly. + """ + if callable(fnc_or_clazz): + # Directly annotated get_config function. Wrap it and register it. + class Wrapper: + """Wrapper class.""" + + def get_config( + self, *args: ArgsType, **kwargs: ArgsType + ) -> ConfigDict: + """Resolves the get_config function.""" + return fnc_or_clazz(*args, **kwargs) + + module = Wrapper() + else: + # Directly annotated class. Register it. + module = fnc_or_clazz + + # Register the config + if category not in AVAILABLE_MODELS: + AVAILABLE_MODELS[category] = {} + + assert isinstance(AVAILABLE_MODELS[category], dict) + + AVAILABLE_MODELS[category][name] = module + + return decorator + + +def _resolve_config_path(path: str) -> str: + """Resolve the path of a config file. + + Args: + path: Name or path of the config. + If the config is not found at this location, + the function will look for the config in the model zoo folder. + + Returns: + The resolved path of the config. + + Raises: + ValueError: If the config is not found. + """ + if os.path.exists(path): + return path + + # Check for duplicate paths. + found_paths: list[str] = [] + all_paths = [] + + for p in REGISTERED_CONFIG_PATHS: + paths = sorted( + glob.glob( + os.path.join(p, f"**/*{ os.path.splitext(path)[-1]}"), + recursive=True, + ) + ) + print( + paths, + "lookup", + os.path.join(p, f"**/*{ os.path.splitext(path)[-1]}"), + ) + for cfg_path in paths: + if cfg_path.endswith(path): + found_paths.append(cfg_path) + all_paths.extend(paths) + + if len(found_paths) > 1: + warnings.warn( + f"Found multiple paths for config {path}:" + f"{found_paths}. Will load the config from the first one!" + ) + elif len(found_paths) == 0: + hint = create_did_you_mean_msg( + [*all_paths, *[os.path.basename(p) for p in all_paths]], path + ) + raise ValueError( + f"Could not find config {path}. \n" + f"The file does not exists at the path {path} or " + f"in the dedicated locations at {REGISTERED_CONFIG_PATHS}. \n" + f"Please check the path or add the config to the model zoo. \n" + f"Current working directory: {os.getcwd()}\n {hint}" + ) + return found_paths[0] + + +def _load_yaml_config(name_or_path: str) -> FieldConfigDict: + """Loads a .yaml configuration file. + + Args: + name_or_path: Name or path of the config. + If the config is not found at this location, $ + the function will look for the config in the model zoo folder. + + Returns: + The config for the experiment. + """ + path = _resolve_config_path(name_or_path) + with open(path, "r", encoding="utf-8") as yaml_file: + return FieldConfigDict(yaml.load(yaml_file, Loader=yaml.UnsafeLoader)) + + +def _load_py_config( + name_or_path: str, *args: ArgsType, method_name: str = "get_config" +) -> ConfigDict: + """Loads a .py configuration file. + + Args: + name_or_path: Name or path of the config. + If the config is not found at this location, + the function will look for the config in the model zoo folder. + *args: Additional arguments to pass to the config. + method_name: Name of the method to call from the file to get the + config. Defaults to "get_config". + + Returns: + The config for the experiment. + """ + path = _resolve_config_path(name_or_path) + config_module = _LoadConfigModule(f"{os.path.basename(path)}_config", path) + cfg = getattr(config_module, method_name)(*args) + assert isinstance(cfg, ConfigDict) + return cfg + + +def _get_registered_configs( + config_name: str, *args: ArgsType, method_name: str = "get_config" +) -> ConfigDict: + """Get a model from the registered config locations. + + Args: + config_name: Name of the config. This can either be + the full path of the config relative to the registered locations + or the name of the config. + If the config matches multiple configs (e.g. if there are two + conflicting config a/cfg and b/cfg) or if it is not found, + a ValueError is raised. + *args: Additional arguments to pass to the config. + method_name: Name of the method to call from the file to get the + config. Defaults to "get_config". + + Raises: + ValueError: If the config is not found. + + Returns: + The Config. + """ + models = flatten_dict(AVAILABLE_MODELS, os.path.sep) + # check if there is an absolute match for the config + if config_name in models: + module = get_dict_nested( + AVAILABLE_MODELS, config_name.split(os.path.sep) + ) + return getattr(module, method_name)(*args) + # check if there is a partial match for the config + matches = {} + for model in models: + if model.endswith(config_name): + matches[model] = get_dict_nested( + AVAILABLE_MODELS, model.split(os.path.sep) + ) + + if len(matches) > 1: + raise ValueError( + f"Found multiple configs matching {config_name}:" + f"{matches.keys()}.\nPlease specify a unique config name." + ) + if len(matches) == 0: + msg = create_did_you_mean_msg( + [*models, *[os.path.basename(m) for m in models]], config_name + ) + raise ValueError(msg) + + module = list(matches.values())[0] + return getattr(module, method_name)(*args) + + +def get_config_by_name( + name_or_path: str, *args: ArgsType, method_name: str = "get_config" +) -> ConfigDict: + """Get a config by name or path. + + Args: + name_or_path: Name or path of the config. + If the path has a .yaml or .py extension, the function will + load the config from the file. + Otherwise, the function will try to resolve the config from the + registered config locations. You can specify a config by its full + registered path (e.g. "a/b/cfg") or by its name (e.g. "cfg"). + *args: Additional arguments to pass to the config. + method_name: Name of the method to call from the file to get the + config. Defaults to "get_config". + + Returns: + The config. + + Raises: + ValueError: If the config is not found. + """ + if name_or_path.endswith(".yaml"): + return _load_yaml_config(name_or_path) + if name_or_path.endswith(".py"): + return _load_py_config(name_or_path, *args, method_name=method_name) + return _get_registered_configs( + name_or_path, *args, method_name=method_name + ) diff --git a/vis4d/config/show_connection.py b/vis4d/config/show_connection.py new file mode 100644 index 0000000000000000000000000000000000000000..e348815f9a2d12e46f16469d0631f3ff59e74414 --- /dev/null +++ b/vis4d/config/show_connection.py @@ -0,0 +1,551 @@ +"""Show connected components in the config.""" + +from __future__ import annotations + +import inspect +from typing import Any, TypedDict, get_type_hints + +from absl import app # pylint: disable=no-name-in-module +from torch import nn + +from vis4d.common.typing import ArgsType +from vis4d.engine.callbacks import ( + Callback, + EvaluatorCallback, + VisualizerCallback, +) +from vis4d.engine.connectors import CallbackConnector, DataConnector +from vis4d.engine.flag import _CONFIG +from vis4d.engine.loss_module import LossModule +from vis4d.eval.base import Evaluator +from vis4d.vis.base import Visualizer + +from .config_dict import instantiate_classes + + +# Types +class DataConnectionInfo(TypedDict): + """Internal type def for visualization. + + This defines a block component + """ + + in_keys: list[str] + out_keys: list[str] + name: str + + +# Private Functions +def _rename_ds(name: str) -> str: + """Replaces data with d and prediction with p. + + Use this to remap the datasources to shorter names. + + Args: + name: Name to remap + + Returns: + remapped name + """ + return name.replace("data", "d").replace("prediction", "p") + + +def _get_model_conn_infos( + model: nn.Module, +) -> dict[str, DataConnectionInfo]: + """Returns the connection infos for a pytorch Model. + + Requires "forward_train" and "forward_test" to be defined and properly + typed! + + Args: + model: Model to extract data from + + Returns: + train_connections, test_connections + """ + train_t = get_type_hints(model.forward_train)["return"] + test_t = get_type_hints(model.forward_test)["return"] + + train_connection_info = DataConnectionInfo( + in_keys=sorted( + list(inspect.signature(model.forward).parameters.keys()) + ), + out_keys=[ + "

-" + e for e in sorted(resolve_named_tuple(train_t, prefix="")) + ], + name=model.__class__.__name__, + ) + + test_connection_info = DataConnectionInfo( + in_keys=sorted( + list(inspect.signature(model.forward).parameters.keys()) + ), + out_keys=[ + "

-" + e for e in sorted(resolve_named_tuple(test_t, prefix="")) + ], + name=model.__class__.__name__, + ) + return {"train": train_connection_info, "test": test_connection_info} + + +def _get_loss_connection_infos(loss: LossModule) -> list[DataConnectionInfo]: + """Returns the connection infos for a loss. + + Args: + loss (LossModule): Custom loss module with .forward() + + Returns: + DataConnectionInfo for the loss. + """ + loss_connection_info = [] + for l in loss.losses: + loss_out = [] + loss_in = [] + for entry, value in l["connector"].key_mapping.items(): + loss_out.append(f"{entry}") + loss_in.append(f"<{_rename_ds(value['source'])}>-" + value["key"]) + + loss_connection_info.append( + DataConnectionInfo( + in_keys=loss_in, out_keys=loss_out, name=l["name"] + ) + ) + + return loss_connection_info + + +def _get_vis_connection_infos( + visualizer: Visualizer, +) -> DataConnectionInfo: + """Returns the connection infos for a visualizer. + + Args: + visualizer: Visualizer to extract data from + + Returns: + DataConnectionInfo for the visualizer. + """ + return DataConnectionInfo( + in_keys=sorted( + list(inspect.signature(visualizer.process).parameters.keys()) + ), + out_keys=[], + name=visualizer.__class__.__name__, + ) + + +def _get_evaluator_connection_infos( + evaluator: Evaluator, +) -> DataConnectionInfo: + """Returns the connection infos for an evaluator. + + Args: + evaluator: Evaluator to extract data from + + Returns: + DataConnectionInfo for the evaluator. + """ + return DataConnectionInfo( + in_keys=sorted( + list(inspect.signature(evaluator.process).parameters.keys()) + ), + out_keys=[], + name=evaluator.__class__.__name__, + ) + + +def _get_data_connector_infos( + data_connector: DataConnector, name: str +) -> DataConnectionInfo: + """Returns the connection infos for a DataConnector. + + Args: + data_connector (DataConnector): Data connector to extract data. + name (str): Name of the data connector. + + Returns: + DataConnectionInfo for the data connector. + """ + return DataConnectionInfo( + in_keys=["-" + e for e in list(data_connector.key_mapping.keys())], + out_keys=list(data_connector.key_mapping.values()), + name=name, + ) + + +def _get_cb_connection_infos( + name: str, + cb_data_connector: None | CallbackConnector = None, +) -> DataConnectionInfo | None: + """Returns the connection infos for a callback.""" + if cb_data_connector is not None: + eval_out = [] + eval_in = [] + for entry, value in cb_data_connector.key_mapping.items(): + eval_out.append(f"{entry}") + eval_in.append(f"<{_rename_ds(value['source'])}>-" + value["key"]) + return DataConnectionInfo( + in_keys=eval_in, out_keys=eval_out, name=name + ) + return None + + +def _get_with_color(key: str, warn_unconnected: bool = True) -> str: + """Prepends colors for internal vsiualization.""" + if "*" in key: + # We connected this one + return f"\033[94m{key}\033[00m" + if "" in key: # key comes from data + return f"\033[90m{key}\033[00m" + + # comes from prediction and is not connected + if warn_unconnected: + return f"\u001b[33m{key}\033[00m" + return f"\033[00m{key}\033[00m" + + +# API Functions +def print_box( + title: str, inputs: list[str], outputs: list[str], use_color: bool = True +) -> str: + """Prints a box with title and in/outputs. + + Args: + title: Title to plot in the middle. + inputs: inputs to plot on the left. + outputs: Outputs to plot on the right. + use_color: Whether to use color in the output. + + Returns: + str: The box as a string. + + Example: + -------------- + -boxes2d | | *boxes2d + -boxes2d_classes | | *boxes2d_classes + -images | Train Data | *images + -input_hw | | *input_hw + -------------- + """ + len_title = len(title) + 4 + + n_lines = max(len(inputs), len(outputs)) + + max_len_inputs = max([0] + [len(inp) for inp in inputs]) + max_len_outputs = max([0] + [len(out) for out in outputs]) + + divider = ( + " " * (max_len_inputs + 1) + + "-" * len_title + + " " * (max_len_outputs + 1) + ) + lines = divider + "\n" + for idx in range(n_lines): + in_data = inputs[idx] if len(inputs) > idx else "" + # left pad + in_key = " " * (max_len_inputs - len(in_data)) + in_data + + out_data = outputs[idx] if len(outputs) > idx else "" + # right pad + out_key = out_data + " " * (max_len_outputs - len(out_data)) + + # title in middle + line = "" + line += _get_with_color(in_key) + line += " | " + line += " " * len(title) if idx != n_lines // 2 else title + line += " | " + line += _get_with_color(out_key) if use_color else out_key + + lines += line + "\n" + + lines += divider + "\n" + return lines + + +def resolve_named_tuple( # type:ignore + clazz: Any, prefix: str = "" +) -> list[str]: + """Returns all fields defined in the clazz t. + + Use this to get all fields defined for an e.g. Named Tuple. + + Args: + clazz: Class that should be resolved + prefix: Prefix to prepend (will be prefix.) + + Returns: + List with all fields and prefixes prepended. + + Examples: + >>> Person = namedtuple("Person", ["name", "age", "gender"]) + >>> Address = namedtuple("Address", ["street", "city", "zipcode"]) + + >>> resolve_named_tuple(clazz=Person, prefix="person") + ["person.name", "person.age", "person.gender"] + + >>> resolve_named_tuple(clazz=Address, prefix="address") + ["address.street", "address.city", "address.zipcode"] + + >>> resolve_named_tuple(clazz=Person, prefix="") + ["name", "age", "gender"] + + With more complex types: + >>> User = namedtuple("User", ["name", "address"]) + >>> user = User(name=Person(name="John"), address=Address(street="str", city="zrh", zipcode="1")) + + >>> resolve_named_tuple(clazz=user, prefix="user") + ["user.name.name", "user.address.street", "user.address.city", + "user.address.zipcode"] + + + + """ + fields = [] + if hasattr(clazz, "_fields"): + for f in clazz._fields: + p = f"{prefix}.{f}" if len(prefix) > 0 else f + fields += resolve_named_tuple(getattr(clazz, f), prefix=p) + return fields + return [prefix] + + +def connect_components( + in_info: DataConnectionInfo, out_info: DataConnectionInfo +) -> None: + """Marks two components as connected. + + Checks if they have intersecting keys and marks them as matched. + Updates the components inplace. + + Args: + in_info (DataConnectionInfo): Input DataConnection + out_info (DataConnectionInfo): Ouput DataConnection + """ + out_keys = [] + for out in out_info["in_keys"]: + out = out.replace("*", "") + out_keys.append(out.split(".")[0]) + + # Check connection + for idx, key in enumerate(in_info["out_keys"]): + key = key.replace("*", "") + for o_idx, o_key in enumerate(out_keys): + if key == o_key: + in_info["out_keys"][idx] = "*" + key + out_info["in_keys"][o_idx] = ( + " " + out_info["in_keys"][o_idx].replace("*", "") + "*" + ) + + +def prints_datagraph_for_config( + model: nn.Module, + train_data_connector: DataConnector, + test_data_connector: DataConnector, + loss: LossModule, + callbacks: list[Callback], +) -> str: + """Shows the setup of the configuration objects. + + For each components, plots which inputs is connected to which output. + Connected components are marked with "*". Use this to debug your + configuration setup. + + Note, that data loaded from the dataset are highlighted with and data + from model predictions with

. + + Args: + model (nn.Module): Model to plot. + train_data_connector (DataConnector): Train data connector to plot. + test_data_connector (DataConnector): Test data connector to plot. + loss (LossModule): Loss to plot. + callbacks (list[Callback]): Callbacks to plot. + + Returns: + str: The datagraph as a string, that can be printed to the console. + + Example: + The following is train datagraph for FasterRCNN with COCO. + Inputs loaded from dataset are marked with and predictions + with

. Unconnected inputs are missing a (*) sign. + + >>> dg = prints_datagraph_for_config(model, train_data_connector, test_data_connector, loss, callbacks))) + >>> print(dg) + ``` + # TODO: check if this is correct + =================================== + = Training Loop = + =================================== + -------------- + -boxes2d | | *boxes2d + -boxes2d_classes | | *boxes2d_classes + -images | Train Data | *images + -input_hw | | *input_hw + -------------- + -------------- + boxes2d* | |

-proposals + boxes2d_classes* | |

-roi + images* | | *

-rpn + input_hw* | FasterRCNN |

-sampled_proposals + original_hw | |

-sampled_target_indices + | |

-sampled_targets + -------------- + ----------- +

-rpn.cls* | | cls_outs + -input_hw | | images_hw +

-rpn.box* | RPNLoss | reg_outs + -boxes2d | | target_boxes + ----------- + ------------ +

-sampled_proposals.boxes | | boxes +

-sampled_targets.labels | | boxes_mask +

-roi.cls_score | | class_outs +

-roi.bbox_pred | RCNNLoss | regression_outs +

-sampled_targets.boxes | | target_boxes +

-sampled_targets.classes | | target_classes + ------------ + =================================== + = Testing Loop = + =================================== + ------------- + -images | | *images + -input_hw | Test Data | *input_hw + -original_hw | | *original_hw + ------------- + -------------- + boxes2d | |

-boxes + boxes2d_classes | |

-class_ids + images* | FasterRCNN |

-scores + input_hw* | | + original_hw* | | + -------------- + =================================== + = Callbacks = + =================================== + ------------------------- + -original_images | | *images + -sample_names | | *image_names +

-boxes | BoundingBoxVisualizer | *boxes +

-scores | | *scores +

-class_ids | | *class_ids + ------------------------- + ---------------------- + -sample_names | | *coco_image_id +

-boxes | | *pred_boxes +

-scores | COCODetectEvaluator | *pred_scores +

-class_ids | | *pred_classes + ---------------------- + ``` + """ + model_connection_info = _get_model_conn_infos(model) + + # TODO: support more data connectors + assert isinstance(train_data_connector, DataConnector) and isinstance( + test_data_connector, DataConnector + ), "Only DataConnector is supported." + train_data_connection_info = _get_data_connector_infos( + train_data_connector, name="Train Data" + ) + test_data_connection_info = _get_data_connector_infos( + test_data_connector, name="Test Data" + ) + + loss_info = _get_loss_connection_infos(loss) + log_str = "" + + # connect components + log_str += "=" * 35 + "\n" + log_str += "=" + " " * 10 + "Training Loop" + " " * 10 + "=" + "\n" + log_str += "=" * 35 + "\n" + + train_components = [ + train_data_connection_info, + model_connection_info["train"], + ] + loss_info + + for inp, out in zip(train_components[:-1], train_components[1:]): + connect_components(inp, out) + for e in train_components: + log_str += print_box(e["name"], e["in_keys"], e["out_keys"]) + + log_str += "=" * 35 + "\n" + log_str += "=" + " " * 10 + "Testing Loop " + " " * 10 + "=" + "\n" + log_str += "=" * 35 + "\n" + + test_components = [ + test_data_connection_info, + model_connection_info["test"], + ] + + for inp, out in zip(test_components[:-1], test_components[1:]): + connect_components(inp, out) + + for e in test_components: + log_str += print_box(e["name"], e["in_keys"], e["out_keys"]) + + # TODO: Add support for more callbacks and handle train_connector + log_str += "=" * 35 + "\n" + log_str += "=" + " " * 12 + "Callbacks" + " " * 12 + "=" + "\n" + log_str += "=" * 35 + "\n" + + # evaluator and visualizer + callback_components: list[DataConnectionInfo] = [] + + for cb in callbacks: + if isinstance(cb, EvaluatorCallback): + evaluator = cb.evaluator + + connect_info = _get_evaluator_connection_infos(evaluator) + component = _get_cb_connection_infos( + cb.evaluator.__class__.__name__, cb.test_connector + ) + + # found matching connector + if component is not None: + connect_components(component, connect_info) + callback_components.append(component) + + if isinstance(cb, VisualizerCallback): + visualizer = cb.visualizer + + connect_info = _get_vis_connection_infos(visualizer) + + component = _get_cb_connection_infos( + cb.visualizer.__class__.__name__, cb.test_connector + ) + + # found matching connector + if component is not None: + connect_components(component, connect_info) + callback_components.append(component) + + for e in callback_components: + log_str += print_box(e["name"], e["in_keys"], e["out_keys"]) + + return log_str + + +def main( + argv: ArgsType, # pylint: disable=unused-argument +) -> None: # pragma: no cover + """Main entry point to show connected components in the config. + + >>> python -m vis4d.config.show_connection --config vis4d/zoo/faster_rcnn/faster_rcnn_coco.py + """ + config = _CONFIG.value + + train_data_connector = instantiate_classes(config.train_data_connector) + test_data_connector = instantiate_classes(config.test_data_connector) + loss = instantiate_classes(config.loss) + model = instantiate_classes(config.model) + callbacks = [instantiate_classes(cb) for cb in config.callbacks] + + dg = prints_datagraph_for_config( + model, train_data_connector, test_data_connector, loss, callbacks + ) + print(dg) + + +if __name__ == "__main__": # pragma: no cover + app.run(main) diff --git a/vis4d/config/sweep.py b/vis4d/config/sweep.py new file mode 100644 index 0000000000000000000000000000000000000000..83ed954cb6a5695ab05adb5fd4b97398812ce14f --- /dev/null +++ b/vis4d/config/sweep.py @@ -0,0 +1,44 @@ +"""Helper functions for creating sweep configurations.""" + +from __future__ import annotations + +from ml_collections import ConfigDict + +from vis4d.common.typing import ArgsType + + +def grid_search( + param_names: list[str] | str, + param_values: list[ArgsType] | list[list[ArgsType]], +) -> ConfigDict: + """Linear grid search configuration over a list of parameters. + + Returns a configuration object that can be used to perform a grid search + over a list of parameters. + + Args: + param_names (list[str] | str): The name of the parameters to be + sampled. + param_values (list[Any] | list[list[Any]]): The values which + should be sampled. + + Example: + >>> grid_search("params.lr", [0.001, 0.01, 0.1]) + + + Returns: + ConfigDict: The configuration object that can be used to perform a grid + search. + """ + if isinstance(param_names, str): + param_names = [param_names] + param_values = [param_values] + + assert len(param_names) == len(param_values) + + config = ConfigDict() + config.method = "grid" + config.sampling_args = [] + for name, values in zip(param_names, param_values): + config.sampling_args.append([name, values]) + return config diff --git a/vis4d/config/typing.py b/vis4d/config/typing.py new file mode 100644 index 0000000000000000000000000000000000000000..f9161dfd05f5420d1af6629321b78f8562bf0e1e --- /dev/null +++ b/vis4d/config/typing.py @@ -0,0 +1,194 @@ +"""Type definitions for configuration files.""" + +from __future__ import annotations + +from typing import Any, TypedDict + +from ml_collections import ConfigDict, FieldReference +from typing_extensions import NotRequired + +from .config_dict import FieldConfigDict + + +class ParamGroupCfg(TypedDict): + """Parameter group config. + + Attributes: + custom_keys (list[str]): List of custom keys. + lr_mult (NotRequired[float]): Learning rate multiplier. + decay_mult (NotRequired[float]): Weight Decay multiplier. + """ + + custom_keys: list[str] + lr_mult: NotRequired[float] + decay_mult: NotRequired[float] + norm_decay_mult: NotRequired[float] + bias_decay_mult: NotRequired[float] + + +class DataConfig(ConfigDict): # type: ignore + """Configuration for a data set. + + This data object is used to configure the training and test data of an + experiment. In particular, the train_dataloader and test_dataloader + need to be config dicts that can be instantiated as a dataloader. + + Attributes: + train_dataloader (ConfigDict): Configuration for the training + dataloader. + test_dataloader (ConfigDict): Configuration for the test dataloader. + + + Example: + >>> from vis4d.config.types import DataConfig + >>> from vis4d.zoo.base import class_config + >>> from my_package.data import MyDataLoader + >>> cfg = DataConfig() + >>> cfg.train_dataloader = class_config(MyDataLoader, ...) + """ + + train_dataloader: ConfigDict + test_dataloader: ConfigDict + + +class LrSchedulerConfig(ConfigDict): # type: ignore + """Configuration for a learning rate scheduler. + + Attributes: + scheduler (ConfigDict): Configuration for the learning rate scheduler. + begin (int): Begin epoch. + end (int): End epoch. + epoch_based (bool): Whether the learning rate scheduler is epoch based + or step based. + convert_epochs_to_steps (bool): Whether to convert the begin and end + for a step based scheduler to steps automatically based on length + of train dataloader. Enables users to set the iteration breakpoints + as epochs. Defaults to False. + convert_attributes (list[str] | None): List of attributes in the + scheduler that should be converted to steps. Defaults to None. + """ + + scheduler: ConfigDict + begin: int + end: int + epoch_based: bool + convert_epochs_to_steps: bool = False + convert_attributes: list[str] | None = None + + +class OptimizerConfig(ConfigDict): # type: ignore + """Configuration for an optimizer. + + Attributes: + optimizer (ConfigDict): Configuration for the optimizer. + lr_scheduler (list[LrSchedulerConfig] | None): Configuration for the + learning rate scheduler. + param_groups (list[ParamGroupCfg] | None): Configuration for the + parameter groups. + """ + + optimizer: ConfigDict + lr_scheduler: list[LrSchedulerConfig] | None + param_groups: list[ParamGroupCfg] | None + + +class ExperimentParameters(FieldConfigDict): + """Parameters for an experiment. + + Attributes: + samples_per_gpu (int): Number of samples per GPU. + workers_per_gpu (int): Number of workers per GPU. + """ + + samples_per_gpu: int + workers_per_gpu: int + + +class ExperimentConfig(FieldConfigDict): + """Configuration for an experiment. + + This data object is used to configure an experiment. It contains the + minimal required configuration to run an experiment. In particular, the + data, model, optimizers, and loss need to be config dicts that can be + instantiated as a data set, model, optimizer, and loss function, + respectively. + + Attributes: + work_dir (str | FieldReference): The working directory for the + experiment. + experiment_name (str | FieldReference): The name of the experiment. + timestamp (str | FieldReference): The timestamp of the experiment. + version (str | FieldReference): The version of the experiment. + output_dir (str | FieldReference): The output directory for the + experiment. + seed (int | FieldReference): The random seed for the experiment. + log_every_n_steps (int | FieldReference): The number of steps after + which the logs should be written. + use_tf32 (bool | FieldReference): Whether to use tf32. + benchmark (bool | FieldReference): Whether to enable benchmarking. + params (ExperimentParameters): Configuration for the experiment + parameters. + data (DataConfig): Configuration for the dataset. + model (FieldConfigDictOrRef): Configuration for the model. + loss (FieldConfigDictOrRef): Configuration for the loss function. + optimizers (list[OptimizerConfig]): Configuration for the optimizers. + data_connector (FieldConfigDictOrRef): Configuration for the data + connector. + callbacks (list[FieldConfigDictOrRef]): Configuration for the + callbacks which are used in the engine. + """ + + # General + work_dir: str | FieldReference + experiment_name: str | FieldReference + timestamp: str | FieldReference + version: str | FieldReference + output_dir: str | FieldReference + seed: int | FieldReference + log_every_n_steps: int | FieldReference + use_tf32: bool | FieldReference + benchmark: bool | FieldReference + tf32_matmul_precision: str | FieldReference + + params: ExperimentParameters + + # Data + data: DataConfig + + # Model + model: ConfigDict + + # Loss + loss: ConfigDict + + # Optimizer + optimizers: list[OptimizerConfig] + + # Data connector + train_data_connector: ConfigDict + test_data_connector: ConfigDict + + # Callbacks + callbacks: list[ConfigDict] + + +class ParameterSweepConfig(FieldConfigDict): + """Configuration for a parameter sweep. + + Confguration object for a parameter sweep. It contains the minimal required + configuration to run a parameter sweep. + + Attributes: + method (str): Sweep method that should be used (e.g. grid) + sampling_args (list[tuple[str, Any]]): Arguments that should be passed + to the sweep method. E.g. for grid, this would be a list of tuples + of the form (parameter_name, parameter_values). + suffix (str): Suffix that should be appended to the output directory. + This will be interpreted as a string template and can contain + references to the sampling_args. + E.g. "lr_{lr:.2e}_bs_{batch_size}". + """ + + method: str | FieldReference + sampling_args: list[tuple[str, Any]] | FieldReference # type: ignore + suffix: str | FieldReference = "" diff --git a/vis4d/data/__init__.py b/vis4d/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..efa31a3fb5f1097d5b7e0863488ee8c19141f916 --- /dev/null +++ b/vis4d/data/__init__.py @@ -0,0 +1,10 @@ +"""The data package defines the full data pipeline. + +We provide dataset implementations in the `datasets` submodule that return a +common data format `DictData`. This data format is used by the pre-processing +functions in the submodule `transforms`. The preprocessing functions are +composed with the datasets in `DataPipe`. Optionally, a reference view sampler +can be added here. The `DataPipe` is input to `torch.data.DataLoader`, for +which we provide utility functions for instantiation that handle also +batch-wise preprocessing and batch collation. +""" diff --git a/vis4d/data/__pycache__/__init__.cpython-311.pyc b/vis4d/data/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1356948d291f0dc6d86759bd8ad5f72df5baae05 Binary files /dev/null and b/vis4d/data/__pycache__/__init__.cpython-311.pyc differ diff --git a/vis4d/data/__pycache__/const.cpython-311.pyc b/vis4d/data/__pycache__/const.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c51a3c233bb63003ef80eaa7deded1a5318907e Binary files /dev/null and b/vis4d/data/__pycache__/const.cpython-311.pyc differ diff --git a/vis4d/data/__pycache__/typing.cpython-311.pyc b/vis4d/data/__pycache__/typing.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c8957ee83a181a87b23cc17254e1252d25ae595 Binary files /dev/null and b/vis4d/data/__pycache__/typing.cpython-311.pyc differ diff --git a/vis4d/data/cbgs.py b/vis4d/data/cbgs.py new file mode 100644 index 0000000000000000000000000000000000000000..d087ad5419aac8158fef8aeb9490fdf9ee820ad7 --- /dev/null +++ b/vis4d/data/cbgs.py @@ -0,0 +1,153 @@ +"""Class-balanced Grouping and Sampling for 3D Object Detection. + +Implementation of `Class-balanced Grouping and Sampling for Point Cloud 3D +Object Detection `_. +""" + +from __future__ import annotations + +import numpy as np +from torch.utils.data import Dataset + +from vis4d.common.distributed import broadcast, rank_zero_only +from vis4d.common.logging import rank_zero_info +from vis4d.common.time import Timer + +from .datasets.util import print_class_histogram +from .reference import MultiViewDataset +from .typing import DictDataOrList + + +# TODO: Support sensor selection. +class CBGSDataset(Dataset[DictDataOrList]): + """Balance the number of scenes under different classes.""" + + def __init__( + self, + dataset: Dataset[DictDataOrList], + class_map: dict[str, int], + ignore: int = -1, + ) -> None: + """Creates an instance of the class.""" + super().__init__() + self.dataset = dataset + self.has_reference = isinstance(dataset, MultiViewDataset) + self.cat2id = dict(sorted(class_map.items(), key=lambda x: x[1])) + self.ignore = ignore + + rank_zero_info("Wrapping dataset with CBGS...") + sample_indices = self._get_sample_indices() + self.sample_indices = broadcast(sample_indices) + + def _show_histogram( + self, + sample_indices: list[int], + sample_frequencies: list[dict[str, int]], + ) -> None: + """Show class histogram.""" + frequencies = {cat: 0 for cat in self.cat2id.keys()} + + for idx in sample_indices: + freq = sample_frequencies[idx] + for box3d_class in freq: + frequencies[box3d_class] += freq[box3d_class] + + print_class_histogram(frequencies) + + def _get_class_sample_indices( + self, + ) -> tuple[dict[int, list[int]], list[dict[str, int]]]: + """Get sample indices.""" + class_sample_indices: dict[int, list[int]] = { + cat_id: [] for cat_id in self.cat2id.values() + } + sample_frequencies = [] + inv_class_map = {v: k for k, v in self.cat2id.items()} + + # Handle the case that dataset is already wrapped. + if hasattr(self.dataset, "dataset"): + dataset = self.dataset.dataset + else: + dataset = self.dataset + + for idx in range(len(dataset)): + assert hasattr( + dataset, "get_cat_ids" + ), "The dataset must have a method `get_cat_ids` to get cat ids." + cat_ids = dataset.get_cat_ids(idx) + cur_cats = {} + frequencies = {cat: 0 for cat in self.cat2id.keys()} + + for cat_id in cat_ids: + if cat_id != self.ignore: + cur_cats[cat_id] = [idx] + frequencies[inv_class_map[cat_id]] += 1 + + sample_frequencies.append(frequencies) + for cat_id, index in cur_cats.items(): + class_sample_indices[cat_id] += index + + return class_sample_indices, sample_frequencies + + @rank_zero_only + def _get_sample_indices(self) -> list[int]: + """Load sample indices. + + Returns: + list[int]: List of indices after class sampling. + """ + t = Timer() + ( + class_sample_indices, + sample_frequencies, + ) = self._get_class_sample_indices() + + duplicated_samples = sum( + len(v) for _, v in class_sample_indices.items() + ) + class_distribution = { + k: len(v) / duplicated_samples + for k, v in class_sample_indices.items() + } + + sample_indices = [] + + frac = 1.0 / len(self.cat2id) + ratios = [ + frac / v if v > 0 else 1 for v in class_distribution.values() + ] + for cls_inds, ratio in zip( + list(class_sample_indices.values()), ratios + ): + sample_indices += np.random.choice( + cls_inds, int(len(cls_inds) * ratio) + ).tolist() + + self._show_histogram(sample_indices, sample_frequencies) + + rank_zero_info( + f"Generating {len(sample_indices)} CBGS samples takes " + + f"{t.time():.2f} seconds." + ) + + return sample_indices + + def __len__(self) -> int: + """Return the length of sample indices. + + Returns: + int: Length of sample indices. + """ + return len(self.sample_indices) + + def __getitem__(self, idx: int) -> DictDataOrList: + """Get original dataset idx according to the given index. + + Args: + idx (int): The index of self.sample_indices. + + Returns: + DictDataOrList: Data of the corresponding index. + """ + ori_index = self.sample_indices[idx] + return self.dataset[ori_index] diff --git a/vis4d/data/const.py b/vis4d/data/const.py new file mode 100644 index 0000000000000000000000000000000000000000..fb26daad3479d70c8bb366af928f9a302cc9166f --- /dev/null +++ b/vis4d/data/const.py @@ -0,0 +1,179 @@ +"""Defines data related constants. + +While the datasets can hold arbitrary data types and formats, this file +provides some constants that are used to define a common data format which is +helpful to use for better data transformation. +""" + +from dataclasses import dataclass +from enum import Enum + +# A custom value to distinguish instance ID and category ID; need to be greater +# than the number of categories. For a pixel in the panoptic result map: +# panaptic_id = instance_id * INSTANCE_OFFSET + category_id +INSTANCE_OFFSET = 1000 + + +class AxisMode(Enum): + """Enum for choosing among different coordinate frame conventions. + + ROS: The coordinate frame aligns with the right hand rule: + - x axis points forward. + - y axis points left. + - z axis points up. + See also: https://www.ros.org/reps/rep-0103.html#axis-orientation + + OpenCV: The coordinate frame aligns with a camera coordinate system: + - x axis points right. + - y axis points down. + - z axis points forward. + See also: https://docs.opencv.org/3.4/d9/d0c/group__calib3d.html + + LiDAR: The coordinate frame aligns with a LiDAR coordinate system: + - x axis points right. + - y axis points forward. + - z axis points up. + See also: https://www.nuscenes.org/nuscenes#data-collection + """ + + ROS = 0 + OPENCV = 1 + LIDAR = 2 + + +@dataclass +class CommonKeys: + """Common supported keys for DictData. + + While DictData can hold arbitrary keys of data, we define a common set of + keys where we expect a pre-defined format to enable the usage of common + data pre-processing operations among different datasets. + + General Info: + - sample_names (str): Name of the sample. + + If the dataset contains videos: + - sequence_names (str): The name of the sequence. + - frame_ids (int): The temporal frame index of the sample. + + Image Based Inputs: + - images (NDArrayF32): Image of shape [1, H, W, C]. + - input_hw (Tuple[int, int]): Shape of image in (height, width) after + transformations. + - original_images (NDArrayF32): Original image of shape [1, H, W, C]. + - original_hw (Tuple[int, int]): Shape of original image in + (height, width). + + Image Classification: + - categories (NDArrayI64): Class labels of shape [1, ]. + + 2D Object Detection: + - boxes2d (NDArrayF32): 2D bounding boxes of shape [N, 4] in xyxy + format. + - boxes2d_classes (NDArrayI64): Classes of 2D bounding boxes of shape + [N,]. + - boxes2d_names (List[str]): Names of 2D bounding box classes, same + order as `boxes2d_classes`. + + 2D Object Tracking: + - boxes2d_track_ids (NDArrayI64): Tracking IDs of 2D bounding boxes of + shape [N,]. + + Segmentation: + - masks (NDArrayUI8): Segmentation masks of shape [N, H, W]. + - seg_masks (NDArrayUI8): Semantic segmentation masks [H, W]. + - instance_masks (NDArrayUI8): Instance segmentation masks of shape + [N, H, W]. + - panoptic_masks (NDArrayI64): Panoptic segmentation masks [H, W]. + + Depth Estimation: + - depth_maps (NDArrayF32): Depth maps of shape [H, W]. + + Optical Flow: + - optical_flows (NDArrayF32): Optical flow maps of shape [H, W, 2]. + + Sensor Calibration: + - intrinsics (NDArrayF32): Intrinsic sensor calibration. Shape [3, 3]. + - extrinsics (NDArrayF32): Extrinsic sensor calibration, transformation + of sensor to world coordinate frame. Shape [4, 4]. + - axis_mode (AxisMode): Coordinate convention of the current sensor. + - timestamp (int): Sensor timestamp in Unix format. + + 3D Point Cloud Data: + - points3d (NDArrayF32): 3D pointcloud data, assumed to be [N, 3] and + in sensor frame. + - colors3d (NDArrayF32): Associated color values for each point [N, 3]. + + 3D Point Cloud Annotations: + - semantics3d (NDArrayI64): Semantic classes of 3D points [N, 1]. + - instances3d (NDArrayI64): Instance IDs of 3D points [N, 1]. + + 3D Object Detection: + - boxes3d (NDArrayF32): 3D bounding boxes of shape [N, 10], each + consists of center (XYZ), dimensions (WLH), and orientation + quaternion (WXYZ). + - boxes3d_classes (NDArrayI64): Associated semantic classes of 3D + bounding boxes of shape [N,]. + - boxes3d_names (List[str]): Names of 3D bounding box classes, same + order as `boxes3d_classes`. + - boxes3d_track_ids (NDArrayI64): Associated tracking IDs of 3D + bounding boxes of shape [N,]. + - boxes3d_velocities (NDArrayF32): Associated velocities of 3D bounding + boxes of shape [N, 3], where each velocity is in the form of + (vx, vy, vz). + """ + + # General Info + sample_names = "sample_names" + sequence_names = "sequence_names" + frame_ids = "frame_ids" + + # image based inputs + images = "images" + input_hw = "input_hw" + original_images = "original_images" + original_hw = "original_hw" + + # Image Classification + categories = "categories" + + # 2D Object Detection + boxes2d = "boxes2d" + boxes2d_classes = "boxes2d_classes" + boxes2d_names = "boxes2d_names" + + # 2D Object Tracking + boxes2d_track_ids = "boxes2d_track_ids" + + # Segmentation + masks = "masks" + seg_masks = "seg_masks" + instance_masks = "instance_masks" + panoptic_masks = "panoptic_masks" + + # Depth Estimation + depth_maps = "depth_maps" + + # Optical Flow + optical_flows = "optical_flows" + + # Sensor Calibration + intrinsics = "intrinsics" + extrinsics = "extrinsics" + axis_mode = "axis_mode" + timestamp = "timestamp" + + # 3D Point Cloud Data + points3d = "points3d" + colors3d = "colors3d" + + # 3D Point Cloud Annotations + semantics3d = "semantics3d" + instances3d = "instances3d" + + # 3D Object Detection + boxes3d = "boxes3d" + boxes3d_classes = "boxes3d_classes" + boxes3d_names = "boxes3d_names" + boxes3d_track_ids = "boxes3d_track_ids" + boxes3d_velocities = "boxes3d_velocities" diff --git a/vis4d/data/data_pipe.py b/vis4d/data/data_pipe.py new file mode 100644 index 0000000000000000000000000000000000000000..323c74f4bb63b1ef6f3afef784161bf36a27e8f6 --- /dev/null +++ b/vis4d/data/data_pipe.py @@ -0,0 +1,139 @@ +"""DataPipe wraps datasets to share the prepossessing pipeline.""" + +from __future__ import annotations + +import random +from collections.abc import Callable, Iterable + +from torch.utils.data import ConcatDataset, Dataset + +from .reference import MultiViewDataset +from .transforms.base import TFunctor +from .typing import DictData, DictDataOrList + + +class DataPipe(ConcatDataset[DictDataOrList]): + """DataPipe class. + + This class wraps one or multiple instances of a PyTorch Dataset so that the + preprocessing steps can be shared across those datasets. Composes dataset + and the preprocessing pipeline. + """ + + def __init__( + self, + datasets: Dataset[DictDataOrList] | Iterable[Dataset[DictDataOrList]], + preprocess_fn: Callable[ + [list[DictData]], list[DictData] + ] = lambda x: x, + ): + """Creates an instance of the class. + + Args: + datasets (Dataset | Iterable[Dataset]): Dataset(s) to be wrapped by + this data pipeline. + preprocess_fn (Callable[[list[DictData]], list[DictData]]): + Preprocessing function of a single sample. It takes a list of + samples and returns a list of samples. Defaults to identity + function. + """ + if isinstance(datasets, Dataset): + datasets = [datasets] + super().__init__(datasets) + self.preprocess_fn = preprocess_fn + + self.has_reference = any( + _check_reference(dataset) for dataset in datasets + ) + + if self.has_reference and not all( + _check_reference(dataset) for dataset in datasets + ): + raise ValueError( + "All datasets must be MultiViewDataset / has reference if " + + "one of them is." + ) + + def __getitem__(self, idx: int) -> DictDataOrList: + """Wrap getitem to apply augmentations.""" + samples = super().__getitem__(idx) + if isinstance(samples, list): + return self.preprocess_fn(samples) + + return self.preprocess_fn([samples])[0] + + +class MultiSampleDataPipe(DataPipe): + """MultiSampleDataPipe class. + + This class wraps DataPipe to support augmentations that require multiple + images (e.g., Mosaic and Mixup) by sampling additional indices for each + image. NUM_SAMPLES needs to be defined as a class attribute for transforms + that require multi-sample augmentation. + """ + + def __init__( + self, + datasets: Dataset[DictDataOrList] | Iterable[Dataset[DictDataOrList]], + preprocess_fn: list[list[TFunctor]], + ): + """Creates an instance of the class. + + Args: + datasets (Dataset | Iterable[Dataset]): Dataset(s) to be wrapped by + this data pipeline. + preprocess_fn (list[list[TFunctor]]): Preprocessing functions of a + single sample. Different than DataPipe, this is a list of lists + of transformation functions. The inner list is for transforms + that needs to share the same sampled indices (e.g., + GenMosaicParameters and MosaicImages), and the outer list is + for different transforms. + """ + super().__init__(datasets) + self.preprocess_fns = preprocess_fn + + def _sample_indices(self, idx: int, num_samples: int) -> list[int]: + """Sample additional indices for multi-sample augmentation.""" + indices = [idx] + for _ in range(1, num_samples): + indices.append(random.randint(0, len(self) - 1)) + return indices + + def __getitem__(self, idx: int) -> DictDataOrList: + """Wrap getitem to apply augmentations.""" + samples = super(DataPipe, self).__getitem__(idx) + if not isinstance(samples, list): + samples = [samples] + single_view = True + else: + single_view = False + + for preprocess_fn in self.preprocess_fns: + if hasattr(preprocess_fn[0], "NUM_SAMPLES"): + num_samples = preprocess_fn[0].NUM_SAMPLES + aug_inds = self._sample_indices(idx, num_samples) + add_samples = [ + super(DataPipe, self).__getitem__(ind) + for ind in aug_inds[1:] + ] + prep_samples = [] + for i, samp in enumerate(samples): + prep_samples.append(samp) + prep_samples += [ + s[i] if isinstance(s, list) else s for s in add_samples + ] + else: + num_samples = 1 + prep_samples = samples + for prep_fn in preprocess_fn: + prep_samples = prep_fn.apply_to_data(prep_samples) # type: ignore # pylint: disable=line-too-long + samples = prep_samples[::num_samples] + return samples[0] if single_view else samples + + +def _check_reference(dataset: Dataset[DictDataOrList]) -> bool: + """Check if the datasets have reference.""" + has_reference = ( + dataset.has_reference if hasattr(dataset, "has_reference") else False + ) + return has_reference or isinstance(dataset, MultiViewDataset) diff --git a/vis4d/data/datasets/__init__.py b/vis4d/data/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fb7083d7b05f96c2fa6e262cb48a4c5377122c42 --- /dev/null +++ b/vis4d/data/datasets/__init__.py @@ -0,0 +1 @@ +"""Datasets module.""" diff --git a/vis4d/data/datasets/base.py b/vis4d/data/datasets/base.py new file mode 100644 index 0000000000000000000000000000000000000000..a0647f262534618ba40394ad44f1fbf7e1da0834 --- /dev/null +++ b/vis4d/data/datasets/base.py @@ -0,0 +1,118 @@ +"""Base dataset classes. + +We implement a typed version of the PyTorch dataset class here. In addition, we +provide a number of Mixin classes which a dataset can inherit from to implement +additional functionality. +""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import TypedDict + +from torch.utils.data import Dataset as TorchDataset + +from vis4d.common.typing import ArgsType +from vis4d.data.io.base import DataBackend +from vis4d.data.io.file import FileBackend +from vis4d.data.typing import DictData + + +class Dataset(TorchDataset[DictData]): + """Basic pytorch dataset with defined return type.""" + + # Dataset metadata. + DESCRIPTION = "" + HOMEPAGE = "" + PAPER = "" + LICENSE = "" + + # List of all keys supported by this dataset. + KEYS: Sequence[str] = [] + + def __init__( + self, + image_channel_mode: str = "RGB", + data_backend: None | DataBackend = None, + ) -> None: + """Initialize dataset. + + Args: + image_channel_mode (str): Image channel mode to use. Default: RGB. + data_backend (None | DataBackend): Data backend to use. + Default: None. + """ + self.image_channel_mode = image_channel_mode + self.data_backend = ( + data_backend if data_backend is not None else FileBackend() + ) + + def __len__(self) -> int: + """Return length of dataset.""" + raise NotImplementedError + + def __getitem__(self, idx: int) -> DictData: + """Convert single element at given index into Vis4D data format.""" + raise NotImplementedError + + def validate_keys(self, keys_to_load: Sequence[str]) -> None: + """Validate that all keys to load are supported. + + Args: + keys_to_load (list[str]): List of keys to load. + + Raises: + ValueError: Raise if any key is not defined in AVAILABLE_KEYS. + """ + for k in keys_to_load: + if k not in self.KEYS: + raise ValueError(f"Key '{k}' is not supported!") + + +class VideoMapping(TypedDict): + """Grouped dataset sample indices and frame indices.""" + + video_to_indices: dict[str, list[int]] + video_to_frame_ids: dict[str, list[int]] + + +class VideoDataset(Dataset): + """Video datasets. + + Provides video_mapping attribute for video based interface and reference + view samplers. + """ + + def __init__(self, *args: ArgsType, **kwargs: ArgsType) -> None: + """Initialize dataset.""" + super().__init__(*args, **kwargs) + self.video_mapping: VideoMapping = { + "video_to_indices": {}, + "video_to_frame_ids": {}, + } + + def _sort_video_mapping(self, video_mapping: VideoMapping) -> VideoMapping: + """Sort video mapping by frame ids.""" + video_to_indices = video_mapping["video_to_indices"] + video_to_frame_ids = video_mapping["video_to_frame_ids"] + + for seq in video_to_indices: + sorted_zipped = sorted( + list(zip(video_to_indices[seq], video_to_frame_ids[seq])), + key=lambda x: x[1], + ) + sorted_indices, sorted_frame_ids = zip(*sorted_zipped) + video_mapping["video_to_indices"][seq] = list(sorted_indices) + video_mapping["video_to_frame_ids"][seq] = list(sorted_frame_ids) + + return video_mapping + + def _generate_video_mapping(self) -> VideoMapping: + """Group dataset sample by their associated video ID. + + The sample index is an integer while video IDs are string. + + Returns: + VideoMapping: Mapping of video IDs to sample indices and frame IDs. + """ + raise NotImplementedError diff --git a/vis4d/data/datasets/bdd100k.py b/vis4d/data/datasets/bdd100k.py new file mode 100644 index 0000000000000000000000000000000000000000..e3d2c1919cb4a7b4443777f654526e3ddd48ef1d --- /dev/null +++ b/vis4d/data/datasets/bdd100k.py @@ -0,0 +1,126 @@ +"""BDD100K dataset.""" + +from vis4d.common.imports import BDD100K_AVAILABLE, SCALABEL_AVAILABLE + +from .scalabel import Scalabel + +bdd100k_det_map = { + "pedestrian": 0, + "rider": 1, + "car": 2, + "truck": 3, + "bus": 4, + "train": 5, + "motorcycle": 6, + "bicycle": 7, + "traffic light": 8, + "traffic sign": 9, +} +bdd100k_track_map = { + "pedestrian": 0, + "rider": 1, + "car": 2, + "truck": 3, + "bus": 4, + "train": 5, + "motorcycle": 6, + "bicycle": 7, +} +bdd100k_seg_map = { + "road": 0, + "sidewalk": 1, + "building": 2, + "wall": 3, + "fence": 4, + "pole": 5, + "traffic light": 6, + "traffic sign": 7, + "vegetation": 8, + "terrain": 9, + "sky": 10, + "person": 11, + "rider": 12, + "car": 13, + "truck": 14, + "bus": 15, + "train": 16, + "motorcycle": 17, + "bicycle": 18, +} +bdd100k_panseg_map = { + "dynamic": 0, + "ego vehicle": 1, + "ground": 2, + "static": 3, + "parking": 4, + "rail track": 5, + "road": 6, + "sidewalk": 7, + "bridge": 8, + "building": 9, + "fence": 10, + "garage": 11, + "guard rail": 12, + "tunnel": 13, + "wall": 14, + "banner": 15, + "billboard": 16, + "lane divider": 17, + "parking sign": 18, + "pole": 19, + "polegroup": 20, + "street light": 21, + "traffic cone": 22, + "traffic device": 23, + "traffic light": 24, + "traffic sign": 25, + "traffic sign frame": 26, + "terrain": 27, + "vegetation": 28, + "sky": 29, + "person": 30, + "rider": 31, + "bicycle": 32, + "bus": 33, + "car": 34, + "caravan": 35, + "motorcycle": 36, + "trailer": 37, + "train": 38, + "truck": 39, +} + +if BDD100K_AVAILABLE and SCALABEL_AVAILABLE: + from bdd100k.common.utils import load_bdd100k_config + from bdd100k.label.to_scalabel import bdd100k_to_scalabel + from scalabel.label.io import load + from scalabel.label.typing import Dataset as ScalabelData +else: + raise ImportError("bdd100k or scalabel is not installed.") + + +class BDD100K(Scalabel): + """BDD100K type dataset, based on Scalabel.""" + + DESCRIPTION = """BDD100K is a large-scale dataset for driving scene + understanding.""" + HOMEPAGE = "https://www.bdd100k.com/" + PAPER = "https://arxiv.org/abs/1805.04687" + LICENSE = "https://www.bdd100k.com/license" + + def _generate_mapping(self) -> ScalabelData: + """Generate data mapping.""" + bdd100k_anns = load(self.annotation_path) + if self.config_path is None: + return bdd100k_anns # pragma: no cover + frames = bdd100k_anns.frames + assert isinstance(self.config_path, str) + bdd100k_cfg = load_bdd100k_config(self.config_path) + scalabel_frames = bdd100k_to_scalabel(frames, bdd100k_cfg) + return ScalabelData( + frames=scalabel_frames, config=bdd100k_cfg.scalabel, groups=None + ) + + def __repr__(self) -> str: + """Concise representation of the dataset.""" + return f"BDD100KDataset {self.data_root}" diff --git a/vis4d/data/datasets/coco.py b/vis4d/data/datasets/coco.py new file mode 100644 index 0000000000000000000000000000000000000000..44d60528d5f360048fba388ebb10a66d2dd3cd40 --- /dev/null +++ b/vis4d/data/datasets/coco.py @@ -0,0 +1,365 @@ +"""COCO dataset.""" + +from __future__ import annotations + +import contextlib +import io +import os +from collections.abc import Sequence + +import numpy as np +import pycocotools.mask as maskUtils +from pycocotools.coco import COCO as COCOAPI + +from vis4d.common.typing import ArgsType, DictStrAny +from vis4d.data.const import CommonKeys as K +from vis4d.data.typing import DictData + +from .base import Dataset +from .util import CacheMappingMixin, get_category_names, im_decode + +# COCO detection +coco_det_map = { + "person": 0, + "bicycle": 1, + "car": 2, + "motorcycle": 3, + "airplane": 4, + "bus": 5, + "train": 6, + "truck": 7, + "boat": 8, + "traffic light": 9, + "fire hydrant": 10, + "stop sign": 11, + "parking meter": 12, + "bench": 13, + "bird": 14, + "cat": 15, + "dog": 16, + "horse": 17, + "sheep": 18, + "cow": 19, + "elephant": 20, + "bear": 21, + "zebra": 22, + "giraffe": 23, + "backpack": 24, + "umbrella": 25, + "handbag": 26, + "tie": 27, + "suitcase": 28, + "frisbee": 29, + "skis": 30, + "snowboard": 31, + "sports ball": 32, + "kite": 33, + "baseball bat": 34, + "baseball glove": 35, + "skateboard": 36, + "surfboard": 37, + "tennis racket": 38, + "bottle": 39, + "wine glass": 40, + "cup": 41, + "fork": 42, + "knife": 43, + "spoon": 44, + "bowl": 45, + "banana": 46, + "apple": 47, + "sandwich": 48, + "orange": 49, + "broccoli": 50, + "carrot": 51, + "hot dog": 52, + "pizza": 53, + "donut": 54, + "cake": 55, + "chair": 56, + "couch": 57, + "potted plant": 58, + "bed": 59, + "dining table": 60, + "toilet": 61, + "tv": 62, + "laptop": 63, + "mouse": 64, + "remote": 65, + "keyboard": 66, + "cell phone": 67, + "microwave": 68, + "oven": 69, + "toaster": 70, + "sink": 71, + "refrigerator": 72, + "book": 73, + "clock": 74, + "vase": 75, + "scissors": 76, + "teddy bear": 77, + "hair drier": 78, + "toothbrush": 79, +} + +# COCO segmentation categories +coco_seg_map = { + "background": 0, + "airplane": 1, + "bicycle": 2, + "bird": 3, + "boat": 4, + "bottle": 5, + "bus": 6, + "car": 7, + "cat": 8, + "chair": 9, + "cow": 10, + "dining table": 11, + "dog": 12, + "horse": 13, + "motorcycle": 14, + "person": 15, + "potted plant": 16, + "sheep": 17, + "couch": 18, + "train": 19, + "tv": 20, +} + + +class COCO(CacheMappingMixin, Dataset): + """COCO dataset class.""" + + DESCRIPTION = """COCO is a large-scale object detection, segmentation, and + captioning dataset.""" + HOMEPAGE = "http://cocodataset.org" + PAPER = "http://arxiv.org/abs/1405.0312" + LICENSE = "BY-NC-SA 2.0" + + KEYS = [ + K.images, + K.input_hw, + K.original_images, + K.original_hw, + K.sample_names, + K.boxes2d, + K.boxes2d_classes, + K.instance_masks, + K.seg_masks, + ] + + def __init__( + self, + data_root: str, + keys_to_load: Sequence[str] = ( + K.images, + K.boxes2d, + K.boxes2d_classes, + K.instance_masks, + ), + split: str = "train2017", + remove_empty: bool = False, + use_pascal_voc_cats: bool = False, + cache_as_binary: bool = False, + cached_file_path: str | None = None, + **kwargs: ArgsType, + ) -> None: + """Initialize the COCO dataset. + + Args: + data_root (str): Path to the root directory of the dataset. + keys_to_load (tuple[str, ...]): Keys to load from the dataset. + split (split): Which split to load. Default: "train2017". + remove_empty (bool): Whether to remove images with no annotations. + use_pascal_voc_cats (bool): Whether to use Pascal VOC categories. + cache_as_binary (bool): Whether to cache the dataset as binary. + Default: False. + cached_file_path (str | None): Path to a cached file. If cached + file exist then it will load it instead of generating the data + mapping. Default: None. + """ + super().__init__(**kwargs) + + self.data_root = data_root + self.keys_to_load = keys_to_load + self.split = split + self.remove_empty = remove_empty + self.use_pascal_voc_cats = use_pascal_voc_cats + + # handling keys to load + self.validate_keys(keys_to_load) + + self.load_annotations = ( + K.boxes2d in keys_to_load + or K.boxes2d_classes in keys_to_load + or K.instance_masks in keys_to_load + or K.seg_masks in keys_to_load + ) + + self.data, _ = self._load_mapping( + self._generate_data_mapping, + self._filter_data, + cache_as_binary=cache_as_binary, + cached_file_path=cached_file_path, + ) + + if self.use_pascal_voc_cats: + self.category_names = get_category_names(coco_seg_map) + else: + self.category_names = get_category_names(coco_det_map) + + def __repr__(self) -> str: + """Concise representation of the dataset.""" + return ( + f"COCODataset(root={self.data_root}, split={self.split}, " + f"use_pascal_voc_cats={self.use_pascal_voc_cats})" + ) + + def _filter_data(self, data: list[DictStrAny]) -> list[DictStrAny]: + """Remove empty samples.""" + if self.remove_empty: + samples = [] + for sample in data: + if len(sample["anns"]) > 0: + samples.append(sample) + return samples + return data + + def _generate_data_mapping(self) -> list[DictStrAny]: + """Generate coco dataset mapping.""" + annotation_file = os.path.join( + self.data_root, "annotations", "instances_" + self.split + ".json" + ) + with contextlib.redirect_stdout(io.StringIO()): + coco_api = COCOAPI(annotation_file) + cat_ids = sorted(coco_api.getCatIds()) + cats_map = {c["id"]: c["name"] for c in coco_api.loadCats(cat_ids)} + if self.use_pascal_voc_cats: + voc_cats = set(coco_seg_map.keys()) + + img_ids = sorted(coco_api.imgs.keys()) + imgs = coco_api.loadImgs(img_ids) + samples = [] + for img_id, img in zip(img_ids, imgs): + anns = coco_api.imgToAnns[img_id] + if self.use_pascal_voc_cats: + anns = [ + ann + for ann in anns + if cats_map[ann["category_id"]] in voc_cats + ] + for ann in anns: + cat_name = cats_map[ann["category_id"]] + if self.use_pascal_voc_cats: + ann["category_id"] = coco_seg_map[cat_name] + else: + ann["category_id"] = coco_det_map[cat_name] + samples.append({"img_id": img_id, "img": img, "anns": anns}) + return samples + + def __len__(self) -> int: + """Return length of dataset.""" + return len(self.data) + + def __getitem__(self, idx: int) -> DictData: + """Transform coco sample to vis4d input format. + + Returns: + DataDict[DataKeys, Union[torch.Tensor, Dict[Any]]] + """ + data = self.data[idx] + img_h, img_w = data["img"]["height"], data["img"]["width"] + + dict_data: DictData = {} + + if K.images in self.keys_to_load: + img_path = os.path.join( + self.data_root, self.split, data["img"]["file_name"] + ) + im_bytes = self.data_backend.get(img_path) + img = im_decode(im_bytes, mode=self.image_channel_mode) + img_ = np.ascontiguousarray(img, dtype=np.float32)[None] + assert (img_h, img_w) == img_.shape[ + 1:3 + ], "Image's shape doesn't match annotation." + + dict_data[K.sample_names] = data["img"]["id"] + dict_data[K.images] = img_ + dict_data[K.input_hw] = [img_h, img_w] + + if K.original_images in self.keys_to_load: + dict_data[K.original_images] = img_ + dict_data[K.original_hw] = [img_h, img_w] + + if self.load_annotations: + boxes = [] + classes = [] + masks = [] + + for ann in data["anns"]: + if K.boxes2d in self.keys_to_load: + x1, y1, width, height = ann["bbox"] + x2, y2 = x1 + width, y1 + height + boxes.append((x1, y1, x2, y2)) + if ( + K.boxes2d in self.keys_to_load + or K.boxes2d_classes in self.keys_to_load + or K.seg_masks in self.keys_to_load + ): + classes.append(ann["category_id"]) + + if ( + K.seg_masks in self.keys_to_load + or K.instance_masks in self.keys_to_load + ): + mask_ann = ann.get("segmentation", None) + if mask_ann is not None: + if isinstance(mask_ann, list): + rles = maskUtils.frPyObjects( + mask_ann, img_h, img_w + ) + rle = maskUtils.merge(rles) + elif isinstance(mask_ann["counts"], list): + # uncompressed RLE + rle = maskUtils.frPyObjects(mask_ann, img_h, img_w) + else: + # RLE + rle = mask_ann + masks.append(maskUtils.decode(rle)) + else: # pragma: no cover + masks.append(np.empty((img_h, img_w), dtype=np.uint8)) + + box_tensor = ( + np.empty((0, 4), dtype=np.float32) + if not boxes + else np.array(boxes, dtype=np.float32) + ) + mask_tensor = ( + np.empty((0, img_h, img_w), dtype=np.uint8) + if not masks + else np.ascontiguousarray(masks, dtype=np.uint8) + ) + + if K.boxes2d in self.keys_to_load: + dict_data[K.boxes2d] = box_tensor + + if K.boxes2d_classes in self.keys_to_load: + dict_data[K.boxes2d_classes] = np.array( + classes, dtype=np.int64 + ) + + if K.instance_masks in self.keys_to_load: + dict_data[K.instance_masks] = mask_tensor + + if K.seg_masks in self.keys_to_load: + seg_masks = ( + mask_tensor * np.array(classes)[:, None, None] + ).max(axis=0) + seg_masks = seg_masks.astype(np.int64) + seg_masks[mask_tensor.sum(0) > 1] = 255 # discard overlapped + dict_data[K.seg_masks] = seg_masks[None] + + dict_data[K.boxes2d_names] = self.category_names + + return dict_data diff --git a/vis4d/data/datasets/imagenet.py b/vis4d/data/datasets/imagenet.py new file mode 100644 index 0000000000000000000000000000000000000000..6ae6622154262f222110b49bea5e3b0a48790051 --- /dev/null +++ b/vis4d/data/datasets/imagenet.py @@ -0,0 +1,145 @@ +"""ImageNet 1k dataset.""" + +from __future__ import annotations + +import os +import pickle +import tarfile +from collections.abc import Sequence + +import numpy as np + +from vis4d.common.logging import rank_zero_info +from vis4d.common.time import Timer +from vis4d.common.typing import ArgsType +from vis4d.data.const import CommonKeys as K +from vis4d.data.typing import DictData + +from .base import Dataset +from .util import im_decode, to_onehot + + +class ImageNet(Dataset): + """ImageNet 1K dataset.""" + + DESCRIPTION = """ImageNet is a large visual database designed for use in + visual object recognition software research.""" + HOMEPAGE = "http://www.image-net.org/" + PAPER = "http://www.image-net.org/papers/imagenet_cvpr09.pdf" + LICENSE = "http://www.image-net.org/terms-of-use" + + KEYS = [K.images, K.categories] + + def __init__( + self, + data_root: str, + keys_to_load: Sequence[str] = (K.images, K.categories), + split: str = "train", + num_classes: int = 1000, + use_sample_lists: bool = False, + **kwargs: ArgsType, + ) -> None: + """Initialize ImageNet dataset. + + Args: + data_root (str): Path to root directory of dataset. + keys_to_load (list[str], optional): List of keys to load. Defaults + to (K.images, K.categories). + split (str, optional): Dataset split to load. Defaults to "train". + num_classes (int, optional): Number of classes to load. Defaults to + 1000. + use_sample_lists (bool, optional): Whether to use sample lists for + loading the dataset. Defaults to False. + + NOTE: The dataset is expected to be in the following format: + data_root + ├── train.pkl # Sample lists for training set (optional) + ├── val.pkl # Sample lists for validation set (optional) + ├── train + │ ├── n01440764.tar + │ ├── ... + └── val + ├── n01440764.tar + ├── ... + With each tar file containing the images of a single class. The + images are expected to be in ".JPEG" extension. + + Currently, we are not using the DataBackend for loading the tars to + avoid keeping too many file pointers open at the same time. + """ + super().__init__(**kwargs) + self.data_root = data_root + self.keys_to_load = keys_to_load + self.split = split + self.num_classes = num_classes + self.use_sample_lists = use_sample_lists + self.data_infos: list[tuple[tarfile.TarInfo, int]] = [] + self._classes: list[str] = [] + self._load_data_infos() + + def _load_data_infos(self) -> None: + """Load data infos from disk.""" + timer = Timer() + # Load tar files + for file in os.listdir(os.path.join(self.data_root, self.split)): + if file.endswith(".tar"): + self._classes.append(file) + assert len(self._classes) == self.num_classes, ( + f"Expected {self.num_classes} classes, but found " + f"{len(self._classes)} tar files." + ) + self._classes = sorted(self._classes) + + sample_list_path = os.path.join(self.data_root, f"{self.split}.pkl") + if self.use_sample_lists and os.path.exists(sample_list_path): + with open(sample_list_path, "rb") as f: + sample_list = pickle.load(f)[0] + if sample_list[-1][1] == self.num_classes - 1: + self.data_infos = sample_list + else: + raise ValueError( + "Sample list does not match the number of classes. " + "Please regenerate the sample list or set " + "use_sample_lists=False." + ) + # If sample lists are not available, generate them on the fly. + else: + for class_idx, file in enumerate(self._classes): + with tarfile.open( + os.path.join(self.data_root, self.split, file) + ) as f: + members = f.getmembers() + for member in members: + if member.isfile() and member.name.endswith(".JPEG"): + self.data_infos.append((member, class_idx)) + + rank_zero_info(f"Loading {self} takes {timer.time():.2f} seconds.") + + def __len__(self) -> int: + """Return length of dataset.""" + return len(self.data_infos) + + def __getitem__(self, idx: int) -> DictData: + """Convert single element at given index into Vis4D data format.""" + member, class_idx = self.data_infos[idx] + with tarfile.open( + os.path.join(self.data_root, self.split, self._classes[class_idx]), + mode="r:*", # unexclusive read mode + ) as f: + im_bytes = f.extractfile(member) + assert im_bytes is not None, f"Could not extract {member.name}!" + image = im_decode(im_bytes.read()) + + data_dict: DictData = {} + if K.images in self.keys_to_load: + data_dict[K.images] = np.ascontiguousarray( + image, dtype=np.float32 + )[np.newaxis, ...] + image_hw = image.shape[:2] + data_dict[K.input_hw] = image_hw + data_dict[K.original_hw] = image_hw + if K.categories in self.keys_to_load: + data_dict[K.categories] = to_onehot( + np.array(class_idx, dtype=np.int64), self.num_classes + ) + return data_dict diff --git a/vis4d/data/datasets/nuscenes.py b/vis4d/data/datasets/nuscenes.py new file mode 100644 index 0000000000000000000000000000000000000000..a80270bfce21436046bcc3d335a208fff6819e8a --- /dev/null +++ b/vis4d/data/datasets/nuscenes.py @@ -0,0 +1,1011 @@ +"""NuScenes multi-sensor video dataset.""" + +from __future__ import annotations + +import os +from collections import defaultdict +from collections.abc import Sequence + +import numpy as np +import torch +from scipy.spatial.transform import Rotation as R +from tqdm import tqdm + +from vis4d.common.imports import NUSCENES_AVAILABLE +from vis4d.common.logging import rank_zero_info +from vis4d.common.time import Timer +from vis4d.common.typing import ( + ArgsType, + DictStrAny, + NDArrayBool, + NDArrayF32, + NDArrayI64, +) +from vis4d.data.const import AxisMode +from vis4d.data.const import CommonKeys as K +from vis4d.data.typing import DictData +from vis4d.op.geometry.projection import generate_depth_map +from vis4d.op.geometry.transform import ( + inverse_rigid_transform, + transform_points, +) + +from .base import VideoDataset, VideoMapping +from .util import CacheMappingMixin, im_decode, print_class_histogram + +if NUSCENES_AVAILABLE: + from nuscenes import NuScenes as NuScenesDevkit + from nuscenes.can_bus.can_bus_api import NuScenesCanBus + from nuscenes.eval.common.utils import quaternion_yaw + from nuscenes.eval.detection.utils import category_to_detection_name + from nuscenes.scripts.export_2d_annotations_as_json import ( + post_process_coords, + ) + from nuscenes.utils.data_classes import Quaternion + from nuscenes.utils.geometry_utils import ( + box_in_image, + transform_matrix, + view_points, + ) + from nuscenes.utils.splits import create_splits_scenes +else: + raise ImportError("nusenes-devkit is not available.") + +nuscenes_class_map = { + "bicycle": 0, + "motorcycle": 1, + "pedestrian": 2, + "bus": 3, + "car": 4, + "trailer": 5, + "truck": 6, + "construction_vehicle": 7, + "traffic_cone": 8, + "barrier": 9, +} + +nuscenes_attribute_map = { + "cycle.with_rider": 0, + "cycle.without_rider": 1, + "pedestrian.moving": 2, + "pedestrian.standing": 3, + "pedestrian.sitting_lying_down": 4, + "vehicle.moving": 5, + "vehicle.parked": 6, + "vehicle.stopped": 7, + "": 8, +} + +nuscenes_detection_range_map = { + "bicycle": 40, + "motorcycle": 40, + "pedestrian": 40, + "bus": 50, + "car": 50, + "trailer": 50, + "truck": 50, + "construction_vehicle": 50, + "traffic_cone": 30, + "barrier": 30, +} + + +def _get_extrinsics( + ego_pose: DictStrAny, car_from_sensor: DictStrAny +) -> NDArrayF32: + """Get NuScenes sensor to global extrinsics.""" + global_from_car = transform_matrix( + ego_pose["translation"], + Quaternion(ego_pose["rotation"]), + inverse=False, + ) + car_from_sensor_ = transform_matrix( + car_from_sensor["translation"], + Quaternion(car_from_sensor["rotation"]), + inverse=False, + ) + extrinsics = np.dot(global_from_car, car_from_sensor_).astype(np.float32) + return extrinsics + + +class NuScenes(CacheMappingMixin, VideoDataset): + """NuScenes multi-sensor video dataset. + + This dataset loads both LiDAR and camera inputs from the NuScenes dataset + into the Vis4D expected format for multi-sensor, video datasets. + """ + + DESCRIPTION = "NuScenes multi-sensor driving video dataset." + HOMEPAGE = "https://www.nuscenes.org/" + PAPER = "https://arxiv.org/abs/1903.11027" + LICENSE = "https://www.nuscenes.org/license" + + KEYS = [ + K.images, + K.original_hw, + K.input_hw, + K.intrinsics, + K.extrinsics, + K.timestamp, + K.axis_mode, + K.boxes2d, + K.boxes2d_classes, + K.boxes2d_track_ids, + K.boxes3d, + K.boxes3d_classes, + K.boxes3d_track_ids, + ] + + SENSORS = [ + "LIDAR_TOP", + "CAM_FRONT", + "CAM_FRONT_LEFT", + "CAM_FRONT_RIGHT", + "CAM_BACK", + "CAM_BACK_LEFT", + "CAM_BACK_RIGHT", + ] + + CAMERAS = [ + "CAM_FRONT", + "CAM_FRONT_LEFT", + "CAM_FRONT_RIGHT", + "CAM_BACK", + "CAM_BACK_LEFT", + "CAM_BACK_RIGHT", + ] + + def __init__( + self, + data_root: str, + keys_to_load: Sequence[str] = ( + K.images, + K.boxes2d, + K.boxes3d, + ), + sensors: Sequence[str] = ( + "LIDAR_TOP", + "CAM_FRONT", + "CAM_FRONT_LEFT", + "CAM_FRONT_RIGHT", + "CAM_BACK", + "CAM_BACK_LEFT", + "CAM_BACK_RIGHT", + ), + version: str = "v1.0-trainval", + split: str = "train", + max_sweeps: int = 10, + skip_empty_samples: bool = False, + point_based_filter: bool = False, + distance_based_filter: bool = False, + cache_as_binary: bool = False, + cached_file_path: str | None = None, + **kwargs: ArgsType, + ) -> None: + """Creates an instance of the class. + + Args: + data_root (str): Root directory of nuscenes data in original + format. + keys_to_load (tuple[str, ...]): Keys to load from the dataset. + Defaults to (K.images, K.boxes2d, K.boxes3d). + sensors (Sequence[str, ...]): Which sensor to load. Defaults + to ("LIDAR_TOP", "CAM_FRONT", "CAM_FRONT_LEFT", + "CAM_FRONT_RIGHT", "CAM_BACK", "CAM_BACK_LEFT", + "CAM_BACK_RIGHT"). + version (str, optional): Version of the data to load. Defaults to + "v1.0-trainval". + split (str, optional): Split of the data to load. Defaults to + "train". + max_sweeps (int, optional): Maximum number of sweeps for a single + key-frame to load. Defaults to 10. + skip_empty_samples (bool, optional): Whether to skip samples + without annotations. Defaults to False. + point_based_filter (bool, optional): Whether to filter out + samples based on the number of points in the point cloud. + Defaults to False. + distance_based_filter (bool, optional): Whether to filter out + samples based on the distance of the object from the ego + vehicle. Defaults to False. + cache_as_binary (bool): Whether to cache the dataset as binary. + Default: False. + cached_file_path (str | None): Path to a cached file. If cached + file exist then it will load it instead of generating the data + mapping. Default: None. + """ + super().__init__(**kwargs) + self.data_root = data_root + self.keys_to_load = keys_to_load + self.sensors = sensors + self._check_version_and_split(version, split) + self.max_sweeps = max_sweeps + self.skip_empty_samples = skip_empty_samples + + self.point_based_filter = point_based_filter + self.distance_based_filter = distance_based_filter + + # Load annotations + self.samples, self.original_len = self._load_mapping( + self._generate_data_mapping, + self._filter_data, + cache_as_binary=cache_as_binary, + cached_file_path=cached_file_path, + ) + + # Generate video mapping + self.video_mapping = self._generate_video_mapping() + + # Needed for CBGS + def get_cat_ids(self, idx: int) -> list[int]: + """Return the samples.""" + return self.samples[idx]["LIDAR_TOP"]["annotations"]["boxes3d_classes"] + + def _filter_data(self, data: list[DictStrAny]) -> list[DictStrAny]: + """Remove empty samples.""" + if self.split == "test": + return data + + samples = [] + frequencies = {cat: 0 for cat in nuscenes_class_map} + inv_nuscenes_class_map = {v: k for k, v in nuscenes_class_map.items()} + + t = Timer() + for sample in data: + ( + _, + boxes3d, + boxes3d_classes, + boxes3d_attributes, + boxes3d_track_ids, + boxes3d_velocities, + ) = self._filter_boxes(sample["LIDAR_TOP"]["annotations"]) + + sample["LIDAR_TOP"]["annotations"]["boxes3d"] = boxes3d + sample["LIDAR_TOP"]["annotations"][ + "boxes3d_classes" + ] = boxes3d_classes + sample["LIDAR_TOP"]["annotations"][ + "boxes3d_attributes" + ] = boxes3d_attributes + sample["LIDAR_TOP"]["annotations"][ + "boxes3d_track_ids" + ] = boxes3d_track_ids + sample["LIDAR_TOP"]["annotations"][ + "boxes3d_velocities" + ] = boxes3d_velocities + + for box3d_class in boxes3d_classes: + frequencies[inv_nuscenes_class_map[box3d_class]] += 1 + + for cam in NuScenes.CAMERAS: + ( + mask, + boxes3d, + boxes3d_classes, + boxes3d_attributes, + boxes3d_track_ids, + boxes3d_velocities, + ) = self._filter_boxes(sample[cam]["annotations"]) + + sample[cam]["annotations"]["boxes3d"] = boxes3d + sample[cam]["annotations"]["boxes3d_classes"] = boxes3d_classes + sample[cam]["annotations"][ + "boxes3d_attributes" + ] = boxes3d_attributes + sample[cam]["annotations"][ + "boxes3d_track_ids" + ] = boxes3d_track_ids + sample[cam]["annotations"][ + "boxes3d_velocities" + ] = boxes3d_velocities + sample[cam]["annotations"]["boxes2d"] = sample[cam][ + "annotations" + ]["boxes2d"][mask] + + if self.skip_empty_samples: + if len(sample["LIDAR_TOP"]["annotations"]["boxes3d"]) > 0: + samples.append(sample) + else: + samples.append(sample) + + rank_zero_info( + f"Preprocessing {len(data)} frames takes {t.time():.2f}" + " seconds." + ) + + print_class_histogram(frequencies) + + if self.skip_empty_samples: + rank_zero_info( + f"Filtered {len(data) - len(samples)} empty frames." + ) + + return samples + + def _check_version_and_split(self, version: str, split: str) -> None: + """Check that the version and split are valid.""" + assert version in { + "v1.0-trainval", + "v1.0-test", + "v1.0-mini", + }, f"Invalid version {version} for NuScenes!" + self.version = version + + if "mini" in version: + valid_splits = {"mini_train", "mini_val"} + elif "test" in version: + valid_splits = {"test"} + else: + valid_splits = {"train", "val"} + + assert ( + split in valid_splits + ), f"Invalid split {split} for NuScenes {version}!" + self.split = split + + def __repr__(self) -> str: + """Concise representation of the dataset.""" + return f"NuScenesDataset {self.version} {self.split}" + + def _generate_video_mapping(self) -> VideoMapping: + """Group dataset sample indices by their associated video ID. + + The sample index is an integer while video IDs are string. + + Returns: + VideoMapping: Mapping of video IDs to sample indices and frame IDs. + """ + video_to_indices: dict[str, list[int]] = defaultdict(list) + video_to_frame_ids: dict[str, list[int]] = defaultdict(list) + for i, sample in enumerate(self.samples): # type: ignore + seq = sample["scene_name"] + video_to_indices[seq].append(i) + video_to_frame_ids[seq].append(sample["frame_ids"]) + + return self._sort_video_mapping( + { + "video_to_indices": video_to_indices, + "video_to_frame_ids": video_to_frame_ids, + } + ) + + def _generate_data_mapping(self) -> list[DictStrAny]: + """Generate data mapping. + + Returns: + List[DictStrAny]: List of items required to load for a single + dataset sample. + """ + data = NuScenesDevkit( + version=self.version, dataroot=self.data_root, verbose=False + ) + + can_bus_data = NuScenesCanBus(dataroot=self.data_root) + + frames = [] + instance_tokens: list[str] = [] + + scene_names_per_split = create_splits_scenes() + + scenes = [ + scene + for scene in data.scene + if scene["name"] in scene_names_per_split[self.split] + ] + + for scene in tqdm(scenes): + scene_name = scene["name"] + frame_ids = 0 + sample_token = scene["first_sample_token"] + while sample_token: + frame = {} + sample = data.get("sample", sample_token) + + frame["scene_name"] = scene_name + frame["token"] = sample["token"] + frame["frame_ids"] = frame_ids + + sd_rec = data.get("sample_data", sample["data"]["LIDAR_TOP"]) + + # Can bus data + can_bus = self._load_can_bus_data( + scene_name, can_bus_data, sample["timestamp"] + ) + + pose_record = data.get("ego_pose", sd_rec["ego_pose_token"]) + rotation = Quaternion(pose_record["rotation"]) + translation = pose_record["translation"] + + can_bus[:3] = translation + can_bus[3:7] = rotation + patch_angle = quaternion_yaw(rotation) / np.pi * 180 + patch_angle += 360 if patch_angle < 0 else 0 + can_bus[-2] = patch_angle / 180 * np.pi + can_bus[-1] = patch_angle + + frame["can_bus"] = can_bus + + # LIDAR data + lidar_token = sample["data"]["LIDAR_TOP"] + + frame["LIDAR_TOP"] = self._load_lidar_data(data, lidar_token) + + if self.split != "test": + frame["LIDAR_TOP"]["annotations"] = self._load_annotations( + data, + frame["LIDAR_TOP"]["extrinsics"], + sample["anns"], + instance_tokens, + axis_mode=AxisMode.LIDAR, + ) + + # obtain sweeps for a single key-frame + sweeps: list[DictStrAny] = [] + while len(sweeps) < self.max_sweeps: + if sd_rec["prev"] != "": + sweep = self._load_lidar_data(data, sd_rec["prev"]) + sweeps.append(sweep) + sd_rec = data.get("sample_data", sd_rec["prev"]) + else: + break + frame["LIDAR_TOP"]["sweeps"] = sweeps + + # Get the sample data for each camera + for cam in self.CAMERAS: + cam_token = sample["data"][cam] + + frame[cam] = self._load_cam_data(data, cam_token) + + if self.split != "test": + frame[cam]["annotations"] = self._load_annotations( + data, + frame[cam]["extrinsics"], + sample["anns"], + instance_tokens, + axis_mode=AxisMode.OPENCV, + export_2d_annotations=True, + intrinsics=frame[cam]["intrinsics"], + image_hw=frame[cam]["image_hw"], + ) + + # TODO add RADAR, Map + + frames.append(frame) + + sample_token = sample["next"] + frame_ids += 1 + + return frames + + def _load_can_bus_data( + self, + scene_name: str, + can_bus_data: NuScenesCanBus, + sample_timestamp: int, + ) -> list[float]: + """Load can bus data.""" + try: + pose_list = can_bus_data.get_messages(scene_name, "pose") + except: # pylint: disable=bare-except + # server scenes do not have can bus information. + return [0.0] * 18 + + # during each scene, the first timestamp of can_bus may be large than + # the first sample's timestamp + can_bus = [] + last_pose = pose_list[0] + for pose in pose_list: + if pose["utime"] > sample_timestamp: + break + last_pose = pose + + last_pose.pop("utime") + pos = last_pose.pop("pos") + rotation = last_pose.pop("orientation") + can_bus.extend(pos) + can_bus.extend(rotation) + + # 16 elements + for key in last_pose.keys(): + can_bus.extend(last_pose[key]) + can_bus.extend([0.0, 0.0]) + + return can_bus + + def _load_lidar_data( + self, data: NuScenesDevkit, lidar_token: str + ) -> DictStrAny: + """Load LiDAR data. + + Args: + data (NuScenesDevkit): NuScenes toolkit. + lidar_token (str): LiDAR token. + + Returns: + DictStrAny: LiDAR data. + """ + lidar_data = data.get("sample_data", lidar_token) + + sample_name = ( + lidar_data["filename"].split("/")[-1].replace(".pcd.bin", "") + ) + + lidar_path = os.path.join(self.data_root, lidar_data["filename"]) + + calibration_lidar = data.get( + "calibrated_sensor", lidar_data["calibrated_sensor_token"] + ) + + ego_pose = data.get("ego_pose", lidar_data["ego_pose_token"]) + + extrinsics = _get_extrinsics(ego_pose, calibration_lidar) + + return { + "sample_name": sample_name, + "lidar_path": lidar_path, + "extrinsics": extrinsics, + "timestamp": lidar_data["timestamp"], + } + + def _load_cam_data( + self, data: NuScenesDevkit, cam_token: str + ) -> DictStrAny: + """Load camera data. + + Args: + data (NuScenesDevkit): NuScenes toolkit. + cam_token (str): Camera token. + + Returns: + DictStrAny: Camera data containing the sample name, image path, + image height and width, intrinsics, extrinsics, and + timestamp. + """ + cam_data = data.get("sample_data", cam_token) + + sample_name = ( + cam_data["filename"] + .split("/")[-1] + .replace(f".{cam_data['fileformat']}", "") + ) + + image_path = os.path.join(self.data_root, cam_data["filename"]) + + calibration_cam = data.get( + "calibrated_sensor", cam_data["calibrated_sensor_token"] + ) + + intrinsics = np.array( + calibration_cam["camera_intrinsic"], dtype=np.float32 + ) + + ego_pose = data.get("ego_pose", cam_data["ego_pose_token"]) + extrinsics = _get_extrinsics(ego_pose, calibration_cam) + + return { + "sample_name": sample_name, + "image_path": image_path, + "image_hw": (cam_data["height"], cam_data["width"]), + "intrinsics": intrinsics, + "extrinsics": extrinsics, + "timestamp": cam_data["timestamp"], + } + + def _load_annotations( + self, + data: NuScenesDevkit, + extrinsics: NDArrayF32, + ann_tokens: list[str], + instance_tokens: list[str], + axis_mode: AxisMode = AxisMode.ROS, + export_2d_annotations: bool = False, + intrinsics: NDArrayF32 | None = None, + image_hw: tuple[int, int] | None = None, + ) -> DictStrAny: + """Load annonations.""" + boxes3d = np.empty((1, 10), dtype=np.float32)[1:] + boxes3d_classes = np.empty((1,), dtype=np.int64)[1:] + boxes3d_attributes = np.empty((1,), dtype=np.int64)[1:] + boxes3d_track_ids = np.empty((1,), dtype=np.int64)[1:] + boxes3d_velocities = np.empty((1, 3), dtype=np.float32)[1:] + boxes3d_num_lidar_pts = np.empty((1,), dtype=np.int64)[1:] + boxes3d_num_radar_pts = np.empty((1,), dtype=np.int64)[1:] + + if export_2d_annotations: + assert ( + axis_mode == AxisMode.OPENCV + ), "2D annotations are only supported in camera coordinates." + assert intrinsics is not None, "Intrinsics must be provided." + boxes2d = np.empty((1, 4), dtype=np.float32)[1:] + + sensor_from_global = inverse_rigid_transform( + torch.from_numpy(extrinsics) + ) + translation = sensor_from_global[:3, 3].numpy() + rotation = Quaternion( + matrix=sensor_from_global[:3, :3].numpy(), atol=1e-5 + ) + + for ann_token in ann_tokens: + ann_info = data.get("sample_annotation", ann_token) + box3d_class = category_to_detection_name(ann_info["category_name"]) + + if box3d_class is None: + continue + + # 3D box in global coordinates + box3d = data.get_box(ann_info["token"]) + + # Get 3D box velocity + box3d.velocity = data.box_velocity(ann_info["token"]) + + # Move 3D box to sensor coordinates + box3d.rotate(rotation) + box3d.translate(translation) + + if export_2d_annotations: + assert ( + image_hw is not None + ), "Image height and width must be provided." + if not box_in_image( + box3d, intrinsics, (image_hw[1], image_hw[0]) + ): + continue + + # Number of points in the 3D box + boxes3d_num_lidar_pts = np.concatenate( + [ + boxes3d_num_lidar_pts, + np.array([ann_info["num_lidar_pts"]], dtype=np.int64), + ] + ) + boxes3d_num_radar_pts = np.concatenate( + [ + boxes3d_num_radar_pts, + np.array([ann_info["num_radar_pts"]], dtype=np.int64), + ] + ) + + # Get 2D box + if export_2d_annotations: + corner_coords = ( + view_points(box3d.corners(), intrinsics, True) + .T[:, :2] + .tolist() + ) + + boxes2d = np.concatenate( + [ + boxes2d, + np.array( + [post_process_coords(corner_coords)], + dtype=np.float32, + ), + ] + ) + + # Get 3D box yaw. Use extrinsic rotation to align with PyTorch3D. + if axis_mode == AxisMode.OPENCV: + yaw = -box3d.orientation.yaw_pitch_roll[0] + x, y, z, w = R.from_euler("XYZ", [0, yaw, 0]).as_quat() + else: + yaw = box3d.orientation.yaw_pitch_roll[0] + x, y, z, w = R.from_euler("XYZ", [0, 0, yaw]).as_quat() + + orientation = Quaternion([w, x, y, z]) + + boxes3d = np.concatenate( + [ + boxes3d, + np.array( + [[*box3d.center, *box3d.wlh, *orientation.elements]], + dtype=np.float32, + ), + ] + ) + + # Get 3D box class id + boxes3d_classes = np.concatenate( + [ + boxes3d_classes, + np.array( + [nuscenes_class_map[box3d_class]], dtype=np.int64 + ), + ] + ) + + # Get 3D box attribute id + if len(ann_info["attribute_tokens"]) == 0: + box3d_attr = "" + else: + box3d_attr = data.get( + "attribute", ann_info["attribute_tokens"][0] + )["name"] + boxes3d_attributes = np.concatenate( + [ + boxes3d_attributes, + np.array( + [nuscenes_attribute_map[box3d_attr]], dtype=np.int64 + ), + ] + ) + + # Get 3D box track id + instance_token = data.get("sample_annotation", box3d.token)[ + "instance_token" + ] + if not instance_token in instance_tokens: + instance_tokens.append(instance_token) + track_id = instance_tokens.index(instance_token) + + boxes3d_track_ids = np.concatenate( + [boxes3d_track_ids, np.array([track_id], dtype=np.int64)] + ) + + # 3D bounding box velocity + velocity = box3d.velocity.astype(np.float32) + if np.any(np.isnan(velocity)): + velocity = np.zeros(3, dtype=np.float32) + + boxes3d_velocities = np.concatenate( + [boxes3d_velocities, velocity[None]] + ) + + annotations = { + "boxes3d": boxes3d, + "boxes3d_classes": boxes3d_classes, + "boxes3d_attributes": boxes3d_attributes, + "boxes3d_track_ids": boxes3d_track_ids, + "boxes3d_velocities": boxes3d_velocities, + "boxes3d_num_lidar_pts": boxes3d_num_lidar_pts, + "boxes3d_num_radar_pts": boxes3d_num_radar_pts, + } + + if export_2d_annotations: + annotations["boxes2d"] = boxes2d + + return annotations + + def _accumulate_sweeps( + self, + points: NDArrayF32, + lidar2global: NDArrayF32, + sweeps: list[DictStrAny], + ) -> NDArrayF32: + """Accumulate LiDAR sweeps.""" + if len(sweeps) == 0: + return points + + global2lidar = inverse_rigid_transform(torch.from_numpy(lidar2global)) + + points_sweeps = [torch.from_numpy(points)] + for sweep in sweeps: + points_bytes = self.data_backend.get(sweep["lidar_path"]) + lidar_points = np.frombuffer( + bytearray(points_bytes), dtype=np.float32 + ) + lidar_points = lidar_points.reshape(-1, 5)[:, :3] + + # Transform LiDAR points to global frame + global_lidar_points = transform_points( + torch.from_numpy(lidar_points), + torch.from_numpy(sweep["extrinsics"]), + ) + + # Transform LiDAR points to current LiDAR frame + current_lidar_points = transform_points( + global_lidar_points, global2lidar + ) + + points_sweeps.append(current_lidar_points) + + return torch.cat(points_sweeps).numpy() + + def _load_depth_map( + self, + points_lidar: NDArrayF32, + lidar2global: NDArrayF32, + cam2global: NDArrayF32, + intrinsics: NDArrayF32, + image_hw: tuple[int, int], + ) -> NDArrayF32: + """Load depth map. + + Args: + points_lidar (NDArrayF32): LiDAR points. + lidar2global (NDArrayF32): LiDAR to global extrinsics. + cam2global (NDArrayF32): Camera to global extrinsics. + intrinsics (NDArrayF32): Camera intrinsic matrix. + image_hw (tuple[int, int]): Image height and width. + + Returns: + NDArrayF32: Depth map. + """ + cam2global_ = torch.from_numpy(cam2global) + lidar2global_ = torch.from_numpy(lidar2global) + intrinsics_ = torch.from_numpy(intrinsics) + points_lidar_ = torch.from_numpy(np.copy(points_lidar)) + + lidar2cam = torch.matmul(torch.inverse(cam2global_), lidar2global_) + cam2img = torch.eye(4, 4) + cam2img[:3, :3] = intrinsics_ + points_cam = points_lidar_[:, :3] @ (lidar2cam[:3, :3].T) + lidar2cam[ + :3, 3 + ].unsqueeze(0) + + depth_map = generate_depth_map(points_cam, intrinsics_, image_hw) + return depth_map.numpy() + + def _filter_boxes( + self, annotations: DictStrAny + ) -> tuple[ + NDArrayBool, NDArrayF32, NDArrayI64, NDArrayI64, NDArrayI64, NDArrayF32 + ]: + """Load boxes.""" + valid_mask = np.full(annotations["boxes3d"].shape[0], True) + + if self.point_based_filter: + boxes3d_num_lidar_pts = annotations["boxes3d_num_lidar_pts"] + boxes3d_num_radar_pts = annotations["boxes3d_num_radar_pts"] + valid_mask = np.logical_and( + (boxes3d_num_lidar_pts + boxes3d_num_radar_pts) > 0, valid_mask + ) + + if self.distance_based_filter: + raise NotImplementedError( + "Distance based filter not implemented yet" + ) + + boxes3d = annotations["boxes3d"][valid_mask] + boxes3d_classes = annotations["boxes3d_classes"][valid_mask] + boxes3d_attributes = annotations["boxes3d_attributes"][valid_mask] + boxes3d_track_ids = annotations["boxes3d_track_ids"][valid_mask] + boxes3d_velocities = annotations["boxes3d_velocities"][valid_mask] + + return ( + valid_mask, + boxes3d, + boxes3d_classes, + boxes3d_attributes, + boxes3d_track_ids, + boxes3d_velocities, + ) + + def __len__(self) -> int: + """Length.""" + return len(self.samples) + + def __getitem__(self, idx: int) -> DictData: + """Get single sample. + + Args: + idx (int): Index of sample. + + Returns: + DictData: sample at index in Vis4D input format. + """ + sample = self.samples[idx] + data_dict: DictData = {} + + # metadata + data_dict["token"] = sample["token"] + data_dict[K.frame_ids] = sample["frame_ids"] + data_dict[K.sequence_names] = sample["scene_name"] + data_dict["can_bus"] = sample["can_bus"] + + if "LIDAR_TOP" in self.sensors: + lidar_data = sample["LIDAR_TOP"] + + # load LiDAR frame + data_dict["LIDAR_TOP"] = { + K.sample_names: lidar_data["sample_name"], + K.timestamp: lidar_data["timestamp"], + K.extrinsics: lidar_data["extrinsics"], + K.axis_mode: AxisMode.LIDAR, + } + + if ( + K.points3d in self.keys_to_load + or K.depth_maps in self.keys_to_load + ): + points_bytes = self.data_backend.get(lidar_data["lidar_path"]) + lidar_points = np.frombuffer( + bytearray(points_bytes), dtype=np.float32 + ) + lidar_points = lidar_points.reshape(-1, 5)[:, :3] + + lidar_points = self._accumulate_sweeps( + lidar_points, + lidar_data["extrinsics"], + lidar_data["sweeps"], + ) + + if K.points3d in self.keys_to_load: + data_dict["LIDAR_TOP"][K.points3d] = lidar_points + + if K.boxes3d in self.keys_to_load: + data_dict["LIDAR_TOP"][K.boxes3d] = lidar_data["annotations"][ + "boxes3d" + ] + data_dict["LIDAR_TOP"][K.boxes3d_classes] = lidar_data[ + "annotations" + ]["boxes3d_classes"] + data_dict["LIDAR_TOP"][K.boxes3d_track_ids] = lidar_data[ + "annotations" + ]["boxes3d_track_ids"] + data_dict["LIDAR_TOP"][K.boxes3d_velocities] = lidar_data[ + "annotations" + ]["boxes3d_velocities"] + data_dict["LIDAR_TOP"]["attributes"] = lidar_data[ + "annotations" + ]["boxes3d_attributes"] + + # load camera frame + for cam in NuScenes.CAMERAS: + if cam in self.sensors: + cam_data = sample[cam] + + data_dict[cam] = {K.timestamp: cam_data["timestamp"]} + + if K.images in self.keys_to_load: + im_bytes = self.data_backend.get(cam_data["image_path"]) + image = np.ascontiguousarray( + im_decode(im_bytes, mode=self.image_channel_mode), + dtype=np.float32, + )[None] + + data_dict[cam][K.images] = image + data_dict[cam][K.input_hw] = cam_data["image_hw"] + data_dict[cam][K.sample_names] = cam_data["sample_name"] + data_dict[cam][K.intrinsics] = cam_data["intrinsics"] + data_dict[cam][K.extrinsics] = cam_data["extrinsics"] + data_dict[cam][K.axis_mode] = AxisMode.OPENCV + + if K.original_images in self.keys_to_load: + data_dict[cam][K.original_images] = image + data_dict[cam][K.original_hw] = cam_data["image_hw"] + + if ( + K.boxes3d in self.keys_to_load + or K.boxes2d in self.keys_to_load + ): + if K.boxes3d in self.keys_to_load: + data_dict[cam][K.boxes3d] = cam_data["annotations"][ + "boxes3d" + ] + data_dict[cam][K.boxes3d_classes] = cam_data[ + "annotations" + ]["boxes3d_classes"] + data_dict[cam][K.boxes3d_track_ids] = cam_data[ + "annotations" + ]["boxes3d_track_ids"] + data_dict[cam][K.boxes3d_velocities] = cam_data[ + "annotations" + ]["boxes3d_velocities"] + data_dict[cam]["attributes"] = cam_data["annotations"][ + "boxes3d_attributes" + ] + + if K.boxes2d in self.keys_to_load: + boxes2d = cam_data["annotations"]["boxes2d"] + + data_dict[cam][K.boxes2d] = boxes2d + data_dict[cam][K.boxes2d_classes] = data_dict[cam][ + K.boxes3d_classes + ] + data_dict[cam][K.boxes2d_track_ids] = data_dict[cam][ + K.boxes3d_track_ids + ] + + if K.depth_maps in self.keys_to_load: + depth_maps = self._load_depth_map( + lidar_points, + lidar_data["extrinsics"], + cam_data["extrinsics"], + cam_data["intrinsics"], + cam_data["image_hw"], + ) + + data_dict[cam][K.depth_maps] = depth_maps + + return data_dict diff --git a/vis4d/data/datasets/nuscenes_detection.py b/vis4d/data/datasets/nuscenes_detection.py new file mode 100644 index 0000000000000000000000000000000000000000..decd7e7b5f885717f2d9e06cf4fe33e60c89f58e --- /dev/null +++ b/vis4d/data/datasets/nuscenes_detection.py @@ -0,0 +1,113 @@ +"""NuScenes multi-sensor video dataset.""" + +from __future__ import annotations + +import json + +import numpy as np + +from vis4d.common.typing import ArgsType, DictStrAny, NDArrayF32, NDArrayI64 +from vis4d.data.typing import DictData + +from .nuscenes import NuScenes, nuscenes_class_map + + +class NuScenesDetection(NuScenes): + """NuScenes detection dataset.""" + + def __init__( + self, + pure_detection: str, + score_thres: float = 0.05, + **kwargs: ArgsType, + ) -> None: + """Creates an instance of the class.""" + self.pure_detection = pure_detection + self.score_thres = score_thres + + with open(self.pure_detection, encoding="utf-8") as f: + self.predictions = json.load(f) + + super().__init__(**kwargs) + + def __repr__(self) -> str: + """Concise representation of the dataset.""" + return ( + f"NuScenesDetection {self.version} {self.split} using " + + f"{self.pure_detection}" + ) + + def _load_pred( + self, preds: list[DictStrAny] + ) -> tuple[NDArrayF32, NDArrayI64, NDArrayF32, NDArrayF32]: + """Load nuscenes format prediction.""" + boxes3d = np.empty((1, 10), dtype=np.float32)[1:] + boxes3d_classes = np.empty((1,), dtype=np.int64)[1:] + boxes3d_scores = np.empty((1,), dtype=np.float32)[1:] + boxes3d_velocities = np.empty((1, 3), dtype=np.float32)[1:] + + for pred in preds: + if pred["detection_name"] not in nuscenes_class_map: + continue + + if float(pred["detection_score"]) <= self.score_thres: + continue + + boxes3d = np.concatenate( + [ + boxes3d, + np.array( + [ + [ + *pred["translation"], + *pred["size"], + *pred["rotation"], + ] + ], + dtype=np.float32, + ), + ] + ) + boxes3d_classes = np.concatenate( + [ + boxes3d_classes, + np.array( + [nuscenes_class_map[pred["detection_name"]]], + dtype=np.int64, + ), + ] + ) + boxes3d_scores = np.concatenate( + [ + boxes3d_scores, + np.array([pred["detection_score"]], dtype=np.float32), + ] + ) + boxes3d_velocities = np.concatenate( + [ + boxes3d_velocities, + np.array([[*pred["velocity"], 0]], dtype=np.float32), + ] + ) + + return boxes3d, boxes3d_classes, boxes3d_scores, boxes3d_velocities + + def __getitem__(self, idx: int) -> DictData: + """Get single sample. + + Args: + idx (int): Index of sample. + + Returns: + DictData: sample at index in Vis4D input format. + """ + data_dict = super().__getitem__(idx) + + ( + data_dict["LIDAR_TOP"]["pred_boxes3d"], + data_dict["LIDAR_TOP"]["pred_boxes3d_classes"], + data_dict["LIDAR_TOP"]["pred_boxes3d_scores"], + data_dict["LIDAR_TOP"]["pred_boxes3d_velocities"], + ) = self._load_pred(self.predictions["results"][data_dict["token"]]) + + return data_dict diff --git a/vis4d/data/datasets/nuscenes_mono.py b/vis4d/data/datasets/nuscenes_mono.py new file mode 100644 index 0000000000000000000000000000000000000000..9a48f9222cefba099e7a445c6486acca34ba385c --- /dev/null +++ b/vis4d/data/datasets/nuscenes_mono.py @@ -0,0 +1,248 @@ +"""NuScenes monocular dataset.""" + +from __future__ import annotations + +import numpy as np +from tqdm import tqdm + +from vis4d.common.imports import NUSCENES_AVAILABLE +from vis4d.common.logging import rank_zero_info +from vis4d.common.time import Timer +from vis4d.common.typing import ArgsType, DictStrAny +from vis4d.data.const import AxisMode +from vis4d.data.const import CommonKeys as K +from vis4d.data.typing import DictData + +from .nuscenes import NuScenes, nuscenes_class_map +from .util import im_decode, print_class_histogram + +if NUSCENES_AVAILABLE: + from nuscenes import NuScenes as NuScenesDevkit + from nuscenes.utils.splits import create_splits_scenes +else: + raise ImportError("nusenes-devkit is not available.") + + +class NuScenesMono(NuScenes): + """NuScenes monocular dataset.""" + + def __init__(self, *args: ArgsType, **kwargs: ArgsType) -> None: + """Initialize the dataset.""" + super().__init__(*args, **kwargs) + + # Needed for CBGS + def get_cat_ids(self, idx: int) -> list[int]: + """Return the samples.""" + return self.samples[idx]["CAM"]["annotations"]["boxes3d_classes"] + + def _filter_data(self, data: list[DictStrAny]) -> list[DictStrAny]: + """Remove empty samples.""" + samples = [] + frequencies = {cat: 0 for cat in nuscenes_class_map} + inv_nuscenes_class_map = {v: k for k, v in nuscenes_class_map.items()} + + t = Timer() + for sample in data: + ( + mask, + boxes3d, + boxes3d_classes, + boxes3d_attributes, + boxes3d_track_ids, + boxes3d_velocities, + ) = self._filter_boxes(sample["CAM"]["annotations"]) + + sample["CAM"]["annotations"]["boxes3d"] = boxes3d + sample["CAM"]["annotations"]["boxes3d_classes"] = boxes3d_classes + sample["CAM"]["annotations"][ + "boxes3d_attributes" + ] = boxes3d_attributes + sample["CAM"]["annotations"][ + "boxes3d_track_ids" + ] = boxes3d_track_ids + sample["CAM"]["annotations"][ + "boxes3d_velocities" + ] = boxes3d_velocities + sample["CAM"]["annotations"]["boxes2d"] = sample["CAM"][ + "annotations" + ]["boxes2d"][mask] + + for box3d_class in boxes3d_classes: + frequencies[inv_nuscenes_class_map[box3d_class]] += 1 + + if self.skip_empty_samples: + if len(sample["CAM"]["annotations"]["boxes3d"]) > 0: + samples.append(sample) + else: + samples.append(sample) + + rank_zero_info( + f"Preprocessing {len(data)} frames takes {t.time():.2f}" + " seconds." + ) + + print_class_histogram(frequencies) + + if self.skip_empty_samples: + rank_zero_info( + f"Filtered {len(data) - len(samples)} empty frames." + ) + + return samples + + def __repr__(self) -> str: + """Concise representation of the dataset.""" + return f"NuScenes Monocular Dataset {self.version} {self.split}" + + def _generate_data_mapping(self) -> list[DictStrAny]: + """Generate data mapping. + + Returns: + List[DictStrAny]: List of items required to load for a single + dataset sample. + """ + data = NuScenesDevkit( + version=self.version, dataroot=self.data_root, verbose=False + ) + + frames = [] + instance_tokens: list[str] = [] + + scene_names_per_split = create_splits_scenes() + + scenes = [ + scene + for scene in data.scene + if scene["name"] in scene_names_per_split[self.split] + ] + + for scene in tqdm(scenes): + scene_name = scene["name"] + frame_ids = 0 + sample_token = scene["first_sample_token"] + while sample_token: + sample = data.get("sample", sample_token) + + # LIDAR data + lidar_token = sample["data"]["LIDAR_TOP"] + + lidar_data = self._load_lidar_data(data, lidar_token) + lidar_data["annotations"] = self._load_annotations( + data, + lidar_data["extrinsics"], + sample["anns"], + instance_tokens, + ) + + # TODO add RADAR, Map data + + # Get the sample data for each camera + for cam in self.CAMERAS: + frame: DictStrAny = {} + frame["scene_name"] = f"{scene_name}_{cam}" + frame["token"] = sample["token"] + frame["frame_ids"] = frame_ids + + frame["LIDAR_TOP"] = lidar_data + + cam_token = sample["data"][cam] + + frame["CAM"] = self._load_cam_data(data, cam_token) + frame["CAM"]["annotations"] = self._load_annotations( + data, + frame["CAM"]["extrinsics"], + sample["anns"], + instance_tokens, + axis_mode=AxisMode.OPENCV, + export_2d_annotations=True, + intrinsics=frame["CAM"]["intrinsics"], + image_hw=frame["CAM"]["image_hw"], + ) + + frames.append(frame) + + sample_token = sample["next"] + frame_ids += 1 + + return frames + + def __getitem__(self, idx: int) -> DictData: + """Get single sample. + + Args: + idx (int): Index of sample. + + Returns: + DictData: sample at index in Vis4D input format. + """ + sample = self.samples[idx] + data_dict: DictData = {} + + if K.depth_maps in self.keys_to_load: + lidar_data = sample["LIDAR_TOP"] + + points_bytes = self.data_backend.get(lidar_data["lidar_path"]) + points = np.frombuffer(points_bytes, dtype=np.float32) + points = points.reshape(-1, 5)[:, :3] + + if K.depth_maps in self.keys_to_load: + lidar_to_global = lidar_data["extrinsics"] + + # load camera frame + data_dict = { + "token": sample["token"], + K.sequence_names: sample["scene_name"], + K.frame_ids: sample["frame_ids"], + K.timestamp: sample["CAM"]["timestamp"], + } + + if K.images in self.keys_to_load: + im_bytes = self.data_backend.get(sample["CAM"]["image_path"]) + image = np.ascontiguousarray( + im_decode(im_bytes), dtype=np.float32 + )[None] + + data_dict[K.images] = image + data_dict[K.input_hw] = sample["CAM"]["image_hw"] + data_dict[K.sample_names] = sample["CAM"]["sample_name"] + data_dict[K.intrinsics] = sample["CAM"]["intrinsics"] + + if K.original_images in self.keys_to_load: + data_dict[K.original_images] = image + data_dict[K.original_hw] = sample["CAM"]["image_hw"] + + if K.boxes3d in self.keys_to_load or K.boxes2d in self.keys_to_load: + if K.boxes3d in self.keys_to_load: + data_dict[K.boxes3d] = sample["CAM"]["annotations"]["boxes3d"] + data_dict[K.boxes3d_classes] = sample["CAM"]["annotations"][ + "boxes3d_classes" + ] + data_dict[K.boxes3d_track_ids] = sample["CAM"]["annotations"][ + "boxes3d_track_ids" + ] + data_dict[K.boxes3d_velocities] = sample["CAM"]["annotations"][ + "boxes3d_velocities" + ] + data_dict["attributes"] = sample["CAM"]["annotations"][ + "boxes3d_attributes" + ] + data_dict[K.extrinsics] = sample["CAM"]["extrinsics"] + data_dict[K.axis_mode] = AxisMode.OPENCV + + if K.boxes2d in self.keys_to_load: + data_dict[K.boxes2d] = sample["CAM"]["annotations"]["boxes2d"] + data_dict[K.boxes2d_classes] = data_dict[K.boxes3d_classes] + data_dict[K.boxes2d_track_ids] = data_dict[K.boxes3d_track_ids] + + if K.depth_maps in self.keys_to_load: + depth_maps = self._load_depth_map( + points, + lidar_to_global, + sample["CAM"]["extrinsics"], + sample["CAM"]["intrinsics"], + sample["CAM"]["image_hw"], + ) + + data_dict[K.depth_maps] = depth_maps + + return data_dict diff --git a/vis4d/data/datasets/nuscenes_trajectory.py b/vis4d/data/datasets/nuscenes_trajectory.py new file mode 100644 index 0000000000000000000000000000000000000000..a155ac837d63c1b7aa22de795cab09bf2be4319d --- /dev/null +++ b/vis4d/data/datasets/nuscenes_trajectory.py @@ -0,0 +1,264 @@ +"""NuScenes trajectory dataset.""" + +from __future__ import annotations + +import json + +import numpy as np +from scipy.spatial.distance import cdist +from tqdm import tqdm + +from vis4d.common.imports import NUSCENES_AVAILABLE +from vis4d.common.logging import rank_zero_info +from vis4d.common.typing import DictStrAny, NDArrayF32 +from vis4d.data.typing import DictData + +from .base import Dataset +from .util import CacheMappingMixin + +if NUSCENES_AVAILABLE: + from nuscenes import NuScenes as NuScenesDevkit + from nuscenes.eval.detection.utils import category_to_detection_name + from nuscenes.utils.data_classes import Quaternion + from nuscenes.utils.splits import create_splits_scenes +else: + raise ImportError("nusenes-devkit is not available.") + + +class NuScenesTrajectory(CacheMappingMixin, Dataset): + """NuScenes Trajectory dataset with given detection results. + + It will generate a trajectory data pair with minimum sequence length. The + detection results will be matched with the ground truth trajectory + according to the BEV distance. + """ + + def __init__( + self, + detector: str, + pure_detection: str, + data_root: str, + version: str = "v1.0-trainval", + split: str = "train", + min_seq_len: int = 10, + cache_as_binary: bool = False, + cached_file_path: str | None = None, + ) -> None: + """Init dataset. + + Args: + detector (str): The detector name. + pure_detection (str): The path to the pure detection results. It + should be the same format as nuScenes submission format. + data_root (str): The root path of the dataset. + version (str, optional): The version of the dataset. Defaults to + "v1.0-trainval". + split (str, optional): The split of the dataset. Defaults to + "train". + min_seq_len (int, optional): The minimum sequence length of the + trajectory. Defaults to 10. + cache_as_binary (bool, optional): Whether to cache the dataset as + binary. Defaults to False. + cached_file_path (str | None, optional): The path to the cached + file. Defaults to None. + """ + super().__init__() + self.data_root = data_root + self.version = version + self.split = split + + self.detector = detector + self.min_seq_len = min_seq_len + + self.pure_detection = pure_detection + + # Load trajectories + self.samples, _ = self._load_mapping( + self._generate_data_mapping, + cache_as_binary=cache_as_binary, + cached_file_path=cached_file_path, + ) + rank_zero_info(f"Generated {len(self.samples)} trajectories.") + + def __repr__(self) -> str: + """Concise representation of the dataset.""" + return f"NuScenes Trajectory Data with {self.detector} detection" + + def _match_gt_pred( + self, + gt_world: NDArrayF32, + gt_class: str, + predictions: list[DictStrAny], + ) -> tuple[NDArrayF32, bool]: + """Match gt and pred according to BEV center distance. + + If the distance is less than 2 meters, the prediction will be used + instead of the ground truth. + """ + if len(predictions) > 0: + same_class_preds = [ + pred + for pred in predictions + if pred["detection_name"] == gt_class + ] + + if len(same_class_preds) > 0: + preds_center = [ + pred["translation"][:2] for pred in same_class_preds + ] + distance_matrix = ( + cdist( # pylint: disable=unsubscriptable-object + gt_world[:, :2], + np.array(preds_center).reshape(-1, 2), + )[0] + ) + + if distance_matrix[distance_matrix.argmin()] <= 2: + match_pred = same_class_preds[distance_matrix.argmin()] + + # WLH -> HWL + w, l, h = match_pred["size"] + dimensions = [h, w, l] + yaw = Quaternion(match_pred["rotation"]).yaw_pitch_roll[0] + + pred_world = np.array( + [ + [ + *match_pred["translation"], + *dimensions, + yaw, + match_pred["detection_score"], + ] + ], + dtype=np.float32, + ) + + return pred_world, False + + return gt_world, True + + def _generate_data_mapping(self) -> list[dict[str, NDArrayF32]]: + """Generate trajectories predction and groundtruth. + + Trajectories will be generated for each scene. Each trajectory consists + of [x, y, z, h, w, l, yaw, score] in world coordinate. + + Returns: + list[dict[str, NDArrayF32]]: The list of trajectories. + """ + data = NuScenesDevkit( + version=self.version, dataroot=self.data_root, verbose=False + ) + + scene_names_per_split = create_splits_scenes() + + scenes = [ + scene + for scene in data.scene + if scene["name"] in scene_names_per_split[self.split] + ] + + instance_tokens = [] + + with open(self.pure_detection, "r", encoding="utf-8") as f: + predictions = json.load(f) + + num_gt_boxes = 0 + num_pred_boxes = 0 + total_traj = [] + for scene in tqdm(scenes): + local_traj: dict[int, dict[str, list[NDArrayF32]]] = {} + + sample_token = scene["first_sample_token"] + while sample_token: + sample = data.get("sample", sample_token) + + preds = predictions["results"][sample_token] + + for ann_token in sample["anns"]: + ann_info = data.get("sample_annotation", ann_token) + box3d_class = category_to_detection_name( + ann_info["category_name"] + ) + + if box3d_class is None: + continue + + box3d = data.get_box(ann_info["token"]) + + instance_token = data.get( + "sample_annotation", box3d.token + )["instance_token"] + + if not instance_token in instance_tokens: + instance_tokens.append(instance_token) + track_id = instance_tokens.index(instance_token) + + if track_id not in local_traj: + local_traj[track_id] = {"gt": [], "pred": []} + + # WLH -> HWL + w, l, h = box3d.wlh + dimensions = [h, w, l] + yaw = box3d.orientation.yaw_pitch_roll[0] + + gt_world = np.array( + [[*box3d.center, *dimensions, yaw, 1.0]], + dtype=np.float32, + ) + + local_traj[track_id]["gt"].append(gt_world) + + matched_pred, is_gt = self._match_gt_pred( + gt_world, box3d_class, preds + ) + local_traj[track_id]["pred"].append(matched_pred) + + if is_gt: + num_gt_boxes += 1 + else: + num_pred_boxes += 1 + + sample_token = sample["next"] + + for _, traj in local_traj.items(): + if len(traj["gt"]) >= self.min_seq_len: + trajectory = { + "gt": np.concatenate(traj["gt"]), + "pred": np.concatenate(traj["pred"]), + } + total_traj.append(trajectory) + + rank_zero_info(f"Use {num_gt_boxes} gt boxes.") + rank_zero_info(f"Use {num_pred_boxes} pred boxes.") + + return total_traj + + def __len__(self) -> int: + """Return the length of the dataset.""" + return len(self.samples) + + def __getitem__(self, idx: int) -> DictData: + """Return the item at the given index. + + The trajectory will be randomly cropped to the minimum sequence length. + """ + trajectory = self.samples[idx] + data_dict: DictData = {} + + traj_len = len(trajectory["gt"]) + + if traj_len > self.min_seq_len: + first_frame = np.random.randint(traj_len - self.min_seq_len) + else: + first_frame = 0 + + data_dict["gt_traj"] = trajectory["gt"][ + first_frame : first_frame + self.min_seq_len + ] + + data_dict["pred_traj"] = trajectory["pred"][ + first_frame : first_frame + self.min_seq_len + ] + + return data_dict diff --git a/vis4d/data/datasets/s3dis.py b/vis4d/data/datasets/s3dis.py new file mode 100644 index 0000000000000000000000000000000000000000..ad2657d0a69d32e4b2fba4d53bfaa10f21e1c82e --- /dev/null +++ b/vis4d/data/datasets/s3dis.py @@ -0,0 +1,276 @@ +"""Stanford 3D indoor dataset.""" + +from __future__ import annotations + +import copy +import glob +import os +from collections.abc import Sequence +from io import BytesIO + +import numpy as np +import pandas as pd +import torch + +from vis4d.common.typing import ArgsType, DictStrAny +from vis4d.data.const import CommonKeys as K +from vis4d.data.typing import DictData + +from .base import Dataset +from .util import CacheMappingMixin + + +class S3DIS(CacheMappingMixin, Dataset): + """S3DIS dataset class.""" + + DESCRIPTION = """S3DIS is a large-scale indoor pointcloud dataset.""" + HOMEPAGE = "https://buildingparser.stanford.edu/dataset.html" + PAPER = ( + "https://openaccess.thecvf.com/content_cvpr_2016/papers/" + "Armeni_3D_Semantic_Parsing_CVPR_2016_paper.pdf" + ) + LICENSE = "CC BY-NC-SA 4.0" + + KEYS = [ + K.points3d, + K.colors3d, + K.semantics3d, + K.instances3d, + ] + + CLASS_NAME_TO_IDX = { + "ceiling": 0, + "floor": 1, + "wall": 2, + "beam": 3, + "column": 4, + "window": 5, + "door": 6, + "chair": 7, + "table": 8, + "bookcase": 9, + "sofa": 10, + "board": 11, + "clutter": 12, + } + + CLASS_COUNTS = torch.Tensor( + [ + 3370714, + 2856755, + 4919229, + 318158, + 375640, + 478001, + 974733, + 650464, + 791496, + 88727, + 1284130, + 229758, + 2272837, + ] + ) + + AVAILABLE_KEYS: Sequence[str] = ( + K.points3d, + K.colors3d, + K.semantics3d, + K.instances3d, + ) + + COLOR_MAPPING = torch.tensor( + [ + [152, 223, 138], + [31, 119, 180], + [188, 189, 34], + [140, 86, 75], + [255, 152, 150], + [214, 39, 40], + [197, 176, 213], + [23, 190, 207], + [178, 76, 76], + [247, 182, 210], + [66, 188, 102], + [219, 219, 141], + [140, 57, 197], + [202, 185, 52], + ] + ) + + def __init__( + self, + data_root: str, + split: str = "trainNoArea5", + keys_to_load: Sequence[str] = ( + K.points3d, + K.colors3d, + K.semantics3d, + K.instances3d, + ), + cache_points: bool = True, + cache_as_binary: bool = False, + cached_file_path: str | None = None, + **kwargs: ArgsType, + ) -> None: + """Creates a new S3DIS dataset. + + Args: + data_root (str): Path to S3DIS folder + split (str): which split to load. Must either be trainNoArea[1-6] + or testArea[1-6]. e.g. trainNoArea5 will load all areas except + area 5 and testArea5 will only load area 5. + keys_to_load (list[str]): What kind of data should be loaded + (e.g. colors, xyz, semantics, ...) + cache_points (bool): If true caches loaded points instead of + reading them from the disk every time. + cache_as_binary (bool): Whether to cache the dataset as binary. + Default: False. + cached_file_path (str | None): Path to a cached file. If cached + file exist then it will load it instead of generating the data + mapping. Default: None. + + Raises: + ValueError: If requested split is malformed. + """ + super().__init__(**kwargs) + + self.data_root = data_root + self.split = split + + self.areas: list[str] = [ + "Area_1", + "Area_2", + "Area_3", + "Area_4", + "Area_5", + "Area_6", + ] + area_number = int(self.split.split("Area")[-1]) + if "trainNoArea" in self.split: + self.areas.remove(self.areas[area_number - 1]) + elif "testArea" in self.split: + self.areas = [self.areas[area_number - 1]] + else: + raise ValueError("Unknown split: ", self.split) + + self.data, _ = self._load_mapping( + self._generate_data_mapping, + cache_as_binary=cache_as_binary, + cached_file_path=cached_file_path, + ) + self.keys_to_load = keys_to_load + + # Cache + self.cache_points = cache_points + self._cache: dict[int, DictData] = {} + + @property + def num_classes(self) -> int: + """The number of classes int he datset.""" + return len(S3DIS.CLASS_NAME_TO_IDX) + + def __repr__(self) -> str: + """Concise representation of the dataset.""" + return f"S3DIS(root={self.data_root}, split={self.split})" + + def _generate_data_mapping(self) -> list[DictStrAny]: + """Generate 3dis dataset mapping.""" + data: list[DictStrAny] = [] + for area in self.areas: + for room_path in glob.glob( + os.path.join(self.data_root, area + "/*") + ): + room_data: DictStrAny = {} + if not os.path.isdir(room_path): + continue + + for anns in glob.glob( + os.path.join(room_path, "Annotations/*.txt") + ): + instance_id = os.path.basename(anns.replace(".txt", "")) + sem_name = instance_id.split("_")[0] + room_data[instance_id] = { + "class_label": S3DIS.CLASS_NAME_TO_IDX.get( + sem_name, 12 + ), + "path": anns, + } + data.append(room_data) + + return data + + def __len__(self) -> int: + """Length of the datset.""" + return len(self.data) + + def __getitem__(self, idx: int) -> DictData: + """Transform s3dis sample to vis4d input format. + + Returns: + coordinates: 3D Poitns coordinate Shape(n x 3) + colors: 3D Point colors Shape(n x 3) + Semantic Classes: 3D Point classes Shape(n x 1) + + Raises: + ValueError: If a requested key does not exist in this dataset. + """ + data = self.data[idx] + + # Cache data + if self.cache_points and idx in self._cache: + return copy.deepcopy(self._cache[idx]) + + coords = np.zeros((0, 3), dtype=np.float32) + color = np.zeros((0, 3), dtype=np.float32) + semantic_ids = np.zeros((0, 1), dtype=int) + instance_ids = np.zeros((0, 1), dtype=int) + + for values in data.values(): + data_path = values["path"] + instance_id = int( + values["path"].split("_")[-1].replace(".txt", "") + ) + np_data = pd.read_csv( + BytesIO(self.data_backend.get(data_path)), + header=None, + delimiter=" ", + ).values.astype(np.float32) + + if K.points3d in self.keys_to_load: + coords = np.vstack([coords, np_data[:, :3]]) + if K.colors3d in self.keys_to_load: + color = np.vstack([color, np_data[:, 3:]]) + if K.semantics3d in self.keys_to_load: + semantic_ids = np.vstack( + [ + semantic_ids, + np.ones((np_data.shape[0], 1), dtype=int) + * values["class_label"], + ] + ) + if K.instances3d in self.keys_to_load: + instance_ids = np.vstack( + [ + instance_ids, + np.ones((np_data.shape[0], 1), dtype=int) + * instance_id, + ] + ) + + data = {} + for key in self.keys_to_load: + if key == K.points3d: + data[key] = coords + elif key == K.colors3d: + data[key] = color / 255.0 + elif key == K.semantics3d: + data[key] = semantic_ids.squeeze(-1) + elif key == K.instances3d: + data[key] = instance_ids.squeeze(-1) + else: + raise ValueError(f"Can not load data for key: {key}") + + if self.cache_points: + self._cache[idx] = copy.deepcopy(data) + return data diff --git a/vis4d/data/datasets/scalabel.py b/vis4d/data/datasets/scalabel.py new file mode 100644 index 0000000000000000000000000000000000000000..746cd2be84473897d6c8ba5e41681589569b3eb1 --- /dev/null +++ b/vis4d/data/datasets/scalabel.py @@ -0,0 +1,753 @@ +"""Scalabel type dataset.""" + +from __future__ import annotations + +import os +from collections import defaultdict +from collections.abc import Callable, Sequence +from typing import Union + +import numpy as np +import torch + +from vis4d.common.distributed import broadcast +from vis4d.common.imports import SCALABEL_AVAILABLE +from vis4d.common.logging import rank_zero_info +from vis4d.common.time import Timer +from vis4d.common.typing import ( + ArgsType, + ListAny, + NDArrayF32, + NDArrayI64, + NDArrayUI8, +) +from vis4d.data.const import AxisMode +from vis4d.data.const import CommonKeys as K +from vis4d.data.datasets.util import CacheMappingMixin, DatasetFromList +from vis4d.data.io import DataBackend +from vis4d.data.typing import DictData +from vis4d.op.geometry.rotation import ( + euler_angles_to_matrix, + matrix_to_quaternion, +) + +from .base import VideoDataset, VideoMapping +from .util import DatasetFromList, im_decode, ply_decode, print_class_histogram + +if SCALABEL_AVAILABLE: + from scalabel.label.io import load, load_label_config + from scalabel.label.transforms import ( + box2d_to_xyxy, + poly2ds_to_mask, + rle_to_mask, + ) + from scalabel.label.typing import ( + Config, + ) + from scalabel.label.typing import Dataset as ScalabelData + from scalabel.label.typing import ( + Extrinsics, + Frame, + ImageSize, + Intrinsics, + Label, + ) + from scalabel.label.utils import ( + check_crowd, + check_ignored, + get_leaf_categories, + get_matrix_from_extrinsics, + get_matrix_from_intrinsics, + ) +else: + raise ImportError("scalabel is not installed.") + + +def load_intrinsics(intrinsics: Intrinsics) -> NDArrayF32: + """Transform intrinsic camera matrix according to augmentations.""" + return get_matrix_from_intrinsics(intrinsics).astype(np.float32) + + +def load_extrinsics(extrinsics: Extrinsics) -> NDArrayF32: + """Transform extrinsics from Scalabel to Vis4D.""" + return get_matrix_from_extrinsics(extrinsics).astype(np.float32) + + +def load_image( + url: str, backend: DataBackend, image_channel_mode: str +) -> NDArrayF32: + """Load image tensor from url.""" + im_bytes = backend.get(url) + image = im_decode(im_bytes, mode=image_channel_mode) + return np.ascontiguousarray(image, dtype=np.float32)[None] + + +def load_pointcloud(url: str, backend: DataBackend) -> NDArrayF32: + """Load pointcloud tensor from url.""" + assert url.endswith(".ply"), "Only PLY files are supported now." + ply_bytes = backend.get(url) + pointcloud = ply_decode(ply_bytes) + return pointcloud.astype(np.float32) + + +def instance_ids_to_global( + frames: list[Frame], local_instance_ids: dict[str, list[str]] +) -> None: + """Use local (per video) instance ids to produce global ones.""" + video_names = list(local_instance_ids.keys()) + for frame_id, ann in enumerate(frames): + if ann.labels is None: # pragma: no cover + continue + for label in ann.labels: + assert label.attributes is not None + if not check_crowd(label) and not check_ignored(label): + video_name = ( + ann.videoName + if ann.videoName is not None + else "no-video-" + str(frame_id) + ) + sum_previous_vids = sum( + ( + len(local_instance_ids[v]) + for v in video_names[: video_names.index(video_name)] + ) + ) + label.attributes["instance_id"] = ( + sum_previous_vids + + local_instance_ids[video_name].index(label.id) + ) + + +def add_data_path(data_root: str, frames: list[Frame]) -> None: + """Add filepath to frame using data_root.""" + for ann in frames: + assert ann.name is not None + if ann.url is None: + if ann.videoName is not None: + ann.url = os.path.join(data_root, ann.videoName, ann.name) + else: + ann.url = os.path.join(data_root, ann.name) + else: + ann.url = os.path.join(data_root, ann.url) + + +def discard_labels_outside_set( + dataset: list[Frame], class_set: list[str] +) -> None: + """Discard labels outside given set of classes. + + Args: + dataset (list[Frame]): List of frames to filter. + class_set (list[str]): List of classes to keep. + """ + for frame in dataset: + remove_anns = [] + if frame.labels is not None: + for i, ann in enumerate(frame.labels): + if not ann.category in class_set: + remove_anns.append(i) + for i in reversed(remove_anns): + frame.labels.pop(i) + + +def remove_empty_samples(frames: list[Frame]) -> list[Frame]: + """Remove empty samples.""" + new_frames = [] + for frame in frames: + if frame.labels is None: + continue + labels_used = [] + for label in frame.labels: + assert label.attributes is not None and label.category is not None + if not check_crowd(label) and not check_ignored(label): + labels_used.append(label) + + if len(labels_used) != 0: + frame.labels = labels_used + new_frames.append(frame) + rank_zero_info(f"Filtered {len(frames) - len(new_frames)} empty frames.") + del frames + return new_frames + + +def prepare_labels( + frames: list[Frame], + class_list: list[str], + global_instance_ids: bool = False, +) -> dict[str, int]: + """Add category id and instance id to labels, return class frequencies. + + Args: + frames (list[Frame]): List of frames. + class_list (list[str]): List of classes. + global_instance_ids (bool): Whether to use global instance ids. + Defaults to False. + """ + instance_ids: dict[str, list[str]] = defaultdict(list) + frequencies = {cat: 0 for cat in class_list} + for frame_id, ann in enumerate(frames): + if ann.labels is None: # pragma: no cover + continue + + for label in ann.labels: + attr: dict[str, bool | int | float | str] = {} + if label.attributes is not None: + attr = label.attributes + + if check_crowd(label) or check_ignored(label): + continue + + assert label.category is not None + frequencies[label.category] += 1 + video_name = ( + ann.videoName + if ann.videoName is not None + else "no-video-" + str(frame_id) + ) + if label.id not in instance_ids[video_name]: + instance_ids[video_name].append(label.id) + attr["instance_id"] = instance_ids[video_name].index(label.id) + label.attributes = attr + + if global_instance_ids: + instance_ids_to_global(frames, instance_ids) + + return frequencies + + +def filter_frames_by_attributes( + frames: list[Frame], + attributes_to_load: Sequence[dict[str, str | float]] | None, +) -> list[Frame]: + """Filter frames based on attributes.""" + if attributes_to_load is None: + return frames + filtered_frames: list[Frame] = [] + for frame in frames: + for attribute_dict in attributes_to_load: + if hasattr(frame, "attributes") and frame.attributes is not None: + if all( + frame.attributes.get(key) == value + for key, value in attribute_dict.items() + ): + filtered_frames.append(frame) + break + else: + raise ValueError( + "Attribute to load is specified but no attributes " + "are found in the frame." + ) + return filtered_frames + + +# Not using | operator because of a bug in Python 3.9 +# https://bugs.python.org/issue42233 +CategoryMap = Union[dict[str, int], dict[str, dict[str, int]]] + + +class Scalabel(CacheMappingMixin, VideoDataset): + """Scalabel type dataset. + + This class loads scalabel format data into Vis4D. + """ + + def __init__( + self, + data_root: str, + annotation_path: str, + keys_to_load: Sequence[str] = (K.images, K.boxes2d), + category_map: None | CategoryMap = None, + config_path: None | str | Config = None, + global_instance_ids: bool = False, + bg_as_class: bool = False, + skip_empty_samples: bool = False, + attributes_to_load: Sequence[dict[str, str | float]] | None = None, + cache_as_binary: bool = False, + cached_file_path: str | None = None, + **kwargs: ArgsType, + ) -> None: + """Creates an instance of the class. + + Args: + data_root (str): Root directory of the data. + annotation_path (str): Path to the annotation json(s). + keys_to_load (Sequence[str, ...], optional): Keys to load from the + dataset. Defaults to (K.images, K.boxes2d). + category_map (None | CategoryMap, optional): Mapping from a + Scalabel category string to an integer index. If None, the + standard mapping in the dataset config will be used. Defaults + to None. + config_path (None | str | Config, optional): Path to the dataset + config, can be added if it is not provided together with the + labels or should be modified. Defaults to None. + global_instance_ids (bool): Whether to convert tracking IDs of + annotations into dataset global IDs or stay with local, + per-video IDs. Defaults to false. + bg_as_class (bool): Whether to include background pixels as an + additional class for masks. + skip_empty_samples (bool): Whether to skip samples without + annotations. + attributes_to_load (Sequence[dict[str, str]]): List of attributes + dictionaries to load. Each dictionary is a mapping from the + attribute name to its desired value. If any of the attributes + dictionaries is matched, the corresponding frame will be + loaded. Defaults to None. + cache_as_binary (bool): Whether to cache the dataset as binary. + Default: False. + cached_file_path (str | None): Path to a cached file. If cached + file exist then it will load it instead of generating the data + mapping. Default: None. + """ + super().__init__(**kwargs) + assert SCALABEL_AVAILABLE, "Scalabel is not installed." + self.data_root = data_root + self.annotation_path = annotation_path + self.keys_to_load = keys_to_load + self.global_instance_ids = global_instance_ids + self.bg_as_class = bg_as_class + self.config_path = config_path + self.skip_empty_samples = skip_empty_samples + + self.cats_name2id: dict[str, dict[str, int]] = {} + self.category_map = category_map + + self.attributes_to_load = attributes_to_load + + self.frames, self.cfg = self._load_mapping( + self._generate_mapping, + remove_empty_samples, + cache_as_binary=cache_as_binary, + cached_file_path=cached_file_path, + ) + + assert self.cfg is not None, ( + "No dataset configuration found. Please provide a configuration " + "via config_path." + ) + + if self.category_map is None: + class_list = list( + c.name for c in get_leaf_categories(self.cfg.categories) + ) + self.category_map = {c: i for i, c in enumerate(class_list)} + self._setup_categories() + self.video_mapping = self._generate_video_mapping() + + def _generate_video_mapping(self) -> VideoMapping: + """Group all dataset sample indices (int) by their video ID (str). + + Returns: + VideoMapping: Mapping of video IDs to sample indices and frame IDs. + """ + video_to_indices: dict[str, list[int]] = defaultdict(list) + video_to_frame_ids: dict[str, list[int]] = defaultdict(list) + for idx, frame in enumerate(self.frames): # type: ignore + if frame.videoName is not None: + assert ( + frame.frameIndex is not None + ), "found videoName but no frameIndex!" + video_to_indices[frame.videoName].append(idx) + video_to_frame_ids[frame.videoName].append(frame.frameIndex) + + return self._sort_video_mapping( + { + "video_to_indices": video_to_indices, + "video_to_frame_ids": video_to_frame_ids, + } + ) + + def _setup_categories(self) -> None: + """Setup categories.""" + assert self.category_map is not None + for target in self.keys_to_load: + if isinstance(list(self.category_map.values())[0], int): + self.cats_name2id[target] = self.category_map # type: ignore + else: + assert ( + target in self.category_map + ), f"Target={target} not specified in category_mapping" + target_map = self.category_map[target] + assert isinstance(target_map, dict) + self.cats_name2id[target] = target_map + + def _load_mapping( # type: ignore + self, + generate_map_func: Callable[[], ScalabelData], + filter_func: Callable[[ListAny], ListAny] = lambda x: x, + cache_as_binary: bool = True, + cached_file_path: str | None = None, + ) -> tuple[DatasetFromList, Config]: + """Load cached mapping or generate if not exists.""" + timer = Timer() + data = self._load_mapping_data( + generate_map_func, cache_as_binary, cached_file_path + ) + if data is not None: + frames, cfg = data.frames, data.config + + add_data_path(self.data_root, frames) + rank_zero_info(f"Loading {self} takes {timer.time():.2f} seconds.") + + if self.category_map is None: + class_list = list( + c.name for c in get_leaf_categories(cfg.categories) + ) + self.category_map = {c: i for i, c in enumerate(class_list)} + else: + class_list = list(self.category_map.keys()) + + assert len(set(class_list)) == len( + class_list + ), "Class names are not unique!" + + discard_labels_outside_set(frames, class_list) + + frames = filter_frames_by_attributes( + frames, self.attributes_to_load + ) + + if self.skip_empty_samples: + frames = filter_func(frames) + + t = Timer() + frequencies = prepare_labels( + frames, + class_list, + global_instance_ids=self.global_instance_ids, + ) + rank_zero_info( + f"Preprocessing {len(frames)} frames takes {t.time():.2f}" + " seconds." + ) + print_class_histogram(frequencies) + frames_dataset = DatasetFromList(frames) + else: + frames_dataset = None + cfg = None + frames_dataset = broadcast(frames_dataset) + cfg = broadcast(cfg) + assert frames_dataset is not None + return frames_dataset, cfg + + def _generate_mapping(self) -> ScalabelData: + """Generate data mapping.""" + data = load(self.annotation_path) + if self.config_path is not None: + if isinstance(self.config_path, str): + data.config = load_label_config(self.config_path) + else: + data.config = self.config_path + return data + + def _load_inputs(self, frame: Frame) -> DictData: + """Load inputs given a scalabel frame.""" + data: DictData = {} + if K.images in self.keys_to_load: + assert frame.url is not None, "url is None!" + image = load_image( + frame.url, self.data_backend, self.image_channel_mode + ) + input_hw = (image.shape[1], image.shape[2]) + data[K.images] = image + data[K.input_hw] = input_hw + + # Original image + data[K.original_images] = image + data[K.original_hw] = input_hw + + data[K.axis_mode] = AxisMode.OPENCV + data[K.frame_ids] = frame.frameIndex + + data[K.sample_names] = frame.name + data[K.sequence_names] = frame.videoName + + if K.points3d in self.keys_to_load: + assert frame.url is not None, "url is None!" + data[K.points3d] = load_pointcloud(frame.url, self.data_backend) + + if frame.intrinsics is not None and K.intrinsics in self.keys_to_load: + data[K.intrinsics] = load_intrinsics(frame.intrinsics) + + if frame.extrinsics is not None and K.extrinsics in self.keys_to_load: + data[K.extrinsics] = load_extrinsics(frame.extrinsics) + return data + + def _add_annotations(self, frame: Frame, data: DictData) -> None: + """Add annotations given a scalabel frame and a data dictionary.""" + labels_used, instid_map = [], {} + if frame.labels is not None: + for label in frame.labels: + assert ( + label.attributes is not None and label.category is not None + ) + if not check_crowd(label) and not check_ignored(label): + labels_used.append(label) + if label.id not in instid_map: + instid_map[label.id] = int( + label.attributes["instance_id"] + ) + + image_size = ( + ImageSize(height=data[K.input_hw][0], width=data[K.input_hw][1]) + if K.input_hw in data + else frame.size + ) + + if K.boxes2d in self.keys_to_load: + cats_name2id = self.cats_name2id[K.boxes2d] + boxes2d, classes, track_ids = boxes2d_from_scalabel( + labels_used, cats_name2id, instid_map + ) + data[K.boxes2d] = boxes2d + data[K.boxes2d_classes] = classes + data[K.boxes2d_track_ids] = track_ids + + if K.instance_masks in self.keys_to_load: + # NOTE: instance masks' mapping is consistent with boxes2d + cats_name2id = self.cats_name2id[K.instance_masks] + instance_masks = instance_masks_from_scalabel( + labels_used, cats_name2id, image_size + ) + data[K.instance_masks] = instance_masks + + if K.seg_masks in self.keys_to_load: + sem_map = self.cats_name2id[K.seg_masks] + semantic_masks = semantic_masks_from_scalabel( + labels_used, sem_map, image_size, self.bg_as_class + ) + data[K.seg_masks] = semantic_masks + + if K.boxes3d in self.keys_to_load: + boxes3d, classes, track_ids = boxes3d_from_scalabel( + labels_used, self.cats_name2id[K.boxes3d], instid_map + ) + data[K.boxes3d] = boxes3d + data[K.boxes3d_classes] = classes + data[K.boxes3d_track_ids] = track_ids + + def __len__(self) -> int: + """Length of dataset.""" + return len(self.frames) + + def __getitem__(self, index: int) -> DictData: + """Get item from dataset at given index.""" + frame = self.frames[index] + data = self._load_inputs(frame) + + # load annotations to input sample + self._add_annotations(frame, data) + + return data + + +def boxes2d_from_scalabel( + labels: list[Label], + class_to_idx: dict[str, int], + label_id_to_idx: dict[str, int] | None = None, +) -> tuple[NDArrayF32, NDArrayI64, NDArrayI64]: + """Convert from scalabel format to Vis4D. + + NOTE: The box definition in Scalabel includes x2y2 in the box area, whereas + Vis4D and other software libraries like detectron2 and mmdet do not include + this, which is why we convert via box2d_to_xyxy. + + Args: + labels (list[Label]): list of scalabel labels. + class_to_idx (dict[str, int]): mapping from class name to index. + label_id_to_idx (dict[str, int] | None, optional): mapping from label + id to index. Defaults to None. + + Returns: + tuple[NDArrayF32, NDArrayI64, NDArrayI64]: boxes, classes, track_ids + """ + box_list, cls_list, idx_list = [], [], [] + for i, label in enumerate(labels): + box, box_cls, l_id = label.box2d, label.category, label.id + if box is None: + continue + if box_cls in class_to_idx: + cls_list.append(class_to_idx[box_cls]) + else: + continue + + box_list.append(box2d_to_xyxy(box)) + idx = label_id_to_idx[l_id] if label_id_to_idx is not None else i + idx_list.append(idx) + + if len(box_list) == 0: + return ( + np.empty((0, 4), dtype=np.float32), + np.empty((0,), dtype=np.int64), + np.empty((0,), dtype=np.int64), + ) + + box_tensor = np.array(box_list, dtype=np.float32) + class_ids = np.array(cls_list, dtype=np.int64) + track_ids = np.array(idx_list, dtype=np.int64) + return box_tensor, class_ids, track_ids + + +def instance_masks_from_scalabel( + labels: list[Label], + class_to_idx: dict[str, int], + image_size: ImageSize | None = None, +) -> NDArrayUI8: + """Convert instance masks from scalabel format to Vis4D. + + Args: + labels (list[Label]): list of scalabel labels. + class_to_idx (dict[str, int]): mapping from class name to index. + image_size (ImageSize, optional): image size. Defaults to None. + + Returns: + NDArrayUI8: instance masks. + """ + bitmask_list = [] + for _, label in enumerate(labels): + if label.category not in class_to_idx: # pragma: no cover + continue # skip unknown classes + if label.poly2d is None and label.rle is None: + continue + if label.rle is not None: + bitmask = rle_to_mask(label.rle) + elif label.poly2d is not None: + assert ( + image_size is not None + ), "image size must be specified for masks with polygons!" + bitmask_raw = poly2ds_to_mask(image_size, label.poly2d) + bitmask: NDArrayUI8 = (bitmask_raw > 0).astype( # type: ignore + bitmask_raw.dtype + ) + else: + raise ValueError("No mask found in label.") + bitmask_list.append(bitmask) + if len(bitmask_list) == 0: # pragma: no cover + return np.empty((0, 0, 0), dtype=np.uint8) + mask_array = np.array(bitmask_list, dtype=np.uint8) + return mask_array + + +def nhw_to_hwc_mask( + masks: NDArrayUI8, class_ids: NDArrayI64, ignore_class: int = 255 +) -> NDArrayUI8: + """Convert N binary HxW masks to HxW semantic mask. + + Args: + masks (NDArrayUI8): Masks with shape [N, H, W]. + class_ids (NDArrayI64): Class IDs with shape [N, 1]. + ignore_class (int, optional): Ignore label. Defaults to 255. + + Returns: + NDArrayUI8: Masks with shape [H, W], where each location indicate the + class label. + """ + hwc_mask = np.full(masks.shape[1:], ignore_class, dtype=masks.dtype) + for mask, cat_id in zip(masks, class_ids): + hwc_mask[mask > 0] = cat_id + return hwc_mask + + +def semantic_masks_from_scalabel( + labels: list[Label], + class_to_idx: dict[str, int], + image_size: ImageSize | None = None, + bg_as_class: bool = False, +) -> NDArrayUI8: + """Convert masks from scalabel format to Vis4D. + + Args: + labels (list[Label]): list of scalabel labels. + class_to_idx (dict[str, int]): mapping from class name to index. + image_size (ImageSize, optional): image size. Defaults to None. + bg_as_class (bool, optional): whether to include background as a class. + Defaults to False. + + Returns: + NDArrayUI8: instance masks. + """ + bitmask_list, cls_list = [], [] + if bg_as_class: + foreground: NDArrayUI8 | None = None + for _, label in enumerate(labels): + if label.poly2d is None and label.rle is None: + continue + mask_cls = label.category + if mask_cls in class_to_idx: + cls_list.append(class_to_idx[mask_cls]) + else: # pragma: no cover + continue # skip unknown classes + if label.rle is not None: + bitmask = rle_to_mask(label.rle) + elif label.poly2d is not None: + assert ( + image_size is not None + ), "image size must be specified for masks with polygons!" + bitmask_raw = poly2ds_to_mask(image_size, label.poly2d) + bitmask: NDArrayUI8 = (bitmask_raw > 0).astype( # type: ignore + bitmask_raw.dtype + ) + else: + raise ValueError("No mask found in label.") + bitmask_list.append(bitmask) + if bg_as_class: + foreground = ( + bitmask + if foreground is None + else np.logical_or(foreground, bitmask) + ) + if bg_as_class: + if foreground is None: # pragma: no cover + assert image_size is not None + foreground = np.zeros( + (image_size.height, image_size.width), dtype=np.uint8 + ) + bitmask_list.append(np.logical_not(foreground)) + assert "background" in class_to_idx, ( + '"bg_as_class" requires "background" class to be ' + "in category_mapping" + ) + cls_list.append(class_to_idx["background"]) + if len(bitmask_list) == 0: # pragma: no cover + return np.empty((0, 0), dtype=np.uint8) + mask_array = np.array(bitmask_list, dtype=np.uint8) + class_ids = np.array(cls_list, dtype=np.int64) + return nhw_to_hwc_mask(mask_array, class_ids) + + +def boxes3d_from_scalabel( + labels: list[Label], + class_to_idx: dict[str, int], + label_id_to_idx: dict[str, int] | None = None, +) -> tuple[NDArrayF32, NDArrayI64, NDArrayI64]: + """Convert 3D bounding boxes from scalabel format to Vis4D.""" + box_list, cls_list, idx_list = [], [], [] + for i, label in enumerate(labels): + box, box_cls, l_id = label.box3d, label.category, label.id + if box is None: + continue + if box_cls in class_to_idx: + cls_list.append(class_to_idx[box_cls]) + else: + continue + + quaternion = ( + matrix_to_quaternion( + euler_angles_to_matrix(torch.tensor([box.orientation])) + )[0] + .numpy() + .tolist() + ) + box_list.append([*box.location, *box.dimension, *quaternion]) + idx = label_id_to_idx[l_id] if label_id_to_idx is not None else i + idx_list.append(idx) + + if len(box_list) == 0: + return ( + np.empty((0, 10), dtype=np.float32), + np.empty((0,), dtype=np.int64), + np.empty((0,), dtype=np.int64), + ) + box_tensor = np.array(box_list, dtype=np.float32) + class_ids = np.array(cls_list, dtype=np.int64) + track_ids = np.array(idx_list, dtype=np.int64) + return box_tensor, class_ids, track_ids diff --git a/vis4d/data/datasets/shift.py b/vis4d/data/datasets/shift.py new file mode 100644 index 0000000000000000000000000000000000000000..fd79e2fced301293f937e01579b54cc57bc834e6 --- /dev/null +++ b/vis4d/data/datasets/shift.py @@ -0,0 +1,621 @@ +"""SHIFT dataset.""" + +from __future__ import annotations + +import json +import multiprocessing +import os +from collections.abc import Sequence +from functools import partial + +import numpy as np +from tqdm import tqdm + +from vis4d.common.imports import SCALABEL_AVAILABLE +from vis4d.common.logging import rank_zero_info +from vis4d.common.typing import NDArrayF32, NDArrayI64, NDArrayNumber +from vis4d.data.const import CommonKeys as K +from vis4d.data.datasets.base import VideoDataset +from vis4d.data.datasets.util import im_decode, npy_decode +from vis4d.data.io import DataBackend, FileBackend, HDF5Backend, ZipBackend +from vis4d.data.typing import DictData + +from .base import VideoDataset, VideoMapping +from .scalabel import Scalabel + +shift_det_map = { + "pedestrian": 0, + "car": 1, + "truck": 2, + "bus": 3, + "motorcycle": 4, + "bicycle": 5, +} +shfit_track_map = { + "pedestrian": 0, + "car": 1, + "truck": 2, + "bus": 3, + "motorcycle": 4, + "bicycle": 5, +} +shift_seg_map = { + "unlabeled": 0, + "building": 1, + "fence": 2, + "other": 3, + "pedestrian": 4, + "pole": 5, + "road line": 6, + "road": 7, + "sidewalk": 8, + "vegetation": 9, + "vehicle": 10, + "wall": 11, + "traffic sign": 12, + "sky": 13, + "ground": 14, + "bridge": 15, + "rail track": 16, + "guard rail": 17, + "traffic light": 18, + "static": 19, + "dynamic": 20, + "water": 21, + "terrain": 22, +} +shift_seg_ignore = [ + "unlabeled", + "other", + "ground", + "bridge", + "rail track", + "guard rail", + "static", + "dynamic", + "water", +] + +if SCALABEL_AVAILABLE: + from scalabel.label.io import parse + from scalabel.label.typing import Config + from scalabel.label.typing import Dataset as ScalabelData +else: + raise ImportError("scalabel is not installed.") + + +def _get_extension(backend: DataBackend) -> str: + """Get the appropriate file extension for the given backend.""" + if isinstance(backend, HDF5Backend): + return ".hdf5" + if isinstance(backend, ZipBackend): + return ".zip" + if isinstance(backend, FileBackend): # pragma: no cover + return "" + raise ValueError(f"Unsupported backend {backend}.") # pragma: no cover + + +class _SHIFTScalabelLabels(Scalabel): + """Helper class for labels in SHIFT that are stored in Scalabel format.""" + + VIEWS = [ + "front", + "center", + "left_45", + "left_90", + "right_45", + "right_90", + "left_stereo", + ] + + def __init__( + self, + data_root: str, + split: str, + data_file: str = "", + keys_to_load: Sequence[str] = (K.images, K.boxes2d), + attributes_to_load: Sequence[dict[str, str | float]] | None = None, + annotation_file: str = "", + view: str = "front", + framerate: str = "images", + shift_type: str = "discrete", + skip_empty_frames: bool = False, + backend: DataBackend = HDF5Backend(), + verbose: bool = False, + num_workers: int = 1, + ) -> None: + """Initialize SHIFT dataset for one view. + + Args: + data_root (str): Path to the root directory of the dataset. + split (str): Which data split to load. + data_file (str): Path to the data archive file. Default: "". + keys_to_load (Sequence[str]): List of keys to load. + Default: (K.images, K.boxes2d). + attributes_to_load (Sequence[dict[str, str | float]] | None): + List of attributes to load. Default: None. + annotation_file (str): Path to the annotation file. Default: "". + view (str): Which view to load. Default: "front". Options: "front", + "center", "left_45", "left_90", "right_45", "right_90", and + "left_stereo". + framerate (str): Which framerate to load. Default: "images". + shift_type (str): Which shift type to load. Default: "discrete". + Options: "discrete", "continuous/1x", "continuous/10x", and + "continuous/100x". + skip_empty_frames (bool): Whether to skip frames with no + instance annotations. Default: False. + backend (DataBackend): Backend to use for loading data. Default: + HDF5Backend(). + verbose (bool): Whether to print verbose logs. Default: False. + num_workers (int): Number of workers to use for loading data. + Default: 1. + """ + self.verbose = verbose + self.num_workers = num_workers + + # Validate input + assert split in {"train", "val", "test"}, f"Invalid split '{split}'" + assert view in _SHIFTScalabelLabels.VIEWS, f"Invalid view '{view}'" + + # Set attributes + ext = _get_extension(backend) + if shift_type.startswith("continuous"): + shift_speed = shift_type.split("/")[-1] + annotation_path = os.path.join( + data_root, + "continuous", + framerate, + shift_speed, + split, + view, + annotation_file, + ) + data_path = os.path.join( + data_root, + "continuous", + framerate, + shift_speed, + split, + view, + f"{data_file}{ext}", + ) + else: + annotation_path = os.path.join( + data_root, "discrete", framerate, split, view, annotation_file + ) + data_path = os.path.join( + data_root, + "discrete", + framerate, + split, + view, + f"{data_file}{ext}", + ) + super().__init__( + data_path, + annotation_path, + data_backend=backend, + keys_to_load=keys_to_load, + attributes_to_load=attributes_to_load, + skip_empty_samples=skip_empty_frames, + ) + + def _generate_mapping(self) -> ScalabelData: + """Generate data mapping.""" + # Skipping validation for much faster loading + if self.verbose: + rank_zero_info( + "Loading annotation from '%s' ...", self.annotation_path + ) + return self._load(self.annotation_path) + + def _load(self, filepath: str) -> ScalabelData: + """Load labels from a json file or a folder of json files.""" + raw_frames: list[DictData] = [] + raw_groups: list[DictData] = [] + if not os.path.exists(filepath): + raise FileNotFoundError(f"{filepath} does not exist.") + + def process_file(filepath: str) -> DictData | None: + raw_cfg = None + with open(filepath, mode="r", encoding="utf-8") as fp: + content = json.load(fp) + if isinstance(content, dict): + raw_frames.extend(content["frames"]) + if "groups" in content and content["groups"] is not None: + raw_groups.extend(content["groups"]) + if "config" in content and content["config"] is not None: + raw_cfg = content["config"] + elif isinstance(content, list): + raw_frames.extend(content) + else: + raise TypeError( + "The input file contains neither dict nor list." + ) + + rank_zero_info( + "Loading SHIFT annotation from '%s' Done.", filepath + ) + return raw_cfg + + cfg = None + if os.path.isfile(filepath) and filepath.endswith("json"): + ret_cfg = process_file(filepath) + if ret_cfg is not None: + cfg = ret_cfg + else: + raise TypeError("Inputs must be a folder or a JSON file.") + + config = None + if cfg is not None: + config = Config(**cfg) + + parse_func = partial(parse, validate_frames=False) + if self.num_workers > 1: + with multiprocessing.Pool(self.num_workers) as pool: + frames = [] + with tqdm(total=len(raw_frames)) as pbar: + for result in pool.imap_unordered( + parse_func, raw_frames, chunksize=1000 + ): + frames.append(result) + pbar.update() + else: + frames = [parse_func(frame) for frame in raw_frames] + return ScalabelData(frames=frames, config=config, groups=None) + + +class SHIFT(VideoDataset): + """SHIFT dataset class, supporting multiple tasks and views.""" + + DESCRIPTION = """SHIFT Dataset, a synthetic driving dataset for continuous + multi-task domain adaptation""" + HOMEPAGE = "https://www.vis.xyz/shift/" + PAPER = "https://arxiv.org/abs/2206.08367" + LICENSE = "CC BY-NC-SA 4.0" + + KEYS = [ + # Inputs + K.images, + K.original_hw, + K.input_hw, + K.points3d, + # Scalabel formatted annotations + K.intrinsics, + K.extrinsics, + K.timestamp, + K.axis_mode, + K.boxes2d, + K.boxes2d_classes, + K.boxes2d_track_ids, + K.instance_masks, + K.boxes3d, + K.boxes3d_classes, + K.boxes3d_track_ids, + # Bit masks + K.seg_masks, + K.depth_maps, + K.optical_flows, + ] + + VIEWS = [ + "front", + "center", + "left_45", + "left_90", + "right_45", + "right_90", + "left_stereo", + ] + + DATA_GROUPS = { + "img": [ + K.images, + K.original_hw, + K.input_hw, + K.intrinsics, + ], + "det_2d": [ + K.timestamp, + K.axis_mode, + K.extrinsics, + K.boxes2d, + K.boxes2d_classes, + K.boxes2d_track_ids, + ], + "det_3d": [ + K.boxes3d, + K.boxes3d_classes, + K.boxes3d_track_ids, + ], + "det_insseg_2d": [ + K.instance_masks, + ], + "semseg": [ + K.seg_masks, + ], + "depth": [ + K.depth_maps, + ], + "flow": [ + K.optical_flows, + ], + "lidar": [ + K.points3d, + ], + } + + GROUPS_IN_SCALABEL = ["det_2d", "det_3d", "det_insseg_2d"] + + def __init__( + self, + data_root: str, + split: str, + keys_to_load: Sequence[str] = (K.images, K.boxes2d), + views_to_load: Sequence[str] = ("front",), + attributes_to_load: Sequence[dict[str, str | float]] | None = None, + framerate: str = "images", + shift_type: str = "discrete", + skip_empty_frames: bool = False, + backend: DataBackend = HDF5Backend(), + num_workers: int = 1, + verbose: bool = False, + ) -> None: + """Initialize SHIFT dataset.""" + super().__init__(data_backend=backend) + # Validate input + assert split in {"train", "val", "test"}, f"Invalid split '{split}'." + assert framerate in { + "images", + "videos", + }, f"Invalid framerate '{framerate}'. Must be 'images' or 'videos'." + assert shift_type in { + "discrete", + "continuous/1x", + "continuous/10x", + "continuous/100x", + }, ( + f"Invalid shift_type '{shift_type}'. " + "Must be one of 'discrete', 'continuous/1x', 'continuous/10x', " + "or 'continuous/100x'." + ) + self.validate_keys(keys_to_load) + + # Set attributes + self.data_root = data_root + self.split = split + self.keys_to_load = keys_to_load + self.views_to_load = views_to_load + self.attributes_to_load = attributes_to_load + self.framerate = framerate + self.shift_type = shift_type + self.backend = backend + self.verbose = verbose + self.ext = _get_extension(backend) + if self.shift_type.startswith("continuous"): + shift_speed = self.shift_type.split("/")[-1] + self.annotation_base = os.path.join( + self.data_root, + "continuous", + self.framerate, + shift_speed, + self.split, + ) + else: + self.annotation_base = os.path.join( + self.data_root, self.shift_type, self.framerate, self.split + ) + if self.verbose: + print(f"Base: {self.annotation_base}. Backend: {self.backend}") + + # Get the data groups' classes that need to be loaded + self._data_groups_to_load = self._get_data_groups(keys_to_load) + if "det_2d" not in self._data_groups_to_load: + raise ValueError( + "In current implementation, the 'det_2d' data group must be " + "loaded to load any other data group." + ) + + self.scalabel_datasets = {} + for view in self.views_to_load: + if view == "center": + # Load lidar data, only available for center view + self.scalabel_datasets["center/lidar"] = _SHIFTScalabelLabels( + data_root=self.data_root, + split=self.split, + data_file="lidar", + annotation_file="det_3d.json", + view=view, + framerate=self.framerate, + shift_type=self.shift_type, + keys_to_load=(K.points3d, *self.DATA_GROUPS["det_3d"]), + attributes_to_load=self.attributes_to_load, + skip_empty_frames=skip_empty_frames, + backend=backend, + num_workers=num_workers, + verbose=verbose, + ) + else: + # Skip the lidar data group, which is loaded separately + image_loaded = False + for group in self._data_groups_to_load: + name = f"{view}/{group}" + keys_to_load = list(self.DATA_GROUPS[group]) + # Load the image data group only once + if not image_loaded: + keys_to_load.extend(self.DATA_GROUPS["img"]) + image_loaded = True + self.scalabel_datasets[name] = _SHIFTScalabelLabels( + data_root=self.data_root, + split=self.split, + data_file="img", + annotation_file=f"{group}.json", + view=view, + framerate=self.framerate, + shift_type=self.shift_type, + keys_to_load=keys_to_load, + attributes_to_load=self.attributes_to_load, + skip_empty_frames=skip_empty_frames, + backend=backend, + num_workers=num_workers, + verbose=verbose, + ) + + self.video_mapping = self._generate_video_mapping() + + def validate_keys(self, keys_to_load: Sequence[str]) -> None: + """Validate that all keys to load are supported.""" + for k in keys_to_load: + if k not in self.KEYS: + raise ValueError(f"Key '{k}' is not supported!") + + def _get_data_groups(self, keys_to_load: Sequence[str]) -> list[str]: + """Get the data groups that need to be loaded from Scalabel.""" + data_groups = ["det_2d"] + for data_group, group_keys in self.DATA_GROUPS.items(): + if data_group in self.GROUPS_IN_SCALABEL: + # If the data group is loaded by Scalabel, add it to the list + if any(key in group_keys for key in keys_to_load): + data_groups.append(data_group) + return list(set(data_groups)) + + def _load( + self, view: str, data_group: str, file_ext: str, video: str, frame: str + ) -> NDArrayNumber: + """Load data from the given data group.""" + frame_number = frame.split("_")[0] + filepath = os.path.join( + self.annotation_base, + view, + f"{data_group}{self.ext}", + video, + f"{frame_number}_{data_group}_{view}.{file_ext}", + ) + if data_group == "semseg": + return self._load_semseg(filepath) + if data_group == "depth": + return self._load_depth(filepath) + if data_group == "flow": + return self._load_flow(filepath) + raise ValueError( + f"Invalid data group '{data_group}'" + ) # pragma: no cover + + def _load_semseg(self, filepath: str) -> NDArrayI64: + """Load semantic segmentation data.""" + im_bytes = self.backend.get(filepath) + image = im_decode(im_bytes)[..., 0] + return image.astype(np.int64) + + def _load_depth( + self, filepath: str, depth_factor: float = 16777.216 # 256 ^ 3 / 1000 + ) -> NDArrayF32: + """Load depth data.""" + assert depth_factor > 0, "Max depth value must be greater than 0." + + im_bytes = self.backend.get(filepath) + image = im_decode(im_bytes) + if image.shape[2] > 3: # pragma: no cover + image = image[:, :, :3] + image = image.astype(np.float32) + + # Convert to depth + depth = ( + image[:, :, 2] * 256 * 256 + image[:, :, 1] * 256 + image[:, :, 0] + ) + return np.ascontiguousarray(depth / depth_factor, dtype=np.float32) + + def _load_flow(self, filepath: str) -> NDArrayF32: + """Load optical flow data.""" + npy_bytes = self.backend.get(filepath) + flow = npy_decode(npy_bytes, key="flow") + flow = flow[:, :, [1, 0]] # Convert to (u, v) format + flow *= flow.shape[1] # Scale to image size (1280) + if self.framerate == "images": + flow *= 10.0 # NOTE: Scale to 1 fps approximately + return flow.astype(np.float32) + + def _get_frame_key(self, idx: int) -> tuple[str, str]: + """Get the frame identifier (video name, frame name) by index.""" + if len(self.scalabel_datasets) > 0: + frames = self.scalabel_datasets[ + list(self.scalabel_datasets.keys())[0] + ].frames + return frames[idx].videoName, frames[idx].name + raise ValueError("No Scalabel file has been loaded.") + + def __len__(self) -> int: + """Get the number of samples in the dataset.""" + if len(self.scalabel_datasets) > 0: + return len( + self.scalabel_datasets[list(self.scalabel_datasets.keys())[0]] + ) + raise ValueError( + "No Scalabel file has been loaded." + ) # pragma: no cover + + def _generate_video_mapping(self) -> VideoMapping: + """Group all dataset sample indices (int) by their video ID (str). + + Returns: + VideoMapping: Mapping of video IDs to sample indices and frame IDs. + + Raises: + ValueError: If no Scalabel file has been loaded. + """ + if len(self.scalabel_datasets) > 0: + return self.scalabel_datasets[ + list(self.scalabel_datasets.keys())[0] + ].video_mapping + raise ValueError("No Scalabel file has been loaded.") + + def __getitem__(self, idx: int) -> DictData: + """Get single sample. + + Args: + idx (int): Index of sample. + + Returns: + DictData: sample at index in Vis4D input format. + """ + # load camera frames + data_dict = {} + + # metadata + video_name, frame_name = self._get_frame_key(idx) + data_dict[K.sample_names] = frame_name + data_dict[K.sequence_names] = video_name + data_dict[K.frame_ids] = frame_name.split("_")[0] + + for view in self.views_to_load: + data_dict_view = {} + + if view == "center": + # Lidar is only available in the center view + if K.points3d in self.keys_to_load: + data_dict_view.update( + self.scalabel_datasets["center/lidar"][idx] + ) + else: + # Load data from Scalabel + for group in self._data_groups_to_load: + data_dict_view.update( + self.scalabel_datasets[f"{view}/{group}"][idx] + ) + + # Load data from bit masks + if K.seg_masks in self.keys_to_load: + data_dict_view[K.seg_masks] = self._load( + view, "semseg", "png", video_name, frame_name + ) + if K.depth_maps in self.keys_to_load: + data_dict_view[K.depth_maps] = self._load( + view, "depth", "png", video_name, frame_name + ) + if K.optical_flows in self.keys_to_load: + data_dict_view[K.optical_flows] = self._load( + view, "flow", "npz", video_name, frame_name + ) + data_dict[view] = data_dict_view # type: ignore + + return data_dict diff --git a/vis4d/data/datasets/torchvision.py b/vis4d/data/datasets/torchvision.py new file mode 100644 index 0000000000000000000000000000000000000000..a23c974b56f472079261807bee5bc816850e5a20 --- /dev/null +++ b/vis4d/data/datasets/torchvision.py @@ -0,0 +1,130 @@ +"""Provides functionalities to wrap torchvision datasets.""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import Any + +import numpy as np +from PIL.Image import Image +from torchvision.datasets import VisionDataset +from torchvision.transforms import ToTensor + +from ..const import CommonKeys as K +from ..typing import DictData +from .base import Dataset + + +class TorchvisionDataset(Dataset): + """Wrapper for torchvision datasets. + + This class wraps torchvision datasets and converts them to the format that + is expected by the vis4d framework. + + The return of the torchvisons dataset is passed to the data_converter, + which needs to be provided by the user. The data_converter is expected to + return a DictData object following the vis4d conventions. + + For well defined dataformats, such as classification, there + are already implemented wrappers that can be used. See + `TorchvisionClassificationDataset` for an example. + """ + + def __init__( # type: ignore + self, + torchvision_ds: VisionDataset, + data_converter: Callable[[Any], DictData], + ) -> None: + """Creates a new instance of the class. + + Args: + torchvision_ds (VisionDataset): Torchvision dataset that should be + converted. + data_converter (Callable[[Any], DictData]): Function that + converts the output of the torchvision datasets __getitem__ + to the format expected by the vis4d framework. + """ + super().__init__() + self.torchvision_ds = torchvision_ds + self.data_converter = data_converter + + def __getitem__(self, idx: int) -> DictData: + """Returns a new sample from the dataset. + + Args: + idx (int): Index of the sample. + + Returns: + DictData: Data in vis4d format. + """ + return self.data_converter(self.torchvision_ds[idx]) + + def __len__(self) -> int: + """Returns the number of samples in the dataset. + + Returns: + int: Length of the dataset. + """ + return len(self.torchvision_ds) + + +class TorchvisionClassificationDataset(TorchvisionDataset): + """Wrapper for torchvision classification datasets. + + This class wraps torchvision classification datasets and converts them to + the format that is expected by the vis4d framework. + + It expects the torchvision dataset to return a tuple of (image, class_id) + where the image is a PIL Image and the class_id is an integer. + + If you want to use a torchvision dataset that returns a different format, + you can provide a custom data_converter function to the + `TorchvisionDataset` class. + + The returned sample will have the following key, values: + images: ndarray of dimension (1, H, W, C) + categories: ndarray of dimension 1. + + Example: + >>> from torchvision.datasets.mnist import MNIST + >>> ds = TorchvisionClassificationDataset( + >>> MNIST("data/mnist_ds", train=False) + >>> ) + >>> data = next(iter(ds)) + >>> print(data.keys) + dict_keys(['images', 'categories']) + """ + + def __init__(self, detection_ds: VisionDataset) -> None: + """Creates a new instance of the class. + + Args: + detection_ds (VisionDataset): Torchvision dataset that + returns a tuple of (image, class_id) where the image is a PIL + Image and the class_id is an integer. + """ + img_to_tensor = ToTensor() + + def _data_converter(img_and_target: tuple[Image, int]) -> DictData: + """Converts the output of a torchvision dataset. + + The output is converted to the format expected by the vis4d + framework. + + Args: + img_and_target (tuple[Image, int]): Output of the datasets + __getitem__ method. + + Returns: + DictData: Sample in vis4d format. + """ + img, class_id = img_and_target + data: DictData = {} + data[K.images] = ( + img_to_tensor(img).unsqueeze(0).permute(0, 2, 3, 1).numpy() + ) + data[K.categories] = np.array([class_id], dtype=np.int64) + + return data + + super().__init__(detection_ds, _data_converter) diff --git a/vis4d/data/datasets/util.py b/vis4d/data/datasets/util.py new file mode 100644 index 0000000000000000000000000000000000000000..cdb232259b9f8a584f8d7ca4db3ba8f8b08bd1d5 --- /dev/null +++ b/vis4d/data/datasets/util.py @@ -0,0 +1,367 @@ +"""Utility functions for datasets.""" + +from __future__ import annotations + +import copy +import itertools +import os +import pickle +from collections.abc import Callable, Sequence +from datetime import datetime +from io import BytesIO +from typing import Any + +import numpy as np +import plyfile +from PIL import Image, ImageOps +from tabulate import tabulate +from termcolor import colored +from torch.utils.data import Dataset + +from vis4d.common.distributed import broadcast, rank_zero_only +from vis4d.common.imports import OPENCV_AVAILABLE +from vis4d.common.logging import rank_zero_info +from vis4d.common.time import Timer +from vis4d.common.typing import ( + DictStrAny, + ListAny, + NDArrayFloat, + NDArrayI64, + NDArrayUI8, +) + +from ..typing import DictData + +if OPENCV_AVAILABLE: + from cv2 import ( # pylint: disable=no-member,no-name-in-module + COLOR_BGR2RGB, + IMREAD_COLOR, + IMREAD_GRAYSCALE, + cvtColor, + imdecode, + ) +else: + raise ImportError("cv2 is not installed.") + + +def im_decode( + im_bytes: bytes, mode: str = "RGB", backend: str = "PIL" +) -> NDArrayUI8: + """Decode to image (numpy array, RGB) from bytes.""" + assert mode in { + "BGR", + "RGB", + "L", + }, f"{mode} not supported for image decoding!" + if backend == "PIL": + pil_img_file = Image.open(BytesIO(bytearray(im_bytes))) + pil_img = ImageOps.exif_transpose(pil_img_file) + assert pil_img is not None, "Image could not be loaded!" + if pil_img.mode == "L": # pragma: no cover + if mode == "L": + img: NDArrayUI8 = np.array(pil_img)[..., None] + else: + # convert grayscale image to RGB + pil_img = pil_img.convert("RGB") + elif mode == "L": # pragma: no cover + raise ValueError("Cannot convert colorful image to grayscale!") + if mode == "BGR": # pragma: no cover + img = np.array(pil_img)[..., [2, 1, 0]] + elif mode == "RGB": + img = np.array(pil_img) + elif backend == "cv2": # pragma: no cover + if not OPENCV_AVAILABLE: + raise ImportError( + "Please install opencv-python to use cv2 backend!" + ) + img_np: NDArrayUI8 = np.frombuffer(im_bytes, np.uint8) + img = imdecode( # type: ignore + img_np, IMREAD_GRAYSCALE if mode == "L" else IMREAD_COLOR + ) + if mode == "RGB": + cvtColor(img, COLOR_BGR2RGB, img) + else: + raise NotImplementedError(f"Image backend {backend} not known!") + return img + + +def ply_decode(ply_bytes: bytes, mode: str = "XYZI") -> NDArrayFloat: + """Decode to point clouds (numpy array) from bytes. + + Args: + ply_bytes (bytes): The bytes of the ply file. + mode (str, optional): The point format of the ply file. If "XYZI", the + intensity channel will be included, otherwise only the XYZ + coordinates. Defaults to "XYZI". + """ + assert mode in { + "XYZ", + "XYZI", + }, f"{mode} not supported for points decoding!" + + plydata = plyfile.PlyData.read(BytesIO(bytearray(ply_bytes))) + num_points = plydata["vertex"].count + num_channels = 3 if mode == "XYZ" else 4 + points = np.zeros((num_points, num_channels), dtype=np.float32) + + points[:, 0] = plydata["vertex"].data["x"] + points[:, 1] = plydata["vertex"].data["y"] + points[:, 2] = plydata["vertex"].data["z"] + if mode == "XYZI": + points[:, 3] = plydata["vertex"].data["intensity"] + return points + + +def npy_decode(npy_bytes: bytes, key: str | None = None) -> NDArrayFloat: + """Decode to numpy array from npy/npz file bytes.""" + data = np.load(BytesIO(bytearray(npy_bytes))) + if key is not None: + data = data[key] + return data + + +def filter_by_keys( + data_dict: DictData, keys_to_keep: Sequence[str] +) -> DictData: + """Filter a dictionary by keys. + + Args: + data_dict (DictData): The dictionary to filter. + keys_to_keep (list[str]): The keys to keep. + + Returns: + DictData: The filtered dictionary. + """ + return {key: data_dict[key] for key in keys_to_keep if key in data_dict} + + +def get_used_data_groups( + data_groups: dict[str, list[str]], keys: list[str] +) -> list[str]: + """Get the data groups that are used by the given keys. + + Args: + data_groups (dict[str, list[str]]): The data groups. + keys (list[str]): The keys to check. + + Returns: + list[str]: The used data groups. + """ + used_groups = [] + for group_name, group_keys in data_groups.items(): + if not group_keys: + continue + if any(key in keys for key in group_keys): + used_groups.append(group_name) + return used_groups + + +def to_onehot(categories: NDArrayI64, num_classes: int) -> NDArrayFloat: + """Transform integer categorical labels to onehot vectors. + + Args: + categories (NDArrayI64): Integer categorical labels of shape (N, ). + num_classes (int): Number of classes. + + Returns: + NDArrayFloat: Onehot vector of shape (N, num_classes). + """ + _eye = np.eye(num_classes, dtype=np.float32) + return _eye[categories] + + +class CacheMappingMixin: + """Caches a mapping for fast I/O and multi-processing. + + This class provides functionality for caching a mapping from dataset index + requested by a call on __getitem__ to a dictionary that holds relevant + information for loading the sample in question from the disk. + Caching the mapping reduces startup time by loading the mapping instead of + re-computing it at every startup. + + NOTE: Make sure your annotations file is up-to-date. Otherwise, the mapping + will be wrong and you will get wrong samples. + """ + + @rank_zero_only + def _load_mapping_data( + self, + generate_map_func: Callable[[], list[DictStrAny]], + cache_as_binary: bool, + cached_file_path: str | None, + ) -> ListAny: + """Load possibly cached mapping via generate_map_func. + + Args: + generate_map_func (Callable[[], list[DictStrAny]]): The function + that generates the mapping. + cache_as_binary (bool): Whether to cache the mapping as binary. + cached_file_path (str | None): The path to the cached mapping file. + """ + if cache_as_binary: + assert ( + cached_file_path is not None + ), "cached_file_path must be set if cache_as_binary is True!" + if not os.path.exists(cached_file_path): + rank_zero_info( + f"Did not find {cached_file_path}, generating it..." + ) + data = generate_map_func() + os.makedirs(os.path.dirname(cached_file_path), exist_ok=True) + with open(cached_file_path, "wb") as file: + file.write(pickle.dumps(data)) + else: + dt = datetime.fromtimestamp(os.stat(cached_file_path).st_mtime) + rank_zero_info( + f"Found {cached_file_path} generated at {dt.isoformat()} " + + "and loading it..." + ) + with open(cached_file_path, "rb") as file: + data = pickle.loads(file.read()) + else: + rank_zero_info(f"Generating {self} data mapping...") + data = generate_map_func() + return data + + def _load_mapping( + self, + generate_map_func: Callable[[], list[DictStrAny]], + filter_func: Callable[[ListAny], ListAny] = lambda x: x, + cache_as_binary: bool = False, + cached_file_path: str | None = None, + ) -> tuple[DatasetFromList, int]: + """Load cached mapping or generate if not exists. + + Args: + generate_map_func (Callable[[], list[DictStrAny]]): The function + that generates the mapping. + filter_func (Callable[[ListAny], ListAny], optional): The function + that filters the mapping. Defaults to lambda x: x. + cache_as_binary (bool, optional): Whether to cache the mapping as + binary. Defaults to True. + cached_file_path (str | None, optional): The path to the cached + mapping file. Defaults to None. + """ + timer = Timer() + dataset = self._load_mapping_data( + generate_map_func, cache_as_binary, cached_file_path + ) + original_len = 0 + if dataset is not None: + original_len = len(dataset) + dataset = filter_func(dataset) + dataset = DatasetFromList(dataset) + dataset = broadcast(dataset) + original_len = broadcast(original_len) + rank_zero_info(f"Loading {self} takes {timer.time():.2f} seconds.") + return dataset, original_len + + +# reference: +# https://github.com/facebookresearch/detectron2/blob/7f8f29deae278b75625872c8a0b00b74129446ac/detectron2/data/common.py#L109 +class DatasetFromList(Dataset): # type: ignore + """Wrap a list to a torch Dataset. + + We serialize and wrap big python objects in a torch.Dataset due to a + memory leak when dealing with large python objects using multiple workers. + See: https://github.com/pytorch/pytorch/issues/13246 + """ + + def __init__( + self, lst: ListAny, deepcopy: bool = False, serialize: bool = True + ): + """Creates an instance of the class. + + Args: + lst: a list which contains elements to produce. + deepcopy: whether to deepcopy the element when producing it, s.t. + the result can be modified in place without affecting the source + in the list. + serialize: whether to hold memory using serialized objects. When + enabled, data loader workers can use shared RAM from master + process instead of making a copy. + """ + self._copy = deepcopy + self._serialize = serialize + + def _serialize(data: Any) -> NDArrayUI8: # type: ignore + """Serialize python object to numpy array.""" + buffer = pickle.dumps(data, protocol=-1) + return np.frombuffer(buffer, dtype=np.uint8) + + if self._serialize: + self._lst = [_serialize(x) for x in lst] + self._addr: NDArrayI64 = np.asarray( + [len(x) for x in self._lst], dtype=np.int64 + ) + self._addr = np.cumsum(self._addr) + self._lst = np.concatenate(self._lst) # type: ignore + else: + self._lst = lst # pragma: no cover + + def __len__(self) -> int: + """Return len of list.""" + if self._serialize: + return len(self._addr) + return len(self._lst) # pragma: no cover + + def __getitem__(self, idx: int) -> Any: # type: ignore + """Return item of list at idx.""" + if self._serialize: + start_addr = 0 if idx == 0 else self._addr[idx - 1].item() + end_addr = self._addr[idx].item() + bytes_ = memoryview(self._lst[start_addr:end_addr]) # type: ignore + return pickle.loads(bytes_) + if self._copy: # pragma: no cover + return copy.deepcopy(self._lst[idx]) + + return self._lst[idx] # pragma: no cover + + +def print_class_histogram(class_frequencies: dict[str, int]) -> None: + """Prints out given class frequencies.""" + if len(class_frequencies) == 0: # pragma: no cover + return + + class_names = list(class_frequencies.keys()) + frequencies = list(class_frequencies.values()) + num_classes = len(class_names) + + n_cols = min(6, len(class_names) * 2) + + def short_name(name: str) -> str: + """Make long class names shorter.""" + if len(name) > 13: + return name[:11] + ".." # pragma: no cover + return name + + data = list( + itertools.chain( + *[ + [short_name(class_names[i]), int(v)] + for i, v in enumerate(frequencies) + ] + ) + ) + total_num_instances = sum(data[1::2]) # type: ignore + data.extend([None] * (n_cols - (len(data) % n_cols))) + if num_classes > 1: + data.extend(["total", total_num_instances]) + + table = tabulate( + itertools.zip_longest(*[data[i::n_cols] for i in range(n_cols)]), + headers=["category", "#instances"] * (n_cols // 2), + tablefmt="pipe", + numalign="left", + stralign="center", + ) + + rank_zero_info( + f"Distribution of instances among all {num_classes} categories:\n" + + colored(table, "cyan") + ) + + +def get_category_names(det_mapping: dict[str, int]) -> list[str]: + """Get category names from a mapping of category names to ids.""" + return sorted(det_mapping, key=det_mapping.get) # type: ignore diff --git a/vis4d/data/io/__init__.py b/vis4d/data/io/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f167ac5fc413a17d4248f900d3eb0545fb8fe2f3 --- /dev/null +++ b/vis4d/data/io/__init__.py @@ -0,0 +1,13 @@ +"""Init io module.""" + +from .base import DataBackend +from .file import FileBackend +from .hdf5 import HDF5Backend +from .zip import ZipBackend + +__all__ = [ + "DataBackend", + "HDF5Backend", + "FileBackend", + "ZipBackend", +] diff --git a/vis4d/data/io/base.py b/vis4d/data/io/base.py new file mode 100644 index 0000000000000000000000000000000000000000..7e2c7fee9066ce6b575df96f412e2ff50ff31948 --- /dev/null +++ b/vis4d/data/io/base.py @@ -0,0 +1,84 @@ +"""Backends for the data types a dataset of interest is saved in. + +Those can be used to load data from diverse storage backends, e.g. from HDF5 +files which are more suitable for data centers. The naive backend is the +FileBackend, which loads from / saves to file naively. +""" + +from abc import abstractmethod +from typing import Literal + + +class DataBackend: + """Abstract class of storage backends. + + All backends need to implement three functions: get(), set() and exists(). + get() reads the file as a byte stream and set() writes a byte stream to a + file. exists() checks if a certain filepath exists. + """ + + @abstractmethod + def set( + self, filepath: str, content: bytes, mode: Literal["w", "a"] + ) -> None: + """Set the file content at the given filepath. + + Args: + filepath (str): The filepath to store the data at. + content (bytes): The content to store as bytes. + mode (str): The mode to open the file in. + """ + raise NotImplementedError + + @abstractmethod + def get(self, filepath: str) -> bytes: + """Get the file content at the given filepath as bytes. + + Args: + filepath (str): The filepath to retrieve the data from." + + Returns: + bytes: The content of the file as bytes. + """ + raise NotImplementedError + + @abstractmethod + def exists(self, filepath: str) -> bool: + """Check if filepath exists. + + Args: + filepath (str): The filepath to check. + + Returns: + bool: True if the filepath exists, False otherwise. + """ + raise NotImplementedError + + @abstractmethod + def isfile(self, filepath: str) -> bool: + """Check if filepath is a file. + + Args: + filepath (str): The filepath to check. + + Returns: + bool: True if the filepath is a file, False otherwise. + """ + raise NotImplementedError + + @abstractmethod + def listdir(self, filepath: str) -> list[str]: + """List all files in a directory. + + Args: + filepath (str): The directory to list. + + Returns: + list[str]: A list of all files in the directory. + """ + raise NotImplementedError + + @abstractmethod + def close(self) -> None: + """Close all opened files in the backend.""" + raise NotImplementedError diff --git a/vis4d/data/io/file.py b/vis4d/data/io/file.py new file mode 100644 index 0000000000000000000000000000000000000000..abfbe85dde75ea07fcdb33c62510234ba8fa68cc --- /dev/null +++ b/vis4d/data/io/file.py @@ -0,0 +1,83 @@ +"""Standard backend for local files on a hard drive. + +This backends loads data from and saves data to the local hard drive. +""" + +import os +from typing import Literal + +from .base import DataBackend + + +class FileBackend(DataBackend): + """Raw file from hard disk data backend.""" + + def isfile(self, filepath: str) -> bool: + """Check if filepath is a file. + + Args: + filepath (str): Path to file. + + Returns: + bool: True if file exists, False otherwise. + """ + return os.path.isfile(filepath) + + def listdir(self, filepath: str) -> list[str]: + """List all files in the directory. + + Args: + filepath (str): Path to file. + + Returns: + list[str]: List of all files in the directory. + """ + return sorted(os.listdir(filepath)) + + def exists(self, filepath: str) -> bool: + """Check if filepath exists. + + Args: + filepath (str): Path to file. + + Returns: + bool: True if file exists, False otherwise. + """ + return os.path.exists(filepath) + + def set( + self, filepath: str, content: bytes, mode: Literal["w", "a"] = "w" + ) -> None: + """Write the file content to disk. + + Args: + filepath (str): Path to file. + content (bytes): Content to write in bytes. + mode (Literal["w", "a"], optional): Overwrite or append mode. + Defaults to "w". + """ + os.makedirs(os.path.dirname(filepath), exist_ok=True) + mode_binary: Literal["wb", "ab"] = "wb" if mode == "w" else "ab" + with open(filepath, mode_binary) as f: + f.write(content) + + def get(self, filepath: str) -> bytes: + """Get file content as bytes. + + Args: + filepath (str): Path to file. + + Raises: + FileNotFoundError: If filepath does not exist. + + Returns: + bytes: File content as bytes. + """ + if not self.exists(filepath): + raise FileNotFoundError(f"File not found:" f" {filepath}") + with open(filepath, "rb") as f: + value_buf = f.read() + return value_buf + + def close(self) -> None: + """No need to close manually.""" diff --git a/vis4d/data/io/hdf5.py b/vis4d/data/io/hdf5.py new file mode 100644 index 0000000000000000000000000000000000000000..c7eee55a75b830d3ed40055416cd56d8261eec0f --- /dev/null +++ b/vis4d/data/io/hdf5.py @@ -0,0 +1,242 @@ +"""Hdf5 data backend. + +This backend works with filepaths pointing to valid HDF5 files. We assume that +the given HDF5 file contains the whole dataset associated to this backend. +""" + +from __future__ import annotations + +import os +from typing import Literal + +import numpy as np + +from vis4d.common.imports import H5PY_AVAILABLE + +from .base import DataBackend + +if H5PY_AVAILABLE: + import h5py + from h5py import File +else: + raise ImportError("Please install h5py to enable HDF5Backend.") + + +class HDF5Backend(DataBackend): + """Backend for loading data from HDF5 files. + + This backend works with filepaths pointing to valid HDF5 files. We assume + that the given HDF5 file contains the whole dataset associated to this + backend. + + You can use the provided script at vis4d/data/datasets/to_hdf5.py to + convert your dataset to the expected hdf5 format before using this backend. + """ + + def __init__(self) -> None: + """Creates an instance of the class.""" + super().__init__() + if not H5PY_AVAILABLE: + raise ImportError("Please install h5py to enable HDF5Backend.") + self.db_cache: dict[str, File] = {} + + @staticmethod + def _get_hdf5_path( + filepath: str, allow_omitted_ext: bool = True + ) -> tuple[str, list[str]]: + """Get .hdf5 path and keys from filepath. + + Args: + filepath (str): The filepath to retrieve the data from. + Should have the following format: 'path/to/file.hdf5/key1/key2' + allow_omitted_ext (bool, optional): Whether to allow omitted + extension, in which case the backend will try to append + '.hdf5' to the filepath. Defaults to True. + + Returns: + tuple[str, list[str]]: The .hdf5 path and the keys to retrieve. + + Examples: + >>> HDF5Backend._get_hdf5_path("path/to/file.hdf5/key1/key2") + ("path/to/file.hdf5", ["key2", "key1"]) + >>> HDF5Backend._get_hdf5_path("path/to/file/key1/key2", True) + ("path/to/file.hdf5", ["key2", "key1"]) # if file.hdf5 exists and + # is a valid hdf5 file + """ + filepath_as_list = filepath.split("/") + keys = [] + + while True: + if filepath.endswith(".hdf5") or filepath == "": + break + if allow_omitted_ext and h5py.is_hdf5(filepath + ".hdf5"): + filepath = filepath + ".hdf5" + break + keys.append(filepath_as_list.pop()) + filepath = "/".join(filepath_as_list) + return filepath, keys + + def exists(self, filepath: str) -> bool: + """Check if filepath exists. + + Args: + filepath (str): Path to file. + + Returns: + bool: True if file exists, False otherwise. + """ + hdf5_path, keys = self._get_hdf5_path(filepath) + if not os.path.exists(hdf5_path): + return False + value_buf = self._get_client(hdf5_path, "r") + + while keys: + value_buf = value_buf.get(keys.pop()) + if value_buf is None: + return False + return True + + def set( + self, filepath: str, content: bytes, mode: Literal["w", "a"] = "a" + ) -> None: + """Set the file content. + + Args: + filepath: path/to/file.hdf5/key1/key2/key3 + content: Bytes to be written to entry key3 within group key2 + within another group key1, for example. + mode: "w" to overwrite the file, "a" to append to it. + + Raises: + ValueError: If filepath is not a valid .hdf5 file + """ + if ".hdf5" not in filepath: + raise ValueError(f"{filepath} not a valid .hdf5 filepath!") + hdf5_path, keys_str = filepath.split(".hdf5") + key_list = keys_str.split("/") + file = self._get_client(hdf5_path + ".hdf5", mode) + if len(key_list) > 1: + group_str = "/".join(key_list[:-1]) + if group_str == "": + group_str = "/" + + group = file[group_str] + key = key_list[-1] + group.create_dataset( + key, data=np.frombuffer(content, dtype="uint8") + ) + + def _get_client(self, hdf5_path: str, mode: str) -> File: + """Get HDF5 client from path. + + Args: + hdf5_path (str): Path to HDF5 file. + mode (str): Mode to open the file in. + + Returns: + File: the hdf5 file. + """ + if hdf5_path not in self.db_cache: + client = File(hdf5_path, mode, swmr=True, libver="latest") + self.db_cache[hdf5_path] = [client, mode] + else: + client, current_mode = self.db_cache[hdf5_path] + if current_mode != mode: + client.close() + client = File(hdf5_path, mode, swmr=True, libver="latest") + self.db_cache[hdf5_path] = [client, mode] + return client + + def get(self, filepath: str) -> bytes: + """Get values according to the filepath as bytes. + + Args: + filepath (str): The path to the file. It consists of an HDF5 path + together with the relative path inside it, e.g.: "/path/to/ + file.hdf5/key/subkey/data". If no .hdf5 given inside filepath, + the function will search for the first .hdf5 file present in + the path, i.e. "/path/to/file/key/subkey/data" will also /key/ + subkey/data from /path/to/file.hdf5. + + Raises: + FileNotFoundError: If no suitable file exists. + ValueError: If key not found inside hdf5 file. + + Returns: + bytes: The file content in bytes + """ + hdf5_path, keys = self._get_hdf5_path(filepath) + + if not os.path.exists(hdf5_path): + raise FileNotFoundError( + f"Corresponding HDF5 file not found:" f" {filepath}" + ) + value_buf = self._get_client(hdf5_path, "r") + url = "/".join(reversed(keys)) + while keys: + value_buf = value_buf.get(keys.pop()) + if value_buf is None: + raise ValueError(f"Value {url} not found in {hdf5_path}!") + + return bytes(value_buf[()]) + + def isfile(self, filepath: str) -> bool: + """Check if filepath is a file. + + Args: + filepath (str): Path to file. + + Raises: + FileNotFoundError: If no suitable file exists. + ValueError: If key not found inside hdf5 file. + + Returns: + bool: True if file exists, False otherwise. + """ + hdf5_path, keys = self._get_hdf5_path(filepath) + if not os.path.exists(hdf5_path): + raise FileNotFoundError( + f"Corresponding HDF5 file not found:" f" {filepath}" + ) + value_buf = self._get_client(hdf5_path, "r") + url = "/".join(reversed(keys)) + while keys: + value_buf = value_buf.get(keys.pop()) + if value_buf is None: + raise ValueError(f"Value {url} not found in {hdf5_path}!") + return not isinstance(value_buf, h5py.Group) + + def listdir(self, filepath: str) -> list[str]: + """List all files in the given directory. + + Args: + filepath (str): Path to directory. + + Raises: + FileNotFoundError: If no suitable file exists. + ValueError: If key not found inside hdf5 file. + + Returns: + list[str]: List of files in the given directory. + """ + hdf5_path, keys = self._get_hdf5_path(filepath) + if not os.path.exists(hdf5_path): + raise FileNotFoundError( + f"Corresponding HDF5 file not found:" f" {filepath}" + ) + value_buf = self._get_client(hdf5_path, "r") + url = "/".join(reversed(keys)) + while keys: + value_buf = value_buf.get(keys.pop()) + if value_buf is None: + raise ValueError(f"Value {url} not found in {hdf5_path}!") + if not isinstance(value_buf, h5py.Group): + raise ValueError(f"Value {url} is not a group in {hdf5_path}!") + + return sorted(list(value_buf.keys())) + + def close(self) -> None: + """Close all opened HDF5 files.""" + for client, _ in self.db_cache.values(): + client.close() + self.db_cache.clear() diff --git a/vis4d/data/io/to_hdf5.py b/vis4d/data/io/to_hdf5.py new file mode 100644 index 0000000000000000000000000000000000000000..4a2161a5ba4d73de179b80347a639de443022acc --- /dev/null +++ b/vis4d/data/io/to_hdf5.py @@ -0,0 +1,76 @@ +"""Script to convert a dataset to hdf5 format.""" + +from __future__ import annotations + +import argparse +import os + +import numpy as np +from tqdm import tqdm + +from vis4d.common.imports import H5PY_AVAILABLE + +if H5PY_AVAILABLE: + import h5py +else: + raise ImportError("Please install h5py to enable HDF5Backend.") + + +def convert_dataset(source_dir: str) -> None: + """Convert a dataset to HDF5 format. + + This function converts an arbitary dictionary to an HDF5 file. The keys + inside the HDF5 file preserve the directory structure of the original. + + As an example, if you convert "/path/to/dataset" to HDF5, the resulting + file will be: "/path/to/dataset.hdf5". The file "relative/path/to/file" + will be stored at "relative/path/to/file" inside /path/to/dataset.hdf5. + + Args: + source_dir (str): The path to the dataset to convert. + """ + if not os.path.exists(source_dir): + raise FileNotFoundError(f"No such file or directory: {source_dir}") + + source_dir = os.path.join(source_dir, "") # must end with trailing slash + hdf5_path = source_dir.rstrip("/") + ".hdf5" + if os.path.exists(hdf5_path): + print(f"File {hdf5_path} already exists! Skipping {source_dir}") + return + + print(f"Converting dataset at: {source_dir}") + hdf5_file = h5py.File(hdf5_path, mode="w") + sub_dirs = list(os.walk(source_dir)) + file_count = sum(len(files) for (_, _, files) in sub_dirs) + + with tqdm(total=file_count) as pbar: + for root, _, files in sub_dirs: + g_name = root.replace(source_dir, "") + g = hdf5_file.create_group(g_name) if g_name else hdf5_file + for f in files: + filepath = os.path.join(root, f) + if os.path.isfile(filepath): + with open(filepath, "rb") as fp: + file_content = fp.read() + g.create_dataset( + f, data=np.frombuffer(file_content, dtype="uint8") + ) + pbar.update() + + hdf5_file.close() + print("done.") + + +if __name__ == "__main__": # pragma: no cover + parser = argparse.ArgumentParser( + description="Converts a dataset at the specified path to hdf5. The " + "local directory structure is preserved in the hdf5 file." + ) + parser.add_argument( + "-p", + "--path", + required=True, + help="path to the root folder of a specific dataset to convert", + ) + args = parser.parse_args() + convert_dataset(args.path) diff --git a/vis4d/data/io/util.py b/vis4d/data/io/util.py new file mode 100644 index 0000000000000000000000000000000000000000..6d23dbc6bf8df17c54f730346c9e5d4231f3f89c --- /dev/null +++ b/vis4d/data/io/util.py @@ -0,0 +1,21 @@ +"""Data I/O Utilities.""" + +from __future__ import annotations + +import sys + + +def str_decode(str_bytes: bytes, encoding: None | str = None) -> str: + """Decode to string from bytes. + + Args: + str_bytes (bytes): Bytes to decode. + encoding (None | str): Encoding to use. Defaults to None which is + equivalent to sys.getdefaultencoding(). + + Returns: + str: Decoded string. + """ + if encoding is None: + encoding = sys.getdefaultencoding() + return str_bytes.decode(encoding) diff --git a/vis4d/data/io/zip.py b/vis4d/data/io/zip.py new file mode 100644 index 0000000000000000000000000000000000000000..0ea4f7cddc438ca8699bdc5e557081504956d647 --- /dev/null +++ b/vis4d/data/io/zip.py @@ -0,0 +1,206 @@ +"""Zip data backend. + +This backend works with filepaths pointing to valid Zip files. We assume that +the given Zip file contains the whole dataset associated to this backend. +""" + +from __future__ import annotations + +import os +import zipfile +from typing import Literal +from zipfile import ZipFile + +from .base import DataBackend + + +class ZipBackend(DataBackend): + """Backend for loading data from Zip files. + + This backend works with filepaths pointing to valid Zip files. We assume + that the given Zip file contains the whole dataset associated to this + backend. + """ + + def __init__(self) -> None: + """Creates an instance of the class.""" + super().__init__() + self.db_cache: dict[str, tuple[ZipFile, str]] = {} + + @staticmethod + def _get_zip_path( + filepath: str, allow_omitted_ext: bool = True + ) -> tuple[str, list[str]]: + """Get .zip path and keys from filepath. + + Args: + filepath (str): The filepath to retrieve the data from. + Should have the following format: 'path/to/file.zip/key1/key2' + allow_omitted_ext (bool, optional): Whether to allow omitted + extension, in which case the backend will try to append + '.zip' to the filepath. Defaults to True. + + Returns: + tuple[str, list[str]]: The .hdf5 path and the keys to retrieve. + + Examples: + >>> _get_zip_path("path/to/file.zip/key1/key2") + ("path/to/file.zip", ["key2", "key1"]) + >>> _get_zip_path("path/to/file/key1/key2", True) + ("path/to/file.zip", ["key2", "key1"]) # if file.hdf5 exists and + # is a valid hdf5 file + """ + filepath_as_list = filepath.split("/") + keys = [] + + while True: + if filepath.endswith(".zip") or filepath == "": + break + if allow_omitted_ext and zipfile.is_zipfile(filepath + ".zip"): + filepath = filepath + ".zip" + break + keys.append(filepath_as_list.pop()) + filepath = "/".join(filepath_as_list) + return filepath, keys + + def exists(self, filepath: str) -> bool: + """Check if filepath exists. + + Args: + filepath (str): Path to file. + + Returns: + bool: True if file exists, False otherwise. + """ + zip_path, keys = self._get_zip_path(filepath) + if not os.path.exists(zip_path): + return False + file = self._get_client(zip_path, "r") + url = "/".join(reversed(keys)) + return url in file.namelist() + + def set( + self, filepath: str, content: bytes, mode: Literal["w", "a"] = "w" + ) -> None: + """Write the file content to the zip file. + + Args: + filepath: path/to/file.zip/key1/key2/key3 + content: Bytes to be written to entry key3 within group key2 + within another group key1, for example. + mode: Mode to open the file in. "w" for writing a file, "a" for + appending to existing file. + + Raises: + ValueError: If filepath is not a valid .zip file + NotImplementedError: If the method is not implemented. + """ + if ".zip" not in filepath: + raise ValueError(f"{filepath} not a valid .zip filepath!") + + zip_path, keys = self._get_zip_path(filepath) + zip_file = self._get_client(zip_path, mode) + url = "/".join(reversed(keys)) + zip_file.writestr(url, content) + + def _get_client( + self, zip_path: str, mode: Literal["r", "w", "a", "x"] + ) -> ZipFile: + """Get Zip client from path. + + Args: + zip_path (str): Path to Zip file. + mode (str): Mode to open the file in. + + Returns: + ZipFile: the hdf5 file. + """ + assert len(mode) == 1, "Mode must be a single character for zip file." + if zip_path not in self.db_cache: + os.makedirs(os.path.dirname(zip_path), exist_ok=True) + client = ZipFile(zip_path, mode) + self.db_cache[zip_path] = (client, mode) + else: + client, current_mode = self.db_cache[zip_path] + if current_mode != mode: + client.close() + client = ZipFile( # pylint:disable=consider-using-with + zip_path, mode + ) + self.db_cache[zip_path] = (client, mode) + return client + + def get(self, filepath: str) -> bytes: + """Get values according to the filepath as bytes. + + Args: + filepath (str): The path to the file. It consists of an Zip path + together with the relative path inside it, e.g.: "/path/to/ + file.zip/key/subkey/data". If no .zip given inside filepath, + the function will search for the first .zip file present in + the path, i.e. "/path/to/file/key/subkey/data" will also /key/ + subkey/data from /path/to/file.zip. + + Raises: + ZipFileNotFoundError: If no suitable file exists. + OSError: If the file cannot be opened. + ValueError: If key not found inside zip file. + + Returns: + bytes: The file content in bytes + """ + zip_path, keys = self._get_zip_path(filepath) + + if not os.path.exists(zip_path): + raise FileNotFoundError( + f"Corresponding zip file not found:" f" {filepath}" + ) + zip_file = self._get_client(zip_path, "r") + url = "/".join(reversed(keys)) + try: + with zip_file.open(url) as zf: + content = zf.read() + except KeyError as e: + raise ValueError(f"Value '{url}' not found in {zip_path}!") from e + return bytes(content) + + def listdir(self, filepath: str) -> list[str]: + """List all files in the given directory. + + Args: + filepath (str): The path to the directory. + + Returns: + list[str]: List of all files in the given directory. + """ + zip_path, keys = self._get_zip_path(filepath) + zip_file = self._get_client(zip_path, "r") + url = "/".join(reversed(keys)) + files = [ + os.path.basename(key) + for key in zip_file.namelist() + if key.startswith(url) and os.path.basename(key) != "" + ] + return sorted(files) + + def isfile(self, filepath: str) -> bool: + """Check if filepath is a file. + + Args: + filepath (str): Path to file. + + Returns: + bool: True if file exists, False otherwise. + """ + zip_path, keys = self._get_zip_path(filepath) + if not os.path.exists(zip_path): + return False + zip_file = self._get_client(zip_path, "r") + url = "/".join(reversed(keys)) + return url in zip_file.namelist() + + def close(self) -> None: + """Close all opened Zip files.""" + for client, _ in self.db_cache.values(): + client.close() + self.db_cache = {} diff --git a/vis4d/data/iterable.py b/vis4d/data/iterable.py new file mode 100644 index 0000000000000000000000000000000000000000..6caca4656dc11088e8abdd25e19ae9c001dc097b --- /dev/null +++ b/vis4d/data/iterable.py @@ -0,0 +1,100 @@ +"""Iterable datasets.""" + +from __future__ import annotations + +import math +from collections.abc import Callable, Iterator + +import numpy as np +from torch.utils.data import Dataset, IterableDataset, get_worker_info + +from .typing import DictData + + +class SubdividingIterableDataset(IterableDataset[DictData]): + """Subdivides a given dataset into smaller chunks. + + This also adds a field called 'index' (DataKeys.index) to the data + struct in order to relate the data to the source index. + + Example: Given a dataset (ds) that outputs tensors of the shape (10, 3): + sub_ds = SubdividingIterableDataset(ds, n_samples_per_batch = 5) + + next(iter(sub_ds))['key'].shape + >> torch.Size([5, 3]) + + next(DataLoader(sub_ds, batch_size = 4))['key'].shape + >> torch.size([4,5,3]) + + Assuming the dataset returns two entries with shape (10,3): + [e['index'].item() for e in sub_ds] + >> [0,0,1,1] + """ + + def __init__( + self, + dataset: Dataset[DictData], + n_samples_per_batch: int, + preprocess_fn: Callable[ + [list[DictData]], list[DictData] + ] = lambda x: x, + ) -> None: + """Creates a new Dataset. + + Args: + dataset (Dataset): The dataset which should be subdivided. + n_samples_per_batch: How many samples each batch should contain. + The first dimension of dataset[0].shape must be divisible by + this number. + preprocess_fn (Callable[[list[DictData]], list[DictData]): + Preprocessing function. Defaults to identity. + """ + super().__init__() + self.dataset = dataset + self.n_samples_per_batch = n_samples_per_batch + self.preprocess_fn = preprocess_fn + + def __getitem__(self, index: int) -> DictData: + """Indexing is not supported for IterableDatasets.""" + raise NotImplementedError("IterableDataset does not support indeing") + + def __iter__(self) -> Iterator[DictData]: + """Iterates over the dataset, supporting distributed sampling.""" + worker_info = get_worker_info() + if worker_info is None: + # not distributed + num_workers = 1 + worker_id = 0 + else: # pragma: no cover + num_workers = worker_info.num_workers + worker_id = worker_info.id + + assert hasattr( + self.dataset, "__len__" + ), "Dataset must have __len__ in order to be subdivided." + n_samples = len(self.dataset) + for i in range(math.ceil(n_samples / num_workers)): + data_idx = i * num_workers + worker_id + if data_idx >= n_samples: + continue + data_sample = self.dataset[data_idx] + + n_elements = list((data_sample.values()))[0].shape[0] + for idx in range(int(n_elements / self.n_samples_per_batch)): + # This is kind of ugly + # this field defines from which source the data was loaded + # (first entry, second entry, ...) + # this is required if we e.g. want to subdivide a room that is + # too big into equal sized chunks and stick them back together + # for visualizaton + out_data: DictData = {"source_index": np.ndarray([data_idx])} + for key in data_sample: + start_idx = idx * self.n_samples_per_batch + end_idx = (idx + 1) * self.n_samples_per_batch + if (len(data_sample[key])) < self.n_samples_per_batch: + out_data[key] = data_sample[key] + else: + out_data[key] = data_sample[key][ + start_idx:end_idx, ... + ] + yield self.preprocess_fn([out_data])[0] diff --git a/vis4d/data/loader.py b/vis4d/data/loader.py new file mode 100644 index 0000000000000000000000000000000000000000..2ecf3c08953a7fe61507cc185176329629111586 --- /dev/null +++ b/vis4d/data/loader.py @@ -0,0 +1,281 @@ +"""Dataloader utility functions.""" + +from __future__ import annotations + +import random +import warnings +from collections.abc import Callable, Sequence + +import numpy as np +import torch +from torch.utils.data import ( + DataLoader, + Dataset, + RandomSampler, + SequentialSampler, +) +from torch.utils.data.distributed import DistributedSampler +from torch.utils.data.sampler import Sampler + +from vis4d.common.distributed import get_rank, get_world_size + +from .const import CommonKeys as K +from .data_pipe import DataPipe +from .datasets.base import VideoDataset +from .samplers import AspectRatioBatchSampler, VideoInferenceSampler +from .transforms import compose +from .transforms.to_tensor import ToTensor +from .typing import DictData, DictDataOrList + +DEFAULT_COLLATE_KEYS = ( + K.seg_masks, + K.extrinsics, + K.intrinsics, + K.depth_maps, + K.optical_flows, + K.categories, +) + + +def default_collate( + batch: list[DictData], + collate_keys: Sequence[str] = DEFAULT_COLLATE_KEYS, + sensors: Sequence[str] | None = None, +) -> DictData: + """Default batch collate. + + It will concatenate images and stack seg_masks, extrinsics, intrinsics, + and depth_maps. Other keys will be put into a list. + + Args: + batch (list[DictData]): List of data dicts. + collate_keys (Sequence[str]): Keys to be collated. Default is + DEFAULT_COLLATE_KEYS. + sensors (Sequence[str] | None): List of sensors to collate. If is not + None will raise an error. Default is None. + + Returns: + DictData: Collated data dict. + """ + assert sensors is None, "If specified sensors, use multi_sensor_collate." + + data: DictData = {} + for key in batch[0]: + try: + if key == "transforms": # skip transform parameters + continue + if key in [K.images]: + data[key] = torch.cat([b[key] for b in batch]) + elif key in collate_keys: + data[key] = torch.stack([b[key] for b in batch], 0) + else: + data[key] = [b[key] for b in batch] + except RuntimeError as e: + raise RuntimeError(f"Error collating key {key}") from e + return data + + +def multi_sensor_collate( + batch: list[DictData], + collate_keys: Sequence[str] = DEFAULT_COLLATE_KEYS, + sensors: Sequence[str] | None = None, +) -> DictData: + """Default multi-sensor batch collate. + + Args: + batch (list[DictData]): List of data dicts. Each data dict contains + data from multiple sensors. + collate_keys (Sequence[str]): Keys to be collated. Default is + DEFAULT_COLLATE_KEYS. + sensors (Sequence[str] | None): List of sensors to collate. If None, + will raise an error. Default is None. + + Returns: + DictData: Collated data dict. + """ + assert ( + sensors is not None + ), "If not specified sensors, use default_collate." + + collated_batch: DictData = {} + + # For each sensor, collate the batch. Other keys will be put into a list. + for key in batch[0]: + inner_batch = [b[key] for b in batch] + if key in sensors: + collated_batch[key] = default_collate(inner_batch, collate_keys) + else: + collated_batch[key] = inner_batch + return collated_batch + + +def default_pipeline(data: list[DictData]) -> list[DictData]: + """Default data pipeline.""" + return compose([ToTensor()])(data) + + +def build_train_dataloader( + dataset: DataPipe, + samples_per_gpu: int = 1, + workers_per_gpu: int = 1, + batchprocess_fn: Callable[ + [list[DictData]], list[DictData] + ] = default_pipeline, + collate_fn: Callable[ + [list[DictData], Sequence[str]], DictData + ] = default_collate, + collate_keys: Sequence[str] = DEFAULT_COLLATE_KEYS, + sensors: Sequence[str] | None = None, + pin_memory: bool = True, + shuffle: bool | None = True, + drop_last: bool = False, + seed: int | None = None, + aspect_ratio_grouping: bool = False, + sampler: Sampler | None = None, # type: ignore + disable_subprocess_warning: bool = False, +) -> DataLoader[DictDataOrList]: + """Build training dataloader.""" + assert isinstance(dataset, DataPipe), "dataset must be a DataPipe" + + def _collate_fn_single(data: list[DictData]) -> DictData: + """Collates data from single view dataset.""" + return collate_fn( # type: ignore + batch=batchprocess_fn(data), + collate_keys=collate_keys, + sensors=sensors, + ) + + def _collate_fn_multi(data: list[list[DictData]]) -> list[DictData]: + """Collates data from multi view dataset.""" + views = [] + for view_idx in range(len(data[0])): + view = collate_fn( # type: ignore + batch=batchprocess_fn([d[view_idx] for d in data]), + collate_keys=collate_keys, + sensors=sensors, + ) + views.append(view) + return views + + def _worker_init_fn(worker_id: int) -> None: + """Will be called on each worker after seeding and before data loading. + + Args: + worker_id (int): Worker id in [0, num_workers - 1]. + """ + if seed is not None: + # The seed of each worker equals to + # num_workers * rank + worker_id + user_seed + worker_seed = workers_per_gpu * get_rank() + worker_id + seed + np.random.seed(worker_seed) + random.seed(worker_seed) + torch.manual_seed(worker_seed) + if disable_subprocess_warning and worker_id != 0: + warnings.simplefilter("ignore") + + if sampler is None: + if get_world_size() > 1: + assert isinstance( + shuffle, bool + ), "When using distributed training, shuffle must be a boolean." + sampler = DistributedSampler( + dataset, shuffle=shuffle, drop_last=drop_last + ) + shuffle = False + drop_last = False + elif shuffle: + sampler = RandomSampler(dataset) + shuffle = False + else: + sampler = SequentialSampler(dataset) + + batch_sampler = None + if aspect_ratio_grouping: + batch_sampler = AspectRatioBatchSampler( + sampler, batch_size=samples_per_gpu, drop_last=drop_last + ) + samples_per_gpu = 1 + shuffle = None + drop_last = False + sampler = None + + dataloader = DataLoader( + dataset, + batch_size=samples_per_gpu, + num_workers=workers_per_gpu, + collate_fn=( + _collate_fn_multi if dataset.has_reference else _collate_fn_single + ), + sampler=sampler, + batch_sampler=batch_sampler, + worker_init_fn=_worker_init_fn, + persistent_workers=workers_per_gpu > 0, + pin_memory=pin_memory, + shuffle=shuffle, + drop_last=drop_last, + ) + return dataloader + + +def build_inference_dataloaders( + datasets: Dataset[DictDataOrList] | list[Dataset[DictDataOrList]], + samples_per_gpu: int = 1, + workers_per_gpu: int = 1, + video_based_inference: bool = False, + batchprocess_fn: Callable[ + [list[DictData]], list[DictData] + ] = default_pipeline, + collate_fn: Callable[ + [list[DictData], Sequence[str]], DictData + ] = default_collate, + collate_keys: Sequence[str] = DEFAULT_COLLATE_KEYS, + sensors: Sequence[str] | None = None, +) -> list[DataLoader[DictDataOrList]]: + """Build dataloaders for test / predict.""" + + def _collate_fn(data: list[DictData]) -> DictData: + """Collates data for inference.""" + return collate_fn( # type: ignore + batch=batchprocess_fn(data), + collate_keys=collate_keys, + sensors=sensors, + ) + + if isinstance(datasets, Dataset): + datasets_ = [datasets] + else: + datasets_ = datasets + + dataloaders = [] + for dataset in datasets_: + sampler: DistributedSampler[list[int]] | None + if get_world_size() > 1: + if video_based_inference: + if isinstance(dataset, DataPipe): + assert ( + len(dataset.datasets) == 1 + ), "DDP Vdieo Inference only support a single dataset." + current_dataset = dataset.datasets[0] + else: + current_dataset = dataset + + assert isinstance( + current_dataset, VideoDataset + ), "Video based inference needs a VideoDataset." + sampler = VideoInferenceSampler(current_dataset) + else: + sampler = DistributedSampler(dataset) + else: + sampler = None + + test_dataloader = DataLoader( + dataset, + batch_size=samples_per_gpu, + num_workers=workers_per_gpu, + sampler=sampler, + shuffle=False, + collate_fn=_collate_fn, + persistent_workers=workers_per_gpu > 0, + ) + dataloaders.append(test_dataloader) + return dataloaders diff --git a/vis4d/data/reference.py b/vis4d/data/reference.py new file mode 100644 index 0000000000000000000000000000000000000000..0b27f7862f7b8c3e21de1b770363f22be67422ac --- /dev/null +++ b/vis4d/data/reference.py @@ -0,0 +1,239 @@ +"""Reference View Sampling. + +These Classes sample reference views from a dataset that contains videos. +This is usually used when a model needs multiple samples of a video during +training. +""" + +from __future__ import annotations + +from abc import abstractmethod +from typing import Callable, List + +import numpy as np +from torch.utils.data import Dataset + +from .const import CommonKeys as K +from .datasets.base import VideoDataset +from .typing import DictData + +SortingFunc = Callable[[DictData, list[DictData]], List[DictData]] + + +def sort_key_first( + cur_sample: DictData, ref_data: list[DictData] +) -> list[DictData]: + """Sort views as key first.""" + return [cur_sample, *ref_data] + + +def sort_temporal( + cur_sample: DictData, ref_data: list[DictData] +) -> list[DictData]: + """Sort views temporally.""" + return sorted([cur_sample, *ref_data], key=lambda x: x[K.frame_ids]) + + +class ReferenceViewSampler: + """Base reference view sampler.""" + + def __init__(self, num_ref_samples: int) -> None: + """Creates an instance of the class. + + Args: + num_ref_samples (int): Number of reference views to sample. + """ + self.num_ref_samples = num_ref_samples + + @abstractmethod + def __call__( + self, + key_dataset_index: int, + indices_in_video: list[int], + frame_ids: list[int], + ) -> list[int]: + """Sample num_ref_samples reference view indices. + + Args: + key_index (int): Index of key view in the video. + indices_in_video (list[int]): All dataset indices in the video. + frame_ids (list[int]): Frame ids of all views in the video. + + Returns: + list[int]: dataset indices of reference views. + """ + raise NotImplementedError + + +class SequentialViewSampler(ReferenceViewSampler): + """Sequential View Sampler.""" + + def __call__( + self, + key_dataset_index: int, + indices_in_video: list[int], + frame_ids: list[int], + ) -> list[int]: + """Sample sequential reference views.""" + assert len(frame_ids) >= self.num_ref_samples + 1 + + key_index = indices_in_video.index(key_dataset_index) + + right = key_index + 1 + self.num_ref_samples + if right <= len(indices_in_video): + ref_dataset_indices = indices_in_video[key_index + 1 : right] + else: + left = key_index - (right - len(indices_in_video)) + ref_dataset_indices = ( + indices_in_video[left:key_index] + + indices_in_video[key_index + 1 :] + ) + return ref_dataset_indices + + +class UniformViewSampler(ReferenceViewSampler): + """View Sampler that chooses reference views uniform at random.""" + + def __init__(self, scope: int, num_ref_samples: int) -> None: + """Creates an instance of the class. + + Args: + scope (int): Define scope of neighborhood to key view to sample + from. + num_ref_samples (int): Number of reference views to sample. + """ + super().__init__(num_ref_samples) + if scope != 0 and scope < num_ref_samples // 2: + raise ValueError("Scope must be higher than num_ref_imgs / 2.") + self.scope = scope + + def _get_valid_indices( + self, key_index: int, indices_in_video: list[int], frame_ids: list[int] + ) -> list[int]: + """Get valid indices in video.""" + key_fid = frame_ids[key_index] + min_fid = max(0, key_fid - self.scope) + max_fid = min(key_fid + self.scope, frame_ids[-1]) + + return [ + ind + for i, ind in enumerate(indices_in_video) + if min_fid <= frame_ids[i] <= max_fid and i != key_index + ] + + def __call__( + self, + key_dataset_index: int, + indices_in_video: list[int], + frame_ids: list[int], + ) -> list[int]: + """Uniformly sample reference views.""" + if self.scope > 0: + key_index = indices_in_video.index(key_dataset_index) + + valid_indices = self._get_valid_indices( + key_index, indices_in_video, frame_ids + ) + + if len(valid_indices) > 0: + assert len(valid_indices) >= self.num_ref_samples + return np.random.choice( + valid_indices, self.num_ref_samples, replace=False + ).tolist() + + return [key_dataset_index] * self.num_ref_samples + + +class MultiViewDataset(Dataset[list[DictData]]): + """Dataset that samples reference views from a video dataset.""" + + def __init__( + self, + dataset: VideoDataset, + sampler: ReferenceViewSampler, + sort_fn: SortingFunc = sort_key_first, + num_retry: int = 3, + match_key: str = K.boxes2d_track_ids, + skip_nomatch_samples: bool = False, + ) -> None: + """Creates an instance of the class. + + Args: + dataset (Dataset): Video dataset to sample from. + sampler (ReferenceViewSampler): Sampler that samples reference + views. + sort_fn (SortingFunc, optional): Function that sorts key and + reference views. Defaults to sort_key_first. + num_retry (int, optional): Number of retries if no match is found. + Defaults to 3. + match_key (str, optional): Key to match reference views with key + view. Defaults to K.boxes2d_track_ids. + skip_nomatch_samples (bool, optional): Whether to skip samples + where no match is found. Defaults to False. + """ + self.dataset = dataset + self.sampler = sampler + self.sort_fn = sort_fn + self.num_retry = num_retry + self.match_key = match_key + self.skip_nomatch_samples = skip_nomatch_samples + + def has_matches( + self, key_data: DictData, ref_data: list[DictData] + ) -> bool: + """Check if key / ref data have matches.""" + key_target = key_data[self.match_key] + for ref_view in ref_data: + ref_target = ref_view[self.match_key] + match = np.equal( + np.expand_dims(key_target, axis=1), ref_target[None] + ) + if match.any(): + return True + return False # pragma: no cover + + def __len__(self) -> int: + """Get length of dataset.""" + return len(self.dataset) + + def get_ref_data(self, ref_indices: list[int]) -> list[DictData]: + """Get reference data from dataset.""" + ref_data = [] + for ref_index in ref_indices: + ref_sample = self.dataset[ref_index] + ref_sample["keyframes"] = False + ref_data.append(ref_sample) + + assert self.sampler.num_ref_samples == len(ref_data) + return ref_data + + def __getitem__(self, index: int) -> list[DictData]: + """Get item from dataset.""" + cur_sample = self.dataset[index] + cur_sample["keyframes"] = True + + indices_in_video = self.dataset.video_mapping["video_to_indices"][ + cur_sample[K.sequence_names] + ] + frame_ids = self.dataset.video_mapping["video_to_frame_ids"][ + cur_sample[K.sequence_names] + ] + + if self.sampler.num_ref_samples > 0: + for _ in range(self.num_retry): + ref_indices = self.sampler(index, indices_in_video, frame_ids) + + ref_data = self.get_ref_data(ref_indices) + + if self.skip_nomatch_samples and not ( + self.has_matches(cur_sample, ref_data) + ): + continue + + return self.sort_fn(cur_sample, ref_data) + + ref_indices = [index] * self.sampler.num_ref_samples + ref_data = self.get_ref_data(ref_indices) + return [cur_sample, *ref_data] + + return [cur_sample] diff --git a/vis4d/data/resample.py b/vis4d/data/resample.py new file mode 100644 index 0000000000000000000000000000000000000000..e31b0c6361f6e78c30cd4eb20f135e903151dad7 --- /dev/null +++ b/vis4d/data/resample.py @@ -0,0 +1,77 @@ +"""Resample index to recover the original dataset length.""" + +from __future__ import annotations + +import numpy as np +from torch.utils.data import Dataset + +from vis4d.common.logging import rank_zero_info + +from .reference import MultiViewDataset +from .typing import DictDataOrList + + +class ResampleDataset(Dataset[DictDataOrList]): + """Dataset wrapper to recover the filtered samples through resampling. + + In MMEngine and Detectron2, the dataset might return None when the sample + has no valid annotations. They will resample the index and try to get the + valid training data. The length of dataset will be different depends on + whether filtering the empty samples first. + + This dataset wrapper resamples the index to recover the original dataset + length (before filter empty frames) to align with the other codebases' + implementation. + + https://github.com/open-mmlab/mmengine/blob/main/mmengine/dataset/base_dataset.py#L411 + https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/common.py#L96 + """ + + def __init__(self, dataset: Dataset[DictDataOrList]) -> None: + """Creates an instance of the class.""" + super().__init__() + self.dataset = dataset + self.has_reference = isinstance(dataset, MultiViewDataset) + self.valid_len = len(dataset) # type: ignore + + # Handle the case that dataset is already wrapped. + if hasattr(self.dataset, "dataset"): + _dataset = self.dataset.dataset + else: + _dataset = self.dataset + + assert hasattr(_dataset, "original_len"), ( + "The dataset must have the attribute `original_len` to resample " + + "index to recover the original length." + ) + self.original_len = _dataset.original_len + + rank_zero_info( + f"Recover {_dataset} to {self.original_len} samples by resampling " + + "index." + ) + + def __len__(self) -> int: + """Return the length of dataset. + + Returns: + int: Length of dataset. + """ + return self.original_len + + def __getitem__(self, idx: int) -> DictDataOrList: + """Get original dataset idx according to the given index. + + Resample index to recover the original dataset length. + + Args: + idx (int): The index of original dataset length. + + Returns: + DictDataOrList: Data of the corresponding index. + """ + if idx < self.valid_len: + index = idx + else: + index = np.random.randint(0, self.valid_len) + return self.dataset[index] diff --git a/vis4d/data/samplers.py b/vis4d/data/samplers.py new file mode 100644 index 0000000000000000000000000000000000000000..b7401fcc2db25c7c5f356278674e5e8949e0f443 --- /dev/null +++ b/vis4d/data/samplers.py @@ -0,0 +1,156 @@ +"""Vis4D data samplers.""" + +from __future__ import annotations + +from collections.abc import Iterator + +import numpy as np +from torch.utils.data import Dataset +from torch.utils.data.distributed import DistributedSampler +from torch.utils.data.sampler import BatchSampler, Sampler + +from vis4d.data.const import CommonKeys as K + +from .datasets.base import VideoDataset +from .typing import DictDataOrList + + +class VideoInferenceSampler( + DistributedSampler[list[int]] +): # pragma: no cover # No unittest for distributed setting. + """Produce sequence ordered indices for inference across all workers. + + Inference needs to run on the __exact__ set of sequences and their + respective samples, therefore if the sequences are not divisible by the + number of workers or if they have different length, the sampler + produces different number of samples on different workers. + """ + + def __init__( + self, + dataset: Dataset[DictDataOrList], + num_replicas: None | int = None, + rank: None | int = None, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, + ) -> None: + """Creates an instance of the class. + + Args: + dataset (Dataset): Inference dataset. + num_replicas (int, optional): Number of processes participating in + distributed training. By default, :attr:`world_size` is + retrieved from the current distributed group. + rank (int, optional): Rank of the current process within + :attr:`num_replicas`. By default, :attr:`rank` is retrieved + from the current distributed group. + shuffle (bool, optional): If ``True`` (default), sampler will + shuffle the indices. + seed (int, optional): random seed used to shuffle the sampler if + :attr:`shuffle=True`. This number should be identical across + all processes in the distributed group. Default: ``0``. + drop_last (bool, optional): if ``True``, then the sampler will drop + the tail of the data to make it evenly divisible across the + number of replicas. If ``False``, the sampler will add extra + indices to make the data evenly divisible across the replicas. + Default: ``False``. + """ + super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last) + assert isinstance(dataset, VideoDataset) + self.sequences = list(dataset.video_mapping["video_to_indices"]) + self.num_seqs = len(self.sequences) + assert self.num_seqs >= self.num_replicas, ( + f"Number of sequences ({self.num_seqs}) must be greater or " + f"equal to number of replicas ({self.num_replicas})!" + ) + chunks = np.array_split(self.sequences, self.num_replicas) + self._local_seqs = chunks[self.rank] + self._local_idcs: list[int] = [] + for seq in self._local_seqs: + self._local_idcs.extend( + dataset.video_mapping["video_to_indices"][seq] + ) + + def __iter__(self) -> Iterator[list[int]]: + """Iteration method.""" + return iter(self._local_idcs) # type: ignore + + def __len__(self) -> int: + """Return length of sampler instance.""" + return len(self._local_idcs) + + +class AspectRatioBatchSampler(BatchSampler): + """A sampler wrapper for grouping images with similar aspect ratio. + + Moidified from: + https://github.com/open-mmlab/mmdetection/blob/main/mmdet/datasets/samplers/batch_sampler.py + + Args: + sampler (Sampler): Base sampler. + batch_size (int): Size of mini-batch. + drop_last (bool): If ``True``, the sampler will drop the last batch if + its size would be less than ``batch_size``. + """ + + def __init__( + self, + sampler: Sampler, # type: ignore + batch_size: int, + drop_last: bool = False, + ) -> None: + """Creates an instance of the class.""" + if not isinstance(sampler, Sampler): + raise TypeError( + "sampler should be an instance of ``Sampler``, " + f"but got {sampler}" + ) + + super().__init__(sampler, batch_size, drop_last) + + # two groups for w < h and w >= h + self._aspect_ratio_buckets: list[list[int]] = [[] for _ in range(2)] + + def __iter__(self) -> Iterator[list[int]]: + """Iteration method.""" + for idx in self.sampler: + if hasattr(self.sampler, "dataset"): + data_dict = self.sampler.dataset[idx] + elif hasattr(self.sampler, "data_source"): + data_dict = self.sampler.data_source[idx] + else: + raise ValueError( + "sampler should have dataset or data_source attribute" + ) + height, width = data_dict[K.input_hw] + bucket_id = 0 if width < height else 1 + bucket = self._aspect_ratio_buckets[bucket_id] + bucket.append(idx) + # yield a batch of indices in the same aspect ratio group + if len(bucket) == self.batch_size: + yield bucket[:] + del bucket[:] + + # yield the rest data and reset the bucket + left_data = ( + self._aspect_ratio_buckets[0] + self._aspect_ratio_buckets[1] + ) + self._aspect_ratio_buckets = [[] for _ in range(2)] + while len(left_data) > 0: + if len(left_data) <= self.batch_size: + if not self.drop_last: + yield left_data[:] + left_data = [] + else: + yield left_data[: self.batch_size] + left_data = left_data[self.batch_size :] + + def __len__(self) -> int: + """Return length of sampler instance.""" + if self.drop_last: + return len(self.sampler) // self.batch_size # type: ignore + + return ( + len(self.sampler) + self.batch_size - 1 # type: ignore + ) // self.batch_size diff --git a/vis4d/data/transforms/__init__.py b/vis4d/data/transforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..59f70b93c8bacdc7e7f439214c81236985465264 --- /dev/null +++ b/vis4d/data/transforms/__init__.py @@ -0,0 +1,5 @@ +"""Transforms.""" + +from .base import RandomApply, Transform, compose + +__all__ = ["Transform", "RandomApply", "compose"] diff --git a/vis4d/data/transforms/__pycache__/__init__.cpython-311.pyc b/vis4d/data/transforms/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f97d8d001485ca71033275df79285446c239c2f Binary files /dev/null and b/vis4d/data/transforms/__pycache__/__init__.cpython-311.pyc differ diff --git a/vis4d/data/transforms/__pycache__/base.cpython-311.pyc b/vis4d/data/transforms/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b999e4559784cda0663ad91e77b6d31c1731273b Binary files /dev/null and b/vis4d/data/transforms/__pycache__/base.cpython-311.pyc differ diff --git a/vis4d/data/transforms/__pycache__/normalize.cpython-311.pyc b/vis4d/data/transforms/__pycache__/normalize.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e9f57f46b17f4f29d1a15a9974ee0786c28cbfe Binary files /dev/null and b/vis4d/data/transforms/__pycache__/normalize.cpython-311.pyc differ diff --git a/vis4d/data/transforms/__pycache__/pad.cpython-311.pyc b/vis4d/data/transforms/__pycache__/pad.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..51e47e25b5f4feae685eb67e8c8e0bc7da04afb6 Binary files /dev/null and b/vis4d/data/transforms/__pycache__/pad.cpython-311.pyc differ diff --git a/vis4d/data/transforms/__pycache__/resize.cpython-311.pyc b/vis4d/data/transforms/__pycache__/resize.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ffd0b059ad4bf3ce2e6c060d786f91e749f1cd32 Binary files /dev/null and b/vis4d/data/transforms/__pycache__/resize.cpython-311.pyc differ diff --git a/vis4d/data/transforms/__pycache__/to_tensor.cpython-311.pyc b/vis4d/data/transforms/__pycache__/to_tensor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..646112b7bdd3340b1f0b820c33ea7920c7558605 Binary files /dev/null and b/vis4d/data/transforms/__pycache__/to_tensor.cpython-311.pyc differ diff --git a/vis4d/data/transforms/affine.py b/vis4d/data/transforms/affine.py new file mode 100644 index 0000000000000000000000000000000000000000..b4c1bacc109b220fee73fa8cac2b083b93e7ff70 --- /dev/null +++ b/vis4d/data/transforms/affine.py @@ -0,0 +1,314 @@ +"""Affine transformation. + +Modified from mmdetection (https://github.com/open-mmlab/mmdetection). +""" + +from __future__ import annotations + +import math +import random +from typing import TypedDict + +import numpy as np +import torch + +from vis4d.common.imports import OPENCV_AVAILABLE +from vis4d.common.typing import NDArrayF32, NDArrayI64 +from vis4d.data.const import CommonKeys as K +from vis4d.op.box.box2d import bbox_clip, bbox_project + +from .base import Transform +from .crop import _get_keep_mask + +if OPENCV_AVAILABLE: + import cv2 +else: + raise ImportError("Please install opencv-python to use this module.") + + +class AffineParam(TypedDict): + """Parameters for Affine.""" + + warp_matrix: NDArrayF32 + height: int + width: int + + +def get_rotation_matrix(rotate_degrees: float) -> NDArrayF32: + """Generate rotation matrix. + + Args: + rotate_degrees (float): Rotation degrees. + """ + radian = math.radians(rotate_degrees) + rotation_matrix = np.array( + [ + [np.cos(radian), -np.sin(radian), 0.0], + [np.sin(radian), np.cos(radian), 0.0], + [0.0, 0.0, 1.0], + ], + dtype=np.float32, + ) + return rotation_matrix + + +def get_scaling_matrix(scale_ratio: float) -> NDArrayF32: + """Generate scaling matrix. + + Args: + scale_ratio (float): Scale ratio. + """ + scaling_matrix = np.array( + [[scale_ratio, 0.0, 0.0], [0.0, scale_ratio, 0.0], [0.0, 0.0, 1.0]], + dtype=np.float32, + ) + return scaling_matrix + + +def get_shear_matrix( + x_shear_degrees: float, y_shear_degrees: float +) -> NDArrayF32: + """Generate shear matrix. + + Args: + x_shear_degrees (float): X shear degrees. + y_shear_degrees (float): Y shear degrees. + """ + x_radian = math.radians(x_shear_degrees) + y_radian = math.radians(y_shear_degrees) + shear_matrix = np.array( + [ + [1, np.tan(x_radian), 0.0], + [np.tan(y_radian), 1, 0.0], + [0.0, 0.0, 1.0], + ], + dtype=np.float32, + ) + return shear_matrix + + +def get_translation_matrix(x_trans: float, y_trans: float) -> NDArrayF32: + """Generate translation matrix. + + Args: + x_trans (float): X translation. + y_trans (float): Y translation. + """ + translation_matrix = np.array( + [[1, 0.0, x_trans], [0.0, 1, y_trans], [0.0, 0.0, 1.0]], + dtype=np.float32, + ) + return translation_matrix + + +@Transform(K.input_hw, ["transforms.affine"]) +class GenAffineParameters: + """Random affine transform data augmentation. + + This operation randomly generates affine transform matrix which including + rotation, translation, shear, and scaling transforms. + """ + + def __init__( + self, + max_rotate_degree: float = 10.0, + max_translate_ratio: float = 0.1, + scaling_ratio_range: tuple[float, float] = (0.5, 1.5), + max_shear_degree: float = 2.0, + border: tuple[int, int] = (0, 0), + ) -> None: + """Creates an instance of the class. + + Args: + max_rotate_degree (float): Maximum degrees of rotation transform. + Defaults to 10. + max_translate_ratio (float): Maximum ratio of translation. + Defaults to 0.1. + scaling_ratio_range (tuple[float]): Min and max ratio of + scaling transform. Defaults to (0.5, 1.5). + max_shear_degree (float): Maximum degrees of shear + transform. Defaults to 2. + border (tuple[int, int]): Distance from height and width sides of + input image to adjust output shape. Only used in mosaic + dataset. Defaults to (0, 0). + """ + assert 0 <= max_translate_ratio <= 1 + assert scaling_ratio_range[0] <= scaling_ratio_range[1] + assert scaling_ratio_range[0] > 0 + self.max_rotate_degree = max_rotate_degree + self.max_translate_ratio = max_translate_ratio + self.scaling_ratio_range = scaling_ratio_range + self.max_shear_degree = max_shear_degree + self.border = border + + def _get_random_homography_matrix( + self, height: int, width: int + ) -> NDArrayF32: + """Generate random homography matrix.""" + # Rotation + rotation_degree = random.uniform( + -self.max_rotate_degree, self.max_rotate_degree + ) + rotation_matrix = get_rotation_matrix(rotation_degree) + + # Scaling + scaling_ratio = random.uniform( + self.scaling_ratio_range[0], self.scaling_ratio_range[1] + ) + scaling_matrix = get_scaling_matrix(scaling_ratio) + + # Shear + x_degree = random.uniform( + -self.max_shear_degree, self.max_shear_degree + ) + y_degree = random.uniform( + -self.max_shear_degree, self.max_shear_degree + ) + shear_matrix = get_shear_matrix(x_degree, y_degree) + + # Translation + trans_x = ( + random.uniform(-self.max_translate_ratio, self.max_translate_ratio) + * width + ) + trans_y = ( + random.uniform(-self.max_translate_ratio, self.max_translate_ratio) + * height + ) + translate_matrix = get_translation_matrix(trans_x, trans_y) + + warp_matrix = ( + translate_matrix @ shear_matrix @ rotation_matrix @ scaling_matrix + ) + return warp_matrix + + def __call__(self, input_hw: list[tuple[int, int]]) -> list[AffineParam]: + """Compute the parameters and put them in the data dict.""" + img_shape = input_hw[0] + height = img_shape[0] + self.border[0] * 2 + width = img_shape[1] + self.border[1] * 2 + + warp_matrix = self._get_random_homography_matrix(height, width) + return [ + AffineParam(warp_matrix=warp_matrix, height=height, width=width) + ] * len(input_hw) + + +@Transform( + [ + K.images, + "transforms.affine.warp_matrix", + "transforms.affine.height", + "transforms.affine.width", + ], + [K.images, K.input_hw], +) +class AffineImages: + """Affine Images.""" + + def __init__( + self, + border_val: tuple[int, int, int] = (114, 114, 114), + as_int: bool = False, + ) -> None: + """Creates an instance of the class. + + Args: + border_val (tuple[int, int, int]): Border padding values of 3 + channels. Defaults to (114, 114, 114). + as_int (bool): Whether to convert the image to int. Defaults to + False. + """ + self.border_val = border_val + self.as_int = as_int + + def __call__( + self, + images: list[NDArrayF32], + warp_matrix_list: list[NDArrayF32], + height_list: list[int], + width_list: list[int], + ) -> tuple[list[NDArrayF32], list[tuple[int, int]]]: + """Affine a list of image of dimensions [N, H, W, C].""" + input_hw_list = [] + for i, (image, warp_matrix, height, width) in enumerate( + zip(images, warp_matrix_list, height_list, width_list) + ): + image = image[0].astype(np.uint8) if self.as_int else image[0] + image = cv2.warpPerspective( # pylint: disable=no-member, unsubscriptable-object, line-too-long + image, + warp_matrix, + dsize=(width, height), + borderValue=self.border_val, + )[ + None, ... + ].astype( + np.float32 + ) + + images[i] = image + input_hw_list.append((height, width)) + return images, input_hw_list + + +@Transform( + in_keys=[ + K.boxes2d, + K.boxes2d_classes, + K.boxes2d_track_ids, + "transforms.affine.warp_matrix", + "transforms.affine.height", + "transforms.affine.width", + ], + out_keys=[K.boxes2d, K.boxes2d_classes, K.boxes2d_track_ids], +) +class AffineBoxes2D: + """Apply Affine to a list of 2D bounding boxes.""" + + def __init__(self, bbox_clip_border: bool = True) -> None: + """Creates an instance of the class. + + Args: + bbox_clip_border (bool, optional): Whether to clip the objects + outside the border of the image. In some dataset like MOT17, + the gt bboxes are allowed to cross the border of images. + Therefore, we don't need to clip the gt bboxes in these cases. + Defaults to True. + """ + self.bbox_clip_border = bbox_clip_border + + def __call__( + self, + boxes: list[NDArrayF32], + classes: list[NDArrayI64], + track_ids: list[NDArrayI64] | None, + warp_matrix_list: list[NDArrayF32], + height_list: list[int], + width_list: list[int], + ) -> tuple[list[NDArrayF32], list[NDArrayI64], list[NDArrayI64] | None]: + """Apply Affine to 2D bounding boxes.""" + for i, (box, class_, warp_matrix, height, width) in enumerate( + zip( + boxes, + classes, + warp_matrix_list, + height_list, + width_list, + ) + ): + box_ = bbox_project( + torch.from_numpy(box), torch.from_numpy(warp_matrix) + ) + if self.bbox_clip_border: + box_ = bbox_clip(box_, (height, width)) + boxes[i] = box_.numpy() + + keep_mask = _get_keep_mask( + boxes[i], np.array([0, 0, width, height]) + ) + boxes[i] = boxes[i][keep_mask] + classes[i] = class_[keep_mask] + if track_ids is not None: + track_ids[i] = track_ids[i][keep_mask] + + return boxes, classes, track_ids diff --git a/vis4d/data/transforms/autoaugment.py b/vis4d/data/transforms/autoaugment.py new file mode 100644 index 0000000000000000000000000000000000000000..6e47ad9f98846af3cce04aae27f5c4e9d57c0d5e --- /dev/null +++ b/vis4d/data/transforms/autoaugment.py @@ -0,0 +1,209 @@ +"""A wrap for timm transforms.""" + +from __future__ import annotations + +from typing import Union + +import numpy as np +from PIL import Image + +from vis4d.common.imports import TIMM_AVAILABLE +from vis4d.common.typing import NDArrayUI8 +from vis4d.data.const import CommonKeys as K + +from .base import Transform + +if TIMM_AVAILABLE: + from timm.data.auto_augment import ( + _RAND_INCREASING_TRANSFORMS, + _RAND_TRANSFORMS, + AugMixAugment, + AutoAugment, + RandAugment, + augmix_ops, + auto_augment_policy, + rand_augment_ops, + ) +else: + raise ImportError("timm is not installed.") + +AugOp = Union[AutoAugment, RandAugment, AugMixAugment] + + +def _apply_aug(images: NDArrayUI8, aug_op: AugOp) -> NDArrayUI8: + """Apply augmentation to a batch of images with shape [N, H, W, C].""" + assert images.shape[-1] == 3, "Images must be in RGB format." + imgs: list[Image.Image] = [] + for img in images: + # convert to uint8 if necessary + if img.dtype != np.uint8: + img = img.astype(np.uint8) + imgs.append(aug_op(Image.fromarray(img))) + return np.stack([np.array(img).astype(np.float32) for img in imgs]) + + +@Transform(K.images, K.images) +class _AutoAug: + """Apply Timm's AutoAugment to a image array.""" + + def __init__(self) -> None: + self.aug_op: AugOp | None = None + + def _create(self, policy: str, hparams: dict[str, float]) -> AugOp: + """Create augmentation op.""" + aa_policy = auto_augment_policy(policy, hparams=hparams) + return AutoAugment(aa_policy) + + def __call__(self, images: list[NDArrayUI8]) -> list[NDArrayUI8]: + """Execute the transform.""" + assert self.aug_op is not None, "Augmentation op is not created." + for i, img in enumerate(images): + images[i] = _apply_aug(img, self.aug_op) + return images + + +class AutoAugV0(_AutoAug): + """Apply Timm's AutoAugment (policy=v0) to a image array.""" + + def __init__(self, magnitude_std: float = 0.5): + """Create an instance of AutoAug. + + Args: + magnitude_std (float, optional): Standard deviation of the + magnitude for random autoaugment. Defaults to 0.5. + """ + super().__init__() + self.aug_op = self._create("v0", {"magnitude_std": magnitude_std}) + + +class AutoAugOriginal(_AutoAug): + """Apply Timm's AutoAugment (policy=original) to a image array.""" + + def __init__(self, magnitude_std: float = 0.5): + """Create an instance of AutoAug. + + Args: + magnitude_std (float, optional): Standard deviation of the + magnitude for random autoaugment. Defaults to 0.5. + """ + super().__init__() + self.aug_op = self._create( + "original", {"magnitude_std": magnitude_std} + ) + + +@Transform(K.images, K.images) +class RandAug: + """Apply Timm's RandomAugment to a image tensor.""" + + def __init__( + self, + magnitude: int = 10, + num_layers: int = 2, + use_increasing: bool = False, + magnitude_std: float = 0.5, + ): + """Create an instance of RandAug. + + Args: + magnitude (int): Level of magnitude for augments, ranging from 1 to + 9. + num_layers (int, optional): Number of layers for rand augment. + Defaults to 2. + use_increasing (bool, optional): Whether to use increasing setting + for transforms. Defaults to False. + magnitude_std (float, optional): Standard deviation of the + magnitude for random autoaugment. Defaults to 0.5. + + Returns: + Callable: A function that takes a tensor of shape [N, C, H, W] and + returns a tensor of the same shape. + + Example: + Rand augment with magnitude 9. (`https://arxiv.org/abs/1909.13719`) + >>> rand_augment(magnitude=9) + """ + super().__init__() + assert TIMM_AVAILABLE, "timm is not installed." + self.magnitude = magnitude + self.num_layers = num_layers + self.use_increasing = use_increasing + self.magnitude_std = magnitude_std + hparams = {"magnitude_std": self.magnitude_std} + + if self.use_increasing: + transforms = _RAND_INCREASING_TRANSFORMS + else: + transforms = _RAND_TRANSFORMS + ra_ops = rand_augment_ops( + magnitude=self.magnitude, hparams=hparams, transforms=transforms + ) + self.aug_op = RandAugment(ra_ops, self.num_layers) + + def __call__(self, images: list[NDArrayUI8]) -> list[NDArrayUI8]: + """Execute the transform.""" + for i, img in enumerate(images): + images[i] = _apply_aug(img, self.aug_op) + return images + + +@Transform(K.images, K.images) +class AugMix: + """Apply Timm's AugMix to a image tensor.""" + + def __init__( + self, + magnitude: int = 10, + width: int = 3, + alpha: float = 1.0, + depth: int = -1, + blended: bool = True, + magnitude_std: float = 0.5, + ): + """Create an instance of AugMix. + + Args: + magnitude (int): Level of magnitude, ranging from 1 to 9. + width (int, optional): Width of the augmentation chain. Defaults to + 3. + alpha (float, optional): Alpha for beta distribution. Defaults to + 1.0. + depth (int, optional): Depth of the augmentation chain. Defaults to + -1. + blended (bool, optional): Whether to blend the original image with + the augmented image. Defaults to True. + magnitude_std (float, optional): Standard deviation of the + magnitude for random autoaugment. Defaults to 0.5. + + Returns: + Callable: A function that takes a tensor of shape [N, C, H, W] and + returns a tensor of the same shape. + + Example: + Augmix with magnitude 9. (`https://arxiv.org/abs/1912.02781`) + >>> augmix(magnitude=9) + """ + super().__init__() + assert TIMM_AVAILABLE, "timm is not installed." + self.magnitude = magnitude + self.width = width + self.alpha = alpha + self.depth = depth + self.blended = blended + self.magnitude_std = magnitude_std + hparams = {"magnitude_std": self.magnitude_std} + + am_ops = augmix_ops(magnitude=self.magnitude, hparams=hparams) + self.aug_op = AugMixAugment( + am_ops, + alpha=self.alpha, + width=self.width, + depth=self.depth, + blended=self.blended, + ) + + def __call__(self, images: list[NDArrayUI8]) -> list[NDArrayUI8]: + """Execute the transform.""" + for i, img in enumerate(images): + images[i] = _apply_aug(img, self.aug_op) + return images diff --git a/vis4d/data/transforms/base.py b/vis4d/data/transforms/base.py new file mode 100644 index 0000000000000000000000000000000000000000..ec1b2d8b52a10eecf2bbb79d512070bada4532e9 --- /dev/null +++ b/vis4d/data/transforms/base.py @@ -0,0 +1,227 @@ +"""Basic data augmentation class.""" + +from __future__ import annotations + +from collections.abc import Callable, Sequence +from typing import TypeVar, no_type_check + +import torch + +from vis4d.common.dict import get_dict_nested, set_dict_nested +from vis4d.data.typing import DictData + +TFunctor = TypeVar("TFunctor", bound=object) # pylint: disable=invalid-name +TransformFunction = Callable[[list[DictData]], list[DictData]] + + +class Transform: + """Transforms Decorator. + + This class stores which `in_keys` are input to a transformation function + and which `out_keys` are overwritten in the data dictionary by the output + of this transformation. + Nested keys in the data dictionary can be accessed via key.subkey1.subkey2 + If any of `in_keys` is 'data', the full data dictionary will be forwarded + to the transformation. + If the only entry in `out_keys` is 'data', the full data dictionary will + be updated with the return value of the transformation. + For the case of multi-sensor data, the sensors that the transform should be + applied can be set via the 'sensors' attribute. By default, we assume + a transformation is applied to all sensors. + This class will add a 'apply_to_data' method to a given Functor which is + used to call it on a DictData object. NOTE: This is an issue for static + checking and is not recognized by pylint. It will usually be called in the + compose() function and will not be called directly. + + Example: + >>> @Transform(in_keys="images", out_keys="images") + >>> class MyTransform: + >>> def __call__(images: list[np.array]) -> list[np.array]: + >>> images = do_something(images) + >>> return images + >>> my_transform = MyTransform() + >>> data = my_transform.apply_to_data(data) + """ + + def __init__( + self, + in_keys: Sequence[str] | str, + out_keys: Sequence[str] | str, + sensors: Sequence[str] | str | None = None, + same_on_batch: bool = True, + ) -> None: + """Creates an instance of Transform. + + Args: + in_keys (Sequence[str] | str): Specifies one or multiple (if any) + input keys of the data dictionary which should be remapeed to + another key. Defaults to None. + out_keys (Sequence[str] | str): Specifies one or multiple (if any) + output keys of the data dictionary which should be remaped to + another key. Defaults to None. + sensors (Sequence[str] | str | None, optional): Specifies the + sensors this transformation should be applied to. If None, it + will be applied to all available sensors. Defaults to None. + same_on_batch (bool, optional): Whether to use the same + transformation parameters to all sensors / view. Defaults to + True. + """ + if isinstance(in_keys, str): + in_keys = [in_keys] + self.in_keys = in_keys + + if isinstance(out_keys, str): + out_keys = [out_keys] + self.out_keys = out_keys + + if isinstance(sensors, str): + sensors = [sensors] + self.sensors = sensors + + self.same_on_batch = same_on_batch + + @no_type_check + def __call__(self, transform: TFunctor) -> TFunctor: + """Add in_keys / out_keys / sensors / apply_to_data attributes. + + Args: + transform (TFunctor): A given Functor. + + Returns: + TFunctor: The decorated Functor. + """ + original_init = transform.__init__ + + def apply_to_data( + self_, input_batch: list[DictData] + ) -> list[DictData]: + """Wrap function with a handler for input / output keys. + + We use the specified in_keys in order to extract the positional + input arguments of a function from the data dictionary, and the + out_keys to replace the corresponding values in the output + dictionary. + """ + + def _transform_fn(batch: list[DictData]) -> list[DictData]: + in_batch = [] + for key in self_.in_keys: + key_data = [] + for data in batch: + # Optionally allow the function to get the full data + # dict as aux input and set default value to None if + # key is not found + key_data += [ + ( + get_dict_nested( + data, key.split("."), allow_missing=True + ) + if key != "data" + else data + ) + ] + if any(d is None for d in key_data): + # If any of the data in the batch is None, replace + # the input of the key with None. + in_batch.append(None) + else: + in_batch.append(key_data) + + result = self_(*in_batch) + + if len(self_.out_keys) == 1: + if self_.out_keys[0] == "data": + return result + result = [result] + + for key, values in zip(self_.out_keys, result): + if values is None: + continue + for data, value in zip(batch, values): + if value is not None: + set_dict_nested(data, key.split("."), value) + return batch + + if self_.sensors is not None: + if self_.same_on_batch: + for sensor in self_.sensors: + batch_sensor = _transform_fn( + [d[sensor] for d in input_batch] + ) + for i, d in enumerate(batch_sensor): + input_batch[i][sensor] = d + else: + for i, data in enumerate(input_batch): + for sensor in self_.sensors: + input_batch[i][sensor] = _transform_fn( + [data[sensor]] + ) + elif self_.same_on_batch: + input_batch = _transform_fn(input_batch) + else: + for i, data in enumerate(input_batch): + input_batch[i] = _transform_fn([data])[0] + + return input_batch + + def init( + *args, + in_keys: Sequence[str] = self.in_keys, + out_keys: Sequence[str] = self.out_keys, + sensors: Sequence[str] | None = self.sensors, + same_on_batch: bool = self.same_on_batch, + **kwargs, + ): + self_ = args[0] + original_init(*args, **kwargs) + self_.in_keys = in_keys + self_.out_keys = out_keys + self_.sensors = sensors + self_.same_on_batch = same_on_batch + self_.apply_to_data = lambda *args, **kwargs: apply_to_data( + self_, *args, **kwargs + ) + + transform.__init__ = init + return transform + + +def compose(transforms: list[TFunctor]) -> TransformFunction: + """Compose transformations. + + This function composes a given set of transformation functions, i.e. any + functor decorated with Transform, into a single transform. + """ + + def _preprocess_func(batch: list[DictData]) -> list[DictData]: + for op in transforms: + batch = op.apply_to_data(batch) # type: ignore + return batch + + return _preprocess_func + + +@Transform("data", "data") +class RandomApply: + """Randomize the application of a given set of transformations.""" + + def __init__( + self, transforms: list[TFunctor], probability: float = 0.5 + ) -> None: + """Creates an instance of RandomApply. + + Args: + transforms (list[TFunctor]): Transformations that are applied with + a given probability. + probability (float, optional): Probability to apply + transformations. Defaults to 0.5. + """ + self.transforms = transforms + self.probability = probability + + def __call__(self, batch: list[DictData]) -> list[DictData]: + """Apply transforms with a given probability.""" + if torch.rand(1) < self.probability: + for op in self.transforms: + batch = op.apply_to_data(batch) # type: ignore + return batch diff --git a/vis4d/data/transforms/crop.py b/vis4d/data/transforms/crop.py new file mode 100644 index 0000000000000000000000000000000000000000..a1fa1154e1ac33bfa9415a4b2edc69477f795b89 --- /dev/null +++ b/vis4d/data/transforms/crop.py @@ -0,0 +1,529 @@ +"""Crop transformation.""" + +from __future__ import annotations + +import math +from collections.abc import Callable +from typing import List, Tuple, TypedDict, Union + +import numpy as np +import torch + +from vis4d.common.logging import rank_zero_warn +from vis4d.common.typing import ( + NDArrayBool, + NDArrayF32, + NDArrayI32, + NDArrayI64, + NDArrayUI8, +) +from vis4d.data.const import CommonKeys as K +from vis4d.op.box.box2d import bbox_intersection + +from .base import Transform + +CropShape = Union[ + Tuple[float, float], + Tuple[int, int], + List[Tuple[float, float]], + List[Tuple[int, int]], +] +CropFunc = Callable[[int, int, CropShape], Tuple[int, int]] + + +class CropParam(TypedDict): + """Parameters for Crop.""" + + crop_box: NDArrayI32 + keep_mask: NDArrayBool + + +def absolute_crop(im_h: int, im_w: int, shape: CropShape) -> tuple[int, int]: + """Absolute crop.""" + assert isinstance(shape, tuple) + assert shape[0] > 0 and shape[1] > 0 + return (min(int(shape[0]), im_h), min(int(shape[1]), im_w)) + + +def absolute_range_crop( + im_h: int, im_w: int, shape: CropShape +) -> tuple[int, int]: + """Absolute range crop.""" + assert isinstance(shape, list) + assert len(shape) == 2 + assert shape[1][0] >= shape[0][0] + assert shape[1][1] >= shape[0][1] + + for crop in shape: + assert crop[0] > 0 and crop[1] > 0 + shape_min: tuple[int, int] = (int(shape[0][0]), int(shape[0][1])) + shape_max: tuple[int, int] = (int(shape[1][0]), int(shape[1][1])) + + crop_h = np.random.randint( + min(im_h, shape_min[0]), min(im_h, shape_max[0]) + 1 + ) + crop_w = np.random.randint( + min(im_w, shape_min[1]), min(im_w, shape_max[1]) + 1 + ) + return int(crop_h), int(crop_w) + + +def relative_crop(im_h: int, im_w: int, shape: CropShape) -> tuple[int, int]: + """Relative crop.""" + assert isinstance(shape, tuple) + assert 0 < shape[0] <= 1 and 0 < shape[1] <= 1 + crop_h, crop_w = shape + return int(im_h * crop_h + 0.5), int(im_w * crop_w + 0.5) + + +def relative_range_crop( + im_h: int, im_w: int, shape: CropShape +) -> tuple[int, int]: + """Relative range crop.""" + assert isinstance(shape, list) + assert len(shape) == 2 + assert shape[1][0] >= shape[0][0] + assert shape[1][1] >= shape[0][1] + for crop in shape: + assert 0 < crop[0] <= 1 and 0 < crop[1] <= 1 + scale_min: tuple[float, float] = shape[0] + scale_max: tuple[float, float] = shape[1] + + crop_h = np.random.rand() * (scale_max[0] - scale_min[0]) + scale_min[0] + crop_w = np.random.rand() * (scale_max[1] - scale_min[1]) + scale_min[1] + return int(im_h * crop_h + 0.5), int(im_w * crop_w + 0.5) + + +@Transform( + in_keys=[K.input_hw, K.boxes2d, K.seg_masks], + out_keys="transforms.crop", +) +class GenCropParameters: + """Generate the parameters for a crop operation.""" + + def __init__( + self, + shape: CropShape, + crop_func: CropFunc = absolute_crop, + allow_empty_crops: bool = True, + cat_max_ratio: float = 1.0, + ignore_index: int = 255, + ) -> None: + """Creates an instance of the class. + + Args: + shape (CropShape): Image shape to be cropped to in [H, W]. + crop_func (CropFunc, optional): Function used to generate the size + of the crop. Defaults to absolute_crop. + allow_empty_crops (bool, optional): Allow crops which result in + empty labels. Defaults to True. + cat_max_ratio (float, optional): Maximum ratio of a particular + class in segmentation masks after cropping. Defaults to 1.0. + ignore_index (int, optional): The index to ignore. Defaults to 255. + """ + self.shape = shape + self.crop_func = crop_func + self.allow_empty_crops = allow_empty_crops + self.cat_max_ratio = cat_max_ratio + self.ignore_index = ignore_index + + def _get_crop( + self, im_h: int, im_w: int, boxes: NDArrayF32 | None = None + ) -> tuple[NDArrayI32, NDArrayBool]: + """Get the crop parameters.""" + crop_size = self.crop_func(im_h, im_w, self.shape) + crop_box = _sample_crop(im_h, im_w, crop_size) + keep_mask = _get_keep_mask(boxes, crop_box) + return crop_box, keep_mask + + def __call__( + self, + input_hw_list: list[tuple[int, int]], + boxes_list: list[NDArrayF32] | None, + masks_list: list[NDArrayUI8] | None, + ) -> list[CropParam]: + """Compute the parameters and put them in the data dict.""" + im_h, im_w = input_hw_list[0] + boxes = boxes_list[0] if boxes_list is not None else None + masks = masks_list[0] if masks_list is not None else None + + crop_box, keep_mask = self._get_crop(im_h, im_w, boxes) + if (boxes is not None and len(boxes) > 0) or masks is not None: + # resample crop if conditions not satisfied + found_crop = False + for _ in range(10): + # try resampling 10 times, otherwise use last crop + if (self.allow_empty_crops or keep_mask.sum() != 0) and ( + _check_seg_max_cat( + masks, crop_box, self.cat_max_ratio, self.ignore_index + ) + ): + found_crop = True + break + crop_box, keep_mask = self._get_crop(im_h, im_w, boxes) + if not found_crop: + rank_zero_warn("Random crop not found within 10 resamples.") + + crop_params = [ + CropParam(crop_box=crop_box, keep_mask=keep_mask) + ] * len(input_hw_list) + + return crop_params + + +@Transform([K.input_hw, K.boxes2d], "transforms.crop") +class GenCentralCropParameters: + """Generate the parameters for a central crop operation.""" + + def __init__( + self, + shape: CropShape, + crop_func: CropFunc = absolute_crop, + ) -> None: + """Creates an instance of the class. + + Args: + shape (CropShape): Image shape to be cropped to. + crop_func (CropFunc, optional): Function used to generate the size + of the crop. Defaults to absolute_crop. + """ + self.shape = shape + self.crop_func = crop_func + + def __call__( + self, + input_hw_list: list[tuple[int, int]], + boxes_list: list[NDArrayF32] | None, + ) -> list[CropParam]: + """Compute the parameters and put them in the data dict.""" + im_h, im_w = input_hw_list[0] + boxes = boxes_list[0] if boxes_list is not None else None + + crop_size = self.crop_func(im_h, im_w, self.shape) + crop_box = _get_central_crop(im_h, im_w, crop_size) + keep_mask = _get_keep_mask(boxes, crop_box) + crop_params = [ + CropParam(crop_box=crop_box, keep_mask=keep_mask) + ] * len(input_hw_list) + + return crop_params + + +@Transform([K.input_hw, K.boxes2d], "transforms.crop") +class GenRandomSizeCropParameters: + """Generate the parameters for a random size crop operation. + + A crop of the original image is made: the crop has a random area (H * W) + and a random aspect ratio. Code adapted from torchvision. + """ + + def __init__( + self, + scale: tuple[float, float] = (0.08, 1.0), + ratio: tuple[float, float] = (3.0 / 4.0, 4.0 / 3.0), + ): + """Creates an instance of the class. + + Args: + scale (tuple[float, float], optional): Scale range of the cropped + area. Defaults to (0.08, 1.0). + ratio (tuple[float, float], optional): Aspect ratio range of the + cropped area. Defaults to (3.0 / 4.0, 4.0 / 3.0). + """ + self.scale = scale + self.ratio = np.array(ratio) + self.log_ratio = np.log(self.ratio) + + def get_params(self, height: int, width: int) -> NDArrayI32: + """Get parameters for the random size crop.""" + area = height * width + for _ in range(10): + target_area = area * np.random.uniform( + self.scale[0], self.scale[1] + ) + aspect_ratio = np.exp( + np.random.uniform(self.log_ratio[0], self.log_ratio[1]) + ) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if 0 < w <= width and 0 < h <= height: + i = np.random.randint(0, height - h + 1) + j = np.random.randint(0, width - w + 1) + crop_x1, crop_y1, crop_x2, crop_y2 = i, j, i + h, j + w + return np.array([crop_x1, crop_y1, crop_x2, crop_y2]) + + # Fallback to central crop + in_ratio = float(width) / float(height) + if in_ratio < min(self.ratio): + w = width + h = int(round(w / min(self.ratio))) + elif in_ratio > max(self.ratio): + h = height + w = int(round(h * max(self.ratio))) + else: # whole image + w = width + h = height + i = (height - h) // 2 + j = (width - w) // 2 + crop_x1, crop_y1, crop_x2, crop_y2 = i, j, i + h, j + w + return np.array([crop_x1, crop_y1, crop_x2, crop_y2]) + + def __call__( + self, + input_hw_list: list[tuple[int, int]], + boxes_list: list[NDArrayF32] | None, + ) -> list[CropParam]: + """Compute the parameters and put them in the data dict.""" + im_h, im_w = input_hw_list[0] + boxes = boxes_list[0] if boxes_list is not None else None + + crop_box = self.get_params(im_h, im_w) + keep_mask = _get_keep_mask(boxes, crop_box) + + crop_params = [ + CropParam(crop_box=crop_box, keep_mask=keep_mask) + ] * len(input_hw_list) + + return crop_params + + +@Transform([K.images, "transforms.crop.crop_box"], [K.images, K.input_hw]) +class CropImages: + """Crop Images.""" + + def __call__( + self, images: list[NDArrayF32], crop_box_list: list[NDArrayI32] + ) -> tuple[list[NDArrayF32], list[tuple[int, int]]]: + """Crop a list of image of dimensions [N, H, W, C]. + + Args: + images (list[NDArrayF32]): The list of image. + crop_box (list[NDArrayI32]): The list of box to crop. + + Returns: + list[NDArrayF32]: List of cropped image according to parameters. + """ + input_hw_list = [] + for i, (image, crop_box) in enumerate(zip(images, crop_box_list)): + h, w = image.shape[1], image.shape[2] + x1, y1, x2, y2 = crop_box + crop_w, crop_h = x2 - x1, y2 - y1 + image = image[:, y1:y2, x1:x2, :] + input_hw = (min(crop_h, h), min(crop_w, w)) + + images[i] = image + input_hw_list.append(input_hw) + return images, input_hw_list + + +@Transform( + in_keys=[ + K.boxes2d, + K.boxes2d_classes, + K.boxes2d_track_ids, + "transforms.crop.crop_box", + "transforms.crop.keep_mask", + ], + out_keys=[K.boxes2d, K.boxes2d_classes, K.boxes2d_track_ids], +) +class CropBoxes2D: + """Crop 2D bounding boxes.""" + + def __call__( + self, + boxes_list: list[NDArrayF32], + classes_list: list[NDArrayI64], + track_ids_list: list[NDArrayI64] | None, + crop_box_list: list[NDArrayI32], + keep_mask_list: list[NDArrayBool], + ) -> tuple[list[NDArrayF32], list[NDArrayI64], list[NDArrayI64] | None]: + """Crop 2D bounding boxes. + + Args: + boxes_list (list[NDArrayF32]): The list of bounding boxes to be + cropped. + classes_list (list[NDArrayI64]): The list of the corresponding + classes. + track_ids_list (list[NDArrayI64] | None, optional): The list of + corresponding tracking IDs. Defaults to None. + crop_box_list (list[NDArrayI32]): The list of box to crop. + keep_mask_list (list[NDArrayBool]): Which boxes to keep. + + Returns: + tuple[list[NDArrayF32], list[NDArrayI64], list[NDArrayI64]] | None: + List of cropped bounding boxes according to parameters. + """ + for i, (boxes, classes, crop_box, keep_mask) in enumerate( + zip( + boxes_list, + classes_list, + crop_box_list, + keep_mask_list, + ) + ): + x1, y1 = crop_box[:2] + boxes -= np.array([x1, y1, x1, y1]) + + boxes_list[i] = boxes[keep_mask] + classes_list[i] = classes[keep_mask] + + if track_ids_list is not None: + track_ids_list[i] = track_ids_list[i][keep_mask] + + return boxes_list, classes_list, track_ids_list + + +@Transform([K.seg_masks, "transforms.crop.crop_box"], K.seg_masks) +class CropSegMasks: + """Crop segmentation masks.""" + + def __call__( + self, masks_list: list[NDArrayUI8], crop_box_list: list[NDArrayI32] + ) -> list[NDArrayUI8]: + """Crop masks.""" + for i, (masks, crop_box) in enumerate(zip(masks_list, crop_box_list)): + x1, y1, x2, y2 = crop_box + masks_list[i] = masks[y1:y2, x1:x2] + return masks_list + + +@Transform( + in_keys=[ + K.instance_masks, + "transforms.crop.crop_box", + "transforms.crop.keep_mask", + ], + out_keys=[K.instance_masks], +) +class CropInstanceMasks: + """Crop instance segmentation masks.""" + + def __call__( + self, + masks_list: list[NDArrayUI8], + crop_box_list: list[NDArrayI32], + keep_mask_list: list[NDArrayBool], + ) -> list[NDArrayUI8]: + """Crop masks.""" + for i, (masks, crop_box) in enumerate(zip(masks_list, crop_box_list)): + x1, y1, x2, y2 = crop_box + masks = masks[:, y1:y2, x1:x2] + masks_list[i] = masks[keep_mask_list[i]] + return masks_list + + +@Transform([K.depth_maps, "transforms.crop.crop_box"], K.depth_maps) +class CropDepthMaps: + """Crop depth maps.""" + + def __call__( + self, depth_maps: list[NDArrayF32], crop_box_list: list[NDArrayI32] + ) -> list[NDArrayF32]: + """Crop depth maps.""" + for i, (depth_map, crop_box) in enumerate( + zip(depth_maps, crop_box_list) + ): + x1, y1, x2, y2 = crop_box + depth_maps[i] = depth_map[y1:y2, x1:x2] + return depth_maps + + +@Transform([K.optical_flows, "transforms.crop.crop_box"], K.optical_flows) +class CropOpticalFlows: + """Crop optical flows.""" + + def __call__( + self, optical_flows: list[NDArrayF32], crop_box_list: NDArrayI32 + ) -> list[NDArrayF32]: + """Crop optical flows.""" + for i, (optical_flow, crop_box) in enumerate( + zip(optical_flows, crop_box_list) + ): + x1, y1, x2, y2 = crop_box + optical_flows[i] = optical_flow[y1:y2, x1:x2] + return optical_flows + + +@Transform([K.intrinsics, "transforms.crop.crop_box"], K.intrinsics) +class CropIntrinsics: + """Crop Intrinsics.""" + + def __call__( + self, + intrinsics_list: list[NDArrayF32], + crop_box_list: list[NDArrayI32], + ) -> list[NDArrayF32]: + """Crop camera intrinsics.""" + for i, crop_box in enumerate(crop_box_list): + x1, y1 = crop_box[:2] + intrinsics_list[i][0, 2] -= x1 + intrinsics_list[i][1, 2] -= y1 + return intrinsics_list + + +def _sample_crop( + im_h: int, im_w: int, crop_size: tuple[int, int] +) -> NDArrayI32: + """Sample crop parameters according to config.""" + margin_h = max(im_h - crop_size[0], 0) + margin_w = max(im_w - crop_size[1], 0) + offset_h = np.random.randint(0, margin_h + 1) + offset_w = np.random.randint(0, margin_w + 1) + crop_y1, crop_y2 = offset_h, offset_h + crop_size[0] + crop_x1, crop_x2 = offset_w, offset_w + crop_size[1] + return np.array([crop_x1, crop_y1, crop_x2, crop_y2]) + + +def _get_central_crop( + im_h: int, im_w: int, crop_size: tuple[int, int] +) -> NDArrayI32: + """Get central crop parameters.""" + margin_h = max(im_h - crop_size[0], 0) + margin_w = max(im_w - crop_size[1], 0) + offset_h = margin_h // 2 + offset_w = margin_w // 2 + crop_y1, crop_y2 = offset_h, offset_h + crop_size[0] + crop_x1, crop_x2 = offset_w, offset_w + crop_size[1] + return np.array([crop_x1, crop_y1, crop_x2, crop_y2]) + + +def _get_keep_mask( + boxes: NDArrayF32 | None, crop_box: NDArrayI32 +) -> NDArrayBool: + """Get mask for 2D annotations to keep.""" + if boxes is None or len(boxes) == 0: + return np.array([], dtype=bool) + # will be better to compute mask intersection (if exists) instead + overlap = bbox_intersection( + torch.tensor(boxes), torch.tensor(crop_box).unsqueeze(0) + ).numpy() + return overlap.squeeze(-1) > 0 + + +def _check_seg_max_cat( + masks: NDArrayUI8 | None, + crop_box: NDArrayI32, + cat_max_ratio: float, + ignore_index: int = 255, +) -> bool: + """Check if any category occupies more than cat_max_ratio. + + Args: + masks (NDArrayUI8 | None): Segmentation masks. + crop_box (NDArrayI32): The box to crop. + cat_max_ratio (float): Maximum category ratio. + ignore_index (int, optional): The index to ignore. Defaults to 255. + + Returns: + bool: True if no category occupies more than cat_max_ratio. + """ + if cat_max_ratio >= 1.0 or masks is None: + return True + x1, y1, x2, y2 = crop_box + crop_masks = masks[y1:y2, x1:x2] + cls_ids, cnts = np.unique(crop_masks, return_counts=True) + cnts = cnts[cls_ids != ignore_index] + + return (cnts.max() / cnts.sum()) < cat_max_ratio diff --git a/vis4d/data/transforms/flip.py b/vis4d/data/transforms/flip.py new file mode 100644 index 0000000000000000000000000000000000000000..42a314e17ef82fe208eecb4016356c87b27f83d0 --- /dev/null +++ b/vis4d/data/transforms/flip.py @@ -0,0 +1,359 @@ +"""Horizontal flip augmentation.""" + +import numpy as np +import torch + +from vis4d.common.typing import NDArrayF32, NDArrayUI8 +from vis4d.data.const import AxisMode +from vis4d.data.const import CommonKeys as K +from vis4d.op.geometry.rotation import ( + euler_angles_to_matrix, + matrix_to_euler_angles, + matrix_to_quaternion, + quaternion_to_matrix, +) + +from .base import Transform + + +@Transform(K.images, K.images) +class FlipImages: + """Flip a list of numpy image array of shape [N, H, W, C].""" + + def __init__(self, direction: str = "horizontal"): + """Creates an instance of FlipImage. + + Args: + direction (str, optional): Either vertical or horizontal. + Defaults to "horizontal". + + Raises: + ValueError: If direction is not horizontal or vertical. + """ + if direction not in ["horizontal", "vertical"]: + raise ValueError(f"Direction {direction} not known!") + self.direction = direction + + def __call__(self, images: list[NDArrayF32]) -> list[NDArrayF32]: + """Execute flipping op. + + Args: + image (NDArrayF32): [N, H, W, C] array of image. + + Returns: + list[NDArrayF32]: [N, H, W, C] array of flipped image. + """ + for i, image in enumerate(images): + image_ = torch.from_numpy(image) + if self.direction == "horizontal": + images[i] = image_.flip(2).numpy() + if self.direction == "vertical": + images[i] = image_.flip(1).numpy() + return images + + +@Transform(in_keys=(K.boxes2d, K.images), out_keys=(K.boxes2d,)) +class FlipBoxes2D: + """Flip a list of 2D bounding boxes.""" + + def __init__(self, direction: str = "horizontal"): + """Creates an instance of FlipBoxes2D. + + Args: + direction (str, optional): Either vertical or horizontal. + Defaults to "horizontal". + + Raises: + ValueError: If direction is not horizontal or vertical. + """ + if direction not in ["horizontal", "vertical"]: + raise ValueError(f"Direction {direction} not known!") + self.direction = direction + + def __call__( + self, boxes_list: list[NDArrayF32], images: list[NDArrayF32] + ) -> list[NDArrayF32]: + """Execute flipping op. + + Args: + boxes (list[NDArrayF32]): List of [M, 4] array of boxes. + image (list[NDArrayF32]): List of [N, H, W, C] array of image. + + Returns: + list[NDArrayF32]: List of [M, 4] array of flipped boxes. + """ + for i, (boxes, image) in enumerate(zip(boxes_list, images)): + if self.direction == "horizontal": + im_width = image.shape[2] + tmp = im_width - boxes[..., 2::4] + boxes[..., 2::4] = im_width - boxes[..., 0::4] + boxes[..., 0::4] = tmp + elif self.direction == "vertical": + im_height = image.shape[1] + tmp = im_height - boxes[..., 3::4] + boxes[..., 3::4] = im_height - boxes[..., 1::4] + boxes[..., 1::4] = tmp + boxes_list[i] = boxes + return boxes_list + + +@Transform(K.seg_masks, K.seg_masks) +class FlipSegMasks: + """Flip segmentation masks.""" + + def __init__(self, direction: str = "horizontal"): + """Creates an instance of FlipSemanticMasks. + + Args: + direction (str, optional): Either vertical or horizontal. + Defaults to "horizontal". + + Raises: + ValueError: If direction is not horizontal or vertical. + """ + if direction not in ["horizontal", "vertical"]: + raise ValueError(f"Direction {direction} not known!") + self.direction = direction + + def __call__(self, masks: list[NDArrayUI8]) -> list[NDArrayUI8]: + """Execute flipping op. + + Args: + masks (NDArrayUI8): [H, W] array of masks. + + Returns: + list[NDArrayUI8]: [H, W] array of flipped masks. + """ + for i, mask in enumerate(masks): + mask_ = torch.from_numpy(mask) + if self.direction == "horizontal": + mask = mask_.flip(1).numpy() + if self.direction == "vertical": + mask = mask_.flip(0).numpy() + masks[i] = mask + return masks + + +@Transform(K.depth_maps, K.depth_maps) +class FlipDepthMaps: + """Flip depth map.""" + + def __init__(self, direction: str = "horizontal"): + """Creates an instance of FlipDepth. + + Args: + direction (str, optional): Either vertical or horizontal. + Defaults to "horizontal". + """ + self.direction = direction + if direction not in ["horizontal", "vertical"]: + raise ValueError(f"Direction {self.direction} not known!") + + def __call__(self, depths: list[NDArrayF32]) -> list[NDArrayF32]: + """Execute flipping op. + + Args: + depths (list[NDArrayF32]): Each is a [H, W] array of depth. + + Returns: + list[NDArrayF32]: Each is a [H, W] array of flipped depth. + """ + for i, depth in enumerate(depths): + depth_ = torch.from_numpy(depth) + if self.direction == "horizontal": + depths[i] = depth_.flip(1).numpy() + if self.direction == "vertical": + depths[i] = depth_.flip(0).numpy() + + return depths + + +@Transform(K.optical_flows, K.optical_flows) +class FlipOpticalFlows: + """Flip optical flow map.""" + + def __init__(self, direction: str = "horizontal"): + """Creates an instance of FlipOpticalFlow. + + Args: + direction (str, optional): Either vertical or horizontal. + Defaults to "horizontal". + """ + self.direction = direction + if direction not in ["horizontal", "vertical"]: + raise ValueError(f"Direction {self.direction} not known!") + + def __call__(self, flows: list[NDArrayF32]) -> list[NDArrayF32]: + """Execute flipping op. + + Args: + flows (NDArrayF32): Each is a [H, W, 2] array of optical flow. + + Returns: + list[NDArrayF32]: Each is a [H, W, 2] array of flipped optical + flow. + """ + for i, flow in enumerate(flows): + flow_ = torch.from_numpy(flow) + if self.direction == "horizontal": + image_flipped = flow_.flip(1).numpy() + image_flipped[..., 0] *= -1 + flows[i] = image_flipped + if self.direction == "vertical": + image_flipped = flow_.flip(0).numpy() + image_flipped[..., 1] *= -1 + flows[i] = image_flipped + return flows + + +@Transform(K.instance_masks, K.instance_masks) +class FlipInstanceMasks: + """Flip instance masks.""" + + def __init__(self, direction: str = "horizontal"): + """Creates an instance of FlipInstanceMasks. + + Args: + direction (str, optional): Either vertical or horizontal. + Defaults to "horizontal". + + Raises: + ValueError: If direction is not horizontal or vertical. + """ + if direction not in ["horizontal", "vertical"]: + raise ValueError(f"Direction {direction} not known!") + self.direction = direction + + def __call__(self, masks: list[NDArrayUI8]) -> list[NDArrayUI8]: + """Execute flipping op. + + Args: + masks (list[NDArrayUI8]): List of [N, H, W] array of masks. + + Returns: + list[NDArrayUI8]: List of [N, H, W] array of flipped masks. + """ + for i, mask in enumerate(masks): + mask_ = torch.from_numpy(mask) + if self.direction == "horizontal": + mask = mask_.flip(2).numpy() + if self.direction == "vertical": + mask = mask_.flip(1).numpy() + masks[i] = mask + return masks + + +def get_axis(direction: str, axis_mode: AxisMode) -> int: + """Get axis number of certain direction given axis mode. + + Args: + direction (str): One of horizontal, vertical and lateral. + axis_mode (AxisMode): axis mode. + + Returns: + int: Number of axis in certain direction. + """ + if direction not in {"horizontal", "lateral", "vertical"}: + raise ValueError(f"Direction {direction} not known!") + coord_mapping = { + AxisMode.ROS: {"horizontal": 0, "lateral": 1, "vertical": 2}, + AxisMode.OPENCV: {"horizontal": 0, "vertical": 1, "lateral": 2}, + } + return coord_mapping[axis_mode][direction] + + +@Transform(in_keys=(K.boxes3d, K.axis_mode), out_keys=(K.boxes3d,)) +class FlipBoxes3D: + """Flip 3D bounding box array.""" + + def __init__(self, direction: str = "horizontal"): + """Creates an instance of FlipBoxes3D. + + Args: + direction (str, optional): Either vertical or horizontal. + Defaults to "horizontal". + """ + self.direction = direction + + def __call__( + self, boxes_list: list[NDArrayF32], axis_mode_list: list[AxisMode] + ) -> list[NDArrayF32]: + """Execute flipping.""" + for i, (boxes, axis_mode) in enumerate( + zip(boxes_list, axis_mode_list) + ): + axis = get_axis(self.direction, axis_mode) + angle_dir = ( + "vertical" if self.direction == "horizontal" else "lateral" + ) + angles_axis = get_axis(angle_dir, axis_mode) + boxes[:, axis] *= -1.0 + angles = matrix_to_euler_angles( + quaternion_to_matrix(torch.from_numpy(boxes[:, 6:])) + ) + angles[:, angles_axis] = np.pi - angles[:, angles_axis] + boxes[:, 6:] = matrix_to_quaternion( + euler_angles_to_matrix(angles) + ).numpy() + + boxes_list[i] = boxes + + return boxes_list + + +@Transform(in_keys=(K.points3d, K.axis_mode), out_keys=(K.points3d,)) +class FlipPoints3D: + """Flip pointcloud array.""" + + def __init__(self, direction: str = "horizontal"): + """Creates an instance of FlipBoxes2D. + + Args: + direction (str, optional): Either vertical or horizontal. + Defaults to "horizontal". + """ + self.direction = direction + + def __call__( + self, points3d_list: list[NDArrayF32], axis_mode_list: list[AxisMode] + ) -> list[NDArrayF32]: + """Execute flipping.""" + for i, (points3d, axis_mode) in enumerate( + zip(points3d_list, axis_mode_list) + ): + points3d[:, get_axis(self.direction, axis_mode)] *= -1.0 + points3d_list[i] = points3d + return points3d_list + + +@Transform(in_keys=(K.intrinsics, K.images), out_keys=(K.intrinsics,)) +class FlipIntrinsics: + """Modify intrinsics for image flip.""" + + def __init__(self, direction: str = "horizontal"): + """Creates an instance of FlipIntrinsics. + + Args: + direction (str, optional): Either vertical or horizontal. + Defaults to "horizontal". + + Raises: + ValueError: If direction is not horizontal or vertical. + """ + if direction not in ["horizontal", "vertical"]: + raise ValueError(f"Direction {direction} not known!") + self.direction = direction + + def __call__( + self, intrinsics_list: list[NDArrayF32], images: list[NDArrayF32] + ) -> list[NDArrayF32]: + """Execute flipping.""" + for i, (intrinsics, image) in enumerate(zip(intrinsics_list, images)): + if self.direction == "horizontal": + center = image.shape[2] / 2 + intrinsics[0, 2] = center - intrinsics[0, 2] + center + elif self.direction == "vertical": + center = image.shape[1] / 2 + intrinsics[1, 2] = center - intrinsics[1, 2] + center + intrinsics_list[i] = intrinsics + return intrinsics_list diff --git a/vis4d/data/transforms/mask.py b/vis4d/data/transforms/mask.py new file mode 100644 index 0000000000000000000000000000000000000000..d2b422b6483560b3e8c696b2dcf5cf87c1b2a849 --- /dev/null +++ b/vis4d/data/transforms/mask.py @@ -0,0 +1,74 @@ +"""Segmentation/Instance Mask Transform.""" + +from __future__ import annotations + +import numpy as np + +from vis4d.common.typing import NDArrayI64, NDArrayUI8 +from vis4d.data.const import CommonKeys as K + +from .base import Transform + + +@Transform( + in_keys=(K.boxes2d_classes, K.instance_masks), + out_keys=K.seg_masks, +) +class ConvertInstanceMaskToSegMask: + """Merge all instance masks into a single segmentation map.""" + + def __call__( + self, classes_list: list[NDArrayI64], masks_list: list[NDArrayUI8] + ) -> list[NDArrayUI8]: + """Execute conversion. + + Args: + classes_list (list[NDArrayI64]): List of Array of class ids, shape + [N,]. + masks_list (NDArrayU8): List of array of instance masks, shape + [N, H, W]. + + Returns: + list[NDArrayU8]: List of Segmentation mask, shape [H, W]. + """ + seg_masks = [] + for classes, masks in zip(classes_list, masks_list): + classes = np.asarray(classes, dtype=masks.dtype) + target = np.max(masks * classes[:, None, None], axis=0) + # discard overlapping instances + target[np.sum(masks, axis=0) > 1] = 255 + + seg_masks.append(target) + return seg_masks + + +@Transform( + in_keys=K.boxes2d_classes, + out_keys=K.boxes2d_classes, +) +class RemappingCategories: + """Remap classes using a mapping list.""" + + def __init__(self, mapping: list[int]): + """Initialize remapping. + + Args: + mapping (List[int]): List of class ids, such that classes will be + mapped to its location in the list. + """ + self.mapping = mapping + + def __call__(self, classes_list: list[NDArrayI64]) -> list[NDArrayI64]: + """Execute remapping. + + Args: + classes_list (list[NDArrayI64]): List of array of class ids, shape + [N,]. + + Returns: + list[NDArrayI64]: List of array of remapped class ids, shape [N,]. + """ + for i, classes in enumerate(classes_list): + for j, class_ in enumerate(classes): + classes_list[i][j] = self.mapping.index(class_) + return classes_list diff --git a/vis4d/data/transforms/mixup.py b/vis4d/data/transforms/mixup.py new file mode 100644 index 0000000000000000000000000000000000000000..584d45462261c90b14a6eaa337ddb077713a41c0 --- /dev/null +++ b/vis4d/data/transforms/mixup.py @@ -0,0 +1,376 @@ +"""Mixup data augmentation.""" + +from __future__ import annotations + +import random +from typing import TypedDict + +import numpy as np +import torch + +from vis4d.common.typing import NDArrayF32, NDArrayI64 +from vis4d.data.const import CommonKeys as K +from vis4d.op.box.box2d import bbox_intersection + +from .base import Transform +from .resize import get_resize_shape, resize_image + + +class MixupParam(TypedDict): + """Typed dict for mixup parameters. + + The parameters are used to mixup a pair of items in a batch. Usually, the + pairs are selected as follows: + (0, bs - 1), (1, bs - 2), ..., (bs // 2, bs // 2) + where bs is the batch size. The batch size must be even for mixup to work. + """ + + ratio: float + im_shape: tuple[int, int] + im_scale: tuple[float, float] + other_ori_hw: tuple[int, int] + other_new_hw: tuple[int, int] + crop_coord: tuple[int, int, int, int] + pad_hw: tuple[int, int] + pad_value: float + + +@Transform(in_keys=(K.images,), out_keys=("transforms.mixup",)) +class GenMixupParameters: + """Generate the parameters for a mixup operation.""" + + NUM_SAMPLES = 2 + + def __init__( + self, + out_shape: tuple[int, int], + mixup_ratio_dist: str = "beta", + alpha: float = 1.0, + const_ratio: float = 0.5, + scale_range: tuple[float, float] = (1.0, 1.0), + pad_value: float = 0.0, + ) -> None: + """Init function. + + Args: + out_shape (tuple[int, int]): Output shape of the mixed up images. + mixup_ratio_dist (str, optional): Distribution for sampling the + mixup ratio (i.e., lambda). Options are "beta" and "const". + Defaults to "beta". If "const", the mixup ratio will be fixed + to the value of `const_ratio`. Otherwise, the mixup ratio will + be sampled from a beta distribution with parameters `alpha`. + alpha (float, optional): Parameter for beta distribution used for + sampling the mixup ratio (i.e., lambda). Defaults to 1.0. + const_ratio (float, optional): Constant mixup ratio. Defaults to + 0.5. + scale_range (tuple[float, float], optional): Range for + random scale jitter. Defaults to (1.0, 1.0). + pad_value (float, optional): Value for padding the mixed up image. + Defaults to 0.0. + """ + assert mixup_ratio_dist in { + "beta", + "const", + }, "Mixup ratio distribution must be either 'beta' or 'const'." + self.out_shape = out_shape + self.mixup_ratio_dist = mixup_ratio_dist + self.alpha = alpha + self.const_ratio = const_ratio + self.scale_range = scale_range + self.pad_value = pad_value + + def __call__(self, images: list[NDArrayF32]) -> list[MixupParam]: + """Generate parameters for MixUp.""" + batch_size = len(images) + assert batch_size % 2 == 0, "MixUp only supports even batch size." + + if self.mixup_ratio_dist == "beta": + ratio = np.random.beta(self.alpha, self.alpha) + else: + ratio = self.const_ratio + jit_factor = random.uniform(*self.scale_range) + + h, w = self.out_shape + ori_img, other_img = images[0], images[1] + ori_h, ori_w = ori_img.shape[1], ori_img.shape[2] + other_ori_h, other_ori_w = other_img.shape[1], other_img.shape[2] + other_ori_hw = (other_ori_h, other_ori_w) + h_i, w_i = get_resize_shape(other_ori_hw, (h, w), keep_ratio=True) + h_i, w_i = int(jit_factor * h_i), int(jit_factor * w_i) + pad_shape = (max(h_i, ori_h), max(w_i, ori_w)) + + x_offset, y_offset = 0, 0 + if pad_shape[0] > ori_h: + y_offset = random.randint(0, pad_shape[0] - ori_h) + if pad_shape[1] > ori_w: + x_offset = random.randint(0, pad_shape[1] - ori_w) + + parameter_list = [ + MixupParam( + ratio=ratio, + im_scale=(h_i / other_ori_h, w_i / other_ori_w), + im_shape=(h_i, w_i), + other_ori_hw=other_ori_hw, + other_new_hw=(min(h_i, ori_h), min(w_i, ori_w)), + pad_hw=pad_shape, + pad_value=self.pad_value, + crop_coord=( + x_offset, + y_offset, + x_offset + ori_w, + y_offset + ori_h, + ), + ) + for _ in range(batch_size) + ] + return parameter_list + + +@Transform(in_keys=(K.images, "transforms.mixup"), out_keys=(K.images,)) +class MixupImages: + """Mixup a batch of images.""" + + NUM_SAMPLES = 2 + + def __init__( + self, interpolation: str = "bilinear", imresize_backend: str = "torch" + ) -> None: + """Init function. + + Args: + interpolation (str, optional): Interpolation method for resizing + the other image. Defaults to "bilinear". + imresize_backend (str): One of torch, cv2. Defaults to torch. + """ + self.interpolation = interpolation + self.imresize_backend = imresize_backend + assert imresize_backend in { + "torch", + "cv2", + }, f"Invalid imresize backend: {imresize_backend}" + + def __call__( + self, images: list[NDArrayF32], mixup_parameters: list[MixupParam] + ) -> list[NDArrayF32]: + """Execute image mixup operation.""" + batch_size = len(images) + assert ( + batch_size % self.NUM_SAMPLES == 0 + ), "Batch size must be even for mixup!" + + mixup_images = [] + for i in range(0, batch_size, self.NUM_SAMPLES): + j = i + 1 + ori_img, other_img = images[i], images[j] + h_i, w_i = mixup_parameters[i]["im_shape"] + c = ori_img.shape[-1] + + # resize, scale jitter other image + other_img = resize_image( + other_img, + (h_i, w_i), + self.interpolation, + backend=self.imresize_backend, + ) + + # pad, optionally random crop other image + padded_img = np.full( + (1, *mixup_parameters[i]["pad_hw"], c), + fill_value=mixup_parameters[i]["pad_value"], + dtype=np.float32, + ) + padded_img[:, :h_i, :w_i, :] = other_img + x1_c, y1_c, x2_c, y2_c = mixup_parameters[i]["crop_coord"] + padded_cropped_img = padded_img[:, y1_c:y2_c, x1_c:x2_c, :] + + # mix ori and other + ratio = mixup_parameters[i]["ratio"] + mixup_image = ratio * ori_img + (1 - ratio) * padded_cropped_img + mixup_images += [mixup_image for _ in range(self.NUM_SAMPLES)] + return mixup_images + + +@Transform( + in_keys=(K.categories, "transforms.mixup"), out_keys=(K.categories,) +) +class MixupCategories: + """Mixup a batch of categories.""" + + NUM_SAMPLES = 2 + + def __init__(self, num_classes: int, label_smoothing: float = 0.1) -> None: + """Creates an instance of MixupCategories. + + Args: + num_classes (int): Number of classes. + label_smoothing (float, optional): Label smoothing parameter for + the mixup of categories. Defaults to 0.1. + """ + self.num_classes = num_classes + self.label_smoothing = label_smoothing + + def _label_smoothing( + self, + cat_1: NDArrayF32, + cat_2: NDArrayF32, + ratio: float, + ) -> NDArrayF32: + """Apply label smoothing to two category labels.""" + lam = np.array(ratio, dtype=np.float32) + off_value = np.array( + self.label_smoothing / self.num_classes, dtype=np.float32 + ) + on_value = np.array( + 1 - self.label_smoothing + off_value, dtype=np.float32 + ) + categories_1: NDArrayF32 = ( + np.zeros((self.num_classes,), dtype=np.float32) + off_value + ) + categories_2: NDArrayF32 = ( + np.zeros((self.num_classes,), dtype=np.float32) + off_value + ) + categories_1 = cat_1 * on_value + categories_2 = cat_2 * on_value + mixed = categories_1 * lam + categories_2 * (1 - lam) + return mixed.astype(np.float32) + + def __call__( + self, + categories: list[NDArrayF32], + mixup_parameters: list[MixupParam], + ) -> list[NDArrayF32]: + """Execute the categories mixup operation.""" + batch_size = len(categories) + assert ( + batch_size % self.NUM_SAMPLES == 0 + ), "Batch size must be even for mixup!" + + smooth_categories = [np.empty(0, dtype=np.float32)] * batch_size + for i in range(0, batch_size, self.NUM_SAMPLES): + j = i + 1 + smooth_categories[i] = self._label_smoothing( + categories[i], categories[j], mixup_parameters[i]["ratio"] + ) + smooth_categories[j] = smooth_categories[i] + return smooth_categories + + +@Transform( + in_keys=( + K.boxes2d, + K.boxes2d_classes, + K.boxes2d_track_ids, + "transforms.mixup", + ), + out_keys=(K.boxes2d, K.boxes2d_classes, K.boxes2d_track_ids), +) +class MixupBoxes2D: + """Mixup a batch of boxes.""" + + NUM_SAMPLES = 2 + + def __init__( + self, clip_inside_image: bool = True, max_track_ids: int = 1000 + ) -> None: + """Creates an instance of the class. + + Args: + clip_inside_image (bool): Whether to clip the boxes to be inside + the image. Defaults to True. + max_track_ids (int): The maximum number of track ids. Defaults to + 1000. + """ + self.clip_inside_image = clip_inside_image + self.max_track_ids = max_track_ids + + def __call__( + self, + boxes_list: list[NDArrayF32], + classes_list: list[NDArrayI64], + track_ids_list: list[NDArrayI64] | None, + mixup_parameters: list[MixupParam], + ) -> tuple[list[NDArrayF32], list[NDArrayI64], list[NDArrayI64] | None]: + """Execute the boxes2d mixup operation.""" + batch_size = len(boxes_list) + assert ( + batch_size % self.NUM_SAMPLES == 0 + ), "Batch size must be even for mixup!" + + mixup_boxes_list = [] + mixup_classes_list = [] + mixup_track_ids_list: list[NDArrayI64] | None = ( + [] if track_ids_list is not None else None + ) + for i in range(0, batch_size, self.NUM_SAMPLES): + j = i + 1 + ori_boxes, other_boxes = boxes_list[i].copy(), boxes_list[j].copy() + ori_classes, other_classes = ( + classes_list[i].copy(), + classes_list[j].copy(), + ) + + crop_coord = mixup_parameters[i]["crop_coord"] + im_scale = mixup_parameters[i]["im_scale"] + x1_c, y1_c, _, _ = crop_coord + + if len(other_boxes) == 0: + continue + # adjust boxes to new image size and origin coord + other_boxes[:, [0, 2]] = ( + im_scale[1] * other_boxes[:, [0, 2]] - x1_c + ) + other_boxes[:, [1, 3]] = ( + im_scale[0] * other_boxes[:, [1, 3]] - y1_c + ) + # filter boxes outside other image + crop_box = torch.tensor(crop_coord).unsqueeze(0) + is_overlap = ( + bbox_intersection(torch.from_numpy(other_boxes), crop_box) + .squeeze(-1) + .numpy() + ) + other_boxes = other_boxes[is_overlap > 0] + other_classes = other_classes[is_overlap > 0] + + # mixup track ids if available + if track_ids_list is not None: + assert mixup_track_ids_list is not None + ori_track_ids = track_ids_list[i].copy() + other_track_ids = track_ids_list[j].copy() + if ( + len(ori_track_ids) > 0 + and max(ori_track_ids) >= self.max_track_ids + ) or ( + len(other_track_ids) > 0 + and max(other_track_ids) >= self.max_track_ids + ): + raise ValueError( + f"Track id exceeds maximum track id" + f"{self.max_track_ids}!" + ) + other_track_ids += self.max_track_ids + other_track_ids = other_track_ids[is_overlap > 0] + mixup_track_ids: NDArrayI64 = np.concatenate( + (ori_track_ids, other_track_ids), 0 + ) + mixup_track_ids_list += [ + mixup_track_ids for _ in range(self.NUM_SAMPLES) + ] + + if self.clip_inside_image: + new_h, new_w = mixup_parameters[i]["other_new_hw"] + other_boxes[:, [0, 2]] = np.clip( + other_boxes[:, [0, 2]], 0, new_w + ) + other_boxes[:, [1, 3]] = np.clip( + other_boxes[:, [1, 3]], 0, new_h + ) + mixup_boxes = np.concatenate((ori_boxes, other_boxes), axis=0) + mixup_classes = np.concatenate( + (ori_classes, other_classes), axis=0 + ) + mixup_boxes_list += [mixup_boxes for _ in range(self.NUM_SAMPLES)] + mixup_classes_list += [ + mixup_classes for _ in range(self.NUM_SAMPLES) + ] + return mixup_boxes_list, mixup_classes_list, mixup_track_ids_list diff --git a/vis4d/data/transforms/mosaic.py b/vis4d/data/transforms/mosaic.py new file mode 100644 index 0000000000000000000000000000000000000000..60845b868c4408c7ecc48516006709a3fbf1bbd0 --- /dev/null +++ b/vis4d/data/transforms/mosaic.py @@ -0,0 +1,358 @@ +"""Mosaic transformation. + +Modified from mmdetection (https://github.com/open-mmlab/mmdetection). +""" + +from __future__ import annotations + +import random +from typing import TypedDict + +import numpy as np + +from vis4d.common.typing import NDArrayF32, NDArrayI64 +from vis4d.data.const import CommonKeys as K + +from .base import Transform +from .crop import _get_keep_mask +from .resize import resize_image + + +class MosaicParam(TypedDict): + """Parameters for Mosaic.""" + + out_shape: tuple[int, int] + paste_coords: list[tuple[int, int, int, int]] + crop_coords: list[tuple[int, int, int, int]] + im_shapes: list[tuple[int, int]] + im_scales: list[tuple[float, float]] + + +def mosaic_combine( + index: int, + center: tuple[int, int], + im_hw: tuple[int, int], + out_shape: tuple[int, int], +) -> tuple[tuple[int, int, int, int], tuple[int, int, int, int]]: + """Compute the mosaic parameters for the image at the current index. + + Index: + 0 = top_left, 1 = top_right, 3 = bottom_left, 4 = bottom_right + """ + assert index in {0, 1, 2, 3} + if index == 0: + # index0 to top left part of image + x1, y1, x2, y2 = ( + max(center[1] - im_hw[1], 0), + max(center[0] - im_hw[0], 0), + center[1], + center[0], + ) + crop_coord = ( + im_hw[1] - (x2 - x1), + im_hw[0] - (y2 - y1), + im_hw[1], + im_hw[0], + ) + elif index == 1: + # index1 to top right part of image + x1, y1, x2, y2 = ( + center[1], + max(center[0] - im_hw[0], 0), + min(center[1] + im_hw[1], out_shape[1] * 2), + center[0], + ) + crop_coord = ( + 0, + im_hw[0] - (y2 - y1), + min(im_hw[1], x2 - x1), + im_hw[0], + ) + elif index == 2: + # index2 to bottom left part of image + x1, y1, x2, y2 = ( + max(center[1] - im_hw[1], 0), + center[0], + center[1], + min(out_shape[0] * 2, center[0] + im_hw[0]), + ) + crop_coord = ( + im_hw[1] - (x2 - x1), + 0, + im_hw[1], + min(y2 - y1, im_hw[0]), + ) + else: + # index3 to bottom right part of image + x1, y1, x2, y2 = ( + center[1], + center[0], + min(center[1] + im_hw[1], out_shape[1] * 2), + min(out_shape[0] * 2, center[0] + im_hw[0]), + ) + crop_coord = 0, 0, min(im_hw[1], x2 - x1), min(y2 - y1, im_hw[0]) + + paste_coord = x1, y1, x2, y2 + return paste_coord, crop_coord + + +@Transform(K.input_hw, ["transforms.mosaic"]) +class GenMosaicParameters: + """Generate the parameters for a mosaic operation. + + Given 4 images, mosaic transform combines them into + one output image. The output image is composed of the parts from each sub- + image. + + mosaic transform + center_x + +------------------------------+ + | pad | pad | + | +-----------+ | + | | | | + | | image1 |--------+ | + | | | | | + | | | image2 | | + center_y |----+-------------+-----------| + | | cropped | | + |pad | image3 | image4 | + | | | | + +----|-------------+-----------+ + | | + +-------------+ + + The mosaic transform steps are as follows: + + 1. Choose the mosaic center as the intersections of 4 images. + 2. Get the left top image according to the index, and randomly + sample another 3 images from the dataset. + 3. Sub image will be cropped if image is larger than mosaic patch. + + Args: + out_shape (tuple[int, int]): The output shape of the mosaic transform. + center_ratio_range (tuple[float, float]): The range of the ratio of + the center of the mosaic patch to the output image size. + """ + + NUM_SAMPLES = 4 + + def __init__( + self, + out_shape: tuple[int, int], + center_ratio_range: tuple[float, float] = (0.5, 1.5), + ) -> None: + """Creates an instance of the class.""" + self.out_shape = out_shape + self.center_ratio_range = center_ratio_range + + def __call__(self, input_hw: list[tuple[int, int]]) -> list[MosaicParam]: + """Compute the parameters and put them in the data dict.""" + assert ( + len(input_hw) % self.NUM_SAMPLES == 0 + ), "Input number of images must be a multiple of 4 for Mosaic." + h, w = self.out_shape + # mosaic center x, y + center_y = int(random.uniform(*self.center_ratio_range) * h) + center_x = int(random.uniform(*self.center_ratio_range) * w) + center = (center_y, center_x) + + mosaic_params = [] + for i in range(0, len(input_hw), self.NUM_SAMPLES): + paste_coords, crop_coords, im_scales, im_shapes = [], [], [], [] + for idx, ori_hw in enumerate(input_hw[i : i + self.NUM_SAMPLES]): + # compute the resize shape + scale_ratio_i = min(h / ori_hw[0], w / ori_hw[1]) + h_i = int(ori_hw[0] * scale_ratio_i) + w_i = int(ori_hw[1] * scale_ratio_i) + + # compute the combine parameters + paste_coord, crop_coord = mosaic_combine( + idx, center, (h_i, w_i), self.out_shape + ) + paste_coords.append(paste_coord) + crop_coords.append(crop_coord) + im_shapes.append((h_i, w_i)) + im_scales.append((scale_ratio_i, scale_ratio_i)) + mosaic_params += [ + MosaicParam( + out_shape=self.out_shape, + paste_coords=paste_coords, + crop_coords=crop_coords, + im_shapes=im_shapes, + im_scales=im_scales, + ) + for _ in range(self.NUM_SAMPLES) + ] + + return mosaic_params + + +@Transform( + in_keys=[ + K.images, + "transforms.mosaic.out_shape", + "transforms.mosaic.paste_coords", + "transforms.mosaic.crop_coords", + "transforms.mosaic.im_shapes", + ], + out_keys=[K.images, K.input_hw], +) +class MosaicImages: + """Apply Mosaic to images.""" + + NUM_SAMPLES = 4 + + def __init__( + self, + pad_value: float = 114.0, + interpolation: str = "bilinear", + imresize_backend: str = "torch", + ) -> None: + """Creates an instance of the class. + + Args: + pad_value (float): The value to pad the image with. Defaults to + 114.0. + interpolation (str): Interpolation mode for resizing image. + Defaults to bilinear. + imresize_backend (str): One of torch, cv2. Defaults to torch. + """ + self.pad_value = pad_value + self.interpolation = interpolation + self.imresize_backend = imresize_backend + assert imresize_backend in { + "torch", + "cv2", + }, f"Invalid imresize backend: {imresize_backend}" + + def __call__( + self, + images: list[NDArrayF32], + out_shape: list[tuple[int, int]], + paste_coords: list[list[tuple[int, int, int, int]]], + crop_coords: list[list[tuple[int, int, int, int]]], + im_shapes: list[list[tuple[int, int]]], + ) -> tuple[list[NDArrayF32], list[tuple[int, int]]]: + """Resize an image of dimensions [N, H, W, C].""" + h, w = out_shape[0] + c = images[0].shape[-1] + + mosaic_imgs = [] + for i in range(0, len(images), self.NUM_SAMPLES): + mosaic_img = np.full( + (1, h * 2, w * 2, c), self.pad_value, dtype=np.float32 + ) + for idx, img in enumerate(images[i : i + self.NUM_SAMPLES]): + # resize current image + h_i, w_i = im_shapes[i][idx] + img_ = resize_image( + img, + (h_i, w_i), + self.interpolation, + backend=self.imresize_backend, + ) + + x1_p, y1_p, x2_p, y2_p = paste_coords[i][idx] + x1_c, y1_c, x2_c, y2_c = crop_coords[i][idx] + + # crop and paste image + mosaic_img[:, y1_p:y2_p, x1_p:x2_p, :] = img_[ + :, y1_c:y2_c, x1_c:x2_c, : + ] + mosaic_imgs += [mosaic_img for _ in range(self.NUM_SAMPLES)] + return mosaic_imgs, [(m.shape[1], m.shape[2]) for m in mosaic_imgs] + + +@Transform( + in_keys=[ + K.boxes2d, + K.boxes2d_classes, + K.boxes2d_track_ids, + "transforms.mosaic.paste_coords", + "transforms.mosaic.crop_coords", + "transforms.mosaic.im_scales", + ], + out_keys=[K.boxes2d, K.boxes2d_classes, K.boxes2d_track_ids], +) +class MosaicBoxes2D: + """Apply Mosaic to a list of 2D bounding boxes.""" + + NUM_SAMPLES = 4 + + def __init__( + self, clip_inside_image: bool = True, max_track_ids: int = 1000 + ) -> None: + """Creates an instance of the class. + + Args: + clip_inside_image (bool): Whether to clip the boxes to be inside + the image. Defaults to True. + max_track_ids (int): The maximum number of track ids. Defaults to + 1000. + """ + self.clip_inside_image = clip_inside_image + self.max_track_ids = max_track_ids + + def __call__( + self, + boxes: list[NDArrayF32], + classes: list[NDArrayI64], + track_ids: list[NDArrayI64] | None, + paste_coords: list[list[tuple[int, int, int, int]]], + crop_coords: list[list[tuple[int, int, int, int]]], + im_scales: list[list[tuple[float, float]]], + ) -> tuple[list[NDArrayF32], list[NDArrayI64], list[NDArrayI64] | None]: + """Apply Mosaic to 2D bounding boxes.""" + new_boxes, new_classes = [], [] + new_track_ids: list[NDArrayI64] | None = ( + [] if track_ids is not None else None + ) + for i in range(0, len(boxes), self.NUM_SAMPLES): + for idx in range(self.NUM_SAMPLES): + j = i + idx + + x1_p, y1_p, x2_p, y2_p = paste_coords[i][idx] + x1_c, y1_c, _, _ = crop_coords[i][idx] + + pw = x1_p - x1_c + ph = y1_p - y1_c + boxes[j][:, [0, 2]] = ( + im_scales[i][idx][1] * boxes[j][:, [0, 2]] + pw + ) + boxes[j][:, [1, 3]] = ( + im_scales[i][idx][0] * boxes[j][:, [1, 3]] + ph + ) + + keep_mask = _get_keep_mask( + boxes[j], np.array([x1_p, y1_p, x2_p, y2_p]) + ) + boxes[j] = boxes[j][keep_mask] + classes[j] = classes[j][keep_mask] + if track_ids is not None: + track_ids[j] = track_ids[j][keep_mask].copy() + if len(track_ids[j]) > 0: + if max(track_ids[j]) >= self.max_track_ids: + raise ValueError( + f"Track id exceeds maximum track id" + f"{self.max_track_ids}!" + ) + track_ids[j] += self.max_track_ids * idx + + if self.clip_inside_image: + boxes[j][:, [0, 2]] = boxes[j][:, [0, 2]].clip(x1_p, x2_p) + boxes[j][:, [1, 3]] = boxes[j][:, [1, 3]].clip(y1_p, y2_p) + new_boxes += [ + np.concatenate(boxes[i : i + self.NUM_SAMPLES]) + for _ in range(self.NUM_SAMPLES) + ] + new_classes += [ + np.concatenate(classes[i : i + self.NUM_SAMPLES]) + for _ in range(self.NUM_SAMPLES) + ] + if track_ids is not None: + assert new_track_ids is not None + new_track_ids += [ + np.concatenate(track_ids[i : i + self.NUM_SAMPLES]) + for _ in range(self.NUM_SAMPLES) + ] + return new_boxes, new_classes, new_track_ids diff --git a/vis4d/data/transforms/normalize.py b/vis4d/data/transforms/normalize.py new file mode 100644 index 0000000000000000000000000000000000000000..79d3845def8b6ab103550794c6cb3cfabee8fd2e --- /dev/null +++ b/vis4d/data/transforms/normalize.py @@ -0,0 +1,50 @@ +"""Normalize Transform.""" + +from __future__ import annotations + +import torch + +from vis4d.common.typing import NDArrayF32 + +from ..const import CommonKeys as K +from .base import Transform + + +@Transform(K.images, K.images) +class NormalizeImages: + """Normalize a list of image tensor with given mean and std. + + Image tensor is of shape [N, H, W, C] and range (0, 255). + """ + + def __init__( + self, + mean: tuple[float, float, float] = (123.675, 116.28, 103.53), + std: tuple[float, float, float] = (58.395, 57.12, 57.375), + epsilon: float = 1e-08, + ) -> None: + """Creates an instance of NormalizeImage. + + Args: + mean (Tuple[float, float, float], optional): Mean value. Defaults + to (123.675, 116.28, 103.53). + std (Tuple[float, float, float], optional): Standard deviation + value. Defaults to (58.395, 57.12, 57.375). + epsilon (float, optional): Epsilon for numerical stability of + division. Defaults to 1e-08. + """ + self.mean = mean + self.std = std + self.epsilon = epsilon + + def __call__(self, images: list[NDArrayF32]) -> list[NDArrayF32]: + """Normalize image tensor.""" + for i, image in enumerate(images): + img = torch.from_numpy(image).permute(0, 3, 1, 2) + pixel_mean = torch.tensor(self.mean).view(-1, 1, 1) + pixel_std = torch.tensor(self.std).view(-1, 1, 1) + img = (img - pixel_mean) / (pixel_std + self.epsilon) + + images[i] = img.permute(0, 2, 3, 1).numpy() + + return images diff --git a/vis4d/data/transforms/pad.py b/vis4d/data/transforms/pad.py new file mode 100644 index 0000000000000000000000000000000000000000..bda0f4c99341903adaa71c591ed65bfe8c52abf7 --- /dev/null +++ b/vis4d/data/transforms/pad.py @@ -0,0 +1,155 @@ +"""Pad transformation.""" + +from __future__ import annotations + +import numpy as np +import torch +import torch.nn.functional as F + +from vis4d.common.typing import NDArrayF32, NDArrayUI8 +from vis4d.data.const import CommonKeys as K + +from .base import Transform + + +@Transform(K.images, K.images) +class PadImages: + """Pad batch of images at the bottom right.""" + + def __init__( + self, + stride: int = 32, + mode: str = "constant", + value: float = 0.0, + shape: tuple[int, int] | None = None, + pad2square: bool = False, + ) -> None: + """Creates an instance of PadImage. + + Args: + stride (int, optional): Chooses padding size so that the input will + be divisible by stride. Defaults to 32. + mode (str, optional): Padding mode. One of constant, reflect, + replicate or circular. Defaults to "constant". + value (float, optional): Value for constant padding. + Defaults to 0.0. + shape (tuple[int, int], optional): Shape of the padded image + (H, W). Defaults to None. + pad2square (bool, optional): Pad to square. Defaults to False. + """ + if pad2square: + assert ( + shape is None + ), "Cannot specify shape when pad2square is True." + self.stride = stride + self.mode = mode + self.value = value + self.shape = shape + self.pad2square = pad2square + + def __call__(self, images: list[NDArrayF32]) -> list[NDArrayF32]: + """Pad images to consistent size.""" + heights = [im.shape[1] for im in images] + widths = [im.shape[2] for im in images] + max_hw = _get_max_shape( + heights, widths, self.stride, self.shape, self.pad2square + ) + + # generate params for torch pad + for i, (image, h, w) in enumerate(zip(images, heights, widths)): + pad_param = (0, max_hw[1] - w, 0, max_hw[0] - h) + image_ = torch.from_numpy(image).permute(0, 3, 1, 2) + image_ = F.pad( # pylint: disable=not-callable + image_, pad_param, self.mode, self.value + ) + images[i] = image_.permute(0, 2, 3, 1).numpy() + return images + + +@Transform(K.seg_masks, K.seg_masks) +class PadSegMasks: + """Pad batch of segmentation masks at the bottom right.""" + + def __init__( + self, + stride: int = 32, + mode: str = "constant", + value: int = 255, + shape: tuple[int, int] | None = None, + pad2square: bool = False, + ) -> None: + """Creates an instance of PadSegMasks. + + Args: + stride (int, optional): Chooses padding size so that the input will + be divisible by stride. Defaults to 32. + mode (str, optional): Padding mode. One of constant, reflect, + replicate or circular. Defaults to "constant". + value (float, optional): Value for constant padding. + Defaults to 0.0. + shape (tuple[int, int], optional): Shape of the padded image + (H, W). Defaults to None. + pad2square (bool, optional): Pad to square. Defaults to False. + """ + if pad2square: + assert ( + shape is None + ), "Cannot specify shape when pad2square is True." + self.stride = stride + self.mode = mode + self.value = value + self.shape = shape + self.pad2square = pad2square + + def __call__(self, masks: list[NDArrayUI8]) -> list[NDArrayUI8]: + """Pad images to consistent size.""" + heights = [mask.shape[0] for mask in masks] + widths = [mask.shape[1] for mask in masks] + max_hw = _get_max_shape( + heights, widths, self.stride, self.shape, self.pad2square + ) + + # generate params for torch pad + for i, (mask, h, w) in enumerate(zip(masks, heights, widths)): + pad_param = ((0, max_hw[0] - h), (0, max_hw[1] - w)) + masks[i] = np.pad( # type: ignore + mask, pad_param, mode=self.mode, constant_values=self.value + ) + return masks + + +def _get_max_shape( + heights: list[int], + widths: list[int], + stride: int, + shape: tuple[int, int] | None, + pad2square: bool, +) -> tuple[int, int]: + """Get max shape for padding. + + Args: + stride (int): Chooses padding size so that the input will be divisible + by stride. + shape (tuple[int, int], optional): Shape of the padded image (H, W). + Defaults to None. + heights (list[int]): List of heights of input. + widths (list[int]): List of widths of input. + pad2square (bool): Pad to square. + + Returns: + tuple[int, int]: Max shape for padding. + """ + if pad2square: + max_size = max(heights + widths) + max_hw = (max_size, max_size) + elif shape is not None: + max_hw = shape + else: + max_hw = max(heights), max(widths) + max_hw = tuple(_make_divisible(x, stride) for x in max_hw) # type: ignore # pylint: disable=line-too-long + return max_hw + + +def _make_divisible(x: int, stride: int) -> int: + """Ensure divisibility by stride.""" + return (x + (stride - 1)) // stride * stride diff --git a/vis4d/data/transforms/photometric.py b/vis4d/data/transforms/photometric.py new file mode 100644 index 0000000000000000000000000000000000000000..a0abd454e25e1cea76b6b1dcf297e17c81f78947 --- /dev/null +++ b/vis4d/data/transforms/photometric.py @@ -0,0 +1,355 @@ +"""Photometric transforms.""" + +from __future__ import annotations + +from collections.abc import Callable + +import numpy as np +import torch +import torchvision.transforms.v2.functional as TF +from torch import Tensor + +from vis4d.common.imports import OPENCV_AVAILABLE +from vis4d.common.typing import NDArrayF32 +from vis4d.data.const import CommonKeys as K + +from .base import Transform + +if OPENCV_AVAILABLE: + import cv2 +else: + raise ImportError("cv2 is not installed.") + + +@Transform(K.images, K.images) +class RandomGamma: + """Apply Gamma transformation to images. + + Args: + gamma_range (tuple[float, float]): Range of gamma values. + image_channel_mode (str, optional): Image channel mode. Defaults to + "RGB". + """ + + def __init__( + self, + gamma_range: tuple[float, float] = (1.0, 1.0), + image_channel_mode: str = "RGB", + ) -> None: + """Init function for Gamma.""" + self.gamma_range = gamma_range + self.image_channel_mode = image_channel_mode + assert image_channel_mode in {"RGB", "BGR"}, ( + "image_channel_mode should be 'RGB' or 'BGR', " + f"got {image_channel_mode}." + ) + + def __call__(self, images: list[NDArrayF32]) -> list[NDArrayF32]: + """Call function for Gamma transformation.""" + factor = np.random.uniform(self.gamma_range[0], self.gamma_range[1]) + return _adjust_images( + images, TF.adjust_gamma, factor, self.image_channel_mode + ) + + +@Transform(K.images, K.images) +class RandomBrightness: + """Apply Brightness transformation to images. + + Args: + brightness_range (tuple[float, float]): Range of brightness values. + image_channel_mode (str, optional): Image channel mode. Defaults to + "RGB". + """ + + def __init__( + self, + brightness_range: tuple[float, float] = (1.0, 1.0), + image_channel_mode: str = "RGB", + ) -> None: + """Init function for Brightness.""" + self.brightness_range = brightness_range + self.image_channel_mode = image_channel_mode + assert image_channel_mode in {"RGB", "BGR"}, ( + "image_channel_mode should be 'RGB' or 'BGR', " + f"got {image_channel_mode}." + ) + + def __call__(self, images: list[NDArrayF32]) -> list[NDArrayF32]: + """Call function for Brightness transformation.""" + factor = np.random.uniform( + self.brightness_range[0], self.brightness_range[1] + ) + return _adjust_images( + images, TF.adjust_brightness, factor, self.image_channel_mode + ) + + +@Transform(K.images, K.images) +class RandomContrast: + """Apply Contrast transformation to images. + + Args: + contrast_range (tuple[float, float]): Range of contrast values. + image_channel_mode (str, optional): Image channel mode. Defaults to + "RGB". + """ + + def __init__( + self, + contrast_range: tuple[float, float] = (1.0, 1.0), + image_channel_mode: str = "RGB", + ): + """Init function for Contrast.""" + self.contrast_range = contrast_range + self.image_channel_mode = image_channel_mode + assert image_channel_mode in {"RGB", "BGR"}, ( + "image_channel_mode should be 'RGB' or 'BGR', " + f"got {image_channel_mode}." + ) + + def __call__(self, images: list[NDArrayF32]) -> list[NDArrayF32]: + """Call function for Contrast transformation.""" + factor = np.random.uniform( + self.contrast_range[0], self.contrast_range[1] + ) + return _adjust_images( + images, TF.adjust_contrast, factor, self.image_channel_mode + ) + + +@Transform(K.images, K.images) +class RandomSaturation: + """Apply saturation transformation to images. + + Args: + saturation_range (tuple[float, float]): Range of saturation values. + image_channel_mode (str, optional): Image channel mode. Defaults to + "RGB". + """ + + def __init__( + self, + saturation_range: tuple[float, float] = (1.0, 1.0), + image_channel_mode: str = "RGB", + ): + """Init function for saturation.""" + self.saturation_range = saturation_range + self.image_channel_mode = image_channel_mode + assert image_channel_mode in {"RGB", "BGR"}, ( + "image_channel_mode should be 'RGB' or 'BGR', " + f"got {image_channel_mode}." + ) + + def __call__(self, images: list[NDArrayF32]) -> list[NDArrayF32]: + """Call function for saturation transformation.""" + factor = np.random.uniform( + self.saturation_range[0], self.saturation_range[1] + ) + return _adjust_images( + images, TF.adjust_saturation, factor, self.image_channel_mode + ) + + +@Transform(K.images, K.images) +class RandomHue: + """Apply hue transformation to images. + + Args: + hue_range (tuple[float, float]): Range of hue values. + image_channel_mode (str, optional): Image channel mode. Defaults to + "RGB". + """ + + def __init__( + self, + hue_range: tuple[float, float] = (0.0, 0.0), + image_channel_mode: str = "RGB", + ): + """Init function for hue.""" + self.hue_range = hue_range + self.image_channel_mode = image_channel_mode + assert image_channel_mode in {"RGB", "BGR"}, ( + "image_channel_mode should be 'RGB' or 'BGR', " + f"got {image_channel_mode}." + ) + + def __call__(self, images: list[NDArrayF32]) -> list[NDArrayF32]: + """Call function for Hue transformation.""" + factor = np.random.uniform(self.hue_range[0], self.hue_range[1]) + return _adjust_images( + images, TF.adjust_hue, factor, self.image_channel_mode + ) + + +@Transform(K.images, K.images) +class ColorJitter: + """Apply color jitter to images. + + Args: + brightness_range (tuple[float, float]): Range of brightness values. + contrast_range (tuple[float, float]): Range of contrast values. + saturation_range (tuple[float, float]): Range of saturation values. + hue_range (tuple[float, float]): Range of hue values. + image_channel_mode (str, optional): Image channel mode. Defaults to + "RGB". + """ + + def __init__( + self, + brightness_range: tuple[float, float] = (0.875, 1.125), + contrast_range: tuple[float, float] = (0.5, 1.5), + saturation_range: tuple[float, float] = (0.5, 1.5), + hue_range: tuple[float, float] = (-0.05, 0.05), + image_channel_mode: str = "RGB", + ): + """Init function for color jitter.""" + self.brightness_range = brightness_range + self.contrast_range = contrast_range + self.saturation_range = saturation_range + self.hue_range = hue_range + self.image_channel_mode = image_channel_mode + assert image_channel_mode in {"RGB", "BGR"}, ( + "image_channel_mode should be 'RGB' or 'BGR', " + f"got {image_channel_mode}." + ) + + def __call__(self, images: list[NDArrayF32]) -> list[NDArrayF32]: + """Call function for Hue transformation.""" + transform_order = np.random.permutation(4) + for transform in transform_order: + # apply photometric transforms in a random order + if transform == 0: + # random brightness + bfactor = np.random.uniform( + self.brightness_range[0], self.brightness_range[1] + ) + images = _adjust_images( + images, + TF.adjust_brightness, + bfactor, + self.image_channel_mode, + ) + elif transform == 1: + # random contrast + cfactor = np.random.uniform( + self.contrast_range[0], self.contrast_range[1] + ) + images = _adjust_images( + images, + TF.adjust_contrast, + cfactor, + self.image_channel_mode, + ) + elif transform == 2: + # random saturation + sfactor = np.random.uniform( + self.saturation_range[0], self.saturation_range[1] + ) + images = _adjust_images( + images, + TF.adjust_saturation, + sfactor, + self.image_channel_mode, + ) + elif transform == 3: + # random hue + hfactor = np.random.uniform( + self.hue_range[0], self.hue_range[1] + ) + images = _adjust_images( + images, TF.adjust_hue, hfactor, self.image_channel_mode + ) + return images + + +def _adjust_images( + images: list[NDArrayF32], + adjust_func: Callable[[Tensor, float], Tensor], + adj_factor: float, + image_channel_mode: str = "RGB", +) -> list[NDArrayF32]: + """Apply color transformation to images. + + Args: + images (list[NDArrayF32]): Image to be transformed. + adjust_func (Callable[[Tensor, float], Tensor]): Function to apply. + adj_factor (float): Adjustment factor. + image_channel_mode (str, optional): Image channel mode. Defaults to + "RGB". + + Returns: + list[NDArrayF32]: Transformed image. + """ + for i, image in enumerate(images): + if image_channel_mode == "BGR": + image = image[..., [2, 1, 0]] # convert to RGB + image_ = torch.from_numpy(image).permute(0, 3, 1, 2) / 255.0 + image_ = adjust_func(image_, adj_factor) * 255.0 + images[i] = image_.permute(0, 2, 3, 1).numpy() + if image_channel_mode == "BGR": + images[i] = images[i][..., [2, 1, 0]] # convert back to BGR + return images + + +@Transform(K.images, K.images) +class RandomHSV: + """Apply HSV transformation to images. + + Used by YOLOX. Modifed from: https://github.com/Megvii-BaseDetection/YOLOX. + + Args: + hue_delta (int): Delta for hue. + saturation_delta (int): Delta for saturation. + value_delta (int): Delta for value. + image_channel_mode (str, optional): Image channel mode. Defaults to + "BGR". + """ + + def __init__( + self, + hue_delta: int = 5, + saturation_delta: int = 30, + value_delta: int = 30, + image_channel_mode: str = "BGR", + ): + """Init function for HSV transformation.""" + assert OPENCV_AVAILABLE, "RandomHSV requires OpenCV to be installed." + self.hue_delta = hue_delta + self.saturation_delta = saturation_delta + self.value_delta = value_delta + self.image_channel_mode = image_channel_mode + assert image_channel_mode in {"RGB", "BGR"}, ( + "image_channel_mode should be 'RGB' or 'BGR', " + f"got {image_channel_mode}." + ) + + # pylint: disable=no-member + def __call__(self, images: list[NDArrayF32]) -> list[NDArrayF32]: + """Call function for Hue transformation.""" + for i, image in enumerate(images): + image = image[0].astype(np.uint8) + if self.image_channel_mode == "BGR": + image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) + else: + image = cv2.cvtColor(image, cv2.COLOR_RGB2HSV) + image = image.astype(np.int16) + hsv_gains = np.random.uniform(-1, 1, 3) * [ + self.hue_delta, + self.saturation_delta, + self.value_delta, + ] + # random selection of h, s, v + hsv_gains = (hsv_gains * np.random.randint(0, 2, 3)).astype( + np.int16 + ) + image[..., 0] = (image[..., 0] + hsv_gains[0]) % 180 + image[..., 1] = np.clip(image[..., 1] + hsv_gains[1], 0, 255) + image[..., 2] = np.clip(image[..., 2] + hsv_gains[2], 0, 255) + image = image.astype(np.uint8) + if self.image_channel_mode == "BGR": + cv2.cvtColor(image, cv2.COLOR_HSV2BGR, dst=image) + else: + cv2.cvtColor(image, cv2.COLOR_HSV2RGB, dst=image) + images[i] = image[None, ...].astype(np.float32) + return images diff --git a/vis4d/data/transforms/point_sampling.py b/vis4d/data/transforms/point_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..d83f39d51f3306ac3adb1841ebc76ae6dc0e8723 --- /dev/null +++ b/vis4d/data/transforms/point_sampling.py @@ -0,0 +1,253 @@ +"""Contains different Sampling Trasnforms for pointclouds.""" + +from __future__ import annotations + +import numpy as np + +from vis4d.common.typing import NDArrayInt, NDArrayNumber +from vis4d.data.const import CommonKeys as K + +from .base import Transform + + +@Transform(K.points3d, "transforms.sampling_idxs") +class GenerateSamplingIndices: + """Samples num_pts from the first dim of the provided data tensor. + + If num_pts > data.shape[0], the indices will be upsampled with + replacement. If num_pts < data.shape[0], the indices will be sampled + without replacement. + """ + + def __init__(self, num_pts: int) -> None: + """Creates an instance of the class. + + Args: + num_pts (int): Number of indices to sample + """ + self.num_pts = num_pts + + def __call__(self, data_list: list[NDArrayNumber]) -> list[NDArrayInt]: + """Samples num_pts from the first dim of the provided data tensor. + + If num_pts > data.shape[0], the indices will be upsampled with + replacement. If num_pts < data.shape[0], the indices will be sampled + without replacement. + + Args: + data_list (list[NDArrayNumber]): Data from which to sample indices. + + Returns: + list[NDArrayInt]: List of indices. + + Raises: + ValueError: If data is empty. + """ + data = data_list[0] + + if len(data) == 0: + raise ValueError("Data sample was empty!") + + if self.num_pts > len(data): + return [ + np.concatenate( + [ + np.arange(len(data)), + np.random.randint( + 0, len(data), self.num_pts - len(data) + ), + ] + ) + ] * len(data_list) + return [ + np.random.choice(len(data), self.num_pts, replace=False) + ] * len(data_list) + + +@Transform(K.points3d, "transforms.sampling_idxs") +class GenerateBlockSamplingIndices: + """Samples num_pts from the first dim of the provided data tensor. + + Makes sure that the sampled points are within a block of size block_size + centered around center_xyz. If num_pts > data.shape[0], the indices will + be upsampled with replacement. If num_pts < data.shape[0], the indices + will be sampled without replacement. + """ + + def __init__( + self, + num_pts: int, + block_dimensions: tuple[float, float, float], + center_point: tuple[float, float, float] | None = None, + ) -> None: + """Creates an instance of the class. + + Args: + num_pts (int): Number of indices to sample + block_dimensions (tuple[float, float, float]): Dimensions of the + block in x,y,z + center_point (tuple[float, float, float] | None): Center point of + the block in x,y,z. If None, the center will be sampled + randomly. + """ + self.block_dimensions = np.asarray(block_dimensions) + self.center_point = ( + np.asarray(center_point) if center_point is not None else None + ) + + self._idx_sampler = GenerateSamplingIndices(num_pts) + + def __call__(self, data_list: list[NDArrayNumber]) -> list[NDArrayInt]: + """Samples num_pts from the first dim of the provided data tensor.""" + data = data_list[0] + + if self.center_point is None: + center_point = data[np.random.choice(len(data), 1)] + else: + center_point = self.center_point + + max_box = center_point + self.block_dimensions / 2.0 + min_box = center_point - self.block_dimensions / 2.0 + + box_mask = np.logical_and( + np.all(data >= min_box, axis=1), + np.all(data <= max_box, axis=1), + ) + if box_mask.sum().item() == 0: # No valid data sample found! + return [np.array([], dtype=np.int32)] * len(data_list) + + idxs = self._idx_sampler([data[box_mask, ...]])[0] + + masked_idxs = np.arange(data.shape[0])[box_mask] + selected_idxs_global = masked_idxs[idxs] + return [selected_idxs_global] * len(data_list) + + +@Transform(K.points3d, "transforms.sampling_idxs") +class GenFullCovBlockSamplingIndices: + """Subsamples the pointcloud using blocks of a given size.""" + + def __init__( + self, + num_pts: int, + block_dimensions: tuple[float, float, float], + min_pts: int = 32, + ) -> None: + """Creates an instance of the class. + + Args: + num_pts (int): Number of points to sample for each block + block_dimensions (tuple[float, float, float]): Dimensions of the + block in x,y,z + min_pts (int): Minimum number of points in a block to be considered + valid + """ + self.num_pts = num_pts + self.min_pts = min_pts + self.block_dimensions = np.asarray(block_dimensions) + self._idx_sampler = GenerateBlockSamplingIndices( + num_pts=self.num_pts, + block_dimensions=block_dimensions, + ) + + def __call__( + self, coordinates_list: list[NDArrayNumber] + ) -> list[NDArrayInt]: + """Subsamples the pointcloud using blocks of a given size.""" + coordinates = coordinates_list[0] + + # Get bounding box for sampling + coord_min, coord_max = ( + np.min(coordinates, axis=0), + np.max(coordinates, axis=0), + ) + sampled_idxs = [] + hwl = coord_max - coord_min + num_blocks = np.ceil(hwl / self.block_dimensions).astype(np.int32) + + for idx_x in range(num_blocks[0]): + for idx_y in range(num_blocks[1]): + for idx_z in range(num_blocks[2]): + center_pt = ( + coord_min + + np.array([idx_x, idx_y, idx_z]) + * self.block_dimensions + + self.block_dimensions / 2.0 + ) + + self._idx_sampler.center_point = center_pt + selected_idxs = self._idx_sampler([coordinates])[0] + if selected_idxs.sum() >= self.min_pts: + sampled_idxs.append(selected_idxs) + return [np.stack(sampled_idxs)] * len(coordinates_list) # type: ignore + + +@Transform([K.points3d, "transforms.sampling_idxs"], K.points3d) +class SamplePoints: + """Subsamples points randomly. + + Samples 'num_pts' randomly from the provided data tensors using the + provided sampling indices. + + This transform is used to sample points from a pointcloud. The indices + are generated by the GenerateSamplingIndices transform. + + """ + + def __call__( + self, + data_list: list[NDArrayNumber], + selected_idxs_list: list[NDArrayInt], + ) -> list[NDArrayNumber]: + """Returns data[selected_idxs]. + + If the provided indices have two dimension (i.e n_masks, 64), then + this operation indices the data n_masks times and returns an array + """ + for i, (data, selected_idxs) in enumerate( + zip(data_list, selected_idxs_list) + ): + assert selected_idxs.ndim <= 2, "Indices must be 1D or 2D" + if selected_idxs.ndim == 2: + data_list[i] = np.stack( + [data[idxs, ...] for idxs in selected_idxs] + ) + else: + data_list[i] = data[selected_idxs, ...] + return data_list + + +@Transform([K.colors3d, "transforms.sampling_idxs"], K.colors3d) +class SampleColors(SamplePoints): + """Subsamples colors randomly. + + Samples 'num_pts' randomly from the provided data tensors using the + provided sampling indices. + + This transform is used to sample colors from a pointcloud. The indices + are generated by the GenerateSamplingIndices transform. + """ + + +@Transform([K.semantics3d, "transforms.sampling_idxs"], K.semantics3d) +class SampleSemantics(SamplePoints): + """Subsamples semantics randomly. + + Samples 'num_pts' randomly from the provided data tensors using the + provided sampling indices. + + This transform is used to sample semantics from a pointcloud. The indices + are generated by the GenerateSamplingIndices transform. + """ + + +@Transform([K.instances3d, "transforms.sampling_idxs"], K.instances3d) +class SampleInstances(SamplePoints): + """Subsamples instances randomly. + + Samples 'num_pts' randomly from the provided data tensors using the + provided sampling indices. + + This transform is used to sample instances from a pointcloud. The indices + are generated by the GenerateSamplingIndices transform. + """ diff --git a/vis4d/data/transforms/points.py b/vis4d/data/transforms/points.py new file mode 100644 index 0000000000000000000000000000000000000000..d981f25927c222daae4476ae966ffa32180d5ec7 --- /dev/null +++ b/vis4d/data/transforms/points.py @@ -0,0 +1,269 @@ +"""Pointwise transformations.""" + +from __future__ import annotations + +from typing import TypedDict + +import numpy as np + +from vis4d.common.typing import NDArrayFloat +from vis4d.data.const import CommonKeys as K + +from .base import Transform + + +@Transform(in_keys=K.points3d, out_keys="transforms.pc_bounds") +class GenPcBounds: + """Extracts the max and min values of the loaded points.""" + + def __call__( + self, coordinates_list: list[NDArrayFloat] + ) -> list[NDArrayFloat]: + """Extracts the max and min values of the pointcloud.""" + coordinates = coordinates_list[0] + + pc_bounds = [np.stack([coordinates.min(0), coordinates.max(0)])] * len( + coordinates_list + ) + + return pc_bounds + + +@Transform(in_keys=(K.points3d, "trasforms.pc_bounds"), out_keys=K.points3d) +class NormalizeByMaxBounds: + """Normalizes the pointcloud by the max bounds.""" + + def __init__(self, axes: tuple[int, int, int] = (0, 1, 2)) -> None: + """Creates an instance of the class. + + Args: + axes (tuple[int, int, int]): Over which axes to apply + normalization. + """ + self.axes = axes + + def __call__( + self, + coords_list: list[NDArrayFloat], + pc_bounds_list: list[NDArrayFloat], + ) -> list[NDArrayFloat]: + """Applies the normalization.""" + for i, (coords, pc_bounds) in enumerate( + zip(coords_list, pc_bounds_list) + ): + max_bound = np.max(np.abs(pc_bounds), axis=0) + for ax in self.axes: + coords[:, ax] = coords[:, ax] / max_bound[ax] + coords_list[i] = coords + return coords_list + + +@Transform(in_keys=K.points3d, out_keys=K.points3d) +class CenterAndNormalize: + """Centers and normalizes the pointcloud.""" + + def __init__(self, centering: bool = True, normalize: bool = True) -> None: + """Creates an instance of the class. + + Args: + centering (bool): Whether to center the pointcloud + normalize (bool): Whether to normalize the pointcloud + """ + self.centering = centering + self.normalize = normalize + + def __call__(self, coords_list: list[NDArrayFloat]) -> list[NDArrayFloat]: + """Applies the Center and Normalization operations.""" + for i, coords in enumerate(coords_list): + if self.centering: + coords = coords - np.mean(coords, axis=0) + if self.normalize: + coords = coords / np.max(np.sqrt(np.sum(coords**2, axis=-1))) + coords_list[i] = coords + return coords_list + + +@Transform(in_keys=K.points3d, out_keys=K.points3d) +class AddGaussianNoise: + """Adds random normal distributed noise with given std to the data. + + Args: + std (float): Standard Deviation of the noise + """ + + def __init__(self, noise_level: float = 0.01): + """Creates an instance of the class. + + Args: + noise_level (float): The noise level. Standard deviation for + the gaussian noise. + """ + self.noise_level = noise_level + + def __call__( + self, coordinates_list: list[NDArrayFloat] + ) -> list[NDArrayFloat]: + """Adds gaussian noise to the coordiantes.""" + for i, coordinates in enumerate(coordinates_list): + coordinates[i] = ( + coordinates + + np.random.randn(*coordinates.shape) * self.noise_level + ) + return coordinates_list + + +@Transform(in_keys=K.points3d, out_keys=K.points3d) +class AddUniformNoise: + """Adds random normal distributed noise with given std to the data. + + Args: + std (float): Standard Deviation of the noise + """ + + def __init__(self, noise_level: float = 0.01): + """Creates an instance of the class. + + Args: + noise_level (float): The noise level. Half the range of the + uniform noise. The noise is sampled from + [-noise_level, noise_level]. + """ + self.noise_level = noise_level + + def __call__( + self, coordinates_list: list[NDArrayFloat] + ) -> list[NDArrayFloat]: + """Adds uniform noise to the coordinates.""" + for i, coordinates in enumerate(coordinates_list): + coordinates_list[i] = coordinates + np.random.uniform( + -self.noise_level, self.noise_level, coordinates.shape + ) + return coordinates_list + + +class SE3Transform(TypedDict): + """Parameters for Resize.""" + + translation: NDArrayFloat + rotation: NDArrayFloat + + +def _gen_random_se3_transform( + translation_min: NDArrayFloat, + translation_max: NDArrayFloat, + rotation_min: NDArrayFloat, + rotation_max: NDArrayFloat, +) -> SE3Transform: + """Creates a random SE3 Transforms. + + The transform is generated by sampling a random translation and + rotation from a uniform distribution. + """ + angle = np.random.uniform(rotation_min, rotation_max) + translation = np.random.uniform(translation_min, translation_max) + cos_x, sin_x = np.cos(angle[0]), np.sin(angle[0]) + cos_y, sin_y = np.cos(angle[1]), np.sin(angle[1]) + cos_z, sin_z = np.cos(angle[2]), np.sin(angle[2]) + rotx = np.array([[1, 0, 0], [0, cos_x, -sin_x], [0, sin_x, cos_x]]) + roty = np.array([[cos_y, 0, sin_y], [0, 1, 0], [-sin_y, 0, cos_y]]) + rotz = np.array([[cos_z, -sin_z, 0], [sin_z, cos_z, 0], [0, 0, 1]]) + rot = np.dot(rotz, np.dot(roty, rotx)) + return SE3Transform(translation=translation, rotation=rot) + + +@Transform(in_keys=K.points3d, out_keys=K.points3d) +class ApplySE3Transform: + """Applies a given SE3 Transform to the data.""" + + def __init__( + self, + translation_min: tuple[float, float, float] = (0.0, 0.0, 0.0), + translation_max: tuple[float, float, float] = (0.0, 0.0, 0.0), + rotation_min: tuple[float, float, float] = (0.0, 0.0, 0.0), + rotation_max: tuple[float, float, float] = (0.0, 0.0, 0.0), + ) -> None: + """Creates an instance of the class. + + Args: + translation_min (tuple[float, float, float]): Minimum translation. + translation_max (tuple[float, float, float]): Maximum translation. + rotation_min (tuple[float, float, float]): Minimum euler rotation + angles [rad]. Applied in the order rot_x -> rot_y -> rot_z. + rotation_max (tuple[float, float, float]): Maximum euler rotation + angles [rad]. Applied in the order rot_x -> rot_y -> rot_z. + """ + self.translation_min = np.asarray(translation_min) + self.translation_max = np.asarray(translation_max) + self.rotation_min = np.asarray(rotation_min) + self.rotation_max = np.asarray(rotation_max) + + def __call__( + self, coordinates_list: list[NDArrayFloat] + ) -> list[NDArrayFloat]: + """Applies a SE3 Transform.""" + for i, coordinates in enumerate(coordinates_list): + transform = _gen_random_se3_transform( + self.translation_min, + self.translation_max, + self.rotation_min, + self.rotation_max, + ) + if coordinates.shape[-1] == 3: + coordinates_list[i] = ( + transform["rotation"] @ coordinates.T + ).T + transform["translation"] + elif coordinates.shape[-2] == 3: + coordinates_list[i] = ( + transform["rotation"] @ coordinates + ).T + transform["translation"] + else: + raise ValueError( + f"Invalid shape for coordinates: {coordinates.shape}" + ) + return coordinates_list + + +class ApplySO3Transform(ApplySE3Transform): + """Applies a given SO3 Transform to the data.""" + + def __call__( + self, coordinates_list: list[NDArrayFloat] + ) -> list[NDArrayFloat]: + """Applies a given SO3 Transform to the data.""" + for i, coordinates in enumerate(coordinates_list): + transform = _gen_random_se3_transform( + self.translation_min, + self.translation_max, + self.rotation_min, + self.rotation_max, + )["rotation"] + if coordinates.shape[-1] == 3: + coordinates_list[i] = (transform @ coordinates.T).T + elif coordinates.shape[-2] == 3: + coordinates_list[i] = (transform @ coordinates).T + else: + raise ValueError( + f"Invalid shape for coordinates: {coordinates.shape}" + ) + return coordinates_list + + +@Transform(in_keys=K.points3d, out_keys=K.points3d) +class TransposeChannels: + """Transposes some predifined channels.""" + + def __init__(self, channels: tuple[int, int] = (-1, -2)): + """Creates an instance of the class. + + Args: + channels (tuple[int, int]): Tuple of channels to transpose + """ + self.channels = channels + + def __call__( + self, coordinates_list: list[NDArrayFloat] + ) -> list[NDArrayFloat]: + """Transposes some predifined channels.""" + for i, coordinates in enumerate(coordinates_list): + coordinates_list[i] = coordinates.transpose(*self.channels) + return coordinates_list diff --git a/vis4d/data/transforms/post_process.py b/vis4d/data/transforms/post_process.py new file mode 100644 index 0000000000000000000000000000000000000000..b6605c201a8b8e43412cc7597dacb7f7c9ad821a --- /dev/null +++ b/vis4d/data/transforms/post_process.py @@ -0,0 +1,161 @@ +"""Post process after transformation.""" + +from __future__ import annotations + +import torch + +from vis4d.common.typing import NDArrayF32, NDArrayI64 +from vis4d.data.const import CommonKeys as K +from vis4d.op.box.box2d import bbox_area, bbox_clip + +from .base import Transform + + +@Transform( + in_keys=[ + K.boxes2d, + K.boxes2d_classes, + K.boxes2d_track_ids, + K.input_hw, + K.boxes3d, + K.boxes3d_classes, + K.boxes3d_track_ids, + ], + out_keys=[ + K.boxes2d, + K.boxes2d_classes, + K.boxes2d_track_ids, + K.boxes3d, + K.boxes3d_classes, + K.boxes3d_track_ids, + ], +) +class PostProcessBoxes2D: + """Post process after transformation.""" + + def __init__( + self, min_area: float = 7.0 * 7.0, clip_bboxes_to_image: bool = True + ) -> None: + """Creates an instance of the class. + + Args: + min_area (float): Minimum area of the bounding box. Defaults to + 7.0 * 7.0. + clip_bboxes_to_image (bool): Whether to clip the bounding boxes to + the image size. Defaults to True. + """ + self.min_area = min_area + self.clip_bboxes_to_image = clip_bboxes_to_image + + def __call__( + self, + boxes_list: list[NDArrayF32], + classes_list: list[NDArrayI64], + track_ids_list: list[NDArrayI64] | None, + input_hw_list: list[tuple[int, int]], + boxes3d_list: list[NDArrayF32] | None, + boxes3d_classes_list: list[NDArrayI64] | None, + boxes3d_track_ids_list: list[NDArrayI64] | None, + ) -> tuple[ + list[NDArrayF32], + list[NDArrayI64], + list[NDArrayI64] | None, + list[NDArrayF32] | None, + list[NDArrayI64] | None, + list[NDArrayI64] | None, + ]: + """Post process according to boxes2D after transformation. + + Args: + boxes_list (list[NDArrayF32]): The bounding boxes to be post + processed. + classes_list (list[NDArrayF32]): The classes of the bounding boxes. + track_ids_list (list[NDArrayI64] | None): The track ids of the + bounding boxes. + input_hw_list (list[tuple[int, int]]): The height and width of the + input image. + boxes3d_list (list[NDArrayF32] | None): The 3D bounding boxes to be + post processed. + boxes3d_classes_list (list[NDArrayI64] | None): The classes of the + 3D bounding boxes. + boxes3d_track_ids_list (list[NDArrayI64] | None): The track ids of + the 3D bounding boxes. + + Returns: + tuple[list[NDArrayF32], list[NDArrayI64], list[NDArrayI64] | None, + list[NDArrayF32] | None, list[NDArrayI64] | None, + list[NDArrayI64] | None]: The post processed results. + """ + new_track_ids: list[NDArrayI64] | None = ( + [] if track_ids_list is not None else None + ) + new_boxes3d: list[NDArrayF32] | None = ( + [] if boxes3d_list is not None else None + ) + new_boxes3d_classes: list[NDArrayI64] | None = ( + [] if boxes3d_classes_list is not None else None + ) + new_boxes3d_track_ids: list[NDArrayI64] | None = ( + [] if boxes3d_track_ids_list is not None else None + ) + for i, (boxes, classes) in enumerate(zip(boxes_list, classes_list)): + boxes_ = torch.from_numpy(boxes) + if self.clip_bboxes_to_image: + boxes_ = bbox_clip(boxes_, input_hw_list[i]) + + keep = (bbox_area(boxes_) >= self.min_area).numpy() + + boxes_list[i] = boxes[keep] + classes_list[i] = classes[keep] + + if track_ids_list is not None: + assert new_track_ids is not None + new_track_ids.append(track_ids_list[i][keep]) + + if boxes3d_list is not None: + assert new_boxes3d is not None + new_boxes3d.append(boxes3d_list[i][keep]) + + if boxes3d_classes_list is not None: + assert new_boxes3d_classes is not None + new_boxes3d_classes.append(boxes3d_classes_list[i][keep]) + + if boxes3d_track_ids_list is not None: + assert new_boxes3d_track_ids is not None + new_boxes3d_track_ids.append(boxes3d_track_ids_list[i][keep]) + + return ( + boxes_list, + classes_list, + new_track_ids, + new_boxes3d, + new_boxes3d_classes, + new_boxes3d_track_ids, + ) + + +@Transform(in_keys=[K.boxes2d_track_ids], out_keys=[K.boxes2d_track_ids]) +class RescaleTrackIDs: + """Rescale track ids.""" + + def __call__(self, track_ids_list: list[NDArrayI64]) -> list[NDArrayI64]: + """Rescale the track ids. + + Args: + track_ids_list (list[NDArrayI64]): The track ids to be + rescaled. + + Returns: + list[NDArrayI64]: The rescaled track ids. + """ + track_ids_all: dict[int, int] = {} + for track_ids in track_ids_list: + for track_id in track_ids: + if track_id not in track_ids_all: + track_ids_all[track_id] = len(track_ids_all) + + for track_ids in track_ids_list: + for i, track_id in enumerate(track_ids): + track_ids[i] = track_ids_all[track_id] + + return track_ids_list diff --git a/vis4d/data/transforms/random_erasing.py b/vis4d/data/transforms/random_erasing.py new file mode 100644 index 0000000000000000000000000000000000000000..ace6af900b6629625f6ed60b2f4cec6623c73be4 --- /dev/null +++ b/vis4d/data/transforms/random_erasing.py @@ -0,0 +1,91 @@ +"""Random erasing data augmentation.""" + +import numpy as np + +from vis4d.common.typing import NDArrayNumber +from vis4d.data.const import CommonKeys as K + +from .base import Transform + + +@Transform(in_keys=K.images, out_keys=K.images) +class RandomErasing: + """Randomly erase a rectangular region in an image tensor.""" + + def __init__( + self, + min_area: float = 0.02, + max_area: float = 0.4, + min_aspect_ratio: float = 0.3, + max_aspect_ratio: float = 1 / 0.3, + mean: tuple[float, float, float] = (0.0, 0.0, 0.0), + num_attempt: int = 10, + ): + """Creates an instance of RandomErasing. + + Recommended to use this transform after normalization. The erased + region will be filled with the mean value. See + `https://arxiv.org/abs/1708.04896`. + + Args: + min_area (float, optional): Minimum area of the erased region. + Defaults to 0.02. + max_area (float, optional): Maximum area of the erased region. + Defaults to 0.4. + min_aspect_ratio (float, optional): Minimum aspect ratio of the + erased region. Defaults to 0.3. + max_aspect_ratio (float, optional): Maximum aspect ratio of the + erased region. Defaults to 1 / 0.3. + mean (tuple[float, float, float], optional): Mean of the dataset. + Defaults to (0.0, 0.0, 0.0). + num_attempt (int, optional): Number of maximum attempts to find a + valid erased region. This is used to avoid infinite attempts of + resampling the region, though such cases are very unlikely to + happen. Defaults to 10. + + Returns: + Callable: A function that takes a tensor of shape [N, H, W, C] and + returns a tensor of the same shape. + """ + self.min_area = min_area + self.max_area = max_area + self.min_aspect_ratio = min_aspect_ratio + self.max_aspect_ratio = max_aspect_ratio + self.mean = mean + self.num_attempt = num_attempt + + def do_erasing(self, images: NDArrayNumber) -> NDArrayNumber: + """Execute the random erasing.""" + fill = np.array(self.mean) + for i in range(images.shape[0]): + image = images[i] + h, w = image.shape[0:2] + area = h * w + + for _ in range(self.num_attempt): + target_area = ( + np.random.uniform(self.min_area, self.max_area) * area + ) + aspect_ratio = np.random.uniform( + self.min_aspect_ratio, self.max_aspect_ratio + ) + h_erase = int(round(np.sqrt(target_area * aspect_ratio))) + w_erase = int(round(np.sqrt(target_area / aspect_ratio))) + if w_erase < w and h_erase < h: + x_erase = np.random.randint(0, w - w_erase) + y_erase = np.random.randint(0, h - h_erase) + image[ + y_erase : y_erase + h_erase, + x_erase : x_erase + w_erase, + :, + ] = fill + break + return images + + def __call__( + self, images_list: list[NDArrayNumber] + ) -> list[NDArrayNumber]: + """Execute the transform.""" + for i, images in enumerate(images_list): + images_list[i] = self.do_erasing(images) + return images_list diff --git a/vis4d/data/transforms/resize.py b/vis4d/data/transforms/resize.py new file mode 100644 index 0000000000000000000000000000000000000000..c829f7a0935e771578613e47cdd7dae9f58fc4b0 --- /dev/null +++ b/vis4d/data/transforms/resize.py @@ -0,0 +1,539 @@ +"""Resize transformation.""" + +from __future__ import annotations + +import random +from typing import TypedDict + +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor + +from vis4d.common.imports import OPENCV_AVAILABLE +from vis4d.common.typing import NDArrayF32 +from vis4d.data.const import CommonKeys as K +from vis4d.op.box.box2d import transform_bbox + +from .base import Transform + +if OPENCV_AVAILABLE: + import cv2 + from cv2 import ( # pylint: disable=no-member,no-name-in-module + INTER_AREA, + INTER_CUBIC, + INTER_LANCZOS4, + INTER_LINEAR, + INTER_NEAREST, + ) +else: + raise ImportError("Please install opencv-python to use this module.") + + +class ResizeParam(TypedDict): + """Parameters for Resize.""" + + target_shape: tuple[int, int] + scale_factor: tuple[float, float] + + +@Transform(K.images, ["transforms.resize", K.input_hw]) +class GenResizeParameters: + """Generate the parameters for a resize operation.""" + + def __init__( + self, + shape: tuple[int, int] | list[tuple[int, int]], + keep_ratio: bool = False, + multiscale_mode: str = "range", + scale_range: tuple[float, float] = (1.0, 1.0), + align_long_edge: bool = False, + resize_short_edge: bool = False, + allow_overflow: bool = False, + fixed_scale: bool = False, + ) -> None: + """Creates an instance of the class. + + Args: + shape (tuple[int, int] | list[tuple[int, int]]): Image shape to + be resized to in (H, W) format. In multiscale mode 'list', + shape represents the list of possible shapes for resizing. + keep_ratio (bool, optional): If aspect ratio of the original image + should be kept, the new shape will modified to fit the aspect + ratio of the original image. Defaults to False. + multiscale_mode (str, optional): one of [range, list]. Defaults to + "range". + scale_range (tuple[float, float], optional): Range of sampled image + scales in range mode, e.g. (0.8, 1.2), indicating minimum of + 0.8 * shape and maximum of 1.2 * shape. Defaults to (1.0, 1.0). + align_long_edge (bool, optional): If keep_ratio=true, this option + indicates if shape should be automatically aligned with the + long edge of the original image, e.g. original shape=(100, 80), + shape to be resized=(100, 200) will yield (125, 100) as new + shape. Defaults to False. + resize_short_edge (bool, optional): If keep_ratio=true, this option + scale the image according to the short edge. e.g. original + shape=(80, 100), shape to be resized=(100, 200) will yield + (100, 125) as new shape. Defaults to False. + allow_overflow (bool, optional): If set to True, we scale the image + to the smallest size such that it is no smaller than shape. + Otherwise, we scale the image to the largest size such that it + is no larger than shape. Defaults to False. + fixed_scale (bool, optional): If set to True, we scale the image + without offset. Defaults to False. + """ + self.shape = shape + self.keep_ratio = keep_ratio + + assert multiscale_mode in {"list", "range"} + self.multiscale_mode = multiscale_mode + + assert ( + scale_range[0] <= scale_range[1] + ), f"Invalid scale range: {scale_range[1]} < {scale_range[0]}" + self.scale_range = scale_range + + self.align_long_edge = align_long_edge + self.resize_short_edge = resize_short_edge + self.allow_overflow = allow_overflow + self.fixed_scale = fixed_scale + + def _get_target_shape( + self, input_shape: tuple[int, int] + ) -> tuple[int, int]: + """Generate possibly random target shape.""" + if self.multiscale_mode == "range": + assert isinstance( + self.shape, tuple + ), "Specify shape as tuple when using multiscale mode range." + if self.scale_range[0] < self.scale_range[1]: # do multi-scale + w_scale = ( + random.uniform(0, 1) + * (self.scale_range[1] - self.scale_range[0]) + + self.scale_range[0] + ) + h_scale = ( + random.uniform(0, 1) + * (self.scale_range[1] - self.scale_range[0]) + + self.scale_range[0] + ) + else: + h_scale = w_scale = 1.0 + + shape = int(self.shape[0] * h_scale), int(self.shape[1] * w_scale) + else: + assert isinstance( + self.shape, list + ), "Specify shape as list when using multiscale mode list." + shape = random.choice(self.shape) + + return get_resize_shape( + input_shape, + shape, + self.keep_ratio, + self.align_long_edge, + self.resize_short_edge, + self.allow_overflow, + self.fixed_scale, + ) + + def __call__( + self, images: list[NDArrayF32] + ) -> tuple[list[ResizeParam], list[tuple[int, int]]]: + """Compute the parameters and put them in the data dict.""" + image = images[0] + + im_shape = (image.shape[1], image.shape[2]) + target_shape = self._get_target_shape(im_shape) + scale_factor = ( + target_shape[1] / im_shape[1], + target_shape[0] / im_shape[0], + ) + + resize_params = [ + ResizeParam(target_shape=target_shape, scale_factor=scale_factor) + ] * len(images) + target_shapes = [target_shape] * len(images) + + return resize_params, target_shapes + + +def get_resize_shape( + original_shape: tuple[int, int], + new_shape: tuple[int, int], + keep_ratio: bool = True, + align_long_edge: bool = False, + resize_short_edge: bool = False, + allow_overflow: bool = False, + fixed_scale: bool = False, +) -> tuple[int, int]: + """Get shape for resize, considering keep_ratio and align_long_edge. + + Args: + original_shape (tuple[int, int]): Original shape in [H, W]. + new_shape (tuple[int, int]): New shape in [H, W]. + keep_ratio (bool, optional): Whether to keep the aspect ratio. + Defaults to True. + align_long_edge (bool, optional): Whether to align the long edge of + the original shape with the long edge of the new shape. + Defaults to False. + resize_short_edge (bool, optional): Whether to resize according to the + short edge. Defaults to False. + allow_overflow (bool, optional): Whether to allow overflow. + Defaults to False. + fixed_scale (bool, optional): Whether to use fixed scale. + + Returns: + tuple[int, int]: The new shape in [H, W]. + """ + h, w = original_shape + new_h, new_w = new_shape + + if keep_ratio: + if allow_overflow: + comp_fn = max + else: + comp_fn = min + + if align_long_edge: + long_edge, short_edge = max(new_shape), min(new_shape) + scale_factor = comp_fn( + long_edge / max(h, w), short_edge / min(h, w) + ) + elif resize_short_edge: + short_edge = min(original_shape) + new_short_edge = min(new_shape) + scale_factor = new_short_edge / short_edge + else: + scale_factor = comp_fn(new_w / w, new_h / h) + + if fixed_scale: + offset = 0.0 + else: + offset = 0.5 + + new_h = int(h * scale_factor + offset) + new_w = int(w * scale_factor + offset) + + return new_h, new_w + + +@Transform([K.images, "transforms.resize.target_shape"], K.images) +class ResizeImages: + """Resize Images.""" + + def __init__( + self, + interpolation: str = "bilinear", + antialias: bool = False, + imresize_backend: str = "torch", + ) -> None: + """Creates an instance of the class. + + Args: + interpolation (str, optional): Interpolation method. One of + ["nearest", "bilinear", "bicubic"]. Defaults to "bilinear". + antialias (bool): Whether to use antialiasing. Defaults to False. + imresize_backend (str): One of torch, cv2. Defaults to torch. + """ + self.interpolation = interpolation + self.antialias = antialias + self.imresize_backend = imresize_backend + assert imresize_backend in { + "torch", + "cv2", + }, f"Invalid imresize backend: {imresize_backend}" + + def __call__( + self, images: list[NDArrayF32], target_shapes: list[tuple[int, int]] + ) -> list[NDArrayF32]: + """Resize an image of dimensions [N, H, W, C]. + + Args: + image (Tensor): The image. + target_shape (tuple[int, int]): The target shape after resizing. + + Returns: + list[NDArrayF32]: Resized images according to parameters in resize. + """ + for i, (image, target_shape) in enumerate(zip(images, target_shapes)): + images[i] = resize_image( + image, + target_shape, + interpolation=self.interpolation, + antialias=self.antialias, + backend=self.imresize_backend, + ) + return images + + +def resize_image( + inputs: NDArrayF32, + shape: tuple[int, int], + interpolation: str = "bilinear", + antialias: bool = False, + backend: str = "torch", +) -> NDArrayF32: + """Resize image.""" + if backend == "torch": + image = torch.from_numpy(inputs).permute(0, 3, 1, 2) + image = resize_tensor(image, shape, interpolation, antialias) + return image.permute(0, 2, 3, 1).numpy() + + if backend == "cv2": + cv2_interp_codes = { + "nearest": INTER_NEAREST, + "bilinear": INTER_LINEAR, + "bicubic": INTER_CUBIC, + "area": INTER_AREA, + "lanczos": INTER_LANCZOS4, + } + return cv2.resize( # pylint: disable=no-member, unsubscriptable-object + inputs[0].astype(np.uint8), + (shape[1], shape[0]), + interpolation=cv2_interp_codes[interpolation], + )[None, ...].astype(np.float32) + + raise ValueError(f"Invalid imresize backend: {backend}") + + +@Transform([K.boxes2d, "transforms.resize.scale_factor"], K.boxes2d) +class ResizeBoxes2D: + """Resize list of 2D bounding boxes.""" + + def __call__( + self, + boxes_list: list[NDArrayF32], + scale_factors: list[tuple[float, float]], + ) -> list[NDArrayF32]: + """Resize 2D bounding boxes. + + Args: + boxes_list: (list[NDArrayF32]): The bounding boxes to be resized. + scale_factors (list[tuple[float, float]]): scaling factors. + + Returns: + list[NDArrayF32]: Resized bounding boxes according to parameters in + resize. + """ + for i, (boxes, scale_factor) in enumerate( + zip(boxes_list, scale_factors) + ): + boxes_ = torch.from_numpy(boxes) + scale_matrix = torch.eye(3) + scale_matrix[0, 0] = scale_factor[0] + scale_matrix[1, 1] = scale_factor[1] + boxes_list[i] = transform_bbox(scale_matrix, boxes_).numpy() + return boxes_list + + +@Transform( + [ + K.depth_maps, + "transforms.resize.target_shape", + "transforms.resize.scale_factor", + ], + K.depth_maps, +) +class ResizeDepthMaps: + """Resize depth maps.""" + + def __init__( + self, + interpolation: str = "nearest", + rescale_depth_values: bool = False, + check_scale_factors: bool = False, + ): + """Initialize the transform. + + Args: + interpolation (str, optional): Interpolation method. One of + ["nearest", "bilinear", "bicubic"]. Defaults to "nearest". + rescale_depth_values (bool, optional): If the depth values should + be rescaled according to the new scale factor. Defaults to + False. This is useful if we want to keep the intrinsic + parameters of the camera the same. + check_scale_factors (bool, optional): If the scale factors should + be checked to ensure they are the same. Defaults to False. + If False, the scale factor is assumed to be the same for both + dimensions and will just use the first scale factor. + """ + self.interpolation = interpolation + self.rescale_depth_values = rescale_depth_values + self.check_scale_factors = check_scale_factors + + def __call__( + self, + depth_maps: list[NDArrayF32], + target_shapes: list[tuple[int, int]], + scale_factors: list[tuple[float, float]], + ) -> list[NDArrayF32]: + """Resize depth maps.""" + for i, (depth_map, target_shape, scale_factor) in enumerate( + zip(depth_maps, target_shapes, scale_factors) + ): + depth_map_ = torch.from_numpy(depth_map) + depth_map_ = ( + resize_tensor( + depth_map_.float().unsqueeze(0).unsqueeze(0), + target_shape, + interpolation=self.interpolation, + ) + .type(depth_map_.dtype) + .squeeze(0) + .squeeze(0) + ) + if self.rescale_depth_values: + if self.check_scale_factors: + assert np.isclose( + scale_factor[0], scale_factor[1], atol=1e-4 + ), "Depth map scale factors must be the same" + depth_map_ /= scale_factor[0] + depth_maps[i] = depth_map_.numpy() + return depth_maps + + +@Transform( + [ + K.optical_flows, + "transforms.resize.target_shape", + "transforms.resize.scale_factor", + ], + K.optical_flows, +) +class ResizeOpticalFlows: + """Resize optical flows.""" + + def __init__(self, normalized_flow: bool = True): + """Create a ResizeOpticalFlows instance. + + Args: + normalized_flow (bool): Whether the optical flow is normalized. + Defaults to True. If false, the optical flow will be scaled + according to the scale factor. + """ + self.normalized_flow = normalized_flow + + def __call__( + self, + optical_flows: list[NDArrayF32], + target_shapes: list[tuple[int, int]], + scale_factors: list[tuple[float, float]], + ) -> list[NDArrayF32]: + """Resize optical flows.""" + for i, (optical_flow, target_shape, scale_factor) in enumerate( + zip(optical_flows, target_shapes, scale_factors) + ): + optical_flow_ = torch.from_numpy(optical_flow).permute(2, 0, 1) + optical_flow_ = ( + resize_tensor( + optical_flow_.float().unsqueeze(0), + target_shape, + interpolation="bilinear", + ) + .type(optical_flow_.dtype) + .squeeze(0) + .permute(1, 2, 0) + ) + # scale optical flows + if not self.normalized_flow: + optical_flow_[:, :, 0] *= scale_factor[0] + optical_flow_[:, :, 1] *= scale_factor[1] + optical_flows[i] = optical_flow_.numpy() + return optical_flows + + +@Transform( + [K.instance_masks, "transforms.resize.target_shape"], K.instance_masks +) +class ResizeInstanceMasks: + """Resize instance segmentation masks.""" + + def __call__( + self, + masks_list: list[NDArrayF32], + target_shapes: list[tuple[int, int]], + ) -> list[NDArrayF32]: + """Resize masks.""" + for i, (masks, target_shape) in enumerate( + zip(masks_list, target_shapes) + ): + if len(masks) == 0: # handle empty masks + continue + masks_ = torch.from_numpy(masks) + masks_ = ( + resize_tensor( + masks_.float().unsqueeze(1), + target_shape, + interpolation="nearest", + ) + .type(masks_.dtype) + .squeeze(1) + ) + masks_list[i] = masks_.numpy() + return masks_list + + +@Transform([K.seg_masks, "transforms.resize.target_shape"], K.seg_masks) +class ResizeSegMasks: + """Resize segmentation masks.""" + + def __call__( + self, + masks_list: list[NDArrayF32], + target_shape_list: list[tuple[int, int]], + ) -> list[NDArrayF32]: + """Resize masks.""" + for i, (masks, target_shape) in enumerate( + zip(masks_list, target_shape_list) + ): + masks_ = torch.from_numpy(masks) + masks_ = ( + resize_tensor( + masks_.float().unsqueeze(0).unsqueeze(0), + target_shape, + interpolation="nearest", + ) + .type(masks_.dtype) + .squeeze(0) + .squeeze(0) + ) + masks_list[i] = masks_.numpy() + return masks_list + + +@Transform([K.intrinsics, "transforms.resize.scale_factor"], K.intrinsics) +class ResizeIntrinsics: + """Resize Intrinsics.""" + + def __call__( + self, + intrinsics: list[NDArrayF32], + scale_factors: list[tuple[float, float]], + ) -> list[NDArrayF32]: + """Scale camera intrinsics when resizing.""" + for i, scale_factor in enumerate(scale_factors): + scale_matrix = np.eye(3, dtype=np.float32) + scale_matrix[0, 0] *= scale_factor[0] + scale_matrix[1, 1] *= scale_factor[1] + intrinsics[i] = scale_matrix @ intrinsics[i] + return intrinsics + + +def resize_tensor( + inputs: Tensor, + shape: tuple[int, int], + interpolation: str = "bilinear", + antialias: bool = False, +) -> Tensor: + """Resize Tensor.""" + assert interpolation in {"nearest", "bilinear", "bicubic"} + align_corners = None if interpolation == "nearest" else False + output = F.interpolate( + inputs, + shape, + mode=interpolation, + align_corners=align_corners, + antialias=antialias, + ) + return output diff --git a/vis4d/data/transforms/select_sensor.py b/vis4d/data/transforms/select_sensor.py new file mode 100644 index 0000000000000000000000000000000000000000..534ef6ef3f464d1ca8f9ab688a15657d423c55a5 --- /dev/null +++ b/vis4d/data/transforms/select_sensor.py @@ -0,0 +1,52 @@ +# pylint: disable=no-member +"""Select Sensor transformation.""" +from vis4d.data.typing import DictData + +from .base import Transform + + +@Transform("data", "data") +class SelectSensor: + """Keep data from one sensor only but keep shared data. + + Note: The input data is assumed to be in the format of DictData[DictData], + i.e. a list of data dictionaries, each of which contains a dictionary of + either the data from a sensor or the shared data (metadata) for all + sensors. + + Example: + >>> data = [ + { + "sensor1": {"image": 1, "label": 2}, + "sensor2": {"image": 1, "label": 2}, + "meta": 3}, + }, + ] + >>> tsfm = SelectSensor( + sensor="sensor1", sensors=["sensor1", "sensor2"] + ) + >>> tsfm(data) + [{"image": 1, "label": 2, "meta": 3},] + """ + + def __init__(self, selected_sensor: str) -> None: + """Creates an instance of SelectSensor. + + Args: + selected_sensor (str): The name of the sensor to keep. + """ + self.selected_sensor = selected_sensor + + def __call__(self, batch: list[DictData]) -> list[DictData]: + """Select data from one sensor only.""" + output_batch = [] + for data in batch: + output_data = {} + for key in data.keys(): + if key in self.sensors: # type: ignore + if key == self.selected_sensor: + output_data.update(data[key]) + else: + output_data[key] = data[key] + output_batch.append(output_data) + return output_batch diff --git a/vis4d/data/transforms/to_tensor.py b/vis4d/data/transforms/to_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..c3ec5e3e3b22491a4471bc21aed64fffff07f8d8 --- /dev/null +++ b/vis4d/data/transforms/to_tensor.py @@ -0,0 +1,48 @@ +"""ToTensor transformation.""" + +import numpy as np +import torch + +from vis4d.data.const import CommonKeys as K +from vis4d.data.typing import DictData + +from .base import Transform + + +def _replace_arrays(data: DictData) -> None: + """Replace numpy arrays with tensors.""" + for key in data.keys(): + if key in [K.images, K.original_images]: + if not data[key].flags.c_contiguous: + data[key] = np.ascontiguousarray( + data[key].transpose(0, 3, 1, 2) + ) + data[key] = torch.from_numpy(data[key]) + else: + data[key] = ( + torch.from_numpy(data[key]) + .permute(0, 3, 1, 2) + .contiguous() + ) + elif isinstance(data[key], np.ndarray): + data[key] = torch.from_numpy(data[key]) + elif isinstance(data[key], dict): + _replace_arrays(data[key]) + elif isinstance(data[key], list): + for i, entry in enumerate(data[key]): + if isinstance(entry, np.ndarray): + data[key][i] = torch.from_numpy(entry) + + +@Transform("data", "data") +class ToTensor: + """Transform all entries in a list of DataDict from numpy to torch. + + Note that we reshape K.images from NHWC to NCHW. + """ + + def __call__(self, batch: list[DictData]) -> list[DictData]: + """Transform all entries to tensor.""" + for data in batch: + _replace_arrays(data) + return batch diff --git a/vis4d/data/typing.py b/vis4d/data/typing.py new file mode 100644 index 0000000000000000000000000000000000000000..a3da0b03b9b11c6976ab84f7270f3d4301707fc1 --- /dev/null +++ b/vis4d/data/typing.py @@ -0,0 +1,17 @@ +"""Type definitions related to the data pipeline. + +This file defines the data format `DictData` as an arbitrary dictionary that +can, in principle, hold arbitrary data. However, we provide `CommonKeys` in +`vis4d.data.const` to define the input format for commonly used input types, +so that the data pre-processing pipeline can take advantage of pre-defined +data formats that are necessary to properly pre-process a given data sample. +""" + +from __future__ import annotations + +from typing import Union + +from vis4d.common.typing import DictStrAny + +DictData = DictStrAny +DictDataOrList = Union[DictData, list[DictData]] diff --git a/vis4d/engine/__init__.py b/vis4d/engine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d0f099edf7f3185b4f5f5c2ce63a0e4a37244aee --- /dev/null +++ b/vis4d/engine/__init__.py @@ -0,0 +1 @@ +"""Vis4D run package.""" diff --git a/vis4d/engine/callbacks/__init__.py b/vis4d/engine/callbacks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e95dc07eef537b48ba383a38a40715afe7983037 --- /dev/null +++ b/vis4d/engine/callbacks/__init__.py @@ -0,0 +1,25 @@ +"""Callback modules.""" + +from .base import Callback +from .ema import EMACallback +from .evaluator import EvaluatorCallback +from .logging import LoggingCallback +from .scheduler import LRSchedulerCallback +from .visualizer import VisualizerCallback +from .yolox_callbacks import ( + YOLOXModeSwitchCallback, + YOLOXSyncNormCallback, + YOLOXSyncRandomResizeCallback, +) + +__all__ = [ + "Callback", + "EMACallback", + "EvaluatorCallback", + "LoggingCallback", + "VisualizerCallback", + "LRSchedulerCallback", + "YOLOXModeSwitchCallback", + "YOLOXSyncNormCallback", + "YOLOXSyncRandomResizeCallback", +] diff --git a/vis4d/engine/callbacks/base.py b/vis4d/engine/callbacks/base.py new file mode 100644 index 0000000000000000000000000000000000000000..0ce9cea86007b43bfbb75408d4be8574dd654f9a --- /dev/null +++ b/vis4d/engine/callbacks/base.py @@ -0,0 +1,85 @@ +"""Base module for callbacks.""" + +from __future__ import annotations + +import lightning.pytorch as pl +from torch import Tensor + +from vis4d.common.typing import DictStrArrNested +from vis4d.data.typing import DictData +from vis4d.engine.connectors import CallbackConnector + + +class Callback(pl.Callback): + """Base class for Callbacks.""" + + def __init__( + self, + epoch_based: bool = True, + train_connector: None | CallbackConnector = None, + test_connector: None | CallbackConnector = None, + ) -> None: + """Init callback. + + Args: + epoch_based (bool, optional): Whether the callback is epoch based. + Defaults to False. + train_connector (None | CallbackConnector, optional): Defines which + kwargs to use during training for different callbacks. Defaults + to None. + test_connector (None | CallbackConnector, optional): Defines which + kwargs to use during testing for different callbacks. Defaults + to None. + """ + self.epoch_based = epoch_based + self.train_connector = train_connector + self.test_connector = test_connector + + def setup( + self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str + ) -> None: + """Setup callback.""" + + def get_train_callback_inputs( + self, outputs: DictData, batch: DictData + ) -> dict[str, Tensor | DictStrArrNested]: + """Returns the data connector results for training. + + It extracts the required data from prediction and datas and passes it + to the next component with the provided new key. + + Args: + outputs (DictData): Outputs of the model. + batch (DictData): Batch data. + + Returns: + dict[str, Tensor | DictStrArrNested]: Data connector results. + + Raises: + AssertionError: If train connector is None. + """ + assert self.train_connector is not None, "Train connector is None." + + return self.train_connector(outputs, batch) + + def get_test_callback_inputs( + self, outputs: DictData, batch: DictData + ) -> dict[str, Tensor | DictStrArrNested]: + """Returns the data connector results for inference. + + It extracts the required data from prediction and datas and passes it + to the next component with the provided new key. + + Args: + outputs (DictData): Outputs of the model. + batch (DictData): Batch data. + + Returns: + dict[str, Tensor | DictStrArrNested]: Data connector results. + + Raises: + AssertionError: If test connector is None. + """ + assert self.test_connector is not None, "Test connector is None." + + return self.test_connector(outputs, batch) diff --git a/vis4d/engine/callbacks/ema.py b/vis4d/engine/callbacks/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..fb732fd99804031350dbd5e3e96122b6a79133e3 --- /dev/null +++ b/vis4d/engine/callbacks/ema.py @@ -0,0 +1,39 @@ +"""Callback for updating EMA model.""" + +from __future__ import annotations + +import lightning.pytorch as pl + +from vis4d.common.distributed import is_module_wrapper +from vis4d.data.typing import DictData +from vis4d.model.adapter import ModelEMAAdapter + +from .base import Callback +from .util import get_model + + +class EMACallback(Callback): + """Callback for EMA.""" + + def on_train_batch_end( # type: ignore + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: DictData, + batch: DictData, + batch_idx: int, + ) -> None: + """Hook to run at the end of a training batch.""" + model = get_model(pl_module) + + if is_module_wrapper(model): + module = model.module + else: + module = model + + assert isinstance(module, ModelEMAAdapter), ( + "Model should be wrapped with ModelEMAAdapter when using " + "EMACallback." + ) + + module.update(trainer.global_step) diff --git a/vis4d/engine/callbacks/evaluator.py b/vis4d/engine/callbacks/evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..2430edf10126eb22f5105d674c30bd0163c7c41f --- /dev/null +++ b/vis4d/engine/callbacks/evaluator.py @@ -0,0 +1,193 @@ +"""This module contains utilities for callbacks.""" + +from __future__ import annotations + +import os +from typing import Any + +import lightning.pytorch as pl + +from vis4d.common.distributed import ( + all_gather_object_cpu, + broadcast, + rank_zero_only, + synchronize, +) +from vis4d.common.logging import rank_zero_info +from vis4d.common.typing import ArgsType, MetricLogs +from vis4d.data.typing import DictData +from vis4d.eval.base import Evaluator + +from .base import Callback + + +class EvaluatorCallback(Callback): + """Callback for model evaluation.""" + + def __init__( + self, + *args: ArgsType, + evaluator: Evaluator, + metrics_to_eval: list[str] | None = None, + save_predictions: bool = False, + save_prefix: None | str = None, + output_dir: str | None = None, + **kwargs: ArgsType, + ) -> None: + """Init callback. + + Args: + evaluator (Evaluator): Evaluator. + metrics_to_eval (list[str], Optional): Metrics to evaluate. If + None, all metrics in the evaluator will be evaluated. Defaults + to None. + save_predictions (bool): If the predictions should be saved. + Defaults to False. + save_prefix (str, Optional): Output directory for saving the + evaluation results. Defaults to None. + output_dir (str, Optional): Output directory for saving the + evaluation results. + """ + super().__init__(*args, **kwargs) + self.evaluator = evaluator + self.save_predictions = save_predictions + self.metrics_to_eval = metrics_to_eval or self.evaluator.metrics + + if self.save_predictions: + assert ( + output_dir is not None + ), "If save_predictions is True, save_prefix must be provided." + + output_dir = os.path.join(output_dir, "eval") + + self.output_dir = output_dir + self.save_prefix = save_prefix + + def setup( + self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str + ) -> None: # pragma: no cover + """Setup callback.""" + if self.save_predictions: + self.output_dir = broadcast(self.output_dir) + + if self.save_prefix is not None: + self.output_dir = os.path.join( + self.output_dir, self.save_prefix + ) + + for metric in self.metrics_to_eval: + output_dir = os.path.join(self.output_dir, metric) + os.makedirs(output_dir, exist_ok=True) + self.evaluator.reset() + + def on_validation_batch_end( # type: ignore + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: Any, + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + """Hook to run at the end of a validation batch.""" + self.on_test_batch_end( + trainer=trainer, + pl_module=pl_module, + outputs=outputs, + batch=batch, + batch_idx=batch_idx, + dataloader_idx=dataloader_idx, + ) + + def on_validation_epoch_end( + self, trainer: pl.Trainer, pl_module: pl.LightningModule + ) -> None: + """Wait for on_validation_epoch_end PL hook to call 'evaluate'.""" + log_dict = self.run_eval() + + for k, v in log_dict.items(): + pl_module.log(f"val/{k}", v, sync_dist=True, rank_zero_only=True) + + def on_test_batch_end( # type: ignore + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: DictData, + batch: DictData, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + """Hook to run at the end of a testing batch.""" + self.evaluator.process_batch( + **self.get_test_callback_inputs(outputs, batch) + ) + for metric in self.metrics_to_eval: + # Save output predictions in current batch. + if self.save_predictions: + output_dir = os.path.join(self.output_dir, metric) + self.evaluator.save_batch(metric, output_dir) + + def on_test_epoch_end( + self, trainer: pl.Trainer, pl_module: pl.LightningModule + ) -> None: + """Hook to run at the end of a testing epoch.""" + log_dict = self.run_eval() + + for k, v in log_dict.items(): + pl_module.log(f"test/{k}", v, sync_dist=True, rank_zero_only=True) + + def run_eval(self) -> MetricLogs: + """Run evaluation for the given evaluator.""" + self.evaluator.gather(all_gather_object_cpu) + + synchronize() + self.process() + + log_dict: MetricLogs = {} + for metric in self.metrics_to_eval: + metric_dict = self.evaluate(metric) + metric_dict = broadcast(metric_dict) + assert isinstance(metric_dict, dict) + log_dict.update(metric_dict) + + self.evaluator.reset() + + return log_dict + + @rank_zero_only + def process(self) -> None: + """Process the evaluator.""" + self.evaluator.process() + + @rank_zero_only + def evaluate(self, metric: str) -> MetricLogs: + """Evaluate the performance after processing all input/output pairs. + + Returns: + MetricLogs: A dictionary containing the evaluation results. The + keys are formatted as {metric_name}/{key_name}, and the + values are the corresponding evaluated values. + """ + rank_zero_info( + f"Running evaluator {str(self.evaluator)} with {metric} metric... " + ) + log_dict = {} + + # Save output predictions. This is done here instead of + # on_test_batch_end because the evaluator may not have processed + # all batches yet. + if self.save_predictions: + output_dir = os.path.join(self.output_dir, metric) + self.evaluator.save(metric, output_dir) + + # Evaluate metric + metric_dict, metric_str = self.evaluator.evaluate(metric) + for k, v in metric_dict.items(): + log_k = metric + "/" + k + rank_zero_info("%s: %.4f", log_k, v) + log_dict[f"{metric}/{k}"] = v + + rank_zero_info("Showing results for metric: %s", metric) + rank_zero_info(metric_str) + + return log_dict diff --git a/vis4d/engine/callbacks/logging.py b/vis4d/engine/callbacks/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..86fcb6e5886b528ae34f7932eaed6b10e62e5e92 --- /dev/null +++ b/vis4d/engine/callbacks/logging.py @@ -0,0 +1,165 @@ +"""This module contains utilities for callbacks.""" + +from __future__ import annotations + +from collections import defaultdict +from typing import Any + +import lightning.pytorch as pl + +from vis4d.common.logging import rank_zero_info +from vis4d.common.progress import compose_log_str +from vis4d.common.time import Timer +from vis4d.common.typing import ArgsType, MetricLogs + +from .base import Callback + + +class LoggingCallback(Callback): + """Callback for logging.""" + + def __init__( + self, *args: ArgsType, refresh_rate: int = 50, **kwargs: ArgsType + ) -> None: + """Init callback.""" + super().__init__(*args, **kwargs) + self._refresh_rate = refresh_rate + self._metrics: dict[str, list[float]] = defaultdict(list) + self.train_timer = Timer() + self.test_timer = Timer() + self.last_step = 0 + + def on_train_epoch_start( + self, trainer: pl.Trainer, pl_module: pl.LightningModule + ) -> None: + """Hook to run at the start of a training epoch.""" + if self.epoch_based: + self.train_timer.reset() + self.last_step = 0 + self._metrics.clear() + elif trainer.global_step == 0: + self.train_timer.reset() + + def on_train_batch_start( # type: ignore + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + batch: Any, + batch_idx: int, + ) -> None: + """Hook to run at the start of a training batch.""" + if self.train_timer.paused: + self.train_timer.resume() + + def on_train_batch_end( # type: ignore + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: Any, + batch: Any, + batch_idx: int, + ) -> None: + """Hook to run at the end of a training batch.""" + if "metrics" in outputs: + for k, v in outputs["metrics"].items(): + self._metrics[k].append(v) + + if self.epoch_based: + cur_iter = batch_idx + 1 + + # Resolve float("inf") to -1 + if isinstance(trainer.num_training_batches, float): + total_iters = -1 + else: + total_iters = trainer.num_training_batches + else: + cur_iter = trainer.global_step + 1 + total_iters = trainer.max_steps + + if cur_iter % self._refresh_rate == 0 and cur_iter != self.last_step: + prefix = ( + f"Epoch {pl_module.current_epoch + 1}" + if self.epoch_based + else "Iter" + ) + + log_dict: MetricLogs = { + k: sum(v) / len(v) if len(v) > 0 else float("NaN") + for k, v in self._metrics.items() + } + + rank_zero_info( + compose_log_str( + prefix, cur_iter, total_iters, self.train_timer, log_dict + ) + ) + + self._metrics.clear() + self.last_step = cur_iter + + for k, v in log_dict.items(): + pl_module.log(f"train/{k}", v, rank_zero_only=True) + + def on_validation_epoch_start( + self, trainer: pl.Trainer, pl_module: pl.LightningModule + ) -> None: + """Hook to run at the start of a validation epoch.""" + self.test_timer.reset() + self.train_timer.pause() + + def on_validation_batch_end( # type: ignore + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: Any, + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + """Wait for on_validation_batch_end PL hook to call 'process'.""" + cur_iter = batch_idx + 1 + + # Resolve float("inf") to -1 + if isinstance(trainer.num_val_batches[dataloader_idx], int): + total_iters = int(trainer.num_val_batches[dataloader_idx]) + else: + total_iters = -1 + + if cur_iter % self._refresh_rate == 0: + rank_zero_info( + compose_log_str( + "Validation", cur_iter, total_iters, self.test_timer + ) + ) + + def on_test_epoch_start( + self, trainer: pl.Trainer, pl_module: pl.LightningModule + ) -> None: + """Hook to run at the start of a testing epoch.""" + self.test_timer.reset() + self.train_timer.pause() + + def on_test_batch_end( # type: ignore + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: Any, + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + """Hook to run at the end of a testing batch.""" + cur_iter = batch_idx + 1 + + # Resolve float("inf") to -1 + if isinstance(trainer.num_test_batches[dataloader_idx], int): + total_iters = int(trainer.num_test_batches[dataloader_idx]) + else: + total_iters = -1 + + if cur_iter % self._refresh_rate == 0: + rank_zero_info( + compose_log_str( + "Testing", cur_iter, total_iters, self.test_timer + ) + ) diff --git a/vis4d/engine/callbacks/scheduler.py b/vis4d/engine/callbacks/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..c8cb0a541b22ffb9cf368aea7b5c03f1d1a86bd4 --- /dev/null +++ b/vis4d/engine/callbacks/scheduler.py @@ -0,0 +1,44 @@ +"""Callback to configure learning rate during training.""" + +from __future__ import annotations + +from collections.abc import Iterable +from typing import Any + +import lightning.pytorch as pl + +from vis4d.engine.optim.scheduler import LRSchedulerWrapper + +from .base import Callback + + +class LRSchedulerCallback(Callback): + """Callback to configure learning rate during training.""" + + def __init__(self) -> None: + """Initialize the callback.""" + super().__init__() + self.last_step = 0 + + def on_train_batch_end( # type: ignore + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: Any, + batch: Any, + batch_idx: int, + ) -> None: + """Hook on training batch end.""" + schedulers = pl_module.lr_schedulers() + + if not isinstance(schedulers, Iterable): + schedulers = [schedulers] # type: ignore + + if trainer.global_step != self.last_step: + for scheduler in schedulers: + if scheduler is None: + continue + assert isinstance(scheduler, LRSchedulerWrapper) + scheduler.step_on_batch(trainer.global_step) + + self.last_step = trainer.global_step diff --git a/vis4d/engine/callbacks/util.py b/vis4d/engine/callbacks/util.py new file mode 100644 index 0000000000000000000000000000000000000000..e46038a8ef696a679777bb806d5c008e4a92ab80 --- /dev/null +++ b/vis4d/engine/callbacks/util.py @@ -0,0 +1,24 @@ +"""PyTorch Lightning callbacks utilities.""" + +from __future__ import annotations + +import lightning.pytorch as pl +from torch import nn + +from vis4d.engine.loss_module import LossModule +from vis4d.engine.training_module import TrainingModule + + +def get_model(model: pl.LightningModule) -> nn.Module: + """Get model from pl module.""" + if isinstance(model, TrainingModule): + return model.model + return model + + +def get_loss_module(loss_module: pl.LightningModule) -> LossModule: + """Get loss_module from pl module.""" + assert hasattr(loss_module, "loss_module") and isinstance( + loss_module.loss_module, LossModule + ), "Loss module is not set in the training module." + return loss_module.loss_module diff --git a/vis4d/engine/callbacks/visualizer.py b/vis4d/engine/callbacks/visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..30179597fb3e653a66457409f34576b82ccc7b38 --- /dev/null +++ b/vis4d/engine/callbacks/visualizer.py @@ -0,0 +1,165 @@ +"""This module contains utilities for callbacks.""" + +from __future__ import annotations + +import os +from typing import Any + +import lightning.pytorch as pl + +from vis4d.common.distributed import broadcast, synchronize +from vis4d.common.typing import ArgsType +from vis4d.vis.base import Visualizer + +from .base import Callback + + +class VisualizerCallback(Callback): + """Callback for model visualization.""" + + def __init__( + self, + *args: ArgsType, + visualizer: Visualizer, + visualize_train: bool = False, + show: bool = False, + save_to_disk: bool = True, + save_prefix: str | None = None, + output_dir: str | None = None, + **kwargs: ArgsType, + ) -> None: + """Init callback. + + Args: + visualizer (Visualizer): Visualizer. + visualize_train (bool): If the training data should be visualized. + Defaults to False. + show (bool): If the visualizations should be shown. Defaults to + False. + save_to_disk (bool): If the visualizations should be saved to disk. + Defaults to True. + save_prefix (str): Output directory prefix for distinguish + different visualizations. + output_dir (str): Output directory for saving the visualizations. + """ + super().__init__(*args, **kwargs) + self.visualizer = visualizer + self.visualize_train = visualize_train + self.save_prefix = save_prefix + self.show = show + self.save_to_disk = save_to_disk + + if self.save_to_disk: + assert ( + output_dir is not None + ), "If save_to_disk is True, output_dir must be provided." + + output_dir = os.path.join(output_dir, "vis") + + self.output_dir = output_dir + self.save_prefix = save_prefix + + def setup( + self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str + ) -> None: # pragma: no cover + """Setup callback.""" + if self.save_to_disk: + self.output_dir = broadcast(self.output_dir) + + def on_train_batch_end( # type: ignore + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: Any, + batch: Any, + batch_idx: int, + ) -> None: + """Hook to run at the end of a training batch.""" + cur_iter = batch_idx + 1 + + if self.visualize_train: + self.visualizer.process( + cur_iter=cur_iter, + **self.get_train_callback_inputs(outputs, batch), + ) + + if self.show: + self.visualizer.show(cur_iter=cur_iter) + + if self.save_to_disk: + self.save(cur_iter=cur_iter, stage="train") + + self.visualizer.reset() + + def on_validation_batch_end( # type: ignore + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: Any, + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + """Hook to run at the end of a validation batch.""" + cur_iter = batch_idx + 1 + + self.visualizer.process( + cur_iter=cur_iter, + **self.get_test_callback_inputs(outputs, batch), + ) + + if self.show: + self.visualizer.show(cur_iter=cur_iter) + + if self.save_to_disk: + self.save(cur_iter=cur_iter, stage="val") + + self.visualizer.reset() + + def on_test_batch_end( # type: ignore + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: Any, + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + """Hook to run at the end of a testing batch.""" + cur_iter = batch_idx + 1 + + self.visualizer.process( + cur_iter=cur_iter, + **self.get_test_callback_inputs(outputs, batch), + ) + + if self.show: + self.visualizer.show(cur_iter=cur_iter) + + if self.save_to_disk: + self.save(cur_iter=cur_iter, stage="test") + + self.visualizer.reset() + + def save(self, cur_iter: int, stage: str) -> None: + """Save the visualizer state.""" + output_folder = os.path.join(self.output_dir, stage) + + if self.save_prefix is not None: + output_folder = os.path.join(output_folder, self.save_prefix) + + os.makedirs(output_folder, exist_ok=True) + + self.visualizer.save_to_disk( + cur_iter=cur_iter, output_folder=output_folder + ) + + # TODO: Add support for logging images to WandB. + # if get_rank() == 0: + # if isinstance(trainer.logger, WandbLogger) and image is not None: + # trainer.logger.log_image( + # key=f"{self.visualizer}/{cur_iter}", + # images=[image], + # ) + + synchronize() diff --git a/vis4d/engine/callbacks/yolox_callbacks.py b/vis4d/engine/callbacks/yolox_callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..868e825dc41c52bb930d76b1798d8d503cd9286d --- /dev/null +++ b/vis4d/engine/callbacks/yolox_callbacks.py @@ -0,0 +1,196 @@ +"""YOLOX-specific callbacks.""" + +from __future__ import annotations + +import random +from collections import OrderedDict +from typing import Any + +import lightning.pytorch as pl +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn.modules.batchnorm import _NormBase +from torch.utils.data import DataLoader + +from vis4d.common.distributed import ( + all_reduce_dict, + broadcast, + get_rank, + get_world_size, + synchronize, +) +from vis4d.common.logging import rank_zero_info, rank_zero_warn +from vis4d.common.typing import ArgsType, DictStrAny +from vis4d.data.const import CommonKeys as K +from vis4d.data.data_pipe import DataPipe +from vis4d.op.detect.yolox import YOLOXHeadLoss +from vis4d.op.loss.common import l1_loss + +from .base import Callback +from .util import get_loss_module, get_model + + +class YOLOXModeSwitchCallback(Callback): + """Callback for switching the mode of YOLOX training.""" + + def __init__( + self, *args: ArgsType, switch_epoch: int, **kwargs: ArgsType + ) -> None: + """Init callback. + + Args: + switch_epoch (int): Epoch to switch the mode. + """ + super().__init__(*args, **kwargs) + self.switch_epoch = switch_epoch + self.switched = False + + def on_train_epoch_end( + self, trainer: pl.Trainer, pl_module: pl.LightningModule + ) -> None: + """Hook to run at the end of a training epoch.""" + if pl_module.current_epoch < self.switch_epoch - 1 or self.switched: + # TODO: Make work with resume. + return + + loss_module = get_loss_module(pl_module) + + found_loss = False + for loss in loss_module.losses: + if isinstance(loss["loss"], YOLOXHeadLoss): + found_loss = True + yolox_loss = loss["loss"] + break + rank_zero_info( + "Switching YOLOX training mode starting next training epoch " + "(turning off strong augmentations, adding L1 loss, switching to " + "validation every epoch)." + ) + if found_loss: + yolox_loss.loss_l1 = l1_loss # set L1 loss function + else: + rank_zero_warn("YOLOXHeadLoss should be in LossModule.") + # Set data pipeline to default DataPipe to skip strong augs. + # Switch to checking validation every epoch. + dataloader = trainer.train_dataloader + assert dataloader is not None + new_dataloader = DataLoader( + DataPipe(dataloader.dataset.datasets), + batch_size=dataloader.batch_size, + num_workers=dataloader.num_workers, + collate_fn=dataloader.collate_fn, + sampler=dataloader.sampler, + persistent_workers=dataloader.persistent_workers, + pin_memory=dataloader.pin_memory, + ) + + pl_module.check_val_every_n_epoch = 1 # type: ignore + + # Override train_dataloader method in PL datamodule. + # Set reload_dataloaders_every_n_epochs to 1 to use the new + # dataloader. + def train_dataloader() -> DataLoader: # type: ignore + """Return dataloader for training.""" + return new_dataloader + + pl_module.datamodule.train_dataloader = train_dataloader # type: ignore # pylint: disable=line-too-long + pl_module.reload_dataloaders_every_n_epochs = self.switch_epoch # type: ignore # pylint: disable=line-too-long + + self.switched = True + + +def get_norm_states(module: nn.Module) -> DictStrAny: + """Get the state_dict of batch norms in the module. + + Args: + module (nn.Module): Module to get batch norm states from. + """ + async_norm_states = OrderedDict() + for name, child in module.named_modules(): + if isinstance(child, _NormBase): + for k, v in child.state_dict().items(): + async_norm_states[".".join([name, k])] = v + return async_norm_states + + +class YOLOXSyncNormCallback(Callback): + """Callback for syncing the norm states of YOLOX training.""" + + def on_test_epoch_start( + self, trainer: pl.Trainer, pl_module: pl.LightningModule + ) -> None: + """Hook to run at the beginning of a testing epoch.""" + if get_world_size() > 1: + model = get_model(pl_module) + norm_states = get_norm_states(model) + + if len(norm_states) > 0: + rank_zero_info("Synced norm states across all processes.") + norm_states = all_reduce_dict(norm_states, reduce_op="mean") + model.load_state_dict(norm_states, strict=False) + + +class YOLOXSyncRandomResizeCallback(Callback): + """Callback for syncing random resize during YOLOX training.""" + + def __init__( + self, + *args: ArgsType, + size_list: list[tuple[int, int]], + interval: int, + **kwargs: ArgsType, + ) -> None: + """Init callback.""" + super().__init__(*args, **kwargs) + self.size_list = size_list + self.interval = interval + self.random_shape = size_list[-1] + + def _get_random_shape(self, device: torch.device) -> tuple[int, int]: + """Randomly generate shape from size_list and sync across ranks.""" + shape_tensor = torch.zeros(2, dtype=torch.int).to(device) + if get_rank() == 0: + random_shape = random.choice(self.size_list) + shape_tensor[0], shape_tensor[1] = random_shape[0], random_shape[1] + synchronize() + shape_tensor = broadcast(shape_tensor, 0) + return (int(shape_tensor[0].item()), int(shape_tensor[1].item())) + + def on_train_batch_start( # type: ignore + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + batch: Any, + batch_idx: int, + ) -> None: + """Hook to run at the start of a training batch.""" + if not isinstance(batch, list): + batch = [batch] + if (trainer.global_step + 1) % self.interval == 0: + self.random_shape = self._get_random_shape( + batch[0][K.images].device + ) + + for b in batch: + scale_y = self.random_shape[0] / b[K.images].shape[-2] + scale_x = self.random_shape[1] / b[K.images].shape[-1] + + if scale_y == 1 and scale_x == 1: + return + + # resize images + b[K.images] = F.interpolate( + b[K.images], + size=self.random_shape, + mode="bilinear", + align_corners=False, + ) + b[K.input_hw] = [ + self.random_shape for _ in range(b[K.images].size(0)) + ] + + # resize boxes + for boxes in b[K.boxes2d]: + boxes[..., ::2] = boxes[..., ::2] * scale_x + boxes[..., 1::2] = boxes[..., 1::2] * scale_y diff --git a/vis4d/engine/connectors/__init__.py b/vis4d/engine/connectors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..18454f0ffa2f7f1455e24dcb3ad7f321801661f4 --- /dev/null +++ b/vis4d/engine/connectors/__init__.py @@ -0,0 +1,31 @@ +"""Data connector for data connection.""" + +from .base import CallbackConnector, DataConnector, LossConnector +from .multi_sensor import ( + MultiSensorCallbackConnector, + MultiSensorDataConnector, + MultiSensorLossConnector, + get_multi_sensor_inputs, +) +from .util import ( + SourceKeyDescription, + data_key, + get_inputs_for_pred_and_data, + pred_key, + remap_pred_keys, +) + +__all__ = [ + "CallbackConnector", + "DataConnector", + "data_key", + "get_multi_sensor_inputs", + "get_inputs_for_pred_and_data", + "LossConnector", + "MultiSensorDataConnector", + "MultiSensorCallbackConnector", + "MultiSensorLossConnector", + "pred_key", + "remap_pred_keys", + "SourceKeyDescription", +] diff --git a/vis4d/engine/connectors/base.py b/vis4d/engine/connectors/base.py new file mode 100644 index 0000000000000000000000000000000000000000..75e38d6e35e60767e41cca39c599ff98d97a775d --- /dev/null +++ b/vis4d/engine/connectors/base.py @@ -0,0 +1,108 @@ +"""Base data connector to define data structures for data connection.""" + +from __future__ import annotations + +from typing import NamedTuple + +from torch import Tensor + +from vis4d.common.typing import DictStrArrNested +from vis4d.data.typing import DictData, DictDataOrList + +from .util import SourceKeyDescription, get_inputs_for_pred_and_data + + +class DataConnector: + """Defines which data to pass to which component. + + It extracts the required data from a 'DictData' objects and passes it to + the next component with the provided new key. + """ + + def __init__(self, key_mapping: dict[str, str]): + """Initializes the data connector with static remapping of the keys. + + Args: + key_mapping (dict[str, str]): Defines which kwargs to pass onto the + module. + + Simple Example Configuration: + + >>> train = dict(images = "images", gt = "gt_images) + >>> train_data_connector = DataConnector(train) + >>> test = dict(images = "images") + >>> test_data_connector = DataConnector(test) + """ + self.key_mapping = key_mapping + + def __call__(self, data: DictDataOrList) -> DictData: + """Returns the kwargs that are passed to the module. + + Args: + data (DictDataorList): The data (e.g. from the dataloader) which + contains all data that was loaded. + + Returns: + DictData: kwargs that are passed onto the model. + """ + if isinstance(data, list): + return { + k: [d[v] for d in data] for k, v in self.key_mapping.items() + } + return {k: data[v] for k, v in self.key_mapping.items()} + + +class LossConnector: + """Defines which data to pass to loss module of the training pipeline. + + It extracts the required data from prediction and data and passes it to + the next component with the provided new key. + """ + + def __init__(self, key_mapping: dict[str, SourceKeyDescription]) -> None: + """Initializes the data connector with static remapping of the keys.""" + self.key_mapping = key_mapping + + def __call__( + self, prediction: DictData | NamedTuple, data: DictData + ) -> dict[str, Tensor | DictStrArrNested]: + """Returns the kwargs that are passed to the loss module. + + Args: + prediction (DictData | NamedTuple): The output from model. + data (DictData): The data dictionary from the dataloader which + contains all data that was loaded. + + Returns: + dict[str, Tensor | DictStrArrNested]: kwargs that are passed + onto the loss. + """ + return get_inputs_for_pred_and_data(self.key_mapping, prediction, data) + + +class CallbackConnector: + """Data connector for the callback. + + It extracts the required data from prediction and datas and passes it to + the next component with the provided new key. + """ + + def __init__(self, key_mapping: dict[str, SourceKeyDescription]) -> None: + """Initializes the data connector with static remapping of the keys.""" + self.key_mapping = key_mapping + + def __call__( + self, prediction: DictData | NamedTuple, data: DictData + ) -> dict[str, Tensor | DictStrArrNested]: + """Returns the kwargs that are passed to the callback. + + Args: + prediction (DictData | NamedTuple): The output from model. + data (DictData): The data dictionary from the dataloader which + contains all data that was loaded. + + Returns: + dict[str, Tensor | DictStrArrNested]: kwargs that are passed + onto the callback. + """ + return get_inputs_for_pred_and_data(self.key_mapping, prediction, data) diff --git a/vis4d/engine/connectors/multi_sensor.py b/vis4d/engine/connectors/multi_sensor.py new file mode 100644 index 0000000000000000000000000000000000000000..abfa46fdd8270421fef659e249998071a9cbe892 --- /dev/null +++ b/vis4d/engine/connectors/multi_sensor.py @@ -0,0 +1,147 @@ +"""Data connector for multi-sensor dataset.""" + +from __future__ import annotations + +from typing import NamedTuple + +from vis4d.data.typing import DictData, DictDataOrList + +from .base import CallbackConnector, DataConnector, LossConnector +from .util import SourceKeyDescription, get_field_from_prediction + + +class MultiSensorDataConnector(DataConnector): + """Data connector for multi-sensor data dict.""" + + def __init__(self, key_mapping: dict[str, str | SourceKeyDescription]): + """Initializes the data connector with static remapping of the keys. + + Args: + key_mapping (dict[str, | SourceKeyDescription]): Defines which + kwargs to pass onto the module. + + TODO: Add Simple Example Configuration: + """ + _key_mapping = {} + multi_sensor_key_mapping = {} + + for k, v in key_mapping.items(): + if isinstance(v, dict): + sensors = v.get("sensors") + if sensors is not None: + multi_sensor_key_mapping[k] = v + else: + _key_mapping[k] = v["key"] + else: + _key_mapping[k] = v + + super().__init__(_key_mapping) + self.multi_sensor_key_mapping = multi_sensor_key_mapping + + def __call__(self, data: DictDataOrList) -> DictData: + """Returns the train input for the model.""" + input_dict = super().__call__(data) + + for target_key, source_key in self.multi_sensor_key_mapping.items(): + key = source_key["key"] + sensors = source_key["sensors"] + + if isinstance(data, list): + input_dict[target_key] = [ + [d[sensor][key] for sensor in sensors] for d in data + ] + else: + input_dict[target_key] = [ + data[sensor][key] for sensor in sensors + ] + return input_dict + + +class MultiSensorLossConnector(LossConnector): + """Multi-sensor Data connector for loss module of the training pipeline.""" + + def __call__( + self, prediction: DictData | NamedTuple, data: DictData + ) -> DictData: + """Returns the kwargs that are passed to the loss module. + + Args: + prediction (DictData | NamedTuple): The output from model. + data (DictData): The data dictionary from the dataloader which + contains all data that was loaded. + + Returns: + DictData: kwargs that are passed onto the loss. + """ + return get_multi_sensor_inputs(self.key_mapping, prediction, data) + + +class MultiSensorCallbackConnector(CallbackConnector): + """Multi-sensor data connector for the callback.""" + + def __call__( + self, prediction: DictData | NamedTuple, data: DictData + ) -> DictData: + """Returns the kwargs that are passed to the callback. + + Args: + prediction (DictData | NamedTuple): The output from model. + data (DictData): The data dictionary from the dataloader which + contains all data that was loaded. + + Returns: + DictData: kwargs that are passed onto the callback. + """ + return get_multi_sensor_inputs(self.key_mapping, prediction, data) + + +def get_multi_sensor_inputs( + connection_dict: dict[str, SourceKeyDescription], + prediction: DictData | NamedTuple, + data: DictData, +) -> DictData: + """Extracts multi-sensor input data from the provided SourceKeyDescription. + + Args: + connection_dict (dict[str, SourceKeyDescription]): Input Key + description which is used to gather and remap data from the + two data dicts. + prediction (DictData): Dict containing the model prediction output. + data (DictData): Dict containing the dataloader output. + + Raises: + ValueError: If the datasource is invalid. + + Returns: + out (DictData): Dict containing new kwargs consisting of new key name + and data extracted from the data dicts. + """ + out: DictData = {} + for new_key_name, old_key_name in connection_dict.items(): + # Assign field from data + if old_key_name["source"] == "data": + sensors = old_key_name.get("sensors") + + if sensors is None: + if old_key_name["key"] not in data: + raise ValueError( + f"Key {old_key_name['key']} not found in data dict." + f" Available keys: {data.keys()}" + ) + out[new_key_name] = data[old_key_name["key"]] + else: + out[new_key_name] = [ + data[sensor][old_key_name["key"]] for sensor in sensors + ] + + # Assign field from prediction + elif old_key_name["source"] == "prediction": + out[new_key_name] = get_field_from_prediction( + prediction, old_key_name + ) + else: + raise ValueError( + f"Unknown data source {old_key_name['source']}." + f"Available: [prediction, data]" + ) + return out diff --git a/vis4d/engine/connectors/util.py b/vis4d/engine/connectors/util.py new file mode 100644 index 0000000000000000000000000000000000000000..1a101215c9fca108e314a9cc5cde0ea4203bbfc0 --- /dev/null +++ b/vis4d/engine/connectors/util.py @@ -0,0 +1,152 @@ +"""Utility functions for the connectors module.""" + +from __future__ import annotations + +from collections.abc import Sequence +from copy import deepcopy +from typing import NamedTuple, TypedDict + +from torch import Tensor +from typing_extensions import NotRequired + +from vis4d.common.dict import get_dict_nested +from vis4d.common.named_tuple import get_from_namedtuple, is_namedtuple +from vis4d.common.typing import DictStrArrNested +from vis4d.data.typing import DictData + + +class SourceKeyDescription(TypedDict): + """Defines a data entry by providing the key and source of the data. + + Attributes: + key (str): Key that is used to index data from the specified source + source (str): Which datasource to choose from. + Options are ['data', 'prediction'] where data referes to the + output of the dataloader and prediction refers to the model + output + sensors (Sequence[str]): Which sensors to use for the data. + """ + + key: str + source: str + sensors: NotRequired[Sequence[str]] + + +def remap_pred_keys( + info: dict[str, SourceKeyDescription], parent_key: str +) -> dict[str, SourceKeyDescription]: + """Remaps the key of a connection mapping to a new parent key. + + Args: + info (SourceKeyDescription): Description to remap. + parent_key (str): New parent_key to use. + + Returns: + SourceKeyDescription: Description with new key. + + """ + info = deepcopy(info) + + for value in info.values(): + if value["source"] == "prediction": + value["key"] = parent_key + "." + value["key"] + return info + + +def data_key( + key: str, sensors: Sequence[str] | None = None +) -> SourceKeyDescription: + """Returns a SourceKeyDescription with data as source. + + Args: + key (str): Key to use for the data entry. + sensors (Sequence[str] | None, optional): Which sensors to use for the + data. Defaults to None. + + Returns: + SourceKeyDescription: A SourceKeyDescription with data as source. + """ + if sensors is None: + return SourceKeyDescription(key=key, source="data") + + return SourceKeyDescription(key=key, source="data", sensors=sensors) + + +def pred_key(key: str) -> SourceKeyDescription: + """Returns a SourceKeyDescription with prediction as source. + + Args: + key (str): Key to use for the data entry. + + Returns: + SourceKeyDescription: A SourceKeyDescription with prediction as source. + """ + return SourceKeyDescription(key=key, source="prediction") + + +def get_field_from_prediction( + prediction: DictData | NamedTuple, + old_key_name: SourceKeyDescription, +) -> Tensor | DictStrArrNested: + """Extracts a field from the prediction dict. + + Args: + prediction (DictData): Dict containing the model prediction output. + old_key_name (SourceKeyDescription): Description of the data to + extract. + + Returns: + Tensor | DictStrArrNested: Data extracted from the prediction dict. + """ + if is_namedtuple(prediction): + return get_from_namedtuple( + prediction, old_key_name["key"] # type: ignore + ) + + old_key = old_key_name["key"] + return get_dict_nested(prediction, old_key.split(".")) # type: ignore + + +def get_inputs_for_pred_and_data( + connection_dict: dict[str, SourceKeyDescription], + prediction: DictData | NamedTuple, + data: DictData, +) -> dict[str, Tensor | DictStrArrNested]: + """Extracts input data from the provided SourceKeyDescription. + + Args: + connection_dict (dict[str, SourceKeyDescription]): Input Key + description which is used to gather and remap data from the + two data dicts. + prediction (DictData): Dict containing the model prediction output. + data (DictData): Dict containing the dataloader output. + + Raises: + ValueError: If the datasource is invalid. + + Returns: + out (dict[str, Tensor | DictStrArrNested]): Dict containing new kwargs + consisting of new key name and data extracted from the data dicts. + """ + out = {} + for new_key_name, old_key_name in connection_dict.items(): + # Assign field from data + if old_key_name["source"] == "data": + if old_key_name["key"] not in data: + raise ValueError( + f"Key {old_key_name['key']} not found in data dict." + f" Available keys: {data.keys()}" + ) + out[new_key_name] = data[old_key_name["key"]] + + # Assign field from model prediction + elif old_key_name["source"] == "prediction": + out[new_key_name] = get_field_from_prediction( + prediction, old_key_name + ) + else: + raise ValueError( + f"Unknown data source {old_key_name['source']}." + f" Available: [prediction, data]" + ) + return out diff --git a/vis4d/engine/data_module.py b/vis4d/engine/data_module.py new file mode 100644 index 0000000000000000000000000000000000000000..de307a6e1519a4de29549c1aa887ff5edfcc9478 --- /dev/null +++ b/vis4d/engine/data_module.py @@ -0,0 +1,39 @@ +"""Data module composing the data loading pipeline.""" + +from __future__ import annotations + +import lightning.pytorch as pl +from torch.utils.data import DataLoader + +from vis4d.config import instantiate_classes +from vis4d.config.typing import DataConfig +from vis4d.data.typing import DictData + + +class DataModule(pl.LightningDataModule): + """DataModule that wraps around the vis4d implementations. + + This is a wrapper around the vis4d implementations that allows to use + pytorch-lightning for training and testing. + """ + + def __init__(self, data_cfg: DataConfig) -> None: + """Creates an instance of the class.""" + super().__init__() + self.data_cfg = data_cfg + + def train_dataloader(self) -> DataLoader[DictData]: + """Return dataloader for training.""" + if self.trainer is not None and hasattr(self.trainer, "seed"): + seed = self.trainer.seed + else: + seed = None + return instantiate_classes(self.data_cfg.train_dataloader, seed=seed) + + def test_dataloader(self) -> list[DataLoader[DictData]]: + """Return dataloaders for testing.""" + return instantiate_classes(self.data_cfg.test_dataloader) + + def val_dataloader(self) -> list[DataLoader[DictData]]: + """Return dataloaders for validation.""" + return self.test_dataloader() diff --git a/vis4d/engine/flag.py b/vis4d/engine/flag.py new file mode 100644 index 0000000000000000000000000000000000000000..0c0b9cfc3a3040ae0c29f987674e96e0edc49f6c --- /dev/null +++ b/vis4d/engine/flag.py @@ -0,0 +1,34 @@ +"""Engine Flags.""" + +from absl import flags + +from .parser import DEFINE_config_file + +_CONFIG = DEFINE_config_file("config", method_name="get_config") +_GPUS = flags.DEFINE_integer("gpus", default=0, help="Number of GPUs per node") +_NODES = flags.DEFINE_integer("nodes", default=1, help="Number of nodes") +_WANDB = flags.DEFINE_bool( + "wandb", default=False, help="If set, use Weights & Biases for logging." +) +_CKPT = flags.DEFINE_string("ckpt", default=None, help="Checkpoint path") +_RESUME = flags.DEFINE_bool("resume", default=False, help="Resume training") +_SHOW_CONFIG = flags.DEFINE_bool( + "print-config", default=False, help="If set, prints the configuration." +) +_VIS = flags.DEFINE_bool( + "vis", + default=False, + help="If set, running visualization using visualizer callback.", +) + + +__all__ = [ + "_CONFIG", + "_GPUS", + "_NODES", + "_CKPT", + "_RESUME", + "_SHOW_CONFIG", + "_WANDB", + "_VIS", +] diff --git a/vis4d/engine/loss_module.py b/vis4d/engine/loss_module.py new file mode 100644 index 0000000000000000000000000000000000000000..fc72531781898395ea559b5ff35be1fd3683cbe2 --- /dev/null +++ b/vis4d/engine/loss_module.py @@ -0,0 +1,215 @@ +"""Loss module maps loss function input keys and controls loss weight.""" + +from __future__ import annotations + +from typing import TypedDict, Union + +import torch +from torch import Tensor, nn +from typing_extensions import NotRequired + +from vis4d.common.named_tuple import is_namedtuple +from vis4d.common.typing import LossesType +from vis4d.data.typing import DictData +from vis4d.engine.connectors import LossConnector +from vis4d.op.loss.base import Loss + +NestedLossesType = Union[dict[str, "NestedLossesType"], LossesType] + + +class LossDefinition(TypedDict): + """Loss definition. + + Attributes: + loss (Loss | nn.Module): Loss function to use. + connector (LossConnector): Connector to use for the loss. + weight (float | dict[str, float], optional): Weight to use for the + loss. + name (str, optional): Name to use for the loss. + """ + + loss: Loss | nn.Module + connector: LossConnector + weight: NotRequired[float | dict[str, float]] + name: NotRequired[str] + + +def _get_tensors_nested( + loss_dict: NestedLossesType, prefix: str = "" +) -> list[tuple[str, Tensor]]: + """Get tensors from loss dict. + + Args: + loss_dict (LossesType): Loss dict. + prefix (str, optional): Prefix to add to keys. Defaults to "". + + Returns: + list[tuple[str, Tensor]]: List of tensors. + + Raises: + ValueError: If loss dict contains non-tensor or dict values. + """ + named_tensors: list[tuple[str, Tensor]] = [] + for key in loss_dict: + value = loss_dict[key] + + if isinstance(value, Tensor): + named_tensors.append((prefix + key, value)) + elif isinstance(value, dict): + named_tensors.extend( + _get_tensors_nested(value, prefix + key + ".") + ) + else: + raise ValueError( + f"Loss dict must only contain tensors or dicts. " + f"Found {type(loss_dict[key])} at {prefix + key}." + ) + return named_tensors + + +class LossModule(nn.Module): + """Loss module maps input keys and combines losses with weights. + + This loss combines multiple losses with weights. The loss values are + weighted by the corresponding weight and returned as a dictionary. + """ + + def __init__( + self, + losses: list[LossDefinition] | LossDefinition, + exclude_attributes: list[str] | None = None, + ) -> None: + """Creates an instance of the class. + + Each loss will be called with arguments matching the kwargs of the loss + function through its connector. By default, the weight is set to 1.0. + + Args: + losses (list[LossDefinition]): List of loss definitions. + exclude_attributes (list[str] | None): List of attributes returned + by the losses that should be excluded from the total loss + computation. Use it to log metrics that should not be + optimised. Defaults to None. + + Example: + >>> loss = LossModule( + >>> [ + >>> { + >>> "loss": nn.MSELoss(), + >>> "weight": 0.7, + >>> "connector": LossConnector( + >>> { + >>> "input": pred_key("input"), + >>> "target": data_key("target"), + >>> } + >>> ), + >>> }, + >>> { + >>> "loss": nn.L1Loss(), + >>> "weight": 0.3 + >>> "connector": LossConnector( + >>> { + >>> "input": pred_key("input"), + >>> "target": data_key("target"), + >>> } + >>> ), + >>> }, + >>> ] + >>> ) + """ + super().__init__() + self.losses: list[LossDefinition] = [] + + if not isinstance(losses, list): + losses = [losses] + + for loss in losses: + assert "loss" in loss, "Loss definition must contain a loss." + assert ( + "connector" in loss + ), "Loss definition must contain a connector." + + if "name" not in loss: + loss["name"] = loss["loss"].__class__.__name__ + + if "weight" not in loss: + loss["weight"] = 1.0 + + self.losses.append(loss) + + self.exclude_attributes = exclude_attributes + + def forward( + self, output: DictData, batch: DictData + ) -> tuple[Tensor, dict[str, float]]: + """Forward of loss module. + + This function will call all loss functions and return a dictionary + containing the loss values. The loss values are weighted by the + corresponding weight. + + If two losses have the same name, the name will be appended with + two underscores. + + Args: + output (DictData): Output of the model. + batch (DictData): Batch data. + + Returns: + total_loss: The total loss value. + metrics: The metrics disctionary. + """ + loss_dict: LossesType = {} + + for loss in self.losses: + loss_values_as_dict: LossesType = {} + name = loss["name"] + + loss_value = loss["loss"](**loss["connector"](output, batch)) + + # Convert loss value to one level dict. + if isinstance(loss_value, Tensor): + # Loss returned a simple tensor + loss_values_as_dict[name] = loss_value + elif isinstance(loss_value, dict): + # Loss returned a dictionary. + for loss_name, loss_value in _get_tensors_nested( + loss_value, name + "." + ): + loss_values_as_dict[loss_name] = loss_value + elif is_namedtuple(loss_value): + # Loss returned a named tuple. + for loss_name, loss_value in zip( + loss_value._fields, loss_value + ): + loss_values_as_dict[name + "." + loss_name] = loss_value + + # Assign values + for key, value in loss_values_as_dict.items(): + if value is None: + continue + + if isinstance(loss["weight"], dict): + loss_weight = loss["weight"].get(key, 1.0) + else: + loss_weight = loss["weight"] + + while key in loss_dict: + key = "__" + key + + loss_dict[key] = torch.mul(loss_weight, value) + + # Convert loss_dict to total loss and metrics dictionary + metrics: dict[str, float] = {} + keep_loss_dict: LossesType = {} + for k, v in loss_dict.items(): + metrics[k] = v.detach().cpu().item() + if ( + self.exclude_attributes is None + or k not in self.exclude_attributes + ): + keep_loss_dict[k] = v + total_loss: Tensor = sum(keep_loss_dict.values()) # type: ignore + metrics["loss"] = total_loss.detach().cpu().item() + + return total_loss, metrics diff --git a/vis4d/engine/optim/__init__.py b/vis4d/engine/optim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..72f88d69eb6f467b93a3f7fbc1104d7c44d2880e --- /dev/null +++ b/vis4d/engine/optim/__init__.py @@ -0,0 +1,17 @@ +"""Optimizer modules.""" + +from .optimizer import set_up_optimizers +from .scheduler import ( + ConstantLR, + LRSchedulerWrapper, + PolyLR, + QuadraticLRWarmup, +) + +__all__ = [ + "set_up_optimizers", + "LRSchedulerWrapper", + "PolyLR", + "ConstantLR", + "QuadraticLRWarmup", +] diff --git a/vis4d/engine/optim/optimizer.py b/vis4d/engine/optim/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..33629ee5ba5882cc7fb3a541cbf5e950a767ca71 --- /dev/null +++ b/vis4d/engine/optim/optimizer.py @@ -0,0 +1,166 @@ +"""Optimizer.""" + +from __future__ import annotations + +from typing import TypedDict + +from torch import nn +from torch.nn import GroupNorm, LayerNorm +from torch.nn.modules.batchnorm import _BatchNorm +from torch.nn.modules.instancenorm import _InstanceNorm +from torch.optim.optimizer import Optimizer +from typing_extensions import NotRequired + +from vis4d.common.logging import rank_zero_info +from vis4d.config import instantiate_classes +from vis4d.config.typing import OptimizerConfig, ParamGroupCfg + +from .scheduler import LRSchedulerWrapper + + +class ParamGroup(TypedDict): + """Parameter dictionary. + + Attributes: + params (list[nn.Parameter]): List of parameters. + lr (NotRequired[float]): Learning rate. + weight_decay (NotRequired[float]): Weight decay. + """ + + params: list[nn.Parameter] + lr: NotRequired[float] + weight_decay: NotRequired[float] + + +# TODO: Add true support for multiple optimizers. This will need to +# modify config to specify which optimizer to use for which module. +def set_up_optimizers( + optimizers_cfg: list[OptimizerConfig], + models: list[nn.Module], + steps_per_epoch: int = -1, +) -> tuple[list[Optimizer], list[LRSchedulerWrapper]]: + """Set up optimizers.""" + optimizers = [] + lr_schedulers = [] + for optim_cfg, model in zip(optimizers_cfg, models): + optimizer = configure_optimizer(optim_cfg, model) + optimizers.append(optimizer) + + if optim_cfg.lr_schedulers is not None: + lr_schedulers.append( + LRSchedulerWrapper( + optim_cfg.lr_schedulers, optimizer, steps_per_epoch + ) + ) + + return optimizers, lr_schedulers + + +def configure_optimizer( + optim_cfg: OptimizerConfig, model: nn.Module +) -> Optimizer: + """Configure optimizer with parameter groups.""" + param_groups_cfg = optim_cfg.get("param_groups", None) + + if param_groups_cfg is None: + return instantiate_classes( + optim_cfg.optimizer, params=model.parameters() + ) + + params = [] + base_lr = optim_cfg.optimizer["init_args"].lr + weight_decay = optim_cfg.optimizer["init_args"].get("weight_decay", None) + for group in param_groups_cfg: + lr_mult = group.get("lr_mult", 1.0) + decay_mult = group.get("decay_mult", 1.0) + norm_decay_mult = group.get("norm_decay_mult", None) + bias_decay_mult = group.get("bias_decay_mult", None) + + param_group: ParamGroup = {"params": [], "lr": base_lr * lr_mult} + + if weight_decay is not None: + if norm_decay_mult is not None: + param_group["weight_decay"] = weight_decay * norm_decay_mult + elif bias_decay_mult is not None: + param_group["weight_decay"] = weight_decay * bias_decay_mult + else: + param_group["weight_decay"] = weight_decay * decay_mult + + params.append(param_group) + + # Create a param group for the rest of the parameters + param_group = {"params": [], "lr": base_lr} + if weight_decay is not None: + param_group["weight_decay"] = weight_decay + params.append(param_group) + + # Add the parameters to the param groups + add_params(params, model, param_groups_cfg) + + return instantiate_classes(optim_cfg.optimizer, params=params) + + +def add_params( + params: list[ParamGroup], + module: nn.Module, + param_groups_cfg: list[ParamGroupCfg], + prefix: str = "", +) -> None: + """Add all parameters of module to the params list. + + The parameters of the given module will be added to the list of param + groups, with specific rules defined by paramwise_cfg. + + Args: + params (list[DictStrAny]): A list of param groups, it will be modified + in place. + module (nn.Module): The module to be added. + param_groups_cfg (dict[str, list[str] | float]): The configuration + of the param groups. + prefix (str): The prefix of the module. Default: ''. + """ + for name, param in module.named_parameters(recurse=False): + if not param.requires_grad: + params[-1]["params"].append(param) + continue + + is_norm = isinstance( + module, (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm) + ) + + # if the parameter match one of the custom keys, ignore other rules + is_custom = False + msg = f"{prefix}.{name}" + for i, group in enumerate(param_groups_cfg): + for key in group["custom_keys"]: + if key not in f"{prefix}.{name}": + continue + norm_decay_mult = group.get("norm_decay_mult", None) + bias_decay_mult = group.get("bias_decay_mult", None) + if group.get("lr_mult", None) is not None: + msg += f" with lr_mult: {group['lr_mult']}" + if norm_decay_mult is not None: + if not is_norm: + continue + msg += f" with norm_decay_mult: {norm_decay_mult}" + if bias_decay_mult is not None: + if name != "bias": + continue + msg += f" with bias_decay_mult: {bias_decay_mult}" + if group.get("decay_mult", None) is not None: + msg += f" with decay_mult: {group['decay_mult']}" + params[i]["params"].append(param) + is_custom = True + break + if is_custom: + break + + if is_custom: + rank_zero_info(msg) + else: + # add parameter to the last param group + params[-1]["params"].append(param) + + for child_name, child_mod in module.named_children(): + child_prefix = f"{prefix}.{child_name}" if prefix else child_name + add_params(params, child_mod, param_groups_cfg, prefix=child_prefix) diff --git a/vis4d/engine/optim/scheduler.py b/vis4d/engine/optim/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..70288f09b1a9259a7266c8e61ef397f13d167a67 --- /dev/null +++ b/vis4d/engine/optim/scheduler.py @@ -0,0 +1,262 @@ +# pylint: disable=no-member +"""LR schedulers.""" + +from __future__ import annotations + +from typing import TypedDict + +from torch.optim.lr_scheduler import LRScheduler +from torch.optim.optimizer import Optimizer + +from vis4d.common.typing import DictStrAny +from vis4d.config import copy_and_resolve_references, instantiate_classes +from vis4d.config.typing import LrSchedulerConfig + + +class LRSchedulerDict(TypedDict): + """LR scheduler.""" + + scheduler: LRScheduler + begin: int + end: int + epoch_based: bool + + +class LRSchedulerWrapper(LRScheduler): + """LR scheduler wrapper.""" + + def __init__( + self, + lr_schedulers_cfg: list[LrSchedulerConfig], + optimizer: Optimizer, + steps_per_epoch: int = -1, + ) -> None: + """Initialize LRSchedulerWrapper.""" + self.lr_schedulers_cfg: list[LrSchedulerConfig] = ( + copy_and_resolve_references(lr_schedulers_cfg) + ) + self.lr_schedulers: dict[int, LRSchedulerDict] = {} + super().__init__(optimizer) + self.steps_per_epoch = steps_per_epoch + self._convert_epochs_to_steps() + + for i, lr_scheduler_cfg in enumerate(self.lr_schedulers_cfg): + if lr_scheduler_cfg["begin"] == 0: + self._instantiate_lr_scheduler(i, lr_scheduler_cfg) + + def _convert_epochs_to_steps(self) -> None: + """Convert epochs to steps.""" + for lr_scheduler_cfg in self.lr_schedulers_cfg: + if ( + lr_scheduler_cfg["convert_epochs_to_steps"] + and not lr_scheduler_cfg["epoch_based"] + ): + lr_scheduler_cfg["begin"] *= self.steps_per_epoch + lr_scheduler_cfg["end"] *= self.steps_per_epoch + if lr_scheduler_cfg["convert_attributes"] is not None: + for attr in lr_scheduler_cfg["convert_attributes"]: + lr_scheduler_cfg["scheduler"]["init_args"][ + attr + ] *= self.steps_per_epoch + + def _instantiate_lr_scheduler( + self, scheduler_idx: int, lr_scheduler_cfg: LrSchedulerConfig + ) -> None: + """Instantiate LR schedulers.""" + # OneCycleLR needs max_lr to be set + if "max_lr" in lr_scheduler_cfg["scheduler"]["init_args"]: + lr_scheduler_cfg["scheduler"]["init_args"]["max_lr"] = [ + pg["lr"] for pg in self.optimizer.param_groups + ] + + self.lr_schedulers[scheduler_idx] = { + "scheduler": instantiate_classes( + lr_scheduler_cfg["scheduler"], optimizer=self.optimizer + ), + "begin": lr_scheduler_cfg["begin"], + "end": lr_scheduler_cfg["end"], + "epoch_based": lr_scheduler_cfg["epoch_based"], + } + + def get_lr(self) -> list[float]: + """Get current learning rate.""" + lr = [] + for lr_scheduler in self.lr_schedulers.values(): + lr.extend(lr_scheduler["scheduler"].get_lr()) + return lr + + def state_dict(self) -> dict[int, DictStrAny]: # type: ignore + """Get state dict.""" + state_dict = {} + for scheduler_idx, lr_scheduler in self.lr_schedulers.items(): + state_dict[scheduler_idx] = lr_scheduler["scheduler"].state_dict() + return state_dict + + def load_state_dict( + self, state_dict: dict[int, DictStrAny] # type: ignore + ) -> None: + """Load state dict.""" + for scheduler_idx, _state_dict in state_dict.items(): + # Instantiate the lr scheduler if it is not instantiated yet + if not scheduler_idx in self.lr_schedulers: + self._instantiate_lr_scheduler( + scheduler_idx, self.lr_schedulers_cfg[scheduler_idx] + ) + self.lr_schedulers[scheduler_idx]["scheduler"].load_state_dict( + _state_dict + ) + + def _step_lr(self, lr_scheduler: LRSchedulerDict, step: int) -> None: + """Step the learning rate.""" + if lr_scheduler["begin"] <= step and ( + lr_scheduler["end"] == -1 or lr_scheduler["end"] >= step + ): + lr_scheduler["scheduler"].step() + + def step(self, epoch: int | None = None) -> None: + """Step on training epoch end.""" + if epoch is not None: + for lr_scheduler in self.lr_schedulers.values(): + if lr_scheduler["epoch_based"]: + self._step_lr(lr_scheduler, epoch) + + for i, lr_scheduler_cfg in enumerate(self.lr_schedulers_cfg): + if lr_scheduler_cfg["epoch_based"] and ( + lr_scheduler_cfg["begin"] == epoch + 1 + ): + self._instantiate_lr_scheduler(i, lr_scheduler_cfg) + + def step_on_batch(self, step: int) -> None: + """Step on training batch end.""" + for lr_scheduler in self.lr_schedulers.values(): + if not lr_scheduler["epoch_based"]: + self._step_lr(lr_scheduler, step) + + for i, lr_scheduler_cfg in enumerate(self.lr_schedulers_cfg): + if not lr_scheduler_cfg["epoch_based"] and ( + lr_scheduler_cfg["begin"] == step + ): + self._instantiate_lr_scheduler(i, lr_scheduler_cfg) + + +class ConstantLR(LRScheduler): + """Constant learning rate scheduler. + + Args: + optimizer (Optimizer): Wrapped optimizer. + max_steps (int): Maximum number of steps. + factor (float): Scale factor. Default: 1.0 / 3.0. + last_epoch (int): The index of last epoch. Default: -1. + """ + + def __init__( + self, + optimizer: Optimizer, + max_steps: int, + factor: float = 1.0 / 3.0, + last_epoch: int = -1, + ): + """Initialize ConstantLR.""" + self.max_steps = max_steps + self.factor = factor + super().__init__(optimizer, last_epoch) + + def get_lr(self) -> list[float]: + """Compute current learning rate.""" + step_count = self._step_count - 1 + if step_count == 0: + return [ + group["lr"] * self.factor + for group in self.optimizer.param_groups + ] + if step_count == self.max_steps: + return [ + group["lr"] * (1.0 / self.factor) + for group in self.optimizer.param_groups + ] + return [group["lr"] for group in self.optimizer.param_groups] + + +class PolyLR(LRScheduler): + """Polynomial learning rate decay. + + Example: + Assuming lr = 0.001, max_steps = 4, min_lr = 0.0, and power = 1.0, the + learning rate will be: + lr = 0.001 if step == 0 + lr = 0.00075 if step == 1 + lr = 0.00050 if step == 2 + lr = 0.00025 if step == 3 + lr = 0.0 if step >= 4 + + Args: + optimizer (Optimizer): Wrapped optimizer. + max_steps (int): Maximum number of steps. + power (float, optional): Power factor. Default: 1.0. + min_lr (float): Minimum learning rate. Default: 0.0. + last_epoch (int): The index of last epoch. Default: -1. + """ + + def __init__( + self, + optimizer: Optimizer, + max_steps: int, + power: float = 1.0, + min_lr: float = 0.0, + last_epoch: int = -1, + ): + """Initialize PolyLRScheduler.""" + self.max_steps = max_steps + self.power = power + self.min_lr = min_lr + super().__init__(optimizer, last_epoch) + + def get_lr(self) -> list[float]: + """Compute current learning rate.""" + step_count = self._step_count - 1 + if step_count == 0 or step_count > self.max_steps: + return [group["lr"] for group in self.optimizer.param_groups] + decay_factor = ( + (1.0 - step_count / self.max_steps) + / (1.0 - (step_count - 1) / self.max_steps) + ) ** self.power + return [ + (group["lr"] - self.min_lr) * decay_factor + self.min_lr + for group in self.optimizer.param_groups + ] + + +class QuadraticLRWarmup(LRScheduler): + """Quadratic learning rate warmup. + + Args: + optimizer (Optimizer): Wrapped optimizer. + max_steps (int): Maximum number of steps. + last_epoch (int): The index of last epoch. Default: -1. + """ + + def __init__( + self, + optimizer: Optimizer, + max_steps: int, + last_epoch: int = -1, + ): + """Initialize QuadraticLRWarmup.""" + self.max_steps = max_steps + super().__init__(optimizer, last_epoch) + + def get_lr(self) -> list[float]: + """Compute current learning rate.""" + step_count = self._step_count - 1 + if step_count >= self.max_steps: + return self.base_lrs + factors = [ + base_lr * (2 * step_count + 1) / self.max_steps**2 + for base_lr in self.base_lrs # pylint: disable=not-an-iterable + ] + if step_count == 0: + return factors + return [ + group["lr"] + factor + for factor, group in zip(factors, self.optimizer.param_groups) + ] diff --git a/vis4d/engine/parser.py b/vis4d/engine/parser.py new file mode 100644 index 0000000000000000000000000000000000000000..6cb728b57513ec09c0caf5a56ecb8e70123b16f2 --- /dev/null +++ b/vis4d/engine/parser.py @@ -0,0 +1,218 @@ +"""Parser for config files that can be used with absl flags.""" + +from __future__ import annotations + +import logging +import re +import sys +import traceback +from typing import Any + +from absl import flags +from ml_collections import ConfigDict, FieldReference +from ml_collections.config_flags.config_flags import ( + _ConfigFlag, + _ErrorConfig, + _LockConfig, +) + +from vis4d.config import copy_and_resolve_references +from vis4d.config.registry import get_config_by_name + + +class ConfigFileParser(flags.ArgumentParser): # type: ignore + """Parser for config files.""" + + def __init__( + self, + name: str, + lock_config: bool = True, + method_name: str = "get_config", + ) -> None: + """Initializes the parser. + + Args: + name (str): The name of the flag (e.g. config for --config flag) + lock_config (bool, optional): Whether or not to lock the config. + Defaults to True. + method_name (str, optional): Name of the method to call in the + config. Defaults to "get_config". + """ + self.name = name + self._lock_config = lock_config + self.method_name = method_name + + def parse( # pylint: disable=arguments-renamed + self, path: str + ) -> ConfigDict | _ErrorConfig: + """Loads a config module from `path` and returns the `method_name()`. + + This implementation is based on the original ml_collections and + modified to allow for a custom method name. + + If a colon is present in `path`, everything to the right of the first + colon is passed to `method_name` as an argument. This allows the + structure of what + is returned to be modified, which is useful when performing complex + hyperparameter sweeps. + + Args: + path: string, path pointing to the config file to execute. May also + contain a config_string argument, e.g. be of the form + "config.py:some_configuration". + + Returns: + Result of calling `method_name` in the specified module. + """ + # This will be a 2 element list iff extra configuration args are + # present. + split_path = path.split(":", 1) + + try: + config = get_config_by_name( + split_path[0], + *split_path[1:], + method_name=self.method_name, + ) + if config is None: + logging.warning( + "%s:%s() returned None, did you forget a return " + "statement?", + path, + self.method_name, + ) + except IOError as e: + # Don't raise the error unless/until the config is + # actually accessed. + return _ErrorConfig(e) + # Third party flags library catches TypeError and ValueError + # and rethrows, + # removing useful information unless it is added here (b/63877430): + except (TypeError, ValueError) as e: + error_trace = traceback.format_exc() + raise type(e)( + "Error whilst parsing config file:\n\n" + error_trace + ) + + if self._lock_config: + _LockConfig(config) + + return config + + def flag_type(self) -> str: + """Returns the type of the flag.""" + return "config object" + + +def DEFINE_config_file( # pylint: disable=invalid-name + name: str, + default: str | None = None, + help_string: str = "path to config file [.py |.yaml].", + lock_config: bool = False, + method_name: str = "get_config", +) -> flags.FlagHolder: # type: ignore + """Registers a new flag for a config file. + + Args: + name (str): The name of the flag (e.g. config for --config flag) + default (str | None, optional): Default Value. Defaults to None. + help_string (str, optional): Help String. + Defaults to "path to config file.". + lock_config (bool, optional): Whether or note to lock the returned + config. Defaults to False. + method_name (str, optional): Name of the method to call in the config. + + Returns: + flags.FlagHolder: Flag holder instance. + """ + parser = ConfigFileParser( + name=name, lock_config=lock_config, method_name=method_name + ) + flag = _ConfigFlag( + parser=parser, + serializer=flags.ArgumentSerializer(), + name=name, + default=default, + help_string=help_string, + flag_values=flags.FLAGS, + ) + + # Get the module name for the frame at depth 1 in the call stack. + module_name = sys._getframe( # pylint: disable=protected-access + 1 + ).f_globals.get("__name__", None) + module_name = sys.argv[0] if module_name == "__main__" else module_name + return flags.DEFINE_flag(flag, flags.FLAGS, module_name=module_name) + + +def pprints_config(data: ConfigDict) -> str: + """Converts a Config Dict into a string with a .yaml like structure. + + This function differs from __repr__ of ConfigDict in that it will not + encode python classes using binary formats but just prints the __repr__ + of these classes. + + Args: + data (ConfigDict): Configuration dict to convert to string + + Returns: + str: A string representation of the ConfigDict + """ + return _pprints_config(copy_and_resolve_references(data)) + + +def _pprints_config( # type: ignore + data: Any, prefix: str = "", n_indents: int = 1 +) -> str: + """Converts a ConfigDict into a string with a YAML like structure. + + This is the recursive implementation of 'pprints_config' and will be called + recursively for every element in the dict. + + This function differs from __repr__ of ConfigDict in that it will not + encode python classes using binary formats but just prints the __repr__ + of these classes. + + Args: + data (Any): Configuration dict or object to convert to + string + prefix (str): Prefix to print on each new line + n_indents (int): Number of spaces to append for each nester property. + + Returns: + str: A string representation of the ConfigDict + """ + string_repr = "" + if isinstance(data, FieldReference): + data = data.get() + + if not isinstance(data, (dict, ConfigDict, list, tuple, dict)): + return str(data) + + string_repr += "\n" + + if isinstance(data, (ConfigDict, dict)): + for key in data: + value = data[key] + string_repr += ( + prefix + + key + + ": " + + _pprints_config(value, prefix=prefix + " " * n_indents) + ) + "\n" + + elif isinstance(data, (list, tuple)): + for value in data: + string_repr += prefix + "- " + if isinstance(value, (ConfigDict, dict)): + string_repr += "\n" + + string_repr += ( + _pprints_config(value, prefix=prefix + " " + " " * n_indents) + + "\n" + ) + string_repr += " \n" # Add newline after list for better readability. + + # Clean up some formatting issues using regex. Could be done better + string_repr = re.sub("\n\n+", "\n", string_repr) + return re.sub("- +\n +", "- ", string_repr) diff --git a/vis4d/engine/run.py b/vis4d/engine/run.py new file mode 100644 index 0000000000000000000000000000000000000000..07c723b2d20d63589ce548c365bfb2c0ec97cf10 --- /dev/null +++ b/vis4d/engine/run.py @@ -0,0 +1,180 @@ +"""CLI interface using PyTorch Lightning.""" + +from __future__ import annotations + +import logging +import os.path as osp + +import torch +from absl import app # pylint: disable=no-name-in-module +from torch.utils.collect_env import get_pretty_env_info + +from vis4d.common.logging import dump_config, rank_zero_info, setup_logger +from vis4d.common.typing import ArgsType +from vis4d.common.util import set_tf32 +from vis4d.config import instantiate_classes +from vis4d.config.typing import ExperimentConfig +from vis4d.engine.callbacks import ( + Callback, + LRSchedulerCallback, + VisualizerCallback, +) +from vis4d.engine.data_module import DataModule +from vis4d.engine.flag import ( + _CKPT, + _CONFIG, + _GPUS, + _NODES, + _RESUME, + _SHOW_CONFIG, + _VIS, + _WANDB, +) +from vis4d.engine.parser import pprints_config +from vis4d.engine.trainer import PLTrainer +from vis4d.engine.training_module import TrainingModule + + +def main(argv: ArgsType) -> None: + """Main entry point for the CLI. + + Example to run this script: + >>> python -m vis4d.pl.run fit --config configs/faster_rcnn/faster_rcnn_coco.py + """ + # Get config + mode = argv[1] + assert mode in {"fit", "test"}, f"Invalid mode: {mode}" + config: ExperimentConfig = _CONFIG.value + num_gpus = _GPUS.value + num_nodes = _NODES.value + + # Setup logging + logger_vis4d = logging.getLogger("vis4d") + logger_pl = logging.getLogger("pytorch_lightning") + log_file = osp.join(config.output_dir, f"log_{config.timestamp}.txt") + setup_logger(logger_vis4d, log_file) + setup_logger(logger_pl, log_file) + + # Dump config + config_file = osp.join( + config.output_dir, f"config_{config.timestamp}.yaml" + ) + dump_config(config, config_file) + + rank_zero_info("Environment info: %s", get_pretty_env_info()) + + # PyTorch Setting + set_tf32(config.use_tf32, config.tf32_matmul_precision) + torch.hub.set_dir(f"{config.work_dir}/.cache/torch/hub") + + # Setup device + if num_gpus > 0: + config.pl_trainer.accelerator = "gpu" + config.pl_trainer.devices = num_gpus + else: + config.pl_trainer.accelerator = "cpu" + config.pl_trainer.devices = 1 + + if num_nodes > 1: + config.pl_trainer.num_nodes = num_nodes + + # Wandb + config.pl_trainer.wandb = _WANDB.value + + trainer_args = instantiate_classes(config.pl_trainer).to_dict() + + if _SHOW_CONFIG.value: + rank_zero_info(pprints_config(config)) + + # Instantiate classes + if mode == "fit": + train_data_connector = instantiate_classes(config.train_data_connector) + loss = instantiate_classes(config.loss) + else: + train_data_connector = None + loss = None + + if config.test_data_connector is not None: + test_data_connector = instantiate_classes(config.test_data_connector) + else: + test_data_connector = None + + # Callbacks + vis = _VIS.value + + callbacks: list[Callback] = [] + for cb in config.callbacks: + callback = instantiate_classes(cb) + + assert isinstance(callback, Callback), ( + "Callback must be a subclass of Callback. " + f"Provided callback: {cb} is not!" + ) + + if not vis and isinstance(callback, VisualizerCallback): + rank_zero_info( + f"{callback.visualizer} is not used. " + "Please set --vis=True to use it." + ) + continue + + callbacks.append(callback) + + # Add needed callbacks + callbacks.append(LRSchedulerCallback()) + + # Checkpoint path + ckpt_path = _CKPT.value + + # Resume training + resume = _RESUME.value + if resume: + if ckpt_path is None: + resume_ckpt_path = osp.join( + config.output_dir, "checkpoints/last.ckpt" + ) + else: + resume_ckpt_path = ckpt_path + # Check if checkpoint exists, if not start fresh + if not osp.exists(resume_ckpt_path): + print(f"[vis4d] Checkpoint not found: {resume_ckpt_path}, starting fresh training") + resume_ckpt_path = None + else: + resume_ckpt_path = None + + trainer = PLTrainer(callbacks=callbacks, **trainer_args) + + hyper_params = trainer_args + + if config.get("params", None) is not None: + hyper_params.update(config.params.to_dict()) + + training_module = TrainingModule( + config.model, + config.optimizers, + loss, + train_data_connector, + test_data_connector, + hyper_params, + config.seed, + ckpt_path if not resume else None, + config.compute_flops, + config.check_unused_parameters, + ) + data_module = DataModule(config.data) + + if mode == "fit": + trainer.fit( + training_module, datamodule=data_module, ckpt_path=resume_ckpt_path + ) + elif mode == "test": + trainer.test(training_module, datamodule=data_module, verbose=False) + + +def entrypoint() -> None: + """Entry point for the CLI.""" + app.run(main) + + +if __name__ == "__main__": + entrypoint() diff --git a/vis4d/engine/trainer.py b/vis4d/engine/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..616e593e5823da846e026fb1304449f576550236 --- /dev/null +++ b/vis4d/engine/trainer.py @@ -0,0 +1,141 @@ +"""Trainer for PyTorch Lightning.""" + +from __future__ import annotations + +import datetime +import os.path as osp + +from lightning.pytorch import Callback, Trainer +from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint +from lightning.pytorch.loggers import Logger, TensorBoardLogger +from lightning.pytorch.loggers.wandb import WandbLogger +from lightning.pytorch.strategies.ddp import DDPStrategy + +from vis4d.common.imports import TENSORBOARD_AVAILABLE +from vis4d.common.logging import rank_zero_info +from vis4d.common.typing import ArgsType + + +class PLTrainer(Trainer): + """Trainer for PyTorch Lightning.""" + + def __init__( + self, + *args: ArgsType, + work_dir: str, + exp_name: str, + version: str, + epoch_based: bool = True, + find_unused_parameters: bool = False, + save_top_k: int = 1, + checkpoint_period: int = 1, + checkpoint_callback: ModelCheckpoint | None = None, + wandb: bool = False, + seed: int = -1, + timeout: int = 3600, + wandb_id: str | None = None, + **kwargs: ArgsType, + ) -> None: + """Perform some basic common setups at the beginning of a job. + + Args: + work_dir: Specific directory to save checkpoints, logs, etc. + Integrates with exp_name and version to get output_dir. + exp_name: Name of current experiment. + version: Version of current experiment. + epoch_based: Use epoch-based / iteration-based training. Default is + True. + find_unused_parameters: Activates PyTorch checking for unused + parameters in DDP setting. Default: False, for better + performance. + save_top_k: Save top k checkpoints. Default: 1 (save last). + checkpoint_period: After N epochs / stpes, save out checkpoints. + Default: 1. + checkpoint_callback: Custom PL checkpoint callback. Default: None. + wandb: Use weights and biases logging instead of tensorboard. + Default: False. + seed (int, optional): The integer value seed for global random + state. Defaults to -1. If -1, a random seed will be generated. + This will be set by TrainingModule. + timeout: Timeout (seconds) for DDP connection. Default: 3600. + wandb_id: If using wandb, the id of the run. If None, a new run + will be created. Default: None. + """ + self.work_dir = work_dir + self.exp_name = exp_name + self.version = version + self.seed = seed + + self.output_dir = osp.join(work_dir, exp_name, version) + + # setup experiment logging + if "logger" not in kwargs or ( + isinstance(kwargs["logger"], bool) and kwargs["logger"] + ): + exp_logger: Logger | None = None + if wandb: # pragma: no cover + exp_logger = WandbLogger( + save_dir=work_dir, + project=exp_name, + name=version, + id=wandb_id, + ) + elif TENSORBOARD_AVAILABLE: + exp_logger = TensorBoardLogger( + save_dir=work_dir, + name=exp_name, + version=version, + default_hp_metric=False, + ) + else: + rank_zero_info( + "Neither `tensorboard` nor `tensorboardX` is " + "available. Running without experiment logger. To log " + "your experiments, try `pip install`ing either." + ) + kwargs["logger"] = exp_logger + + callbacks: list[Callback] = [] + + # add learning rate / GPU stats monitor (logs to tensorboard) + if TENSORBOARD_AVAILABLE or wandb: + callbacks += [LearningRateMonitor(logging_interval="step")] + + # Model checkpointer + if checkpoint_callback is None: + if epoch_based: + checkpoint_cb = ModelCheckpoint( + dirpath=osp.join(self.output_dir, "checkpoints"), + verbose=True, + save_last=True, + save_top_k=save_top_k, + every_n_epochs=checkpoint_period, + save_on_train_epoch_end=True, + ) + else: + checkpoint_cb = ModelCheckpoint( + dirpath=osp.join(self.output_dir, "checkpoints"), + verbose=True, + save_last=True, + save_top_k=save_top_k, + every_n_train_steps=checkpoint_period, + ) + else: + checkpoint_cb = checkpoint_callback + callbacks += [checkpoint_cb] + + kwargs["callbacks"] += callbacks + + # add distributed strategy + if kwargs["devices"] == 0: + kwargs["accelerator"] = "cpu" + kwargs["devices"] = "auto" + elif kwargs["devices"] > 1: # pragma: no cover + if kwargs["accelerator"] == "gpu": + ddp_plugin = DDPStrategy( + find_unused_parameters=find_unused_parameters, + timeout=datetime.timedelta(timeout), + ) + kwargs["strategy"] = ddp_plugin + + super().__init__(*args, **kwargs) diff --git a/vis4d/engine/training_module.py b/vis4d/engine/training_module.py new file mode 100644 index 0000000000000000000000000000000000000000..99a3cf955c1abe1b807f0ec354ba8b8e51fffe84 --- /dev/null +++ b/vis4d/engine/training_module.py @@ -0,0 +1,210 @@ +"""LightningModule that wraps around the models, losses and optims.""" + +from __future__ import annotations + +from typing import Any + +import lightning.pytorch as pl +from lightning.pytorch import seed_everything +from lightning.pytorch.core.optimizer import LightningOptimizer +from ml_collections import ConfigDict +from torch import nn +from torch.optim.optimizer import Optimizer + +from vis4d.common.ckpt import load_model_checkpoint +from vis4d.common.distributed import broadcast +from vis4d.common.imports import FVCORE_AVAILABLE +from vis4d.common.logging import rank_zero_info +from vis4d.common.typing import DictStrAny, GenericFunc +from vis4d.common.util import init_random_seed +from vis4d.config import instantiate_classes +from vis4d.config.typing import OptimizerConfig +from vis4d.data.typing import DictData +from vis4d.engine.connectors import DataConnector +from vis4d.engine.loss_module import LossModule +from vis4d.engine.optim import LRSchedulerWrapper, set_up_optimizers +from vis4d.model.adapter.flops import IGNORED_OPS, FlopsModelAdapter + +if FVCORE_AVAILABLE: + from fvcore.nn import FlopCountAnalysis + + +class TrainingModule(pl.LightningModule): + """LightningModule that wraps around the vis4d implementations. + + This is a wrapper around the vis4d implementations that allows to use + pytorch-lightning for training and testing. + """ + + def __init__( + self, + model_cfg: ConfigDict, + optimizers_cfg: list[OptimizerConfig], + loss_module: None | LossModule, + train_data_connector: None | DataConnector, + test_data_connector: None | DataConnector, + hyper_parameters: DictStrAny | None = None, + seed: int = -1, + ckpt_path: None | str = None, + compute_flops: bool = False, + check_unused_parameters: bool = False, + ) -> None: + """Initialize the TrainingModule. + + Args: + model_cfg: The model config. + optimizers_cfg: The optimizers config. + loss_module: The loss module. + train_data_connector: The data connector to use. + test_data_connector: The data connector to use. + data_connector: The data connector to use. + hyper_parameters (DictStrAny | None, optional): The hyper + parameters to use. Defaults to None. + seed (int, optional): The integer value seed for global random + state. Defaults to -1. If -1, a random seed will be generated. + ckpt_path (str, optional): The path to the checkpoint to load. + Defaults to None. + compute_flops (bool, optional): If to compute the FLOPs of the + model. Defaults to False. + check_unused_parameters (bool, optional): If to check the + unused parameters. Defaults to False. + """ + super().__init__() + self.model_cfg = model_cfg + self.optimizers_cfg = optimizers_cfg + self.loss_module = loss_module + self.train_data_connector = train_data_connector + self.test_data_connector = test_data_connector + self.hyper_parameters = hyper_parameters + self.seed = seed + self.ckpt_path = ckpt_path + self.compute_flops = compute_flops + self.check_unused_parameters = check_unused_parameters + + # Create model placeholder + self.model: nn.Module + + def setup(self, stage: str) -> None: + """Setup the model.""" + if stage == "fit": + if self.seed == -1: + self.seed = init_random_seed() + self.seed = broadcast(self.seed) + self.trainer.seed = self.seed # type: ignore + + seed_everything(self.seed, workers=True) + rank_zero_info(f"Global seed set to {self.seed}") + + if self.hyper_parameters is not None: + self.hyper_parameters["seed"] = self.seed + if "checkpoint_callback" in self.hyper_parameters: + self.hyper_parameters.pop("checkpoint_callback") + self.save_hyperparameters(self.hyper_parameters) + + # Instantiate the model after the seed has been set + self.model = instantiate_classes(self.model_cfg) + + if self.ckpt_path is not None: + load_model_checkpoint( + self.model, + self.ckpt_path, + rev_keys=[(r"^model\.", ""), (r"^module\.", "")], + ) + + def forward( # type: ignore # pylint: disable=arguments-differ + self, data: DictData + ) -> Any: + """Forward pass through the model.""" + if self.training: + assert self.train_data_connector is not None + return self.model(**self.train_data_connector(data)) + assert self.test_data_connector is not None + return self.model(**self.test_data_connector(data)) + + def training_step( # type: ignore # pylint: disable=arguments-differ,line-too-long,unused-argument + self, batch: DictData, batch_idx: int + ) -> Any: + """Perform a single training step.""" + assert self.train_data_connector is not None + out = self.model(**self.train_data_connector(batch)) + + assert self.loss_module is not None + total_loss, metrics = self.loss_module(out, batch) + + return { + "loss": total_loss, + "metrics": metrics, + "predictions": out, + } + + def validation_step( # pylint: disable=arguments-differ,line-too-long,unused-argument + self, batch: DictData, batch_idx: int, dataloader_idx: int = 0 + ) -> DictData: + """Perform a single validation step.""" + assert self.test_data_connector is not None + out = self.model(**self.test_data_connector(batch)) + return out + + def test_step( # pylint: disable=arguments-differ,line-too-long,unused-argument + self, batch: DictData, batch_idx: int, dataloader_idx: int = 0 + ) -> DictData: + """Perform a single test step.""" + assert self.test_data_connector is not None + + if self.compute_flops: + flatten_inputs = [ + self.test_data_connector(batch)[key] + for key in self.test_data_connector(batch) + ] + + flops_model = FlopsModelAdapter( + self.model, self.test_data_connector + ) + + if not FVCORE_AVAILABLE: + raise RuntimeError( + "Please install fvcore to compute FLOPs of the model." + ) + + flop_analyzer = FlopCountAnalysis( # pylint: disable=possibly-used-before-assignment, line-too-long + flops_model, flatten_inputs + ) + + flop_analyzer.set_op_handle(**{k: None for k in IGNORED_OPS}) + + flops = flop_analyzer.total() / 1e9 + + rank_zero_info(f"Flops: {flops:.2f} Gflops") + + out = self.model(**self.test_data_connector(batch)) + return out + + def configure_optimizers(self) -> Any: # type: ignore + """Return the optimizer to use.""" + self.trainer.fit_loop.setup_data() + steps_per_epoch = len(self.trainer.train_dataloader) # type: ignore + return set_up_optimizers( + self.optimizers_cfg, [self.model], steps_per_epoch + ) + + def lr_scheduler_step( # type: ignore # pylint: disable=arguments-differ,line-too-long,unused-argument + self, scheduler: LRSchedulerWrapper, metric: Any | None = None + ) -> None: + """Perform a step on the lr scheduler.""" + # TODO: Support metric if needed + scheduler.step(self.current_epoch) + + def optimizer_step( + self, + epoch: int, + batch_idx: int, + optimizer: Optimizer | LightningOptimizer, + optimizer_closure: GenericFunc | None = None, + ) -> None: + """Optimizer step.""" + if self.check_unused_parameters: + for name, param in self.model.named_parameters(): + if param.grad is None: + rank_zero_info(name) + + optimizer.step(closure=optimizer_closure) diff --git a/vis4d/eval/__init__.py b/vis4d/eval/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..93c79074b0775cf8cfeafa1938fc436faf38b747 --- /dev/null +++ b/vis4d/eval/__init__.py @@ -0,0 +1,5 @@ +"""Evaluation protocols and metrics for different tasks.""" + +from .base import Evaluator + +__all__ = ["Evaluator"] diff --git a/vis4d/eval/base.py b/vis4d/eval/base.py new file mode 100644 index 0000000000000000000000000000000000000000..3b49e6f0a465d03a72b9f8e3064eda4a830cc3e0 --- /dev/null +++ b/vis4d/eval/base.py @@ -0,0 +1,121 @@ +"""Vis4D base evaluation.""" + +from __future__ import annotations + +from vis4d.common.typing import GenericFunc, MetricLogs, unimplemented + + +class Evaluator: # pragma: no cover + """Abstract evaluator class. + + The evaluator is responsible for evaluating the model on a given dataset. + At each end of batches, the process_batch() is called with the model + outputs and the batch data to accumulate the data for evaluation. An + optional save_batch() can be implemented to save the predictions in the + current batch. + + After all batches are processed, the gather() method is called to gather + the data from all ranks. Then, the process() method is used to process all + the accumulated data that are metrics-independent. Finally, the evaluate() + method is called to evaluate the model for the specified metrics and return + the results. Optionally, the save() method can be implemented to save the + predictions for the specified metrics. + + The following diagram illustrates the evaluation process:: + + RANK 0 RANK 1 ... + + x num_batches + ┌────────────────────────────────────────────────────────────────┐ + │ ┌──────────────────────────┐ ┌──────────────────────────┐ │ + │ │ process_batch(data, ...) │ │ process_batch(data, ...) │ │ <- Process a batch (predictions, labels, etc.) + │ └──────────────────────────┘ └──────────────────────────┘ │ and accumulate the data for evaluation. + │ ▼ ▼ │ + │ ┌────────────────────┐ ┌────────────────────┐ │ + │ │ save_batch(metric) │ │ save_batch(metric) │ │ <- Dump the predictions in a batch for a specified + │ └────────────────────┘ └────────────────────┘ │ metric (e.g., for online evaluation). + └────────────────┬──────────────────────────────┬────────────────┘ + ┌─────┴────┐ │ + │ gather() ├─────────────────────────┘ + └──────────┘ <- Gather the data from all ranks + ▼ + ┌───────────┐ + │ process() │ <- Process the data that are + └───────────┘ metrics-independent (if any) + ▼ + ┌──────────────────┐ + │ evaluate(metric) │ <- Evaluate for a specified metric and + └──────────────────┘ return the results. + ▼ + ┌──────────────┐ + │ save(metric) │ <- Dump the predictions for a specified + └──────────────┘ metric (e.g., for online evaluation). + + Note: + The save_batch() saves the predictions every batch, which is helpful + for reducing the memory usage, compared to saving all predictions at + once in the save() method. However, the save_batch() is optional and + can be omitted if the data can be saved only after all batches are + processed. + """ # pylint: disable=line-too-long + + @property + def metrics(self) -> list[str]: + """Return list of metrics to evaluate. + + Returns: + list[str]: Metrics to evaluate. + """ + return [] + + def gather(self, gather_func: GenericFunc) -> None: + """Gather variables in case of distributed setting (if needed). + + Args: + gather_func (Callable[[Any], Any]): Gather function. + """ + + def reset(self) -> None: + """Reset evaluator for new round of evaluation. + + Raises: + NotImplementedError: This is an abstract class method. + """ + raise NotImplementedError + + # Process a batch of data. + process_batch: GenericFunc = unimplemented + + def process(self) -> None: + """Process all accumulated data at the end of an epoch, if any.""" + + def evaluate(self, metric: str) -> tuple[MetricLogs, str]: + """Evaluate all predictions according to given metric. + + Args: + metric (str): Metric to evaluate. + + Raises: + NotImplementedError: This is an abstract class method. + + Returns: + tuple[MetricLogs, str]: Dictionary of scores to log and a pretty + printed string. + """ + raise NotImplementedError + + def save_batch(self, metric: str, output_dir: str) -> None: + """Save batch of predictions to file. + + Args: + metric (str): Save predictions for the specified metrics. + output_dir (str): Output directory. + """ + + def save(self, metric: str, output_dir: str) -> None: + """Save all predictions to file at the end of an epoch. + + Args: + metric (str): Save predictions for the specified metrics. + output_dir (str): Output directory. + """ diff --git a/vis4d/eval/bdd100k/__init__.py b/vis4d/eval/bdd100k/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b4eb30ec2752affdfb193ce2417abfdd9caa470d --- /dev/null +++ b/vis4d/eval/bdd100k/__init__.py @@ -0,0 +1,11 @@ +"""BDD100K evaluators.""" + +from .detect import BDD100KDetectEvaluator +from .seg import BDD100KSegEvaluator +from .track import BDD100KTrackEvaluator + +__all__ = [ + "BDD100KDetectEvaluator", + "BDD100KSegEvaluator", + "BDD100KTrackEvaluator", +] diff --git a/vis4d/eval/bdd100k/detect.py b/vis4d/eval/bdd100k/detect.py new file mode 100644 index 0000000000000000000000000000000000000000..a252c9188977ff9d99027008cb565cc890efc988 --- /dev/null +++ b/vis4d/eval/bdd100k/detect.py @@ -0,0 +1,36 @@ +"""BDD100K detection evaluator.""" + +from __future__ import annotations + +from vis4d.common.imports import BDD100K_AVAILABLE +from vis4d.eval.scalabel import ScalabelDetectEvaluator + +if BDD100K_AVAILABLE: + from bdd100k.common.utils import load_bdd100k_config +else: + raise ImportError("bdd100k is not installed.") + + +class BDD100KDetectEvaluator(ScalabelDetectEvaluator): + """BDD100K 2D detection evaluation class.""" + + METRICS_DET = "Det" + METRICS_INS_SEG = "InsSeg" + + def __init__( + self, + annotation_path: str, + config_path: str, + mask_threshold: float = 0.0, + ) -> None: + """Initialize the evaluator.""" + config = load_bdd100k_config(config_path) + super().__init__( + annotation_path=annotation_path, + config=config.scalabel, + mask_threshold=mask_threshold, + ) + + def __repr__(self) -> str: + """Concise representation of the dataset evaluator.""" + return "BDD100K Detection Evaluator" diff --git a/vis4d/eval/bdd100k/seg.py b/vis4d/eval/bdd100k/seg.py new file mode 100644 index 0000000000000000000000000000000000000000..a874d642bdda8bb0f6df416ac5ba8d9f8eaf49d6 --- /dev/null +++ b/vis4d/eval/bdd100k/seg.py @@ -0,0 +1,101 @@ +"""BDD100K segmentation evaluator.""" + +from __future__ import annotations + +import itertools +from collections.abc import Callable +from typing import Any + +import numpy as np + +from vis4d.common.array import array_to_numpy +from vis4d.common.imports import BDD100K_AVAILABLE, SCALABEL_AVAILABLE +from vis4d.common.typing import ArrayLike, MetricLogs +from vis4d.data.datasets.bdd100k import bdd100k_seg_map + +from ..base import Evaluator + +if SCALABEL_AVAILABLE and BDD100K_AVAILABLE: + from bdd100k.common.utils import load_bdd100k_config + from bdd100k.label.to_scalabel import bdd100k_to_scalabel + from scalabel.eval.sem_seg import evaluate_sem_seg + from scalabel.label.io import load + from scalabel.label.transforms import mask_to_rle + from scalabel.label.typing import Frame, Label +else: + raise ImportError("scalabel or bdd100k is not installed.") + + +class BDD100KSegEvaluator(Evaluator): + """BDD100K segmentation evaluation class.""" + + inverse_seg_map = {v: k for k, v in bdd100k_seg_map.items()} + + def __init__(self, annotation_path: str) -> None: + """Initialize the evaluator.""" + super().__init__() + self.annotation_path = annotation_path + self.frames: list[Frame] = [] + + bdd100k_anns = load(annotation_path) + frames = bdd100k_anns.frames + self.config = load_bdd100k_config("sem_seg") + self.gt_frames = bdd100k_to_scalabel(frames, self.config) + + self.reset() + + def __repr__(self) -> str: + """Concise representation of the dataset evaluator.""" + return "BDD100K Segmentation Evaluator" + + @property + def metrics(self) -> list[str]: + """Supported metrics.""" + return ["sem_seg"] + + def gather( # type: ignore # pragma: no cover + self, gather_func: Callable[[Any], Any] + ) -> None: + """Gather variables in case of distributed setting (if needed). + + Args: + gather_func (Callable[[Any], Any]): Gather function. + """ + all_preds = gather_func(self.frames) + if all_preds is not None: + self.frames = list(itertools.chain(*all_preds)) + + def reset(self) -> None: + """Reset the evaluator.""" + self.frames = [] + + def process_batch( + self, data_names: list[str], masks_list: list[ArrayLike] + ) -> None: + """Process tracking results.""" + masks_numpy = [array_to_numpy(m, None) for m in masks_list] # to numpy + for data_name, masks in zip(data_names, masks_numpy): + labels = [] + for i, class_id in enumerate(np.unique(masks)): + label = Label( + rle=mask_to_rle((masks == class_id).astype(np.uint8)), + category=self.inverse_seg_map[int(class_id)], + id=str(i), + ) + labels.append(label) + frame = Frame(name=data_name, labels=labels) + self.frames.append(frame) + + def evaluate(self, metric: str) -> tuple[MetricLogs, str]: + """Evaluate the dataset.""" + if metric == "sem_seg": + results = evaluate_sem_seg( + ann_frames=self.gt_frames, + pred_frames=self.frames, + config=self.config.scalabel, + nproc=0, + ) + else: + raise NotImplementedError + + return {}, str(results) diff --git a/vis4d/eval/bdd100k/track.py b/vis4d/eval/bdd100k/track.py new file mode 100644 index 0000000000000000000000000000000000000000..a71b08f1facc2e199575c9939e99938516483a06 --- /dev/null +++ b/vis4d/eval/bdd100k/track.py @@ -0,0 +1,81 @@ +"""BDD100K tracking evaluator.""" + +from __future__ import annotations + +from vis4d.common.imports import BDD100K_AVAILABLE, SCALABEL_AVAILABLE +from vis4d.common.typing import MetricLogs +from vis4d.data.datasets.bdd100k import bdd100k_track_map + +from ..scalabel.track import ScalabelTrackEvaluator + +if SCALABEL_AVAILABLE and BDD100K_AVAILABLE: + from bdd100k.common.utils import load_bdd100k_config + from bdd100k.label.to_scalabel import bdd100k_to_scalabel + from scalabel.eval.detect import evaluate_det + from scalabel.eval.mot import acc_single_video_mot, evaluate_track + from scalabel.label.io import group_and_sort +else: + raise ImportError("scalabel or bdd100k is not installed.") + + +class BDD100KTrackEvaluator(ScalabelTrackEvaluator): + """BDD100K 2D tracking evaluation class.""" + + METRICS_DET = "Det" + METRICS_TRACK = "Track" + + def __init__( + self, + annotation_path: str, + config_path: str = "box_track", + mask_threshold: float = 0.0, + ) -> None: + """Initialize the evaluator.""" + config = load_bdd100k_config(config_path) + super().__init__( + annotation_path=annotation_path, + config=config.scalabel, + mask_threshold=mask_threshold, + ) + self.gt_frames = bdd100k_to_scalabel(self.gt_frames, config) + self.inverse_cat_map = {v: k for k, v in bdd100k_track_map.items()} + + def __repr__(self) -> str: + """Concise representation of the dataset evaluator.""" + return "BDD100K Tracking Evaluator" + + @property + def metrics(self) -> list[str]: + """Supported metrics.""" + return [self.METRICS_DET, self.METRICS_TRACK] + + def evaluate(self, metric: str) -> tuple[MetricLogs, str]: + """Evaluate the dataset.""" + assert self.config is not None, "BDD100K config is not loaded." + metrics_log: MetricLogs = {} + short_description = "" + + if metric == self.METRICS_DET: + det_results = evaluate_det( + self.gt_frames, + self.frames, + config=self.config, + nproc=0, + ) + for metric_name, metric_value in det_results.summary().items(): + metrics_log[metric_name] = metric_value + short_description += str(det_results) + "\n" + + if metric == self.METRICS_TRACK: + track_results = evaluate_track( + acc_single_video_mot, + gts=group_and_sort(self.gt_frames), + results=group_and_sort(self.frames), + config=self.config, + nproc=1, + ) + for metric_name, metric_value in track_results.summary().items(): + metrics_log[metric_name] = metric_value + short_description += str(track_results) + "\n" + + return metrics_log, short_description diff --git a/vis4d/eval/coco/__init__.py b/vis4d/eval/coco/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a05a9be3e7c60702df544c15b7935b4f0c133d5e --- /dev/null +++ b/vis4d/eval/coco/__init__.py @@ -0,0 +1,5 @@ +"""Detection evaluators.""" + +from .detect import COCODetectEvaluator + +__all__ = ["COCODetectEvaluator"] diff --git a/vis4d/eval/coco/detect.py b/vis4d/eval/coco/detect.py new file mode 100644 index 0000000000000000000000000000000000000000..dce40302e9ed1b99d1249f55fb7393e59010460c --- /dev/null +++ b/vis4d/eval/coco/detect.py @@ -0,0 +1,289 @@ +"""COCO evaluator.""" + +from __future__ import annotations + +import contextlib +import copy +import io +import itertools + +import numpy as np +import pycocotools.mask as maskUtils +from pycocotools.coco import COCO +from pycocotools.cocoeval import COCOeval +from terminaltables import AsciiTable + +from vis4d.common.array import array_to_numpy +from vis4d.common.logging import rank_zero_warn +from vis4d.common.typing import ( + ArrayLike, + DictStrAny, + GenericFunc, + MetricLogs, + NDArrayF32, + NDArrayI64, +) +from vis4d.data.datasets.coco import coco_det_map + +from ..base import Evaluator + + +def xyxy_to_xywh(boxes: NDArrayF32) -> NDArrayF32: + """Convert Tensor [N, 4] in xyxy format into xywh. + + Args: + boxes (NDArrayF32): Bounding boxes in Vis4D format. + + Returns: + NDArrayF32: COCO format bounding boxes. + """ + boxes[:, 2] = boxes[:, 2] - boxes[:, 0] + boxes[:, 3] = boxes[:, 3] - boxes[:, 1] + return boxes + + +class COCOevalV2(COCOeval): # type: ignore + """Subclass COCO eval for logging / printing.""" + + def summarize(self) -> str: + """Capture summary in string. + + Returns: + str: Pretty printed string. + """ + f = io.StringIO() + with contextlib.redirect_stdout(f): + super().summarize() + summary_str = "\n" + f.getvalue() + return summary_str + + +def predictions_to_coco( + cat_map: dict[str, int], + coco_id2name: dict[int, str], + image_id: int, + boxes: NDArrayF32, + scores: NDArrayF32, + classes: NDArrayI64, + masks: None | NDArrayF32 = None, +) -> list[DictStrAny]: + """Convert Vis4D format predictions to COCO format. + + Args: + cat_map (dict[str, int]): COCO class name to class ID mapping. + coco_id2name (dict[int, str]): COCO class ID to class name mapping. + image_id (int): ID of image. + boxes (NDArrayF32): Predicted bounding boxes. + scores (NDArrayF32): Predicted scores for each box. + classes (NDArrayI64): Predicted classes for each box. + masks (None | NDArrayF32, optional): Predicted masks. Defaults to + None. + + Returns: + list[DictStrAny]: Predictions in COCO format. + """ + predictions = [] + boxes_xyxy = copy.deepcopy(boxes) + boxes_xywh = xyxy_to_xywh(boxes_xyxy) + for i, (box, score, cls) in enumerate(zip(boxes_xywh, scores, classes)): + mask = masks[i] if masks is not None else None + xywh = box.tolist() + area = float(xywh[2] * xywh[3]) + annotation = { + "image_id": image_id, + "bbox": xywh, + "area": area, + "score": float(score), + "category_id": cat_map[coco_id2name[int(cls)]], + "iscrowd": 0, + } + if mask is not None: + annotation["segmentation"] = maskUtils.encode( + np.array(mask, order="F", dtype="uint8") + ) + annotation["segmentation"]["counts"] = annotation["segmentation"][ + "counts" + ].decode() + predictions.append(annotation) + return predictions + + +class COCODetectEvaluator(Evaluator): + """COCO detection evaluation class.""" + + METRIC_DET = "Det" + METRIC_INS_SEG = "InsSeg" + + def __init__( + self, + data_root: str, + split: str = "val2017", + per_class_eval: bool = False, + ) -> None: + """Creates an instance of the class. + + Args: + data_root (str): Root directory of data. + split (str, optional): COCO data split. Defaults to "val2017". + per_class_eval (bool, optional): Per-class evaluation. Defaults to + False. + """ + super().__init__() + self.per_class_eval = per_class_eval + self.coco_id2name = {v: k for k, v in coco_det_map.items()} + self.annotation_path = ( + f"{data_root}/annotations/instances_{split}.json" + ) + with contextlib.redirect_stdout(io.StringIO()): + self._coco_gt = COCO(self.annotation_path) + coco_gt_cats = self._coco_gt.loadCats(self._coco_gt.getCatIds()) + self.cat_map = {c["name"]: c["id"] for c in coco_gt_cats} + self._predictions: list[DictStrAny] = [] + + @property + def metrics(self) -> list[str]: + """Supported metrics. + + Returns: + list[str]: Metrics to evaluate. + """ + return [self.METRIC_DET, self.METRIC_INS_SEG] + + def gather(self, gather_func: GenericFunc) -> None: + """Accumulate predictions across processes.""" + all_preds = gather_func(self._predictions) + if all_preds is not None: + self._predictions = list(itertools.chain(*all_preds)) + + def reset(self) -> None: + """Reset the saved predictions to start new round of evaluation.""" + self._predictions = [] + + def process_batch( + self, + coco_image_id: list[int], + pred_boxes: list[ArrayLike], + pred_scores: list[ArrayLike], + pred_classes: list[ArrayLike], + pred_masks: None | list[ArrayLike] = None, + ) -> None: + """Process sample and convert detections to coco format. + + coco_image_id (list[int]): COCO image ID. + pred_boxes (list[ArrayLike]): Predicted bounding boxes. + pred_scores (list[ArrayLike]): Predicted scores for each box. + pred_classes (list[ArrayLike]): Predicted classes for each box. + pred_masks (None | list[ArrayLike], optional): Predicted masks. + """ + for i, (image_id, boxes, scores, classes) in enumerate( + zip(coco_image_id, pred_boxes, pred_scores, pred_classes) + ): + boxes_np = array_to_numpy(boxes, n_dims=None, dtype=np.float32) + scores_np = array_to_numpy(scores, n_dims=None, dtype=np.float32) + classes_np = array_to_numpy(classes, n_dims=None, dtype=np.int64) + + if pred_masks is not None: + masks_np = array_to_numpy( + pred_masks[i], n_dims=3, dtype=np.float32 + ) + else: + masks_np = None + + coco_preds = predictions_to_coco( + self.cat_map, + self.coco_id2name, + image_id, + boxes_np, + scores_np, + classes_np, + masks_np, + ) + self._predictions.extend(coco_preds) + + def evaluate(self, metric: str) -> tuple[MetricLogs, str]: + """Evaluate COCO predictions. + + Args: + metric (str): Metric to evaluate. Should be "COCO_AP". + + Raises: + NotImplementedError: Raised if metric is not "COCO_AP". + RuntimeError: Raised if no predictions are available. + + Returns: + tuple[MetricLogs, str]: Dictionary of scores to log and a pretty + printed string. + """ + if metric not in [self.METRIC_DET, self.METRIC_INS_SEG]: + raise NotImplementedError(f"Metric {metric} not known!") + + if len(self._predictions) == 0: + rank_zero_warn( + "No predictions to evaluate. Make sure to process batch first!" + ) + return { + "AP": 0.0, + "AP50": 0.0, + "AP75": 0.0, + "APs": 0.0, + "APm": 0.0, + "APl": 0.0, + }, "No predictions to evaluate." + + if metric == self.METRIC_DET: + iou_type = "bbox" + _predictions = self._predictions + else: + # remove bbox for segm evaluation so cocoapi will use mask + # area instead of box area + iou_type = "segm" + _predictions = copy.deepcopy(self._predictions) + for pred in _predictions: + pred.pop("bbox") + coco_dt = self._coco_gt.loadRes(_predictions) + + with contextlib.redirect_stdout(io.StringIO()): + assert coco_dt is not None + evaluator = COCOevalV2(self._coco_gt, coco_dt, iouType=iou_type) + evaluator.evaluate() + evaluator.accumulate() + + log_str = evaluator.summarize() + metrics = ["AP", "AP50", "AP75", "APs", "APm", "APl"] + score_dict = dict(zip(metrics, evaluator.stats)) + + if self.per_class_eval: + # Compute per-category AP + # from https://github.com/facebookresearch/detectron2/ + precisions = evaluator.eval["precision"] + # precision: (iou, recall, cls, area range, max dets) + assert len(self._coco_gt.getCatIds()) == precisions.shape[2] + + results_per_category = [] + for idx, cat_id in enumerate(self._coco_gt.getCatIds()): + # area range index 0: all area ranges + # max dets index -1: typically 100 per image + nm = self._coco_gt.loadCats(cat_id)[0] + precision = precisions[:, :, idx, 0, -1] + precision = precision[precision > -1] + if precision.size: + ap = np.mean(precision).item() + else: + ap = float("nan") + results_per_category.append((f'{nm["name"]}', f"{ap:0.3f}")) + + num_columns = min(6, len(results_per_category) * 2) + results_flatten = list(itertools.chain(*results_per_category)) + headers = ["category", "AP"] * (num_columns // 2) + results_2d = itertools.zip_longest( + *[results_flatten[i::num_columns] for i in range(num_columns)] + ) + table_data = [headers] + list(results_2d) + table = AsciiTable(table_data) + log_str = f"\n{table.table}\n{log_str}" + + return score_dict, log_str + + def __repr__(self) -> str: + """Returns the string representation of the object.""" + return f"CocoEvaluator(annotation_path={self.annotation_path})" diff --git a/vis4d/eval/common/__init__.py b/vis4d/eval/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a5434fe3032bca59d8c6f6333a2b7f42845d9e28 --- /dev/null +++ b/vis4d/eval/common/__init__.py @@ -0,0 +1,15 @@ +"""Common evaluation code.""" + +from .binary import BinaryEvaluator +from .cls import ClassificationEvaluator +from .depth import DepthEvaluator +from .flow import OpticalFlowEvaluator +from .seg import SegEvaluator + +__all__ = [ + "ClassificationEvaluator", + "DepthEvaluator", + "OpticalFlowEvaluator", + "BinaryEvaluator", + "SegEvaluator", +] diff --git a/vis4d/eval/common/binary.py b/vis4d/eval/common/binary.py new file mode 100644 index 0000000000000000000000000000000000000000..30f661d398097218b9757dae567ccb101bd7c6f7 --- /dev/null +++ b/vis4d/eval/common/binary.py @@ -0,0 +1,191 @@ +"""Binary occupancy evaluator.""" + +from __future__ import annotations + +import numpy as np + +from vis4d.common.array import array_to_numpy +from vis4d.common.typing import ( + ArrayLike, + MetricLogs, + NDArrayBool, + NDArrayNumber, +) +from vis4d.eval.base import Evaluator + + +def threshold_and_flatten( + prediction: NDArrayNumber, target: NDArrayNumber, threshold_value: float +) -> tuple[NDArrayBool, NDArrayBool]: + """Thresholds the predictions based on the provided treshold value. + + Applies the following actions: + prediction -> prediction >= threshold_value + pred, gt = pred.ravel().bool(), gt.ravel().bool() + + Args: + prediction: Prediction array with continuous values + target: Grondgtruth values {0,1} + threshold_value: Value to use to convert the continuous prediction + into binary. + + Returns: + tuple of two boolean arrays, prediction and target + """ + prediction_bin: NDArrayBool = prediction >= threshold_value + return prediction_bin.ravel().astype(bool), target.ravel().astype(bool) + + +class BinaryEvaluator(Evaluator): + """Creates a new Evaluater that evaluates binary predictions.""" + + METRIC_BINARY = "BinaryCls" + + KEY_IOU = "IoU" + KEY_ACCURACY = "Accuracy" + KEY_F1 = "F1" + KEY_PRECISION = "Precision" + KEY_RECALL = "Recall" + + def __init__( + self, + threshold: float = 0.5, + ) -> None: + """Creates a new binary evaluator. + + Args: + threshold (float): Threshold for prediction to convert + to binary. All prediction that are higher than + this value will be assigned the 'True' label + """ + super().__init__() + self.threshold = threshold + self.reset() + + self.true_positives: list[float] = [] + self.false_positives: list[float] = [] + self.true_negatives: list[float] = [] + self.false_negatives: list[float] = [] + self.n_samples: list[float] = [] + + self.has_samples = False + + def _calc_confusion_matrix( + self, prediction: NDArrayBool, target: NDArrayBool + ) -> None: + """Calculates the confusion matrix and stores them as attributes. + + Args: + prediction: the prediction (binary) (N, Pts) + target: the groundtruth (binary) (N, Pts) + """ + tp = int(np.sum(np.logical_and(prediction == 1, target == 1))) + fp = int(np.sum(np.logical_and(prediction == 1, target == 0))) + tn = int(np.sum(np.logical_and(prediction == 0, target == 0))) + fn = int(np.sum(np.logical_and(prediction == 0, target == 1))) + self.true_positives.append(tp) + self.false_positives.append(fp) + self.true_negatives.append(tn) + self.false_negatives.append(fn) + self.n_samples.append(tp + fp + tn + fn) + + @property + def metrics(self) -> list[str]: + """Supported metrics.""" + return [self.METRIC_BINARY] + + def reset(self) -> None: + """Reset the saved predictions to start new round of evaluation.""" + self.true_positives = [] + self.false_positives = [] + self.true_negatives = [] + self.false_negatives = [] + self.n_samples = [] + + def process_batch( + self, + prediction: ArrayLike, + groundtruth: ArrayLike, + ) -> None: + """Processes a new (batch) of predictions. + + Calculates the metrics and caches them internally. + + Args: + prediction: the prediction(continuous values or bin) (Batch x Pts) + groundtruth: the groundtruth (binary) (Batch x Pts) + """ + pred, gt = threshold_and_flatten( + array_to_numpy(prediction, n_dims=None, dtype=np.float32), + array_to_numpy(groundtruth, n_dims=None, dtype=np.bool_), + self.threshold, + ) + + # Confusion Matrix + self._calc_confusion_matrix(pred, gt) + self.has_samples = True + + def evaluate(self, metric: str) -> tuple[MetricLogs, str]: + """Evaluate predictions. + + Returns a dict containing the raw data and a + short description string containing a readable result. + + Args: + metric (str): Metric to use. See @property metric + + Returns: + metric_data, description + tuple containing the metric data (dict with metric name and value) + as well as a short string with shortened information. + + Raises: + RuntimeError: if no data has been registered to be evaluated. + ValueError: if metric is not supported. + """ + if not self.has_samples: + raise RuntimeError( + """No data registered to calculate metric. + Register data using .process() first!""" + ) + metric_data: MetricLogs = {} + short_description = "" + + if metric == self.METRIC_BINARY: + # IoU + iou = sum(self.true_positives) / ( + sum(self.n_samples) - sum(self.true_negatives) + 1e-6 + ) + metric_data[self.KEY_IOU] = iou + short_description += f"IoU: {iou:.3f}\n" + + # Accuracy + acc = (sum(self.true_positives) + sum(self.true_negatives)) / sum( + self.n_samples + ) + metric_data[self.KEY_ACCURACY] = acc + short_description += f"Accuracy: {acc:.3f}\n" + + # Precision + tp_fp = sum(self.true_positives) + sum(self.false_positives) + precision = sum(self.true_positives) / tp_fp if tp_fp != 0 else 1 + metric_data[self.KEY_PRECISION] = precision + short_description += f"Precision: {precision:.3f}\n" + + # Recall + tp_fn = sum(self.true_positives) + sum(self.false_negatives) + recall = sum(self.true_positives) / tp_fn if tp_fn != 0 else 1 + metric_data[self.KEY_RECALL] = recall + short_description += f"Recall: {acc:.3f}\n" + + # F1 + f1 = 2 * precision * recall / (precision + recall + 1e-8) + metric_data[self.KEY_F1] = f1 + short_description += f"F1: {f1:.3f}\n" + + else: + raise ValueError( + f"Unsupported metric: {metric}" + ) # pragma: no cover + + return metric_data, short_description diff --git a/vis4d/eval/common/cls.py b/vis4d/eval/common/cls.py new file mode 100644 index 0000000000000000000000000000000000000000..7db6482c5d2254190f7fdd7618e508040bebf42a --- /dev/null +++ b/vis4d/eval/common/cls.py @@ -0,0 +1,137 @@ +"""Image classification evaluator.""" + +from __future__ import annotations + +import itertools + +import numpy as np + +from vis4d.common.array import array_to_numpy +from vis4d.common.typing import ( + ArrayLike, + GenericFunc, + MetricLogs, + NDArrayI64, + NDArrayNumber, +) +from vis4d.eval.base import Evaluator + +from ..metrics.cls import accuracy + + +class ClassificationEvaluator(Evaluator): + """Multi-class classification evaluator.""" + + METRIC_CLASSIFICATION = "Cls" + + KEY_ACCURACY = "Acc@1" + KEY_ACCURACY_TOP5 = "Acc@5" + + def __init__(self) -> None: + """Initialize the classification evaluator.""" + super().__init__() + self._metrics_list: list[dict[str, float]] = [] + + @property + def metrics(self) -> list[str]: + """Supported metrics.""" + return [ + self.KEY_ACCURACY, + self.KEY_ACCURACY_TOP5, + ] + + def reset(self) -> None: + """Reset evaluator for new round of evaluation.""" + self._metrics_list = [] + + def _is_correct( + self, pred: NDArrayNumber, target: NDArrayI64, top_k: int = 1 + ) -> bool: + """Check if the prediction is correct for top-k. + + Args: + pred (NDArrayNumber): Prediction logits, in shape (C, ). + target (NDArrayI64): Target logits, in shape (1, ). + top_k (int, optional): Top-k to check. Defaults to 1. + + Returns: + bool: Whether the prediction is correct. + """ + top_k = min(top_k, pred.shape[0]) + top_k_idx = np.argsort(pred)[-top_k:] + return bool(np.any(top_k_idx == target)) + + def process_batch( # type: ignore # pylint: disable=arguments-differ + self, prediction: ArrayLike, groundtruth: ArrayLike + ): + """Process a batch of predictions and groundtruths. + + Args: + prediction (ArrayLike): Prediction, in shape (N, C). + groundtruth (ArrayLike): Groundtruth, in shape (N, ). + """ + pred = array_to_numpy(prediction, n_dims=None, dtype=np.float32) + gt = array_to_numpy(groundtruth, n_dims=None, dtype=np.int64) + for i in range(pred.shape[0]): + self._metrics_list.append( + { + "top1_correct": accuracy(pred[i], gt[i], top_k=1), + "top5_correct": accuracy(pred[i], gt[i], top_k=5), + } + ) + + def gather(self, gather_func: GenericFunc) -> None: + """Accumulate predictions across processes.""" + all_metrics = gather_func(self._metrics_list) + if all_metrics is not None: + self._metrics_list = list(itertools.chain(*all_metrics)) + + def evaluate(self, metric: str) -> tuple[MetricLogs, str]: + """Evaluate predictions. + + Returns a dict containing the raw data and a + short description string containing a readable result. + + Args: + metric (str): Metric to use. See @property metric + + Returns: + metric_data, description + tuple containing the metric data (dict with metric name and value) + as well as a short string with shortened information. + + Raises: + RuntimeError: if no data has been registered to be evaluated. + ValueError: if the metric is not supported. + """ + if len(self._metrics_list) == 0: + raise RuntimeError( + """No data registered to calculate metric. + Register data using .process() first!""" + ) + metric_data: MetricLogs = {} + short_description = "" + + if metric == self.METRIC_CLASSIFICATION: + # Top1 accuracy + top1_correct = np.array( + [metric["top1_correct"] for metric in self._metrics_list] + ) + top1_acc = np.mean(top1_correct) + metric_data[self.KEY_ACCURACY] = top1_acc + short_description += f"Top1 Accuracy: {top1_acc:.4f}\n" + + # Top5 accuracy + top5_correct = np.array( + [metric["top5_correct"] for metric in self._metrics_list] + ) + top5_acc = np.mean(top5_correct) + metric_data[self.KEY_ACCURACY_TOP5] = top5_acc + short_description += f"Top5 Accuracy: {top5_acc:.4f}\n" + + else: + raise ValueError( + f"Unsupported metric: {metric}" + ) # pragma: no cover + + return metric_data, short_description diff --git a/vis4d/eval/common/depth.py b/vis4d/eval/common/depth.py new file mode 100644 index 0000000000000000000000000000000000000000..9d9b3244c3ba611663d47aacb237336675614919 --- /dev/null +++ b/vis4d/eval/common/depth.py @@ -0,0 +1,214 @@ +"""Depth estimation evaluator.""" + +from __future__ import annotations + +import itertools + +import numpy as np + +from vis4d.common.array import array_to_numpy +from vis4d.common.typing import ( + ArrayLike, + GenericFunc, + MetricLogs, + NDArrayFloat, +) +from vis4d.eval.base import Evaluator + +from ..metrics.depth import ( + absolute_error, + absolute_relative_error, + delta_p, + log_10_error, + root_mean_squared_error, + root_mean_squared_error_log, + scale_invariant_log, + squared_relative_error, +) + + +class DepthEvaluator(Evaluator): + """Depth estimation evaluator.""" + + METRIC_DEPTH = "Depth" + + KEY_DELTA05 = "d05" + KEY_DELTA1 = "d1" + KEY_DELTA2 = "d2" + KEY_DELTA3 = "d3" + + KEY_ABS_REL = "AbsRel" + KEY_ABS_ERR = "AbsErr" + KEY_SQ_REL = "SqRel" + KEY_RMSE = "RMSE" + KEY_RMSE_LOG = "RMSELog" + KEY_SILOG = "SILog" + KEY_LOG10 = "Log10" + + def __init__( + self, + min_depth: float = 0.0, + max_depth: float = 80.0, + scale: float = 1.0, + epsilon: float = 1e-3, + ) -> None: + """Initialize the optical flow evaluator. + + Args: + min_depth (float): Minimum depth to evaluate. Defaults to 0.001. + max_depth (float): Maximum depth to evaluate. Defaults to 80.0. + scale (float): Scale factor for depth. Defaults to 1.0. + epsilon (float): Small value to avoid logarithms of small values. + Defaults to 1e-3. + """ + super().__init__() + self.min_depth = min_depth + self.max_depth = max_depth + self.epsilon = epsilon + self.scale = scale + self._metrics_list: list[dict[str, float]] = [] + + def __repr__(self) -> str: + """Concise representation of the evaluator.""" + return "Common Depth Evaluator" + + @property + def metrics(self) -> list[str]: + """Supported metrics.""" + return [self.METRIC_DEPTH] + + def reset(self) -> None: + """Reset evaluator for new round of evaluation.""" + self._metrics_list = [] + + def gather(self, gather_func: GenericFunc) -> None: + """Accumulate predictions across processes.""" + all_metrics = gather_func(self._metrics_list) + if all_metrics is not None: + self._metrics_list = list(itertools.chain(*all_metrics)) + + def _apply_mask( + self, prediction: NDArrayFloat, target: NDArrayFloat + ) -> tuple[NDArrayFloat, NDArrayFloat]: + """Apply mask to prediction and target.""" + mask = (target > self.min_depth) & (target <= self.max_depth) + return prediction[mask], target[mask] + + def process_batch( + self, prediction: ArrayLike, groundtruth: ArrayLike + ) -> None: + """Process a batch of data. + + Args: + prediction (np.array): Prediction optical flow, in shape (B, H, W). + groundtruth (np.array): Target optical flow, in shape (B, H, W). + """ + preds = ( + array_to_numpy(prediction, n_dims=None, dtype=np.float32) + * self.scale + ) + gts = array_to_numpy(groundtruth, n_dims=None, dtype=np.float32) + + for pred, gt in zip(preds, gts): + pred, gt = self._apply_mask(pred, gt) + self._metrics_list.append( + { + self.KEY_ABS_REL: absolute_relative_error(pred, gt), + self.KEY_ABS_ERR: absolute_error(pred, gt), + self.KEY_SQ_REL: squared_relative_error(pred, gt), + self.KEY_RMSE: root_mean_squared_error(pred, gt), + self.KEY_RMSE_LOG: root_mean_squared_error_log(pred, gt), + self.KEY_SILOG: scale_invariant_log(pred, gt), + self.KEY_DELTA05: delta_p(pred, gt, 0.5), + self.KEY_DELTA1: delta_p(pred, gt, 1.0), + self.KEY_DELTA2: delta_p(pred, gt, 2.0), + self.KEY_DELTA3: delta_p(pred, gt, 3.0), + self.KEY_LOG10: log_10_error(pred, gt), + } + ) + + def evaluate(self, metric: str) -> tuple[MetricLogs, str]: + """Evaluate predictions. + + Returns a dict containing the raw data and a + short description string containing a readablae result. + + Args: + metric (str): Metric to use. See @property metric + + Returns: + metric_data, description + tuple containing the metric data (dict with metric name and value) + as well as a short string with shortened information. + + Raises: + RuntimeError: if no data has been registered to be evaluated. + ValueError: if metric is not supported. + """ + if len(self._metrics_list) == 0: + raise RuntimeError( + """No data registered to calculate metric. + Register data using .process() first!""" + ) + metric_data: MetricLogs = {} + short_description = "\n" + + if metric == self.METRIC_DEPTH: + abs_rel = np.mean( + [x[self.KEY_ABS_REL] for x in self._metrics_list] + ) + metric_data[self.KEY_ABS_REL] = float(abs_rel) + short_description += f"Absolute relative error: {abs_rel:.3f}\n" + + abs_err = np.mean( + [x[self.KEY_ABS_ERR] for x in self._metrics_list] + ) + metric_data[self.KEY_ABS_ERR] = float(abs_err) + short_description += f"Absolute error: {abs_err:.3f}\n" + + sq_rel = np.mean([x[self.KEY_SQ_REL] for x in self._metrics_list]) + metric_data[self.KEY_SQ_REL] = float(sq_rel) + short_description += f"Squared relative error: {sq_rel:.3f}\n" + + rmse = np.mean([x[self.KEY_RMSE] for x in self._metrics_list]) + metric_data[self.KEY_RMSE] = float(rmse) + short_description += f"RMSE: {rmse:.3f}\n" + + rmse_log = np.mean( + [x[self.KEY_RMSE_LOG] for x in self._metrics_list] + ) + metric_data[self.KEY_RMSE_LOG] = float(rmse_log) + short_description += f"RMSE log: {rmse_log:.3f}\n" + + silog = np.mean([x[self.KEY_SILOG] for x in self._metrics_list]) + metric_data[self.KEY_SILOG] = float(silog) + short_description += f"SILog: {silog:.3f}\n" + + delta05 = np.mean( + [x[self.KEY_DELTA05] for x in self._metrics_list] + ) + metric_data[self.KEY_DELTA05] = float(delta05) + short_description += f"Delta 0.5: {delta05:.3f}\n" + + delta1 = np.mean([x[self.KEY_DELTA1] for x in self._metrics_list]) + metric_data[self.KEY_DELTA1] = float(delta1) + short_description += f"Delta 1: {delta1:.3f}\n" + + delta2 = np.mean([x[self.KEY_DELTA2] for x in self._metrics_list]) + metric_data[self.KEY_DELTA2] = float(delta2) + short_description += f"Delta 2: {delta2:.3f}\n" + + delta3 = np.mean([x[self.KEY_DELTA3] for x in self._metrics_list]) + metric_data[self.KEY_DELTA3] = float(delta3) + short_description += f"Delta 3: {delta3:.3f}\n" + + log10 = np.mean([x[self.KEY_LOG10] for x in self._metrics_list]) + metric_data[self.KEY_LOG10] = float(log10) + short_description += f"Log10 error: {log10:.3f}\n" + + else: + raise ValueError( + f"Unsupported metric: {metric}" + ) # pragma: no cover + + return metric_data, short_description diff --git a/vis4d/eval/common/flow.py b/vis4d/eval/common/flow.py new file mode 100644 index 0000000000000000000000000000000000000000..1a591587b2b3c4b5c9aa722ee1da0bc5c1f97f12 --- /dev/null +++ b/vis4d/eval/common/flow.py @@ -0,0 +1,152 @@ +"""Optical flow evaluator.""" + +from __future__ import annotations + +import itertools + +import numpy as np + +from vis4d.common.array import array_to_numpy +from vis4d.common.typing import ( + ArrayLike, + GenericFunc, + MetricLogs, + NDArrayFloat, +) +from vis4d.eval.base import Evaluator + +from ..metrics.flow import angular_error, end_point_error + + +class OpticalFlowEvaluator(Evaluator): + """Optical flow evaluator.""" + + METRIC_FLOW = "Flow" + + KEY_ENDPOINT_ERROR = "EPE" + KEY_ANGULAR_ERROR = "AE" + + def __init__( + self, + max_flow: float = 400.0, + use_degrees: bool = False, + scale: float = 1.0, + epsilon: float = 1e-6, + ) -> None: + """Initialize the optical flow evaluator. + + Args: + max_flow (float, optional): Maximum flow value. Defaults to 400.0. + use_degrees (bool, optional): Whether to use degrees for angular + error. Defaults to False. + scale (float, optional): Scale factor for the optical flow. + Defaults to 1.0. + epsilon (float, optional): Epsilon value for numerical stability. + """ + super().__init__() + self.max_flow = max_flow + self.use_degrees = use_degrees + self.scale = scale + self.epsilon = epsilon + self._metrics_list: list[dict[str, float]] = [] + + @property + def metrics(self) -> list[str]: + """Supported metrics.""" + return [ + OpticalFlowEvaluator.METRIC_FLOW, + ] + + def reset(self) -> None: + """Reset evaluator for new round of evaluation.""" + self._metrics_list = [] + + def _apply_mask( + self, prediction: NDArrayFloat, target: NDArrayFloat + ) -> tuple[NDArrayFloat, NDArrayFloat]: + """Apply mask to prediction and target.""" + mask = np.sum(np.abs(target), axis=-1) <= self.max_flow + return prediction[mask], target[mask] + + def process_batch( + self, prediction: ArrayLike, groundtruth: ArrayLike + ) -> None: + """Process a batch of data. + + Args: + prediction (NDArrayNumber): Prediction optical flow, in shape + (N, H, W, 2). + groundtruth (NDArrayNumber): Target optical flow, in shape + (N, H, W, 2). + """ + preds = ( + array_to_numpy(prediction, n_dims=None, dtype=np.float32) + * self.scale + ) + gts = array_to_numpy(groundtruth, n_dims=None, dtype=np.float32) + + for pred, gt in zip(preds, gts): + pred, gt = self._apply_mask(pred, gt) + epe = end_point_error(pred, gt) + ae = angular_error(pred, gt, self.epsilon) + self._metrics_list.append( + { + OpticalFlowEvaluator.KEY_ENDPOINT_ERROR: epe, + OpticalFlowEvaluator.KEY_ANGULAR_ERROR: ae, + } + ) + + def gather(self, gather_func: GenericFunc) -> None: + """Accumulate predictions across processes.""" + all_metrics = gather_func(self._metrics_list) + if all_metrics is not None: + self._metrics_list = list(itertools.chain(*all_metrics)) + + def evaluate(self, metric: str) -> tuple[MetricLogs, str]: + """Evaluate predictions. + + Returns a dict containing the raw data and a + short description string containing a readable result. + + Args: + metric (str): Metric to use. See @property metric + + Returns: + metric_data, description + tuple containing the metric data (dict with metric name and value) + as well as a short string with shortened information. + + Raises: + RuntimeError: if no data has been registered to be evaluated. + ValueError: if metric is not supported. + """ + if len(self._metrics_list) == 0: + raise RuntimeError( + """No data registered to calculate metric. + Register data using .process() first!""" + ) + metric_data: MetricLogs = {} + short_description = "" + + if metric == OpticalFlowEvaluator.METRIC_FLOW: + # EPE + epe = np.mean( + [x[self.KEY_ENDPOINT_ERROR] for x in self._metrics_list] + ) + metric_data[self.KEY_ENDPOINT_ERROR] = float(epe) + short_description = f"EPE: {epe:.3f}" + + # AE + ae = np.mean( + [x[self.KEY_ANGULAR_ERROR] for x in self._metrics_list] + ) + metric_data[self.KEY_ANGULAR_ERROR] = float(ae) + angular_unit = "rad" if not self.use_degrees else "deg" + short_description = f"AE: {ae:.3f}{angular_unit}" + + else: + raise ValueError( + f"Unsupported metric: {metric}" + ) # pragma: no cover + + return metric_data, short_description diff --git a/vis4d/eval/common/seg.py b/vis4d/eval/common/seg.py new file mode 100644 index 0000000000000000000000000000000000000000..28817a89cf1ef49495f532d8cc6503714bf4294f --- /dev/null +++ b/vis4d/eval/common/seg.py @@ -0,0 +1,185 @@ +"""Common segmentation evaluator.""" + +from __future__ import annotations + +import numpy as np +from terminaltables import AsciiTable + +from vis4d.common.array import array_to_numpy +from vis4d.common.typing import ( + ArrayLike, + MetricLogs, + NDArrayI64, + NDArrayNumber, +) +from vis4d.eval.base import Evaluator + + +class SegEvaluator(Evaluator): + """Creates an evaluator that calculates mIoU score and confusion matrix.""" + + METRIC_MIOU = "mIoU" + METRIC_CONFUSION_MATRIX = "confusion_matrix" + + def __init__( + self, + num_classes: int | None = None, + class_to_ignore: int | None = None, + class_mapping: dict[int, str] | None = None, + ): + """Creates a new evaluator. + + Args: + num_classes (int): Number of semantic classes + class_to_ignore (int | None): Groundtruth class that should be + ignored + class_mapping (int): dict mapping each class_id to a readable name + + """ + super().__init__() + self.num_classes = num_classes + self.class_mapping = class_mapping if class_mapping is not None else {} + self.class_to_ignore = class_to_ignore + + self._confusion_matrix: NDArrayI64 | None = None + self.reset() + + @property + def metrics(self) -> list[str]: + """Supported metrics.""" + return [ + self.METRIC_MIOU, + self.METRIC_CONFUSION_MATRIX, + ] + + # Taken and modified (added static N) from + # https://stackoverflow.com/questions/59080843/faster-method-of-computing-confusion-matrix + def calc_confusion_matrix( + self, prediction: NDArrayNumber, groundtruth: NDArrayI64 + ) -> NDArrayI64: + """Calculates the confusion matrix for multi class predictions. + + Args: + prediction (array): Class predictions + groundtruth (array): Groundtruth classes + + Returns: + Confusion Matrix of dimension n_classes x n_classes. + """ + y_true = groundtruth.reshape(-1) + if prediction.shape != groundtruth.shape: + y_pred = np.argmax(prediction, axis=1).reshape(-1) + else: + y_pred = prediction.reshape(-1) + y_pred = y_pred.astype(np.int64) + + if self.class_to_ignore is not None: + valid = y_true != self.class_to_ignore + y_true = y_true[valid] + y_pred = y_pred[valid] + if self.num_classes is None: + n_classes = np.max(np.max(groundtruth), np.max(y_pred)) + 1 + else: + n_classes = self.num_classes + + y = n_classes * y_true + y_pred + y = np.bincount(y, minlength=n_classes * n_classes) + return y.reshape(n_classes, n_classes) + + def reset(self) -> None: + """Reset the saved predictions to start new round of evaluation.""" + self._confusion_matrix = None + + def process_batch( + self, prediction: ArrayLike, groundtruth: ArrayLike + ) -> None: + """Process sample and update confusion matrix. + + Args: + prediction: Predictions of shape [N,C,...] or [N,...] with + C* being any number if channels. Note, C is passed, + the prediction is converted to target labels by applying + the max operations along the second axis + groundtruth: Groundtruth of shape [N_batch, ...] type int + """ + confusion_matrix = self.calc_confusion_matrix( + array_to_numpy(prediction, n_dims=None, dtype=np.float32), + array_to_numpy(groundtruth, n_dims=None, dtype=np.int64), + ) + + if self._confusion_matrix is None: + self._confusion_matrix = confusion_matrix + else: + assert ( + self._confusion_matrix.shape == confusion_matrix.shape + ), """Shape of confusion matrix changed during runtime!, + Please specify a static number of classes in constructor.""" + self._confusion_matrix += confusion_matrix + + def _get_class_name_for_idx(self, idx: int) -> str: + """Maps a class index to a unique class name. + + Args: + idx (int): class index. + + Returns: + (str) class name + """ + return self.class_mapping.get(idx, f"class_{idx}") + + def evaluate(self, metric: str) -> tuple[MetricLogs, str]: + """Evaluate predictions. + + Returns a dict containing the raw data and a + short description string containing a readable result. + + Args: + metric (str): Metric to use. See @property metric. + + Returns: + (dict, str) containing the raw data and a short description string. + + Raises: + ValueError: If metric is not supported. + """ + assert ( + self._confusion_matrix is not None + ), """Evaluate() needs to process samples first. + Please call the process() function before calling evaluate()""" + + metric_data, short_description = {}, "" + if metric == self.METRIC_MIOU: + # Calculate miou from confusion matrix + tp = np.diag(self._confusion_matrix) + fp = np.sum(self._confusion_matrix, axis=0) - tp + fn = np.sum(self._confusion_matrix, axis=1) - tp + iou = tp / (tp + fn + fp) * 100 + m_iou = np.nanmean(iou) + + iou_class_str = ", ".join( + f"{self._get_class_name_for_idx(idx)}: ({d:.3f}%)" + for idx, d in enumerate(iou) + ) + metric_data[self.METRIC_MIOU] = m_iou + short_description += f"mIoU: {m_iou:.3f}% \n" + short_description += iou_class_str + "\n" + + elif metric == self.METRIC_CONFUSION_MATRIX: + headers = ["Confusion"] + [ + self._get_class_name_for_idx(i) + for i in range(self._confusion_matrix.shape[0]) + ] + table_data = self._confusion_matrix / ( + np.sum(self._confusion_matrix, axis=1) + ) + data = list( + [f"Class_{idx}"] + list(d) for idx, d in enumerate(table_data) + ) + table = AsciiTable([headers] + data) + # TODO, change MetricLogs type for more complex log types as e.g. + # confusion matrix + short_description += table.table + "\n" + + else: + raise ValueError(f"Metric {metric} not supported") + return metric_data, short_description diff --git a/vis4d/eval/kitti/__init__.py b/vis4d/eval/kitti/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eae643cb0372f2507a574267967dd5e7ba5814aa --- /dev/null +++ b/vis4d/eval/kitti/__init__.py @@ -0,0 +1,5 @@ +"""KITTI evaluator.""" + +from .depth import KITTIDepthEvaluator + +__all__ = ["KITTIDepthEvaluator"] diff --git a/vis4d/eval/kitti/depth.py b/vis4d/eval/kitti/depth.py new file mode 100644 index 0000000000000000000000000000000000000000..df21b448028df460ef2cd3080724967687de899b --- /dev/null +++ b/vis4d/eval/kitti/depth.py @@ -0,0 +1,87 @@ +"""KITTI evaluation code.""" + +from __future__ import annotations + +import numpy as np + +from vis4d.common.typing import NDArrayFloat, NDArrayNumber + +from ..common import DepthEvaluator + + +def apply_garg_crop(mask: NDArrayNumber) -> NDArrayNumber: + """Apply Garg ECCV16 crop to the mask. + + Args: + mask (np.array): Mask to be cropped, in shape (..., H, W). + + Returns: + np.array: Cropped mask, in shape (..., H', W'). + """ + # crop used by Garg ECCV16 + h, w = mask.shape[-2:] + crop = np.array( + [0.40810811 * h, 0.99189189 * h, 0.03594771 * w, 0.96405229 * w] + ).astype(np.int32) + mask[..., crop[0] : crop[1], crop[2] : crop[3]] = 1 + return mask + + +def apply_eigen_crop(mask: NDArrayNumber) -> NDArrayNumber: + """Apply Eigen NIPS14 crop to the mask. + + Args: + mask (np.array): Mask to be cropped, in shape (N, H, W). + + Returns: + np.array: Cropped mask, in shape (N, H', W'). + """ + # https://github.com/mrharicot/monodepth/utils/evaluate_kitti.py + h, w = mask.shape[-2:] + crop = np.array( + [0.3324324 * h, 0.91351351 * h, 0.0359477 * w, 0.96405229 * w] + ).astype(np.int32) + mask[..., crop[0] : crop[1], crop[2] : crop[3]] = 1 + return mask + + +class KITTIDepthEvaluator(DepthEvaluator): + """KITTI depth evaluation class.""" + + METRIC_DEPTH = "depth" + + def __init__( + self, + min_depth: float = 0.01, + max_depth: float = 80.0, + eval_crop: str | None = None, + ) -> None: + """Initialize KITTI depth evaluator.""" + super().__init__(min_depth, max_depth) + self.eval_crop = eval_crop + self.reset() + + def __repr__(self) -> str: + """Concise representation of the dataset evaluator.""" + return "KITTI evaluation for depth" + + def _get_eval_mask(self, valid_mask: NDArrayNumber) -> NDArrayNumber: + """Do Grag or Eigen cropping for testing.""" + eval_mask = np.zeros_like(valid_mask) + if self.eval_crop == "garg_crop": + eval_mask = apply_garg_crop(eval_mask) + elif self.eval_crop == "eigen_crop": + eval_mask = apply_eigen_crop(eval_mask) + else: + eval_mask = np.ones_like(valid_mask) + return np.logical_and(valid_mask, eval_mask) + + def _apply_mask( + self, prediction: NDArrayFloat, target: NDArrayFloat + ) -> tuple[NDArrayFloat, NDArrayFloat]: + """Apply mask to prediction and target.""" + valid_mask = (target > self.min_depth) & (target < self.max_depth) + eval_mask = self._get_eval_mask(valid_mask) + prediction = prediction[eval_mask] + target = target[eval_mask] + return prediction, target diff --git a/vis4d/eval/metrics/__init__.py b/vis4d/eval/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e2a4dbf7b6c58249fdca883e2ed3a4f8d9268c8b --- /dev/null +++ b/vis4d/eval/metrics/__init__.py @@ -0,0 +1 @@ +"""Eval metrics.""" diff --git a/vis4d/eval/metrics/cls.py b/vis4d/eval/metrics/cls.py new file mode 100644 index 0000000000000000000000000000000000000000..9a863fb7853644166c45202b45e87c621182c791 --- /dev/null +++ b/vis4d/eval/metrics/cls.py @@ -0,0 +1,31 @@ +"""Classification metrics.""" + +from __future__ import annotations + +import numpy as np + +from vis4d.common.array import array_to_numpy +from vis4d.common.typing import ArrayLike, ArrayLikeInt + + +def accuracy( + prediction: ArrayLike, target: ArrayLikeInt, top_k: int = 1 +) -> float: + """Calculate the accuracy of the prediction. + + Args: + prediction (ArrayLike): Probabilities (or logits) of shape (N, C) or + (C, ). + target (ArrayLikeInt): Target of shape (N, ) or (1, ). + top_k (int, optional): Top k accuracy. Defaults to 1. + + Returns: + float: Accuracy of the prediction, in range [0, 1]. + """ + prediction = array_to_numpy(prediction, n_dims=2, dtype=np.float32) + target = array_to_numpy(target, n_dims=1, dtype=np.int64) + assert prediction.shape[0] == target.shape[0], "Batch size mismatch." + top_k = min(top_k, prediction.shape[1]) + top_k_idx = np.argsort(prediction, axis=1)[:, -top_k:] + correct = np.any(top_k_idx == target[:, None], axis=1) + return float(np.mean(correct)) diff --git a/vis4d/eval/metrics/depth.py b/vis4d/eval/metrics/depth.py new file mode 100644 index 0000000000000000000000000000000000000000..8d783b6c0ff1f55305edcde697b0846264165362 --- /dev/null +++ b/vis4d/eval/metrics/depth.py @@ -0,0 +1,146 @@ +"""Depth estimation metrics.""" + +from __future__ import annotations + +import numpy as np + +from vis4d.common.typing import ArrayLike + +from ..utils import check_shape_match, dense_inputs_to_numpy + + +def absolute_error(prediction: ArrayLike, target: ArrayLike) -> float: + """Compute the absolute error. + + Args: + prediction (NDArrayNumber): Prediction depth map, in shape (..., H, W). + target (NDArrayNumber): Target depth map, in shape (..., H, W). + + Returns: + float: Absolute error. + """ + prediction, target = dense_inputs_to_numpy(prediction, target) + check_shape_match(prediction, target) + return np.mean(np.abs(prediction - target)).item() + + +def squared_relative_error(prediction: ArrayLike, target: ArrayLike) -> float: + """Compute the squared relative error. + + Args: + prediction (NDArrayNumber): Prediction depth map, in shape (..., H, W). + target (NDArrayNumber): Target depth map, in shape (..., H, W). + + Returns: + float: Square relative error. + """ + prediction, target = dense_inputs_to_numpy(prediction, target) + check_shape_match(prediction, target) + return np.mean(np.square(prediction - target) / target).item() + + +def absolute_relative_error(prediction: ArrayLike, target: ArrayLike) -> float: + """Compute the absolute relative error. + + Args: + prediction (NDArrayNumber): Prediction depth map, in shape (..., H, W). + target (NDArrayNumber): Target depth map, in shape (..., H, W). + + Returns: + float: Absolute relative error. + """ + prediction, target = dense_inputs_to_numpy(prediction, target) + check_shape_match(prediction, target) + return np.mean(np.abs(prediction - target) / target).item() + + +def root_mean_squared_error(prediction: ArrayLike, target: ArrayLike) -> float: + """Compute the root mean squared error. + + Args: + prediction (ArrayLike): Prediction depth map, in shape (..., H, W). + target (ArrayLike): Target depth map, in shape (..., H, W). + + Returns: + float: Root mean squared error. + """ + prediction, target = dense_inputs_to_numpy(prediction, target) + check_shape_match(prediction, target) + squared_diff = np.square(prediction - target) + return np.sqrt(np.mean(squared_diff)).item() + + +def root_mean_squared_error_log( + prediction: ArrayLike, target: ArrayLike, epsilon: float = 1e-8 +) -> float: + """Compute the root mean squared error in log space. + + Args: + prediction (ArrayLike): Prediction depth map, in shape (H, W). + target (ArrayLike): Target depth map, in shape (H, W). + epsilon (float, optional): Epsilon to avoid log(0). Defaults to 1e-6. + + Returns: + float: Root mean squared error in log space. + """ + prediction, target = dense_inputs_to_numpy(prediction, target) + check_shape_match(prediction, target) + log_pred = np.log(prediction + epsilon) + log_target = np.log(target + epsilon) + squared_diff = np.square(log_pred - log_target) + return np.sqrt(np.mean(squared_diff)).item() + + +def scale_invariant_log( + prediction: ArrayLike, target: ArrayLike, epsilon: float = 1e-8 +) -> float: + """Compute the scale invariant log error. + + Args: + prediction (ArrayLike): Prediction depth map, in shape (H, W). + target (ArrayLike): Target depth map, in shape (H, W). + epsilon (float, optional): Epsilon to avoid log(0). Defaults to 1e-6. + + Returns: + float: Scale invariant log error. + """ + prediction, target = dense_inputs_to_numpy(prediction, target) + check_shape_match(prediction, target) + log_diff = np.log(prediction + epsilon) - np.log(target + epsilon) + return 100.0 * float(np.sqrt(np.var(log_diff)).mean()) + + +def delta_p( + prediction: ArrayLike, target: ArrayLike, power: float = 1 +) -> float: + """Compute the delta_p metric. + + Args: + prediction (ArrayLike): Prediction depth map, in shape (H, W). + target (ArrayLike): Target depth map, in shape (H, W). + power (float, optional): Power of the threshold. Defaults to 1. + + Returns: + float: Delta_p metric. + """ + prediction, target = dense_inputs_to_numpy(prediction, target) + check_shape_match(prediction, target) + return np.mean( + np.maximum((target / prediction), (prediction / target)) < 1.25**power + ).item() + + +def log_10_error(prediction: ArrayLike, target: ArrayLike) -> float: + """Compute the log_10 error. + + Args: + prediction (ArrayLike): Prediction depth map, in shape (H, W). + target (ArrayLike): Target depth map, in shape (H, W). + + Returns: + float: Log_10 error. + """ + prediction, target = dense_inputs_to_numpy(prediction, target) + check_shape_match(prediction, target) + log10_diff = np.log10(prediction) - np.log10(target) + return np.mean(np.abs(log10_diff)).item() diff --git a/vis4d/eval/metrics/flow.py b/vis4d/eval/metrics/flow.py new file mode 100644 index 0000000000000000000000000000000000000000..7c9e7d13c6d7b4735c4ccd037b78f505d27dc69a --- /dev/null +++ b/vis4d/eval/metrics/flow.py @@ -0,0 +1,47 @@ +"""Depth estimation metrics.""" + +from __future__ import annotations + +import numpy as np + +from vis4d.common.typing import ArrayLike + +from ..utils import check_shape_match, dense_inputs_to_numpy + + +def end_point_error(prediction: ArrayLike, target: ArrayLike) -> float: + """Compute the end point error. + + Args: + prediction (ArrayLike): Prediction UV optical flow, in shape (..., 2). + target (ArrayLike): Target UV optical flow, in shape (..., 2). + + Returns: + float: End point error. + """ + prediction, target = dense_inputs_to_numpy(prediction, target) + check_shape_match(prediction, target) + squared_sum = np.sum((prediction - target) ** 2, axis=-1) + return np.mean(np.sqrt(squared_sum)).item() + + +def angular_error( + prediction: ArrayLike, target: ArrayLike, epsilon: float = 1e-6 +) -> float: + """Compute the angular error. + + Args: + prediction (ArrayLike): Prediction UV optical flow, in shape (..., 2). + target (ArrayLike): Target UV optical flow, in shape (..., 2). + epsilon (float, optional): Epsilon value for numerical stability. + + Returns: + float: Angular error. + """ + prediction, target = dense_inputs_to_numpy(prediction, target) + check_shape_match(prediction, target) + product = np.sum(prediction * target, axis=-1) + pred_norm = np.linalg.norm(prediction, axis=-1) + target_norm = np.linalg.norm(target, axis=-1) + cos_angle = np.abs(product) / (pred_norm * target_norm + epsilon) + return np.mean(np.arccos(np.clip(cos_angle, 0.0, 1.0))).item() diff --git a/vis4d/eval/nuscenes/__init__.py b/vis4d/eval/nuscenes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6fbf75265b626beadbe81b68b09dec0162b8f623 --- /dev/null +++ b/vis4d/eval/nuscenes/__init__.py @@ -0,0 +1,6 @@ +"""NuScenes evaluator.""" + +from .detect3d import NuScenesDet3DEvaluator +from .track3d import NuScenesTrack3DEvaluator + +__all__ = ["NuScenesDet3DEvaluator", "NuScenesTrack3DEvaluator"] diff --git a/vis4d/eval/nuscenes/detect3d.py b/vis4d/eval/nuscenes/detect3d.py new file mode 100644 index 0000000000000000000000000000000000000000..c2817120eca10fea05cbfbb10225843e74cc44e7 --- /dev/null +++ b/vis4d/eval/nuscenes/detect3d.py @@ -0,0 +1,338 @@ +"""NuScenes 3D detection evaluation code.""" + +from __future__ import annotations + +import json +import os +from collections.abc import Callable +from typing import Any + +import numpy as np + +from vis4d.common.array import array_to_numpy +from vis4d.common.imports import NUSCENES_AVAILABLE +from vis4d.common.logging import rank_zero_warn +from vis4d.common.typing import ArrayLike, DictStrAny, MetricLogs +from vis4d.data.datasets.nuscenes import ( + nuscenes_attribute_map, + nuscenes_class_map, +) + +from ..base import Evaluator + +if NUSCENES_AVAILABLE: + from nuscenes import NuScenes as NuScenesDevkit + from nuscenes.eval.detection.config import config_factory + from nuscenes.eval.detection.evaluate import NuScenesEval + from nuscenes.utils.data_classes import Quaternion +else: + raise ImportError("nuscenes-devkit is not installed.") + + +def _parse_high_level_metrics( + mean_ap: float, + tp_errors: dict[str, float], + nd_score: float, + eval_time: float, +) -> tuple[MetricLogs, list[str]]: + """Collect high-level metrics.""" + log_dict: MetricLogs = { + "mAP": mean_ap, + "mATE": tp_errors["trans_err"], + "mASE": tp_errors["scale_err"], + "mAOE": tp_errors["orient_err"], + "mAVE": tp_errors["vel_err"], + "mAAE": tp_errors["attr_err"], + "NDS": nd_score, + } + + str_summary_list = ["\nHigh-level metrics:"] + for k, v in log_dict.items(): + str_summary_list.append(f"{k}: {v:.4f}") + + str_summary_list.append(f"Eval time: {eval_time:.1f}s") + + return log_dict, str_summary_list + + +def _parse_per_class_metrics( + str_summary_list: list[str], + class_aps: dict[str, float], + class_tps: dict[str, dict[str, float]], +) -> list[str]: + """Collect per-class metrics.""" + str_summary_list.append("\nPer-class results:") + str_summary_list.append("Object Class\tAP\tATE\tASE\tAOE\tAVE\tAAE") + + for class_name in class_aps.keys(): + tmp_str_list = [class_name] + tmp_str_list.append(f"{class_aps[class_name]:.3f}") + tmp_str_list.append(f"{class_tps[class_name]['trans_err']:.3f}") + tmp_str_list.append(f"{class_tps[class_name]['scale_err']:.3f}") + tmp_str_list.append(f"{class_tps[class_name]['orient_err']:.3f}") + tmp_str_list.append(f"{class_tps[class_name]['vel_err']:.3f}") + tmp_str_list.append(f"{class_tps[class_name]['attr_err']:.3f}") + + str_summary_list.append("\t".join(tmp_str_list)) + return str_summary_list + + +class NuScenesDet3DEvaluator(Evaluator): + """NuScenes 3D detection evaluation class.""" + + inv_nuscenes_attribute_map = { + v: k for k, v in nuscenes_attribute_map.items() + } + + DefaultAttribute = { + "car": "vehicle.parked", + "pedestrian": "pedestrian.moving", + "trailer": "vehicle.parked", + "truck": "vehicle.parked", + "bus": "vehicle.moving", + "motorcycle": "cycle.without_rider", + "construction_vehicle": "vehicle.parked", + "bicycle": "cycle.without_rider", + "barrier": "", + "traffic_cone": "", + } + + def __init__( + self, + data_root: str, + version: str, + split: str, + save_only: bool = False, + class_map: dict[str, int] | None = None, + metadata: tuple[str, ...] = ("use_camera",), + use_default_attr: bool = False, + velocity_thres: float = 1.0, + ) -> None: + """Initialize NuScenes evaluator.""" + super().__init__() + self.data_root = data_root + self.version = version + self.split = split + self.save_only = save_only + self.use_default_attr = use_default_attr + self.velocity_thres = velocity_thres + + self.meta_data = { + "use_camera": False, + "use_lidar": False, + "use_radar": False, + "use_map": False, + "use_external": False, + } + + for m in metadata: + self.meta_data[m] = True + + class_map = class_map or nuscenes_class_map + self.inv_nuscenes_class_map = {v: k for k, v in class_map.items()} + + self.output_dir = "" + self.detect_3d: DictStrAny = {} + self.reset() + + def __repr__(self) -> str: + """Concise representation of the dataset evaluator.""" + return "NuScenes 3D Detection Evaluator" + + @property + def metrics(self) -> list[str]: + """Supported metrics.""" + return ["detect_3d"] + + def gather( # type: ignore + self, gather_func: Callable[[Any], Any] + ) -> None: + """Gather variables in case of distributed setting (if needed). + + Args: + gather_func (Callable[[Any], Any]): Gather function. + """ + detect_3d_list = gather_func(self.detect_3d) + if detect_3d_list is not None: + collated_detect_3d: DictStrAny = {} + for prediction in detect_3d_list: + for k, v in prediction.items(): + if k not in collated_detect_3d: + collated_detect_3d[k] = v + else: + collated_detect_3d[k].extend(v) + self.detect_3d = collated_detect_3d + + def reset(self) -> None: + """Reset evaluator.""" + self.detect_3d.clear() + + def get_attributes(self, name: str, velocity: list[float]) -> str: + """Get nuScenes attributes.""" + if self.use_default_attr: + return self.DefaultAttribute[name] + + if np.sqrt(velocity[0] ** 2 + velocity[1] ** 2) > self.velocity_thres: + if name in { + "car", + "construction_vehicle", + "bus", + "truck", + "trailer", + }: + attr = "vehicle.moving" + elif name in {"bicycle", "motorcycle"}: + attr = "cycle.with_rider" + else: + attr = self.DefaultAttribute[name] + elif name in {"pedestrian"}: + attr = "pedestrian.standing" + elif name in {"bus"}: + attr = "vehicle.stopped" + else: + attr = self.DefaultAttribute[name] + return attr + + def _process_detect_3d( + self, + token: str, + boxes_3d: ArrayLike, + velocities: ArrayLike, + scores_3d: ArrayLike, + class_ids: ArrayLike, + attributes: ArrayLike | None = None, + ) -> None: + """Process 3D detection results.""" + annos = [] + boxes_3d_np = array_to_numpy(boxes_3d, n_dims=None, dtype=np.float32) + velocities_np = array_to_numpy( + velocities, n_dims=None, dtype=np.float32 + ) + scores_3d_np = array_to_numpy(scores_3d, n_dims=None, dtype=np.float32) + class_ids_np = array_to_numpy(class_ids, n_dims=None, dtype=np.int64) + + if len(boxes_3d_np) != 0: + for i, (box_3d, velocity, score_3d, class_id) in enumerate( + zip( + boxes_3d_np, + velocities_np, + scores_3d_np, + class_ids_np, + ) + ): + category = self.inv_nuscenes_class_map[int(class_id)] + + translation = box_3d[0:3] + + dims = box_3d[3:6].tolist() + dimension = [d if d >= 0 else 0.1 for d in dims] + + rotation = Quaternion(box_3d[6:].tolist()) + + score = float(score_3d) + + velocity_list = velocity.tolist() + + if attributes is None: + attribute_name = self.get_attributes( + category, velocity_list + ) + else: + attribute = array_to_numpy( + attributes[i], n_dims=None, dtype=np.int64 # type: ignore # pylint: disable=line-too-long + ) + attribute_name = self.inv_nuscenes_attribute_map[ + int(attribute) + ] + + nusc_anno = { + "sample_token": token, + "translation": translation.tolist(), + "size": dimension, + "rotation": rotation.elements.tolist(), + "velocity": [velocity_list[0], velocity_list[1]], + "detection_name": category, + "detection_score": score, + "attribute_name": attribute_name, + } + annos.append(nusc_anno) + self.detect_3d[token] = annos + + def process_batch( + self, + tokens: list[str], + boxes_3d: list[ArrayLike], + velocities: list[ArrayLike], + class_ids: list[ArrayLike], + scores_3d: list[ArrayLike], + attributes: list[ArrayLike] | None = None, + ) -> None: + """Process the results.""" + for i, token in enumerate(tokens): + self._process_detect_3d( + token, + boxes_3d[i], + velocities[i], + scores_3d[i], + class_ids[i], + attributes[i] if attributes is not None else None, + ) + + def evaluate(self, metric: str) -> tuple[MetricLogs, str]: + """Evaluate the results.""" + assert metric == "detect_3d" + if self.save_only: + return {}, "Results are saved to the json file." + + try: + nusc = NuScenesDevkit( + version=self.version, + dataroot=self.data_root, + verbose=False, + ) + + nusc_eval = NuScenesEval( + nusc, + config=config_factory("detection_cvpr_2019"), + result_path=f"{self.output_dir}/detect_3d_predictions.json", + eval_set=self.split, + output_dir=os.path.join(self.output_dir, "detection"), + ) + metrics, _ = nusc_eval.evaluate() + metrics_summary = metrics.serialize() + + log_dict, str_summary_list = _parse_high_level_metrics( + metrics_summary["mean_ap"], + metrics_summary["tp_errors"], + metrics_summary["nd_score"], + metrics_summary["eval_time"], + ) + + class_aps = metrics_summary["mean_dist_aps"] + class_tps = metrics_summary["label_tp_errors"] + str_summary_list = _parse_per_class_metrics( + str_summary_list, class_aps, class_tps + ) + + str_summary = "\n".join(str_summary_list) + except Exception as e: # pylint: disable=broad-except + error_msg = "".join(e.args) + rank_zero_warn(f"Evaluation error: {error_msg}") + log_dict = {} + str_summary = ( + "Evaluation failure might be raised due to sanity check" + + "or all emtpy boxes." + ) + rank_zero_warn(str_summary) + return log_dict, str_summary + + def save(self, metric: str, output_dir: str) -> None: + """Save the results to json files.""" + assert metric == "detect_3d" + nusc_annos = {"results": self.detect_3d, "meta": self.meta_data} + result_file = f"{output_dir}/detect_3d_predictions.json" + + with open(result_file, mode="w", encoding="utf-8") as f: + json.dump(nusc_annos, f) + + self.output_dir = output_dir diff --git a/vis4d/eval/nuscenes/track3d.py b/vis4d/eval/nuscenes/track3d.py new file mode 100644 index 0000000000000000000000000000000000000000..77d7f89900e9d8f9995cd3503100a0f2fb3275c9 --- /dev/null +++ b/vis4d/eval/nuscenes/track3d.py @@ -0,0 +1,167 @@ +"""NuScenes 3D tracking evaluation code.""" + +from __future__ import annotations + +import json +from collections.abc import Callable +from typing import Any + +import numpy as np +from nuscenes.utils.data_classes import Quaternion + +from vis4d.common.array import array_to_numpy +from vis4d.common.typing import ArrayLike, DictStrAny, MetricLogs +from vis4d.data.datasets.nuscenes import nuscenes_class_map + +from ..base import Evaluator + + +class NuScenesTrack3DEvaluator(Evaluator): + """NuScenes 3D tracking evaluation class.""" + + inv_nuscenes_class_map = {v: k for k, v in nuscenes_class_map.items()} + + tracking_cats = [ + "bicycle", + "motorcycle", + "pedestrian", + "bus", + "car", + "trailer", + "truck", + ] + + def __init__(self, metadata: tuple[str, ...] = ("use_camera",)) -> None: + """Initialize NuScenes evaluator.""" + super().__init__() + self.meta_data = { + "use_camera": False, + "use_lidar": False, + "use_radar": False, + "use_map": False, + "use_external": False, + } + + for m in metadata: + self.meta_data[m] = True + + self.tracks_3d: DictStrAny = {} + self.reset() + + def __repr__(self) -> str: + """Concise representation of the dataset evaluator.""" + return "NuScenes 3D Tracking Evaluator" + + @property + def metrics(self) -> list[str]: + """Supported metrics.""" + return ["track_3d"] + + def gather( # type: ignore + self, gather_func: Callable[[Any], Any] + ) -> None: + """Gather variables in case of distributed setting (if needed). + + Args: + gather_func (Callable[[Any], Any]): Gather function. + """ + tracks_3d_list = gather_func(self.tracks_3d) + if tracks_3d_list is not None: + collated_track_3d: DictStrAny = {} + for prediction in tracks_3d_list: + for k, v in prediction.items(): + if k not in collated_track_3d: + collated_track_3d[k] = v + else: + collated_track_3d[k].extend(v) + self.tracks_3d = collated_track_3d + + def reset(self) -> None: + """Reset evaluator.""" + self.tracks_3d.clear() + + def _process_track_3d( + self, + token: str, + boxes_3d: ArrayLike, + velocities: ArrayLike, + scores_3d: ArrayLike, + class_ids: ArrayLike, + track_ids: ArrayLike, + ) -> None: + """Process 3D tracking results.""" + annos = [] + boxes_3d_np = array_to_numpy(boxes_3d, n_dims=None, dtype=np.float32) + velocities_np = array_to_numpy( + velocities, n_dims=None, dtype=np.float32 + ) + scores_3d_np = array_to_numpy(scores_3d, n_dims=None, dtype=np.float32) + class_ids_np = array_to_numpy(class_ids, n_dims=None, dtype=np.int64) + track_ids_np = array_to_numpy(track_ids, n_dims=None, dtype=np.int64) + + if len(boxes_3d_np) != 0: + for box_3d, velocity, score_3d, class_id, track_id in zip( + boxes_3d_np, + velocities_np, + scores_3d_np, + class_ids_np, + track_ids_np, + ): + category = self.inv_nuscenes_class_map[int(class_id)] + if not category in self.tracking_cats: + continue + + translation = box_3d[0:3] + + dimension = box_3d[3:6] + + rotation = Quaternion(box_3d[6:].tolist()) + + score = float(score_3d) + + velocity_list = velocity.tolist() + + nusc_anno = { + "sample_token": token, + "translation": translation.tolist(), + "size": dimension.tolist(), + "rotation": rotation.elements.tolist(), + "velocity": [velocity_list[0], velocity_list[1]], + "tracking_id": int(track_id), + "tracking_name": category, + "tracking_score": score, + } + annos.append(nusc_anno) + self.tracks_3d[token] = annos + + def process_batch( + self, + tokens: list[str], + boxes_3d: list[ArrayLike], + velocities: list[ArrayLike], + class_ids: list[ArrayLike], + scores_3d: list[ArrayLike], + track_ids: list[ArrayLike], + ) -> None: + """Process the results.""" + for i, token in enumerate(tokens): + self._process_track_3d( + token, + boxes_3d[i], + velocities[i], + scores_3d[i], + class_ids[i], + track_ids[i], + ) + + def evaluate(self, metric: str) -> tuple[MetricLogs, str]: + """Evaluate the results.""" + return {}, "Currently only save the json file." + + def save(self, metric: str, output_dir: str) -> None: + """Save the results to json files.""" + nusc_annos = {"results": self.tracks_3d, "meta": self.meta_data} + result_file = f"{output_dir}/track_3d_predictions.json" + + with open(result_file, mode="w", encoding="utf-8") as f: + json.dump(nusc_annos, f) diff --git a/vis4d/eval/scalabel/__init__.py b/vis4d/eval/scalabel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7d006bf0d3e6760054f6e766ee82419e0a2bf5a6 --- /dev/null +++ b/vis4d/eval/scalabel/__init__.py @@ -0,0 +1,11 @@ +"""Scalabel evaluator.""" + +from .base import ScalabelEvaluator +from .detect import ScalabelDetectEvaluator +from .track import ScalabelTrackEvaluator + +__all__ = [ + "ScalabelEvaluator", + "ScalabelDetectEvaluator", + "ScalabelTrackEvaluator", +] diff --git a/vis4d/eval/scalabel/base.py b/vis4d/eval/scalabel/base.py new file mode 100644 index 0000000000000000000000000000000000000000..d96f60990d97a47992424e93c59e1af1f41223be --- /dev/null +++ b/vis4d/eval/scalabel/base.py @@ -0,0 +1,65 @@ +"""Scalabel base evaluator.""" + +from __future__ import annotations + +import itertools +from collections.abc import Callable +from typing import Any + +from vis4d.common.imports import SCALABEL_AVAILABLE +from vis4d.common.typing import MetricLogs +from vis4d.eval.base import Evaluator + +if SCALABEL_AVAILABLE: + from scalabel.label.io import load + from scalabel.label.typing import Config, Frame + from scalabel.label.utils import get_leaf_categories +else: + raise ImportError("scalabel is not installed.") + + +class ScalabelEvaluator(Evaluator): + """Scalabel base evaluation class.""" + + def __init__( + self, annotation_path: str, config: Config | None = None + ) -> None: + """Initialize the evaluator.""" + super().__init__() + self.annotation_path = annotation_path + self.frames: list[Frame] = [] + + dataset = load(self.annotation_path, validate_frames=False) + self.gt_frames = dataset.frames + if config is not None: + self.config: Config | None = config + else: + self.config = dataset.config + if self.config is not None and self.config.categories is not None: + categories = get_leaf_categories(self.config.categories) + self.inverse_cat_map = { + cat_id: cat.name for cat_id, cat in enumerate(categories) + } + else: + self.inverse_cat_map = {} + self.reset() + + def gather( # type: ignore # pragma: no cover + self, gather_func: Callable[[Any], Any] + ) -> None: + """Gather variables in case of distributed setting (if needed). + + Args: + gather_func (Callable[[Any], Any]): Gather function. + """ + all_preds = gather_func(self.frames) + if all_preds is not None: + self.frames = list(itertools.chain(*all_preds)) + + def reset(self) -> None: + """Reset the evaluator.""" + self.frames = [] + + def evaluate(self, metric: str) -> tuple[MetricLogs, str]: + """Evaluate the dataset.""" + raise NotImplementedError diff --git a/vis4d/eval/scalabel/detect.py b/vis4d/eval/scalabel/detect.py new file mode 100644 index 0000000000000000000000000000000000000000..8e0a80444a9cfbb24476f7e7183cdc28bf69a0ef --- /dev/null +++ b/vis4d/eval/scalabel/detect.py @@ -0,0 +1,139 @@ +"""Scalabel detection evaluator.""" + +from __future__ import annotations + +import numpy as np + +from vis4d.common.array import array_to_numpy +from vis4d.common.imports import SCALABEL_AVAILABLE +from vis4d.common.typing import ArrayLike, MetricLogs + +from .base import ScalabelEvaluator + +if SCALABEL_AVAILABLE: + from scalabel.eval.detect import evaluate_det + from scalabel.eval.ins_seg import evaluate_ins_seg + from scalabel.label.transforms import mask_to_rle, xyxy_to_box2d + from scalabel.label.typing import Config, Frame, Label +else: + raise ImportError("scalabel is not installed.") + + +class ScalabelDetectEvaluator(ScalabelEvaluator): + """Scalabel 2D detection evaluation class.""" + + METRICS_DET = "Det" + METRICS_INS_SEG = "InsSeg" + + def __init__( + self, + annotation_path: str, + config: Config | None = None, + mask_threshold: float = 0.0, + ) -> None: + """Initialize the evaluator.""" + super().__init__(annotation_path=annotation_path, config=config) + self.mask_threshold = mask_threshold + + def __repr__(self) -> str: + """Concise representation of the dataset evaluator.""" + return "Scalabel Detection Evaluator" + + @property + def metrics(self) -> list[str]: + """Supported metrics.""" + return [self.METRICS_DET, self.METRICS_INS_SEG] + + def process_batch( + self, + frame_ids: list[int], + sample_names: list[str], + sequence_names: list[str], + pred_boxes: list[ArrayLike], + pred_classes: list[ArrayLike], + pred_scores: list[ArrayLike], + pred_masks: list[ArrayLike] | None = None, + ) -> None: + """Process tracking results.""" + for i, ( + frame_id, + sample_name, + sequence_name, + boxes, + class_ids, + scores, + ) in enumerate( + zip( + frame_ids, + sample_names, + sequence_names, + pred_boxes, + pred_classes, + pred_scores, + ) + ): + boxes = array_to_numpy(boxes, n_dims=None, dtype=np.float32) + class_ids = array_to_numpy(class_ids, n_dims=None, dtype=np.int64) + scores = array_to_numpy(scores, n_dims=None, dtype=np.float32) + if pred_masks: + masks = array_to_numpy( + pred_masks[i], n_dims=None, dtype=np.float32 + ) + labels = [] + for label_id, (box, score, class_id) in enumerate( + zip(boxes, scores, class_ids) + ): + box2d = xyxy_to_box2d(*box.tolist()) + + if pred_masks: + rle = mask_to_rle( + (masks[label_id] > self.mask_threshold).astype( + np.uint8 + ) + ) + else: + rle = None + + label = Label( + id=str(label_id), + box2d=box2d, + category=( + self.inverse_cat_map[int(class_id)] + if self.inverse_cat_map != {} + else str(class_id) + ), + score=float(score), + rle=rle, + ) + labels.append(label) + frame = Frame( + name=sample_name, + videoName=sequence_name, + frameIndex=frame_id, + labels=labels, + ) + self.frames.append(frame) + + def evaluate(self, metric: str) -> tuple[MetricLogs, str]: + """Evaluate the dataset.""" + assert self.config is not None, "Scalabel config is not loaded." + metrics_log: MetricLogs = {} + short_description = "" + + if metric == self.METRICS_DET: + results = evaluate_det( + self.gt_frames, self.frames, config=self.config, nproc=0 + ) + for metric_name, metric_value in results.summary().items(): + metrics_log[metric_name] = metric_value + short_description += str(results) + "\n" + + if metric == self.METRICS_INS_SEG: + results = evaluate_ins_seg( + self.gt_frames, self.frames, config=self.config, nproc=0 + ) + for metric_name, metric_value in results.summary().items(): + metrics_log[metric_name] = metric_value + short_description += str(results) + "\n" + + return metrics_log, short_description diff --git a/vis4d/eval/scalabel/track.py b/vis4d/eval/scalabel/track.py new file mode 100644 index 0000000000000000000000000000000000000000..fb80132c8c45302f530949324b82ba85898cb42a --- /dev/null +++ b/vis4d/eval/scalabel/track.py @@ -0,0 +1,153 @@ +"""Scalabel tracking evaluator.""" + +from __future__ import annotations + +import numpy as np + +from vis4d.common.array import array_to_numpy +from vis4d.common.imports import SCALABEL_AVAILABLE +from vis4d.common.typing import MetricLogs, NDArrayNumber + +from .base import ScalabelEvaluator + +if SCALABEL_AVAILABLE: + from scalabel.eval.mot import acc_single_video_mot, evaluate_track + from scalabel.eval.mots import acc_single_video_mots, evaluate_seg_track + from scalabel.label.io import group_and_sort + from scalabel.label.transforms import mask_to_rle, xyxy_to_box2d + from scalabel.label.typing import Config, Frame, Label +else: + raise ImportError("scalabel is not installed.") + + +class ScalabelTrackEvaluator(ScalabelEvaluator): + """Scalabel 2D tracking evaluation class.""" + + METRICS_TRACK = "MOT" + METRICS_SEG_TRACK = "MOTS" + METRICS_ALL = "all" + + def __init__( + self, + annotation_path: str, + config: Config | None = None, + mask_threshold: float = 0.0, + ) -> None: + """Initialize the evaluator.""" + super().__init__(annotation_path=annotation_path, config=config) + self.mask_threshold = mask_threshold + + def __repr__(self) -> str: + """Concise representation of the dataset evaluator.""" + return "Scalabel Tracking Evaluator" + + @property + def metrics(self) -> list[str]: + """Supported metrics.""" + return [self.METRICS_TRACK, self.METRICS_SEG_TRACK] + + def process_batch( + self, + frame_ids: list[int], + sample_names: list[str], + sequence_names: list[str], + pred_boxes: list[NDArrayNumber], + pred_classes: list[NDArrayNumber], + pred_scores: list[NDArrayNumber], + pred_track_ids: list[NDArrayNumber], + pred_masks: list[NDArrayNumber] | None = None, + ) -> None: + """Process tracking results.""" + for i, ( + frame_id, + sample_name, + sequence_name, + boxes, + scores, + class_ids, + track_ids, + ) in enumerate( + zip( + frame_ids, + sample_names, + sequence_names, + pred_boxes, + pred_scores, + pred_classes, + pred_track_ids, + ) + ): + boxes = array_to_numpy(boxes, n_dims=None, dtype=np.float32) + class_ids = array_to_numpy(class_ids, n_dims=None, dtype=np.int64) + scores = array_to_numpy(scores, n_dims=None, dtype=np.float32) + if pred_masks: + masks = array_to_numpy( + pred_masks[i], n_dims=None, dtype=np.float32 + ) + + labels = [] + for label_id, (box, score, class_id, track_id) in enumerate( + zip(boxes, scores, class_ids, track_ids) + ): + box2d = xyxy_to_box2d(*box.tolist()) + + if pred_masks: + rle = mask_to_rle( + (masks[label_id] > self.mask_threshold).astype( + np.uint8 + ) + ) + else: + rle = None + + label = Label( + box2d=box2d, + category=( + self.inverse_cat_map[int(class_id)] + if self.inverse_cat_map != {} + else str(class_id) + ), + score=float(score), + id=str(int(track_id)), + rle=rle, + ) + labels.append(label) + frame = Frame( + name=sample_name, + videoName=sequence_name, + frameIndex=frame_id, + labels=labels, + ) + self.frames.append(frame) + + def evaluate(self, metric: str) -> tuple[MetricLogs, str]: + """Evaluate the dataset.""" + assert self.config is not None, "config is not set" + metrics_log: MetricLogs = {} + short_description = "" + + if metric in [self.METRICS_TRACK, self.METRICS_ALL]: + results = evaluate_track( + acc_single_video_mot, + gts=group_and_sort(self.gt_frames), + results=group_and_sort(self.frames), + config=self.config, + nproc=0, + ) + for metric_name, metric_value in results.summary().items(): + metrics_log[metric_name] = metric_value + short_description += str(results) + "\n" + + if metric in [self.METRICS_SEG_TRACK, self.METRICS_ALL]: + results = evaluate_seg_track( + acc_single_video_mots, + gts=group_and_sort(self.gt_frames), + results=group_and_sort(self.frames), + config=self.config, + nproc=0, + ) + for metric_name, metric_value in results.summary().items(): + metrics_log[metric_name] = metric_value + short_description += str(results) + "\n" + + return metrics_log, short_description diff --git a/vis4d/eval/shift/__init__.py b/vis4d/eval/shift/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a1ff5c48a9c3320c36cec71a1b0ace0c70e9e3a4 --- /dev/null +++ b/vis4d/eval/shift/__init__.py @@ -0,0 +1,17 @@ +"""SHIFT evaluation metrics.""" + +from .depth import SHIFTDepthEvaluator +from .detect import SHIFTDetectEvaluator +from .flow import SHIFTOpticalFlowEvaluator +from .multitask_writer import SHIFTMultitaskWriter +from .seg import SHIFTSegEvaluator +from .track import SHIFTTrackEvaluator + +__all__ = [ + "SHIFTDepthEvaluator", + "SHIFTDetectEvaluator", + "SHIFTOpticalFlowEvaluator", + "SHIFTSegEvaluator", + "SHIFTTrackEvaluator", + "SHIFTMultitaskWriter", +] diff --git a/vis4d/eval/shift/depth.py b/vis4d/eval/shift/depth.py new file mode 100644 index 0000000000000000000000000000000000000000..55e4f3708f7d7c0ef3745594910c5f3764360dde --- /dev/null +++ b/vis4d/eval/shift/depth.py @@ -0,0 +1,45 @@ +"""SHIFT depth estimation evaluator.""" + +from __future__ import annotations + +from vis4d.common.typing import NDArrayNumber + +from ..common import DepthEvaluator + + +def apply_crop(depth: NDArrayNumber) -> NDArrayNumber: + """Apply crop to depth map to match SHIFT evaluation.""" + return depth[..., 0:740, :] + + +class SHIFTDepthEvaluator(DepthEvaluator): + """SHIFT depth estimation evaluation class.""" + + def __init__(self, use_eval_crop: bool = True) -> None: + """Initialize the evaluator. + + Args: + use_eval_crop (bool): Whether to use the evaluation crop. + Default: True. + """ + super().__init__(min_depth=0.01, max_depth=80.0) + self.use_eval_crop = use_eval_crop + + def __repr__(self) -> str: + """Concise representation of the dataset evaluator.""" + return "SHIFT Depth Estimation Evaluator" + + def process_batch( # type: ignore # pylint: disable=arguments-differ + self, prediction: NDArrayNumber, groundtruth: NDArrayNumber + ) -> None: + """Process sample and update confusion matrix. + + Args: + prediction: Predictions of shape (N, H, W). + groundtruth: Groundtruth of shape (N, H, W). + """ + if self.use_eval_crop: + prediction = apply_crop(prediction) + groundtruth = apply_crop(groundtruth) + print(prediction.shape, groundtruth.shape) + super().process_batch(prediction, groundtruth) diff --git a/vis4d/eval/shift/detect.py b/vis4d/eval/shift/detect.py new file mode 100644 index 0000000000000000000000000000000000000000..ef4575600c5d5c14155fa5e5328645b9bec55132 --- /dev/null +++ b/vis4d/eval/shift/detect.py @@ -0,0 +1,18 @@ +"""SHIFT detection evaluator.""" + +from __future__ import annotations + +from vis4d.data.datasets.shift import shift_det_map + +from ..scalabel import ScalabelDetectEvaluator + + +class SHIFTDetectEvaluator(ScalabelDetectEvaluator): + """SHIFT detection evaluation class.""" + + inverse_det_map = {v: k for k, v in shift_det_map.items()} + + def __init__(self, annotation_path: str) -> None: + """Initialize the evaluator.""" + super().__init__(annotation_path=annotation_path, mask_threshold=0) + self.inverse_cat_map = self.inverse_det_map diff --git a/vis4d/eval/shift/flow.py b/vis4d/eval/shift/flow.py new file mode 100644 index 0000000000000000000000000000000000000000..e3c3b5b3dc404398e8852003dbf8a429147fbc24 --- /dev/null +++ b/vis4d/eval/shift/flow.py @@ -0,0 +1,19 @@ +"""SHIFT optical flow estimation evaluator.""" + +from __future__ import annotations + +from ..common import OpticalFlowEvaluator + + +class SHIFTOpticalFlowEvaluator(OpticalFlowEvaluator): + """SHIFT optical flow estimation evaluation class.""" + + def __init__( + self, + ) -> None: + """Initialize the evaluator.""" + super().__init__(max_flow=200.0, use_degrees=False, scale=1.0) + + def __repr__(self) -> str: + """Concise representation of the dataset evaluator.""" + return "SHIFT Optical Flow Estimation Evaluator" diff --git a/vis4d/eval/shift/multitask_writer.py b/vis4d/eval/shift/multitask_writer.py new file mode 100644 index 0000000000000000000000000000000000000000..e6d228bedcd9fb07ae6a909b1989c753d140373f --- /dev/null +++ b/vis4d/eval/shift/multitask_writer.py @@ -0,0 +1,279 @@ +"""SHIFT result writer.""" + +from __future__ import annotations + +import io +import itertools +import json +import os +from collections import defaultdict + +import numpy as np +from PIL import Image + +from vis4d.common.array import array_to_numpy +from vis4d.common.imports import SCALABEL_AVAILABLE +from vis4d.common.typing import ( + ArrayLike, + GenericFunc, + MetricLogs, + NDArrayNumber, +) +from vis4d.data.datasets.shift import shift_det_map +from vis4d.data.io import DataBackend, ZipBackend +from vis4d.eval.base import Evaluator + +if SCALABEL_AVAILABLE: + from scalabel.label.transforms import mask_to_rle, xyxy_to_box2d + from scalabel.label.typing import Dataset, Frame, Label +else: + raise ImportError("scalabel is not installed.") + + +class SHIFTMultitaskWriter(Evaluator): + """SHIFT result writer for online evaluation.""" + + inverse_cat_map = {v: k for k, v in shift_det_map.items()} + + def __init__( + self, + output_dir: str, + submission_file: str = "submission.zip", + ) -> None: + """Creates a new writer. + + Args: + output_dir (str): Output directory. + submission_file (str): Submission file name. Defaults to + "submission.zip". + """ + super().__init__() + assert submission_file.endswith( + ".zip" + ), "Submission file must be a zip file." + self.backend: DataBackend = ZipBackend() + self.output_path = os.path.join(output_dir, submission_file) + self.frames_det_2d: list[Frame] = [] + self.frames_det_3d: list[Frame] = [] + self.sample_counts: defaultdict[str, int] = defaultdict(int) + + def _write_sem_mask( + self, sem_mask: NDArrayNumber, sample_name: str, video_name: str + ) -> None: + """Write semantic mask. + + Args: + sem_mask (NDArrayNumber): Predicted semantic mask, shape (H, W). + sample_name (str): Sample name. + video_name (str): Video name. + """ + image = Image.fromarray(sem_mask.astype("uint8"), mode="L") + image_bytes = io.BytesIO() + image.save(image_bytes, format="PNG") + self.backend.set( + f"{self.output_path}/semseg/{video_name}/{sample_name}", + image_bytes.getvalue(), + mode="w", + ) + + def _write_depth( + self, depth_map: NDArrayNumber, sample_name: str, video_name: str + ) -> None: + """Write depth map. + + Args: + depth_map (NDArrayNumber): Predicted depth map, shape (H, W). + sample_name (str): Sample name. + video_name (str): Video name. + """ + depth_map = np.clip(depth_map / 80.0 * 255.0, 0, 255) + image = Image.fromarray(depth_map.astype("uint8"), mode="L") + image_bytes = io.BytesIO() + image.save(image_bytes, format="PNG") + self.backend.set( + f"{self.output_path}/depth/{video_name}/{sample_name}", + image_bytes.getvalue(), + mode="w", + ) + + def _write_flow( + self, flow: NDArrayNumber, sample_name: str, video_name: str + ) -> None: + """Write semantic mask. + + Args: + flow (NDArrayNumber): Predicted optical flow, shape (H, W, 2). + sample_name (str): Sample name. + video_name (str): Video name. + """ + raise NotImplementedError + + def process_batch( + self, + frame_ids: list[int], + sample_names: list[str], + sequence_names: list[str], + pred_sem_mask: list[ArrayLike] | None = None, + pred_depth: list[ArrayLike] | None = None, + pred_flow: list[ArrayLike] | None = None, + pred_boxes2d: list[ArrayLike] | None = None, + pred_boxes2d_classes: list[ArrayLike] | None = None, + pred_boxes2d_scores: list[ArrayLike] | None = None, + pred_boxes2d_track_ids: list[ArrayLike] | None = None, + pred_instance_masks: list[ArrayLike] | None = None, + ) -> None: + """Process SHIFT results. + + You can omit some of the predictions if they are not used. + + Args: + frame_ids (list[int]): Frame IDs. + sample_names (list[str]): Sample names. + sequence_names (list[str]): Sequence names. + pred_sem_mask (list[ArrayLike], optional): Predicted semantic + masks, each in shape (C, H, W) or (H, W). Defaults to None. + pred_depth (list[ArrayLike], optional): Predicted depth maps, + each in shape (H, W), with meter unit. Defaults to None. + pred_flow (list[ArrayLike], optional): Predicted optical flows, + each in shape (H, W, 2). Defaults to None. + pred_boxes2d (list[ArrayLike], optional): Predicted 2D boxes, + each in shape (N, 4). Defaults to None. + pred_boxes2d_classes (list[ArrayLike], optional): Predicted + 2D box classes, each in shape (N,). Defaults to None. + pred_boxes2d_scores (list[ArrayLike], optional): Predicted + 2D box scores, each in shape (N,). Defaults to None. + pred_boxes2d_track_ids (list[ArrayLike], optional): Predicted + 2D box track IDs, each in shape (N,). Defaults to None. + pred_instance_masks (list[ArrayLike], optional): Predicted + instance masks, each in shape (N, H, W). Defaults to None. + """ + for i, (frame_id, sample_name, sequence_name) in enumerate( + zip(frame_ids, sample_names, sequence_names) + ): + if pred_sem_mask is not None: + sem_mask_ = array_to_numpy( + pred_sem_mask[i], + n_dims=None, + dtype=np.float32, + ) + if len(sem_mask_.shape) == 3: + sem_mask = sem_mask_.argmax(axis=0) + else: + sem_mask = sem_mask_.astype(np.uint8) + semseg_filename = sample_name.replace(".jpg", ".png").replace( + "img", "semseg" + ) + self._write_sem_mask(sem_mask, semseg_filename, sequence_name) + self.sample_counts["semseg"] += 1 + if pred_depth is not None: + depth = array_to_numpy( + pred_depth[i], n_dims=None, dtype=np.float32 + ) + depth_filename = sample_name.replace(".jpg", ".png").replace( + "img", "depth" + ) + self._write_depth(depth, depth_filename, sequence_name) + self.sample_counts["depth"] += 1 + if pred_flow is not None: + flow = array_to_numpy( + pred_flow[i], n_dims=None, dtype=np.float32 + ) + self._write_flow(flow, sample_name, sequence_name) + self.sample_counts["flow"] += 1 + if ( + pred_boxes2d is not None + and pred_boxes2d_classes is not None + and pred_boxes2d_scores is not None + ): + labels = [] + if pred_instance_masks: + masks = array_to_numpy( + pred_instance_masks[i], n_dims=None, dtype=np.float32 + ) + if pred_boxes2d_track_ids: + track_ids = array_to_numpy( + pred_boxes2d_track_ids[i], + n_dims=None, + dtype=np.int64, + ) + for box, score, class_id in zip( + pred_boxes2d[i], + pred_boxes2d_scores[i], + pred_boxes2d_classes[i], + ): + box2d = xyxy_to_box2d(*box.tolist()) + if pred_instance_masks: + rle = mask_to_rle( + (masks[class_id] > 0.0).astype(np.uint8) + ) + else: + rle = None + + if pred_boxes2d_track_ids: + track_id = str(int(track_ids[0])) + else: + track_id = None + + label = Label( + box2d=box2d, + category=( + self.inverse_cat_map[int(class_id)] + if self.inverse_cat_map != {} + else str(class_id) + ), + score=float(score), + rle=rle, + id=track_id, + ) + labels.append(label) + frame = Frame( + name=sample_name, + videoName=sequence_name, + frameIndex=frame_id, + labels=labels, + ) + self.frames_det_2d.append(frame) + self.sample_counts["det_2d"] += 1 + + def gather(self, gather_func: GenericFunc) -> None: # pragma: no cover + """Gather variables in case of distributed setting (if needed). + + Args: + gather_func (Callable[[Any], Any]): Gather function. + """ + all_preds = gather_func(self.frames_det_2d) + if all_preds is not None: + self.frames_det_2d = list(itertools.chain(*all_preds)) + + def evaluate(self, metric: str) -> tuple[MetricLogs, str]: + """No evaluation locally.""" + return {}, "No evaluation locally." + + def save(self, metric: str, output_dir: str) -> None: + """Save scalabel output to zip file. + + Raises: + ValueError: If the number of samples in each category is not the + same. + """ + # Check if the sample counts are correct + equal_size = True + for key in self.sample_counts: + if self.sample_counts[key] != len(self.frames_det_2d): + equal_size = False + break + if not equal_size: + raise ValueError( + "The number of samples in each category is not the same." + ) + + # Save the 2D detection results + if len(self.frames_det_2d) > 0: + ds = Dataset(frames=self.frames_det_2d, groups=None, config=None) + ds_bytes = json.dumps(ds.dict()).encode("utf-8") + self.backend.set( + f"{self.output_path}/det_2d.json", ds_bytes, mode="w" + ) + + self.backend.close() + print(f"Saved the submission file at {self.output_path}.") diff --git a/vis4d/eval/shift/seg.py b/vis4d/eval/shift/seg.py new file mode 100644 index 0000000000000000000000000000000000000000..939d007a981ad57e834b64e801a05d70f9c68b2c --- /dev/null +++ b/vis4d/eval/shift/seg.py @@ -0,0 +1,48 @@ +"""SHIFT segmentation evaluator.""" + +from __future__ import annotations + +from vis4d.common.typing import NDArrayI64, NDArrayNumber +from vis4d.data.datasets.shift import shift_seg_ignore, shift_seg_map +from vis4d.eval.common.seg import SegEvaluator + + +class SHIFTSegEvaluator(SegEvaluator): + """SHIFT segmentation evaluation class.""" + + inverse_seg_map = {v: k for k, v in shift_seg_map.items()} + + def __init__(self, ignore_classes_as_cityscapes: bool = True) -> None: + """Initialize the evaluator.""" + super().__init__( + num_classes=23, + class_to_ignore=255, + class_mapping=self.inverse_seg_map, + ) + self.ignore_classes_as_cityscapes = ignore_classes_as_cityscapes + + def __repr__(self) -> str: + """Concise representation of the dataset evaluator.""" + return "SHIFT Segmentation Evaluator" + + def _prune_class(self, label: NDArrayI64) -> NDArrayI64: + """Prune class labels.""" + for cls in shift_seg_ignore: + label[label == shift_seg_map[cls]] = 255 + return label + + def process_batch( # type: ignore # pylint: disable=arguments-differ + self, prediction: NDArrayNumber, groundtruth: NDArrayI64 + ) -> None: + """Process sample and update confusion matrix. + + Args: + prediction: Predictions of shape [N,C,...] or [N,...] with + C* being any number if channels. Note, C is passed, + the prediction is converted to target labels by applying + the max operations along the second axis + groundtruth: Groundtruth of shape [N_batch, ...] type int + """ + if self.ignore_classes_as_cityscapes: + groundtruth = self._prune_class(groundtruth) + super().process_batch(prediction, groundtruth) diff --git a/vis4d/eval/shift/track.py b/vis4d/eval/shift/track.py new file mode 100644 index 0000000000000000000000000000000000000000..1a9b0d1c50d1838d9d3a29a8ee424317dfff9535 --- /dev/null +++ b/vis4d/eval/shift/track.py @@ -0,0 +1,22 @@ +"""SHIFT tracking evaluator.""" + +from __future__ import annotations + +from vis4d.data.datasets.shift import shift_det_map + +from ..scalabel import ScalabelTrackEvaluator + + +class SHIFTTrackEvaluator(ScalabelTrackEvaluator): + """SHIFT tracking evaluation class.""" + + inverse_det_map = {v: k for k, v in shift_det_map.items()} + + def __init__( + self, annotation_path: str, mask_threshold: float = 0.0 + ) -> None: + """Initialize the evaluator.""" + super().__init__( + annotation_path=annotation_path, mask_threshold=mask_threshold + ) + self.inverse_cat_map = self.inverse_det_map diff --git a/vis4d/eval/utils.py b/vis4d/eval/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9911ebd4d481c30b2e479b8d65d5cd3ab54bab4a --- /dev/null +++ b/vis4d/eval/utils.py @@ -0,0 +1,25 @@ +"""Utility functions for evaluation.""" + +import numpy as np + +from vis4d.common.array import array_to_numpy +from vis4d.common.typing import ArrayLike, NDArrayNumber + + +def dense_inputs_to_numpy( + prediction: ArrayLike, target: ArrayLike +) -> tuple[NDArrayNumber, NDArrayNumber]: + """Convert dense prediction and target to numpy arrays.""" + prediction = array_to_numpy(prediction, n_dims=None, dtype=np.float32) + target = array_to_numpy(target, n_dims=None, dtype=np.float32) + return prediction, target + + +def check_shape_match( + prediction: NDArrayNumber, target: NDArrayNumber +) -> None: + """Check if the shape of prediction and target matches.""" + assert prediction.shape == target.shape, ( + f"Shape mismatch between prediction {prediction.shape} and target" + f"{target.shape}." + ) diff --git a/vis4d/model/__init__.py b/vis4d/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9643ab24c6f0ca60c08174bb19903bc1ba1e817f --- /dev/null +++ b/vis4d/model/__init__.py @@ -0,0 +1,6 @@ +"""Model definitions that connect operators and states together. + +All the compute should go to operators and the model memories should be kept +in states. The models are supposed to do minimum job to connect the model +pipelines. +""" diff --git a/vis4d/model/adapter/__init__.py b/vis4d/model/adapter/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..450a3d45261b6526bb6aa4530b207f3bcbd81ea5 --- /dev/null +++ b/vis4d/model/adapter/__init__.py @@ -0,0 +1,5 @@ +"""Model adapters.""" + +from .ema import ModelEMAAdapter, ModelExpEMAAdapter + +__all__ = ["ModelEMAAdapter", "ModelExpEMAAdapter"] diff --git a/vis4d/model/adapter/ema.py b/vis4d/model/adapter/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..573836bc4adade41103738dab475fd74aaafb60a --- /dev/null +++ b/vis4d/model/adapter/ema.py @@ -0,0 +1,118 @@ +"""Exponential Moving Average (EMA) for PyTorch models.""" + +from __future__ import annotations + +import math +from collections.abc import Callable +from copy import deepcopy +from typing import Any + +import torch +from torch import Tensor, nn + +from vis4d.common.logging import rank_zero_info + + +class ModelEMAAdapter(nn.Module): + """Torch module with Exponential Moving Average (EMA). + + Args: + model (nn.Module): Model to apply EMA. + decay (float): Decay factor for EMA. Defaults to 0.9998. + use_ema_during_test (bool): Use EMA model during testing. Defaults to + True. + device (torch.device | None): Device to use. Defaults to None. + """ + + def __init__( + self, + model: nn.Module, + decay: float = 0.9998, + use_ema_during_test: bool = True, + device: torch.device | None = None, + ): + """Init ModelEMAAdapter class.""" + super().__init__() + self.model = model + self.ema_model = deepcopy(self.model) + self.ema_model.eval() + for p in self.ema_model.parameters(): + p.requires_grad_(False) + self.decay = decay + self.use_ema_during_test = use_ema_during_test + self.device = device + if self.device is not None: + self.ema_model.to(device=device) + rank_zero_info("Using model EMA with decay rate %f", self.decay) + + def _update( + self, model: nn.Module, update_fn: Callable[[Tensor, Tensor], Tensor] + ) -> None: + """Update model params.""" + with torch.no_grad(): + for ema_v, model_v in zip( + self.ema_model.state_dict().values(), + model.state_dict().values(), + ): + if self.device is not None: + model_v = model_v.to(device=self.device) + ema_v.copy_(update_fn(ema_v, model_v)) + + def update(self, steps: int) -> None: # pylint: disable=unused-argument + """Update the internal EMA model.""" + self._update( + self.model, + update_fn=lambda e, m: self.decay * e + (1.0 - self.decay) * m, + ) + + def set(self, model: nn.Module) -> None: + """Copy model params into the internal EMA.""" + self._update(model, update_fn=lambda e, m: m) + + def forward(self, *args: Any, **kwargs: Any) -> Any: # type: ignore + """Forward pass with original model.""" + if self.training or not self.use_ema_during_test: + return self.model(*args, **kwargs) + return self.ema_model(*args, **kwargs) + + +class ModelExpEMAAdapter(ModelEMAAdapter): + """Exponential Moving Average (EMA) with exponential decay strategy. + + Used by YOLOX. + + Args: + model (nn.Module): Model to apply EMA. + decay (float): Decay factor for EMA. Defaults to 0.9998. + warmup_steps (int): Number of warmup steps for decay. Use a smaller + decay early in training and gradually anneal to the set decay value + to update the EMA model smoothly. + use_ema_during_test (bool): Use EMA model during testing. Defaults to + True. + device (torch.device | None): Device to use. Defaults to None. + """ + + def __init__( + self, + model: nn.Module, + decay: float = 0.9998, + warmup_steps: int = 2000, + use_ema_during_test: bool = True, + device: torch.device | None = None, + ): + """Init ModelEMAAdapter class.""" + super().__init__(model, decay, use_ema_during_test, device) + assert ( + warmup_steps > 0 + ), f"warmup_steps must be greater than 0, got {warmup_steps}" + self.warmup_steps = warmup_steps + + def update(self, steps: int) -> None: + """Update the internal EMA model.""" + decay = self.decay * ( + 1 - math.exp(-float(1 + steps) / self.warmup_steps) + ) + self._update( + self.model, + update_fn=lambda e, m: decay * e + (1.0 - decay) * m, + ) diff --git a/vis4d/model/adapter/flops.py b/vis4d/model/adapter/flops.py new file mode 100644 index 0000000000000000000000000000000000000000..170178af8b51f462390dcd59214bc87967173b60 --- /dev/null +++ b/vis4d/model/adapter/flops.py @@ -0,0 +1,59 @@ +"""Adapter for counting flops in a model.""" + +from __future__ import annotations + +from typing import Any + +from torch import nn + +from vis4d.engine.connectors import DataConnector + +# Ops to ignore from counting, including elementwise and reduction ops +IGNORED_OPS = { + "aten::add", + "aten::add_", + "aten::argmax", + "aten::argsort", + "aten::batch_norm", + "aten::constant_pad_nd", + "aten::div", + "aten::div_", + "aten::exp", + "aten::log2", + "aten::max_pool2d", + "aten::meshgrid", + "aten::mul", + "aten::mul_", + "aten::neg", + "aten::nonzero_numpy", + "aten::reciprocal", + "aten::repeat_interleave", + "aten::rsub", + "aten::sigmoid", + "aten::sigmoid_", + "aten::softmax", + "aten::sort", + "aten::sqrt", + "aten::sub", + "torchvision::nms", +} + + +class FlopsModelAdapter(nn.Module): + """Adapter for the model to count flops.""" + + def __init__( + self, model: nn.Module, data_connector: DataConnector + ) -> None: + """Initialize the adapter.""" + super().__init__() + self.model = model + self.data_connector = data_connector + + def forward(self, *args: Any) -> Any: # type: ignore + """Forward pass through the model.""" + data_dict = {} + for i, key in enumerate(self.data_connector.key_mapping): + data_dict[key] = args[0][i] + + return self.model(**data_dict) diff --git a/vis4d/model/cls/__init__.py b/vis4d/model/cls/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cd4e60c12ac67a9c748ef85d3e49c4347f584d4b --- /dev/null +++ b/vis4d/model/cls/__init__.py @@ -0,0 +1,6 @@ +"""Common classes and functions for classification models.""" + +from .common import ClsOut +from .vit import ViTClassifer + +__all__ = ["ViTClassifer", "ClsOut"] diff --git a/vis4d/model/cls/common.py b/vis4d/model/cls/common.py new file mode 100644 index 0000000000000000000000000000000000000000..91aad295018f262a300ce581c6a3668310a66e3f --- /dev/null +++ b/vis4d/model/cls/common.py @@ -0,0 +1,12 @@ +"""Common types for classification models.""" + +from typing import NamedTuple + +import torch + + +class ClsOut(NamedTuple): + """Output of the classification results.""" + + logits: torch.Tensor # (N, num_classes) + probs: torch.Tensor # (N, num_classes) diff --git a/vis4d/model/cls/vit.py b/vis4d/model/cls/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..ea4e6366cbf5333d04736f63c9af43db55ae1ffc --- /dev/null +++ b/vis4d/model/cls/vit.py @@ -0,0 +1,122 @@ +"""ViT for classification tasks.""" + +from __future__ import annotations + +import timm.models.vision_transformer as _vision_transformer +import torch +from torch import nn + +from vis4d.common.ckpt import load_model_checkpoint +from vis4d.common.typing import ArgsType +from vis4d.op.base.vit import VisionTransformer, ViT_PRESET + +from .common import ClsOut + + +class ViTClassifer(nn.Module): + """ViT for classification tasks.""" + + def __init__( + self, + variant: str = "", + num_classes: int = 1000, + use_global_pooling: bool = False, + weights: str | None = None, + num_prefix_tokens: int = 1, + **kwargs: ArgsType, + ) -> None: + """Initialize the classification ViT. + + Args: + variant (str): Name of the ViT variant. Defaults to "". If the name + starts with "timm://", the variant will be loaded from timm's + model zoo. Otherwise, the variant will be loaded from the + ViT_PRESET dict. If the variant is empty, the default ViT + variant will be used. In all cases, the additional keyword + arguments will override the default arguments. + num_classes (int, optional): Number of classes. Defaults to 1000. + use_global_pooling (bool, optional): If to use global pooling. + Defaults to False. If set to True, the output of the ViT will + be averaged over the spatial dimensions. Otherwise, the first + token will be used for classification. + weights (str, optional): If to load pretrained weights. If set to + "timm", the weights will be loaded from timm's model zoo that + matches the variant. If a URL is provided, the weights will be + downloaded from the URL. Defaults to None, which means no + weights will be loaded. + num_prefix_tokens (int, optional): Number of prefix tokens. + Defaults to 1. + **kwargs: Keyword arguments passed to the ViT model. + """ + super().__init__() + self.num_classes = num_classes + self.use_global_pooling = use_global_pooling + self.num_prefix_tokens = num_prefix_tokens + + if variant != "": + assert variant in ViT_PRESET, ( + f"Unknown ViT variant: {variant}. " + f"Available ViT variants are: {list(ViT_PRESET.keys())}" + ) + preset_kwargs = ViT_PRESET[variant] + preset_kwargs["num_classes"] = num_classes + preset_kwargs.update(kwargs) + self.vit = VisionTransformer(**preset_kwargs) # type: ignore + else: + # Build ViT from scratch using kwargs + preset_kwargs = {} + self.vit = VisionTransformer(num_classes=num_classes, **kwargs) + + # Classification head + embed_dim = kwargs.get( + "embed_dim", preset_kwargs.get("embed_dim", 768) + ) + self.norm = ( + nn.LayerNorm(embed_dim) if use_global_pooling else nn.Identity() + ) + self.head = ( + nn.Linear(embed_dim, num_classes) + if num_classes > 0 + else nn.Identity() + ) + + # Load pretrain weights + if weights is not None: + if weights.startswith("timm://"): + weights = weights.removeprefix("timm://") + if "." in weights: + model_name, pretrain_tag = weights.split(".") + else: + model_name = weights + pretrain_tag = None + assert model_name in _vision_transformer.__dict__, ( + f"Unknown Timm ViT weights: {model_name}. " + f"Available Timm ViT weights are: " + f"{list(_vision_transformer.__dict__.keys())}" + ) + _model = _vision_transformer.__dict__[model_name]( + pretrained=True, pretrained_cfg=pretrain_tag, **kwargs + ) + self.vit.load_state_dict(_model.state_dict(), strict=False) + self.norm.load_state_dict( + _model.norm.state_dict(), strict=False + ) + self.head.load_state_dict( + _model.head.state_dict(), strict=False + ) + else: + load_model_checkpoint(self, weights) + + def forward(self, images: torch.Tensor) -> ClsOut: + """Forward pass.""" + feats = self.vit(images) + x = feats[-1] + if self.use_global_pooling: + x = x[:, self.num_prefix_tokens :].mean(dim=1) + else: + x = x[:, 0] + x = self.norm(x) + logits = self.head(x) + return ClsOut( + logits=logits, probs=torch.softmax(logits.detach(), dim=-1) + ) diff --git a/vis4d/model/detect/__init__.py b/vis4d/model/detect/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..db4d5a49062d47c24ce4d6621d5a7aa9bee7bf59 --- /dev/null +++ b/vis4d/model/detect/__init__.py @@ -0,0 +1 @@ +"""This module contains the model implementations of 2D detectors.""" diff --git a/vis4d/model/detect/faster_rcnn.py b/vis4d/model/detect/faster_rcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..97a672bafe271e1fc6a59aac253f48b57f4c1c8c --- /dev/null +++ b/vis4d/model/detect/faster_rcnn.py @@ -0,0 +1,178 @@ +"""Faster RCNN model implementation and runtime.""" + +from __future__ import annotations + +import torch +from torch import nn + +from vis4d.common.ckpt import load_model_checkpoint +from vis4d.op.base import BaseModel, ResNet +from vis4d.op.box.box2d import scale_and_clip_boxes +from vis4d.op.box.encoder import DeltaXYWHBBoxDecoder +from vis4d.op.detect.common import DetOut +from vis4d.op.detect.faster_rcnn import FasterRCNNHead, FRCNNOut +from vis4d.op.detect.rcnn import RoI2Det +from vis4d.op.fpp.fpn import FPN + +REV_KEYS = [ + (r"^backbone\.", "basemodel."), + (r"^rpn_head.rpn_reg\.", "faster_rcnn_head.rpn_head.rpn_box."), + (r"^rpn_head.rpn_", "faster_rcnn_head.rpn_head.rpn_"), + (r"^roi_head.bbox_head\.", "faster_rcnn_head.roi_head."), + (r"^neck.lateral_convs\.", "fpn.inner_blocks."), + (r"^neck.fpn_convs\.", "fpn.layer_blocks."), + (r"\.conv.weight", ".weight"), + (r"\.conv.bias", ".bias"), +] + + +class FasterRCNN(nn.Module): + """Faster RCNN model.""" + + def __init__( + self, + num_classes: int, + basemodel: BaseModel | None = None, + faster_rcnn_head: FasterRCNNHead | None = None, + rcnn_box_decoder: DeltaXYWHBBoxDecoder | None = None, + weights: None | str = None, + ) -> None: + """Creates an instance of the class. + + Args: + num_classes (int): Number of object categories. + basemodel (BaseModel, optional): Base model network. Defaults to + None. If None, will use ResNet50. + faster_rcnn_head (FasterRCNNHead, optional): Faster RCNN head. + Defaults to None. if None, will use default FasterRCNNHead. + rcnn_box_decoder (DeltaXYWHBBoxDecoder, optional): Decoder for RCNN + bounding boxes. Defaults to None. + weights (str, optional): Weights to load for model. If set to + "mmdet", will load MMDetection pre-trained weights. Defaults to + None. + """ + super().__init__() + self.basemodel = ( + ResNet(resnet_name="resnet50", pretrained=True, trainable_layers=3) + if basemodel is None + else basemodel + ) + + self.fpn = FPN(self.basemodel.out_channels[2:], 256) + + if faster_rcnn_head is None: + self.faster_rcnn_head = FasterRCNNHead(num_classes=num_classes) + else: + self.faster_rcnn_head = faster_rcnn_head + + self.roi2det = RoI2Det(rcnn_box_decoder) + + if weights is not None: + if weights == "mmdet": + weights = ( + "mmdet://faster_rcnn/faster_rcnn_r50_fpn_1x_coco/" + "faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth" + ) + if weights.startswith("mmdet://") or weights.startswith( + "bdd100k://" + ): + load_model_checkpoint(self, weights, rev_keys=REV_KEYS) + else: + load_model_checkpoint(self, weights) + + def forward( + self, + images: torch.Tensor, + input_hw: list[tuple[int, int]], + boxes2d: None | list[torch.Tensor] = None, + boxes2d_classes: None | list[torch.Tensor] = None, + original_hw: None | list[tuple[int, int]] = None, + ) -> FRCNNOut | DetOut: + """Forward pass. + + Args: + images (torch.Tensor): Input images. + input_hw (list[tuple[int, int]]): Input image resolutions. + boxes2d (None | list[torch.Tensor], optional): Bounding box labels. + Required for training. Defaults to None. + boxes2d_classes (None | list[torch.Tensor], optional): Class + labels. Required for training. Defaults to None. + original_hw (None | list[tuple[int, int]], optional): Original + image resolutions (before padding and resizing). Required for + testing. Defaults to None. + + Returns: + FRCNNOut | DetOut: Either raw model outputs (for training) or + predicted outputs (for testing). + """ + if self.training: + assert boxes2d is not None and boxes2d_classes is not None + return self.forward_train( + images, input_hw, boxes2d, boxes2d_classes + ) + assert original_hw is not None + return self.forward_test(images, input_hw, original_hw) + + def __call__( + self, + images: torch.Tensor, + input_hw: list[tuple[int, int]], + boxes2d: None | list[torch.Tensor] = None, + boxes2d_classes: None | list[torch.Tensor] = None, + original_hw: None | list[tuple[int, int]] = None, + ) -> FRCNNOut | DetOut: + """Type definition for call implementation.""" + return self._call_impl( + images, input_hw, boxes2d, boxes2d_classes, original_hw + ) + + def forward_train( + self, + images: torch.Tensor, + images_hw: list[tuple[int, int]], + target_boxes: list[torch.Tensor], + target_classes: list[torch.Tensor], + ) -> FRCNNOut: + """Forward training stage. + + Args: + images (torch.Tensor): Input images. + images_hw (list[tuple[int, int]]): Input image resolutions. + target_boxes (list[torch.Tensor]): Bounding box labels. + target_classes (list[torch.Tensor]): Class labels. + + Returns: + FRCNNOut: Raw model outputs. + """ + features = self.fpn(self.basemodel(images)) + return self.faster_rcnn_head( + features, images_hw, target_boxes, target_classes + ) + + def forward_test( + self, + images: torch.Tensor, + images_hw: list[tuple[int, int]], + original_hw: list[tuple[int, int]], + ) -> DetOut: + """Forward testing stage. + + Args: + images (torch.Tensor): Input images. + images_hw (list[tuple[int, int]]): Input image resolutions. + original_hw (list[tuple[int, int]]): Original image resolutions + (before padding and resizing). + + Returns: + DetOut: Predicted outputs. + """ + features = self.fpn(self.basemodel(images)) + outs = self.faster_rcnn_head(features, images_hw) + boxes, scores, class_ids = self.roi2det( + *outs.roi, outs.proposals.boxes, images_hw + ) + + for i, boxs in enumerate(boxes): + boxes[i] = scale_and_clip_boxes(boxs, original_hw[i], images_hw[i]) + + return DetOut(boxes, scores, class_ids) diff --git a/vis4d/model/detect/mask_rcnn.py b/vis4d/model/detect/mask_rcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..22cc29b7632a36df5a4c21b855c0afac53a46d6f --- /dev/null +++ b/vis4d/model/detect/mask_rcnn.py @@ -0,0 +1,219 @@ +"""Mask RCNN model implementation and runtime.""" + +from __future__ import annotations + +from typing import NamedTuple + +import torch +from torch import nn + +from vis4d.common.ckpt import load_model_checkpoint +from vis4d.op.base import BaseModel, ResNet +from vis4d.op.box.box2d import apply_mask, scale_and_clip_boxes +from vis4d.op.box.encoder import DeltaXYWHBBoxDecoder +from vis4d.op.detect.common import DetOut +from vis4d.op.detect.faster_rcnn import FasterRCNNHead, FRCNNOut +from vis4d.op.detect.mask_rcnn import ( + Det2Mask, + MaskOut, + MaskRCNNHead, + MaskRCNNHeadOut, +) +from vis4d.op.detect.rcnn import RoI2Det +from vis4d.op.fpp.fpn import FPN + + +class MaskDetectionOut(NamedTuple): + """Mask detection output.""" + + boxes: DetOut + masks: MaskOut + + +class MaskRCNNOut(NamedTuple): + """Mask RCNN output.""" + + boxes: FRCNNOut + masks: MaskRCNNHeadOut + + +REV_KEYS = [ + (r"^backbone\.", "basemodel."), + (r"^rpn_head.rpn_reg\.", "rpn_head.rpn_box."), + (r"^roi_head.bbox_head\.", "roi_head."), + (r"^roi_head.mask_head\.", "mask_head."), + (r"^convs\.", "mask_head.convs."), + (r"^upsample\.", "mask_head.upsample."), + (r"^conv_logits\.", "mask_head.conv_logits."), + (r"^roi_head\.", "faster_rcnn_head.roi_head."), + (r"^rpn_head\.", "faster_rcnn_head.rpn_head."), + (r"^neck.lateral_convs\.", "fpn.inner_blocks."), + (r"^neck.fpn_convs\.", "fpn.layer_blocks."), + (r"\.conv.weight", ".weight"), + (r"\.conv.bias", ".bias"), +] + + +class MaskRCNN(nn.Module): + """Mask RCNN model. + + Args: + num_classes (int): Number of classes. + basemodel (BaseModel, optional): Base model network. Defaults to + None. If None, will use ResNet50. + faster_rcnn_head (FasterRCNNHead, optional): Faster RCNN head. + Defaults to None. if None, will use default FasterRCNNHead. + mask_head (MaskRCNNHead, optional): Mask RCNN head. Defaults to + None. if None, will use default MaskRCNNHead. + rcnn_box_decoder (DeltaXYWHBBoxDecoder, optional): Decoder for RCNN + bounding boxes. Defaults to None. + no_overlap (bool, optional): Whether to remove overlapping pixels + between masks. Defaults to False. + weights (None | str, optional): Weights to load for model. If set + to "mmdet", will load MMDetection pre-trained weights. + Defaults to None. + """ + + def __init__( + self, + num_classes: int, + basemodel: BaseModel | None = None, + faster_rcnn_head: FasterRCNNHead | None = None, + mask_head: MaskRCNNHead | None = None, + rcnn_box_decoder: DeltaXYWHBBoxDecoder | None = None, + no_overlap: bool = False, + weights: None | str = None, + ) -> None: + """Creates an instance of the class.""" + super().__init__() + self.basemodel = ( + ResNet(resnet_name="resnet50", pretrained=True, trainable_layers=3) + if basemodel is None + else basemodel + ) + + self.fpn = FPN(self.basemodel.out_channels[2:], 256) + + if faster_rcnn_head is None: + self.faster_rcnn_head = FasterRCNNHead(num_classes=num_classes) + else: + self.faster_rcnn_head = faster_rcnn_head + + if mask_head is None: + self.mask_head = MaskRCNNHead(num_classes=num_classes) + else: + self.mask_head = mask_head + + self.transform_outs = RoI2Det(rcnn_box_decoder) + self.det2mask = Det2Mask(no_overlap=no_overlap) + + if weights is not None: + if weights == "mmdet": + weights = ( + "mmdet://mask_rcnn/mask_rcnn_r50_fpn_2x_coco/" + "mask_rcnn_r50_fpn_2x_coco_bbox_mAP-0.392__segm_mAP-0.354_" + "20200505_003907-3e542a40.pth" + ) + if weights.startswith("mmdet://") or weights.startswith( + "bdd100k://" + ): + load_model_checkpoint(self, weights, rev_keys=REV_KEYS) + else: + load_model_checkpoint(self, weights) + + def forward( + self, + images: torch.Tensor, + input_hw: list[tuple[int, int]], + boxes2d: None | list[torch.Tensor] = None, + boxes2d_classes: None | list[torch.Tensor] = None, + original_hw: None | list[tuple[int, int]] = None, + ) -> MaskRCNNOut | MaskDetectionOut: + """Forward pass. + + Args: + images (torch.Tensor): Input images. + input_hw (list[tuple[int, int]]): Input image resolutions. + boxes2d (None | list[torch.Tensor], optional): Bounding box + labels. Required for training. Defaults to None. + boxes2d_classes (None | list[torch.Tensor], optional): Class + labels. Required for training. Defaults to None. + original_hw (None | list[tuple[int, int]], optional): Original + image resolutions (before padding and resizing). Required for + testing. Defaults to None. + + Returns: + MaskRCNNOut | MaskDetectionOut: Either raw model + outputs (for training) or predicted outputs (for testing). + """ + if self.training: + assert boxes2d is not None and boxes2d_classes is not None + return self.forward_train( + images, input_hw, boxes2d, boxes2d_classes + ) + assert original_hw is not None + return self.forward_test(images, input_hw, original_hw) + + def forward_train( + self, + images: torch.Tensor, + images_hw: list[tuple[int, int]], + target_boxes: list[torch.Tensor], + target_classes: list[torch.Tensor], + ) -> MaskRCNNOut: + """Forward training stage. + + Args: + images (torch.Tensor): Input images. + images_hw (list[tuple[int, int]]): Input image resolutions. + target_boxes (list[torch.Tensor]): Bounding box labels. Required + for training. Defaults to None. + target_classes (list[torch.Tensor]): Class labels. Required for + training. Defaults to None. + + Returns: + MaskRCNNOut: Raw model outputs. + """ + features = self.fpn(self.basemodel(images)) + outputs = self.faster_rcnn_head( + features, images_hw, target_boxes, target_classes + ) + assert outputs.sampled_proposals is not None + assert outputs.sampled_targets is not None + pos_proposals = apply_mask( + [torch.eq(label, 1) for label in outputs.sampled_targets.labels], + outputs.sampled_proposals.boxes, + )[0] + mask_outs = self.mask_head(features, pos_proposals) + return MaskRCNNOut(outputs, mask_outs) + + def forward_test( + self, + images: torch.Tensor, + images_hw: list[tuple[int, int]], + original_hw: list[tuple[int, int]], + ) -> MaskDetectionOut: + """Forward testing stage. + + Args: + images (torch.Tensor): Input images. + images_hw (list[tuple[int, int]]): Input image resolutions. + original_hw (list[tuple[int, int]]): Original image resolutions + (before padding and resizing). + + Returns: + MaskDetectionOut: Predicted outputs. + """ + features = self.fpn(self.basemodel(images)) + outs = self.faster_rcnn_head(features, images_hw) + boxes, scores, class_ids = self.transform_outs( + *outs.roi, outs.proposals.boxes, images_hw + ) + mask_outs = self.mask_head(features, boxes) + for i, boxs in enumerate(boxes): + boxes[i] = scale_and_clip_boxes(boxs, original_hw[i], images_hw[i]) + mask_preds = [m.sigmoid() for m in mask_outs.mask_pred] + masks = self.det2mask( + mask_preds, boxes, scores, class_ids, original_hw + ) + return MaskDetectionOut(DetOut(boxes, scores, class_ids), masks) diff --git a/vis4d/model/detect/retinanet.py b/vis4d/model/detect/retinanet.py new file mode 100644 index 0000000000000000000000000000000000000000..204bd866641d8c085ad40ac21a025bbfb8fcdadb --- /dev/null +++ b/vis4d/model/detect/retinanet.py @@ -0,0 +1,193 @@ +"""RetinaNet model implementation and runtime.""" + +from __future__ import annotations + +from torch import Tensor, nn + +from vis4d.common.ckpt import load_model_checkpoint +from vis4d.common.typing import LossesType +from vis4d.op.base.resnet import ResNet +from vis4d.op.box.anchor import AnchorGenerator +from vis4d.op.box.box2d import scale_and_clip_boxes +from vis4d.op.box.encoder import DeltaXYWHBBoxEncoder +from vis4d.op.box.matchers import Matcher +from vis4d.op.box.samplers import Sampler +from vis4d.op.detect.common import DetOut +from vis4d.op.detect.retinanet import ( + Dense2Det, + RetinaNetHead, + RetinaNetHeadLoss, + RetinaNetOut, +) +from vis4d.op.fpp.fpn import FPN, ExtraFPNBlock + +REV_KEYS = [ + (r"^backbone\.", "basemodel."), + (r"^bbox_head\.", "retinanet_head."), + (r"^neck.lateral_convs\.", "fpn.inner_blocks."), + (r"^neck.fpn_convs\.", "fpn.layer_blocks."), + (r"^fpn.layer_blocks.3\.", "fpn.extra_blocks.convs.0."), + (r"^fpn.layer_blocks.4\.", "fpn.extra_blocks.convs.1."), + (r"\.conv.weight", ".weight"), + (r"\.conv.bias", ".bias"), +] + + +class RetinaNet(nn.Module): + """RetinaNet wrapper class for checkpointing etc.""" + + def __init__(self, num_classes: int, weights: None | str = None) -> None: + """Creates an instance of the class. + + Args: + num_classes (int): Number of classes. + weights (None | str, optional): Weights to load for model. If + set to "mmdet", will load MMDetection pre-trained weights. + Defaults to None. + """ + super().__init__() + self.basemodel = ResNet( + "resnet50", pretrained=True, trainable_layers=3 + ) + self.fpn = FPN( + self.basemodel.out_channels[3:], + 256, + ExtraFPNBlock(2, 2048, 256, add_extra_convs="on_input"), + start_index=3, + ) + self.retinanet_head = RetinaNetHead( + num_classes=num_classes, in_channels=256 + ) + self.transform_outs = Dense2Det( + self.retinanet_head.anchor_generator, + self.retinanet_head.box_decoder, + num_pre_nms=1000, + max_per_img=100, + nms_threshold=0.5, + score_thr=0.05, + ) + + if weights == "mmdet": + weights = ( + "mmdet://retinanet/retinanet_r50_fpn_2x_coco/" + "retinanet_r50_fpn_2x_coco_20200131-fdb43119.pth" + ) + load_model_checkpoint(self, weights, rev_keys=REV_KEYS) + elif weights is not None: + load_model_checkpoint(self, weights) + + def forward( + self, + images: Tensor, + input_hw: None | list[tuple[int, int]] = None, + original_hw: None | list[tuple[int, int]] = None, + ) -> RetinaNetOut | DetOut: + """Forward pass. + + Args: + images (Tensor): Input images. + input_hw (None | list[tuple[int, int]], optional): Input image + resolutions. Defaults to None. + original_hw (None | list[tuple[int, int]], optional): Original + image resolutions (before padding and resizing). Required for + testing. Defaults to None. + + Returns: + RetinaNetOut | DetOut: Either raw model outputs (for training) or + predicted outputs (for testing). + """ + if self.training: + return self.forward_train(images) + assert input_hw is not None and original_hw is not None + return self.forward_test(images, input_hw, original_hw) + + def forward_train(self, images: Tensor) -> RetinaNetOut: + """Forward training stage. + + Args: + images (Tensor): Input images. + + Returns: + RetinaNetOut: Raw model outputs. + """ + features = self.fpn(self.basemodel(images)) + return self.retinanet_head(features[-5:]) + + def forward_test( + self, + images: Tensor, + images_hw: list[tuple[int, int]], + original_hw: list[tuple[int, int]], + ) -> DetOut: + """Forward testing stage. + + Args: + images (Tensor): Input images. + images_hw (list[tuple[int, int]]): Input image resolutions. + original_hw (list[tuple[int, int]]): Original image resolutions + (before padding and resizing). + + Returns: + DetOut: Predicted outputs. + """ + features = self.fpn(self.basemodel(images)) + outs = self.retinanet_head(features[-5:]) + boxes, scores, class_ids = self.transform_outs( + cls_outs=outs.cls_score, + reg_outs=outs.bbox_pred, + images_hw=images_hw, + ) + for i, boxs in enumerate(boxes): + boxes[i] = scale_and_clip_boxes(boxs, original_hw[i], images_hw[i]) + return DetOut(boxes, scores, class_ids) + + +class RetinaNetLoss(nn.Module): + """RetinaNet Loss.""" + + def __init__( + self, + anchor_generator: AnchorGenerator, + box_encoder: DeltaXYWHBBoxEncoder, + box_matcher: Matcher, + box_sampler: Sampler, + ) -> None: + """Creates an instance of the class. + + Args: + anchor_generator (AnchorGenerator): Anchor generator for RPN. + box_encoder (BoxEncoder2D): Bounding box encoder. + box_matcher (BaseMatcher): Bounding box matcher. + box_sampler (BaseSampler): Bounding box sampler. + """ + super().__init__() + self.retinanet_loss = RetinaNetHeadLoss( + anchor_generator, box_encoder, box_matcher, box_sampler + ) + + def forward( + self, + outputs: RetinaNetOut, + images_hw: list[tuple[int, int]], + target_boxes: list[Tensor], + target_classes: list[Tensor], + ) -> LossesType: + """Forward of loss function. + + Args: + outputs (RetinaNetOut): Raw model outputs. + images_hw (list[tuple[int, int]]): Input image resolutions. + target_boxes (list[Tensor]): Bounding box labels. + target_classes (list[Tensor]): Class labels. + + Returns: + LossesType: Dictionary of model losses. + """ + losses = self.retinanet_loss( + outputs.cls_score, + outputs.bbox_pred, + target_boxes, + images_hw, + target_classes, + ) + return losses._asdict() diff --git a/vis4d/model/detect/yolox.py b/vis4d/model/detect/yolox.py new file mode 100644 index 0000000000000000000000000000000000000000..0bfd9f7195977948a3b10243aaece0426390edb6 --- /dev/null +++ b/vis4d/model/detect/yolox.py @@ -0,0 +1,154 @@ +"""YOLOX model implementation and runtime.""" + +from __future__ import annotations + +import torch +from torch import nn + +from vis4d.common.ckpt import load_model_checkpoint +from vis4d.op.base import BaseModel, CSPDarknet +from vis4d.op.box.box2d import scale_and_clip_boxes +from vis4d.op.detect.common import DetOut +from vis4d.op.detect.yolox import YOLOXHead, YOLOXOut, YOLOXPostprocess +from vis4d.op.fpp import YOLOXPAFPN, FeaturePyramidProcessing + +REV_KEYS = [ + (r"^backbone\.", "basemodel."), + (r"^bbox_head\.", "yolox_head."), + (r"^neck\.", "fpn."), + (r"\.bn\.", ".norm."), + (r"\.conv.weight", ".weight"), + (r"\.conv.bias", ".bias"), +] + + +class YOLOX(nn.Module): + """YOLOX detector.""" + + def __init__( + self, + num_classes: int, + basemodel: BaseModel | None = None, + fpn: FeaturePyramidProcessing | None = None, + yolox_head: YOLOXHead | None = None, + postprocessor: YOLOXPostprocess | None = None, + weights: None | str = None, + ) -> None: + """Creates an instance of the class. + + Args: + num_classes (int): Number of classes. + basemodel (BaseModel, optional): Base model. Defaults to None. If + None, will use CSPDarknet. + fpn (FeaturePyramidProcessing, optional): Feature Pyramid + Processing. Defaults to None. If None, will use YOLOXPAFPN. + yolox_head (YOLOXHead, optional): YOLOX head. Defaults to None. If + None, will use YOLOXHead. + postprocessor (YOLOXPostprocess, optional): Post processor. + Defaults to None. If None, will use YOLOXPostprocess. + weights (None | str, optional): Weights to load for model. If + set to "mmdet", will load MMDetection pre-trained weights. + Defaults to None. + """ + super().__init__() + self.basemodel = ( + CSPDarknet(deepen_factor=0.33, widen_factor=0.5) + if basemodel is None + else basemodel + ) + self.fpn = ( + YOLOXPAFPN([128, 256, 512], 128, num_csp_blocks=1) + if fpn is None + else fpn + ) + self.yolox_head = ( + YOLOXHead( + num_classes=num_classes, in_channels=128, feat_channels=128 + ) + if yolox_head is None + else yolox_head + ) + self.postprocessor = ( + YOLOXPostprocess( + self.yolox_head.point_generator, + self.yolox_head.box_decoder, + nms_threshold=0.65, + score_thr=0.01, + ) + if postprocessor is None + else postprocessor + ) + + if weights is not None: + if weights.startswith("mmdet://") or weights.startswith( + "bdd100k://" + ): + load_model_checkpoint(self, weights, rev_keys=REV_KEYS) + else: + load_model_checkpoint(self, weights) + + def forward( + self, + images: torch.Tensor, + input_hw: None | list[tuple[int, int]] = None, + original_hw: None | list[tuple[int, int]] = None, + ) -> YOLOXOut | DetOut: + """Forward pass. + + Args: + images (torch.Tensor): Input images. + input_hw (None | list[tuple[int, int]], optional): Input image + resolutions. Defaults to None. + original_hw (None | list[tuple[int, int]], optional): Original + image resolutions (before padding and resizing). Required for + testing. Defaults to None. + + Returns: + YOLOXOut | DetOut: Either raw model outputs (for training) or + predicted outputs (for testing). + """ + if self.training: + return self.forward_train(images) + assert input_hw is not None and original_hw is not None + return self.forward_test(images, input_hw, original_hw) + + def forward_train(self, images: torch.Tensor) -> YOLOXOut: + """Forward training stage. + + Args: + images (torch.Tensor): Input images. + + Returns: + YOLOXOut: Raw model outputs. + """ + features = self.fpn(self.basemodel(images.contiguous())) + return self.yolox_head(features[-3:]) + + def forward_test( + self, + images: torch.Tensor, + images_hw: list[tuple[int, int]], + original_hw: list[tuple[int, int]], + ) -> DetOut: + """Forward testing stage. + + Args: + images (torch.Tensor): Input images. + images_hw (list[tuple[int, int]]): Input image resolutions. + original_hw (list[tuple[int, int]]): Original image resolutions + (before padding and resizing). + + Returns: + DetOut: Predicted outputs. + """ + features = self.fpn(self.basemodel(images)) + outs = self.yolox_head(features[-3:]) + boxes, scores, class_ids = self.postprocessor( + cls_outs=outs.cls_score, + reg_outs=outs.bbox_pred, + obj_outs=outs.objectness, + images_hw=images_hw, + ) + for i, boxs in enumerate(boxes): + boxes[i] = scale_and_clip_boxes(boxs, original_hw[i], images_hw[i]) + return DetOut(boxes, scores, class_ids) diff --git a/vis4d/model/detect3d/__init__.py b/vis4d/model/detect3d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..44403d249f007a38297424080538f51d87b50463 --- /dev/null +++ b/vis4d/model/detect3d/__init__.py @@ -0,0 +1 @@ +"""3D Detection Models.""" diff --git a/vis4d/model/detect3d/bevformer.py b/vis4d/model/detect3d/bevformer.py new file mode 100644 index 0000000000000000000000000000000000000000..34f6d6621283ac31d5f88e6d6648801eccfe7bac --- /dev/null +++ b/vis4d/model/detect3d/bevformer.py @@ -0,0 +1,162 @@ +"""BEVFromer model implementation. + +This file composes the operations associated with BEVFormer +`https://arxiv.org/abs/2203.17270` into the full model implementation. +""" + +from __future__ import annotations + +import copy +from typing import TypedDict + +import torch +from torch import Tensor, nn + +from vis4d.common.ckpt import load_model_checkpoint +from vis4d.op.base import BaseModel +from vis4d.op.detect3d.bevformer import BEVFormerHead, GridMask +from vis4d.op.detect3d.common import Detect3DOut +from vis4d.op.fpp.fpn import FPN, ExtraFPNBlock + +REV_KEYS = [ + (r"^img_backbone\.", "basemodel."), + (r"^img_neck.lateral_convs\.", "fpn.inner_blocks."), + (r"^img_neck.fpn_convs\.", "fpn.layer_blocks."), + (r"^fpn.layer_blocks.3\.", "fpn.extra_blocks.convs.0."), + (r"\.conv.weight", ".weight"), + (r"\.conv.bias", ".bias"), +] + + +class PrevFrameInfo(TypedDict): + """Previous frame information.""" + + scene_name: str + prev_bev: Tensor | None + prev_pos: Tensor + prev_angle: Tensor + + +class BEVFormer(nn.Module): + """BEVFormer 3D Detector.""" + + def __init__( + self, + basemodel: BaseModel, + fpn: FPN | None = None, + pts_bbox_head: BEVFormerHead | None = None, + weights: str | None = None, + ) -> None: + """Creates an instance of the class. + + Args: + basemodel (BaseModel): Base model network. + fpn (FPN, optional): Feature Pyramid Network. Defaults to None. If + None, a default FPN will be used. + pts_bbox_head (BEVFormerHead, optional): BEVFormer head. Defaults + to None. If None, a default BEVFormer head will be used. + weights (str, optional): Path to the checkpoint to load. Defaults + to None. + """ + super().__init__() + self.basemodel = basemodel + self.fpn = fpn or FPN( + self.basemodel.out_channels[3:], + 256, + extra_blocks=ExtraFPNBlock( + extra_levels=1, in_channels=256, out_channels=256 + ), + start_index=3, + ) + + self.grid_mask = GridMask( + True, True, rotate=1, offset=False, ratio=0.5, mode=1, prob=0.7 + ) + + self.pts_bbox_head = pts_bbox_head or BEVFormerHead() + + # Temporal information + self.prev_frame_info = PrevFrameInfo( + scene_name="", + prev_bev=None, + prev_pos=torch.zeros(3), + prev_angle=torch.zeros(1), + ) + + if weights is not None: + load_model_checkpoint(self, weights, rev_keys=REV_KEYS) + + def extract_feat(self, images_list: list[Tensor]) -> list[Tensor]: + """Extract features of images.""" + n = len(images_list) # N + b = images_list[0].shape[0] # B + images = torch.stack(images_list, dim=1) # [B, N, C, H, W] + images = images.view(-1, *images.shape[2:]) # [B*N, C, H, W] + + # grid mask + if self.training: + images = self.grid_mask(images) + + features = self.basemodel(images) + features = self.fpn(features)[self.fpn.start_index :] + + img_feats = [] + for img_feat in features: + _, c, h, w = img_feat.size() + img_feats.append(img_feat.view(b, n, c, h, w)) + + return img_feats + + def forward( + self, + images: list[Tensor], + can_bus: list[list[float]], + scene_names: list[str], + cam_intrinsics: list[Tensor], + cam_extrinsics: list[Tensor], + lidar_extrinsics: list[Tensor], + ) -> Detect3DOut: + """Forward.""" + # Parse lidar extrinsics from LIDAR sensor data. + lidar_extrinsics_tensor = lidar_extrinsics[0] + can_bus_tensor = torch.tensor( + can_bus, dtype=torch.float32, device=images[0].device + ) + + if scene_names[0] != self.prev_frame_info["scene_name"]: + # the first sample of each scene is truncated + self.prev_frame_info["prev_bev"] = None + + # update idx + self.prev_frame_info["scene_name"] = scene_names[0] + + # Get the delta of ego position and angle between two timestamps. + tmp_pos = copy.deepcopy(can_bus_tensor[0][:3]) + tmp_angle = copy.deepcopy(can_bus_tensor[0][-1]) + if self.prev_frame_info["prev_bev"] is not None: + can_bus_tensor[0][:3] -= self.prev_frame_info["prev_pos"] + can_bus_tensor[0][-1] -= self.prev_frame_info["prev_angle"] + else: + can_bus_tensor[0][:3] = 0 + can_bus_tensor[0][-1] = 0 + + images_hw = (int(images[0].shape[-2]), int(images[0].shape[-1])) + img_feats = self.extract_feat(images) + + out, bev_embed = self.pts_bbox_head( + img_feats, + can_bus_tensor, + images_hw, + cam_intrinsics, + cam_extrinsics, + lidar_extrinsics_tensor, + prev_bev=self.prev_frame_info["prev_bev"], + ) + + # During inference, we save the BEV features and ego motion of each + # timestamp. + self.prev_frame_info["prev_pos"] = tmp_pos + self.prev_frame_info["prev_angle"] = tmp_angle + self.prev_frame_info["prev_bev"] = bev_embed + + return out diff --git a/vis4d/model/motion/__init__.py b/vis4d/model/motion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f9db6284e04c7125a12886cd30b4c96b3556552d --- /dev/null +++ b/vis4d/model/motion/__init__.py @@ -0,0 +1 @@ +"""Motion models.""" diff --git a/vis4d/model/motion/velo_lstm.py b/vis4d/model/motion/velo_lstm.py new file mode 100644 index 0000000000000000000000000000000000000000..7c08ae5e874b30d984467731002394a8ccaa499c --- /dev/null +++ b/vis4d/model/motion/velo_lstm.py @@ -0,0 +1,309 @@ +"""VeloLSTM 3D motion model.""" + +from __future__ import annotations + +from typing import NamedTuple + +import torch +from torch import Tensor, nn + +from vis4d.common.ckpt import load_model_checkpoint +from vis4d.op.geometry.rotation import acute_angle, normalize_angle +from vis4d.op.layer.weight_init import xavier_init + + +class VeloLSTMOut(NamedTuple): + """VeloLSTM output.""" + + loc_preds: Tensor + loc_refines: Tensor + + +class VeloLSTM(nn.Module): + """Estimating object location in world coordinates. + + Prediction LSTM: + Input: 5 frames velocity + Output: Next frame location + Updating LSTM: + Input: predicted location and observed location + Output: Refined location + """ + + def __init__( + self, + num_frames: int = 5, + feature_dim: int = 64, + hidden_size: int = 128, + num_layers: int = 2, + loc_dim: int = 7, + dropout: float = 0.1, + weights: str | None = None, + ) -> None: + """Init.""" + super().__init__() + self.num_frames = num_frames + self.feature_dim = feature_dim + self.hidden_size = hidden_size + self.num_layers = num_layers + self.loc_dim = loc_dim + + self.vel2feat = nn.Linear( + loc_dim, + feature_dim, + ) + + self.pred_lstm = nn.LSTM( + input_size=feature_dim, + hidden_size=hidden_size, + dropout=dropout, + num_layers=num_layers, + ) + + self.pred2atten = nn.Linear( + hidden_size, + loc_dim, + bias=False, + ) + + self.conf2feat = nn.Linear( + 1, + feature_dim, + bias=False, + ) + + self.refine_lstm = nn.LSTM( + input_size=3 * feature_dim, + hidden_size=hidden_size, + dropout=dropout, + num_layers=num_layers, + ) + + self.conf2atten = nn.Linear( + hidden_size, + loc_dim, + bias=False, + ) + + self._init_weights() + + if weights is not None: + load_model_checkpoint( + self, + weights, + map_location="cpu", + rev_keys=[(r"^model\.", ""), (r"^module\.", "")], + ) + + def _init_weights(self) -> None: + """Initialize model weights.""" + xavier_init(self.vel2feat) + xavier_init(self.pred2atten) + xavier_init(self.conf2feat) + xavier_init(self.conf2atten) + init_lstm_module(self.pred_lstm) + init_lstm_module(self.refine_lstm) + + def init_hidden( + self, device: torch.device, batch_size: int = 1 + ) -> tuple[Tensor, Tensor]: + """Initializae hidden state. + + The axes semantics are (num_layers, minibatch_size, hidden_dim) + """ + return ( + torch.zeros(self.num_layers, batch_size, self.hidden_size).to( + device + ), + torch.zeros(self.num_layers, batch_size, self.hidden_size).to( + device + ), + ) + + def refine( + self, + location: Tensor, + observation: Tensor, + prev_location: Tensor, + confidence: Tensor, + hc_0: tuple[Tensor, Tensor], + ) -> tuple[Tensor, tuple[Tensor, Tensor]]: + """Refine predicted location using single frame estimation at t+1. + + Input: + location: (num_batch x loc_dim), location from prediction + observation: (num_batch x loc_dim), location from single frame + estimation + prev_location: (num_batch x loc_dim), refined location + confidence: (num_batch X 1), depth estimation confidence + hc_0: (num_layers, num_batch, hidden_size), tuple of hidden and + cell + Middle: + loc_embed: (1, num_batch x feature_dim), predicted location feature + obs_embed: (1, num_batch x feature_dim), single frame location + feature + conf_embed: (1, num_batch x feature_dim), depth estimation + confidence feature + embed: (1, num_batch x 2*feature_dim), location feature + out: (1 x num_batch x hidden_size), lstm output + Output: + hc_n: (num_layers, num_batch, hidden_size), tuple of updated + hidden, cell + output_pred: (num_batch x loc_dim), predicted location + """ + num_batch = location.shape[0] + + pred_vel = location - prev_location + obsv_vel = observation - prev_location + + # Embed feature to hidden_size + loc_embed = self.vel2feat(pred_vel).view(num_batch, self.feature_dim) + obs_embed = self.vel2feat(obsv_vel).view(num_batch, self.feature_dim) + conf_embed = self.conf2feat(confidence).view( + num_batch, self.feature_dim + ) + embed = torch.cat( + [ + loc_embed, + obs_embed, + conf_embed, + ], + dim=1, + ).view(1, num_batch, 3 * self.feature_dim) + + out, (h_n, c_n) = self.refine_lstm(embed, hc_0) + + delta_vel_atten = torch.sigmoid(self.conf2atten(out)).view( + num_batch, self.loc_dim + ) + + output_pred = ( + delta_vel_atten * obsv_vel + + (1.0 - delta_vel_atten) * pred_vel + + prev_location + ) + + return output_pred, (h_n, c_n) + + def predict( + self, + vel_history: Tensor, + location: Tensor, + hc_0: tuple[Tensor, Tensor], + ) -> tuple[Tensor, tuple[Tensor, Tensor]]: + """Predict location at t+1 using updated location at t. + + Input: + vel_history: (num_seq, num_batch, loc_dim), velocity from previous + num_seq updates + location: (num_batch, loc_dim), location from previous update + hc_0: (num_layers, num_batch, hidden_size), tuple of hidden and + cell + Middle: + embed: (num_seq, num_batch x feature_dim), location feature + out: (num_seq x num_batch x hidden_size), lstm output + attention_logit: (num_seq x num_batch x loc_dim), the predicted + residual + Output: + hc_n: (num_layers, num_batch, hidden_size), tuple of updated + hidden, cell + output_pred: (num_batch x loc_dim), predicted location + """ + num_seq, num_batch, _ = vel_history.shape + + # Embed feature to hidden_size + embed = self.vel2feat(vel_history).view( + num_seq, num_batch, self.feature_dim + ) + + out, (h_n, c_n) = self.pred_lstm(embed, hc_0) + + attention_logit = self.pred2atten(out).view( + num_seq, num_batch, self.loc_dim + ) + attention = torch.softmax(attention_logit, dim=0) + + output_pred = torch.sum(attention * vel_history, dim=0) + location + + return output_pred, (h_n, c_n) + + def forward(self, pred_traj: Tensor) -> VeloLSTMOut: + """Forward of QD3DTrackGraph in training stage.""" + loc_preds_list = [] + loc_refines_list = [] + + hidden_predict = self.init_hidden( + pred_traj.device, batch_size=pred_traj.shape[0] + ) + hidden_refine = self.init_hidden( + pred_traj.device, batch_size=pred_traj.shape[0] + ) + + vel_history = pred_traj.new_zeros( + self.num_frames, pred_traj.shape[0], self.loc_dim + ) + + # Starting condition + pred_traj[:, :, 6] = normalize_angle(pred_traj[:, :, 6]) + prev_refine = pred_traj[:, 0, : self.loc_dim] + loc_pred = pred_traj[:, 1, : self.loc_dim] + + # LSTM + for i in range(1, pred_traj.shape[1]): + # Update + loc_pred[:, 6] = normalize_angle(loc_pred[:, 6]) + + for batch_id in range(pred_traj.shape[0]): + # acute angle + loc_pred[batch_id, 6] = acute_angle( + loc_pred[batch_id, 6], pred_traj[batch_id, i, 6] + ) + + loc_refine, hidden_refine = self.refine( + loc_pred.detach().clone(), + pred_traj[:, i, : self.loc_dim], + prev_refine.detach().clone(), + pred_traj[:, i, -1].unsqueeze(-1), + hidden_refine, + ) + loc_refine[:, 6] = normalize_angle(loc_refine[:, 6]) + + if i == 1: + vel_history = torch.cat( + [(loc_refine - prev_refine).unsqueeze(0)] * self.num_frames + ) + else: + vel_history = torch.cat( + [vel_history[1:], (loc_refine - prev_refine).unsqueeze(0)], + dim=0, + ) + prev_refine = loc_refine + + # Predict + loc_pred, hidden_predict = self.predict( + vel_history, loc_refine.detach().clone(), hidden_predict + ) + loc_pred[:, 6] = normalize_angle(loc_pred[:, 6]) + + loc_refines_list.append(loc_refine) + loc_preds_list.append(loc_pred) + + loc_refines = torch.cat(loc_refines_list, dim=1).view( + pred_traj.shape[0], -1, self.loc_dim + ) + loc_preds = torch.cat(loc_preds_list, dim=1).view( + pred_traj.shape[0], -1, self.loc_dim + ) + + return VeloLSTMOut(loc_preds=loc_preds, loc_refines=loc_refines) + + +def init_lstm_module(layer: nn.Module) -> None: + """Initialize LSTM weights and biases.""" + for name, param in layer.named_parameters(): + if "weight_ih" in name: + torch.nn.init.xavier_uniform_(param.data) + elif "weight_hh" in name: + torch.nn.init.orthogonal_(param.data) + elif "bias" in name: + param.data.fill_(0) diff --git a/vis4d/model/seg/__init__.py b/vis4d/model/seg/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aa068c2da72651b6be0701a3f570ca51b462c846 --- /dev/null +++ b/vis4d/model/seg/__init__.py @@ -0,0 +1 @@ +"""Semantic segmentation models.""" diff --git a/vis4d/model/seg/fcn_resnet.py b/vis4d/model/seg/fcn_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..f5a7ff6e1feca6daba60a6b85777f9ce4d1f9a91 --- /dev/null +++ b/vis4d/model/seg/fcn_resnet.py @@ -0,0 +1,85 @@ +"""FCN Resnet Implementation.""" + +from __future__ import annotations + +import torch +from torch import nn + +from vis4d.op.base.resnet import ResNet +from vis4d.op.seg.fcn import FCNHead, FCNOut + +REV_KEYS = [ + (r"^backbone\.", "basemodel."), + (r"^aux_classifier\.", "fcn.heads.0."), + (r"^classifier\.", "fcn.heads.1."), +] + + +class FCNResNet(nn.Module): + """FCN with ResNet basemodel for semantic segmentation.""" + + def __init__( + self, + base_model: str = "resnet50", + num_classes: int = 21, + resize: None | tuple[int, int] = (520, 520), + ) -> None: + """FCN with ResNet basemodel, following torchvision implementation. + + _. + + model: FCNResNet(base_model="resnet50") + - dataset: Coco2017 + - recipe: vis4d/model/segment/FCNResNet_coco_training.py + - metrics: + - mIoU: 62.52 + - Acc: 90.50 + """ + super().__init__() + if base_model.startswith("resnet"): + self.basemodel = ResNet( + base_model, + pretrained=True, + replace_stride_with_dilation=[False, True, True], + ) + else: + raise ValueError("base model not supported!") + self.fcn = FCNHead( + self.basemodel.out_channels[4:], num_classes, resize=resize + ) + + def forward_train(self, images: torch.Tensor) -> FCNOut: + """Forward pass for training. + + Args: + images (torch.Tensor): Input images. + + Returns: + FCNOut: Raw model predictions. + """ + return self.forward(images) + + def forward_test(self, images: torch.Tensor) -> FCNOut: + """Forward pass for testing. + + Args: + images (torch.Tensor): Input images. + + Returns: + FCNOut: Raw model predictions. + """ + return self.forward(images) + + def forward(self, images: torch.Tensor) -> FCNOut: + """Forward pass. + + Args: + images (torch.Tensor): Input images. + + Returns: + FCNOut: Raw model predictions. + """ + features = self.basemodel(images) + out = self.fcn(features) + return out diff --git a/vis4d/model/seg/semantic_fpn.py b/vis4d/model/seg/semantic_fpn.py new file mode 100644 index 0000000000000000000000000000000000000000..9b8420840f100f30f88fedb0ae7f4eb7917c2b08 --- /dev/null +++ b/vis4d/model/seg/semantic_fpn.py @@ -0,0 +1,152 @@ +"""SemanticFPN Implementation.""" + +from __future__ import annotations + +from typing import NamedTuple + +import torch +import torch.nn.functional as F +from torch import nn + +from vis4d.common.ckpt import load_model_checkpoint +from vis4d.op.base import BaseModel, ResNetV1c +from vis4d.op.fpp.fpn import FPN +from vis4d.op.mask.util import clip_mask +from vis4d.op.seg.semantic_fpn import SemanticFPNHead, SemanticFPNOut + +REV_KEYS = [ + (r"^decode_head\.", "seg_head."), + (r"^classifier\.", "fcn.heads.1."), + (r"^backbone\.", "basemodel."), + (r"^neck.lateral_convs\.", "fpn.inner_blocks."), + (r"^neck.fpn_convs\.", "fpn.layer_blocks."), + (r"\.conv.weight", ".weight"), + (r"\.conv.bias", ".bias"), +] +for ki in range(4): + for kj in range(5): + REV_KEYS += [ + ( + rf"^seg_head.scale_heads\.{ki}\.{kj}\.bn\.", + f"seg_head.scale_heads.{ki}.{kj}.norm.", + ) + ] + + +class MaskOut(NamedTuple): + """Output mask predictions.""" + + masks: list[torch.Tensor] # list of masks for each image + + +class SemanticFPN(nn.Module): + """Semantic FPN. + + Args: + num_classes (int): Number of classes. + resize (bool): Resize output to input size. + weights (None | str): Pre-trained weights. + basemodel (None | BaseModel): Base model to use. If None is passed, + this will default to ResNetV1c + """ + + def __init__( + self, + num_classes: int, + resize: bool = True, + weights: None | str = None, + basemodel: None | BaseModel = None, + ): + """Init.""" + super().__init__() + self.resize = resize + if basemodel is None: + basemodel = ResNetV1c( + "resnet50_v1c", + pretrained=True, + trainable_layers=5, + norm_frozen=False, + ) + + self.basemodel = basemodel + self.fpn = FPN(self.basemodel.out_channels[2:], 256, extra_blocks=None) + self.seg_head = SemanticFPNHead(num_classes, 256) + + if weights is not None: + if weights.startswith("mmseg://") or weights.startswith( + "bdd100k://" + ): + load_model_checkpoint(self, weights, rev_keys=REV_KEYS) + else: + load_model_checkpoint(self, weights) + + def forward_train(self, images: torch.Tensor) -> SemanticFPNOut: + """Forward pass for training. + + Args: + images (torch.Tensor): Input images. + + Returns: + SemanticFPNOut: Raw model predictions. + """ + features = self.fpn(self.basemodel(images.contiguous())) + out = self.seg_head(features) + if self.resize: + return SemanticFPNOut( + outputs=F.interpolate( + out.outputs, + scale_factor=4, + mode="bilinear", + align_corners=False, + ) + ) + return out + + def forward_test( + self, images: torch.Tensor, original_hw: list[tuple[int, int]] + ) -> MaskOut: + """Forward pass for testing. + + Args: + images (torch.Tensor): Input images. + original_hw (list[tuple[int, int]], optional): Original image + resolutions (before padding and resizing). Required for + testing. + + Returns: + SemanticFPNOut: Raw model predictions. + """ + features = self.fpn(self.basemodel(images)) + out = self.seg_head(features) + + new_masks = [] + for i, outputs in enumerate(out.outputs): + opt = F.interpolate( + outputs.unsqueeze(0), + scale_factor=4, + mode="bilinear", + align_corners=False, + ).squeeze(0) + new_masks.append(clip_mask(opt, original_hw[i]).argmax(dim=0)) + return MaskOut(masks=new_masks) + + def forward( + self, + images: torch.Tensor, + original_hw: None | list[tuple[int, int]] = None, + ) -> SemanticFPNOut | MaskOut: + """Forward pass. + + Args: + images (torch.Tensor): Input images. + original_hw (None | list[tuple[int, int]], optional): Original + image resolutions (before padding and resizing). Required for + testing. Defaults to None. + + Returns: + MaskOut: Raw model predictions. + """ + if self.training: + return self.forward_train(images) + assert original_hw is not None + return self.forward_test(images, original_hw) diff --git a/vis4d/model/segment3d/pointnet.py b/vis4d/model/segment3d/pointnet.py new file mode 100644 index 0000000000000000000000000000000000000000..ff85d098cfda4caf5ef875d5af2223bd69053f99 --- /dev/null +++ b/vis4d/model/segment3d/pointnet.py @@ -0,0 +1,143 @@ +"""Implementation of Pointnet.""" + +from __future__ import annotations + +import torch +from torch import nn + +from vis4d.common.ckpt import load_model_checkpoint +from vis4d.common.typing import LossesType, ModelOutput +from vis4d.data.const import CommonKeys +from vis4d.op.base.pointnet import PointNetSegmentation, PointNetSemanticsOut +from vis4d.op.loss.orthogonal_transform_loss import ( + OrthogonalTransformRegularizationLoss, +) + + +class PointnetSegmentationModel(nn.Module): + """Simple Segmentation Model using Pointnet.""" + + def __init__( + self, + num_classes: int = 11, + in_dimensions: int = 3, + weights: str | None = None, + ) -> None: + """Simple Segmentation Model using Pointnet. + + Args: + num_classes: Number of semantic classes + in_dimensions: Input dimension + weights: Path to weight file + """ + super().__init__() + self.model = PointNetSegmentation( + n_classes=num_classes, in_dimensions=in_dimensions + ) + if weights is not None: + load_model_checkpoint(self, weights) + + def __call__( + self, data: torch.Tensor, target: torch.Tensor | None = None + ) -> PointNetSemanticsOut | ModelOutput: + """Runs the semantic model. + + Args: + data: Input Tensor Shape [N, C, n_pts] + target: Target Classes shape [N, n_pts] + """ + return self._call_impl(data, target) + + def forward( + self, data: torch.Tensor, target: torch.Tensor | None = None + ) -> PointNetSemanticsOut | ModelOutput: + """Runs the semantic model. + + Args: + data: Input Tensor Shape [N, C, n_pts] + target: Target Classes shape [N, n_pts] + """ + if target is not None: + return self.forward_train(data, target) + return self.forward_test(data) + + def forward_train( + self, + points: torch.Tensor, + target: torch.Tensor, + ) -> PointNetSemanticsOut: + """Forward training stage. + + Args: + points: Input Tensor Shape [N, C, n_pts] + target: Target Classes shape [N, n_pts] + """ + out = self.model(points) + return out + + def forward_test( + self, + points: torch.Tensor, + ) -> ModelOutput: + """Forward test stage. + + Args: + points: Input Tensor Shape [N, C, n_pts] + """ + return { + CommonKeys.semantics3d: torch.argmax( + self.model(points).class_logits, dim=1 + ) + } + + +class PointnetSegmentationLoss(nn.Module): + """PointnetSegmentationLoss Loss.""" + + def __init__( + self, + regularize_transform: bool = True, + ignore_index: int = 255, + transform_weight: float = 1e-3, + semantic_weights: torch.Tensor | None = None, + ) -> None: + """Creates an instance of the class. + + Args: + regularize_transform: If true add transforms to loss + ignore_index: Semantic class that should be ignored + transform_weight: Loss weight factor for transform + regularization loss + semantic_weights: Classwise weights for semantic loss + """ + super().__init__() + self.segmentation_loss = nn.CrossEntropyLoss( + weight=semantic_weights, ignore_index=ignore_index + ) + self.transformation_loss = OrthogonalTransformRegularizationLoss() + self.regularize_transform = regularize_transform + self.transform_weight = transform_weight + + def forward( + self, outputs: PointNetSemanticsOut, target: torch.Tensor + ) -> LossesType: + """Calculates the losss. + + Args: + outputs: Pointnet output + target: Target Labels + """ + if not self.regularize_transform: + dict( + segmentation_loss=self.segmentation_loss( + outputs.class_logits, target + ) + ) + + return dict( + segmentation_loss=self.segmentation_loss( + outputs.class_logits, target + ), + transform_loss=self.transform_weight + * self.transformation_loss(outputs.transformations), + ) diff --git a/vis4d/model/segment3d/pointnetpp.py b/vis4d/model/segment3d/pointnetpp.py new file mode 100644 index 0000000000000000000000000000000000000000..6e7a65d926649ef97fdee6648404373b3fd1e029 --- /dev/null +++ b/vis4d/model/segment3d/pointnetpp.py @@ -0,0 +1,95 @@ +"""Pointnet++ Implementation.""" + +from __future__ import annotations + +import torch +from torch import Tensor, nn + +from vis4d.common.ckpt import load_model_checkpoint +from vis4d.common.typing import LossesType, ModelOutput +from vis4d.data.const import CommonKeys as K +from vis4d.op.base.pointnetpp import ( + PointNet2Segmentation, + PointNet2SegmentationOut, +) + + +class PointNet2SegmentationModel(nn.Module): + """PointNet++ Segmentation Model implementaiton.""" + + def __init__( + self, + num_classes: int, + in_dimensions: int = 3, + weights: str | None = None, + ): + """Creates a Pointnet+++ Model. + + Args: + num_classes (int): Number of classes + in_dimensions (int, optional): Input dimensions. Defaults to 3. + weights (str, optional): Path to weights. Defaults to None. + """ + super().__init__() + + self.segmentation_model = PointNet2Segmentation( + num_classes, in_dimensions + ) + + if weights is not None: + load_model_checkpoint(self, weights) + + def forward( + self, points3d: Tensor, semantics3d: Tensor | None = None + ) -> PointNet2SegmentationOut | ModelOutput: + """Forward pass of the model. Extract semantic predictions. + + Args: + points3d (Tensor): Input point shape [b, N, C]. + semantics3d (torch.Tenosr): Groundtruth semantic labels of + shape [b, N]. Defaults to None + + Returns: + ModelOutput: Semantic predictions of the model. + """ + x = self.segmentation_model(points3d) + if semantics3d is not None: + return x + class_pred = torch.argmax(x.class_logits, dim=1) + return {K.semantics3d: class_pred} + + +class Pointnet2SegmentationLoss(nn.Module): + """Pointnet2SegmentationLoss Loss.""" + + def __init__( + self, + ignore_index: int = 255, + semantic_weights: Tensor | None = None, + ) -> None: + """Creates an instance of the class. + + Args: + ignore_index (int, optional): Class Index that should be ignored. + Defaults to 255. + semantic_weights (Tensor, optional): Weights for each class. + """ + super().__init__() + self.segmentation_loss = nn.CrossEntropyLoss( + weight=semantic_weights, ignore_index=ignore_index + ) + + def forward( + self, outputs: PointNet2SegmentationOut, semantics3d: Tensor + ) -> LossesType: + """Calculates the loss. + + Args: + outputs (PointNet2SegmentationOut): Model outputs. + semantics3d (Tensor): Groundtruth semantic labels. + """ + return dict( + segmentation_loss=self.segmentation_loss( + outputs.class_logits, semantics3d + ), + ) diff --git a/vis4d/model/track/__init__.py b/vis4d/model/track/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e64d68e5e16c5acb9535fb1ebc7c3245745ad4f0 --- /dev/null +++ b/vis4d/model/track/__init__.py @@ -0,0 +1 @@ +"""Contains the implementation of 2D tracking models.""" diff --git a/vis4d/model/track/qdtrack.py b/vis4d/model/track/qdtrack.py new file mode 100644 index 0000000000000000000000000000000000000000..a89bbb103f5113a6367ff4273ff3e917de93c871 --- /dev/null +++ b/vis4d/model/track/qdtrack.py @@ -0,0 +1,567 @@ +"""Quasi-dense instance similarity learning model.""" + +from __future__ import annotations + +from typing import NamedTuple + +import torch +from torch import Tensor, nn + +from vis4d.common.ckpt import load_model_checkpoint +from vis4d.model.detect.yolox import REV_KEYS as YOLOX_REV_KEYS +from vis4d.op.base import BaseModel, CSPDarknet, ResNet +from vis4d.op.box.box2d import scale_and_clip_boxes +from vis4d.op.box.encoder import DeltaXYWHBBoxDecoder +from vis4d.op.box.poolers import MultiScaleRoIAlign +from vis4d.op.detect.faster_rcnn import FasterRCNNHead, FRCNNOut +from vis4d.op.detect.rcnn import RoI2Det +from vis4d.op.detect.yolox import YOLOXHead, YOLOXOut, YOLOXPostprocess +from vis4d.op.fpp import FPN, YOLOXPAFPN, FeaturePyramidProcessing +from vis4d.op.track.common import TrackOut +from vis4d.op.track.qdtrack import ( + QDSimilarityHead, + QDTrackAssociation, + QDTrackHead, +) +from vis4d.state.track.qdtrack import QDTrackGraph + +from .util import split_key_ref_indices + +REV_KEYS = [ + (r"^faster_rcnn_heads\.", "faster_rcnn_head."), + (r"^backbone.body\.", "basemodel."), + (r"^qdtrack\.", "qdtrack_head."), +] + + +class FasterRCNNQDTrackOut(NamedTuple): + """Output of QDtrack model.""" + + detector_out: FRCNNOut + key_images_hw: list[tuple[int, int]] + key_target_boxes: list[Tensor] + key_embeddings: list[Tensor] + ref_embeddings: list[list[Tensor]] + key_track_ids: list[Tensor] + ref_track_ids: list[list[Tensor]] + + +class FasterRCNNQDTrack(nn.Module): + """Wrap QDTrack with Faster R-CNN detector.""" + + def __init__( + self, + num_classes: int, + basemodel: BaseModel | None = None, + faster_rcnn_head: FasterRCNNHead | None = None, + rcnn_box_decoder: DeltaXYWHBBoxDecoder | None = None, + qdtrack_head: QDTrackHead | None = None, + track_graph: QDTrackGraph | None = None, + weights: None | str = None, + ) -> None: + """Creates an instance of the class. + + Args: + num_classes (int): Number of object categories. + basemodel (BaseModel, optional): Base model network. Defaults to + None. If None, will use ResNet50. + faster_rcnn_head (FasterRCNNHead, optional): Faster RCNN head. + Defaults to None. if None, will use default FasterRCNNHead. + rcnn_box_decoder (DeltaXYWHBBoxDecoder, optional): Decoder for RCNN + bounding boxes. Defaults to None. + qdtrack_head (QDTrack, optional): QDTrack head. Defaults to None. + If None, will use default QDTrackHead. + track_graph (QDTrackGraph, optional): Track graph. Defaults to + None. If None, will use default QDTrackGraph. + weights (str, optional): Weights to load for model. + """ + super().__init__() + self.basemodel = ( + ResNet(resnet_name="resnet50", pretrained=True, trainable_layers=3) + if basemodel is None + else basemodel + ) + + self.fpn = FPN(self.basemodel.out_channels[2:], 256) + + if faster_rcnn_head is None: + self.faster_rcnn_head = FasterRCNNHead(num_classes=num_classes) + else: + self.faster_rcnn_head = faster_rcnn_head + + self.roi2det = RoI2Det(rcnn_box_decoder) + + self.qdtrack_head = ( + QDTrackHead() if qdtrack_head is None else qdtrack_head + ) + + self.track_graph = ( + QDTrackGraph() if track_graph is None else track_graph + ) + + if weights is not None: + load_model_checkpoint( + self, weights, map_location="cpu", rev_keys=REV_KEYS + ) + + def forward( + self, + images: list[Tensor] | Tensor, + images_hw: list[list[tuple[int, int]]] | list[tuple[int, int]], + original_hw: list[list[tuple[int, int]]] | list[tuple[int, int]], + frame_ids: list[list[int]] | list[int], + boxes2d: None | list[list[Tensor]] = None, + boxes2d_classes: None | list[list[Tensor]] = None, + boxes2d_track_ids: None | list[list[Tensor]] = None, + keyframes: None | list[list[bool]] = None, + ) -> TrackOut | FasterRCNNQDTrackOut: + """Forward.""" + if self.training: + assert ( + isinstance(images, list) + and boxes2d is not None + and boxes2d_classes is not None + and boxes2d_track_ids is not None + and keyframes is not None + ) + return self._forward_train( + images, + images_hw, # type: ignore + boxes2d, + boxes2d_classes, + boxes2d_track_ids, + keyframes, + ) + return self._forward_test(images, images_hw, original_hw, frame_ids) # type: ignore # pylint: disable=line-too-long + + def _forward_train( + self, + images: list[Tensor], + images_hw: list[list[tuple[int, int]]], + target_boxes: list[list[Tensor]], + target_classes: list[list[Tensor]], + target_track_ids: list[list[Tensor]], + keyframes: list[list[bool]], + ) -> FasterRCNNQDTrackOut: + """Forward training stage. + + Args: + images (list[Tensor]): Input images. + images_hw (list[list[tuple[int, int]]]): Input image resolutions. + target_boxes (list[list[Tensor]]): Bounding box labels. + target_classes (list[list[Tensor]]): Class labels. + target_track_ids (list[list[Tensor]]): Track IDs. + keyframes (list[list[bool]]): Whether the frame is a keyframe. + + Returns: + FasterRCNNQDTrackOut: Raw model outputs. + """ + key_index, ref_indices = split_key_ref_indices(keyframes) + + # feature extraction + key_features = self.fpn(self.basemodel(images[key_index])) + ref_features = [ + self.fpn(self.basemodel(images[ref_index])) + for ref_index in ref_indices + ] + + key_detector_out = self.faster_rcnn_head( + key_features, + images_hw[key_index], + target_boxes[key_index], + target_classes[key_index], + ) + + with torch.no_grad(): + ref_detector_out = [ + self.faster_rcnn_head( + ref_features[i], + images_hw[ref_index], + target_boxes[ref_index], + target_classes[ref_index], + ) + for i, ref_index in enumerate(ref_indices) + ] + + key_proposals = key_detector_out.proposals.boxes + ref_proposals = [ref.proposals.boxes for ref in ref_detector_out] + key_target_boxes = target_boxes[key_index] + ref_target_boxes = [ + target_boxes[ref_index] for ref_index in ref_indices + ] + key_target_track_ids = target_track_ids[key_index] + ref_target_track_ids = [ + target_track_ids[ref_index] for ref_index in ref_indices + ] + + ( + key_embeddings, + ref_embeddings, + key_track_ids, + ref_track_ids, + ) = self.qdtrack_head( + features=[key_features, *ref_features], + det_boxes=[key_proposals, *ref_proposals], + target_boxes=[key_target_boxes, *ref_target_boxes], + target_track_ids=[key_target_track_ids, *ref_target_track_ids], + ) + assert ( + ref_embeddings is not None + and key_track_ids is not None + and ref_track_ids is not None + ) + + return FasterRCNNQDTrackOut( + detector_out=key_detector_out, + key_images_hw=images_hw[key_index], + key_target_boxes=key_target_boxes, + key_embeddings=key_embeddings, + ref_embeddings=ref_embeddings, + key_track_ids=key_track_ids, + ref_track_ids=ref_track_ids, + ) + + def _forward_test( + self, + images: Tensor, + images_hw: list[tuple[int, int]], + original_hw: list[tuple[int, int]], + frame_ids: list[int], + ) -> TrackOut: + """Forward inference stage.""" + features = self.basemodel(images) + features = self.fpn(features) + detector_out = self.faster_rcnn_head(features, images_hw) + + boxes, scores, class_ids = self.roi2det( + *detector_out.roi, detector_out.proposals.boxes, images_hw + ) + embeddings, _, _, _ = self.qdtrack_head(features, boxes) + + tracks = self.track_graph( + embeddings, boxes, scores, class_ids, frame_ids + ) + + for i, boxs in enumerate(tracks.boxes): + tracks.boxes[i] = scale_and_clip_boxes( + boxs, original_hw[i], images_hw[i] + ) + return tracks + + def __call__( + self, + images: list[Tensor] | Tensor, + images_hw: list[list[tuple[int, int]]] | list[tuple[int, int]], + original_hw: list[tuple[int, int]], + frame_ids: list[list[int]] | list[int], + boxes2d: None | list[list[Tensor]] = None, + boxes2d_classes: None | list[list[Tensor]] = None, + boxes2d_track_ids: None | list[list[Tensor]] = None, + keyframes: None | list[list[bool]] = None, + ) -> TrackOut | FasterRCNNQDTrackOut: + """Type definition for call implementation.""" + return self._call_impl( + images, + images_hw, + original_hw, + frame_ids, + boxes2d, + boxes2d_classes, + boxes2d_track_ids, + keyframes, + ) + + +class YOLOXQDTrackOut(NamedTuple): + """Output of QDtrack YOLOX model.""" + + detector_out: YOLOXOut + key_images_hw: list[tuple[int, int]] + key_target_boxes: list[Tensor] + key_target_classes: list[Tensor] + key_embeddings: list[Tensor] + ref_embeddings: list[list[Tensor]] + key_track_ids: list[Tensor] + ref_track_ids: list[list[Tensor]] + + +class YOLOXQDTrack(nn.Module): + """Wrap QDTrack with YOLOX detector.""" + + def __init__( + self, + num_classes: int, + basemodel: BaseModel | None = None, + fpn: FeaturePyramidProcessing | None = None, + yolox_head: YOLOXHead | None = None, + train_postprocessor: YOLOXPostprocess | None = None, + test_postprocessor: YOLOXPostprocess | None = None, + qdtrack_head: QDTrackHead | None = None, + track_graph: QDTrackGraph | None = None, + weights: None | str = None, + ) -> None: + """Creates an instance of the class. + + Args: + num_classes (int): Number of object categories. + basemodel (BaseModel, optional): Base model. Defaults to None. If + None, will use CSPDarknet. + fpn (FeaturePyramidProcessing, optional): Feature Pyramid + Processing. Defaults to None. If None, will use YOLOXPAFPN. + yolox_head (YOLOXHead, optional): YOLOX head. Defaults to None. If + None, will use YOLOXHead. + train_postprocessor (YOLOXPostprocess, optional): Post processor + for training. Defaults to None. If None, will use + YOLOXPostprocess. + test_postprocessor (YOLOXPostprocess, optional): Post processor + for testing. Defaults to None. If None, will use + YOLOXPostprocess. + qdtrack_head (QDTrack, optional): QDTrack head. Defaults to None. + If None, will use default QDTrackHead. + track_graph (QDTrackGraph, optional): Track graph. Defaults to + None. If None, will use default QDTrackGraph. + weights (str, optional): Weights to load for model. + """ + super().__init__() + self.basemodel = ( + CSPDarknet(deepen_factor=1.33, widen_factor=1.25) + if basemodel is None + else basemodel + ) + self.fpn = ( + YOLOXPAFPN([320, 640, 1280], 320, num_csp_blocks=4) + if fpn is None + else fpn + ) + self.yolox_head = ( + YOLOXHead( + num_classes=num_classes, in_channels=320, feat_channels=320 + ) + if yolox_head is None + else yolox_head + ) + self.train_postprocessor = ( + YOLOXPostprocess( + self.yolox_head.point_generator, + self.yolox_head.box_decoder, + nms_threshold=0.7, + score_thr=0.0, + nms_pre=2000, + max_per_img=1000, + ) + if train_postprocessor is None + else train_postprocessor + ) + self.test_postprocessor = ( + YOLOXPostprocess( + self.yolox_head.point_generator, + self.yolox_head.box_decoder, + nms_threshold=0.65, + score_thr=0.1, + ) + if test_postprocessor is None + else test_postprocessor + ) + + self.qdtrack_head = ( + QDTrackHead( + QDSimilarityHead( + MultiScaleRoIAlign( + resolution=[7, 7], + strides=[8, 16, 32], + sampling_ratio=0, + ), + in_dim=320, + ) + ) + if qdtrack_head is None + else qdtrack_head + ) + + self.track_graph = ( + QDTrackGraph( + track=QDTrackAssociation( + init_score_thr=0.5, obj_score_thr=0.35 + ) + ) + if track_graph is None + else track_graph + ) + + if weights is not None: + load_model_checkpoint( + self, weights, map_location="cpu", rev_keys=YOLOX_REV_KEYS + ) + + def forward( + self, + images: list[Tensor] | Tensor, + images_hw: list[list[tuple[int, int]]] | list[tuple[int, int]], + original_hw: list[list[tuple[int, int]]] | list[tuple[int, int]], + frame_ids: list[list[int]] | list[int], + boxes2d: None | list[list[Tensor]] = None, + boxes2d_classes: None | list[list[Tensor]] = None, + boxes2d_track_ids: None | list[list[Tensor]] = None, + keyframes: None | list[list[bool]] = None, + ) -> TrackOut | YOLOXQDTrackOut: + """Forward.""" + if self.training: + assert ( + isinstance(images, list) + and boxes2d is not None + and boxes2d_classes is not None + and boxes2d_track_ids is not None + and keyframes is not None + ) + return self._forward_train( + images, + images_hw, # type: ignore + boxes2d, + boxes2d_classes, + boxes2d_track_ids, + keyframes, + ) + return self._forward_test(images, images_hw, original_hw, frame_ids) # type: ignore # pylint: disable=line-too-long + + def _forward_train( + self, + images: list[Tensor], + images_hw: list[list[tuple[int, int]]], + target_boxes: list[list[Tensor]], + target_classes: list[list[Tensor]], + target_track_ids: list[list[Tensor]], + keyframes: list[list[bool]], + ) -> YOLOXQDTrackOut: + """Forward training stage. + + Args: + images (list[Tensor]): Input images. + images_hw (list[list[tuple[int, int]]]): Input image resolutions. + target_boxes (list[list[Tensor]]): Bounding box labels. + target_classes (list[list[Tensor]]): Class labels. + target_track_ids (list[list[Tensor]]): Track IDs. + keyframes (list[list[bool]]): Whether the frame is a keyframe. + + Returns: + YOLOXQDTrackOut: Raw model outputs. + """ + key_index, ref_indices = split_key_ref_indices(keyframes) + + # feature extraction + key_features = self.fpn(self.basemodel(images[key_index].contiguous())) + ref_features = [ + self.fpn(self.basemodel(images[ref_index].contiguous())) + for ref_index in ref_indices + ] + + key_detector_out = self.yolox_head(key_features[-3:]) + key_proposals, _, _ = self.train_postprocessor( + cls_outs=key_detector_out.cls_score, + reg_outs=key_detector_out.bbox_pred, + obj_outs=key_detector_out.objectness, + images_hw=images_hw[key_index], + ) + + with torch.no_grad(): + ref_detector_out = [ + self.yolox_head(ref_feat[-3:]) for ref_feat in ref_features + ] + ref_proposals = [ + self.train_postprocessor( + cls_outs=ref_out.cls_score, + reg_outs=ref_out.bbox_pred, + obj_outs=ref_out.objectness, + images_hw=images_hw[ref_index], + )[0] + for ref_index, ref_out in zip(ref_indices, ref_detector_out) + ] + + key_target_boxes = target_boxes[key_index] + ref_target_boxes = [ + target_boxes[ref_index] for ref_index in ref_indices + ] + key_target_classes = target_classes[key_index] + key_target_track_ids = target_track_ids[key_index] + ref_target_track_ids = [ + target_track_ids[ref_index] for ref_index in ref_indices + ] + + ( + key_embeddings, + ref_embeddings, + key_track_ids, + ref_track_ids, + ) = self.qdtrack_head( + features=[key_features, *ref_features], + det_boxes=[key_proposals, *ref_proposals], + target_boxes=[key_target_boxes, *ref_target_boxes], + target_track_ids=[key_target_track_ids, *ref_target_track_ids], + ) + assert ( + ref_embeddings is not None + and key_track_ids is not None + and ref_track_ids is not None + ) + + return YOLOXQDTrackOut( + detector_out=key_detector_out, + key_images_hw=images_hw[key_index], + key_target_boxes=key_target_boxes, + key_target_classes=key_target_classes, + key_embeddings=key_embeddings, + ref_embeddings=ref_embeddings, + key_track_ids=key_track_ids, + ref_track_ids=ref_track_ids, + ) + + def _forward_test( + self, + images: torch.Tensor, + images_hw: list[tuple[int, int]], + original_hw: list[tuple[int, int]], + frame_ids: list[int], + ) -> TrackOut: + """Forward inference stage.""" + features = self.fpn(self.basemodel(images)) + outs = self.yolox_head(features[-3:]) + boxes, scores, class_ids = self.test_postprocessor( + cls_outs=outs.cls_score, + reg_outs=outs.bbox_pred, + obj_outs=outs.objectness, + images_hw=images_hw, + ) + + embeddings, _, _, _ = self.qdtrack_head(features, boxes) + + tracks = self.track_graph( + embeddings, boxes, scores, class_ids, frame_ids + ) + + for i, boxs in enumerate(tracks.boxes): + tracks.boxes[i] = scale_and_clip_boxes( + boxs, original_hw[i], images_hw[i] + ) + return tracks + + def __call__( + self, + images: list[Tensor] | Tensor, + images_hw: list[list[tuple[int, int]]] | list[tuple[int, int]], + original_hw: list[list[tuple[int, int]]] | list[tuple[int, int]], + frame_ids: list[list[int]] | list[int], + boxes2d: None | list[list[Tensor]] = None, + boxes2d_classes: None | list[list[Tensor]] = None, + boxes2d_track_ids: None | list[list[Tensor]] = None, + keyframes: None | list[list[bool]] = None, + ) -> TrackOut | FasterRCNNQDTrackOut: + """Type definition for call implementation.""" + return self._call_impl( + images, + images_hw, + original_hw, + frame_ids, + boxes2d, + boxes2d_classes, + boxes2d_track_ids, + keyframes, + ) diff --git a/vis4d/model/track/util.py b/vis4d/model/track/util.py new file mode 100644 index 0000000000000000000000000000000000000000..f32af99ed1658d56ac2091994756642df4a441a9 --- /dev/null +++ b/vis4d/model/track/util.py @@ -0,0 +1,24 @@ +"""Utility functions for track module.""" + +from __future__ import annotations + + +def split_key_ref_indices( + keyframes: list[list[bool]], +) -> tuple[int, list[int]]: + """Get key frame from list of sample attributes.""" + key_ind = None + ref_inds = [] + for i, is_keys in enumerate(keyframes): + assert all( + is_keys[0] == is_key for is_key in is_keys + ), "Same batch should have the same view." + if is_keys[0]: + key_ind = i + else: + ref_inds.append(i) + + assert key_ind is not None, "Key frame not found." + assert len(ref_inds) > 0, "No reference frames found." + + return key_ind, ref_inds diff --git a/vis4d/model/track3d/__init__.py b/vis4d/model/track3d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ed2d2ba631ffd586e77b6866ad8cd94e96a24419 --- /dev/null +++ b/vis4d/model/track3d/__init__.py @@ -0,0 +1 @@ +"""Contains the implementation of 3D Tracking models.""" diff --git a/vis4d/model/track3d/cc_3dt.py b/vis4d/model/track3d/cc_3dt.py new file mode 100644 index 0000000000000000000000000000000000000000..972728ccda86d586fb75ace53419d3373ea5c416 --- /dev/null +++ b/vis4d/model/track3d/cc_3dt.py @@ -0,0 +1,605 @@ +"""CC-3DT model implementation. + +This file composes the operations associated with CC-3DT +`https://arxiv.org/abs/2212.01247` into the full model implementation. +""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import NamedTuple + +import torch +from torch import Tensor, nn + +from vis4d.data.const import AxisMode +from vis4d.model.track.qdtrack import FasterRCNNQDTrackOut +from vis4d.op.base import BaseModel, ResNet +from vis4d.op.box.anchor import AnchorGenerator +from vis4d.op.box.box2d import bbox_area, bbox_clip +from vis4d.op.box.box3d import boxes3d_to_corners, transform_boxes3d +from vis4d.op.box.encoder import DeltaXYWHBBoxDecoder +from vis4d.op.detect3d.qd_3dt import QD3DTBBox3DHead, RoI2Det3D +from vis4d.op.detect3d.util import bev_3d_nms +from vis4d.op.detect.faster_rcnn import FasterRCNNHead +from vis4d.op.detect.rcnn import RCNNHead, RoI2Det +from vis4d.op.fpp import FPN +from vis4d.op.geometry.projection import project_points +from vis4d.op.geometry.rotation import ( + quaternion_to_matrix, + rotation_matrix_yaw, +) +from vis4d.op.geometry.transform import inverse_rigid_transform +from vis4d.op.track3d.cc_3dt import ( + CC3DTrackAssociation, + cam_to_global, + get_track_3d_out, +) +from vis4d.op.track3d.common import Track3DOut +from vis4d.op.track.qdtrack import QDTrackHead +from vis4d.state.track3d.cc_3dt import CC3DTrackGraph + +from ..track.util import split_key_ref_indices + + +class FasterRCNNCC3DTOut(NamedTuple): + """Output of CC-3DT model with Faster R-CNN detector.""" + + detector_3d_out: Tensor + detector_3d_target: Tensor + detector_3d_labels: Tensor + qdtrack_out: FasterRCNNQDTrackOut + + +class FasterRCNNCC3DT(nn.Module): + """CC-3DT with Faster-RCNN detector.""" + + def __init__( + self, + num_classes: int, + basemodel: BaseModel | None = None, + faster_rcnn_head: FasterRCNNHead | None = None, + rcnn_box_decoder: DeltaXYWHBBoxDecoder | None = None, + qdtrack_head: QDTrackHead | None = None, + track_graph: CC3DTrackGraph | None = None, + pure_det: bool = False, + ) -> None: + """Creates an instance of the class. + + Args: + num_classes (int): Number of object categories. + basemodel (BaseModel, optional): Base model network. Defaults to + None. If None, will use ResNet50. + faster_rcnn_head (FasterRCNNHead, optional): Faster RCNN head. + Defaults to None. if None, will use default FasterRCNNHead. + rcnn_box_decoder (DeltaXYWHBBoxDecoder, optional): Decoder for RCNN + bounding boxes. Defaults to None. + qdtrack_head (QDTrack, optional): QDTrack head. Defaults to None. + If None, will use default QDTrackHead. + track_graph (CC3DTrackGraph, optional): Track graph. Defaults to + None. If None, will use default CC3DTrackGraph. + pure_det (bool, optional): Whether to use pure detection. Defaults + to False. + """ + super().__init__() + self.basemodel = ( + ResNet(resnet_name="resnet50", pretrained=True, trainable_layers=3) + if basemodel is None + else basemodel + ) + + self.fpn = FPN(self.basemodel.out_channels[2:], 256) + + if faster_rcnn_head is None: + anchor_generator = AnchorGenerator( + scales=[4, 8], + ratios=[0.25, 0.5, 1.0, 2.0, 4.0], + strides=[4, 8, 16, 32, 64], + ) + roi_head = RCNNHead(num_shared_convs=4, num_classes=num_classes) + self.faster_rcnn_head = FasterRCNNHead( + num_classes=num_classes, + anchor_generator=anchor_generator, + roi_head=roi_head, + ) + else: + self.faster_rcnn_head = faster_rcnn_head + + self.roi2det = RoI2Det(rcnn_box_decoder) + + self.bbox_3d_head = QD3DTBBox3DHead(num_classes=num_classes) + + self.roi2det_3d = RoI2Det3D() + + self.qdtrack_head = ( + QDTrackHead() if qdtrack_head is None else qdtrack_head + ) + + self.track_graph = ( + CC3DTrackGraph() if track_graph is None else track_graph + ) + + self.pure_det = pure_det + + def forward( + self, + images: list[Tensor], + images_hw: list[list[tuple[int, int]]], + intrinsics: list[Tensor], + extrinsics: list[Tensor] | None = None, + frame_ids: list[int] | None = None, + boxes2d: list[list[Tensor]] | None = None, + boxes3d: list[list[Tensor]] | None = None, + boxes3d_classes: list[list[Tensor]] | None = None, + boxes3d_track_ids: list[list[Tensor]] | None = None, + keyframes: None | list[list[bool]] | None = None, + ) -> FasterRCNNCC3DTOut | Track3DOut: + """Forward.""" + if self.training: + assert ( + boxes2d is not None + and boxes3d is not None + and boxes3d_classes is not None + and boxes3d_track_ids is not None + and keyframes is not None + ) + return self._forward_train( + images, + images_hw, + intrinsics, + boxes2d, + boxes3d, + boxes3d_classes, + boxes3d_track_ids, + keyframes, + ) + + assert extrinsics is not None and frame_ids is not None + return self._forward_test( + images, images_hw, intrinsics, extrinsics, frame_ids + ) + + def _forward_train( + self, + images: list[Tensor], + images_hw: list[list[tuple[int, int]]], + intrinsics: list[Tensor], + target_boxes2d: list[list[Tensor]], + target_boxes3d: list[list[Tensor]], + target_classes: list[list[Tensor]], + target_track_ids: list[list[Tensor]], + keyframes: list[list[bool]], + ) -> FasterRCNNCC3DTOut: + """Foward training stage.""" + key_index, ref_indices = split_key_ref_indices(keyframes) + + # feature extraction + key_features = self.fpn(self.basemodel(images[key_index])) + ref_features = [ + self.fpn(self.basemodel(images[ref_index])) + for ref_index in ref_indices + ] + + key_detector_out = self.faster_rcnn_head( + key_features, + images_hw[key_index], + target_boxes2d[key_index], + target_classes[key_index], + ) + + with torch.no_grad(): + ref_detector_out = [ + self.faster_rcnn_head( + ref_features[i], + images_hw[ref_index], + target_boxes2d[ref_index], + target_classes[ref_index], + ) + for i, ref_index in enumerate(ref_indices) + ] + + key_proposals = key_detector_out.proposals.boxes + ref_proposals = [ref.proposals.boxes for ref in ref_detector_out] + key_target_boxes = target_boxes2d[key_index] + ref_target_boxes = [ + target_boxes2d[ref_index] for ref_index in ref_indices + ] + key_target_track_ids = target_track_ids[key_index] + ref_target_track_ids = [ + target_track_ids[ref_index] for ref_index in ref_indices + ] + + ( + key_embeddings, + ref_embeddings, + key_track_ids, + ref_track_ids, + ) = self.qdtrack_head( + features=[key_features, *ref_features], + det_boxes=[key_proposals, *ref_proposals], + target_boxes=[key_target_boxes, *ref_target_boxes], + target_track_ids=[key_target_track_ids, *ref_target_track_ids], + ) + assert ( + ref_embeddings is not None + and key_track_ids is not None + and ref_track_ids is not None + ) + + predictions, targets, labels = self.bbox_3d_head( + features=key_features, + det_boxes=key_proposals, + intrinsics=intrinsics[key_index], + target_boxes=key_target_boxes, + target_boxes3d=target_boxes3d[key_index], + target_class_ids=target_classes[key_index], + ) + detector_3d_out = torch.cat(predictions) + assert targets is not None and labels is not None + + return FasterRCNNCC3DTOut( + detector_3d_out=detector_3d_out, + detector_3d_target=targets, + detector_3d_labels=labels, + qdtrack_out=FasterRCNNQDTrackOut( + detector_out=key_detector_out, + key_images_hw=images_hw[key_index], + key_target_boxes=key_target_boxes, + key_embeddings=key_embeddings, + ref_embeddings=ref_embeddings, + key_track_ids=key_track_ids, + ref_track_ids=ref_track_ids, + ), + ) + + def _forward_test( + self, + images_list: list[Tensor], + images_hw: list[list[tuple[int, int]]], + intrinsics_list: list[Tensor], + extrinsics_list: list[Tensor], + frame_ids: list[int], + ) -> Track3DOut: + """Forward inference stage. + + Curretnly only work with single batch per gpu. + """ + # (N, 1, 3, H, W) -> (N, 3, H, W) + images = torch.cat(images_list) + # (N, 1, 3, 3) -> (N, 3, 3) + intrinsics = torch.cat(intrinsics_list) + # (N, 1, 4, 4) -> (N, 4, 4) + extrinsics = torch.cat(extrinsics_list) + # (N, 1) -> (N,) + frame_id = frame_ids[0] + images_hw_list: list[tuple[int, int]] = sum(images_hw, []) + + features = self.basemodel(images) + features = self.fpn(features) + _, roi, proposals, _, _, _ = self.faster_rcnn_head( + features, images_hw_list + ) + + boxes_2d_list, scores_2d_list, class_ids_list = self.roi2det( + *roi, proposals.boxes, images_hw_list + ) + + predictions, _, _ = self.bbox_3d_head( + features, det_boxes=boxes_2d_list + ) + + boxes_3d_list, scores_3d_list = self.roi2det_3d( + predictions, boxes_2d_list, class_ids_list, intrinsics + ) + + embeddings_list, _, _, _ = self.qdtrack_head(features, boxes_2d_list) + + # Assign camera id + camera_ids_list = [] + for i, boxes_2d in enumerate(boxes_2d_list): + camera_ids_list.append( + (torch.mul(torch.ones(len(boxes_2d)), i)).to(boxes_2d.device) + ) + + # Move 3D boxes to world coordinate + boxes_3d_list = cam_to_global(boxes_3d_list, extrinsics) + + # Merge boxes from all cameras + boxes_2d = torch.cat(boxes_2d_list) + scores_2d = torch.cat(scores_2d_list) + camera_ids = torch.cat(camera_ids_list) + boxes_3d = torch.cat(boxes_3d_list) + scores_3d = torch.cat(scores_3d_list) + class_ids = torch.cat(class_ids_list) + embeddings = torch.cat(embeddings_list) + + if self.pure_det: + return get_track_3d_out( + boxes_3d, class_ids, scores_3d, torch.zeros_like(class_ids) + ) + + # 3D NMS in world coordinate + keep_indices = bev_3d_nms( + center_x=boxes_3d[:, 0].unsqueeze(1), + center_y=boxes_3d[:, 1].unsqueeze(1), + width=boxes_3d[:, 4].unsqueeze(1), + length=boxes_3d[:, 5].unsqueeze(1), + angle=180.0 / torch.pi * boxes_3d[:, 8].unsqueeze(1), + scores=scores_2d * scores_3d, + ) + + boxes_2d = boxes_2d[keep_indices] + scores_2d = scores_2d[keep_indices] + camera_ids = camera_ids[keep_indices] + boxes_3d = boxes_3d[keep_indices] + scores_3d = scores_3d[keep_indices] + class_ids = class_ids[keep_indices] + embeddings = embeddings[keep_indices] + + outs = self.track_graph( + boxes_2d, + scores_2d, + camera_ids, + boxes_3d, + scores_3d, + class_ids, + embeddings, + frame_id, + ) + + return outs + + def __call__( + self, + images: list[Tensor] | Tensor, + images_hw: list[list[tuple[int, int]]], + intrinsics: list[Tensor] | Tensor, + extrinsics: Tensor | None = None, + frame_ids: list[list[int]] | None = None, + boxes2d: list[list[Tensor]] | None = None, + boxes3d: list[list[Tensor]] | None = None, + boxes3d_classes: list[list[Tensor]] | None = None, + boxes3d_track_ids: list[list[Tensor]] | None = None, + keyframes: None | list[list[bool]] | None = None, + ) -> FasterRCNNCC3DTOut | Track3DOut: + """Type definition for call implementation.""" + return self._call_impl( + images, + images_hw, + intrinsics, + extrinsics, + frame_ids, + boxes2d, + boxes3d, + boxes3d_classes, + boxes3d_track_ids, + keyframes, + ) + + +class CC3DT(nn.Module): + """CC-3DT with custom detection results.""" + + def __init__( + self, + basemodel: BaseModel | None = None, + qdtrack_head: QDTrackHead | None = None, + track_graph: CC3DTrackGraph | None = None, + detection_range: Sequence[float] | None = None, + ) -> None: + """Creates an instance of the class. + + Args: + basemodel (BaseModel, optional): Base model network. Defaults to + None. If None, will use ResNet50. + qdtrack_head (QDTrack, optional): QDTrack head. Defaults to None. + If None, will use default QDTrackHead. + track_graph (CC3DTrackGraph, optional): Track graph. Defaults to + None. If None, will use default CC3DTrackGraph. + detection_range (Sequence[float], optional): Detection range for + each class. Defaults to None. + """ + super().__init__() + self.basemodel = ( + ResNet(resnet_name="resnet50", pretrained=True, trainable_layers=3) + if basemodel is None + else basemodel + ) + + self.fpn = FPN(self.basemodel.out_channels[2:], 256) + + self.qdtrack_head = ( + QDTrackHead() if qdtrack_head is None else qdtrack_head + ) + + self.track_graph = track_graph or CC3DTrackGraph( + track=CC3DTrackAssociation(init_score_thr=0.2, obj_score_thr=0.1), + update_3d_score=False, + add_backdrops=False, + ) + + self.detection_range = detection_range + + def forward( + self, + images_list: list[Tensor], + images_hw: list[list[tuple[int, int]]], + intrinsics_list: list[Tensor], + extrinsics_list: list[Tensor], + frame_ids: list[int], + pred_boxes3d: list[list[Tensor]], + pred_boxes3d_classes: list[list[Tensor]], + pred_boxes3d_scores: list[list[Tensor]], + pred_boxes3d_velocities: list[list[Tensor]], + ) -> Track3DOut: + """Forward inference stage. + + Curretnly only work with single batch per gpu. + """ + # (N, 1, 3, H, W) -> (N, 3, H, W) + images = torch.cat(images_list) + # (N, 1, 3, 3) -> (N, 3, 3) + intrinsics = torch.cat(intrinsics_list) + # (N, 1, 4, 4) -> (N, 4, 4) + extrinsics = torch.cat(extrinsics_list) + # (N, 1) -> (N,) + frame_id = frame_ids[0] + images_hw_list: list[tuple[int, int]] = sum(images_hw, []) + + features = self.basemodel(images) + features = self.fpn(features) + + # (1, 1, B,) -> (B,) + boxes_3d = pred_boxes3d[0][0] + class_ids = pred_boxes3d_classes[0][0] + scores_3d = pred_boxes3d_scores[0][0] + velocities = pred_boxes3d_velocities[0][0] + + # Get 2D boxes and assign camera id + global_to_cams = inverse_rigid_transform(extrinsics) + + boxes_3d_list = [] + boxes_2d_list = [] + class_ids_list = [] + scores_list = [] + camera_ids_list = [] + for i, global_to_cam in enumerate(global_to_cams): + boxes3d_cam = transform_boxes3d( + boxes_3d, + global_to_cam, + source_axis_mode=AxisMode.ROS, + target_axis_mode=AxisMode.OPENCV, + ) + + corners = boxes3d_to_corners( + boxes3d_cam, axis_mode=AxisMode.OPENCV + ) + + corners_2d = project_points(corners, intrinsics[i]) + + boxes_2d = self._to_boxes2d(corners_2d) + boxes_2d = bbox_clip(boxes_2d, images_hw_list[i], 1) + + mask = ( + (boxes3d_cam[:, 2] > 0) + & (bbox_area(boxes_2d) > 0) + & ( + bbox_area(boxes_2d) + < (images_hw_list[i][0] - 1) * (images_hw_list[i][1] - 1) + ) + & self._filter_distance(class_ids, boxes3d_cam) + ) + + cc_3dt_boxes_3d = boxes_3d.new_zeros(len(boxes_2d[mask]), 12) + cc_3dt_boxes_3d[:, :3] = boxes_3d[mask][:, :3] + # WLH -> HWL + cc_3dt_boxes_3d[:, 3:6] = boxes_3d[mask][:, [5, 3, 4]] + cc_3dt_boxes_3d[:, 6:9] = rotation_matrix_yaw( + quaternion_to_matrix(boxes_3d[mask][:, 6:]), AxisMode.ROS + ) + cc_3dt_boxes_3d[:, 9:] = velocities[mask] + + boxes_3d_list.append(cc_3dt_boxes_3d) + boxes_2d_list.append(boxes_2d[mask]) + class_ids_list.append(class_ids[mask]) + scores_list.append(scores_3d[mask]) + camera_ids_list.append( + (torch.mul(torch.ones(len(cc_3dt_boxes_3d)), i)).to( + boxes_2d.device + ) + ) + + embeddings_list, _, _, _ = self.qdtrack_head(features, boxes_2d_list) + + boxes_3d = torch.cat(boxes_3d_list) + boxes_2d = torch.cat(boxes_2d_list) + camera_ids = torch.cat(camera_ids_list) + scores = torch.cat(scores_list) + class_ids = torch.cat(class_ids_list) + embeddings = torch.cat(embeddings_list) + + # Select project boxes2d according to bbox area + keep_indices = embeddings.new_ones(len(boxes_3d)).bool() + boxes_2d_area = bbox_area(boxes_2d) + for i, box3d in enumerate(boxes_3d): + for same_idx in ( + (box3d[:3] == boxes_3d[:, :3]).all(dim=1).nonzero() + ): + if ( + same_idx != i + and boxes_2d_area[same_idx] > boxes_2d_area[i] + ): + keep_indices[i] = False + break + + boxes_3d = boxes_3d[keep_indices] + boxes_2d = boxes_2d[keep_indices] + camera_ids = camera_ids[keep_indices] + scores = scores[keep_indices] + class_ids = class_ids[keep_indices] + embeddings = embeddings[keep_indices] + + outs = self.track_graph( + boxes_2d, + scores, + camera_ids, + boxes_3d, + scores, + class_ids, + embeddings, + frame_id, + ) + + return outs + + def _to_boxes2d(self, corners_2d: Tensor) -> Tensor: + """Project 3D boxes (Camera coordinates) to 2D boxes.""" + min_x = torch.min(corners_2d[:, :, 0], 1).values.unsqueeze(-1) + min_y = torch.min(corners_2d[:, :, 1], 1).values.unsqueeze(-1) + max_x = torch.max(corners_2d[:, :, 0], 1).values.unsqueeze(-1) + max_y = torch.max(corners_2d[:, :, 1], 1).values.unsqueeze(-1) + + return torch.cat([min_x, min_y, max_x, max_y], dim=1) + + def _filter_distance( + self, class_ids: Tensor, boxes3d: Tensor, tolerance: float = 2.0 + ) -> Tensor: + """Filter boxes3d on distance.""" + if self.detection_range is None: + return torch.ones_like(class_ids, dtype=torch.bool) + + return torch.linalg.norm( # pylint: disable=not-callable + boxes3d[:, [0, 2]], dim=1 + ) <= torch.tensor( + [ + self.detection_range[class_id] + tolerance + for class_id in class_ids + ] + ).to( + class_ids.device + ) + + def __call__( + self, + images_list: list[Tensor], + images_hw: list[list[tuple[int, int]]], + intrinsics_list: list[Tensor], + extrinsics_list: list[Tensor], + frame_ids: list[int], + pred_boxes3d: list[list[Tensor]], + pred_boxes3d_classes: list[list[Tensor]], + pred_boxes3d_scores: list[list[Tensor]], + pred_boxes3d_velocities: list[list[Tensor]], + ) -> Track3DOut: + """Type definition for call implementation.""" + return self._call_impl( + images_list, + images_hw, + intrinsics_list, + extrinsics_list, + frame_ids, + pred_boxes3d, + pred_boxes3d_classes, + pred_boxes3d_scores, + pred_boxes3d_velocities, + ) diff --git a/vis4d/op/__init__.py b/vis4d/op/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1fecfaf92b8fec6c23fe16b34dfc8f9bcd95370d --- /dev/null +++ b/vis4d/op/__init__.py @@ -0,0 +1,8 @@ +"""Compositional operators used for implementing models. + +This is where most of the library APIs are implemented. +All the operators are functors. They are native PyTorch modules and only have a +forward member for function invocations. We follow the principle of functional +programming. The operators don't keep internal states besides the operator +weights. The operator computation and call has no side effects. +""" diff --git a/vis4d/op/__pycache__/__init__.cpython-311.pyc b/vis4d/op/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..efd7c26650690b1ce2842f2a236cc75e94eeb4ce Binary files /dev/null and b/vis4d/op/__pycache__/__init__.cpython-311.pyc differ diff --git a/vis4d/op/base/__init__.py b/vis4d/op/base/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ffa6ba27c722bb9b3131185555cae7ebb80aed39 --- /dev/null +++ b/vis4d/op/base/__init__.py @@ -0,0 +1,8 @@ +"""Base model module.""" + +from .base import BaseModel +from .csp_darknet import CSPDarknet +from .dla import DLA +from .resnet import ResNet, ResNetV1c + +__all__ = ["BaseModel", "CSPDarknet", "DLA", "ResNet", "ResNetV1c"] diff --git a/vis4d/op/base/base.py b/vis4d/op/base/base.py new file mode 100644 index 0000000000000000000000000000000000000000..90222696739151aa78309f04b780a3f0a5f8eaa8 --- /dev/null +++ b/vis4d/op/base/base.py @@ -0,0 +1,58 @@ +"""Base model interface.""" + +from __future__ import annotations + +import abc + +import torch +from torch import nn + + +class BaseModel(nn.Module): + """Abstract base model for feature extraction.""" + + @abc.abstractmethod + def forward(self, images: torch.Tensor) -> list[torch.Tensor]: + """Base model forward. + + Args: + images (Tensor[N, C, H, W]): Image input to process. Expected to be + type float32. + + Raises: + NotImplementedError: This is an abstract class method. + + Returns: + fp (list[torch.Tensor]): The output feature pyramid. The list index + represents the level, which has a downsampling ratio of 2^index for + most of the cases. fp[2] is the C2 or P2 in the FPN paper + (https://arxiv.org/abs/1612.03144). fp[0] is the original image or + the feature map with the same resolution. fp[1] may be the copy of + the input image if the network doesn't generate the feature map of + the resolution. + """ + raise NotImplementedError + + @property + @abc.abstractmethod + def out_channels(self) -> list[int]: + """Get the number of channels for each level of feature pyramid. + + Raises: + NotImplementedError: This is an abstract class method. + + Returns: + list[int]: Number of channels. + """ + raise NotImplementedError + + def __call__(self, images: torch.Tensor) -> list[torch.Tensor]: + """Type definition for call implementation. + + Args: + images (torch.Tensor): Image input to process. + + Returns: + list[torch.Tensor]: The output feature pyramid. + """ + return self._call_impl(images) diff --git a/vis4d/op/base/csp_darknet.py b/vis4d/op/base/csp_darknet.py new file mode 100644 index 0000000000000000000000000000000000000000..830d171f31c5935ee5b6477e5dadac0cbf519e9b --- /dev/null +++ b/vis4d/op/base/csp_darknet.py @@ -0,0 +1,305 @@ +"""CSP-Darknet base network used in YOLOX. + +Modified from mmdetection (https://github.com/open-mmlab/mmdetection). +""" + +from __future__ import annotations + +import math +from collections.abc import Sequence + +import torch +from torch import nn +from torch.nn.modules.batchnorm import _BatchNorm + +from vis4d.op.layer.conv2d import Conv2d +from vis4d.op.layer.csp_layer import CSPLayer + + +class Focus(nn.Module): + """Focus width and height information into channel space. + + Args: + in_channels (int): The input channels of this Module. + out_channels (int): The output channels of this Module. + kernel_size (int, optional): The kernel size of the convolution. + Defaults to 1. + stride (int, optional): The stride of the convolution. Defaults to 1. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 1, + stride: int = 1, + ): + """Init.""" + super().__init__() + self.conv = Conv2d( + in_channels * 4, + out_channels, + kernel_size, + stride, + padding=(kernel_size - 1) // 2, + bias=False, + norm=nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.03), + activation=nn.SiLU(inplace=True), + ) + + def forward(self, features: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + features (torch.Tensor): The input tensor of shape [B, C, W, H]. + """ + patch_top_left = features[..., ::2, ::2] + patch_top_right = features[..., ::2, 1::2] + patch_bot_left = features[..., 1::2, ::2] + patch_bot_right = features[..., 1::2, 1::2] + x = torch.cat( + ( + patch_top_left, + patch_bot_left, + patch_top_right, + patch_bot_right, + ), + dim=1, + ) + return self.conv(x) + + +class SPPBottleneck(nn.Module): + """Spatial pyramid pooling layer used in YOLOv3-SPP. + + Args: + in_channels (int): Input channels. + out_channels (int): Output channels. + kernel_sizes (Sequence[int], optional): Sequential of kernel sizes of + pooling layers. Defaults to (5, 9, 13). + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_sizes: Sequence[int] = (5, 9, 13), + ): + """Init.""" + super().__init__() + mid_channels = in_channels // 2 + self.conv1 = Conv2d( + in_channels, + mid_channels, + 1, + stride=1, + bias=False, + norm=nn.BatchNorm2d(mid_channels, eps=0.001, momentum=0.03), + activation=nn.SiLU(inplace=True), + ) + self.poolings = nn.ModuleList( + [ + nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2) + for ks in kernel_sizes + ] + ) + conv2_channels = mid_channels * (len(kernel_sizes) + 1) + self.conv2 = Conv2d( + conv2_channels, + out_channels, + 1, + bias=False, + norm=nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.03), + activation=nn.SiLU(inplace=True), + ) + + def forward(self, features: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + features (torch.Tensor): Input features. + """ + x = self.conv1(features) + x = torch.cat([x] + [pooling(x) for pooling in self.poolings], dim=1) + x = self.conv2(x) + return x + + +class CSPDarknet(nn.Module): + """CSP-Darknet backbone used in YOLOv5 and YOLOX. + + Args: + arch (str): Architecture of CSP-Darknet, from {P5, P6}. + Default: P5. + deepen_factor (float): Depth multiplier, multiply number of + blocks in CSP layer by this amount. Default: 1.0. + widen_factor (float): Width multiplier, multiply number of + channels in each layer by this amount. Default: 1.0. + out_indices (Sequence[int]): Output from which stages. + Default: (2, 3, 4). + frozen_stages (int): Stages to be frozen (stop grad and set eval + mode). -1 means not freezing any parameters. Default: -1. + use_depthwise (bool): Whether to use depthwise separable convolution. + Default: False. + arch_ovewrite(list[list[int]], optional): Overwrite default arch + settings. Defaults to None. + spp_kernal_sizes: (tuple[int]): Sequential of kernel sizes of SPP + layers. Default: (5, 9, 13). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. + + Example: + >>> import torch + >>> from vis4d.op.base import CSPDarknet + >>> self = CSPDarknet() + >>> self.eval() + >>> inputs = torch.rand(1, 3, 416, 416) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + ... + (1, 256, 52, 52) + (1, 512, 26, 26) + (1, 1024, 13, 13) + """ + + # From left to right: + # in_channels, out_channels, num_blocks, add_identity, use_spp + arch_settings = { + "P5": [ + [64, 128, 3, True, False], + [128, 256, 9, True, False], + [256, 512, 9, True, False], + [512, 1024, 3, False, True], + ], + "P6": [ + [64, 128, 3, True, False], + [128, 256, 9, True, False], + [256, 512, 9, True, False], + [512, 768, 3, True, False], + [768, 1024, 3, False, True], + ], + } + + def __init__( + self, + arch: str = "P5", + deepen_factor: float = 1.0, + widen_factor: float = 1.0, + out_indices: Sequence[int] = (2, 3, 4), + frozen_stages: int = -1, + arch_ovewrite: list[list[int]] | None = None, + spp_kernal_sizes: Sequence[int] = (5, 9, 13), + norm_eval: bool = False, + ): + """Init.""" + super().__init__() + arch_setting = self.arch_settings[arch] + if arch_ovewrite: + arch_setting = arch_ovewrite + assert set(out_indices).issubset( + i for i in range(len(arch_setting) + 1) + ) + if frozen_stages not in range(-1, len(arch_setting) + 1): + raise ValueError( + "frozen_stages must be in range(-1, " + "len(arch_setting) + 1). But received " + f"{frozen_stages}" + ) + + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.norm_eval = norm_eval + + self.stem = Focus( + 3, int(arch_setting[0][0] * widen_factor), kernel_size=3 + ) + self.layers = ["stem"] + + for i, ( + in_channels, + out_channels, + num_blocks, + add_identity, + use_spp, + ) in enumerate(arch_setting): + in_channels = int(in_channels * widen_factor) + out_channels = int(out_channels * widen_factor) + num_blocks = max(round(num_blocks * deepen_factor), 1) + stage: list[nn.Module] = [] + conv_layer = Conv2d( + in_channels, + out_channels, + 3, + stride=2, + padding=1, + bias=False, + norm=nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.03), + activation=nn.SiLU(inplace=True), + ) + stage.append(conv_layer) + if use_spp: + spp = SPPBottleneck( + out_channels, out_channels, kernel_sizes=spp_kernal_sizes + ) + stage.append(spp) + csp_layer = CSPLayer( + out_channels, + out_channels, + num_blocks=num_blocks, + add_identity=bool(add_identity), + ) + stage.append(csp_layer) + self.add_module(f"stage{i + 1}", nn.Sequential(*stage)) + self.layers.append(f"stage{i + 1}") + self._init_weights() + + def _init_weights(self) -> None: + """Initialize weights.""" + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_uniform_( + m.weight, + a=math.sqrt(5), + mode="fan_in", + nonlinearity="leaky_relu", + ) + + def _freeze_stages(self) -> None: + """Freeze stages.""" + if self.frozen_stages >= 0: + for i in range(self.frozen_stages + 1): + m = getattr(self, self.layers[i]) + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def train(self, mode: bool = True) -> CSPDarknet: + """Override the train mode for the model. + + Args: + mode (bool): Whether to set training mode to True. + """ + super().train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() + return self + + def forward(self, images: torch.Tensor) -> list[torch.Tensor]: + """Forward pass. + + Args: + images (torch.Tensor): Input images. + """ + outs = [images, images] + x = images + for i, layer_name in enumerate(self.layers): + layer = getattr(self, layer_name) + x = layer(x) + if i in self.out_indices: + outs.append(x) + return outs diff --git a/vis4d/op/base/dla.py b/vis4d/op/base/dla.py new file mode 100644 index 0000000000000000000000000000000000000000..98ff9d486a26c0440c680a05fe7d9fac1562405c --- /dev/null +++ b/vis4d/op/base/dla.py @@ -0,0 +1,647 @@ +"""DLA base model.""" + +from __future__ import annotations + +import math +from collections.abc import Sequence + +import torch +from torch import Tensor, nn +from torch.utils.checkpoint import checkpoint + +from vis4d.common.ckpt import load_model_checkpoint + +from .base import BaseModel + +BN_MOMENTUM = 0.1 + +DLA_MODEL_PREFIX = "http://dl.yf.io/dla/models/imagenet" + +DLA_MODEL_MAPPING = { + "dla34": "dla34-ba72cf86.pth", + "dla46_c": "dla46_c-2bfd52c3.pth", + "dla46x_c": "dla46x_c-d761bae7.pth", + "dla60x_c": "dla60x_c-b870c45c.pth", + "dla60": "dla60-24839fc4.pth", + "dla60x": "dla60x-d15cacda.pth", + "dla102": "dla102-d94d9790.pth", + "dla102x": "dla102x-ad62be81.pth", + "dla102x2": "dla102x2-262837b6.pth", + "dla169": "dla169-0914e092.pth", +} + +DLA_ARCH_SETTINGS = { # pylint: disable=consider-using-namedtuple-or-dataclass + "dla34": ( + (1, 1, 1, 2, 2, 1), + (16, 32, 64, 128, 256, 512), + False, + "BasicBlock", + ), + "dla46_c": ( + (1, 1, 1, 2, 2, 1), + (16, 32, 64, 64, 128, 256), + False, + "Bottleneck", + ), + "dla46x_c": ( + (1, 1, 1, 2, 2, 1), + (16, 32, 64, 64, 128, 256), + False, + "BottleneckX", + ), + "dla60x_c": ( + (1, 1, 1, 2, 3, 1), + (16, 32, 64, 64, 128, 256), + False, + "BottleneckX", + ), + "dla60": ( + (1, 1, 1, 2, 3, 1), + (16, 32, 128, 256, 512, 1024), + False, + "Bottleneck", + ), + "dla60x": ( + (1, 1, 1, 2, 3, 1), + (16, 32, 128, 256, 512, 1024), + False, + "BottleneckX", + ), + "dla102": ( + (1, 1, 1, 3, 4, 1), + (16, 32, 128, 256, 512, 1024), + True, + "Bottleneck", + ), + "dla102x": ( + (1, 1, 1, 3, 4, 1), + (16, 32, 128, 256, 512, 1024), + True, + "BottleneckX", + ), + "dla102x2": ( + (1, 1, 1, 3, 4, 1), + (16, 32, 128, 256, 512, 1024), + True, + "BottleneckX", + ), + "dla169": ( + (1, 1, 2, 3, 5, 1), + (16, 32, 128, 256, 512, 1024), + True, + "Bottleneck", + ), +} + + +class BasicBlock(nn.Module): + """BasicBlock.""" + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + dilation: int = 1, + with_cp: bool = False, + ) -> None: + """Creates an instance of the class.""" + super().__init__() + self.conv1 = nn.Conv2d( + inplanes, + planes, + kernel_size=3, + stride=stride, + padding=dilation, + bias=False, + dilation=dilation, + ) + self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=3, + stride=1, + padding=dilation, + bias=False, + dilation=dilation, + ) + self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.stride = stride + self.with_cp = with_cp + + def forward( + self, input_x: Tensor, residual: None | Tensor = None + ) -> Tensor: + """Forward.""" + + def _inner_forward( + input_x: Tensor, residual: None | Tensor = None + ) -> Tensor: + if residual is None: + residual = input_x + out = self.conv1(input_x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + out += residual + + return out + + if self.with_cp and input_x.requires_grad: + out = checkpoint( + _inner_forward, input_x, residual, use_reentrant=True + ) + else: + out = _inner_forward(input_x, residual) + + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + """Bottleneck.""" + + expansion = 2 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + dilation: int = 1, + with_cp: bool = False, + ) -> None: + """Creates an instance of the class.""" + super().__init__() + expansion = Bottleneck.expansion + bottle_planes = planes // expansion + self.conv1 = nn.Conv2d( + inplanes, bottle_planes, kernel_size=1, bias=False + ) + self.bn1 = nn.BatchNorm2d(bottle_planes, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d( + bottle_planes, + bottle_planes, + kernel_size=3, + stride=stride, + padding=dilation, + bias=False, + dilation=dilation, + ) + self.bn2 = nn.BatchNorm2d(bottle_planes, momentum=BN_MOMENTUM) + self.conv3 = nn.Conv2d( + bottle_planes, planes, kernel_size=1, bias=False + ) + self.bn3 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.stride = stride + self.with_cp = with_cp + + def forward( + self, input_x: Tensor, residual: None | Tensor = None + ) -> Tensor: + """Forward.""" + + def _inner_forward( + input_x: Tensor, residual: None | Tensor = None + ) -> Tensor: + if residual is None: + residual = input_x + + out = self.conv1(input_x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + out += residual + + return out + + if self.with_cp and input_x.requires_grad: + out = checkpoint( + _inner_forward, input_x, residual, use_reentrant=True + ) + else: + out = _inner_forward(input_x, residual) + + out = self.relu(out) + + return out + + +class BottleneckX(nn.Module): + """BottleneckX.""" + + expansion = 2 + cardinality = 32 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + dilation: int = 1, + with_cp: bool = False, + ) -> None: + """Creates an instance of the class.""" + super().__init__() + cardinality = BottleneckX.cardinality + bottle_planes = planes * cardinality // 32 + self.conv1 = nn.Conv2d( + inplanes, bottle_planes, kernel_size=1, bias=False + ) + self.bn1 = nn.BatchNorm2d(bottle_planes, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d( + bottle_planes, + bottle_planes, + kernel_size=3, + stride=stride, + padding=dilation, + bias=False, + dilation=dilation, + groups=cardinality, + ) + self.bn2 = nn.BatchNorm2d(bottle_planes, momentum=BN_MOMENTUM) + self.conv3 = nn.Conv2d( + bottle_planes, planes, kernel_size=1, bias=False + ) + self.bn3 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.stride = stride + self.with_cp = with_cp + + def forward( + self, input_x: Tensor, residual: None | Tensor = None + ) -> Tensor: + """Forward.""" + + def _inner_forward( + input_x: Tensor, residual: None | Tensor = None + ) -> Tensor: + if residual is None: + residual = input_x + + out = self.conv1(input_x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + out += residual + + return out + + if self.with_cp and input_x.requires_grad: + out = checkpoint( + _inner_forward, input_x, residual, use_reentrant=True + ) + else: + out = _inner_forward(input_x, residual) + + out = self.relu(out) + + return out + + +class Root(nn.Module): + """Root.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + residual: bool, + with_cp: bool = False, + ) -> None: + """Creates an instance of the class.""" + super().__init__() + self.conv = nn.Conv2d( + in_channels, + out_channels, + 1, + stride=1, + bias=False, + padding=(kernel_size - 1) // 2, + ) + self.bn = nn.BatchNorm2d( # pylint: disable=invalid-name + out_channels, momentum=BN_MOMENTUM + ) + self.relu = nn.ReLU(inplace=True) + self.residual = residual + self.with_cp = with_cp + + def forward(self, *input_x: Tensor) -> Tensor: + """Forward.""" + + def _inner_forward(*input_x: Tensor) -> Tensor: + feats = self.conv(torch.cat(input_x, 1)) + feats = self.bn(feats) + if self.residual: + feats += input_x[0] + return feats + + if self.with_cp and input_x[0].requires_grad: + feats = checkpoint(_inner_forward, *input_x, use_reentrant=True) + else: + feats = _inner_forward(*input_x) + + feats = self.relu(feats) + + return feats + + +class Tree(nn.Module): + """Tree.""" + + def __init__( # pylint: disable=too-many-arguments + self, + levels: int, + block: str, + in_channels: int, + out_channels: int, + stride: int = 1, + level_root: bool = False, + root_dim: int = 0, + root_kernel_size: int = 1, + dilation: int = 1, + root_residual: bool = False, + with_cp: bool = False, + ) -> None: + """Creates an instance of the class.""" + super().__init__() + if block == "BasicBlock": + block_c = BasicBlock + elif block == "Bottleneck": + block_c = Bottleneck # type: ignore + elif block == "BottleneckX": + block_c = BottleneckX # type: ignore + else: + raise ValueError(f"Block={block} not yet supported in DLA!") + if root_dim == 0: + root_dim = 2 * out_channels + if level_root: + root_dim += in_channels + if levels == 1: + self.tree1: Tree | BasicBlock = block_c( + in_channels, + out_channels, + stride, + dilation=dilation, + with_cp=with_cp, + ) + self.tree2: Tree | BasicBlock = block_c( + out_channels, + out_channels, + 1, + dilation=dilation, + with_cp=with_cp, + ) + self.root = Root( + root_dim, + out_channels, + root_kernel_size, + root_residual, + with_cp=with_cp, + ) + else: + self.tree1 = Tree( + levels - 1, + block, + in_channels, + out_channels, + stride, + root_dim=0, + root_kernel_size=root_kernel_size, + dilation=dilation, + root_residual=root_residual, + with_cp=with_cp, + ) + self.tree2 = Tree( + levels - 1, + block, + out_channels, + out_channels, + root_dim=root_dim + out_channels, + root_kernel_size=root_kernel_size, + dilation=dilation, + root_residual=root_residual, + with_cp=with_cp, + ) + self.level_root = level_root + self.root_dim = root_dim + self.downsample = None + self.project = None + self.levels = levels + if stride > 1: + self.downsample = nn.MaxPool2d(stride, stride=stride) + if in_channels != out_channels and levels == 1: + # NOTE the official impl/weights have project layers in levels > 1 + # case that are never used, hence 'levels == 1' is added but + # pretrained models will need strict=False while loading. + self.project = nn.Sequential( + nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + bias=False, + ), + nn.BatchNorm2d(out_channels), + ) + + def forward( + self, + input_x: Tensor, + residual: None | Tensor = None, + children: None | list[Tensor] = None, + ) -> Tensor: + """Forward.""" + children = [] if children is None else children + bottom = self.downsample(input_x) if self.downsample else input_x + residual = self.project(bottom) if self.project else bottom + if self.level_root: + children.append(bottom) + input_x1 = self.tree1(input_x, residual) + if self.levels == 1: + input_x2 = self.tree2(input_x1) + input_x = self.root(input_x2, input_x1, *children) + else: + children.append(input_x1) + input_x = self.tree2(input_x1, children=children) + return input_x + + +class DLA(BaseModel): + """DLA base model.""" + + def __init__( + self, + name: str, + out_indices: Sequence[int] = (0, 1, 2, 3), + with_cp: bool = False, + pretrained: bool = False, + weights: None | str = None, + ) -> None: + """Creates an instance of the class.""" + super().__init__() + assert name in DLA_ARCH_SETTINGS, f"{name} is not supported!" + + levels, channels, residual_root, block = DLA_ARCH_SETTINGS[name] + + if name == "dla102x2": # pragma: no cover + BottleneckX.cardinality = 64 + + self.base_layer = nn.Sequential( + nn.Conv2d( + 3, channels[0], kernel_size=7, stride=1, padding=3, bias=False + ), + nn.BatchNorm2d(channels[0], momentum=BN_MOMENTUM), + nn.ReLU(inplace=True), + ) + self.level0 = self._make_conv_level( + channels[0], channels[0], levels[0] + ) + self.level1 = self._make_conv_level( + channels[0], channels[1], levels[1], stride=2 + ) + self.level2 = Tree( + levels[2], + block, + channels[1], + channels[2], + 2, + level_root=False, + root_residual=residual_root, + with_cp=with_cp, + ) + self.level3 = Tree( + levels[3], + block, + channels[2], + channels[3], + 2, + level_root=True, + root_residual=residual_root, + with_cp=with_cp, + ) + self.level4 = Tree( + levels[4], + block, + channels[3], + channels[4], + 2, + level_root=True, + root_residual=residual_root, + with_cp=with_cp, + ) + self.level5 = Tree( + levels[5], + block, + channels[4], + channels[5], + 2, + level_root=True, + root_residual=residual_root, + with_cp=with_cp, + ) + + self.out_indices = out_indices + self._out_channels = [channels[i + 2] for i in out_indices] + + if pretrained: + if weights is None: # pragma: no cover + weights = f"{DLA_MODEL_PREFIX}/{DLA_MODEL_MAPPING[name]}" + + load_model_checkpoint(self, weights) + + else: + self._init_weights() + + def _init_weights(self) -> None: + """Initialize module weights.""" + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2.0 / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + @staticmethod + def _make_conv_level( + inplanes: int, + planes: int, + convs: int, + stride: int = 1, + dilation: int = 1, + ) -> nn.Sequential: + """Build convolutional level.""" + modules = [] + for i in range(convs): + modules.extend( + [ + nn.Conv2d( + inplanes, + planes, + kernel_size=3, + stride=stride if i == 0 else 1, + padding=dilation, + bias=False, + dilation=dilation, + ), + nn.BatchNorm2d(planes, momentum=BN_MOMENTUM), + nn.ReLU(inplace=True), + ] + ) + inplanes = planes + return nn.Sequential(*modules) + + def forward(self, images: Tensor) -> list[Tensor]: + """DLA forward. + + Args: + images (Tensor[N, C, H, W]): Image input to process. Expected to + type float32 with values ranging 0..255. + + Returns: + fp (list[Tensor]): The output feature pyramid. The list index + represents the level, which has a downsampling raio of 2^index. + """ + input_x = self.base_layer(images) + + outs = [images, images] + + for i in range(6): + input_x = getattr(self, f"level{i}")(input_x) + + if i - 2 in self.out_indices: + outs.append(input_x) + + return outs + + @property + def out_channels(self) -> list[int]: + """Get the numbers of channels for each level of feature pyramid. + + Returns: + list[int]: number of channels + """ + return [3, 3] + self._out_channels diff --git a/vis4d/op/base/pointnet.py b/vis4d/op/base/pointnet.py new file mode 100644 index 0000000000000000000000000000000000000000..39eb638c84fa54ec8d2cbe42af91dc5fdbc854b7 --- /dev/null +++ b/vis4d/op/base/pointnet.py @@ -0,0 +1,408 @@ +"""Operations for PointNet. + +Code taken from +https://github.com/timothylimyl/PointNet-Pytorch/blob/master/pointnet/model.py +and modified to allow for modular configuration. +""" + +from __future__ import annotations + +from collections.abc import Callable, Iterable +from typing import NamedTuple + +import torch +from torch import nn + +from vis4d.common.typing import ArgsType + + +class PointNetEncoderOut(NamedTuple): + """Output of the PointNetEncoder. + + features: Global features shape [N, feature_dim] + pointwise Features: Pointwise features shape [N, last_mlp_dim, n_pts] + transformations: list with all transformation matrixes that were used. + Shape [N, d, d] + """ + + features: torch.Tensor + pointwise_features: torch.Tensor # + transformations: list[ # list with all transformation matrices [[B, d, d]] + torch.Tensor + ] + + +class PointNetSemanticsLoss(NamedTuple): + """Losses for the pointnet semantic segmentation network.""" + + semantic_loss: torch.Tensor + regularization_loss: torch.Tensor + + +class PointNetSemanticsOut(NamedTuple): + """Output of the PointNet Segmentation network.""" + + class_logits: torch.Tensor # B, n_classes, n_pts + transformations: list[ # list with all transformation matrices [[B, d, d]] + torch.Tensor + ] + + +class LinearTransform(nn.Module): + """Module that learns a linear transformation for a input pointcloud. + + Code taken from + https://github.com/timothylimyl/PointNet-Pytorch/blob/master/pointnet/model.py + and modified to allow for modular configuration. + + See T-Net in Pointnet publication (https://arxiv.org/pdf/1612.00593.pdf) + for more information + """ + + def __init__( + self, + in_dimension: int = 3, + upsampling_dims: Iterable[int] = (64, 128, 1024), + downsampling_dims: Iterable[int] = (1024, 512, 256), + norm_cls: str | None = "BatchNorm1d", + activation_cls: str = "ReLU", + ) -> None: + """Creates a new LinearTransform. + + This learns a transformation matrix from data. + + Args: + in_dimension (int): input dimension + upsampling_dims (Iterable[int]): list of intermediate feature + shapes for upsampling + downsampling_dims (Iterable[int]): list of intermediate feature + shapes for downsampling. + Make sure this matches with the + last upsampling_dims + norm_cls (Optional(str)): class for norm (nn.'norm_cls') or None + activation_cls (str): class for activation (nn.'activation_cls') + """ + super().__init__() + self.upsampling_dims = list(upsampling_dims) + self.downsampling_dims = list(downsampling_dims) + + assert ( + len(self.upsampling_dims) != 0 and len(self.downsampling_dims) != 0 + ) + assert self.upsampling_dims[-1] == self.downsampling_dims[0] + + self.in_dimension_ = in_dimension + self.identity: torch.Tensor + self.register_buffer( + "identity", torch.eye(in_dimension).reshape(1, in_dimension**2) + ) + + # Create activation + self.activation_ = getattr(nn, activation_cls)() + + # Create norms + norm_fn: Callable[[int], nn.Module] | None = ( + getattr(nn, norm_cls) if norm_cls is not None else None + ) + + if norm_fn is not None: + self.norms_ = nn.ModuleList( + norm_fn(feature_size) + for feature_size in ( + *upsampling_dims, + *self.downsampling_dims[1:], + ) + ) + + # Create upsampling layers + self.upsampling_layers = nn.ModuleList( + [nn.Conv1d(in_dimension, self.upsampling_dims[0], 1)] + ) + for i in range(len(self.upsampling_dims) - 1): + self.upsampling_layers.append( + nn.Conv1d( + self.upsampling_dims[i], self.upsampling_dims[i + 1], 1 + ) + ) + + # Create downsampling layers + self.downsampling_layers = nn.ModuleList( + [ + nn.Linear( + self.downsampling_dims[i], self.downsampling_dims[i + 1] + ) + for i in range(len(self.downsampling_dims) - 1) + ] + ) + self.downsampling_layers.append( + nn.Linear(self.downsampling_dims[-1], in_dimension**2) + ) + + def __call__( + self, + features: torch.Tensor, + ) -> torch.Tensor: + """Type definition for call implementation.""" + return self._call_impl(features) + + def forward( + self, + features: torch.Tensor, + ) -> torch.Tensor: + """Linear Transform forward. + + Args: + features (Tensor[B, C, N]): Input features (e.g. points) + + Returns: + Learned Canonical Transfomation Matrix for this input. + See T-Net in Pointnet publication + (https://arxiv.org/pdf/1612.00593.pdf) + for further information + """ + batchsize = features.shape[0] + # Upsample features + for idx, layer in enumerate(self.upsampling_layers): + features = layer(features) + if self.norms_ is not None: + features = self.norms_[idx](features) + features = self.activation_(features) + + features = torch.max(features, 2, keepdim=True)[0] + features = features.view(-1, self.upsampling_dims[-1]) + + # Downsample features + for idx, layer in enumerate(self.downsampling_layers): + features = layer(features) + + # Do not apply norm and activation for + # final layer + if idx != len(self.downsampling_layers) - 1: + if self.norms_ is not None: + norm_idx = idx + len(self.upsampling_layers) + features = self.norms_[norm_idx](features) + features = self.activation_(features) + + identity_batch = self.identity.repeat(batchsize, 1) + transformations = features + identity_batch + + return transformations.view( + batchsize, self.in_dimension_, self.in_dimension_ + ) + + +class PointNetEncoder(nn.Module): + """PointNetEncoder. + + Encodes a pointcloud and additional features into one feature description + + See pointnet publication for more information + (https://arxiv.org/pdf/1612.00593.pdf) + """ + + def __init__( + self, + in_dimensions: int = 3, + out_dimensions: int = 1024, + mlp_dimensions: Iterable[Iterable[int]] = ((64, 64), (64, 128)), + norm_cls: str | None = "BatchNorm1d", + activation_cls: str = "ReLU", + **kwargs: ArgsType, + ): + """Creates a new PointNetEncoder. + + Args: + in_dimensions (int): input dimension (e.g. 3 for xzy, 6 for xzyrgb) + out_dimensions (int): output dimensions + mlp_dimensions (Iterable[Iterable[int]]):(Dimensions of MLP layers) + norm_cls (Optional(str)): class for norm (nn.'norm_cls') or None + activation_cls (str): class for activation (nn.'activation_cls') + kwargs : See arguments of @LinearTransformStn + """ + super().__init__() + + self.out_dimension = out_dimensions + + # Extend dimensions to upscale from input dimension + mlp_dim_list: list[list[int]] = [list(d) for d in mlp_dimensions] + mlp_dim_list[0].insert(0, in_dimensions) + mlp_dim_list[-1].append(out_dimensions) + self.mlp_dimensions = mlp_dim_list + + # Learnable transformation layers. + self.trans_layers_ = nn.ModuleList( + [ + LinearTransform( + in_dimension=dims[0], + norm_cls=norm_cls, + activation_cls=activation_cls, + **kwargs, + ) + for dims in mlp_dim_list + ] + ) + + # MLP layers + self.mlp_layers_ = nn.ModuleList() + + # Create activation + activation = getattr(nn, activation_cls)() + + # Create norms + norm_fn: Callable[[int], nn.Module] | None = ( + getattr(nn, norm_cls) if norm_cls is not None else None + ) + + for mlp_idx, mlp_dims in enumerate(mlp_dim_list): + layers: list[nn.Module] = [] + + for idx, (in_dim, out_dim) in enumerate( + zip(mlp_dims[:-1], mlp_dims[1:]) + ): + # Create MLP + layers.append(torch.nn.Conv1d(in_dim, out_dim, 1)) + # Create BN if needed + if norm_fn is not None: + layers.append(norm_fn(out_dim)) + + # Only add activation if not last layer + if ( + mlp_idx != len(mlp_dim_list) - 1 + and idx != len(mlp_dims) - 2 + ): + layers.append(activation) + + self.mlp_layers_.append(nn.Sequential(*layers)) + + def __call__(self, features: torch.Tensor) -> PointNetEncoderOut: + """Type definition for call implementation.""" + return self._call_impl(features) + + def forward(self, features: torch.Tensor) -> PointNetEncoderOut: + """Pointnet encoder forward. + + Args: + features (Tensor[B, C, N]): Input features stacked in channels. + e.g. raw point inputs: [B, 3, N] , w color : [B, 3+3, N], ... + + Returns: + Extracted feature representation for input and all + applied transformations. + """ + transforms: list[torch.Tensor] = [] + + for block_idx, trans_layer in enumerate(self.trans_layers_): + # Apply transformation + trans = trans_layer(features) + transforms.append(trans) + features = features.transpose(2, 1) + features = torch.bmm(features, trans) + features = features.transpose(2, 1) + + if block_idx == len(self.trans_layers_) - 1: + pointwise_features = features.clone() + + # Apply MLP + features = self.mlp_layers_[block_idx](features) + + features = torch.max(features, 2, keepdim=True)[0] + features = features.view(-1, self.out_dimension) + + return PointNetEncoderOut( + features=features, + transformations=transforms, + pointwise_features=pointwise_features, # pylint: disable=possibly-used-before-assignment, line-too-long + ) + + +class PointNetSegmentation(nn.Module): + """Segmentation network using a simple pointnet as encoder.""" + + def __init__( + self, + n_classes: int, + in_dimensions: int = 3, + feature_dimension: int = 1024, + norm_cls: str = "BatchNorm1d", + activation_cls: str = "ReLU", + ): + """Creates a new Point Net segementation network. + + Args: + n_classes (int): Number of semantic classes + in_dimensions (int): Input dimension (3 for xyz, 6 xyzrgb, ...) + feature_dimension (int): Size of feature from the encoder + norm_cls (Optional(str)): class for norm (nn.'norm_cls') or None + activation_cls (str): class for activation (nn.'activation_cls') + + Raises: + ValueError: If dimensions are invalid + """ + super().__init__() + self.in_dimensions = in_dimensions + + self.encoder = PointNetEncoder( + in_dimensions=in_dimensions, + out_dimensions=feature_dimension, + norm_cls=norm_cls, + activation_cls=activation_cls, + ) + pc_feat_dim = self.encoder.mlp_dimensions[-1][0] + + # Create activation + activation = getattr(nn, activation_cls)() + + # Create norms + norm_fn: Callable[[int], nn.Module] = ( + getattr(nn, norm_cls) if norm_cls is not None else None + ) + self.classifier_dims = [feature_dimension + pc_feat_dim, 512, 256, 128] + # Build Model + self.classifier = nn.Sequential() + for in_dim, out_dim in zip( + self.classifier_dims[:-1], self.classifier_dims[1:] + ): + self.classifier.append(nn.Conv1d(in_dim, out_dim, 1)) + if norm_fn is not None: + self.classifier.append(norm_fn(out_dim)) + self.classifier.append(activation) + + self.classifier.append( + nn.Conv1d( + out_dim, # pylint: disable=undefined-loop-variable + n_classes, + 1, + ) + ) + + def __call__(self, points: torch.Tensor) -> PointNetSemanticsOut: + """Call function.""" + return self._call_impl(points) + + def forward(self, points: torch.Tensor) -> PointNetSemanticsOut: + """Pointnet Segmenter Forward. + + Args: + points (tensor) : inputs points dimension [B, in_dim, n_pts] + + Returns: + Returns a list of tensors where the first element is + the desired segmentation [B, n_classes, n_pts] and the other + elements are the linear transformation matrices which + have been used to transform the pointclouds + @see LinearTransform + """ + assert points.size(-2) == self.in_dimensions + n_pts = points.size(-1) + bs = points.size(0) + encoder_out = self.encoder(points) + global_features = encoder_out.features.view(bs, -1, 1).repeat( + 1, 1, n_pts + ) + + x = torch.cat([global_features, encoder_out.pointwise_features], 1) + + x = self.classifier(x) + return PointNetSemanticsOut( + class_logits=x, transformations=encoder_out.transformations + ) diff --git a/vis4d/op/base/pointnetpp.py b/vis4d/op/base/pointnetpp.py new file mode 100644 index 0000000000000000000000000000000000000000..d3f4105fa97f2aafd6eae667c19017d13ef63244 --- /dev/null +++ b/vis4d/op/base/pointnetpp.py @@ -0,0 +1,498 @@ +"""Pointnet++ implementation. + +based on https://github.com/yanx27/Pointnet_Pointnet2_pytorch +Added typing and named tuples for convenience. + +#TODO write tests +""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import NamedTuple + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + + +class PointNetSetAbstractionOut(NamedTuple): + """Ouput of PointNet set abstraction.""" + + coordinates: Tensor # [B, C, S] + features: Tensor # [B, D', S] + + +def square_distance(src: Tensor, dst: Tensor) -> Tensor: + """Calculate Euclid distance between each two points. + + src^T * dst = xn * xm + yn * ym + zn * zm; + sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; + sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; + dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 + = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst + + Input: + src: source points, [B, N, C] + dst: target points, [B, M, C] + + Output: + dist: per-point square distance, [B, N, M] + """ + bs, n_pts_in, _ = src.shape + _, n_pts_out, _ = dst.shape + dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) + dist += torch.sum(src**2, -1).view(bs, n_pts_in, 1) + dist += torch.sum(dst**2, -1).view(bs, 1, n_pts_out) + return dist + + +def index_points(points: Tensor, idx: Tensor) -> Tensor: + """Indexes points. + + Input: + points: input points data, [B, N, C] + idx: sample index data, [B, S] + + Return: + new_points:, indexed points data, [B, S, C] + """ + device = points.device + bs = points.shape[0] + view_shape = list(idx.shape) + view_shape[1:] = [1] * (len(view_shape) - 1) + repeat_shape = list(idx.shape) + repeat_shape[0] = 1 + batch_indices = ( + torch.arange(bs, dtype=torch.long) + .to(device) + .view(view_shape) + .repeat(repeat_shape) + ) + new_points = points[batch_indices, idx, :] + return new_points + + +def farthest_point_sample(xyz: Tensor, npoint: int) -> Tensor: + """Farthest point sampling. + + Input: + xyz: pointcloud data, [B, N, 3] + npoint: number of samples + + Return: + centroids: sampled pointcloud index, [B, npoint] + """ + device = xyz.device + bs, n_pts, _ = xyz.shape + centroids = torch.zeros(bs, npoint, dtype=torch.long).to(device) + distance = torch.ones(bs, n_pts).to(device) * 1e10 + farthest = torch.randint(0, n_pts, (bs,), dtype=torch.long).to(device) + batch_indices = torch.arange(bs, dtype=torch.long).to(device) + for i in range(npoint): + centroids[:, i] = farthest + centroid = xyz[batch_indices, farthest, :].view(bs, 1, 3) + dist = torch.sum((xyz - centroid) ** 2, -1) + mask = dist < distance + distance[mask] = dist[mask] + farthest = torch.max(distance, -1)[1] + return centroids + + +def query_ball_point( + radius: float, nsample: int, xyz: Tensor, new_xyz: Tensor +) -> Tensor: + """Query around a ball with given radius. + + Input: + radius: local region radius + nsample: max sample number in local region + xyz: all points, [B, N, 3] + new_xyz: query points, [B, S, 3] + + Return: + group_idx: grouped points index, [B, S, nsample] + """ + device = xyz.device + bs, n_pts_in, _ = xyz.shape + _, n_pts_out, _ = new_xyz.shape + group_idx = ( + torch.arange(n_pts_in, dtype=torch.long) + .to(device) + .view(1, 1, n_pts_in) + .repeat([bs, n_pts_out, 1]) + ) + sqrdists = square_distance(new_xyz, xyz) + group_idx[sqrdists > radius**2] = n_pts_in + group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] + group_first = ( + group_idx[:, :, 0].view(bs, n_pts_out, 1).repeat([1, 1, nsample]) + ) + mask = group_idx == n_pts_in + group_idx[mask] = group_first[mask] + return group_idx + + +def sample_and_group( + npoint: int, + radius: float, + nsample: int, + xyz: Tensor, + points: Tensor, +) -> tuple[Tensor, Tensor]: + """Samples and groups. + + Input: + npoint: Number of center to sample + radius: Grouping Radius + nsample: Max number of points to sample for each circle + xyz: input points position data, [B, N, 3] + points: input points data, [B, N, D] + + Return: + new_xyz: sampled points position data, [B, npoint, nsample, 3] + new_points: sampled points data, [B, npoint, nsample, 3+D] + """ + bs, _, channels = xyz.shape + fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C] + new_xyz = index_points(xyz, fps_idx) + idx = query_ball_point(radius, nsample, xyz, new_xyz) + grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C] + grouped_xyz_norm = grouped_xyz - new_xyz.view(bs, npoint, 1, channels) + + if points is not None: + grouped_points = index_points(points, idx) + new_points = torch.cat( + [grouped_xyz_norm, grouped_points], dim=-1 + ) # [B, npoint, nsample, C+D] + else: + new_points = grouped_xyz_norm + return new_xyz, new_points + + +def sample_and_group_all(xyz: Tensor, points: Tensor) -> tuple[Tensor, Tensor]: + """Sample and groups all. + + Input: + xyz: input points position data, [B, N, 3] + points: input points data, [B, N, D] + + Return: + new_xyz: sampled points position data, [B, 1, 3] + new_points: sampled points data, [B, 1, N, 3+D] + """ + device = xyz.device + bs, n_pts, channels = xyz.shape + new_xyz = torch.zeros(bs, 1, channels).to(device) + grouped_xyz = xyz.view(bs, 1, n_pts, channels) + if points is not None: + new_points = torch.cat( + [grouped_xyz, points.view(bs, 1, n_pts, -1)], dim=-1 + ) + else: + new_points = grouped_xyz + return new_xyz, new_points + + +class PointNetSetAbstraction(nn.Module): + """PointNet set abstraction layer.""" + + def __init__( + self, + npoint: int, + radius: float, + nsample: int, + in_channel: int, + mlp: list[int], + group_all: bool, + norm_cls: str | None = "BatchNorm2d", + ): + """Set Abstraction Layer from the Pointnet Architecture. + + Args: + npoint: How many points to sample + radius: Size of the ball query + nsample: Max number of points to group inside circle + in_channel: Input channel dimension + mlp: Input channel dimension of the mlp layers. + E.g. [32 , 32, 64] will use a MLP with three layers + group_all: If true, groups all point inside the ball, otherwise + samples 'nsample' points. + norm_cls (Optional(str)): class for norm (nn.'norm_cls') or None + """ + super().__init__() + self.npoint = npoint + self.radius = radius + self.nsample = nsample + self.mlp_convs = nn.ModuleList() + self.mlp_bns = nn.ModuleList() + last_channel = in_channel + + # Create norms + norm_fn: Callable[[int], nn.Module] | None = ( + getattr(nn, norm_cls) if norm_cls is not None else None + ) + + for out_channel in mlp: + self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1)) + if norm_fn is not None: + self.mlp_bns.append(norm_fn(out_channel)) + last_channel = out_channel + self.group_all = group_all + + def __call__( + self, coordinates: Tensor, features: Tensor + ) -> PointNetSetAbstractionOut: + """Call function. + + Input: + coordinates: input points position data, [B, C, N] + features: input points data, [B, D, N] + + Return: + PointNetSetAbstractionOut with: + coordinates: sampled points position data, [B, C, S] + features: sample points feature data, [B, D', S] + """ + return self._call_impl(coordinates, features) + + def forward( + self, xyz: Tensor, points: Tensor + ) -> PointNetSetAbstractionOut: + """Pointnet++ set abstraction layer forward. + + Input: + xyz: input points position data, [B, C, N] + points: input points data, [B, D, N] + + Return: + PointNetSetAbstractionOut with: + coordinates: sampled points position data, [B, C, S] + features: sample points feature data, [B, D', S] + """ + xyz = xyz.permute(0, 2, 1) + if points is not None: + points = points.permute(0, 2, 1) + + if self.group_all: + new_xyz, new_points = sample_and_group_all(xyz, points) + else: + new_xyz, new_points = sample_and_group( + self.npoint, self.radius, self.nsample, xyz, points + ) + # new_xyz: sampled points position data, [B, npoint, C] + # new_points: sampled points data, [B, npoint, nsample, C+D] + new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint] + for i, conv in enumerate(self.mlp_convs): + bn = self.mlp_bns[i] if len(self.mlp_bns) != 0 else lambda x: x + new_points = F.relu(bn(conv(new_points))) + + new_points = torch.max(new_points, 2)[0] + new_xyz = new_xyz.permute(0, 2, 1) + return PointNetSetAbstractionOut(new_xyz, new_points) + + +class PointNetFeaturePropagation(nn.Module): + """Pointnet++ Feature Propagation Layer.""" + + def __init__( + self, + in_channel: int, + mlp: list[int], + norm_cls: str = "BatchNorm1d", + ): + """Creates a pointnet++ feature propagation layer. + + Args: + in_channel: Number of input channels + mlp: list with hidden dimensions of the MLP. + norm_cls (Optional(str)): class for norm (nn.'norm_cls') or None + """ + super().__init__() + self.mlp_convs = nn.ModuleList() + self.mlp_bns = nn.ModuleList() + + # Create norms + norm_fn: Callable[[int], nn.Module] = ( + getattr(nn, norm_cls) if norm_cls is not None else None + ) + last_channel = in_channel + for out_channel in mlp: + self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1)) + if norm_cls is not None: + self.mlp_bns.append(norm_fn(out_channel)) + last_channel = out_channel + + def __call__( + self, + xyz1: Tensor, + xyz2: Tensor, + points1: Tensor | None, + points2: Tensor, + ) -> Tensor: + """Call function. + + Input: + xyz1: input points position data, [B, C, N] + xyz2: sampled input points position data, [B, C, S] + points1: input points features, [B, D, N] + points2: sampled points features, [B, D, S] + + Return: + new_points: upsampled points data, [B, D', N] + """ + return self._call_impl(xyz1, xyz2, points1, points2) + + def forward( + self, + xyz1: Tensor, + xyz2: Tensor, + points1: Tensor | None, + points2: Tensor, + ) -> Tensor: + """Forward Implementation. + + Input: + xyz1: input points position data, [B, C, N] + xyz2: sampled input points position data, [B, C, S] + points1: input points features, [B, D, N] + points2: sampled points features, [B, D, S] + + Return: + new_points: upsampled points data, [B, D', N] + """ + xyz1 = xyz1.permute(0, 2, 1) + xyz2 = xyz2.permute(0, 2, 1) + + points2 = points2.permute(0, 2, 1) + bs, n_pts, _ = xyz1.shape + _, n_out_pts, _ = xyz2.shape + + if n_out_pts == 1: + interpolated_points = points2.repeat(1, n_pts, 1) + else: + dists = square_distance(xyz1, xyz2) + dists, idx = dists.sort(dim=-1) + dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3] + + dist_recip: Tensor = 1.0 / (dists + 1e-8) + norm = torch.sum(dist_recip, dim=2, keepdim=True) + weight = dist_recip / norm + interpolated_points = torch.sum( + index_points(points2, idx) * weight.view(bs, n_pts, 3, 1), + dim=2, + ) + + if points1 is not None: + points1 = points1.permute(0, 2, 1) + new_points = torch.cat([points1, interpolated_points], dim=-1) + else: + new_points = interpolated_points + + new_points = new_points.permute(0, 2, 1) + for i, conv in enumerate(self.mlp_convs): + bn = self.mlp_bns[i] if len(self.mlp_bns) != 0 else lambda x: x + new_points = F.relu(bn(conv(new_points))) + return new_points + + +class PointNet2SegmentationOut(NamedTuple): + """Prediction for the pointnet++ semantic segmentation network.""" + + class_logits: Tensor + + +class PointNet2Segmentation(nn.Module): # TODO, probably move to module? + """Pointnet++ Segmentation Network.""" + + def __init__(self, num_classes: int, in_channels: int = 3): + """Creates a new Pointnet++ for segmentation. + + Args: + num_classes: Number of semantic classes + in_channels: Number of input channels + """ + super().__init__() + + self.set_abstractions = [ + PointNetSetAbstraction( + 1024, 0.1, 32, in_channels + 3, [32, 32, 64], False + ), + PointNetSetAbstraction(256, 0.2, 32, 64 + 3, [64, 64, 128], False), + PointNetSetAbstraction( + 64, 0.4, 32, 128 + 3, [128, 128, 256], False + ), + PointNetSetAbstraction( + 16, 0.8, 32, 256 + 3, [256, 256, 512], False + ), + ] + + self.feature_propagations = [ + PointNetFeaturePropagation(768, [256, 256]), + PointNetFeaturePropagation(384, [256, 256]), + PointNetFeaturePropagation(320, [256, 128]), + PointNetFeaturePropagation(128 + 3, [128, 128, 128]), + ] + + # Final convolutions + self.conv1 = nn.Conv1d(128, 128, 1) + self.bn1 = nn.BatchNorm1d(128) + self.drop1 = nn.Dropout(0.5) + self.conv2 = nn.Conv1d(128, num_classes, 1) + self.in_channels = in_channels + + def __call__(self, xyz: Tensor) -> PointNet2SegmentationOut: + """Call implementation. + + Args: + xyz: Pointcloud data shaped [N, n_feats, n_pts] + + Returns: + PointNet2SegmentationOut, class logits for each point + """ + return self._call_impl(xyz) + + def forward(self, xyz: Tensor) -> PointNet2SegmentationOut: + """Predicts the semantic class logits for each point. + + Args: + xyz: Pointcloud data shaped [N, n_feats, n_pts]$ + + Returns: + PointNet2SegmentationOut, class logits for each point + """ + assert xyz.size(1) == self.in_channels + + l0_points = xyz + l0_xyz = xyz[:, :3, :] + + set_abstraction_out = PointNetSetAbstractionOut( + coordinates=l0_xyz, features=l0_points + ) + outputs: list[PointNetSetAbstractionOut] = [set_abstraction_out] + + for set_abs_layer in self.set_abstractions: + set_abstraction_out = set_abs_layer( + set_abstraction_out.coordinates, set_abstraction_out.features + ) + + outputs.append(set_abstraction_out) + + pointwise_features = outputs[-1].features + for idx, feature_prop_layer in enumerate(self.feature_propagations): + layer_after_out = outputs[-idx - 1] # l4 + layer_out = outputs[-idx - 2] # l3 + + out_features = ( + layer_out.features if idx < len(outputs) - 1 else None + ) + pointwise_features = feature_prop_layer( + layer_out.coordinates, + layer_after_out.coordinates, + out_features, + pointwise_features, + ) + + x = self.drop1(F.relu(self.bn1(self.conv1(pointwise_features)))) + x = self.conv2(x) + return PointNet2SegmentationOut(class_logits=x) diff --git a/vis4d/op/base/resnet.py b/vis4d/op/base/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..10e9f51f34f3dfde1cc5c1a6de702abf115f8e02 --- /dev/null +++ b/vis4d/op/base/resnet.py @@ -0,0 +1,609 @@ +"""Residual networks base model. + +Modified from mmdetection (https://github.com/open-mmlab/mmdetection). +""" + +from __future__ import annotations + +from collections.abc import Sequence + +import torchvision.models.resnet as _resnet +from torch import Tensor, nn +from torch.nn.modules.batchnorm import _BatchNorm +from torch.utils.checkpoint import checkpoint + +from vis4d.common.ckpt import load_model_checkpoint +from vis4d.common.typing import ArgsType +from vis4d.op.layer.util import build_conv_layer, build_norm_layer +from vis4d.op.layer.weight_init import constant_init, kaiming_init + +from .base import BaseModel + + +class BasicBlock(nn.Module): + """BasicBlock.""" + + expansion = 1 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + dilation: int = 1, + downsample: nn.Module | None = None, + style: str = "pytorch", + use_checkpoint: bool = False, + with_dcn: bool = False, + norm: str = "BatchNorm2d", + ) -> None: + """Creates an instance of the class.""" + super().__init__() + assert style in {"pytorch", "caffe"} # No effect for BasicBlock + assert not with_dcn, "DCN is not supported for BasicBlock." + self.conv1 = build_conv_layer( + inplanes, + planes, + 3, + stride=stride, + dilation=dilation, + padding=dilation, + bias=False, + ) + self.bn1 = build_norm_layer(norm, planes) + self.conv2 = build_conv_layer(planes, planes, 3, padding=1, bias=False) + self.bn2 = build_norm_layer(norm, planes) + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + self.use_checkpoint = use_checkpoint + + def forward(self, x: Tensor) -> Tensor: + """Forward function.""" + + def _inner_forward(x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.use_checkpoint and x.requires_grad: + out = checkpoint(_inner_forward, x, use_reentrant=True) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + """Bottleneck.""" + + expansion = 4 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + dilation: int = 1, + downsample: nn.Module | None = None, + style: str = "pytorch", + use_checkpoint: bool = False, + with_dcn: bool = False, + norm: str = "BatchNorm2d", + ) -> None: + """Bottleneck block for ResNet. + + If style is "pytorch", the stride-two layer is the 3x3 conv layer, if + it is "caffe", the stride-two layer is the first 1x1 conv layer. + """ + super().__init__() + self.inplanes = inplanes + self.planes = planes + self.stride = stride + self.dilation = dilation + self.use_checkpoint = use_checkpoint + + assert style in {"pytorch", "caffe"} + if style == "pytorch": + self.conv1_stride = 1 + self.conv2_stride = stride + else: + self.conv1_stride = stride + self.conv2_stride = 1 + + self.conv1 = build_conv_layer( + inplanes, + planes, + kernel_size=1, + stride=self.conv1_stride, + bias=False, + ) + self.bn1 = build_norm_layer(norm, planes) + + self.conv2 = build_conv_layer( + planes, + planes, + kernel_size=3, + stride=self.conv2_stride, + padding=dilation, + dilation=dilation, + bias=False, + use_dcn=with_dcn, + ) + self.bn2 = build_norm_layer(norm, planes) + + self.conv3 = build_conv_layer( + planes, + planes * self.expansion, + kernel_size=1, + bias=False, + ) + self.bn3 = build_norm_layer(norm, planes * self.expansion) + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + + def forward(self, x: Tensor) -> Tensor: + """Forward function.""" + + def _inner_forward(x: Tensor) -> Tensor: + identity = x + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.use_checkpoint and x.requires_grad: + out = checkpoint(_inner_forward, x, use_reentrant=True) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +class ResNet(BaseModel): + """ResNet BaseModel.""" + + arch_settings = { + "resnet18": (18, BasicBlock, (2, 2, 2, 2)), + "resnet34": (34, BasicBlock, (3, 4, 6, 3)), + "resnet50": (50, Bottleneck, (3, 4, 6, 3)), + "resnet101": (101, Bottleneck, (3, 4, 23, 3)), + "resnet152": (152, Bottleneck, (3, 8, 36, 3)), + } + + def __init__( + self, + resnet_name: str, + in_channels: int = 3, + stem_channels: int | None = None, + base_channels: int = 64, + num_stages: int = 4, + strides: Sequence[int] = (1, 2, 2, 2), + dilations: Sequence[int] = (1, 1, 1, 1), + style: str = "pytorch", + deep_stem: bool = False, + avg_down: bool = False, + trainable_layers: int = 5, + norm: str = "BatchNorm2d", + norm_frozen: bool = True, + stages_with_dcn: Sequence[bool] = (False, False, False, False), + replace_stride_with_dilation: Sequence[bool] = (False, False, False), + use_checkpoint: bool = False, + zero_init_residual: bool = True, + pretrained: bool = False, + weights: None | str = None, + ) -> None: + """Create ResNet. + + Args: + resnet_name (str): Name of the ResNet variant. + in_channels (int): Number of input image channels. Default: 3. + stem_channels (int | None): Number of stem channels. If not + specified, it will be the same as `base_channels`. Default: + None. + base_channels (int): Number of base channels of res layer. Default: + 64. + num_stages (int): Resnet stages. Default: 4. + strides (Sequence[int]): Strides of the first block of each stage. + Default: (1, 2, 2, 2). + dilations (Sequence[int]): Dilation of each stage. Default: (1, 1, + 1, 1) + style (str): `pytorch` or `caffe`. If set to "pytorch", the + stride-two layer is the 3x3 conv layer, otherwise the + stride-two layer is the first 1x1 conv layer. Default: pytorch. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. + Default: False. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False. + trainable_layers (int, optional): Number layers for training or + fine-tuning. 5 means all the layers can be fine-tuned. Defaults + to 5. + norm (str): Normalization layer str. Default: BatchNorm2d, which + means using `nn.BatchNorm2d`. + norm_frozen (bool): Whether to set norm layers to eval mode. It + freezes running stats (mean and var). Note: Effect on + Batch Norm and its variants only. + stages_with_dcn (Sequence[bool]): Indices of stages with deformable + convolutions. Default: (False, False, False, False). + replace_stride_with_dilation (Sequence[bool]): Whether to replace + stride with dilation. Default: (False, False, False). + use_checkpoint (bool): Use checkpoint or not. Using checkpoint will + save some memory while slowing down the training speed. + Default: False. + zero_init_residual (bool): Whether to use zero init for last norm + layer in resblocks to let them behave as identity. + Default: True. + pretrained (bool): Whether to load pretrained weights. Default: + False. + weights (str, optional): model pretrained path. Default: None + """ + super().__init__() + self._norm = norm + + self.zero_init_residual = zero_init_residual + if resnet_name not in self.arch_settings: + raise KeyError(f"invalid architecture {resnet_name} for ResNet") + self.name = resnet_name + self.deep_stem = deep_stem + self.trainable_layers = trainable_layers + + self.use_checkpoint = use_checkpoint + self.norm_frozen = norm_frozen + + depth, self.block, stage_blocks = self.arch_settings[resnet_name] + assert isinstance(depth, int) + + self.depth = depth + stem_channels = stem_channels or base_channels + + assert 4 >= num_stages >= 1 + assert len(strides) == len(dilations) == num_stages + + self.stage_blocks = stage_blocks[:num_stages] + self.inplanes = stem_channels + + self._make_stem_layer(in_channels, stem_channels) + + self.res_layers = [] + for i, num_blocks in enumerate(self.stage_blocks): + if i > 0 and replace_stride_with_dilation[i - 1]: + dilation = strides[i] + stride = 1 + else: + stride = strides[i] + dilation = dilations[i] + planes = base_channels * 2**i + res_layer = self._make_res_layer( + block=self.block, # type: ignore + inplanes=self.inplanes, + planes=planes, + num_blocks=num_blocks, + stride=stride, + dilation=dilation, + style=style, + avg_down=avg_down, + use_checkpoint=use_checkpoint, + with_dcn=stages_with_dcn[i], + ) + self.inplanes = planes * self.block.expansion # type: ignore + layer_name = f"layer{i + 1}" + self.add_module(layer_name, res_layer) + self.res_layers.append(layer_name) + + if pretrained: + if weights is None: + # default loading the imagenet-1k v1 pre-trained model weights + weights = _resnet.__dict__[ + f"ResNet{depth}_Weights" + ].IMAGENET1K_V1.url + + load_model_checkpoint(self, weights) + else: + self._init_weights() + + self._freeze_stages() + + def _init_weights(self) -> None: + """Initialize the weights of module.""" + for m in self.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + constant_init(m, 1) + + if self.zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck) and isinstance( + m.bn3.weight, nn.Parameter + ): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock) and isinstance( + m.bn2.weight, nn.Parameter + ): + nn.init.constant_(m.bn2.weight, 0) + + def _make_stem_layer(self, in_channels: int, stem_channels: int) -> None: + """Make stem layer for ResNet.""" + if self.deep_stem: + self.stem = nn.Sequential( + build_conv_layer( + in_channels, + stem_channels // 2, + kernel_size=3, + stride=2, + padding=1, + bias=False, + ), + build_norm_layer(self._norm, stem_channels // 2), + nn.ReLU(inplace=True), + build_conv_layer( + stem_channels // 2, + stem_channels // 2, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ), + build_norm_layer(self._norm, stem_channels // 2), + nn.ReLU(inplace=True), + build_conv_layer( + stem_channels // 2, + stem_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ), + build_norm_layer(self._norm, stem_channels), + nn.ReLU(inplace=True), + ) + else: + self.conv1 = build_conv_layer( + in_channels, + stem_channels, + kernel_size=7, + stride=2, + padding=3, + bias=False, + ) + self.bn1 = build_norm_layer(self._norm, stem_channels) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + def _make_res_layer( + self, + block: BasicBlock | Bottleneck, + inplanes: int, + planes: int, + num_blocks: int, + stride: int, + dilation: int, + style: str, + avg_down: bool, + use_checkpoint: bool, + with_dcn: bool, + ) -> nn.Sequential: + """Pack all blocks in a stage into a ``ResLayer``.""" + layers: list[BasicBlock | Bottleneck] = [] + downsample: nn.Module | None = None + if stride != 1 or inplanes != planes * block.expansion: + downsample_list: list[nn.AvgPool2d | nn.Module] = [] + conv_stride = stride + if avg_down: + conv_stride = 1 + downsample_list.append( + nn.AvgPool2d( + kernel_size=stride, + stride=stride, + ceil_mode=True, + count_include_pad=False, + ) + ) + downsample_list.extend( + [ + build_conv_layer( + inplanes, + planes * block.expansion, + kernel_size=1, + stride=conv_stride, + bias=False, + ), + build_norm_layer(self._norm, planes * block.expansion), + ] + ) + downsample = nn.Sequential(*downsample_list) + + layers = [] + layers.append( + block( + inplanes=inplanes, + planes=planes, + stride=stride, + dilation=dilation, + downsample=downsample, + style=style, + use_checkpoint=use_checkpoint, + with_dcn=with_dcn, + norm=self._norm, + ) + ) + inplanes = planes * block.expansion + for _ in range(1, num_blocks): + layers.append( + block( + inplanes=inplanes, + planes=planes, + stride=1, + dilation=dilation, + style=style, + use_checkpoint=use_checkpoint, + with_dcn=with_dcn, + norm=self._norm, + ) + ) + return nn.Sequential(*layers) + + def _freeze_stages(self) -> None: + """Freeze stages param and norm stats.""" + if self.trainable_layers < 5: + if self.deep_stem: + self.stem.eval() + for param in self.stem.parameters(): + param.requires_grad = False + else: + self.bn1.eval() + for m in (self.conv1, self.bn1): + for param in m.parameters(): + param.requires_grad = False + + for i in range(1, 5 - self.trainable_layers): + m = getattr(self, f"layer{i}") + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def train(self, mode: bool = True) -> ResNet: + """Override the train mode for the model.""" + super().train(mode) + self._freeze_stages() + + if mode and self.norm_frozen: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + return self + + @property + def out_channels(self) -> list[int]: + """Get the number of channels for each level of feature pyramid. + + Returns: + list[int]: number of channels + """ + if self.name in {"resnet18", "resnet34"}: + # channels = [3, 3] + [64 * 2**i for i in range(4)] + channels = [3, 3, 64, 128, 256, 512] + else: + # channels = [3, 3] + [256 * 2**i for i in range(4)] + channels = [3, 3, 256, 512, 1024, 2048] + return channels + + def forward(self, images: Tensor) -> list[Tensor]: + """Forward function. + + Args: + images (Tensor[N, C, H, W]): Image input to process. Expected to + type float32 with values ranging 0..255. + + Returns: + fp (list[torch.Tensor]): The output feature pyramid. The list index + represents the level, which has a downsampling raio of 2^index. + fp[0] and fp[1] is a reference to the input images and + torchvision resnet downsamples the feature maps by 4 directly. + The last feature map downsamples the input image by 64 with a + pooling layer on the second last map. + """ + if self.deep_stem: + x = self.stem(images) + else: + x = self.conv1(images) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + outs = [images, images] + for _, layer_name in enumerate(self.res_layers): + res_layer = getattr(self, layer_name) + x = res_layer(x) + outs.append(x) + return outs + + +class ResNetV1c(ResNet): + """ResNetV1c variant with a deeper stem. + + Compared with default ResNet, ResNetV1c replaces the 7x7 conv in the input + stem with three 3x3 convs. For more details please refer to `Bag of Tricks + for Image Classification with Convolutional Neural Networks + `. + """ + + model_urls = { + "resnet50_v1c": ( + "https://download.openmmlab.com/pretrain/third_party/" + "resnet50_v1c-2cccc1ad.pth" + ), + "resnet101_v1c": ( + "https://download.openmmlab.com/pretrain/third_party/" + "resnet101_v1c-e67eebb6.pth" + ), + } + + def __init__( + self, + resnet_name: str, + pretrained: bool = False, + weights: str | None = None, + **kwargs: ArgsType, + ): + """Initialize ResNetV1c. + + Args: + resnet_name (str): Name of the resnet model. + pretrained (bool, optional): Whether to load ImageNet pre-trained + weights. Defaults to False. + weights (str, optional): Path to custom pretrained weights. + **kwargs: Arguments for ResNet. + """ + assert resnet_name in { + "resnet18_v1c", + "resnet34_v1c", + "resnet50_v1c", + "resnet101_v1c", + } + if pretrained and weights is None: + assert resnet_name in { + "resnet50_v1c", + "resnet101_v1c", + }, "Only resnet50_v1c and resnet101_v1c have pretrained weights." + weights = self.model_urls[resnet_name] + + super().__init__( + resnet_name[:-4], + deep_stem=True, + pretrained=pretrained, + weights=weights, + **kwargs, + ) diff --git a/vis4d/op/base/unet.py b/vis4d/op/base/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..6bfbf4d1815f59316f052372f59e6584eb9dcb88 --- /dev/null +++ b/vis4d/op/base/unet.py @@ -0,0 +1,169 @@ +"""Unet Implementation based on https://arxiv.org/abs/1505.04597. + +Code taken from https://github.com/jaxony/unet-pytorch/blob/master/model.py +and modified to include typing and custom ops. +""" + +from __future__ import annotations + +from typing import NamedTuple + +import torch +from torch import nn + +from vis4d.op.layer.conv2d import UnetDownConv, UnetUpConv + + +class UNetOut(NamedTuple): + """Output of the UNet operator. + + logits: Final output of the network without applying softmax + intermediate_features: Intermediate features of the upsampling path + at different scales. + """ + + logits: torch.Tensor + intermediate_features: list[torch.Tensor] + + +class UNet(nn.Module): + """The U-Net is a convolutional encoder-decoder neural network. + + Contextual spatial information (from the decoding, + expansive pathway) about an input tensor is merged with + information representing the localization of details + (from the encoding, compressive pathway). + + Modifications to the original paper: + (1) padding is used in 3x3 convolutions to prevent loss + of border pixels + (2) merging outputs does not require cropping due to (1) + (3) residual connections can be used by specifying + UNet(merge_mode='add') + (4) if non-parametric upsampling is used in the decoder + pathway (specified by upmode='upsample'), then an + additional 1x1 2d convolution occurs after upsampling + to reduce channel dimensionality by a factor of 2. + This channel halving happens with the convolution in + the tranpose convolution (specified by upmode='transpose') + """ + + def __init__( + self, + num_classes: int, + in_channels: int = 3, + depth: int = 5, + start_filts: int = 32, + up_mode: str = "transpose", + merge_mode: str = "concat", + ): + """Unet Operator. + + Args: + in_channels: int, number of channels in the input tensor. + Default is 3 for RGB images. + num_classes: int, number of output classes. + depth: int, number of MaxPools in the U-Net. + start_filts: int, number of convolutional filters for the + first conv. + up_mode: string, type of upconvolution. Choices: 'transpose' + for transpose convolution or 'upsample' for nearest neighbour + upsampling. + merge_mode: string, how to merge features, can be 'concat' or 'add' + + + Raises: + ValueError: if invalid modes are provided + """ + super().__init__() + + if up_mode in {"transpose", "upsample"}: + self.up_mode = up_mode + else: + raise ValueError( + f"{up_mode} is not a valid mode for upsampling. Only" + f"'transpose' and 'upsample' are allowed." + ) + + if merge_mode in {"concat", "add"}: + self.merge_mode = merge_mode + else: + raise ValueError( + f'"{up_mode}" is not a valid mode for' + f"merging up and down paths. " + f'Only "concat" and ' + f'"add" are allowed.' + ) + + # NOTE: up_mode 'upsample' is incompatible with merge_mode 'add' + if self.up_mode == "upsample" and self.merge_mode == "add": + raise ValueError( + 'up_mode "upsample" is incompatible ' + 'with merge_mode "add" at the moment ' + "because it doesn't make sense to use " + "nearest neighbour to reduce " + "depth channels (by half)." + ) + + self.num_classes = num_classes + self.in_channels = in_channels + self.start_filts = start_filts + self.depth = depth + + self.down_convs: nn.ModuleList = nn.ModuleList() + + # create the encoder pathway and add to a list + for i in range(depth): + ins = self.in_channels if i == 0 else outs # type: ignore + outs = self.start_filts * (2**i) + pooling = i < (depth - 1) + + down_conv = UnetDownConv(ins, outs, pooling=pooling) + self.down_convs.append(down_conv) + + self.up_convs: nn.ModuleList = nn.ModuleList() + + # create the decoder pathway and add to a list + # - careful! decoding only requires depth-1 blocks + for i in range(depth - 1): + ins = outs + outs = ins // 2 + up_conv = UnetUpConv( + ins, outs, up_mode=up_mode, merge_mode=merge_mode + ) + self.up_convs.append(up_conv) + self.conv_final = nn.Conv2d( + outs, num_classes, kernel_size=1, groups=1, stride=1 + ) + + def __call__(self, data: torch.Tensor) -> UNetOut: + """Applies the UNet. + + Args: + data (tensor): Input Images into the network shape [N, C, W, H] + + """ + return self._call_impl(data) + + def forward(self, data: torch.Tensor) -> UNetOut: + """Applies the UNet. + + Args: + data (tensor): Input Images into the network shape [N, C, W, H] + """ + encoder_outs: list[torch.Tensor] = [] + inter_feats: list[torch.Tensor] = [] + # encoder pathway, save outputs for merging + + for down_conv in self.down_convs: + out = down_conv(data) + data = out.pooled_features + encoder_outs.append(out.features) + + for level, up_conv in enumerate(self.up_convs): + before_pool = encoder_outs[-(level + 2)] + data = up_conv(before_pool, data) + inter_feats.append(data) + + logits = self.conv_final(data) + return UNetOut(logits=logits, intermediate_features=inter_feats) diff --git a/vis4d/op/base/vgg.py b/vis4d/op/base/vgg.py new file mode 100644 index 0000000000000000000000000000000000000000..a887ea03bb52c764adc8b56580f861665221c68b --- /dev/null +++ b/vis4d/op/base/vgg.py @@ -0,0 +1,107 @@ +"""Residual networks for classification.""" + +from __future__ import annotations + +import torch +import torchvision.models.vgg as _vgg +from torchvision.models._utils import IntermediateLayerGetter + +from .base import BaseModel + + +class VGG(BaseModel): + """Wrapper for torch vision VGG.""" + + def __init__( + self, + vgg_name: str, + trainable_layers: None | int = None, + pretrained: bool = False, + ): + """Initialize the VGG base model from torchvision. + + Args: + vgg_name (str): name of the VGG variant. Choices in ["vgg11", + "vgg13", "vgg16", "vgg19", "vgg11_bn", "vgg13_bn", "vgg16_bn", + "vgg19_bn"]. + trainable_layers (int, optional): Number layers for training or + fine-tuning. None means all the layers can be fine-tuned. + pretrained (bool, optional): Whether to load ImageNet + pre-trained weights. Defaults to False. + + Raises: + ValueError: The VGG name is not supported + """ + super().__init__() + if vgg_name not in [ + "vgg11", + "vgg13", + "vgg16", + "vgg19", + "vgg11_bn", + "vgg13_bn", + "vgg16_bn", + "vgg19_bn", + ]: + raise ValueError("The VGG name is not supported!") + + weights = "IMAGENET1K_V1" if pretrained else None + vgg = _vgg.__dict__[vgg_name](weights=weights) + use_bn = vgg_name[-3:] == "_bn" + self._out_channels: list[int] = [] + returned_layers = [] + last_channel = -1 + layer_counter = 0 + + vgg_channels = _vgg.cfgs[ + {"vgg11": "A", "vgg13": "B", "vgg16": "D", "vgg19": "E"}[ + vgg_name[:5] + ] + ] + for channel in vgg_channels: + if channel == "M": + returned_layers.append(layer_counter) + self._out_channels.append(last_channel) + layer_counter += 1 + else: + if use_bn: + layer_counter += 3 + else: + layer_counter += 2 + last_channel = channel + + if trainable_layers is not None: + for name, parameter in vgg.features.named_parameters(): + layer_ind = int(name.split(".")[0]) + if layer_ind < layer_counter - trainable_layers: + parameter.requires_grad_(False) + + return_layers = {str(v): str(i) for i, v in enumerate(returned_layers)} + self.body = IntermediateLayerGetter( + vgg.features, return_layers=return_layers + ) + self.name = vgg_name + + @property + def out_channels(self) -> list[int]: + """Get the number of channels for each level of feature pyramid. + + Returns: + list[int]: number of channels + """ + return [3, 3, *self._out_channels] + + def forward(self, images: torch.Tensor) -> list[torch.Tensor]: + """VGG feature forward without classification head. + + Args: + images (Tensor[N, C, H, W]): Image input to process. Expected to + type float32 with values ranging 0..255. + + Returns: + fp (list[torch.Tensor]): The output feature pyramid. The list index + represents the level, which has a downsampling raio of 2^index. + fp[0] and fp[1] is a reference to the input images. The last + feature map downsamples the input image by 64. + """ + return [images, images, *self.body(images).values()] diff --git a/vis4d/op/base/vit.py b/vis4d/op/base/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..186a36f4c678d8d033b5bb997051af9b40e8f553 --- /dev/null +++ b/vis4d/op/base/vit.py @@ -0,0 +1,271 @@ +"""Residual networks for classification.""" + +from __future__ import annotations + +import torch +from timm.models import named_apply +from torch import nn + +from vis4d.op.layer.patch_embed import PatchEmbed +from vis4d.op.layer.transformer import TransformerBlock + +from .base import BaseModel + + +def _init_weights_vit_timm( # pylint: disable=unused-argument + module: nn.Module, name: str +) -> None: + """Weight initialization, original timm impl (for reproducibility).""" + if isinstance(module, nn.Linear): + nn.init.trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif hasattr(module, "init_weights"): + module.init_weights() # type: ignore + + +ViT_PRESET = { # pylint: disable=consider-using-namedtuple-or-dataclass + "vit_tiny_patch16_224": { + "patch_size": 16, + "embed_dim": 192, + "depth": 12, + "num_heads": 3, + }, + "vit_small_patch16_224": { + "patch_size": 16, + "embed_dim": 384, + "depth": 12, + "num_heads": 6, + }, + "vit_base_patch16_224": { + "patch_size": 16, + "embed_dim": 768, + "depth": 12, + "num_heads": 12, + }, + "vit_large_patch16_224": { + "patch_size": 16, + "embed_dim": 1024, + "depth": 24, + "num_heads": 16, + }, + "vit_huge_patch16_224": { + "patch_size": 16, + "embed_dim": 1280, + "depth": 32, + "num_heads": 16, + }, + "vit_small_patch32_224": { + "patch_size": 32, + "embed_dim": 384, + "depth": 12, + "num_heads": 6, + }, + "vit_base_patch32_224": { + "patch_size": 32, + "embed_dim": 768, + "depth": 12, + "num_heads": 12, + }, + "vit_large_patch32_224": { + "patch_size": 32, + "embed_dim": 1024, + "depth": 24, + "num_heads": 16, + }, + "vit_huge_patch32_224": { + "patch_size": 32, + "embed_dim": 1280, + "depth": 32, + "num_heads": 16, + }, +} + + +class VisionTransformer(BaseModel): + """Vision Transformer (ViT) model without classification head. + + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for + Image Recognition at Scale` + - https://arxiv.org/abs/2010.11929 + + Adapted from: + - pytorch vision transformer impl + - timm vision transformer impl + """ + + def __init__( + self, + img_size: int = 224, + patch_size: int = 16, + in_channels: int = 3, + num_classes: int = 1000, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + init_values: float | None = None, + class_token: bool = True, + no_embed_class: bool = False, + pre_norm: bool = False, + pos_drop_rate: float = 0.0, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + norm_layer: nn.Module | None = None, + act_layer: nn.Module = nn.GELU(), + ) -> None: + """Init VisionTransformer. + + Args: + img_size (int, optional): Input image size. Defaults to 224. + patch_size (int, optional): Patch size. Defaults to 16. + in_channels (int, optional): Number of input channels. Defaults to + 3. + num_classes (int, optional): Number of classes. Defaults to 1000. + embed_dim (int, optional): Embedding dimension. Defaults to 768. + depth (int, optional): Depth. Defaults to 12. + num_heads (int, optional): Number of attention heads. Defaults to + 12. + mlp_ratio (float, optional): Ratio of MLP hidden dim to embedding + dim. Defaults to 4.0. + qkv_bias (bool, optional): If to add bias to qkv. Defaults to True. + init_values (float, optional): Initial values for layer scale. + Defaults to None. + class_token (bool, optional): If to add a class token. Defaults to + True. + no_embed_class (bool, optional): If to not embed class token. + Defaults to False. + pre_norm (bool, optional): If to use pre-norm. Defaults to False. + pos_drop_rate (float, optional): Postional dropout rate. Defaults + to 0.0. + drop_rate (float, optional): Dropout rate. Defaults to 0.0. + attn_drop_rate (float, optional): Attention dropout rate. Defaults + to 0.0. + drop_path_rate (float, optional): Drop path rate. Defaults to 0.0. + embed_layer (nn.Module, optional): Embedding layer. Defaults to + PatchEmbed. + norm_layer (nn.Module, optional): Normalization layer. If None, + nn.LayerNorm is used. Defaults to None. + act_layer (nn.Module, optional): Activation layer. Defaults to + nn.GELU(). + """ + super().__init__() + self.num_classes = num_classes + self.num_features = self.embed_dim = ( + embed_dim # num_features for consistency with other models + ) + self.num_depth = depth + self.num_prefix_tokens = 1 if class_token else 0 + self.no_embed_class = no_embed_class + + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=embed_dim, + bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP) + ) + num_patches = self.patch_embed.num_patches + + self.cls_token = ( + nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None + ) + embed_len = ( + num_patches + if no_embed_class + else num_patches + self.num_prefix_tokens + ) + self.pos_embed = nn.Parameter(torch.zeros(1, embed_len, embed_dim)) + self.pos_drop = nn.Dropout(p=pos_drop_rate) + self.norm_pre = ( + nn.LayerNorm(embed_dim, eps=1e-6) if pre_norm else nn.Identity() + ) + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + blocks = [ + TransformerBlock( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + init_values=init_values, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ) + for i in range(depth) + ] + self.blocks = nn.ModuleList(blocks) + self.init_weights() + + def init_weights(self) -> None: + """Init weights using timm's implementation.""" + nn.init.trunc_normal_(self.pos_embed, std=0.02) + if self.cls_token is not None: + nn.init.normal_(self.cls_token, std=1e-6) + named_apply(_init_weights_vit_timm, self) + + def _pos_embed(self, x: torch.Tensor) -> torch.Tensor: + """Add positional embeddings.""" + if self.no_embed_class: + # deit-3, updated JAX (big vision) + # position embedding does not overlap with class token, add then + # concat + x = x + self.pos_embed + if self.cls_token is not None: + x = torch.cat( + (self.cls_token.expand(x.shape[0], -1, -1), x), dim=1 + ) + else: + # original timm, JAX, and deit vit impl + # pos_embed has entry for class token, concat then add + if self.cls_token is not None: + x = torch.cat( + (self.cls_token.expand(x.shape[0], -1, -1), x), dim=1 + ) + x = x + self.pos_embed + return self.pos_drop(x) + + @property + def out_channels(self) -> list[int]: + """Return the number of output channels per feature level.""" + return [self.embed_dim] * (self.num_depth + 1) + + def __call__(self, data: torch.Tensor) -> list[torch.Tensor]: + """Applies the ViT encoder. + + Args: + data (tensor): Input Images into the network shape [N, C, W, H] + + """ + return self._call_impl(data) + + def forward(self, images: torch.Tensor) -> list[torch.Tensor]: + """Forward pass. + + Args: + images (torch.Tensor): Input images tensor of shape (B, C, H, W). + + Returns: + feats (list[torch.Tensor]): Features of the input images extracted + by the ViT encoder. feats[0] is the input images, and feats[1] + is the output of the patch embedding layer. The rest of the + elements are the outputs of each transformer block, with the + shape (B, N, dim), where N is the number of patches, and dim + is the embedding dimension. The final element is the output of + the ViT encoder. + """ + feats = [images] + x = self.patch_embed(images) + x = self.norm_pre(self._pos_embed(x)) + feats.append(x) + for blk in self.blocks: + x = blk(x) + feats.append(x) + return feats diff --git a/vis4d/op/box/__init__.py b/vis4d/op/box/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c5fa78566ff8e3e3eb9eca8a7d31f3fe85ee8ebf --- /dev/null +++ b/vis4d/op/box/__init__.py @@ -0,0 +1 @@ +"""Operations on 2D bounding boxes.""" diff --git a/vis4d/op/box/__pycache__/__init__.cpython-311.pyc b/vis4d/op/box/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c502b757743d77862f0557e3e924f015b1c7897 Binary files /dev/null and b/vis4d/op/box/__pycache__/__init__.cpython-311.pyc differ diff --git a/vis4d/op/box/__pycache__/box2d.cpython-311.pyc b/vis4d/op/box/__pycache__/box2d.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..735c1ba37b969b4fc97635c8b423d0f479b167a9 Binary files /dev/null and b/vis4d/op/box/__pycache__/box2d.cpython-311.pyc differ diff --git a/vis4d/op/box/__pycache__/box3d.cpython-311.pyc b/vis4d/op/box/__pycache__/box3d.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c21bd9e20cb638d7c30d17d8d6db614b40fae624 Binary files /dev/null and b/vis4d/op/box/__pycache__/box3d.cpython-311.pyc differ diff --git a/vis4d/op/box/anchor/__init__.py b/vis4d/op/box/anchor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..95be90d588d53e64af33827bfec13970bdc5acac --- /dev/null +++ b/vis4d/op/box/anchor/__init__.py @@ -0,0 +1,6 @@ +"""Anchor and point generators.""" + +from .anchor_generator import AnchorGenerator, anchor_inside_image +from .point_generator import MlvlPointGenerator + +__all__ = ["AnchorGenerator", "anchor_inside_image", "MlvlPointGenerator"] diff --git a/vis4d/op/box/anchor/anchor_generator.py b/vis4d/op/box/anchor/anchor_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..2a4efc383e213af2f5d5473e66103506b273f439 --- /dev/null +++ b/vis4d/op/box/anchor/anchor_generator.py @@ -0,0 +1,329 @@ +"""Anchor generator for 2D bounding boxes. + +Modified from: +https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/anchor/anchor_generator.py +""" + +from __future__ import annotations + +import numpy as np +import torch +from torch import Tensor +from torch.nn.modules.utils import _pair + +from .util import meshgrid + + +def anchor_inside_image( + flat_anchors: Tensor, img_shape: tuple[int, int], allowed_border: int = 0 +) -> Tensor: + """Check whether the anchors are inside the border. + + Args: + flat_anchors (Tensor): Flatten anchors, shape (n, 4). + img_shape (tuple(int)): Shape of current image. + allowed_border (int): The border to allow the valid anchor. + Defaults to 0. + + Returns: + Tensor: Flags indicating whether the anchors are inside a valid range. + """ + img_h, img_w = img_shape + inside_flags = ( + (flat_anchors[:, 0] >= -allowed_border) + & (flat_anchors[:, 1] >= -allowed_border) + & (flat_anchors[:, 2] < img_w + allowed_border) + & (flat_anchors[:, 3] < img_h + allowed_border) + ) + return inside_flags + + +class AnchorGenerator: + """Standard anchor generator for 2D anchor-based detectors. + + Examples: + >>> from vis4d.op.box.anchor import AnchorGenerator + >>> self = AnchorGenerator([16], [1.], [1.], [9]) + >>> all_anchors = self.grid_priors([(2, 2)], device='cpu') + >>> print(all_anchors) + [tensor([[-4.5000, -4.5000, 4.5000, 4.5000], + [11.5000, -4.5000, 20.5000, 4.5000], + [-4.5000, 11.5000, 4.5000, 20.5000], + [11.5000, 11.5000, 20.5000, 20.5000]])] + >>> self = AnchorGenerator([16, 32], [1.], [1.], [9, 18]) + >>> all_anchors = self.grid_priors([(2, 2), (1, 1)], device='cpu') + >>> print(all_anchors) + [tensor([[-4.5000, -4.5000, 4.5000, 4.5000], + [11.5000, -4.5000, 20.5000, 4.5000], + [-4.5000, 11.5000, 4.5000, 20.5000], + [11.5000, 11.5000, 20.5000, 20.5000]]), \ + tensor([[-9., -9., 9., 9.]])] + """ + + def __init__( + self, + strides: list[int] | list[tuple[int, int]], + ratios: list[float], + scales: list[int] | None = None, + base_sizes: list[int] | None = None, + scale_major: bool = True, + octave_base_scale: None | int = None, + scales_per_octave: None | int = None, + centers: list[tuple[float, float]] | None = None, + center_offset: float = 0.0, + ) -> None: + """Creates an instance of the class. + + Args: + strides (list[int] | list[tuple[int, int]]): Strides of anchors + in multiple feature levels in order (w, h). + ratios (list[float]): The list of ratios between the height and + width of anchors in a single level. + scales (list[int] | None): Anchor scales for anchors in a single + level. It cannot be set at the same time if `octave_base_scale` + and `scales_per_octave` are set. + base_sizes (list[int] | None): The basic sizes + of anchors in multiple levels. + If None is given, strides will be used as base_sizes. + (If strides are non square, the shortest stride is taken.) + scale_major (bool): Whether to multiply scales first when + generating base anchors. If true, the anchors in the same row + will have the same scales. By default it is True in V2.0 + octave_base_scale (int): The base scale of octave. + scales_per_octave (int): Number of scales for each octave. + `octave_base_scale` and `scales_per_octave` are usually used in + retinanet and the `scales` should be None when they are set. + centers (list[tuple[float, float]] | None): The centers of the + anchor relative to the feature grid center in multiple feature + levels. By default it is set to be None and not used. If a list + of tuple of float is given, they will be used to shift the + centers of anchors. + center_offset (float): The offset of center in proportion to + anchors' width and height. By default it is 0 in V2.0. + """ + # check center and center_offset + if center_offset != 0: + assert centers is None, ( + "center cannot be set when center_offset" + f"!=0, {centers} is given." + ) + if not 0 <= center_offset <= 1: + raise ValueError( + "center_offset should be in range [0, 1], " + f"{center_offset} is given." + ) + if centers is not None: + assert len(centers) == len(strides), ( + "The number of strides should be the same as centers, got " + f"{strides} and {centers}" + ) + + # calculate base sizes of anchors + self.strides = [_pair(stride) for stride in strides] + self.base_sizes = ( + [min(stride) for stride in self.strides] + if base_sizes is None + else base_sizes + ) + assert len(self.base_sizes) == len(self.strides), ( + "The number of strides should be the same as base sizes, got " + f"{self.strides} and {self.base_sizes}" + ) + + # calculate scales of anchors + assert ( + octave_base_scale is not None and scales_per_octave is not None + ) ^ (scales is not None), ( + "scales and octave_base_scale with scales_per_octave cannot" + " be set at the same time" + ) + if scales is not None: + self.scales = torch.Tensor(scales) + elif octave_base_scale is not None and scales_per_octave is not None: + octave_scales = np.array( + [ + 2 ** (i / scales_per_octave) + for i in range(scales_per_octave) + ] + ) + scales = octave_scales * octave_base_scale # type: ignore + self.scales = torch.Tensor(scales) + else: + raise ValueError( + "Either scales or octave_base_scale with " + "scales_per_octave should be set" + ) + + self.octave_base_scale = octave_base_scale + self.scales_per_octave = scales_per_octave + self.ratios = torch.Tensor(ratios) + self.scale_major = scale_major + self.centers = centers + self.center_offset = center_offset + self.base_anchors = self.gen_base_anchors() + + @property + def num_base_priors(self) -> list[int]: + """list[int]: The number of priors at a point on the feature grid.""" + return [base_anchors.size(0) for base_anchors in self.base_anchors] + + @property + def num_levels(self) -> int: + """int: number of feature levels that the generator will be applied.""" + return len(self.strides) + + def gen_base_anchors(self) -> list[Tensor]: + """Generate base anchors. + + Returns: + list(torch.Tensor): Base anchors of a feature grid in multiple \ + feature levels. + """ + multi_level_base_anchors = [] + for i, base_size in enumerate(self.base_sizes): + center = None + if self.centers is not None: + center = self.centers[i] + multi_level_base_anchors.append( + self.gen_single_level_base_anchors( + base_size, + scales=self.scales, + ratios=self.ratios, + center=center, + ) + ) + return multi_level_base_anchors + + def gen_single_level_base_anchors( + self, + base_size: int, + scales: Tensor, + ratios: Tensor, + center: tuple[float, float] | None = None, + ) -> Tensor: + """Generate base anchors of a single level. + + Args: + base_size (int): Basic size of an anchor. + scales (Tensor): Scales of the anchor. + ratios (Tensor): The ratio between between the height + and width of anchors in a single level. + center (tuple[float], optional): The center of the base anchor + related to a single feature grid. Defaults to None. + + Returns: + Tensor: Anchors in a single-level feature maps. + """ + width, height = base_size, base_size + if center is None: + x_center = self.center_offset * width + y_center = self.center_offset * height + else: + x_center, y_center = center + + h_ratios = torch.sqrt(ratios) + w_ratios = 1 / h_ratios + if self.scale_major: + ws = (width * w_ratios[:, None] * scales[None, :]).view(-1) + hs = (height * h_ratios[:, None] * scales[None, :]).view(-1) + else: + ws = (width * scales[:, None] * w_ratios[None, :]).view(-1) + hs = (height * scales[:, None] * h_ratios[None, :]).view(-1) + + # use float anchor and the anchor's center is aligned with the + # pixel center + base_anchors = [ + x_center - 0.5 * ws, + y_center - 0.5 * hs, + x_center + 0.5 * ws, + y_center + 0.5 * hs, + ] + + return torch.stack(base_anchors, dim=-1) + + def grid_priors( + self, + featmap_sizes: list[tuple[int, int]], + dtype: torch.dtype = torch.float32, + device: torch.device = torch.device("cpu"), + ) -> list[Tensor]: + """Generate grid anchors in multiple feature levels. + + Args: + featmap_sizes (list[tuple]): List of feature map sizes in + multiple feature levels. + dtype (torch.dtype): Dtype of priors. Default: torch.float32. + device (torch.device): The device where the anchors will be put on. + + Return: + list[Tensor]: Anchors in multiple feature levels. The sizes of each + tensor should be [N, 4], where + N = width * height * num_base_anchors, width and height + are the sizes of the corresponding feature level, + num_base_anchors is the number of anchors for that level. + """ + assert self.num_levels == len(featmap_sizes) + multi_level_anchors = [] + for i in range(self.num_levels): + anchors = self.single_level_grid_priors( + featmap_sizes[i], level_idx=i, dtype=dtype, device=device + ) + multi_level_anchors.append(anchors) + return multi_level_anchors + + def single_level_grid_priors( + self, + featmap_size: tuple[int, int], + level_idx: int, + dtype: torch.dtype = torch.float32, + device: torch.device = torch.device("cpu"), + ) -> Tensor: + """Generate grid anchors of a single level. + + Args: + featmap_size (tuple[int, int]): Size of the feature maps. + level_idx (int): The index of corresponding feature map level. + dtype (torch.dtype, optional): Data type of points. Defaults to + torch.float32. + device (torch.device): The device the tensor will be put on. + + Returns: + Tensor: Anchors in the overall feature maps. + """ + base_anchors = self.base_anchors[level_idx].to(device).to(dtype) + feat_h, feat_w = featmap_size + stride_w, stride_h = self.strides[level_idx] + # First create Range with the default dtype, than convert to + # target `dtype` for onnx exporting. + shift_x = torch.arange(0, feat_w, device=device).to(dtype) * stride_w + shift_y = torch.arange(0, feat_h, device=device).to(dtype) * stride_h + + shift_xx, shift_yy = meshgrid(shift_x, shift_y) + shifts = torch.stack([shift_xx, shift_yy, shift_xx, shift_yy], dim=-1) + # first feat_w elements correspond to the first row of shifts + # add A anchors (1, A, 4) to K shifts (K, 1, 4) to get + # shifted anchors (K, A, 4), reshape to (K*A, 4) + + all_anchors = base_anchors[None, :, :] + shifts[:, None, :] + all_anchors = all_anchors.view(-1, 4) + # first A rows correspond to A anchors of (0, 0) in feature map, + # then (0, 1), (0, 2), ... + return all_anchors + + def __repr__(self) -> str: + """str: a string that describes the module.""" + indent_str = " " + repr_str = self.__class__.__name__ + "(\n" + repr_str += f"{indent_str}strides={self.strides},\n" + repr_str += f"{indent_str}ratios={self.ratios},\n" + repr_str += f"{indent_str}scales={self.scales},\n" + repr_str += f"{indent_str}base_sizes={self.base_sizes},\n" + repr_str += f"{indent_str}scale_major={self.scale_major},\n" + repr_str += f"{indent_str}octave_base_scale=" + repr_str += f"{self.octave_base_scale},\n" + repr_str += f"{indent_str}scales_per_octave=" + repr_str += f"{self.scales_per_octave},\n" + repr_str += f"{indent_str}num_levels={self.num_levels}\n" + repr_str += f"{indent_str}centers={self.centers},\n" + repr_str += f"{indent_str}center_offset={self.center_offset})" + return repr_str diff --git a/vis4d/op/box/anchor/point_generator.py b/vis4d/op/box/anchor/point_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..c1bb100595339cf94856a9e0a3ad09be5d7f89ba --- /dev/null +++ b/vis4d/op/box/anchor/point_generator.py @@ -0,0 +1,210 @@ +"""Point generator for 2D bounding boxes. + +Modified from: +https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/anchor/point_generator.py +""" + +from __future__ import annotations + +import numpy as np +import torch +from torch.nn.modules.utils import _pair + +from .util import meshgrid + + +class MlvlPointGenerator: + """Standard points generator for multi-level feature maps. + + Used for 2D points-based detectors. + + Args: + strides (list[int] | list[tuple[int, int]]): Strides of anchors + in multiple feature levels in order (w, h). + offset (float): The offset of points, the value is normalized with + corresponding stride. Defaults to 0.5. + """ + + def __init__( + self, strides: list[int] | list[tuple[int, int]], offset: float = 0.5 + ): + """Init.""" + self.strides = [_pair(stride) for stride in strides] + self.offset = offset + + @property + def num_levels(self) -> int: + """Number of feature levels.""" + return len(self.strides) + + @property + def num_base_priors(self) -> list[int]: + """Number of points at a point on the feature grid.""" + return [1 for _ in range(len(self.strides))] + + def grid_priors( + self, + featmap_sizes: list[tuple[int, int]], + dtype: torch.dtype = torch.float32, + device: torch.device = torch.device("cuda"), + with_stride: bool = False, + ) -> list[torch.Tensor]: + """Generate grid points of multiple feature levels. + + Args: + featmap_sizes (list[tuple[int, int]]): List of feature map sizes in + multiple feature levels, each (H, W). + dtype (torch.dtype): Dtype of priors. Defaults to torch.float32. + device (torch.device): The device where the anchors will be put on. + Defaults to torch.device("cuda"). + with_stride (bool): Whether to concatenate the stride to the last + dimension of points. Defaults to False, + + Return: + list[torch.Tensor]: Points of multiple feature levels. + The sizes of each tensor should be (N, 2) when with stride is + ``False``, where N = width * height, width and height + are the sizes of the corresponding feature level, + and the last dimension 2 represent (coord_x, coord_y), + otherwise the shape should be (N, 4), + and the last dimension 4 represent + (coord_x, coord_y, stride_w, stride_h). + """ + assert self.num_levels == len(featmap_sizes) + multi_level_priors = [] + for i in range(self.num_levels): + priors = self.single_level_grid_priors( + featmap_sizes[i], + level_idx=i, + dtype=dtype, + device=device, + with_stride=with_stride, + ) + multi_level_priors.append(priors) + return multi_level_priors + + def single_level_grid_priors( + self, + featmap_size: tuple[int, int], + level_idx: int, + dtype: torch.dtype = torch.float32, + device: torch.device = torch.device("cuda"), + with_stride: bool = False, + ) -> torch.Tensor: + """Generate grid Points of a single level. + + Note: + This function is usually called by method ``self.grid_priors``. + + Args: + featmap_size (tuple[int, int]): Size of the feature maps, (H, W). + level_idx (int): The index of corresponding feature map level. + dtype (torch.dtype): Dtype of priors. Defaults to torch.float32. + device (torch.device): The device where the tensors will be put on. + Defaults to torch.device("cuda"). + with_stride (bool): Concatenate the stride to the last dimension + of points. Defaults to False, + + Return: + Tensor: Points of single feature levels. + The shape of tensor should be (N, 2) when with stride is + ``False``, where N = width * height, width and height + are the sizes of the corresponding feature level, + and the last dimension 2 represent (coord_x, coord_y), + otherwise the shape should be (N, 4), + and the last dimension 4 represent + (coord_x, coord_y, stride_w, stride_h). + """ + feat_h, feat_w = featmap_size + stride_w, stride_h = self.strides[level_idx] + shift_x = ( + torch.arange(0, feat_w, device=device) + self.offset + ) * stride_w + # keep featmap_size as Tensor instead of int, so that we + # can convert to ONNX correctly + shift_x = shift_x.to(dtype) + + shift_y = ( + torch.arange(0, feat_h, device=device) + self.offset + ) * stride_h + # keep featmap_size as Tensor instead of int, so that we + # can convert to ONNX correctly + shift_y = shift_y.to(dtype) + shift_xx, shift_yy = meshgrid(shift_x, shift_y) + if not with_stride: + shifts = torch.stack([shift_xx, shift_yy], dim=-1) + else: + # use `shape[0]` instead of `len(shift_xx)` for ONNX export + stride_w = shift_xx.new_full((shift_xx.shape[0],), stride_w).to( + dtype + ) + stride_h = shift_xx.new_full((shift_yy.shape[0],), stride_h).to( + dtype + ) + shifts = torch.stack( + [shift_xx, shift_yy, stride_w, stride_h], dim=-1 + ) + all_points = shifts.to(device) + return all_points + + def valid_flags( + self, + featmap_sizes: list[tuple[int, int]], + pad_shape: tuple[int, int], + device: torch.device = torch.device("cuda"), + ) -> list[torch.Tensor]: + """Generate valid flags of points of multiple feature levels. + + Args: + featmap_sizes (list[tuple[int, int]]): List of feature map sizes in + multiple feature levels, each (H, W). + pad_shape (tuple[int, int]): The padded shape of the image, (H, W). + device (torch.device): The device where the anchors will be put on. + Defaults to torch.device("cuda"). + + Return: + list(torch.Tensor): Valid flags of points of multiple levels. + """ + assert self.num_levels == len(featmap_sizes) + multi_level_flags = [] + for i in range(self.num_levels): + point_stride = self.strides[i] + feat_h, feat_w = featmap_sizes[i] + h, w = pad_shape[:2] + valid_feat_h = min(int(np.ceil(h / point_stride[1])), feat_h) + valid_feat_w = min(int(np.ceil(w / point_stride[0])), feat_w) + flags = self.single_level_valid_flags( + (feat_h, feat_w), (valid_feat_h, valid_feat_w), device=device + ) + multi_level_flags.append(flags) + return multi_level_flags + + def single_level_valid_flags( + self, + featmap_size: tuple[int, int], + valid_size: tuple[int, int], + device: torch.device = torch.device("cuda"), + ) -> torch.Tensor: + """Generate the valid flags of points of a single feature map. + + Args: + featmap_size (tuple[int, int]): The size of feature maps, (H, W). + valid_size (tuple[int, int]): The valid size of the feature maps, + (H, W). + device (torch.device, optional): The device where the flags will + be put on. Defaults to torch.device("cuda"). + + Returns: + torch.Tensor: The valid flags of each points in a single level + feature map. + """ + feat_h, feat_w = featmap_size + valid_h, valid_w = valid_size + assert valid_h <= feat_h and valid_w <= feat_w + valid_x = torch.zeros(feat_w, dtype=torch.bool, device=device) + valid_y = torch.zeros(feat_h, dtype=torch.bool, device=device) + valid_x[:valid_w] = 1 + valid_y[:valid_h] = 1 + valid_xx, valid_yy = meshgrid(valid_x, valid_y) + valid = valid_xx & valid_yy + return valid diff --git a/vis4d/op/box/anchor/util.py b/vis4d/op/box/anchor/util.py new file mode 100644 index 0000000000000000000000000000000000000000..fa314c86a3bd28cc832d67d11561a9632f241e23 --- /dev/null +++ b/vis4d/op/box/anchor/util.py @@ -0,0 +1,27 @@ +"""Anchor utils.""" + +from __future__ import annotations + +from torch import Tensor + + +def meshgrid( + x_grid: Tensor, y_grid: Tensor, row_major: bool = True +) -> tuple[Tensor, Tensor]: + """Generate mesh grid of x and y. + + Args: + x_grid (Tensor): Grids of x dimension. + y_grid (Tensor): Grids of y dimension. + row_major (bool, optional): Whether to return y grids first. + Defaults to True. + + Returns: + tuple[Tensor]: The mesh grids of x and y. + """ + # use shape instead of len to keep tracing while exporting to onnx + xx = x_grid.repeat(y_grid.shape[0]) + yy = y_grid.view(-1, 1).repeat(1, x_grid.shape[0]).view(-1) + if row_major: + return xx, yy + return yy, xx diff --git a/vis4d/op/box/box2d.py b/vis4d/op/box/box2d.py new file mode 100644 index 0000000000000000000000000000000000000000..906f876434d20cc1926b9ed561c42b34bf610a62 --- /dev/null +++ b/vis4d/op/box/box2d.py @@ -0,0 +1,467 @@ +"""Utility functions for bounding boxes.""" + +from __future__ import annotations + +import torch +from torch import Tensor +from torchvision.ops import batched_nms, nms + +from vis4d.common.logging import rank_zero_warn +from vis4d.op.geometry.transform import transform_points + + +def bbox_scale( + boxes: torch.Tensor, scale_factor_xy: tuple[float, float] +) -> torch.Tensor: + """Scale bounding box tensor. + + Args: + boxes (torch.Tensor): Bounding boxes with shape [N, 4] + scale_factor_xy (tuple[float, float]): Scaling factor for x and y + + Returns: + torch.Tensor with bounding boxes scaled by the given factors in + x and y direction + """ + boxes[:, [0, 2]] *= scale_factor_xy[0] + boxes[:, [1, 3]] *= scale_factor_xy[1] + return boxes + + +def bbox_clip( + boxes: torch.Tensor, + image_hw: tuple[float, float], + epsilon: int = 0, +) -> torch.Tensor: + """Clip bounding boxes to image dims. + + Args: + boxes (torch.Tensor): Bounding boxes with shape [N, 4] + image_hw (tuple[float, float]): Image dimensions. + epsilon (int): Epsilon for clipping. + Defaults to 0. + + Returns: + torch.Tensor: Clipped bounding boxes. + """ + boxes[:, [0, 2]] = boxes[:, [0, 2]].clamp(0, image_hw[1] - epsilon) + boxes[:, [1, 3]] = boxes[:, [1, 3]].clamp(0, image_hw[0] - epsilon) + return boxes + + +def scale_and_clip_boxes( + boxes: torch.Tensor, + original_hw: tuple[int, int], + current_hw: tuple[int, int], + clip: bool = True, +) -> torch.Tensor: + """Postprocess boxes by scaling and clipping to given image dims. + + Args: + boxes (torch.Tensor): Bounding boxes with shape [N, 4]. + original_hw (tuple[int, int]): Original height / width of image. + current_hw (tuple[int, int]): Current height / width of image. + clip (bool): If true, clips box corners to image bounds. + + Returns: + torch.Tensor: Rescaled and possibly clipped bounding boxes. + """ + scale_factor = ( + original_hw[1] / current_hw[1], + original_hw[0] / current_hw[0], + ) + boxes = bbox_scale(boxes, scale_factor) + if clip: + boxes = bbox_clip(boxes, original_hw) + return boxes + + +def bbox_area(boxes: torch.Tensor) -> torch.Tensor: + """Compute bounding box areas. + + Args: + boxes (torch.Tensor): [N, 4] tensor of 2D boxes + in format (x1, y1, x2, y2). + + Returns: + torch.Tensor: [N,] tensor of box areas. + """ + return (boxes[:, 2] - boxes[:, 0]).clamp(0) * ( + boxes[:, 3] - boxes[:, 1] + ).clamp(0) + + +def bbox_intersection(boxes1: Tensor, boxes2: Tensor) -> torch.Tensor: + """Given two lists of boxes of size N and M, compute N x M intersection. + + Args: + boxes1: N 2D boxes in format (x1, y1, x2, y2) + boxes2: M 2D boxes in format (x1, y1, x2, y2) + + Returns: + Tensor: intersection (N, M). + """ + width_height = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) - torch.max( + boxes1[:, None, :2], boxes2[:, :2] + ) + width_height.clamp_(min=0) + intersection = width_height.prod(dim=2) + return intersection + + +def bbox_iou(boxes1: torch.Tensor, boxes2: torch.Tensor) -> torch.Tensor: + """Compute IoU between all pairs of boxes. + + Args: + boxes1: N 2D boxes in format (x1, y1, x2, y2) + boxes2: M 2D boxes in format (x1, y1, x2, y2) + + Returns: + Tensor: IoU (N, M). + """ + area1 = bbox_area(boxes1) + area2 = bbox_area(boxes2) + inter = bbox_intersection(boxes1, boxes2) + + union = area1[:, None] + area2 - inter + + inter = torch.where( + union > 0, + inter, + torch.zeros(1, dtype=inter.dtype, device=inter.device), + ) + + iou = torch.where( + inter > 0, + inter / (area1[:, None] + area2 - inter), + torch.zeros(1, dtype=inter.dtype, device=inter.device), + ) + return iou + + +def bbox_intersection_aligned(boxes1: Tensor, boxes2: Tensor) -> torch.Tensor: + """Given two lists of boxes both of size N, compute N intersection. + + Args: + boxes1: N 2D boxes in format (x1, y1, x2, y2) + boxes2: N 2D boxes in format (x1, y1, x2, y2) + + Returns: + Tensor: intersection (N). + """ + width_height = torch.min(boxes1[:, 2:], boxes2[:, 2:]) - torch.max( + boxes1[:, :2], boxes2[:, :2] + ) + width_height.clamp_(min=0) + intersection = width_height.prod(dim=1) + return intersection + + +def bbox_iou_aligned( + boxes1: torch.Tensor, boxes2: torch.Tensor +) -> torch.Tensor: + """Compute IoU between aligned pairs of boxes. + + The number of boxes in both inputs must be the same. + + Args: + boxes1: N 2D boxes in format (x1, y1, x2, y2) + boxes2: N 2D boxes in format (x1, y1, x2, y2) + + Returns: + Tensor: IoU (N). + """ + area1 = bbox_area(boxes1) + area2 = bbox_area(boxes2) + inter = bbox_intersection_aligned(boxes1, boxes2) + + iou = torch.where( + inter > 0, + inter / (area1 + area2 - inter), + torch.zeros(1, dtype=inter.dtype, device=inter.device), + ) + return iou + + +def transform_bbox( + trans_mat: torch.Tensor, boxes: torch.Tensor +) -> torch.Tensor: + """Apply trans_mat (3, 3) / (B, 3, 3) to (N, 4) / (B, N, 4) xyxy boxes. + + Args: + trans_mat (torch.Tensor): Transformation matrix + of shape (3,3) or (B,3,3) + boxes (torch.Tensor): Bounding boxes of shape (N,4) or (B,N,4) + + Returns: + torch.Tensor containing linear transformed bounding boxes. (B?, N, 4) + """ + assert len(trans_mat.shape) == len( + boxes.shape + ), "trans_mat and boxes must have same number of dimensions!" + x1y1 = boxes[..., :2] + x2y1 = torch.stack((boxes[..., 2], boxes[..., 1]), -1) + x2y2 = boxes[..., 2:] + x1y2 = torch.stack((boxes[..., 0], boxes[..., 3]), -1) + + x1y1 = transform_points(x1y1, trans_mat) + x2y1 = transform_points(x2y1, trans_mat) + x2y2 = transform_points(x2y2, trans_mat) + x1y2 = transform_points(x1y2, trans_mat) + + x_all = torch.stack( + (x1y1[..., 0], x2y2[..., 0], x2y1[..., 0], x1y2[..., 0]), -1 + ) + y_all = torch.stack( + (x1y1[..., 1], x2y2[..., 1], x2y1[..., 1], x1y2[..., 1]), -1 + ) + transformed_boxes = torch.stack( + ( + x_all.min(dim=-1)[0], + y_all.min(dim=-1)[0], + x_all.max(dim=-1)[0], + y_all.max(dim=-1)[0], + ), + -1, + ) + + if len(boxes.shape) == 2: + transformed_boxes.squeeze(0) + return transformed_boxes + + +# TODO, refactor? move to utils? +def random_choice(tensor: torch.Tensor, sample_size: int) -> torch.Tensor: + """Randomly choose elements from a tensor. + + If sample_size < len(tensor) this function will sample without repetition + otherwise certain elements will be repeated. + + Args: + tensor (torch.Tensor): Tensor to sample from + sample_size (int): Number of elements to sample + + Returns: + torch.Tensor containing sample_size randomly sampled entries. + """ + perm = torch.randperm(len(tensor), device=tensor.device)[:sample_size] + + # Additionally sample with repetition + if sample_size > len(tensor): + remaining_samples = sample_size - len(tensor) + perm = torch.concat( + [ + torch.randint( + remaining_samples, + (remaining_samples,), + device=tensor.device, + ), + perm, + ] + ) + + return tensor[perm] + + +def non_intersection( + tensor_a: torch.Tensor, tensor_b: torch.Tensor +) -> torch.Tensor: + """Get the elements of tensor_a that are not present in tensor_b. + + Args: + tensor_a (torch.Tensor): First tensor + tensor_b (torch.Tensor): Second tensor + + Returns: + torch.Tensor containing all elements that occur in both tensors + """ + compareview = tensor_b.repeat(tensor_a.shape[0], 1).T + return tensor_a[(compareview != tensor_a).T.prod(1) == 1] + + +def apply_mask( + masks: list[torch.Tensor], *args: list[torch.Tensor] +) -> tuple[list[torch.Tensor], ...]: + """Apply given masks (either bool or indices) to given list of tensors. + + Args: + masks (list[torch.Tensor]): Masks to apply on tensors. + *args (list[torch.Tensor]): List of tensors to apply the masks on. + + Returns: + tuple[list[torch.Tensor], ...]: Masked tensor lists. + """ + return tuple( + [t[m] if len(t) > 0 else t for t, m in zip(t_list, masks)] + for t_list in args + ) + + +def filter_boxes_by_area( + boxes: torch.Tensor, min_area: float = 0.0 +) -> tuple[torch.Tensor, torch.Tensor]: + """Filter a set of 2D bounding boxes given a minimum area. + + Args: + boxes (Tensor): 2D bounding boxes [N, 4]. + min_area (float, optional): Minimum area. Defaults to 0.0. + + Returns: + tuple[Tensor, Tensor]: filtered boxes, boolean mask + """ + if min_area > 0.0: + w = boxes[:, 2] - boxes[:, 0] + h = boxes[:, 3] - boxes[:, 1] + valid_mask = w * h >= min_area + if not valid_mask.all(): + return boxes[valid_mask], valid_mask + return boxes, boxes.new_ones((len(boxes),), dtype=torch.bool) + + +def hbox2corner(boxes: Tensor) -> Tensor: + """Convert box coordinates from boxes to corners. + + Boxes are represented as (x1, y1, x2, y2). + Corners are represented as ((x1, y1), (x2, y1), (x1, y2), (x2, y2)). + + Args: + boxes (Tensor): Horizontal box tensor with shape of (..., 4). + + Returns: + Tensor: Corner tensor with shape of (..., 4, 2). + """ + x1, y1, x2, y2 = torch.split(boxes, 1, dim=-1) + corners = torch.cat([x1, y1, x2, y1, x1, y2, x2, y2], dim=-1) + return corners.reshape(*corners.shape[:-1], 4, 2) + + +def corner2hbox(corners: Tensor) -> Tensor: + """Convert box coordinates from corners to boxes. + + Boxes are represented as (x1, y1, x2, y2). + Corners are represented as ((x1, y1), (x2, y1), (x1, y2), (x2, y2)). + + Args: + corners (Tensor): Corner tensor with shape of (..., 4, 2). + + Returns: + Tensor: Horizontal box tensor with shape of (..., 4). + """ + if corners.numel() == 0: + return corners.new_zeros((0, 4)) + min_xy = corners.min(dim=-2)[0] + max_xy = corners.max(dim=-2)[0] + return torch.cat([min_xy, max_xy], dim=-1) + + +def bbox_project(boxes: Tensor, homography_matrix: Tensor) -> Tensor: + """Apply geometric transform to boxes in-place. + + Args: + boxes (Tensor): Horizontal box tensor with shape of (..., 4). + homography_matrix (Tensor): Shape (3, 3) for geometric transformation. + """ + corners = hbox2corner(boxes) + corners = torch.cat( + [corners, corners.new_ones(*corners.shape[:-1], 1)], dim=-1 + ) + corners_t = torch.transpose(corners, -1, -2) + corners_t = torch.matmul(homography_matrix, corners_t) + corners = torch.transpose(corners_t, -1, -2) + # Convert to homogeneous coordinates by normalization + corners = corners[..., :2] / corners[..., 2:3] + return corner2hbox(corners) + + +def multiclass_nms( + multi_bboxes: Tensor, + multi_scores: Tensor, + score_thr: float, + iou_thr: float, + max_num: int = -1, + class_agnostic: bool = False, + split_thr: int = 100000, +) -> tuple[Tensor, Tensor, Tensor, Tensor]: + """Non-maximum suppression with multiple classes. + + Args: + multi_bboxes (Tensor): shape (n, #class*4) or (n, 4) + multi_scores (Tensor): shape (n, #class), where the last column + contains scores of the background class, but this will be ignored. + score_thr (float): bbox threshold, bboxes with scores lower than it + will not be considered. + iou_thr (float): NMS IoU threshold + max_num (int, optional): if there are more than max_num bboxes after + NMS, only top max_num will be kept. Defaults to -1. + class_agnostic (bool, optional): whether apply class_agnostic NMS. + Defaults to False. + split_thr (int, optional): If the number of bboxes is less than + split_thr, use class agnostic NMS with class_agnostic=True. + Defaults to 100000. + + Returns: + tuple: (Tensor, Tensor, Tensor, Tensor): detections (k, 5), scores + (k), classes (k) and indices (k). + + Raises: + RuntimeError: If there is a onnx error, + """ + num_classes = multi_scores.size(1) - 1 + # exclude background category + if multi_bboxes.shape[1] > 4: + bboxes = multi_bboxes.view(multi_scores.size(0), -1, 4) + else: + bboxes = multi_bboxes[:, None].expand( + multi_scores.size(0), num_classes, 4 + ) + + scores = multi_scores[:, :-1] + + labels = torch.arange(num_classes, dtype=torch.long, device=scores.device) + labels = labels.view(1, -1).expand_as(scores) + + bboxes = bboxes.reshape(-1, 4) + scores = scores.reshape(-1) + labels = labels.reshape(-1) + + if not torch.onnx.is_in_onnx_export(): + # NonZero not supported in TensorRT + # remove low scoring boxes + valid_mask = scores > score_thr + + if not torch.onnx.is_in_onnx_export(): + # NonZero not supported in TensorRT + inds = valid_mask.nonzero(as_tuple=False).squeeze(1) + bboxes, scores, labels = bboxes[inds], scores[inds], labels[inds] + else: + # TensorRT NMS plugin has invalid output filled with -1 + # add dummy data to make detection output correct. + bboxes = torch.cat([bboxes, bboxes.new_zeros(1, 4)], dim=0) + scores = torch.cat([scores, scores.new_zeros(1)], dim=0) + labels = torch.cat([labels, labels.new_zeros(1)], dim=0) + + if bboxes.numel() == 0: + if torch.onnx.is_in_onnx_export(): + raise RuntimeError( + "[ONNX Error] Can not record NMS " + "as it has not been executed this time" + ) + return bboxes, scores, labels, inds + + if class_agnostic and bboxes.shape[0] < split_thr: + keep = nms(bboxes, scores, iou_thr) + else: + if class_agnostic: + rank_zero_warn( + f"Number of bboxes is larger than {split_thr}, " + "using per-class NMS instead" + ) + keep = batched_nms(bboxes, scores, labels, iou_thr) + + if max_num > 0: + keep = keep[:max_num] + + bboxes = bboxes[keep] + scores = scores[keep] + labels = labels[keep] + return bboxes, scores, labels, inds[keep] diff --git a/vis4d/op/box/box3d.py b/vis4d/op/box/box3d.py new file mode 100644 index 0000000000000000000000000000000000000000..756d2172709c322e315451dc53d3a69754307a60 --- /dev/null +++ b/vis4d/op/box/box3d.py @@ -0,0 +1,144 @@ +"""Utility functions for 3D bounding boxes.""" + +from __future__ import annotations + +import torch +from torch import Tensor + +from vis4d.data.const import AxisMode +from vis4d.op.geometry.projection import project_points +from vis4d.op.geometry.rotation import ( + euler_angles_to_matrix, + matrix_to_quaternion, + quaternion_multiply, + quaternion_to_matrix, + rotate_orientation, + rotation_matrix_yaw, +) +from vis4d.op.geometry.transform import get_transform_matrix, transform_points + + +def boxes3d_to_corners(boxes3d: Tensor, axis_mode: AxisMode) -> Tensor: + """Convert a Tensor of 3D boxes to its respective corner points. + + Args: + boxes3d (Tensor): Box parameters. Tensor of shape [N, 10]. + axis_mode (AxisMode): Coordinate system convention. + + Returns: + Tensor: [N, 8, 3] 3D bounding box corner coordinates, in this order: + + (back) + (6) +---------+. (7) + | ` . | ` . + | (4) +---+-----+ (5) + | | | | + (2) +-----+---+. (3)| + ` . | ` . | + (0) ` +---------+ (1) + (front) + """ + w, l, h = boxes3d[:, 3], boxes3d[:, 4], boxes3d[:, 5] + rotation_matrix = quaternion_to_matrix(boxes3d[:, 6:]) + + if axis_mode == AxisMode.OPENCV: + x_corners = torch.stack( + [l / 2, l / 2, -l / 2, -l / 2, l / 2, l / 2, -l / 2, -l / 2], + dim=-1, + ) + y_corners = torch.stack( + [h / 2, h / 2, h / 2, h / 2, -h / 2, -h / 2, -h / 2, -h / 2], + dim=-1, + ) + z_corners = torch.stack( + [-w / 2, w / 2, -w / 2, w / 2, -w / 2, w / 2, -w / 2, w / 2], + dim=-1, + ) + else: + x_corners = torch.stack( + [l / 2, l / 2, -l / 2, -l / 2, l / 2, l / 2, -l / 2, -l / 2], + dim=-1, + ) + y_corners = torch.stack( + [-w / 2, w / 2, -w / 2, w / 2, -w / 2, w / 2, -w / 2, w / 2], + dim=-1, + ) + z_corners = torch.stack( + [-h / 2, -h / 2, -h / 2, -h / 2, h / 2, h / 2, h / 2, h / 2], + dim=-1, + ) + + corners = torch.stack([x_corners, y_corners, z_corners], dim=-1) + corners = transform_points( + corners, get_transform_matrix(rotation_matrix, boxes3d[:, :3]) + ) + return corners + + +def boxes3d_in_image( + box_corners: Tensor, cam_intrinsics: Tensor, image_hw: tuple[int, int] +) -> Tensor: + """Check if a 3D bounding box is (partially) in an image. + + Args: + box_corners (Tensor): [N, 8, 3] Tensor of 3D boxes corners. In OpenCV + coordinate frame. + cam_intrinsics (Tensor): [3, 3] Camera matrix. + image_hw (tuple[int, int]): image height / width. + + Returns: + Tensor: [N,] boolean values. + """ + points = project_points(box_corners.view(-1, 3), cam_intrinsics).view( + -1, 8, 2 + ) + mask = (points[..., 0] >= 0) * (points[..., 0] < image_hw[1]) * ( + points[..., 1] >= 0 + ) * (points[..., 1] < image_hw[0]) * box_corners[..., 2] > 0.0 + mask = mask.any(dim=-1) + return mask + + +def transform_boxes3d( + boxes3d: Tensor, + transform_matrix: Tensor, + source_axis_mode: AxisMode, + target_axis_mode: AxisMode, + only_yaw: bool = True, +) -> Tensor: + """Transform 3D boxes using given transform matrix. + + Args: + boxes3d (Tensor): [N, 10] Tensor of 3D boxes. + transform_matrix (Tensor): [4, 4] Transform matrix. + source_axis_mode (AxisMode): Source coordinate system convention of the + boxes. + target_axis_mode (AxisMode): Target coordinate system convention of the + boxes. + only_yaw (bool): Whether to only care about yaw rotation. + """ + boxes3d_transformed = boxes3d.new_zeros(boxes3d.shape) + boxes3d_transformed[:, :3] = transform_points( + boxes3d[:, :3], transform_matrix + ) + boxes3d_transformed[:, 3:6] = boxes3d[:, 3:6] + + if only_yaw: + orientation = rotation_matrix_yaw( + quaternion_to_matrix(boxes3d[:, 6:]), source_axis_mode + ) + + orientation = rotate_orientation( + orientation, transform_matrix, axis_mode=target_axis_mode + ) + + boxes3d_transformed[:, 6:] = matrix_to_quaternion( + euler_angles_to_matrix(orientation) + ) + else: + rot_quat = matrix_to_quaternion(transform_matrix[:3, :3]) + boxes3d_transformed[:, 6:] = quaternion_multiply( + rot_quat, boxes3d[:, 6:] + ) + + return boxes3d_transformed diff --git a/vis4d/op/box/encoder/__init__.py b/vis4d/op/box/encoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..76cb645aae18da37a436e0b2333f4a3648ccb8ad --- /dev/null +++ b/vis4d/op/box/encoder/__init__.py @@ -0,0 +1,12 @@ +"""Init box coder module.""" + +from .delta_xywh import DeltaXYWHBBoxDecoder, DeltaXYWHBBoxEncoder +from .qd_3dt import QD3DTBox3DDecoder +from .yolox import YOLOXBBoxDecoder + +__all__ = [ + "DeltaXYWHBBoxEncoder", + "DeltaXYWHBBoxDecoder", + "QD3DTBox3DDecoder", + "YOLOXBBoxDecoder", +] diff --git a/vis4d/op/box/encoder/bevformer.py b/vis4d/op/box/encoder/bevformer.py new file mode 100644 index 0000000000000000000000000000000000000000..6991757366f82d25fd210dc99d032850339c7458 --- /dev/null +++ b/vis4d/op/box/encoder/bevformer.py @@ -0,0 +1,119 @@ +"""NMS-Free bounding box coder for BEVFormer.""" + +from __future__ import annotations + +import torch +from torch import Tensor + + +class NMSFreeDecoder: + """BBox decoder for NMS-free detector.""" + + def __init__( + self, + num_classes: int, + post_center_range: list[float], + max_num: int = 100, + score_threshold: float | None = None, + ) -> None: + """Initialize NMSFreeDecoder. + + Args: + num_classes (int): Number of classes. + post_center_range (list[float]): Limit of the center. + max_num (int): Max number to be kept. Default: 100. + score_threshold (float): Threshold to filter boxes based on score. + Default: None. + """ + self.num_classes = num_classes + self.post_center_range = post_center_range + self.max_num = max_num + self.score_threshold = score_threshold + + def __call__( + self, cls_scores: Tensor, bbox_preds: Tensor + ) -> tuple[Tensor, Tensor, Tensor]: + """Decode single batch bboxes. + + Args: + cls_scores (Tensor): Outputs from the classification head, in shape + of [num_query, cls_out_channels]. Note cls_out_channels + should includes background. + bbox_preds (Tensor): Outputs from the regression + head with normalized coordinate format (cx, cy, w, l, cz, h, + rot_sine, rot_cosine, vx, vy). Shape [num_query, 9]. + + Returns: + tuple[Tensor, Tensor, Tensor]: Decoded boxes (x, y, z, l, w, h, + yaw, vx, vy), scores and labels. + """ + cls_scores = cls_scores.sigmoid() + scores, indexs = cls_scores.view(-1).topk(self.max_num) + labels = indexs % self.num_classes + bbox_index = indexs // self.num_classes + bbox_preds = bbox_preds[bbox_index] + + final_box_preds = _denormalize_bbox(bbox_preds) + final_scores = scores + final_preds = labels + + # use score threshold + if self.score_threshold is not None: + thresh_mask = final_scores > self.score_threshold + tmp_score = self.score_threshold + while thresh_mask.sum() == 0: + tmp_score *= 0.9 + if tmp_score < 0.01: + thresh_mask = final_scores > -1 + break + thresh_mask = final_scores >= tmp_score + + post_center_range = torch.tensor( + self.post_center_range, device=scores.device + ) + mask = (final_box_preds[..., :3] >= post_center_range[:3]).all(1) + mask &= (final_box_preds[..., :3] <= post_center_range[3:]).all(1) + + if self.score_threshold: + mask &= thresh_mask + + boxes3d = final_box_preds[mask] + scores = final_scores[mask] + + labels = final_preds[mask] + + return boxes3d, scores, labels + + +def _denormalize_bbox(normalized_bboxes: Tensor) -> Tensor: + """Denormalize bounding boxes.""" + # rotation + rot_sine = normalized_bboxes[..., 6:7] + + rot_cosine = normalized_bboxes[..., 7:8] + rot = torch.atan2(rot_sine, rot_cosine) + + # center in the bev + cx = normalized_bboxes[..., 0:1] + cy = normalized_bboxes[..., 1:2] + cz = normalized_bboxes[..., 4:5] + + # size + w = normalized_bboxes[..., 2:3] + l = normalized_bboxes[..., 3:4] + h = normalized_bboxes[..., 5:6] + + w = w.exp() + l = l.exp() + h = h.exp() + if normalized_bboxes.size(-1) > 8: + # velocity + vx = normalized_bboxes[:, 8:9] + vy = normalized_bboxes[:, 9:10] + denormalized_bboxes = torch.cat( + [cx, cy, cz, w, l, h, rot, vx, vy], dim=-1 + ) + else: + denormalized_bboxes = torch.cat([cx, cy, cz, w, l, h, rot], dim=-1) + + return denormalized_bboxes diff --git a/vis4d/op/box/encoder/delta_xywh.py b/vis4d/op/box/encoder/delta_xywh.py new file mode 100644 index 0000000000000000000000000000000000000000..65944620088cb7174938bdd06e6153e81f85be28 --- /dev/null +++ b/vis4d/op/box/encoder/delta_xywh.py @@ -0,0 +1,215 @@ +"""XYWH Delta coder for 2D boxes. + +Modified from mmdetection (https://github.com/open-mmlab/mmdetection). +""" + +from __future__ import annotations + +import math + +import torch +from torch import Tensor + + +class DeltaXYWHBBoxEncoder: + """Delta XYWH BBox encoder. + + Following the practice in `R-CNN `_, + it encodes bbox (x1, y1, x2, y2) into delta (dx, dy, dw, dh). + """ + + def __init__( + self, + target_means: tuple[float, float, float, float] = (0.0, 0.0, 0.0, 0.0), + target_stds: tuple[float, float, float, float] = (1.0, 1.0, 1.0, 1.0), + ) -> None: + """Creates an instance of the class. + + Args: + target_means (tuple, optional): Denormalizing means of target for + delta coordinates. Defaults to (0.0, 0.0, 0.0, 0.0). + target_stds (tuple, optional): Denormalizing standard deviation of + target for delta coordinates. Defaults to (1.0, 1.0, 1.0, 1.0). + """ + self.means = target_means + self.stds = target_stds + + def __call__(self, boxes: Tensor, targets: Tensor) -> Tensor: + """Get box regression transformation deltas. + + Used to transform target boxes into target regression parameters. + + Args: + boxes (Tensor): Source boxes, e.g., object proposals. + targets (Tensor): Target of the transformation, e.g., + ground-truth boxes. + + Returns: + Tensor: Box transformation deltas + """ + assert boxes.size(0) == targets.size(0) + assert boxes.size(-1) == targets.size(-1) == 4 + encoded_bboxes = bbox2delta(boxes, targets, self.means, self.stds) + return encoded_bboxes + + +class DeltaXYWHBBoxDecoder: + """Delta XYWH BBox decoder. + + Following the practice in `R-CNN `_, + it decodes delta (dx, dy, dw, dh) back to original bbox (x1, y1, x2, y2). + """ + + def __init__( + self, + target_means: tuple[float, float, float, float] = (0.0, 0.0, 0.0, 0.0), + target_stds: tuple[float, float, float, float] = (1.0, 1.0, 1.0, 1.0), + wh_ratio_clip: float = 16 / 1000, + ) -> None: + """Creates an instance of the class. + + Args: + target_means (tuple, optional): Denormalizing means of target for + delta coordinates. Defaults to (0.0, 0.0, 0.0, 0.0). + target_stds (tuple, optional): Denormalizing standard deviation of + target for delta coordinates. Defaults to (1.0, 1.0, 1.0, 1.0). + wh_ratio_clip (float, optional): Maximum aspect ratio for boxes. + Defaults to 16/1000. + """ + self.means = target_means + self.stds = target_stds + self.wh_ratio_clip = wh_ratio_clip + + def __call__(self, boxes: Tensor, box_deltas: Tensor) -> Tensor: + """Apply box offset energies box_deltas to boxes. + + Args: + boxes (Tensor): Basic boxes. Shape (B, N, 4) or (N, 4) + box_deltas (Tensor): Encoded offsets with respect to each roi. + Has shape (B, N, num_classes * 4) or (B, N, 4) or + (N, num_classes * 4) or (N, 4). Note N = num_anchors * W * H + when rois is a grid of anchors.Offset encoding follows [1]_. + + Returns: + Tensor: Decoded boxes. + """ + assert box_deltas.size(0) == boxes.size(0) + decoded_boxes = delta2bbox( + boxes, box_deltas, self.means, self.stds, self.wh_ratio_clip + ) + return decoded_boxes + + +def bbox2delta( + proposals: torch.Tensor, + gt_boxes: torch.Tensor, + means: tuple[float, float, float, float] = (0.0, 0.0, 0.0, 0.0), + stds: tuple[float, float, float, float] = (1.0, 1.0, 1.0, 1.0), +) -> Tensor: + """Compute deltas of proposals w.r.t. gt. + + We usually compute the deltas of x, y, w, h of proposals w.r.t ground + truth boxes to get regression target. + This is the inverse function of :func:`delta2bbox`. + + Args: + proposals (Tensor): Boxes to be transformed, shape (N, ..., 4). + gt_boxes (Tensor): Gt boxes to be used as base, shape (N, ..., 4). + means (Sequence[float]): Denormalizing means for delta coordinates. + stds (Sequence[float]): Denormalizing standard deviation for delta + coordinates. + + Returns: + Tensor: deltas with shape (N, 4), where columns represent dx, dy, + dw, dh. + """ + assert proposals.size() == gt_boxes.size() + + proposals = proposals.float() + gt = gt_boxes.float() + px = (proposals[..., 0] + proposals[..., 2]) * 0.5 + py = (proposals[..., 1] + proposals[..., 3]) * 0.5 + pw = proposals[..., 2] - proposals[..., 0] + ph = proposals[..., 3] - proposals[..., 1] + + gx = (gt[..., 0] + gt[..., 2]) * 0.5 + gy = (gt[..., 1] + gt[..., 3]) * 0.5 + gw = gt[..., 2] - gt[..., 0] + gh = gt[..., 3] - gt[..., 1] + + dx = (gx - px) / pw + dy = (gy - py) / ph + dw = torch.log(gw / pw) + dh = torch.log(gh / ph) + deltas = torch.stack([dx, dy, dw, dh], dim=-1) + + mean_tensor = torch.tensor(means, dtype=deltas.dtype, device=deltas.device) + std_tensor = torch.tensor(stds, dtype=deltas.dtype, device=deltas.device) + deltas = deltas.sub_(mean_tensor.view(1, -1)).div_(std_tensor.view(1, -1)) + + return deltas + + +def delta2bbox( + rois: torch.Tensor, + deltas: torch.Tensor, + means: tuple[float, float, float, float] = (0.0, 0.0, 0.0, 0.0), + stds: tuple[float, float, float, float] = (1.0, 1.0, 1.0, 1.0), + wh_ratio_clip: float = 16 / 1000, +) -> Tensor: + """Apply deltas to shift/scale base boxes. + + Typically the rois are anchor or proposed bounding boxes and the deltas are + network outputs used to shift/scale those boxes. + This is the inverse function of :func:`bbox2delta`. + + Args: + rois (Tensor): Boxes to be transformed. Has shape (N, 4). + deltas (Tensor): Encoded offsets relative to each roi. + Has shape (N, num_classes * 4) or (N, 4). Note + N = num_base_anchors * W * H, when rois is a grid of + anchors. Offset encoding follows [1]_. + means (Sequence[float]): Denormalizing means for delta coordinates. + Default (0., 0., 0., 0.). + stds (Sequence[float]): Denormalizing standard deviation for delta + coordinates. Default (1., 1., 1., 1.). + wh_ratio_clip (float): Maximum aspect ratio for boxes. Default + 16 / 1000. + + Returns: + Tensor: Boxes with shape (N, num_classes * 4) or (N, 4), where 4 + represent tl_x, tl_y, br_x, br_y. + + References: + .. [1] https://arxiv.org/abs/1311.2524 + """ + num_boxes, num_classes = deltas.size(0), deltas.size(1) // 4 + if num_boxes == 0: + return deltas + + deltas = deltas.reshape(-1, 4) + + mean_tensor = torch.tensor(means, dtype=deltas.dtype, device=deltas.device) + std_tensor = torch.tensor(stds, dtype=deltas.dtype, device=deltas.device) + denorm_deltas = deltas * std_tensor.view(1, -1) + mean_tensor.view(1, -1) + + dxy = denorm_deltas[:, :2] + dwh = denorm_deltas[:, 2:] + + # Compute width/height of each roi + rois_ = rois.repeat(1, num_classes).reshape(-1, 4) + pxy = (rois_[:, :2] + rois_[:, 2:]) * 0.5 + pwh = rois_[:, 2:] - rois_[:, :2] + + dxy_wh = pwh * dxy + + max_ratio = abs(math.log(wh_ratio_clip)) + dwh = dwh.clamp(min=-max_ratio, max=max_ratio) + + gxy = pxy + dxy_wh + gwh = pwh * dwh.exp() + x1y1 = gxy - (gwh * 0.5) + x2y2 = gxy + (gwh * 0.5) + boxes = torch.cat([x1y1, x2y2], dim=-1) + boxes = boxes.reshape(num_boxes, -1) + return boxes diff --git a/vis4d/op/box/encoder/qd_3dt.py b/vis4d/op/box/encoder/qd_3dt.py new file mode 100644 index 0000000000000000000000000000000000000000..7258f10b98a2f04a6407d484f77f8d2416a3e2d7 --- /dev/null +++ b/vis4d/op/box/encoder/qd_3dt.py @@ -0,0 +1,159 @@ +"""3D bounding box coder.""" + +from __future__ import annotations + +import numpy as np +import torch +from torch import Tensor + +from vis4d.data.const import AxisMode +from vis4d.op.geometry.projection import project_points, unproject_points +from vis4d.op.geometry.rotation import ( + alpha2yaw, + normalize_angle, + quaternion_to_matrix, + rotation_matrix_yaw, + rotation_output_to_alpha, + yaw2alpha, +) + + +class QD3DTBox3DEncoder: + """3D bounding box encoder based on qd_3dt.""" + + def __init__( + self, + center_scale: float = 10.0, + depth_log_scale: float = 2.0, + dim_log_scale: float = 2.0, + num_rotation_bins: int = 2, + bin_overlap: float = 1 / 6, + ) -> None: + """Init.""" + self.center_scale = center_scale + self.depth_log_scale = depth_log_scale + self.dim_log_scale = dim_log_scale + self.num_rotation_bins = num_rotation_bins + self.bin_overlap = bin_overlap + + def __call__( + self, boxes: Tensor, boxes3d: Tensor, intrinsics: Tensor + ) -> Tensor: + """Encode deltas between 2D boxes and 3D boxes given intrinsics.""" + # delta center 2d + projected_center_3d = project_points(boxes3d[:, :3], intrinsics) + ctr_x = (boxes[:, 0] + boxes[:, 2]) / 2 + ctr_y = (boxes[:, 1] + boxes[:, 3]) / 2 + center_2d = torch.stack([ctr_x, ctr_y], -1) + delta_center = (projected_center_3d - center_2d) / self.center_scale + + # depth + depth = torch.where( + boxes3d[:, 2] > 0, + torch.log(boxes3d[:, 2]) * self.depth_log_scale, + -boxes3d[:, 2].new_ones(1), + ) + depth = depth.unsqueeze(-1) + + # dimensions + dims = torch.where( + boxes3d[:, 3:6] > 0, + torch.log(boxes3d[:, 3:6]) * self.dim_log_scale, + boxes3d[:, 3:6].new_ones(1) * 100.0, + ) + + # WLH -> HWL + dims = dims[:, [2, 0, 1]] + + # rotation + yaw = rotation_matrix_yaw( + quaternion_to_matrix(boxes3d[:, 6:]), axis_mode=AxisMode.OPENCV + )[:, 1] + alpha = yaw2alpha(yaw, boxes3d[:, :3]) + bin_cls = torch.zeros( + (alpha.shape[0], self.num_rotation_bins), device=alpha.device + ) + bin_res = torch.zeros( + (alpha.shape[0], self.num_rotation_bins), device=alpha.device + ) + bin_centers = torch.arange( + -np.pi, + np.pi, + 2 * np.pi / self.num_rotation_bins, + device=alpha.device, + ) + bin_centers += np.pi / self.num_rotation_bins + for i in range(alpha.shape[0]): + overlap_value = ( + np.pi * 2 / self.num_rotation_bins * self.bin_overlap + ) + alpha_hi = normalize_angle(alpha[i] + overlap_value) + alpha_lo = normalize_angle(alpha[i] - overlap_value) + for bin_idx in range(self.num_rotation_bins): + bin_min = bin_centers[bin_idx] - np.pi / self.num_rotation_bins + bin_max = bin_centers[bin_idx] + np.pi / self.num_rotation_bins + if ( + bin_min <= alpha_lo <= bin_max + or bin_min <= alpha_hi <= bin_max + ): + bin_cls[i, bin_idx] = 1 + bin_res[i, bin_idx] = alpha[i] - bin_centers[bin_idx] + + return torch.cat([delta_center, depth, dims, bin_cls, bin_res], -1) + + +class QD3DTBox3DDecoder: + """3D bounding box decoder based on qd_3dt.""" + + def __init__( + self, + center_scale: float = 10.0, + depth_log_scale: float = 2.0, + dim_log_scale: float = 2.0, + num_rotation_bins: int = 2, + ) -> None: + """Creates an instance of the class.""" + super().__init__() + self.center_scale = center_scale + self.depth_log_scale = depth_log_scale + self.dim_log_scale = dim_log_scale + self.num_rotation_bins = num_rotation_bins + + def __call__( + self, boxes_2d: Tensor, boxes_deltas: Tensor, intrinsics: Tensor + ) -> Tensor: + """Decode the predicted boxes_deltas according to given 2D boxes.""" + # center + delta_center = boxes_deltas[:, 0:2] * self.center_scale + ctr_x = (boxes_2d[:, 0] + boxes_2d[:, 2]) / 2 + ctr_y = (boxes_2d[:, 1] + boxes_2d[:, 3]) / 2 + boxes_2d_center = torch.stack([ctr_x, ctr_y], -1) + center_2d = boxes_2d_center + delta_center + depth = torch.exp(boxes_deltas[:, 2:3] / self.depth_log_scale) + center_3d = unproject_points(center_2d, depth, intrinsics) + + # dimensions + dimensions = torch.exp(boxes_deltas[:, 3:6] / self.dim_log_scale) + + # rot_y + alpha = rotation_output_to_alpha( + boxes_deltas[:, 6:-1], self.num_rotation_bins + ) + rot_y = alpha2yaw(alpha, center_3d) + orientation = torch.stack( + [torch.zeros_like(rot_y), rot_y, torch.zeros_like(rot_y)], -1 + ) + + velocities = torch.zeros( + (boxes_deltas.shape[0], 3), device=boxes_deltas.device + ) + + return torch.cat( + [ + center_3d, + dimensions, + orientation, + velocities, + ], + 1, + ) diff --git a/vis4d/op/box/encoder/yolox.py b/vis4d/op/box/encoder/yolox.py new file mode 100644 index 0000000000000000000000000000000000000000..b9c7dbd7c5168bb9203488f30b17232cd88bbb45 --- /dev/null +++ b/vis4d/op/box/encoder/yolox.py @@ -0,0 +1,34 @@ +"""YOLOX decoder for 2D boxes. + +Modified from mmdetection (https://github.com/open-mmlab/mmdetection). +""" + +from __future__ import annotations + +import torch +from torch import Tensor + + +class YOLOXBBoxDecoder: + """YOLOX BBox decoder.""" + + def __call__(self, points: Tensor, offsets: Tensor) -> Tensor: + """Apply box offsets to points, used by YOLOX. + + Args: + points (Tensor): Points. Shape (B, N, 4) or (N, 4). + offsets (Tensor): Offsets. Has shape (B, N, 4) or (N, 4). + + Returns: + Tensor: Decoded boxes. + """ + xys = (offsets[..., :2] * points[:, 2:]) + points[:, :2] + whs = offsets[..., 2:].exp() * points[:, 2:] + + tl_x = xys[..., 0] - whs[..., 0] / 2 + tl_y = xys[..., 1] - whs[..., 1] / 2 + br_x = xys[..., 0] + whs[..., 0] / 2 + br_y = xys[..., 1] + whs[..., 1] / 2 + + decoded_bboxes = torch.stack([tl_x, tl_y, br_x, br_y], -1) + return decoded_bboxes diff --git a/vis4d/op/box/matchers/__init__.py b/vis4d/op/box/matchers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ef1dd157c1e4b3a9ebaf9e9445399b857c0b4003 --- /dev/null +++ b/vis4d/op/box/matchers/__init__.py @@ -0,0 +1,7 @@ +"""Matchers package.""" + +from .base import Matcher, MatchResult +from .max_iou import MaxIoUMatcher +from .sim_ota import SimOTAMatcher + +__all__ = ["Matcher", "MaxIoUMatcher", "MatchResult", "SimOTAMatcher"] diff --git a/vis4d/op/box/matchers/base.py b/vis4d/op/box/matchers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..52a6b27735e70101fdc82fe0cb8b9fa5c99098d8 --- /dev/null +++ b/vis4d/op/box/matchers/base.py @@ -0,0 +1,37 @@ +"""Matchers.""" + +import abc +from typing import NamedTuple + +import torch +from torch import nn + + +class MatchResult(NamedTuple): + """Match result class. Stores expected result tensors. + + assigned_gt_indices: torch.Tensor - Tensor of [0, M) where M = num gt + assigned_gt_iou: torch.Tensor - Tensor with IoU to assigned GT + assigned_labels: torch.Tensor - Tensor of {0, -1, 1} = {neg, ignore, pos} + """ + + assigned_gt_indices: torch.Tensor + assigned_gt_iou: torch.Tensor + assigned_labels: torch.Tensor + + +class Matcher(nn.Module): + """Base class for box / target matchers.""" + + @abc.abstractmethod + def forward( + self, boxes: torch.Tensor, targets: torch.Tensor + ) -> MatchResult: + """Match bounding boxes according to their struct.""" + raise NotImplementedError + + def __call__( + self, boxes: torch.Tensor, targets: torch.Tensor + ) -> MatchResult: + """Type declaration for forward.""" + return self._call_impl(boxes, targets) diff --git a/vis4d/op/box/matchers/max_iou.py b/vis4d/op/box/matchers/max_iou.py new file mode 100644 index 0000000000000000000000000000000000000000..506098db36a210fb67d547214b81eaa931436c70 --- /dev/null +++ b/vis4d/op/box/matchers/max_iou.py @@ -0,0 +1,126 @@ +"""Match predictions and targets according to maximum 2D IoU.""" + +from __future__ import annotations + +import torch +from torch import Tensor + +from ..box2d import bbox_iou +from .base import Matcher, MatchResult + + +# implementation modified from: +# https://github.com/facebookresearch/detectron2/ +class MaxIoUMatcher(Matcher): + """MaxIoUMatcher class.""" + + def __init__( + self, + thresholds: list[float], + labels: list[int], + allow_low_quality_matches: bool, + min_positive_iou: float = 0.0, + ): + """Creates an instance of the class.""" + super().__init__() + self.allow_low_quality_matches = allow_low_quality_matches + self.min_positive_iou = min_positive_iou + if not thresholds[0] > 0: + raise ValueError( + f"Lowest threshold {thresholds[0]} must be greater than 0!" + ) + eps = 1e-4 + thresholds.insert(0, 0.0 - eps) + thresholds.append(1.0 + eps) + if not all( + (lo <= hi for (lo, hi) in zip(thresholds[:-1], thresholds[1:])) + ): + raise ValueError("Thresholds must be in ascending order!") + + assert all( + (v in [-1, 0, 1] for v in labels) + ), "labels must be in [-1, 0, 1]!" + assert ( + len(labels) == len(thresholds) - 1 + ), "Labels must be of len(thresholds) + 1." + self.thresholds = thresholds + self.labels = labels + + def forward(self, boxes: Tensor, targets: Tensor) -> MatchResult: + """Match all boxes to targets based on maximum IoU.""" + if len(targets) == 0: + matches = boxes.new_zeros((len(boxes),), dtype=torch.int64) + match_labels = boxes.new_zeros((len(boxes),), dtype=torch.int8) + match_iou = boxes.new_zeros((len(boxes),)) + else: + # M x N matrix, where M = num gt, N = num proposals + match_quality_matrix = bbox_iou(targets, boxes) + + # matches N x 1 = index of assigned gt i.e. range [0, M) + # match_labels N x 1, 0 = negative, -1 = ignore, 1 = positive + matches, match_labels = self._compute_matches(match_quality_matrix) + match_iou = match_quality_matrix[ + matches, torch.arange(0, len(boxes), device=boxes.device) + ] + + return MatchResult( + assigned_gt_indices=matches, + assigned_labels=match_labels, + assigned_gt_iou=match_iou, + ) + + def _compute_matches( + self, match_quality_matrix: Tensor + ) -> tuple[Tensor, Tensor]: + """Compute matching boxes and their labels w/ match_quality_matrix.""" + assert match_quality_matrix.dim() == 2 + if match_quality_matrix.numel() == 0: + default_matches = match_quality_matrix.new_full( + (match_quality_matrix.shape[1],), 0, dtype=torch.int64 + ) + default_match_labels = match_quality_matrix.new_full( + (match_quality_matrix.shape[1],), + self.labels[0], + dtype=torch.int8, + ) + return default_matches, default_match_labels + + assert torch.all(torch.greater_equal(match_quality_matrix, 0)) + + # Max over gt elements (dim 0) --> best gt for each prediction + matched_vals, matches = match_quality_matrix.max(dim=0) + + match_labels = matches.new_full(matches.size(), 1, dtype=torch.int8) + + for l, low, high in zip( + self.labels, self.thresholds[:-1], self.thresholds[1:] + ): + low_high = (matched_vals >= low) & (matched_vals < high) + match_labels[low_high] = l + + if self.allow_low_quality_matches: + _set_low_quality_matches( + match_labels, match_quality_matrix, self.min_positive_iou + ) + + return matches, match_labels + + +def _set_low_quality_matches( + match_labels: Tensor, + match_quality_matrix: Tensor, + min_positive_iou: float = 0.0, +) -> None: + """Set matches for predictions that have only low-quality matches. + + See Sec. 3.1.2 of Faster R-CNN: https://arxiv.org/abs/1506.01497 + """ + highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1) + if min_positive_iou > 0: + highest_quality_foreach_gt = highest_quality_foreach_gt.clamp( + min_positive_iou + ) + pred_inds_with_highest_quality = ( + match_quality_matrix == highest_quality_foreach_gt[:, None] + ).nonzero()[:, 1] + match_labels[pred_inds_with_highest_quality] = 1 diff --git a/vis4d/op/box/matchers/sim_ota.py b/vis4d/op/box/matchers/sim_ota.py new file mode 100644 index 0000000000000000000000000000000000000000..d940ce12c5ac3c77e396aab6de1f1d4ee04ac874 --- /dev/null +++ b/vis4d/op/box/matchers/sim_ota.py @@ -0,0 +1,252 @@ +"""SimOTA label assigner. + +Modified from mmdetection (https://github.com/open-mmlab/mmdetection). +""" + +from __future__ import annotations + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from vis4d.op.box.box2d import bbox_iou + +from .base import MatchResult + +INF = 100000.0 +EPS = 1.0e-7 + + +class SimOTAMatcher(nn.Module): + """SimOTA label assigner used by YOLOX. + + Args: + center_radius (float, optional): Ground truth center size to judge + whether a prior is in center. Defaults to 2.5. + candidate_topk (int, optional): The candidate top-k which used to + get top-k ious to calculate dynamic-k. Defaults to 10. + iou_weight (float, optional): The scale factor for regression + iou cost. Defaults to 3.0. + cls_weight (float, optional): The scale factor for classification + cost. Defaults to 1.0. + """ + + def __init__( + self, + center_radius: float = 2.5, + candidate_topk: int = 10, + iou_weight: float = 3.0, + cls_weight: float = 1.0, + ): + """Init.""" + super().__init__() + self.center_radius = center_radius + self.candidate_topk = candidate_topk + self.iou_weight = iou_weight + self.cls_weight = cls_weight + + def forward( # pylint: disable=arguments-differ # type: ignore[override] + self, + pred_scores: Tensor, + priors: Tensor, + decoded_bboxes: Tensor, + gt_bboxes: Tensor, + gt_labels: Tensor, + ) -> MatchResult: + """Assign gt to priors using SimOTA. + + Args: + pred_scores (Tensor): Classification scores of one image, + a 2D-Tensor with shape [num_priors, num_classes] + priors (Tensor): All priors of one image, a 2D-Tensor with shape + [num_priors, 4] in [cx, xy, stride_w, stride_y] format. + decoded_bboxes (Tensor): Predicted bboxes, a 2D-Tensor with shape + [num_priors, 4] in [tl_x, tl_y, br_x, br_y] format. + gt_bboxes (Tensor): Ground truth bboxes of one image, a 2D-Tensor + with shape [num_gts, 4] in [tl_x, tl_y, br_x, br_y] format. + gt_labels (Tensor): Ground truth labels of one image, a Tensor + with shape [num_gts]. + + Returns: + MatchResult: The assigned result. + """ + num_gt = gt_bboxes.size(0) + num_bboxes = decoded_bboxes.size(0) + + # assign 0 by default + assigned_gt_inds = decoded_bboxes.new_full( + (num_bboxes,), 0, dtype=torch.long + ) + valid_mask, is_in_boxes_and_center = self.get_in_gt_and_in_center_info( + priors, gt_bboxes + ) + valid_decoded_bbox = decoded_bboxes[valid_mask] + valid_pred_scores = pred_scores[valid_mask] + num_valid = valid_decoded_bbox.size(0) + + if num_gt == 0 or num_bboxes == 0 or num_valid == 0: + # No ground truth or boxes, return empty assignment + assigned_gt_iou = decoded_bboxes.new_zeros((num_bboxes,)) + if num_gt == 0: + # No truth, assign everything to background + assigned_gt_inds[:] = 0 + if gt_labels is None: + assigned_labels = None + else: + assigned_labels = decoded_bboxes.new_full( + (num_bboxes,), -1, dtype=torch.long + ) + return MatchResult( + assigned_gt_indices=assigned_gt_inds, + assigned_labels=assigned_labels, + assigned_gt_iou=assigned_gt_iou, + ) + + pairwise_ious = bbox_iou(valid_decoded_bbox, gt_bboxes) + iou_cost = -torch.log(pairwise_ious + EPS) + + gt_onehot_label = ( + F.one_hot( # pylint: disable=not-callable + gt_labels.to(torch.int64), pred_scores.shape[-1] + ) + .float() + .unsqueeze(0) + .repeat(num_valid, 1, 1) + ) + + valid_pred_scores = valid_pred_scores.unsqueeze(1).repeat(1, num_gt, 1) + # disable AMP autocast and calculate BCE with FP32 to avoid overflow + with torch.cuda.amp.autocast(enabled=False): + cls_cost = ( + F.binary_cross_entropy( + valid_pred_scores.to(dtype=torch.float32), + gt_onehot_label, + reduction="none", + ) + .sum(-1) + .to(dtype=valid_pred_scores.dtype) + ) + + cost_matrix = ( + cls_cost * self.cls_weight + + iou_cost * self.iou_weight + + (~is_in_boxes_and_center) * INF + ) + + matched_pred_ious, matched_gt_inds = self.dynamic_k_matching( + cost_matrix, pairwise_ious, num_gt, valid_mask + ) + + # convert to MatchResult format + assigned_gt_inds[valid_mask] = matched_gt_inds + assigned_labels = assigned_gt_inds.new_full((num_bboxes,), -1) + assigned_labels[valid_mask] = 1 + assigned_gt_iou = assigned_gt_inds.new_full( + (num_bboxes,), -INF, dtype=torch.float32 + ) + assigned_gt_iou[valid_mask] = matched_pred_ious + return MatchResult( + assigned_gt_indices=assigned_gt_inds, + assigned_labels=assigned_labels, + assigned_gt_iou=assigned_gt_iou, + ) + + def get_in_gt_and_in_center_info( + self, priors: Tensor, gt_bboxes: Tensor + ) -> tuple[Tensor, Tensor]: + """Get whether the priors are in gt bboxes and in centers.""" + num_gt = gt_bboxes.size(0) + + repeated_x = priors[:, 0].unsqueeze(1).repeat(1, num_gt) + repeated_y = priors[:, 1].unsqueeze(1).repeat(1, num_gt) + repeated_stride_x = priors[:, 2].unsqueeze(1).repeat(1, num_gt) + repeated_stride_y = priors[:, 3].unsqueeze(1).repeat(1, num_gt) + + # is prior centers in gt bboxes, shape: [n_prior, n_gt] + l_ = repeated_x - gt_bboxes[:, 0] + t_ = repeated_y - gt_bboxes[:, 1] + r_ = gt_bboxes[:, 2] - repeated_x + b_ = gt_bboxes[:, 3] - repeated_y + + deltas = torch.stack([l_, t_, r_, b_], dim=1) + is_in_gts = deltas.min(dim=1).values > 0 + is_in_gts_all = is_in_gts.sum(dim=1) > 0 + + # is prior centers in gt centers + gt_cxs = (gt_bboxes[:, 0] + gt_bboxes[:, 2]) / 2.0 + gt_cys = (gt_bboxes[:, 1] + gt_bboxes[:, 3]) / 2.0 + ct_box_l = gt_cxs - self.center_radius * repeated_stride_x + ct_box_t = gt_cys - self.center_radius * repeated_stride_y + ct_box_r = gt_cxs + self.center_radius * repeated_stride_x + ct_box_b = gt_cys + self.center_radius * repeated_stride_y + + cl_ = repeated_x - ct_box_l + ct_ = repeated_y - ct_box_t + cr_ = ct_box_r - repeated_x + cb_ = ct_box_b - repeated_y + + ct_deltas = torch.stack([cl_, ct_, cr_, cb_], dim=1) + is_in_cts = ct_deltas.min(dim=1).values > 0 + is_in_cts_all = is_in_cts.sum(dim=1) > 0 + + # in boxes or in centers, shape: [num_priors] + is_in_gts_or_centers = is_in_gts_all | is_in_cts_all + + # both in boxes and centers, shape: [num_fg, num_gt] + is_in_boxes_and_centers = ( + is_in_gts[is_in_gts_or_centers, :] + & is_in_cts[is_in_gts_or_centers, :] + ) + return is_in_gts_or_centers, is_in_boxes_and_centers + + def dynamic_k_matching( + self, + cost: Tensor, + pairwise_ious: Tensor, + num_gt: int, + valid_mask: Tensor, + ) -> tuple[Tensor, Tensor]: + """Dynamic K matching strategy.""" + matching_matrix = torch.zeros_like(cost, dtype=torch.uint8) + # select candidate topk ious for dynamic-k calculation + candidate_topk = min(self.candidate_topk, pairwise_ious.size(0)) + topk_ious, _ = torch.topk(pairwise_ious, candidate_topk, dim=0) + # calculate dynamic k for each gt + dynamic_ks = torch.clamp(topk_ious.sum(0).int(), min=1) + for gt_idx in range(num_gt): + _, pos_idx = torch.topk( + cost[:, gt_idx], + k=dynamic_ks[gt_idx].item(), # type: ignore + largest=False, + ) + matching_matrix[:, gt_idx][pos_idx] = 1 + + del topk_ious, dynamic_ks, pos_idx + + prior_match_gt_mask = matching_matrix.sum(1) > 1 + if prior_match_gt_mask.sum() > 0: + _, cost_argmin = torch.min(cost[prior_match_gt_mask, :], dim=1) + matching_matrix[prior_match_gt_mask, :] *= 0 + matching_matrix[prior_match_gt_mask, cost_argmin] = 1 + # get foreground mask inside box and center prior + fg_mask_inboxes = matching_matrix.sum(1) > 0 + valid_mask[valid_mask.clone()] = fg_mask_inboxes + + matched_gt_inds = matching_matrix[fg_mask_inboxes, :].argmax(1) + matched_pred_ious = (matching_matrix * pairwise_ious).sum(1)[ + fg_mask_inboxes + ] + return matched_pred_ious, matched_gt_inds + + def __call__( + self, + pred_scores: Tensor, + priors: Tensor, + decoded_bboxes: Tensor, + gt_bboxes: Tensor, + gt_labels: Tensor, + ) -> MatchResult: + """Type declaration for forward.""" + return self._call_impl( + pred_scores, priors, decoded_bboxes, gt_bboxes, gt_labels + ) diff --git a/vis4d/op/box/poolers/__init__.py b/vis4d/op/box/poolers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0ef125f3eac772daf3e0e279812fbd64213f5f16 --- /dev/null +++ b/vis4d/op/box/poolers/__init__.py @@ -0,0 +1,15 @@ +"""Init sampler module.""" + +from .base import RoIPooler +from .roi_pooler import ( + MultiScaleRoIAlign, + MultiScaleRoIPool, + MultiScaleRoIPooler, +) + +__all__ = [ + "RoIPooler", + "MultiScaleRoIAlign", + "MultiScaleRoIPool", + "MultiScaleRoIPooler", +] diff --git a/vis4d/op/box/poolers/base.py b/vis4d/op/box/poolers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..55cc64ed8f281fc29837cf5a7a5765b515865423 --- /dev/null +++ b/vis4d/op/box/poolers/base.py @@ -0,0 +1,24 @@ +"""RoI Pooling module base.""" + +from __future__ import annotations + +import abc + +import torch +from torch import nn + + +class RoIPooler(nn.Module): + """Base class for RoI poolers.""" + + def __init__(self, resolution: tuple[int, int]) -> None: + """Creates an instance of the class.""" + super().__init__() + self.resolution = resolution + + @abc.abstractmethod + def forward( + self, features: list[torch.Tensor], boxes: list[torch.Tensor] + ) -> torch.Tensor: + """Pool features in input bounding boxes from given feature maps.""" + raise NotImplementedError diff --git a/vis4d/op/box/poolers/roi_pooler.py b/vis4d/op/box/poolers/roi_pooler.py new file mode 100644 index 0000000000000000000000000000000000000000..12a133a0e54e3249fc9d6c39680db142d28efe79 --- /dev/null +++ b/vis4d/op/box/poolers/roi_pooler.py @@ -0,0 +1,193 @@ +"""Vis4D RoI Pooling module.""" + +from __future__ import annotations + +import abc +import math + +import torch +from torchvision.ops import roi_align, roi_pool + +from vis4d.common.typing import ArgsType + +from .base import RoIPooler +from .utils import assign_boxes_to_levels, boxes_to_tensor + + +# implementation modified from: +# https://github.com/facebookresearch/detectron2/ +class MultiScaleRoIPooler(RoIPooler): + """Wrapper for roi pooling that supports multi-scale feature maps.""" + + def __init__( + self, + resolution: tuple[int, int], + strides: list[int], + canonical_box_size: int = 224, + canonical_level: int = 4, + aligned: bool = True, + ): + """Multi-scale version of arbitrary RoI pooling operations. + + Args: + resolution: Pooler resolution. + strides: Feature map strides relative to the input. + The strides must be powers of 2 and a monotically decreasing + geometric sequence with a factor of 1/2. + canonical_box_size: Canonical box size in pixels (sqrt(box area)). + The default is heuristically defined as 224 pixels in the FPN + paper (based on ImageNet pre-training). + canonical_level: The feature map level index from which a canonical + sized box should be placed. The default is defined as level 4 + (stride=16) in the FPN paper, i.e., a box of size 224x224 will + be placed on the feature with stride=16. + The box placement for all boxes will be determined from their + sizes w.r.t canonical_box_size. For example, a box whose area + is 4x that of a canonical box should be used to pool features + from feature level ``canonical_level+1``. + aligned (bool): For roi_align op. Shift the box coordinates it by + -0.5 for a better alignment with the two neighboring pixel + indices. + """ + super().__init__(resolution) + self.canonical_level = canonical_level + self.canonical_box_size = canonical_box_size + self.aligned = aligned + self.strides = strides + + # Map scale (defined as 1 / stride) to its feature map level under the + # assumption that stride is a power of 2. + self.scales = [1 / s for s in self.strides] + + min_level = -(math.log2(self.scales[0])) + max_level = -(math.log2(self.scales[-1])) + assert math.isclose(min_level, int(min_level)) and math.isclose( + max_level, int(max_level) + ), "Featuremap stride is not power of 2!" + self.min_level = int(min_level) + self.max_level = int(max_level) + assert ( + len(self.scales) == self.max_level - self.min_level + 1 + ), "[ROIPooler] Sizes of input NamedTensors do not form a pyramid!" + assert self.min_level >= 0 and self.min_level <= self.max_level + assert self.canonical_box_size > 0 + + def forward( + self, features: list[torch.Tensor], boxes: list[torch.Tensor] + ) -> torch.Tensor: + """Torchvision based roi pooling operation. + + Args: + features: List of image feature tensors (e.g., fpn levels) - NCHW + format. + boxes: List of proposals (per image). + + Returns: + torch.Tensor: NCHW format, where N = num boxes (total), + HW is roi size, C is feature dim. Boxes are concatenated along + dimension 0 for all batch elements. + """ + assert len(features) == len(self.scales), ( + f"unequal value, len(strides)={len(self.scales)}, " + f"but x is list of {len(features)} Tensors" + ) + + assert len(boxes) == features[0].shape[0], ( + f"unequal value, x[0] batch dim 0 is {features[0].shape[0]}, " + f"but box_list has length {len(boxes)}" + ) + if len(boxes) == 0: + return torch.zeros( + (0, features[0].shape[1]) + self.resolution, + device=features[0].device, + dtype=features[0].dtype, + ) + + pooler_fmt_boxes = boxes_to_tensor(boxes) + if len(self.scales) == 1: + return self._pooling_op( + features[0], + pooler_fmt_boxes, + spatial_scale=self.scales[0], + ) + + level_assignments = assign_boxes_to_levels( + boxes, + self.min_level, + self.max_level, + self.canonical_box_size, + self.canonical_level, + ) + + num_boxes = pooler_fmt_boxes.shape[0] + num_channels = features[0].shape[1] + output_size = self.resolution[0] + + dtype, device = features[0].dtype, features[0].device + output = torch.zeros( + (num_boxes, num_channels, output_size, output_size), + dtype=dtype, + device=device, + ) + + for level, scale in enumerate(self.scales): + inds = torch.eq(level_assignments, level).nonzero()[:, 0] + pooler_fmt_boxes_level = pooler_fmt_boxes[inds] + pooled_features = self._pooling_op( + features[level], pooler_fmt_boxes_level, spatial_scale=scale + ) + # Use index_put_ instead of advance indexing + # avoids pytorch/issues/49852 + output.index_put_((inds,), pooled_features) + + return output + + @abc.abstractmethod + def _pooling_op( + self, + inputs: torch.Tensor, + boxes: torch.Tensor, + spatial_scale: float = 1.0, + ) -> torch.Tensor: + """Execute pooling op defined in config.""" + raise NotImplementedError + + +class MultiScaleRoIAlign(MultiScaleRoIPooler): + """RoI Align supporting multi-scale inputs.""" + + def __init__( + self, sampling_ratio: int, *args: ArgsType, **kwargs: ArgsType + ) -> None: + """Creates an instance of the class.""" + super().__init__(*args, **kwargs) + self.sampling_ratio = sampling_ratio + + def _pooling_op( + self, + inputs: torch.Tensor, + boxes: torch.Tensor, + spatial_scale: float = 1.0, + ) -> torch.Tensor: + """Roialign wrapper.""" + return roi_align( + inputs, + boxes, + self.resolution, + spatial_scale, + self.sampling_ratio, + self.aligned, + ) + + +class MultiScaleRoIPool(MultiScaleRoIPooler): + """RoI Pool supporting multi-scale inputs.""" + + def _pooling_op( + self, + inputs: torch.Tensor, + boxes: torch.Tensor, + spatial_scale: float = 1.0, + ) -> torch.Tensor: + """Roipool wrapper.""" + return roi_pool(inputs, boxes, self.resolution, spatial_scale) diff --git a/vis4d/op/box/poolers/utils.py b/vis4d/op/box/poolers/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9dcb5f749cfdac0a9f3fc1e20b975094c7417b08 --- /dev/null +++ b/vis4d/op/box/poolers/utils.py @@ -0,0 +1,73 @@ +"""Utility functions for RoI poolers.""" + +from __future__ import annotations + +import torch + +from ..box2d import bbox_area + + +def assign_boxes_to_levels( + box_lists: list[torch.Tensor], + min_level: int, + max_level: int, + canonical_box_size: int, + canonical_level: int, +) -> torch.Tensor: + """Map each box to a feature map level index and return the assignment. + + Args: + box_lists: List of Boxes + min_level: Smallest feature map level index. The input is considered + index 0, the output of stage 1 is index 1, and so. + max_level: Largest feature map level index. + canonical_box_size: A canonical box size in pixels (sqrt(box area)). + canonical_level: The feature map level index on which a + canonically-sized box should be placed. + + Returns: + Tensor (M,), where M is the total number of boxes in the list. Each + element is the feature map index, as an offset from min_level, for the + corresponding box (so value i means the box is at self.min_level + i). + """ + box_sizes = torch.sqrt( + torch.cat([bbox_area(boxes) for boxes in box_lists]) + ) + # Eqn.(1) in FPN paper + level_assignments = torch.floor( + canonical_level + torch.log2(box_sizes / canonical_box_size + 1e-8) + ) + # clamp level to (min, max), in case the box size is too large or too small + # for the available feature maps + level_assignments = torch.clamp( + level_assignments, min=min_level, max=max_level + ) + return level_assignments.to(torch.int64) - min_level + + +def boxes_to_tensor(boxes: list[torch.Tensor]) -> torch.Tensor: + """Convert all boxes into the tensor format used by ROI pooling ops. + + Args: + boxes: List of Boxes + + Returns: + A tensor of shape (M, 5), where M is the total number of boxes + aggregated over all N batch images. The 5 columns are + (batch index, x0, y0, x1, y1), where batch index is in [0, N). + """ + + def _fmt_box_list(box_tensor: torch.Tensor, batch_i: int) -> torch.Tensor: + repeated_index = torch.full_like( + box_tensor[:, :1], + batch_i, + dtype=box_tensor.dtype, + device=box_tensor.device, + ) + return torch.cat((repeated_index, box_tensor), dim=1) + + pooler_fmt_boxes = torch.cat( + [_fmt_box_list(boxs[:, :4], i) for i, boxs in enumerate(boxes)], + dim=0, + ) + return pooler_fmt_boxes diff --git a/vis4d/op/box/samplers/__init__.py b/vis4d/op/box/samplers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..716b8f076e87b0fc31e00eb5233ba80b0674b9e5 --- /dev/null +++ b/vis4d/op/box/samplers/__init__.py @@ -0,0 +1,15 @@ +"""Init sampler module.""" + +from .base import Sampler, SamplingResult, match_and_sample_proposals +from .combined import CombinedSampler +from .pseudo import PseudoSampler +from .random import RandomSampler + +__all__ = [ + "Sampler", + "CombinedSampler", + "RandomSampler", + "PseudoSampler", + "SamplingResult", + "match_and_sample_proposals", +] diff --git a/vis4d/op/box/samplers/base.py b/vis4d/op/box/samplers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..2e41b14eb6dcecee94558e9640cbd3cb0fd7355e --- /dev/null +++ b/vis4d/op/box/samplers/base.py @@ -0,0 +1,71 @@ +"""Interface for Vis4D bounding box samplers.""" + +from __future__ import annotations + +import abc +from typing import NamedTuple + +import torch +from torch import Tensor, nn + +from ..matchers import Matcher, MatchResult + + +class SamplingResult(NamedTuple): + """Sampling result class. Stores expected result tensors. + + sampled_box_indices (Tensor): Index of sampled boxes from input. + sampled_target_indices (Tensor): Index of assigned target for each + positive sampled box. + sampled_labels (Tensor): {0, -1, 1} = {neg, ignore, pos}. + """ + + sampled_box_indices: Tensor + sampled_target_indices: Tensor + sampled_labels: Tensor + + +class Sampler(nn.Module): + """Sampler base class.""" + + def __init__(self, batch_size: int, positive_fraction: float) -> None: + """Creates an instance of the class.""" + super().__init__() + self.batch_size = batch_size + self.positive_fraction = positive_fraction + + @abc.abstractmethod + def forward(self, matching: MatchResult) -> SamplingResult: + """Sample bounding boxes according to their struct.""" + raise NotImplementedError + + def __call__(self, matching: MatchResult) -> SamplingResult: + """Type declaration.""" + return self._call_impl(matching) + + +def match_and_sample_proposals( + matcher: Matcher, + sampler: Sampler, + proposal_boxes: list[Tensor], + target_boxes: list[Tensor], +) -> tuple[list[Tensor], list[Tensor], list[Tensor]]: + """Match proposals to targets and subsample. + + First, match the proposals to targets (ground truth labels) using the + matcher. It is usually IoU matcher. The matching labels the proposals with + positive or negative to show whether they are matched to an object. + Second, the sampler will choose proposals based on certain criteria such as + total proposal number and ratio of postives and negatives. + """ + with torch.no_grad(): + matchings = tuple( + matcher(prop_box, tgt_box) + for prop_box, tgt_box in zip(proposal_boxes, target_boxes) + ) + sampling_results = tuple(sampler(matchs) for matchs in matchings) + return ( + [s.sampled_box_indices for s in sampling_results], + [s.sampled_target_indices for s in sampling_results], + [s.sampled_labels for s in sampling_results], + ) diff --git a/vis4d/op/box/samplers/combined.py b/vis4d/op/box/samplers/combined.py new file mode 100644 index 0000000000000000000000000000000000000000..622f7591fd0bcfd9b6205f57ea36f8616a99a1f4 --- /dev/null +++ b/vis4d/op/box/samplers/combined.py @@ -0,0 +1,210 @@ +"""Combined Sampler.""" + +from __future__ import annotations + +import torch +from torch import Tensor + +from vis4d.common.typing import ArgsType + +from ..box2d import non_intersection, random_choice +from ..matchers.base import MatchResult +from .base import Sampler, SamplingResult + + +class CombinedSampler(Sampler): + """Combined sampler. Can have different strategies for pos/neg samples.""" + + def __init__( + self, + *args: ArgsType, + pos_strategy: str, + neg_strategy: str, + neg_pos_ub: float = 3.0, + floor_thr: float = -1.0, + floor_fraction: float = 0.0, + num_bins: int = 3, + bg_label: int = 0, + **kwargs: ArgsType, + ): + """Creates an instance of the class.""" + super().__init__(*args, **kwargs) + self.neg_pos_ub = neg_pos_ub + self.floor_thr = floor_thr + self.floor_fraction = floor_fraction + self.num_bins = num_bins + self.bg_label = bg_label + + if not pos_strategy in { + "instance_balanced", + "iou_balanced", + } or not neg_strategy in {"instance_balanced", "iou_balanced"}: + raise ValueError( + "strategies must be in [instance_balanced, iou_balanced]" + ) + + self.pos_strategy = getattr(self, pos_strategy + "_sampling") + self.neg_strategy = getattr(self, neg_strategy + "_sampling") + + @staticmethod + def instance_balanced_sampling( + idx_tensor: Tensor, + assigned_gts: Tensor, + assigned_gt_ious: Tensor, # pylint: disable=unused-argument + sample_size: int, + ) -> Tensor: + """Sample indices with balancing according to matched GT instance.""" + if idx_tensor.numel() <= sample_size: + return idx_tensor + + unique_gt_inds = assigned_gts.unique() + num_gts = len(unique_gt_inds) + num_per_gt = int(sample_size / float(num_gts)) + sampled_inds_list = [] + # sample specific amount per gt instance + for i in unique_gt_inds: + inds = torch.nonzero(assigned_gts == i, as_tuple=False) + inds = inds.squeeze(1) + if len(inds) > num_per_gt: + inds = random_choice(inds, num_per_gt) + sampled_inds_list.append(inds) + sampled_inds = torch.cat(sampled_inds_list) + + # deal with edge cases + if len(sampled_inds) < sample_size: + num_extra = sample_size - len(sampled_inds) + extra_inds = non_intersection(idx_tensor, sampled_inds) + if len(extra_inds) > num_extra: + extra_inds = random_choice(extra_inds, num_extra) + sampled_inds = torch.cat([sampled_inds, extra_inds]) + return sampled_inds + + def iou_balanced_sampling( + self, + idx_tensor: Tensor, + assigned_gts: Tensor, # pylint: disable=unused-argument + assigned_gt_ious: Tensor, + sample_size: int, + ) -> Tensor: + """Sample indices with balancing according to IoU with matched GT.""" + if idx_tensor.numel() <= sample_size: + return idx_tensor + + # define 'floor' set - set with low iou samples + if self.floor_thr >= 0: + floor_set = idx_tensor[assigned_gt_ious <= self.floor_thr] + iou_sampling_set = idx_tensor[assigned_gt_ious > self.floor_thr] + else: + floor_set = None + iou_sampling_set = idx_tensor[assigned_gt_ious > self.floor_thr] + + num_iou_set_samples = int(sample_size * (1 - self.floor_fraction)) + if len(iou_sampling_set) > num_iou_set_samples: + if self.num_bins >= 2: + iou_sampled_inds = self.sample_within_intervals( + idx_tensor, assigned_gt_ious, num_iou_set_samples + ) + else: + iou_sampled_inds = random_choice( + iou_sampling_set, num_iou_set_samples + ) + else: + iou_sampled_inds = iou_sampling_set # pragma: no cover + + if floor_set is not None: + num_floor_set_samples = sample_size - len(iou_sampled_inds) + if len(floor_set) > num_floor_set_samples: + sampled_floor_inds = random_choice( + floor_set, num_floor_set_samples + ) + else: + sampled_floor_inds = floor_set # pragma: no cover + sampled_inds = torch.cat([sampled_floor_inds, iou_sampled_inds]) + else: + sampled_inds = iou_sampled_inds + + if len(sampled_inds) < sample_size: # pragma: no cover + num_extra = sample_size - len(sampled_inds) + extra_inds = non_intersection(idx_tensor, sampled_inds) + if len(extra_inds) > num_extra: + extra_inds = random_choice(extra_inds, num_extra) + sampled_inds = torch.cat([sampled_inds, extra_inds]) + + return sampled_inds + + def forward(self, matching: MatchResult) -> SamplingResult: + """Sample boxes according to strategies defined in cfg.""" + pos_sample_size = int(self.batch_size * self.positive_fraction) + + positive_mask: Tensor = (matching.assigned_labels != -1) & ( + matching.assigned_labels != self.bg_label + ) + negative_mask = torch.eq(matching.assigned_labels, self.bg_label) + + positive = positive_mask.nonzero()[:, 0] + negative = negative_mask.nonzero()[:, 0] + + num_pos = min(positive.numel(), pos_sample_size) + num_neg = self.batch_size - num_pos + + if self.neg_pos_ub >= 0: + neg_upper_bound = int(self.neg_pos_ub * num_pos) + num_neg = min(num_neg, neg_upper_bound) + + pos_idx = self.pos_strategy( + idx_tensor=positive, + assigned_gts=matching.assigned_gt_indices[positive_mask], + assigned_gt_ious=matching.assigned_gt_iou[positive_mask], + sample_size=num_pos, + ) + + neg_idx = self.neg_strategy( + idx_tensor=negative, + assigned_gts=matching.assigned_gt_indices[negative_mask], + assigned_gt_ious=matching.assigned_gt_iou[negative_mask], + sample_size=num_neg, + ) + sampled_idcs = torch.cat([pos_idx, neg_idx], dim=0) + + return SamplingResult( + sampled_box_indices=sampled_idcs, + sampled_target_indices=matching.assigned_gt_indices[sampled_idcs], + sampled_labels=matching.assigned_labels[sampled_idcs], + ) + + def sample_within_intervals( + self, + idx_tensor: Tensor, + assigned_gt_ious: Tensor, + sample_size: int, + ) -> Tensor: + """Sample according to N iou intervals where N = num bins.""" + floor_thr = max(self.floor_thr, 0.0) + max_iou = assigned_gt_ious.max() + iou_interval = (max_iou - floor_thr) / self.num_bins + per_bin_samples = int(sample_size / self.num_bins) + + sampled_inds_list = [] + for i in range(self.num_bins): + start_iou = floor_thr + i * iou_interval + end_iou = floor_thr + (i + 1) * iou_interval + tmp_set = ( + (start_iou <= assigned_gt_ious) & (assigned_gt_ious < end_iou) + ).nonzero()[:, 0] + if len(tmp_set) > per_bin_samples: + tmp_sampled_set = random_choice( + idx_tensor[tmp_set], per_bin_samples + ) + else: + tmp_sampled_set = idx_tensor[tmp_set] # pragma: no cover + sampled_inds_list.append(tmp_sampled_set) + + sampled_inds = torch.cat(sampled_inds_list) + if len(sampled_inds) < sample_size: + num_extra = sample_size - len(sampled_inds) + extra_inds = non_intersection(idx_tensor, sampled_inds) + if len(extra_inds) > num_extra: + extra_inds = random_choice(extra_inds, num_extra) + sampled_inds = torch.cat([sampled_inds, extra_inds]) + + return sampled_inds diff --git a/vis4d/op/box/samplers/pseudo.py b/vis4d/op/box/samplers/pseudo.py new file mode 100644 index 0000000000000000000000000000000000000000..e6e22431c9f606e442b05a3450f246140f814e16 --- /dev/null +++ b/vis4d/op/box/samplers/pseudo.py @@ -0,0 +1,35 @@ +"""Pseudo Sampler.""" + +from __future__ import annotations + +import torch + +from ..matchers.base import MatchResult +from .base import Sampler, SamplingResult + + +class PseudoSampler(Sampler): + """Pseudo sampler class (does nothing).""" + + def __init__(self) -> None: + """Init.""" + super(Sampler, self).__init__() + + def forward(self, matching: MatchResult) -> SamplingResult: + """Sample boxes randomly.""" + pos_idx, neg_idx = self._sample_labels(matching.assigned_labels) + sampled_idcs = torch.cat([pos_idx, neg_idx], dim=0) + return SamplingResult( + sampled_box_indices=sampled_idcs, + sampled_target_indices=matching.assigned_gt_indices[sampled_idcs], + sampled_labels=matching.assigned_labels[sampled_idcs], + ) + + @staticmethod + def _sample_labels( + labels: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Randomly sample indices from given labels.""" + positive = ((labels != -1) & (labels != 0)).nonzero()[:, 0] + negative = torch.eq(labels, 0).nonzero()[:, 0] + return positive, negative diff --git a/vis4d/op/box/samplers/random.py b/vis4d/op/box/samplers/random.py new file mode 100644 index 0000000000000000000000000000000000000000..d0bc6d202982f3b096e2519e8fdcf45b5b6dc965 --- /dev/null +++ b/vis4d/op/box/samplers/random.py @@ -0,0 +1,63 @@ +"""Random Sampler.""" + +from __future__ import annotations + +import torch + +from vis4d.common.typing import ArgsType + +from ..matchers.base import MatchResult +from .base import Sampler, SamplingResult + + +class RandomSampler(Sampler): + """Random sampler class.""" + + def __init__( + self, + *args: ArgsType, + bg_label: int = 0, + **kwargs: ArgsType, + ): + """Creates an instance of the class.""" + super().__init__(*args, **kwargs) + self.bg_label = bg_label + + def forward( + self, + matching: MatchResult, + ) -> SamplingResult: + """Sample boxes randomly.""" + pos_idx, neg_idx = self._sample_labels(matching.assigned_labels) + sampled_idcs = torch.cat([pos_idx, neg_idx], dim=0) + return SamplingResult( + sampled_box_indices=sampled_idcs, + sampled_target_indices=matching.assigned_gt_indices[sampled_idcs], + sampled_labels=matching.assigned_labels[sampled_idcs], + ) + + def _sample_labels( + self, labels: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + """Randomly sample indices from given labels.""" + positive = ((labels != -1) & (labels != self.bg_label)).nonzero()[:, 0] + negative = torch.eq(labels, self.bg_label).nonzero()[:, 0] + + num_pos = int(self.batch_size * self.positive_fraction) + # protect against not enough positive examples + num_pos = min(positive.numel(), num_pos) + num_neg = self.batch_size - num_pos + # protect against not enough negative examples + num_neg = min(negative.numel(), num_neg) + + # randomly select positive and negative examples + perm1 = torch.randperm(positive.numel(), device=positive.device)[ + :num_pos + ] + perm2 = torch.randperm(negative.numel(), device=negative.device)[ + :num_neg + ] + + pos_idx = positive[perm1] + neg_idx = negative[perm2] + return pos_idx, neg_idx diff --git a/vis4d/op/detect/__init__.py b/vis4d/op/detect/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9bddfadc65f29aee0f133843cf877ca84b67eba2 --- /dev/null +++ b/vis4d/op/detect/__init__.py @@ -0,0 +1 @@ +"""Detector module.""" diff --git a/vis4d/op/detect/common.py b/vis4d/op/detect/common.py new file mode 100644 index 0000000000000000000000000000000000000000..377366c121d1cdb7fd819e69cf5c2fd3590aa992 --- /dev/null +++ b/vis4d/op/detect/common.py @@ -0,0 +1,18 @@ +"""Common classes and functions for detection.""" + +from typing import NamedTuple + +from torch import Tensor + + +class DetOut(NamedTuple): + """Output of the detection model. + + boxes (list[Tensor]): 2D bounding boxes of shape [N, 4] in xyxy format. + scores (list[Tensor]): confidence scores of shape [N,]. + class_ids (list[Tensor]): class ids of shape [N,]. + """ + + boxes: list[Tensor] + scores: list[Tensor] + class_ids: list[Tensor] diff --git a/vis4d/op/detect/dense_anchor.py b/vis4d/op/detect/dense_anchor.py new file mode 100644 index 0000000000000000000000000000000000000000..7d8dd92ee6832223d6afabbaffd584dac289f5fd --- /dev/null +++ b/vis4d/op/detect/dense_anchor.py @@ -0,0 +1,347 @@ +"""Dense anchor-based head.""" + +from __future__ import annotations + +from typing import NamedTuple + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from vis4d.common.typing import TorchLossFunc +from vis4d.op.box.anchor import AnchorGenerator, anchor_inside_image +from vis4d.op.box.encoder import DeltaXYWHBBoxEncoder +from vis4d.op.box.matchers import Matcher +from vis4d.op.box.samplers import Sampler +from vis4d.op.loss.reducer import SumWeightedLoss +from vis4d.op.util import unmap + + +class DetectorTargets(NamedTuple): + """Targets for first-stage detection.""" + + labels: Tensor + label_weights: Tensor + bbox_targets: Tensor + bbox_weights: Tensor + + +def images_to_levels( + targets: list[ + tuple[list[Tensor], list[Tensor], list[Tensor], list[Tensor]] + ], +) -> list[list[Tensor]]: + """Convert targets by image to targets by feature level.""" + targets_per_level = [] + for lvl_id in range(len(targets[0][0])): + targets_single_level = [] + for tgt_id in range(len(targets[0])): + targets_single_level.append( + torch.stack([tgt[tgt_id][lvl_id] for tgt in targets], 0) + ) + targets_per_level.append(targets_single_level) + return targets_per_level + + +def get_targets_per_image( + target_boxes: Tensor, + anchors: Tensor, + matcher: Matcher, + sampler: Sampler, + box_encoder: DeltaXYWHBBoxEncoder, + image_hw: tuple[int, int], + target_class: Tensor | float = 1.0, + allowed_border: int = 0, +) -> tuple[DetectorTargets, int, int]: + """Get targets per batch element, all scales. + + Args: + target_boxes (Tensor): (N, 4) Tensor of target boxes for a single + image. + anchors (Tensor): (M, 4) box priors + matcher (Matcher): box matcher matching anchors to targets. + sampler (Sampler): box sampler sub-sampling matches. + box_encoder (DeltaXYWHBBoxEncoder): Encodes boxes into target + regression parameters. + image_hw (tuple[int, int]): input image height and width. + target_class (Tensor | float, optional): class label(s) of target + boxes. Defaults to 1.0. + allowed_border (int, optional): Allowed border for sub-sampling anchors + that lie inside the input image. Defaults to 0. + + Returns: + tuple[DetectorTargets, Tensor, Tensor]: Targets, sum of positives, sum + of negatives. + """ + inside_flags = anchor_inside_image( + anchors, image_hw, allowed_border=allowed_border + ) + # assign gt and sample anchors + anchors = anchors[inside_flags, :] + + matching = matcher(anchors, target_boxes) + sampling_result = sampler(matching) + + num_valid_anchors = anchors.size(0) + bbox_targets = torch.zeros_like(anchors) + bbox_weights = torch.zeros_like(anchors) + labels = anchors.new_zeros((num_valid_anchors,)) + label_weights = anchors.new_zeros(num_valid_anchors) + + positives = torch.eq(sampling_result.sampled_labels, 1) + negatives = torch.eq(sampling_result.sampled_labels, 0) + pos_inds = sampling_result.sampled_box_indices[positives] + pos_target_inds = sampling_result.sampled_target_indices[positives] + neg_inds = sampling_result.sampled_box_indices[negatives] + if len(pos_inds) > 0: + pos_bbox_targets = box_encoder( + anchors[pos_inds], target_boxes[pos_target_inds] + ) + bbox_targets[pos_inds] = pos_bbox_targets + bbox_weights[pos_inds] = 1.0 + if isinstance(target_class, float): + labels[pos_inds] = target_class + else: + labels[pos_inds] = target_class[pos_target_inds].float() + label_weights[pos_inds] = 1.0 + if len(neg_inds) > 0: + label_weights[neg_inds] = 1.0 + + # map up to original set of anchors + num_total_anchors = inside_flags.size(0) + labels = unmap(labels, num_total_anchors, inside_flags) + label_weights = unmap(label_weights, num_total_anchors, inside_flags) + bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags) + bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags) + + return ( + DetectorTargets(labels, label_weights, bbox_targets, bbox_weights), + int(positives.sum()), + int(negatives.sum()), + ) + + +def get_targets_per_batch( + featmap_sizes: list[tuple[int, int]], + target_boxes: list[Tensor], + target_class_ids: list[Tensor | float], + images_hw: list[tuple[int, int]], + anchor_generator: AnchorGenerator, + box_encoder: DeltaXYWHBBoxEncoder, + box_matcher: Matcher, + box_sampler: Sampler, + allowed_border: int = 0, +) -> tuple[list[list[Tensor]], int]: + """Get targets for all batch elements, all scales.""" + device = target_boxes[0].device + + anchor_grids = anchor_generator.grid_priors(featmap_sizes, device=device) + num_level_anchors = [anchors.size(0) for anchors in anchor_grids] + anchors_all_levels = torch.cat(anchor_grids) + + targets: list[ + tuple[list[Tensor], list[Tensor], list[Tensor], list[Tensor]] + ] = [] + num_total_pos, num_total_neg = 0, 0 + for tgt_box, tgt_cls, image_hw in zip( + target_boxes, target_class_ids, images_hw + ): + target, num_pos, num_neg = get_targets_per_image( + tgt_box, + anchors_all_levels, + box_matcher, + box_sampler, + box_encoder, + image_hw, + tgt_cls, + allowed_border, + ) + num_total_pos += num_pos + num_total_neg += num_neg + bbox_targets_per_level = target.bbox_targets.split(num_level_anchors) + bbox_weights_per_level = target.bbox_weights.split(num_level_anchors) + labels_per_level = target.labels.split(num_level_anchors) + label_weights_per_level = target.label_weights.split(num_level_anchors) + targets.append( + ( + bbox_targets_per_level, + bbox_weights_per_level, + labels_per_level, + label_weights_per_level, + ) + ) + targets_per_level = images_to_levels(targets) + num_samples = num_total_pos + num_total_neg + return targets_per_level, num_samples + + +class DenseAnchorHeadLosses(NamedTuple): + """Dense anchor head loss container.""" + + loss_cls: Tensor + loss_bbox: Tensor + + +class DenseAnchorHeadLoss(nn.Module): + """Loss of dense anchor heads. + + For a given set of multi-scale dense outputs, compute the desired target + outputs and apply classification and regression losses. + The targets are computed with the given target bounding boxes, the + anchor grid defined by the anchor generator and the given box encoder. + """ + + def __init__( + self, + anchor_generator: AnchorGenerator, + box_encoder: DeltaXYWHBBoxEncoder, + box_matcher: Matcher, + box_sampler: Sampler, + loss_cls: TorchLossFunc, + loss_bbox: TorchLossFunc, + allowed_border: int = 0, + ) -> None: + """Creates an instance of the class. + + Args: + anchor_generator (AnchorGenerator): Generates anchor grid priors. + box_encoder (DeltaXYWHBBoxEncoder): Encodes bounding boxes to + the desired network output. + box_matcher (Matcher): Box matcher. + box_sampler (Sampler): Box sampler. + loss_cls (TorchLossFunc): Classification loss. + loss_bbox (TorchLossFunc): Bounding box regression loss. + allowed_border (int): The border to allow the valid anchor. + Defaults to 0. + """ + super().__init__() + self.anchor_generator = anchor_generator + self.box_encoder = box_encoder + self.allowed_border = allowed_border + self.matcher = box_matcher + self.sampler = box_sampler + self.loss_cls = loss_cls + self.loss_bbox = loss_bbox + + def _loss_single_scale( + self, + cls_out: Tensor, + reg_out: Tensor, + bbox_targets: Tensor, + bbox_weights: Tensor, + labels: Tensor, + label_weights: Tensor, + num_total_samples: int, + ) -> tuple[Tensor, Tensor]: + """Compute losses per scale, all batch elements. + + Args: + cls_out (Tensor): [N, C, H, W] tensor of class logits. + reg_out (Tensor): [N, C, H, W] tensor of regression params. + bbox_targets (Tensor): [H * W, 4] bounding box targets + bbox_weights (Tensor): [H * W] per-sample weighting for loss. + labels (Tensor): [H * W] classification targets. + label_weights (Tensor): [H * W] per-sample weighting for loss. + num_total_samples (int): average factor of loss. + + Returns: + tuple[Tensor, Tensor]: classification and regression losses. + """ + # classification loss + labels = labels.reshape(-1) + label_weights = label_weights.reshape(-1) + cls_score = cls_out.permute(0, 2, 3, 1).reshape(labels.size(0), -1) + if cls_score.size(1) > 1: + labels = F.one_hot( # pylint: disable=not-callable + labels.long(), num_classes=cls_score.size(1) + 1 + )[:, : cls_score.size(1)].float() + label_weights = label_weights.repeat(cls_score.size(1)).reshape( + -1, cls_score.size(1) + ) + else: + cls_score = cls_score.squeeze(1) + + loss_cls = self.loss_cls(cls_score, labels, reduction="none") + loss_cls = SumWeightedLoss(label_weights, num_total_samples)(loss_cls) + + # regression loss + bbox_targets = bbox_targets.reshape(-1, 4) + bbox_weights = bbox_weights.reshape(-1, 4) + bbox_pred = reg_out.permute(0, 2, 3, 1).reshape(-1, 4) + + loss_bbox = self.loss_bbox( + pred=bbox_pred, + target=bbox_targets, + reducer=SumWeightedLoss(bbox_weights, num_total_samples), + ) + return loss_cls, loss_bbox + + def forward( + self, + cls_outs: list[Tensor], + reg_outs: list[Tensor], + target_boxes: list[Tensor], + images_hw: list[tuple[int, int]], + target_class_ids: list[Tensor | float] | None = None, + ) -> DenseAnchorHeadLosses: + """Compute RetinaNet classification and regression losses. + + Args: + cls_outs (list[Tensor]): Network classification outputs + at all scales. + reg_outs (list[Tensor]): Network regression outputs + at all scales. + target_boxes (list[Tensor]): Target bounding boxes. + images_hw (list[tuple[int, int]]): Image dimensions without + padding. + target_class_ids (list[Tensor] | None, optional): Target + class labels. + + Returns: + DenseAnchorHeadLosses: Classification and regression losses. + """ + featmap_sizes = [ + (featmap.size()[-2], featmap.size()[-1]) for featmap in cls_outs + ] + assert len(featmap_sizes) == self.anchor_generator.num_levels + if target_class_ids is None: + target_class_ids = [1.0 for _ in range(len(target_boxes))] + + targets_per_level, num_samples = get_targets_per_batch( + featmap_sizes, + target_boxes, + target_class_ids, + images_hw, + self.anchor_generator, + self.box_encoder, + self.matcher, + self.sampler, + self.allowed_border, + ) + + device = cls_outs[0].device + loss_cls_all = torch.tensor(0.0, device=device) + loss_bbox_all = torch.tensor(0.0, device=device) + for level_id, (cls_out, reg_out) in enumerate(zip(cls_outs, reg_outs)): + box_tgt, box_wgt, lbl, lbl_wgt = targets_per_level[level_id] + loss_cls, loss_bbox = self._loss_single_scale( + cls_out, reg_out, box_tgt, box_wgt, lbl, lbl_wgt, num_samples + ) + loss_cls_all += loss_cls + loss_bbox_all += loss_bbox + return DenseAnchorHeadLosses( + loss_cls=loss_cls_all, loss_bbox=loss_bbox_all + ) + + def __call__( + self, + cls_outs: list[Tensor], + reg_outs: list[Tensor], + target_boxes: list[Tensor], + images_hw: list[tuple[int, int]], + target_class_ids: list[Tensor] | None = None, + ) -> DenseAnchorHeadLosses: + """Type definition.""" + return self._call_impl( + cls_outs, reg_outs, target_boxes, images_hw, target_class_ids + ) diff --git a/vis4d/op/detect/faster_rcnn.py b/vis4d/op/detect/faster_rcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..08062d72bdbc09067fd156c60c213618a0a5e4fa --- /dev/null +++ b/vis4d/op/detect/faster_rcnn.py @@ -0,0 +1,228 @@ +"""Faster RCNN detector.""" + +from __future__ import annotations + +from typing import NamedTuple + +import torch +from torch import nn + +from vis4d.op.box.anchor import AnchorGenerator +from vis4d.op.box.box2d import apply_mask +from vis4d.op.box.encoder import DeltaXYWHBBoxDecoder +from vis4d.op.box.matchers import Matcher, MaxIoUMatcher +from vis4d.op.box.samplers import ( + RandomSampler, + Sampler, + match_and_sample_proposals, +) + +from .rcnn import RCNNHead, RCNNOut +from .rpn import RPN2RoI, RPNHead, RPNOut +from .typing import Proposals, Targets + + +class FRCNNOut(NamedTuple): + """Faster RCNN function call outputs.""" + + rpn: RPNOut + roi: RCNNOut + proposals: Proposals + sampled_proposals: Proposals | None + sampled_targets: Targets | None + sampled_target_indices: list[torch.Tensor] | None + + +class FasterRCNNHead(nn.Module): + """This class composes RPN and RCNN head components. + + It generates proposals via RPN and samples those, and runs the RCNN head + on the sampled proposals. During training, the sampling process is based + on the GT bounding boxes, during inference it is based on objectness score + of the proposals. + """ + + def __init__( + self, + num_classes: int, + anchor_generator: None | AnchorGenerator = None, + rpn_box_decoder: None | DeltaXYWHBBoxDecoder = None, + box_matcher: None | Matcher = None, + box_sampler: None | Sampler = None, + roi_head: None | RCNNHead = None, + proposal_append_gt: bool = True, + ) -> None: + """Creates an instance of the class. + + Args: + num_classes (int): Number of object categories. + anchor_generator (AnchorGenerator, optional): Custom generator for + RPN. Defaults to None. + rpn_box_decoder (DeltaXYWHBBoxDecoder, optional): Custom rpn box + decoder. Defaults to None. + box_matcher (Matcher, optional): Custom box matcher for RCNN stage. + Defaults to None. + box_sampler (Sampler, optional): Custom box sampler for RCNN stage. + Defaults to None. + roi_head (RCNNHead, optional): Custom ROI head. Defaults to None. + proposal_append_gt (bool): If to append the ground truth boxes for + proposal sampling during training. Defaults to True. + """ + super().__init__() + if anchor_generator is None: + anchor_generator = AnchorGenerator( + scales=[8], ratios=[0.5, 1.0, 2.0], strides=[4, 8, 16, 32, 64] + ) + + self.box_matcher = ( + MaxIoUMatcher( + thresholds=[0.5], + labels=[0, 1], + allow_low_quality_matches=False, + ) + if box_matcher is None + else box_matcher + ) + + self.box_sampler = ( + RandomSampler(batch_size=512, positive_fraction=0.25) + if box_sampler is None + else box_sampler + ) + + self.proposal_append_gt = proposal_append_gt + self.rpn_head = RPNHead(anchor_generator.num_base_priors[0]) + self.rpn2roi = RPN2RoI(anchor_generator, rpn_box_decoder) + + self.roi_head = ( + RCNNHead(num_classes=num_classes) if roi_head is None else roi_head + ) + + @torch.no_grad() + def _sample_proposals( + self, + proposal_boxes: list[torch.Tensor], + scores: list[torch.Tensor], + target_boxes: list[torch.Tensor], + target_classes: list[torch.Tensor], + ) -> tuple[Proposals, Targets, list[torch.Tensor]]: + """Sample proposals for training of Faster RCNN. + + Args: + proposal_boxes (list[torch.Tensor]): Proposals decoded from RPN. + scores (list[torch.Tensor]): Scores decoded from RPN. + target_boxes (list[torch.Tensor]): All target boxes. + target_classes (list[torch.Tensor]): According class labels. + + Returns: + tuple[Proposals, Targets]: Sampled proposals, associated targets. + """ + if self.proposal_append_gt: + proposal_boxes = [ + torch.cat([p, t]) for p, t in zip(proposal_boxes, target_boxes) + ] + scores = [ + torch.cat([s, s.new_ones(len(t))]) + for s, t in zip(scores, target_boxes) + ] + + ( + sampled_box_indices, + sampled_target_indices, + sampled_labels, + ) = match_and_sample_proposals( + self.box_matcher, self.box_sampler, proposal_boxes, target_boxes + ) + + sampled_boxes, sampled_scores = apply_mask( + sampled_box_indices, proposal_boxes, scores + ) + + sampled_target_boxes, sampled_target_classes = apply_mask( + sampled_target_indices, target_boxes, target_classes + ) + + sampled_proposals = Proposals( + boxes=sampled_boxes, scores=sampled_scores + ) + sampled_targets = Targets( + boxes=sampled_target_boxes, + classes=sampled_target_classes, + labels=sampled_labels, + ) + return sampled_proposals, sampled_targets, sampled_target_indices + + def forward( + self, + features: list[torch.Tensor], + images_hw: list[tuple[int, int]], + target_boxes: None | list[torch.Tensor] = None, + target_classes: None | list[torch.Tensor] = None, + ) -> FRCNNOut: + """Faster RCNN forward. + + Args: + features (list[torch.Tensor]): Feature pyramid. + images_hw (list[tuple[int, int]]): Image sizes without padding. + This is necessary for removing the erroneous boxes on the + padded regions. + target_boxes (None | list[torch.Tensor], optional): Ground truth + bounding box locations. Defaults to None. + target_classes (None | list[torch.Tensor], optional): Ground truth + bounding box classes. Defaults to None. + + Returns: + FRCNNReturn: Proposal and RoI outputs. + """ + if target_boxes is not None: + assert target_classes is not None + + rpn_out = self.rpn_head(features) + + if target_boxes is not None: + assert ( + target_classes is not None + ), "Need target classes for target boxes!" + proposal_boxes, scores = self.rpn2roi( + rpn_out.cls, rpn_out.box, images_hw + ) + + ( + sampled_proposals, + sampled_targets, + sampled_target_indices, + ) = self._sample_proposals( + proposal_boxes, scores, target_boxes, target_classes + ) + roi_out = self.roi_head(features, sampled_proposals.boxes) + else: + proposal_boxes, scores = self.rpn2roi( + rpn_out.cls, rpn_out.box, images_hw + ) + sampled_proposals, sampled_targets, sampled_target_indices = ( + None, + None, + None, + ) + roi_out = self.roi_head(features, proposal_boxes) + + return FRCNNOut( + roi=roi_out, + rpn=rpn_out, + proposals=Proposals(proposal_boxes, scores), + sampled_proposals=sampled_proposals, + sampled_targets=sampled_targets, + sampled_target_indices=sampled_target_indices, + ) + + def __call__( + self, + features: list[torch.Tensor], + images_hw: list[tuple[int, int]], + target_boxes: list[torch.Tensor] | None = None, + target_classes: list[torch.Tensor] | None = None, + ) -> FRCNNOut: + """Type definition for call implementation.""" + return self._call_impl( + features, images_hw, target_boxes, target_classes + ) diff --git a/vis4d/op/detect/mask_rcnn.py b/vis4d/op/detect/mask_rcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..d8a528d8693b9789052af7cada4f8df69459b20c --- /dev/null +++ b/vis4d/op/detect/mask_rcnn.py @@ -0,0 +1,420 @@ +"""Mask RCNN detector.""" + +from __future__ import annotations + +from typing import NamedTuple, Protocol + +import torch +import torch.nn.functional as F +from torch import Tensor, nn +from torchvision.ops import roi_align + +from vis4d.op.box.box2d import apply_mask +from vis4d.op.box.poolers import MultiScaleRoIAlign +from vis4d.op.mask.util import paste_masks_in_image, remove_overlap + +from .typing import Proposals, Targets + + +class MaskRCNNHeadOut(NamedTuple): + """Mask R-CNN RoI head outputs.""" + + # logits for mask prediction. The dimension is number of masks x number of + # classes x H_mask x W_mask + mask_pred: list[torch.Tensor] + + +class MaskRCNNHead(nn.Module): + """Mask R-CNN RoI head. + + Args: + num_classes (int, optional): Number of classes. Defaults to 80. + num_convs (int, optional): Number of convolution layers. Defaults to 4. + roi_size (tuple[int, int], optional): Size of RoI after pooling. + Defaults to (14, 14). + in_channels (int, optional): Input feature channels. Defaults to 256. + conv_kernel_size (int, optional): Kernel size of convolution. Defaults + to 3. + conv_out_channels (int, optional): Output channels of convolution. + Defaults to 256. + scale_factor (int, optional): Scaling factor of upsampling. Defaults + to 2. + class_agnostic (bool, optional): Whether to do class agnostic mask + prediction. Defaults to False. + """ + + def __init__( + self, + num_classes: int = 80, + num_convs: int = 4, + roi_size: tuple[int, int] = (14, 14), + in_channels: int = 256, + conv_kernel_size: int = 3, + conv_out_channels: int = 256, + scale_factor: int = 2, + class_agnostic: bool = False, + ) -> None: + """Creates an instance of the class.""" + super().__init__() + self.roi_pooler = MultiScaleRoIAlign( + sampling_ratio=0, resolution=roi_size, strides=[4, 8, 16, 32] + ) + + self.convs = nn.ModuleList() + for i in range(num_convs): + in_channels = in_channels if i == 0 else conv_out_channels + padding = (conv_kernel_size - 1) // 2 + self.convs.append( + nn.Conv2d( + in_channels, + conv_out_channels, + conv_kernel_size, + padding=padding, + ) + ) + + upsample_in_channels = ( + conv_out_channels if num_convs > 0 else in_channels + ) + self.upsample = nn.ConvTranspose2d( + upsample_in_channels, + conv_out_channels, + scale_factor, + stride=scale_factor, + ) + + out_channels = 1 if class_agnostic else num_classes + self.conv_logits = nn.Conv2d(conv_out_channels, out_channels, 1) + self.relu = nn.ReLU(inplace=True) + + self._init_weights(self.convs) + self._init_weights(self.upsample, mode="fan_out") + self._init_weights(self.conv_logits, mode="fan_out") + + @staticmethod + def _init_weights(module: nn.Module, mode: str = "fan_in") -> None: + """Initialize weights.""" + if hasattr(module, "weight") and hasattr(module, "bias"): + assert isinstance(module.weight, torch.Tensor) and isinstance( + module.bias, torch.Tensor + ) + nn.init.kaiming_normal_( + module.weight, mode=mode, nonlinearity="relu" # type: ignore + ) + nn.init.constant_(module.bias, 0) + + def forward( + self, features: list[torch.Tensor], boxes: list[torch.Tensor] + ) -> MaskRCNNHeadOut: + """Forward pass. + + Args: + features (list[torch.Tensor]): Feature pyramid. + boxes (list[torch.Tensor]): Proposal boxes. + + Returns: + MaskRCNNHeadOut: Mask prediction outputs. + """ + # Take stride 4, 8, 16, 32 features + mask_feats = self.roi_pooler(features[2:6], boxes) + for conv in self.convs: + mask_feats = self.relu(conv(mask_feats)) + mask_feats = self.relu(self.upsample(mask_feats)) + mask_pred = self.conv_logits(mask_feats) + num_dets_per_img = tuple(len(d) for d in boxes) + mask_preds = mask_pred.split(num_dets_per_img, 0) + return MaskRCNNHeadOut(mask_pred=mask_preds) + + +class MaskOut(NamedTuple): + """Output of the final detections from Mask RCNN.""" + + masks: list[torch.Tensor] # N, H, W + scores: list[torch.Tensor] + class_ids: list[torch.Tensor] + + +class Det2Mask(nn.Module): + """Post processing of mask predictions. + + Args: + mask_threshold (float, optional): Positive threshold. Defaults to 0.5. + no_overlap (bool, optional): Whether to remove overlapping pixels + between masks. Defaults to False. + """ + + def __init__( + self, mask_threshold: float = 0.5, no_overlap: bool = False + ) -> None: + """Creates an instance of the class.""" + super().__init__() + self.mask_threshold = mask_threshold + self.no_overlap = no_overlap + + def forward( + self, + mask_outs: list[torch.Tensor], + det_boxes: list[torch.Tensor], + det_scores: list[torch.Tensor], + det_class_ids: list[torch.Tensor], + original_hw: list[tuple[int, int]], + ) -> MaskOut: + """Paste mask predictions back into original image resolution. + + Args: + mask_outs (list[torch.Tensor]): List of mask outputs for each batch + element. + det_boxes (list[torch.Tensor]): List of detection boxes for each + batch element. + det_scores (list[torch.Tensor]): List of detection scores for each + batch element. + det_class_ids (list[torch.Tensor]): List of detection classeds for + each batch element. + original_hw (list[tuple[int, int]]): Original image resolution. + + Returns: + MaskOut: Post-processed mask predictions. + """ + all_masks = [] + all_scores = [] + all_class_ids = [] + for mask_out, boxes, scores, class_ids, orig_hw in zip( + mask_outs, det_boxes, det_scores, det_class_ids, original_hw + ): + pasted_masks = paste_masks_in_image( + mask_out[torch.arange(len(mask_out)), class_ids], + boxes, + orig_hw[::-1], + self.mask_threshold, + ) + if self.no_overlap: + pasted_masks = remove_overlap(pasted_masks, scores) + all_masks.append(pasted_masks) + all_scores.append(scores) + all_class_ids.append(class_ids) + return MaskOut( + masks=all_masks, scores=all_scores, class_ids=all_class_ids + ) + + def __call__( + self, + mask_outs: list[torch.Tensor], + det_boxes: list[torch.Tensor], + det_scores: list[torch.Tensor], + det_class_ids: list[torch.Tensor], + original_hw: list[tuple[int, int]], + ) -> MaskOut: + """Type definition for function call.""" + return self._call_impl( + mask_outs, det_boxes, det_scores, det_class_ids, original_hw + ) + + +class MaskRCNNHeadLosses(NamedTuple): + """Mask RoI head loss container.""" + + rcnn_loss_mask: torch.Tensor + + +class MaskRCNNHeadLoss(nn.Module): + """Mask RoI head loss function. + + Args: + num_classes (int): number of object categories. + """ + + def __init__(self, num_classes: int) -> None: + """Creates an instance of the class.""" + super().__init__() + self.num_classes = num_classes + + @staticmethod + def _get_targets_per_image( + boxes: Tensor, + tgt_masks: Tensor, + out_shape: tuple[int, int], + binarize: bool = True, + ) -> Tensor: + """Get aligned mask targets for each proposal. + + Args: + boxes (Tensor): proposal boxes. + tgt_masks (Tensor): target masks. + out_shape (tuple[int, int]): output shape. + binarize (bool, optional): whether to convert target mask to + binary. Defaults to True. + + Returns: + Tensor: aligned mask targets. + """ + fake_inds = torch.arange(len(boxes), device=boxes.device)[:, None] + rois = torch.cat([fake_inds, boxes], dim=1) # Nx5 + gt_masks_th = tgt_masks[:, None, :, :].type(rois.dtype) + targets = roi_align( + gt_masks_th, rois, out_shape, 1.0, 0, True + ).squeeze(1) + resized_masks = targets >= 0.5 if binarize else targets + return resized_masks + + def forward( + self, + mask_preds: list[torch.Tensor], + proposal_boxes: list[torch.Tensor], + target_classes: list[torch.Tensor], + target_masks: list[torch.Tensor], + ) -> MaskRCNNHeadLosses: + """Calculate losses of Mask RCNN head. + + Args: + mask_preds (list[torch.Tensor]): [M, C, H', W'] mask outputs per + batch element. + proposal_boxes (list[torch.Tensor]): [M, 4] proposal boxes per + batch element. + target_classes (list[torch.Tensor]): list of [M, 4] assigned + target boxes for each proposal. + target_masks (list[torch.Tensor]): list of [M, H, W] assigned + target masks for each proposal. + + Returns: + MaskRCNNHeadLosses: mask loss. + """ + mask_pred = torch.cat(mask_preds) + mask_size = (mask_pred.shape[2], mask_pred.shape[3]) + # get targets + targets = [] + for boxes, tgt_masks in zip(proposal_boxes, target_masks): + if len(tgt_masks) == 0: + targets.append( + torch.empty((0, *mask_size), device=tgt_masks.device) + ) + else: + targets.append( + self._get_targets_per_image(boxes, tgt_masks, mask_size) + ) + mask_targets = torch.cat(targets) + mask_labels = torch.cat(target_classes) + + if len(mask_targets) > 0: + num_rois = mask_pred.shape[0] + inds = torch.arange( + 0, num_rois, dtype=torch.long, device=mask_pred.device + ) + pred_slice = mask_pred[inds, mask_labels[inds].long()].squeeze(1) + loss_mask = F.binary_cross_entropy_with_logits( + pred_slice, mask_targets.float(), reduction="mean" + ) + else: + loss_mask = mask_targets.sum() + + return MaskRCNNHeadLosses(rcnn_loss_mask=loss_mask) + + +class MaskSampler(Protocol): + """Type definition for mask sampler.""" + + def __call__( + self, + target_masks: list[Tensor], + sampled_target_indices: list[Tensor], + sampled_targets: Targets, + sampled_proposals: Proposals, + ) -> tuple[list[Tensor], list[Tensor], list[Tensor]]: + """Type definition for function call. + + Args: + target_masks (list[Tensor]): list of [N, H, W] target masks per + batch element. + sampled_target_indices (list[Tensor]): list of [M] indices of + sampled targets per batch element. + sampled_targets (Targets): sampled targets. + sampled_proposals (Proposals): sampled proposals. + + Returns: + tuple[list[Tensor], list[Tensor], list[Tensor]]: sampled masks, + sampled target indices, sampled targets. + """ + + +def positive_mask_sampler( + target_masks: list[Tensor], + sampled_target_indices: list[Tensor], + sampled_targets: Targets, + sampled_proposals: Proposals, +) -> tuple[list[Tensor], list[Tensor], list[Tensor]]: + """Sample only positive masks from target masks. + + Args: + target_masks (list[Tensor]): list of [N, H, W] target masks per + batch element. + sampled_target_indices (list[Tensor]): list of [M] indices of + sampled targets per batch element. + sampled_targets (Targets): sampled targets. + sampled_proposals (Proposals): sampled proposals. + + Returns: + tuple[list[Tensor], list[Tensor], list[Tensor]]: sampled masks, + sampled target indices, sampled targets. + """ + sampled_masks = apply_mask(sampled_target_indices, target_masks)[0] + + pos_proposals, pos_classes, pos_mask_targets = apply_mask( + [torch.eq(label, 1) for label in sampled_targets.labels], + sampled_proposals.boxes, + sampled_targets.classes, + sampled_masks, + ) + return pos_proposals, pos_classes, pos_mask_targets + + +class SampledMaskLoss(nn.Module): + """Sampled Mask RCNN head loss function.""" + + def __init__( + self, + mask_sampler: MaskSampler, + loss: MaskRCNNHeadLoss, + ) -> None: + """Initialize sampled mask loss. + + Args: + mask_sampler (MaskSampler): mask sampler. + loss (MaskRCNNHeadLoss): mask loss. + """ + super().__init__() + self.loss = loss + self.mask_sampler = mask_sampler + + def forward( + self, + mask_preds: list[Tensor], + target_masks: list[Tensor], + sampled_target_indices: list[Tensor], + sampled_targets: Targets, + sampled_proposals: Proposals, + ) -> MaskRCNNHeadLosses: + """Calculate losses of Mask RCNN head. + + Args: + mask_preds (list[torch.Tensor]): [M, C, H', W'] mask outputs per + batch element. + target_masks (list[torch.Tensor]): list of [M, H, W] assigned + target masks for each proposal. + sampled_target_indices (list[Tensor]): list of [M, 4] assigned + target boxes for each proposal. + sampled_targets (Targets): list of [M, 4] assigned + target boxes for each proposal. + sampled_proposals (Proposals): list of [M, 4] assigned + target boxes for each proposal. + + Returns: + MaskRCNNHeadLosses: mask loss. + """ + pos_proposals, pos_classes, pos_mask_targets = self.mask_sampler( + target_masks, + sampled_target_indices, + sampled_targets, + sampled_proposals, + ) + return self.loss( + mask_preds, pos_proposals, pos_classes, pos_mask_targets + ) diff --git a/vis4d/op/detect/rcnn.py b/vis4d/op/detect/rcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..bb4fa5cbd1f7f57ead00bcf7614741cd4b636982 --- /dev/null +++ b/vis4d/op/detect/rcnn.py @@ -0,0 +1,452 @@ +"""Faster R-CNN RoI head.""" + +from __future__ import annotations + +from math import prod +from typing import NamedTuple + +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from vis4d.common.typing import TorchLossFunc +from vis4d.op.box.box2d import bbox_clip, multiclass_nms +from vis4d.op.box.encoder import DeltaXYWHBBoxDecoder, DeltaXYWHBBoxEncoder +from vis4d.op.box.poolers import MultiScaleRoIAlign +from vis4d.op.detect.common import DetOut +from vis4d.op.layer.conv2d import add_conv_branch +from vis4d.op.layer.weight_init import kaiming_init, normal_init, xavier_init +from vis4d.op.loss.common import l1_loss +from vis4d.op.loss.reducer import SumWeightedLoss + + +class RCNNOut(NamedTuple): + """Faster R-CNN RoI head outputs.""" + + # Logits for box classication. The logit dimension is number of classes + # plus 1 for the background. + cls_score: torch.Tensor + # Each box has regression for all classes. So the tensor dimention is + # [batch_size, number of boxes, number of classes x 4] + bbox_pred: torch.Tensor + + +def get_default_rcnn_box_codec( + target_means: tuple[float, float, float, float] = (0.0, 0.0, 0.0, 0.0), + target_stds: tuple[float, float, float, float] = (0.1, 0.1, 0.2, 0.2), +) -> tuple[DeltaXYWHBBoxEncoder, DeltaXYWHBBoxDecoder]: + """Get the default bounding box encoder and decoder for RCNN.""" + return ( + DeltaXYWHBBoxEncoder(target_means, target_stds), + DeltaXYWHBBoxDecoder(target_means, target_stds), + ) + + +class RCNNHead(nn.Module): + """Faster R-CNN RoI head. + + This head pools the RoIs from a set of feature maps and processes them + into classification / regression outputs. + + Args: + num_shared_convs (int, optional): number of shared conv layers. + Defaults to 0. + num_shared_fcs (int, optional): number of shared fc layers. Defaults + to 2. + conv_out_channels (int, optional): number of output channels for + shared conv layers. Defaults to 256. + in_channels (int, optional): Number of channels in input feature maps. + Defaults to 256. + fc_out_channels (int, optional): Output channels of shared linear + layers. Defaults to 1024. + num_classes (int, optional): number of categories. Defaults to 80. + roi_size (tuple[int, int], optional): size of pooled RoIs. Defaults + to (7, 7). + """ + + def __init__( + self, + num_shared_convs: int = 0, + num_shared_fcs: int = 2, + conv_out_channels: int = 256, + in_channels: int = 256, + fc_out_channels: int = 1024, + num_classes: int = 80, + roi_size: tuple[int, int] = (7, 7), + start_level: int = 2, + ) -> None: + """Creates an instance of the class.""" + super().__init__() + self.roi_pooler = MultiScaleRoIAlign( + sampling_ratio=0, resolution=roi_size, strides=[4, 8, 16, 32] + ) + + # Used feature layers are [start_level, end_level) + self.start_level = start_level + self.end_level = start_level + len(self.roi_pooler.scales) + + self.num_shared_convs = num_shared_convs + self.num_shared_fcs = num_shared_fcs + self.conv_out_channels = conv_out_channels + self.fc_out_channels = fc_out_channels + + # add shared convs and fcs + ( + self.shared_convs, + self.shared_fcs, + last_layer_dim, + ) = self._add_conv_fc_branch( + self.num_shared_convs, self.num_shared_fcs, in_channels, True + ) + self.shared_out_channels = last_layer_dim + + in_channels *= prod(roi_size) + + self.fc_cls = nn.Linear( + in_features=fc_out_channels, out_features=num_classes + 1 + ) + self.fc_reg = nn.Linear( + in_features=fc_out_channels, out_features=4 * num_classes + ) + self.relu = nn.ReLU(inplace=True) + + self._init_weights() + + def _add_conv_fc_branch( + self, + num_branch_convs: int = 0, + num_branch_fcs: int = 0, + in_channels: int = 0, + is_shared: bool = False, + ) -> tuple[nn.ModuleList, nn.ModuleList, int]: + """Add shared or separable branch.""" + convs, last_layer_dim = add_conv_branch( + num_branch_convs, + in_channels, + self.conv_out_channels, + True, + None, + None, + ) + + fcs = nn.ModuleList() + if num_branch_fcs > 0: + if is_shared or num_branch_fcs == 0: + last_layer_dim *= int(np.prod(self.roi_pooler.resolution)) + for i in range(num_branch_fcs): + fc_in_dim = last_layer_dim if i == 0 else self.fc_out_channels + fcs.append(nn.Linear(fc_in_dim, self.fc_out_channels)) + return convs, fcs, last_layer_dim + + def _init_weights(self) -> None: + """Init weights.""" + for m in self.shared_convs.modules(): + kaiming_init(m) + + for m in self.shared_fcs.modules(): + xavier_init(m, distribution="uniform") + + normal_init(self.fc_cls, std=0.01) + normal_init(self.fc_reg, std=0.001) + + def forward( + self, features: list[torch.Tensor], boxes: list[torch.Tensor] + ) -> RCNNOut: + """Forward pass during training stage.""" + bbox_feats = self.roi_pooler( + features[self.start_level : self.end_level], boxes + ) + if self.num_shared_convs > 0: + for conv in self.shared_convs: + bbox_feats = conv(bbox_feats) + + bbox_feats = bbox_feats.flatten(start_dim=1) + + for fc in self.shared_fcs: + bbox_feats = self.relu(fc(bbox_feats)) + cls_score = self.fc_cls(bbox_feats) + bbox_pred = self.fc_reg(bbox_feats) + return RCNNOut(cls_score, bbox_pred) + + def __call__( + self, features: list[torch.Tensor], boxes: list[torch.Tensor] + ) -> RCNNOut: + """Type definition for function call.""" + return self._call_impl(features, boxes) + + +class RoI2Det(nn.Module): + """Post processing of RCNN results and detection generation. + + It does the following: + 1. Take the classification and regression outputs from the RCNN heads. + 2. Take the proposal boxes that are RCNN inputs. + 3. Determine the final box classes and take the according box regression + parameters. + 4. Adjust the box sizes and offsets according the regression parameters. + 5. Return the final boxes. + """ + + def __init__( + self, + box_decoder: None | DeltaXYWHBBoxDecoder = None, + score_threshold: float = 0.05, + iou_threshold: float = 0.5, + max_per_img: int = 100, + class_agnostic_nms: bool = False, + ) -> None: + """Creates an instance of the class. + + Args: + box_decoder (DeltaXYWHBBoxDecoder, optional): Decodes regression + parameters to detected boxes. Defaults to None. If None, it + will use the default decoder. + score_threshold (float, optional): Minimum score of a detection. + Defaults to 0.05. + iou_threshold (float, optional): IoU threshold of NMS + post-processing step. Defaults to 0.5. + max_per_img (int, optional): Maximum number of detections per + image. Defaults to 100. + class_agnostic_nms (bool, optional): Whether to use class agnostic + NMS. Defaults to False. + """ + super().__init__() + if box_decoder is None: + _, self.box_decoder = get_default_rcnn_box_codec() + else: + self.box_decoder = box_decoder + self.score_threshold = score_threshold + self.max_per_img = max_per_img + self.iou_threshold = iou_threshold + self.class_agnostic_nms = class_agnostic_nms + + def forward( + self, + class_outs: torch.Tensor, + regression_outs: torch.Tensor, + boxes: list[torch.Tensor], + images_hw: list[tuple[int, int]], + ) -> DetOut: + """Convert RCNN network outputs to detections. + + Args: + class_outs (torch.Tensor): [B, num_classes] batched tensor of + classifiation scores. + regression_outs (torch.Tensor): [B, num_classes * 4] predicted + box offsets. + boxes (list[torch.Tensor]): Initial boxes (RoIs). + images_hw (list[tuple[int, int]]): Image sizes. + + Returns: + DetOut: boxes, scores and class ids of detections per image. + """ + num_proposals_per_img = tuple(len(p) for p in boxes) + regression_outs = regression_outs.split(num_proposals_per_img, 0) + class_outs = class_outs.split(num_proposals_per_img, 0) + all_det_boxes = [] + all_det_scores = [] + all_det_class_ids = [] + for cls_out, reg_out, boxs, image_hw in zip( + class_outs, regression_outs, boxes, images_hw + ): + scores = F.softmax(cls_out, dim=-1) + bboxes = bbox_clip( + self.box_decoder(boxs[:, :4], reg_out).view(-1, 4), + image_hw, + ).view(reg_out.shape) + det_bbox, det_scores, det_label, _ = multiclass_nms( + bboxes, + scores, + self.score_threshold, + self.iou_threshold, + self.max_per_img, + self.class_agnostic_nms, + ) + all_det_boxes.append(det_bbox) + all_det_scores.append(det_scores) + all_det_class_ids.append(det_label) + + return DetOut( + boxes=all_det_boxes, + scores=all_det_scores, + class_ids=all_det_class_ids, + ) + + def __call__( + self, + class_outs: torch.Tensor, + regression_outs: torch.Tensor, + boxes: list[torch.Tensor], + images_hw: list[tuple[int, int]], + ) -> DetOut: + """Type definition for function call.""" + return self._call_impl(class_outs, regression_outs, boxes, images_hw) + + +class RCNNTargets(NamedTuple): + """Target container.""" + + labels: Tensor + label_weights: Tensor + bbox_targets: Tensor + bbox_weights: Tensor + + +class RCNNLosses(NamedTuple): + """RCNN loss container.""" + + rcnn_loss_cls: torch.Tensor + rcnn_loss_bbox: torch.Tensor + + +class RCNNLoss(nn.Module): + """RCNN loss in Faster R-CNN. + + This class computes the loss of RCNN given proposal boxes and their + corresponding target boxes with the given box encoder. + """ + + def __init__( + self, + box_encoder: DeltaXYWHBBoxEncoder, + num_classes: int = 80, + loss_cls: TorchLossFunc = F.cross_entropy, + loss_bbox: TorchLossFunc = l1_loss, + ) -> None: + """Creates an instance of the class. + + Args: + box_encoder (DeltaXYWHBBoxEncoder): Decodes box regression + parameters into detected boxes. + num_classes (int, optional): number of object categories. Defaults + to 80. + loss_cls (TorchLossFunc, optional): Classification loss function. + Defaults to F.cross_entropy. + loss_bbox (TorchLossFunc, optional): Regression loss function. + Defaults to l1_loss. + """ + super().__init__() + self.num_classes = num_classes + self.box_encoder = box_encoder + self.loss_cls = loss_cls + self.loss_bbox = loss_bbox + + def _get_targets_per_image( + self, + boxes: Tensor, + labels: Tensor, + target_boxes: Tensor, + target_classes: Tensor, + ) -> RCNNTargets: + """Generate targets per image. + + Args: + boxes (Tensor): [N, 4] tensor of proposal boxes + labels (Tensor): [N,] tensor of positive / negative / ignore labels + target_boxes (Tensor): [N, 4] Assigned target boxes. + target_classes (Tensor): [N,] Assigned target class labels. + + Returns: + RCNNTargets: Box / class label tensors and weights. + """ + pos_mask, neg_mask = torch.eq(labels, 1), torch.eq(labels, 0) + num_pos, num_neg = int(pos_mask.sum()), int(neg_mask.sum()) + num_samples = num_pos + num_neg + + # original implementation uses new_zeros since BG are set to be 0 + # now use empty & fill because BG cat_id = num_classes, + # FG cat_id = [0, num_classes-1] + labels = boxes.new_full( + (num_samples,), self.num_classes, dtype=torch.long + ) + label_weights = boxes.new_zeros(num_samples) + box_targets = boxes.new_zeros(num_samples, 4) + box_weights = boxes.new_zeros(num_samples, 4) + if num_pos > 0: + pos_target_boxes = target_boxes[pos_mask] + pos_target_classes = target_classes[pos_mask] + labels[:num_pos] = pos_target_classes + label_weights[:num_pos] = 1.0 + pos_box_targets = self.box_encoder( + boxes[pos_mask], pos_target_boxes + ) + box_targets[:num_pos, :] = pos_box_targets + box_weights[:num_pos, :] = 1 + if num_neg > 0: + label_weights[-num_neg:] = 1.0 + return RCNNTargets(labels, label_weights, box_targets, box_weights) + + def forward( + self, + class_outs: torch.Tensor, + regression_outs: torch.Tensor, + boxes: list[torch.Tensor], + boxes_mask: list[torch.Tensor], + target_boxes: list[torch.Tensor], + target_classes: list[torch.Tensor], + ) -> RCNNLosses: + """Calculate losses of RCNN head. + + Args: + class_outs (torch.Tensor): [M*B, num_classes] classification + outputs. + regression_outs (torch.Tensor): Tensor[M*B, regression_params] + regression outputs. + boxes (list[torch.Tensor]): [M, 4] proposal boxes per batch + element. + boxes_mask (list[torch.Tensor]): positive (1), ignore (-1), + negative (0). + target_boxes (list[torch.Tensor]): list of [M, 4] assigned target + boxes for each proposal. + target_classes (list[torch.Tensor]): list of [M,] assigned target + classes for each proposal. + + Returns: + RCNNLosses: classification and regression losses. + """ + # get targets + targets = [] + for boxs, boxs_mask, tgt_boxs, tgt_cls in zip( + boxes, boxes_mask, target_boxes, target_classes + ): + targets.append( + self._get_targets_per_image(boxs, boxs_mask, tgt_boxs, tgt_cls) + ) + + labels = torch.cat([tgt.labels for tgt in targets], 0) + label_weights = torch.cat([tgt.label_weights for tgt in targets], 0) + bbox_targets = torch.cat([tgt.bbox_targets for tgt in targets], 0) + bbox_weights = torch.cat([tgt.bbox_weights for tgt in targets], 0) + + # compute losses + avg_factor = torch.sum(torch.greater(label_weights, 0)).clamp(1.0) + if class_outs.numel() > 0: + loss_cls = SumWeightedLoss(label_weights, avg_factor)( + self.loss_cls(class_outs, labels, reduction="none") + ) + else: + loss_cls = class_outs.sum() + + bg_class_ind = self.num_classes + # 0~self.num_classes-1 are FG, self.num_classes is BG + pos_inds = torch.logical_and( + torch.greater_equal(labels, 0), torch.less(labels, bg_class_ind) + ) + # do not perform bounding box regression for BG anymore. + if pos_inds.any(): + pos_reg_outs = regression_outs.view( + regression_outs.size(0), -1, 4 + )[pos_inds.type(torch.bool), labels[pos_inds.type(torch.bool)]] + loss_bbox = self.loss_bbox( + pred=pos_reg_outs, + target=bbox_targets[pos_inds.type(torch.bool)], + reducer=SumWeightedLoss( + bbox_weights[pos_inds.type(torch.bool)], + bbox_targets.size(0), + ), + ) + else: + loss_bbox = regression_outs[pos_inds].sum() + + return RCNNLosses(rcnn_loss_cls=loss_cls, rcnn_loss_bbox=loss_bbox) diff --git a/vis4d/op/detect/retinanet.py b/vis4d/op/detect/retinanet.py new file mode 100644 index 0000000000000000000000000000000000000000..28e847c56336fb1a62b308e93d8a015a89069481 --- /dev/null +++ b/vis4d/op/detect/retinanet.py @@ -0,0 +1,410 @@ +"""RetinaNet.""" + +from __future__ import annotations + +from math import prod +from typing import NamedTuple + +import torch +from torch import nn +from torchvision.ops import batched_nms, sigmoid_focal_loss + +from vis4d.common.typing import TorchLossFunc +from vis4d.op.box.anchor import AnchorGenerator +from vis4d.op.box.box2d import bbox_clip, filter_boxes_by_area +from vis4d.op.box.encoder import DeltaXYWHBBoxDecoder, DeltaXYWHBBoxEncoder +from vis4d.op.box.matchers import Matcher, MaxIoUMatcher +from vis4d.op.box.samplers import PseudoSampler, Sampler +from vis4d.op.loss.common import l1_loss + +from .common import DetOut +from .dense_anchor import DenseAnchorHeadLoss + + +class RetinaNetOut(NamedTuple): + """RetinaNet head outputs.""" + + # Logits for box classification for each feature level. The logit + # dimention is [batch_size, number of anchors * number of classes, height, + # width]. + cls_score: list[torch.Tensor] + # Each box has regression for all classes for each feature level. So the + # tensor dimension is [batch_size, number of anchors * 4, height, width]. + bbox_pred: list[torch.Tensor] + + +def get_default_anchor_generator() -> AnchorGenerator: + """Get default anchor generator.""" + return AnchorGenerator( + octave_base_scale=4, + scales_per_octave=3, + ratios=[0.5, 1.0, 2.0], + strides=[8, 16, 32, 64, 128], + ) + + +def get_default_box_codec() -> ( + tuple[DeltaXYWHBBoxEncoder, DeltaXYWHBBoxDecoder] +): + """Get the default bounding box encoder.""" + return ( + DeltaXYWHBBoxEncoder( + target_means=(0.0, 0.0, 0.0, 0.0), target_stds=(1.0, 1.0, 1.0, 1.0) + ), + DeltaXYWHBBoxDecoder( + target_means=(0.0, 0.0, 0.0, 0.0), target_stds=(1.0, 1.0, 1.0, 1.0) + ), + ) + + +def get_default_box_matcher() -> MaxIoUMatcher: + """Get default bounding box matcher.""" + return MaxIoUMatcher( + thresholds=[0.4, 0.5], + labels=[0, -1, 1], + allow_low_quality_matches=True, + ) + + +def get_default_box_sampler() -> PseudoSampler: + """Get default bounding box sampler.""" + return PseudoSampler() + + +class RetinaNetHead(nn.Module): # TODO: Refactor to use the new API + """RetinaNet Head.""" + + def __init__( + self, + num_classes: int, + in_channels: int, + feat_channels: int = 256, + stacked_convs: int = 4, + use_sigmoid_cls: bool = True, + anchor_generator: AnchorGenerator | None = None, + box_decoder: DeltaXYWHBBoxDecoder | None = None, + box_matcher: Matcher | None = None, + box_sampler: Sampler | None = None, + ): + """Creates an instance of the class.""" + super().__init__() + self.anchor_generator = ( + anchor_generator + if anchor_generator is not None + else get_default_anchor_generator() + ) + if box_decoder is None: + _, self.box_decoder = get_default_box_codec() + else: + self.box_decoder = box_decoder + self.box_matcher = ( + box_matcher + if box_matcher is not None + else get_default_box_matcher() + ) + self.box_sampler = ( + box_sampler + if box_sampler is not None + else get_default_box_sampler() + ) + num_base_priors = self.anchor_generator.num_base_priors[0] + + if use_sigmoid_cls: + cls_out_channels = num_classes + else: + cls_out_channels = num_classes + 1 + self.relu = nn.ReLU(inplace=True) + self.cls_convs = nn.ModuleList() + self.reg_convs = nn.ModuleList() + for i in range(stacked_convs): + chn = in_channels if i == 0 else feat_channels + self.cls_convs.append( + nn.Conv2d(chn, feat_channels, 3, stride=1, padding=1), + ) + self.reg_convs.append( + nn.Conv2d(chn, feat_channels, 3, stride=1, padding=1), + ) + self.retina_cls = nn.Conv2d( + feat_channels, num_base_priors * cls_out_channels, 3, padding=1 + ) + self.retina_reg = nn.Conv2d( + feat_channels, num_base_priors * 4, 3, padding=1 + ) + + def forward(self, features: list[torch.Tensor]) -> RetinaNetOut: + """Forward pass of RetinaNet. + + Args: + features (list[torch.Tensor]): Feature pyramid + + Returns: + RetinaNetOut: classification score and box prediction. + """ + cls_scores, bbox_preds = [], [] + for feat in features: + cls_feat = feat + reg_feat = feat + for cls_conv in self.cls_convs: + cls_feat = self.relu(cls_conv(cls_feat)) + for reg_conv in self.reg_convs: + reg_feat = self.relu(reg_conv(reg_feat)) + cls_scores.append(self.retina_cls(cls_feat)) + bbox_preds.append(self.retina_reg(reg_feat)) + return RetinaNetOut(cls_score=cls_scores, bbox_pred=bbox_preds) + + def __call__(self, features: list[torch.Tensor]) -> RetinaNetOut: + """Type definition for call implementation.""" + return self._call_impl(features) + + +def get_params_per_level( + cls_out: torch.Tensor, + reg_out: torch.Tensor, + anchors: torch.Tensor, + num_pre_nms: int = 2000, + score_thr: float = 0.0, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Get topk params from feature output per level per image before nms. + + Params include flattened classification scores, box energies, and + corresponding anchors. + + Args: + cls_out (torch.Tensor): + [C, H, W] classification scores at a particular scale. + reg_out (torch.Tensor): + [C, H, W] regression parameters at a particular scale. + anchors (torch.Tensor): [H * W, 4] anchor boxes per cell. + num_pre_nms (int): number of predictions before nms. + score_thr (float): score threshold for filtering predictions. + + Returns: + tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: topk + flattened classification, regression outputs, and corresponding + anchors. + """ + assert cls_out.size()[-2:] == reg_out.size()[-2:], ( + f"Shape mismatch: cls_out({cls_out.size()[-2:]}), reg_out(" + f"{reg_out.size()[-2:]})." + ) + reg_out = reg_out.permute(1, 2, 0).reshape(-1, 4) + cls_out = cls_out.permute(1, 2, 0).reshape(reg_out.size(0), -1).sigmoid() + valid_mask = torch.greater(cls_out, score_thr) + valid_idxs = torch.nonzero(valid_mask) + num_topk = min(num_pre_nms, valid_idxs.size(0)) + cls_out_filt = cls_out[valid_mask] + cls_out_ranked, rank_inds = cls_out_filt.sort(descending=True) + topk_inds = valid_idxs[rank_inds[:num_topk]] + keep_inds, labels = topk_inds.unbind(dim=1) + cls_out = cls_out_ranked[:num_topk] + reg_out = reg_out[keep_inds, :] + anchors = anchors[keep_inds, :] + + return cls_out, labels, reg_out, anchors + + +def decode_multi_level_outputs( + cls_out_all: list[torch.Tensor], + lbl_out_all: list[torch.Tensor], + reg_out_all: list[torch.Tensor], + anchors_all: list[torch.Tensor], + image_hw: tuple[int, int], + box_decoder: DeltaXYWHBBoxDecoder, + max_per_img: int = 1000, + nms_threshold: float = 0.7, + min_box_size: tuple[int, int] = (0, 0), +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Decode box energies into detections for a single image. + + Detections are post-processed via NMS. NMS is performed per level. + Afterwards, select topk detections. + + Args: + cls_out_all (list[torch.Tensor]): topk class scores per level. + lbl_out_all (list[torch.Tensor]): topk class labels per level. + reg_out_all (list[torch.Tensor]): topk regression params per level. + anchors_all (list[torch.Tensor]): topk anchor boxes per level. + image_hw (tuple[int, int]): image size. + box_decoder (DeltaXYWHBBoxDecoder): bounding box encoder. + max_per_img (int, optional): maximum predictions per image. + Defaults to 1000. + nms_threshold (float, optional): iou threshold for NMS. + Defaults to 0.7. + min_box_size (tuple[int, int], optional): minimum box size. + Defaults to (0, 0). + + Returns: + tuple[torch.Tensor, torch.Tensor]: decoded proposal boxes & scores. + """ + scores, labels = torch.cat(cls_out_all), torch.cat(lbl_out_all) + boxes = bbox_clip( + box_decoder(torch.cat(anchors_all), torch.cat(reg_out_all)), + image_hw, + ) + + boxes, mask = filter_boxes_by_area(boxes, min_area=prod(min_box_size)) + scores, labels = scores[mask], labels[mask] + + if boxes.numel() > 0: + keep = batched_nms(boxes, scores, labels, iou_threshold=nms_threshold)[ + :max_per_img + ] + return boxes[keep], scores[keep], labels[keep] + return (boxes.new_zeros(0, 4), scores.new_zeros(0), labels.new_zeros(0)) + + +class Dense2Det(nn.Module): + """Compute detections from dense network outputs. + + This class acts as a stateless functor that does the following: + 1. Create anchor grid for feature grids (classification and regression + outputs) at all scales. + For each image + For each level + 2. Get a topk pre-selection of flattened classification scores and + box energies from feature output before NMS. + 3. Decode class scores and box energies into detection boxes, + apply NMS. + Return detection boxes for all images. + """ + + def __init__( + self, + anchor_generator: AnchorGenerator, + box_decoder: DeltaXYWHBBoxDecoder, + num_pre_nms: int = 2000, + max_per_img: int = 1000, + nms_threshold: float = 0.7, + min_box_size: tuple[int, int] = (0, 0), + score_thr: float = 0.0, + ) -> None: + """Creates an instance of the class.""" + super().__init__() + self.anchor_generator = anchor_generator + self.box_decoder = box_decoder + self.num_pre_nms = num_pre_nms + self.max_per_img = max_per_img + self.nms_threshold = nms_threshold + self.min_box_size = min_box_size + self.score_thr = score_thr + + def forward( + self, + cls_outs: list[torch.Tensor], + reg_outs: list[torch.Tensor], + images_hw: list[tuple[int, int]], + ) -> DetOut: + """Compute detections from dense network outputs. + + Generate anchor grid for all scales. + For each batch element: + Compute classification, regression, and anchor pairs for all + scales. Decode those pairs into proposals, post-process with NMS. + + Args: + cls_outs (list[torch.Tensor]): [N, C * A, H, W] per scale. + reg_outs (list[torch.Tensor]): [N, 4 * A, H, W] per scale. + images_hw (list[tuple[int, int]]): list of image sizes. + + Returns: + DetOut: Detection outputs. + """ + # since feature map sizes of all images are the same, we only compute + # anchors for one time + device = cls_outs[0].device + featmap_sizes: list[tuple[int, int]] = [ + featmap.size()[-2:] for featmap in cls_outs # type: ignore + ] + assert len(featmap_sizes) == self.anchor_generator.num_levels + anchor_grids = self.anchor_generator.grid_priors( + featmap_sizes, device=device + ) + proposals, scores, labels = [], [], [] + for img_id, image_hw in enumerate(images_hw): + cls_out_all, lbl_out_all, reg_out_all, anchors_all = [], [], [], [] + for cls_out, reg_out, anchor_grid in zip( + cls_outs, reg_outs, anchor_grids + ): + cls_out_, lbl_out, reg_out_, anchors = get_params_per_level( + cls_out[img_id], + reg_out[img_id], + anchor_grid, + self.num_pre_nms, + self.score_thr, + ) + cls_out_all += [cls_out_] + lbl_out_all += [lbl_out] + reg_out_all += [reg_out_] + anchors_all += [anchors] + + box, score, label = decode_multi_level_outputs( + cls_out_all, + lbl_out_all, + reg_out_all, + anchors_all, + image_hw, + self.box_decoder, + self.max_per_img, + self.nms_threshold, + self.min_box_size, + ) + proposals.append(box) + scores.append(score) + labels.append(label) + return DetOut(proposals, scores, labels) + + def __call__( + self, + cls_outs: list[torch.Tensor], + reg_outs: list[torch.Tensor], + images_hw: list[tuple[int, int]], + ) -> DetOut: + """Type definition for function call.""" + return self._call_impl(cls_outs, reg_outs, images_hw) + + +class RetinaNetHeadLoss(DenseAnchorHeadLoss): + """Loss of RetinaNet head.""" + + def __init__( + self, + anchor_generator: AnchorGenerator, + box_encoder: DeltaXYWHBBoxEncoder, + box_matcher: None | Matcher = None, + box_sampler: None | Sampler = None, + loss_cls: TorchLossFunc = sigmoid_focal_loss, + loss_bbox: TorchLossFunc = l1_loss, + ) -> None: + """Creates an instance of the class. + + Args: + anchor_generator (AnchorGenerator): Generates anchor grid priors. + box_encoder (DeltaXYWHBBoxEncoder): Encodes bounding boxes to the + desired network output. + box_matcher (None | Matcher, optional): Box matcher. Defaults to + None. + box_sampler (None | Sampler, optional): Box sampler. Defaults to + None. + loss_cls (TorchLossFunc, optional): Classification loss function. + Defaults to sigmoid_focal_loss. + loss_bbox (TorchLossFunc, optional): Regression loss function. + Defaults to l1_loss. + """ + matcher = ( + box_matcher + if box_matcher is not None + else get_default_box_matcher() + ) + sampler = ( + box_sampler + if box_sampler is not None + else get_default_box_sampler() + ) + super().__init__( + anchor_generator, + box_encoder, + matcher, + sampler, + loss_cls, + loss_bbox, + ) diff --git a/vis4d/op/detect/rpn.py b/vis4d/op/detect/rpn.py new file mode 100644 index 0000000000000000000000000000000000000000..ae646d47b2473588f96a82c25be0db202aa9f13a --- /dev/null +++ b/vis4d/op/detect/rpn.py @@ -0,0 +1,421 @@ +"""Faster RCNN RPN Head.""" + +from __future__ import annotations + +from math import prod +from typing import NamedTuple + +import torch +import torch.nn.functional as F +from torch import nn +from torchvision.ops import batched_nms + +from vis4d.common.typing import TorchLossFunc +from vis4d.op.box.anchor import AnchorGenerator +from vis4d.op.box.box2d import bbox_clip, filter_boxes_by_area +from vis4d.op.box.encoder import DeltaXYWHBBoxDecoder, DeltaXYWHBBoxEncoder +from vis4d.op.box.matchers import Matcher, MaxIoUMatcher +from vis4d.op.box.samplers import RandomSampler, Sampler +from vis4d.op.layer.conv2d import Conv2d +from vis4d.op.loss.common import l1_loss + +from .dense_anchor import DenseAnchorHeadLoss, DenseAnchorHeadLosses +from .typing import Proposals + + +class RPNOut(NamedTuple): + """Output of RPN head.""" + + # Sigmoid input for binary classification of the anchor + # Positive means there is an object in that anchor. + # Each list item is for on feature pyramid level. + cls: list[torch.Tensor] + # 4 x number of anchors for center offets and sizes (width, height) of the + # boxes under the anchor. + # Each list item is for on feature pyramid level. + box: list[torch.Tensor] + + +def get_default_rpn_box_codec( + target_means: tuple[float, float, float, float] = (0.0, 0.0, 0.0, 0.0), + target_stds: tuple[float, float, float, float] = (1.0, 1.0, 1.0, 1.0), +) -> tuple[DeltaXYWHBBoxEncoder, DeltaXYWHBBoxDecoder]: + """Get the default bounding box encoder and decoder for RPN.""" + return ( + DeltaXYWHBBoxEncoder(target_means, target_stds), + DeltaXYWHBBoxDecoder(target_means, target_stds), + ) + + +class RPNHead(nn.Module): + """Faster RCNN RPN Head. + + Creates RPN network output from a multi-scale feature map input. + """ + + rpn_conv: nn.Module + + def __init__( + self, + num_anchors: int, + num_convs: int = 1, + in_channels: int = 256, + feat_channels: int = 256, + start_level: int = 2, + ) -> None: + """Creates an instance of the class. + + Args: + num_anchors (int): Number of anchors per cell. + num_convs (int, optional): Number of conv layers before RPN heads. + Defaults to 1. + in_channels (int, optional): Feature channel size of input feature + maps. Defaults to 256. + feat_channels (int, optional): Feature channel size of conv layers. + Defaults to 256. + start_level (int, optional): starting level of feature maps. + Defaults to 2. + """ + super().__init__() + self.start_level = start_level + + if num_convs > 1: + rpn_convs = [] + for i in range(num_convs): + if i > 0: + in_channels = feat_channels + rpn_convs.append( + Conv2d( + in_channels, + feat_channels, + kernel_size=3, + padding=1, + activation=nn.ReLU(inplace=False), + ) + ) + self.rpn_conv = nn.Sequential(*rpn_convs) + else: + self.rpn_conv = Conv2d( + in_channels, + feat_channels, + kernel_size=3, + padding=1, + activation=nn.ReLU(inplace=True), + ) + self.rpn_cls = Conv2d(feat_channels, num_anchors, 1) + self.rpn_box = Conv2d(feat_channels, num_anchors * 4, 1) + + self.apply(self._init_weights) + + @staticmethod + def _init_weights(module: nn.Module) -> None: + """Init RPN weights.""" + if isinstance(module, nn.Conv2d): + module.weight.data.normal_(mean=0.0, std=0.01) + if module.bias is not None: + module.bias.data.zero_() + + def forward(self, features: list[torch.Tensor]) -> RPNOut: + """Forward pass of RPN.""" + cls_outs, box_outs = [], [] + for feat in features[self.start_level :]: + feat = self.rpn_conv(feat) + cls_outs += [self.rpn_cls(feat)] + box_outs += [self.rpn_box(feat)] + return RPNOut(cls=cls_outs, box=box_outs) + + def __call__(self, features: list[torch.Tensor]) -> RPNOut: + """Type definition.""" + return self._call_impl(features) + + +class RPN2RoI(nn.Module): + """Generate Proposals (RoIs) from RPN network output. + + This class acts as a stateless functor that does the following: + 1. Create anchor grid for feature grids (classification and regression + outputs) at all scales. + For each image + For each level + 2. Get a topk pre-selection of flattened classification scores and + box energies from feature output before NMS. + 3. Decode class scores and box energies into proposal boxes, apply NMS. + Return proposal boxes for all images. + """ + + def __init__( + self, + anchor_generator: AnchorGenerator, + box_decoder: None | DeltaXYWHBBoxDecoder = None, + num_proposals_pre_nms_train: int = 2000, + num_proposals_pre_nms_test: int = 1000, + max_per_img: int = 1000, + proposal_nms_threshold: float = 0.7, + min_proposal_size: tuple[int, int] = (0, 0), + ) -> None: + """Creates an instance of the class. + + Args: + anchor_generator (AnchorGenerator): Creates anchor grid serving as + for bounding box regression. + box_decoder (DeltaXYWHBBoxDecoder, optional): decodes box energies + predicted by the network into 2D bounding box parameters. + Defaults to None. If None, uses the default decoder. + num_proposals_pre_nms_train (int, optional): How many boxes are + kept prior to NMS during training. Defaults to 2000. + num_proposals_pre_nms_test (int, optional): How many boxes are + kept prior to NMS during inference. Defaults to 1000. + max_per_img (int, optional): Maximum boxes per image. + Defaults to 1000. + proposal_nms_threshold (float, optional): NMS threshold on proposal + boxes. Defaults to 0.7. + min_proposal_size (tuple[int, int], optional): Minimum size of a + proposal box. Defaults to (0, 0). + """ + super().__init__() + self.anchor_generator = anchor_generator + + if box_decoder is None: + _, self.box_decoder = get_default_rpn_box_codec() + else: + self.box_decoder = box_decoder + + self.max_per_img = max_per_img + self.min_proposal_size = min_proposal_size + self.num_proposals_pre_nms_train = num_proposals_pre_nms_train + self.num_proposals_pre_nms_test = num_proposals_pre_nms_test + self.proposal_nms_threshold = proposal_nms_threshold + + def _get_params_per_level( + self, + cls_out: torch.Tensor, + reg_out: torch.Tensor, + anchors: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Get a topk pre-selection of parameters. + + The parameters include flattened classification scores and box + energies from feature output per level per image before nms. + + Args: + cls_out (torch.Tensor): [C, H, W] classification scores at a + particular scale. + reg_out (torch.Tensor): [C, H, W] regression parameters at a + particular scale. + anchors (torch.Tensor): [H*W, 4] anchor boxes per cell. + + Returns: + tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Topk flattened + classification, regression outputs and corresponding anchors. + """ + assert cls_out.size()[-2:] == reg_out.size()[-2:], ( + f"Shape mismatch: cls_out({cls_out.size()[-2:]}), reg_out(" + f"{reg_out.size()[-2:]})." + ) + cls_out = cls_out.permute(1, 2, 0).reshape(-1).sigmoid() + reg_out = reg_out.permute(1, 2, 0).reshape(-1, 4) + if self.training: + num_proposals_pre_nms = self.num_proposals_pre_nms_train + else: + num_proposals_pre_nms = self.num_proposals_pre_nms_test + + if 0 < num_proposals_pre_nms < cls_out.shape[0]: + cls_out_ranked, rank_inds = cls_out.sort(descending=True) + topk_inds = rank_inds[:num_proposals_pre_nms] + cls_out = cls_out_ranked[:num_proposals_pre_nms] + reg_out = reg_out[topk_inds, :] + anchors = anchors[topk_inds, :] + + return cls_out, reg_out, anchors + + def _decode_multi_level_outputs( + self, + cls_out_all: list[torch.Tensor], + reg_out_all: list[torch.Tensor], + anchors_all: list[torch.Tensor], + level_all: list[torch.Tensor], + image_hw: tuple[int, int], + ) -> tuple[torch.Tensor, torch.Tensor]: + """Decode box energies into proposals for a single image, post-process. + + Post-processing happens via NMS. NMS is performed per level. + Afterwards, select topk proposals. + + Args: + cls_out_all (list[torch.Tensor]): topk class scores per level. + reg_out_all (list[torch.Tensor]): topk regression params per level. + anchors_all (list[torch.Tensor]): topk anchor boxes per level. + level_all (list[torch.Tensor]): tensors indicating level per entry. + image_hw (tuple[int, int]): image size. + + Returns: + tuple[torch.Tensor, torch.Tensor]: decoded proposal boxes & scores. + """ + scores = torch.cat(cls_out_all) + levels = torch.cat(level_all) + + proposals = bbox_clip( + self.box_decoder(torch.cat(anchors_all), torch.cat(reg_out_all)), + image_hw, + ) + + proposals, mask = filter_boxes_by_area( + proposals, min_area=prod(self.min_proposal_size) + ) + scores = scores[mask] + levels = levels[mask] + + if proposals.numel() > 0: + keep = batched_nms( + proposals, + scores, + levels, + iou_threshold=self.proposal_nms_threshold, + )[: self.max_per_img] + proposals = proposals[keep] + scores = scores[keep] + else: # pragma: no cover + return proposals.new_zeros(0, 4), scores.new_zeros(0) + return proposals, scores + + def forward( + self, + class_outs: list[torch.Tensor], + regression_outs: list[torch.Tensor], + images_hw: list[tuple[int, int]], + ) -> Proposals: + """Compute proposals from RPN network outputs. + + Generate anchor grid for all scales. + For each batch element: + Compute classification, regression and anchor pairs for all scales. + Decode those pairs into proposals, post-process with NMS. + + Args: + class_outs (list[torch.Tensor]): [N, 1 * A, H, W] per scale. + regression_outs (list[torch.Tensor]): [N, 4 * A, H, W] per scale. + images_hw (list[tuple[int, int]]): list of image sizes. + + Returns: + Proposals: proposal boxes and scores. + """ + # since feature map sizes of all images are the same, we only compute + # anchors for one time + device = class_outs[0].device + featmap_sizes: list[tuple[int, int]] = [ + featmap.size()[-2:] for featmap in class_outs # type: ignore + ] + assert len(featmap_sizes) == self.anchor_generator.num_levels + anchor_grids = self.anchor_generator.grid_priors( + featmap_sizes, device=device + ) + proposals, scores = [], [] + for img_id, image_hw in enumerate(images_hw): + cls_out_all, reg_out_all, anchors_all, level_all = [], [], [], [] + for level, (cls_outs, reg_outs, anchor_grid) in enumerate( + zip(class_outs, regression_outs, anchor_grids) + ): + cls_out, reg_out, anchors = self._get_params_per_level( + cls_outs[img_id], reg_outs[img_id], anchor_grid + ) + cls_out_all += [cls_out] + reg_out_all += [reg_out] + anchors_all += [anchors] + level_all += [ + cls_out.new_full((len(cls_out),), level, dtype=torch.long) + ] + + box, score = self._decode_multi_level_outputs( + cls_out_all, reg_out_all, anchors_all, level_all, image_hw + ) + proposals.append(box) + scores.append(score) + return Proposals(proposals, scores) + + +class RPNLosses(NamedTuple): + """RPN loss container.""" + + rpn_loss_cls: torch.Tensor + rpn_loss_bbox: torch.Tensor + + +class RPNLoss(DenseAnchorHeadLoss): + """Loss of region proposal network.""" + + def __init__( + self, + anchor_generator: AnchorGenerator, + box_encoder: DeltaXYWHBBoxEncoder, + matcher: Matcher | None = None, + sampler: Sampler | None = None, + loss_cls: TorchLossFunc = F.binary_cross_entropy_with_logits, + loss_bbox: TorchLossFunc = l1_loss, + ): + """Creates an instance of the class. + + Args: + anchor_generator (AnchorGenerator): Generates anchor grid priors. + box_encoder (DeltaXYWHBBoxEncoder): Encodes bounding boxes to the + desired network output. + matcher (Matcher): Matches ground truth boxes to anchor grid + priors. Defaults to None. If None, uses MaxIoUMatcher. + sampler (Sampler): Samples anchors for training. Defaults to None. + If None, uses RandomSampler. + loss_cls (TorchLossFunc): Classification loss function. Defaults to + F.binary_cross_entropy_with_logits. + loss_bbox (TorchLossFunc): Regression loss function. Defaults to + l1_loss. + """ + matcher = ( + MaxIoUMatcher( + thresholds=[0.3, 0.7], + labels=[0, -1, 1], + allow_low_quality_matches=True, + min_positive_iou=0.3, + ) + if matcher is None + else matcher + ) + + sampler = ( + RandomSampler(batch_size=256, positive_fraction=0.5) + if sampler is None + else sampler + ) + + super().__init__( + anchor_generator, + box_encoder, + matcher, + sampler, + loss_cls, + loss_bbox, + ) + + def forward( + self, + cls_outs: list[torch.Tensor], + reg_outs: list[torch.Tensor], + target_boxes: list[torch.Tensor], + images_hw: list[tuple[int, int]], + target_class_ids: list[torch.Tensor | float] | None = None, + ) -> DenseAnchorHeadLosses: + """Compute RPN classification and regression losses. + + Args: + cls_outs (list[torch.Tensor]): Network classification outputs + at all scales. + reg_outs (list[torch.Tensor]): Network regression outputs + at all scales. + target_boxes (list[torch.Tensor]): Target bounding boxes. + images_hw (list[tuple[int, int]]): Image dimensions + without padding. + target_class_ids (list[torch.Tensor] | None): Target class labels. + + Returns: + DenseAnchorHeadLosses: Classification and regression losses. + """ + return super().forward( + cls_outs, reg_outs, target_boxes, images_hw, target_class_ids + ) diff --git a/vis4d/op/detect/typing.py b/vis4d/op/detect/typing.py new file mode 100644 index 0000000000000000000000000000000000000000..5c9dc21d21b22a2f804e4e538fbe283566db9768 --- /dev/null +++ b/vis4d/op/detect/typing.py @@ -0,0 +1,22 @@ +"""Detect op typing.""" + +from __future__ import annotations + +from typing import NamedTuple + +from torch import Tensor + + +class Proposals(NamedTuple): + """Output structure for 2D bounding box proposals.""" + + boxes: list[Tensor] + scores: list[Tensor] + + +class Targets(NamedTuple): + """Output structure for targets.""" + + boxes: list[Tensor] + classes: list[Tensor] + labels: list[Tensor] diff --git a/vis4d/op/detect/yolox.py b/vis4d/op/detect/yolox.py new file mode 100644 index 0000000000000000000000000000000000000000..269565db94f29ac85ff651d3e2d700b7d22d2e41 --- /dev/null +++ b/vis4d/op/detect/yolox.py @@ -0,0 +1,714 @@ +"""YOLOX detection head. + +Modified from mmdetection (https://github.com/open-mmlab/mmdetection). +""" + +from __future__ import annotations + +import math +from collections.abc import Sequence +from typing import NamedTuple + +import torch +import torch.nn.functional as F +from torch import Tensor, nn +from torchvision.ops import batched_nms + +from vis4d.common.distributed import reduce_mean +from vis4d.common.typing import TorchLossFunc +from vis4d.op.box.anchor import MlvlPointGenerator +from vis4d.op.box.encoder import YOLOXBBoxDecoder +from vis4d.op.box.matchers import SimOTAMatcher +from vis4d.op.box.samplers import PseudoSampler +from vis4d.op.layer.conv2d import Conv2d +from vis4d.op.layer.weight_init import bias_init_with_prob +from vis4d.op.loss import IoULoss +from vis4d.op.loss.reducer import SumWeightedLoss + +from .common import DetOut + + +class YOLOXOut(NamedTuple): + """YOLOX head outputs.""" + + # Logits for box classification for each feature level. The logit + # dimention is [batch_size, number of classes, height, width]. + cls_score: list[torch.Tensor] + # Each box has regression for all classes for each feature level. So the + # tensor dimension is [batch_size, 4, height, width]. + bbox_pred: list[torch.Tensor] + # Objectness scores for each feature level. The tensor dimension is + # [batch_size, 1, height, width] + objectness: list[torch.Tensor] + + +def get_default_point_generator() -> MlvlPointGenerator: + """Get default point generator.""" + return MlvlPointGenerator(strides=[8, 16, 32], offset=0) + + +class YOLOXHead(nn.Module): + """YOLOX Head. + + Args: + num_classes (int): Number of classes. + in_channels (int): Number of input channels. + feat_channels (int, optional): Number of feature channels. Defaults to + 256. + stacked_convs (int, optional): Number of stacked convolutions. Defaults + to 2. + strides (Sequence[int], optional): Strides for each feature level. + Defaults to (8, 16, 32). + point_generator (MlvlPointGenerator, optional): Point generator. + Defaults to None. + box_decoder (YOLOXBBoxDecoder, optional): Bounding box decoder. + Defaults to None. + box_matcher (Matcher, optional): Bounding box matcher. Defaults to + None. + box_sampler (Sampler, optional): Bounding box sampler. Defaults to + None. + """ + + def __init__( + self, + num_classes: int, + in_channels: int, + feat_channels: int = 256, + stacked_convs: int = 2, + strides: Sequence[int] = (8, 16, 32), + point_generator: MlvlPointGenerator | None = None, + box_decoder: YOLOXBBoxDecoder | None = None, + ): + """Creates an instance of the class.""" + super().__init__() + self.point_generator = ( + point_generator + if point_generator is not None + else get_default_point_generator() + ) + if box_decoder is None: + self.box_decoder = YOLOXBBoxDecoder() + else: + self.box_decoder = box_decoder + + self.multi_level_cls_convs = nn.ModuleList() + self.multi_level_reg_convs = nn.ModuleList() + self.multi_level_conv_cls = nn.ModuleList() + self.multi_level_conv_reg = nn.ModuleList() + self.multi_level_conv_obj = nn.ModuleList() + for _ in strides: + self.multi_level_cls_convs.append( + self._build_stacked_convs( + in_channels, feat_channels, stacked_convs + ) + ) + self.multi_level_reg_convs.append( + self._build_stacked_convs( + in_channels, feat_channels, stacked_convs + ) + ) + conv_cls, conv_reg, conv_obj = self._build_predictor( + feat_channels, num_classes + ) + self.multi_level_conv_cls.append(conv_cls) + self.multi_level_conv_reg.append(conv_reg) + self.multi_level_conv_obj.append(conv_obj) + self._init_weights() + + def _build_stacked_convs( + self, in_channels: int, feat_channels: int, stacked_convs: int + ) -> nn.Module: + """Initialize conv layers of a single level head. + + Args: + in_channels (int): Number of input channels. + feat_channels (int): Number of feature channels. + stacked_convs (int): Number of stacked conv layers. + """ + stacked_conv_layers = [] + for i in range(stacked_convs): + chn = in_channels if i == 0 else feat_channels + stacked_conv_layers.append( + Conv2d( + chn, + feat_channels, + 3, + stride=1, + padding=1, + norm=nn.BatchNorm2d( + feat_channels, eps=0.001, momentum=0.03 + ), + activation=nn.SiLU(inplace=True), + bias=False, + ) + ) + return nn.Sequential(*stacked_conv_layers) + + def _build_predictor( + self, feat_channels: int, num_classes: int + ) -> tuple[nn.Module, nn.Module, nn.Module]: + """Initialize predictor layers of a single level head. + + Args: + feat_channels (int): Number of input channels. + num_classes (int): Number of classes. + """ + conv_cls = nn.Conv2d(feat_channels, num_classes, 1) + conv_reg = nn.Conv2d(feat_channels, 4, 1) + conv_obj = nn.Conv2d(feat_channels, 1, 1) + return conv_cls, conv_reg, conv_obj + + def _init_weights(self) -> None: + """Initialize weights.""" + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_uniform_( + m.weight, + a=math.sqrt(5), + mode="fan_in", + nonlinearity="leaky_relu", + ) + bias_init = bias_init_with_prob(0.01) + for conv_cls, conv_obj in zip( + self.multi_level_conv_cls, self.multi_level_conv_obj + ): + conv_cls.bias.data.fill_(bias_init) # type: ignore + conv_obj.bias.data.fill_(bias_init) # type: ignore + + def forward(self, features: list[torch.Tensor]) -> YOLOXOut: + """Forward pass of YOLOX head. + + Args: + features (list[torch.Tensor]): Input features. + + Returns: + YOLOXOut: Classification, box, and objectness predictions. + """ + cls_score, bbox_pred, objectness = [], [], [] + for feature, cls_conv, reg_conv, conv_cls, conv_reg, conv_obj in zip( + features, + self.multi_level_cls_convs, + self.multi_level_reg_convs, + self.multi_level_conv_cls, + self.multi_level_conv_reg, + self.multi_level_conv_obj, + ): + cls_feat = cls_conv(feature) + reg_feat = reg_conv(feature) + + cls_score.append(conv_cls(cls_feat)) + bbox_pred.append(conv_reg(reg_feat)) + objectness.append(conv_obj(reg_feat)) + return YOLOXOut( + cls_score=cls_score, bbox_pred=bbox_pred, objectness=objectness + ) + + def __call__(self, features: list[torch.Tensor]) -> YOLOXOut: + """Type definition for call implementation.""" + return self._call_impl(features) + + +def bboxes_nms( + cls_scores: torch.Tensor, + bboxes: torch.Tensor, + objectness: torch.Tensor, + nms_threshold: float = 0.65, + score_thr: float = 0.01, + nms_pre: int = -1, + max_per_img: int = -1, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Decode box energies into detections for a single image. + + Detections are post-processed via NMS. NMS is performed per level. + Afterwards, select topk detections. + + Args: + cls_scores (torch.Tensor): topk class scores per level. + bboxes (torch.Tensor): topk class labels per level. + objectness (torch.Tensor): topk regression params per level. + nms_threshold (float, optional): iou threshold for NMS. + Defaults to 0.65. + score_thr (float, optional): score threshold to filter detections. + Defaults to 0.01. + nms_pre (int, optional): number of topk results before NMS. + Defaults to -1 (all). + max_per_img (int, optional): number of topk results after NMS. + Defaults to -1 (all). + + Returns: + tuple[torch.Tensor, torch.Tensor, torch.Tensor]: decoded boxes, scores, + and labels. + """ + if nms_pre == -1: + nms_pre = len(cls_scores) + if max_per_img == -1: + max_per_img = len(cls_scores) + max_scores, labels = torch.max(cls_scores, 1) + valid_mask = objectness * max_scores >= score_thr + valid_idxs = valid_mask.nonzero()[:, 0] + num_topk = min(nms_pre, valid_mask.sum()) # type: ignore + + scores, idxs = (max_scores[valid_mask] * objectness[valid_mask]).sort( + descending=True + ) + scores = scores[:num_topk] + topk_idxs = valid_idxs[idxs[:num_topk]] + + bboxes = bboxes[topk_idxs] + labels = labels[topk_idxs] + + if labels.numel() > 0: + keep = batched_nms(bboxes, scores, labels, nms_threshold)[:max_per_img] + return bboxes[keep], scores[keep], labels[keep] + return bboxes.new_zeros(0, 4), scores.new_zeros(0), labels.new_zeros(0) + + +def preprocess_outputs( + cls_outs: list[torch.Tensor], + reg_outs: list[torch.Tensor], + obj_outs: list[torch.Tensor], + images_hw: list[tuple[int, int]], + point_generator: MlvlPointGenerator, + box_decoder: YOLOXBBoxDecoder, +) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + """Preprocess model outputs before postprocessing/loss computation. + + Args: + cls_outs (list[torch.Tensor]): [N, C, H, W] per scale. + reg_outs (list[torch.Tensor]): [N, 4, H, W] per scale. + obj_outs (list[torch.Tensor]): [N, 1, H, W] per scale. + images_hw (list[tuple[int, int]]): List of image sizes. + point_generator (MlvlPointGenerator): Point generator. + box_decoder (YOLOXBBoxDecoder): Box decoder. + + Returns: + tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: Flattened outputs. + """ + dtype, device = cls_outs[0].dtype, cls_outs[0].device + num_imgs = len(images_hw) + num_classes = cls_outs[0].shape[1] + featmap_sizes: list[tuple[int, int]] = [ + tuple(featmap.size()[-2:]) for featmap in cls_outs # type: ignore + ] + assert len(featmap_sizes) == point_generator.num_levels + mlvl_points = point_generator.grid_priors( + featmap_sizes, dtype=dtype, device=device, with_stride=True + ) + + # flatten cls_outs, reg_outs and obj_outs + cls_list = [ + cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, num_classes) + for cls_score in cls_outs + ] + reg_list = [ + bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4) + for bbox_pred in reg_outs + ] + obj_list = [ + objectness.permute(0, 2, 3, 1).reshape(num_imgs, -1) + for objectness in obj_outs + ] + + flatten_cls = torch.cat(cls_list, dim=1) + flatten_reg = torch.cat(reg_list, dim=1) + flatten_obj = torch.cat(obj_list, dim=1) + flatten_points = torch.cat(mlvl_points) + + flatten_boxes = box_decoder(flatten_points, flatten_reg) + return flatten_cls, flatten_reg, flatten_obj, flatten_points, flatten_boxes + + +class YOLOXPostprocess(nn.Module): + """Postprocess detections from YOLOX detection head.""" + + def __init__( + self, + point_generator: MlvlPointGenerator, + box_decoder: YOLOXBBoxDecoder, + nms_threshold: float = 0.65, + score_thr: float = 0.01, + nms_pre: int = -1, + max_per_img: int = -1, + ) -> None: + """Creates an instance of the class. + + Args: + point_generator (MlvlPointGenerator): Point generator. + box_decoder (YOLOXBBoxDecoder): Box decoder. + nms_threshold (float, optional): IoU threshold for NMS. Defaults to + 0.65. + score_thr (float, optional): Score threshold to filter detections. + Defaults to 0.01. + nms_pre (int, optional): Number of topk results before NMS. + Defaults to -1 (all). + max_per_img (int, optional): Number of topk results after NMS. + Defaults to -1 (all). + """ + super().__init__() + self.point_generator = point_generator + self.box_decoder = box_decoder + self.nms_threshold = nms_threshold + self.score_thr = score_thr + self.nms_pre = nms_pre + self.max_per_img = max_per_img + + def forward( + self, + cls_outs: list[torch.Tensor], + reg_outs: list[torch.Tensor], + obj_outs: list[torch.Tensor], + images_hw: list[tuple[int, int]], + ) -> DetOut: + """Forward pass. + + Args: + cls_outs (list[torch.Tensor]): [N, C, H, W] per scale. + reg_outs (list[torch.Tensor]): [N, 4, H, W] per scale. + obj_outs (list[torch.Tensor]): [N, 1, H, W] per scale. + images_hw (list[tuple[int, int]]): list of image sizes. + + Returns: + DetOut: Detection outputs. + """ + flatten_cls, _, flatten_obj, _, flatten_boxes = preprocess_outputs( + cls_outs, + reg_outs, + obj_outs, + images_hw, + self.point_generator, + self.box_decoder, + ) + flatten_cls, flatten_obj = flatten_cls.sigmoid(), flatten_obj.sigmoid() + + bbox_list, score_list, label_list = [], [], [] + for img_id, _ in enumerate(images_hw): + bboxes, scores, labels = bboxes_nms( + flatten_cls[img_id], + flatten_boxes[img_id], + flatten_obj[img_id], + nms_threshold=self.nms_threshold, + score_thr=self.score_thr, + nms_pre=self.nms_pre, + max_per_img=self.max_per_img, + ) + bbox_list.append(bboxes) + score_list.append(scores) + label_list.append(labels) + return DetOut(bbox_list, score_list, label_list) + + def __call__( + self, + cls_outs: list[torch.Tensor], + reg_outs: list[torch.Tensor], + obj_outs: list[torch.Tensor], + images_hw: list[tuple[int, int]], + ) -> DetOut: + """Type definition for function call.""" + return self._call_impl(cls_outs, reg_outs, obj_outs, images_hw) + + +class YOLOXHeadLosses(NamedTuple): + """YOLOX head loss container.""" + + loss_cls: Tensor + loss_bbox: Tensor + loss_obj: Tensor + loss_l1: Tensor | None + + +def bbox_xyxy_to_cxcywh(bbox: Tensor) -> Tensor: + """Convert bbox coordinates from (x1, y1, x2, y2) to (cx, cy, w, h). + + Args: + bbox (Tensor): Shape (n, 4) for bboxes. + + Returns: + Tensor: Converted bboxes. + """ + x1, y1, x2, y2 = bbox.split((1, 1, 1, 1), dim=-1) + bbox_new = [(x1 + x2) / 2, (y1 + y2) / 2, (x2 - x1), (y2 - y1)] + return torch.cat(bbox_new, dim=-1) + + +def get_l1_target( + bbox_target: Tensor, priors: Tensor, eps: float = 1e-8 +) -> Tensor: + """Convert gt bboxes to center offset and log width height. + + Args: + bbox_target (Tensor): Shape (n, 4) for ground-truth bboxes. + priors (Tensor): Shape (n, 4) for prior boxes. + eps (float, optional): Epsilon for numerical stability. Defaults to + 1e-8. + """ + l1_target = bbox_target.new_zeros((len(bbox_target), 4)) + gt_cxcywh = bbox_xyxy_to_cxcywh(bbox_target) + l1_target[:, :2] = (gt_cxcywh[:, :2] - priors[:, :2]) / priors[:, 2:] + l1_target[:, 2:] = torch.log(gt_cxcywh[:, 2:] / priors[:, 2:] + eps) + return l1_target + + +class YOLOXHeadLoss(nn.Module): + """Loss of YOLOX head.""" + + def __init__( + self, + num_classes: int, + point_generator: MlvlPointGenerator | None = None, + box_decoder: YOLOXBBoxDecoder | None = None, + loss_cls: TorchLossFunc = F.binary_cross_entropy_with_logits, + loss_bbox: TorchLossFunc = IoULoss(mode="square", eps=1e-16), + loss_obj: TorchLossFunc = F.binary_cross_entropy_with_logits, + loss_l1: TorchLossFunc | None = None, + ) -> None: + """Creates an instance of the class. + + Args: + num_classes (int): Number of classes. + point_generator (MlvlPointGenerator): Point generator. + box_decoder (YOLOXBBoxDecoder): Box decoder. + loss_cls (TorchLossFunc, optional): Classification loss function. + Defaults to sigmoid_focal_loss. + loss_bbox (TorchLossFunc, optional): Regression loss function. + Defaults to l1_loss. + loss_obj (TorchLossFunc, optional): Objectness loss function. + Defaults to sigmoid_focal_loss. + loss_l1 (TorchLossFunc | None, optional): L1 loss function. + Defaults to None. Only used during the final few epochs. + """ + super().__init__() + self.num_classes = num_classes + self.point_generator = ( + point_generator + if point_generator is not None + else get_default_point_generator() + ) + if box_decoder is None: + self.box_decoder = YOLOXBBoxDecoder() + else: + self.box_decoder = box_decoder + self.box_matcher = SimOTAMatcher() + self.box_sampler = PseudoSampler() + self.loss_cls = loss_cls + self.loss_bbox = loss_bbox + self.loss_obj = loss_obj + self.loss_l1 = loss_l1 + + def _get_target_single( + self, + cls_preds: Tensor, + objectness: Tensor, + priors: Tensor, + decoded_bboxes: Tensor, + gt_bboxes: Tensor, + gt_labels: Tensor, + ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, int]: + """Compute YOLOX training targets in a single image. + + Args: + cls_preds (Tensor): Classification predictions of one image, + a 2D-Tensor with shape [num_priors, num_classes] + objectness (Tensor): Objectness predictions of one image, + a 1D-Tensor with shape [num_priors] + priors (Tensor): All priors of one image, a 2D-Tensor with shape + [num_priors, 4] in [cx, xy, stride_w, stride_y] format. + decoded_bboxes (Tensor): Decoded bboxes predictions of one image, + a 2D-Tensor with shape [num_priors, 4] in [tl_x, tl_y, + br_x, br_y] format. + gt_bboxes (Tensor): Ground truth bboxes of one image, a 2D-Tensor + with shape [num_gts, 4] in [tl_x, tl_y, br_x, br_y] format. + gt_labels (Tensor): Ground truth labels of one image, a Tensor + with shape [num_gts]. + """ + num_priors = priors.size(0) + num_gts = gt_labels.size(0) + gt_bboxes = gt_bboxes.to(decoded_bboxes.dtype) + # No target + if num_gts == 0: + cls_target = cls_preds.new_zeros((0, self.num_classes)) + bbox_target = cls_preds.new_zeros((0, 4)) + l1_target = cls_preds.new_zeros((0, 4)) + obj_target = cls_preds.new_zeros((num_priors, 1)) + foreground_mask = cls_preds.new_zeros(num_priors).bool() + return ( + foreground_mask, + cls_target, + obj_target, + bbox_target, + l1_target, + 0, + ) + + # YOLOX uses center priors with 0.5 offset to assign targets, + # but use center priors without offset to regress bboxes. + offset_priors = torch.cat( + [priors[:, :2] + priors[:, 2:] * 0.5, priors[:, 2:]], dim=-1 + ) + + scores = cls_preds.sigmoid() * objectness.unsqueeze(1).sigmoid() + match_result = self.box_matcher( + scores.sqrt_(), + offset_priors, + decoded_bboxes, + gt_bboxes, + gt_labels, + ) + sampling_result = self.box_sampler(match_result) + positives = sampling_result.sampled_labels == 1 + pos_inds = sampling_result.sampled_box_indices[positives] + pos_tgt_inds = sampling_result.sampled_target_indices[positives] + num_pos_per_img = pos_inds.size(0) + + pos_ious = match_result.assigned_gt_iou[pos_inds] + # IOU aware classification score + cls_target = F.one_hot( # pylint: disable=not-callable + gt_labels[pos_tgt_inds], self.num_classes + ) * pos_ious.unsqueeze(-1) + obj_target = torch.zeros_like(objectness).unsqueeze(-1) + obj_target[pos_inds] = 1 + bbox_target = gt_bboxes[pos_tgt_inds] + if self.loss_l1 is not None: + l1_target = get_l1_target(bbox_target, priors[pos_inds]) + else: + l1_target = bbox_target.new_zeros((len(bbox_target), 4)) + foreground_mask = torch.zeros_like(objectness).to(torch.bool) + foreground_mask[pos_inds] = 1 + return ( + foreground_mask, + cls_target, + obj_target, + bbox_target, + l1_target, + num_pos_per_img, + ) + + def forward( + self, + cls_outs: list[Tensor], + reg_outs: list[Tensor], + obj_outs: list[Tensor], + target_boxes: list[Tensor], + target_class_ids: list[Tensor], + images_hw: list[tuple[int, int]], + ) -> YOLOXHeadLosses: + """Compute YOLOX classification, regression, and objectness losses. + + Args: + cls_outs (list[Tensor]): Network classification outputs at all + scales. + reg_outs (list[Tensor]): Network regression outputs at all scales. + obj_outs (list[Tensor]): Network objectness outputs at all scales. + target_boxes (list[Tensor]): Target bounding boxes. + images_hw (list[tuple[int, int]]): Image dimensions without + padding. + target_class_ids (list[Tensor]): Target class labels. + + Returns: + YOLOXHeadLosses: YOLOX losses. + """ + ( + flatten_cls, + flatten_reg, + flatten_obj, + flatten_points, + flatten_boxes, + ) = preprocess_outputs( + cls_outs, + reg_outs, + obj_outs, + images_hw, + self.point_generator, + self.box_decoder, + ) + + num_imgs = len(images_hw) + pos_masks_list, cls_targets_list, obj_targets_list = [], [], [] + bbox_targets_list, l1_targets_list, num_fg_imgs_list = [], [], [] + for flat_cls, flat_obj, flat_pts, flat_bxs, tgt_bxs, tgt_cls in zip( + flatten_cls.detach(), + flatten_obj.detach(), + flatten_points.unsqueeze(0).repeat(num_imgs, 1, 1), + flatten_boxes.detach(), + target_boxes, + target_class_ids, + ): + targets = self._get_target_single( + flat_cls, flat_obj, flat_pts, flat_bxs, tgt_bxs, tgt_cls + ) + pos_masks_list.append(targets[0]) + cls_targets_list.append(targets[1]) + obj_targets_list.append(targets[2]) + bbox_targets_list.append(targets[3]) + l1_targets_list.append(targets[4]) + num_fg_imgs_list.append(targets[5]) + + num_pos = torch.tensor( + sum(num_fg_imgs_list), dtype=torch.float, device=flatten_cls.device + ) + num_total_samples: Tensor | float = max( # type: ignore + reduce_mean(num_pos), 1.0 + ) + + pos_masks = torch.cat(pos_masks_list, 0) + cls_targets = torch.cat(cls_targets_list, 0) + obj_targets = torch.cat(obj_targets_list, 0) + bbox_targets = torch.cat(bbox_targets_list, 0) + if self.loss_l1 is not None: + l1_targets = torch.cat(l1_targets_list, 0) + + loss_obj = self.loss_obj( + flatten_obj.view(-1, 1), obj_targets, reduction="none" + ) + loss_obj = SumWeightedLoss(1.0, num_total_samples)(loss_obj) + + if num_pos > 0: + loss_cls = self.loss_cls( + flatten_cls.view(-1, self.num_classes)[pos_masks], + cls_targets, + reduction="none", + ) + loss_cls = SumWeightedLoss(1.0, num_total_samples)(loss_cls) + loss_bbox = self.loss_bbox( + flatten_boxes.view(-1, 4)[pos_masks], bbox_targets + ) + loss_bbox = SumWeightedLoss(5.0, num_total_samples)(loss_bbox) + else: + loss_cls = flatten_cls.sum() * 0 + loss_bbox = flatten_boxes.sum() * 0 + + if self.loss_l1 is not None: + if num_pos > 0: + loss_l1 = self.loss_l1( + flatten_reg.view(-1, 4)[pos_masks], l1_targets + ) + loss_l1 = SumWeightedLoss(1.0, num_total_samples)(loss_l1) + else: + loss_l1 = flatten_reg.sum() * 0 + else: + loss_l1 = None + + return YOLOXHeadLosses( + loss_cls=loss_cls, + loss_bbox=loss_bbox, + loss_obj=loss_obj, + loss_l1=loss_l1, + ) + + def __call__( + self, + cls_outs: list[Tensor], + reg_outs: list[Tensor], + obj_outs: list[Tensor], + target_boxes: list[Tensor], + target_class_ids: list[Tensor], + images_hw: list[tuple[int, int]], + ) -> YOLOXHeadLosses: + """Type definition.""" + return self._call_impl( + cls_outs, + reg_outs, + obj_outs, + target_boxes, + target_class_ids, + images_hw, + ) diff --git a/vis4d/op/detect3d/__init__.py b/vis4d/op/detect3d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c523fd6406a8767bafa438e90cc4db159e9ab032 --- /dev/null +++ b/vis4d/op/detect3d/__init__.py @@ -0,0 +1 @@ +"""3D detector module.""" diff --git a/vis4d/op/detect3d/bevformer/__init__.py b/vis4d/op/detect3d/bevformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..efdc356055d2a9619afc603c6748401054239d68 --- /dev/null +++ b/vis4d/op/detect3d/bevformer/__init__.py @@ -0,0 +1,6 @@ +"""BEVFormer ops.""" + +from .bevformer import BEVFormerHead +from .grid_mask import GridMask + +__all__ = ["BEVFormerHead", "GridMask"] diff --git a/vis4d/op/detect3d/bevformer/bevformer.py b/vis4d/op/detect3d/bevformer/bevformer.py new file mode 100644 index 0000000000000000000000000000000000000000..bdcf87a8a592d07d452ee69de3a12c253b20f07b --- /dev/null +++ b/vis4d/op/detect3d/bevformer/bevformer.py @@ -0,0 +1,298 @@ +"""BEVFormer head.""" + +from __future__ import annotations + +from collections.abc import Sequence + +import numpy as np +import torch +from torch import Tensor, nn + +from vis4d.data.const import AxisMode +from vis4d.op.box.box3d import transform_boxes3d +from vis4d.op.box.encoder.bevformer import NMSFreeDecoder +from vis4d.op.geometry.rotation import ( + euler_angles_to_matrix, + matrix_to_quaternion, + rotate_velocities, +) +from vis4d.op.layer.positional_encoding import LearnedPositionalEncoding +from vis4d.op.layer.transformer import get_clones, inverse_sigmoid +from vis4d.op.layer.weight_init import bias_init_with_prob + +from ..common import Detect3DOut +from .transformer import PerceptionTransformer + + +def bbox3d2result( + bbox_list: list[tuple[Tensor, Tensor, Tensor]], lidar2global: Tensor +) -> Detect3DOut: + """Convert BEVFormer detection results to Detect3DOut. + + Args: + bbox_list (list[tuple[Tensor, Tensor, Tensor]): List of bounding boxes, + scores and labels. + lidar2global (Tensor): Lidar to global transformation (B, 4, 4). + + Returns: + Detect3DOut: Detection results. + """ + boxes_3d = [] + velocities = [] + class_ids = [] + scores_3d = [] + for i, (bboxes, scores, labels) in enumerate(bbox_list): + # move boxes from lidar to global coordinate system + yaw = bboxes.new_zeros(bboxes.shape[0], 3) + yaw[:, 2] = bboxes[:, 6] + orientation = matrix_to_quaternion(euler_angles_to_matrix(yaw)) + + boxes3d_lidar = torch.cat([bboxes[:, :6], orientation], dim=1) + boxes_3d.append( + transform_boxes3d( + boxes3d_lidar, lidar2global[i], AxisMode.LIDAR, AxisMode.ROS + ) + ) + + _velocities = bboxes.new_zeros(bboxes.shape[0], 3) + _velocities[:, :2] = bboxes[:, -2:] + velocities.append(rotate_velocities(_velocities, lidar2global[i])) + + class_ids.append(labels) + scores_3d.append(scores) + + return Detect3DOut(boxes_3d, velocities, class_ids, scores_3d) + + +class BEVFormerHead(nn.Module): + """BEVFormer 3D detection head.""" + + def __init__( + self, + num_classes: int = 10, + embed_dims: int = 256, + num_query: int = 900, + transformer: PerceptionTransformer | None = None, + num_reg_fcs: int = 2, + num_cls_fcs: int = 2, + point_cloud_range: Sequence[float] = ( + -51.2, + -51.2, + -5.0, + 51.2, + 51.2, + 3.0, + ), + bev_h: int = 200, + bev_w: int = 200, + ) -> None: + """Initialize BEVFormerHead. + + Args: + num_classes (int, optional): Number of classes. Defaults to 10. + embed_dims (int, optional): Embedding dimensions. Defaults to 256. + num_query (int, optional): Number of queries. Defaults to 900. + transformer (PerceptionTransformer, optional): Transformer. + Defaults to None. If None, a default transformer will be + created. + num_reg_fcs (int, optional): Number of fully connected layers in + regression branch. Defaults to 2. + num_cls_fcs (int, optional): Number of fully connected layers in + classification branch. Defaults to 2. + point_cloud_range (Sequence[float], optional): Point cloud range. + Defaults to (-51.2, -51.2, -5.0, 51.2, 51.2, 3.0). + bev_h (int, optional): BEV height. Defaults to 200. + bev_w (int, optional): BEV width. Defaults to 200. + """ + super().__init__() + self.embed_dims = embed_dims + self.num_reg_fcs = num_reg_fcs + self.bev_h = bev_h + self.bev_w = bev_w + + self.positional_encoding = LearnedPositionalEncoding( + num_feats=embed_dims // 2, row_num_embed=bev_h, col_num_embed=bev_w + ) + + self.cls_out_channels = num_classes + + self.transformer = transformer or PerceptionTransformer( + embed_dims=embed_dims + ) + + self.code_size = 10 + self.num_query = num_query + + self.box_decoder = NMSFreeDecoder( + num_classes=num_classes, + post_center_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0], + max_num=300, + ) + self.pc_range = list(point_cloud_range) + self.real_w = self.pc_range[3] - self.pc_range[0] + self.real_h = self.pc_range[4] - self.pc_range[1] + self.num_cls_fcs = num_cls_fcs - 1 + + self.code_weights = nn.Parameter( + torch.tensor( + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2], + requires_grad=False, + ), + requires_grad=False, + ) + + self._init_layers() + self._init_weights() + + def _init_layers(self) -> None: + """Initialize classification branch and regression branch of head.""" + cls_branch: list[nn.Module] = [] + for _ in range(self.num_reg_fcs): + cls_branch.append(nn.Linear(self.embed_dims, self.embed_dims)) + cls_branch.append(nn.LayerNorm(self.embed_dims)) + cls_branch.append(nn.ReLU(inplace=True)) + cls_branch.append(nn.Linear(self.embed_dims, self.cls_out_channels)) + fc_cls = nn.Sequential(*cls_branch) + + reg_branch: list[nn.Module] = [] + for _ in range(self.num_reg_fcs): + reg_branch.append(nn.Linear(self.embed_dims, self.embed_dims)) + reg_branch.append(nn.ReLU()) + reg_branch.append(nn.Linear(self.embed_dims, self.code_size)) + fc_reg = nn.Sequential(*reg_branch) + + num_pred = self.transformer.decoder.num_layers + + self.cls_branches = get_clones(fc_cls, num_pred) + self.reg_branches = get_clones(fc_reg, num_pred) + + self.bev_embedding = nn.Embedding( + self.bev_h * self.bev_w, self.embed_dims + ) + self.query_embedding = nn.Embedding( + self.num_query, self.embed_dims * 2 + ) + + def _init_weights(self) -> None: + """Initialize weights.""" + bias_init = bias_init_with_prob(0.01) + for m in self.cls_branches: + nn.init.constant_(m[-1].bias, bias_init) # type: ignore + + def forward( + self, + mlvl_feats: list[Tensor], + can_bus: Tensor, + images_hw: tuple[int, int], + cam_intrinsics: list[Tensor], + cam_extrinsics: list[Tensor], + lidar_extrinsics: Tensor, + prev_bev: Tensor | None = None, + ) -> tuple[Detect3DOut, Tensor]: + """Forward function. + + Args: + mlvl_feats (list[Tensor]): Features from the upstream network, each + is with shape (B, N, C, H, W). + can_bus (Tensor): CAN bus data, with shape (B, 18). + images_hw (tuple[int, int]): Image height and width. + cam_intrinsics (list[Tensor]): Camera intrinsics. + cam_extrinsics (list[Tensor]): Camera extrinsics. + lidar_extrinsics (list[Tensor]): LiDAR extrinsics. + prev_bev (Tensor, optional): Previous BEV feature map, with shape + (B, C, H, W). Defaults to None. + + Returns: + tuple[Detect3DOut, Tensor]: Detection results and BEV feature map. + """ + batch_size = mlvl_feats[0].shape[0] + dtype = mlvl_feats[0].dtype + object_query_embeds = self.query_embedding.weight.to(dtype) + bev_queries = self.bev_embedding.weight.to(dtype) + + bev_mask = bev_queries.new_zeros((batch_size, self.bev_h, self.bev_w)) + bev_pos = self.positional_encoding(bev_mask) + + bev_embed, hs, init_reference, inter_references = self.transformer( + mlvl_feats, + can_bus, + bev_queries, + object_query_embeds, + self.bev_h, + self.bev_w, + images_hw=images_hw, + cam_intrinsics=cam_intrinsics, + cam_extrinsics=cam_extrinsics, + lidar_extrinsics=lidar_extrinsics, + grid_length=(self.real_h / self.bev_h, self.real_w / self.bev_w), + bev_pos=bev_pos, + reg_branches=self.reg_branches, + prev_bev=prev_bev, + ) + + hs = hs.permute(0, 2, 1, 3) + outputs_classes = [] + outputs_coords = [] + for lvl in range(hs.shape[0]): + if lvl == 0: + reference = init_reference + else: + reference = inter_references[lvl - 1] + reference = inverse_sigmoid(reference) + outputs_class = self.cls_branches[lvl](hs[lvl]) + outputs_coord = self.reg_branches[lvl](hs[lvl]) + + assert reference.shape[-1] == 3 + outputs_coord[..., 0:2] += reference[..., 0:2] + outputs_coord[..., 0:2] = outputs_coord[..., 0:2].sigmoid() + outputs_coord[..., 4:5] += reference[..., 2:3] + outputs_coord[..., 4:5] = outputs_coord[..., 4:5].sigmoid() + outputs_coord[..., 0:1] = ( + outputs_coord[..., 0:1] * (self.pc_range[3] - self.pc_range[0]) + + self.pc_range[0] + ) + outputs_coord[..., 1:2] = ( + outputs_coord[..., 1:2] * (self.pc_range[4] - self.pc_range[1]) + + self.pc_range[1] + ) + outputs_coord[..., 4:5] = ( + outputs_coord[..., 4:5] * (self.pc_range[5] - self.pc_range[2]) + + self.pc_range[2] + ) + + outputs_classes.append(outputs_class) + outputs_coords.append(outputs_coord) + + ret_list: list[tuple[Tensor, Tensor, Tensor]] = [] + for cls_scores, bbox_preds in zip( + outputs_classes[-1], outputs_coords[-1] + ): + bboxes, scores, labels = self.box_decoder(cls_scores, bbox_preds) + + # mapping MMDetection3D's coordinate to our LIDAR coordinate + bboxes[:, 6] = -(bboxes[:, 6] + np.pi / 2) + + ret_list.append((bboxes, scores, labels)) + + return bbox3d2result(ret_list, lidar_extrinsics), bev_embed + + def __call__( + self, + mlvl_feats: list[Tensor], + can_bus: Tensor, + images_hw: tuple[int, int], + cam_intrinsics: list[Tensor], + cam_extrinsics: list[Tensor], + lidar_extrinsics: Tensor, + prev_bev: Tensor | None = None, + ) -> tuple[Detect3DOut, Tensor]: + """Type definition.""" + return self._call_impl( + mlvl_feats, + can_bus, + images_hw, + cam_intrinsics, + cam_extrinsics, + lidar_extrinsics, + prev_bev, + ) diff --git a/vis4d/op/detect3d/bevformer/decoder.py b/vis4d/op/detect3d/bevformer/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..c4d14ed64df2c9f4f33330d1b34580c548bfd2cb --- /dev/null +++ b/vis4d/op/detect3d/bevformer/decoder.py @@ -0,0 +1,415 @@ +"""BEVFormer decoder.""" + +from __future__ import annotations + +import math + +import torch +from torch import Tensor, nn + +from vis4d.op.layer.attention import MultiheadAttention +from vis4d.op.layer.ms_deform_attn import ( + MSDeformAttentionFunction, + is_power_of_2, + ms_deformable_attention_cpu, +) +from vis4d.op.layer.transformer import FFN, inverse_sigmoid +from vis4d.op.layer.weight_init import constant_init, xavier_init + + +class BEVFormerDecoder(nn.Module): + """Implements the decoder in DETR3D transformer.""" + + def __init__( + self, + num_layers: int = 6, + embed_dims: int = 256, + return_intermediate: bool = True, + ) -> None: + """Init. + + Args: + num_layers (int): The number of decoder layers. Default: 6. + embed_dims (int): The embedding dimension. Default: 256. + return_intermediate (bool): Whether to return intermediate + results. Default: True. + """ + super().__init__() + self.num_layers = num_layers + self.return_intermediate = return_intermediate + + self.layers = nn.ModuleList( + [ + (BEVFormerDecoderLayer(embed_dims=embed_dims)) + for _ in range(num_layers) + ] + ) + + def forward( + self, + query: Tensor, + value: Tensor, + reference_points: Tensor, + spatial_shapes: Tensor, + level_start_index: Tensor, + query_pos: Tensor, + reg_branches: list[nn.Module], + ) -> tuple[Tensor, Tensor]: + """Forward function. + + Args: + query (Tensor): Input query with shape (num_query, bs, embed_dims). + value (Tensor): Input value with shape (bs, num_query, embed_dims). + reference_points (Tensor): The reference points of offset. In shape + (bs, num_query, 4) when as_two_stage, otherwise has shape (bs, + num_query, 2). + spatial_shapes (Tensor): The spatial shapes of feature maps. + level_start_index (Tensor): The start index of each level. + query_pos (Tensor): The query position embedding. + reg_branches: (list[nn.Module]): Used for refining the regression + results. + + Returns: + tuple[Tensor, Tensor]: The output of the decoder with reference + points. If return_intermediate is True, the output and + reference points of each layer will be stacked and return. + """ + output = query + intermediate = [] + intermediate_reference_points = [] + for lid, layer in enumerate(self.layers): + # BS, NUM_QUERY, NUM_LEVEL, 2 + reference_points_input = reference_points[..., :2].unsqueeze(2) + output = layer( + output, + reference_points=reference_points_input, + value=value, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + query_pos=query_pos, + ) + output = output.permute(1, 0, 2) + + tmp = reg_branches[lid](output) + + assert reference_points.shape[-1] == 3 + new_reference_points = torch.zeros_like(reference_points) + new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid( + reference_points[..., :2] + ) + new_reference_points[..., 2:3] = tmp[..., 4:5] + inverse_sigmoid( + reference_points[..., 2:3] + ) + + new_reference_points = new_reference_points.sigmoid() + + reference_points = new_reference_points.detach() + + output = output.permute(1, 0, 2) + if self.return_intermediate: + intermediate.append(output) + intermediate_reference_points.append(reference_points) + + if self.return_intermediate: + return torch.stack(intermediate), torch.stack( + intermediate_reference_points + ) + + return output, reference_points + + +class BEVFormerDecoderLayer(nn.Module): + """Implements decoder layer in DETR transformer.""" + + def __init__( + self, + embed_dims: int = 256, + feedforward_channels: int = 512, + drop_out: float = 0.1, + ) -> None: + """Init. + + Args: + embed_dims (int): The embedding dimension. + feedforward_channels (int): The hidden dimension of FFNs. + drop_out (float): The dropout rate of FFNs. + """ + super().__init__() + self.attentions = nn.ModuleList() + + self.attentions.append( + MultiheadAttention( + embed_dims=embed_dims, + num_heads=8, + attn_drop=0.1, + proj_drop=0.1, + ) + ) + self.attentions.append( + DecoderCrossAttention(embed_dims=embed_dims, num_levels=1) + ) + + self.embed_dims = embed_dims + + self.ffns = nn.ModuleList() + self.ffns.append( + FFN( + embed_dims=self.embed_dims, + feedforward_channels=feedforward_channels, + dropout=drop_out, + ) + ) + + self.norms = nn.ModuleList() + for _ in range(3): + self.norms.append(nn.LayerNorm(self.embed_dims)) + + def forward( + self, + query: Tensor, + reference_points: Tensor, + value: Tensor, + spatial_shapes: Tensor, + level_start_index: Tensor, + query_pos: Tensor | None = None, + ) -> Tensor: + """Forward. + + Args: + query (Tensor): The input query, has shape (bs, num_queries, dim). + reference_points (Tensor): The reference points of offset. In shape + (bs, num_query, 4) when as_two_stage, otherwise has shape (bs, + num_query, 2). + value (Tensor, optional): The input value, has shape (bs, num_keys, + dim). + spatial_shapes (Tensor): The spatial shapes of feature maps. + level_start_index (Tensor): The start index of each level. + query_pos (Tensor, optional): The positional encoding for `query`, + has the same shape as `query`. If not `None`, it will be added + to `query` before forward function. Defaults to `None`. + + Returns: + Tensor: forwarded results, has shape (bs, num_queries, dim). + """ + query = self.attentions[0]( + query=query, + key=query, + value=query, + query_pos=query_pos, + key_pos=query_pos, + ) + + query = self.norms[0](query) + + query = self.attentions[1]( + query=query, + reference_points=reference_points, + value=value, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + query_pos=query_pos, + ) + + query = self.norms[1](query) + + query = self.ffns[0](query) + + query = self.norms[2](query) + + return query + + +class DecoderCrossAttention(nn.Module): + """Custom Multi-Scale Deformable Attention.""" + + def __init__( + self, + embed_dims: int = 256, + num_heads: int = 8, + num_levels: int = 4, + num_points: int = 4, + im2col_step: int = 64, + dropout: float = 0.1, + batch_first: bool = False, + ) -> None: + """Initialization. + + Args: + embed_dims (int): The embedding dimension of Attention. + Default: 256. + num_heads (int): Parallel attention heads. Default: 8. + num_levels (int): The number of feature map used in Attention. + Default: 4. + num_points (int): The number of sampling points for each query in + each head. Default: 4. + im2col_step (int): The step used in image_to_column. + Default: 64. + dropout (float): A Dropout layer on `inp_identity`. + Default: 0.1. + batch_first (bool): Key, Query and Value are shape of (batch, n, + embed_dim) or (n, batch, embed_dim). Default to False. + """ + super().__init__() + if embed_dims % num_heads != 0: + raise ValueError( + f"embed_dims must be divisible by num_heads, " + f"but got {embed_dims} and {num_heads}" + ) + self.dropout = nn.Dropout(dropout) + self.batch_first = batch_first + + is_power_of_2(embed_dims // num_heads) + + self.im2col_step = im2col_step + self.embed_dims = embed_dims + self.num_levels = num_levels + self.num_heads = num_heads + self.num_points = num_points + self.sampling_offsets = nn.Linear( + embed_dims, num_heads * num_levels * num_points * 2 + ) + self.attention_weights = nn.Linear( + embed_dims, num_heads * num_levels * num_points + ) + self.value_proj = nn.Linear(embed_dims, embed_dims) + self.output_proj = nn.Linear(embed_dims, embed_dims) + self.init_weights() + + def init_weights(self) -> None: + """Default initialization for Parameters of Module.""" + constant_init(self.sampling_offsets, 0.0) + thetas = torch.mul( + torch.arange(self.num_heads, dtype=torch.float32), + (2.0 * math.pi / self.num_heads), + ) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = ( + (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) + .view(self.num_heads, 1, 1, 2) + .repeat(1, self.num_levels, self.num_points, 1) + ) + for i in range(self.num_points): + grid_init[:, :, i, :] *= i + 1 + + self.sampling_offsets.bias.data = grid_init.view(-1) + constant_init(self.attention_weights, val=0.0, bias=0.0) + xavier_init(self.value_proj, distribution="uniform", bias=0.0) + xavier_init(self.output_proj, distribution="uniform", bias=0.0) + + def forward( # pylint: disable=duplicate-code + self, + query: Tensor, + reference_points: Tensor, + value: Tensor, + spatial_shapes: Tensor, + level_start_index: Tensor, + key_padding_mask: Tensor | None = None, + query_pos: Tensor | None = None, + identity: Tensor | None = None, + ) -> Tensor: + """Forward. + + Args: + query (Tensor): Query of Transformer with shape (num_query, bs, + embed_dims). + reference_points (Tensor): The normalized reference points with + shape (bs, num_query, num_levels, 2), all elements is range in + [0, 1], top-left (0,0), bottom-right (1, 1), including padding + area. or (N, Length_{query}, num_levels, 4), add additional two + dimensions is (w, h) to form reference boxes. + value (Tensor): The value tensor with shape (num_key, bs, + embed_dims). + spatial_shapes (Tensor): Spatial shape of features in + different levels. With shape (num_levels, 2), + last dimension represents (h, w). + level_start_index (Tensor): The start index of each level. + A tensor has shape ``(num_levels, )`` and can be represented + as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. + key_padding_mask (Tensor): ByteTensor for `query`, with + shape [bs, num_key]. + query_pos (Tensor): The positional encoding for `query`. + Default: None. + identity (Tensor): The tensor used for addition, with the + same shape as `query`. Default None. If None, + `query` will be used. + + Returns: + Tensor: forwarded results with shape [num_query, bs, embed_dims]. + """ + if identity is None: + identity = query + + if query_pos is not None: + query = query + query_pos + + # change to (bs, num_query ,embed_dims) + if not self.batch_first: + query = query.permute(1, 0, 2) + value = value.permute(1, 0, 2) + + bs, num_query, _ = query.shape + bs, num_value, _ = value.shape + assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value + + value = self.value_proj(value) + if key_padding_mask is not None: + value = value.masked_fill(key_padding_mask[..., None], 0.0) + value = value.view(bs, num_value, self.num_heads, -1) + + sampling_offsets = self.sampling_offsets(query).view( + bs, num_query, self.num_heads, self.num_levels, self.num_points, 2 + ) + + attention_weights = self.attention_weights(query).view( + bs, num_query, self.num_heads, self.num_levels * self.num_points + ) + attention_weights = attention_weights.softmax(-1) + + attention_weights = attention_weights.view( + bs, num_query, self.num_heads, self.num_levels, self.num_points + ) + + if reference_points.shape[-1] == 2: + offset_normalizer = torch.stack( + [spatial_shapes[..., 1], spatial_shapes[..., 0]], -1 + ) + sampling_locations = ( + reference_points[:, :, None, :, None, :] + + sampling_offsets + / offset_normalizer[None, None, None, :, None, :] + ) + elif reference_points.shape[-1] == 4: + sampling_locations = ( + reference_points[:, :, None, :, None, :2] + + sampling_offsets + / self.num_points + * reference_points[:, :, None, :, None, 2:] + * 0.5 + ) + else: + raise ValueError( + f"Last dim of reference_points must be" + f" 2 or 4, but get {reference_points.shape[-1]} instead." + ) + + if torch.cuda.is_available() and value.is_cuda: + output = MSDeformAttentionFunction.apply( + value, + spatial_shapes, + level_start_index, + sampling_locations, + attention_weights, + self.im2col_step, + ) + else: + output = ms_deformable_attention_cpu( + value, spatial_shapes, sampling_locations, attention_weights + ) + + output = self.output_proj(output) + + # (num_query, bs ,embed_dims) + if not self.batch_first: + output = output.permute(1, 0, 2) + + return self.dropout(output) + identity diff --git a/vis4d/op/detect3d/bevformer/encoder.py b/vis4d/op/detect3d/bevformer/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..ddda533aedea9c0c216bb4ce51e49e58d4fb2949 --- /dev/null +++ b/vis4d/op/detect3d/bevformer/encoder.py @@ -0,0 +1,432 @@ +"""BEVFormer Encoder.""" + +from __future__ import annotations + +from collections.abc import Sequence + +import torch +from torch import Tensor, nn + +from vis4d.op.geometry.transform import inverse_rigid_transform +from vis4d.op.layer.transformer import FFN, get_clones + +from .spatial_cross_attention import SpatialCrossAttention +from .temporal_self_attention import TemporalSelfAttention + + +class BEVFormerEncoder(nn.Module): + """Attention with both self and cross attention.""" + + def __init__( + self, + num_layers: int = 6, + layer: BEVFormerEncoderLayer | None = None, + embed_dims: int = 256, + num_points_in_pillar: int = 4, + point_cloud_range: Sequence[float] = ( + -51.2, + -51.2, + -5.0, + 51.2, + 51.2, + 3.0, + ), + return_intermediate: bool = False, + ) -> None: + """Init. + + Args: + num_layers (int): Number of layers in the encoder. + layer (BEVFormerEncoderLayer, optional): Encoder layer. Defaults to + None. If None, a default layer will be used. + embed_dims (int): Embedding dimension. + num_points_in_pillar (int): Number of points in each pillar. + point_cloud_range (Sequence[float]): Range of the point cloud. + Defaults to (-51.2, -51.2, -5.0, 51.2, 51.2, 3.0). + return_intermediate (bool): Whether to return intermediate outputs. + """ + super().__init__() + self.num_layers = num_layers + self.embed_dims = embed_dims + self.num_points_in_pillar = num_points_in_pillar + self.pc_range = point_cloud_range + self.return_intermediate = return_intermediate + + layer = layer or BEVFormerEncoderLayer(embed_dims=embed_dims) + + self.layers = get_clones(layer, num=self.num_layers) + + self.eps = 1e-5 + + def get_reference_points( + self, + bev_h: int, + bev_w: int, + dim: int, + batch_size: int, + device: torch.device, + dtype: torch.dtype, + ) -> Tensor: + """Get the reference points used in SCA and TSA. + + Args: + bev_h (int): Height of the BEV feature map. + bev_w (int): Width of the BEV feature map. + dim (int): Dimension of the reference points. + batch_size (int): Batch size. + device (torch.device): The device where reference_points should be. + dtype (torch.dtype): The dtype of reference_points. + + Returns: + Tensor: reference points used in decoder, has shape (batch_size, + num_keys, num_levels, dim). + """ + assert dim in {2, 3}, f"Unknown dim {dim}." + # Reference points in 3D space for spatial cross-attention (SCA) + if dim == 3: + height_z = self.pc_range[5] - self.pc_range[2] + zs = ( + torch.linspace( + 0.5, + height_z - 0.5, + self.num_points_in_pillar, + dtype=dtype, + device=device, + ) + .view(-1, 1, 1) + .expand(self.num_points_in_pillar, bev_h, bev_w) + / height_z + ) + xs = ( + torch.linspace( + 0.5, bev_w - 0.5, bev_w, dtype=dtype, device=device + ) + .view(1, 1, bev_w) + .expand(self.num_points_in_pillar, bev_h, bev_w) + / bev_w + ) + ys = ( + torch.linspace( + 0.5, bev_h - 0.5, bev_h, dtype=dtype, device=device + ) + .view(1, bev_h, 1) + .expand(self.num_points_in_pillar, bev_h, bev_w) + / bev_h + ) + ref_3d = torch.stack((xs, ys, zs), -1) + ref_3d = ref_3d.permute(0, 3, 1, 2).flatten(2).permute(0, 2, 1) + ref_3d = ref_3d[None].repeat(batch_size, 1, 1, 1) + return ref_3d + + # Reference points on 2D bev plane for temporal self-attention (TSA) + ref_y, ref_x = torch.meshgrid( + torch.linspace( + 0.5, bev_h - 0.5, bev_h, dtype=dtype, device=device + ), + torch.linspace( + 0.5, bev_w - 0.5, bev_w, dtype=dtype, device=device + ), + indexing="ij", + ) + ref_y = ref_y.reshape(-1)[None] / bev_h + ref_x = ref_x.reshape(-1)[None] / bev_w + ref_2d = torch.stack((ref_x, ref_y), -1) + ref_2d = ref_2d.repeat(batch_size, 1, 1).unsqueeze(2) + return ref_2d + + def point_sampling( + self, + reference_points: Tensor, + images_hw: tuple[int, int], + cam_intrinsics: list[Tensor], + cam_extrinsics: list[Tensor], + lidar_extrinsics: Tensor, + ) -> tuple[Tensor, Tensor]: + """Sample points from reference points.""" + lidar2img_list = [] + for i, _cam_intrinsics in enumerate(cam_intrinsics): + viewpad = torch.eye(4, device=_cam_intrinsics.device) + viewpad[:3, :3] = _cam_intrinsics + + lidar2img = ( + viewpad + @ inverse_rigid_transform(cam_extrinsics[i]) + @ lidar_extrinsics + ) + + lidar2img_list.append(lidar2img) + + lidar2img = torch.stack(lidar2img_list, dim=1) # (B, N, 4, 4) + + reference_points = reference_points.clone() + reference_points[..., 0:1] = ( + reference_points[..., 0:1] * (self.pc_range[3] - self.pc_range[0]) + + self.pc_range[0] + ) + reference_points[..., 1:2] = ( + reference_points[..., 1:2] * (self.pc_range[4] - self.pc_range[1]) + + self.pc_range[1] + ) + reference_points[..., 2:3] = ( + reference_points[..., 2:3] * (self.pc_range[5] - self.pc_range[2]) + + self.pc_range[2] + ) + + reference_points = torch.cat( + (reference_points, torch.ones_like(reference_points[..., :1])), -1 + ) + + reference_points = reference_points.permute(1, 0, 2, 3) + d, b, num_query, _ = reference_points.shape + num_cam = lidar2img.size(1) + + reference_points = ( + reference_points.view(d, b, 1, num_query, 4) + .repeat(1, 1, num_cam, 1, 1) + .unsqueeze(-1) + ) + + lidar2img = lidar2img.view(1, b, num_cam, 1, 4, 4).repeat( + d, 1, 1, num_query, 1, 1 + ) + + reference_points_cam = torch.matmul( + lidar2img, reference_points + ).squeeze(-1) + + bev_mask = reference_points_cam[..., 2:3] > self.eps + + reference_points_cam = reference_points_cam[..., 0:2] / torch.maximum( + reference_points_cam[..., 2:3], + torch.mul( + torch.ones_like(reference_points_cam[..., 2:3]), self.eps + ), + ) + + reference_points_cam[..., 0] /= images_hw[1] + reference_points_cam[..., 1] /= images_hw[0] + + bev_mask = ( + bev_mask + & (reference_points_cam[..., 1:2] > 0.0) + & (reference_points_cam[..., 1:2] < 1.0) + & (reference_points_cam[..., 0:1] < 1.0) + & (reference_points_cam[..., 0:1] > 0.0) + ) + + reference_points_cam = reference_points_cam.permute(2, 1, 3, 0, 4) + bev_mask = bev_mask.permute(2, 1, 3, 0, 4).squeeze(-1) + + return reference_points_cam, bev_mask + + def forward( + self, + bev_query: Tensor, + value: Tensor, + bev_h: int, + bev_w: int, + bev_pos: Tensor, + spatial_shapes: Tensor, + level_start_index: Tensor, + prev_bev: Tensor | None, + shift: Tensor, + images_hw: tuple[int, int], + cam_intrinsics: list[Tensor], + cam_extrinsics: list[Tensor], + lidar_extrinsics: Tensor, + ) -> Tensor: + """Forward. + + Args: + bev_query (Tensor): Input BEV query with shape (num_query, + batch_size, embed_dims). + value (Tensor): Input multi-cameta features with shape (num_cam, + num_value, batch_size, embed_dims). + bev_h (int): BEV height. + bev_w (int): BEV width. + bev_pos (Tensor): BEV positional encoding with shape (batch_size, + embed_dims). + spatial_shapes (Tensor): Spatial shapes of multi-level + features with shape (num_levels, 2). + level_start_index (Tensor): Start index of each level with shape + (num_levels, ). + prev_bev (Tensor | None): Previous BEV features with shape + (batch_size, embed_dims). + shift (Tensor): Shift of each level with shape (num_levels, 2). + images_hw (tuple[int, int]): List of image height and width. + cam_intrinsics (list[Tensor]): List of camera intrinsics. In shape + (num_cam, batch_size, 3, 3) + cam_extrinsics (list[Tensor]): List of camera extrinsics. In shape + (num_cam, batch_size, 4, 4) + lidar_extrinsics (Tensor): LiDAR extrinsics. In shape (batch_size, + 4, 4) + + Returns: + Tensor: Results with shape [batch_size, num_query, embed_dims] + when return_intermediate is False, otherwise it has shape + [num_layers, batch_size, num_query, embed_dims]. + """ + intermediate = [] + + ref_3d = self.get_reference_points( + bev_h, + bev_w, + dim=3, + batch_size=bev_query.size(1), + device=bev_query.device, + dtype=bev_query.dtype, + ) + + ref_2d = self.get_reference_points( + bev_h, + bev_w, + dim=2, + batch_size=bev_query.size(1), + device=bev_query.device, + dtype=bev_query.dtype, + ) + + reference_points_img, bev_mask = self.point_sampling( + ref_3d, + images_hw, + cam_intrinsics, + cam_extrinsics, + lidar_extrinsics, + ) + + shift_ref_2d = ref_2d.clone() + shift_ref_2d += shift[:, None, None, :] + + bev_query = bev_query.permute(1, 0, 2) + bev_pos = bev_pos.permute(1, 0, 2) + + batch_size, len_bev, num_bev_level, _ = ref_2d.shape + if prev_bev is not None: + prev_bev = prev_bev.permute(1, 0, 2) + prev_bev = torch.stack([prev_bev, bev_query], 1).reshape( + batch_size * 2, len_bev, -1 + ) + hybird_ref_2d = torch.stack([shift_ref_2d, ref_2d], 1).reshape( + batch_size * 2, len_bev, num_bev_level, 2 + ) + else: + hybird_ref_2d = torch.stack([ref_2d, ref_2d], 1).reshape( + batch_size * 2, len_bev, num_bev_level, 2 + ) + + for _, layer in enumerate(self.layers): + output = layer( + bev_query, + value, + bev_pos=bev_pos, + ref_2d=hybird_ref_2d, + bev_h=bev_h, + bev_w=bev_w, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + reference_points_img=reference_points_img, + bev_mask=bev_mask, + prev_bev=prev_bev, + ) + + bev_query = output + + if self.return_intermediate: + intermediate.append(output) + + if self.return_intermediate: + return torch.stack(intermediate) + + return output + + +class BEVFormerEncoderLayer(nn.Module): + """BEVFormer encoder layer.""" + + def __init__( + self, + embed_dims: int = 256, + self_attn: TemporalSelfAttention | None = None, + cross_attn: SpatialCrossAttention | None = None, + feedforward_channels: int = 512, + drop_out: float = 0.1, + ) -> None: + """Init.""" + super().__init__() + self.attentions = nn.ModuleList() + + self_attn = self_attn or TemporalSelfAttention( + embed_dims=embed_dims, num_levels=1 + ) + self.attentions.append(self_attn) + + cross_attn = cross_attn or SpatialCrossAttention(embed_dims=embed_dims) + self.attentions.append(cross_attn) + + self.embed_dims = embed_dims + + self.ffns = nn.ModuleList() + self.ffns.append( + FFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + dropout=drop_out, + ) + ) + + self.norms = nn.ModuleList() + for _ in range(3): + self.norms.append(nn.LayerNorm(self.embed_dims)) + + def forward( + self, + query: Tensor, + value: Tensor, + bev_pos: Tensor, + ref_2d: Tensor, + bev_h: int, + bev_w: int, + spatial_shapes: Tensor, + level_start_index: Tensor, + reference_points_img: Tensor, + bev_mask: Tensor, + prev_bev: Tensor | None = None, + ) -> Tensor: + """Forward function. + + self_attn -> norm -> cross_attn -> norm -> ffn -> norm + + Returns: + Tensor: forwarded results with shape [num_queries, batch_size, + embed_dims]. + """ + # Temporal self attention + query = self.attentions[0]( + query, + ref_2d, + prev_bev, + spatial_shapes=torch.tensor([[bev_h, bev_w]], device=query.device), + level_start_index=torch.tensor([0], device=query.device), + query_pos=bev_pos, + ) + + query = self.norms[0](query) + + # Spaital cross attention + query = self.attentions[1]( + query, + reference_points_img, + value, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + bev_mask=bev_mask, + ) + + query = self.norms[1](query) + + # FFN + query = self.ffns[0](query) + + query = self.norms[2](query) + + return query diff --git a/vis4d/op/detect3d/bevformer/grid_mask.py b/vis4d/op/detect3d/bevformer/grid_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..01efd8d91a0b1915b83688db95b16447f2407e37 --- /dev/null +++ b/vis4d/op/detect3d/bevformer/grid_mask.py @@ -0,0 +1,82 @@ +"""Grid mask for BEVFormer.""" + +import numpy as np +import torch +from PIL import Image +from torch import Tensor, nn + + +class GridMask(nn.Module): + """Grid Mask Layer.""" + + def __init__( + self, + use_h: bool, + use_w: bool, + rotate: int = 1, + offset: bool = False, + ratio: float = 0.5, + mode: int = 0, + prob: float = 1.0, + ) -> None: + """Init.""" + super().__init__() + self.use_h = use_h + self.use_w = use_w + self.rotate = rotate + self.offset = offset + self.ratio = ratio + self.mode = mode + self.st_prob = prob + self.prob = prob + + def forward(self, x: Tensor) -> Tensor: + """Forward.""" + if np.random.rand() > self.prob: + return x + + device = x.device + n, c, h, w = x.size() + x = x.view(-1, h, w) + hh = int(1.5 * h) + ww = int(1.5 * w) + d = np.random.randint(2, h) + l = min(max(int(d * self.ratio + 0.5), 1), d - 1) + mask = np.ones((hh, ww), np.float32) + st_h = np.random.randint(d) + st_w = np.random.randint(d) + if self.use_h: + for i in range(hh // d): + s = d * i + st_h + t = min(s + l, hh) + mask[s:t, :] *= 0 + if self.use_w: + for i in range(ww // d): + s = d * i + st_w + t = min(s + l, ww) + mask[:, s:t] *= 0 + + r = np.random.randint(self.rotate) + mask_img = Image.fromarray(np.uint8(mask)) + mask_img = mask_img.rotate(r) + mask = np.asarray(mask_img) + mask = mask[ + (hh - h) // 2 : (hh - h) // 2 + h, + (ww - w) // 2 : (ww - w) // 2 + w, + ] + + mask_tensor = torch.from_numpy(mask).to(x.dtype).to(device) + if self.mode == 1: + mask_tensor = 1 - mask_tensor + mask_tensor = mask_tensor.expand_as(x) + if self.offset: + offset = ( + torch.from_numpy(2 * (np.random.rand(h, w) - 0.5)) + .to(x.dtype) + .to(device) + ) + x = x * mask_tensor + offset * (1 - mask_tensor) + else: + x = x * mask_tensor + + return x.view(n, c, h, w) diff --git a/vis4d/op/detect3d/bevformer/spatial_cross_attention.py b/vis4d/op/detect3d/bevformer/spatial_cross_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..1840d797af164bb254d53bddea3e5b347361ecdc --- /dev/null +++ b/vis4d/op/detect3d/bevformer/spatial_cross_attention.py @@ -0,0 +1,375 @@ +"""Spatial Cross Attention Module for BEVFormer.""" + +from __future__ import annotations + +import math + +import torch +from torch import Tensor, nn + +from vis4d.op.layer.ms_deform_attn import ( + MSDeformAttentionFunction, + is_power_of_2, + ms_deformable_attention_cpu, +) +from vis4d.op.layer.weight_init import constant_init, xavier_init + + +class SpatialCrossAttention(nn.Module): + """An attention module used in BEVFormer.""" + + def __init__( + self, + embed_dims: int = 256, + num_cams: int = 6, + dropout: float = 0.1, + deformable_attention: MSDeformableAttention3D | None = None, + ) -> None: + """Init. + + Args: + embed_dims (int): The embedding dimension of Attention. Default: + 256. + num_cams (int): The number of cameras. Default: 6. + dropout (float): A Dropout layer on `inp_residual`. Default: 0.1. + deformable_attention (MSDeformableAttention3D, optional): + The deformable attention module. Default: None. If None, + we will use `MSDeformableAttention3D` with default + parameters. + """ + super().__init__() + self.dropout = nn.Dropout(dropout) + self.deformable_attention = ( + deformable_attention or MSDeformableAttention3D() + ) + self.embed_dims = embed_dims + self.num_cams = num_cams + self.output_proj = nn.Linear(embed_dims, embed_dims) + self.init_weight() + + def init_weight(self) -> None: + """Default initialization for Parameters of Module.""" + xavier_init(self.output_proj, distribution="uniform", bias=0.0) + + def forward( + self, + query: Tensor, + reference_points: Tensor, + value: Tensor, + spatial_shapes: Tensor, + level_start_index: Tensor, + bev_mask: Tensor, + query_pos: Tensor | None = None, + ) -> Tensor: + """Forward Function of Detr3DCrossAtten. + + Args: + query (Tensor): Query of Transformer with shape + (num_query, bs, embed_dims). + reference_points (Tensor): The normalized reference points with + shape (bs, num_query, 4), all elements is range in [0, 1], + top-left (0,0), bottom-right (1, 1), including padding area. + Or (N, Length_{query}, num_levels, 4), add additional two + dimensions is (w, h) to form reference boxes. + value (Tensor): The value tensor with shape `(num_key, bs, + embed_dims)`. (B, N, C, H, W) + spatial_shapes (Tensor): Spatial shape of features in different + level. With shape (num_levels, 2), last dimension represent + (h, w). + level_start_index (Tensor): The start index of each level. + A tensor has shape (num_levels) and can be represented + as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. + bev_mask (Tensor): The mask of BEV features with shape + (num_query, bs, num_levels, h, w). + query_pos (Tensor): The positional encoding for `query`. Default + None. + + Returns: + Tensor: Forwarded results with shape [num_query, bs, embed_dims]. + """ + inp_residual = query + slots = torch.zeros_like(query) + + if query_pos is not None: + query = query + query_pos + + bs = query.shape[0] + d = reference_points.shape[3] + + indexes = [] + for i, mask_per_img in enumerate(bev_mask): + index_query_per_img = mask_per_img[0].sum(-1).nonzero().squeeze(-1) + indexes.append(index_query_per_img) + max_len = max(len(each) for each in indexes) + + # Each camera only interacts with its corresponding BEV queries. + # This step can greatly save GPU memory. + queries_rebatch = query.new_zeros( + [bs, self.num_cams, max_len, self.embed_dims] + ) + reference_points_rebatch = reference_points.new_zeros( + [bs, self.num_cams, max_len, d, 2] + ) + + for j in range(bs): + for i, _reference_points in enumerate(reference_points): + index_query_per_img = indexes[i] + queries_rebatch[j, i, : len(index_query_per_img)] = query[ + j, index_query_per_img + ] + reference_points_rebatch[j, i, : len(index_query_per_img)] = ( + _reference_points[j, index_query_per_img] + ) + + _, l, bs, _ = value.shape + + value = value.permute(2, 0, 1, 3).reshape( + bs * self.num_cams, l, self.embed_dims + ) + + queries = self.deformable_attention( + query=queries_rebatch.view( + bs * self.num_cams, max_len, self.embed_dims + ), + reference_points=reference_points_rebatch.view( + bs * self.num_cams, max_len, d, 2 + ), + value=value, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + ).view(bs, self.num_cams, max_len, self.embed_dims) + + for j in range(bs): + for i, index_query_per_img in enumerate(indexes): + slots[j, index_query_per_img] += queries[ + j, i, : len(index_query_per_img) + ] + + count = bev_mask.sum(-1) > 0 + count = count.permute(1, 2, 0).sum(-1) + count = torch.clamp(count, min=1.0) + slots = slots / count[..., None] + slots = self.output_proj(slots) + + return self.dropout(slots) + inp_residual + + +class MSDeformableAttention3D(nn.Module): + """An attention module used in BEVFormer based on Deformable-Detr. + + `Deformable DETR: Deformable Transformers for End-to-End Object Detection. + `_. + """ + + def __init__( + self, + embed_dims: int = 256, + num_heads: int = 8, + num_levels: int = 4, + num_points: int = 8, + im2col_step: int = 64, + batch_first: bool = True, + ) -> None: + """Init. + + Args: + embed_dims (int): The embedding dimension of Attention. Default: + 256. + num_heads (int): Parallel attention heads. Default: 64. + num_levels (int): The number of feature map used in + Attention. Default: 4. + num_points (int): The number of sampling points for each query in + each head. Default: 4. + im2col_step (int): The step used in image_to_column. + Default: 64. + batch_first (bool): Key, Query and Value are shape of (batch, n, + embed_dim) or (n, batch, embed_dim). Default to True. + """ + super().__init__() + if embed_dims % num_heads != 0: + raise ValueError( + f"embed_dims must be divisible by num_heads, " + f"but got {embed_dims} and {num_heads}" + ) + + self.batch_first = batch_first + + is_power_of_2(embed_dims // num_heads) + + self.im2col_step = im2col_step + self.embed_dims = embed_dims + self.num_levels = num_levels + self.num_heads = num_heads + self.num_points = num_points + self.sampling_offsets = nn.Linear( + embed_dims, num_heads * num_levels * num_points * 2 + ) + self.attention_weights = nn.Linear( + embed_dims, num_heads * num_levels * num_points + ) + self.value_proj = nn.Linear(embed_dims, embed_dims) + + self.init_weights() + + def init_weights(self) -> None: + """Default initialization for Parameters of Module.""" + constant_init(self.sampling_offsets, 0.0) + thetas = torch.mul( + torch.arange(self.num_heads, dtype=torch.float32), + (2.0 * math.pi / self.num_heads), + ) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = ( + (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) + .view(self.num_heads, 1, 1, 2) + .repeat(1, self.num_levels, self.num_points, 1) + ) + for i in range(self.num_points): + grid_init[:, :, i, :] *= i + 1 + + self.sampling_offsets.bias.data = grid_init.view(-1) + constant_init(self.attention_weights, val=0.0, bias=0.0) + xavier_init(self.value_proj, distribution="uniform", bias=0.0) + + def forward( # pylint: disable=duplicate-code + self, + query: Tensor, + reference_points: Tensor, + value: Tensor, + spatial_shapes: Tensor, + level_start_index: Tensor, + key_padding_mask: Tensor | None = None, + query_pos: Tensor | None = None, + ) -> Tensor: + """Forward. + + Args: + query (Tensor): Query of Transformer with shape (bs, num_query, + embed_dims). + reference_points (Tensor): The normalized reference points with + shape (bs, num_query, num_levels, 2), all elements is range in + [0, 1], top-left (0,0), bottom-right (1, 1), including padding + area. Or (N, Length_{query}, num_levels, 4), add additional two + dimensions is (w, h) to form reference boxes. + value (Tensor): The value tensor with shape `(bs, num_key, + embed_dims)`. + spatial_shapes (Tensor): Spatial shape of features in different + levels. With shape (num_levels, 2), last dimension represents + (h, w). + level_start_index (Tensor): The start index of each level. A tensor + has shape ``(num_levels, )`` and can be represented as [0, + h_0*w_0, h_0*w_0+h_1*w_1, ...]. + key_padding_mask (Tensor): ByteTensor for value, with shape [bs, + num_key]. + query_pos (Tensor): The positional encoding for `query`. + Default: None. + + Returns: + Tensor: forwarded results with shape [num_query, bs, embed_dims]. + """ + if query_pos is not None: + query = query + query_pos + + if not self.batch_first: + # change to (bs, num_query ,embed_dims) + query = query.permute(1, 0, 2) + value = value.permute(1, 0, 2) + + bs, num_query, _ = query.shape + bs, num_value, _ = value.shape + assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value + + value = self.value_proj(value) + if key_padding_mask is not None: + value = value.masked_fill(key_padding_mask[..., None], 0.0) + value = value.view(bs, num_value, self.num_heads, -1) + + sampling_offsets = self.sampling_offsets(query).view( + bs, num_query, self.num_heads, self.num_levels, self.num_points, 2 + ) + + attention_weights = self.attention_weights(query).view( + bs, num_query, self.num_heads, self.num_levels * self.num_points + ) + + attention_weights = attention_weights.softmax(-1) + + # bs, num_query, num_heads, num_levels, num_all_points + attention_weights = attention_weights.view( + bs, num_query, self.num_heads, self.num_levels, self.num_points + ) + + # For each BEV query, it owns `num_z_anchors` in 3D space that + # having different heights. After proejcting, each BEV query has + # `num_z_anchors` reference points in each 2D image. For each + # referent point, we sample `num_points` sampling points. + # For `num_z_anchors` reference points, it has overall `num_points + # * num_z_anchors` sampling points. + if reference_points.shape[-1] == 2: + offset_normalizer = torch.stack( + [spatial_shapes[..., 1], spatial_shapes[..., 0]], -1 + ) + + bs, num_query, num_z_anchors, xy = reference_points.shape + reference_points = reference_points[:, :, None, None, None, :, :] + sampling_offsets = ( + sampling_offsets + / offset_normalizer[None, None, None, :, None, :] + ) + ( + bs, + num_query, + num_heads, + num_levels, + num_all_points, + xy, + ) = sampling_offsets.shape + sampling_offsets = sampling_offsets.view( + bs, + num_query, + num_heads, + num_levels, + num_all_points // num_z_anchors, + num_z_anchors, + xy, + ) + sampling_locations = reference_points + sampling_offsets + ( + bs, + num_query, + num_heads, + num_levels, + num_points, + num_z_anchors, + xy, + ) = sampling_locations.shape + assert num_all_points == num_points * num_z_anchors + + # bs, num_query, num_heads, num_levels, num_all_points, 2 + sampling_locations = sampling_locations.view( + bs, num_query, num_heads, num_levels, num_all_points, xy + ) + else: + raise ValueError( + "Last dim of reference_points must be 2 , but get " + + f"{reference_points.shape[-1]} instead." + ) + + if torch.cuda.is_available() and value.is_cuda: + output = MSDeformAttentionFunction.apply( + value, + spatial_shapes, + level_start_index, + sampling_locations, + attention_weights, + self.im2col_step, + ) + else: + output = ms_deformable_attention_cpu( + value, spatial_shapes, sampling_locations, attention_weights + ) + + if not self.batch_first: + output = output.permute(1, 0, 2) + + return output diff --git a/vis4d/op/detect3d/bevformer/temporal_self_attention.py b/vis4d/op/detect3d/bevformer/temporal_self_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..00fafdead8979eb39386b53c4fa2520dab174433 --- /dev/null +++ b/vis4d/op/detect3d/bevformer/temporal_self_attention.py @@ -0,0 +1,285 @@ +"""An attention module used in BEVFormer based on Deformable-Detr.""" + +from __future__ import annotations + +import math + +import torch +from torch import Tensor, nn + +from vis4d.op.layer.ms_deform_attn import ( + MSDeformAttentionFunction, + is_power_of_2, + ms_deformable_attention_cpu, +) +from vis4d.op.layer.weight_init import constant_init, xavier_init + + +class TemporalSelfAttention(nn.Module): + """Temperal Self Attention.""" + + def __init__( + self, + embed_dims: int = 256, + num_heads: int = 8, + num_levels: int = 4, + num_points: int = 4, + num_bev_queue: int = 2, + im2col_step: int = 64, + dropout: float = 0.1, + batch_first: bool = True, + ) -> None: + """Init. + + Args: + embed_dims (int): The embedding dimension of Attention. Default: + 256. + num_heads (int): Parallel attention heads. Default: 64. + num_levels (int): The number of feature map used in Attention. + Default: 4. + num_points (int): The number of sampling points for each query in + each head. Default: 4. + num_bev_queue (int): In this version, we only use one history BEV + and one currenct BEV. The length of BEV queue is 2. + im2col_step (int): The step used in image_to_column. Default: 64. + dropout (float): A Dropout layer on `inp_identity`. Default: 0.1. + batch_first (bool): Key, Query and Value are shape of (batch, n, + embed_dim) or (n, batch, embed_dim). Default to True. + """ + super().__init__() + if embed_dims % num_heads != 0: + raise ValueError( + f"embed_dims must be divisible by num_heads, " + f"but got {embed_dims} and {num_heads}" + ) + + is_power_of_2(embed_dims // num_heads) + + self.dropout = nn.Dropout(dropout) + self.batch_first = batch_first + + self.im2col_step = im2col_step + self.embed_dims = embed_dims + self.num_levels = num_levels + self.num_heads = num_heads + self.num_points = num_points + self.num_bev_queue = num_bev_queue + self.sampling_offsets = nn.Linear( + embed_dims * self.num_bev_queue, + num_bev_queue * num_heads * num_levels * num_points * 2, + ) + self.attention_weights = nn.Linear( + embed_dims * self.num_bev_queue, + num_bev_queue * num_heads * num_levels * num_points, + ) + self.value_proj = nn.Linear(embed_dims, embed_dims) + self.output_proj = nn.Linear(embed_dims, embed_dims) + self.init_weights() + + def init_weights(self) -> None: + """Default initialization for Parameters of Module.""" + constant_init(self.sampling_offsets, 0.0) + thetas = torch.mul( + torch.arange(self.num_heads, dtype=torch.float32), + (2.0 * math.pi / self.num_heads), + ) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = ( + (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) + .view(self.num_heads, 1, 1, 2) + .repeat( + 1, self.num_levels * self.num_bev_queue, self.num_points, 1 + ) + ) + + for i in range(self.num_points): + grid_init[:, :, i, :] *= i + 1 + + self.sampling_offsets.bias.data = grid_init.view(-1) + constant_init(self.attention_weights, val=0.0, bias=0.0) + xavier_init(self.value_proj, distribution="uniform", bias=0.0) + xavier_init(self.output_proj, distribution="uniform", bias=0.0) + + def forward( + self, + query: Tensor, + reference_points: Tensor, + value: Tensor | None, + spatial_shapes: Tensor, + level_start_index: Tensor, + key_padding_mask: Tensor | None = None, + identity: Tensor | None = None, + query_pos: Tensor | None = None, + ) -> Tensor: + """Forward Function of MultiScaleDeformAttention. + + Args: + query (Tensor): Query of Transformer with shape (num_query, bs, + embed_dims). + reference_points (Tensor): The normalized reference points with + shape (bs, num_query, num_levels, 2), all elements is range in + [0, 1], top-left (0,0), bottom-right (1, 1), including padding + area. or (N, Length_{query}, num_levels, 4), add additional two + dimensions is (w, h) to form reference boxes. + value (Tensor): The value tensor with shape (num_key, bs, + embed_dims). + spatial_shapes (Tensor): Spatial shape of features in different + levels. With shape (num_levels, 2), last dimension represents + (h, w). + level_start_index (Tensor): The start index of each level. + A tensor has shape ``(num_levels, )`` and can be represented + as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. + key_padding_mask (Tensor): ByteTensor for value, with shape [bs, + num_key]. + identity (Tensor): The tensor used for addition, with the + same shape as query. Default None. If None, query will be used. + query_pos (Tensor, optional): The positional encoding for query. + Default: None. + + Returns: + Tensor: forwarded results with shape [num_query, bs, embed_dims]. + """ + if value is None: + assert self.batch_first + bs, len_bev, c = query.shape + value = torch.stack([query, query], 1).reshape(bs * 2, len_bev, c) + + if identity is None: + identity = query + + if query_pos is not None: + query = query + query_pos + + if not self.batch_first: + # change to (bs, num_query ,embed_dims) + query = query.permute(1, 0, 2) + value = value.permute(1, 0, 2) + + bs, num_query, embed_dims = query.shape + _, num_value, _ = value.shape + assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value + assert self.num_bev_queue == 2 + + query = torch.cat([value[:bs], query], -1) + value = self.value_proj(value) + assert isinstance(value, Tensor) + + if key_padding_mask is not None: + value = value.masked_fill(key_padding_mask[..., None], 0.0) + + value = value.reshape( + bs * self.num_bev_queue, num_value, self.num_heads, -1 + ) + + sampling_offsets = self.sampling_offsets(query) + sampling_offsets = sampling_offsets.view( + bs, + num_query, + self.num_heads, + self.num_bev_queue, + self.num_levels, + self.num_points, + 2, + ) + attention_weights = self.attention_weights(query).view( + bs, + num_query, + self.num_heads, + self.num_bev_queue, + self.num_levels * self.num_points, + ) + attention_weights = attention_weights.softmax(-1) + + attention_weights = attention_weights.view( + bs, + num_query, + self.num_heads, + self.num_bev_queue, + self.num_levels, + self.num_points, + ) + + attention_weights = ( + attention_weights.permute(0, 3, 1, 2, 4, 5) + .reshape( + bs * self.num_bev_queue, + num_query, + self.num_heads, + self.num_levels, + self.num_points, + ) + .contiguous() + ) + + sampling_offsets = sampling_offsets.permute( + 0, 3, 1, 2, 4, 5, 6 + ).reshape( + bs * self.num_bev_queue, + num_query, + self.num_heads, + self.num_levels, + self.num_points, + 2, + ) + + if reference_points.shape[-1] == 2: + offset_normalizer = torch.stack( + [spatial_shapes[..., 1], spatial_shapes[..., 0]], -1 + ) + sampling_locations = ( + reference_points[:, :, None, :, None, :] + + sampling_offsets + / offset_normalizer[None, None, None, :, None, :] + ) + + elif reference_points.shape[-1] == 4: + sampling_locations = ( + reference_points[:, :, None, :, None, :2] + + sampling_offsets + / self.num_points + * reference_points[:, :, None, :, None, 2:] + * 0.5 + ) + else: + raise ValueError( + f"Last dim of reference_points must be" + f" 2 or 4, but get {reference_points.shape[-1]} instead." + ) + + if torch.cuda.is_available() and value.is_cuda: + output = MSDeformAttentionFunction.apply( + value, + spatial_shapes, + level_start_index, + sampling_locations, + attention_weights, + self.im2col_step, + ) + else: + output = ms_deformable_attention_cpu( + value, + spatial_shapes, + sampling_locations, + attention_weights, + ) + + # output shape (bs*num_bev_queue, num_query, embed_dims) + # (bs*num_bev_queue, num_query, embed_dims) + # -> (num_query, embed_dims, bs*num_bev_queue) + output = output.permute(1, 2, 0) + + # fuse history value and current value + # (num_query, embed_dims, bs*num_bev_queue) + # -> (num_query, embed_dims, bs, num_bev_queue) + output = output.view(num_query, embed_dims, bs, self.num_bev_queue) + output = output.mean(-1) + + # (num_query, embed_dims, bs)-> (bs, num_query, embed_dims) + output = output.permute(2, 0, 1) + + output = self.output_proj(output) + + if not self.batch_first: + output = output.permute(1, 0, 2) + + return self.dropout(output) + identity diff --git a/vis4d/op/detect3d/bevformer/transformer.py b/vis4d/op/detect3d/bevformer/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..489c713f3dc141467daa637d497e8dc33e7d5c09 --- /dev/null +++ b/vis4d/op/detect3d/bevformer/transformer.py @@ -0,0 +1,271 @@ +"""BEVFormer transformer.""" + +from __future__ import annotations + +import numpy as np +import torch +from torch import Tensor, nn +from torchvision.transforms.functional import rotate + +from vis4d.op.layer.weight_init import xavier_init + +from .decoder import BEVFormerDecoder +from .encoder import BEVFormerEncoder + + +class PerceptionTransformer(nn.Module): + """Perception Transformer.""" + + def __init__( + self, + num_cams: int = 6, + encoder: BEVFormerEncoder | None = None, + decoder: BEVFormerDecoder | None = None, + embed_dims: int = 256, + num_feature_levels: int = 4, + rotate_center: tuple[int, int] = (100, 100), + ) -> None: + """Init.""" + super().__init__() + self.num_cams = num_cams + self.embed_dims = embed_dims + self.num_feature_levels = num_feature_levels + self.rotate_center = list(rotate_center) + + self.encoder = encoder or BEVFormerEncoder(embed_dims=self.embed_dims) + self.decoder = decoder or BEVFormerDecoder(embed_dims=self.embed_dims) + + self._init_layers() + self._init_weights() + + def _init_layers(self) -> None: + """Initialize layers of the Detr3DTransformer.""" + self.level_embeds = nn.Parameter( + torch.Tensor(self.num_feature_levels, self.embed_dims) + ) + self.cams_embeds = nn.Parameter( + torch.Tensor(self.num_cams, self.embed_dims) + ) + self.reference_points = nn.Linear(self.embed_dims, 3) + + self.can_bus_mlp = nn.Sequential( + nn.Linear(18, self.embed_dims // 2), + nn.ReLU(inplace=True), + nn.Linear(self.embed_dims // 2, self.embed_dims), + nn.ReLU(inplace=True), + ) + self.can_bus_mlp.add_module("norm", nn.LayerNorm(self.embed_dims)) + + def _init_weights(self) -> None: + """Initialize the transformer weights.""" + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + nn.init.normal_(self.level_embeds) + nn.init.normal_(self.cams_embeds) + xavier_init(self.reference_points, distribution="uniform", bias=0.0) + xavier_init(self.can_bus_mlp, distribution="uniform", bias=0.0) + + def get_bev_features( + self, + mlvl_feats: list[Tensor], + can_bus: Tensor, + bev_queries: Tensor, + bev_h: int, + bev_w: int, + images_hw: tuple[int, int], + cam_intrinsics: list[Tensor], + cam_extrinsics: list[Tensor], + lidar_extrinsics: Tensor, + grid_length: tuple[float, float], + bev_pos: Tensor, + prev_bev: Tensor | None = None, + ) -> Tensor: + """Obtain bev features.""" + batch_size = mlvl_feats[0].shape[0] + bev_queries = bev_queries.unsqueeze(1).repeat(1, batch_size, 1) + bev_pos = bev_pos.flatten(2).permute(2, 0, 1) + + # obtain rotation angle and shift with ego motion + delta_x = can_bus[:, 0].unsqueeze(1) + delta_y = can_bus[:, 1].unsqueeze(1) + ego_angle = can_bus[:, -2] / np.pi * 180 + + translation_length = torch.sqrt(delta_x**2 + delta_y**2) + translation_angle = torch.arctan2(delta_y, delta_x) / np.pi * 180 + bev_angle = ego_angle - translation_angle + + shift_y = ( + translation_length + * torch.cos(bev_angle / 180 * np.pi) + / grid_length[0] + / bev_h + ) + shift_x = ( + translation_length + * torch.sin(bev_angle / 180 * np.pi) + / grid_length[1] + / bev_w + ) + + # B, xy + shift = torch.cat([shift_x, shift_y], dim=1) + + if prev_bev is not None: + if prev_bev.shape[1] == bev_h * bev_w: + prev_bev = prev_bev.permute(1, 0, 2) + + # rotate prev_bev + for i in range(batch_size): + rotation_angle = float(can_bus[i][-1]) + tmp_prev_bev = ( + prev_bev[:, i].reshape(bev_h, bev_w, -1).permute(2, 0, 1) + ) + tmp_prev_bev = rotate( + tmp_prev_bev, rotation_angle, center=self.rotate_center + ) + tmp_prev_bev = tmp_prev_bev.permute(1, 2, 0).reshape( + bev_h * bev_w, 1, -1 + ) + prev_bev[:, i] = tmp_prev_bev[:, 0] + + # add can bus signals + bev_queries = bev_queries + self.can_bus_mlp(can_bus)[None, :, :] + + feat_flatten_list = [] + spatial_shapes_list = [] + for lvl, feat in enumerate(mlvl_feats): + spatial_shape = feat.shape[-2:] + feat = feat.flatten(3).permute(1, 0, 3, 2) + + # Add cams_embeds and level_embeds + feat += self.cams_embeds[:, None, None, :].to(feat.dtype) + feat += self.level_embeds[None, None, lvl : lvl + 1, :].to( + feat.dtype + ) + + spatial_shapes_list.append(spatial_shape) + feat_flatten_list.append(feat) + + feat_flatten = torch.cat(feat_flatten_list, 2) + spatial_shapes = torch.as_tensor( + spatial_shapes_list, dtype=torch.long, device=bev_pos.device + ) + level_start_index = torch.cat( + ( + spatial_shapes.new_zeros((1,)), + spatial_shapes.prod(1).cumsum(0)[:-1], + ) + ) + + # (num_cam, H*W, bs, embed_dims) + feat_flatten = feat_flatten.permute(0, 2, 1, 3) + + bev_embed = self.encoder( + bev_queries, + feat_flatten, + bev_h=bev_h, + bev_w=bev_w, + bev_pos=bev_pos, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + prev_bev=prev_bev, + shift=shift, + images_hw=images_hw, + cam_intrinsics=cam_intrinsics, + cam_extrinsics=cam_extrinsics, + lidar_extrinsics=lidar_extrinsics, + ) + return bev_embed + + def forward( + self, + mlvl_feats: list[Tensor], + can_bus: Tensor, + bev_queries: Tensor, + object_query_embed: Tensor, + bev_h: int, + bev_w: int, + images_hw: tuple[int, int], + cam_intrinsics: list[Tensor], + cam_extrinsics: list[Tensor], + lidar_extrinsics: Tensor, + grid_length: tuple[float, float], + bev_pos: Tensor, + reg_branches: list[nn.Module], + prev_bev: Tensor | None = None, + ) -> tuple[Tensor, Tensor, Tensor, Tensor]: + """Forward function for BEVFormer transformer. + + Args: + mlvl_feats (list(Tensor)): Input queries from different level. Each + element has shape [bs, num_cams, embed_dims, h, w]. + can_bus (Tensor): The can bus signals, has shape [bs, 18]. + bev_queries (Tensor): (bev_h * bev_w, embed_dims). + object_query_embed (Tensor): The query embedding for decoder, + with shape [num_query, embed_dims * 2]. + bev_h (int): The height of BEV feature map. + bev_w (int): The width of BEV feature map. + images_hw (tuple[int, int]): The height and width of images. + cam_intrinsics (list[Tensor]): The camera intrinsics. + cam_extrinsics (list[Tensor]): The camera extrinsics. + lidar_extrinsics (Tensor): The lidar extrinsics. + grid_length (tuple[float, float]): The length of grid in x and y + direction. + bev_pos (Tensor): (bs, embed_dims, bev_h, bev_w) + reg_branches (list[nn.Module]): Regression heads for feature maps + from each decoder layer. + prev_bev (Tensor, optional): The previous BEV feature map, has + shape [bev_h * bev_w, bs, embed_dims]. Defaults to None. + + Returns: + bev_embed (Tensor): BEV features has shape [bev_h *bev_w, bs, + embed_dims]. + inter_states: Outputs from decoder has shape [1, bs, num_query, + embed_dims]. + reference_points: As the initial reference has shape [bs, + num_queries, 4]. + inter_references: The internal value of reference points in the + decoder, has shape [num_dec_layers, bs,num_query, embed_dims]. + """ + # bs, bev_h*bev_w, embed_dims + bev_embed = self.get_bev_features( + mlvl_feats, + can_bus, + bev_queries, + bev_h, + bev_w, + images_hw=images_hw, + cam_intrinsics=cam_intrinsics, + cam_extrinsics=cam_extrinsics, + lidar_extrinsics=lidar_extrinsics, + grid_length=grid_length, + bev_pos=bev_pos, + prev_bev=prev_bev, + ) + + bs = mlvl_feats[0].shape[0] + query_pos, query = torch.split( + object_query_embed, self.embed_dims, dim=1 + ) + query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1) + query = query.unsqueeze(0).expand(bs, -1, -1) + reference_points = self.reference_points(query_pos) + reference_points = reference_points.sigmoid() + + query = query.permute(1, 0, 2) + query_pos = query_pos.permute(1, 0, 2) + bev_embed = bev_embed.permute(1, 0, 2) + + inter_states, inter_references = self.decoder( + query=query, + value=bev_embed, + reference_points=reference_points, + spatial_shapes=torch.tensor([[bev_h, bev_w]], device=query.device), + level_start_index=torch.tensor([0], device=query.device), + query_pos=query_pos, + reg_branches=reg_branches, + ) + + return bev_embed, inter_states, reference_points, inter_references diff --git a/vis4d/op/detect3d/common.py b/vis4d/op/detect3d/common.py new file mode 100644 index 0000000000000000000000000000000000000000..33003af712f87f544b8e9c4689bc45281fa0f84b --- /dev/null +++ b/vis4d/op/detect3d/common.py @@ -0,0 +1,23 @@ +"""Common classes and functions for 3D detection.""" + +from __future__ import annotations + +from typing import NamedTuple + +from torch import Tensor + + +class Detect3DOut(NamedTuple): + """Output of detect 3D model. + + Attributes: + boxes_3d (list[Tensor]): List of bounding boxes (B, N, 10). + velocities (list[Tensor]): List of velocities (B, N, 3). + class_ids (list[Tensor]): List of class ids (B, N). + scores_3d (list[Tensor]): List of scores (B, N). + """ + + boxes_3d: list[Tensor] + velocities: list[Tensor] + class_ids: list[Tensor] + scores_3d: list[Tensor] diff --git a/vis4d/op/detect3d/qd_3dt.py b/vis4d/op/detect3d/qd_3dt.py new file mode 100644 index 0000000000000000000000000000000000000000..d92370b8e38d5861c38d15e220767682a7784d79 --- /dev/null +++ b/vis4d/op/detect3d/qd_3dt.py @@ -0,0 +1,699 @@ +"""QD-3DT detector.""" + +from __future__ import annotations + +from typing import NamedTuple + +import numpy as np +import torch +from torch import Tensor, nn + +from vis4d.common.typing import LossesType +from vis4d.op.box.encoder.qd_3dt import QD3DTBox3DDecoder, QD3DTBox3DEncoder +from vis4d.op.box.matchers import Matcher, MaxIoUMatcher +from vis4d.op.box.poolers import MultiScaleRoIAlign, MultiScaleRoIPooler +from vis4d.op.box.samplers import ( + CombinedSampler, + Sampler, + match_and_sample_proposals, +) +from vis4d.op.geometry.rotation import generate_rotation_output +from vis4d.op.layer.conv2d import Conv2d, add_conv_branch +from vis4d.op.layer.weight_init import kaiming_init, xavier_init +from vis4d.op.loss.base import Loss +from vis4d.op.loss.common import rotation_loss, smooth_l1_loss +from vis4d.op.loss.reducer import LossReducer, SumWeightedLoss, mean_loss + + +class QD3DTBBox3DHeadOutput(NamedTuple): + """QD-3DT bounding box 3D head training output.""" + + predictions: list[Tensor] + targets: Tensor | None + labels: Tensor | None + + +class QD3DTDet3DOut(NamedTuple): + """Output of QD-3DT bounding box 3D head. + + Attributes: + boxes_3d (list[Tensor]): Predicted 3D bounding boxes. Each tensor has + shape (N, 12) and contains x,y,z,h,w,l,rx,ry,rz,vx,vy,vz. + depth_uncertainty (list[Tensor]): Predicted depth uncertainty. Each + tensor has shape (N, 1). + """ + + boxes_3d: list[Tensor] + depth_uncertainty: list[Tensor] + + +def get_default_proposal_pooler() -> MultiScaleRoIAlign: + """Get default proposal pooler of QD-3DT bounding box 3D head.""" + return MultiScaleRoIAlign( + resolution=[7, 7], strides=[4, 8, 16, 32], sampling_ratio=0 + ) + + +def get_default_box_sampler() -> CombinedSampler: + """Get default box sampler of QD-3DT bounding box 3D head.""" + return CombinedSampler( + batch_size=512, + positive_fraction=0.25, + pos_strategy="instance_balanced", + neg_strategy="iou_balanced", + ) + + +def get_default_box_matcher() -> MaxIoUMatcher: + """Get default box matcher of QD-3DT bounding box 3D head.""" + return MaxIoUMatcher( + thresholds=[0.5, 0.5], + labels=[0, -1, 1], + allow_low_quality_matches=False, + ) + + +def get_default_box_codec( + center_scale: float = 10.0, + depth_log_scale: float = 2.0, + dim_log_scale: float = 2.0, + num_rotation_bins: int = 2, + bin_overlap: float = 1 / 6, +) -> tuple[QD3DTBox3DEncoder, QD3DTBox3DDecoder]: + """Get the default bounding box encoder and decoder.""" + return ( + QD3DTBox3DEncoder( + center_scale=center_scale, + depth_log_scale=depth_log_scale, + dim_log_scale=dim_log_scale, + num_rotation_bins=num_rotation_bins, + bin_overlap=bin_overlap, + ), + QD3DTBox3DDecoder( + center_scale=center_scale, + depth_log_scale=depth_log_scale, + dim_log_scale=dim_log_scale, + num_rotation_bins=num_rotation_bins, + ), + ) + + +class QD3DTBBox3DHead(nn.Module): + """This class implements the QD-3DT bounding box 3D head.""" + + def __init__( # pylint: disable=too-many-arguments, too-many-positional-arguments, line-too-long + self, + num_classes: int, + proposal_pooler: None | MultiScaleRoIPooler = None, + box_matcher: None | Matcher = None, + box_sampler: None | Sampler = None, + box_encoder: None | QD3DTBox3DEncoder = None, + proposal_append_gt: bool = True, + num_shared_convs: int = 2, + num_shared_fcs: int = 0, + num_dep_convs: int = 4, + num_dep_fcs: int = 0, + num_dim_convs: int = 4, + num_dim_fcs: int = 0, + num_rot_convs: int = 4, + num_rot_fcs: int = 0, + num_cen_2d_convs: int = 4, + num_cen_2d_fcs: int = 0, + in_channels: int = 256, + conv_out_dim: int = 256, + fc_out_dim: int = 1024, + roi_feat_size: int = 7, + conv_has_bias: bool = True, + norm: None | str = None, + num_groups: int = 32, + num_rotation_bins: int = 2, + start_level: int = 2, + ): + """Initialize the QD-3DT bounding box 3D head.""" + super().__init__() + self.proposal_pooler = ( + proposal_pooler + if proposal_pooler is not None + else get_default_proposal_pooler() + ) + self.box_matcher = ( + box_matcher + if box_matcher is not None + else get_default_box_matcher() + ) + self.box_sampler = ( + box_sampler + if box_sampler is not None + else get_default_box_sampler() + ) + self.box_encoder = ( + box_encoder if box_encoder is not None else QD3DTBox3DEncoder() + ) + self.num_shared_convs = num_shared_convs + self.num_shared_fcs = num_shared_fcs + self.num_rotation_bins = num_rotation_bins + self.proposal_append_gt = proposal_append_gt + self.cls_out_channels = num_classes + + # Used feature layers are [start_level, end_level) + self.start_level = start_level + num_strides = len(self.proposal_pooler.scales) + self.end_level = start_level + num_strides + + # add shared convs and fcs + ( + self.shared_convs, + self.shared_fcs, + self.shared_out_channels, + ) = self._add_conv_fc_branch( + num_shared_convs, + num_shared_fcs, + in_channels, + conv_out_dim, + fc_out_dim, + conv_has_bias, + norm, + num_groups, + True, + ) + + # add depth specific branch + ( + self.dep_convs, + self.dep_fcs, + self.dep_last_dim, + ) = self._add_conv_fc_branch( + num_dep_convs, + num_dep_fcs, + self.shared_out_channels, + conv_out_dim, + fc_out_dim, + conv_has_bias, + norm, + num_groups, + ) + + # add dim specific branch + ( + self.dim_convs, + self.dim_fcs, + self.dim_last_dim, + ) = self._add_conv_fc_branch( + num_dim_convs, + num_dim_fcs, + self.shared_out_channels, + conv_out_dim, + fc_out_dim, + conv_has_bias, + norm, + num_groups, + ) + + # add rot specific branch + ( + self.rot_convs, + self.rot_fcs, + self.rot_last_dim, + ) = self._add_conv_fc_branch( + num_rot_convs, + num_rot_fcs, + self.shared_out_channels, + conv_out_dim, + fc_out_dim, + conv_has_bias, + norm, + num_groups, + ) + + # add delta 2D center specific branch + ( + self.cen_2d_convs, + self.cen_2d_fcs, + self.cen_2d_last_dim, + ) = self._add_conv_fc_branch( + num_cen_2d_convs, + num_cen_2d_fcs, + self.shared_out_channels, + conv_out_dim, + fc_out_dim, + conv_has_bias, + norm, + num_groups, + ) + + if num_shared_fcs == 0: + if num_dep_fcs == 0: + self.dep_last_dim *= roi_feat_size * roi_feat_size + if num_dim_fcs == 0: + self.dim_last_dim *= roi_feat_size * roi_feat_size + if num_rot_fcs == 0: + self.rot_last_dim *= roi_feat_size * roi_feat_size + if num_cen_2d_fcs == 0: + self.cen_2d_last_dim *= roi_feat_size * roi_feat_size + + self.relu = nn.ReLU(inplace=True) + # reconstruct fc_cls and fc_reg since input channels are changed + out_dim_dep = self.cls_out_channels + self.fc_dep = nn.Linear(self.dep_last_dim, out_dim_dep) + + self.fc_dep_uncer = nn.Linear(self.dep_last_dim, out_dim_dep) + + out_dim_size = 3 * self.cls_out_channels + self.fc_dim = nn.Linear(self.dim_last_dim, out_dim_size) + + out_rot_size = 3 * num_rotation_bins * self.cls_out_channels + self.fc_rot = nn.Linear(self.rot_last_dim, out_rot_size) + + out_cen_2d_size = 2 * self.cls_out_channels + self.fc_cen_2d = nn.Linear(self.cen_2d_last_dim, out_cen_2d_size) + + self._init_weights() + + def _init_weights(self) -> None: + """Init weights of modules in head.""" + module_lists: list[nn.ModuleList | nn.Linear | Conv2d] = [] + module_lists += [self.shared_convs] + module_lists += [self.shared_fcs] + module_lists += [self.dep_convs] + module_lists += [self.fc_dep_uncer] + module_lists += [self.fc_dep, self.dep_fcs] + module_lists += [self.dim_convs] + module_lists += [self.fc_dim, self.dim_fcs] + module_lists += [self.rot_convs] + module_lists += [self.fc_rot, self.rot_fcs] + module_lists += [self.cen_2d_convs] + module_lists += [self.fc_cen_2d, self.cen_2d_fcs] + + for module_list in module_lists: + for m in module_list.modules(): + if isinstance(m, nn.Linear): + xavier_init(m, distribution="uniform") + elif isinstance(m, Conv2d): + kaiming_init(m) + + def _add_conv_fc_branch( + self, + num_branch_convs: int, + num_branch_fcs: int, + in_channels: int, + conv_out_dim: int, + fc_out_dim: int, + conv_has_bias: bool, + norm: None | str, + num_groups: int, + is_shared: bool = False, + ) -> tuple[nn.ModuleList, nn.ModuleList, int]: + """Init modules of head.""" + convs, last_layer_dim = add_conv_branch( + num_branch_convs, + in_channels, + conv_out_dim, + conv_has_bias, + norm, + num_groups, + ) + + fcs = nn.ModuleList() + if num_branch_fcs > 0: + if is_shared or num_branch_fcs == 0: + last_layer_dim *= int(np.prod(self.proposal_pooler.resolution)) + for i in range(num_branch_fcs): + fc_in_dim = last_layer_dim if i == 0 else fc_out_dim + fcs.append( + nn.Sequential( + nn.Linear(fc_in_dim, fc_out_dim), + nn.ReLU(inplace=True), + ) + ) + last_layer_dim = fc_out_dim + return convs, fcs, last_layer_dim + + def get_embeds( + self, feat: Tensor + ) -> tuple[Tensor, Tensor, Tensor, Tensor]: + """Generate embedding from bbox feature.""" + # shared part + if self.num_shared_convs > 0: + for conv in self.shared_convs: + feat = conv(feat) + + if self.num_shared_fcs > 0: + feat = feat.view(feat.size(0), -1) + for fc in self.shared_fcs: + feat = self.relu(fc(feat)) + + # separate branches + x_dep = feat + x_dim = feat + x_rot = feat + x_cen_2d = feat + + for conv in self.dep_convs: + x_dep = conv(x_dep) + if x_dep.dim() > 2: + x_dep = x_dep.view(x_dep.size(0), -1) + for fc in self.dep_fcs: + x_dep = self.relu(fc(x_dep)) + + for conv in self.dim_convs: + x_dim = conv(x_dim) + if x_dim.dim() > 2: + x_dim = x_dim.view(x_dim.size(0), -1) + for fc in self.dim_fcs: + x_dim = self.relu(fc(x_dim)) + + for conv in self.rot_convs: + x_rot = conv(x_rot) + if x_rot.dim() > 2: + x_rot = x_rot.view(x_rot.size(0), -1) + for fc in self.rot_fcs: + x_rot = self.relu(fc(x_rot)) + + for conv in self.cen_2d_convs: + x_cen_2d = conv(x_cen_2d) + if x_cen_2d.dim() > 2: + x_cen_2d = x_cen_2d.view(x_cen_2d.size(0), -1) + for fc in self.cen_2d_fcs: + x_cen_2d = self.relu(fc(x_cen_2d)) + + return x_dep, x_dim, x_rot, x_cen_2d + + def get_outputs( + self, x_dep: Tensor, x_dim: Tensor, x_rot: Tensor, x_cen_2d: Tensor + ) -> Tensor: + """Generate output 3D bounding box parameters.""" + depth = self.fc_dep(x_dep).view(-1, self.cls_out_channels, 1) + depth_uncertainty = self.fc_dep_uncer(x_dep).view( + -1, self.cls_out_channels, 1 + ) + dim = self.fc_dim(x_dim).view(-1, self.cls_out_channels, 3) + alpha = generate_rotation_output( + self.fc_rot(x_rot), self.num_rotation_bins + ) + delta_cen_2d = self.fc_cen_2d(x_cen_2d).view( + -1, self.cls_out_channels, 2 + ) + return torch.cat( + [delta_cen_2d, depth, dim, alpha, depth_uncertainty], -1 + ) + + def get_predictions( + self, features: list[Tensor], boxes_2d: list[Tensor] + ) -> list[Tensor]: + """Get 3D bounding box prediction parameters.""" + if sum(len(b) for b in boxes_2d) == 0: # pragma: no cover + return [ + torch.empty( + ( + 0, + self.cls_out_channels, + 6 + 3 * self.num_rotation_bins + 1, + ), + device=boxes_2d[0].device, + ) + ] * len(boxes_2d) + + roi_feats = self.proposal_pooler( + features[self.start_level : self.end_level], boxes_2d + ) + x_dep, x_dim, x_rot, x_cen_2d = self.get_embeds(roi_feats) + + outputs: list[Tensor] = list( + self.get_outputs(x_dep, x_dim, x_rot, x_cen_2d).split( + [len(b) for b in boxes_2d] + ) + ) + return outputs + + def get_targets( + self, + pos_assigned_gt_inds: list[Tensor], + target_boxes: list[Tensor], + target_boxes3d: list[Tensor], + target_class_ids: list[Tensor], + intrinsics: Tensor, + ) -> tuple[Tensor, Tensor]: + """Get 3D bounding box targets for training.""" + targets = [] + labels = [] + for i, (tgt_boxes, tgt_boxes3d, intrinsics_) in enumerate( + zip(target_boxes, target_boxes3d, intrinsics) + ): + bbox_target = self.box_encoder(tgt_boxes, tgt_boxes3d, intrinsics_) + targets.append(bbox_target[pos_assigned_gt_inds[i]]) + + labels.append(target_class_ids[i][pos_assigned_gt_inds[i]]) + + return torch.cat(targets), torch.cat(labels) + + def forward( + self, + features: list[Tensor], + det_boxes: list[Tensor], + intrinsics: Tensor | None = None, + target_boxes: list[Tensor] | None = None, + target_boxes3d: list[Tensor] | None = None, + target_class_ids: list[Tensor] | None = None, + ) -> QD3DTBBox3DHeadOutput: + """Forward.""" + if ( + intrinsics is not None + and target_boxes is not None + and target_boxes3d is not None + and target_class_ids is not None + ): + if self.proposal_append_gt: + det_boxes = [ + torch.cat([d, t]) for d, t in zip(det_boxes, target_boxes) + ] + + ( + sampled_box_indices, + sampled_target_indices, + sampled_labels, + ) = match_and_sample_proposals( + self.box_matcher, self.box_sampler, det_boxes, target_boxes + ) + positives = [torch.eq(l, 1) for l in sampled_labels] + pos_assigned_gt_inds = [ + i[p] if len(p) != 0 else p + for i, p in zip(sampled_target_indices, positives) + ] + pos_boxes = [ + b[s_i][p] + for b, s_i, p in zip(det_boxes, sampled_box_indices, positives) + ] + predictions = self.get_predictions(features, pos_boxes) + + targets, labels = self.get_targets( + pos_assigned_gt_inds, + target_boxes, + target_boxes3d, + target_class_ids, + intrinsics, + ) + + return QD3DTBBox3DHeadOutput( + predictions=predictions, targets=targets, labels=labels + ) + + predictions = self.get_predictions(features, det_boxes) + + return QD3DTBBox3DHeadOutput(predictions, None, None) + + def __call__( + self, + features: list[Tensor], + det_boxes: list[Tensor], + intrinsics: Tensor | None = None, + target_boxes: list[Tensor] | None = None, + target_boxes3d: list[Tensor] | None = None, + target_class_ids: list[Tensor] | None = None, + ) -> QD3DTBBox3DHeadOutput: + """Type definition.""" + return self._call_impl( + features, + det_boxes, + intrinsics, + target_boxes, + target_boxes3d, + target_class_ids, + ) + + +class RoI2Det3D: + """Post processing for QD3DTBBox3DHead.""" + + def __init__(self, box_decoder: None | QD3DTBox3DDecoder = None) -> None: + """Initialize.""" + self.box_decoder = ( + QD3DTBox3DDecoder() if box_decoder is None else box_decoder + ) + + def __call__( + self, + predictions: list[Tensor], + boxes_2d: list[Tensor], + class_ids: list[Tensor], + intrinsics: Tensor, + ) -> QD3DTDet3DOut: + """Forward pass during testing stage. + + Args: + predictions(list[Tensor]): Predictions. + boxes_2d(list[Tensor]): 2D boxes. + class_ids(list[Tensor]): Class IDs. + intrinsics(Tensor): Camera intrinsics. + + Returns: + QD3DTDet3DOut: QD3DT 3D detection output. + """ + boxes_3d = [] + depth_uncertainty = [] + device = boxes_2d[0].device + for _boxes_2d, _class_ids, _boxes_deltas, _intrinsics in zip( + boxes_2d, class_ids, predictions, intrinsics + ): + if len(_boxes_2d) == 0: + boxes_3d.append(torch.empty(0, 12).to(device)) + depth_uncertainty.append(torch.empty(0).to(device)) + continue + + _boxes_deltas = _boxes_deltas[ + torch.arange(_boxes_deltas.shape[0]), _class_ids + ] + + depth_uncertainty.append( + _boxes_deltas[:, -1].clamp(min=0.0, max=1.0) + ) + boxes_3d.append( + self.box_decoder(_boxes_2d, _boxes_deltas, _intrinsics) + ) + + return QD3DTDet3DOut( + boxes_3d=boxes_3d, depth_uncertainty=depth_uncertainty + ) + + +class Box3DUncertaintyLoss(Loss): + """Box3d loss for QD-3DT.""" + + def __init__( + self, + reducer: LossReducer = mean_loss, + center_loss_weight: float = 1.0, + depth_loss_weight: float = 1.0, + dimension_loss_weight: float = 1.0, + rotation_loss_weight: float = 1.0, + uncertainty_loss_weight: float = 1.0, + num_rotation_bins: int = 2, + ) -> None: + """Creates an instance of the class. + + Args: + reducer (LossReducer): Reducer for the loss function. + center_loss_weight (float): Weight for center loss. + depth_loss_weight (float): Weight for depth loss. + dimension_loss_weight (float): Weight for dimension loss. + rotation_loss_weight (float): Weight for rotation loss. + uncertainty_loss_weight (float): Weight for uncertainty loss. + num_rotation_bins (int): Number of rotation bins. + """ + super().__init__(reducer) + self.center_loss_weight = center_loss_weight + self.depth_loss_weight = depth_loss_weight + self.dimension_loss_weight = dimension_loss_weight + self.rotation_loss_weight = rotation_loss_weight + self.uncertainty_loss_weight = uncertainty_loss_weight + self.num_rotation_bins = num_rotation_bins + + def forward( + self, pred: Tensor, target: Tensor, labels: Tensor + ) -> LossesType: + """Compute box3d loss. + + Args: + pred (Tensor): Box predictions of shape [N, num_classes, + 6 + 3 * num_rotations_bins]. + target (torcch.Tensor): Target boxes of shape [N, + 6 + num_rotation_bins]. + labels (Tensor): Target Labels of shape [N]. + + Returns: + dict[str, Tensor] containing 'delta 2dc', 'dimension', 'depth', + 'rotation' and 'uncertainty' loss. + """ + if pred.size(0) == 0: + loss_ctr3d = loss_dep3d = loss_dim3d = loss_rot3d = loss_conf3d = ( + pred.sum() * 0 + ) + result_dict = { + "loss_ctr3d": loss_ctr3d, + "loss_dep3d": loss_dep3d, + "loss_dim3d": loss_dim3d, + "loss_rot3d": loss_rot3d, + "loss_conf3d": loss_conf3d, + } + + return result_dict + + pred = pred[torch.arange(pred.shape[0], device=pred.device), labels] + + # delta 2dc loss + loss_cen = smooth_l1_loss( + pred[:, :2], target[:, :2], reducer=self.reducer, beta=1 / 9 + ) + + # dimension loss + dim_mask = target[:, 3:6] != 100.0 + loss_dim = smooth_l1_loss( + pred[:, 3:6][dim_mask], + target[:, 3:6][dim_mask], + reducer=self.reducer, + beta=1 / 9, + ) + + # depth loss + depth_mask = target[:, 2] > 0 + loss_dep = smooth_l1_loss( + pred[:, 2][depth_mask], + target[:, 2][depth_mask], + reducer=self.reducer, + beta=1 / 9, + ) + + # rotation loss + loss_rot = rotation_loss( + pred[:, 6 : 6 + self.num_rotation_bins * 3], + target[:, 6 : 6 + self.num_rotation_bins], + target[:, 6 + self.num_rotation_bins :], + self.num_rotation_bins, + reducer=self.reducer, + ) + + # uncertainty loss + pos_depth_self_labels = torch.exp( + -torch.mul(torch.abs(pred[:, 2] - target[:, 2]), 5.0) + ) + pos_depth_self_weights = torch.where( + pos_depth_self_labels > 0.8, + pos_depth_self_labels.new_ones(1) * 5.0, + pos_depth_self_labels.new_ones(1) * 0.1, + ) + + loss_unc3d = smooth_l1_loss( + pred[:, -1], + pos_depth_self_labels.detach().clone(), + reducer=SumWeightedLoss( + pos_depth_self_weights, len(pos_depth_self_weights) + ), + beta=1 / 9, + ) + + return { + "loss_ctr3d": torch.mul(self.center_loss_weight, loss_cen), + "loss_dep3d": torch.mul(self.depth_loss_weight, loss_dep), + "loss_dim3d": torch.mul(self.dimension_loss_weight, loss_dim), + "loss_rot3d": torch.mul(self.rotation_loss_weight, loss_rot), + "loss_unc3d": torch.mul(self.uncertainty_loss_weight, loss_unc3d), + } diff --git a/vis4d/op/detect3d/util.py b/vis4d/op/detect3d/util.py new file mode 100644 index 0000000000000000000000000000000000000000..80c646200cff4107f5063982363164d305d873bf --- /dev/null +++ b/vis4d/op/detect3d/util.py @@ -0,0 +1,117 @@ +"""Utilitiy functions for detection 3D ops.""" + +from __future__ import annotations + +import torch +from torch import Tensor + +from vis4d.common.imports import VIS4D_CUDA_OPS_AVAILABLE + +if VIS4D_CUDA_OPS_AVAILABLE: + from vis4d_cuda_ops import nms_rotated # pylint: disable=no-name-in-module + + +def bev_3d_nms( + center_x: Tensor, + center_y: Tensor, + width: Tensor, + length: Tensor, + angle: Tensor, + scores: Tensor, + class_ids: Tensor | None = None, + iou_threshold: float = 0.1, +) -> Tensor: + """BEV 3D NMS. + + Args: + center_x (Tensor): Center x of boxes. In shape (N, 1). + center_y (Tensor): Center y of boxes. In shape (N, 1). + width (Tensor): Width of boxes. In shape (N, 1). + length (Tensor): Length of boxes. In shape (N, 1). + angle (Tensor): Angle of boxes. In shape (N, 1). + scores (Tensor): Scores of boxes. In shape (N, 1). + class_ids (Tensor | None, optional): Class ids of boxes. In shape + (N,). Defaults to None. If None, class_agnostic NMS will be + performed. + iou_threshold (float, optional): IoU threshold. Defaults to 0.1. + + Returns: + Tensor: Indices of boxes that have been kept by NMS. + """ + class_ids = ( + torch.zeros_like(scores, dtype=torch.int64) # class_agnostic + if class_ids is None + else class_ids + ) + + return batched_nms_rotated( + torch.cat([center_x, center_y, width, length, angle], dim=-1), + scores, + class_ids, + iou_threshold, + ) + + +def batched_nms_rotated( + boxes: Tensor, + scores: Tensor, + idxs: Tensor, + iou_threshold: float, +) -> Tensor: + """Performs non-maximum suppression in a batched fashion. + + Each index value correspond to a category, and NMS + will not be applied between elements of different categories. + + Args: + boxes (Tensor): Boxes where NMS will be performed. They are expected to + be in (x_ctr, y_ctr, width, height, angle_degrees) format. In shape + (N, 5). + scores (Tensor): Scores for each one of the boxes. In shape (N,). + idxs (Tensor): Indices of the categories for each one of the boxes. + In shape (N,). + iou_threshold (float): Discards all overlapping boxes with IoU < + iou_threshold. + + Returns: + Tensor: Int64 tensor with the indices of the elements that have been + kept by NMS, sorted in decreasing order of scores + """ + assert boxes.shape[-1] == 5 + + if boxes.numel() == 0: + return torch.empty((0,), dtype=torch.int64, device=boxes.device) + + boxes = boxes.float() # fp16 does not have enough range for batched NMS + + # Strategy: in order to perform NMS independently per class, + # we add an offset to all the boxes. The offset is dependent + # only on the class idx, and is large enough so that boxes + # from different classes do not overlap + + # Note that batched_nms in torchvision/ops/boxes.py only uses + # max_coordinate, which won't handle negative coordinates correctly. + # Here by using min_coordinate we can make sure the negative coordinates + # are correctly handled. + max_coordinate = ( + torch.max(boxes[:, 0], boxes[:, 1]) + + torch.max(boxes[:, 2], boxes[:, 3]) / 2 + ).max() + min_coordinate = ( + torch.min(boxes[:, 0], boxes[:, 1]) + - torch.max(boxes[:, 2], boxes[:, 3]) / 2 + ).min() + offsets = idxs.to(boxes) * (max_coordinate - min_coordinate + 1) + boxes_for_nms = ( + boxes.clone() + ) # avoid modifying the original values in boxes + boxes_for_nms[:, :2] += offsets[:, None] + + if not VIS4D_CUDA_OPS_AVAILABLE: + raise RuntimeError( + "Please install vis4d_cuda_ops to use batched_nms_rotated" + ) + keep = nms_rotated( # pylint: disable=possibly-used-before-assignment + boxes_for_nms, scores, iou_threshold + ) + return keep diff --git a/vis4d/op/fpp/__init__.py b/vis4d/op/fpp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eb06dca6c0f7362ae45c7119b8f24a5cd8a216b4 --- /dev/null +++ b/vis4d/op/fpp/__init__.py @@ -0,0 +1,12 @@ +"""Vis4D modules for feature pyramid processing. + +Feature pyramid processing is usually used for augmenting the existing feature +maps and/or upsampling the feature maps. +""" + +from .base import FeaturePyramidProcessing +from .dla_up import DLAUp +from .fpn import FPN +from .yolox_pafpn import YOLOXPAFPN + +__all__ = ["DLAUp", "FPN", "FeaturePyramidProcessing", "YOLOXPAFPN"] diff --git a/vis4d/op/fpp/base.py b/vis4d/op/fpp/base.py new file mode 100644 index 0000000000000000000000000000000000000000..76d998ea3238f119a9a378ebe60cbc2afee6b53d --- /dev/null +++ b/vis4d/op/fpp/base.py @@ -0,0 +1,31 @@ +"""Feature pyramid processing base class.""" + +from __future__ import annotations + +import abc + +from torch import Tensor, nn + + +class FeaturePyramidProcessing(nn.Module): + """Base Neck class.""" + + @abc.abstractmethod + def forward(self, features: list[Tensor]) -> list[Tensor]: + """Feature pyramid processing. + + This module do a further processing for the hierarchical feature + representation extracted by the base models. + + Args: + features (list[Tensor]): Feature pyramid as outputs of the + base model. + + Returns: + list[Tensor]: Feature pyramid after the processing. + """ + raise NotImplementedError + + def __call__(self, features: list[Tensor]) -> list[Tensor]: + """Type definition for call implementation.""" + return self._call_impl(features) diff --git a/vis4d/op/fpp/dla_up.py b/vis4d/op/fpp/dla_up.py new file mode 100644 index 0000000000000000000000000000000000000000..7e37cacf8ffc64eee49d4dd34d6f35bf65cae3d8 --- /dev/null +++ b/vis4d/op/fpp/dla_up.py @@ -0,0 +1,171 @@ +"""DLA-UP. + +TODO(fyu) need clean up and update to the latest interface. +""" + +from __future__ import annotations + +import math + +import numpy as np +import torch +from torch import nn + +from vis4d.common.typing import NDArrayI64 +from vis4d.op.layer.conv2d import Conv2d +from vis4d.op.layer.deform_conv import DeformConv + +from .base import FeaturePyramidProcessing + + +def fill_up_weights(up_layer: nn.ConvTranspose2d) -> None: + """Initialize weights of upsample layer.""" + w = up_layer.weight.data + f = math.ceil(w.size(2) / 2) + c = (2 * f - 1 - f % 2) / (2.0 * f) + for i in range(w.size(2)): + for j in range(w.size(3)): + w[0, 0, i, j] = (1 - math.fabs(i / f - c)) * ( + 1 - math.fabs(j / f - c) + ) + for c in range(1, w.size(0)): + w[c, 0, :, :] = w[0, 0, :, :] + + +class IDAUp(nn.Module): + """IDAUp.""" + + def __init__( + self, use_dc: bool, o: int, channels: list[int], up_f: list[int] + ) -> None: + """Creates an instance of the class.""" + super().__init__() + for i in range(1, len(channels)): + c = channels[i] + f = int(up_f[i]) + if use_dc: + proj: Conv2d | DeformConv = DeformConv( + c, + o, + kernel_size=3, + padding=1, + norm=nn.BatchNorm2d(o), + activation=nn.ReLU(inplace=True), + ) + node: Conv2d | DeformConv = DeformConv( + o, + o, + kernel_size=3, + padding=1, + norm=nn.BatchNorm2d(o), + activation=nn.ReLU(inplace=True), + ) + else: + proj = Conv2d( + c, + o, + kernel_size=1, + stride=1, + bias=False, + norm=nn.BatchNorm2d(o), + activation=nn.ReLU(inplace=True), + ) + node = Conv2d( + o, + o, + kernel_size=3, + stride=1, + padding=1, + bias=False, + norm=nn.BatchNorm2d(o), + activation=nn.ReLU(inplace=True), + ) + + up = nn.ConvTranspose2d( + o, + o, + f * 2, + stride=f, + padding=f // 2, + output_padding=0, + groups=o, + bias=False, + ) + fill_up_weights(up) + + setattr(self, "proj_" + str(i), proj) + setattr(self, "up_" + str(i), up) + setattr(self, "node_" + str(i), node) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2.0 / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def forward( + self, layers: list[torch.Tensor], startp: int, endp: int + ) -> None: + """Forward.""" + for i in range(startp + 1, endp): + upsample = getattr(self, "up_" + str(i - startp)) + project = getattr(self, "proj_" + str(i - startp)) + layers[i] = upsample(project(layers[i])) + node = getattr(self, "node_" + str(i - startp)) + layers[i] = node(layers[i] + layers[i - 1]) + + +class DLAUp(FeaturePyramidProcessing): + """DLAUp.""" + + def __init__( + self, + in_channels: list[int], + out_channels: None | int = None, + start_level: int = 0, + end_level: int = -1, + use_deformable_convs: bool = True, + ) -> None: + """Creates an instance of the class.""" + super().__init__() + self.start_level = start_level + self.end_level = end_level + if self.end_level == -1: + self.end_level = len(in_channels) + in_channels = in_channels[self.start_level : self.end_level] + channels = list(in_channels) + scales: NDArrayI64 = np.array( + [2**i for i, _ in enumerate(in_channels)], dtype=np.int64 + ) + for i in range(len(channels) - 1): + j = -i - 2 + idaup = IDAUp( + use_deformable_convs, + channels[j], + in_channels[j:], + scales[j:] // scales[j], + ) + setattr(self, f"ida_{i}", idaup) + scales[j + 1 :] = scales[j] + in_channels[j + 1 :] = [channels[j] for _ in channels[j + 1 :]] + if out_channels is None: + out_channels = channels[0] + self.ida_final = IDAUp( + use_deformable_convs, + out_channels, + channels, + [2**i for i in range(self.end_level - self.start_level)], + ) + + def forward(self, features: list[torch.Tensor]) -> list[torch.Tensor]: + """Forward.""" + outs = [features[self.end_level - 1]] + for i in range(self.end_level - self.start_level - 1): + ida = getattr(self, f"ida_{i}") + ida(features, self.end_level - i - 2, self.end_level) + outs.insert(0, features[self.end_level - 1]) + self.ida_final(outs, 0, len(outs)) + outs = [outs[-1]] + return outs diff --git a/vis4d/op/fpp/fpn.py b/vis4d/op/fpp/fpn.py new file mode 100644 index 0000000000000000000000000000000000000000..0433437bb9529ffb6d696cba15a3a43538afd204 --- /dev/null +++ b/vis4d/op/fpp/fpn.py @@ -0,0 +1,138 @@ +"""Feature Pyramid Network. + +This is based on `"Feature Pyramid Network for Object Detection" +`_. +""" + +from __future__ import annotations + +from collections import OrderedDict + +import torch.nn.functional as F +from torch import Tensor, nn +from torchvision.ops import FeaturePyramidNetwork as _FPN +from torchvision.ops.feature_pyramid_network import ( + ExtraFPNBlock as _ExtraFPNBlock, +) +from torchvision.ops.feature_pyramid_network import ( + LastLevelMaxPool, +) + +from .base import FeaturePyramidProcessing + + +class FPN(_FPN, FeaturePyramidProcessing): # type: ignore + """Feature Pyramid Network. + + This is a wrapper of the torchvision implementation. + """ + + def __init__( + self, + in_channels_list: list[int], + out_channels: int, + extra_blocks: _ExtraFPNBlock | None = LastLevelMaxPool(), + start_index: int = 2, + ) -> None: + """Init without additional components. + + Args: + in_channels_list (list[int]): List of input channels. + out_channels (int): Output channels. + extra_blocks (_ExtraFPNBlock, optional): Extra block. Defaults to + LastLevelMaxPool(). + start_index (int, optional): Start index of base model feature + maps. Defaults to 2. + """ + super().__init__( + in_channels_list, out_channels, extra_blocks=extra_blocks + ) + self.start_index = start_index + + def forward(self, x: list[Tensor]) -> list[Tensor]: + """Process the input features with FPN. + + Because by default, FPN doesn't upsample the first two feature maps in + the pyramid, we keep the first two feature maps intact. + + Args: + x (list[Tensor]): Feature pyramid as outputs of the + base model. + + Returns: + list[Tensor]: Feature pyramid after FPN processing. + """ + feat_dict = OrderedDict( + (k, v) + for k, v in zip( + [str(i) for i in range(len(x) - self.start_index)], + x[self.start_index :], + ) + ) + outs = super().forward(feat_dict) # type: ignore + return [*x[: self.start_index], *outs.values()] # type: ignore + + def __call__(self, x: list[Tensor]) -> list[Tensor]: + """Type definition for call implementation.""" + return self._call_impl(x) + + +class ExtraFPNBlock(_ExtraFPNBlock): # type: ignore + """Extra block in the FPN. + + This is a wrapper of the torchvision implementation. + """ + + def __init__( + self, + extra_levels: int, + in_channels: int, + out_channels: int, + add_extra_convs: str = "on_output", + extra_relu: bool = False, + ) -> None: + """Create an instance of the class.""" + super().__init__() + self.extra_levels = extra_levels + self.add_extra_convs = add_extra_convs + self.extra_relu = extra_relu + + self.convs = nn.ModuleList() + if extra_levels >= 1: + for i in range(extra_levels): + if i == 0 and self.add_extra_convs == "on_input": + _in_channels = in_channels + else: + _in_channels = out_channels + + extra_fpn_conv = nn.Conv2d( + _in_channels, + out_channels, + 3, + stride=2, + padding=1, + ) + self.convs.append(extra_fpn_conv) + + def forward( + self, results: list[Tensor], x: list[Tensor], names: list[str] + ) -> tuple[list[Tensor], list[str]]: + """Forward.""" + if self.add_extra_convs == "on_input": + extra_source = x[-1] + elif self.add_extra_convs == "on_output": + extra_source = results[-1] + else: + raise NotImplementedError + + results.append(self.convs[0](extra_source)) + names.append(str(int(names[-1]) + 1)) + + for i in range(1, self.extra_levels): + if self.extra_relu: + results.append(self.convs[i](F.relu(results[-1]))) + else: + results.append(self.convs[i](results[-1])) + names.append(str(int(names[-1]) + 1)) + + return results, names diff --git a/vis4d/op/fpp/yolox_pafpn.py b/vis4d/op/fpp/yolox_pafpn.py new file mode 100644 index 0000000000000000000000000000000000000000..ce5794f359d918215061753cb79462df63a583ba --- /dev/null +++ b/vis4d/op/fpp/yolox_pafpn.py @@ -0,0 +1,175 @@ +"""YOLOX PAFPN. + +Modified from mmdetection (https://github.com/open-mmlab/mmdetection). +""" + +from __future__ import annotations + +import math + +import torch +from torch import nn + +from vis4d.op.layer.conv2d import Conv2d +from vis4d.op.layer.csp_layer import CSPLayer + +from .base import FeaturePyramidProcessing + + +class YOLOXPAFPN(FeaturePyramidProcessing): + """Path Aggregation Network used in YOLOX. + + Args: + in_channels (list[int]): Number of input channels per scale. + out_channels (int): Number of output channels (used at each scale). + num_csp_blocks (int, optional): Number of bottlenecks in CSPLayer. + Defaults to 3. + start_index (int, optional): Index of the first input feature map. + Defaults to 2. + """ + + def __init__( + self, + in_channels: list[int], + out_channels: int, + num_csp_blocks: int = 3, + start_index: int = 2, + ): + """Init.""" + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.start_index = start_index + + # build top-down blocks + self.upsample = nn.Upsample(scale_factor=2, mode="nearest") + self.reduce_layers = nn.ModuleList() + self.top_down_blocks = nn.ModuleList() + for idx in range(len(in_channels) - 1, 0, -1): + self.reduce_layers.append( + Conv2d( + in_channels[idx], + in_channels[idx - 1], + 1, + bias=False, + norm=nn.BatchNorm2d( + in_channels[idx - 1], eps=0.001, momentum=0.03 + ), + activation=nn.SiLU(inplace=True), + ) + ) + self.top_down_blocks.append( + CSPLayer( + in_channels[idx - 1] * 2, + in_channels[idx - 1], + num_blocks=num_csp_blocks, + add_identity=False, + ) + ) + + # build bottom-up blocks + self.downsamples = nn.ModuleList() + self.bottom_up_blocks = nn.ModuleList() + for idx in range(len(in_channels) - 1): + self.downsamples.append( + Conv2d( + in_channels[idx], + in_channels[idx], + 3, + stride=2, + padding=1, + bias=False, + norm=nn.BatchNorm2d( + in_channels[idx], eps=0.001, momentum=0.03 + ), + activation=nn.SiLU(inplace=True), + ) + ) + self.bottom_up_blocks.append( + CSPLayer( + in_channels[idx] * 2, + in_channels[idx + 1], + num_blocks=num_csp_blocks, + add_identity=False, + ) + ) + + self.out_convs = nn.ModuleList() + for _, inc in enumerate(in_channels): + self.out_convs.append( + Conv2d( + inc, + out_channels, + 1, + bias=False, + norm=nn.BatchNorm2d( + out_channels, eps=0.001, momentum=0.03 + ), + activation=nn.SiLU(inplace=True), + ) + ) + self._init_weights() + + def _init_weights(self) -> None: + """Initialize weights.""" + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_uniform_( + m.weight, + a=math.sqrt(5), + mode="fan_in", + nonlinearity="leaky_relu", + ) + + def forward(self, features: list[torch.Tensor]) -> list[torch.Tensor]: + """Forward pass. + + Args: + features (tuple[Tensor]): Input features. + + Returns: + list[Tensor]: YOLOXPAFPN features. + """ + images, features = ( + features[: self.start_index], + features[self.start_index :], + ) + assert len(features) == len(self.in_channels) + + # top-down path + inner_outs = [features[-1]] + for idx in range(len(self.in_channels) - 1, 0, -1): + feat_heigh = inner_outs[0] + feat_low = features[idx - 1] + feat_heigh = self.reduce_layers[len(self.in_channels) - 1 - idx]( + feat_heigh + ) + inner_outs[0] = feat_heigh + + upsample_feat = self.upsample(feat_heigh) + + inner_out = self.top_down_blocks[len(self.in_channels) - 1 - idx]( + torch.cat([upsample_feat, feat_low], 1) + ) + inner_outs.insert(0, inner_out) + + # bottom-up path + outs = [inner_outs[0]] + for idx in range(len(self.in_channels) - 1): + feat_low = outs[-1] + feat_height = inner_outs[idx + 1] + downsample_feat = self.downsamples[idx](feat_low) + out = self.bottom_up_blocks[idx]( + torch.cat([downsample_feat, feat_height], 1) + ) + outs.append(out) + + # out convs + for idx, conv in enumerate(self.out_convs): + outs[idx] = conv(outs[idx]) + + return images + outs + + def __call__(self, features: list[torch.Tensor]) -> list[torch.Tensor]: + """Type definition for call implementation.""" + return self._call_impl(features) diff --git a/vis4d/op/geometry/__init__.py b/vis4d/op/geometry/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..07be652039b98f91f13deba9d270f48db32da255 --- /dev/null +++ b/vis4d/op/geometry/__init__.py @@ -0,0 +1 @@ +"""Init geometry module.""" diff --git a/vis4d/op/geometry/__pycache__/__init__.cpython-311.pyc b/vis4d/op/geometry/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03bf2ee18cf669d7fdb220ff7eb658b93fc9eb6a Binary files /dev/null and b/vis4d/op/geometry/__pycache__/__init__.cpython-311.pyc differ diff --git a/vis4d/op/geometry/__pycache__/projection.cpython-311.pyc b/vis4d/op/geometry/__pycache__/projection.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6edbb527072ab7e765916660ade8b89a246daf54 Binary files /dev/null and b/vis4d/op/geometry/__pycache__/projection.cpython-311.pyc differ diff --git a/vis4d/op/geometry/__pycache__/rotation.cpython-311.pyc b/vis4d/op/geometry/__pycache__/rotation.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0de841b3bee8f6e089a0a5bf44ec28f792326eab Binary files /dev/null and b/vis4d/op/geometry/__pycache__/rotation.cpython-311.pyc differ diff --git a/vis4d/op/geometry/__pycache__/transform.cpython-311.pyc b/vis4d/op/geometry/__pycache__/transform.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb8060a4c244977f08bb59eb25178f2e8ba753b8 Binary files /dev/null and b/vis4d/op/geometry/__pycache__/transform.cpython-311.pyc differ diff --git a/vis4d/op/geometry/projection.py b/vis4d/op/geometry/projection.py new file mode 100644 index 0000000000000000000000000000000000000000..ab2c352dddbafaa5c8b845a804146f4201bdcbc4 --- /dev/null +++ b/vis4d/op/geometry/projection.py @@ -0,0 +1,138 @@ +"""Projection utilities.""" + +from __future__ import annotations + +import torch + +from .transform import inverse_pinhole + + +def project_points( + points: torch.Tensor, intrinsics: torch.Tensor +) -> torch.Tensor: + """Project points to pixel coordinates with given intrinsics. + + Args: + points: (N, 3) or (B, N, 3) 3D coordinates. + intrinsics: (3, 3) or (B, 3, 3) intrinsic camera matrices. + + Returns: + torch.Tensor: (N, 2) or (B, N, 2) 2D pixel coordinates. + + Raises: + ValueError: Shape of input points is not valid for computation. + """ + assert points.shape[-1] == 3, "Input coordinates must be 3 dimensional!" + hom_coords = points / points[..., 2:3] + if len(hom_coords.shape) == 2: + assert ( + len(intrinsics.shape) == 2 + ), "Got multiple intrinsics for single point set!" + intrinsics = intrinsics.T + elif len(hom_coords.shape) == 3: + if len(intrinsics.shape) == 2: + intrinsics = intrinsics.unsqueeze(0) + intrinsics = intrinsics.permute(0, 2, 1) + else: + raise ValueError(f"Shape of input points not valid: {points.shape}") + pts_2d = hom_coords @ intrinsics + return pts_2d[..., :2] + + +def unproject_points( + points: torch.Tensor, depths: torch.Tensor, intrinsics: torch.Tensor +) -> torch.Tensor: + """Un-projects pixel coordinates to 3D coordinates with given intrinsics. + + Args: + points: (N, 2) or (B, N, 2) 2D pixel coordinates. + depths: (N,) / (N, 1) or (B, N,) / (B, N, 1) depth values. + intrinsics: (3, 3) or (B, 3, 3) intrinsic camera matrices. + + Returns: + torch.Tensor: (N, 3) or (B, N, 3) 3D coordinates. + + Raises: + ValueError: Shape of input points is not valid for computation. + """ + if len(points.shape) == 2: + assert ( + len(intrinsics.shape) == 2 or intrinsics.shape[0] == 1 + ), "Got multiple intrinsics for single point set!" + if len(intrinsics.shape) == 3: + intrinsics = intrinsics.squeeze(0) + inv_intrinsics = inverse_pinhole(intrinsics).transpose(0, 1) + if len(depths.shape) == 1: + depths = depths.unsqueeze(-1) + assert len(depths.shape) == 2, "depths must have same dims as points" + elif len(points.shape) == 3: + inv_intrinsics = inverse_pinhole(intrinsics).transpose(-2, -1) + if len(depths.shape) == 2: + depths = depths.unsqueeze(-1) + assert len(depths.shape) == 3, "depths must have same dims as points" + else: + raise ValueError(f"Shape of input points not valid: {points.shape}") + hom_coords = torch.cat([points, torch.ones_like(points)[..., 0:1]], -1) + pts_3d = hom_coords @ inv_intrinsics + pts_3d *= depths + return pts_3d + + +def points_inside_image( + points_coord: torch.Tensor, + depths: torch.Tensor, + images_hw: torch.Tensor | tuple[int, int], +) -> torch.Tensor: + """Generate binary mask. + + Creates a mask that is true for all point coordiantes that lie inside the + image, + + Args: + points_coord (torch.Tensor): 2D pixel coordinates of shape [..., 2]. + depths (torch.Tensor): Associated depth of each 2D pixel coordinate. + images_hw: (torch.Tensor| tuple[int, int]]) Associated tensor of image + dimensions, shape [..., 2] or single height, width pair. + + Returns: + torch.Tensor: Binary mask of points inside an image. + """ + mask = torch.ones_like(depths) + h: int | torch.Tensor + w: int | torch.Tensor + + if isinstance(images_hw, tuple): + h, w = images_hw + else: + h, w = images_hw[..., 0], images_hw[..., 1] + mask = torch.logical_and(mask, torch.greater(depths, 0)) + mask = torch.logical_and(mask, points_coord[..., 0] > 0) + mask = torch.logical_and(mask, points_coord[..., 0] < w - 1) + mask = torch.logical_and(mask, points_coord[..., 1] > 0) + mask = torch.logical_and(mask, points_coord[..., 1] < h - 1) + return mask + + +def generate_depth_map( + points: torch.Tensor, + intrinsics: torch.Tensor, + image_hw: tuple[int, int], +) -> torch.Tensor: + """Generate depth map for given pointcloud. + + Args: + points: (N, 3) coordinates. + intrinsics: (3, 3) intrinsic camera matrices. + image_hw: (tuple[int,int]) height, width of the image + + Returns: + torch.Tensor: Projected depth map of the given pointcloud. + Invalid depth has 0 values + """ + pts_2d = project_points(points, intrinsics).round() + depths = points[:, 2] + depth_map = points.new_zeros(image_hw) + mask = points_inside_image(pts_2d, depths, image_hw) + pts_2d = pts_2d[mask].long() + depth_map[pts_2d[:, 1], pts_2d[:, 0]] = depths[mask] + return depth_map diff --git a/vis4d/op/geometry/rotation.py b/vis4d/op/geometry/rotation.py new file mode 100644 index 0000000000000000000000000000000000000000..b605ecc567b7a5d19f985d62f32ab7b8252861b2 --- /dev/null +++ b/vis4d/op/geometry/rotation.py @@ -0,0 +1,539 @@ +"""Rotation utilities.""" + +import functools + +import torch +import torch.nn.functional as F +from torch import Tensor + +from vis4d.data.const import AxisMode + + +def normalize_angle(input_angles: Tensor) -> Tensor: + """Normalize content of input_angles to range [-pi, pi]. + + Args: + input_angles: (Tensor) tensor of any shape containing + unnormalized angles. + + Returns: + Tensor with angles normalized to +/- pi + """ + return torch.sub((input_angles + torch.pi) % (2 * torch.pi), torch.pi) + + +def acute_angle(theta_1: Tensor, theta_2: Tensor) -> Tensor: + """Update theta_1 to mkae the agnle between two thetas is acute.""" + # Make sure the angle between two thetas is acute + if torch.pi / 2.0 < abs(theta_2 - theta_1) < torch.pi * 3 / 2.0: + theta_1 += torch.pi + if theta_1 > torch.pi: + theta_1 -= torch.pi * 2 + if theta_1 < -torch.pi: + theta_1 += torch.pi * 2 + + # Convert the case of > 270 to < 90 + if abs(theta_2 - theta_1) >= torch.pi * 3 / 2.0: + if theta_2 > 0: + theta_1 += torch.pi * 2 + else: + theta_1 -= torch.pi * 2 + return theta_1 + + +def yaw2alpha(rot_y: Tensor, center: Tensor) -> Tensor: + """Get alpha by vertical rotation - theta. + + Args: + rot_y: Rotation around Y-axis in camera coordinates [-pi..pi] + center: 3D object center in camera coordinates + + Returns: + alpha: Observation angle of object, ranging [-pi..pi] + """ + alpha = rot_y - torch.atan2(center[..., 0], center[..., 2]) + return normalize_angle(alpha) + + +def alpha2yaw(alpha: Tensor, center: Tensor) -> Tensor: + """Get vertical rotation by alpha + theta. + + Args: + alpha: Observation angle of object, ranging [-pi..pi] + center: 3D object center in camera coordinates + + Returns: + rot_y: Vertical rotation in camera coordinates [-pi..pi] + """ + rot_y = alpha + torch.atan2(center[..., 0], center[..., 2]) + return normalize_angle(rot_y) + + +def rotation_output_to_alpha(output: Tensor, num_bins: int = 2) -> Tensor: + """Get alpha from bin-based regression output. + + Uses method described in (with two bins): + See: 3D Bounding Box Estimation Using Deep Learning and Geometry, + Mousavian et al., CVPR'17 + + Args: + output: (Tensor) bin based regressed output. + num_bins: (int) number of bins to use + + Returns: + Tensor containing the angle from the bin-based regression output + """ + out_range = torch.tensor(list(range(len(output))), device=output.device) + bin_idx = output[:, :num_bins].argmax(dim=-1) + res_idx = num_bins + 2 * bin_idx + bin_centers = torch.arange( + -torch.pi, torch.pi, 2 * torch.pi / num_bins, device=output.device + ) + bin_centers += torch.pi / num_bins + alpha = ( + torch.atan(output[out_range, res_idx] / output[out_range, res_idx + 1]) + + bin_centers[bin_idx] + ) + return alpha + + +def generate_rotation_output(pred: Tensor, num_bins: int = 2) -> Tensor: + """Convert output to bin confidence and cos / sin of residual. + + The viewpoint (alpha) prediction (N, num_bins + 2 * num_bins) consists of: + bin confidences (N, num_bins): softmax logits for bin probability. + 1st entry is probability for orientation being in bin 1, + 2nd entry is probability for orientation being in bin 2, + and so on. + bin residual (N, num_bins * 2): angle residual w.r.t. bin N orientation, + represented as sin and cos values. + + See: 3D Bounding Box Estimation Using Deep Learning and Geometry, + Mousavian et al., CVPR'17 + """ + pred = pred.view(pred.size(0), -1, 3 * num_bins) + bin_logits = pred[..., :num_bins] + + bin_residuals = [] + for i in range(num_bins): + res_idx = num_bins + 2 * i + norm = pred[..., res_idx : res_idx + 2].norm(dim=-1, keepdim=True) + bsin = pred[..., res_idx : res_idx + 1] / norm + bcos = pred[..., res_idx + 1 : res_idx + 2] / norm + bin_residuals.append(bsin) + bin_residuals.append(bcos) + + rot = torch.cat([bin_logits, *bin_residuals], -1) + return rot + + +# Rotation conversion functions adapted from: +# https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py +def _axis_angle_rotation(axis: str, angle: Tensor) -> Tensor: + """Get rotation matrix for an angle around an axis. + + Args: + axis: Axis label "X" or "Y or "Z". + angle: any shape tensor of Euler angles in radians + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + assert axis in {"X", "Y", "Z"}, f"Invalid axis {axis}." + cos = torch.cos(angle) + sin = torch.sin(angle) + one = torch.ones_like(angle) + zero = torch.zeros_like(angle) + + if axis == "X": + rot_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) + elif axis == "Y": + rot_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) + else: + rot_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) + + return torch.stack(rot_flat, -1).reshape(angle.shape + (3, 3)) + + +def euler_angles_to_matrix( + euler_angles: Tensor, convention: str = "XYZ" +) -> Tensor: + """Convert rotations given as Euler angles in radians to rotation matrices. + + Args: + euler_angles: Euler angles in radians as tensor of shape (..., 3). + convention: Convention string of three uppercase letters from + "X", "Y", and "Z". + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + + Raises: + ValueError: if convention string is not a combination of XYZ + """ + if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: + raise ValueError("Invalid input euler angles.") + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + matrices = [ + _axis_angle_rotation(c, a) + for c, a in zip(convention, torch.unbind(euler_angles, -1)) + ] + return functools.reduce(torch.matmul, matrices) + + +def _index_from_letter(letter: str) -> int: # pragma: no cover + """Return index from letter. + + Args: + letter: (str) letter in [X,Y,Z] + + Returns: + int mapping of the corresponding letter [0,1,2] + + Raises: + ValueError: if the given letter is not valid + """ + if letter == "X": + return 0 + if letter == "Y": + return 1 + if letter == "Z": + return 2 + raise ValueError("letter not valid!") + + +def _angle_from_tan( + axis: str, + other_axis: str, + data: Tensor, + horizontal: bool, + tait_bryan: bool, +) -> Tensor: + """Helper function for matrix_to_euler_angles. + + Extracts the first or third Euler angle from the two members of + the matrix which are positive constant times its sine and cosine. + + Args: + axis: Axis label "X" or "Y or "Z" for the angle we are finding. + other_axis: Axis label "X" or "Y or "Z" for the middle axis in the + convention. + data: Rotation matrices as tensor of shape (..., 3, 3). + horizontal: Whether we are looking for the angle for the third axis, + which means the relevant entries are in the same row of the + rotation matrix. If not, they are in the same column. + tait_bryan: Whether the first and third axes in the convention differ. + + Returns: + Euler Angles in radians for each matrix in data as a tensor + of shape (...). + """ + i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis] + if horizontal: + i2, i1 = i1, i2 + even = axis + other_axis in {"XY", "YZ", "ZX"} + if horizontal == even: + return torch.atan2(data[..., i1], data[..., i2]) + if tait_bryan: + return torch.atan2(-data[..., i2], data[..., i1]) + return torch.atan2(data[..., i2], -data[..., i1]) + + +def matrix_to_euler_angles(matrix: Tensor, convention: str = "XYZ") -> Tensor: + """Convert rotations given as rotation matrices to Euler angles in radians. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + convention: Convention string of three uppercase letters. + + Returns: + Euler angles in radians as tensor of shape (..., 3). + + Raises: + ValueError: if convention string is not a combination of XYZ + """ + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + i0 = _index_from_letter(convention[0]) + i2 = _index_from_letter(convention[2]) + tait_bryan = i0 != i2 + if tait_bryan: + rads = matrix[..., i0, i2] + # safety for nan + rads[torch.where(rads > 1.0)] = rads.new_tensor([1.0]).to(rads.device) + rads[torch.where(rads < -1.0)] = rads.new_tensor([-1.0]).to( + rads.device + ) + central_angle = torch.asin( + rads * (-1.0 if i0 - i2 in [-1, 2] else 1.0) + ) + else: + central_angle = torch.acos(matrix[..., i0, i0]) + + o = ( + _angle_from_tan( + convention[0], convention[1], matrix[..., i2], False, tait_bryan + ), + central_angle, + _angle_from_tan( + convention[2], convention[1], matrix[..., i0, :], True, tait_bryan + ), + ) + return torch.stack(o, -1) + + +def quaternion_to_matrix(quaternions: Tensor) -> Tensor: + """Convert rotations given as quaternions to rotation matrices. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + r, i, j, k = torch.unbind(quaternions, -1) + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def _sqrt_positive_part(quat: Tensor) -> Tensor: + """Returns sqrt(max(0, x)) but with a zero subgradient where x is 0.""" + ret = torch.zeros_like(quat) + positive_mask = quat > 0 + ret[positive_mask] = torch.sqrt(quat[positive_mask]) + return ret + + +def matrix_to_quaternion(matrix: Tensor) -> Tensor: + """Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + + Raises: + ValueError: If shape of input matrix is not correct. + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + + batch_dim = matrix.shape[:-2] + m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( + matrix.reshape(*batch_dim, 9), dim=-1 + ) + + q_abs = _sqrt_positive_part( + torch.stack( + [ + 1.0 + m00 + m11 + m22, + 1.0 + m00 - m11 - m22, + 1.0 - m00 + m11 - m22, + 1.0 - m00 - m11 + m22, + ], + dim=-1, + ) + ) + + # we produce the desired quaternion multiplied by each of r, i, j, k + quat_by_rijk = torch.stack( + [ + torch.stack( + [q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1 + ), + torch.stack( + [m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1 + ), + torch.stack( + [m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1 + ), + torch.stack( + [m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1 + ), + ], + dim=-2, + ) + + # We floor here at 0.1 but the exact level is not important; if q_abs is + # small, the candidate won't be picked. + quat_candidates = quat_by_rijk / ( + 2.0 * q_abs[..., None].max(q_abs.new_tensor(0.1)) + ) + + # if not for numerical problems, quat_candidates[i] should be same + # (up to a sign), forall i; we pick the best-conditioned one + # (with the largest denominator) + + return quat_candidates[ + F.one_hot( # pylint: disable=not-callable + q_abs.argmax(dim=-1), num_classes=4 + ) + > 0.5, + :, # pyre-ignore[16] + ].reshape(*batch_dim, 4) + + +def standardize_quaternion(quaternions: Tensor) -> Tensor: + """Convert a unit quaternion to a standard form. + + Standard form: One in which the real part is non negative. + + Args: + quaternions: Quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Standardized quaternions as tensor of shape (..., 4). + """ + return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) + + +def quaternion_raw_multiply(quat1: Tensor, quat2: Tensor) -> Tensor: + """Multiply two quaternions. + + Usual torch rules for broadcasting apply. + + Args: + quat1: Quaternions as tensor of shape (..., 4), real part first. + quat2: Quaternions as tensor of shape (..., 4), real part first. + + Returns: + The product of quat1 and quat2, tensor of quaternions shape (..., 4). + """ + aw, ax, ay, az = torch.unbind(quat1, -1) + bw, bx, by, bz = torch.unbind(quat2, -1) + ow = aw * bw - ax * bx - ay * by - az * bz + ox = aw * bx + ax * bw + ay * bz - az * by + oy = aw * by - ax * bz + ay * bw + az * bx + oz = aw * bz + ax * by - ay * bx + az * bw + return torch.stack((ow, ox, oy, oz), -1) + + +def quaternion_multiply(quat1: Tensor, quat2: Tensor) -> Tensor: + """Multiply two quaternions representing rotations. + + Returns the quaternion representing their composition, i.e. the version + with nonnegative real part. Usual torch rules for broadcasting apply. + + Args: + quat1: Quaternions as tensor of shape (..., 4), real part first. + quat2: Quaternions as tensor of shape (..., 4), real part first. + + Returns: + The product of quat1 and quat2, tensor of quaternions shape (..., 4). + """ + return standardize_quaternion(quaternion_raw_multiply(quat1, quat2)) + + +def quaternion_invert(quaternion: Tensor) -> Tensor: + """Return quaternion that represents inverse rotation. + + Args: + quaternion: Quaternions as tensor of shape (..., 4), with real part + first, which must be versors (unit quaternions). + + Returns: + The inverse, a tensor of quaternions of shape (..., 4). + """ + return quaternion * quaternion.new_tensor([1, -1, -1, -1]) + + +def quaternion_apply(quaternion: Tensor, points: Tensor) -> Tensor: + """Apply the rotation given by a quaternion to a 3D point. + + Usual torch rules for broadcasting apply. + + Args: + quaternion: Tensor of quaternions, real part first, of shape (..., 4). + points: Tensor of 3D points of shape (..., 3). + + Returns: + Tensor of rotated points of shape (..., 3). + + Raises: + ValueError: If points is not a valid 3D point set. + """ + if points.size(-1) != 3: + raise ValueError(f"Points are not in 3D, {points.shape}.") + real_parts = points.new_zeros(points.shape[:-1] + (1,)) + point_as_quaternion = torch.cat((real_parts, points), -1) + out = quaternion_raw_multiply( + quaternion_raw_multiply(quaternion, point_as_quaternion), + quaternion_invert(quaternion), + ) + return out[..., 1:] + + +def rotation_matrix_yaw( + rotation_matrix: Tensor, axis_mode: AxisMode +) -> Tensor: + """Get yaw of 3D boxes in euler angle under given axis mode. + + Args: + rotation_matrix (Tensor): [N, 3, 3] Rotation matrix of the object. + axis_mode (AxisMode): Coordinate system convention. + + Returns: + orientation (Tensor): [N, 3] Yaw in euler angle. + """ + orientation = rotation_matrix.new_zeros(rotation_matrix.shape[0], 3) + + if axis_mode == AxisMode.OPENCV: + orientation[:, 1] = matrix_to_euler_angles(rotation_matrix, "YZX")[ + :, 0 + ] + else: + orientation[:, 2] = matrix_to_euler_angles(rotation_matrix, "ZYX")[ + :, 0 + ] + return orientation + + +def rotate_orientation( + orientation: Tensor, extrinsics: Tensor, axis_mode: AxisMode = AxisMode.ROS +) -> Tensor: + """Rotate the orientation of the object in different coordinate. + + Args: + orientation (Tensor): [N, 3] Orientation of the object in euler angles. + extrinsics (Tensor): [4, 4] Extrinsic matrix of the object. + axis_mode (AxisMode): Coordinate system convention. Default: + AxisMode.ROS + """ + rot = extrinsics[:3, :3] @ euler_angles_to_matrix(orientation) + return rotation_matrix_yaw(rot, axis_mode) + + +def rotate_velocities(velocities: Tensor, extrinsics: Tensor) -> Tensor: + """Rotate the velocities of the object in different coordinate.""" + return (extrinsics[:3, :3] @ velocities.unsqueeze(-1)).squeeze(-1) diff --git a/vis4d/op/geometry/transform.py b/vis4d/op/geometry/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..2276419a55cb386f840dc6f536dff26756d009df --- /dev/null +++ b/vis4d/op/geometry/transform.py @@ -0,0 +1,120 @@ +"""Vis4D geometric transformation functions.""" + +import torch +from torch import Tensor + + +def transform_points(points: Tensor, transform: Tensor) -> Tensor: + """Applies transform to points. + + Args: + points (Tensor): points of shape (N, D) or (B, N, D). + transform (Tensor): transforms of shape (D+1, D+1) or (B, D+1, D+1). + + Returns: + Tensor: (N, D) / (B, N, D) transformed points. + + Raises: + ValueError: Either points or transform have incorrect shape + """ + hom_coords = torch.cat([points, torch.ones_like(points[..., 0:1])], -1) + if len(points.shape) == 2: + if len(transform.shape) == 3: + assert ( + transform.shape[0] == 1 + ), "Got multiple transforms for single point set!" + transform = transform.squeeze(0) + transform = transform.T + elif len(points.shape) == 3: + if len(transform.shape) == 2: + transform = transform.T.unsqueeze(0) + elif len(transform.shape) == 3: + transform = transform.permute(0, 2, 1) + else: + raise ValueError(f"Shape of transform invalid: {transform.shape}") + else: + raise ValueError(f"Shape of input points invalid: {points.shape}") + points_transformed = hom_coords @ transform + return points_transformed[..., : points.shape[-1]] + + +def inverse_pinhole(intrinsic_matrix: Tensor) -> Tensor: + """Calculate inverse of pinhole projection matrix. + + Args: + intrinsic_matrix (Tensor): [..., 3, 3] intrinsics or single [3, 3] + intrinsics. + + Returns: + Tensor: Inverse of input intrinisics. + """ + squeeze = False + inv = intrinsic_matrix.clone() + if len(intrinsic_matrix.shape) == 2: + inv = inv.unsqueeze(0) + squeeze = True + + inv[..., 0, 0] = 1.0 / inv[..., 0, 0] + inv[..., 1, 1] = 1.0 / inv[..., 1, 1] + inv[..., 0, 2] = -inv[..., 0, 2] * inv[..., 0, 0] + inv[..., 1, 2] = -inv[..., 1, 2] * inv[..., 1, 1] + + if squeeze: + inv = inv.squeeze(0) + return inv + + +def inverse_rigid_transform(transformation: Tensor) -> Tensor: + """Calculate inverse of rigid body transformation(s). + + Args: + transformation (Tensor): [N, 4, 4] transformations or single [4, 4] + transformation. + + Returns: + Tensor: Inverse of input transformation(s). + """ + squeeze = False + if len(transformation.shape) == 2: + transformation = transformation.unsqueeze(0) + squeeze = True + rotation, translation = transformation[:, :3, :3], transformation[:, :3, 3] + rot = rotation.permute(0, 2, 1) + t = -rot @ translation[:, :, None] + inv = torch.cat([torch.cat([rot, t], -1), transformation[:, 3:4]], 1) + if squeeze: + inv = inv.squeeze(0) + return inv + + +def get_transform_matrix(rotation: Tensor, translation: Tensor) -> Tensor: + """Assembles 4x4 transformation from rotation / translation pair(s). + + Args: + rotation (Tensor): [N, 3, 3] or [3, 3] rotation(s). + translation (Tensor): [N, 3] or [3,] translation(s). + + Returns: + Tensor: [N, 4, 4] or [4, 4] transformation. + """ + squeeze = False + if len(rotation.shape) == 2: + assert len(translation.shape) == 1 + rotation = rotation.unsqueeze(0) + translation = translation.unsqueeze(0) + squeeze = True + batch_size = 1 + else: + assert len(rotation.shape) == 3 and len(translation.shape) == 2 + assert rotation.shape[0] == translation.shape[0] + batch_size = rotation.shape[0] + assert ( + rotation.shape[-2] == rotation.shape[-1] == translation.shape[-1] == 3 + ) + transforms = rotation.new_zeros((batch_size, 4, 4)) + transforms[:, :3, :3] = rotation + transforms[:, :3, 3] = translation + transforms[:, 3, 3] = 1.0 + if squeeze: + transforms = transforms.squeeze(0) + return transforms diff --git a/vis4d/op/layer/__init__.py b/vis4d/op/layer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7baaa368a725604b71292dab2dd125a9e3799a94 --- /dev/null +++ b/vis4d/op/layer/__init__.py @@ -0,0 +1 @@ +"""layers op module.""" diff --git a/vis4d/op/layer/__pycache__/__init__.cpython-311.pyc b/vis4d/op/layer/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6956b9af70383dd46441c309a28ca27c9063113a Binary files /dev/null and b/vis4d/op/layer/__pycache__/__init__.cpython-311.pyc differ diff --git a/vis4d/op/layer/__pycache__/attention.cpython-311.pyc b/vis4d/op/layer/__pycache__/attention.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6243d27908d23a077773034e27e4ec7482dea5a8 Binary files /dev/null and b/vis4d/op/layer/__pycache__/attention.cpython-311.pyc differ diff --git a/vis4d/op/layer/__pycache__/conv2d.cpython-311.pyc b/vis4d/op/layer/__pycache__/conv2d.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9138db1a7adb5e568b9a9cff3a1fe87c1387cc48 Binary files /dev/null and b/vis4d/op/layer/__pycache__/conv2d.cpython-311.pyc differ diff --git a/vis4d/op/layer/__pycache__/deform_conv.cpython-311.pyc b/vis4d/op/layer/__pycache__/deform_conv.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f56497013dc1b5e804d8c716b445938dd79c7c69 Binary files /dev/null and b/vis4d/op/layer/__pycache__/deform_conv.cpython-311.pyc differ diff --git a/vis4d/op/layer/__pycache__/drop.cpython-311.pyc b/vis4d/op/layer/__pycache__/drop.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ba08d87d4986851f8e6769b1cb7fe350092636c Binary files /dev/null and b/vis4d/op/layer/__pycache__/drop.cpython-311.pyc differ diff --git a/vis4d/op/layer/__pycache__/mlp.cpython-311.pyc b/vis4d/op/layer/__pycache__/mlp.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d15262fd304789cb7dcbff07aa5882d39fd92f7 Binary files /dev/null and b/vis4d/op/layer/__pycache__/mlp.cpython-311.pyc differ diff --git a/vis4d/op/layer/__pycache__/transformer.cpython-311.pyc b/vis4d/op/layer/__pycache__/transformer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1d2d61ef5c2f90236af5f07376f660283924ff4 Binary files /dev/null and b/vis4d/op/layer/__pycache__/transformer.cpython-311.pyc differ diff --git a/vis4d/op/layer/__pycache__/util.cpython-311.pyc b/vis4d/op/layer/__pycache__/util.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af205966e5dc357cced6e5aa3816db2f5e0ff6e3 Binary files /dev/null and b/vis4d/op/layer/__pycache__/util.cpython-311.pyc differ diff --git a/vis4d/op/layer/__pycache__/weight_init.cpython-311.pyc b/vis4d/op/layer/__pycache__/weight_init.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4eee10c16a20de1ce4e8676847aa97100c10eea Binary files /dev/null and b/vis4d/op/layer/__pycache__/weight_init.cpython-311.pyc differ diff --git a/vis4d/op/layer/attention.py b/vis4d/op/layer/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..f07212d0e3bd39a3bb338d565bfe845a16e50e4e --- /dev/null +++ b/vis4d/op/layer/attention.py @@ -0,0 +1,241 @@ +"""Attention layer.""" + +from __future__ import annotations + +from torch import Tensor, nn + +from vis4d.common.logging import rank_zero_warn +from vis4d.common.typing import ArgsType + + +class Attention(nn.Module): + """ViT Attention Layer. + + Modified from timm (https://github.com/huggingface/pytorch-image-models). + """ + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + """Init attention layer. + + Args: + dim (int): Input tensor's dimension. + num_heads (int, optional): Number of attention heads. Defaults to + 8. + qkv_bias (bool, optional): If to add bias to qkv. Defaults to + False. + attn_drop (float, optional): Dropout rate for attention. Defaults + to 0.0. + proj_drop (float, optional): Dropout rate for projection. Defaults + to 0.0. + """ + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def __call__(self, data: Tensor) -> Tensor: + """Applies the layer. + + Args: + data (Tensor): Input tensor of shape (B, N, dim). + + Returns: + Tensor: Output tensor of the same shape as input. + """ + return self._call_impl(data) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + batch_size, num_samples, dim = x.shape + qkv = ( + self.qkv(x) + .reshape( + batch_size, + num_samples, + 3, + self.num_heads, + dim // self.num_heads, + ) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = qkv.unbind( + 0 + ) # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(batch_size, num_samples, dim) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MultiheadAttention(nn.Module): + """A wrapper for ``torch.nn.MultiheadAttention``. + + This module implements MultiheadAttention with identity connection, + and positional encoding is also passed as input. + """ + + def __init__( + self, + embed_dims: int, + num_heads: int, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + dropout_layer: nn.Module | None = None, + batch_first: bool = False, + need_weights: bool = False, + **kwargs: ArgsType, + ) -> None: + """Init MultiheadAttention. + + Args: + embed_dims (int): The embedding dimension. + num_heads (int): Parallel attention heads. + attn_drop (float): A Dropout layer on attn_output_weights. + Default: 0.0. + proj_drop (float): A Dropout layer after `nn.MultiheadAttention`. + Default: 0.0. + dropout_layer (nn.Module | None, optional): The dropout_layer used + when adding the shortcut. Defaults to None. + batch_first (bool): When it is True, Key, Query and Value are + shape of (batch, n, embed_dim), otherwise (n, batch, + embed_dim). Default to False. + need_weights (bool): Whether to return the attention weights. + If True, the output will be a tuple of (attn_output, + attn_output_weights) and not using FlashAttention. If False, + only the attn_output will be returned. Default to False. + """ + super().__init__() + self.batch_first = batch_first + self.embed_dims = embed_dims + self.num_heads = num_heads + self.need_weights = need_weights + + self.attn = nn.MultiheadAttention( + embed_dims, num_heads, dropout=attn_drop, **kwargs + ) + + self.proj_drop = nn.Dropout(proj_drop) + + self.dropout_layer = dropout_layer or nn.Identity() + + def forward( + self, + query: Tensor, + key: Tensor | None = None, + value: Tensor | None = None, + identity: Tensor | None = None, + query_pos: Tensor | None = None, + key_pos: Tensor | None = None, + attn_mask: Tensor | None = None, + key_padding_mask: Tensor | None = None, + ) -> Tensor: + """Forward function for `MultiheadAttention`. + + **kwargs allow passing a more general data flow when combining + with other operations in `transformerlayer`. + + Args: + query (Tensor): The input query with shape [num_queries, bs, + embed_dims] if self.batch_first is False, else + [bs, num_queries embed_dims]. + key (Tensor): The key tensor with shape [num_keys, bs, + embed_dims] if self.batch_first is False, else + [bs, num_keys, embed_dims] . + If None, the ``query`` will be used. Defaults to None. + value (Tensor): The value tensor with same shape as `key`. + Same in `nn.MultiheadAttention.forward`. Defaults to None. + If None, the `key` will be used. + identity (Tensor): This tensor, with the same shape as query, + will be used for the identity link. + If None, `query` will be used. Defaults to None. + query_pos (Tensor): The positional encoding for query, with + the same shape as `query`. If not None, it will + be added to `query` before forward function. Defaults to None. + key_pos (Tensor): The positional encoding for `key`, with the + same shape as `key`. Defaults to None. If not None, it will + be added to `key` before forward function. If None, and + `query_pos` has the same shape as `key`, then `query_pos` + will be used for `key_pos`. Defaults to None. + attn_mask (Tensor): ByteTensor mask with shape [num_queries, + num_keys]. Same in `nn.MultiheadAttention.forward`. + Defaults to None. + key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys]. + Defaults to None. + + Returns: + Tensor: forwarded results with shape [num_queries, bs, embed_dims] + if self.batch_first is False, else [bs, num_queries, + embed_dims]. + """ + if key is None: + key = query + + if value is None: + value = key + + if identity is None: + identity = query + + if key_pos is None and query_pos is not None: + # use query_pos if key_pos is not available + if query_pos.shape == key.shape: + key_pos = query_pos + else: + rank_zero_warn( + f"Position encoding of key in {self.__class__.__name__}" + + "is missing, and positional encodeing of query has " + + "has different shape and cannot be usde for key. " + + "It it is not desired, please provide key_pos." + ) + + if query_pos is not None: + query = query + query_pos + + if key_pos is not None: + key = key + key_pos + + # Because the dataflow('key', 'query', 'value') of + # ``torch.nn.MultiheadAttention`` is (num_query, batch, + # embed_dims), We should adjust the shape of dataflow from + # batch_first (batch, num_query, embed_dims) to num_query_first + # (num_query, batch, embed_dims), and recover ``attn_output`` + # from num_query_first to batch_first. + if self.batch_first: + query = query.transpose(0, 1) + key = key.transpose(0, 1) + value = value.transpose(0, 1) + + out = self.attn( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=self.need_weights, + ) + + if isinstance(out, tuple): + out = out[0] + + if self.batch_first: + out = out.transpose(0, 1) + + return identity + self.dropout_layer(self.proj_drop(out)) diff --git a/vis4d/op/layer/conv2d.py b/vis4d/op/layer/conv2d.py new file mode 100644 index 0000000000000000000000000000000000000000..81521b5565080285c422b95d6cb68e096285884d --- /dev/null +++ b/vis4d/op/layer/conv2d.py @@ -0,0 +1,283 @@ +"""Wrapper for conv2d.""" + +from __future__ import annotations + +from typing import NamedTuple + +import torch +from torch import Tensor, nn +from torch.nn import functional as F + +from vis4d.common.typing import ArgsType + +from .weight_init import constant_init + + +class Conv2d(nn.Conv2d): + """Wrapper around Conv2d to support empty inputs and norm/activation.""" + + def __init__( + self, + *args: ArgsType, + norm: nn.Module | None = None, + activation: nn.Module | None = None, + **kwargs: ArgsType, + ) -> None: + """Creates an instance of the class. + + If norm is specified, it is initialized with 1.0 and bias with 0.0. + """ + super().__init__(*args, **kwargs) + self.norm = norm + self.activation = activation + + if self.norm is not None: + constant_init(self.norm, 1.0, bias=0.0) + + def forward( # pylint: disable=arguments-renamed + self, x: Tensor + ) -> Tensor: + """Forward pass.""" + if not torch.jit.is_scripting(): # type: ignore + # https://github.com/pytorch/pytorch/issues/12013 + if ( + x.numel() == 0 + and self.training + and isinstance(self.norm, nn.SyncBatchNorm) + ): + raise ValueError( + "SyncBatchNorm does not support empty inputs!" + ) + + x = F.conv2d( # pylint: disable=not-callable + x, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + ) + if self.norm is not None: + x = self.norm(x) + if self.activation is not None: + x = self.activation(x) + return x + + +def add_conv_branch( + num_branch_convs: int, + last_layer_dim: int, + conv_out_dim: int, + conv_has_bias: bool, + norm_cfg: str | None, + num_groups: int | None, +) -> tuple[nn.ModuleList, int]: + """Init conv branch for head.""" + convs = nn.ModuleList() + if norm_cfg is not None: + norm = getattr(nn, norm_cfg) + else: + norm = None + + if norm == nn.GroupNorm: + assert num_groups is not None, "num_groups must be specified" + norm = lambda x: nn.GroupNorm( # pylint: disable=unnecessary-lambda-assignment + num_groups, x + ) + if num_branch_convs > 0: + for i in range(num_branch_convs): + conv_in_dim = last_layer_dim if i == 0 else conv_out_dim + convs.append( + Conv2d( + conv_in_dim, + conv_out_dim, + kernel_size=3, + padding=1, + bias=conv_has_bias, + norm=norm(conv_out_dim) if norm is not None else norm, + activation=nn.ReLU(inplace=True), + ) + ) + last_layer_dim = conv_out_dim + + return convs, last_layer_dim + + +class UnetDownConvOut(NamedTuple): + """Output of the UnetDownConv operator. + + features: Features before applying the pooling operator + pooled_features: Features after applying the pooling operator + """ + + features: Tensor + pooled_features: Tensor + + +class UnetDownConv(nn.Module): + """Downsamples a feature map by applying two convolutions and maxpool.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + pooling: bool = True, + activation: str = "ReLU", + ): + """Creates a new downsampling convolution operator. + + This operator consists of two convolutions followed by a maxpool + operator. + + Args: + in_channels (int): input channesl + out_channels (int): output channesl + pooling (bool): If pooling should be applied + activation (str): Activation that should be applied + """ + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.pooling = pooling + activation = getattr(nn, activation)() + + self.conv1 = nn.Conv2d( + self.in_channels, + self.out_channels, + kernel_size=3, + padding=1, + stride=1, + bias=True, + ) + self.conv2 = nn.Conv2d( + self.out_channels, + self.out_channels, + kernel_size=3, + padding=1, + stride=1, + bias=True, + ) + + if self.pooling: + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + + def __call__(self, data: Tensor) -> UnetDownConvOut: + """Applies the operator. + + Args: + data (Tensor): Input data. + + Returns: + UnetDownConvOut: Containing the features before the pooling + operation (features) and after (pooled_features). + """ + return self._call_impl(data) + + def forward(self, data: Tensor) -> UnetDownConvOut: + """Applies the operator. + + Args: + data (Tensor): Input data. + + Returns: + UnetDownConvOut: containing the features before the pooling + operation (features) and after (pooled_features). + """ + x = F.relu(self.conv1(data)) + x = F.relu(self.conv2(x)) + before_pool = x + if self.pooling: + x = self.pool(x) + return UnetDownConvOut(features=before_pool, pooled_features=x) + + +class UnetUpConv(nn.Module): + """An operator that performs 2 convolutions and 1 UpConvolution. + + A ReLU activation follows each convolution. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + merge_mode: str = "concat", + up_mode: str = "transpose", + ): + """Creates a new UpConv operator. + + This operator merges two inputs by upsampling one and combining it with + the other. + + Args: + in_channels: Number of input channels (low res) + out_channels: Number of output channels (high res) + merge_mode: How to merge both input channels + up_mode: How to upsample the channel with lower resolution + + Raises: + ValueError: If upsampling mode is unknown + """ + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.merge_mode = merge_mode + self.up_mode = up_mode + + # Upsampling + if self.up_mode == "transpose": + self.upconv: nn.Module = nn.ConvTranspose2d( + in_channels, out_channels, kernel_size=2, stride=2 + ) + elif self.up_mode == "upsample": + self.upconv = nn.Sequential( + nn.Upsample(mode="bilinear", scale_factor=2), + nn.Conv2d(in_channels, out_channels, kernel_size=1), + ) + else: + raise ValueError(f"Unknown upsampling mode: {up_mode}") + + if self.merge_mode == "concat": + self.conv1 = nn.Conv2d( + 2 * self.out_channels, self.out_channels, 3, padding=1 + ) + else: + # num of input channels to conv2 is same + self.conv1 = nn.Conv2d( + self.out_channels, self.out_channels, 3, padding=1 + ) + self.conv2 = nn.Conv2d( + self.out_channels, self.out_channels, 3, padding=1 + ) + + def __call__(self, from_down: Tensor, from_up: Tensor) -> Tensor: + """Forward pass. + + Arguments: + from_down (Tensor): Tensor from the encoder pathway. Assumed to + have dimension 'out_channels' + from_up (Tensor): Upconv'd tensor from the decoder pathway. Assumed + to have dimension 'in_channels' + """ + return self._call_impl(from_down, from_up) + + def forward(self, from_down: Tensor, from_up: Tensor) -> Tensor: + """Forward pass. + + Arguments: + from_down (Tensor): Tensor from the encoder pathway. Assumed to + have dimension 'out_channels' + from_up (Tensor): Upconv'd tensor from the decoder pathway. Assumed + to have dimension 'in_channels' + """ + from_up = self.upconv(from_up) + if self.merge_mode == "concat": + x = torch.cat((from_up, from_down), 1) + else: + x = from_up + from_down + x = F.relu(self.conv1(x)) + x = F.relu(self.conv2(x)) + return x diff --git a/vis4d/op/layer/csp_layer.py b/vis4d/op/layer/csp_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..47dc5e6f7678ccd220a8bf3203295c448a4e4bf0 --- /dev/null +++ b/vis4d/op/layer/csp_layer.py @@ -0,0 +1,146 @@ +"""Cross Stage Partial Layer. + +Modified from mmdetection (https://github.com/open-mmlab/mmdetection). +""" + +from __future__ import annotations + +import torch +from torch import nn + +from .conv2d import Conv2d + + +class DarknetBottleneck(nn.Module): + """The basic bottleneck block used in Darknet. + + Each ResBlock consists of two Conv blocks and the input is added to the + final output. Each block is composed of Conv, BN, and SiLU. + The first convolutional layer has filter size of 1x1 and the second one + has filter size of 3x3. + + Args: + in_channels (int): The input channels of this Module. + out_channels (int): The output channels of this Module. + expansion (float, optional): The kernel size of the convolution. + Defaults to 0.5. + add_identity (bool, optional): Whether to add identity to the output. + Defaults to True. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + expansion: float = 0.5, + add_identity: bool = True, + ): + """Init.""" + super().__init__() + hidden_channels = int(out_channels * expansion) + self.conv1 = Conv2d( + in_channels, + hidden_channels, + 1, + bias=False, + norm=nn.BatchNorm2d(hidden_channels, eps=0.001, momentum=0.03), + activation=nn.SiLU(inplace=True), + ) + self.conv2 = Conv2d( + hidden_channels, + out_channels, + 3, + stride=1, + padding=1, + bias=False, + norm=nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.03), + activation=nn.SiLU(inplace=True), + ) + self.add_identity = add_identity and in_channels == out_channels + + def forward(self, features: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + features (torch.Tensor): Input features. + """ + identity = features + out = self.conv1(features) + out = self.conv2(out) + + if self.add_identity: + return out + identity + return out + + +class CSPLayer(nn.Module): + """Cross Stage Partial Layer. + + Args: + in_channels (int): The input channels of the CSP layer. + out_channels (int): The output channels of the CSP layer. + expand_ratio (float, optional): Ratio to adjust the number of channels + of the hidden layer. Defaults to 0.5. + num_blocks (int, optional): Number of blocks. Defaults to 1. + add_identity (bool, optional): Whether to add identity in blocks. + Defaults to True. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + expand_ratio: float = 0.5, + num_blocks: int = 1, + add_identity: bool = True, + ): + """Init.""" + super().__init__() + mid_channels = int(out_channels * expand_ratio) + self.main_conv = Conv2d( + in_channels, + mid_channels, + 1, + bias=False, + norm=nn.BatchNorm2d(mid_channels, eps=0.001, momentum=0.03), + activation=nn.SiLU(inplace=True), + ) + self.short_conv = Conv2d( + in_channels, + mid_channels, + 1, + bias=False, + norm=nn.BatchNorm2d(mid_channels, eps=0.001, momentum=0.03), + activation=nn.SiLU(inplace=True), + ) + self.final_conv = Conv2d( + 2 * mid_channels, + out_channels, + 1, + bias=False, + norm=nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.03), + activation=nn.SiLU(inplace=True), + ) + + self.blocks = nn.Sequential( + *[ + DarknetBottleneck( + mid_channels, mid_channels, 1.0, add_identity + ) + for _ in range(num_blocks) + ] + ) + + def forward(self, features: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + features (torch.Tensor): Input features. + """ + x_short = self.short_conv(features) + + x_main = self.main_conv(features) + x_main = self.blocks(x_main) + + x_final = torch.cat((x_main, x_short), dim=1) + return self.final_conv(x_final) diff --git a/vis4d/op/layer/deform_conv.py b/vis4d/op/layer/deform_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..7a3a52baf2063130799cf6f57e10b988f13545f8 --- /dev/null +++ b/vis4d/op/layer/deform_conv.py @@ -0,0 +1,93 @@ +"""Wrapper for deformable convolution.""" + +from __future__ import annotations + +import torch +from torch import Tensor, nn +from torchvision.ops import DeformConv2d + +from .weight_init import constant_init + + +class DeformConv(DeformConv2d): # type: ignore + """Wrapper around Deformable Convolution operator with norm/activation. + + If norm is specified, it is initialized with 1.0 and bias with 0.0. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + norm: nn.Module | None = None, + activation: nn.Module | None = None, + ) -> None: + """Creates an instance of the class. + + Args: + in_channels (int): Input channels. + out_channels (int): Output channels. + kernel_size (int): Size of convolutional kernel. + stride (int, optional): Stride of convolutional layer. Defaults to + 1. + padding (int, optional): Padding of convolutional layer. Defaults + to 0. + dilation (int, optional): Dilation of convolutional layer. Defaults + to 1. + groups (int, optional): Number of deformable groups. Defaults to 1. + bias (bool, optional): Whether to use bias in convolutional layer. + Defaults to True. + norm (nn.Module, optional): Normalization layer. Defaults to None. + activation (nn.Module, optional): Activation layer. Defaults to + None. + """ + super().__init__( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + self.conv_offset = nn.Conv2d( + self.in_channels, + self.groups * 3 * self.kernel_size[0] * self.kernel_size[1], + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + bias=True, + ) + self.norm = norm + self.activation = activation + self.init_weights() + + def init_weights(self) -> None: + """Initialize weights of offset conv layer.""" + self.conv_offset.weight.data.zero_() + self.conv_offset.bias.data.zero_() # type: ignore + if self.norm is not None: + constant_init(self.norm, 1.0, bias=0.0) + + def forward( # pylint: disable=arguments-differ + self, input_x: Tensor + ) -> Tensor: + """Forward.""" + out = self.conv_offset(input_x) + o1, o2, mask = torch.chunk(out, 3, dim=1) + offset = torch.cat((o1, o2), dim=1) + mask = torch.sigmoid(mask) + input_x = super().forward(input_x, offset, mask) + if self.norm is not None: + input_x = self.norm(input_x) + if self.activation is not None: + input_x = self.activation(input_x) + return input_x diff --git a/vis4d/op/layer/drop.py b/vis4d/op/layer/drop.py new file mode 100644 index 0000000000000000000000000000000000000000..f0a1cef8b7f5627c158df26fc8b34d480e0709cb --- /dev/null +++ b/vis4d/op/layer/drop.py @@ -0,0 +1,68 @@ +"""DropPath (Stochastic Depth) regularization layers. + +Modified from timm (https://github.com/huggingface/pytorch-image-models). +""" + +from __future__ import annotations + +import torch +from torch import nn + + +def drop_path( + x: torch.Tensor, + drop_prob: float = 0.0, + training: bool = False, + scale_by_keep: bool = True, +) -> torch.Tensor: + """Drop path regularizer (Stochastic Depth) per sample. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, ...). + drop_prob (float, optional): Probability of an element to be zeroed. + Defaults to 0.0. + training (bool, optional): If to apply drop path. Defaults to False. + scale_by_keep (bool, optional): If to scale by keep probability. + Defaults to True. + """ + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * ( + x.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +class DropPath(nn.Module): + """DropPath regularizer (Stochastic Depth) per sample.""" + + def __init__( + self, drop_prob: float = 0.0, scale_by_keep: bool = True + ) -> None: + """Init DropPath. + + Args: + drop_prob (float, optional): Probability of an item to be masked. + Defaults to 0.0. + scale_by_keep (bool, optional): If to scale by keep probability. + Defaults to True. + """ + super().__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def __call__(self, data: torch.Tensor) -> torch.Tensor: + """Applies the layer. + + Args: + data: (tensor) input shape [N, ...] + """ + return self._call_impl(data) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass.""" + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) diff --git a/vis4d/op/layer/mlp.py b/vis4d/op/layer/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..ea157511ab5075f93a4e04c9eadb6787b6168ad7 --- /dev/null +++ b/vis4d/op/layer/mlp.py @@ -0,0 +1,62 @@ +"""MLP Layers.""" + +from __future__ import annotations + +from torch import Tensor, nn + + +class TransformerBlockMLP(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks.""" + + def __init__( + self, + in_features: int, + hidden_features: int | None = None, + out_features: int | None = None, + act_layer: nn.Module = nn.GELU(), + bias: bool = True, + drop: float = 0.0, + ): + """Init MLP. + + Args: + in_features (int): Number of input features. + hidden_features (int, optional): Number of hidden features. + Defaults to None. + out_features (int, optional): Number of output features. + Defaults to None. + act_layer (nn.Module, optional): Activation layer. + Defaults to nn.GELU. + bias (bool, optional): If bias should be used. Defaults to True. + drop (float, optional): Dropout probability. Defaults to 0.0. + """ + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer + self.drop1 = nn.Dropout(drop) + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop2 = nn.Dropout(drop) + + def __call__(self, data: Tensor) -> Tensor: + """Applies the layer. + + Args: + data: (tensor) input shape [N, C] + """ + return self._call_impl(data) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass. + + Args: + x: (tensor) input shape [N, C] + """ + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + return x diff --git a/vis4d/op/layer/ms_deform_attn.py b/vis4d/op/layer/ms_deform_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..cd91ec0e4c84774b7cbdc2bbd0e0f1acac7acfb3 --- /dev/null +++ b/vis4d/op/layer/ms_deform_attn.py @@ -0,0 +1,563 @@ +# pylint: disable=no-name-in-module, abstract-method, arguments-differ +"""Multi-Scale Deformable Attention Module. + +Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/modules/ms_deform_attn.py) # pylint: disable=line-too-long +""" +from __future__ import annotations + +import math + +import torch +import torch.nn.functional as F +from torch import Tensor, nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn.init import constant_, xavier_uniform_ + +from vis4d.common.imports import VIS4D_CUDA_OPS_AVAILABLE +from vis4d.common.logging import rank_zero_warn + +if VIS4D_CUDA_OPS_AVAILABLE: + from vis4d_cuda_ops import ms_deform_attn_backward, ms_deform_attn_forward +else: + raise ImportError("vis4d_cuda_ops is not installed.") + + +class MSDeformAttentionFunction(Function): # pragma: no cover + """Multi-Scale Deformable Attention Function module.""" + + @staticmethod + def forward( # type: ignore + ctx, + value: Tensor, + value_spatial_shapes: Tensor, + value_level_start_index: Tensor, + sampling_locations: Tensor, + attention_weights: Tensor, + im2col_step: int, + ) -> Tensor: + """Forward pass.""" + if not VIS4D_CUDA_OPS_AVAILABLE: + raise RuntimeError( + "MSDeformAttentionFunction requires vis4d cuda ops to run." + ) + ctx.im2col_step = im2col_step + output = ms_deform_attn_forward( + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + ctx.im2col_step, + ) + ctx.save_for_backward( + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + ) + return output + + @staticmethod + @once_differentiable # type: ignore + def backward( # type: ignore + ctx, grad_output: Tensor + ) -> tuple[Tensor, None, None, Tensor, Tensor, None]: + """Backward pass.""" + if not VIS4D_CUDA_OPS_AVAILABLE: + raise RuntimeError( + "MSDeformAttentionFunction requires vis4d cuda ops to run." + ) + ( + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + ) = ctx.saved_tensors + ( + grad_value, + grad_sampling_loc, + grad_attn_weight, + ) = ms_deform_attn_backward( + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + grad_output, + ctx.im2col_step, + ) + + return ( + grad_value, + None, + None, + grad_sampling_loc, + grad_attn_weight, + None, + ) + + +def ms_deformable_attention_cpu( + value: Tensor, + value_spatial_shapes: Tensor, + sampling_locations: Tensor, + attention_weights: Tensor, +) -> Tensor: + """CPU version of multi-scale deformable attention. + + Args: + value (Tensor): The value has shape (bs, num_keys, mum_heads, + embed_dims // num_heads) + value_spatial_shapes (Tensor): Spatial shape of each feature map, has + shape (num_levels, 2), last dimension 2 represent (h, w). + sampling_locations (Tensor): The location of sampling points, has shape + (bs ,num_queries, num_heads, num_levels, num_points, 2), the last + dimension 2 represent (x, y). + attention_weights (Tensor): The weight of sampling points used when + calculate the attention, has shape (bs ,num_queries, num_heads, + num_levels, num_points), + + Returns: + Tensor: has shape (bs, num_queries, embed_dims). + """ + bs, _, num_heads, embed_dims = value.shape + ( + _, + num_queries, + num_heads, + num_levels, + num_points, + _, + ) = sampling_locations.shape + value_list = value.split([h * w for h, w in value_spatial_shapes], dim=1) + sampling_grids: Tensor = 2 * sampling_locations - 1 + sampling_value_list = [] + for level, (h, w) in enumerate(value_spatial_shapes): + # bs, h*w, num_heads, embed_dims -> + # bs, h*w, num_heads*embed_dims -> + # bs, num_heads*embed_dims, h*w -> + # bs*num_heads, embed_dims, h, w + value_l_ = ( + value_list[level] + .flatten(2) + .transpose(1, 2) + .reshape(bs * num_heads, embed_dims, h, w) + ) + # bs, num_queries, num_heads, num_points, 2 -> + # bs, num_heads, num_queries, num_points, 2 -> + # bs*num_heads, num_queries, num_points, 2 + sampling_grid_l_ = ( + sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1) + ) + # bs*num_heads, embed_dims, num_queries, num_points + sampling_value_l_ = F.grid_sample( + value_l_, + sampling_grid_l_, + mode="bilinear", + padding_mode="zeros", + align_corners=False, + ) + sampling_value_list.append(sampling_value_l_) + # (bs, num_queries, num_heads, num_levels, num_points) -> + # (bs, num_heads, num_queries, num_levels, num_points) -> + # (bs, num_heads, 1, num_queries, num_levels*num_points) + attention_weights = attention_weights.transpose(1, 2).reshape( + bs * num_heads, 1, num_queries, num_levels * num_points + ) + output = ( + ( + torch.stack(sampling_value_list, dim=-2).flatten(-2) + * attention_weights + ) + .sum(-1) + .view(bs, num_heads * embed_dims, num_queries) + ) + return output.transpose(1, 2).contiguous() + + +def is_power_of_2(number: int) -> None: + """Check if a number is a power of 2.""" + if (not isinstance(number, int)) or (number < 0): + raise ValueError( + f"invalid input for is_power_of_2: {number} (type: {type(number)})" + ) + if not ((number & (number - 1) == 0) and number != 0): + rank_zero_warn( + "You'd better set hidden dimensions in MultiScaleDeformAttention" + "to make the dimension of each attention head a power of 2, " + "which is more efficient in our CUDA implementation." + ) + + +class MSDeformAttention(nn.Module): + """Multi-Scale Deformable Attention Module. + + This is the original implementation from Deformable DETR. + """ + + def __init__( + self, + d_model: int = 256, + n_levels: int = 4, + n_heads: int = 8, + n_points: int = 4, + im2col_step: int = 64, + ) -> None: + """Creates an instance of the class. + + Args: + d_model (int): Hidden dimensions. + n_levels (int): Number of feature levels. + n_heads (int): Number of attention heads. + n_points (int): Number of sampling points per attention head per + feature level. + im2col_step (int): The step used in image_to_column. Default: 64. + """ + super().__init__() + if d_model % n_heads != 0: + raise ValueError( + "d_model must be divisible by n_heads, but got " + + f"{d_model} and {n_heads}." + ) + + is_power_of_2(d_model // n_heads) + + self.d_model = d_model + self.n_levels = n_levels + self.n_heads = n_heads + self.n_points = n_points + self.im2col_step = im2col_step + + self.sampling_offsets = nn.Linear( + d_model, n_heads * n_levels * n_points * 2 + ) + self.attention_weights = nn.Linear( + d_model, n_heads * n_levels * n_points + ) + self.value_proj = nn.Linear(d_model, d_model) + self.output_proj = nn.Linear(d_model, d_model) + + self._reset_parameters() + + def _reset_parameters(self) -> None: + """Reset parameters.""" + constant_(self.sampling_offsets.weight.data, 0.0) + thetas = torch.mul( + torch.arange(self.n_heads, dtype=torch.float32), + (2.0 * math.pi / self.n_heads), + ) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = ( + (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) + .view(self.n_heads, 1, 1, 2) + .repeat(1, self.n_levels, self.n_points, 1) + ) + for i in range(self.n_points): + grid_init[:, :, i, :] *= i + 1 + with torch.no_grad(): + self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) + constant_(self.attention_weights.weight.data, 0.0) + constant_(self.attention_weights.bias.data, 0.0) + xavier_uniform_(self.value_proj.weight.data) + constant_(self.value_proj.bias.data, 0.0) + xavier_uniform_(self.output_proj.weight.data) + constant_(self.output_proj.bias.data, 0.0) + + def forward( + self, + query: Tensor, + reference_points: Tensor, + input_flatten: Tensor, + input_spatial_shapes: Tensor, + input_level_start_index: Tensor, + input_padding_mask: Tensor | None = None, + ) -> Tensor: + r"""Forward function. + + Args: + query (Tensor): (n, length_{query}, C). + reference_points (Tensor): (n, length_{query}, n_levels, 2), + range in [0, 1], top-left (0,0), bottom-right (1, 1), including + padding area or (n, length_{query}, n_levels, 4), add + additional (w, h) to form reference boxes. + input_flatten (Tensor): (n, \sum_{l=0}^{L-1} H_l \cdot W_l, C). + input_spatial_shapes (Tensor): (n_levels, 2), [(H_0, W_0), + (H_1, W_1), ..., (H_{L-1}, W_{L-1})] + input_level_start_index (Tensor): (n_levels, ), [0, H_0*W_0, + H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., + H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] + input_padding_mask (Tensor): (n, \sum_{l=0}^{L-1} H_l \cdot W_l), + True for padding elements, False for non-padding elements. + + Retrun + output (Tensor): (n, length_{query}, C). + """ + n, len_q, _ = query.shape + n, len_in, _ = input_flatten.shape + assert ( + input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1] + ).sum() == len_in + + value = self.value_proj(input_flatten) + if input_padding_mask is not None: + value = value.masked_fill(input_padding_mask[..., None], float(0)) + value = value.view( + n, len_in, self.n_heads, self.d_model // self.n_heads + ) + sampling_offsets = self.sampling_offsets(query).view( + n, len_q, self.n_heads, self.n_levels, self.n_points, 2 + ) + attention_weights = self.attention_weights(query).view( + n, len_q, self.n_heads, self.n_levels * self.n_points + ) + attention_weights = F.softmax(attention_weights, -1).view( + n, len_q, self.n_heads, self.n_levels, self.n_points + ) + # n, len_q, n_heads, n_levels, n_points, 2 + if reference_points.shape[-1] == 2: + offset_normalizer = torch.stack( + [input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], + -1, + ) + sampling_locations = ( + reference_points[:, :, None, :, None, :] + + sampling_offsets + / offset_normalizer[None, None, None, :, None, :] + ) + elif reference_points.shape[-1] == 4: + sampling_locations = ( + reference_points[:, :, None, :, None, :2] + + sampling_offsets + / self.n_points + * reference_points[:, :, None, :, None, 2:] + * 0.5 + ) + else: + raise ValueError( + "Last dim of reference_points must be 2 or 4, " + + f"but get {reference_points.shape[-1]} instead." + ) + + if torch.cuda.is_available() and value.is_cuda: + output = MSDeformAttentionFunction.apply( + value, + input_spatial_shapes, + input_level_start_index, + sampling_locations, + attention_weights, + self.im2col_step, + ) + else: + output = ms_deformable_attention_cpu( + value, + input_spatial_shapes, + sampling_locations, + attention_weights, + ) + + output = self.output_proj(output) + + return output + + def __call__( + self, + query: Tensor, + reference_points: Tensor, + input_flatten: Tensor, + input_spatial_shapes: Tensor, + input_level_start_index: Tensor, + input_padding_mask: Tensor | None = None, + ) -> Tensor: + """Type definition for call implementation.""" + return self._call_impl( + query, + reference_points, + input_flatten, + input_spatial_shapes, + input_level_start_index, + input_padding_mask, + ) + + +class MultiScaleDeformableAttention(nn.Module): + """A wrapper for ``MSDeformAttention``. + + This module implements MSDeformAttention with identity connection, + and positional encoding is also passed as input. + """ + + def __init__( + self, + embed_dims: int = 256, + num_heads: int = 8, + num_levels: int = 4, + num_points: int = 4, + im2col_step: int = 64, + dropout: float = 0.0, + ) -> None: + """Init.""" + super().__init__() + if embed_dims % num_heads != 0: + raise ValueError( + "embed_dims must be divisible by num_heads, but got " + + f"{embed_dims} and {num_heads}." + ) + + is_power_of_2(embed_dims // num_heads) + + self.embed_dims = embed_dims + self.num_heads = num_heads + self.num_levels = num_levels + self.num_points = num_points + self.im2col_step = im2col_step + + self.sampling_offsets = nn.Linear( + embed_dims, num_heads * num_levels * num_points * 2 + ) + self.attention_weights = nn.Linear( + embed_dims, num_heads * num_levels * num_points + ) + self.value_proj = nn.Linear(embed_dims, embed_dims) + self.output_proj = nn.Linear(embed_dims, embed_dims) + + self.dropout = nn.Dropout(dropout) + + self._init_weights() + + def _init_weights(self) -> None: + """Initialize weights.""" + constant_(self.sampling_offsets.weight.data, 0.0) + thetas = torch.mul( + torch.arange(self.num_heads, dtype=torch.float32), + (2.0 * math.pi / self.num_heads), + ) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = ( + (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) + .view(self.num_heads, 1, 1, 2) + .repeat(1, self.num_levels, self.num_points, 1) + ) + for i in range(self.num_points): + grid_init[:, :, i, :] *= i + 1 + with torch.no_grad(): + self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) + constant_(self.attention_weights.weight.data, 0.0) + constant_(self.attention_weights.bias.data, 0.0) + xavier_uniform_(self.value_proj.weight.data) + constant_(self.value_proj.bias.data, 0.0) + xavier_uniform_(self.output_proj.weight.data) + constant_(self.output_proj.bias.data, 0.0) + + def forward( + self, + query: Tensor, + reference_points: Tensor, + input_flatten: Tensor, + input_spatial_shapes: Tensor, + input_level_start_index: Tensor, + query_pos: Tensor | None = None, + identity: Tensor | None = None, + input_padding_mask: Tensor | None = None, + ) -> Tensor: + r"""Forward function. + + Args: + query (Tensor): The input query with shape [bs, num_queries, + embed_dims]. + reference_points (Tensor): (bs, num_queries, num_levels, 2), + range in [0, 1], top-left (0,0), bottom-right (1, 1), including + padding area or (bs, num_queries, num_levels, 4), add + additional (w, h) to form reference boxes. + input_flatten (Tensor): (bs, \sum_{l=0}^{L-1} H_l \cdot W_l, C). + input_spatial_shapes (Tensor): (num_levels, 2), [(H_0, W_0), + (H_1, W_1), ..., (H_{L-1}, W_{L-1})]. + input_level_start_index (Tensor): (num_levels, ), [0, H_0*W_0, + H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., + H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]. + query_pos (Tensor | None): The positional encoding for query, with + the same shape as `query`. If not None, it will + be added to `query` before forward function. Defaults to None. + identity (Tensor | None): With the same shape as query, it will be + used for the identity link. If None, `query` will be used. + Defaults to None. + input_padding_mask (Tensor): (bs, \sum_{l=0}^{L-1} H_l \cdot W_l), + True for padding elements, False for non-padding elements. + + Returns + output (Tensor): (bs, num_queries, C). + """ + if identity is None: + identity = query + + if query_pos is not None: + query = query + query_pos + + n, len_q, _ = query.shape + n, len_in, _ = input_flatten.shape + assert ( + input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1] + ).sum() == len_in + + value = self.value_proj(input_flatten) + if input_padding_mask is not None: + value = value.masked_fill(input_padding_mask[..., None], float(0)) + value = value.view( + n, len_in, self.num_heads, self.embed_dims // self.num_heads + ) + sampling_offsets = self.sampling_offsets(query).view( + n, len_q, self.num_heads, self.num_levels, self.num_points, 2 + ) + attention_weights = self.attention_weights(query).view( + n, len_q, self.num_heads, self.num_levels * self.num_points + ) + attention_weights = F.softmax(attention_weights, -1).view( + n, len_q, self.num_heads, self.num_levels, self.num_points + ) + # n, len_q, num_heads, num_levels, num_points, 2 + if reference_points.shape[-1] == 2: + offset_normalizer = torch.stack( + [input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], + -1, + ) + sampling_locations = ( + reference_points[:, :, None, :, None, :] + + sampling_offsets + / offset_normalizer[None, None, None, :, None, :] + ) + elif reference_points.shape[-1] == 4: + sampling_locations = ( + reference_points[:, :, None, :, None, :2] + + sampling_offsets + / self.num_points + * reference_points[:, :, None, :, None, 2:] + * 0.5 + ) + else: + raise ValueError( + "Last dim of reference_points must be 2 or 4, " + + f"but get {reference_points.shape[-1]} instead." + ) + + if torch.cuda.is_available() and value.is_cuda: + output = MSDeformAttentionFunction.apply( + value, + input_spatial_shapes, + input_level_start_index, + sampling_locations, + attention_weights, + self.im2col_step, + ) + else: + output = ms_deformable_attention_cpu( + value, + input_spatial_shapes, + sampling_locations, + attention_weights, + ) + + output = self.output_proj(output) + + return self.dropout(output) + identity diff --git a/vis4d/op/layer/patch_embed.py b/vis4d/op/layer/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..ba775a63cd4f030ef6d2c1feb6a73977d0b45b3d --- /dev/null +++ b/vis4d/op/layer/patch_embed.py @@ -0,0 +1,91 @@ +"""Image to Patch Embedding using Conv2d. + +Modified from vision_transformer +(https://github.com/google-research/vision_transformer). +""" + +from __future__ import annotations + +import torch +from torch import nn + + +class PatchEmbed(nn.Module): + """2D Image to Patch Embedding.""" + + def __init__( + self, + img_size: int = 224, + patch_size: int = 16, + in_channels: int = 3, + embed_dim: int = 768, + norm_layer: nn.Module | None = None, + flatten: bool = True, + bias: bool = True, + ): + """Init PatchEmbed. + + Args: + img_size (int, optional): Input image's size. Defaults to 224. + patch_size (int, optional): Patch size. Defaults to 16. + in_channels (int, optional): Number of input image's channels. + Defaults to 3. + embed_dim (int, optional): Patch embedding's dim. Defaults to 768. + norm_layer (nn.Module, optional): Normalization layer. Defaults to + None, which means no normalization layer. + flatten (bool, optional): If to flatten the output tensor. + Defaults to True. + bias (bool, optional): If to add bias to the convolution layer. + Defaults to True. + + Raises: + ValueError: If the input image's size is not divisible by the patch + size. + """ + super().__init__() + self.img_size = (img_size, img_size) + self.patch_size = (patch_size, patch_size) + self.grid_size = ( + self.img_size[0] // self.patch_size[0], + self.img_size[1] // self.patch_size[1], + ) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + + self.proj = nn.Conv2d( + in_channels, + embed_dim, + kernel_size=patch_size, + stride=patch_size, + bias=bias, + ) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def __call__(self, data: torch.Tensor) -> torch.Tensor: + """Applies the layer. + + Args: + data (torch.Tensor): Input tensor of shape (B, C, H, W). + + Returns: + torch.Tensor: Output tensor of shape (B, N, C), where N is the + number of patches (N = H * W). + """ + return self._call_impl(data) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function.""" + _, _, height, width = x.shape + assert height == self.img_size[0], ( + f"Input image height ({height}) doesn't match model" + f"({self.img_size})." + ) + assert width == self.img_size[1], ( + f"Input image width ({width}) doesn't match model" + f"({self.img_size})." + ) + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # (B, C, H, W) -> (B, N, C) + x = self.norm(x) + return x diff --git a/vis4d/op/layer/positional_encoding.py b/vis4d/op/layer/positional_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..37ccfa7c0e5be18feec72d9d4cbda8b9d7b5ac9e --- /dev/null +++ b/vis4d/op/layer/positional_encoding.py @@ -0,0 +1,192 @@ +"""Positional encoding for transformer. + +Modified from mmdetection (https://github.com/open-mmlab/mmdetection). +""" + +from __future__ import annotations + +import math + +import torch +from torch import Tensor, nn + +from .weight_init import uniform_init + + +class SinePositionalEncoding(nn.Module): + """Position encoding with sine and cosine functions. + + See `End-to-End Object Detection with Transformers + `_ for details. + """ + + def __init__( + self, + num_feats: int, + temperature: int = 10000, + normalize: bool = False, + scale: float = 2 * math.pi, + eps: float = 1e-6, + offset: float = 0.0, + ) -> None: + """Initialization for `SinePositionalEncoding`. + + Args: + num_feats (int): The feature dimension for each position + along x-axis or y-axis. Note the final returned dimension + for each position is 2 times of this value. + temperature (int, optional): The temperature used for scaling + the position embedding. Defaults to 10000. + normalize (bool, optional): Whether to normalize the position + embedding. Defaults to False. + scale (float, optional): A scale factor that scales the position + embedding. The scale will be used only when normalize is True. + Defaults to 2*pi. + eps (float, optional): A value added to the denominator for + numerical stability. Defaults to 1e-6. + offset (float, optional): offset add to embed when do the + normalization. Defaults to 0. + """ + super().__init__() + if normalize: + assert isinstance(scale, (float, int)), ( + "when normalize is set," + "scale should be provided and in float or int type, " + f"found {type(scale)}" + ) + self.num_feats = num_feats + self.temperature = temperature + self.normalize = normalize + self.scale = scale + self.eps = eps + self.offset = offset + + def forward( + self, mask: Tensor | None, inputs: Tensor | None = None + ) -> Tensor: + """Forward function for `SinePositionalEncoding`. + + Args: + mask (Tensor | None): ByteTensor mask. Non-zero values representing + ignored positions, while zero values means valid positions + for this image. Shape [bs, h, w]. If None, it means single + image or batch image with no padding. + inputs (Tensor | None): The input tensor. It mask is None, this + input tensor is required to get the shape of the input image. + + Returns: + pos (Tensor): Returned position embedding with shape + [bs, num_feats*2, h, w]. + """ + if mask is not None: + # For convenience of exporting to ONNX, it's required to convert + # `masks` from bool to int. + mask = mask.to(torch.int) + b, h, w = mask.size() + device = mask.device + not_mask = 1 - mask # logical_not + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + else: + # single image or batch image with no padding + assert isinstance(inputs, Tensor) + b, _, h, w = inputs.shape + device = inputs.device + x_embed = torch.arange( + 1, w + 1, dtype=torch.float32, device=device + ) + x_embed = x_embed.view(1, 1, -1).repeat(b, h, 1) + y_embed = torch.arange( + 1, h + 1, dtype=torch.float32, device=device + ) + y_embed = y_embed.view(1, -1, 1).repeat(b, 1, w) + if self.normalize: + y_embed = ( + (y_embed + self.offset) + / (y_embed[:, -1:, :] + self.eps) + * self.scale + ) + x_embed = ( + (x_embed + self.offset) + / (x_embed[:, :, -1:] + self.eps) + * self.scale + ) + dim_t = torch.arange( + self.num_feats, dtype=torch.float32, device=device + ) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_feats) + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + # use `view` instead of `flatten` for dynamically exporting to ONNX + + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 + ).view(b, h, w, -1) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 + ).view(b, h, w, -1) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +class LearnedPositionalEncoding(nn.Module): + """Position embedding with learnable embedding weights.""" + + def __init__( + self, num_feats: int, row_num_embed: int = 50, col_num_embed: int = 50 + ) -> None: + """Initialization for LearnedPositionalEncoding. + + Args: + num_feats (int): The feature dimension for each position + along x-axis or y-axis. The final returned dimension for + each position is 2 times of this value. + row_num_embed (int, optional): The dictionary size of row + embeddings. Defaults to 50. + col_num_embed (int, optional): The dictionary size of col + embeddings. Defaults to 50. + """ + super().__init__() + self.row_embed = nn.Embedding(row_num_embed, num_feats) + self.col_embed = nn.Embedding(col_num_embed, num_feats) + self.num_feats = num_feats + self.row_num_embed = row_num_embed + self.col_num_embed = col_num_embed + + self.init_weights() + + def init_weights(self) -> None: + """Initialize the weights of position embedding.""" + uniform_init(self.row_embed, lower=0, upper=1) + uniform_init(self.col_embed, lower=0, upper=1) + + def forward(self, mask: Tensor) -> Tensor: + """Forward function for `LearnedPositionalEncoding`. + + Args: + mask (Tensor): ByteTensor mask. Non-zero values representing + ignored positions, while zero values means valid positions + for this image. Shape [bs, h, w]. + + Returns: + pos (Tensor): Returned position embedding with shape + [bs, num_feats*2, h, w]. + """ + h, w = mask.shape[-2:] + x = torch.arange(w, device=mask.device) + y = torch.arange(h, device=mask.device) + x_embed = self.col_embed(x) + y_embed = self.row_embed(y) + pos = ( + torch.cat( + ( + x_embed.unsqueeze(0).repeat(h, 1, 1), + y_embed.unsqueeze(1).repeat(1, w, 1), + ), + dim=-1, + ) + .permute(2, 0, 1) + .unsqueeze(0) + .repeat(mask.shape[0], 1, 1, 1) + ) + return pos diff --git a/vis4d/op/layer/transformer.py b/vis4d/op/layer/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..2600ade3df5b1d0fe085b452cc088c41d0c49616 --- /dev/null +++ b/vis4d/op/layer/transformer.py @@ -0,0 +1,255 @@ +"""Transformer layer. + +Modified from timm (https://github.com/huggingface/pytorch-image-models) and +mmdetection (https://github.com/open-mmlab/mmdetection). +""" + +from __future__ import annotations + +import copy + +import torch +from torch import Tensor, nn + +from .attention import Attention +from .drop import DropPath +from .mlp import TransformerBlockMLP +from .util import build_activation_layer + + +def inverse_sigmoid(x: Tensor, eps: float = 1e-5) -> Tensor: + """Inverse function of sigmoid. + + Args: + x (Tensor): The tensor to do the inverse. + eps (float): EPS avoid numerical overflow. Defaults 1e-5. + + Returns: + Tensor: The x has passed the inverse function of sigmoid, has same + shape with input. + """ + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1 / x2) + + +def get_clones(module: nn.Module, num: int) -> nn.ModuleList: + """Create N identical layers.""" + return nn.ModuleList([copy.deepcopy(module) for _ in range(num)]) + + +class LayerScale(nn.Module): + """Layer scaler.""" + + def __init__( + self, + dim: int, + inplace: bool = False, + data_format: str = "channels_last", + init_values: float = 1e-5, + ): + """Init layer scaler. + + Args: + dim (int): Input tensor's dimension. + inplace (bool): Whether performs operation in-place. Default: + False. + data_format (str): The input data format, could be 'channels_last' + or 'channels_first', representing (B, C, H, W) and (B, N, C) + format data respectively. Default: channels_last. + init_values (float, optional): Initial values for layer scale. + Defaults to 1e-5. + """ + super().__init__() + assert data_format in { + "channels_last", + "channels_first", + }, "data_format could only be channels_last or channels_first." + self.inplace = inplace + self.data_format = data_format + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass.""" + if self.data_format == "channels_first": + shape = tuple((1, -1, *(1 for _ in range(x.dim() - 2)))) + else: + shape = tuple((*(1 for _ in range(x.dim() - 1)), -1)) + + if self.inplace: + return x.mul_(self.gamma.view(*shape)) + + return x * self.gamma.view(*shape) + + +class TransformerBlock(nn.Module): + """Transformer block for Vision Transformer.""" + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values: float | None = None, + drop_path: float = 0.0, + act_layer: nn.Module = nn.GELU(), + norm_layer: nn.Module | None = None, + ): + """Init transformer block. + + Args: + dim (int): Input tensor's dimension. + num_heads (int): Number of attention heads. + mlp_ratio (float, optional): Ratio of MLP hidden dim to embedding + dim. Defaults to 4.0. + qkv_bias (bool, optional): If to add bias to qkv. Defaults to + False. + drop (float, optional): Dropout rate for attention and projection. + Defaults to 0.0. + attn_drop (float, optional): Dropout rate for attention. Defaults + to 0.0. + init_values (tuple[float, float] | None, optional): Initial values + for layer scale. Defaults to None. + drop_path (float, optional): Dropout rate for drop path. Defaults + to 0.0. + act_layer (nn.Module, optional): Activation layer. Defaults to + nn.GELU. + norm_layer (nn.Module, optional): Normalization layer. If None, use + nn.LayerNorm. + """ + super().__init__() + self.norm1 = ( + norm_layer(dim) if norm_layer else nn.LayerNorm(dim, eps=1e-6) + ) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = ( + LayerScale(dim, init_values=init_values) + if init_values + else nn.Identity() + ) + self.drop_path1 = ( + DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + ) + + self.norm2 = ( + norm_layer(dim) if norm_layer else nn.LayerNorm(dim, eps=1e-6) + ) + self.mlp = TransformerBlockMLP( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=drop, + ) + self.ls2 = ( + LayerScale(dim, init_values=init_values) + if init_values + else nn.Identity() + ) + self.drop_path2 = ( + DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + ) + + def __call__(self, data: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + data (torch.Tensor): Input tensor of shape (B, N, dim). + + Returns: + torch.Tensor: Output tensor of shape (B, N, dim). + """ + return self._call_impl(data) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass.""" + x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +class FFN(nn.Module): + """Implements feed-forward networks (FFNs) with identity connection.""" + + def __init__( + self, + embed_dims: int = 256, + feedforward_channels: int = 1024, + num_fcs: int = 2, + dropout: float = 0.0, + activation: str = "ReLU", + inplace: bool = True, + dropout_layer: nn.Module | None = None, + add_identity: bool = True, + layer_scale_init_value: float = 0.0, + ) -> None: + """Init FFN. + + Args: + embed_dims (int): The feature dimension. Defaults: 256. + feedforward_channels (int): The hidden dimension of FFNs. + Defaults: 1024. + num_fcs (int): The number of fully-connected layers in FFNs. + Defaults: 2. + dropout (float): The dropout rate of FFNs. + activation (str): The activation function of FFNs. + inplace (bool): Whether to set inplace for activation. + dropout_layer (nn.Module | None, optional): The dropout_layer used + when adding the shortcut. Defaults to None. If None, Identity + is used. + add_identity (bool, optional): Whether to add the identity + connection. Default: True. + layer_scale_init_value (float): Initial value of scale factor in + LayerScale. Default: 0.0 + """ + super().__init__() + self.embed_dims = embed_dims + + layers: list[nn.Module] = [] + in_channels = embed_dims + for _ in range(num_fcs - 1): + layers.append( + nn.Sequential( + nn.Linear(in_channels, feedforward_channels), + build_activation_layer(activation, inplace), + nn.Dropout(dropout), + ) + ) + in_channels = feedforward_channels + layers.append(nn.Linear(feedforward_channels, embed_dims)) + layers.append(nn.Dropout(dropout)) + self.layers = nn.Sequential(*layers) + + self.dropout_layer = dropout_layer or nn.Identity() + self.add_identity = add_identity + self.layer_scale_init_value = layer_scale_init_value + + if self.layer_scale_init_value > 0: + self.gamma2 = LayerScale( + embed_dims, init_values=self.layer_scale_init_value + ) + + def forward(self, x: Tensor, identity: Tensor | None = None) -> None: + """Forward function for FFN. + + The function would add x to the output tensor if residue is None. + """ + out = self.layers(x) + + if self.layer_scale_init_value > 0: + out = self.gamma2(out) + + if self.add_identity: + identity = x if identity is None else identity + return identity + self.dropout_layer(out) + + return self.dropout_layer(out) diff --git a/vis4d/op/layer/util.py b/vis4d/op/layer/util.py new file mode 100644 index 0000000000000000000000000000000000000000..9c6c7488ebd8bc0ae60ca1e06b243ce5e5ed0433 --- /dev/null +++ b/vis4d/op/layer/util.py @@ -0,0 +1,89 @@ +"""Utility functions for layer ops.""" + +from __future__ import annotations + +from torch import nn + +from .conv2d import Conv2d +from .deform_conv import DeformConv + + +def build_conv_layer( + in_planes: int, + out_planes: int, + kernel_size: int = 3, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, + bias: bool = False, + norm: nn.Module | None = None, + activation: nn.Module | None = None, + use_dcn: bool = False, +) -> nn.Module: + """Build a convolution layer.""" + if use_dcn: + return DeformConv( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + norm=norm, + activation=activation, + ) + + return Conv2d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + norm=norm, + activation=activation, + ) + + +def build_activation_layer( + activation: str, inplace: bool = False +) -> nn.Module: + """Build activation layer. + + Args: + activation (str): Activation layer type. + inplace (bool, optional): If to set inplace. Defaults to False. It will + be ignored if the activation layer is not inplace. + """ + activation_layer = getattr(nn, activation) + + if activation_layer in {nn.Tanh, nn.PReLU, nn.Sigmoid, nn.GELU}: + return activation_layer() + + return activation_layer(inplace=inplace) + + +def build_norm_layer( + norm: str, out_channels: int, num_groups: int | None = None +) -> nn.Module: + """Build normalization layer. + + Args: + norm (str): Normalization layer type. + out_channels (int): Number of output channels. + num_groups (int | None, optional): Number of groups for GroupNorm. + Defaults to None. + """ + norm_layer = getattr(nn, norm) + if norm_layer == nn.GroupNorm: + assert ( + num_groups is not None + ), "num_groups must be specified when using Group Norm" + return norm_layer(num_groups, out_channels) + + return norm_layer(out_channels) diff --git a/vis4d/op/layer/weight_init.py b/vis4d/op/layer/weight_init.py new file mode 100644 index 0000000000000000000000000000000000000000..5e58e24e0d2400cadd2b733cca6617075a120b49 --- /dev/null +++ b/vis4d/op/layer/weight_init.py @@ -0,0 +1,120 @@ +"""Model weight initialization.""" + +from typing import Literal + +import numpy as np +from torch import nn + +NonlinearityType = Literal[ + "linear", + "conv1d", + "conv2d", + "conv3d", + "conv_transpose1d", + "conv_transpose2d", + "conv_transpose3d", + "sigmoid", + "tanh", + "relu", + "leaky_relu", + "selu", +] +FanMode = Literal["fan_in", "fan_out"] + + +def constant_init(module: nn.Module, val: float, bias: float = 0.0) -> None: + """Initialize module with constant value.""" + if hasattr(module, "weight") and isinstance(module.weight, nn.Parameter): + nn.init.constant_(module.weight, val) + if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter): + nn.init.constant_(module.bias, bias) + + +def xavier_init( + module: nn.Module, + gain: float = 1.0, + bias: float = 0.0, + distribution: str = "normal", +) -> None: + """Initialize module with Xavier initialization.""" + assert distribution in {"uniform", "normal"} + if hasattr(module, "weight") and isinstance(module.weight, nn.Parameter): + if distribution == "uniform": + nn.init.xavier_uniform_(module.weight, gain=gain) + else: + nn.init.xavier_normal_(module.weight, gain=gain) + if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter): + nn.init.constant_(module.bias, bias) + + +def kaiming_init( + module: nn.Module, + negative_slope: float = 0.0, + mode: FanMode = "fan_out", + nonlinearity: NonlinearityType = "relu", + bias: float = 0.0, + distribution: str = "normal", +) -> None: + """Initialize module with Kaiming initialization. + + Args: + module (nn.Module): Module to initialize. + negative_slope (float, optional): The negative slope of the rectifier + used after this layer (only used with ``'leaky_relu'``). Defaults + to 0.0. + mode (FanMode, optional): Either `"fan_in"` (default) or `"fan_out"``. + Choosing `"fan_in"` preserves the magnitude of the variance of + the weights in the forward pass. Choosing `"fan_out"` preserves + magnitudes in the backwards pass. Defaults to "fan_out". + nonlinearity (NonlinearityType, optional): The non-linear function + (`nn.functional` name). Defaults to "relu". + bias (float, optional): The bias to use. Defaults to 0.0. + distribution (str, optional): Either ``'uniform'`` or ``'normal'``. + Defaults to "normal". + """ + assert distribution in {"uniform", "normal"} + if hasattr(module, "weight") and isinstance(module.weight, nn.Parameter): + if distribution == "uniform": + nn.init.kaiming_uniform_( + module.weight, + a=negative_slope, + mode=mode, + nonlinearity=nonlinearity, + ) + else: + nn.init.kaiming_normal_( + module.weight, + a=negative_slope, + mode=mode, + nonlinearity=nonlinearity, + ) + if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter): + nn.init.constant_(module.bias, bias) + + +def normal_init( + module: nn.Module, mean: float = 0.0, std: float = 1.0, bias: float = 0 +) -> None: + """Initialize module with normal distribution.""" + if hasattr(module, "weight") and isinstance(module.weight, nn.Parameter): + nn.init.normal_(module.weight, mean, std) + if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter): + nn.init.constant_(module.bias, bias) + + +def bias_init_with_prob(prior_prob: float) -> float: + """Initialize conv/fc bias value according to a given probability value.""" + return float(-np.log((1 - prior_prob) / prior_prob)) + + +def uniform_init( + module: nn.Module, + lower: float = 0.0, + upper: float = 1.0, + bias: float = 0.0, +) -> None: + """Initialize module with uniform distribution.""" + if hasattr(module, "weight") and isinstance(module.weight, nn.Parameter): + nn.init.uniform_(module.weight, lower, upper) + if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter): + nn.init.constant_(module.bias, bias) diff --git a/vis4d/op/loss/__init__.py b/vis4d/op/loss/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e8846c92abc71540756ffdbdb3b3fc086ecda395 --- /dev/null +++ b/vis4d/op/loss/__init__.py @@ -0,0 +1,23 @@ +"""This module contains commonly used loss functions. + +The losses do not follow a common API, but have a reducer as attribute, +which is a function to aggregate loss values into a single tensor value. +""" + +from .base import Loss +from .embedding_distance import EmbeddingDistanceLoss +from .iou_loss import IoULoss +from .multi_level_seg_loss import MultiLevelSegLoss +from .multi_pos_cross_entropy import MultiPosCrossEntropyLoss +from .orthogonal_transform_loss import OrthogonalTransformRegularizationLoss +from .seg_cross_entropy_loss import SegCrossEntropyLoss + +__all__ = [ + "Loss", + "EmbeddingDistanceLoss", + "IoULoss", + "MultiLevelSegLoss", + "MultiPosCrossEntropyLoss", + "OrthogonalTransformRegularizationLoss", + "SegCrossEntropyLoss", +] diff --git a/vis4d/op/loss/base.py b/vis4d/op/loss/base.py new file mode 100644 index 0000000000000000000000000000000000000000..e9ecc27a1e87b1d11f82c62ba70439a7afe97767 --- /dev/null +++ b/vis4d/op/loss/base.py @@ -0,0 +1,28 @@ +"""Base class for meta architectures.""" + +import abc + +from torch import nn + +from vis4d.op.loss.reducer import identity_loss + +from .reducer import LossReducer + + +class Loss(nn.Module, abc.ABC): + """Base loss class.""" + + def __init__(self, reducer: LossReducer = identity_loss) -> None: + """Initialize a loss functor. + + Args: + reducer (LossReducer): A function to aggregate the loss values into + a single tensor value. It is commonly used for dense prediction + tasks to merge pixel-wise loss to a final loss. + + Example:: + def mean_loss(loss: torch.Tensor) -> torch.Tensor: + return loss.mean() + """ + super().__init__() + self.reducer = reducer diff --git a/vis4d/op/loss/common.py b/vis4d/op/loss/common.py new file mode 100644 index 0000000000000000000000000000000000000000..991947c61a7b4dfb401240f47d745f60888dfe65 --- /dev/null +++ b/vis4d/op/loss/common.py @@ -0,0 +1,129 @@ +"""Common loss functions.""" + +import torch +import torch.nn.functional as F +from torch import Tensor + +from vis4d.op.loss.reducer import LossReducer, identity_loss + + +def smooth_l1_loss( + pred: Tensor, + target: Tensor, + reducer: LossReducer = identity_loss, + beta: float = 1.0, +) -> Tensor: + """Smooth L1 loss. + + L1 loss that uses a squared term if the absolute element-wise error + falls below beta. + + Args: + pred (Tensor): Model predictions + target (Tensor): Ground truth value + reducer (LossReducer): Reducer to reduce the loss value. Defaults to + identy_loss, which is no reduction. + beta (float): Specifies the threshold at which to change between L1 + and L2 loss. The value must be non-negative. Default: 1.0 + + Returns: + Tensor : The reduced smooth l1 loss: + |pred - target| - 0.5*beta if |pred - target| < 0.5*beta + (pred - target)^2 * 0.5/beta else + """ + assert beta > 0 + assert pred.size() == target.size() and target.numel() > 0 + diff = torch.abs(pred - target) + loss = torch.where( + diff < beta, 0.5 * diff * diff / beta, diff - 0.5 * beta + ) + return reducer(loss) + + +def l1_loss( + pred: Tensor, target: Tensor, reducer: LossReducer = identity_loss +) -> Tensor: + """L1 loss. + + Args: + pred (Tensor): Model predictions + target (Tensor): Ground truth value + reducer (LossReducer): Reducer to reduce the loss value. Defaults to + identy_loss, which is no reduction. + + Returns: + Tensor : The reduced L1 loss (reduce(|pred - target|)) + """ + assert pred.size() == target.size() and target.numel() > 0 + loss = torch.abs(pred - target) + return reducer(loss) + + +def l2_loss( + pred: Tensor, target: Tensor, reducer: LossReducer = identity_loss +) -> Tensor: + """L2 loss. + + Args: + pred (Tensor): Model predictions + target (Tensor): Ground truth value + reducer (LossReducer): Reducer to reduce the loss value. Defaults to + identy_loss, which is no reduction. + + Returns: + Tensor : The reduced L2 loss (reduce((pred - target)**2)) + """ + assert pred.size() == target.size() and target.numel() > 0 + loss = (pred - target) ** 2 + return reducer(loss) + + +def rotation_loss( + pred: Tensor, + target_bin: Tensor, + target_res: Tensor, + num_bins: int, + reducer: LossReducer = identity_loss, +) -> Tensor: + """Rotation loss. + + Consists of bin-based classification loss and residual-based regression + loss. + + Args: + pred (Tensor): Prediction shape [B, num_bins * 3] + target_bin (Tensor): Target bins shape [B, num_bin] + target_res (Tensor): Target residual shape [B, num_bin] + num_bins (int): Number of bins + reducer (LossReducer, optional): Loss Reducer. + Defaults to identity_loss. + + Returns: + Tensor: The reduced loss value + """ + loss_bins = ( + F.binary_cross_entropy_with_logits( + pred[:, :num_bins], target_bin, reduction="none" + ) + .mean(dim=0) + .sum() + ) + + loss_res = torch.zeros_like(loss_bins) + for i in range(num_bins): + bin_mask = target_bin[:, i] == 1 + res_idx = num_bins + 2 * i + if bin_mask.any(): + loss_sin = smooth_l1_loss( + pred[bin_mask, res_idx], + torch.sin(target_res[bin_mask, i]), + reducer=reducer, + ) + loss_cos = smooth_l1_loss( + pred[bin_mask, res_idx + 1], + torch.cos(target_res[bin_mask, i]), + reducer=reducer, + ) + loss_res += loss_sin + loss_cos + + return loss_bins + loss_res diff --git a/vis4d/op/loss/cross_entropy.py b/vis4d/op/loss/cross_entropy.py new file mode 100644 index 0000000000000000000000000000000000000000..e51ead860b2207d17a2c89a5cb1c45db128c8150 --- /dev/null +++ b/vis4d/op/loss/cross_entropy.py @@ -0,0 +1,89 @@ +"""Cross entropy loss.""" + +from __future__ import annotations + +import torch.nn.functional as F +from torch import Tensor + +from .base import Loss +from .reducer import LossReducer, mean_loss + + +class CrossEntropyLoss(Loss): + """Cross entropy loss class.""" + + def __init__( + self, + reducer: LossReducer = mean_loss, + class_weights: list[float] | None = None, + ) -> None: + """Creates an instance of the class. + + Args: + reducer (LossReducer): Reducer for the loss function. Defaults to + mean_loss. + class_weights (list[float], optional): Class weights for the loss + function. Defaults to None. + """ + super().__init__(reducer) + self.class_weights = class_weights + + def forward( + self, + output: Tensor, + target: Tensor, + reducer: LossReducer | None = None, + ignore_index: int = 255, + ) -> Tensor: + """Forward pass. + + Args: + output (list[Tensor]): Model output. + target (Tensor): Assigned segmentation target mask. + reducer (LossReducer, optional): Reducer for the loss function. + Defaults to None. + ignore_index (int): Ignore class id. Default to 255. + + Returns: + Tensor: Computed loss. + """ + if self.class_weights is not None: + class_weights = output.new_tensor( + self.class_weights, device=output.device + ) + else: + class_weights = None + reducer = reducer or self.reducer + + return reducer( + cross_entropy( + output, target, class_weights, ignore_index=ignore_index + ) + ) + + +def cross_entropy( + output: Tensor, + target: Tensor, + class_weights: Tensor | None = None, + ignore_index: int = 255, +) -> Tensor: + """Cross entropy loss function. + + Args: + output (Tensor): Model output. + target (Tensor): Assigned segmentation target mask. + class_weights (Tensor | None, optional): Class weights for the loss + function. Defaults to None. + ignore_index (int): Ignore class id. Default to 255. + + Returns: + Tensor: Computed loss. + """ + return F.cross_entropy( + output, + target.long(), + weight=class_weights, + ignore_index=ignore_index, + reduction="none", + ) diff --git a/vis4d/op/loss/embedding_distance.py b/vis4d/op/loss/embedding_distance.py new file mode 100644 index 0000000000000000000000000000000000000000..9c614541e96c3dba7810810e9386129f59824cf7 --- /dev/null +++ b/vis4d/op/loss/embedding_distance.py @@ -0,0 +1,103 @@ +"""Embedding distance loss.""" + +from __future__ import annotations + +import torch + +from vis4d.op.box.box2d import random_choice + +from .base import Loss +from .common import l2_loss +from .reducer import LossReducer, SumWeightedLoss, identity_loss + + +class EmbeddingDistanceLoss(Loss): + """Embedding distance loss for learning appearance similarity. + + Computes the difference between the target distances and the predicted + distances of two sets of embedding vectors. Uses hard negative mining based + on the loss values to select pairs for overall loss computation. + """ + + def __init__( + self, + reducer: LossReducer = identity_loss, + neg_pos_ub: float = 3.0, + pos_margin: float = 0.0, + neg_margin: float = 0.3, + hard_mining: bool = True, + ): + """Creates an instance of the class.""" + super().__init__(reducer) + self.neg_pos_ub = neg_pos_ub + self.neg_margin = neg_margin + self.pos_margin = pos_margin + self.hard_mining = hard_mining + + def forward( # pylint: disable=arguments-differ + self, + pred: torch.Tensor, + target: torch.Tensor, + weight: torch.Tensor | None = None, + ) -> torch.Tensor: + """Forward function. + + Args: + pred (torch.Tensor): The predicted distances between two sets of + predictions. Shape [N, M]. + target (torch.Tensor): The corresponding target distances. Either + zero (different identity) or one (same identity). + weight (torch.Tensor, optional): The weight of loss for each + prediction. Defaults to None. + + Returns: + loss_bbox (torch.Tensor): embedding distance loss. + """ + if weight is None: + weight = target.new_ones(target.size()) + pred, weight, avg_factor = self.update_weight(pred, target, weight) + return l2_loss( + pred, target, reducer=SumWeightedLoss(weight, avg_factor) + ) + + def update_weight( + self, pred: torch.Tensor, target: torch.Tensor, weight: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Update element-wise loss weights. + + Exclude negatives according to maximum fraction of samples and/or + hard negative mining. + """ + invalid_inds = weight <= 0 + target[invalid_inds] = -1 + pos_inds = torch.eq(target, 1) + neg_inds = torch.eq(target, 0) + + if self.pos_margin > 0: + pred[pos_inds] -= self.pos_margin + if self.neg_margin > 0: + pred[neg_inds] -= self.neg_margin + pred = torch.clamp(pred, min=0, max=1) + + num_pos = max(1, int(torch.eq(target, 1).sum())) + num_neg = int(torch.eq(target, 0).sum()) + if self.neg_pos_ub > 0 and num_neg / num_pos > self.neg_pos_ub: + num_neg = int(num_pos * self.neg_pos_ub) + neg_idx = torch.nonzero(torch.eq(target, 0), as_tuple=False) + + if self.hard_mining: + costs = l2_loss(pred, target)[ + neg_idx[:, 0], neg_idx[:, 1] + ].detach() + neg_idx = neg_idx[costs.topk(num_neg)[1], :] + else: + neg_idx = random_choice(neg_idx, num_neg) + + new_neg_inds = neg_inds.new_zeros(neg_inds.size()).bool() + new_neg_inds[neg_idx[:, 0], neg_idx[:, 1]] = True + + invalid_neg_inds = torch.logical_xor(neg_inds, new_neg_inds) + weight[invalid_neg_inds] = 0 + + avg_factor = torch.greater(weight, 0).sum() + return pred, weight, avg_factor diff --git a/vis4d/op/loss/iou_loss.py b/vis4d/op/loss/iou_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..5290d3b1d18010873299de5f0f6e1c964147104a --- /dev/null +++ b/vis4d/op/loss/iou_loss.py @@ -0,0 +1,94 @@ +"""Embedding distance loss.""" + +from __future__ import annotations + +import torch + +from vis4d.op.box.box2d import bbox_iou_aligned + +from .base import Loss +from .reducer import LossReducer, identity_loss + + +def iou_loss( + pred: torch.Tensor, + target: torch.Tensor, + reducer: LossReducer = identity_loss, + mode: str = "log", + eps: float = 1e-6, +) -> torch.Tensor: + """Compute IoU loss. + + Args: + pred (torch.Tensor): Predicted bboxes. + target (torch.Tensor): Target bboxes. + reducer (LossReducer): Reducer to reduce the loss value. Defaults to + identy_loss, which is no reduction. + mode (str, optional): Mode to calculate the loss. Defaults to "log". + eps (float, optional): Epsilon value to avoid division by zero. + + Returns: + torch.Tensor : The reduced IoU loss. + """ + assert mode in { + "linear", + "square", + "log", + }, f"Invalid mode {mode}. Must be one of 'linear', 'square', 'log'." + ious = bbox_iou_aligned(pred, target).clamp(min=eps) + if mode == "linear": + loss = 1 - ious + elif mode == "square": + loss = 1 - ious**2 + else: + loss = -ious.log() + return reducer(loss) + + +class IoULoss(Loss): + """IoU loss. + + Computing the IoU loss between a set of predicted bboxes and target bboxes. + The loss is calculated depending on the mode: + - linear: 1 - IoU + - square: 1 - IoU^2 + - log: -log(IoU) + + Args: + reducer (LossReducer): Reducer to reduce the loss value. Defaults to + identy_loss, which is no reduction. + mode (str, optional): Mode to calculate the loss. Defaults to "log". + eps (float, optional): Epsilon value to avoid division by zero. + """ + + def __init__( + self, + reducer: LossReducer = identity_loss, + mode: str = "log", + eps: float = 1e-6, + ): + """Creates an instance of the class.""" + super().__init__(reducer) + self.mode = mode + self.eps = eps + assert mode in { + "linear", + "square", + "log", + }, f"Invalid mode {mode}. Must be one of 'linear', 'square', 'log'." + + def forward( # pylint: disable=arguments-differ + self, pred: torch.Tensor, target: torch.Tensor + ) -> torch.Tensor: + """Forward function. + + Args: + pred (torch.Tensor): Predicted bboxes. + target (torch.Tensor): Target bboxes. + + Returns: + torch.Tensor: The reduced IoU loss. + """ + return iou_loss( + pred, target, reducer=self.reducer, mode=self.mode, eps=self.eps + ) diff --git a/vis4d/op/loss/multi_level_seg_loss.py b/vis4d/op/loss/multi_level_seg_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..2c1cddfc9fc330aa201843daec8ac3a5df1a2100 --- /dev/null +++ b/vis4d/op/loss/multi_level_seg_loss.py @@ -0,0 +1,72 @@ +"""Multi-level segmentation loss.""" + +from __future__ import annotations + +import torch +from torch import Tensor + +from vis4d.common.typing import LossesType + +from .base import Loss +from .cross_entropy import cross_entropy +from .reducer import LossReducer, mean_loss + + +class MultiLevelSegLoss(Loss): + """Multi-level segmentation loss class. + + Applies the segmentation loss function to multiple levels of predictions to + provide auxiliary losses for intermediate outputs in addition to the final + output, used in FCN. + """ + + def __init__( + self, + reducer: LossReducer = mean_loss, + feature_idx: tuple[int, ...] = (0,), + weights: list[float] | None = None, + ) -> None: + """Creates an instance of the class. + + Args: + reducer (LossReducer): Reducer for the loss function. Defaults to + mean_loss. + feature_idx (tuple[int]): Indices for the level of features to + compute losses. Defaults to (0,). + weights (list[float], optional): The weights of each feature level. + If None passes, it will set to 1 for all levels. Defaults to + None. + """ + super().__init__(reducer) + self.feature_idx = feature_idx + if weights is None: + self.weights = [1.0] * len(self.feature_idx) + else: + self.weights = weights + + def forward( + self, outputs: list[Tensor], target: Tensor, ignore_index: int = 255 + ) -> LossesType: + """Forward pass. + + Args: + outputs (list[Tensor]): Multi-level outputs. + target (Tensor): Assigned segmentation target mask. + ignore_index (int): Ignore class id. Default to 255. + + Returns: + LossesType: Computed losses for each level. + """ + losses: LossesType = {} + tgt_h, tgt_w = target.shape[-2:] + for i, idx in enumerate(self.feature_idx): + loss = self.reducer( + cross_entropy( + outputs[idx][:, :, :tgt_h, :tgt_w], + target, + ignore_index=ignore_index, + ) + ) + losses[f"loss_seg_level{idx}"] = torch.mul(self.weights[i], loss) + + return losses diff --git a/vis4d/op/loss/multi_pos_cross_entropy.py b/vis4d/op/loss/multi_pos_cross_entropy.py new file mode 100644 index 0000000000000000000000000000000000000000..e6ab48f087fcca3dba1e62b701f9393be11b9f01 --- /dev/null +++ b/vis4d/op/loss/multi_pos_cross_entropy.py @@ -0,0 +1,60 @@ +"""Multi-positive cross entropy loss.""" + +import torch +from torch import Tensor + +from .base import Loss +from .reducer import LossReducer, SumWeightedLoss + + +class MultiPosCrossEntropyLoss(Loss): + """Multi-positive cross entropy loss. + + Used for appearance similiary learning in QDTrack. + """ + + def forward( + self, + pred: Tensor, + target: Tensor, + weight: Tensor, + avg_factor: float, + ) -> Tensor: + """Multi-positive cross entropy loss. + + Args: + pred (Tensor): Similarity scores before softmax. Shape [N, M] + target (Tensor): Target for each pair. Either one, meaning + same identity or zero, meaning different identity. Shape [N, M] + weight (Tensor): The weight of loss for each prediction. + avg_factor (float): Averaging factor for the loss. + + Returns: + Tensor: Scalar loss value. + """ + return multi_pos_cross_entropy( + pred, target, reducer=SumWeightedLoss(weight, avg_factor) + ) + + +def multi_pos_cross_entropy( + pred: Tensor, target: Tensor, reducer: LossReducer +) -> Tensor: + """Calculate multi-positive cross-entropy loss.""" + pos_inds = torch.eq(target, 1) + neg_inds = torch.eq(target, 0) + pred_pos = pred * pos_inds.float() + pred_neg = pred * neg_inds.float() + # use -inf to mask out unwanted elements. + pred_pos[neg_inds] = pred_pos[neg_inds] + float("inf") + pred_neg[pos_inds] = pred_neg[pos_inds] + float("-inf") + + _pos_expand = torch.repeat_interleave(pred_pos, pred.shape[1], dim=1) + _neg_expand = pred_neg.repeat(1, pred.shape[1]) + + x = torch.nn.functional.pad( # pylint: disable=not-callable + (_neg_expand - _pos_expand), (0, 1), "constant", 0 + ) + loss = torch.logsumexp(x, dim=1) + + return reducer(loss) diff --git a/vis4d/op/loss/orthogonal_transform_loss.py b/vis4d/op/loss/orthogonal_transform_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..c97dc2718b0ed03ce9fb4d37eed5dd958fddbc0a --- /dev/null +++ b/vis4d/op/loss/orthogonal_transform_loss.py @@ -0,0 +1,61 @@ +"""Orthogonal Transform Loss.""" + +from __future__ import annotations + +import torch + +from .base import Loss + + +class OrthogonalTransformRegularizationLoss(Loss): + """Loss that punishes linear transformations that are not orthogonal. + + Calculates difference of X'*X and identity matrix using norm( X'*X - I) + """ + + def __call___(self, transforms: list[torch.Tensor]) -> torch.Tensor: + """Calculates the loss. + + Calculates difference of X'*X and the identity matrix using + norm(X'*X - I) for each transformation + + Args: + transforms: (list(torch.tensor)) list with transformation matrices + batched ([N, 3, 3], [N, x, x], ....) + + Returns: + torch.Tensor containing the mean loss value (mean(norm(X'*X - I))) + """ + return self._call_impl(transforms) + + def forward(self, transforms: list[torch.Tensor]) -> torch.Tensor: + """Calculates the loss. + + Calculates difference of X'*X and the identity matrix using + norm(X'*X - I) for each transformation + + Args: + transforms: (list(torch.tensor)) list with transformation matrices + batched ([N, 3, 3], [N, x, x], ....) + + Returns: + torch.Tensor containing the mean loss value (mean(norm(X'*X - I))) + """ + loss = torch.tensor(0.0) + for trans in transforms: + d = trans.size()[1] + + try: + identity = self.get_buffer(f"identity_{d}") + except AttributeError as _: + # Create identity buffers if not yet allocated + identity = torch.eye(d, device=trans.device) + self.register_buffer(f"identity_{d}", identity) + + loss += torch.mean( + torch.norm( + torch.bmm(trans, trans.transpose(2, 1)) - identity, + dim=(1, 2), + ) + ) + return loss diff --git a/vis4d/op/loss/reducer.py b/vis4d/op/loss/reducer.py new file mode 100644 index 0000000000000000000000000000000000000000..05a47c59631f75f2254ce97679ec6d9aa04083c0 --- /dev/null +++ b/vis4d/op/loss/reducer.py @@ -0,0 +1,69 @@ +"""Definitions of loss reducers. + +Loss reducers are usually used as the last step in loss computation to average +or sum the loss maps from dense predictions or object detections. +""" + +from __future__ import annotations + +from typing import Callable + +from torch import Tensor + +LossReducer = Callable[[Tensor], Tensor] + + +def identity_loss(loss: Tensor) -> Tensor: + """Make no change to the loss.""" + return loss + + +def mean_loss(loss: Tensor) -> Tensor: + """Average the loss tensor values to a single value. + + Args: + loss (Tensor): Input multi-dimentional tensor. + + Returns: + Tensor: Tensor containing a single loss value. + """ + return loss.mean() + + +def sum_loss(loss: Tensor) -> Tensor: + """Sum the loss tensor values to a single value. + + Args: + loss (Tensor): Input multi-dimentional tensor. + + Returns: + Tensor: Tensor containing a single loss value. + """ + return loss.sum() + + +class SumWeightedLoss: + """A loss reducer to calculated weighted sum loss.""" + + def __init__( + self, weight: float | Tensor, avg_factor: float | Tensor + ) -> None: + """Initialize the loss reducer. + + Args: + weight (float | Tensor): Weights for each loss elements + avg_factor (float | Tensor): average factor for the weighted loss + """ + self.weight = weight + self.avg_factor = avg_factor + + def __call__(self, loss: Tensor) -> Tensor: + """Weight the loss elements and take the sum with the average factor. + + Args: + loss (Tensor): input loss + + Returns: + Tensor: output loss + """ + return (loss * self.weight).sum() / self.avg_factor diff --git a/vis4d/op/loss/seg_cross_entropy_loss.py b/vis4d/op/loss/seg_cross_entropy_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..069affb8941933133140cd372ec2d5b5d2ed6d53 --- /dev/null +++ b/vis4d/op/loss/seg_cross_entropy_loss.py @@ -0,0 +1,50 @@ +"""Segmentation cross entropy loss.""" + +from __future__ import annotations + +from torch import Tensor + +from vis4d.common.typing import LossesType + +from .base import Loss +from .cross_entropy import cross_entropy +from .reducer import LossReducer, mean_loss + + +class SegCrossEntropyLoss(Loss): + """Segmentation cross entropy loss class. + + Wrapper for nn.CrossEntropyLoss that additionally clips the output to the + target size and converts the target mask tensor to long. + """ + + def __init__(self, reducer: LossReducer = mean_loss) -> None: + """Creates an instance of the class. + + Args: + reducer (LossReducer): Reducer for the loss function. Defaults to + mean_loss. + """ + super().__init__(reducer) + + def forward( + self, output: Tensor, target: Tensor, ignore_index: int = 255 + ) -> LossesType: + """Forward pass. + + Args: + output (list[Tensor]): Model output. + target (Tensor): Assigned segmentation target mask. + ignore_index (int): Ignore class id. Default to 255. + + Returns: + LossesType: Computed loss. + """ + losses: LossesType = {} + tgt_h, tgt_w = target.shape[-2:] + losses["loss_seg"] = self.reducer( + cross_entropy( + output[:, :, :tgt_h, :tgt_w], target, ignore_index=ignore_index + ) + ) + return losses diff --git a/vis4d/op/mask/__init__.py b/vis4d/op/mask/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5942e862bf69effca1ea4e645cea9c65df72b95b --- /dev/null +++ b/vis4d/op/mask/__init__.py @@ -0,0 +1 @@ +"""Operations on 2D segmentation masks.""" diff --git a/vis4d/op/mask/util.py b/vis4d/op/mask/util.py new file mode 100644 index 0000000000000000000000000000000000000000..4e0b299ca5e78342f916834a437b3977b2480fda --- /dev/null +++ b/vis4d/op/mask/util.py @@ -0,0 +1,283 @@ +"""Utility functions for segmentation masks.""" + +from __future__ import annotations + +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor + + +def _do_paste_mask( # type: ignore + masks: Tensor, + boxes: Tensor, + img_h: int, + img_w: int, + skip_empty: bool = True, +) -> tuple[Tensor, tuple[slice, slice] | tuple[()]]: + """Paste mask onto image. + + On GPU, paste all masks together (up to chunk size) by using the entire + image to sample the masks Compared to pasting them one by one, this has + more operations but is faster on COCO-scale dataset. + + This implementation is modified from + https://github.com/facebookresearch/detectron2/ + + Args: + masks (Tensor): Masks with shape [N, 1, Hmask, Wmask]. + boxes (Tensor): Boxes with shape [N, 4]. + img_h (int): Image height. + img_w (int): Image width. + skip_empty (bool, optional): Only paste masks within the region that + tightly bound all boxes, and returns the results this region only. + An important optimization for CPU. Defaults to True. + + Returns: + Tensor: Mask with shape [N, Himg, Wimg] if skip_empty == True, or + a mask of shape (N, H', W') and the slice object for the + corresponding region if skip_empty == False. + """ + device = masks.device + + if skip_empty: + x0_int, y0_int = torch.clamp( + boxes.min(dim=0).values.floor()[:2] - 1, min=0 + ).to(dtype=torch.int32) + x0_int, y0_int = x0_int.item(), y0_int.item() + x1_int = ( + torch.clamp(boxes[:, 2].max().ceil() + 1, max=img_w) + .to(dtype=torch.int32) + .item() + ) + y1_int = ( + torch.clamp(boxes[:, 3].max().ceil() + 1, max=img_h) + .to(dtype=torch.int32) + .item() + ) + else: + x0_int, y0_int = 0, 0 + x1_int, y1_int = img_w, img_h + x0, y0, x1, y1 = torch.split(boxes, 1, dim=1) # each is Nx1 + + num_masks = masks.shape[0] + + img_y: Tensor = ( + torch.arange(y0_int, y1_int, device=device, dtype=torch.float32) + 0.5 + ) + img_x: Tensor = ( + torch.arange(x0_int, x1_int, device=device, dtype=torch.float32) + 0.5 + ) + img_y = (img_y - y0) / (y1 - y0) * 2 - 1 # (N, h) + img_x = (img_x - x0) / (x1 - x0) * 2 - 1 # (N, w) + + gx = img_x[:, None, :].expand(num_masks, img_y.size(1), img_x.size(1)) + gy = img_y[:, :, None].expand(num_masks, img_y.size(1), img_x.size(1)) + grid = torch.stack([gx, gy], dim=3) + + if not masks.dtype.is_floating_point: + masks = masks.float() + img_masks = F.grid_sample(masks, grid, align_corners=False) + + if skip_empty: + return img_masks[:, 0], ( # pylint: disable=unsubscriptable-object + slice(y0_int, y1_int), + slice(x0_int, x1_int), + ) + return img_masks[:, 0], () # pylint: disable=unsubscriptable-object + + +def paste_masks_in_image( + masks: Tensor, + boxes: Tensor, + image_shape: tuple[int, int], + threshold: float = 0.5, + bytes_per_float: int = 4, + gpu_mem_limit: int = 1024**3, +) -> Tensor: + """Paste masks that are of a fixed resolution into an image. + + The location, height, and width for pasting each mask is determined by + their corresponding bounding boxes in boxes. + + This implementation is modified from + https://github.com/facebookresearch/detectron2/ + + Args: + masks (Tensor): Masks with shape [N, Hmask, Wmask], where N is + the number of detected object instances in the image and Hmask, + Wmask are the mask width and mask height of the predicted mask + (e.g., Hmask = Wmask = 28). Values are in [0, 1]. + boxes (Tensor): Boxes with shape [N, 4]. boxes[i] and masks[i] + correspond to the same object instance. + image_shape (tuple[int, int]): Image resolution (width, height). + threshold (float, optional): Threshold for discretization of mask. + Defaults to 0.5. + bytes_per_float (int, optional): Number of bytes per float. Defaults to + 4. + gpu_mem_limit (int, optional): GPU memory limit. Defaults to 1024**3. + + Returns: + Tensor: Masks with shape [N, Himage, Wimage], where N is the + number of detected object instances and Himage, Wimage are the + image width and height. + """ + assert ( + masks.shape[-1] == masks.shape[-2] + ), "Only square mask predictions are supported" + assert threshold >= 0 + num_masks = len(masks) + if num_masks == 0: + return masks + + img_w, img_h = image_shape + + # The actual implementation split the input into chunks, + # and paste them chunk by chunk. + if masks.device.type == "cpu": + # CPU is most efficient when they are pasted one by one with + # skip_empty=True so that it performs minimal number of operations. + num_chunks = num_masks + else: # pragma: no cover + # GPU benefits from parallelism for larger chunks, but may have + # memory issue int(img_h) because shape may be tensors in tracing + num_chunks = int( + np.ceil( + num_masks + * int(img_h) + * int(img_w) + * bytes_per_float + / gpu_mem_limit + ) + ) + assert ( + num_chunks <= num_masks + ), "Default gpu_mem_limit is too small; try increasing it" + chunks = torch.chunk( + torch.arange(num_masks, device=masks.device), num_chunks + ) + + img_masks = torch.zeros( + num_masks, img_h, img_w, device=masks.device, dtype=torch.bool + ) + for inds in chunks: + ( + masks_chunk, + spatial_inds, + ) = _do_paste_mask( + masks[inds, None, :, :], + boxes[inds, :4], + img_h, + img_w, + skip_empty=masks.device.type == "cpu", + ) + masks_chunk = torch.greater_equal(masks_chunk, threshold).to( + dtype=torch.bool + ) + img_masks[(inds,) + spatial_inds] = masks_chunk + return img_masks.type(torch.uint8) + + +def nhw_to_hwc_mask( + masks: Tensor, class_ids: Tensor, ignore_class: int = 255 +) -> Tensor: + """Convert N binary HxW masks to HxW semantic mask. + + Args: + masks (Tensor): Masks with shape [N, H, W]. + class_ids (Tensor): Class IDs with shape [N, 1]. + ignore_class (int, optional): Ignore label. Defaults to 255. + + Returns: + Tensor: Masks with shape [H, W], where each location indicate the + class label. + """ + hwc_mask = torch.full( + masks.shape[1:], ignore_class, dtype=masks.dtype, device=masks.device + ) + for mask, cat_id in zip(masks, class_ids): + hwc_mask[mask > 0] = cat_id + return hwc_mask + + +def clip_mask(mask: Tensor, target_shape: tuple[int, int]) -> Tensor: + """Clip mask. + + Args: + mask (Tensor): Mask with shape [C, H, W]. + target_shape (tuple[int, int]): Target shape (Ht, Wt). + + Returns: + Tensor: Clipped mask with shape [C, Ht, Wt]. + """ + return mask[:, : target_shape[0], : target_shape[1]] + + +def remove_overlap(mask: Tensor, score: Tensor) -> Tensor: + """Remove overlapping pixels between masks. + + Args: + mask (Tensor): Mask with shape [N, H, W]. + score (Tensor): Score with shape [N]. + + Returns: + Tensor: Mask with shape [N, H, W]. + """ + foreground = torch.zeros( + mask.shape[1:], dtype=torch.bool, device=mask.device + ) + sort_idx = score.argsort(descending=True) + for i in sort_idx: + mask[i] = torch.logical_and(mask[i], ~foreground) + foreground = torch.logical_or(mask[i], foreground) + return mask + + +def postprocess_segms( + segms: Tensor, + images_hw: list[tuple[int, int]], + original_hw: list[tuple[int, int]], +) -> Tensor: + """Postprocess segmentations. + + Args: + segms (Tensor): Segmentations with shape [B, C, H, W]. + images_hw (list[tuple[int, int]]): Image resolutions. + original_hw (list[tuple[int, int]]): Original image resolutions. + + Returns: + Tensor: Post-processed segmentations. + """ + post_segms = [] + for segm, image_hw, orig_hw in zip(segms, images_hw, original_hw): + post_segms.append( + F.interpolate( + segm[:, : image_hw[0], : image_hw[1]].unsqueeze(1), + size=(orig_hw[0], orig_hw[1]), + mode="bilinear", + ).squeeze(1) + ) + return torch.stack(post_segms).argmax(dim=1) + + +def masks2boxes(masks: Tensor) -> Tensor: + """Obtain the tight bounding boxes of binary masks. + + Args: + masks (Tensor): Binary mask of shape (N, H, W). + + Returns: + Tensor: Boxes with shape (N, 4) of positive region in binary mask. + """ + num_masks = masks.shape[0] + bboxes = masks.new_zeros((num_masks, 4), dtype=torch.float32) + x_any = torch.any(masks, dim=1) + y_any = torch.any(masks, dim=2) + for i in range(num_masks): + x = torch.where(x_any[i, :])[0] + y = torch.where(y_any[i, :])[0] + if len(x) > 0 and len(y) > 0: + bboxes[i, :] = bboxes.new_tensor( + [x[0], y[0], x[-1] + 1, y[-1] + 1] + ) + return bboxes diff --git a/vis4d/op/motion/__init__.py b/vis4d/op/motion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..725a2159d956ef4aaf41f115a8b1237d36c9e706 --- /dev/null +++ b/vis4d/op/motion/__init__.py @@ -0,0 +1 @@ +"""Motion operations.""" diff --git a/vis4d/op/motion/kalman_filter.py b/vis4d/op/motion/kalman_filter.py new file mode 100644 index 0000000000000000000000000000000000000000..1a3a05963afba0c14fa77e21da4c0cebee864d21 --- /dev/null +++ b/vis4d/op/motion/kalman_filter.py @@ -0,0 +1,84 @@ +"""Kalman Filter PyTorch implementation.""" + +from __future__ import annotations + +import torch +from torch import Tensor + + +def predict( + motion_mat: Tensor, + cov_motion_q: Tensor, + mean: Tensor, + covariance: Tensor, +) -> tuple[Tensor, Tensor]: + """Run Kalman filter prediction step.""" + # x = Fx + mean = torch.matmul(motion_mat, mean) + + # P = (FP)F + Q + covariance = ( + torch.matmul(motion_mat, torch.matmul(covariance, motion_mat.T)) + + cov_motion_q + ) + + return mean, covariance + + +def project( + update_mat: Tensor, cov_project_r: Tensor, mean: Tensor, covariance: Tensor +) -> tuple[Tensor, Tensor]: + """Project state distribution to measurement space.""" + # Hx + mean = torch.matmul(update_mat, mean) + + # HPH^T + R + covariance = torch.matmul( + update_mat, torch.matmul(covariance, update_mat.T) + ) + projected_cov = covariance + cov_project_r + return mean, projected_cov + + +def update( + update_mat: Tensor, + cov_project_r: Tensor, + mean: Tensor, + covariance: Tensor, + measurement: Tensor, +) -> tuple[Tensor, Tensor]: + """Run Kalman filter correction step.""" + # Hx, S = HPH^T + R + projected_mean, projected_cov = project( + update_mat, cov_project_r, mean, covariance + ) + + # K = PHT * S^-1 + chol_factor = torch.linalg.cholesky( # pylint: disable=not-callable + projected_cov + ) + kalman_gain = torch.cholesky_solve( + torch.matmul(covariance, update_mat.T).T, + chol_factor, + upper=False, + ).T + + # y = z - Hx + innovation = measurement - projected_mean + + # x = x + Ky + new_mean = mean + torch.matmul(innovation, kalman_gain.T) + + # P = (I-KH)P(I-KH)' + KRK' + # This is more numerically stable + # and works for non-optimal K vs the equation + # P = (I-KH)P usually seen in the literature. + i_kh = torch.eye(mean.shape[-1]).to( + device=measurement.device + ) - torch.matmul(kalman_gain, update_mat) + + new_covariance = torch.matmul( + torch.matmul(i_kh, covariance), i_kh.T + ) + torch.matmul(torch.matmul(kalman_gain, cov_project_r), kalman_gain.T) + + return new_mean, new_covariance diff --git a/vis4d/op/motion/velo_lstm.py b/vis4d/op/motion/velo_lstm.py new file mode 100644 index 0000000000000000000000000000000000000000..c1b8ec715ae28a644778327f20711aa6d89c362f --- /dev/null +++ b/vis4d/op/motion/velo_lstm.py @@ -0,0 +1,56 @@ +"""VeloLSTM operations.""" + +from __future__ import annotations + +import torch +import torch.nn.functional as F +from torch import Tensor + +from vis4d.common.typing import LossesType +from vis4d.op.loss.base import Loss + + +class VeloLSTMLoss(Loss): + """Loss term for VeloLSTM.""" + + def __init__(self, loc_dim: int = 7, smooth_weight: float = 0.001) -> None: + """Initialize the loss term.""" + super().__init__() + self.loc_dim = loc_dim + self.smooth_weight = smooth_weight + + @staticmethod + def linear_motion_loss(outputs: Tensor) -> Tensor: + """Linear motion loss. + + Loss: |(loc_t - loc_t-1), (loc_t-1, loc_t-2)|_1 for t = [2, s_len] + """ + s_len = outputs.shape[1] + + loss = outputs.new_zeros(1) + past_motion = outputs[:, 1, :] - outputs[:, 0, :] + for idx in range(2, s_len, 1): + curr_motion = outputs[:, idx, :] - outputs[:, idx - 1, :] + loss += F.l1_loss(past_motion, curr_motion, reduction="mean") + past_motion = curr_motion + return loss / (s_len - 2) + + def forward( + self, loc_preds: Tensor, loc_refines: Tensor, gt_traj: Tensor + ) -> LossesType: + """Loss term for VeloLSTM.""" + refine_loss = F.smooth_l1_loss( + loc_refines, gt_traj[:, 1:, : self.loc_dim], reduction="mean" + ) + pred_loss = F.smooth_l1_loss( + loc_preds[:, :-1, :], + gt_traj[:, 2:, : self.loc_dim], + reduction="mean", + ) + linear_loss = self.linear_motion_loss(loc_preds[:, :-1, :]) + + return { + "refine_loss": refine_loss, + "pred_loss": pred_loss, + "linear_loss": torch.mul(self.smooth_weight, linear_loss), + } diff --git a/vis4d/op/seg/__init__.py b/vis4d/op/seg/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2abad433ff740ed5313977b11976abcf913e7957 --- /dev/null +++ b/vis4d/op/seg/__init__.py @@ -0,0 +1 @@ +"""Segmentor module.""" diff --git a/vis4d/op/seg/fcn.py b/vis4d/op/seg/fcn.py new file mode 100644 index 0000000000000000000000000000000000000000..ab9ba09927214f68b4da6fd6c96f5b99b30cd66c --- /dev/null +++ b/vis4d/op/seg/fcn.py @@ -0,0 +1,117 @@ +"""FCN Head for semantic segmentation.""" + +from __future__ import annotations + +from typing import NamedTuple + +import torch +import torch.nn.functional as F +from torch import nn + + +class FCNOut(NamedTuple): + """Output of the FCN prediction.""" + + pred: torch.Tensor # logits for final prediction, (N, C, H, W) + outputs: list[torch.Tensor] # transformed feature maps + + +class FCNHead(nn.Module): + """FCN Head made with ResNet base model. + + This is based on the implementation in `torchvision + `_. + """ + + def __init__( + self, + in_channels: list[int], + out_channels: int, + dropout_prob: float = 0.1, + resize: tuple[int, int] | None = None, + ) -> None: + """Creates an instance of the class. + + Args: + in_channels (list[int]): Number of channels in multi-level image + feature. + out_channels (int): Number of output channels. Usually the number + of classes. + dropout_prob (float, optional): Dropout probability. Defaults to + 0.1. + resize (tuple(int,int), optional): Target shape to resize output. + Defaults to None. + """ + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.resize = resize + self.heads = nn.ModuleList() + for in_channel in self.in_channels: + self.heads.append( + self._make_head(in_channel, self.out_channels, dropout_prob) + ) + + def _make_head( + self, in_channels: int, channels: int, dropout_prob: float + ) -> nn.Module: + """Generate FCN segmentation head. + + Args: + in_channels (int): Input feature channels. + channels (int): Output segmentation channels. + dropout_prob (float): Dropout probability. + + Returns: + nn.Module: FCN segmentation head. + """ + inter_channels = in_channels // 4 + layers = [ + nn.Conv2d( + in_channels, + inter_channels, + kernel_size=3, + padding=1, + bias=False, + ), + nn.BatchNorm2d(inter_channels), + nn.ReLU(), + nn.Dropout(dropout_prob), + nn.Conv2d(inter_channels, channels, kernel_size=1), + ] + return nn.Sequential(*layers) + + def forward(self, feats: list[torch.Tensor]) -> FCNOut: + """Transforms feature maps and returns segmentation prediction. + + Args: + feats (list[torch.Tensor]): List of multi-level image features. + + Returns: + output (list[torch.Tensor]): Each tensor has shape (batch_size, + self.channels, H, W) which is prediction for each FCN stages. E.g., + + outputs[-1] ==> main output map + outputs[-2] ==> aux output map (e.g., used for training) + outputs[:-2] ==> x[:-2] + """ + outputs = feats.copy() + num_features = len(feats) + for i in range(len(self.in_channels)): + idx = num_features - len(self.in_channels) + i + feat = feats[idx] + output = self.heads[i](feat) + if self.resize: + output = F.interpolate( + output, + size=self.resize, + mode="bilinear", + align_corners=False, + ) + outputs[idx] = F.log_softmax(output, dim=1) + return FCNOut(pred=outputs[-1], outputs=outputs) + + def __call__(self, feats: list[torch.Tensor]) -> FCNOut: + """Type definition for function call.""" + return super()._call_impl(feats) diff --git a/vis4d/op/seg/semantic_fpn.py b/vis4d/op/seg/semantic_fpn.py new file mode 100644 index 0000000000000000000000000000000000000000..b646fda3d6566e4098863d9d150b893fa9d2c451 --- /dev/null +++ b/vis4d/op/seg/semantic_fpn.py @@ -0,0 +1,122 @@ +"""Semantic FPN Head for segmentation.""" + +from __future__ import annotations + +from typing import NamedTuple + +import torch.nn.functional as F +from torch import Tensor, nn + +from vis4d.op.layer.conv2d import Conv2d + + +class SemanticFPNOut(NamedTuple): + """Output of the SemanticFPN prediction.""" + + outputs: Tensor # logits for final prediction, (N, C, H, W) + + +class SemanticFPNHead(nn.Module): + """SemanticFPNHead used in Panoptic FPN.""" + + def __init__( + self, + num_classes: int = 53, + in_channels: int = 256, + inner_channels: int = 128, + start_level: int = 2, + end_level: int = 6, + dropout_ratio: float = 0.1, + ): + """Creates an instance of the class. + + Args: + num_classes (int): Number of classes. Default: 53. + in_channels (int): Number of channels in the input feature map. + inner_channels (int): Number of channels in inner features. + start_level (int): The start level of the input features used in + SemanticFPN. + end_level (int): The end level of the used features, the + ``end_level``-th layer will not be used. + dropout_ratio (float): The drop ratio of dropout layer. + Default: 0.1. + """ + super().__init__() + self.num_classes = num_classes + + # Used feature layers are [start_level, end_level) + self.start_level = start_level + self.end_level = end_level + self.num_stages = end_level - start_level + self.inner_channels = inner_channels + + self.scale_heads = nn.ModuleList() + for i in range(start_level, end_level): + head_length = max(1, i - start_level) + scale_head: list[nn.Module] = [] + for k in range(head_length): + scale_head.append( + Conv2d( + in_channels if k == 0 else inner_channels, + inner_channels, + 3, + padding=1, + stride=1, + bias=False, + norm=nn.BatchNorm2d(inner_channels), + activation=nn.ReLU(inplace=True), + ) + ) + if i > start_level: + scale_head.append( + nn.Upsample( + scale_factor=2, + mode="bilinear", + align_corners=False, + ) + ) + self.scale_heads.append(nn.Sequential(*scale_head)) + self.conv_seg = nn.Conv2d(inner_channels, num_classes, 1) + self.dropout_ratio = dropout_ratio + if dropout_ratio > 0: + self.dropout = nn.Dropout2d(dropout_ratio) + self.init_weights() + + def init_weights(self) -> None: + """Initialize weights.""" + nn.init.kaiming_normal_( + self.conv_seg.weight, mode="fan_out", nonlinearity="relu" + ) + if hasattr(self.conv_seg, "bias") and self.conv_seg.bias is not None: + nn.init.constant_(self.conv_seg.bias, 0) + + def forward(self, features: list[Tensor]) -> SemanticFPNOut: + """Transforms feature maps and returns segmentation prediction. + + Args: + features (list[Tensor]): List of multi-level image features. + + Returns: + SemanticFPNOut: Segmentation outputs. + """ + assert self.num_stages <= len( + features + ), "Number of subnets must be not more than length of features." + + output = self.scale_heads[0](features[self.start_level]) + for i in range(1, self.num_stages): + output = output + F.interpolate( + self.scale_heads[i](features[self.start_level + i]), + size=output.shape[2:], + mode="bilinear", + align_corners=False, + ) + + if self.dropout_ratio > 0: + output = self.dropout(output) + seg_preds = self.conv_seg(output) + return SemanticFPNOut(outputs=seg_preds) + + def __call__(self, feats: list[Tensor]) -> SemanticFPNOut: + """Type definition for function call.""" + return super()._call_impl(feats) diff --git a/vis4d/op/track/__init__.py b/vis4d/op/track/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7f2b8dff252eed5da17e7f1929da96c2806c7637 --- /dev/null +++ b/vis4d/op/track/__init__.py @@ -0,0 +1 @@ +"""Tracking models module.""" diff --git a/vis4d/op/track/assignment.py b/vis4d/op/track/assignment.py new file mode 100644 index 0000000000000000000000000000000000000000..4136c4161cd85a4521cc1ec8a841dfa3f1c11ce2 --- /dev/null +++ b/vis4d/op/track/assignment.py @@ -0,0 +1,104 @@ +"""Track assignment functions.""" + +from __future__ import annotations + +import torch +from scipy.optimize import linear_sum_assignment +from torch import Tensor + + +def greedy_assign( + detection_scores: Tensor, + tracklet_ids: Tensor, + affinity_scores: Tensor, + match_score_thr: float = 0.5, + obj_score_thr: float = 0.3, + nms_conf_thr: None | float = None, +) -> Tensor: + """Greedy assignment of detections to tracks given affinities.""" + ids = torch.full( + (len(detection_scores),), + -1, + dtype=torch.long, + device=detection_scores.device, + ) + + for i, score in enumerate(detection_scores): + conf, memo_ind = torch.max(affinity_scores[i, :], dim=0) + cur_id = tracklet_ids[memo_ind] + if conf > match_score_thr: + if cur_id > -1: + if score > obj_score_thr: + ids[i] = cur_id + affinity_scores[:i, memo_ind] = 0 + affinity_scores[(i + 1) :, memo_ind] = 0 + elif nms_conf_thr is not None and conf > nms_conf_thr: + ids[i] = -2 + return ids + + +def hungarian_assign( + detection_scores: Tensor, + tracklet_ids: Tensor, + affinity_scores: Tensor, + match_score_thr: float = 0.5, + obj_score_thr: float = 0.3, + nms_conf_thr: None | float = None, +) -> Tensor: + """Hungarian assignment of detections to tracks given affinities.""" + ids = torch.full( + (len(detection_scores),), + -1, + dtype=torch.long, + device=detection_scores.device, + ) + + matched_indices = linear_sum_assignment(-affinity_scores.cpu().numpy()) + + for idx in range(len(matched_indices[0])): + i = matched_indices[0][idx] + memo_ind = matched_indices[1][idx] + conf = affinity_scores[i, memo_ind] + tid = tracklet_ids[memo_ind] + if conf > match_score_thr and tid > -1: + if detection_scores[i] > obj_score_thr: + ids[i] = tid + affinity_scores[:i, memo_ind] = 0 + affinity_scores[i + 1 :, memo_ind] = 0 + elif nms_conf_thr is not None and conf > nms_conf_thr: + ids[i] = -2 + + return ids + + +class TrackIDCounter: + """Global counter for track ids. + + Holds a count of tracks to enable unique and contiguous track ids starting + from zero. + """ + + count: int = 0 + + @classmethod + def reset(cls) -> None: + """Reset track id counter.""" + cls.count = 0 + + @classmethod + def get_ids( + cls, num_ids: int, device: torch.device = torch.device("cpu") + ) -> Tensor: + """Generate a num_ids number of new unique tracking ids. + + Args: + num_ids (int): number of ids + device (torch.device, optional): Device to create ids on. Defaults + to torch.device("cpu"). + + Returns: + Tensor: Tensor of new contiguous track ids. + """ + new_ids = torch.arange(cls.count, cls.count + num_ids, device=device) + cls.count = cls.count + num_ids + return new_ids diff --git a/vis4d/op/track/common.py b/vis4d/op/track/common.py new file mode 100644 index 0000000000000000000000000000000000000000..b8b9e1a88298769cf86f6ead619d46fec692c2e1 --- /dev/null +++ b/vis4d/op/track/common.py @@ -0,0 +1,23 @@ +"""Common classes and functions for tracking.""" + +from __future__ import annotations + +from typing import NamedTuple + +from torch import Tensor + + +class TrackOut(NamedTuple): + """Output of track model. + + Attributes: + boxes (list[Tensor]): List of bounding boxes (B, N, 4). + class_ids (list[Tensor]): List of class ids (B, N). + scores (list[Tensor]): List of scores (B, N). + track_ids (list[Tensor]): List of track ids (B, N). + """ + + boxes: list[Tensor] + class_ids: list[Tensor] + scores: list[Tensor] + track_ids: list[Tensor] diff --git a/vis4d/op/track/matching.py b/vis4d/op/track/matching.py new file mode 100644 index 0000000000000000000000000000000000000000..c1baff96a10c28ddc8e829d25891c94963fa358a --- /dev/null +++ b/vis4d/op/track/matching.py @@ -0,0 +1,48 @@ +"""Matching calculation utils.""" + +from __future__ import annotations + +import torch +from torch.nn import functional as F + + +def calc_bisoftmax_affinity( + detection_embeddings: torch.Tensor, + track_embeddings: torch.Tensor, + detection_class_ids: torch.Tensor | None = None, + track_class_ids: torch.Tensor | None = None, + with_categories: bool = False, +) -> torch.Tensor: + """Calculate affinity matrix using bisoftmax metric.""" + feats = torch.mm(detection_embeddings, track_embeddings.t()) + d2t_scores = feats.softmax(dim=1) + t2d_scores = feats.softmax(dim=0) + similarity_scores = (d2t_scores + t2d_scores) / 2 + + if with_categories: + assert ( + detection_class_ids is not None and track_class_ids is not None + ), "Please provide class ids if with_categories=True!" + cat_same = detection_class_ids.view(-1, 1) == track_class_ids.view( + 1, -1 + ) + similarity_scores *= cat_same.float() + return similarity_scores + + +def cosine_similarity( + key_embeds: torch.Tensor, + ref_embeds: torch.Tensor, + normalize: bool = True, + temperature: float = -1, +) -> torch.Tensor: + """Calculate cosine similarity.""" + if normalize: + key_embeds = F.normalize(key_embeds, p=2, dim=1) + ref_embeds = F.normalize(ref_embeds, p=2, dim=1) + + dists = torch.mm(key_embeds, ref_embeds.t()) + + if temperature > 0: + dists /= temperature # pragma: no cover + return dists diff --git a/vis4d/op/track/qdtrack.py b/vis4d/op/track/qdtrack.py new file mode 100644 index 0000000000000000000000000000000000000000..b177d4f08acc0cbbbfd59489298a0b636c4cf24c --- /dev/null +++ b/vis4d/op/track/qdtrack.py @@ -0,0 +1,681 @@ +"""Quasi-dense embedding similarity based graph.""" + +from __future__ import annotations + +import math +from typing import NamedTuple + +import torch +from torch import Tensor, nn + +from vis4d.op.box.box2d import bbox_iou +from vis4d.op.box.matchers.max_iou import MaxIoUMatcher +from vis4d.op.box.poolers import MultiScaleRoIAlign, MultiScaleRoIPooler +from vis4d.op.box.samplers import CombinedSampler, match_and_sample_proposals +from vis4d.op.layer.conv2d import add_conv_branch +from vis4d.op.loss import EmbeddingDistanceLoss, MultiPosCrossEntropyLoss + +from .assignment import TrackIDCounter, greedy_assign +from .matching import calc_bisoftmax_affinity, cosine_similarity + + +def get_default_box_sampler() -> CombinedSampler: + """Get default box sampler of qdtrack.""" + box_sampler = CombinedSampler( + batch_size=256, + positive_fraction=0.5, + pos_strategy="instance_balanced", + neg_strategy="iou_balanced", + ) + return box_sampler + + +def get_default_box_matcher() -> MaxIoUMatcher: + """Get default box matcher of qdtrack.""" + box_matcher = MaxIoUMatcher( + thresholds=[0.3, 0.7], + labels=[0, -1, 1], + allow_low_quality_matches=False, + ) + return box_matcher + + +class QDTrackOut(NamedTuple): + """Output of QDTrack during training.""" + + key_embeddings: list[Tensor] + ref_embeddings: list[list[Tensor]] | None + key_track_ids: list[Tensor] | None + ref_track_ids: list[list[Tensor]] | None + + +class QDTrackHead(nn.Module): + """QDTrack - quasi-dense instance similarity learning.""" + + def __init__( + self, + similarity_head: QDSimilarityHead | None = None, + box_sampler: CombinedSampler | None = None, + box_matcher: MaxIoUMatcher | None = None, + proposal_append_gt: bool = True, + ) -> None: + """Creates an instance of the class.""" + super().__init__() + self.similarity_head = ( + QDSimilarityHead() if similarity_head is None else similarity_head + ) + + self.box_sampler = ( + box_sampler + if box_sampler is not None + else get_default_box_sampler() + ) + + self.box_matcher = ( + box_matcher + if box_matcher is not None + else get_default_box_matcher() + ) + + self.proposal_append_gt = proposal_append_gt + + @torch.no_grad() + def _sample_proposals( + self, + det_boxes: list[list[Tensor]], + target_boxes: list[list[Tensor]], + target_track_ids: list[list[Tensor]], + ) -> tuple[list[list[Tensor]], list[list[Tensor]]]: + """Sample proposals for instance similarity learning.""" + sampled_boxes, sampled_track_ids = [], [] + for i, (boxes, tgt_boxes) in enumerate(zip(det_boxes, target_boxes)): + if self.proposal_append_gt: + boxes = [torch.cat([d, t]) for d, t in zip(boxes, tgt_boxes)] + + ( + sampled_box_indices, + sampled_target_indices, + sampled_labels, + ) = match_and_sample_proposals( + self.box_matcher, self.box_sampler, boxes, tgt_boxes + ) + + positives = [l == 1 for l in sampled_labels] + if i == 0: # key view: take only positives + sampled_box = [ + b[s_i][p] + for b, s_i, p in zip(boxes, sampled_box_indices, positives) + ] + sampled_tr_id = [ + t[s_i][p] + for t, s_i, p in zip( + target_track_ids[i], sampled_target_indices, positives + ) + ] + else: # set track_ids to -1 for all negatives + sampled_box = [ + b[s_i] for b, s_i in zip(boxes, sampled_box_indices) + ] + sampled_tr_id = [ + t[s_i] + for t, s_i in zip( + target_track_ids[i], sampled_target_indices + ) + ] + for pos, samp_tgt in zip(positives, sampled_tr_id): + samp_tgt[~pos] = -1 + + sampled_boxes.append(sampled_box) + sampled_track_ids.append(sampled_tr_id) + return sampled_boxes, sampled_track_ids + + def forward( + self, + features: list[Tensor] | list[list[Tensor]], + det_boxes: list[Tensor] | list[list[Tensor]], + target_boxes: None | list[list[Tensor]] = None, + target_track_ids: None | list[list[Tensor]] = None, + ) -> QDTrackOut: + """Forward function.""" + if target_boxes is not None and target_track_ids is not None: + sampled_boxes, sampled_track_ids = self._sample_proposals( + det_boxes, # type: ignore + target_boxes, + target_track_ids, + ) + + embeddings = [] + for feats, boxes in zip(features, sampled_boxes): + assert isinstance(feats, list) and isinstance(boxes, list) + embeddings.append(self.similarity_head(feats, boxes)) + + return QDTrackOut( + embeddings[0], + embeddings[1:], + sampled_track_ids[0], + sampled_track_ids[1:], + ) + + key_embeddings = self.similarity_head(features, det_boxes) # type: ignore # pylint: disable=line-too-long + + return QDTrackOut(key_embeddings, None, None, None) + + def __call__( + self, + features: list[Tensor] | list[list[Tensor]], + det_boxes: list[Tensor] | list[list[Tensor]], + target_boxes: None | list[list[Tensor]] = None, + target_track_ids: None | list[list[Tensor]] = None, + ) -> QDTrackOut: + """Type definition for call implementation.""" + return self._call_impl( + features, det_boxes, target_boxes, target_track_ids + ) + + +class QDTrackAssociation: + """Data association relying on quasi-dense instance similarity. + + This class assigns detection candidates to a given memory of existing + tracks and backdrops. + Backdrops are low-score detections kept in case they have high + similarity with a high-score detection in succeeding frames. + + Attributes: + init_score_thr: Confidence threshold for initializing a new track + obj_score_thr: Confidence treshold s.t. a detection is considered in + the track / det matching process. + match_score_thr: Similarity score threshold for matching a detection to + an existing track. + memo_backdrop_frames: Number of timesteps to keep backdrops. + memo_momentum: Momentum of embedding memory for smoothing embeddings. + nms_backdrop_iou_thr: Maximum IoU of a backdrop with another detection. + nms_class_iou_thr: Maximum IoU of a high score detection with another + of a different class. + with_cats: If to consider category information for tracking (i.e. all + detections within a track must have consistent category labels). + """ + + def __init__( + self, + init_score_thr: float = 0.7, + obj_score_thr: float = 0.3, + match_score_thr: float = 0.5, + nms_conf_thr: float = 0.5, + nms_backdrop_iou_thr: float = 0.3, + nms_class_iou_thr: float = 0.7, + with_cats: bool = True, + ) -> None: + """Creates an instance of the class.""" + super().__init__() + self.init_score_thr = init_score_thr + self.obj_score_thr = obj_score_thr + self.match_score_thr = match_score_thr + self.nms_backdrop_iou_thr = nms_backdrop_iou_thr + self.nms_class_iou_thr = nms_class_iou_thr + self.nms_conf_thr = nms_conf_thr + self.with_cats = with_cats + + def _filter_detections( + self, + detections: Tensor, + scores: Tensor, + class_ids: Tensor, + embeddings: Tensor, + ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + """Remove overlapping objects across classes via nms. + + Args: + detections (Tensor): [N, 4] Tensor of boxes. + scores (Tensor): [N,] Tensor of confidence scores. + class_ids (Tensor): [N,] Tensor of class ids. + embeddings (Tensor): [N, C] tensor of appearance embeddings. + + Returns: + tuple[Tensor]: filtered detections, scores, class_ids, + embeddings, and filtered indices. + """ + scores, inds = scores.sort(descending=True) + detections, embeddings, class_ids = ( + detections[inds], + embeddings[inds], + class_ids[inds], + ) + valids = embeddings.new_ones((len(detections),), dtype=torch.bool) + ious = bbox_iou(detections, detections) + for i in range(1, len(detections)): + if scores[i] < self.obj_score_thr: + thr = self.nms_backdrop_iou_thr + else: + thr = self.nms_class_iou_thr + + if (ious[i, :i] > thr).any(): + valids[i] = False + detections = detections[valids] + scores = scores[valids] + class_ids = class_ids[valids] + embeddings = embeddings[valids] + return detections, scores, class_ids, embeddings, inds[valids] + + def __call__( + self, + detections: Tensor, + detection_scores: Tensor, + detection_class_ids: Tensor, + detection_embeddings: Tensor, + memory_track_ids: Tensor | None = None, + memory_class_ids: Tensor | None = None, + memory_embeddings: Tensor | None = None, + ) -> tuple[Tensor, Tensor]: + """Process inputs, match detections with existing tracks. + + Args: + detections (Tensor): [N, 4] detected boxes. + detection_scores (Tensor): [N,] confidence scores. + detection_class_ids (Tensor): [N,] class indices. + detection_embeddings (Tensor): [N, C] appearance embeddings. + memory_track_ids (Tensor): [M,] track ids in memory. + memory_class_ids (Tensor): [M,] class indices in memory. + memory_embeddings (Tensor): [M, C] appearance embeddings in + memory. + + Returns: + tuple[Tensor, Tensor]: track ids of active tracks and selected + detection indices corresponding to tracks. + """ + ( + detections, + detection_scores, + detection_class_ids, + detection_embeddings, + permute_inds, + ) = self._filter_detections( + detections, + detection_scores, + detection_class_ids, + detection_embeddings, + ) + + # match if buffer is not empty + if len(detections) > 0 and memory_track_ids is not None: + assert ( + memory_class_ids is not None and memory_embeddings is not None + ) + + affinity_scores = calc_bisoftmax_affinity( + detection_embeddings, + memory_embeddings, + detection_class_ids, + memory_class_ids, + self.with_cats, + ) + ids = greedy_assign( + detection_scores, + memory_track_ids, + affinity_scores, + self.match_score_thr, + self.obj_score_thr, + self.nms_conf_thr, + ) + else: + ids = torch.full( + (len(detections),), + -1, + dtype=torch.long, + device=detections.device, + ) + new_inds = (ids == -1) & (detection_scores > self.init_score_thr) + ids[new_inds] = TrackIDCounter.get_ids( + new_inds.sum(), device=ids.device # type: ignore + ) + return ids, permute_inds + + +class QDSimilarityHead(nn.Module): + """Instance embedding head for quasi-dense similarity learning. + + Given a set of input feature maps and RoIs, pool RoI representations from + feature maps and process them to a per-RoI embeddings vector. + """ + + def __init__( + self, + proposal_pooler: None | MultiScaleRoIPooler = None, + in_dim: int = 256, + num_convs: int = 4, + conv_out_dim: int = 256, + conv_has_bias: bool = False, + num_fcs: int = 1, + fc_out_dim: int = 1024, + embedding_dim: int = 256, + norm: str = "GroupNorm", + num_groups: int = 32, + start_level: int = 2, + ) -> None: + """Creates an instance of the class. + + Args: + proposal_pooler (None | MultiScaleRoIPooler, optional): RoI pooling + module. Defaults to None. + in_dim (int, optional): Input feature dimension. Defaults to 256. + num_convs (int, optional): Number of convolutional layers inside + the head. Defaults to 4. + conv_out_dim (int, optional): Output dimension of the last conv + layer. Defaults to 256. + conv_has_bias (bool, optional): If the conv layers have a bias + parameter. Defaults to False. + num_fcs (int, optional): Number of fully connected layers following + the conv layers. Defaults to 1. + fc_out_dim (int, optional): Output dimension of the last fully + connected layer. Defaults to 1024. + embedding_dim (int, optional): Dimensionality of the output + instance embedding. Defaults to 256. + norm (str, optional): Normalization of the layers inside the head. + One of BatchNorm2d, GroupNorm. Defaults to "GroupNorm". + num_groups (int, optional): Number of groups for the GroupNorm + normalization. Defaults to 32. + start_level (int, optional): starting level of feature maps. + Defaults to 2. + """ + super().__init__() + self.in_dim = in_dim + self.num_convs = num_convs + self.conv_out_dim = conv_out_dim + self.conv_has_bias = conv_has_bias + self.num_fcs = num_fcs + self.fc_out_dim = fc_out_dim + self.norm = norm + self.num_groups = num_groups + + if proposal_pooler is not None: + self.roi_pooler = proposal_pooler + else: + self.roi_pooler = MultiScaleRoIAlign( + resolution=[7, 7], strides=[4, 8, 16, 32], sampling_ratio=0 + ) + + # Used feature layers are [start_level, end_level) + self.start_level = start_level + num_strides = len(self.roi_pooler.scales) + self.end_level = start_level + num_strides + + self.convs, self.fcs, last_layer_dim = self._init_embedding_head() + self.fc_embed = nn.Linear(last_layer_dim, embedding_dim) + self._init_weights() + + def _init_weights(self) -> None: + """Init weights of modules in head.""" + for m in self.convs: + nn.init.kaiming_uniform_(m.weight, a=1) # type: ignore + if m.bias is not None: + nn.init.constant_(m.bias, 0) # type: ignore + + for m in self.fcs: + if isinstance(m[0], nn.Linear): # type: ignore + nn.init.xavier_uniform_(m[0].weight) # type: ignore + nn.init.constant_(m[0].bias, 0) # type: ignore + + nn.init.normal_(self.fc_embed.weight, 0, 0.01) + nn.init.constant_(self.fc_embed.bias, 0) + + def _init_embedding_head( + self, + ) -> tuple[torch.nn.ModuleList, torch.nn.ModuleList, int]: + """Init modules of head.""" + convs, last_layer_dim = add_conv_branch( + self.num_convs, + self.in_dim, + self.conv_out_dim, + self.conv_has_bias, + self.norm, + self.num_groups, + ) + + fcs = nn.ModuleList() + if self.num_fcs > 0: + last_layer_dim *= math.prod(self.roi_pooler.resolution) + for i in range(self.num_fcs): + fc_in_dim = last_layer_dim if i == 0 else self.fc_out_dim + fcs.append( + nn.Sequential( + nn.Linear(fc_in_dim, self.fc_out_dim), + nn.ReLU(inplace=True), + ) + ) + last_layer_dim = self.fc_out_dim + return convs, fcs, last_layer_dim + + def forward( + self, features: list[Tensor], boxes: list[Tensor] + ) -> list[Tensor]: + """Similarity head forward pass. + + Args: + features (list[Tensor]): A feature pyramid. The list index + represents the level, which has a downsampling raio of 2^index. + fp[0] is a feature map with the image resolution instead of the + original image. + boxes (list[Tensor]): A list of [N, 4] 2D bounding boxes per + batch element. + + Returns: + list[Tensor]: An embedding vector per input box, . + """ + # RoI pooling + x = self.roi_pooler(features[self.start_level : self.end_level], boxes) + + # convs + if self.num_convs > 0: + for conv in self.convs: + x = conv(x) + + # fcs + x = torch.flatten(x, start_dim=1) + if self.num_fcs > 0: + for fc in self.fcs: + x = fc(x) + + embeddings: list[Tensor] = list( + self.fc_embed(x).split([len(b) for b in boxes]) + ) + return embeddings + + def __call__( + self, features: list[Tensor], boxes: list[Tensor] + ) -> list[Tensor]: + """Type definition.""" + return self._call_impl(features, boxes) + + +class QDTrackInstanceSimilarityLosses(NamedTuple): + """QDTrack losses return type. Consists of two scalar loss tensors.""" + + track_loss: Tensor + track_loss_aux: Tensor + + +class QDTrackInstanceSimilarityLoss(nn.Module): + """Instance similarity loss as in QDTrack. + + Given a number of key frame embeddings and a number of reference frame + embeddings along with their track identities, compute two losses: + 1. Multi-positive cross-entropy loss. + 2. Cosine similarity loss (auxiliary). + """ + + def __init__(self, softmax_temp: float = -1): + """Creates an instance of the class. + + Args: + softmax_temp (float, optional): Temperature parameter for + multi-positive cross-entropy loss. Defaults to -1. + """ + super().__init__() + self.softmax_temp = softmax_temp + self.track_loss = MultiPosCrossEntropyLoss() + self.track_loss_aux = EmbeddingDistanceLoss() + self.track_loss_weight = 0.25 + + def forward( + self, + key_embeddings: list[Tensor], + ref_embeddings: list[list[Tensor]], + key_track_ids: list[Tensor], + ref_track_ids: list[list[Tensor]], + ) -> QDTrackInstanceSimilarityLosses: + """The QDTrack instance similarity loss. + + Key inputs are of type list[Tensor/Boxes2D] (Lists are length N) + Ref inputs are of type list[list[Tensor/Boxes2D]] where the lists + are of length MxN. + Where M is the number of reference views and N is the + number of batch elements. + + NOTE: this only works if key only contains positives and all + negatives in ref have track_id -1 + + Args: + key_embeddings (list[Tensor]): key frame embeddings. + ref_embeddings (list[list[Tensor]]): reference frame + embeddings. + key_track_ids (list[Tensor]): associated track ids per + embedding in key frame. + ref_track_ids (list[list[Tensor]]): associated track ids per + embedding in reference frame(s). + + Returns: + QDTrackInstanceSimilarityLosses: Scalar loss tensors. + """ + if sum(len(e) for e in key_embeddings) == 0: # pragma: no cover + dummy_loss = sum(e.sum() * 0.0 for e in key_embeddings) + return QDTrackInstanceSimilarityLosses(dummy_loss, dummy_loss) # type: ignore # pylint: disable=line-too-long + + loss_track = torch.tensor(0.0, device=key_embeddings[0].device) + loss_track_aux = torch.tensor(0.0, device=key_embeddings[0].device) + dists, cos_dists = self._match(key_embeddings, ref_embeddings) + track_targets, track_weights = self._get_targets( + key_track_ids, ref_track_ids + ) + # for each reference view + for curr_dists, curr_cos_dists, curr_targets, curr_weights in zip( + dists, cos_dists, track_targets, track_weights + ): + # for each batch element + for _dists, _cos_dists, _targets, _weights in zip( + curr_dists, curr_cos_dists, curr_targets, curr_weights + ): + if all(_dists.shape): + loss_track += ( + self.track_loss( + _dists, + _targets, + _weights, + avg_factor=_weights.sum() + 1e-5, + ) + * self.track_loss_weight + ) + if self.track_loss_aux is not None: + loss_track_aux += self.track_loss_aux( + _cos_dists, _targets + ) + + num_pairs = len(dists) * len(dists[0]) + loss_track = torch.div(loss_track, num_pairs) + loss_track_aux = torch.div(loss_track_aux, num_pairs) + + return QDTrackInstanceSimilarityLosses( + track_loss=loss_track, track_loss_aux=loss_track_aux + ) + + def __call__( + self, + key_embeddings: list[Tensor], + ref_embeddings: list[list[Tensor]], + key_track_ids: list[Tensor], + ref_track_ids: list[list[Tensor]], + ) -> QDTrackInstanceSimilarityLosses: + """Type definition.""" + return self._call_impl( + key_embeddings, ref_embeddings, key_track_ids, ref_track_ids + ) + + @staticmethod + def _get_targets( + key_track_ids: list[Tensor], + ref_track_ids: list[list[Tensor]], + ) -> tuple[list[list[Tensor]], list[list[Tensor]]]: + """Create tracking target tensors. + + Args: + key_track_ids (list[Tensor]): A List of Tensors [N,] per + batch element containing the corresponding track ids of each + box in the key frame. + ref_track_ids (list[list[Tensor]]): A nested list fo Tensors + [N,] per batch element, per reference view. The inner list + denotes the batch index, the outer list the reference view + index. Contains track ids of boxes in all reference views + across the batch. + + Returns: + tuple[list[list[Tensor]], list[list[Tensor]]]: The + target tensors per key-reference pair containing 1 if the + identities of two boxes across the key and a reference view + match, and 0 otherwise and the loss reduction weights for + a certain box. + """ + # for each reference view + track_targets, track_weights = [], [] + for ref_target in ref_track_ids: + # for each batch element + curr_targets, curr_weights = [], [] + for key_target, ref_target_ in zip(key_track_ids, ref_target): + # target shape: len(key_target) x len(ref_target_) + # NOTE: this only works if key only contains positives and all + # negatives in ref have track_id -1 + target = ( + key_target.view(-1, 1) == ref_target_.view(1, -1) + ).int() + weight = (target.sum(dim=1) > 0).float() + curr_targets.append(target) + curr_weights.append(weight) + track_targets.append(curr_targets) + track_weights.append(curr_weights) + return track_targets, track_weights + + def _match( + self, + key_embeds: list[Tensor], + ref_embeds: list[list[Tensor]], + ) -> tuple[list[list[Tensor]], list[list[Tensor]]]: + """Calculate distances for all pairs of key / ref embeddings. + + Args: + key_embeds (list[Tensor]): Embeddings for boxes in key frame. + ref_embeds (list[list[Tensor]]): Embeddings for boxes in + all reference frames. + + Returns: + tuple[list[list[Tensor]], list[list[Tensor]]]: + Embedding distances for all embedding pairs, first normalized + via softmax, then normal cosine similary. + """ + # for each reference view + dists, cos_dists = [], [] + for ref_embed in ref_embeds: + # for each batch element + dists_curr, cos_dists_curr = [], [] + for key_embed, ref_embed_ in zip(key_embeds, ref_embed): + dist = cosine_similarity( + key_embed, + ref_embed_, + normalize=False, + temperature=self.softmax_temp, + ) + dists_curr.append(dist) + if self.track_loss_aux is not None: + cos_dist = cosine_similarity(key_embed, ref_embed_) + cos_dists_curr.append(cos_dist) + + dists.append(dists_curr) + cos_dists.append(cos_dists_curr) + return dists, cos_dists diff --git a/vis4d/op/track3d/__init__.py b/vis4d/op/track3d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cdb3d1797016ad1ca3b483e81bcf2f79a0ba3d0c --- /dev/null +++ b/vis4d/op/track3d/__init__.py @@ -0,0 +1 @@ +"""3D tracking models module.""" diff --git a/vis4d/op/track3d/cc_3dt.py b/vis4d/op/track3d/cc_3dt.py new file mode 100644 index 0000000000000000000000000000000000000000..837b2b76f8a01ec9db6d151f939bb09e0169e897 --- /dev/null +++ b/vis4d/op/track3d/cc_3dt.py @@ -0,0 +1,446 @@ +"""CC-3DT graph.""" + +from __future__ import annotations + +import torch +import torch.nn.functional as F +from torch import Tensor + +from vis4d.op.box.box2d import bbox_iou +from vis4d.op.geometry.rotation import ( + euler_angles_to_matrix, + matrix_to_quaternion, + rotate_orientation, + rotate_velocities, +) +from vis4d.op.geometry.transform import transform_points +from vis4d.op.track.assignment import TrackIDCounter, greedy_assign +from vis4d.op.track.matching import calc_bisoftmax_affinity + +from .common import Track3DOut + + +def get_track_3d_out( + boxes_3d: Tensor, class_ids: Tensor, scores_3d: Tensor, track_ids: Tensor +) -> Track3DOut: + """Get track 3D output. + + Args: + boxes_3d (Tensor): (N, 12): x,y,z,h,w,l,rx,ry,rz,vx,vy,vz + class_ids (Tensor): (N,) + scores_3d (Tensor): (N,) + track_ids (Tensor): (N,) + + Returns: + Track3DOut: output + """ + center = boxes_3d[:, :3] + # HWL -> WLH + dims = boxes_3d[:, [4, 5, 3]] + orientation = matrix_to_quaternion( + euler_angles_to_matrix(boxes_3d[:, 6:9]) + ) + + return Track3DOut( + boxes_3d=[torch.cat([center, dims, orientation], dim=1)], + velocities=[boxes_3d[:, 9:12]], + class_ids=[class_ids], + scores_3d=[scores_3d], + track_ids=[track_ids], + ) + + +class CC3DTrackAssociation: + """Data association relying on quasi-dense instance similarity and 3D clue. + + This class assigns detection candidates to a given memory of existing + tracks and backdrops. + Backdrops are low-score detections kept in case they have high + similarity with a high-score detection in succeeding frames. + """ + + def __init__( + self, + init_score_thr: float = 0.8, + obj_score_thr: float = 0.5, + match_score_thr: float = 0.5, + nms_backdrop_iou_thr: float = 0.3, + nms_class_iou_thr: float = 0.7, + nms_conf_thr: float = 0.5, + with_cats: bool = True, + with_velocities: bool = False, + bbox_affinity_weight: float = 0.5, + ) -> None: + """Creates an instance of the class. + + Args: + init_score_thr (float): Confidence threshold for initializing a new + track. + obj_score_thr (float): Confidence treshold s.t. a detection is + considered in the track / det matching process. + match_score_thr (float): Similarity score threshold for matching a + detection to an existing track. + nms_backdrop_iou_thr (float): Maximum IoU of a backdrop with + another detection. + nms_class_iou_thr (float): Maximum IoU of a high score detection + with another of a different class. + nms_conf_thr (float): Confidence threshold for NMS. + with_cats (bool): If to consider category information for + tracking (i.e. all detections within a track must have + consistent category labels). + with_velocities (bool): If to use predicted velocities for + matching. + bbox_affinity_weight (float): Weight of bbox affinity in the + overall affinity score. + """ + super().__init__() + self.init_score_thr = init_score_thr + self.obj_score_thr = obj_score_thr + self.match_score_thr = match_score_thr + self.nms_backdrop_iou_thr = nms_backdrop_iou_thr + self.nms_class_iou_thr = nms_class_iou_thr + self.nms_conf_thr = nms_conf_thr + self.with_cats = with_cats + self.with_velocities = with_velocities + self.bbox_affinity_weight = bbox_affinity_weight + self.feat_affinity_weight = 1 - bbox_affinity_weight + + def _filter_detections( + self, + detections: Tensor, + camera_ids: Tensor, + scores: Tensor, + detections_3d: Tensor, + scores_3d: Tensor, + class_ids: Tensor, + embeddings: Tensor, + velocities: Tensor | None = None, + ) -> tuple[ + Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor | None, Tensor + ]: + """Remove overlapping objects across classes via nms. + + Args: + detections (Tensor): [N, 4] Tensor of boxes. + camera_ids (Tensor): [N,] Tensor of camera ids. + scores (Tensor): [N,] Tensor of confidence scores. + detections_3d (Tensor): [N, 7] Tensor of 3D boxes. + scores_3d (Tensor): [N,] Tensor of 3D confidence scores. + class_ids (Tensor): [N,] Tensor of class ids. + embeddings (Tensor): [N, C] tensor of appearance embeddings. + velocities (Tensor | None): [N, 3] Tensor of velocities. + + Returns: + tuple[Tensor]: filtered detections, scores, class_ids, + embeddings, and filtered indices. + """ + scores, inds = scores.sort(descending=True) + ( + detections, + camera_ids, + embeddings, + class_ids, + detections_3d, + scores_3d, + ) = ( + detections[inds], + camera_ids[inds], + embeddings[inds], + class_ids[inds], + detections_3d[inds], + scores_3d[inds], + ) + + if velocities is not None: + velocities = velocities[inds] + + valids = embeddings.new_ones((len(detections),), dtype=torch.bool) + + ious = bbox_iou(detections, detections) + valid_ious = torch.eq( + camera_ids.unsqueeze(1), camera_ids.unsqueeze(0) + ).int() + ious *= valid_ious + + for i in range(1, len(detections)): + if scores[i] < self.obj_score_thr: + thr = self.nms_backdrop_iou_thr + else: + thr = self.nms_class_iou_thr + + if (ious[i, :i] > thr).any(): + valids[i] = False + + detections = detections[valids] + scores = scores[valids] + detections_3d = detections_3d[valids] + scores_3d = scores_3d[valids] + class_ids = class_ids[valids] + embeddings = embeddings[valids] + + if velocities is not None: + velocities = velocities[valids] + + return ( + detections, + scores, + detections_3d, + scores_3d, + class_ids, + embeddings, + velocities, + inds[valids], + ) + + def depth_ordering( + self, + obsv_boxes_3d: Tensor, + obsv_velocities: Tensor | None, + memory_boxes_3d_predict: Tensor, + memory_boxes_3d: Tensor, + memory_velocities: Tensor, + ) -> Tensor: + """Depth ordering matching.""" + # Centroid + centroid_weight_list = [] + for memory_box_3d_predict in memory_boxes_3d_predict: + centroid_weight_list.append( + F.pairwise_distance( # pylint: disable=not-callable + obsv_boxes_3d[:, :3], + memory_box_3d_predict[:3], + keepdim=True, + ) + ) + centroid_weight = torch.cat(centroid_weight_list, dim=1) + centroid_weight = torch.exp(-torch.div(centroid_weight, 10.0)) + + # Moving distance should be aligned + motion_weight_list = [] + moving_dist = ( + obsv_boxes_3d[:, :3, None] + - memory_boxes_3d[:, :3, None].transpose(2, 0) + ).transpose(1, 2) + for v in moving_dist: + motion_weight_list.append( + F.pairwise_distance( # pylint: disable=not-callable + v, memory_velocities[:, :3] + ).unsqueeze(0) + ) + motion_weight = torch.cat(motion_weight_list, dim=0) + motion_weight = torch.exp(-torch.div(motion_weight, 5.0)) + + # Velocity scores + if self.with_velocities: + assert ( + obsv_velocities is not None + ), "Please provide velocities if with_velocities=True!" + + velsim_weight_list = [] + obsvvv_velocities = obsv_velocities.unsqueeze(1).expand_as( + moving_dist + ) + for v in obsvvv_velocities: + velsim_weight_list.append( + F.pairwise_distance( # pylint: disable=not-callable + v, memory_velocities[:, -3:] + ).unsqueeze(0) + ) + velsim_weight = torch.cat(velsim_weight_list, dim=0) + cos_sim = torch.exp(-velsim_weight / 5.0) + else: + # Moving direction should be aligned + # Set to 0.5 when two vector not within +-90 degree + cos_sim_list = [] + obsv_direct = ( + obsv_boxes_3d[:, :2, None] + - memory_boxes_3d[:, :2, None].transpose(2, 0) + ).transpose(1, 2) + for d in obsv_direct: + cos_sim_list.append( + F.cosine_similarity( # pylint: disable=not-callable + d, memory_velocities[:, :2] + ).unsqueeze(0) + ) + cos_sim = torch.cat(cos_sim_list, dim=0) + cos_sim = torch.add(cos_sim, 1.0) + cos_sim = torch.div(cos_sim, 2.0) + + scores_depth = ( + cos_sim * centroid_weight + (1.0 - cos_sim) * motion_weight + ) + + return scores_depth + + def __call__( + self, + detections: Tensor, + camera_ids: Tensor, + detection_scores: Tensor, + detections_3d: Tensor, + detection_scores_3d: Tensor, + detection_class_ids: Tensor, + detection_embeddings: Tensor, + obs_velocities: Tensor | None = None, + memory_boxes_3d: Tensor | None = None, + memory_track_ids: Tensor | None = None, + memory_class_ids: Tensor | None = None, + memory_embeddings: Tensor | None = None, + memory_boxes_3d_predict: Tensor | None = None, + memory_velocities: Tensor | None = None, + with_depth_confidence: bool = True, + ) -> tuple[Tensor, Tensor]: + """Process inputs, match detections with existing tracks. + + Args: + detections (Tensor): [N, 4] detected boxes. + camera_ids (Tensor): [N,] camera ids. + detection_scores (Tensor): [N,] confidence scores. + detections_3d (Tensor): [N, 7] detected boxes in 3D. + detection_scores_3d (Tensor): [N,] confidence scores in 3D. + detection_class_ids (Tensor): [N,] class indices. + detection_embeddings (Tensor): [N, C] appearance embeddings. + obs_velocities (Tensor | None): [N, 3] velocities of detections. + memory_boxes_3d (Tensor): [M, 7] boxes in memory. + memory_track_ids (Tensor): [M,] track ids in memory. + memory_class_ids (Tensor): [M,] class indices in memory. + memory_embeddings (Tensor): [M, C] appearance embeddings in + memory. + memory_boxes_3d_predict (Tensor): [M, 7] predicted boxes in + memory. + memory_velocities (Tensor): [M, 7] velocities in memory. + + Returns: + tuple[Tensor, Tensor]: track ids of active tracks and selected + detection indices corresponding to tracks. + """ + ( + detections, + detection_scores, + detections_3d, + detection_scores_3d, + detection_class_ids, + detection_embeddings, + obs_velocities, + permute_inds, + ) = self._filter_detections( + detections, + camera_ids, + detection_scores, + detections_3d, + detection_scores_3d, + detection_class_ids, + detection_embeddings, + obs_velocities, + ) + + if with_depth_confidence: + depth_confidence = detection_scores_3d + else: + depth_confidence = detection_scores_3d.new_ones( + len(detection_scores_3d) + ) + + # match if buffer is not empty + if len(detections) > 0 and memory_boxes_3d is not None: + assert ( + memory_track_ids is not None + and memory_class_ids is not None + and memory_embeddings is not None + and memory_boxes_3d_predict is not None + and memory_velocities is not None + ) + + # Box 3D + bbox3d_weight_list = [] + for memory_box_3d_predict in memory_boxes_3d_predict: + bbox3d_weight_list.append( + F.pairwise_distance( # pylint: disable=not-callable + detections_3d, + memory_box_3d_predict, + keepdim=True, + ) + ) + bbox3d_weight = torch.cat(bbox3d_weight_list, dim=1) + scores_iou = torch.exp(-torch.div(bbox3d_weight, 10.0)) + + # Depth Ordering + scores_depth = self.depth_ordering( + detections_3d, + obs_velocities, + memory_boxes_3d_predict, + memory_boxes_3d, + memory_velocities, + ) + + # match using bisoftmax metric + similarity_scores = calc_bisoftmax_affinity( + detection_embeddings, + memory_embeddings, + detection_class_ids, + memory_class_ids, + ) + + if self.with_cats: + assert ( + detection_class_ids is not None + and memory_class_ids is not None + ), "Please provide class ids if with_categories=True!" + cat_same = detection_class_ids.view( + -1, 1 + ) == memory_class_ids.view(1, -1) + scores_cats = cat_same.float() + + affinity_scores = ( + self.bbox_affinity_weight * scores_iou * scores_depth + + self.feat_affinity_weight * similarity_scores + ) + affinity_scores /= ( + self.bbox_affinity_weight + self.feat_affinity_weight + ) + affinity_scores = torch.mul( + affinity_scores, torch.greater(scores_iou, 0.0).float() + ) + affinity_scores = torch.mul( + affinity_scores, torch.greater(scores_depth, 0.0).float() + ) + if self.with_cats: + affinity_scores = torch.mul(affinity_scores, scores_cats) + + ids = greedy_assign( + detection_scores * depth_confidence, + memory_track_ids, + affinity_scores, + self.match_score_thr, + self.obj_score_thr, + self.nms_conf_thr, + ) + else: + ids = torch.full( + (len(detections),), + -1, + dtype=torch.long, + device=detections.device, + ) + new_inds = (ids == -1) & (detection_scores > self.init_score_thr) + ids[new_inds] = TrackIDCounter.get_ids( + new_inds.sum(), device=ids.device # type: ignore + ) + return ids, permute_inds + + +def cam_to_global( + boxes_3d_list: list[Tensor], extrinsics: Tensor +) -> list[Tensor]: + """Convert camera coordinates to global coordinates.""" + for i, boxes_3d in enumerate(boxes_3d_list): + if len(boxes_3d) != 0: + boxes_3d_list[i][:, :3] = transform_points( + boxes_3d_list[i][:, :3], extrinsics[i] + ) + boxes_3d_list[i][:, 6:9] = rotate_orientation( + boxes_3d_list[i][:, 6:9], extrinsics[i] + ) + boxes_3d_list[i][:, 9:12] = rotate_velocities( + boxes_3d_list[i][:, 9:12], extrinsics[i] + ) + return boxes_3d_list diff --git a/vis4d/op/track3d/common.py b/vis4d/op/track3d/common.py new file mode 100644 index 0000000000000000000000000000000000000000..77a2d289a84a4110bf54e1e76f0a3d748611b086 --- /dev/null +++ b/vis4d/op/track3d/common.py @@ -0,0 +1,25 @@ +"""Common classes and functions for 3D tracking.""" + +from __future__ import annotations + +from typing import NamedTuple + +from torch import Tensor + + +class Track3DOut(NamedTuple): + """Output of track 3D model. + + Attributes: + boxes_3d (list[Tensor]): List of bounding boxes (B, N, 10). + velocities (list[Tensor]): List of velocities (B, N, 3). + class_ids (list[Tensor]): List of class ids (B, N). + scores_3d (list[Tensor]): List of scores (B, N). + track_ids (list[Tensor]): List of track ids (B, N). + """ + + boxes_3d: list[Tensor] + velocities: list[Tensor] + class_ids: list[Tensor] + scores_3d: list[Tensor] + track_ids: list[Tensor] diff --git a/vis4d/op/util.py b/vis4d/op/util.py new file mode 100644 index 0000000000000000000000000000000000000000..eb683e7516b1bdd16d4a58df4becf737a156cd48 --- /dev/null +++ b/vis4d/op/util.py @@ -0,0 +1,28 @@ +"""Utilities for op.""" + +from __future__ import annotations + +import torch +from torch import Tensor + + +def unmap(data: Tensor, count: int, inds: Tensor, fill: int = 0) -> Tensor: + """Unmap a subset of data back to the original data (of size count). + + Args: + data (Tensor): Subset of the original data. + count (int): Length of the original data. + inds (Tensor): Indices of the subset entries in the original set. + fill (int, optional): Fill value for other entries. Defaults to 0. + + Returns: + Tensor: Tensor sized like original data that contains the subset. + """ + if data.dim() == 1: + ret = data.new_full((count,), fill) + ret[inds.type(torch.bool)] = data + else: + new_size = (count,) + data.size()[1:] + ret = data.new_full(new_size, fill) + ret[inds.type(torch.bool), :] = data + return ret diff --git a/vis4d/state/__init__.py b/vis4d/state/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..034dac3be7266d41b216e2fa2781d09c75211038 --- /dev/null +++ b/vis4d/state/__init__.py @@ -0,0 +1 @@ +"""Memory and internal states needed for models.""" diff --git a/vis4d/state/track/__init__.py b/vis4d/state/track/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f3d677fa8378e3cae1ed73fe5d51f9be0a0e1734 --- /dev/null +++ b/vis4d/state/track/__init__.py @@ -0,0 +1 @@ +"""Memory and state for tracking algorithms.""" diff --git a/vis4d/state/track/qdtrack.py b/vis4d/state/track/qdtrack.py new file mode 100644 index 0000000000000000000000000000000000000000..316608954c022312312da242837857f32cae5ef8 --- /dev/null +++ b/vis4d/state/track/qdtrack.py @@ -0,0 +1,327 @@ +"""Memory for QDTrack inference.""" + +from __future__ import annotations + +from typing import TypedDict + +import torch +from torch import Tensor + +from vis4d.op.box.box2d import bbox_iou +from vis4d.op.track.assignment import TrackIDCounter +from vis4d.op.track.common import TrackOut +from vis4d.op.track.qdtrack import QDTrackAssociation + + +class Track(TypedDict): + """QDTrack Track state. + + Attributes: + box (Tensor): In shape (4,) and contains x1, y1, x2, y2. + score (Tensor): In shape (1,). + class_id (Tensor): In shape (1,). + embedding (Tensor): In shape (E,). E is the embedding dimension. + last_frame (int): Last frame id. + """ + + box: Tensor + score: Tensor + class_id: Tensor + embed: Tensor + last_frame: int + + +class QDTrackGraph: + """Quasi-dense embedding similarity based graph.""" + + def __init__( + self, + track: QDTrackAssociation | None = None, + memory_size: int = 10, + memory_momentum: float = 0.8, + nms_backdrop_iou_thr: float = 0.3, + backdrop_memory_size: int = 1, + ) -> None: + """Init.""" + assert memory_size >= 0 + self.memory_size = memory_size + assert 0 <= memory_momentum <= 1.0 + self.memory_momentum = memory_momentum + assert backdrop_memory_size >= 0 + self.backdrop_memory_size = backdrop_memory_size + self.nms_backdrop_iou_thr = nms_backdrop_iou_thr + + self.tracker = QDTrackAssociation() if track is None else track + + self.tracklets: dict[int, Track] = {} + self.backdrops: list[dict[str, Tensor]] = [] + + def reset(self) -> None: + """Empty the memory.""" + self.tracklets.clear() + self.backdrops.clear() + + def is_empty(self) -> bool: + """Check if the memory is empty.""" + return len(self.tracklets) == 0 + + def get_tracks( + self, + device: torch.device, + frame_id: int | None = None, + add_backdrops: bool = False, + ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + """Get tracklests. + + If the frame_id is not provided, will return the latest state of all + tracklets. Otherwise, will return the state of all tracklets at the + given frame_id. If add_backdrops is True, will also return the + backdrops. + + Args: + device (torch.device): Device to put the tensors on. + frame_id (int, optional): Frame id to query. Defaults to None. + add_backdrops (bool, optional): Whether to add backdrops to the + output. Defaults to False. + + Returns: + boxes (Tensor): 2D boxes in shape (N, 4). + scores (Tensor): 2D scores in shape (N,). + class_ids (Tensor): Class ids in shape (N,). + track_ids (Tensor): Track ids in shape (N,). + embeddings (Tensor): Embeddings in shape (N, E). + """ + ( + boxes_list, + scores_list, + class_ids_list, + embeddings_list, + track_ids_list, + ) = ([], [], [], [], []) + + for track_id, track in self.tracklets.items(): + if frame_id is None or track["last_frame"] == frame_id: + boxes_list.append(track["box"].unsqueeze(0)) + scores_list.append(track["score"].unsqueeze(0)) + class_ids_list.append(track["class_id"].unsqueeze(0)) + embeddings_list.append(track["embed"].unsqueeze(0)) + track_ids_list.append(track_id) + + boxes = ( + torch.cat(boxes_list) + if len(boxes_list) > 0 + else torch.empty((0, 4), device=device) + ) + scores = ( + torch.cat(scores_list) + if len(scores_list) > 0 + else torch.empty((0,), device=device) + ) + class_ids = ( + torch.cat(class_ids_list) + if len(class_ids_list) > 0 + else torch.empty((0,), device=device) + ) + embeddings = ( + torch.cat(embeddings_list) + if len(embeddings_list) > 0 + else torch.empty((0,), device=device) + ) + track_ids = torch.tensor(track_ids_list, device=device) + + if add_backdrops: + for backdrop in self.backdrops: + backdrop_ids = torch.full( + (len(backdrop["embeddings"]),), + -1, + dtype=torch.long, + device=device, + ) + track_ids = torch.cat([track_ids, backdrop_ids]) + boxes = torch.cat([boxes, backdrop["boxes"]]) + scores = torch.cat([scores, backdrop["scores"]]) + class_ids = torch.cat([class_ids, backdrop["class_ids"]]) + embeddings = torch.cat([embeddings, backdrop["embeddings"]]) + + return boxes, scores, class_ids, track_ids, embeddings + + def __call__( + self, + embeddings_list: list[Tensor], + det_boxes_list: list[Tensor], + det_scores_list: list[Tensor], + class_ids_list: list[Tensor], + frame_id_list: list[int], + ) -> TrackOut: + """Forward during test.""" + ( + batched_boxes, + batched_scores, + batched_class_ids, + batched_track_ids, + ) = ([], [], [], []) + + for frame_id, det_boxes, det_scores, class_ids, embeddings in zip( + frame_id_list, + det_boxes_list, + det_scores_list, + class_ids_list, + embeddings_list, + ): + # reset graph at begin of sequence + if frame_id == 0: + self.reset() + TrackIDCounter.reset() + + if not self.is_empty(): + ( + _, + _, + memo_class_ids, + memo_track_ids, + memo_embeds, + ) = self.get_tracks(det_boxes.device, add_backdrops=True) + else: + memo_class_ids = None + memo_track_ids = None + memo_embeds = None + + track_ids, filter_indices = self.tracker( + det_boxes, + det_scores, + class_ids, + embeddings, + memo_track_ids, + memo_class_ids, + memo_embeds, + ) + + self.update( + frame_id, + track_ids, + det_boxes[filter_indices], + det_scores[filter_indices], + class_ids[filter_indices], + embeddings[filter_indices], + ) + + ( + boxes, + scores, + class_ids, + track_ids, + _, + ) = self.get_tracks(det_boxes.device, frame_id=frame_id) + + batched_boxes.append(boxes) + batched_scores.append(scores) + batched_class_ids.append(class_ids) + batched_track_ids.append(track_ids) + + return TrackOut( + boxes=batched_boxes, + class_ids=batched_class_ids, + scores=batched_scores, + track_ids=batched_track_ids, + ) + + def update( + self, + frame_id: int, + track_ids: Tensor, + boxes: Tensor, + scores: Tensor, + class_ids: Tensor, + embeddings: Tensor, + ) -> None: + """Update the track memory with a new state.""" + valid_tracks = track_ids > -1 + + # update memo + for track_id, box, score, class_id, embed in zip( + track_ids[valid_tracks], + boxes[valid_tracks], + scores[valid_tracks], + class_ids[valid_tracks], + embeddings[valid_tracks], + ): + track_id = int(track_id) + if track_id in self.tracklets: + self.update_track( + track_id, box, score, class_id, embed, frame_id + ) + else: + self.create_track( + track_id, box, score, class_id, embed, frame_id + ) + + # backdrops + backdrop_inds = torch.nonzero( + torch.eq(track_ids, -1), as_tuple=False + ).squeeze(1) + + ious = bbox_iou(boxes[backdrop_inds], boxes) + + for i, ind in enumerate(backdrop_inds): + if (ious[i, :ind] > self.nms_backdrop_iou_thr).any(): + backdrop_inds[i] = -1 + backdrop_inds = backdrop_inds[backdrop_inds > -1] + + self.backdrops.insert( + 0, + { + "boxes": boxes[backdrop_inds], + "scores": scores[backdrop_inds], + "class_ids": class_ids[backdrop_inds], + "embeddings": embeddings[backdrop_inds], + }, + ) + + # delete invalid tracks from memory + invalid_ids = [] + for k, v in self.tracklets.items(): + if frame_id - v["last_frame"] >= self.memory_size: + invalid_ids.append(k) + for invalid_id in invalid_ids: + self.tracklets.pop(invalid_id) + + if len(self.backdrops) > self.backdrop_memory_size: + self.backdrops.pop() + + def update_track( + self, + track_id: int, + box: Tensor, + score: Tensor, + class_id: Tensor, + embedding: Tensor, + frame_id: int, + ) -> None: + """Update a specific track with a new models.""" + self.tracklets[track_id]["box"] = box + self.tracklets[track_id]["score"] = score + self.tracklets[track_id]["class_id"] = class_id + self.tracklets[track_id]["embed"] = ( + 1 - self.memory_momentum + ) * self.tracklets[track_id][ + "embed" + ] + self.memory_momentum * embedding + self.tracklets[track_id]["last_frame"] = frame_id + + def create_track( + self, + track_id: int, + box: Tensor, + score: Tensor, + class_id: Tensor, + embedding: Tensor, + frame_id: int, + ) -> None: + """Create a new track from a models.""" + self.tracklets[track_id] = Track( + box=box, + score=score, + class_id=class_id, + embed=embedding, + last_frame=frame_id, + ) diff --git a/vis4d/state/track3d/__init__.py b/vis4d/state/track3d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0b145e08204552b43e29a0ad9058381c565f0453 --- /dev/null +++ b/vis4d/state/track3d/__init__.py @@ -0,0 +1 @@ +"""Memory and state for 3D tracking algorithms.""" diff --git a/vis4d/state/track3d/cc_3dt.py b/vis4d/state/track3d/cc_3dt.py new file mode 100644 index 0000000000000000000000000000000000000000..3d9928695a6037d8dd283640fd5bdf60c83911ac --- /dev/null +++ b/vis4d/state/track3d/cc_3dt.py @@ -0,0 +1,577 @@ +"""Memory for CC-3DT inference.""" + +from __future__ import annotations + +from typing import TypedDict + +import torch +from torch import Tensor, nn + +from vis4d.common.typing import DictStrAny +from vis4d.op.box.box2d import bbox_iou +from vis4d.op.track3d.cc_3dt import CC3DTrackAssociation, get_track_3d_out +from vis4d.op.track3d.common import Track3DOut +from vis4d.op.track.assignment import TrackIDCounter + +from .motion import BaseMotionModel, KF3DMotionModel, LSTM3DMotionModel + + +class Track(TypedDict): + """CC-3DT Track state. + + Attributes: + box_2d (Tensor): In shape (4,) and contains x1, y1, x2, y2. + score_2d (Tensor): In shape (1,). + box_3d (Tensor): In shape (12,) contains x,y,z,h,w,l,rx,ry,rz,vx,vy,vz. + score_3d (Tensor): In shape (1,). + class_id (Tensor): In shape (1,). + embed (Tensor): In shape (E,). E is the embedding dimension. + motion_model (BaseMotionModel): The motion model. + velocity (Tensor): In shape (motion_dims,). + last_frame (int): The last frame the track was updated. + acc_frame (int): The number of frames the track was updated. + """ + + box_2d: Tensor + score_2d: Tensor + box_3d: Tensor + score_3d: Tensor + class_id: Tensor + embed: Tensor + motion_model: BaseMotionModel + velocity: Tensor + last_frame: int + acc_frame: int + + +class CC3DTrackGraph: + """CC-3DT tracking graph.""" + + def __init__( + self, + track: CC3DTrackAssociation | None = None, + memory_size: int = 10, + memory_momentum: float = 0.8, + backdrop_memory_size: int = 1, + nms_backdrop_iou_thr: float = 0.3, + motion_model: str = "KF3D", + lstm_model: nn.Module | None = None, + motion_dims: int = 7, + num_frames: int = 5, + fps: int = 2, + update_3d_score: bool = True, + use_velocities: bool = False, + add_backdrops: bool = True, + ) -> None: + """Creates an instance of the class.""" + assert memory_size >= 0 + self.memory_size = memory_size + assert 0 <= memory_momentum <= 1.0 + self.memory_momentum = memory_momentum + assert backdrop_memory_size >= 0 + self.backdrop_memory_size = backdrop_memory_size + self.nms_backdrop_iou_thr = nms_backdrop_iou_thr + + self.tracker = CC3DTrackAssociation() if track is None else track + + self.tracklets: dict[int, Track] = {} + self.backdrops: list[DictStrAny] = [] + + if motion_model == "VeloLSTM": + assert ( + lstm_model is not None + ), "lstm_model must be provided for VeloLSTM" + self.lstm_model = lstm_model + + self.motion_model = motion_model + self.motion_dims = motion_dims + self.num_frames = num_frames + self.fps = fps + self.update_3d_score = update_3d_score + self.add_backdrops = add_backdrops + self.use_velocities = use_velocities + + def reset(self) -> None: + """Empty the memory.""" + self.tracklets.clear() + self.backdrops.clear() + + def is_empty(self) -> bool: + """Check if the memory is empty.""" + return len(self.tracklets) == 0 + + def get_tracks( + self, + device: torch.device, + frame_id: int | None = None, + add_backdrops: bool = False, + ) -> tuple[ + Tensor, + Tensor, + Tensor, + Tensor, + Tensor, + Tensor, + Tensor, + list[BaseMotionModel], + Tensor, + ]: + """Get tracklests. + + If the frame_id is not provided, will return the latest state of all + tracklets. Otherwise, will return the state of all tracklets at the + given frame_id. If add_backdrops is True, will also return the + backdrops. + + Args: + device (torch.device): Device to put the tensors on. + frame_id (int, optional): Frame id to query. Defaults to None. + add_backdrops (bool, optional): Whether to add backdrops to the + output. Defaults to False. + + Returns: + boxes_2d (Tensor): 2D boxes in shape (N, 4). + scores_2d (Tensor): 2D scores in shape (N,). + boxes_3d (Tensor): 3D boxes in shape (N, 12). + scores_3d (Tensor): 3D scores in shape (N,). + class_ids (Tensor): Class ids in shape (N,). + track_ids (Tensor): Track ids in shape (N,). + embeds (Tensor): Embeddings in shape (N, E). + motion_models (list[BaseMotionModel]): Motion models. + velocities (Tensor): Velocities in shape (N, 3). + """ + ( + boxes_2d_list, + scores_2d_list, + boxes_3d_list, + scores_3d_list, + class_ids_list, + embeds_list, + motion_models, + velocities_list, + track_ids_list, + ) = ([], [], [], [], [], [], [], [], []) + + for track_id, track in self.tracklets.items(): + if frame_id is None or track["last_frame"] == frame_id: + boxes_2d_list.append(track["box_2d"].unsqueeze(0)) + scores_2d_list.append(track["score_2d"].unsqueeze(0)) + boxes_3d_list.append(track["box_3d"].unsqueeze(0)) + scores_3d_list.append(track["score_3d"].unsqueeze(0)) + class_ids_list.append(track["class_id"].unsqueeze(0)) + embeds_list.append(track["embed"].unsqueeze(0)) + motion_models.append(track["motion_model"]) + velocities_list.append(track["velocity"].unsqueeze(0)) + track_ids_list.append(track_id) + + boxes_2d = ( + torch.cat(boxes_2d_list) + if len(boxes_2d_list) > 0 + else torch.empty((0, 4), device=device) + ) + scores_2d = ( + torch.cat(scores_2d_list) + if len(scores_2d_list) > 0 + else torch.empty((0,), device=device) + ) + boxes_3d = ( + torch.cat(boxes_3d_list) + if len(boxes_3d_list) > 0 + else torch.empty((0, 12), device=device) + ) + scores_3d = ( + torch.cat(scores_3d_list) + if len(scores_3d_list) > 0 + else torch.empty((0,), device=device) + ) + class_ids = ( + torch.cat(class_ids_list) + if len(class_ids_list) > 0 + else torch.empty((0,), device=device) + ) + embeds = ( + torch.cat(embeds_list) + if len(embeds_list) > 0 + else torch.empty((0,), device=device) + ) + velocities = ( + torch.cat(velocities_list) + if len(velocities_list) > 0 + else torch.empty((0, self.motion_dims), device=device) + ) + track_ids = torch.tensor(track_ids_list, device=device) + + if add_backdrops: + for backdrop in self.backdrops: + backdrop_ids = torch.full( + (len(backdrop["embeddings"]),), + -1, + dtype=torch.long, + device=device, + ) + track_ids = torch.cat([track_ids, backdrop_ids]) + boxes_2d = torch.cat([boxes_2d, backdrop["boxes_2d"]]) + scores_2d = torch.cat([scores_2d, backdrop["scores_2d"]]) + boxes_3d = torch.cat([boxes_3d, backdrop["boxes_3d"]]) + scores_3d = torch.cat([scores_3d, backdrop["scores_3d"]]) + class_ids = torch.cat([class_ids, backdrop["class_ids"]]) + embeds = torch.cat([embeds, backdrop["embeddings"]]) + motion_models.extend(backdrop["motion_models"]) + backdrop_vs = torch.zeros_like( + backdrop["boxes_3d"][:, : self.motion_dims] + ) + velocities = torch.cat([velocities, backdrop_vs]) + + return ( + boxes_2d, + scores_2d, + boxes_3d, + scores_3d, + class_ids, + track_ids, + embeds, + motion_models, + velocities, + ) + + def __call__( + self, + boxes_2d: Tensor, + scores_2d: Tensor, + camera_ids: Tensor, + boxes_3d: Tensor, + scores_3d: Tensor, + class_ids: Tensor, + embeddings: Tensor, + frame_id: int, + ) -> Track3DOut: + """Update the tracker with new detections.""" + if frame_id == 0: + self.reset() + TrackIDCounter.reset() + + if not self.is_empty(): + ( + _, + _, + memo_boxes_3d, + _, + memo_class_ids, + memo_track_ids, + memo_embeds, + memo_motion_models, + memo_velocities, + ) = self.get_tracks( + boxes_2d.device, add_backdrops=self.add_backdrops + ) + + memory_boxes_3d = torch.cat( + [memo_boxes_3d[:, :6], memo_boxes_3d[:, 8].unsqueeze(1)], + dim=1, + ) + + memory_track_ids = memo_track_ids + memory_class_ids = memo_class_ids + memory_embeddings = memo_embeds + + memory_boxes_3d_predict = memory_boxes_3d.clone() + for i, memo_motion_model in enumerate(memo_motion_models): + pd_box_3d = memo_motion_model.predict( + update_state=memo_motion_model.age != 0 + ) + memory_boxes_3d_predict[i, :3] += pd_box_3d[self.motion_dims :] + + memory_velocities = memo_velocities + + else: + memory_boxes_3d = None + memory_track_ids = None + memory_class_ids = None + memory_embeddings = None + memory_boxes_3d_predict = None + memory_velocities = None + + obs_velocities = boxes_3d[:, 9:] + obs_boxes_3d = torch.cat( + [boxes_3d[:, :6], boxes_3d[:, 8].unsqueeze(1)], dim=1 + ) + + track_ids, filter_indices = self.tracker( + boxes_2d, + camera_ids, + scores_2d, + obs_boxes_3d, + scores_3d, + class_ids, + embeddings, + obs_velocities, + memory_boxes_3d, + memory_track_ids, + memory_class_ids, + memory_embeddings, + memory_boxes_3d_predict, + memory_velocities, + self.update_3d_score, + ) + + self.update( + frame_id, + track_ids, + boxes_2d[filter_indices], + scores_2d[filter_indices], + camera_ids[filter_indices], + boxes_3d[filter_indices], + scores_3d[filter_indices], + class_ids[filter_indices], + embeddings[filter_indices], + obs_boxes_3d[filter_indices], + ) + + ( + _, + scores_2d, + boxes_3d, + scores_3d, + class_ids, + track_ids, + _, + _, + _, + ) = self.get_tracks(boxes_2d.device, frame_id=frame_id) + + # update 3D score + if self.update_3d_score: + track_scores_3d = scores_2d * scores_3d + else: + track_scores_3d = scores_3d + + return get_track_3d_out( + boxes_3d, class_ids, track_scores_3d, track_ids + ) + + def update( + self, + frame_id: int, + track_ids: Tensor, + boxes_2d: Tensor, + scores_2d: Tensor, + camera_ids: Tensor, + boxes_3d: Tensor, + scores_3d: Tensor, + class_ids: Tensor, + embeddings: Tensor, + obs_boxes_3d: Tensor, + ) -> None: + """Update the track memory with a new state.""" + valid_tracks = track_ids > -1 + + # update memo + for ( + track_id, + box_2d, + score_2d, + box_3d, + score_3d, + class_id, + embed, + obs_box_3d, + ) in zip( + track_ids[valid_tracks], + boxes_2d[valid_tracks], + scores_2d[valid_tracks], + boxes_3d[valid_tracks], + scores_3d[valid_tracks], + class_ids[valid_tracks], + embeddings[valid_tracks], + obs_boxes_3d[valid_tracks], + ): + track_id = int(track_id) + if track_id in self.tracklets: + self.update_track( + track_id, + box_2d, + score_2d, + box_3d, + score_3d, + class_id, + embed, + obs_box_3d, + frame_id, + ) + else: + self.create_track( + track_id, + box_2d, + score_2d, + box_3d, + score_3d, + class_id, + embed, + obs_box_3d, + frame_id, + ) + + # Handle vanished tracklets + for track_id, track in self.tracklets.items(): + if frame_id > track["last_frame"] and track_id > -1: + pd_box_3d = track["motion_model"].predict() + track["box_3d"][:6] = pd_box_3d[:6] + track["box_3d"][8] = pd_box_3d[6] + + # Backdrops + backdrop_inds = torch.nonzero( + torch.eq(track_ids, -1), as_tuple=False + ).squeeze(1) + + valid_ious = torch.eq( + camera_ids[backdrop_inds].unsqueeze(1), + camera_ids.unsqueeze(0), + ).int() + ious = bbox_iou(boxes_2d[backdrop_inds], boxes_2d) + ious *= valid_ious + + for i, ind in enumerate(backdrop_inds): + if (ious[i, :ind] > self.nms_backdrop_iou_thr).any(): + backdrop_inds[i] = -1 + backdrop_inds = backdrop_inds[backdrop_inds > -1] + + backdrop_motion_model = [] + for bd_ind in backdrop_inds: + backdrop_motion_model.append( + self.build_motion_model(obs_boxes_3d[bd_ind]) + ) + + self.backdrops.insert( + 0, + { + "boxes_2d": boxes_2d[backdrop_inds], + "scores_2d": scores_2d[backdrop_inds], + "boxes_3d": boxes_3d[backdrop_inds], + "scores_3d": scores_3d[backdrop_inds], + "class_ids": class_ids[backdrop_inds], + "embeddings": embeddings[backdrop_inds], + "motion_models": backdrop_motion_model, + }, + ) + + # delete invalid tracks from memory + invalid_ids = [] + for k, v in self.tracklets.items(): + if frame_id - v["last_frame"] >= self.memory_size: + invalid_ids.append(k) + for invalid_id in invalid_ids: + self.tracklets.pop(invalid_id) + + if len(self.backdrops) > self.backdrop_memory_size: + self.backdrops.pop() + + def update_track( + self, + track_id: int, + box_2d: Tensor, + score_2d: Tensor, + box_3d: Tensor, + score_3d: Tensor, + class_id: Tensor, + embed: Tensor, + obs_box_3d: Tensor, + frame_id: int, + ) -> None: + """Update a track.""" + self.tracklets[track_id]["box_2d"] = box_2d + self.tracklets[track_id]["score_2d"] = score_2d + self.tracklets[track_id]["motion_model"].update(obs_box_3d, score_3d) + + pd_box_3d = self.tracklets[track_id]["motion_model"].get_state()[ + : self.motion_dims + ] + + prev_obs = torch.cat( + [ + self.tracklets[track_id]["box_3d"][:6], + self.tracklets[track_id]["box_3d"][8].unsqueeze(0), + ] + ) + + self.tracklets[track_id]["box_3d"] = box_3d + self.tracklets[track_id]["box_3d"][:6] = pd_box_3d[:6] + self.tracklets[track_id]["box_3d"][8] = pd_box_3d[6] + self.tracklets[track_id]["box_3d"][9:12] = self.tracklets[track_id][ + "motion_model" + ].predict_velocity() + self.tracklets[track_id]["score_3d"] = score_3d + self.tracklets[track_id]["class_id"] = class_id + + self.tracklets[track_id]["embed"] = ( + 1 - self.memory_momentum + ) * self.tracklets[track_id]["embed"] + self.memory_momentum * embed + + velocity = (pd_box_3d - prev_obs) / ( + frame_id - self.tracklets[track_id]["last_frame"] + ) + + self.tracklets[track_id]["velocity"] = ( + self.tracklets[track_id]["velocity"] + * self.tracklets[track_id]["acc_frame"] + + velocity + ) / (self.tracklets[track_id]["acc_frame"] + 1) + + # Use predicted velocity if available + if self.use_velocities: + self.tracklets[track_id]["velocity"][4:] = self.tracklets[ + track_id + ]["box_3d"][9:12] + + self.tracklets[track_id]["last_frame"] = frame_id + self.tracklets[track_id]["acc_frame"] += 1 + + def create_track( + self, + track_id: int, + box_2d: Tensor, + score_2d: Tensor, + box_3d: Tensor, + score_3d: Tensor, + class_id: Tensor, + embed: Tensor, + obs_box_3d: Tensor, + frame_id: int, + ) -> None: + """Create a new track.""" + motion_model = self.build_motion_model(obs_box_3d) + + self.tracklets[track_id] = Track( + box_2d=box_2d, + score_2d=score_2d, + box_3d=box_3d, + score_3d=score_3d, + class_id=class_id, + embed=embed, + motion_model=motion_model, + velocity=torch.zeros(self.motion_dims, device=box_3d.device), + last_frame=frame_id, + acc_frame=0, + ) + + def build_motion_model(self, obs_3d: Tensor) -> BaseMotionModel: + """Build motion model.""" + if self.motion_model == "KF3D": + return KF3DMotionModel( + num_frames=self.num_frames, + obs_3d=obs_3d, + motion_dims=self.motion_dims, + fps=self.fps, + ) + + if self.motion_model == "VeloLSTM": + return LSTM3DMotionModel( + num_frames=self.num_frames, + lstm_model=self.lstm_model, + obs_3d=obs_3d, + motion_dims=self.motion_dims, + fps=self.fps, + ) + + raise NotImplementedError( + f"Motion model: {self.motion_model} not known!" + ) diff --git a/vis4d/state/track3d/motion/__init__.py b/vis4d/state/track3d/motion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..585785bbed43e6fb75a904c39e13760de71a3ef1 --- /dev/null +++ b/vis4d/state/track3d/motion/__init__.py @@ -0,0 +1,7 @@ +"""3D Motional Models.""" + +from .base import BaseMotionModel +from .kf3d import KF3DMotionModel +from .lstm_3d import LSTM3DMotionModel + +__all__ = ["BaseMotionModel", "KF3DMotionModel", "LSTM3DMotionModel"] diff --git a/vis4d/state/track3d/motion/base.py b/vis4d/state/track3d/motion/base.py new file mode 100644 index 0000000000000000000000000000000000000000..c83fa6357d9ca977a8e44d76a7655ec0dfb79676 --- /dev/null +++ b/vis4d/state/track3d/motion/base.py @@ -0,0 +1,50 @@ +"""Motion model base class.""" + +from torch import Tensor + + +class BaseMotionModel: + """Base class for motion model.""" + + def __init__( + self, + num_frames: int, + motion_dims: int, + hits: int = 1, + hit_streak: int = 0, + time_since_update: int = 0, + age: int = 0, + fps: int = 1, + ) -> None: + """Creates an instance of the class.""" + self.num_frames = num_frames + self.motion_dims = motion_dims + self.hits = hits + self.hit_streak = hit_streak + self.time_since_update = time_since_update + self.age = age + self.fps = fps + + def update(self, obs_3d: Tensor, info: Tensor) -> None: + """Update the state.""" + raise NotImplementedError() + + def predict_velocity(self) -> Tensor: + """Predict velocity.""" + raise NotImplementedError() + + def predict(self, update_state: bool = True) -> Tensor: + """Predict the state.""" + raise NotImplementedError() + + def get_state(self) -> Tensor: + """Get the state.""" + raise NotImplementedError() + + +def update_array(origin_array: Tensor, input_array: Tensor) -> Tensor: + """Update array according the input.""" + new_array = origin_array.clone() + new_array[:-1] = origin_array[1:] + new_array[-1:] = input_array + return new_array diff --git a/vis4d/state/track3d/motion/kf3d.py b/vis4d/state/track3d/motion/kf3d.py new file mode 100644 index 0000000000000000000000000000000000000000..c102cba4df9069d401d262b833e5f760e16017eb --- /dev/null +++ b/vis4d/state/track3d/motion/kf3d.py @@ -0,0 +1,133 @@ +"""Kalman Filter 3D motion model.""" + +from __future__ import annotations + +import torch +from torch import Tensor + +from vis4d.common.typing import ArgsType +from vis4d.op.geometry.rotation import acute_angle, normalize_angle +from vis4d.op.motion.kalman_filter import predict, update + +from .base import BaseMotionModel + + +class KF3DMotionModel(BaseMotionModel): + """Kalman filter 3D motion model.""" + + def __init__( + self, + *args: ArgsType, + obs_3d: Tensor, + **kwargs: ArgsType, + ) -> None: + """Creates an instance of the class.""" + super().__init__(*args, **kwargs) + self.device = obs_3d.device + + # F, H, Q, R + ( + self._motion_mat, + self._update_mat, + self._cov_motion_q, + self._cov_project_r, + ) = self._kf3d_init() + + self._motion_mat = self._motion_mat.to(self.device) + self._update_mat = self._update_mat.to(self.device) + self._cov_motion_q = self._cov_motion_q.to(self.device) + self._cov_project_r = self._cov_project_r.to(self.device) + + self.mean, self.covariance = self._init_mean_cov(obs_3d) + + def _kf3d_init(self) -> tuple[Tensor, Tensor, Tensor, Tensor]: + """KF3D init function.""" + motion_mat = torch.Tensor( + [ + [1, 0, 0, 0, 0, 0, 0, 1, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1], + ] + ) + + update_mat = torch.Tensor( + [ + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + ] + ) + + cov_motion_q = torch.eye(self.motion_dims + 3) + cov_motion_q[self.motion_dims :, self.motion_dims :] *= 0.01 + + cov_project_r = torch.eye(self.motion_dims) + return motion_mat, update_mat, cov_motion_q, cov_project_r + + def _init_mean_cov(self, obs_3d: Tensor) -> tuple[Tensor, Tensor]: + """Init KF3D mean and covariance.""" + mean = torch.zeros(self.motion_dims + 3).to(obs_3d.device) + mean[: self.motion_dims] = obs_3d + covariance = torch.eye(self.motion_dims + 3).to(obs_3d.device) * 10.0 + covariance[self.motion_dims :, self.motion_dims :] *= 1000.0 + return mean, covariance + + def update(self, obs_3d: Tensor, info: Tensor) -> None: + """Update the state.""" + self.time_since_update = 0 + self.hits += 1 + self.hit_streak += 1 + + self.mean[6] = normalize_angle(self.mean[6]) + obs_3d[6] = normalize_angle(obs_3d[6]) + + self.mean[6] = acute_angle(self.mean[6], obs_3d[6]) + + self.mean, self.covariance = update( + self._update_mat, + self._cov_project_r, + self.mean, + self.covariance, + obs_3d, + ) + self.mean[6] = normalize_angle(self.mean[6]) + + def predict_velocity(self) -> Tensor: + """Predict velocity.""" + pred_loc, _ = predict( + self._motion_mat, + self._cov_motion_q, + self.mean, + self.covariance, + ) + return (pred_loc[:3] - self.mean[:3]) * self.fps + + def predict(self, update_state: bool = True) -> Tensor: + """Predict the state.""" + self.mean, self.covariance = predict( + self._motion_mat, self._cov_motion_q, self.mean, self.covariance + ) + + self.mean[6] = normalize_angle(self.mean[6]) + + self.age += 1 + if self.time_since_update > 0: + self.hit_streak = 0 + self.time_since_update += 1 + + return self.mean + + def get_state(self) -> Tensor: + """Returns the current bounding box estimate.""" + return self.mean diff --git a/vis4d/state/track3d/motion/lstm_3d.py b/vis4d/state/track3d/motion/lstm_3d.py new file mode 100644 index 0000000000000000000000000000000000000000..d943e7c1c714701817beedde32ed52ceccd4d297 --- /dev/null +++ b/vis4d/state/track3d/motion/lstm_3d.py @@ -0,0 +1,149 @@ +"""LSTM 3D motion model.""" + +from __future__ import annotations + +import torch +from torch import Tensor, nn + +from vis4d.common.typing import ArgsType +from vis4d.model.motion.velo_lstm import VeloLSTM +from vis4d.op.geometry.rotation import acute_angle, normalize_angle + +from .base import BaseMotionModel, update_array + + +class LSTM3DMotionModel(BaseMotionModel): + """LSTM 3D motion model.""" + + def __init__( + self, + *args: ArgsType, + lstm_model: nn.Module, + obs_3d: Tensor, + init_flag: bool = True, + **kwargs: ArgsType, + ) -> None: + """Initialize a motion model using initial bounding box.""" + super().__init__(*args, **kwargs) + self.init_flag = init_flag + self.device = obs_3d.device + + assert isinstance( + lstm_model, VeloLSTM + ), "Currently only support VeloLSTM motion model!" + self.lstm_model = lstm_model + self.lstm_model.to(self.device) + self.lstm_model.eval() + + self.obj_state = torch.cat([obs_3d, obs_3d.new_zeros(3)]) + self.history = obs_3d.new_zeros(self.num_frames, self.motion_dims) + self.ref_history = torch.cat( + [obs_3d.view(1, self.motion_dims)] * (self.num_frames + 1) + ) + self.prev_ref = obs_3d.clone() + self.hidden_pred = self.lstm_model.init_hidden( + self.device, batch_size=1 + ) + self.hidden_ref = self.lstm_model.init_hidden( + self.device, batch_size=1 + ) + + def _update_history(self, bbox_3d: Tensor) -> None: + """Update velocity history.""" + self.ref_history = update_array(self.ref_history, bbox_3d) + self.history = update_array( + self.history, self.ref_history[-1] - self.ref_history[-2] + ) + self.prev_ref[: self.motion_dims] = self.obj_state[: self.motion_dims] + + def _init_history(self, bbox_3d: Tensor) -> None: + """Initialize velocity history.""" + self.ref_history = update_array(self.ref_history, bbox_3d) + self.history = torch.cat( + [ + (self.ref_history[-1] - self.ref_history[-2]).view( + 1, self.motion_dims + ) + ] + * self.num_frames + ) + self.prev_ref[: self.motion_dims] = self.obj_state[: self.motion_dims] + + def update(self, obs_3d: Tensor, info: Tensor) -> None: + """Updates the state vector with observed bbox.""" + self.time_since_update = 0 + self.hits += 1 + self.hit_streak += 1 + + if self.age == 1: + self.obj_state[: self.motion_dims] = obs_3d.clone() + + self.obj_state[6] = normalize_angle(self.obj_state[6]) + obs_3d[6] = normalize_angle(obs_3d[6]) + + # acute angle + self.obj_state[6] = acute_angle(self.obj_state[6], obs_3d[6]) + + with torch.no_grad(): + refined_loc, self.hidden_ref = self.lstm_model.refine( + self.obj_state[: self.motion_dims].unsqueeze(0), + obs_3d.unsqueeze(0), + self.prev_ref.unsqueeze(0), + info.unsqueeze(0).unsqueeze(0), + self.hidden_ref, + ) + + refined_obj = refined_loc.view(self.motion_dims) + refined_obj[6] = normalize_angle(refined_obj[6]) + + self.obj_state[: self.motion_dims] = refined_obj + + if self.init_flag: + self._init_history(refined_obj) + self.init_flag = False + else: + self._update_history(refined_obj) + + def predict_velocity(self) -> Tensor: + """Predict velocity.""" + with torch.no_grad(): + pred_loc, _ = self.lstm_model.predict( + self.history[..., : self.motion_dims].view( + self.num_frames, -1, self.motion_dims + ), + self.obj_state[: self.motion_dims], + self.hidden_pred, + ) + return (pred_loc[0][:3] - self.prev_ref[:3]) * self.fps + + def predict(self, update_state: bool = True) -> Tensor: + """Advances the state vector and returns the predicted bounding box.""" + with torch.no_grad(): + pred_loc, hidden_pred = self.lstm_model.predict( + self.history[..., : self.motion_dims].view( + self.num_frames, -1, self.motion_dims + ), + self.obj_state[: self.motion_dims], + self.hidden_pred, + ) + + pred_state = self.obj_state.clone() + pred_state[: self.motion_dims] = pred_loc.view(self.motion_dims) + pred_state[self.motion_dims :] = pred_state[:3] - self.prev_ref[:3] + + pred_state[6] = normalize_angle(pred_state[6]) + + if update_state: + self.hidden_pred = hidden_pred + self.obj_state = pred_state + + self.age += 1 + if self.time_since_update > 0: + self.hit_streak = 0 + self.time_since_update += 1 + + return pred_state + + def get_state(self) -> Tensor: + """Returns the current bounding box estimate.""" + return self.obj_state diff --git a/vis4d/vis/__init__.py b/vis4d/vis/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f08cdc210651555f0038f2f178dc863c06e047df --- /dev/null +++ b/vis4d/vis/__init__.py @@ -0,0 +1 @@ +"""Contains visualization tools for a variety of data types.""" diff --git a/vis4d/vis/__pycache__/__init__.cpython-311.pyc b/vis4d/vis/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..838c5141e849e7b06b452cb976958e6cf8a39473 Binary files /dev/null and b/vis4d/vis/__pycache__/__init__.cpython-311.pyc differ diff --git a/vis4d/vis/__pycache__/util.cpython-311.pyc b/vis4d/vis/__pycache__/util.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c49c968e4486f9a0572cb093c44c6d619f70c2f Binary files /dev/null and b/vis4d/vis/__pycache__/util.cpython-311.pyc differ diff --git a/vis4d/vis/base.py b/vis4d/vis/base.py new file mode 100644 index 0000000000000000000000000000000000000000..1985f00435521160ec52aa377aa2f5385fe0ea96 --- /dev/null +++ b/vis4d/vis/base.py @@ -0,0 +1,53 @@ +"""Visualizer base class.""" + +from vis4d.common.typing import ArgsType + + +class Visualizer: + """Base visualizer class.""" + + def __init__(self, vis_freq: int = 50, image_mode: str = "RGB") -> None: + """Initialize the visualizer. + + Args: + vis_freq (int): Visualization frequency. Defaults to 0. + image_mode (str): Image channel mode (RGB or BGR). + """ + self.vis_freq = vis_freq + self.image_mode = image_mode + assert image_mode in {"RGB", "BGR"} + + def _run_on_batch(self, cur_iter: int) -> bool: + """Return whether to run on current iteration. + + Args: + cur_iter (int): Current iteration. + """ + return cur_iter % self.vis_freq == 0 + + def reset(self) -> None: + """Reset visualizer for new round of evaluation.""" + raise NotImplementedError() + + def process(self, cur_iter: int, *args: ArgsType) -> None: + """Process data of single sample.""" + raise NotImplementedError() + + def show(self, cur_iter: int, blocking: bool = True) -> None: + """Shows the visualization. + + Args: + cur_iter (int): Current iteration. + blocking (bool): If the visualization should be blocking and wait + for human input. Defaults to True. + """ + raise NotImplementedError() + + def save_to_disk(self, cur_iter: int, output_folder: str) -> None: + """Saves the visualization to disk. + + Args: + cur_iter (int): Current iteration. + output_folder (str): Folder where the output should be written. + """ + raise NotImplementedError() diff --git a/vis4d/vis/image/__init__.py b/vis4d/vis/image/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7525db917d3bd5b8bd108a354f9fe215a7480d34 --- /dev/null +++ b/vis4d/vis/image/__init__.py @@ -0,0 +1,6 @@ +"""Image Visualization.""" + +from .bounding_box_visualizer import BoundingBoxVisualizer +from .seg_mask_visualizer import SegMaskVisualizer + +__all__ = ["BoundingBoxVisualizer", "SegMaskVisualizer"] diff --git a/vis4d/vis/image/bbox3d_visualizer.py b/vis4d/vis/image/bbox3d_visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..1f336f48bbb91c7288f4588c5139902de1bab943 --- /dev/null +++ b/vis4d/vis/image/bbox3d_visualizer.py @@ -0,0 +1,465 @@ +"""Bounding box 3D visualizer.""" + +from __future__ import annotations + +import os +from collections import defaultdict +from collections.abc import Sequence +from dataclasses import dataclass + +import numpy as np +import torch + +from vis4d.common.array import array_to_numpy +from vis4d.common.typing import ( + ArgsType, + ArrayLike, + ArrayLikeFloat, + ArrayLikeInt, + NDArrayF32, + NDArrayUI8, +) +from vis4d.data.const import AxisMode +from vis4d.op.geometry.transform import inverse_rigid_transform +from vis4d.vis.base import Visualizer +from vis4d.vis.util import generate_color_map + +from .canvas import CanvasBackend, PillowCanvasBackend +from .util import preprocess_boxes3d, preprocess_image, project_point +from .viewer import ImageViewerBackend, MatplotlibImageViewer + + +@dataclass +class DetectionBox3D: + """Dataclass storing box informations.""" + + corners: list[tuple[float, float, float]] + label: str + color: tuple[int, int, int] + track_id: int | None + + +@dataclass +class DataSample: + """Dataclass storing a data sample that can be visualized.""" + + image: NDArrayUI8 + image_name: str + intrinsics: NDArrayF32 + extrinsics: NDArrayF32 | None + sequence_name: str | None + camera_name: str | None + boxes: list[DetectionBox3D] + + +class BoundingBox3DVisualizer(Visualizer): + """Bounding box 3D visualizer class.""" + + def __init__( + self, + *args: ArgsType, + n_colors: int = 50, + cat_mapping: dict[str, int] | None = None, + file_type: str = "png", + image_mode: str = "RGB", + width: int = 2, + camera_near_clip: float = 0.15, + plot_heading: bool = True, + axis_mode: AxisMode = AxisMode.ROS, + trajectory_length: int = 10, + plot_trajectory: bool = True, + save_boxes3d: bool = False, + canvas: CanvasBackend | None = None, + viewer: ImageViewerBackend | None = None, + **kwargs: ArgsType, + ) -> None: + """Creates a new Visualizer for Image and 3D Bounding Boxes. + + Args: + n_colors (int): How many colors should be used for the internal + color map. Defaults to 100. + cat_mapping (dict[str, int]): Mapping from class names to class + ids. Defaults to None. + file_type (str): Desired file type. Defaults to "png". + image_mode (str): Image channel mode (RGB or BGR). Defaults to + "RGB". + width (int): Width of the drawn bounding boxes. Defaults to 2. + camera_near_clip (float): Near clipping plane of the camera. + Defaults to 0.15. + plot_heading (bool): If the heading should be plotted. Defaults to + True. + axis_mode (AxisMode): Axis mode for the input bboxes. Defaults to + AxisMode.ROS (i.e. global coordinate). + trajectory_length (int): How many past frames should be used to + draw the trajectory. Defaults to 10. + plot_trajectory (bool): If the trajectory should be plotted. + Defaults to True. + save_boxes3d (bool): If the corners of 3D boxes should be saved to + disk in the format of npy. Defaults to False. + canvas (CanvasBackend): Backend that is used to draw on images. If + None a PillowCanvasBackend is used. + viewer (ImageViewerBackend): Backend that is used show images. If + None a MatplotlibImageViewer is used. + """ + super().__init__(*args, **kwargs) + self._samples: list[DataSample] = [] + self.axis_mode = axis_mode + self.trajectories: dict[int, list[tuple[float, float, float]]] = ( + defaultdict(list) + ) + self.trajectory_length = trajectory_length + self.plot_trajectory = plot_trajectory + + self.color_palette = generate_color_map(n_colors) + + self.class_id_mapping = ( + {v: k for k, v in cat_mapping.items()} + if cat_mapping is not None + else {} + ) + + self.file_type = file_type + self.image_mode = image_mode + self.width = width + + self.camera_near_clip = camera_near_clip + self.plot_heading = plot_heading + self.save_boxes3d = save_boxes3d + + self.canvas = canvas if canvas is not None else PillowCanvasBackend() + self.viewer = viewer if viewer is not None else MatplotlibImageViewer() + + def reset(self) -> None: + """Reset visualizer.""" + self._samples.clear() + + def __repr__(self) -> str: + """Return string representation.""" + return "BoundingBox3DVisualizer" + + def process( # pylint: disable=arguments-differ + self, + cur_iter: int, + images: list[ArrayLike], + image_names: list[str], + boxes3d: list[ArrayLikeFloat], + intrinsics: ArrayLikeFloat, + extrinsics: None | ArrayLikeFloat = None, + scores: None | list[ArrayLikeFloat] = None, + class_ids: None | list[ArrayLikeInt] = None, + track_ids: None | list[ArrayLikeInt] = None, + sequence_names: None | list[str] = None, + categories: None | list[list[str]] = None, + ) -> None: + """Processes a batch of data. + + Args: + cur_iter (int): Current iteration. + images (list[ArrayLike]): Images to show. + image_names (list[str]): Image names. + boxes3d (list[ArrayLikeFloat]): List of predicted bounding boxes + with shape [B, N, 10]. + intrinsics (ArrayLikeFloat): Camera intrinsics with shape + [B, 3, 3]. + extrinsics (None | ArrayLikeFloat, optional): Camera extrinsics + with shape [B, 4, 4]. Defaults to None. + scores (None | list[ArrayLikeFloat], optional): List of predicted + box scores each of shape [B, N]. Defaults to None. + class_ids (None | list[ArrayLikeInt], optional): List of predicted + class ids each of shape [B, N]. Defaults to None. + track_ids (None | list[ArrayLikeInt], optional): List of predicted + track ids each of shape [B, N]. Defaults to None. + sequence_names (None | list[str], optional): List of sequence + names of shape [B,]. Defaults to None. + categories (None | list[list[str]], optional): List of categories + for each image. Instead of class ids, the categories will be + used to label the boxes. Defaults to None. + """ + if self._run_on_batch(cur_iter): + for batch, image in enumerate(images): + self.process_single_image( + image, + image_names[batch], + boxes3d[batch], + intrinsics[batch], # type: ignore + ( + None if extrinsics is None else extrinsics[batch] # type: ignore # pylint: disable=line-too-long + ), + None if scores is None else scores[batch], + None if class_ids is None else class_ids[batch], + None if track_ids is None else track_ids[batch], + None if sequence_names is None else sequence_names[batch], + None if categories is None else categories[batch], + ) + + for tid in self.trajectories: + if len(self.trajectories[tid]) > self.trajectory_length: + self.trajectories[tid].pop(0) + + def process_single_image( + self, + image: ArrayLike, + image_name: str, + boxes3d: ArrayLikeFloat, + intrinsics: ArrayLikeFloat, + extrinsics: None | ArrayLikeFloat = None, + scores: None | ArrayLikeFloat = None, + class_ids: None | ArrayLikeInt = None, + track_ids: None | ArrayLikeInt = None, + sequence_name: None | str = None, + categories: None | list[str] = None, + camera_name: None | str = None, + ) -> None: + """Processes a single image entry. + + Args: + image (ArrayLike): Image to show. + image_name (str): Image name. + boxes3d (ArrayLikeFloat): Predicted bounding boxes with shape + [N, 10], where N is the number of boxes. + intrinsics (ArrayLikeFloat): Camera intrinsics with shape [3, 3]. + extrinsics (None | ArrayLikeFloat, optional): Camera extrinsics + with shape [4, 4]. Defaults to None. + scores (None | ArrayLikeFloat, optional): Predicted box scores of + shape [N]. Defaults to None. + class_ids (None | ArrayLikeInt, optional): Predicted class ids of + shape [N]. Defaults to None. + track_ids (None | ArrayLikeInt, optional): Predicted track ids of + shape [N]. Defaults to None. + sequence_name (None | str, optional): Sequence name. Defaults to + None. + categories (None | list[str], optional): List of categories for + each box. Instead of class ids, the categories will be used to + label the boxes. Defaults to None. + camera_name (None | str, optional): Camera name. Defaults to None. + """ + img_normalized = preprocess_image(image, mode=self.image_mode) + image_hw = (img_normalized.shape[0], img_normalized.shape[1]) + + intrinsics_np = array_to_numpy(intrinsics, n_dims=2, dtype=np.float32) + extrinsics_np = ( + array_to_numpy(extrinsics, n_dims=2, dtype=np.float32) + if extrinsics is not None + else None + ) + data_sample = DataSample( + img_normalized, + image_name, + intrinsics_np, + extrinsics_np, + sequence_name, + camera_name, + [], + ) + + if len(boxes3d) != 0: # type: ignore + for center, corners, label, color, track_id in zip( + *preprocess_boxes3d( + image_hw, + boxes3d, + intrinsics, + extrinsics, + scores, + class_ids, + track_ids, + self.color_palette, + self.class_id_mapping, + axis_mode=self.axis_mode, + categories=categories, + ) + ): + data_sample.boxes.append( + DetectionBox3D( + corners=corners, + label=label, + color=color, + track_id=track_id, + ) + ) + if track_id is not None: + self.trajectories[track_id].append(center) + + self._samples.append(data_sample) + + def show(self, cur_iter: int, blocking: bool = True) -> None: + """Shows the processed images in a interactive window. + + Args: + cur_iter (int): Current iteration. + blocking (bool): If the visualizer should be blocking i.e. wait for + human input for each image. Defaults to True. + """ + if self._run_on_batch(cur_iter): + image_data = [self._draw_image(d) for d in self._samples] + self.viewer.show_images(image_data, blocking=blocking) + + def _draw_image(self, sample: DataSample) -> NDArrayUI8: + """Visualizes the datasample and returns is as numpy image. + + Args: + sample (DataSample): The data sample to visualize. + + Returns: + NDArrayUI8: A image with the visualized data sample. + """ + self.canvas.create_canvas(sample.image) + + if self.plot_trajectory: + assert ( + sample.extrinsics is not None + ), "Extrinsics is needed to plot trajectory." + global_to_cam = inverse_rigid_transform( + torch.from_numpy(sample.extrinsics) + ).numpy() + + for box in sample.boxes: + self.canvas.draw_box_3d( + box.corners, + box.color, + sample.intrinsics, + self.width, + self.camera_near_clip, + self.plot_heading, + ) + + selected_corner = project_point(box.corners[0], sample.intrinsics) + self.canvas.draw_text( + (selected_corner[0], selected_corner[1]), box.label, box.color + ) + + if self.plot_trajectory: + assert ( + box.track_id is not None + ), "track id must be set to plot trajectory." + + trajectory = self.trajectories[box.track_id] + for center in trajectory: + # Move global center to current camera frame + center_cam = np.dot(global_to_cam, [*center, 1])[:3] + + if center_cam[2] > 0: + projected_center = project_point( + center_cam, sample.intrinsics + ) + self.canvas.draw_circle( + projected_center, box.color, self.width * 2 + ) + + return self.canvas.as_numpy_image() + + def save_to_disk(self, cur_iter: int, output_folder: str) -> None: + """Saves the visualization to disk. + + Writes all processes samples to the output folder naming each image + .. + + Args: + cur_iter (int): Current iteration. + output_folder (str): Folder where the output should be written. + """ + if self._run_on_batch(cur_iter): + for sample in self._samples: + output_dir = output_folder + image_name = f"{sample.image_name}.{self.file_type}" + + self._draw_image(sample) + + if sample.sequence_name is not None: + output_dir = os.path.join(output_dir, sample.sequence_name) + + if sample.camera_name is not None: + output_dir = os.path.join(output_dir, sample.camera_name) + + os.makedirs(output_dir, exist_ok=True) + self.canvas.save_to_disk(os.path.join(output_dir, image_name)) + + if self.save_boxes3d: + corners = np.array([box.corners for box in sample.boxes]) + + np.save( + os.path.join(output_dir, f"{sample.image_name}.npy"), + corners, + ) + + +class MultiCameraBBox3DVisualizer(BoundingBox3DVisualizer): + """Bounding box 3D visualizer class for multi-camera datasets.""" + + def __init__( + self, *args: ArgsType, cameras: Sequence[str], **kwargs: ArgsType + ) -> None: + """Creates a new Visualizer for Image and 3D Bounding Boxes. + + Args: + cameras (Sequence[str]): Camera names. + """ + super().__init__(*args, **kwargs) + + self.cameras = cameras + + def __repr__(self) -> str: + """Return string representation.""" + return "MultiCameraBBox3DVisualizer" + + def process( # type: ignore # pylint: disable=arguments-differ + self, + cur_iter: int, + images: list[list[ArrayLike]], + image_names: list[list[str]], + boxes3d: list[ArrayLikeFloat], + intrinsics: list[ArrayLikeFloat], + extrinsics: list[ArrayLikeFloat] | None = None, + scores: list[ArrayLikeFloat] | None = None, + class_ids: list[ArrayLikeInt] | None = None, + track_ids: list[ArrayLikeInt] | None = None, + sequence_names: list[str] | None = None, + categories: None | list[list[str]] = None, + ) -> None: + """Processes a batch of data. + + Args: + cur_iter (int): Current iteration. + images (list[ArrayLike]): Images to show. + image_names (list[str]): Image names. + boxes3d (list[ArrayLikeFloat]): List of predicted bounding boxes + with shape [B, N, 10]. + intrinsics (ArrayLikeFloat): Camera intrinsics with shape + [num_cam, B, 3, 3]. + extrinsics (None | ArrayLikeFloat, optional): Camera extrinsics + with shape [num_cam, B, 4, 4]. Defaults to None. + scores (None | list[ArrayLikeFloat], optional): List of predicted + box scores each of shape [B, N]. Defaults to None. + class_ids (None | list[ArrayLikeInt], optional): List of predicted + class ids each of shape [B, N]. Defaults to None. + track_ids (None | list[ArrayLikeInt], optional): List of predicted + track ids each of shape [B, N]. Defaults to None. + sequence_names (None | list[str], optional): List of sequence + names of shape [B,]. Defaults to None. + categories (None | list[list[str]], optional): List of categories + for each image. Instead of class ids, the categories will be + used to label the boxes. Defaults to None. + """ + if self._run_on_batch(cur_iter): + for idx, batch_images in enumerate(images): + for batch, image in enumerate(batch_images): + self.process_single_image( + image, + image_names[idx][batch], + boxes3d[batch], + intrinsics[idx][batch], # type: ignore + ( + None + if extrinsics is None + else extrinsics[idx][batch] # type: ignore + ), + None if scores is None else scores[batch], + None if class_ids is None else class_ids[batch], + None if track_ids is None else track_ids[batch], + ( + None + if sequence_names is None + else sequence_names[batch] + ), + None if categories is None else categories[batch], + self.cameras[idx], + ) diff --git a/vis4d/vis/image/bev_visualizer.py b/vis4d/vis/image/bev_visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..d6c68c901becd9c96df710c581dd00a4e4671c72 --- /dev/null +++ b/vis4d/vis/image/bev_visualizer.py @@ -0,0 +1,361 @@ +"""BEV Bounding box 3D visualizer.""" + +from __future__ import annotations + +import os +from collections import defaultdict +from dataclasses import dataclass + +import numpy as np +import torch +from torch import Tensor + +from vis4d.common.array import array_to_numpy +from vis4d.common.typing import ( + ArgsType, + ArrayLikeFloat, + ArrayLikeInt, + NDArrayF32, + NDArrayUI8, +) +from vis4d.data.const import AxisMode +from vis4d.op.box.box3d import boxes3d_to_corners, transform_boxes3d +from vis4d.op.geometry.transform import inverse_rigid_transform +from vis4d.vis.base import Visualizer +from vis4d.vis.util import generate_color_map + +from .canvas import CanvasBackend, PillowCanvasBackend +from .viewer import ImageViewerBackend, MatplotlibImageViewer + + +@dataclass +class BEVBox: + """Dataclass storing box informations.""" + + corners: list[tuple[float, float]] + color: tuple[int, int, int] + track_id: int | None + + +@dataclass +class DataSample: + """Dataclass storing a data sample that can be visualized.""" + + name: str + extrinsics: NDArrayF32 + sequence_name: str | None + boxes: list[BEVBox] + + +class BEVBBox3DVisualizer(Visualizer): + """BEV Bounding box 3D visualizer class.""" + + def __init__( + self, + *args: ArgsType, + n_colors: int = 50, + file_type: str = "png", + max_range: float = 60, + scale: float = 10, + width: int = 2, + margin: int = 10, + axis_mode: AxisMode = AxisMode.ROS, + trajectory_length: int = 10, + plot_trajectory: bool = True, + canvas: CanvasBackend | None = None, + viewer: ImageViewerBackend | None = None, + **kwargs: ArgsType, + ) -> None: + """Creates a new Visualizer for BEV Image and Bounding Boxes. + + Args: + n_colors (int): How many colors should be used for the internal + color map. Defaults to 100. + file_type (str): Desired file type. Defaults to "png". + max_range (float): Maximum range (meters) of the BEV image. + Defaults to 60. + scale (float): Scale of the BEV image. Defaults to 10. Means that + 1m in the BEV image is 10px. + width (int): Width of the drawn bounding boxes. Defaults to 2. + margin (int): Margin of the BEV image. Defaults to 10. + axis_mode (AxisMode): Axis mode for the input bboxes. Defaults to + AxisMode.ROS (i.e. global coordinate). + trajectory_length (int): How many past frames should be used to + draw the trajectory. Defaults to 10. + plot_trajectory (bool): If the trajectory should be plotted. + Defaults to True. + canvas (CanvasBackend): Backend that is used to draw on images. If + None a PillowCanvasBackend is used. + viewer (ImageViewerBackend): Backend that is used show images. If + None a MatplotlibImageViewer is used. + """ + super().__init__(*args, **kwargs) + self._samples: list[DataSample] = [] + self.axis_mode = axis_mode + self.trajectories: dict[int, list[tuple[float, float, float]]] = ( + defaultdict(list) + ) + self.trajectory_length = trajectory_length + self.plot_trajectory = plot_trajectory + + self.color_palette = generate_color_map(n_colors) + + self.file_type = file_type + self.max_range = max_range + self.scale = scale + + # Generate figure size + self.figure_hw = ( + int(max_range * scale + margin) * 2, + int(max_range * scale + margin) * 2, + ) + + self.width = width + + self.canvas = canvas if canvas is not None else PillowCanvasBackend() + self.viewer = viewer if viewer is not None else MatplotlibImageViewer() + + def __repr__(self) -> str: + """Return string representation.""" + return "BEVBBox3DVisualizer" + + def reset(self) -> None: + """Reset visualizer.""" + self._samples.clear() + + def process( # pylint: disable=arguments-differ + self, + cur_iter: int, + sample_names: list[list[str]] | list[str], + boxes3d: list[ArrayLikeFloat], + extrinsics: list[ArrayLikeFloat] | ArrayLikeFloat, + class_ids: None | list[ArrayLikeInt] = None, + track_ids: None | list[ArrayLikeInt] = None, + sequence_names: None | list[str] = None, + ) -> None: + """Processes a batch of data.""" + # Handle multi-sensor connector results from multi-sensor data dict + if isinstance(sample_names[0], list) and isinstance(extrinsics, list): + sample_names = sample_names[0] + extrinsics = extrinsics[0] + + if self._run_on_batch(cur_iter): + for batch, sample_name in enumerate(sample_names): + self.process_single( + sample_name, # type: ignore + boxes3d[batch], + extrinsics[batch], # type: ignore + class_ids[batch] if class_ids is not None else None, + track_ids[batch] if track_ids is not None else None, + ( + sequence_names[batch] + if sequence_names is not None + else None + ), + ) + + for tid in self.trajectories: + if len(self.trajectories[tid]) > self.trajectory_length: + self.trajectories[tid].pop(0) + + def process_single( + self, + sample_name: str, + boxes3d: ArrayLikeFloat, + extrinsics: ArrayLikeFloat, + class_ids: None | ArrayLikeInt = None, + track_ids: None | ArrayLikeInt = None, + sequence_name: None | str = None, + ) -> None: + """Process single batch.""" + boxes3d = array_to_numpy(boxes3d, n_dims=2, dtype=np.float32) + extrinsics_np = array_to_numpy(extrinsics, n_dims=2, dtype=np.float32) + data_sample = DataSample( + sample_name, + extrinsics_np, + sequence_name, + [], + ) + + boxes3d_lidar, boxes3d = self._get_lidar_and_global_boxes3d( + boxes3d, extrinsics_np + ) + + corners = boxes3d_to_corners( + boxes3d_lidar, axis_mode=AxisMode.LIDAR + ).numpy() + + track_ids_np = array_to_numpy(track_ids, n_dims=1, dtype=np.int32) + class_ids_np = array_to_numpy(class_ids, n_dims=1, dtype=np.int32) + + for i in range(corners.shape[0]): + track_id = None if track_ids_np is None else int(track_ids_np[i]) + class_id = None if class_ids_np is None else int(class_ids_np[i]) + + if track_id is not None: + color = self.color_palette[track_id % len(self.color_palette)] + self.trajectories[track_id].append( + tuple(boxes3d[i][:3].tolist()) + ) + elif class_id is not None: + color = self.color_palette[class_id % len(self.color_palette)] + else: + color = (255, 0, 0) + + data_sample.boxes.append( + BEVBox( + [tuple(pts) for pts in corners[i, :4, :2]], + color, + track_id=track_id, + ) + ) + + self._samples.append(data_sample) + + def _get_lidar_and_global_boxes3d( + self, boxes3d: NDArrayF32, extrinsics: NDArrayF32 + ) -> tuple[Tensor, NDArrayF32]: + """Get boxes3d in lidar and global frame.""" + if self.axis_mode == AxisMode.ROS: + global_to_lidar = inverse_rigid_transform( + torch.from_numpy(extrinsics) + ) + + boxes3d_global = boxes3d + + boxes3d_lidar = transform_boxes3d( + torch.from_numpy(boxes3d), + global_to_lidar, + source_axis_mode=self.axis_mode, + target_axis_mode=AxisMode.LIDAR, + ) + elif self.axis_mode == AxisMode.LIDAR: + boxes3d_global = transform_boxes3d( + torch.from_numpy(boxes3d), + torch.from_numpy(extrinsics), + source_axis_mode=self.axis_mode, + target_axis_mode=AxisMode.ROS, + ).numpy() + + boxes3d_lidar = torch.from_numpy(boxes3d) + else: + raise NotImplementedError( + f"Axis mode {self.axis_mode} not supported" + ) + return boxes3d_lidar, boxes3d_global + + def show(self, cur_iter: int, blocking: bool = True) -> None: + """Shows the processed images in a interactive window. + + Args: + cur_iter (int): Current iteration. + blocking (bool): If the visualizer should be blocking i.e. wait for + human input for each image. Defaults to True. + """ + if self._run_on_batch(cur_iter): + image_data = [self._draw_image(d) for d in self._samples] + self.viewer.show_images(image_data, blocking=blocking) + + def _map_lidar_to_bev_image( + self, point_x: float, point_y: float + ) -> tuple[float, float]: + """Maps a point from lidar frame to BEV image frame.""" + return ( + self.scale * point_x + self.figure_hw[1] // 2, + self.scale * -point_y + self.figure_hw[0] // 2, + ) + + def _draw_image(self, sample: DataSample) -> NDArrayUI8: + """Visualizes the datasample and returns is as numpy image. + + Args: + sample (DataSample): The data sample to visualize. + + Returns: + NDArrayUI8: A image with the visualized data sample. + """ + self.canvas.create_canvas(image_hw=self.figure_hw) + + img_center = self._map_lidar_to_bev_image(0, 0) + + # Mark range every 10m + for i in range(int(self.max_range / 10), 0, -1): + distance = int(10 * self.scale * i) + grey_level = 140 + i * 10 + self.canvas.draw_circle( + img_center, (grey_level, grey_level, grey_level), distance + ) + + self.canvas.draw_text( + (img_center[0] + distance - 25, img_center[1]), + f"{10 * i} m", + color=(0, 0, 0), + ) + + # Draw ego car + self.canvas.draw_rotated_box( + [ + (img_center[0] - self.scale, img_center[1] - self.scale * 2), + (img_center[0] + self.scale, img_center[1] - self.scale * 2), + (img_center[0] - self.scale, img_center[1] + self.scale * 2), + (img_center[0] + self.scale, img_center[1] + self.scale * 2), + ], + (0, 0, 0), + self.width, + ) + + global_to_lidar = inverse_rigid_transform( + torch.from_numpy(sample.extrinsics) + ).numpy() + + for box in sample.boxes: + corners = [ + self._map_lidar_to_bev_image(pts[0], pts[1]) + for pts in box.corners + ] + self.canvas.draw_rotated_box(corners, box.color, self.width) + + if self.plot_trajectory: + assert ( + box.track_id is not None + ), "Track id must be set to plot trajectory." + + trajectory = self.trajectories[box.track_id] + for center in trajectory: + # Move global center to current lidar frame + center_lidar = np.dot(global_to_lidar, [*center, 1])[:3] + + bev_center = self._map_lidar_to_bev_image( + center_lidar[0], center_lidar[1] + ) + + self.canvas.draw_circle( + bev_center, box.color, self.width * 2 + ) + + return self.canvas.as_numpy_image() + + def save_to_disk(self, cur_iter: int, output_folder: str) -> None: + """Saves the visualization to disk. + + Writes all processes samples to the output folder naming each image + .. + + Args: + cur_iter (int): Current iteration. + output_folder (str): Folder where the output should be written. + """ + if self._run_on_batch(cur_iter): + for sample in self._samples: + output_dir = output_folder + sample_name = f"{sample.name}.{self.file_type}" + + self._draw_image(sample) + + if sample.sequence_name is not None: + output_dir = os.path.join(output_dir, sample.sequence_name) + + output_dir = os.path.join(output_dir, "BEV") + + os.makedirs(output_dir, exist_ok=True) + self.canvas.save_to_disk(os.path.join(output_dir, sample_name)) diff --git a/vis4d/vis/image/bounding_box_visualizer.py b/vis4d/vis/image/bounding_box_visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..5a436c89dba735ef788b568a7c7e6c70c38f2eaf --- /dev/null +++ b/vis4d/vis/image/bounding_box_visualizer.py @@ -0,0 +1,226 @@ +"""Bounding box visualizer.""" + +from __future__ import annotations + +import os +from dataclasses import dataclass + +from vis4d.common.typing import ( + ArgsType, + ArrayLike, + ArrayLikeFloat, + ArrayLikeInt, + NDArrayUI8, +) +from vis4d.vis.base import Visualizer +from vis4d.vis.util import generate_color_map + +from .canvas import CanvasBackend, PillowCanvasBackend +from .util import preprocess_boxes, preprocess_image +from .viewer import ImageViewerBackend, MatplotlibImageViewer + + +@dataclass +class DetectionBox2D: + """Dataclass storing box informations.""" + + corners: tuple[float, float, float, float] + label: str + color: tuple[int, int, int] + + +@dataclass +class DataSample: + """Dataclass storing a data sample that can be visualized.""" + + image: NDArrayUI8 + image_name: str + boxes: list[DetectionBox2D] + + +class BoundingBoxVisualizer(Visualizer): + """Bounding box visualizer class.""" + + def __init__( + self, + *args: ArgsType, + n_colors: int = 50, + cat_mapping: dict[str, int] | None = None, + file_type: str = "png", + width: int = 2, + canvas: CanvasBackend = PillowCanvasBackend(), + viewer: ImageViewerBackend = MatplotlibImageViewer(), + **kwargs: ArgsType, + ) -> None: + """Creates a new Visualizer for Image and Bounding Boxes. + + Args: + n_colors (int): How many colors should be used for the internal + color map + cat_mapping (dict[str, int]): Mapping from class names to class + ids. Defaults to None. + file_type (str): Desired file type. Defaults to "png". + width (int): Width of the bounding box lines. Defaults to 2. + canvas (CanvasBackend): Backend that is used to draw on images. + viewer (ImageViewerBackend): Backend that is used show images. + """ + super().__init__(*args, **kwargs) + self._samples: list[DataSample] = [] + self.color_palette = generate_color_map(n_colors) + self.class_id_mapping = ( + {v: k for k, v in cat_mapping.items()} + if cat_mapping is not None + else {} + ) + self.file_type = file_type + self.width = width + self.canvas = canvas + self.viewer = viewer + + def __repr__(self) -> str: + """Return string representation of the visualizer.""" + return "BoundingBoxVisualizer" + + def reset(self) -> None: + """Reset visualizer.""" + self._samples.clear() + + def process( # pylint: disable=arguments-differ + self, + cur_iter: int, + images: list[ArrayLike], + image_names: list[str], + boxes: list[ArrayLikeFloat], + scores: None | list[ArrayLikeFloat] = None, + class_ids: None | list[ArrayLikeInt] = None, + track_ids: None | list[ArrayLikeInt] = None, + categories: None | list[list[str]] = None, + ) -> None: + """Processes a batch of data. + + Args: + cur_iter (int): Current iteration. + images (list[ArrayLike]): Images to show. + image_names (list[str]): Image names. + boxes (list[ArrayLikeFloat]): List of predicted bounding boxes with + shape [N, (x1, y1, x2, y2)], where N is the number of boxes. + scores (None | list[ArrayLikeFloat], optional): List of predicted + box scores each of shape [N]. Defaults to None. + class_ids (None | list[ArrayLikeInt], optional): List of predicted + class ids each of shape [N]. Defaults to None. + track_ids (None | list[ArrayLikeInt], optional): List of predicted + track ids each of shape [N]. Defaults to None. + categories (None | list[list[str]], optional): List of categories + for each image. Instead of class ids, the categories will be + used to label the boxes. Defaults to None. + """ + if self._run_on_batch(cur_iter): + for idx, image in enumerate(images): + self.process_single_image( + image, + image_names[idx], + boxes[idx], + None if scores is None else scores[idx], + None if class_ids is None else class_ids[idx], + None if track_ids is None else track_ids[idx], + None if categories is None else categories[idx], + ) + + def process_single_image( + self, + image: ArrayLike, + image_name: str, + boxes: ArrayLikeFloat, + scores: None | ArrayLikeFloat = None, + class_ids: None | ArrayLikeInt = None, + track_ids: None | ArrayLikeInt = None, + categories: None | list[str] = None, + ) -> None: + """Processes a single image entry. + + Args: + image (ArrayLike): Image to show. + image_name (str): Image name. + boxes (ArrayLikeFloat): Predicted bounding boxes with shape + [N, (x1,y1,x2,y2)], where N is the number of boxes. + scores (None | ArrayLikeFloat, optional): Predicted box scores of + shape [N]. Defaults to None. + class_ids (None | ArrayLikeInt, optional): Predicted class ids of + shape [N]. Defaults to None. + track_ids (None | ArrayLikeInt, optional): Predicted track ids of + shape [N]. Defaults to None. + categories (None | list[str], optional): List of categories for + each box. Instead of class ids, the categories will be used to + label the boxes. Defaults to None. + """ + img_normalized = preprocess_image(image, mode=self.image_mode) + data_sample = DataSample(img_normalized, image_name, []) + + for corners, label, color in zip( + *preprocess_boxes( + boxes, + scores, + class_ids, + track_ids, + self.color_palette, + self.class_id_mapping, + categories=categories, + ) + ): + data_sample.boxes.append( + DetectionBox2D( + corners=(corners[0], corners[1], corners[2], corners[3]), + label=label, + color=color, + ) + ) + + self._samples.append(data_sample) + + def show(self, cur_iter: int, blocking: bool = True) -> None: + """Shows the processed images in a interactive window. + + Args: + cur_iter (int): Current iteration. + blocking (bool): If the visualizer should be blocking i.e. wait for + human input for each image. Defaults to True. + """ + if self._run_on_batch(cur_iter): + image_data = [self._draw_image(d) for d in self._samples] + self.viewer.show_images(image_data, blocking=blocking) + + def _draw_image(self, sample: DataSample) -> NDArrayUI8: + """Visualizes the datasample and returns is as numpy image. + + Args: + sample (DataSample): The data sample to visualize. + + Returns: + NDArrayUI8: A image with the visualized data sample. + """ + self.canvas.create_canvas(sample.image) + for box in sample.boxes: + self.canvas.draw_box(box.corners, box.color, width=self.width) + self.canvas.draw_text(box.corners[:2], box.label, box.color) + + return self.canvas.as_numpy_image() + + def save_to_disk(self, cur_iter: int, output_folder: str) -> None: + """Saves the visualization to disk. + + Writes all processes samples to the output folder naming each image + .. + + Args: + cur_iter (int): Current iteration. + output_folder (str): Folder where the output should be written. + """ + if self._run_on_batch(cur_iter): + for sample in self._samples: + image_name = f"{sample.image_name}.{self.file_type}" + + _ = self._draw_image(sample) + + self.canvas.save_to_disk( + os.path.join(output_folder, image_name) + ) diff --git a/vis4d/vis/image/canvas/__init__.py b/vis4d/vis/image/canvas/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ea8e856bc8e8319e9576592b6fe763add88e5b56 --- /dev/null +++ b/vis4d/vis/image/canvas/__init__.py @@ -0,0 +1,6 @@ +"""Vis4D image canvas backends.""" + +from .base import CanvasBackend +from .pillow_backend import PillowCanvasBackend + +__all__ = ["CanvasBackend", "PillowCanvasBackend"] diff --git a/vis4d/vis/image/canvas/base.py b/vis4d/vis/image/canvas/base.py new file mode 100644 index 0000000000000000000000000000000000000000..6eb31eda31dafea2da0418407c0bb9a8d83e0a1e --- /dev/null +++ b/vis4d/vis/image/canvas/base.py @@ -0,0 +1,175 @@ +"""Base class of canvas for image based visualization.""" + +from __future__ import annotations + +from vis4d.common.typing import NDArrayBool, NDArrayF32, NDArrayUI8 + + +class CanvasBackend: + """Abstract interface that allows to draw on images. + + Supports drawing different bounding boxes on top of an image. + """ + + def create_canvas( + self, + image: NDArrayUI8 | None = None, + image_hw: tuple[int, int] | None = None, + ) -> None: + """Creates a new canvas with a given image or shape internally. + + Either provide a background image or the desired height, width + of the canvas. + + Args: + image (np.array[uint8] | None): Numpy array with a background image + image_hw (tuple[int, int] | None): height, width of the canvas + """ + raise NotImplementedError + + def draw_bitmap( + self, + bitmap: NDArrayBool, + color: tuple[int, int, int], + top_left_corner: tuple[float, float] = (0, 0), + alpha: float = 0.5, + ) -> None: + """Draws a binary mask onto the given canvas. + + Args: + bitmap (ndarray): The binary mask to draw + color (tuple[int, int, int]): Color of the box [0,255]. + top_left_corner (tuple(float, float)): Coordinates of top left + corner of the bitmap. Defaults to (0, 0). + alpha (float, optional): Alpha value for transparency of this mask. + Defaults to 0.5. + """ + raise NotImplementedError + + def draw_text( + self, + position: tuple[float, float], + text: str, + color: tuple[int, int, int] = (255, 255, 255), + ) -> None: + """Draw text onto canvas at given position. + + Args: + position (tuple[float, float]): x,y position where the text will + start. + text (str): Text to be placed at the given location. + color (tuple[int, int, int], optional): Text color. Defaults to + (255, 255, 255). + """ + raise NotImplementedError + + def draw_line( + self, + point1: tuple[float, float], + point2: tuple[float, float], + color: tuple[int, int, int], + width: int = 0, + ) -> None: + """Draw a line onto canvas from point 1 to 2. + + Args: + point1 (tuple[float, float]): Start point (2D pixel coordinates). + point2 (tuple[float, float]): End point (2D pixel coordinates). + color (ttuple[int, int, int]): Color of the line. + width (int, optional): Line width. Defaults to 0. + """ + raise NotImplementedError + + def draw_circle( + self, + center: tuple[float, float], + color: tuple[int, int, int], + radius: int = 2, + ) -> None: + """Draw a circle onto canvas. + + Args: + center (tuple[float, float]): Center of the circle. + color (tuple[int, int, int]): Color of the circle. + radius (int, optional): Radius of the circle. Defaults to 2. + """ + raise NotImplementedError + + def draw_box( + self, + corners: tuple[float, float, float, float], + color: tuple[int, int, int], + width: int = 1, + ) -> None: + """Draws a box onto the given canvas. + + Args: + corners (list[float]): Containing [x1,y1,x2,y2] the corners of + the box. + color (tuple[int, int, int]): Color of the box [0,255]. + width (int, optional): Line width. Defaults to 1. + + Raises: + ValueError: If the canvas is not initialized. + """ + raise NotImplementedError + + def draw_rotated_box( + self, + corners: list[tuple[float, float]], + color: tuple[int, int, int], + width: int = 0, + ) -> None: + """Draws a box onto the given canvas. + + Corner ordering: + + (2) +---------+ (3) + | | + | | + | | + (0) +---------+ (1) + + Args: + corners (list[tuple[float, float]]): Containing the four corners of + the box. + color (tuple[int, int, int]): Color of the box [0,255]. + width (int, optional): Line width. Defaults to 0. + """ + raise NotImplementedError + + def draw_box_3d( + self, + corners: list[tuple[float, float, float]], + color: tuple[int, int, int], + intrinsics: NDArrayF32, + width: int = 0, + camera_near_clip: float = 0.15, + plot_heading: bool = True, + ) -> None: + """Draws a line between two points. + + Args: + corners (list[tuple[float, float, float]]): Containing the eight + corners of the box. + color (tuple[int, int, int]): Color of the line. + intrinsics (NDArrayF32): Camera intrinsics matrix. + width (int, optional): The width of the line. Defaults to 0. + camera_near_clip (float, optional): The near clipping plane of the + camera. Defaults to 0.15. + plot_heading (bool, optional): If True, the heading of the box will + be plotted as a line. Defaults to True. + """ + raise NotImplementedError + + def as_numpy_image(self) -> NDArrayUI8: + """Returns the current canvas as numpy image.""" + raise NotImplementedError + + def save_to_disk(self, image_path: str) -> None: + """Writes the current canvas to disk. + + Args: + image_path (str): Full image path (with file name and ending). + """ + raise NotImplementedError diff --git a/vis4d/vis/image/canvas/pillow_backend.py b/vis4d/vis/image/canvas/pillow_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..6bfd0104097cbaabf5087f18ba0cc4b6b951ac4e --- /dev/null +++ b/vis4d/vis/image/canvas/pillow_backend.py @@ -0,0 +1,376 @@ +"""Pillow backend implementation to draw on images.""" + +from __future__ import annotations + +import numpy as np +from PIL import Image, ImageDraw +from PIL.ImageFont import ImageFont, load_default + +from vis4d.common.typing import NDArrayBool, NDArrayF32, NDArrayF64, NDArrayUI8 + +from ..util import get_intersection_point, project_point +from .base import CanvasBackend + + +class PillowCanvasBackend(CanvasBackend): + """Canvas backend using Pillow.""" + + def __init__( + self, font: ImageFont | None = None, font_size: int | None = None + ) -> None: + """Creates a new canvas backend. + + Args: + font (ImageFont): Pillow font to use for the label. + font_size (int): Font size to use for the label. + """ + self._image_draw: ImageDraw.ImageDraw | None = None + self._font = font if font is not None else load_default(font_size) + self._image: Image.Image | None = None + + def create_canvas( + self, + image: NDArrayUI8 | None = None, + image_hw: tuple[int, int] | None = None, + ) -> None: + """Creates a new canvas with a given image or shape internally. + + Either provide a background image or the desired height, width + of the canvas. + + Args: + image (np.array[uint8] | None): Numpy array with a background image + image_hw (tuple[int, int] | None): height, width of the canvas + + Raises: + ValueError: If the canvas is not initialized. + """ + if image_hw is not None: + white_image = np.ones([*image_hw, 3]) * 255 + image = white_image.astype(np.uint8) + else: + assert ( + image is not None + ), "Image or Image Shapes required to create canvas" + + self._image = Image.fromarray(image) + self._image_draw = ImageDraw.Draw(self._image) + + def draw_bitmap( + self, + bitmap: NDArrayBool, + color: tuple[int, int, int], + top_left_corner: tuple[float, float] = (0, 0), + alpha: float = 0.5, + ) -> None: + """Draws a binary mask onto the given canvas. + + Args: + bitmap (ndarray): The binary mask to draw. + color (tuple[int, int, int]): Color of the box [0,255]. + top_left_corner (tuple(float, float)): Coordinates of top left + corner of the bitmap. + alpha (float): Alpha value for transparency of this mask. + + Raises: + ValueError: If the canvas is not initialized. + """ + if self._image_draw is None: + raise ValueError( + "No Image Draw initialized! Did you call 'create_canvas'?" + ) + mask = np.squeeze(bitmap) + assert len(mask.shape) == 2, "Bitmap expected to have shape [h,w]" + + bitmap_with_alpha: NDArrayF64 = np.repeat( + mask[:, :, None], 4, axis=2 + ).astype(np.float64) + bitmap_with_alpha[..., -1] = bitmap_with_alpha[..., -1] * alpha * 255 + bitmap_pil = Image.fromarray( + bitmap_with_alpha.astype(np.uint8), mode="RGBA" + ) + self._image_draw.bitmap( + top_left_corner, bitmap_pil, fill=color # type: ignore + ) + + def draw_text( + self, + position: tuple[float, float], + text: str, + color: tuple[int, int, int] = (255, 255, 255), + ) -> None: + """Draw text onto canvas at given position. + + Args: + position (tuple[float, float]): x,y position where the text will + start. + text (str): Text to be placed at the given location. + color (tuple[int, int, int], optional): Text color. Defaults to + (255, 255, 255). + + Raises: + ValueError: If the canvas is not initialized. + """ + if self._image_draw is None: + raise ValueError( + "No Image Draw initialized! Did you call 'create_canvas'?" + ) + left, top, right, bottom = self._image_draw.textbbox( + position, text, font=self._font + ) + self._image_draw.rectangle( + (left - 2, top - 2, right + 2, bottom + 2), fill=color + ) + self._image_draw.text(position, text, (255, 255, 255), font=self._font) + + def draw_box( + self, + corners: tuple[float, float, float, float], + color: tuple[int, int, int], + width: int = 1, + ) -> None: + """Draws a box onto the given canvas. + + Args: + corners (list[float]): Containing [x1,y2,x2,y2] the corners of + the box. + color (tuple[int, int, int]): Color of the box [0,255]. + width (int, optional): Line width. Defaults to 1. + + Raises: + ValueError: If the canvas is not initialized. + """ + if self._image_draw is None: + raise ValueError( + "No Image Draw initialized! Did you call 'create_canvas'?" + ) + + self._image_draw.rectangle(corners, outline=color, width=width) + + def draw_rotated_box( + self, + corners: list[tuple[float, float]], + color: tuple[int, int, int], + width: int = 0, + ) -> None: + """Draws a box onto the given canvas. + + Corner ordering: + + (2) +---------+ (3) + | | + | | + | | + (0) +---------+ (1) + + Args: + corners (list[tuple[float, float]]): Containing the four corners of + the box. + color (tuple[int, int, int]): Color of the box [0,255]. + width (int, optional): Line width. Defaults to 0. + + Raises: + ValueError: If the canvas is not initialized. + """ + assert len(corners) == 4, "2D box must consist of 4 corner points." + if self._image_draw is None: + raise ValueError( + "No Image Draw initialized! Did you call 'create_canvas'?" + ) + + self.draw_line(corners[0], corners[1], color, 2 * width) + self.draw_line(corners[0], corners[2], color, width) + self.draw_line(corners[1], corners[3], color, width) + self.draw_line(corners[2], corners[3], color, width) + + center_forward = np.mean(corners[:2], axis=0, dtype=np.float32) + center = np.mean(corners, axis=0, dtype=np.float32) + self.draw_line( + tuple(center.tolist()), + tuple(center_forward.tolist()), + color, + width, + ) + + def draw_line( + self, + point1: tuple[float, float], + point2: tuple[float, float], + color: tuple[int, int, int], + width: int = 0, + ) -> None: + """Draw a line onto canvas from point 1 to 2. + + Args: + point1 (tuple[float, float]): Start point (2D pixel coordinates). + point2 (tuple[float, float]): End point (2D pixel coordinates). + color (tuple[int, int, int]): Color of the line. + width (int, optional): Line width. Defaults to 0. + + Raises: + ValueError: If the canvas is not initialized. + """ + if self._image_draw is None: + raise ValueError( + "No Image Draw initialized! Did you call 'create_canvas'?" + ) + self._image_draw.line((point1, point2), width=width, fill=color) + + def draw_circle( + self, + center: tuple[float, float], + color: tuple[int, int, int], + radius: int = 2, + ) -> None: + """Draw a circle onto canvas. + + Args: + center (tuple[float, float]): Center of the circle. + color (tuple[int, int, int]): Color of the circle. + radius (int, optional): Radius of the circle. Defaults to 2. + """ + x1 = center[0] - radius + y1 = center[1] - radius + x2 = center[0] + radius + y2 = center[1] + radius + if self._image_draw is None: + raise ValueError( + "No Image Draw initialized! Did you call 'create_canvas'?" + ) + self._image_draw.ellipse((x1, y1, x2, y2), fill=color, outline=color) + + def _draw_box_3d_line( + self, + point1: tuple[float, float, float], + point2: tuple[float, float, float], + color: tuple[int, int, int], + intrinsics: NDArrayF32, + width: int = 0, + camera_near_clip: float = 0.15, + ) -> None: + """Draws a line between two points. + + Args: + point1 (tuple[float, float, float]): The first point. The third + coordinate is the depth. + point2 (tuple[float, float, float]): The first point. The third + coordinate is the depth. + color (tuple[int, int, int]): Color of the line. + intrinsics (NDArrayF32): Camera intrinsics matrix. + width (int, optional): The width of the line. Defaults to 0. + camera_near_clip (float, optional): The near clipping plane of the + camera. Defaults to 0.15. + + Raises: + ValueError: If the canvas is not initialized. + """ + if point1[2] < camera_near_clip and point2[2] < camera_near_clip: + return + + if point1[2] < camera_near_clip: + point1 = get_intersection_point(point1, point2, camera_near_clip) + elif point2[2] < camera_near_clip: + point2 = get_intersection_point(point1, point2, camera_near_clip) + + pt1 = project_point(point1, intrinsics) + pt2 = project_point(point2, intrinsics) + + if self._image_draw is None: + raise ValueError( + "No Image Draw initialized! Did you call 'create_canvas'?" + ) + self._image_draw.line((pt1, pt2), width=width, fill=color) + + def draw_box_3d( + self, + corners: list[tuple[float, float, float]], + color: tuple[int, int, int], + intrinsics: NDArrayF32, + width: int = 0, + camera_near_clip: float = 0.15, + plot_heading: bool = True, + ) -> None: + """Draws a 3D box onto the given canvas.""" + # Draw Front + self._draw_box_3d_line( + corners[0], corners[1], color, intrinsics, width, camera_near_clip + ) + self._draw_box_3d_line( + corners[1], corners[5], color, intrinsics, width, camera_near_clip + ) + self._draw_box_3d_line( + corners[5], corners[4], color, intrinsics, width, camera_near_clip + ) + self._draw_box_3d_line( + corners[4], corners[0], color, intrinsics, width, camera_near_clip + ) + + # Draw Sides + self._draw_box_3d_line( + corners[0], corners[2], color, intrinsics, width, camera_near_clip + ) + self._draw_box_3d_line( + corners[1], corners[3], color, intrinsics, width, camera_near_clip + ) + self._draw_box_3d_line( + corners[4], corners[6], color, intrinsics, width, camera_near_clip + ) + self._draw_box_3d_line( + corners[5], corners[7], color, intrinsics, width, camera_near_clip + ) + + # Draw Back + self._draw_box_3d_line( + corners[2], corners[3], color, intrinsics, width, camera_near_clip + ) + self._draw_box_3d_line( + corners[3], corners[7], color, intrinsics, width, camera_near_clip + ) + self._draw_box_3d_line( + corners[7], corners[6], color, intrinsics, width, camera_near_clip + ) + self._draw_box_3d_line( + corners[6], corners[2], color, intrinsics, width, camera_near_clip + ) + + # Draw line indicating the front + if plot_heading: + center_bottom_forward = np.mean( + corners[:2], axis=0, dtype=np.float32 + ) + center_bottom = np.mean(corners[:4], axis=0, dtype=np.float32) + self._draw_box_3d_line( + tuple(center_bottom.tolist()), + tuple(center_bottom_forward.tolist()), + color, + intrinsics, + width, + camera_near_clip, + ) + + def as_numpy_image(self) -> NDArrayUI8: + """Returns the current canvas as numpy image. + + Raises: + ValueError: If the canvas is not initialized. + """ + if self._image is None: + raise ValueError( + "No Image initialized! Did you call 'create_canvas'?" + ) + return np.asarray(self._image) + + def save_to_disk(self, image_path: str) -> None: + """Writes the current canvas to disk. + + Args: + image_path (str): Full image path (with file name and ending). + + Raises: + ValueError: If the canvas is not initialized. + """ + if self._image is None: + raise ValueError( + "No Image initialized! Did you call 'create_canvas'?" + ) + self._image.save(image_path) diff --git a/vis4d/vis/image/functional.py b/vis4d/vis/image/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..5113b198be71e932f38dd1d98a842ca3f7f5f75e --- /dev/null +++ b/vis4d/vis/image/functional.py @@ -0,0 +1,430 @@ +"""Function interface for image visualization functions.""" + +from __future__ import annotations + +import numpy as np + +from vis4d.common.array import array_to_numpy +from vis4d.common.typing import ( + ArrayLike, + ArrayLikeBool, + ArrayLikeFloat, + ArrayLikeInt, + NDArrayF32, + NDArrayUI8, +) + +from ..util import generate_color_map +from .canvas import CanvasBackend, PillowCanvasBackend +from .util import ( + preprocess_boxes, + preprocess_boxes3d, + preprocess_image, + preprocess_masks, + project_point, +) +from .viewer import ImageViewerBackend, MatplotlibImageViewer + + +def imshow( + image: ArrayLike, + image_mode: str = "RGB", + image_viewer: ImageViewerBackend = MatplotlibImageViewer(), +) -> None: + """Shows a single image. + + Args: + image (NDArrayNumber): The image to show. + image_mode (str, optional): Image Mode. Defaults to "RGB". + image_viewer (ImageViewerBackend, optional): The Image viewer backend + to use. Defaults to MatplotlibImageViewer(). + """ + image = preprocess_image(image, image_mode) + image_viewer.show_images([image]) + + +def draw_masks( + image: ArrayLike, + masks: ArrayLikeBool, + class_ids: ArrayLikeInt | None, + n_colors: int = 50, + image_mode: str = "RGB", + canvas: CanvasBackend = PillowCanvasBackend(), +) -> NDArrayUI8: + """Draws semantic masks into the given image. + + Args: + image (ArrayLike): The image to draw the bboxes into. + masks (ArrayLikeBool): The semantic masks with the same shape as the + image. + class_ids (ArrayLikeInt, optional): Predicted class ids. + Defaults to None. + n_colors (int, optional): Number of colors to use for color palette. + Defaults to 50. + image_mode (str, optional): Image Mode. Defaults to "RGB". + canvas (CanvasBackend, optional): Canvas backend to use. + Defaults to PillowCanvasBackend(). + + Returns: + NDArrayUI8: The image with semantic masks drawn into it, + """ + image = preprocess_image(image, mode=image_mode) + canvas.create_canvas(image) + for m, c in zip( + *preprocess_masks(masks, class_ids, generate_color_map(n_colors)) + ): + canvas.draw_bitmap(m, c) + return canvas.as_numpy_image() + + +def draw_bboxes( + image: ArrayLike, + boxes: ArrayLikeFloat, + scores: None | ArrayLikeFloat = None, + class_ids: None | ArrayLikeInt = None, + track_ids: None | ArrayLikeInt = None, + class_id_mapping: None | dict[int, str] = None, + n_colors: int = 50, + image_mode: str = "RGB", + box_width: int = 1, + canvas: CanvasBackend = PillowCanvasBackend(), +) -> CanvasBackend: + """Draws the predicted bounding boxes into the given image. + + Args: + image (ArrayLike): The image to draw the bboxes into. + boxes (ArrayLikeFloat): Predicted bounding boxes. + scores (None | ArrayLikeFloat, optional): Predicted scores. + Defaults to None. + class_ids (ArrayLikeInt, optional): Predicted class ids. + Defaults to None. + track_ids (ArrayLikeInt, optional): Predicted track ids. + Defaults to None. + class_id_mapping (dict[int, str], optional): Mapping from class id to + name. Defaults to None. + n_colors (int, optional): Number of colors to use for color palette. + Defaults to 50. + image_mode (str, optional): Image Mode. Defaults to "RGB". + box_width (int, optional): Width of the box border. Defaults to 1. + canvas (CanvasBackend, optional): Canvas backend to use. + Defaults to PillowCanvasBackend(). + + Returns: + NDArrayUI8: The image with boxes drawn into it, + """ + image = preprocess_image(image, image_mode) + box_data = preprocess_boxes( + boxes, + scores, + class_ids, + track_ids, + color_palette=generate_color_map(n_colors), + class_id_mapping=class_id_mapping, + ) + canvas.create_canvas(image) + + for corners, label, color in zip(*box_data): + canvas.draw_box(corners, color, box_width) + + if len(label) > 0: + canvas.draw_text((corners[0], corners[1]), label, color=color) + return canvas + + +def imshow_bboxes( + image: ArrayLike, + boxes: ArrayLikeFloat, + scores: None | ArrayLikeFloat = None, + class_ids: None | ArrayLikeInt = None, + track_ids: None | ArrayLikeInt = None, + class_id_mapping: None | dict[int, str] = None, + n_colors: int = 50, + image_mode: str = "RGB", + box_width: int = 1, + image_viewer: ImageViewerBackend = MatplotlibImageViewer(), + file_path: str | None = None, +) -> None: + """Shows the bounding boxes overlayed on the given image. + + Args: + image (ArrayLike): Background Image + boxes (ArrayLikeFloat): Boxes to show. Shape [N, 4] with + (x1,y1,x2,y2) as corner convention + scores (ArrayLikeFloat, optional): Score for each box shape [N] + class_ids (ArrayLikeInt, optional): Class id for each box shape [N] + track_ids (ArrayLikeInt, optional): Track id for each box shape [N] + class_id_mapping (dict[int, str], optional): Mapping to convert + class id to class name + n_colors (int, optional): Number of distinct colors used to color the + boxes. Defaults to 50. + image_mode (str, optional): Image channel mode (RGB or BGR). + box_width (int, optional): Width of the box border. Defaults to 1. + image_viewer (ImageViewerBackend, optional): The Image viewer backend + to use. Defaults to MatplotlibImageViewer(). + file_path (str): The path to save the image to. Defaults to None. + """ + image = preprocess_image(image, mode=image_mode) + canvas = draw_bboxes( + image, + boxes, + scores, + class_ids, + track_ids, + class_id_mapping, + n_colors, + image_mode, + box_width, + ) + imshow(canvas.as_numpy_image(), image_mode, image_viewer) + + if file_path is not None: + canvas.save_to_disk(file_path) + + +def draw_bbox3d( + image: NDArrayUI8, + boxes3d: ArrayLikeFloat, + intrinsics: NDArrayF32, + extrinsics: NDArrayF32 | None = None, + scores: None | ArrayLikeFloat = None, + class_ids: None | ArrayLikeInt = None, + track_ids: None | ArrayLikeInt = None, + class_id_mapping: None | dict[int, str] = None, + n_colors: int = 50, + image_mode: str = "RGB", + canvas: CanvasBackend = PillowCanvasBackend(), + width: int = 4, + camera_near_clip: float = 0.15, +) -> CanvasBackend: + """Draw 3D box onto image.""" + image = preprocess_image(image, image_mode) + image_hw = (image.shape[0], image.shape[1]) + _, corners, labels, colors, _ = preprocess_boxes3d( + image_hw, + boxes3d, + intrinsics, + extrinsics, + scores, + class_ids, + track_ids, + color_palette=generate_color_map(n_colors), + class_id_mapping=class_id_mapping, + ) + canvas.create_canvas(image) + + for corner, label, color in zip(corners, labels, colors): + canvas.draw_box_3d(corner, color, intrinsics, width, camera_near_clip) + + selected_corner = project_point(corner[0], intrinsics) + + if len(label) > 0: + canvas.draw_text( + (selected_corner[0], selected_corner[1]), label, color=color + ) + + return canvas + + +def imshow_bboxes3d( + image: ArrayLike, + boxes3d: ArrayLikeFloat, + intrinsics: NDArrayF32, + extrinsics: NDArrayF32 | None = None, + scores: None | ArrayLikeFloat = None, + class_ids: None | ArrayLikeInt = None, + track_ids: None | ArrayLikeInt = None, + class_id_mapping: None | dict[int, str] = None, + n_colors: int = 50, + image_mode: str = "RGB", + image_viewer: ImageViewerBackend = MatplotlibImageViewer(), + file_path: str | None = None, +) -> None: + """Show image with bounding boxes.""" + image = preprocess_image(image, mode=image_mode) + canvas = draw_bbox3d( + image, + boxes3d, + intrinsics, + extrinsics, + scores, + class_ids, + track_ids, + class_id_mapping=class_id_mapping, + n_colors=n_colors, + image_mode=image_mode, + ) + imshow(canvas.as_numpy_image(), image_mode, image_viewer) + + if file_path is not None: + canvas.save_to_disk(file_path) + + +def imshow_masks( + image: ArrayLike, + masks: ArrayLikeBool, + class_ids: ArrayLikeInt | None, + n_colors: int = 50, + image_mode: str = "RGB", + canvas: CanvasBackend = PillowCanvasBackend(), +) -> None: + """Shows semantic masks overlayed over the given image. + + Args: + image (ArrayLike): The image to draw the bboxes into. + masks (ArrayLikeBool): The semantic masks with the same shape as the + image. + class_ids (ArrayLikeInt, optional): Predicted class ids. + Defaults to None. + n_colors (int, optional): Number of colors to use for color palette. + Defaults to 50. + image_mode (str, optional): Image Mode.. Defaults to "RGB". + canvas (CanvasBackend, optional): Canvas backend to use. + Defaults to PillowCanvasBackend(). + """ + imshow( + draw_masks(image, masks, class_ids, n_colors, image_mode, canvas), + image_mode, + ) + + +def imshow_topk_bboxes( + image: ArrayLike, + boxes: ArrayLikeFloat, + scores: ArrayLikeFloat, + topk: int = 100, + class_ids: None | ArrayLikeInt = None, + track_ids: None | ArrayLikeInt = None, + class_id_mapping: None | dict[int, str] = None, + n_colors: int = 50, + image_mode: str = "RGB", + box_width: int = 1, + image_viewer: ImageViewerBackend = MatplotlibImageViewer(), + file_path: str | None = None, +) -> None: + """Visualize the 'topk' bounding boxes with highest score. + + Args: + image (ArrayLike): Background Image + boxes (ArrayLikeFloat): Boxes to show. Shape [N, 4] with + (x1,y1,x2,y2) as corner convention + scores (ArrayLikeFloat): Score for each box shape [N] + topk (int): Number of boxes to visualize + class_ids (ArrayLikeInt, optional): Class id for each box shape [N] + track_ids (ArrayLikeInt, optional): Track id for each box shape [N] + class_id_mapping (dict[int, str], optional): Mapping to convert + class id to class name + n_colors (int, optional): Number of distinct colors used to color the + boxes. Defaults to 50. + image_mode (str, optional): Image channel mode (RGB or BGR). + box_width (int, optional): Width of the box border. Defaults to 1. + image_viewer (ImageViewerBackend, optional): The Image viewer backend + to use. Defaults to MatplotlibImageViewer(). + file_path (str): The path to save the image to. Defaults to None. + + """ + scores = array_to_numpy(scores, n_dims=1, dtype=np.float32) + top_k_idxs = np.argpartition(scores.ravel(), -topk)[-topk:] + + boxes_np = array_to_numpy(boxes, n_dims=2, dtype=np.float32) + class_ids_np = array_to_numpy(class_ids, n_dims=1, dtype=np.int32) + track_ids_np = array_to_numpy(track_ids, n_dims=1, dtype=np.int32) + imshow_bboxes( + image, + boxes_np[top_k_idxs], + scores[top_k_idxs], + class_ids_np[top_k_idxs] if class_ids_np is not None else None, + track_ids_np[top_k_idxs] if track_ids_np is not None else None, + class_id_mapping, + n_colors, + image_mode, + box_width, + image_viewer, + file_path, + ) + + +def imshow_track_matches( + key_imgs: list[ArrayLike], + ref_imgs: list[ArrayLike], + key_boxes: list[ArrayLikeFloat], + ref_boxes: list[ArrayLikeFloat], + key_track_ids: list[ArrayLikeInt], + ref_track_ids: list[ArrayLikeInt], + image_mode: str = "RGB", + image_viewer: ImageViewerBackend = MatplotlibImageViewer(), +) -> None: + """Visualize paired bounding boxes successively for batched frame pairs. + + Args: + key_imgs (list[ArrayLike]): Key Images. + ref_imgs (list[ArrayLike]): Reference Images. + key_boxes (list[ArrayLikeFloat]): Predicted Boxes for the key image. + Shape [N, 4] + ref_boxes (list[ArrayLikeFloat]): Predicted Boxes for the key image. + Shape [N, 4] + key_track_ids (list[ArrayLikeInt]): Predicted ids for the key images. + ref_track_ids (list[ArrayLikeInt]): Predicted ids for the reference + images. + image_mode (str, optional): Color mode if the image. Defaults to "RGB". + image_viewer (ImageViewerBackend, optional): The Image viewer backend + to use. Defaults to MatplotlibImageViewer(). + """ + key_imgs_np = tuple( + array_to_numpy(img, n_dims=3, dtype=np.float32) for img in key_imgs + ) + ref_imgs_np = tuple( + array_to_numpy(img, n_dims=3, dtype=np.float32) for img in ref_imgs + ) + key_boxes_np = tuple( + array_to_numpy(b, n_dims=2, dtype=np.float32) for b in key_boxes + ) + ref_boxes_np = tuple( + array_to_numpy(b, n_dims=2, dtype=np.float32) for b in ref_boxes + ) + key_track_ids_np = tuple( + array_to_numpy(t, n_dims=1, dtype=np.int32) for t in key_track_ids + ) + ref_track_ids_np = tuple( + array_to_numpy(t, n_dims=1, dtype=np.int32) for t in ref_track_ids + ) + + for batch_i, (key_box, ref_box) in enumerate( + zip(key_boxes_np, ref_boxes_np) + ): + target = key_track_ids_np[batch_i].reshape(-1, 1) == ref_track_ids_np[ + batch_i + ].reshape(1, -1) + for key_i in range(target.shape[0]): + if target[key_i].sum() == 0: + continue + ref_i = np.argmax(target[key_i]).item() + ref_image = ref_imgs_np[batch_i] + key_image = key_imgs_np[batch_i] + + if ref_image.shape != key_image.shape: + # Can not stack images together + imshow_bboxes( + key_image, + key_box[key_i], + image_mode=image_mode, + image_viewer=image_viewer, + ) + imshow_bboxes( + ref_image, + ref_box[ref_i], + image_mode=image_mode, + image_viewer=image_viewer, + ) + else: + # stack imgs horizontal + k_canvas = draw_bboxes( + key_image, key_box[batch_i], image_mode=image_mode + ) + r_canvas = draw_bboxes( + ref_image, ref_box[batch_i], image_mode=image_mode + ) + k_np_img = k_canvas.as_numpy_image() + r_np_img = r_canvas.as_numpy_image() + stacked_img = np.vstack([k_np_img, r_np_img]) + + imshow(stacked_img, image_mode, image_viewer) diff --git a/vis4d/vis/image/seg_mask_visualizer.py b/vis4d/vis/image/seg_mask_visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..7b95451a0d3228aa7d85da20152abf22a20051f2 --- /dev/null +++ b/vis4d/vis/image/seg_mask_visualizer.py @@ -0,0 +1,213 @@ +"""Segmentation mask visualizer.""" + +from __future__ import annotations + +import os +from dataclasses import dataclass + +from vis4d.common.typing import ( + ArgsType, + ArrayLikeFloat, + ArrayLikeInt, + ArrayLikeUInt, + NDArrayBool, + NDArrayUI8, +) +from vis4d.vis.base import Visualizer +from vis4d.vis.image.canvas import CanvasBackend, PillowCanvasBackend +from vis4d.vis.image.util import preprocess_image, preprocess_masks +from vis4d.vis.image.viewer import ImageViewerBackend, MatplotlibImageViewer +from vis4d.vis.util import generate_color_map + + +@dataclass +class SegMask2D: + """Dataclass storing mask information.""" + + mask: NDArrayBool + color: tuple[int, int, int] + + +@dataclass +class ImageWithSegMask: + """Dataclass storing a data sample that can be visualized.""" + + image: NDArrayUI8 + image_name: str + masks: list[SegMask2D] + + +class SegMaskVisualizer(Visualizer): + """Segmentation mask visualizer class.""" + + def __init__( + self, + *args: ArgsType, + n_colors: int = 50, + class_id_mapping: dict[int, str] | None = None, + file_type: str = "png", + color_palette: list[tuple[int, int, int]] | None = None, + canvas: CanvasBackend = PillowCanvasBackend(), + viewer: ImageViewerBackend = MatplotlibImageViewer(), + **kwargs: ArgsType, + ) -> None: + """Creates a new Visualizer for Image and Bounding Boxes. + + Args: + n_colors (int): How many colors should be used for the color map. + class_id_mapping (dict[int, str]): Mapping from class id to + human readable name. + file_type (str): Desired file type + color_palette (list[tuple[int, int, int]]): Color palette for each + class, in RGB format (0-255). If None, a random color palette + with n_colors is generated automatically. Defaults to None. + canvas (CanvasBackend): Backend that is used to draw on images + viewer (ImageViewerBackend): Backend that is used show images + """ + super().__init__(*args, **kwargs) + self._samples: list[ImageWithSegMask] = [] + self.color_palette = ( + generate_color_map(n_colors) + if color_palette is None + else color_palette + ) + self.class_id_mapping = ( + class_id_mapping if class_id_mapping is not None else {} + ) + self.file_type = file_type + self.canvas = canvas + self.viewer = viewer + + def reset(self) -> None: + """Reset visualizer for new round of evaluation.""" + self._samples.clear() + + def _add_masks( + self, + data_sample: ImageWithSegMask, + masks: ArrayLikeUInt, + class_ids: ArrayLikeInt | None = None, + ) -> None: + """Adds a mask to the current data sample. + + Args: + data_sample (ImageWithSegMask): Data sample to add mask to. + masks (ArrayLikeUInt): Binary masks shape [N, H, W] or [H, W]. + class_ids (NDArrayInt, optional): Class ids for each mask, with + shape [N]. Defaults to None. + """ + if class_ids is not None: + assert ( + class_ids.shape[0] == masks.shape[0] # type: ignore + ), "The amount of masks must match the given class count!" + + for mask, color in zip( + *preprocess_masks(masks, class_ids, self.color_palette) + ): + data_sample.masks.append(SegMask2D(mask=mask, color=color)) + + def _draw_image(self, sample: ImageWithSegMask) -> NDArrayUI8: + """Visualizes the datasample and returns is as numpy image. + + Args: + sample (DataSample): The data sample to visualize. + + Returns: + NDArrayUI8: A image with the visualized data sample. + """ + self.canvas.create_canvas(sample.image) + for mask in sample.masks: + self.canvas.draw_bitmap(mask.mask, mask.color) + return self.canvas.as_numpy_image() + + def process( # pylint: disable=arguments-differ + self, + cur_iter: int, + images: list[ArrayLikeFloat], + image_names: list[str], + masks: list[ArrayLikeUInt], + class_ids: list[ArrayLikeInt] | None = None, + ) -> None: + """Processes a batch of data. + + Args: + cur_iter (int): Current iteration. + images (list[ArrayLikeFloat]): Images to show. + image_names (list[str]): Image names. + masks (list[ArrayLikeUInt]): Segmentation masks to show, each + with shape [H, W] or [N, H, W]. If the shape is [H, W], the + mask is assumed to be a semantic segmentation mask with each + pixel being the class id. If the shape is [N, H, W], each mask + is assumed to be a binary mask with each pixel being either 0 + or 1. + class_ids (list[ArrayLikeInt], optional): Class ids for each mask, + with shape [N]. If set, the masks are assumed to be binary + masks and the length of class_ids must match the amount of + masks. Defaults to None. + """ + if not self._run_on_batch(cur_iter): + return + + for idx, image in enumerate(images): + self.process_single_image( + image, + image_names[idx], + masks[idx], + None if class_ids is None else class_ids[idx], + ) + + def process_single_image( + self, + image: ArrayLikeFloat, + image_name: str, + masks: ArrayLikeUInt, + class_ids: ArrayLikeInt | None = None, + ) -> None: + """Processes a single image entry. + + Args: + image (ArrayLikeFloat): Images to show. + image_name (str): Name of the image. + masks (ArrayLikeUInt): Binary masks to show, each with shape + [N, H, W] or [H, W]. + class_ids (ArrayLikeInt, optional): Class ids for each mask, with + shape [N]. Defaults to None. + """ + img_normalized = preprocess_image(image, mode=self.image_mode) + data_sample = ImageWithSegMask(img_normalized, image_name, []) + self._add_masks(data_sample, masks, class_ids) + self._samples.append(data_sample) + + def show(self, cur_iter: int, blocking: bool = True) -> None: + """Shows the processed images in a interactive window. + + Args: + cur_iter (int): Current iteration. + blocking (bool): If the visualizer should be blocking i.e. wait for + human input for each image + """ + if not self._run_on_batch(cur_iter): + return + image_data = [self._draw_image(d) for d in self._samples] + self.viewer.show_images(image_data, blocking=blocking) + + def save_to_disk(self, cur_iter: int, output_folder: str) -> None: + """Saves the visualization to disk. + + Writes all processes samples to the output folder naming each image + .. + + Args: + cur_iter (int): Current iteration. + output_folder (str): Folder where the output should be written. + """ + if not self._run_on_batch(cur_iter): + return + for sample in self._samples: + image_name = f"{sample.image_name}.{self.file_type}" + + self.canvas.create_canvas(sample.image) + for mask in sample.masks: + self.canvas.draw_bitmap(mask.mask, mask.color) + + self.canvas.save_to_disk(os.path.join(output_folder, image_name)) diff --git a/vis4d/vis/image/util.py b/vis4d/vis/image/util.py new file mode 100644 index 0000000000000000000000000000000000000000..a66dc65b82450c878c2a2f0b36867b77378887d2 --- /dev/null +++ b/vis4d/vis/image/util.py @@ -0,0 +1,423 @@ +"""Utility functions for image processing operations.""" + +from __future__ import annotations + +import numpy as np +import torch + +from vis4d.common.array import array_to_numpy +from vis4d.common.typing import ( + ArrayLike, + ArrayLikeFloat, + ArrayLikeInt, + ArrayLikeUInt, + NDArrayBool, + NDArrayF32, + NDArrayUI8, +) +from vis4d.data.const import AxisMode +from vis4d.op.box.box3d import ( + boxes3d_in_image, + boxes3d_to_corners, + transform_boxes3d, +) +from vis4d.op.geometry.projection import project_points +from vis4d.op.geometry.transform import inverse_rigid_transform +from vis4d.vis.util import DEFAULT_COLOR_MAPPING + + +def _get_box_label( + category: str | None, + score: float | None, + track_id: int | None, +) -> str: + """Gets a unique string representation for a box definition. + + Args: + category (str): The category name + score (float): The confidence score + track_id (int): The track id + + Returns: + str: Label for this box of format + 'class_name, track_id, score%' + """ + labels = [] + + if category is not None: + labels.append(category) + if track_id is not None: + labels.append(str(track_id)) + if score is not None: + labels.append(f"{score * 100:.1f}%") + return ", ".join(labels) + + +def _to_binary_mask( + mask: NDArrayUI8, ignore_class: int = 255 +) -> tuple[NDArrayUI8, NDArrayUI8]: + """Converts a mask to binary masks. + + Args: + mask (NDArrayUI8): The mask to convert with shape [H, W]. + ignore_class (int): The class id to ignore. Defaults to 255. + + Returns: + NDArrayUI8: The binary masks with shape [N, H, W]. + NDArrayUI8: The class ids for each binary mask. + """ + binary_masks = [] + class_ids = [] + for class_id in np.unique(mask): + if class_id == ignore_class: + continue + binary_masks.append(mask == class_id) + class_ids.append(class_id) + return np.stack(binary_masks, axis=0), np.array(class_ids, dtype=np.uint8) + + +def preprocess_boxes( + boxes: ArrayLikeFloat, + scores: None | ArrayLikeFloat = None, + class_ids: None | ArrayLikeInt = None, + track_ids: None | ArrayLikeInt = None, + color_palette: list[tuple[int, int, int]] = DEFAULT_COLOR_MAPPING, + class_id_mapping: dict[int, str] | None = None, + default_color: tuple[int, int, int] = (255, 0, 0), + categories: None | list[str] = None, +) -> tuple[ + list[tuple[float, float, float, float]], + list[str], + list[tuple[int, int, int]], +]: + """Preprocesses bounding boxes. + + Converts the given predicted bounding boxes and class/track information + into lists of corners, labels and colors. + + Args: + boxes (ArrayLikeFloat): Boxes of shape [N, 4] where N is the number of + boxes and the second channel consists of + (x1,y1,x2,y2) box coordinates. + scores (ArrayLikeFloat): Scores for each box shape [N] + class_ids (ArrayLikeInt): Class id for each box shape [N] + track_ids (ArrayLikeInt): Track id for each box shape [N] + color_palette (list[tuple[float, float, float]]): Color palette for + each id. + class_id_mapping(dict[int, str], optional): Mapping from class id + to color tuple (0-255). + default_color (tuple[int, int, int]): fallback color for boxes of no + class or track id is given. + categories (None | list[str], optional): List of categories for each + box. + + Returns: + boxes_proc (list[tuple[float, float, float, float]]): List of box + corners. + labels_proc (list[str]): List of labels. + colors_proc (list[tuple[int, int, int]]): List of colors. + """ + if class_id_mapping is None: + class_id_mapping = {} + + boxes = array_to_numpy(boxes, n_dims=2, dtype=np.float32) + + scores_np = array_to_numpy(scores, n_dims=1, dtype=np.float32) + class_ids_np = array_to_numpy(class_ids, n_dims=1, dtype=np.int32) + track_ids_np = array_to_numpy(track_ids, n_dims=1, dtype=np.int32) + + boxes_proc: list[tuple[float, float, float, float]] = [] + colors_proc: list[tuple[int, int, int]] = [] + labels_proc: list[str] = [] + + # Only one box provided + if len(boxes.shape) == 1: + # unsqueeze one dimension + boxes = boxes.reshape(1, -1) + + for idx in range(boxes.shape[0]): + class_id = None if class_ids_np is None else class_ids_np[idx].item() + score = None if scores_np is None else scores_np[idx].item() + track_id = None if track_ids_np is None else track_ids_np[idx].item() + + if track_id is not None: + color = color_palette[track_id % len(color_palette)] + elif class_id is not None: + color = color_palette[class_id % len(color_palette)] + else: + color = default_color + + boxes_proc.append( + ( + boxes[idx][0].item(), + boxes[idx][1].item(), + boxes[idx][2].item(), + boxes[idx][3].item(), + ) + ) + colors_proc.append(color) + + if categories is not None: + category = categories[idx] + elif class_id is not None: + category = class_id_mapping.get(class_id, str(class_id)) + else: + category = None + + labels_proc.append(_get_box_label(category, score, track_id)) + return boxes_proc, labels_proc, colors_proc + + +def preprocess_boxes3d( + image_hw: tuple[int, int], + boxes3d: ArrayLikeFloat, + intrinsics: ArrayLikeFloat, + extrinsics: ArrayLikeFloat | None = None, + scores: None | ArrayLikeFloat = None, + class_ids: None | ArrayLikeInt = None, + track_ids: None | ArrayLikeInt = None, + color_palette: list[tuple[int, int, int]] = DEFAULT_COLOR_MAPPING, + class_id_mapping: dict[int, str] | None = None, + default_color: tuple[int, int, int] = (255, 0, 0), + axis_mode: AxisMode = AxisMode.OPENCV, + categories: None | list[str] = None, +) -> tuple[ + list[tuple[float, float, float]], + list[list[tuple[float, float, float]]], + list[str], + list[tuple[int, int, int]], + list[int | None], +]: + """Preprocesses bounding boxes. + + Converts the given predicted bounding boxes and class/track information + into lists of centers, corners, labels, colors and track_ids. + """ + if class_id_mapping is None: + class_id_mapping = {} + + boxes3d = array_to_numpy(boxes3d, n_dims=2, dtype=np.float32) + intrinsics = array_to_numpy(intrinsics, n_dims=2, dtype=np.float32) + + boxes3d = torch.from_numpy(boxes3d) + intrinsics = torch.from_numpy(intrinsics) + + if axis_mode != AxisMode.OPENCV: + assert ( + extrinsics is not None + ), "extrinsics must be provided to move boxes to camera coordiante." + extrinsics = array_to_numpy(extrinsics, n_dims=2, dtype=np.float32) + extrinsics = torch.from_numpy(extrinsics) + global_to_cam = inverse_rigid_transform(extrinsics) + boxes3d_cam = transform_boxes3d( + boxes3d, + global_to_cam, + source_axis_mode=AxisMode.ROS, + target_axis_mode=AxisMode.OPENCV, + ) + else: + boxes3d_cam = boxes3d + + corners = boxes3d_to_corners(boxes3d_cam, axis_mode=AxisMode.OPENCV) + + mask = boxes3d_in_image(corners, intrinsics, image_hw) + + boxes3d_np = boxes3d.numpy() + corners_np = corners.numpy() + + scores_np = array_to_numpy(scores, n_dims=1, dtype=np.float32) + class_ids_np = array_to_numpy(class_ids, n_dims=1, dtype=np.int32) + track_ids_np = array_to_numpy(track_ids, n_dims=1, dtype=np.int32) + + centers_proc: list[tuple[float, float, float]] = [] + corners_proc: list[list[tuple[float, float, float]]] = [] + colors_proc: list[tuple[int, int, int]] = [] + labels_proc: list[str] = [] + track_ids_proc: list[int | None] = [] + + if len(mask) == 1: + if not mask[0]: + return ( + centers_proc, + corners_proc, + labels_proc, + colors_proc, + track_ids_proc, + ) + else: + boxes3d_np = boxes3d_np[mask] + corners_np = corners_np[mask] + scores_np = scores_np[mask] if scores_np is not None else None + class_ids_np = class_ids_np[mask] if class_ids_np is not None else None + track_ids_np = track_ids_np[mask] if track_ids_np is not None else None + + for idx in range(corners_np.shape[0]): + class_id = None if class_ids_np is None else class_ids_np[idx].item() + score = None if scores_np is None else scores_np[idx].item() + track_id = None if track_ids_np is None else track_ids_np[idx].item() + + if track_id is not None: + color = color_palette[track_id % len(color_palette)] + elif class_id is not None: + color = color_palette[class_id % len(color_palette)] + else: + color = default_color + + centers_proc.append( + ( + boxes3d_np[idx][0].item(), + boxes3d_np[idx][1].item(), + boxes3d_np[idx][2].item(), + ) + ) + corners_proc.append([tuple(pts) for pts in corners_np[idx].tolist()]) + colors_proc.append(color) + + if categories is not None: + category = categories[idx] + elif class_id is not None: + category = class_id_mapping.get(class_id, str(class_id)) + else: + category = None + + labels_proc.append(_get_box_label(category, score, track_id)) + track_ids_proc.append(track_id) + return centers_proc, corners_proc, labels_proc, colors_proc, track_ids_proc + + +def preprocess_masks( + masks: ArrayLikeUInt, + class_ids: ArrayLikeInt | None = None, + color_mapping: list[tuple[int, int, int]] = DEFAULT_COLOR_MAPPING, +) -> tuple[list[NDArrayBool], list[tuple[int, int, int]]]: + """Preprocesses predicted semantic or instance segmentation masks. + + Args: + masks (ArrayLikeUInt): Masks of shape [H, W] or [N, H, W]. If the + masks are of shape [H, W], they are assumed to be semantic + segmentation masks, i.e. each pixel contains the class id. + If the masks are of shape [N, H, W], they are assumed to be + the binary masks of N instances. + class_ids (ArrayLikeInt, None): An array with class ids for each mask + shape [N]. If None, then the masks must be semantic segmentation + masks and the class ids are extracted from the masks. + color_mapping (list[tuple[int, int, int]]): Color mapping for + each class. + + Returns: + tuple[list[masks], list[colors]]: Returns a list with all masks of + shape [H, W] as well as a list with the corresponding colors. + + Raises: + ValueError: If the masks have an invalid shape. + """ + masks_np = array_to_numpy(masks, n_dims=None, dtype=np.uint8) + + if len(masks_np.shape) == 2: + masks_np, class_ids = _to_binary_mask(masks_np) + elif len(masks_np.shape) == 3: + if class_ids is not None: + class_ids = array_to_numpy(class_ids, n_dims=1, dtype=np.int32) + else: + raise ValueError( + f"Expected masks to have 2 or 3 dimensions, but got " + f"{len(masks_np.shape)}" + ) + + masks_binary = masks_np.astype(bool) + mask_list: list[NDArrayBool] = [] + color_list: list[tuple[int, int, int]] = [] + + for idx in range(masks_binary.shape[0]): + mask = masks_binary[idx, ...] + + class_id = None if class_ids is None else class_ids[idx].item() + if class_id is not None: + color = color_mapping[class_id % len(color_mapping)] + else: + color = color_mapping[idx % len(color_mapping)] + mask_list.append(mask) + color_list.append(color) + return mask_list, color_list + + +def preprocess_image(image: ArrayLike, mode: str = "RGB") -> NDArrayUI8: + """Validate and convert input image. + + Args: + image: CHW or HWC image (ArrayLike) with C = 3. + mode: input channel format (e.g. BGR, HSV). + + Returns: + np.array[uint8]: Processed image_np in RGB. + """ + image_np = array_to_numpy(image, n_dims=3, dtype=np.float32) + # Convert torch to numpy + assert len(image_np.shape) == 3 + assert image_np.shape[0] == 3 or image_np.shape[-1] == 3 + + # Convert torch to numpy convention + if not image_np.shape[-1] == 3: + image_np = np.transpose(image_np, (1, 2, 0)) + + # Convert image_np to [0, 255] + min_val, max_val = ( + np.min(image_np, axis=(0, 1)), + np.max(image_np, axis=(0, 1)), + ) + image_np = image_np.astype(np.float32) + image_np = (image_np - min_val) / (max_val - min_val) * 255.0 + + if mode == "BGR": + image_np = image_np[..., [2, 1, 0]] + + return image_np.astype(np.uint8) + + +def get_intersection_point( + point1: tuple[float, float, float], + point2: tuple[float, float, float], + camera_near_clip: float, +) -> tuple[float, float, float]: + """Get point intersecting with camera near plane on line point1 -> point2. + + The line is defined by two points in camera coordinates and their depth. + + Args: + point1 (tuple[float x 3]): First point in camera coordinates. + point2 (tuple[float x 3]): Second point in camera coordinates + camera_near_clip (float): camera_near_clip + + Returns: + tuple[float, float, float]: The intersection point in camera + coordiantes. + """ + c1, c2, c3 = 0, 0, camera_near_clip + a1, a2, a3 = 0, 0, 1 + x1, y1, z1 = point1 + x2, y2, z2 = point2 + + k_up = abs(a1 * (x1 - c1) + a2 * (y1 - c2) + a3 * (z1 - c3)) + k_down = abs(a1 * (x1 - x2) + a2 * (y1 - y2) + a3 * (z1 - z2)) + if k_up > k_down: + k = 1.0 + else: + k = k_up / k_down + + return ((1 - k) * x1 + k * x2, (1 - k) * y1 + k * y2, camera_near_clip) + + +def project_point( + point: tuple[float, float, float], intrinsics: NDArrayF32 +) -> tuple[float, float]: + """Project single point into the image plane.""" + projected_x, projected_y = ( + project_points( + torch.from_numpy(np.array([point], dtype=np.float32)), + torch.from_numpy(intrinsics), + ) + .squeeze(0) + .numpy() + .tolist() + ) + return projected_x, projected_y diff --git a/vis4d/vis/image/viewer/__init__.py b/vis4d/vis/image/viewer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..207225e6dace026ad9eec9aa3bf778e93fc84f4d --- /dev/null +++ b/vis4d/vis/image/viewer/__init__.py @@ -0,0 +1,6 @@ +"""Viewer implementations to display images.""" + +from .base import ImageViewerBackend +from .matplotlib_viewer import MatplotlibImageViewer + +__all__ = ["ImageViewerBackend", "MatplotlibImageViewer"] diff --git a/vis4d/vis/image/viewer/base.py b/vis4d/vis/image/viewer/base.py new file mode 100644 index 0000000000000000000000000000000000000000..59aed95cda1b3ac17a8e0aff9fd91ba8c2147798 --- /dev/null +++ b/vis4d/vis/image/viewer/base.py @@ -0,0 +1,32 @@ +"""Base class of image viewer for image based visualization.""" + +from __future__ import annotations + +from vis4d.common.typing import NDArrayUI8 + + +class ImageViewerBackend: + """Abstract interface that allows to show images.""" + + def show_images( + self, images: list[NDArrayUI8], blocking: bool = True + ) -> None: + """Shows a list of images. + + Args: + images (list[NDArrayUI8]): Images to display. + blocking (bool, optional): If the viewer should be blocking and + wait for input after each image. Defaults to True. + """ + raise NotImplementedError + + def save_images( + self, images: list[NDArrayUI8], file_paths: list[str] + ) -> None: + """Saves a list of images. + + Args: + images (list[NDArrayUI8]): Images to save. + file_paths (list[str]): File paths to save the images to. + """ + raise NotImplementedError diff --git a/vis4d/vis/image/viewer/matplotlib_viewer.py b/vis4d/vis/image/viewer/matplotlib_viewer.py new file mode 100644 index 0000000000000000000000000000000000000000..cf11ec78d8fe600c1b462a2e1d868a578f9972f4 --- /dev/null +++ b/vis4d/vis/image/viewer/matplotlib_viewer.py @@ -0,0 +1,42 @@ +"""Matplotlib based image viewer.""" + +from __future__ import annotations + +import matplotlib.pyplot as plt + +from vis4d.common.typing import NDArrayUI8 + +from .base import ImageViewerBackend + + +class MatplotlibImageViewer(ImageViewerBackend): + """A image viewer using matplotlib.pyplot.""" + + def show_images( + self, images: list[NDArrayUI8], blocking: bool = True + ) -> None: + """Shows a list of images. + + Args: + images (list[NDArrayUI8]): Images to display. + blocking (bool): If the viewer should be blocking and wait + for human input after each image. + """ + for image in images: + plt.imshow(image) + plt.axis("off") + plt.show(block=blocking) + + def save_images( + self, images: list[NDArrayUI8], file_paths: list[str] + ) -> None: + """Saves a list of images. + + Args: + images (list[NDArrayUI8]): Images to save. + file_paths (list[str]): File paths to save the images to. + """ + for i, image in enumerate(images): + plt.imshow(image) + plt.axis("off") + plt.savefig(f"{file_paths[i]}", bbox_inches="tight") diff --git a/vis4d/vis/pointcloud/__init__.py b/vis4d/vis/pointcloud/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..147260a84af3a18416d898f4a4ee814e33f7efcf --- /dev/null +++ b/vis4d/vis/pointcloud/__init__.py @@ -0,0 +1,5 @@ +"""Pointcloud Visualization Package.""" + +from .pointcloud_visualizer import PointCloudVisualizer + +__all__ = ["PointCloudVisualizer"] diff --git a/vis4d/vis/pointcloud/functional.py b/vis4d/vis/pointcloud/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..fec7fd83af77f826c1adb94749499acc152df952 --- /dev/null +++ b/vis4d/vis/pointcloud/functional.py @@ -0,0 +1,87 @@ +"""Function interface for point cloud visualization functions.""" + +from __future__ import annotations + +from vis4d.common.typing import ArrayLikeFloat, ArrayLikeInt + +from ..util import DEFAULT_COLOR_MAPPING +from .scene import Scene3D +from .viewer import Open3DVisualizationBackend, PointCloudVisualizerBackend + + +def show_3d( + scene: Scene3D, + viewer: PointCloudVisualizerBackend = Open3DVisualizationBackend( + class_color_mapping=DEFAULT_COLOR_MAPPING + ), +) -> None: + """Shows a given 3D scene. + + This method shows a 3D visualization of a given 3D scene. Use the viewer + attribute to use different visualization backends (e.g. open3d) + + Args: + scene (Scene3D): The 3D scene that should be visualized. + viewer (PointCloudVisualizerBackend, optional): The Visualization + backend that should be used to visualize the scene. + Defaults to Open3DVisualizationBackend. + """ + viewer.add_scene(scene) + viewer.show() + viewer.reset() + + +def draw_points( + points_xyz: ArrayLikeFloat, + colors: ArrayLikeFloat | None = None, + classes: ArrayLikeInt | None = None, + instances: ArrayLikeInt | None = None, + transform: ArrayLikeFloat | None = None, + scene: Scene3D | None = None, +) -> Scene3D: + """Adds pointcloud data to a 3D scene for visualization purposes. + + Args: + points_xyz: xyz coordinates of the points shape [N, 3] + classes: semantic ids of the points shape [N, 1] + instances: instance ids of the points shape [N, 1] + colors: colors of the points shape [N,3] and ranging from [0,1] + transform: Optional 4x4 SE3 transform that transforms the point data + into a static reference frame. + scene (Scene3D | None): Visualizer that should be used to display the + data. + """ + if scene is None: + scene = Scene3D() + + return scene.add_pointcloud( + points_xyz, colors, classes, instances, transform + ) + + +def show_points( + points_xyz: ArrayLikeFloat, + colors: ArrayLikeFloat | None = None, + classes: ArrayLikeInt | None = None, + instances: ArrayLikeInt | None = None, + transform: ArrayLikeFloat | None = None, + viewer: PointCloudVisualizerBackend = Open3DVisualizationBackend( + class_color_mapping=DEFAULT_COLOR_MAPPING + ), +) -> None: + """Visualizes a pointcloud with color and semantic information. + + Args: + points_xyz: xyz coordinates of the points shape [N, 3] + classes: semantic ids of the points shape [N, 1] + instances: instance ids of the points shape [N, 1] + colors: colors of the points shape [N,3] and ranging from [0,1] + transform: Optional 4x4 SE3 transform that transforms the point data + into a static reference frame + viewer (PointCloudVisualizerBackend, optional): The Visualization + backend that should be used to visualize the scene. + Defaults to Open3DVisualizationBackend. + """ + show_3d( + draw_points(points_xyz, colors, classes, instances, transform), viewer + ) diff --git a/vis4d/vis/pointcloud/pointcloud_visualizer.py b/vis4d/vis/pointcloud/pointcloud_visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..ae27211340c8a5409217286ad09ae48b8053132e --- /dev/null +++ b/vis4d/vis/pointcloud/pointcloud_visualizer.py @@ -0,0 +1,181 @@ +"""Vis4D Visualization tools for analysis and debugging.""" + +from __future__ import annotations + +from vis4d.common.imports import OPEN3D_AVAILABLE +from vis4d.common.typing import ArgsType, NDArrayF64, NDArrayI64 +from vis4d.vis.base import Visualizer +from vis4d.vis.pointcloud.scene import Scene3D +from vis4d.vis.pointcloud.viewer import PointCloudVisualizerBackend +from vis4d.vis.util import DEFAULT_COLOR_MAPPING + +if OPEN3D_AVAILABLE: + from .viewer.open3d_viewer import Open3DVisualizationBackend + + +# TODO: Check typing +class PointCloudVisualizer(Visualizer): + """Visualizer that visualizes pointclouds.""" + + def __init__( + self, + *args: ArgsType, + backend: str = "open3d", + class_color_mapping: list[ + tuple[int, int, int] + ] = DEFAULT_COLOR_MAPPING, + instance_color_mapping: list[ + tuple[int, int, int] + ] = DEFAULT_COLOR_MAPPING, + **kwargs: ArgsType, + ) -> None: + """Creates a new Pointcloud visualizer. + + Args: + backend (str): Visualization backend that should be used. Choice + of [open3d]. + class_color_mapping (list[tuple[int, int, int]], optional): List + of length n_classes that assigns each class a unique color. + instance_color_mapping (list[tuple[int, int, int]], optional): List + of length n_classes that assigns each class a unique color. + """ + super().__init__(*args, **kwargs) + if backend == "open3d": + if not OPEN3D_AVAILABLE: + raise ValueError( + "You have specified the open3d backend." + "But open3d is not installed on this system!" + ) + self.visualization_backend: PointCloudVisualizerBackend = ( + Open3DVisualizationBackend( + class_color_mapping=class_color_mapping, + instance_color_mapping=instance_color_mapping, + ) + ) + else: + raise ValueError(f"Unknown Point Visualization Backend {backend}") + + self.current_scene_idx: int | None = None + self.current_scene: Scene3D | None = None + + def process_single( + self, + points_xyz: NDArrayF64, + semantics: NDArrayI64 | None = None, + instances: NDArrayI64 | None = None, + colors: NDArrayF64 | None = None, + scene_index: NDArrayI64 | int | None = None, + ) -> None: + """Processes data and adds it to the visualizer. + + Args: + points_xyz: xyz coordinates of the points shape [B, N, 3] + semantics: semantic ids of the points shape [B, N, 1] + instances: instance ids of the points shape [B, N, 1] + colors: colors of the points shape [B, N,3] and ranging from [0,1] + scene_index: Scene index for visualization of shape [B, 1]. + This allows to plot multiple predictions in the same scene + if e.g. for memory reasons it had to be split up in multiple + channels.. + + Raises: + ValueError: If shapes of the arrays missmatch. + """ + # Load correct scene + if scene_index is None: + # No scene index given. Create new scene for each call + self.current_scene = self.visualization_backend.create_new_scene() + else: + # Scene index given, check if we should update given scene + # or create a new one + new_scene_idx = ( + scene_index + if isinstance(scene_index, int) + else scene_index.item() + ) + if ( + self.current_scene_idx is None + or self.current_scene_idx != new_scene_idx + ): + self.current_scene = ( + self.visualization_backend.create_new_scene() + ) + self.current_scene_idx = new_scene_idx + + if self.current_scene is None: + self.current_scene = self.visualization_backend.create_new_scene() + + # Add data to scene + self.current_scene.add_pointcloud( + points_xyz, colors=colors, classes=semantics, instances=instances + ) + + def process( # pylint: disable=arguments-differ + self, + cur_iter: int, + points_xyz: NDArrayF64, + semantics: NDArrayI64 | None = None, + instances: NDArrayI64 | None = None, + colors: NDArrayF64 | None = None, + scene_index: NDArrayI64 | None = None, + ) -> None: + """Processes a batch of data and adds it to the visualizer. + + Args: + cur_iter: Current iteration. + points_xyz: xyz coordinates of the points shape [N, 3] + semantics: semantic ids of the points shape [N, 1] + instances: instance ids of the points shape [N, 1] + colors: colors of the points shape [N,3] and ranging from [0,1] + scene_index: Scene index for visualization of sape [1] or int. + This allows to plot multiple predictions in the same scene + if e.g. for memory reasons it had to be split up in multiple + chunls. + + Raises: + ValueError: If shapes of the arrays missmatch. + """ + if self._run_on_batch(cur_iter): + if len(points_xyz.shape) == 2: # Data is not batched + self.process_single( + points_xyz, semantics, instances, colors, scene_index + ) + elif len(points_xyz.shape) == 3: + for idx in range(points_xyz.shape[0]): + self.process_single( + points_xyz[idx, ...], + semantics[idx, ...] if semantics is not None else None, + instances[idx, ...] if instances is not None else None, + colors[idx, ...] if colors is not None else None, + ( + scene_index[idx, ...] + if scene_index is not None + else None + ), + ) + + else: + raise ValueError( + f"Invalid shape for point data: {points_xyz.shape}" + ) + + def show(self, cur_iter: int, blocking: bool = True) -> None: + """Shows the visualization. + + Args: + cur_iter (int): Current iteration. + blocking (bool): If the visualization should be blocking and wait + for human input + """ + self.visualization_backend.show(blocking) + + def reset(self) -> None: + """Clears all saved data.""" + self.visualization_backend.reset() + self.current_scene_idx = None + self.current_scene = None + + def save_to_disk(self, cur_iter: int, output_folder: str) -> None: + """Saves the visualization to disk.""" + if self._run_on_batch(cur_iter): + self.visualization_backend.save_to_disk(output_folder) diff --git a/vis4d/vis/pointcloud/scene.py b/vis4d/vis/pointcloud/scene.py new file mode 100644 index 0000000000000000000000000000000000000000..963a99f7cc200e9ecf9aecdd5849b8ae2ab5d43c --- /dev/null +++ b/vis4d/vis/pointcloud/scene.py @@ -0,0 +1,279 @@ +"""Data structures to store 3D data.""" + +from __future__ import annotations + +from dataclasses import dataclass + +import numpy as np + +from vis4d.common.array import array_to_numpy +from vis4d.common.typing import ArrayLike, NDArrayFloat, NDArrayInt + + +@dataclass +class BoundingBoxData: + """Stores bounding box data for visualization. + + Attributes: + corners (NDArrayFloat): Corners of the bounding box shape [8, 3]. + color (NDArrayFloat): Colors of the bounding box shape [3]. + class (int | None): Class id of the bounding box. Defaults to None. + instance (int | None): Instance id of the bounding box. + Defaults to None. + score (float | None): Score of the bounding box. Defaults to None. + """ + + corners: NDArrayFloat + color: NDArrayFloat | None + class_: int | None + instance: int | None + score: float | None + + def transform(self, transform: NDArrayFloat) -> BoundingBoxData: + """Transforms the bounding box. + + Args: + transform (NDArrayFloat): Transformation matrix shape [4,4] that + transforms points from the current local frame to a fixed + global frame. + + Returns: + BoundingBoxData: Returns a new bounding box with the transformed + points. + """ + assert transform.shape == ( + 4, + 4, + ), "Shape of the provided transform not valid." + return BoundingBoxData( + (transform[:3, :3] @ self.corners.T).T + transform[:3, -1], + self.color, + self.class_, + self.instance, + self.score, + ) + + +@dataclass +class PointcloudData: + """Stores pointcloud data for visualization. + + Attributes: + xyz: Point Coordinates shape [n_pts,3]. + colors: Point Colors shape [n_pts, 3] or None. + classes: Class ids shape [n_pts] or None. + instances: Instance ids shape [n_pts] or None. + num_points: Total number of points. + num_classes: Total number of classes. + num_instances: Total number of unique class, instance combinations. + """ + + xyz: NDArrayFloat + colors: NDArrayFloat | None + classes: NDArrayInt | None + instances: NDArrayInt | None + + num_points: int + num_classes: int + num_instances: int + + def __init__( + self, + xyz: ArrayLike, + colors: ArrayLike | None = None, + classes: ArrayLike | None = None, + instances: ArrayLike | None = None, + ) -> None: + """Creates a new pointcloud. + + Args: + xyz (ArrayLike): Coordinates for each point shape [n_pts, 3] + colors (ArrayLike | None, optional): Colors for each point encoded + as rgb [n_pts, 3] in the range (0,255). Defaults to None. + classes (ArrayLike | None, optional): Class id for each point + shape [n_pts]. Defaults to None. + instances (ArrayLike | None, optional): Instance id for each point. + shape [n_pts]. Defaults to None. + """ + self.xyz = array_to_numpy(xyz, n_dims=2, dtype=np.float32) + self.colors = array_to_numpy(colors, n_dims=2, dtype=np.float32) + self.classes = array_to_numpy(classes, n_dims=1, dtype=np.int32) + self.instances = array_to_numpy(instances, n_dims=1, dtype=np.int32) + + # Assing other properties. Number points, ... + self.num_points = self.xyz.shape[0] + + if self.classes is not None: + self.num_classes = len(np.unique(self.classes)) + + if self.instances is not None: + if self.classes is None: + self.num_instances = len(np.unique(self.instances)) + else: + self.num_instances = len( + np.unique( + self.classes * np.max(self.instances) + self.instances + ) + ) + + def transform(self, transform: NDArrayFloat) -> PointcloudData: + """Transforms the pointcloud. + + Args: + transform (NDArrayFloat): Transformation matrix shape [4,4] that + transforms points from the current local frame to a fixed + global frame. + + Returns: + PointcloudData: Returns a new pointcloud with the transformed + points. + """ + assert transform.shape == ( + 4, + 4, + ), "Shape of the provided transform not valid." + return PointcloudData( + (transform[:3, :3] @ self.xyz.T).T + transform[:3, -1], + self.colors, + self.classes, + self.instances, + ) + + +class Scene3D: + """Stores the data for a 3D scene. + + This Scene3D object can be used to be visualized by any 3D viewer. + + Attributes: + pointclouds (list[PointcloudData]): Stores all pointclouds that + have been registered for this scene so far. + pointclouds (list[NDArrayFloat]): Stores a transformation matrix + (SE3, shape (4,4)) for each pointcloud. + """ + + def __init__(self) -> None: + """Creates a new, empty scene.""" + self._pointclouds: list[tuple[PointcloudData, NDArrayFloat]] = [] + self._bounding_boxes: list[tuple[BoundingBoxData, NDArrayFloat]] = [] + + @staticmethod + def _parse_se3_transform(transform: ArrayLike | None) -> NDArrayFloat: + """Parses a SE3 transformation matrix. + + Args: + transform (ArrayLike | None): Transformation matrix shape [4,4] + that transforms points from the current local frame to a fixed + global frame. + + Returns: + NDArrayFloat: Returns a valid SE3 transformation matrix. + """ + tf = array_to_numpy(transform, n_dims=2, dtype=np.float32) + + if tf is None: + return np.eye(4) + + assert tf.shape == ( + 4, + 4, + ), "Shape of the provided transform not valid." + return tf + + def add_bounding_box( + self, + corners: ArrayLike, + color: ArrayLike | None, + class_: int | None, + instance: int | None, + score: float | None, + transform: ArrayLike | None = None, + ) -> Scene3D: + """Adds a bounding box to the 3D Scene. + + Args: + corners (ArrayLike): Corners of the bounding box shape [8, 3]. + color (ArrayLike | None): Color of the bounding box shape [3]. + class_ (int | None): Class id of the bounding box. + Defaults to None. + instance (int | None): Instance id of the bounding box. + Defaults to None. + score (float | None): Score of the bounding box. Defaults to None. + transform (ArrayLike | None): Transformation matrix shape [4,4] + that transforms points from the current local frame to a fixed + global frame. + + Returns: + Scene3D: Returns 'self' to chain calls. + """ + corners_np = array_to_numpy(corners, n_dims=2, dtype=np.float32) + colors_np = array_to_numpy(color, n_dims=1, dtype=np.float32) + self._bounding_boxes.append( + ( + BoundingBoxData( + corners_np, + colors_np, + class_, + instance, + score, + ), + self._parse_se3_transform(transform), + ), + ) + return self + + def add_pointcloud( + self, + xyz: ArrayLike, + colors: ArrayLike | None = None, + classes: ArrayLike | None = None, + instances: ArrayLike | None = None, + transform: ArrayLike | None = None, + ) -> Scene3D: + """Adds a pointcloud to the 3D Scene. + + Args: + xyz (ArrayLike): Coordinates for each point shape [n_pts, 3] in the + current local frame. + colors (ArrayLike | None, optional): Colors for each point encoded + as rgb [n_pts, 3] in the range (0,255) or (0,1). + Defaults to None. + classes (ArrayLike | None, optional): Class id for each point + shape [n_pts]. Defaults to None. + instances (ArrayLike | None, optional): Instance id for each point. + shape [n_pts]. Defaults to None. + transform (ArrayLike | None, optional): Transformation matrix + shape [4,4] that transforms points from the current local frame + to a fixed global frame. Defaults to None which is the identity + matrix. + + Returns: + Scene3D: Returns 'self' to chain calls. + """ + self._pointclouds.append( + ( + PointcloudData(xyz, colors, classes, instances), + self._parse_se3_transform(transform), + ) + ) + return self + + @property + def bounding_boxes(self) -> list[BoundingBoxData]: + """Returns all bounding boxes in the scene. + + Returns: + list[BoundingBoxData]: List of all bounding boxes in the scene. + """ + return [bbox.transform(tf) for (bbox, tf) in self._bounding_boxes] + + @property + def points(self) -> list[PointcloudData]: + """Returns all points of all pointclouds in the scene. + + Returns: + List[PointcloudData]: Data information for all points in the scene. + Providing information about the points, colors, classes and + instances. + """ + return [pc.transform(tf) for (pc, tf) in self._pointclouds] diff --git a/vis4d/vis/pointcloud/viewer/__init__.py b/vis4d/vis/pointcloud/viewer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c2d6326dea18ed46acb6be4c7587330c609e4235 --- /dev/null +++ b/vis4d/vis/pointcloud/viewer/__init__.py @@ -0,0 +1,6 @@ +"""Viewer implementations to display pointcloud.""" + +from .base import PointCloudVisualizerBackend +from .open3d_viewer import Open3DVisualizationBackend + +__all__ = ["PointCloudVisualizerBackend", "Open3DVisualizationBackend"] diff --git a/vis4d/vis/pointcloud/viewer/base.py b/vis4d/vis/pointcloud/viewer/base.py new file mode 100644 index 0000000000000000000000000000000000000000..8bce996c7d255de2793f585a2e8fc84e442fa476 --- /dev/null +++ b/vis4d/vis/pointcloud/viewer/base.py @@ -0,0 +1,86 @@ +"""Generic classes to visualize and save pointcloud data.""" + +from __future__ import annotations + +import numpy as np + +from ..scene import Scene3D + + +class PointCloudVisualizerBackend: + """Visualization Backen Interface for Pointclouds.""" + + def __init__( + self, + class_color_mapping: list[tuple[int, int, int]], + instance_color_mapping: list[tuple[int, int, int]] | None = None, + ) -> None: + """Creates a new Open3D visualization backend. + + Args: + class_color_mapping (list[tuple[int, int ,int]]): List of length + n_classes that maps each class index to a unique color. + instance_color_mapping (list[tuple[int, int ,int]], optional): List + of length n_instances that maps each instance id to a unique + color. Defaults to None. + """ + self.scenes: list[Scene3D] = [] + + self.class_color_mapping = np.asarray(class_color_mapping) + + if np.any(self.class_color_mapping > 1): # Color mapping from [0, 255] + self.class_color_mapping = self.class_color_mapping / 255 + + if instance_color_mapping is None: + self.instance_color_mapping = self.class_color_mapping + else: + self.instance_color_mapping = np.asarray(instance_color_mapping) + if np.any(self.instance_color_mapping > 1): + self.instance_color_mapping = self.instance_color_mapping / 255 + + def create_new_scene(self) -> Scene3D: + """Creates a new empty scene.""" + self.scenes.append(Scene3D()) + return self.get_current_scene() + + def get_current_scene(self) -> Scene3D: + """Returns the currently active scene. + + If no scene is available, an new empty one is created. + + Returns: + Scene3D: current pointcloud scene + """ + if (len(self.scenes)) == 0: + return self.create_new_scene() + + return self.scenes[-1] + + def show(self, blocking: bool = True) -> None: + """Shows the visualization. + + Args: + blocking (bool): If the visualization should be blocking + and wait for human input + """ + raise NotImplementedError() + + def reset(self) -> None: + """Clears all stored data.""" + self.scenes = [] + + def add_scene(self, scene: Scene3D) -> None: + """Adds a given Scene3D to the visualization. + + Args: + scene (Scene3D): 3D scene that should be added. + """ + self.scenes.append(scene) + + def save_to_disk(self, path_to_out_folder: str) -> None: + """Saves the visualization to disk. + + Args: + path_to_out_folder (str): Path to output folder + """ + raise NotImplementedError() diff --git a/vis4d/vis/pointcloud/viewer/open3d_viewer.py b/vis4d/vis/pointcloud/viewer/open3d_viewer.py new file mode 100644 index 0000000000000000000000000000000000000000..942878b26ee34ddeee70d99b01de89085a32057f --- /dev/null +++ b/vis4d/vis/pointcloud/viewer/open3d_viewer.py @@ -0,0 +1,184 @@ +"""Open3d visualization backend.""" + +from __future__ import annotations + +import os +from typing import TypedDict + +import numpy as np + +from vis4d.common.imports import OPEN3D_AVAILABLE +from vis4d.common.typing import NDArrayF64 +from vis4d.vis.pointcloud.scene import Scene3D + +from .base import PointCloudVisualizerBackend + +if OPEN3D_AVAILABLE: + import open3d as o3d + + +class PointcloudVisEntry(TypedDict): + """Entry for a pointcloud to visualize with open3d. + + Only used for typing. + """ + + name: str + geometry: o3d.geometry.PointCloud + + +class Open3DVisualizationBackend(PointCloudVisualizerBackend): + """Backend that uses open3d to visualize potincloud data.""" + + def __init__( + self, + class_color_mapping: list[tuple[int, int, int]], + instance_color_mapping: list[tuple[int, int, int]] | None = None, + ) -> None: + """Creates a new Open3D visualization backend. + + Args: + color_mapping (NDArrayF64): array of size [n_classes, 3] that maps + each class index to a unique color. + class_color_mapping (list[tuple[int, int, int]]): List of length + n_classes that assigns each class a unique color. + instance_color_mapping (list[tuple[int, int, int]], optional): List + of length n_classes that maps each instance id to unqiue color. + Defaults to None. + """ + super().__init__( + class_color_mapping=class_color_mapping, + instance_color_mapping=instance_color_mapping, + ) + + def save_to_disk(self, path_to_out_folder: str) -> None: + """Saves the visualization to disk. + + Creates files [colors.ply, classes.ply, instances.ply] for each scene + + Args: + path_to_out_folder (str): Path to output folder + """ + for idx, scene in enumerate(self.scenes): + out_folder = os.path.join(path_to_out_folder, f"scene_{idx:03d}") + os.makedirs(out_folder, exist_ok=True) + + for vis_pc in self._get_pc_data_for_scene(scene): + name = vis_pc["name"] + pc = vis_pc["geometry"] + o3d.io.write_point_cloud( + os.path.join(out_folder, f"{name}.ply"), pc + ) + print("written", f"{name}.ply") + + def show(self, blocking: bool = False) -> None: + """Shows the visualization. + + Args: + blocking (bool): If the visualization should be blocking + and wait for human input. + """ + for scene in self.scenes: + vis_data = [] + vis_data += self._get_pc_data_for_scene(scene) + + o3d.visualization.draw( + vis_data, non_blocking_and_return_uid=not blocking + ) + + def _get_pc_data_for_scene( + self, scene: Scene3D + ) -> list[PointcloudVisEntry]: + """Converts a given scene to a list of o3d data to visualize. + + Args: + scene (PointcloudVisEntry): Point cloud scene to visualize + Returns: + list[dict[str, Any]]: List of o3d geometries primitives to show. + """ + xyz, colors, classes, instances = [], [], [], [] + has_classes = False + has_instances = False + + for pc in scene.points: + n_pts = pc.xyz.shape[0] + + xyz.append(pc.xyz) + colors.append( + pc.colors if pc.colors is not None else np.zeros((n_pts, 3)) + ) + + if pc.classes is not None: + has_classes = True + col = self.class_color_mapping[ + pc.classes.squeeze() % self.class_color_mapping.shape[0] + ] + classes.append(col) + else: + classes.append(np.zeros((n_pts, 3))) + + if pc.instances is not None: + has_instances = True + col = self.instance_color_mapping[ + pc.instances.squeeze() + % self.instance_color_mapping.shape[0] + ] + instances.append(col) + else: + instances.append(np.zeros((n_pts, 3))) + + data: list[PointcloudVisEntry] = [] + + data += [ + { + "name": "colors", + "geometry": self._create_o3d_cloud( + np.concatenate(xyz), np.concatenate(colors) + ), + } + ] + if has_instances: + data += [ + { + "name": "instances", + "geometry": self._create_o3d_cloud( + np.concatenate(xyz), np.concatenate(instances) + ), + } + ] + if has_classes: + data += [ + { + "name": "classes", + "geometry": self._create_o3d_cloud( + np.concatenate(xyz), np.concatenate(classes) + ), + } + ] + + return data + + @staticmethod + def _create_o3d_cloud( + points: NDArrayF64, + colors: NDArrayF64 | None = None, + normals: NDArrayF64 | None = None, + ) -> o3d.geometry.PointCloud: + """Creates a o3d pointcloud from poitns and colors. + + Args: + points (NDArrayF64): xyz coordinates of the points + colors (NDArrayF64, optional): Colors of the points + normals (NDArrayF64, optional): Surface normals + + Returns: + o3d.geometry.PointCloud: o3d pointcloud with the given attributes + """ + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(points) + if colors is not None and len(colors) > 0: + pcd.colors = o3d.utility.Vector3dVector(colors) + if normals is not None and len(normals) > 0: + pcd.normals = o3d.utility.Vector3dVector(normals) + + return pcd diff --git a/vis4d/vis/util.py b/vis4d/vis/util.py new file mode 100644 index 0000000000000000000000000000000000000000..9242f0a9391b6e1b38cc4ec8ae37d24ce9543e46 --- /dev/null +++ b/vis4d/vis/util.py @@ -0,0 +1,33 @@ +"""Utilities for visualization.""" + +from __future__ import annotations + +import colorsys + +import numpy as np + + +def generate_color_map(length: int) -> list[tuple[int, int, int]]: + """Generate a color palette of [length] colors. + + Args: + length (int): Number of colors to generate. + + Returns: + list[tuple[int, int, int]]: List with different colors ranging + from [0,255]. + """ + brightness = 0.7 + hsv = [(i / length, 1, brightness) for i in range(length)] + colors_float = [colorsys.hsv_to_rgb(*c) for c in hsv] + colors: list[int] = ( + (np.array(colors_float) * 255).astype(np.uint8).tolist() + ) + s = np.random.get_state() + np.random.seed(0) + result = [tuple(colors[i]) for i in np.random.permutation(len(colors))] + np.random.set_state(s) + return result + + +DEFAULT_COLOR_MAPPING = generate_color_map(50) diff --git a/vis4d/zoo/__init__.py b/vis4d/zoo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c9a31eb9aacef6c1e4da8c7f43df5e7d702cd0da --- /dev/null +++ b/vis4d/zoo/__init__.py @@ -0,0 +1,31 @@ +"""Model Zoo.""" + +from __future__ import annotations + +from vis4d.common.typing import ArgsType + +from .bdd100k import AVAILABLE_MODELS as BDD100K_MODELS +from .bevformer import AVAILABLE_MODELS as BEVFORMER_MODELS +from .cc_3dt import AVAILABLE_MODELS as CC_3DT_MODELS +from .faster_rcnn import AVAILABLE_MODELS as FASTER_RCNN_MODELS +from .fcn_resnet import AVAILABLE_MODELS as FCN_RESNET_MODELS +from .mask_rcnn import AVAILABLE_MODELS as MASK_RCNN_MODELS +from .qdtrack import AVAILABLE_MODELS as QDTRACK_MODELS +from .retinanet import AVAILABLE_MODELS as RETINANET_MODELS +from .shift import AVAILABLE_MODELS as SHIFT_MODELS +from .vit import AVAILABLE_MODELS as VIT_MODELS +from .yolox import AVAILABLE_MODELS as YOLOX_MODELS + +AVAILABLE_MODELS: dict[str, dict[str, ArgsType]] = { + "bdd100k": BDD100K_MODELS, + "cc_3dt": CC_3DT_MODELS, + "bevformer": BEVFORMER_MODELS, + "faster_rcnn": FASTER_RCNN_MODELS, + "fcn_resnet": FCN_RESNET_MODELS, + "mask_rcnn": MASK_RCNN_MODELS, + "qdtrack": QDTRACK_MODELS, + "retinanet": RETINANET_MODELS, + "shift": SHIFT_MODELS, + "vit": VIT_MODELS, + "yolox": YOLOX_MODELS, +} diff --git a/vis4d/zoo/base/__init__.py b/vis4d/zoo/base/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b920b1c675f7cb72fdcac830cfee95700df7f75d --- /dev/null +++ b/vis4d/zoo/base/__init__.py @@ -0,0 +1,18 @@ +"""Model Zoo base.""" + +from .callable import get_callable_cfg +from .dataloader import get_inference_dataloaders_cfg, get_train_dataloader_cfg +from .optimizer import get_lr_scheduler_cfg, get_optimizer_cfg +from .pl_trainer import get_default_pl_trainer_cfg +from .runtime import get_default_callbacks_cfg, get_default_cfg + +__all__ = [ + "get_callable_cfg", + "get_train_dataloader_cfg", + "get_inference_dataloaders_cfg", + "get_optimizer_cfg", + "get_lr_scheduler_cfg", + "get_default_cfg", + "get_default_callbacks_cfg", + "get_default_pl_trainer_cfg", +] diff --git a/vis4d/zoo/base/callable.py b/vis4d/zoo/base/callable.py new file mode 100644 index 0000000000000000000000000000000000000000..a50553f109671084bc35b30dac322ce933947185 --- /dev/null +++ b/vis4d/zoo/base/callable.py @@ -0,0 +1,19 @@ +"""Callable objects for use in config files.""" + +from ml_collections import ConfigDict + +from vis4d.common.typing import ArgsType, GenericFunc +from vis4d.config import class_config, delay_instantiation + + +def get_callable_cfg(func: GenericFunc, **kwargs: ArgsType) -> ConfigDict: + """Return callable config. + + Args: + func (GenericFunc): Callable object. + **kwargs (ArgsType): Keyword arguments to pass to the callable. + + Returns: + ConfigDict: Config for the callable. + """ + return delay_instantiation(class_config(func, **kwargs)) diff --git a/vis4d/zoo/base/data_connectors/__init__.py b/vis4d/zoo/base/data_connectors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bb607050d8bd36610def74b22b6148ff4d1801a0 --- /dev/null +++ b/vis4d/zoo/base/data_connectors/__init__.py @@ -0,0 +1,20 @@ +"""Base data connectors.""" + +from .common import CONN_IMAGES_TEST, CONN_IMAGES_TRAIN +from .detection import CONN_BBOX_2D_TEST, CONN_BBOX_2D_TRAIN, CONN_BOX_LOSS_2D +from .visualizers import ( + CONN_BBOX_2D_TRACK_VIS, + CONN_BBOX_2D_VIS, + CONN_INS_MASK_2D_VIS, +) + +__all__ = [ + "CONN_IMAGES_TEST", + "CONN_IMAGES_TRAIN", + "CONN_BBOX_2D_TEST", + "CONN_BBOX_2D_TRAIN", + "CONN_BOX_LOSS_2D", + "CONN_BBOX_2D_VIS", + "CONN_BBOX_2D_TRACK_VIS", + "CONN_INS_MASK_2D_VIS", +] diff --git a/vis4d/zoo/base/data_connectors/cls.py b/vis4d/zoo/base/data_connectors/cls.py new file mode 100644 index 0000000000000000000000000000000000000000..d76edfb5d949ad9bd22e91f80ecc0cf9c9d11873 --- /dev/null +++ b/vis4d/zoo/base/data_connectors/cls.py @@ -0,0 +1,18 @@ +"""Data connectors for classification.""" + +from vis4d.data.const import CommonKeys as K +from vis4d.engine.connectors import data_key, pred_key + +CONN_CLS_TRAIN = {K.images: K.images} + +CONN_CLS_TEST = {K.images: K.images} + +CONN_CLS_LOSS = { + "input": pred_key("logits"), + "target": data_key("categories"), +} + +CONN_CLS_EVAL = { + "prediction": pred_key("probs"), + "groundtruth": data_key("categories"), +} diff --git a/vis4d/zoo/base/data_connectors/common.py b/vis4d/zoo/base/data_connectors/common.py new file mode 100644 index 0000000000000000000000000000000000000000..5738c731a4a965c296ee8c05a6723ddee514c3b5 --- /dev/null +++ b/vis4d/zoo/base/data_connectors/common.py @@ -0,0 +1,11 @@ +"""Data connectors for common tasks.""" + +from vis4d.data.const import CommonKeys as K + +CONN_IMAGES_TRAIN = {"images": K.images, "input_hw": K.input_hw} + +CONN_IMAGES_TEST = { + "images": K.images, + "input_hw": K.input_hw, + "original_hw": K.original_hw, +} diff --git a/vis4d/zoo/base/data_connectors/detection.py b/vis4d/zoo/base/data_connectors/detection.py new file mode 100644 index 0000000000000000000000000000000000000000..9ec2921da2705575797f93454b267b83ff25af57 --- /dev/null +++ b/vis4d/zoo/base/data_connectors/detection.py @@ -0,0 +1,24 @@ +"""Data connectors for detection.""" + +from vis4d.data.const import CommonKeys as K +from vis4d.engine.connectors import data_key, pred_key + +CONN_BBOX_2D_TRAIN = { + "images": K.images, + "input_hw": K.input_hw, + "boxes2d": K.boxes2d, + "boxes2d_classes": K.boxes2d_classes, +} + +CONN_BBOX_2D_TEST = { + "images": K.images, + "input_hw": K.input_hw, + "original_hw": K.original_hw, +} + +CONN_BOX_LOSS_2D = { + "cls_outs": pred_key("cls_score"), + "reg_outs": pred_key("bbox_pred"), + "target_boxes": data_key(K.boxes2d), + "images_hw": data_key(K.input_hw), +} diff --git a/vis4d/zoo/base/data_connectors/seg.py b/vis4d/zoo/base/data_connectors/seg.py new file mode 100644 index 0000000000000000000000000000000000000000..b6899f725856a1b6c6dad8e28f2d961bd2611f29 --- /dev/null +++ b/vis4d/zoo/base/data_connectors/seg.py @@ -0,0 +1,29 @@ +"""Data connectors for segmentation.""" + +from vis4d.data.const import CommonKeys as K +from vis4d.engine.connectors import data_key, pred_key + +CONN_MASKS_TRAIN = {"images": K.images} + +CONN_MASKS_TEST = {"images": K.images, K.original_hw: "original_hw"} + +CONN_SEG_LOSS = { + "output": pred_key("outputs"), + "target": data_key(K.seg_masks), +} + +CONN_MULTI_SEG_LOSS = { + "outputs": pred_key("outputs"), + "target": data_key(K.seg_masks), +} + +CONN_SEG_EVAL = { + "prediction": pred_key(K.seg_masks), + "groundtruth": data_key(K.seg_masks), +} + +CONN_SEG_VIS = { + K.images: data_key(K.images), + "image_names": data_key(K.sample_names), + "masks": pred_key("masks"), +} diff --git a/vis4d/zoo/base/data_connectors/visualizers.py b/vis4d/zoo/base/data_connectors/visualizers.py new file mode 100644 index 0000000000000000000000000000000000000000..0b29e35af4fe996899c08c67fff1cccdb9d60a78 --- /dev/null +++ b/vis4d/zoo/base/data_connectors/visualizers.py @@ -0,0 +1,27 @@ +"""Default data connectors for visualizers.""" + +from vis4d.data.const import CommonKeys as K +from vis4d.engine.connectors import data_key, pred_key + +CONN_BBOX_2D_VIS = { + "images": data_key(K.original_images), + "image_names": data_key(K.sample_names), + "boxes": pred_key("boxes"), + "scores": pred_key("scores"), + "class_ids": pred_key("class_ids"), +} + +CONN_BBOX_2D_TRACK_VIS = { + "images": data_key(K.original_images), + "image_names": data_key(K.sample_names), + "boxes": pred_key("boxes"), + "scores": pred_key("scores"), + "class_ids": pred_key("class_ids"), + "track_ids": pred_key("track_ids"), +} + +CONN_INS_MASK_2D_VIS = { + "images": data_key(K.original_images), + "image_names": data_key(K.sample_names), + "masks": pred_key("masks.masks"), +} diff --git a/vis4d/zoo/base/dataloader.py b/vis4d/zoo/base/dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..09c6371b46b0f17732a4d713d4dde360ab1c50e7 --- /dev/null +++ b/vis4d/zoo/base/dataloader.py @@ -0,0 +1,134 @@ +"""Dataloader configuration.""" + +from __future__ import annotations + +from collections.abc import Sequence + +from ml_collections import ConfigDict, FieldReference + +from vis4d.common.typing import GenericFunc +from vis4d.config import class_config +from vis4d.data.data_pipe import DataPipe +from vis4d.data.loader import ( + DEFAULT_COLLATE_KEYS, + build_inference_dataloaders, + build_train_dataloader, + default_collate, +) +from vis4d.data.transforms.to_tensor import ToTensor + +from .callable import get_callable_cfg + + +def get_train_dataloader_cfg( + datasets_cfg: ConfigDict | list[ConfigDict], + samples_per_gpu: int | FieldReference = 1, + workers_per_gpu: int | FieldReference = 1, + batchprocess_cfg: ConfigDict | None = None, + collate_fn: GenericFunc = default_collate, + collate_keys: Sequence[str] = DEFAULT_COLLATE_KEYS, + sensors: Sequence[str] | None = None, + pin_memory: bool | FieldReference = True, + shuffle: bool | FieldReference = True, + aspect_ratio_grouping: bool | FieldReference = False, +) -> ConfigDict: + """Creates dataloader configuration given dataset and preprocessing. + + Args: + datasets_cfg (ConfigDict | list[ConfigDict]): The configuration + contains the single dataset or datasets. If it is a list, + it will be wrapped into a DataPipe. + samples_per_gpu (int | FieldReference, optional): How many samples each + GPU will process. Defaults to 1. + workers_per_gpu (int | FieldReference, optional): How many workers to + spawn per GPU. Defaults to 1. + batchprocess_cfg (ConfigDict, optional): The config that contains the + batch processing operations. Defaults to None. If None, ToTensor + will be used. + collate_fn (GenericFunc, optional): The collate function to use. + Defaults to default_collate. + collate_keys (Sequence[str], optional): The keys to collate. Defaults + to DEFAULT_COLLATE_KEYS. + sensors (Sequence[str], optional): The sensors to collate. Defaults to + None. + pin_memory (bool | FieldReference, optional): Whether to pin memory. + Defaults to True. + shuffle (bool | FieldReference, optional): Whether to shuffle the + dataset. Defaults to True. + aspect_ratio_grouping (bool | FieldReference, optional): Whether to + group the samples by aspect ratio. Defaults to False. + + Returns: + ConfigDict: Configuration that can be instantiate as a dataloader. + """ + if batchprocess_cfg is None: + batchprocess_cfg = class_config(ToTensor) + + if isinstance(datasets_cfg, list): + dataset = class_config(DataPipe, datasets=datasets_cfg) + else: + dataset = datasets_cfg + + return class_config( + build_train_dataloader, + dataset=dataset, + samples_per_gpu=samples_per_gpu, + workers_per_gpu=workers_per_gpu, + batchprocess_fn=batchprocess_cfg, + collate_fn=get_callable_cfg(collate_fn), + collate_keys=collate_keys, + sensors=sensors, + pin_memory=pin_memory, + shuffle=shuffle, + aspect_ratio_grouping=aspect_ratio_grouping, + ) + + +def get_inference_dataloaders_cfg( + datasets_cfg: ConfigDict | list[ConfigDict], + samples_per_gpu: int | FieldReference = 1, + workers_per_gpu: int | FieldReference = 1, + video_based_inference: bool | FieldReference = False, + batchprocess_cfg: ConfigDict | None = None, + collate_fn: GenericFunc = default_collate, + collate_keys: Sequence[str] = DEFAULT_COLLATE_KEYS, + sensors: Sequence[str] | None = None, +) -> ConfigDict: + """Creates dataloader configuration given dataset for inference. + + Args: + datasets_cfg (ConfigDict | list[ConfigDict]): The configuration + contains the single dataset or datasets. + samples_per_gpu (int | FieldReference, optional): How many samples each + GPU will process per batch. Defaults to 1. + workers_per_gpu (int | FieldReference, optional): How many workers each + GPU will spawn. Defaults to 1. + video_based_inference (bool | FieldReference , optional): Whether to + split dataset by sequences. Defaults to False. + batchprocess_cfg (ConfigDict, optional): The config that contains the + batch processing operations. Defaults to None. If None, ToTensor + will be used. + collate_fn (GenericFunc, optional): The collate function that will be + used to stack the batch. Defaults to default_collate. + collate_keys (Sequence[str], optional): The keys to collate. Defaults + to DEFAULT_COLLATE_KEYS. + sensors (Sequence[str], optional): The sensors to collate. Defaults to + None. + + Returns: + ConfigDict: The dataloader configuration. + """ + if batchprocess_cfg is None: + batchprocess_cfg = class_config(ToTensor) + + return class_config( + build_inference_dataloaders, + datasets=datasets_cfg, + samples_per_gpu=samples_per_gpu, + workers_per_gpu=workers_per_gpu, + video_based_inference=video_based_inference, + batchprocess_fn=batchprocess_cfg, + collate_fn=get_callable_cfg(collate_fn), + collate_keys=collate_keys, + sensors=sensors, + ) diff --git a/vis4d/zoo/base/datasets/__init__.py b/vis4d/zoo/base/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..085c5df6ca1c7a1d779444fb412de95c703d1d2d --- /dev/null +++ b/vis4d/zoo/base/datasets/__init__.py @@ -0,0 +1 @@ +"""Model Zoo base datasets.""" diff --git a/vis4d/zoo/base/datasets/bdd100k/__init__.py b/vis4d/zoo/base/datasets/bdd100k/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..66ea6eb77ddbc574d1b77853f5d6926b28d2aba4 --- /dev/null +++ b/vis4d/zoo/base/datasets/bdd100k/__init__.py @@ -0,0 +1,19 @@ +"""BDD100K dataset config.""" + +from .detect import ( + CONN_BDD100K_DET_EVAL, + CONN_BDD100K_INS_EVAL, + get_bdd100k_detection_config, +) +from .sem_seg import CONN_BDD100K_SEG_EVAL, get_bdd100k_sem_seg_cfg +from .track import CONN_BDD100K_TRACK_EVAL, get_bdd100k_track_cfg + +__all__ = [ + "CONN_BDD100K_DET_EVAL", + "CONN_BDD100K_INS_EVAL", + "get_bdd100k_detection_config", + "get_bdd100k_sem_seg_cfg", + "CONN_BDD100K_SEG_EVAL", + "get_bdd100k_track_cfg", + "CONN_BDD100K_TRACK_EVAL", +] diff --git a/vis4d/zoo/base/datasets/bdd100k/detect.py b/vis4d/zoo/base/datasets/bdd100k/detect.py new file mode 100644 index 0000000000000000000000000000000000000000..777323029a78d139ced7c8066fec62312e8a90cc --- /dev/null +++ b/vis4d/zoo/base/datasets/bdd100k/detect.py @@ -0,0 +1,243 @@ +# pylint: disable=duplicate-code +"""BDD100K dataset config for object detection.""" +from __future__ import annotations + +from collections.abc import Sequence + +from ml_collections import ConfigDict + +from vis4d.config import class_config +from vis4d.config.typing import DataConfig +from vis4d.data.const import CommonKeys as K +from vis4d.data.data_pipe import DataPipe +from vis4d.data.datasets.bdd100k import BDD100K +from vis4d.data.io import DataBackend +from vis4d.data.transforms.base import RandomApply, compose +from vis4d.data.transforms.flip import ( + FlipBoxes2D, + FlipImages, + FlipInstanceMasks, +) +from vis4d.data.transforms.normalize import NormalizeImages +from vis4d.data.transforms.pad import PadImages +from vis4d.data.transforms.resize import ( + GenResizeParameters, + ResizeBoxes2D, + ResizeImages, + ResizeInstanceMasks, +) +from vis4d.data.transforms.to_tensor import ToTensor +from vis4d.engine.connectors import data_key, pred_key +from vis4d.zoo.base import ( + get_inference_dataloaders_cfg, + get_train_dataloader_cfg, +) + +CONN_BDD100K_DET_EVAL = { + "frame_ids": data_key("frame_ids"), + "sample_names": data_key("sample_names"), + "sequence_names": data_key("sequence_names"), + "pred_boxes": pred_key("boxes"), + "pred_scores": pred_key("scores"), + "pred_classes": pred_key("class_ids"), +} +CONN_BDD100K_INS_EVAL = { + "frame_ids": data_key("frame_ids"), + "sample_names": data_key("sample_names"), + "sequence_names": data_key("sequence_names"), + "pred_boxes": pred_key("boxes.boxes"), + "pred_scores": pred_key("boxes.scores"), + "pred_classes": pred_key("boxes.class_ids"), + "pred_masks": pred_key("masks.masks"), +} + + +def get_train_dataloader( + data_root: str, + anno_path: str, + keys_to_load: Sequence[str] = (K.images, K.boxes2d), + ins_seg: bool = False, + data_backend: None | DataBackend = None, + image_size: tuple[int, int] = (720, 1280), + multi_scale: bool = False, + samples_per_gpu: int = 2, + workers_per_gpu: int = 2, +) -> ConfigDict: + """Get the default train dataloader for BDD100K segmentation.""" + # Train Dataset + train_dataset_cfg = class_config( + BDD100K, + data_root=data_root, + annotation_path=anno_path, + config_path="ins_seg" if ins_seg else "det", + keys_to_load=keys_to_load, + data_backend=data_backend, + skip_empty_samples=True, + ) + + # Train Preprocessing + if multi_scale: + ms_shapes = [(image_size[0] - 24 * i, image_size[1]) for i in range(6)] + preprocess_transforms = [ + class_config( + GenResizeParameters, + shape=ms_shapes, + keep_ratio=True, + multiscale_mode="list", + align_long_edge=True, + ) + ] + else: + preprocess_transforms = [ + class_config( + GenResizeParameters, + shape=image_size, + keep_ratio=True, + align_long_edge=True, + ) + ] + preprocess_transforms += [ + class_config(ResizeImages), + class_config(ResizeBoxes2D), + ] + if K.instance_masks in keys_to_load: + preprocess_transforms.append(class_config(ResizeInstanceMasks)) + + flip_transforms = [class_config(FlipImages), class_config(FlipBoxes2D)] + if K.instance_masks in keys_to_load: + flip_transforms.append(class_config(FlipInstanceMasks)) + + preprocess_transforms.append( + class_config( + RandomApply, + transforms=flip_transforms, + probability=0.5, + ) + ) + + preprocess_transforms.append(class_config(NormalizeImages)) + + train_preprocess_cfg = class_config( + compose, transforms=preprocess_transforms + ) + + train_batchprocess_cfg = class_config( + compose, + transforms=[class_config(PadImages), class_config(ToTensor)], + ) + + return get_train_dataloader_cfg( + datasets_cfg=class_config( + DataPipe, + datasets=train_dataset_cfg, + preprocess_fn=train_preprocess_cfg, + ), + batchprocess_cfg=train_batchprocess_cfg, + samples_per_gpu=samples_per_gpu, + workers_per_gpu=workers_per_gpu, + ) + + +def get_test_dataloader( + data_root: str, + anno_path: str, + keys_to_load: Sequence[str] = (K.images, K.original_images), + ins_seg: bool = False, + data_backend: None | DataBackend = None, + image_size: tuple[int, int] = (720, 1280), + samples_per_gpu: int = 1, + workers_per_gpu: int = 1, +) -> ConfigDict: + """Get the default test dataloader for BDD100K segmentation.""" + # Test Dataset + test_dataset_cfg = class_config( + BDD100K, + data_root=data_root, + annotation_path=anno_path, + config_path="ins_seg" if ins_seg else "det", + keys_to_load=keys_to_load, + data_backend=data_backend, + ) + + # Test Preprocessing + preprocess_transforms = [ + class_config( + GenResizeParameters, + shape=image_size, + keep_ratio=True, + align_long_edge=True, + ), + class_config(ResizeImages), + ] + + preprocess_transforms.append(class_config(NormalizeImages)) + + test_preprocess_cfg = class_config( + compose, transforms=preprocess_transforms + ) + + test_batchprocess_cfg = class_config( + compose, + transforms=[class_config(PadImages), class_config(ToTensor)], + ) + + # Test Dataset Config + test_dataset_cfg = class_config( + DataPipe, datasets=test_dataset_cfg, preprocess_fn=test_preprocess_cfg + ) + + return get_inference_dataloaders_cfg( + datasets_cfg=test_dataset_cfg, + batchprocess_cfg=test_batchprocess_cfg, + samples_per_gpu=samples_per_gpu, + workers_per_gpu=workers_per_gpu, + ) + + +def get_bdd100k_detection_config( + data_root: str = "data/bdd100k/images/100k", + train_split: str = "train", + train_keys_to_load: Sequence[str] = (K.images, K.boxes2d), + test_split: str = "val", + test_keys_to_load: Sequence[str] = (K.images, K.original_images), + ins_seg: bool = False, + data_backend: None | ConfigDict = None, + image_size: tuple[int, int] = (720, 1280), + multi_scale: bool = False, + samples_per_gpu: int = 2, + workers_per_gpu: int = 2, +) -> DataConfig: + """Get the default config for BDD100K detection.""" + data = DataConfig() + + if K.instance_masks in train_keys_to_load: + train_anno_path = "data/bdd100k/labels/ins_seg_train_rle.json" + test_anno_path = "data/bdd100k/labels/ins_seg_val_rle.json" + else: + train_anno_path = "data/bdd100k/labels/det_20/det_train.json" + test_anno_path = "data/bdd100k/labels/det_20/det_val.json" + + data.train_dataloader = get_train_dataloader( + data_root=f"{data_root}/{train_split}", + anno_path=train_anno_path, + keys_to_load=train_keys_to_load, + ins_seg=ins_seg, + data_backend=data_backend, + image_size=image_size, + multi_scale=multi_scale, + samples_per_gpu=samples_per_gpu, + workers_per_gpu=workers_per_gpu, + ) + + data.test_dataloader = get_test_dataloader( + data_root=f"{data_root}/{test_split}", + anno_path=test_anno_path, + keys_to_load=test_keys_to_load, + ins_seg=ins_seg, + data_backend=data_backend, + image_size=image_size, + samples_per_gpu=1, + workers_per_gpu=1, + ) + + return data diff --git a/vis4d/zoo/base/datasets/bdd100k/sem_seg.py b/vis4d/zoo/base/datasets/bdd100k/sem_seg.py new file mode 100644 index 0000000000000000000000000000000000000000..d2351a5b8e4d17599eecd4cae2e57cfa6e023189 --- /dev/null +++ b/vis4d/zoo/base/datasets/bdd100k/sem_seg.py @@ -0,0 +1,216 @@ +# pylint: disable=duplicate-code +"""BDD100K dataset config for semantic segmentation.""" +from __future__ import annotations + +from collections.abc import Sequence + +from ml_collections import ConfigDict + +from vis4d.config import class_config +from vis4d.config.typing import DataConfig +from vis4d.data.const import CommonKeys as K +from vis4d.data.data_pipe import DataPipe +from vis4d.data.datasets.bdd100k import BDD100K +from vis4d.data.io import DataBackend +from vis4d.data.transforms.base import RandomApply, compose +from vis4d.data.transforms.crop import ( + CropImages, + CropSegMasks, + GenCropParameters, +) +from vis4d.data.transforms.flip import FlipImages, FlipSegMasks +from vis4d.data.transforms.normalize import NormalizeImages +from vis4d.data.transforms.pad import PadImages, PadSegMasks +from vis4d.data.transforms.photometric import ColorJitter +from vis4d.data.transforms.resize import ( + GenResizeParameters, + ResizeImages, + ResizeSegMasks, +) +from vis4d.data.transforms.to_tensor import ToTensor +from vis4d.engine.connectors import data_key, pred_key +from vis4d.zoo.base import ( + get_inference_dataloaders_cfg, + get_train_dataloader_cfg, +) + +CONN_BDD100K_SEG_EVAL = { + "data_names": data_key("sample_names"), + "masks_list": pred_key("masks"), +} + + +def get_train_dataloader( + data_root: str, + anno_path: str, + keys_to_load: Sequence[str] = (K.images, K.seg_masks), + data_backend: None | DataBackend = None, + image_size: tuple[int, int] = (720, 1280), + crop_size: tuple[int, int] = (512, 1024), + samples_per_gpu: int = 2, + workers_per_gpu: int = 2, +) -> ConfigDict: + """Get the default train dataloader for BDD100K segmentation.""" + # Train Dataset + train_dataset_cfg = class_config( + BDD100K, + data_root=data_root, + annotation_path=anno_path, + config_path="sem_seg", + keys_to_load=keys_to_load, + data_backend=data_backend, + ) + + # Train Preprocessing + preprocess_transforms = [ + class_config( + GenResizeParameters, + shape=image_size, + keep_ratio=True, + scale_range=(0.5, 2.0), + ), + class_config(ResizeImages), + class_config(ResizeSegMasks), + ] + + preprocess_transforms = [ + class_config(GenCropParameters, shape=crop_size, cat_max_ratio=0.75), + class_config(CropImages), + class_config(CropSegMasks), + ] + + preprocess_transforms.append( + class_config( + RandomApply, + transforms=[class_config(FlipImages), class_config(FlipSegMasks)], + probability=0.5, + ) + ) + + preprocess_transforms.append( + class_config( + RandomApply, + transforms=[class_config(ColorJitter)], + probability=0.5, + ) + ) + + preprocess_transforms.append(class_config(NormalizeImages)) + + train_preprocess_cfg = class_config( + compose, transforms=preprocess_transforms + ) + + train_batchprocess_cfg = class_config( + compose, + transforms=[ + class_config(PadImages, shape=crop_size), + class_config(PadSegMasks, shape=crop_size), + class_config(ToTensor), + ], + ) + + return get_train_dataloader_cfg( + datasets_cfg=class_config( + DataPipe, + datasets=train_dataset_cfg, + preprocess_fn=train_preprocess_cfg, + ), + batchprocess_cfg=train_batchprocess_cfg, + samples_per_gpu=samples_per_gpu, + workers_per_gpu=workers_per_gpu, + ) + + +def get_test_dataloader( + data_root: str, + anno_path: str, + keys_to_load: Sequence[str] = (K.images, K.seg_masks), + data_backend: None | DataBackend = None, + image_size: tuple[int, int] = (720, 1280), + samples_per_gpu: int = 1, + workers_per_gpu: int = 1, +) -> ConfigDict: + """Get the default test dataloader for BDD100K segmentation.""" + # Test Dataset + test_dataset_cfg = class_config( + BDD100K, + data_root=data_root, + annotation_path=anno_path, + config_path="sem_seg", + keys_to_load=keys_to_load, + data_backend=data_backend, + ) + + # Test Preprocessing + preprocess_transforms = [ + class_config(GenResizeParameters, shape=image_size, keep_ratio=True), + class_config(ResizeImages), + class_config(ResizeSegMasks), + ] + + preprocess_transforms.append(class_config(NormalizeImages)) + + test_preprocess_cfg = class_config( + compose, transforms=preprocess_transforms + ) + + test_batchprocess_cfg = class_config( + compose, + transforms=[ + class_config(PadImages, shape=image_size), + class_config(PadSegMasks, shape=image_size), + class_config(ToTensor), + ], + ) + + # Test Dataset Config + test_dataset_cfg = class_config( + DataPipe, datasets=test_dataset_cfg, preprocess_fn=test_preprocess_cfg + ) + + return get_inference_dataloaders_cfg( + datasets_cfg=test_dataset_cfg, + batchprocess_cfg=test_batchprocess_cfg, + samples_per_gpu=samples_per_gpu, + workers_per_gpu=workers_per_gpu, + ) + + +def get_bdd100k_sem_seg_cfg( + data_root: str = "data/bdd100k/images/10k", + train_split: str = "train", + train_keys_to_load: Sequence[str] = (K.images, K.seg_masks), + test_split: str = "val", + test_keys_to_load: Sequence[str] = (K.images, K.seg_masks), + data_backend: None | ConfigDict = None, + image_size: tuple[int, int] = (720, 1280), + crop_size: tuple[int, int] = (512, 1024), + samples_per_gpu: int = 2, + workers_per_gpu: int = 2, +) -> DataConfig: + """Get the default config for BDD100K semantic segmentation.""" + data = DataConfig() + + data.train_dataloader = get_train_dataloader( + data_root=f"{data_root}/{train_split}", + anno_path=f"data/bdd100k/labels/sem_seg_{train_split}_rle.json", + keys_to_load=train_keys_to_load, + data_backend=data_backend, + image_size=image_size, + crop_size=crop_size, + samples_per_gpu=samples_per_gpu, + workers_per_gpu=workers_per_gpu, + ) + + data.test_dataloader = get_test_dataloader( + data_root=f"{data_root}/{test_split}", + anno_path=f"data/bdd100k/labels/sem_seg_{test_split}_rle.json", + keys_to_load=test_keys_to_load, + data_backend=data_backend, + image_size=image_size, + samples_per_gpu=1, + workers_per_gpu=workers_per_gpu, + ) + + return data diff --git a/vis4d/zoo/base/datasets/bdd100k/track.py b/vis4d/zoo/base/datasets/bdd100k/track.py new file mode 100644 index 0000000000000000000000000000000000000000..ed3bc2b5228c9eff6e83c008ec7f31f03b058f48 --- /dev/null +++ b/vis4d/zoo/base/datasets/bdd100k/track.py @@ -0,0 +1,212 @@ +"""BDD100K tracking dataset configs.""" + +from __future__ import annotations + +from ml_collections import ConfigDict + +from vis4d.config import class_config +from vis4d.config.typing import DataConfig +from vis4d.data.const import CommonKeys as K +from vis4d.data.data_pipe import DataPipe +from vis4d.data.datasets.bdd100k import BDD100K, bdd100k_track_map +from vis4d.data.reference import MultiViewDataset, UniformViewSampler +from vis4d.data.transforms import RandomApply, compose +from vis4d.data.transforms.flip import FlipBoxes2D, FlipImages +from vis4d.data.transforms.normalize import NormalizeImages +from vis4d.data.transforms.pad import PadImages +from vis4d.data.transforms.post_process import PostProcessBoxes2D +from vis4d.data.transforms.resize import ( + GenResizeParameters, + ResizeBoxes2D, + ResizeImages, +) +from vis4d.data.transforms.to_tensor import ToTensor +from vis4d.engine.connectors import data_key, pred_key +from vis4d.zoo.base import ( + get_inference_dataloaders_cfg, + get_train_dataloader_cfg, +) + +CONN_BDD100K_TRACK_EVAL = { + "frame_ids": data_key("frame_ids"), + "sample_names": data_key(K.sample_names), + "sequence_names": data_key(K.sequence_names), + "pred_boxes": pred_key("boxes"), + "pred_classes": pred_key("class_ids"), + "pred_scores": pred_key("scores"), + "pred_track_ids": pred_key("track_ids"), +} + + +def get_train_dataloader( + data_backend: None | ConfigDict, + samples_per_gpu: int, + workers_per_gpu: int, +) -> ConfigDict: + """Get the default train dataloader for BDD100K tracking.""" + bdd100k_det_train = class_config( + BDD100K, + data_root="data/bdd100k/images/100k/train/", + keys_to_load=(K.images, K.boxes2d), + annotation_path="data/bdd100k/labels/det_20/det_train.json", + config_path="det", + data_backend=data_backend, + category_map=bdd100k_track_map, + skip_empty_samples=True, + cache_as_binary=True, + cached_file_path="data/bdd100k/det_train.pkl", + ) + + bdd100k_track_train = class_config( + BDD100K, + data_root="data/bdd100k/images/track/train/", + keys_to_load=(K.images, K.boxes2d), + annotation_path="data/bdd100k/labels/box_track_20/train/", + config_path="box_track", + data_backend=data_backend, + category_map=bdd100k_track_map, + skip_empty_samples=True, + cache_as_binary=True, + cached_file_path="data/bdd100k/track_train.pkl", + ) + + train_dataset_cfg = [ + class_config( + MultiViewDataset, + dataset=bdd100k_det_train, + sampler=class_config( + UniformViewSampler, scope=0, num_ref_samples=1 + ), + ), + class_config( + MultiViewDataset, + dataset=bdd100k_track_train, + sampler=class_config( + UniformViewSampler, scope=3, num_ref_samples=1 + ), + ), + ] + + preprocess_transforms = [ + class_config( + GenResizeParameters, + shape=(720, 1280), + keep_ratio=True, + ), + class_config(ResizeImages), + class_config(ResizeBoxes2D), + ] + + preprocess_transforms.append( + class_config( + RandomApply, + transforms=[ + class_config(FlipImages), + class_config(FlipBoxes2D), + ], + probability=0.5, + ) + ) + + preprocess_transforms.append(class_config(NormalizeImages)) + preprocess_transforms.append(class_config(PostProcessBoxes2D)) + + train_preprocess_cfg = class_config( + compose, + transforms=preprocess_transforms, + ) + + train_batchprocess_cfg = class_config( + compose, + transforms=[class_config(PadImages), class_config(ToTensor)], + ) + + return get_train_dataloader_cfg( + datasets_cfg=class_config( + DataPipe, + datasets=train_dataset_cfg, + preprocess_fn=train_preprocess_cfg, + ), + samples_per_gpu=samples_per_gpu, + workers_per_gpu=workers_per_gpu, + batchprocess_cfg=train_batchprocess_cfg, + ) + + +def get_test_dataloader( + data_backend: None | ConfigDict, + samples_per_gpu: int, + workers_per_gpu: int, +) -> ConfigDict: + """Get the default test dataloader for BDD100K tracking.""" + test_dataset = class_config( + BDD100K, + data_root="data/bdd100k/images/track/val/", + keys_to_load=(K.images, K.original_images), + annotation_path="data/bdd100k/labels/box_track_20/val/", + config_path="box_track", + category_map=bdd100k_track_map, + data_backend=data_backend, + cache_as_binary=True, + cached_file_path="data/bdd100k/track_val.pkl", + ) + + preprocess_transforms = [ + class_config( + GenResizeParameters, + shape=(720, 1280), + keep_ratio=True, + ), + class_config(ResizeImages), + class_config(NormalizeImages), + ] + + test_preprocess_cfg = class_config( + compose, + transforms=preprocess_transforms, + ) + + test_batchprocess_cfg = class_config( + compose, + transforms=[ + class_config(PadImages), + class_config(ToTensor), + ], + ) + + test_dataset_cfg = class_config( + DataPipe, + datasets=test_dataset, + preprocess_fn=test_preprocess_cfg, + ) + + return get_inference_dataloaders_cfg( + datasets_cfg=test_dataset_cfg, + samples_per_gpu=samples_per_gpu, + workers_per_gpu=workers_per_gpu, + video_based_inference=True, + batchprocess_cfg=test_batchprocess_cfg, + ) + + +def get_bdd100k_track_cfg( + data_backend: None | ConfigDict = None, + samples_per_gpu: int = 2, + workers_per_gpu: int = 2, +) -> DataConfig: + """Get the default config for BDD100K tracking.""" + data = DataConfig() + + data.train_dataloader = get_train_dataloader( + data_backend=data_backend, + samples_per_gpu=samples_per_gpu, + workers_per_gpu=workers_per_gpu, + ) + + data.test_dataloader = get_test_dataloader( + data_backend=data_backend, + samples_per_gpu=1, + workers_per_gpu=1, + ) + + return data diff --git a/vis4d/zoo/base/datasets/coco/__init__.py b/vis4d/zoo/base/datasets/coco/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cb7015d84df0e5b1b05b2aace1aa7393cd49bc17 --- /dev/null +++ b/vis4d/zoo/base/datasets/coco/__init__.py @@ -0,0 +1,15 @@ +"""COCO dataset config.""" + +from .detection import ( + CONN_COCO_BBOX_EVAL, + CONN_COCO_MASK_EVAL, + get_coco_detection_cfg, +) +from .sem_seg import get_coco_sem_seg_cfg + +__all__ = [ + "get_coco_detection_cfg", + "CONN_COCO_BBOX_EVAL", + "CONN_COCO_MASK_EVAL", + "get_coco_sem_seg_cfg", +] diff --git a/vis4d/zoo/base/datasets/coco/detection.py b/vis4d/zoo/base/datasets/coco/detection.py new file mode 100644 index 0000000000000000000000000000000000000000..33330dbc52e6b1486e6bb18e0828ce11f7a31365 --- /dev/null +++ b/vis4d/zoo/base/datasets/coco/detection.py @@ -0,0 +1,246 @@ +# pylint: disable=duplicate-code +"""COCO data loading config for object detection.""" +from __future__ import annotations + +from collections.abc import Sequence + +from ml_collections import ConfigDict + +from vis4d.config import class_config +from vis4d.config.typing import DataConfig +from vis4d.data.const import CommonKeys as K +from vis4d.data.data_pipe import DataPipe +from vis4d.data.datasets.coco import COCO +from vis4d.data.io import DataBackend +from vis4d.data.transforms.base import RandomApply, compose +from vis4d.data.transforms.flip import ( + FlipBoxes2D, + FlipImages, + FlipInstanceMasks, +) +from vis4d.data.transforms.normalize import NormalizeImages +from vis4d.data.transforms.pad import PadImages +from vis4d.data.transforms.resize import ( + GenResizeParameters, + ResizeBoxes2D, + ResizeImages, + ResizeInstanceMasks, +) +from vis4d.data.transforms.to_tensor import ToTensor +from vis4d.engine.connectors import data_key, pred_key +from vis4d.zoo.base import ( + get_inference_dataloaders_cfg, + get_train_dataloader_cfg, +) + +CONN_COCO_BBOX_EVAL = { + "coco_image_id": data_key(K.sample_names), + "pred_boxes": pred_key("boxes"), + "pred_scores": pred_key("scores"), + "pred_classes": pred_key("class_ids"), +} + +CONN_COCO_MASK_EVAL = { + "coco_image_id": data_key(K.sample_names), + "pred_boxes": pred_key("boxes.boxes"), + "pred_scores": pred_key("boxes.scores"), + "pred_classes": pred_key("boxes.class_ids"), + "pred_masks": pred_key("masks"), +} + + +def get_train_dataloader( + data_root: str, + split: str, + keys_to_load: Sequence[str], + data_backend: None | DataBackend, + image_size: tuple[int, int], + samples_per_gpu: int, + workers_per_gpu: int, + cache_as_binary: bool, + cached_file_path: str | None = None, +) -> ConfigDict: + """Get the default train dataloader for COCO detection.""" + # Train Dataset + train_dataset_cfg = class_config( + COCO, + keys_to_load=keys_to_load, + data_root=data_root, + split=split, + remove_empty=True, + data_backend=data_backend, + cache_as_binary=cache_as_binary, + cached_file_path=cached_file_path, + ) + + # Train Preprocessing + preprocess_transforms = [ + class_config( + GenResizeParameters, + shape=image_size, + keep_ratio=True, + align_long_edge=True, + ), + class_config(ResizeImages), + class_config(ResizeBoxes2D), + ] + + if K.instance_masks in keys_to_load: + preprocess_transforms.append(class_config(ResizeInstanceMasks)) + + flip_transforms = [class_config(FlipImages), class_config(FlipBoxes2D)] + + if K.instance_masks in keys_to_load: + flip_transforms.append(class_config(FlipInstanceMasks)) + + preprocess_transforms.append( + class_config( + RandomApply, + transforms=flip_transforms, + probability=0.5, + ) + ) + + preprocess_transforms.append(class_config(NormalizeImages)) + + train_preprocess_cfg = class_config( + compose, + transforms=preprocess_transforms, + ) + + train_batchprocess_cfg = class_config( + compose, + transforms=[ + class_config(PadImages), + class_config(ToTensor), + ], + ) + + return get_train_dataloader_cfg( + datasets_cfg=class_config( + DataPipe, + datasets=train_dataset_cfg, + preprocess_fn=train_preprocess_cfg, + ), + batchprocess_cfg=train_batchprocess_cfg, + samples_per_gpu=samples_per_gpu, + workers_per_gpu=workers_per_gpu, + ) + + +def get_test_dataloader( + data_root: str, + split: str, + keys_to_load: Sequence[str], + data_backend: None | DataBackend, + image_size: tuple[int, int], + samples_per_gpu: int, + workers_per_gpu: int, + cache_as_binary: bool, + cached_file_path: str | None = None, +) -> ConfigDict: + """Get the default test dataloader for COCO detection.""" + # Test Dataset + test_dataset = class_config( + COCO, + keys_to_load=keys_to_load, + data_root=data_root, + split=split, + data_backend=data_backend, + cache_as_binary=cache_as_binary, + cached_file_path=cached_file_path, + ) + + # Test Preprocessing + preprocess_transforms = [ + class_config( + GenResizeParameters, + shape=image_size, + keep_ratio=True, + align_long_edge=True, + ), + class_config(ResizeImages), + class_config(ResizeBoxes2D), + ] + + preprocess_transforms.append(class_config(NormalizeImages)) + + test_preprocess_cfg = class_config( + compose, + transforms=preprocess_transforms, + ) + + test_batchprocess_cfg = class_config( + compose, + transforms=[ + class_config(PadImages), + class_config(ToTensor), + ], + ) + + # Test Dataset Config + test_dataset_cfg = class_config( + DataPipe, + datasets=test_dataset, + preprocess_fn=test_preprocess_cfg, + ) + + return get_inference_dataloaders_cfg( + datasets_cfg=test_dataset_cfg, + batchprocess_cfg=test_batchprocess_cfg, + samples_per_gpu=samples_per_gpu, + workers_per_gpu=workers_per_gpu, + ) + + +def get_coco_detection_cfg( + data_root: str = "data/coco", + train_split: str = "train2017", + train_keys_to_load: Sequence[str] = ( + K.images, + K.boxes2d, + K.boxes2d_classes, + ), + train_cached_file_path: str | None = "data/coco/train.pkl", + test_split: str = "val2017", + test_keys_to_load: Sequence[str] = ( + K.images, + K.original_images, + K.boxes2d, + K.boxes2d_classes, + ), + test_cached_file_path: str | None = "data/coco/val.pkl", + cache_as_binary: bool = True, + data_backend: None | ConfigDict = None, + image_size: tuple[int, int] = (800, 1333), + samples_per_gpu: int = 2, + workers_per_gpu: int = 2, +) -> DataConfig: + """Get the default config for COCO detection.""" + data = DataConfig() + + data.train_dataloader = get_train_dataloader( + data_root=data_root, + split=train_split, + keys_to_load=train_keys_to_load, + data_backend=data_backend, + image_size=image_size, + samples_per_gpu=samples_per_gpu, + workers_per_gpu=workers_per_gpu, + cache_as_binary=cache_as_binary, + cached_file_path=train_cached_file_path, + ) + + data.test_dataloader = get_test_dataloader( + data_root=data_root, + split=test_split, + keys_to_load=test_keys_to_load, + data_backend=data_backend, + image_size=image_size, + samples_per_gpu=1, + workers_per_gpu=workers_per_gpu, + cache_as_binary=cache_as_binary, + cached_file_path=test_cached_file_path, + ) + + return data diff --git a/vis4d/zoo/base/datasets/coco/sem_seg.py b/vis4d/zoo/base/datasets/coco/sem_seg.py new file mode 100644 index 0000000000000000000000000000000000000000..36a4e489c80cd2634adc26afda1a226259542ca9 --- /dev/null +++ b/vis4d/zoo/base/datasets/coco/sem_seg.py @@ -0,0 +1,195 @@ +# pylint: disable=duplicate-code +"""COCO data loading config for for semantic segmentation.""" +from __future__ import annotations + +from collections.abc import Sequence + +from ml_collections import ConfigDict + +from vis4d.config import class_config +from vis4d.config.typing import DataConfig +from vis4d.data.const import CommonKeys as K +from vis4d.data.data_pipe import DataPipe +from vis4d.data.datasets.coco import COCO +from vis4d.data.io import DataBackend +from vis4d.data.transforms.base import RandomApply, compose +from vis4d.data.transforms.flip import FlipImages, FlipSegMasks +from vis4d.data.transforms.normalize import NormalizeImages +from vis4d.data.transforms.pad import PadImages, PadSegMasks +from vis4d.data.transforms.photometric import ColorJitter +from vis4d.data.transforms.resize import ( + GenResizeParameters, + ResizeImages, + ResizeSegMasks, +) +from vis4d.data.transforms.to_tensor import ToTensor +from vis4d.zoo.base import ( + get_inference_dataloaders_cfg, + get_train_dataloader_cfg, +) + + +def get_train_dataloader( + data_root: str, + split: str, + keys_to_load: Sequence[str], + data_backend: None | DataBackend, + image_size: tuple[int, int], + samples_per_gpu: int, + workers_per_gpu: int, +) -> ConfigDict: + """Get the default train dataloader for COCO detection.""" + # Train Dataset + train_dataset_cfg = class_config( + COCO, + keys_to_load=keys_to_load, + data_root=data_root, + split=split, + remove_empty=True, + data_backend=data_backend, + ) + + # Train Preprocessing + preprocess_transforms = [ + class_config( + GenResizeParameters, + shape=image_size, + keep_ratio=True, + scale_range=(0.5, 2.0), + ), + class_config(ResizeImages), + class_config(ResizeSegMasks), + ] + + preprocess_transforms.append( + class_config( + RandomApply, + transforms=[class_config(FlipImages), class_config(FlipSegMasks)], + probability=0.5, + ) + ) + + preprocess_transforms.append( + class_config( + RandomApply, + transforms=[class_config(ColorJitter)], + probability=0.5, + ) + ) + + preprocess_transforms.append(class_config(NormalizeImages)) + + train_preprocess_cfg = class_config( + compose, transforms=preprocess_transforms + ) + + train_batchprocess_cfg = class_config( + compose, + transforms=[ + class_config(PadImages), + class_config(PadSegMasks), + class_config(ToTensor), + ], + ) + + return get_train_dataloader_cfg( + datasets_cfg=class_config( + DataPipe, + datasets=train_dataset_cfg, + preprocess_fn=train_preprocess_cfg, + ), + batchprocess_cfg=train_batchprocess_cfg, + samples_per_gpu=samples_per_gpu, + workers_per_gpu=workers_per_gpu, + ) + + +def get_test_dataloader( + data_root: str, + split: str, + keys_to_load: Sequence[str], + data_backend: None | DataBackend, + image_size: tuple[int, int], + samples_per_gpu: int, + workers_per_gpu: int, +) -> ConfigDict: + """Get the default test dataloader for COCO detection.""" + # Test Dataset + test_dataset = class_config( + COCO, + keys_to_load=keys_to_load, + data_root=data_root, + split=split, + data_backend=data_backend, + ) + + # Test Preprocessing + preprocess_transforms = [ + class_config(GenResizeParameters, shape=image_size, keep_ratio=True), + class_config(ResizeImages), + class_config(ResizeSegMasks), + ] + + preprocess_transforms.append(class_config(NormalizeImages)) + + test_preprocess_cfg = class_config( + compose, transforms=preprocess_transforms + ) + + test_batchprocess_cfg = class_config( + compose, + transforms=[ + class_config(PadImages, shape=image_size), + class_config(PadSegMasks, shape=image_size), + class_config(ToTensor), + ], + ) + + # Test Dataset Config + test_dataset_cfg = class_config( + DataPipe, datasets=test_dataset, preprocess_fn=test_preprocess_cfg + ) + + return get_inference_dataloaders_cfg( + datasets_cfg=test_dataset_cfg, + batchprocess_cfg=test_batchprocess_cfg, + samples_per_gpu=samples_per_gpu, + workers_per_gpu=workers_per_gpu, + ) + + +def get_coco_sem_seg_cfg( + data_root: str = "data/coco", + train_split: str = "train2017", + train_keys_to_load: Sequence[str] = (K.images, K.seg_masks), + test_split: str = "val2017", + test_keys_to_load: Sequence[str] = (K.images, K.seg_masks), + data_backend: None | ConfigDict = None, + image_size: tuple[int, int] = (520, 520), + samples_per_gpu: int = 2, + workers_per_gpu: int = 2, +) -> DataConfig: + """Get the default config for COCO semantic segmentation.""" + data = DataConfig() + + data.train_dataloader = get_train_dataloader( + data_root=data_root, + split=train_split, + keys_to_load=train_keys_to_load, + data_backend=data_backend, + image_size=image_size, + samples_per_gpu=samples_per_gpu, + workers_per_gpu=workers_per_gpu, + ) + + data.test_dataloader = get_test_dataloader( + data_root=data_root, + split=test_split, + keys_to_load=test_keys_to_load, + data_backend=data_backend, + image_size=image_size, + samples_per_gpu=1, + workers_per_gpu=workers_per_gpu, + ) + + return data diff --git a/vis4d/zoo/base/datasets/imagenet.py b/vis4d/zoo/base/datasets/imagenet.py new file mode 100644 index 0000000000000000000000000000000000000000..6f73f02ee51a613ffa5c34216045fdcde93205bf --- /dev/null +++ b/vis4d/zoo/base/datasets/imagenet.py @@ -0,0 +1,217 @@ +"""ImageNet classification config.""" + +from __future__ import annotations + +from collections.abc import Sequence + +from ml_collections import ConfigDict + +from vis4d.config import class_config +from vis4d.config.typing import DataConfig +from vis4d.data.const import CommonKeys as K +from vis4d.data.data_pipe import DataPipe +from vis4d.data.datasets.imagenet import ImageNet +from vis4d.data.transforms.autoaugment import RandAug +from vis4d.data.transforms.base import RandomApply, compose +from vis4d.data.transforms.crop import ( + CropImages, + GenCentralCropParameters, + GenRandomSizeCropParameters, +) +from vis4d.data.transforms.flip import FlipImages +from vis4d.data.transforms.mixup import ( + GenMixupParameters, + MixupCategories, + MixupImages, +) +from vis4d.data.transforms.normalize import NormalizeImages +from vis4d.data.transforms.random_erasing import RandomErasing +from vis4d.data.transforms.resize import GenResizeParameters, ResizeImages +from vis4d.data.transforms.to_tensor import ToTensor +from vis4d.engine.connectors import data_key, pred_key +from vis4d.zoo.base import ( + get_inference_dataloaders_cfg, + get_train_dataloader_cfg, +) + +CONN_IMAGENET_CLS_EVAL = { + "prediction": pred_key("probs"), + "groundtruth": data_key("categories"), +} + + +def get_train_dataloader( + data_root: str, + split: str, + keys_to_load: Sequence[str], + image_size: tuple[int, int], + samples_per_gpu: int, + workers_per_gpu: int, +) -> ConfigDict: + """Get the default train dataloader for ImageNet 1K dataset.""" + # Train Dataset + train_dataset_cfg = class_config( + ImageNet, + data_root=data_root, + split=split, + num_classes=1000, + keys_to_load=keys_to_load, + ) + + flip_trans = class_config( + RandomApply, + transforms=[class_config(FlipImages)], + probability=0.5, + ) + random_resized_crop_trans = [ + class_config(GenRandomSizeCropParameters), + class_config(CropImages), + class_config(GenResizeParameters, shape=image_size, keep_ratio=False), + class_config(ResizeImages), + ] + random_aug_trans = [ + class_config(RandAug, magnitude=10, use_increasing=True), + class_config(RandomErasing), + ] + normalize_trans = class_config(NormalizeImages) + train_preprocess_cfg = class_config( + compose, + transforms=[ + flip_trans, + *random_resized_crop_trans, + *random_aug_trans, + normalize_trans, + ], + ) + + mixup_trans = [ + class_config(GenMixupParameters, alpha=0.2, out_shape=image_size), + class_config(MixupImages), + class_config(MixupCategories, num_classes=1000, label_smoothing=0.1), + ] + train_batchprocess_cfg = class_config( + compose, + transforms=[ + *mixup_trans, + class_config(ToTensor), + ], + ) + + return get_train_dataloader_cfg( + datasets_cfg=class_config( + DataPipe, + datasets=train_dataset_cfg, + preprocess_fn=train_preprocess_cfg, + ), + batchprocess_cfg=train_batchprocess_cfg, + samples_per_gpu=samples_per_gpu, + workers_per_gpu=workers_per_gpu, + ) + + +def get_test_dataloader( + data_root: str, + split: str, + keys_to_load: Sequence[str], + image_size: tuple[int, int], + samples_per_gpu: int, + workers_per_gpu: int, + crop_pct: float = 0.875, +) -> ConfigDict: + """Get the default test dataloader for COCO detection.""" + # Test Dataset + + test_dataset_cfg = class_config( + ImageNet, + data_root=data_root, + split=split, + num_classes=1000, + keys_to_load=keys_to_load, + ) + + crop_size = tuple(int(size / crop_pct) for size in image_size) + resized_crop_trans = [ + class_config( + GenResizeParameters, + shape=crop_size, + keep_ratio=True, + allow_overflow=True, + ), + class_config(ResizeImages), + class_config( + GenCentralCropParameters, shape=image_size, keep_ratio=False + ), + class_config(CropImages), + ] + normalize_trans = class_config(NormalizeImages) + test_preprocess_cfg = class_config( + compose, + transforms=[ + *resized_crop_trans, + normalize_trans, + ], + ) + + mixup_trans = [ + class_config(GenMixupParameters, alpha=0.2, out_shape=image_size), + class_config(MixupImages), + class_config(MixupCategories, num_classes=1000, label_smoothing=0.1), + ] + test_batchprocess_cfg = class_config( + compose, + transforms=[ + *mixup_trans, + class_config(ToTensor), + ], + ) + + return get_inference_dataloaders_cfg( + datasets_cfg=class_config( + DataPipe, + datasets=test_dataset_cfg, + preprocess_fn=test_preprocess_cfg, + ), + batchprocess_cfg=test_batchprocess_cfg, + samples_per_gpu=samples_per_gpu, + workers_per_gpu=workers_per_gpu, + ) + + +def get_imagenet_cls_cfg( + data_root: str = "data/imagenet", + train_split: str = "train", + train_keys_to_load: Sequence[str] = ( + K.images, + K.categories, + ), + test_split: str = "val", + test_keys_to_load: Sequence[str] = ( + K.images, + K.categories, + ), + image_size: tuple[int, int] = (224, 224), + samples_per_gpu: int = 256, + workers_per_gpu: int = 8, +) -> DataConfig: + """Get the default config for COCO detection.""" + data = DataConfig() + + data.train_dataloader = get_train_dataloader( + data_root=data_root, + split=train_split, + keys_to_load=train_keys_to_load, + image_size=image_size, + samples_per_gpu=samples_per_gpu, + workers_per_gpu=workers_per_gpu, + ) + + data.test_dataloader = get_test_dataloader( + data_root=data_root, + split=test_split, + keys_to_load=test_keys_to_load, + image_size=image_size, + samples_per_gpu=1, + workers_per_gpu=workers_per_gpu, + ) + + return data diff --git a/vis4d/zoo/base/datasets/nuscenes/__init__.py b/vis4d/zoo/base/datasets/nuscenes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..35ed92ba8a9c902abb9c1e2d22b9075fd9313690 --- /dev/null +++ b/vis4d/zoo/base/datasets/nuscenes/__init__.py @@ -0,0 +1,21 @@ +"""NuScenes dataset config.""" + +from .nuscenes import ( + get_nusc_mini_train_cfg, + get_nusc_mini_val_cfg, + get_nusc_train_cfg, + get_nusc_val_cfg, +) +from .nuscenes_mono import ( + get_nusc_mono_mini_train_cfg, + get_nusc_mono_train_cfg, +) + +__all__ = [ + "get_nusc_train_cfg", + "get_nusc_mini_train_cfg", + "get_nusc_val_cfg", + "get_nusc_mini_val_cfg", + "get_nusc_mono_train_cfg", + "get_nusc_mono_mini_train_cfg", +] diff --git a/vis4d/zoo/base/datasets/nuscenes/nuscenes.py b/vis4d/zoo/base/datasets/nuscenes/nuscenes.py new file mode 100644 index 0000000000000000000000000000000000000000..1c2491e6a19700fe00edec9258f83168a4c936a1 --- /dev/null +++ b/vis4d/zoo/base/datasets/nuscenes/nuscenes.py @@ -0,0 +1,115 @@ +"""NuScenes multi-sensor video dataset config.""" + +from __future__ import annotations + +from collections.abc import Sequence + +from ml_collections import ConfigDict + +from vis4d.config import class_config +from vis4d.data.const import CommonKeys as K +from vis4d.data.datasets.nuscenes import NuScenes + + +def get_nusc_train_cfg( + data_root: str = "data/nuscenes", + keys_to_load: Sequence[str] = (K.images, K.boxes2d, K.boxes3d), + skip_empty_samples: bool = True, + cache_as_binary: bool = True, + cached_file_path: str | None = None, + data_backend: None | ConfigDict = None, +) -> ConfigDict: + """Get the nuScenes validation dataset config.""" + if cache_as_binary and cached_file_path is None: + cached_file_path = f"{data_root}/train.pkl" + + return class_config( + NuScenes, + data_root=data_root, + keys_to_load=keys_to_load, + version="v1.0-trainval", + split="train", + skip_empty_samples=skip_empty_samples, + data_backend=data_backend, + cache_as_binary=cache_as_binary, + cached_file_path=cached_file_path, + ) + + +def get_nusc_mini_train_cfg( + data_root: str = "data/nuscenes", + keys_to_load: Sequence[str] = (K.images, K.boxes2d, K.boxes3d), + skip_empty_samples: bool = True, + cache_as_binary: bool = True, + cached_file_path: str | None = None, + data_backend: None | ConfigDict = None, +) -> ConfigDict: + """Get the nuScenes validation dataset config.""" + if cache_as_binary and cached_file_path is None: + cached_file_path = f"{data_root}/mini_train.pkl" + + return class_config( + NuScenes, + data_root=data_root, + keys_to_load=keys_to_load, + version="v1.0-mini", + split="mini_train", + skip_empty_samples=skip_empty_samples, + data_backend=data_backend, + cache_as_binary=cache_as_binary, + cached_file_path=cached_file_path, + ) + + +def get_nusc_val_cfg( + data_root: str = "data/nuscenes", + keys_to_load: Sequence[str] = (K.images, K.original_images, K.boxes3d), + skip_empty_samples: bool = False, + cache_as_binary: bool = True, + cached_file_path: str | None = None, + image_channel_mode: str = "RGB", + data_backend: None | ConfigDict = None, +) -> ConfigDict: + """Get the nuScenes validation dataset config.""" + if cache_as_binary and cached_file_path is None: + cached_file_path = f"{data_root}/val.pkl" + + return class_config( + NuScenes, + data_root=data_root, + image_channel_mode=image_channel_mode, + keys_to_load=keys_to_load, + version="v1.0-trainval", + split="val", + skip_empty_samples=skip_empty_samples, + data_backend=data_backend, + cache_as_binary=cache_as_binary, + cached_file_path=cached_file_path, + ) + + +def get_nusc_mini_val_cfg( + data_root: str = "data/nuscenes", + keys_to_load: Sequence[str] = (K.images, K.original_images, K.boxes3d), + skip_empty_samples: bool = False, + cache_as_binary: bool = True, + cached_file_path: str | None = None, + image_channel_mode: str = "RGB", + data_backend: None | ConfigDict = None, +) -> ConfigDict: + """Get the nuScenes mini validation dataset config.""" + if cache_as_binary and cached_file_path is None: + cached_file_path = f"{data_root}/mini_val.pkl" + + return class_config( + NuScenes, + data_root=data_root, + image_channel_mode=image_channel_mode, + keys_to_load=keys_to_load, + version="v1.0-mini", + split="mini_val", + skip_empty_samples=skip_empty_samples, + data_backend=data_backend, + cache_as_binary=cache_as_binary, + cached_file_path=cached_file_path, + ) diff --git a/vis4d/zoo/base/datasets/nuscenes/nuscenes_mono.py b/vis4d/zoo/base/datasets/nuscenes/nuscenes_mono.py new file mode 100644 index 0000000000000000000000000000000000000000..b28ad98d59c9357ed4e16f7bb48c171002c15016 --- /dev/null +++ b/vis4d/zoo/base/datasets/nuscenes/nuscenes_mono.py @@ -0,0 +1,61 @@ +"""NuScenes monocular dataset config.""" + +from __future__ import annotations + +from collections.abc import Sequence + +from ml_collections import ConfigDict + +from vis4d.config import class_config +from vis4d.data.const import CommonKeys as K +from vis4d.data.datasets.nuscenes_mono import NuScenesMono + + +def get_nusc_mono_train_cfg( + data_root: str = "data/nuscenes", + keys_to_load: Sequence[str] = (K.images, K.boxes2d, K.boxes3d), + skip_empty_samples: bool = True, + cache_as_binary: bool = True, + cached_file_path: str | None = None, + data_backend: None | ConfigDict = None, +) -> ConfigDict: + """Get the nuScenes monocular training dataset config.""" + if cache_as_binary and cached_file_path is None: + cached_file_path = f"{data_root}/mono_train.pkl" + + return class_config( + NuScenesMono, + data_root=data_root, + keys_to_load=keys_to_load, + version="v1.0-trainval", + split="train", + skip_empty_samples=skip_empty_samples, + cache_as_binary=cache_as_binary, + cached_file_path=cached_file_path, + data_backend=data_backend, + ) + + +def get_nusc_mono_mini_train_cfg( + data_root: str = "data/nuscenes", + keys_to_load: Sequence[str] = (K.images, K.boxes2d, K.boxes3d), + skip_empty_samples: bool = True, + cache_as_binary: bool = True, + cached_file_path: str | None = None, + data_backend: None | ConfigDict = None, +) -> ConfigDict: + """Get the nuScenes monocular mini training dataset config.""" + if cache_as_binary and cached_file_path is None: + cached_file_path = f"{data_root}/mono_mini_train.pkl" + + return class_config( + NuScenesMono, + data_root=data_root, + keys_to_load=keys_to_load, + version="v1.0-mini", + split="mini_train", + skip_empty_samples=skip_empty_samples, + cache_as_binary=cache_as_binary, + cached_file_path=cached_file_path, + data_backend=data_backend, + ) diff --git a/vis4d/zoo/base/datasets/shift/__init__.py b/vis4d/zoo/base/datasets/shift/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..53686843f4f5e9bc3a47dfe806f1d48053589e79 --- /dev/null +++ b/vis4d/zoo/base/datasets/shift/__init__.py @@ -0,0 +1,27 @@ +"""SHIFT dataset config.""" + +from .tasks import ( + CONN_SHIFT_DET_EVAL, + CONN_SHIFT_INS_EVAL, + get_shift_depth_est_config, + get_shift_det_config, + get_shift_instance_seg_config, + get_shift_multitask_2d_config, + get_shift_multitask_3d_config, + get_shift_optical_flow_config, + get_shift_sem_seg_config, + get_shift_tracking_config, +) + +__all__ = [ + "CONN_SHIFT_DET_EVAL", + "CONN_SHIFT_INS_EVAL", + "get_shift_depth_est_config", + "get_shift_det_config", + "get_shift_instance_seg_config", + "get_shift_tracking_config", + "get_shift_multitask_2d_config", + "get_shift_multitask_3d_config", + "get_shift_optical_flow_config", + "get_shift_sem_seg_config", +] diff --git a/vis4d/zoo/base/datasets/shift/common.py b/vis4d/zoo/base/datasets/shift/common.py new file mode 100644 index 0000000000000000000000000000000000000000..9ed84ec5b86306c856f6883d163884799bf3cee3 --- /dev/null +++ b/vis4d/zoo/base/datasets/shift/common.py @@ -0,0 +1,414 @@ +"""SHIFT data loading config for data augmentation.""" + +from __future__ import annotations + +from collections.abc import Sequence + +from ml_collections.config_dict import ConfigDict + +from vis4d.config import class_config +from vis4d.data.const import CommonKeys as K +from vis4d.data.data_pipe import DataPipe +from vis4d.data.datasets.shift import SHIFT +from vis4d.data.loader import default_collate, multi_sensor_collate +from vis4d.data.transforms.base import RandomApply, compose +from vis4d.data.transforms.crop import ( + CropBoxes2D, + CropDepthMaps, + CropImages, + CropOpticalFlows, + CropSegMasks, + GenCropParameters, +) +from vis4d.data.transforms.flip import ( + FlipBoxes2D, + FlipDepthMaps, + FlipImages, + FlipInstanceMasks, + FlipOpticalFlows, + FlipSegMasks, +) +from vis4d.data.transforms.normalize import NormalizeImages +from vis4d.data.transforms.photometric import ColorJitter +from vis4d.data.transforms.resize import ( + GenResizeParameters, + ResizeBoxes2D, + ResizeDepthMaps, + ResizeImages, + ResizeInstanceMasks, + ResizeOpticalFlows, + ResizeSegMasks, +) +from vis4d.data.transforms.select_sensor import SelectSensor +from vis4d.data.transforms.to_tensor import ToTensor +from vis4d.zoo.base import ( + get_inference_dataloaders_cfg, + get_train_dataloader_cfg, +) + +IMAGE_MEAN = [122.884, 117.266, 110.287] +IMAGE_STD = [59.925, 59.466, 60.69] + + +def get_train_preprocessing( + image_size: tuple[int, int] = (800, 1280), + crop_size: tuple[int, int] | None = None, + horizontal_flip_prob: float = 0.5, + color_jitter_prob: float = 0.0, + keys_to_load: Sequence[str] = (K.images, K.seg_masks), + views_to_load: Sequence[str] = ("front",), +) -> ConfigDict: + """Get the default data preprocessing for SHIFT dataset. + + Args: + image_size: The image size to resize to. Defaults to (800, 1280). + crop_size: The crop size to crop to randomly, if not None. Defaults to + None. This step is applied after the resize step. + horizontal_flip_prob: The probability of horizontal flipping. Defaults + to 0.5. + color_jitter_prob: The probability of color jittering. Defaults to 0.5. + keys_to_load: The keys to load from the dataset. Defaults to + (K.images, K.seg_masks). + views_to_load: The views to load from the dataset. Defaults to + ("front",). + + Returns: + The data preprocessing config. + """ + preprocess_transforms = [] + + for key_to_load in keys_to_load: + assert key_to_load in SHIFT.KEYS, f"Invalid key: {key_to_load}" + + views_arg = {} + if len(views_to_load) == 1: + preprocess_transforms.append( + class_config( + SelectSensor, + selected_sensor=views_to_load[0], + sensors=views_to_load, + ) + ) + elif len(views_to_load) > 1: + views_arg["sensors"] = views_to_load + + # Resize + if image_size != (800, 1280): + preprocess_transforms.append( + class_config( + GenResizeParameters, + shape=image_size, + keep_ratio=True, + **views_arg, + ) + ) + preprocess_transforms.append(class_config(ResizeImages, **views_arg)) + if K.seg_masks in keys_to_load: + preprocess_transforms.append( + class_config(ResizeSegMasks, **views_arg) + ) + if K.boxes2d in keys_to_load: + preprocess_transforms.append( + class_config(ResizeBoxes2D, **views_arg) + ) + if K.instance_masks in keys_to_load: + preprocess_transforms.append( + class_config(ResizeInstanceMasks, **views_arg) + ) + if K.depth_maps in keys_to_load: + preprocess_transforms.append( + class_config(ResizeDepthMaps, **views_arg) + ) + if K.optical_flows in keys_to_load: + preprocess_transforms.append( + class_config( + ResizeOpticalFlows, normalized_flow=False, **views_arg + ) + ) + + # Crop + if crop_size is not None: + preprocess_transforms.append( + class_config( + GenCropParameters, + shape=crop_size, + cat_max_ratio=0.75, + **views_arg, + ), + ) + preprocess_transforms.append(class_config(CropImages, **views_arg)) + if K.seg_masks in keys_to_load: + preprocess_transforms.append( + class_config(CropSegMasks, **views_arg) + ) + if K.boxes2d in keys_to_load: + preprocess_transforms.append( + class_config(CropBoxes2D, **views_arg) + ) + if K.depth_maps in keys_to_load: + preprocess_transforms.append( + class_config(CropDepthMaps, **views_arg) + ) + if K.optical_flows in keys_to_load: + preprocess_transforms.append( + class_config(CropOpticalFlows, **views_arg) + ) + + # Random flip + if horizontal_flip_prob > 0: + flip_transforms = [] + flip_transforms.append(class_config(FlipImages)) + if K.seg_masks in keys_to_load: + flip_transforms.append(class_config(FlipSegMasks)) + if K.boxes2d in keys_to_load: + flip_transforms.append(class_config(FlipBoxes2D)) + if K.instance_masks in keys_to_load: + flip_transforms.append(class_config(FlipInstanceMasks)) + if K.depth_maps in keys_to_load: + flip_transforms.append(class_config(FlipDepthMaps)) + if K.optical_flows in keys_to_load: + flip_transforms.append(class_config(FlipOpticalFlows)) + preprocess_transforms.append( + class_config( + RandomApply, + transforms=flip_transforms, + probability=horizontal_flip_prob, + **views_arg, + ) + ) + + if color_jitter_prob > 0: + preprocess_transforms.append( + class_config( + RandomApply, + transforms=[class_config(ColorJitter, **views_arg)], + probability=color_jitter_prob, + ) + ) + + preprocess_transforms.append( + class_config( + NormalizeImages, mean=IMAGE_MEAN, std=IMAGE_STD, **views_arg + ) + ) + train_preprocess_cfg = class_config( + compose, transforms=preprocess_transforms + ) + + batchprocess_transforms = [class_config(ToTensor, **views_arg)] + train_batchprocess_cfg = class_config( + compose, transforms=batchprocess_transforms + ) + + return train_preprocess_cfg, train_batchprocess_cfg + + +def get_test_preprocessing( + image_size: tuple[int, int] = (800, 1280), + keys_to_load: Sequence[str] = (K.images, K.seg_masks), + views_to_load: Sequence[str] = ("front",), +) -> ConfigDict: + """Get the default data preprocessing for SHIFT dataset. + + Args: + image_size: The image size to resize to. Defaults to (800, 1280). + keys_to_load: The keys to load from the dataset. Defaults to + (K.images, K.seg_masks). + views_to_load: The views to load from the dataset. Defaults to + ("front",). + + Returns: + The data preprocessing config. + """ + preprocess_transforms = [] + + for key_to_load in keys_to_load: + assert key_to_load in SHIFT.KEYS, f"Invalid key: {key_to_load}" + + views_arg = {} + if len(views_to_load) == 1: + preprocess_transforms.append( + class_config( + SelectSensor, + selected_sensor=views_to_load[0], + sensors=views_to_load, + ) + ) + elif len(views_to_load) > 1: + views_arg["sensors"] = views_to_load + + # Resize + if image_size != (800, 1280): + preprocess_transforms.append( + class_config( + GenResizeParameters, + shape=image_size, + keep_ratio=True, + **views_arg, + ) + ) + preprocess_transforms.append(class_config(ResizeImages, **views_arg)) + if K.seg_masks in keys_to_load: + preprocess_transforms.append( + class_config(ResizeSegMasks, **views_arg) + ) + if K.boxes2d in keys_to_load: + preprocess_transforms.append( + class_config(ResizeBoxes2D, **views_arg) + ) + if K.depth_maps in keys_to_load: + preprocess_transforms.append( + class_config(ResizeDepthMaps, **views_arg) + ) + if K.optical_flows in keys_to_load: + preprocess_transforms.append( + class_config(ResizeOpticalFlows, **views_arg) + ) + + preprocess_transforms.append( + class_config( + NormalizeImages, mean=IMAGE_MEAN, std=IMAGE_STD, **views_arg + ) + ) + test_preprocess_cfg = class_config( + compose, transforms=preprocess_transforms + ) + + batchprocess_transforms = [class_config(ToTensor, **views_arg)] + + test_batchprocess_cfg = class_config( + compose, transforms=batchprocess_transforms + ) + + return test_preprocess_cfg, test_batchprocess_cfg + + +def get_shift_dataloader_config( + train_dataset_cfg: ConfigDict, + test_dataset_cfg: ConfigDict, + keys_to_load: Sequence[str] = (K.images, K.seg_masks), + image_size: tuple[int, int] = (800, 1280), + crop_size: tuple[int, int] | None = None, + horizontal_flip_prob: float = 0.5, + color_jitter_prob: float = 0.5, + samples_per_gpu: int = 2, + workers_per_gpu: int = 2, + train_views_to_load: Sequence[str] = ("front",), + test_views_to_load: Sequence[str] = ("front",), +) -> ConfigDict: + """Get the default config for BDD100K segmentation.""" + data = ConfigDict() + + train_preprocess_cfg, train_batchprocess_cfg = get_train_preprocessing( + keys_to_load=keys_to_load, + image_size=image_size, + crop_size=crop_size, + horizontal_flip_prob=horizontal_flip_prob, + color_jitter_prob=color_jitter_prob, + views_to_load=train_views_to_load, + ) + + test_preprocess_cfg, test_batchprocess_cfg = get_test_preprocessing( + keys_to_load=keys_to_load, + image_size=image_size, + views_to_load=test_views_to_load, + ) + + data.train_dataloader = get_train_dataloader_cfg( + datasets_cfg=class_config( + DataPipe, + datasets=train_dataset_cfg, + preprocess_fn=train_preprocess_cfg, + ), + batchprocess_cfg=train_batchprocess_cfg, + samples_per_gpu=samples_per_gpu, + workers_per_gpu=workers_per_gpu, + shuffle=True, + collate_fn=( + multi_sensor_collate + if len(train_views_to_load) > 1 + else default_collate + ), + ) + + # Test Dataset Config + test_dataset_cfg = class_config( + DataPipe, datasets=test_dataset_cfg, preprocess_fn=test_preprocess_cfg + ) + data.test_dataloader = get_inference_dataloaders_cfg( + datasets_cfg=test_dataset_cfg, + batchprocess_cfg=test_batchprocess_cfg, + samples_per_gpu=samples_per_gpu, + workers_per_gpu=workers_per_gpu, + collate_fn=( + multi_sensor_collate + if len(test_views_to_load) > 1 + else default_collate + ), + ) + return data + + +def get_shift_config( # pylint: disable=too-many-arguments, too-many-positional-arguments, line-too-long + data_root: str = "data/shift/images", + train_split: str = "train", + train_framerate: str = "images", + train_shift_type: str = "discrete", + train_views_to_load: Sequence[str] = ("front",), + train_keys_to_load: Sequence[str] = (K.images, K.seg_masks), + train_attributes_to_load: Sequence[dict[str, str | float]] | None = None, + train_skip_empty_frames: bool = False, + test_split: str = "val", + test_framerate: str = "images", + test_shift_type: str = "discrete", + test_views_to_load: Sequence[str] = ("front",), + test_keys_to_load: Sequence[str] = (K.images, K.seg_masks), + test_attributes_to_load: Sequence[dict[str, str | float]] | None = None, + test_skip_empty_frames: bool = False, + data_backend: None | ConfigDict = None, + image_size: tuple[int, int] = (800, 1280), + crop_size: tuple[int, int] | None = None, + horizontal_flip_prob: float = 0.5, + color_jitter_prob: float = 0.0, + samples_per_gpu: int = 2, + workers_per_gpu: int = 2, +) -> ConfigDict: + """Get the default config for BDD100K segmentation.""" + train_dataset_cfg = class_config( + SHIFT, + data_root=data_root, + split=train_split, + framerate=train_framerate, + shift_type=train_shift_type, + views_to_load=train_views_to_load, + keys_to_load=train_keys_to_load, + attributes_to_load=train_attributes_to_load, + skip_empty_frames=train_skip_empty_frames, + backend=data_backend, + ) + test_dataset_cfg = class_config( + SHIFT, + data_root=data_root, + split=test_split, + framerate=test_framerate, + shift_type=test_shift_type, + views_to_load=test_views_to_load, + keys_to_load=test_keys_to_load, + attributes_to_load=test_attributes_to_load, + skip_empty_frames=test_skip_empty_frames, + backend=data_backend, + ) + + return get_shift_dataloader_config( + train_dataset_cfg=train_dataset_cfg, + test_dataset_cfg=test_dataset_cfg, + keys_to_load=train_keys_to_load, + image_size=image_size, + crop_size=crop_size, + horizontal_flip_prob=horizontal_flip_prob, + color_jitter_prob=color_jitter_prob, + samples_per_gpu=samples_per_gpu, + workers_per_gpu=workers_per_gpu, + train_views_to_load=train_views_to_load, + test_views_to_load=test_views_to_load, + ) diff --git a/vis4d/zoo/base/datasets/shift/tasks.py b/vis4d/zoo/base/datasets/shift/tasks.py new file mode 100644 index 0000000000000000000000000000000000000000..3f0304a6e211a486f08cce84ec784ce6e2800668 --- /dev/null +++ b/vis4d/zoo/base/datasets/shift/tasks.py @@ -0,0 +1,183 @@ +"""SHIFT data loading config for segmentation.""" + +from __future__ import annotations + +from ml_collections.config_dict import ConfigDict + +from vis4d.common.typing import ArgsType +from vis4d.data.const import CommonKeys as K +from vis4d.engine.connectors import data_key, pred_key + +from .common import get_shift_config + +CONN_SHIFT_DET_EVAL = { + "frame_ids": data_key("frame_ids"), + "sample_names": data_key("sample_names"), + "sequence_names": data_key("sequence_names"), + "pred_boxes": pred_key("boxes"), + "pred_scores": pred_key("scores"), + "pred_classes": pred_key("class_ids"), +} +CONN_SHIFT_INS_EVAL = { + "frame_ids": data_key("frame_ids"), + "sample_names": data_key("sample_names"), + "sequence_names": data_key("sequence_names"), + "pred_boxes": pred_key("boxes.boxes"), + "pred_scores": pred_key("boxes.scores"), + "pred_classes": pred_key("boxes.class_ids"), + "pred_masks": pred_key("masks.masks"), +} + + +def get_shift_sem_seg_config(**kwargs: ArgsType) -> ConfigDict: + """Get the config for the SHIFT segmentation task.""" + keys_to_load = (K.images, K.input_hw, K.original_hw, K.seg_masks) + cfg = get_shift_config( + train_keys_to_load=keys_to_load, + test_keys_to_load=keys_to_load, + horizontal_flip_prob=0.5, + color_jitter_prob=0.5, + crop_size=kwargs.get("crop_size", (512, 1024)), + **kwargs, + ) + return cfg + + +def get_shift_det_config(**kwargs: ArgsType) -> ConfigDict: + """Get the config for the SHIFT detection task.""" + keys_to_load = ( + K.images, + K.input_hw, + K.original_hw, + K.boxes2d, + K.boxes2d_classes, + ) + cfg = get_shift_config( + train_keys_to_load=keys_to_load, + test_keys_to_load=keys_to_load, + train_skip_empty_frames=True, + test_skip_empty_frames=False, + horizontal_flip_prob=0.5, + color_jitter_prob=0.0, + crop_size=kwargs.get("crop_size", None), + **kwargs, + ) + return cfg + + +def get_shift_instance_seg_config(**kwargs: ArgsType) -> ConfigDict: + """Get the config for the SHIFT instance segmentation task.""" + keys_to_load = ( + K.images, + K.input_hw, + K.original_hw, + K.boxes2d, + K.boxes2d_classes, + K.instance_masks, + ) + cfg = get_shift_config( + train_keys_to_load=keys_to_load, + test_keys_to_load=keys_to_load, + train_skip_empty_frames=True, + test_skip_empty_frames=False, + horizontal_flip_prob=0.5, + color_jitter_prob=0.5, + crop_size=kwargs.get("crop_size", None), + **kwargs, + ) + return cfg + + +def get_shift_depth_est_config(**kwargs: ArgsType) -> ConfigDict: + """Get the config for the SHIFT depth estimation task.""" + keys_to_load = (K.images, K.input_hw, K.original_hw, K.depth_maps) + cfg = get_shift_config( + train_keys_to_load=keys_to_load, + test_keys_to_load=keys_to_load, + horizontal_flip_prob=0.5, + crop_size=kwargs.get("crop_size", None), + **kwargs, + ) + return cfg + + +def get_shift_optical_flow_config(**kwargs: ArgsType) -> ConfigDict: + """Get the config for the SHIFT optical flow task.""" + keys_to_load = (K.images, K.input_hw, K.original_hw, K.optical_flows) + cfg = get_shift_config( + train_keys_to_load=keys_to_load, + test_keys_to_load=keys_to_load, + horizontal_flip_prob=0.5, + crop_size=kwargs.get("crop_size", None), + **kwargs, + ) + return cfg + + +def get_shift_tracking_config(**kwargs: ArgsType) -> ConfigDict: + """Get the config for the SHIFT tracking task.""" + keys_to_load = ( + K.images, + K.input_hw, + K.original_hw, + K.boxes2d, + K.boxes2d_classes, + K.boxes2d_track_ids, + ) + cfg = get_shift_config( + train_keys_to_load=keys_to_load, + test_keys_to_load=keys_to_load, + horizontal_flip_prob=0.5, + crop_size=kwargs.get("crop_size", None), + **kwargs, + ) + return cfg + + +def get_shift_multitask_2d_config(**kwargs: ArgsType) -> ConfigDict: + """Get the config for the SHIFT multitask 2D task.""" + keys_to_load = ( + K.images, + K.input_hw, + K.original_hw, + K.intrinsics, + K.boxes2d, + K.boxes2d_classes, + K.boxes2d_track_ids, + K.seg_masks, + K.depth_maps, + ) + cfg = get_shift_config( + train_keys_to_load=keys_to_load, + test_keys_to_load=keys_to_load, + horizontal_flip_prob=0.5, + crop_size=kwargs.get("crop_size", None), + **kwargs, + ) + return cfg + + +def get_shift_multitask_3d_config(**kwargs: ArgsType) -> ConfigDict: + """Get the config for the SHIFT multitask 3D task.""" + keys_to_load = ( + K.images, + K.input_hw, + K.original_hw, + K.intrinsics, + K.boxes2d, + K.boxes2d_classes, + K.boxes2d_track_ids, + K.boxes3d, + K.boxes3d_classes, + K.boxes3d_track_ids, + K.seg_masks, + K.depth_maps, + ) + cfg = get_shift_config( + train_keys_to_load=keys_to_load, + test_keys_to_load=keys_to_load, + horizontal_flip_prob=0.5, + crop_size=kwargs.get("crop_size", None), + **kwargs, + ) + return cfg diff --git a/vis4d/zoo/base/models/__init__.py b/vis4d/zoo/base/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..95d427c92e9b91c3b2d68699c9b8f7d14828d594 --- /dev/null +++ b/vis4d/zoo/base/models/__init__.py @@ -0,0 +1 @@ +"""Model Zoo base models.""" diff --git a/vis4d/zoo/base/models/faster_rcnn.py b/vis4d/zoo/base/models/faster_rcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..1b0a4a18491aab96a7242f91debe0dbdd15f34bb --- /dev/null +++ b/vis4d/zoo/base/models/faster_rcnn.py @@ -0,0 +1,148 @@ +"""Faseter R-CNN base model config.""" + +from __future__ import annotations + +from ml_collections import ConfigDict, FieldReference + +from vis4d.config import class_config +from vis4d.engine.connectors import LossConnector, data_key, pred_key +from vis4d.engine.loss_module import LossModule +from vis4d.model.detect.faster_rcnn import FasterRCNN +from vis4d.op.box.anchor import AnchorGenerator +from vis4d.op.box.encoder import DeltaXYWHBBoxDecoder, DeltaXYWHBBoxEncoder +from vis4d.op.box.matchers import MaxIoUMatcher +from vis4d.op.box.samplers import RandomSampler +from vis4d.op.detect.faster_rcnn import FasterRCNNHead +from vis4d.op.detect.rcnn import RCNNHead, RCNNLoss +from vis4d.op.detect.rpn import RPNLoss + +# Data connectors +CONN_RPN_LOSS_2D = { + "cls_outs": pred_key("rpn.cls"), + "reg_outs": pred_key("rpn.box"), + "target_boxes": data_key("boxes2d"), + "images_hw": data_key("input_hw"), +} + +CONN_ROI_LOSS_2D = { + "class_outs": pred_key("roi.cls_score"), + "regression_outs": pred_key("roi.bbox_pred"), + "boxes": pred_key("sampled_proposals.boxes"), + "boxes_mask": pred_key("sampled_targets.labels"), + "target_boxes": pred_key("sampled_targets.boxes"), + "target_classes": pred_key("sampled_targets.classes"), +} + + +def get_default_rpn_box_codec_cfg( + target_means: tuple[float, ...] = (0.0, 0.0, 0.0, 0.0), + target_stds: tuple[float, ...] = (1.0, 1.0, 1.0, 1.0), +) -> tuple[ConfigDict, ConfigDict]: + """Get default config for rpn box encoder and decoder.""" + return tuple( + class_config(x, target_means=target_means, target_stds=target_stds) + for x in (DeltaXYWHBBoxEncoder, DeltaXYWHBBoxDecoder) + ) + + +def get_default_rcnn_box_codec_cfg( + target_means: tuple[float, ...] = (0.0, 0.0, 0.0, 0.0), + target_stds: tuple[float, ...] = (0.1, 0.1, 0.2, 0.2), +) -> tuple[ConfigDict, ConfigDict]: + """Get default config for rcnn box encoder and decoder.""" + return tuple( + class_config(x, target_means=target_means, target_stds=target_stds) + for x in (DeltaXYWHBBoxEncoder, DeltaXYWHBBoxDecoder) + ) + + +def get_faster_rcnn_cfg( + num_classes: FieldReference | int, + basemodel: ConfigDict, + weights: str | None = None, +) -> tuple[ConfigDict, ConfigDict]: + """Return default config for faster_rcnn model and loss. + + This is an example for setting every component of the model and loss. + Everything is the same as the default args. + + Args: + num_classes (FieldReference | int): Number of classes. + basemodel (ConfigDict): Base model config. + weights (str | None, optional): Weights to load. Defaults to None. + """ + ###################################################### + ## MODEL ## + ###################################################### + anchor_generator = class_config( + AnchorGenerator, + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64], + ) + + rpn_box_encoder, rpn_box_decoder = get_default_rpn_box_codec_cfg() + rcnn_box_encoder, rcnn_box_decoder = get_default_rcnn_box_codec_cfg() + + box_matcher = class_config( + MaxIoUMatcher, + thresholds=[0.5], + labels=[0, 1], + allow_low_quality_matches=False, + ) + + box_sampler = class_config( + RandomSampler, batch_size=512, positive_fraction=0.25 + ) + + roi_head = class_config(RCNNHead, num_classes=num_classes) + + faster_rcnn_head = class_config( + FasterRCNNHead, + num_classes=num_classes, + anchor_generator=anchor_generator, + rpn_box_decoder=rpn_box_decoder, + box_matcher=box_matcher, + box_sampler=box_sampler, + roi_head=roi_head, + ) + + model = class_config( + FasterRCNN, + num_classes=num_classes, + basemodel=basemodel, + faster_rcnn_head=faster_rcnn_head, + rcnn_box_decoder=rcnn_box_decoder, + weights=weights, + ) + + ###################################################### + ## LOSS ## + ###################################################### + rpn_loss = class_config( + RPNLoss, + anchor_generator=anchor_generator, + box_encoder=rpn_box_encoder, + ) + rcnn_loss = class_config( + RCNNLoss, box_encoder=rcnn_box_encoder, num_classes=num_classes + ) + + loss = class_config( + LossModule, + losses=[ + { + "loss": rpn_loss, + "connector": class_config( + LossConnector, key_mapping=CONN_RPN_LOSS_2D + ), + }, + { + "loss": rcnn_loss, + "connector": class_config( + LossConnector, key_mapping=CONN_ROI_LOSS_2D + ), + }, + ], + ) + return model, loss diff --git a/vis4d/zoo/base/models/mask_rcnn.py b/vis4d/zoo/base/models/mask_rcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..9f556e56d3eaffacb36b0e74cc354597984e14a6 --- /dev/null +++ b/vis4d/zoo/base/models/mask_rcnn.py @@ -0,0 +1,163 @@ +"""Mask RCNN base model config.""" + +from __future__ import annotations + +from ml_collections import ConfigDict, FieldReference + +from vis4d.config import class_config +from vis4d.data.const import CommonKeys as K +from vis4d.engine.connectors import ( + LossConnector, + data_key, + pred_key, + remap_pred_keys, +) +from vis4d.engine.loss_module import LossModule +from vis4d.model.detect.mask_rcnn import MaskRCNN +from vis4d.op.box.anchor import AnchorGenerator +from vis4d.op.box.matchers import MaxIoUMatcher +from vis4d.op.box.samplers import RandomSampler +from vis4d.op.detect.faster_rcnn import FasterRCNNHead +from vis4d.op.detect.mask_rcnn import ( + MaskRCNNHead, + MaskRCNNHeadLoss, + SampledMaskLoss, + positive_mask_sampler, +) +from vis4d.op.detect.rcnn import RCNNHead, RCNNLoss +from vis4d.op.detect.rpn import RPNLoss +from vis4d.zoo.base import get_callable_cfg +from vis4d.zoo.base.models.faster_rcnn import ( + CONN_ROI_LOSS_2D as _CONN_ROI_LOSS_2D, +) +from vis4d.zoo.base.models.faster_rcnn import ( + CONN_RPN_LOSS_2D as _CONN_RPN_LOSS_2D, +) +from vis4d.zoo.base.models.faster_rcnn import ( + get_default_rcnn_box_codec_cfg, + get_default_rpn_box_codec_cfg, +) + +# Data connectors +CONN_MASK_HEAD_LOSS_2D = { + "mask_preds": pred_key("masks.mask_pred"), + "target_masks": data_key(K.instance_masks), + "sampled_target_indices": pred_key("boxes.sampled_target_indices"), + "sampled_targets": pred_key("boxes.sampled_targets"), + "sampled_proposals": pred_key("boxes.sampled_proposals"), +} + +CONN_RPN_LOSS_2D = remap_pred_keys(_CONN_RPN_LOSS_2D, "boxes") + +CONN_ROI_LOSS_2D = remap_pred_keys(_CONN_ROI_LOSS_2D, "boxes") + + +def get_mask_rcnn_cfg( + num_classes: FieldReference | int, + basemodel: ConfigDict, + no_overlap: bool = False, + weights: str | None = None, +) -> tuple[ConfigDict, ConfigDict]: + """Return default config for mask_rcnn model and loss. + + This is an example for setting every component of the model and loss. + Everything is the same as the default args. + + Args: + num_classes (FieldReference | int): Number of classes. + basemodel (ConfigDict): Base model config. + no_overlap (bool, optional): Whether to remove overlapping pixels + between masks. Defaults to False. + weights (str | None, optional): Weights to load. Defaults to None. + """ + ###################################################### + ## MODEL ## + ###################################################### + anchor_generator = class_config( + AnchorGenerator, + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64], + ) + + rpn_box_encoder, rpn_box_decoder = get_default_rpn_box_codec_cfg() + rcnn_box_encoder, rcnn_box_decoder = get_default_rcnn_box_codec_cfg() + + box_matcher = class_config( + MaxIoUMatcher, + thresholds=[0.5], + labels=[0, 1], + allow_low_quality_matches=False, + ) + + box_sampler = class_config( + RandomSampler, batch_size=512, positive_fraction=0.25 + ) + + roi_head = class_config(RCNNHead, num_classes=num_classes) + + mask_head = class_config(MaskRCNNHead, num_classes=num_classes) + + faster_rcnn_head = class_config( + FasterRCNNHead, + num_classes=num_classes, + anchor_generator=anchor_generator, + rpn_box_decoder=rpn_box_decoder, + box_matcher=box_matcher, + box_sampler=box_sampler, + roi_head=roi_head, + ) + + model = class_config( + MaskRCNN, + num_classes=num_classes, + basemodel=basemodel, + faster_rcnn_head=faster_rcnn_head, + mask_head=mask_head, + rcnn_box_decoder=rcnn_box_decoder, + no_overlap=no_overlap, + weights=weights, + ) + + ###################################################### + ## LOSS ## + ###################################################### + rpn_loss = class_config( + RPNLoss, + anchor_generator=anchor_generator, + box_encoder=rpn_box_encoder, + ) + rcnn_loss = class_config( + RCNNLoss, box_encoder=rcnn_box_encoder, num_classes=num_classes + ) + + mask_loss = class_config( + SampledMaskLoss, + mask_sampler=get_callable_cfg(positive_mask_sampler), + loss=class_config(MaskRCNNHeadLoss, num_classes=num_classes), + ) + + loss = class_config( + LossModule, + losses=[ + { + "loss": rpn_loss, + "connector": class_config( + LossConnector, key_mapping=CONN_RPN_LOSS_2D + ), + }, + { + "loss": rcnn_loss, + "connector": class_config( + LossConnector, key_mapping=CONN_ROI_LOSS_2D + ), + }, + { + "loss": mask_loss, + "connector": class_config( + LossConnector, key_mapping=CONN_MASK_HEAD_LOSS_2D + ), + }, + ], + ) + return model, loss diff --git a/vis4d/zoo/base/models/qdtrack.py b/vis4d/zoo/base/models/qdtrack.py new file mode 100644 index 0000000000000000000000000000000000000000..89b2e5848ac0cf5f8446a0a07762459c979f5a6a --- /dev/null +++ b/vis4d/zoo/base/models/qdtrack.py @@ -0,0 +1,219 @@ +"""QD-Track model config.""" + +from __future__ import annotations + +from ml_collections import ConfigDict, FieldReference + +from vis4d.config import class_config +from vis4d.data.const import CommonKeys as K +from vis4d.engine.connectors import LossConnector, pred_key, remap_pred_keys +from vis4d.engine.loss_module import LossModule +from vis4d.model.adapter import ModelExpEMAAdapter +from vis4d.model.track.qdtrack import FasterRCNNQDTrack, YOLOXQDTrack +from vis4d.op.box.anchor import AnchorGenerator +from vis4d.op.box.poolers import MultiScaleRoIAlign +from vis4d.op.detect.faster_rcnn import FasterRCNNHead +from vis4d.op.detect.rcnn import RCNNLoss +from vis4d.op.detect.rpn import RPNLoss +from vis4d.op.detect.yolox import YOLOXHeadLoss +from vis4d.op.loss.common import smooth_l1_loss +from vis4d.op.track.qdtrack import ( + QDSimilarityHead, + QDTrackHead, + QDTrackInstanceSimilarityLoss, +) +from vis4d.zoo.base import get_callable_cfg +from vis4d.zoo.base.models.faster_rcnn import ( + CONN_ROI_LOSS_2D as _CONN_ROI_LOSS_2D, +) +from vis4d.zoo.base.models.faster_rcnn import ( + get_default_rcnn_box_codec_cfg, + get_default_rpn_box_codec_cfg, +) + +from .yolox import get_yolox_model_cfg + +PRED_PREFIX = "detector_out" + +CONN_BBOX_2D_TRAIN = { + "images": K.images, + "images_hw": K.input_hw, + "original_hw": K.original_hw, + "frame_ids": K.frame_ids, + "boxes2d": K.boxes2d, + "boxes2d_classes": K.boxes2d_classes, + "boxes2d_track_ids": K.boxes2d_track_ids, + "keyframes": "keyframes", +} + +CONN_BBOX_2D_TEST = { + "images": K.images, + "images_hw": K.input_hw, + "original_hw": K.original_hw, + "frame_ids": K.frame_ids, +} + +CONN_RPN_LOSS_2D = { + "cls_outs": pred_key(f"{PRED_PREFIX}.rpn.cls"), + "reg_outs": pred_key(f"{PRED_PREFIX}.rpn.box"), + "target_boxes": pred_key("key_target_boxes"), + "images_hw": pred_key("key_images_hw"), +} + +CONN_ROI_LOSS_2D = remap_pred_keys(_CONN_ROI_LOSS_2D, PRED_PREFIX) + +CONN_TRACK_LOSS_2D = { + "key_embeddings": pred_key("key_embeddings"), + "ref_embeddings": pred_key("ref_embeddings"), + "key_track_ids": pred_key("key_track_ids"), + "ref_track_ids": pred_key("ref_track_ids"), +} + +CONN_YOLOX_LOSS_2D = { + "cls_outs": pred_key(f"{PRED_PREFIX}.cls_score"), + "reg_outs": pred_key(f"{PRED_PREFIX}.bbox_pred"), + "obj_outs": pred_key(f"{PRED_PREFIX}.objectness"), + "target_boxes": pred_key("key_target_boxes"), + "target_class_ids": pred_key("key_target_classes"), + "images_hw": pred_key("key_images_hw"), +} + + +def get_qdtrack_cfg( + num_classes: int | FieldReference, + basemodel: ConfigDict, + weights: str | None = None, +) -> tuple[ConfigDict, ConfigDict]: + """Get QDTrack model config.""" + ###################################################### + ## MODEL ## + ###################################################### + anchor_generator = class_config( + AnchorGenerator, + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64], + ) + + rpn_box_encoder, _ = get_default_rpn_box_codec_cfg() + rcnn_box_encoder, _ = get_default_rcnn_box_codec_cfg() + + faster_rcnn_head = class_config( + FasterRCNNHead, + num_classes=num_classes, + anchor_generator=anchor_generator, + ) + + model = class_config( + FasterRCNNQDTrack, + num_classes=num_classes, + basemodel=basemodel, + faster_rcnn_head=faster_rcnn_head, + weights=weights, + ) + + rpn_loss = class_config( + RPNLoss, + anchor_generator=anchor_generator, + box_encoder=rpn_box_encoder, + loss_bbox=get_callable_cfg(smooth_l1_loss, beta=1.0 / 9.0), + ) + rcnn_loss = class_config( + RCNNLoss, + box_encoder=rcnn_box_encoder, + num_classes=num_classes, + loss_bbox=get_callable_cfg(smooth_l1_loss), + ) + + track_loss = class_config(QDTrackInstanceSimilarityLoss) + + loss = class_config( + LossModule, + losses=[ + { + "loss": rpn_loss, + "connector": class_config( + LossConnector, key_mapping=CONN_RPN_LOSS_2D + ), + }, + { + "loss": rcnn_loss, + "connector": class_config( + LossConnector, key_mapping=CONN_ROI_LOSS_2D + ), + }, + { + "loss": track_loss, + "connector": class_config( + LossConnector, key_mapping=CONN_TRACK_LOSS_2D + ), + }, + ], + ) + + return model, loss + + +def get_qdtrack_yolox_cfg( + num_classes: int | FieldReference, + model_type: str, + use_ema: bool = True, + weights: str | None = None, +) -> tuple[ConfigDict, ConfigDict]: + """Get QDTrack YOLOX model config.""" + ###################################################### + ## MODEL ## + ###################################################### + basemodel, fpn, yolox_head = get_yolox_model_cfg(num_classes, model_type) + if model_type == "tiny": + in_dim = 96 + elif model_type == "small": + in_dim = 128 + elif model_type == "large": + in_dim = 256 + elif model_type == "xlarge": + in_dim = 320 + else: + raise ValueError(f"Invalid model type: {model_type}") + model = class_config( + YOLOXQDTrack, + num_classes=num_classes, + basemodel=basemodel, + fpn=fpn, + yolox_head=yolox_head, + qdtrack_head=class_config( + QDTrackHead, + similarity_head=class_config( + QDSimilarityHead, + proposal_pooler=MultiScaleRoIAlign( + resolution=(7, 7), strides=[8, 16, 32], sampling_ratio=0 + ), + in_dim=in_dim, + ), + ), + weights=weights, + ) + if use_ema: + model = class_config(ModelExpEMAAdapter, model=model) + + track_loss = class_config(QDTrackInstanceSimilarityLoss) + + loss = class_config( + LossModule, + losses=[ + { + "loss": class_config(YOLOXHeadLoss, num_classes=num_classes), + "connector": class_config( + LossConnector, key_mapping=CONN_YOLOX_LOSS_2D + ), + }, + { + "loss": track_loss, + "connector": class_config( + LossConnector, key_mapping=CONN_TRACK_LOSS_2D + ), + }, + ], + ) + + return model, loss diff --git a/vis4d/zoo/base/models/yolox.py b/vis4d/zoo/base/models/yolox.py new file mode 100644 index 0000000000000000000000000000000000000000..eadc3bc191210a57dd9ee0f2ec27c8b5c4e726cf --- /dev/null +++ b/vis4d/zoo/base/models/yolox.py @@ -0,0 +1,221 @@ +"""YOLOX base model config.""" + +from __future__ import annotations + +from ml_collections import ConfigDict, FieldReference +from torch.optim.lr_scheduler import CosineAnnealingLR +from torch.optim.sgd import SGD + +from vis4d.config import class_config +from vis4d.config.typing import OptimizerConfig +from vis4d.data.const import CommonKeys as K +from vis4d.engine.callbacks import ( + EMACallback, + YOLOXModeSwitchCallback, + YOLOXSyncNormCallback, + YOLOXSyncRandomResizeCallback, +) +from vis4d.engine.connectors import LossConnector, data_key, pred_key +from vis4d.engine.loss_module import LossModule +from vis4d.engine.optim.scheduler import ConstantLR, QuadraticLRWarmup +from vis4d.model.adapter import ModelExpEMAAdapter +from vis4d.model.detect.yolox import YOLOX +from vis4d.op.base import CSPDarknet +from vis4d.op.detect.yolox import YOLOXHead, YOLOXHeadLoss +from vis4d.op.fpp import YOLOXPAFPN +from vis4d.zoo.base import get_lr_scheduler_cfg, get_optimizer_cfg + +# Data connectors +CONN_YOLOX_LOSS_2D = { + "cls_outs": pred_key("cls_score"), + "reg_outs": pred_key("bbox_pred"), + "obj_outs": pred_key("objectness"), + "target_boxes": data_key(K.boxes2d), + "target_class_ids": data_key(K.boxes2d_classes), + "images_hw": data_key(K.input_hw), +} + + +def get_yolox_optimizers_cfg( + lr: float | FieldReference, + num_epochs: int | FieldReference, + warmup_epochs: int, + num_last_epochs: int, +) -> list[OptimizerConfig]: + """Construct optimizer for YOLOX training.""" + return [ + get_optimizer_cfg( + optimizer=class_config( + SGD, + lr=lr, + momentum=0.9, + weight_decay=5e-4, + nesterov=True, + ), + lr_schedulers=[ + get_lr_scheduler_cfg( + class_config(QuadraticLRWarmup, max_steps=warmup_epochs), + end=warmup_epochs, + epoch_based=False, + convert_epochs_to_steps=True, + convert_attributes=["max_steps"], + ), + get_lr_scheduler_cfg( + class_config( + CosineAnnealingLR, + T_max=num_epochs - num_last_epochs - warmup_epochs, + eta_min=lr * 0.05, + ), + begin=warmup_epochs, + end=num_epochs - num_last_epochs, + epoch_based=False, + convert_epochs_to_steps=True, + convert_attributes=["T_max"], + ), + get_lr_scheduler_cfg( + class_config( + ConstantLR, max_steps=num_last_epochs, factor=1.0 + ), + begin=num_epochs - num_last_epochs, + end=num_epochs, + epoch_based=True, + ), + ], + param_groups=[ + { + "custom_keys": ["basemodel", "fpn", "yolox_head"], + "norm_decay_mult": 0.0, + }, + { + "custom_keys": ["basemodel", "fpn", "yolox_head"], + "bias_decay_mult": 0.0, + }, + ], + ) + ] + + +def get_yolox_callbacks_cfg( + switch_epoch: int, + shape: tuple[int, int] = (480, 480), + num_sizes: int = 11, + use_ema: bool = True, +) -> list[ConfigDict]: + """Get YOLOX callbacks for training.""" + callbacks = [] + if num_sizes > 0: + callbacks.append( + class_config( + YOLOXSyncRandomResizeCallback, + size_list=[ + (shape[0] + i * 32, shape[1] + i * 32) + for i in range(num_sizes) + ], + interval=10, + ) + ) + callbacks += [ + class_config(YOLOXModeSwitchCallback, switch_epoch=switch_epoch), + class_config(YOLOXSyncNormCallback), + ] + if use_ema: + callbacks += [class_config(EMACallback)] + return callbacks + + +def get_model_setting(model_type: str) -> tuple[float, float, int, list[int]]: + """Get YOLOX model setting.""" + if model_type == "tiny": + deepen_factor, widen_factor, num_csp_blocks = 0.33, 0.375, 1 + in_channels = [96, 192, 384] + elif model_type == "small": + deepen_factor, widen_factor, num_csp_blocks = 0.33, 0.5, 1 + in_channels = [128, 256, 512] + elif model_type == "large": + deepen_factor, widen_factor, num_csp_blocks = 1.0, 1.0, 3 + in_channels = [256, 512, 1024] + elif model_type == "xlarge": + deepen_factor, widen_factor, num_csp_blocks = 1.33, 1.25, 4 + in_channels = [320, 640, 1280] + else: + raise ValueError(f"Unknown model type: {model_type}") + return deepen_factor, widen_factor, num_csp_blocks, in_channels + + +def get_yolox_model_cfg( + num_classes: FieldReference | int, model_type: str +) -> ConfigDict: + """Get YOLOX model.""" + assert model_type in {"tiny", "small", "large", "xlarge"}, ( + f"model_type must be one of 'tiny', 'small', 'large', 'xlarge', " + f"got {model_type}." + ) + ( + deepen_factor, + widen_factor, + num_csp_blocks, + in_channels, + ) = get_model_setting(model_type) + basemodel = class_config( + CSPDarknet, deepen_factor=deepen_factor, widen_factor=widen_factor + ) + fpn = class_config( + YOLOXPAFPN, + in_channels=in_channels, + out_channels=in_channels[0], + num_csp_blocks=num_csp_blocks, + ) + yolox_head = class_config( + YOLOXHead, + num_classes=num_classes, + in_channels=in_channels[0], + feat_channels=in_channels[0], + ) + return basemodel, fpn, yolox_head + + +def get_yolox_cfg( + num_classes: FieldReference | int, + model_type: str, + use_ema: bool = True, + weights: str | None = None, +) -> tuple[ConfigDict, ConfigDict]: + """Return default config for YOLOX model and loss. + + Args: + num_classes (FieldReference | int): Number of classes. + model_type (str): Model type. Must be one of 'tiny', 'small', 'large', + 'xlarge'. + use_ema (bool, optional): Whether to use EMA. Defaults to True. + weights (str | None, optional): Weights to load. Defaults to None. + """ + ###################################################### + ## MODEL ## + ###################################################### + basemodel, fpn, yolox_head = get_yolox_model_cfg(num_classes, model_type) + model = class_config( + YOLOX, + num_classes=num_classes, + basemodel=basemodel, + fpn=fpn, + yolox_head=yolox_head, + weights=weights, + ) + if use_ema: + model = class_config(ModelExpEMAAdapter, model=model) + + ###################################################### + ## LOSS ## + ###################################################### + loss = class_config( + LossModule, + losses=[ + { + "loss": class_config(YOLOXHeadLoss, num_classes=num_classes), + "connector": class_config( + LossConnector, key_mapping=CONN_YOLOX_LOSS_2D + ), + }, + ], + ) + return model, loss diff --git a/vis4d/zoo/base/optimizer.py b/vis4d/zoo/base/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..9d5bebcb550e637301d508893db8ac43661a1004 --- /dev/null +++ b/vis4d/zoo/base/optimizer.py @@ -0,0 +1,82 @@ +"""Optimizer configuration.""" + +from __future__ import annotations + +from ml_collections import ConfigDict + +from vis4d.config.typing import ( + LrSchedulerConfig, + OptimizerConfig, + ParamGroupCfg, +) + + +def get_lr_scheduler_cfg( + scheduler: ConfigDict, + begin: int = 0, + end: int = -1, + epoch_based: bool = True, + convert_epochs_to_steps: bool = False, + convert_attributes: list[str] | None = None, +) -> LrSchedulerConfig: + """Default learning rate scheduler configuration. + + This creates a config object that can be initialized as a LearningRate + scheduler for training. + + Args: + scheduler (ConfigDict): Learning rate scheduler configuration. + begin (int, optional): Begin epoch. Defaults to 0. + end (int, optional): End epoch. Defaults to None. Defaults to -1. + epoch_based (bool, optional): Whether the learning rate scheduler is + epoch based or step based. Defaults to True. + convert_epochs_to_steps (bool): Whether to convert the begin and end + for a step based scheduler to steps automatically based on length + of train dataloader. Enables users to set the iteration breakpoints + as epochs. Defaults to False. + convert_attributes (list[str] | None): List of attributes in the + scheduler that should be converted to steps. Defaults to None. + + Returns: + LrSchedulerConfig: Config dict that can be instantiated as LearningRate + scheduler. + """ + lr_scheduler = LrSchedulerConfig() + + lr_scheduler.scheduler = scheduler + lr_scheduler.begin = begin + lr_scheduler.end = end + lr_scheduler.epoch_based = epoch_based + lr_scheduler.convert_epochs_to_steps = convert_epochs_to_steps + lr_scheduler.convert_attributes = convert_attributes + + return lr_scheduler + + +def get_optimizer_cfg( + optimizer: ConfigDict, + lr_schedulers: list[LrSchedulerConfig] | None = None, + param_groups: list[ParamGroupCfg] | None = None, +) -> OptimizerConfig: + """Default optimizer configuration. + + This creates a config object that can be initialized as an Optimizer for + training. + + Args: + optimizer (ConfigDict): Optimizer configuration. + lr_schedulers (list[LrSchedulerConfig] | None, optional): Learning rate + schedulers configuration. Defaults to None. + param_groups (list[ParamGroupCfg] | None, optional): Parameter groups + configuration. Defaults to None. + + Returns: + OptimizerConfig: Config dict that can be instantiated as Optimizer. + """ + optim = OptimizerConfig() + + optim.optimizer = optimizer + optim.lr_schedulers = lr_schedulers + optim.param_groups = param_groups + + return optim diff --git a/vis4d/zoo/base/pl_trainer.py b/vis4d/zoo/base/pl_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e54e625cb78d8946d7018c95fd86302d232cba97 --- /dev/null +++ b/vis4d/zoo/base/pl_trainer.py @@ -0,0 +1,38 @@ +"""Default runtime configuration for PyTorch Lightning.""" + +import inspect + +from lightning import Trainer + +from vis4d.config import FieldConfigDict +from vis4d.config.typing import ExperimentConfig + + +def get_default_pl_trainer_cfg(config: ExperimentConfig) -> ExperimentConfig: + """Get PyTorch Lightning Trainer config.""" + pl_trainer = FieldConfigDict() + + # PL Trainer arguments + for k, v in inspect.signature(Trainer).parameters.items(): + if not k in {"callbacks", "devices", "logger", "strategy"}: + pl_trainer[k] = v.default + + # PL Trainer settings + pl_trainer.benchmark = config.benchmark + pl_trainer.use_distributed_sampler = False + pl_trainer.num_sanity_val_steps = 0 + + # logger + pl_trainer.enable_progress_bar = False + pl_trainer.log_every_n_steps = config.log_every_n_steps + + # Default Trainer arguments + pl_trainer.work_dir = config.work_dir + pl_trainer.exp_name = config.experiment_name + pl_trainer.version = config.version + pl_trainer.find_unused_parameters = False + pl_trainer.checkpoint_period = 1 + pl_trainer.save_top_k = 1 + pl_trainer.wandb = False + + return pl_trainer diff --git a/vis4d/zoo/base/runtime.py b/vis4d/zoo/base/runtime.py new file mode 100644 index 0000000000000000000000000000000000000000..d127ace78dd8b54678911e07e4b9f9dd01ff4256 --- /dev/null +++ b/vis4d/zoo/base/runtime.py @@ -0,0 +1,93 @@ +"""Default runtime configuration for the project.""" + +from __future__ import annotations + +import platform +from datetime import datetime + +from ml_collections import ConfigDict + +from vis4d.config import class_config +from vis4d.config.typing import ExperimentConfig +from vis4d.engine.callbacks import LoggingCallback + + +def get_default_cfg( + exp_name: str, work_dir: str = "vis4d-workspace" +) -> ExperimentConfig: + """Set default config for the project. + + Args: + exp_name (str): Experiment name. + work_dir (str, optional): Working directory. Defaults to + "vis4d-workspace". + + Returns: + ExperimentConfig: Config for the project. + """ + config = ExperimentConfig() + + config.work_dir = work_dir + config.experiment_name = exp_name + + timestamp = ( + str(datetime.now()) + .split(".", maxsplit=1)[0] + .replace(" ", "_") + .replace(":", "-") + ) + config.timestamp = timestamp + config.version = timestamp + + if platform.system() == "Windows": + path_component = "\\" + else: + path_component = "/" + + config.output_dir = ( + config.work_dir + + path_component + + config.experiment_name + + path_component + + config.version + ) + + # Set default value for the following fields + config.seed = -1 + config.log_every_n_steps = 50 + config.use_tf32 = False + config.tf32_matmul_precision = "highest" + config.benchmark = False + config.compute_flops = False + config.check_unused_parameters = False + + return config + + +def get_default_callbacks_cfg( + epoch_based: bool = True, + refresh_rate: int = 50, +) -> list[ConfigDict]: + """Get default callbacks config. + + It will return a list of callbacks config including: + - LoggingCallback + + Args: + epoch_based (bool, optional): Whether to use epoch based logging. + refresh_rate (int, optional): Refresh rate for the logging. Defaults to + 50. + + Returns: + list[ConfigDict]: List of callbacks config. + """ + callbacks = [] + + # Logger + callbacks.append( + class_config( + LoggingCallback, epoch_based=epoch_based, refresh_rate=refresh_rate + ) + ) + + return callbacks diff --git a/vis4d/zoo/bdd100k/__init__.py b/vis4d/zoo/bdd100k/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..44d079c7f6048d6ca3787f42f360776ca24c0fc2 --- /dev/null +++ b/vis4d/zoo/bdd100k/__init__.py @@ -0,0 +1,27 @@ +"""BDD100K Model Zoo.""" + +from .faster_rcnn import faster_rcnn_r50_1x_bdd100k, faster_rcnn_r50_3x_bdd100k +from .mask_rcnn import ( + mask_rcnn_r50_1x_bdd100k, + mask_rcnn_r50_3x_bdd100k, + mask_rcnn_r50_5x_bdd100k, +) +from .qdtrack import qdtrack_frcnn_r50_fpn_1x_bdd100k +from .semantic_fpn import ( + semantic_fpn_r50_40k_bdd100k, + semantic_fpn_r50_80k_bdd100k, + semantic_fpn_r101_80k_bdd100k, +) + +# Lists of available models in BDD100K Model Zoo. +AVAILABLE_MODELS = { + "faster_rcnn_r50_1x_bdd100k": faster_rcnn_r50_1x_bdd100k, + "faster_rcnn_r50_3x_bdd100k": faster_rcnn_r50_3x_bdd100k, + "mask_rcnn_r50_1x_bdd100k": mask_rcnn_r50_1x_bdd100k, + "mask_rcnn_r50_3x_bdd100k": mask_rcnn_r50_3x_bdd100k, + "mask_rcnn_r50_5x_bdd100k": mask_rcnn_r50_5x_bdd100k, + "semantic_fpn_r50_40k_bdd100k": semantic_fpn_r50_40k_bdd100k, + "semantic_fpn_r50_80k_bdd100k": semantic_fpn_r50_80k_bdd100k, + "semantic_fpn_r101_80k_bdd100k": semantic_fpn_r101_80k_bdd100k, + "qdtrack_frcnn_r50_fpn_1x_bdd100k": qdtrack_frcnn_r50_fpn_1x_bdd100k, +} diff --git a/vis4d/zoo/bdd100k/faster_rcnn/__init__.py b/vis4d/zoo/bdd100k/faster_rcnn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fc567e6ee7189a8c9ce7908e0ed3d1d68d8625e1 --- /dev/null +++ b/vis4d/zoo/bdd100k/faster_rcnn/__init__.py @@ -0,0 +1 @@ +"""Faster R-CNN for BDD100K.""" diff --git a/vis4d/zoo/bdd100k/faster_rcnn/faster_rcnn_r50_1x_bdd100k.py b/vis4d/zoo/bdd100k/faster_rcnn/faster_rcnn_r50_1x_bdd100k.py new file mode 100644 index 0000000000000000000000000000000000000000..ad720c0ec350a8feddf5bded595f7352a3c74a1e --- /dev/null +++ b/vis4d/zoo/bdd100k/faster_rcnn/faster_rcnn_r50_1x_bdd100k.py @@ -0,0 +1,162 @@ +# pylint: disable=duplicate-code +"""Faster RCNN BDD100K training example.""" +from __future__ import annotations + +from torch.optim.lr_scheduler import LinearLR, MultiStepLR +from torch.optim.sgd import SGD + +from vis4d.config import class_config +from vis4d.config.typing import ExperimentConfig, ExperimentParameters +from vis4d.data.io.hdf5 import HDF5Backend +from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback +from vis4d.engine.connectors import CallbackConnector, DataConnector +from vis4d.eval.bdd100k import BDD100KDetectEvaluator +from vis4d.op.base import ResNet +from vis4d.vis.image import BoundingBoxVisualizer +from vis4d.zoo.base import ( + get_default_callbacks_cfg, + get_default_cfg, + get_default_pl_trainer_cfg, + get_lr_scheduler_cfg, + get_optimizer_cfg, +) +from vis4d.zoo.base.data_connectors import ( + CONN_BBOX_2D_TEST, + CONN_BBOX_2D_TRAIN, + CONN_BBOX_2D_VIS, +) +from vis4d.zoo.base.datasets.bdd100k import ( + CONN_BDD100K_DET_EVAL, + get_bdd100k_detection_config, +) +from vis4d.zoo.base.models.faster_rcnn import get_faster_rcnn_cfg + + +def get_config() -> ExperimentConfig: + """Returns the Faster-RCNN config dict for the BDD100K detection task. + + Returns: + ExperimentConfig: The configuration + """ + ###################################################### + ## General Config ## + ###################################################### + config = get_default_cfg(exp_name="faster_rcnn_r50_1x_bdd100k") + + # High level hyper parameters + params = ExperimentParameters() + params.samples_per_gpu = 2 + params.workers_per_gpu = 2 + params.lr = 0.02 + params.num_epochs = 12 + params.num_classes = 10 + config.params = params + + ###################################################### + ## Datasets with augmentations ## + ###################################################### + data_root = "data/bdd100k/images/100k" + train_split = "train" + test_split = "val" + + data_backend = class_config(HDF5Backend) + + config.data = get_bdd100k_detection_config( + data_root=data_root, + train_split=train_split, + test_split=test_split, + data_backend=data_backend, + samples_per_gpu=params.samples_per_gpu, + workers_per_gpu=params.workers_per_gpu, + ) + + ###################################################### + ## MODEL & LOSS ## + ###################################################### + basemodel = class_config( + ResNet, resnet_name="resnet50", pretrained=True, trainable_layers=3 + ) + + config.model, config.loss = get_faster_rcnn_cfg( + num_classes=params.num_classes, basemodel=basemodel + ) + + ###################################################### + ## OPTIMIZERS ## + ###################################################### + config.optimizers = [ + get_optimizer_cfg( + optimizer=class_config( + SGD, lr=params.lr, momentum=0.9, weight_decay=0.0001 + ), + lr_schedulers=[ + get_lr_scheduler_cfg( + class_config( + LinearLR, start_factor=0.001, total_iters=500 + ), + end=500, + epoch_based=False, + ), + get_lr_scheduler_cfg( + class_config(MultiStepLR, milestones=[8, 11], gamma=0.1), + ), + ], + ) + ] + + ###################################################### + ## DATA CONNECTOR ## + ###################################################### + config.train_data_connector = class_config( + DataConnector, key_mapping=CONN_BBOX_2D_TRAIN + ) + + config.test_data_connector = class_config( + DataConnector, key_mapping=CONN_BBOX_2D_TEST + ) + + ###################################################### + ## CALLBACKS ## + ###################################################### + # Logger + callbacks = get_default_callbacks_cfg() + + # Visualizer + callbacks.append( + class_config( + VisualizerCallback, + visualizer=class_config(BoundingBoxVisualizer, vis_freq=100), + output_dir=config.output_dir, + test_connector=class_config( + CallbackConnector, key_mapping=CONN_BBOX_2D_VIS + ), + ) + ) + + # Evaluator + callbacks.append( + class_config( + EvaluatorCallback, + evaluator=class_config( + BDD100KDetectEvaluator, + annotation_path="data/bdd100k/labels/det_20/det_val.json", + config_path="det", + ), + test_connector=class_config( + CallbackConnector, key_mapping=CONN_BDD100K_DET_EVAL + ), + metrics_to_eval=[BDD100KDetectEvaluator.METRICS_DET], + ) + ) + + config.callbacks = callbacks + + ###################################################### + ## PL CLI ## + ###################################################### + # PL Trainer args + pl_trainer = get_default_pl_trainer_cfg(config) + pl_trainer.max_epochs = params.num_epochs + config.pl_trainer = pl_trainer + + return config.value_mode() diff --git a/vis4d/zoo/bdd100k/faster_rcnn/faster_rcnn_r50_3x_bdd100k.py b/vis4d/zoo/bdd100k/faster_rcnn/faster_rcnn_r50_3x_bdd100k.py new file mode 100644 index 0000000000000000000000000000000000000000..de53a841171f3b551cb9126c80f7539480ca61ba --- /dev/null +++ b/vis4d/zoo/bdd100k/faster_rcnn/faster_rcnn_r50_3x_bdd100k.py @@ -0,0 +1,163 @@ +# pylint: disable=duplicate-code +"""Faster RCNN BDD100K training example.""" +from __future__ import annotations + +from torch.optim.lr_scheduler import LinearLR, MultiStepLR +from torch.optim.sgd import SGD + +from vis4d.config import class_config +from vis4d.config.typing import ExperimentConfig, ExperimentParameters +from vis4d.data.io.hdf5 import HDF5Backend +from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback +from vis4d.engine.connectors import CallbackConnector, DataConnector +from vis4d.eval.bdd100k import BDD100KDetectEvaluator +from vis4d.op.base import ResNet +from vis4d.vis.image import BoundingBoxVisualizer +from vis4d.zoo.base import ( + get_default_callbacks_cfg, + get_default_cfg, + get_default_pl_trainer_cfg, + get_lr_scheduler_cfg, + get_optimizer_cfg, +) +from vis4d.zoo.base.data_connectors import ( + CONN_BBOX_2D_TEST, + CONN_BBOX_2D_TRAIN, + CONN_BBOX_2D_VIS, +) +from vis4d.zoo.base.datasets.bdd100k import ( + CONN_BDD100K_DET_EVAL, + get_bdd100k_detection_config, +) +from vis4d.zoo.base.models.faster_rcnn import get_faster_rcnn_cfg + + +def get_config() -> ExperimentConfig: + """Returns the Faster-RCNN config dict for the BDD100K detection task. + + Returns: + ExperimentConfig: The configuration + """ + ###################################################### + ## General Config ## + ###################################################### + config = get_default_cfg(exp_name="faster_rcnn_r50_3x_bdd100k") + + # High level hyper parameters + params = ExperimentParameters() + params.samples_per_gpu = 2 + params.workers_per_gpu = 2 + params.lr = 0.02 + params.num_epochs = 36 + params.num_classes = 10 + config.params = params + + ###################################################### + ## Datasets with augmentations ## + ###################################################### + data_root = "data/bdd100k/images/100k" + train_split = "train" + test_split = "val" + + data_backend = class_config(HDF5Backend) + + config.data = get_bdd100k_detection_config( + data_root=data_root, + train_split=train_split, + test_split=test_split, + data_backend=data_backend, + multi_scale=True, + samples_per_gpu=params.samples_per_gpu, + workers_per_gpu=params.workers_per_gpu, + ) + + ###################################################### + ## MODEL & LOSS ## + ###################################################### + basemodel = class_config( + ResNet, resnet_name="resnet50", pretrained=True, trainable_layers=3 + ) + + config.model, config.loss = get_faster_rcnn_cfg( + num_classes=params.num_classes, basemodel=basemodel + ) + + ###################################################### + ## OPTIMIZERS ## + ###################################################### + config.optimizers = [ + get_optimizer_cfg( + optimizer=class_config( + SGD, lr=params.lr, momentum=0.9, weight_decay=0.0001 + ), + lr_schedulers=[ + get_lr_scheduler_cfg( + class_config( + LinearLR, start_factor=0.001, total_iters=500 + ), + end=500, + epoch_based=False, + ), + get_lr_scheduler_cfg( + class_config(MultiStepLR, milestones=[24, 33], gamma=0.1), + ), + ], + ) + ] + + ###################################################### + ## DATA CONNECTOR ## + ###################################################### + config.train_data_connector = class_config( + DataConnector, key_mapping=CONN_BBOX_2D_TRAIN + ) + + config.test_data_connector = class_config( + DataConnector, key_mapping=CONN_BBOX_2D_TEST + ) + + ###################################################### + ## CALLBACKS ## + ###################################################### + # Logger + callbacks = get_default_callbacks_cfg() + + # Visualizer + callbacks.append( + class_config( + VisualizerCallback, + visualizer=class_config(BoundingBoxVisualizer, vis_freq=100), + output_dir=config.output_dir, + test_connector=class_config( + CallbackConnector, key_mapping=CONN_BBOX_2D_VIS + ), + ) + ) + + # Evaluator + callbacks.append( + class_config( + EvaluatorCallback, + evaluator=class_config( + BDD100KDetectEvaluator, + annotation_path="data/bdd100k/labels/det_20/det_val.json", + config_path="det", + ), + test_connector=class_config( + CallbackConnector, key_mapping=CONN_BDD100K_DET_EVAL + ), + metrics_to_eval=[BDD100KDetectEvaluator.METRICS_DET], + ) + ) + + config.callbacks = callbacks + + ###################################################### + ## PL CLI ## + ###################################################### + # PL Trainer args + pl_trainer = get_default_pl_trainer_cfg(config) + pl_trainer.max_epochs = params.num_epochs + config.pl_trainer = pl_trainer + + return config.value_mode() diff --git a/vis4d/zoo/bdd100k/mask_rcnn/__init__.py b/vis4d/zoo/bdd100k/mask_rcnn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e738b6b71d039bbcf361451597e4850179bb1883 --- /dev/null +++ b/vis4d/zoo/bdd100k/mask_rcnn/__init__.py @@ -0,0 +1 @@ +"""Mask R-CNN for BDD100K.""" diff --git a/vis4d/zoo/bdd100k/mask_rcnn/mask_rcnn_r50_1x_bdd100k.py b/vis4d/zoo/bdd100k/mask_rcnn/mask_rcnn_r50_1x_bdd100k.py new file mode 100644 index 0000000000000000000000000000000000000000..e6fb4c4e515c58be188081e399928f12d778a6ff --- /dev/null +++ b/vis4d/zoo/bdd100k/mask_rcnn/mask_rcnn_r50_1x_bdd100k.py @@ -0,0 +1,167 @@ +# pylint: disable=duplicate-code +"""Mask RCNN BDD100K training example.""" +from __future__ import annotations + +from torch.optim.lr_scheduler import LinearLR, MultiStepLR +from torch.optim.sgd import SGD + +from vis4d.config import class_config +from vis4d.config.typing import ExperimentConfig, ExperimentParameters +from vis4d.data.const import CommonKeys as K +from vis4d.data.io.hdf5 import HDF5Backend +from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback +from vis4d.engine.connectors import CallbackConnector, DataConnector +from vis4d.eval.bdd100k import BDD100KDetectEvaluator +from vis4d.op.base import ResNet +from vis4d.vis.image import SegMaskVisualizer +from vis4d.zoo.base import ( + get_default_callbacks_cfg, + get_default_cfg, + get_default_pl_trainer_cfg, + get_lr_scheduler_cfg, + get_optimizer_cfg, +) +from vis4d.zoo.base.data_connectors import ( + CONN_BBOX_2D_TEST, + CONN_BBOX_2D_TRAIN, + CONN_INS_MASK_2D_VIS, +) +from vis4d.zoo.base.datasets.bdd100k import ( + CONN_BDD100K_INS_EVAL, + get_bdd100k_detection_config, +) +from vis4d.zoo.base.models.mask_rcnn import get_mask_rcnn_cfg + + +def get_config() -> ExperimentConfig: + """Returns the Mask R-CNN config dict for BDD100K instance segmentation. + + Returns: + ExperimentConfig: The configuration + """ + ###################################################### + ## General Config ## + ###################################################### + config = get_default_cfg(exp_name="mask_rcnn_r50_1x_bdd100k") + + # High level hyper parameters + params = ExperimentParameters() + params.samples_per_gpu = 2 + params.workers_per_gpu = 2 + params.lr = 0.02 + params.num_epochs = 12 + params.num_classes = 8 + config.params = params + + ###################################################### + ## Datasets with augmentations ## + ###################################################### + data_root = "data/bdd100k/images/10k" + train_split = "train" + test_split = "val" + + data_backend = class_config(HDF5Backend) + + config.data = get_bdd100k_detection_config( + data_root=data_root, + train_split=train_split, + train_keys_to_load=(K.images, K.boxes2d, K.instance_masks), + test_split=test_split, + test_keys_to_load=(K.images, K.original_images), + ins_seg=True, + data_backend=data_backend, + samples_per_gpu=params.samples_per_gpu, + workers_per_gpu=params.workers_per_gpu, + ) + + ###################################################### + ## MODEL & LOSS ## + ###################################################### + basemodel = class_config( + ResNet, resnet_name="resnet50", pretrained=True, trainable_layers=3 + ) + + config.model, config.loss = get_mask_rcnn_cfg( + num_classes=params.num_classes, + basemodel=basemodel, + no_overlap=True, + ) + + ###################################################### + ## OPTIMIZERS ## + ###################################################### + config.optimizers = [ + get_optimizer_cfg( + optimizer=class_config( + SGD, lr=params.lr, momentum=0.9, weight_decay=0.0001 + ), + lr_schedulers=[ + get_lr_scheduler_cfg( + class_config( + LinearLR, start_factor=0.001, total_iters=500 + ), + end=500, + epoch_based=False, + ), + get_lr_scheduler_cfg( + class_config(MultiStepLR, milestones=[8, 11], gamma=0.1), + ), + ], + ) + ] + + ###################################################### + ## DATA CONNECTOR ## + ###################################################### + config.train_data_connector = class_config( + DataConnector, key_mapping=CONN_BBOX_2D_TRAIN + ) + + config.test_data_connector = class_config( + DataConnector, key_mapping=CONN_BBOX_2D_TEST + ) + + ###################################################### + ## CALLBACKS ## + ###################################################### + # Logger + callbacks = get_default_callbacks_cfg() + + # Visualizer + callbacks.append( + class_config( + VisualizerCallback, + visualizer=class_config(SegMaskVisualizer, vis_freq=25), + output_dir=config.output_dir, + test_connector=class_config( + CallbackConnector, key_mapping=CONN_INS_MASK_2D_VIS + ), + ) + ) + + # Evaluator + callbacks.append( + class_config( + EvaluatorCallback, + evaluator=class_config( + BDD100KDetectEvaluator, + annotation_path="data/bdd100k/labels/ins_seg_val_rle.json", + config_path="ins_seg", + ), + test_connector=class_config( + CallbackConnector, key_mapping=CONN_BDD100K_INS_EVAL + ), + ) + ) + + config.callbacks = callbacks + + ###################################################### + ## PL CLI ## + ###################################################### + # PL Trainer args + pl_trainer = get_default_pl_trainer_cfg(config) + pl_trainer.max_epochs = params.num_epochs + config.pl_trainer = pl_trainer + + return config.value_mode() diff --git a/vis4d/zoo/bdd100k/mask_rcnn/mask_rcnn_r50_3x_bdd100k.py b/vis4d/zoo/bdd100k/mask_rcnn/mask_rcnn_r50_3x_bdd100k.py new file mode 100644 index 0000000000000000000000000000000000000000..f57801ba3af4f21bbada1b6aafea2ecb81f51431 --- /dev/null +++ b/vis4d/zoo/bdd100k/mask_rcnn/mask_rcnn_r50_3x_bdd100k.py @@ -0,0 +1,168 @@ +# pylint: disable=duplicate-code +"""Mask RCNN BDD100K training example.""" +from __future__ import annotations + +from torch.optim.lr_scheduler import LinearLR, MultiStepLR +from torch.optim.sgd import SGD + +from vis4d.config import class_config +from vis4d.config.typing import ExperimentConfig, ExperimentParameters +from vis4d.data.const import CommonKeys as K +from vis4d.data.io.hdf5 import HDF5Backend +from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback +from vis4d.engine.connectors import CallbackConnector, DataConnector +from vis4d.eval.bdd100k import BDD100KDetectEvaluator +from vis4d.op.base import ResNet +from vis4d.vis.image import SegMaskVisualizer +from vis4d.zoo.base import ( + get_default_callbacks_cfg, + get_default_cfg, + get_default_pl_trainer_cfg, + get_lr_scheduler_cfg, + get_optimizer_cfg, +) +from vis4d.zoo.base.data_connectors import ( + CONN_BBOX_2D_TEST, + CONN_BBOX_2D_TRAIN, + CONN_INS_MASK_2D_VIS, +) +from vis4d.zoo.base.datasets.bdd100k import ( + CONN_BDD100K_INS_EVAL, + get_bdd100k_detection_config, +) +from vis4d.zoo.base.models.mask_rcnn import get_mask_rcnn_cfg + + +def get_config() -> ExperimentConfig: + """Returns the Mask R-CNN config dict for BDD100K instance segmentation. + + Returns: + ExperimentConfig: The configuration + """ + ###################################################### + ## General Config ## + ###################################################### + config = get_default_cfg(exp_name="mask_rcnn_r50_3x_bdd100k") + config.check_val_every_n_epoch = 3 + + # High level hyper parameters + params = ExperimentParameters() + params.samples_per_gpu = 2 + params.workers_per_gpu = 2 + params.lr = 0.02 + params.num_epochs = 36 + params.num_classes = 8 + config.params = params + + ###################################################### + ## Datasets with augmentations ## + ###################################################### + data_root = "data/bdd100k/images/10k" + train_split = "train" + test_split = "val" + + data_backend = class_config(HDF5Backend) + + config.data = get_bdd100k_detection_config( + data_root=data_root, + train_split=train_split, + train_keys_to_load=(K.images, K.boxes2d, K.instance_masks), + test_split=test_split, + test_keys_to_load=(K.images, K.original_images), + ins_seg=True, + multi_scale=True, + data_backend=data_backend, + samples_per_gpu=params.samples_per_gpu, + workers_per_gpu=params.workers_per_gpu, + ) + + ###################################################### + ## MODEL & LOSS ## + ###################################################### + basemodel = class_config( + ResNet, resnet_name="resnet50", pretrained=True, trainable_layers=3 + ) + + config.model, config.loss = get_mask_rcnn_cfg( + num_classes=params.num_classes, basemodel=basemodel, no_overlap=True + ) + + ###################################################### + ## OPTIMIZERS ## + ###################################################### + config.optimizers = [ + get_optimizer_cfg( + optimizer=class_config( + SGD, lr=params.lr, momentum=0.9, weight_decay=0.0001 + ), + lr_schedulers=[ + get_lr_scheduler_cfg( + class_config( + LinearLR, start_factor=0.001, total_iters=500 + ), + end=500, + epoch_based=False, + ), + get_lr_scheduler_cfg( + class_config(MultiStepLR, milestones=[24, 33], gamma=0.1), + ), + ], + ) + ] + + ###################################################### + ## DATA CONNECTOR ## + ###################################################### + config.train_data_connector = class_config( + DataConnector, key_mapping=CONN_BBOX_2D_TRAIN + ) + + config.test_data_connector = class_config( + DataConnector, key_mapping=CONN_BBOX_2D_TEST + ) + + ###################################################### + ## CALLBACKS ## + ###################################################### + # Logger + callbacks = get_default_callbacks_cfg() + + # Visualizer + callbacks.append( + class_config( + VisualizerCallback, + visualizer=class_config(SegMaskVisualizer, vis_freq=25), + output_dir=config.output_dir, + test_connector=class_config( + CallbackConnector, key_mapping=CONN_INS_MASK_2D_VIS + ), + ) + ) + + # Evaluator + callbacks.append( + class_config( + EvaluatorCallback, + evaluator=class_config( + BDD100KDetectEvaluator, + annotation_path="data/bdd100k/labels/ins_seg_val_rle.json", + config_path="ins_seg", + ), + test_connector=class_config( + CallbackConnector, key_mapping=CONN_BDD100K_INS_EVAL + ), + ) + ) + + config.callbacks = callbacks + + ###################################################### + ## PL CLI ## + ###################################################### + # PL Trainer args + pl_trainer = get_default_pl_trainer_cfg(config) + pl_trainer.max_epochs = params.num_epochs + pl_trainer.check_val_every_n_epoch = config.check_val_every_n_epoch + config.pl_trainer = pl_trainer + + return config.value_mode() diff --git a/vis4d/zoo/bdd100k/mask_rcnn/mask_rcnn_r50_5x_bdd100k.py b/vis4d/zoo/bdd100k/mask_rcnn/mask_rcnn_r50_5x_bdd100k.py new file mode 100644 index 0000000000000000000000000000000000000000..e2e6dc4ba71bf0383eec50846883e4c48ee62b34 --- /dev/null +++ b/vis4d/zoo/bdd100k/mask_rcnn/mask_rcnn_r50_5x_bdd100k.py @@ -0,0 +1,168 @@ +# pylint: disable=duplicate-code +"""Mask RCNN BDD100K training example.""" +from __future__ import annotations + +from torch.optim.lr_scheduler import LinearLR, MultiStepLR +from torch.optim.sgd import SGD + +from vis4d.config import class_config +from vis4d.config.typing import ExperimentConfig, ExperimentParameters +from vis4d.data.const import CommonKeys as K +from vis4d.data.io.hdf5 import HDF5Backend +from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback +from vis4d.engine.connectors import CallbackConnector, DataConnector +from vis4d.eval.bdd100k import BDD100KDetectEvaluator +from vis4d.op.base import ResNet +from vis4d.vis.image import SegMaskVisualizer +from vis4d.zoo.base import ( + get_default_callbacks_cfg, + get_default_cfg, + get_default_pl_trainer_cfg, + get_lr_scheduler_cfg, + get_optimizer_cfg, +) +from vis4d.zoo.base.data_connectors import ( + CONN_BBOX_2D_TEST, + CONN_BBOX_2D_TRAIN, + CONN_INS_MASK_2D_VIS, +) +from vis4d.zoo.base.datasets.bdd100k import ( + CONN_BDD100K_INS_EVAL, + get_bdd100k_detection_config, +) +from vis4d.zoo.base.models.mask_rcnn import get_mask_rcnn_cfg + + +def get_config() -> ExperimentConfig: + """Returns the Mask R-CNN config dict for BDD100K instance segmentation. + + Returns: + ExperimentConfig: The configuration + """ + ###################################################### + ## General Config ## + ###################################################### + config = get_default_cfg(exp_name="mask_rcnn_r50_5x_bdd100k") + config.check_val_every_n_epoch = 5 + + # High level hyper parameters + params = ExperimentParameters() + params.samples_per_gpu = 2 + params.workers_per_gpu = 2 + params.lr = 0.02 + params.num_epochs = 60 + params.num_classes = 8 + config.params = params + + ###################################################### + ## Datasets with augmentations ## + ###################################################### + data_root = "data/bdd100k/images/10k" + train_split = "train" + test_split = "val" + + data_backend = class_config(HDF5Backend) + + config.data = get_bdd100k_detection_config( + data_root=data_root, + train_split=train_split, + train_keys_to_load=(K.images, K.boxes2d, K.instance_masks), + test_split=test_split, + test_keys_to_load=(K.images, K.original_images), + ins_seg=True, + multi_scale=True, + data_backend=data_backend, + samples_per_gpu=params.samples_per_gpu, + workers_per_gpu=params.workers_per_gpu, + ) + + ###################################################### + ## MODEL & LOSS ## + ###################################################### + basemodel = class_config( + ResNet, resnet_name="resnet50", pretrained=True, trainable_layers=3 + ) + + config.model, config.loss = get_mask_rcnn_cfg( + num_classes=params.num_classes, basemodel=basemodel, no_overlap=True + ) + + ###################################################### + ## OPTIMIZERS ## + ###################################################### + config.optimizers = [ + get_optimizer_cfg( + optimizer=class_config( + SGD, lr=params.lr, momentum=0.9, weight_decay=0.0001 + ), + lr_schedulers=[ + get_lr_scheduler_cfg( + class_config( + LinearLR, start_factor=0.001, total_iters=500 + ), + end=500, + epoch_based=False, + ), + get_lr_scheduler_cfg( + class_config(MultiStepLR, milestones=[40, 55], gamma=0.1), + ), + ], + ) + ] + + ###################################################### + ## DATA CONNECTOR ## + ###################################################### + config.train_data_connector = class_config( + DataConnector, key_mapping=CONN_BBOX_2D_TRAIN + ) + + config.test_data_connector = class_config( + DataConnector, key_mapping=CONN_BBOX_2D_TEST + ) + + ###################################################### + ## CALLBACKS ## + ###################################################### + # Logger + callbacks = get_default_callbacks_cfg() + + # Visualizer + callbacks.append( + class_config( + VisualizerCallback, + visualizer=class_config(SegMaskVisualizer, vis_freq=25), + output_dir=config.output_dir, + test_connector=class_config( + CallbackConnector, key_mapping=CONN_INS_MASK_2D_VIS + ), + ) + ) + + # Evaluator + callbacks.append( + class_config( + EvaluatorCallback, + evaluator=class_config( + BDD100KDetectEvaluator, + annotation_path="data/bdd100k/labels/ins_seg_val_rle.json", + config_path="ins_seg", + ), + test_connector=class_config( + CallbackConnector, key_mapping=CONN_BDD100K_INS_EVAL + ), + ) + ) + + config.callbacks = callbacks + + ###################################################### + ## PL CLI ## + ###################################################### + # PL Trainer args + pl_trainer = get_default_pl_trainer_cfg(config) + pl_trainer.max_epochs = params.num_epochs + pl_trainer.check_val_every_n_epoch = config.check_val_every_n_epoch + config.pl_trainer = pl_trainer + + return config.value_mode() diff --git a/vis4d/zoo/bdd100k/qdtrack/__init__.py b/vis4d/zoo/bdd100k/qdtrack/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dd5f1cd38da271c6b636d9c8e26995e9dc86caad --- /dev/null +++ b/vis4d/zoo/bdd100k/qdtrack/__init__.py @@ -0,0 +1 @@ +"""QDTrack for BDD100k.""" diff --git a/vis4d/zoo/bdd100k/qdtrack/qdtrack_frcnn_r50_fpn_1x_bdd100k.py b/vis4d/zoo/bdd100k/qdtrack/qdtrack_frcnn_r50_fpn_1x_bdd100k.py new file mode 100644 index 0000000000000000000000000000000000000000..90470087f689d390ce36fc5cdbe46e97cd160817 --- /dev/null +++ b/vis4d/zoo/bdd100k/qdtrack/qdtrack_frcnn_r50_fpn_1x_bdd100k.py @@ -0,0 +1,140 @@ +# pylint: disable=duplicate-code +"""QDTrack with Faster R-CNN on BDD100K.""" +from __future__ import annotations + +from torch.optim.lr_scheduler import LinearLR, MultiStepLR +from torch.optim.sgd import SGD + +from vis4d.config import class_config +from vis4d.config.typing import ExperimentConfig, ExperimentParameters +from vis4d.data.datasets.bdd100k import bdd100k_track_map +from vis4d.data.io.hdf5 import HDF5Backend +from vis4d.engine.callbacks import EvaluatorCallback +from vis4d.engine.connectors import CallbackConnector, DataConnector +from vis4d.eval.bdd100k import BDD100KTrackEvaluator +from vis4d.op.base import ResNet +from vis4d.zoo.base import ( + get_default_callbacks_cfg, + get_default_cfg, + get_default_pl_trainer_cfg, + get_lr_scheduler_cfg, + get_optimizer_cfg, +) +from vis4d.zoo.base.datasets.bdd100k import ( + CONN_BDD100K_TRACK_EVAL, + get_bdd100k_track_cfg, +) +from vis4d.zoo.base.models.qdtrack import ( + CONN_BBOX_2D_TEST, + CONN_BBOX_2D_TRAIN, + get_qdtrack_cfg, +) + + +def get_config() -> ExperimentConfig: + """Returns the config dict for qdtrack on bdd100k. + + Returns: + ExperimentConfig: The configuration + """ + ###################################################### + ## General Config ## + ###################################################### + config = get_default_cfg(exp_name="qdtrack_frcnn_r50_fpn_1x_bdd100k") + + # High level hyper parameters + params = ExperimentParameters() + params.samples_per_gpu = 4 # batch size = 4 GPUs * 4 samples per GPU = 16 + params.workers_per_gpu = 4 + params.lr = 0.02 + params.num_epochs = 12 + config.params = params + + ###################################################### + ## Datasets with augmentations ## + ###################################################### + data_backend = class_config(HDF5Backend) + + config.data = get_bdd100k_track_cfg( + data_backend=data_backend, + samples_per_gpu=params.samples_per_gpu, + workers_per_gpu=params.workers_per_gpu, + ) + + ###################################################### + ## MODEL ## + ###################################################### + num_classes = len(bdd100k_track_map) + basemodel = class_config( + ResNet, resnet_name="resnet50", pretrained=True, trainable_layers=3 + ) + + config.model, config.loss = get_qdtrack_cfg( + num_classes=num_classes, basemodel=basemodel + ) + + ###################################################### + ## OPTIMIZERS ## + ###################################################### + config.optimizers = [ + get_optimizer_cfg( + optimizer=class_config( + SGD, lr=params.lr, momentum=0.9, weight_decay=0.0001 + ), + lr_schedulers=[ + get_lr_scheduler_cfg( + class_config(LinearLR, start_factor=0.1, total_iters=1000), + end=1000, + epoch_based=False, + ), + get_lr_scheduler_cfg( + class_config(MultiStepLR, milestones=[8, 11], gamma=0.1), + ), + ], + ) + ] + + ###################################################### + ## DATA CONNECTOR ## + ###################################################### + config.train_data_connector = class_config( + DataConnector, key_mapping=CONN_BBOX_2D_TRAIN + ) + + config.test_data_connector = class_config( + DataConnector, key_mapping=CONN_BBOX_2D_TEST + ) + + ###################################################### + ## CALLBACKS ## + ###################################################### + # Logger + callbacks = get_default_callbacks_cfg() + + # Evaluator + callbacks.append( + class_config( + EvaluatorCallback, + evaluator=class_config( + BDD100KTrackEvaluator, + annotation_path="data/bdd100k/labels/box_track_20/val/", + ), + test_connector=class_config( + CallbackConnector, key_mapping=CONN_BDD100K_TRACK_EVAL + ), + ) + ) + + config.callbacks = callbacks + + ###################################################### + ## PL CLI ## + ###################################################### + # PL Trainer args + pl_trainer = get_default_pl_trainer_cfg(config) + pl_trainer.max_epochs = params.num_epochs + config.pl_trainer = pl_trainer + + pl_trainer.gradient_clip_val = 35 + + return config.value_mode() diff --git a/vis4d/zoo/bdd100k/semantic_fpn/__init__.py b/vis4d/zoo/bdd100k/semantic_fpn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bc7e40eb37a174fb774bbba23d0bbe851a10b38b --- /dev/null +++ b/vis4d/zoo/bdd100k/semantic_fpn/__init__.py @@ -0,0 +1 @@ +"""Semantic FPN for BDD100K.""" diff --git a/vis4d/zoo/bdd100k/semantic_fpn/semantic_fpn_r101_80k_bdd100k.py b/vis4d/zoo/bdd100k/semantic_fpn/semantic_fpn_r101_80k_bdd100k.py new file mode 100644 index 0000000000000000000000000000000000000000..bea81a90ad0ada011dcba9a13166c61a186be58e --- /dev/null +++ b/vis4d/zoo/bdd100k/semantic_fpn/semantic_fpn_r101_80k_bdd100k.py @@ -0,0 +1,199 @@ +# pylint: disable=duplicate-code +"""Semantic FPN BDD100K training example.""" +from __future__ import annotations + +from torch.optim.lr_scheduler import LinearLR +from torch.optim.sgd import SGD + +from vis4d.config import class_config +from vis4d.config.typing import ExperimentConfig, ExperimentParameters +from vis4d.data.io.hdf5 import HDF5Backend +from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback +from vis4d.engine.connectors import ( + CallbackConnector, + DataConnector, + LossConnector, +) +from vis4d.engine.loss_module import LossModule +from vis4d.engine.optim import PolyLR +from vis4d.eval.bdd100k import BDD100KSegEvaluator +from vis4d.model.seg.semantic_fpn import SemanticFPN +from vis4d.op.base import ResNetV1c +from vis4d.op.loss import SegCrossEntropyLoss +from vis4d.vis.image import SegMaskVisualizer +from vis4d.zoo.base import ( + get_default_callbacks_cfg, + get_default_cfg, + get_default_pl_trainer_cfg, + get_lr_scheduler_cfg, + get_optimizer_cfg, +) +from vis4d.zoo.base.data_connectors.seg import ( + CONN_MASKS_TEST, + CONN_MASKS_TRAIN, + CONN_SEG_LOSS, + CONN_SEG_VIS, +) +from vis4d.zoo.base.datasets.bdd100k import ( + CONN_BDD100K_SEG_EVAL, + get_bdd100k_sem_seg_cfg, +) + + +def get_config() -> ExperimentConfig: + """Returns the config dict for the BDD100K semantic segmentation task. + + Returns: + ExperimentConfig: The configuration + """ + ###################################################### + ## General Config ## + ###################################################### + config = get_default_cfg(exp_name="semantic_fpn_r101_80k_bdd100k") + config.sync_batchnorm = True + config.val_check_interval = 4000 + config.check_val_every_n_epoch = None + + ## High level hyper parameters + params = ExperimentParameters() + params.samples_per_gpu = 2 + params.workers_per_gpu = 2 + params.lr = 0.01 + params.num_steps = 80000 + params.num_classes = 19 + config.params = params + + ###################################################### + ## Datasets with augmentations ## + ###################################################### + data_root = "data/bdd100k/images/10k" + train_split = "train" + test_split = "val" + + data_backend = class_config(HDF5Backend) + + config.data = get_bdd100k_sem_seg_cfg( + data_root=data_root, + train_split=train_split, + test_split=test_split, + data_backend=data_backend, + samples_per_gpu=params.samples_per_gpu, + workers_per_gpu=params.workers_per_gpu, + ) + + ###################################################### + ## MODEL & LOSS ## + ###################################################### + basemodel = class_config( + ResNetV1c, + resnet_name="resnet101_v1c", + pretrained=True, + trainable_layers=5, + norm_frozen=False, + ) + config.model = class_config( + SemanticFPN, num_classes=params.num_classes, basemodel=basemodel + ) + config.loss = class_config( + LossModule, + losses=[ + { + "loss": class_config(SegCrossEntropyLoss), + "connector": class_config( + LossConnector, key_mapping=CONN_SEG_LOSS + ), + }, + ], + ) + + ###################################################### + ## OPTIMIZERS ## + ###################################################### + config.optimizers = [ + get_optimizer_cfg( + optimizer=class_config( + SGD, lr=params.lr, momentum=0.9, weight_decay=0.0005 + ), + lr_schedulers=[ + get_lr_scheduler_cfg( + class_config( + LinearLR, start_factor=0.001, total_iters=500 + ), + end=500, + epoch_based=False, + ), + get_lr_scheduler_cfg( + class_config( + PolyLR, + max_steps=params.num_steps, + min_lr=0.0001, + power=0.9, + ), + epoch_based=False, + ), + ], + ) + ] + + ###################################################### + ## DATA CONNECTOR ## + ###################################################### + config.train_data_connector = class_config( + DataConnector, key_mapping=CONN_MASKS_TRAIN + ) + + config.test_data_connector = class_config( + DataConnector, key_mapping=CONN_MASKS_TEST + ) + + ###################################################### + ## CALLBACKS ## + ###################################################### + callbacks = get_default_callbacks_cfg(epoch_based=False) + + # Evaluator + callbacks.append( + class_config( + EvaluatorCallback, + evaluator=class_config( + BDD100KSegEvaluator, + annotation_path="data/bdd100k/labels/sem_seg_val_rle.json", + ), + test_connector=class_config( + CallbackConnector, key_mapping=CONN_BDD100K_SEG_EVAL + ), + ) + ) + + # Visualizer + callbacks.append( + class_config( + VisualizerCallback, + visualizer=class_config(SegMaskVisualizer, vis_freq=20), + output_dir=config.output_dir, + test_connector=class_config( + CallbackConnector, key_mapping=CONN_SEG_VIS + ), + ) + ) + + config.callbacks = callbacks + + ###################################################### + ## PL CLI ## + ###################################################### + # PL Trainer args + pl_trainer = get_default_pl_trainer_cfg(config) + pl_trainer.epoch_based = False + pl_trainer.max_steps = params.num_steps + + pl_trainer.checkpoint_period = config.val_check_interval + pl_trainer.val_check_interval = config.val_check_interval + pl_trainer.check_val_every_n_epoch = config.check_val_every_n_epoch + + pl_trainer.sync_batchnorm = config.sync_batchnorm + # pl_trainer.precision = 16 + + config.pl_trainer = pl_trainer + + return config.value_mode() diff --git a/vis4d/zoo/bdd100k/semantic_fpn/semantic_fpn_r50_40k_bdd100k.py b/vis4d/zoo/bdd100k/semantic_fpn/semantic_fpn_r50_40k_bdd100k.py new file mode 100644 index 0000000000000000000000000000000000000000..902b7e9dd57b34c01e409fa3fc53d6250b99d09b --- /dev/null +++ b/vis4d/zoo/bdd100k/semantic_fpn/semantic_fpn_r50_40k_bdd100k.py @@ -0,0 +1,189 @@ +# pylint: disable=duplicate-code +"""Semantic FPN BDD100K training example.""" +from __future__ import annotations + +from torch.optim.lr_scheduler import LinearLR +from torch.optim.sgd import SGD + +from vis4d.config import class_config +from vis4d.config.typing import ExperimentConfig, ExperimentParameters +from vis4d.data.io.hdf5 import HDF5Backend +from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback +from vis4d.engine.connectors import ( + CallbackConnector, + DataConnector, + LossConnector, +) +from vis4d.engine.loss_module import LossModule +from vis4d.engine.optim import PolyLR +from vis4d.eval.bdd100k import BDD100KSegEvaluator +from vis4d.model.seg.semantic_fpn import SemanticFPN +from vis4d.op.loss import SegCrossEntropyLoss +from vis4d.vis.image import SegMaskVisualizer +from vis4d.zoo.base import ( + get_default_callbacks_cfg, + get_default_cfg, + get_default_pl_trainer_cfg, + get_lr_scheduler_cfg, + get_optimizer_cfg, +) +from vis4d.zoo.base.data_connectors.seg import ( + CONN_MASKS_TEST, + CONN_MASKS_TRAIN, + CONN_SEG_LOSS, + CONN_SEG_VIS, +) +from vis4d.zoo.base.datasets.bdd100k import ( + CONN_BDD100K_SEG_EVAL, + get_bdd100k_sem_seg_cfg, +) + + +def get_config() -> ExperimentConfig: + """Returns the config dict for the BDD100K semantic segmentation task. + + Returns: + ExperimentConfig: The configuration + """ + ###################################################### + ## General Config ## + ###################################################### + config = get_default_cfg(exp_name="semantic_fpn_r50_40k_bdd100k") + config.sync_batchnorm = True + config.val_check_interval = 2000 + config.check_val_every_n_epoch = None + + ## High level hyper parameters + params = ExperimentParameters() + params.samples_per_gpu = 2 + params.workers_per_gpu = 2 + params.lr = 0.01 + params.num_steps = 40000 + params.num_classes = 19 + config.params = params + + ###################################################### + ## Datasets with augmentations ## + ###################################################### + data_root = "data/bdd100k/images/10k" + train_split = "train" + test_split = "val" + + data_backend = class_config(HDF5Backend) + + config.data = get_bdd100k_sem_seg_cfg( + data_root=data_root, + train_split=train_split, + test_split=test_split, + data_backend=data_backend, + samples_per_gpu=params.samples_per_gpu, + workers_per_gpu=params.workers_per_gpu, + ) + + ###################################################### + ## MODEL & LOSS ## + ###################################################### + config.model = class_config(SemanticFPN, num_classes=params.num_classes) + config.loss = class_config( + LossModule, + losses=[ + { + "loss": class_config(SegCrossEntropyLoss), + "connector": class_config( + LossConnector, key_mapping=CONN_SEG_LOSS + ), + }, + ], + ) + + ###################################################### + ## OPTIMIZERS ## + ###################################################### + config.optimizers = [ + get_optimizer_cfg( + optimizer=class_config( + SGD, lr=params.lr, momentum=0.9, weight_decay=0.0005 + ), + lr_schedulers=[ + get_lr_scheduler_cfg( + class_config( + LinearLR, start_factor=0.001, total_iters=500 + ), + end=500, + epoch_based=False, + ), + get_lr_scheduler_cfg( + class_config( + PolyLR, + max_steps=params.num_steps, + min_lr=0.0001, + power=0.9, + ), + epoch_based=False, + ), + ], + ) + ] + + ###################################################### + ## DATA CONNECTOR ## + ###################################################### + config.train_data_connector = class_config( + DataConnector, key_mapping=CONN_MASKS_TRAIN + ) + + config.test_data_connector = class_config( + DataConnector, key_mapping=CONN_MASKS_TEST + ) + + ###################################################### + ## CALLBACKS ## + ###################################################### + callbacks = get_default_callbacks_cfg(epoch_based=False) + + # Evaluator + callbacks.append( + class_config( + EvaluatorCallback, + evaluator=class_config( + BDD100KSegEvaluator, + annotation_path="data/bdd100k/labels/sem_seg_val_rle.json", + ), + test_connector=class_config( + CallbackConnector, key_mapping=CONN_BDD100K_SEG_EVAL + ), + ) + ) + + # Visualizer + callbacks.append( + class_config( + VisualizerCallback, + visualizer=class_config(SegMaskVisualizer, vis_freq=20), + output_dir=config.output_dir, + test_connector=class_config( + CallbackConnector, key_mapping=CONN_SEG_VIS + ), + ) + ) + + config.callbacks = callbacks + + ###################################################### + ## PL CLI ## + ###################################################### + # PL Trainer args + pl_trainer = get_default_pl_trainer_cfg(config) + pl_trainer.epoch_based = False + pl_trainer.max_steps = params.num_steps + + pl_trainer.checkpoint_period = config.val_check_interval + pl_trainer.val_check_interval = config.val_check_interval + pl_trainer.check_val_every_n_epoch = config.check_val_every_n_epoch + + pl_trainer.sync_batchnorm = config.sync_batchnorm + # pl_trainer.precision = 16 + + config.pl_trainer = pl_trainer + + return config.value_mode() diff --git a/vis4d/zoo/bdd100k/semantic_fpn/semantic_fpn_r50_80k_bdd100k.py b/vis4d/zoo/bdd100k/semantic_fpn/semantic_fpn_r50_80k_bdd100k.py new file mode 100644 index 0000000000000000000000000000000000000000..61709ea1dd350215cb6fa4ef3dc5affa8e04be1c --- /dev/null +++ b/vis4d/zoo/bdd100k/semantic_fpn/semantic_fpn_r50_80k_bdd100k.py @@ -0,0 +1,189 @@ +# pylint: disable=duplicate-code +"""Semantic FPN BDD100K training example.""" +from __future__ import annotations + +from torch.optim.lr_scheduler import LinearLR +from torch.optim.sgd import SGD + +from vis4d.config import class_config +from vis4d.config.typing import ExperimentConfig, ExperimentParameters +from vis4d.data.io.hdf5 import HDF5Backend +from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback +from vis4d.engine.connectors import ( + CallbackConnector, + DataConnector, + LossConnector, +) +from vis4d.engine.loss_module import LossModule +from vis4d.engine.optim import PolyLR +from vis4d.eval.bdd100k import BDD100KSegEvaluator +from vis4d.model.seg.semantic_fpn import SemanticFPN +from vis4d.op.loss import SegCrossEntropyLoss +from vis4d.vis.image import SegMaskVisualizer +from vis4d.zoo.base import ( + get_default_callbacks_cfg, + get_default_cfg, + get_default_pl_trainer_cfg, + get_lr_scheduler_cfg, + get_optimizer_cfg, +) +from vis4d.zoo.base.data_connectors.seg import ( + CONN_MASKS_TEST, + CONN_MASKS_TRAIN, + CONN_SEG_LOSS, + CONN_SEG_VIS, +) +from vis4d.zoo.base.datasets.bdd100k import ( + CONN_BDD100K_SEG_EVAL, + get_bdd100k_sem_seg_cfg, +) + + +def get_config() -> ExperimentConfig: + """Returns the config dict for the BDD100K semantic segmentation task. + + Returns: + ExperimentConfig: The configuration + """ + ###################################################### + ## General Config ## + ###################################################### + config = get_default_cfg(exp_name="semantic_fpn_r50_80k_bdd100k") + config.sync_batchnorm = True + config.val_check_interval = 4000 + config.check_val_every_n_epoch = None + + ## High level hyper parameters + params = ExperimentParameters() + params.samples_per_gpu = 2 + params.workers_per_gpu = 2 + params.lr = 0.01 + params.num_steps = 80000 + params.num_classes = 19 + config.params = params + + ###################################################### + ## Datasets with augmentations ## + ###################################################### + data_root = "data/bdd100k/images/10k" + train_split = "train" + test_split = "val" + + data_backend = class_config(HDF5Backend) + + config.data = get_bdd100k_sem_seg_cfg( + data_root=data_root, + train_split=train_split, + test_split=test_split, + data_backend=data_backend, + samples_per_gpu=params.samples_per_gpu, + workers_per_gpu=params.workers_per_gpu, + ) + + ###################################################### + ## MODEL & LOSS ## + ###################################################### + config.model = class_config(SemanticFPN, num_classes=params.num_classes) + config.loss = class_config( + LossModule, + losses=[ + { + "loss": class_config(SegCrossEntropyLoss), + "connector": class_config( + LossConnector, key_mapping=CONN_SEG_LOSS + ), + }, + ], + ) + + ###################################################### + ## OPTIMIZERS ## + ###################################################### + config.optimizers = [ + get_optimizer_cfg( + optimizer=class_config( + SGD, lr=params.lr, momentum=0.9, weight_decay=0.0005 + ), + lr_schedulers=[ + get_lr_scheduler_cfg( + class_config( + LinearLR, start_factor=0.001, total_iters=500 + ), + end=500, + epoch_based=False, + ), + get_lr_scheduler_cfg( + class_config( + PolyLR, + max_steps=params.num_steps, + min_lr=0.0001, + power=0.9, + ), + epoch_based=False, + ), + ], + ) + ] + + ###################################################### + ## DATA CONNECTOR ## + ###################################################### + config.train_data_connector = class_config( + DataConnector, key_mapping=CONN_MASKS_TRAIN + ) + + config.test_data_connector = class_config( + DataConnector, key_mapping=CONN_MASKS_TEST + ) + + ###################################################### + ## CALLBACKS ## + ###################################################### + callbacks = get_default_callbacks_cfg(epoch_based=False) + + # Evaluator + callbacks.append( + class_config( + EvaluatorCallback, + evaluator=class_config( + BDD100KSegEvaluator, + annotation_path="data/bdd100k/labels/sem_seg_val_rle.json", + ), + test_connector=class_config( + CallbackConnector, key_mapping=CONN_BDD100K_SEG_EVAL + ), + ) + ) + + # Visualizer + callbacks.append( + class_config( + VisualizerCallback, + visualizer=class_config(SegMaskVisualizer, vis_freq=20), + output_dir=config.output_dir, + test_connector=class_config( + CallbackConnector, key_mapping=CONN_SEG_VIS + ), + ) + ) + + config.callbacks = callbacks + + ###################################################### + ## PL CLI ## + ###################################################### + # PL Trainer args + pl_trainer = get_default_pl_trainer_cfg(config) + pl_trainer.epoch_based = False + pl_trainer.max_steps = params.num_steps + + pl_trainer.checkpoint_period = config.val_check_interval + pl_trainer.val_check_interval = config.val_check_interval + pl_trainer.check_val_every_n_epoch = config.check_val_every_n_epoch + + pl_trainer.sync_batchnorm = config.sync_batchnorm + # pl_trainer.precision = 16 + + config.pl_trainer = pl_trainer + + return config.value_mode() diff --git a/vis4d/zoo/bevformer/__init__.py b/vis4d/zoo/bevformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8e75b3249c6d7d3c8c0f504a52e962d6b527bd5e --- /dev/null +++ b/vis4d/zoo/bevformer/__init__.py @@ -0,0 +1,9 @@ +"""BEVFormer model zoo.""" + +from . import bevformer_base, bevformer_tiny, bevformer_vis + +AVAILABLE_MODELS = { + "bevformer_base": bevformer_base, + "bevformer_tiny": bevformer_tiny, + "bevformer_vis": bevformer_vis, +} diff --git a/vis4d/zoo/bevformer/bevformer_base.py b/vis4d/zoo/bevformer/bevformer_base.py new file mode 100644 index 0000000000000000000000000000000000000000..fd491bbc7bb68b0377cb875c08609143b0fe4413 --- /dev/null +++ b/vis4d/zoo/bevformer/bevformer_base.py @@ -0,0 +1,157 @@ +# pylint: disable=duplicate-code +"""BEVFormer base with ResNet-101-DCN backbone.""" +from __future__ import annotations + +from torch.optim.adamw import AdamW +from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR + +from vis4d.config import class_config +from vis4d.config.typing import ExperimentConfig, ExperimentParameters +from vis4d.data.io.hdf5 import HDF5Backend +from vis4d.engine.callbacks import EvaluatorCallback +from vis4d.engine.connectors import CallbackConnector, MultiSensorDataConnector +from vis4d.eval.nuscenes import NuScenesDet3DEvaluator +from vis4d.model.detect3d.bevformer import BEVFormer +from vis4d.op.base import ResNet +from vis4d.zoo.base import ( + get_default_callbacks_cfg, + get_default_cfg, + get_default_pl_trainer_cfg, + get_lr_scheduler_cfg, + get_optimizer_cfg, +) +from vis4d.zoo.bevformer.data import ( + CONN_NUSC_BBOX_3D_TEST, + CONN_NUSC_DET3D_EVAL, + get_nusc_cfg, + nuscenes_class_map, +) + + +def get_config() -> ExperimentConfig: + """Returns the config dict for BEVFormer on nuScenes. + + Returns: + ExperimentConfig: The configuration + """ + ###################################################### + ## General Config ## + ###################################################### + config = get_default_cfg(exp_name="bevformer_base") + + # Hyper Parameters + params = ExperimentParameters() + params.samples_per_gpu = 1 + params.workers_per_gpu = 4 + params.lr = 2e-4 + params.num_epochs = 24 + config.params = params + + ###################################################### + ## Datasets with augmentations ## + ###################################################### + data_root = "data/nuscenes" + version = "v1.0-trainval" + train_split = "train" + test_split = "val" + + data_backend = class_config(HDF5Backend) + + config.data = get_nusc_cfg( + data_root=data_root, + version=version, + train_split=train_split, + test_split=test_split, + data_backend=data_backend, + samples_per_gpu=params.samples_per_gpu, + workers_per_gpu=params.workers_per_gpu, + ) + + ###################################################### + ## MODEL & LOSS ## + ###################################################### + basemodel = class_config( + ResNet, + resnet_name="resnet101", + trainable_layers=3, + style="caffe", + stages_with_dcn=(False, False, True, True), + ) + + config.model = class_config( + BEVFormer, + basemodel=basemodel, + weights="https://github.com/zhiqi-li/storage/releases/download/v1.0/bevformer_r101_dcn_24ep.pth", # pylint: disable=line-too-long + ) + + config.loss = None + + ###################################################### + ## OPTIMIZERS ## + ###################################################### + config.optimizers = [ + get_optimizer_cfg( + optimizer=class_config(AdamW, lr=params.lr, weight_decay=0.01), + lr_schedulers=[ + get_lr_scheduler_cfg( + class_config( + LinearLR, start_factor=1.0 / 3, total_iters=500 + ), + end=500, + epoch_based=False, + ), + get_lr_scheduler_cfg( + class_config(CosineAnnealingLR, T_max=params.num_epochs), + ), + ], + param_groups=[{"custom_keys": ["basemodel"], "lr_mult": 0.1}], + ) + ] + + ###################################################### + ## DATA CONNECTOR ## + ###################################################### + config.train_data_connector = None + + config.test_data_connector = class_config( + MultiSensorDataConnector, key_mapping=CONN_NUSC_BBOX_3D_TEST + ) + + ###################################################### + ## CALLBACKS ## + ###################################################### + # Logger + callbacks = get_default_callbacks_cfg() + + # Evaluator + callbacks.append( + class_config( + EvaluatorCallback, + evaluator=class_config( + NuScenesDet3DEvaluator, + data_root=data_root, + version=version, + split=test_split, + class_map=nuscenes_class_map, + velocity_thres=0.2, + ), + save_predictions=True, + output_dir=config.output_dir, + test_connector=class_config( + CallbackConnector, key_mapping=CONN_NUSC_DET3D_EVAL + ), + ) + ) + + config.callbacks = callbacks + + ###################################################### + ## PL CLI ## + ###################################################### + # PL Trainer args + pl_trainer = get_default_pl_trainer_cfg(config) + pl_trainer.max_epochs = params.num_epochs + pl_trainer.gradient_clip_val = 35 + config.pl_trainer = pl_trainer + + return config.value_mode() diff --git a/vis4d/zoo/bevformer/bevformer_tiny.py b/vis4d/zoo/bevformer/bevformer_tiny.py new file mode 100644 index 0000000000000000000000000000000000000000..e850ef4c442e77401c7c45ea265e3333828bffbf --- /dev/null +++ b/vis4d/zoo/bevformer/bevformer_tiny.py @@ -0,0 +1,195 @@ +# pylint: disable=duplicate-code +"""BEVFormer tiny with ResNet-50 backbone.""" +from __future__ import annotations + +from torch.optim.adamw import AdamW +from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR + +from vis4d.config import class_config +from vis4d.config.typing import ExperimentConfig, ExperimentParameters +from vis4d.data.io.hdf5 import HDF5Backend +from vis4d.engine.callbacks import EvaluatorCallback +from vis4d.engine.connectors import CallbackConnector, MultiSensorDataConnector +from vis4d.eval.nuscenes import NuScenesDet3DEvaluator +from vis4d.model.detect3d.bevformer import BEVFormer +from vis4d.op.base import ResNet +from vis4d.op.detect3d.bevformer import BEVFormerHead +from vis4d.op.detect3d.bevformer.encoder import ( + BEVFormerEncoder, + BEVFormerEncoderLayer, +) +from vis4d.op.detect3d.bevformer.spatial_cross_attention import ( + MSDeformableAttention3D, + SpatialCrossAttention, +) +from vis4d.op.detect3d.bevformer.transformer import PerceptionTransformer +from vis4d.op.fpp.fpn import FPN +from vis4d.zoo.base import ( + get_default_callbacks_cfg, + get_default_cfg, + get_default_pl_trainer_cfg, + get_lr_scheduler_cfg, + get_optimizer_cfg, +) +from vis4d.zoo.bevformer.data import ( + CONN_NUSC_BBOX_3D_TEST, + CONN_NUSC_DET3D_EVAL, + get_nusc_cfg, + nuscenes_class_map, +) + + +def get_config() -> ExperimentConfig: + """Returns the config dict for BEVFormer on nuScenes. + + Returns: + ExperimentConfig: The configuration + """ + ###################################################### + ## General Config ## + ###################################################### + config = get_default_cfg(exp_name="bevformer_tiny") + + # Hyper Parameters + params = ExperimentParameters() + params.samples_per_gpu = 1 + params.workers_per_gpu = 4 + params.lr = 2e-4 + params.num_epochs = 24 + config.params = params + + ###################################################### + ## Datasets with augmentations ## + ###################################################### + data_root = "data/nuscenes" + version = "v1.0-trainval" + train_split = "train" + test_split = "val" + + data_backend = class_config(HDF5Backend) + + config.data = get_nusc_cfg( + data_root=data_root, + version=version, + train_split=train_split, + test_split=test_split, + data_backend=data_backend, + scale_factor=0.5, + style="pytorch", + samples_per_gpu=params.samples_per_gpu, + workers_per_gpu=params.workers_per_gpu, + ) + + ###################################################### + ## MODEL & LOSS ## + ###################################################### + basemodel = class_config( + ResNet, resnet_name="resnet50", trainable_layers=3, pretrained=True + ) + + config.model = class_config( + BEVFormer, + basemodel=basemodel, + fpn=class_config( + FPN, + in_channels_list=[2048], + out_channels=256, + extra_blocks=None, + start_index=5, + ), + pts_bbox_head=class_config( + BEVFormerHead, + transformer=class_config( + PerceptionTransformer, + encoder=class_config( + BEVFormerEncoder, + layer=class_config( + BEVFormerEncoderLayer, + cross_attn=class_config( + SpatialCrossAttention, + deformable_attention=class_config( + MSDeformableAttention3D, + num_levels=1, + ), + ), + ), + num_layers=3, + ), + ), + bev_h=50, + bev_w=50, + ), + weights="https://github.com/zhiqi-li/storage/releases/download/v1.0/bevformer_tiny_epoch_24.pth", # pylint: disable=line-too-long + ) + + config.loss = None + + ###################################################### + ## OPTIMIZERS ## + ###################################################### + config.optimizers = [ + get_optimizer_cfg( + optimizer=class_config(AdamW, lr=params.lr, weight_decay=0.01), + lr_schedulers=[ + get_lr_scheduler_cfg( + class_config( + LinearLR, start_factor=1.0 / 3, total_iters=500 + ), + end=500, + epoch_based=False, + ), + get_lr_scheduler_cfg( + class_config(CosineAnnealingLR, T_max=params.num_epochs), + ), + ], + param_groups=[{"custom_keys": ["basemodel"], "lr_mult": 0.1}], + ) + ] + + ###################################################### + ## DATA CONNECTOR ## + ###################################################### + config.train_data_connector = None + + config.test_data_connector = class_config( + MultiSensorDataConnector, key_mapping=CONN_NUSC_BBOX_3D_TEST + ) + + ###################################################### + ## CALLBACKS ## + ###################################################### + # Logger + callbacks = get_default_callbacks_cfg() + + # Evaluator + callbacks.append( + class_config( + EvaluatorCallback, + evaluator=class_config( + NuScenesDet3DEvaluator, + data_root=data_root, + version=version, + split=test_split, + class_map=nuscenes_class_map, + velocity_thres=0.2, + ), + save_predictions=True, + output_dir=config.output_dir, + test_connector=class_config( + CallbackConnector, key_mapping=CONN_NUSC_DET3D_EVAL + ), + ) + ) + + config.callbacks = callbacks + + ###################################################### + ## PL CLI ## + ###################################################### + # PL Trainer args + pl_trainer = get_default_pl_trainer_cfg(config) + pl_trainer.max_epochs = params.num_epochs + pl_trainer.gradient_clip_val = 35 + config.pl_trainer = pl_trainer + + return config.value_mode() diff --git a/vis4d/zoo/bevformer/bevformer_vis.py b/vis4d/zoo/bevformer/bevformer_vis.py new file mode 100644 index 0000000000000000000000000000000000000000..75a98eeac13b4015821e3d0ef2ea56485d34170b --- /dev/null +++ b/vis4d/zoo/bevformer/bevformer_vis.py @@ -0,0 +1,63 @@ +"""BEVFormer Visualizaion for NuScenes Example.""" + +from __future__ import annotations + +from vis4d.config import class_config +from vis4d.config.typing import ExperimentConfig +from vis4d.engine.callbacks import VisualizerCallback +from vis4d.engine.connectors import MultiSensorCallbackConnector +from vis4d.vis.image.bbox3d_visualizer import MultiCameraBBox3DVisualizer +from vis4d.zoo.base import get_default_callbacks_cfg +from vis4d.zoo.bevformer.bevformer_base import ( + get_config as get_bevformer_config, +) +from vis4d.zoo.bevformer.data import ( + CONN_NUSC_BBOX_3D_VIS, + NUSC_CAMERAS, + nuscenes_class_map, +) + + +def get_config() -> ExperimentConfig: + """Returns the config dict for BEVFormer on nuScenes. + + Returns: + ExperimentConfig: The configuration + """ + ###################################################### + ## General Config ## + ###################################################### + config = get_bevformer_config().ref_mode() + + config.experiment_name = "bevformer_vis" + + ###################################################### + ## CALLBACKS ## + ###################################################### + # Logger + callbacks = get_default_callbacks_cfg() + + # Visualizer + callbacks.append( + class_config( + VisualizerCallback, + visualizer=class_config( + MultiCameraBBox3DVisualizer, + cat_mapping=nuscenes_class_map, + width=2, + camera_near_clip=0.15, + cameras=NUSC_CAMERAS, + vis_freq=1, + plot_trajectory=False, + ), + output_dir=config.output_dir, + test_connector=class_config( + MultiSensorCallbackConnector, + key_mapping=CONN_NUSC_BBOX_3D_VIS, + ), + ) + ) + + config.callbacks = callbacks + + return config.value_mode() diff --git a/vis4d/zoo/bevformer/data.py b/vis4d/zoo/bevformer/data.py new file mode 100644 index 0000000000000000000000000000000000000000..1bf1642720f4c19e2be2353e27c79c05e3614d66 --- /dev/null +++ b/vis4d/zoo/bevformer/data.py @@ -0,0 +1,199 @@ +"""BEVFormer NuScenes data config.""" + +from __future__ import annotations + +from ml_collections import ConfigDict + +from vis4d.config import class_config +from vis4d.config.typing import DataConfig +from vis4d.data.const import CommonKeys as K +from vis4d.data.data_pipe import DataPipe +from vis4d.data.loader import multi_sensor_collate +from vis4d.data.transforms import compose +from vis4d.data.transforms.normalize import NormalizeImages +from vis4d.data.transforms.pad import PadImages +from vis4d.data.transforms.resize import ( + GenResizeParameters, + ResizeImages, + ResizeIntrinsics, +) +from vis4d.data.transforms.to_tensor import ToTensor +from vis4d.engine.connectors import data_key, pred_key +from vis4d.zoo.base import get_inference_dataloaders_cfg +from vis4d.zoo.base.datasets.nuscenes import ( + get_nusc_mini_val_cfg, + get_nusc_val_cfg, +) + +nuscenes_class_map = { + "car": 0, + "truck": 1, + "construction_vehicle": 2, + "bus": 3, + "trailer": 4, + "barrier": 5, + "motorcycle": 6, + "bicycle": 7, + "pedestrian": 8, + "traffic_cone": 9, +} + +NUSC_SENSORS = [ + "LIDAR_TOP", + "CAM_FRONT", + "CAM_FRONT_RIGHT", + "CAM_FRONT_LEFT", + "CAM_BACK", + "CAM_BACK_LEFT", + "CAM_BACK_RIGHT", +] + +NUSC_CAMERAS = [ + "CAM_FRONT", + "CAM_FRONT_RIGHT", + "CAM_FRONT_LEFT", + "CAM_BACK", + "CAM_BACK_LEFT", + "CAM_BACK_RIGHT", +] + +CONN_NUSC_BBOX_3D_TEST = { + "images": data_key(K.images, sensors=NUSC_CAMERAS), + "can_bus": "can_bus", + "scene_names": K.sequence_names, + "cam_intrinsics": data_key(K.intrinsics, sensors=NUSC_CAMERAS), + "cam_extrinsics": data_key(K.extrinsics, sensors=NUSC_CAMERAS), + "lidar_extrinsics": data_key(K.extrinsics, sensors=["LIDAR_TOP"]), +} + +CONN_NUSC_BBOX_3D_VIS = { + "images": data_key(K.original_images, sensors=NUSC_CAMERAS), + "image_names": data_key(K.sample_names, sensors=NUSC_CAMERAS), + "boxes3d": pred_key("boxes_3d"), + "intrinsics": data_key(K.intrinsics, sensors=NUSC_CAMERAS), + "extrinsics": data_key(K.extrinsics, sensors=NUSC_CAMERAS), + "scores": pred_key("scores_3d"), + "class_ids": pred_key("class_ids"), + "sequence_names": data_key(K.sequence_names), +} + +CONN_NUSC_DET3D_EVAL = { + "tokens": data_key("token"), + "boxes_3d": pred_key("boxes_3d"), + "velocities": pred_key("velocities"), + "class_ids": pred_key("class_ids"), + "scores_3d": pred_key("scores_3d"), +} + + +def get_test_dataloader( + test_dataset: ConfigDict, + shape: tuple[int, int], + mean: list[float], + std: list[float], + samples_per_gpu: int, + workers_per_gpu: int, +) -> ConfigDict: + """Get the default test dataloader for nuScenes tracking.""" + test_transforms = [ + class_config( + GenResizeParameters, + shape=shape, + keep_ratio=True, + sensors=NUSC_CAMERAS, + ), + class_config(ResizeImages, sensors=NUSC_CAMERAS), + class_config(ResizeIntrinsics, sensors=NUSC_CAMERAS), + class_config( + NormalizeImages, mean=mean, std=std, sensors=NUSC_CAMERAS + ), + ] + + test_preprocess_cfg = class_config(compose, transforms=test_transforms) + + test_batch_transforms = [ + class_config(PadImages, sensors=NUSC_CAMERAS), + class_config(ToTensor, sensors=NUSC_SENSORS), + ] + + test_batchprocess_cfg = class_config( + compose, transforms=test_batch_transforms + ) + + test_dataset_cfg = class_config( + DataPipe, datasets=test_dataset, preprocess_fn=test_preprocess_cfg + ) + + return get_inference_dataloaders_cfg( + datasets_cfg=test_dataset_cfg, + samples_per_gpu=samples_per_gpu, + workers_per_gpu=workers_per_gpu, + video_based_inference=True, + batchprocess_cfg=test_batchprocess_cfg, + collate_fn=multi_sensor_collate, + sensors=NUSC_SENSORS, + ) + + +def get_nusc_cfg( + data_root: str = "data/nuscenes", + version: str = "v1.0-trainval", + train_split: str = "train", + test_split: str = "val", + data_backend: None | ConfigDict = None, + scale_factor: float = 1.0, + style: str = "caffe", + samples_per_gpu: int = 1, + workers_per_gpu: int = 4, +) -> DataConfig: + """Get the default config for nuScenes tracking.""" + data = DataConfig() + + shape = (int(900 * scale_factor), int(1600 * scale_factor)) + + if style == "pytorch": + mean = [123.675, 116.28, 103.53] + std = [58.395, 57.12, 57.375] + image_channel_mode = "RGB" + elif style == "caffe": + mean = [103.530, 116.280, 123.675] + std = [1.0, 1.0, 1.0] + image_channel_mode = "BGR" + else: + raise ValueError(f"Unknown style {style}") + + if version == "v1.0-mini": # pragma: no cover + assert train_split == "mini_train" + assert test_split == "mini_val" + test_dataset = get_nusc_mini_val_cfg( + data_root=data_root, + image_channel_mode=image_channel_mode, + data_backend=data_backend, + cached_file_path=f"{data_root}/bevformer_mini_val.pkl", + ) + elif version == "v1.0-trainval": + assert train_split == "train" + assert test_split == "val" + test_dataset = get_nusc_val_cfg( + data_root=data_root, + image_channel_mode=image_channel_mode, + data_backend=data_backend, + cached_file_path=f"{data_root}/bevformer_val.pkl", + ) + else: + # TODO: Add support for v1.0-test + raise ValueError(f"Unknown version {version}") + + # TODO: Add train dataloader + data.train_dataloader = None + + data.test_dataloader = get_test_dataloader( + test_dataset, + shape, + mean, + std, + samples_per_gpu=samples_per_gpu, + workers_per_gpu=workers_per_gpu, + ) + + return data diff --git a/vis4d/zoo/cc_3dt/__init__.py b/vis4d/zoo/cc_3dt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f9ba88f85d296f5fd6972145e5e3107cfd942f47 --- /dev/null +++ b/vis4d/zoo/cc_3dt/__init__.py @@ -0,0 +1,15 @@ +"""CC-3DT Model Zoo.""" + +from . import ( + cc_3dt_frcnn_r50_fpn_kf3d_12e_nusc, + cc_3dt_frcnn_r101_fpn_kf3d_24e_nusc, + cc_3dt_frcnn_r101_fpn_velo_lstm_24e_nusc, + cc_3dt_nusc_vis, +) + +AVAILABLE_MODELS = { + "cc_3dt_frcnn_r50_fpn_kf3d_12e_nusc": cc_3dt_frcnn_r50_fpn_kf3d_12e_nusc, + "cc_3dt_frcnn_r101_fpn_kf3d_24e_nusc": cc_3dt_frcnn_r101_fpn_kf3d_24e_nusc, + "cc_3dt_frcnn_r101_fpn_velo_lstm_24e_nusc": cc_3dt_frcnn_r101_fpn_velo_lstm_24e_nusc, # pylint: disable=line-too-long + "cc_3dt_nusc_vis": cc_3dt_nusc_vis, +} diff --git a/vis4d/zoo/cc_3dt/cc_3dt_bevformer_base_velo_lstm_nusc.py b/vis4d/zoo/cc_3dt/cc_3dt_bevformer_base_velo_lstm_nusc.py new file mode 100644 index 0000000000000000000000000000000000000000..ddf9740fcd307a6e9eb6317153c40e7f4a815c37 --- /dev/null +++ b/vis4d/zoo/cc_3dt/cc_3dt_bevformer_base_velo_lstm_nusc.py @@ -0,0 +1,113 @@ +# pylint: disable=duplicate-code +"""CC-3DT with BEV detector on nuScenes.""" +from __future__ import annotations + +from vis4d.config import class_config +from vis4d.config.typing import DataConfig, ExperimentConfig +from vis4d.data.const import CommonKeys as K +from vis4d.data.datasets.nuscenes import NuScenes +from vis4d.data.datasets.nuscenes_detection import NuScenesDetection +from vis4d.data.io.hdf5 import HDF5Backend +from vis4d.engine.connectors import MultiSensorDataConnector, data_key +from vis4d.model.motion.velo_lstm import VeloLSTM +from vis4d.model.track3d.cc_3dt import CC3DT +from vis4d.op.base import ResNet +from vis4d.op.track3d.cc_3dt import CC3DTrackAssociation +from vis4d.state.track3d.cc_3dt import CC3DTrackGraph +from vis4d.zoo.cc_3dt.cc_3dt_frcnn_r101_fpn_velo_lstm_24e_nusc import ( + get_config as get_velo_lstm_cfg, +) +from vis4d.zoo.cc_3dt.data import CONN_NUSC_BBOX_3D_TEST, get_test_dataloader + +CONN_NUSC_BBOX_3D_TEST = { + "images_list": data_key(K.images, sensors=NuScenes.CAMERAS), + "images_hw": data_key(K.original_hw, sensors=NuScenes.CAMERAS), + "intrinsics_list": data_key(K.intrinsics, sensors=NuScenes.CAMERAS), + "extrinsics_list": data_key(K.extrinsics, sensors=NuScenes.CAMERAS), + "frame_ids": K.frame_ids, + "pred_boxes3d": data_key("pred_boxes3d", sensors=["LIDAR_TOP"]), + "pred_boxes3d_classes": data_key( + "pred_boxes3d_classes", sensors=["LIDAR_TOP"] + ), + "pred_boxes3d_scores": data_key( + "pred_boxes3d_scores", sensors=["LIDAR_TOP"] + ), + "pred_boxes3d_velocities": data_key( + "pred_boxes3d_velocities", sensors=["LIDAR_TOP"] + ), +} + + +def get_config() -> ExperimentConfig: + """Returns the config dict for CC-3DT on nuScenes. + + Returns: + ExperimentConfig: The configuration + """ + ###################################################### + ## General Config ## + ###################################################### + config = get_velo_lstm_cfg().ref_mode() + + config.experiment_name = "cc_3dt_bevformer_base_velo_lstm_nusc" + + ###################################################### + ## Datasets with augmentations ## + ###################################################### + config.pure_detection = "" + + data = DataConfig() + + data.train_dataloader = None + + test_dataset = class_config( + NuScenesDetection, + data_root="data/nuscenes", + version="v1.0-trainval", + split="val", + keys_to_load=[K.images, K.original_images, K.boxes3d], + data_backend=class_config(HDF5Backend), + pure_detection=config.pure_detection, + cache_as_binary=True, + cached_file_path="data/nuscenes/val.pkl", + ) + + data.test_dataloader = get_test_dataloader( + test_dataset=test_dataset, samples_per_gpu=1, workers_per_gpu=4 + ) + + config.data = data + + ###################################################### + ## MODEL & LOSS ## + ###################################################### + basemodel = class_config( + ResNet, resnet_name="resnet101", pretrained=True, trainable_layers=3 + ) + + track_graph = class_config( + CC3DTrackGraph, + track=class_config( + CC3DTrackAssociation, init_score_thr=0.2, obj_score_thr=0.1 + ), + motion_model="VeloLSTM", + lstm_model=class_config(VeloLSTM, weights=config.velo_lstm_ckpt), + update_3d_score=False, + add_backdrops=False, + ) + + config.model = class_config( + CC3DT, + basemodel=basemodel, + track_graph=track_graph, + detection_range=[40, 40, 40, 50, 50, 50, 50, 50, 30, 30], + ) + + ###################################################### + ## DATA CONNECTOR ## + ###################################################### + config.test_data_connector = class_config( + MultiSensorDataConnector, key_mapping=CONN_NUSC_BBOX_3D_TEST + ) + + return config.value_mode() diff --git a/vis4d/zoo/cc_3dt/cc_3dt_frcnn_r101_fpn_kf3d_24e_nusc.py b/vis4d/zoo/cc_3dt/cc_3dt_frcnn_r101_fpn_kf3d_24e_nusc.py new file mode 100644 index 0000000000000000000000000000000000000000..53aa553179009485cef5eab80576cf11c7bd4421 --- /dev/null +++ b/vis4d/zoo/cc_3dt/cc_3dt_frcnn_r101_fpn_kf3d_24e_nusc.py @@ -0,0 +1,200 @@ +# pylint: disable=duplicate-code +"""CC-3DT with Faster-RCNN ResNet-101 detector using KF3D motion model.""" +from __future__ import annotations + +from torch.optim.lr_scheduler import LinearLR, MultiStepLR +from torch.optim.sgd import SGD + +from vis4d.config import class_config +from vis4d.config.typing import ExperimentConfig, ExperimentParameters +from vis4d.data.datasets.nuscenes import nuscenes_class_map +from vis4d.data.io.hdf5 import HDF5Backend +from vis4d.engine.callbacks import EvaluatorCallback +from vis4d.engine.connectors import ( + CallbackConnector, + DataConnector, + MultiSensorDataConnector, +) +from vis4d.eval.nuscenes import ( + NuScenesDet3DEvaluator, + NuScenesTrack3DEvaluator, +) +from vis4d.op.base import ResNet +from vis4d.zoo.base import ( + get_default_callbacks_cfg, + get_default_cfg, + get_default_pl_trainer_cfg, + get_lr_scheduler_cfg, + get_optimizer_cfg, +) +from vis4d.zoo.cc_3dt.data import ( + CONN_NUSC_BBOX_3D_TEST, + CONN_NUSC_DET3D_EVAL, + CONN_NUSC_TRACK3D_EVAL, + get_nusc_cfg, +) +from vis4d.zoo.cc_3dt.model import CONN_BBOX_3D_TRAIN, get_cc_3dt_cfg + + +def get_config() -> ExperimentConfig: + """Returns the config dict for cc-3dt on nuScenes. + + Returns: + ExperimentConfig: The configuration + """ + ###################################################### + ## General Config ## + ###################################################### + config = get_default_cfg(exp_name="cc_3dt_frcnn_r101_fpn_kf3d_24e_nusc") + + # Hyper Parameters + params = ExperimentParameters() + params.samples_per_gpu = 4 + params.workers_per_gpu = 4 + params.lr = 0.01 + params.num_epochs = 24 + config.params = params + + ###################################################### + ## Datasets with augmentations ## + ###################################################### + data_root = "data/nuscenes" + version = "v1.0-trainval" + train_split = "train" + test_split = "val" + + data_backend = class_config(HDF5Backend) + + config.data = get_nusc_cfg( + data_root=data_root, + version=version, + train_split=train_split, + test_split=test_split, + data_backend=data_backend, + samples_per_gpu=params.samples_per_gpu, + workers_per_gpu=params.workers_per_gpu, + ) + + ###################################################### + ## MODEL & LOSS ## + ###################################################### + basemodel = class_config( + ResNet, resnet_name="resnet101", pretrained=True, trainable_layers=3 + ) + + config.model, config.loss = get_cc_3dt_cfg( + num_classes=len(nuscenes_class_map), basemodel=basemodel, fps=2 + ) + + ###################################################### + ## OPTIMIZERS ## + ###################################################### + config.optimizers = [ + get_optimizer_cfg( + optimizer=class_config( + SGD, lr=params.lr, momentum=0.9, weight_decay=0.0001 + ), + lr_schedulers=[ + get_lr_scheduler_cfg( + class_config(LinearLR, start_factor=0.1, total_iters=1000), + end=1000, + epoch_based=False, + ), + get_lr_scheduler_cfg( + class_config(MultiStepLR, milestones=[16, 22], gamma=0.1), + ), + ], + param_groups=[ + { + "custom_keys": [ + "faster_rcnn_head.rpn_head.rpn_cls.weight", + "faster_rcnn_head.rpn_head.rpn_box.weight", + "faster_rcnn_head.roi_head.fc_cls.weight", + "faster_rcnn_head.roi_head.fc_reg.weight", + "bbox_3d_head.dep_convs.0.weight", + "bbox_3d_head.dep_convs.1.weight", + "bbox_3d_head.dep_convs.2.weight", + "bbox_3d_head.dep_convs.3.weight", + "bbox_3d_head.dim_convs.0.weight", + "bbox_3d_head.dim_convs.1.weight", + "bbox_3d_head.dim_convs.2.weight", + "bbox_3d_head.dim_convs.3.weight", + "bbox_3d_head.rot_convs.0.weight" + "bbox_3d_head.rot_convs.1.weight", + "bbox_3d_head.rot_convs.2.weight", + "bbox_3d_head.rot_convs.3.weight", + "bbox_3d_head.cen_2d_convs.0.weight", + "bbox_3d_head.cen_2d_convs.1.weight", + "bbox_3d_head.cen_2d_convs.2.weight", + "bbox_3d_head.cen_2d_convs.3.weight", + "bbox_3d_head.fc_dep.weight", + "bbox_3d_head.fc_dep_uncer.weight", + "bbox_3d_head.fc_dim.weight", + "bbox_3d_head.fc_rot.weight", + "bbox_3d_head.fc_cen_2d.weight", + ], + "lr_mult": 10.0, + } + ], + ) + ] + + ###################################################### + ## DATA CONNECTOR ## + ###################################################### + config.train_data_connector = class_config( + DataConnector, key_mapping=CONN_BBOX_3D_TRAIN + ) + + config.test_data_connector = class_config( + MultiSensorDataConnector, key_mapping=CONN_NUSC_BBOX_3D_TEST + ) + + ###################################################### + ## CALLBACKS ## + ###################################################### + # Logger + callbacks = get_default_callbacks_cfg() + + # Evaluator + callbacks.append( + class_config( + EvaluatorCallback, + evaluator=class_config( + NuScenesDet3DEvaluator, + data_root=data_root, + version=version, + split=test_split, + ), + save_predictions=True, + output_dir=config.output_dir, + test_connector=class_config( + CallbackConnector, key_mapping=CONN_NUSC_DET3D_EVAL + ), + ) + ) + + callbacks.append( + class_config( + EvaluatorCallback, + evaluator=class_config(NuScenesTrack3DEvaluator), + save_predictions=True, + output_dir=config.output_dir, + test_connector=class_config( + CallbackConnector, key_mapping=CONN_NUSC_TRACK3D_EVAL + ), + ) + ) + + config.callbacks = callbacks + + ###################################################### + ## PL CLI ## + ###################################################### + # PL Trainer args + pl_trainer = get_default_pl_trainer_cfg(config) + pl_trainer.max_epochs = params.num_epochs + pl_trainer.gradient_clip_val = 10 + config.pl_trainer = pl_trainer + + return config.value_mode() diff --git a/vis4d/zoo/cc_3dt/cc_3dt_frcnn_r101_fpn_pure_det_nusc.py b/vis4d/zoo/cc_3dt/cc_3dt_frcnn_r101_fpn_pure_det_nusc.py new file mode 100644 index 0000000000000000000000000000000000000000..6653033cdde78c86b51a18b029801b5b1bb30445 --- /dev/null +++ b/vis4d/zoo/cc_3dt/cc_3dt_frcnn_r101_fpn_pure_det_nusc.py @@ -0,0 +1,88 @@ +"""CC-3DT with Faster-RCNN ResNet-101 detector generating pure detection.""" + +from __future__ import annotations + +from vis4d.config import class_config +from vis4d.config.typing import ExperimentConfig +from vis4d.data.datasets.nuscenes import NuScenes, nuscenes_class_map +from vis4d.data.io.hdf5 import HDF5Backend +from vis4d.engine.callbacks import EvaluatorCallback +from vis4d.engine.connectors import MultiSensorCallbackConnector +from vis4d.eval.nuscenes import NuScenesDet3DEvaluator +from vis4d.op.base import ResNet +from vis4d.zoo.base import get_default_callbacks_cfg +from vis4d.zoo.cc_3dt.cc_3dt_frcnn_r101_fpn_kf3d_24e_nusc import ( + get_config as get_kf3d_config, +) +from vis4d.zoo.cc_3dt.data import CONN_NUSC_DET3D_EVAL, get_nusc_cfg +from vis4d.zoo.cc_3dt.model import get_cc_3dt_cfg + + +def get_config() -> ExperimentConfig: + """Get config.""" + config = get_kf3d_config().ref_mode() + + config.experiment_name = "cc_3dt_frcnn_r101_fpn_pure_det_nusc" + + ###################################################### + ## Datasets with augmentations ## + ###################################################### + data_root = "data/nuscenes" + version = "v1.0-trainval" + train_split = "train" + test_split = "train" + + data_backend = class_config(HDF5Backend) + + config.data = get_nusc_cfg( + data_root=data_root, + version=version, + train_split=train_split, + test_split=test_split, + data_backend=data_backend, + ) + + ###################################################### + ## MODEL & LOSS ## + ###################################################### + basemodel = class_config( + ResNet, resnet_name="resnet101", pretrained=True, trainable_layers=3 + ) + + config.model, _ = get_cc_3dt_cfg( + num_classes=len(nuscenes_class_map), + basemodel=basemodel, + fps=2, + pure_det=True, + ) + + ###################################################### + ## CALLBACKS ## + ###################################################### + # Logger + callbacks = get_default_callbacks_cfg() + + # Evaluator + callbacks.append( + class_config( + EvaluatorCallback, + evaluator=class_config( + NuScenesDet3DEvaluator, + data_root=data_root, + version=version, + split=test_split, + save_only=True, + ), + save_predictions=True, + output_dir=config.output_dir, + test_connector=class_config( + MultiSensorCallbackConnector, + key_mapping=CONN_NUSC_DET3D_EVAL, + sensors=NuScenes.CAMERAS, + ), + ) + ) + + config.callbacks = callbacks + + return config.value_mode() diff --git a/vis4d/zoo/cc_3dt/cc_3dt_frcnn_r101_fpn_velo_lstm_24e_nusc.py b/vis4d/zoo/cc_3dt/cc_3dt_frcnn_r101_fpn_velo_lstm_24e_nusc.py new file mode 100644 index 0000000000000000000000000000000000000000..343ede71b15887f9c5c4cd9a148c3a8e1f2bce89 --- /dev/null +++ b/vis4d/zoo/cc_3dt/cc_3dt_frcnn_r101_fpn_velo_lstm_24e_nusc.py @@ -0,0 +1,46 @@ +"""CC-3DT inference with Faster-RCNN ResNet-101 detector using VeloLSTM.""" + +from __future__ import annotations + +from vis4d.config import class_config +from vis4d.config.typing import ExperimentConfig +from vis4d.data.datasets.nuscenes import nuscenes_class_map +from vis4d.model.motion.velo_lstm import VeloLSTM +from vis4d.op.base import ResNet +from vis4d.zoo.cc_3dt.cc_3dt_frcnn_r101_fpn_kf3d_24e_nusc import ( + get_config as get_kf3d_cfg, +) +from vis4d.zoo.cc_3dt.model import get_cc_3dt_cfg + + +def get_config() -> ExperimentConfig: + """Returns the config dict for cc-3dt on nuScenes. + + Returns: + ExperimentConfig: The configuration + """ + ###################################################### + ## General Config ## + ###################################################### + config = get_kf3d_cfg().ref_mode() + + config.experiment_name = "cc_3dt_frcnn_r101_fpn_velo_lstm_24e_nusc" + + config.velo_lstm_ckpt = "" + + ###################################################### + ## MODEL & LOSS ## + ###################################################### + basemodel = class_config( + ResNet, resnet_name="resnet101", pretrained=True, trainable_layers=3 + ) + + config.model, _ = get_cc_3dt_cfg( + num_classes=len(nuscenes_class_map), + basemodel=basemodel, + motion_model="VeloLSTM", + lstm_model=class_config(VeloLSTM, weights=config.velo_lstm_ckpt), + fps=2, + ) + + return config.value_mode() diff --git a/vis4d/zoo/cc_3dt/cc_3dt_frcnn_r50_fpn_kf3d_12e_nusc.py b/vis4d/zoo/cc_3dt/cc_3dt_frcnn_r50_fpn_kf3d_12e_nusc.py new file mode 100644 index 0000000000000000000000000000000000000000..169852be7e143b4cea7a7cef9c23c11a0d4f6c0e --- /dev/null +++ b/vis4d/zoo/cc_3dt/cc_3dt_frcnn_r50_fpn_kf3d_12e_nusc.py @@ -0,0 +1,200 @@ +# pylint: disable=duplicate-code +"""CC-3DT with Faster-RCNN ResNet-50 detector using KF3D motion model.""" +from __future__ import annotations + +from torch.optim.lr_scheduler import LinearLR, MultiStepLR +from torch.optim.sgd import SGD + +from vis4d.config import class_config +from vis4d.config.typing import ExperimentConfig, ExperimentParameters +from vis4d.data.datasets.nuscenes import nuscenes_class_map +from vis4d.data.io.hdf5 import HDF5Backend +from vis4d.engine.callbacks import EvaluatorCallback +from vis4d.engine.connectors import ( + CallbackConnector, + DataConnector, + MultiSensorDataConnector, +) +from vis4d.eval.nuscenes import ( + NuScenesDet3DEvaluator, + NuScenesTrack3DEvaluator, +) +from vis4d.op.base import ResNet +from vis4d.zoo.base import ( + get_default_callbacks_cfg, + get_default_cfg, + get_default_pl_trainer_cfg, + get_lr_scheduler_cfg, + get_optimizer_cfg, +) +from vis4d.zoo.cc_3dt.data import ( + CONN_NUSC_BBOX_3D_TEST, + CONN_NUSC_DET3D_EVAL, + CONN_NUSC_TRACK3D_EVAL, + get_nusc_cfg, +) +from vis4d.zoo.cc_3dt.model import CONN_BBOX_3D_TRAIN, get_cc_3dt_cfg + + +def get_config() -> ExperimentConfig: + """Returns the config dict for cc-3dt on nuScenes. + + Returns: + ExperimentConfig: The configuration + """ + ###################################################### + ## General Config ## + ###################################################### + config = get_default_cfg(exp_name="cc_3dt_frcnn_r50_fpn_kf3d_12e_nusc") + + # Hyper Parameters + params = ExperimentParameters() + params.samples_per_gpu = 4 + params.workers_per_gpu = 4 + params.lr = 0.01 + params.num_epochs = 12 + config.params = params + + ###################################################### + ## Datasets with augmentations ## + ###################################################### + data_root = "data/nuscenes" + version = "v1.0-trainval" + train_split = "train" + test_split = "val" + + data_backend = class_config(HDF5Backend) + + config.data = get_nusc_cfg( + data_root=data_root, + version=version, + train_split=train_split, + test_split=test_split, + data_backend=data_backend, + samples_per_gpu=params.samples_per_gpu, + workers_per_gpu=params.workers_per_gpu, + ) + + ###################################################### + ## MODEL & LOSS ## + ###################################################### + basemodel = class_config( + ResNet, resnet_name="resnet50", pretrained=True, trainable_layers=3 + ) + + config.model, config.loss = get_cc_3dt_cfg( + num_classes=len(nuscenes_class_map), basemodel=basemodel, fps=2 + ) + + ###################################################### + ## OPTIMIZERS ## + ###################################################### + config.optimizers = [ + get_optimizer_cfg( + optimizer=class_config( + SGD, lr=params.lr, momentum=0.9, weight_decay=0.0001 + ), + lr_schedulers=[ + get_lr_scheduler_cfg( + class_config(LinearLR, start_factor=0.1, total_iters=1000), + end=1000, + epoch_based=False, + ), + get_lr_scheduler_cfg( + class_config(MultiStepLR, milestones=[8, 11], gamma=0.1), + ), + ], + param_groups=[ + { + "custom_keys": [ + "faster_rcnn_head.rpn_head.rpn_cls.weight", + "faster_rcnn_head.rpn_head.rpn_box.weight", + "faster_rcnn_head.roi_head.fc_cls.weight", + "faster_rcnn_head.roi_head.fc_reg.weight", + "bbox_3d_head.dep_convs.0.weight", + "bbox_3d_head.dep_convs.1.weight", + "bbox_3d_head.dep_convs.2.weight", + "bbox_3d_head.dep_convs.3.weight", + "bbox_3d_head.dim_convs.0.weight", + "bbox_3d_head.dim_convs.1.weight", + "bbox_3d_head.dim_convs.2.weight", + "bbox_3d_head.dim_convs.3.weight", + "bbox_3d_head.rot_convs.0.weight" + "bbox_3d_head.rot_convs.1.weight", + "bbox_3d_head.rot_convs.2.weight", + "bbox_3d_head.rot_convs.3.weight", + "bbox_3d_head.cen_2d_convs.0.weight", + "bbox_3d_head.cen_2d_convs.1.weight", + "bbox_3d_head.cen_2d_convs.2.weight", + "bbox_3d_head.cen_2d_convs.3.weight", + "bbox_3d_head.fc_dep.weight", + "bbox_3d_head.fc_dep_uncer.weight", + "bbox_3d_head.fc_dim.weight", + "bbox_3d_head.fc_rot.weight", + "bbox_3d_head.fc_cen_2d.weight", + ], + "lr_mult": 10.0, + } + ], + ) + ] + + ###################################################### + ## DATA CONNECTOR ## + ###################################################### + config.train_data_connector = class_config( + DataConnector, key_mapping=CONN_BBOX_3D_TRAIN + ) + + config.test_data_connector = class_config( + MultiSensorDataConnector, key_mapping=CONN_NUSC_BBOX_3D_TEST + ) + + ###################################################### + ## CALLBACKS ## + ###################################################### + # Logger + callbacks = get_default_callbacks_cfg() + + # Evaluator + callbacks.append( + class_config( + EvaluatorCallback, + evaluator=class_config( + NuScenesDet3DEvaluator, + data_root=data_root, + version=version, + split=test_split, + ), + save_predictions=True, + output_dir=config.output_dir, + test_connector=class_config( + CallbackConnector, key_mapping=CONN_NUSC_DET3D_EVAL + ), + ) + ) + + callbacks.append( + class_config( + EvaluatorCallback, + evaluator=class_config(NuScenesTrack3DEvaluator), + save_predictions=True, + output_dir=config.output_dir, + test_connector=class_config( + CallbackConnector, key_mapping=CONN_NUSC_TRACK3D_EVAL + ), + ) + ) + + config.callbacks = callbacks + + ###################################################### + ## PL CLI ## + ###################################################### + # PL Trainer args + pl_trainer = get_default_pl_trainer_cfg(config) + pl_trainer.max_epochs = params.num_epochs + pl_trainer.gradient_clip_val = 10 + config.pl_trainer = pl_trainer + + return config.value_mode() diff --git a/vis4d/zoo/cc_3dt/cc_3dt_nusc_test.py b/vis4d/zoo/cc_3dt/cc_3dt_nusc_test.py new file mode 100644 index 0000000000000000000000000000000000000000..8ed480340feab7908609a4b2304466776e49b768 --- /dev/null +++ b/vis4d/zoo/cc_3dt/cc_3dt_nusc_test.py @@ -0,0 +1,106 @@ +# pylint: disable=duplicate-code +"""CC-3DT with BEV detector on nuScenes.""" +from __future__ import annotations + +from vis4d.config import class_config +from vis4d.config.typing import DataConfig, ExperimentConfig +from vis4d.data.const import CommonKeys as K +from vis4d.data.datasets.nuscenes_detection import NuScenesDetection +from vis4d.data.io.hdf5 import HDF5Backend +from vis4d.engine.callbacks import EvaluatorCallback +from vis4d.engine.connectors import CallbackConnector +from vis4d.eval.nuscenes import ( + NuScenesDet3DEvaluator, + NuScenesTrack3DEvaluator, +) +from vis4d.zoo.base import get_default_callbacks_cfg +from vis4d.zoo.cc_3dt.cc_3dt_bevformer_base_velo_lstm_nusc import ( + get_config as get_cc_3dt_config, +) +from vis4d.zoo.cc_3dt.data import ( + CONN_NUSC_DET3D_EVAL, + CONN_NUSC_TRACK3D_EVAL, + get_test_dataloader, +) + + +def get_config() -> ExperimentConfig: + """Returns the config dict for CC-3DT on nuScenes. + + Returns: + ExperimentConfig: The configuration + """ + ###################################################### + ## General Config ## + ###################################################### + config = get_cc_3dt_config().ref_mode() + + config.experiment_name = "cc_3dt_nusc_test" + + ###################################################### + ## Datasets with augmentations ## + ###################################################### + config.pure_detection = "" + + data = DataConfig() + + data.train_dataloader = None + + test_dataset = class_config( + NuScenesDetection, + data_root="data/nuscenes", + version="v1.0-test", + split="test", + keys_to_load=[K.images, K.original_images], + data_backend=class_config(HDF5Backend), + pure_detection=config.pure_detection, + cache_as_binary=True, + cached_file_path="data/nuscenes/test.pkl", + ) + + data.test_dataloader = get_test_dataloader( + test_dataset=test_dataset, samples_per_gpu=1, workers_per_gpu=4 + ) + + config.data = data + + ###################################################### + ## CALLBACKS ## + ###################################################### + # Logger + callbacks = get_default_callbacks_cfg() + + # Evaluator + callbacks.append( + class_config( + EvaluatorCallback, + evaluator=class_config( + NuScenesDet3DEvaluator, + data_root="data/nuscenes", + version="v1.0-test", + split="test", + save_only=True, + ), + save_predictions=True, + output_dir=config.output_dir, + test_connector=class_config( + CallbackConnector, key_mapping=CONN_NUSC_DET3D_EVAL + ), + ) + ) + + callbacks.append( + class_config( + EvaluatorCallback, + evaluator=class_config(NuScenesTrack3DEvaluator), + save_predictions=True, + output_dir=config.output_dir, + test_connector=class_config( + CallbackConnector, key_mapping=CONN_NUSC_TRACK3D_EVAL + ), + ) + ) + + config.callbacks = callbacks + + return config.value_mode() diff --git a/vis4d/zoo/cc_3dt/cc_3dt_nusc_vis.py b/vis4d/zoo/cc_3dt/cc_3dt_nusc_vis.py new file mode 100644 index 0000000000000000000000000000000000000000..f629cd412754ca35cbfe2e1a4578a209e88c27ca --- /dev/null +++ b/vis4d/zoo/cc_3dt/cc_3dt_nusc_vis.py @@ -0,0 +1,77 @@ +"""CC-3DT Visualizaion for NuScenes Example.""" + +from __future__ import annotations + +from vis4d.config import class_config +from vis4d.config.typing import ExperimentConfig +from vis4d.data.datasets.nuscenes import NuScenes, nuscenes_class_map +from vis4d.engine.callbacks import VisualizerCallback +from vis4d.engine.connectors import MultiSensorCallbackConnector +from vis4d.vis.image.bbox3d_visualizer import MultiCameraBBox3DVisualizer +from vis4d.vis.image.bev_visualizer import BEVBBox3DVisualizer +from vis4d.zoo.base import get_default_callbacks_cfg +from vis4d.zoo.cc_3dt.cc_3dt_frcnn_r50_fpn_kf3d_12e_nusc import ( + get_config as get_cc_3dt_config, +) +from vis4d.zoo.cc_3dt.data import ( + CONN_NUSC_BBOX_3D_VIS, + CONN_NUSC_BEV_BBOX_3D_VIS, +) + + +def get_config() -> ExperimentConfig: + """Returns the config dict for cc-3dt on nuScenes. + + Returns: + ExperimentConfig: The configuration + """ + ###################################################### + ## General Config ## + ###################################################### + config = get_cc_3dt_config().ref_mode() + + config.experiment_name = "cc_3dt_nusc_vis" + + ###################################################### + ## CALLBACKS ## + ###################################################### + # Logger + callbacks = get_default_callbacks_cfg() + + # Visualizer + callbacks.append( + class_config( + VisualizerCallback, + visualizer=class_config( + MultiCameraBBox3DVisualizer, + cat_mapping=nuscenes_class_map, + width=2, + camera_near_clip=0.15, + cameras=NuScenes.CAMERAS, + vis_freq=1, + ), + output_dir=config.output_dir, + save_prefix="boxes3d", + test_connector=class_config( + MultiSensorCallbackConnector, + key_mapping=CONN_NUSC_BBOX_3D_VIS, + ), + ) + ) + + callbacks.append( + class_config( + VisualizerCallback, + visualizer=class_config(BEVBBox3DVisualizer, width=2, vis_freq=1), + output_dir=config.output_dir, + save_prefix="bev", + test_connector=class_config( + MultiSensorCallbackConnector, + key_mapping=CONN_NUSC_BEV_BBOX_3D_VIS, + ), + ) + ) + + config.callbacks = callbacks + + return config.value_mode() diff --git a/vis4d/zoo/cc_3dt/cc_3dt_pp_kf3d.py b/vis4d/zoo/cc_3dt/cc_3dt_pp_kf3d.py new file mode 100644 index 0000000000000000000000000000000000000000..c5eed3563ebba92b3258c4a632d58861b340e047 --- /dev/null +++ b/vis4d/zoo/cc_3dt/cc_3dt_pp_kf3d.py @@ -0,0 +1,175 @@ +# pylint: disable=duplicate-code +"""CC-3DT++ on nuScenes.""" +from __future__ import annotations + +from vis4d.config import class_config +from vis4d.config.typing import DataConfig, ExperimentConfig +from vis4d.data.const import CommonKeys as K +from vis4d.data.datasets.nuscenes import NuScenes +from vis4d.data.datasets.nuscenes_detection import NuScenesDetection +from vis4d.data.io.hdf5 import HDF5Backend +from vis4d.engine.callbacks import EvaluatorCallback +from vis4d.engine.connectors import ( + CallbackConnector, + MultiSensorDataConnector, + data_key, +) +from vis4d.eval.nuscenes import ( + NuScenesDet3DEvaluator, + NuScenesTrack3DEvaluator, +) +from vis4d.model.track3d.cc_3dt import CC3DT +from vis4d.op.base import ResNet +from vis4d.op.track3d.cc_3dt import CC3DTrackAssociation +from vis4d.state.track3d.cc_3dt import CC3DTrackGraph +from vis4d.zoo.base import get_default_callbacks_cfg +from vis4d.zoo.cc_3dt.cc_3dt_frcnn_r101_fpn_kf3d_24e_nusc import ( + get_config as get_kf3d_cfg, +) +from vis4d.zoo.cc_3dt.data import ( + CONN_NUSC_DET3D_EVAL, + CONN_NUSC_TRACK3D_EVAL, + get_test_dataloader, +) + +CONN_NUSC_BBOX_3D_TEST = { + "images_list": data_key(K.images, sensors=NuScenes.CAMERAS), + "images_hw": data_key(K.original_hw, sensors=NuScenes.CAMERAS), + "intrinsics_list": data_key(K.intrinsics, sensors=NuScenes.CAMERAS), + "extrinsics_list": data_key(K.extrinsics, sensors=NuScenes.CAMERAS), + "frame_ids": K.frame_ids, + "pred_boxes3d": data_key("pred_boxes3d", sensors=["LIDAR_TOP"]), + "pred_boxes3d_classes": data_key( + "pred_boxes3d_classes", sensors=["LIDAR_TOP"] + ), + "pred_boxes3d_scores": data_key( + "pred_boxes3d_scores", sensors=["LIDAR_TOP"] + ), + "pred_boxes3d_velocities": data_key( + "pred_boxes3d_velocities", sensors=["LIDAR_TOP"] + ), +} + + +def get_config() -> ExperimentConfig: + """Returns the config dict for CC-3DT on nuScenes. + + Returns: + ExperimentConfig: The configuration + """ + ###################################################### + ## General Config ## + ###################################################### + config = get_kf3d_cfg().ref_mode() + + config.experiment_name = "cc_3dt_pp_kf3d_nusc" + + ###################################################### + ## Datasets with augmentations ## + ###################################################### + config.pure_detection = "" + + data_root = "data/nuscenes" + version = "v1.0-trainval" + test_split = "val" + + data = DataConfig() + + data.train_dataloader = None + + test_dataset = class_config( + NuScenesDetection, + data_root=data_root, + version=version, + split=test_split, + keys_to_load=[K.images, K.original_images], + data_backend=class_config(HDF5Backend), + pure_detection=config.pure_detection, + cache_as_binary=True, + cached_file_path=f"{data_root}/val.pkl", + ) + + data.test_dataloader = get_test_dataloader( + test_dataset=test_dataset, samples_per_gpu=1, workers_per_gpu=1 + ) + + config.data = data + + ###################################################### + ## MODEL & LOSS ## + ###################################################### + basemodel = class_config( + ResNet, resnet_name="resnet101", pretrained=True, trainable_layers=3 + ) + + track_graph = class_config( + CC3DTrackGraph, + track=class_config( + CC3DTrackAssociation, + init_score_thr=0.2, + obj_score_thr=0.1, + match_score_thr=0.3, + nms_class_iou_thr=0.3, + bbox_affinity_weight=0.75, + with_velocities=True, + ), + update_3d_score=False, + use_velocities=True, + add_backdrops=False, + ) + + config.model = class_config( + CC3DT, + basemodel=basemodel, + track_graph=track_graph, + detection_range=[40, 40, 40, 50, 50, 50, 50, 50, 30, 30], + ) + + ###################################################### + ## DATA CONNECTOR ## + ###################################################### + config.test_data_connector = class_config( + MultiSensorDataConnector, key_mapping=CONN_NUSC_BBOX_3D_TEST + ) + + ###################################################### + ## CALLBACKS ## + ###################################################### + # Logger + callbacks = get_default_callbacks_cfg(config.output_dir) + + # Evaluator + callbacks.append( + class_config( + EvaluatorCallback, + evaluator=class_config( + NuScenesDet3DEvaluator, + data_root=data_root, + version=version, + split=test_split, + ), + save_predictions=True, + output_dir=config.output_dir, + test_connector=class_config( + CallbackConnector, key_mapping=CONN_NUSC_DET3D_EVAL + ), + ) + ) + + callbacks.append( + class_config( + EvaluatorCallback, + evaluator=class_config( + NuScenesTrack3DEvaluator, metadata=("use_camera", "use_radar") + ), + save_predictions=True, + output_dir=config.output_dir, + test_connector=class_config( + CallbackConnector, key_mapping=CONN_NUSC_TRACK3D_EVAL + ), + ) + ) + + config.callbacks = callbacks + + return config.value_mode() diff --git a/vis4d/zoo/cc_3dt/data.py b/vis4d/zoo/cc_3dt/data.py new file mode 100644 index 0000000000000000000000000000000000000000..3ed579cd778160145f7248c9829ca5aa468d0604 --- /dev/null +++ b/vis4d/zoo/cc_3dt/data.py @@ -0,0 +1,240 @@ +"""CC-3DT NuScenes data config.""" + +from __future__ import annotations + +from ml_collections import ConfigDict + +from vis4d.config import class_config +from vis4d.config.typing import DataConfig +from vis4d.data.const import CommonKeys as K +from vis4d.data.data_pipe import DataPipe +from vis4d.data.datasets.nuscenes import NuScenes +from vis4d.data.loader import multi_sensor_collate +from vis4d.data.reference import MultiViewDataset, UniformViewSampler +from vis4d.data.transforms import RandomApply, compose +from vis4d.data.transforms.flip import ( + FlipBoxes2D, + FlipBoxes3D, + FlipImages, + FlipIntrinsics, +) +from vis4d.data.transforms.normalize import NormalizeImages +from vis4d.data.transforms.pad import PadImages +from vis4d.data.transforms.post_process import PostProcessBoxes2D +from vis4d.data.transforms.resize import ( + GenResizeParameters, + ResizeBoxes2D, + ResizeImages, + ResizeIntrinsics, +) +from vis4d.data.transforms.to_tensor import ToTensor +from vis4d.engine.connectors import data_key, pred_key +from vis4d.zoo.base import ( + get_inference_dataloaders_cfg, + get_train_dataloader_cfg, +) +from vis4d.zoo.base.datasets.nuscenes import ( + get_nusc_mini_val_cfg, + get_nusc_mono_mini_train_cfg, + get_nusc_mono_train_cfg, + get_nusc_train_cfg, + get_nusc_val_cfg, +) + +CONN_NUSC_DET3D_EVAL = { + "tokens": data_key("token"), + "boxes_3d": pred_key("boxes_3d"), + "velocities": pred_key("velocities"), + "class_ids": pred_key("class_ids"), + "scores_3d": pred_key("scores_3d"), +} + +CONN_NUSC_TRACK3D_EVAL = { + "tokens": data_key("token"), + "boxes_3d": pred_key("boxes_3d"), + "velocities": pred_key("velocities"), + "class_ids": pred_key("class_ids"), + "scores_3d": pred_key("scores_3d"), + "track_ids": pred_key("track_ids"), +} + +CONN_NUSC_BBOX_3D_TEST = { + "images": data_key(K.images, sensors=NuScenes.CAMERAS), + "images_hw": data_key(K.original_hw, sensors=NuScenes.CAMERAS), + "intrinsics": data_key(K.intrinsics, sensors=NuScenes.CAMERAS), + "extrinsics": data_key(K.extrinsics, sensors=NuScenes.CAMERAS), + "frame_ids": K.frame_ids, +} + +CONN_NUSC_BBOX_3D_VIS = { + "images": data_key(K.original_images, sensors=NuScenes.CAMERAS), + "image_names": data_key(K.sample_names, sensors=NuScenes.CAMERAS), + "boxes3d": pred_key("boxes_3d"), + "intrinsics": data_key(K.intrinsics, sensors=NuScenes.CAMERAS), + "extrinsics": data_key(K.extrinsics, sensors=NuScenes.CAMERAS), + "scores": pred_key("scores_3d"), + "class_ids": pred_key("class_ids"), + "track_ids": pred_key("track_ids"), + "sequence_names": data_key(K.sequence_names), +} + +CONN_NUSC_BEV_BBOX_3D_VIS = { + "sample_names": data_key(K.sample_names, sensors=["LIDAR_TOP"]), + "boxes3d": pred_key("boxes_3d"), + "extrinsics": data_key(K.extrinsics, sensors=["LIDAR_TOP"]), + "track_ids": pred_key("track_ids"), + "sequence_names": data_key(K.sequence_names), +} + + +def get_train_dataloader( + train_dataset: ConfigDict, samples_per_gpu: int, workers_per_gpu: int +) -> ConfigDict: + """Get the default train dataloader for nuScenes tracking.""" + train_dataset_cfg = class_config( + MultiViewDataset, + dataset=train_dataset, + sampler=class_config(UniformViewSampler, scope=2, num_ref_samples=1), + ) + + preprocess_transforms = [ + class_config(GenResizeParameters, shape=(900, 1600), keep_ratio=True), + class_config(ResizeImages), + class_config(ResizeBoxes2D), + ] + + preprocess_transforms.append( + class_config( + RandomApply, + transforms=[ + class_config(FlipImages), + class_config(FlipIntrinsics), + class_config(FlipBoxes2D), + class_config(FlipBoxes3D), + ], + probability=0.5, + ) + ) + + preprocess_transforms.append(class_config(PostProcessBoxes2D)) + + train_preprocess_cfg = class_config( + compose, transforms=preprocess_transforms + ) + + train_batchprocess_cfg = class_config( + compose, + transforms=[ + class_config(PadImages), + class_config(NormalizeImages), + class_config(ToTensor), + ], + ) + + return get_train_dataloader_cfg( + datasets_cfg=class_config( + DataPipe, + datasets=train_dataset_cfg, + preprocess_fn=train_preprocess_cfg, + ), + samples_per_gpu=samples_per_gpu, + workers_per_gpu=workers_per_gpu, + batchprocess_cfg=train_batchprocess_cfg, + ) + + +def get_test_dataloader( + test_dataset: ConfigDict, samples_per_gpu: int, workers_per_gpu: int +) -> ConfigDict: + """Get the default test dataloader for nuScenes tracking.""" + test_transforms = [ + class_config( + GenResizeParameters, + shape=(900, 1600), + keep_ratio=True, + sensors=NuScenes.CAMERAS, + ), + class_config(ResizeImages, sensors=NuScenes.CAMERAS), + class_config(ResizeIntrinsics, sensors=NuScenes.CAMERAS), + ] + + test_preprocess_cfg = class_config(compose, transforms=test_transforms) + + test_batch_transforms = [ + class_config(PadImages, sensors=NuScenes.CAMERAS), + class_config(NormalizeImages, sensors=NuScenes.CAMERAS), + class_config(ToTensor, sensors=NuScenes.SENSORS), + ] + + test_batchprocess_cfg = class_config( + compose, transforms=test_batch_transforms + ) + + test_dataset_cfg = class_config( + DataPipe, datasets=test_dataset, preprocess_fn=test_preprocess_cfg + ) + + return get_inference_dataloaders_cfg( + datasets_cfg=test_dataset_cfg, + samples_per_gpu=samples_per_gpu, + workers_per_gpu=workers_per_gpu, + video_based_inference=True, + batchprocess_cfg=test_batchprocess_cfg, + collate_fn=multi_sensor_collate, + sensors=NuScenes.SENSORS, + ) + + +def get_nusc_cfg( + data_root: str = "data/nuscenes", + version: str = "v1.0-trainval", + train_split: str = "train", + test_split: str = "val", + data_backend: None | ConfigDict = None, + samples_per_gpu: int = 2, + workers_per_gpu: int = 2, +) -> DataConfig: + """Get the default config for nuScenes tracking.""" + data = DataConfig() + + if version == "v1.0-mini": # pragma: no cover + assert train_split == "mini_train" + assert test_split == "mini_val" + train_dataset = get_nusc_mono_mini_train_cfg( + data_root=data_root, data_backend=data_backend + ) + test_dataset = get_nusc_mini_val_cfg( + data_root=data_root, data_backend=data_backend + ) + elif version == "v1.0-trainval": + assert train_split == "train" + train_dataset = get_nusc_mono_train_cfg( + data_root=data_root, data_backend=data_backend + ) + + if test_split == "val": + test_dataset = get_nusc_val_cfg( + data_root=data_root, data_backend=data_backend + ) + elif test_split == "train": + test_dataset = get_nusc_train_cfg( + data_root=data_root, + skip_empty_samples=False, + keys_to_load=[K.images, K.original_images, K.boxes3d], + data_backend=data_backend, + ) + else: + # TODO: Add support for v1.0-test + raise ValueError(f"Unknown version {version}") + + data.train_dataloader = get_train_dataloader( + train_dataset=train_dataset, + samples_per_gpu=samples_per_gpu, + workers_per_gpu=workers_per_gpu, + ) + + data.test_dataloader = get_test_dataloader( + test_dataset, samples_per_gpu=1, workers_per_gpu=1 + ) + + return data diff --git a/vis4d/zoo/cc_3dt/model.py b/vis4d/zoo/cc_3dt/model.py new file mode 100644 index 0000000000000000000000000000000000000000..39d6c9b1935a94b6607aa8cf359fc0174dc25748 --- /dev/null +++ b/vis4d/zoo/cc_3dt/model.py @@ -0,0 +1,171 @@ +"""CC-3DT model config.""" + +from __future__ import annotations + +from ml_collections import ConfigDict, FieldReference + +from vis4d.config import class_config +from vis4d.data.const import CommonKeys as K +from vis4d.engine.connectors import LossConnector, pred_key, remap_pred_keys +from vis4d.engine.loss_module import LossModule +from vis4d.model.track3d.cc_3dt import FasterRCNNCC3DT +from vis4d.op.box.anchor import AnchorGenerator +from vis4d.op.detect3d.qd_3dt import Box3DUncertaintyLoss +from vis4d.op.detect.faster_rcnn import FasterRCNNHead +from vis4d.op.detect.rcnn import RCNNHead, RCNNLoss +from vis4d.op.detect.rpn import RPNLoss +from vis4d.op.loss.common import smooth_l1_loss +from vis4d.op.track.qdtrack import QDTrackInstanceSimilarityLoss +from vis4d.state.track3d.cc_3dt import CC3DTrackGraph +from vis4d.zoo.base import get_callable_cfg +from vis4d.zoo.base.models.faster_rcnn import ( + get_default_rcnn_box_codec_cfg, + get_default_rpn_box_codec_cfg, +) +from vis4d.zoo.base.models.qdtrack import CONN_ROI_LOSS_2D as _CONN_ROI_LOSS_2D +from vis4d.zoo.base.models.qdtrack import ( + CONN_TRACK_LOSS_2D as _CONN_TRACK_LOSS_2D, +) + +PRED_PREFIX = "qdtrack_out" + +CONN_RPN_LOSS_2D = { + "cls_outs": pred_key(f"{PRED_PREFIX}.detector_out.rpn.cls"), + "reg_outs": pred_key(f"{PRED_PREFIX}.detector_out.rpn.box"), + "target_boxes": pred_key(f"{PRED_PREFIX}.key_target_boxes"), + "images_hw": pred_key(f"{PRED_PREFIX}.key_images_hw"), +} + +CONN_ROI_LOSS_2D = remap_pred_keys(_CONN_ROI_LOSS_2D, PRED_PREFIX) + +CONN_TRACK_LOSS_2D = remap_pred_keys(_CONN_TRACK_LOSS_2D, PRED_PREFIX) + +CONN_DET_3D_LOSS = { + "pred": pred_key("detector_3d_out"), + "target": pred_key("detector_3d_target"), + "labels": pred_key("detector_3d_labels"), +} + +CONN_BBOX_3D_TRAIN = { + "images": K.images, + "images_hw": K.input_hw, + "intrinsics": K.intrinsics, + "boxes2d": K.boxes2d, + "boxes3d": K.boxes3d, + "boxes3d_classes": K.boxes3d_classes, + "boxes3d_track_ids": K.boxes3d_track_ids, + "keyframes": "keyframes", +} + + +def get_cc_3dt_cfg( + num_classes: int | FieldReference, + basemodel: ConfigDict, + pure_det: bool | FieldReference = False, + motion_model: str | FieldReference = "KF3D", + lstm_model: ConfigDict | None = None, + fps: int | FieldReference = 2, +) -> tuple[ConfigDict, ConfigDict]: + """Get CC-3DT model config. + + Args: + num_classes (int): Number of classes. + basemodel (ConfigDict): Base model config. + pure_det (bool, optional): Whether to use pure detection mode. + Defaults to False. + motion_model (str, optional): Motion model. Defaults to "KF3D". + lstm_model (ConfigDict, optional): LSTM model config. Defaults to None. + fps (int, optional): FPS. Defaults to 2. + """ + ###################################################### + ## MODEL ## + ###################################################### + anchor_generator = class_config( + AnchorGenerator, + scales=[4, 8], + ratios=[0.25, 0.5, 1.0, 2.0, 4.0], + strides=[4, 8, 16, 32, 64], + ) + + roi_head = class_config( + RCNNHead, + num_shared_convs=4, + num_classes=num_classes, + ) + + faster_rcnn_head = class_config( + FasterRCNNHead, + num_classes=num_classes, + anchor_generator=anchor_generator, + roi_head=roi_head, + ) + + track_graph = class_config( + CC3DTrackGraph, + motion_model=motion_model, + lstm_model=lstm_model, + fps=fps, + ) + + model = class_config( + FasterRCNNCC3DT, + num_classes=num_classes, + basemodel=basemodel, + faster_rcnn_head=faster_rcnn_head, + track_graph=track_graph, + pure_det=pure_det, + ) + + ###################################################### + ## LOSS ## + ###################################################### + rpn_box_encoder, _ = get_default_rpn_box_codec_cfg() + rcnn_box_encoder, _ = get_default_rcnn_box_codec_cfg() + + rpn_loss = class_config( + RPNLoss, + anchor_generator=anchor_generator, + box_encoder=rpn_box_encoder, + loss_bbox=get_callable_cfg(smooth_l1_loss, beta=1.0 / 9.0), + ) + rcnn_loss = class_config( + RCNNLoss, + box_encoder=rcnn_box_encoder, + num_classes=num_classes, + loss_bbox=get_callable_cfg(smooth_l1_loss, beta=1.0 / 9.0), + ) + + track_loss = class_config(QDTrackInstanceSimilarityLoss) + + loss = class_config( + LossModule, + losses=[ + { + "loss": rpn_loss, + "connector": class_config( + LossConnector, key_mapping=CONN_RPN_LOSS_2D + ), + }, + { + "loss": rcnn_loss, + "connector": class_config( + LossConnector, key_mapping=CONN_ROI_LOSS_2D + ), + "weight": 5.0, + }, + { + "loss": track_loss, + "connector": class_config( + LossConnector, key_mapping=CONN_TRACK_LOSS_2D + ), + }, + { + "loss": class_config(Box3DUncertaintyLoss), + "connector": class_config( + LossConnector, key_mapping=CONN_DET_3D_LOSS + ), + }, + ], + ) + + return model, loss diff --git a/vis4d/zoo/cc_3dt/velo_lstm_bevformer_base_100e_nusc.py b/vis4d/zoo/cc_3dt/velo_lstm_bevformer_base_100e_nusc.py new file mode 100644 index 0000000000000000000000000000000000000000..51c566fd31ffc6557ae914a196ddda888f50a37b --- /dev/null +++ b/vis4d/zoo/cc_3dt/velo_lstm_bevformer_base_100e_nusc.py @@ -0,0 +1,150 @@ +# pylint: disable=duplicate-code +"""CC-3DT VeloLSTM for BEVFormer on nuScenes.""" +from __future__ import annotations + +from torch.optim.adam import Adam +from torch.optim.lr_scheduler import MultiStepLR + +from vis4d.config import class_config +from vis4d.config.typing import ( + DataConfig, + ExperimentConfig, + ExperimentParameters, +) +from vis4d.data.datasets.nuscenes_trajectory import NuScenesTrajectory +from vis4d.engine.connectors import ( + DataConnector, + LossConnector, + data_key, + pred_key, +) +from vis4d.engine.loss_module import LossModule +from vis4d.model.motion.velo_lstm import VeloLSTM +from vis4d.op.motion.velo_lstm import VeloLSTMLoss +from vis4d.zoo.base import ( + get_default_callbacks_cfg, + get_default_cfg, + get_default_pl_trainer_cfg, + get_lr_scheduler_cfg, + get_optimizer_cfg, + get_train_dataloader_cfg, +) + +TRAJ_TRAIN = {"pred_traj": "pred_traj"} +TRAJ_LOSS = { + "loc_preds": pred_key("loc_preds"), + "loc_refines": pred_key("loc_refines"), + "gt_traj": data_key("gt_traj"), +} + + +def get_config() -> ExperimentConfig: + """Returns the config dict for VeloLSTM on nuScenes. + + Returns: + ExperimentConfig: The configuration + """ + ###################################################### + ## General Config ## + ###################################################### + config = get_default_cfg(exp_name="velo_lstm_bevformer_base_100e_nusc") + + # Hyper Parameters + params = ExperimentParameters() + params.samples_per_gpu = 32 + params.workers_per_gpu = 4 + params.lr = 0.005 + params.num_epochs = 100 + config.params = params + + ###################################################### + ## Datasets with augmentations ## + ###################################################### + data = DataConfig() + + train_dataset_cfg = class_config( + NuScenesTrajectory, + detector="cc_3dt_frcnn_r101_fpn", + data_root="data/nuscenes", + version="v1.0-trainval", + split="train", + pure_detection="./vis4d-workspace/pure_det/bevformer_base.json", + cache_as_binary=True, + cached_file_path="data/nuscenes/cc_3dt_bevformer_base_traj_train.pkl", + ) + + data.train_dataloader = get_train_dataloader_cfg( + datasets_cfg=train_dataset_cfg, + samples_per_gpu=params.samples_per_gpu, + workers_per_gpu=params.workers_per_gpu, + collate_keys=["pred_traj", "gt_traj"], + ) + + data.test_dataloader = None + + config.data = data + + ###################################################### + ## MODEL & LOSS ## + ###################################################### + config.model = class_config(VeloLSTM) + + config.loss = class_config( + LossModule, + losses=[ + { + "loss": class_config(VeloLSTMLoss), + "weight": 10.0, + "connector": class_config( + LossConnector, key_mapping=TRAJ_LOSS + ), + } + ], + ) + + ###################################################### + ## OPTIMIZERS ## + ###################################################### + config.optimizers = [ + get_optimizer_cfg( + optimizer=class_config( + Adam, lr=params.lr, amsgrad=True, weight_decay=0.0001 + ), + lr_schedulers=[ + get_lr_scheduler_cfg( + class_config( + MultiStepLR, milestones=[20, 40, 60, 80], gamma=0.5 + ), + ), + ], + ) + ] + + ###################################################### + ## DATA CONNECTOR ## + ###################################################### + config.train_data_connector = class_config( + DataConnector, key_mapping=TRAJ_TRAIN + ) + + config.test_data_connector = None + + ###################################################### + ## CALLBACKS ## + ###################################################### + # Logger + callbacks = get_default_callbacks_cfg() + + config.callbacks = callbacks + + ###################################################### + ## PL CLI ## + ###################################################### + # PL Trainer args + pl_trainer = get_default_pl_trainer_cfg(config) + pl_trainer.max_epochs = params.num_epochs + pl_trainer.gradient_clip_val = 3 + pl_trainer.check_val_every_n_epoch = 101 # Disable validation + config.pl_trainer = pl_trainer + + return config.value_mode() diff --git a/vis4d/zoo/cc_3dt/velo_lstm_frcnn_r101_fpn_100e_nusc.py b/vis4d/zoo/cc_3dt/velo_lstm_frcnn_r101_fpn_100e_nusc.py new file mode 100644 index 0000000000000000000000000000000000000000..f3752ef781b7ca2716c2670fdcdd61e4ecddea5f --- /dev/null +++ b/vis4d/zoo/cc_3dt/velo_lstm_frcnn_r101_fpn_100e_nusc.py @@ -0,0 +1,150 @@ +# pylint: disable=duplicate-code +"""CC-3DT VeloLSTM on nuScenes.""" +from __future__ import annotations + +from torch.optim.adam import Adam +from torch.optim.lr_scheduler import MultiStepLR + +from vis4d.config import class_config +from vis4d.config.typing import ( + DataConfig, + ExperimentConfig, + ExperimentParameters, +) +from vis4d.data.data_pipe import DataPipe +from vis4d.data.datasets.nuscenes_trajectory import NuScenesTrajectory +from vis4d.engine.connectors import ( + DataConnector, + LossConnector, + data_key, + pred_key, +) +from vis4d.engine.loss_module import LossModule +from vis4d.model.motion.velo_lstm import VeloLSTM +from vis4d.op.motion.velo_lstm import VeloLSTMLoss +from vis4d.zoo.base import ( + get_default_callbacks_cfg, + get_default_cfg, + get_default_pl_trainer_cfg, + get_lr_scheduler_cfg, + get_optimizer_cfg, + get_train_dataloader_cfg, +) + +TRAJ_TRAIN = {"pred_traj": "pred_traj"} +TRAJ_LOSS = { + "loc_preds": pred_key("loc_preds"), + "loc_refines": pred_key("loc_refines"), + "gt_traj": data_key("gt_traj"), +} + + +def get_config() -> ExperimentConfig: + """Returns the config dict for VeloLSTM on nuScenes. + + Returns: + ExperimentConfig: The configuration + """ + ###################################################### + ## General Config ## + ###################################################### + config = get_default_cfg(exp_name="velo_lstm_frcnn_r101_fpn_100e_nusc") + + # Hyper Parameters + params = ExperimentParameters() + params.samples_per_gpu = 32 + params.workers_per_gpu = 4 + params.lr = 0.005 + params.num_epochs = 100 + config.params = params + + ###################################################### + ## Datasets with augmentations ## + ###################################################### + data = DataConfig() + + train_dataset_cfg = class_config( + NuScenesTrajectory, + detector="cc_3dt_frcnn_r101_fpn", + data_root="data/nuscenes", + version="v1.0-trainval", + split="train", + pure_detection="./vis4d-workspace/pure_det/cc_3dt_frcnn_r101_fpn.json", + cache_as_binary=True, + cached_file_path="data/nuscenes/cc_3dt_frcnn_r101_fpn_traj_train.pkl", + ) + + data.train_dataloader = get_train_dataloader_cfg( + datasets_cfg=class_config(DataPipe, datasets=train_dataset_cfg), + samples_per_gpu=params.samples_per_gpu, + workers_per_gpu=params.workers_per_gpu, + collate_keys=["pred_traj", "gt_traj"], + ) + + data.test_dataloader = None + + config.data = data + + ###################################################### + ## MODEL & LOSS ## + ###################################################### + config.model = class_config(VeloLSTM) + + config.loss = class_config( + LossModule, + losses=[ + { + "loss": class_config(VeloLSTMLoss), + "connector": class_config( + LossConnector, key_mapping=TRAJ_LOSS + ), + } + ], + ) + + ###################################################### + ## OPTIMIZERS ## + ###################################################### + config.optimizers = [ + get_optimizer_cfg( + optimizer=class_config( + Adam, lr=params.lr, amsgrad=True, weight_decay=0.0001 + ), + lr_schedulers=[ + get_lr_scheduler_cfg( + class_config( + MultiStepLR, milestones=[20, 40, 60, 80], gamma=0.5 + ), + ), + ], + ) + ] + + ###################################################### + ## DATA CONNECTOR ## + ###################################################### + config.train_data_connector = class_config( + DataConnector, key_mapping=TRAJ_TRAIN + ) + + config.test_data_connector = None + + ###################################################### + ## CALLBACKS ## + ###################################################### + # Logger + callbacks = get_default_callbacks_cfg() + + config.callbacks = callbacks + + ###################################################### + ## PL CLI ## + ###################################################### + # PL Trainer args + pl_trainer = get_default_pl_trainer_cfg(config) + pl_trainer.max_epochs = params.num_epochs + pl_trainer.gradient_clip_val = 3 + pl_trainer.check_val_every_n_epoch = 101 # Disable validation + config.pl_trainer = pl_trainer + + return config.value_mode() diff --git a/vis4d/zoo/faster_rcnn/__init__.py b/vis4d/zoo/faster_rcnn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4aaded2cdd7db40c570a95ea559526a3b37657a8 --- /dev/null +++ b/vis4d/zoo/faster_rcnn/__init__.py @@ -0,0 +1,7 @@ +"""Faster-RCNN Model Zoo.""" + +from . import faster_rcnn_coco + +AVAILABLE_MODELS = { + "faster_rcnn_coco": faster_rcnn_coco, +} diff --git a/vis4d/zoo/faster_rcnn/faster_rcnn_coco.py b/vis4d/zoo/faster_rcnn/faster_rcnn_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..f7aa6dab689ed88a8403e3ffb08741a512ea7658 --- /dev/null +++ b/vis4d/zoo/faster_rcnn/faster_rcnn_coco.py @@ -0,0 +1,170 @@ +# pylint: disable=duplicate-code +"""Faster RCNN COCO training example.""" +from __future__ import annotations + +from torch.optim.lr_scheduler import LinearLR, MultiStepLR +from torch.optim.sgd import SGD + +from vis4d.config import class_config +from vis4d.config.typing import ExperimentConfig, ExperimentParameters +from vis4d.data.io.hdf5 import HDF5Backend +from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback +from vis4d.engine.connectors import CallbackConnector, DataConnector +from vis4d.eval.coco import COCODetectEvaluator +from vis4d.op.base import ResNet +from vis4d.vis.image import BoundingBoxVisualizer +from vis4d.zoo.base import ( + get_default_callbacks_cfg, + get_default_cfg, + get_default_pl_trainer_cfg, + get_lr_scheduler_cfg, + get_optimizer_cfg, +) +from vis4d.zoo.base.data_connectors import ( + CONN_BBOX_2D_TEST, + CONN_BBOX_2D_TRAIN, + CONN_BBOX_2D_VIS, +) +from vis4d.zoo.base.datasets.coco import ( + CONN_COCO_BBOX_EVAL, + get_coco_detection_cfg, +) +from vis4d.zoo.base.models.faster_rcnn import get_faster_rcnn_cfg + + +def get_config() -> ExperimentConfig: + """Returns the Faster-RCNN config dict for the coco detection task. + + This is an example that shows how to set up a training experiment for the + COCO detection task. + + Note that the high level params are exposed in the config. This allows + to easily change them from the command line. + E.g.: + >>> python -m vis4d.engine.run fit --config configs/faster_rcnn/faster_rcnn_coco.py --config.params.lr 0.001 + + Returns: + ExperimentConfig: The configuration + """ + ###################################################### + ## General Config ## + ###################################################### + config = get_default_cfg(exp_name="faster_rcnn_r50_fpn_coco") + + # High level hyper parameters + params = ExperimentParameters() + params.samples_per_gpu = 2 + params.workers_per_gpu = 2 + params.lr = 0.02 + params.num_epochs = 12 + params.num_classes = 80 + config.params = params + + ###################################################### + ## Datasets with augmentations ## + ###################################################### + data_root = "data/coco" + train_split = "train2017" + test_split = "val2017" + + data_backend = class_config(HDF5Backend) + + config.data = get_coco_detection_cfg( + data_root=data_root, + train_split=train_split, + test_split=test_split, + data_backend=data_backend, + samples_per_gpu=params.samples_per_gpu, + workers_per_gpu=params.workers_per_gpu, + ) + + ###################################################### + ## MODEL & LOSS ## + ###################################################### + basemodel = class_config( + ResNet, resnet_name="resnet50", pretrained=True, trainable_layers=3 + ) + + config.model, config.loss = get_faster_rcnn_cfg( + num_classes=params.num_classes, basemodel=basemodel + ) + + ###################################################### + ## OPTIMIZERS ## + ###################################################### + config.optimizers = [ + get_optimizer_cfg( + optimizer=class_config( + SGD, lr=params.lr, momentum=0.9, weight_decay=0.0001 + ), + lr_schedulers=[ + get_lr_scheduler_cfg( + class_config( + LinearLR, start_factor=0.001, total_iters=500 + ), + end=500, + epoch_based=False, + ), + get_lr_scheduler_cfg( + class_config(MultiStepLR, milestones=[8, 11], gamma=0.1), + ), + ], + ) + ] + + ###################################################### + ## DATA CONNECTOR ## + ###################################################### + config.train_data_connector = class_config( + DataConnector, + key_mapping=CONN_BBOX_2D_TRAIN, + ) + + config.test_data_connector = class_config( + DataConnector, + key_mapping=CONN_BBOX_2D_TEST, + ) + + ###################################################### + ## CALLBACKS ## + ###################################################### + # Logger + callbacks = get_default_callbacks_cfg() + + # Visualizer + callbacks.append( + class_config( + VisualizerCallback, + visualizer=class_config(BoundingBoxVisualizer, vis_freq=100), + output_dir=config.output_dir, + test_connector=class_config( + CallbackConnector, key_mapping=CONN_BBOX_2D_VIS + ), + ) + ) + + # Evaluator + callbacks.append( + class_config( + EvaluatorCallback, + evaluator=class_config( + COCODetectEvaluator, data_root=data_root, split=test_split + ), + metrics_to_eval=["Det"], + test_connector=class_config( + CallbackConnector, key_mapping=CONN_COCO_BBOX_EVAL + ), + ) + ) + + config.callbacks = callbacks + + ###################################################### + ## PL CLI ## + ###################################################### + # PL Trainer args + pl_trainer = get_default_pl_trainer_cfg(config) + pl_trainer.max_epochs = params.num_epochs + config.pl_trainer = pl_trainer + + return config.value_mode() diff --git a/vis4d/zoo/fcn_resnet/__init__.py b/vis4d/zoo/fcn_resnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6e5854217e69e0807c689cf86a5fc4c85913d09f --- /dev/null +++ b/vis4d/zoo/fcn_resnet/__init__.py @@ -0,0 +1,7 @@ +"""FCN Model Zoo.""" + +from . import fcn_resnet_coco + +AVAILABLE_MODELS = { + "fcn_resnet_coco": fcn_resnet_coco, +} diff --git a/vis4d/zoo/fcn_resnet/fcn_resnet_coco.py b/vis4d/zoo/fcn_resnet/fcn_resnet_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..fb905d9bd35f7d876ae550e672b53d6c0385ea77 --- /dev/null +++ b/vis4d/zoo/fcn_resnet/fcn_resnet_coco.py @@ -0,0 +1,163 @@ +"""FCN-ResNet COCO training example.""" + +from __future__ import annotations + +from torch.optim.lr_scheduler import LinearLR +from torch.optim.sgd import SGD + +from vis4d.config import class_config +from vis4d.config.typing import ExperimentConfig, ExperimentParameters +from vis4d.data.io.hdf5 import HDF5Backend +from vis4d.engine.connectors import DataConnector, LossConnector +from vis4d.engine.loss_module import LossModule +from vis4d.engine.optim import PolyLR +from vis4d.model.seg.fcn_resnet import FCNResNet +from vis4d.op.loss import MultiLevelSegLoss +from vis4d.zoo.base import ( + get_default_callbacks_cfg, + get_default_cfg, + get_default_pl_trainer_cfg, + get_lr_scheduler_cfg, + get_optimizer_cfg, +) +from vis4d.zoo.base.data_connectors.seg import ( + CONN_MASKS_TEST, + CONN_MASKS_TRAIN, + CONN_MULTI_SEG_LOSS, +) +from vis4d.zoo.base.datasets.coco import get_coco_sem_seg_cfg + + +def get_config() -> ExperimentConfig: + """Returns the config dict for the COCO semantic segmentation task. + + Returns: + ExperimentConfig: The configuration + """ + ###################################################### + ## General Config ## + ###################################################### + config = get_default_cfg(exp_name="fcn_coco") + config.sync_batchnorm = True + config.val_check_interval = 2000 + config.check_val_every_n_epoch = None + + ## High level hyper parameters + params = ExperimentParameters() + params.samples_per_gpu = 2 + params.workers_per_gpu = 2 + params.lr = 0.01 + params.num_steps = 40000 + params.num_classes = 21 + config.params = params + + ###################################################### + ## Datasets with augmentations ## + ###################################################### + data_root = "data/COCO" + train_split = "train2017" + test_split = "val2017" + image_size = (520, 520) + + data_backend = class_config(HDF5Backend) + + config.data = get_coco_sem_seg_cfg( + data_root=data_root, + train_split=train_split, + test_split=test_split, + data_backend=data_backend, + image_size=image_size, + samples_per_gpu=params.samples_per_gpu, + workers_per_gpu=params.workers_per_gpu, + ) + + ###################################################### + ## MODEL ## + ###################################################### + config.model = class_config( + FCNResNet, + base_model="resnet50", + num_classes=params.num_classes, + resize=image_size, + ) + + ###################################################### + ## LOSS ## + ###################################################### + config.loss = class_config( + LossModule, + losses={ + "loss": class_config( + MultiLevelSegLoss, feature_idx=[4, 5], weights=[0.5, 1] + ), + "connector": class_config( + LossConnector, key_mapping=CONN_MULTI_SEG_LOSS + ), + }, + ) + + ###################################################### + ## OPTIMIZERS ## + ###################################################### + config.optimizers = [ + get_optimizer_cfg( + optimizer=class_config( + SGD, lr=params.lr, momentum=0.9, weight_decay=0.0005 + ), + lr_schedulers=[ + get_lr_scheduler_cfg( + class_config( + LinearLR, start_factor=0.001, total_iters=500 + ), + end=500, + epoch_based=False, + ), + get_lr_scheduler_cfg( + class_config( + PolyLR, + max_steps=params.num_steps, + min_lr=0.0001, + power=0.9, + ), + epoch_based=False, + ), + ], + ) + ] + + ###################################################### + ## DATA CONNECTOR ## + ###################################################### + config.train_data_connector = class_config( + DataConnector, key_mapping=CONN_MASKS_TRAIN + ) + + config.test_data_connector = class_config( + DataConnector, key_mapping=CONN_MASKS_TEST + ) + + ###################################################### + ## CALLBACKS ## + ###################################################### + callbacks = get_default_callbacks_cfg(epoch_based=False) + + config.callbacks = callbacks + + ###################################################### + ## PL CLI ## + ###################################################### + # PL Trainer args + pl_trainer = get_default_pl_trainer_cfg(config) + pl_trainer.epoch_based = False + pl_trainer.max_steps = params.num_steps + + pl_trainer.checkpoint_period = config.val_check_interval + pl_trainer.val_check_interval = config.val_check_interval + pl_trainer.check_val_every_n_epoch = config.check_val_every_n_epoch + + pl_trainer.sync_batchnorm = config.sync_batchnorm + # pl_trainer.precision = 16 + + config.pl_trainer = pl_trainer + + return config.value_mode() diff --git a/vis4d/zoo/mask_rcnn/__init__.py b/vis4d/zoo/mask_rcnn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6b09f158ffd75c7e3b652df48643f5d049981294 --- /dev/null +++ b/vis4d/zoo/mask_rcnn/__init__.py @@ -0,0 +1,7 @@ +"""Mask-RCNN Model Zoo.""" + +from . import mask_rcnn_coco + +AVAILABLE_MODELS = { + "mask_rcnn_coco": mask_rcnn_coco, +} diff --git a/vis4d/zoo/mask_rcnn/mask_rcnn_coco.py b/vis4d/zoo/mask_rcnn/mask_rcnn_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..dff61dd6999b2b28495f1b2140fb960d70cc0de5 --- /dev/null +++ b/vis4d/zoo/mask_rcnn/mask_rcnn_coco.py @@ -0,0 +1,192 @@ +# pylint: disable=duplicate-code +"""Mask RCNN COCO training example.""" +from __future__ import annotations + +from torch.optim.lr_scheduler import LinearLR, MultiStepLR +from torch.optim.sgd import SGD + +from vis4d.config import class_config +from vis4d.config.typing import ExperimentConfig, ExperimentParameters +from vis4d.data.const import CommonKeys as K +from vis4d.data.io.hdf5 import HDF5Backend +from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback +from vis4d.engine.connectors import ( + CallbackConnector, + DataConnector, + remap_pred_keys, +) +from vis4d.eval.coco import COCODetectEvaluator +from vis4d.op.base import ResNet +from vis4d.vis.image import BoundingBoxVisualizer +from vis4d.zoo.base import ( + get_default_callbacks_cfg, + get_default_cfg, + get_default_pl_trainer_cfg, + get_lr_scheduler_cfg, + get_optimizer_cfg, +) +from vis4d.zoo.base.data_connectors import ( + CONN_BBOX_2D_TEST, + CONN_BBOX_2D_TRAIN, + CONN_BBOX_2D_VIS, +) +from vis4d.zoo.base.datasets.coco import ( + CONN_COCO_BBOX_EVAL, + get_coco_detection_cfg, +) +from vis4d.zoo.base.models.mask_rcnn import get_mask_rcnn_cfg + + +def get_config() -> ExperimentConfig: + """Returns the Mask-RCNN config dict for the coco detection task. + + This is an example that shows how to set up a training experiment for the + COCO detection task. + + Note that the high level params are exposed in the config. This allows + to easily change them from the command line. + E.g.: + >>> python -m vis4d.engine.run fit --config configs/faster_rcnn/faster_rcnn_coco.py --config.params.lr 0.001 + + Returns: + ExperimentConfig: The configuration + """ + ###################################################### + ## General Config ## + ###################################################### + config = get_default_cfg(exp_name="mask_rcnn_r50_fpn_coco") + + # High level hyper parameters + params = ExperimentParameters() + params.samples_per_gpu = 2 + params.workers_per_gpu = 2 + params.lr = 0.02 + params.num_epochs = 12 + params.num_classes = 80 + config.params = params + + ###################################################### + ## Datasets with augmentations ## + ###################################################### + data_root = "data/coco" + train_split = "train2017" + test_split = "val2017" + + data_backend = class_config(HDF5Backend) + + config.data = get_coco_detection_cfg( + data_root=data_root, + train_split=train_split, + train_keys_to_load=( + K.images, + K.boxes2d, + K.boxes2d_classes, + K.instance_masks, + ), + test_split=test_split, + test_keys_to_load=( + K.images, + K.original_images, + K.boxes2d, + K.boxes2d_classes, + K.instance_masks, + ), + data_backend=data_backend, + samples_per_gpu=params.samples_per_gpu, + workers_per_gpu=params.workers_per_gpu, + ) + + ###################################################### + ## MODEL & LOSS ## + ###################################################### + basemodel = class_config( + ResNet, resnet_name="resnet50", pretrained=True, trainable_layers=3 + ) + + config.model, config.loss = get_mask_rcnn_cfg( + num_classes=params.num_classes, basemodel=basemodel + ) + + ###################################################### + ## OPTIMIZERS ## + ###################################################### + config.optimizers = [ + get_optimizer_cfg( + optimizer=class_config( + SGD, lr=params.lr, momentum=0.9, weight_decay=0.0001 + ), + lr_schedulers=[ + get_lr_scheduler_cfg( + class_config( + LinearLR, start_factor=0.001, total_iters=500 + ), + end=500, + epoch_based=False, + ), + get_lr_scheduler_cfg( + class_config(MultiStepLR, milestones=[8, 11], gamma=0.1), + ), + ], + ) + ] + + ###################################################### + ## DATA CONNECTOR ## + ###################################################### + config.train_data_connector = class_config( + DataConnector, + key_mapping=CONN_BBOX_2D_TRAIN, + ) + + config.test_data_connector = class_config( + DataConnector, + key_mapping=CONN_BBOX_2D_TEST, + ) + + ###################################################### + ## CALLBACKS ## + ###################################################### + # Logger + callbacks = get_default_callbacks_cfg() + + # Visualizer + callbacks.append( + class_config( + VisualizerCallback, + visualizer=class_config(BoundingBoxVisualizer, vis_freq=100), + output_dir=config.output_dir, + test_connector=class_config( + CallbackConnector, + key_mapping=remap_pred_keys(CONN_BBOX_2D_VIS, "boxes"), + ), + ) + ) + + # Evaluator + callbacks.append( + class_config( + EvaluatorCallback, + evaluator=class_config( + COCODetectEvaluator, + data_root=data_root, + split=test_split, + ), + metrics_to_eval=["Det"], + test_connector=class_config( + CallbackConnector, + key_mapping=remap_pred_keys(CONN_COCO_BBOX_EVAL, "boxes"), + ), + ) + ) + + config.callbacks = callbacks + + ###################################################### + ## PL CLI ## + ###################################################### + # PL Trainer args + pl_trainer = get_default_pl_trainer_cfg(config) + pl_trainer.max_epochs = params.num_epochs + config.pl_trainer = pl_trainer + + return config.value_mode() diff --git a/vis4d/zoo/qdtrack/__init__.py b/vis4d/zoo/qdtrack/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8d58127bbf41a2c403e77569d3cb50bee68ccf4f --- /dev/null +++ b/vis4d/zoo/qdtrack/__init__.py @@ -0,0 +1,14 @@ +"""QDTrack.""" + +from . import ( + qdtrack_frcnn_r50_fpn_augs_1x_bdd100k, + qdtrack_yolox_x_25e_bdd100k, +) + +# Lists of available models in BDD100K Model Zoo. +AVAILABLE_MODELS = { + "qdtrack_frcnn_r50_fpn_augs_1x_bdd100k": ( + qdtrack_frcnn_r50_fpn_augs_1x_bdd100k + ), + "qdtrack_yolox_x_25e_bdd100k": qdtrack_yolox_x_25e_bdd100k, +} diff --git a/vis4d/zoo/qdtrack/data_yolox.py b/vis4d/zoo/qdtrack/data_yolox.py new file mode 100644 index 0000000000000000000000000000000000000000..376654492f0b23fa9e995b7bd274f2dcbac849f5 --- /dev/null +++ b/vis4d/zoo/qdtrack/data_yolox.py @@ -0,0 +1,275 @@ +"""BDD100K data loading config for QDTrack YOLOX.""" + +from __future__ import annotations + +from ml_collections import ConfigDict + +from vis4d.config import class_config +from vis4d.config.typing import DataConfig +from vis4d.data.const import CommonKeys as K +from vis4d.data.data_pipe import DataPipe, MultiSampleDataPipe +from vis4d.data.datasets.bdd100k import BDD100K, bdd100k_track_map +from vis4d.data.loader import build_train_dataloader, default_collate +from vis4d.data.reference import MultiViewDataset, UniformViewSampler +from vis4d.data.transforms.affine import ( + AffineBoxes2D, + AffineImages, + GenAffineParameters, +) +from vis4d.data.transforms.base import RandomApply, compose +from vis4d.data.transforms.crop import ( + CropBoxes2D, + CropImages, + GenCropParameters, +) +from vis4d.data.transforms.flip import FlipBoxes2D, FlipImages +from vis4d.data.transforms.mixup import ( + GenMixupParameters, + MixupBoxes2D, + MixupImages, +) +from vis4d.data.transforms.mosaic import ( + GenMosaicParameters, + MosaicBoxes2D, + MosaicImages, +) +from vis4d.data.transforms.normalize import NormalizeImages +from vis4d.data.transforms.pad import PadImages +from vis4d.data.transforms.photometric import RandomHSV +from vis4d.data.transforms.post_process import ( + PostProcessBoxes2D, + RescaleTrackIDs, +) +from vis4d.data.transforms.resize import ( + GenResizeParameters, + ResizeBoxes2D, + ResizeImages, +) +from vis4d.data.transforms.to_tensor import ToTensor +from vis4d.zoo.base import get_inference_dataloaders_cfg +from vis4d.zoo.base.callable import get_callable_cfg + + +def get_train_dataloader( + data_backend: None | ConfigDict, + image_size: tuple[int, int], + normalize_image: bool, + samples_per_gpu: int, + workers_per_gpu: int, +) -> ConfigDict: + """Get the default train dataloader for BDD100K tracking.""" + bdd100k_det_train = class_config( + BDD100K, + data_root="data/bdd100k/images/100k/train/", + keys_to_load=(K.images, K.boxes2d), + annotation_path="data/bdd100k/labels/det_20/det_train.json", + category_map=bdd100k_track_map, + config_path="det", + image_channel_mode="BGR", + data_backend=data_backend, + skip_empty_samples=True, + cache_as_binary=True, + cached_file_path="data/bdd100k/pkl/det_train.pkl", + ) + + bdd100k_track_train = class_config( + BDD100K, + data_root="data/bdd100k/images/track/train/", + keys_to_load=(K.images, K.boxes2d), + annotation_path="data/bdd100k/labels/box_track_20/train/", + category_map=bdd100k_track_map, + config_path="box_track", + image_channel_mode="BGR", + data_backend=data_backend, + skip_empty_samples=True, + cache_as_binary=True, + cached_file_path="data/bdd100k/pkl/track_train.pkl", + ) + + train_dataset_cfg = [ + class_config( + MultiViewDataset, + dataset=bdd100k_det_train, + sampler=class_config( + UniformViewSampler, scope=0, num_ref_samples=1 + ), + ), + class_config( + MultiViewDataset, + dataset=bdd100k_track_train, + sampler=class_config( + UniformViewSampler, scope=3, num_ref_samples=1 + ), + ), + ] + + # Train Preprocessing + preprocess_transforms = [ + [ + class_config(GenMosaicParameters, out_shape=image_size), + class_config(MosaicImages, imresize_backend="cv2"), + class_config(MosaicBoxes2D), + ], + [class_config(RescaleTrackIDs)], + ] + + preprocess_transforms += [ + [ + class_config( + GenAffineParameters, + scaling_ratio_range=(0.5, 1.5), + border=(-image_size[0] // 2, -image_size[1] // 2), + ), + class_config(AffineImages, as_int=True), + class_config(AffineBoxes2D), + ] + ] + + preprocess_transforms += [ + [ + class_config( + GenMixupParameters, + out_shape=image_size, + mixup_ratio_dist="const", + scale_range=(0.8, 1.6), + pad_value=114.0, + ), + class_config(MixupImages, imresize_backend="cv2"), + class_config(MixupBoxes2D), + ], + [class_config(RescaleTrackIDs)], + ] + + preprocess_transforms.append( + [class_config(PostProcessBoxes2D, min_area=1.0)] + ) + + batch_transforms = [ + class_config(RandomHSV, same_on_batch=False), + class_config( + RandomApply, + transforms=[class_config(FlipImages), class_config(FlipBoxes2D)], + probability=0.5, + same_on_batch=False, + ), + class_config( + GenResizeParameters, + shape=image_size, + keep_ratio=True, + scale_range=(0.5, 1.5), + same_on_batch=False, + ), + class_config(ResizeImages), + class_config(ResizeBoxes2D), + class_config(GenCropParameters, shape=image_size, same_on_batch=False), + class_config(CropImages), + class_config(CropBoxes2D), + ] + if normalize_image: + batch_transforms += [ + class_config(NormalizeImages), + class_config(PadImages), + ] + else: + batch_transforms += [class_config(PadImages, value=114.0)] + train_batchprocess_cfg = class_config( + compose, transforms=batch_transforms + [class_config(ToTensor)] + ) + + return class_config( + build_train_dataloader, + dataset=class_config( + MultiSampleDataPipe, + datasets=train_dataset_cfg, + preprocess_fn=preprocess_transforms, + ), + samples_per_gpu=samples_per_gpu, + workers_per_gpu=workers_per_gpu, + batchprocess_fn=train_batchprocess_cfg, + collate_fn=get_callable_cfg(default_collate), + pin_memory=True, + shuffle=True, + ) + + +def get_test_dataloader( + data_backend: None | ConfigDict, + image_size: tuple[int, int], + normalize_image: bool, + samples_per_gpu: int, + workers_per_gpu: int, +) -> ConfigDict: + """Get the default test dataloader for BDD100K tracking.""" + test_dataset = class_config( + BDD100K, + data_root="data/bdd100k/images/track/val/", + keys_to_load=(K.images, K.original_images), + annotation_path="data/bdd100k/labels/box_track_20/val/", + category_map=bdd100k_track_map, + config_path="box_track", + image_channel_mode="BGR", + data_backend=data_backend, + cache_as_binary=True, + cached_file_path="data/bdd100k/pkl/track_val.pkl", + ) + + preprocess_transforms = [ + class_config(GenResizeParameters, shape=image_size, keep_ratio=True), + class_config(ResizeImages), + ] + + test_preprocess_cfg = class_config( + compose, transforms=preprocess_transforms + ) + + if normalize_image: + batch_transforms = [ + class_config(NormalizeImages), + class_config(PadImages), + ] + else: + batch_transforms = [class_config(PadImages, value=114.0)] + test_batchprocess_cfg = class_config( + compose, transforms=batch_transforms + [class_config(ToTensor)] + ) + + test_dataset_cfg = class_config( + DataPipe, datasets=test_dataset, preprocess_fn=test_preprocess_cfg + ) + + return get_inference_dataloaders_cfg( + datasets_cfg=test_dataset_cfg, + samples_per_gpu=samples_per_gpu, + workers_per_gpu=workers_per_gpu, + video_based_inference=True, + batchprocess_cfg=test_batchprocess_cfg, + ) + + +def get_bdd100k_track_cfg( + data_backend: None | ConfigDict = None, + image_size: tuple[int, int] = (800, 1440), + normalize_image: bool = False, + samples_per_gpu: int = 2, + workers_per_gpu: int = 2, +) -> DataConfig: + """Get the default config for BDD100K tracking.""" + data = DataConfig() + + data.train_dataloader = get_train_dataloader( + data_backend=data_backend, + image_size=image_size, + normalize_image=normalize_image, + samples_per_gpu=samples_per_gpu, + workers_per_gpu=workers_per_gpu, + ) + + data.test_dataloader = get_test_dataloader( + data_backend=data_backend, + image_size=image_size, + normalize_image=normalize_image, + samples_per_gpu=1, + workers_per_gpu=1, + ) + + return data diff --git a/vis4d/zoo/qdtrack/qdtrack_frcnn_r50_fpn_augs_1x_bdd100k.py b/vis4d/zoo/qdtrack/qdtrack_frcnn_r50_fpn_augs_1x_bdd100k.py new file mode 100644 index 0000000000000000000000000000000000000000..757f97f945b37888be6beb6a94ea4c35dad0581a --- /dev/null +++ b/vis4d/zoo/qdtrack/qdtrack_frcnn_r50_fpn_augs_1x_bdd100k.py @@ -0,0 +1,175 @@ +# pylint: disable=duplicate-code +"""QDTrack with Faster R-CNN on BDD100K.""" +from __future__ import annotations + +from lightning.pytorch.callbacks import ModelCheckpoint +from torch.optim.lr_scheduler import LinearLR, MultiStepLR +from torch.optim.sgd import SGD + +from vis4d.config import class_config +from vis4d.config.typing import ExperimentConfig, ExperimentParameters +from vis4d.data.datasets.bdd100k import bdd100k_track_map +from vis4d.data.io.hdf5 import HDF5Backend +from vis4d.engine.callbacks import ( + EvaluatorCallback, + VisualizerCallback, + YOLOXModeSwitchCallback, +) +from vis4d.engine.connectors import CallbackConnector, DataConnector +from vis4d.eval.bdd100k import BDD100KTrackEvaluator +from vis4d.op.base import ResNet +from vis4d.vis.image import BoundingBoxVisualizer +from vis4d.zoo.base import ( + get_default_callbacks_cfg, + get_default_cfg, + get_default_pl_trainer_cfg, + get_lr_scheduler_cfg, + get_optimizer_cfg, +) +from vis4d.zoo.base.data_connectors import CONN_BBOX_2D_TRACK_VIS +from vis4d.zoo.base.datasets.bdd100k import CONN_BDD100K_TRACK_EVAL +from vis4d.zoo.base.models.qdtrack import ( + CONN_BBOX_2D_TEST, + CONN_BBOX_2D_TRAIN, + get_qdtrack_cfg, +) +from vis4d.zoo.qdtrack.data_yolox import get_bdd100k_track_cfg + + +def get_config() -> ExperimentConfig: + """Returns the config dict for qdtrack on bdd100k. + + Returns: + ExperimentConfig: The configuration + """ + ###################################################### + ## General Config ## + ###################################################### + config = get_default_cfg(exp_name="qdtrack_frcnn_r50_fpn_augs_1x_bdd100k") + + # High level hyper parameters + params = ExperimentParameters() + params.samples_per_gpu = 4 # batch size = 4 GPUs * 4 samples per GPU = 16 + params.workers_per_gpu = 8 + params.lr = 0.02 + params.num_epochs = 12 + config.params = params + + ###################################################### + ## Datasets with augmentations ## + ###################################################### + data_backend = class_config(HDF5Backend) + + config.data = get_bdd100k_track_cfg( + data_backend=data_backend, + image_size=(720, 1280), + normalize_image=True, + samples_per_gpu=params.samples_per_gpu, + workers_per_gpu=params.workers_per_gpu, + ) + + ###################################################### + ## MODEL ## + ###################################################### + num_classes = len(bdd100k_track_map) + basemodel = class_config( + ResNet, resnet_name="resnet50", pretrained=True, trainable_layers=3 + ) + + config.model, config.loss = get_qdtrack_cfg( + num_classes=num_classes, basemodel=basemodel + ) + + ###################################################### + ## OPTIMIZERS ## + ###################################################### + config.optimizers = [ + get_optimizer_cfg( + optimizer=class_config( + SGD, lr=params.lr, momentum=0.9, weight_decay=0.0001 + ), + lr_schedulers=[ + get_lr_scheduler_cfg( + class_config(LinearLR, start_factor=0.1, total_iters=1000), + end=1000, + epoch_based=False, + ), + get_lr_scheduler_cfg( + class_config(MultiStepLR, milestones=[8, 11], gamma=0.1), + ), + ], + ) + ] + + ###################################################### + ## DATA CONNECTOR ## + ###################################################### + config.train_data_connector = class_config( + DataConnector, key_mapping=CONN_BBOX_2D_TRAIN + ) + + config.test_data_connector = class_config( + DataConnector, key_mapping=CONN_BBOX_2D_TEST + ) + + ###################################################### + ## CALLBACKS ## + ###################################################### + # Logger + callbacks = get_default_callbacks_cfg() + + # Mode switch for strong augmentations + callbacks += [class_config(YOLOXModeSwitchCallback, switch_epoch=9)] + + # Visualizer + callbacks.append( + class_config( + VisualizerCallback, + visualizer=class_config( + BoundingBoxVisualizer, vis_freq=500, image_mode="BGR" + ), + output_dir=config.output_dir, + test_connector=class_config( + CallbackConnector, key_mapping=CONN_BBOX_2D_TRACK_VIS + ), + ) + ) + + # Evaluator + callbacks.append( + class_config( + EvaluatorCallback, + evaluator=class_config( + BDD100KTrackEvaluator, + annotation_path="data/bdd100k/labels/box_track_20/val/", + ), + test_connector=class_config( + CallbackConnector, key_mapping=CONN_BDD100K_TRACK_EVAL + ), + ) + ) + + config.callbacks = callbacks + + ###################################################### + ## PL CLI ## + ###################################################### + # PL Trainer args + pl_trainer = get_default_pl_trainer_cfg(config) + pl_trainer.max_epochs = params.num_epochs + pl_trainer.checkpoint_callback = class_config( + ModelCheckpoint, + dirpath=config.get_ref("output_dir") + "/checkpoints", + verbose=True, + save_last=True, + save_on_train_epoch_end=True, + every_n_epochs=1, + save_top_k=4, + mode="max", + monitor="step", + ) + pl_trainer.wandb = True + pl_trainer.gradient_clip_val = 35 + config.pl_trainer = pl_trainer + + return config.value_mode() diff --git a/vis4d/zoo/qdtrack/qdtrack_yolox_x_25e_bdd100k.py b/vis4d/zoo/qdtrack/qdtrack_yolox_x_25e_bdd100k.py new file mode 100644 index 0000000000000000000000000000000000000000..782c42ad12278ba11ce0cb1b5ee614fa51f99d7f --- /dev/null +++ b/vis4d/zoo/qdtrack/qdtrack_yolox_x_25e_bdd100k.py @@ -0,0 +1,163 @@ +# pylint: disable=duplicate-code +"""QDTrack with YOLOX-x on BDD100K.""" +from __future__ import annotations + +from lightning.pytorch.callbacks import ModelCheckpoint + +from vis4d.config import class_config +from vis4d.config.typing import ExperimentConfig, ExperimentParameters +from vis4d.data.datasets.bdd100k import bdd100k_track_map +from vis4d.data.io.hdf5 import HDF5Backend +from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback +from vis4d.engine.connectors import CallbackConnector, DataConnector +from vis4d.eval.bdd100k import BDD100KTrackEvaluator +from vis4d.vis.image import BoundingBoxVisualizer +from vis4d.zoo.base import ( + get_default_callbacks_cfg, + get_default_cfg, + get_default_pl_trainer_cfg, +) +from vis4d.zoo.base.data_connectors import CONN_BBOX_2D_TRACK_VIS +from vis4d.zoo.base.datasets.bdd100k import CONN_BDD100K_TRACK_EVAL +from vis4d.zoo.base.models.qdtrack import ( + CONN_BBOX_2D_TEST, + CONN_BBOX_2D_TRAIN, + get_qdtrack_yolox_cfg, +) +from vis4d.zoo.base.models.yolox import ( + get_yolox_callbacks_cfg, + get_yolox_optimizers_cfg, +) +from vis4d.zoo.qdtrack.data_yolox import get_bdd100k_track_cfg + + +def get_config() -> ExperimentConfig: + """Returns the config dict for qdtrack on bdd100k. + + Returns: + ExperimentConfig: The configuration + """ + ###################################################### + ## General Config ## + ###################################################### + config = get_default_cfg(exp_name="qdtrack_yolox_x_25e_bdd100k") + config.checkpoint_period = 5 + config.check_val_every_n_epoch = 5 + + # Hyper Parameters + params = ExperimentParameters() + params.samples_per_gpu = 8 # batch size = 8 GPUs * 8 samples per GPU = 64 + params.workers_per_gpu = 8 + params.lr = 0.001 + params.num_epochs = 25 + config.params = params + + ###################################################### + ## Datasets with augmentations ## + ###################################################### + data_backend = class_config(HDF5Backend) + + config.data = get_bdd100k_track_cfg( + data_backend=data_backend, + samples_per_gpu=params.samples_per_gpu, + workers_per_gpu=params.workers_per_gpu, + ) + + ###################################################### + ## MODEL ## + ###################################################### + num_classes = len(bdd100k_track_map) + weights = ( + "mmdet://yolox/yolox_x_8x8_300e_coco/" + "yolox_x_8x8_300e_coco_20211126_140254-1ef88d67.pth" + ) + config.model, config.loss = get_qdtrack_yolox_cfg( + num_classes, "xlarge", weights=weights + ) + + ###################################################### + ## OPTIMIZERS ## + ###################################################### + # we use a schedule with 50 epochs, but only train for 25 epochs + num_total_epochs, num_last_epochs = 50, 10 + config.optimizers = get_yolox_optimizers_cfg( + params.lr, num_total_epochs, 1, num_last_epochs + ) + + ###################################################### + ## DATA CONNECTOR ## + ###################################################### + config.train_data_connector = class_config( + DataConnector, key_mapping=CONN_BBOX_2D_TRAIN + ) + + config.test_data_connector = class_config( + DataConnector, key_mapping=CONN_BBOX_2D_TEST + ) + + ###################################################### + ## CALLBACKS ## + ###################################################### + # Logger + callbacks = get_default_callbacks_cfg( + refresh_rate=config.log_every_n_steps + ) + + # YOLOX callbacks + callbacks += get_yolox_callbacks_cfg( + switch_epoch=num_total_epochs - num_last_epochs, num_sizes=0 + ) + + # Visualizer + callbacks.append( + class_config( + VisualizerCallback, + visualizer=class_config( + BoundingBoxVisualizer, vis_freq=500, image_mode="BGR" + ), + output_dir=config.output_dir, + test_connector=class_config( + CallbackConnector, key_mapping=CONN_BBOX_2D_TRACK_VIS + ), + ) + ) + + # Evaluator + callbacks.append( + class_config( + EvaluatorCallback, + evaluator=class_config( + BDD100KTrackEvaluator, + annotation_path="data/bdd100k/labels/box_track_20/val/", + ), + test_connector=class_config( + CallbackConnector, key_mapping=CONN_BDD100K_TRACK_EVAL + ), + ) + ) + + config.callbacks = callbacks + + ###################################################### + ## PL CLI ## + ###################################################### + # PL Trainer args + pl_trainer = get_default_pl_trainer_cfg(config) + pl_trainer.max_epochs = params.num_epochs + pl_trainer.check_val_every_n_epoch = config.check_val_every_n_epoch + pl_trainer.checkpoint_callback = class_config( + ModelCheckpoint, + dirpath=config.get_ref("output_dir") + "/checkpoints", + verbose=True, + save_last=True, + save_on_train_epoch_end=True, + every_n_epochs=config.checkpoint_period, + save_top_k=5, + mode="max", + monitor="step", + ) + pl_trainer.wandb = True + pl_trainer.precision = "16-mixed" + config.pl_trainer = pl_trainer + + return config.value_mode() diff --git a/vis4d/zoo/retinanet/__init__.py b/vis4d/zoo/retinanet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f8ea582388b856e975f2668f55c7bd5b059adec8 --- /dev/null +++ b/vis4d/zoo/retinanet/__init__.py @@ -0,0 +1,7 @@ +"""RetinaNet Model Zoo.""" + +from . import retinanet_coco + +AVAILABLE_MODELS = { + "retinanet_coco": retinanet_coco, +} diff --git a/vis4d/zoo/retinanet/retinanet_coco.py b/vis4d/zoo/retinanet/retinanet_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..446345faaf7f5d4da66f06bcd4f157f11ad5809d --- /dev/null +++ b/vis4d/zoo/retinanet/retinanet_coco.py @@ -0,0 +1,206 @@ +# pylint: disable=duplicate-code +"""RetinaNet COCO training example.""" +from __future__ import annotations + +from torch.optim.lr_scheduler import LinearLR, MultiStepLR +from torch.optim.sgd import SGD + +from vis4d.config import class_config +from vis4d.config.typing import ExperimentConfig, ExperimentParameters +from vis4d.data.io.hdf5 import HDF5Backend +from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback +from vis4d.engine.connectors import ( + CallbackConnector, + DataConnector, + LossConnector, +) +from vis4d.engine.loss_module import LossModule +from vis4d.eval.coco import COCODetectEvaluator +from vis4d.model.detect.retinanet import RetinaNet +from vis4d.op.box.encoder import DeltaXYWHBBoxEncoder +from vis4d.op.detect.retinanet import ( + RetinaNetHeadLoss, + get_default_anchor_generator, +) +from vis4d.vis.image import BoundingBoxVisualizer +from vis4d.zoo.base import ( + get_default_callbacks_cfg, + get_default_cfg, + get_default_pl_trainer_cfg, + get_lr_scheduler_cfg, + get_optimizer_cfg, +) +from vis4d.zoo.base.data_connectors import ( + CONN_BBOX_2D_VIS, + CONN_BOX_LOSS_2D, + CONN_IMAGES_TEST, + CONN_IMAGES_TRAIN, +) +from vis4d.zoo.base.datasets.coco import ( + CONN_COCO_BBOX_EVAL, + get_coco_detection_cfg, +) + + +def get_config() -> ExperimentConfig: + """Returns the RetinaNet config dict for the coco detection task. + + This is an example that shows how to set up a training experiment for the + COCO detection task. + + Note that the high level params are exposed in the config. This allows + to easily change them from the command line. + E.g.: + >>> python -m vis4d.engine.run fit --config vis4d/zoo/retinanet/retinanet_rcnn_coco.py --config.num_epochs 100 --config.params.lr 0.001 + + Returns: + ExperimentConfig: The configuration + """ + ###################################################### + ## General Config ## + ###################################################### + config = get_default_cfg(exp_name="retinanet_r50_fpn_coco") + + # High level hyper parameters + params = ExperimentParameters() + params.samples_per_gpu = 2 + params.workers_per_gpu = 2 + params.lr = 0.01 + params.num_epochs = 12 + params.num_classes = 80 + config.params = params + + ###################################################### + ## Datasets with augmentations ## + ###################################################### + data_root = "data/coco" + train_split = "train2017" + test_split = "val2017" + + data_backend = class_config(HDF5Backend) + + config.data = get_coco_detection_cfg( + data_root=data_root, + train_split=train_split, + test_split=test_split, + data_backend=data_backend, + samples_per_gpu=params.samples_per_gpu, + workers_per_gpu=params.workers_per_gpu, + ) + + ###################################################### + ## MODEL & LOSS ## + ###################################################### + config.model = class_config( + RetinaNet, + num_classes=params.num_classes, + # weights="mmdet", + ) + + box_encoder = class_config( + DeltaXYWHBBoxEncoder, + target_means=(0.0, 0.0, 0.0, 0.0), + target_stds=(1.0, 1.0, 1.0, 1.0), + ) + + anchor_generator = class_config(get_default_anchor_generator) + + retina_loss = class_config( + RetinaNetHeadLoss, + box_encoder=box_encoder, + anchor_generator=anchor_generator, + ) + + config.loss = class_config( + LossModule, + losses={ + "loss": retina_loss, + "connector": class_config( + LossConnector, key_mapping=CONN_BOX_LOSS_2D + ), + }, + ) + + ###################################################### + ## OPTIMIZERS ## + ###################################################### + config.optimizers = [ + get_optimizer_cfg( + optimizer=class_config( + SGD, lr=params.lr, momentum=0.9, weight_decay=0.0001 + ), + lr_schedulers=[ + get_lr_scheduler_cfg( + class_config( + LinearLR, start_factor=0.001, total_iters=500 + ), + end=500, + epoch_based=False, + ), + get_lr_scheduler_cfg( + class_config(MultiStepLR, milestones=[8, 11], gamma=0.1), + ), + ], + ) + ] + + ###################################################### + ## DATA CONNECTOR ## + ###################################################### + config.train_data_connector = class_config( + DataConnector, + key_mapping=CONN_IMAGES_TRAIN, + ) + + config.test_data_connector = class_config( + DataConnector, + key_mapping=CONN_IMAGES_TEST, + ) + + ###################################################### + ## CALLBACKS ## + ###################################################### + # Logger + callbacks = get_default_callbacks_cfg() + + # Visualizer + callbacks.append( + class_config( + VisualizerCallback, + visualizer=class_config(BoundingBoxVisualizer, vis_freq=100), + output_dir=config.output_dir, + test_connector=class_config( + CallbackConnector, + key_mapping=CONN_BBOX_2D_VIS, + ), + ) + ) + + # Evaluator + callbacks.append( + class_config( + EvaluatorCallback, + evaluator=class_config( + COCODetectEvaluator, + data_root=data_root, + split=test_split, + ), + metrics_to_eval=["Det"], + test_connector=class_config( + CallbackConnector, + key_mapping=CONN_COCO_BBOX_EVAL, + ), + ) + ) + + config.callbacks = callbacks + + ###################################################### + ## PL CLI ## + ###################################################### + # PL Trainer args + pl_trainer = get_default_pl_trainer_cfg(config) + pl_trainer.max_epochs = params.num_epochs + config.pl_trainer = pl_trainer + + return config.value_mode() diff --git a/vis4d/zoo/run.py b/vis4d/zoo/run.py new file mode 100644 index 0000000000000000000000000000000000000000..a53b8d955e5fc130aa2c69011f5200d74b003c59 --- /dev/null +++ b/vis4d/zoo/run.py @@ -0,0 +1,27 @@ +"""CLI interface.""" + +from __future__ import annotations + +from absl import app # pylint: disable=no-name-in-module + +from vis4d.common.typing import ArgsType +from vis4d.zoo import AVAILABLE_MODELS + + +def main(argv: ArgsType) -> None: + """Main entry point for the model zoo.""" + assert len(argv) > 1, "Command must be specified: `list`" + if argv[1] == "list": + for ds, models in AVAILABLE_MODELS.items(): + print(ds) + model_names = list(models.keys()) + for model in model_names[:-1]: + print(" ├─", model) + print(" └─", model_names[-1]) + else: + raise ValueError(f"Invalid command. {argv[1]}") + + +def entrypoint() -> None: + """Entry point for the CLI.""" + app.run(main) diff --git a/vis4d/zoo/shift/__init__.py b/vis4d/zoo/shift/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..498a278c10e344317651b20aae8194781a091287 --- /dev/null +++ b/vis4d/zoo/shift/__init__.py @@ -0,0 +1,31 @@ +"""BDD100K Model Zoo.""" + +from .faster_rcnn import ( + faster_rcnn_r50_6e_shift_all_domains, + faster_rcnn_r50_12e_shift, + faster_rcnn_r50_36e_shift, +) +from .mask_rcnn import ( + mask_rcnn_r50_6e_shift_all_domains, + mask_rcnn_r50_12e_shift, + mask_rcnn_r50_36e_shift, +) +from .semantic_fpn import ( + semantic_fpn_r50_40k_shift, + semantic_fpn_r50_40k_shift_all_domains, + semantic_fpn_r50_160k_shift, + semantic_fpn_r50_160k_shift_all_domains, +) + +AVAILABLE_MODELS = { + "faster_rcnn_r50_6e_shift_all_domains": faster_rcnn_r50_6e_shift_all_domains, # pylint: disable=line-too-long + "faster_rcnn_r50_12e_shift": faster_rcnn_r50_12e_shift, + "faster_rcnn_r50_36e_shift": faster_rcnn_r50_36e_shift, + "mask_rcnn_r50_6e_shift_all_domains": mask_rcnn_r50_6e_shift_all_domains, + "mask_rcnn_r50_12e_shift": mask_rcnn_r50_12e_shift, + "mask_rcnn_r50_36e_shift": mask_rcnn_r50_36e_shift, + "semantic_fpn_r50_40k_shift_all_domains": semantic_fpn_r50_40k_shift_all_domains, # pylint: disable=line-too-long + "semantic_fpn_r50_40k_shift": semantic_fpn_r50_40k_shift, + "semantic_fpn_r50_160k_shift_all_domains": semantic_fpn_r50_160k_shift_all_domains, # pylint: disable=line-too-long + "semantic_fpn_r50_160k_shift": semantic_fpn_r50_160k_shift, +} diff --git a/vis4d/zoo/shift/faster_rcnn/__init__.py b/vis4d/zoo/shift/faster_rcnn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b7bf040b0026b6b3c7bf685770abfa55fde40e0f --- /dev/null +++ b/vis4d/zoo/shift/faster_rcnn/__init__.py @@ -0,0 +1 @@ +"""Faster R-CNN for SHIFT.""" diff --git a/vis4d/zoo/shift/faster_rcnn/faster_rcnn_r50_12e_shift.py b/vis4d/zoo/shift/faster_rcnn/faster_rcnn_r50_12e_shift.py new file mode 100644 index 0000000000000000000000000000000000000000..2f422d3e57eae876e45fc84727d119d8ad75a4b0 --- /dev/null +++ b/vis4d/zoo/shift/faster_rcnn/faster_rcnn_r50_12e_shift.py @@ -0,0 +1,170 @@ +# pylint: disable=duplicate-code +"""Faster RCNN SHIFT training example.""" +from __future__ import annotations + +from torch.optim.lr_scheduler import LinearLR, MultiStepLR +from torch.optim.sgd import SGD + +from vis4d.config import class_config +from vis4d.config.typing import ExperimentConfig, ExperimentParameters +from vis4d.data.io.hdf5 import HDF5Backend +from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback +from vis4d.engine.connectors import CallbackConnector, DataConnector +from vis4d.eval.shift import SHIFTDetectEvaluator +from vis4d.op.base import ResNet +from vis4d.vis.image import BoundingBoxVisualizer +from vis4d.zoo.base import ( + get_default_callbacks_cfg, + get_default_cfg, + get_default_pl_trainer_cfg, + get_lr_scheduler_cfg, + get_optimizer_cfg, +) +from vis4d.zoo.base.data_connectors import ( + CONN_BBOX_2D_TEST, + CONN_BBOX_2D_TRAIN, + CONN_BBOX_2D_VIS, +) +from vis4d.zoo.base.datasets.shift import ( + CONN_SHIFT_DET_EVAL, + get_shift_det_config, +) +from vis4d.zoo.base.models.faster_rcnn import get_faster_rcnn_cfg + + +def get_config() -> ExperimentConfig: + """Returns the Faster-RCNN config dict for the SHIFT detection task. + + Returns: + ExperimentConfig: The configuration + """ + ###################################################### + ## General Config ## + ###################################################### + config = get_default_cfg(exp_name="faster_rcnn_r50_12e_shift") + + # High level hyper parameters + params = ExperimentParameters() + params.samples_per_gpu = 2 + params.workers_per_gpu = 2 + params.lr = 0.02 + params.num_epochs = 12 + params.num_classes = 6 + config.params = params + + ###################################################### + ## Datasets with augmentations ## + ###################################################### + data_root = "data/shift/" + views_to_load = ["front"] + train_split = "train" + test_split = "val" + domain_attr = [{"weather_coarse": "clear", "timeofday_coarse": "daytime"}] + + data_backend = class_config(HDF5Backend) + + config.data = get_shift_det_config( + data_root=data_root, + train_split=train_split, + test_split=test_split, + train_views_to_load=views_to_load, + test_views_to_load=views_to_load, + train_attributes_to_load=domain_attr, + test_attributes_to_load=domain_attr, + data_backend=data_backend, + samples_per_gpu=params.samples_per_gpu, + workers_per_gpu=params.workers_per_gpu, + ) + + ###################################################### + ## MODEL & LOSS ## + ###################################################### + basemodel = class_config( + ResNet, resnet_name="resnet50", pretrained=True, trainable_layers=4 + ) + + config.model, config.loss = get_faster_rcnn_cfg( + num_classes=params.num_classes, basemodel=basemodel + ) + + ###################################################### + ## OPTIMIZERS ## + ###################################################### + config.optimizers = [ + get_optimizer_cfg( + optimizer=class_config( + SGD, lr=params.lr, momentum=0.9, weight_decay=0.0001 + ), + lr_schedulers=[ + get_lr_scheduler_cfg( + class_config( + LinearLR, start_factor=0.001, total_iters=500 + ), + end=500, + epoch_based=False, + ), + get_lr_scheduler_cfg( + class_config(MultiStepLR, milestones=[8, 11], gamma=0.1), + ), + ], + ) + ] + + ###################################################### + ## DATA CONNECTOR ## + ###################################################### + config.train_data_connector = class_config( + DataConnector, key_mapping=CONN_BBOX_2D_TRAIN + ) + + config.test_data_connector = class_config( + DataConnector, key_mapping=CONN_BBOX_2D_TEST + ) + + ###################################################### + ## CALLBACKS ## + ###################################################### + # Logger + callbacks = get_default_callbacks_cfg() + + # Visualizer + callbacks.append( + class_config( + VisualizerCallback, + visualizer=class_config(BoundingBoxVisualizer, vis_freq=100), + output_dir=config.output_dir, + test_connector=class_config( + CallbackConnector, key_mapping=CONN_BBOX_2D_VIS + ), + ) + ) + + # Evaluator + callbacks.append( + class_config( + EvaluatorCallback, + evaluator=class_config( + SHIFTDetectEvaluator, + annotation_path=( + f"{data_root}/discrete/images/val/front/det_2d.json" + ), + attributes_to_load=domain_attr, + ), + test_connector=class_config( + CallbackConnector, key_mapping=CONN_SHIFT_DET_EVAL + ), + metrics_to_eval=[SHIFTDetectEvaluator.METRICS_DET], + ) + ) + + config.callbacks = callbacks + + ###################################################### + ## PL CLI ## + ###################################################### + # PL Trainer args + pl_trainer = get_default_pl_trainer_cfg(config) + pl_trainer.max_epochs = params.num_epochs + config.pl_trainer = pl_trainer + + return config.value_mode() diff --git a/vis4d/zoo/shift/faster_rcnn/faster_rcnn_r50_36e_shift.py b/vis4d/zoo/shift/faster_rcnn/faster_rcnn_r50_36e_shift.py new file mode 100644 index 0000000000000000000000000000000000000000..64dbef288e1f6ec12f2aee9085841533e8e7b471 --- /dev/null +++ b/vis4d/zoo/shift/faster_rcnn/faster_rcnn_r50_36e_shift.py @@ -0,0 +1,170 @@ +# pylint: disable=duplicate-code +"""Faster RCNN SHIFT training example.""" +from __future__ import annotations + +from torch.optim.lr_scheduler import LinearLR, MultiStepLR +from torch.optim.sgd import SGD + +from vis4d.config import class_config +from vis4d.config.typing import ExperimentConfig, ExperimentParameters +from vis4d.data.io.hdf5 import HDF5Backend +from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback +from vis4d.engine.connectors import CallbackConnector, DataConnector +from vis4d.eval.shift import SHIFTDetectEvaluator +from vis4d.op.base import ResNet +from vis4d.vis.image import BoundingBoxVisualizer +from vis4d.zoo.base import ( + get_default_callbacks_cfg, + get_default_cfg, + get_default_pl_trainer_cfg, + get_lr_scheduler_cfg, + get_optimizer_cfg, +) +from vis4d.zoo.base.data_connectors import ( + CONN_BBOX_2D_TEST, + CONN_BBOX_2D_TRAIN, + CONN_BBOX_2D_VIS, +) +from vis4d.zoo.base.datasets.shift import ( + CONN_SHIFT_DET_EVAL, + get_shift_det_config, +) +from vis4d.zoo.base.models.faster_rcnn import get_faster_rcnn_cfg + + +def get_config() -> ExperimentConfig: + """Returns the Faster-RCNN config dict for the SHIFT detection task. + + Returns: + ExperimentConfig: The configuration + """ + ###################################################### + ## General Config ## + ###################################################### + config = get_default_cfg(exp_name="faster_rcnn_r50_36e_shift") + + # High level hyper parameters + params = ExperimentParameters() + params.samples_per_gpu = 2 + params.workers_per_gpu = 2 + params.lr = 0.02 + params.num_epochs = 36 + params.num_classes = 6 + config.params = params + + ###################################################### + ## Datasets with augmentations ## + ###################################################### + data_root = "data/shift/" + views_to_load = ["front"] + train_split = "train" + test_split = "val" + domain_attr = [{"weather_coarse": "clear", "timeofday_coarse": "daytime"}] + + data_backend = class_config(HDF5Backend) + + config.data = get_shift_det_config( + data_root=data_root, + train_split=train_split, + test_split=test_split, + train_views_to_load=views_to_load, + test_views_to_load=views_to_load, + train_attributes_to_load=domain_attr, + test_attributes_to_load=domain_attr, + data_backend=data_backend, + samples_per_gpu=params.samples_per_gpu, + workers_per_gpu=params.workers_per_gpu, + ) + + ###################################################### + ## MODEL & LOSS ## + ###################################################### + basemodel = class_config( + ResNet, resnet_name="resnet50", pretrained=True, trainable_layers=4 + ) + + config.model, config.loss = get_faster_rcnn_cfg( + num_classes=params.num_classes, basemodel=basemodel + ) + + ###################################################### + ## OPTIMIZERS ## + ###################################################### + config.optimizers = [ + get_optimizer_cfg( + optimizer=class_config( + SGD, lr=params.lr, momentum=0.9, weight_decay=0.0001 + ), + lr_schedulers=[ + get_lr_scheduler_cfg( + class_config( + LinearLR, start_factor=0.001, total_iters=500 + ), + end=500, + epoch_based=False, + ), + get_lr_scheduler_cfg( + class_config(MultiStepLR, milestones=[24, 33], gamma=0.1), + ), + ], + ) + ] + + ###################################################### + ## DATA CONNECTOR ## + ###################################################### + config.train_data_connector = class_config( + DataConnector, key_mapping=CONN_BBOX_2D_TRAIN + ) + + config.test_data_connector = class_config( + DataConnector, key_mapping=CONN_BBOX_2D_TEST + ) + + ###################################################### + ## CALLBACKS ## + ###################################################### + # Logger + callbacks = get_default_callbacks_cfg() + + # Visualizer + callbacks.append( + class_config( + VisualizerCallback, + visualizer=class_config(BoundingBoxVisualizer, vis_freq=100), + output_dir=config.output_dir, + test_connector=class_config( + CallbackConnector, key_mapping=CONN_BBOX_2D_VIS + ), + ) + ) + + # Evaluator + callbacks.append( + class_config( + EvaluatorCallback, + evaluator=class_config( + SHIFTDetectEvaluator, + annotation_path=( + f"{data_root}/discrete/images/val/front/det_2d.json" + ), + attributes_to_load=domain_attr, + ), + test_connector=class_config( + CallbackConnector, key_mapping=CONN_SHIFT_DET_EVAL + ), + metrics_to_eval=[SHIFTDetectEvaluator.METRICS_DET], + ) + ) + + config.callbacks = callbacks + + ###################################################### + ## PL CLI ## + ###################################################### + # PL Trainer args + pl_trainer = get_default_pl_trainer_cfg(config) + pl_trainer.max_epochs = params.num_epochs + config.pl_trainer = pl_trainer + + return config.value_mode() diff --git a/vis4d/zoo/shift/faster_rcnn/faster_rcnn_r50_6e_shift_all_domains.py b/vis4d/zoo/shift/faster_rcnn/faster_rcnn_r50_6e_shift_all_domains.py new file mode 100644 index 0000000000000000000000000000000000000000..45d427e501100a5fa487bd7d183eeb38566133b2 --- /dev/null +++ b/vis4d/zoo/shift/faster_rcnn/faster_rcnn_r50_6e_shift_all_domains.py @@ -0,0 +1,170 @@ +# pylint: disable=duplicate-code +"""Faster RCNN SHIFT training example.""" +from __future__ import annotations + +from torch.optim.lr_scheduler import LinearLR, MultiStepLR +from torch.optim.sgd import SGD + +from vis4d.config import class_config +from vis4d.config.typing import ExperimentConfig, ExperimentParameters +from vis4d.data.io.hdf5 import HDF5Backend +from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback +from vis4d.engine.connectors import CallbackConnector, DataConnector +from vis4d.eval.shift import SHIFTDetectEvaluator +from vis4d.op.base import ResNet +from vis4d.vis.image import BoundingBoxVisualizer +from vis4d.zoo.base import ( + get_default_callbacks_cfg, + get_default_cfg, + get_default_pl_trainer_cfg, + get_lr_scheduler_cfg, + get_optimizer_cfg, +) +from vis4d.zoo.base.data_connectors import ( + CONN_BBOX_2D_TEST, + CONN_BBOX_2D_TRAIN, + CONN_BBOX_2D_VIS, +) +from vis4d.zoo.base.datasets.shift import ( + CONN_SHIFT_DET_EVAL, + get_shift_det_config, +) +from vis4d.zoo.base.models.faster_rcnn import get_faster_rcnn_cfg + + +def get_config() -> ExperimentConfig: + """Returns the Faster-RCNN config dict for the SHIFT detection task. + + Returns: + ExperimentConfig: The configuration + """ + ###################################################### + ## General Config ## + ###################################################### + config = get_default_cfg(exp_name="faster_rcnn_r50_6e_shift_all_domains") + + # High level hyper parameters + params = ExperimentParameters() + params.samples_per_gpu = 2 + params.workers_per_gpu = 2 + params.lr = 0.02 + params.num_epochs = 6 + params.num_classes = 6 + config.params = params + + ###################################################### + ## Datasets with augmentations ## + ###################################################### + data_root = "data/shift/" + views_to_load = ["front"] + train_split = "train" + test_split = "val" + domain_attr = None + + data_backend = class_config(HDF5Backend) + + config.data = get_shift_det_config( + data_root=data_root, + train_split=train_split, + test_split=test_split, + train_views_to_load=views_to_load, + test_views_to_load=views_to_load, + train_attributes_to_load=domain_attr, + test_attributes_to_load=domain_attr, + data_backend=data_backend, + samples_per_gpu=params.samples_per_gpu, + workers_per_gpu=params.workers_per_gpu, + ) + + ###################################################### + ## MODEL & LOSS ## + ###################################################### + basemodel = class_config( + ResNet, resnet_name="resnet50", pretrained=True, trainable_layers=4 + ) + + config.model, config.loss = get_faster_rcnn_cfg( + num_classes=params.num_classes, basemodel=basemodel + ) + + ###################################################### + ## OPTIMIZERS ## + ###################################################### + config.optimizers = [ + get_optimizer_cfg( + optimizer=class_config( + SGD, lr=params.lr, momentum=0.9, weight_decay=0.0001 + ), + lr_schedulers=[ + get_lr_scheduler_cfg( + class_config( + LinearLR, start_factor=0.001, total_iters=500 + ), + end=500, + epoch_based=False, + ), + get_lr_scheduler_cfg( + class_config(MultiStepLR, milestones=[4, 5], gamma=0.1), + ), + ], + ) + ] + + ###################################################### + ## DATA CONNECTOR ## + ###################################################### + config.train_data_connector = class_config( + DataConnector, key_mapping=CONN_BBOX_2D_TRAIN + ) + + config.test_data_connector = class_config( + DataConnector, key_mapping=CONN_BBOX_2D_TEST + ) + + ###################################################### + ## CALLBACKS ## + ###################################################### + # Logger + callbacks = get_default_callbacks_cfg() + + # Visualizer + callbacks.append( + class_config( + VisualizerCallback, + visualizer=class_config(BoundingBoxVisualizer, vis_freq=100), + output_dir=config.output_dir, + test_connector=class_config( + CallbackConnector, key_mapping=CONN_BBOX_2D_VIS + ), + ) + ) + + # Evaluator + callbacks.append( + class_config( + EvaluatorCallback, + evaluator=class_config( + SHIFTDetectEvaluator, + annotation_path=( + f"{data_root}/discrete/images/val/front/det_2d.json" + ), + attributes_to_load=domain_attr, + ), + test_connector=class_config( + CallbackConnector, key_mapping=CONN_SHIFT_DET_EVAL + ), + metrics_to_eval=[SHIFTDetectEvaluator.METRICS_DET], + ) + ) + + config.callbacks = callbacks + + ###################################################### + ## PL CLI ## + ###################################################### + # PL Trainer args + pl_trainer = get_default_pl_trainer_cfg(config) + pl_trainer.max_epochs = params.num_epochs + config.pl_trainer = pl_trainer + + return config.value_mode() diff --git a/vis4d/zoo/shift/mask_rcnn/__init__.py b/vis4d/zoo/shift/mask_rcnn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..14605694ec1b083b095b2e098ff55c802550018b --- /dev/null +++ b/vis4d/zoo/shift/mask_rcnn/__init__.py @@ -0,0 +1 @@ +"""Mask R-CNN for SHIFT.""" diff --git a/vis4d/zoo/shift/mask_rcnn/mask_rcnn_r50_12e_shift.py b/vis4d/zoo/shift/mask_rcnn/mask_rcnn_r50_12e_shift.py new file mode 100644 index 0000000000000000000000000000000000000000..a66c4a5140de1d3b2e773da75dc32489fb285b55 --- /dev/null +++ b/vis4d/zoo/shift/mask_rcnn/mask_rcnn_r50_12e_shift.py @@ -0,0 +1,174 @@ +# pylint: disable=duplicate-code +"""Mask RCNN SHIFT training example.""" +from __future__ import annotations + +from torch.optim.lr_scheduler import LinearLR, MultiStepLR +from torch.optim.sgd import SGD + +from vis4d.config import FieldConfigDict, class_config +from vis4d.data.io.hdf5 import HDF5Backend +from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback +from vis4d.engine.connectors import CallbackConnector, DataConnector +from vis4d.eval.shift import SHIFTDetectEvaluator +from vis4d.op.base import ResNet +from vis4d.vis.image import SegMaskVisualizer +from vis4d.zoo.base import ( + get_default_callbacks_cfg, + get_default_cfg, + get_default_pl_trainer_cfg, + get_lr_scheduler_cfg, + get_optimizer_cfg, +) +from vis4d.zoo.base.data_connectors import ( + CONN_BBOX_2D_TEST, + CONN_BBOX_2D_TRAIN, + CONN_INS_MASK_2D_VIS, +) +from vis4d.zoo.base.datasets.shift import ( + CONN_SHIFT_INS_EVAL, + get_shift_instance_seg_config, +) +from vis4d.zoo.base.models.mask_rcnn import get_mask_rcnn_cfg + + +def get_config() -> FieldConfigDict: + """Returns the Faster-RCNN config dict for the SHIFT detection task. + + Returns: + FieldConfigDict: The configuration + """ + ###################################################### + ## General Config ## + ###################################################### + config = get_default_cfg(exp_name="mask_rcnn_r50_12e_shift") + + # High level hyper parameters + params = FieldConfigDict() + params.samples_per_gpu = 2 + params.workers_per_gpu = 2 + params.lr = 0.02 + params.num_epochs = 12 + params.num_classes = 6 + config.params = params + + ###################################################### + ## Datasets with augmentations ## + ###################################################### + data_root = "data/shift/" + views_to_load = ["front"] + train_split = "train" + test_split = "val" + domain_attr = [{"weather_coarse": "clear", "timeofday_coarse": "daytime"}] + + data_backend = class_config(HDF5Backend) + + config.data = get_shift_instance_seg_config( + data_root=data_root, + train_split=train_split, + test_split=test_split, + train_views_to_load=views_to_load, + test_views_to_load=views_to_load, + train_attributes_to_load=domain_attr, + test_attributes_to_load=domain_attr, + data_backend=data_backend, + samples_per_gpu=params.samples_per_gpu, + workers_per_gpu=params.workers_per_gpu, + ) + + ###################################################### + ## MODEL & LOSS ## + ###################################################### + basemodel = class_config( + ResNet, resnet_name="resnet50", pretrained=True, trainable_layers=4 + ) + + config.model, config.loss = get_mask_rcnn_cfg( + num_classes=params.num_classes, + basemodel=basemodel, + no_overlap=True, + ) + + ###################################################### + ## OPTIMIZERS ## + ###################################################### + config.optimizers = [ + get_optimizer_cfg( + optimizer=class_config( + SGD, lr=params.lr, momentum=0.9, weight_decay=0.0001 + ), + lr_schedulers=[ + get_lr_scheduler_cfg( + class_config( + LinearLR, start_factor=0.001, total_iters=500 + ), + end=500, + epoch_based=False, + ), + get_lr_scheduler_cfg( + class_config(MultiStepLR, milestones=[8, 11], gamma=0.1), + ), + ], + ) + ] + + ###################################################### + ## DATA CONNECTOR ## + ###################################################### + config.train_data_connector = class_config( + DataConnector, key_mapping=CONN_BBOX_2D_TRAIN + ) + + config.test_data_connector = class_config( + DataConnector, key_mapping=CONN_BBOX_2D_TEST + ) + + ###################################################### + ## CALLBACKS ## + ###################################################### + # Logger + callbacks = get_default_callbacks_cfg() + + # Visualizer + callbacks.append( + class_config( + VisualizerCallback, + visualizer=class_config(SegMaskVisualizer, vis_freq=25), + output_dir=config.output_dir, + test_connector=class_config( + CallbackConnector, key_mapping=CONN_INS_MASK_2D_VIS + ), + ) + ) + + # Evaluator + callbacks.append( + class_config( + EvaluatorCallback, + evaluator=class_config( + SHIFTDetectEvaluator, + annotation_path=( + f"{data_root}/discrete/images/val/front/det_insseg_2d.json" + ), + attributes_to_load=domain_attr, + ), + test_connector=class_config( + CallbackConnector, key_mapping=CONN_SHIFT_INS_EVAL + ), + metrics_to_eval=[ + SHIFTDetectEvaluator.METRICS_DET, + SHIFTDetectEvaluator.METRICS_INS_SEG, + ], + ) + ) + + config.callbacks = callbacks + + ###################################################### + ## PL CLI ## + ###################################################### + # PL Trainer args + pl_trainer = get_default_pl_trainer_cfg(config) + pl_trainer.max_epochs = params.num_epochs + config.pl_trainer = pl_trainer + + return config.value_mode() diff --git a/vis4d/zoo/shift/mask_rcnn/mask_rcnn_r50_36e_shift.py b/vis4d/zoo/shift/mask_rcnn/mask_rcnn_r50_36e_shift.py new file mode 100644 index 0000000000000000000000000000000000000000..4dd2f746ee810d8bb40a205d99b46a788b20bd7f --- /dev/null +++ b/vis4d/zoo/shift/mask_rcnn/mask_rcnn_r50_36e_shift.py @@ -0,0 +1,174 @@ +# pylint: disable=duplicate-code +"""Mask RCNN SHIFT training example.""" +from __future__ import annotations + +from torch.optim.lr_scheduler import LinearLR, MultiStepLR +from torch.optim.sgd import SGD + +from vis4d.config import FieldConfigDict, class_config +from vis4d.data.io.hdf5 import HDF5Backend +from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback +from vis4d.engine.connectors import CallbackConnector, DataConnector +from vis4d.eval.shift import SHIFTDetectEvaluator +from vis4d.op.base import ResNet +from vis4d.vis.image import SegMaskVisualizer +from vis4d.zoo.base import ( + get_default_callbacks_cfg, + get_default_cfg, + get_default_pl_trainer_cfg, + get_lr_scheduler_cfg, + get_optimizer_cfg, +) +from vis4d.zoo.base.data_connectors import ( + CONN_BBOX_2D_TEST, + CONN_BBOX_2D_TRAIN, + CONN_INS_MASK_2D_VIS, +) +from vis4d.zoo.base.datasets.shift import ( + CONN_SHIFT_INS_EVAL, + get_shift_instance_seg_config, +) +from vis4d.zoo.base.models.mask_rcnn import get_mask_rcnn_cfg + + +def get_config() -> FieldConfigDict: + """Returns the Faster-RCNN config dict for the SHIFT detection task. + + Returns: + FieldConfigDict: The configuration + """ + ###################################################### + ## General Config ## + ###################################################### + config = get_default_cfg(exp_name="mask_rcnn_r50_36e_shift") + + # High level hyper parameters + params = FieldConfigDict() + params.samples_per_gpu = 2 + params.workers_per_gpu = 2 + params.lr = 0.02 + params.num_epochs = 36 + params.num_classes = 6 + config.params = params + + ###################################################### + ## Datasets with augmentations ## + ###################################################### + data_root = "data/shift/" + views_to_load = ["front"] + train_split = "train" + test_split = "val" + domain_attr = [{"weather_coarse": "clear", "timeofday_coarse": "daytime"}] + + data_backend = class_config(HDF5Backend) + + config.data = get_shift_instance_seg_config( + data_root=data_root, + train_split=train_split, + test_split=test_split, + train_views_to_load=views_to_load, + test_views_to_load=views_to_load, + train_attributes_to_load=domain_attr, + test_attributes_to_load=domain_attr, + data_backend=data_backend, + samples_per_gpu=params.samples_per_gpu, + workers_per_gpu=params.workers_per_gpu, + ) + + ###################################################### + ## MODEL & LOSS ## + ###################################################### + basemodel = class_config( + ResNet, resnet_name="resnet50", pretrained=True, trainable_layers=4 + ) + + config.model, config.loss = get_mask_rcnn_cfg( + num_classes=params.num_classes, + basemodel=basemodel, + no_overlap=True, + ) + + ###################################################### + ## OPTIMIZERS ## + ###################################################### + config.optimizers = [ + get_optimizer_cfg( + optimizer=class_config( + SGD, lr=params.lr, momentum=0.9, weight_decay=0.0001 + ), + lr_schedulers=[ + get_lr_scheduler_cfg( + class_config( + LinearLR, start_factor=0.001, total_iters=500 + ), + end=500, + epoch_based=False, + ), + get_lr_scheduler_cfg( + class_config(MultiStepLR, milestones=[24, 33], gamma=0.1), + ), + ], + ) + ] + + ###################################################### + ## DATA CONNECTOR ## + ###################################################### + config.train_data_connector = class_config( + DataConnector, key_mapping=CONN_BBOX_2D_TRAIN + ) + + config.test_data_connector = class_config( + DataConnector, key_mapping=CONN_BBOX_2D_TEST + ) + + ###################################################### + ## CALLBACKS ## + ###################################################### + # Logger + callbacks = get_default_callbacks_cfg() + + # Visualizer + callbacks.append( + class_config( + VisualizerCallback, + visualizer=class_config(SegMaskVisualizer, vis_freq=25), + output_dir=config.output_dir, + test_connector=class_config( + CallbackConnector, key_mapping=CONN_INS_MASK_2D_VIS + ), + ) + ) + + # Evaluator + callbacks.append( + class_config( + EvaluatorCallback, + evaluator=class_config( + SHIFTDetectEvaluator, + annotation_path=( + f"{data_root}/discrete/images/val/front/det_insseg_2d.json" + ), + attributes_to_load=domain_attr, + ), + test_connector=class_config( + CallbackConnector, key_mapping=CONN_SHIFT_INS_EVAL + ), + metrics_to_eval=[ + SHIFTDetectEvaluator.METRICS_DET, + SHIFTDetectEvaluator.METRICS_INS_SEG, + ], + ) + ) + + config.callbacks = callbacks + + ###################################################### + ## PL CLI ## + ###################################################### + # PL Trainer args + pl_trainer = get_default_pl_trainer_cfg(config) + pl_trainer.max_epochs = params.num_epochs + config.pl_trainer = pl_trainer + + return config.value_mode() diff --git a/vis4d/zoo/shift/mask_rcnn/mask_rcnn_r50_6e_shift_all_domains.py b/vis4d/zoo/shift/mask_rcnn/mask_rcnn_r50_6e_shift_all_domains.py new file mode 100644 index 0000000000000000000000000000000000000000..5f92ff6df1367e4ea643ac65f76da5572f69d9af --- /dev/null +++ b/vis4d/zoo/shift/mask_rcnn/mask_rcnn_r50_6e_shift_all_domains.py @@ -0,0 +1,174 @@ +# pylint: disable=duplicate-code +"""Mask RCNN SHIFT training example.""" +from __future__ import annotations + +from torch.optim.lr_scheduler import LinearLR, MultiStepLR +from torch.optim.sgd import SGD + +from vis4d.config import FieldConfigDict, class_config +from vis4d.data.io.hdf5 import HDF5Backend +from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback +from vis4d.engine.connectors import CallbackConnector, DataConnector +from vis4d.eval.shift import SHIFTDetectEvaluator +from vis4d.op.base import ResNet +from vis4d.vis.image import SegMaskVisualizer +from vis4d.zoo.base import ( + get_default_callbacks_cfg, + get_default_cfg, + get_default_pl_trainer_cfg, + get_lr_scheduler_cfg, + get_optimizer_cfg, +) +from vis4d.zoo.base.data_connectors import ( + CONN_BBOX_2D_TEST, + CONN_BBOX_2D_TRAIN, + CONN_INS_MASK_2D_VIS, +) +from vis4d.zoo.base.datasets.shift import ( + CONN_SHIFT_INS_EVAL, + get_shift_instance_seg_config, +) +from vis4d.zoo.base.models.mask_rcnn import get_mask_rcnn_cfg + + +def get_config() -> FieldConfigDict: + """Returns the Faster-RCNN config dict for the SHIFT detection task. + + Returns: + FieldConfigDict: The configuration + """ + ###################################################### + ## General Config ## + ###################################################### + config = get_default_cfg(exp_name="mask_rcnn_r50_6e_shift_all_domains") + + # High level hyper parameters + params = FieldConfigDict() + params.samples_per_gpu = 2 + params.workers_per_gpu = 2 + params.lr = 0.02 + params.num_epochs = 6 + params.num_classes = 6 + config.params = params + + ###################################################### + ## Datasets with augmentations ## + ###################################################### + data_root = "data/shift/" + views_to_load = ["front"] + train_split = "train" + test_split = "val" + domain_attr = None + + data_backend = class_config(HDF5Backend) + + config.data = get_shift_instance_seg_config( + data_root=data_root, + train_split=train_split, + test_split=test_split, + train_views_to_load=views_to_load, + test_views_to_load=views_to_load, + train_attributes_to_load=domain_attr, + test_attributes_to_load=domain_attr, + data_backend=data_backend, + samples_per_gpu=params.samples_per_gpu, + workers_per_gpu=params.workers_per_gpu, + ) + + ###################################################### + ## MODEL & LOSS ## + ###################################################### + basemodel = class_config( + ResNet, resnet_name="resnet50", pretrained=True, trainable_layers=4 + ) + + config.model, config.loss = get_mask_rcnn_cfg( + num_classes=params.num_classes, + basemodel=basemodel, + no_overlap=True, + ) + + ###################################################### + ## OPTIMIZERS ## + ###################################################### + config.optimizers = [ + get_optimizer_cfg( + optimizer=class_config( + SGD, lr=params.lr, momentum=0.9, weight_decay=0.0001 + ), + lr_schedulers=[ + get_lr_scheduler_cfg( + class_config( + LinearLR, start_factor=0.001, total_iters=500 + ), + end=500, + epoch_based=False, + ), + get_lr_scheduler_cfg( + class_config(MultiStepLR, milestones=[4, 5], gamma=0.1), + ), + ], + ) + ] + + ###################################################### + ## DATA CONNECTOR ## + ###################################################### + config.train_data_connector = class_config( + DataConnector, key_mapping=CONN_BBOX_2D_TRAIN + ) + + config.test_data_connector = class_config( + DataConnector, key_mapping=CONN_BBOX_2D_TEST + ) + + ###################################################### + ## CALLBACKS ## + ###################################################### + # Logger + callbacks = get_default_callbacks_cfg() + + # Visualizer + callbacks.append( + class_config( + VisualizerCallback, + visualizer=class_config(SegMaskVisualizer, vis_freq=25), + output_dir=config.output_dir, + test_connector=class_config( + CallbackConnector, key_mapping=CONN_INS_MASK_2D_VIS + ), + ) + ) + + # Evaluator + callbacks.append( + class_config( + EvaluatorCallback, + evaluator=class_config( + SHIFTDetectEvaluator, + annotation_path=( + f"{data_root}/discrete/images/val/front/det_insseg_2d.json" + ), + attributes_to_load=domain_attr, + ), + test_connector=class_config( + CallbackConnector, key_mapping=CONN_SHIFT_INS_EVAL + ), + metrics_to_eval=[ + SHIFTDetectEvaluator.METRICS_DET, + SHIFTDetectEvaluator.METRICS_INS_SEG, + ], + ) + ) + + config.callbacks = callbacks + + ###################################################### + ## PL CLI ## + ###################################################### + # PL Trainer args + pl_trainer = get_default_pl_trainer_cfg(config) + pl_trainer.max_epochs = params.num_epochs + config.pl_trainer = pl_trainer + + return config.value_mode() diff --git a/vis4d/zoo/shift/semantic_fpn/__init__.py b/vis4d/zoo/shift/semantic_fpn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..56b5a6da0c8a3eb7db1a09297403b73a5f30af4d --- /dev/null +++ b/vis4d/zoo/shift/semantic_fpn/__init__.py @@ -0,0 +1 @@ +"""Semantic FPN for SHIFT.""" diff --git a/vis4d/zoo/shift/semantic_fpn/semantic_fpn_r50_160k_shift.py b/vis4d/zoo/shift/semantic_fpn/semantic_fpn_r50_160k_shift.py new file mode 100644 index 0000000000000000000000000000000000000000..eb6d91e2da013146ae8d8a7714795b00117c9ac7 --- /dev/null +++ b/vis4d/zoo/shift/semantic_fpn/semantic_fpn_r50_160k_shift.py @@ -0,0 +1,191 @@ +# pylint: disable=duplicate-code +"""Semantic FPN SHIFT training example.""" +from __future__ import annotations + +from torch.optim.lr_scheduler import LinearLR +from torch.optim.sgd import SGD + +from vis4d.config import class_config +from vis4d.config.typing import ExperimentConfig, ExperimentParameters +from vis4d.data.io.hdf5 import HDF5Backend +from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback +from vis4d.engine.connectors import ( + CallbackConnector, + DataConnector, + LossConnector, +) +from vis4d.engine.loss_module import LossModule +from vis4d.engine.optim import PolyLR +from vis4d.eval.shift import SHIFTSegEvaluator +from vis4d.model.seg.semantic_fpn import SemanticFPN +from vis4d.op.loss import SegCrossEntropyLoss +from vis4d.vis.image import SegMaskVisualizer +from vis4d.zoo.base import ( + get_default_callbacks_cfg, + get_default_cfg, + get_default_pl_trainer_cfg, + get_lr_scheduler_cfg, + get_optimizer_cfg, +) +from vis4d.zoo.base.data_connectors.seg import ( + CONN_MASKS_TEST, + CONN_MASKS_TRAIN, + CONN_SEG_EVAL, + CONN_SEG_LOSS, + CONN_SEG_VIS, +) +from vis4d.zoo.base.datasets.shift import get_shift_sem_seg_config + + +def get_config() -> ExperimentConfig: + """Returns the config dict for the BDD100K semantic segmentation task. + + Returns: + ExperimentParameters: The configuration + """ + ###################################################### + ## General Config ## + ###################################################### + config = get_default_cfg(exp_name="semantic_fpn_r50_160k_shift") + config.sync_batchnorm = True + config.val_check_interval = 2000 + config.check_val_every_n_epoch = None + + ## High level hyper parameters + params = ExperimentParameters() + params.samples_per_gpu = 2 + params.workers_per_gpu = 2 + params.lr = 0.01 + params.num_steps = 160000 + params.num_classes = 23 + config.params = params + + ###################################################### + ## Datasets with augmentations ## + ###################################################### + data_root = "data/shift/" + views_to_load = ["front"] + train_split = "train" + test_split = "val" + domain_attr = [{"weather_coarse": "clear", "timeofday_coarse": "daytime"}] + data_backend = class_config(HDF5Backend) + + config.data = get_shift_sem_seg_config( + data_root=data_root, + train_split=train_split, + test_split=test_split, + train_views_to_load=views_to_load, + test_views_to_load=views_to_load, + train_attributes_to_load=domain_attr, + test_attributes_to_load=domain_attr, + data_backend=data_backend, + samples_per_gpu=params.samples_per_gpu, + workers_per_gpu=params.workers_per_gpu, + ) + + ###################################################### + ## MODEL & LOSS ## + ###################################################### + config.model = class_config(SemanticFPN, num_classes=params.num_classes) + config.loss = class_config( + LossModule, + losses=[ + { + "loss": class_config(SegCrossEntropyLoss), + "connector": class_config( + LossConnector, key_mapping=CONN_SEG_LOSS + ), + }, + ], + ) + + ###################################################### + ## OPTIMIZERS ## + ###################################################### + config.optimizers = [ + get_optimizer_cfg( + optimizer=class_config( + SGD, lr=params.lr, momentum=0.9, weight_decay=0.0005 + ), + lr_schedulers=[ + get_lr_scheduler_cfg( + class_config( + LinearLR, start_factor=0.001, total_iters=500 + ), + end=500, + epoch_based=False, + ), + get_lr_scheduler_cfg( + class_config( + PolyLR, + max_steps=params.num_steps, + min_lr=0.0001, + power=0.9, + ), + epoch_based=False, + ), + ], + ) + ] + + ###################################################### + ## DATA CONNECTOR ## + ###################################################### + config.train_data_connector = class_config( + DataConnector, key_mapping=CONN_MASKS_TRAIN + ) + + config.test_data_connector = class_config( + DataConnector, key_mapping=CONN_MASKS_TEST + ) + + ###################################################### + ## CALLBACKS ## + ###################################################### + callbacks = get_default_callbacks_cfg(epoch_based=False) + + # Evaluator + callbacks.append( + class_config( + EvaluatorCallback, + evaluator=class_config( + SHIFTSegEvaluator, + ), + test_connector=class_config( + CallbackConnector, key_mapping=CONN_SEG_EVAL + ), + ) + ) + + # Visualizer + callbacks.append( + class_config( + VisualizerCallback, + visualizer=class_config(SegMaskVisualizer, vis_freq=20), + output_dir=config.output_dir, + test_connector=class_config( + CallbackConnector, key_mapping=CONN_SEG_VIS + ), + ) + ) + + config.callbacks = callbacks + + ###################################################### + ## PL CLI ## + ###################################################### + # PL Trainer args + pl_trainer = get_default_pl_trainer_cfg(config) + pl_trainer.epoch_based = False + pl_trainer.max_steps = params.num_steps + + pl_trainer.checkpoint_period = config.val_check_interval + pl_trainer.val_check_interval = config.val_check_interval + pl_trainer.check_val_every_n_epoch = config.check_val_every_n_epoch + + pl_trainer.sync_batchnorm = config.sync_batchnorm + # pl_trainer.precision = 16 + + config.pl_trainer = pl_trainer + + return config.value_mode() diff --git a/vis4d/zoo/shift/semantic_fpn/semantic_fpn_r50_160k_shift_all_domains.py b/vis4d/zoo/shift/semantic_fpn/semantic_fpn_r50_160k_shift_all_domains.py new file mode 100644 index 0000000000000000000000000000000000000000..25fcf86184a1633905aa1f97d6529f367616b13d --- /dev/null +++ b/vis4d/zoo/shift/semantic_fpn/semantic_fpn_r50_160k_shift_all_domains.py @@ -0,0 +1,193 @@ +# pylint: disable=duplicate-code +"""Semantic FPN SHIFT training example.""" +from __future__ import annotations + +from torch.optim.lr_scheduler import LinearLR +from torch.optim.sgd import SGD + +from vis4d.config import class_config +from vis4d.config.typing import ExperimentConfig, ExperimentParameters +from vis4d.data.io.hdf5 import HDF5Backend +from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback +from vis4d.engine.connectors import ( + CallbackConnector, + DataConnector, + LossConnector, +) +from vis4d.engine.loss_module import LossModule +from vis4d.engine.optim import PolyLR +from vis4d.eval.shift import SHIFTSegEvaluator +from vis4d.model.seg.semantic_fpn import SemanticFPN +from vis4d.op.loss import SegCrossEntropyLoss +from vis4d.vis.image import SegMaskVisualizer +from vis4d.zoo.base import ( + get_default_callbacks_cfg, + get_default_cfg, + get_default_pl_trainer_cfg, + get_lr_scheduler_cfg, + get_optimizer_cfg, +) +from vis4d.zoo.base.data_connectors.seg import ( + CONN_MASKS_TEST, + CONN_MASKS_TRAIN, + CONN_SEG_EVAL, + CONN_SEG_LOSS, + CONN_SEG_VIS, +) +from vis4d.zoo.base.datasets.shift import get_shift_sem_seg_config + + +def get_config() -> ExperimentConfig: + """Returns the config dict for the BDD100K semantic segmentation task. + + Returns: + ExperimentParameters: The configuration + """ + ###################################################### + ## General Config ## + ###################################################### + config = get_default_cfg( + exp_name="semantic_fpn_r50_160k_shift_all_domains" + ) + config.sync_batchnorm = True + config.val_check_interval = 2000 + config.check_val_every_n_epoch = None + + ## High level hyper parameters + params = ExperimentParameters() + params.samples_per_gpu = 2 + params.workers_per_gpu = 2 + params.lr = 0.01 + params.num_steps = 160000 + params.num_classes = 23 + config.params = params + + ###################################################### + ## Datasets with augmentations ## + ###################################################### + data_root = "data/shift/" + views_to_load = ["front"] + train_split = "train" + test_split = "val" + domain_attr = None + data_backend = class_config(HDF5Backend) + + config.data = get_shift_sem_seg_config( + data_root=data_root, + train_split=train_split, + test_split=test_split, + train_views_to_load=views_to_load, + test_views_to_load=views_to_load, + train_attributes_to_load=domain_attr, + test_attributes_to_load=domain_attr, + data_backend=data_backend, + samples_per_gpu=params.samples_per_gpu, + workers_per_gpu=params.workers_per_gpu, + ) + + ###################################################### + ## MODEL & LOSS ## + ###################################################### + config.model = class_config(SemanticFPN, num_classes=params.num_classes) + config.loss = class_config( + LossModule, + losses=[ + { + "loss": class_config(SegCrossEntropyLoss), + "connector": class_config( + LossConnector, key_mapping=CONN_SEG_LOSS + ), + }, + ], + ) + + ###################################################### + ## OPTIMIZERS ## + ###################################################### + config.optimizers = [ + get_optimizer_cfg( + optimizer=class_config( + SGD, lr=params.lr, momentum=0.9, weight_decay=0.0005 + ), + lr_schedulers=[ + get_lr_scheduler_cfg( + class_config( + LinearLR, start_factor=0.001, total_iters=500 + ), + end=500, + epoch_based=False, + ), + get_lr_scheduler_cfg( + class_config( + PolyLR, + max_steps=params.num_steps, + min_lr=0.0001, + power=0.9, + ), + epoch_based=False, + ), + ], + ) + ] + + ###################################################### + ## DATA CONNECTOR ## + ###################################################### + config.train_data_connector = class_config( + DataConnector, key_mapping=CONN_MASKS_TRAIN + ) + + config.test_data_connector = class_config( + DataConnector, key_mapping=CONN_MASKS_TEST + ) + + ###################################################### + ## CALLBACKS ## + ###################################################### + callbacks = get_default_callbacks_cfg(epoch_based=False) + + # Evaluator + callbacks.append( + class_config( + EvaluatorCallback, + evaluator=class_config( + SHIFTSegEvaluator, + ), + test_connector=class_config( + CallbackConnector, key_mapping=CONN_SEG_EVAL + ), + ) + ) + + # Visualizer + callbacks.append( + class_config( + VisualizerCallback, + visualizer=class_config(SegMaskVisualizer, vis_freq=20), + output_dir=config.output_dir, + test_connector=class_config( + CallbackConnector, key_mapping=CONN_SEG_VIS + ), + ) + ) + + config.callbacks = callbacks + + ###################################################### + ## PL CLI ## + ###################################################### + # PL Trainer args + pl_trainer = get_default_pl_trainer_cfg(config) + pl_trainer.epoch_based = False + pl_trainer.max_steps = params.num_steps + + pl_trainer.checkpoint_period = config.val_check_interval + pl_trainer.val_check_interval = config.val_check_interval + pl_trainer.check_val_every_n_epoch = config.check_val_every_n_epoch + + pl_trainer.sync_batchnorm = config.sync_batchnorm + # pl_trainer.precision = 16 + + config.pl_trainer = pl_trainer + + return config.value_mode() diff --git a/vis4d/zoo/shift/semantic_fpn/semantic_fpn_r50_40k_shift.py b/vis4d/zoo/shift/semantic_fpn/semantic_fpn_r50_40k_shift.py new file mode 100644 index 0000000000000000000000000000000000000000..cd8dc16032662fe410a5ae7431aaf61e4ef5eb74 --- /dev/null +++ b/vis4d/zoo/shift/semantic_fpn/semantic_fpn_r50_40k_shift.py @@ -0,0 +1,191 @@ +# pylint: disable=duplicate-code +"""Semantic FPN SHIFT training example.""" +from __future__ import annotations + +from torch.optim.lr_scheduler import LinearLR +from torch.optim.sgd import SGD + +from vis4d.config import class_config +from vis4d.config.typing import ExperimentConfig, ExperimentParameters +from vis4d.data.io.hdf5 import HDF5Backend +from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback +from vis4d.engine.connectors import ( + CallbackConnector, + DataConnector, + LossConnector, +) +from vis4d.engine.loss_module import LossModule +from vis4d.engine.optim import PolyLR +from vis4d.eval.shift import SHIFTSegEvaluator +from vis4d.model.seg.semantic_fpn import SemanticFPN +from vis4d.op.loss import SegCrossEntropyLoss +from vis4d.vis.image import SegMaskVisualizer +from vis4d.zoo.base import ( + get_default_callbacks_cfg, + get_default_cfg, + get_default_pl_trainer_cfg, + get_lr_scheduler_cfg, + get_optimizer_cfg, +) +from vis4d.zoo.base.data_connectors.seg import ( + CONN_MASKS_TEST, + CONN_MASKS_TRAIN, + CONN_SEG_EVAL, + CONN_SEG_LOSS, + CONN_SEG_VIS, +) +from vis4d.zoo.base.datasets.shift import get_shift_sem_seg_config + + +def get_config() -> ExperimentConfig: + """Returns the config dict for the BDD100K semantic segmentation task. + + Returns: + ExperimentParameters: The configuration + """ + ###################################################### + ## General Config ## + ###################################################### + config = get_default_cfg(exp_name="semantic_fpn_r50_40k_shift") + config.sync_batchnorm = True + config.val_check_interval = 2000 + config.check_val_every_n_epoch = None + + ## High level hyper parameters + params = ExperimentParameters() + params.samples_per_gpu = 2 + params.workers_per_gpu = 2 + params.lr = 0.01 + params.num_steps = 160000 + params.num_classes = 23 + config.params = params + + ###################################################### + ## Datasets with augmentations ## + ###################################################### + data_root = "data/shift/" + views_to_load = ["front"] + train_split = "train" + test_split = "val" + domain_attr = [{"weather_coarse": "clear", "timeofday_coarse": "daytime"}] + data_backend = class_config(HDF5Backend) + + config.data = get_shift_sem_seg_config( + data_root=data_root, + train_split=train_split, + test_split=test_split, + train_views_to_load=views_to_load, + test_views_to_load=views_to_load, + train_attributes_to_load=domain_attr, + test_attributes_to_load=domain_attr, + data_backend=data_backend, + samples_per_gpu=params.samples_per_gpu, + workers_per_gpu=params.workers_per_gpu, + ) + + ###################################################### + ## MODEL & LOSS ## + ###################################################### + config.model = class_config(SemanticFPN, num_classes=params.num_classes) + config.loss = class_config( + LossModule, + losses=[ + { + "loss": class_config(SegCrossEntropyLoss), + "connector": class_config( + LossConnector, key_mapping=CONN_SEG_LOSS + ), + }, + ], + ) + + ###################################################### + ## OPTIMIZERS ## + ###################################################### + config.optimizers = [ + get_optimizer_cfg( + optimizer=class_config( + SGD, lr=params.lr, momentum=0.9, weight_decay=0.0005 + ), + lr_schedulers=[ + get_lr_scheduler_cfg( + class_config( + LinearLR, start_factor=0.001, total_iters=500 + ), + end=500, + epoch_based=False, + ), + get_lr_scheduler_cfg( + class_config( + PolyLR, + max_steps=params.num_steps, + min_lr=0.0001, + power=0.9, + ), + epoch_based=False, + ), + ], + ) + ] + + ###################################################### + ## DATA CONNECTOR ## + ###################################################### + config.train_data_connector = class_config( + DataConnector, key_mapping=CONN_MASKS_TRAIN + ) + + config.test_data_connector = class_config( + DataConnector, key_mapping=CONN_MASKS_TEST + ) + + ###################################################### + ## CALLBACKS ## + ###################################################### + callbacks = get_default_callbacks_cfg(epoch_based=False) + + # Evaluator + callbacks.append( + class_config( + EvaluatorCallback, + evaluator=class_config( + SHIFTSegEvaluator, + ), + test_connector=class_config( + CallbackConnector, key_mapping=CONN_SEG_EVAL + ), + ) + ) + + # Visualizer + callbacks.append( + class_config( + VisualizerCallback, + visualizer=class_config(SegMaskVisualizer, vis_freq=20), + output_dir=config.output_dir, + test_connector=class_config( + CallbackConnector, key_mapping=CONN_SEG_VIS + ), + ) + ) + + config.callbacks = callbacks + + ###################################################### + ## PL CLI ## + ###################################################### + # PL Trainer args + pl_trainer = get_default_pl_trainer_cfg(config) + pl_trainer.epoch_based = False + pl_trainer.max_steps = params.num_steps + + pl_trainer.checkpoint_period = config.val_check_interval + pl_trainer.val_check_interval = config.val_check_interval + pl_trainer.check_val_every_n_epoch = config.check_val_every_n_epoch + + pl_trainer.sync_batchnorm = config.sync_batchnorm + # pl_trainer.precision = 16 + + config.pl_trainer = pl_trainer + + return config.value_mode() diff --git a/vis4d/zoo/shift/semantic_fpn/semantic_fpn_r50_40k_shift_all_domains.py b/vis4d/zoo/shift/semantic_fpn/semantic_fpn_r50_40k_shift_all_domains.py new file mode 100644 index 0000000000000000000000000000000000000000..50c607cede9fd2505201bc1967afd69602dae7b4 --- /dev/null +++ b/vis4d/zoo/shift/semantic_fpn/semantic_fpn_r50_40k_shift_all_domains.py @@ -0,0 +1,191 @@ +# pylint: disable=duplicate-code +"""Semantic FPN SHIFT training example.""" +from __future__ import annotations + +from torch.optim.lr_scheduler import LinearLR +from torch.optim.sgd import SGD + +from vis4d.config import class_config +from vis4d.config.typing import ExperimentConfig, ExperimentParameters +from vis4d.data.io.hdf5 import HDF5Backend +from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback +from vis4d.engine.connectors import ( + CallbackConnector, + DataConnector, + LossConnector, +) +from vis4d.engine.loss_module import LossModule +from vis4d.engine.optim import PolyLR +from vis4d.eval.shift import SHIFTSegEvaluator +from vis4d.model.seg.semantic_fpn import SemanticFPN +from vis4d.op.loss import SegCrossEntropyLoss +from vis4d.vis.image import SegMaskVisualizer +from vis4d.zoo.base import ( + get_default_callbacks_cfg, + get_default_cfg, + get_default_pl_trainer_cfg, + get_lr_scheduler_cfg, + get_optimizer_cfg, +) +from vis4d.zoo.base.data_connectors.seg import ( + CONN_MASKS_TEST, + CONN_MASKS_TRAIN, + CONN_SEG_EVAL, + CONN_SEG_LOSS, + CONN_SEG_VIS, +) +from vis4d.zoo.base.datasets.shift import get_shift_sem_seg_config + + +def get_config() -> ExperimentConfig: + """Returns the config dict for the BDD100K semantic segmentation task. + + Returns: + ExperimentParameters: The configuration + """ + ###################################################### + ## General Config ## + ###################################################### + config = get_default_cfg(exp_name="semantic_fpn_r50_40k_shift_all_domains") + config.sync_batchnorm = True + config.val_check_interval = 2000 + config.check_val_every_n_epoch = None + + ## High level hyper parameters + params = ExperimentParameters() + params.samples_per_gpu = 2 + params.workers_per_gpu = 2 + params.lr = 0.01 + params.num_steps = 160000 + params.num_classes = 23 + config.params = params + + ###################################################### + ## Datasets with augmentations ## + ###################################################### + data_root = "data/shift/" + views_to_load = ["front"] + train_split = "train" + test_split = "val" + domain_attr = [{"weather_coarse": "clear", "timeofday_coarse": "daytime"}] + data_backend = class_config(HDF5Backend) + + config.data = get_shift_sem_seg_config( + data_root=data_root, + train_split=train_split, + test_split=test_split, + train_views_to_load=views_to_load, + test_views_to_load=views_to_load, + train_attributes_to_load=domain_attr, + test_attributes_to_load=domain_attr, + data_backend=data_backend, + samples_per_gpu=params.samples_per_gpu, + workers_per_gpu=params.workers_per_gpu, + ) + + ###################################################### + ## MODEL & LOSS ## + ###################################################### + config.model = class_config(SemanticFPN, num_classes=params.num_classes) + config.loss = class_config( + LossModule, + losses=[ + { + "loss": class_config(SegCrossEntropyLoss), + "connector": class_config( + LossConnector, key_mapping=CONN_SEG_LOSS + ), + }, + ], + ) + + ###################################################### + ## OPTIMIZERS ## + ###################################################### + config.optimizers = [ + get_optimizer_cfg( + optimizer=class_config( + SGD, lr=params.lr, momentum=0.9, weight_decay=0.0005 + ), + lr_schedulers=[ + get_lr_scheduler_cfg( + class_config( + LinearLR, start_factor=0.001, total_iters=500 + ), + end=500, + epoch_based=False, + ), + get_lr_scheduler_cfg( + class_config( + PolyLR, + max_steps=params.num_steps, + min_lr=0.0001, + power=0.9, + ), + epoch_based=False, + ), + ], + ) + ] + + ###################################################### + ## DATA CONNECTOR ## + ###################################################### + config.train_data_connector = class_config( + DataConnector, key_mapping=CONN_MASKS_TRAIN + ) + + config.test_data_connector = class_config( + DataConnector, key_mapping=CONN_MASKS_TEST + ) + + ###################################################### + ## CALLBACKS ## + ###################################################### + callbacks = get_default_callbacks_cfg(epoch_based=False) + + # Evaluator + callbacks.append( + class_config( + EvaluatorCallback, + evaluator=class_config( + SHIFTSegEvaluator, + ), + test_connector=class_config( + CallbackConnector, key_mapping=CONN_SEG_EVAL + ), + ) + ) + + # Visualizer + callbacks.append( + class_config( + VisualizerCallback, + visualizer=class_config(SegMaskVisualizer, vis_freq=20), + output_dir=config.output_dir, + test_connector=class_config( + CallbackConnector, key_mapping=CONN_SEG_VIS + ), + ) + ) + + config.callbacks = callbacks + + ###################################################### + ## PL CLI ## + ###################################################### + # PL Trainer args + pl_trainer = get_default_pl_trainer_cfg(config) + pl_trainer.epoch_based = False + pl_trainer.max_steps = params.num_steps + + pl_trainer.checkpoint_period = config.val_check_interval + pl_trainer.val_check_interval = config.val_check_interval + pl_trainer.check_val_every_n_epoch = config.check_val_every_n_epoch + + pl_trainer.sync_batchnorm = config.sync_batchnorm + # pl_trainer.precision = 16 + + config.pl_trainer = pl_trainer + + return config.value_mode() diff --git a/vis4d/zoo/util.py b/vis4d/zoo/util.py new file mode 100644 index 0000000000000000000000000000000000000000..b58c5f87017cf7a261ccf9a2446b0c123f22e11e --- /dev/null +++ b/vis4d/zoo/util.py @@ -0,0 +1,14 @@ +"""Utility functions for the zoo module.""" + +from __future__ import annotations + +import importlib + +from vis4d.config.typing import ExperimentConfig + + +def get_config_for_name(config_name: str) -> ExperimentConfig: + """Get config for name.""" + module = importlib.import_module("vis4d.zoo." + config_name) + + return module.get_config() diff --git a/vis4d/zoo/vit/__init__.py b/vis4d/zoo/vit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..59c0b0ef7a1b904cf878134f38dd52a0554b0a10 --- /dev/null +++ b/vis4d/zoo/vit/__init__.py @@ -0,0 +1,8 @@ +"""ViT for image classification configs.""" + +from . import vit_small_imagenet, vit_tiny_imagenet + +AVAILABLE_MODELS = { + "vit_small_imagenet": vit_small_imagenet, + "vit_tiny_imagenet": vit_tiny_imagenet, +} diff --git a/vis4d/zoo/vit/vit_small_imagenet.py b/vis4d/zoo/vit/vit_small_imagenet.py new file mode 100644 index 0000000000000000000000000000000000000000..d2c10bda561118188d22a2729fb5b0e3e5a284ce --- /dev/null +++ b/vis4d/zoo/vit/vit_small_imagenet.py @@ -0,0 +1,181 @@ +# pylint: disable=duplicate-code +"""VIT ImageNet-1k training example.""" +from __future__ import annotations + +from torch import nn +from torch.optim.adamw import AdamW +from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR + +from vis4d.config import class_config +from vis4d.config.typing import ExperimentConfig, ExperimentParameters +from vis4d.engine.callbacks import EMACallback, EvaluatorCallback +from vis4d.engine.connectors import ( + CallbackConnector, + DataConnector, + LossConnector, +) +from vis4d.engine.loss_module import LossModule +from vis4d.eval.common.cls import ClassificationEvaluator +from vis4d.model.adapter import ModelEMAAdapter +from vis4d.model.cls.vit import ViTClassifer +from vis4d.zoo.base import ( + get_default_callbacks_cfg, + get_default_cfg, + get_default_pl_trainer_cfg, + get_lr_scheduler_cfg, + get_optimizer_cfg, +) +from vis4d.zoo.base.data_connectors.cls import ( + CONN_CLS_LOSS, + CONN_CLS_TEST, + CONN_CLS_TRAIN, +) +from vis4d.zoo.base.datasets.imagenet import ( + CONN_IMAGENET_CLS_EVAL, + get_imagenet_cls_cfg, +) + + +def get_config() -> ExperimentConfig: + """Returns the config dict for the ImageNet Classification task. + + Returns: + ExperimentConfig: The configuration + """ + ###################################################### + ## General Config ## + ###################################################### + + config = get_default_cfg(exp_name="vit_small_16_imagenet1k") + config.sync_batchnorm = True + config.check_val_every_n_epoch = 1 + config.ema_decay_rate = 0.99996 + + ## High level hyper parameters + params = ExperimentParameters() + params.samples_per_gpu = 256 + params.workers_per_gpu = 8 + params.num_epochs = 300 + params.lr = 1e-3 + params.weight_decay = 0.01 + params.num_classes = 1000 + params.grad_norm_clip = 1.0 + params.accumulate_grad_batches = 1 + config.params = params + + ###################################################### + ## Datasets with augmentations ## + ###################################################### + data_root = "data/imagenet1k" + train_split = "train" + test_split = "val" + image_size = (224, 224) + + config.data = get_imagenet_cls_cfg( + data_root=data_root, + train_split=train_split, + test_split=test_split, + image_size=image_size, + samples_per_gpu=params.samples_per_gpu, + workers_per_gpu=params.workers_per_gpu, + ) + + ###################################################### + ## MODEL ## + ###################################################### + config.model = class_config( + ModelEMAAdapter, + model=class_config( + ViTClassifer, + variant="vit_small_patch16_224", + num_classes=params.num_classes, + drop_rate=0.1, + drop_path_rate=0.1, + ), + ) + + ###################################################### + ## LOSS ## + ###################################################### + config.loss = class_config( + LossModule, + losses={ + "loss": class_config(nn.CrossEntropyLoss), + "connector": class_config( + LossConnector, key_mapping=CONN_CLS_LOSS + ), + }, + ) + + ###################################################### + ## OPTIMIZERS ## + ###################################################### + + config.optimizers = [ + get_optimizer_cfg( + optimizer=class_config( + AdamW, lr=params.lr, weight_decay=params.weight_decay + ), + lr_schedulers=[ + get_lr_scheduler_cfg( + class_config(LinearLR, estart_factor=1e-3, total_iters=10), + end=10, + ), + get_lr_scheduler_cfg( + class_config( + CosineAnnealingLR, + T_max=params.num_epochs, + eta_min=1e-9, + ), + begin=10, + ), + ], + ) + ] + + ###################################################### + ## DATA CONNECTOR ## + ###################################################### + config.train_data_connector = class_config( + DataConnector, + key_mapping=CONN_CLS_TRAIN, + ) + + config.test_data_connector = class_config( + DataConnector, + key_mapping=CONN_CLS_TEST, + ) + + ###################################################### + ## GENERIC CALLBACKS ## + ###################################################### + callbacks = get_default_callbacks_cfg() + + # EMA callback + callbacks.append(class_config(EMACallback)) + + # Evaluator + callbacks.append( + class_config( + EvaluatorCallback, + evaluator=class_config(ClassificationEvaluator), + metrics_to_eval=["Cls"], + test_connector=class_config( + CallbackConnector, key_mapping=CONN_IMAGENET_CLS_EVAL + ), + ) + ) + + config.callbacks = callbacks + + ###################################################### + ## PL CLI ## + ###################################################### + pl_trainer = get_default_pl_trainer_cfg(config) + pl_trainer.max_epochs = params.num_epochs + pl_trainer.gradient_clip_val = params.grad_norm_clip + pl_trainer.accumulate_grad_batches = params.accumulate_grad_batches + + config.pl_trainer = pl_trainer + + return config.value_mode() diff --git a/vis4d/zoo/vit/vit_tiny_imagenet.py b/vis4d/zoo/vit/vit_tiny_imagenet.py new file mode 100644 index 0000000000000000000000000000000000000000..e28542ed5db406b22a114904cb53123600bea198 --- /dev/null +++ b/vis4d/zoo/vit/vit_tiny_imagenet.py @@ -0,0 +1,181 @@ +# pylint: disable=duplicate-code +"""VIT ImageNet-1k training example.""" +from __future__ import annotations + +from torch import nn +from torch.optim.adamw import AdamW +from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR + +from vis4d.config import class_config +from vis4d.config.typing import ExperimentConfig, ExperimentParameters +from vis4d.engine.callbacks import EMACallback, EvaluatorCallback +from vis4d.engine.connectors import ( + CallbackConnector, + DataConnector, + LossConnector, +) +from vis4d.engine.loss_module import LossModule +from vis4d.eval.common.cls import ClassificationEvaluator +from vis4d.model.adapter import ModelEMAAdapter +from vis4d.model.cls.vit import ViTClassifer +from vis4d.zoo.base import ( + get_default_callbacks_cfg, + get_default_cfg, + get_default_pl_trainer_cfg, + get_lr_scheduler_cfg, + get_optimizer_cfg, +) +from vis4d.zoo.base.data_connectors.cls import ( + CONN_CLS_LOSS, + CONN_CLS_TEST, + CONN_CLS_TRAIN, +) +from vis4d.zoo.base.datasets.imagenet import ( + CONN_IMAGENET_CLS_EVAL, + get_imagenet_cls_cfg, +) + + +def get_config() -> ExperimentConfig: + """Returns the config dict for the ImageNet Classification task. + + Returns: + ExperimentConfig: The configuration + """ + ###################################################### + ## General Config ## + ###################################################### + + config = get_default_cfg(exp_name="vit_tiny_16_imagenet1k") + config.sync_batchnorm = True + config.check_val_every_n_epoch = 1 + + ## High level hyper parameters + params = ExperimentParameters() + params.samples_per_gpu = 256 + params.workers_per_gpu = 8 + params.num_epochs = 300 + params.lr = 1e-3 + params.weight_decay = 0.01 + params.num_classes = 1000 + params.grad_norm_clip = 1.0 + params.accumulate_grad_batches = 1 + config.params = params + + ###################################################### + ## Datasets with augmentations ## + ###################################################### + data_root = "data/imagenet1k" + train_split = "train" + test_split = "val" + image_size = (224, 224) + + config.data = get_imagenet_cls_cfg( + data_root=data_root, + train_split=train_split, + test_split=test_split, + image_size=image_size, + samples_per_gpu=params.samples_per_gpu, + workers_per_gpu=params.workers_per_gpu, + ) + + ###################################################### + ## MODEL ## + ###################################################### + config.model = class_config( + ModelEMAAdapter, + model=class_config( + ViTClassifer, + variant="vit_tiny_patch16_224", + num_classes=params.num_classes, + drop_rate=0.1, + drop_path_rate=0.1, + ), + decay=0.99998, + ) + + ###################################################### + ## LOSS ## + ###################################################### + config.loss = class_config( + LossModule, + losses={ + "loss": class_config(nn.CrossEntropyLoss), + "connector": class_config( + LossConnector, key_mapping=CONN_CLS_LOSS + ), + }, + ) + + ###################################################### + ## OPTIMIZERS ## + ###################################################### + + config.optimizers = [ + get_optimizer_cfg( + optimizer=class_config( + AdamW, lr=params.lr, weight_decay=params.weight_decay + ), + lr_schedulers=[ + get_lr_scheduler_cfg( + class_config(LinearLR, estart_factor=1e-3, total_iters=10), + end=10, + ), + get_lr_scheduler_cfg( + class_config( + CosineAnnealingLR, + T_max=params.num_epochs, + eta_min=1e-9, + ), + begin=10, + ), + ], + ) + ] + + ###################################################### + ## DATA CONNECTOR ## + ###################################################### + config.train_data_connector = class_config( + DataConnector, + key_mapping=CONN_CLS_TRAIN, + ) + + config.test_data_connector = class_config( + DataConnector, + key_mapping=CONN_CLS_TEST, + ) + + ###################################################### + ## GENERIC CALLBACKS ## + ###################################################### + callbacks = get_default_callbacks_cfg() + + # EMA callback + callbacks.append(class_config(EMACallback)) + + # Evaluator + callbacks.append( + class_config( + EvaluatorCallback, + evaluator=class_config(ClassificationEvaluator), + metrics_to_eval=["Cls"], + test_connector=class_config( + CallbackConnector, key_mapping=CONN_IMAGENET_CLS_EVAL + ), + ) + ) + + config.callbacks = callbacks + + ###################################################### + ## PL CLI ## + ###################################################### + pl_trainer = get_default_pl_trainer_cfg(config) + pl_trainer.max_epochs = params.num_epochs + pl_trainer.gradient_clip_val = params.grad_norm_clip + pl_trainer.accumulate_grad_batches = params.accumulate_grad_batches + + config.pl_trainer = pl_trainer + + return config.value_mode() diff --git a/vis4d/zoo/yolox/__init__.py b/vis4d/zoo/yolox/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..17ba54d4d1130c41aa4d141eb035c6923dc7e0bd --- /dev/null +++ b/vis4d/zoo/yolox/__init__.py @@ -0,0 +1,8 @@ +"""YOLOX Model Zoo.""" + +from . import yolox_s_300e_coco, yolox_tiny_300e_coco + +AVAILABLE_MODELS = { + "yolox_s_300e_coco": yolox_s_300e_coco, + "yolox_tiny_300e_coco": yolox_tiny_300e_coco, +} diff --git a/vis4d/zoo/yolox/data.py b/vis4d/zoo/yolox/data.py new file mode 100644 index 0000000000000000000000000000000000000000..eefda5e35a025aede27ee6a0b10ff8a93b9d56d0 --- /dev/null +++ b/vis4d/zoo/yolox/data.py @@ -0,0 +1,261 @@ +# pylint: disable=duplicate-code +"""COCO data loading config for YOLOX object detection.""" +from __future__ import annotations + +from collections.abc import Sequence + +from ml_collections import ConfigDict + +from vis4d.config import class_config +from vis4d.config.typing import DataConfig +from vis4d.data.const import CommonKeys as K +from vis4d.data.data_pipe import DataPipe, MultiSampleDataPipe +from vis4d.data.datasets.coco import COCO +from vis4d.data.io import DataBackend +from vis4d.data.loader import build_train_dataloader, default_collate +from vis4d.data.transforms.affine import ( + AffineBoxes2D, + AffineImages, + GenAffineParameters, +) +from vis4d.data.transforms.base import RandomApply, compose +from vis4d.data.transforms.flip import FlipBoxes2D, FlipImages +from vis4d.data.transforms.mixup import ( + GenMixupParameters, + MixupBoxes2D, + MixupImages, +) +from vis4d.data.transforms.mosaic import ( + GenMosaicParameters, + MosaicBoxes2D, + MosaicImages, +) +from vis4d.data.transforms.pad import PadImages +from vis4d.data.transforms.photometric import RandomHSV +from vis4d.data.transforms.post_process import PostProcessBoxes2D +from vis4d.data.transforms.resize import ( + GenResizeParameters, + ResizeBoxes2D, + ResizeImages, +) +from vis4d.data.transforms.to_tensor import ToTensor +from vis4d.engine.connectors import data_key, pred_key +from vis4d.zoo.base import get_inference_dataloaders_cfg +from vis4d.zoo.base.callable import get_callable_cfg + +CONN_COCO_BBOX_EVAL = { + "coco_image_id": data_key(K.sample_names), + "pred_boxes": pred_key("boxes"), + "pred_scores": pred_key("scores"), + "pred_classes": pred_key("class_ids"), +} + +CONN_COCO_MASK_EVAL = { + "coco_image_id": data_key(K.sample_names), + "pred_boxes": pred_key("boxes.boxes"), + "pred_scores": pred_key("boxes.scores"), + "pred_classes": pred_key("boxes.class_ids"), + "pred_masks": pred_key("masks"), +} + + +def get_train_dataloader( + data_root: str, + split: str, + keys_to_load: Sequence[str], + data_backend: None | DataBackend, + image_size: tuple[int, int], + scaling_ratio_range: tuple[float, float], + use_mixup: bool, + samples_per_gpu: int, + workers_per_gpu: int, +) -> ConfigDict: + """Get the default train dataloader for COCO detection.""" + # Train Dataset + train_dataset_cfg = class_config( + COCO, + keys_to_load=keys_to_load, + data_root=data_root, + split=split, + remove_empty=False, + image_channel_mode="BGR", + data_backend=data_backend, + ) + + # Train Preprocessing + preprocess_transforms = [ + [ + class_config(GenMosaicParameters, out_shape=image_size), + class_config(MosaicImages, imresize_backend="cv2"), + class_config(MosaicBoxes2D), + ] + ] + + preprocess_transforms += [ + [ + class_config( + GenAffineParameters, + scaling_ratio_range=scaling_ratio_range, + border=(-image_size[0] // 2, -image_size[1] // 2), + ), + class_config(AffineImages, as_int=True), + class_config(AffineBoxes2D), + ] + ] + + if use_mixup: + preprocess_transforms += [ + [ + class_config( + GenMixupParameters, + out_shape=image_size, + mixup_ratio_dist="const", + scale_range=(0.8, 1.6), + pad_value=114.0, + ), + class_config(MixupImages, imresize_backend="cv2"), + class_config(MixupBoxes2D), + ] + ] + + preprocess_transforms.append( + [class_config(PostProcessBoxes2D, min_area=1.0)] + ) + + train_batchprocess_cfg = class_config( + compose, + transforms=[ + class_config(RandomHSV, same_on_batch=False), + class_config( + RandomApply, + transforms=[ + class_config(FlipImages), + class_config(FlipBoxes2D), + ], + probability=0.5, + same_on_batch=False, + ), + class_config( + GenResizeParameters, + shape=image_size, + keep_ratio=True, + same_on_batch=False, + ), + class_config(ResizeImages, imresize_backend="cv2"), + class_config(ResizeBoxes2D), + class_config(PadImages, value=114.0, pad2square=True), + class_config(ToTensor), + ], + ) + + return class_config( + build_train_dataloader, + dataset=class_config( + MultiSampleDataPipe, + datasets=train_dataset_cfg, + preprocess_fn=preprocess_transforms, + ), + samples_per_gpu=samples_per_gpu, + workers_per_gpu=workers_per_gpu, + batchprocess_fn=train_batchprocess_cfg, + collate_fn=get_callable_cfg(default_collate), + pin_memory=True, + shuffle=True, + ) + + +def get_test_dataloader( + data_root: str, + split: str, + keys_to_load: Sequence[str], + data_backend: None | DataBackend, + image_size: tuple[int, int], + samples_per_gpu: int, + workers_per_gpu: int, +) -> ConfigDict: + """Get the default test dataloader for COCO detection.""" + # Test Dataset + test_dataset = class_config( + COCO, + keys_to_load=keys_to_load, + data_root=data_root, + split=split, + image_channel_mode="BGR", + data_backend=data_backend, + ) + + # Test Preprocessing + preprocess_transforms = [ + class_config(GenResizeParameters, shape=image_size, keep_ratio=True), + class_config(ResizeImages, imresize_backend="cv2"), + ] + + test_preprocess_cfg = class_config( + compose, transforms=preprocess_transforms + ) + + test_batchprocess_cfg = class_config( + compose, + transforms=[ + class_config(PadImages, value=114.0, pad2square=True), + class_config(ToTensor), + ], + ) + + # Test Dataset Config + test_dataset_cfg = class_config( + DataPipe, datasets=test_dataset, preprocess_fn=test_preprocess_cfg + ) + + return get_inference_dataloaders_cfg( + datasets_cfg=test_dataset_cfg, + batchprocess_cfg=test_batchprocess_cfg, + samples_per_gpu=samples_per_gpu, + workers_per_gpu=workers_per_gpu, + ) + + +def get_coco_yolox_cfg( + data_root: str = "data/coco", + train_split: str = "train2017", + train_keys_to_load: Sequence[str] = ( + K.images, + K.boxes2d, + K.boxes2d_classes, + ), + test_split: str = "val2017", + test_keys_to_load: Sequence[str] = (K.images, K.original_images), + data_backend: None | ConfigDict = None, + train_image_size: tuple[int, int] = (640, 640), + scaling_ratio_range: tuple[float, float] = (0.1, 2.0), + use_mixup: bool = True, + test_image_size: tuple[int, int] = (640, 640), + samples_per_gpu: int = 2, + workers_per_gpu: int = 2, +) -> DataConfig: + """Get the default config for COCO detection.""" + data = DataConfig() + + data.train_dataloader = get_train_dataloader( + data_root=data_root, + split=train_split, + keys_to_load=train_keys_to_load, + data_backend=data_backend, + image_size=train_image_size, + scaling_ratio_range=scaling_ratio_range, + use_mixup=use_mixup, + samples_per_gpu=samples_per_gpu, + workers_per_gpu=workers_per_gpu, + ) + + data.test_dataloader = get_test_dataloader( + data_root=data_root, + split=test_split, + keys_to_load=test_keys_to_load, + data_backend=data_backend, + image_size=test_image_size, + samples_per_gpu=1, + workers_per_gpu=workers_per_gpu, + ) + + return data diff --git a/vis4d/zoo/yolox/yolox_s_300e_coco.py b/vis4d/zoo/yolox/yolox_s_300e_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..5b65110335ba1f394f2a01f37a7fe1af380bc4e4 --- /dev/null +++ b/vis4d/zoo/yolox/yolox_s_300e_coco.py @@ -0,0 +1,159 @@ +# pylint: disable=duplicate-code +"""YOLOX COCO.""" +from __future__ import annotations + +from lightning.pytorch.callbacks import ModelCheckpoint + +from vis4d.config import class_config +from vis4d.config.typing import ExperimentConfig, ExperimentParameters +from vis4d.data.const import CommonKeys as K +from vis4d.data.io.hdf5 import HDF5Backend +from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback +from vis4d.engine.connectors import CallbackConnector, DataConnector +from vis4d.eval.coco import COCODetectEvaluator +from vis4d.vis.image import BoundingBoxVisualizer +from vis4d.zoo.base import ( + get_default_callbacks_cfg, + get_default_cfg, + get_default_pl_trainer_cfg, +) +from vis4d.zoo.base.data_connectors import CONN_BBOX_2D_TEST, CONN_BBOX_2D_VIS +from vis4d.zoo.base.models.yolox import ( + get_yolox_callbacks_cfg, + get_yolox_cfg, + get_yolox_optimizers_cfg, +) +from vis4d.zoo.yolox.data import CONN_COCO_BBOX_EVAL, get_coco_yolox_cfg + +CONN_BBOX_2D_TRAIN = {"images": K.images} + + +def get_config() -> ExperimentConfig: + """Returns the YOLOX config dict for the coco detection task. + + Returns: + ExperimentConfig: The configuration + """ + ###################################################### + ## General Config ## + ###################################################### + config = get_default_cfg(exp_name="yolox_s_300e_coco") + config.checkpoint_period = 15 + config.check_val_every_n_epoch = 10 + + # High level hyper parameters + params = ExperimentParameters() + params.samples_per_gpu = 8 + params.workers_per_gpu = 4 + params.lr = 0.01 + params.num_epochs = 300 + params.num_classes = 80 + config.params = params + + ###################################################### + ## Datasets with augmentations ## + ###################################################### + data_root = "data/coco" + train_split = "train2017" + test_split = "val2017" + + data_backend = class_config(HDF5Backend) + + config.data = get_coco_yolox_cfg( + data_root=data_root, + train_split=train_split, + test_split=test_split, + data_backend=data_backend, + samples_per_gpu=params.samples_per_gpu, + workers_per_gpu=params.workers_per_gpu, + ) + + ###################################################### + ## MODEL & LOSS ## + ###################################################### + config.model, config.loss = get_yolox_cfg(params.num_classes, "small") + + ###################################################### + ## OPTIMIZERS ## + ###################################################### + num_last_epochs, warmup_epochs = 15, 5 + config.optimizers = get_yolox_optimizers_cfg( + params.lr, params.num_epochs, warmup_epochs, num_last_epochs + ) + + ###################################################### + ## DATA CONNECTOR ## + ###################################################### + config.train_data_connector = class_config( + DataConnector, key_mapping=CONN_BBOX_2D_TRAIN + ) + + config.test_data_connector = class_config( + DataConnector, key_mapping=CONN_BBOX_2D_TEST + ) + + ###################################################### + ## CALLBACKS ## + ###################################################### + # Logger + callbacks = get_default_callbacks_cfg( + refresh_rate=config.log_every_n_steps + ) + + # YOLOX callbacks + callbacks += get_yolox_callbacks_cfg( + switch_epoch=params.num_epochs - num_last_epochs + ) + + # Visualizer + callbacks.append( + class_config( + VisualizerCallback, + visualizer=class_config( + BoundingBoxVisualizer, vis_freq=100, image_mode="BGR" + ), + output_dir=config.output_dir, + test_connector=class_config( + CallbackConnector, key_mapping=CONN_BBOX_2D_VIS + ), + ) + ) + + # Evaluator + callbacks.append( + class_config( + EvaluatorCallback, + evaluator=class_config( + COCODetectEvaluator, data_root=data_root, split=test_split + ), + metrics_to_eval=["Det"], + test_connector=class_config( + CallbackConnector, key_mapping=CONN_COCO_BBOX_EVAL + ), + ) + ) + + config.callbacks = callbacks + + ###################################################### + ## PL CLI ## + ###################################################### + # PL Trainer args + pl_trainer = get_default_pl_trainer_cfg(config) + pl_trainer.max_epochs = params.num_epochs + pl_trainer.check_val_every_n_epoch = config.check_val_every_n_epoch + pl_trainer.checkpoint_callback = class_config( + ModelCheckpoint, + dirpath=config.get_ref("output_dir") + "/checkpoints", + verbose=True, + save_last=True, + save_on_train_epoch_end=True, + every_n_epochs=config.checkpoint_period, + save_top_k=3, + mode="max", + monitor="step", + ) + pl_trainer.wandb = True + config.pl_trainer = pl_trainer + + return config.value_mode() diff --git a/vis4d/zoo/yolox/yolox_tiny_300e_coco.py b/vis4d/zoo/yolox/yolox_tiny_300e_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..b6bad53e5718edc606e5c6e7f236c91ad10f942b --- /dev/null +++ b/vis4d/zoo/yolox/yolox_tiny_300e_coco.py @@ -0,0 +1,162 @@ +# pylint: disable=duplicate-code +"""YOLOX COCO.""" +from __future__ import annotations + +from lightning.pytorch.callbacks import ModelCheckpoint + +from vis4d.config import class_config +from vis4d.config.typing import ExperimentConfig, ExperimentParameters +from vis4d.data.const import CommonKeys as K +from vis4d.data.io.hdf5 import HDF5Backend +from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback +from vis4d.engine.connectors import CallbackConnector, DataConnector +from vis4d.eval.coco import COCODetectEvaluator +from vis4d.vis.image import BoundingBoxVisualizer +from vis4d.zoo.base import ( + get_default_callbacks_cfg, + get_default_cfg, + get_default_pl_trainer_cfg, +) +from vis4d.zoo.base.data_connectors import CONN_BBOX_2D_TEST, CONN_BBOX_2D_VIS +from vis4d.zoo.base.models.yolox import ( + get_yolox_callbacks_cfg, + get_yolox_cfg, + get_yolox_optimizers_cfg, +) +from vis4d.zoo.yolox.data import CONN_COCO_BBOX_EVAL, get_coco_yolox_cfg + +CONN_BBOX_2D_TRAIN = {"images": K.images} + + +def get_config() -> ExperimentConfig: + """Returns the YOLOX config dict for the coco detection task. + + Returns: + ExperimentConfig: The configuration + """ + ###################################################### + ## General Config ## + ###################################################### + config = get_default_cfg(exp_name="yolox_tiny_300e_coco") + config.checkpoint_period = 15 + config.check_val_every_n_epoch = 10 + + # High level hyper parameters + params = ExperimentParameters() + params.samples_per_gpu = 8 + params.workers_per_gpu = 4 + params.lr = 0.01 + params.num_epochs = 300 + params.num_classes = 80 + config.params = params + + ###################################################### + ## Datasets with augmentations ## + ###################################################### + data_root = "data/coco" + train_split = "train2017" + test_split = "val2017" + + data_backend = class_config(HDF5Backend) + + config.data = get_coco_yolox_cfg( + data_root=data_root, + train_split=train_split, + test_split=test_split, + data_backend=data_backend, + scaling_ratio_range=(0.5, 1.5), + use_mixup=False, + test_image_size=(416, 416), + samples_per_gpu=params.samples_per_gpu, + workers_per_gpu=params.workers_per_gpu, + ) + + ###################################################### + ## MODEL & LOSS ## + ###################################################### + config.model, config.loss = get_yolox_cfg(params.num_classes, "tiny") + + ###################################################### + ## OPTIMIZERS ## + ###################################################### + num_last_epochs, warmup_epochs = 15, 5 + config.optimizers = get_yolox_optimizers_cfg( + params.lr, params.num_epochs, warmup_epochs, num_last_epochs + ) + + ###################################################### + ## DATA CONNECTOR ## + ###################################################### + config.train_data_connector = class_config( + DataConnector, key_mapping=CONN_BBOX_2D_TRAIN + ) + + config.test_data_connector = class_config( + DataConnector, key_mapping=CONN_BBOX_2D_TEST + ) + + ###################################################### + ## CALLBACKS ## + ###################################################### + # Logger + callbacks = get_default_callbacks_cfg( + refresh_rate=config.log_every_n_steps + ) + + # YOLOX callbacks + callbacks += get_yolox_callbacks_cfg( + switch_epoch=params.num_epochs - num_last_epochs, shape=(320, 320) + ) + + # Visualizer + callbacks.append( + class_config( + VisualizerCallback, + visualizer=class_config( + BoundingBoxVisualizer, vis_freq=100, image_mode="BGR" + ), + output_dir=config.output_dir, + test_connector=class_config( + CallbackConnector, key_mapping=CONN_BBOX_2D_VIS + ), + ) + ) + + # Evaluator + callbacks.append( + class_config( + EvaluatorCallback, + evaluator=class_config( + COCODetectEvaluator, data_root=data_root, split=test_split + ), + metrics_to_eval=["Det"], + test_connector=class_config( + CallbackConnector, key_mapping=CONN_COCO_BBOX_EVAL + ), + ) + ) + + config.callbacks = callbacks + + ###################################################### + ## PL CLI ## + ###################################################### + # PL Trainer args + pl_trainer = get_default_pl_trainer_cfg(config) + pl_trainer.max_epochs = params.num_epochs + pl_trainer.check_val_every_n_epoch = config.check_val_every_n_epoch + pl_trainer.checkpoint_callback = class_config( + ModelCheckpoint, + dirpath=config.get_ref("output_dir") + "/checkpoints", + verbose=True, + save_last=True, + save_on_train_epoch_end=True, + every_n_epochs=config.checkpoint_period, + save_top_k=3, + mode="max", + monitor="step", + ) + pl_trainer.wandb = True + config.pl_trainer = pl_trainer + + return config.value_mode() diff --git a/wilddet3d/__init__.py b/wilddet3d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8ce3fca0e97e250dd3b5af327624c378434d3949 --- /dev/null +++ b/wilddet3d/__init__.py @@ -0,0 +1,29 @@ +"""WildDet3D: Open-Vocabulary Monocular 3D Object Detection in the Wild.""" + +import sys +from pathlib import Path + +# Add third_party submodules to Python path +_third_party = Path(__file__).parent.parent / "third_party" +_sam3_path = str(_third_party / "sam3") +_lingbot_path = str(_third_party / "lingbot_depth") + +if _sam3_path not in sys.path: + sys.path.insert(0, _sam3_path) +if _lingbot_path not in sys.path: + sys.path.insert(0, _lingbot_path) + +from .data_types import Det3DOut, WildDet3DInput, WildDet3DOut +from .inference import WildDet3DPredictor, build_model +from .model import WildDet3D +from .preprocessing import preprocess + +__all__ = [ + "WildDet3D", + "WildDet3DPredictor", + "WildDet3DInput", + "WildDet3DOut", + "Det3DOut", + "build_model", + "preprocess", +] diff --git a/wilddet3d/__pycache__/__init__.cpython-311.pyc b/wilddet3d/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2b553fec1d09854d0bbbaeb9ed86d53e833d2d6 Binary files /dev/null and b/wilddet3d/__pycache__/__init__.cpython-311.pyc differ diff --git a/wilddet3d/__pycache__/data_types.cpython-311.pyc b/wilddet3d/__pycache__/data_types.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5951004c70b0d5c933fe5795c3aec5ccd652d885 Binary files /dev/null and b/wilddet3d/__pycache__/data_types.cpython-311.pyc differ diff --git a/wilddet3d/__pycache__/inference.cpython-311.pyc b/wilddet3d/__pycache__/inference.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..404984d26919949c7d9eb15acd7cbc5ba22237bb Binary files /dev/null and b/wilddet3d/__pycache__/inference.cpython-311.pyc differ diff --git a/wilddet3d/__pycache__/model.cpython-311.pyc b/wilddet3d/__pycache__/model.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c25246280fffb9cd1da2963ee69ad33eee0b1fd1 Binary files /dev/null and b/wilddet3d/__pycache__/model.cpython-311.pyc differ diff --git a/wilddet3d/__pycache__/preprocessing.cpython-311.pyc b/wilddet3d/__pycache__/preprocessing.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e770793d4dbe2c093b7a9d4212527c980aec66bb Binary files /dev/null and b/wilddet3d/__pycache__/preprocessing.cpython-311.pyc differ diff --git a/wilddet3d/connector.py b/wilddet3d/connector.py new file mode 100644 index 0000000000000000000000000000000000000000..ea034b79986782562e0822d7aaba9b6a5d8ade7d --- /dev/null +++ b/wilddet3d/connector.py @@ -0,0 +1,1852 @@ +"""WildDet3D data connector and collator configuration. + +This module provides: +1. DataConnector key mappings for train/test +2. WildDet3DCollator: converts per-image DataLoader output to WildDet3DInput +3. Point prompt sampling (from mask or box region) +""" + +from __future__ import annotations + +import random +import time +from collections import defaultdict +from typing import List, Literal, Optional + +import numpy as np +import torch +from torch import Tensor + +from wilddet3d.ops.profiler import profile_start, profile_stop + +from ml_collections import ConfigDict +from vis4d.config import class_config +from vis4d.data.const import CommonKeys as K +from vis4d.engine.connectors import DataConnector, data_key, pred_key + +from wilddet3d.model import WildDet3DInput + + +# ============================================================================ +# Point Sampling Utilities +# ============================================================================ + +def sample_points_from_mask( + mask: np.ndarray, + n_points: int, + mode: Literal["centered", "random_mask", "random_box"], + box: Optional[np.ndarray] = None, +) -> np.ndarray: + """Sample points from a binary mask. + + Args: + mask: Binary mask (H, W), 1=foreground, 0=background + n_points: Number of points to sample + mode: Sampling mode + - "centered": sample from mask center (farthest from edges) + - "random_mask": uniform sample from mask interior + - "random_box": uniform sample from box, label from mask + box: Box in xyxy format (required for random_box mode) + + Returns: + Points array (n_points, 3) with (x, y, label) + """ + if mode == "centered": + return _center_positive_sample(mask, n_points) + elif mode == "random_mask": + return _uniform_positive_sample(mask, n_points) + elif mode == "random_box": + assert box is not None, "'random_box' mode requires a provided box." + return _uniform_sample_from_box(mask, box, n_points) + else: + raise ValueError(f"Unknown point sampling mode {mode}.") + + +def _uniform_positive_sample(mask: np.ndarray, n_points: int) -> np.ndarray: + """Sample positive points uniformly from mask interior.""" + mask_points = np.stack(np.nonzero(mask), axis=0).transpose(1, 0) + if len(mask_points) == 0: + # Empty mask, return center of image as fallback + h, w = mask.shape + return np.array([[w // 2, h // 2, 1]] * n_points) + + selected_idxs = np.random.randint(low=0, high=len(mask_points), size=n_points) + selected_points = mask_points[selected_idxs] + selected_points = selected_points[:, ::-1] # (y, x) -> (x, y) + labels = np.ones((len(selected_points), 1)) + return np.concatenate([selected_points, labels], axis=1) + + +def _center_positive_sample(mask: np.ndarray, n_points: int) -> np.ndarray: + """Sample points farthest from mask edges (using distance transform).""" + try: + import cv2 + except ImportError: + # Fallback to uniform sampling if cv2 not available + return _uniform_positive_sample(mask, n_points) + + if np.max(mask) == 0: + h, w = mask.shape + return np.array([[w // 2, h // 2, 1]] * n_points) + + padded_mask = np.pad(mask.astype(np.uint8), 1) + points = [] + + for _ in range(n_points): + if np.max(padded_mask) == 0: + break + dist = cv2.distanceTransform(padded_mask, cv2.DIST_L2, 0) + point = np.unravel_index(dist.argmax(), dist.shape) + padded_mask[point[0], point[1]] = 0 + points.append(point[::-1]) # (y, x) -> (x, y) + + if len(points) == 0: + h, w = mask.shape + return np.array([[w // 2, h // 2, 1]] * n_points) + + points = np.stack(points, axis=0) + points = points - 1 # Subtract padding offset + labels = np.ones((len(points), 1)) + return np.concatenate([points, labels], axis=1) + + +def _uniform_sample_from_box( + mask: np.ndarray, + box: np.ndarray, + n_points: int, +) -> np.ndarray: + """Sample points uniformly from box, determine labels from mask.""" + int_box = np.ceil(box).astype(int) + x1, y1, x2, y2 = int_box + + # Ensure valid box + x2 = max(x2, x1 + 1) + y2 = max(y2, y1 + 1) + + x = np.random.randint(low=x1, high=x2, size=n_points) + y = np.random.randint(low=y1, high=y2, size=n_points) + + # Clip to mask boundaries + h, w = mask.shape + x = np.clip(x, 0, w - 1) + y = np.clip(y, 0, h - 1) + + labels = mask[y, x] + return np.stack([x, y, labels], axis=1) + + +def sample_points_without_mask( + box: np.ndarray, + n_positive: int, + n_negative: int, + H: int, + W: int, +) -> np.ndarray: + """Sample points when no mask is available. + + Uses box region as pseudo-mask: + - Positive points: uniformly from inside box + - Negative points: uniformly from outside box + + Args: + box: Box in xyxy format (x1, y1, x2, y2) + n_positive: Number of positive points to sample + n_negative: Number of negative points to sample + H: Image height + W: Image width + + Returns: + Points array (n_positive + n_negative, 3) with (x, y, label) + """ + x1, y1, x2, y2 = map(int, box) + + # Ensure valid box + x1 = max(0, min(x1, W - 1)) + x2 = max(x1 + 1, min(x2, W)) + y1 = max(0, min(y1, H - 1)) + y2 = max(y1 + 1, min(y2, H)) + + points_list = [] + + # Positive points: inside box + if n_positive > 0: + pos_x = np.random.randint(x1, x2, size=n_positive) + pos_y = np.random.randint(y1, y2, size=n_positive) + pos_labels = np.ones(n_positive) + pos_points = np.stack([pos_x, pos_y, pos_labels], axis=1) + points_list.append(pos_points) + + # Negative points: outside box + if n_negative > 0: + neg_points = [] + max_attempts = n_negative * 100 + + for _ in range(max_attempts): + if len(neg_points) >= n_negative: + break + x = np.random.randint(0, W) + y = np.random.randint(0, H) + # Check if outside box + if not (x1 <= x < x2 and y1 <= y < y2): + neg_points.append([x, y, 0]) + + if len(neg_points) < n_negative: + # Fallback: sample from image corners if box is too large + corners = [(0, 0), (W-1, 0), (0, H-1), (W-1, H-1)] + while len(neg_points) < n_negative: + cx, cy = corners[len(neg_points) % 4] + neg_points.append([cx, cy, 0]) + + neg_points = np.array(neg_points[:n_negative]) + points_list.append(neg_points) + + if points_list: + return np.concatenate(points_list, axis=0) + else: + return np.zeros((0, 3)) + + +def noise_box( + box: np.ndarray, + im_size: tuple, + box_noise_std: float = 0.1, + box_noise_max: Optional[float] = None, + min_box_area: float = 0.0, +) -> np.ndarray: + """Add noise to a box for data augmentation. + + Follows SAM3's noise_box implementation: + - Gaussian noise scaled by box dimensions + - Optional pixel clamp + - Fallback to original box if area too small + + Args: + box: Box in xyxy format (x1, y1, x2, y2) + im_size: Image size (H, W) + box_noise_std: Noise std relative to box size + box_noise_max: Max noise in pixels (None = no clamp) + min_box_area: Min area after noising (SAM3 default: 0.0) + + Returns: + Noised box in xyxy format + """ + if box_noise_std <= 0.0: + return box + + noise = box_noise_std * np.random.randn(4) + w, h = box[2] - box[0], box[3] - box[1] + scale_factor = np.array([w, h, w, h]) + noise = noise * scale_factor + + if box_noise_max is not None: + noise = np.clip(noise, -box_noise_max, box_noise_max) + + noised_box = box + noise + + # Clamp to image bounds + H, W = im_size + noised_box = np.maximum(noised_box, 0) + noised_box = np.minimum(noised_box, [W, H, W, H]) + + # Check min area (SAM3 default: 0.0 = no limit) + new_w = noised_box[2] - noised_box[0] + new_h = noised_box[3] - noised_box[1] + if new_w * new_h <= min_box_area: + return box + + return noised_box + + +# ============================================================================ +# WildDet3D Collator +# ============================================================================ + +class WildDet3DCollator: + """Collator that converts per-image data to WildDet3DInput. + + Design (SAM3 original - per-category queries): + - DataLoader produces per-image samples + - Collator groups GT boxes by category + - Each category creates ONE query with multi-instance targets + - This aligns with SAM3's multi-instance detection design + + Per-prompt batch strategy: + - N_prompts = sum of unique categories across batch (NOT sum of boxes!) + - img_ids[i] indicates which image prompt i belongs to + - Each prompt can have multiple GT boxes (multi-instance targets) + + Coordinate format: + - Input boxes2d: pixel xyxy (from dataset) + - geo_boxes: normalized cxcywh [0,1] (for SAM3) + - geo_points: normalized xy [0,1] (for SAM3) + - gt_boxes2d: normalized xyxy [0,1] (for loss) + - gt_boxes2d shape: (N_prompts, max_gts, 4) for multi-instance + - num_gts: (N_prompts,) number of GT boxes per query (can be > 1) + + Text/Visual Query: + - text_query_prob controls the ratio of text vs visual queries + - text_query_prob=1.0: all text queries (SAM3 default for training) + - text_query_prob=0.7: 70% text, 30% visual (recommended by SAM3) + - Visual queries use one randomly selected target box as geo_box + """ + + def __init__( + self, + max_prompts_per_image: int = 50, + use_text_prompts: bool = True, + default_text: str = "visual", + # Point prompt options + use_point_prompts: bool = False, + num_positive_points: int | tuple[int, int] = 1, + num_negative_points: int | tuple[int, int] = 0, + point_sample_mode: Literal["centered", "random_mask", "random_box"] = "random_mask", + # Box prompt options + use_box_prompts: bool = True, + box_noise_std: float = 0.0, + box_noise_max: float | None = None, + # Multi-tier box noise: (prob, std) tiers sampled per box. + # If set, overrides box_noise_std. Each tier is (probability, std). + # Probabilities must sum to 1.0. + # Example: [(0.3, 0.0), (0.5, 0.1), (0.2, 0.2)] + # = 30% no noise, 50% mild, 20% extreme + box_noise_tiers: list[tuple[float, float]] | None = None, + # Text/Visual query ratio (SAM3 original design) + text_query_prob: float = 0.7, # 70% text, 30% visual (SAM3 recommended) + keep_text_for_visual: bool = False, # If True, visual queries keep category text + # Geometry prompt options (text + geometry training) + use_geometry_prompts: bool = False, # If True, create 2 queries per category + geometric_query_str: str = "geometric", # Text for geometry queries + visual_query_str: str = "visual", # Text for visual queries + # 5-mode training: Branch 1 and Branch 2 probabilities + # Branch 1 (o2m): TEXT (text_only_prob) / VISUAL or VISUAL+LABEL (1-text_only_prob) + # Branch 2 (o2o): GEOMETRY or GEOMETRY+LABEL + # use_label_prob controls +LABEL variants for both branches + text_only_prob: float = 0.5, # Branch 1: P(TEXT) vs P(box-based query) + use_label_prob: float = 1/3, # P(+LABEL) when query has a box prompt + # Oracle evaluation mode (GT box as geometry prompt) + oracle_eval: bool = False, # If True, each GT box = one geometry prompt + oracle_text_category: bool = False, # If True, oracle + category text + # Point prompt: SAM3-style box/point budget (only when use_point_prompts=True) + # num_points is the total geometric prompt budget. + # box_chance controls probability of including a box (which takes 1 slot). + # E.g. num_points=(1,3), box_chance=0.5: + # num=1, box=True → pure box | num=1, box=False → 1 point + # num=2, box=True → box+1pt | num=2, box=False → 2 points + # num=3, box=True → box+2pt | num=3, box=False → 3 points + box_chance: float = 0.5, + # Exclusive point mode probability. When use_point_prompts=True, + # Branch 2 randomly picks EITHER box-only OR point-only (never + # both). Point-only is chosen with probability point_mode_prob, + # but only when the selected box has a mask (masks2d_rle). + # Otherwise box-only. Points use SAM3 random_box mode: uniform + # from box region, mask determines pos/neg labels. + point_mode_prob: float = 0.3, + # Negative sampling (SAM3 style) + include_negatives: bool = False, # Add negative queries (absent categories) + max_negatives_per_image: int = 5, # Max negative queries per image + # Training vs inference filtering + filter_empty_boxes: bool = True, # Set False at test time to keep 0-GT-box images + ): + """Initialize collator. + + Args: + max_prompts_per_image: Max number of prompts (categories) per image + use_text_prompts: Whether to include text with geometric prompts + default_text: Default text when class name not available + use_point_prompts: Whether to sample point prompts (for ablation) + num_positive_points: Number of positive points to sample + Can be int or (min, max) tuple for random range + num_negative_points: Number of negative points to sample + Can be int or (min, max) tuple for random range + point_sample_mode: How to sample points when mask is available + - "centered": sample from mask center (farthest from edges) + - "random_mask": uniform sample from mask interior + - "random_box": uniform sample from box, label from mask + use_box_prompts: Whether to use box prompts + box_noise_std: Noise std for box jittering (0 = no noise) + box_noise_max: Max noise in pixels (None = no clamp) + box_noise_tiers: Multi-tier noise as list of (prob, std). + Overrides box_noise_std when set. + text_query_prob: Probability of text-only queries (SAM3 recommended: 0.7) + Only used when use_geometry_prompts=False (legacy 2-mode). + keep_text_for_visual: If True, visual queries keep category text + If False (default), visual queries use "visual" as text. + Only used when use_geometry_prompts=False (legacy 2-mode). + use_geometry_prompts: If True, 5-mode training with 2 queries + per category (Branch 1 o2m + Branch 2 o2o). + geometric_query_str: Text for geometry queries (default "geometric") + visual_query_str: Text for visual queries (default "visual") + text_only_prob: Branch 1 probability of TEXT mode (no box). + Remaining (1-text_only_prob) is box-based (VISUAL or VISUAL+LABEL). + use_label_prob: Probability of +LABEL variant when query has a box. + Controls both Branch 1 (VISUAL vs VISUAL+LABEL) and + Branch 2 (GEOMETRY vs GEOMETRY+LABEL). + +LABEL format: "visual: car" / "geometric: car". + oracle_eval: If True, each GT 2D box becomes its own geometry + prompt (one-to-one). For measuring 3D regression quality + in isolation, following DetAny3D's GT prompt evaluation. + oracle_text_category: If True, oracle mode with category text. + Each GT box = one GEOMETRY+LABEL prompt with text + "geometric: " (e.g., "geometric: apple"). + """ + self.max_prompts_per_image = max_prompts_per_image + self.use_text_prompts = use_text_prompts + self.default_text = default_text + + # Point prompt options + self.use_point_prompts = use_point_prompts + self.num_positive_points = num_positive_points + self.num_negative_points = num_negative_points + self.point_sample_mode = point_sample_mode + + # Box prompt options + self.use_box_prompts = use_box_prompts + self.box_noise_std = box_noise_std + self.box_noise_max = box_noise_max + self.box_noise_tiers = box_noise_tiers + + # Text/Visual query ratio + self.text_query_prob = text_query_prob + self.keep_text_for_visual = keep_text_for_visual + + # Geometry prompt options (5-mode training) + self.use_geometry_prompts = use_geometry_prompts + self.geometric_query_str = geometric_query_str + self.visual_query_str = visual_query_str + self.text_only_prob = text_only_prob + self.use_label_prob = use_label_prob + + # Oracle evaluation mode + self.oracle_eval = oracle_eval + self.oracle_text_category = oracle_text_category + + # Point prompt: box/point budget + self.box_chance = box_chance + self.point_mode_prob = point_mode_prob + + # Negative sampling (SAM3 style presence loss training) + self.include_negatives = include_negatives + self.max_negatives_per_image = max_negatives_per_image + + # Training vs inference filtering + self.filter_empty_boxes = filter_empty_boxes + + def _sample_box_noise_std(self) -> float: + """Sample box noise std from tiers or fallback to self.box_noise_std.""" + if self.box_noise_tiers is not None: + r = random.random() + cumulative = 0.0 + for prob, std in self.box_noise_tiers: + cumulative += prob + if r < cumulative: + return std + return self.box_noise_tiers[-1][1] + return self.box_noise_std + + def _sample_num_points(self, num_spec: int | tuple[int, int]) -> int: + """Sample number of points from spec.""" + if isinstance(num_spec, int): + return num_spec + else: + low, high = num_spec + return np.random.randint(low, high + 1) + + def _sample_points_for_box( + self, + box_xyxy: np.ndarray, + mask: Optional[np.ndarray], + H: int, + W: int, + ) -> np.ndarray: + """Sample points for a single box. + + Args: + box_xyxy: Box in pixel xyxy format + mask: Optional binary mask (H, W) + H, W: Image dimensions + + Returns: + Points array (N, 3) with (x, y, label) in pixel coords + """ + n_pos = self._sample_num_points(self.num_positive_points) + n_neg = self._sample_num_points(self.num_negative_points) + + if mask is not None: + # Sample from actual mask + points = sample_points_from_mask( + mask, n_pos + n_neg, self.point_sample_mode, box_xyxy + ) + else: + # Use box as pseudo-mask + points = sample_points_without_mask(box_xyxy, n_pos, n_neg, H, W) + + return points + + def _sample_geo_budget(self) -> tuple[int, bool]: + """Sample geometric prompt budget (SAM3 style). + + Returns: + (n_points, use_box): number of point prompts and whether to + include a box. Box takes 1 slot from the total budget. + """ + n_total = self._sample_num_points(self.num_positive_points) + if self.box_chance > 0: + use_box = random.random() < self.box_chance + n_points = max(n_total - int(use_box), 0) + else: + use_box = False + n_points = n_total + return n_points, use_box + + def _sample_points_normalized( + self, + box_xyxy_pixel: np.ndarray, + n_points: int, + H: int, + W: int, + mask: Optional[np.ndarray] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Sample n_points from box region, return normalized coords + labels. + + Returns: + pts_xy: (n_points, 2) normalized [0,1] + pts_labels: (n_points,) long, 1=positive 0=negative + """ + if n_points <= 0: + return None, None + if mask is not None: + points = sample_points_from_mask( + mask, n_points, self.point_sample_mode, box_xyxy_pixel + ) + else: + points = sample_points_without_mask( + box_xyxy_pixel, n_points, 0, H, W + ) + pts_xy = torch.tensor( + points[:, :2] / np.array([W, H]), + dtype=torch.float32, + ) + pts_labels = torch.tensor(points[:, 2], dtype=torch.long) + return pts_xy, pts_labels + + def __call__(self, batch: List[dict]) -> WildDet3DInput: + """Collate batch of per-image samples to WildDet3DInput. + + Args: + batch: List of dicts, each containing: + - images: (3, H, W) + - boxes2d: (N_i, 4) pixel xyxy + - boxes2d_classes: (N_i,) class indices + - boxes2d_names: List[str] class names (optional) + - boxes3d: (N_i, 7+) 3D box params + - intrinsics: (3, 3) + - masks2d: (N_i, H, W) binary masks (optional) + + Returns: + WildDet3DInput with per-prompt batch + """ + profile_start(" collator_total") + + # Filter out images with no GT boxes to avoid empty prompts. + # Only applied during training; at test time we keep all images so + # that the evaluator receives predictions for every image (even ones + # with 0 valid 3D GT boxes). _forward_test already handles the + # n_prompts_this_img==0 case by returning empty tensors. + original_batch_size = len(batch) + if self.filter_empty_boxes: + batch = [ + item for item in batch + if item.get("boxes2d") is not None and len(item["boxes2d"]) > 0 + ] + + # if len(batch) < original_batch_size: + # import torch.distributed as dist + # rank = dist.get_rank() if dist.is_initialized() else 0 + # filtered_count = original_batch_size - len(batch) + # print( + # f"[WildDet3DCollator] Filtered {filtered_count}/{original_batch_size} " + # f"empty images on rank {rank}" + # ) + + B = len(batch) + + # Handle completely empty batch (all images filtered out) + if B == 0: + # import torch.distributed as dist + # rank = dist.get_rank() if dist.is_initialized() else 0 + # print( + # f"[WildDet3DCollator] WARNING: Entire batch empty after filtering " + # f"({original_batch_size} images all had 0 GT boxes) on rank {rank}" + # ) + # Return minimal empty batch - model will handle this gracefully + return WildDet3DInput( + images=torch.zeros(0, 3, 1, 1), # (0, 3, H, W) + intrinsics=torch.zeros(0, 3, 3), # (0, 3, 3) + img_ids=torch.zeros(0, dtype=torch.long), + text_ids=torch.zeros(0, dtype=torch.long), + unique_texts=[self.default_text], + sample_names=None, + dataset_name=None, + original_hw=None, + original_images=None, + original_intrinsics=None, + padding=None, + ) + + device = batch[0]["images"].device if batch[0]["images"].is_cuda else "cpu" + + # Collect image-level data + profile_start(" collator_image_stack") + # Images might be (3, H, W) or (1, 3, H, W) depending on data pipeline + images_list = [] + for b in batch: + img = b["images"] + # Handle case where img might have extra batch dim + if img.dim() == 4 and img.shape[0] == 1: + img = img.squeeze(0) # (1, 3, H, W) -> (3, H, W) + images_list.append(img) + images = torch.stack(images_list) # (B, 3, H, W) + intrinsics = torch.stack([b["intrinsics"] for b in batch]) # (B, 3, 3) + H, W = images.shape[-2:] # Use -2: and -1 for H, W to be safe + profile_stop(" collator_image_stack") + + # Collect metadata for evaluation/visualization + sample_names = [] + dataset_name_list = [] + original_hw_list = [] + original_images_list = [] + original_intrinsics_list = [] + padding_list = [] + for b_idx, b in enumerate(batch): + # sample_names - image identifier for evaluation + if "sample_names" in b: + sample_names.append(b["sample_names"]) + elif "image_id" in b: + sample_names.append(b["image_id"]) + else: + sample_names.append(None) + + # dataset_name - for evaluator to route to correct dataset + if "dataset_name" in b: + dataset_name_list.append(b["dataset_name"]) + else: + dataset_name_list.append(None) + + # original_hw - for coordinate scaling back + if "original_hw" in b: + original_hw_list.append(b["original_hw"]) + else: + original_hw_list.append(None) + + # original_images - unresized images for visualization + if "original_images" in b: + original_images_list.append(b["original_images"]) + else: + original_images_list.append(None) + + # original_intrinsics - intrinsics before resize + if "original_intrinsics" in b: + original_intrinsics_list.append(b["original_intrinsics"]) + else: + original_intrinsics_list.append(None) + + # padding - CenterPad offsets [pad_left, pad_right, pad_top, pad_bottom] + if "padding" in b: + padding_list.append(b["padding"]) + else: + padding_list.append(None) + + # Collect depth maps for geometry backend supervision + depth_maps_list = [] + for b in batch: + # depth_maps - K.depth_maps key from dataset + if "depth_maps" in b and b["depth_maps"] is not None: + depth_maps_list.append(b["depth_maps"]) + else: + depth_maps_list.append(None) + + # Stack depth maps if available (all images must have depth) + depth_gt = None + if depth_maps_list and all(d is not None for d in depth_maps_list): + try: + depth_gt = torch.stack(depth_maps_list, dim=0) # (B, H, W) or (B, 1, H, W) + if depth_gt.dim() == 3: + depth_gt = depth_gt.unsqueeze(1) # (B, H, W) -> (B, 1, H, W) + except (RuntimeError, TypeError): + depth_gt = None + + # Convert to proper format (None if all are None) + sample_names = sample_names if any(s is not None for s in sample_names) else None + dataset_name = dataset_name_list if any(d is not None for d in dataset_name_list) else None + original_hw = original_hw_list if any(h is not None for h in original_hw_list) else None + padding = padding_list if any(p is not None for p in padding_list) else None + original_images = None + if any(img is not None for img in original_images_list): + # Convert numpy arrays to tensors, then try stacking. + # Different-sized images (e.g. cross-dataset) cannot be stacked; + # in that case keep as list for the visualizer. + imgs = [] + for img in original_images_list: + if img is None: + continue + if not isinstance(img, torch.Tensor): + img = torch.as_tensor(img) + imgs.append(img) + if len(imgs) == 1: + original_images = imgs[0].unsqueeze(0) if imgs[0].dim() == 3 else imgs[0] + elif len(imgs) > 1: + try: + original_images = torch.stack(imgs) + except RuntimeError: + # Different shapes across batch - keep first only + original_images = imgs[0].unsqueeze(0) if imgs[0].dim() == 3 else imgs[0] + original_intrinsics = None + if any(intr is not None for intr in original_intrinsics_list): + intrs = [] + for intr in original_intrinsics_list: + if intr is None: + continue + if not isinstance(intr, torch.Tensor): + intr = torch.as_tensor(intr) + intrs.append(intr) + try: + original_intrinsics = torch.stack(intrs) + except (RuntimeError, TypeError): + original_intrinsics = None + + # Build per-prompt data (SAM3 original: per-category queries) + # If use_geometry_prompts=True: Each category creates TWO queries + # - TEXT query (one-to-many targets) + # - GEOMETRY query (one-to-one target) + # If use_geometry_prompts=False: Original behavior (text or visual per category) + img_ids_list = [] + text_ids_list = [] + geo_boxes_list = [] # normalized cxcywh (for visual/geometry queries) + geo_points_list = [] # normalized xy (N, 2) or None + geo_point_labels_list = [] # labels (N,) or None + is_visual_query_list = [] # Track which queries have visual prompts + # Query types (collator-level label only, does NOT control SAM3 internal matching): + # 0=TEXT, 1=VISUAL, 2=GEOMETRY, 3=VISUAL+LABEL, 4=GEOMETRY+LABEL + query_types_list = [] + + # Multi-instance targets: list of lists + # gt_boxes2d_per_query[i] = list of normalized xyxy boxes for query i + gt_boxes2d_per_query = [] + gt_boxes3d_per_query = [] + gt_category_ids_list = [] + + # Ignore boxes per query (for negative loss suppression) + ignore_boxes2d_per_query = [] + + # Build unique text list + unique_texts = [] + text_to_id = {} + + # Helper function to normalize box to xyxy [0,1] + def normalize_box_xyxy(box_xyxy_raw): + if isinstance(box_xyxy_raw, torch.Tensor): + gt_box_norm = box_xyxy_raw.clone().float() + else: + gt_box_norm = torch.tensor(box_xyxy_raw, dtype=torch.float32) + gt_box_norm[0::2] /= W + gt_box_norm[1::2] /= H + return gt_box_norm.to(device) + + # Helper function to convert xyxy to cxcywh + def xyxy_to_cxcywh(box_norm_xyxy): + cx = (box_norm_xyxy[0] + box_norm_xyxy[2]) / 2 + cy = (box_norm_xyxy[1] + box_norm_xyxy[3]) / 2 + w_box = box_norm_xyxy[2] - box_norm_xyxy[0] + h_box = box_norm_xyxy[3] - box_norm_xyxy[1] + return torch.tensor([cx, cy, w_box, h_box], device=device) + + profile_start(" collator_category_group") + + if self.oracle_eval: + # ========== Oracle Mode: Each GT box = one geometry prompt ========== + # Following DetAny3D's GT prompt evaluation approach. + # One-to-one mapping: each GT box becomes a separate geometry + # prompt, model predicts 3D for each box independently. + geo_text = self.geometric_query_str + if geo_text not in text_to_id: + text_to_id[geo_text] = len(unique_texts) + unique_texts.append(geo_text) + geo_text_id = text_to_id[geo_text] + + for img_idx, sample in enumerate(batch): + boxes2d = sample.get("boxes2d") + boxes3d = sample.get("boxes3d") + class_ids = sample.get("boxes2d_classes") + + if boxes2d is None or len(boxes2d) == 0: + continue + + # During test, boxes2d are in original pixel space (test + # transforms don't include ResizeBoxes2D / CenterPadBoxes2D). + # Transform to padded pixel space using the SAME math as + # _forward_test's inverse (subtract pad, divide scale), reversed: + # original -> padded: x * scale_x + pad_left + # where scale_x = content_w / orig_w (from _forward_test) + original_hw = sample.get("original_hw", None) + pad_info = sample.get("padding", None) + + if original_hw is not None and pad_info is not None: + orig_h, orig_w = original_hw + if isinstance(orig_h, torch.Tensor): + orig_h, orig_w = orig_h.item(), orig_w.item() + pad_left, pad_right, pad_top, pad_bottom = pad_info + if isinstance(pad_left, torch.Tensor): + pad_left = pad_left.item() + pad_right = pad_right.item() + pad_top = pad_top.item() + pad_bottom = pad_bottom.item() + content_w = W - pad_left - pad_right + content_h = H - pad_top - pad_bottom + scale_x = content_w / orig_w + scale_y = content_h / orig_h + + def transform_box_to_padded(box_raw): + """Transform box: original pixel -> padded pixel.""" + if isinstance(box_raw, torch.Tensor): + box = box_raw.clone().float() + else: + box = torch.tensor(box_raw, dtype=torch.float32) + box[0::2] = box[0::2] * scale_x + pad_left + box[1::2] = box[1::2] * scale_y + pad_top + return box + else: + def transform_box_to_padded(box_raw): + if isinstance(box_raw, torch.Tensor): + return box_raw.clone().float() + return torch.tensor(box_raw, dtype=torch.float32) + + for box_idx in range(len(boxes2d)): + img_ids_list.append(img_idx) + + # Category ID + if class_ids is not None: + cat_id = class_ids[box_idx] + if isinstance(cat_id, torch.Tensor): + cat_id = cat_id.item() + else: + cat_id = 0 + gt_category_ids_list.append(cat_id) + + # Geometry query type + query_types_list.append(2) # GEOMETRY + is_visual_query_list.append(True) + text_ids_list.append(geo_text_id) + + # Transform box to padded pixel space, then normalize + box_padded = transform_box_to_padded(boxes2d[box_idx]) + box_norm_xyxy = normalize_box_xyxy(box_padded) + geo_boxes_list.append(xyxy_to_cxcywh(box_norm_xyxy)) + geo_points_list.append(None) + geo_point_labels_list.append(None) + + # Target = this single box (one-to-one) + gt_boxes2d_per_query.append( + [normalize_box_xyxy(boxes2d[box_idx])] + ) + if boxes3d is not None and box_idx < len(boxes3d): + gt_boxes3d_per_query.append( + [boxes3d[box_idx].to(device)] + ) + else: + gt_boxes3d_per_query.append(None) + # Oracle mode: no ignore box suppression needed + ignore_boxes2d_per_query.append([]) + + elif self.oracle_text_category: + # ========== Oracle + Text Category Mode ========== + # Same as oracle (each GT box = one geometry prompt), but with + # category-specific text: "geometric: " instead of + # generic "geometric". Query type = GEOMETRY+LABEL (4). + for img_idx, sample in enumerate(batch): + boxes2d = sample.get("boxes2d") + boxes3d = sample.get("boxes3d") + class_ids = sample.get("boxes2d_classes") + class_names = sample.get("boxes2d_names", None) + + if boxes2d is None or len(boxes2d) == 0: + continue + + original_hw = sample.get("original_hw", None) + pad_info = sample.get("padding", None) + + if original_hw is not None and pad_info is not None: + orig_h, orig_w = original_hw + if isinstance(orig_h, torch.Tensor): + orig_h, orig_w = orig_h.item(), orig_w.item() + pad_left, pad_right, pad_top, pad_bottom = pad_info + if isinstance(pad_left, torch.Tensor): + pad_left = pad_left.item() + pad_right = pad_right.item() + pad_top = pad_top.item() + pad_bottom = pad_bottom.item() + content_w = W - pad_left - pad_right + content_h = H - pad_top - pad_bottom + scale_x = content_w / orig_w + scale_y = content_h / orig_h + + def transform_box_to_padded(box_raw): + """Transform box: original pixel -> padded pixel.""" + if isinstance(box_raw, torch.Tensor): + box = box_raw.clone().float() + else: + box = torch.tensor(box_raw, dtype=torch.float32) + box[0::2] = box[0::2] * scale_x + pad_left + box[1::2] = box[1::2] * scale_y + pad_top + return box + else: + def transform_box_to_padded(box_raw): + if isinstance(box_raw, torch.Tensor): + return box_raw.clone().float() + return torch.tensor(box_raw, dtype=torch.float32) + + for box_idx in range(len(boxes2d)): + img_ids_list.append(img_idx) + + # Category ID + if class_ids is not None: + cat_id = class_ids[box_idx] + if isinstance(cat_id, torch.Tensor): + cat_id = cat_id.item() + else: + cat_id = 0 + gt_category_ids_list.append(cat_id) + + # Get category name + if class_names is not None and cat_id < len(class_names): + cat_name = class_names[cat_id] + else: + cat_name = self.default_text + + # GEOMETRY+LABEL query: "geometric: " + gl_text = f"{self.geometric_query_str}: {cat_name}" + if gl_text not in text_to_id: + text_to_id[gl_text] = len(unique_texts) + unique_texts.append(gl_text) + query_types_list.append(4) # GEOMETRY+LABEL + is_visual_query_list.append(True) + text_ids_list.append(text_to_id[gl_text]) + + # Transform box to padded pixel space, then normalize + box_padded = transform_box_to_padded(boxes2d[box_idx]) + box_norm_xyxy = normalize_box_xyxy(box_padded) + geo_boxes_list.append(xyxy_to_cxcywh(box_norm_xyxy)) + geo_points_list.append(None) + geo_point_labels_list.append(None) + + # Target = this single box (one-to-one) + gt_boxes2d_per_query.append( + [normalize_box_xyxy(boxes2d[box_idx])] + ) + if boxes3d is not None and box_idx < len(boxes3d): + gt_boxes3d_per_query.append( + [boxes3d[box_idx].to(device)] + ) + else: + gt_boxes3d_per_query.append(None) + ignore_boxes2d_per_query.append([]) + + else: + # ========== Standard Mode: Group by category ========== + for img_idx, sample in enumerate(batch): + boxes2d = sample.get("boxes2d") # (N_i, 4) pixel xyxy + boxes3d = sample.get("boxes3d") # (N_i, 7+) + class_ids = sample.get("boxes2d_classes") # (N_i,) + class_names = sample.get("boxes2d_names", None) # List[str] or None + masks2d = sample.get("masks2d", None) # (N_i, H, W) or None + + if boxes2d is None or len(boxes2d) == 0: + continue + + # ========== SAM3 Original: Group boxes by category ========== + cat_to_box_indices = defaultdict(list) + for box_idx in range(len(boxes2d)): + if class_ids is not None: + cat_id = class_ids[box_idx] + if isinstance(cat_id, torch.Tensor): + cat_id = cat_id.item() + else: + cat_id = 0 + cat_to_box_indices[cat_id].append(box_idx) + + # Group ignore boxes by category (for negative loss suppression) + ignore_boxes2d_raw = sample.get("ignore_boxes2d", None) + ignore_class_ids_raw = sample.get("ignore_class_ids", None) + cat_to_ignore_indices = defaultdict(list) + if ( + ignore_boxes2d_raw is not None + and len(ignore_boxes2d_raw) > 0 + ): + for ign_idx in range(len(ignore_boxes2d_raw)): + ign_cat_id = int(ignore_class_ids_raw[ign_idx]) + cat_to_ignore_indices[ign_cat_id].append(ign_idx) + + # Limit number of categories (queries) per image + categories = list(cat_to_box_indices.keys()) + if len(categories) > self.max_prompts_per_image: + random.shuffle(categories) + categories = categories[:self.max_prompts_per_image] + + # ========== Create queries per category ========== + for cat_id in categories: + box_indices = cat_to_box_indices[cat_id] + + # Get category name for text + if self.use_text_prompts and class_names is not None: + cat_name = class_names[cat_id] if cat_id < len(class_names) else self.default_text + else: + cat_name = self.default_text + + if self.use_geometry_prompts: + # ========== 5-Mode Training ========== + # Creates 2 queries per category: + # + # Branch 1 ("multi-target"): target = ALL instances of this category + # - TEXT: text="car", no box + # - VISUAL: text="visual", geo_box + # - VISUAL+LABEL: text="visual: car", geo_box + # + # Branch 2 ("single-target"): target = 1 selected instance only + # - GEOMETRY: text="geometric", geo_box + # - GEOMETRY+LABEL: text="geometric: car", geo_box + # + # NOTE on "multi-target" vs "single-target": + # This refers to how many GT boxes are assigned as + # targets in this collator (num_gts). This is DIFFERENT + # from SAM3's internal o2o/o2m matching (DAC mechanism). + # SAM3's DAC always runs both Hungarian (o2o) and + # one-to-many (o2m) matchers in the decoder regardless + # of how many GT targets we assign here. + + # Helper: add text to unique_texts and return its id + def _get_text_id(text_str): + if text_str not in text_to_id: + text_to_id[text_str] = len(unique_texts) + unique_texts.append(text_str) + return text_to_id[text_str] + + # Helper: select a random GT box and return its + # normalized cxcywh (with optional noise) + def _make_geo_box(box_indices_inner): + sel_idx = random.choice(box_indices_inner) + bx = boxes2d[sel_idx] + bx_np = bx.cpu().numpy() if isinstance(bx, torch.Tensor) else bx + std = self._sample_box_noise_std() + if std > 0: + bx_np = noise_box( + bx_np, + im_size=(H, W), + box_noise_std=std, + box_noise_max=self.box_noise_max, + ) + norm_xyxy = torch.tensor([ + bx_np[0] / W, bx_np[1] / H, + bx_np[2] / W, bx_np[3] / H, + ], dtype=torch.float32, device=device) + return sel_idx, xyxy_to_cxcywh(norm_xyxy) + + # ----- Branch 1 (multi-target): TEXT / VISUAL / VISUAL+LABEL ----- + img_ids_list.append(img_idx) + gt_category_ids_list.append(cat_id) + + is_text_only = random.random() < self.text_only_prob + if is_text_only: + # TEXT: text="car", no box, no points, all targets + query_types_list.append(0) # TEXT + is_visual_query_list.append(False) + text_ids_list.append(_get_text_id(cat_name)) + geo_boxes_list.append(None) + geo_points_list.append(None) + geo_point_labels_list.append(None) + else: + # Box-based o2m query + has_label = random.random() < self.use_label_prob + if has_label: + # VISUAL+LABEL: text="visual: car", box, all targets + query_types_list.append(3) # VISUAL+LABEL + vl_text = f"{self.visual_query_str}: {cat_name}" + text_ids_list.append(_get_text_id(vl_text)) + else: + # VISUAL: text="visual", box, all targets + query_types_list.append(1) # VISUAL + text_ids_list.append(_get_text_id(self.visual_query_str)) + is_visual_query_list.append(True) + _, geo_cxcywh = _make_geo_box(box_indices) + geo_boxes_list.append(geo_cxcywh) + # Branch 1 visual: no point prompts (box only) + geo_points_list.append(None) + geo_point_labels_list.append(None) + + # Targets: ALL boxes of this category (multi-target) + query_gt_boxes2d = [] + query_gt_boxes3d = [] + for box_idx in box_indices: + query_gt_boxes2d.append(normalize_box_xyxy(boxes2d[box_idx])) + if boxes3d is not None and box_idx < len(boxes3d): + query_gt_boxes3d.append(boxes3d[box_idx].to(device)) + gt_boxes2d_per_query.append(query_gt_boxes2d) + gt_boxes3d_per_query.append(query_gt_boxes3d if query_gt_boxes3d else None) + # Collect ignore boxes for this category + ign_indices = cat_to_ignore_indices.get(cat_id, []) + query_ign = [normalize_box_xyxy(ignore_boxes2d_raw[i]) for i in ign_indices] if ign_indices and ignore_boxes2d_raw is not None else [] + ignore_boxes2d_per_query.append(query_ign) + + # ----- Branch 2 (single-target): GEOMETRY / GEOMETRY+LABEL ----- + img_ids_list.append(img_idx) + gt_category_ids_list.append(cat_id) + + has_label_b2 = random.random() < self.use_label_prob + if has_label_b2: + # GEOMETRY+LABEL: text="geometric: car", 1 target + query_types_list.append(4) # GEOMETRY+LABEL + gl_text = f"{self.geometric_query_str}: {cat_name}" + text_ids_list.append(_get_text_id(gl_text)) + else: + # GEOMETRY: text="geometric", 1 target + query_types_list.append(2) # GEOMETRY + text_ids_list.append(_get_text_id(self.geometric_query_str)) + is_visual_query_list.append(True) + + selected_idx, geo_cxcywh = _make_geo_box(box_indices) + + # Decide geometric prompt mode for Branch 2 + if self.use_point_prompts: + # Exclusive mode: box OR point, never both + masks2d = sample.get( + "masks2d", None + ) + has_mask = ( + masks2d is not None + and selected_idx < len(masks2d) + and masks2d[selected_idx].sum() > 0 + ) + use_pt = ( + has_mask + and random.random() + < self.point_mode_prob + ) + if use_pt: + # Point-only (no box) + sel_mask = masks2d[selected_idx] + if isinstance( + sel_mask, torch.Tensor + ): + sel_mask = ( + sel_mask.cpu().numpy() + ) + sel_box = boxes2d[selected_idx] + sel_box_np = ( + sel_box.cpu().numpy() + if isinstance( + sel_box, torch.Tensor + ) + else np.array(sel_box) + ) + n_pts = self._sample_num_points( + self.num_positive_points + ) + if n_pts == 1: + # Single point: always positive + # from mask center (farthest + # from edges) + points = sample_points_from_mask( + sel_mask, + 1, + "centered", + ) + else: + # Multi-point: random_box mode, + # mask determines pos/neg labels + points = sample_points_from_mask( + sel_mask, + n_pts, + "random_box", + sel_box_np, + ) + pts_xy = torch.tensor( + points[:, :2] + / np.array([W, H]), + dtype=torch.float32, + ) + pts_labels = torch.tensor( + points[:, 2], + dtype=torch.long, + ) + geo_boxes_list.append(None) + geo_points_list.append(pts_xy) + geo_point_labels_list.append( + pts_labels + ) + else: + # Box-only (no points) + geo_boxes_list.append(geo_cxcywh) + geo_points_list.append(None) + geo_point_labels_list.append(None) + else: + geo_boxes_list.append(geo_cxcywh) + geo_points_list.append(None) + geo_point_labels_list.append(None) + + # Target: ONLY the selected box (single-target) + query_gt_boxes2d = [normalize_box_xyxy(boxes2d[selected_idx])] + query_gt_boxes3d = [] + if boxes3d is not None and selected_idx < len(boxes3d): + query_gt_boxes3d.append(boxes3d[selected_idx].to(device)) + gt_boxes2d_per_query.append(query_gt_boxes2d) + gt_boxes3d_per_query.append(query_gt_boxes3d if query_gt_boxes3d else None) + # Same ignore boxes as Branch 1 (same category) + ign_indices = cat_to_ignore_indices.get(cat_id, []) + query_ign = [normalize_box_xyxy(ignore_boxes2d_raw[i]) for i in ign_indices] if ign_indices and ignore_boxes2d_raw is not None else [] + ignore_boxes2d_per_query.append(query_ign) + + else: + # ========== Original: Text/Visual random selection ========== + img_ids_list.append(img_idx) + gt_category_ids_list.append(cat_id) + + # Decide query type: text-only or visual + is_text_query = random.random() < self.text_query_prob + is_visual_query = not is_text_query + + # Track query type (0=TEXT for both text and visual in original mode) + query_types_list.append(0 if is_text_query else 1) # 1=VISUAL + is_visual_query_list.append(is_visual_query) + + # Determine text for this query + if is_visual_query and not self.keep_text_for_visual: + text = "visual" + else: + text = cat_name + + if text not in text_to_id: + text_to_id[text] = len(unique_texts) + unique_texts.append(text) + text_ids_list.append(text_to_id[text]) + + # Visual query: pick one target as geo_box + if is_visual_query and self.use_box_prompts: + selected_idx = random.choice(box_indices) + box_xyxy = boxes2d[selected_idx] + box_xyxy_np = box_xyxy.cpu().numpy() if isinstance(box_xyxy, torch.Tensor) else box_xyxy + + std = self._sample_box_noise_std() + if std > 0: + box_xyxy_np = noise_box( + box_xyxy_np, + im_size=(H, W), + box_noise_std=std, + box_noise_max=self.box_noise_max, + ) + + box_norm_xyxy = torch.tensor([ + box_xyxy_np[0] / W, + box_xyxy_np[1] / H, + box_xyxy_np[2] / W, + box_xyxy_np[3] / H, + ], dtype=torch.float32, device=device) + geo_boxes_list.append(xyxy_to_cxcywh(box_norm_xyxy)) + else: + geo_boxes_list.append(None) + # Legacy mode: no point prompts + geo_points_list.append(None) + geo_point_labels_list.append(None) + + # Multi-instance targets: ALL boxes of this category + query_gt_boxes2d = [] + query_gt_boxes3d = [] + for box_idx in box_indices: + query_gt_boxes2d.append(normalize_box_xyxy(boxes2d[box_idx])) + if boxes3d is not None and box_idx < len(boxes3d): + query_gt_boxes3d.append(boxes3d[box_idx].to(device)) + gt_boxes2d_per_query.append(query_gt_boxes2d) + gt_boxes3d_per_query.append(query_gt_boxes3d if query_gt_boxes3d else None) + # Collect ignore boxes for this category + ign_indices = cat_to_ignore_indices.get(cat_id, []) + query_ign = [normalize_box_xyxy(ignore_boxes2d_raw[i]) for i in ign_indices] if ign_indices and ignore_boxes2d_raw is not None else [] + ignore_boxes2d_per_query.append(query_ign) + + # ========== Negative sampling (SAM3 style) ========== + # Add TEXT queries for absent categories (num_gts=0). + # These train the presence head to predict "not present". + # SAM3 does this via COCO_FROM_JSON include_negatives=True. + if ( + self.include_negatives + and class_names is not None + and 0 < len(class_names) <= 100 + ): + present_cats = set(cat_to_box_indices.keys()) + all_cats = set(range(len(class_names))) + absent_cats = list(all_cats - present_cats) + + if len(absent_cats) > self.max_negatives_per_image: + absent_cats = random.sample( + absent_cats, self.max_negatives_per_image + ) + + for neg_cat_id in absent_cats: + neg_cat_name = class_names[neg_cat_id] + img_ids_list.append(img_idx) + gt_category_ids_list.append(neg_cat_id) + query_types_list.append(0) # TEXT (exhaustive) + is_visual_query_list.append(False) + if neg_cat_name not in text_to_id: + text_to_id[neg_cat_name] = len(unique_texts) + unique_texts.append(neg_cat_name) + text_ids_list.append(text_to_id[neg_cat_name]) + geo_boxes_list.append(None) + geo_points_list.append(None) + geo_point_labels_list.append(None) + gt_boxes2d_per_query.append([]) + gt_boxes3d_per_query.append(None) + ignore_boxes2d_per_query.append([]) + + profile_stop(" collator_category_group") + + N_prompts = len(img_ids_list) + + if N_prompts == 0: + import torch.distributed as dist + rank = dist.get_rank() if dist.is_initialized() else 0 + print( + f"[WildDet3DCollator] WARNING: Unexpected N_prompts=0 " + f"(B={B} images passed filter) on rank {rank}" + ) + return WildDet3DInput( + images=images, + intrinsics=intrinsics, + img_ids=torch.zeros(0, dtype=torch.long, device=device), + text_ids=torch.zeros(0, dtype=torch.long, device=device), + unique_texts=[self.default_text], + sample_names=sample_names, + dataset_name=dataset_name, + original_hw=original_hw, + original_images=original_images, + original_intrinsics=original_intrinsics, + padding=padding, + ) + + # Stack tensors + profile_start(" collator_tensor_stack") + img_ids = torch.tensor(img_ids_list, dtype=torch.long, device=device) + text_ids = torch.tensor(text_ids_list, dtype=torch.long, device=device) + + # ========== Box prompts for visual queries ========== + # geo_boxes: (N_prompts, 1, 4) - None for text-only queries + geo_boxes = None + geo_boxes_mask = None + geo_box_labels = None + + # Check if any visual queries exist + has_visual = any(g is not None for g in geo_boxes_list) + if has_visual: + # Stack geo_boxes, use zeros for text-only queries + stacked_geo_boxes = [] + for g in geo_boxes_list: + if g is not None: + stacked_geo_boxes.append(g) + else: + stacked_geo_boxes.append(torch.zeros(4, device=device)) + geo_boxes = torch.stack(stacked_geo_boxes).unsqueeze(1) # (N, 1, 4) + + # Mask: True = padding (i.e., text-only queries have no valid box) + geo_boxes_mask = torch.tensor( + [[g is None] for g in geo_boxes_list], + dtype=torch.bool, device=device + ) # (N, 1) + + # Labels: 1 for positive (valid) boxes + geo_box_labels = torch.tensor( + [[1 if g is not None else 0] for g in geo_boxes_list], + dtype=torch.long, device=device + ) # (N, 1) + + # ========== Point prompts: pad to (N_prompts, max_P, 2) ========== + geo_points = None + geo_points_mask = None + geo_point_labels = None + has_points = any(p is not None for p in geo_points_list) + if has_points: + max_P = max( + len(p) for p in geo_points_list if p is not None + ) + if max_P > 0: + pts_padded = [] + pts_mask_list = [] + pts_labels_padded = [] + for pts, lbls in zip( + geo_points_list, geo_point_labels_list + ): + if pts is None or len(pts) == 0: + pts_padded.append( + torch.zeros(max_P, 2, device=device) + ) + pts_mask_list.append( + torch.ones(max_P, dtype=torch.bool, device=device) + ) + pts_labels_padded.append( + torch.zeros(max_P, dtype=torch.long, device=device) + ) + else: + n = len(pts) + pad_n = max_P - n + pts_padded.append(torch.cat([ + pts.to(device), + torch.zeros(pad_n, 2, device=device), + ])) + pts_mask_list.append(torch.cat([ + torch.zeros(n, dtype=torch.bool, device=device), + torch.ones(pad_n, dtype=torch.bool, device=device), + ])) + pts_labels_padded.append(torch.cat([ + lbls.to(device), + torch.zeros(pad_n, dtype=torch.long, device=device), + ])) + geo_points = torch.stack(pts_padded) # (N, max_P, 2) + geo_points_mask = torch.stack(pts_mask_list) # (N, max_P) + geo_point_labels = torch.stack(pts_labels_padded) # (N, max_P) + + # ========== Multi-instance GT boxes: pad to (N_prompts, max_gt, 4) ========== + # Find max number of targets per query (at least 1 for tensor shape) + max_gt = max( + (len(q) for q in gt_boxes2d_per_query), default=1 + ) + max_gt = max(max_gt, 1) # Ensure at least 1 for padded tensor shape + num_gts_list = [] + + gt_boxes2d_padded = [] + for query_boxes in gt_boxes2d_per_query: + n_gt = len(query_boxes) + num_gts_list.append(n_gt) + + if n_gt == 0: + # Negative query: all-zero padding, num_gts=0 + padded = [torch.zeros(4, device=device)] * max_gt + elif n_gt < max_gt: + # Pad with zeros + padded = query_boxes + [torch.zeros(4, device=device)] * (max_gt - n_gt) + else: + padded = query_boxes + gt_boxes2d_padded.append(torch.stack(padded)) + + gt_boxes2d = torch.stack(gt_boxes2d_padded) # (N_prompts, max_gt, 4) + num_gts = torch.tensor(num_gts_list, dtype=torch.long, device=device) # (N_prompts,) + + # 3D boxes (if available) + gt_boxes3d = None + if any(q is not None for q in gt_boxes3d_per_query): + # Get 3D box dimension from first valid entry + box3d_dim = None + for q in gt_boxes3d_per_query: + if q is not None and len(q) > 0: + box3d_dim = q[0].shape[-1] + break + + if box3d_dim is not None: + gt_boxes3d_padded = [] + for query_boxes in gt_boxes3d_per_query: + if query_boxes is None or len(query_boxes) == 0: + # No 3D boxes for this query + padded = [torch.zeros(box3d_dim, device=device)] * max_gt + else: + n_gt = len(query_boxes) + if n_gt < max_gt: + padded = query_boxes + [torch.zeros(box3d_dim, device=device)] * (max_gt - n_gt) + else: + padded = query_boxes + gt_boxes3d_padded.append(torch.stack(padded)) + gt_boxes3d = torch.stack(gt_boxes3d_padded) # (N_prompts, max_gt, box3d_dim) + + gt_category_ids = torch.tensor(gt_category_ids_list, dtype=torch.long, device=device) + + # ========== Ignore boxes: pad to (N_prompts, max_ignore, 4) ========== + max_ignore = max( + (len(q) for q in ignore_boxes2d_per_query), default=0 + ) + if max_ignore > 0: + num_ignores_list = [] + ignore_padded = [] + for q in ignore_boxes2d_per_query: + n_ign = len(q) + num_ignores_list.append(n_ign) + if n_ign < max_ignore: + padded = q + [ + torch.zeros(4, device=device) + ] * (max_ignore - n_ign) + else: + padded = q + ignore_padded.append(torch.stack(padded)) + ignore_boxes2d_tensor = torch.stack(ignore_padded) + num_ignores_tensor = torch.tensor( + num_ignores_list, dtype=torch.long, device=device + ) + else: + ignore_boxes2d_tensor = None + num_ignores_tensor = None + + # Query types: 0=TEXT, 1=VISUAL, 2=GEOMETRY + query_types = torch.tensor(query_types_list, dtype=torch.long, device=device) + profile_stop(" collator_tensor_stack") + profile_stop(" collator_total") + + return WildDet3DInput( + images=images, + intrinsics=intrinsics, + img_ids=img_ids, + text_ids=text_ids, + unique_texts=unique_texts, + geo_boxes=geo_boxes, + geo_boxes_mask=geo_boxes_mask, + geo_box_labels=geo_box_labels, + geo_points=geo_points, + geo_points_mask=geo_points_mask, + geo_point_labels=geo_point_labels, + gt_boxes2d=gt_boxes2d, + gt_boxes3d=gt_boxes3d, + num_gts=num_gts, + gt_category_ids=gt_category_ids, + ignore_boxes2d=ignore_boxes2d_tensor, + num_ignores=num_ignores_tensor, + query_types=query_types, + # Metadata for evaluation/visualization + sample_names=sample_names, + dataset_name=dataset_name, + original_hw=original_hw, + original_images=original_images, + original_intrinsics=original_intrinsics, + padding=padding, + # Depth ground truth for geometry backend supervision + depth_gt=depth_gt, + depth_mask=None, # Not yet implemented + ) + + +# ============================================================================ +# WildDet3D Specific Connectors +# ============================================================================ + +# Training connector for WildDet3D +# Note: SAM3 uses geometric prompts (boxes/points) instead of text +CONN_WILDDET3D_TRAIN = { + "images": K.images, + "input_hw": K.input_hw, + # Geometric prompts (boxes as prompts) + "prompt_boxes": K.boxes2d, # Use GT boxes as prompts during training + "prompt_box_labels": K.boxes2d_classes, + # Targets + "boxes2d": K.boxes2d, + "boxes2d_classes": K.boxes2d_classes, + "boxes3d": K.boxes3d, + # Camera + "intrinsics": K.intrinsics, + # Depth for geometry backend + "depth_gt": K.depth_maps, +} + +# Test connector for WildDet3D +CONN_WILDDET3D_TEST = { + "images": K.images, + "input_hw": K.input_hw, + "original_hw": K.original_hw, + # Geometric prompts (from external detector or user input) + "prompt_boxes": K.boxes2d, # External 2D detections as prompts + # Camera + "intrinsics": K.intrinsics, + "padding": "padding", +} + +# Loss connector for WildDet3D +CONN_WILDDET3D_LOSS = { + # Model outputs + "pred_logits": pred_key("pred_logits"), + "pred_boxes_2d": pred_key("pred_boxes_2d"), + "pred_boxes_3d": pred_key("pred_boxes_3d"), + "aux_outputs": pred_key("aux_outputs"), + "geom_losses": pred_key("geom_losses"), + # Matching indices (computed by model) + "indices": pred_key("indices"), + # Targets + "targets": { + "boxes": data_key(K.boxes2d), + "boxes_xyxy": data_key(K.boxes2d), # Will be converted + "boxes_3d": data_key(K.boxes3d), + "num_boxes": data_key("num_boxes"), + "image_size": data_key(K.input_hw), # (H, W) for pixel coordinate conversion + }, + # Camera + "intrinsics": data_key(K.intrinsics), + # Image size for pixel coordinate conversion (following GDino3D) + "image_size": data_key(K.input_hw), +} + +# Evaluation connector +CONN_WILDDET3D_EVAL = { + "coco_image_id": data_key(K.sample_names), + "pred_boxes": pred_key("boxes"), + "pred_scores": pred_key("scores"), + "pred_classes": pred_key("class_ids"), + "pred_boxes3d": pred_key("boxes3d"), +} + +# Visualization connector +CONN_WILDDET3D_VIS = { + "images": data_key(K.original_images), + "image_names": data_key(K.sample_names), + "intrinsics": data_key("original_intrinsics"), + "boxes3d": pred_key("boxes3d"), + "class_ids": pred_key("class_ids"), + "scores": pred_key("scores"), +} + + +class WildDet3DPassthroughConnector: + """Data connector that passes WildDet3DInput directly to model. + + Since WildDet3DCollator already produces WildDet3DInput with all needed + data, we just pass it through as the 'batch' parameter to model.forward(). + + This bypasses the key_mapping approach used by vis4d's DataConnector, + which expects raw DataLoader output format. + """ + + def __call__(self, data: WildDet3DInput) -> dict: + """Pass batch directly to model. + + Args: + data: WildDet3DInput from collator + + Returns: + Dict with 'batch' key pointing to the input data + """ + return {"batch": data} + + +class WildDet3DLossConnector: + """Loss connector that passes model output and batch directly to loss. + + Similar to WildDet3DPassthroughConnector, this bypasses vis4d's key_mapping + since WildDet3DLoss expects structured objects (WildDet3DOut, WildDet3DInput). + + This connector is used with LossModule to enable proper wandb logging of + individual loss components (loss_cls, loss_bbox, loss_giou, etc.). + """ + + def __call__(self, predictions, batch: WildDet3DInput) -> dict: + """Map model output and batch to loss function inputs. + + Args: + predictions: WildDet3DOut from model.forward() + batch: WildDet3DInput from collator + + Returns: + Dict with 'out' and 'batch' keys for WildDet3DLoss.forward() + """ + return { + "out": predictions, + "batch": batch, + } + + +class WildDet3DVisConnector: + """Vis connector that extracts from WildDet3DInput for visualization. + + vis4d's CallbackConnector uses dict access (data[key]) which doesn't + work with WildDet3DInput dataclass. This connector does the + extraction manually. + + Args: + score_threshold: Only visualize boxes with score >= this value. + Separate from model's score_threshold so evaluation AP is unaffected. + """ + + def __init__(self, score_threshold: float = 0.0): + self.score_threshold = score_threshold + + def __call__(self, prediction, data: WildDet3DInput) -> dict: + """Extract visualization data from dataclass + prediction. + + Args: + prediction: Det3DOut NamedTuple from model. + data: WildDet3DInput from collator. + + Returns: + Dict with keys expected by BoundingBox3DVisualizer. + """ + # When the collator filters out images with no GT boxes (empty batch), + # original_images is None. Return empty tensor so the visualizer's + # for-loop iterates 0 times instead of crashing. + images = data.original_images + if images is None: + images = torch.zeros(0, 3, 1, 1) + + boxes3d = prediction.boxes3d + class_ids = prediction.class_ids + scores = prediction.scores + + # Filter by score threshold per image for cleaner visualization + if self.score_threshold > 0.0 and scores is not None: + filtered_boxes3d = [] + filtered_class_ids = [] + filtered_scores = [] + for i in range(len(scores)): + mask = scores[i] >= self.score_threshold + filtered_scores.append(scores[i][mask]) + filtered_class_ids.append(class_ids[i][mask]) + filtered_boxes3d.append(boxes3d[i][mask]) + boxes3d = filtered_boxes3d + class_ids = filtered_class_ids + scores = filtered_scores + + # Cast to float32 for numpy compatibility (bf16 not supported) + if scores is not None: + scores = [s.float() for s in scores] + if boxes3d is not None: + boxes3d = [b.float() for b in boxes3d] + + intrinsics = data.original_intrinsics + if intrinsics is not None: + intrinsics = intrinsics.float() + + return { + "images": images, + "image_names": data.sample_names, + "intrinsics": intrinsics, + "boxes3d": boxes3d, + "class_ids": class_ids, + "scores": scores, + } + + +class WildDet3DEvalConnector: + """Eval connector that extracts from WildDet3DInput for evaluator. + + Same issue as WildDet3DVisConnector: CallbackConnector doesn't work with + dataclass. This connector manually extracts fields. + """ + + def __call__(self, prediction, data: WildDet3DInput) -> dict: + """Extract evaluation data from dataclass + prediction. + + Args: + prediction: Det3DOut NamedTuple from model. + data: WildDet3DInput from collator. + + Returns: + Dict with keys expected by Omni3DEvaluator. + """ + return { + "coco_image_id": data.sample_names, + "dataset_names": data.dataset_name, + "pred_boxes": prediction.boxes, + "pred_scores": prediction.scores, + "pred_classes": prediction.class_ids, + "pred_boxes3d": prediction.boxes3d, + } + + +class WildDet3DDetect3DEvalConnector: + """Eval connector for Detect3DEvaluator with WildDet3DInput. + + Unlike WildDet3DEvalConnector, this connector does not include dataset_names + since Detect3DEvaluator.process_batch does not accept that argument. + """ + + def __call__(self, prediction, data: WildDet3DInput) -> dict: + """Extract evaluation data from dataclass + prediction. + + Args: + prediction: Det3DOut NamedTuple from model. + data: WildDet3DInput from collator. + + Returns: + Dict with keys expected by Detect3DEvaluator.process_batch. + """ + return { + "coco_image_id": data.sample_names, + "pred_boxes": prediction.boxes, + "pred_scores": prediction.scores, + "pred_classes": prediction.class_ids, + "pred_boxes3d": prediction.boxes3d, + } + + +def get_wilddet3d_data_connector_cfg() -> tuple[ConfigDict, ConfigDict]: + """Get WildDet3D data connector configuration. + + Returns: + Tuple of (train_connector, test_connector). + + Note: + Uses WildDet3DPassthroughConnector which passes the collated batch + directly to model.forward(batch=...), rather than mapping individual + keys like standard vis4d DataConnector. + """ + train_data_connector = class_config(WildDet3DPassthroughConnector) + test_data_connector = class_config(WildDet3DPassthroughConnector) + + return train_data_connector, test_data_connector + + +def get_wilddet3d_collator_cfg( + max_prompts_per_image: int = 50, + use_text_prompts: bool = True, + # Point prompt options (for ablation) + use_point_prompts: bool = False, + num_positive_points: int | tuple[int, int] = 1, + num_negative_points: int | tuple[int, int] = 0, + point_sample_mode: Literal["centered", "random_mask", "random_box"] = "random_mask", + # Box prompt options + use_box_prompts: bool = True, + box_noise_std: float = 0.0, + box_noise_max: float | None = 20.0, + # Text/Visual query ratio (SAM3 original design) + text_query_prob: float = 0.7, + keep_text_for_visual: bool = False, +) -> ConfigDict: + """Get WildDet3D collator configuration. + + The collator converts per-image DataLoader output to WildDet3DInput. + Following SAM3 original design: per-category queries with multi-instance targets. + + Args: + max_prompts_per_image: Max prompts (categories) per image + use_text_prompts: Whether to include text with geometric prompts + use_point_prompts: Whether to sample point prompts (for ablation) + num_positive_points: Number of positive points to sample + Can be int or (min, max) tuple for random range + num_negative_points: Number of negative points to sample + Can be int or (min, max) tuple for random range + point_sample_mode: How to sample points when mask is available + - "centered": sample from mask center (farthest from edges) + - "random_mask": uniform sample from mask interior + - "random_box": uniform sample from box, label from mask + use_box_prompts: Whether to use box prompts + box_noise_std: Noise std for box jittering (0 = no noise) + box_noise_max: Max noise in pixels + text_query_prob: Probability of text-only queries (SAM3 recommended: 0.7) + 1.0 = all text queries (pure text training) + 0.7 = 70% text, 30% visual (SAM3 mixed training) + 0.0 = all visual queries (DetAny3D style) + keep_text_for_visual: If True, visual queries keep category text + If False (default), visual queries use "visual" as text + + Returns: + Collator configuration + """ + return class_config( + WildDet3DCollator, + max_prompts_per_image=max_prompts_per_image, + use_text_prompts=use_text_prompts, + use_point_prompts=use_point_prompts, + num_positive_points=num_positive_points, + num_negative_points=num_negative_points, + point_sample_mode=point_sample_mode, + use_box_prompts=use_box_prompts, + box_noise_std=box_noise_std, + box_noise_max=box_noise_max, + text_query_prob=text_query_prob, + keep_text_for_visual=keep_text_for_visual, + ) diff --git a/wilddet3d/data/__init__.py b/wilddet3d/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..68ffbd63f41daecbc90a48541f408e03349431e3 --- /dev/null +++ b/wilddet3d/data/__init__.py @@ -0,0 +1 @@ +"""Data utilities.""" diff --git a/wilddet3d/data/__pycache__/__init__.cpython-311.pyc b/wilddet3d/data/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e86e856a2a2bce6ef5568d3965d30f143d540678 Binary files /dev/null and b/wilddet3d/data/__pycache__/__init__.cpython-311.pyc differ diff --git a/wilddet3d/data/datasets/__init__.py b/wilddet3d/data/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/wilddet3d/data/datasets/argoverse.py b/wilddet3d/data/datasets/argoverse.py new file mode 100644 index 0000000000000000000000000000000000000000..5032fc27faed0de4d2b531e6c9826e106c99400b --- /dev/null +++ b/wilddet3d/data/datasets/argoverse.py @@ -0,0 +1,94 @@ +"""Argoverse V2 Sensor dataset.""" + +from __future__ import annotations + +from vis4d.common.typing import ArgsType, DictStrAny + +from .coco3d import COCO3DDataset + +TRAIN_SAMPLE_RATE = 10 +VAL_SAMPLE_RATE = 5 +ACC_FRAMES = 5 + + +av2_class_map = { + "regular vehicle": 0, + "pedestrian": 1, + "bicyclist": 2, + "motorcyclist": 3, + "wheeled rider": 4, + "bollard": 5, + "construction cone": 6, + "sign": 7, + "construction barrel": 8, + "stop sign": 9, + "mobile pedestrian crossing sign": 10, + "large vehicle": 11, + "bus": 12, + "box truck": 13, + "truck": 14, + "vehicular trailer": 15, + "truck cab": 16, + "school bus": 17, + "articulated bus": 18, + "message board trailer": 19, + "bicycle": 20, + "motorcycle": 21, + "wheeled device": 22, + "wheelchair": 23, + "stroller": 24, + "dog": 25, +} + +av2_det_map = { + "regular vehicle": 0, + "pedestrian": 1, + "bicyclist": 2, + "motorcyclist": 3, + "wheeled rider": 4, + "bollard": 5, + "construction cone": 6, + "sign": 7, + "construction barrel": 8, + "stop sign": 9, + "mobile pedestrian crossing sign": 10, + "large vehicle": 11, + "bus": 12, + "box truck": 13, + "truck": 14, + "vehicular trailer": 15, + "truck cab": 16, + "school bus": 17, + "articulated bus": 18, + "bicycle": 19, + "motorcycle": 20, + "wheeled device": 21, + "stroller": 22, +} + + +class AV2SensorDataset(COCO3DDataset): + """Argoverse V2 Sensor dataset.""" + + def __init__( + self, + class_map: dict[str, int] = av2_class_map, + max_depth: float = 80.0, + depth_scale: float = 256.0, + **kwargs: ArgsType, + ) -> None: + """Creates an instance of the class.""" + super().__init__( + class_map=class_map, + max_depth=max_depth, + depth_scale=depth_scale, + **kwargs, + ) + + def get_depth_filenames(self, img: DictStrAny) -> str | None: + """Get the depth filenames.""" + return ( + img["file_path"] + .replace("images", "depth") + .replace(".jpg", "_depth.png") + ) diff --git a/wilddet3d/data/datasets/coco3d.py b/wilddet3d/data/datasets/coco3d.py new file mode 100644 index 0000000000000000000000000000000000000000..6b0613fca6abf0443fec6b38fd9e2c43dcc56cb9 --- /dev/null +++ b/wilddet3d/data/datasets/coco3d.py @@ -0,0 +1,574 @@ +"""COCO 3D API.""" + +from __future__ import annotations + +import contextlib +import io +import json +import os +import time +from collections import defaultdict +from collections.abc import Sequence + +import numpy as np +from pycocotools.coco import COCO +from pyquaternion import Quaternion +from scipy.spatial.transform import Rotation as R +from vis4d.common.logging import rank_zero_info, rank_zero_warn +from vis4d.common.typing import ArgsType, DictStrAny +from vis4d.data.const import AxisMode +from vis4d.data.const import CommonKeys as K +from vis4d.data.datasets.base import Dataset +from vis4d.data.datasets.util import ( + CacheMappingMixin, + im_decode, + print_class_histogram, +) +from vis4d.data.typing import DictData + + +class COCO3DDataset(CacheMappingMixin, Dataset): + """3D Object Detection Dataset using coco annotation files.""" + + def __init__( + self, + data_root: str, + dataset_name: str, + class_map: dict[str, int], + det_map: dict[str, int], + keys_to_load: Sequence[str] = (K.images, K.boxes2d, K.boxes3d), + with_depth: bool = False, + max_depth: float = 80.0, + depth_scale: float = 256.0, + remove_empty: bool = False, + data_prefix: str | None = None, + text_prompt_mapping: dict[str, dict[str, str]] | None = None, + cache_as_binary: bool = False, + cached_file_path: str | None = None, + # Omni3DAPI filtering thresholds (passed to COCO3D) + truncation_thres: float = 0.33333333, + visibility_thres: float = 0.33333333, + min_height_thres: float = 0.0625, + max_height_thres: float = 1.50, + **kwargs: ArgsType, + ) -> None: + """Creates an instance of the class.""" + super().__init__(**kwargs) + self.data_root = data_root + self.dataset_name = dataset_name + self.annotation_file = f"{dataset_name}.json" + + self.keys_to_load = list(keys_to_load) + self.remove_empty = remove_empty + + self.class_map = class_map # Class mapping in the annotation file + self.det_map = det_map # Class mapping for detection + self.categories = sorted(self.det_map, key=self.det_map.get) + + self.data_prefix = data_prefix + self.text_prompt_mapping = text_prompt_mapping + + # Omni3DAPI filtering thresholds + self.truncation_thres = truncation_thres + self.visibility_thres = visibility_thres + self.min_height_thres = min_height_thres + self.max_height_thres = max_height_thres + + # Metric Depth + if with_depth and not K.depth_maps in keys_to_load: + self.keys_to_load.append(K.depth_maps) + + self.max_depth = max_depth + self.depth_scale = depth_scale + + # Load annotations + self.samples, _ = self._load_mapping( + self._generate_data_mapping, + self._filter_data, + cache_as_binary=cache_as_binary, + cached_file_path=cached_file_path, + ) + + def __repr__(self) -> str: + """Concise representation of the dataset.""" + return self.dataset_name + + def _filter_data(self, data: list[DictStrAny]) -> list[DictStrAny]: + """Remove empty samples.""" + samples = [] + + frequencies = {cat: 0 for cat in sorted(self.det_map)} + + empty_samples = 0 + no_depth_samples = 0 + for sample in data: + if self.remove_empty and len(sample["anns"]) == 0: + empty_samples += 1 + continue + + if ( + K.depth_maps in self.keys_to_load + and "depth_filename" not in sample + ): + empty_samples += 1 + no_depth_samples += 1 + continue + + for ann in sample["anns"]: + frequencies[ann["category_name"]] += 1 + + samples.append(sample) + + rank_zero_info( + f"Propocessing {self.dataset_name} with {len(samples)} samples." + ) + rank_zero_info(f"No depth samples: {no_depth_samples}") + rank_zero_info(f"Filtered {empty_samples} empty samples") + print_class_histogram(frequencies) + + return samples + + def _get_cat_id( + self, img: DictStrAny, ann: DictStrAny, cat_name: str + ) -> None: + """Get the category id from the category name.""" + ann["category_id"] = self.det_map[cat_name] + + def _generate_data_mapping(self) -> list[DictStrAny]: + """Generates the data mapping.""" + # Load annotations + with contextlib.redirect_stdout(io.StringIO()): + coco_api = COCO3D( + os.path.join( + self.data_root, "annotations", self.annotation_file + ), + self.categories, + truncation_thres=self.truncation_thres, + visibility_thres=self.visibility_thres, + min_height_thres=self.min_height_thres, + max_height_thres=self.max_height_thres, + ) + + cats_map = {v: k for k, v in self.class_map.items()} + + img_ids = sorted(coco_api.getImgIds()) + imgs = coco_api.loadImgs(img_ids) + + samples = [] + for img_id, img in zip(img_ids, imgs): + # Fix file path for Omni3D + if self.data_prefix is not None: + img["file_path"] = os.path.join( + self.data_prefix, img["file_path"] + ) + + valid_anns = [] + anns = coco_api.imgToAnns[img_id] + + boxes = [] + boxes3d = np.empty((0, 10), dtype=np.float32)[1:] + class_ids = np.empty((0,), dtype=np.int64)[1:] + ignore_boxes = [] + ignore_class_ids_list = [] + for ann in anns: + cat_name = cats_map[ann["category_id"]] + assert cat_name == ann["category_name"] + + if cat_name in {"dontcare", "ignore", "void"}: + continue + + if ann["ignore"]: + # Preserve ignore box 2D coords and class ID + # for negative loss suppression during training. + # Only keep objects that are actually visible in + # the image — skip behind_camera and degenerate bbox. + if ( + cat_name in self.det_map + and not ann.get("behind_camera", False) + ): + x1, y1, w, h = ann["bbox"] + if w > 0 and h > 0: + ignore_boxes.append( + (x1, y1, x1 + w, y1 + h) + ) + ignore_class_ids_list.append( + self.det_map[cat_name] + ) + continue + + # Box 3D + center = ann["center_cam"] + width, height, length = ann["dimensions"] + + # Check if the rotation matrix is valid + try: + x, y, z, w = R.from_matrix( + np.array(ann["R_cam"]) + ).as_quat() + except Exception as e: + rank_zero_warn( + f"Error processing rotation matrix for annotation {ann['id']}: {e}" + ) + continue + + orientation = Quaternion([w, x, y, z]) + + boxes3d = np.concatenate( + [ + boxes3d, + np.array( + [ + [ + *center, + width, + length, + height, + *orientation.elements, + ] + ], + dtype=np.float32, + ), + ] + ) + + # Box 2D + x1, y1, width, height = ann["bbox"] + x2, y2 = x1 + width, y1 + height + boxes.append((x1, y1, x2, y2)) + + # Class + self._get_cat_id(img, ann, cat_name) + + class_ids = np.concatenate( + [ + class_ids, + np.array([ann["category_id"]], dtype=np.int64), + ] + ) + + valid_anns.append(ann) + + boxes2d = ( + np.empty((0, 4), dtype=np.float32) + if not boxes + else np.array(boxes, dtype=np.float32) + ) + + depth_filename = self.get_depth_filenames(img) + + ignore_boxes2d = ( + np.empty((0, 4), dtype=np.float32) + if not ignore_boxes + else np.array(ignore_boxes, dtype=np.float32) + ) + ignore_class_ids = np.array( + ignore_class_ids_list, dtype=np.int64 + ) + + sample = { + "img_id": img_id, + "img": img, + "anns": valid_anns, + "boxes2d": boxes2d, + "boxes3d": boxes3d, + "class_ids": class_ids, + "ignore_boxes2d": ignore_boxes2d, + "ignore_class_ids": ignore_class_ids, + } + + if depth_filename is not None and ( + self.data_backend.exists(depth_filename) + or os.path.exists(depth_filename) + ): + sample["depth_filename"] = depth_filename + + samples.append(sample) + + return samples + + def get_depth_filenames(self, img: DictStrAny) -> str | None: + """Get the depth filenames. + + Since not every data has depth. + """ + return None + + def get_cat_ids(self, idx: int) -> list[int]: + """Return the samples.""" + return self.samples[idx]["class_ids"].tolist() + + def __len__(self) -> int: + """Total number of samples of data.""" + return len(self.samples) + + def get_depth_map(self, sample: DictStrAny) -> np.ndarray: + """Get the depth map.""" + depth_bytes = self.data_backend.get(sample["depth_filename"]) + depth_array = im_decode(depth_bytes) + + depth = np.ascontiguousarray(depth_array, dtype=np.float32) + + depth = depth / self.depth_scale + + return depth + + def __getitem__(self, idx: int) -> DictData: + """Get single sample. + + Args: + idx (int): Index of sample. + + Returns: + DictData: sample at index in Vis4D input format. + """ + sample = self.samples[idx] + data_dict: DictData = {} + + # Get image info + data_dict[K.sample_names] = sample["img_id"] + + data_dict["dataset_name"] = self.dataset_name + data_dict[K.boxes2d_names] = self.categories + data_dict["text_prompt_mapping"] = self.text_prompt_mapping + + if K.images in self.keys_to_load: + im_bytes = self.data_backend.get(sample["img"]["file_path"]) + image = np.ascontiguousarray( + im_decode(im_bytes, mode=self.image_channel_mode), + dtype=np.float32, + )[None] + + data_dict[K.images] = image + data_dict[K.input_hw] = (image.shape[1], image.shape[2]) + + data_dict[K.original_images] = image + data_dict[K.original_hw] = (image.shape[1], image.shape[2]) + + # Get camera info + intrinsics = np.array(sample["img"]["K"], dtype=np.float32) + data_dict[K.intrinsics] = intrinsics + data_dict["original_intrinsics"] = intrinsics + + data_dict[K.boxes2d] = sample["boxes2d"] + data_dict[K.boxes2d_classes] = sample["class_ids"] + data_dict[K.boxes3d] = sample["boxes3d"] + data_dict[K.boxes3d_classes] = sample["class_ids"] + data_dict[K.axis_mode] = AxisMode.OPENCV + + # Ignore boxes for negative loss suppression (backward compat) + data_dict["ignore_boxes2d"] = sample.get( + "ignore_boxes2d", np.empty((0, 4), dtype=np.float32) + ) + data_dict["ignore_class_ids"] = sample.get( + "ignore_class_ids", np.empty((0,), dtype=np.int64) + ) + + if K.depth_maps in self.keys_to_load: + depth = self.get_depth_map(sample) + + depth[depth > self.max_depth] = 0 + + data_dict[K.depth_maps] = depth + + data_dict["tokens_positive"] = None + + self.data_backend.close() + + return data_dict + + +class COCO3D(COCO): + """COCO API with 3D annotations.""" + + def __init__( + self, + annotation_files: Sequence[str] | str, + category_names: Sequence[str] | None = None, + ignore_names: Sequence[str] = ("dontcare", "ignore", "void"), + truncation_thres: float = 0.33333333, + visibility_thres: float = 0.33333333, + min_height_thres: float = 0.0625, + max_height_thres: float = 1.50, + modal_2D_boxes: bool = False, + trunc_2D_boxes: bool = True, + max_depth: int = 1e8, + ) -> None: + """Creates an instance of the class.""" + self.dataset, self.anns, self.cats, self.imgs = {}, {}, {}, {} + self.imgToAnns, self.catToImgs = defaultdict(list), defaultdict(list) + + self.truncation_thres = truncation_thres + self.visibility_thres = visibility_thres + self.min_height_thres = min_height_thres + self.max_height_thres = max_height_thres + self.max_depth = max_depth + + if isinstance(annotation_files, str): + annotation_files = [annotation_files] + + cats_ids_master = [] + cats_master = [] + + for annotation_file in annotation_files: + _, tail = os.path.split(annotation_file) + name, _ = os.path.splitext(tail) + + print(f"loading {name} annotations into memory...") + tic = time.time() + + with open(annotation_file, "r") as f: + dataset = json.load(f) + + assert ( + type(dataset) == dict + ), f"annotation file format {type(dataset)} not supported" + print(f"Done (t={time.time() - tic:.2f}s)") + + if "info" not in dataset: + dataset["info"] = {"description": name} + + if type(dataset["info"]) == list: + dataset["info"] = dataset["info"][0] + + dataset["info"]["known_category_ids"] = [ + cat["id"] for cat in dataset["categories"] + ] + + # first dataset + if len(self.dataset) == 0: + self.dataset = dataset + # concatenate datasets + else: + if type(self.dataset["info"]) == dict: + self.dataset["info"] = [self.dataset["info"]] + + self.dataset["info"] += [dataset["info"]] + self.dataset["annotations"] += dataset["annotations"] + self.dataset["images"] += dataset["images"] + + # sort through categories + for cat in dataset["categories"]: + if not cat["id"] in cats_ids_master: + cats_ids_master.append(cat["id"]) + cats_master.append(cat) + + # category names are provided to us + if category_names is not None: + self.dataset["categories"] = [ + cats_master[i] + for i in np.argsort(cats_ids_master) + if cats_master[i]["name"] in category_names + ] + # no categories are provided, so assume use ALL available. + else: + self.dataset["categories"] = [ + cats_master[i] for i in np.argsort(cats_ids_master) + ] + + category_names = [ + cat["name"] for cat in self.dataset["categories"] + ] + + # determine which categories we may actually use for filtering. + trainable_cats = set(ignore_names) | set(category_names) + + valid_anns = [] + im_height_map = {} + + for im_obj in self.dataset["images"]: + im_height_map[im_obj["id"]] = im_obj["height"] + + # Filter out annotations + for anno_idx, anno in enumerate(self.dataset["annotations"]): + + im_height = im_height_map[anno["image_id"]] + + # tightly annotated 2D boxes are not always available. + if ( + modal_2D_boxes + and "bbox2D_tight" in anno + and anno["bbox2D_tight"][0] != -1 + ): + bbox2D = anno["bbox2D_tight"] + elif ( + trunc_2D_boxes + and "bbox2D_trunc" in anno + and not np.all([val == -1 for val in anno["bbox2D_trunc"]]) + ): + bbox2D = anno["bbox2D_trunc"] + elif anno["bbox2D_proj"][0] != -1: + bbox2D = anno["bbox2D_proj"] + elif anno["bbox2D_tight"][0] != -1: + bbox2D = anno["bbox2D_tight"] + else: + continue + + # convert to xywh + bbox2D[2] = bbox2D[2] - bbox2D[0] + bbox2D[3] = bbox2D[3] - bbox2D[1] + + ignore = self.is_ignore(anno, bbox2D, ignore_names, im_height) + + width = bbox2D[2] + height = bbox2D[3] + + self.dataset["annotations"][anno_idx]["area"] = width * height + self.dataset["annotations"][anno_idx]["iscrowd"] = False + self.dataset["annotations"][anno_idx]["ignore"] = ignore + self.dataset["annotations"][anno_idx]["ignore2D"] = ignore + self.dataset["annotations"][anno_idx]["ignore3D"] = ignore + + self.dataset["annotations"][anno_idx]["bbox"] = bbox2D + self.dataset["annotations"][anno_idx]["bbox3D"] = anno[ + "bbox3D_cam" + ] + self.dataset["annotations"][anno_idx]["depth"] = anno[ + "center_cam" + ][2] + + category_name = anno["category_name"] + + if category_name in trainable_cats: + valid_anns.append(self.dataset["annotations"][anno_idx]) + + self.dataset["annotations"] = valid_anns + + self.createIndex() + + def is_ignore( + self, + anno, + bbox2D: list[float, float, float, float], + ignore_names: Sequence[str] | None, + image_height: int, + ) -> bool: + ignore = anno["behind_camera"] + ignore |= not bool(anno["valid3D"]) + + if ignore: + return ignore + + ignore |= anno["dimensions"][0] <= 0 + ignore |= anno["dimensions"][1] <= 0 + ignore |= anno["dimensions"][2] <= 0 + ignore |= anno["center_cam"][2] > self.max_depth + ignore |= anno["lidar_pts"] == 0 + ignore |= anno["segmentation_pts"] == 0 + ignore |= anno["depth_error"] > 0.5 + + ignore |= bbox2D[3] <= self.min_height_thres * image_height + ignore |= bbox2D[3] >= self.max_height_thres * image_height + + ignore |= ( + anno["truncation"] >= 0 + and anno["truncation"] >= self.truncation_thres + ) + ignore |= ( + anno["visibility"] >= 0 + and anno["visibility"] <= self.visibility_thres + ) + + if ignore_names is not None: + ignore |= anno["category_name"] in ignore_names + + return ignore diff --git a/wilddet3d/data/datasets/cubifyanything.py b/wilddet3d/data/datasets/cubifyanything.py new file mode 100644 index 0000000000000000000000000000000000000000..92f066ab56189fabca293cfa15c3de216f936ad3 --- /dev/null +++ b/wilddet3d/data/datasets/cubifyanything.py @@ -0,0 +1,90 @@ +"""CubifyAnything (CA-1M) dataset for 3D object detection.""" + +from __future__ import annotations + +import json +import os + +from vis4d.common.typing import ArgsType, DictStrAny + +from wilddet3d.data.datasets.coco3d import COCO3DDataset + + +def get_cubifyanything_det_map( + dataset_name: str, + data_root: str = "data/cubifyanything", +) -> dict[str, int]: + """Build det_map from CA-1M annotation JSON categories. + + CA-1M has ~3000 free-form categories. Since our model is + open-vocabulary (text-prompted), we build det_map dynamically + from the annotation JSON's categories list. + + Args: + dataset_name: e.g. "CubifyAnything_train" or "CubifyAnything_val" + data_root: Root directory for CubifyAnything data. + """ + cache_path = os.path.join( + data_root, "annotations", f"{dataset_name}_class_map.json" + ) + if os.path.exists(cache_path): + with open(cache_path) as f: + return json.load(f) + json_path = os.path.join( + data_root, "annotations", f"{dataset_name}.json" + ) + with open(json_path) as f: + data = json.load(f) + class_map = {cat["name"]: cat["id"] for cat in data["categories"]} + with open(cache_path, "w") as f: + json.dump(class_map, f) + return class_map + + +def get_cubifyanything_class_map( + dataset_name: str, + data_root: str = "data/cubifyanything", +) -> dict[str, int]: + """Build class_map from CA-1M annotation JSON categories. + + CA-1M has ~3000 categories (not in omni3d_class_map), so + we build class_map dynamically from the annotation JSON. + class_map maps category_name -> category_id (same as det_map + for CA-1M, since all categories are trainable). + + Args: + dataset_name: e.g. "CubifyAnything_train" or "CubifyAnything_val" + data_root: Root directory for CubifyAnything data. + """ + return get_cubifyanything_det_map(dataset_name, data_root) + + +class CubifyAnything(COCO3DDataset): + """CubifyAnything (CA-1M) Dataset. + + Indoor scenes with uint16 mm-encoded depth maps. + """ + + def __init__( + self, + max_depth: float = 20.0, + depth_scale: float = 1000.0, + **kwargs: ArgsType, + ) -> None: + """Creates an instance of the class.""" + super().__init__( + max_depth=max_depth, + depth_scale=depth_scale, + **kwargs, + ) + + def get_depth_filenames(self, img: DictStrAny) -> str | None: + """Get the depth filename for a given image. + + Maps image path to depth path: + cubifyanything/data/CubifyAnything/train/42446540/ts.jpg + -> cubifyanything/depth_gt/train/42446540/ts.png + """ + return img["file_path"].replace( + "data/CubifyAnything", "depth_gt" + ).replace(".jpg", ".png") diff --git a/wilddet3d/data/datasets/foundationpose.py b/wilddet3d/data/datasets/foundationpose.py new file mode 100644 index 0000000000000000000000000000000000000000..882b8ab1042b14d22e75f36199989d04c95a6b32 --- /dev/null +++ b/wilddet3d/data/datasets/foundationpose.py @@ -0,0 +1,78 @@ +"""FoundationPose (GSO) dataset for 3D object detection. + +Synthetic dataset from FoundationPose with Google Scanned Objects (GSO). +438 categories, ~446K images with dense depth maps (uint16, depth_m * 256). +""" + +from __future__ import annotations + +import json +import os + +from vis4d.common.typing import ArgsType, DictStrAny + +from wilddet3d.data.datasets.coco3d import COCO3DDataset + + +def get_foundationpose_det_map( + dataset_name: str, + data_root: str = "data/foundationpose", +) -> dict[str, int]: + """Build det_map from FoundationPose annotation JSON categories.""" + cache_path = os.path.join( + data_root, "annotations", f"{dataset_name}_class_map.json" + ) + if os.path.exists(cache_path): + with open(cache_path) as f: + return json.load(f) + json_path = os.path.join( + data_root, "annotations", f"{dataset_name}.json" + ) + with open(json_path) as f: + data = json.load(f) + class_map = {cat["name"]: cat["id"] for cat in data["categories"]} + with open(cache_path, "w") as f: + json.dump(class_map, f) + return class_map + + +def get_foundationpose_class_map( + dataset_name: str, + data_root: str = "data/foundationpose", +) -> dict[str, int]: + """Build class_map from FoundationPose annotation JSON categories.""" + return get_foundationpose_det_map(dataset_name, data_root) + + +class FoundationPoseDataset(COCO3DDataset): + """FoundationPose (GSO) Dataset. + + Synthetic scenes with dense depth maps (uint16, depth_m * 256). + """ + + def __init__( + self, + max_depth: float = 20.0, + depth_scale: float = 256.0, + **kwargs: ArgsType, + ) -> None: + """Creates an instance of the class.""" + super().__init__( + max_depth=max_depth, + depth_scale=depth_scale, + **kwargs, + ) + + def get_depth_filenames(self, img: DictStrAny) -> str | None: + """Get the depth filename for a given image. + + Maps image path to depth path: + foundationpose/images_jpg/gso/{name}.jpg + -> foundationpose/depth/gso/{name}.png + """ + path = img["file_path"] + path = path.replace( + "foundationpose/images_jpg/", "foundationpose/depth/" + ) + path = path.replace(".jpg", ".png") + return path diff --git a/wilddet3d/data/datasets/in_the_wild.py b/wilddet3d/data/datasets/in_the_wild.py new file mode 100644 index 0000000000000000000000000000000000000000..0335f8c2ef99383b1919b48d5dcd9d371d5d9984 --- /dev/null +++ b/wilddet3d/data/datasets/in_the_wild.py @@ -0,0 +1,447 @@ +"""In-The-Wild 3D dataset (COCO/LVIS/Objects365 with human-annotated 3D boxes).""" + +from __future__ import annotations + +import json +import os +import time +from collections import defaultdict + +import numpy as np +import cv2 +from pycocotools import mask as maskUtils + +from vis4d.common.logging import rank_zero_info +from vis4d.common.typing import ArgsType, DictStrAny +from vis4d.data.const import CommonKeys as K + +from .coco3d import COCO3DDataset + +_V4_DEPTH_ROOT = ( + "/weka/oe-training-default/weikaih/3d_boundingbox_detection" + "/single_frame_data/experiment/v4_depth_new" +) + +# Depth directories (v4_depth_new, unscaled) +_DEPTH_DIRS_NEW = { + "coco/val": f"{_V4_DEPTH_ROOT}/coco/val/depth", + "coco/train": f"{_V4_DEPTH_ROOT}/coco/train/depth", + "obj365/val": f"{_V4_DEPTH_ROOT}/obj365/val/depth", + "obj365/train": f"{_V4_DEPTH_ROOT}/obj365/train/depth", + "v3det/train": f"{_V4_DEPTH_ROOT}/v3det/train/depth", +} + +# Confidence map directories (uint8 PNG, same resolution as depth) +_CONF_DIRS = { + "coco/val": f"{_V4_DEPTH_ROOT}/coco/val/confidence", + "coco/train": f"{_V4_DEPTH_ROOT}/coco/train/confidence", + "obj365/val": f"{_V4_DEPTH_ROOT}/obj365/val/confidence", + "obj365/train": f"{_V4_DEPTH_ROOT}/obj365/train/confidence", + "v3det/train": f"{_V4_DEPTH_ROOT}/v3det/train/confidence", +} + +# Depth values in the .npy files are in mm; convert to meters +_DEPTH_MM_TO_M = 1.0 / 1000.0 + + +def _get_source_key_from_file_path(file_path: str) -> str: + """Infer v4_depth source key from image file_path. + + Handles both absolute paths (legacy) and HDF5 relative paths: + /weka/.../coco/train2017/X.jpg -> "coco/train" + images/coco_train/X.jpg -> "coco/train" + images/v3det_train/Q.../X.jpg -> "v3det/train" + """ + if "/v3det_train/" in file_path: + return "v3det/train" + elif "coco/val2017" in file_path or "/coco_val/" in file_path: + return "coco/val" + elif "coco/train2017" in file_path or "/coco_train/" in file_path: + return "coco/train" + elif ( + ("obj365" in file_path and "/train/" in file_path) + or "/obj365_train/" in file_path + ): + return "obj365/train" + else: + return "obj365/val" + + +def _get_formatted_id_from_file_path(file_path: str) -> str: + """Extract zero-padded 12-digit image ID from file path.""" + basename = file_path.split("/")[-1] # e.g. 000000000724.jpg + return ( + basename.replace(".jpg", "") + .replace("obj365_val_", "") + .replace("obj365_train_", "") + ) + + +def load_in_the_wild_class_map( + annotation_path: str = "data/in_the_wild/annotations/InTheWild_val.json", +) -> dict[str, int]: + """Load class map from InTheWild annotation file. + + Returns a mapping from category name to category ID (0-indexed alphabetical). + + Args: + annotation_path: Path to the InTheWild annotation JSON file. + + Returns: + dict mapping category name to annotation category ID. + """ + cache_path = annotation_path.replace(".json", "_class_map.json") + if os.path.exists(cache_path): + with open(cache_path) as f: + return json.load(f) + with open(annotation_path) as f: + data = json.load(f) + class_map = {cat["name"]: cat["id"] for cat in data["categories"]} + with open(cache_path, "w") as f: + json.dump(class_map, f) + return class_map + + +class InTheWild3DDataset(COCO3DDataset): + """In-The-Wild 3D dataset with 800+ open-vocabulary categories. + + Human-annotated 3D bounding boxes on COCO val2017, LVIS (COCO train2017), + and Objects365 val images. + + Annotations converted from human_annotated_val_full2d.json to Omni3D + COCO3D format using scripts/in_the_wild/convert_in_the_wild.py. + Camera intrinsics are scaled back to original image resolution (non-SR). + + Depth maps are from v4_depth (SR 1024-long-edge .npy, mm units), + resized to original image resolution on load. + """ + + def __init__( + self, + class_map: dict[str, int], + max_depth: float = 100.0, + per_image_categories: bool = False, + depth_confidence_threshold: int = 0, + mask_annotation_files: dict[str, str] | None = None, + **kwargs: ArgsType, + ) -> None: + """Creates an instance of the class. + + Args: + class_map: Mapping from category name to category ID. + max_depth: Maximum depth in meters (clip beyond this). + per_image_categories: If True, boxes2d_names only contains + the GT categories present in each image. Required for + GDino/3D-MOOD eval (avoids BERT truncation with 1246 + categories). Must be False for WildDet3D (collator indexes + boxes2d_names by global cat_id). + depth_confidence_threshold: Minimum confidence (uint8, 0-255) + for a depth pixel to be considered valid. Pixels below + this threshold are set to 0 (invalid). Set to 0 to + disable confidence masking. Only applies when confidence + map exists for the image. + mask_annotation_files: Optional dict mapping source key + (e.g. "coco/train", "obj365/val") to the annotation + JSON path that contains segmentation masks. When + provided, masks are matched to each sample's boxes and + returned as "masks2d_rle" in __getitem__. + """ + super().__init__( + class_map=class_map, + det_map=class_map, + max_depth=max_depth, + **kwargs, + ) + self.per_image_categories = per_image_categories + self.depth_confidence_threshold = depth_confidence_threshold + + # Separate dict for mask RLEs (DatasetFromList serializes + # samples, so in-place mutation does not persist). + self._mask_rle_index: dict[int, list] = {} + if mask_annotation_files: + self._build_mask_index(mask_annotation_files) + + def _build_mask_index( + self, mask_annotation_files: dict[str, str | list[str]] + ) -> None: + """Load mask annotations and build per-sample mask index. + + For each mask annotation file, builds an index by image filename, + then matches masks to ITW sample boxes by (x1, y1, w, h) + coordinate proximity. Supports multiple files per source key + (e.g. both LVIS and COCO instances for coco/train). + + Args: + mask_annotation_files: {source_key: path_or_list_of_paths}. + """ + # Group samples by (source_key, basename) for matching + source_bn_to_indices = defaultdict(list) + for i in range(len(self.samples)): + sample = self.samples[i] + fp = sample["img"]["file_path"] + sk = _get_source_key_from_file_path(fp) + bn = fp.split("/")[-1] + source_bn_to_indices[(sk, bn)].append(i) + + # Normalize to list of paths per source key + expanded = {} + for source_key, paths in mask_annotation_files.items(): + if isinstance(paths, str): + expanded[source_key] = [paths] + else: + expanded[source_key] = list(paths) + + for source_key, ann_paths in expanded.items(): + for ann_path in ann_paths: + # Basenames we need from this source + needed_bns = { + bn + for (sk, bn) in source_bn_to_indices + if sk == source_key + } + if not needed_bns: + continue + + rank_zero_info( + f"[masks] Loading {source_key} from {ann_path} ..." + ) + t0 = time.time() + with open(ann_path) as f: + data = json.load(f) + rank_zero_info( + f"[masks] Loaded in {time.time() - t0:.1f}s " + f"({len(data.get('images', []))} images, " + f"{len(data.get('annotations', []))} annotations)" + ) + + # filename -> (mask_img_id, height, width) + fn_to_info = {} + for img in data["images"]: + fn = img.get("file_name") + if fn is None: + # LVIS format: file_name is None, use id + fn = f"{img['id']:012d}.jpg" + else: + fn = fn.split("/")[-1] + if fn in needed_bns: + fn_to_info[fn] = ( + img["id"], + img["height"], + img["width"], + ) + + # Reverse lookup: mask_img_id -> (height, width) + mid_to_hw = { + v[0]: (v[1], v[2]) for v in fn_to_info.values() + } + + rank_zero_info( + f"[masks] Matched {len(fn_to_info)} images " + "by filename" + ) + + # mask_img_id -> [(x1, y1, rle_dict), ...] + needed_ids = set(mid_to_hw.keys()) + mask_by_id = defaultdict(list) + for ann in data["annotations"]: + mid = ann["image_id"] + if mid not in needed_ids: + continue + seg = ann.get("segmentation") + if seg is None: + continue + bbox = ann["bbox"] # xywh + # Convert polygon / uncompressed RLE to compressed + # RLE for uniform handling + hw = mid_to_hw.get(mid) + if hw is None: + continue + if isinstance(seg, list): + # Polygon format + rles = maskUtils.frPyObjects(seg, hw[0], hw[1]) + seg = maskUtils.merge(rles) + elif isinstance(seg.get("counts"), list): + # Uncompressed RLE (iscrowd) -> compress + seg = maskUtils.frPyObjects( + seg, hw[0], hw[1] + ) + mask_by_id[mid].append( + (bbox[0], bbox[1], bbox[2], bbox[3], seg) + ) + + del data # free raw JSON + + # Match masks to ITW sample boxes (merge with + # existing matches from previous files) + n_matched = 0 + n_total = 0 + for (sk, bn), indices in source_bn_to_indices.items(): + if sk != source_key: + continue + info = fn_to_info.get(bn) + if info is None: + continue + mid = info[0] + masks_for_img = mask_by_id.get(mid, []) + if not masks_for_img: + continue + for si in indices: + sample = self.samples[si] + boxes2d = sample["boxes2d"] # (N, 4) xyxy + # Get existing matches (from previous file) + existing = self._mask_rle_index.get(si) + masks_rle = ( + list(existing) + if existing is not None + else [None] * len(boxes2d) + ) + for bi, box in enumerate(boxes2d): + if masks_rle[bi] is not None: + # Already matched by previous file + n_total += 1 + n_matched += 1 + continue + x1 = float(box[0]) + y1 = float(box[1]) + bw = float(box[2]) - x1 + bh = float(box[3]) - y1 + matched = None + for mx1, my1, mw, mh, rle in masks_for_img: + if ( + abs(mx1 - x1) < 1.0 + and abs(my1 - y1) < 1.0 + and abs(mw - bw) < 2.0 + and abs(mh - bh) < 2.0 + ): + matched = rle + break + masks_rle[bi] = matched + n_total += 1 + if matched is not None: + n_matched += 1 + self._mask_rle_index[si] = masks_rle + + rank_zero_info( + f"[masks] Matched {n_matched}/{n_total} boxes " + f"for {source_key}" + ) + + rank_zero_info( + f"[masks] Total: {len(self._mask_rle_index)}" + f"/{len(self.samples)} samples have masks" + ) + + def __getitem__(self, idx: int): + """Get single sample, optionally with per-image category filtering.""" + data_dict = super().__getitem__(idx) + if self.per_image_categories: + class_ids_in_img = data_dict[K.boxes2d_classes] + if len(class_ids_in_img) > 0: + unique_global_ids = sorted(set(class_ids_in_img.tolist())) + data_dict[K.boxes2d_names] = [ + self.categories[gid] for gid in unique_global_ids + ] + else: + data_dict[K.boxes2d_names] = [] + + # Decode masks and add as (N, H, W) uint8 array for transforms. + # masks_rle is aligned with sample["boxes2d"] (pre-filter). + # data_dict boxes2d comes from COCO3D which may filter some + # boxes (ignore, bad rotation, etc.), but the ordering of + # valid boxes is preserved, so masks_rle indices still match. + masks_rle = self._mask_rle_index.get(idx) + if masks_rle is not None and len(masks_rle) > 0: + n_boxes = len(data_dict[K.boxes2d]) + if n_boxes == 0 or n_boxes != len(masks_rle): + pass # Misaligned or empty, skip masks + else: + sample = self.samples[idx] + h = sample["img"]["height"] + w = sample["img"]["width"] + decoded = [] + for rle in masks_rle: + if rle is not None: + decoded.append(maskUtils.decode(rle)) + else: + decoded.append( + np.zeros((h, w), dtype=np.uint8) + ) + data_dict["masks2d"] = np.stack( + decoded, axis=0 + ) + + return data_dict + + def get_depth_filenames(self, img: DictStrAny) -> str | None: + """Return path to the .npy depth file for this image. + + Uses v4_depth_new (unscaled depth maps). + """ + file_path = img["file_path"] + source_key = _get_source_key_from_file_path(file_path) + if "formatted_id" in img: + formatted_id = img["formatted_id"] + else: + formatted_id = _get_formatted_id_from_file_path(file_path) + depth_dir = _DEPTH_DIRS_NEW.get(source_key) + if depth_dir is None: + return None + depth_path = f"{depth_dir}/{formatted_id}_sr_1024_long.npy" + return depth_path if os.path.exists(depth_path) else None + + def get_depth_map(self, sample: DictStrAny) -> np.ndarray: + """Load .npy depth (mm) and resize to original image resolution. + + If depth_confidence_threshold > 0, loads the MoGe2 confidence + map (uint8 PNG, same resolution as depth) and zeros out pixels + where confidence < threshold. + """ + depth_npy = np.load(sample["depth_filename"]) # (H_sr, W_sr) float32, mm + + # Apply MoGe2 confidence masking before resize + if self.depth_confidence_threshold > 0: + img_entry = sample["img"] + file_path = img_entry["file_path"] + source_key = _get_source_key_from_file_path(file_path) + conf_dir = _CONF_DIRS.get(source_key) + if conf_dir is not None: + if "formatted_id" in img_entry: + formatted_id = img_entry["formatted_id"] + else: + formatted_id = _get_formatted_id_from_file_path( + file_path + ) + conf_path = f"{conf_dir}/{formatted_id}.png" + if os.path.exists(conf_path): + conf = cv2.imread( + conf_path, cv2.IMREAD_UNCHANGED + ) # uint8, same shape as depth + if conf.shape != depth_npy.shape: + conf = cv2.resize( + conf, + (depth_npy.shape[1], depth_npy.shape[0]), + interpolation=cv2.INTER_NEAREST, + ) + depth_npy[ + conf < self.depth_confidence_threshold + ] = 0.0 + + orig_h = sample["img"]["height"] + orig_w = sample["img"]["width"] + + # Resize to original image size using nearest-neighbor to avoid + # interpolation artifacts at depth discontinuities + if depth_npy.shape != (orig_h, orig_w): + depth_npy = cv2.resize( + depth_npy, + (orig_w, orig_h), + interpolation=cv2.INTER_NEAREST, + ) + + # Convert mm -> meters + depth = depth_npy * _DEPTH_MM_TO_M + + # Clip to max_depth + depth[depth > self.max_depth] = 0.0 + + return depth.astype(np.float32) diff --git a/wilddet3d/data/datasets/labelany3d_coco.py b/wilddet3d/data/datasets/labelany3d_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..9b0e34605646838fe4b3f60ab2a9b0b1d9b14226 --- /dev/null +++ b/wilddet3d/data/datasets/labelany3d_coco.py @@ -0,0 +1,145 @@ +"""LabelAny3D COCO Dataset. + +This dataset contains COCO images with 3D annotations generated by LabelAny3D. +It uses standard COCO 80 categories with metric 3D bounding boxes. + +The dataset provides: +- 2010 validation images from COCO val2017 +- 5409 3D bounding box annotations +- 80 COCO object categories +- Camera intrinsics (K matrix) for each image +- Full 3D annotations: center_cam, dimensions, R_cam, bbox3D_cam +""" + +from __future__ import annotations + +from vis4d.common.typing import ArgsType, DictStrAny + +from wilddet3d.data.datasets.coco3d import COCO3DDataset + +# ============================================================================= +# COCO 80 categories mapping +# Keys: category names (str) +# Values: category IDs from the annotation file (int) +# Sorted alphabetically for consistency +# ============================================================================= +labelany3d_coco_class_map = { + "airplane": 98, + "apple": 136, + "backpack": 116, + "banana": 135, + "baseball bat": 126, + "baseball glove": 127, + "bear": 113, + "bed": 39, + "bench": 105, + "bicycle": 11, + "bird": 106, + "boat": 100, + "book": 149, + "bottle": 15, + "bowl": 56, + "broccoli": 139, + "bus": 12, + "cake": 144, + "car": 1, + "carrot": 140, + "cat": 107, + "cell phone": 148, + "chair": 18, + "clock": 87, + "couch": 145, + "cow": 111, + "cup": 19, + "dining table": 146, + "dog": 108, + "donut": 143, + "elephant": 112, + "fire hydrant": 102, + "fork": 132, + "frisbee": 121, + "giraffe": 115, + "hair drier": 152, + "handbag": 118, + "horse": 109, + "hot dog": 141, + "keyboard": 77, + "kite": 125, + "knife": 133, + "laptop": 20, + "microwave": 54, + "motorcycle": 10, + "mouse": 81, + "orange": 138, + "oven": 57, + "parking meter": 104, + "person": 7, + "pizza": 142, + "potted plant": 73, + "refrigerator": 49, + "remote": 95, + "sandwich": 137, + "scissors": 150, + "sheep": 110, + "sink": 28, + "skateboard": 128, + "skis": 122, + "snowboard": 123, + "spoon": 134, + "sports ball": 124, + "stop sign": 103, + "suitcase": 120, + "surfboard": 129, + "teddy bear": 151, + "tennis racket": 130, + "tie": 119, + "toaster": 72, + "toilet": 32, + "toothbrush": 153, + "traffic light": 101, + "train": 99, + "truck": 5, + "tv": 147, + "umbrella": 117, + "vase": 58, + "wine glass": 131, + "zebra": 114, +} + +# Detection map for evaluation (0-indexed, continuous) +labelany3d_coco_det_map = {cat: i for i, cat in enumerate(sorted(labelany3d_coco_class_map.keys()))} + + +class LabelAny3DCOCO(COCO3DDataset): + """LabelAny3D COCO Dataset with 3D annotations.""" + + def __init__( + self, + class_map: dict[str, int] = labelany3d_coco_class_map, + max_depth: float = 80.0, + depth_scale: float = 256.0, + **kwargs: ArgsType, + ) -> None: + """Creates an instance of the class. + + Args: + class_map: Mapping from category names to class IDs + max_depth: Maximum depth value for clipping + depth_scale: Scale factor for depth values + **kwargs: Additional arguments passed to COCO3DDataset + """ + super().__init__( + class_map=class_map, + max_depth=max_depth, + depth_scale=depth_scale, + **kwargs, + ) + + def get_depth_filenames(self, img: DictStrAny) -> str | None: + """Get the depth filenames. + + LabelAny3D COCO doesn't have pre-computed depth maps. + Depth will be estimated on-the-fly during inference. + """ + return None + diff --git a/wilddet3d/data/datasets/odvg.py b/wilddet3d/data/datasets/odvg.py new file mode 100644 index 0000000000000000000000000000000000000000..17f63dba8df954450eb4b4680883836209729b35 --- /dev/null +++ b/wilddet3d/data/datasets/odvg.py @@ -0,0 +1,280 @@ +"""Object detection and visual grounding dataset.""" + +from __future__ import annotations + +import json +import os.path as osp + +import numpy as np +from tqdm import tqdm +from vis4d.common.logging import rank_zero_info +from vis4d.common.typing import ArgsType, DictStrAny +from vis4d.data.const import CommonKeys as K +from vis4d.data.datasets.base import Dataset +from vis4d.data.datasets.util import ( + CacheMappingMixin, + im_decode, + print_class_histogram, +) +from vis4d.data.typing import DictData + + +class ODVGDataset(CacheMappingMixin, Dataset): + """Object detection and visual grounding dataset.""" + + def __init__( + self, + data_root: str, + ann_file: str, + label_map_file: str | None = None, + dataset_type: str = "VG", + dataset_prefix: str | None = None, + remove_empty: bool = False, + cache_as_binary: bool = False, + cached_file_path: str | None = None, + **kwargs: ArgsType, + ) -> None: + """Create an object detection and visual grounding dataset.""" + super().__init__(**kwargs) + + self.data_root = data_root + self.ann_file = ann_file + self.dataset_type = dataset_type + self.dataset_prefix = dataset_prefix + self.remove_empty = remove_empty + + if label_map_file is not None: + label_map_file = osp.join(self.data_root, label_map_file) + + with open(label_map_file, "r") as file: + # dict[class_id (str): class_name (str)] + self.label_map = json.load(file) + + self.dataset_type = "OD" + + self.det_map = {v: int(k) for k, v in self.label_map.items()} + self.categories = sorted(self.det_map, key=self.det_map.get) + else: + self.label_map = None + self.dataset_type = "VG" + + # Load annotations + self.samples, _ = self._load_mapping( + self._generate_data_mapping, + self._filter_data, + cache_as_binary=cache_as_binary, + cached_file_path=cached_file_path, + ) + + def __repr__(self) -> str: + """Concise representation of the dataset.""" + return f"ODVGDataset({self.ann_file})" + + def _filter_data(self, data: list[DictStrAny]) -> list[DictStrAny]: + """Remove empty samples.""" + samples = [] + + if self.dataset_type == "OD": + frequencies = {cat: 0 for _, cat in self.label_map.items()} + + empty_samples = 0 + for sample in data: + if self.remove_empty and len(sample["anns"]) == 0: + empty_samples += 1 + continue + + if self.dataset_type == "OD": + for ann in sample["anns"]: + frequencies[ann["category"]] += 1 + + samples.append(sample) + + rank_zero_info(f"Propocessing {self} with {len(samples)} samples.") + rank_zero_info(f"Filtered {empty_samples} empty samples") + + if self.dataset_type == "OD": + frequencies = dict(sorted(frequencies.items())) + + print_class_histogram(frequencies) + + return samples + + def _generate_data_mapping(self) -> list[DictStrAny]: + """Generates the data mapping.""" + with open(osp.join(self.data_root, self.ann_file), "r") as f: + data_list = [json.loads(line) for line in f] + + if self.with_camera: + with open(osp.join(self.data_root, "cam_info.json"), "r") as f: + cameras = json.load(f) + + samples = [] + for data in tqdm(data_list): + data_info = {} + + if self.dataset_prefix is not None: + img_path = osp.join( + self.data_root, self.dataset_prefix, data["filename"] + ) + else: + img_path = osp.join(self.data_root, data["filename"]) + + data_info["img_path"] = img_path + + # Pseudo K + if self.with_camera: + data_info["K"] = cameras[img_path][0] + + # Pseudo Depth Path + if self.dataset_prefix is not None: + depth_path = osp.join( + self.data_root, + f"{self.dataset_prefix}_depth", + data["filename"].replace(".jpg", "_depth.png"), + ) + else: + depth_path = osp.join( + self.data_root, + data["filename"].replace(".jpg", "_depth.png"), + ) + data_info["depth_path"] = depth_path + + data_info["height"] = data["height"] + data_info["width"] = data["width"] + + valid_anns = [] + boxes = [] + class_ids = np.empty((0,), dtype=np.int64)[1:] + if self.dataset_type == "OD": + instances = data.get("detection", {}).get("instances", []) + + for ann in instances: + bbox = ann["bbox"] + + # Box 2D + x1, y1, x2, y2 = bbox + inter_w = max(0, min(x2, data["width"]) - max(x1, 0)) + inter_h = max(0, min(y2, data["height"]) - max(y1, 0)) + + if inter_w * inter_h == 0: + continue + if (x2 - x1) < 1 or (y2 - y1) < 1: + continue + + boxes.append(bbox) + + # Class + class_ids = np.concatenate( + [class_ids, np.array([ann["label"]], dtype=np.int64)] + ) + + valid_anns.append(ann) + else: + anno = data["grounding"] + + caption = anno["caption"].lower().strip() + if not caption.endswith("."): + caption = caption + ". " + + data_info["caption"] = caption + + regions = anno["regions"] + phrases = [] + positive_positions = [] + for i, region in enumerate(regions): + bboxes = region["bbox"] + + if not isinstance(bboxes[0], list): + bboxes = [bboxes] + + for bbox in bboxes: + x1, y1, x2, y2 = bbox + inter_w = max(0, min(x2, data["width"]) - max(x1, 0)) + inter_h = max(0, min(y2, data["height"]) - max(y1, 0)) + + if inter_w * inter_h == 0: + continue + if (x2 - x1) < 1 or (y2 - y1) < 1: + continue + + boxes.append(bbox) + phrases.append(region["phrase"]) + positive_positions.append(region["tokens_positive"]) + valid_anns.append(region) + + class_ids = np.concatenate( + [class_ids, np.array([i], dtype=np.int64)] + ) + + data_info["phrases"] = phrases + data_info["positive_positions"] = positive_positions + + boxes2d = ( + np.empty((0, 4), dtype=np.float32) + if not boxes + else np.array(boxes, dtype=np.float32) + ) + + data_info["boxes2d"] = boxes2d + data_info["class_ids"] = class_ids + data_info["anns"] = valid_anns + + samples.append(data_info) + + del data_list + return samples + + def get_cat_ids(self, idx: int) -> list[int]: + """Return the samples.""" + return self.samples[idx]["class_ids"].tolist() + + def __len__(self) -> int: + """Total number of samples of data.""" + return len(self.samples) + + def __getitem__(self, idx: int) -> DictData: + """Get single sample. + + Args: + idx (int): Index of sample. + + Returns: + DictData: sample at index in Vis4D input format. + """ + sample = self.samples[idx] + data_dict: DictData = {} + + # Get image info + sample_name = sample["img_path"].split("/")[-1] + data_dict[K.sample_names] = sample_name + + im_bytes = self.data_backend.get(sample["img_path"]) + image = np.ascontiguousarray( + im_decode(im_bytes, mode=self.image_channel_mode), + dtype=np.float32, + )[None] + + data_dict[K.images] = image + data_dict[K.input_hw] = (image.shape[1], image.shape[2]) + + data_dict[K.original_images] = image + data_dict[K.original_hw] = (image.shape[1], image.shape[2]) + + data_dict[K.boxes2d] = sample["boxes2d"] + data_dict[K.boxes2d_classes] = sample["class_ids"] + + if self.dataset_type == "OD": + data_dict[K.boxes2d_names] = self.categories + data_dict["phrases"] = None + data_dict["positive_positions"] = None + else: + data_dict[K.boxes2d_names] = sample["caption"] + data_dict["phrases"] = sample["phrases"] + data_dict["positive_positions"] = sample["positive_positions"] + + data_dict["dataset_type"] = self.dataset_type + data_dict["label_map"] = self.label_map + + self.data_backend.close() + + return data_dict diff --git a/wilddet3d/data/datasets/omni3d/__init__.py b/wilddet3d/data/datasets/omni3d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f173db6b4ea0691fb71340d1a8b870c841873c00 --- /dev/null +++ b/wilddet3d/data/datasets/omni3d/__init__.py @@ -0,0 +1 @@ +"""Omni3D Dataset.""" diff --git a/wilddet3d/data/datasets/omni3d/arkitscenes.py b/wilddet3d/data/datasets/omni3d/arkitscenes.py new file mode 100644 index 0000000000000000000000000000000000000000..cbbf4591454016c81de77b048b5d263529604bf7 --- /dev/null +++ b/wilddet3d/data/datasets/omni3d/arkitscenes.py @@ -0,0 +1,81 @@ +"""ARKitScenes from Omni3D.""" + +from __future__ import annotations + +import os + +from vis4d.common.typing import ArgsType, DictStrAny + +from wilddet3d.data.datasets.coco3d import COCO3DDataset + +from .omni3d_classes import omni3d_class_map + +arkitscenes_det_map = { + "bathtub": 0, + "bed": 1, + "cabinet": 2, + "chair": 3, + "fireplace": 4, + "machine": 5, + "oven": 6, + "refrigerator": 7, + "shelves": 8, + "sink": 9, + "sofa": 10, + "stove": 11, + "table": 12, + "television": 13, + "toilet": 14, +} + +omni3d_arkitscenes_det_map = { + "table": 0, + "bed": 1, + "sofa": 2, + "television": 3, + "refrigerator": 4, + "chair": 5, + "oven": 6, + "machine": 7, + "stove": 8, + "shelves": 9, + "sink": 10, + "cabinet": 11, + "bathtub": 12, + "toilet": 13, +} + + +class ARKitScenes(COCO3DDataset): + """ARKitScenes Dataset.""" + + def __init__( + self, + class_map: dict[str, int] = omni3d_class_map, + max_depth: float = 10.0, + depth_scale: float = 1000.0, + **kwargs: ArgsType, + ) -> None: + """Creates an instance of the class.""" + super().__init__( + class_map=class_map, + max_depth=max_depth, + depth_scale=depth_scale, + **kwargs, + ) + + def get_depth_filenames(self, img: DictStrAny) -> str | None: + """Get the depth filenames. + + Since not every data has depth. + """ + _, _, split, video_id, image_name = img["file_path"].split("/") + + depth_filename = os.path.join( + "data/ARKitScenes_depth", + split, + video_id, + image_name.replace("jpg", "png"), + ) + + return depth_filename diff --git a/wilddet3d/data/datasets/omni3d/hypersim.py b/wilddet3d/data/datasets/omni3d/hypersim.py new file mode 100644 index 0000000000000000000000000000000000000000..8260f8415144d7d37d3f33a41abeb84bf91170e9 --- /dev/null +++ b/wilddet3d/data/datasets/omni3d/hypersim.py @@ -0,0 +1,190 @@ +"""Hypersim from Omni3D.""" + +from __future__ import annotations + +import os + +from vis4d.common.typing import ArgsType, DictStrAny + +from wilddet3d.data.datasets.coco3d import COCO3DDataset + +from .omni3d_classes import omni3d_class_map + +hypersim_train_det_map = { + "bathtub": 0, + "bed": 1, + "blinds": 2, + "bookcase": 3, + "books": 4, + "box": 5, + "cabinet": 6, + "chair": 7, + "clothes": 8, + "counter": 9, + "curtain": 10, + "desk": 11, + "door": 12, + "dresser": 13, + "floor mat": 14, + "lamp": 15, + "mirror": 16, + "night stand": 17, + "person": 18, + "picture": 19, + "pillow": 20, + "refrigerator": 21, + "shelves": 22, + "sink": 23, + "sofa": 24, + "stationery": 25, + "table": 26, + "television": 27, + "toilet": 28, + "towel": 29, + "window": 30, +} + +hypersim_val_det_map = { + "bathtub": 0, + "bed": 1, + "blinds": 2, + "bookcase": 3, + "books": 4, + "box": 5, + "cabinet": 6, + "chair": 7, + "clothes": 8, + "counter": 9, + "curtain": 10, + "desk": 11, + "door": 12, + "dresser": 13, + "floor mat": 14, + "lamp": 15, + "mirror": 16, + "night stand": 17, + "picture": 18, + "pillow": 19, + "refrigerator": 20, + "shelves": 21, + "sink": 22, + "sofa": 23, + "stationery": 24, + "table": 25, + "television": 26, + "toilet": 27, + "towel": 28, + "window": 29, +} + +hypersim_test_det_map = { + "bathtub": 0, + "bed": 1, + "blinds": 2, + "board": 3, + "bookcase": 4, + "books": 5, + "box": 6, + "cabinet": 7, + "chair": 8, + "clothes": 9, + "counter": 10, + "curtain": 11, + "desk": 12, + "door": 13, + "floor mat": 14, + "lamp": 15, + "mirror": 16, + "night stand": 17, + "picture": 18, + "pillow": 19, + "refrigerator": 20, + "shelves": 21, + "sink": 22, + "sofa": 23, + "stationery": 24, + "table": 25, + "television": 26, + "towel": 27, + "window": 28, +} + + +omni3d_hypersim_det_map = { + "books": 0, + "chair": 1, + "towel": 2, + "blinds": 3, + "window": 4, + "lamp": 5, + "shelves": 6, + "mirror": 7, + "sink": 8, + "cabinet": 9, + "bathtub": 10, + "door": 11, + "desk": 12, + "box": 13, + "bookcase": 14, + "picture": 15, + "table": 16, + "counter": 17, + "bed": 18, + "night stand": 19, + "pillow": 20, + "sofa": 21, + "television": 22, + "floor mat": 23, + "curtain": 24, + "clothes": 25, + "stationery": 26, + "refrigerator": 27, +} + + +def get_hypersim_det_map(split: str) -> dict[str, int]: + """Get Hypersim detection map.""" + assert split in {"train", "val", "test"}, f"Invalid split: {split}" + + if split == "train": + return hypersim_train_det_map + elif split == "val": + return hypersim_val_det_map + elif split == "test": + return hypersim_test_det_map + + +class Hypersim(COCO3DDataset): + """Hypersim Dataset.""" + + def __init__( + self, + class_map: dict[str, int] = omni3d_class_map, + max_depth: float = 50.0, + depth_scale: float = 1000.0, + **kwargs: ArgsType, + ) -> None: + """Creates an instance of the class.""" + super().__init__( + class_map=class_map, + max_depth=max_depth, + depth_scale=depth_scale, + **kwargs, + ) + + def get_depth_filenames(self, img: DictStrAny) -> str | None: + """Get the depth filenames. + + Since not every data has depth. + """ + _, _, scene, _, img_dir, img_name = img["file_path"].split("/") + + depth_filename = os.path.join( + "data/hypersim_depth", + scene, + "images", + img_dir, + img_name.replace("jpg", "png"), + ) + + return depth_filename diff --git a/wilddet3d/data/datasets/omni3d/kitti_object.py b/wilddet3d/data/datasets/omni3d/kitti_object.py new file mode 100644 index 0000000000000000000000000000000000000000..13aa64d94ea7d4ff5dd15d8e0ca7bed7070d0fa1 --- /dev/null +++ b/wilddet3d/data/datasets/omni3d/kitti_object.py @@ -0,0 +1,105 @@ +"""KITTI Object from Omni3D. + +KITTI Object Labels: +Categories, -, -, alpha, x1, y1, x2, y2, h, w, l, x, botom_y, z, ry + +KITTI Object Categories: +{ + "Pedestrian": "pedestrian", + "Cyclist": "cyclist", + "Car": "car", + "Van": "car", + "Truck": "truck", + "Tram": "tram", + "Person": "pedestrian", + "Person_sitting": "pedestrian", + "Misc": "misc", + "DontCare": "dontcare", +} +""" + +from __future__ import annotations + +import os + +from vis4d.common.typing import ArgsType, DictStrAny + +from wilddet3d.data.datasets.coco3d import COCO3DDataset + +from .omni3d_classes import omni3d_class_map + +kitti_train_det_map = kitti_test_det_map = { + "car": 0, + "cyclist": 1, + "pedestrian": 2, + "person": 3, + "tram": 4, + "truck": 5, + "van": 6, +} + +kitti_val_det_map = { + "car": 0, + "cyclist": 1, + "pedestrian": 2, + "tram": 3, + "truck": 4, +} + +# KITTI-Omni3D Mapping +omni3d_kitti_det_map = { + "pedestrian": 0, + "car": 1, + "cyclist": 2, + "van": 3, + "truck": 4, +} + + +def get_kitti_det_map(split: str) -> dict[str, int]: + """Get the KITTI detection map.""" + assert split in {"train", "val", "test"}, f"Invalid split: {split}" + + if split == "val": + return kitti_val_det_map + + # Train and Test are the same + return kitti_train_det_map + + +class KITTIObject(COCO3DDataset): + """KITTI Object Dataset.""" + + def __init__( + self, + class_map: dict[str, int] = omni3d_class_map, + max_depth: float = 80.0, + depth_scale: float = 256.0, + depth_data_root: str = "data/KITTI_object_depth", + **kwargs: ArgsType, + ) -> None: + """Creates an instance of the class.""" + self.depth_data_root = depth_data_root + + super().__init__( + class_map=class_map, + max_depth=max_depth, + depth_scale=depth_scale, + **kwargs, + ) + + def get_depth_filenames(self, img: DictStrAny) -> str | None: + """Get the depth filenames. + + Since not every data has depth. + """ + _, _, split, image_id, img_filename = img["file_path"].split("/") + + depth_filename = os.path.join( + self.depth_data_root, + split, + image_id, + img_filename.replace(".jpg", ".png"), + ) + + return depth_filename diff --git a/wilddet3d/data/datasets/omni3d/nuscenes.py b/wilddet3d/data/datasets/omni3d/nuscenes.py new file mode 100644 index 0000000000000000000000000000000000000000..f218bff9205ddaf87801c98c6a68005f66c7d364 --- /dev/null +++ b/wilddet3d/data/datasets/omni3d/nuscenes.py @@ -0,0 +1,62 @@ +"""nuScenes from Omni3D.""" + +from __future__ import annotations + +from vis4d.common.typing import ArgsType, DictStrAny + +from wilddet3d.data.datasets.coco3d import COCO3DDataset + +from .omni3d_classes import omni3d_class_map + +nusc_det_map = { + "bicycle": 0, + "motorcycle": 1, + "pedestrian": 2, + "bus": 3, + "car": 4, + "trailer": 5, + "truck": 6, + "traffic cone": 7, + "barrier": 8, +} + + +class nuScenes(COCO3DDataset): + """nuScenes dataset.""" + + def __init__( + self, + class_map: dict[str, int] = omni3d_class_map, + max_depth: float = 80.0, + depth_scale: float = 256.0, + **kwargs: ArgsType, + ) -> None: + """Creates an instance of the class.""" + super().__init__( + class_map=class_map, + max_depth=max_depth, + depth_scale=depth_scale, + **kwargs, + ) + + def get_depth_filenames(self, img: DictStrAny) -> str | None: + """Get the depth filenames. + + Since not every data has depth. + """ + img["file_path"] = img["file_path"].replace("nuScenes", "nuscenes") + + depth_filename = ( + img["file_path"] + .replace("nuscenes", "nuscenes_depth") + .replace("jpg", "png") + ) + return depth_filename + + def get_cat_ids(self, idx: int) -> list[int]: + """Return the samples.""" + return self.samples[idx]["class_ids"].tolist() + + def __len__(self) -> int: + """Total number of samples of data.""" + return len(self.samples) diff --git a/wilddet3d/data/datasets/omni3d/objectron.py b/wilddet3d/data/datasets/omni3d/objectron.py new file mode 100644 index 0000000000000000000000000000000000000000..8f9cf007b3753b45c3c13290ba94324502828f7a --- /dev/null +++ b/wilddet3d/data/datasets/omni3d/objectron.py @@ -0,0 +1,56 @@ +"""Objectron from Omni3D.""" + +from __future__ import annotations + +import os + +from vis4d.common.typing import ArgsType, DictStrAny + +from wilddet3d.data.datasets.coco3d import COCO3DDataset + +from .omni3d_classes import omni3d_class_map + +objectron_det_map = { + "bicycle": 0, + "books": 1, + "bottle": 2, + "camera": 3, + "cereal box": 4, + "chair": 5, + "cup": 6, + "laptop": 7, + "shoes": 8, +} + + +class Objectron(COCO3DDataset): + """Objectron dataset.""" + + def __init__( + self, + class_map: dict[str, int] = omni3d_class_map, + max_depth: float = 12.0, + depth_scale: float = 1000.0, + **kwargs: ArgsType, + ) -> None: + """Creates an instance of the class.""" + super().__init__( + class_map=class_map, + max_depth=max_depth, + depth_scale=depth_scale, + **kwargs, + ) + + def get_depth_filenames(self, img: DictStrAny) -> str | None: + """Get the depth filenames. + + Since not every data has depth. + """ + _, _, split, img_name = img["file_path"].split("/") + + depth_filename = os.path.join( + "data/objectron_depth", + split, + img_name.replace(".jpg", "_depth.png"), + ) + return depth_filename diff --git a/wilddet3d/data/datasets/omni3d/omni3d_classes.py b/wilddet3d/data/datasets/omni3d/omni3d_classes.py new file mode 100644 index 0000000000000000000000000000000000000000..d9dceb74ea84bd9fd1ebe91388990ae6134fc291 --- /dev/null +++ b/wilddet3d/data/datasets/omni3d/omni3d_classes.py @@ -0,0 +1,156 @@ +"""Omni3D classes.""" + +omni3d_class_map = { + "pedestrian": 0, + "car": 1, + "dontcare": 2, + "cyclist": 3, + "van": 4, + "truck": 5, + "tram": 6, + "person": 7, + "traffic cone": 8, + "barrier": 9, + "motorcycle": 10, + "bicycle": 11, + "bus": 12, + "trailer": 13, + "books": 14, + "bottle": 15, + "camera": 16, + "cereal box": 17, + "chair": 18, + "cup": 19, + "laptop": 20, + "shoes": 21, + "towel": 22, + "blinds": 23, + "window": 24, + "lamp": 25, + "shelves": 26, + "mirror": 27, + "sink": 28, + "cabinet": 29, + "bathtub": 30, + "door": 31, + "toilet": 32, + "desk": 33, + "box": 34, + "bookcase": 35, + "picture": 36, + "table": 37, + "counter": 38, + "bed": 39, + "night stand": 40, + "dresser": 41, + "pillow": 42, + "sofa": 43, + "television": 44, + "floor mat": 45, + "curtain": 46, + "clothes": 47, + "stationery": 48, + "refrigerator": 49, + "board": 50, + "kitchen pan": 51, + "bin": 52, + "stove": 53, + "microwave": 54, + "plates": 55, + "bowl": 56, + "oven": 57, + "vase": 58, + "faucet": 59, + "tissues": 60, + "machine": 61, + "printer": 62, + "monitor": 63, + "podium": 64, + "cart": 65, + "projector": 66, + "electronics": 67, + "computer": 68, + "air conditioner": 69, + "drawers": 70, + "coffee maker": 71, + "toaster": 72, + "potted plant": 73, + "painting": 74, + "bag": 75, + "tray": 76, + "keyboard": 77, + "blanket": 78, + "rack": 79, + "phone": 80, + "mouse": 81, + "fire extinguisher": 82, + "toys": 83, + "ladder": 84, + "fan": 85, + "glass": 86, + "clock": 87, + "toilet paper": 88, + "closet": 89, + "fume hood": 90, + "utensils": 91, + "soundsystem": 92, + "fire place": 93, + "shower curtain": 94, + "remote": 95, + "pen": 96, + "fireplace": 97, +} + +# Used for Cube R-CNN and Omni3D benchmark +omni3d_det_map = { + "pedestrian": 0, + "car": 1, + "cyclist": 2, + "van": 3, + "truck": 4, + "traffic cone": 5, + "barrier": 6, + "motorcycle": 7, + "bicycle": 8, + "bus": 9, + "trailer": 10, + "books": 11, + "bottle": 12, + "camera": 13, + "cereal box": 14, + "chair": 15, + "cup": 16, + "laptop": 17, + "shoes": 18, + "towel": 19, + "blinds": 20, + "window": 21, + "lamp": 22, + "shelves": 23, + "mirror": 24, + "sink": 25, + "cabinet": 26, + "bathtub": 27, + "door": 28, + "toilet": 29, + "desk": 30, + "box": 31, + "bookcase": 32, + "picture": 33, + "table": 34, + "counter": 35, + "bed": 36, + "night stand": 37, + "pillow": 38, + "sofa": 39, + "television": 40, + "floor mat": 41, + "curtain": 42, + "clothes": 43, + "stationery": 44, + "refrigerator": 45, + "bin": 46, + "stove": 47, + "oven": 48, + "machine": 49, +} diff --git a/wilddet3d/data/datasets/omni3d/sunrgbd.py b/wilddet3d/data/datasets/omni3d/sunrgbd.py new file mode 100644 index 0000000000000000000000000000000000000000..274579c850e7555f44197c05554d05fe1e564cee --- /dev/null +++ b/wilddet3d/data/datasets/omni3d/sunrgbd.py @@ -0,0 +1,278 @@ +"""SUN RGB-D from Omni3D.""" + +from __future__ import annotations + +import os + +import numpy as np +from vis4d.common.typing import ArgsType, DictStrAny +from vis4d.data.datasets.util import im_decode + +from wilddet3d.data.datasets.coco3d import COCO3DDataset + +from .omni3d_classes import omni3d_class_map + +# Train and Test are sharing the classes +sun_rgbd_train_det_map = sun_rgbd_test_det_map = { + "air conditioner": 0, + "bag": 1, + "bathtub": 2, + "bed": 3, + "bicycle": 4, + "bin": 5, + "blanket": 6, + "blinds": 7, + "board": 8, + "bookcase": 9, + "books": 10, + "bottle": 11, + "bowl": 12, + "box": 13, + "cabinet": 14, + "cart": 15, + "chair": 16, + "clock": 17, + "closet": 18, + "clothes": 19, + "coffee maker": 20, + "computer": 21, + "counter": 22, + "cup": 23, + "curtain": 24, + "desk": 25, + "door": 26, + "drawers": 27, + "dresser": 28, + "electronics": 29, + "fan": 30, + "faucet": 31, + "fire extinguisher": 32, + "fire place": 33, + "floor mat": 34, + "fume hood": 35, + "glass": 36, + "keyboard": 37, + "kitchen pan": 38, + "ladder": 39, + "lamp": 40, + "laptop": 41, + "machine": 42, + "microwave": 43, + "mirror": 44, + "monitor": 45, + "mouse": 46, + "night stand": 47, + "oven": 48, + "painting": 49, + "pen": 50, + "person": 51, + "phone": 52, + "picture": 53, + "pillow": 54, + "plates": 55, + "podium": 56, + "potted plant": 57, + "printer": 58, + "projector": 59, + "rack": 60, + "refrigerator": 61, + "remote": 62, + "shelves": 63, + "shoes": 64, + "shower curtain": 65, + "sink": 66, + "sofa": 67, + "soundsystem": 68, + "stationery": 69, + "stove": 70, + "table": 71, + "television": 72, + "tissues": 73, + "toaster": 74, + "toilet": 75, + "toilet paper": 76, + "towel": 77, + "toys": 78, + "tray": 79, + "utensils": 80, + "vase": 81, + "window": 82, +} + +sun_rgbd_val_det_map = { + "air conditioner": 0, + "bag": 1, + "bathtub": 2, + "bed": 3, + "bin": 4, + "blanket": 5, + "blinds": 6, + "board": 7, + "bookcase": 8, + "books": 9, + "bottle": 10, + "bowl": 11, + "box": 12, + "cabinet": 13, + "cart": 14, + "chair": 15, + "closet": 16, + "clothes": 17, + "coffee maker": 18, + "computer": 19, + "counter": 20, + "cup": 21, + "curtain": 22, + "desk": 23, + "door": 24, + "drawers": 25, + "dresser": 26, + "electronics": 27, + "fan": 28, + "faucet": 29, + "fire extinguisher": 30, + "fire place": 31, + "fume hood": 32, + "keyboard": 33, + "kitchen pan": 34, + "lamp": 35, + "laptop": 36, + "machine": 37, + "microwave": 38, + "mirror": 39, + "monitor": 40, + "night stand": 41, + "oven": 42, + "painting": 43, + "pen": 44, + "person": 45, + "phone": 46, + "picture": 47, + "pillow": 48, + "plates": 49, + "potted plant": 50, + "printer": 51, + "projector": 52, + "rack": 53, + "refrigerator": 54, + "shelves": 55, + "sink": 56, + "sofa": 57, + "soundsystem": 58, + "stationery": 59, + "stove": 60, + "table": 61, + "television": 62, + "tissues": 63, + "toaster": 64, + "toilet": 65, + "towel": 66, + "toys": 67, + "tray": 68, + "utensils": 69, + "vase": 70, + "window": 71, +} + +omni3d_sun_rgbd_det_map = { + "bicycle": 0, + "books": 1, + "bottle": 2, + "chair": 3, + "cup": 4, + "laptop": 5, + "shoes": 6, + "towel": 7, + "blinds": 8, + "window": 9, + "lamp": 10, + "shelves": 11, + "mirror": 12, + "sink": 13, + "cabinet": 14, + "bathtub": 15, + "door": 16, + "toilet": 17, + "desk": 18, + "box": 19, + "bookcase": 20, + "picture": 21, + "table": 22, + "counter": 23, + "bed": 24, + "night stand": 25, + "pillow": 26, + "sofa": 27, + "television": 28, + "floor mat": 29, + "curtain": 30, + "clothes": 31, + "stationery": 32, + "refrigerator": 33, + "bin": 34, + "stove": 35, + "oven": 36, + "machine": 37, +} + + +def get_sunrgbd_det_map(split: str) -> dict[str, int]: + """Get the SUN RGB-D detection map.""" + assert split in {"train", "val", "test"}, f"Invalid split: {split}" + + if split == "train": + return sun_rgbd_train_det_map + elif split == "val": + return sun_rgbd_val_det_map + else: + return sun_rgbd_test_det_map + + +class SUNRGBD(COCO3DDataset): + """SUN RGB-D Dataset.""" + + def __init__( + self, + class_map: dict[str, int] = omni3d_class_map, + max_depth: float = 8.0, + depth_scale: float = 1000.0, + **kwargs: ArgsType, + ) -> None: + """Initialize SUN RGB-D dataset.""" + super().__init__( + class_map=class_map, + max_depth=max_depth, + depth_scale=depth_scale, + **kwargs, + ) + + def get_depth_filenames(self, img: DictStrAny) -> str | None: + """Get the depth filenames. + + Since not every data has depth. + """ + img["file_path"] = img["file_path"].replace("//", "/") + + data_dir = img["file_path"].split("/image")[0] + + depth_files = self.data_backend.listdir( + os.path.join(data_dir, "depth") + ) + assert len(depth_files) == 1 + + depth_filename = os.path.join(data_dir, "depth", depth_files[0]) + + return depth_filename + + def get_depth_map(self, sample: DictStrAny) -> np.ndarray: + """Get the depth map.""" + depth_bytes = self.data_backend.get(sample["depth_filename"]) + depth_array = im_decode(depth_bytes) + + depth_array = depth_array >> 3 | depth_array << (16 - 3) + + depth = np.ascontiguousarray(depth_array, dtype=np.float32) + + depth = depth / self.depth_scale + + return depth diff --git a/wilddet3d/data/datasets/omni3d/util.py b/wilddet3d/data/datasets/omni3d/util.py new file mode 100644 index 0000000000000000000000000000000000000000..ebdaa8365c5741c74cff9e22b378fa65a4ff2b40 --- /dev/null +++ b/wilddet3d/data/datasets/omni3d/util.py @@ -0,0 +1,86 @@ +"""Omni3D data util.""" + +from __future__ import annotations + +from .arkitscenes import arkitscenes_det_map, omni3d_arkitscenes_det_map +from .hypersim import get_hypersim_det_map, omni3d_hypersim_det_map +from .kitti_object import get_kitti_det_map, omni3d_kitti_det_map +from .nuscenes import nusc_det_map +from .objectron import objectron_det_map +from .sunrgbd import get_sunrgbd_det_map, omni3d_sun_rgbd_det_map + +DATASET_ID_MAP = { + 0: "KITTI_train", + 1: "KITTI_val", + 2: "KITTI_test", + 3: "nuScenes_train", + 4: "nuScenes_val", + 5: "nuScenes_test", + 6: "Objectron_train", + 7: "Objectron_val", + 8: "Objectron_test", + 9: "Hypersim_train", + 10: "Hypersim_val", + 11: "Hypersim_test", + 12: "SUNRGBD_train", + 13: "SUNRGBD_val", + 14: "SUNRGBD_test", + 15: "ARKitScenes_train", + 16: "ARKitScenes_val", + 17: "ARKitScenes_test", +} + + +def get_dataset_det_map( + dataset_name: str, + omni3d50: bool = True, +) -> tuple[str, dict[str, int]]: + """Get the detection map.""" + if "train" in dataset_name: + split = "train" + elif "val" in dataset_name: + split = "val" + elif "test" in dataset_name: + split = "test" + else: + raise ValueError(f"Unknown dataset_name: {dataset_name}") + + if "nuScenes" in dataset_name: + det_map = nusc_det_map + elif "KITTI" in dataset_name: + if omni3d50: + det_map = omni3d_kitti_det_map + else: + det_map = get_kitti_det_map(split) + elif "Objectron" in dataset_name: + det_map = objectron_det_map + elif "SUNRGBD" in dataset_name: + if omni3d50: + det_map = omni3d_sun_rgbd_det_map + else: + det_map = get_sunrgbd_det_map(split) + elif "Hypersim" in dataset_name: + if omni3d50: + det_map = omni3d_hypersim_det_map + else: + det_map = get_hypersim_det_map(split) + elif "ARKitScenes" in dataset_name: + det_map = ( + omni3d_arkitscenes_det_map if omni3d50 else arkitscenes_det_map + ) + elif "CubifyAnything" in dataset_name: + from wilddet3d.data.datasets.cubifyanything import ( + get_cubifyanything_det_map, + ) + + det_map = get_cubifyanything_det_map(dataset_name) + elif "Waymo" in dataset_name: + from wilddet3d.data.datasets.waymo import ( + get_waymo_det_map, + ) + + det_map = get_waymo_det_map(dataset_name) + else: + raise ValueError(f"Unknown dataset_name: {dataset_name}") + + return det_map diff --git a/wilddet3d/data/datasets/scannet.py b/wilddet3d/data/datasets/scannet.py new file mode 100644 index 0000000000000000000000000000000000000000..7f31438a6b22fa81876aa58f4652877786fa9246 --- /dev/null +++ b/wilddet3d/data/datasets/scannet.py @@ -0,0 +1,449 @@ +"""ScanNet dataset.""" + +from __future__ import annotations + +from vis4d.common.typing import ArgsType, DictStrAny + +from .coco3d import COCO3DDataset + +scannet_class_map = { + "cabinet": 3, + "bed": 4, + "chair": 5, + "sofa": 6, + "table": 7, + "door": 8, + "window": 9, + "bookshelf": 10, + "picture": 11, + "counter": 12, + "desk": 14, + "curtain": 16, + "refrigerator": 24, + "shower curtain": 28, + "toilet": 33, + "sink": 34, + "bathtub": 36, + "other furniture": 39, +} + +scannet_det_map = { + "cabinet": 0, + "bed": 1, + "chair": 2, + "sofa": 3, + "table": 4, + "door": 5, + "window": 6, + "bookshelf": 7, + "picture": 8, + "counter": 9, + "desk": 10, + "curtain": 11, + "refrigerator": 12, + "shower curtain": 13, + "toilet": 14, + "sink": 15, + "bathtub": 16, + "other furniture": 17, +} + +scannet200_class_map = { + "chair": 2, + "book": 22, + "door": 5, + "object": 1163, + "window": 16, + "table": 4, + "trash can": 56, + "pillow": 13, + "picture": 15, + "box": 26, + "doorframe": 161, + "monitor": 19, + "cabinet": 7, + "desk": 9, + "shelf": 8, + "office chair": 10, + "towel": 31, + "couch": 6, + "sink": 14, + "backpack": 48, + "lamp": 28, + "bed": 11, + "bookshelf": 18, + "mirror": 71, + "curtain": 21, + "plant": 40, + "whiteboard": 52, + "radiator": 96, + "kitchen cabinet": 29, + "toilet paper": 49, + "armchair": 23, + "shoe": 63, + "coffee table": 24, + "toilet": 17, + "bag": 47, + "clothes": 32, + "keyboard": 46, + "bottle": 65, + "recycling bin": 97, + "nightstand": 34, + "stool": 38, + "tv": 33, + "file cabinet": 75, + "dresser": 36, + "computer tower": 64, + "telephone": 101, + "cup": 130, + "refrigerator": 27, + "end table": 44, + "jacket": 131, + "shower curtain": 55, + "bathtub": 42, + "microwave": 59, + "kitchen counter": 159, + "sofa chair": 74, + "paper towel dispenser": 82, + "bathroom vanity": 1164, + "suitcase": 93, + "laptop": 77, + "ottoman": 67, + "shower wall": 128, + "printer": 50, + "counter": 35, + "board": 69, + "soap dispenser": 100, + "stove": 62, + "light": 105, + "closet wall": 1165, + "mini fridge": 165, + "fan": 76, + "tissue box": 230, + "blanket": 54, + "bathroom stall": 125, + "copier": 72, + "bench": 68, + "bar": 145, + "soap dish": 157, + "laundry hamper": 1166, + "storage bin": 132, + "bathroom stall door": 1167, + "light switch": 232, + "coffee maker": 134, + "tv stand": 51, + "decoration": 250, + "ceiling light": 1168, + "range hood": 342, + "blackboard": 89, + "clock": 103, + "wardrobe": 99, + "rail": 95, + "bulletin board": 154, + "mat": 140, + "trash bin": 1169, + "ledge": 193, + "seat": 116, + "mouse": 202, + "basket": 73, + "shower": 78, + "dumbbell": 1170, + "paper": 79, + "person": 80, + "windowsill": 141, + "closet": 57, + "bucket": 102, + "sign": 261, + "speaker": 118, + "dishwasher": 136, + "container": 98, + "stair rail": 1171, + "shower curtain rod": 170, + "tube": 1172, + "bathroom cabinet": 1173, + "storage container": 221, + "paper bag": 570, + "paper towel roll": 138, + "ball": 168, + "closet door": 276, + "laundry basket": 106, + "cart": 214, + "dish rack": 323, + "stairs": 58, + "blinds": 86, + "purse": 399, + "bicycle": 121, + "tray": 185, + "plunger": 300, + "paper cutter": 180, + "toilet paper dispenser": 163, + "bin": 66, + "toilet seat cover dispenser": 208, + "guitar": 112, + "mailbox": 540, + "handicap bar": 395, + "fire extinguisher": 166, + "ladder": 122, + "column": 120, + "pipe": 107, + "vacuum cleaner": 283, + "plate": 88, + "piano": 90, + "water cooler": 177, + "cd case": 1174, + "bowl": 562, + "closet rod": 1175, + "bathroom counter": 1156, + "oven": 84, + "stand": 104, + "scale": 229, + "washing machine": 70, + "broom": 325, + "hat": 169, + "guitar case": 331, + "rack": 87, + "water pitcher": 488, + "laundry detergent": 776, + "hair dryer": 370, + "pillar": 191, + "divider": 748, + "power outlet": 242, + "dining table": 45, + "shower floor": 417, + "shower door": 188, + "coffee kettle": 1176, + "structure": 1178, + "clothes dryer": 110, + "toaster": 148, + "ironing board": 155, + "alarm clock": 572, + "shower head": 1179, + "water bottle": 392, + "keyboard piano": 1180, + "projector screen": 609, + "case of water bottles": 1181, + "toaster oven": 195, + "music stand": 581, + "coat rack": 1182, + "storage organizer": 1183, + "machine": 139, + "folded chair": 1184, + "fire alarm": 1185, + "fireplace": 156, + "vent": 408, + "furniture": 213, + "power strip": 1186, + "calendar": 1187, + "poster": 1188, + "toilet paper holder": 115, + "potted plant": 1189, + "stuffed animal": 304, + "luggage": 1190, + "headphones": 312, + "crate": 233, + "candle": 286, + "projector": 264, + "mattress": 1191, + "dustpan": 356, + "cushion": 39, + "stick": 1163, +} + +scannet200_det_map = { + "chair": 0, + "table": 1, + "door": 2, + "couch": 3, + "cabinet": 4, + "shelf": 5, + "desk": 6, + "office chair": 7, + "bed": 8, + "pillow": 9, + "sink": 10, + "picture": 11, + "window": 12, + "toilet": 13, + "bookshelf": 14, + "monitor": 15, + "curtain": 16, + "book": 17, + "armchair": 18, + "coffee table": 19, + "box": 20, + "refrigerator": 21, + "lamp": 22, + "kitchen cabinet": 23, + "towel": 24, + "clothes": 25, + "tv": 26, + "nightstand": 27, + "counter": 28, + "dresser": 29, + "stool": 30, + "plant": 31, + "bathtub": 32, + "end table": 33, + "dining table": 34, + "keyboard": 35, + "bag": 36, + "backpack": 37, + "toilet paper": 38, + "printer": 39, + "tv stand": 40, + "whiteboard": 41, + "blanket": 42, + "shower curtain": 43, + "trash can": 44, + "closet": 45, + "stairs": 46, + "microwave": 47, + "stove": 48, + "shoe": 49, + "computer tower": 50, + "bottle": 51, + "bin": 52, + "ottoman": 53, + "bench": 54, + "board": 55, + "washing machine": 56, + "mirror": 57, + "copier": 58, + "basket": 59, + "sofa chair": 60, + "file cabinet": 61, + "fan": 62, + "laptop": 63, + "shower": 64, + "paper": 65, + "person": 66, + "paper towel dispenser": 67, + "oven": 68, + "blinds": 69, + "rack": 70, + "plate": 71, + "blackboard": 72, + "piano": 73, + "suitcase": 74, + "rail": 75, + "radiator": 76, + "recycling bin": 77, + "container": 78, + "wardrobe": 79, + "soap dispenser": 80, + "telephone": 81, + "bucket": 82, + "clock": 83, + "stand": 84, + "light": 85, + "laundry basket": 86, + "pipe": 87, + "clothes dryer": 88, + "guitar": 89, + "toilet paper holder": 90, + "seat": 91, + "speaker": 92, + "column": 93, + "ladder": 94, + "cup": 95, + "jacket": 96, + "storage bin": 97, + "coffee maker": 98, + "dishwasher": 99, + "paper towel roll": 100, + "machine": 101, + "mat": 102, + "windowsill": 103, + "bar": 104, + "bulletin board": 105, + "ironing board": 106, + "fireplace": 107, + "soap dish": 108, + "kitchen counter": 109, + "doorframe": 110, + "toilet paper dispenser": 111, + "mini fridge": 112, + "fire extinguisher": 113, + "ball": 114, + "hat": 115, + "shower curtain rod": 116, + "water cooler": 117, + "paper cutter": 118, + "tray": 119, + "pillar": 120, + "ledge": 121, + "toaster oven": 122, + "mouse": 123, + "toilet seat cover dispenser": 124, + "cart": 125, + "scale": 126, + "tissue box": 127, + "light switch": 128, + "crate": 129, + "power outlet": 130, + "decoration": 131, + "sign": 132, + "projector": 133, + "closet door": 134, + "vacuum cleaner": 135, + "headphones": 136, + "dish rack": 137, + "broom": 138, + "range hood": 139, + "hair dryer": 140, + "water bottle": 141, + "vent": 142, + "mailbox": 143, + "bowl": 144, + "paper bag": 145, + "projector screen": 146, + "divider": 147, + "laundry detergent": 148, + "bathroom counter": 149, + "stick": 150, + "bathroom vanity": 151, + "closet wall": 152, + "laundry hamper": 153, + "bathroom stall door": 154, + "ceiling light": 155, + "trash bin": 156, + "dumbbell": 157, + "stair rail": 158, + "tube": 159, + "bathroom cabinet": 160, + "coffee kettle": 161, + "shower head": 162, + "case of water bottles": 163, + "power strip": 164, + "calendar": 165, + "poster": 166, + "mattress": 167, +} + + +class ScanNetDataset(COCO3DDataset): + """ScanNetV2 dataset.""" + + def __init__( + self, + class_map: dict[str, int] = scannet_class_map, + max_depth: float = 12.0, + depth_scale: float = 1000.0, + **kwargs: ArgsType, + ) -> None: + """Creates an instance of the class.""" + super().__init__( + class_map=class_map, + max_depth=max_depth, + depth_scale=depth_scale, + **kwargs, + ) + + def get_depth_filenames(self, img: DictStrAny) -> str | None: + """Get the depth filenames. + + Since not every data has depth. + """ + return ( + img["file_path"].replace("image", "depth").replace(".jpg", ".png") + ) diff --git a/wilddet3d/data/datasets/stereo4d.py b/wilddet3d/data/datasets/stereo4d.py new file mode 100644 index 0000000000000000000000000000000000000000..074a5948ec9b7a000dc78aaf97ec5ed6fc395a34 --- /dev/null +++ b/wilddet3d/data/datasets/stereo4d.py @@ -0,0 +1,178 @@ +"""Stereo4D tinyval 3D dataset (real stereo depth, 500 images).""" + +from __future__ import annotations + +import json +import os + +import cv2 +import numpy as np + +from vis4d.common.typing import ArgsType, DictStrAny +from vis4d.data.const import CommonKeys as K + +from .coco3d import COCO3DDataset + +# Stereo4D v3 depth directory (meters, 512x512 .npy files) +_STEREO4D_DEPTH_DIR = ( + "/weka/oe-training-default/weikaih/3d_boundingbox_detection" + "/video_data/stereo4d_test/stereo4d_dataset_v3/depth" +) + +# V3 annotation for image_id -> filename mapping (depth file lookup) +_V3_ANN_PATH = ( + "/weka/oe-training-default/weikaih/3d_boundingbox_detection" + "/video_data/stereo4d_test/stereo4d_dataset_v3" + "/annotations/stereo4d_test.json" +) + +# Tinyval source directory (to recover original v3 image_ids) +_TINYVAL_DIR = ( + "/weka/oe-training-default/weikaih/3d_boundingbox_detection" + "/single_frame_data/experiment/v4_score_merged_la3d/stereo4d/tinyval" +) + +# Cached v3 id-to-stem mapping (built once, reused) +_v3_id_to_stem_cache: dict[int, str] | None = None +_tinyval_orig_ids_cache: list[int] | None = None + + +def _load_v3_id_to_stem() -> dict[int, str]: + """Load v3 image_id -> file stem mapping for depth lookup.""" + global _v3_id_to_stem_cache + if _v3_id_to_stem_cache is not None: + return _v3_id_to_stem_cache + with open(_V3_ANN_PATH) as f: + v3 = json.load(f) + _v3_id_to_stem_cache = {} + for img in v3["images"]: + stem = os.path.splitext(os.path.basename(img["file_name"]))[0] + _v3_id_to_stem_cache[img["id"]] = stem + return _v3_id_to_stem_cache + + +def _load_tinyval_orig_ids() -> list[int]: + """Load tinyval original v3 image_ids in sorted order.""" + global _tinyval_orig_ids_cache + if _tinyval_orig_ids_cache is not None: + return _tinyval_orig_ids_cache + files = sorted( + f for f in os.listdir(_TINYVAL_DIR) if f.endswith(".json") + ) + _tinyval_orig_ids_cache = [] + for f in files: + img_id = int(f.split("_")[-1].replace(".json", "")) + _tinyval_orig_ids_cache.append(img_id) + return _tinyval_orig_ids_cache + + +def load_stereo4d_class_map( + annotation_path: str, +) -> dict[str, int]: + """Load class map from Stereo4D annotation file. + + Returns a mapping from category name to category ID. + """ + cache_path = annotation_path.replace(".json", "_class_map.json") + if os.path.exists(cache_path): + with open(cache_path) as f: + return json.load(f) + with open(annotation_path) as f: + data = json.load(f) + class_map = {cat["name"]: cat["id"] for cat in data["categories"]} + with open(cache_path, "w") as f: + json.dump(class_map, f) + return class_map + + +class Stereo4D3DDataset(COCO3DDataset): + """Stereo4D tinyval 3D dataset with real stereo depth. + + 500 images from Stereo4D test set with human-reviewed 3D bounding + boxes. Depth maps are real stereo depth (meters, 512x512). + + Key differences from InTheWild3DDataset: + - Depth is real stereo depth (meters), not estimated depth (mm). + - All images are 512x512. + - No confidence masking needed (stereo depth is high quality). + """ + + def __init__( + self, + class_map: dict[str, int], + max_depth: float = 100.0, + per_image_categories: bool = False, + **kwargs: ArgsType, + ) -> None: + """Creates an instance of the class. + + Args: + class_map: Mapping from category name to category ID. + max_depth: Maximum depth in meters (clip beyond this). + per_image_categories: If True, boxes2d_names only contains + the GT categories present in each image. + """ + # Initialize depth mappings BEFORE super().__init__() because + # _generate_data_mapping -> get_depth_filenames needs these. + self.per_image_categories = per_image_categories + self._v3_id_to_stem = _load_v3_id_to_stem() + self._tinyval_orig_ids = _load_tinyval_orig_ids() + + super().__init__( + class_map=class_map, + det_map=class_map, + max_depth=max_depth, + **kwargs, + ) + + def __getitem__(self, idx: int): + """Get single sample, optionally with per-image category filtering.""" + data_dict = super().__getitem__(idx) + if self.per_image_categories: + class_ids_in_img = data_dict[K.boxes2d_classes] + if len(class_ids_in_img) > 0: + unique_global_ids = sorted(set(class_ids_in_img.tolist())) + data_dict[K.boxes2d_names] = [ + self.categories[gid] for gid in unique_global_ids + ] + else: + data_dict[K.boxes2d_names] = [] + return data_dict + + def get_depth_filenames(self, img: DictStrAny) -> str | None: + """Return path to the .npy stereo depth file for this image. + + Maps: converted image index -> tinyval orig_id -> v3 stem -> depth .npy + """ + img_id = img["id"] + if img_id >= len(self._tinyval_orig_ids): + return None + orig_id = self._tinyval_orig_ids[img_id] + stem = self._v3_id_to_stem.get(orig_id) + if stem is None: + return None + depth_path = os.path.join(_STEREO4D_DEPTH_DIR, f"{stem}.npy") + return depth_path if os.path.exists(depth_path) else None + + def get_depth_map(self, sample: DictStrAny) -> np.ndarray: + """Load stereo depth .npy (meters, 512x512). + + No mm-to-meters conversion needed (already in meters). + Resize to original resolution if needed. + """ + depth = np.load(sample["depth_filename"]) # (H, W) float32, meters + + orig_h = sample["img"]["height"] + orig_w = sample["img"]["width"] + + if depth.shape != (orig_h, orig_w): + depth = cv2.resize( + depth, + (orig_w, orig_h), + interpolation=cv2.INTER_NEAREST, + ) + + # Clip to max_depth + depth[depth > self.max_depth] = 0.0 + + return depth.astype(np.float32) diff --git a/wilddet3d/data/datasets/threeeed.py b/wilddet3d/data/datasets/threeeed.py new file mode 100644 index 0000000000000000000000000000000000000000..1d61644b152c5719e94ebbe69eeb0c9e99f56c47 --- /dev/null +++ b/wilddet3d/data/datasets/threeeed.py @@ -0,0 +1,79 @@ +"""3EED dataset for 3D object detection. + +Multi-platform outdoor scenes (Waymo vehicle, M3ED drone, M3ED quadruped) +with sparse LiDAR depth maps (uint16, depth_m * 256). +Categories: car, pedestrian, bus, truck, othervehicle, cyclist. +""" + +from __future__ import annotations + +import json +import os + +from vis4d.common.typing import ArgsType, DictStrAny + +from wilddet3d.data.datasets.coco3d import COCO3DDataset + + +def get_threeeed_det_map( + dataset_name: str, + data_root: str = "data/3eed", +) -> dict[str, int]: + """Build det_map from 3EED annotation JSON categories.""" + cache_path = os.path.join( + data_root, "annotations", f"{dataset_name}_class_map.json" + ) + if os.path.exists(cache_path): + with open(cache_path) as f: + return json.load(f) + json_path = os.path.join( + data_root, "annotations", f"{dataset_name}.json" + ) + with open(json_path) as f: + data = json.load(f) + class_map = {cat["name"]: cat["id"] for cat in data["categories"]} + with open(cache_path, "w") as f: + json.dump(class_map, f) + return class_map + + +def get_threeeed_class_map( + dataset_name: str, + data_root: str = "data/3eed", +) -> dict[str, int]: + """Build class_map from 3EED annotation JSON categories.""" + return get_threeeed_det_map(dataset_name, data_root) + + +class ThreeEEDDataset(COCO3DDataset): + """3EED Dataset. + + Multi-platform outdoor scenes with sparse LiDAR depth maps. + """ + + def __init__( + self, + max_depth: float = 80.0, + depth_scale: float = 256.0, + **kwargs: ArgsType, + ) -> None: + """Creates an instance of the class.""" + super().__init__( + max_depth=max_depth, + depth_scale=depth_scale, + **kwargs, + ) + + def get_depth_filenames(self, img: DictStrAny) -> str | None: + """Get the depth filename for a given image. + + Maps image path to depth path: + 3eed/3eed_dataset/{platform}/{seq}/{frame}/image.jpg + -> 3eed/depth/{platform}/{seq}/{frame}.png + """ + # image: 3eed/3eed_dataset/waymo/seq/frame/image.jpg + # depth: 3eed/depth/waymo/seq/frame.png + path = img["file_path"] + parts = path.replace("3eed/3eed_dataset/", "3eed/depth/") + parts = parts.replace("/image.jpg", ".png") + return parts diff --git a/wilddet3d/data/datasets/waymo.py b/wilddet3d/data/datasets/waymo.py new file mode 100644 index 0000000000000000000000000000000000000000..2b2066530c43229d9423d632e0c6a6fe3c90d9be --- /dev/null +++ b/wilddet3d/data/datasets/waymo.py @@ -0,0 +1,89 @@ +"""Waymo Open Dataset for 3D object detection. + +Outdoor driving scenes with sparse LiDAR depth maps (uint16, depth_m * 256). +Categories: vehicle, pedestrian, cyclist, sign. +""" + +from __future__ import annotations + +import json +import os + +from vis4d.common.typing import ArgsType, DictStrAny + +from wilddet3d.data.datasets.coco3d import COCO3DDataset + + +def get_waymo_det_map( + dataset_name: str, + data_root: str = "data/waymo", +) -> dict[str, int]: + """Build det_map from Waymo annotation JSON categories. + + Waymo has 4 categories (vehicle, pedestrian, cyclist, sign). + Since our model is open-vocabulary (text-prompted), we build + det_map dynamically from the annotation JSON. + + Args: + dataset_name: e.g. "Waymo_train" or "Waymo_val" + data_root: Root directory for Waymo data. + """ + cache_path = os.path.join( + data_root, "annotations", f"{dataset_name}_class_map.json" + ) + if os.path.exists(cache_path): + with open(cache_path) as f: + return json.load(f) + json_path = os.path.join( + data_root, "annotations", f"{dataset_name}.json" + ) + with open(json_path) as f: + data = json.load(f) + class_map = {cat["name"]: cat["id"] for cat in data["categories"]} + with open(cache_path, "w") as f: + json.dump(class_map, f) + return class_map + + +def get_waymo_class_map( + dataset_name: str, + data_root: str = "data/waymo", +) -> dict[str, int]: + """Build class_map from Waymo annotation JSON categories. + + Args: + dataset_name: e.g. "Waymo_train" or "Waymo_val" + data_root: Root directory for Waymo data. + """ + return get_waymo_det_map(dataset_name, data_root) + + +class WaymoDataset(COCO3DDataset): + """Waymo Open Dataset. + + Outdoor driving scenes with sparse LiDAR depth maps. + """ + + def __init__( + self, + max_depth: float = 80.0, + depth_scale: float = 256.0, + **kwargs: ArgsType, + ) -> None: + """Creates an instance of the class.""" + super().__init__( + max_depth=max_depth, + depth_scale=depth_scale, + **kwargs, + ) + + def get_depth_filenames(self, img: DictStrAny) -> str | None: + """Get the depth filename for a given image. + + Maps image path to depth path: + waymo/images/validation/xxx.jpg + -> waymo/depth/validation/xxx.png + """ + return img["file_path"].replace( + "images", "depth" + ).replace(".jpg", ".png") diff --git a/wilddet3d/data/samplers.py b/wilddet3d/data/samplers.py new file mode 100644 index 0000000000000000000000000000000000000000..536278189841bd67aa0c5431e79a2bc15e99e2be --- /dev/null +++ b/wilddet3d/data/samplers.py @@ -0,0 +1,224 @@ +"""Dataset-ratio weighted sampler for multi-dataset training.""" + +from __future__ import annotations + +import math +from collections.abc import Callable, Iterator, Sequence + +import torch +from torch.utils.data import ConcatDataset, DataLoader, Sampler +from torch.utils.data.distributed import DistributedSampler + +from vis4d.common.distributed import get_rank, get_world_size +from vis4d.data.data_pipe import DataPipe +from vis4d.data.loader import build_train_dataloader +from vis4d.data.typing import DictData, DictDataOrList + + +class DatasetRatioSampler(Sampler[int]): + """Weighted sampler that controls per-dataset sampling ratios. + + For a ConcatDataset with N sub-datasets, this sampler assigns each + sample a weight based on which sub-dataset it belongs to, then + performs weighted random sampling. This allows controlling the + proportion each dataset appears during training without dropping + any data. + + Two modes of specifying ratios: + + 1. dataset_ratios (original): raw per-dataset weights. + weight_i = ratio_i / size_i, proportion is derived. + Example: dataset_ratios=[1.0, 1.0] for Omni3D(100K)+CA-1M(200K) + -> 50/50 sampling proportion. + + 2. target_proportions (new): directly specify desired proportions. + Must sum to 1.0. Weights are computed automatically. + Example: target_proportions=[0.5, 0.25, 0.25] + -> Omni3D 50%, CA-1M 25%, Waymo 25%. + + epoch_dataset_idx: If set, one epoch = the specified dataset sees + every sample once. num_samples is computed as: + size[idx] / proportion[idx] + + Supports distributed training (splits indices across ranks). + + Args: + dataset: A ConcatDataset (e.g., DataPipe with multiple datasets). + dataset_ratios: Per-dataset sampling weight. Mutually exclusive + with target_proportions. + target_proportions: Per-dataset target proportion (must sum to 1). + Mutually exclusive with dataset_ratios. + epoch_dataset_idx: If set, one epoch = this dataset sees all its + samples once. Overrides num_samples. + num_samples: Total samples per epoch. If None and + epoch_dataset_idx is None, uses sum of all dataset sizes. + shuffle: Whether to shuffle indices each epoch. + seed: Random seed for reproducibility. + """ + + def __init__( + self, + dataset: ConcatDataset, + dataset_ratios: list[float] | None = None, + target_proportions: list[float] | None = None, + epoch_dataset_idx: int | None = None, + num_samples: int | None = None, + shuffle: bool = True, + seed: int = 0, + ) -> None: + """Creates an instance of the class.""" + assert isinstance(dataset, ConcatDataset), ( + "dataset must be a ConcatDataset (e.g., DataPipe)" + ) + assert (dataset_ratios is None) != (target_proportions is None), ( + "Exactly one of dataset_ratios or target_proportions " + "must be provided" + ) + self.dataset = dataset + self.shuffle = shuffle + self.seed = seed + self.epoch = 0 + + num_datasets = len(dataset.datasets) + sizes = [len(d) for d in dataset.datasets] + + if target_proportions is not None: + assert len(target_proportions) == num_datasets, ( + f"target_proportions length ({len(target_proportions)}) " + f"must match number of sub-datasets ({num_datasets})" + ) + assert abs(sum(target_proportions) - 1.0) < 1e-6, ( + f"target_proportions must sum to 1.0, " + f"got {sum(target_proportions)}" + ) + # weight per sample = proportion_i / size_i + # Expected count: num_samples * (prop_i/size_i * size_i) / sum(prop) = num_samples * prop_i + sample_weights = [] + for size, prop in zip(sizes, target_proportions): + w = prop / size + sample_weights.extend([w] * size) + proportions = list(target_proportions) + else: + assert len(dataset_ratios) == num_datasets, ( + f"dataset_ratios length ({len(dataset_ratios)}) must " + f"match number of sub-datasets ({num_datasets})" + ) + # weight_i = ratio_i / size_i + sample_weights = [] + for size, ratio in zip(sizes, dataset_ratios): + w = ratio / size + sample_weights.extend([w] * size) + # Compute actual proportions for epoch_dataset_idx + raw = [r / s for r, s in zip(dataset_ratios, sizes)] + total = sum(raw) + proportions = [r / total for r in raw] + + self.weights = torch.tensor(sample_weights, dtype=torch.float64) + + # Determine num_samples (epoch length) + if epoch_dataset_idx is not None: + assert 0 <= epoch_dataset_idx < num_datasets + # 1 epoch = dataset[idx] sees all samples once + self.num_samples = int( + sizes[epoch_dataset_idx] / proportions[epoch_dataset_idx] + ) + print( + f"[DatasetRatioSampler] epoch_dataset_idx={epoch_dataset_idx}" + f" ({sizes[epoch_dataset_idx]} samples," + f" {proportions[epoch_dataset_idx]:.1%} proportion)" + f" -> {self.num_samples} samples/epoch" + ) + elif num_samples is not None: + self.num_samples = num_samples + else: + self.num_samples = len(dataset) + + # Log dataset info + for i, (size, prop) in enumerate(zip(sizes, proportions)): + expected = int(self.num_samples * prop) + print( + f"[DatasetRatioSampler] dataset[{i}]: " + f"size={size}, proportion={prop:.1%}, " + f"~{expected} samples/epoch" + ) + + # Distributed settings + self.world_size = get_world_size() + self.rank = get_rank() + # Each rank gets an equal share + self.num_samples_per_rank = math.ceil( + self.num_samples / self.world_size + ) + self.total_size = self.num_samples_per_rank * self.world_size + + def __iter__(self) -> Iterator[int]: + """Generate sampled indices.""" + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + + indices = torch.multinomial( + self.weights, + num_samples=self.total_size, + replacement=True, + generator=g, + ).tolist() + + # Subsample for this rank + indices = indices[self.rank::self.world_size] + assert len(indices) == self.num_samples_per_rank + + return iter(indices) + + def __len__(self) -> int: + """Return number of samples for this rank.""" + return self.num_samples_per_rank + + def set_epoch(self, epoch: int) -> None: + """Set epoch for shuffling (required for distributed training).""" + self.epoch = epoch + + +def build_train_dataloader_with_ratios( + dataset: DataPipe, + dataset_ratios: list[float] | None = None, + target_proportions: list[float] | None = None, + epoch_dataset_idx: int | None = None, + num_samples: int | None = None, + **kwargs, +) -> DataLoader[DictDataOrList]: + """Build training dataloader with per-dataset ratio sampling. + + Thin wrapper around vis4d's build_train_dataloader that creates a + DatasetRatioSampler at runtime (when the dataset is instantiated). + + Two ways to specify dataset mixing: + + 1. dataset_ratios: raw weights (original, for backwards compat). + Example: dataset_ratios=[1.0, 1.0] -> equal weight per dataset. + + 2. target_proportions: direct proportions (must sum to 1.0). + Example: target_proportions=[0.5, 0.25, 0.25] + + Args: + dataset: DataPipe (ConcatDataset) with multiple sub-datasets. + dataset_ratios: Per-dataset sampling weight (mutually exclusive + with target_proportions). + target_proportions: Per-dataset target proportion, must sum to 1. + epoch_dataset_idx: If set, 1 epoch = this dataset sees all its + samples once. Overrides num_samples. + num_samples: Total samples per epoch (overridden by + epoch_dataset_idx). + **kwargs: All other arguments forwarded to build_train_dataloader. + """ + sampler = DatasetRatioSampler( + dataset, + dataset_ratios=dataset_ratios, + target_proportions=target_proportions, + epoch_dataset_idx=epoch_dataset_idx, + num_samples=num_samples, + shuffle=kwargs.pop("shuffle", True), + ) + # shuffle must be False when using custom sampler (PyTorch requirement) + return build_train_dataloader( + dataset=dataset, sampler=sampler, shuffle=False, **kwargs + ) diff --git a/wilddet3d/data/transforms/__init__.py b/wilddet3d/data/transforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d2f92c4b6244dea653d83d9a1f4396519f679b27 --- /dev/null +++ b/wilddet3d/data/transforms/__init__.py @@ -0,0 +1 @@ +"""Data transforms.""" diff --git a/wilddet3d/data/transforms/__pycache__/__init__.cpython-311.pyc b/wilddet3d/data/transforms/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ecdb728e1ea33c5b26f108392223efb1366f2a82 Binary files /dev/null and b/wilddet3d/data/transforms/__pycache__/__init__.cpython-311.pyc differ diff --git a/wilddet3d/data/transforms/__pycache__/pad.cpython-311.pyc b/wilddet3d/data/transforms/__pycache__/pad.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d9577488bacd77ecd4c1ea79244d3e4fa0fae2b Binary files /dev/null and b/wilddet3d/data/transforms/__pycache__/pad.cpython-311.pyc differ diff --git a/wilddet3d/data/transforms/__pycache__/resize.cpython-311.pyc b/wilddet3d/data/transforms/__pycache__/resize.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f2b24e8f082cbbd4983ed8ca306bcc01997ca59 Binary files /dev/null and b/wilddet3d/data/transforms/__pycache__/resize.cpython-311.pyc differ diff --git a/wilddet3d/data/transforms/crop.py b/wilddet3d/data/transforms/crop.py new file mode 100644 index 0000000000000000000000000000000000000000..c3bc38cc72158d85e1e8f8795528df74d0ecf313 --- /dev/null +++ b/wilddet3d/data/transforms/crop.py @@ -0,0 +1,43 @@ +"""Crop transforms.""" + +from __future__ import annotations + +from vis4d.common.typing import ( + NDArrayBool, + NDArrayF32, + NDArrayI64, +) +from vis4d.data.const import CommonKeys as K +from vis4d.data.transforms.base import Transform + + +@Transform( + in_keys=[ + K.boxes3d, + K.boxes3d_classes, + K.boxes3d_track_ids, + "transforms.crop.keep_mask", + ], + out_keys=[K.boxes3d, K.boxes3d_classes, K.boxes3d_track_ids], +) +class CropBoxes3D: + """Crop 3D bounding boxes.""" + + def __call__( + self, + boxes_list: list[NDArrayF32], + classes_list: list[NDArrayI64], + track_ids_list: list[NDArrayI64] | None, + keep_mask_list: list[NDArrayBool], + ) -> tuple[list[NDArrayF32], list[NDArrayI64], list[NDArrayI64] | None]: + """Crop 3D bounding boxes.""" + for i, (boxes, classes, keep_mask) in enumerate( + zip(boxes_list, classes_list, keep_mask_list) + ): + boxes_list[i] = boxes[keep_mask] + classes_list[i] = classes[keep_mask] + + if track_ids_list is not None: + track_ids_list[i] = track_ids_list[i][keep_mask] + + return boxes_list, classes_list, track_ids_list diff --git a/wilddet3d/data/transforms/language.py b/wilddet3d/data/transforms/language.py new file mode 100644 index 0000000000000000000000000000000000000000..07e5c0ab40242a8e0550799a41de8c1f1cead87a --- /dev/null +++ b/wilddet3d/data/transforms/language.py @@ -0,0 +1,267 @@ +"""Language related transforms.""" + +from __future__ import annotations + +import random +import re + +import numpy as np +from transformers import AutoTokenizer +from vis4d.common.logging import rank_zero_warn +from vis4d.common.typing import NDArrayF32, NDArrayI64 +from vis4d.data.const import CommonKeys as K +from vis4d.data.transforms.base import Transform + + +def clean_name(name: str) -> str: + """Clean the name.""" + name = re.sub(r"\(.*\)", "", name) + name = re.sub(r"_", " ", name) + name = re.sub(r" ", " ", name) + name = name.lower() + return name + + +def generate_senetence_given_labels( + positive_label_list: list[int], + negative_label_list: list[str], + label_map: dict[str, str], +) -> tuple[dict[int, list[list[int]]], str, dict[int, int]]: + """Generate a sentence given positive and negative labels.""" + label_to_positions = {} + + label_list = negative_label_list + positive_label_list + + random.shuffle(label_list) + + pheso_caption = "" + + label_remap_dict = {} + for index, label in enumerate(label_list): + start_index = len(pheso_caption) + + pheso_caption += clean_name(label_map[str(label)]) + + end_index = len(pheso_caption) + + if label in positive_label_list: + label_to_positions[index] = [[start_index, end_index]] + label_remap_dict[int(label)] = index + + pheso_caption += ". " + + return label_to_positions, pheso_caption, label_remap_dict + + +@Transform( + [ + "dataset_type", + K.boxes2d, + K.boxes2d_classes, + K.boxes2d_names, + "label_map", + "positive_positions", + ], + [K.boxes2d, K.boxes2d_classes, K.boxes2d_names, "tokens_positive"], +) +class RandomSamplingNegPos: + """Randomly sample negative and positive labels for object detection.""" + + def __init__( + self, + tokenizer_name: str = "bert-base-uncased", + num_sample_negative: int = 85, + max_tokens: int = 256, + full_sampling_prob: float = 0.5, + ) -> None: + """Creates an instance of RandomSamplingNegPos.""" + if AutoTokenizer is None: + raise RuntimeError( + "transformers is not installed, please install it by: " + "pip install transformers." + ) + + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + self.num_sample_negative = num_sample_negative + self.full_sampling_prob = full_sampling_prob + self.max_tokens = max_tokens + + def __call__( + self, + dataset_type_list: list[str], + boxes_list: list[NDArrayF32], + class_ids_list: list[NDArrayI64], + texts_list: list[str] | None = None, + label_map_list: dict | None = None, + positive_positions_list: list[dict] | None = None, + ) -> tuple[ + list[NDArrayF32], + list[NDArrayI64], + list[str], + list[dict[int, list[list[int]]]], + ]: + """Randomly sample negative and positive labels.""" + new_texts_list = [] + tokens_positive_list = [] + for i, (boxes, class_ids) in enumerate( + zip(boxes_list, class_ids_list) + ): + if dataset_type_list[i] == "OD": + assert ( + label_map_list[i] is not None + ), "label_map should not be None" + boxes_list[i], class_ids_list[i], text, tokens_positive = ( + self.od_aug(boxes, class_ids, label_map_list[i]) + ) + new_texts_list.append(text) + tokens_positive_list.append(tokens_positive) + else: + assert ( + positive_positions_list[i] is not None + ), "positive_positions should not be None" + tokens_positive = self.vg_aug( + class_ids, positive_positions_list[i] + ) + new_texts_list.append(texts_list[i]) + tokens_positive_list.append(tokens_positive) + + return boxes_list, class_ids_list, new_texts_list, tokens_positive_list + + def vg_aug(self, class_ids: NDArrayI64, positive_positions): + """Visual Genome data augmentation.""" + positive_label_list = np.unique(class_ids).tolist() + + label_to_positions = {} + for label in positive_label_list: + label_to_positions[label] = positive_positions[label] + + return label_to_positions + + def od_aug( + self, + boxes: NDArrayF32, + class_ids: NDArrayI64, + label_map: dict, + ) -> tuple[NDArrayF32, NDArrayI64, str, dict[int, list[list[int]]]]: + """Object detection data augmentation.""" + original_box_num = len(class_ids) + + # If the category name is in the format of 'a/b' (in object365), + # we randomly select one of them. + for key, value in label_map.items(): + if "/" in value: + label_map[key] = random.choice(value.split("/")).strip() + + keep_box_index, class_ids, positive_caption_length = ( + self.check_for_positive_overflow(class_ids, label_map) + ) + + boxes = boxes[keep_box_index] + + if len(boxes) < original_box_num: + rank_zero_warn( + f"Remove {original_box_num - len(boxes)} boxes due to " + "positive caption overflow." + ) + + valid_negative_indexes = list(label_map.keys()) + + positive_label_list = np.unique(class_ids).tolist() + + full_negative = self.num_sample_negative + if full_negative > len(valid_negative_indexes): + full_negative = len(valid_negative_indexes) + + outer_prob = random.random() + + if outer_prob < self.full_sampling_prob: + # c. probability_full: add both all positive and all negatives + num_negatives = full_negative + else: + if random.random() < 1.0: + num_negatives = np.random.choice(max(1, full_negative)) + 1 + else: + num_negatives = full_negative + + # Keep some negatives + negative_label_list = set() + if num_negatives != -1: + if num_negatives > len(valid_negative_indexes): + num_negatives = len(valid_negative_indexes) + + for i in np.random.choice( + valid_negative_indexes, size=num_negatives, replace=False + ): + if int(i) not in positive_label_list: + negative_label_list.add(i) + + random.shuffle(positive_label_list) + + negative_label_list = list(negative_label_list) + random.shuffle(negative_label_list) + + negative_max_length = self.max_tokens - positive_caption_length + screened_negative_label_list = [] + + for negative_label in negative_label_list: + label_text = clean_name(label_map[str(negative_label)]) + ". " + + tokenized = self.tokenizer.tokenize(label_text) + + negative_max_length -= len(tokenized) + + if negative_max_length > 0: + screened_negative_label_list.append(negative_label) + else: + break + + negative_label_list = screened_negative_label_list + label_to_positions, pheso_caption, label_remap_dict = ( + generate_senetence_given_labels( + positive_label_list, negative_label_list, label_map + ) + ) + + # label remap + if len(class_ids) > 0: + class_ids = np.vectorize(lambda x: label_remap_dict[x])(class_ids) + + return boxes, class_ids, pheso_caption, label_to_positions + + def check_for_positive_overflow( + self, class_ids: NDArrayI64, label_map: dict[str, str] + ) -> tuple[list[int], NDArrayI64, int]: + """Check if having too many positive labels.""" + # generate a caption by appending the positive labels + positive_label_list = np.unique(class_ids).tolist() + + # random shuffule so we can sample different annotations + # at different epochs + random.shuffle(positive_label_list) + + kept_lables = [] + length = 0 + for _, label in enumerate(positive_label_list): + label_text = clean_name(label_map[str(label)]) + ". " + + tokenized = self.tokenizer.tokenize(label_text) + + length += len(tokenized) + + if length > self.max_tokens: + break + else: + kept_lables.append(label) + + keep_box_index = [] + keep_gt_labels = [] + for i, class_id in enumerate(class_ids): + if class_id in kept_lables: + keep_box_index.append(i) + keep_gt_labels.append(class_id) + + return ( + keep_box_index, + np.array(keep_gt_labels, dtype=np.int64), + length, + ) diff --git a/wilddet3d/data/transforms/masks.py b/wilddet3d/data/transforms/masks.py new file mode 100644 index 0000000000000000000000000000000000000000..18c0849a46228f3e062f81cc8f7ab818a8fea410 --- /dev/null +++ b/wilddet3d/data/transforms/masks.py @@ -0,0 +1,120 @@ +"""Spatial transforms for per-box binary masks (masks2d). + +masks2d is a list (per image in batch) of (N, H, W) uint8 arrays, +where N is the number of boxes in that image. Each mask slice is a +binary mask for one box. These transforms keep masks aligned with +images, boxes2d, and depth_maps through the spatial augmentation +pipeline. +""" + +from __future__ import annotations + +import numpy as np +import torch +import torch.nn.functional as F + +from vis4d.data.transforms.base import Transform + +MASKS2D_KEY = "masks2d" + + +@Transform( + [MASKS2D_KEY, "transforms.resize.target_shape"], + MASKS2D_KEY, +) +class ResizeMasks2D: + """Resize per-box masks using nearest interpolation.""" + + def __call__( + self, + masks_list, + target_shapes, + ): + """Resize masks.""" + if masks_list is None: + return masks_list + for i, (masks, target_shape) in enumerate( + zip(masks_list, target_shapes) + ): + if masks is None or len(masks) == 0: + continue + # masks: (N, H, W) uint8 + t = torch.from_numpy(masks).float().unsqueeze(1) # (N,1,H,W) + t = F.interpolate( + t, size=target_shape, mode="nearest" + ) + masks_list[i] = ( + t.squeeze(1).to(torch.uint8).numpy() + ) # (N, H', W') + return masks_list + + +@Transform([MASKS2D_KEY, "transforms.crop.crop_box"], MASKS2D_KEY) +class CropMasks2D: + """Crop per-box masks.""" + + def __call__( + self, + masks_list, + crop_box_list, + ): + """Crop masks.""" + if masks_list is None: + return masks_list + for i, (masks, crop_box) in enumerate( + zip(masks_list, crop_box_list) + ): + if masks is None or len(masks) == 0: + continue + x1, y1, x2, y2 = crop_box + masks_list[i] = masks[:, y1:y2, x1:x2] + return masks_list + + +@Transform(MASKS2D_KEY, MASKS2D_KEY) +class FlipMasks2D: + """Flip per-box masks horizontally.""" + + def __call__( + self, + masks_list, + ): + """Flip masks.""" + if masks_list is None: + return masks_list + for i, masks in enumerate(masks_list): + if masks is None or len(masks) == 0: + continue + masks_list[i] = np.ascontiguousarray( + masks[:, :, ::-1] + ) + return masks_list + + +@Transform([MASKS2D_KEY, "transforms.pad"], MASKS2D_KEY) +class CenterPadMasks2D: + """Center-pad per-box masks.""" + + def __call__( + self, + masks_list, + pad_params, + ): + """Pad masks.""" + if masks_list is None: + return masks_list + for i, (masks, pad_param) in enumerate( + zip(masks_list, pad_params) + ): + if masks is None or len(masks) == 0: + continue + pad = ( + pad_param["pad_left"], + pad_param["pad_right"], + pad_param["pad_top"], + pad_param["pad_bottom"], + ) + t = torch.from_numpy(masks).unsqueeze(1) # (N,1,H,W) + t = F.pad(t, pad, mode="constant", value=0) + masks_list[i] = t.squeeze(1).numpy() # (N, H', W') + return masks_list diff --git a/wilddet3d/data/transforms/pad.py b/wilddet3d/data/transforms/pad.py new file mode 100644 index 0000000000000000000000000000000000000000..3cc32a1d0233ca25bd34426a56cf3a7d17ba7d5a --- /dev/null +++ b/wilddet3d/data/transforms/pad.py @@ -0,0 +1,176 @@ +"""Pad transformation.""" + +from __future__ import annotations + +from typing import TypedDict + +import torch +import torch.nn.functional as F +from vis4d.common.typing import NDArrayF32 +from vis4d.data.const import CommonKeys as K +from vis4d.data.transforms.base import Transform +from vis4d.data.transforms.pad import _get_max_shape + + +class PadParam(TypedDict): + """Parameters for Reshape.""" + + pad_top: int + pad_bottom: int + pad_left: int + pad_right: int + + +@Transform( + [K.images, K.input_hw], + [K.images, "transforms.pad", K.input_hw, "padding"], +) +class CenterPadImages: + """Pad batch of images at the bottom right.""" + + def __init__( + self, + stride: int = 32, + mode: str = "constant", + value: float = 0.0, + update_input_hw: bool = False, + shape: tuple[int, int] | None = None, + pad2square: bool = False, + ) -> None: + """Creates an instance of PadImage. + + Args: + stride (int, optional): Chooses padding size so that the input will + be divisible by stride. Defaults to 32. + mode (str, optional): Padding mode. One of constant, reflect, + replicate or circular. Defaults to "constant". + value (float, optional): Value for constant padding. + Defaults to 0.0. + shape (tuple[int, int], optional): Shape of the padded image + (H, W). Defaults to None. + pad2square (bool, optional): Pad to square. Defaults to False. + """ + self.stride = stride + self.mode = mode + self.value = value + self.update_input_hw = update_input_hw + self.shape = shape + self.pad2square = pad2square + + def __call__( + self, images: list[NDArrayF32], input_hw: list[tuple[int, int]] + ) -> tuple[list[NDArrayF32], list[PadParam], list[tuple[int, int]]]: + """Pad images to consistent size.""" + heights = [im.shape[1] for im in images] + widths = [im.shape[2] for im in images] + + max_hw = _get_max_shape( + heights, widths, self.stride, self.shape, self.pad2square + ) + + # generate params for torch pad + pad_params = [] + target_input_hw = [] + paddings = [] + for i, (image, h, w) in enumerate(zip(images, heights, widths)): + pad_top, pad_bottom = (max_hw[0] - h) // 2, max_hw[0] - h - ( + max_hw[0] - h + ) // 2 + + pad_left, pad_right = (max_hw[1] - w) // 2, max_hw[1] - w - ( + max_hw[1] - w + ) // 2 + + image_ = torch.from_numpy(image).permute(0, 3, 1, 2) + image_ = F.pad( + image_, + (pad_left, pad_right, pad_top, pad_bottom), + self.mode, + self.value, + ) + images[i] = image_.permute(0, 2, 3, 1).numpy() + + pad_params.append( + PadParam( + pad_top=pad_top, + pad_bottom=pad_bottom, + pad_left=pad_left, + pad_right=pad_right, + ) + ) + + paddings.append([pad_left, pad_right, pad_top, pad_bottom]) + + target_input_hw.append(max_hw) + + if self.update_input_hw: + input_hw = target_input_hw + + return images, pad_params, input_hw, paddings + + +@Transform([K.intrinsics, "transforms.pad"], K.intrinsics) +class CenterPadIntrinsics: + """Resize Intrinsics.""" + + def __call__( + self, intrinsics: list[NDArrayF32], pad_params: list[PadParam] + ) -> list[NDArrayF32]: + """Scale camera intrinsics when resizing.""" + for i, intrinsic in enumerate(intrinsics): + intrinsic[0, 2] += pad_params[i]["pad_left"] + intrinsic[1, 2] += pad_params[i]["pad_top"] + + intrinsics[i] = intrinsic + return intrinsics + + +@Transform([K.boxes2d, "transforms.pad"], K.boxes2d) +class CenterPadBoxes2D: + """Pad batch of depth maps at the bottom right.""" + + def __call__( + self, boxes_list: list[NDArrayF32], pad_params: list[PadParam] + ) -> list[NDArrayF32]: + """Scale camera intrinsics when resizing.""" + for i, boxes in enumerate(boxes_list): + boxes[:, 0] += pad_params[i]["pad_left"] + boxes[:, 1] += pad_params[i]["pad_top"] + boxes[:, 2] += pad_params[i]["pad_left"] + boxes[:, 3] += pad_params[i]["pad_top"] + + boxes_list[i] = boxes + + return boxes_list + + +@Transform([K.depth_maps, "transforms.pad"], K.depth_maps) +class CenterPadDepthMaps: + """Pad batch of depth maps at the bottom right.""" + + def __init__(self, mode: str = "constant", value: int = 0) -> None: + """Creates an instance.""" + self.mode = mode + self.value = value + + def __call__( + self, depth_maps: list[NDArrayF32], pad_params: list[PadParam] + ) -> list[NDArrayF32]: + """Pad images to consistent size.""" + + # generate params for torch pad + for i, (depth, pad_param_dict) in enumerate( + zip(depth_maps, pad_params) + ): + pad_param = ( + pad_param_dict["pad_left"], + pad_param_dict["pad_right"], + pad_param_dict["pad_top"], + pad_param_dict["pad_bottom"], + ) + + depth_ = torch.from_numpy(depth).unsqueeze(0).unsqueeze(0) + depth_ = F.pad(depth_, pad_param, self.mode, self.value) + depth_maps[i] = depth_.squeeze(0).squeeze(0).numpy() + + return depth_maps diff --git a/wilddet3d/data/transforms/resize.py b/wilddet3d/data/transforms/resize.py new file mode 100644 index 0000000000000000000000000000000000000000..0ef2b4aac045f5bf11bb19f6ea9dedeb5a7a50fb --- /dev/null +++ b/wilddet3d/data/transforms/resize.py @@ -0,0 +1,121 @@ +"""Resize transformation.""" + +from __future__ import annotations + +import math + +import numpy as np +import torch +from vis4d.common.typing import NDArrayF32, NDArrayI64 +from vis4d.data.const import CommonKeys as K +from vis4d.data.transforms.base import Transform +from vis4d.data.transforms.resize import ResizeParam, resize_tensor + + +@Transform(K.images, ["transforms.resize", K.input_hw]) +class GenResizeParameters: + """Generate the parameters for a resize operation.""" + + def __init__( + self, shape: tuple[int, int], scales: tuple[float, float] | float = 1.0 + ) -> None: + """Create a new instance of the class.""" + self.shape = shape + self.scales = scales + + def __call__( + self, images: list[NDArrayF32] + ) -> tuple[list[ResizeParam], list[tuple[int, int]]]: + """Compute the parameters and put them in the data dict.""" + if isinstance(self.scales, float): + random_scale = self.scales + else: + random_scale = np.random.uniform(self.scales[0], self.scales[1]) + + shape = ( + math.ceil(self.shape[0] * random_scale - 0.5), + math.ceil(self.shape[1] * random_scale - 0.5), + ) + + output_ratio = shape[1] / shape[0] + + image = images[0] + + input_h, input_w = (image.shape[1], image.shape[2]) + input_ratio = input_w / input_h + + if output_ratio > input_ratio: + scale = shape[0] / input_h + else: + scale = shape[1] / input_w + + target_shape = ( + math.ceil(input_h * scale - 0.5), + math.ceil(input_w * scale - 0.5), + ) + + scale_factor = (target_shape[0] / input_h, target_shape[1] / input_w) + + resize_params = [ + ResizeParam(target_shape=target_shape, scale_factor=scale_factor) + ] * len(images) + target_shapes = [target_shape] * len(images) + + return resize_params, target_shapes + + +@Transform( + [K.panoptic_masks, "transforms.resize.target_shape"], K.panoptic_masks +) +class ResizePanopticMasks: + """Resize panoptic segmentation masks.""" + + def __call__( + self, + masks_list: list[NDArrayI64], + target_shape_list: list[tuple[int, int]], + ) -> list[NDArrayI64]: + """Resize masks.""" + for i, (masks, target_shape) in enumerate( + zip(masks_list, target_shape_list) + ): + masks_ = torch.from_numpy(masks) + masks_ = ( + resize_tensor( + masks_.float().unsqueeze(0).unsqueeze(0), + target_shape, + interpolation="nearest", + ) + .type(masks_.dtype) + .squeeze(0) + .squeeze(0) + ) + masks_list[i] = masks_.numpy() + return masks_list + + +@Transform([K.boxes3d, "transforms.resize.scale_factor"], K.boxes3d) +class ResizeBoxes3D: + """Resize list of 2D bounding boxes.""" + + def __call__( + self, + boxes_list: list[NDArrayF32], + scale_factors: list[tuple[float, float]], + ) -> list[NDArrayF32]: + """Resize 2D bounding boxes. + + Args: + boxes_list: (list[NDArrayF32]): The bounding boxes to be resized. + scale_factors (list[tuple[float, float]]): scaling factors. + + Returns: + list[NDArrayF32]: Resized bounding boxes according to parameters in + resize. + """ + for i, (boxes, scale_factor) in enumerate( + zip(boxes_list, scale_factors) + ): + boxes[:, 2] /= scale_factor[0] + boxes_list[i] = boxes + return boxes_list diff --git a/wilddet3d/data_types.py b/wilddet3d/data_types.py new file mode 100644 index 0000000000000000000000000000000000000000..c587150c43764b349490b8d6db254da11d9770ee --- /dev/null +++ b/wilddet3d/data_types.py @@ -0,0 +1,229 @@ +"""WildDet3D data types.""" + +from __future__ import annotations + +from dataclasses import dataclass, fields +from typing import List, NamedTuple + +import torch +from torch import Tensor + + +class Det3DOut(NamedTuple): + """Output of the detection model. + + boxes (list[Tensor]): 2D bounding boxes of shape [N, 4] in xyxy format. + boxes3d (list[Tensor]): 3D bounding boxes of shape [N, 10]. + scores (list[Tensor]): 2D confidence scores of shape [N,]. + class_ids (list[Tensor]): class ids of shape [N,]. + depth_maps (list[Tensor] | None): depth maps for each image. + categories (list[list[str]] | None): category names for each detection. + predicted_intrinsics (Tensor | None): predicted camera intrinsics (B, 3, 3). + scores_3d (list[Tensor] | None): 3D confidence scores of shape [N,]. + scores_2d (list[Tensor] | None): pure 2D confidence scores of shape [N,]. + """ + + boxes: list[Tensor] + boxes3d: list[Tensor] + scores: list[Tensor] + class_ids: list[Tensor] + depth_maps: list[Tensor] | None + categories: list[list[str]] | None = None + predicted_intrinsics: Tensor | None = None + scores_3d: list[Tensor] | None = None + scores_2d: list[Tensor] | None = None + + +class WildDet3DOut(NamedTuple): + """Output of WildDet3D model. + + All tensors use batch-first format: (B, num_queries, dim) + where B = N_prompts (per-prompt batch). + + Coordinate formats: + - pred_boxes_2d: normalized xyxy [0, 1] + - pred_boxes_3d: encoded 3D params (delta_center, log_depth, log_dims, rot_6d) + """ + # 2D Detection (from SAM3 decoder) - O2O outputs + pred_logits: Tensor # (N_prompts, num_queries, 1) - objectness + pred_boxes_2d: Tensor # (N_prompts, num_queries, 4) - normalized xyxy + + # 3D Detection (from 3D head) - O2O outputs + pred_boxes_3d: Tensor | None # (N_prompts, num_queries, 12) - encoded 3D params + + # Auxiliary outputs for each decoder layer (for deep supervision) + aux_outputs: list[dict] | None + + # Geometry backend losses (SILog depth, phi, theta) + geom_losses: dict[str, Tensor] | None + + # SAM3 specific outputs + presence_logits: Tensor | None # (N_prompts, num_queries, 1) + queries: Tensor | None # (N_prompts, num_queries, d_model) - for segmentation + + # Encoder hidden states (for depth head if needed) + encoder_hidden_states: Tensor | None # (H*W, N_prompts, d_model) + + # Matching indices from SAM3 (for loss computation) + # Format: (batch_idx, src_idx, tgt_idx) from Hungarian matching + indices: tuple | None = None + + # Normalized cxcywh boxes (needed by SAM3's Boxes loss for L1) + pred_boxes_2d_cxcywh: Tensor | None = None # (N_prompts, num_queries, 4) - normalized cxcywh + + # O2M (One-to-Many) outputs from SAM3 DAC mechanism + # These are separate outputs from the second half of queries in DAC mode + pred_logits_o2m: Tensor | None = None # (N_prompts, num_queries, 1) + pred_boxes_2d_o2m: Tensor | None = None # (N_prompts, num_queries, 4) - normalized xyxy + pred_boxes_2d_cxcywh_o2m: Tensor | None = None # (N_prompts, num_queries, 4) - normalized cxcywh + pred_boxes_3d_o2m: Tensor | None = None # (N_prompts, num_queries, 12) - encoded 3D params + + # 3D confidence head outputs (camera+depth conditioned) + pred_conf_3d: Tensor | None = None # (N_prompts, num_queries, 1) + pred_conf_3d_o2m: Tensor | None = None # (N_prompts, num_queries, 1) + + def __getitem__(self, key: str): + """Support dict-like access for vis4d data connector compatibility.""" + return getattr(self, key) + + def keys(self): + """Return field names for dict-like iteration.""" + return [f.name for f in fields(self)] + + def __contains__(self, key: str) -> bool: + """Support 'in' operator for dict-like access.""" + return hasattr(self, key) + + +@dataclass +class WildDet3DInput: + """WildDet3D batched input format (per-prompt batch). + + Design Principles: + 1. Aligned with SAM3's BatchedDatapoint + 2. Added 3D detection required fields (intrinsics, gt_boxes3d) + 3. Supports three modes: TEXT / GEOMETRIC / TEXT_GEOMETRIC + + Coordinate Format Convention: + - geo_boxes: normalized [0,1] cxcywh (SAM3 Geometry Encoder input) + - gt_boxes2d: normalized [0,1] xyxy (for loss computation) + - Model output pred_boxes_2d: normalized xyxy [0,1] + """ + + # ========== Image-level (Backbone processing) ========== + images: Tensor # (B_images, 3, H, W) + intrinsics: Tensor # (B_images, 3, 3) + + # ========== Prompt-level (expanded) ========== + img_ids: Tensor # (N_prompts,) - which image each prompt belongs to + text_ids: Tensor # (N_prompts,) - text index for each prompt + unique_texts: List[str] # deduplicated texts (including "visual" placeholder) + + # Geometry input - batch-first: (N_prompts, max_K, 4) - normalized cxcywh + # Converted to sequence-first when passed to SAM3 Prompt class + geo_boxes: Tensor | None = None # (N_prompts, max_K, 4) + geo_boxes_mask: Tensor | None = None # (N_prompts, max_K) - True=padding + geo_box_labels: Tensor | None = None # (N_prompts, max_K) - 0/1 for neg/pos + + # Point prompts (optional) + geo_points: Tensor | None = None # (N_prompts, max_P, 2) - (x, y) + geo_points_mask: Tensor | None = None # (N_prompts, max_P) - True=padding + geo_point_labels: Tensor | None = None # (N_prompts, max_P) - 0/1 for neg/pos + + # Ground Truth - normalized xyxy (training) + gt_boxes2d: Tensor | None = None # (N_prompts, max_gt, 4) - xyxy + gt_boxes3d: Tensor | None = None # (N_prompts, max_gt, 12) - 3D params + num_gts: Tensor | None = None # (N_prompts,) - number of GTs per prompt + gt_category_ids: Tensor | None = None # (N_prompts, max_gt) + + # Ignore boxes for negative loss suppression (per-prompt, same category) + # Objects marked ignore in Omni3D (truncated, occluded, behind camera, etc.) + # are not used as GT but should not cause FP penalty either. + ignore_boxes2d: Tensor | None = None # (N_prompts, max_ignore, 4) normalized xyxy + num_ignores: Tensor | None = None # (N_prompts,) number of ignore boxes per prompt + + # Query type tracking (collator-level label, does NOT control SAM3 internal matching). + # 0=TEXT, 1=VISUAL, 2=GEOMETRY, 3=VISUAL+LABEL, 4=GEOMETRY+LABEL + # "multi-target" (0,1,3): num_gts can be > 1 (all instances of a category) + # "single-target" (2,4): num_gts = 1 (one selected instance) + # NOTE: SAM3's DAC mechanism (internal o2o/o2m matcher) always runs + # both branches regardless of this field. + query_types: Tensor | None = None # (N_prompts,) int + + # Metadata for evaluation/visualization + sample_names: List[str] | None = None # (B_images,) - image identifiers + dataset_name: List[str] | None = None # (B_images,) - dataset names for evaluator + original_hw: List[tuple] | None = None # (B_images,) - original (H, W) per image + original_images: Tensor | None = None # (B_images, 3, H_orig, W_orig) - unresized + original_intrinsics: Tensor | None = None # (B_images, 3, 3) - intrinsics before resize + + # CenterPad offsets [pad_left, pad_right, pad_top, pad_bottom] + padding: List | None = None # (B_images,) - padding offsets per image + + # Depth Ground Truth (for geometry backend supervision) + depth_gt: Tensor | None = None # (B_images, 1, H, W) depth map + depth_mask: Tensor | None = None # (B_images, H, W) valid depth mask + + # Key aliases for vis4d DataConnector compatibility + # Maps expected DataLoader keys to actual dataclass field names + _KEY_ALIASES = { + # Target boxes (for loss computation) + "boxes2d": "gt_boxes2d", + "boxes3d": "gt_boxes3d", + "boxes2d_classes": "gt_category_ids", + # Geometric prompts (for SAM3 input) + "prompt_boxes": "geo_boxes", + "prompt_box_labels": "geo_box_labels", + # Not available in per-prompt batch + "depth_maps": None, + "original_hw": None, + "original_images": None, + "padding": None, + } + + def __getitem__(self, key: str): + """Support dict-like access for vis4d data connector compatibility. + + Supports both actual field names and aliased keys from raw DataLoader. + """ + # Check alias first + if key in self._KEY_ALIASES: + aliased_key = self._KEY_ALIASES[key] + if aliased_key is None: + return None # Field not available + return getattr(self, aliased_key) + + # Handle special computed fields + if key == "input_hw": + # Return (H, W) from images shape + return (self.images.shape[2], self.images.shape[3]) + + # Direct field access + if hasattr(self, key): + return getattr(self, key) + + # Return None for unknown keys instead of raising error + return None + + def keys(self): + """Return field names for dict-like iteration.""" + return [f.name for f in fields(self)] + + def __contains__(self, key: str) -> bool: + """Support 'in' operator for dict-like access.""" + return hasattr(self, key) + + @property + def num_images(self) -> int: + """Number of unique images.""" + return self.images.shape[0] + + @property + def num_prompts(self) -> int: + """Number of prompts (batch size for decoder).""" + return self.img_ids.shape[0] + + @property + def device(self) -> torch.device: + """Device of the batch.""" + return self.images.device diff --git a/wilddet3d/depth/__init__.py b/wilddet3d/depth/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0090ead850330426c45d79494176e5df9133eb5b --- /dev/null +++ b/wilddet3d/depth/__init__.py @@ -0,0 +1,10 @@ +"""Depth estimation backends.""" + +from .base import GeometryBackendBase, GeometryBackendOutput +from .lingbot_backend import LingbotDepthBackend + +__all__ = [ + "GeometryBackendBase", + "GeometryBackendOutput", + "LingbotDepthBackend", +] diff --git a/wilddet3d/depth/__pycache__/__init__.cpython-311.pyc b/wilddet3d/depth/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ac8232468f56074bfc14d446d055735c07afb79 Binary files /dev/null and b/wilddet3d/depth/__pycache__/__init__.cpython-311.pyc differ diff --git a/wilddet3d/depth/__pycache__/base.cpython-311.pyc b/wilddet3d/depth/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4dcadc33aef3774e2b25009a974e589c037d330a Binary files /dev/null and b/wilddet3d/depth/__pycache__/base.cpython-311.pyc differ diff --git a/wilddet3d/depth/__pycache__/depth_fusion.cpython-311.pyc b/wilddet3d/depth/__pycache__/depth_fusion.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5bb84f3cebb34bbafa5c669923b6daa82af1ac51 Binary files /dev/null and b/wilddet3d/depth/__pycache__/depth_fusion.cpython-311.pyc differ diff --git a/wilddet3d/depth/__pycache__/lingbot_backend.cpython-311.pyc b/wilddet3d/depth/__pycache__/lingbot_backend.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea80e0fd3d9f5d3fc767ea6945614a384b724758 Binary files /dev/null and b/wilddet3d/depth/__pycache__/lingbot_backend.cpython-311.pyc differ diff --git a/wilddet3d/depth/base.py b/wilddet3d/depth/base.py new file mode 100644 index 0000000000000000000000000000000000000000..c856ce6ff1fb60e03b33c447dd24c071cfad4daf --- /dev/null +++ b/wilddet3d/depth/base.py @@ -0,0 +1,187 @@ +"""GeometryBackendBase: Abstract interface for depth/geometry backends. + +Each backend is a self-contained geometry module that: +- Extracts features using its own method (DINO, Swin+FPN, etc.) +- Runs its own depth head +- Computes its own geometry losses + +The interface provides a unified way to plug different geometry systems +into the 3D-MOOD / GroundingDINO3D framework. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TypedDict + +import torch +from torch import Tensor, nn + + +class GeometryBackendOutput(TypedDict, total=False): + """Output dictionary from GeometryBackend. + + Attributes: + depth_map: Predicted depth map [B, 1, H, W] in metric scale. + depth_latents: Depth latent tokens [B, N, C] for 3D head. + Dimension C is aligned to target_latent_dim (default: 128). + K_pred: Predicted camera intrinsics [B, 3, 3] (optional). + ray_intrinsics: Intrinsics to use for ray_embeddings generation [B, 3, 3]. + This may be adjusted intrinsics for DINOv2-based backends. + ray_image_hw: Image (H, W) to use for ray_embeddings generation. + This corresponds to the space where depth_latents were computed. + ray_downsample: Downsample factor for ray_embeddings (8 or 16). + Must match the spatial resolution of depth_latents. + aux: Auxiliary outputs (rays, points, confidence, etc.). + losses: Dictionary of geometry losses (only in training). + """ + + depth_map: Tensor + depth_latents: Tensor + K_pred: Tensor | None + ray_intrinsics: Tensor + ray_image_hw: tuple[int, int] + ray_downsample: int + aux: dict[str, Tensor] + losses: dict[str, Tensor] + + +class GeometryBackendBase(nn.Module, ABC): + """Abstract base class for geometry backends. + + Each concrete implementation wraps a complete geometry pipeline: + - Feature extraction (backbone + neck specific to this backend) + - Depth head + - Loss computation + + This allows switching between different depth systems (UniDepthHead, + DetAny3D, UniDepthV2) without changing the main GroundingDINO3D code. + + Args: + detach_depth_latents: If True, detach depth_latents before returning. + This prevents gradients from the 3D head from flowing back to + the depth head. Useful when you want to freeze depth training + but still use its features for 3D detection. + """ + + # Whether this backend's depth decoder already incorporates ray/camera info. + # If True, the 3D head does NOT need a separate camera prompt branch, + # because the depth_latents are already ray-aware. + # - UniDepthV2 / DetAny3D: True (decoder fuses rays internally) + # - UniDepthHead (v1): False (no ray info in decoder) + is_ray_aware: bool = False + + def __init__(self, detach_depth_latents: bool = False) -> None: + """Initialize the geometry backend. + + Args: + detach_depth_latents: Whether to detach depth_latents from the graph. + """ + super().__init__() + self.detach_depth_latents = detach_depth_latents + + def _maybe_detach_latents(self, depth_latents: Tensor | None) -> Tensor | None: + """Optionally detach depth latents from computation graph. + + Args: + depth_latents: Depth latents [B, N, C] or None + + Returns: + Detached latents if detach_depth_latents is True, otherwise unchanged + """ + if depth_latents is not None and self.detach_depth_latents: + return depth_latents.detach() + return depth_latents + + @abstractmethod + def forward_train( + self, + images: Tensor, + depth_feats: list[Tensor] | None, + intrinsics: Tensor, + image_hw: tuple[int, int], + depth_gt: Tensor | None = None, + depth_mask: Tensor | None = None, + **kwargs, + ) -> GeometryBackendOutput: + """Forward pass for training. + + Args: + images: Input images [B, 3, H, W]. + depth_feats: Multi-scale features from FPN [B, C, H_i, W_i] (for + backends that use external features like UniDepthHead). + Can be None for backends with their own encoder (e.g., UniDepthV2). + intrinsics: Camera intrinsics [B, 3, 3]. + image_hw: Input image size (H, W). + depth_gt: Ground truth depth [B, H, W] (optional). + depth_mask: Valid depth mask [B, H, W] (optional). + **kwargs: Additional backend-specific arguments. + + Returns: + GeometryBackendOutput containing: + - depth_map: [B, 1, H, W] + - depth_latents: [B, N, C] + - K_pred: [B, 3, 3] or None + - aux: dict of auxiliary outputs + - losses: dict of loss tensors + """ + raise NotImplementedError + + @torch.no_grad() + @abstractmethod + def forward_test( + self, + images: Tensor, + depth_feats: list[Tensor] | None, + intrinsics: Tensor, + image_hw: tuple[int, int], + **kwargs, + ) -> GeometryBackendOutput: + """Forward pass for inference (no loss computation). + + Args: + images: Input images [B, 3, H, W]. + depth_feats: Multi-scale features from FPN (optional). + intrinsics: Camera intrinsics [B, 3, 3]. + image_hw: Input image size (H, W). + **kwargs: Additional backend-specific arguments. + + Returns: + GeometryBackendOutput containing: + - depth_map: [B, 1, H, W] + - depth_latents: [B, N, C] + - K_pred: [B, 3, 3] or None + - aux: dict of auxiliary outputs + - losses: empty dict + """ + raise NotImplementedError + + def forward( + self, + images: Tensor, + depth_feats: list[Tensor] | None, + intrinsics: Tensor, + image_hw: tuple[int, int], + depth_gt: Tensor | None = None, + depth_mask: Tensor | None = None, + **kwargs, + ) -> GeometryBackendOutput: + """Forward pass (dispatches to train or test based on mode).""" + if self.training: + return self.forward_train( + images=images, + depth_feats=depth_feats, + intrinsics=intrinsics, + image_hw=image_hw, + depth_gt=depth_gt, + depth_mask=depth_mask, + **kwargs, + ) + return self.forward_test( + images=images, + depth_feats=depth_feats, + intrinsics=intrinsics, + image_hw=image_hw, + depth_gt=depth_gt, + **kwargs, + ) diff --git a/wilddet3d/depth/depth_fusion.py b/wilddet3d/depth/depth_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..3487cc477bced2c1fa5f9d48e987263ef542c770 --- /dev/null +++ b/wilddet3d/depth/depth_fusion.py @@ -0,0 +1,223 @@ +"""Early Depth Fusion Modules. + +Two variants for fusing depth latents into visual features before the encoder: + +1. EarlyDepthFusionUniDepthV2 (Concat-Add): + Concatenate visual + depth, project back, residual add. + delta = W * [P; D] + output = P + delta + +2. EarlyDepthFusionLingbot (ControlNet-style): + LayerNorm depth, project depth only, residual add. + delta = W_d @ LayerNorm(D) + output = P + delta +""" + +from __future__ import annotations + +import torch +import torch.nn as nn +from torch import Tensor + + +class EarlyDepthFusionUniDepthV2(nn.Module): + """Concat-Add fusion for UniDepthV2 backend. + + Concatenates visual and depth features, projects back to visual dim, + then adds as residual. More expressive than depth-only projection: + delta = W_P * P + W_D * D (from concat projection) + output = P + delta = (I + W_P) * P + W_D * D + + Args: + visual_dim: Dimension of visual features (e.g., 256). + depth_dim: Dimension of depth latents (e.g., 256). + fusion_type: Kept for config compatibility, ignored. + zero_init: Whether to zero-initialize the projection layer. + """ + + def __init__( + self, + visual_dim: int = 256, + depth_dim: int = 256, + fusion_type: str = "concat_add", + zero_init: bool = True, + ): + super().__init__() + + self.visual_dim = visual_dim + self.depth_dim = depth_dim + + # Projection: [C + C_depth] -> [C] + self.proj = nn.Conv2d( + visual_dim + depth_dim, + visual_dim, + kernel_size=1, + bias=True, + ) + + if zero_init: + nn.init.zeros_(self.proj.weight) + nn.init.zeros_(self.proj.bias) + + def forward( + self, + visual_feats: list[Tensor], + depth_latents: Tensor, + depth_latents_hw: tuple[int, int], + ) -> list[Tensor]: + """Fuse depth latents into visual features. + + Args: + visual_feats: List of visual features [[B, C, H, W]]. + depth_latents: Depth features [B, N, C_depth]. + depth_latents_hw: (H_d, W_d) spatial dims of depth latents. + + Returns: + List of fused visual features with same shapes as input. + """ + if depth_latents is None or len(visual_feats) == 0: + return visual_feats + + B, N, C_depth = depth_latents.shape + H_d, W_d = depth_latents_hw + + assert N == H_d * W_d, f"depth_latents N={N} != H_d*W_d={H_d * W_d}" + + # Reshape: [B, N, C_depth] -> [B, C_depth, H_d, W_d] + depth_2d = depth_latents.permute(0, 2, 1).reshape( + B, C_depth, H_d, W_d + ) + + fused_feats = [] + for visual_feat in visual_feats: + B_v, C_v, H_v, W_v = visual_feat.shape + assert C_v == self.visual_dim + + # Interpolate depth to match visual spatial size + if (H_d, W_d) != (H_v, W_v): + depth_resized = torch.nn.functional.interpolate( + depth_2d, + size=(H_v, W_v), + mode="bilinear", + align_corners=False, + ) + else: + depth_resized = depth_2d + + # Concat + project + residual + concat_feat = torch.cat([visual_feat, depth_resized], dim=1) + proj_feat = self.proj(concat_feat) + fused_feat = visual_feat + proj_feat + + fused_feats.append(fused_feat) + + return fused_feats + + +class EarlyDepthFusionLingbot(nn.Module): + """ControlNet-style fusion for Lingbot depth backend. + + LayerNorm on depth latents, project depth only, residual add. + Visual features never pass through any trainable layer, preserving + the pretrained distribution. + + Args: + visual_dim: Dimension of visual features (e.g., 256). + depth_dim: Dimension of depth latents (e.g., 256). + fusion_type: Kept for config compatibility, ignored. + zero_init: Whether to zero-initialize the projection layer. + """ + + def __init__( + self, + visual_dim: int = 256, + depth_dim: int = 256, + fusion_type: str = "concat_add", + zero_init: bool = True, + ): + super().__init__() + + self.visual_dim = visual_dim + self.depth_dim = depth_dim + + # Normalize depth_latents to unit scale before projection. + # depth_latents (raw neck output, std~4.0) and visual features + # (SAM3 FPN, std~0.017) differ by ~230x. LayerNorm brings depth + # to mean=0, std=1 so the projection sees consistent input scale. + self.depth_norm = nn.LayerNorm(depth_dim) + + # Projection: depth_dim -> visual_dim (depth only) + self.proj = nn.Conv2d( + depth_dim, + visual_dim, + kernel_size=1, + bias=True, + ) + + if zero_init: + nn.init.zeros_(self.proj.weight) + nn.init.zeros_(self.proj.bias) + + def forward( + self, + visual_feats: list[Tensor], + depth_latents: Tensor, + depth_latents_hw: tuple[int, int], + ) -> list[Tensor]: + """Fuse depth latents into visual features. + + Args: + visual_feats: List of visual features [[B, C, H, W]]. + depth_latents: Depth features [B, N, C_depth]. + depth_latents_hw: (H_d, W_d) spatial dims of depth latents. + + Returns: + List of fused visual features with same shapes as input. + """ + if depth_latents is None or len(visual_feats) == 0: + return visual_feats + + B, N, C_depth = depth_latents.shape + H_d, W_d = depth_latents_hw + + assert N == H_d * W_d, f"depth_latents N={N} != H_d*W_d={H_d * W_d}" + + # Normalize depth_latents to unit scale + # Cast to match LayerNorm dtype (AMP bf16 compatibility) + depth_latents = depth_latents.to(self.depth_norm.weight.dtype) + depth_latents = self.depth_norm(depth_latents) + + # Reshape: [B, N, C_depth] -> [B, C_depth, H_d, W_d] + depth_2d = depth_latents.permute(0, 2, 1).reshape( + B, C_depth, H_d, W_d + ) + + fused_feats = [] + for visual_feat in visual_feats: + B_v, C_v, H_v, W_v = visual_feat.shape + assert C_v == self.visual_dim + + # Interpolate depth to match visual spatial size + if (H_d, W_d) != (H_v, W_v): + depth_resized = torch.nn.functional.interpolate( + depth_2d, + size=(H_v, W_v), + mode="bilinear", + align_corners=False, + ) + else: + depth_resized = depth_2d + + # Project depth only + residual add + delta = self.proj(depth_resized) + fused_feat = visual_feat + delta + + self._last_delta_mean_abs = delta.detach().abs().mean().item() + + fused_feats.append(fused_feat) + + return fused_feats + + +# Backward compatibility alias +EarlyDepthFusion = EarlyDepthFusionUniDepthV2 diff --git a/wilddet3d/depth/lingbot_backend.py b/wilddet3d/depth/lingbot_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..1a7dd5fa2e3b9fff89c01fac62b90845849b9eef --- /dev/null +++ b/wilddet3d/depth/lingbot_backend.py @@ -0,0 +1,1543 @@ +"""LingbotDepthBackend: LingBot-Depth geometry backend for 3D-MOOD. + +Uses DINOv2 RGB-D encoder with mixed depth input strategy (per-sample): +- 70% monocular: zero depth input +- 20% patch-masked: patch-level random masking (60-90% ratio, following + the Masked Depth Modeling paper) for depth completion training +- 10% copy-through: full depth_gt as input +- Inference: always zero depth (monocular mode) + +Intrinsic prediction: MLP on cls_token predicts camera K. +is_ray_aware = False so the 3D head's camera prompt branch is active. + +Depth loss: L1 + MoGe2 affine-invariant losses (global, local, edge) + + confidence mask BCE on all valid pixels. +Camera loss: ray-based MSE (same approach as UniDepthV2). +""" + +from __future__ import annotations + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from .base import GeometryBackendBase, GeometryBackendOutput +from wilddet3d.ops.ray import generate_rays + +import utils3d + + +def backproject_depth_to_points( + depth: Tensor, K: Tensor, H: int, W: int +) -> Tensor: + """Back-project depth map to 3D points using camera intrinsics. + + Uses utils3d (same as MoGe2) with normalized intrinsics. + + Args: + depth: [B, 1, H, W] or [B, H, W] metric depth. + K: [B, 3, 3] camera intrinsics (pixel space). + H: Image height. + W: Image width. + + Returns: + points: [B, H, W, 3] 3D points in camera space (x, y, z). + """ + z = depth.squeeze(1) if depth.ndim == 4 else depth # [B, H, W] + # Normalize pixel intrinsics to [0, 1] for utils3d + K_norm = K.clone() + K_norm[:, 0, 0] /= W + K_norm[:, 0, 2] /= W + K_norm[:, 1, 1] /= H + K_norm[:, 1, 2] /= H + return utils3d.pt.depth_map_to_point_map(z, intrinsics=K_norm) + + +class LingbotDepthBackend(GeometryBackendBase): + """Backend using LingBot-Depth (DINOv2 RGB-D encoder + ConvStack decoder). + + Loads a pretrained MDMModel and decomposes it into: + - encoder: DINOv2_RGBD_Encoder (RGB-D feature extraction) + - neck: ConvStack (multiscale refinement) + - depth_head: ConvStack (depth regression) + + depth_latents are extracted from neck level 1 output (after 2 ResBlocks, + 256-dim, 2x encoder resolution) and pooled to encoder grid size. + This matches UniDepthV2's approach of using decoder intermediate features. + + During training, each sample independently gets one of three modes: + - monocular (zero depth): prob = monocular_prob (default 0.7) + - patch-masked depth: prob = masked_prob (default 0.2) + - copy-through (full depth): prob = 1 - monocular - masked (0.1) + During inference, always zero depth. + + Args: + pretrained_model: Path or HuggingFace repo ID for MDMModel. + num_tokens: Number of base tokens for the encoder. + target_latent_dim: Target dimension for depth_latents. + Neck level 1 outputs 256-dim; if target != 256, a Linear + projection is applied. Use 256 to avoid projection. + depth_loss_weight: Weight for L1 depth loss. + silog_loss_weight: Weight for SILog depth loss (scale-invariant). + affine_global_weight: Weight for MoGe2 affine-invariant global loss. + affine_local_weight: Weight for MoGe2 affine-invariant local loss. + edge_loss_weight: Weight for MoGe2 edge loss. + mask_loss_weight: Weight for confidence mask BCE loss. + monocular_prob: Probability of zero depth input (training). + masked_prob: Probability of patch-masked depth input (training). + mask_ratio_range: (min, max) masking ratio for patch-masked mode. + mask_patch_size: Patch size for depth masking grid. + camera_loss_weight: Weight for ray-based L2 camera loss. + detach_depth_latents: Whether to detach depth_latents from graph. + encoder_freeze_blocks: Number of encoder transformer blocks to + freeze (from the beginning). ViT-L has 24 blocks; e.g. 20 + freezes blocks[0..19], only training the last 4. + """ + + # Encoder does not fuse camera rays; 3D head needs camera prompt + is_ray_aware: bool = False + + def __init__( + self, + pretrained_model: str = ( + "robbyant/lingbot-depth-pretrain-vitl-14-v0.5" + ), + num_tokens: int = 2400, + target_latent_dim: int = 128, + depth_loss_weight: float = 1.0, + silog_loss_weight: float = 0.5, + affine_global_weight: float = 10.0, + affine_local_weight: float = 10.0, + edge_loss_weight: float = 10.0, + mask_loss_weight: float = 0.1, + monocular_prob: float = 0.7, + masked_prob: float = 0.2, + mask_ratio_range: tuple[float, float] = (0.6, 0.9), + mask_patch_size: int = 14, + camera_loss_weight: float = 1.0, + detach_depth_latents: bool = True, + encoder_freeze_blocks: int = 0, + unpad_test: bool = True, + ) -> None: + """Initialize the LingbotDepthBackend.""" + super().__init__(detach_depth_latents=detach_depth_latents) + self.unpad_test = unpad_test + + self.num_tokens = num_tokens + self.target_latent_dim = target_latent_dim + self.depth_loss_weight = depth_loss_weight + self.silog_loss_weight = silog_loss_weight + self.affine_global_weight = affine_global_weight + self.affine_local_weight = affine_local_weight + self.edge_loss_weight = edge_loss_weight + self.mask_loss_weight = mask_loss_weight + self.monocular_prob = monocular_prob + self.masked_prob = masked_prob + self.mask_ratio_range = mask_ratio_range + self.mask_patch_size = mask_patch_size + self.camera_loss_weight = camera_loss_weight + + # SILog loss (scale-invariant) - lazy init, only needed for training + self._silog_loss_weight = silog_loss_weight + self._silog_loss = None + + # Load pretrained MDMModel and decompose into sub-modules + from mdm.model.v2 import MDMModel + + print( + f"[LingbotDepth] Loading pretrained model: " + f"{pretrained_model}" + ) + mdm_model = MDMModel.from_pretrained(pretrained_model) + + self.encoder = mdm_model.encoder + self.neck = mdm_model.neck + self.depth_head = mdm_model.depth_head + self.remap_depth_in = mdm_model.remap_depth_in + self.remap_depth_out = mdm_model.remap_depth_out + + # Load mask_head from pretrained model (confidence prediction) + if hasattr(mdm_model, "mask_head"): + self.mask_head = mdm_model.mask_head + print("[LingbotDepth] mask_head loaded from checkpoint") + else: + self.mask_head = None + print( + "[LingbotDepth] WARNING: mask_head not found in " + "checkpoint, confidence prediction disabled" + ) + + # Get dimensions from loaded model + cls_dim = self.encoder.dim_features + + # Neck level 1 outputs 256-dim features. + # If target_latent_dim != 256, project; otherwise Identity. + self._neck_latent_dim = 256 + if target_latent_dim != self._neck_latent_dim: + self.latent_proj = nn.Linear( + self._neck_latent_dim, target_latent_dim + ) + else: + self.latent_proj = nn.Identity() + + # Intrinsic prediction head: cls_token -> camera K + # Same parameterization as UniDepthV2 CameraHead: + # exp(raw_f) * 0.7 * diagonal for focal length, + # sigmoid(raw_c) * W/H for principal point. + # Init: exp(0)=1.0 gives fx ~ 0.7*diag, sigmoid(0)=0.5 gives cx=W/2 + self.intrinsic_head = nn.Sequential( + nn.LayerNorm(cls_dim), + nn.Linear(cls_dim, 256), + nn.ReLU(), + nn.Linear(256, 4), + ) + nn.init.zeros_(self.intrinsic_head[-1].weight) + nn.init.zeros_(self.intrinsic_head[-1].bias) + + # De-normalization buffers: convert 3D-MOOD normalized images + # back to [0,1] for the encoder (which does its own ImageNet norm) + self.register_buffer( + "denorm_mean", + torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1), + ) + self.register_buffer( + "denorm_std", + torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1), + ) + + # Delete reference to full model (sub-modules survive via self) + del mdm_model + + # torch.compile for encoder (controlled by SAM3_COMPILE env var) + import os + if os.environ.get("SAM3_COMPILE", "0") == "1": + self.encoder = torch.compile(self.encoder) + print("[LingbotDepth] torch.compile ENABLED for encoder") + + # Freeze the first N transformer blocks of the encoder backbone. + # ViT-L has 24 blocks; e.g. encoder_freeze_blocks=20 freezes + # blocks[0..19] and only trains blocks[20..23] + patch_embed + + # norm + output_projections + neck + depth_head + new heads. + num_blocks = len(self.encoder.backbone.blocks) + encoder_freeze_blocks = min(encoder_freeze_blocks, num_blocks) + if encoder_freeze_blocks > 0: + bb = self.encoder.backbone + # Freeze everything in backbone first + for p in bb.parameters(): + p.requires_grad = False + # Unfreeze the last (num_blocks - freeze_blocks) blocks + for i in range(encoder_freeze_blocks, num_blocks): + for p in bb.blocks[i].parameters(): + p.requires_grad = True + # Unfreeze final norm (after all blocks) + for p in bb.norm.parameters(): + p.requires_grad = True + + copythrough_prob = 1.0 - monocular_prob - masked_prob + freeze_msg = ( + f" encoder freeze: {encoder_freeze_blocks}/{num_blocks}" + f" blocks frozen" + ) + print( + f"[LingbotDepth] Initialized: " + f"cls_dim={cls_dim}, num_tokens={num_tokens}, " + f"depth_latents=neck[1] (256-dim, pooled)\n" + f" remap_depth_in={self.remap_depth_in}, " + f"remap_depth_out={self.remap_depth_out}\n" + f" depth strategy: {monocular_prob:.0%} monocular / " + f"{masked_prob:.0%} patch-masked / " + f"{copythrough_prob:.0%} copy-through\n" + f" mask_ratio_range={mask_ratio_range}, " + f"mask_patch_size={mask_patch_size}\n" + f" losses: L1={depth_loss_weight}, " + f"affine_global={affine_global_weight}, " + f"affine_local={affine_local_weight}, " + f"edge={edge_loss_weight}, " + f"mask_bce={mask_loss_weight}, " + f"camera_ray={camera_loss_weight}\n" + f" mask_head={'loaded' if self.mask_head is not None else 'none'}\n" + f"{freeze_msg}" + ) + + def load_pretrained_weights(self) -> None: + """No-op: weights already loaded in __init__ via from_pretrained.""" + pass + + def _compute_token_grid( + self, H: int, W: int + ) -> tuple[int, int]: + """Compute token grid dimensions from image aspect ratio. + + Same formula as MDMModel.forward lines 110-115. + + Args: + H: Image height. + W: Image width. + + Returns: + (base_h, base_w) token grid dimensions. + """ + aspect_ratio = W / H + base_h = round(math.sqrt(self.num_tokens / aspect_ratio)) + base_w = round(math.sqrt(self.num_tokens * aspect_ratio)) + return base_h, base_w + + def _prepare_depth_input( + self, + depth_gt: Tensor | None, + depth_mask: Tensor | None, + B: int, + H: int, + W: int, + device: torch.device, + ) -> Tensor | None: + """Prepare depth input with mixed strategy for training. + + Per-sample mode selection: + - [0, monocular_prob): zero depth (monocular) + - [monocular_prob, monocular_prob + masked_prob): patch-masked + - [monocular_prob + masked_prob, 1.0): copy-through (full depth) + + Args: + depth_gt: Ground truth depth [B, H, W] or [B, 1, H, W]. + depth_mask: Valid depth mask [B, H, W] or [B, 1, H, W]. + B: Batch size. + H: Image height. + W: Image width. + device: Tensor device. + + Returns: + depth_input [B, 1, H, W] or None if no depth_gt. + """ + if depth_gt is None: + return None + + if depth_gt.ndim == 3: + depth_gt = depth_gt.unsqueeze(1) # [B, 1, H, W] + + # Apply depth_mask if provided + if depth_mask is not None: + if depth_mask.ndim == 3: + depth_mask = depth_mask.unsqueeze(1) + depth_gt = depth_gt * depth_mask.float() + + depth_input = torch.zeros_like(depth_gt) + rand_vals = torch.rand(B, device=device) + masked_threshold = self.monocular_prob + self.masked_prob + + for i in range(B): + if rand_vals[i] < self.monocular_prob: + # Monocular: keep zeros + pass + elif rand_vals[i] < masked_threshold: + # Patch-level random masking + depth_input[i] = self._patch_mask_depth( + depth_gt[i], H, W, device + ) + else: + # Copy-through: full depth + depth_input[i] = depth_gt[i] + + return depth_input + + def _patch_mask_depth( + self, + depth: Tensor, + H: int, + W: int, + device: torch.device, + ) -> Tensor: + """Apply patch-level random masking to depth map. + + Following the MDM paper: randomly mask 60-90% of patches, + zeroing out entire patch regions. + + Args: + depth: [1, H, W] single-sample depth map. + H: Image height. + W: Image width. + device: Tensor device. + + Returns: + Masked depth [1, H, W] with some patches zeroed out. + """ + ps = self.mask_patch_size + grid_h = H // ps + grid_w = W // ps + num_patches = grid_h * grid_w + + # Random masking ratio in [min, max] + lo, hi = self.mask_ratio_range + mask_ratio = torch.rand(1, device=device).item() * (hi - lo) + lo + num_masked = int(num_patches * mask_ratio) + + # Random permutation: first num_masked patches are masked (0) + perm = torch.randperm(num_patches, device=device) + keep = torch.ones(num_patches, device=device) + keep[perm[:num_masked]] = 0.0 + + # Reshape to spatial grid and upsample to image size + keep = keep.view(1, 1, grid_h, grid_w) + keep = F.interpolate( + keep, size=(grid_h * ps, grid_w * ps), mode="nearest" + ) # [1, 1, grid_h*ps, grid_w*ps] + + # Pad if image size not divisible by patch size + pad_h = H - grid_h * ps + pad_w = W - grid_w * ps + if pad_h > 0 or pad_w > 0: + keep = F.pad(keep, (0, pad_w, 0, pad_h), value=1.0) + + return depth * keep.squeeze(0) # [1, H, W] + + def _predict_intrinsics( + self, cls_token: Tensor, H: int, W: int + ) -> Tensor: + """Predict camera intrinsics from cls_token. + + Same parameterization as UniDepthV2 CameraHead.fill_intrinsics: + - fx = exp(raw) * 0.7 * diagonal + - fy = exp(raw) * 0.7 * diagonal + - cx = sigmoid(raw) * W + - cy = sigmoid(raw) * H + + Args: + cls_token: [B, cls_dim] class token from encoder. + H: Image height (original pixel space). + W: Image width (original pixel space). + + Returns: + K_pred: [B, 3, 3] predicted intrinsics in pixel coords. + """ + params = self.intrinsic_head(cls_token) # [B, 4] + + diagonal = (H**2 + W**2) ** 0.5 + fx = torch.exp(params[:, 0].clamp(-10, 10)) * 0.7 * diagonal + fy = torch.exp(params[:, 1].clamp(-10, 10)) * 0.7 * diagonal + cx = torch.sigmoid(params[:, 2]) * W + cy = torch.sigmoid(params[:, 3]) * H + + B = cls_token.shape[0] + K_pred = torch.zeros( + B, 3, 3, device=cls_token.device, dtype=cls_token.dtype + ) + K_pred[:, 0, 0] = fx + K_pred[:, 1, 1] = fy + K_pred[:, 0, 2] = cx + K_pred[:, 1, 2] = cy + K_pred[:, 2, 2] = 1.0 + + return K_pred + + def _run_encoder_and_decoder( + self, + images: Tensor, + depth_input: Tensor | None, + image_hw: tuple[int, int], + ) -> tuple[Tensor, Tensor, Tensor, int, int, list[Tensor]]: + """Run encoder + neck + depth_head pipeline. + + Replicates MDMModel.forward() logic (lines 98-168 of v2.py). + + Args: + images: [B, 3, H, W] 3D-MOOD normalized images. + depth_input: [B, 1, H, W] depth for encoder, or None. + image_hw: Original (H, W) dimensions. + + Returns: + depth_map: [B, 1, H, W] metric depth in meters. + depth_latents: [B, N, target_latent_dim]. + cls_token: [B, cls_dim]. + base_h: Token grid height. + base_w: Token grid width. + neck_out: List of neck feature maps for mask_head. + """ + from mdm.utils.geo import normalized_view_plane_uv + + B = images.shape[0] + H, W = image_hw + device, dtype = images.device, images.dtype + + # De-normalize from 3D-MOOD normalization to [0, 1] + # 3D-MOOD: norm_img = (img_255 - mean_255) / std_255 + # Reverse: img_01 = norm_img * (std_255/255) + (mean_255/255) + # = norm_img * imagenet_std + imagenet_mean + images_01 = images * self.denorm_std + self.denorm_mean + + # Compute token grid + base_h, base_w = self._compute_token_grid(H, W) + + # Prepare depth: zeros if None (monocular mode) + if depth_input is None: + depth_for_encoder = torch.zeros( + B, 1, H, W, device=device, dtype=dtype + ) + else: + depth_for_encoder = depth_input + + # Encoder forward: expects [0,1] images + # (encoder internally normalizes with ImageNet stats and resizes + # to (base_h*14, base_w*14)) + # enable_depth_mask=False avoids xformers BlockDiagonalMask + # dependency and uses standard attention instead + features, cls_token, _, _ = self.encoder( + images_01, + depth_for_encoder, + base_h, + base_w, + return_class_token=True, + remap_depth_in=self.remap_depth_in, + enable_depth_mask=False, + ) + # features: [B, encoder_dim, base_h, base_w] + # cls_token: [B, cls_dim] + + # Run neck + depth_head (MDMModel.forward lines 120-148) + aspect_ratio = W / H + + # Add cls_token to features + feat_with_cls = features + cls_token[..., None, None] + feat_list = [feat_with_cls, None, None, None, None] + + # Concat UV coordinates at 5 pyramid levels + for level in range(5): + uv = normalized_view_plane_uv( + width=base_w * 2**level, + height=base_h * 2**level, + aspect_ratio=aspect_ratio, + dtype=dtype, + device=device, + ) + uv = ( + uv.permute(2, 0, 1).unsqueeze(0).expand(B, -1, -1, -1) + ) + if feat_list[level] is None: + feat_list[level] = uv + else: + feat_list[level] = torch.cat( + [feat_list[level], uv], dim=1 + ) + + # Shared neck + neck_out = self.neck(feat_list) + + # Extract depth_latents from neck level 1 (after 2 ResBlocks) + # neck_out[1]: [B, 256, base_h*2, base_w*2] + # Pool to (base_h, base_w) to keep N = base_h * base_w + neck_feat = neck_out[1] # [B, 256, base_h*2, base_w*2] + neck_feat_pooled = F.adaptive_avg_pool2d( + neck_feat, (base_h, base_w) + ) # [B, 256, base_h, base_w] + depth_latents = neck_feat_pooled.flatten(2).permute( + 0, 2, 1 + ) # [B, N, 256] + depth_latents = self.latent_proj(depth_latents) + + # Depth head: take last output + depth_reg = self.depth_head(neck_out)[-1] # [B, 1, h, w] + + # Resize to original image dimensions + depth_reg = F.interpolate( + depth_reg, + (H, W), + mode="bilinear", + align_corners=False, + ) + + # Apply output remapping + # Clamp before exp to prevent overflow (float16 overflows at ~11, + # float32 at ~88). Range [-10, 10] maps to depth [4.5e-5, 22026] m. + if self.remap_depth_out == "exp": + depth_map = depth_reg.clamp(-10, 10).exp() # [B, 1, H, W] + elif self.remap_depth_out == "linear": + # Linear output can be negative; clamp to positive for + # downstream log-based losses and 3D head depth usage. + depth_map = depth_reg.clamp(min=1e-3) + else: + raise ValueError( + f"Invalid remap_depth_out: {self.remap_depth_out}" + ) + + return depth_map, depth_latents, cls_token, base_h, base_w, neck_out + + def _run_mask_head( + self, + neck_out: list[Tensor], + H: int, + W: int, + ) -> Tensor | None: + """Run mask_head to produce confidence map. + + Args: + neck_out: List of neck feature maps. + H: Target height. + W: Target width. + + Returns: + confidence_map: [B, 1, H, W] sigmoid probabilities, or None. + """ + if self.mask_head is None: + return None + confidence_raw = self.mask_head(neck_out)[-1] # [B, 1, h, w] + confidence_map = F.interpolate( + confidence_raw, + (H, W), + mode="bilinear", + align_corners=False, + ).sigmoid() + return confidence_map + + @torch.autocast(device_type="cuda", enabled=False) + def _compute_losses( + self, + depth_map: Tensor, + depth_gt: Tensor | None, + depth_mask: Tensor | None, + K_pred: Tensor, + intrinsics: Tensor, + image_hw: tuple[int, int], + confidence_map: Tensor | None = None, + ) -> dict[str, Tensor]: + """Compute depth and camera losses. + + Depth loss: masked L1 + MoGe2 affine-invariant losses (global, + local level 4 & 16, edge) + confidence mask BCE. + Camera loss: ray-based L2 RMSE (same as UniDepthV2). + + Args: + depth_map: [B, 1, H, W] predicted metric depth. + depth_gt: [B, H, W] or [B, 1, H, W] ground truth depth. + depth_mask: [B, H, W] or [B, 1, H, W] valid depth mask. + K_pred: [B, 3, 3] predicted intrinsics. + intrinsics: [B, 3, 3] ground truth intrinsics. + image_hw: (H, W) image dimensions. + confidence_map: [B, 1, H, W] confidence from mask_head. + + Returns: + Dictionary of loss tensors. + """ + # Lazy import: moge losses only needed for training + # Lazy imports: only needed for training + from moge.train.losses import ( + affine_invariant_global_loss, + affine_invariant_local_loss, + edge_loss, + mask_bce_loss, + ) + if self._silog_loss is None and self._silog_loss_weight > 0: + from wilddet3d.loss.silog_loss import SILogLoss + self._silog_loss = SILogLoss(scale_pred_weight=0.15) + + losses = {} + H, W = image_hw + + # Cast to float32 for numerical stability under mixed precision + depth_map = depth_map.float() + K_pred = K_pred.float() + intrinsics = intrinsics.float() + if depth_gt is not None: + depth_gt = depth_gt.float() + if depth_mask is not None: + depth_mask = depth_mask.float() + if confidence_map is not None: + confidence_map = confidence_map.float() + + # Depth losses + if depth_gt is not None: + depth_pred = depth_map.squeeze(1) # [B, H, W] + + if depth_gt.ndim == 4: + depth_gt = depth_gt.squeeze(1) # [B, H, W] + + valid_mask = depth_gt > 0 + if depth_mask is not None: + if depth_mask.ndim == 4: + depth_mask = depth_mask.squeeze(1) + valid_mask = valid_mask & depth_mask.bool() + + # Filter out extreme GT depth (>100m) and extreme + # pred/gt ratio (>3x or <1/3x) to prevent unstable + # gradients from outlier pixels. + _MAX_DEPTH = 100.0 + _MAX_RATIO = 3.0 + valid_mask = valid_mask & (depth_gt <= _MAX_DEPTH) + with torch.no_grad(): + ratio = depth_pred / (depth_gt + 1e-6) + valid_mask = valid_mask & ( + (ratio > 1.0 / _MAX_RATIO) + & (ratio < _MAX_RATIO) + ) + + B = depth_pred.shape[0] + + # L1 metric depth loss + if valid_mask.any(): + depth_loss = F.l1_loss( + depth_pred[valid_mask], depth_gt[valid_mask] + ) + else: + depth_loss = depth_pred.new_tensor(0.0) + + losses["depth_l1"] = ( + depth_loss.clamp(max=10.0) * self.depth_loss_weight + ) + + # SILog loss (scale-invariant) + if self._silog_loss is not None and valid_mask.any(): + silog_val = self._silog_loss( + depth_pred, depth_gt, mask=valid_mask + ) + losses["depth_silog"] = ( + silog_val.clamp(max=10.0) + * self.silog_loss_weight + ) + + # Back-project to 3D points for MoGe2 losses + # 50% chance per image: use K_pred or GT intrinsics + # This trains intrinsic head via MoGe2 loss while keeping + # depth supervised with GT intrinsics half the time. + use_pred_k = torch.rand(B, device=depth_pred.device) < 0.5 + K_for_pred = torch.where( + use_pred_k[:, None, None], K_pred, intrinsics + ) + pred_points = backproject_depth_to_points( + depth_pred, K_for_pred, H, W + ) # [B, H, W, 3] + gt_points = backproject_depth_to_points( + depth_gt, intrinsics, H, W + ) # [B, H, W, 3] + # MoGe2 convention: invalid GT -> inf + gt_points[~valid_mask] = float("inf") + + # Per-image MoGe2 losses (alignment is per-image) + zero = depth_pred.new_tensor(0.0) + aff_global_sum = zero + aff_local4_sum = zero + aff_local16_sum = zero + edge_sum = zero + + for i in range(B): + has_valid = valid_mask[i].any() + if has_valid: + loss_g, _, scale_i = ( + affine_invariant_global_loss( + pred_points[i], + gt_points[i], + align_resolution=48, + ) + ) + else: + loss_g = zero + scale_i = zero + aff_global_sum = aff_global_sum + loss_g + + # MoGe2 local loss expects normalized focal + # (fx/W, fy/H ~0.5-1.0), not pixel focal + fx_norm = K_pred[i, 0, 0] / W + fy_norm = K_pred[i, 1, 1] / H + focal_i = 1.0 / ( + 1.0 / fx_norm**2 + 1.0 / fy_norm**2 + ) ** 0.5 + + if has_valid: + loss_l4, _ = affine_invariant_local_loss( + pred_points[i], + gt_points[i], + focal_i, + scale_i, + level=4, + align_resolution=24, + num_patches=16, + importance_sampling=False, + ) + loss_l16, _ = affine_invariant_local_loss( + pred_points[i], + gt_points[i], + focal_i, + scale_i, + level=16, + align_resolution=12, + num_patches=256, + importance_sampling=False, + ) + loss_e, _ = edge_loss( + pred_points[i], gt_points[i] + ) + else: + loss_l4 = zero + loss_l16 = zero + loss_e = zero + aff_local4_sum = aff_local4_sum + loss_l4 + aff_local16_sum = aff_local16_sum + loss_l16 + edge_sum = edge_sum + loss_e + + losses["affine_global"] = ( + (aff_global_sum / B).clamp(max=10.0) + * self.affine_global_weight + ) + losses["affine_local_4"] = ( + (aff_local4_sum / B).clamp(max=10.0) + * self.affine_local_weight + ) + losses["affine_local_16"] = ( + (aff_local16_sum / B).clamp(max=10.0) + * self.affine_local_weight + ) + losses["edge"] = ( + (edge_sum / B).clamp(max=10.0) + * self.edge_loss_weight + ) + + # Mask BCE loss (confidence map) + # MoGe2 uses 3-state masks (fin / inf / unknown). + # For sparse data (LiDAR), most pixels have no + # annotation and should NOT be labeled "known invalid". + # Use per-image coverage to decide: dense (>50%) + # treats all non-valid as known-invalid; sparse + # treats only depth_mask-annotated invalid pixels. + if ( + confidence_map is not None + and self.mask_loss_weight > 0 + ): + conf = confidence_map.squeeze(1) # [B, H, W] + gt_mask_fin = valid_mask # [B, H, W] + has_depth = depth_gt > 0 # [B, H, W] + if depth_mask is not None: + annotated = depth_mask.bool() + else: + # Per-image: dense -> all pixels annotated; + # sparse -> only depth>0 pixels annotated. + coverage = has_depth.flatten(1).float().mean(1) + is_dense = coverage > 0.7 # [B] + annotated = torch.where( + is_dense[:, None, None], + torch.ones_like(has_depth), + has_depth, + ) + gt_mask_inf = annotated & ~has_depth + loss_mask, _ = mask_bce_loss( + conf, gt_mask_fin, gt_mask_inf + ) + losses["mask_bce"] = ( + loss_mask.mean().clamp(max=10.0) + * self.mask_loss_weight + ) + + # Camera loss: ray-based MSE (same as UniDepthV2) + rays_pred, _ = generate_rays(K_pred, image_hw) + rays_gt, _ = generate_rays(intrinsics, image_hw) + camera_loss = F.mse_loss(rays_pred, rays_gt) + losses["camera_ray"] = ( + camera_loss.clamp(max=10.0) * self.camera_loss_weight + ) + + return losses + + def _scale_intrinsics( + self, + intrinsics: Tensor, + from_hw: tuple[int, int], + to_hw: tuple[int, int], + ) -> Tensor: + """Scale intrinsics from one image space to another. + + Args: + intrinsics: [B, 3, 3] intrinsics in from_hw space. + from_hw: Source (H, W). + to_hw: Target (H, W). + + Returns: + Scaled intrinsics [B, 3, 3] in to_hw space. + """ + scale_x = to_hw[1] / from_hw[1] + scale_y = to_hw[0] / from_hw[0] + + K_scaled = intrinsics.clone() + K_scaled[:, 0, 0] *= scale_x # fx + K_scaled[:, 0, 2] *= scale_x # cx + K_scaled[:, 1, 1] *= scale_y # fy + K_scaled[:, 1, 2] *= scale_y # cy + + return K_scaled + + def _has_valid_padding(self, padding: list | None) -> bool: + """Check if padding info is valid and non-zero.""" + if padding is None: + return False + return any( + p is not None and any(v > 0 for v in p) for p in padding + ) + + def _crop_padding_single( + self, + image: Tensor, + intrinsics: Tensor, + pad_info: list[int], + H_pad: int, + W_pad: int, + depth_gt: Tensor | None = None, + depth_mask: Tensor | None = None, + ) -> tuple[Tensor, Tensor, int, int, Tensor | None, Tensor | None]: + """Crop padding from a single image and adjust intrinsics. + + Args: + image: [1, 3, H_pad, W_pad] padded image. + intrinsics: [1, 3, 3] padded-space intrinsics. + pad_info: [pad_left, pad_right, pad_top, pad_bottom]. + H_pad: Padded height. + W_pad: Padded width. + depth_gt: [1, 1, H_pad, W_pad] or None. + depth_mask: [1, 1, H_pad, W_pad] or [1, H_pad, W_pad] or None. + + Returns: + (cropped_image, adjusted_intrinsics, H_orig, W_orig, + cropped_depth_gt, cropped_depth_mask) + """ + pad_left, pad_right, pad_top, pad_bottom = pad_info + H_orig = H_pad - pad_top - pad_bottom + W_orig = W_pad - pad_left - pad_right + + # Crop image + img_cropped = image[ + :, :, pad_top : pad_top + H_orig, pad_left : pad_left + W_orig + ] + + # Adjust intrinsics: reverse CenterPadIntrinsics + K_cropped = intrinsics.clone() + K_cropped[0, 0, 2] -= pad_left # cx + K_cropped[0, 1, 2] -= pad_top # cy + + # Crop depth_gt + dgt_cropped = None + if depth_gt is not None: + dgt_cropped = depth_gt[ + :, :, + pad_top : pad_top + H_orig, + pad_left : pad_left + W_orig, + ] + + # Crop depth_mask + dm_cropped = None + if depth_mask is not None: + if depth_mask.ndim == 3: + dm_cropped = depth_mask[ + :, + pad_top : pad_top + H_orig, + pad_left : pad_left + W_orig, + ] + else: + dm_cropped = depth_mask[ + :, :, + pad_top : pad_top + H_orig, + pad_left : pad_left + W_orig, + ] + + return ( + img_cropped, + K_cropped, + H_orig, + W_orig, + dgt_cropped, + dm_cropped, + ) + + def _repad_depth_latents( + self, + depth_latents: Tensor, + base_h_orig: int, + base_w_orig: int, + base_h_pad: int, + base_w_pad: int, + pad_top: int, + pad_left: int, + H_pad: int, + W_pad: int, + ) -> Tensor: + """Repad depth latents from original to padded token grid. + + Places original-resolution tokens at the correct position within + the padded token grid, with zeros filling the padding regions. + + Args: + depth_latents: [1, N_orig, C] original-resolution latents. + base_h_orig: Original token grid height. + base_w_orig: Original token grid width. + base_h_pad: Padded token grid height. + base_w_pad: Padded token grid width. + pad_top: Pixel-space top padding. + pad_left: Pixel-space left padding. + H_pad: Padded image height. + W_pad: Padded image width. + + Returns: + [1, N_pad, C] depth latents in padded token grid. + """ + if ( + base_h_orig == base_h_pad + and base_w_orig == base_w_pad + ): + return depth_latents + + _, N_orig, C = depth_latents.shape + + # Reshape to spatial: [1, C, base_h_orig, base_w_orig] + dl_2d = depth_latents.permute(0, 2, 1).reshape( + 1, C, base_h_orig, base_w_orig + ) + + # Compute token-space offsets + pad_top_tok = round(pad_top * base_h_pad / H_pad) + pad_left_tok = round(pad_left * base_w_pad / W_pad) + + # Clamp to valid range + pad_top_tok = min(pad_top_tok, base_h_pad - 1) + pad_left_tok = min(pad_left_tok, base_w_pad - 1) + + # How many original tokens fit + h_fit = min(base_h_orig, base_h_pad - pad_top_tok) + w_fit = min(base_w_orig, base_w_pad - pad_left_tok) + + # Create padded output with zeros + dl_padded = torch.zeros( + 1, + C, + base_h_pad, + base_w_pad, + device=depth_latents.device, + dtype=depth_latents.dtype, + ) + dl_padded[ + :, + :, + pad_top_tok : pad_top_tok + h_fit, + pad_left_tok : pad_left_tok + w_fit, + ] = dl_2d[:, :, :h_fit, :w_fit] + + # Flatten back: [1, N_pad, C] + return dl_padded.flatten(2).permute(0, 2, 1) + + def _repad_depth_map( + self, + depth_map: Tensor, + pad_left: int, + pad_right: int, + pad_top: int, + pad_bottom: int, + ) -> Tensor: + """Repad depth map from original to padded resolution. + + Args: + depth_map: [1, 1, H_orig, W_orig]. + pad_left, pad_right, pad_top, pad_bottom: Pixel padding. + + Returns: + [1, 1, H_pad, W_pad] with zeros in padding region. + """ + return F.pad( + depth_map, + (pad_left, pad_right, pad_top, pad_bottom), + value=0.0, + ) + + def forward_train( + self, + images: Tensor, + depth_feats: list[Tensor] | None, + intrinsics: Tensor, + image_hw: tuple[int, int], + depth_gt: Tensor | None = None, + depth_mask: Tensor | None = None, + **kwargs, + ) -> GeometryBackendOutput: + """Forward pass for training. + + Uses mixed depth input strategy: each sample independently + gets monocular / patch-masked / copy-through depth input. + + When padding info is provided, crops padding before the encoder + so LingBot-Depth processes at original resolution with correct + aspect ratio, then repads outputs back to padded space. + + Args: + images: [B, 3, H, W] 3D-MOOD normalized images. + depth_feats: Ignored (we use our own encoder). + intrinsics: [B, 3, 3] camera intrinsics. + image_hw: (H, W) image dimensions. + depth_gt: [B, H, W] ground truth depth. + depth_mask: [B, H, W] valid depth mask. + **kwargs: May contain 'padding' (list of [L,R,T,B] per image). + + Returns: + GeometryBackendOutput. + """ + B = images.shape[0] + H_pad, W_pad = image_hw + padding = kwargs.get("padding", None) + + # If no valid padding, use original batched code path + if not self._has_valid_padding(padding): + return self._forward_train_batched( + images, intrinsics, image_hw, depth_gt, depth_mask + ) + + # Per-image processing at original (unpadded) resolution + # Padded token grid (target for repadding depth_latents) + base_h_pad, base_w_pad = self._compute_token_grid( + H_pad, W_pad + ) + + depth_maps_list = [] + depth_latents_list = [] + K_pred_list = [] + confidence_maps_list = [] + losses_accum = {} + + for i in range(B): + pad_info = padding[i] + if pad_info is None or all(v == 0 for v in pad_info): + # No padding for this image + pad_left = pad_right = pad_top = pad_bottom = 0 + img_i = images[i : i + 1] + K_i = intrinsics[i : i + 1] + H_orig, W_orig = H_pad, W_pad + dgt_i = ( + depth_gt[i : i + 1] if depth_gt is not None + else None + ) + dm_i = ( + depth_mask[i : i + 1] + if depth_mask is not None + else None + ) + else: + pad_left, pad_right, pad_top, pad_bottom = pad_info + ( + img_i, + K_i, + H_orig, + W_orig, + dgt_i, + dm_i, + ) = self._crop_padding_single( + images[i : i + 1], + intrinsics[i : i + 1], + pad_info, + H_pad, + W_pad, + ( + depth_gt[i : i + 1] + if depth_gt is not None + else None + ), + ( + depth_mask[i : i + 1] + if depth_mask is not None + else None + ), + ) + + orig_hw = (H_orig, W_orig) + + # Prepare depth input with mixed strategy (per-image) + depth_input_i = self._prepare_depth_input( + dgt_i, dm_i, 1, H_orig, W_orig, images.device + ) + + # Run encoder at ORIGINAL resolution (correct aspect ratio) + ( + depth_map_i, + depth_latents_i, + cls_token_i, + base_h_i, + base_w_i, + neck_out_i, + ) = self._run_encoder_and_decoder( + img_i, depth_input_i, orig_hw + ) + + # Predict intrinsics at original resolution + K_pred_i = self._predict_intrinsics( + cls_token_i, H_orig, W_orig + ) + + # Run mask_head for confidence map + confidence_map_i = self._run_mask_head( + neck_out_i, H_orig, W_orig + ) + + # Compute losses at original resolution + losses_i = self._compute_losses( + depth_map_i, + dgt_i, + dm_i, + K_pred_i, + K_i, + orig_hw, + confidence_map=confidence_map_i, + ) + + # Accumulate losses + for key, val in losses_i.items(): + if key not in losses_accum: + losses_accum[key] = val + else: + losses_accum[key] = losses_accum[key] + val + + # Repad depth_map back to padded resolution + depth_map_padded_i = self._repad_depth_map( + depth_map_i, + pad_left, + pad_right, + pad_top, + pad_bottom, + ) + depth_maps_list.append(depth_map_padded_i) + + # Repad confidence_map back to padded resolution + if confidence_map_i is not None: + confidence_maps_list.append( + self._repad_depth_map( + confidence_map_i, + pad_left, + pad_right, + pad_top, + pad_bottom, + ) + ) + + # Repad depth_latents to padded token grid + depth_latents_padded_i = self._repad_depth_latents( + depth_latents_i, + base_h_i, + base_w_i, + base_h_pad, + base_w_pad, + pad_top, + pad_left, + H_pad, + W_pad, + ) + depth_latents_list.append(depth_latents_padded_i) + + # K_pred: restore to padded space (add padding offset) + # fx, fy unchanged (padding doesn't change focal length) + # Use non-inplace ops to preserve autograd graph + K_pred_padded_i = K_pred_i.clone() + K_pred_padded_i[:, 0, 2] = K_pred_i[:, 0, 2] + pad_left + K_pred_padded_i[:, 1, 2] = K_pred_i[:, 1, 2] + pad_top + K_pred_list.append(K_pred_padded_i) + + # Average losses across batch + for key in losses_accum: + losses_accum[key] = losses_accum[key] / B + + # Stack results + depth_map = torch.cat(depth_maps_list, dim=0) + depth_latents = torch.cat(depth_latents_list, dim=0) + K_pred = torch.cat(K_pred_list, dim=0) + confidence_map = ( + torch.cat(confidence_maps_list, dim=0) + if confidence_maps_list + else None + ) + + depth_latents = self._maybe_detach_latents(depth_latents) + + # Ray intrinsics: padded intrinsics scaled to padded token grid + # (consistent with padded depth_latents space) + internal_hw = (base_h_pad * 14, base_w_pad * 14) + ray_intrinsics = self._scale_intrinsics( + intrinsics, (H_pad, W_pad), internal_hw + ) + + return GeometryBackendOutput( + depth_map=depth_map, + depth_latents=depth_latents, + K_pred=K_pred, + ray_intrinsics=ray_intrinsics, + ray_image_hw=internal_hw, + ray_downsample=14, + aux={ + "depth_latents_hw": (base_h_pad, base_w_pad), + "confidence_map": confidence_map, + }, + losses=losses_accum, + ) + + def _forward_train_batched( + self, + images: Tensor, + intrinsics: Tensor, + image_hw: tuple[int, int], + depth_gt: Tensor | None, + depth_mask: Tensor | None, + ) -> GeometryBackendOutput: + """Original batched forward_train path (no unpadding).""" + B = images.shape[0] + H, W = image_hw + + depth_input = self._prepare_depth_input( + depth_gt, depth_mask, B, H, W, images.device + ) + + ( + depth_map, depth_latents, cls_token, + base_h, base_w, neck_out, + ) = self._run_encoder_and_decoder( + images, depth_input, image_hw + ) + + depth_latents = self._maybe_detach_latents(depth_latents) + K_pred = self._predict_intrinsics(cls_token, H, W) + + # Run mask_head for confidence map + confidence_map = self._run_mask_head(neck_out, H, W) + + losses = self._compute_losses( + depth_map, depth_gt, depth_mask, K_pred, intrinsics, + image_hw, confidence_map=confidence_map, + ) + + internal_hw = (base_h * 14, base_w * 14) + ray_intrinsics = self._scale_intrinsics( + intrinsics, (H, W), internal_hw + ) + + return GeometryBackendOutput( + depth_map=depth_map, + depth_latents=depth_latents, + K_pred=K_pred, + ray_intrinsics=ray_intrinsics, + ray_image_hw=internal_hw, + ray_downsample=14, + aux={ + "depth_latents_hw": (base_h, base_w), + "confidence_map": confidence_map, + }, + losses=losses, + ) + + @torch.no_grad() + def forward_test( + self, + images: Tensor, + depth_feats: list[Tensor] | None, + intrinsics: Tensor, + image_hw: tuple[int, int], + depth_gt: Tensor | None = None, + **kwargs, + ) -> GeometryBackendOutput: + """Forward pass for inference. + + When padding info is provided, crops padding before the encoder + so LingBot-Depth processes at original resolution, then repads. + + Args: + images: [B, 3, H, W] 3D-MOOD normalized images. + depth_feats: Ignored. + intrinsics: [B, 3, 3] camera intrinsics. + image_hw: (H, W) image dimensions. + depth_gt: [B, 1, H, W] depth map input (optional). + **kwargs: May contain 'padding' (list of [L,R,T,B]). + + Returns: + GeometryBackendOutput. + """ + H_pad, W_pad = image_hw + padding = kwargs.get("padding", None) + + # If unpad disabled or no valid padding, use batched (padded) path + if not self.unpad_test or not self._has_valid_padding(padding): + return self._forward_test_batched( + images, intrinsics, image_hw, depth_gt + ) + + # Per-image processing at original resolution + B = images.shape[0] + base_h_pad, base_w_pad = self._compute_token_grid( + H_pad, W_pad + ) + + depth_maps_list = [] + depth_latents_list = [] + K_pred_list = [] + confidence_maps_list = [] + + for i in range(B): + pad_info = padding[i] + if pad_info is None or all(v == 0 for v in pad_info): + pad_left = pad_right = pad_top = pad_bottom = 0 + img_i = images[i : i + 1] + K_i = intrinsics[i : i + 1] + H_orig, W_orig = H_pad, W_pad + dgt_i = ( + depth_gt[i : i + 1] + if depth_gt is not None + else None + ) + else: + pad_left, pad_right, pad_top, pad_bottom = pad_info + ( + img_i, + K_i, + H_orig, + W_orig, + dgt_i, + _, + ) = self._crop_padding_single( + images[i : i + 1], + intrinsics[i : i + 1], + pad_info, + H_pad, + W_pad, + ( + depth_gt[i : i + 1] + if depth_gt is not None + else None + ), + ) + + orig_hw = (H_orig, W_orig) + + # Use depth_gt as input if available, otherwise monocular + depth_input_i = dgt_i if dgt_i is not None else None + + ( + depth_map_i, + depth_latents_i, + cls_token_i, + base_h_i, + base_w_i, + neck_out_i, + ) = self._run_encoder_and_decoder( + img_i, depth_input_i, orig_hw + ) + + K_pred_i = self._predict_intrinsics( + cls_token_i, H_orig, W_orig + ) + + # Run mask_head for confidence map + confidence_map_i = self._run_mask_head( + neck_out_i, H_orig, W_orig + ) + + # Repad depth_map + depth_maps_list.append( + self._repad_depth_map( + depth_map_i, + pad_left, + pad_right, + pad_top, + pad_bottom, + ) + ) + + # Repad confidence_map + if confidence_map_i is not None: + confidence_maps_list.append( + self._repad_depth_map( + confidence_map_i, + pad_left, + pad_right, + pad_top, + pad_bottom, + ) + ) + + # Repad depth_latents + depth_latents_list.append( + self._repad_depth_latents( + depth_latents_i, + base_h_i, + base_w_i, + base_h_pad, + base_w_pad, + pad_top, + pad_left, + H_pad, + W_pad, + ) + ) + + # K_pred: restore to padded space (non-inplace for autograd) + K_pred_padded_i = K_pred_i.clone() + K_pred_padded_i[:, 0, 2] = K_pred_i[:, 0, 2] + pad_left + K_pred_padded_i[:, 1, 2] = K_pred_i[:, 1, 2] + pad_top + K_pred_list.append(K_pred_padded_i) + + depth_map = torch.cat(depth_maps_list, dim=0) + depth_latents = torch.cat(depth_latents_list, dim=0) + K_pred = torch.cat(K_pred_list, dim=0) + confidence_map = ( + torch.cat(confidence_maps_list, dim=0) + if confidence_maps_list + else None + ) + + depth_latents = self._maybe_detach_latents(depth_latents) + + internal_hw = (base_h_pad * 14, base_w_pad * 14) + ray_intrinsics = self._scale_intrinsics( + intrinsics, (H_pad, W_pad), internal_hw + ) + + return GeometryBackendOutput( + depth_map=depth_map, + depth_latents=depth_latents, + K_pred=K_pred, + ray_intrinsics=ray_intrinsics, + ray_image_hw=internal_hw, + ray_downsample=14, + aux={ + "depth_latents_hw": (base_h_pad, base_w_pad), + "confidence_map": confidence_map, + }, + losses={}, + ) + + def _forward_test_batched( + self, + images: Tensor, + intrinsics: Tensor, + image_hw: tuple[int, int], + depth_gt: Tensor | None, + ) -> GeometryBackendOutput: + """Original batched forward_test path (no unpadding).""" + H, W = image_hw + + depth_input = depth_gt if depth_gt is not None else None + ( + depth_map, depth_latents, cls_token, + base_h, base_w, neck_out, + ) = self._run_encoder_and_decoder( + images, depth_input, image_hw + ) + + depth_latents = self._maybe_detach_latents(depth_latents) + K_pred = self._predict_intrinsics(cls_token, H, W) + + # Run mask_head for confidence map + confidence_map = self._run_mask_head(neck_out, H, W) + + internal_hw = (base_h * 14, base_w * 14) + ray_intrinsics = self._scale_intrinsics( + intrinsics, (H, W), internal_hw + ) + + return GeometryBackendOutput( + depth_map=depth_map, + depth_latents=depth_latents, + K_pred=K_pred, + ray_intrinsics=ray_intrinsics, + ray_image_hw=internal_hw, + ray_downsample=14, + aux={ + "depth_latents_hw": (base_h, base_w), + "confidence_map": confidence_map, + }, + losses={}, + ) diff --git a/wilddet3d/eval/__init__.py b/wilddet3d/eval/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/wilddet3d/eval/detect3d.py b/wilddet3d/eval/detect3d.py new file mode 100644 index 0000000000000000000000000000000000000000..78ca7f853a13298d515a09604ff080e25e01d3d8 --- /dev/null +++ b/wilddet3d/eval/detect3d.py @@ -0,0 +1,1734 @@ +"""3D Multiple Object Detection Evaluator.""" + +import contextlib +import copy +import datetime +import io +import itertools +import json +import os +import time +from collections import defaultdict + +import numpy as np +import pycocotools.mask as maskUtils +import torch +from pycocotools.cocoeval import COCOeval +from scipy.spatial.distance import cdist +from terminaltables import AsciiTable +from vis4d.common.array import array_to_numpy +from vis4d.common.distributed import all_gather_object_cpu +from vis4d.common.typing import ( + ArrayLike, + DictStrAny, + GenericFunc, + MetricLogs, + NDArrayF32, + NDArrayI64, +) +from vis4d.eval.base import Evaluator +from vis4d.eval.coco.detect import xyxy_to_xywh + +from vis4d.data.const import AxisMode +from vis4d.op.box.box3d import boxes3d_to_corners +from vis4d.op.geometry.rotation import quaternion_to_matrix + +from wilddet3d.data.datasets.coco3d import COCO3D +from wilddet3d.ops.box3d import box3d_overlap +from wilddet3d.ops.rotation import so3_relative_angle + + +def _canonicalize_rotation_np(R_cam, dims_whl): + """Canonicalize rotation for evaluation (numpy version). + + Matches _normalize_canonical in coder.py. Eliminates 4-fold OBB + rotation ambiguity: + Step 1 - Force W <= L: if W > L, swap and apply Ry(90). + Step 2 - Normalize yaw to [0, pi): if yaw outside, apply Ry(180). + + Args: + R_cam: 3x3 rotation matrix (numpy). + dims_whl: [W, H, L] dimensions (numpy or list). + + Returns: + R_out: 3x3 canonical rotation matrix. + """ + R_out = np.array(R_cam, dtype=np.float64).copy() + w, h, l = float(dims_whl[0]), float(dims_whl[1]), float(dims_whl[2]) + + # Step 1: Force W <= L + if w > l: + w, l = l, w + col0 = R_out[:, 0].copy() + R_out[:, 0] = -R_out[:, 2] + R_out[:, 2] = col0 + + # Step 2: Normalize yaw to [0, pi) + # YZX intrinsic: yaw = atan2(-R[2,0], R[0,0]) + yaw = np.arctan2(-R_out[2, 0], R_out[0, 0]) + if yaw < 0 or yaw > np.pi - 1e-4: + R_out[:, 0] = -R_out[:, 0] + R_out[:, 2] = -R_out[:, 2] + + return R_out + + +class Detect3DEvaluator(Evaluator): + """3D object detection evaluation with COCO format.""" + + def __init__( + self, + det_map: dict[str, int], + cat_map: dict[str, int], + annotation: str, + id2name: dict[int, str] | None = None, + per_class_eval: bool = True, + eval_prox: bool = False, + iou_type: str = "bbox", + num_columns: int = 6, + base_classes: list[str] | None = None, + # Frequency-based AP split (LVIS-style) + # Categories with APr + # Categories with rare_thresh..freq_thresh images -> APc + # Categories with >=freq_thresh images -> APf + freq_rare_thresh: int = 0, + freq_freq_thresh: int = 0, + # APRel3D parameters (LabelAny3D-style) + enable_aprel3d: bool = False, + aprel_2d_iou_thresh: float = 0.75, + ) -> None: + """Create an instance of the class.""" + if id2name is None: + self.id2name = {v: k for k, v in det_map.items()} + else: + self.id2name = id2name + + self.annotation = annotation + self.per_class_eval = per_class_eval + self.eval_prox = eval_prox + self.iou_type = iou_type + self.num_columns = num_columns + self.base_classes = base_classes + + # APRel3D settings (LabelAny3D-style) + self.enable_aprel3d = enable_aprel3d + self.aprel_2d_iou_thresh = aprel_2d_iou_thresh + + self.tp_errors = ["ATE", "AOE", "ASE"] + + category_names = sorted(det_map, key=det_map.get) + + with contextlib.redirect_stdout(io.StringIO()): + self._coco_gt = COCO3D([annotation], category_names) + + self.cat_map = cat_map + + # Build frequency split if thresholds are set + self.freq_rare_thresh = freq_rare_thresh + self.freq_freq_thresh = freq_freq_thresh + self.cat_freq_group: dict[int, str] | None = None + if freq_rare_thresh > 0 and freq_freq_thresh > 0: + with open(annotation) as f: + ann_data = json.load(f) + cat_img_count: dict[int, set] = {} + for ann in ann_data["annotations"]: + cid = ann["category_id"] + if cid not in cat_img_count: + cat_img_count[cid] = set() + cat_img_count[cid].add(ann["image_id"]) + self.cat_freq_group = {} + for cat in ann_data["categories"]: + n = len(cat_img_count.get(cat["id"], set())) + if n < freq_rare_thresh: + self.cat_freq_group[cat["id"]] = "rare" + elif n < freq_freq_thresh: + self.cat_freq_group[cat["id"]] = "common" + else: + self.cat_freq_group[cat["id"]] = "frequent" + n_r = sum(1 for v in self.cat_freq_group.values() if v == "rare") + n_c = sum(1 for v in self.cat_freq_group.values() if v == "common") + n_f = sum(1 for v in self.cat_freq_group.values() if v == "frequent") + print(f"[Detect3DEvaluator] Frequency split: " + f"rare(<{freq_rare_thresh})={n_r}, " + f"common({freq_rare_thresh}-{freq_freq_thresh})={n_c}, " + f"frequent(>={freq_freq_thresh})={n_f}") + + self.bbox_2D_evals_per_cat_area: DictStrAny = {} + self.bbox_3D_evals_per_cat_area: DictStrAny = {} + self._predictions: list[DictStrAny] = [] + + # Store optimal scales for APRel3D + self.optimal_scales: dict[int, float] = {} + + def __repr__(self) -> str: + """Returns the string representation of the object.""" + return f"3D Object Detection Evaluator with {self.annotation}" + + @property + def metrics(self) -> list[str]: + """Supported metrics. + + Returns: + list[str]: Metrics to evaluate. + """ + return ["2D", "3D"] + + def gather(self, gather_func: GenericFunc) -> None: + """Accumulate predictions across all processes. + + Uses NCCL-based all_gather_object instead of vis4d's file-based + all_gather_object_cpu, which fails on weka cross-node due to + filesystem cache consistency issues. + """ + import torch.distributed as dist + + if not dist.is_initialized() or dist.get_world_size() == 1: + return + + rank = dist.get_rank() + world_size = dist.get_world_size() + + # Use NCCL-based gathering (avoids cross-node filesystem issues) + all_preds = [None] * world_size + dist.all_gather_object(all_preds, self._predictions) + + if rank == 0: + self._predictions = list( + itertools.chain(*all_preds) + ) + else: + self._predictions = [] + + def reset(self) -> None: + """Reset the saved predictions to start new round of evaluation.""" + self._predictions.clear() + self.bbox_2D_evals_per_cat_area.clear() + self.bbox_3D_evals_per_cat_area.clear() + self.optimal_scales.clear() + + def _find_optimal_scale( + self, preds: list[DictStrAny], gts: list[DictStrAny] + ) -> float: + """Find optimal global scale factor (LabelAny3D method). + + Ported from LabelAny3D compute_optimal_scale(): + 1. Match each dt to best gt using 2D IoU (threshold 0.75). + 2. Grid search [0.1, 3.5] step 0.1, maximize avg 3D IoU. + """ + # Collect dt/gt 3D corners and 2D boxes + dt_boxes = [] + dt_boxes_2d = [] + for pred in preds: + if "bbox3D" not in pred or "bbox" not in pred: + continue + dt_boxes.append(pred["bbox3D"]) + # bbox is COCO [x, y, w, h], convert to [x1, y1, x2, y2] + b = pred["bbox"] + dt_boxes_2d.append([b[0], b[1], b[0] + b[2], b[1] + b[3]]) + + gt_boxes = [] + gt_boxes_2d = [] + for gt in gts: + if "bbox3D" not in gt or "bbox" not in gt: + continue + gt_boxes.append(gt["bbox3D"]) + b = gt["bbox"] + gt_boxes_2d.append([b[0], b[1], b[0] + b[2], b[1] + b[3]]) + + if len(gt_boxes) == 0 or len(dt_boxes) == 0: + return 1.0 + + dt_boxes = np.array(dt_boxes, dtype=np.float32) + gt_boxes = np.array(gt_boxes, dtype=np.float32) + + # Match each dt to the most similar gt using 2D IoU + matched_pairs = [] + for dt_idx, dt_2d in enumerate(dt_boxes_2d): + best_iou = 0 + best_gt_idx = -1 + + for gt_idx, gt_2d in enumerate(gt_boxes_2d): + x1 = max(dt_2d[0], gt_2d[0]) + y1 = max(dt_2d[1], gt_2d[1]) + x2 = min(dt_2d[2], gt_2d[2]) + y2 = min(dt_2d[3], gt_2d[3]) + + if x2 <= x1 or y2 <= y1: + continue + + inter_area = (x2 - x1) * (y2 - y1) + dt_area = (dt_2d[2] - dt_2d[0]) * (dt_2d[3] - dt_2d[1]) + gt_area = (gt_2d[2] - gt_2d[0]) * (gt_2d[3] - gt_2d[1]) + iou = inter_area / (dt_area + gt_area - inter_area) + + if iou > best_iou: + best_iou = iou + best_gt_idx = gt_idx + + if best_gt_idx >= 0 and best_iou > 0.75: + matched_pairs.append((dt_idx, best_gt_idx)) + + if len(matched_pairs) == 0: + return 1.0 + + def compute_avg_iou(scale): + avg_iou = 0.0 + for dt_idx, gt_idx in matched_pairs: + scaled_dt_box = dt_boxes[dt_idx] * scale + dt_tensor = torch.tensor( + scaled_dt_box[np.newaxis, :, :], + dtype=torch.float32, + ) + gt_tensor = torch.tensor( + gt_boxes[gt_idx][np.newaxis, :, :], + dtype=torch.float32, + ) + iou = box3d_overlap(dt_tensor, gt_tensor).cpu().numpy()[0] + avg_iou += iou + return avg_iou / len(matched_pairs) + + # Grid search: start with scale=1.0, then search [0.1, 3.5] + best_scale = 1.0 + best_iou = compute_avg_iou(best_scale) + + for scale in np.arange(0.1, 3.51, 0.1): + iou = compute_avg_iou(scale) + if iou > best_iou: + best_iou = iou + best_scale = scale + + return best_scale + + def _optimize_and_apply_scales(self) -> None: + """Optimize per-image scale and apply to all predictions.""" + print("Optimizing scales for APRel3D (LabelAny3D method)...") + print(f" 2D IoU match threshold: {self.aprel_2d_iou_thresh}") + + # Step 1: Group predictions by image + preds_by_image = defaultdict(list) + for pred in self._predictions: + preds_by_image[pred["image_id"]].append(pred) + + # Step 2: Optimize scale for each image + n_matched_images = 0 + for img_id, preds in preds_by_image.items(): + gts = self._coco_gt.loadAnns( + self._coco_gt.getAnnIds(imgIds=[img_id]) + ) + if len(gts) == 0: + self.optimal_scales[img_id] = 1.0 + continue + + s_star = self._find_optimal_scale(preds, gts) + self.optimal_scales[img_id] = s_star + if s_star != 1.0: + n_matched_images += 1 + + # Step 3: Apply scales (direct corner multiplication) + scaled_predictions = [] + for pred in self._predictions: + img_id = pred["image_id"] + scale = self.optimal_scales.get(img_id, 1.0) + + scaled_pred = pred.copy() + + if "center_cam" in pred: + scaled_pred["center_cam"] = [ + c * scale for c in pred["center_cam"] + ] + if "dimensions" in pred: + scaled_pred["dimensions"] = [ + d * scale for d in pred["dimensions"] + ] + if "bbox3D" in pred: + scaled_pred["bbox3D"] = [ + [c * scale for c in corner] + for corner in pred["bbox3D"] + ] + if "depth" in pred: + scaled_pred["depth"] = pred["depth"] * scale + + scaled_predictions.append(scaled_pred) + + self._predictions = scaled_predictions + + # Print statistics + scales = list(self.optimal_scales.values()) + if len(scales) > 0: + print(f"APRel3D: {len(scales)} images, " + f"{n_matched_images} had 2D-IoU matches") + print(f" Mean scale: {np.mean(scales):.3f}") + print(f" Std scale: {np.std(scales):.3f}") + print(f" Min scale: {np.min(scales):.3f}") + print(f" Max scale: {np.max(scales):.3f}") + + def process_batch( + self, + coco_image_id: list[int], + pred_boxes: list[ArrayLike], + pred_scores: list[ArrayLike], + pred_classes: list[ArrayLike], + pred_boxes3d: list[ArrayLike] | None = None, + ) -> None: + """Process sample and convert detections to coco format.""" + for i, image_id in enumerate(coco_image_id): + boxes = array_to_numpy( + pred_boxes[i].to(torch.float32), n_dims=None, dtype=np.float32 + ) + scores = array_to_numpy( + pred_scores[i].to(torch.float32), n_dims=None, dtype=np.float32 + ) + classes = array_to_numpy( + pred_classes[i], n_dims=None, dtype=np.int64 + ) + + if pred_boxes3d is not None: + boxes3d = array_to_numpy( + pred_boxes3d[i].to(torch.float32), + n_dims=None, + dtype=np.float32, + ) + else: + boxes3d = None + + self._predictions_to_coco( + image_id, boxes, boxes3d, scores, classes + ) + + def _predictions_to_coco( + self, + img_id: int, + boxes: NDArrayF32, + boxes3d: NDArrayF32 | None, + scores: NDArrayF32, + classes: NDArrayI64, + ) -> None: + """Convert predictions to COCO format.""" + boxes_xyxy = copy.deepcopy(boxes) + boxes_xywh = xyxy_to_xywh(boxes_xyxy) + + if boxes3d is not None: + # FIXME: Make axismode configurable + corners_3d = boxes3d_to_corners( + torch.from_numpy(boxes3d), AxisMode.OPENCV + ) + + for i, (box, box_score, box_class) in enumerate( + zip(boxes_xywh, scores, classes) + ): + xywh = box.tolist() + + result = { + "image_id": img_id, + "bbox": xywh, + "category_id": self.cat_map[self.id2name[box_class.item()]], + "score": box_score.item(), + } + + # mapping to Omni3D format + if boxes3d is not None: + result["center_cam"] = boxes3d[i][:3].tolist() + + # wlh to whl + result["dimensions"] = boxes3d[i][[3, 5, 4]].tolist() + + result["R_cam"] = ( + quaternion_to_matrix(torch.from_numpy(boxes3d[i][6:10])) + .numpy() + .tolist() + ) + + corners = corners_3d[i].numpy().tolist() + + result["bbox3D"] = [ + corners[6], + corners[4], + corners[0], + corners[2], + corners[7], + corners[5], + corners[1], + corners[3], + ] + + result["depth"] = boxes3d[i][2].item() + + self._predictions.append(result) + + def evaluate(self, metric: str) -> tuple[MetricLogs, str]: + """Evaluate predictions.""" + if metric == "2D": + metrics = ["AP", "AP50", "AP75", "AP95", "APs", "APm", "APl"] + else: + if self.iou_type == "bbox": + if self.enable_aprel3d: + metrics = [ + "APRel3D", + "APRel15", + "APRel25", + "APRel50", + "APReln", + "APRelm", + "APRelf", + ] + main_metric = "APRel3D" + else: + metrics = ["AP", "AP15", "AP25", "AP50", "APn", "APm", "APf"] + main_metric = "AP" + else: + if self.enable_aprel3d: + metrics = ["APRel", "ATERel", "ASERel", "AOERel", "ODSRel", "ODSRelSym"] + main_metric = "ODSRel" + else: + metrics = ["AP", "ATE", "ASE", "AOE", "ODS", "ODS_Sym"] + main_metric = "ODS" + + if self.base_classes is not None: + metrics += [f"{main_metric}_Base", f"{main_metric}_Novel"] + + if len(self._predictions) == 0: + return {m: 0.0 for m in metrics}, "No predictions to evaluate." + + # APRel3D: Optimize and apply scales before evaluation + if self.enable_aprel3d and metric == "3D": + self._optimize_and_apply_scales() + + with contextlib.redirect_stdout(io.StringIO()): + coco_dt = self._coco_gt.loadRes(self._predictions) + + assert coco_dt is not None + evaluator = Detect3Deval( + self._coco_gt, + coco_dt, + mode=metric, + eval_prox=self.eval_prox, + iou_type=self.iou_type, + ) + evaluator.evaluate() + evaluator.accumulate() + + if self.iou_type == "bbox": + log_str = "\n" + evaluator.summarize() + + # precision: (iou, recall, cls, area range, max dets) + precisions = evaluator.eval["precision"] + assert len(self._coco_gt.getCatIds()) == precisions.shape[2] + + if metric == "2D": + self.bbox_2D_evals_per_cat_area = evaluator.evals_per_cat_area + + score_dict = dict(zip(metrics, evaluator.stats)) + else: + if self.iou_type == "bbox": + self.bbox_3D_evals_per_cat_area = evaluator.evals_per_cat_area + + score_dict = dict(zip(metrics, evaluator.stats)) + + # Compute mASE, mAOE, mAOE_Sym for bbox mode + # Note: ATE is not returned in bbox mode because the normalization + # by IoU threshold makes it unreliable (can be > 1) + rot_tp_errors = evaluator.eval["rot_tp_errors"] + rot_sym_tp_errors = evaluator.eval["rot_sym_tp_errors"] + rot_canonical_tp_errors = evaluator.eval["rot_canonical_tp_errors"] + scale_tp_errors = evaluator.eval["scale_tp_errors"] + + rot_tp = rot_tp_errors[:, :, :, 0, -1] + rot_tp = rot_tp[rot_tp > -1] + + rot_sym_tp = rot_sym_tp_errors[:, :, :, 0, -1] + rot_sym_tp = rot_sym_tp[rot_sym_tp > -1] + + rot_canonical_tp = rot_canonical_tp_errors[:, :, :, 0, -1] + rot_canonical_tp = rot_canonical_tp[rot_canonical_tp > -1] + + scale_tp = scale_tp_errors[:, :, :, 0, -1] + scale_tp = scale_tp[scale_tp > -1] + + if rot_tp.size: + mAOE = np.mean(rot_tp).item() + mAOE_Sym = np.mean(rot_sym_tp).item() + mAOE_Canonical = np.mean(rot_canonical_tp).item() + mASE = np.mean(scale_tp).item() + else: + mAOE = float("nan") + mAOE_Sym = float("nan") + mAOE_Canonical = float("nan") + mASE = float("nan") + + # Add error metrics to output (no ATE in bbox mode) + if self.enable_aprel3d: + score_dict["ASERel"] = mASE + score_dict["AOERel"] = mAOE + score_dict["AOERelSym"] = mAOE_Sym + score_dict["AOERelCanonical"] = mAOE_Canonical + else: + score_dict["ASE"] = mASE + score_dict["AOE"] = mAOE + score_dict["AOE_Sym"] = mAOE_Sym + score_dict["AOE_Canonical"] = mAOE_Canonical + + # Add scale statistics for APRel3D + if self.enable_aprel3d and len(self.optimal_scales) > 0: + scales = list(self.optimal_scales.values()) + score_dict["mean_scale"] = np.mean(scales) + score_dict["std_scale"] = np.std(scales) + else: + trans_tp_errors = evaluator.eval["trans_tp_errors"] + rot_tp_errors = evaluator.eval["rot_tp_errors"] + rot_sym_tp_errors = evaluator.eval["rot_sym_tp_errors"] + rot_canonical_tp_errors = evaluator.eval["rot_canonical_tp_errors"] + scale_tp_errors = evaluator.eval["scale_tp_errors"] + + precision = precisions[:, :, :, 0, -1] + precision = precision[precision > -1] + if precision.size: + mAP = np.mean(precision).item() + else: + mAP = float("nan") + + trans_tp = trans_tp_errors[:, :, :, 0, -1] + trans_tp = trans_tp[trans_tp > -1] + + rot_tp = rot_tp_errors[:, :, :, 0, -1] + rot_tp = rot_tp[rot_tp > -1] + + rot_sym_tp = rot_sym_tp_errors[:, :, :, 0, -1] + rot_sym_tp = rot_sym_tp[rot_sym_tp > -1] + + rot_canonical_tp = rot_canonical_tp_errors[:, :, :, 0, -1] + rot_canonical_tp = rot_canonical_tp[rot_canonical_tp > -1] + + scale_tp = scale_tp_errors[:, :, :, 0, -1] + scale_tp = scale_tp[scale_tp > -1] + + if trans_tp.size: + mATE = np.mean(trans_tp).item() + mAOE = np.mean(rot_tp).item() + mAOE_Sym = np.mean(rot_sym_tp).item() + mAOE_Canonical = np.mean(rot_canonical_tp).item() + mASE = np.mean(scale_tp).item() + + mODS = ( + np.sum(mAP * 3 + (1 - mATE) + (1 - mAOE) + (1 - mASE)) + / 6 + ) + mODS_Sym = ( + np.sum(mAP * 3 + (1 - mATE) + (1 - mAOE_Sym) + (1 - mASE)) + / 6 + ) + mODS_Canonical = ( + np.sum(mAP * 3 + (1 - mATE) + (1 - mAOE_Canonical) + (1 - mASE)) + / 6 + ) + + else: + mATE = float("nan") + mAOE = float("nan") + mAOE_Sym = float("nan") + mAOE_Canonical = float("nan") + mASE = float("nan") + mODS = float("nan") + mODS_Sym = float("nan") + mODS_Canonical = float("nan") + + if self.enable_aprel3d: + score_dict = { + "APRel": mAP, + "ATERel": mATE, + "ASERel": mASE, + "AOERel": mAOE, + "AOERelSym": mAOE_Sym, + "AOERelCanonical": mAOE_Canonical, + "ODSRel": mODS, + "ODSRelSym": mODS_Sym, + "ODSRelCanonical": mODS_Canonical, + } + else: + score_dict = { + "AP": mAP, + "ATE": mATE, + "ASE": mASE, + "AOE": mAOE, + "AOE_Sym": mAOE_Sym, + "AOE_Canonical": mAOE_Canonical, + "ODS": mODS, + "ODS_Sym": mODS_Sym, + "ODS_Canonical": mODS_Canonical, + } + + # Add scale statistics for APRel3D + if self.enable_aprel3d and len(self.optimal_scales) > 0: + scales = list(self.optimal_scales.values()) + score_dict["mean_scale"] = np.mean(scales) + score_dict["std_scale"] = np.std(scales) + + log_str = "\nHigh-level metrics:" + for k, v in score_dict.items(): + log_str += f"\n{k}: {v:.4f}" + + if self.per_class_eval: + results_per_category = [] + score_base_list = [] + score_novel_list = [] + freq_ap: dict[str, list] = {"rare": [], "common": [], "frequent": []} + + for idx, cat_id in enumerate(self._coco_gt.getCatIds()): + # area range index 0: all area ranges + # max dets index -1: typically 100 per image + nm = self._coco_gt.loadCats(cat_id)[0] + precision = precisions[:, :, idx, 0, -1] + precision = precision[precision > -1] + if precision.size: + ap = np.mean(precision).item() + else: + ap = float("nan") + + if self.iou_type == "dist": + trans_tp = trans_tp_errors[:, :, idx, 0, -1] + trans_tp = trans_tp[trans_tp > -1] + + rot_tp = rot_tp_errors[:, :, idx, 0, -1] + rot_tp = rot_tp[rot_tp > -1] + + rot_sym_tp = rot_sym_tp_errors[:, :, idx, 0, -1] + rot_sym_tp = rot_sym_tp[rot_sym_tp > -1] + + scale_tp = scale_tp_errors[:, :, idx, 0, -1] + scale_tp = scale_tp[scale_tp > -1] + + if trans_tp.size: + ate = np.mean(trans_tp).item() + aoe = np.mean(rot_tp).item() + aoe_sym = np.mean(rot_sym_tp).item() + ase = np.mean(scale_tp).item() + + ods = ( + np.sum(ap * 3 + (1 - ate) + (1 - aoe) + (1 - ase)) + / 6 + ) + ods_sym = ( + np.sum(ap * 3 + (1 - ate) + (1 - aoe_sym) + (1 - ase)) + / 6 + ) + + else: + ate = float("nan") + aoe = float("nan") + aoe_sym = float("nan") + ase = float("nan") + ods = float("nan") + ods_sym = float("nan") + + results_per_category.append( + ( + f'{nm["name"]}', + f"{ap:0.3f}", + f"{ate:0.3f}", + f"{ase:0.3f}", + f"{aoe:0.3f}", + f"{aoe_sym:0.3f}", + f"{ods:0.3f}", + f"{ods_sym:0.3f}", + ) + ) + else: + results_per_category.append( + (f'{nm["name"]}', f"{ap:0.3f}") + ) + + if self.base_classes is not None: + if self.iou_type == "dist": + score = ods + else: + score = ap + + if nm["name"] in self.base_classes: + score_base_list.append(score) + else: + score_novel_list.append(score) + + if self.cat_freq_group is not None and not np.isnan(ap): + group = self.cat_freq_group.get(cat_id, "rare") + freq_ap[group].append(ap) + + results_flatten = list(itertools.chain(*results_per_category)) + + if self.iou_type == "dist": + num_columns = 8 + headers = ["category", "AP", "ATE", "ASE", "AOE", "AOE_Sym", "ODS", "ODS_Sym"] + else: + num_columns = min( + self.num_columns, len(results_per_category) * 2 + ) + headers = ["category", "AP"] * (num_columns // 2) + results = itertools.zip_longest( + *[results_flatten[i::num_columns] for i in range(num_columns)] + ) + table_data = [headers] + list(results) + if AsciiTable is not None: + table = AsciiTable(table_data) + log_str = f"\n{table.table}\n{log_str}" + else: + # Fallback when terminaltables is not installed. + log_str = f"\n(per-class table omitted; install terminaltables for pretty output)\n{log_str}" + + if self.base_classes is not None: + score_dict[f"{main_metric}_Base"] = np.mean(score_base_list).item() + score_dict[f"{main_metric}_Novel"] = np.mean( + score_novel_list + ).item() + + if self.cat_freq_group is not None and self.per_class_eval: + score_dict["APr"] = np.mean(freq_ap["rare"]).item() if freq_ap["rare"] else float("nan") + score_dict["APc"] = np.mean(freq_ap["common"]).item() if freq_ap["common"] else float("nan") + score_dict["APf"] = np.mean(freq_ap["frequent"]).item() if freq_ap["frequent"] else float("nan") + log_str += ( + f"\nFrequency split (<{self.freq_rare_thresh}/{self.freq_freq_thresh}):" + f" APr={score_dict['APr']:.4f} ({len(freq_ap['rare'])} cats)," + f" APc={score_dict['APc']:.4f} ({len(freq_ap['common'])} cats)," + f" APf={score_dict['APf']:.4f} ({len(freq_ap['frequent'])} cats)" + ) + + return score_dict, log_str + + def save( + self, metric: str, output_dir: str, prefix: str | None = None + ) -> None: + """Save the results to json files.""" + assert metric in self.metrics + + if prefix is not None: + result_folder = os.path.join(output_dir, prefix) + os.makedirs(result_folder, exist_ok=True) + else: + result_folder = output_dir + + result_file = os.path.join( + result_folder, f"detect_{metric}_results.json" + ) + + with open(result_file, mode="w", encoding="utf-8") as f: + json.dump(self._predictions, f) + + +class Detect3Deval(COCOeval): + """COCOeval Wrapper for 2D and 3D box evaluation. + + Now it support bbox IoU matching only. + """ + + def __init__( + self, + cocoGt=None, + cocoDt=None, + mode: str = "2D", + iou_type: str = "bbox", + eval_prox: bool = False, + ): + """Initialize Detect3Deval using coco APIs for Gt and Dt. + + Args: + cocoGt: COCO object with ground truth annotations + cocoDt: COCO object with detection results + mode: (str) defines whether to evaluate 2D or 3D performance. + One of {"2D", "3D"} + eval_prox: (bool) if True, performs "Proximity Evaluation", i.e. + evaluates detections in the proximity of the ground truth2D + boxes. This is used for datasets which are not exhaustively + annotated. + """ + if mode not in {"2D", "3D"}: + raise Exception(f"{mode} mode is not supported") + self.mode = mode + self.iou_type = iou_type + self.eval_prox = eval_prox + + self.cocoGt = cocoGt # ground truth COCO API + self.cocoDt = cocoDt # detections COCO API + + # per-image per-category evaluation results [KxAxI] elements + self.evalImgs = defaultdict(list) + + self.eval = {} # accumulated evaluation results + self._gts = defaultdict(list) # gt for evaluation + self._dts = defaultdict(list) # dt for evaluation + self.params = Detect3DParams(mode=mode, iouType=iou_type) # parameters + self._paramsEval = {} # parameters for evaluation + self.stats = [] # result summarization + self.ious = {} # ious between all gts and dts + + if cocoGt is not None: + self.params.imgIds = sorted(cocoGt.getImgIds()) + self.params.catIds = sorted(cocoGt.getCatIds()) + + self.evals_per_cat_area = None + + def _prepare(self) -> None: + """Prepare ._gts and ._dts for evaluation based on params.""" + p = self.params + + if p.useCats: + gts = self.cocoGt.loadAnns( + self.cocoGt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds) + ) + dts = self.cocoDt.loadAnns( + self.cocoDt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds) + ) + + else: + gts = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds)) + dts = self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds)) + + # set ignore flag + ignore_flag = "ignore2D" if self.mode == "2D" else "ignore3D" + for gt in gts: + gt[ignore_flag] = gt[ignore_flag] if ignore_flag in gt else 0 + + self._gts = defaultdict(list) # gt for evaluation + self._dts = defaultdict(list) # dt for evaluation + + for gt in gts: + self._gts[gt["image_id"], gt["category_id"]].append(gt) + + for dt in dts: + self._dts[dt["image_id"], dt["category_id"]].append(dt) + + self.evalImgs = defaultdict( + list + ) # per-image per-category evaluation results + self.eval = {} # accumulated evaluation results + + def accumulate(self, p=None) -> None: + """Accumulate per image evaluation and store the result in self.eval. + + Args: + p: input params for evaluation + """ + print("Accumulating evaluation results...") + assert self.evalImgs, "Please run evaluate() first" + + tic = time.time() + + # allows input customized parameters + if p is None: + p = self.params + + p.catIds = p.catIds if p.useCats == 1 else [-1] + + T = len(p.iouThrs) + R = len(p.recThrs) + K = len(p.catIds) if p.useCats else 1 + A = len(p.areaRng) + M = len(p.maxDets) + + precision = -np.ones( + (T, R, K, A, M) + ) # -1 for the precision of absent categories + trans_tp_errors = -np.ones((T, R, K, A, M)) + rot_tp_errors = -np.ones((T, R, K, A, M)) + rot_sym_tp_errors = -np.ones((T, R, K, A, M)) + rot_canonical_tp_errors = -np.ones((T, R, K, A, M)) + scale_tp_errors = -np.ones((T, R, K, A, M)) + recall = -np.ones((T, K, A, M)) + scores = -np.ones((T, R, K, A, M)) + + # create dictionary for future indexing + _pe = self._paramsEval + + catIds = _pe.catIds if _pe.useCats else [-1] + setK = set(catIds) + setA = set(map(tuple, _pe.areaRng)) + setM = set(_pe.maxDets) + setI = set(_pe.imgIds) + + # get inds to evaluate + catid_list = [k for n, k in enumerate(p.catIds) if k in setK] + k_list = [n for n, k in enumerate(p.catIds) if k in setK] + m_list = [m for n, m in enumerate(p.maxDets) if m in setM] + a_list = [ + n + for n, a in enumerate(map(lambda x: tuple(x), p.areaRng)) + if a in setA + ] + i_list = [n for n, i in enumerate(p.imgIds) if i in setI] + + I0 = len(_pe.imgIds) + A0 = len(_pe.areaRng) + + has_precomputed_evals = not (self.evals_per_cat_area is None) + + if has_precomputed_evals: + evals_per_cat_area = self.evals_per_cat_area + else: + evals_per_cat_area = {} + + # retrieve E at each category, area range, and max number of detections + for k, (k0, catId) in enumerate(zip(k_list, catid_list)): + Nk = k0 * A0 * I0 + for a, a0 in enumerate(a_list): + Na = a0 * I0 + + if has_precomputed_evals: + E = evals_per_cat_area.get((catId, a), []) + + else: + E = [self.evalImgs[Nk + Na + i] for i in i_list] + E = [e for e in E if not e is None] + evals_per_cat_area[(catId, a)] = E + + if len(E) == 0: + continue + + for m, maxDet in enumerate(m_list): + + dtScores = np.concatenate( + [e["dtScores"][0:maxDet] for e in E] + ) + + # different sorting method generates slightly different results. + # mergesort is used to be consistent as Matlab implementation. + inds = np.argsort(-dtScores, kind="mergesort") + dtScoresSorted = dtScores[inds] + + dtm = np.concatenate( + [e["dtMatches"][:, 0:maxDet] for e in E], axis=1 + )[:, inds] + dtIg = np.concatenate( + [e["dtIgnore"][:, 0:maxDet] for e in E], axis=1 + )[:, inds] + gtIg = np.concatenate([e["gtIgnore"] for e in E]) + npig = np.count_nonzero(gtIg == 0) + + if npig == 0: + continue + + tps = np.logical_and(dtm, np.logical_not(dtIg)) + fps = np.logical_and( + np.logical_not(dtm), np.logical_not(dtIg) + ) + + tp_sum = np.cumsum(tps, axis=1).astype(dtype=np.float64) + fp_sum = np.cumsum(fps, axis=1).astype(dtype=np.float64) + + # Compute TP error (for both bbox and dist modes) + tems = np.concatenate( + [e["dtTranslationError"][:, 0:maxDet] for e in E], + axis=1, + )[:, inds] + + oems = np.concatenate( + [e["dtOrientationError"][:, 0:maxDet] for e in E], + axis=1, + )[:, inds] + + oems_sym = np.concatenate( + [e["dtOrientationErrorSym"][:, 0:maxDet] for e in E], + axis=1, + )[:, inds] + + oems_canonical = np.concatenate( + [e["dtOrientationErrorCanonical"][:, 0:maxDet] for e in E], + axis=1, + )[:, inds] + + sems = np.concatenate( + [e["dtScaleError"][:, 0:maxDet] for e in E], axis=1 + )[:, inds] + + for t, (tp, fp) in enumerate(zip(tp_sum, fp_sum)): + tp = np.array(tp) + fp = np.array(fp) + nd = len(tp) + rc = tp / npig + pr = tp / (fp + tp + np.spacing(1)) + + q = np.zeros((R,)) + ss = np.zeros((R,)) + tran_tp_error = np.ones((R,)) + rot_tp_error = np.ones((R,)) + rot_sym_tp_error = np.ones((R,)) + rot_canonical_tp_error = np.ones((R,)) + scale_tp_error = np.ones((R,)) + + if nd: + recall[t, k, a, m] = rc[-1] + + else: + recall[t, k, a, m] = 0 + + # numpy is slow without cython optimization for accessing elements + # use python array gets significant speed improvement + pr = pr.tolist() + q = q.tolist() + tran_tp_error = tran_tp_error.tolist() + rot_tp_error = rot_tp_error.tolist() + rot_sym_tp_error = rot_sym_tp_error.tolist() + rot_canonical_tp_error = rot_canonical_tp_error.tolist() + scale_tp_error = scale_tp_error.tolist() + + for i in range(nd - 1, 0, -1): + if pr[i] > pr[i - 1]: + pr[i - 1] = pr[i] + + inds = np.searchsorted(rc, p.recThrs, side="left") + + try: + for ri, pi in enumerate(inds): + q[ri] = pr[pi] + ss[ri] = dtScoresSorted[pi] + # Store errors for both bbox and dist modes + tran_tp_error[ri] = tems[t][pi] + rot_tp_error[ri] = oems[t][pi] + rot_sym_tp_error[ri] = oems_sym[t][pi] + rot_canonical_tp_error[ri] = oems_canonical[t][pi] + scale_tp_error[ri] = sems[t][pi] + except: + pass + + precision[t, :, k, a, m] = np.array(q) + scores[t, :, k, a, m] = np.array(ss) + + # Store errors for both bbox and dist modes + trans_tp_errors[t, :, k, a, m] = np.array( + tran_tp_error + ) + rot_tp_errors[t, :, k, a, m] = np.array( + rot_tp_error + ) + rot_sym_tp_errors[t, :, k, a, m] = np.array( + rot_sym_tp_error + ) + rot_canonical_tp_errors[t, :, k, a, m] = np.array( + rot_canonical_tp_error + ) + scale_tp_errors[t, :, k, a, m] = np.array( + scale_tp_error + ) + + self.evals_per_cat_area = evals_per_cat_area + + self.eval = { + "params": p, + "counts": [T, R, K, A, M], + "date": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + "precision": precision, + "recall": recall, + "scores": scores, + "trans_tp_errors": trans_tp_errors, + "rot_tp_errors": rot_tp_errors, + "rot_sym_tp_errors": rot_sym_tp_errors, + "rot_canonical_tp_errors": rot_canonical_tp_errors, + "scale_tp_errors": scale_tp_errors, + } + + toc = time.time() + print("DONE (t={:0.2f}s).".format(toc - tic)) + + def evaluate(self) -> None: + """Run per image evaluation on given images. + + It will store results (a list of dict) in self.evalImgs + """ + print("Running per image evaluation...") + + p = self.params + print(f"Evaluate annotation type *{p.iouType}*") + + tic = time.time() + + p.imgIds = list(np.unique(p.imgIds)) + if p.useCats: + p.catIds = list(np.unique(p.catIds)) + + p.maxDets = sorted(p.maxDets) + self.params = p + + self._prepare() + + catIds = p.catIds if p.useCats else [-1] + + # loop through images, area range, max detection number + self.ious = { + (imgId, catId): self.computeIoU(imgId, catId) + for imgId in p.imgIds + for catId in catIds + } + + maxDet = p.maxDets[-1] + + self.evalImgs = [ + self.evaluateImg(imgId, catId, areaRng, maxDet) + for catId in catIds + for areaRng in p.areaRng + for imgId in p.imgIds + ] + + self._paramsEval = copy.deepcopy(self.params) + + toc = time.time() + print("DONE (t={:0.2f}s).".format(toc - tic)) + + def computeIoU(self, imgId, catId) -> tuple[NDArrayF32, NDArrayF32]: + """Computes the IoUs by sorting based on score""" + p = self.params + + if p.useCats: + gt = self._gts[imgId, catId] + dt = self._dts[imgId, catId] + else: + gt = [_ for cId in p.catIds for _ in self._gts[imgId, cId]] + dt = [_ for cId in p.catIds for _ in self._dts[imgId, cId]] + + if len(gt) == 0 and len(dt) == 0: + return [] + + inds = np.argsort([-d["score"] for d in dt], kind="mergesort") + dt = [dt[i] for i in inds] + if len(dt) > p.maxDets[-1]: + dt = dt[0 : p.maxDets[-1]] + + if self.mode == "2D": + g = [g["bbox"] for g in gt] + d = [d["bbox"] for d in dt] + elif self.mode == "3D": + g = [g["bbox3D"] for g in gt] + d = [d["bbox3D"] for d in dt] + + # compute iou between each dt and gt region + # iscrowd is required in builtin maskUtils so we + # use a dummy buffer for it + iscrowd = [0 for _ in gt] + if self.mode == "2D": + ious = maskUtils.iou(d, g, iscrowd) + elif len(d) > 0 and len(g) > 0: + if p.iouType == "bbox": + dd = torch.tensor(d, dtype=torch.float32) + gg = torch.tensor(g, dtype=torch.float32) + + ious = box3d_overlap(dd, gg).cpu().numpy() + else: + ious = np.zeros((len(d), len(g))) + + dd = [d["center_cam"] for d in dt] + gg = [g["center_cam"] for g in gt] + + ious = cdist(dd, gg, metric="euclidean") + else: + ious = [] + + in_prox = None + + if self.eval_prox: + g = [g["bbox"] for g in gt] + d = [d["bbox"] for d in dt] + iscrowd = [0 for o in gt] + ious2d = maskUtils.iou(d, g, iscrowd) + + if type(ious2d) == list: + in_prox = [] + + else: + in_prox = ious2d > p.proximity_thresh + + return ious, in_prox + + def evaluateImg(self, imgId, catId, aRng, maxDet): + """ + Perform evaluation for single category and image + Returns: + dict (single image results) + """ + + p = self.params + if p.useCats: + gt = self._gts[imgId, catId] + dt = self._dts[imgId, catId] + + else: + gt = [_ for cId in p.catIds for _ in self._gts[imgId, cId]] + dt = [_ for cId in p.catIds for _ in self._dts[imgId, cId]] + + if len(gt) == 0 and len(dt) == 0: + return None + + flag_range = "area" if self.mode == "2D" else "depth" + flag_ignore = "ignore2D" if self.mode == "2D" else "ignore3D" + + for g in gt: + if g[flag_ignore] or ( + g[flag_range] < aRng[0] or g[flag_range] > aRng[1] + ): + g["_ignore"] = 1 + else: + g["_ignore"] = 0 + + # sort dt highest score first, sort gt ignore last + gtind = np.argsort([g["_ignore"] for g in gt], kind="mergesort") + gt = [gt[i] for i in gtind] + dtind = np.argsort([-d["score"] for d in dt], kind="mergesort") + dt = [dt[i] for i in dtind[0:maxDet]] + + # load computed ious + ious = ( + self.ious[imgId, catId][0][:, gtind] + if len(self.ious[imgId, catId][0]) > 0 + else self.ious[imgId, catId][0] + ) + + if self.eval_prox: + in_prox = ( + self.ious[imgId, catId][1][:, gtind] + if len(self.ious[imgId, catId][1]) > 0 + else self.ious[imgId, catId][1] + ) + + T = len(p.iouThrs) + G = len(gt) + D = len(dt) + gtm = np.zeros((T, G)) + dtm = np.zeros((T, D)) + tem = np.ones((T, D)) # Translation Error + sem = np.ones((T, D)) # Scale Error + oem = np.ones((T, D)) # Oritentation Error + oem_sym = np.ones((T, D)) # Symmetric Orientation Error (mod 180) + oem_canonical = np.ones((T, D)) # Canonical Orientation Error + gtIg = np.array([g["_ignore"] for g in gt]) + dtIg = np.zeros((T, D)) + + dist_thres = 1 + if not len(ious) == 0: + for tind, t in enumerate(p.iouThrs): + for dind, d in enumerate(dt): + + # information about best match so far (m=-1 -> unmatched) + iou = min([t, 1 - 1e-10]) + m = -1 + + for gind, g in enumerate(gt): + # in case of proximity evaluation, if not in proximity continue + if self.eval_prox and not in_prox[dind, gind]: + continue + + # if this gt already matched, continue + if gtm[tind, gind] > 0: + continue + + # if dt matched to reg gt, and on ignore gt, stop + if m > -1 and gtIg[m] == 0 and gtIg[gind] == 1: + break + + # continue to next gt unless better match made + if p.iouType == "bbox" and ious[dind, gind] < iou: + continue + + if p.iouType == "dist": + # Compute Object Radius + gt_obj_radius = ( + np.linalg.norm(np.array(g["dimensions"])) / 2 + ) + if ious[dind, gind] > gt_obj_radius * iou: + continue + else: + dist_thres = gt_obj_radius * iou + + # if match successful and best so far, store appropriately + iou = ious[dind, gind] + m = gind + + # if match made store id of match for both dt and gt + if m == -1: + continue + + dtIg[tind, dind] = gtIg[m] + dtm[tind, dind] = gt[m]["id"] + gtm[tind, m] = d["id"] + + # Compute errors for both bbox and dist modes + # (previously only computed for dist mode) + + # Compute GT object radius for normalization + gt_obj_radius = ( + np.linalg.norm(np.array(gt[m]["dimensions"])) / 2 + ) + + # Translation Error + if p.iouType == "dist": + # For dist mode: normalize by distance threshold + # (dist_thres was computed during matching) + tem[tind, dind] = np.linalg.norm( + np.array(d["center_cam"]) + - np.array(gt[m]["center_cam"]) + ) / (dist_thres) + else: + # For bbox mode: normalize by distance threshold + # (same as dist mode, for consistency) + dist_thres_bbox = gt_obj_radius * t + tem[tind, dind] = np.linalg.norm( + np.array(d["center_cam"]) + - np.array(gt[m]["center_cam"]) + ) / (dist_thres_bbox + 1e-6) + + # Orientation Error (same for both modes) + try: + angle = so3_relative_angle( + torch.tensor(d["R_cam"])[None], + torch.tensor(gt[m]["R_cam"])[None], + cos_bound=1e-2, + eps=1e-3, + ).item() + oem[tind, dind] = angle / np.pi + # Symmetric: fold 180 ambiguity, min(angle, pi-angle) + # range [0, pi/2], normalized by pi/2 to [0, 1] + oem_sym[tind, dind] = min(angle, np.pi - angle) / (np.pi / 2) + + # Canonical: normalize both to canonical form + # (W<=L + yaw [0,pi)) before computing angle + R_pred_c = _canonicalize_rotation_np( + d["R_cam"], d["dimensions"] + ) + R_gt_c = _canonicalize_rotation_np( + gt[m]["R_cam"], gt[m]["dimensions"] + ) + angle_c = so3_relative_angle( + torch.tensor(R_pred_c)[None], + torch.tensor(R_gt_c)[None], + cos_bound=1e-2, + eps=1e-3, + ).item() + oem_canonical[tind, dind] = angle_c / np.pi + except ValueError as e: + # Skip invalid rotation matrix pairs + # This can happen when GT or prediction has numerical precision issues + import warnings + R_pred = np.array(d["R_cam"]) + R_gt = np.array(gt[m]["R_cam"]) + R_rel = R_pred @ R_gt.T + warnings.warn( + f"Skipping rotation error for img={imgId}, cat={catId}: {e}\n" + f" det(R_pred)={np.linalg.det(R_pred):.6f}, " + f"det(R_gt)={np.linalg.det(R_gt):.6f}, " + f"trace(R_rel)={np.trace(R_rel):.6f}" + ) + # Set to maximum error (180 degrees = 1.0 in normalized units) + oem[tind, dind] = 1.0 + oem_sym[tind, dind] = 1.0 + oem_canonical[tind, dind] = 1.0 + + # Scale Error (same for both modes) + min_whl = np.minimum( + d["dimensions"], gt[m]["dimensions"] + ) + volume_annotation = np.prod(gt[m]["dimensions"]) + volume_result = np.prod(d["dimensions"]) + + intersection = np.prod(min_whl) + union = ( + volume_annotation + volume_result - intersection + ) + scale_iou = intersection / union + + sem[tind, dind] = 1 - scale_iou + + # set unmatched detections outside of area range to ignore + a = np.array( + [d[flag_range] < aRng[0] or d[flag_range] > aRng[1] for d in dt] + ).reshape((1, len(dt))) + + dtIg = np.logical_or( + dtIg, np.logical_and(dtm == 0, np.repeat(a, T, 0)) + ) + + # in case of proximity evaluation, ignore detections which are far from gt regions + if self.eval_prox and len(in_prox) > 0: + dt_far = in_prox.any(1) == 0 + dtIg = np.logical_or( + dtIg, np.repeat(dt_far.reshape((1, len(dt))), T, 0) + ) + + # store results for given image and category + return { + "image_id": imgId, + "category_id": catId, + "aRng": aRng, + "maxDet": maxDet, + "dtIds": [d["id"] for d in dt], + "gtIds": [g["id"] for g in gt], + "dtMatches": dtm, + "gtMatches": gtm, + "dtScores": [d["score"] for d in dt], + "gtIgnore": gtIg, + "dtIgnore": dtIg, + "dtTranslationError": tem, + "dtScaleError": sem, + "dtOrientationError": oem, + "dtOrientationErrorSym": oem_sym, + "dtOrientationErrorCanonical": oem_canonical, + } + + def summarize(self): + """ + Compute and display summary metrics for evaluation results. + Note this functin can *only* be applied on the default parameter setting + """ + + def _summarize( + mode, ap=1, iouThr=None, areaRng="all", maxDets=100, log_str="" + ): + p = self.params + eval = self.eval + + if mode == "2D": + if self.iou_type == "bbox": + iStr = " {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}" + else: + iStr = " {:<18} {} @[ Dist={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}" + + elif mode == "3D": + if self.iou_type == "bbox": + iStr = " {:<18} {} @[ IoU={:<9} | depth={:>6s} | maxDets={:>3d} ] = {:0.3f}" + else: + iStr = " {:<18} {} @[ Dist={:<9} | depth={:>6s} | maxDets={:>3d} ] = {:0.3f}" + + titleStr = "Average Precision" if ap == 1 else "Average Recall" + typeStr = "(AP)" if ap == 1 else "(AR)" + + iouStr = ( + "{:0.2f}:{:0.2f}".format(p.iouThrs[0], p.iouThrs[-1]) + if iouThr is None + else "{:0.2f}".format(iouThr) + ) + + aind = [ + i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng + ] + mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets] + + if ap == 1: + + # dimension of precision: [TxRxKxAxM] + s = eval["precision"] + + # IoU + if iouThr is not None: + t = np.where(np.isclose(iouThr, p.iouThrs.astype(float)))[ + 0 + ] + s = s[t] + + s = s[:, :, :, aind, mind] + + else: + # dimension of recall: [TxKxAxM] + s = eval["recall"] + if iouThr is not None: + t = np.where(iouThr == p.iouThrs)[0] + s = s[t] + s = s[:, :, aind, mind] + + if len(s[s > -1]) == 0: + mean_s = -1 + + else: + mean_s = np.mean(s[s > -1]) + + if log_str != "": + log_str += "\n" + + log_str += "mode={} ".format(mode) + iStr.format( + titleStr, typeStr, iouStr, areaRng, maxDets, mean_s + ) + + return mean_s, log_str + + def _summarizeDets(mode): + + params = self.params + + # Define the thresholds to be printed + if mode == "2D": + thres = [0.5, 0.75, 0.95] + else: + if self.iou_type == "bbox": + thres = [0.15, 0.25, 0.50] + else: + thres = [0.5, 0.75, 1.0] + + stats = np.zeros((13,)) + stats[0], log_str = _summarize(mode, 1) + + stats[1], log_str = _summarize( + mode, + 1, + iouThr=thres[0], + maxDets=params.maxDets[2], + log_str=log_str, + ) + + stats[2], log_str = _summarize( + mode, + 1, + iouThr=thres[1], + maxDets=params.maxDets[2], + log_str=log_str, + ) + + stats[3], log_str = _summarize( + mode, + 1, + iouThr=thres[2], + maxDets=params.maxDets[2], + log_str=log_str, + ) + + stats[4], log_str = _summarize( + mode, + 1, + areaRng=params.areaRngLbl[1], + maxDets=params.maxDets[2], + log_str=log_str, + ) + + stats[5], log_str = _summarize( + mode, + 1, + areaRng=params.areaRngLbl[2], + maxDets=params.maxDets[2], + log_str=log_str, + ) + + stats[6], log_str = _summarize( + mode, + 1, + areaRng=params.areaRngLbl[3], + maxDets=params.maxDets[2], + log_str=log_str, + ) + + stats[7], log_str = _summarize( + mode, 0, maxDets=params.maxDets[0], log_str=log_str + ) + + stats[8], log_str = _summarize( + mode, 0, maxDets=params.maxDets[1], log_str=log_str + ) + + stats[9], log_str = _summarize( + mode, 0, maxDets=params.maxDets[2], log_str=log_str + ) + + stats[10], log_str = _summarize( + mode, + 0, + areaRng=params.areaRngLbl[1], + maxDets=params.maxDets[2], + log_str=log_str, + ) + + stats[11], log_str = _summarize( + mode, + 0, + areaRng=params.areaRngLbl[2], + maxDets=params.maxDets[2], + log_str=log_str, + ) + + stats[12], log_str = _summarize( + mode, + 0, + areaRng=params.areaRngLbl[3], + maxDets=params.maxDets[2], + log_str=log_str, + ) + + return stats, log_str + + if not self.eval: + raise Exception("Please run accumulate() first") + + stats, log_str = _summarizeDets(self.mode) + self.stats = stats + + return log_str + + +class Detect3DParams: + """Params for the 3d detection evaluation API.""" + + def __init__( + self, + mode: str = "2D", + iouType: str = "bbox", + proximity_thresh: float = 0.3, + ) -> None: + """Create an instance of Detect3DParams. + + Args: + mode: (str) defines whether to evaluate 2D or 3D performance. + iouType: (str) defines the type of IoU to be used for evaluation. + proximity_thresh (float): It defines the neighborhood when + evaluating on non-exhaustively annotated datasets. + """ + assert iouType in {"bbox", "dist"}, f"Invalid iouType {iouType}." + self.iouType = iouType + + if mode == "2D": + self.setDet2DParams() + elif mode == "3D": + self.setDet3DParams() + else: + raise Exception(f"{mode} mode is not supported") + self.mode = mode + self.proximity_thresh = proximity_thresh + + def setDet2DParams(self) -> None: + """Set parameters for 2D detection evaluation.""" + self.imgIds = [] + self.catIds = [] + + # np.arange causes trouble. the data point on arange is slightly larger than the true value + self.iouThrs = np.linspace( + 0.5, 0.95, int(np.round((0.95 - 0.5) / 0.05)) + 1, endpoint=True + ) + + self.recThrs = np.linspace( + 0.0, 1.00, int(np.round((1.00 - 0.0) / 0.01)) + 1, endpoint=True + ) + + self.maxDets = [1, 10, 100] + self.areaRng = [ + [0**2, 1e5**2], + [0**2, 32**2], + [32**2, 96**2], + [96**2, 1e5**2], + ] + + self.areaRngLbl = ["all", "small", "medium", "large"] + self.useCats = 1 + + def setDet3DParams(self) -> None: + """Set parameters for 3D detection evaluation.""" + self.imgIds = [] + self.catIds = [] + + # np.arange causes trouble. The data point on arange is slightly + # larger than the true value + if self.iouType == "bbox": + self.iouThrs = np.linspace( + 0.05, + 0.5, + int(np.round((0.5 - 0.05) / 0.05)) + 1, + endpoint=True, + ) + else: + self.iouThrs = np.linspace( + 0.5, 1.0, int(np.round((1.00 - 0.5) / 0.05)) + 1, endpoint=True + ) + + self.recThrs = np.linspace( + 0.0, 1.00, int(np.round((1.00 - 0.0) / 0.01)) + 1, endpoint=True + ) + + self.maxDets = [1, 10, 100] + self.areaRng = [[0, 1e5], [0, 10], [10, 35], [35, 1e5]] + self.areaRngLbl = ["all", "near", "medium", "far"] + self.useCats = 1 diff --git a/wilddet3d/eval/omni3d.py b/wilddet3d/eval/omni3d.py new file mode 100644 index 0000000000000000000000000000000000000000..77fc2805c8faeb02dc7ced5cd636f461a7062358 --- /dev/null +++ b/wilddet3d/eval/omni3d.py @@ -0,0 +1,378 @@ +"""Omni3D 3D detection evaluation.""" + +import contextlib +import copy +import io +import itertools +import os +from collections.abc import Sequence + +import numpy as np +from terminaltables import AsciiTable +from vis4d.common.logging import rank_zero_info +from vis4d.common.typing import GenericFunc, MetricLogs, NDArrayNumber +from vis4d.eval.base import Evaluator + +from wilddet3d.data.datasets.omni3d.omni3d_classes import omni3d_class_map +from wilddet3d.data.datasets.omni3d.util import get_dataset_det_map + +from .detect3d import Detect3Deval, Detect3DEvaluator + +omni3d_in = { + "stationery", + "sink", + "table", + "floor mat", + "bottle", + "bookcase", + "bin", + "blinds", + "pillow", + "bicycle", + "refrigerator", + "night stand", + "chair", + "sofa", + "books", + "oven", + "towel", + "cabinet", + "window", + "curtain", + "bathtub", + "laptop", + "desk", + "television", + "clothes", + "stove", + "cup", + "shelves", + "box", + "shoes", + "mirror", + "door", + "picture", + "lamp", + "machine", + "counter", + "bed", + "toilet", +} + +omni3d_out = { + "cyclist", + "pedestrian", + "trailer", + "bus", + "motorcycle", + "car", + "barrier", + "truck", + "van", + "traffic cone", + "bicycle", +} + + +class Omni3DEvaluator(Evaluator): + """Omni3D 3D detection evaluator.""" + + def __init__( + self, + data_root: str = "data/omni3d", + omni3d50: bool = True, + datasets: Sequence[str] = ( + "KITTI_test", + "nuScenes_test", + "SUNRGBD_test", + "Hypersim_test", + "ARKitScenes_test", + "Objectron_test", + ), + per_class_eval: bool = True, + # APRel3D parameters (LabelAny3D-style) + enable_aprel3d: bool = False, + aprel_2d_iou_thresh: float = 0.75, + # Mini dataset support + use_mini_dataset: bool = False, + ) -> None: + """Initialize the evaluator. + + Args: + data_root: Root directory for Omni3D data. + omni3d50: Whether to use Omni3D-50 class mapping. + datasets: List of dataset names to evaluate. + per_class_eval: Whether to evaluate per-class metrics. + enable_aprel3d: Whether to enable APRel3D evaluation. + aprel_2d_iou_thresh: 2D IoU threshold for matching (default 0.75). + use_mini_dataset: If True, use annotations_mini100/ for GT. + """ + super().__init__() + self.id_to_name = {v: k for k, v in omni3d_class_map.items()} + self.dataset_names = datasets + self.per_class_eval = per_class_eval + self.enable_aprel3d = enable_aprel3d + self.aprel_2d_iou_thresh = aprel_2d_iou_thresh + self.use_mini_dataset = use_mini_dataset + + # Each dataset evaluator is stored here + self.evaluators: dict[str, Detect3DEvaluator] = {} + + # These store the evaluations for each category and area, + # concatenated from ALL evaluated datasets. Doing so avoids + # the need to re-compute them when accumulating results. + self.evals_per_cat_area2D = {} + self.evals_per_cat_area3D = {} + + self.overall_imgIds = set() + self.overall_catIds = set() + + # Determine annotation directory based on mini dataset flag + if use_mini_dataset: + annotation_dir = os.path.join(data_root, "annotations_mini100") + else: + annotation_dir = os.path.join(data_root, "annotations") + + for dataset_name in self.dataset_names: + annotation = os.path.join( + annotation_dir, f"{dataset_name}.json" + ) + + det_map = get_dataset_det_map( + dataset_name=dataset_name, omni3d50=omni3d50 + ) + + # create an individual dataset evaluator + self.evaluators[dataset_name] = Detect3DEvaluator( + det_map, + cat_map=omni3d_class_map, + annotation=annotation, + eval_prox=( + "Objectron" in dataset_name or "SUNRGBD" in dataset_name + ), + enable_aprel3d=enable_aprel3d, + aprel_2d_iou_thresh=aprel_2d_iou_thresh, + ) + + self.overall_imgIds.update( + set(self.evaluators[dataset_name]._coco_gt.getImgIds()) + ) + self.overall_catIds.update( + set(self.evaluators[dataset_name]._coco_gt.getCatIds()) + ) + + def __repr__(self) -> str: + """Returns the string representation of the object.""" + datasets_str = ", ".join(self.dataset_names) + return f"Omni3DEvaluator ({datasets_str})" + + @property + def metrics(self) -> list[str]: + """Supported metrics. + + Returns: + list[str]: Metrics to evaluate. + """ + return ["2D", "3D"] + + def reset(self) -> None: + """Reset the saved predictions to start new round of evaluation.""" + for dataset_name in self.dataset_names: + self.evaluators[dataset_name].reset() + self.evals_per_cat_area2D.clear() + self.evals_per_cat_area3D.clear() + + def gather(self, gather_func: GenericFunc) -> None: + """Accumulate predictions across processes.""" + for dataset_name in self.dataset_names: + self.evaluators[dataset_name].gather(gather_func) + + def process_batch( + self, + coco_image_id: list[int], + dataset_names: list[str], + pred_boxes: list[NDArrayNumber], + pred_scores: list[NDArrayNumber], + pred_classes: list[NDArrayNumber], + pred_boxes3d: list[NDArrayNumber] | None = None, + ) -> None: + """Process sample and convert detections to coco format.""" + # Handle empty batch (can happen when all images have 0 GT boxes) + if dataset_names is None or len(dataset_names) == 0: + return + for i, dataset_name in enumerate(dataset_names): + self.evaluators[dataset_name].process_batch( + [coco_image_id[i]], + [pred_boxes[i]], + [pred_scores[i]], + [pred_classes[i]], + pred_boxes3d=[pred_boxes3d[i]] if pred_boxes3d else None, + ) + + def evaluate(self, metric: str) -> tuple[MetricLogs, str]: + """Evaluate predictions and return the results.""" + assert metric in self.metrics, f"Unsupported metric: {metric}" + + log_dict = {} + per_dataset_results = {} # Store results for later aggregation + + for dataset_name in self.dataset_names: + rank_zero_info(f"Evaluating {dataset_name}...") + per_dataset_log_dict, dataset_log_str = self.evaluators[ + dataset_name + ].evaluate(metric) + + per_dataset_results[dataset_name] = per_dataset_log_dict + + # Get the main metric key (APRel3D/APRel in APRel mode, AP otherwise) + # Priority: APRel3D > APRel > AP + if "APRel3D" in per_dataset_log_dict: + main_metric_key = "APRel3D" + elif "APRel" in per_dataset_log_dict: + main_metric_key = "APRel" + elif "AP" in per_dataset_log_dict: + main_metric_key = "AP" + else: + # Fallback: use the first key that starts with "AP" + main_metric_key = next((k for k in per_dataset_log_dict.keys() if k.startswith("AP")), "AP") + + log_dict[f"AP_{dataset_name}"] = per_dataset_log_dict[main_metric_key] + + rank_zero_info(dataset_log_str + "\n") + + # store the partially accumulated evaluations per category per area + if metric == "2D": + for key, item in self.evaluators[ + dataset_name + ].bbox_2D_evals_per_cat_area.items(): + if not key in self.evals_per_cat_area2D: + self.evals_per_cat_area2D[key] = [] + self.evals_per_cat_area2D[key] += item + else: + for key, item in self.evaluators[ + dataset_name + ].bbox_3D_evals_per_cat_area.items(): + if not key in self.evals_per_cat_area3D: + self.evals_per_cat_area3D[key] = [] + self.evals_per_cat_area3D[key] += item + + results_per_category_dict = {} + results_per_category = [] + + rank_zero_info(f"Evaluating Omni3D for {metric} Detection...") + + evaluator = Detect3Deval(mode=metric) + evaluator.params.catIds = list(self.overall_catIds) + evaluator.params.imgIds = list(self.overall_imgIds) + evaluator.evalImgs = True + + if metric == "2D": + evaluator.evals_per_cat_area = self.evals_per_cat_area2D + metrics = ["AP", "AP50", "AP75", "AP95", "APs", "APm", "APl"] + else: + evaluator.evals_per_cat_area = self.evals_per_cat_area3D + if self.enable_aprel3d: + metrics = ["APRel3D", "APRel15", "APRel25", "APRel50", "APReln", "APRelm", "APRelf"] + else: + metrics = ["AP", "AP15", "AP25", "AP50", "APn", "APm", "APf"] + + evaluator._paramsEval = copy.deepcopy(evaluator.params) + + with contextlib.redirect_stdout(io.StringIO()): + evaluator.accumulate() + log_str = "\n" + evaluator.summarize() + + log_dict.update(dict(zip(metrics, evaluator.stats))) + + # Add error metrics (aggregate from all datasets) + # Note: In bbox mode, only ASE and AOE are returned (no ATE) + # In dist mode, ATE, ASE, and AOE are all returned + if metric == "3D": + # Collect error metrics from all datasets + all_ase = [] + all_aoe = [] + all_aoe_sym = [] + all_aoe_canonical = [] + all_ods_sym = [] + all_ods_canonical = [] + + # Determine which keys to look for based on mode + if self.enable_aprel3d: + ase_key, aoe_key, aoe_sym_key = "ASERel", "AOERel", "AOERelSym" + aoe_canonical_key = "AOERelCanonical" + ods_sym_key = "ODSRelSym" + ods_canonical_key = "ODSRelCanonical" + else: + ase_key, aoe_key, aoe_sym_key = "ASE", "AOE", "AOE_Sym" + aoe_canonical_key = "AOE_Canonical" + ods_sym_key = "ODS_Sym" + ods_canonical_key = "ODS_Canonical" + + for dataset_name in self.dataset_names: + per_dataset_log_dict = per_dataset_results[dataset_name] + if ase_key in per_dataset_log_dict and not np.isnan(per_dataset_log_dict[ase_key]): + all_ase.append(per_dataset_log_dict[ase_key]) + if aoe_key in per_dataset_log_dict and not np.isnan(per_dataset_log_dict[aoe_key]): + all_aoe.append(per_dataset_log_dict[aoe_key]) + if aoe_sym_key in per_dataset_log_dict and not np.isnan(per_dataset_log_dict[aoe_sym_key]): + all_aoe_sym.append(per_dataset_log_dict[aoe_sym_key]) + if aoe_canonical_key in per_dataset_log_dict and not np.isnan(per_dataset_log_dict[aoe_canonical_key]): + all_aoe_canonical.append(per_dataset_log_dict[aoe_canonical_key]) + if ods_sym_key in per_dataset_log_dict and not np.isnan(per_dataset_log_dict[ods_sym_key]): + all_ods_sym.append(per_dataset_log_dict[ods_sym_key]) + if ods_canonical_key in per_dataset_log_dict and not np.isnan(per_dataset_log_dict[ods_canonical_key]): + all_ods_canonical.append(per_dataset_log_dict[ods_canonical_key]) + + log_dict[ase_key] = np.mean(all_ase) if len(all_ase) > 0 else float("nan") + log_dict[aoe_key] = np.mean(all_aoe) if len(all_aoe) > 0 else float("nan") + log_dict[aoe_sym_key] = np.mean(all_aoe_sym) if len(all_aoe_sym) > 0 else float("nan") + log_dict[aoe_canonical_key] = np.mean(all_aoe_canonical) if len(all_aoe_canonical) > 0 else float("nan") + log_dict[ods_sym_key] = np.mean(all_ods_sym) if len(all_ods_sym) > 0 else float("nan") + log_dict[ods_canonical_key] = np.mean(all_ods_canonical) if len(all_ods_canonical) > 0 else float("nan") + + if self.per_class_eval: + precisions = evaluator.eval["precision"] + for idx, cat_id in enumerate(self.overall_catIds): + cat_name = self.id_to_name[cat_id] + precision = precisions[:, :, idx, 0, -1] + precision = precision[precision > -1] + if precision.size: + ap = float(np.mean(precision).item()) + else: + ap = float("nan") + + results_per_category_dict[cat_name] = ap + results_per_category.append((f"{cat_name}", f"{ap:0.3f}")) + + num_columns = min(6, len(results_per_category) * 2) + results_flatten = list(itertools.chain(*results_per_category)) + headers = ["category", "AP"] * (num_columns // 2) + results_2d = itertools.zip_longest( + *[results_flatten[i::num_columns] for i in range(num_columns)] + ) + table_data = [headers] + list(results_2d) + table = AsciiTable(table_data) + log_str = f"\n{table.table}\n{log_str}" + + # Omni3D Outdoor performance + ap_out_lst = [] + for cat in omni3d_out: + ap_out_lst.append(results_per_category_dict.get(cat, 0.0)) + + log_dict["Omni3D_Out"] = np.mean(ap_out_lst).item() + + # Omni3D Indoor performance + ap_in_lst = [] + for cat in omni3d_in: + ap_in_lst.append(results_per_category_dict.get(cat, 0.0)) + + log_dict["Omni3D_In"] = np.mean(ap_in_lst).item() + + return log_dict, log_str + + def save(self, metric: str, output_dir: str) -> None: + """Save the results to json files.""" + for dataset_name in self.dataset_names: + self.evaluators[dataset_name].save( + metric, output_dir, prefix=dataset_name + ) diff --git a/wilddet3d/eval/open.py b/wilddet3d/eval/open.py new file mode 100644 index 0000000000000000000000000000000000000000..044ffb5e6691f57caac00df8d37bc40e37bd8629 --- /dev/null +++ b/wilddet3d/eval/open.py @@ -0,0 +1,143 @@ +"""Multi-data 3D detection evaluation.""" + +from collections.abc import Sequence + +from vis4d.common.logging import rank_zero_info +from vis4d.common.typing import GenericFunc, MetricLogs, NDArrayNumber +from vis4d.eval.base import Evaluator + +from .detect3d import Detect3DEvaluator +from .omni3d import Omni3DEvaluator + + +class OpenDetect3DEvaluator(Evaluator): + """Multi-data 3D detection evaluator.""" + + def __init__( + self, + datasets: Sequence[str], + evaluators: Sequence[Detect3DEvaluator], + omni3d_evaluator: Omni3DEvaluator | None = None, + ) -> None: + """Initialize the evaluator.""" + super().__init__() + self.dataset_names = datasets + self.evaluators = { + name: evaluator for name, evaluator in zip(datasets, evaluators) + } + + self.omni3d_evaluator = omni3d_evaluator + + def __repr__(self) -> str: + """Returns the string representation of the object.""" + datasets_str = ", ".join(self.dataset_names) + return f"Open 3D Object Detection Evaluator ({datasets_str})" + + @property + def metrics(self) -> list[str]: + """Supported metrics. + + Returns: + list[str]: Metrics to evaluate. + """ + return ["2D", "3D"] + + def reset(self) -> None: + """Reset the saved predictions to start new round of evaluation.""" + for dataset_name in self.dataset_names: + self.evaluators[dataset_name].reset() + + if self.omni3d_evaluator is not None: + self.omni3d_evaluator.reset() + + def gather(self, gather_func: GenericFunc) -> None: + """Accumulate predictions across processes.""" + for dataset_name in self.dataset_names: + self.evaluators[dataset_name].gather(gather_func) + + if self.omni3d_evaluator is not None: + self.omni3d_evaluator.gather(gather_func) + + def process_batch( + self, + coco_image_id: list[int], + dataset_names: list[str], + pred_boxes: list[NDArrayNumber], + pred_scores: list[NDArrayNumber], + pred_classes: list[NDArrayNumber], + pred_boxes3d: list[NDArrayNumber] | None = None, + ) -> None: + """Process sample and convert detections to coco format.""" + for i, dataset_name in enumerate(dataset_names): + if ( + self.omni3d_evaluator is not None + and dataset_name in self.omni3d_evaluator.dataset_names + ): + self.omni3d_evaluator.process_batch( + [coco_image_id[i]], + [dataset_name], + [pred_boxes[i]], + [pred_scores[i]], + [pred_classes[i]], + pred_boxes3d=[pred_boxes3d[i]] if pred_boxes3d else None, + ) + else: + self.evaluators[dataset_name].process_batch( + [coco_image_id[i]], + [pred_boxes[i]], + [pred_scores[i]], + [pred_classes[i]], + pred_boxes3d=[pred_boxes3d[i]] if pred_boxes3d else None, + ) + + def evaluate(self, metric: str) -> tuple[MetricLogs, str]: + """Evaluate predictions and return the results.""" + assert metric in self.metrics, f"Unsupported metric: {metric}" + + log_dict = {} + log_str = "" + + if self.omni3d_evaluator is not None: + log_dict_omni3d, omni3d_log_str = self.omni3d_evaluator.evaluate( + metric + ) + + log_dict.update(log_dict_omni3d) + log_str += omni3d_log_str + + for dataset_name in self.dataset_names: + rank_zero_info(f"Evaluating {dataset_name}...") + per_dataset_log_dict, dataset_log_str = self.evaluators[ + dataset_name + ].evaluate(metric) + + if "ODS" in per_dataset_log_dict: + score = "ODS" + else: + score = "AP" + + log_dict[f"{score}_{dataset_name}"] = per_dataset_log_dict[score] + + if self.evaluators[dataset_name].base_classes is not None: + log_dict[f"{score}_Base_{dataset_name}"] = ( + per_dataset_log_dict[f"{score}_Base"] + ) + log_dict[f"{score}_Novel_{dataset_name}"] = ( + per_dataset_log_dict[f"{score}_Novel"] + ) + + log_str += f"\nCheck {dataset_name} results in log dict." + + rank_zero_info(dataset_log_str + "\n") + + return log_dict, log_str + + def save(self, metric: str, output_dir: str) -> None: + """Save the results to json files.""" + if self.omni3d_evaluator is not None: + self.omni3d_evaluator.save(metric, output_dir) + + for dataset_name in self.dataset_names: + self.evaluators[dataset_name].save( + metric, output_dir, prefix=dataset_name + ) diff --git a/wilddet3d/eval/postprocess_cache_export.py b/wilddet3d/eval/postprocess_cache_export.py new file mode 100644 index 0000000000000000000000000000000000000000..e63c586a191439e227a0bafa215c7ba57e17a043 --- /dev/null +++ b/wilddet3d/eval/postprocess_cache_export.py @@ -0,0 +1,185 @@ +"""Postprocess cache exporter (test-time). + +This evaluator is used with vis4d's EvaluatorCallback to export per-image caches +needed for depth-based 3D box post-processing, without changing the normal +evaluation flow. + +Cache layout: + {cache_root}/{dataset_name}/{image_id}.npz + +We intentionally store the full metric depth map (aligned to original_hw) to +avoid coordinate-system bugs from cropping. +""" + +from __future__ import annotations + +import os +from typing import Any + +import numpy as np +import torch +from vis4d.common.array import array_to_numpy +from vis4d.common.typing import GenericFunc, MetricLogs, NDArrayNumber +from vis4d.eval.base import Evaluator + + +class PostprocessCacheExporter(Evaluator): + """Exports model outputs needed for post-processing into .npz cache files.""" + + def __init__( + self, + cache_root: str, + compress: bool = True, + overwrite: bool = False, + depth_dtype: str = "float32", + ) -> None: + super().__init__() + self.cache_root = cache_root + self.compress = compress + self.overwrite = overwrite + if depth_dtype not in {"float16", "float32"}: + raise ValueError(f"Unsupported depth_dtype: {depth_dtype}") + self.depth_dtype = depth_dtype + + self._num_written = 0 + self._num_skipped = 0 + + @property + def metrics(self) -> list[str]: + # Not a real evaluator; we only export. + return [] + + def reset(self) -> None: # pragma: no cover + self._num_written = 0 + self._num_skipped = 0 + + def gather(self, gather_func: GenericFunc) -> None: # pragma: no cover + # Nothing to gather; each rank writes its own files (safe because image_id is unique). + return + + def process_batch( + self, + coco_image_id: list[int], + dataset_names: list[str], + pred_boxes: list[NDArrayNumber], + pred_scores: list[NDArrayNumber], + pred_classes: list[NDArrayNumber], + pred_boxes3d: list[NDArrayNumber] | None = None, + pred_categories: list[list[str]] | None = None, + depth_maps: list[torch.Tensor] | None = None, + intrinsics: list[NDArrayNumber] | NDArrayNumber | None = None, + original_hw: list[tuple[int, int]] | None = None, + ) -> None: + """Write one .npz per image.""" + if pred_boxes3d is None: + # No 3D boxes -> nothing to export for depth alignment. + print("[PostprocessCacheExporter] Skipping: pred_boxes3d is None") + return + if depth_maps is None: + # Depth backend disabled -> nothing to export. + print("[PostprocessCacheExporter] Skipping: depth_maps is None") + return + if intrinsics is None: + print("[PostprocessCacheExporter] Skipping: intrinsics is None") + return + if original_hw is None: + print("[PostprocessCacheExporter] Skipping: original_hw is None") + return + + print(f"[PostprocessCacheExporter] Processing batch: {len(coco_image_id)} images") + + # Normalize intrinsics to per-sample list + if torch.is_tensor(intrinsics): + # intrinsics: Tensor [B, 3, 3] (may be on GPU) + intrinsics_np = intrinsics.detach().cpu().numpy() + intrinsics_list = [intrinsics_np[j] for j in range(intrinsics_np.shape[0])] + elif isinstance(intrinsics, np.ndarray): + # intrinsics: ndarray [3,3] or [B,3,3] + if intrinsics.ndim == 2: + intrinsics_list = [intrinsics for _ in range(len(coco_image_id))] + else: + intrinsics_list = [intrinsics[j] for j in range(intrinsics.shape[0])] + else: + # intrinsics: sequence of arrays/tensors + intrinsics_list = list(intrinsics) + + for i, image_id in enumerate(coco_image_id): + dataset_name = dataset_names[i] + out_dir = os.path.join(self.cache_root, str(dataset_name)) + os.makedirs(out_dir, exist_ok=True) + + out_path = os.path.join(out_dir, f"{int(image_id)}.npz") + if (not self.overwrite) and os.path.exists(out_path): + self._num_skipped += 1 + continue + + boxes2d = array_to_numpy( + pred_boxes[i].to(torch.float32) if hasattr(pred_boxes[i], "to") else pred_boxes[i], + n_dims=None, + dtype=np.float32, + ) + scores = array_to_numpy( + pred_scores[i].to(torch.float32) if hasattr(pred_scores[i], "to") else pred_scores[i], + n_dims=None, + dtype=np.float32, + ) + class_ids = array_to_numpy( + pred_classes[i].to(torch.int64) if hasattr(pred_classes[i], "to") else pred_classes[i], + n_dims=None, + dtype=np.int64, + ) + boxes3d = array_to_numpy( + pred_boxes3d[i].to(torch.float32) if hasattr(pred_boxes3d[i], "to") else pred_boxes3d[i], + n_dims=None, + dtype=np.float32, + ) + + # depth_maps is list[Tensor] where each Tensor is [H, W] or [1, H, W] + depth = depth_maps[i] + if depth.ndim == 3 and depth.shape[0] == 1: + depth = depth[0] + depth_np = depth.detach().cpu().numpy() + depth_np = depth_np.astype(np.float16 if self.depth_dtype == "float16" else np.float32) + + Ki = intrinsics_list[i] + if torch.is_tensor(Ki): + K = Ki.detach().cpu().numpy().astype(np.float32) + else: + K = np.asarray(Ki, dtype=np.float32) + hw = original_hw[i] + + meta: dict[str, Any] = { + "dataset_name": str(dataset_name), + "image_id": int(image_id), + "original_hw": np.asarray(hw, dtype=np.int32), + } + + # Categories are variable-length strings; store as object array. + if pred_categories is not None and i < len(pred_categories) and pred_categories[i] is not None: + cats = np.asarray(pred_categories[i], dtype=object) + else: + cats = np.asarray([], dtype=object) + + save_fn = np.savez_compressed if self.compress else np.savez + save_fn( + out_path, + boxes2d=boxes2d, + scores=scores, + class_ids=class_ids, + boxes3d_raw=boxes3d, + categories=cats, + depth_map=depth_np, + intrinsics=K, + meta=np.asarray(meta, dtype=object), + ) + self._num_written += 1 + + def evaluate(self, metric: str) -> tuple[MetricLogs, str]: + # No evaluation; return empty. + return {}, f"PostprocessCacheExporter: wrote={self._num_written}, skipped={self._num_skipped}" + + def save(self, metric: str, output_dir: str, prefix: str | None = None) -> None: # pragma: no cover + # Nothing to save beyond the cache files. + return + + diff --git a/wilddet3d/head/__init__.py b/wilddet3d/head/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f5c8863b9706e3dfafb6f584f1e74090c2a6fed9 --- /dev/null +++ b/wilddet3d/head/__init__.py @@ -0,0 +1,12 @@ +"""3D detection head.""" + +from .coder_3d import Det3DCoder +from .depth_cross_attn import DepthCrossAttention +from .head_3d import Det3DHead, RoI2Det3D + +__all__ = [ + "Det3DHead", + "RoI2Det3D", + "Det3DCoder", + "DepthCrossAttention", +] diff --git a/wilddet3d/head/__pycache__/__init__.cpython-311.pyc b/wilddet3d/head/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2ddb486972490a79f081628dcaf1b21dbc1bb22 Binary files /dev/null and b/wilddet3d/head/__pycache__/__init__.cpython-311.pyc differ diff --git a/wilddet3d/head/__pycache__/coder_3d.cpython-311.pyc b/wilddet3d/head/__pycache__/coder_3d.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4f100addb5accc9a72fef27dffcaa0c69119060 Binary files /dev/null and b/wilddet3d/head/__pycache__/coder_3d.cpython-311.pyc differ diff --git a/wilddet3d/head/__pycache__/depth_cross_attn.cpython-311.pyc b/wilddet3d/head/__pycache__/depth_cross_attn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e8db85fe10cd9ab695ef961548419bcb9b3a00f Binary files /dev/null and b/wilddet3d/head/__pycache__/depth_cross_attn.cpython-311.pyc differ diff --git a/wilddet3d/head/__pycache__/head_3d.cpython-311.pyc b/wilddet3d/head/__pycache__/head_3d.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38351d955d44e0b88c747644984f81614f3e4617 Binary files /dev/null and b/wilddet3d/head/__pycache__/head_3d.cpython-311.pyc differ diff --git a/wilddet3d/head/coder_3d.py b/wilddet3d/head/coder_3d.py new file mode 100644 index 0000000000000000000000000000000000000000..bc59dd319ef1af28753e870e0449986fd55bc8e0 --- /dev/null +++ b/wilddet3d/head/coder_3d.py @@ -0,0 +1,263 @@ +"""3D bounding box encoder.""" + +from __future__ import annotations + +import torch +from torch import Tensor +from vis4d.data.const import AxisMode +from vis4d.op.geometry.projection import project_points, unproject_points +from vis4d.op.geometry.rotation import ( + euler_angles_to_matrix, + matrix_to_quaternion, + quaternion_to_matrix, + rotation_matrix_yaw, +) + +from wilddet3d.ops.rotation import ( + matrix_to_rotation_6d, + rotation_6d_to_matrix, +) + + +def _normalize_rotation_half(poses: Tensor) -> Tensor: + """Normalize rotation matrices to [0, pi) yaw range. + + For objects with 180-degree rotational ambiguity (e.g. tables, chairs), + this folds yaw into [0, pi) so that 90 and 270 map to the same target. + Also handles boundary: 180 and 0 map to the same target. + + Uses Y-axis rotation (OPENCV convention) to detect and flip. + """ + import math + + yaw = rotation_matrix_yaw( + poses, axis_mode=AxisMode.OPENCV + )[:, 1] # [N] + # Flip by 180 around Y-axis: Ry(pi) = diag(-1, 1, -1) + # yaw in [-pi, 0) or yaw ~= pi -> flip to [0, pi) + flip_mask = (yaw < 0) | (yaw > math.pi - 1e-4) + poses_out = poses.clone() + # R_new = R @ Ry(pi), Ry(pi) negates columns 0 and 2 + poses_out[flip_mask, :, 0] = -poses[flip_mask, :, 0] + poses_out[flip_mask, :, 2] = -poses[flip_mask, :, 2] + return poses_out + + +def _normalize_canonical( + poses: Tensor, dims: Tensor, +) -> tuple[Tensor, Tensor]: + """Normalize rotation and dimensions to canonical form. + + Eliminates OBB rotation ambiguity via 2 steps: + + Step 1 - Force W <= L: + If W > L, swap W and L, then apply Ry(90 deg) to rotation. + boxes3d dims = [W, L, H]. Canonical: X=L, Z=W, so swapping + W<->L requires rotating 90 deg around Y to keep the box + geometry identical. + Ry(90): new_col0 = old_col2, new_col2 = -old_col0 + + Step 2 - Normalize yaw to [0, pi): + Same as _normalize_rotation_half. Apply Ry(180 deg) if yaw < 0 + or yaw >= pi. + + Together these reduce 4-fold Ry ambiguity to 1-fold. + (Rx(180) upside-down ambiguity is left to data preprocessing.) + + Args: + poses: Rotation matrices [N, 3, 3]. + dims: Dimensions [N, 3] as [W, L, H]. + + Returns: + poses_out: Normalized rotation matrices [N, 3, 3]. + dims_out: Normalized dimensions [N, 3] with W <= L. + """ + import math + + poses_out = poses.clone() + dims_out = dims.clone() + + # Step 1: Force W <= L + # dims = [W, L, H], indices 0, 1, 2 + swap_mask = dims_out[:, 0] > dims_out[:, 1] # W > L + if swap_mask.any(): + # Swap W and L + w_old = dims_out[swap_mask, 0].clone() + dims_out[swap_mask, 0] = dims_out[swap_mask, 1] + dims_out[swap_mask, 1] = w_old + + # Apply Ry(90 deg): R_new = R @ Ry(90) + # Ry(90) = [[0,0,1],[0,1,0],[-1,0,0]] + # col0_new = R @ [0,0,-1]^T = -col2 + # col1_new = R @ [0,1,0]^T = col1 (unchanged) + # col2_new = R @ [1,0,0]^T = col0 + col0 = poses_out[swap_mask, :, 0].clone() + col2 = poses_out[swap_mask, :, 2].clone() + poses_out[swap_mask, :, 0] = -col2 + poses_out[swap_mask, :, 2] = col0 + + # Step 2: Normalize yaw to [0, pi) + yaw = rotation_matrix_yaw( + poses_out, axis_mode=AxisMode.OPENCV + )[:, 1] # [N] + flip_mask = (yaw < 0) | (yaw > math.pi - 1e-4) + if flip_mask.any(): + # R_new = R @ Ry(pi), negates columns 0 and 2 + poses_out[flip_mask, :, 0] = -poses_out[flip_mask, :, 0] + poses_out[flip_mask, :, 2] = -poses_out[flip_mask, :, 2] + + return poses_out, dims_out + + +class Det3DCoder: + """3D box coder for encoding/decoding 3D bounding boxes.""" + + def __init__( + self, + center_scale: float = 10.0, + depth_scale: float = 2.0, + dim_scale: float = 2.0, + orientation: str = "rotation_6d", + ambiguous_rotation: bool = False, + canonical_rotation: bool = False, + ) -> None: + """Initialize the 3D box coder.""" + self.center_scale = center_scale + self.depth_scale = depth_scale + self.dim_scale = dim_scale + self.ambiguous_rotation = ambiguous_rotation + self.canonical_rotation = canonical_rotation + if canonical_rotation: + print( + "[Det3DCoder] canonical_rotation=True: " + "dims normalized to W<=L, yaw to [0, 180)" + ) + elif ambiguous_rotation: + print( + "[Det3DCoder] ambiguous_rotation=True: " + "GT rotation normalized to [0, 180) yaw range" + ) + + assert orientation in { + "yaw", + "rotation_6d", + }, f"Invalid orientation {orientation}." + self.orientation = orientation + + if orientation == "yaw": + reg_dims = 8 + elif orientation == "rotation_6d": + reg_dims = 12 + + self.reg_dims = reg_dims + + def encode( + self, boxes: Tensor, boxes3d: Tensor, intrinsics: Tensor, + ) -> tuple[Tensor, Tensor]: + """Encode the 3D bounding boxes. + + Args: + boxes: 2D boxes in PIXEL xyxy format. Shape (N, 4). + IMPORTANT: Should be GT 2D boxes during training (not predictions!) + This ensures stable targets. At inference, decode() uses pred boxes. + boxes3d: GT 3D boxes [center_3d(3), dims(3), quat(4)]. Shape (N, 10). + intrinsics: Camera intrinsics. Shape (3, 3) or (N, 3, 3). + + Returns: + boxes3d_target: Encoded targets [delta_2d(2), log_depth(1), log_dims(3), rot_6d(6)]. + boxes3d_weights: Per-element weights (0 for invalid depth/dims). + """ + projected_center_3d = project_points(boxes3d[:, :3], intrinsics) + ctr_x = (boxes[:, 0] + boxes[:, 2]) / 2 + ctr_y = (boxes[:, 1] + boxes[:, 3]) / 2 + center_2d = torch.stack([ctr_x, ctr_y], -1) + + delta_center = projected_center_3d - center_2d + + delta_center /= self.center_scale + + valid_depth = boxes3d[:, 2] > 0 + + depth = torch.where( + valid_depth, + torch.log(boxes3d[:, 2]) * self.depth_scale, + boxes3d[:, 2].new_zeros(1), + ) + depth = depth.unsqueeze(-1) + + raw_dims = boxes3d[:, 3:6] # [W, L, H] + + poses = quaternion_to_matrix(boxes3d[:, 6:]) + + if self.canonical_rotation: + poses, raw_dims = _normalize_canonical(poses, raw_dims) + elif self.ambiguous_rotation: + poses = _normalize_rotation_half(poses) + + valid_dims = raw_dims > 0 + dims = torch.where( + valid_dims, + torch.log(raw_dims) * self.dim_scale, + raw_dims.new_zeros(1), + ) + + if self.orientation == "yaw": + yaw = rotation_matrix_yaw( + poses, + axis_mode=AxisMode.OPENCV, + )[:, 1] + + sin_yaw = torch.sin(yaw).unsqueeze(-1) + cos_yaw = torch.cos(yaw).unsqueeze(-1) + + boxes3d_target = torch.cat( + [delta_center, depth, dims, sin_yaw, cos_yaw], -1 + ) + elif self.orientation == "rotation_6d": + rot_6d = matrix_to_rotation_6d(poses) + + boxes3d_target = torch.cat([delta_center, depth, dims, rot_6d], -1) + + boxes3d_weights = torch.ones_like(boxes3d_target) + boxes3d_weights[:, 2] = valid_depth.float() + boxes3d_weights[:, 3:6] = valid_dims.float() + + return boxes3d_target, boxes3d_weights + + def decode( + self, boxes: Tensor, boxes3d: Tensor, intrinsics: Tensor + ) -> Tensor: + """Decode the 3D bounding boxes.""" + delta_center = boxes3d[:, :2] * self.center_scale + + ctr_x = (boxes[:, 0] + boxes[:, 2]) / 2 + ctr_y = (boxes[:, 1] + boxes[:, 3]) / 2 + center_2d = torch.stack([ctr_x, ctr_y], -1) + + proj_center_3d = center_2d + delta_center + + depth = torch.exp(boxes3d[:, 2] / self.depth_scale) + + center_3d = unproject_points(proj_center_3d, depth, intrinsics) + + dims = torch.exp(boxes3d[:, 3:6] / self.dim_scale) + + if self.orientation == "yaw": + yaw = torch.atan2(boxes3d[:, 6], boxes3d[:, 7]) + + orientation = torch.stack( + [torch.zeros_like(yaw), yaw, torch.zeros_like(yaw)], -1 + ) + + poses = euler_angles_to_matrix(orientation) + elif self.orientation == "rotation_6d": + poses = rotation_6d_to_matrix(boxes3d[:, 6:]) + + if self.canonical_rotation: + poses, dims = _normalize_canonical(poses, dims) + elif self.ambiguous_rotation: + poses = _normalize_rotation_half(poses) + + orientation = matrix_to_quaternion(poses) + + return torch.cat([center_3d, dims, orientation], dim=1) diff --git a/wilddet3d/head/depth_cross_attn.py b/wilddet3d/head/depth_cross_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..a9c0412e62ea0b872321704ee87f78c02d56ca38 --- /dev/null +++ b/wilddet3d/head/depth_cross_attn.py @@ -0,0 +1,340 @@ +"""Depth cross-attention head.""" + +from __future__ import annotations + +from collections.abc import Sequence + +import torch +from einops import rearrange +from timm.layers import trunc_normal_ +from torch import Tensor, nn +from torch.nn import functional as F + +from wilddet3d.ops.ray import generate_rays, rsh_cart_8 +from wilddet3d.ops.attention import ( + AttentionBlock, + NystromBlock, + PositionEmbeddingSine, +) +from wilddet3d.ops.mlp import MLP +from wilddet3d.ops.upsample import ConvUpsample +from wilddet3d.ops.util import flat_interpolate + + +class DepthCrossAttention(nn.Module): + """Depth cross-attention head for depth estimation.""" + + def __init__( + self, + embed_dims: int = 256, + depth_scale: float = 2.0, + input_dims: Sequence[int] = (256, 256, 256), + output_scales: int = 1, + ) -> None: + """Initialize the depth head.""" + super().__init__() + self.depth_scale = depth_scale + assert ( + output_scales >= 1 and output_scales <= 3 + ), "Invalid output scales." + self.output_scales = output_scales + + num_resolutions = len(input_dims) + self.input_dims = input_dims + self.num_resolutions = num_resolutions + + # Pool features as depth query + self.features_channel_cat = nn.Linear( + embed_dims * self.num_resolutions, embed_dims + ) + self.to_latents = MLP(embed_dims, expansion=2) + + self.pos_embed = PositionEmbeddingSine(embed_dims // 2, normalize=True) + + self.level_embeds = nn.Parameter( + torch.randn(self.num_resolutions, embed_dims), + requires_grad=True, + ) + self.level_embed_layer = nn.Sequential( + nn.Linear(embed_dims, embed_dims), + nn.GELU(), + nn.Linear(embed_dims, embed_dims), + nn.LayerNorm(embed_dims), + ) + + self.aggregate_16 = AttentionBlock( + embed_dims, + num_heads=1, + expansion=4, + context_dim=embed_dims, + ) + + self.prompt_camera = AttentionBlock( + embed_dims, num_heads=1, expansion=4, context_dim=embed_dims + ) + + # 1/16 resolution + self.project_rays_16 = MLP(81, expansion=4, output_dim=embed_dims) + + self.layers_16 = nn.ModuleList( + [ + AttentionBlock(embed_dims, num_heads=8, expansion=4), + NystromBlock(embed_dims, num_heads=8, expansion=4), + ] + ) + + self.up_8 = ConvUpsample(embed_dims, expansion=4) + + if self.output_scales == 1: + self.out_8 = nn.Conv2d(embed_dims // 2, 1, 3, padding=1) + + if self.output_scales >= 2: + # 1/8 resolution + embed_dims_8 = embed_dims // 2 + self.project_rays_8 = MLP(81, expansion=4, output_dim=embed_dims_8) + + self.layers_8 = nn.ModuleList( + [ + AttentionBlock(embed_dims_8, num_heads=4, expansion=4), + NystromBlock(embed_dims_8, num_heads=4, expansion=4), + ] + ) + + self.up_4 = ConvUpsample(embed_dims_8, expansion=4) + + if self.output_scales == 2: + self.out_4 = nn.Conv2d(embed_dims_8 // 2, 1, 3, padding=1) + + if self.output_scales == 3: + # 1/4 resolution + embed_dims_4 = embed_dims // 4 + self.project_rays_4 = MLP(81, expansion=4, output_dim=embed_dims_4) + + self.layers_4 = nn.ModuleList( + [ + AttentionBlock(embed_dims_4, num_heads=2, expansion=4), + NystromBlock(embed_dims_4, num_heads=2, expansion=4), + ] + ) + + self.up_2 = ConvUpsample(embed_dims_4, expansion=4) + + self.out_2 = nn.Conv2d(embed_dims_4 // 2, 1, 3, padding=1) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Conv2d): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def get_rsh_cart(self, rays_embedding: Tensor) -> Tensor: + """Get real spherical harmonic.""" + return rsh_cart_8(rays_embedding) + + def forward( + self, feats: Tensor, intrinsics: Tensor, image_hw: tuple[int, int] + ) -> Tensor: + """Forward.""" + # Camera Embedding + rays_hr, _ = generate_rays(intrinsics, image_hw) + + # 1/16 shape + shape = image_hw[0] // 16, image_hw[1] // 16 + + latents = [] + for _, feat in enumerate(feats): + latent = ( + F.interpolate( + feat, + size=shape, + mode="bilinear", + align_corners=False, + antialias=True, + ) + .flatten(2) + .permute(0, 2, 1) + ) + + latents.append(latent) + + # positional embeddings, spatial and level + level_embed = torch.cat( + [ + self.level_embed_layer(self.level_embeds)[i : i + 1] + .unsqueeze(0) + .repeat(feats[0].shape[0], shape[0] * shape[1], 1) + for i in range(self.num_resolutions) + ], + dim=1, + ) + pos_embed = self.pos_embed( + torch.zeros( + feats[0].shape[0], + 1, + shape[0], + shape[1], + device=feats[0].device, + requires_grad=False, + ) + ) + pos_embed = rearrange(pos_embed, "b c h w -> b (h w) c").repeat( + 1, self.num_resolutions, 1 + ) + + features_tokens = torch.cat(latents, dim=1) + features_tokens_pos = pos_embed + level_embed + + features_channels = torch.cat(latents, dim=-1) + features_16 = self.features_channel_cat(features_channels) + latents_16 = self.to_latents(features_16) + + # Aggregate features: F -> D + latents_16 = self.aggregate_16( + latents_16, + context=features_tokens, + pos_embed_context=features_tokens_pos, + ) + + # 1/16 shape + rays_embedding_16 = F.normalize( + flat_interpolate(rays_hr, old=image_hw, new=shape), dim=-1 + ) + + rays_embedding_16 = self.project_rays_16( + self.get_rsh_cart(rays_embedding_16) + ) + + # Aggregate camera: D -> D|E + latents_16 = self.prompt_camera(latents_16, context=rays_embedding_16) + + outs = [] + depth_latents = [] + + # Block 16 - Out 8 + for layer in self.layers_16: + latents_16 = layer(latents_16, pos_embed=rays_embedding_16) + + latents_8 = self.up_8( + rearrange( + latents_16, + "b (h w) c -> b c h w", + h=shape[0], + w=shape[1], + ).contiguous() + ) + + if self.output_scales == 1: + out_8 = self.out_8( + rearrange( + latents_8, + "b (h w) c -> b c h w", + h=shape[0] * 2, + w=shape[1] * 2, + ) + ) + outs.append(out_8) + depth_latents.append(latents_8.detach()) + + if self.output_scales >= 2: + # 1/8 shape + rays_embedding_8 = F.normalize( + flat_interpolate( + rays_hr, old=image_hw, new=(shape[0] * 2, shape[1] * 2) + ), + dim=-1, + ) + + rays_embedding_8 = self.project_rays_8( + self.get_rsh_cart(rays_embedding_8) + ) + + # Block 8 - Out 4 + for layer in self.layers_8: + latents_8 = layer(latents_8, pos_embed=rays_embedding_8) + + latents_4 = self.up_4( + rearrange( + latents_8, + "b (h w) c -> b c h w", + h=shape[0] * 2, + w=shape[1] * 2, + ).contiguous() + ) + + if self.output_scales == 2: + out_4 = self.out_4( + rearrange( + latents_4, + "b (h w) c -> b c h w", + h=shape[0] * 4, + w=shape[1] * 4, + ) + ) + outs.append(out_4) + depth_latents.append(latents_4.detach()) + + if self.output_scales == 3: + # 1/4 shape + rays_embedding_4 = F.normalize( + flat_interpolate( + rays_hr, old=image_hw, new=(shape[0] * 4, shape[1] * 4) + ), + dim=-1, + ) + + rays_embedding_4 = self.project_rays_4( + self.get_rsh_cart(rays_embedding_4) + ) + + # Block 4 - Out 2 + for layer in self.layers_4: + latents_4 = layer(latents_4, pos_embed=rays_embedding_4) + + latents_2 = self.up_2( + rearrange( + latents_4, + "b (h w) c -> b c h w", + h=shape[0] * 4, + w=shape[1] * 4, + ).contiguous() + ) + out_2 = self.out_2( + rearrange( + latents_2, + "b (h w) c -> b c h w", + h=shape[0] * 8, + w=shape[1] * 8, + ) + ) + outs.append(out_2) + depth_latents.append(latents_2.detach()) + + # MS Outputs + depth_preds = ( + sum( + [ + F.interpolate( + torch.exp((out / self.depth_scale).clamp(-10.0, 10.0)), + size=image_hw, + mode="bilinear", + align_corners=False, + antialias=True, + ) + for out in outs + ] + ) + / len(outs) + ).squeeze(1) + + depth_latent = depth_latents[-1] + + return depth_preds, depth_latent diff --git a/wilddet3d/head/head_3d.py b/wilddet3d/head/head_3d.py new file mode 100644 index 0000000000000000000000000000000000000000..3f3ebbbf49e1a17c1bdaad1ec97a5d80ee85c3ff --- /dev/null +++ b/wilddet3d/head/head_3d.py @@ -0,0 +1,452 @@ +"""3D detection head.""" + +from __future__ import annotations + +import torch +from torch import Tensor, nn +from torch.nn import functional as F +from torchvision.ops import batched_nms, nms +from vis4d.op.layer.attention import MultiheadAttention +from vis4d.op.layer.transformer import FFN, get_clones +from vis4d.op.layer.weight_init import xavier_init + +from wilddet3d.ops.box2d import bbox_cxcywh_to_xyxy +from wilddet3d.ops.ray import generate_rays, rsh_cart_8 +from wilddet3d.ops.mlp import MLP +from wilddet3d.ops.util import flat_interpolate + +from .coder_3d import Det3DCoder + + +def convert_grounding_to_cls_scores( + logits: Tensor, positive_maps: dict[int, list[int, int]] +) -> Tensor: + """Convert logits to class scores.""" + assert len(positive_maps) == logits.shape[0] # batch size + + scores = torch.zeros( + logits.shape[0], logits.shape[1], len(positive_maps[0]) + ).to(logits.device) + if positive_maps is not None: + if all(x == positive_maps[0] for x in positive_maps): + # only need to compute once + positive_map = positive_maps[0] + for label_j in positive_map: + scores[:, :, label_j - 1] = logits[ + :, :, torch.LongTensor(positive_map[label_j]) + ].mean(-1) + else: + for i, positive_map in enumerate(positive_maps): + for label_j in positive_map: + scores[i, :, label_j - 1] = logits[ + i, :, torch.LongTensor(positive_map[label_j]) + ].mean(-1) + return scores + + +class Det3DHead(nn.Module): + """3D detection head. + + Args: + embed_dims: Embedding dimension for the head. + num_decoder_layer: Number of decoder layers. + num_reg_fcs: Number of fully connected layers in regression branch. + as_two_stage: Whether to use two-stage detection. + box_coder: 3D box coder for encoding/decoding. + depth_output_scales: Scale factor for depth embedding dims. + use_camera_prompt: Whether to use camera/ray prompt branch. + Set to False when using ray-aware depth backends (UniDepthV2, DetAny3D) + since their depth_latents already incorporate ray information. + Set to True for non-ray-aware backends (UniDepthHead v1). + use_depth_prompt: Whether to use depth prompt branch. + Set to False for ablation: only use depth via encoder fusion. + """ + + def __init__( + self, + embed_dims: int = 256, + num_decoder_layer: int = 6, + num_reg_fcs: int = 2, + as_two_stage: bool = True, + box_coder: Det3DCoder | None = None, + depth_output_scales: int = 1, + depth_latent_dim: int | None = None, + use_camera_prompt: bool = True, + use_depth_prompt: bool = True, + ) -> None: + """Initialize the 3D detection head. + + Args: + depth_latent_dim: Dimension of depth latents from geometry backend. + If provided, uses this directly. If None, computes from + depth_output_scales as embed_dims // 2**depth_output_scales. + """ + super().__init__() + self.embed_dims = embed_dims + self.use_camera_prompt = use_camera_prompt + self.use_depth_prompt = use_depth_prompt + + self.num_pred_layer = ( + num_decoder_layer + 1 if as_two_stage else num_decoder_layer + ) + self.as_two_stage = as_two_stage + + self.box_coder = box_coder or Det3DCoder() + + reg_branch = self._get_reg_branch(num_reg_fcs, self.box_coder.reg_dims) + self.reg_branches = get_clones(reg_branch, self.num_pred_layer) + + # 3D confidence branch (predicts 3D-aware objectness score) + conf_branch = self._get_conf_branch(num_reg_fcs) + self.conf_branches = get_clones(conf_branch, self.num_pred_layer) + + # Camera prompt branch (only created if use_camera_prompt is True) + if self.use_camera_prompt: + project_rays, prompt_camera = self._get_condition_branch( + input_dims=81, expansion=4, embed_dims=embed_dims + ) + self.project_rays = get_clones(project_rays, self.num_pred_layer) + self.prompt_camera = get_clones(prompt_camera, self.num_pred_layer) + else: + self.project_rays = None + self.prompt_camera = None + + # Depth prompt branch (only created if use_depth_prompt is True) + if self.use_depth_prompt: + # Use depth_latent_dim directly if provided, else compute from depth_output_scales + if depth_latent_dim is not None: + depth_embed_dims = depth_latent_dim + else: + depth_embed_dims = embed_dims // 2**depth_output_scales + project_depth, prompt_depth = self._get_condition_branch( + depth_embed_dims, expansion=4, embed_dims=embed_dims + ) + self.project_depth = get_clones(project_depth, self.num_pred_layer) + self.prompt_depth = get_clones(prompt_depth, self.num_pred_layer) + else: + self.project_depth = None + self.prompt_depth = None + + self._init_weights() + + def _get_reg_branch( + self, num_reg_fcs: int, reg_dims: int + ) -> nn.Sequential: + """Get the regression branch.""" + reg_branch = [] + for _ in range(num_reg_fcs): + reg_branch.append(nn.Linear(self.embed_dims, self.embed_dims)) + reg_branch.append(nn.ReLU()) + reg_branch.append(nn.Linear(self.embed_dims, reg_dims)) + return nn.Sequential(*reg_branch) + + def _get_conf_branch(self, num_reg_fcs: int) -> nn.Sequential: + """Get the 3D confidence branch (output dim = 1).""" + conf_branch = [] + for _ in range(num_reg_fcs): + conf_branch.append(nn.Linear(self.embed_dims, self.embed_dims)) + conf_branch.append(nn.ReLU()) + conf_branch.append(nn.Linear(self.embed_dims, 1)) + return nn.Sequential(*conf_branch) + + def _get_condition_branch( + self, input_dims: int, expansion: int, embed_dims: int + ) -> tuple[nn.Module, nn.Module]: + """Get the condition branch.""" + project_layer = MLP( + input_dims, expansion=expansion, output_dim=embed_dims + ) + + prompt_layer = Prompt3DQueryLayer(embed_dims) + + return project_layer, prompt_layer + + def _init_weights(self) -> None: + """Initialize weights of the Deformable DETR head.""" + for m in self.reg_branches: + xavier_init(m, distribution="uniform") + for m in self.conf_branches: + xavier_init(m, distribution="uniform") + + def get_camera_embeddings( + self, + intrinsics: Tensor, + image_shape: tuple[int, int], + downsample: int = 16, + ) -> Tensor: + """Get the camera embeddings. + + Args: + intrinsics: Camera intrinsics [B, 3, 3]. Should match the space + where depth_latents were computed (may be adjusted for DINOv2). + image_shape: Image (H, W) in the same space as intrinsics. + downsample: Downsample factor for ray grid (8 or 16). + Must match depth_latents resolution. + + Returns: + ray_embeddings: [B, H//downsample * W//downsample, 81] + """ + rays, _ = generate_rays(intrinsics, image_shape) + + rays = F.normalize( + flat_interpolate( + rays, + old=image_shape, + new=(image_shape[0] // downsample, image_shape[1] // downsample), + ), + dim=-1, + ) + + return rsh_cart_8(rays) + + def single_forward( + self, + layer_id: int, + hidden_state: Tensor, + ray_embeddings: Tensor | None, + depth_latents: Tensor | None = None, + ) -> tuple[Tensor, Tensor]: + """Single layer forward pass of the 3D detection head. + + Args: + layer_id: Index of the decoder layer. + hidden_state: Query hidden states [B, num_queries, embed_dims]. + ray_embeddings: Ray embeddings [B, H*W, 81]. Only used if use_camera_prompt=True. + depth_latents: Depth latent features [B, H*W, depth_embed_dims]. + + Returns: + Tuple of (reg_output, conf_output): + - reg_output: 3D box regression [B, num_queries, reg_dims] + - conf_output: 3D confidence logits [B, num_queries, 1] + """ + # Camera-aware 3D queries (only if use_camera_prompt is True) + if self.use_camera_prompt and ray_embeddings is not None: + ray_embedding = self.project_rays[layer_id](ray_embeddings) + hidden_state = self.prompt_camera[layer_id]( + hidden_state, ray_embedding, ray_embedding + ) + + # Depth-aware 3D queries (only if use_depth_prompt is True) + if self.use_depth_prompt and depth_latents is not None: + proj_depth_latents = self.project_depth[layer_id](depth_latents) + hidden_state = self.prompt_depth[layer_id]( + hidden_state, proj_depth_latents, proj_depth_latents + ) + + reg_output = self.reg_branches[layer_id](hidden_state) + conf_output = self.conf_branches[layer_id](hidden_state) + + return reg_output, conf_output + + def forward( + self, + hidden_states: Tensor, + ray_embeddings: Tensor | None, + depth_latents: Tensor | None = None, + ) -> tuple[Tensor, Tensor]: + """Forward pass of the 3D detection head. + + Args: + hidden_states: Query hidden states [num_layers, B, num_queries, embed_dims]. + ray_embeddings: Ray embeddings [B, H*W, 81]. Can be None if use_camera_prompt=False. + depth_latents: Depth latent features [B, H*W, depth_embed_dims]. + + Returns: + Tuple of (stacked_reg, stacked_conf): + - stacked_reg: [num_layers, B, num_queries, reg_dims] + - stacked_conf: [num_layers, B, num_queries, 1] + """ + all_layers_outputs_3d = [] + all_layers_conf_3d = [] + + for layer_id in range(hidden_states.shape[0]): + hidden_state = hidden_states[layer_id] + + reg_output, conf_output = self.single_forward( + layer_id, hidden_state, ray_embeddings, depth_latents + ) + + all_layers_outputs_3d.append(reg_output) + all_layers_conf_3d.append(conf_output) + + return torch.stack(all_layers_outputs_3d), torch.stack(all_layers_conf_3d) + + +class Prompt3DQueryLayer(nn.Module): + """Prompt 3D object query Layer.""" + + def __init__(self, embed_dims: int = 256) -> None: + """Init.""" + super().__init__() + self.self_attn = MultiheadAttention( + embed_dims=256, num_heads=8, batch_first=True + ) + + self.norm1 = nn.LayerNorm(embed_dims) + + self.cross_attn = MultiheadAttention( + embed_dims=256, num_heads=1, batch_first=True + ) + + self.norm2 = nn.LayerNorm(embed_dims) + + self.ffn = FFN(embed_dims) + + self.norm3 = nn.LayerNorm(embed_dims) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + query_pos: Tensor | None = None, + ) -> Tensor: + """Forward.""" + # self attention + query = self.self_attn( + query=query, + key=query, + value=query, + query_pos=query_pos, + key_pos=query_pos, + ) + query = self.norm1(query) + + # cross attention + query = self.cross_attn( + query=query, + key=key, + value=value, + query_pos=query_pos, + ) + query = self.norm2(query) + + # FFN + query = self.ffn(query) + query = self.norm3(query) + + return query + + +class RoI2Det3D: + """Convert RoI to 3D Detection.""" + + def __init__( + self, + nms: bool = False, + max_per_img: int = 300, + class_agnostic_nms: bool = False, + score_threshold: float = 0.0, + iou_threshold: float = 0.5, + box_coder: Det3DCoder | None = None, + ) -> None: + """Create an instance of RoI2Det3D.""" + self.nms = nms + self.max_per_img = max_per_img + self.class_agnostic_nms = class_agnostic_nms + self.score_threshold = score_threshold + self.iou_threshold = iou_threshold + + self.box_coder = box_coder or Det3DCoder() + + def __call__( + self, + cls_score: Tensor, + bbox_pred: Tensor, + token_positive_maps: dict[int, list[int]] | None, + img_shape: tuple[int, int], + ori_shape: tuple[int, int], + bbox_3d_pred: Tensor, + intrinsics: Tensor, + padding: list[int] | None, + ) -> tuple[Tensor, Tensor, Tensor]: + """Transform the bbox head output into bbox results.""" + assert len(cls_score) == len(bbox_pred) # num_queries + + det_bboxes = bbox_cxcywh_to_xyxy(bbox_pred) + det_bboxes[:, 0::2] = det_bboxes[:, 0::2] * img_shape[1] + det_bboxes[:, 1::2] = det_bboxes[:, 1::2] * img_shape[0] + det_bboxes[:, 0::2].clamp_(min=0, max=img_shape[1]) + det_bboxes[:, 1::2].clamp_(min=0, max=img_shape[0]) + + if token_positive_maps is not None: + cls_score = convert_grounding_to_cls_scores( + logits=cls_score.sigmoid()[None], + positive_maps=[token_positive_maps], + )[0] + + k = min(self.max_per_img, cls_score.view(-1).shape[0]) + if k == 0: + device = cls_score.device + return ( + torch.zeros(0, 4, device=device), + torch.zeros(0, device=device), + torch.zeros(0, dtype=torch.long, device=device), + torch.zeros(0, 10, device=device), + ) + scores, indexes = cls_score.view(-1).topk(k) + num_classes = cls_score.shape[-1] + det_labels = indexes % num_classes + bbox_index = indexes // num_classes + det_bboxes = det_bboxes[bbox_index] + bbox_3d_pred = bbox_3d_pred[bbox_index] + + # Remove low scoring boxes + if self.score_threshold > 0.0: + mask = scores > self.score_threshold + det_bboxes = det_bboxes[mask] + det_labels = det_labels[mask] + scores = scores[mask] + bbox_3d_pred = bbox_3d_pred[mask] + + if self.nms: + if self.class_agnostic_nms: + keep = nms(det_bboxes, scores, self.iou_threshold) + else: + keep = batched_nms( + det_bboxes, scores, det_labels, self.iou_threshold + ) + + det_bboxes = det_bboxes[keep] + det_labels = det_labels[keep] + scores = scores[keep] + bbox_3d_pred = bbox_3d_pred[keep] + else: + cls_score = cls_score.sigmoid() + scores, _ = cls_score.max(-1) + scores, indexes = scores.topk(self.max_per_img) + det_bboxes = det_bboxes[indexes] + bbox_3d_pred = bbox_3d_pred[indexes] + det_labels = scores.new_zeros(scores.shape, dtype=torch.long) + + if bbox_3d_pred.numel() == 0: + return ( + det_bboxes, + scores, + det_labels, + bbox_3d_pred.new_empty((0, 10)), + ) + + det_bboxes3d = self.box_coder.decode( + det_bboxes, bbox_3d_pred, intrinsics + ) + + # Remove padding when input_hw is affected by padding + if padding is not None: + det_bboxes[:, 0] -= padding[0] + det_bboxes[:, 1] -= padding[2] + det_bboxes[:, 2] -= padding[0] + det_bboxes[:, 3] -= padding[2] + + scales = [ + (img_shape[1] - padding[0] - padding[1]) / ori_shape[1], + (img_shape[0] - padding[2] - padding[3]) / ori_shape[0], + ] + + else: + scales = [img_shape[1] / ori_shape[1], img_shape[0] / ori_shape[0]] + + # Rescale to original shape + det_bboxes /= det_bboxes.new_tensor(scales).repeat((1, 2)) + + return det_bboxes, scores, det_labels, det_bboxes3d diff --git a/wilddet3d/inference.py b/wilddet3d/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..cdda2f055a6a5d235d6d06c36e44b1be0d13f50f --- /dev/null +++ b/wilddet3d/inference.py @@ -0,0 +1,606 @@ +"""WildDet3D inference wrapper. + +Provides a simple forward() interface for WildDet3D inference: + +Supports three prompt types with 5-mode text labels: +- Text prompt: input_texts=["chair", "table"] +- Box prompt: input_boxes=[[x1, y1, x2, y2]] (pixel xyxy) +- Point prompt: input_points=[[(x, y, label), ...]] (pixel coords, + label: 1=pos, 0=neg) + +5-mode support via prompt_text parameter (for box/point prompts): +- "visual" -> VISUAL mode (one-to-many, no category label) +- "visual: car" -> VISUAL+LABEL mode (one-to-many, with category) +- "geometric" -> GEOMETRY mode (one-to-one, no category label) +- "geometric: car" -> GEOMETRY+LABEL mode (one-to-one, with category) +- "object" -> default (backward compatible) + +Example usage: + from wilddet3d.inference import build_model + from wilddet3d.preprocessing import preprocess + + # Build model + model = build_model( + checkpoint="path/to/checkpoint.ckpt" + ) + + # Preprocess data + data = preprocess(image, intrinsics) + + # TEXT mode + boxes, boxes3d, scores, class_ids, depth_maps = model( + images=data["images"], + intrinsics=data["intrinsics"], + input_hw=[data["input_hw"]], + original_hw=[data["original_hw"]], + padding=[data["padding"]], + input_texts=["chair", "table"], + ) + + # VISUAL mode (box prompt, one-to-many) + boxes, boxes3d, scores, class_ids, depth_maps = model( + ..., + input_boxes=[[100, 200, 300, 400]], + prompt_text="visual", + ) + + # GEOMETRY mode (box prompt, one-to-one) + boxes, boxes3d, scores, class_ids, depth_maps = model( + ..., + input_boxes=[[100, 200, 300, 400]], + prompt_text="geometric", + ) + + # Point prompt (works with any prompt_text) + boxes, boxes3d, scores, class_ids, depth_maps = model( + ..., + input_points=[[(150, 250, 1), (200, 300, 0)]], + prompt_text="geometric", + ) +""" + +from typing import List, Optional, Tuple + +import torch +from torch import Tensor, nn + +from wilddet3d.data_types import WildDet3DInput +from wilddet3d.depth import LingbotDepthBackend +from wilddet3d.depth.depth_fusion import EarlyDepthFusionLingbot +from wilddet3d.head import Det3DCoder, RoI2Det3D +from wilddet3d.model import WildDet3D + + +class WildDet3DPredictor(nn.Module): + """WildDet3D wrapper with a simple forward() interface. + + Provides a simple forward() interface: + boxes, boxes3d, scores, class_ids, depth_maps = model( + images=..., + intrinsics=..., + input_texts=["chair", "table"], + ) + """ + + def __init__( + self, + wilddet3d: WildDet3D, + score_threshold: float = 0.3, + ): + super().__init__() + self.wilddet3d = wilddet3d + self.score_threshold = score_threshold + + def forward( + self, + images: Tensor, + intrinsics: Optional[Tensor], + input_hw: List[Tuple[int, int]], + original_hw: List[Tuple[int, int]], + padding: List[Tuple[int, int, int, int]], + # Prompt types (mutually exclusive) + input_texts: Optional[List[str]] = None, + input_boxes: Optional[List[List[float]]] = None, + input_points: Optional[ + List[List[Tuple[float, float, int]]] + ] = None, + # Text label for box/point prompts (5-mode support) + # e.g. "visual", "visual: car", "geometric", "geometric: car" + prompt_text: str = "object", + return_predicted_intrinsics: bool = False, + # Optional depth input (e.g., from LiDAR) + depth_gt: Optional[Tensor] = None, # (B, 1, H, W) meters + ) -> Tuple[ + List[Tensor], + List[Tensor], + List[Tensor], + List[Tensor], + Optional[List[Tensor]], + Optional[Tensor], + ]: + """Forward with simple interface. + + Args: + images: (B, 3, H, W) preprocessed images + intrinsics: (B, 3, 3) camera intrinsics, or None to use + predicted + input_hw: List of (H, W) tuples for each image + original_hw: List of original (H, W) tuples + padding: List of (left, right, top, bottom) padding tuples + input_texts: Text prompts (e.g., ["chair", "table"]) + input_boxes: Box prompts per image, pixel xyxy + [[x1,y1,x2,y2], ...] + input_points: Point prompts per image + [[(x,y,label), ...], ...] + prompt_text: Text label for box/point prompts. Controls + 5-mode: "object" (default), "visual", "visual: car", + "geometric", "geometric: car" + return_predicted_intrinsics: Whether to return predicted + intrinsics + depth_gt: Optional depth input (B, 1, H, W) in meters + + Returns: + boxes: List of 2D boxes per image (pixel xyxy) + boxes3d: List of 3D boxes per image + scores: List of confidence scores per image + class_ids: List of class IDs per image + depth_maps: List of depth maps per image (or None) + predicted_intrinsics: (B, 3, 3) predicted intrinsics + (if requested) + """ + device = images.device + B = images.shape[0] + H, W = input_hw[0] + + # Determine prompt type and create batch + if input_texts is not None: + batch = self._create_text_batch( + images, + intrinsics, + input_texts, + device, + padding=padding, + ) + class_names = input_texts + elif input_boxes is not None: + batch = self._create_box_batch( + images, + intrinsics, + input_boxes, + (H, W), + device, + text=prompt_text, + padding=padding, + ) + class_names = [prompt_text] + elif input_points is not None: + batch = self._create_point_batch( + images, + intrinsics, + input_points, + (H, W), + device, + text=prompt_text, + padding=padding, + ) + class_names = [prompt_text] + else: + raise ValueError( + "Must provide one of: input_texts, input_boxes, " + "input_points" + ) + + # Attach depth input if provided + if depth_gt is not None: + batch.depth_gt = depth_gt + + # Run inference + with torch.no_grad(): + output = self.wilddet3d(batch) + + # Output is Det3DOut with per-image lists + boxes = output.boxes + boxes3d = output.boxes3d + scores = output.scores + scores_2d = output.scores_2d + scores_3d = output.scores_3d + class_ids = output.class_ids + depth_maps = output.depth_maps + + # Apply score threshold and rescale boxes to original size + boxes_out = [] + boxes3d_out = [] + scores_out = [] + scores_2d_out = [] + scores_3d_out = [] + class_ids_out = [] + + for i in range(B): + # Filter by 2D score + mask = scores[i] >= self.score_threshold + img_scores = scores[i][mask] + img_scores_2d = ( + scores_2d[i][mask] + if scores_2d is not None + else torch.zeros_like(img_scores) + ) + img_scores_3d = ( + scores_3d[i][mask] + if scores_3d is not None + else torch.zeros_like(img_scores) + ) + img_boxes = boxes[i][mask] + img_boxes3d = boxes3d[i][mask] + img_class_ids = class_ids[i][mask] + + # Rescale 2D boxes from input_hw to original_hw + # Account for padding + pad_left, pad_right, pad_top, pad_bottom = padding[i] + orig_h, orig_w = original_hw[i] + + # Remove padding offset and rescale + img_boxes = img_boxes.clone() + img_boxes[:, 0] -= pad_left # x1 + img_boxes[:, 2] -= pad_left # x2 + img_boxes[:, 1] -= pad_top # y1 + img_boxes[:, 3] -= pad_top # y2 + + # Scale from padded size to original + padded_h = H - pad_top - pad_bottom + padded_w = W - pad_left - pad_right + scale_x = orig_w / padded_w + scale_y = orig_h / padded_h + + img_boxes[:, 0::2] *= scale_x + img_boxes[:, 1::2] *= scale_y + + # Clamp to image bounds + img_boxes[:, 0::2] = img_boxes[:, 0::2].clamp(0, orig_w) + img_boxes[:, 1::2] = img_boxes[:, 1::2].clamp(0, orig_h) + + boxes_out.append(img_boxes) + boxes3d_out.append(img_boxes3d) + scores_out.append(img_scores) + scores_2d_out.append(img_scores_2d) + scores_3d_out.append(img_scores_3d) + class_ids_out.append(img_class_ids) + + # Get predicted intrinsics if available + predicted_K = output.predicted_intrinsics + + if return_predicted_intrinsics: + return ( + boxes_out, + boxes3d_out, + scores_out, + scores_2d_out, + scores_3d_out, + class_ids_out, + depth_maps, + predicted_K, + ) + else: + return ( + boxes_out, + boxes3d_out, + scores_out, + scores_2d_out, + scores_3d_out, + class_ids_out, + depth_maps, + ) + + def _create_text_batch( + self, + images: Tensor, + intrinsics: Tensor, + texts: List[str], + device: torch.device, + padding: Optional[List[Tuple[int, int, int, int]]] = None, + ) -> WildDet3DInput: + """Create batch for text prompts.""" + n_prompts = len(texts) + + return WildDet3DInput( + images=images, + intrinsics=intrinsics, + img_ids=torch.zeros( + n_prompts, dtype=torch.long, device=device + ), + text_ids=torch.arange( + n_prompts, dtype=torch.long, device=device + ), + unique_texts=texts, + padding=padding, + ) + + def _create_box_batch( + self, + images: Tensor, + intrinsics: Tensor, + boxes_xyxy: List[List[float]], + input_hw: Tuple[int, int], + device: torch.device, + text: str = "object", + padding: Optional[List[Tuple[int, int, int, int]]] = None, + ) -> WildDet3DInput: + """Create batch for box prompts. + + Args: + text: Text label for the prompt. Controls 5-mode behavior: + "visual" / "visual: car" for one-to-many matching, + "geometric" / "geometric: car" for one-to-one matching. + """ + H, W = input_hw + n_prompts = len(boxes_xyxy) + + # Convert pixel xyxy to normalized cxcywh + boxes_cxcywh = [] + for box in boxes_xyxy: + x1, y1, x2, y2 = box + cx = (x1 + x2) / 2 / W + cy = (y1 + y2) / 2 / H + w = (x2 - x1) / W + h = (y2 - y1) / H + boxes_cxcywh.append([cx, cy, w, h]) + + geo_boxes = torch.tensor( + boxes_cxcywh, dtype=torch.float32, device=device + ) + geo_boxes = geo_boxes.unsqueeze(1) # (n_prompts, 1, 4) + + return WildDet3DInput( + images=images, + intrinsics=intrinsics, + img_ids=torch.zeros( + n_prompts, dtype=torch.long, device=device + ), + text_ids=torch.zeros( + n_prompts, dtype=torch.long, device=device + ), + unique_texts=[text], + geo_boxes=geo_boxes, + geo_boxes_mask=torch.zeros( + n_prompts, 1, dtype=torch.bool, device=device + ), + geo_box_labels=torch.ones( + n_prompts, 1, dtype=torch.long, device=device + ), + padding=padding, + ) + + def _create_point_batch( + self, + images: Tensor, + intrinsics: Tensor, + points_list: List[List[Tuple[float, float, int]]], + input_hw: Tuple[int, int], + device: torch.device, + text: str = "object", + padding: Optional[List[Tuple[int, int, int, int]]] = None, + ) -> WildDet3DInput: + """Create batch for point prompts. + + Args: + text: Text label for the prompt. Controls 5-mode behavior: + "visual" / "visual: car" for one-to-many matching, + "geometric" / "geometric: car" for one-to-one matching. + """ + H, W = input_hw + n_prompts = len(points_list) + + # Find max points per prompt for padding + max_points = max(len(pts) for pts in points_list) + + # Normalize and pad points + geo_points = torch.zeros( + n_prompts, max_points, 2, device=device + ) + geo_point_labels = torch.zeros( + n_prompts, max_points, dtype=torch.long, device=device + ) + geo_points_mask = torch.ones( + n_prompts, max_points, dtype=torch.bool, device=device + ) + + for i, pts in enumerate(points_list): + for j, (x, y, label) in enumerate(pts): + geo_points[i, j] = torch.tensor([x / W, y / H]) + geo_point_labels[i, j] = label + geo_points_mask[i, j] = False # False = valid + + return WildDet3DInput( + images=images, + intrinsics=intrinsics, + img_ids=torch.zeros( + n_prompts, dtype=torch.long, device=device + ), + text_ids=torch.zeros( + n_prompts, dtype=torch.long, device=device + ), + unique_texts=[text], + geo_points=geo_points, + geo_points_mask=geo_points_mask, + geo_point_labels=geo_point_labels, + padding=padding, + ) + + +def build_model( + checkpoint: str, + sam3_checkpoint: str = "pretrained/sam3/sam3_detector.pt", + score_threshold: float = 0.3, + nms: bool = True, + iou_threshold: float = 0.6, + device: str = "cuda", + backbone_freeze_blocks: int = 28, + lingbot_encoder_freeze_blocks: int = 21, + ambiguous_rotation: bool = False, + canonical_rotation: bool = False, + use_depth_input_test: bool = False, + use_predicted_intrinsics: bool = False, + skip_pretrained: bool = False, +) -> WildDet3DPredictor: + """Build WildDet3D model with LingBot-Depth backend. + + Args: + checkpoint: Path to trained WildDet3D checkpoint (.ckpt file) + sam3_checkpoint: Path to SAM3 pretrained weights + score_threshold: Confidence threshold for filtering + nms: Whether to apply NMS + iou_threshold: IoU threshold for NMS + device: Device to load model on + backbone_freeze_blocks: Number of SAM3 ViT blocks to freeze. + lingbot_encoder_freeze_blocks: Number of LingBot encoder blocks + to freeze. + use_predicted_intrinsics: If True, use geometry backend's + predicted intrinsics (K_pred) for 3D box decoding instead of + the input intrinsics. Useful for in-the-wild images without + GT intrinsics. + skip_pretrained: If True, skip loading pretrained weights for + SAM3 and LingBot-Depth. Use this for inference when the + training checkpoint already contains all weights (avoids + loading ~4GB of pretrained weights that get immediately + overwritten). + + Returns: + WildDet3DPredictor model ready for inference + """ + print("Building WildDet3D model with LingBot-Depth backend...") + + # When skip_pretrained=True, patch MDMModel.from_pretrained to build + # model structure from config without loading weights (~1GB saved). + _mdm_patch_cleanup = None + if skip_pretrained: + from mdm.model.v2 import MDMModel + + _orig_from_pretrained = MDMModel.from_pretrained + + @classmethod + def _from_pretrained_config_only(cls, path, **kwargs): + from pathlib import Path as P + + from huggingface_hub import hf_hub_download + + if P(path).exists(): + cp = path + else: + cp = hf_hub_download( + repo_id=path, + repo_type="model", + filename="model.pt", + **kwargs, + ) + ckpt = torch.load( + cp, map_location="cpu", weights_only=True + ) + model = cls(**ckpt["model_config"]) + print( + f"[LingbotDepth] Built model structure from config " + f"(skipped pretrained weights)" + ) + return model + + MDMModel.from_pretrained = _from_pretrained_config_only + _mdm_patch_cleanup = lambda: setattr( + MDMModel, "from_pretrained", _orig_from_pretrained + ) + + # Build geometry backend (LingBot-Depth) + geometry_backend = LingbotDepthBackend( + pretrained_model="robbyant/lingbot-depth-postrain-dc-vitl14", + num_tokens=2400, + target_latent_dim=256, + depth_loss_weight=1.0, + silog_loss_weight=0.5, + monocular_prob=0.7, + masked_prob=0.2, + mask_ratio_range=(0.6, 0.9), + mask_patch_size=14, + camera_loss_weight=1.0, + detach_depth_latents=True, + encoder_freeze_blocks=lingbot_encoder_freeze_blocks, + ) + + # Restore original from_pretrained + if _mdm_patch_cleanup is not None: + _mdm_patch_cleanup() + + # Build components + box_coder = Det3DCoder( + ambiguous_rotation=ambiguous_rotation, + canonical_rotation=canonical_rotation, + ) + roi2det3d = RoI2Det3D( + box_coder=box_coder, + score_threshold=0.0, # Threshold in wrapper + nms=nms, + iou_threshold=iou_threshold, + ) + + # ControlNet-style fusion for LingBot-Depth + early_depth_fusion = EarlyDepthFusionLingbot( + visual_dim=256, + depth_dim=256, + zero_init=True, + ) + + # Build WildDet3D + # When skip_pretrained=True, build SAM3 model structure without + # loading pretrained weights (~3.2GB) since the training checkpoint + # already contains all weights. + if skip_pretrained: + from sam3.model_builder import build_sam3_image_model + + print( + "[skip_pretrained] Building SAM3 structure without " + "pretrained weights..." + ) + sam3_model = build_sam3_image_model( + checkpoint_path=None, + load_from_HF=False, + device="cpu", + eval_mode=False, + enable_segmentation=False, + ) + else: + sam3_model = None + + wilddet3d = WildDet3D( + sam3_model=sam3_model if skip_pretrained else None, + sam3_checkpoint=None if skip_pretrained else sam3_checkpoint, + box_coder=box_coder, + geometry_backend=geometry_backend, + roi2det3d=roi2det3d, + early_depth_fusion=early_depth_fusion, + backbone_freeze_blocks=backbone_freeze_blocks, + use_depth_input_test=use_depth_input_test, + use_predicted_intrinsics=use_predicted_intrinsics, + ) + + # Load trained checkpoint + print(f"Loading checkpoint: {checkpoint}") + ckpt = torch.load(checkpoint, map_location="cpu", weights_only=False) + state_dict = ckpt.get("state_dict", ckpt) + + # Remove "model." prefix + new_state_dict = {} + for k, v in state_dict.items(): + new_key = ( + k.replace("model.", "") if k.startswith("model.") else k + ) + new_state_dict[new_key] = v + + wilddet3d.load_state_dict(new_state_dict, strict=False) + wilddet3d = wilddet3d.to(device) + wilddet3d.eval() + + # Wrap with predictor interface + model = WildDet3DPredictor( + wilddet3d, score_threshold=score_threshold + ) + model = model.to(device) + model.eval() + + print("Model ready!") + return model diff --git a/wilddet3d/loss/__init__.py b/wilddet3d/loss/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/wilddet3d/loss/det2d_loss.py b/wilddet3d/loss/det2d_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..d3c2a09109ae00395909e5474b004e5d26c74e92 --- /dev/null +++ b/wilddet3d/loss/det2d_loss.py @@ -0,0 +1,964 @@ +"""G-DINO Loss.""" + +import torch +from torch import Tensor, nn +from vis4d.common.distributed import reduce_mean +from vis4d.op.loss.common import l1_loss +from vis4d.op.loss.reducer import SumWeightedLoss + +from wilddet3d.ops.box2d import bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh +from wilddet3d.ops.matchers.hungarian import HungarianMatcher +from wilddet3d.loss.focal_loss import FocalLoss +from wilddet3d.loss.iou_loss import GIoULoss +from wilddet3d.ops.match_cost import ( + BBoxL1Cost, + BinaryFocalLossCost, + IoUCost, +) +from wilddet3d.ops.util import multi_apply + + +class Det2DLoss(nn.Module): + """Grounding DINO loss module.""" + + def __init__( + self, max_text_len: int = 256, sync_cls_avg_factor: bool = True + ): + super().__init__() + self.sync_cls_avg_factor = sync_cls_avg_factor + self.max_text_len = max_text_len + + # Matcher + self.cls_cost = BinaryFocalLossCost(weight=2.0) + self.reg_cost = BBoxL1Cost(weight=5.0, box_format="xywh") + self.iou_cost = IoUCost(weight=2.0, iou_mode="giou") + + self.assigner = HungarianMatcher() + + # Losses + self.loss_cls = FocalLoss(alpha=0.25, gamma=2.0) + self.bg_cls_weight = 0.0 + self.cls_loss_weight = 1.0 + + self.loss_bbox = l1_loss + self.bbox_loss_weight = 5.0 + + self.loss_iou = GIoULoss() + self.iou_loss_weight = 2.0 + + def get_targets( + self, + cls_scores_list: list[Tensor], + bbox_preds_list: list[Tensor], + input_hw: list[tuple[int, int]], + batch_gt_boxes: list[Tensor], + batch_gt_boxes_classes: list[Tensor], + positive_maps: list[Tensor], + text_token_mask: Tensor, + ) -> tuple: + """Compute regression and classification targets for a batch image. + + Outputs from a single decoder layer of a single feature level are used. + + Args: + cls_scores_list (list[Tensor]): Box score logits from a single + decoder layer for each image, has shape [num_queries, + cls_out_channels]. + bbox_preds_list (list[Tensor]): Sigmoid outputs from a single + decoder layer for each image, with normalized coordinate + (cx, cy, w, h) and shape [num_queries, 4]. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + + Returns: + tuple: a tuple containing the following targets. + + - labels_list (list[Tensor]): Labels for all images. + - label_weights_list (list[Tensor]): Label weights for all images. + - bbox_targets_list (list[Tensor]): BBox targets for all images. + - bbox_weights_list (list[Tensor]): BBox weights for all images. + - num_total_pos (int): Number of positive samples in all images. + - num_total_neg (int): Number of negative samples in all images. + """ + ( + labels_list, + label_weights_list, + bbox_targets_list, + bbox_weights_list, + pos_inds_list, + neg_inds_list, + ) = multi_apply( + self._get_targets_single, + cls_scores_list, + bbox_preds_list, + input_hw, + batch_gt_boxes, + batch_gt_boxes_classes, + positive_maps, + text_token_mask, + ) + + num_total_pos = sum((inds.numel() for inds in pos_inds_list)) + num_total_neg = sum((inds.numel() for inds in neg_inds_list)) + + return ( + labels_list, + label_weights_list, + bbox_targets_list, + bbox_weights_list, + num_total_pos, + num_total_neg, + ) + + def _get_cost( + self, + cls_score, + bbox_pred, + gt_boxes, + input_hw, + text_token_mask, + positive_map, + ): + """Compute regression and classification cost for one image.""" + if self.cls_cost.weight != 0: + cls_cost = self.cls_cost(cls_score, text_token_mask, positive_map) + else: + cls_cost = 0 + + if self.reg_cost.weight != 0: + reg_cost = self.reg_cost( + bbox_pred, gt_boxes, input_hw[0], input_hw[1] + ) + else: + reg_cost = 0 + + if self.iou_cost.weight != 0: + iou_cost = self.iou_cost(bbox_pred, gt_boxes) + else: + iou_cost = 0 + + return cls_cost + reg_cost + iou_cost + + def _get_targets_2d_single( + self, + cls_score: Tensor, + bbox_pred: Tensor, + input_hw: tuple[int, int], + gt_boxes: Tensor, + gt_classes: Tensor, + positive_map: Tensor, + text_token_mask: Tensor, + ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + """Compute regression and classification targets for one image.""" + img_h, img_w = input_hw + num_bboxes = bbox_pred.size(0) + factor = bbox_pred.new_tensor([img_w, img_h, img_w, img_h]).unsqueeze( + 0 + ) + + # convert bbox_pred from xywh, normalized to xyxy, unnormalized + bbox_pred = bbox_cxcywh_to_xyxy(bbox_pred) + bbox_pred = bbox_pred * factor + + # assigner and sampler + cost = self._get_cost( + cls_score, + bbox_pred, + gt_boxes, + input_hw, + text_token_mask, + positive_map, + ) + + assign_result = self.assigner(cost, bbox_pred, gt_boxes, gt_classes) + + pos_inds = ( + torch.nonzero( + assign_result.assigned_gt_indices > 0, as_tuple=False + ) + .squeeze(-1) + .unique() + ) + neg_inds = ( + torch.nonzero( + assign_result.assigned_gt_indices == 0, as_tuple=False + ) + .squeeze(-1) + .unique() + ) + pos_assigned_gt_inds = assign_result.assigned_gt_indices[pos_inds] - 1 + pos_gt_bboxes = gt_boxes[pos_assigned_gt_inds.long(), :] + + # Major changes. The labels are 0-1 binary labels for each bbox + # and text tokens. + labels = gt_boxes.new_full( + (num_bboxes, self.max_text_len), 0, dtype=torch.float32 + ) + labels[pos_inds] = positive_map[pos_assigned_gt_inds] + label_weights = gt_boxes.new_ones(num_bboxes) + + # bbox targets + bbox_targets = torch.zeros_like(bbox_pred, dtype=gt_boxes.dtype) + bbox_weights = torch.zeros_like(bbox_pred, dtype=gt_boxes.dtype) + bbox_weights[pos_inds] = 1.0 + + # DETR regress the relative position of boxes (cxcywh) in the image. + # Thus the learning target should be normalized by the image size, also + # the box format should be converted from defaultly x1y1x2y2 to cxcywh. + pos_gt_bboxes_normalized = pos_gt_bboxes / factor + pos_gt_bboxes_targets = bbox_xyxy_to_cxcywh(pos_gt_bboxes_normalized) + bbox_targets[pos_inds] = pos_gt_bboxes_targets + + return ( + labels, + label_weights, + bbox_targets, + bbox_weights, + pos_gt_bboxes, + pos_inds, + neg_inds, + pos_assigned_gt_inds, + ) + + def _get_targets_single( + self, + cls_score: Tensor, + bbox_pred: Tensor, + input_hw: tuple[int, int], + gt_boxes: Tensor, + gt_classes: Tensor, + positive_map: Tensor, + text_token_mask: Tensor, + ) -> tuple: + """Compute regression and classification targets for one image. + + Outputs from a single decoder layer of a single feature level are used. + + Args: + cls_score (Tensor): Box score logits from a single decoder layer + for one image. Shape [num_queries, cls_out_channels]. + bbox_pred (Tensor): Sigmoid outputs from a single decoder layer + for one image, with normalized coordinate (cx, cy, w, h) and + shape [num_queries, 4]. + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It should includes ``bboxes`` and ``labels`` + attributes. + img_meta (dict): Meta information for one image. + + Returns: + tuple[Tensor]: a tuple containing the following for one image. + + - labels (Tensor): Labels of each image. + - label_weights (Tensor]): Label weights of each image. + - bbox_targets (Tensor): BBox targets of each image. + - bbox_weights (Tensor): BBox weights of each image. + - pos_inds (Tensor): Sampled positive indices for each image. + - neg_inds (Tensor): Sampled negative indices for each image. + """ + ( + labels, + label_weights, + bbox_targets, + bbox_weights, + _, + pos_inds, + neg_inds, + _, + ) = self._get_targets_2d_single( + cls_score, + bbox_pred, + input_hw, + gt_boxes, + gt_classes, + positive_map, + text_token_mask, + ) + + return ( + labels, + label_weights, + bbox_targets, + bbox_weights, + pos_inds, + neg_inds, + ) + + def loss_by_feat_single( + self, + cls_scores: Tensor, + bbox_preds: Tensor, + text_token_mask: Tensor, + input_hw: list[tuple[int, int]], + batch_gt_boxes: list[Tensor], + batch_gt_boxes_classes: list[Tensor], + positive_maps: list[Tensor], + ) -> tuple[Tensor]: + """Loss function for outputs from a single decoder layer of a single + feature level. + + Args: + cls_scores (Tensor): Box score logits from a single decoder layer + for all images, has shape (bs, num_queries, cls_out_channels). + bbox_preds (Tensor): Sigmoid outputs from a single decoder layer + for all images, with normalized coordinate (cx, cy, w, h) and + shape (bs, num_queries, 4). + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + + Returns: + Tuple[Tensor]: A tuple including `loss_cls`, `loss_box` and + `loss_iou`. + """ + num_imgs = cls_scores.size(0) + + cls_scores_list = [cls_scores[i] for i in range(num_imgs)] + bbox_preds_list = [bbox_preds[i] for i in range(num_imgs)] + + with torch.no_grad(): + ( + labels_list, + label_weights_list, + bbox_targets_list, + bbox_weights_list, + num_total_pos, + num_total_neg, + ) = self.get_targets( + cls_scores_list, + bbox_preds_list, + input_hw, + batch_gt_boxes, + batch_gt_boxes_classes, + positive_maps, + text_token_mask, + ) + + labels = torch.stack(labels_list, 0) + label_weights = torch.stack(label_weights_list, 0) + bbox_targets = torch.cat(bbox_targets_list, 0) + bbox_weights = torch.cat(bbox_weights_list, 0) + + # Loss is not computed for the padded regions of the text. + assert text_token_mask.dim() == 2 + text_masks = text_token_mask.new_zeros( + (text_token_mask.size(0), self.max_text_len) + ) + text_masks[:, : text_token_mask.size(1)] = text_token_mask + text_mask = (text_masks > 0).unsqueeze(1) + text_mask = text_mask.repeat(1, cls_scores.size(1), 1) + cls_scores = torch.masked_select(cls_scores, text_mask).contiguous() + + labels = torch.masked_select(labels, text_mask) + label_weights = label_weights[..., None].repeat( + 1, 1, text_mask.size(-1) + ) + label_weights = torch.masked_select(label_weights, text_mask) + + # classification loss + # construct weighted avg_factor to match with the official DETR repo + cls_avg_factor = ( + num_total_pos * 1.0 + num_total_neg * self.bg_cls_weight + ) + if self.sync_cls_avg_factor: + cls_avg_factor = reduce_mean( + cls_scores.new_tensor([cls_avg_factor]) + ) + cls_avg_factor = max(cls_avg_factor, 1) + + loss_cls = self.cls_loss_weight * self.loss_cls( + cls_scores, + labels, + reducer=SumWeightedLoss( + weight=label_weights, avg_factor=cls_avg_factor + ), + ) + + # Compute the average number of gt boxes across all gpus, for + # normalization purposes + num_total_pos = loss_cls.new_tensor([num_total_pos]) + num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item() + + # construct factors used for rescale bboxes + factors = [] + for img_hw, bbox_pred in zip(input_hw, bbox_preds): + img_h, img_w = img_hw + factor = ( + bbox_pred.new_tensor([img_w, img_h, img_w, img_h]) + .unsqueeze(0) + .repeat(bbox_pred.size(0), 1) + ) + factors.append(factor) + factors = torch.cat(factors, 0) + + # DETR regress the relative position of boxes (cxcywh) in the image, + # thus the learning target is normalized by the image size. So here + # we need to re-scale them for calculating IoU loss + bbox_preds = bbox_preds.reshape(-1, 4) + bboxes = bbox_cxcywh_to_xyxy(bbox_preds) * factors + bboxes_gt = bbox_cxcywh_to_xyxy(bbox_targets) * factors + + # regression L1 loss + loss_bbox = self.bbox_loss_weight * self.loss_bbox( + bbox_preds, + bbox_targets, + reducer=SumWeightedLoss( + weight=bbox_weights, avg_factor=num_total_pos + ), + ) + + # regression IoU loss, defaultly GIoU loss + loss_iou = self.iou_loss_weight * self.loss_iou( + bboxes, + bboxes_gt, + reducer=SumWeightedLoss( + weight=bbox_weights.mean(-1), avg_factor=num_total_pos + ), + ) + + return loss_cls, loss_bbox, loss_iou + + def forward( + self, + all_layers_cls_scores: Tensor, + all_layers_bbox_preds: Tensor, + text_token_mask: Tensor, + enc_cls_scores: Tensor, + enc_bbox_preds: Tensor, + dn_meta: dict[str, int], + positive_maps: list[Tensor], + input_hw: list[tuple[int, int]], + batch_gt_boxes: list[Tensor], + batch_gt_boxes_classes: list[Tensor], + ) -> dict[str, Tensor]: + """Loss function. + + Args: + all_layers_cls_scores (Tensor): Classification scores of all + decoder layers, has shape (num_decoder_layers, bs, + num_queries_total, cls_out_channels), where + `num_queries_total` is the sum of `num_denoising_queries` + and `num_matching_queries`. + all_layers_bbox_preds (Tensor): Regression outputs of all decoder + layers. Each is a 4D-tensor with normalized coordinate format + (cx, cy, w, h) and has shape (num_decoder_layers, bs, + num_queries_total, 4). + enc_cls_scores (Tensor): The score of each point on encode + feature map, has shape (bs, num_feat_points, cls_out_channels). + enc_bbox_preds (Tensor): The proposal generate from the encode + feature map, has shape (bs, num_feat_points, 4) with the last + dimension arranged as (cx, cy, w, h). + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + dn_meta (Dict[str, int]): The dictionary saves information about + group collation, including 'num_denoising_queries' and + 'num_denoising_groups'. It will be used for split outputs of + denoising and matching parts and loss calculation. + batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + # extract denoising and matching part of outputs + ( + all_layers_matching_cls_scores, + all_layers_matching_bbox_preds, + all_layers_denoising_cls_scores, + all_layers_denoising_bbox_preds, + ) = split_outputs( + all_layers_cls_scores, all_layers_bbox_preds, dn_meta + ) + + # DETRHead loss_by_feat + losses_cls, losses_bbox, losses_iou = multi_apply( + self.loss_by_feat_single, + all_layers_matching_cls_scores, + all_layers_matching_bbox_preds, + text_token_mask=text_token_mask, + input_hw=input_hw, + batch_gt_boxes=batch_gt_boxes, + batch_gt_boxes_classes=batch_gt_boxes_classes, + positive_maps=positive_maps, + ) + + loss_dict = dict() + + # loss from the last decoder layer + loss_dict["loss_cls"] = losses_cls[-1] + loss_dict["loss_bbox"] = losses_bbox[-1] + loss_dict["loss_iou"] = losses_iou[-1] + + # loss from other decoder layers + for num_dec_layer, (loss_cls_i, loss_bbox_i, loss_iou_i) in enumerate( + zip(losses_cls[:-1], losses_bbox[:-1], losses_iou[:-1]) + ): + loss_dict[f"d{num_dec_layer}.loss_cls"] = loss_cls_i + loss_dict[f"d{num_dec_layer}.loss_bbox"] = loss_bbox_i + loss_dict[f"d{num_dec_layer}.loss_iou"] = loss_iou_i + + # loss of proposal generated from encode feature map. + if enc_cls_scores is not None: + # NOTE The enc_loss calculation of the DINO is + # different from that of Deformable DETR. + enc_loss_cls, enc_losses_bbox, enc_losses_iou = ( + self.loss_by_feat_single( + enc_cls_scores, + enc_bbox_preds, + text_token_mask=text_token_mask, + input_hw=input_hw, + batch_gt_boxes=batch_gt_boxes, + batch_gt_boxes_classes=batch_gt_boxes_classes, + positive_maps=positive_maps, + ) + ) + loss_dict["enc_loss_cls"] = enc_loss_cls + loss_dict["enc_loss_bbox"] = enc_losses_bbox + loss_dict["enc_loss_iou"] = enc_losses_iou + + if all_layers_denoising_cls_scores is not None: + # calculate denoising loss from all decoder layers + dn_losses_cls, dn_losses_bbox, dn_losses_iou = self.loss_dn( + all_layers_denoising_cls_scores, + all_layers_denoising_bbox_preds, + boxes2d=batch_gt_boxes, + boxes2d_classes=batch_gt_boxes_classes, + positive_maps=positive_maps, + input_hw=input_hw, + text_token_mask=text_token_mask, + dn_meta=dn_meta, + ) + + # collate denoising loss + loss_dict["dn_loss_cls"] = dn_losses_cls[-1] + loss_dict["dn_loss_bbox"] = dn_losses_bbox[-1] + loss_dict["dn_loss_iou"] = dn_losses_iou[-1] + + for num_dec_layer, ( + loss_cls_i, + loss_bbox_i, + loss_iou_i, + ) in enumerate( + zip( + dn_losses_cls[:-1], dn_losses_bbox[:-1], dn_losses_iou[:-1] + ) + ): + loss_dict[f"d{num_dec_layer}.dn_loss_cls"] = loss_cls_i + loss_dict[f"d{num_dec_layer}.dn_loss_bbox"] = loss_bbox_i + loss_dict[f"d{num_dec_layer}.dn_loss_iou"] = loss_iou_i + + return loss_dict + + def _get_dn_targets_single( + self, + gt_bboxes: Tensor, + gt_labels: Tensor, + positive_maps: Tensor, + img_shape: tuple[int, int], + num_groups: int, + num_denoising_queries: int, + ) -> tuple: + """Get targets in denoising part for one image. + + Args: + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It should includes ``bboxes`` and ``labels`` + attributes. + img_meta (dict): Meta information for one image. + dn_meta (Dict[str, int]): The dictionary saves information about + group collation, including 'num_denoising_queries' and + 'num_denoising_groups'. It will be used for split outputs of + denoising and matching parts and loss calculation. + + Returns: + tuple[Tensor]: a tuple containing the following for one image. + + - labels (Tensor): Labels of each image. + - label_weights (Tensor]): Label weights of each image. + - bbox_targets (Tensor): BBox targets of each image. + - bbox_weights (Tensor): BBox weights of each image. + - pos_inds (Tensor): Sampled positive indices for each image. + - neg_inds (Tensor): Sampled negative indices for each image. + """ + num_queries_each_group = int(num_denoising_queries / num_groups) + device = gt_bboxes.device + + if len(gt_labels) > 0: + t = torch.arange(len(gt_labels), dtype=torch.long, device=device) + t = t.unsqueeze(0).repeat(num_groups, 1) + pos_assigned_gt_inds = t.flatten() + pos_inds = torch.arange( + num_groups, dtype=torch.long, device=device + ) + pos_inds = pos_inds.unsqueeze(1) * num_queries_each_group + t + pos_inds = pos_inds.flatten() + else: + pos_inds = pos_assigned_gt_inds = gt_bboxes.new_tensor( + [], dtype=torch.long + ) + + neg_inds = pos_inds + num_queries_each_group // 2 + # label targets + # this change + labels = gt_bboxes.new_full( + (num_denoising_queries, self.max_text_len), 0, dtype=torch.float32 + ) + labels[pos_inds] = positive_maps[pos_assigned_gt_inds] + label_weights = gt_bboxes.new_ones(num_denoising_queries) + + # bbox targets + bbox_targets = torch.zeros(num_denoising_queries, 4, device=device) + bbox_weights = torch.zeros(num_denoising_queries, 4, device=device) + bbox_weights[pos_inds] = 1.0 + + img_h, img_w = img_shape + + # DETR regress the relative position of boxes (cxcywh) in the image. + # Thus the learning target should be normalized by the image size, also + # the box format should be converted from defaultly x1y1x2y2 to cxcywh. + factor = gt_bboxes.new_tensor([img_w, img_h, img_w, img_h]).unsqueeze( + 0 + ) + gt_bboxes_normalized = gt_bboxes / factor + gt_bboxes_targets = bbox_xyxy_to_cxcywh(gt_bboxes_normalized) + bbox_targets[pos_inds] = gt_bboxes_targets.repeat([num_groups, 1]) + + return ( + labels, + label_weights, + bbox_targets, + bbox_weights, + pos_inds, + neg_inds, + ) + + def get_dn_targets( + self, + boxes2d: list[Tensor], + boxes2d_classes: list[Tensor], + positive_maps: list[Tensor], + input_hw: list[tuple[int, int]], + dn_meta: dict[str, int], + ) -> tuple: + """Get targets in denoising part for a batch of images. + + Args: + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + dn_meta (Dict[str, int]): The dictionary saves information about + group collation, including 'num_denoising_queries' and + 'num_denoising_groups'. It will be used for split outputs of + denoising and matching parts and loss calculation. + + Returns: + tuple: a tuple containing the following targets. + + - labels_list (list[Tensor]): Labels for all images. + - label_weights_list (list[Tensor]): Label weights for all images. + - bbox_targets_list (list[Tensor]): BBox targets for all images. + - bbox_weights_list (list[Tensor]): BBox weights for all images. + - num_total_pos (int): Number of positive samples in all images. + - num_total_neg (int): Number of negative samples in all images. + """ + ( + labels_list, + label_weights_list, + bbox_targets_list, + bbox_weights_list, + pos_inds_list, + neg_inds_list, + ) = multi_apply( + self._get_dn_targets_single, + boxes2d, + boxes2d_classes, + positive_maps, + input_hw, + num_groups=dn_meta["num_denoising_groups"], + num_denoising_queries=dn_meta["num_denoising_queries"], + ) + + num_total_pos = sum((inds.numel() for inds in pos_inds_list)) + num_total_neg = sum((inds.numel() for inds in neg_inds_list)) + + return ( + labels_list, + label_weights_list, + bbox_targets_list, + bbox_weights_list, + num_total_pos, + num_total_neg, + ) + + def _loss_dn_single( + self, + dn_cls_scores: Tensor, + dn_bbox_preds: Tensor, + boxes2d: list[Tensor], + boxes2d_classes: list[Tensor], + positive_maps: list[Tensor], + input_hw: list[tuple[int, int]], + text_token_mask: Tensor, + dn_meta, + ): + """Denoising loss for outputs from a single decoder layer. + + Args: + dn_cls_scores (Tensor): Classification scores of a single decoder + layer in denoising part, has shape (bs, num_denoising_queries, + cls_out_channels). + dn_bbox_preds (Tensor): Regression outputs of a single decoder + layer in denoising part. Each is a 4D-tensor with normalized + coordinate format (cx, cy, w, h) and has shape + (bs, num_denoising_queries, 4). + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + dn_meta (Dict[str, int]): The dictionary saves information about + group collation, including 'num_denoising_queries' and + 'num_denoising_groups'. It will be used for split outputs of + denoising and matching parts and loss calculation. + + Returns: + Tuple[Tensor]: A tuple including `loss_cls`, `loss_box` and + `loss_iou`. + """ + ( + labels_list, + label_weights_list, + bbox_targets_list, + bbox_weights_list, + num_total_pos, + num_total_neg, + ) = self.get_dn_targets( + boxes2d, boxes2d_classes, positive_maps, input_hw, dn_meta + ) + + labels = torch.stack(labels_list, 0) + label_weights = torch.stack(label_weights_list, 0) + bbox_targets = torch.cat(bbox_targets_list, 0) + bbox_weights = torch.cat(bbox_weights_list, 0) + + # Loss is not computed for the padded regions of the text. + assert text_token_mask.dim() == 2 + text_masks = text_token_mask.new_zeros( + (text_token_mask.size(0), self.max_text_len) + ) + text_masks[:, : text_token_mask.size(1)] = text_token_mask + text_mask = (text_masks > 0).unsqueeze(1) + text_mask = text_mask.repeat(1, dn_cls_scores.size(1), 1) + cls_scores = torch.masked_select(dn_cls_scores, text_mask).contiguous() + labels = torch.masked_select(labels, text_mask) + label_weights = label_weights[..., None].repeat( + 1, 1, text_mask.size(-1) + ) + label_weights = torch.masked_select(label_weights, text_mask) + + # classification loss + # construct weighted avg_factor to match with the official DETR repo + cls_avg_factor = ( + num_total_pos * 1.0 + num_total_neg * self.bg_cls_weight + ) + if self.sync_cls_avg_factor: + cls_avg_factor = reduce_mean( + cls_scores.new_tensor([cls_avg_factor]) + ) + cls_avg_factor = max(cls_avg_factor, 1) + + if len(cls_scores) > 0: + loss_cls = self.cls_loss_weight * self.loss_cls( + cls_scores, + labels, + reducer=SumWeightedLoss( + weight=label_weights, avg_factor=cls_avg_factor + ), + ) + else: + loss_cls = torch.zeros( + 1, dtype=cls_scores.dtype, device=cls_scores.device + ) + + # Compute the average number of gt boxes across all gpus, for + # normalization purposes + num_total_pos = loss_cls.new_tensor([num_total_pos]) + num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item() + + # construct factors used for rescale bboxes + factors = [] + for img_hw, bbox_pred in zip(input_hw, dn_bbox_preds): + img_h, img_w = img_hw + factor = ( + bbox_pred.new_tensor([img_w, img_h, img_w, img_h]) + .unsqueeze(0) + .repeat(bbox_pred.size(0), 1) + ) + factors.append(factor) + factors = torch.cat(factors) + + # DETR regress the relative position of boxes (cxcywh) in the image, + # thus the learning target is normalized by the image size. So here + # we need to re-scale them for calculating IoU loss + bbox_preds = dn_bbox_preds.reshape(-1, 4) + bboxes = bbox_cxcywh_to_xyxy(bbox_preds) * factors + bboxes_gt = bbox_cxcywh_to_xyxy(bbox_targets) * factors + + if bbox_targets.shape[0] == 0: + loss_bbox = bbox_preds.sum() + loss_iou = bbox_preds.sum() + return loss_cls, loss_bbox, loss_iou + + # regression L1 loss + loss_bbox = self.bbox_loss_weight * self.loss_bbox( + bbox_preds, + bbox_targets, + reducer=SumWeightedLoss( + weight=bbox_weights, avg_factor=num_total_pos + ), + ) + + # regression IoU loss, defaultly GIoU loss + loss_iou = self.iou_loss_weight * self.loss_iou( + bboxes, + bboxes_gt, + reducer=SumWeightedLoss( + weight=bbox_weights.mean(-1), avg_factor=num_total_pos + ), + ) + + return loss_cls, loss_bbox, loss_iou + + def loss_dn( + self, + all_layers_denoising_cls_scores: Tensor, + all_layers_denoising_bbox_preds: Tensor, + boxes2d: list[Tensor], + boxes2d_classes: list[Tensor], + positive_maps: list[Tensor], + input_hw: list[tuple[int, int]], + text_token_mask: Tensor, + dn_meta: dict[str, int], + ): + """Calculate denoising loss. + + Args: + all_layers_denoising_cls_scores (Tensor): Classification scores of + all decoder layers in denoising part, has shape ( + num_decoder_layers, bs, num_denoising_queries, + cls_out_channels). + all_layers_denoising_bbox_preds (Tensor): Regression outputs of all + decoder layers in denoising part. Each is a 4D-tensor with + normalized coordinate format (cx, cy, w, h) and has shape + (num_decoder_layers, bs, num_denoising_queries, 4). + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + dn_meta (Dict[str, int]): The dictionary saves information about + group collation, including 'num_denoising_queries' and + 'num_denoising_groups'. It will be used for split outputs of + denoising and matching parts and loss calculation. + + Returns: + Tuple[List[Tensor]]: The loss_dn_cls, loss_dn_bbox, and loss_dn_iou + of each decoder layers. + """ + return multi_apply( + self._loss_dn_single, + all_layers_denoising_cls_scores, + all_layers_denoising_bbox_preds, + boxes2d=boxes2d, + boxes2d_classes=boxes2d_classes, + positive_maps=positive_maps, + input_hw=input_hw, + text_token_mask=text_token_mask, + dn_meta=dn_meta, + ) + + +# TODO: Move to DINO ops +def split_outputs( + all_layers_cls_scores: Tensor, + all_layers_bbox_preds: Tensor, + dn_meta: dict[str, int] | None = None, +) -> tuple[Tensor, Tensor, Tensor, Tensor]: + """Split outputs of the denoising part and the matching part. + + For the total outputs of `num_queries_total` length, the former + `num_denoising_queries` outputs are from denoising queries, and + the rest `num_matching_queries` ones are from matching queries, + where `num_queries_total` is the sum of `num_denoising_queries` and + `num_matching_queries`. + + Args: + all_layers_cls_scores (Tensor): Classification scores of all + decoder layers, has shape (num_decoder_layers, bs, + num_queries_total, cls_out_channels). + all_layers_bbox_preds (Tensor): Regression outputs of all decoder + layers. Each is a 4D-tensor with normalized coordinate format + (cx, cy, w, h) and has shape (num_decoder_layers, bs, + num_queries_total, 4). + dn_meta (Dict[str, int]): The dictionary saves information about + group collation, including 'num_denoising_queries' and + 'num_denoising_groups'. + + Returns: + Tuple[Tensor]: a tuple containing the following outputs. + + - all_layers_matching_cls_scores (Tensor): Classification scores + of all decoder layers in matching part, has shape + (num_decoder_layers, bs, num_matching_queries, cls_out_channels). + - all_layers_matching_bbox_preds (Tensor): Regression outputs of + all decoder layers in matching part. Each is a 4D-tensor with + normalized coordinate format (cx, cy, w, h) and has shape + (num_decoder_layers, bs, num_matching_queries, 4). + - all_layers_denoising_cls_scores (Tensor): Classification scores + of all decoder layers in denoising part, has shape + (num_decoder_layers, bs, num_denoising_queries, + cls_out_channels). + - all_layers_denoising_bbox_preds (Tensor): Regression outputs of + all decoder layers in denoising part. Each is a 4D-tensor with + normalized coordinate format (cx, cy, w, h) and has shape + (num_decoder_layers, bs, num_denoising_queries, 4). + """ + # FIXME: Can dn_meta be None? + num_denoising_queries = dn_meta["num_denoising_queries"] + + if dn_meta is not None: + all_layers_denoising_cls_scores = all_layers_cls_scores[ + :, :, :num_denoising_queries, : + ] + all_layers_denoising_bbox_preds = all_layers_bbox_preds[ + :, :, :num_denoising_queries, : + ] + all_layers_matching_cls_scores = all_layers_cls_scores[ + :, :, num_denoising_queries:, : + ] + all_layers_matching_bbox_preds = all_layers_bbox_preds[ + :, :, num_denoising_queries:, : + ] + else: + all_layers_denoising_cls_scores = None + all_layers_denoising_bbox_preds = None + all_layers_matching_cls_scores = all_layers_cls_scores + all_layers_matching_bbox_preds = all_layers_bbox_preds + + return ( + all_layers_matching_cls_scores, + all_layers_matching_bbox_preds, + all_layers_denoising_cls_scores, + all_layers_denoising_bbox_preds, + ) diff --git a/wilddet3d/loss/det3d_loss.py b/wilddet3d/loss/det3d_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..739d0b36e24847c9839a6bb967a1d984b8d40948 --- /dev/null +++ b/wilddet3d/loss/det3d_loss.py @@ -0,0 +1,490 @@ +"""3D-MOOD loss.""" + +from __future__ import annotations + +import torch +from torch import Tensor +from vis4d.common.distributed import reduce_mean +from vis4d.common.typing import ArgsType +from vis4d.op.loss.common import l1_loss +from vis4d.op.loss.reducer import SumWeightedLoss + +from wilddet3d.ops.box2d import bbox_cxcywh_to_xyxy +from wilddet3d.loss.det2d_loss import ( + Det2DLoss, + split_outputs, +) +from wilddet3d.ops.util import multi_apply + +from .coder import Det3DCoder + + +class Det3DLoss(Det2DLoss): + """Grounding DINO with 3D loss.""" + + def __init__( + self, + *args: ArgsType, + box_coder: Det3DCoder | None = None, + loss_center_weight: float = 1.0, + loss_depth_weight: float = 1.0, + loss_dim_weight: float = 1.0, + loss_rot_weight: float = 1.0, + loss_2d_scale: float = 1.0, + loss_3d_scale: float = 1.0, + **kwargs: ArgsType, + ): + """Init.""" + super().__init__(*args, **kwargs) + self.box_coder = box_coder or Det3DCoder() + + self.reg_dims = self.box_coder.reg_dims + + self.loss_center_weight = loss_center_weight + self.loss_depth_weight = loss_depth_weight + self.loss_dim_weight = loss_dim_weight + self.loss_rot_weight = loss_rot_weight + self.loss_2d_scale = loss_2d_scale + self.loss_3d_scale = loss_3d_scale + + def get_targets_3d( + self, + cls_scores_list: list[Tensor], + bbox_preds_list: list[Tensor], + bbox_preds_3d_list: list[Tensor], + input_hw: list[tuple[int, int]], + batch_gt_boxes: list[Tensor], + batch_gt_boxes_3d: list[Tensor], + batch_gt_boxes_classes: list[Tensor], + batch_gt_intrinsics: list[Tensor], + positive_maps: list[Tensor], + text_token_mask: Tensor, + ) -> tuple: + """Compute regression and classification targets for a batch image.""" + ( + labels_list, + label_weights_list, + bbox_targets_list, + bbox_weights_list, + bbox_targets_3d_list, + bbox_weights_3d_list, + pos_inds_list, + neg_inds_list, + ) = multi_apply( + self._get_targets_3d_single, + cls_scores_list, + bbox_preds_list, + bbox_preds_3d_list, + input_hw, + batch_gt_boxes, + batch_gt_boxes_3d, + batch_gt_boxes_classes, + batch_gt_intrinsics, + positive_maps, + text_token_mask, + ) + + num_total_pos = sum((inds.numel() for inds in pos_inds_list)) + num_total_neg = sum((inds.numel() for inds in neg_inds_list)) + + return ( + labels_list, + label_weights_list, + bbox_targets_list, + bbox_weights_list, + bbox_targets_3d_list, + bbox_weights_3d_list, + num_total_pos, + num_total_neg, + ) + + def _get_targets_3d_single( + self, + cls_score: Tensor, + bbox_pred: Tensor, + bbox_pred_3d: Tensor, + input_hw: tuple[int, int], + gt_boxes: Tensor, + gt_boxes_3d: Tensor, + gt_classes: Tensor, + gt_intrinsics: Tensor, + positive_map: Tensor, + text_token_mask: Tensor, + ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + """Compute regression and classification targets for one image.""" + # 2D Target + with torch.no_grad(): + ( + labels, + label_weights, + bbox_targets, + bbox_weights, + pos_pred_boxes2d, + pos_inds, + neg_inds, + pos_assigned_gt_inds, + ) = self._get_targets_2d_single( + cls_score, + bbox_pred, + input_hw, + gt_boxes, + gt_classes, + positive_map, + text_token_mask, + ) + + # 3D Target + pos_gt_boxes3d = gt_boxes_3d[pos_assigned_gt_inds.long(), :] + + pos_gt_bboxes_3d, pos_gt_bboxes_3d_weights = self.box_coder.encode( + pos_pred_boxes2d, pos_gt_boxes3d, gt_intrinsics + ) + + bbox_targets_3d = torch.zeros_like(bbox_pred_3d) + bbox_targets_3d[pos_inds] = pos_gt_bboxes_3d + + bbox_weights_3d = torch.zeros_like(bbox_pred_3d) + bbox_weights_3d[pos_inds] = pos_gt_bboxes_3d_weights + + return ( + labels, + label_weights, + bbox_targets, + bbox_weights, + bbox_targets_3d, + bbox_weights_3d, + pos_inds, + neg_inds, + ) + + def loss_3d_by_feat_single( + self, + cls_scores: Tensor, + bbox_preds: Tensor, + bbox_3d_preds: Tensor, + text_token_mask: Tensor, + input_hw: list[tuple[int, int]], + batch_gt_boxes: list[Tensor], + batch_gt_boxes_3d: list[Tensor], + batch_gt_boxes_classes: list[Tensor], + batch_gt_intrinsics: list[Tensor], + positive_maps: list[Tensor], + ): + """Loss function for outputs from a single decoder layer.""" + num_imgs = cls_scores.size(0) + + cls_scores_list = [cls_scores[i] for i in range(num_imgs)] + bbox_preds_list = [bbox_preds[i] for i in range(num_imgs)] + bbox_preds_3d_list = [bbox_3d_preds[i] for i in range(num_imgs)] + + ( + labels_list, + label_weights_list, + bbox_targets_list, + bbox_weights_list, + bbox_targets_3d_list, + bbox_weights_3d_list, + num_total_pos, + num_total_neg, + ) = self.get_targets_3d( + cls_scores_list, + bbox_preds_list, + bbox_preds_3d_list, + input_hw, + batch_gt_boxes, + batch_gt_boxes_3d, + batch_gt_boxes_classes, + batch_gt_intrinsics, + positive_maps, + text_token_mask, + ) + + labels = torch.stack(labels_list, 0) + label_weights = torch.stack(label_weights_list, 0) + bbox_targets = torch.cat(bbox_targets_list, 0) + bbox_targets_3d = torch.cat(bbox_targets_3d_list, 0) + bbox_weights = torch.cat(bbox_weights_list, 0) + bbox_weights_3d = torch.cat(bbox_weights_3d_list, 0) + + # Loss is not computed for the padded regions of the text. + assert text_token_mask.dim() == 2 + text_masks = text_token_mask.new_zeros( + (text_token_mask.size(0), self.max_text_len) + ) + text_masks[:, : text_token_mask.size(1)] = text_token_mask + text_mask = (text_masks > 0).unsqueeze(1) + text_mask = text_mask.repeat(1, cls_scores.size(1), 1) + cls_scores = torch.masked_select(cls_scores, text_mask).contiguous() + + labels = torch.masked_select(labels, text_mask) + label_weights = label_weights[..., None].repeat( + 1, 1, text_mask.size(-1) + ) + label_weights = torch.masked_select(label_weights, text_mask) + + # classification loss + # construct weighted avg_factor to match with the official DETR repo + cls_avg_factor = ( + num_total_pos * 1.0 + num_total_neg * self.bg_cls_weight + ) + if self.sync_cls_avg_factor: + cls_avg_factor = reduce_mean( + cls_scores.new_tensor([cls_avg_factor]) + ) + cls_avg_factor = max(cls_avg_factor, 1) + + loss_cls = self.loss_2d_scale * self.cls_loss_weight * self.loss_cls( + cls_scores, + labels, + reducer=SumWeightedLoss( + weight=label_weights, avg_factor=cls_avg_factor + ), + ) + + # Compute the average number of gt boxes across all gpus, for + # normalization purposes + num_total_pos = loss_cls.new_tensor([num_total_pos]) + num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item() + + # construct factors used for rescale bboxes + factors = [] + for img_hw, bbox_pred in zip(input_hw, bbox_preds): + img_h, img_w = img_hw + factor = ( + bbox_pred.new_tensor([img_w, img_h, img_w, img_h]) + .unsqueeze(0) + .repeat(bbox_pred.size(0), 1) + ) + factors.append(factor) + factors = torch.cat(factors, 0) + + # DETR regress the relative position of boxes (cxcywh) in the image, + # thus the learning target is normalized by the image size. So here + # we need to re-scale them for calculating IoU loss + bbox_preds = bbox_preds.reshape(-1, 4) + bboxes = bbox_cxcywh_to_xyxy(bbox_preds) * factors + bboxes_gt = bbox_cxcywh_to_xyxy(bbox_targets) * factors + + # regression L1 loss (2D) + loss_bbox = self.loss_2d_scale * self.bbox_loss_weight * self.loss_bbox( + bbox_preds, + bbox_targets, + reducer=SumWeightedLoss( + weight=bbox_weights, avg_factor=num_total_pos + ), + ) + + # regression IoU loss (2D) + loss_iou = self.loss_2d_scale * self.iou_loss_weight * self.loss_iou( + bboxes, + bboxes_gt, + reducer=SumWeightedLoss( + weight=bbox_weights.mean(-1), avg_factor=num_total_pos + ), + ) + + # 3D Loss + bbox_3d_preds = bbox_3d_preds.reshape(-1, self.reg_dims) + + # Delta 2D center Loss (3D) + loss_cen = self.loss_3d_scale * self.loss_center_weight * l1_loss( + bbox_3d_preds[:, :2], + bbox_targets_3d[:, :2], + reducer=SumWeightedLoss( + weight=bbox_weights_3d[:, :2], avg_factor=num_total_pos + ), + ) + + # Depth Loss (3D) + loss_depth = self.loss_3d_scale * self.loss_depth_weight * l1_loss( + bbox_3d_preds[:, 2], + bbox_targets_3d[:, 2], + reducer=SumWeightedLoss( + weight=bbox_weights_3d[:, 2], avg_factor=num_total_pos + ), + ) + + # Dimension Loss (3D) + loss_dim = self.loss_3d_scale * self.loss_dim_weight * l1_loss( + bbox_3d_preds[:, 3:6], + bbox_targets_3d[:, 3:6], + reducer=SumWeightedLoss( + weight=bbox_weights_3d[:, 3:6], avg_factor=num_total_pos + ), + ) + + # Rotation Loss (3D) + loss_rot = self.loss_3d_scale * self.loss_rot_weight * l1_loss( + bbox_3d_preds[:, 6:], + bbox_targets_3d[:, 6:], + reducer=SumWeightedLoss( + weight=bbox_weights_3d[:, 6:], avg_factor=num_total_pos + ), + ) + + return ( + loss_cls, + loss_bbox, + loss_iou, + loss_cen, + loss_depth, + loss_dim, + loss_rot, + ) + + def forward( + self, + all_layers_cls_scores: Tensor, + all_layers_bbox_preds: Tensor, + all_layers_bbox_3d_preds: Tensor, + text_token_mask: Tensor, + enc_cls_scores: Tensor, + enc_bbox_preds: Tensor, + enc_outputs_3d: Tensor, + dn_meta: dict[str, int], + positive_maps: list[Tensor], + input_hw: list[tuple[int, int]], + batch_gt_boxes: list[Tensor], + batch_gt_boxes_3d: list[Tensor], + batch_gt_boxes_classes: list[Tensor], + batch_gt_intrinsics: list[Tensor], + ) -> dict[str, Tensor]: + """Forward pass of the 3D Grounding DINO loss.""" + ( + all_layers_matching_cls_scores, + all_layers_matching_bbox_preds, + all_layers_denoising_cls_scores, + all_layers_denoising_bbox_preds, + ) = split_outputs( + all_layers_cls_scores, all_layers_bbox_preds, dn_meta + ) + + ( + losses_cls, + losses_bbox, + losses_iou, + losses_cen, + losses_depth, + losses_dim, + losses_rot, + ) = multi_apply( + self.loss_3d_by_feat_single, + all_layers_matching_cls_scores, + all_layers_matching_bbox_preds, + all_layers_bbox_3d_preds, + text_token_mask=text_token_mask, + input_hw=input_hw, + batch_gt_boxes=batch_gt_boxes, + batch_gt_boxes_3d=batch_gt_boxes_3d, + batch_gt_boxes_classes=batch_gt_boxes_classes, + batch_gt_intrinsics=batch_gt_intrinsics, + positive_maps=positive_maps, + ) + + loss_dict = dict() + + # loss from the last decoder layer + loss_dict["loss_cls"] = losses_cls[-1] + loss_dict["loss_bbox"] = losses_bbox[-1] + loss_dict["loss_iou"] = losses_iou[-1] + loss_dict["loss_delta_2d"] = losses_cen[-1] + loss_dict["loss_depth"] = losses_depth[-1] + loss_dict["loss_dim"] = losses_dim[-1] + loss_dict["loss_rot"] = losses_rot[-1] + + # loss from other decoder layers + for num_dec_layer, (loss_cls_i, loss_bbox_i, loss_iou_i) in enumerate( + zip(losses_cls[:-1], losses_bbox[:-1], losses_iou[:-1]) + ): + loss_dict[f"d{num_dec_layer}.loss_cls"] = loss_cls_i + loss_dict[f"d{num_dec_layer}.loss_bbox"] = loss_bbox_i + loss_dict[f"d{num_dec_layer}.loss_iou"] = loss_iou_i + loss_dict[f"d{num_dec_layer}.loss_delta_2d"] = losses_cen[ + num_dec_layer + ] + loss_dict[f"d{num_dec_layer}.loss_depth"] = losses_depth[ + num_dec_layer + ] + loss_dict[f"d{num_dec_layer}.loss_dim"] = losses_dim[num_dec_layer] + loss_dict[f"d{num_dec_layer}.loss_rot"] = losses_rot[num_dec_layer] + + # loss of proposal generated from encode feature map. + if enc_cls_scores is not None: + if enc_outputs_3d is None: + # NOTE The enc_loss calculation of the DINO is + # different from that of Deformable DETR. + enc_loss_cls, enc_losses_bbox, enc_losses_iou = ( + self.loss_by_feat_single( + enc_cls_scores, + enc_bbox_preds, + text_token_mask=text_token_mask, + input_hw=input_hw, + batch_gt_boxes=batch_gt_boxes, + batch_gt_boxes_classes=batch_gt_boxes_classes, + positive_maps=positive_maps, + ) + ) + loss_dict["enc_loss_cls"] = enc_loss_cls + loss_dict["enc_loss_bbox"] = enc_losses_bbox + loss_dict["enc_loss_iou"] = enc_losses_iou + else: + ( + enc_loss_cls, + enc_losses_bbox, + enc_losses_iou, + enc_losses_cen, + enc_losses_depth, + enc_losses_dim, + enc_losses_rot, + ) = self.loss_3d_by_feat_single( + enc_cls_scores, + enc_bbox_preds, + enc_outputs_3d, + text_token_mask=text_token_mask, + input_hw=input_hw, + batch_gt_boxes=batch_gt_boxes, + batch_gt_boxes_3d=batch_gt_boxes_3d, + batch_gt_boxes_classes=batch_gt_boxes_classes, + batch_gt_intrinsics=batch_gt_intrinsics, + positive_maps=positive_maps, + ) + loss_dict["enc_loss_cls"] = enc_loss_cls + loss_dict["enc_loss_bbox"] = enc_losses_bbox + loss_dict["enc_loss_iou"] = enc_losses_iou + loss_dict["enc_loss_delta_2d"] = enc_losses_cen + loss_dict["enc_loss_depth"] = enc_losses_depth + loss_dict["enc_loss_dim"] = enc_losses_dim + loss_dict["enc_loss_rot"] = enc_losses_rot + + if all_layers_denoising_cls_scores is not None: + # calculate denoising loss from all decoder layers + dn_losses_cls, dn_losses_bbox, dn_losses_iou = self.loss_dn( + all_layers_denoising_cls_scores, + all_layers_denoising_bbox_preds, + boxes2d=batch_gt_boxes, + boxes2d_classes=batch_gt_boxes_classes, + positive_maps=positive_maps, + input_hw=input_hw, + text_token_mask=text_token_mask, + dn_meta=dn_meta, + ) + + # collate denoising loss + loss_dict["dn_loss_cls"] = dn_losses_cls[-1] + loss_dict["dn_loss_bbox"] = dn_losses_bbox[-1] + loss_dict["dn_loss_iou"] = dn_losses_iou[-1] + + for num_dec_layer, ( + loss_cls_i, + loss_bbox_i, + loss_iou_i, + ) in enumerate( + zip( + dn_losses_cls[:-1], dn_losses_bbox[:-1], dn_losses_iou[:-1] + ) + ): + loss_dict[f"d{num_dec_layer}.dn_loss_cls"] = loss_cls_i + loss_dict[f"d{num_dec_layer}.dn_loss_bbox"] = loss_bbox_i + loss_dict[f"d{num_dec_layer}.dn_loss_iou"] = loss_iou_i + + return loss_dict diff --git a/wilddet3d/loss/focal_loss.py b/wilddet3d/loss/focal_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..b1cdb12ad534780d2b33bac40f3159c78322e68b --- /dev/null +++ b/wilddet3d/loss/focal_loss.py @@ -0,0 +1,62 @@ +"""Focal Loss.""" + +from __future__ import annotations + +import torch.nn.functional as F +from torch import Tensor +from torchvision.ops import sigmoid_focal_loss +from vis4d.op.loss.base import Loss +from vis4d.op.loss.reducer import LossReducer, mean_loss + + +class FocalLoss(Loss): + """Focal loss `_.""" + + def __init__( + self, + alpha: float = 0.25, + gamma: float = 2.0, + reducer: LossReducer = mean_loss, + ) -> None: + """Creates an instance of the class. + + Args: + alpha (float, optional): A balanced form for Focal Loss. + Defaults to 0.25. + gamma (float, optional): The gamma for calculating the modulating + factor. Defaults to 2.0. + reducer (LossReducer, optional): Reducer for the loss function. + Defaults to mean_loss. + """ + super().__init__(reducer) + self.alpha = alpha + self.gamma = gamma + + def forward( + self, pred: Tensor, target: Tensor, reducer: LossReducer | None = None + ) -> Tensor: + """Forward function. + + Args: + pred (Tensor): The prediction. + target (Tensor): The learning label of the prediction. + + Returns: + Tensor: The calculated loss. + """ + # this means that target is not in One-Hot form. + if pred.dim() != target.dim(): + num_classes = pred.size(1) + target = F.one_hot(target, num_classes=num_classes + 1).float() + target = target[:, :num_classes] + + reducer = reducer or self.reducer + + focal_loss = sigmoid_focal_loss( + pred, + target, + alpha=self.alpha, + gamma=self.gamma, + ) + + return reducer(focal_loss) diff --git a/wilddet3d/loss/geom_loss_aggregator.py b/wilddet3d/loss/geom_loss_aggregator.py new file mode 100644 index 0000000000000000000000000000000000000000..a0459b01730c973841195726e1a7250747bce2e2 --- /dev/null +++ b/wilddet3d/loss/geom_loss_aggregator.py @@ -0,0 +1,55 @@ +"""Geometry Loss Aggregator. + +This module provides a loss class that aggregates geometry losses from +the model output (geom_losses dict from GeometryBackend). +""" + +from __future__ import annotations + +from torch import Tensor +from vis4d.common.typing import ArgsType +from vis4d.op.loss.base import Loss + + +class GeomLossAggregator(Loss): + """Aggregates geometry losses from model output. + + This loss class takes the geom_losses dict from the model output + and returns the sum of all losses. Each individual loss is also + logged separately. + + Args: + weight: Global weight multiplier for all geometry losses. + """ + + def __init__( + self, + *args: ArgsType, + weight: float = 1.0, + **kwargs: ArgsType, + ) -> None: + """Initialize the GeomLossAggregator.""" + super().__init__(*args, **kwargs) + self.weight = weight + + def forward( + self, + geom_losses: dict[str, Tensor] | None, + ) -> dict[str, Tensor]: + """Forward function. + + Args: + geom_losses: Dictionary of geometry losses from the model. + + Returns: + Dictionary of weighted losses. + """ + if geom_losses is None or len(geom_losses) == 0: + return {} + + weighted_losses = {} + for name, loss in geom_losses.items(): + weighted_losses[f"geom_{name}"] = loss * self.weight + + return weighted_losses + diff --git a/wilddet3d/loss/iou_loss.py b/wilddet3d/loss/iou_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..f5d93c7efa7ff98cb6facb4b83076d79e7b7e4ae --- /dev/null +++ b/wilddet3d/loss/iou_loss.py @@ -0,0 +1,81 @@ +"""IoU Loss.""" + +import torch +from torch import Tensor +from vis4d.op.loss.base import Loss +from vis4d.op.loss.reducer import LossReducer, mean_loss + +from wilddet3d.ops.box2d import bbox_overlaps + + +def giou_loss(pred: Tensor, target: Tensor, eps: float = 1e-7) -> Tensor: + r"""`Generalized Intersection over Union: A Metric and A Loss for Bounding + Box Regression `_. + + Args: + pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2), + shape (n, 4). + target (Tensor): Corresponding gt bboxes, shape (n, 4). + eps (float): Epsilon to avoid log(0). + + Return: + Tensor: Loss tensor. + """ + # avoid fp16 overflow + if pred.dtype == torch.float16: + fp16 = True + pred = pred.to(torch.float32) + else: + fp16 = False + + gious = bbox_overlaps(pred, target, mode="giou", is_aligned=True, eps=eps) + + if fp16: + gious = gious.to(torch.float16) + + loss = 1 - gious + return loss + + +class GIoULoss(Loss): + r"""`Generalized Intersection over Union: A Metric and A Loss for Bounding + Box Regression `_. + """ + + def __init__( + self, + eps: float = 1e-6, + reducer: LossReducer = mean_loss, + ) -> None: + super().__init__(reducer) + self.eps = eps + + def forward( + self, + pred: Tensor, + target: Tensor, + reducer: LossReducer | None = None, + ) -> Tensor: + """Forward function. + + Args: + pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2), + shape (n, 4). + target (Tensor): The learning target of the prediction, + shape (n, 4). + weight (Optional[Tensor], optional): The weight of loss for each + prediction. Defaults to None. + avg_factor (Optional[int], optional): Average factor that is used + to average the loss. Defaults to None. + reduction_override (Optional[str], optional): The reduction method + used to override the original reduction method of the loss. + Defaults to None. Options are "none", "mean" and "sum". + + Returns: + Tensor: Loss tensor. + """ + reducer = reducer or self.reducer + + loss = giou_loss(pred, target, eps=self.eps) + + return reducer(loss) diff --git a/wilddet3d/loss/silog_loss.py b/wilddet3d/loss/silog_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..de46146513fda9269994dffe007a4d4aa1bd6fbd --- /dev/null +++ b/wilddet3d/loss/silog_loss.py @@ -0,0 +1,59 @@ +"""SILog loss for depth estimation.""" + +from __future__ import annotations + +import torch +from torch import Tensor +from vis4d.common.typing import ArgsType +from vis4d.op.loss.base import Loss + +from .util import masked_mean_var + + +class SILogLoss(Loss): + """SILogLoss.""" + + def __init__( + self, + *args: ArgsType, + scale_pred_weight: float = 0.15, + eps: float = 1e-5, + min_depth: float = 0.0, + **kwargs: ArgsType, + ) -> None: + """Init.""" + super().__init__(*args, **kwargs) + self.scale_pred_weight = scale_pred_weight + self.eps = eps + self.min_depth = min_depth + + def forward( + self, depths: Tensor, target_depths: Tensor, mask: Tensor | None = None + ) -> Tensor: + """Forward function. + + Args: + depths (Tensor): Predicted depth. Shape: (B, H, W) + target_depths (Tensor): Target depth. Shape: (B, H, W) + mask (Tensor | None): Mask. Shape: (B, H, W) + """ + if mask is None: + mask = target_depths > self.min_depth + else: + mask = mask.to(torch.bool) + mask = torch.logical_and(mask, target_depths > self.min_depth) + + log_depths = torch.log(depths.clamp(min=self.eps)) + log_target_depths = torch.log(target_depths.clamp(min=self.eps)) + + log_error = log_depths - log_target_depths + + mean_error, var_error = masked_mean_var(log_error, mask=mask) + + scale_error = mean_error**2 + + loss = var_error + self.scale_pred_weight * scale_error + + out_loss = torch.sqrt(loss.clamp(min=self.eps)) + + return out_loss.mean() diff --git a/wilddet3d/loss/util.py b/wilddet3d/loss/util.py new file mode 100644 index 0000000000000000000000000000000000000000..0f01a93fdbcedfab1396af9ec44f4c6cabb75751 --- /dev/null +++ b/wilddet3d/loss/util.py @@ -0,0 +1,35 @@ +"""Loss util.""" + +from __future__ import annotations + +import torch +from torch import Tensor + + +def masked_mean_var(error: Tensor, mask: Tensor | None = None) -> Tensor: + """Compute mean and variance of error with mask.""" + if mask is None: + return error.mean(dim=[-2, -1], keepdim=True), error.var( + dim=[-2, -1], keepdim=True + ) + mask = mask.float() + mask_sum = torch.sum(mask, dim=[-2, -1], keepdim=True) + mask_mean = torch.sum( + error * mask, dim=[-2, -1], keepdim=True + ) / torch.clamp(mask_sum, min=1.0) + mask_var = torch.sum( + mask * (error - mask_mean) ** 2, dim=[-2, -1], keepdim=True + ) / torch.clamp(mask_sum, min=1.0) + return mask_mean.squeeze([-2, -1]), mask_var.squeeze([-2, -1]) + + +def masked_mean(data: Tensor, mask: Tensor | None): + """Compute mean of data with mask.""" + if mask is None: + return data.mean(dim=[-2, -1], keepdim=True) + mask = mask.float() + mask_sum = torch.sum(mask, dim=[-2, -1], keepdim=True) + mask_mean = torch.sum( + data * mask, dim=[-2, -1], keepdim=True + ) / torch.clamp(mask_sum, min=1.0) + return mask_mean diff --git a/wilddet3d/loss/wilddet3d_loss.py b/wilddet3d/loss/wilddet3d_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..d5947ab9822f8e66c32bd6d89a5cc6ea075e82d1 --- /dev/null +++ b/wilddet3d/loss/wilddet3d_loss.py @@ -0,0 +1,1256 @@ +"""WildDet3D Loss Module. + +This module implements the loss function for WildDet3D, combining: +1. SAM3-style 2D losses (IABCEMdetr for classification, L1+GIoU for boxes) +2. 3D-MOOD-style 3D losses (delta_center, depth, dimensions, rotation) + +Key Design Decisions: +- Uses SAM3's Hungarian matcher for assignment (already computed in model) +- Follows SAM3's loss normalization (global/local/none) +- Adds 3D regression losses on top of 2D losses +- Supports deep supervision on auxiliary decoder outputs +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Literal + +import numpy as np +import torch +from torch import Tensor, nn +import torch.nn.functional as F + +from vis4d.common.distributed import reduce_mean +from vis4d.op.loss.common import l1_loss +from vis4d.op.loss.reducer import SumWeightedLoss + +from wilddet3d.head.coder_3d import Det3DCoder +from sam3.model.box_ops import fast_diag_box_iou, fast_diag_generalized_box_iou +from sam3.train.matcher import BinaryOneToManyMatcher +from sam3.train.loss.loss_fns import ( + IABCEMdetr, + Boxes as SAM3Boxes, + sigmoid_focal_loss, +) + + +def _packed_to_padded(boxes_packed: Tensor, num_boxes: Tensor, fill_value: float = 0.0) -> Tensor: + """Convert packed tensor to padded tensor. + + This function converts a packed (concatenated) tensor of bounding boxes + to a batch-wise padded tensor, following SAM3's collator implementation. + + Args: + boxes_packed: Packed boxes tensor of shape (N_total, 4) where + N_total = N_1 + N_2 + ... + N_B + num_boxes: Number of boxes per image, shape (B,) + fill_value: Value to use for padding (default: 0.0) + + Returns: + Padded boxes tensor of shape (B, max_N, 4) where max_N = max(num_boxes) + + Example: + >>> boxes = torch.tensor([[1,2,3,4], [5,6,7,8], [9,10,11,12]]) + >>> num_boxes = torch.tensor([1, 2]) + >>> padded = _packed_to_padded(boxes, num_boxes) + >>> padded.shape + torch.Size([2, 2, 4]) + """ + B = num_boxes.shape[0] + Ns = num_boxes.tolist() + max_N = max(Ns) + + # Create padded tensor + boxes_padded = boxes_packed.new_full((B, max_N, *boxes_packed.shape[1:]), fill_value) + + # Fill in actual boxes + prev_idx = 0 + for i in range(B): + next_idx = prev_idx + Ns[i] + boxes_padded[i, :Ns[i]] = boxes_packed[prev_idx:next_idx] + prev_idx = next_idx + + return boxes_padded + + +@dataclass +class WildDet3DLossConfig: + """Configuration for WildDet3D loss. + + Follows SAM3's loss configuration style with additional 3D loss weights. + """ + # ========== Global Scale Factors ========== + # These allow adjusting the balance between 2D, 3D, and geometry losses + # Default 1.0, can be adjusted in training config to tune 2D:3D:Geom ratio + loss_2d_scale: float = 1.0 # Scale for 2D losses (cls, bbox, giou) + loss_3d_scale: float = 1.0 # Scale for 3D losses (delta, depth, dim, rot) + loss_geom_scale: float = 10.0 # Scale for geometry backend losses (SILog, SSI, camera angles) + + # ========== O2M (One-to-Many) Matcher Configuration ========== + # Note: O2O matcher is configured in wilddet3d.py (self.sam3.matcher) + use_o2m: bool = True # Enable O2M matching + o2m_loss_clip: float = 150.0 # Clip O2M loss to prevent gradient explosion + o2m_alpha: float = 0.3 # Alpha for O2M cost computation + o2m_threshold: float = 0.4 # IoU threshold for O2M matching + o2m_topk: int = 4 # Top-k predictions per GT (SAM3 original: topk: 4) + o2m_loss_weight: float = 2.0 # Weight for O2M loss (SAM3 original: o2m_weight: 2.0) + + # ========== 2D Loss Weights (SAM3 style) ========== + # Classification loss (IABCEMdetr style) + loss_cls_weight: float = 20.0 # SAM3 original + pos_weight: float = 5.0 # SAM3 original (was incorrectly 10.0) + gamma: float = 2.0 # SAM3 original focal (was incorrectly 0.0) + alpha: float = 0.25 # IoU-aware alpha + + # IABCEMdetr advanced features + use_weak_loss: bool = False # Enable weak supervision (SAM3 original: weak_loss: False) + weak_loss_weight: float = 1.0 # Weight for weak loss (only used if use_weak_loss=True) + use_presence: bool = True # Enable presence loss (per-category presence detection) + presence_loss_weight: float = 20.0 # Weight for presence loss (SAM3 original: presence_weight: 20.0) + presence_alpha: float = 0.5 # SAM3 original presence focal loss alpha + presence_gamma: float = 0.0 # SAM3 original (gamma=0 = plain BCE, no focal weighting) + + # Box regression loss + loss_bbox_weight: float = 5.0 # L1 loss weight + loss_giou_weight: float = 2.0 # GIoU loss weight + + # ========== 3D Loss Weights (3D-MOOD style) ========== + loss_delta_2d_weight: float = 1.0 # Delta 2D center + loss_depth_weight: float = 1.0 # Log depth + loss_dim_weight: float = 1.0 # Log dimensions + loss_rot_weight: float = 1.0 # 6D rotation + + # ========== Geometry Backend Loss Weights ========== + loss_silog_weight: float = 1.0 # SILog depth loss + loss_phi_weight: float = 0.1 # Phi angle loss + loss_theta_weight: float = 0.1 # Theta angle loss + loss_opt_ssi_weight: float = 0.5 # SSI loss weight (UniDepthV2) + + # ========== Normalization ========== + normalization: Literal["global", "local", "none"] = "global" + + # ========== Auxiliary Loss ========== + aux_loss_weight: float = 1.0 # Weight for auxiliary decoder outputs + + # ========== Mask Loss (optional) ========== + loss_mask_weight: float = 0.0 # Set > 0 to enable mask loss + loss_dice_weight: float = 0.0 # Set > 0 to enable dice loss + + # ========== 3D Confidence Head ========== + # Positive: soft target = quality (iou_3d + depth). Negative: push to 0. + # Inference: final_score = 2d_score + conf_3d_inference_weight * 3d_score + use_3d_conf: bool = False # Enable 3D confidence head loss + loss_3d_conf_weight: float = 20.0 # Weight for 3D confidence loss (same as 2D loss_cls_weight) + conf_depth_weight: float = 0.7 # Weight for depth quality in quality target + conf_iou_3d_weight: float = 0.3 # Weight for 3D IoU in quality target + + # ========== Ignore Box Negative Loss Suppression ========== + # Suppress negative classification loss for predictions that overlap + # with ignore-annotated objects (truncated, occluded, etc.). + # This aligns training with eval, where such detections are neutral. + use_ignore_suppress: bool = False + ignore_iou_threshold: float = 0.5 # 2D IoU threshold for suppression + + +class WildDet3DLoss(nn.Module): + """Loss function for WildDet3D. + + Combines SAM3-style 2D losses with 3D-MOOD-style 3D losses. + + Loss Components: + 1. Classification: IABCEMdetr (IoU-aware BCE with soft targets) + 2. 2D Box: L1 + GIoU + 3. 3D Box: L1 for (delta_center, log_depth, log_dims, rot_6d) + 4. Geometry: SILog depth + phi/theta angles (from geometry backend) + """ + + def __init__( + self, + config: WildDet3DLossConfig | None = None, + box_coder: Det3DCoder | None = None, + ) -> None: + """Initialize WildDet3D loss. + + Args: + config: Loss configuration + box_coder: 3D box encoder/decoder for target encoding + """ + super().__init__() + self.config = config or WildDet3DLossConfig() + self.box_coder = box_coder or Det3DCoder() + self.reg_dims = self.box_coder.reg_dims + + # SAM3's 2D loss classes (directly imported from sam3.train.loss.loss_fns) + # weak_loss=False follows SAM3's own training configs — all unmatched + # predictions receive negative loss regardless of is_exhaustive. + self.cls_loss = IABCEMdetr( + pos_weight=self.config.pos_weight, + gamma=self.config.gamma, + alpha=self.config.alpha, + weak_loss=False, + use_presence=self.config.use_presence, + presence_alpha=self.config.presence_alpha, + presence_gamma=self.config.presence_gamma, + ) + self.box_loss = SAM3Boxes() + + # O2M matcher for DAC one-to-many loss + if self.config.use_o2m: + self.o2m_matcher = BinaryOneToManyMatcher( + alpha=self.config.o2m_alpha, + threshold=self.config.o2m_threshold, + topk=self.config.o2m_topk, + ) + else: + self.o2m_matcher = None + + def _compute_ignore_neg_mask( + self, + pred_boxes: Tensor, + ignore_boxes: Tensor, + num_ignores: Tensor, + threshold: float = 0.5, + ) -> Tensor: + """Compute mask for predictions overlapping ignore boxes. + + Args: + pred_boxes: (B, S, 4) normalized xyxy predicted boxes. + ignore_boxes: (B, max_ignore, 4) normalized xyxy ignore boxes. + num_ignores: (B,) number of valid ignore boxes per prompt. + threshold: 2D IoU threshold above which to suppress. + + Returns: + mask: (B, S) float. 1.0 = suppress negative loss, 0.0 = normal. + """ + import torchvision.ops + + B, S, _ = pred_boxes.shape + device = pred_boxes.device + mask = torch.zeros(B, S, device=device) + + for b in range(B): + n_ign = num_ignores[b].item() + if n_ign == 0: + continue + iou = torchvision.ops.box_iou( + pred_boxes[b], + ignore_boxes[b, :n_ign], + ) # (S, n_ign) + mask[b] = (iou.max(dim=1).values > threshold).float() + + return mask + + def _build_targets_from_batch( + self, batch: "WildDet3DInput" + ) -> dict[str, Tensor]: + """Build targets dict from WildDet3DInput. + + WildDet3D uses per-category queries with multi-instance targets. + The collator produces: + - gt_boxes2d: (N_prompts, max_gt, 4) - multiple GTs per query + - gt_boxes3d: (N_prompts, max_gt, 12) - multiple GTs per query (if available) + - num_gts: (N_prompts,) - number of valid GTs per query (can be > 1) + + We convert this to the packed format expected by loss computation. + + Args: + batch: WildDet3DInput containing GT boxes + + Returns: + targets dict with: + - boxes_xyxy: (N_total, 4) GT boxes in xyxy format (packed) + - boxes_3d: (N_total, 12) 3D GT boxes (packed) + - num_boxes: (N_prompts,) number of GTs per query + - intrinsics: (N_prompts, 3, 3) camera intrinsics per prompt + """ + device = batch.images.device + N_prompts = batch.img_ids.shape[0] + + # Extract GT from batch + gt_boxes2d = batch.gt_boxes2d # (N_prompts, max_gt, 4) or (N_prompts, 4) + gt_boxes3d = batch.gt_boxes3d # (N_prompts, max_gt, 12) or None + num_gts = batch.num_gts # (N_prompts,) number of valid GTs per query + + if gt_boxes2d is None: + # No GT available + return { + "boxes_xyxy": torch.zeros(0, 4, device=device), + "boxes_3d": torch.zeros(0, 12, device=device), + "classes": torch.zeros(0, dtype=torch.long, device=device), + "num_boxes": torch.zeros(N_prompts, dtype=torch.long, device=device), + "intrinsics": batch.intrinsics[batch.img_ids], + } + + # Handle both old (N_prompts, 4) and new (N_prompts, max_gt, 4) formats + if gt_boxes2d.dim() == 2: + # Old format: (N_prompts, 4) - single GT per prompt + boxes_xyxy = gt_boxes2d + if num_gts is None: + num_gts = torch.ones(N_prompts, dtype=torch.long, device=device) + + if gt_boxes3d is not None and gt_boxes3d.dim() == 2: + boxes_3d = gt_boxes3d + else: + boxes_3d = torch.zeros(N_prompts, 12, device=device) + else: + # New format: (N_prompts, max_gt, 4) - multi-instance targets + # Pack valid boxes into a flat tensor + if num_gts is None: + # Fallback: assume all boxes are valid + num_gts = torch.tensor([gt_boxes2d.shape[1]] * N_prompts, dtype=torch.long, device=device) + + # Pack boxes into (N_total, 4) + boxes_list = [] + boxes_3d_list = [] + for i in range(N_prompts): + n_gt = num_gts[i].item() + boxes_list.append(gt_boxes2d[i, :n_gt]) # (n_gt, 4) + if gt_boxes3d is not None: + boxes_3d_list.append(gt_boxes3d[i, :n_gt]) # (n_gt, 12) + + if boxes_list: + boxes_xyxy = torch.cat(boxes_list, dim=0) # (N_total, 4) + else: + boxes_xyxy = torch.zeros(0, 4, device=device) + + if boxes_3d_list: + boxes_3d = torch.cat(boxes_3d_list, dim=0) # (N_total, 12) + else: + box3d_dim = gt_boxes3d.shape[-1] if gt_boxes3d is not None else 12 + boxes_3d = torch.zeros(boxes_xyxy.shape[0], box3d_dim, device=device) + + # SAM3 uses binary detection (all targets are class 1) + N_total = boxes_xyxy.shape[0] + classes = torch.ones(N_total, dtype=torch.long, device=device) + + # Get per-prompt intrinsics + intrinsics = batch.intrinsics[batch.img_ids] # (N_prompts, 3, 3) + + # SAM3's IABCEMdetr and Boxes loss classes need additional formats: + # - boxes (cxcywh packed) for L1 loss + # - boxes_padded (cxcywh padded) for presence keep_loss + # - object_ids_padded for presence keep_loss + # - is_exhaustive for weak loss masking + boxes_cxcywh = self._xyxy_to_cxcywh(boxes_xyxy) + + # Padded format (B, max_N, 4) for presence loss keep_loss computation + boxes_padded = _packed_to_padded(boxes_cxcywh, num_gts) + max_N = boxes_padded.shape[1] + + # Object IDs: sequential within each prompt's targets + object_ids_padded = torch.full( + (N_prompts, max_N), -1, dtype=torch.long, device=device + ) + offset = 0 + for i in range(N_prompts): + n = int(num_gts[i].item()) + if n > 0: + object_ids_padded[i, :n] = torch.arange( + offset, offset + n, device=device + ) + offset += n + + # is_exhaustive: multi-target queries are exhaustive, single-target are not + # query_types: 0=TEXT, 1=VISUAL, 3=VISUAL+LABEL → exhaustive (True) + # query_types: 2=GEOMETRY, 4=GEOMETRY+LABEL → not exhaustive (False) + if batch.query_types is not None: + qt = batch.query_types.to(device) + is_exhaustive = (qt == 0) | (qt == 1) | (qt == 3) + else: + is_exhaustive = torch.ones(N_prompts, dtype=torch.bool, device=device) + + return { + "boxes_xyxy": boxes_xyxy, + "boxes": boxes_cxcywh, + "boxes_padded": boxes_padded, + "boxes_3d": boxes_3d, + "classes": classes, + "num_boxes": num_gts, + "intrinsics": intrinsics, + "object_ids_padded": object_ids_padded, + "is_exhaustive": is_exhaustive, + } + + def forward( + self, + out: "WildDet3DOutput", + batch: "WildDet3DInput", + ) -> dict[str, Tensor]: + """Compute all losses. + + vis4d LossModule interface: expects either Tensor, dict, or namedtuple. + We return a dict of tensors, and LossModule will sum them automatically. + + Following SAM3 and GDino3D's design, we compute 2D box L1 loss in normalized + cxcywh space and GIoU loss in pixel xyxy space for consistent loss weights. + + Args: + out: Model output (WildDet3DOutput dataclass) + batch: Input batch (WildDet3DInput dataclass) + + Returns: + Dict of loss tensors (vis4d LossModule will sum them) + """ + import time + import os + import torch + _PROFILE_LOSS = os.environ.get("PROFILE_WILDDET3D", "0") == "1" + if _PROFILE_LOSS: + torch.cuda.synchronize() + _loss_start = time.perf_counter() + # Unpack model outputs + pred_logits = out.pred_logits + pred_boxes_2d = out.pred_boxes_2d + pred_boxes_3d = out.pred_boxes_3d + aux_outputs = out.aux_outputs + geom_losses = out.geom_losses + + # Build targets from batch + # Get per-prompt intrinsics by indexing into batch intrinsics + B_images = batch.images.shape[0] + N_prompts = batch.img_ids.shape[0] + intrinsics = batch.intrinsics[batch.img_ids] # (N_prompts, 3, 3) + + # Image size from batch + image_size = (batch.images.shape[2], batch.images.shape[3]) # (H, W) + + if _PROFILE_LOSS: + torch.cuda.synchronize() + _t_targets = time.perf_counter() + + targets = self._build_targets_from_batch(batch) + losses = {} + + # Normalize targets to [0, 1] range (for matching and computation) + normalized_targets = self._normalize_targets(targets) + + if _PROFILE_LOSS: + torch.cuda.synchronize() + _loss_targets_time = (time.perf_counter() - _t_targets) * 1000 + + # Store image_size for pixel coordinate conversion + if image_size is None and "image_size" in targets: + image_size = targets["image_size"] + + # Get matching indices from SAM3's internal matching + # SAM3's forward_grounding computes indices via _compute_matching when find_target is provided + # Handle empty batch (N_prompts=0) case - return zero loss with grad + if out.indices is None: + device = pred_logits.device + rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + print(f"[WildDet3D Loss] Empty batch detected on rank {rank}, returning zero loss") + + # CRITICAL: Must still participate in all_reduce to prevent DDP deadlock + # Other ranks may have non-empty batches and will call all_reduce + if self.config.normalization == "global" and torch.distributed.is_initialized(): + dummy_num_boxes = torch.tensor(0.0, device=device) + torch.distributed.all_reduce(dummy_num_boxes) + + # Use pred_logits.sum() * 0 to maintain computation graph for DDP + zero_loss = pred_logits.sum() * 0 + return { + "loss_cls": zero_loss, # Keep grad for DDP + "loss_bbox": zero_loss.clone(), + "loss_giou": zero_loss.clone(), + } + + batch_idx, src_idx, tgt_idx = out.indices + + # Move indices to the same device as predictions + batch_idx = batch_idx.to(pred_logits.device) + src_idx = src_idx.to(pred_logits.device) + tgt_idx = tgt_idx.to(pred_logits.device) if tgt_idx is not None else None + + indices = (batch_idx, src_idx, tgt_idx) + + # Get number of boxes for normalization + num_boxes = self._get_num_boxes(normalized_targets) + + # ========== 2D Losses via SAM3's loss classes (scaled by loss_2d_scale) ========== + if _PROFILE_LOSS: + torch.cuda.synchronize() + _t0 = time.perf_counter() + + # Build SAM3-format outputs dict for loss classes + sam3_outputs = { + "pred_logits": pred_logits, + "pred_boxes_xyxy": pred_boxes_2d, + "pred_boxes": out.pred_boxes_2d_cxcywh, + } + if out.presence_logits is not None: + sam3_outputs["presence_logit_dec"] = out.presence_logits + + # Compute ignore negative loss suppression mask + if ( + self.config.use_ignore_suppress + and batch.ignore_boxes2d is not None + and batch.num_ignores is not None + ): + normalized_targets["_ignore_boxes2d"] = batch.ignore_boxes2d + normalized_targets["_num_ignores"] = batch.num_ignores + normalized_targets["ignore_neg_mask"] = ( + self._compute_ignore_neg_mask( + pred_boxes_2d, + batch.ignore_boxes2d, + batch.num_ignores, + threshold=self.config.ignore_iou_threshold, + ) + ) + + # Classification + presence via SAM3's IABCEMdetr + cls_losses = self.cls_loss.get_loss( + sam3_outputs, normalized_targets, indices, num_boxes + ) + losses["loss_cls"] = ( + self.config.loss_2d_scale * cls_losses["loss_ce"] * self.config.loss_cls_weight + ) + # Metrics from SAM3's IABCEMdetr (not losses, just for wandb logging) + if "ce_f1" in cls_losses: + losses["metric_ce_f1"] = cls_losses["ce_f1"].detach() + # Presence loss (computed inside IABCEMdetr when use_presence=True) + presence_val = cls_losses.get("presence_loss") + if presence_val is not None and isinstance(presence_val, Tensor): + losses["loss_presence"] = ( + self.config.loss_2d_scale * presence_val + * self.config.presence_loss_weight + ) + if "presence_dec_acc" in cls_losses: + losses["metric_presence_acc"] = cls_losses["presence_dec_acc"].detach() + + if _PROFILE_LOSS: + torch.cuda.synchronize() + _loss_cls_time = (time.perf_counter() - _t0) * 1000 + + # 2D box losses (L1 + GIoU) via SAM3's Boxes class + if _PROFILE_LOSS: + torch.cuda.synchronize() + _t1 = time.perf_counter() + + box_losses = self.box_loss.get_loss( + sam3_outputs, normalized_targets, indices, num_boxes + ) + losses["loss_bbox"] = ( + self.config.loss_2d_scale * box_losses["loss_bbox"] * self.config.loss_bbox_weight + ) + losses["loss_giou"] = ( + self.config.loss_2d_scale * box_losses["loss_giou"] * self.config.loss_giou_weight + ) + + if _PROFILE_LOSS: + torch.cuda.synchronize() + _loss_2d_box_time = (time.perf_counter() - _t1) * 1000 + + # ========== O2M Loss (2D scaled by loss_2d_scale, 3D scaled by loss_3d_scale) ========== + if _PROFILE_LOSS: + torch.cuda.synchronize() + _t_o2m = time.perf_counter() + _loss_o2m_time = 0 + + # Use real O2M outputs from SAM3 DAC mechanism (not O2O outputs) + if self.config.use_o2m and self.o2m_matcher is not None and out.pred_logits_o2m is not None: + o2m_losses = self._loss_o2m( + pred_logits=out.pred_logits_o2m, + pred_boxes_2d=out.pred_boxes_2d_o2m, + pred_boxes_2d_cxcywh=out.pred_boxes_2d_cxcywh_o2m, + pred_boxes_3d=out.pred_boxes_3d_o2m, + targets=normalized_targets, + num_boxes=num_boxes, + intrinsics=intrinsics, + image_size=image_size, + pred_conf_3d=out.pred_conf_3d_o2m, + ) + # Apply appropriate scale and loss weights (following SAM3 original) + # SAM3 original: loss = loss_value * o2m_weight * loss_weight + # We need to apply the individual loss weights, not just o2m_loss_weight + o2m_weight_map = { + "loss_cls": self.config.loss_cls_weight, + "loss_bbox": self.config.loss_bbox_weight, + "loss_giou": self.config.loss_giou_weight, + "loss_delta_2d": self.config.loss_delta_2d_weight, + "loss_depth": self.config.loss_depth_weight, + "loss_dim": self.config.loss_dim_weight, + "loss_rot": self.config.loss_rot_weight, + "loss_3d_cls": self.config.loss_3d_conf_weight, + } + for key, value in o2m_losses.items(): + loss_weight = o2m_weight_map.get(key, 1.0) + if key in ("loss_delta_2d", "loss_depth", "loss_dim", "loss_rot"): + # 3D losses use loss_3d_scale + o2m_loss_val = ( + self.config.loss_3d_scale * value * loss_weight * self.config.o2m_loss_weight + ) + elif key == "loss_3d_cls": + # 3D confidence loss: weight * o2m_weight (no extra scale) + o2m_loss_val = value * loss_weight * self.config.o2m_loss_weight + else: + # 2D losses (loss_cls, loss_bbox, loss_giou) use loss_2d_scale + o2m_loss_val = ( + self.config.loss_2d_scale * value * loss_weight * self.config.o2m_loss_weight + ) + # Clip O2M loss to prevent gradient explosion + losses[f"o2m_{key}"] = torch.clamp(o2m_loss_val, max=self.config.o2m_loss_clip) + + if _PROFILE_LOSS: + torch.cuda.synchronize() + _loss_o2m_time = (time.perf_counter() - _t_o2m) * 1000 + + # ========== 3D Losses (scaled by loss_3d_scale) ========== + if _PROFILE_LOSS: + torch.cuda.synchronize() + _t2 = time.perf_counter() + _loss_3d_time = 0 + + if pred_boxes_3d is not None and intrinsics is not None: + loss_3d = self._loss_boxes_3d( + pred_boxes_2d, pred_boxes_3d, indices, normalized_targets, + intrinsics, num_boxes, image_size=image_size + ) + # Apply loss_3d_scale to all 3D losses + for key, value in loss_3d.items(): + losses[key] = self.config.loss_3d_scale * value + + if _PROFILE_LOSS: + torch.cuda.synchronize() + _loss_3d_time = (time.perf_counter() - _t2) * 1000 + + # ========== 3D Confidence Loss (positive samples only) ========== + if (self.config.use_3d_conf + and out.pred_conf_3d is not None + and pred_boxes_3d is not None + and intrinsics is not None): + loss_3d_cls = self._loss_3d_classification( + out.pred_conf_3d, pred_boxes_2d, pred_boxes_3d, + indices, normalized_targets, intrinsics, num_boxes, image_size, + ) + losses["loss_3d_cls"] = self.config.loss_3d_conf_weight * loss_3d_cls + + # ========== Geometry Backend Losses (scaled by loss_geom_scale) ========== + if _PROFILE_LOSS: + torch.cuda.synchronize() + _t_geom = time.perf_counter() + _loss_geom_time = 0 + + if geom_losses is not None: + for key, value in geom_losses.items(): + if key.startswith("metric_"): + # Monitoring-only: log raw value, no scaling + losses[key] = value.detach() + else: + weight = getattr( + self.config, f"loss_{key}_weight", 1.0 + ) + losses[f"loss_{key}"] = ( + self.config.loss_geom_scale * value * weight + ) + + if _PROFILE_LOSS: + torch.cuda.synchronize() + _loss_geom_time = (time.perf_counter() - _t_geom) * 1000 + + # ========== Auxiliary Losses (Deep Supervision) ========== + if _PROFILE_LOSS: + torch.cuda.synchronize() + _t3 = time.perf_counter() + _loss_aux_time = 0 + + _num_aux_layers = 0 + if aux_outputs is not None: + _num_aux_layers = len(aux_outputs) + for i, aux_out in enumerate(aux_outputs): + aux_losses = self._compute_aux_loss( + aux_out, indices, normalized_targets, num_boxes, intrinsics, image_size + ) + for key, value in aux_losses.items(): + losses[f"d{i}.{key}"] = value * self.config.aux_loss_weight + + if _PROFILE_LOSS: + torch.cuda.synchronize() + _loss_aux_time = (time.perf_counter() - _t3) * 1000 + _loss_total_time = (time.perf_counter() - _loss_start) * 1000 + + # Print loss timing summary (every N steps via profiler) + from wilddet3d.ops.profiler import profiler + p = profiler() + p.current_step_timings["loss_total"] = _loss_total_time / 1000 + p.current_step_timings[" loss_targets"] = _loss_targets_time / 1000 + p.current_step_timings[" loss_cls"] = _loss_cls_time / 1000 + p.current_step_timings[" loss_2d_box"] = _loss_2d_box_time / 1000 + p.current_step_timings[" loss_o2m"] = _loss_o2m_time / 1000 + p.current_step_timings[" loss_3d"] = _loss_3d_time / 1000 + p.current_step_timings[" loss_geom"] = _loss_geom_time / 1000 + p.current_step_timings[" loss_aux"] = _loss_aux_time / 1000 + p.current_step_timings[" loss_aux_layers"] = _num_aux_layers + + # ========== Ensure all losses are tensors ========== + # vis4d LossModule expects dict of tensors + for k, v in list(losses.items()): + if not isinstance(v, Tensor): + losses[k] = torch.tensor(v, device=pred_logits.device) + + # vis4d LossModule will sum all losses in the dict automatically + return losses + + def _get_num_boxes(self, targets: dict) -> Tensor: + """Get number of boxes for loss normalization.""" + num_boxes = targets["num_boxes"].sum().float() + + if self.config.normalization == "global": + # Handle non-distributed case + if torch.distributed.is_initialized(): + torch.distributed.all_reduce(num_boxes) + world_size = torch.distributed.get_world_size() + num_boxes = torch.clamp(num_boxes / world_size, min=1) + else: + # Non-distributed: just clamp + num_boxes = torch.clamp(num_boxes, min=1) + elif self.config.normalization == "local": + num_boxes = torch.clamp(num_boxes, min=1) + else: # "none" + num_boxes = torch.ones_like(num_boxes) + + return num_boxes + + # 2D classification and box losses are now handled by SAM3's + # IABCEMdetr (self.cls_loss) and Boxes (self.box_loss) classes. + + def _loss_o2m( + self, + pred_logits: Tensor, # (B, S, 1) + pred_boxes_2d: Tensor, # (B, S, 4) normalized xyxy + pred_boxes_2d_cxcywh: Tensor | None, # (B, S, 4) normalized cxcywh + pred_boxes_3d: Tensor | None, # (B, S, reg_dims) + targets: dict, + num_boxes: Tensor, + intrinsics: Tensor | None = None, # (B, 3, 3) + image_size: tuple[int, int] | None = None, + pred_conf_3d: Tensor | None = None, # (B, S, 1) 3D confidence + ) -> dict[str, Tensor]: + """Compute O2M (One-to-Many) auxiliary loss. + + Uses SAM3's IABCEMdetr and Boxes classes for 2D losses, + plus our own 3D loss for matched predictions. + """ + losses = {} + device = pred_logits.device + B, S = pred_logits.shape[:2] + + # Prepare targets in padded format for O2M matcher + num_boxes_per_image = targets.get( + "num_boxes", + torch.tensor([len(targets["boxes_xyxy"])], device=device), + ) + boxes_padded = targets.get("boxes_padded") + if boxes_padded is None: + boxes_cxcywh = self._xyxy_to_cxcywh(targets["boxes_xyxy"]) + boxes_padded = _packed_to_padded(boxes_cxcywh, num_boxes_per_image) + + max_N = boxes_padded.shape[1] + target_is_valid_padded = torch.zeros( + B, max_N, dtype=torch.bool, device=device + ) + for i in range(B): + target_is_valid_padded[i, :num_boxes_per_image[i]] = True + + # O2M matching + if pred_boxes_2d_cxcywh is None: + pred_boxes_2d_cxcywh = self._xyxy_to_cxcywh(pred_boxes_2d) + + outputs_dict = { + "pred_logits": pred_logits, + "pred_boxes": pred_boxes_2d_cxcywh, + } + targets_dict = { + "boxes_padded": boxes_padded, + "labels": targets["classes"], + "num_boxes": num_boxes_per_image, + } + batch_idx, src_idx, tgt_idx = self.o2m_matcher( + outputs_dict, + targets_dict, + target_is_valid_padded=target_is_valid_padded, + ) + + if batch_idx.numel() == 0: + zero_losses = { + "loss_cls": torch.tensor(0.0, device=device), + "loss_bbox": torch.tensor(0.0, device=device), + "loss_giou": torch.tensor(0.0, device=device), + } + if pred_boxes_3d is not None and intrinsics is not None: + zero_losses.update({ + "loss_delta_2d": torch.tensor(0.0, device=device), + "loss_depth": torch.tensor(0.0, device=device), + "loss_dim": torch.tensor(0.0, device=device), + "loss_rot": torch.tensor(0.0, device=device), + }) + return zero_losses + + o2m_indices = (batch_idx, src_idx, tgt_idx) + + # Recompute ignore mask for O2M predictions (different pred boxes) + if "_ignore_boxes2d" in targets: + targets = targets.copy() + targets["ignore_neg_mask"] = self._compute_ignore_neg_mask( + pred_boxes_2d, + targets["_ignore_boxes2d"], + targets["_num_ignores"], + threshold=self.config.ignore_iou_threshold, + ) + + # 2D losses via SAM3 classes + o2m_outputs = { + "pred_logits": pred_logits, + "pred_boxes_xyxy": pred_boxes_2d, + "pred_boxes": pred_boxes_2d_cxcywh, + } + cls_losses = self.cls_loss.get_loss( + o2m_outputs, targets, o2m_indices, num_boxes + ) + losses["loss_cls"] = cls_losses["loss_ce"] + + box_losses = self.box_loss.get_loss( + o2m_outputs, targets, o2m_indices, num_boxes + ) + losses["loss_bbox"] = box_losses["loss_bbox"] + losses["loss_giou"] = box_losses["loss_giou"] + + # 3D losses (our own, not in SAM3) + if (pred_boxes_3d is not None and intrinsics is not None + and "boxes_3d" in targets): + loss_3d = self._loss_boxes_3d( + pred_boxes_2d=pred_boxes_2d, + pred_boxes_3d=pred_boxes_3d, + indices=o2m_indices, + targets=targets, + intrinsics=intrinsics, + num_boxes=num_boxes, + image_size=image_size, + ) + losses.update(loss_3d) + + # 3D confidence loss (O2M branch) + if (self.config.use_3d_conf + and pred_conf_3d is not None + and pred_boxes_3d is not None + and intrinsics is not None): + loss_3d_cls = self._loss_3d_classification( + pred_conf_3d, pred_boxes_2d, pred_boxes_3d, + o2m_indices, targets, intrinsics, num_boxes, image_size, + ) + losses["loss_3d_cls"] = loss_3d_cls + + return losses + + # _loss_boxes_2d replaced by SAM3's Boxes class (self.box_loss). + + def _loss_boxes_3d( + self, + pred_boxes_2d: Tensor, # (B, S, 4) + pred_boxes_3d: Tensor, # (B, S, reg_dims) + indices: tuple[Tensor, Tensor, Tensor | None], + targets: dict, + intrinsics: Tensor, + num_boxes: Tensor, + image_size: tuple[int, int] | None = None, + ) -> dict[str, Tensor]: + """Compute 3D box regression losses. + + Args: + pred_boxes_2d: Predicted 2D boxes in normalized xyxy [0,1]. Shape (B, S, 4). + pred_boxes_3d: Predicted 3D box parameters. Shape (B, S, reg_dims). + indices: Matching indices (batch_idx, src_idx, tgt_idx). + targets: Target dict containing boxes_3d. + intrinsics: Camera intrinsics. Shape (B, 3, 3). + num_boxes: Number of matched boxes for normalization. + image_size: (H, W) tuple for converting normalized to pixel coords. + Required for correct box_coder.encode() which expects pixel coords. + """ + batch_idx, src_idx, tgt_idx = indices + + # Get matched predictions (for loss computation) + src_boxes_3d = pred_boxes_3d[(batch_idx, src_idx)] + + # Get matched GT 2D boxes (for box_coder.encode target computation) + # IMPORTANT: Use GT 2D boxes, NOT predicted boxes! + # This matches GDino3D's design where encode() uses GT 2D boxes to compute + # stable targets, while decode() at inference uses predicted 2D boxes. + target_boxes_2d = ( + targets["boxes_xyxy"][tgt_idx] if tgt_idx is not None + else targets["boxes_xyxy"] + ) + + # Get matched GT 3D boxes + target_boxes_3d = ( + targets["boxes_3d"][tgt_idx] if tgt_idx is not None + else targets["boxes_3d"] + ) + + # Get intrinsics for matched samples + # Note: intrinsics is (B, 3, 3), need to index by batch_idx + # Since box_coder.encode() expects single intrinsics (3, 3), + # we need to process each matched box individually + if len(batch_idx) == 0: + # No matches, return zero losses + return { + "loss_delta_2d": torch.tensor(0.0, device=pred_boxes_2d.device), + "loss_depth": torch.tensor(0.0, device=pred_boxes_2d.device), + "loss_dim": torch.tensor(0.0, device=pred_boxes_2d.device), + "loss_rot": torch.tensor(0.0, device=pred_boxes_2d.device), + } + + target_boxes_3d_encoded_list = [] + weights_3d_list = [] + + # Validate image_size is provided - required for correct box_coder.encode() + if image_size is None: + raise ValueError( + "image_size is required for _loss_boxes_3d. " + "box_coder.encode() expects pixel coordinates because " + "project_points() returns pixel coords and " + "delta_center = projected_3d_center - 2d_box_center (both in pixels)." + ) + + H, W = image_size + factors = target_boxes_2d.new_tensor([W, H, W, H]) + + for i in range(len(batch_idx)): + single_box_3d = target_boxes_3d[i:i+1] + + # Skip entries with invalid (all-zero) 3D boxes: set weight=0 + # so they don't contribute to 3D loss. This handles the case + # where GT has a valid 2D box but no 3D annotation. + if single_box_3d.abs().sum() < 1e-6: + reg_dims = pred_boxes_3d.shape[-1] + target_boxes_3d_encoded_list.append( + torch.zeros(1, reg_dims, device=pred_boxes_3d.device) + ) + weights_3d_list.append( + torch.zeros(1, reg_dims, device=pred_boxes_3d.device) + ) + continue + + # Use GT 2D box (normalized xyxy) and convert to pixel + single_gt_box_2d = target_boxes_2d[i:i+1] + single_gt_box_2d_pixel = single_gt_box_2d * factors + + single_intrinsic = intrinsics[batch_idx[i]] # (3, 3) + + encoded, weights = self.box_coder.encode( + single_gt_box_2d_pixel, single_box_3d, single_intrinsic, + ) + target_boxes_3d_encoded_list.append(encoded) + weights_3d_list.append(weights) + + target_boxes_3d_encoded = torch.cat(target_boxes_3d_encoded_list, dim=0) + weights_3d = torch.cat(weights_3d_list, dim=0) + + losses = {} + + # Delta 2D center loss + loss_delta_2d = l1_loss( + src_boxes_3d[:, :2], + target_boxes_3d_encoded[:, :2], + reducer=SumWeightedLoss( + weight=weights_3d[:, :2], avg_factor=num_boxes.item() + ), + ) + losses["loss_delta_2d"] = loss_delta_2d * self.config.loss_delta_2d_weight + + # Depth loss + loss_depth = l1_loss( + src_boxes_3d[:, 2], + target_boxes_3d_encoded[:, 2], + reducer=SumWeightedLoss( + weight=weights_3d[:, 2], avg_factor=num_boxes.item() + ), + ) + losses["loss_depth"] = loss_depth * self.config.loss_depth_weight + + # Dimension loss + loss_dim = l1_loss( + src_boxes_3d[:, 3:6], + target_boxes_3d_encoded[:, 3:6], + reducer=SumWeightedLoss( + weight=weights_3d[:, 3:6], avg_factor=num_boxes.item() + ), + ) + losses["loss_dim"] = loss_dim * self.config.loss_dim_weight + + # Rotation loss + loss_rot = l1_loss( + src_boxes_3d[:, 6:], + target_boxes_3d_encoded[:, 6:], + reducer=SumWeightedLoss( + weight=weights_3d[:, 6:], avg_factor=num_boxes.item() + ), + ) + losses["loss_rot"] = loss_rot * self.config.loss_rot_weight + + return losses + + def _loss_3d_classification( + self, + pred_conf_3d: Tensor, # (B, S, 1) + pred_boxes_2d: Tensor, # (B, S, 4) normalized xyxy + pred_boxes_3d: Tensor, # (B, S, 12) encoded + indices: tuple[Tensor, Tensor, Tensor | None], + targets: dict, + intrinsics: Tensor, # (N_prompts, 3, 3) + num_boxes: Tensor, + image_size: tuple[int, int], + ) -> Tensor: + """Compute 3D confidence loss (positive + negative). + + Positive: soft target = quality (0.7*iou_3d + 0.3*depth) + Negative: target = 0, with focal weighting + Same structure as 2D cls loss (IABCEMdetr). + + At inference: final_score = 2d_score + 0.5 * 3d_score + """ + batch_idx, src_idx, tgt_idx = indices + B, S, _ = pred_conf_3d.shape + device = pred_conf_3d.device + M = len(batch_idx) + + if M == 0: + return pred_conf_3d.sum() * 0.0 + + prob = pred_conf_3d.sigmoid() + target_classes = torch.zeros(B, S, 1, device=device) + target_classes[(batch_idx, src_idx)] = 1.0 + + with torch.no_grad(): + # 1. Depth quality - directly from encoded params, no decode needed + src_boxes_3d = pred_boxes_3d[(batch_idx, src_idx)] + target_boxes_3d_raw = ( + targets["boxes_3d"][tgt_idx] if tgt_idx is not None + else targets["boxes_3d"] + ) + depth_scale = self.box_coder.depth_scale + pred_log_z = src_boxes_3d[:, 2] / depth_scale # = log(pred_z) + gt_z = target_boxes_3d_raw[:, 2].clamp(min=0.1) + gt_log_z = torch.log(gt_z) + depth_quality = torch.exp(-torch.abs(pred_log_z - gt_log_z)) + depth_quality = torch.nan_to_num(depth_quality, nan=0.0, posinf=1.0, neginf=0.0) + + # 2. 3D IoU using safe shapely-based implementation + # (CPU, full rotation support, never crashes) + from wilddet3d.ops.iou_box3d import batch_box3d_iou + + H, W = image_size + factors = pred_boxes_2d.new_tensor([[W, H, W, H]]) + src_boxes_2d_pixel = pred_boxes_2d[(batch_idx, src_idx)] * factors + + pred_decoded_list = [] + for i in range(M): + single_decoded = self.box_coder.decode( + src_boxes_2d_pixel[i:i+1], + src_boxes_3d[i:i+1], + intrinsics[batch_idx[i]], + ) + pred_decoded_list.append(single_decoded) + pred_decoded = torch.cat(pred_decoded_list, dim=0) # (M, 10) + + iou_3d = batch_box3d_iou(pred_decoded, target_boxes_3d_raw[:, :10]) + + # 3. Combined quality + quality = ( + self.config.conf_depth_weight * depth_quality + + self.config.conf_iou_3d_weight * iou_3d + ) + quality = torch.nan_to_num(quality, nan=0.0).clamp(0.0, 1.0) + + # 4. Build soft target (same as 2D IABCEMdetr pattern) + t = ( + prob[(batch_idx, src_idx)].squeeze(-1) ** self.config.alpha + * quality ** (1 - self.config.alpha) + ) + t = t.clamp(min=0.01).detach() + + positive_target = target_classes.clone() + positive_target[(batch_idx, src_idx)] = t.unsqueeze(-1) + + # Positive loss with soft quality target + loss_pos = F.binary_cross_entropy_with_logits( + pred_conf_3d, positive_target, reduction="none" + ) + loss_pos = loss_pos * target_classes * self.config.pos_weight + + # Negative loss with focal weighting (push unmatched queries toward 0) + loss_neg = F.binary_cross_entropy_with_logits( + pred_conf_3d, target_classes, reduction="none" + ) + loss_neg = loss_neg * (1 - target_classes) * (prob ** self.config.gamma) + + # Suppress negative loss for predictions overlapping ignore boxes + if "ignore_neg_mask" in targets: + neg_suppress = targets["ignore_neg_mask"].unsqueeze(-1) + loss_neg = loss_neg * (1 - neg_suppress) + + loss_bce = loss_pos + loss_neg + + # Apply presence mask (zero out loss for prompts with no GT) + if self.config.use_presence: + num_gts = targets.get( + "num_boxes", torch.zeros(B, dtype=torch.long, device=device) + ) + keep_loss = (num_gts > 0).float().view(B, 1, 1) # (B, 1, 1) for (B, S, 1) broadcasting + loss_bce = loss_bce * keep_loss + + return loss_bce.mean() + + def _compute_aux_loss( + self, + aux_out: dict, + indices: tuple[Tensor, Tensor, Tensor | None], + targets: dict, + num_boxes: Tensor, + intrinsics: Tensor | None = None, + image_size: tuple[int, int] | None = None, + ) -> dict[str, Tensor]: + """Compute losses for auxiliary decoder outputs. + + Following GDino3D's design, we compute all losses (2D + 3D) for auxiliary outputs + to enable full deep supervision across all decoder layers. + + Args: + aux_out: Auxiliary output dictionary containing pred_logits, pred_boxes_2d, pred_boxes_3d + indices: Matching indices from matcher + targets: Ground truth targets + num_boxes: Number of boxes for normalization + intrinsics: Camera intrinsics for 3D loss computation + image_size: (H, W) for pixel coordinate conversion + + Returns: + Dictionary of auxiliary losses + """ + losses = {} + + # Build SAM3-format outputs for aux layer + sam3_aux = { + "pred_logits": aux_out.get("pred_logits"), + "pred_boxes_xyxy": aux_out.get( + "pred_boxes_xyxy", aux_out.get("pred_boxes_2d") + ), + "pred_boxes": aux_out.get("pred_boxes"), + } + # If pred_boxes (cxcywh) not available, convert from xyxy + if sam3_aux["pred_boxes"] is None and sam3_aux["pred_boxes_xyxy"] is not None: + sam3_aux["pred_boxes"] = self._xyxy_to_cxcywh( + sam3_aux["pred_boxes_xyxy"] + ) + + # Recompute ignore mask for this aux layer's predicted boxes + if "_ignore_boxes2d" in targets and sam3_aux["pred_boxes_xyxy"] is not None: + targets = targets.copy() + targets["ignore_neg_mask"] = self._compute_ignore_neg_mask( + sam3_aux["pred_boxes_xyxy"], + targets["_ignore_boxes2d"], + targets["_num_ignores"], + threshold=self.config.ignore_iou_threshold, + ) + + # Classification loss via SAM3's IABCEMdetr (scaled by loss_2d_scale) + if sam3_aux["pred_logits"] is not None: + cls_losses = self.cls_loss.get_loss( + sam3_aux, targets, indices, num_boxes + ) + losses["loss_cls"] = ( + self.config.loss_2d_scale + * cls_losses["loss_ce"] + * self.config.loss_cls_weight + ) + + # 2D box losses via SAM3's Boxes class (scaled by loss_2d_scale) + if sam3_aux["pred_boxes"] is not None: + box_losses = self.box_loss.get_loss( + sam3_aux, targets, indices, num_boxes + ) + losses["loss_bbox"] = ( + self.config.loss_2d_scale + * box_losses["loss_bbox"] + * self.config.loss_bbox_weight + ) + losses["loss_giou"] = ( + self.config.loss_2d_scale + * box_losses["loss_giou"] + * self.config.loss_giou_weight + ) + + # 3D box loss (our own, scaled by loss_3d_scale) + pred_boxes_2d_aux = aux_out.get( + "pred_boxes_2d", aux_out.get("pred_boxes_xyxy") + ) + if "pred_boxes_3d" in aux_out and intrinsics is not None: + loss_3d = self._loss_boxes_3d( + pred_boxes_2d_aux, + aux_out["pred_boxes_3d"], + indices, + targets, + intrinsics, + num_boxes, + image_size=image_size, + ) + for key, value in loss_3d.items(): + losses[key] = self.config.loss_3d_scale * value + + # 3D confidence loss (deep supervision) + if (self.config.use_3d_conf + and "pred_conf_3d" in aux_out + and "pred_boxes_3d" in aux_out + and intrinsics is not None): + loss_3d_cls = self._loss_3d_classification( + aux_out["pred_conf_3d"], + pred_boxes_2d_aux, + aux_out["pred_boxes_3d"], + indices, targets, intrinsics, num_boxes, image_size, + ) + losses["loss_3d_cls"] = self.config.loss_3d_conf_weight * loss_3d_cls + + return losses + + def _normalize_targets(self, targets: dict) -> dict: + """Ensure targets are in expected format for loss computation. + + Note: WildDet3D collator always outputs GT boxes in normalized [0, 1] xyxy format. + This function simply ensures the classes tensor exists (for binary classification). + + Args: + targets: Dictionary containing ground truth data + - boxes_xyxy: (N, 4) boxes in normalized xyxy [0, 1] format + - classes: (N,) class labels (all ones for SAM3) + - num_boxes: (N,) number of boxes per prompt (always 1) + - boxes_3d: (N, 12) 3D boxes (optional) + + Returns: + Targets dict with classes tensor guaranteed to exist + """ + normalized = targets.copy() + boxes_xyxy = targets["boxes_xyxy"] + + # Ensure classes tensor exists (all ones for binary classification) + if "classes" not in normalized: + num_boxes = boxes_xyxy.shape[0] + normalized["classes"] = torch.ones( + num_boxes, dtype=torch.long, device=boxes_xyxy.device + ) + + return normalized + + def _xyxy_to_cxcywh(self, boxes_xyxy: Tensor) -> Tensor: + """Convert boxes from xyxy to cxcywh format. + + Args: + boxes_xyxy: (N, 4) boxes in xyxy format + + Returns: + boxes_cxcywh: (N, 4) boxes in cxcywh format + """ + x1, y1, x2, y2 = boxes_xyxy.unbind(-1) + cx = (x1 + x2) / 2 + cy = (y1 + y2) / 2 + w = x2 - x1 + h = y2 - y1 + return torch.stack([cx, cy, w, h], dim=-1) + diff --git a/wilddet3d/model.py b/wilddet3d/model.py new file mode 100644 index 0000000000000000000000000000000000000000..929b7444dd927b423b6d0d91b10c95e5440d98d1 --- /dev/null +++ b/wilddet3d/model.py @@ -0,0 +1,1647 @@ +"""WildDet3D: SAM3 with 3D Detection Head. + +This module combines SAM3 (2D detection with geometric prompting) with +3D detection head and geometry backend. + +Key Design Decisions (from Design Doc): +1. Coordinate format: SAM3 uses normalized cxcywh internally, + model outputs normalized xyxy [0, 1] +2. Tensor format: SAM3 Decoder outputs sequence-first (L, S, B, C), + 3D Head expects batch-first (L, B, S, C) -> need permute +3. Batch strategy: per-prompt batch with img_ids indexing +4. bbox_head: Reuse SAM3 Decoder's internal bbox_embed, + no external bbox_head needed +5. Forward: Reuse SAM3's forward_grounding() method for 2D detection, + then add 3D head on top + +Data Flow: +1. DataLoader produces per-image data +2. Collator expands to per-prompt batch (WildDet3DInput) +3. Model forward receives expanded data, calls SAM3's forward_grounding +4. 3D head processes SAM3 output +""" + +from __future__ import annotations + +from typing import List + +import torch +from torch import Tensor, nn +from torchvision.ops import nms, batched_nms, box_iou + +from wilddet3d.ops.profiler import profile_start, profile_stop, profile_step + +# SAM3 imports +from sam3.model.sam3_image import Sam3Image +from sam3.model.geometry_encoders import Prompt +from sam3.model.box_ops import box_cxcywh_to_xyxy +from sam3.model.data_misc import FindStage, BatchedFindTarget + +# 3D detection imports +from wilddet3d.head import ( + Det3DHead, + Det3DCoder, + RoI2Det3D, +) +from wilddet3d.data_types import Det3DOut, WildDet3DOut, WildDet3DInput +from wilddet3d.depth import GeometryBackendBase + + +class Fp32LayerNorm(nn.LayerNorm): + """LayerNorm that always computes in fp32. + + In mixed-precision training (bf16/fp16), standard LayerNorm can overflow + because the variance computation involves squaring values. bf16 max is + ~65504, so values > ~256 squared will overflow. + + This wrapper casts input to fp32, runs LayerNorm, then casts back. + The overhead is negligible since LayerNorm is memory-bound. + """ + + def forward(self, x: Tensor) -> Tensor: + orig_dtype = x.dtype + x = x.float() + x = super().forward(x) + return x.to(orig_dtype) + + +def _upgrade_layernorms_to_fp32(module: nn.Module) -> int: + """Replace all nn.LayerNorm in a module tree with Fp32LayerNorm. + + Walks the module tree and swaps each nn.LayerNorm with an Fp32LayerNorm + that shares the same weight and bias tensors (no copy, no extra memory). + + Args: + module: Root module to patch. + + Returns: + Number of LayerNorm modules replaced. + """ + count = 0 + for name, child in module.named_children(): + if isinstance(child, nn.LayerNorm) and not isinstance(child, Fp32LayerNorm): + fp32_ln = Fp32LayerNorm( + child.normalized_shape, + eps=child.eps, + elementwise_affine=child.elementwise_affine, + ) + # Share weight/bias tensors (no copy) + fp32_ln.weight = child.weight + fp32_ln.bias = child.bias + setattr(module, name, fp32_ln) + count += 1 + else: + count += _upgrade_layernorms_to_fp32(child) + return count + + +class WildDet3D(nn.Module): + """SAM3 with 3D Detection Head. + + This model combines: + 1. SAM3's backbone, encoder, decoder (for 2D detection with geometric prompting) + 2. Geometry backend (depth estimation) + 3. 3D head (3D box regression) + + Architecture: + ``` + Image + Prompts + | + v + +------------------------------------------+ + | SAM3 (backbone + encoder + decoder) | + | - ViT backbone with SimpleFPN | + | - Geometry Encoder for prompts | + | - Transformer Encoder/Decoder | + | - Internal bbox_embed for 2D boxes | + +-------------------+----------------------+ + | hidden_states, pred_boxes (cxcywh) + | + +-------+-------+ + v v + +-----------+ +---------------+ + | cxcywh | | Geometry | + | -> xyxy | | Backend | + +-----+-----+ | (depth) | + | +-------+-------+ + | | depth_latents + v v + +-------------------------------+ + | 3D Head | + | (depth + ray cross-attention)| + +---------------+---------------+ + | + v + pred_boxes_3d + ``` + """ + + def __init__( + self, + # ========== SAM3 Components ========== + sam3_model: Sam3Image | None = None, + sam3_checkpoint: str | None = None, + + # ========== 3D Components ========== + bbox3d_head: Det3DHead | None = None, + box_coder: Det3DCoder | None = None, + geometry_backend: GeometryBackendBase | None = None, + roi2det3d: RoI2Det3D | None = None, + + # ========== Depth-Memory Fusion ========== + early_depth_fusion: nn.Module | None = None, + + # ========== Freeze Settings ========== + backbone_freeze_blocks: int = 0, + + # ========== Oracle Evaluation ========== + oracle_eval: bool = False, + + # ========== Depth Input at Test Time ========== + use_depth_input_test: bool = False, + + # ========== Predicted Intrinsics ========== + use_predicted_intrinsics: bool = False, + + # ========== Eval Score Control ========== + eval_3d_conf_weight: float = 0.5, + use_presence_score: bool = True, + ) -> None: + """Initialize WildDet3D. + + Args: + sam3_model: Complete SAM3 model (backbone + encoder + decoder). + If None, will be built from sam3_checkpoint. + sam3_checkpoint: Path to SAM3 checkpoint. Only used if sam3_model is None. + bbox3d_head: 3D box regression head. If None, creates default. + box_coder: 3D box encoder/decoder. If None, creates default. + geometry_backend: Depth estimation backend. If None, no depth. + roi2det3d: Inference post-processor. If None, creates default. + early_depth_fusion: Early fusion module (after backbone, before encoder). + If None, no early fusion is performed. + backbone_freeze_blocks: Number of SAM3 ViT backbone blocks to + freeze (from the beginning). SAM3 has 32 blocks; e.g. 30 + freezes blocks[0..29], only training the last 2. + 0 means no freezing. + oracle_eval: If True, use oracle evaluation mode where each + prompt gets top-1 prediction (no NMS, no score filtering). + For measuring 3D regression quality with GT box prompts. + use_predicted_intrinsics: If True, use geometry backend's + predicted intrinsics (K_pred) for 3D box decoding at test + time instead of batch.intrinsics (dataset/default). + Useful for in-the-wild images without GT intrinsics. + Can be overridden by env var SAM3_USE_PRED_K=1/0. + eval_3d_conf_weight: Weight for 3D confidence in eval score. + final_score = 2d_score + weight * 3d_score. + Set to 0.0 to use only 2D confidence for eval. + """ + super().__init__() + + # SAM3 model - build if not provided + if sam3_model is None: + import os + from sam3.model_builder import build_sam3_image_model + + # Check if torch.compile should be enabled for SAM3 + use_compile = os.environ.get("SAM3_COMPILE", "0") == "1" + if use_compile: + print("[WildDet3D] torch.compile ENABLED for SAM3 backbone (SAM3_COMPILE=1)") + else: + print("[WildDet3D] torch.compile disabled (set SAM3_COMPILE=1 to enable)") + + print(f"Building SAM3 model from checkpoint: {sam3_checkpoint}") + sam3_model = build_sam3_image_model( + checkpoint_path=sam3_checkpoint, + load_from_HF=(sam3_checkpoint is None), # Only load from HF if no checkpoint provided + device="cpu", # Will be moved to correct device later + eval_mode=False, # Must be False to enable matcher for training + enable_segmentation=False, # Skip seg head for 3D detection (saves ~4GB memory) + compile=use_compile, # Enable torch.compile for backbone + ) + # Store checkpoint path for logging in on_load_checkpoint + self._sam3_checkpoint_path = sam3_checkpoint + else: + self._sam3_checkpoint_path = "provided_model" + + self.sam3 = sam3_model + self.hidden_dim = sam3_model.hidden_dim + self.oracle_eval = oracle_eval + self.use_depth_input_test = use_depth_input_test + self.use_predicted_intrinsics = use_predicted_intrinsics + self.eval_3d_conf_weight = eval_3d_conf_weight + self.use_presence_score = use_presence_score + print(f"[WildDet3D] use_presence_score={self.use_presence_score}") + + # 3D components + self.box_coder = box_coder or Det3DCoder() + self.geometry_backend = geometry_backend + self.roi2det3d = roi2det3d + self.early_depth_fusion = early_depth_fusion + + # Determine use_camera_prompt based on geometry_backend.is_ray_aware + # Ray-aware backends already fuse ray info into depth_latents, + # so we don't need the separate ray_embeddings (camera prompt) branch. + if self.geometry_backend is not None and hasattr(self.geometry_backend, 'is_ray_aware'): + use_camera_prompt = not self.geometry_backend.is_ray_aware + print(f"[WildDet3D] geometry_backend.is_ray_aware={self.geometry_backend.is_ray_aware}, use_camera_prompt={use_camera_prompt}") + else: + use_camera_prompt = True # Default to True for safety + print(f"[WildDet3D] No geometry_backend or is_ray_aware attr, defaulting use_camera_prompt=True") + + # Get depth_latent_dim from geometry_backend (for 3D head) + if self.geometry_backend is not None and hasattr(self.geometry_backend, 'target_latent_dim'): + depth_latent_dim = self.geometry_backend.target_latent_dim + else: + depth_latent_dim = 256 # Default + + # Create or validate bbox3d_head with correct use_camera_prompt setting + if bbox3d_head is not None: + self.bbox3d_head = bbox3d_head + # Warn if provided head has mismatched use_camera_prompt + if hasattr(bbox3d_head, 'use_camera_prompt') and bbox3d_head.use_camera_prompt != use_camera_prompt: + print(f"[WildDet3D] Warning: bbox3d_head.use_camera_prompt={bbox3d_head.use_camera_prompt} " + f"but geometry_backend suggests use_camera_prompt={use_camera_prompt}") + else: + self.bbox3d_head = Det3DHead( + embed_dims=self.hidden_dim, + box_coder=self.box_coder, + use_camera_prompt=use_camera_prompt, + depth_latent_dim=depth_latent_dim, + ) + print(f"[WildDet3D] Created bbox3d_head with use_camera_prompt={use_camera_prompt}, depth_latent_dim={depth_latent_dim}") + + # 3D conf_branches use xavier init (from _init_weights in head.py). + # No warm start from class_embed: the positive-only loss design + # (quality targets ~0.1-0.3 early) conflicts with class_embed's + # high-logit initialization, causing large initial loss. + + # Load geometry backend pretrained weights + # This is called during __init__ to ensure weights are loaded for first training + # (on_load_checkpoint is only called when resuming from checkpoint) + if self.geometry_backend is not None and hasattr(self.geometry_backend, 'load_pretrained_weights'): + print("[WildDet3D] Loading geometry backend pretrained weights...") + self.geometry_backend.load_pretrained_weights() + + # Ensure SAM3 has a matcher for training + # SAM3 built with eval_mode=True doesn't have a matcher, so we create one + # Using BinaryHungarianMatcherV2 with focal=True to match SAM3 original config + if self.sam3.matcher is None: + from sam3.train.matcher import BinaryHungarianMatcherV2 + print("[WildDet3D] Creating BinaryHungarianMatcherV2 for training...") + self.sam3.matcher = BinaryHungarianMatcherV2( + cost_class=2.0, # SAM3 original + cost_bbox=5.0, # SAM3 original + cost_giou=2.0, # SAM3 original + focal=True, # SAM3 original + alpha=0.25, # SAM3 original + gamma=2.0, # SAM3 original + ) + + # Freeze SAM3 ViT backbone blocks (like lingbot encoder_freeze_blocks) + # SAM3 ViT has 32 blocks at sam3.backbone.vision_backbone.trunk.blocks + if backbone_freeze_blocks > 0: + trunk = self.sam3.backbone.vision_backbone.trunk + num_blocks = len(trunk.blocks) + backbone_freeze_blocks = min(backbone_freeze_blocks, num_blocks) + + # Freeze patch_embed + ln_pre + first N blocks + for p in trunk.patch_embed.parameters(): + p.requires_grad = False + for p in trunk.ln_pre.parameters(): + p.requires_grad = False + for i in range(backbone_freeze_blocks): + for p in trunk.blocks[i].parameters(): + p.requires_grad = False + + frozen_params = sum( + p.numel() for p in trunk.parameters() if not p.requires_grad + ) + total_params = sum(p.numel() for p in trunk.parameters()) + print( + f"[WildDet3D] Backbone freeze: {backbone_freeze_blocks}/{num_blocks}" + f" blocks frozen ({frozen_params/1e6:.1f}M/{total_params/1e6:.1f}M params)" + ) + + # Upgrade ALL LayerNorm in the entire model to fp32. + # In bf16 mixed-precision, LayerNorm's variance computation can + # overflow (bf16 max ~65504). This covers sam3 (transformer decoder, + # backbone, encoder), geometry_backend (DINOv2 encoder, intrinsic + # head), early_depth_fusion (depth_norm), and bbox3d_head. + # Negligible performance cost -- LayerNorm is memory-bound. + n_replaced = _upgrade_layernorms_to_fp32(self) + print(f"[WildDet3D] Upgraded {n_replaced} LayerNorm -> Fp32LayerNorm (entire model)") + + def _xyxy_to_cxcywh(self, boxes: Tensor) -> Tensor: + """Convert boxes from xyxy to cxcywh format. + + Args: + boxes: Tensor of shape (..., 4) in xyxy format + + Returns: + Tensor of shape (..., 4) in cxcywh format + """ + x1, y1, x2, y2 = boxes.unbind(-1) + cx = (x1 + x2) / 2 + cy = (y1 + y2) / 2 + w = x2 - x1 + h = y2 - y1 + return torch.stack([cx, cy, w, h], dim=-1) + + def _build_find_target(self, batch: WildDet3DInput) -> BatchedFindTarget: + """Convert WildDet3DInput GT to SAM3's BatchedFindTarget format. + + This is used for SAM3's internal matching during training. + + SAM3 expects: + - boxes: (N_total, 4) packed cxcywh normalized + - boxes_padded: (N_prompts, max_gt, 4) padded cxcywh + - num_boxes: (N_prompts,) number of GT per prompt + - is_exhaustive: (N_prompts,) bool + + Note: In WildDet3D, each prompt corresponds to exactly one GT box, + so gt_boxes2d has shape (N_prompts, 4) not (N_prompts, max_gt, 4). + + Args: + batch: WildDet3DInput with gt_boxes2d in normalized xyxy + + Returns: + BatchedFindTarget for SAM3's _compute_matching + """ + device = batch.gt_boxes2d.device + gt_boxes_xyxy = batch.gt_boxes2d + + # Handle different input shapes + # Case 1: (N_prompts, 4) - one GT per prompt (WildDet3D design) + # Case 2: (N_prompts, max_gt, 4) - multiple GTs per prompt (general case) + if gt_boxes_xyxy.dim() == 2: + # Shape: (N_prompts, 4) - one GT per prompt + N_prompts = gt_boxes_xyxy.shape[0] + max_gt = 1 + + # Convert xyxy -> cxcywh + gt_boxes_cxcywh = self._xyxy_to_cxcywh(gt_boxes_xyxy) # (N_prompts, 4) + + # Each prompt has exactly 1 GT box + num_boxes = torch.ones(N_prompts, dtype=torch.long, device=device) + + # Packed boxes = all boxes (no padding) + boxes_packed = gt_boxes_cxcywh # (N_prompts, 4) + + # Padded format: add max_gt dimension + gt_boxes_cxcywh_padded = gt_boxes_cxcywh.unsqueeze(1) # (N_prompts, 1, 4) + + # Object IDs: sequential + object_ids = torch.arange(N_prompts, device=device) + object_ids_padded = torch.arange(N_prompts, device=device).unsqueeze(1) # (N_prompts, 1) + + else: + # Shape: (N_prompts, max_gt, 4) - multiple GTs per prompt + N_prompts = gt_boxes_xyxy.shape[0] + max_gt = gt_boxes_xyxy.shape[1] + + # Convert xyxy -> cxcywh + gt_boxes_cxcywh = self._xyxy_to_cxcywh(gt_boxes_xyxy) + + # Compute num_boxes per prompt (count non-zero boxes) + valid_mask = (gt_boxes_xyxy.abs().sum(dim=-1) > 1e-6) # (N_prompts, max_gt) + num_boxes = valid_mask.sum(dim=-1) # (N_prompts,) + + # Pack boxes (remove padding) + boxes_list = [] + for i in range(N_prompts): + n = int(num_boxes[i].item()) + if n > 0: + boxes_list.append(gt_boxes_cxcywh[i, :n]) + if boxes_list: + boxes_packed = torch.cat(boxes_list, dim=0) # (N_total, 4) + else: + boxes_packed = torch.zeros(0, 4, device=device) + + gt_boxes_cxcywh_padded = gt_boxes_cxcywh + + # Object IDs (placeholder - just sequential) + object_ids = torch.arange(len(boxes_packed), device=device) + object_ids_padded = torch.full( + (N_prompts, max_gt), -1, device=device, dtype=torch.long + ) + offset = 0 + for i in range(N_prompts): + n = int(num_boxes[i].item()) + if n > 0: + object_ids_padded[i, :n] = torch.arange( + offset, offset + n, device=device + ) + offset += n + + return BatchedFindTarget( + num_boxes=num_boxes, + boxes=boxes_packed, + boxes_padded=gt_boxes_cxcywh_padded, + repeated_boxes=None, + segments=None, + semantic_segments=None, + is_valid_segment=None, + # is_exhaustive: controls negative loss masking in SAM3's IABCEMdetr. + # Multi-target queries (TEXT=0, VISUAL=1, VISUAL+LABEL=3) are exhaustive: + # all instances of the category are annotated as targets. + # Single-target queries (GEOMETRY=2, GEOMETRY+LABEL=4) are NOT exhaustive: + # only 1 selected instance is the target, other instances of the + # same category exist but are not annotated for this query. + is_exhaustive=self._get_is_exhaustive(batch, N_prompts, device), + object_ids=object_ids, + object_ids_padded=object_ids_padded, + ) + + def _get_is_exhaustive( + self, + batch: WildDet3DInput, + N_prompts: int, + device: torch.device, + ) -> Tensor: + """Determine is_exhaustive per query based on query_types. + + Multi-target queries (TEXT=0, VISUAL=1, VISUAL+LABEL=3) are exhaustive: + all instances of the category are annotated as targets, so unmatched + predictions should receive negative loss. + + Single-target queries (GEOMETRY=2, GEOMETRY+LABEL=4) are NOT exhaustive: + only 1 selected instance is the target. Other instances of the same + category exist but are not annotated for this query, so unmatched + predictions should NOT receive negative loss. + """ + if batch.query_types is not None: + qt = batch.query_types.to(device) + return (qt == 0) | (qt == 1) | (qt == 3) + return torch.ones(N_prompts, dtype=torch.bool, device=device) + + def on_load_checkpoint(self, checkpoint): + """ + PyTorch Lightning hook called when loading a checkpoint. + + This is called BEFORE load_state_dict, so we can: + 1. Load SAM3 pretrained weights first (if first training) + 2. Load geometry backend pretrained weights first (if first training) + 3. Filter out incompatible keys from the checkpoint + 4. Let PyTorch Lightning load the filtered checkpoint + """ + print("\n" + "="*80) + print("WildDet3D CHECKPOINT LOADING (PyTorch Lightning Hook)") + print("="*80) + + # Get the state_dict from checkpoint + state_dict = checkpoint.get('state_dict', {}) + + # Analyze checkpoint content + has_sam3 = any('sam3.' in key for key in state_dict.keys()) + has_geometry_backend = any('geometry_backend' in key for key in state_dict.keys()) + has_bbox3d_head = any('bbox3d_head' in key for key in state_dict.keys()) + + # Determine if this is resume training or first training + is_resume = has_sam3 and has_geometry_backend + + if is_resume: + # Resume training: load everything from checkpoint + print("\nMode: Resume Training") + print("Loading complete checkpoint (all components)") + print(f" Resuming from epoch {checkpoint.get('epoch', 'unknown')}") + print(f" Resuming from global_step {checkpoint.get('global_step', 'unknown')}") + + else: + # First training: load pretrained weights + print("\nMode: First Training (Fine-tuning)") + + # Step 1: Load SAM3 pretrained weights (if not already loaded in __init__) + if not has_sam3 and self.sam3 is not None: + print("\n[Step 1/3] SAM3 weights already loaded in __init__") + print(f" SAM3 checkpoint: {getattr(self, '_sam3_checkpoint_path', 'unknown')}") + + # Step 2: Load geometry backend pretrained weights + if self.geometry_backend is not None and hasattr(self.geometry_backend, 'load_pretrained_weights'): + print("\n[Step 2/3] Loading geometry backend pretrained weights...") + self.geometry_backend.load_pretrained_weights() + + # Step 3: Filter checkpoint if needed + print("\n[Step 3/3] Processing checkpoint...") + if not has_sam3: + print(" No SAM3 weights in checkpoint (will use pretrained SAM3)") + if not has_geometry_backend: + print(" No geometry_backend weights in checkpoint (will use pretrained)") + if not has_bbox3d_head: + print(" No bbox3d_head weights in checkpoint (will initialize randomly)") + + # Step 4: Reset training state (epoch, step, optimizer) + print("\n[Step 4/4] Resetting training state for fine-tuning...") + if 'epoch' in checkpoint: + old_epoch = checkpoint['epoch'] + checkpoint['epoch'] = 0 + print(f" Reset epoch: {old_epoch} -> 0") + + if 'global_step' in checkpoint: + old_step = checkpoint['global_step'] + checkpoint['global_step'] = 0 + print(f" Reset global_step: {old_step} -> 0") + + # Remove optimizer states (they won't match our new optimizer config) + if 'optimizer_states' in checkpoint: + del checkpoint['optimizer_states'] + print(f" Removed optimizer_states (will initialize fresh)") + + # Remove lr_scheduler states + if 'lr_schedulers' in checkpoint: + del checkpoint['lr_schedulers'] + print(f" Removed lr_schedulers (will initialize fresh)") + + # Store resume status for later use + self._is_resume_training = is_resume + + print("\n" + "="*80) + print("Checkpoint loading hook completed") + print("="*80 + "\n") + + def forward( + self, + batch: WildDet3DInput, + targets: dict | None = None, + ) -> WildDet3DOut: + """Forward pass of WildDet3D using SAM3's forward_grounding. + + This method reuses SAM3's complete 2D detection pipeline and adds + 3D detection on top. + + Args: + batch: WildDet3DInput containing: + - images: (B_images, 3, H, W) + - intrinsics: (B_images, 3, 3) + - img_ids: (N_prompts,) - which image each prompt belongs to + - text_ids: (N_prompts,) - text index per prompt + - unique_texts: List[str] - all unique texts + - geo_boxes: (N_prompts, max_K, 4) - normalized cxcywh + - geo_boxes_mask: (N_prompts, max_K) - True=padding + - geo_box_labels: (N_prompts, max_K) - 0/1 for neg/pos + targets: Training targets (optional) + + Returns: + WildDet3DOut with 2D and 3D predictions + """ + B_images = batch.images.shape[0] + N_prompts = len(batch.img_ids) + _, _, H, W = batch.images.shape + device = batch.images.device + + profile_start("forward_total") + + # Sync SAM3 training mode with parent module + # This is important because SAM3's forward_grounding only computes + # matching indices when self.training is True + if self.sam3.training != self.training: + self.sam3.train(self.training) + + # Handle empty batch (no prompts) + if N_prompts == 0: + if self.training: + # Create dummy output connected to ALL model parameters for DDP backward + # DDP requires all parameters to participate in backward across all ranks + # Using only one parameter causes deadlock when other ranks use all params + rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + print(f"[WildDet3D] Empty batch (N_prompts=0) on rank {rank}, using all-param dummy") + dummy_grad = sum(p.sum() * 0 for p in self.parameters() if p.requires_grad) + dummy_logits = torch.zeros(1, 1, 1, device=device) + dummy_grad + return WildDet3DOut( + pred_logits=dummy_logits, + pred_boxes_2d=torch.zeros(1, 1, 4, device=device), + pred_boxes_3d=None, + aux_outputs=None, + geom_losses=None, + presence_logits=None, + queries=None, + encoder_hidden_states=None, + indices=None, + ) + else: + # Test mode: return empty Det3DOut + return Det3DOut( + boxes=[torch.zeros(0, 4, device=device) for _ in range(B_images)], + boxes3d=[torch.zeros(0, 10, device=device) for _ in range(B_images)], + scores=[torch.zeros(0, device=device) for _ in range(B_images)], + class_ids=[torch.zeros(0, dtype=torch.long, device=device) for _ in range(B_images)], + depth_maps=None, + categories=None, + ) + + # ========== Step 1 & 2: SAM3 Backbone + Geometry Backend (PARALLEL) ========== + # These two operations are independent - run them in parallel using CUDA streams + profile_start(" backbone+geom_parallel") + + # Convert images for SAM3 (needed by backbone) + images_for_sam3 = self._convert_imagenet_to_sam3_norm(batch.images) + + # Prepare geometry backend inputs + geom_losses = None + depth_latents = None + geom_out = None + _, _, H, W = batch.images.shape + + if self.geometry_backend is not None: + # Create CUDA streams for parallel execution + backbone_stream = torch.cuda.Stream() + geom_stream = torch.cuda.Stream() + + # Prepare inputs for geometry backend (before streams) + intrinsics_per_image = batch.intrinsics + depth_gt = None + depth_mask = None + if self.training or self.use_depth_input_test: + depth_gt = getattr(batch, 'depth_gt', None) + if self.training: + depth_mask = getattr(batch, 'depth_mask', None) + + # Run backbone on stream 1 + profile_start(" backbone") + with torch.cuda.stream(backbone_stream): + backbone_out = {"img_batch_all_stages": batch.images} + backbone_out.update(self.sam3.backbone.forward_image(images_for_sam3)) + text_out = self.sam3.backbone.forward_text( + batch.unique_texts, device=device + ) + backbone_out.update(text_out) + + # Run geometry backend on stream 2 (parallel with backbone) + profile_start(" geometry_backend") + with torch.cuda.stream(geom_stream): + geom_out = self.geometry_backend( + images=batch.images, + depth_feats=None, # Not using backbone features + intrinsics=intrinsics_per_image, + image_hw=(H, W), + depth_gt=depth_gt, + depth_mask=depth_mask, + padding=batch.padding, + ) + + # Wait for both streams to complete + backbone_stream.synchronize() + profile_stop(" backbone") + geom_stream.synchronize() + profile_stop(" geometry_backend") + + # Extract geometry outputs + depth_latents = geom_out.get("depth_latents") + if self.training: + geom_losses = geom_out.get("losses", {}) + else: + # No geometry backend - just run backbone + profile_start(" backbone") + backbone_out = {"img_batch_all_stages": batch.images} + backbone_out.update(self.sam3.backbone.forward_image(images_for_sam3)) + text_out = self.sam3.backbone.forward_text( + batch.unique_texts, device=device + ) + backbone_out.update(text_out) + profile_stop(" backbone") + + profile_stop(" backbone+geom_parallel") + + # ========== Step 2.5: Early Depth Fusion (after backbone, before encoder) ========== + # Fuse depth_latents into backbone visual features before encoder + # This allows depth information to participate in encoder's self-attention + # and text cross-attention + if self.early_depth_fusion is not None and depth_latents is not None: + # Get depth_latents spatial dimensions from geometry backend output + aux = geom_out.get("aux", {}) + depth_latents_hw = aux.get("depth_latents_hw") + + if depth_latents_hw is not None and "backbone_fpn" in backbone_out: + # Fuse depth into visual features + backbone_fpn = backbone_out["backbone_fpn"] + + # early_depth_fusion expects list of visual features + if not isinstance(backbone_fpn, list): + backbone_fpn = [backbone_fpn] + + # Perform fusion + fused_fpn = self.early_depth_fusion( + visual_feats=backbone_fpn, + depth_latents=depth_latents, + depth_latents_hw=depth_latents_hw, + ) + + # Update backbone_out with fused features + # SAM3 will use these fused features in encoder + if len(fused_fpn) == 1: + backbone_out["backbone_fpn"] = fused_fpn[0] + else: + backbone_out["backbone_fpn"] = fused_fpn + + # Log fusion delta magnitude (monitoring only) + if self.training and geom_losses is not None: + geom_losses["metric_fusion_delta"] = torch.tensor( + self.early_depth_fusion._last_delta_mean_abs, + device=device, + ) + else: + # Warn user that early depth fusion is configured but cannot run + import warnings + if depth_latents_hw is None: + warnings.warn( + "EarlyDepthFusion is configured but depth_latents_hw not " + "provided by geometry backend. Skipping depth fusion. " + "Check geometry backend outputs include 'aux.depth_latents_hw'.", + UserWarning, + ) + elif "backbone_fpn" not in backbone_out: + warnings.warn( + "EarlyDepthFusion is configured but backbone_fpn not found " + "in backbone outputs. Skipping depth fusion.", + UserWarning, + ) + + # ========== Step 3: Build SAM3 inputs ========== + find_input = self._build_find_stage(batch, device) + geometric_prompt = self._build_geometric_prompt(batch, device) + + # ========== Step 4: SAM3 forward_grounding ========== + # This does: encode_prompt -> encoder -> decoder -> score/box prediction + # + # In training mode, we build find_target from batch GT boxes so that + # SAM3's internal _compute_matching can compute matching indices. + # These indices are then used by our loss function. + find_target = None + if self.training: + assert batch.gt_boxes2d is not None, \ + "Training requires GT boxes (batch.gt_boxes2d)" + find_target = self._build_find_target(batch) + + profile_start(" sam3_grounding") + sam3_out = self.sam3.forward_grounding( + backbone_out=backbone_out, + find_input=find_input, + find_target=find_target, + geometric_prompt=geometric_prompt, + ) + profile_stop(" sam3_grounding") + + # ========== Step 5: Extract SAM3 outputs ========== + # SAM3 output format (after _update_scores_and_boxes): + # - pred_logits: (N_prompts, num_queries, 1) - final layer + # - pred_boxes: (N_prompts, num_queries, 4) - normalized cxcywh + # - pred_boxes_xyxy: (N_prompts, num_queries, 4) - normalized xyxy + # - queries: (N_prompts, num_queries, d_model) - last layer hidden states + # - aux_outputs: list of dicts for each decoder layer (for deep supervision) + # O2O outputs (one-to-one matching) + pred_logits = sam3_out["pred_logits"] # (N_prompts, S, 1) + pred_boxes_xyxy = sam3_out["pred_boxes_xyxy"] # (N_prompts, S, 4) + pred_boxes_cxcywh = sam3_out["pred_boxes"] # (N_prompts, S, 4) + queries = sam3_out.get("queries") # (N_prompts, S, d_model) + encoder_hidden_states = sam3_out.get("encoder_hidden_states") + presence_logits = sam3_out.get("presence_logit_dec") + + # O2M outputs (one-to-many matching) from SAM3 DAC mechanism + # These are separate outputs from the second half of queries in DAC mode + pred_logits_o2m = sam3_out.get("pred_logits_o2m") # (N_prompts, S, 1) + pred_boxes_xyxy_o2m = sam3_out.get("pred_boxes_xyxy_o2m") # (N_prompts, S, 4) + pred_boxes_cxcywh_o2m = sam3_out.get("pred_boxes_o2m") # (N_prompts, S, 4) + queries_o2m = sam3_out.get("queries_o2m") # (N_prompts, S, d_model) + + # Extract auxiliary outputs from SAM3 for deep supervision + sam3_aux_outputs = sam3_out.get("aux_outputs", []) + + # ========== Step 6: 3D Head ========== + profile_start(" 3d_head") + pred_boxes_3d = None + pred_conf_3d = None + aux_outputs = None + + if self.bbox3d_head is not None and queries is not None: + # Generate ray embeddings if camera prompt is enabled + # For ray-aware backends, depth_latents already + # contain ray info, so we can either use camera prompt or skip it + ray_embeddings = None + if self.bbox3d_head.use_camera_prompt: + # Get ray parameters from geometry backend output + if geom_out is not None: + # Use backend's ray parameters for consistent space + ray_intrinsics = geom_out.get("ray_intrinsics", batch.intrinsics) + ray_image_hw = geom_out.get("ray_image_hw", (H, W)) + ray_downsample = geom_out.get("ray_downsample", 16) + else: + # Fallback: use image-level intrinsics with default downsample + # Note: This will broadcast to all prompts, not per-prompt + ray_intrinsics = batch.intrinsics + ray_image_hw = (H, W) + ray_downsample = 16 # Default + + ray_embeddings = self.bbox3d_head.get_camera_embeddings( + ray_intrinsics, ray_image_hw, ray_downsample + ) + + # Align depth_latents and ray_embeddings spatial resolution (if needed) + # + # Note: This code only runs when use_camera_prompt=True (i.e., for non-ray-aware + # backends). For ray-aware backends, use_camera_prompt=False and + # ray_embeddings=None, so this block is skipped. + # + # When this does run, depth_latents and ray_embeddings may have different spatial + # resolutions that need to be aligned for the 3D head's cross-attention. + if depth_latents is not None and ray_embeddings is not None: + # depth_latents: [B_images, N_depth, C_depth] + # ray_embeddings: [B_images, N_ray, C_ray] + B_depth, N_depth, C_depth = depth_latents.shape + B_ray, N_ray, C_ray = ray_embeddings.shape + + if N_depth != N_ray: + # Resize depth_latents to match ray spatial size + # Infer spatial dimensions (assuming square) + H_depth = int(N_depth ** 0.5) + W_depth = H_depth + H_ray = int(N_ray ** 0.5) + W_ray = H_ray + + # Reshape depth_latents: [B, N, C] -> [B, C, H, W] + depth_latents_2d = depth_latents.permute(0, 2, 1).reshape( + B_depth, C_depth, H_depth, W_depth + ) + + # Adaptive pool to ray size + depth_latents_resized = torch.nn.functional.adaptive_avg_pool2d( + depth_latents_2d, (H_ray, W_ray) + ) + + # Reshape back: [B, C, H, W] -> [B, N, C] + depth_latents = depth_latents_resized.reshape( + B_depth, C_depth, H_ray * W_ray + ).permute(0, 2, 1) + + # Index ray_embeddings and depth_latents from per-image to per-prompt + # ray_embeddings and depth_latents are per-image [B_images, N, C] + # But 3D head expects them to be per-prompt [N_prompts, N, C] + # Use batch.img_ids to correctly map prompts to their corresponding images + if ray_embeddings is not None: + # batch.img_ids: [N_prompts] - which image each prompt belongs to + # ray_embeddings: [B_images, N, C] + # Index to get: [N_prompts, N, C] + ray_embeddings = ray_embeddings[batch.img_ids] + + if depth_latents is not None: + # depth_latents: [B_images, N, C] + # Index to get: [N_prompts, N, C] + depth_latents = depth_latents[batch.img_ids] + + # ========== Deep Supervision: Process all decoder layers ========== + # Following SAM3's design, we process auxiliary outputs from all decoder layers + # for deep supervision during training + # + # SAM3's output structure: + # - aux_outputs[0..L-2]: intermediate decoder layers (layer 0 to layer L-2) + # - final output (pred_logits, queries, etc.): final decoder layer (layer L-1) + + # Collect all layers' queries in correct order: [layer0, layer1, ..., layerL-1] + # Track which aux_outputs have queries for building aux_outputs later + all_layers_queries = [] + aux_indices_with_queries = [] # Track original indices of aux_outputs with queries + for i, aux_out in enumerate(sam3_aux_outputs): + aux_queries = aux_out.get("queries") + if aux_queries is not None: + all_layers_queries.append(aux_queries) + aux_indices_with_queries.append(i) + all_layers_queries.append(queries) # Final layer at the end + + # Stack to (L, N_prompts, S, C) format expected by 3D head + if len(all_layers_queries) > 1: + # Have auxiliary outputs - stack all layers + hidden_states = torch.stack(all_layers_queries, dim=0) # (L, N_prompts, S, C) + else: + # No auxiliary outputs - just expand final layer + hidden_states = queries.unsqueeze(0) # (1, N_prompts, S, C) + + # Call 3D head with all layers + # Returns: (L, N_prompts, S, 12), (L, N_prompts, S, 1) + all_layers_boxes_3d, all_layers_conf_3d = self.bbox3d_head( + hidden_states=hidden_states, + ray_embeddings=ray_embeddings, + depth_latents=depth_latents, + ) + + # Extract final layer output + if len(all_layers_queries) > 1: + pred_boxes_3d = all_layers_boxes_3d[-1] # (N_prompts, S, 12) + pred_conf_3d = all_layers_conf_3d[-1] # (N_prompts, S, 1) + else: + pred_boxes_3d = all_layers_boxes_3d.squeeze(0) # (N_prompts, S, 12) + pred_conf_3d = all_layers_conf_3d.squeeze(0) # (N_prompts, S, 1) + + # Build auxiliary outputs for deep supervision + # Only include layers that have queries (tracked by aux_indices_with_queries) + if len(aux_indices_with_queries) > 0 and self.training: + aux_outputs = [] + for layer_idx, orig_idx in enumerate(aux_indices_with_queries): + aux_out = sam3_aux_outputs[orig_idx] + aux_dict = { + "pred_logits": aux_out["pred_logits"], + "pred_boxes_2d": aux_out["pred_boxes_xyxy"], + "pred_boxes_3d": all_layers_boxes_3d[layer_idx], # 3D predictions for this layer + } + # Include presence logits if available + if "presence_logit_dec" in aux_out: + aux_dict["presence_logits"] = aux_out["presence_logit_dec"] + aux_outputs.append(aux_dict) + + # Compute 3D boxes for O2M queries (if available, only during training) + pred_boxes_3d_o2m = None + pred_conf_3d_o2m = None + if self.bbox3d_head is not None and queries_o2m is not None and self.training: + # O2M queries use the same 3D head but only compute final layer (no aux) + o2m_hidden_states = queries_o2m.unsqueeze(0) # (1, N_prompts, S, C) + o2m_boxes_3d, o2m_conf_3d = self.bbox3d_head( + hidden_states=o2m_hidden_states, + ray_embeddings=ray_embeddings, + depth_latents=depth_latents, + ) + pred_boxes_3d_o2m = o2m_boxes_3d.squeeze(0) # (N_prompts, S, 12) + pred_conf_3d_o2m = o2m_conf_3d.squeeze(0) # (N_prompts, S, 1) + + profile_stop(" 3d_head") + + # Training mode: return raw outputs for loss computation + if self.training: + # Extract matching indices from SAM3 output (computed by _compute_matching) + sam3_indices = sam3_out.get("indices", None) + + profile_stop("forward_total") + + # Record profiling step (will print summary every N steps if enabled) + profile_step() + + return WildDet3DOut( + pred_logits=pred_logits, + pred_boxes_2d=pred_boxes_xyxy, + pred_boxes_3d=pred_boxes_3d, + aux_outputs=aux_outputs, + geom_losses=geom_losses, + presence_logits=presence_logits, + queries=queries, + encoder_hidden_states=encoder_hidden_states, + indices=sam3_indices, + pred_boxes_2d_cxcywh=pred_boxes_cxcywh, + # O2M outputs from SAM3 DAC mechanism + pred_logits_o2m=pred_logits_o2m, + pred_boxes_2d_o2m=pred_boxes_xyxy_o2m, + pred_boxes_2d_cxcywh_o2m=pred_boxes_cxcywh_o2m, + pred_boxes_3d_o2m=pred_boxes_3d_o2m, + # 3D confidence head outputs + pred_conf_3d=pred_conf_3d, + pred_conf_3d_o2m=pred_conf_3d_o2m, + ) + + # Test mode: forward_test returns Det3DOut for evaluation + return self._forward_test( + pred_logits=pred_logits, + pred_boxes_2d=pred_boxes_xyxy, + pred_boxes_3d=pred_boxes_3d, + pred_conf_3d=pred_conf_3d, + presence_logits=presence_logits, + batch=batch, + geom_out=geom_out, + ) + + def _convert_imagenet_to_sam3_norm(self, images: Tensor) -> Tensor: + """Convert ImageNet normalized images to SAM3 normalization. + + vis4d/3D-MOOD uses ImageNet normalization: + ImageNet: (x - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225] + Output range: ~[-2.5, 2.5] + + SAM3 expects custom normalization: + SAM3: (x - 0.5) / 0.5 + Output range: [-1, 1] + + This function converts from ImageNet normalized to SAM3 normalized: + 1. Denormalize ImageNet -> [0, 1] + 2. Normalize SAM3 -> [-1, 1] + + Args: + images: ImageNet normalized images (B, 3, H, W) + + Returns: + SAM3 normalized images (B, 3, H, W) + """ + # ImageNet constants + imagenet_mean = torch.tensor( + [0.485, 0.456, 0.406], device=images.device, dtype=images.dtype + ).view(1, 3, 1, 1) + imagenet_std = torch.tensor( + [0.229, 0.224, 0.225], device=images.device, dtype=images.dtype + ).view(1, 3, 1, 1) + + # Denormalize: ImageNet normalized -> [0, 1] + images_01 = images * imagenet_std + imagenet_mean + + # Normalize: [0, 1] -> SAM3 [-1, 1] + images_sam3 = (images_01 - 0.5) / 0.5 + + return images_sam3 + + def _forward_test( + self, + pred_logits: Tensor, + pred_boxes_2d: Tensor, + pred_boxes_3d: Tensor | None, + pred_conf_3d: Tensor | None = None, + presence_logits: Tensor | None = None, + batch: WildDet3DInput | None = None, + geom_out: dict | None = None, + ) -> Det3DOut: + """Forward pass for test/inference mode. + + Postprocesses model outputs to Det3DOut format for evaluation. + Converts per-prompt outputs to per-image outputs with: + - Pixel coordinate boxes (scaled from normalized) + - Decoded 3D boxes + - Score thresholding (optional) + + Args: + pred_logits: (N_prompts, S, 1) objectness logits + pred_boxes_2d: (N_prompts, S, 4) normalized xyxy boxes + pred_boxes_3d: (N_prompts, S, 12) encoded 3D params or None + pred_conf_3d: (N_prompts, S, 1) 3D confidence logits or None + presence_logits: (N_prompts, 1) presence logits (category exists in image) + batch: Input batch with img_ids, intrinsics, etc. + geom_out: Geometry backend output (may contain depth_maps) + + Returns: + Det3DOut with per-image detection results + """ + H, W = batch.images.shape[2:] + device = pred_logits.device + B_images = batch.images.shape[0] + + # 2D confidence (foreground/background) - used for threshold & NMS + scores_2d = pred_logits.sigmoid().squeeze(-1) # (N_prompts, S) + + # 3D confidence (depth/geometry quality) - tracked separately + scores_3d_all = None + if pred_conf_3d is not None: + scores_3d_all = pred_conf_3d.sigmoid().squeeze(-1) # (N_prompts, S) + + # Combined score for ranking (NMS tie-breaking etc) + # WILDDET3D_CONF_WEIGHT env var overrides config (e.g., "0.0" for 2D only) + import os + conf_weight = self.eval_3d_conf_weight + conf_weight_override = os.environ.get("WILDDET3D_CONF_WEIGHT", None) + if conf_weight_override is not None: + conf_weight = float(conf_weight_override) + if scores_3d_all is not None and conf_weight > 0: + scores_all = scores_2d + conf_weight * scores_3d_all + else: + scores_all = scores_2d + + # Apply presence score if available (following SAM3 original postprocessors.py) + # Presence score indicates whether a category has objects in the image + # This suppresses all proposals for categories that don't exist in the image + # SAM3 original: presence_score = outputs["presence_logit_dec"].sigmoid().unsqueeze(1) + if presence_logits is not None and self.use_presence_score: + presence_score = presence_logits.sigmoid() + # Ensure correct shape for broadcasting: (N_prompts, 1) or (N_prompts,) -> (N_prompts, 1) + if presence_score.dim() == 1: + presence_score = presence_score.unsqueeze(-1) + scores_all = scores_all * presence_score # (N_prompts, S) * (N_prompts, 1) + scores_2d = scores_2d * presence_score # Also apply to 2D scores + + # Scale boxes to pixel coordinates + # pred_boxes_2d is normalized xyxy [0, 1] + boxes_pixel = pred_boxes_2d.clone() + boxes_pixel[..., 0::2] *= W + boxes_pixel[..., 1::2] *= H + + # Group by image + boxes_list = [] + boxes3d_list = [] + scores_list = [] + scores_2d_list = [] + scores_3d_list = [] + class_ids_list = [] + + # Get parameters from roi2det3d if available + score_threshold = getattr(self.roi2det3d, 'score_threshold', -1.0) if self.roi2det3d else -1.0 + + # NMS parameters (following 3D-MOOD's RoI2Det3D design) + # Note: max_per_img not used - WildDet3D already limits to 100 proposals per category + use_nms = getattr(self.roi2det3d, 'nms', False) if self.roi2det3d else False + # class_agnostic_nms=False: NMS only within same category (recommended for per-category prediction) + class_agnostic_nms = getattr(self.roi2det3d, 'class_agnostic_nms', False) if self.roi2det3d else False + iou_threshold = getattr(self.roi2det3d, 'iou_threshold', 0.5) if self.roi2det3d else 0.5 + + # Environment variable overrides (useful for A/B testing) + import os + # SAM3_NMS=0 to disable, SAM3_NMS=1 to enable + nms_override = os.environ.get("SAM3_NMS", None) + if nms_override is not None: + use_nms = nms_override == "1" + # SAM3_SCORE_THRESH to override score threshold (e.g., "0.0" to disable) + score_thresh_override = os.environ.get("SAM3_SCORE_THRESH", None) + if score_thresh_override is not None: + score_threshold = float(score_thresh_override) + # SAM3_IOU_THRESH to override NMS IoU threshold (e.g., "0.8" for more conservative) + iou_thresh_override = os.environ.get("SAM3_IOU_THRESH", None) + if iou_thresh_override is not None: + iou_threshold = float(iou_thresh_override) + + # Debug: print config once at start + if not hasattr(self, '_nms_config_printed'): + print(f"[NMS CONFIG] use_nms={use_nms}, class_agnostic={class_agnostic_nms}, iou_thresh={iou_threshold}, score_thresh={score_threshold}") + # Log predicted intrinsics setting + _use_pred_k = self.use_predicted_intrinsics + _pred_k_override = os.environ.get("SAM3_USE_PRED_K", None) + if _pred_k_override is not None: + _use_pred_k = _pred_k_override == "1" + print(f"[INTRINSICS CONFIG] use_predicted_intrinsics={_use_pred_k}") + self._nms_config_printed = True + + S = scores_all.shape[1] # predictions per prompt + + for img_idx in range(B_images): + # Find prompts belonging to this image + prompt_mask = batch.img_ids == img_idx + n_prompts_this_img = prompt_mask.sum().item() + + if n_prompts_this_img == 0: + # No prompts for this image + boxes_list.append(torch.zeros(0, 4, device=device)) + boxes3d_list.append(torch.zeros(0, 10, device=device)) + scores_list.append(torch.zeros(0, device=device)) + scores_2d_list.append(torch.zeros(0, device=device)) + scores_3d_list.append(torch.zeros(0, device=device)) + class_ids_list.append(torch.zeros(0, dtype=torch.long, device=device)) + continue + + # Get predictions for this image's prompts + img_scores = scores_all[prompt_mask] # (n_prompts, S) + img_boxes = boxes_pixel[prompt_mask] # (n_prompts, S, 4) + + # Get class IDs for each prompt + if batch.gt_category_ids is not None: + img_class_ids = batch.gt_category_ids[prompt_mask] # (n_prompts,) or (n_prompts, max_gt) + if img_class_ids.dim() > 1: + img_class_ids = img_class_ids[:, 0] # Take first if multiple + elif batch.text_ids is not None: + img_class_ids = batch.text_ids[prompt_mask] + else: + img_class_ids = torch.zeros(n_prompts_this_img, dtype=torch.long, device=device) + + if self.oracle_eval: + # Oracle mode: IoU top-K + highest confidence + # 1. Compute 2D IoU between each proposal and its GT box + # 2. Take top-K proposals by IoU (well-localized candidates) + # 3. Among top-K, pick highest confidence (best quality) + oracle_topk = int(os.environ.get("SAM3_ORACLE_TOPK", "10")) + prompt_indices = torch.arange(n_prompts_this_img, device=device) + best_indices = torch.zeros(n_prompts_this_img, dtype=torch.long, device=device) + + if batch.geo_boxes is not None: + # geo_boxes is in padded-normalized cxcywh (correct space) + img_geo_boxes = batch.geo_boxes[prompt_mask] # (n_prompts, max_K, 4) + gt_cxcywh = img_geo_boxes[:, 0, :] # (n_prompts, 4) + gt_xyxy_norm = box_cxcywh_to_xyxy(gt_cxcywh) + gt_boxes_pixel = gt_xyxy_norm.clone() + gt_boxes_pixel[:, 0::2] *= W + gt_boxes_pixel[:, 1::2] *= H + + K = min(oracle_topk, S) + for p_idx in range(n_prompts_this_img): + ious = box_iou( + img_boxes[p_idx], gt_boxes_pixel[p_idx].unsqueeze(0) + ).squeeze(-1) # (S,) + # Top-K by IoU + _, topk_iou_indices = ious.topk(K) + # Among top-K, pick highest confidence + topk_scores = img_scores[p_idx][topk_iou_indices] + best_in_topk = topk_scores.argmax() + best_indices[p_idx] = topk_iou_indices[best_in_topk] + + if img_idx == 0 and not hasattr(self, '_oracle_debug_printed'): + self._oracle_debug_printed = True + p0_ious = box_iou( + img_boxes[0], gt_boxes_pixel[0].unsqueeze(0) + ).squeeze(-1) + sel = best_indices[0].item() + print( + f"[ORACLE] topK={K}, " + f"IoU={p0_ious[sel]:.4f}, " + f"score={img_scores[0][sel]:.4f}, " + f"maxIoU={p0_ious.max():.4f}" + ) + else: + # Fallback: pure argmax + best_indices = img_scores.argmax(dim=1) + + img_scores_flat = img_scores[prompt_indices, best_indices] + img_boxes_flat = img_boxes[prompt_indices, best_indices] + img_class_ids_flat = img_class_ids + + # Track 2D and 3D scores for oracle mode + img_scores_2d_flat = scores_2d[prompt_mask][prompt_indices, best_indices] + if scores_3d_all is not None: + img_scores_3d = scores_3d_all[prompt_mask] + img_scores_3d_flat = img_scores_3d[prompt_indices, best_indices] + else: + img_scores_3d_flat = torch.zeros_like(img_scores_flat) + + if pred_boxes_3d is not None: + img_boxes3d = pred_boxes_3d[prompt_mask] + img_boxes3d_flat = img_boxes3d[prompt_indices, best_indices] + else: + img_boxes3d_flat = None + + else: + # Standard mode: flatten all proposals + NMS + # Flatten all predictions: (n_prompts, S) -> (n_prompts * S,) + img_scores_flat = img_scores.flatten() # (n_prompts * S,) + img_boxes_flat = img_boxes.reshape(-1, 4) # (n_prompts * S, 4) + + # Track 2D scores separately for threshold filtering and output + img_scores_2d = scores_2d[prompt_mask].flatten() # (n_prompts * S,) + img_scores_2d_flat = img_scores_2d # alias for output + + # Track 3D scores + if scores_3d_all is not None: + img_scores_3d_flat = scores_3d_all[prompt_mask].flatten() + else: + img_scores_3d_flat = torch.zeros_like(img_scores_flat) + + # Expand class_ids to match flattened shape + img_class_ids_flat = img_class_ids.unsqueeze(1).expand(-1, S).flatten() # (n_prompts * S,) + + # Get 3D boxes if available (flattened) + if pred_boxes_3d is not None: + img_boxes3d = pred_boxes_3d[prompt_mask] # (n_prompts, S, 12) + img_boxes3d_flat = img_boxes3d.reshape(-1, 12) # (n_prompts * S, 12) + else: + img_boxes3d_flat = None + + # Score threshold filter (uses 2D score only) + if score_threshold > 0: + keep = img_scores_2d > score_threshold + img_scores_flat = img_scores_flat[keep] + img_scores_2d_flat = img_scores_2d_flat[keep] + img_scores_2d = img_scores_2d[keep] + img_scores_3d_flat = img_scores_3d_flat[keep] + img_boxes_flat = img_boxes_flat[keep] + img_class_ids_flat = img_class_ids_flat[keep] + if img_boxes3d_flat is not None: + img_boxes3d_flat = img_boxes3d_flat[keep] + + # NMS based on 2D boxes (following RoI2Det3D design) + if use_nms and len(img_boxes_flat) > 0: + n_before_nms = len(img_boxes_flat) + if class_agnostic_nms: + keep = nms(img_boxes_flat, img_scores_flat, iou_threshold) + else: + keep = batched_nms( + img_boxes_flat, img_scores_flat, img_class_ids_flat, iou_threshold + ) + img_scores_flat = img_scores_flat[keep] + img_scores_2d_flat = img_scores_2d_flat[keep] + img_scores_3d_flat = img_scores_3d_flat[keep] + img_boxes_flat = img_boxes_flat[keep] + img_class_ids_flat = img_class_ids_flat[keep] + if img_boxes3d_flat is not None: + img_boxes3d_flat = img_boxes3d_flat[keep] + if img_idx == 0: + n_after_nms = len(img_boxes_flat) + print(f"[NMS DEBUG] img={img_idx}, before={n_before_nms}, after={n_after_nms}, suppressed={n_before_nms - n_after_nms}, iou_thresh={iou_threshold}") + + # Decode 3D boxes in padded space BEFORE rescaling (matching GDino3D) + # Use padded-space intrinsics since 2D boxes are still in padded + # pixel coordinates at this point. + # When use_predicted_intrinsics is enabled, use geometry backend's + # K_pred (also in padded space) instead of dataset intrinsics. + if img_boxes3d_flat is not None and self.box_coder is not None and len(img_boxes_flat) > 0: + # Determine whether to use predicted intrinsics + use_pred_k = self.use_predicted_intrinsics + pred_k_override = os.environ.get("SAM3_USE_PRED_K", None) + if pred_k_override is not None: + use_pred_k = pred_k_override == "1" + + if use_pred_k and geom_out is not None and "K_pred" in geom_out and geom_out["K_pred"] is not None: + intrinsics_this_img = geom_out["K_pred"][img_idx] # (3, 3) padded-space + else: + intrinsics_this_img = batch.intrinsics[img_idx] # (3, 3) padded-space + + decoded_boxes3d = self.box_coder.decode( + img_boxes_flat, # pixel xyxy in padded space + img_boxes3d_flat, + intrinsics_this_img, + ) + else: + decoded_boxes3d = torch.zeros(len(img_boxes_flat), 10, device=device) + + # Rescale 2D boxes from padded space (H, W) to original image space + # Must account for CenterPad: first subtract padding offset, then + # divide by content_size/original_size (NOT padded_size/original_size). + # Matches GDino3D RoI2Det3D.__call__ (head.py:380-396). + if batch.original_hw is not None: + # original_hw may be List[tuple] or a single tuple + # (Lightning's transfer_batch_to_device can unwrap + # single-element lists for batch_size=1) + hw = batch.original_hw + if isinstance(hw, (tuple, list)) and len(hw) == 2 and isinstance(hw[0], (int, float)): + # Direct tuple (h, w) - single image batch + orig_h, orig_w = hw + elif isinstance(hw, (tuple, list)) and img_idx < len(hw): + orig_h, orig_w = hw[img_idx] + else: + orig_h, orig_w = None, None + + if orig_h is None: + continue + + img_boxes_flat = img_boxes_flat.clone() # Don't modify in-place + + # padding may also be unwrapped for batch_size=1 + pad_info = batch.padding + if pad_info is not None: + if isinstance(pad_info, (tuple, list)) and len(pad_info) == 4 and isinstance(pad_info[0], (int, float)): + # Direct [L,R,T,B] - single image batch + pad_left, pad_right, pad_top, pad_bottom = pad_info + elif isinstance(pad_info, (tuple, list)) and img_idx < len(pad_info) and pad_info[img_idx] is not None: + pad_left, pad_right, pad_top, pad_bottom = pad_info[img_idx] + else: + pad_left = pad_right = pad_top = pad_bottom = 0 + + # Step 1: subtract CenterPad offset + img_boxes_flat[:, 0::2] -= pad_left + img_boxes_flat[:, 1::2] -= pad_top + # Step 2: scale = content_size / original_size + content_w = W - pad_left - pad_right + content_h = H - pad_top - pad_bottom + scale_x = content_w / orig_w + scale_y = content_h / orig_h + else: + # Fallback: no padding info, use full image size + scale_x = W / orig_w + scale_y = H / orig_h + img_boxes_flat[:, 0::2] /= scale_x # x coordinates + img_boxes_flat[:, 1::2] /= scale_y # y coordinates + + boxes_list.append(img_boxes_flat) + boxes3d_list.append(decoded_boxes3d) + scores_list.append(img_scores_flat) + scores_2d_list.append(img_scores_2d_flat) + scores_3d_list.append(img_scores_3d_flat) + class_ids_list.append(img_class_ids_flat) + + # Get depth maps if available + depth_maps = None + if geom_out is not None and "depth_map" in geom_out: + depth_maps = [geom_out["depth_map"][i] for i in range(B_images)] + + # Get predicted intrinsics if available + predicted_intrinsics = None + if geom_out is not None and "K_pred" in geom_out: + predicted_intrinsics = geom_out["K_pred"] + + return Det3DOut( + boxes=boxes_list, + boxes3d=boxes3d_list, + scores=scores_list, + class_ids=class_ids_list, + depth_maps=depth_maps, + categories=None, + predicted_intrinsics=predicted_intrinsics, + scores_3d=scores_3d_list, + scores_2d=scores_2d_list, + ) + + def _build_find_stage( + self, + batch: WildDet3DInput, + device: torch.device, + ) -> FindStage: + """Convert WildDet3DInput to SAM3's FindStage format. + + FindStage is SAM3's internal representation for per-prompt batch, + containing img_ids, text_ids, and geometry inputs. + """ + N_prompts = len(batch.img_ids) + + # Prepare geometry inputs - need to convert to sequence-first + # FindStage expects (max_K, N_prompts, 4) for boxes + if batch.geo_boxes is not None: + # (N_prompts, max_K, 4) -> (max_K, N_prompts, 4) + input_boxes = batch.geo_boxes.permute(1, 0, 2) + input_boxes_mask = batch.geo_boxes_mask # (N_prompts, max_K) + input_boxes_label = ( + batch.geo_box_labels.permute(1, 0) + if batch.geo_box_labels is not None + else torch.ones( + input_boxes.shape[0], N_prompts, dtype=torch.long, device=device + ) + ) + else: + # No geometry input - create empty tensors + input_boxes = torch.zeros(0, N_prompts, 4, device=device) + input_boxes_mask = torch.ones(N_prompts, 0, dtype=torch.bool, device=device) + input_boxes_label = torch.zeros(0, N_prompts, dtype=torch.long, device=device) + + # Points (if any) + if batch.geo_points is not None: + input_points = batch.geo_points.permute(1, 0, 2) # (max_P, N, 2) + input_points_mask = batch.geo_points_mask + else: + input_points = torch.zeros(0, N_prompts, 2, device=device) + input_points_mask = torch.ones(N_prompts, 0, dtype=torch.bool, device=device) + + return FindStage( + img_ids=batch.img_ids, + text_ids=batch.text_ids, + input_boxes=input_boxes, + input_boxes_mask=input_boxes_mask, + input_boxes_label=input_boxes_label, + input_points=input_points, + input_points_mask=input_points_mask, + object_ids=None, + ) + + def _build_geometric_prompt( + self, + batch: WildDet3DInput, + device: torch.device, + ) -> Prompt: + """Build SAM3 Prompt object from batch. + + SAM3's Prompt class expects sequence-first format: (K, N_prompts, dim) + """ + N_prompts = len(batch.img_ids) + + # Box prompts + if batch.geo_boxes is not None and batch.geo_boxes.shape[1] > 0: + # (N_prompts, max_K, 4) -> (max_K, N_prompts, 4) + box_embeddings = batch.geo_boxes.permute(1, 0, 2) + box_mask = batch.geo_boxes_mask # (N_prompts, max_K) + box_labels = ( + batch.geo_box_labels.permute(1, 0) + if batch.geo_box_labels is not None + else torch.ones( + box_embeddings.shape[0], N_prompts, dtype=torch.long, device=device + ) + ) + else: + box_embeddings = None + box_mask = None + box_labels = None + + # Point prompts + if batch.geo_points is not None and batch.geo_points.shape[1] > 0: + point_embeddings = batch.geo_points.permute(1, 0, 2) # (max_P, N, 2) + point_mask = batch.geo_points_mask + point_labels = ( + batch.geo_point_labels.permute(1, 0) + if batch.geo_point_labels is not None + else torch.ones( + point_embeddings.shape[0], N_prompts, dtype=torch.long, device=device + ) + ) + else: + # For text-only mode: create empty tensors instead of None + # SAM3's geometry encoder cannot handle None for points + point_embeddings = torch.zeros(0, N_prompts, 2, device=device) + point_mask = torch.ones(N_prompts, 0, dtype=torch.bool, device=device) + point_labels = torch.zeros(0, N_prompts, dtype=torch.long, device=device) + + # Ensure box prompts also have empty tensors if None + if box_embeddings is None: + box_embeddings = torch.zeros(0, N_prompts, 4, device=device) + box_mask = torch.ones(N_prompts, 0, dtype=torch.bool, device=device) + box_labels = torch.zeros(0, N_prompts, dtype=torch.long, device=device) + + return Prompt( + box_embeddings=box_embeddings, + box_mask=box_mask, + box_labels=box_labels, + point_embeddings=point_embeddings, + point_mask=point_mask, + point_labels=point_labels, + ) + + @torch.no_grad() + def inference( + self, + batch: WildDet3DInput, + score_threshold: float = 0.3, + nms_threshold: float = 0.5, + ) -> list[dict]: + """Run inference and decode 3D boxes. + + Args: + batch: WildDet3DInput with images and prompts + score_threshold: Confidence threshold + nms_threshold: NMS IoU threshold + + Returns: + List of dicts per image with decoded 3D boxes + """ + self.eval() + + out = self.forward(batch) + + if self.roi2det3d is None or out.pred_boxes_3d is None: + return self._decode_2d_only(out, batch.img_ids, score_threshold) + + # Decode 3D boxes using roi2det3d + H, W = batch.images.shape[2:] + intrinsics_per_prompt = batch.intrinsics[batch.img_ids] + results = self.roi2det3d( + pred_logits=out.pred_logits, + pred_boxes_2d=out.pred_boxes_2d, + pred_boxes_3d=out.pred_boxes_3d, + intrinsics=intrinsics_per_prompt, + image_size=(H, W), + img_ids=batch.img_ids, + score_threshold=score_threshold, + nms_threshold=nms_threshold, + ) + return results + + def _decode_2d_only( + self, + out: WildDet3DOut, + img_ids: Tensor, + score_threshold: float, + ) -> list[dict]: + """Decode 2D-only results when 3D head is not available.""" + scores = out.pred_logits.sigmoid().squeeze(-1) # (N_prompts, S) + boxes = out.pred_boxes_2d # (N_prompts, S, 4) normalized xyxy + + results = [] + unique_img_ids = img_ids.unique() + + for img_id in unique_img_ids: + mask = img_ids == img_id + img_scores = scores[mask].flatten() + img_boxes = boxes[mask].reshape(-1, 4) + + keep = img_scores > score_threshold + results.append({ + "scores": img_scores[keep], + "boxes_2d": img_boxes[keep], + "boxes_3d": None, + }) + + return results + + +def build_wilddet3d( + sam3_checkpoint: str | None = None, + geometry_backend_type: str = "unidepth_v2", + hidden_dim: int = 256, + num_decoder_layers: int = 6, + device: str = "cuda", +) -> WildDet3D: + """Factory function to build WildDet3D model. + + Args: + sam3_checkpoint: Path to SAM3 checkpoint + geometry_backend_type: Type of geometry backend + hidden_dim: Hidden dimension for 3D head + num_decoder_layers: Number of decoder layers + device: Device to load model on + + Returns: + Initialized WildDet3D model + + Note: + Learning rate control is handled by param_groups in optimizer config, + not by freezing parameters. + """ + from sam3.model.sam3_image import build_sam3_image + from wilddet3d.depth import GeometryBackendBase + + # Build SAM3 model + sam3_model = build_sam3_image(checkpoint=sam3_checkpoint) + sam3_model = sam3_model.to(device) + + # Build geometry backend + # Note: geometry backend construction depends on the specific backend type + # For now, this is a placeholder - users should construct the backend externally + geometry_backend = None + + # Build 3D head + bbox3d_head = Det3DHead( + hidden_dim=hidden_dim, + num_layers=num_decoder_layers, + ) + + # Build box coder + box_coder = Det3DCoder() + + # Build inference post-processor + roi2det3d = RoI2Det3D(box_coder=box_coder) + + model = WildDet3D( + sam3_model=sam3_model, + bbox3d_head=bbox3d_head, + box_coder=box_coder, + geometry_backend=geometry_backend, + roi2det3d=roi2det3d, + ) + + return model.to(device) diff --git a/wilddet3d/ops/__init__.py b/wilddet3d/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f43b8f2189da9b7d3906ec11138e0eb7a3ac920c --- /dev/null +++ b/wilddet3d/ops/__init__.py @@ -0,0 +1 @@ +"""Operations and layers.""" diff --git a/wilddet3d/ops/__pycache__/__init__.cpython-311.pyc b/wilddet3d/ops/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75a2081652f643f2e3ee74d6379eac00e276e047 Binary files /dev/null and b/wilddet3d/ops/__pycache__/__init__.cpython-311.pyc differ diff --git a/wilddet3d/ops/__pycache__/attention.cpython-311.pyc b/wilddet3d/ops/__pycache__/attention.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50a698ee12f9b2aba6df8a21744b4f4dd7a266e6 Binary files /dev/null and b/wilddet3d/ops/__pycache__/attention.cpython-311.pyc differ diff --git a/wilddet3d/ops/__pycache__/box2d.cpython-311.pyc b/wilddet3d/ops/__pycache__/box2d.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa43381a26ff7d207d2c97175a33ad4e7df9d4dc Binary files /dev/null and b/wilddet3d/ops/__pycache__/box2d.cpython-311.pyc differ diff --git a/wilddet3d/ops/__pycache__/mlp.cpython-311.pyc b/wilddet3d/ops/__pycache__/mlp.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4661d22de7cc3762130c1ae31c785b43f2ec3cf Binary files /dev/null and b/wilddet3d/ops/__pycache__/mlp.cpython-311.pyc differ diff --git a/wilddet3d/ops/__pycache__/nystrom.cpython-311.pyc b/wilddet3d/ops/__pycache__/nystrom.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0788d650bf627d4ce5a5f5b8098d4debb71eb6e Binary files /dev/null and b/wilddet3d/ops/__pycache__/nystrom.cpython-311.pyc differ diff --git a/wilddet3d/ops/__pycache__/profiler.cpython-311.pyc b/wilddet3d/ops/__pycache__/profiler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d2704871f91c73fb42c6b1c242180b4a39d5565 Binary files /dev/null and b/wilddet3d/ops/__pycache__/profiler.cpython-311.pyc differ diff --git a/wilddet3d/ops/__pycache__/ray.cpython-311.pyc b/wilddet3d/ops/__pycache__/ray.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c54126675e1b0127cd2af998dd2da6b7928adeb Binary files /dev/null and b/wilddet3d/ops/__pycache__/ray.cpython-311.pyc differ diff --git a/wilddet3d/ops/__pycache__/rotation.cpython-311.pyc b/wilddet3d/ops/__pycache__/rotation.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5db103271b6baff2bbdee82a3754ef3a16809b1 Binary files /dev/null and b/wilddet3d/ops/__pycache__/rotation.cpython-311.pyc differ diff --git a/wilddet3d/ops/__pycache__/upsample.cpython-311.pyc b/wilddet3d/ops/__pycache__/upsample.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..528c02c2784e5edb615b339d9dbc0c5ef29f39f4 Binary files /dev/null and b/wilddet3d/ops/__pycache__/upsample.cpython-311.pyc differ diff --git a/wilddet3d/ops/__pycache__/util.cpython-311.pyc b/wilddet3d/ops/__pycache__/util.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d8ee3f3acc4625bff40096db60d4097ab6ec24d Binary files /dev/null and b/wilddet3d/ops/__pycache__/util.cpython-311.pyc differ diff --git a/wilddet3d/ops/attention.py b/wilddet3d/ops/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..4ffb2fa5a54a8a5ce4f1b7634db16d25fb0cd51e --- /dev/null +++ b/wilddet3d/ops/attention.py @@ -0,0 +1,284 @@ +"""Attention layer.""" + +from functools import partial +from math import log2, pi + +import torch +import torch.nn.functional as F +from einops import rearrange +from torch import Tensor, nn + +from .mlp import MLP +from .nystrom import NystromAttention + + +class LayerScale(nn.Module): + """Layer scale.""" + + def __init__( + self, + dim: int, + init_values: float | Tensor = 1e-5, + inplace: bool = False, + ) -> None: + """Initialize.""" + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + """Forward.""" + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class AttentionBlock(nn.Module): + """Attention block.""" + + def __init__( + self, + dim: int, + num_heads: int = 4, + expansion: int = 4, + dropout: float = 0.0, + cosine: bool = False, + gated: bool = False, + layer_scale: float = 1.0, + context_dim: int | None = None, + ) -> None: + """Initialize.""" + super().__init__() + self.num_heads = num_heads + self.hidden_dim = dim + self.context_dim = context_dim or dim + + self.norm_attnx = nn.LayerNorm(self.hidden_dim) + self.norm_attnctx = nn.LayerNorm(self.context_dim) + + self.q = nn.Linear(self.hidden_dim, self.hidden_dim) + self.kv = nn.Linear(self.context_dim, self.hidden_dim * 2) + + self.cosine = cosine + self.dropout = dropout + self.out = nn.Linear(self.hidden_dim, self.hidden_dim) + + self.ls1 = ( + LayerScale(dim, layer_scale) + if layer_scale > 0.0 + else nn.Identity() + ) + + self.mlp = MLP( + self.hidden_dim, expansion=expansion, dropout=dropout, gated=gated + ) + + self.ls2 = ( + LayerScale(dim, layer_scale) + if layer_scale > 0.0 + else nn.Identity() + ) + + def attn( + self, + x: Tensor, + attn_bias: Tensor | None = None, + context: Tensor | None = None, + pos_embed: Tensor | None = None, + pos_embed_context: Tensor | None = None, + rope: nn.Module | None = None, + ) -> Tensor: + """Attention.""" + x = self.norm_attnx(x) + + context = self.norm_attnctx(context) + + q = rearrange(self.q(x), "b n (h d) -> b h n d", h=self.num_heads) + + k, v = rearrange( + self.kv(context), + "b n (kv h d) -> b h n d kv", + h=self.num_heads, + kv=2, + ).unbind(dim=-1) + + if rope is not None: + q = rope(q) + k = rope(k) + else: + if pos_embed is not None: + pos_embed = rearrange( + pos_embed, "b n (h d) -> b h n d", h=self.num_heads + ) + q = q + pos_embed + + if pos_embed_context is not None: + pos_embed_context = rearrange( + pos_embed_context, "b n (h d) -> b h n d", h=self.num_heads + ) + k = k + pos_embed_context + + if self.cosine: + q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) + + x = F.scaled_dot_product_attention( + q, k, v, dropout_p=self.dropout, attn_mask=attn_bias + ) + x = rearrange(x, "b h n d -> b n (h d)") + x = self.out(x) + return x + + def forward( + self, + x: Tensor, + attn_bias: Tensor | None = None, + context: Tensor | None = None, + pos_embed: Tensor | None = None, + pos_embed_context: Tensor | None = None, + rope: nn.Module | None = None, + ) -> Tensor: + """Forward.""" + context = x if context is None else context + + x = ( + self.ls1( + self.attn( + x, + rope=rope, + attn_bias=attn_bias, + context=context, + pos_embed=pos_embed, + pos_embed_context=pos_embed_context, + ) + ) + + x + ) + + return self.ls2(self.mlp(x)) + x + + +class NystromBlock(AttentionBlock): + def __init__( + self, + dim: int, + num_heads: int = 4, + expansion: int = 4, + dropout: float = 0.0, + cosine: bool = False, + gated: bool = False, + layer_scale: float = 1.0, + context_dim: int | None = None, + ): + super().__init__( + dim=dim, + num_heads=num_heads, + expansion=expansion, + dropout=dropout, + cosine=cosine, + gated=gated, + layer_scale=layer_scale, + context_dim=context_dim, + ) + self.attention_fn = NystromAttention( + num_landmarks=128, num_heads=num_heads, dropout=dropout + ) + + def attn( + self, + x: torch.Tensor, + attn_bias: torch.Tensor | None = None, + context: torch.Tensor | None = None, + pos_embed: torch.Tensor | None = None, + pos_embed_context: torch.Tensor | None = None, + rope: nn.Module | None = None, + ) -> torch.Tensor: + x = self.norm_attnx(x) + context = self.norm_attnctx(context) + k, v = rearrange( + self.kv(context), + "b n (kv h d) -> b n h d kv", + h=self.num_heads, + kv=2, + ).unbind(dim=-1) + q = rearrange(self.q(x), "b n (h d) -> b n h d", h=self.num_heads) + + if rope is not None: + q = rope(q) + k = rope(k) + else: + if pos_embed is not None: + pos_embed = rearrange( + pos_embed, "b n (h d) -> b n h d", h=self.num_heads + ) + q = q + pos_embed + if pos_embed_context is not None: + pos_embed_context = rearrange( + pos_embed_context, "b n (h d) -> b n h d", h=self.num_heads + ) + k = k + pos_embed_context + + if self.cosine: + q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) + x = self.attention_fn(q, k, v, key_padding_mask=attn_bias) + x = rearrange(x, "b n h d -> b n (h d)") + x = self.out(x) + return x + + +class PositionEmbeddingSine(nn.Module): + def __init__( + self, num_pos_feats=64, temperature=10000, normalize=False, scale=None + ): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * pi + self.scale = scale + + def forward( + self, x: torch.Tensor, mask: Tensor | None = None + ) -> torch.Tensor: + if mask is None: + mask = torch.zeros( + (x.size(0), x.size(2), x.size(3)), + device=x.device, + dtype=torch.bool, + ) + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange( + self.num_pos_feats, dtype=torch.float32, device=x.device + ) + dim_t = self.temperature ** ( + 2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats + ) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + def __repr__(self, _repr_indent=4): + head = "Positional encoding " + self.__class__.__name__ + body = [ + "num_pos_feats: {}".format(self.num_pos_feats), + "temperature: {}".format(self.temperature), + "normalize: {}".format(self.normalize), + "scale: {}".format(self.scale), + ] + lines = [head] + [" " * _repr_indent + line for line in body] + return "\n".join(lines) diff --git a/wilddet3d/ops/box2d.py b/wilddet3d/ops/box2d.py new file mode 100644 index 0000000000000000000000000000000000000000..36bbb243d68c3c4ab8e054defd1c13cd8a379c8a --- /dev/null +++ b/wilddet3d/ops/box2d.py @@ -0,0 +1,101 @@ +"""Box operations for 2D bounding boxes.""" + +import numpy as np +import torch +from torch import Tensor + + +def fp16_clamp(x, min=None, max=None): + if not x.is_cuda and x.dtype == torch.float16: + return x.float().clamp(min, max).half() + return x.clamp(min, max) + + +def bbox_cxcywh_to_xyxy(bbox: Tensor) -> Tensor: + """Convert bbox coordinates from (cx, cy, w, h) to (x1, y1, x2, y2).""" + cx, cy, w, h = bbox.split((1, 1, 1, 1), dim=-1) + bbox_new = [(cx - 0.5 * w), (cy - 0.5 * h), (cx + 0.5 * w), (cy + 0.5 * h)] + return torch.cat(bbox_new, dim=-1) + + +def bbox_xyxy_to_cxcywh(bbox: Tensor) -> Tensor: + """Convert bbox coordinates from (x1, y1, x2, y2) to (cx, cy, w, h).""" + x1, y1, x2, y2 = bbox.split((1, 1, 1, 1), dim=-1) + bbox_new = [(x1 + x2) / 2, (y1 + y2) / 2, (x2 - x1), (y2 - y1)] + return torch.cat(bbox_new, dim=-1) + + +def bbox_overlaps(bboxes1, bboxes2, mode="iou", is_aligned=False, eps=1e-6): + """Calculate overlap between two set of bboxes.""" + assert mode in ["iou", "iof", "giou"], f"Unsupported mode {mode}" + assert bboxes1.size(-1) == 4 or bboxes1.size(0) == 0 + assert bboxes2.size(-1) == 4 or bboxes2.size(0) == 0 + + assert bboxes1.shape[:-2] == bboxes2.shape[:-2] + batch_shape = bboxes1.shape[:-2] + + rows = bboxes1.size(-2) + cols = bboxes2.size(-2) + if is_aligned: + assert rows == cols + + if rows * cols == 0: + if is_aligned: + return bboxes1.new(batch_shape + (rows,)) + else: + return bboxes1.new(batch_shape + (rows, cols)) + + area1 = (bboxes1[..., 2] - bboxes1[..., 0]) * ( + bboxes1[..., 3] - bboxes1[..., 1] + ) + area2 = (bboxes2[..., 2] - bboxes2[..., 0]) * ( + bboxes2[..., 3] - bboxes2[..., 1] + ) + + if is_aligned: + lt = torch.max(bboxes1[..., :2], bboxes2[..., :2]) + rb = torch.min(bboxes1[..., 2:], bboxes2[..., 2:]) + + wh = fp16_clamp(rb - lt, min=0) + overlap = wh[..., 0] * wh[..., 1] + + if mode in ["iou", "giou"]: + union = area1 + area2 - overlap + else: + union = area1 + if mode == "giou": + enclosed_lt = torch.min(bboxes1[..., :2], bboxes2[..., :2]) + enclosed_rb = torch.max(bboxes1[..., 2:], bboxes2[..., 2:]) + else: + lt = torch.max( + bboxes1[..., :, None, :2], bboxes2[..., None, :, :2] + ) + rb = torch.min( + bboxes1[..., :, None, 2:], bboxes2[..., None, :, 2:] + ) + + wh = fp16_clamp(rb - lt, min=0) + overlap = wh[..., 0] * wh[..., 1] + + if mode in ["iou", "giou"]: + union = area1[..., None] + area2[..., None, :] - overlap + else: + union = area1[..., None] + if mode == "giou": + enclosed_lt = torch.min( + bboxes1[..., :, None, :2], bboxes2[..., None, :, :2] + ) + enclosed_rb = torch.max( + bboxes1[..., :, None, 2:], bboxes2[..., None, :, 2:] + ) + + eps = union.new_tensor([eps]) + union = torch.max(union, eps) + ious = overlap / union + if mode in ["iou", "iof"]: + return ious + enclose_wh = fp16_clamp(enclosed_rb - enclosed_lt, min=0) + enclose_area = enclose_wh[..., 0] * enclose_wh[..., 1] + enclose_area = torch.max(enclose_area, eps) + gious = ious - (enclose_area - union) / enclose_area + return gious diff --git a/wilddet3d/ops/box3d.py b/wilddet3d/ops/box3d.py new file mode 100644 index 0000000000000000000000000000000000000000..d2067700d2b73e4d84717034d83285846d047e4d --- /dev/null +++ b/wilddet3d/ops/box3d.py @@ -0,0 +1,79 @@ +"""Box3D ops.""" + +from torch import Tensor +from vis4d_cuda_ops import iou_box3d + +from wilddet3d.ops.iou_box3d import check_coplanar, check_nonzero + + +def box3d_overlap( + boxes_dt: Tensor, + boxes_gt: Tensor, + eps_coplanar: float = 1e-3, + eps_nonzero: float = 1e-8, +) -> Tensor: + """ + Computes the intersection of 3D boxes_dt and boxes_gt. + + Inputs boxes_dt, boxes_gt are tensors of shape (B, 8, 3) + (where B doesn't have to be the same for boxes_dt and boxes_gt), + containing the 8 corners of the boxes, as follows: + + (4) +---------+. (5) + | ` . | ` . + | (0) +---+-----+ (1) + | | | | + (7) +-----+---+. (6)| + ` . | ` . | + (3) ` +---------+ (2) + + + NOTE: Throughout this implementation, we assume that boxes + are defined by their 8 corners exactly in the order specified in the + diagram above for the function to give correct results. In addition + the vertices on each plane must be coplanar. + As an alternative to the diagram, this is a unit bounding + box which has the correct vertex ordering: + + box_corner_vertices = [ + [0, 0, 0], + [1, 0, 0], + [1, 1, 0], + [0, 1, 0], + [0, 0, 1], + [1, 0, 1], + [1, 1, 1], + [0, 1, 1], + ] + + Args: + boxes_dt: tensor of shape (N, 8, 3) of the coordinates of the 1st boxes + boxes_gt: tensor of shape (M, 8, 3) of the coordinates of the 2nd boxes + Returns: + iou: (N, M) tensor of the intersection over union which is + defined as: `iou = vol / (vol1 + vol2 - vol)` + """ + # Make sure predictions are coplanar and nonzero + invalid_coplanar = ~check_coplanar(boxes_dt, eps=eps_coplanar) + invalid_nonzero = ~check_nonzero(boxes_dt, eps=eps_nonzero) + + ious = iou_box3d(boxes_dt, boxes_gt)[1] + + # Offending boxes are set to zero IoU + if invalid_coplanar.any(): + ious[invalid_coplanar] = 0 + print( + "Warning: skipping {:d} non-coplanar boxes at eval.".format( + int(invalid_coplanar.float().sum()) + ) + ) + + if invalid_nonzero.any(): + ious[invalid_nonzero] = 0 + print( + "Warning: skipping {:d} zero volume boxes at eval.".format( + int(invalid_nonzero.float().sum()) + ) + ) + + return ious diff --git a/wilddet3d/ops/iou_box3d.py b/wilddet3d/ops/iou_box3d.py new file mode 100644 index 0000000000000000000000000000000000000000..95c3f745d697f716979112e30ff50d0af6aca624 --- /dev/null +++ b/wilddet3d/ops/iou_box3d.py @@ -0,0 +1,174 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from typing import Tuple + +import torch +import torch.nn.functional as F +from torch import Tensor +from torch.autograd import Function +from vis4d_cuda_ops import iou_box3d + +# -------------------------------------------------- # +# CONSTANTS # +# -------------------------------------------------- # +""" +_box_planes and _box_triangles define the 4- and 3-connectivity +of the 8 box corners. +_box_planes gives the quad faces of the 3D box +_box_triangles gives the triangle faces of the 3D box +""" +_box_planes = [ + [0, 1, 2, 3], + [3, 2, 6, 7], + [0, 1, 5, 4], + [0, 3, 7, 4], + [1, 2, 6, 5], + [4, 5, 6, 7], +] +_box_triangles = [ + [0, 1, 2], + [0, 3, 2], + [4, 5, 6], + [4, 6, 7], + [1, 5, 6], + [1, 6, 2], + [0, 4, 7], + [0, 7, 3], + [3, 2, 6], + [3, 6, 7], + [0, 1, 5], + [0, 4, 5], +] + + +def check_coplanar(boxes: Tensor, eps: float = 1e-4) -> torch.BoolTensor: + """ + Checks that plane vertices are coplanar. + Returns a bool tensor of size B, where True indicates a box is coplanar. + """ + faces = torch.tensor(_box_planes, dtype=torch.int64, device=boxes.device) + verts = boxes.index_select(index=faces.view(-1), dim=1) + B = boxes.shape[0] + P, V = faces.shape + # (B, P, 4, 3) -> (B, P, 3) + v0, v1, v2, v3 = verts.reshape(B, P, V, 3).unbind(2) + + # Compute the normal + e0 = F.normalize(v1 - v0, dim=-1) + e1 = F.normalize(v2 - v0, dim=-1) + normal = F.normalize(torch.cross(e0, e1, dim=-1), dim=-1) + + # Check the fourth vertex is also on the same plane + mat1 = (v3 - v0).view(B, 1, -1) # (B, 1, P*3) + mat2 = normal.view(B, -1, 1) # (B, P*3, 1) + + return (mat1.bmm(mat2).abs() < eps).view(B) + + +def check_nonzero(boxes: Tensor, eps: float = 1e-4) -> torch.BoolTensor: + """ + Checks that the sides of the box have a non zero area + """ + faces = torch.tensor( + _box_triangles, dtype=torch.int64, device=boxes.device + ) + verts = boxes.index_select(index=faces.view(-1), dim=1) + B = boxes.shape[0] + T, V = faces.shape + # (B, T, 3, 3) -> (B, T, 3) + v0, v1, v2 = verts.reshape(B, T, V, 3).unbind(2) + + normals = torch.cross(v1 - v0, v2 - v0, dim=-1) # (B, T, 3) + face_areas = normals.norm(dim=-1) / 2 + + return (face_areas > eps).all(1).view(B) + + +class _box3d_overlap(Function): + """ + Torch autograd Function wrapper for box3d_overlap C++/CUDA implementations. + Backward is not supported. + """ + + @staticmethod + def forward(ctx, boxes1, boxes2): + """ + Arguments defintions the same as in the box3d_overlap function + """ + vol, iou = iou_box3d(boxes1, boxes2) + return vol, iou + + @staticmethod + def backward(ctx, grad_vol, grad_iou): + raise ValueError("box3d_overlap backward is not supported") + + +def box3d_overlap( + boxes1: Tensor, boxes2: Tensor, eps: float = 1e-4 +) -> Tuple[Tensor, Tensor]: + """ + Computes the intersection of 3D boxes1 and boxes2. + + Inputs boxes1, boxes2 are tensors of shape (B, 8, 3) + (where B doesn't have to be the same for boxes1 and boxes2), + containing the 8 corners of the boxes, as follows: + + (4) +---------+. (5) + | ` . | ` . + | (0) +---+-----+ (1) + | | | | + (7) +-----+---+. (6)| + ` . | ` . | + (3) ` +---------+ (2) + + + NOTE: Throughout this implementation, we assume that boxes + are defined by their 8 corners exactly in the order specified in the + diagram above for the function to give correct results. In addition + the vertices on each plane must be coplanar. + As an alternative to the diagram, this is a unit bounding + box which has the correct vertex ordering: + + box_corner_vertices = [ + [0, 0, 0], + [1, 0, 0], + [1, 1, 0], + [0, 1, 0], + [0, 0, 1], + [1, 0, 1], + [1, 1, 1], + [0, 1, 1], + ] + + Args: + boxes1: tensor of shape (N, 8, 3) of the coordinates of the 1st boxes + boxes2: tensor of shape (M, 8, 3) of the coordinates of the 2nd boxes + Returns: + vol: (N, M) tensor of the volume of the intersecting convex shapes + iou: (N, M) tensor of the intersection over union which is + defined as: `iou = vol / (vol1 + vol2 - vol)` + """ + if not all((8, 3) == box.shape[1:] for box in [boxes1, boxes2]): + raise ValueError("Each box in the batch must be of shape (8, 3)") + + if not check_coplanar(boxes1, eps): + raise ValueError("boxes1 plane vertices are not coplanar") + + if not check_coplanar(boxes2, eps): + raise ValueError("boxes2 plane vertices are not coplanar") + + if not check_nonzero(boxes1, eps): + raise ValueError("boxes1 planes have zero areas") + + if not check_nonzero(boxes2, eps): + raise ValueError("boxes2 planes have zero areas") + + vol, iou = _box3d_overlap.apply(boxes1, boxes2) + + return vol, iou diff --git a/wilddet3d/ops/language/__init__.py b/wilddet3d/ops/language/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/wilddet3d/ops/language/grounding.py b/wilddet3d/ops/language/grounding.py new file mode 100644 index 0000000000000000000000000000000000000000..d2afcc80796d496be319e700ec04a53d2fee407c --- /dev/null +++ b/wilddet3d/ops/language/grounding.py @@ -0,0 +1,206 @@ +"""Language grounding utilities.""" + +import re + +import nltk +import torch +from torch import Tensor +from transformers import BatchEncoding +from vis4d.common.logging import rank_zero_info, rank_zero_warn + + +def find_noun_phrases(caption: str) -> list: + """Find noun phrases in a caption using nltk. + Args: + caption (str): The caption to analyze. + + Returns: + list: List of noun phrases found in the caption. + + Examples: + >>> caption = 'There is two cat and a remote in the picture' + >>> find_noun_phrases(caption) # ['cat', 'a remote', 'the picture'] + """ + caption = caption.lower() + tokens = nltk.word_tokenize(caption) + pos_tags = nltk.pos_tag(tokens) + + grammar = "NP: {

?*+}" + cp = nltk.RegexpParser(grammar) + result = cp.parse(pos_tags) + + noun_phrases = [] + for subtree in result.subtrees(): + if subtree.label() == "NP": + noun_phrases.append(" ".join(t[0] for t in subtree.leaves())) + + return noun_phrases + + +def remove_punctuation(text: str) -> str: + """Remove punctuation from a text. + Args: + text (str): The input text. + + Returns: + str: The text with punctuation removed. + """ + punctuation = [ + "|", + ":", + ";", + "@", + "(", + ")", + "[", + "]", + "{", + "}", + "^", + "'", + '"', + "’", + "`", + "?", + "$", + "%", + "#", + "!", + "&", + "*", + "+", + ",", + ".", + ] + for p in punctuation: + text = text.replace(p, "") + return text.strip() + + +def run_ner(caption: str) -> tuple[list[list[int]], list[str]]: + """Run NER on a caption and return the tokens and noun phrases. + Args: + caption (str): The input caption. + + Returns: + Tuple[List, List]: A tuple containing the tokens and noun phrases. + - tokens_positive (List): A list of token positions. + - noun_phrases (List): A list of noun phrases. + """ + noun_phrases = find_noun_phrases(caption) + noun_phrases = [remove_punctuation(phrase) for phrase in noun_phrases] + noun_phrases = [phrase for phrase in noun_phrases if phrase != ""] + rank_zero_info("noun_phrases:", noun_phrases) + relevant_phrases = noun_phrases + labels = noun_phrases + + tokens_positive = [] + for entity, label in zip(relevant_phrases, labels): + try: + # search all occurrences and mark them as different entities + # TODO: Not Robust + for m in re.finditer(entity, caption.lower()): + tokens_positive.append([[m.start(), m.end()]]) + except Exception: + rank_zero_warn("noun entities:", noun_phrases) + rank_zero_warn("entity:", entity) + rank_zero_warn("caption:", caption.lower()) + return tokens_positive, noun_phrases + + +def create_positive_map( + tokenized: BatchEncoding, + tokens_positive: list[list[int]], + max_num_entities: int = 256, +) -> Tensor: + """construct a map such that positive_map[i,j] = True + if box i is associated to token j + + Args: + tokenized: The tokenized input. + tokens_positive (list): A list of token ranges + associated with positive boxes. + max_num_entities (int, optional): The maximum number of entities. + Defaults to 256. + + Returns: + torch.Tensor: The positive map. + + Raises: + Exception: If an error occurs during token-to-char mapping. + """ + positive_map = torch.zeros( + (len(tokens_positive), max_num_entities), dtype=torch.float + ) + + for j, tok_list in enumerate(tokens_positive): + for beg, end in tok_list: + try: + beg_pos = tokenized.char_to_token(beg) + end_pos = tokenized.char_to_token(end - 1) + except Exception as e: + print("beg:", beg, "end:", end) + print("token_positive:", tokens_positive) + raise e + if beg_pos is None: + try: + beg_pos = tokenized.char_to_token(beg + 1) + if beg_pos is None: + beg_pos = tokenized.char_to_token(beg + 2) + except Exception: + beg_pos = None + if end_pos is None: + try: + end_pos = tokenized.char_to_token(end - 2) + if end_pos is None: + end_pos = tokenized.char_to_token(end - 3) + except Exception: + end_pos = None + if beg_pos is None or end_pos is None: + continue + + assert beg_pos is not None and end_pos is not None + positive_map[j, beg_pos : end_pos + 1].fill_(1) + return positive_map / (positive_map.sum(-1)[:, None] + 1e-6) + + +def create_positive_map_label_to_token( + positive_map: Tensor, plus: int = 0 +) -> dict: + """Create a dictionary mapping the label to the token. + Args: + positive_map (Tensor): The positive map tensor. + plus (int, optional): Value added to the label for indexing. + Defaults to 0. + + Returns: + dict: The dictionary mapping the label to the token. + """ + positive_map_label_to_token = {} + for i in range(len(positive_map)): + positive_map_label_to_token[i + plus] = torch.nonzero( + positive_map[i], as_tuple=True + )[0].tolist() + return positive_map_label_to_token + + +def clean_label_name(name: str) -> str: + """Clean label name.""" + name = re.sub(r"\(.*\)", "", name) + name = re.sub(r"_", " ", name) + name = re.sub(r" ", " ", name) + return name + + +def chunks(lst: list, n: int) -> list: + """Yield successive n-sized chunks from lst.""" + all_ = [] + for i in range(0, len(lst), n): + data_index = lst[i : i + n] + all_.append(data_index) + counter = 0 + for i in all_: + counter += len(i) + assert counter == len(lst) + + return all_ diff --git a/wilddet3d/ops/match_cost.py b/wilddet3d/ops/match_cost.py new file mode 100644 index 0000000000000000000000000000000000000000..d1cc1341913b130cff92343f85ce5728f44ccba1 --- /dev/null +++ b/wilddet3d/ops/match_cost.py @@ -0,0 +1,273 @@ +"""Matcher cost op.""" + +import torch +from torch import Tensor +from vis4d.op.box.box2d import bbox_iou + +from wilddet3d.ops.box2d import bbox_overlaps, bbox_xyxy_to_cxcywh + + +class MatchCost: + + def __init__(self, weight: float = 1.0) -> None: + """Create an instance of the class.""" + self.weight = weight + + +class ClassificationCost(MatchCost): + """ClsSoftmaxCost. + + Args: + weight (Union[float, int]): Cost weight. Defaults to 1. + + Examples: + >>> from mmdet.models.task_modules.assigners. + ... match_costs.match_cost import ClassificationCost + >>> import torch + >>> self = ClassificationCost() + >>> cls_pred = torch.rand(4, 3) + >>> gt_labels = torch.tensor([0, 1, 2]) + >>> factor = torch.tensor([10, 8, 10, 8]) + >>> self(cls_pred, gt_labels) + tensor([[-0.3430, -0.3525, -0.3045], + [-0.3077, -0.2931, -0.3992], + [-0.3664, -0.3455, -0.2881], + [-0.3343, -0.2701, -0.3956]]) + """ + + def __init__(self, weight: float = 1.0) -> None: + """Create an instance of the class.""" + super().__init__(weight=weight) + + def __call__(self, cls_pred, gt_labels) -> Tensor: + """Compute match cost. + + Args: + pred_instances (:obj:`InstanceData`): ``scores`` inside is + predicted classification logits, of shape + (num_queries, num_class). + gt_instances (:obj:`InstanceData`): ``labels`` inside should have + shape (num_gt, ). + img_meta (Optional[dict]): _description_. Defaults to None. + + Returns: + Tensor: Match Cost matrix of shape (num_preds, num_gts). + """ + pred_scores = cls_pred.softmax(-1) + cls_cost = -pred_scores[:, gt_labels] + + return cls_cost * self.weight + + +class BBoxL1Cost(MatchCost): + """BBoxL1Cost. + + Note: ``bboxes`` in ``InstanceData`` passed in is of format 'xyxy' + and its coordinates are unnormalized. + + Args: + box_format (str, optional): 'xyxy' for DETR, 'xywh' for Sparse_RCNN. + Defaults to 'xyxy'. + weight (Union[float, int]): Cost weight. Defaults to 1. + + Examples: + >>> from mmdet.models.task_modules.assigners. + ... match_costs.match_cost import BBoxL1Cost + >>> import torch + >>> self = BBoxL1Cost() + >>> bbox_pred = torch.rand(1, 4) + >>> gt_bboxes= torch.FloatTensor([[0, 0, 2, 4], [1, 2, 3, 4]]) + >>> factor = torch.tensor([10, 8, 10, 8]) + >>> self(bbox_pred, gt_bboxes, factor) + tensor([[1.6172, 1.6422]]) + """ + + def __init__(self, box_format: str = "xyxy", weight: float = 1.0) -> None: + """Create an instance of the class.""" + super().__init__(weight=weight) + assert box_format in ["xyxy", "xywh"] + self.box_format = box_format + + def __call__( + self, + pred_bboxes, + gt_bboxes, + img_h, + img_w, + ) -> Tensor: + """Compute match cost. + + Args: + pred_instances (:obj:`InstanceData`): ``bboxes`` inside is + predicted boxes with unnormalized coordinate + (x, y, x, y). + gt_instances (:obj:`InstanceData`): ``bboxes`` inside is gt + bboxes with unnormalized coordinate (x, y, x, y). + img_meta (Optional[dict]): Image information. Defaults to None. + + Returns: + Tensor: Match Cost matrix of shape (num_preds, num_gts). + """ + # convert box format + if self.box_format == "xywh": + gt_bboxes = bbox_xyxy_to_cxcywh(gt_bboxes) + pred_bboxes = bbox_xyxy_to_cxcywh(pred_bboxes) + + # normalized + factor = gt_bboxes.new_tensor([img_w, img_h, img_w, img_h]).unsqueeze( + 0 + ) + gt_bboxes = gt_bboxes / factor + pred_bboxes = pred_bboxes / factor + + bbox_cost = torch.cdist(pred_bboxes, gt_bboxes, p=1) + + return bbox_cost * self.weight + + +class IoUCost(MatchCost): + """IoUCost. + + Note: ``bboxes`` in ``InstanceData`` passed in is of format 'xyxy' + and its coordinates are unnormalized. + + Args: + iou_mode (str): iou mode such as 'iou', 'giou'. Defaults to 'giou'. + weight (Union[float, int]): Cost weight. Defaults to 1. + + Examples: + >>> from mmdet.models.task_modules.assigners. + ... match_costs.match_cost import IoUCost + >>> import torch + >>> self = IoUCost() + >>> bboxes = torch.FloatTensor([[1,1, 2, 2], [2, 2, 3, 4]]) + >>> gt_bboxes = torch.FloatTensor([[0, 0, 2, 4], [1, 2, 3, 4]]) + >>> self(bboxes, gt_bboxes) + tensor([[-0.1250, 0.1667], + [ 0.1667, -0.5000]]) + """ + + def __init__(self, iou_mode: str = "giou", weight: float = 1.0): + super().__init__(weight=weight) + self.iou_mode = iou_mode + + def __call__( + self, + pred_bboxes, + gt_bboxes, + ): + """Compute match cost. + + Args: + pred_instances (:obj:`InstanceData`): ``bboxes`` inside is + predicted boxes with unnormalized coordinate + (x, y, x, y). + gt_instances (:obj:`InstanceData`): ``bboxes`` inside is gt + bboxes with unnormalized coordinate (x, y, x, y). + img_meta (Optional[dict]): Image information. Defaults to None. + + Returns: + Tensor: Match Cost matrix of shape (num_preds, num_gts). + """ + # avoid fp16 overflow + if pred_bboxes.dtype == torch.float16: + fp16 = True + pred_bboxes = pred_bboxes.to(torch.float32) + else: + fp16 = False + + if self.iou_mode == "iou": + overlaps = bbox_iou(pred_bboxes, gt_bboxes) + else: + overlaps = bbox_overlaps( + pred_bboxes, gt_bboxes, mode=self.iou_mode + ) + + if fp16: + overlaps = overlaps.to(torch.float16) + + # The 1 is a constant that doesn't change the matching, so omitted. + iou_cost = -overlaps + return iou_cost * self.weight + + +class BinaryFocalLossCost(MatchCost): + """BinaryFocalLossCost. + + Args: + alpha (Union[float, int]): focal_loss alpha. Defaults to 0.25. + gamma (Union[float, int]): focal_loss gamma. Defaults to 2. + eps (float): Defaults to 1e-12. + binary_input (bool): Whether the input is binary. Currently, + binary_input = True is for masks input, binary_input = False + is for label input. Defaults to False. + weight (Union[float, int]): Cost weight. Defaults to 1. + """ + + def __init__( + self, + alpha: float = 0.25, + gamma: float = 2.0, + eps: float = 1e-12, + binary_input: bool = False, + weight: float = 1.0, + ) -> None: + super().__init__(weight=weight) + self.alpha = alpha + self.gamma = gamma + self.eps = eps + self.binary_input = binary_input + + def _focal_loss_cost(self, cls_pred: Tensor, gt_labels: Tensor) -> Tensor: + """ + Args: + cls_pred (Tensor): Predicted classification logits, shape + (num_queries, num_class). + gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,). + + Returns: + torch.Tensor: cls_cost value with weight + """ + cls_pred = cls_pred.flatten(1) + gt_labels = gt_labels.flatten(1).float() + cls_pred = cls_pred.sigmoid() + neg_cost = ( + -(1 - cls_pred + self.eps).log() + * (1 - self.alpha) + * cls_pred.pow(self.gamma) + ) + pos_cost = ( + -(cls_pred + self.eps).log() + * self.alpha + * (1 - cls_pred).pow(self.gamma) + ) + + cls_cost = torch.einsum( + "nc,mc->nm", pos_cost, gt_labels + ) + torch.einsum("nc,mc->nm", neg_cost, (1 - gt_labels)) + return cls_cost * self.weight + + def __call__( + self, + cls_pred: Tensor, + text_token_mask: Tensor, + positive_map: Tensor, + ) -> Tensor: + """Compute match cost. + + Args: + pred_instances (:obj:`InstanceData`): Predicted instances which + must contain ``scores`` or ``masks``. + gt_instances (:obj:`InstanceData`): Ground truth which must contain + ``labels`` or ``mask``. + img_meta (Optional[dict]): Image information. Defaults to None. + + Returns: + Tensor: Match Cost matrix of shape (num_preds, num_gts). + """ + text_token_mask = torch.nonzero(text_token_mask).squeeze(-1) + + pred_scores = cls_pred[:, text_token_mask] + gt_labels = positive_map[:, text_token_mask] + + return self._focal_loss_cost(pred_scores, gt_labels) diff --git a/wilddet3d/ops/matchers/__init__.py b/wilddet3d/ops/matchers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/wilddet3d/ops/matchers/hungarian.py b/wilddet3d/ops/matchers/hungarian.py new file mode 100644 index 0000000000000000000000000000000000000000..9a2fd9f1756bd6fe04db2a8ebf0a01d5e3f92442 --- /dev/null +++ b/wilddet3d/ops/matchers/hungarian.py @@ -0,0 +1,117 @@ +"""Box Hungarian Assigner.""" + +from __future__ import annotations + +import numpy as np +import torch +import torch.nn.functional as F +from scipy.optimize import linear_sum_assignment +from torch import Tensor +from vis4d.op.box.box2d import bbox_iou +from vis4d.op.box.matchers.base import MatchResult + + +class HungarianMatcher: + """Computes one-to-one matching between predictions and ground truth. + + This class computes an assignment between the targets and the predictions + based on the costs. The targets don't include the no_object, so generally + there are more predictions than targets. After the one-to-one matching, the + un-matched are treated as backgrounds. Thus each query prediction will be + assigned with `0` or a positive integer indicating the ground truth index: + + - 0: negative sample, no assigned gt + - positive integer: positive sample, index (1-based) of assigned gt + """ + + def __call__( + self, + cost: Tensor, + boxes: Tensor, + targets: Tensor, + target_classes: Tensor, + ) -> MatchResult: + """Computes one-to-one matching based on the weighted costs. + + This method assign each query prediction to a ground truth or + background. The `assigned_gt_inds` with -1 means don't care, + 0 means negative sample, and positive number is the index (1-based) + of assigned gt. + The assignment is done in the following steps, the order matters. + + 1. assign every prediction to -1 + 2. compute the weighted costs + 3. do Hungarian matching on CPU based on the costs + 4. assign all to 0 (background) first, then for each matched pair + between predictions and gts, treat this prediction as foreground + and assign the corresponding gt index (plus 1) to it. + + Args: + boxes (Tensor): Predicted boxes with normalized coordinates + (cx, cy, w, h), which are all in range [0, 1]. Shape + [num_query, 4]. + boxes_classes (Tensor): Predicted classification logits, shape + [num_query, num_class]. + targets (Tensor): Ground truth boxes with unnormalized + coordinates (x1, y1, x2, y2). Shape [num_gt, 4]. + gt_labels (Tensor): Label of `targets`, shape (num_gt,). + img_meta (dict): Meta information for current image. + gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are + labelled as `ignored`. Default None. + eps (int | float, optional): A value added to the denominator for + numerical stability. Default 1e-7. + + gt_depth is a single channel map + depth_pred is per-label maps + + Returns: + MatchResult: Matching results. + """ + num_gts, num_bboxes = targets.size(0), boxes.size(0) + + match_iou = boxes.new_zeros((len(boxes),)) + + # 1. assign -1 by default + assigned_gt_inds = boxes.new_full((num_bboxes,), -1, dtype=torch.long) + assigned_labels = boxes.new_full((num_bboxes,), -1, dtype=torch.long) + + if num_gts == 0 or num_bboxes == 0: + # No ground truth or boxes, return empty assignment + if num_gts == 0: + # No ground truth, assign all to background + assigned_gt_inds[:] = 0 + return MatchResult(assigned_gt_inds, match_iou, assigned_labels) + + # 2. compute the weighted costs. + # NOTE: We dissentangle the cost computation and Hungarian matching + + # 3. do Hungarian matching on CPU using linear_sum_assignment + cost = cost.detach().cpu() + cost = np.nan_to_num(cost) + + matched_row_inds, matched_col_inds = linear_sum_assignment(cost) + + matched_row_inds = torch.from_numpy(matched_row_inds).to(boxes.device) + matched_col_inds = torch.from_numpy(matched_col_inds).to(boxes.device) + + # 4. assign backgrounds and foregrounds + # assign all indices to backgrounds first + assigned_gt_inds[:] = 0 + + # assign foregrounds based on matching results + assigned_gt_inds[matched_row_inds] = matched_col_inds + 1 + assigned_labels[matched_row_inds] = target_classes[matched_col_inds] + + pos_inds = ( + torch.nonzero(assigned_gt_inds > 0, as_tuple=False) + .squeeze(-1) + .unique() + ) + + _ious = bbox_iou(boxes[pos_inds], targets) + + for i, pid in enumerate(pos_inds): + matched_gt_idx = assigned_gt_inds[pid] - 1 + match_iou[pid] = _ious[i, matched_gt_idx] + + return MatchResult(assigned_gt_inds, match_iou, assigned_labels) diff --git a/wilddet3d/ops/mlp.py b/wilddet3d/ops/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..60ee8cea981c806c37f5343744a1a914703041d5 --- /dev/null +++ b/wilddet3d/ops/mlp.py @@ -0,0 +1,67 @@ +"""Multi-layer perceptron (MLP).""" + +import torch.nn.functional as F +from torch import Tensor, nn + + +class SwiGLU(nn.Module): + """SwiGLU activation function.""" + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + x, gates = x.chunk(2, dim=-1) + return x * F.silu(gates) + + +class MLP(nn.Module): + """Multi-layer perceptron (MLP) module.""" + + def __init__( + self, + input_dim: int, + expansion: int = 4, + dropout: float = 0.0, + gated: bool = False, + output_dim: int | None = None, + ) -> None: + """Creates an instance of the class.""" + super().__init__() + if gated: + expansion = int(expansion * 2 / 3) + hidden_dim = int(input_dim * expansion) + output_dim = output_dim if output_dim is not None else input_dim + self.norm = nn.LayerNorm(input_dim) + self.proj1 = nn.Linear(input_dim, hidden_dim) + self.proj2 = nn.Linear(hidden_dim, output_dim) + self.act = nn.GELU() if not gated else SwiGLU() + self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + """Forward pass.""" + x = self.norm(x) + x = self.proj1(x) + x = self.act(x) + x = self.proj2(x) + x = self.dropout(x) + return x + + def __call__(self, x: Tensor) -> Tensor: + """Type definition for call implementation.""" + return self._call_impl(x) + + +class SimpleMLP(nn.Module): + """Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x diff --git a/wilddet3d/ops/nystrom.py b/wilddet3d/ops/nystrom.py new file mode 100644 index 0000000000000000000000000000000000000000..669ee5707bd24f1fdd1f926bb93ca7e02ccfdee3 --- /dev/null +++ b/wilddet3d/ops/nystrom.py @@ -0,0 +1,374 @@ +"""Nystrom Attention. + +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +""" + +import math +import warnings +from contextlib import nullcontext + +import torch +from torch import Tensor, nn + + +class AvgPool(nn.Module): + def __init__(self, n: int): + super().__init__() + self.n = n + + def forward(self, x: Tensor): + seq_len = x.shape[1] + head_dim = x.shape[2] + segments = seq_len // self.n + assert ( + segments > 0 + ), "num_landmarks should be smaller than the sequence length" + + if seq_len % self.n == 0: + return x.reshape( + -1, + self.n, + segments, + head_dim, + ).mean(dim=-2) + + n_round = self.n - seq_len % self.n + + x_avg_round = ( + x[:, : n_round * segments, :] + .reshape(-1, n_round, segments, head_dim) + .mean(dim=-2) + ) + x_avg_off = ( + x[:, n_round * segments :, :] + .reshape(-1, self.n - n_round, segments + 1, head_dim) + .mean(dim=-2) + ) + return torch.cat((x_avg_round, x_avg_off), dim=-2) + + +def bmm(a: Tensor, b: Tensor) -> Tensor: + return a @ b + + +def _apply_dropout(att, dropout): + if dropout is None: + return att + att = dropout(att) + return att + + +def _matmul_with_mask( + a: Tensor, + b: Tensor, + mask: Tensor | None = None, +) -> Tensor: + if mask is None: + return a @ b + + att = a @ b + if mask.dtype == torch.bool: + if mask.ndim == 2: + mask = mask.unsqueeze(0).expand(att.shape[0], -1, -1) + att[~mask] = float("-inf") + else: + if ( + mask.ndim == 3 + and mask.shape[0] != att.shape[0] + and (att.shape[0] % mask.shape[0]) == 0 + ): + repeat_factor = att.shape[0] // mask.shape[0] + mask = mask.repeat([repeat_factor, 1, 1]) + warnings.warn( + "Mismatched batch dimensions for mask, repeating mask." + ) + att += mask + return att + + +def _softmax(a: Tensor) -> Tensor: + if a.is_sparse: + return torch.sparse.softmax(a, dim=a.ndim - 1) + return torch.softmax(a, dim=a.ndim - 1) + + +def scaled_query_key_softmax( + q: Tensor, + k: Tensor, + att_mask: Tensor | None = None, +) -> Tensor: + q = q / math.sqrt(k.size(-1)) + mask = att_mask + att = _matmul_with_mask(q, k.transpose(-2, -1), mask) + att = _softmax(att) + return att + + +def scaled_dot_product_attention( + q: Tensor, + k: Tensor, + v: Tensor, + att_mask: Tensor | None = None, + dropout: nn.Module | None = None, +) -> Tensor: + autocast_disabled = att_mask is not None and att_mask.is_sparse + + with torch.cuda.amp.autocast(enabled=False) if autocast_disabled else nullcontext(): + if autocast_disabled: + q, k, v = q.float(), k.float(), v.float() + + att = scaled_query_key_softmax(q, k, att_mask=att_mask) + att = _apply_dropout(att, dropout) + y = bmm(att, v) + return y + + +def bool_mask_to_additive( + mask: Tensor, dtype: torch.dtype | None = torch.float32 +) -> Tensor: + assert ( + mask.dtype == torch.bool + ), "This util is meant to convert in between bool masks and additive ones" + + mask_ = torch.zeros_like(mask, dtype=dtype) + mask_[~mask] = float("-inf") + return mask_ + + +def iterative_pinv( + softmax_mat: Tensor, n_iter=6, pinverse_original_init=False +): + """Computing the Moore-Penrose inverse via iterative method.""" + i = torch.eye( + softmax_mat.size(-1), + device=softmax_mat.device, + dtype=softmax_mat.dtype, + ) + k = softmax_mat + + if pinverse_original_init: + v = 1 / torch.max(torch.sum(k, dim=-2)) * k.transpose(-1, -2) + else: + v = ( + 1 + / torch.max(torch.sum(k, dim=-2), dim=-1).values[:, None, None] + * k.transpose(-1, -2) + ) + + for _ in range(n_iter): + kv = torch.matmul(k, v) + v = torch.matmul( + 0.25 * v, + 13 * i - torch.matmul(kv, 15 * i - torch.matmul(kv, 7 * i - kv)), + ) + return v + + +def reshape_key_padding_mask( + key_padding_mask: Tensor, batched_dim: int +) -> Tensor: + assert key_padding_mask.ndim == 2 + batch_size, src_len = key_padding_mask.size() + num_heads = batched_dim // batch_size + return _reshape_key_padding_mask( + key_padding_mask, batch_size, src_len, num_heads + ) + + +def _reshape_key_padding_mask( + key_padding_mask: Tensor, + batch_size: int, + src_len: int, + num_heads: int, +) -> Tensor: + assert key_padding_mask.shape == (batch_size, src_len) + key_padding_mask = ( + key_padding_mask.view(batch_size, 1, 1, src_len) + .expand(-1, num_heads, -1, -1) + .reshape(batch_size * num_heads, 1, src_len) + ) + return key_padding_mask + + +class NystromAttention(nn.Module): + """Nystrom attention mechanism.""" + + def __init__( + self, + dropout: float, + num_heads: int, + num_landmarks: int = 64, + landmark_pooling: nn.Module | None = None, + causal: bool = False, + use_razavi_pinverse: bool = True, + pinverse_original_init: bool = False, + inv_iterations: int = 6, + v_skip_connection: nn.Module | None = None, + conv_kernel_size: int | int = None, + ): + """Creates an instance of the class.""" + super().__init__() + self.requires_separate_masks = True + self.num_landmarks = num_landmarks + self.num_heads = num_heads + self.use_razavi_pinverse = use_razavi_pinverse + self.pinverse_original_init = pinverse_original_init + self.inv_iterations = inv_iterations + self.attn_drop = nn.Dropout(dropout) + self.skip_connection = v_skip_connection + self.causal = causal + + if self.skip_connection is None and conv_kernel_size is not None: + self.skip_connection = nn.Conv2d( + in_channels=self.num_heads, + out_channels=self.num_heads, + kernel_size=(conv_kernel_size, 1), + padding=(conv_kernel_size // 2, 0), + bias=False, + groups=self.num_heads, + ) + + if landmark_pooling is not None: + self.landmark_pooling = landmark_pooling + else: + self.landmark_pooling = AvgPool(n=self.num_landmarks) + + self.causal_mask_1: Tensor | None = None + self.causal_mask_2: Tensor | None = None + self.causal_mask_3: Tensor | None = None + + self.supports_attention_mask = False + self.supports_key_padding_mask = True + + def forward( + self, + q: Tensor, + k: Tensor, + v: Tensor, + key_padding_mask: Tensor | None = None, + *args, + **kwargs, + ): + batched_dim = k.size(0) + seq_len = k.size(-2) + tt = {"dtype": q.dtype, "device": q.device} + + if key_padding_mask is not None: + if key_padding_mask.dtype == torch.bool: + warnings.warn( + "Bool mask found, but an additive mask is expected. " + "Converting but this is slow" + ) + key_padding_mask = bool_mask_to_additive(key_padding_mask) + + if key_padding_mask.ndim == 2: + key_padding_mask = reshape_key_padding_mask( + key_padding_mask, batched_dim + ) + + zeros = torch.zeros_like(key_padding_mask) + ones = torch.ones_like(key_padding_mask) + is_masked = torch.isinf(-key_padding_mask) + + _mask = torch.where(is_masked, zeros, ones) + _mask = _mask.transpose(2, 1) + assert _mask.shape == (batched_dim, q.shape[1], 1) + + q = q * _mask + k = k * _mask + + assert key_padding_mask.size() == (batched_dim, 1, seq_len), ( + f"key_padding_mask has invalid dimensions {key_padding_mask.size()}." + f" Must have dimensions {batched_dim, 1, seq_len} or (batch_size, {seq_len})." + ) + + if self.num_landmarks >= seq_len: + mask: Tensor | None = None + + if self.causal: + mask = self._triu_mask(batched_dim, seq_len, seq_len, **tt) + + if key_padding_mask is not None: + mask = ( + key_padding_mask + if mask is None + else mask + key_padding_mask + ) + + x = scaled_dot_product_attention(q=q, k=k, v=v, att_mask=mask) + + else: + q_landmarks = self.landmark_pooling(q) + k_landmarks = self.landmark_pooling(k) + + if self.causal and ( + self.causal_mask_1 is None + or (batched_dim, seq_len, self.num_landmarks) + != self.causal_mask_1.size() + ): + self.causal_mask_1 = self._triu_mask( + batched_dim, seq_len, self.num_landmarks, **tt + ) + self.causal_mask_2 = self._triu_mask( + batched_dim, self.num_landmarks, self.num_landmarks, **tt + ) + self.causal_mask_3 = self._triu_mask( + batched_dim, self.num_landmarks, seq_len, **tt + ) + + mask_3: Tensor | None = self.causal_mask_3 + if key_padding_mask is not None: + mask_3 = ( + key_padding_mask + if mask_3 is None + else mask_3 + key_padding_mask + ) + + kernel_1 = scaled_query_key_softmax( + q=q, k=k_landmarks, att_mask=None + ) + kernel_2 = scaled_query_key_softmax( + q=q_landmarks, k=k_landmarks, att_mask=None + ) + kernel_3 = scaled_dot_product_attention( + q=q_landmarks, k=k, v=v, att_mask=mask_3 + ) + + kernel_2_inv = ( + iterative_pinv( + kernel_2, self.inv_iterations, self.pinverse_original_init + ) + if self.use_razavi_pinverse + else torch.linalg.pinv(kernel_2) + ) + + x = torch.matmul( + torch.matmul( + kernel_1, + kernel_2_inv, + ), + kernel_3, + ) + + if self.skip_connection: + v_conv = self.skip_connection( + v.reshape(-1, self.num_heads, v.size(-2), v.size(-1)) + ) + x += v_conv.reshape(-1, v_conv.size(-2), v_conv.size(-1)) + x = self.attn_drop(x) + return x + + def _triu_mask( + self, dim_1: int, dim_2: int, dim_3: int, **kwargs + ) -> Tensor: + device = kwargs["device"] + dtype = kwargs["dtype"] + + return torch.triu( + torch.ones(dim_2, dim_3, dtype=dtype, device=device) + * float("-inf"), + diagonal=1, + ).expand(dim_1, -1, -1) diff --git a/wilddet3d/ops/profiler.py b/wilddet3d/ops/profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..702b0f05f80677ba7fd998dc40c386091d686020 --- /dev/null +++ b/wilddet3d/ops/profiler.py @@ -0,0 +1,98 @@ +"""Training profiler for performance analysis. + +Usage: + Set environment variable PROFILE_WILDDET3D=1 to enable profiling. + Timing results are printed every N iterations. +""" + +import os +import time +from collections import defaultdict +from typing import Dict, List, Optional + +import torch +import torch.distributed as dist + + +class TrainingProfiler: + """Profiler for measuring training component timings.""" + + _instance: Optional["TrainingProfiler"] = None + + def __init__(self, print_interval: int = 10, enabled: bool = True): + self.print_interval = print_interval + self.enabled = enabled + self.timings: Dict[str, List[float]] = defaultdict(list) + self.step_count = 0 + self.current_step_timings: Dict[str, float] = {} + self._start_times: Dict[str, float] = {} + + @classmethod + def get_instance(cls) -> "TrainingProfiler": + """Get singleton instance.""" + if cls._instance is None: + enabled = os.environ.get("PROFILE_WILDDET3D", "0") == "1" + print_interval = int(os.environ.get("PROFILE_INTERVAL", "10")) + cls._instance = cls(print_interval=print_interval, enabled=enabled) + if enabled: + print(f"[TrainingProfiler] Enabled, printing every {print_interval} steps") + return cls._instance + + def _is_main_process(self) -> bool: + import multiprocessing + current = multiprocessing.current_process() + return current.name == "MainProcess" + + def _safe_cuda_sync(self) -> None: + if self._is_main_process() and torch.cuda.is_available(): + torch.cuda.synchronize() + + def start(self, name: str) -> None: + if not self.enabled: + return + if not self._is_main_process(): + return + self._safe_cuda_sync() + self._start_times[name] = time.perf_counter() + + def stop(self, name: str) -> float: + if not self.enabled: + return 0.0 + if not self._is_main_process(): + return 0.0 + self._safe_cuda_sync() + elapsed = time.perf_counter() - self._start_times.get(name, time.perf_counter()) + self.current_step_timings[name] = elapsed + return elapsed + + def step(self) -> None: + if not self.enabled: + return + for name, elapsed in self.current_step_timings.items(): + self.timings[name].append(elapsed) + self.step_count += 1 + + def _is_rank_zero(self) -> bool: + if not dist.is_initialized(): + return True + return dist.get_rank() == 0 + + +def profiler() -> TrainingProfiler: + """Get the global profiler instance.""" + return TrainingProfiler.get_instance() + + +def profile_start(name: str) -> None: + """Start profiling a named section.""" + TrainingProfiler.get_instance().start(name) + + +def profile_stop(name: str) -> float: + """Stop profiling a named section and return elapsed time.""" + return TrainingProfiler.get_instance().stop(name) + + +def profile_step() -> None: + """Mark end of training step.""" + TrainingProfiler.get_instance().step() diff --git a/wilddet3d/ops/ray.py b/wilddet3d/ops/ray.py new file mode 100644 index 0000000000000000000000000000000000000000..7bc7af6159d07a08e1ad770bd365bdc007a8c542 --- /dev/null +++ b/wilddet3d/ops/ray.py @@ -0,0 +1,771 @@ +"""Ray utilities for 3D reconstruction.""" + +import torch +from torch import Tensor +from torch.nn import functional as F + + +def generate_rays( + camera_intrinsics: Tensor, + image_shape: tuple[int, int], + noisy: bool = False, +) -> tuple[Tensor, Tensor]: + """Generates rays from camera intrinsics and image shape.""" + batch_size, device, dtype = ( + camera_intrinsics.shape[0], + camera_intrinsics.device, + camera_intrinsics.dtype, + ) + + height, width = image_shape + + # Generate grid of pixel coordinates + pixel_coords_x = torch.linspace( + 0, width - 1, width, device=device, dtype=dtype + ) + pixel_coords_y = torch.linspace( + 0, height - 1, height, device=device, dtype=dtype + ) + + if noisy: + pixel_coords_x += torch.rand_like(pixel_coords_x) - 0.5 + pixel_coords_y += torch.rand_like(pixel_coords_y) - 0.5 + + pixel_coords = torch.stack( + [ + pixel_coords_x.repeat(height, 1), + pixel_coords_y.repeat(width, 1).t(), + ], + dim=2, + ) # (H, W, 2) + pixel_coords = pixel_coords + 0.5 + + # Calculate ray directions + intrinsics_inv = ( + torch.eye(3, device=device).unsqueeze(0).repeat(batch_size, 1, 1) + ) + intrinsics_inv[:, 0, 0] = 1.0 / camera_intrinsics[:, 0, 0] + intrinsics_inv[:, 1, 1] = 1.0 / camera_intrinsics[:, 1, 1] + intrinsics_inv[:, 0, 2] = ( + -camera_intrinsics[:, 0, 2] / camera_intrinsics[:, 0, 0] + ) + intrinsics_inv[:, 1, 2] = ( + -camera_intrinsics[:, 1, 2] / camera_intrinsics[:, 1, 1] + ) + homogeneous_coords = torch.cat( + [pixel_coords, torch.ones_like(pixel_coords[:, :, :1])], dim=2 + ) # (H, W, 3) + ray_directions = torch.matmul( + intrinsics_inv, homogeneous_coords.permute(2, 0, 1).flatten(1) + ) # (3, H*W) + ray_directions = F.normalize(ray_directions, dim=1) # (B, 3, H*W) + ray_directions = ray_directions.permute(0, 2, 1) # (B, H*W, 3) + + theta = torch.atan2(ray_directions[..., 0], ray_directions[..., -1]) + phi = torch.acos(ray_directions[..., 1].clamp(-1.0, 1.0)) + # pitch = torch.asin(ray_directions[..., 1]) + # roll = torch.atan2(ray_directions[..., 0], - ray_directions[..., 1]) + angles = torch.stack([theta, phi], dim=-1) + return ray_directions, angles + + +def spherical_zbuffer_to_euclidean( + spherical_tensor: Tensor, +) -> Tensor: + """Converts a spherical zbuffer tensor to euclidean coordinates.""" + theta = spherical_tensor[..., 0] # Extract polar angle + phi = spherical_tensor[..., 1] # Extract azimuthal angle + z = spherical_tensor[..., 2] # Extract zbuffer depth + + x = z * torch.tan(theta) + y = z / torch.tan(phi) / torch.cos(theta) + + euclidean_tensor = torch.stack((x, y, z), dim=-1) + return euclidean_tensor + + +def rsh_cart_3(xyz: torch.Tensor): + """Computes all real spherical harmonics up to degree 3. + + This is an autogenerated method. See + https://github.com/cheind/torch-spherical-harmonics + for more information. + + Params: + xyz: (N,...,3) tensor of points on the unit sphere + + Returns: + rsh: (N,...,16) real spherical harmonics + projections of input. Ynm is found at index + `n*(n+1) + m`, with `0 <= n <= degree` and + `-n <= m <= n`. + """ + x = xyz[..., 0] + y = xyz[..., 1] + z = xyz[..., 2] + + x2 = x**2 + y2 = y**2 + z2 = z**2 + xy = x * y + xz = x * z + yz = y * z + + return torch.stack( + [ + xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]), + -0.48860251190292 * y, + 0.48860251190292 * z, + -0.48860251190292 * x, + 1.09254843059208 * xy, + -1.09254843059208 * yz, + 0.94617469575756 * z2 - 0.31539156525252, + -1.09254843059208 * xz, + 0.54627421529604 * x2 - 0.54627421529604 * y2, + -0.590043589926644 * y * (3.0 * x2 - y2), + 2.89061144264055 * xy * z, + 0.304697199642977 * y * (1.5 - 7.5 * z2), + 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z, + 0.304697199642977 * x * (1.5 - 7.5 * z2), + 1.44530572132028 * z * (x2 - y2), + -0.590043589926644 * x * (x2 - 3.0 * y2), + ], + -1, + ) + + +def rsh_cart_8(xyz: Tensor): + """Computes all real spherical harmonics up to degree 8. + + This is an autogenerated method. See + https://github.com/cheind/torch-spherical-harmonics + for more information. + + Params: + xyz: (N,...,3) tensor of points on the unit sphere + + Returns: + rsh: (N,...,81) real spherical harmonics + projections of input. Ynm is found at index + `n*(n+1) + m`, with `0 <= n <= degree` and + `-n <= m <= n`. + """ + x = xyz[..., 0] + y = xyz[..., 1] + z = xyz[..., 2] + + x2 = x**2 + y2 = y**2 + z2 = z**2 + xy = x * y + xz = x * z + yz = y * z + x4 = x2**2 + y4 = y2**2 + # z4 = z2**2 + return torch.stack( + [ + 0.282094791773878 + * torch.ones(1, device=xyz.device).expand(xyz.shape[:-1]), + -0.48860251190292 * y, + 0.48860251190292 * z, + -0.48860251190292 * x, + 1.09254843059208 * xy, + -1.09254843059208 * yz, + 0.94617469575756 * z2 - 0.31539156525252, + -1.09254843059208 * xz, + 0.54627421529604 * x2 - 0.54627421529604 * y2, + -0.590043589926644 * y * (3.0 * x2 - y2), + 2.89061144264055 * xy * z, + 0.304697199642977 * y * (1.5 - 7.5 * z2), + 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z, + 0.304697199642977 * x * (1.5 - 7.5 * z2), + 1.44530572132028 * z * (x2 - y2), + -0.590043589926644 * x * (x2 - 3.0 * y2), + 2.5033429417967 * xy * (x2 - y2), + -1.77013076977993 * yz * (3.0 * x2 - y2), + 0.126156626101008 * xy * (52.5 * z2 - 7.5), + 0.267618617422916 + * y + * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), + 1.48099765681286 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 0.952069922236839 * z2 + + 0.317356640745613, + 0.267618617422916 + * x + * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), + 0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5), + -1.77013076977993 * xz * (x2 - 3.0 * y2), + -3.75501441269506 * x2 * y2 + + 0.625835735449176 * x4 + + 0.625835735449176 * y4, + -0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4), + 8.30264925952416 * xy * z * (x2 - y2), + 0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2), + 0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), + 0.241571547304372 + * y + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ), + -1.24747010616985 * z * (1.5 * z2 - 0.5) + + 1.6840846433293 + * z + * ( + 1.75 + * z + * ( + 1.66666666666667 * z * (1.5 * z2 - 0.5) + - 0.666666666666667 * z + ) + - 1.125 * z2 + + 0.375 + ) + + 0.498988042467941 * z, + 0.241571547304372 + * x + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ), + 0.0456527312854602 + * (x2 - y2) + * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), + 0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2), + 2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4), + -0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4), + 4.09910463115149 * x**4 * xy + - 13.6636821038383 * xy**3 + + 4.09910463115149 * xy * y**4, + -2.36661916223175 * yz * (-10.0 * x2 * y2 + 5.0 * x4 + y4), + 0.00427144889505798 * xy * (x2 - y2) * (5197.5 * z2 - 472.5), + 0.00584892228263444 + * y + * (3.0 * x2 - y2) + * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z), + 0.0701870673916132 + * xy + * ( + 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + - 91.875 * z2 + + 13.125 + ), + 0.221950995245231 + * y + * ( + -2.8 * z * (1.5 - 7.5 * z2) + + 2.2 + * z + * ( + 2.25 + * z + * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ) + - 4.8 * z + ), + -1.48328138624466 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + + 1.86469659985043 + * z + * ( + -1.33333333333333 * z * (1.5 * z2 - 0.5) + + 1.8 + * z + * ( + 1.75 + * z + * ( + 1.66666666666667 * z * (1.5 * z2 - 0.5) + - 0.666666666666667 * z + ) + - 1.125 * z2 + + 0.375 + ) + + 0.533333333333333 * z + ) + + 0.953538034014426 * z2 + - 0.317846011338142, + 0.221950995245231 + * x + * ( + -2.8 * z * (1.5 - 7.5 * z2) + + 2.2 + * z + * ( + 2.25 + * z + * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ) + - 4.8 * z + ), + 0.0350935336958066 + * (x2 - y2) + * ( + 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + - 91.875 * z2 + + 13.125 + ), + 0.00584892228263444 + * x + * (x2 - 3.0 * y2) + * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z), + 0.0010678622237645 + * (5197.5 * z2 - 472.5) + * (-6.0 * x2 * y2 + x4 + y4), + -2.36661916223175 * xz * (-10.0 * x2 * y2 + x4 + 5.0 * y4), + 0.683184105191914 * x2**3 + + 10.2477615778787 * x2 * y4 + - 10.2477615778787 * x4 * y2 + - 0.683184105191914 * y2**3, + -0.707162732524596 + * y + * (7.0 * x2**3 + 21.0 * x2 * y4 - 35.0 * x4 * y2 - y2**3), + 2.6459606618019 + * z + * (6.0 * x**4 * xy - 20.0 * xy**3 + 6.0 * xy * y**4), + 9.98394571852353e-5 + * y + * (5197.5 - 67567.5 * z2) + * (-10.0 * x2 * y2 + 5.0 * x4 + y4), + 0.00239614697244565 + * xy + * (x2 - y2) + * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z), + 0.00397356022507413 + * y + * (3.0 * x2 - y2) + * ( + 3.25 + * z + * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z) + + 1063.125 * z2 + - 118.125 + ), + 0.0561946276120613 + * xy + * ( + -4.8 * z * (52.5 * z2 - 7.5) + + 2.6 + * z + * ( + 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + - 91.875 * z2 + + 13.125 + ) + + 48.0 * z + ), + 0.206472245902897 + * y + * ( + -2.625 + * z + * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 2.16666666666667 + * z + * ( + -2.8 * z * (1.5 - 7.5 * z2) + + 2.2 + * z + * ( + 2.25 + * z + * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ) + - 4.8 * z + ) + - 10.9375 * z2 + + 2.1875 + ), + 1.24862677781952 * z * (1.5 * z2 - 0.5) + - 1.68564615005635 + * z + * ( + 1.75 + * z + * ( + 1.66666666666667 * z * (1.5 * z2 - 0.5) + - 0.666666666666667 * z + ) + - 1.125 * z2 + + 0.375 + ) + + 2.02901851395672 + * z + * ( + -1.45833333333333 + * z + * ( + 1.66666666666667 * z * (1.5 * z2 - 0.5) + - 0.666666666666667 * z + ) + + 1.83333333333333 + * z + * ( + -1.33333333333333 * z * (1.5 * z2 - 0.5) + + 1.8 + * z + * ( + 1.75 + * z + * ( + 1.66666666666667 * z * (1.5 * z2 - 0.5) + - 0.666666666666667 * z + ) + - 1.125 * z2 + + 0.375 + ) + + 0.533333333333333 * z + ) + + 0.9375 * z2 + - 0.3125 + ) + - 0.499450711127808 * z, + 0.206472245902897 + * x + * ( + -2.625 + * z + * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 2.16666666666667 + * z + * ( + -2.8 * z * (1.5 - 7.5 * z2) + + 2.2 + * z + * ( + 2.25 + * z + * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ) + - 4.8 * z + ) + - 10.9375 * z2 + + 2.1875 + ), + 0.0280973138060306 + * (x2 - y2) + * ( + -4.8 * z * (52.5 * z2 - 7.5) + + 2.6 + * z + * ( + 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + - 91.875 * z2 + + 13.125 + ) + + 48.0 * z + ), + 0.00397356022507413 + * x + * (x2 - 3.0 * y2) + * ( + 3.25 + * z + * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z) + + 1063.125 * z2 + - 118.125 + ), + 0.000599036743111412 + * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z) + * (-6.0 * x2 * y2 + x4 + y4), + 9.98394571852353e-5 + * x + * (5197.5 - 67567.5 * z2) + * (-10.0 * x2 * y2 + x4 + 5.0 * y4), + 2.6459606618019 + * z + * (x2**3 + 15.0 * x2 * y4 - 15.0 * x4 * y2 - y2**3), + -0.707162732524596 + * x + * (x2**3 + 35.0 * x2 * y4 - 21.0 * x4 * y2 - 7.0 * y2**3), + 5.83141328139864 + * xy + * (x2**3 + 7.0 * x2 * y4 - 7.0 * x4 * y2 - y2**3), + -2.91570664069932 + * yz + * (7.0 * x2**3 + 21.0 * x2 * y4 - 35.0 * x4 * y2 - y2**3), + 7.87853281621404e-6 + * (1013512.5 * z2 - 67567.5) + * (6.0 * x**4 * xy - 20.0 * xy**3 + 6.0 * xy * y**4), + 5.10587282657803e-5 + * y + * (5.0 * z * (5197.5 - 67567.5 * z2) + 41580.0 * z) + * (-10.0 * x2 * y2 + 5.0 * x4 + y4), + 0.00147275890257803 + * xy + * (x2 - y2) + * ( + 3.75 + * z + * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z) + - 14293.125 * z2 + + 1299.375 + ), + 0.0028519853513317 + * y + * (3.0 * x2 - y2) + * ( + -7.33333333333333 * z * (52.5 - 472.5 * z2) + + 3.0 + * z + * ( + 3.25 + * z + * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z) + + 1063.125 * z2 + - 118.125 + ) + - 560.0 * z + ), + 0.0463392770473559 + * xy + * ( + -4.125 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + + 2.5 + * z + * ( + -4.8 * z * (52.5 * z2 - 7.5) + + 2.6 + * z + * ( + 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + - 91.875 * z2 + + 13.125 + ) + + 48.0 * z + ) + + 137.8125 * z2 + - 19.6875 + ), + 0.193851103820053 + * y + * ( + 3.2 * z * (1.5 - 7.5 * z2) + - 2.51428571428571 + * z + * ( + 2.25 + * z + * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ) + + 2.14285714285714 + * z + * ( + -2.625 + * z + * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 2.16666666666667 + * z + * ( + -2.8 * z * (1.5 - 7.5 * z2) + + 2.2 + * z + * ( + 2.25 + * z + * ( + 2.33333333333333 * z * (1.5 - 7.5 * z2) + + 4.0 * z + ) + + 9.375 * z2 + - 1.875 + ) + - 4.8 * z + ) + - 10.9375 * z2 + + 2.1875 + ) + + 5.48571428571429 * z + ), + 1.48417251362228 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 1.86581687426801 + * z + * ( + -1.33333333333333 * z * (1.5 * z2 - 0.5) + + 1.8 + * z + * ( + 1.75 + * z + * ( + 1.66666666666667 * z * (1.5 * z2 - 0.5) + - 0.666666666666667 * z + ) + - 1.125 * z2 + + 0.375 + ) + + 0.533333333333333 * z + ) + + 2.1808249179756 + * z + * ( + 1.14285714285714 * z * (1.5 * z2 - 0.5) + - 1.54285714285714 + * z + * ( + 1.75 + * z + * ( + 1.66666666666667 * z * (1.5 * z2 - 0.5) + - 0.666666666666667 * z + ) + - 1.125 * z2 + + 0.375 + ) + + 1.85714285714286 + * z + * ( + -1.45833333333333 + * z + * ( + 1.66666666666667 * z * (1.5 * z2 - 0.5) + - 0.666666666666667 * z + ) + + 1.83333333333333 + * z + * ( + -1.33333333333333 * z * (1.5 * z2 - 0.5) + + 1.8 + * z + * ( + 1.75 + * z + * ( + 1.66666666666667 * z * (1.5 * z2 - 0.5) + - 0.666666666666667 * z + ) + - 1.125 * z2 + + 0.375 + ) + + 0.533333333333333 * z + ) + + 0.9375 * z2 + - 0.3125 + ) + - 0.457142857142857 * z + ) + - 0.954110901614325 * z2 + + 0.318036967204775, + 0.193851103820053 + * x + * ( + 3.2 * z * (1.5 - 7.5 * z2) + - 2.51428571428571 + * z + * ( + 2.25 + * z + * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ) + + 2.14285714285714 + * z + * ( + -2.625 + * z + * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 2.16666666666667 + * z + * ( + -2.8 * z * (1.5 - 7.5 * z2) + + 2.2 + * z + * ( + 2.25 + * z + * ( + 2.33333333333333 * z * (1.5 - 7.5 * z2) + + 4.0 * z + ) + + 9.375 * z2 + - 1.875 + ) + - 4.8 * z + ) + - 10.9375 * z2 + + 2.1875 + ) + + 5.48571428571429 * z + ), + 0.0231696385236779 + * (x2 - y2) + * ( + -4.125 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + + 2.5 + * z + * ( + -4.8 * z * (52.5 * z2 - 7.5) + + 2.6 + * z + * ( + 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + - 91.875 * z2 + + 13.125 + ) + + 48.0 * z + ) + + 137.8125 * z2 + - 19.6875 + ), + 0.0028519853513317 + * x + * (x2 - 3.0 * y2) + * ( + -7.33333333333333 * z * (52.5 - 472.5 * z2) + + 3.0 + * z + * ( + 3.25 + * z + * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z) + + 1063.125 * z2 + - 118.125 + ) + - 560.0 * z + ), + 0.000368189725644507 + * (-6.0 * x2 * y2 + x4 + y4) + * ( + 3.75 + * z + * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z) + - 14293.125 * z2 + + 1299.375 + ), + 5.10587282657803e-5 + * x + * (5.0 * z * (5197.5 - 67567.5 * z2) + 41580.0 * z) + * (-10.0 * x2 * y2 + x4 + 5.0 * y4), + 7.87853281621404e-6 + * (1013512.5 * z2 - 67567.5) + * (x2**3 + 15.0 * x2 * y4 - 15.0 * x4 * y2 - y2**3), + -2.91570664069932 + * xz + * (x2**3 + 35.0 * x2 * y4 - 21.0 * x4 * y2 - 7.0 * y2**3), + -20.4099464848952 * x2**3 * y2 + - 20.4099464848952 * x2 * y2**3 + + 0.72892666017483 * x4**2 + + 51.0248662122381 * x4 * y4 + + 0.72892666017483 * y4**2, + ], + -1, + ) diff --git a/wilddet3d/ops/rotation.py b/wilddet3d/ops/rotation.py new file mode 100644 index 0000000000000000000000000000000000000000..96ec75eb0e6f0887f33de3fd9cf077a9948cfc7c --- /dev/null +++ b/wilddet3d/ops/rotation.py @@ -0,0 +1,198 @@ +"""Rotation ops.""" + +from __future__ import annotations + +import math + +import torch +from torch import Tensor +from torch.nn import functional as F +from vis4d.op.geometry.rotation import quaternion_to_matrix + +DEFAULT_ACOS_BOUND: float = 1.0 - 1e-4 + + +def _acos_linear_approximation(x: Tensor, x0: float) -> Tensor: + return (x - x0) * _dacos_dx(x0) + math.acos(x0) + + +def _dacos_dx(x: float) -> float: + return (-1.0) / math.sqrt(1.0 - x * x) + + +def acos_linear_extrapolation( + x: Tensor, + bounds: tuple[float, float] = (-DEFAULT_ACOS_BOUND, DEFAULT_ACOS_BOUND), +) -> Tensor: + """Implements arccos(x) with linear extrapolation outside (-1, 1).""" + lower_bound, upper_bound = bounds + + if lower_bound > upper_bound: + raise ValueError( + "lower bound has to be smaller or equal to upper bound." + ) + + if lower_bound <= -1.0 or upper_bound >= 1.0: + raise ValueError( + "Both lower bound and upper bound have to be within (-1, 1)." + ) + + acos_extrap = torch.empty_like(x) + x_upper = x >= upper_bound + x_lower = x <= lower_bound + x_mid = (~x_upper) & (~x_lower) + + acos_extrap[x_mid] = torch.acos(x[x_mid]) + acos_extrap[x_upper] = _acos_linear_approximation(x[x_upper], upper_bound) + acos_extrap[x_lower] = _acos_linear_approximation(x[x_lower], lower_bound) + + return acos_extrap + + +def so3_rotation_angle( + R: Tensor, + eps: float = 1e-4, + cos_angle: bool = False, + cos_bound: float = 1e-4, +) -> Tensor: + """Calculates angles (in radians) of a batch of rotation matrices.""" + _, dim1, dim2 = R.shape + if dim1 != 3 or dim2 != 3: + raise ValueError("Input has to be a batch of 3x3 Tensors.") + + rot_trace = R[:, 0, 0] + R[:, 1, 1] + R[:, 2, 2] + + if ((rot_trace < -1.0 - eps) + (rot_trace > 3.0 + eps)).any(): + raise ValueError( + "A matrix has trace outside valid range [-1-eps,3+eps]." + ) + + phi_cos = (rot_trace - 1.0) * 0.5 + + if cos_angle: + return phi_cos + else: + if cos_bound > 0.0: + bound = 1.0 - cos_bound + return acos_linear_extrapolation(phi_cos, (-bound, bound)) + else: + return torch.acos(phi_cos) + + +def so3_relative_angle( + R1: Tensor, + R2: Tensor, + cos_angle: bool = False, + cos_bound: float = 1e-4, + eps: float = 1e-4, +) -> Tensor: + """Calculates the relative angle between pairs of rotation matrices.""" + R12 = torch.bmm(R1, R2.permute(0, 2, 1)) + return so3_rotation_angle( + R12, cos_angle=cos_angle, cos_bound=cos_bound, eps=eps + ) + + +def axis_angle_to_quaternion(axis_angle: Tensor) -> Tensor: + """Convert rotations given as axis/angle to quaternions.""" + angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) + half_angles = angles * 0.5 + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + quaternions = torch.cat( + [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], + dim=-1, + ) + return quaternions + + +def axis_angle_to_matrix(axis_angle: Tensor) -> Tensor: + """Convert rotations given as axis/angle to rotation matrices.""" + return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) + + +def rotation_6d_to_matrix(d6: Tensor) -> Tensor: + """Converts 6D rotation representation to rotation matrix.""" + a1, a2 = d6[..., :3], d6[..., 3:] + b1 = F.normalize(a1, dim=-1) + b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 + b2 = F.normalize(b2, dim=-1) + b3 = torch.cross(b1, b2, dim=-1) + return torch.stack((b1, b2, b3), dim=-2) + + +def matrix_to_rotation_6d(matrix: Tensor) -> Tensor: + """Converts rotation matrices to 6D rotation representation.""" + batch_dim = matrix.size()[:-2] + return matrix[..., :2, :].clone().reshape(batch_dim + (6,)) + + +def R_from_allocentric(K: Tensor, R_view, u=None, v=None): + """Convert rotation matrix to egocentric representation.""" + fx = K[:, 0, 0] + fy = K[:, 1, 1] + sx = K[:, 0, 2] + sy = K[:, 1, 2] + + if u is None: + u = sx + if v is None: + v = sy + + oray = torch.stack(((u - sx) / fx, (v - sy) / fy, torch.ones_like(u))).T + oray = oray / torch.linalg.norm(oray, dim=1).unsqueeze(1) + angle = torch.acos(oray[:, -1]) + + axis = torch.zeros_like(oray) + axis[:, 0] = axis[:, 0] - oray[:, 1] + axis[:, 1] = axis[:, 1] + oray[:, 0] + norms = torch.linalg.norm(axis, dim=1) + + valid_angle = angle > 0 + + M = axis_angle_to_matrix(angle.unsqueeze(1) * axis / norms.unsqueeze(1)) + + R = R_view.clone() + R[valid_angle] = torch.bmm(M[valid_angle], R_view[valid_angle]) + + return R + + +def R_to_allocentric(K: Tensor, R, u=None, v=None): + """Convert rotation matrix to allocentric representation.""" + fx = K[:, 0, 0] + fy = K[:, 1, 1] + sx = K[:, 0, 2] + sy = K[:, 1, 2] + + if u is None: + u = sx + if v is None: + v = sy + + oray = torch.stack(((u - sx) / fx, (v - sy) / fy, torch.ones_like(u))).T + oray = oray / torch.linalg.norm(oray, dim=1).unsqueeze(1) + angle = torch.acos(oray[:, -1]) + + axis = torch.zeros_like(oray) + axis[:, 0] = axis[:, 0] - oray[:, 1] + axis[:, 1] = axis[:, 1] + oray[:, 0] + norms = torch.linalg.norm(axis, dim=1) + + valid_angle = angle > 0 + + M = axis_angle_to_matrix(angle.unsqueeze(1) * axis / norms.unsqueeze(1)) + + R_view = R.clone() + R_view[valid_angle] = torch.bmm( + M[valid_angle].transpose(2, 1), R[valid_angle] + ) + + return R_view diff --git a/wilddet3d/ops/upsample.py b/wilddet3d/ops/upsample.py new file mode 100644 index 0000000000000000000000000000000000000000..e74da9b8ac2d4bcb0bac07a2aff0dce4f8b0dc66 --- /dev/null +++ b/wilddet3d/ops/upsample.py @@ -0,0 +1,127 @@ +"""Upsampling layers.""" + +import torch +from einops import rearrange +from torch import Tensor, nn + + +class CvnxtBlock(nn.Module): + def __init__( + self, + dim, + kernel_size=7, + layer_scale=1.0, + expansion=4, + dilation=1, + padding_mode: str = "zeros", + ): + super().__init__() + self.dwconv = nn.Conv2d( + dim, + dim, + kernel_size=kernel_size, + padding=dilation * (kernel_size - 1) // 2, + groups=dim, + dilation=dilation, + padding_mode=padding_mode, + ) + self.norm = nn.LayerNorm(dim) + self.pwconv1 = nn.Linear(dim, expansion * dim) + self.act = nn.GELU() + self.pwconv2 = nn.Linear(expansion * dim, dim) + self.gamma = ( + nn.Parameter(layer_scale * torch.ones((dim))) + if layer_scale > 0.0 + else 1.0 + ) + + def forward(self, x): + input = x + x = self.dwconv(x) + x = x.permute(0, 2, 3, 1) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + + x = self.gamma * x + x = input + x.permute(0, 3, 1, 2) + return x + + +class ConvUpsample(nn.Module): + """Convolutional upsampling layer.""" + + def __init__( + self, + hidden_dim: int, + output_dim: int | None = None, + num_layers: int = 2, + expansion: int = 4, + layer_scale: float = 1.0, + kernel_size: int = 7, + ) -> None: + """Init.""" + super().__init__() + + if output_dim is None: + output_dim = hidden_dim // 2 + + self.convs = nn.ModuleList([]) + for _ in range(num_layers): + self.convs.append( + CvnxtBlock( + hidden_dim, + kernel_size=kernel_size, + expansion=expansion, + layer_scale=layer_scale, + ) + ) + + self.up = nn.Sequential( + nn.Conv2d(hidden_dim, output_dim, kernel_size=1, padding=0), + nn.UpsamplingBilinear2d(scale_factor=2), + nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1), + ) + + def forward(self, x: Tensor): + for conv in self.convs: + x = conv(x) + x = self.up(x) + x = rearrange(x, "b c h w -> b (h w) c") + return x + + +class ConvUpsampleShuffle(nn.Module): + def __init__( + self, + hidden_dim, + num_layers: int = 2, + expansion: int = 4, + layer_scale: float = 1.0, + kernel_size: int = 7, + ): + super().__init__() + self.convs = nn.ModuleList([]) + for _ in range(num_layers): + self.convs.append( + CvnxtBlock( + hidden_dim, + kernel_size=kernel_size, + expansion=expansion, + layer_scale=layer_scale, + ) + ) + self.up = nn.Sequential( + nn.PixelShuffle(2), + nn.Conv2d( + hidden_dim // 4, hidden_dim // 2, kernel_size=3, padding=1 + ), + ) + + def forward(self, x: Tensor): + for conv in self.convs: + x = conv(x) + x = self.up(x) + x = rearrange(x, "b c h w -> b (h w) c") + return x diff --git a/wilddet3d/ops/util.py b/wilddet3d/ops/util.py new file mode 100644 index 0000000000000000000000000000000000000000..621c9834abc94c41726be625ba44146d016c8522 --- /dev/null +++ b/wilddet3d/ops/util.py @@ -0,0 +1,44 @@ +"""Op utility functions.""" + +from __future__ import annotations + +from functools import partial + +import torch.nn.functional as F +from torch import Tensor + + +def multi_apply(func, *args, **kwargs): + """Apply function to a list of arguments.""" + pfunc = partial(func, **kwargs) if kwargs else func + map_results = map(pfunc, *args) + return tuple(map(list, zip(*map_results))) + + +def flat_interpolate( + flat_tensor: Tensor, + old: tuple[int, int], + new: tuple[int, int], + antialias: bool = True, + mode: str = "bilinear", +) -> Tensor: + if old[0] == new[0] and old[1] == new[1]: + return flat_tensor + tensor = flat_tensor.view( + flat_tensor.shape[0], old[0], old[1], -1 + ).permute( + 0, 3, 1, 2 + ) + tensor_interp = F.interpolate( + tensor, + size=(new[0], new[1]), + mode=mode, + align_corners=False, + antialias=antialias, + ) + flat_tensor_interp = tensor_interp.view( + flat_tensor.shape[0], -1, new[0] * new[1] + ).permute( + 0, 2, 1 + ) + return flat_tensor_interp.contiguous() diff --git a/wilddet3d/preprocessing.py b/wilddet3d/preprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..880a8aa5b96b80330bf39f94cb5c875d3dd72992 --- /dev/null +++ b/wilddet3d/preprocessing.py @@ -0,0 +1,83 @@ +"""Preprocessing utilities for WildDet3D inference. + +Handles image resizing, normalization, center padding, and intrinsics +adjustment to prepare raw inputs for the WildDet3D model. +""" + +from typing import Optional + +import numpy as np + +from vis4d.data.transforms.base import compose +from vis4d.data.transforms.normalize import NormalizeImages +from vis4d.data.transforms.resize import ResizeImages, ResizeIntrinsics +from vis4d.data.transforms.to_tensor import ToTensor + +from wilddet3d.data.transforms.pad import ( + CenterPadImages, + CenterPadIntrinsics, +) +from wilddet3d.data.transforms.resize import GenResizeParameters + +# WildDet3D expects 1008x1008 images +IMAGE_SIZE = (1008, 1008) + + +def preprocess( + image: np.ndarray, + intrinsics: Optional[np.ndarray] = None, +) -> dict: + """Preprocess image for WildDet3D. + + Args: + image: RGB image as numpy array (H, W, 3) + intrinsics: Camera intrinsics (3, 3), or None to use default/predicted + + Returns: + Dict with preprocessed tensors and metadata + """ + images = image.astype(np.float32)[None, ...] + H, W = images.shape[1], images.shape[2] + + # If no intrinsics provided, create a placeholder. + # When use_predicted_intrinsics=True in the model, the geometry backend's + # K_pred will be used for 3D box decoding instead of this placeholder. + # The placeholder is still needed so the data pipeline doesn't crash. + if intrinsics is None: + focal = max(H, W) + intrinsics = np.array( + [ + [focal, 0, W / 2], + [0, focal, H / 2], + [0, 0, 1], + ], + dtype=np.float32, + ) + + data_dict = { + "images": images, + "original_images": images.copy(), + "input_hw": (H, W), + "original_hw": (H, W), + "intrinsics": intrinsics.astype(np.float32), + "original_intrinsics": intrinsics.astype(np.float32).copy(), + } + + preprocess_transforms = compose( + transforms=[ + GenResizeParameters(shape=IMAGE_SIZE), + ResizeImages(), + ResizeIntrinsics(), + NormalizeImages(), + CenterPadImages( + stride=1, shape=IMAGE_SIZE, update_input_hw=True + ), + CenterPadIntrinsics(), + ] + ) + + data = preprocess_transforms([data_dict])[0] + to_tensor = ToTensor() + data = to_tensor([data])[0] + + return data diff --git a/wilddet3d/vis/__init__.py b/wilddet3d/vis/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/wilddet3d/vis/__pycache__/__init__.cpython-311.pyc b/wilddet3d/vis/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df8bab22fe019b68da5b00cab0d9dc421ab2ae78 Binary files /dev/null and b/wilddet3d/vis/__pycache__/__init__.cpython-311.pyc differ diff --git a/wilddet3d/vis/__pycache__/visualize.cpython-311.pyc b/wilddet3d/vis/__pycache__/visualize.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de0dbe8a4aa3349fc38cd8e3812cc208b191cc4c Binary files /dev/null and b/wilddet3d/vis/__pycache__/visualize.cpython-311.pyc differ diff --git a/wilddet3d/vis/fonts/Manrope-Bold.ttf b/wilddet3d/vis/fonts/Manrope-Bold.ttf new file mode 100644 index 0000000000000000000000000000000000000000..52d93a3fbf43352085a4089047aeb600929fd583 Binary files /dev/null and b/wilddet3d/vis/fonts/Manrope-Bold.ttf differ diff --git a/wilddet3d/vis/fonts/Manrope-SemiBold.ttf b/wilddet3d/vis/fonts/Manrope-SemiBold.ttf new file mode 100644 index 0000000000000000000000000000000000000000..85e036efd34d272b11398d7320f5505611913003 Binary files /dev/null and b/wilddet3d/vis/fonts/Manrope-SemiBold.ttf differ diff --git a/wilddet3d/vis/image/__init__.py b/wilddet3d/vis/image/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/wilddet3d/vis/image/depth_visualizer.py b/wilddet3d/vis/image/depth_visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..b911c4caad26a9250d7f17375b420c54c5822c11 --- /dev/null +++ b/wilddet3d/vis/image/depth_visualizer.py @@ -0,0 +1,200 @@ +"""Depth visualizer.""" + +from __future__ import annotations + +import os +from dataclasses import dataclass + +import numpy as np +from PIL import Image +from vis4d.common.array import array_to_numpy +from vis4d.common.typing import ( + ArgsType, + ArrayLikeFloat, + NDArrayF32, + NDArrayUI8, +) +from vis4d.vis.base import Visualizer +from vis4d.vis.image.util import preprocess_image +from vis4d.vis.util import generate_color_map + +from .util import ( + colorize, + get_pointcloud_from_rgbd, + save_depth_map, + save_file_ply, +) + + +@dataclass +class DataSample: + """Dataclass storing a data sample that can be visualized.""" + + image: NDArrayUI8 + image_name: str + depth: NDArrayF32 + depth_gt: NDArrayF32 | None = None + depth_error: NDArrayF32 | None = None + points_rgb: NDArrayF32 | None = None + + +class DepthVisualizer(Visualizer): + """Depth visualizer class.""" + + def __init__( + self, + *args: ArgsType, + max_depth: None | float = None, + plot_error: bool = False, + lift: bool = False, + color_palette: list[tuple[int, int, int]] | None = None, + **kwargs: ArgsType, + ) -> None: + """Creates a new Visualizer for Depth. + + Args: + max_depth (None | float): Maximum depth to visualize. + """ + super().__init__(*args, **kwargs) + self.max_depth = max_depth + self._samples: list[DataSample] = [] + self._gt_samples = [] + self.plot_error = plot_error + self.lift = lift + self.color_palette = ( + generate_color_map(50) if color_palette is None else color_palette + ) + + def reset(self) -> None: + """Reset the visualizer.""" + self._samples.clear() + self._gt_samples.clear() + + def process( + self, + cur_iter: int, + images: list[ArrayLikeFloat], + image_names: list[str], + depths: ArrayLikeFloat, + depth_gts: ArrayLikeFloat | None = None, + intrinsics: ArrayLikeFloat | None = None, + ) -> None: + """Process data of a batch of data.""" + if self._run_on_batch(cur_iter): + for i, image in enumerate(images): + image = preprocess_image(image) + self._samples.append( + self.process_single_image( + image, + image_names[i], + array_to_numpy(depths[i]), + ( + array_to_numpy(depth_gts[i]) + if depth_gts is not None + else None + ), + ( + array_to_numpy(intrinsics[i]) + if intrinsics is not None + else None + ), + ) + ) + + def process_single_image( + self, + image: NDArrayUI8, + image_name: str, + depth: NDArrayF32, + depth_gt: NDArrayF32 | None = None, + intrinsic: NDArrayF32 | None = None, + ) -> DataSample: + """Process data of a batch of data.""" + if self.max_depth is not None: + mask = depth <= self.max_depth + else: + mask = np.full(depth.shape, True) + + if self.plot_error: + assert ( + depth_gt is not None + ), "Ground truth depth is required for plotting error." + error = np.zeros_like(depth_gt) + error[depth_gt > 0] = ( + np.abs(depth_gt - depth)[depth_gt > 0] / depth_gt[depth_gt > 0] + ) + else: + error = None + + if self.lift: + assert ( + intrinsic is not None + ), "Intrinsic matrix is required for lifting." + points_rgb = get_pointcloud_from_rgbd( + image, depth, intrinsic, mask + ) + else: + points_rgb = None + + return DataSample( + image=image, + image_name=image_name, + depth=depth, + depth_gt=depth_gt, + depth_error=error, + points_rgb=points_rgb, + ) + + def save_to_disk(self, cur_iter: int, output_folder: str) -> None: + """Saves the visualization to disk. + + Args: + cur_iter (int): Current iteration. + output_folder (str): Folder where the output should be written. + """ + if self._run_on_batch(cur_iter): + for sample in self._samples: + save_dir = os.path.join(output_folder, "depth") + os.makedirs(save_dir, exist_ok=True) + + Image.fromarray(sample.image).save( + f"{save_dir}/{sample.image_name}.png", + ) + + if self.plot_error: + error = sample.depth_error + + error_image = Image.fromarray( + colorize( + error.clip(0.0, 0.3), + vmin=0.001, + vmax=0.3, + cmap="coolwarm", + ) + ) + + error_image.save( + f"{save_dir}/{sample.image_name}_error.png" + ) + + save_depth_map( + sample.depth, + f"{save_dir}/{sample.image_name}_pred.png", + ) + + if sample.depth_gt is not None: + save_depth_map( + sample.depth_gt, + f"{save_dir}/{sample.image_name}_gt.png", + ) + + if self.lift: + save_dir = os.path.join(output_folder, "points") + os.makedirs(save_dir, exist_ok=True) + + if sample.points_rgb is not None: + save_file_ply( + sample.points_rgb[:, :3], + sample.points_rgb[:, 3:], + os.path.join(save_dir, f"{sample.image_name}.ply"), + ) diff --git a/wilddet3d/vis/image/util.py b/wilddet3d/vis/image/util.py new file mode 100644 index 0000000000000000000000000000000000000000..e7adffc042a79ed0918ab0c4c616b7ade6a9a487 --- /dev/null +++ b/wilddet3d/vis/image/util.py @@ -0,0 +1,151 @@ +"""Utility functions for image processing operations.""" + +from __future__ import annotations + +import numpy as np +from matplotlib.pyplot import get_cmap +from PIL import Image +from vis4d.common.typing import ( + NDArrayBool, + NDArrayF32, + NDArrayUI8, + NDArrayUI16, +) + + +def save_depth_map( + depth_map: NDArrayF32, filename: str, depth_scale: float = 256.0 +) -> None: + """Dump depth map. + + Args: + depth_map (NDArrayF32): Depth map to dump. + filename (str): Path to dump depth map. + depth_scale (float): Depth scale. + """ + numpy_image = (depth_map * depth_scale).astype(np.uint16) + numpy_image = colorize(numpy_image) + Image.fromarray(numpy_image).save(filename) + + +def colorize( + value: NDArrayUI16, + vmin: float | None = None, + vmax: float | None = None, + cmap: str = "magma_r", +) -> Image.Image: + if value.ndim > 2: + return value + invalid_mask = value < 1e-3 + # normalize + vmin = value.min() if vmin is None else vmin + vmax = value.max() if vmax is None else vmax + value = (value - vmin) / (vmax - vmin) # vmin..vmax + + # set color + cmapper = get_cmap(cmap) + value = cmapper(value, bytes=True) # (nxmx4) + value[invalid_mask] = 0 + img = value[..., :3] + return img + + +def get_pointcloud_from_rgbd( + image: NDArrayUI8, + depth: NDArrayF32, + intrinsic_matrix: NDArrayF32, + mask: NDArrayBool, + remove_height: float | None = None, +) -> NDArrayF32: + """Get pointcloud from RGBD image. + + Args: + image (np.array): RGB image. Shape: (H, W, 3) + depth (np.array): Depth image. Shape: (H, W) + mask (np.ndarray): Mask of valid depth values. Shape: (H, W) + intrinsic_matrix (np.array): Intrinsic matrix of camera. Shape: (3, 3) + extrinsic_matrix (np.array, optional): Extrinsic matrix of camera. + Shape: (4, 4). Defaults to None. + voxelize (bool, optional): Whether to voxelize the pointcloud. + + Returns: + NDArrayF32: Pointcloud. Shape: (N, 6) + """ + # Mask the depth array + masked_depth = np.ma.masked_where(mask == False, depth) + + # Create idx array + idxs = np.indices(masked_depth.shape) + u_idxs = idxs[1] + v_idxs = idxs[0] + + # Get only non-masked depth and idxs + z = masked_depth[~masked_depth.mask] + compressed_u_idxs = u_idxs[~masked_depth.mask] + compressed_v_idxs = v_idxs[~masked_depth.mask] + image = np.stack( + [image[..., i][~masked_depth.mask] for i in range(image.shape[-1])], + axis=-1, + ) + + # Calculate local position of each point + # Apply vectorized math to depth using compressed arrays + cx = intrinsic_matrix[0, 2] + fx = intrinsic_matrix[0, 0] + x = (compressed_u_idxs - cx) * z / fx + cy = intrinsic_matrix[1, 2] + fy = intrinsic_matrix[1, 1] + + # Flip y as we want +y pointing up not down + y = (compressed_v_idxs - cy) * z / fy + + # Remove height + if remove_height is not None: + mask = y >= remove_height + x = x[mask] + y = y[mask] + z = z[mask] + image = image[mask] + else: + x = x.reshape(-1) + y = y.reshape(-1) + z = z.reshape(-1) + image = image.reshape(-1, 3) + + x_y_z_local = np.stack((x, y, z), axis=-1) + + return np.concatenate([x_y_z_local, image], axis=-1) + + +def save_file_ply(xyz: NDArrayF32, rgb: NDArrayF32, pc_file: str) -> None: + """Save point cloud to ply file.""" + if rgb.max() < 1.001: + rgb = rgb * 255.0 + rgb = rgb.astype(np.uint8) + + with open(pc_file, "w") as f: + # headers + f.writelines( + [ + "ply\n" "format ascii 1.0\n", + "element vertex {}\n".format(xyz.shape[0]), + "property float x\n", + "property float y\n", + "property float z\n", + "property uchar red\n", + "property uchar green\n", + "property uchar blue\n", + "end_header\n", + ] + ) + + for i in range(xyz.shape[0]): + str_v = "{:10.6f} {:10.6f} {:10.6f} {:d} {:d} {:d}\n".format( + xyz[i][0], + xyz[i, 1], + xyz[i, 2], + rgb[i, 0], + rgb[i, 1], + rgb[i, 2], + ) + f.write(str_v) diff --git a/wilddet3d/vis/visualize.py b/wilddet3d/vis/visualize.py new file mode 100644 index 0000000000000000000000000000000000000000..3768eb58ac5434250669758c479b2c76f74eb2f0 --- /dev/null +++ b/wilddet3d/vis/visualize.py @@ -0,0 +1,261 @@ +"""WildDet3D visualization utilities. + +Anti-aliased 3D bounding boxes with Manrope font score labels. +Uses vis4d's preprocess_boxes3d for correct 3D corner projection, +cv2 LINE_AA for smooth lines, PIL + Manrope for text rendering. +""" + +from __future__ import annotations + +from pathlib import Path + +import cv2 +import numpy as np +import torch +from PIL import Image, ImageDraw, ImageFont +from torch import Tensor + +from vis4d.common.array import array_to_numpy +from vis4d.data.const import AxisMode +from vis4d.op.box.box3d import boxes3d_to_corners +from vis4d.vis.util import generate_color_map + +_FONT_DIR = Path(__file__).parent / "fonts" + +# vis4d edge order (from PillowCanvasBackend.draw_box_3d) +# Front face: 0-1-5-4, Back face: 2-3-7-6, Sides: 0-2, 1-3, 4-6, 5-7 +_EDGES = [ + # Front + (0, 1), (1, 5), (5, 4), (4, 0), + # Sides + (0, 2), (1, 3), (4, 6), (5, 7), + # Back + (2, 3), (3, 7), (7, 6), (6, 2), +] + + +def _get_font(size: int = 14) -> ImageFont.FreeTypeFont: + """Get Manrope Bold font with fallbacks.""" + for path in [ + _FONT_DIR / "Manrope-Bold.ttf", + _FONT_DIR / "Manrope-SemiBold.ttf", + Path("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf"), + ]: + if path.exists(): + return ImageFont.truetype(str(path), size) + return ImageFont.load_default() + + +def _project_pt_simple(pt_3d, K_np): + """Project single 3D point to 2D using intrinsics (no torch overhead).""" + x, y, z = pt_3d + fx, fy = K_np[0, 0], K_np[1, 1] + cx, cy = K_np[0, 2], K_np[1, 2] + u = fx * x / z + cx + v = fy * y / z + cy + return float(u), float(v) + + +def _clip_to_near(p1, p2, near=0.15): + """Clip line to near plane, return clipped point.""" + x1, y1, z1 = p1 + x2, y2, z2 = p2 + k_up = abs(z1 - near) + k_down = abs(z1 - z2) + k = min(k_up / k_down, 1.0) if k_down > 0 else 1.0 + return ((1 - k) * x1 + k * x2, (1 - k) * y1 + k * y2, near) + + +def draw_3d_boxes( + image: np.ndarray, + boxes3d: Tensor | np.ndarray, + intrinsics: np.ndarray, + scores_2d: Tensor | np.ndarray | None = None, + scores_3d: Tensor | np.ndarray | None = None, + class_ids: Tensor | np.ndarray | None = None, + class_names: list[str] | None = None, + line_width: int = 2, + font_size: int = 13, + n_colors: int = 50, + score_format: str = "{name} 2D:{s2d:.2f} 3D:{s3d:.2f}", + near_clip: float = 0.15, + save_path: str | None = None, +) -> Image.Image: + """Draw anti-aliased 3D bounding boxes with 2D/3D score labels. + + Args: + image: RGB image (H, W, 3) uint8. + boxes3d: 3D boxes (N, 10) in OPENCV camera coordinates. + intrinsics: Camera intrinsics (3, 3). + scores_2d: 2D confidence scores (N,). + scores_3d: 3D confidence scores (N,). + class_ids: Class indices (N,). + class_names: List of class names. + line_width: Width of 3D box edges. + font_size: Font size for labels. + n_colors: Number of colors in palette. + score_format: Format string. Available: {name}, {s2d}, {s3d}. + near_clip: Camera near clipping plane. + save_path: If provided, save the result. + + Returns: + PIL Image with drawn boxes and score labels. + """ + if isinstance(image, Tensor): + image = image.cpu().numpy() + if image.dtype != np.uint8: + image = (image * 255).astype(np.uint8) if image.max() <= 1.0 else image.astype(np.uint8) + if isinstance(boxes3d, Tensor): + boxes3d_t = boxes3d.cpu().float() + else: + boxes3d_t = torch.tensor(boxes3d, dtype=torch.float32) + if isinstance(scores_2d, Tensor): + scores_2d = scores_2d.cpu().numpy() + if isinstance(scores_3d, Tensor): + scores_3d = scores_3d.cpu().numpy() + if isinstance(class_ids, Tensor): + class_ids = class_ids.cpu().numpy() + + N = len(boxes3d_t) + H, W = image.shape[:2] + K_np = intrinsics.astype(np.float32) + + if N == 0: + pil_img = Image.fromarray(image) + if save_path: + pil_img.save(save_path, quality=95) + return pil_img + + # Get 3D corners (N, 8, 3) using vis4d's OPENCV convention + corners_3d = boxes3d_to_corners(boxes3d_t, AxisMode.OPENCV).numpy() + + color_map = generate_color_map(n_colors) + + # --- Draw lines with cv2 (anti-aliased) --- + canvas = image.copy() + canvas_bgr = cv2.cvtColor(canvas, cv2.COLOR_RGB2BGR) + + for i in range(N): + cid = int(class_ids[i]) if class_ids is not None else i + color_rgb = color_map[cid % len(color_map)] + color_bgr = (int(color_rgb[2]), int(color_rgb[1]), int(color_rgb[0])) + + corners = corners_3d[i] # (8, 3) + + for e0, e1 in _EDGES: + p1 = tuple(corners[e0].tolist()) + p2 = tuple(corners[e1].tolist()) + + # Near-plane clipping + if p1[2] < near_clip and p2[2] < near_clip: + continue + if p1[2] < near_clip: + p1 = _clip_to_near(p1, p2, near_clip) + elif p2[2] < near_clip: + p2 = _clip_to_near(p2, p1, near_clip) + + # Project to 2D + u1, v1 = _project_pt_simple(p1, K_np) + u2, v2 = _project_pt_simple(p2, K_np) + + # Skip if way outside image + margin = max(W, H) + if (abs(u1) > margin * 2 or abs(v1) > margin * 2 or + abs(u2) > margin * 2 or abs(v2) > margin * 2): + continue + + cv2.line( + canvas_bgr, + (int(round(u1)), int(round(v1))), + (int(round(u2)), int(round(v2))), + color_bgr, + thickness=line_width, + lineType=cv2.LINE_AA, + ) + + canvas_rgb = cv2.cvtColor(canvas_bgr, cv2.COLOR_BGR2RGB) + + # --- Draw text labels with PIL (Manrope font) --- + # Use RGBA for rounded rectangle with alpha + pil_img = Image.fromarray(canvas_rgb).convert("RGBA") + overlay = Image.new("RGBA", pil_img.size, (0, 0, 0, 0)) + draw_overlay = ImageDraw.Draw(overlay) + draw_main = ImageDraw.Draw(pil_img) + font = _get_font(font_size) + + for i in range(N): + cid = int(class_ids[i]) if class_ids is not None else 0 + color = color_map[cid % len(color_map)] + + # Project center to 2D + center_3d = boxes3d_t[i, :3].numpy() + if center_3d[2] < near_clip: + continue + cx, cy = _project_pt_simple(tuple(center_3d.tolist()), K_np) + if cx < -50 or cx >= W + 50 or cy < -50 or cy >= H + 50: + continue + + name = class_names[cid] if class_names is not None else str(cid) + s2d = float(scores_2d[i]) if scores_2d is not None else 0.0 + s3d = float(scores_3d[i]) if scores_3d is not None else 0.0 + label = score_format.format(name=name, s2d=s2d, s3d=s3d) + + # Measure text size (textbbox returns actual glyph bounds) + left, top, right, bottom = draw_main.textbbox((0, 0), label, font=font) + tw = right - left + th = bottom - top + y_offset = top # font ascent offset (glyphs don't start at y=0) + + # Position: inside the box, near the projected center + pad_x, pad_y = 6, 4 + radius = 5 + + # Place label centered at projected center + rx0 = cx - tw / 2 - pad_x + ry0 = cy - th / 2 - pad_y + rx1 = cx + tw / 2 + pad_x + ry1 = cy + th / 2 + pad_y + + # Clamp to image bounds + if rx0 < 2: + shift = 2 - rx0 + rx0 += shift + rx1 += shift + if rx1 > W - 2: + shift = rx1 - (W - 2) + rx0 -= shift + rx1 -= shift + if ry0 < 2: + shift = 2 - ry0 + ry0 += shift + ry1 += shift + if ry1 > H - 2: + shift = ry1 - (H - 2) + ry0 -= shift + ry1 -= shift + + # Draw rounded rectangle on overlay (semi-transparent) + fill_color = tuple(color) + (210,) + draw_overlay.rounded_rectangle( + [rx0, ry0, rx1, ry1], + radius=radius, + fill=fill_color, + ) + + # Text centered in the rounded rect (compensate font ascent offset) + text_x = rx0 + pad_x - left + text_y = ry0 + pad_y - y_offset + draw_overlay.text( + (text_x, text_y), + label, + fill=(255, 255, 255, 255), + font=font, + ) + + # Composite overlay onto main image + pil_img = Image.alpha_composite(pil_img, overlay).convert("RGB") + + if save_path: + pil_img.save(save_path, quality=95) + + return pil_img