Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import torchvision | |
| import pickle | |
| from PIL import Image | |
| import os | |
| def classify_image(img): | |
| # Get trained model | |
| models = os.listdir('models/') | |
| assert len(models)==1, "More than 1 model in 'models/' folder!" | |
| model_name = models[0] | |
| num_labels = 3 | |
| if 'mobilenet_v3_small' in model_name: | |
| weights = torchvision.models.MobileNet_V3_Small_Weights.DEFAULT | |
| model = torchvision.models.mobilenet_v3_small(weights=weights) | |
| elif 'shufflenet_v2_x0_5' in model_name: | |
| weights = torchvision.models.ShuffleNet_V2_X0_5_Weights.DEFAULT | |
| model = torchvision.models.shufflenet_v2_x0_5(pretrained=True) | |
| # Making the last layer similar to first two models for purposes of experiment | |
| model.classifier = model.fc | |
| del model.fc | |
| model.classifier = torch.nn.Sequential( | |
| torch.nn.Dropout(p=0), | |
| model.classifier | |
| ) | |
| # Overwriting _forward_impl method for shufflenet_v2_x0_5 to change "fc" to "classifier" | |
| def custom_shufflenet_forward_impl(self, x:torch.Tensor) -> torch.Tensor: | |
| # See note [TorchScript super()] | |
| x = self.conv1(x) | |
| x = self.maxpool(x) | |
| x = self.stage2(x) | |
| x = self.stage3(x) | |
| x = self.stage4(x) | |
| x = self.conv5(x) | |
| x = x.mean([2, 3]) # globalpool | |
| x = self.classifier(x) | |
| return x | |
| # Patch the _forward_impl method | |
| model._forward_impl = custom_shufflenet_forward_impl.__get__(model, torchvision.models.ResNet) | |
| model.classifier[-1] = torch.nn.Linear(in_features=model.classifier[-1].in_features, | |
| out_features=num_labels, bias=True) | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| model.to(device) | |
| ## Load pretrained weights | |
| chkp_path = os.path.join('models', model_name) | |
| model.load_state_dict(torch.load(chkp_path, map_location=device)) | |
| ## Transform image | |
| if img.mode == 'RGBA': | |
| img = img.convert('RGB') | |
| auto_transforms = weights.transforms() | |
| transformed_img = auto_transforms(img).unsqueeze(0) | |
| # Make prediction | |
| model.eval() | |
| with torch.inference_mode(): | |
| pred_label = torch.argmax(model(transformed_img)).item() | |
| with open('class_to_idx.pkl', 'rb') as file: | |
| class_to_idx = pickle.load(file) | |
| idx_to_class = {v:k for k,v in class_to_idx.items()} | |
| return idx_to_class[pred_label] | |
| with open('class_to_idx.pkl', 'rb') as file: | |
| class_to_idx = pickle.load(file) | |
| food_types = list(class_to_idx.keys()) | |
| example_images = [ | |
| Image.open(os.path.join("examples_to_predict","clam chowder.jpeg")), | |
| Image.open(os.path.join("examples_to_predict","donuts.jpg")), | |
| Image.open(os.path.join("examples_to_predict","ice cream.jpeg")) | |
| ] | |
| #The examples parameter expects a list of lists | |
| examples = [[img] for img in example_images] | |
| demo = gr.Interface(classify_image, inputs=gr.Image(type='pil'), outputs="text", | |
| description=f"Upload the picture of one of these food types {food_types}. The program will classify the image.", | |
| allow_flagging='never', examples=examples) | |
| demo.launch() | |