StarMist0012's picture
Add files using upload-large-folder tool
3270dae verified
"""SFT dataset for HuggingFace datasets."""
from typing import Dict
import torch
from taoTrain.config import TrainingConfig
from taoTrain.data.hf_base import BaseHFDataset
class SFTDataset(BaseHFDataset):
"""Dataset for supervised fine-tuning with instruction-response pairs."""
def _preprocess(self):
"""Process instruction-response pairs."""
dataset_config = self.config.dataset
def format_example(example):
"""Format instruction and response."""
instruction = example.get(dataset_config.instruction_column, "")
response = example.get(dataset_config.response_column, "")
if dataset_config.instruction_template:
# Use custom template
text = dataset_config.instruction_template.format(
instruction=instruction,
response=response
)
else:
# Default template
text = f"{instruction}\n{response}"
return {"text": text}
# Format examples
self.data = self.data.map(
format_example,
remove_columns=[
col for col in self.data.column_names
if col not in ["text"]
] if "text" not in self.data.column_names else [],
desc="Formatting examples...",
)
# Tokenize
def tokenize_function(examples):
tokenized = self.tokenizer(
examples["text"],
truncation=True,
max_length=self.config.model.max_seq_length,
padding="max_length",
return_attention_mask=True,
)
return tokenized
self.data = self.data.map(
tokenize_function,
batched=True,
batch_size=100,
remove_columns=self.data.column_names,
desc="Tokenizing...",
)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
"""Get preprocessed sample."""
item = self.data[idx]
input_ids = torch.tensor(item["input_ids"], dtype=torch.long)
attention_mask = torch.tensor(item["attention_mask"], dtype=torch.long)
# For SFT, labels = input_ids shifted by 1 (next token prediction)
# Position i predicts token at position i+1
labels = input_ids[1:].clone()
labels = torch.cat([labels, torch.tensor([-100])], dim=0)
# Mark padding tokens as -100 to ignore in loss computation
labels[attention_mask == 0] = -100
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}