File size: 5,283 Bytes
7feac49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import argparse
from collections.abc import Mapping
import json
import torch
from transformers import PreTrainedTokenizerBase
def to_list(input_ids):
if isinstance(input_ids, torch.Tensor):
input_ids = input_ids.cpu().numpy().tolist()
if isinstance(input_ids, list) and isinstance(input_ids[0], list):
input_ids = input_ids[0]
return input_ids
def load_ds(ds):
from swift.llm import load_dataset
train_dataset, val_dataset = load_dataset(
ds,
split_dataset_ratio=0.0,
strict=False,
num_proc=1,
model_name=['小黄', 'Xiao Huang'],
model_author=['魔搭', 'ModelScope'])
return train_dataset.select(range(1))
def load_and_tokenize(ms_model_id, template):
from swift.llm import EncodePreprocessor, get_model_tokenizer, get_template
try:
vl_fields = ['vl', 'video', 'minicpmv', 'llava', 'vision', 'emu', 'florence']
model_ins, tokenizer = get_model_tokenizer(ms_model_id, load_model='mplug' in ms_model_id.lower())
template_ins = get_template(template, tokenizer)
if template_ins.use_model:
model_ins, _ = get_model_tokenizer(ms_model_id, load_model=True)
template_ins.model = model_ins
template_ins.set_mode('train')
if 'audio' in template_ins.__class__.__name__.lower():
output = EncodePreprocessor(template_ins)(
load_ds('speech_asr/speech_asr_aishell1_trainsets:validation/test'))
input_ids = output[0].get('input_ids')
elif any([vl in template for vl in vl_fields]):
for row in load_ds('modelscope/coco_2014_caption:validation'):
output = template_ins.encode(row)
input_ids = output.get('input_ids')
# output = EncodePreprocessor(template_ins)(load_ds('swift/OK-VQA_train'))
if model_ins is not None and model_ins.model_meta.is_multimodal:
inputs = template_ins.pre_data_collator([output], model=model_ins)
_, output = template_ins.pre_forward_hook(model_ins, None, inputs)
else:
output = EncodePreprocessor(template_ins)(load_ds('modelscope/DuReader_robust-QG'))
input_ids = output[0].get('input_ids')
if isinstance(output, Mapping):
assert output.get('input_ids') is not None or output.get('inputs_embeds') is not None
else:
assert output[0].get('input_ids') is not None or output[0].get('inputs_embeds') is not None
input_ids = to_list(input_ids)
sent = ''
try:
if not isinstance(tokenizer, PreTrainedTokenizerBase) and hasattr(tokenizer, 'tokenizer'):
tokenizer = tokenizer.tokenizer
sent = tokenizer.decode(input_ids)
except Exception:
pass
return input_ids, sent
except Exception:
import traceback
print(traceback.format_exc())
raise
def load_ds_old(ds):
from swift.llm import load_dataset
train_dataset, val_dataset = load_dataset(ds, split_dataset_ratio=0.0)
return train_dataset.select(range(1))
def load_and_tokenize_old(ms_model_id, template):
model_type = None
model_info = None
from swift.llm import get_model_tokenizer
from swift.llm import get_template, MODEL_MAPPING
found = False
for model_type, model_info in MODEL_MAPPING.items():
if model_info['model_id_or_path'].lower() == ms_model_id.lower():
found = True
break
if not found:
raise ValueError(f'No model_type found: {ms_model_id}')
vl_fields = ['vl', 'video', 'minicpm-v', 'llava', 'vision', 'emu', 'florence']
model_ins, tokenizer = get_model_tokenizer(model_type, load_model=True)
if model_info['template'] == 'default-generation':
model_info['template'] = template.replace('_', '-')
template_ins = get_template(model_info['template'], tokenizer)
template_ins.model = model_ins
if 'audio' in model_info['template']:
output = template_ins.encode(load_ds_old('aishell1-zh-mini')[0])
elif any([vl in model_info['template'] for vl in vl_fields]):
output = template_ins.encode(load_ds_old('coco-en-mini')[0])
else:
output = template_ins.encode(load_ds_old('dureader-robust-zh')[0])
input_ids = to_list(output[0]['input_ids'])
sent = ''
try:
sent = tokenizer.decode(input_ids)
except Exception:
pass
return input_ids, sent
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--ms_model_id',
type=str,
required=True,
)
parser.add_argument(
'--template',
type=str,
required=True,
)
parser.add_argument('--new', type=str, required=False, default='1')
args = parser.parse_args()
is_new = args.new == '1'
if is_new:
input_ids, sent = load_and_tokenize(args.ms_model_id, args.template)
else:
input_ids, sent = load_and_tokenize_old(args.ms_model_id, args.template)
file = 'new_input_ids.txt' if is_new else 'old_input_ids.txt'
if input_ids is not None:
with open(file, 'w') as f:
json.dump({'input_ids': input_ids, 'sent': sent}, f)
|