| |
| """ |
| Feature extraction from frozen DeRIS backbone (Section 3.1-3.3 of Venice-H1). |
| |
| Produces cached .pt files containing per-sample features: |
| - query_feat: [N_samples, N, D] query embeddings (D=256) |
| - det_scores: [N_samples, N] detection scores |
| - query_ious: [N_samples, N] per-query IoU vs GT |
| - oracle_idx: [N_samples] best query index |
| - mask_mean: [N_samples, N] μ_i = mean(P_i) |
| - mask_max: [N_samples, N] p̂_i = max(P_i) |
| - mask_area: [N_samples, N] a_i = mean(P_i > 0.5) |
| - mask_std: [N_samples, N] σ_i = std(P_i) |
| - grid_mean_4: [N_samples, N, 16] AvgPool 4×4 |
| - grid_max_4: [N_samples, N, 16] MaxPool 4×4 |
| - boundary_4: [N_samples, N] boundary energy at 4×4 |
| - grid_mean_8: [N_samples, N, 64] AvgPool 8×8 |
| - grid_max_8: [N_samples, N, 64] MaxPool 8×8 |
| - boundary_8: [N_samples, N] boundary energy at 8×8 |
| - grid_mean_16: [N_samples, N, 256] AvgPool 16×16 |
| - grid_max_16: [N_samples, N, 256] MaxPool 16×16 |
| - boundary_16: [N_samples, N] boundary energy at 16×16 |
| |
| Usage: |
| python scripts/extract_features.py \\ |
| --deris_checkpoint /path/to/deris_l.pth \\ |
| --data_root /path/to/refcoco/ \\ |
| --dataset refcoco --split val \\ |
| --output data/ |
| """ |
|
|
| import argparse |
| import os |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn.functional as F |
| import numpy as np |
| from tqdm import tqdm |
|
|
|
|
| def compute_mask_statistics(mask_probs: torch.Tensor) -> dict: |
| """ |
| Compute scalar mask statistics (Section 3.2, Eq. 1). |
| |
| Args: |
| mask_probs: [N, H, W] sigmoid mask probabilities |
| |
| Returns: |
| dict with mask_mean, mask_max, mask_area, mask_std (each [N]) |
| """ |
| N, H, W = mask_probs.shape |
| flat = mask_probs.reshape(N, -1) |
|
|
| return { |
| "mask_mean": flat.mean(dim=1), |
| "mask_max": flat.max(dim=1).values, |
| "mask_area": (flat > 0.5).float().mean(1), |
| "mask_std": flat.std(dim=1), |
| } |
|
|
|
|
| def compute_grid_signatures(mask_probs: torch.Tensor, grid_size: int) -> dict: |
| """ |
| Compute multi-scale grid signatures (Section 3.3, Eqs. 2-4). |
| |
| Args: |
| mask_probs: [N, H, W] mask probabilities |
| grid_size: G (one of 4, 8, 16) |
| |
| Returns: |
| dict with grid_mean, grid_max, boundary (per query) |
| """ |
| N = mask_probs.shape[0] |
| G = grid_size |
|
|
| |
| x = mask_probs.unsqueeze(1) |
|
|
| |
| grid_mean = F.adaptive_avg_pool2d(x, (G, G)).reshape(N, G * G) |
|
|
| |
| grid_max = F.adaptive_max_pool2d(x, (G, G)).reshape(N, G * G) |
|
|
| |
| grid_2d = grid_mean.reshape(N, G, G) |
| dx = (grid_2d[:, :, 1:] - grid_2d[:, :, :-1]).abs().mean(dim=(1, 2)) |
| dy = (grid_2d[:, 1:, :] - grid_2d[:, :-1, :]).abs().mean(dim=(1, 2)) |
| boundary = 0.5 * (dx + dy) |
|
|
| return { |
| f"grid_mean_{G}": grid_mean, |
| f"grid_max_{G}": grid_max, |
| f"boundary_{G}": boundary, |
| } |
|
|
|
|
| def compute_iou(pred_mask: torch.Tensor, gt_mask: torch.Tensor) -> float: |
| """Compute IoU between binary masks.""" |
| pred = (pred_mask > 0.5).float() |
| gt = gt_mask.float() |
| intersection = (pred * gt).sum() |
| union = (pred + gt).clamp(0, 1).sum() |
| if union < 1: |
| return 0.0 |
| return (intersection / union).item() |
|
|
|
|
| def extract_sample_features( |
| mask_logits: torch.Tensor, |
| query_embeddings: torch.Tensor, |
| det_scores: torch.Tensor, |
| gt_mask: torch.Tensor, |
| ) -> dict: |
| """ |
| Extract all Venice-H1 features for one sample. |
| |
| Args: |
| mask_logits: [N, H, W] raw mask logits from DeRIS |
| query_embeddings: [N, D] query embeddings |
| det_scores: [N] detection scores |
| gt_mask: [H_gt, W_gt] ground-truth binary mask |
| |
| Returns: |
| dict with all features for this sample |
| """ |
| N = mask_logits.shape[0] |
|
|
| |
| mask_probs = torch.sigmoid(mask_logits) |
|
|
| |
| stats = compute_mask_statistics(mask_probs) |
|
|
| |
| grid_4 = compute_grid_signatures(mask_probs, 4) |
| grid_8 = compute_grid_signatures(mask_probs, 8) |
| grid_16 = compute_grid_signatures(mask_probs, 16) |
|
|
| |
| H_gt, W_gt = gt_mask.shape |
| mask_probs_resized = F.interpolate( |
| mask_probs.unsqueeze(1), size=(H_gt, W_gt), |
| mode='bilinear', align_corners=False |
| ).squeeze(1) |
|
|
| query_ious = torch.tensor([ |
| compute_iou(mask_probs_resized[i], gt_mask) for i in range(N) |
| ]) |
| oracle_idx = query_ious.argmax().item() |
|
|
| return { |
| "query_feat": query_embeddings, |
| "det_scores": det_scores, |
| "query_ious": query_ious, |
| "oracle_idx": oracle_idx, |
| **stats, |
| **grid_4, |
| **grid_8, |
| **grid_16, |
| } |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="Extract Venice-H1 features from frozen DeRIS.") |
| parser.add_argument("--deris_checkpoint", type=str, required=True, |
| help="Path to frozen DeRIS-L/B checkpoint") |
| parser.add_argument("--data_root", type=str, default="data/", |
| help="Root directory for RefCOCO data") |
| parser.add_argument("--dataset", type=str, required=True, |
| choices=["refcoco", "refcoco+", "refcocog"]) |
| parser.add_argument("--split", type=str, required=True, |
| choices=["train", "val", "testA", "testB", "test"]) |
| parser.add_argument("--output", type=str, default="data/", |
| help="Output directory for cached features") |
| parser.add_argument("--n_queries", type=int, default=10) |
| parser.add_argument("--batch_size", type=int, default=1) |
| parser.add_argument("--device", type=str, default="cuda") |
| args = parser.parse_args() |
|
|
| device = torch.device(args.device if torch.cuda.is_available() else "cpu") |
| os.makedirs(args.output, exist_ok=True) |
|
|
| print(f"Venice-H1 Feature Extraction") |
| print(f" Backbone: {args.deris_checkpoint}") |
| print(f" Dataset: {args.dataset} / {args.split}") |
| print(f" Device: {device}") |
| print() |
|
|
| |
| |
| |
| |
| print("=" * 60) |
| print("IMPORTANT: You must adapt the model loading section below") |
| print("to your DeRIS installation. See comments in this script.") |
| print("=" * 60) |
| print() |
| print("Expected DeRIS outputs per sample:") |
| print(" - query_embeddings: [N, 256] (N=10 candidate queries)") |
| print(" - mask_logits: [N, H, W] (mask predictions)") |
| print(" - det_scores: [N] (detection confidence scores)") |
| print() |
| print("Once you have DeRIS producing these outputs, the feature") |
| print("extraction loop below handles everything else automatically.") |
| print() |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| print("Feature extraction template ready.") |
| print("Uncomment the dataloader section above and adapt to your setup.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|