jalarrimore commited on
Commit
4f67aa7
·
verified ·
1 Parent(s): 68709bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -11
app.py CHANGED
@@ -5,18 +5,33 @@ __all__ = ['learn', 'classify_image', 'bear', 'env', 'age', 'image', 'label', 'e
5
  from fastai.vision.all import *
6
  import gradio as gr
7
 
8
- def get_bear(p):
9
- name_components = str(p).split(' ')
10
- return name_components[2]
11
 
12
- def get_age(p):
13
- name_components = str(p).split(' ')
14
- return name_components[1]
 
 
 
 
 
 
15
 
16
- def get_environ(p):
17
- name_components = str(p).split(' ')
18
- file_components = name_components[0].split('/')
19
- return file_components[-1]
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  def bear_err(inp,bear,environ,age): return error_rate(inp[:,:5],bear)
22
  def bear_loss(inp,bear,environ,age): return F.cross_entropy(inp[:,:5],bear)
@@ -28,7 +43,7 @@ def age_loss(inp,bear,environ,age): return F.cross_entropy(inp[:,8:],age)
28
  def combine_loss(inp,bear,environ,age):
29
  return bear_loss(inp,bear,environ,age)+environ_loss(inp,bear,environ,age)+age_loss(inp,bear,environ,age)
30
 
31
- learn = load_learner('multitarget_bears.pkl')
32
 
33
  bear = ('black', 'brown', 'grizzly', 'sloth', 'sun')
34
  env = ('forest', 'plains', 'water')
 
5
  from fastai.vision.all import *
6
  import gradio as gr
7
 
8
+ def get_x(r): return r['fname']
 
 
9
 
10
+ def get_bear(r):
11
+ """Returns all bear types associated with the given filename."""
12
+ fname = r['fname'] # Extract filename from row
13
+ row = df[df['fname'] == fname] # Filter dataframe using single value
14
+ if not row.empty:
15
+ bear_types = row['bear_type'].values[0].split() # Get list of bear types
16
+ bear_types = [b.replace("bear", "").strip() for b in bear_types] # Remove "bear" and trim spaces
17
+ return [b for b in bear_types if b] # Remove any empty strings
18
+ return []
19
 
20
+ def get_age(r):
21
+ """Returns all ages associated with the given filename."""
22
+ fname = r['fname']
23
+ row = df[df['fname'] == fname]
24
+ if not row.empty:
25
+ return row['age'].values[0].split()
26
+ return []
27
+
28
+ def get_environ(r):
29
+ """Returns all environments associated with the given filename."""
30
+ fname = r['fname']
31
+ row = df[df['fname'] == fname]
32
+ if not row.empty:
33
+ return row['environ'].values[0].split()
34
+ return []
35
 
36
  def bear_err(inp,bear,environ,age): return error_rate(inp[:,:5],bear)
37
  def bear_loss(inp,bear,environ,age): return F.cross_entropy(inp[:,:5],bear)
 
43
  def combine_loss(inp,bear,environ,age):
44
  return bear_loss(inp,bear,environ,age)+environ_loss(inp,bear,environ,age)+age_loss(inp,bear,environ,age)
45
 
46
+ learn = load_learner('multimulti.pkl')
47
 
48
  bear = ('black', 'brown', 'grizzly', 'sloth', 'sun')
49
  env = ('forest', 'plains', 'water')