Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from torchvision import transforms, models | |
| from PIL import Image | |
| from torch import nn | |
| model_name = "b0" | |
| if model_name == "b4": | |
| IMAGE_RESIZE_SHAPE = 384 | |
| IMAGE_FINAL_SHAPE = 380 | |
| BATCH_SIZE = 32 | |
| FEATURE_SHAPE = 1792 | |
| if model_name == "b0": | |
| IMAGE_RESIZE_SHAPE = 256 | |
| IMAGE_FINAL_SHAPE = 224 | |
| BATCH_SIZE = 32 | |
| FEATURE_SHAPE = 1280 | |
| def load_labels(label_text_path): | |
| with open(label_text_path, "r") as f: | |
| lables = [line.strip() for line in f.readlines()] | |
| label_dict = {i: lables[i] for i in range(len(lables))} | |
| return label_dict | |
| label_dict = load_labels("labels.txt") | |
| # Load PyTorch model | |
| model_params = torch.load("food101.pt", map_location=torch.device("cpu")) | |
| if model_name == "b4": | |
| model = models.efficientnet_b4() | |
| if model_name == "b0": | |
| model = models.efficientnet_b0() | |
| model.eval() | |
| for params in model.parameters(): | |
| params.requires_grad = False | |
| model.classifier[1] = nn.Linear(in_features=FEATURE_SHAPE, out_features=101) | |
| model.load_state_dict(model_params) | |
| # Define image transformation | |
| normalize = transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225], | |
| ) | |
| transform = transforms.Compose( | |
| [ | |
| transforms.Resize(IMAGE_RESIZE_SHAPE), | |
| transforms.CenterCrop(IMAGE_FINAL_SHAPE), | |
| transforms.ToTensor(), | |
| normalize, | |
| ] | |
| ) | |
| # Define prediction function | |
| def predict_image_class(image): | |
| # Load image | |
| image = Image.fromarray(image.astype("uint8"), "RGB") | |
| # Apply transformation | |
| transformed_image = transform(image) | |
| # Add batch dimension | |
| transformed_image = transformed_image.unsqueeze(0) | |
| # Disable gradient calculation | |
| with torch.no_grad(): | |
| # Make prediction | |
| output = model(transformed_image) | |
| _, indices = torch.sort(output, descending=True) | |
| percentage = torch.nn.functional.softmax(output, dim=1)[0] | |
| # create a dictionary of top 10 classes | |
| top_10 = {} | |
| for idx in indices[0][:10]: | |
| top_10[label_dict[idx.item()]] = percentage[idx].item() | |
| return top_10 | |
| def main(): | |
| # Define Gradio interface | |
| description = "This is a demo of EfficientNet trained on Food101 dataset.\ | |
| Upload an image of food and it will predict the class of the food." | |
| inputs = gr.Image() | |
| outputs = gr.Label(num_top_classes=10, label="Prediction") | |
| gradio_app = gr.Interface( | |
| fn=predict_image_class, | |
| inputs=inputs, | |
| outputs=outputs, | |
| title="FoodVision", | |
| description=description, | |
| theme="snehilsanyal/scikit-learn", | |
| examples=[ | |
| ["examples/pizza.jpg"], | |
| ["examples/samosa.jpg"], | |
| ], | |
| ) | |
| gradio_app.queue().launch(server_name="0.0.0.0", server_port=7860) | |
| if __name__ == "__main__": | |
| # Run Gradio app | |
| main() | |