| | """
|
| | Dataset Loader for TouchGrass.
|
| | Handles loading and preprocessing of music QA data for fine-tuning.
|
| | """
|
| |
|
| | from typing import List, Dict, Any, Optional
|
| | from pathlib import Path
|
| | import json
|
| | import random
|
| | from torch.utils.data import Dataset, DataLoader
|
| | from transformers import AutoTokenizer
|
| |
|
| |
|
| | class TouchGrassDataset(Dataset):
|
| | """
|
| | Dataset for TouchGrass fine-tuning.
|
| | Loads chat-formatted data and tokenizes for training.
|
| | """
|
| |
|
| | def __init__(
|
| | self,
|
| | data_path: str,
|
| | tokenizer,
|
| | max_seq_length: int = 4096,
|
| | mode: str = "train",
|
| | ):
|
| | """
|
| | Initialize dataset.
|
| |
|
| | Args:
|
| | data_path: Path to JSONL file with chat data
|
| | tokenizer: Tokenizer (extended Qwen tokenizer)
|
| | max_seq_length: Maximum sequence length
|
| | mode: "train" or "eval"
|
| | """
|
| | self.data_path = Path(data_path)
|
| | self.tokenizer = tokenizer
|
| | self.max_seq_length = max_seq_length
|
| | self.mode = mode
|
| |
|
| |
|
| | self.samples = self._load_data()
|
| |
|
| | print(f"Loaded {len(self.samples)} samples from {data_path}")
|
| |
|
| | def _load_data(self) -> List[Dict[str, Any]]:
|
| | """Load data from JSONL file."""
|
| | samples = []
|
| | with open(self.data_path, "r", encoding="utf-8") as f:
|
| | for line in f:
|
| | if line.strip():
|
| | samples.append(json.loads(line))
|
| | return samples
|
| |
|
| | def __len__(self) -> int:
|
| | return len(self.samples)
|
| |
|
| | def __getitem__(self, idx: int) -> Dict[str, Any]:
|
| | sample = self.samples[idx]
|
| | messages = sample["messages"]
|
| |
|
| |
|
| |
|
| | formatted_text = self._format_chat_qwen(messages)
|
| |
|
| |
|
| | encoding = self.tokenizer(
|
| | formatted_text,
|
| | truncation=True,
|
| | max_length=self.max_seq_length,
|
| | padding="max_length" if self.mode == "train" else False,
|
| | return_tensors="pt",
|
| | )
|
| |
|
| | input_ids = encoding["input_ids"].squeeze(0)
|
| | attention_mask = encoding["attention_mask"].squeeze(0)
|
| |
|
| |
|
| | labels = input_ids.clone()
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | return {
|
| | "input_ids": input_ids,
|
| | "attention_mask": attention_mask,
|
| | "labels": labels,
|
| | }
|
| |
|
| | def _format_chat_qwen(self, messages: List[Dict[str, str]]) -> str:
|
| | """
|
| | Format messages into Qwen chat format.
|
| |
|
| | Qwen chat format:
|
| | <|im_start|>system
|
| | You are a helpful assistant.<|im_end|>
|
| | <|im_start|>user
|
| | Hello!<|im_end|>
|
| | <|im_start|>assistant
|
| | Hi there!<|im_end|>
|
| | """
|
| | formatted = []
|
| | for msg in messages:
|
| | role = msg["role"]
|
| | content = msg["content"].strip()
|
| |
|
| |
|
| | if role == "system":
|
| | formatted.append(f"<|im_start|>system\n{content}<|im_end|>")
|
| | elif role == "user":
|
| | formatted.append(f"<|im_start|>user\n{content}<|im_end|>")
|
| | elif role == "assistant":
|
| | formatted.append(f"<|im_start|>assistant\n{content}<|im_end|>")
|
| | else:
|
| |
|
| | continue
|
| |
|
| | return "\n".join(formatted)
|
| |
|
| | def get_sample(self, idx: int) -> str:
|
| | """Get raw formatted text for inspection."""
|
| | sample = self.samples[idx]
|
| | messages = sample["messages"]
|
| | return self._format_chat_qwen(messages)
|
| |
|
| |
|
| | def test_dataset():
|
| | """Test the dataset loader."""
|
| | from transformers import AutoTokenizer
|
| |
|
| |
|
| | print("Loading tokenizer...")
|
| | try:
|
| | from tokenizer.music_token_extension import MusicTokenizerExtension
|
| | tokenizer_ext = MusicTokenizerExtension(
|
| | base_tokenizer_name="Qwen/Qwen3.5-3B-Instruct",
|
| | )
|
| | tokenizer = tokenizer_ext.get_tokenizer()
|
| | except Exception as e:
|
| | print(f"Could not load tokenizer: {e}")
|
| | print("Using dummy tokenizer for testing...")
|
| | from transformers import AutoTokenizer
|
| | tokenizer = AutoTokenizer.from_pretrained(
|
| | "Qwen/Qwen3.5-3B-Instruct",
|
| | trust_remote_code=True,
|
| | )
|
| | tokenizer.pad_token = tokenizer.eos_token
|
| |
|
| |
|
| | print("\nCreating dataset...")
|
| | dataset = TouchGrassDataset(
|
| | data_path="data/processed/train.jsonl",
|
| | tokenizer=tokenizer,
|
| | max_seq_length=1024,
|
| | mode="train",
|
| | )
|
| |
|
| | print(f"Dataset size: {len(dataset)}")
|
| |
|
| |
|
| | if len(dataset) > 0:
|
| | sample = dataset[0]
|
| | print("\nSample keys:", list(sample.keys()))
|
| | print("Input IDs shape:", sample["input_ids"].shape)
|
| | print("Attention mask shape:", sample["attention_mask"].shape)
|
| | print("Labels shape:", sample["labels"].shape)
|
| |
|
| |
|
| | decoded = tokenizer.decode(sample["input_ids"][:100])
|
| | print(f"\nFirst 100 tokens:\n{decoded}...")
|
| |
|
| | print("\nDataset test complete!")
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | test_dataset() |