aramis-user commited on
Commit
5369f4e
·
verified ·
1 Parent(s): b0097c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -3
app.py CHANGED
@@ -8,8 +8,12 @@ model_path = hf_hub_download(repo_id="ARAMIS-LAB/CNN-AD-CN", filename="model.pth
8
 
9
  # Load ClinicaDL model
10
  checkpoint_state = torch.load(model_path, map_location="cpu")
11
- print(checkpoint_state)
12
- model = Conv5_FC3()
 
 
 
 
13
  model.load_state_dict(checkpoint_state["model"])
14
  model.eval()
15
 
@@ -17,7 +21,7 @@ def predict(input_image):
17
  with torch.no_grad():
18
  output = model(input_image.unsqueeze(0))
19
  output = output.squeeze(0).cpu().float()
20
- return output
21
 
22
  demo = gr.Interface(fn=predict, inputs="image", outputs="label")
23
  demo.launch()
 
8
 
9
  # Load ClinicaDL model
10
  checkpoint_state = torch.load(model_path, map_location="cpu")
11
+ model = Conv5_FC3(input_size= [
12
+ 1,
13
+ 169,
14
+ 208,
15
+ 179
16
+ ])
17
  model.load_state_dict(checkpoint_state["model"])
18
  model.eval()
19
 
 
21
  with torch.no_grad():
22
  output = model(input_image.unsqueeze(0))
23
  output = output.squeeze(0).cpu().float()
24
+ return output[0]
25
 
26
  demo = gr.Interface(fn=predict, inputs="image", outputs="label")
27
  demo.launch()