File size: 5,283 Bytes
7feac49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import argparse
from collections.abc import Mapping

import json
import torch
from transformers import PreTrainedTokenizerBase


def to_list(input_ids):
    if isinstance(input_ids, torch.Tensor):
        input_ids = input_ids.cpu().numpy().tolist()
    if isinstance(input_ids, list) and isinstance(input_ids[0], list):
        input_ids = input_ids[0]
    return input_ids


def load_ds(ds):
    from swift.llm import load_dataset
    train_dataset, val_dataset = load_dataset(
        ds,
        split_dataset_ratio=0.0,
        strict=False,
        num_proc=1,
        model_name=['小黄', 'Xiao Huang'],
        model_author=['魔搭', 'ModelScope'])
    return train_dataset.select(range(1))


def load_and_tokenize(ms_model_id, template):
    from swift.llm import EncodePreprocessor, get_model_tokenizer, get_template
    try:
        vl_fields = ['vl', 'video', 'minicpmv', 'llava', 'vision', 'emu', 'florence']
        model_ins, tokenizer = get_model_tokenizer(ms_model_id, load_model='mplug' in ms_model_id.lower())
        template_ins = get_template(template, tokenizer)
        if template_ins.use_model:
            model_ins, _ = get_model_tokenizer(ms_model_id, load_model=True)
            template_ins.model = model_ins
        template_ins.set_mode('train')
        if 'audio' in template_ins.__class__.__name__.lower():
            output = EncodePreprocessor(template_ins)(
                load_ds('speech_asr/speech_asr_aishell1_trainsets:validation/test'))
            input_ids = output[0].get('input_ids')
        elif any([vl in template for vl in vl_fields]):
            for row in load_ds('modelscope/coco_2014_caption:validation'):
                output = template_ins.encode(row)
                input_ids = output.get('input_ids')
                # output = EncodePreprocessor(template_ins)(load_ds('swift/OK-VQA_train'))
                if model_ins is not None and model_ins.model_meta.is_multimodal:
                    inputs = template_ins.pre_data_collator([output], model=model_ins)
                    _, output = template_ins.pre_forward_hook(model_ins, None, inputs)
        else:
            output = EncodePreprocessor(template_ins)(load_ds('modelscope/DuReader_robust-QG'))
            input_ids = output[0].get('input_ids')
        if isinstance(output, Mapping):
            assert output.get('input_ids') is not None or output.get('inputs_embeds') is not None
        else:
            assert output[0].get('input_ids') is not None or output[0].get('inputs_embeds') is not None
        input_ids = to_list(input_ids)
        sent = ''
        try:
            if not isinstance(tokenizer, PreTrainedTokenizerBase) and hasattr(tokenizer, 'tokenizer'):
                tokenizer = tokenizer.tokenizer
            sent = tokenizer.decode(input_ids)
        except Exception:
            pass
        return input_ids, sent
    except Exception:
        import traceback
        print(traceback.format_exc())
        raise


def load_ds_old(ds):
    from swift.llm import load_dataset
    train_dataset, val_dataset = load_dataset(ds, split_dataset_ratio=0.0)
    return train_dataset.select(range(1))


def load_and_tokenize_old(ms_model_id, template):
    model_type = None
    model_info = None
    from swift.llm import get_model_tokenizer
    from swift.llm import get_template, MODEL_MAPPING
    found = False
    for model_type, model_info in MODEL_MAPPING.items():
        if model_info['model_id_or_path'].lower() == ms_model_id.lower():
            found = True
            break

    if not found:
        raise ValueError(f'No model_type found: {ms_model_id}')

    vl_fields = ['vl', 'video', 'minicpm-v', 'llava', 'vision', 'emu', 'florence']
    model_ins, tokenizer = get_model_tokenizer(model_type, load_model=True)

    if model_info['template'] == 'default-generation':
        model_info['template'] = template.replace('_', '-')
    template_ins = get_template(model_info['template'], tokenizer)
    template_ins.model = model_ins
    if 'audio' in model_info['template']:
        output = template_ins.encode(load_ds_old('aishell1-zh-mini')[0])
    elif any([vl in model_info['template'] for vl in vl_fields]):
        output = template_ins.encode(load_ds_old('coco-en-mini')[0])
    else:
        output = template_ins.encode(load_ds_old('dureader-robust-zh')[0])
    input_ids = to_list(output[0]['input_ids'])
    sent = ''
    try:
        sent = tokenizer.decode(input_ids)
    except Exception:
        pass
    return input_ids, sent


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--ms_model_id',
        type=str,
        required=True,
    )
    parser.add_argument(
        '--template',
        type=str,
        required=True,
    )
    parser.add_argument('--new', type=str, required=False, default='1')
    args = parser.parse_args()

    is_new = args.new == '1'
    if is_new:
        input_ids, sent = load_and_tokenize(args.ms_model_id, args.template)
    else:
        input_ids, sent = load_and_tokenize_old(args.ms_model_id, args.template)
    file = 'new_input_ids.txt' if is_new else 'old_input_ids.txt'
    if input_ids is not None:
        with open(file, 'w') as f:
            json.dump({'input_ids': input_ids, 'sent': sent}, f)