aramis-user commited on
Commit
d2834b7
·
verified ·
1 Parent(s): 92d9b2a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -3
app.py CHANGED
@@ -17,6 +17,9 @@ model = Conv5_FC3(input_size= [
17
  model.load_state_dict(checkpoint_state["model"])
18
  model.eval()
19
 
 
 
 
20
  def preprocess_nii(nii_file):
21
  # Load NIfTI file
22
  img = nib.load(nii_file)
@@ -39,14 +42,16 @@ def predict(input_image):
39
  x = preprocess_nii(input_image)
40
  with torch.no_grad():
41
  output = model(x)
42
- output = output.squeeze(0).cpu().float()
43
- return output[0]
 
 
44
 
45
  # Gradio app: file upload instead of image
46
  demo = gr.Interface(
47
  fn=predict,
48
  inputs=gr.File(type="filepath", label=".nii.gz MRI upload"),
49
- outputs="label",
50
  title="ClinicaDL MRI Classifier",
51
  description="Upload a .nii.gz file to get the model's prediction."
52
  )
 
17
  model.load_state_dict(checkpoint_state["model"])
18
  model.eval()
19
 
20
+ # Class labels
21
+ CLASSES = ["CN", "AD"]
22
+
23
  def preprocess_nii(nii_file):
24
  # Load NIfTI file
25
  img = nib.load(nii_file)
 
42
  x = preprocess_nii(input_image)
43
  with torch.no_grad():
44
  output = model(x)
45
+
46
+ results = {cls: float(prob) for cls, prob in zip(CLASSES, probs)}
47
+
48
+ return results
49
 
50
  # Gradio app: file upload instead of image
51
  demo = gr.Interface(
52
  fn=predict,
53
  inputs=gr.File(type="filepath", label=".nii.gz MRI upload"),
54
+ outputs="json",
55
  title="ClinicaDL MRI Classifier",
56
  description="Upload a .nii.gz file to get the model's prediction."
57
  )