File size: 3,156 Bytes
9d8880e
 
 
 
 
 
 
4f67aa7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d8880e
 
 
 
 
 
 
 
 
 
 
4f67aa7
9d8880e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6910a38
 
9d8880e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
__all__ = ['learn', 'classify_image', 'bear', 'env', 'age', 'image', 'label', 'examples', 'intf', 
           'bear_err', 'bear_loss', 'environ_err', 'environ_loss', 'age_err', 'age_loss', 'combine_loss', 
           'get_bear', 'get_age', 'get_environ']

from fastai.vision.all import *
import gradio as gr

def get_x(r): return r['fname']

def get_bear(r):
    """Returns all bear types associated with the given filename."""
    fname = r['fname']  # Extract filename from row
    row = df[df['fname'] == fname]  # Filter dataframe using single value
    if not row.empty:
        bear_types = row['bear_type'].values[0].split()  # Get list of bear types
        bear_types = [b.replace("bear", "").strip() for b in bear_types]  # Remove "bear" and trim spaces
        return [b for b in bear_types if b]  # Remove any empty strings
    return []

def get_age(r):
    """Returns all ages associated with the given filename."""
    fname = r['fname']
    row = df[df['fname'] == fname]
    if not row.empty:
        return row['age'].values[0].split()
    return []

def get_environ(r):
    """Returns all environments associated with the given filename."""
    fname = r['fname']
    row = df[df['fname'] == fname]
    if not row.empty:
        return row['environ'].values[0].split()
    return []

def bear_err(inp,bear,environ,age): return error_rate(inp[:,:5],bear)
def bear_loss(inp,bear,environ,age): return F.cross_entropy(inp[:,:5],bear)
def environ_err(inp,bear,environ,age): return error_rate(inp[:,5:8],environ)
def environ_loss(inp,bear,environ,age): return F.cross_entropy(inp[:,5:8],environ)
def age_err(inp,bear,environ,age): return error_rate(inp[:,8:],age)
def age_loss(inp,bear,environ,age): return F.cross_entropy(inp[:,8:],age)

def combine_loss(inp,bear,environ,age):    
    return bear_loss(inp,bear,environ,age)+environ_loss(inp,bear,environ,age)+age_loss(inp,bear,environ,age)

learn = load_learner('multimulti.pkl')

bear = ('black', 'brown', 'grizzly', 'sloth', 'sun')
env = ('forest', 'plains', 'water')
age = ('adult', 'baby')

def classify_image(img_path):
    tst_dl = learn.dls.test_dl([img_path])
    preds, _ = learn.get_preds(dl=tst_dl)
    
    idxs_bear = preds[:,:len(bear)].argmax(dim=1)
    results_bear = pd.Series(bear[idxs_bear], name="idxs_bear")
    
    idxs_env = preds[:,len(bear):(len(bear)+len(env))].argmax(dim=1)
    results_env = pd.Series(env[idxs_env], name="idxs_env")
    
    idxs_age = preds[:,(len(bear)+len(env)):(len(bear)+len(env)+len(age))].argmax(dim=1)
    results_age = pd.Series(age[idxs_age], name="idxs_age")
    
    results = pd.DataFrame()
    results['Bear'] = results_bear
    results['Environment'] = results_env
    results['Age'] = results_age
    
    return results

image = gr.Image(height=192, width=192)
outputs = gr.Dataframe()

examples = ['Ronan_Grizzly_Bear_1.jpg', 'blackbear.jpg', 'blackbear2.jpg', 'brownbear.jpg', 'polar.jpg', 
            '0(sun sloth forest black adult baby water grizzly).jpg', '2(black adult plains baby water grizzly).jpg']

intf = gr.Interface(fn=classify_image, inputs=image, outputs=outputs, examples=examples)
intf.launch(inline=False)