File size: 1,336 Bytes
ee3e701
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

import torch

from internlm.core.context import global_context as gpc

DATASET_TYPE_IDS_MAP = {"vision": 0}


def get_dataset_type_id(path):
    import re

    match_idxes = []
    for key, idx in DATASET_TYPE_IDS_MAP.items():
        if re.search(rf"/[z_]*{key}/", path):
            match_idxes.append(idx)
    assert len(match_idxes) == 1, f"{path}, match_idxes should be 1, but got {match_idxes} from {DATASET_TYPE_IDS_MAP}"
    return match_idxes[0]


def unpack_data(input_ids, cu_seqlens):
    """
    input_ids: (n, packed_length)
    Return:
    output: (batch_size, max_length)
    """

    bsz = input_ids.shape[0]

    num_sequence = gpc.config.data["micro_bsz"]

    outputs = torch.zeros(bsz, num_sequence, gpc.config.data.seq_len, device=input_ids.device, dtype=input_ids.dtype)

    for i in range(bsz):
        output = torch.zeros(num_sequence, gpc.config.data.seq_len, device=input_ids.device, dtype=input_ids.dtype)
        cu_seqlens_slice = cu_seqlens[i]
        for j in range(num_sequence):
            seq_length = cu_seqlens_slice[j + 1] - cu_seqlens_slice[j]
            output[j, 0:seq_length] = input_ids[0, cu_seqlens_slice[j] : cu_seqlens_slice[j + 1]]
        outputs[i] = output

    if bsz == 1:
        outputs = outputs.squeeze(0)

    return outputs