aramis-user commited on
Commit
50dc181
·
verified ·
1 Parent(s): 6167e17

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -1,19 +1,22 @@
1
  import gradio as gr
2
  from huggingface_hub import hf_hub_download
3
  import torch
 
4
 
5
  # Download model from Hub
6
  model_path = hf_hub_download(repo_id="ARAMIS-LAB/CNN-AD-CN", filename="model.pth.tar")
7
 
8
  # Load ClinicaDL model
9
- model = torch.load(model_path, map_location="cpu")
10
- print(model)
 
11
  model.eval()
12
 
13
  def predict(input_image):
14
  with torch.no_grad():
15
  output = model(input_image.unsqueeze(0))
16
- return output.numpy().tolist()
 
17
 
18
  demo = gr.Interface(fn=predict, inputs="image", outputs="label")
19
  demo.launch()
 
1
  import gradio as gr
2
  from huggingface_hub import hf_hub_download
3
  import torch
4
+ from clinicadl.utils.network.cnn.models import Conv5_FC3
5
 
6
  # Download model from Hub
7
  model_path = hf_hub_download(repo_id="ARAMIS-LAB/CNN-AD-CN", filename="model.pth.tar")
8
 
9
  # Load ClinicaDL model
10
+ checkpoint_state = torch.load(model_path, map_location="cpu")
11
+ model = Conv5_FC3()
12
+ model.load_state_dict(checkpoint_state["model"])
13
  model.eval()
14
 
15
  def predict(input_image):
16
  with torch.no_grad():
17
  output = model(input_image.unsqueeze(0))
18
+ output = output.squeeze(0).cpu().float()
19
+ return output
20
 
21
  demo = gr.Interface(fn=predict, inputs="image", outputs="label")
22
  demo.launch()