jalarrimore commited on
Commit
9d8880e
·
verified ·
1 Parent(s): 8e15f7a

Built file with multitarget

Browse files
Files changed (1) hide show
  1. app.py +63 -0
app.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __all__ = ['learn', 'classify_image', 'bear', 'env', 'age', 'image', 'label', 'examples', 'intf',
2
+ 'bear_err', 'bear_loss', 'environ_err', 'environ_loss', 'age_err', 'age_loss', 'combine_loss',
3
+ 'get_bear', 'get_age', 'get_environ']
4
+
5
+ from fastai.vision.all import *
6
+ import gradio as gr
7
+
8
+ def get_bear(p):
9
+ name_components = str(p).split(' ')
10
+ return name_components[2]
11
+
12
+ def get_age(p):
13
+ name_components = str(p).split(' ')
14
+ return name_components[1]
15
+
16
+ def get_environ(p):
17
+ name_components = str(p).split(' ')
18
+ file_components = name_components[0].split('/')
19
+ return file_components[-1]
20
+
21
+ def bear_err(inp,bear,environ,age): return error_rate(inp[:,:5],bear)
22
+ def bear_loss(inp,bear,environ,age): return F.cross_entropy(inp[:,:5],bear)
23
+ def environ_err(inp,bear,environ,age): return error_rate(inp[:,5:8],environ)
24
+ def environ_loss(inp,bear,environ,age): return F.cross_entropy(inp[:,5:8],environ)
25
+ def age_err(inp,bear,environ,age): return error_rate(inp[:,8:],age)
26
+ def age_loss(inp,bear,environ,age): return F.cross_entropy(inp[:,8:],age)
27
+
28
+ def combine_loss(inp,bear,environ,age):
29
+ return bear_loss(inp,bear,environ,age)+environ_loss(inp,bear,environ,age)+age_loss(inp,bear,environ,age)
30
+
31
+ learn = load_learner('multitarget_bears.pkl')
32
+
33
+ bear = ('black', 'brown', 'grizzly', 'sloth', 'sun')
34
+ env = ('forest', 'plains', 'water')
35
+ age = ('adult', 'baby')
36
+
37
+ def classify_image(img_path):
38
+ tst_dl = learn.dls.test_dl([img_path])
39
+ preds, _ = learn.get_preds(dl=tst_dl)
40
+
41
+ idxs_bear = preds[:,:len(bear)].argmax(dim=1)
42
+ results_bear = pd.Series(bear[idxs_bear], name="idxs_bear")
43
+
44
+ idxs_env = preds[:,len(bear):(len(bear)+len(env))].argmax(dim=1)
45
+ results_env = pd.Series(env[idxs_env], name="idxs_env")
46
+
47
+ idxs_age = preds[:,(len(bear)+len(env)):(len(bear)+len(env)+len(age))].argmax(dim=1)
48
+ results_age = pd.Series(age[idxs_age], name="idxs_age")
49
+
50
+ results = pd.DataFrame()
51
+ results['Bear'] = results_bear
52
+ results['Environment'] = results_env
53
+ results['Age'] = results_age
54
+
55
+ return results
56
+
57
+ image = gr.Image(height=192, width=192)
58
+ outputs = gr.Dataframe()
59
+
60
+ examples = ['Ronan_Grizzly_Bear_1.jpg', 'blackbear.jpg', 'blackbear2.jpg', 'brownbear.jpg', 'polar.jpg']
61
+
62
+ intf = gr.Interface(fn=classify_image, inputs=image, outputs=outputs, examples=examples)
63
+ intf.launch(inline=False)