Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| import torch | |
| import torch.nn.functional as F | |
| from clinicadl.utils.network.cnn.models import Conv5_FC3 | |
| import nibabel as nib | |
| import numpy as np | |
| # Download model from Hub | |
| model_path = hf_hub_download(repo_id="ARAMIS-LAB/CNN-AD-CN", filename="model.pth.tar") | |
| # Load ClinicaDL model | |
| checkpoint_state = torch.load(model_path, map_location="cpu") | |
| model = Conv5_FC3(input_size= [ | |
| 1, | |
| 169, | |
| 208, | |
| 179 | |
| ], gpu = False) | |
| model.load_state_dict(checkpoint_state["model"]) | |
| model.eval() | |
| # Class labels | |
| CLASSES = ["CN", "AD"] | |
| def preprocess_nii(nii_file): | |
| # Load NIfTI file | |
| img = nib.load(nii_file) | |
| data = img.get_fdata() # numpy array (float64) | |
| # Normalize intensities | |
| data = (data - np.mean(data)) / np.std(data) | |
| # Convert to tensor | |
| tensor = torch.tensor(data, dtype=torch.float32).unsqueeze(0).unsqueeze(0) | |
| # Shape: [1, 1, D, H, W] | |
| # Resize or pad to expected input shape | |
| target_shape = (1, 1, 169, 208, 179) | |
| tensor = F.interpolate(tensor, size=target_shape[2:], mode="trilinear", align_corners=False) | |
| return tensor | |
| def predict(input_image): | |
| x = preprocess_nii(input_image) | |
| with torch.no_grad(): | |
| logits = model(x) | |
| probs = F.softmax(logits, dim=1) # convert to probabilities | |
| probs = probs.squeeze(0).cpu().numpy().tolist() # ensures it's a Python list of floats | |
| results = {cls: prob for cls, prob in zip(CLASSES, probs)} | |
| return results | |
| # Gradio app: file upload instead of image | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.File(type="filepath", label=".nii.gz MRI upload"), | |
| outputs="json", | |
| title="ClinicaDL MRI Classifier", | |
| description="Upload a .nii.gz file to get the model's prediction." | |
| ) | |
| demo.launch() | |