hmr-dataset / scripts /demo /extract_features.py
zirobtc's picture
Upload folder using huggingface_hub
fbb20ff verified
import sys
import os
import torch
import numpy as np
import cv2
import argparse
from pathlib import Path
from tqdm import tqdm
import gc
import concurrent.futures
# Ensure repo root is on sys.path
REPO_ROOT = Path(__file__).resolve().parents[2]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from genmo.utils.pylogger import Log
# Standard ImageNet Normalization
IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
def _require_extractor():
gvhmr_root = REPO_ROOT / "third_party" / "GVHMR"
if gvhmr_root.exists() and str(gvhmr_root) not in sys.path:
sys.path.insert(0, str(gvhmr_root))
try:
from third_party.GVHMR.hmr4d.utils.preproc.vitfeat_extractor import Extractor
except Exception as e:
raise RuntimeError("Could not import Extractor from GVHMR.") from e
return Extractor
# --- FAST IMAGE LOADER ---
def process_single_image(args):
path, cx, cy, scale, img_size = args
img = cv2.imread(path)
if img is None:
return np.zeros((3, img_size, img_size), dtype=np.float32)
H, W = img.shape[:2]
max_side = float(max(H, W, 1))
try:
cx = float(cx)
cy = float(cy)
scale = float(scale)
except Exception as e:
raise RuntimeError(f"Bad bbx_xys types for {path}: cx={cx} cy={cy} scale={scale}") from e
if not (np.isfinite(cx) and np.isfinite(cy) and np.isfinite(scale)):
raise RuntimeError(f"Bad bbx_xys (non-finite) for {path}: cx={cx} cy={cy} scale={scale}")
if scale <= 1.0 or scale > max_side * 20.0:
raise RuntimeError(f"Bad bbx_xys (scale) for {path}: (H,W)=({H},{W}) cx={cx} cy={cy} scale={scale}")
half = scale / 2.0
x0, y0 = int(cx - half), int(cy - half)
x1, y1 = int(cx + half), int(cy + half)
pad_l, pad_t = max(0, -x0), max(0, -y0)
pad_r, pad_b = max(0, x1 - W), max(0, y1 - H)
# Fail loudly instead of letting OpenCV try to allocate absurdly large padded images.
if max(pad_l, pad_t, pad_r, pad_b) > int(max_side * 4.0):
raise RuntimeError(
f"Insane crop for {path}: (H,W)=({H},{W}) cx={cx:.2f} cy={cy:.2f} scale={scale:.2f} "
f"pads(l,t,r,b)=({pad_l},{pad_t},{pad_r},{pad_b})"
)
if pad_l or pad_t or pad_r or pad_b:
img = cv2.copyMakeBorder(img, pad_t, pad_b, pad_l, pad_r, cv2.BORDER_CONSTANT, value=(0,0,0))
x0 += pad_l; y0 += pad_t; x1 += pad_l; y1 += pad_t
crop = img[y0:y1, x0:x1]
if crop.size == 0:
raise RuntimeError(
f"Empty crop for {path}: (H,W)=({H},{W}) cx={cx:.2f} cy={cy:.2f} scale={scale:.2f} "
f"xyxy=({x0},{y0},{x1},{y1})"
)
if crop.shape[0] != img_size or crop.shape[1] != img_size:
crop = cv2.resize(crop, (img_size, img_size), interpolation=cv2.INTER_LINEAR)
crop = crop[:, :, ::-1].astype(np.float32) / 255.0
crop = (crop - IMAGENET_MEAN) / IMAGENET_STD
return crop.transpose(2, 0, 1)
def load_images_parallel(image_paths, bbx_xys, img_size=256, workers=12):
if isinstance(bbx_xys, torch.Tensor): bbx_xys = bbx_xys.cpu().numpy()
tasks = [(str(p), b[0], b[1], b[2], img_size) for p, b in zip(image_paths, bbx_xys)]
with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor:
results = list(executor.map(process_single_image, tasks))
return torch.from_numpy(np.stack(results))
# --- OPTIMIZED INFERENCE LOOP ---
def fast_inference(model, tensor, batch_size=64):
"""
Replaces the slow extractor loop.
"""
model.eval()
F = tensor.shape[0]
features = []
# Pre-allocate pinned memory for faster transfer
tensor = tensor.contiguous()
with torch.inference_mode():
for j in range(0, F, batch_size):
# Non-blocking transfer
batch = tensor[j : j + batch_size].cuda(non_blocking=True)
# AMP (Automatic Mixed Precision) -> 2x Speedup
with torch.amp.autocast("cuda"):
# HMR2 expects dictionary input
feat = model({"img": batch})
features.append(feat.detach().cpu())
return torch.cat(features, dim=0)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_root", required=True)
parser.add_argument("--batch_size", type=int, default=256, help="Increase this if VRAM allows")
parser.add_argument("--workers", type=int, default=4)
parser.add_argument("--overwrite", action="store_true")
args = parser.parse_args()
dataset_root = Path(args.dataset_root)
feat_dir = dataset_root / "genmo_features"
images_root = dataset_root
if not feat_dir.exists():
Log.error("Feature dir not found")
return
Log.info("Initializing ViT Model...")
ExtractorClass = _require_extractor()
extractor_wrapper = ExtractorClass(tqdm_leave=False)
# Get the inner torch module (HMR2)
model = extractor_wrapper.extractor
pt_files = sorted(list(feat_dir.glob("*.pt")))
Log.info(f"Processing {len(pt_files)} sequences. Batch Size: {args.batch_size}")
for pt_file in tqdm(pt_files, desc="Dataset Progress"):
try:
data = torch.load(pt_file, map_location="cpu", weights_only=False)
if not args.overwrite and "f_imgseq" in data:
f = data["f_imgseq"]
if isinstance(f, torch.Tensor) and f.ndim == 2 and f.shape[1] > 0:
continue
# Load Images
img_rel_paths = data["imgname"]
bbx_xys = data["bbx_xys"]
abs_img_paths = [images_root / p for p in img_rel_paths]
if not abs_img_paths[0].exists():
continue
# 1. Load & Process (CPU Parallel)
input_tensor = load_images_parallel(abs_img_paths, bbx_xys, workers=args.workers)
# 2. Fast Inference (GPU FP16)
vit_features = fast_inference(model, input_tensor, batch_size=args.batch_size)
# 3. Save
data["f_imgseq"] = vit_features.float() # Save as float32 for compatibility
torch.save(data, pt_file)
except Exception as e:
Log.error(f"Error {pt_file.stem}: {e}")
continue
if __name__ == "__main__":
# Optimize CUDA allocator
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
main()