fixed other
Browse files- app.py +97 -16
- bear_classifier.pkl +2 -2
- export.pkl → class_means.pt +2 -2
- requirements.txt +0 -1
app.py
CHANGED
|
@@ -1,24 +1,105 @@
|
|
| 1 |
-
import gradio as gr
|
| 2 |
-
from fastai.vision.all import *
|
| 3 |
-
import skimage
|
| 4 |
|
|
|
|
|
|
|
|
|
|
| 5 |
|
|
|
|
| 6 |
learn = load_learner('bear_classifier.pkl')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
-
|
|
|
|
|
|
|
| 9 |
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
pred,pred_idx,probs = learn.predict(img)
|
| 13 |
-
return {labels[i]: float(probs[i]) for i in range(len(labels))}
|
| 14 |
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
|
| 21 |
-
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
|
| 2 |
+
import torch, torch.nn as nn
|
| 3 |
+
from fastai.vision.all import *
|
| 4 |
+
import gradio as gr
|
| 5 |
|
| 6 |
+
# ---------- 1. load the trained learner ----------
|
| 7 |
learn = load_learner('bear_classifier.pkl')
|
| 8 |
+
#embedding_model = nn.Sequential(*list(learn.model.children())[:-2]) # Match notebook logic here
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
learn.model.eval() # inference mode
|
| 12 |
+
device = default_device() # GPU if available
|
| 13 |
+
|
| 14 |
+
#embedding_model.to(device).eval() # move to GPU
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# ---------- 2. build the embedding extractor ----------
|
| 18 |
+
embedding_model = nn.Sequential(
|
| 19 |
+
learn.model[0], # CNN backbone
|
| 20 |
+
nn.AdaptiveAvgPool2d(1), # [B, C, 1, 1]
|
| 21 |
+
nn.Flatten() # [B, 512]
|
| 22 |
+
).to(device).eval()
|
| 23 |
+
|
| 24 |
+
# ---------- 3. get or compute class-mean embeddings ----------
|
| 25 |
+
try:
|
| 26 |
+
# if you saved them to disk:
|
| 27 |
+
cls_means = torch.load('class_means.pt', map_location='cpu', weights_only=False)
|
| 28 |
+
grizzly_mean = cls_means['grizzly']
|
| 29 |
+
black_mean = cls_means['black']
|
| 30 |
+
teddy_mean = cls_means['teddy']
|
| 31 |
+
except FileNotFoundError:
|
| 32 |
+
# recompute from the learner’s training dataloader
|
| 33 |
+
print('class_means.pt not found – recomputing on the fly…')
|
| 34 |
+
embs, lbls = [], []
|
| 35 |
+
for xb, yb in learn.dls.train:
|
| 36 |
+
with torch.no_grad():
|
| 37 |
+
xb = xb.to(device)
|
| 38 |
+
feats = embedding_model(xb).cpu()
|
| 39 |
+
embs.append(feats)
|
| 40 |
+
lbls.append(yb.cpu())
|
| 41 |
+
embs = torch.cat(embs)
|
| 42 |
+
lbls = torch.cat(lbls)
|
| 43 |
+
grizzly_mean = embs[lbls == 0].mean(0)
|
| 44 |
+
black_mean = embs[lbls == 1].mean(0)
|
| 45 |
+
teddy_mean = embs[lbls == 2].mean(0)
|
| 46 |
+
torch.save({'grizzly': grizzly_mean,
|
| 47 |
+
'black' : black_mean,
|
| 48 |
+
'teddy' : teddy_mean}, 'class_means.pt')
|
| 49 |
+
|
| 50 |
+
# ---------- 4. helper functions ----------
|
| 51 |
+
DIST_THRESHOLD = 0.2 # tune to taste
|
| 52 |
+
|
| 53 |
+
def get_embedding_tensor(pil_img):
|
| 54 |
+
"""Convert PIL → tensor → 512-d embedding"""
|
| 55 |
+
dl = learn.dls.test_dl([pil_img])
|
| 56 |
+
xb = dl.one_batch()[0].to(device) # shape [1,3,H,W]
|
| 57 |
+
with torch.no_grad():
|
| 58 |
+
emb = embedding_model(xb)[0] # shape [512]
|
| 59 |
+
return emb.cpu()
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def predict_embedding_ood_from_img(pil_img,
|
| 64 |
+
distance_threshold=DIST_THRESHOLD):
|
| 65 |
+
emb = get_embedding_tensor(pil_img)
|
| 66 |
+
|
| 67 |
+
d_grizzly = 1 - torch.nn.functional.cosine_similarity(
|
| 68 |
+
emb, grizzly_mean, dim=0)
|
| 69 |
+
d_black = 1 - torch.nn.functional.cosine_similarity(
|
| 70 |
+
emb, black_mean, dim=0)
|
| 71 |
+
d_teddy = 1 - torch.nn.functional.cosine_similarity(
|
| 72 |
+
emb, teddy_mean, dim=0)
|
| 73 |
|
| 74 |
+
distances = {'Grizzly': d_grizzly.item(),
|
| 75 |
+
'Black' : d_black.item(),
|
| 76 |
+
'Teddy' : d_teddy.item()}
|
| 77 |
|
| 78 |
+
closest_class, closest_dist = min(distances.items(),
|
| 79 |
+
key=lambda kv: kv[1])
|
|
|
|
|
|
|
| 80 |
|
| 81 |
+
if closest_dist < distance_threshold:
|
| 82 |
+
return f"{closest_class} (the distance is {closest_dist:.2f})"
|
| 83 |
+
else:
|
| 84 |
+
return (f"This is an image of something other than a bear (grizzly distance {d_grizzly.item():.2f}, black distance {d_black.item():.2f}, teddy distance {d_teddy.item():.2f})")
|
| 85 |
+
#return "This is an image of something other than a bear"
|
| 86 |
|
| 87 |
+
# ---------- 5. Gradio interface ----------
|
| 88 |
+
demo = gr.Interface(
|
| 89 |
+
fn=predict_embedding_ood_from_img,
|
| 90 |
+
inputs=gr.Image(type='pil'),
|
| 91 |
+
outputs=gr.Textbox(),
|
| 92 |
+
title="Bear Classifier with OOD Detection",
|
| 93 |
+
description=("Predicts Grizzly / Black / Teddy bear, "
|
| 94 |
+
"or returns Other if the image is far from all three "
|
| 95 |
+
"class prototypes."),
|
| 96 |
+
examples=[
|
| 97 |
+
['grizzly.jpg'],
|
| 98 |
+
['black.jpg'],
|
| 99 |
+
['teddy.jpg'],
|
| 100 |
+
['f1.jpg']
|
| 101 |
+
]
|
| 102 |
+
)
|
| 103 |
|
| 104 |
+
if __name__ == "__main__":
|
| 105 |
+
demo.launch()
|
bear_classifier.pkl
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7cf53ff828bb524fc62bf2c14c7810de19b48b090ad336954c500a537dfc0db3
|
| 3 |
+
size 87478514
|
export.pkl → class_means.pt
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ea80f1004a02a20d047bcab3790de1823671f13c513a410d90cb7406a004cae7
|
| 3 |
+
size 7984
|
requirements.txt
CHANGED
|
@@ -5,5 +5,4 @@ torch
|
|
| 5 |
torchvision
|
| 6 |
fastcore<1.8,>=1.5.29
|
| 7 |
fsspec<=2025.3.0,>=2023.1.0
|
| 8 |
-
fasttransform
|
| 9 |
cloudpickle
|
|
|
|
| 5 |
torchvision
|
| 6 |
fastcore<1.8,>=1.5.29
|
| 7 |
fsspec<=2025.3.0,>=2023.1.0
|
|
|
|
| 8 |
cloudpickle
|