| from transformers import Pipeline |
| from PIL import Image |
| import torchvision.transforms as v2 |
| import torch |
| import torch.nn.functional as F |
|
|
| class TrashClassificationPipeline(Pipeline): |
| def __init__(self, **kwargs): |
| Pipeline.__init__(self, **kwargs) |
|
|
| self.transform = v2.Compose([ |
| v2.CenterCrop(size=(224, 224)), |
| v2.PILToTensor(), |
| v2.ConvertImageDtype(torch.float32), |
| v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| ]) |
|
|
| def _sanitize_parameters(self, **kwargs): |
| return {}, {}, {} |
|
|
| def preprocess(self, inputs): |
| tensor = self.transform(inputs) |
| tensor = tensor.unsqueeze(0) |
|
|
| return tensor |
|
|
| def _forward(self, tensor): |
| self.model.eval() |
| with torch.no_grad(): |
| out = self.model(tensor)["logits"] |
|
|
| return out |
|
|
| def postprocess(self, out): |
| pred = F.softmax(out, dim=1).argmax(dim=1)[0] |
| label = self.model.config.id2label[str(int(pred))] |
|
|
| return label |