csaybar commited on
Commit
102bc5b
·
verified ·
1 Parent(s): 054d3ca

Update single/load.py

Browse files
Files changed (1) hide show
  1. single/load.py +6 -5
single/load.py CHANGED
@@ -57,10 +57,10 @@ class CombinedNet(nn.Module):
57
 
58
 
59
  # MLSTAC API -----------------------------------------------------------------------
60
- def example_data(path: pathlib.Path, *args, **kwargs):
61
  data_f = path / "example_data.safetensor"
62
  sample = safetensors.torch.load_file(data_f)
63
- return sample["image"]
64
 
65
  def trainable_model(path, device: str = "cpu", *args, **kwargs):
66
  trainable_f = path / "model.safetensor"
@@ -92,6 +92,7 @@ def compiled_model(path, device: str = "cpu", *args, **kwargs):
92
 
93
  return cloud_model
94
 
 
95
  def display_results(path: pathlib.Path, device: str = "cpu", *args, **kwargs):
96
  # Load model
97
  model = compiled_model(path, device, benchmark=True)
@@ -100,13 +101,13 @@ def display_results(path: pathlib.Path, device: str = "cpu", *args, **kwargs):
100
  probav = example_data(path)
101
 
102
  # Run model
103
- cloudprobs = model(probav.float().unsqueeze(0).to(device)).squeeze(0).cpu()
104
 
105
  #Display results
106
  fig, ax = plt.subplots(1, 2, figsize=(8, 4))
107
- ax[0].imshow(probav[[2, 1, 0]].cpu().detach().numpy().transpose(1, 2, 0))
108
  ax[0].set_title("Input")
109
- ax[1].imshow(cloudprobs[0].cpu().detach().numpy(), cmap="gray")
110
  ax[1].set_title("Output")
111
  for a in ax:
112
  a.axis("off")
 
57
 
58
 
59
  # MLSTAC API -----------------------------------------------------------------------
60
+ def example_data(path: pathlib.Path, device = "cpu", *args, **kwargs):
61
  data_f = path / "example_data.safetensor"
62
  sample = safetensors.torch.load_file(data_f)
63
+ return sample["image"].float().unsqueeze(0).to(device)
64
 
65
  def trainable_model(path, device: str = "cpu", *args, **kwargs):
66
  trainable_f = path / "model.safetensor"
 
92
 
93
  return cloud_model
94
 
95
+
96
  def display_results(path: pathlib.Path, device: str = "cpu", *args, **kwargs):
97
  # Load model
98
  model = compiled_model(path, device, benchmark=True)
 
101
  probav = example_data(path)
102
 
103
  # Run model
104
+ cloudprobs = model(probav).squeeze().cpu()
105
 
106
  #Display results
107
  fig, ax = plt.subplots(1, 2, figsize=(8, 4))
108
+ ax[0].imshow(probav[0, [2, 1, 0]].cpu().detach().numpy().transpose(1, 2, 0))
109
  ax[0].set_title("Input")
110
+ ax[1].imshow(cloudprobs.cpu().detach().numpy(), cmap="gray")
111
  ax[1].set_title("Output")
112
  for a in ax:
113
  a.axis("off")