bonosa commited on
Commit
3858d46
Β·
1 Parent(s): 466f4b7

work damn you!

Browse files
app.py CHANGED
@@ -1,31 +1,42 @@
1
-
 
 
 
2
  from fastai.vision.all import load_learner, PILImage
3
-
4
  import gradio
5
- import pathlib
6
- import traceback
7
- import sys
8
-
9
  from gradio import Interface, Image, Label
10
 
11
- # Patch pathlib to use WindowsPath on Windows systems
12
- if pathlib.Path == pathlib.WindowsPath:
13
- pathlib.PosixPath = pathlib.WindowsPath
 
 
 
 
 
 
 
14
 
15
- def exception_handler(type, value, tb):
16
- print("".join(traceback.format_exception(type, value, tb)))
 
 
 
 
 
 
 
17
 
18
- sys.excepthook = exception_handler
19
 
20
  # 4. Create a Gradio interface
21
  def predict_parrot_species(image):
22
- learn_inf = load_learner('parrotclassfierwin.pkl')
23
  pred, _, _ = learn_inf.predict(image)
24
  return pred
25
 
26
  input_image = Image(shape=(224, 224))
27
- output_label =Label()
28
 
29
  gr_interface = gradio.Interface(fn=predict_parrot_species, inputs=input_image, outputs=output_label)
30
- gr_interface = gradio.Interface(fn=predict_parrot_species, inputs=input_image, outputs=output_label)
31
- gr_interface.launch()
 
1
+ import torch
2
+ import pickle
3
+ import sys
4
+ import pathlib
5
  from fastai.vision.all import load_learner, PILImage
 
6
  import gradio
 
 
 
 
7
  from gradio import Interface, Image, Label
8
 
9
+ # Custom Unpickler to handle PosixPath
10
+
11
+ class CustomUnpickler(pickle.Unpickler):
12
+ def find_class(self, module, name):
13
+ if module == "pathlib" and name == "PosixPath":
14
+ return pathlib.WindowsPath
15
+ elif module == "pathlib" and name == "WindowsPath":
16
+ return pathlib.WindowsPath
17
+ return super().find_class(module, name)
18
+
19
 
20
+ def persistent_load(self, pid):
21
+ return pid
22
+ import dill
23
+ # Load learner using custom unpickler
24
+ def custom_load_learner(fname, cpu=True):
25
+ map_loc = None if torch.cuda.is_available() and not cpu else 'cpu'
26
+ with open(fname, 'rb') as f:
27
+ return dill.load(f)
28
+ from fastai.learner import load_learner
29
 
 
30
 
31
  # 4. Create a Gradio interface
32
  def predict_parrot_species(image):
33
+ learn_inf = custom_load_learner('parrotclass.pkl')
34
  pred, _, _ = learn_inf.predict(image)
35
  return pred
36
 
37
  input_image = Image(shape=(224, 224))
38
+ output_label = Label()
39
 
40
  gr_interface = gradio.Interface(fn=predict_parrot_species, inputs=input_image, outputs=output_label)
41
+ gr_interface.launch()
42
+
parrotclassfierwin.pkl β†’ parrotclass.pkl RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ce9f99e1500abd02732fc149f3dfe114befa32e03a031f677d2bb31551fe7868
3
- size 46985603
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6657a442bfa942517a8f2d9f03a768fdf6f9695c12eb47b39140cde7a71d8592
3
+ size 46973589