Spaces:
Runtime error
Runtime error
| import torch | |
| import torchvision.models as models | |
| from PIL import Image | |
| from efficientnet_pytorch import EfficientNet | |
| import torchvision.transforms as transforms | |
| import gradio as gr | |
| from gradio import components | |
| import numpy as np | |
| def predict(image): | |
| image = Image.fromarray(np.uint8(image)).convert('RGB') | |
| model = EfficientNet.from_name('efficientnet-b7', num_classes=2) | |
| model_weights_path = 'efficientnetb7_tyrequality_classifier.pth' | |
| model.load_state_dict(torch.load(model_weights_path, map_location=torch.device('cpu'))) | |
| model.eval() | |
| transform = transforms.Compose([ | |
| transforms.Resize(224), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], | |
| [0.229, 0.224, 0.225]) | |
| ]) | |
| input_data = transform(image).unsqueeze(0) | |
| class_to_label = {0: 'defective', 1: 'good'} | |
| with torch.no_grad(): | |
| output = model(input_data) | |
| # Get the predicted class label | |
| _, predicted_class = torch.max(output, 1) | |
| probs = torch.nn.functional.softmax(output, dim=1) | |
| # print(probs, "probs") | |
| conf, _ = torch.max(probs, 1) | |
| result = "Tire status is {} with confidence level in {}%".format(class_to_label[predicted_class.item()], conf.item()*100) | |
| return result | |
| iface = gr.Interface(fn=predict, | |
| inputs=gr.Image(), | |
| outputs="textbox") | |
| iface.launch(share=True) |