File size: 6,505 Bytes
fbb20ff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
import sys
import os
import torch
import numpy as np
import cv2
import argparse
from pathlib import Path
from tqdm import tqdm
import gc
import concurrent.futures
# Ensure repo root is on sys.path
REPO_ROOT = Path(__file__).resolve().parents[2]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from genmo.utils.pylogger import Log
# Standard ImageNet Normalization
IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
def _require_extractor():
gvhmr_root = REPO_ROOT / "third_party" / "GVHMR"
if gvhmr_root.exists() and str(gvhmr_root) not in sys.path:
sys.path.insert(0, str(gvhmr_root))
try:
from third_party.GVHMR.hmr4d.utils.preproc.vitfeat_extractor import Extractor
except Exception as e:
raise RuntimeError("Could not import Extractor from GVHMR.") from e
return Extractor
# --- FAST IMAGE LOADER ---
def process_single_image(args):
path, cx, cy, scale, img_size = args
img = cv2.imread(path)
if img is None:
return np.zeros((3, img_size, img_size), dtype=np.float32)
H, W = img.shape[:2]
max_side = float(max(H, W, 1))
try:
cx = float(cx)
cy = float(cy)
scale = float(scale)
except Exception as e:
raise RuntimeError(f"Bad bbx_xys types for {path}: cx={cx} cy={cy} scale={scale}") from e
if not (np.isfinite(cx) and np.isfinite(cy) and np.isfinite(scale)):
raise RuntimeError(f"Bad bbx_xys (non-finite) for {path}: cx={cx} cy={cy} scale={scale}")
if scale <= 1.0 or scale > max_side * 20.0:
raise RuntimeError(f"Bad bbx_xys (scale) for {path}: (H,W)=({H},{W}) cx={cx} cy={cy} scale={scale}")
half = scale / 2.0
x0, y0 = int(cx - half), int(cy - half)
x1, y1 = int(cx + half), int(cy + half)
pad_l, pad_t = max(0, -x0), max(0, -y0)
pad_r, pad_b = max(0, x1 - W), max(0, y1 - H)
# Fail loudly instead of letting OpenCV try to allocate absurdly large padded images.
if max(pad_l, pad_t, pad_r, pad_b) > int(max_side * 4.0):
raise RuntimeError(
f"Insane crop for {path}: (H,W)=({H},{W}) cx={cx:.2f} cy={cy:.2f} scale={scale:.2f} "
f"pads(l,t,r,b)=({pad_l},{pad_t},{pad_r},{pad_b})"
)
if pad_l or pad_t or pad_r or pad_b:
img = cv2.copyMakeBorder(img, pad_t, pad_b, pad_l, pad_r, cv2.BORDER_CONSTANT, value=(0,0,0))
x0 += pad_l; y0 += pad_t; x1 += pad_l; y1 += pad_t
crop = img[y0:y1, x0:x1]
if crop.size == 0:
raise RuntimeError(
f"Empty crop for {path}: (H,W)=({H},{W}) cx={cx:.2f} cy={cy:.2f} scale={scale:.2f} "
f"xyxy=({x0},{y0},{x1},{y1})"
)
if crop.shape[0] != img_size or crop.shape[1] != img_size:
crop = cv2.resize(crop, (img_size, img_size), interpolation=cv2.INTER_LINEAR)
crop = crop[:, :, ::-1].astype(np.float32) / 255.0
crop = (crop - IMAGENET_MEAN) / IMAGENET_STD
return crop.transpose(2, 0, 1)
def load_images_parallel(image_paths, bbx_xys, img_size=256, workers=12):
if isinstance(bbx_xys, torch.Tensor): bbx_xys = bbx_xys.cpu().numpy()
tasks = [(str(p), b[0], b[1], b[2], img_size) for p, b in zip(image_paths, bbx_xys)]
with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor:
results = list(executor.map(process_single_image, tasks))
return torch.from_numpy(np.stack(results))
# --- OPTIMIZED INFERENCE LOOP ---
def fast_inference(model, tensor, batch_size=64):
"""
Replaces the slow extractor loop.
"""
model.eval()
F = tensor.shape[0]
features = []
# Pre-allocate pinned memory for faster transfer
tensor = tensor.contiguous()
with torch.inference_mode():
for j in range(0, F, batch_size):
# Non-blocking transfer
batch = tensor[j : j + batch_size].cuda(non_blocking=True)
# AMP (Automatic Mixed Precision) -> 2x Speedup
with torch.amp.autocast("cuda"):
# HMR2 expects dictionary input
feat = model({"img": batch})
features.append(feat.detach().cpu())
return torch.cat(features, dim=0)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_root", required=True)
parser.add_argument("--batch_size", type=int, default=256, help="Increase this if VRAM allows")
parser.add_argument("--workers", type=int, default=4)
parser.add_argument("--overwrite", action="store_true")
args = parser.parse_args()
dataset_root = Path(args.dataset_root)
feat_dir = dataset_root / "genmo_features"
images_root = dataset_root
if not feat_dir.exists():
Log.error("Feature dir not found")
return
Log.info("Initializing ViT Model...")
ExtractorClass = _require_extractor()
extractor_wrapper = ExtractorClass(tqdm_leave=False)
# Get the inner torch module (HMR2)
model = extractor_wrapper.extractor
pt_files = sorted(list(feat_dir.glob("*.pt")))
Log.info(f"Processing {len(pt_files)} sequences. Batch Size: {args.batch_size}")
for pt_file in tqdm(pt_files, desc="Dataset Progress"):
try:
data = torch.load(pt_file, map_location="cpu", weights_only=False)
if not args.overwrite and "f_imgseq" in data:
f = data["f_imgseq"]
if isinstance(f, torch.Tensor) and f.ndim == 2 and f.shape[1] > 0:
continue
# Load Images
img_rel_paths = data["imgname"]
bbx_xys = data["bbx_xys"]
abs_img_paths = [images_root / p for p in img_rel_paths]
if not abs_img_paths[0].exists():
continue
# 1. Load & Process (CPU Parallel)
input_tensor = load_images_parallel(abs_img_paths, bbx_xys, workers=args.workers)
# 2. Fast Inference (GPU FP16)
vit_features = fast_inference(model, input_tensor, batch_size=args.batch_size)
# 3. Save
data["f_imgseq"] = vit_features.float() # Save as float32 for compatibility
torch.save(data, pt_file)
except Exception as e:
Log.error(f"Error {pt_file.stem}: {e}")
continue
if __name__ == "__main__":
# Optimize CUDA allocator
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
main()
|