Spaces:
Runtime error
Runtime error
| import time | |
| import gradio as gr | |
| from pathlib import Path | |
| from models import * | |
| class_idx_to_names = { | |
| 0: "pizza", | |
| 1: "steak", | |
| 2: "sushi" | |
| } | |
| examples = [[str(path)] for path in Path(r"examples").glob("*")] | |
| def predict_one(model, transforms, image, device, class_idx_to_names): | |
| model.eval() | |
| model = model.to(device) | |
| with torch.inference_mode(): | |
| start_time = time.perf_counter() | |
| image_transformed = transforms(image).unsqueeze(dim = 0).to(device) | |
| y_logits = model(image_transformed) | |
| y_preds = torch.softmax(y_logits, dim = 1) | |
| y_probs = torch.argmax(y_preds, dim = 1) | |
| end_time = time.perf_counter() | |
| predictions = {class_idx_to_names[index]: x.item() for index, x in enumerate(y_preds[0])} | |
| return predictions, end_time - start_time | |
| def predict(image, model_choice): | |
| if model_choice is None or model_choice == "effnet_b2": | |
| model, transforms = get_effnet_b2() | |
| else: | |
| model, transforms = get_vit_16_base_transformer() | |
| predictions, time_taken = predict_one(model, transforms, image, "cpu", class_idx_to_names) | |
| return predictions, time_taken | |
| title = "Food Recognition ππ" | |
| desc = "A dual model app ft. EfficientNetB2 Feature Extractor and VisionTransformer." | |
| article = ''' | |
| ## Stats on different Models | |
| --- | |
| | Model Name | Train Loss | Test Loss | Train Accuracy | Test Accuracy | Num Parameters | Model Size | | |
| |-----------------|------------|-----------|----------------|---------------|----------------|------------| | |
| | EfficientNet_b2 | 0.340270 | 0.301134 | 0.906250 | 0.953409 | 7705221 | 29.91 MB | | |
| | ViT_Base_16 | 0.040448 | 0.055140 | 0.995833 | 0.981250 | 85800963 | 327.39 MB | | |
| ''' | |
| demo = gr.Interface(fn = predict, | |
| inputs = [gr.Image(type = "pil", label = "upload an Jpeg or Png"), gr.Radio(["effnet_b2", "ViT (Vision Transformer)"], label = "choose model (default on effnet)")], | |
| outputs = [gr.Label(num_top_classes=3, label = "predictions"), gr.Number(label = "Prediction Time in seconds")], | |
| examples = examples, | |
| title = title, | |
| description=desc, | |
| article=article) | |
| demo.launch(debug = False) |