aramis-user commited on
Commit
7ccbc69
·
verified ·
1 Parent(s): 6a939f0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -1
app.py CHANGED
@@ -1,8 +1,10 @@
1
  import gradio as gr
2
  from huggingface_hub import hf_hub_download
3
  import torch
 
4
  from clinicadl.utils.network.cnn.models import Conv5_FC3
5
  import nibabel as nib
 
6
 
7
  # Download model from Hub
8
  model_path = hf_hub_download(repo_id="ARAMIS-LAB/CNN-AD-CN", filename="model.pth.tar")
@@ -27,7 +29,7 @@ def preprocess_nii(nii_file):
27
  data = img.get_fdata() # numpy array (float64)
28
 
29
  # Normalize intensities
30
- data = (data - np.mean(data)) / (np.std(data) + 1e-8)
31
 
32
  # Convert to tensor
33
  tensor = torch.tensor(data, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
@@ -43,6 +45,7 @@ def predict(input_image):
43
  x = preprocess_nii(input_image)
44
  with torch.no_grad():
45
  output = model(x)
 
46
 
47
  results = {cls: float(prob) for cls, prob in zip(CLASSES, probs)}
48
 
 
1
  import gradio as gr
2
  from huggingface_hub import hf_hub_download
3
  import torch
4
+ import torch.nn.functional as F
5
  from clinicadl.utils.network.cnn.models import Conv5_FC3
6
  import nibabel as nib
7
+ import numpy as np
8
 
9
  # Download model from Hub
10
  model_path = hf_hub_download(repo_id="ARAMIS-LAB/CNN-AD-CN", filename="model.pth.tar")
 
29
  data = img.get_fdata() # numpy array (float64)
30
 
31
  # Normalize intensities
32
+ data = (data - np.mean(data)) / np.std(data)
33
 
34
  # Convert to tensor
35
  tensor = torch.tensor(data, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
 
45
  x = preprocess_nii(input_image)
46
  with torch.no_grad():
47
  output = model(x)
48
+ probs = F.softmax(logits, dim=1) # convert to probabilities
49
 
50
  results = {cls: float(prob) for cls, prob in zip(CLASSES, probs)}
51