Spaces:
Running
Running
| import torch.nn as nn | |
| import timm | |
| class TimmFRWrapperV2(nn.Module): | |
| """ | |
| Wraps timm model | |
| """ | |
| def __init__(self, model_name="edgenext_x_small", featdim=512, batchnorm=False): | |
| super().__init__() | |
| self.featdim = featdim | |
| self.model_name = model_name | |
| self.model = timm.create_model(self.model_name) | |
| self.model.reset_classifier(self.featdim) | |
| def forward(self, x): | |
| x = self.model(x) | |
| return x | |
| class LoRaLin(nn.Module): | |
| def __init__(self, in_features, out_features, rank, bias=True): | |
| super(LoRaLin, self).__init__() | |
| self.in_features = in_features | |
| self.out_features = out_features | |
| self.rank = rank | |
| self.linear1 = nn.Linear(in_features, rank, bias=False) | |
| self.linear2 = nn.Linear(rank, out_features, bias=bias) | |
| def forward(self, input): | |
| x = self.linear1(input) | |
| x = self.linear2(x) | |
| return x | |
| def replace_linear_with_lowrank_recursive_2(model, rank_ratio=0.2): | |
| for name, module in model.named_children(): | |
| if isinstance(module, nn.Linear) and "head" not in name: | |
| in_features = module.in_features | |
| out_features = module.out_features | |
| rank = max(2, int(min(in_features, out_features) * rank_ratio)) | |
| bias = False | |
| if module.bias is not None: | |
| bias = True | |
| lowrank_module = LoRaLin(in_features, out_features, rank, bias) | |
| setattr(model, name, lowrank_module) | |
| else: | |
| replace_linear_with_lowrank_recursive_2(module, rank_ratio) | |
| def replace_linear_with_lowrank_2(model, rank_ratio=0.2): | |
| replace_linear_with_lowrank_recursive_2(model, rank_ratio) | |
| return model | |
| model_configs = { | |
| "edgeface_base": { | |
| "repo": "idiap/EdgeFace-Base", | |
| "filename": "edgeface_base.pt", | |
| "timm_model": "edgenext_base", | |
| "post_setup": lambda x: x, | |
| }, | |
| "edgeface_s_gamma_05": { | |
| "repo": "idiap/EdgeFace-S-GAMMA", | |
| "filename": "edgeface_s_gamma_05.pt", | |
| "timm_model": "edgenext_small", | |
| "post_setup": lambda x: replace_linear_with_lowrank_2(x, rank_ratio=0.5), | |
| }, | |
| "edgeface_xs_gamma_06": { | |
| "repo": "idiap/EdgeFace-XS-GAMMA", | |
| "filename": "edgeface_xs_gamma_06.pt", | |
| "timm_model": "edgenext_x_small", | |
| "post_setup": lambda x: replace_linear_with_lowrank_2(x, rank_ratio=0.6), | |
| }, | |
| "edgeface_xxs": { | |
| "repo": "idiap/EdgeFace-XXS", | |
| "filename": "edgeface_xxs.pt", | |
| "timm_model": "edgenext_xx_small", | |
| "post_setup": lambda x: x, | |
| }, | |
| } | |