hmr-dataset / prepare_unity_data.py
zirobtc's picture
Upload folder using huggingface_hub
fbb20ff verified
import numpy as np
import json
import os
import re
from difflib import SequenceMatcher
from tqdm import tqdm
# --- CONFIGURATION ---
INPUT_TRAIN_JSON = "./train.json"
NPZ_FOLDER = "./npz_data"
OUTPUT_DIR = "./unity_ready_json"
FPS = 30.0
# Styles
STYLE_FILLIAN = 0
STYLE_BIBOO = 1
STYLE_ANNY = 2 # Default for unknown characters
STYLE_LAPWING = 3 # Vlog Override
def fuzzy_match_name(text, target, threshold=0.75):
tokens = re.split(r'[^a-z]+', text.lower())
for token in tokens:
if len(token) < 3: continue
if SequenceMatcher(None, token, target).ratio() >= threshold:
return True
return False
def get_base_style(filename):
"""
Determines global style based on filename.
Hierarchy: Biboo -> Fillian -> Anny (Default).
"""
clean_name = filename.lower()
# 1. Check for Biboo
if fuzzy_match_name(clean_name, "biboo", threshold=0.8):
return STYLE_BIBOO
# 2. Check for Fillian
if fuzzy_match_name(clean_name, "fillian", threshold=0.8) or "filian" in clean_name:
return STYLE_FILLIAN
# 3. Default fallback for Miltina, Anny, and others
return STYLE_ANNY
def is_vlog_label(label_entry):
"""
Checks if label indicates vlogging/handheld camera.
CRITICAL FIX: Explicitly excludes 'end of vlog' or 'place camera back'.
"""
proc_label = label_entry.get("proc_label", "").lower()
# 1. EXCLUSION RULES (If these exist, it is NOT vlogging)
if "place camera back" in proc_label or "end of vlog" in proc_label:
return False
# 2. INCLUSION RULES
if "vlog" in proc_label:
return True
if "act_cat" in label_entry:
for cat in label_entry["act_cat"]:
if "vlog" in cat.lower():
return True
return False
def is_transition_label(label_entry):
"""Checks if this is a generic transition label."""
proc = label_entry.get("proc_label", "").lower()
return "transition" in proc
def process_single_entry(entry_id, entry_data):
npz_filename = entry_data.get("feat_p")
npz_path = os.path.join(NPZ_FOLDER, npz_filename)
if not os.path.exists(npz_path):
return
# 1. Load Data
try:
data = np.load(npz_path)
poses = data['poses']
trans = data['trans']
betas = data['betas']
if poses.ndim == 3: poses = poses[0]
if trans.ndim == 3: trans = trans[0]
num_frames = poses.shape[0]
except Exception as e:
print(f"❌ Error loading {npz_filename}: {e}")
return
# 2. Determine Base Style
base_style = get_base_style(npz_filename)
# Initialize all frames with the Base Style
frame_styles = np.full(num_frames, base_style, dtype=int)
# 3. Apply Vlog Logic (State Machine Override)
if "frame_ann" in entry_data and "labels" in entry_data["frame_ann"]:
# Sort labels by time
labels = sorted(entry_data["frame_ann"]["labels"], key=lambda x: x.get("start_t", 0))
previous_was_vlog = False
for label in labels:
start_t = label.get("start_t", 0.0)
end_t = label.get("end_t", 0.0)
s_f = max(0, int(start_t * FPS))
e_f = min(num_frames, int(end_t * FPS))
if e_f <= s_f: continue
if is_vlog_label(label):
# Vlog Label -> Set to Lapwing Style
frame_styles[s_f:e_f] = STYLE_LAPWING
previous_was_vlog = True
elif is_transition_label(label) and previous_was_vlog:
# Transition immediately after Vlog -> Collapse Gap (Keep as Vlog)
frame_styles[s_f:e_f] = STYLE_LAPWING
# Keep state as true
else:
# Regular action OR "Place camera back" -> Reset to Base Style
previous_was_vlog = False
# 4. Construct Frame Data
frames_data = []
poses_list = np.round(poses, 4).tolist()
trans_list = np.round(trans, 4).tolist()
betas_list = np.round(betas, 4).tolist()
styles_list = frame_styles.tolist()
for i in range(num_frames):
frame_entry = {
"i": i,
"p": poses_list[i],
"t": trans_list[i],
"b": betas_list,
"s": styles_list[i]
}
frames_data.append(frame_entry)
# 5. Save JSON
clean_name = os.path.splitext(npz_filename)[0]
output_filename = f"{entry_id}_{clean_name}.json"
output_path = os.path.join(OUTPUT_DIR, output_filename)
style_debug = "Anny"
if base_style == STYLE_FILLIAN: style_debug = "Fillian"
elif base_style == STYLE_BIBOO: style_debug = "Biboo"
wrapper = {
"fps": FPS,
"video_ref": entry_data.get("video_ref_path", ""),
"base_style_debug": style_debug,
"frames": frames_data
}
with open(output_path, 'w') as f:
json.dump(wrapper, f, separators=(',', ':'))
def main():
if not os.path.exists(OUTPUT_DIR):
os.makedirs(OUTPUT_DIR)
print(f"πŸ“‚ Loading Train JSON: {INPUT_TRAIN_JSON}")
if not os.path.exists(INPUT_TRAIN_JSON):
print("❌ Train JSON not found.")
return
with open(INPUT_TRAIN_JSON, 'r') as f:
train_index = json.load(f)
print(f"πŸš€ Processing {len(train_index)} sequences...")
for key, val in tqdm(train_index.items()):
process_single_entry(key, val)
print("βœ… Conversion Complete.")
if __name__ == "__main__":
main()