File size: 6,004 Bytes
3270dae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
"""SFT JSONL dataset with async-only streaming and response-masking."""

from typing import Dict
import torch
from taoTrain.config import TrainingConfig
from taoTrain.data.jsonl_base import BaseJSONLDataset
from taoTrain.data.sft_utils import (
    parse_sft_record,
    build_sft_sequence_tokens,
    build_response_only_next_token_labels,
)


class SFTJSONLDataset(BaseJSONLDataset):
    """

    Dataset for supervised fine-tuning with local JSONL files with chunked loading.

    

    Supports both single-turn and multi-turn SFT data:

    - Single-turn: {"input": "...", "output": "..."}

    - Multi-turn: {"turns": [{"user": "...", "assistant": "..."}, ...]}

    

    With response-only loss masking: only trains on assistant/response tokens.

    """
    
    def __init__(self, *args, **kwargs):
        """Initialize dataset."""
        super().__init__(*args, **kwargs)
        # Store full records for parsing (not just text field)
        self._current_chunk_records = None
        
        # Get SFT-specific config
        self.sft_config = self.config if hasattr(self.config, 'mode') else None
        self.user_token = getattr(self.sft_config, 'user_token', '<user>') if self.sft_config else '<user>'
        self.assistant_token = getattr(self.sft_config, 'assistant_token', '<assistant>') if self.sft_config else '<assistant>'
        self.response_loss_only = getattr(self.sft_config, 'response_loss_only', True) if self.sft_config else True
    
    def _load_chunk(self, chunk_num: int):
        """

        Load a specific chunk from JSONL file, preserving full records for SFT parsing.

        

        Args:

            chunk_num: Chunk number to load (0-indexed)

        """
        if not self.chunk_manager:
            return
        
        if chunk_num == self._current_chunk_num and self._current_chunk_data is not None:
            # Already loaded
            return
        
        # Read chunk - get full record objects
        chunk_examples = self.chunk_manager.read_chunk(chunk_num)
        
        # Store full records for SFT parsing (not just text field)
        self._current_chunk_records = chunk_examples
        
        # Initialize data structures
        self._current_chunk_data = {
            "input_ids": [],
            "attention_mask": [],
            "mask": [],
        }
        self._current_chunk_num = chunk_num
        
        # Preprocess this chunk (tokenize and mask)
        self._preprocess_chunk()
    
    def _preprocess_chunk(self):
        """

        Process SFT records from current chunk into tokenized sequences with masking.

        

        Parses each record (single-turn or multi-turn) and generates:

        - Token sequences with role markers

        - Masking info (0=ignore, 1=train)

        - Labels with -100 for ignored tokens

        """
        if not self._current_chunk_records:
            return
        
        max_seq_length = self.config.model.max_seq_length
        
        all_input_ids = []
        all_attention_masks = []
        all_masks = []
        
        for record in self._current_chunk_records:
            try:
                # Parse record into (user, assistant) turns
                turns, is_multi_turn = parse_sft_record(record, self.config)
                
                if not turns:
                    # Fallback: try to use "text" field if present
                    if "text" in record:
                        turns = [(record["text"], "")]
                    else:
                        continue  # Skip invalid records
                
                # Build token sequence with role tokens and masking
                input_ids, attention_mask, mask = build_sft_sequence_tokens(
                    turns=turns,
                    tokenizer=self.tokenizer,
                    user_token=self.user_token,
                    assistant_token=self.assistant_token,
                    max_seq_length=max_seq_length,
                )
                
                all_input_ids.append(input_ids)
                all_attention_masks.append(attention_mask)
                all_masks.append(mask)
                
            except Exception as e:
                # Log and skip problematic records
                print(f"Warning: Failed to process SFT record: {e}")
                continue
        
        # Update chunk data with tokenized sequences and masks
        self._current_chunk_data = {
            "input_ids": all_input_ids,
            "attention_mask": all_attention_masks,
            "mask": all_masks,
        }
    
    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        """

        Get preprocessed sample with response-only loss masking.

        

        Args:

            idx: Sample index

        

        Returns:

            Dict with input_ids, attention_mask, and labels (with -100 for ignored tokens)

        """
        # Load appropriate chunk if using streaming
        if self.chunk_manager:
            chunk_num = self._get_chunk_for_idx(idx)
            if chunk_num != self._current_chunk_num:
                self._load_chunk(chunk_num)
            local_idx = self._get_local_idx_in_chunk(idx)
        else:
            local_idx = idx
        
        # Get tokenized data
        input_ids = torch.tensor(self._current_chunk_data["input_ids"][local_idx], dtype=torch.long)
        attention_mask = torch.tensor(self._current_chunk_data["attention_mask"][local_idx], dtype=torch.long)
        mask = self._current_chunk_data["mask"][local_idx]
        
        labels = torch.tensor(
            build_response_only_next_token_labels(input_ids.tolist(), mask),
            dtype=torch.long,
        )
        
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
        }