yolo-athletic-pose-estimation / example /post_process_keypoints.py
ray96nex's picture
Upload folder using huggingface_hub
6b5b22f verified
import os
import argparse
import pandas as pd
import numpy as np
from scipy.ndimage import gaussian_filter1d
from pathlib import Path
# Keypoint schema
COCO_BODY_17 = [
"nose", "left_eye", "right_eye", "left_ear", "right_ear",
"left_shoulder", "right_shoulder", "left_elbow", "right_elbow",
"left_wrist", "right_wrist", "left_hip", "right_hip",
"left_knee", "right_knee", "left_ankle", "right_ankle"
]
FEET_6_LABELS = [
"left_heel", "left_big_toe", "left_little_toe",
"right_heel", "right_big_toe", "right_little_toe"
]
ALL_KEYPOINTS = COCO_BODY_17 + FEET_6_LABELS
def parse_labels_to_df(labels_dir, img_width, img_height):
records = []
label_files = sorted([f for f in os.listdir(labels_dir) if f.endswith('.txt')])
for frame_idx, label_file in enumerate(label_files):
label_path = os.path.join(labels_dir, label_file)
with open(label_path, 'r') as f:
lines = f.readlines()
for line in lines:
parts = line.strip().split()
if len(parts) < 5: continue
keypoint_data = parts[5:]
kp_dict = {}
for i in range(0, len(keypoint_data), 3):
kp_idx = i // 3
if kp_idx >= len(ALL_KEYPOINTS): break
x_norm, y_norm, conf = float(keypoint_data[i]), float(keypoint_data[i+1]), float(keypoint_data[i+2])
if conf > 0:
kp_dict[ALL_KEYPOINTS[kp_idx]] = {'x': x_norm * img_width, 'y': y_norm * img_height, 'conf': conf}
records.append({'frame': frame_idx, 'keypoints': kp_dict})
return pd.DataFrame(records)
def process_keypoints(df, sigma=2.0):
"""Expand, interpolate, and smooth keypoints."""
# Expand dictionary into columns
rows = []
for _, row in df.iterrows():
flat = {'frame': row['frame']}
for kp, vals in row['keypoints'].items():
flat[f"{kp}_x"] = vals['x']
flat[f"{kp}_y"] = vals['y']
flat[f"{kp}_conf"] = vals['conf']
rows.append(flat)
expanded = pd.DataFrame(rows)
# Temporal interpolation and smoothing
for kp in ALL_KEYPOINTS:
for suffix in ['_x', '_y']:
col = f"{kp}{suffix}"
if col in expanded.columns:
expanded[col] = expanded[col].interpolate(method='linear').ffill().bfill()
expanded[col] = gaussian_filter1d(expanded[col], sigma=sigma)
return expanded
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--labels", required=True, help="Input labels directory")
parser.add_argument("--output", required=True, help="Output CSV path")
parser.add_argument("--width", type=int, default=1920)
parser.add_argument("--height", type=int, default=1080)
parser.add_argument("--sigma", type=float, default=2.0)
args = parser.parse_args()
print("Post-processing keypoints...")
df_raw = parse_labels_to_df(args.labels, args.width, args.height)
df_clean = process_keypoints(df_raw, sigma=args.sigma)
df_clean.to_csv(args.output, index=False)
print(f"✅ Cleaned keypoints saved to {args.output}")