| import uform | |
| import torch | |
| import coremltools as ct | |
| from os.path import join | |
| from argparse import ArgumentParser | |
| class TextEncoder(torch.nn.Module): | |
| def __init__(self, model): | |
| super().__init__() | |
| self.model = model.eval() | |
| def forward(self, input_ids, attention_mask): | |
| features = self.model.forward_features( | |
| input_ids, attention_mask | |
| ) | |
| embeddings = self.model.forward_embedding( | |
| features, attention_mask | |
| ) | |
| return features, embeddings | |
| class ImageEncoder(torch.nn.Module): | |
| def __init__(self, model): | |
| super().__init__() | |
| self.model = model.eval() | |
| def forward(self, image): | |
| features = self.model.forward_features( | |
| image | |
| ) | |
| embeddings = self.model.forward_embedding( | |
| features | |
| ) | |
| return features, embeddings | |
| def convert_model(opts): | |
| src_model = uform.get_model(opts.model_name) | |
| input_ids = torch.ones(1, src_model.text_encoder.max_position_embeddings, dtype=torch.int32) | |
| attention_mask = torch.ones(1, src_model.text_encoder.max_position_embeddings, dtype=torch.int32) | |
| image = torch.ones(1, 3, src_model.image_encoder.image_size, src_model.image_encoder.image_size, dtype=torch.float32) | |
| print('Tracing models…') | |
| image_encoder = ImageEncoder(src_model.image_encoder).eval() | |
| image_encoder = torch.jit.trace(image_encoder, image) | |
| text_encoder = TextEncoder(src_model.text_encoder).eval() | |
| text_encoder = torch.jit.trace(text_encoder, (input_ids, attention_mask)) | |
| print('Converting models…') | |
| if opts.image_batchsize_lb == opts.image_batchsize_ub: | |
| image_batch_dim_shape = opts.image_batchsize_lb | |
| else: | |
| image_batch_dim_shape = ct.RangeDim(lower_bound=opts.image_batchsize_lb, upper_bound=opts.image_batchsize_ub, default=1) | |
| image_encoder = ct.convert( | |
| image_encoder, | |
| convert_to='mlprogram', | |
| inputs=[ | |
| ct.TensorType( | |
| name='image', | |
| shape=(image_batch_dim_shape,) + image.shape[1:], | |
| dtype=image.numpy().dtype | |
| )], | |
| outputs=[ | |
| ct.TensorType( | |
| name='features' | |
| ), | |
| ct.TensorType( | |
| name='embeddings' | |
| ) | |
| ], | |
| compute_precision=ct.precision.FLOAT16 if opts.use_fp16 else ct.precision.FLOAT32 | |
| ) | |
| if opts.text_batchsize_lb == opts.text_batchsize_ub: | |
| text_batch_dim_shape = opts.text_batchsize_lb | |
| else: | |
| text_batch_dim_shape = ct.RangeDim(lower_bound=opts.text_batchsize_lb, upper_bound=opts.text_batchsize_ub, default=1) | |
| text_encoder = ct.convert( | |
| text_encoder, | |
| convert_to='mlprogram', | |
| inputs=[ | |
| ct.TensorType( | |
| name='input_ids', | |
| shape=(text_batch_dim_shape,) + input_ids.shape[1:], | |
| dtype=input_ids.numpy().dtype | |
| ), | |
| ct.TensorType( | |
| name='attention_mask', | |
| shape=(text_batch_dim_shape,) + attention_mask.shape[1:], | |
| dtype=attention_mask.numpy().dtype | |
| )], | |
| outputs=[ | |
| ct.TensorType( | |
| name="features" | |
| ), | |
| ct.TensorType( | |
| name="embeddings" | |
| ) | |
| ], | |
| compute_precision=ct.precision.FLOAT16 if opts.use_fp16 else ct.precision.FLOAT32 | |
| ) | |
| print('Image encoder:', image_encoder, sep='\n') | |
| print('Text encoder:', text_encoder, sep='\n') | |
| image_encoder.save(join(opts.output_dir, f"{opts.model_name.replace('/', '.')}.image-encoder.mlpackage")) | |
| text_encoder.save(join(opts.output_dir, f"{opts.model_name.replace('/', '.')}.text-encoder.mlpackage")) | |
| if __name__ == '__main__': | |
| opts = ArgumentParser() | |
| opts.add_argument('--model_name', | |
| action='store', | |
| type=str, | |
| help='UForm model name') | |
| opts.add_argument('--text_batchsize_lb', | |
| action='store', | |
| type=int, | |
| help='lower bound of batch size for text encoder') | |
| opts.add_argument('--text_batchsize_ub', | |
| action='store', | |
| type=int, | |
| help='upper bound of batch size for text encoder') | |
| opts.add_argument('--image_batchsize_lb', | |
| action='store', | |
| type=int, | |
| help='lower bound of batch size for image encoder') | |
| opts.add_argument('--image_batchsize_ub', | |
| action='store', | |
| type=int, | |
| help='upper bound of batch size for image encoder') | |
| opts.add_argument('-use_fp16', | |
| action='store_true', | |
| help='whether to use fp16 for inference or not') | |
| opts.add_argument('--output_dir', | |
| action='store', | |
| type=str, | |
| help='ouput directory') | |
| convert_model(opts.parse_args()) | |