Spaces:
Build error
Build error
File size: 3,205 Bytes
7e9357a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 | import gradio as gr
import lightning
import numpy as np
import os
import pandas as pd
import timm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
BACKBONE = "resnet18d"
IMAGE_HEIGHT, IMAGE_WIDTH = 512, 512
trained_weights_path = "epoch=9-step=520.ckpt"
trained_weights = torch.load(trained_weights_path, map_location=torch.device('cpu'))["state_dict"]
# recreate the model
class Net(nn.Module):
def __init__(self, backbone):
super().__init__()
self.backbone = timm.create_model(backbone, pretrained=True, in_chans=1, num_classes=0)
dim_feats = self.backbone(torch.randn((2, 1, 64, 64))).size(1)
self.embed = nn.Embedding(2, 32)
self.regressor = nn.Linear(dim_feats + 32, 1)
def forward(self, x, female):
feat = self.backbone(x)
feat = torch.cat([feat, self.embed(female.long())], dim=1)
return self.regressor(feat)
class BoneAgeModel(lightning.LightningModule):
def __init__(self, net, optimizer, scheduler, loss_fn):
super().__init__()
self.net = net
self.optimizer = optimizer
self.scheduler = scheduler
self.loss_fn = loss_fn
self.val_losses = []
def training_step(self, batch, batch_index):
out = self.net(batch["x"])
loss = self.loss_fn(out, batch["y"])
return loss
def validation_step(self, batch, batch_index):
out = self.net(batch["x"])
loss = self.loss_fn(out, batch["y"])
self.val_losses.append(loss.item())
def on_validation_epoch_end(self, *args, **kwargs):
val_loss = np.mean(self.val_losses)
self.val_losses = []
print(f"Validation Loss : {val_loss:0.3f}")
def configure_optimizers(self):
lr_scheduler = {"scheduler": self.scheduler, "interval": "step"}
return {"optimizer": self.optimizer, "lr_scheduler": lr_scheduler}
class BoneAgeModelV2(BoneAgeModel):
def training_step(self, batch, batch_index):
out = self.net(batch["x"], batch["female"])
loss = self.loss_fn(out, batch["y"])
return loss
def validation_step(self, batch, batch_index):
out = self.net(batch["x"], batch["female"])
loss = self.loss_fn(out, batch["y"])
self.val_losses.append(loss.item())
net = Net(BACKBONE)
trained_model = BoneAgeModelV2(net, None, None, None)
trained_model.load_state_dict(trained_weights)
trained_model.eval()
def predict_bone_age(Radiograph, Sex):
img = torch.from_numpy(Radiograph)
img = img.unsqueeze(0).unsqueeze(0) # add channel and batch dimensions
img = img / 255. # use same normalization as in the PyTorch dataset
binary_sex = torch.tensor(Sex == "Female").unsqueeze(0)
with torch.inference_mode():
bone_age = trained_model.net(img, binary_sex)[0].item()
years = int(bone_age)
months = round((bone_age - years) * 12)
return f"Predicted Bone Age: {years} years, {months} months"
image = gr.Image(height=IMAGE_HEIGHT, width=IMAGE_WIDTH, image_mode="L") # L for grayscale
# additional input
sex = gr.Radio(["Male", "Female"], type="index")
label = gr.Label(show_label=True, label="Bone Age Prediction")
demo = gr.Interface(fn=predict_bone_age,
inputs=[image, sex], # <- adding sex as an input
outputs=label)
demo.launch(debug=True) |