| |
|
| | import os |
| | import numpy as np |
| | import gradio as gr |
| | from glob import glob |
| | from functools import partial |
| | from dataclasses import dataclass |
| |
|
| | import torch |
| | import torchvision |
| | import torch.nn as nn |
| | import lightning.pytorch as pl |
| | import torchvision.transforms as TF |
| |
|
| | from torchmetrics import MeanMetric |
| | from torchmetrics.classification import MultilabelF1Score |
| |
|
| |
|
| | @dataclass |
| | class DatasetConfig: |
| | IMAGE_SIZE: tuple = (384, 384) |
| | CHANNELS: int = 3 |
| | NUM_CLASSES: int = 10 |
| | MEAN: tuple = (0.485, 0.456, 0.406) |
| | STD: tuple = (0.229, 0.224, 0.225) |
| |
|
| |
|
| | @dataclass |
| | class TrainingConfig: |
| | METRIC_THRESH: float = 0.4 |
| | MODEL_NAME: str = "efficientnet_v2_s" |
| | FREEZE_BACKBONE: bool = False |
| |
|
| |
|
| | def get_model(model_name: str, num_classes: int, freeze_backbone: bool = True): |
| | """A helper function to load and prepare any classification model |
| | available in Torchvision for transfer learning or fine-tuning.""" |
| |
|
| | model = getattr(torchvision.models, model_name)(weights="DEFAULT") |
| |
|
| | if freeze_backbone: |
| | |
| | for param in model.parameters(): |
| | param.requires_grad = False |
| |
|
| | model_childrens = [name for name, _ in model.named_children()] |
| |
|
| | try: |
| | final_layer_in_features = getattr(model, f"{model_childrens[-1]}")[-1].in_features |
| | except Exception as e: |
| | final_layer_in_features = getattr(model, f"{model_childrens[-1]}").in_features |
| |
|
| | new_output_layer = nn.Linear(in_features=final_layer_in_features, out_features=num_classes) |
| |
|
| | try: |
| | getattr(model, f"{model_childrens[-1]}")[-1] = new_output_layer |
| | except: |
| | setattr(model, model_childrens[-1], new_output_layer) |
| |
|
| | return model |
| |
|
| |
|
| | class ProteinModel(pl.LightningModule): |
| | def __init__( |
| | self, |
| | model_name: str, |
| | num_classes: int = 10, |
| | freeze_backbone: bool = False, |
| | init_lr: float = 0.001, |
| | optimizer_name: str = "Adam", |
| | weight_decay: float = 1e-4, |
| | use_scheduler: bool = False, |
| | f1_metric_threshold: float = 0.4, |
| | ): |
| | super().__init__() |
| |
|
| | |
| | self.save_hyperparameters() |
| |
|
| | |
| | self.model = get_model( |
| | model_name=self.hparams.model_name, |
| | num_classes=self.hparams.num_classes, |
| | freeze_backbone=self.hparams.freeze_backbone, |
| | ) |
| |
|
| | |
| | self.loss_fn = nn.BCEWithLogitsLoss() |
| |
|
| | |
| | self.mean_train_loss = MeanMetric() |
| | self.mean_train_f1 = MultilabelF1Score(num_labels=self.hparams.num_classes, average="macro", threshold=self.hparams.f1_metric_threshold) |
| | self.mean_valid_loss = MeanMetric() |
| | self.mean_valid_f1 = MultilabelF1Score(num_labels=self.hparams.num_classes, average="macro", threshold=self.hparams.f1_metric_threshold) |
| |
|
| | def forward(self, x): |
| | return self.model(x) |
| |
|
| | def training_step(self, batch, *args, **kwargs): |
| | data, target = batch |
| | logits = self(data) |
| | loss = self.loss_fn(logits, target) |
| |
|
| | self.mean_train_loss(loss, weight=data.shape[0]) |
| | self.mean_train_f1(logits, target) |
| |
|
| | self.log("train/batch_loss", self.mean_train_loss, prog_bar=True) |
| | self.log("train/batch_f1", self.mean_train_f1, prog_bar=True) |
| | return loss |
| |
|
| | def on_train_epoch_end(self): |
| | |
| | self.log("train/loss", self.mean_train_loss, prog_bar=True) |
| | self.log("train/f1", self.mean_train_f1, prog_bar=True) |
| | self.log("step", self.current_epoch) |
| |
|
| | def validation_step(self, batch, *args, **kwargs): |
| | data, target = batch |
| | logits = self(data) |
| | loss = self.loss_fn(logits, target) |
| |
|
| | self.mean_valid_loss.update(loss, weight=data.shape[0]) |
| | self.mean_valid_f1.update(logits, target) |
| |
|
| | def on_validation_epoch_end(self): |
| | |
| | self.log("valid/loss", self.mean_valid_loss, prog_bar=True) |
| | self.log("valid/f1", self.mean_valid_f1, prog_bar=True) |
| | self.log("step", self.current_epoch) |
| |
|
| | def configure_optimizers(self): |
| | optimizer = getattr(torch.optim, self.hparams.optimizer_name)( |
| | filter(lambda p: p.requires_grad, self.model.parameters()), |
| | lr=self.hparams.init_lr, |
| | weight_decay=self.hparams.weight_decay, |
| | ) |
| |
|
| | if self.hparams.use_scheduler: |
| | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( |
| | optimizer, |
| | milestones=[ |
| | self.trainer.max_epochs // 2, |
| | ], |
| | gamma=0.1, |
| | ) |
| |
|
| | |
| | |
| | lr_scheduler_config = { |
| | "scheduler": lr_scheduler, |
| | "interval": "epoch", |
| | "name": "multi_step_lr", |
| | } |
| | return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config} |
| |
|
| | else: |
| | return optimizer |
| |
|
| |
|
| | @torch.inference_mode() |
| | def predict(input_image, threshold=0.4, model=None, preprocess_fn=None, device="cpu", idx2labels=None): |
| | input_tensor = preprocess_fn(input_image) |
| | input_tensor = input_tensor.unsqueeze(0).to(device) |
| |
|
| | |
| | output = model(input_tensor).cpu() |
| |
|
| | probabilities = torch.sigmoid(output)[0].numpy().tolist() |
| |
|
| | output_probs = dict() |
| | predicted_classes = [] |
| |
|
| | for idx, prob in enumerate(probabilities): |
| | output_probs[idx2labels[idx]] = prob |
| | if prob >= threshold: |
| | predicted_classes.append(idx2labels[idx]) |
| |
|
| | predicted_classes = "\n".join(predicted_classes) |
| | return predicted_classes, output_probs |
| |
|
| |
|
| | if __name__ == "__main__": |
| | labels = { |
| | 0: "Mitochondria", |
| | 1: "Nuclear bodies", |
| | 2: "Nucleoli", |
| | 3: "Golgi apparatus", |
| | 4: "Nucleoplasm", |
| | 5: "Nucleoli fibrillar center", |
| | 6: "Cytosol", |
| | 7: "Plasma membrane", |
| | 8: "Centrosome", |
| | 9: "Nuclear speckles", |
| | } |
| |
|
| | DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") |
| | CKPT_PATH = os.path.join(os.getcwd(), r"ckpt_022-vloss_0.1756_vf1_0.7919.ckpt") |
| | model = ProteinModel.load_from_checkpoint(CKPT_PATH) |
| | model.to(DEVICE) |
| | model.eval() |
| | _ = model(torch.randn(1, DatasetConfig.CHANNELS, *DatasetConfig.IMAGE_SIZE[::-1], device=DEVICE)) |
| |
|
| | preprocess = TF.Compose( |
| | [ |
| | TF.Resize(size=DatasetConfig.IMAGE_SIZE[::-1]), |
| | TF.ToTensor(), |
| | TF.Normalize(DatasetConfig.MEAN, DatasetConfig.STD, inplace=True), |
| | ] |
| | ) |
| |
|
| | images_dir = glob(os.path.join(os.getcwd(), "samples") + os.sep + "*.png") |
| | examples = [[i, TrainingConfig.METRIC_THRESH] for i in np.random.choice(images_dir, size=10, replace=False)] |
| | |
| | |
| | |
| | with gr.Interface( |
| | fn=partial(predict, model=model, preprocess_fn=preprocess, device=DEVICE, idx2labels=labels), |
| | inputs=[ |
| | gr.Image(type="pil", label="Image"), |
| | gr.Slider(0.0, 1.0, value=0.4, label="Threshold", info="Select the cut-off threshold for a node to be considered as a valid output."), |
| | ], |
| | outputs=[ |
| | gr.Textbox(label="Labels Present"), |
| | gr.Label(label="Probabilities", show_label=False), |
| | ], |
| | |
| | examples=examples, |
| | cache_examples=False, |
| | allow_flagging="never", |
| | title="Awan AI Medical Image Classification", |
| | theme=gr.themes.Soft(primary_hue="sky", secondary_hue="pink"), |
| | ) as iface: |
| | additional_inputs=[gr.Model3D(label="3D Model", value="./HackMercedIXRunThrough.glb", clear_color=[0.4, 0.2, 0.7, 1.0])] |
| | iface.launch(share=True) |
| |
|