CLINICADL / app.py
aramis-user's picture
Update app.py
0381209 verified
import gradio as gr
from huggingface_hub import hf_hub_download
import torch
import torch.nn.functional as F
from clinicadl.utils.network.cnn.models import Conv5_FC3
import nibabel as nib
import numpy as np
# Download model from Hub
model_path = hf_hub_download(repo_id="ARAMIS-LAB/CNN-AD-CN", filename="model.pth.tar")
# Load ClinicaDL model
checkpoint_state = torch.load(model_path, map_location="cpu")
model = Conv5_FC3(input_size= [
1,
169,
208,
179
], gpu = False)
model.load_state_dict(checkpoint_state["model"])
model.eval()
# Class labels
CLASSES = ["CN", "AD"]
def preprocess_nii(nii_file):
# Load NIfTI file
img = nib.load(nii_file)
data = img.get_fdata() # numpy array (float64)
# Normalize intensities
data = (data - np.mean(data)) / np.std(data)
# Convert to tensor
tensor = torch.tensor(data, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
# Shape: [1, 1, D, H, W]
# Resize or pad to expected input shape
target_shape = (1, 1, 169, 208, 179)
tensor = F.interpolate(tensor, size=target_shape[2:], mode="trilinear", align_corners=False)
return tensor
def predict(input_image):
x = preprocess_nii(input_image)
with torch.no_grad():
logits = model(x)
probs = F.softmax(logits, dim=1) # convert to probabilities
probs = probs.squeeze(0).cpu().numpy().tolist() # ensures it's a Python list of floats
results = {cls: prob for cls, prob in zip(CLASSES, probs)}
return results
# Gradio app: file upload instead of image
demo = gr.Interface(
fn=predict,
inputs=gr.File(type="filepath", label=".nii.gz MRI upload"),
outputs="json",
title="ClinicaDL MRI Classifier",
description="Upload a .nii.gz file to get the model's prediction."
)
demo.launch()