aramis-user commited on
Commit
d7ee6eb
·
verified ·
1 Parent(s): 1c020ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -2
app.py CHANGED
@@ -17,11 +17,38 @@ model = Conv5_FC3(input_size= [
17
  model.load_state_dict(checkpoint_state["model"])
18
  model.eval()
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def predict(input_image):
 
21
  with torch.no_grad():
22
- output = model(input_image.unsqueeze(0).to(model.device))
23
  output = output.squeeze(0).cpu().float()
24
  return output[0]
25
 
26
- demo = gr.Interface(fn=predict, inputs="image", outputs="label")
 
 
 
 
 
 
 
 
27
  demo.launch()
 
17
  model.load_state_dict(checkpoint_state["model"])
18
  model.eval()
19
 
20
+ def preprocess_nii(nii_file):
21
+ # Load NIfTI file
22
+ img = nib.load(nii_file)
23
+ data = img.get_fdata() # numpy array (float64)
24
+
25
+ # Normalize intensities
26
+ data = (data - np.mean(data)) / (np.std(data) + 1e-8)
27
+
28
+ # Convert to tensor
29
+ tensor = torch.tensor(data, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
30
+ # Shape: [1, 1, D, H, W]
31
+
32
+ # Resize or pad to expected input shape
33
+ target_shape = (1, 1, 169, 208, 179)
34
+ tensor = F.interpolate(tensor, size=target_shape[2:], mode="trilinear", align_corners=False)
35
+
36
+ return tensor
37
+
38
  def predict(input_image):
39
+ x = preprocess_nii(input_image)
40
  with torch.no_grad():
41
+ output = model(x)
42
  output = output.squeeze(0).cpu().float()
43
  return output[0]
44
 
45
+ # Gradio app: file upload instead of image
46
+ demo = gr.Interface(
47
+ fn=predict,
48
+ inputs=gr.File(type="file", label=".nii.gz MRI upload"),
49
+ outputs="label",
50
+ title="ClinicaDL MRI Classifier",
51
+ description="Upload a .nii.gz file to get the model's prediction."
52
+ )
53
+
54
  demo.launch()