PotatoPasta's picture
Upload folder using huggingface_hub
be603e6 verified
"""
preprocess.py
AI Hub μˆ˜μ–΄ 데이터(.npz)μ—μ„œ ν‚€ν¬μΈνŠΈλ₯Ό λ‘œλ“œν•˜κ³  temporal interpolation으둜
μ‹œν€€μŠ€ 길이λ₯Ό μ •κ·œν™”ν•œλ‹€. npzλŠ” 이미 μΆ”μΆœλœ keypointμ΄λ―€λ‘œ MediaPipe λΆˆν•„μš”.
Landmark ꡬ쑰 (AI Hub κ³ μ •):
pose : 9
left_hand : 21
right_hand : 21
face : 19
각 μ’Œν‘œλŠ” (x, y, z) 3μΆ•
Feature preset:
A β†’ pose + hands + face, 3μΆ• β†’ 210 dim
B β†’ pose + hands + face, 2μΆ• β†’ 140 dim
C β†’ pose + hands, 2μΆ• β†’ 102 dim
"""
from __future__ import annotations
import logging
import os
from typing import Dict, List
import numpy as np
from tqdm import tqdm
# ────────────────────────────────────────────────────────────
# Feature preset
# ────────────────────────────────────────────────────────────
LANDMARK_COUNTS = {"pose": 9, "left_hand": 21, "right_hand": 21, "face": 19}
FEATURE_PRESETS: Dict[str, Dict] = {
"A": {"use": ["pose", "left_hand", "right_hand", "face"], "axes": 3},
"B": {"use": ["pose", "left_hand", "right_hand", "face"], "axes": 2},
"C": {"use": ["pose", "left_hand", "right_hand"], "axes": 2},
}
TARGET_LENGTH = 64
LOGGER = logging.getLogger(__name__)
def feature_dim_for(preset: str) -> int:
"""preset μ½”λ“œ β†’ feature_dim (단일 frame 벑터 차원)."""
if preset not in FEATURE_PRESETS:
raise ValueError(f"Unknown feature preset: {preset}. "
f"Expected one of {sorted(FEATURE_PRESETS)}")
cfg = FEATURE_PRESETS[preset]
n_landmarks = sum(LANDMARK_COUNTS[k] for k in cfg["use"])
return n_landmarks * cfg["axes"]
# ────────────────────────────────────────────────────────────
# npz λ‘œλ” + temporal interpolation
# ────────────────────────────────────────────────────────────
def load_npz_keypoints(npz_path: str,
preset: str,
target_length: int = TARGET_LENGTH) -> np.ndarray:
"""
μ—­ν• : 단일 .npz μ—μ„œ μ„ νƒν•œ μš”μ†Œμ˜ keypointλ₯Ό λ‘œλ“œ + 평탄화 + μ‹œν€€μŠ€ μ •κ·œν™”
μž…λ ₯: npz 경둜, preset, target_length
좜λ ₯: (target_length, feature_dim) float32
"""
cfg = FEATURE_PRESETS[preset]
data = np.load(npz_path, allow_pickle=True)
parts: List[np.ndarray] = []
T = None
for key in cfg["use"]:
arr = data[key]
if arr.ndim != 3 or arr.shape[1] != LANDMARK_COUNTS[key]:
raise ValueError(f"{npz_path} {key} shape mismatch: {arr.shape}")
if cfg["axes"] == 2:
arr = arr[..., :2]
if T is None:
T = arr.shape[0]
elif arr.shape[0] != T:
raise ValueError(f"{npz_path} inconsistent T across keys")
parts.append(arr.reshape(T, -1))
seq = np.concatenate(parts, axis=-1).astype(np.float32)
return temporal_interpolate(seq, target_length)
def temporal_interpolate(sequence: np.ndarray,
target_length: int = TARGET_LENGTH) -> np.ndarray:
"""
μ—­ν• : κ°€λ³€ 길이 μ‹œν€€μŠ€λ₯Ό target_length 둜 보간 μ •κ·œν™”
μž…λ ₯: (T, D)
좜λ ₯: (target_length, D) float32
"""
if sequence.ndim != 2:
raise ValueError(f"sequence must be 2D, got shape {sequence.shape}")
T = sequence.shape[0]
if T == target_length:
return sequence.astype(np.float32)
if T == 0:
raise ValueError("empty sequence")
if T == 1:
return np.repeat(sequence, target_length, axis=0).astype(np.float32)
src_idx = np.linspace(0, T - 1, target_length)
out = np.array([sequence[int(round(i))] for i in src_idx])
return out.astype(np.float32)
# ────────────────────────────────────────────────────────────
# 데이터셋 일괄 λ‘œλ“œ
# ────────────────────────────────────────────────────────────
def preprocess_dataset(dataset: Dict[str, Dict[str, List[str]]],
preset: str,
target_length: int = TARGET_LENGTH
) -> Dict[str, Dict[str, List[np.ndarray]]]:
"""
μ—­ν• : {word: {signer: [npz_path, ...]}} β†’ {word: {signer: [np.ndarray(T, D)]}}
μž…λ ₯: data.load_dataset κ²°κ³Ό, feature preset
좜λ ₯: 같은 nested ꡬ쑰, 값이 keypoint ndarray
주의: npz 파일 μžμ²΄κ°€ μΊμ‹œμ΄λ―€λ‘œ 별도 λ””μŠ€ν¬ μΊμ‹œ μ—†μŒ.
λ‘œλ“œ μ‹€νŒ¨ μ‹œ ν•΄λ‹Ή μƒ˜ν”Œ skip + 둜그.
"""
out: Dict[str, Dict[str, List[np.ndarray]]] = {}
total = sum(len(lst) for sm in dataset.values() for lst in sm.values())
pbar = tqdm(total=total, desc=f"[preprocess preset={preset}]")
for word, signer_map in dataset.items():
for signer, paths in signer_map.items():
for path in paths:
try:
kp = load_npz_keypoints(path, preset, target_length)
except Exception as e:
LOGGER.warning(f"skip {path}: {e}")
pbar.update(1)
continue
out.setdefault(word, {}).setdefault(signer, []).append(kp)
pbar.update(1)
pbar.close()
cleaned = {}
for w, sm in out.items():
sm2 = {s: lst for s, lst in sm.items() if lst}
if sm2:
cleaned[w] = sm2
return cleaned
# ────────────────────────────────────────────────────────────
# Smoke test
# ────────────────────────────────────────────────────────────
if __name__ == "__main__":
import sys
for p in ("A", "B", "C"):
print(f"preset {p} β†’ feature_dim = {feature_dim_for(p)}")
if len(sys.argv) > 1 and os.path.exists(sys.argv[1]):
path = sys.argv[1]
for p in ("A", "B", "C"):
kp = load_npz_keypoints(path, preset=p)
print(f" {p}: {kp.shape} dtype={kp.dtype}")