Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| from timm.models.vision_transformer import default_cfgs | |
| from timm.models.helpers import load_pretrained, load_custom_pretrained | |
| from src.models.vit.utils import checkpoint_filter_fn | |
| from src.models.vit.vit import VisionTransformer | |
| def create_vit(model_cfg): | |
| model_cfg = model_cfg.copy() | |
| backbone = model_cfg.pop("backbone") | |
| model_cfg.pop("normalization") | |
| model_cfg["n_cls"] = 1000 | |
| mlp_expansion_ratio = 4 | |
| model_cfg["d_ff"] = mlp_expansion_ratio * model_cfg["d_model"] | |
| if backbone in default_cfgs: | |
| default_cfg = default_cfgs[backbone] | |
| else: | |
| default_cfg = dict( | |
| pretrained=False, | |
| num_classes=1000, | |
| drop_rate=0.0, | |
| drop_path_rate=0.0, | |
| drop_block_rate=None, | |
| ) | |
| default_cfg["input_size"] = ( | |
| 3, | |
| model_cfg["image_size"][0], | |
| model_cfg["image_size"][1], | |
| ) | |
| model = VisionTransformer(**model_cfg) | |
| if backbone == "vit_base_patch8_384": | |
| path = os.path.expandvars("$TORCH_HOME/hub/checkpoints/vit_base_patch8_384.pth") | |
| state_dict = torch.load(path, map_location="cpu") | |
| filtered_dict = checkpoint_filter_fn(state_dict, model) | |
| model.load_state_dict(filtered_dict, strict=True) | |
| elif "deit" in backbone: | |
| load_pretrained(model, default_cfg, filter_fn=checkpoint_filter_fn) | |
| else: | |
| load_custom_pretrained(model, default_cfg) | |
| return model | |