import json from pathlib import Path import torch from torch.utils.data import Dataset from .formatting import format_example class DialogueDataset(Dataset): def __init__(self, path: str, tokenizer, max_seq_len: int): self.path = Path(path) self.tokenizer = tokenizer self.max_seq_len = max_seq_len self.assistant_id = tokenizer.piece_to_id("") self.examples = [] with self.path.open("r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue self.examples.append(json.loads(line)) if not self.examples: raise ValueError(f"No examples found in {self.path}") def __len__(self): return len(self.examples) def __getitem__(self, index: int): text = format_example(self.examples[index]) ids = self.tokenizer.encode(text, out_type=int) ids = ids[: self.max_seq_len] input_ids = torch.tensor(ids[:-1], dtype=torch.long) labels = torch.tensor(ids[1:], dtype=torch.long) if self.assistant_id in ids: assistant_pos = ids.index(self.assistant_id) labels[:assistant_pos] = -100 return input_ids, labels def collate_batch(batch, pad_id: int): max_len = max(x[0].numel() for x in batch) input_ids = torch.full((len(batch), max_len), pad_id, dtype=torch.long) labels = torch.full((len(batch), max_len), -100, dtype=torch.long) for i, (x, y) in enumerate(batch): input_ids[i, : x.numel()] = x labels[i, : y.numel()] = y return input_ids, labels