Spaces:
Runtime error
Runtime error
| import requests | |
| from io import BytesIO | |
| from PIL import Image, ImageOps | |
| import torchvision.transforms as T | |
| import torch | |
| import gc | |
| import pickle as pkl | |
| from pathlib import Path | |
| THIS_FOLDER = Path(__file__).parent.resolve() | |
| import datetime as dt | |
| from transformers import AutoModel | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import logging | |
| logging.disable(logging.INFO) | |
| logging.disable(logging.WARNING) | |
| visionTransformer = AutoModel.from_pretrained(r"google/vit-base-patch16-224-in21k", | |
| cache_dir=None | |
| ) | |
| class PlantVision(nn.Module): | |
| def __init__(self, num_classes): | |
| super(PlantVision, self).__init__() | |
| self.vit = visionTransformer | |
| count = 0 | |
| for child in self.vit.children(): | |
| count += 1 | |
| if count < 4: | |
| for param in child.parameters(): | |
| param.requires_grad = False | |
| self.vitLayers = list(self.vit.children()) | |
| self.vitTop = nn.Sequential(*self.vitLayers[:-2]) | |
| self.vitNorm = list(self.vit.children())[2] | |
| self.vit = None | |
| gc.collect() | |
| self.vitFlatten = nn.Flatten() | |
| self.vitLinear = nn.Linear(151296,num_classes) | |
| self.fc = nn.Linear(num_classes, num_classes) | |
| def forward(self, input): | |
| output = self.vitTop(input).last_hidden_state | |
| output = self.vitNorm(output) | |
| output = self.vitFlatten(output) | |
| output = F.relu(self.vitLinear(output)) | |
| output = self.fc(output) | |
| return output | |
| device = 'cpu' # ('cuda' if torch.cuda.is_available else 'cpu') | |
| with open(fr'{THIS_FOLDER}/resources/flowerLabelSet.pkl', 'rb') as f: | |
| flowerLabelSet = pkl.load(f) | |
| with open(fr'{THIS_FOLDER}/resources/leafLabelSet.pkl', 'rb') as f: | |
| leafLabelSet = pkl.load(f) | |
| with open(fr'{THIS_FOLDER}/resources/fruitLabelSet.pkl', 'rb') as f: | |
| fruitLabelSet = pkl.load(f) | |
| def loadModel(feature, labelSet): | |
| model = PlantVision(num_classes=len(labelSet)) | |
| model.vitFlatten.load_state_dict(torch.load(BytesIO(requests.get(f"https://storage.googleapis.com/bmllc-plant-model-bucket/{feature}-vitFlatten-weights.pt").content), map_location=torch.device(device)), strict=False) | |
| model.vitLinear.load_state_dict(torch.load(BytesIO(requests.get(f"https://storage.googleapis.com/bmllc-plant-model-bucket/{feature}-vitLinear-weights.pt").content), map_location=torch.device(device)), strict=False) | |
| model.fc.load_state_dict(torch.load(BytesIO(requests.get(f"https://storage.googleapis.com/bmllc-plant-model-bucket/{feature}-fc-weights.pt").content), map_location=torch.device(device)), strict=False) | |
| model = model.half() | |
| return model | |
| start = dt.datetime.now() | |
| flower = loadModel('flower',flowerLabelSet) | |
| leaf = loadModel('leaf',leafLabelSet) | |
| fruit = loadModel('fruit',fruitLabelSet) | |
| print(dt.datetime.now() - start) | |
| def processImage(imagePath, feature): | |
| with open(fr'{THIS_FOLDER}/resources/{feature}MeansAndStds.pkl', 'rb') as f: | |
| meansAndStds = pkl.load(f) | |
| img = Image.open(imagePath).convert('RGB') | |
| cropped = ImageOps.fit(img, (224,224), Image.Resampling.LANCZOS) | |
| process = T.Compose([ | |
| T.CenterCrop(224), | |
| T.ToTensor(), | |
| T.ConvertImageDtype(torch.float32), | |
| T.Normalize( | |
| mean=meansAndStds['mean'], | |
| std=meansAndStds['std']) | |
| ]) | |
| return process(cropped) | |
| def see(tensor,feature,k): | |
| if feature=='flower': | |
| model = flower.float() | |
| labelSet = flowerLabelSet | |
| elif feature=='leaf': | |
| model = leaf.float() | |
| labelSet = leafLabelSet | |
| elif feature=='fruit': | |
| model = fruit.float() | |
| labelSet = fruitLabelSet | |
| with torch.no_grad(): | |
| output = model(tensor.unsqueeze(0)) | |
| top = torch.topk(output,k,dim=1) | |
| predictions = top.indices[0] | |
| predictedSpecies = [] | |
| for i in predictions: | |
| predictedSpecies.append(labelSet[i]) | |
| model = None | |
| gc.collect() | |
| return predictedSpecies | |