Spaces:
Running
Running
enable dynamic selection of CUDA GPU only if available, else GPU inference
Browse files- lungtumormask/__main__.py +1 -1
- lungtumormask/dataprocessing.py +2 -1
- lungtumormask/mask.py +6 -2
lungtumormask/__main__.py
CHANGED
|
@@ -17,4 +17,4 @@ def main():
|
|
| 17 |
argsin = sys.argv[1:]
|
| 18 |
args = parser.parse_args(argsin)
|
| 19 |
|
| 20 |
-
mask.mask(args.input, args.output)
|
|
|
|
| 17 |
argsin = sys.argv[1:]
|
| 18 |
args = parser.parse_args(argsin)
|
| 19 |
|
| 20 |
+
mask.mask(args.input, args.output)
|
lungtumormask/dataprocessing.py
CHANGED
|
@@ -10,7 +10,8 @@ from monai.transforms import (Compose, LoadImaged, ToNumpyd, ThresholdIntensityd
|
|
| 10 |
|
| 11 |
def mask_lung(scan_path, batch_size=20):
|
| 12 |
model = lungmask.mask.get_model('unet', 'R231')
|
| 13 |
-
|
|
|
|
| 14 |
model.to(device)
|
| 15 |
|
| 16 |
scan_dict = {
|
|
|
|
| 10 |
|
| 11 |
def mask_lung(scan_path, batch_size=20):
|
| 12 |
model = lungmask.mask.get_model('unet', 'R231')
|
| 13 |
+
if T.cuda.is_available():
|
| 14 |
+
device = torch.device('cuda')
|
| 15 |
model.to(device)
|
| 16 |
|
| 17 |
scan_dict = {
|
lungtumormask/mask.py
CHANGED
|
@@ -5,10 +5,14 @@ import torch as T
|
|
| 5 |
import nibabel
|
| 6 |
|
| 7 |
def load_model():
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
model = UNet_double(3, 1, 1, tuple([64, 128, 256, 512, 1024]), tuple([2 for i in range(4)]), num_res_units = 0)
|
| 9 |
-
state_dict = T.hub.load_state_dict_from_url("https://github.com/VemundFredriksen/LungTumorMask/releases/download/0.0/dc_student.pth", progress=True, map_location=
|
| 10 |
#model.load_state_dict(T.load("D:\\OneDrive\\Skole\\Universitet\\10. Semester\\Masteroppgave\\bruk_for_full_model.pth", map_location="cuda:0"))
|
| 11 |
-
model.load_state_dict(state_dict)
|
| 12 |
model.eval()
|
| 13 |
return model
|
| 14 |
|
|
|
|
| 5 |
import nibabel
|
| 6 |
|
| 7 |
def load_model():
|
| 8 |
+
if T.cuda.is_available():
|
| 9 |
+
gpu_device = T.device('cuda')
|
| 10 |
+
else:
|
| 11 |
+
gpu_device = T.device('cpu')
|
| 12 |
model = UNet_double(3, 1, 1, tuple([64, 128, 256, 512, 1024]), tuple([2 for i in range(4)]), num_res_units = 0)
|
| 13 |
+
state_dict = T.hub.load_state_dict_from_url("https://github.com/VemundFredriksen/LungTumorMask/releases/download/0.0/dc_student.pth", progress=True, map_location=gpu_device)
|
| 14 |
#model.load_state_dict(T.load("D:\\OneDrive\\Skole\\Universitet\\10. Semester\\Masteroppgave\\bruk_for_full_model.pth", map_location="cuda:0"))
|
| 15 |
+
model.load_state_dict(state_dict)
|
| 16 |
model.eval()
|
| 17 |
return model
|
| 18 |
|