"""SFT dataset for HuggingFace datasets.""" from typing import Dict import torch from taoTrain.config import TrainingConfig from taoTrain.data.hf_base import BaseHFDataset class SFTDataset(BaseHFDataset): """Dataset for supervised fine-tuning with instruction-response pairs.""" def _preprocess(self): """Process instruction-response pairs.""" dataset_config = self.config.dataset def format_example(example): """Format instruction and response.""" instruction = example.get(dataset_config.instruction_column, "") response = example.get(dataset_config.response_column, "") if dataset_config.instruction_template: # Use custom template text = dataset_config.instruction_template.format( instruction=instruction, response=response ) else: # Default template text = f"{instruction}\n{response}" return {"text": text} # Format examples self.data = self.data.map( format_example, remove_columns=[ col for col in self.data.column_names if col not in ["text"] ] if "text" not in self.data.column_names else [], desc="Formatting examples...", ) # Tokenize def tokenize_function(examples): tokenized = self.tokenizer( examples["text"], truncation=True, max_length=self.config.model.max_seq_length, padding="max_length", return_attention_mask=True, ) return tokenized self.data = self.data.map( tokenize_function, batched=True, batch_size=100, remove_columns=self.data.column_names, desc="Tokenizing...", ) def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: """Get preprocessed sample.""" item = self.data[idx] input_ids = torch.tensor(item["input_ids"], dtype=torch.long) attention_mask = torch.tensor(item["attention_mask"], dtype=torch.long) # For SFT, labels = input_ids shifted by 1 (next token prediction) # Position i predicts token at position i+1 labels = input_ids[1:].clone() labels = torch.cat([labels, torch.tensor([-100])], dim=0) # Mark padding tokens as -100 to ignore in loss computation labels[attention_mask == 0] = -100 return { "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, }