StarMist0012's picture
Add files using upload-large-folder tool
3270dae verified
"""SFT JSONL dataset with async-only streaming and response-masking."""
from typing import Dict
import torch
from taoTrain.config import TrainingConfig
from taoTrain.data.jsonl_base import BaseJSONLDataset
from taoTrain.data.sft_utils import (
parse_sft_record,
build_sft_sequence_tokens,
build_response_only_next_token_labels,
)
class SFTJSONLDataset(BaseJSONLDataset):
"""
Dataset for supervised fine-tuning with local JSONL files with chunked loading.
Supports both single-turn and multi-turn SFT data:
- Single-turn: {"input": "...", "output": "..."}
- Multi-turn: {"turns": [{"user": "...", "assistant": "..."}, ...]}
With response-only loss masking: only trains on assistant/response tokens.
"""
def __init__(self, *args, **kwargs):
"""Initialize dataset."""
super().__init__(*args, **kwargs)
# Store full records for parsing (not just text field)
self._current_chunk_records = None
# Get SFT-specific config
self.sft_config = self.config if hasattr(self.config, 'mode') else None
self.user_token = getattr(self.sft_config, 'user_token', '<user>') if self.sft_config else '<user>'
self.assistant_token = getattr(self.sft_config, 'assistant_token', '<assistant>') if self.sft_config else '<assistant>'
self.response_loss_only = getattr(self.sft_config, 'response_loss_only', True) if self.sft_config else True
def _load_chunk(self, chunk_num: int):
"""
Load a specific chunk from JSONL file, preserving full records for SFT parsing.
Args:
chunk_num: Chunk number to load (0-indexed)
"""
if not self.chunk_manager:
return
if chunk_num == self._current_chunk_num and self._current_chunk_data is not None:
# Already loaded
return
# Read chunk - get full record objects
chunk_examples = self.chunk_manager.read_chunk(chunk_num)
# Store full records for SFT parsing (not just text field)
self._current_chunk_records = chunk_examples
# Initialize data structures
self._current_chunk_data = {
"input_ids": [],
"attention_mask": [],
"mask": [],
}
self._current_chunk_num = chunk_num
# Preprocess this chunk (tokenize and mask)
self._preprocess_chunk()
def _preprocess_chunk(self):
"""
Process SFT records from current chunk into tokenized sequences with masking.
Parses each record (single-turn or multi-turn) and generates:
- Token sequences with role markers
- Masking info (0=ignore, 1=train)
- Labels with -100 for ignored tokens
"""
if not self._current_chunk_records:
return
max_seq_length = self.config.model.max_seq_length
all_input_ids = []
all_attention_masks = []
all_masks = []
for record in self._current_chunk_records:
try:
# Parse record into (user, assistant) turns
turns, is_multi_turn = parse_sft_record(record, self.config)
if not turns:
# Fallback: try to use "text" field if present
if "text" in record:
turns = [(record["text"], "")]
else:
continue # Skip invalid records
# Build token sequence with role tokens and masking
input_ids, attention_mask, mask = build_sft_sequence_tokens(
turns=turns,
tokenizer=self.tokenizer,
user_token=self.user_token,
assistant_token=self.assistant_token,
max_seq_length=max_seq_length,
)
all_input_ids.append(input_ids)
all_attention_masks.append(attention_mask)
all_masks.append(mask)
except Exception as e:
# Log and skip problematic records
print(f"Warning: Failed to process SFT record: {e}")
continue
# Update chunk data with tokenized sequences and masks
self._current_chunk_data = {
"input_ids": all_input_ids,
"attention_mask": all_attention_masks,
"mask": all_masks,
}
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
"""
Get preprocessed sample with response-only loss masking.
Args:
idx: Sample index
Returns:
Dict with input_ids, attention_mask, and labels (with -100 for ignored tokens)
"""
# Load appropriate chunk if using streaming
if self.chunk_manager:
chunk_num = self._get_chunk_for_idx(idx)
if chunk_num != self._current_chunk_num:
self._load_chunk(chunk_num)
local_idx = self._get_local_idx_in_chunk(idx)
else:
local_idx = idx
# Get tokenized data
input_ids = torch.tensor(self._current_chunk_data["input_ids"][local_idx], dtype=torch.long)
attention_mask = torch.tensor(self._current_chunk_data["attention_mask"][local_idx], dtype=torch.long)
mask = self._current_chunk_data["mask"][local_idx]
labels = torch.tensor(
build_response_only_next_token_labels(input_ids.tolist(), mask),
dtype=torch.long,
)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}