""" MiniMind Dataset and DataLoader utilities """ import json from typing import Optional, List, Dict, Any from pathlib import Path import torch from torch.utils.data import Dataset, DataLoader class TextDataset(Dataset): """Simple text dataset for language model training.""" def __init__( self, data_path: str, tokenizer: Any, max_length: int = 2048, format_type: str = "jsonl", # jsonl, txt, parquet ): self.tokenizer = tokenizer self.max_length = max_length self.data = self._load_data(data_path, format_type) def _load_data(self, data_path: str, format_type: str) -> List[str]: data = [] path = Path(data_path) if format_type == "jsonl": with open(path, "r", encoding="utf-8") as f: for line in f: item = json.loads(line.strip()) text = item.get("text", item.get("content", "")) if text: data.append(text) elif format_type == "txt": with open(path, "r", encoding="utf-8") as f: data = [line.strip() for line in f if line.strip()] else: raise ValueError(f"Unsupported format: {format_type}") return data def __len__(self) -> int: return len(self.data) def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: text = self.data[idx] encoding = self.tokenizer( text, truncation=True, max_length=self.max_length, padding="max_length", return_tensors="pt", ) return { "input_ids": encoding["input_ids"].squeeze(0), "attention_mask": encoding["attention_mask"].squeeze(0), "labels": encoding["input_ids"].squeeze(0), } class DistillationDataset(Dataset): """Dataset for knowledge distillation with teacher logits.""" def __init__( self, data_path: str, tokenizer: Any, teacher_logits_path: Optional[str] = None, max_length: int = 2048, ): self.tokenizer = tokenizer self.max_length = max_length self.data = self._load_data(data_path) self.teacher_logits = self._load_teacher_logits(teacher_logits_path) if teacher_logits_path else None def _load_data(self, data_path: str) -> List[str]: with open(data_path, "r", encoding="utf-8") as f: return [json.loads(line.strip()).get("text", "") for line in f if line.strip()] def _load_teacher_logits(self, path: str) -> Optional[torch.Tensor]: if Path(path).exists(): return torch.load(path, map_location="cpu") return None def __len__(self) -> int: return len(self.data) def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: text = self.data[idx] encoding = self.tokenizer( text, truncation=True, max_length=self.max_length, padding="max_length", return_tensors="pt", ) item = { "input_ids": encoding["input_ids"].squeeze(0), "attention_mask": encoding["attention_mask"].squeeze(0), "labels": encoding["input_ids"].squeeze(0), } if self.teacher_logits is not None: item["teacher_logits"] = self.teacher_logits[idx] return item def create_dataloader( dataset: Dataset, batch_size: int = 8, shuffle: bool = True, num_workers: int = 4, pin_memory: bool = True, ) -> DataLoader: """Create a DataLoader with optimal settings.""" return DataLoader( dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory, drop_last=True, )