|
|
|
|
|
import torch, torch.nn as nn |
|
|
from fastai.vision.all import * |
|
|
import gradio as gr |
|
|
|
|
|
import pathlib |
|
|
|
|
|
from pathlib import PosixPath |
|
|
plt = platform.system() |
|
|
if plt == 'Linux': pathlib.WindowsPath = pathlib.PosixPath |
|
|
|
|
|
|
|
|
|
|
|
learn = load_learner('bear_classifier.pkl') |
|
|
|
|
|
|
|
|
|
|
|
learn.model.eval() |
|
|
device = default_device() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
embedding_model = nn.Sequential( |
|
|
learn.model[0], |
|
|
nn.AdaptiveAvgPool2d(1), |
|
|
nn.Flatten() |
|
|
).to(device).eval() |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
cls_means = torch.load('class_means.pt', map_location='cpu', weights_only=False) |
|
|
grizzly_mean = cls_means['grizzly'] |
|
|
black_mean = cls_means['black'] |
|
|
teddy_mean = cls_means['teddy'] |
|
|
except FileNotFoundError: |
|
|
|
|
|
print('class_means.pt not found – recomputing on the fly…') |
|
|
embs, lbls = [], [] |
|
|
for xb, yb in learn.dls.train: |
|
|
with torch.no_grad(): |
|
|
xb = xb.to(device) |
|
|
feats = embedding_model(xb).cpu() |
|
|
embs.append(feats) |
|
|
lbls.append(yb.cpu()) |
|
|
embs = torch.cat(embs) |
|
|
lbls = torch.cat(lbls) |
|
|
grizzly_mean = embs[lbls == 0].mean(0) |
|
|
black_mean = embs[lbls == 1].mean(0) |
|
|
teddy_mean = embs[lbls == 2].mean(0) |
|
|
torch.save({'grizzly': grizzly_mean, |
|
|
'black' : black_mean, |
|
|
'teddy' : teddy_mean}, 'class_means.pt') |
|
|
|
|
|
|
|
|
DIST_THRESHOLD = 0.2 |
|
|
|
|
|
def get_embedding_tensor(pil_img): |
|
|
"""Convert PIL → tensor → 512-d embedding""" |
|
|
dl = learn.dls.test_dl([pil_img]) |
|
|
xb = dl.one_batch()[0].to(device) |
|
|
with torch.no_grad(): |
|
|
emb = embedding_model(xb)[0] |
|
|
return emb.cpu() |
|
|
|
|
|
|
|
|
|
|
|
def predict_embedding_ood_from_img(pil_img, |
|
|
distance_threshold=DIST_THRESHOLD): |
|
|
emb = get_embedding_tensor(pil_img) |
|
|
|
|
|
d_grizzly = 1 - torch.nn.functional.cosine_similarity( |
|
|
emb, grizzly_mean, dim=0) |
|
|
d_black = 1 - torch.nn.functional.cosine_similarity( |
|
|
emb, black_mean, dim=0) |
|
|
d_teddy = 1 - torch.nn.functional.cosine_similarity( |
|
|
emb, teddy_mean, dim=0) |
|
|
|
|
|
distances = {'Grizzly': d_grizzly.item(), |
|
|
'Black' : d_black.item(), |
|
|
'Teddy' : d_teddy.item()} |
|
|
|
|
|
closest_class, closest_dist = min(distances.items(), |
|
|
key=lambda kv: kv[1]) |
|
|
|
|
|
if closest_dist < distance_threshold: |
|
|
return f"{closest_class} (the distance is {closest_dist:.2f})" |
|
|
else: |
|
|
return (f"This is an image of something other than a bear (grizzly distance {d_grizzly.item():.2f}, black distance {d_black.item():.2f}, teddy distance {d_teddy.item():.2f})") |
|
|
|
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=predict_embedding_ood_from_img, |
|
|
inputs=gr.Image(type='pil'), |
|
|
outputs=gr.Textbox(), |
|
|
title="Bear Classifier with OOD Detection", |
|
|
description=("Predicts Grizzly / Black / Teddy bear, " |
|
|
"or returns Other if the image is far from all three " |
|
|
"class prototypes."), |
|
|
examples=[ |
|
|
['grizzly.jpg'], |
|
|
['black.jpg'], |
|
|
['teddy.jpg'], |
|
|
['f1.jpg'] |
|
|
] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|