import gradio as gr from huggingface_hub import hf_hub_download import torch import torch.nn.functional as F from clinicadl.utils.network.cnn.models import Conv5_FC3 import nibabel as nib import numpy as np # Download model from Hub model_path = hf_hub_download(repo_id="ARAMIS-LAB/CNN-AD-CN", filename="model.pth.tar") # Load ClinicaDL model checkpoint_state = torch.load(model_path, map_location="cpu") model = Conv5_FC3(input_size= [ 1, 169, 208, 179 ], gpu = False) model.load_state_dict(checkpoint_state["model"]) model.eval() # Class labels CLASSES = ["CN", "AD"] def preprocess_nii(nii_file): # Load NIfTI file img = nib.load(nii_file) data = img.get_fdata() # numpy array (float64) # Normalize intensities data = (data - np.mean(data)) / np.std(data) # Convert to tensor tensor = torch.tensor(data, dtype=torch.float32).unsqueeze(0).unsqueeze(0) # Shape: [1, 1, D, H, W] # Resize or pad to expected input shape target_shape = (1, 1, 169, 208, 179) tensor = F.interpolate(tensor, size=target_shape[2:], mode="trilinear", align_corners=False) return tensor def predict(input_image): x = preprocess_nii(input_image) with torch.no_grad(): logits = model(x) probs = F.softmax(logits, dim=1) # convert to probabilities probs = probs.squeeze(0).cpu().numpy().tolist() # ensures it's a Python list of floats results = {cls: prob for cls, prob in zip(CLASSES, probs)} return results # Gradio app: file upload instead of image demo = gr.Interface( fn=predict, inputs=gr.File(type="filepath", label=".nii.gz MRI upload"), outputs="json", title="ClinicaDL MRI Classifier", description="Upload a .nii.gz file to get the model's prediction." ) demo.launch()