Spaces:
Running
Running
Anirudh Balaraman commited on
Update run_inference.py
Browse files- run_inference.py +5 -5
run_inference.py
CHANGED
|
@@ -18,9 +18,9 @@ from src.preprocessing.register_and_crop import register_files
|
|
| 18 |
from src.utils import get_parent_image, get_patch_coordinate, setup_logging
|
| 19 |
import streamlit as st
|
| 20 |
|
| 21 |
-
@st.cache_resource
|
| 22 |
def load_pirads_model(num_classes, mil_mode, project_dir, device):
|
| 23 |
-
|
| 24 |
model = MILModel3D(num_classes=num_classes, mil_mode=mil_mode)
|
| 25 |
checkpoint = torch.load(
|
| 26 |
os.path.join(project_dir, "models", "pirads.pt"), map_location="cpu"
|
|
@@ -28,11 +28,11 @@ def load_pirads_model(num_classes, mil_mode, project_dir, device):
|
|
| 28 |
model.load_state_dict(checkpoint["state_dict"])
|
| 29 |
model.to(device)
|
| 30 |
|
| 31 |
-
model.eval()
|
| 32 |
return model
|
| 33 |
@st.cache_resource
|
| 34 |
def load_cspca_model(_pirads_model, project_dir, device):
|
| 35 |
-
|
| 36 |
model = CSPCAModel(backbone=_pirads_model).to(device)
|
| 37 |
checkpt = torch.load(
|
| 38 |
os.path.join(project_dir, "models", "cspca_model.pth"), map_location="cpu"
|
|
@@ -40,7 +40,7 @@ def load_cspca_model(_pirads_model, project_dir, device):
|
|
| 40 |
model.load_state_dict(checkpt["state_dict"])
|
| 41 |
model = model.to(device)
|
| 42 |
|
| 43 |
-
model.eval()
|
| 44 |
return model
|
| 45 |
|
| 46 |
|
|
|
|
| 18 |
from src.utils import get_parent_image, get_patch_coordinate, setup_logging
|
| 19 |
import streamlit as st
|
| 20 |
|
| 21 |
+
@st.cache_resource
|
| 22 |
def load_pirads_model(num_classes, mil_mode, project_dir, device):
|
| 23 |
+
|
| 24 |
model = MILModel3D(num_classes=num_classes, mil_mode=mil_mode)
|
| 25 |
checkpoint = torch.load(
|
| 26 |
os.path.join(project_dir, "models", "pirads.pt"), map_location="cpu"
|
|
|
|
| 28 |
model.load_state_dict(checkpoint["state_dict"])
|
| 29 |
model.to(device)
|
| 30 |
|
| 31 |
+
model.eval()
|
| 32 |
return model
|
| 33 |
@st.cache_resource
|
| 34 |
def load_cspca_model(_pirads_model, project_dir, device):
|
| 35 |
+
|
| 36 |
model = CSPCAModel(backbone=_pirads_model).to(device)
|
| 37 |
checkpt = torch.load(
|
| 38 |
os.path.join(project_dir, "models", "cspca_model.pth"), map_location="cpu"
|
|
|
|
| 40 |
model.load_state_dict(checkpt["state_dict"])
|
| 41 |
model = model.to(device)
|
| 42 |
|
| 43 |
+
model.eval()
|
| 44 |
return model
|
| 45 |
|
| 46 |
|