| from typing import Dict, List, Any |
| from PIL import Image |
| import requests |
| import torch |
| import base64 |
| import os |
| from io import BytesIO |
| from blip import blip_decoder |
| from torchvision import transforms |
| from torchvision.transforms.functional import InterpolationMode |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(device) |
|
|
| class PreTrainedPipeline(): |
| def __init__(self, path=""): |
| |
| self.model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth' |
| self.model = blip_decoder(pretrained=self.model_url, image_size=384, vit='large',med_config=os.path.join(path, 'configs/med_config.json')) |
| self.model.eval() |
| self.model = self.model.to(device) |
| |
| image_size = 384 |
| self.transform = transforms.Compose([ |
| transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC), |
| transforms.ToTensor(), |
| transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) |
| ]) |
| |
|
|
|
|
| def __call__(self, data: Any) -> List[List[Dict[str, float]]]: |
| """ |
| Args: |
| data (:obj:): |
| includes the input data and the parameters for the inference. |
| Return: |
| A :obj:`list`:. The object returned should be a list of one list like [[{"label": 0.9939950108528137}]] containing : |
| - "label": A string representing what the label/class is. There can be multiple labels. |
| - "score": A score between 0 and 1 describing how confident the model is for this label/class. |
| """ |
| inputs = data.pop("inputs", data) |
| parameters = data.pop("parameters", None) |
|
|
| |
| image = Image.open(BytesIO(base64.b64decode(inputs['image']))) |
| image = self.transform(image).unsqueeze(0).to(device) |
| with torch.no_grad(): |
| caption = self.model.generate(image, sample=True, top_p=0.9, max_length=20, min_length=5) |
| |
| return caption |
|
|