File size: 6,212 Bytes
3270dae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
"""SFT utility functions for parsing and masking."""

from typing import Dict, Any, List, Tuple
from taoTrain.config import TrainingConfig


def parse_sft_record(record: Dict[str, Any], config: TrainingConfig) -> Tuple[List[Tuple[str, str]], bool]:
    """

    Parse JSONL record into list of (user, assistant) turns.

    

    Supports two formats:

    1. Single-turn: {"input": "...", "output": "..."}

    2. Multi-turn: {"turns": [{"user": "...", "assistant": "..."}, ...]}

    

    Args:

        record: JSONL record (dict)

        config: Training configuration

    

    Returns:

        (turns_list, is_multi_turn) where:

        - turns_list: List of (user_text, assistant_text) tuples

        - is_multi_turn: Whether this is a multi-turn record

    """
    # Check for multi-turn format
    if "turns" in record:
        turns = []
        for turn in record["turns"]:
            if isinstance(turn, dict) and "user" in turn and "assistant" in turn:
                turns.append((turn["user"], turn["assistant"]))
        if turns:
            return turns, True
    
    # Check for single-turn format with input/output fields
    if "input" in record and "output" in record:
        return [(record["input"], record["output"])], False
    
    # Fallback: check for instruction/response fields (from config)
    dataset_config = config.dataset
    instruction_col = dataset_config.instruction_column or "instruction"
    response_col = dataset_config.response_column or "response"
    
    if instruction_col in record and response_col in record:
        return [(record[instruction_col], record[response_col])], False
    
    # Fallback: assume pre-formatted "text" field (old format)
    if "text" in record:
        return [(record["text"], "")], False
    
    return [], False


def build_sft_sequence_tokens(

    turns: List[Tuple[str, str]],

    tokenizer,

    user_token: str = "<user>",

    assistant_token: str = "<assistant>",

    max_seq_length: int = 1024,

) -> Tuple[List[int], List[int], List[int]]:
    """

    Build token sequence for SFT with role tokens and generate masking info.

    

    Sequence format:

        [user_token_id] user_tokens [assistant_token_id] assistant_tokens ... [eos_token_id]

    

    Mask values:

        - 0 (ignore): user input regions and role tokens → loss=-100

        - 1 (train): assistant output regions → compute loss

    

    Args:

        turns: List of (user_text, assistant_text) tuples

        tokenizer: Tokenizer instance

        user_token: Role token for user (e.g., "<user>")

        assistant_token: Role token for assistant (e.g., "<assistant>")

        max_seq_length: Maximum sequence length

    

    Returns:

        (input_ids, attention_mask, mask) where:

        - input_ids: Token IDs for the full sequence

        - attention_mask: Attention mask (1 for real tokens, 0 for padding)

        - mask: Loss mask (0=ignore, 1=train loss)

    """
    input_ids = []
    mask = []
    
    # Get token IDs for special tokens
    user_token_ids = tokenizer(user_token, add_special_tokens=False)["input_ids"]
    assistant_token_ids = tokenizer(assistant_token, add_special_tokens=False)["input_ids"]
    
    # Process each turn
    for user_text, assistant_text in turns:
        # User role marker
        input_ids.extend(user_token_ids)
        mask.extend([0] * len(user_token_ids))  # Mask role token
        
        # User message tokens
        user_tokens = tokenizer(user_text, add_special_tokens=False)["input_ids"]
        input_ids.extend(user_tokens)
        mask.extend([0] * len(user_tokens))  # Mask user input
        
        # Assistant role marker
        input_ids.extend(assistant_token_ids)
        mask.extend([0] * len(assistant_token_ids))  # Mask role token
        
        # Assistant message tokens
        assistant_tokens = tokenizer(assistant_text, add_special_tokens=False)["input_ids"]
        input_ids.extend(assistant_tokens)
        mask.extend([1] * len(assistant_tokens))  # Train on assistant output
    
    # Add EOS token if exists
    if hasattr(tokenizer, 'eos_token_id') and tokenizer.eos_token_id is not None:
        input_ids.append(tokenizer.eos_token_id)
        mask.append(0)  # Mask EOS token
    
    # Truncate if too long
    if len(input_ids) > max_seq_length:
        input_ids = input_ids[:max_seq_length]
        mask = mask[:max_seq_length]
    
    # Pad to max_seq_length
    padding_len = max_seq_length - len(input_ids)
    if padding_len > 0:
        input_ids.extend([tokenizer.pad_token_id or 0] * padding_len)
        mask.extend([0] * padding_len)  # Mask padding tokens
    
    # Create attention mask (1 for real tokens, 0 for padding)
    attention_mask = [1 if i < len(input_ids) - padding_len else 0 for i in range(len(input_ids))]
    
    return input_ids, attention_mask, mask


def apply_response_masking(input_ids: List[int], mask: List[int]) -> List[int]:
    """

    Apply response-only loss masking by converting mask values to label format.

    

    Args:

        input_ids: Token IDs

        mask: Mask array (0=ignore, 1=train)

    

    Returns:

        labels: Where mask=0 tokens have label=-100 (ignore in loss), mask=1 tokens have label=input_id

    """
    labels = input_ids.copy()
    for i, m in enumerate(mask):
        if m == 0:
            labels[i] = -100  # CrossEntropyLoss will ignore this token
    return labels


def build_response_only_next_token_labels(input_ids: List[int], mask: List[int]) -> List[int]:
    """

    Build next-token labels for SFT response-only training.



    Position i predicts token i+1, so the loss mask must be applied to the target

    token, not the current input token. This trains the first assistant token from

    the assistant role marker and avoids training on masked EOS/padding targets.

    """
    if len(input_ids) != len(mask):
        raise ValueError(f"input_ids and mask must have the same length: {len(input_ids)} != {len(mask)}")

    labels = apply_response_masking(input_ids, mask)
    return labels[1:] + [-100]