Spaces:
Sleeping
Sleeping
Add changes
Browse files- app.py +6 -0
- requirements.txt +2 -0
app.py
CHANGED
|
@@ -5,6 +5,8 @@ import torch
|
|
| 5 |
from model import create_effnetb4_model
|
| 6 |
from timeit import default_timer as timer
|
| 7 |
from typing import Tuple, Dict
|
|
|
|
|
|
|
| 8 |
|
| 9 |
class_names = ['apple_pie','baby_back_ribs','baklava','beef_carpaccio','beef_tartare','beet_salad','beignets','bibimbap','bread_pudding',
|
| 10 |
'breakfast_burrito','bruschetta','caesar_salad','cannoli','caprese_salad','carrot_cake','ceviche','cheese_plate','cheesecake','chicken_curry',
|
|
@@ -40,6 +42,10 @@ effnetb4.load_state_dict(state_dict)
|
|
| 40 |
def predict(img) -> Tuple[Dict, float]:
|
| 41 |
"""Transforms and performs a prediction on img and returns prediction and time taken."""
|
| 42 |
start_time = timer()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
img = effnetb4_transforms(img).unsqueeze(0)
|
| 44 |
|
| 45 |
effnetb4.eval()
|
|
|
|
| 5 |
from model import create_effnetb4_model
|
| 6 |
from timeit import default_timer as timer
|
| 7 |
from typing import Tuple, Dict
|
| 8 |
+
import numpy as np
|
| 9 |
+
from PIL import Image
|
| 10 |
|
| 11 |
class_names = ['apple_pie','baby_back_ribs','baklava','beef_carpaccio','beef_tartare','beet_salad','beignets','bibimbap','bread_pudding',
|
| 12 |
'breakfast_burrito','bruschetta','caesar_salad','cannoli','caprese_salad','carrot_cake','ceviche','cheese_plate','cheesecake','chicken_curry',
|
|
|
|
| 42 |
def predict(img) -> Tuple[Dict, float]:
|
| 43 |
"""Transforms and performs a prediction on img and returns prediction and time taken."""
|
| 44 |
start_time = timer()
|
| 45 |
+
|
| 46 |
+
if isinstance(img, np.ndarray):
|
| 47 |
+
img = Image.fromarray(img)
|
| 48 |
+
|
| 49 |
img = effnetb4_transforms(img).unsqueeze(0)
|
| 50 |
|
| 51 |
effnetb4.eval()
|
requirements.txt
CHANGED
|
@@ -1,3 +1,5 @@
|
|
| 1 |
torch==2.2.0
|
| 2 |
torchvision==0.17.0
|
| 3 |
gradio==4.44.0
|
|
|
|
|
|
|
|
|
| 1 |
torch==2.2.0
|
| 2 |
torchvision==0.17.0
|
| 3 |
gradio==4.44.0
|
| 4 |
+
Pillow==8.4.0
|
| 5 |
+
numpy<2
|