File size: 1,827 Bytes
55431de
d6ebf48
 
7ccbc69
50dc181
6a939f0
7ccbc69
55431de
d6ebf48
 
794835d
d6ebf48
50dc181
5369f4e
 
 
 
 
1c020ad
50dc181
d6ebf48
55431de
d2834b7
 
 
d7ee6eb
 
 
 
 
 
7ccbc69
d7ee6eb
 
 
 
 
 
 
 
 
 
 
d6ebf48
d7ee6eb
d6ebf48
23d80cd
7ccbc69
d2834b7
0381209
 
d2834b7
 
d6ebf48
d7ee6eb
 
 
92d9b2a
d2834b7
d7ee6eb
 
 
 
d6ebf48
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import gradio as gr
from huggingface_hub import hf_hub_download
import torch
import torch.nn.functional as F
from clinicadl.utils.network.cnn.models import Conv5_FC3
import nibabel as nib 
import numpy as np

# Download model from Hub
model_path = hf_hub_download(repo_id="ARAMIS-LAB/CNN-AD-CN", filename="model.pth.tar")

# Load ClinicaDL model
checkpoint_state = torch.load(model_path, map_location="cpu")
model = Conv5_FC3(input_size= [
        1,
        169,
        208,
        179
    ], gpu = False)
model.load_state_dict(checkpoint_state["model"])
model.eval()

# Class labels
CLASSES = ["CN", "AD"]

def preprocess_nii(nii_file):
    # Load NIfTI file
    img = nib.load(nii_file)
    data = img.get_fdata()  # numpy array (float64)

    # Normalize intensities
    data = (data - np.mean(data)) / np.std(data)

    # Convert to tensor
    tensor = torch.tensor(data, dtype=torch.float32).unsqueeze(0).unsqueeze(0)  
    # Shape: [1, 1, D, H, W]

    # Resize or pad to expected input shape
    target_shape = (1, 1, 169, 208, 179)
    tensor = F.interpolate(tensor, size=target_shape[2:], mode="trilinear", align_corners=False)

    return tensor
    
def predict(input_image):
    x = preprocess_nii(input_image)
    with torch.no_grad():
        logits = model(x)
        probs = F.softmax(logits, dim=1)  # convert to probabilities
        
    probs = probs.squeeze(0).cpu().numpy().tolist()  # ensures it's a Python list of floats
    results = {cls: prob for cls, prob in zip(CLASSES, probs)}
    
    return results

# Gradio app: file upload instead of image
demo = gr.Interface(
    fn=predict,
    inputs=gr.File(type="filepath", label=".nii.gz MRI upload"),
    outputs="json",
    title="ClinicaDL MRI Classifier",
    description="Upload a .nii.gz file to get the model's prediction."
)

demo.launch()