File size: 5,604 Bytes
9071ef9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 | """
Dataset Loader for TouchGrass.
Handles loading and preprocessing of music QA data for fine-tuning.
"""
from typing import List, Dict, Any, Optional
from pathlib import Path
import json
import random
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
class TouchGrassDataset(Dataset):
"""
Dataset for TouchGrass fine-tuning.
Loads chat-formatted data and tokenizes for training.
"""
def __init__(
self,
data_path: str,
tokenizer,
max_seq_length: int = 4096,
mode: str = "train",
):
"""
Initialize dataset.
Args:
data_path: Path to JSONL file with chat data
tokenizer: Tokenizer (extended Qwen tokenizer)
max_seq_length: Maximum sequence length
mode: "train" or "eval"
"""
self.data_path = Path(data_path)
self.tokenizer = tokenizer
self.max_seq_length = max_seq_length
self.mode = mode
# Load data
self.samples = self._load_data()
print(f"Loaded {len(self.samples)} samples from {data_path}")
def _load_data(self) -> List[Dict[str, Any]]:
"""Load data from JSONL file."""
samples = []
with open(self.data_path, "r", encoding="utf-8") as f:
for line in f:
if line.strip():
samples.append(json.loads(line))
return samples
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, idx: int) -> Dict[str, Any]:
sample = self.samples[idx]
messages = sample["messages"]
# Format as single text with chat template
# Qwen3.5 uses: <|im_start|>role<|im_sep|>content<|im_end|>
formatted_text = self._format_chat_qwen(messages)
# Tokenize
encoding = self.tokenizer(
formatted_text,
truncation=True,
max_length=self.max_seq_length,
padding="max_length" if self.mode == "train" else False,
return_tensors="pt",
)
input_ids = encoding["input_ids"].squeeze(0)
attention_mask = encoding["attention_mask"].squeeze(0)
# Labels are same as input_ids for causal LM
labels = input_ids.clone()
# Mask out non-assistant parts if needed
# For simplicity, we train on all tokens
# More sophisticated: mask user/system tokens in loss
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
def _format_chat_qwen(self, messages: List[Dict[str, str]]) -> str:
"""
Format messages into Qwen chat format.
Qwen chat format:
<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
Hello!<|im_end|>
<|im_start|>assistant
Hi there!<|im_end|>
"""
formatted = []
for msg in messages:
role = msg["role"]
content = msg["content"].strip()
# Map roles to Qwen format
if role == "system":
formatted.append(f"<|im_start|>system\n{content}<|im_end|>")
elif role == "user":
formatted.append(f"<|im_start|>user\n{content}<|im_end|>")
elif role == "assistant":
formatted.append(f"<|im_start|>assistant\n{content}<|im_end|>")
else:
# Skip unknown roles
continue
return "\n".join(formatted)
def get_sample(self, idx: int) -> str:
"""Get raw formatted text for inspection."""
sample = self.samples[idx]
messages = sample["messages"]
return self._format_chat_qwen(messages)
def test_dataset():
"""Test the dataset loader."""
from transformers import AutoTokenizer
# Load tokenizer (need to extend first)
print("Loading tokenizer...")
try:
from tokenizer.music_token_extension import MusicTokenizerExtension
tokenizer_ext = MusicTokenizerExtension(
base_tokenizer_name="Qwen/Qwen3.5-3B-Instruct",
)
tokenizer = tokenizer_ext.get_tokenizer()
except Exception as e:
print(f"Could not load tokenizer: {e}")
print("Using dummy tokenizer for testing...")
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
"Qwen/Qwen3.5-3B-Instruct",
trust_remote_code=True,
)
tokenizer.pad_token = tokenizer.eos_token
# Create dataset
print("\nCreating dataset...")
dataset = TouchGrassDataset(
data_path="data/processed/train.jsonl",
tokenizer=tokenizer,
max_seq_length=1024, # Smaller for testing
mode="train",
)
print(f"Dataset size: {len(dataset)}")
# Get a sample
if len(dataset) > 0:
sample = dataset[0]
print("\nSample keys:", list(sample.keys()))
print("Input IDs shape:", sample["input_ids"].shape)
print("Attention mask shape:", sample["attention_mask"].shape)
print("Labels shape:", sample["labels"].shape)
# Decode to check formatting
decoded = tokenizer.decode(sample["input_ids"][:100])
print(f"\nFirst 100 tokens:\n{decoded}...")
print("\nDataset test complete!")
if __name__ == "__main__":
test_dataset() |