File size: 3,854 Bytes
9154b8b
32bb702
 
 
4a18314
8951ddb
 
 
69a2b4b
 
 
8951ddb
32bb702
95b1ec2
32bb702
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9154b8b
32bb702
 
 
eb04f86
32bb702
 
e67b2a0
32bb702
 
 
 
 
e67b2a0
32bb702
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7676d7
32bb702
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113

import torch, torch.nn as nn
from fastai.vision.all import *
import gradio as gr

import pathlib

from pathlib import PosixPath
plt = platform.system()
if plt == 'Linux': pathlib.WindowsPath = pathlib.PosixPath


# ---------- 1. load the trained learner ----------
learn = load_learner('bear_classifier.pkl')
#embedding_model = nn.Sequential(*list(learn.model.children())[:-2])  # Match notebook logic here


learn.model.eval()                 # inference mode
device = default_device()          # GPU if available

#embedding_model.to(device).eval()                  # move to GPU


# ---------- 2. build the embedding extractor ----------
embedding_model = nn.Sequential(
    learn.model[0],                # CNN backbone
    nn.AdaptiveAvgPool2d(1),       # [B, C, 1, 1]
    nn.Flatten()                   # [B, 512]
).to(device).eval()

# ---------- 3. get or compute class-mean embeddings ----------
try:
    # if you saved them to disk:
    cls_means = torch.load('class_means.pt', map_location='cpu', weights_only=False)
    grizzly_mean = cls_means['grizzly']
    black_mean   = cls_means['black']
    teddy_mean   = cls_means['teddy']
except FileNotFoundError:
    # recompute from the learner’s training dataloader
    print('class_means.pt not found – recomputing on the fly…')
    embs, lbls = [], []
    for xb, yb in learn.dls.train:
        with torch.no_grad():
            xb = xb.to(device)
            feats = embedding_model(xb).cpu()
        embs.append(feats)
        lbls.append(yb.cpu())
    embs  = torch.cat(embs)
    lbls  = torch.cat(lbls)
    grizzly_mean = embs[lbls == 0].mean(0)
    black_mean   = embs[lbls == 1].mean(0)
    teddy_mean   = embs[lbls == 2].mean(0)
    torch.save({'grizzly': grizzly_mean,
                'black'  : black_mean,
                'teddy'  : teddy_mean}, 'class_means.pt')

# ---------- 4. helper functions ----------
DIST_THRESHOLD = 0.2                          # tune to taste

def get_embedding_tensor(pil_img):
    """Convert PIL → tensor → 512-d embedding"""
    dl  = learn.dls.test_dl([pil_img])
    xb  = dl.one_batch()[0].to(device)          # shape [1,3,H,W]
    with torch.no_grad():
        emb = embedding_model(xb)[0]            # shape [512]
    return emb.cpu()



def predict_embedding_ood_from_img(pil_img,
                                   distance_threshold=DIST_THRESHOLD):
    emb = get_embedding_tensor(pil_img)

    d_grizzly = 1 - torch.nn.functional.cosine_similarity(
        emb, grizzly_mean, dim=0)
    d_black   = 1 - torch.nn.functional.cosine_similarity(
        emb, black_mean, dim=0)
    d_teddy   = 1 - torch.nn.functional.cosine_similarity(
        emb, teddy_mean, dim=0)

    distances = {'Grizzly': d_grizzly.item(),
                 'Black'  : d_black.item(),
                 'Teddy'  : d_teddy.item()}

    closest_class, closest_dist = min(distances.items(),
                                      key=lambda kv: kv[1])

    if closest_dist < distance_threshold:
        return f"{closest_class} (the distance is {closest_dist:.2f})"
    else:
        return (f"This is an image of something other than a bear (grizzly distance {d_grizzly.item():.2f}, black distance {d_black.item():.2f}, teddy distance {d_teddy.item():.2f})")
        #return "This is an image of something other than a bear"

# ---------- 5. Gradio interface ----------
demo = gr.Interface(
    fn=predict_embedding_ood_from_img,
    inputs=gr.Image(type='pil'),
    outputs=gr.Textbox(),
    title="Bear Classifier with OOD Detection",
    description=("Predicts Grizzly / Black / Teddy bear, "
                 "or returns Other if the image is far from all three "
                 "class prototypes."),
    examples=[
        ['grizzly.jpg'],
        ['black.jpg'],
        ['teddy.jpg'],
        ['f1.jpg']
    ]
)

if __name__ == "__main__":
    demo.launch()