rsna-boneage / app.py
felipekitamura's picture
Create app.py
7e9357a verified
import gradio as gr
import lightning
import numpy as np
import os
import pandas as pd
import timm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
BACKBONE = "resnet18d"
IMAGE_HEIGHT, IMAGE_WIDTH = 512, 512
trained_weights_path = "epoch=9-step=520.ckpt"
trained_weights = torch.load(trained_weights_path, map_location=torch.device('cpu'))["state_dict"]
# recreate the model
class Net(nn.Module):
def __init__(self, backbone):
super().__init__()
self.backbone = timm.create_model(backbone, pretrained=True, in_chans=1, num_classes=0)
dim_feats = self.backbone(torch.randn((2, 1, 64, 64))).size(1)
self.embed = nn.Embedding(2, 32)
self.regressor = nn.Linear(dim_feats + 32, 1)
def forward(self, x, female):
feat = self.backbone(x)
feat = torch.cat([feat, self.embed(female.long())], dim=1)
return self.regressor(feat)
class BoneAgeModel(lightning.LightningModule):
def __init__(self, net, optimizer, scheduler, loss_fn):
super().__init__()
self.net = net
self.optimizer = optimizer
self.scheduler = scheduler
self.loss_fn = loss_fn
self.val_losses = []
def training_step(self, batch, batch_index):
out = self.net(batch["x"])
loss = self.loss_fn(out, batch["y"])
return loss
def validation_step(self, batch, batch_index):
out = self.net(batch["x"])
loss = self.loss_fn(out, batch["y"])
self.val_losses.append(loss.item())
def on_validation_epoch_end(self, *args, **kwargs):
val_loss = np.mean(self.val_losses)
self.val_losses = []
print(f"Validation Loss : {val_loss:0.3f}")
def configure_optimizers(self):
lr_scheduler = {"scheduler": self.scheduler, "interval": "step"}
return {"optimizer": self.optimizer, "lr_scheduler": lr_scheduler}
class BoneAgeModelV2(BoneAgeModel):
def training_step(self, batch, batch_index):
out = self.net(batch["x"], batch["female"])
loss = self.loss_fn(out, batch["y"])
return loss
def validation_step(self, batch, batch_index):
out = self.net(batch["x"], batch["female"])
loss = self.loss_fn(out, batch["y"])
self.val_losses.append(loss.item())
net = Net(BACKBONE)
trained_model = BoneAgeModelV2(net, None, None, None)
trained_model.load_state_dict(trained_weights)
trained_model.eval()
def predict_bone_age(Radiograph, Sex):
img = torch.from_numpy(Radiograph)
img = img.unsqueeze(0).unsqueeze(0) # add channel and batch dimensions
img = img / 255. # use same normalization as in the PyTorch dataset
binary_sex = torch.tensor(Sex == "Female").unsqueeze(0)
with torch.inference_mode():
bone_age = trained_model.net(img, binary_sex)[0].item()
years = int(bone_age)
months = round((bone_age - years) * 12)
return f"Predicted Bone Age: {years} years, {months} months"
image = gr.Image(height=IMAGE_HEIGHT, width=IMAGE_WIDTH, image_mode="L") # L for grayscale
# additional input
sex = gr.Radio(["Male", "Female"], type="index")
label = gr.Label(show_label=True, label="Bone Age Prediction")
demo = gr.Interface(fn=predict_bone_age,
inputs=[image, sex], # <- adding sex as an input
outputs=label)
demo.launch(debug=True)