Spaces:
Running
Running
File size: 6,504 Bytes
4bd136e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 | """
Dataset loading and vocabulary building utilities
"""
import json
import os
import random
from typing import List, Dict, Tuple, Any
from collections import defaultdict
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
from config import M_START, M_END, PAD_TOKEN
# ======================================================================================
# Logic from test_overfit.py
# ======================================================================================
def read_json_data(json_path: str) -> List[Dict[str, Any]]:
"""Loads the dataset from the specified JSON file."""
if not os.path.exists(json_path):
raise FileNotFoundError(f"Dataset not found at: {json_path}")
with open(json_path, "r", encoding="utf-8") as f:
return json.load(f)
def deduplicate_and_prepare_data(entries: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], List[str]]:
"""
Cleans the entire dataset by ensuring each (word, participant_id) pair is unique.
If a conflict is found (same pair, different motion), it keeps only the first one encountered.
Then, it prepares the full list of motion tokens from the cleaned data.
"""
print("\n---> Cleaning dataset by removing ambiguous (word, participant_id) pairs...")
unique_samples = {}
conflicts_found = 0
for entry in entries:
word = entry.get("word", "").lower()
pid = entry.get("participant_id", "")
key = (word, pid)
if key not in unique_samples:
unique_samples[key] = entry
else:
# A sample for this key already exists. We only care if it's a conflict.
existing_tokens = unique_samples[key].get("motion_tokens")
current_tokens = entry.get("motion_tokens")
if existing_tokens != current_tokens:
conflicts_found += 1
# We do nothing, effectively discarding this new conflicting sample.
cleaned_data = list(unique_samples.values())
print(f"Original samples: {len(entries)}")
print(f"Cleaned samples (unique (word, pid) pairs): {len(cleaned_data)}")
print(f"Removed {len(entries) - len(cleaned_data)} total samples. ({conflicts_found} were direct conflicts).")
print("\n---> Extracting motion tokens from the full cleaned dataset...")
all_motion_tokens = set()
for entry in cleaned_data:
motion_tokens = entry.get("motion_tokens", "").strip().split()
for token in motion_tokens:
all_motion_tokens.add(f"<M{token}>")
unique_tokens = sorted(list(all_motion_tokens))
print(f"Found {len(unique_tokens)} unique motion tokens in the entire dataset.")
return cleaned_data, unique_tokens
class MotionDataset(Dataset):
"""Dataset for Stage 1: Contains only motion token sequences."""
def __init__(self, data: List[Dict[str, Any]], tokenizer: AutoTokenizer, max_length: int = 256):
self.tokenizer = tokenizer
self.max_length = max_length
self.sequences = []
for item in data:
tokens_str = item.get("motion_tokens", "")
wrapped_tokens = " ".join([f"<M{t}>" for t in tokens_str.split()])
full_sequence = f"{M_START} {wrapped_tokens} {M_END}"
self.sequences.append(full_sequence)
def __len__(self):
return len(self.sequences)
def __getitem__(self, idx):
return self.tokenizer(
self.sequences[idx],
truncation=True,
max_length=self.max_length,
padding="max_length",
return_tensors="pt"
)
class TextMotionDataset(Dataset):
"""Dataset for Stage 2: Contains (prompt, motion_sequence) pairs."""
def __init__(self, data: List[Dict[str, Any]], tokenizer: AutoTokenizer, max_length: int = 256):
self.tokenizer = tokenizer
self.max_length = max_length
self.items = []
for item in data:
prompt = f"Instruction: Generate motion for word '{item['word']}' with variant '{item['participant_id']}'.\nMotion: "
tokens_str = item.get("motion_tokens", "")
wrapped_tokens = " ".join([f"<M{t}>" for t in tokens_str.split()])
target_sequence = f"{M_START} {wrapped_tokens} {M_END}"
full_text = prompt + target_sequence
tokenized = self.tokenizer(
full_text,
truncation=True,
max_length=self.max_length,
padding="max_length",
return_tensors="pt"
)
prompt_tokenized = self.tokenizer(prompt, return_tensors="pt")
prompt_len = prompt_tokenized.input_ids.shape[1]
labels = tokenized['input_ids'].clone()
labels[0, :prompt_len] = -100
self.items.append({
"input_ids": tokenized['input_ids'].squeeze(0),
"attention_mask": tokenized['attention_mask'].squeeze(0),
"labels": labels.squeeze(0)
})
def __len__(self):
return len(self.items)
def __getitem__(self, idx):
return self.items[idx]
# ======================================================================================
# Legacy utilities (kept for compatibility if needed, but mostly superseded)
# ======================================================================================
def build_motion_vocab(dataset):
"""
Build motion vocabulary by finding max token ID
Returns: (codebook_size, max_token_id)
"""
def max_token_in_example(ex):
return max(int(x) for x in ex["motion_tokens"].split())
global_max_id = 0
for ex in dataset:
global_max_id = max(global_max_id, max_token_in_example(ex))
codebook_size = global_max_id + 1
return codebook_size, global_max_id
def motion_specials_to_ids(s: str) -> List[int]:
"""Extract motion IDs from special tokens"""
toks = s.strip().split()
ids = []
for t in toks:
if t.startswith("<motion_") or (t.startswith("<M") and t.endswith(">") and t[2:-1].isdigit()):
# Handle both <motion_ID> and <MID> formats
try:
if t.startswith("<motion_"):
ids.append(int(t[8:-1]))
else:
ids.append(int(t[2:-1]))
except:
pass
return ids
|