File size: 3,205 Bytes
7e9357a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import gradio as gr
import lightning
import numpy as np
import os
import pandas as pd
import timm
import torch
import torch.nn as nn

from torch.utils.data import Dataset, DataLoader

BACKBONE = "resnet18d"
IMAGE_HEIGHT, IMAGE_WIDTH = 512, 512

trained_weights_path = "epoch=9-step=520.ckpt"
trained_weights = torch.load(trained_weights_path, map_location=torch.device('cpu'))["state_dict"]

# recreate the model

class Net(nn.Module):

  def __init__(self, backbone):
    super().__init__()
    self.backbone = timm.create_model(backbone, pretrained=True, in_chans=1, num_classes=0)
    dim_feats = self.backbone(torch.randn((2, 1, 64, 64))).size(1)
    self.embed = nn.Embedding(2, 32)
    self.regressor = nn.Linear(dim_feats + 32, 1)

  def forward(self, x, female):
    feat = self.backbone(x)
    feat = torch.cat([feat, self.embed(female.long())], dim=1)
    return self.regressor(feat)

class BoneAgeModel(lightning.LightningModule):

  def __init__(self, net, optimizer, scheduler, loss_fn):
    super().__init__()
    self.net = net
    self.optimizer = optimizer
    self.scheduler = scheduler
    self.loss_fn = loss_fn

    self.val_losses = []

  def training_step(self, batch, batch_index):
    out = self.net(batch["x"])
    loss = self.loss_fn(out, batch["y"])
    return loss

  def validation_step(self, batch, batch_index):
    out = self.net(batch["x"])
    loss = self.loss_fn(out, batch["y"])
    self.val_losses.append(loss.item())

  def on_validation_epoch_end(self, *args, **kwargs):
    val_loss = np.mean(self.val_losses)
    self.val_losses = []
    print(f"Validation Loss : {val_loss:0.3f}")

  def configure_optimizers(self):
    lr_scheduler = {"scheduler": self.scheduler, "interval": "step"}
    return {"optimizer": self.optimizer, "lr_scheduler": lr_scheduler}

class BoneAgeModelV2(BoneAgeModel):

  def training_step(self, batch, batch_index):
    out = self.net(batch["x"], batch["female"])
    loss = self.loss_fn(out, batch["y"])
    return loss

  def validation_step(self, batch, batch_index):
    out = self.net(batch["x"], batch["female"])
    loss = self.loss_fn(out, batch["y"])
    self.val_losses.append(loss.item())


net = Net(BACKBONE)
trained_model = BoneAgeModelV2(net, None, None, None)
trained_model.load_state_dict(trained_weights)
trained_model.eval()


def predict_bone_age(Radiograph, Sex):
  img = torch.from_numpy(Radiograph)
  img = img.unsqueeze(0).unsqueeze(0) # add channel and batch dimensions
  img = img / 255. # use same normalization as in the PyTorch dataset
  binary_sex = torch.tensor(Sex == "Female").unsqueeze(0)
  with torch.inference_mode():
    bone_age = trained_model.net(img, binary_sex)[0].item()
  years = int(bone_age)
  months = round((bone_age - years) * 12)
  return f"Predicted Bone Age: {years} years, {months} months"


image = gr.Image(height=IMAGE_HEIGHT, width=IMAGE_WIDTH, image_mode="L") # L for grayscale
# additional input
sex = gr.Radio(["Male", "Female"], type="index")
label = gr.Label(show_label=True, label="Bone Age Prediction")

demo = gr.Interface(fn=predict_bone_age,
                    inputs=[image, sex], # <- adding sex as an input
                    outputs=label)

demo.launch(debug=True)