Anirudh Balaraman commited on
Commit
94fe9ea
·
unverified ·
1 Parent(s): 5d668a6

Update run_inference.py

Browse files
Files changed (1) hide show
  1. run_inference.py +5 -5
run_inference.py CHANGED
@@ -18,9 +18,9 @@ from src.preprocessing.register_and_crop import register_files
18
  from src.utils import get_parent_image, get_patch_coordinate, setup_logging
19
  import streamlit as st
20
 
21
- @st.cache_resource # <--- This decorator is the magic!
22
  def load_pirads_model(num_classes, mil_mode, project_dir, device):
23
- # Move the model initialization inside here
24
  model = MILModel3D(num_classes=num_classes, mil_mode=mil_mode)
25
  checkpoint = torch.load(
26
  os.path.join(project_dir, "models", "pirads.pt"), map_location="cpu"
@@ -28,11 +28,11 @@ def load_pirads_model(num_classes, mil_mode, project_dir, device):
28
  model.load_state_dict(checkpoint["state_dict"])
29
  model.to(device)
30
 
31
- model.eval() # Set to evaluation mode
32
  return model
33
  @st.cache_resource
34
  def load_cspca_model(_pirads_model, project_dir, device):
35
- # Move the model initialization inside here
36
  model = CSPCAModel(backbone=_pirads_model).to(device)
37
  checkpt = torch.load(
38
  os.path.join(project_dir, "models", "cspca_model.pth"), map_location="cpu"
@@ -40,7 +40,7 @@ def load_cspca_model(_pirads_model, project_dir, device):
40
  model.load_state_dict(checkpt["state_dict"])
41
  model = model.to(device)
42
 
43
- model.eval() # Set to evaluation mode
44
  return model
45
 
46
 
 
18
  from src.utils import get_parent_image, get_patch_coordinate, setup_logging
19
  import streamlit as st
20
 
21
+ @st.cache_resource
22
  def load_pirads_model(num_classes, mil_mode, project_dir, device):
23
+
24
  model = MILModel3D(num_classes=num_classes, mil_mode=mil_mode)
25
  checkpoint = torch.load(
26
  os.path.join(project_dir, "models", "pirads.pt"), map_location="cpu"
 
28
  model.load_state_dict(checkpoint["state_dict"])
29
  model.to(device)
30
 
31
+ model.eval()
32
  return model
33
  @st.cache_resource
34
  def load_cspca_model(_pirads_model, project_dir, device):
35
+
36
  model = CSPCAModel(backbone=_pirads_model).to(device)
37
  checkpt = torch.load(
38
  os.path.join(project_dir, "models", "cspca_model.pth"), map_location="cpu"
 
40
  model.load_state_dict(checkpt["state_dict"])
41
  model = model.to(device)
42
 
43
+ model.eval()
44
  return model
45
 
46