TouchGrass-3b / data /dataset_loader.py
Zandy-Wandy's picture
Upload 39 files
9071ef9 verified
"""
Dataset Loader for TouchGrass.
Handles loading and preprocessing of music QA data for fine-tuning.
"""
from typing import List, Dict, Any, Optional
from pathlib import Path
import json
import random
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
class TouchGrassDataset(Dataset):
"""
Dataset for TouchGrass fine-tuning.
Loads chat-formatted data and tokenizes for training.
"""
def __init__(
self,
data_path: str,
tokenizer,
max_seq_length: int = 4096,
mode: str = "train",
):
"""
Initialize dataset.
Args:
data_path: Path to JSONL file with chat data
tokenizer: Tokenizer (extended Qwen tokenizer)
max_seq_length: Maximum sequence length
mode: "train" or "eval"
"""
self.data_path = Path(data_path)
self.tokenizer = tokenizer
self.max_seq_length = max_seq_length
self.mode = mode
# Load data
self.samples = self._load_data()
print(f"Loaded {len(self.samples)} samples from {data_path}")
def _load_data(self) -> List[Dict[str, Any]]:
"""Load data from JSONL file."""
samples = []
with open(self.data_path, "r", encoding="utf-8") as f:
for line in f:
if line.strip():
samples.append(json.loads(line))
return samples
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, idx: int) -> Dict[str, Any]:
sample = self.samples[idx]
messages = sample["messages"]
# Format as single text with chat template
# Qwen3.5 uses: <|im_start|>role<|im_sep|>content<|im_end|>
formatted_text = self._format_chat_qwen(messages)
# Tokenize
encoding = self.tokenizer(
formatted_text,
truncation=True,
max_length=self.max_seq_length,
padding="max_length" if self.mode == "train" else False,
return_tensors="pt",
)
input_ids = encoding["input_ids"].squeeze(0)
attention_mask = encoding["attention_mask"].squeeze(0)
# Labels are same as input_ids for causal LM
labels = input_ids.clone()
# Mask out non-assistant parts if needed
# For simplicity, we train on all tokens
# More sophisticated: mask user/system tokens in loss
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
def _format_chat_qwen(self, messages: List[Dict[str, str]]) -> str:
"""
Format messages into Qwen chat format.
Qwen chat format:
<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
Hello!<|im_end|>
<|im_start|>assistant
Hi there!<|im_end|>
"""
formatted = []
for msg in messages:
role = msg["role"]
content = msg["content"].strip()
# Map roles to Qwen format
if role == "system":
formatted.append(f"<|im_start|>system\n{content}<|im_end|>")
elif role == "user":
formatted.append(f"<|im_start|>user\n{content}<|im_end|>")
elif role == "assistant":
formatted.append(f"<|im_start|>assistant\n{content}<|im_end|>")
else:
# Skip unknown roles
continue
return "\n".join(formatted)
def get_sample(self, idx: int) -> str:
"""Get raw formatted text for inspection."""
sample = self.samples[idx]
messages = sample["messages"]
return self._format_chat_qwen(messages)
def test_dataset():
"""Test the dataset loader."""
from transformers import AutoTokenizer
# Load tokenizer (need to extend first)
print("Loading tokenizer...")
try:
from tokenizer.music_token_extension import MusicTokenizerExtension
tokenizer_ext = MusicTokenizerExtension(
base_tokenizer_name="Qwen/Qwen3.5-3B-Instruct",
)
tokenizer = tokenizer_ext.get_tokenizer()
except Exception as e:
print(f"Could not load tokenizer: {e}")
print("Using dummy tokenizer for testing...")
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
"Qwen/Qwen3.5-3B-Instruct",
trust_remote_code=True,
)
tokenizer.pad_token = tokenizer.eos_token
# Create dataset
print("\nCreating dataset...")
dataset = TouchGrassDataset(
data_path="data/processed/train.jsonl",
tokenizer=tokenizer,
max_seq_length=1024, # Smaller for testing
mode="train",
)
print(f"Dataset size: {len(dataset)}")
# Get a sample
if len(dataset) > 0:
sample = dataset[0]
print("\nSample keys:", list(sample.keys()))
print("Input IDs shape:", sample["input_ids"].shape)
print("Attention mask shape:", sample["attention_mask"].shape)
print("Labels shape:", sample["labels"].shape)
# Decode to check formatting
decoded = tokenizer.decode(sample["input_ids"][:100])
print(f"\nFirst 100 tokens:\n{decoded}...")
print("\nDataset test complete!")
if __name__ == "__main__":
test_dataset()