Spaces:
Running
on
Zero
Running
on
Zero
Add way to call consolidate (#80)
Browse files* Add way to call consolidate
* black
* isort
---------
Co-authored-by: Srini Iyer <sviyer@meta.com>
- bytelatent/checkpoint.py +18 -0
bytelatent/checkpoint.py
CHANGED
|
@@ -12,6 +12,7 @@ import torch.distributed as dist
|
|
| 12 |
import torch.distributed.checkpoint as dcp
|
| 13 |
import torch.nn as nn
|
| 14 |
import torch.optim.optimizer
|
|
|
|
| 15 |
from pydantic import BaseModel, ConfigDict
|
| 16 |
from torch.distributed._tensor import DeviceMesh
|
| 17 |
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save
|
|
@@ -323,3 +324,20 @@ class CheckpointManager:
|
|
| 323 |
dist.barrier()
|
| 324 |
|
| 325 |
return cls(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
import torch.distributed.checkpoint as dcp
|
| 13 |
import torch.nn as nn
|
| 14 |
import torch.optim.optimizer
|
| 15 |
+
import typer
|
| 16 |
from pydantic import BaseModel, ConfigDict
|
| 17 |
from torch.distributed._tensor import DeviceMesh
|
| 18 |
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save
|
|
|
|
| 324 |
dist.barrier()
|
| 325 |
|
| 326 |
return cls(args)
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
def main(
|
| 330 |
+
command: str,
|
| 331 |
+
model_checkpoint_dir: str,
|
| 332 |
+
):
|
| 333 |
+
if command == "consolidate":
|
| 334 |
+
print(
|
| 335 |
+
f"Consolidating {model_checkpoint_dir}. Output will be in the {CONSOLIDATE_FOLDER} folder."
|
| 336 |
+
)
|
| 337 |
+
consolidate_checkpoints(fsspec.filesystem("file"), model_checkpoint_dir)
|
| 338 |
+
else:
|
| 339 |
+
raise ValueError("Invalid command")
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
if __name__ == "__main__":
|
| 343 |
+
typer.run(main)
|