|
|
| import torch, torch.nn as nn |
| from fastai.vision.all import * |
| import gradio as gr |
|
|
| |
| learn = load_learner('bear_classifier.pkl') |
| |
|
|
|
|
| learn.model.eval() |
| device = default_device() |
|
|
| |
|
|
|
|
| |
| embedding_model = nn.Sequential( |
| learn.model[0], |
| nn.AdaptiveAvgPool2d(1), |
| nn.Flatten() |
| ).to(device).eval() |
|
|
| |
| try: |
| |
| cls_means = torch.load('class_means.pt', map_location='cpu', weights_only=False) |
| grizzly_mean = cls_means['grizzly'] |
| black_mean = cls_means['black'] |
| teddy_mean = cls_means['teddy'] |
| except FileNotFoundError: |
| |
| print('class_means.pt not found – recomputing on the fly…') |
| embs, lbls = [], [] |
| for xb, yb in learn.dls.train: |
| with torch.no_grad(): |
| xb = xb.to(device) |
| feats = embedding_model(xb).cpu() |
| embs.append(feats) |
| lbls.append(yb.cpu()) |
| embs = torch.cat(embs) |
| lbls = torch.cat(lbls) |
| grizzly_mean = embs[lbls == 0].mean(0) |
| black_mean = embs[lbls == 1].mean(0) |
| teddy_mean = embs[lbls == 2].mean(0) |
| torch.save({'grizzly': grizzly_mean, |
| 'black' : black_mean, |
| 'teddy' : teddy_mean}, 'class_means.pt') |
|
|
| |
| DIST_THRESHOLD = 0.2 |
|
|
| def get_embedding_tensor(pil_img): |
| """Convert PIL → tensor → 512-d embedding""" |
| dl = learn.dls.test_dl([pil_img]) |
| xb = dl.one_batch()[0].to(device) |
| with torch.no_grad(): |
| emb = embedding_model(xb)[0] |
| return emb.cpu() |
|
|
|
|
|
|
| def predict_embedding_ood_from_img(pil_img, |
| distance_threshold=DIST_THRESHOLD): |
| emb = get_embedding_tensor(pil_img) |
|
|
| d_grizzly = 1 - torch.nn.functional.cosine_similarity( |
| emb, grizzly_mean, dim=0) |
| d_black = 1 - torch.nn.functional.cosine_similarity( |
| emb, black_mean, dim=0) |
| d_teddy = 1 - torch.nn.functional.cosine_similarity( |
| emb, teddy_mean, dim=0) |
|
|
| distances = {'Grizzly': d_grizzly.item(), |
| 'Black' : d_black.item(), |
| 'Teddy' : d_teddy.item()} |
|
|
| closest_class, closest_dist = min(distances.items(), |
| key=lambda kv: kv[1]) |
|
|
| if closest_dist < distance_threshold: |
| return f"{closest_class} (the distance is {closest_dist:.2f})" |
| else: |
| return (f"This is an image of something other than a bear (grizzly distance {d_grizzly.item():.2f}, black distance {d_black.item():.2f}, teddy distance {d_teddy.item():.2f})") |
| |
|
|
| |
| demo = gr.Interface( |
| fn=predict_embedding_ood_from_img, |
| inputs=gr.Image(type='pil'), |
| outputs=gr.Textbox(), |
| title="Bear Classifier with OOD Detection", |
| description=("Predicts Grizzly / Black / Teddy bear, " |
| "or returns Other if the image is far from all three " |
| "class prototypes."), |
| examples=[ |
| ['grizzly.jpg'], |
| ['black.jpg'], |
| ['teddy.jpg'], |
| ['f1.jpg'] |
| ] |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|