venice-h1 / scripts /extract_features.py
OdaxAI's picture
Add scripts/extract_features.py
09f0673 verified
Raw
History Blame Contribute Delete
8.81 kB
#!/usr/bin/env python3
"""
Feature extraction from frozen DeRIS backbone (Section 3.1-3.3 of Venice-H1).
Produces cached .pt files containing per-sample features:
- query_feat: [N_samples, N, D] query embeddings (D=256)
- det_scores: [N_samples, N] detection scores
- query_ious: [N_samples, N] per-query IoU vs GT
- oracle_idx: [N_samples] best query index
- mask_mean: [N_samples, N] μ_i = mean(P_i)
- mask_max: [N_samples, N] p̂_i = max(P_i)
- mask_area: [N_samples, N] a_i = mean(P_i > 0.5)
- mask_std: [N_samples, N] σ_i = std(P_i)
- grid_mean_4: [N_samples, N, 16] AvgPool 4×4
- grid_max_4: [N_samples, N, 16] MaxPool 4×4
- boundary_4: [N_samples, N] boundary energy at 4×4
- grid_mean_8: [N_samples, N, 64] AvgPool 8×8
- grid_max_8: [N_samples, N, 64] MaxPool 8×8
- boundary_8: [N_samples, N] boundary energy at 8×8
- grid_mean_16: [N_samples, N, 256] AvgPool 16×16
- grid_max_16: [N_samples, N, 256] MaxPool 16×16
- boundary_16: [N_samples, N] boundary energy at 16×16
Usage:
python scripts/extract_features.py \\
--deris_checkpoint /path/to/deris_l.pth \\
--data_root /path/to/refcoco/ \\
--dataset refcoco --split val \\
--output data/
"""
import argparse
import os
from pathlib import Path
import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
def compute_mask_statistics(mask_probs: torch.Tensor) -> dict:
"""
Compute scalar mask statistics (Section 3.2, Eq. 1).
Args:
mask_probs: [N, H, W] sigmoid mask probabilities
Returns:
dict with mask_mean, mask_max, mask_area, mask_std (each [N])
"""
N, H, W = mask_probs.shape
flat = mask_probs.reshape(N, -1)
return {
"mask_mean": flat.mean(dim=1), # μ_i
"mask_max": flat.max(dim=1).values, # p̂_i
"mask_area": (flat > 0.5).float().mean(1), # a_i
"mask_std": flat.std(dim=1), # σ_i
}
def compute_grid_signatures(mask_probs: torch.Tensor, grid_size: int) -> dict:
"""
Compute multi-scale grid signatures (Section 3.3, Eqs. 2-4).
Args:
mask_probs: [N, H, W] mask probabilities
grid_size: G (one of 4, 8, 16)
Returns:
dict with grid_mean, grid_max, boundary (per query)
"""
N = mask_probs.shape[0]
G = grid_size
# Reshape for pooling: [N, 1, H, W]
x = mask_probs.unsqueeze(1)
# Eq. 2: grid mean (AvgPool)
grid_mean = F.adaptive_avg_pool2d(x, (G, G)).reshape(N, G * G)
# Eq. 3: grid max (MaxPool)
grid_max = F.adaptive_max_pool2d(x, (G, G)).reshape(N, G * G)
# Eq. 4: boundary energy (mean absolute gradient of grid_mean)
grid_2d = grid_mean.reshape(N, G, G)
dx = (grid_2d[:, :, 1:] - grid_2d[:, :, :-1]).abs().mean(dim=(1, 2))
dy = (grid_2d[:, 1:, :] - grid_2d[:, :-1, :]).abs().mean(dim=(1, 2))
boundary = 0.5 * (dx + dy)
return {
f"grid_mean_{G}": grid_mean,
f"grid_max_{G}": grid_max,
f"boundary_{G}": boundary,
}
def compute_iou(pred_mask: torch.Tensor, gt_mask: torch.Tensor) -> float:
"""Compute IoU between binary masks."""
pred = (pred_mask > 0.5).float()
gt = gt_mask.float()
intersection = (pred * gt).sum()
union = (pred + gt).clamp(0, 1).sum()
if union < 1:
return 0.0
return (intersection / union).item()
def extract_sample_features(
mask_logits: torch.Tensor,
query_embeddings: torch.Tensor,
det_scores: torch.Tensor,
gt_mask: torch.Tensor,
) -> dict:
"""
Extract all Venice-H1 features for one sample.
Args:
mask_logits: [N, H, W] raw mask logits from DeRIS
query_embeddings: [N, D] query embeddings
det_scores: [N] detection scores
gt_mask: [H_gt, W_gt] ground-truth binary mask
Returns:
dict with all features for this sample
"""
N = mask_logits.shape[0]
# Eq. 1: mask probabilities
mask_probs = torch.sigmoid(mask_logits) # [N, H, W]
# Mask statistics (Section 3.2)
stats = compute_mask_statistics(mask_probs)
# Multi-scale grid signatures (Section 3.3)
grid_4 = compute_grid_signatures(mask_probs, 4)
grid_8 = compute_grid_signatures(mask_probs, 8)
grid_16 = compute_grid_signatures(mask_probs, 16)
# Compute IoU for each query vs GT
H_gt, W_gt = gt_mask.shape
mask_probs_resized = F.interpolate(
mask_probs.unsqueeze(1), size=(H_gt, W_gt),
mode='bilinear', align_corners=False
).squeeze(1)
query_ious = torch.tensor([
compute_iou(mask_probs_resized[i], gt_mask) for i in range(N)
])
oracle_idx = query_ious.argmax().item()
return {
"query_feat": query_embeddings, # [N, D]
"det_scores": det_scores, # [N]
"query_ious": query_ious, # [N]
"oracle_idx": oracle_idx,
**stats,
**grid_4,
**grid_8,
**grid_16,
}
def main():
parser = argparse.ArgumentParser(
description="Extract Venice-H1 features from frozen DeRIS.")
parser.add_argument("--deris_checkpoint", type=str, required=True,
help="Path to frozen DeRIS-L/B checkpoint")
parser.add_argument("--data_root", type=str, default="data/",
help="Root directory for RefCOCO data")
parser.add_argument("--dataset", type=str, required=True,
choices=["refcoco", "refcoco+", "refcocog"])
parser.add_argument("--split", type=str, required=True,
choices=["train", "val", "testA", "testB", "test"])
parser.add_argument("--output", type=str, default="data/",
help="Output directory for cached features")
parser.add_argument("--n_queries", type=int, default=10)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--device", type=str, default="cuda")
args = parser.parse_args()
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
os.makedirs(args.output, exist_ok=True)
print(f"Venice-H1 Feature Extraction")
print(f" Backbone: {args.deris_checkpoint}")
print(f" Dataset: {args.dataset} / {args.split}")
print(f" Device: {device}")
print()
# ---- Load DeRIS model ----
# NOTE: Adapt this import to your DeRIS installation path
# from deris.model import build_deris
# model = build_deris(args.deris_checkpoint).to(device).eval()
print("=" * 60)
print("IMPORTANT: You must adapt the model loading section below")
print("to your DeRIS installation. See comments in this script.")
print("=" * 60)
print()
print("Expected DeRIS outputs per sample:")
print(" - query_embeddings: [N, 256] (N=10 candidate queries)")
print(" - mask_logits: [N, H, W] (mask predictions)")
print(" - det_scores: [N] (detection confidence scores)")
print()
print("Once you have DeRIS producing these outputs, the feature")
print("extraction loop below handles everything else automatically.")
print()
# ---- Placeholder: replace with your data loader ----
# dataloader = build_refcoco_loader(args.data_root, args.dataset,
# args.split, batch_size=1)
#
# all_features = []
# for batch in tqdm(dataloader, desc=f"Extracting {args.split}"):
# img = batch["image"].to(device)
# expr = batch["expression"]
# gt_mask = batch["gt_mask"].to(device)
#
# with torch.no_grad():
# outputs = model(img, expr)
# mask_logits = outputs["pred_masks"][:args.n_queries]
# query_emb = outputs["query_embeddings"][:args.n_queries]
# scores = outputs["det_scores"][:args.n_queries]
#
# feats = extract_sample_features(
# mask_logits.squeeze(0), query_emb.squeeze(0),
# scores.squeeze(0), gt_mask.squeeze(0))
# all_features.append(feats)
#
# ---- Stack and save ----
# output_path = os.path.join(
# args.output,
# f"cached_{args.split}_{args.dataset}_unc_feats.pt")
# stacked = {k: torch.stack([f[k] for f in all_features])
# for k in all_features[0].keys()
# if k != "oracle_idx"}
# stacked["oracle_idx"] = torch.tensor(
# [f["oracle_idx"] for f in all_features])
# torch.save(stacked, output_path)
# print(f"Saved {len(all_features)} samples → {output_path}")
print("Feature extraction template ready.")
print("Uncomment the dataloader section above and adapt to your setup.")
if __name__ == "__main__":
main()