SignMotionGPT / data.py
rdz-falcon's picture
Deploy SignMotionGPT Demo with LFS
4bd136e
"""
Dataset loading and vocabulary building utilities
"""
import json
import os
import random
from typing import List, Dict, Tuple, Any
from collections import defaultdict
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
from config import M_START, M_END, PAD_TOKEN
# ======================================================================================
# Logic from test_overfit.py
# ======================================================================================
def read_json_data(json_path: str) -> List[Dict[str, Any]]:
"""Loads the dataset from the specified JSON file."""
if not os.path.exists(json_path):
raise FileNotFoundError(f"Dataset not found at: {json_path}")
with open(json_path, "r", encoding="utf-8") as f:
return json.load(f)
def deduplicate_and_prepare_data(entries: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], List[str]]:
"""
Cleans the entire dataset by ensuring each (word, participant_id) pair is unique.
If a conflict is found (same pair, different motion), it keeps only the first one encountered.
Then, it prepares the full list of motion tokens from the cleaned data.
"""
print("\n---> Cleaning dataset by removing ambiguous (word, participant_id) pairs...")
unique_samples = {}
conflicts_found = 0
for entry in entries:
word = entry.get("word", "").lower()
pid = entry.get("participant_id", "")
key = (word, pid)
if key not in unique_samples:
unique_samples[key] = entry
else:
# A sample for this key already exists. We only care if it's a conflict.
existing_tokens = unique_samples[key].get("motion_tokens")
current_tokens = entry.get("motion_tokens")
if existing_tokens != current_tokens:
conflicts_found += 1
# We do nothing, effectively discarding this new conflicting sample.
cleaned_data = list(unique_samples.values())
print(f"Original samples: {len(entries)}")
print(f"Cleaned samples (unique (word, pid) pairs): {len(cleaned_data)}")
print(f"Removed {len(entries) - len(cleaned_data)} total samples. ({conflicts_found} were direct conflicts).")
print("\n---> Extracting motion tokens from the full cleaned dataset...")
all_motion_tokens = set()
for entry in cleaned_data:
motion_tokens = entry.get("motion_tokens", "").strip().split()
for token in motion_tokens:
all_motion_tokens.add(f"<M{token}>")
unique_tokens = sorted(list(all_motion_tokens))
print(f"Found {len(unique_tokens)} unique motion tokens in the entire dataset.")
return cleaned_data, unique_tokens
class MotionDataset(Dataset):
"""Dataset for Stage 1: Contains only motion token sequences."""
def __init__(self, data: List[Dict[str, Any]], tokenizer: AutoTokenizer, max_length: int = 256):
self.tokenizer = tokenizer
self.max_length = max_length
self.sequences = []
for item in data:
tokens_str = item.get("motion_tokens", "")
wrapped_tokens = " ".join([f"<M{t}>" for t in tokens_str.split()])
full_sequence = f"{M_START} {wrapped_tokens} {M_END}"
self.sequences.append(full_sequence)
def __len__(self):
return len(self.sequences)
def __getitem__(self, idx):
return self.tokenizer(
self.sequences[idx],
truncation=True,
max_length=self.max_length,
padding="max_length",
return_tensors="pt"
)
class TextMotionDataset(Dataset):
"""Dataset for Stage 2: Contains (prompt, motion_sequence) pairs."""
def __init__(self, data: List[Dict[str, Any]], tokenizer: AutoTokenizer, max_length: int = 256):
self.tokenizer = tokenizer
self.max_length = max_length
self.items = []
for item in data:
prompt = f"Instruction: Generate motion for word '{item['word']}' with variant '{item['participant_id']}'.\nMotion: "
tokens_str = item.get("motion_tokens", "")
wrapped_tokens = " ".join([f"<M{t}>" for t in tokens_str.split()])
target_sequence = f"{M_START} {wrapped_tokens} {M_END}"
full_text = prompt + target_sequence
tokenized = self.tokenizer(
full_text,
truncation=True,
max_length=self.max_length,
padding="max_length",
return_tensors="pt"
)
prompt_tokenized = self.tokenizer(prompt, return_tensors="pt")
prompt_len = prompt_tokenized.input_ids.shape[1]
labels = tokenized['input_ids'].clone()
labels[0, :prompt_len] = -100
self.items.append({
"input_ids": tokenized['input_ids'].squeeze(0),
"attention_mask": tokenized['attention_mask'].squeeze(0),
"labels": labels.squeeze(0)
})
def __len__(self):
return len(self.items)
def __getitem__(self, idx):
return self.items[idx]
# ======================================================================================
# Legacy utilities (kept for compatibility if needed, but mostly superseded)
# ======================================================================================
def build_motion_vocab(dataset):
"""
Build motion vocabulary by finding max token ID
Returns: (codebook_size, max_token_id)
"""
def max_token_in_example(ex):
return max(int(x) for x in ex["motion_tokens"].split())
global_max_id = 0
for ex in dataset:
global_max_id = max(global_max_id, max_token_in_example(ex))
codebook_size = global_max_id + 1
return codebook_size, global_max_id
def motion_specials_to_ids(s: str) -> List[int]:
"""Extract motion IDs from special tokens"""
toks = s.strip().split()
ids = []
for t in toks:
if t.startswith("<motion_") or (t.startswith("<M") and t.endswith(">") and t[2:-1].isdigit()):
# Handle both <motion_ID> and <MID> formats
try:
if t.startswith("<motion_"):
ids.append(int(t[8:-1]))
else:
ids.append(int(t[2:-1]))
except:
pass
return ids