interactSpeech / tests /llm /load_model.py
Student0809's picture
Add files using upload-large-folder tool
7feac49 verified
raw
history blame
1.36 kB
import argparse
from dataclasses import fields
import torch
from swift.llm import MODEL_ARCH_MAPPING, ModelKeys, get_model_tokenizer
def get_model_and_tokenizer(ms_model_id, model_arch=None):
try:
import transformers
print(f'Test model: {ms_model_id} with transformers version: {transformers.__version__}')
model_ins, tokenizer = get_model_tokenizer(ms_model_id)
model_ins: torch.nn.Module
if model_arch:
model_arch: ModelKeys = MODEL_ARCH_MAPPING[model_arch]
for f in fields(model_arch):
value = getattr(model_arch, f.name)
if value is not None and f.name != 'arch_name':
if isinstance(value, str):
value = [value]
for v in value:
v = v.replace('{}', '0')
model_ins.get_submodule(v)
except Exception:
import traceback
print(traceback.format_exc())
raise
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--ms_model_id',
type=str,
required=True,
)
parser.add_argument(
'--model_arch',
type=str,
required=True,
)
args = parser.parse_args()
get_model_and_tokenizer(args.ms_model_id, args.model_arch)