| import argparse | |
| from dataclasses import fields | |
| import torch | |
| from swift.llm import MODEL_ARCH_MAPPING, ModelKeys, get_model_tokenizer | |
| def get_model_and_tokenizer(ms_model_id, model_arch=None): | |
| try: | |
| import transformers | |
| print(f'Test model: {ms_model_id} with transformers version: {transformers.__version__}') | |
| model_ins, tokenizer = get_model_tokenizer(ms_model_id) | |
| model_ins: torch.nn.Module | |
| if model_arch: | |
| model_arch: ModelKeys = MODEL_ARCH_MAPPING[model_arch] | |
| for f in fields(model_arch): | |
| value = getattr(model_arch, f.name) | |
| if value is not None and f.name != 'arch_name': | |
| if isinstance(value, str): | |
| value = [value] | |
| for v in value: | |
| v = v.replace('{}', '0') | |
| model_ins.get_submodule(v) | |
| except Exception: | |
| import traceback | |
| print(traceback.format_exc()) | |
| raise | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| '--ms_model_id', | |
| type=str, | |
| required=True, | |
| ) | |
| parser.add_argument( | |
| '--model_arch', | |
| type=str, | |
| required=True, | |
| ) | |
| args = parser.parse_args() | |
| get_model_and_tokenizer(args.ms_model_id, args.model_arch) | |