Kyryll Kochkin commited on
Commit
d261f14
·
unverified ·
1 Parent(s): 15fccb8

Update torch.load to include weights_only parameter

Browse files
Files changed (1) hide show
  1. train_conditional.py +3 -3
train_conditional.py CHANGED
@@ -34,7 +34,7 @@ class PixelTransformerConfig:
34
  config_path = os.path.join(path, "config.pt")
35
  if not os.path.exists(config_path):
36
  raise ValueError(f"No config found at {config_path}")
37
- config_dict = torch.load(config_path)
38
  return cls(**config_dict)
39
 
40
  def save_pretrained(self, path: str):
@@ -170,7 +170,7 @@ class PixelTransformer(nn.Module):
170
 
171
  # Create model and load state dict
172
  model = cls(config)
173
- state_dict = torch.load(os.path.join(path, "model.pt"), map_location='cpu')
174
  model.load_state_dict(state_dict)
175
 
176
  # Move model to device after loading
@@ -261,4 +261,4 @@ if __name__ == "__main__":
261
  warmup_steps=500,
262
  )
263
  model = train_pixel_transformer(config)
264
- model.save_pretrained("my_model")
 
34
  config_path = os.path.join(path, "config.pt")
35
  if not os.path.exists(config_path):
36
  raise ValueError(f"No config found at {config_path}")
37
+ config_dict = torch.load(config_path, weights_only=False)
38
  return cls(**config_dict)
39
 
40
  def save_pretrained(self, path: str):
 
170
 
171
  # Create model and load state dict
172
  model = cls(config)
173
+ state_dict = torch.load(os.path.join(path, "model.pt"), map_location='cpu', weights_only=False)
174
  model.load_state_dict(state_dict)
175
 
176
  # Move model to device after loading
 
261
  warmup_steps=500,
262
  )
263
  model = train_pixel_transformer(config)
264
+ model.save_pretrained("my_model")