mash-stylebart-trainer / src /dataset_v4.py
catninja123's picture
Upload src/dataset_v4.py with huggingface_hub
4e94eef verified
"""
MASH Dataset v4 - Instruction Prefix for Flan-T5-XL
Instead of style embeddings, we prepend a task instruction to the input text.
This leverages Flan-T5's instruction-following capability directly.
Instruction format:
"Rewrite the following AI-generated {essay_type} essay in a natural,
authentic human voice. Preserve the original meaning and key details
while making the writing sound genuinely human-written:\n\n{ai_text}"
"""
import json
import random
import torch
from torch.utils.data import Dataset
# Instruction templates β€” varied to prevent overfitting to a single phrasing
INSTRUCTIONS = [
"Rewrite the following AI-generated {type} essay in a natural, authentic human voice. Preserve the original meaning and key details while making the writing sound genuinely human-written:\n\n{text}",
"Transform this AI-written {type} essay into natural human writing. Keep the same ideas and details but make it sound like a real person wrote it:\n\n{text}",
"Convert the following machine-generated {type} essay to sound authentically human. Maintain the core content while adopting a genuine, personal writing style:\n\n{text}",
"Rewrite this {type} essay to remove all traces of AI writing. The output should read as if written by a real student, preserving the original meaning:\n\n{text}",
"Make the following AI-generated {type} essay sound human-written. Keep the same content and structure but use natural, authentic language:\n\n{text}",
]
TYPE_NAMES = {
'ps': 'personal statement',
'supp': 'supplemental',
}
class InstructionDataset(Dataset):
"""
Dataset for instruction-prefix SFT.
Each sample: instruction + AI text β†’ human text
"""
def __init__(self, data_path: str, tokenizer,
max_input_len: int = 512, max_target_len: int = 512):
self.tokenizer = tokenizer
self.max_input_len = max_input_len
self.max_target_len = max_target_len
self.raw_data = []
with open(data_path) as f:
for line in f:
d = json.loads(line)
self.raw_data.append(d)
# Build examples (transfer only β€” no reconstruction needed without style vectors)
self.examples = []
for d in self.raw_data:
essay_type = d['type'] # 'ps' or 'supp'
type_name = TYPE_NAMES.get(essay_type, essay_type)
self.examples.append({
'input_text': d['input_text'],
'human_text': d['human_text'],
'essay_type': essay_type,
'type_name': type_name,
})
def __len__(self):
return len(self.examples)
def __getitem__(self, idx):
ex = self.examples[idx]
# Randomly select an instruction template
template = random.choice(INSTRUCTIONS)
instruction_text = template.format(
type=ex['type_name'],
text=ex['input_text'],
)
# Tokenize input (instruction + AI text)
input_enc = self.tokenizer(
instruction_text,
max_length=self.max_input_len,
truncation=True,
padding='max_length',
return_tensors='pt',
)
# Tokenize target (human text)
target_enc = self.tokenizer(
text_target=ex['human_text'],
max_length=self.max_target_len,
truncation=True,
padding='max_length',
return_tensors='pt',
)
# Replace padding with -100 for loss computation
labels = target_enc['input_ids'].squeeze()
labels[labels == self.tokenizer.pad_token_id] = -100
return {
'input_ids': input_enc['input_ids'].squeeze(),
'attention_mask': input_enc['attention_mask'].squeeze(),
'labels': labels,
}
def collate_fn(batch):
"""Simple collate β€” no style keys needed."""
return {
'input_ids': torch.stack([b['input_ids'] for b in batch]),
'attention_mask': torch.stack([b['attention_mask'] for b in batch]),
'labels': torch.stack([b['labels'] for b in batch]),
}