File size: 3,854 Bytes
9154b8b 32bb702 4a18314 8951ddb 69a2b4b 8951ddb 32bb702 95b1ec2 32bb702 9154b8b 32bb702 eb04f86 32bb702 e67b2a0 32bb702 e67b2a0 32bb702 e7676d7 32bb702 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import torch, torch.nn as nn
from fastai.vision.all import *
import gradio as gr
import pathlib
from pathlib import PosixPath
plt = platform.system()
if plt == 'Linux': pathlib.WindowsPath = pathlib.PosixPath
# ---------- 1. load the trained learner ----------
learn = load_learner('bear_classifier.pkl')
#embedding_model = nn.Sequential(*list(learn.model.children())[:-2]) # Match notebook logic here
learn.model.eval() # inference mode
device = default_device() # GPU if available
#embedding_model.to(device).eval() # move to GPU
# ---------- 2. build the embedding extractor ----------
embedding_model = nn.Sequential(
learn.model[0], # CNN backbone
nn.AdaptiveAvgPool2d(1), # [B, C, 1, 1]
nn.Flatten() # [B, 512]
).to(device).eval()
# ---------- 3. get or compute class-mean embeddings ----------
try:
# if you saved them to disk:
cls_means = torch.load('class_means.pt', map_location='cpu', weights_only=False)
grizzly_mean = cls_means['grizzly']
black_mean = cls_means['black']
teddy_mean = cls_means['teddy']
except FileNotFoundError:
# recompute from the learner’s training dataloader
print('class_means.pt not found – recomputing on the fly…')
embs, lbls = [], []
for xb, yb in learn.dls.train:
with torch.no_grad():
xb = xb.to(device)
feats = embedding_model(xb).cpu()
embs.append(feats)
lbls.append(yb.cpu())
embs = torch.cat(embs)
lbls = torch.cat(lbls)
grizzly_mean = embs[lbls == 0].mean(0)
black_mean = embs[lbls == 1].mean(0)
teddy_mean = embs[lbls == 2].mean(0)
torch.save({'grizzly': grizzly_mean,
'black' : black_mean,
'teddy' : teddy_mean}, 'class_means.pt')
# ---------- 4. helper functions ----------
DIST_THRESHOLD = 0.2 # tune to taste
def get_embedding_tensor(pil_img):
"""Convert PIL → tensor → 512-d embedding"""
dl = learn.dls.test_dl([pil_img])
xb = dl.one_batch()[0].to(device) # shape [1,3,H,W]
with torch.no_grad():
emb = embedding_model(xb)[0] # shape [512]
return emb.cpu()
def predict_embedding_ood_from_img(pil_img,
distance_threshold=DIST_THRESHOLD):
emb = get_embedding_tensor(pil_img)
d_grizzly = 1 - torch.nn.functional.cosine_similarity(
emb, grizzly_mean, dim=0)
d_black = 1 - torch.nn.functional.cosine_similarity(
emb, black_mean, dim=0)
d_teddy = 1 - torch.nn.functional.cosine_similarity(
emb, teddy_mean, dim=0)
distances = {'Grizzly': d_grizzly.item(),
'Black' : d_black.item(),
'Teddy' : d_teddy.item()}
closest_class, closest_dist = min(distances.items(),
key=lambda kv: kv[1])
if closest_dist < distance_threshold:
return f"{closest_class} (the distance is {closest_dist:.2f})"
else:
return (f"This is an image of something other than a bear (grizzly distance {d_grizzly.item():.2f}, black distance {d_black.item():.2f}, teddy distance {d_teddy.item():.2f})")
#return "This is an image of something other than a bear"
# ---------- 5. Gradio interface ----------
demo = gr.Interface(
fn=predict_embedding_ood_from_img,
inputs=gr.Image(type='pil'),
outputs=gr.Textbox(),
title="Bear Classifier with OOD Detection",
description=("Predicts Grizzly / Black / Teddy bear, "
"or returns Other if the image is far from all three "
"class prototypes."),
examples=[
['grizzly.jpg'],
['black.jpg'],
['teddy.jpg'],
['f1.jpg']
]
)
if __name__ == "__main__":
demo.launch()
|