File size: 3,357 Bytes
624b7ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Processing functions for TeleYAML dataset - v2 with nested format support."""
from typing import Any, Optional
from megatron.bridge.data.builders.hf_dataset import ProcessExampleOutput
from megatron.bridge.training.tokenizers.tokenizer import MegatronTokenizer


def _flatten_messages(messages: list[dict[str, str]]) -> str:
    """Convert a list of chat messages into a formatted string.
    
    Args:
        messages: List of message dicts with 'role' and 'content' keys
        
    Returns:
        Formatted string with role tags
    """
    parts = []
    for msg in messages:
        role = msg.get("role", "user")
        content = msg.get("content", "")
        parts.append(f"<{role}>\n{content}\n</{role}>")
    return "\n".join(parts)


def _flatten_output(output_dict: dict[str, Any]) -> str:
    """Convert nested output dict into a formatted string.
    
    Args:
        output_dict: Dict with 'reasoning_context' and/or 'content' keys
        
    Returns:
        Formatted string combining reasoning and content
    """
    reasoning = output_dict.get("reasoning_context", "")
    content = output_dict.get("content", "")
    
    if reasoning and content:
        return f"<reasoning>\n{reasoning}\n</reasoning>\n\n{content}"
    elif reasoning:
        return reasoning
    else:
        return content


def process_teleyaml_example(
    example: dict[str, Any], tokenizer: Optional[MegatronTokenizer] = None
) -> ProcessExampleOutput:
    """Process a TeleYAML example into the required format.
    
    Handles both flat format (v1) and nested format (v2):
    
    Flat (v1):
        {"input": "string", "output": "string"}
        
    Nested (v2):
        {"input": {"messages": [...]}, "output": {"reasoning_context": "...", "content": "..."}}
    
    Args:
        example: Raw TeleYAML example
        tokenizer: Optional tokenizer (not used)
        
    Returns:
        ProcessExampleOutput with formatted input/output and original answers
    """
    raw_input = example.get("input", "")
    raw_output = example.get("output", "")
    
    # Handle input - check if nested messages format
    if isinstance(raw_input, dict) and "messages" in raw_input:
        _input = _flatten_messages(raw_input["messages"])
    elif isinstance(raw_input, str):
        _input = raw_input
    else:
        _input = str(raw_input)
    
    # Handle output - check if nested dict format
    if isinstance(raw_output, dict):
        _output = _flatten_output(raw_output)
    elif isinstance(raw_output, str):
        _output = raw_output
    else:
        _output = str(raw_output)
    
    original_answers = [_output]
    
    return ProcessExampleOutput(input=_input, output=_output, original_answers=original_answers)