bonosa commited on
Commit
f9212bb
Β·
1 Parent(s): 95fb4df
Files changed (3) hide show
  1. app.py +10 -3
  2. importpickle.py +14 -0
  3. parrotclass_state_dict.pth +3 -0
app.py CHANGED
@@ -3,6 +3,7 @@ import sys
3
  import pathlib
4
  from pathlib import Path
5
  from fastai.vision.all import load_learner, PILImage
 
6
  import gradio
7
  from gradio import Interface, Image, Label
8
  import logging
@@ -10,11 +11,16 @@ logging.basicConfig(level=logging.INFO)
10
  def custom_resnet_splitter(model):
11
  resnet = model[0]
12
  return [params(resnet[0]), params(resnet[1]), params(resnet[4]), params(resnet[5]), params(resnet[6]), params(resnet[7]), params(model[1])]
13
-
14
  def predict_parrot_species(image):
15
  try:
16
- model_path = Path('parrotclass.pkl') # Use Path instead of a string
17
- learn_inf = load_learner(model_path)
 
 
 
 
 
 
18
  pred, _, _ = learn_inf.predict(image)
19
  return pred
20
  except Exception as e:
@@ -23,6 +29,7 @@ def predict_parrot_species(image):
23
 
24
 
25
 
 
26
  input_image = Image(shape=(224, 224))
27
  output_label = Label()
28
 
 
3
  import pathlib
4
  from pathlib import Path
5
  from fastai.vision.all import load_learner, PILImage
6
+ from torchvision.models import resnet18
7
  import gradio
8
  from gradio import Interface, Image, Label
9
  import logging
 
11
  def custom_resnet_splitter(model):
12
  resnet = model[0]
13
  return [params(resnet[0]), params(resnet[1]), params(resnet[4]), params(resnet[5]), params(resnet[6]), params(resnet[7]), params(model[1])]
 
14
  def predict_parrot_species(image):
15
  try:
16
+ model_path = Path('parrotclass_state_dict.pth')
17
+ state_dict = torch.load(model_path)
18
+
19
+ # Load the model architecture
20
+ model = resnet18(num_classes=len(parrot_types))
21
+ model.load_state_dict(state_dict)
22
+ learn_inf = Learner(dls, model)
23
+
24
  pred, _, _ = learn_inf.predict(image)
25
  return pred
26
  except Exception as e:
 
29
 
30
 
31
 
32
+
33
  input_image = Image(shape=(224, 224))
34
  output_label = Label()
35
 
importpickle.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import torch
3
+ from pathlib import Path
4
+
5
+ def convert_pickle_path(model_path, output_path):
6
+ with open(model_path, 'rb') as f:
7
+ model = pickle.load(f)
8
+
9
+ torch.save(model, output_path, pickle_module=pickle, pickle_protocol=2, _use_new_zipfile_serialization=False)
10
+
11
+ if __name__ == '__main__':
12
+ input_model_path = Path('parrotclass.pkl')
13
+ output_model_path = Path('parrotclass_converted.pkl')
14
+ convert_pickle_path(input_model_path, output_model_path)
parrotclass_state_dict.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:03d128964b45f9a76cc9e32efdfc853ef2c94c882831117b96041ac7f0c7b99f
3
+ size 46917395