""" Dataset Loader for TouchGrass. Handles loading and preprocessing of music QA data for fine-tuning. """ from typing import List, Dict, Any, Optional from pathlib import Path import json import random from torch.utils.data import Dataset, DataLoader from transformers import AutoTokenizer class TouchGrassDataset(Dataset): """ Dataset for TouchGrass fine-tuning. Loads chat-formatted data and tokenizes for training. """ def __init__( self, data_path: str, tokenizer, max_seq_length: int = 4096, mode: str = "train", ): """ Initialize dataset. Args: data_path: Path to JSONL file with chat data tokenizer: Tokenizer (extended Qwen tokenizer) max_seq_length: Maximum sequence length mode: "train" or "eval" """ self.data_path = Path(data_path) self.tokenizer = tokenizer self.max_seq_length = max_seq_length self.mode = mode # Load data self.samples = self._load_data() print(f"Loaded {len(self.samples)} samples from {data_path}") def _load_data(self) -> List[Dict[str, Any]]: """Load data from JSONL file.""" samples = [] with open(self.data_path, "r", encoding="utf-8") as f: for line in f: if line.strip(): samples.append(json.loads(line)) return samples def __len__(self) -> int: return len(self.samples) def __getitem__(self, idx: int) -> Dict[str, Any]: sample = self.samples[idx] messages = sample["messages"] # Format as single text with chat template # Qwen3.5 uses: <|im_start|>role<|im_sep|>content<|im_end|> formatted_text = self._format_chat_qwen(messages) # Tokenize encoding = self.tokenizer( formatted_text, truncation=True, max_length=self.max_seq_length, padding="max_length" if self.mode == "train" else False, return_tensors="pt", ) input_ids = encoding["input_ids"].squeeze(0) attention_mask = encoding["attention_mask"].squeeze(0) # Labels are same as input_ids for causal LM labels = input_ids.clone() # Mask out non-assistant parts if needed # For simplicity, we train on all tokens # More sophisticated: mask user/system tokens in loss return { "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, } def _format_chat_qwen(self, messages: List[Dict[str, str]]) -> str: """ Format messages into Qwen chat format. Qwen chat format: <|im_start|>system You are a helpful assistant.<|im_end|> <|im_start|>user Hello!<|im_end|> <|im_start|>assistant Hi there!<|im_end|> """ formatted = [] for msg in messages: role = msg["role"] content = msg["content"].strip() # Map roles to Qwen format if role == "system": formatted.append(f"<|im_start|>system\n{content}<|im_end|>") elif role == "user": formatted.append(f"<|im_start|>user\n{content}<|im_end|>") elif role == "assistant": formatted.append(f"<|im_start|>assistant\n{content}<|im_end|>") else: # Skip unknown roles continue return "\n".join(formatted) def get_sample(self, idx: int) -> str: """Get raw formatted text for inspection.""" sample = self.samples[idx] messages = sample["messages"] return self._format_chat_qwen(messages) def test_dataset(): """Test the dataset loader.""" from transformers import AutoTokenizer # Load tokenizer (need to extend first) print("Loading tokenizer...") try: from tokenizer.music_token_extension import MusicTokenizerExtension tokenizer_ext = MusicTokenizerExtension( base_tokenizer_name="Qwen/Qwen3.5-3B-Instruct", ) tokenizer = tokenizer_ext.get_tokenizer() except Exception as e: print(f"Could not load tokenizer: {e}") print("Using dummy tokenizer for testing...") from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained( "Qwen/Qwen3.5-3B-Instruct", trust_remote_code=True, ) tokenizer.pad_token = tokenizer.eos_token # Create dataset print("\nCreating dataset...") dataset = TouchGrassDataset( data_path="data/processed/train.jsonl", tokenizer=tokenizer, max_seq_length=1024, # Smaller for testing mode="train", ) print(f"Dataset size: {len(dataset)}") # Get a sample if len(dataset) > 0: sample = dataset[0] print("\nSample keys:", list(sample.keys())) print("Input IDs shape:", sample["input_ids"].shape) print("Attention mask shape:", sample["attention_mask"].shape) print("Labels shape:", sample["labels"].shape) # Decode to check formatting decoded = tokenizer.decode(sample["input_ids"][:100]) print(f"\nFirst 100 tokens:\n{decoded}...") print("\nDataset test complete!") if __name__ == "__main__": test_dataset()