Spaces:
Build error
Build error
fadindashfr
commited on
Commit
·
17a1b09
1
Parent(s):
4c5329d
add device = 'cpu'
Browse files
app.py
CHANGED
|
@@ -1,9 +1,19 @@
|
|
| 1 |
import torch
|
| 2 |
from monai.bundle import ConfigParser
|
| 3 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
parser = ConfigParser() # load configuration files that specify various parameters for running the MONAI workflow.
|
| 6 |
-
parser.read_config(f=
|
| 7 |
parser.read_meta(f="configs/metadata.json") # read the metadata from specified JSON file
|
| 8 |
|
| 9 |
inference = parser.get_parsed_content("inferer")
|
|
@@ -11,6 +21,9 @@ network = parser.get_parsed_content("network_def")
|
|
| 11 |
preprocess = parser.get_parsed_content("preprocessing")
|
| 12 |
state_dict = torch.load("models/model.pt")
|
| 13 |
network.load_state_dict(state_dict, strict=True) # Loads a model’s parameter dictionary
|
|
|
|
|
|
|
|
|
|
| 14 |
class_names = {
|
| 15 |
0: "Other",
|
| 16 |
1: "Inflammatory",
|
|
@@ -21,6 +34,7 @@ class_names = {
|
|
| 21 |
def classify_image(image_file, label_file):
|
| 22 |
data = {"image":image_file, "label":label_file}
|
| 23 |
batch = preprocess(data)
|
|
|
|
| 24 |
network.eval()
|
| 25 |
with torch.no_grad():
|
| 26 |
pred = inference(batch['image'].unsqueeze(dim=0), network) # expect 4 channels input (3 RGB, 1 Label mask)
|
|
|
|
| 1 |
import torch
|
| 2 |
from monai.bundle import ConfigParser
|
| 3 |
import gradio as gr
|
| 4 |
+
import json
|
| 5 |
+
|
| 6 |
+
with open("configs/inference.json") as f:
|
| 7 |
+
inference_config = json.load(f)
|
| 8 |
+
|
| 9 |
+
device = torch.device('cpu')
|
| 10 |
+
if torch.cuda.is_available():
|
| 11 |
+
device = torch.device('cuda:0')
|
| 12 |
+
|
| 13 |
+
inference_config["device"] = device
|
| 14 |
|
| 15 |
parser = ConfigParser() # load configuration files that specify various parameters for running the MONAI workflow.
|
| 16 |
+
parser.read_config(f=inference_config) # read the config from specified JSON file
|
| 17 |
parser.read_meta(f="configs/metadata.json") # read the metadata from specified JSON file
|
| 18 |
|
| 19 |
inference = parser.get_parsed_content("inferer")
|
|
|
|
| 21 |
preprocess = parser.get_parsed_content("preprocessing")
|
| 22 |
state_dict = torch.load("models/model.pt")
|
| 23 |
network.load_state_dict(state_dict, strict=True) # Loads a model’s parameter dictionary
|
| 24 |
+
network = network.to(device)
|
| 25 |
+
network.eval()
|
| 26 |
+
|
| 27 |
class_names = {
|
| 28 |
0: "Other",
|
| 29 |
1: "Inflammatory",
|
|
|
|
| 34 |
def classify_image(image_file, label_file):
|
| 35 |
data = {"image":image_file, "label":label_file}
|
| 36 |
batch = preprocess(data)
|
| 37 |
+
batch['image'] = batch['image'].to(device)
|
| 38 |
network.eval()
|
| 39 |
with torch.no_grad():
|
| 40 |
pred = inference(batch['image'].unsqueeze(dim=0), network) # expect 4 channels input (3 RGB, 1 Label mask)
|