File size: 1,355 Bytes
7feac49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
import argparse
from dataclasses import fields
import torch
from swift.llm import MODEL_ARCH_MAPPING, ModelKeys, get_model_tokenizer
def get_model_and_tokenizer(ms_model_id, model_arch=None):
try:
import transformers
print(f'Test model: {ms_model_id} with transformers version: {transformers.__version__}')
model_ins, tokenizer = get_model_tokenizer(ms_model_id)
model_ins: torch.nn.Module
if model_arch:
model_arch: ModelKeys = MODEL_ARCH_MAPPING[model_arch]
for f in fields(model_arch):
value = getattr(model_arch, f.name)
if value is not None and f.name != 'arch_name':
if isinstance(value, str):
value = [value]
for v in value:
v = v.replace('{}', '0')
model_ins.get_submodule(v)
except Exception:
import traceback
print(traceback.format_exc())
raise
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--ms_model_id',
type=str,
required=True,
)
parser.add_argument(
'--model_arch',
type=str,
required=True,
)
args = parser.parse_args()
get_model_and_tokenizer(args.ms_model_id, args.model_arch)
|