import torch, torch.nn as nn from fastai.vision.all import * import gradio as gr import pathlib from pathlib import PosixPath plt = platform.system() if plt == 'Linux': pathlib.WindowsPath = pathlib.PosixPath # ---------- 1. load the trained learner ---------- learn = load_learner('bear_classifier.pkl') #embedding_model = nn.Sequential(*list(learn.model.children())[:-2]) # Match notebook logic here learn.model.eval() # inference mode device = default_device() # GPU if available #embedding_model.to(device).eval() # move to GPU # ---------- 2. build the embedding extractor ---------- embedding_model = nn.Sequential( learn.model[0], # CNN backbone nn.AdaptiveAvgPool2d(1), # [B, C, 1, 1] nn.Flatten() # [B, 512] ).to(device).eval() # ---------- 3. get or compute class-mean embeddings ---------- try: # if you saved them to disk: cls_means = torch.load('class_means.pt', map_location='cpu', weights_only=False) grizzly_mean = cls_means['grizzly'] black_mean = cls_means['black'] teddy_mean = cls_means['teddy'] except FileNotFoundError: # recompute from the learner’s training dataloader print('class_means.pt not found – recomputing on the fly…') embs, lbls = [], [] for xb, yb in learn.dls.train: with torch.no_grad(): xb = xb.to(device) feats = embedding_model(xb).cpu() embs.append(feats) lbls.append(yb.cpu()) embs = torch.cat(embs) lbls = torch.cat(lbls) grizzly_mean = embs[lbls == 0].mean(0) black_mean = embs[lbls == 1].mean(0) teddy_mean = embs[lbls == 2].mean(0) torch.save({'grizzly': grizzly_mean, 'black' : black_mean, 'teddy' : teddy_mean}, 'class_means.pt') # ---------- 4. helper functions ---------- DIST_THRESHOLD = 0.2 # tune to taste def get_embedding_tensor(pil_img): """Convert PIL → tensor → 512-d embedding""" dl = learn.dls.test_dl([pil_img]) xb = dl.one_batch()[0].to(device) # shape [1,3,H,W] with torch.no_grad(): emb = embedding_model(xb)[0] # shape [512] return emb.cpu() def predict_embedding_ood_from_img(pil_img, distance_threshold=DIST_THRESHOLD): emb = get_embedding_tensor(pil_img) d_grizzly = 1 - torch.nn.functional.cosine_similarity( emb, grizzly_mean, dim=0) d_black = 1 - torch.nn.functional.cosine_similarity( emb, black_mean, dim=0) d_teddy = 1 - torch.nn.functional.cosine_similarity( emb, teddy_mean, dim=0) distances = {'Grizzly': d_grizzly.item(), 'Black' : d_black.item(), 'Teddy' : d_teddy.item()} closest_class, closest_dist = min(distances.items(), key=lambda kv: kv[1]) if closest_dist < distance_threshold: return f"{closest_class} (the distance is {closest_dist:.2f})" else: return (f"This is an image of something other than a bear (grizzly distance {d_grizzly.item():.2f}, black distance {d_black.item():.2f}, teddy distance {d_teddy.item():.2f})") #return "This is an image of something other than a bear" # ---------- 5. Gradio interface ---------- demo = gr.Interface( fn=predict_embedding_ood_from_img, inputs=gr.Image(type='pil'), outputs=gr.Textbox(), title="Bear Classifier with OOD Detection", description=("Predicts Grizzly / Black / Teddy bear, " "or returns Other if the image is far from all three " "class prototypes."), examples=[ ['grizzly.jpg'], ['black.jpg'], ['teddy.jpg'], ['f1.jpg'] ] ) if __name__ == "__main__": demo.launch()