catninja123's picture
Upload src/dataset.py with huggingface_hub
236d8d8 verified
"""
MASH Dataset for Style-injection SFT β€” v3 (T5 compatible)
Each raw sample produces TWO training examples:
1. Transfer task: (ai_text, s_human_*) β†’ human_text
2. Reconstruction: (ai_text, s_ai_*) β†’ ai_text
The DataLoader interleaves both tasks within each batch.
"""
import json
import random
import torch
from torch.utils.data import Dataset, DataLoader
class MASHSFTDataset(Dataset):
"""
Dataset for Style-injection SFT.
Each __getitem__ returns one training example with:
- input_text (AI text)
- target_text (human or AI text)
- style_key ('human_ps', 'human_supp', 'ai_ps', 'ai_supp')
- task ('transfer' or 'reconstruction')
"""
def __init__(self, data_path: str, tokenizer,
max_input_len: int = 512, max_target_len: int = 512,
include_reconstruction: bool = True,
reconstruction_ratio: float = 0.3):
self.tokenizer = tokenizer
self.max_input_len = max_input_len
self.max_target_len = max_target_len
# Load raw data
self.raw_data = []
with open(data_path) as f:
for line in f:
self.raw_data.append(json.loads(line))
# Build training examples
self.examples = []
for d in self.raw_data:
essay_type = d['type'] # 'ps' or 'supp'
# Transfer task: (ai_text, s_human_*) β†’ human_text
self.examples.append({
'input_text': d['input_text'],
'target_text': d['human_text'],
'style_key': f'human_{essay_type}',
'task': 'transfer',
'essay_type': essay_type,
})
# Reconstruction task: (ai_text, s_ai_*) β†’ ai_text
if include_reconstruction:
self.examples.append({
'input_text': d['input_text'],
'target_text': d['ai_text'],
'style_key': f'ai_{essay_type}',
'task': 'reconstruction',
'essay_type': essay_type,
})
# If reconstruction ratio < 0.5, downsample reconstruction examples
if include_reconstruction and reconstruction_ratio < 0.5:
transfer = [e for e in self.examples if e['task'] == 'transfer']
recon = [e for e in self.examples if e['task'] == 'reconstruction']
n_recon = int(len(transfer) * reconstruction_ratio / (1 - reconstruction_ratio))
random.shuffle(recon)
recon = recon[:n_recon]
self.examples = transfer + recon
random.shuffle(self.examples)
def __len__(self):
return len(self.examples)
def __getitem__(self, idx):
ex = self.examples[idx]
# Tokenize input
input_enc = self.tokenizer(
ex['input_text'],
max_length=self.max_input_len,
truncation=True,
padding='max_length',
return_tensors='pt',
)
# Tokenize target β€” use text_target for T5 compatibility
target_enc = self.tokenizer(
text_target=ex['target_text'],
max_length=self.max_target_len,
truncation=True,
padding='max_length',
return_tensors='pt',
)
# Replace padding token id with -100 for loss computation
labels = target_enc['input_ids'].squeeze()
labels[labels == self.tokenizer.pad_token_id] = -100
return {
'input_ids': input_enc['input_ids'].squeeze(),
'attention_mask': input_enc['attention_mask'].squeeze(),
'labels': labels,
'style_key': ex['style_key'],
'task': ex['task'],
'essay_type': ex['essay_type'],
}
def collate_fn(batch):
"""Custom collate that preserves style_key as a list of strings."""
return {
'input_ids': torch.stack([b['input_ids'] for b in batch]),
'attention_mask': torch.stack([b['attention_mask'] for b in batch]),
'labels': torch.stack([b['labels'] for b in batch]),
'style_keys': [b['style_key'] for b in batch],
'tasks': [b['task'] for b in batch],
'essay_types': [b['essay_type'] for b in batch],
}
class MASHDPODataset(Dataset):
"""Dataset for DPO alignment."""
def __init__(self, data_path: str, tokenizer,
max_input_len: int = 512, max_target_len: int = 512):
self.tokenizer = tokenizer
self.max_input_len = max_input_len
self.max_target_len = max_target_len
self.examples = []
with open(data_path) as f:
for line in f:
self.examples.append(json.loads(line))
def __len__(self):
return len(self.examples)
def __getitem__(self, idx):
ex = self.examples[idx]
input_enc = self.tokenizer(
ex['input_text'],
max_length=self.max_input_len,
truncation=True,
padding='max_length',
return_tensors='pt',
)
chosen_enc = self.tokenizer(
text_target=ex['chosen_text'],
max_length=self.max_target_len,
truncation=True,
padding='max_length',
return_tensors='pt',
)
rejected_enc = self.tokenizer(
text_target=ex['rejected_text'],
max_length=self.max_target_len,
truncation=True,
padding='max_length',
return_tensors='pt',
)
chosen_labels = chosen_enc['input_ids'].squeeze()
chosen_labels[chosen_labels == self.tokenizer.pad_token_id] = -100
rejected_labels = rejected_enc['input_ids'].squeeze()
rejected_labels[rejected_labels == self.tokenizer.pad_token_id] = -100
return {
'input_ids': input_enc['input_ids'].squeeze(),
'attention_mask': input_enc['attention_mask'].squeeze(),
'chosen_labels': chosen_labels,
'rejected_labels': rejected_labels,
'style_key': ex['style_key'],
}
def dpo_collate_fn(batch):
return {
'input_ids': torch.stack([b['input_ids'] for b in batch]),
'attention_mask': torch.stack([b['attention_mask'] for b in batch]),
'chosen_labels': torch.stack([b['chosen_labels'] for b in batch]),
'rejected_labels': torch.stack([b['rejected_labels'] for b in batch]),
'style_keys': [b['style_key'] for b in batch],
}