felipekitamura commited on
Commit
7e9357a
·
verified ·
1 Parent(s): cbd708f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -0
app.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import lightning
3
+ import numpy as np
4
+ import os
5
+ import pandas as pd
6
+ import timm
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from torch.utils.data import Dataset, DataLoader
11
+
12
+ BACKBONE = "resnet18d"
13
+ IMAGE_HEIGHT, IMAGE_WIDTH = 512, 512
14
+
15
+ trained_weights_path = "epoch=9-step=520.ckpt"
16
+ trained_weights = torch.load(trained_weights_path, map_location=torch.device('cpu'))["state_dict"]
17
+
18
+ # recreate the model
19
+
20
+ class Net(nn.Module):
21
+
22
+ def __init__(self, backbone):
23
+ super().__init__()
24
+ self.backbone = timm.create_model(backbone, pretrained=True, in_chans=1, num_classes=0)
25
+ dim_feats = self.backbone(torch.randn((2, 1, 64, 64))).size(1)
26
+ self.embed = nn.Embedding(2, 32)
27
+ self.regressor = nn.Linear(dim_feats + 32, 1)
28
+
29
+ def forward(self, x, female):
30
+ feat = self.backbone(x)
31
+ feat = torch.cat([feat, self.embed(female.long())], dim=1)
32
+ return self.regressor(feat)
33
+
34
+ class BoneAgeModel(lightning.LightningModule):
35
+
36
+ def __init__(self, net, optimizer, scheduler, loss_fn):
37
+ super().__init__()
38
+ self.net = net
39
+ self.optimizer = optimizer
40
+ self.scheduler = scheduler
41
+ self.loss_fn = loss_fn
42
+
43
+ self.val_losses = []
44
+
45
+ def training_step(self, batch, batch_index):
46
+ out = self.net(batch["x"])
47
+ loss = self.loss_fn(out, batch["y"])
48
+ return loss
49
+
50
+ def validation_step(self, batch, batch_index):
51
+ out = self.net(batch["x"])
52
+ loss = self.loss_fn(out, batch["y"])
53
+ self.val_losses.append(loss.item())
54
+
55
+ def on_validation_epoch_end(self, *args, **kwargs):
56
+ val_loss = np.mean(self.val_losses)
57
+ self.val_losses = []
58
+ print(f"Validation Loss : {val_loss:0.3f}")
59
+
60
+ def configure_optimizers(self):
61
+ lr_scheduler = {"scheduler": self.scheduler, "interval": "step"}
62
+ return {"optimizer": self.optimizer, "lr_scheduler": lr_scheduler}
63
+
64
+ class BoneAgeModelV2(BoneAgeModel):
65
+
66
+ def training_step(self, batch, batch_index):
67
+ out = self.net(batch["x"], batch["female"])
68
+ loss = self.loss_fn(out, batch["y"])
69
+ return loss
70
+
71
+ def validation_step(self, batch, batch_index):
72
+ out = self.net(batch["x"], batch["female"])
73
+ loss = self.loss_fn(out, batch["y"])
74
+ self.val_losses.append(loss.item())
75
+
76
+
77
+ net = Net(BACKBONE)
78
+ trained_model = BoneAgeModelV2(net, None, None, None)
79
+ trained_model.load_state_dict(trained_weights)
80
+ trained_model.eval()
81
+
82
+
83
+ def predict_bone_age(Radiograph, Sex):
84
+ img = torch.from_numpy(Radiograph)
85
+ img = img.unsqueeze(0).unsqueeze(0) # add channel and batch dimensions
86
+ img = img / 255. # use same normalization as in the PyTorch dataset
87
+ binary_sex = torch.tensor(Sex == "Female").unsqueeze(0)
88
+ with torch.inference_mode():
89
+ bone_age = trained_model.net(img, binary_sex)[0].item()
90
+ years = int(bone_age)
91
+ months = round((bone_age - years) * 12)
92
+ return f"Predicted Bone Age: {years} years, {months} months"
93
+
94
+
95
+ image = gr.Image(height=IMAGE_HEIGHT, width=IMAGE_WIDTH, image_mode="L") # L for grayscale
96
+ # additional input
97
+ sex = gr.Radio(["Male", "Female"], type="index")
98
+ label = gr.Label(show_label=True, label="Bone Age Prediction")
99
+
100
+ demo = gr.Interface(fn=predict_bone_age,
101
+ inputs=[image, sex], # <- adding sex as an input
102
+ outputs=label)
103
+
104
+ demo.launch(debug=True)