| import torch | |
| import timm | |
| from safetensors.torch import load_file | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| def load_model(path, arch="vit_base_patch16_224"): | |
| state_dict = load_file(path) | |
| # auto detect output classes | |
| last_key = list(state_dict.keys())[-1] | |
| out_features = state_dict[last_key].shape[0] | |
| model = timm.create_model(arch, pretrained=False, num_classes=out_features) | |
| model.load_state_dict(state_dict, strict=False) | |
| model.to(DEVICE) | |
| model.eval() | |
| return model |