Spaces:
Sleeping
Sleeping
Kyryll Kochkin
commited on
Update torch.load to include weights_only parameter
Browse files- train_conditional.py +3 -3
train_conditional.py
CHANGED
|
@@ -34,7 +34,7 @@ class PixelTransformerConfig:
|
|
| 34 |
config_path = os.path.join(path, "config.pt")
|
| 35 |
if not os.path.exists(config_path):
|
| 36 |
raise ValueError(f"No config found at {config_path}")
|
| 37 |
-
config_dict = torch.load(config_path)
|
| 38 |
return cls(**config_dict)
|
| 39 |
|
| 40 |
def save_pretrained(self, path: str):
|
|
@@ -170,7 +170,7 @@ class PixelTransformer(nn.Module):
|
|
| 170 |
|
| 171 |
# Create model and load state dict
|
| 172 |
model = cls(config)
|
| 173 |
-
state_dict = torch.load(os.path.join(path, "model.pt"), map_location='cpu')
|
| 174 |
model.load_state_dict(state_dict)
|
| 175 |
|
| 176 |
# Move model to device after loading
|
|
@@ -261,4 +261,4 @@ if __name__ == "__main__":
|
|
| 261 |
warmup_steps=500,
|
| 262 |
)
|
| 263 |
model = train_pixel_transformer(config)
|
| 264 |
-
model.save_pretrained("my_model")
|
|
|
|
| 34 |
config_path = os.path.join(path, "config.pt")
|
| 35 |
if not os.path.exists(config_path):
|
| 36 |
raise ValueError(f"No config found at {config_path}")
|
| 37 |
+
config_dict = torch.load(config_path, weights_only=False)
|
| 38 |
return cls(**config_dict)
|
| 39 |
|
| 40 |
def save_pretrained(self, path: str):
|
|
|
|
| 170 |
|
| 171 |
# Create model and load state dict
|
| 172 |
model = cls(config)
|
| 173 |
+
state_dict = torch.load(os.path.join(path, "model.pt"), map_location='cpu', weights_only=False)
|
| 174 |
model.load_state_dict(state_dict)
|
| 175 |
|
| 176 |
# Move model to device after loading
|
|
|
|
| 261 |
warmup_steps=500,
|
| 262 |
)
|
| 263 |
model = train_pixel_transformer(config)
|
| 264 |
+
model.save_pretrained("my_model")
|