Spaces:
Runtime error
Runtime error
File size: 1,827 Bytes
55431de d6ebf48 7ccbc69 50dc181 6a939f0 7ccbc69 55431de d6ebf48 794835d d6ebf48 50dc181 5369f4e 1c020ad 50dc181 d6ebf48 55431de d2834b7 d7ee6eb 7ccbc69 d7ee6eb d6ebf48 d7ee6eb d6ebf48 23d80cd 7ccbc69 d2834b7 0381209 d2834b7 d6ebf48 d7ee6eb 92d9b2a d2834b7 d7ee6eb d6ebf48 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import gradio as gr
from huggingface_hub import hf_hub_download
import torch
import torch.nn.functional as F
from clinicadl.utils.network.cnn.models import Conv5_FC3
import nibabel as nib
import numpy as np
# Download model from Hub
model_path = hf_hub_download(repo_id="ARAMIS-LAB/CNN-AD-CN", filename="model.pth.tar")
# Load ClinicaDL model
checkpoint_state = torch.load(model_path, map_location="cpu")
model = Conv5_FC3(input_size= [
1,
169,
208,
179
], gpu = False)
model.load_state_dict(checkpoint_state["model"])
model.eval()
# Class labels
CLASSES = ["CN", "AD"]
def preprocess_nii(nii_file):
# Load NIfTI file
img = nib.load(nii_file)
data = img.get_fdata() # numpy array (float64)
# Normalize intensities
data = (data - np.mean(data)) / np.std(data)
# Convert to tensor
tensor = torch.tensor(data, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
# Shape: [1, 1, D, H, W]
# Resize or pad to expected input shape
target_shape = (1, 1, 169, 208, 179)
tensor = F.interpolate(tensor, size=target_shape[2:], mode="trilinear", align_corners=False)
return tensor
def predict(input_image):
x = preprocess_nii(input_image)
with torch.no_grad():
logits = model(x)
probs = F.softmax(logits, dim=1) # convert to probabilities
probs = probs.squeeze(0).cpu().numpy().tolist() # ensures it's a Python list of floats
results = {cls: prob for cls, prob in zip(CLASSES, probs)}
return results
# Gradio app: file upload instead of image
demo = gr.Interface(
fn=predict,
inputs=gr.File(type="filepath", label=".nii.gz MRI upload"),
outputs="json",
title="ClinicaDL MRI Classifier",
description="Upload a .nii.gz file to get the model's prediction."
)
demo.launch()
|