"""SFT JSONL dataset with async-only streaming and response-masking.""" from typing import Dict import torch from taoTrain.config import TrainingConfig from taoTrain.data.jsonl_base import BaseJSONLDataset from taoTrain.data.sft_utils import ( parse_sft_record, build_sft_sequence_tokens, build_response_only_next_token_labels, ) class SFTJSONLDataset(BaseJSONLDataset): """ Dataset for supervised fine-tuning with local JSONL files with chunked loading. Supports both single-turn and multi-turn SFT data: - Single-turn: {"input": "...", "output": "..."} - Multi-turn: {"turns": [{"user": "...", "assistant": "..."}, ...]} With response-only loss masking: only trains on assistant/response tokens. """ def __init__(self, *args, **kwargs): """Initialize dataset.""" super().__init__(*args, **kwargs) # Store full records for parsing (not just text field) self._current_chunk_records = None # Get SFT-specific config self.sft_config = self.config if hasattr(self.config, 'mode') else None self.user_token = getattr(self.sft_config, 'user_token', '') if self.sft_config else '' self.assistant_token = getattr(self.sft_config, 'assistant_token', '') if self.sft_config else '' self.response_loss_only = getattr(self.sft_config, 'response_loss_only', True) if self.sft_config else True def _load_chunk(self, chunk_num: int): """ Load a specific chunk from JSONL file, preserving full records for SFT parsing. Args: chunk_num: Chunk number to load (0-indexed) """ if not self.chunk_manager: return if chunk_num == self._current_chunk_num and self._current_chunk_data is not None: # Already loaded return # Read chunk - get full record objects chunk_examples = self.chunk_manager.read_chunk(chunk_num) # Store full records for SFT parsing (not just text field) self._current_chunk_records = chunk_examples # Initialize data structures self._current_chunk_data = { "input_ids": [], "attention_mask": [], "mask": [], } self._current_chunk_num = chunk_num # Preprocess this chunk (tokenize and mask) self._preprocess_chunk() def _preprocess_chunk(self): """ Process SFT records from current chunk into tokenized sequences with masking. Parses each record (single-turn or multi-turn) and generates: - Token sequences with role markers - Masking info (0=ignore, 1=train) - Labels with -100 for ignored tokens """ if not self._current_chunk_records: return max_seq_length = self.config.model.max_seq_length all_input_ids = [] all_attention_masks = [] all_masks = [] for record in self._current_chunk_records: try: # Parse record into (user, assistant) turns turns, is_multi_turn = parse_sft_record(record, self.config) if not turns: # Fallback: try to use "text" field if present if "text" in record: turns = [(record["text"], "")] else: continue # Skip invalid records # Build token sequence with role tokens and masking input_ids, attention_mask, mask = build_sft_sequence_tokens( turns=turns, tokenizer=self.tokenizer, user_token=self.user_token, assistant_token=self.assistant_token, max_seq_length=max_seq_length, ) all_input_ids.append(input_ids) all_attention_masks.append(attention_mask) all_masks.append(mask) except Exception as e: # Log and skip problematic records print(f"Warning: Failed to process SFT record: {e}") continue # Update chunk data with tokenized sequences and masks self._current_chunk_data = { "input_ids": all_input_ids, "attention_mask": all_attention_masks, "mask": all_masks, } def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: """ Get preprocessed sample with response-only loss masking. Args: idx: Sample index Returns: Dict with input_ids, attention_mask, and labels (with -100 for ignored tokens) """ # Load appropriate chunk if using streaming if self.chunk_manager: chunk_num = self._get_chunk_for_idx(idx) if chunk_num != self._current_chunk_num: self._load_chunk(chunk_num) local_idx = self._get_local_idx_in_chunk(idx) else: local_idx = idx # Get tokenized data input_ids = torch.tensor(self._current_chunk_data["input_ids"][local_idx], dtype=torch.long) attention_mask = torch.tensor(self._current_chunk_data["attention_mask"][local_idx], dtype=torch.long) mask = self._current_chunk_data["mask"][local_idx] labels = torch.tensor( build_response_only_next_token_labels(input_ids.tolist(), mask), dtype=torch.long, ) return { "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, }