Spaces:
Sleeping
Sleeping
| import json | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms | |
| from torchvision.models import efficientnet_b0 | |
| from PIL import Image | |
| import gradio as gr | |
| # ---------- CONFIG ---------- | |
| CKPT_PATH = "best_effnet_twohead.pt" | |
| LABELS_PATH = "labels.json" | |
| IMG_SIZE = 224 | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # ---------------------------- | |
| # load labels | |
| with open(LABELS_PATH, "r") as f: | |
| labels = json.load(f) | |
| SPECIES = labels["species"] | |
| STATE = labels["state"] | |
| # model (must match training) | |
| class EffNetTwoHead(nn.Module): | |
| def __init__(self, num_species, num_states): | |
| super().__init__() | |
| base = efficientnet_b0(weights=None) | |
| self.features = base.features | |
| self.avgpool = base.avgpool | |
| c = base.classifier[1].in_features | |
| self.head_species = nn.Linear(c, num_species) | |
| self.head_state = nn.Linear(c, num_states) | |
| def forward(self, x): | |
| x = self.features(x) | |
| x = self.avgpool(x) | |
| x = torch.flatten(x, 1) | |
| return self.head_species(x), self.head_state(x) | |
| # load model | |
| ckpt = torch.load(CKPT_PATH, map_location="cpu") | |
| model = EffNetTwoHead(len(SPECIES), len(STATE)) | |
| model.load_state_dict(ckpt["model"]) | |
| model.to(DEVICE).eval() | |
| # preprocessing (same as training) | |
| tfm = transforms.Compose([ | |
| transforms.Resize((IMG_SIZE, IMG_SIZE)), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225], | |
| ), | |
| ]) | |
| def predict(image: Image.Image): | |
| if image is None: | |
| return "No image", "No image" | |
| image = image.convert("RGB") | |
| x = tfm(image).unsqueeze(0).to(DEVICE) | |
| log_sp, log_st = model(x) | |
| prob_sp = torch.softmax(log_sp, dim=1)[0] # [num_species] | |
| prob_st = torch.softmax(log_st, dim=1)[0] # [num_states] | |
| sp_id = int(prob_sp.argmax().item()) | |
| st_id = int(prob_st.argmax().item()) | |
| sp_conf = float(prob_sp[sp_id].item()) | |
| st_conf = float(prob_st[st_id].item()) | |
| sp_text = f"{SPECIES[sp_id]} (id={sp_id}, conf={sp_conf:.3f})" | |
| st_text = f"{STATE[st_id]} (id={st_id}, conf={st_conf:.3f})" | |
| return sp_text, st_text | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil"), | |
| outputs=[ | |
| gr.Textbox(label="Predicted species"), | |
| gr.Textbox(label="Predicted state"), | |
| ], | |
| title="EfficientNet Two-Head Layer Trap Nest (LTN) Classifier", | |
| description="12 species, 4 states supported.", | |
| examples=[ | |
| ["./0.jpg"], | |
| ["./1.jpg"], | |
| ["./2.jpg"], | |
| ["./3.jpg"], | |
| ["./4.jpg"], | |
| ["./5.jpg"], | |
| ["./6.jpg"], | |
| ["./7.jpg"], | |
| ["./8.jpg"], | |
| ["./9.jpg"], | |
| ["./10.jpg"], | |
| ["./11.jpg"], | |
| ["./12.jpg"], | |
| ["./13.jpg"], | |
| ["./14.jpg"], | |
| ["./15.jpg"], | |
| ["./16.jpg"], | |
| ["./17.jpg"], | |
| ["./18.jpg"], | |
| ["./19.jpg"], | |
| ["./20.jpg"], | |
| ["./21.jpg"], | |
| ["./22.jpg"], | |
| ], | |
| descriptions=[ | |
| "0: Cacoxnus indagator - Lv", | |
| "1: Chelostoma florisomne - DauLv", | |
| "2: Chelostoma florisomne - OldFood", | |
| "3: Coeliopencyrtus - DauLv", | |
| "4: Eumenidae - DauLv", | |
| "5: Eumenidae - OldFood", | |
| "6: Heriades - DeadLv", | |
| "7: Heriades - Lv", | |
| "8: Heriades - OldFood", | |
| "9: Hylaeus - Lv", | |
| "10: Hylaeus - OldFood", | |
| "11: Megachile - DauLv", | |
| "12: Osmia bicornis - DauLv", | |
| "13: Osmia bicornis - DeadLv", | |
| "14: Osmia bicornis - OldFood", | |
| "15: Osmia cornuta - DauLv", | |
| "16: Osmia cornuta - DeadLv", | |
| "17: Osmia cornuta - OldFood", | |
| "18: Passaloecus - Lv", | |
| "19: Passaloecus - OldFood", | |
| "20: Psenulus - DauLv", | |
| "21: Trypoxylon - DauLv", | |
| "22: Trypoxylon - OldFood", | |
| ], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |