Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| """gradio_app_chestvision-PRO.ipynb | |
| Automatically generated by Colab. | |
| Original file is located at | |
| https://colab.research.google.com/drive/1gVrx5TyipNPvn8D7GaK0pNBCnLeYTAD_ | |
| """ | |
| """### Import dependencies""" | |
| import torch | |
| import torch.nn.functional as F | |
| import torchvision | |
| from torchvision import transforms, models, datasets | |
| from torch import nn, optim | |
| from torch.utils.data import DataLoader, Dataset | |
| from tqdm import tqdm | |
| from torch.utils.data import random_split | |
| import pytorch_lightning as torch_light | |
| from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint | |
| import torchmetrics | |
| from torchmetrics import Metric | |
| import os | |
| import shutil | |
| import subprocess | |
| import pandas as pd | |
| from PIL import Image | |
| import gradio | |
| from functools import partial | |
| from transformers import AutoModel, pipeline | |
| """### Initialize Containers""" | |
| configs = { | |
| "IMAGE_SIZE": (224, 224), # Resize images to (W, H) | |
| "NUM_CHANNELS": 3, # RGB images | |
| "NUM_CLASSES": 15, # Number of output labels | |
| # ImageNet dataset normalization values (for pretrained backbones) | |
| "MEAN": (0.485, 0.456, 0.406), | |
| "STD": (0.229, 0.224, 0.225), | |
| "DEFAULT_BACKBONE": "CheXFormer-small", | |
| "DEFAULT_VLM": "Lingshu-7B", | |
| "THRESHOLD": 0.2 | |
| } | |
| ViT_REGISTRY = { | |
| "CheXFormer-small": "m42-health/CXformer-small", | |
| # "CheXFormer-base": "m42-health/CXformer-base", | |
| "ViT-base-16": "google/vit-base-patch16-224"} | |
| VLM_REGISTRY = { | |
| "MedMO": "MBZUAI/MedMO-8B", | |
| "Qwen3-VL-2B": "Qwen/Qwen3-VL-2B-Instruct", | |
| "Lingshu-7B": "lingshu-medical-mllm/Lingshu-7B", | |
| "MedGemma-4b": "google/medgemma-1.5-4b-it"} | |
| VLM_SYSTEM_PROMPT = """ You are a medical imaging assistant specializing in chest radiography. | |
| A trained multi-label classifier analyzed a chest X-ray and made a prediction, including predicted medical condition(s) and their associated probabilities: | |
| Your task: | |
| 1. Analyze the chest X-ray image to identify key features supporting the predicted condition(s). | |
| 2. Do NOT introduce new diagnoses. | |
| 3. Only explain radiographic findings that could support the listed prediction(s). | |
| 4. Use cautious, uncertainty-aware language. | |
| 5. If probability < 0.50, emphasize uncertainty. | |
| 6. Do NOT contradict the classifier. | |
| Structure your answer as: | |
| Observed Radiographic Findings: | |
| ... | |
| How Chest X-ray Features Support the Predicted Conditions: | |
| ... | |
| """ | |
| ViT_MODEL_CACHE = {} | |
| VLM_MODEL_CACHE = {} | |
| """### Define helper functions""" | |
| # helper function for loading pre-trained model | |
| # =================================================================================================== | |
| class get_pretrained_model(nn.Module): | |
| def __init__( | |
| self, | |
| model_name: str, | |
| num_classes: int, | |
| num_layers_to_unfreeze: int = 0): | |
| super().__init__() | |
| print(f"Loading pretrained [{model_name}] model") | |
| self.backbone = AutoModel.from_pretrained( | |
| ViT_REGISTRY[model_name], | |
| # model_name, | |
| trust_remote_code=True) | |
| hidden_size = self.backbone.config.hidden_size | |
| # Freeze entire backbone first | |
| for param in self.backbone.parameters(): | |
| param.requires_grad = False | |
| # Selectively unfreeze last N layers | |
| if num_layers_to_unfreeze > 0: | |
| self._unfreeze_last_n_layers(num_layers_to_unfreeze) | |
| # Single classification head | |
| self.classifier = nn.Sequential( | |
| nn.LayerNorm(hidden_size), | |
| nn.Dropout(0.4), | |
| nn.Linear(hidden_size, num_classes) ) | |
| def forward(self, x): | |
| outputs = self.backbone(x) | |
| # Use CLS token | |
| img_embeddings = outputs.last_hidden_state[:, 0] | |
| logits = self.classifier(img_embeddings) | |
| return logits | |
| def _unfreeze_last_n_layers(self, n: int): | |
| if hasattr(self.backbone, "encoder"): | |
| encoder_layers = self.backbone.encoder.layer | |
| elif hasattr(self.backbone, "vision_model"): | |
| encoder_layers = self.backbone.vision_model.encoder.layer | |
| else: | |
| raise ValueError("Cannot find encoder layers in backbone.") | |
| total_layers = len(encoder_layers) | |
| n = min(n, total_layers) | |
| print(f"Unfreezing last {n} of {total_layers} transformer layers.") | |
| for layer in encoder_layers[-n:]: | |
| for param in layer.parameters(): | |
| param.requires_grad = True | |
| # helper function for preprocessing input images | |
| # =================================================================================================== | |
| preprocess_fxn = transforms.Compose( | |
| [transforms.Resize(size=configs["IMAGE_SIZE"][::-1]), | |
| transforms.ToTensor(), | |
| transforms.Normalize(configs["MEAN"], configs["STD"], inplace=True)]) | |
| # Map numeric outputs to string labels | |
| labels_dict = { | |
| 0: "Atelectasis", | |
| 1: "Cardiomegaly", | |
| 2: "Consolidation", | |
| 3: "Edema", | |
| 4: "Effusion", | |
| 5: "Emphysema", | |
| 6: "Fibrosis", | |
| 7: "Hernia", | |
| 8: "Infiltration", | |
| 9: "Mass", | |
| 10: "No finding", | |
| 11: "Nodule", | |
| 12: "Pleural_Thickening", | |
| 13: "Pneumonia", | |
| 14: "Pneumothorax"} | |
| """### Create torch lightning model (i.e., classifier) module""" | |
| class modelModule(torch_light.LightningModule): | |
| def __init__(self, backbone_model_name, num_layers_to_unfreeze, num_classes=configs['NUM_CLASSES']): | |
| super().__init__() | |
| self.num_classes = num_classes | |
| self.backbone_model_name = backbone_model_name | |
| self.num_layers_to_unfreeze = num_layers_to_unfreeze | |
| # Load a pretrained backbone and replace its final layer | |
| self.model = get_pretrained_model( | |
| num_classes = self.num_classes, | |
| model_name = self.backbone_model_name, | |
| num_layers_to_unfreeze = self.num_layers_to_unfreeze) | |
| # Binary classification loss operating on raw logits | |
| self.loss_function = torch.nn.BCEWithLogitsLoss() | |
| self.accuracy_function = torchmetrics.classification.MultilabelAccuracy(num_labels=self.num_classes, average="weighted", threshold=0.5) | |
| self.f1_score_function = torchmetrics.classification.MultilabelF1Score(num_labels=self.num_classes, average="weighted", threshold=0.5) | |
| self.auroc_function = torchmetrics.classification.MultilabelAUROC(num_labels=self.num_classes, average="weighted", thresholds=10) | |
| self.map_score_function = torchmetrics.classification.MultilabelAveragePrecision(num_labels=self.num_classes, average="weighted", thresholds=10) | |
| # average options: macro (simple average), micro (sum), weighted (weight by class size, then avg) | |
| # threshold: Threshold for transforming probability to binary (0,1) predictions. For some metrics (e.g., AUROC), represents the number of thresholds (evenly spaced b/n 0–1) the metric should be computed at (resulting array of values are the averaged to obtain the final score) | |
| def forward(self, x): | |
| # Forward pass through the backbone model | |
| return self.model(x) | |
| def _common_step(self, batch, batch_idx): | |
| """ | |
| Shared logic for train / val / test steps. | |
| Computes loss and evaluation metrics. | |
| """ | |
| x, y = batch | |
| # Compute model predictions () | |
| y_logits = self.forward(x) | |
| y_prob = torch.sigmoid(y_logits) | |
| # Compute metrics (expects logits + labels) | |
| loss = self.loss_function(y_logits, y.float()) | |
| # Compute mean loss over all classes | |
| # loss = torchmetrics.aggregation.MeanMetric(self.loss_function(y_hat, y.float()), weight=X.shape[0]) | |
| accuracy = self.accuracy_function(y_prob, y) | |
| f1_score = self.f1_score_function(y_prob, y) | |
| auroc = self.auroc_function(y_prob, y) | |
| mAP = self.map_score_function(y_prob, y) # mean average precision | |
| return loss, y_logits, y, accuracy, f1_score, auroc, mAP | |
| def training_step(self, batch, batch_idx): | |
| # Run shared step | |
| loss, y_logits, y, accuracy, f1_score, auroc, mAP = self._common_step(batch, batch_idx) | |
| # Log epoch-level training metrics | |
| self.log_dict( | |
| {"train_loss": loss, "train_accuracy": accuracy, "train_f1_score": f1_score, "train_auroc": auroc, "train_mAP": mAP}, | |
| on_step=False, on_epoch=True, prog_bar=True) | |
| # Lightning expects the loss key for backprop | |
| return {"loss": loss} | |
| def validation_step(self, batch, batch_idx): | |
| # Run shared step | |
| loss, y_logits, y, accuracy, f1_score, auroc, mAP = self._common_step(batch, batch_idx) | |
| # Log validation metrics | |
| self.log_dict( | |
| {"val_loss": loss, "val_accuracy": accuracy,"val_f1_score": f1_score, "val_auroc": auroc, "val_mAP": mAP}, | |
| on_step=False, on_epoch=True, prog_bar=True) | |
| def test_step(self, batch, batch_idx): | |
| # Run shared step | |
| loss, y_logits, y, accuracy, f1_score, auroc, mAP = self._common_step(batch, batch_idx) | |
| # Log test metrics | |
| self.log_dict( | |
| {"test_loss": loss, "test_accuracy": accuracy,"test_f1_score": f1_score, "test_auroc": auroc, "test_mAP": mAP}, | |
| on_step=False, on_epoch=True, prog_bar=True) | |
| def predict_step(self, batch, batch_idx): | |
| """ | |
| Prediction logic used by trainer.predict(). | |
| Returns model outputs without computing loss. | |
| """ | |
| x = batch if not isinstance(batch, (tuple, list)) else batch[0] | |
| logits = self.forward(x) | |
| # Convert logits to probabilities for inference | |
| probs = torch.sigmoid(logits) | |
| return probs | |
| def configure_optimizers(self): | |
| # Optimizer over all trainable parameters | |
| optimizer = optim.Adam(self.parameters(), lr=3e-5) | |
| return optimizer | |
| """### Create function for running inference (i.e., assistive medical diagnosis)""" | |
| def generate_query(formatted_predictions): | |
| return f""" | |
| The predicted conditions and their corresponding probabilities are given by the following dictionary: | |
| {formatted_predictions} | |
| What features of the chest X-ray image support the predicted condition(s)? | |
| """ | |
| def predictionReportGenerator(vlm_model, img_pil, system_prompt, query_prompt): | |
| # image_ = Image.open(image_path).convert("RGB") | |
| image_ = img_pil.convert("RGB") | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": [{"type": "text", "text": f"{system_prompt}"}]}, | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image_}, | |
| {"type": "text", "text": f"{query_prompt}"}]}] | |
| output = vlm_model(text=messages, max_new_tokens=350) | |
| prediction_explanation = output[0]["generated_text"][-1]["content"] | |
| return prediction_explanation | |
| def run_diagnosis( | |
| backbone_name, | |
| vlm_name, | |
| input_image, | |
| threshold, | |
| preprocess_fn=None, | |
| Idx2labels=None): | |
| # Preprocess | |
| x = preprocess_fn(input_image).unsqueeze(0) | |
| # Resolve backbone | |
| # ckpt_path = os.path.join(CKPT_ROOT, MODEL_REGISTRY[backbone_name]) | |
| ckpt_path = os.path.join(CKPT_ROOT, f"{backbone_name}.ckpt") | |
| if not os.path.exists(ckpt_path): | |
| raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}") | |
| # Load classification model (cache for speed) | |
| if backbone_name not in ViT_MODEL_CACHE: | |
| ViT_MODEL_CACHE[backbone_name] = modelModule.load_from_checkpoint( | |
| ckpt_path, backbone_model_name=backbone_name, num_layers_to_unfreeze = 2) | |
| model = ViT_MODEL_CACHE[backbone_name] | |
| model.eval() | |
| # device = 0 if torch.cuda.is_available() else -1 | |
| # Forward | |
| logits = model(x) | |
| probs = torch.sigmoid(logits)[0].cpu().numpy() | |
| output_probs = { | |
| Idx2labels[i]: float(p) for i, p in enumerate(probs) | |
| } | |
| predicted_classes = [ | |
| Idx2labels[i] for i, p in enumerate(probs) if p >= threshold] | |
| explanation_ = "No prediction was made." | |
| if predicted_classes: | |
| # Load model (cache for speed) | |
| if vlm_name not in VLM_MODEL_CACHE: | |
| VLM_MODEL_CACHE[vlm_name] = pipeline(task = "image-text-to-text", | |
| model = VLM_REGISTRY[vlm_name], | |
| trust_remote_code = True) | |
| VLM_model = VLM_MODEL_CACHE[vlm_name] | |
| formatted_predictions = {label: output_probs[label] for label in predicted_classes} | |
| query_prompt = generate_query(formatted_predictions) | |
| explanation_ = predictionReportGenerator(vlm_model = VLM_model, img_pil = input_image, | |
| system_prompt = VLM_SYSTEM_PROMPT, query_prompt = query_prompt) | |
| return "\n".join(predicted_classes), explanation_, output_probs | |
| """### Gradio app""" | |
| CKPT_ROOT = os.path.join(os.getcwd(), "Vision Transformers") | |
| example_list_dir = os.path.join(os.getcwd(), "Curated test samples") | |
| example_list_img_names = os.listdir(example_list_dir) | |
| example_list = [ | |
| [configs["DEFAULT_BACKBONE"], configs["DEFAULT_VLM"], os.path.join(example_list_dir, example_img)] | |
| for example_img in example_list_img_names[:8] | |
| if example_img.lower().endswith(".png")] | |
| gradio_app = gradio.Interface( | |
| fn = partial(run_diagnosis, preprocess_fn = preprocess_fxn, Idx2labels = labels_dict), | |
| inputs = [gradio.Dropdown(["CheXFormer-small", "ViT-base-16"], value="CheXFormer-small", label="Select Classification Model"), | |
| gradio.Dropdown(["MedGemma-4b", "MedMO", "Lingshu-7B", "Qwen3-VL-2B"], value="Lingshu-7B", label="Select Explanation Model"), | |
| gradio.Image(type="pil", label="Load chest-X-ray image here"), | |
| gradio.Slider(minimum = 0.1, maximum = 0.9, step = 0.05, value = 0.2, label = "Set Prediction Threshold")], | |
| outputs = [gradio.Textbox(label="Predicted Medical Condition(s)"), | |
| gradio.Textbox(label="Prediction Report"), | |
| gradio.Label(label="Predicted Probabilities", show_label=False)], | |
| examples = example_list, | |
| cache_examples = False, | |
| title = "ChestVision-PRO", | |
| description = "Vision-Transformer solutions for assistive medical diagnosis with Vision-Language-based prediction justification", | |
| article = "Author: C. Foli (02.2026) | Website: coming soon...") | |
| gradio_app.launch() |