| import json |
| from pathlib import Path |
|
|
| import torch |
| from torch.utils.data import Dataset |
|
|
| from .formatting import format_example |
|
|
|
|
| class DialogueDataset(Dataset): |
| def __init__(self, path: str, tokenizer, max_seq_len: int): |
| self.path = Path(path) |
| self.tokenizer = tokenizer |
| self.max_seq_len = max_seq_len |
| self.assistant_id = tokenizer.piece_to_id("<assistant>") |
| self.examples = [] |
|
|
| with self.path.open("r", encoding="utf-8") as f: |
| for line in f: |
| line = line.strip() |
| if not line: |
| continue |
| self.examples.append(json.loads(line)) |
|
|
| if not self.examples: |
| raise ValueError(f"No examples found in {self.path}") |
|
|
| def __len__(self): |
| return len(self.examples) |
|
|
| def __getitem__(self, index: int): |
| text = format_example(self.examples[index]) |
| ids = self.tokenizer.encode(text, out_type=int) |
| ids = ids[: self.max_seq_len] |
| input_ids = torch.tensor(ids[:-1], dtype=torch.long) |
| labels = torch.tensor(ids[1:], dtype=torch.long) |
| if self.assistant_id in ids: |
| assistant_pos = ids.index(self.assistant_id) |
| labels[:assistant_pos] = -100 |
| return input_ids, labels |
|
|
|
|
| def collate_batch(batch, pad_id: int): |
| max_len = max(x[0].numel() for x in batch) |
| input_ids = torch.full((len(batch), max_len), pad_id, dtype=torch.long) |
| labels = torch.full((len(batch), max_len), -100, dtype=torch.long) |
|
|
| for i, (x, y) in enumerate(batch): |
| input_ids[i, : x.numel()] = x |
| labels[i, : y.numel()] = y |
|
|
| return input_ids, labels |
|
|