Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from submodules.lang_seg.modules.models.lseg_net import LSegNet, clip | |
| class LSegFeatureExtractor(LSegNet): | |
| def __init__(self, half_res=True): | |
| super().__init__( | |
| labels='', | |
| backbone='clip_vitl16_384', | |
| features=256, | |
| crop_size=224, | |
| arch_option=0, | |
| block_depth=0, | |
| activation='lrelu' | |
| ) | |
| self.half_res = half_res | |
| def extract_features(self, x): | |
| layer_1, layer_2, layer_3, layer_4 = forward_layers(self.pretrained, x) | |
| # layer:(b, 1024, h//16, w//16) | |
| # image_features = torch.cat([layer_1, layer_2, layer_3, layer_4], dim=1) | |
| # # image_features:(b, 4096, h//16, w//16) | |
| # dense feature | |
| # DPT head | |
| pretrained = self.pretrained | |
| layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1) | |
| layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2) | |
| layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3) | |
| layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4) | |
| # refinenet | |
| layer_1_rn = self.scratch.layer1_rn(layer_1) | |
| layer_2_rn = self.scratch.layer2_rn(layer_2) | |
| layer_3_rn = self.scratch.layer3_rn(layer_3) | |
| layer_4_rn = self.scratch.layer4_rn(layer_4) | |
| path_4 = self.scratch.refinenet4(layer_4_rn) | |
| path_3 = self.scratch.refinenet3(path_4, layer_3_rn) | |
| path_2 = self.scratch.refinenet2(path_3, layer_2_rn) | |
| path_1 = self.scratch.refinenet1(path_2, layer_1_rn) | |
| # (b, 512, h//2, w//2) | |
| image_features = self.scratch.head1(path_1) | |
| if self.half_res: | |
| return image_features | |
| # (b, 512, h, w) | |
| image_features = self.scratch.output_conv(image_features) | |
| return image_features | |
| def decode_feature(self, image_features, labelset=''): | |
| # # image_features:(b, 4096, h//16, w//16) | |
| # # split image_features into 4 parts | |
| # layer_1, layer_2, layer_3, layer_4 = torch.split(image_features, 1024, dim=1) | |
| # # DPT head | |
| # pretrained = self.pretrained | |
| # layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1) | |
| # layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2) | |
| # layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3) | |
| # layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4) | |
| # # refinenet | |
| # layer_1_rn = self.scratch.layer1_rn(layer_1) | |
| # layer_2_rn = self.scratch.layer2_rn(layer_2) | |
| # layer_3_rn = self.scratch.layer3_rn(layer_3) | |
| # layer_4_rn = self.scratch.layer4_rn(layer_4) | |
| # path_4 = self.scratch.refinenet4(layer_4_rn) | |
| # path_3 = self.scratch.refinenet3(path_4, layer_3_rn) | |
| # path_2 = self.scratch.refinenet2(path_3, layer_2_rn) | |
| # path_1 = self.scratch.refinenet1(path_2, layer_1_rn) | |
| # image_features = self.scratch.head1(path_1) | |
| imshape = image_features.shape | |
| # encode text | |
| if labelset == '': | |
| text = self.text | |
| else: | |
| text = clip.tokenize(labelset) | |
| self.logit_scale = self.logit_scale.to(image_features.device) | |
| text = text.to(image_features.device) | |
| text_features = self.clip_pretrained.encode_text(text) | |
| image_features = image_features.permute(0,2,3,1).reshape(-1, self.out_c) | |
| # normalized features | |
| image_features = image_features / image_features.norm(dim=-1, keepdim=True) | |
| text_features = text_features / text_features.norm(dim=-1, keepdim=True) | |
| logits_per_image = self.logit_scale * image_features.half() @ text_features.t() | |
| out = logits_per_image.float().view(imshape[0], imshape[2], imshape[3], -1).permute(0,3,1,2) | |
| if self.arch_option in [1, 2]: | |
| for _ in range(self.block_depth - 1): | |
| out = self.scratch.head_block(out) | |
| out = self.scratch.head_block(out, False) | |
| if self.half_res: | |
| out = self.scratch.output_conv(out) | |
| return out | |
| def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): | |
| print(f"Loading checkpoint from: {pretrained_model_name_or_path}") | |
| ckpt = torch.load(pretrained_model_name_or_path, map_location='cpu') | |
| print(f"Checkpoint loaded. Keys in checkpoint: {ckpt.keys()}") | |
| print("Processing state dict...") | |
| new_state_dict = {k[len("net."):]: v for k, v in ckpt['state_dict'].items() if k.startswith("net.")} | |
| print(f"Processed state dict. Number of keys: {len(new_state_dict)}") | |
| print("Initializing model...") | |
| model = cls(*args, **kwargs) | |
| print("Loading state dict into model...") | |
| model.load_state_dict(new_state_dict, strict=True) | |
| print("State dict loaded successfully.") | |
| print("Cleaning up...") | |
| del ckpt | |
| del new_state_dict | |
| print("Model loading complete.") | |
| return model | |
| def forward_layers(pretrained, x): | |
| b, c, h, w = x.shape | |
| # encoder | |
| glob = pretrained.model.forward_flex(x) | |
| layer_1 = pretrained.activations["1"] | |
| layer_2 = pretrained.activations["2"] | |
| layer_3 = pretrained.activations["3"] | |
| layer_4 = pretrained.activations["4"] | |
| layer_1 = pretrained.act_postprocess1[0:2](layer_1) | |
| layer_2 = pretrained.act_postprocess2[0:2](layer_2) | |
| layer_3 = pretrained.act_postprocess3[0:2](layer_3) | |
| layer_4 = pretrained.act_postprocess4[0:2](layer_4) | |
| unflatten = nn.Sequential( | |
| nn.Unflatten( | |
| 2, | |
| torch.Size( | |
| [ | |
| h // pretrained.model.patch_size[1], | |
| w // pretrained.model.patch_size[0], | |
| ] | |
| ), | |
| ) | |
| ) | |
| if layer_1.ndim == 3: | |
| layer_1 = unflatten(layer_1) | |
| if layer_2.ndim == 3: | |
| layer_2 = unflatten(layer_2) | |
| if layer_3.ndim == 3: | |
| layer_3 = unflatten(layer_3) | |
| if layer_4.ndim == 3: | |
| layer_4 = unflatten(layer_4) | |
| return layer_1, layer_2, layer_3, layer_4 | |