Update single/load.py
Browse files- single/load.py +6 -5
single/load.py
CHANGED
|
@@ -57,10 +57,10 @@ class CombinedNet(nn.Module):
|
|
| 57 |
|
| 58 |
|
| 59 |
# MLSTAC API -----------------------------------------------------------------------
|
| 60 |
-
def example_data(path: pathlib.Path, *args, **kwargs):
|
| 61 |
data_f = path / "example_data.safetensor"
|
| 62 |
sample = safetensors.torch.load_file(data_f)
|
| 63 |
-
return sample["image"]
|
| 64 |
|
| 65 |
def trainable_model(path, device: str = "cpu", *args, **kwargs):
|
| 66 |
trainable_f = path / "model.safetensor"
|
|
@@ -92,6 +92,7 @@ def compiled_model(path, device: str = "cpu", *args, **kwargs):
|
|
| 92 |
|
| 93 |
return cloud_model
|
| 94 |
|
|
|
|
| 95 |
def display_results(path: pathlib.Path, device: str = "cpu", *args, **kwargs):
|
| 96 |
# Load model
|
| 97 |
model = compiled_model(path, device, benchmark=True)
|
|
@@ -100,13 +101,13 @@ def display_results(path: pathlib.Path, device: str = "cpu", *args, **kwargs):
|
|
| 100 |
probav = example_data(path)
|
| 101 |
|
| 102 |
# Run model
|
| 103 |
-
cloudprobs = model(probav
|
| 104 |
|
| 105 |
#Display results
|
| 106 |
fig, ax = plt.subplots(1, 2, figsize=(8, 4))
|
| 107 |
-
ax[0].imshow(probav[[2, 1, 0]].cpu().detach().numpy().transpose(1, 2, 0))
|
| 108 |
ax[0].set_title("Input")
|
| 109 |
-
ax[1].imshow(cloudprobs
|
| 110 |
ax[1].set_title("Output")
|
| 111 |
for a in ax:
|
| 112 |
a.axis("off")
|
|
|
|
| 57 |
|
| 58 |
|
| 59 |
# MLSTAC API -----------------------------------------------------------------------
|
| 60 |
+
def example_data(path: pathlib.Path, device = "cpu", *args, **kwargs):
|
| 61 |
data_f = path / "example_data.safetensor"
|
| 62 |
sample = safetensors.torch.load_file(data_f)
|
| 63 |
+
return sample["image"].float().unsqueeze(0).to(device)
|
| 64 |
|
| 65 |
def trainable_model(path, device: str = "cpu", *args, **kwargs):
|
| 66 |
trainable_f = path / "model.safetensor"
|
|
|
|
| 92 |
|
| 93 |
return cloud_model
|
| 94 |
|
| 95 |
+
|
| 96 |
def display_results(path: pathlib.Path, device: str = "cpu", *args, **kwargs):
|
| 97 |
# Load model
|
| 98 |
model = compiled_model(path, device, benchmark=True)
|
|
|
|
| 101 |
probav = example_data(path)
|
| 102 |
|
| 103 |
# Run model
|
| 104 |
+
cloudprobs = model(probav).squeeze().cpu()
|
| 105 |
|
| 106 |
#Display results
|
| 107 |
fig, ax = plt.subplots(1, 2, figsize=(8, 4))
|
| 108 |
+
ax[0].imshow(probav[0, [2, 1, 0]].cpu().detach().numpy().transpose(1, 2, 0))
|
| 109 |
ax[0].set_title("Input")
|
| 110 |
+
ax[1].imshow(cloudprobs.cpu().detach().numpy(), cmap="gray")
|
| 111 |
ax[1].set_title("Output")
|
| 112 |
for a in ax:
|
| 113 |
a.axis("off")
|