Spaces:
Runtime error
Runtime error
bonosa
commited on
Commit
Β·
f9212bb
1
Parent(s):
95fb4df
yay!
Browse files- app.py +10 -3
- importpickle.py +14 -0
- parrotclass_state_dict.pth +3 -0
app.py
CHANGED
|
@@ -3,6 +3,7 @@ import sys
|
|
| 3 |
import pathlib
|
| 4 |
from pathlib import Path
|
| 5 |
from fastai.vision.all import load_learner, PILImage
|
|
|
|
| 6 |
import gradio
|
| 7 |
from gradio import Interface, Image, Label
|
| 8 |
import logging
|
|
@@ -10,11 +11,16 @@ logging.basicConfig(level=logging.INFO)
|
|
| 10 |
def custom_resnet_splitter(model):
|
| 11 |
resnet = model[0]
|
| 12 |
return [params(resnet[0]), params(resnet[1]), params(resnet[4]), params(resnet[5]), params(resnet[6]), params(resnet[7]), params(model[1])]
|
| 13 |
-
|
| 14 |
def predict_parrot_species(image):
|
| 15 |
try:
|
| 16 |
-
model_path = Path('
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
pred, _, _ = learn_inf.predict(image)
|
| 19 |
return pred
|
| 20 |
except Exception as e:
|
|
@@ -23,6 +29,7 @@ def predict_parrot_species(image):
|
|
| 23 |
|
| 24 |
|
| 25 |
|
|
|
|
| 26 |
input_image = Image(shape=(224, 224))
|
| 27 |
output_label = Label()
|
| 28 |
|
|
|
|
| 3 |
import pathlib
|
| 4 |
from pathlib import Path
|
| 5 |
from fastai.vision.all import load_learner, PILImage
|
| 6 |
+
from torchvision.models import resnet18
|
| 7 |
import gradio
|
| 8 |
from gradio import Interface, Image, Label
|
| 9 |
import logging
|
|
|
|
| 11 |
def custom_resnet_splitter(model):
|
| 12 |
resnet = model[0]
|
| 13 |
return [params(resnet[0]), params(resnet[1]), params(resnet[4]), params(resnet[5]), params(resnet[6]), params(resnet[7]), params(model[1])]
|
|
|
|
| 14 |
def predict_parrot_species(image):
|
| 15 |
try:
|
| 16 |
+
model_path = Path('parrotclass_state_dict.pth')
|
| 17 |
+
state_dict = torch.load(model_path)
|
| 18 |
+
|
| 19 |
+
# Load the model architecture
|
| 20 |
+
model = resnet18(num_classes=len(parrot_types))
|
| 21 |
+
model.load_state_dict(state_dict)
|
| 22 |
+
learn_inf = Learner(dls, model)
|
| 23 |
+
|
| 24 |
pred, _, _ = learn_inf.predict(image)
|
| 25 |
return pred
|
| 26 |
except Exception as e:
|
|
|
|
| 29 |
|
| 30 |
|
| 31 |
|
| 32 |
+
|
| 33 |
input_image = Image(shape=(224, 224))
|
| 34 |
output_label = Label()
|
| 35 |
|
importpickle.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
import torch
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
def convert_pickle_path(model_path, output_path):
|
| 6 |
+
with open(model_path, 'rb') as f:
|
| 7 |
+
model = pickle.load(f)
|
| 8 |
+
|
| 9 |
+
torch.save(model, output_path, pickle_module=pickle, pickle_protocol=2, _use_new_zipfile_serialization=False)
|
| 10 |
+
|
| 11 |
+
if __name__ == '__main__':
|
| 12 |
+
input_model_path = Path('parrotclass.pkl')
|
| 13 |
+
output_model_path = Path('parrotclass_converted.pkl')
|
| 14 |
+
convert_pickle_path(input_model_path, output_model_path)
|
parrotclass_state_dict.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:03d128964b45f9a76cc9e32efdfc853ef2c94c882831117b96041ac7f0c7b99f
|
| 3 |
+
size 46917395
|