Spaces:
Runtime error
Runtime error
| import re | |
| from PIL import Image | |
| import torch | |
| import torch.nn as nn | |
| from transformers import AutoModel, CLIPImageProcessor | |
| from PIL import Image | |
| import requests | |
| import torch.nn.functional as F | |
| from transformers import AutoProcessor, Pix2StructVisionModel, Pix2StructProcessor, Pix2StructForConditionalGeneration | |
| cfg={ | |
| "crop_size": 256, | |
| "do_center_crop": True, | |
| "do_normalize": True, | |
| "do_resize": True, | |
| "feature_extractor_type": "CLIPFeatureExtractor", | |
| "image_mean": [ | |
| 0.48145466, | |
| 0.4578275, | |
| 0.40821073 | |
| ], | |
| "image_std": [ | |
| 0.26862954, | |
| 0.26130258, | |
| 0.27577711 | |
| ], | |
| "resample": 3, | |
| "size": 256 | |
| } | |
| ''' | |
| Pixel2Struct-Large Model (pretrained version) | |
| ''' | |
| class Pix2StructLargeVisionTower(nn.Module): | |
| def __init__(self, vision_tower, args, delay_load=False): | |
| super().__init__() | |
| self.is_loaded = False | |
| self.vision_tower_name = vision_tower | |
| self.do_resize = args.do_resize | |
| self.de_normalize = args.de_normalize # de-normalize the input image and perform preprocessing with pix2struct processor | |
| self.select_layer = args.mm_vision_select_layer # NOTE: not implemented yet, this parameter has no effect | |
| self.input_image_size = args.input_image_size | |
| self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') | |
| self.freeze_vision = args.freeze_vision | |
| self.args = args | |
| if not self.is_loaded: | |
| self.load_model() | |
| def load_model(self): | |
| if self.is_loaded: | |
| return | |
| whole_model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-large") | |
| self.vision_tower = whole_model.encoder | |
| self.pix2struct_processor = AutoProcessor.from_pretrained("google/pix2struct-large") | |
| self.pix2struct_processor.image_processor.is_vqa = False | |
| self.image_processor = CLIPImageProcessor(**cfg) | |
| if self.input_image_size is not None: | |
| self.image_processor.size=self.input_image_size | |
| self.image_processor.crop_size={ | |
| 'height':self.input_image_size, | |
| 'width': self.input_image_size | |
| } | |
| if self.freeze_vision: | |
| self.vision_tower.requires_grad_(False) | |
| self.image_mean = torch.tensor(self.image_processor.image_mean).view(1, 3, 1, 1) | |
| self.image_std = torch.tensor(self.image_processor.image_std).view(1, 3, 1, 1) | |
| self.is_loaded = True | |
| def feature_select(self, image_forward_outs): | |
| image_features = image_forward_outs.hidden_states[self.select_layer] # [bs, n, c], cls at idx=0 | |
| if self.select_feature == 'patch': | |
| image_features = image_features[:, 1:] | |
| elif self.select_feature == 'cls_patch': | |
| image_features = image_features | |
| else: | |
| raise ValueError(f'Unexpected select feature: {self.select_feature}') | |
| return image_features | |
| # @torch.no_grad() | |
| def forward(self, images): | |
| if self.de_normalize: | |
| mean = self.image_mean.clone().view(1, 3, 1, 1).to(dtype=images.dtype, device=images.device) | |
| std = self.image_std.clone().view(1, 3, 1, 1).to(dtype=images.dtype, device=images.device) | |
| x = (images * std + mean) * 255.0 | |
| x = self.pix2struct_processor(images=x.float(), return_tensors="pt") | |
| image_features = self.vision_tower(**(x.to(device=self.device, dtype=self.dtype))).last_hidden_state | |
| bs, n, c = image_features.shape | |
| image_features = image_features[:, :2025, :] # HARD CODE | |
| if self.do_resize: | |
| image_features = image_features.transpose(1,2).reshape(bs, c, 45, 45) # HARD CODE | |
| image_features = F.interpolate(image_features.float(), size=(32, 32), mode='bilinear', align_corners=True).to(dtype=image_features.dtype) # HARD CODE | |
| return image_features | |
| else: | |
| return image_features | |
| def dummy_feature(self): | |
| return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) | |
| def dtype(self): | |
| return next(self.vision_tower.parameters()).dtype | |
| def device(self): | |
| return next(self.vision_tower.parameters()).device | |
| def config(self): | |
| return self.vision_tower.config | |
| def hidden_size(self): | |
| #return self.config.hidden_size | |
| hidden_dim = 1536 | |
| return hidden_dim | |
| def num_patches(self): | |
| # return (self.config.image_size // self.config.patch_size) ** 2 | |
| return self.config['num_patches'] | |
| #main | |
| if __name__ == "__main__": | |
| ''' | |
| print('hello') | |
| from PIL import Image | |
| import requests | |
| from transformers import AutoProcessor, Pix2StructVisionModel | |
| model = Pix2StructVisionModel.from_pretrained("google/pix2struct-textcaps-base") | |
| processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-base") | |
| url = "http://images.cocodataset.org/val2017/000000039769.jpg" | |
| image = Image.open("/lustre/fsw/portfolios/llmservice/users/fuxiaol/me.jpg") | |
| for name, param in model.named_parameters(): | |
| param.requires_grad = False | |
| #inputs = processor(images=image, return_tensors="pt") | |
| image_processor = CLIPImageProcessor.from_pretrained('OpenGVLab/InternViT-6B-448px-V1-5') | |
| pixel_values = image_processor(images=image, return_tensors='pt').pixel_values | |
| pixel_values = torch.cat([pixel_values, pixel_values], dim=0) | |
| #inputs = pixel_values.to(torch.bfloat16) | |
| print('pixel_values:', pixel_values.size()) | |
| inputs = processor(images=pixel_values, max_patches=1024, return_tensors='pt')['flattened_patches'] | |
| print(inputs.size()) | |
| print(inputs.size()) | |
| outputs = model(inputs) | |
| print(outputs.last_hidden_state.size()) | |
| ''' | |
| cfg={ | |
| "crop_size": 1024, | |
| "do_center_crop": True, | |
| "do_normalize": True, | |
| "do_resize": True, | |
| "feature_extractor_type": "CLIPFeatureExtractor", | |
| "image_mean": [ | |
| 0.48145466, | |
| 0.4578275, | |
| 0.40821073 | |
| ], | |
| "image_std": [ | |
| 0.26862954, | |
| 0.26130258, | |
| 0.27577711 | |
| ], | |
| "resample": 3, | |
| "size": 1024 | |
| } | |
| from PIL import Image | |
| import requests | |
| from transformers import AutoProcessor, Pix2StructForConditionalGeneration | |
| from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig | |
| import torchvision.transforms as T | |
| processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-large") | |
| model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-large") | |
| #url = "https://www.ilankelman.org/stopsigns/australia.jpg" | |
| #image = Image.open(requests.get(url, stream=True).raw) | |
| image = Image.open("/lustre/fsw/portfolios/llmservice/users/fuxiaol/sample2.jpg") | |
| image_processor= CLIPImageProcessor(**cfg) | |
| pixel_values = image_processor(images=image, return_tensors='pt').pixel_values | |
| print(pixel_values.size()) | |
| mean = [0.48145466, 0.4578275, 0.40821073] | |
| std = [0.26862954, 0.26130258, 0.27577711] | |
| mean = torch.tensor(mean).view(1, 3, 1, 1) | |
| std = torch.tensor(std).view(1, 3, 1, 1) | |
| pixel_values = pixel_values * std + mean | |
| print(pixel_values.size()) | |
| #pixel_values.save('pix2image.jpg') | |
| transform = T.ToPILImage() | |
| img = transform(pixel_values.squeeze(0)) | |
| img.save('pix2image.jpg') | |
| inputs = processor(images=pixel_values, max_patches=1024,return_tensors="pt")['flattened_patches'] | |
| # autoregressive generation | |
| generated_ids = model.generate(inputs, max_new_tokens=50) | |
| generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| print(generated_text) | |
| #A stop sign is on a street corner. | |
| #A stop sign is on a street corner. | |
| ''' | |
| from PIL import Image | |
| import requests | |
| from transformers import AutoProcessor, CLIPModel | |
| from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig | |
| processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14-336") | |
| model = CLIPVisionModel.from_pretrained('openai/clip-vit-large-patch14-336') | |
| url = "http://images.cocodataset.org/val2017/000000039769.jpg" | |
| image = Image.open(requests.get(url, stream=True).raw) | |
| print(image) | |
| inputs = processor(images=image, return_tensors="pt") | |
| #image_features = model.get_image_features(**inputs) | |
| outputs = model(**inputs,output_hidden_states=True) | |
| print(outputs.hidden_states[-1].size()) | |
| print(outputs.hidden_states[-2].size()) | |
| print(outputs.hidden_states[-3].size()) | |
| ''' | |
| #sequence = processor.batch_decode(outputs, skip_special_tokens=True)[0] | |
| #sequence = processor.post_process_generation(sequence, fix_markdown=False) | |
| # note: we're using repr here such for the sake of printing the \n characters, feel free to just print the sequence | |
| #print(repr(sequence)) | |