sqlbipro commited on
Commit
32bb702
·
1 Parent(s): ce5a1b0

fixed other

Browse files
app.py CHANGED
@@ -1,24 +1,105 @@
1
- import gradio as gr
2
- from fastai.vision.all import *
3
- import skimage
4
 
 
 
 
5
 
 
6
  learn = load_learner('bear_classifier.pkl')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- labels = learn.dls.vocab
 
 
9
 
10
- def predict(img):
11
- img = PILImage.create(img)
12
- pred,pred_idx,probs = learn.predict(img)
13
- return {labels[i]: float(probs[i]) for i in range(len(labels))}
14
 
15
- title = "Bear Classifier"
16
- description = "A bear classifier trained on some images downloaded from ddg with fastai. Created as a demo for Gradio and HuggingFace Spaces."
17
- examples = ['grizzly.jpg', 'black.jpg', 'teddy.jpg', 'f1.jpg']
18
- interpretation='default'
19
- enable_queue=True
20
 
21
- img = gr.Image()
22
- lbl = gr.Label()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- gr.Interface(fn=predict,inputs=img,outputs=lbl,title=title,description=description,examples=examples).launch()
 
 
 
 
 
1
 
2
+ import torch, torch.nn as nn
3
+ from fastai.vision.all import *
4
+ import gradio as gr
5
 
6
+ # ---------- 1. load the trained learner ----------
7
  learn = load_learner('bear_classifier.pkl')
8
+ #embedding_model = nn.Sequential(*list(learn.model.children())[:-2]) # Match notebook logic here
9
+
10
+
11
+ learn.model.eval() # inference mode
12
+ device = default_device() # GPU if available
13
+
14
+ #embedding_model.to(device).eval() # move to GPU
15
+
16
+
17
+ # ---------- 2. build the embedding extractor ----------
18
+ embedding_model = nn.Sequential(
19
+ learn.model[0], # CNN backbone
20
+ nn.AdaptiveAvgPool2d(1), # [B, C, 1, 1]
21
+ nn.Flatten() # [B, 512]
22
+ ).to(device).eval()
23
+
24
+ # ---------- 3. get or compute class-mean embeddings ----------
25
+ try:
26
+ # if you saved them to disk:
27
+ cls_means = torch.load('class_means.pt', map_location='cpu', weights_only=False)
28
+ grizzly_mean = cls_means['grizzly']
29
+ black_mean = cls_means['black']
30
+ teddy_mean = cls_means['teddy']
31
+ except FileNotFoundError:
32
+ # recompute from the learner’s training dataloader
33
+ print('class_means.pt not found – recomputing on the fly…')
34
+ embs, lbls = [], []
35
+ for xb, yb in learn.dls.train:
36
+ with torch.no_grad():
37
+ xb = xb.to(device)
38
+ feats = embedding_model(xb).cpu()
39
+ embs.append(feats)
40
+ lbls.append(yb.cpu())
41
+ embs = torch.cat(embs)
42
+ lbls = torch.cat(lbls)
43
+ grizzly_mean = embs[lbls == 0].mean(0)
44
+ black_mean = embs[lbls == 1].mean(0)
45
+ teddy_mean = embs[lbls == 2].mean(0)
46
+ torch.save({'grizzly': grizzly_mean,
47
+ 'black' : black_mean,
48
+ 'teddy' : teddy_mean}, 'class_means.pt')
49
+
50
+ # ---------- 4. helper functions ----------
51
+ DIST_THRESHOLD = 0.2 # tune to taste
52
+
53
+ def get_embedding_tensor(pil_img):
54
+ """Convert PIL → tensor → 512-d embedding"""
55
+ dl = learn.dls.test_dl([pil_img])
56
+ xb = dl.one_batch()[0].to(device) # shape [1,3,H,W]
57
+ with torch.no_grad():
58
+ emb = embedding_model(xb)[0] # shape [512]
59
+ return emb.cpu()
60
+
61
+
62
+
63
+ def predict_embedding_ood_from_img(pil_img,
64
+ distance_threshold=DIST_THRESHOLD):
65
+ emb = get_embedding_tensor(pil_img)
66
+
67
+ d_grizzly = 1 - torch.nn.functional.cosine_similarity(
68
+ emb, grizzly_mean, dim=0)
69
+ d_black = 1 - torch.nn.functional.cosine_similarity(
70
+ emb, black_mean, dim=0)
71
+ d_teddy = 1 - torch.nn.functional.cosine_similarity(
72
+ emb, teddy_mean, dim=0)
73
 
74
+ distances = {'Grizzly': d_grizzly.item(),
75
+ 'Black' : d_black.item(),
76
+ 'Teddy' : d_teddy.item()}
77
 
78
+ closest_class, closest_dist = min(distances.items(),
79
+ key=lambda kv: kv[1])
 
 
80
 
81
+ if closest_dist < distance_threshold:
82
+ return f"{closest_class} (the distance is {closest_dist:.2f})"
83
+ else:
84
+ return (f"This is an image of something other than a bear (grizzly distance {d_grizzly.item():.2f}, black distance {d_black.item():.2f}, teddy distance {d_teddy.item():.2f})")
85
+ #return "This is an image of something other than a bear"
86
 
87
+ # ---------- 5. Gradio interface ----------
88
+ demo = gr.Interface(
89
+ fn=predict_embedding_ood_from_img,
90
+ inputs=gr.Image(type='pil'),
91
+ outputs=gr.Textbox(),
92
+ title="Bear Classifier with OOD Detection",
93
+ description=("Predicts Grizzly / Black / Teddy bear, "
94
+ "or returns Other if the image is far from all three "
95
+ "class prototypes."),
96
+ examples=[
97
+ ['grizzly.jpg'],
98
+ ['black.jpg'],
99
+ ['teddy.jpg'],
100
+ ['f1.jpg']
101
+ ]
102
+ )
103
 
104
+ if __name__ == "__main__":
105
+ demo.launch()
bear_classifier.pkl CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:afd8cecb9d55e266e82492ca822d7627544fb6f22f2b6f55dc2fa3331fc9bc9a
3
- size 87482744
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7cf53ff828bb524fc62bf2c14c7810de19b48b090ad336954c500a537dfc0db3
3
+ size 87478514
export.pkl → class_means.pt RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:826b49fb79d63196975da000a660d1e8bc013e31be68c98e9efadd6cb223a6c2
3
- size 46978302
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ea80f1004a02a20d047bcab3790de1823671f13c513a410d90cb7406a004cae7
3
+ size 7984
requirements.txt CHANGED
@@ -5,5 +5,4 @@ torch
5
  torchvision
6
  fastcore<1.8,>=1.5.29
7
  fsspec<=2025.3.0,>=2023.1.0
8
- fasttransform
9
  cloudpickle
 
5
  torchvision
6
  fastcore<1.8,>=1.5.29
7
  fsspec<=2025.3.0,>=2023.1.0
 
8
  cloudpickle