chenchangliu commited on
Commit
cc470a4
·
verified ·
1 Parent(s): 8ad8985

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -0
app.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import transforms
5
+ from torchvision.models import efficientnet_b0
6
+ from PIL import Image
7
+ import gradio as gr
8
+
9
+
10
+ # ---------- CONFIG ----------
11
+ CKPT_PATH = "best_effnet_twohead.pt"
12
+ LABELS_PATH = "labels.json"
13
+ IMG_SIZE = 224
14
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
+ # ----------------------------
16
+
17
+
18
+ # load labels
19
+ with open(LABELS_PATH, "r") as f:
20
+ labels = json.load(f)
21
+
22
+ SPECIES = labels["species"]
23
+ STATE = labels["state"]
24
+
25
+
26
+ # model (must match training)
27
+ class EffNetTwoHead(nn.Module):
28
+ def __init__(self, num_species, num_states):
29
+ super().__init__()
30
+ base = efficientnet_b0(weights=None)
31
+ self.features = base.features
32
+ self.avgpool = base.avgpool
33
+ c = base.classifier[1].in_features
34
+ self.head_species = nn.Linear(c, num_species)
35
+ self.head_state = nn.Linear(c, num_states)
36
+
37
+ def forward(self, x):
38
+ x = self.features(x)
39
+ x = self.avgpool(x)
40
+ x = torch.flatten(x, 1)
41
+ return self.head_species(x), self.head_state(x)
42
+
43
+
44
+ # load model
45
+ ckpt = torch.load(CKPT_PATH, map_location="cpu")
46
+ model = EffNetTwoHead(len(SPECIES), len(STATE))
47
+ model.load_state_dict(ckpt["model"])
48
+ model.to(DEVICE).eval()
49
+
50
+
51
+ # preprocessing (same as training)
52
+ tfm = transforms.Compose([
53
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
54
+ transforms.ToTensor(),
55
+ transforms.Normalize(
56
+ mean=[0.485, 0.456, 0.406],
57
+ std=[0.229, 0.224, 0.225],
58
+ ),
59
+ ])
60
+
61
+
62
+ @torch.no_grad()
63
+ def predict(image: Image.Image):
64
+ if image is None:
65
+ return "No image", "No image"
66
+
67
+ image = image.convert("RGB")
68
+ x = tfm(image).unsqueeze(0).to(DEVICE)
69
+
70
+ log_sp, log_st = model(x)
71
+ sp_id = int(log_sp.argmax(dim=1))
72
+ st_id = int(log_st.argmax(dim=1))
73
+
74
+ return SPECIES[sp_id], STATE[st_id]
75
+
76
+
77
+ demo = gr.Interface(
78
+ fn=predict,
79
+ inputs=gr.Image(type="pil"),
80
+ outputs=[
81
+ gr.Textbox(label="Predicted species"),
82
+ gr.Textbox(label="Predicted state"),
83
+ ],
84
+ title="EfficientNet Two-Head Layer Trap Nest (LTN) Classifier",
85
+ )
86
+
87
+ if __name__ == "__main__":
88
+ demo.launch()