Spaces:
Build error
Build error
| from torchvision import transforms | |
| import torch | |
| import torch.utils.data | |
| from PIL import Image | |
| from source.model import CNN | |
| def classify_eye(image: torch.Tensor, | |
| model: CNN) -> str: | |
| """ | |
| Generate caption of a single image of size (3, 224, 224). | |
| Generating of caption starts with <sos>, and each next predicted word ID | |
| is appended for the next LSTM input until the sentence reaches MAX_LENGTH or <eos>. | |
| Returns: | |
| list[str]: caption for given image | |
| """ | |
| # image: (3, 32, 32) | |
| image = image.unsqueeze(0) | |
| # image: (1, 3, 32, 32) | |
| output = model.forward(image) | |
| _, prediction = torch.max(output, dim=1) | |
| if prediction == 0: | |
| output = 'Normal' | |
| elif prediction == 1: | |
| output = 'Red' | |
| return output | |
| def main_classification(image): | |
| image = Image.fromarray(image.astype('uint8'), 'RGB') | |
| transform = transforms.Compose([ | |
| transforms.Resize((32, 32)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | |
| ]) | |
| image = transform(image) | |
| image = image.to(torch.device("cpu")) | |
| cnn = CNN().to(torch.device("cpu")) | |
| cnn.eval() | |
| cnn.load_state_dict(torch.load(f='source/weights/CNN-B8-LR-0.01-E30.pt', map_location=torch.device("cpu"))) | |
| prediction_outcome = classify_eye(image, cnn) | |
| return prediction_outcome | |