Spaces:
Running
Running
| """ | |
| Evaluation metrics for motion generation | |
| """ | |
| import random | |
| import os | |
| import re | |
| import json | |
| import numpy as np | |
| import scipy.linalg | |
| import torch | |
| from typing import List, Tuple, Dict, Optional, Any | |
| from rapidfuzz.distance import Levenshtein | |
| from collections import defaultdict | |
| from data import motion_specials_to_ids | |
| from config import ( | |
| SEED, PIPELINE_OUTPUT_DIR, M_START, M_END, | |
| INFERENCE_TEMPERATURE, INFERENCE_TOP_K, INFERENCE_REPETITION_PENALTY | |
| ) | |
| random.seed(SEED) | |
| # ====================================================================================== | |
| # Logic from test_overfit.py (Metrics & Visualization) | |
| # ====================================================================================== | |
| def calculate_activation_statistics_np(activations: np.ndarray): | |
| """ | |
| Params: | |
| -- activations: num_samples x dim_feat (numpy) | |
| Returns: | |
| -- mu: dim_feat | |
| -- sigma: dim_feat x dim_feat | |
| """ | |
| mu = np.mean(activations, axis=0) | |
| cov = np.cov(activations, rowvar=False) | |
| return mu, cov | |
| def calculate_frechet_distance_np(mu1, sigma1, mu2, sigma2, eps=1e-6): | |
| """Numpy implementation of the Frechet Distance.""" | |
| mu1 = np.atleast_1d(mu1) | |
| mu2 = np.atleast_1d(mu2) | |
| sigma1 = np.atleast_2d(sigma1) | |
| sigma2 = np.atleast_2d(sigma2) | |
| assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths" | |
| assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions" | |
| diff = mu1 - mu2 | |
| covmean, _ = scipy.linalg.sqrtm(sigma1.dot(sigma2), disp=False) | |
| if not np.isfinite(covmean).all(): | |
| offset = np.eye(sigma1.shape[0]) * eps | |
| covmean = scipy.linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) | |
| if np.iscomplexobj(covmean): | |
| if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): | |
| m = np.max(np.abs(covmean.imag)) | |
| raise ValueError(f"Imaginary component {m}") | |
| covmean = covmean.real | |
| tr_covmean = np.trace(covmean) | |
| return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean | |
| def calculate_diversity_np(activation: np.ndarray, diversity_times: int = 200) -> float: | |
| """Mean pairwise L2 distance across random pairs.""" | |
| assert len(activation.shape) == 2 | |
| if activation.shape[0] < 2: | |
| return 0.0 | |
| num_samples = activation.shape[0] | |
| effective_times = min(diversity_times, max(1, num_samples - 1)) | |
| first_indices = np.random.choice(num_samples, effective_times, replace=False) | |
| second_indices = np.random.choice(num_samples, effective_times, replace=False) | |
| diffs = activation[first_indices] - activation[second_indices] | |
| dist = np.linalg.norm(diffs, axis=1) | |
| return float(dist.mean()) | |
| def calculate_multimodality_np(activation: np.ndarray, multimodality_times: int = 20) -> float: | |
| """ | |
| activation: [num_labels, num_per_label, D] | |
| Returns mean pairwise within-label diversity (higher = more multimodal). | |
| """ | |
| assert len(activation.shape) == 3 | |
| num_labels, num_per_label, _ = activation.shape | |
| if num_per_label < 2: | |
| return float("nan") | |
| effective_times = min(multimodality_times, max(1, num_per_label - 1)) | |
| first_dices = np.random.choice(num_per_label, effective_times, replace=False) | |
| second_dices = np.random.choice(num_per_label, effective_times, replace=False) | |
| diffs = activation[:, first_dices] - activation[:, second_dices] | |
| dist = np.linalg.norm(diffs, axis=2) | |
| return float(dist.mean()) | |
| # -------------------------------------------------------------------------------------- | |
| # Token sequence → activation (bag-of-motion-tokens) helpers | |
| # -------------------------------------------------------------------------------------- | |
| def _extract_motion_tokens_from_sequence(seq: str) -> list[str]: | |
| # Expect tokens like <M123>, within M_START/M_END fences; keep only <M...> | |
| return [tok for tok in seq.split() if tok.startswith("<M") and tok.endswith(">")] | |
| def _extract_ids_from_sequence(seq: str) -> list[int]: | |
| return [int(t[2:-1]) for t in _extract_motion_tokens_from_sequence(seq) if t[2:-1].isdigit()] | |
| def _build_token_index(tokens_vocab: list[str]) -> Dict[str, int]: | |
| return {tok: idx for idx, tok in enumerate(tokens_vocab)} | |
| def _sequence_to_activation(seq: str, token_to_index: Dict[str, int]) -> np.ndarray: | |
| vec = np.zeros((len(token_to_index),), dtype=np.float32) | |
| for tok in _extract_motion_tokens_from_sequence(seq): | |
| idx = token_to_index.get(tok) | |
| if idx is not None: | |
| vec[idx] += 1.0 | |
| # Normalize to unit length to reduce length bias | |
| norm = np.linalg.norm(vec) | |
| if norm > 0: | |
| vec = vec / norm | |
| return vec | |
| def generate_motion(model, tokenizer, prompt, device): | |
| """Generates a motion sequence from a prompt using sampling.""" | |
| model.eval() | |
| inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| output = model.generate( | |
| **inputs, | |
| max_new_tokens=100, | |
| do_sample=True, | |
| temperature=INFERENCE_TEMPERATURE, | |
| top_k=INFERENCE_TOP_K, | |
| repetition_penalty=INFERENCE_REPETITION_PENALTY, | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.convert_tokens_to_ids(M_END), | |
| early_stopping=True | |
| ) | |
| decoded = tokenizer.decode(output[0], skip_special_tokens=False) | |
| if "Motion: " in decoded: | |
| motion_part = decoded.split("Motion: ")[-1] | |
| else: | |
| motion_part = decoded | |
| return motion_part.strip() | |
| def _collect_eval_pairs(model, tokenizer, data, device) -> list[Tuple[str, str, str]]: | |
| """ | |
| Returns list of (word, participant_id, gt_sequence, generated_sequence) for each sample in data. | |
| """ | |
| results = [] | |
| for sample in data: | |
| gt_tokens_str = sample.get("motion_tokens", "") | |
| gt_wrapped = " ".join([f"<M{t}>" for t in gt_tokens_str.split()]) | |
| gt_sequence = f"{M_START} {gt_wrapped} {M_END}" | |
| prompt = f"Instruction: Generate motion for word '{sample['word']}' with variant '{sample['participant_id']}'.\nMotion: " | |
| generated_sequence = generate_motion(model, tokenizer, prompt, device) | |
| pid = str(sample.get("participant_id", "")) | |
| results.append((sample["word"], pid, gt_sequence, generated_sequence)) | |
| return results | |
| def _activations_from_pairs(pairs: list[Tuple[str, str, str]], vocab_tokens: list[str]): | |
| """ | |
| Build numpy activations and labels arrays from sequences. | |
| Returns: | |
| gt_acts: (N, D) | |
| gen_acts: (N, D) | |
| labels: list[str] length N (word labels) | |
| """ | |
| token_to_index = _build_token_index(vocab_tokens) | |
| gt_vecs = [] | |
| gen_vecs = [] | |
| labels = [] | |
| for pair in pairs: | |
| # Support both legacy 3-tuple (word, gt, gen) and new 4-tuple (word, pid, gt, gen) | |
| if len(pair) == 4: | |
| word, _pid, gt_seq, gen_seq = pair | |
| else: | |
| word, gt_seq, gen_seq = pair | |
| gt_vecs.append(_sequence_to_activation(gt_seq, token_to_index)) | |
| gen_vecs.append(_sequence_to_activation(gen_seq, token_to_index)) | |
| labels.append(word) | |
| return np.stack(gt_vecs, axis=0), np.stack(gen_vecs, axis=0), labels | |
| def _to_label_tensor3(acts: np.ndarray, labels: list[str]) -> np.ndarray: | |
| """ | |
| Convert N x D activations with string labels to [L, K, D] by truncating each label | |
| to the minimum count across labels. | |
| """ | |
| label_to_indices: Dict[str, list[int]] = {} | |
| for i, lbl in enumerate(labels): | |
| label_to_indices.setdefault(lbl, []).append(i) | |
| per_label_counts = [len(idxs) for idxs in label_to_indices.values()] | |
| if len(per_label_counts) == 0: | |
| raise ValueError("No labels found for multimodality computation.") | |
| min_count = max(2, min(per_label_counts)) | |
| label_names = sorted(label_to_indices.keys()) | |
| stacked = [] | |
| for lbl in label_names: | |
| idxs = label_to_indices[lbl][:min_count] | |
| stacked.append(acts[idxs]) | |
| return np.stack(stacked, axis=0) # [L, K, D] | |
| def evaluate_metrics_motiongpt_style(model, tokenizer, eval_data, all_motion_tokens, device): | |
| """ | |
| Computes: | |
| - Diversity: GT vs GEN (pair) | |
| - Multimodality (MIM): GT vs GEN (pair) | |
| - FID: between GT and GEN | |
| """ | |
| print("\n" + "="*80) | |
| print(" METRICS EVALUATION (FID, Diversity, Multimodality)") | |
| print("="*80) | |
| pairs = _collect_eval_pairs(model, tokenizer, eval_data, device) | |
| gt_acts, gen_acts, labels = _activations_from_pairs(pairs, all_motion_tokens) | |
| # Diversity | |
| diversity_times = min(200, max(4, gt_acts.shape[0] - 1)) | |
| diversity_gt = calculate_diversity_np(gt_acts, diversity_times=diversity_times) | |
| diversity_gen = calculate_diversity_np(gen_acts, diversity_times=diversity_times) | |
| # Multimodality (MIM) | |
| try: | |
| gt_lbl_tensor = _to_label_tensor3(gt_acts, labels) | |
| gen_lbl_tensor = _to_label_tensor3(gen_acts, labels) | |
| multimodality_times = min(20, max(3, gt_lbl_tensor.shape[1] - 1)) | |
| mim_gt = calculate_multimodality_np(gt_lbl_tensor, multimodality_times=multimodality_times) | |
| mim_gen = calculate_multimodality_np(gen_lbl_tensor, multimodality_times=multimodality_times) | |
| except Exception as exc: | |
| print(f"⚠️ Multimodality could not be computed reliably: {exc}") | |
| mim_gt = float("nan") | |
| mim_gen = float("nan") | |
| # FID | |
| mu_gen, cov_gen = calculate_activation_statistics_np(gen_acts) | |
| mu_gt, cov_gt = calculate_activation_statistics_np(gt_acts) | |
| fid = calculate_frechet_distance_np(mu_gt, cov_gt, mu_gen, cov_gen) | |
| print(f"Diversity: GT = {diversity_gt:.4f} | GEN = {diversity_gen:.4f}") | |
| print(f"Multimodality (MIM): GT = {mim_gt:.4f} | GEN = {mim_gen:.4f}") | |
| print(f"FID (GT vs GEN): {fid:.4f}") | |
| return { | |
| "diversity_gt": diversity_gt, | |
| "diversity_gen": diversity_gen, | |
| "mim_gt": mim_gt, | |
| "mim_gen": mim_gen, | |
| "fid": fid, | |
| "pairs": pairs, # for visualization usage | |
| } | |
| def _encode_params_to_feature(params: np.ndarray, vq_model, mean, std, device) -> np.ndarray: | |
| """ | |
| Convert SMPL-X parameter sequence (T, D) into a single clip feature using | |
| the VQ-VAE encoder output BEFORE quantization. Average-pool over time to get (D_embed,). | |
| """ | |
| if params.size == 0: | |
| return np.zeros((getattr(vq_model.vqvae, "output_emb_width", 512),), dtype=np.float32) | |
| x = torch.from_numpy(params.astype(np.float32)).to(device) # [T, D] | |
| x = x.unsqueeze(0) # [1, T, D] | |
| with torch.no_grad(): | |
| # Normalize / preprocess | |
| x_pre = None | |
| if hasattr(vq_model.vqvae, "preprocess"): | |
| try: | |
| x_pre = vq_model.vqvae.preprocess(x) # expected to return tensor ready for encoder | |
| except Exception: | |
| x_pre = None | |
| if x_pre is None: | |
| # Manual normalization with provided mean/std | |
| if mean is not None and std is not None: | |
| mean_t = torch.from_numpy(np.array(mean, dtype=np.float32)).to(device).view(1, 1, -1) | |
| std_t = torch.from_numpy(np.array(std, dtype=np.float32)).to(device).view(1, 1, -1) | |
| x_norm = (x - mean_t) / (std_t + 1e-8) | |
| else: | |
| x_norm = x | |
| # Some encoders expect [N, D, T] | |
| x_pre = x_norm.transpose(1, 2).contiguous() # [1, D, T] | |
| # Encode to get pre-quant latent | |
| z_e = vq_model.vqvae.encoder(x_pre) | |
| # z_e could be [N, D_embed, T_q] or [N, T_q, D_embed] | |
| if z_e.dim() == 3: | |
| embed_dim_known = getattr(vq_model.vqvae, "output_emb_width", None) | |
| if embed_dim_known is not None: | |
| if z_e.shape[1] == embed_dim_known: | |
| time_axis = 2 # [N, D_embed, T_q] | |
| elif z_e.shape[2] == embed_dim_known: | |
| time_axis = 1 # [N, T_q, D_embed] | |
| else: | |
| time_axis = 2 if z_e.shape[2] < z_e.shape[1] else 1 | |
| else: | |
| time_axis = 2 if z_e.shape[2] < z_e.shape[1] else 1 | |
| feat = z_e.mean(dim=time_axis).squeeze(0) | |
| elif z_e.dim() == 2: | |
| feat = z_e.squeeze(0) | |
| else: | |
| feat = z_e.view(1, -1).mean(dim=0) | |
| feat_np = feat.detach().cpu().numpy().astype(np.float32) | |
| # L2 normalize | |
| norm = np.linalg.norm(feat_np) | |
| if norm > 0: | |
| feat_np = feat_np / norm | |
| return feat_np | |
| def evaluate_metrics_encoder_style( | |
| model, | |
| tokenizer, | |
| eval_data, | |
| device, | |
| vqvae_ckpt: Optional[str] = None, | |
| stats_path: Optional[str] = None, | |
| sample_limit: int = 100, | |
| ): | |
| """ | |
| Computes FID, Diversity, and MIM using VQ-VAE encoder pre-quantization features. | |
| """ | |
| print("\n" + "="*80) | |
| print(" METRICS EVALUATION (VQ-VAE Encoder Features)") | |
| print("="*80) | |
| # Lazy import to reuse your visualization utilities and stats | |
| try: | |
| from visualize import load_vqvae, load_stats, VQVAE_CHECKPOINT as DEFAULT_VQ, STATS_PATH as DEFAULT_STATS | |
| vq_ckpt = vqvae_ckpt or os.getenv("VQVAE_CHECKPOINT", DEFAULT_VQ) | |
| stats_p = stats_path or os.getenv("VQVAE_STATS_PATH", DEFAULT_STATS) | |
| vq_model = load_vqvae(vq_ckpt, device=device) | |
| mean, std = load_stats(stats_p) | |
| from visualize import decode_tokens_to_params | |
| except Exception as exc: | |
| print(f"⚠️ Could not set up VQ-VAE encoder metrics: {exc}") | |
| return {} | |
| # Collect GT/GEN token sequences for pairs (limit to speed-up) | |
| pairs = _collect_eval_pairs(model, tokenizer, eval_data[:sample_limit], device) | |
| # Build features | |
| gt_feats = [] | |
| gen_feats = [] | |
| labels = [] | |
| for pair in pairs: | |
| if len(pair) == 4: | |
| word, _pid, gt_seq, gen_seq = pair | |
| else: | |
| word, gt_seq, gen_seq = pair | |
| # Decode to SMPL-X | |
| tokens_gt = _extract_ids_from_sequence(gt_seq) | |
| tokens_gen = _extract_ids_from_sequence(gen_seq) | |
| try: | |
| params_gt = decode_tokens_to_params(tokens_gt, vq_model, mean, std, device=device) # (T, D) denorm | |
| except Exception: | |
| params_gt = np.zeros((0, 182), dtype=np.float32) | |
| try: | |
| params_gen = decode_tokens_to_params(tokens_gen, vq_model, mean, std, device=device) # (T, D) denorm | |
| except Exception: | |
| params_gen = np.zeros((0, 182), dtype=np.float32) | |
| # Encode (pre-quant) -> pooled feature | |
| feat_gt = _encode_params_to_feature(params_gt, vq_model, mean, std, device) | |
| feat_gen = _encode_params_to_feature(params_gen, vq_model, mean, std, device) | |
| gt_feats.append(feat_gt) | |
| gen_feats.append(feat_gen) | |
| labels.append(word) | |
| gt_feats = np.stack(gt_feats, axis=0) | |
| gen_feats = np.stack(gen_feats, axis=0) | |
| # Diversity | |
| diversity_times = min(200, max(4, gt_feats.shape[0] - 1)) | |
| diversity_gt = calculate_diversity_np(gt_feats, diversity_times=diversity_times) | |
| diversity_gen = calculate_diversity_np(gen_feats, diversity_times=diversity_times) | |
| # Multimodality (MIM) | |
| try: | |
| gt_lbl_tensor = _to_label_tensor3(gt_feats, labels) | |
| gen_lbl_tensor = _to_label_tensor3(gen_feats, labels) | |
| multimodality_times = min(20, max(3, gt_lbl_tensor.shape[1] - 1)) | |
| mim_gt = calculate_multimodality_np(gt_lbl_tensor, multimodality_times=multimodality_times) | |
| mim_gen = calculate_multimodality_np(gen_lbl_tensor, multimodality_times=multimodality_times) | |
| except Exception as exc: | |
| print(f"⚠️ Multimodality could not be computed reliably: {exc}") | |
| mim_gt = float("nan") | |
| mim_gen = float("nan") | |
| # FID (on encoder features) | |
| mu_gen, cov_gen = calculate_activation_statistics_np(gen_feats) | |
| mu_gt, cov_gt = calculate_activation_statistics_np(gt_feats) | |
| fid = calculate_frechet_distance_np(mu_gt, cov_gt, mu_gen, cov_gen) | |
| print(f"Diversity (encoder feats): GT = {diversity_gt:.4f} | GEN = {diversity_gen:.4f}") | |
| print(f"Multimodality (MIM, encoder): GT = {mim_gt:.4f} | GEN = {mim_gen:.4f}") | |
| print(f"FID (encoder feats, GT vs GEN): {fid:.4f}") | |
| return { | |
| "diversity_gt": diversity_gt, | |
| "diversity_gen": diversity_gen, | |
| "mim_gt": mim_gt, | |
| "mim_gen": mim_gen, | |
| "fid": fid, | |
| "pairs": pairs, | |
| } | |
| def save_side_by_side_visualizations(pairs: list[Tuple[str, str, str]], output_dir: str, limit: int = 4): | |
| """ | |
| Generate side-by-side 3D animations for GT vs GEN. | |
| """ | |
| try: | |
| from visualize import ( | |
| load_vqvae, load_stats, load_smplx_model, | |
| decode_tokens_to_params, params_to_vertices, | |
| VQVAE_CHECKPOINT as DEFAULT_VQ, STATS_PATH as DEFAULT_STATS, SMPLX_MODEL_DIR as DEFAULT_SMPLX | |
| ) | |
| import plotly.graph_objects as go | |
| from plotly.subplots import make_subplots | |
| except Exception as exc: | |
| print(f"⚠️ Visualization skipped (missing dependencies): {exc}") | |
| return | |
| os.makedirs(output_dir, exist_ok=True) | |
| vqvae_ckpt = os.getenv("VQVAE_CHECKPOINT", DEFAULT_VQ) | |
| stats_path = os.getenv("VQVAE_STATS_PATH", DEFAULT_STATS) | |
| smplx_dir = os.getenv("SMPLX_MODEL_DIR", DEFAULT_SMPLX) | |
| print("Loading VQ-VAE, stats, SMPL-X ...") | |
| vq_model = load_vqvae(vqvae_ckpt) | |
| mean, std = load_stats(stats_path) | |
| smplx_model = load_smplx_model(smplx_dir) | |
| def animate_side_by_side(verts_left, faces, verts_right, fps=20, titles=("Ground Truth", "LLM Generated"), output_html=None): | |
| T = min(verts_left.shape[0], verts_right.shape[0]) | |
| verts_left, verts_right = verts_left[:T], verts_right[:T] | |
| i, j, k = faces.T.tolist() | |
| fig = make_subplots( | |
| rows=1, cols=2, | |
| specs=[[{'type': 'scene'}, {'type': 'scene'}]], | |
| horizontal_spacing=0.05, | |
| subplot_titles=list(titles) | |
| ) | |
| left_mesh = go.Mesh3d(x=verts_left[0,:,0], y=verts_left[0,:,1], z=verts_left[0,:,2], i=i,j=j,k=k,opacity=0.7,showscale=False) | |
| right_mesh = go.Mesh3d(x=verts_right[0,:,0], y=verts_right[0,:,1], z=verts_right[0,:,2], i=i,j=j,k=k,opacity=0.7,showscale=False) | |
| fig.add_trace(left_mesh, row=1, col=1) | |
| fig.add_trace(right_mesh, row=1, col=2) | |
| frames = [] | |
| for t in range(T): | |
| frames.append(go.Frame( | |
| name=str(t), | |
| data=[ | |
| go.Mesh3d(x=verts_left[t,:,0], y=verts_left[t,:,1], z=verts_left[t,:,2], i=i,j=j,k=k,opacity=0.7,showscale=False,scene="scene"), | |
| go.Mesh3d(x=verts_right[t,:,0], y=verts_right[t,:,1], z=verts_right[t,:,2], i=i,j=j,k=k,opacity=0.7,showscale=False,scene="scene2") | |
| ] | |
| )) | |
| fig.frames = frames | |
| fig.update_layout( | |
| showlegend=False, | |
| margin=dict(l=10, r=10, t=50, b=10), | |
| scene=dict(aspectmode='data',xaxis=dict(visible=False),yaxis=dict(visible=False),zaxis=dict(visible=False), | |
| camera=dict(eye=dict(x=0,y=-2,z=0.7))), | |
| scene2=dict(aspectmode='data',xaxis=dict(visible=False),yaxis=dict(visible=False),zaxis=dict(visible=False), | |
| camera=dict(eye=dict(x=0,y=-2,z=0.7))), | |
| updatemenus=[dict( | |
| type="buttons", x=0.5, xanchor="center", y=1.15, yanchor="top", | |
| buttons=[ | |
| dict(label="Play", method="animate", args=[None, {"frame": {"duration": max(1,1000//fps), "redraw": True}, "fromcurrent": True}]), | |
| dict(label="Pause", method="animate", args=[[None], {"frame": {"duration": 0, "redraw": False}}]) | |
| ] | |
| )] | |
| ) | |
| if output_html: | |
| fig.write_html(output_html) | |
| print(f"✅ Saved: {output_html}") | |
| return fig | |
| # Determine which words to include (up to `limit` distinct words) | |
| allowed_words = None | |
| if isinstance(limit, int) and limit > 0: | |
| ordered_unique_words = [] | |
| for pair in pairs: | |
| word = pair[0] | |
| if word not in ordered_unique_words: | |
| ordered_unique_words.append(word) | |
| if len(ordered_unique_words) >= limit: | |
| break | |
| allowed_words = set(ordered_unique_words) | |
| for pair in pairs: | |
| try: | |
| if len(pair) == 4: | |
| word, pid, gt_seq, gen_seq = pair | |
| else: | |
| word, gt_seq, gen_seq = pair | |
| pid = "unknown" | |
| if allowed_words is not None and word not in allowed_words: | |
| continue | |
| tokens_gt = _extract_ids_from_sequence(gt_seq) | |
| tokens_gen = _extract_ids_from_sequence(gen_seq) | |
| params_gt = decode_tokens_to_params(tokens_gt, vq_model, mean, std) | |
| params_gen = decode_tokens_to_params(tokens_gen, vq_model, mean, std) | |
| verts_gt, faces = params_to_vertices(params_gt, smplx_model) | |
| verts_gen, _ = params_to_vertices(params_gen, smplx_model) | |
| out_dir = os.path.join(output_dir) | |
| os.makedirs(out_dir, exist_ok=True) | |
| # Sanitize for filesystem safety | |
| safe_word = re.sub(r'[^A-Za-z0-9_-]+', '_', str(word)) | |
| safe_pid = re.sub(r'[^A-Za-z0-9_-]+', '_', str(pid)) | |
| output_html = os.path.join(out_dir, f"word_{safe_word}_{safe_pid}_side_by_side.html") | |
| animate_side_by_side( | |
| verts_left=verts_gt, | |
| faces=faces, | |
| verts_right=verts_gen, | |
| fps=20, | |
| titles=("Ground Truth", "LLM Generated"), | |
| output_html=output_html | |
| ) | |
| except Exception as exc: | |
| print(f"⚠️ Error creating visualization for word '{pair[0]}': {exc}") | |
| def run_inference_on_all_samples(model, tokenizer, data, device): | |
| """ | |
| Runs inference on ALL available samples for the trained words and compares | |
| each one to its specific ground truth. | |
| """ | |
| print("\n" + "="*80) | |
| print(" INFERENCE AND EVALUATION (ALL SAMPLES)") | |
| print(" Goal: Test the model's performance on every variant.") | |
| print("="*80) | |
| def compare_sequences(gt: str, gen: str): | |
| """Provides a simple visual diff of two sequences without external libraries.""" | |
| gt_tokens = gt.split() | |
| gen_tokens = gen.split() | |
| print("\nDetailed Comparison (✅ = Match, ❌ = Mismatch/Missing/Added):") | |
| gt_str = " GT: " | |
| gen_str = " GEN: " | |
| diff_str = " " | |
| max_len = max(len(gt_tokens), len(gen_tokens)) | |
| for i in range(max_len): | |
| gt_tok = gt_tokens[i] if i < len(gt_tokens) else "___" | |
| gen_tok = gen_tokens[i] if i < len(gen_tokens) else "___" | |
| max_tok_len = max(len(gt_tok), len(gen_tok)) | |
| gt_tok_padded = gt_tok.ljust(max_tok_len) | |
| gen_tok_padded = gen_tok.ljust(max_tok_len) | |
| gt_str += gt_tok_padded + " " | |
| gen_str += gen_tok_padded + " " | |
| if gt_tok == gen_tok: | |
| diff_str += "✅".ljust(max_tok_len) + " " | |
| else: | |
| diff_str += "❌".ljust(max_tok_len) + " " | |
| print(gt_str) | |
| print(gen_str) | |
| print(diff_str) | |
| data_by_word = {} | |
| for item in data: | |
| word = item['word'] | |
| if word not in data_by_word: | |
| data_by_word[word] = [] | |
| data_by_word[word].append(item) | |
| for word, samples in data_by_word.items(): | |
| print(f"\n\n{'='*25} TESTING WORD: '{word}' {'='*25}") | |
| num_correct = 0 | |
| for i, sample in enumerate(samples): | |
| print(f"\n--- Testing Variant {i+1}/{len(samples)}: '{sample['participant_id']}' ---") | |
| gt_tokens_str = sample.get("motion_tokens", "") | |
| gt_wrapped = " ".join([f"<M{t}>" for t in gt_tokens_str.split()]) | |
| gt_sequence = f"{M_START} {gt_wrapped} {M_END}" | |
| print(f"Ground Truth:\n{gt_sequence}") | |
| prompt = f"Instruction: Generate motion for word '{sample['word']}' with variant '{sample['participant_id']}'.\nMotion: " | |
| generated_sequence = generate_motion(model, tokenizer, prompt, device) | |
| print(f"\nLLM Generated:\n{generated_sequence}") | |
| compare_sequences(gt_sequence, generated_sequence) | |
| if gt_sequence.strip() == generated_sequence.strip(): | |
| num_correct += 1 | |
| print("-" * 80) | |
| accuracy = (num_correct / len(samples)) * 100 | |
| print(f"\nSUMMARY FOR '{word}': {num_correct}/{len(samples)} correct ({accuracy:.1f}%)") | |
| # ====================================================================================== | |
| # Existing Utilities (Compatibility) | |
| # ====================================================================================== | |
| def seq_edit_distance(a_ids: List[int], b_ids: List[int]) -> int: | |
| """Token-level Levenshtein distance""" | |
| return Levenshtein.distance(a_ids, b_ids) | |
| def best_ref_distance(pred_ids: List[int], refs: List[List[int]]) -> int: | |
| """Find minimum edit distance to any reference""" | |
| if not refs: | |
| return len(pred_ids) | |
| return min(seq_edit_distance(pred_ids, r) for r in refs) | |
| def build_text_to_refs(dataset): | |
| """ | |
| Build mapping from text prompts to list of reference motion sequences | |
| """ | |
| text_to_refs = defaultdict(list) | |
| for ex in dataset: | |
| text_to_refs[ex["text_query"]].append( | |
| [int(x) for x in ex["motion_tokens"].split()] | |
| ) | |
| return text_to_refs | |
| def _concat(ids_list: List[List[int]]) -> List[int]: | |
| out = [] | |
| for s in ids_list: | |
| out.extend(s) | |
| return out | |
| def _distinct_n(ids_list: List[List[int]], n: int) -> float: | |
| if n <= 0: | |
| return 0.0 | |
| total = 0 | |
| uniq = set() | |
| for seq in ids_list: | |
| if len(seq) < n: | |
| continue | |
| total += (len(seq) - n + 1) | |
| for i in range(len(seq) - n + 1): | |
| uniq.add(tuple(seq[i:i+n])) | |
| if total == 0: | |
| return 0.0 | |
| return len(uniq) / float(total) | |
| def token_fid_diag(gens: List[List[int]], refs: List[List[int]], codebook_size: int) -> float: | |
| """ | |
| Diagonal-covariance Fréchet distance between histograms of token usage. | |
| This is a lightweight proxy for FID using token distributions. | |
| """ | |
| if len(gens) == 0 or len(refs) == 0: | |
| return float("nan") | |
| def feats(batch: List[List[int]]) -> np.ndarray: | |
| mats = [] | |
| for seq in batch: | |
| hist = np.bincount([x for x in seq if 0 <= x < codebook_size], minlength=codebook_size).astype(np.float64) | |
| s = hist.sum() | |
| if s > 0: | |
| hist /= s | |
| mats.append(hist) | |
| return np.stack(mats, axis=0) | |
| G = feats(gens) | |
| R = feats(refs) | |
| mu_g = G.mean(axis=0) | |
| mu_r = R.mean(axis=0) | |
| var_g = G.var(axis=0) | |
| var_r = R.var(axis=0) | |
| mean_term = np.sum((mu_g - mu_r) ** 2) | |
| # Diagonal covariance approximation | |
| cov_term = np.sum(var_g + var_r - 2.0 * np.sqrt(np.clip(var_g * var_r, 0.0, None))) | |
| return float(mean_term + cov_term) | |
| def compute_token_metrics( | |
| gen_by_text: Dict[str, List[int]], | |
| text_to_refs: Dict[str, List[List[int]]], | |
| codebook_size: int, | |
| ) -> Dict[str, float]: | |
| """ | |
| Compute token-level metrics: | |
| - FID_diag: Fréchet distance between token histograms (diag cov) | |
| - MIM: average min edit distance to references | |
| - Diversity: distinct-1 and distinct-2 | |
| """ | |
| gens = list(gen_by_text.values()) | |
| refs_all = _concat([v for v in text_to_refs.values()]) | |
| # refs_all is concatenated list of ids; split sequences are needed | |
| ref_seqs = [r for refs in text_to_refs.values() for r in refs] | |
| fid_diag = token_fid_diag(gens, ref_seqs, codebook_size) | |
| # MIM: average best edit distance per prompt (only over prompts we generated) | |
| mim_dists = [] | |
| for text, gen_ids in gen_by_text.items(): | |
| refs = text_to_refs.get(text, []) | |
| mim_dists.append(best_ref_distance(gen_ids, refs)) | |
| mim = float(sum(mim_dists) / len(mim_dists)) if mim_dists else float("nan") | |
| div1 = _distinct_n(gens, 1) | |
| div2 = _distinct_n(gens, 2) | |
| return { | |
| "FID_diag": fid_diag, | |
| "MIM": mim, | |
| "distinct_1": div1, | |
| "distinct_2": div2, | |
| } | |
| def eval_t2m_set( | |
| model, | |
| tokenizer, | |
| sample_pairs: List[Tuple[str, List[List[int]]]], | |
| mot_begin_id: int, | |
| mot_end_id: int, | |
| motion_token_ids: list, | |
| length_stats_by_text: dict, | |
| global_median_len: int, | |
| prompt_vocab: dict = None, | |
| has_pid: bool = False, | |
| per_prompt_vocab: bool = True, | |
| n_eval: int = 100 | |
| ): | |
| """ | |
| Evaluate text-to-motion generation on a set of samples | |
| Returns a compact dict with avg_edit_dist & median_len; kept for pipeline compatibility. | |
| """ | |
| random.shuffle(sample_pairs) | |
| subset = sample_pairs[:min(n_eval, len(sample_pairs))] | |
| dists = [] | |
| lens = [] | |
| for text, ref_list in subset: | |
| gen = generate_t2m( | |
| model=model, | |
| tokenizer=tokenizer, | |
| prompt_text=text, | |
| mot_begin_id=mot_begin_id, | |
| mot_end_id=mot_end_id, | |
| motion_token_ids=motion_token_ids, | |
| length_stats_by_text=length_stats_by_text, | |
| global_median_len=global_median_len, | |
| prompt_vocab=prompt_vocab, | |
| pid=None, | |
| has_pid=has_pid, | |
| per_prompt_vocab=per_prompt_vocab | |
| ) | |
| span = gen.split("<MOT_BEGIN>")[-1] | |
| span = span.split("<MOT_END>")[0] | |
| pred_ids = motion_specials_to_ids(span) | |
| d = best_ref_distance(pred_ids, ref_list) | |
| dists.append(d) | |
| lens.append(len(pred_ids)) | |
| if dists: | |
| avg_dist = sum(dists) / len(dists) | |
| median_len = sorted(lens)[len(lens)//2] if lens else 0 | |
| print(f"Eval T2M: avg_edit_dist={avg_dist:.2f}, median_len={median_len}, n={len(dists)}") | |
| return {"avg_edit_dist": avg_dist, "median_len": median_len, "n_samples": len(dists)} | |
| else: | |
| print("Eval T2M: no samples") | |
| return {} | |