| """ |
| MASH Dataset for Style-injection SFT β v3 (T5 compatible) |
| |
| Each raw sample produces TWO training examples: |
| 1. Transfer task: (ai_text, s_human_*) β human_text |
| 2. Reconstruction: (ai_text, s_ai_*) β ai_text |
| |
| The DataLoader interleaves both tasks within each batch. |
| """ |
|
|
| import json |
| import random |
| import torch |
| from torch.utils.data import Dataset, DataLoader |
|
|
|
|
| class MASHSFTDataset(Dataset): |
| """ |
| Dataset for Style-injection SFT. |
| |
| Each __getitem__ returns one training example with: |
| - input_text (AI text) |
| - target_text (human or AI text) |
| - style_key ('human_ps', 'human_supp', 'ai_ps', 'ai_supp') |
| - task ('transfer' or 'reconstruction') |
| """ |
| |
| def __init__(self, data_path: str, tokenizer, |
| max_input_len: int = 512, max_target_len: int = 512, |
| include_reconstruction: bool = True, |
| reconstruction_ratio: float = 0.3): |
| self.tokenizer = tokenizer |
| self.max_input_len = max_input_len |
| self.max_target_len = max_target_len |
| |
| |
| self.raw_data = [] |
| with open(data_path) as f: |
| for line in f: |
| self.raw_data.append(json.loads(line)) |
| |
| |
| self.examples = [] |
| for d in self.raw_data: |
| essay_type = d['type'] |
| |
| |
| self.examples.append({ |
| 'input_text': d['input_text'], |
| 'target_text': d['human_text'], |
| 'style_key': f'human_{essay_type}', |
| 'task': 'transfer', |
| 'essay_type': essay_type, |
| }) |
| |
| |
| if include_reconstruction: |
| self.examples.append({ |
| 'input_text': d['input_text'], |
| 'target_text': d['ai_text'], |
| 'style_key': f'ai_{essay_type}', |
| 'task': 'reconstruction', |
| 'essay_type': essay_type, |
| }) |
| |
| |
| if include_reconstruction and reconstruction_ratio < 0.5: |
| transfer = [e for e in self.examples if e['task'] == 'transfer'] |
| recon = [e for e in self.examples if e['task'] == 'reconstruction'] |
| n_recon = int(len(transfer) * reconstruction_ratio / (1 - reconstruction_ratio)) |
| random.shuffle(recon) |
| recon = recon[:n_recon] |
| self.examples = transfer + recon |
| random.shuffle(self.examples) |
| |
| def __len__(self): |
| return len(self.examples) |
| |
| def __getitem__(self, idx): |
| ex = self.examples[idx] |
| |
| |
| input_enc = self.tokenizer( |
| ex['input_text'], |
| max_length=self.max_input_len, |
| truncation=True, |
| padding='max_length', |
| return_tensors='pt', |
| ) |
| |
| |
| target_enc = self.tokenizer( |
| text_target=ex['target_text'], |
| max_length=self.max_target_len, |
| truncation=True, |
| padding='max_length', |
| return_tensors='pt', |
| ) |
| |
| |
| labels = target_enc['input_ids'].squeeze() |
| labels[labels == self.tokenizer.pad_token_id] = -100 |
| |
| return { |
| 'input_ids': input_enc['input_ids'].squeeze(), |
| 'attention_mask': input_enc['attention_mask'].squeeze(), |
| 'labels': labels, |
| 'style_key': ex['style_key'], |
| 'task': ex['task'], |
| 'essay_type': ex['essay_type'], |
| } |
|
|
|
|
| def collate_fn(batch): |
| """Custom collate that preserves style_key as a list of strings.""" |
| return { |
| 'input_ids': torch.stack([b['input_ids'] for b in batch]), |
| 'attention_mask': torch.stack([b['attention_mask'] for b in batch]), |
| 'labels': torch.stack([b['labels'] for b in batch]), |
| 'style_keys': [b['style_key'] for b in batch], |
| 'tasks': [b['task'] for b in batch], |
| 'essay_types': [b['essay_type'] for b in batch], |
| } |
|
|
|
|
| class MASHDPODataset(Dataset): |
| """Dataset for DPO alignment.""" |
| |
| def __init__(self, data_path: str, tokenizer, |
| max_input_len: int = 512, max_target_len: int = 512): |
| self.tokenizer = tokenizer |
| self.max_input_len = max_input_len |
| self.max_target_len = max_target_len |
| |
| self.examples = [] |
| with open(data_path) as f: |
| for line in f: |
| self.examples.append(json.loads(line)) |
| |
| def __len__(self): |
| return len(self.examples) |
| |
| def __getitem__(self, idx): |
| ex = self.examples[idx] |
| |
| input_enc = self.tokenizer( |
| ex['input_text'], |
| max_length=self.max_input_len, |
| truncation=True, |
| padding='max_length', |
| return_tensors='pt', |
| ) |
| |
| chosen_enc = self.tokenizer( |
| text_target=ex['chosen_text'], |
| max_length=self.max_target_len, |
| truncation=True, |
| padding='max_length', |
| return_tensors='pt', |
| ) |
| |
| rejected_enc = self.tokenizer( |
| text_target=ex['rejected_text'], |
| max_length=self.max_target_len, |
| truncation=True, |
| padding='max_length', |
| return_tensors='pt', |
| ) |
| |
| chosen_labels = chosen_enc['input_ids'].squeeze() |
| chosen_labels[chosen_labels == self.tokenizer.pad_token_id] = -100 |
| |
| rejected_labels = rejected_enc['input_ids'].squeeze() |
| rejected_labels[rejected_labels == self.tokenizer.pad_token_id] = -100 |
| |
| return { |
| 'input_ids': input_enc['input_ids'].squeeze(), |
| 'attention_mask': input_enc['attention_mask'].squeeze(), |
| 'chosen_labels': chosen_labels, |
| 'rejected_labels': rejected_labels, |
| 'style_key': ex['style_key'], |
| } |
|
|
|
|
| def dpo_collate_fn(batch): |
| return { |
| 'input_ids': torch.stack([b['input_ids'] for b in batch]), |
| 'attention_mask': torch.stack([b['attention_mask'] for b in batch]), |
| 'chosen_labels': torch.stack([b['chosen_labels'] for b in batch]), |
| 'rejected_labels': torch.stack([b['rejected_labels'] for b in batch]), |
| 'style_keys': [b['style_key'] for b in batch], |
| } |
|
|