Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| import torch.utils.checkpoint | |
| from torch import Tensor, nn | |
| import transformers | |
| from transformers import SamProcessor | |
| from transformers import SamModel, SamVisionConfig, SamVisionConfig | |
| from transformers import SamImageProcessor | |
| from PIL import Image | |
| # Copied from transformers.models.convnext.modeling_convnext.ConvNextLayerNorm with ConvNext->Sam | |
| class SamLayerNorm(nn.Module): | |
| r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. | |
| The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, | |
| width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width). | |
| """ | |
| def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.ones(normalized_shape)) | |
| self.bias = nn.Parameter(torch.zeros(normalized_shape)) | |
| self.eps = eps | |
| self.data_format = data_format | |
| if self.data_format not in ["channels_last", "channels_first"]: | |
| raise NotImplementedError(f"Unsupported data format: {self.data_format}") | |
| self.normalized_shape = (normalized_shape,) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| if self.data_format == "channels_last": | |
| x = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) | |
| elif self.data_format == "channels_first": | |
| input_dtype = x.dtype | |
| x = x.float() | |
| u = x.mean(1, keepdim=True) | |
| s = (x - u).pow(2).mean(1, keepdim=True) | |
| x = (x - u) / torch.sqrt(s + self.eps) | |
| x = x.to(dtype=input_dtype) | |
| x = self.weight[:, None, None] * x + self.bias[:, None, None] | |
| return x | |
| class ShortSamVisionNeck(nn.Module): | |
| def __init__(self, config: SamVisionConfig): | |
| super().__init__() | |
| self.config = config | |
| self.conv1 = nn.Conv2d(config.hidden_size, config.output_channels, kernel_size=1, bias=False) | |
| self.layer_norm1 = SamLayerNorm(config.output_channels, data_format="channels_first") | |
| def forward(self, hidden_states): | |
| hidden_states = hidden_states.permute(0, 3, 1, 2) | |
| hidden_states = self.conv1(hidden_states) | |
| hidden_states = self.layer_norm1(hidden_states) | |
| hidden_states = hidden_states.permute(0,2,3,1) | |
| return hidden_states | |
| class SAMVisionTower(nn.Module): | |
| def __init__(self, vision_tower, args): | |
| super().__init__() | |
| self.args = args | |
| self.is_loaded = False | |
| self.vision_tower_name = vision_tower | |
| self.input_image_size = args.input_image_size | |
| self.pixel_shuffle = getattr(args, 'add_pixel_shuffle', False) | |
| self.freeze = args.freeze_vision | |
| self.load_model() | |
| def load_model(self): | |
| if self.is_loaded: | |
| return | |
| self.image_processor= SamProcessor.from_pretrained("facebook/sam-vit-large") | |
| sam_model = SamModel.from_pretrained("facebook/sam-vit-large").vision_encoder | |
| sam_model.neck = ShortSamVisionNeck(sam_model.config) | |
| self.image_processor.preprocess = self.image_processor.__call__ | |
| self.image_processor.image_mean = [0.485,0.456,0.406] | |
| self.vision_tower = sam_model | |
| if self.freeze: | |
| self.vision_tower.requires_grad_(False) | |
| self.is_loaded = True | |
| def forward(self, images): | |
| if type(images) is list: | |
| image_features = [] | |
| for image in images: | |
| image_feature = self.vision_tower(image.to(device=self.device).unsqueeze(0)) | |
| image_features.append(image_feature) | |
| else: | |
| image_features = self.vision_tower(images.to(device=self.device)).last_hidden_state.flatten(start_dim=1, end_dim=2).to(device=self.device) | |
| if self.pixel_shuffle: | |
| b, n, c = image_features.shape | |
| h = w = int(n ** 0.5) | |
| image_features = image_features.transpose(1,2).reshape(b, c, h, w) | |
| image_features = nn.functional.pixel_unshuffle(image_features, 2) | |
| return image_features | |
| def dummy_feature(self): | |
| return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) | |
| def dtype(self): | |
| return next(self.vision_tower.parameters()).dtype | |
| def device(self): | |
| return next(self.vision_tower.parameters()).device | |
| def config(self): | |
| # if self.is_loaded: | |
| # return self.vision_tower.config | |
| # else: | |
| # return self.cfg_only | |
| config_info = SamVisionConfig() | |
| return SamVisionConfig() | |
| def hidden_size(self): | |
| #return self.config.hidden_size | |
| if self.pixel_shuffle: | |
| hidden_size = 256 * 4 | |
| else: | |
| hidden_size = 256 | |
| return hidden_size | |
| def num_patches(self): | |
| # return (self.config.image_size // self.config.patch_size) ** 2 | |
| return self.config.num_patches | |
| #main | |
| if __name__ == "__main__": | |
| sam_model = SamModel.from_pretrained("facebook/sam-vit-large").vision_encoder | |
| #sam_model = SamModel.from_pretrained("facebook/sam-vit-large") | |
| sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-large") | |
| for name, param in sam_model.named_parameters(): | |
| param.requires_grad = False | |
| #raw_image = torch.rand(1, 3, 224, 224).to('cuda') | |
| raw_image = Image.open('/lustre/fsw/portfolios/llmservice/users/fuxiaol/image/me.jpg').convert("RGB") | |
| inputs = sam_processor(raw_image, return_tensors="pt") | |
| #print(inputs) | |
| #print(inputs['pixel_values']) | |
| out = sam_model(inputs['pixel_values']) | |
| print(out[0].size()) | |
| #vision_config = SamVisionConfig() | |
| #print('=============') | |
| #print(vision_config.hidden_size) | |
| #print('=============') | |
| #print(out) | |
| #print(out) | |
| #print(out) | |
| #config_vision | |
| #vision_config = SamVisionConfig() | |
| #print(sam_model.layers) | |
| #print(vision_config) | |