Spaces:
Runtime error
Runtime error
Hugo Flores Garcia
commited on
Commit
·
405226b
1
Parent(s):
88c78e1
use torch.compile for training
Browse files- scripts/exp/train.py +7 -5
scripts/exp/train.py
CHANGED
|
@@ -485,7 +485,6 @@ def load(
|
|
| 485 |
save_path: str,
|
| 486 |
resume: bool = False,
|
| 487 |
tag: str = "latest",
|
| 488 |
-
load_weights: bool = False,
|
| 489 |
fine_tune_checkpoint: Optional[str] = None,
|
| 490 |
grad_clip_val: float = 5.0,
|
| 491 |
) -> State:
|
|
@@ -498,7 +497,7 @@ def load(
|
|
| 498 |
kwargs = {
|
| 499 |
"folder": f"{save_path}/{tag}",
|
| 500 |
"map_location": "cpu",
|
| 501 |
-
"package":
|
| 502 |
}
|
| 503 |
tracker.print(f"Loading checkpoint from {kwargs['folder']}")
|
| 504 |
if (Path(kwargs["folder"]) / "vampnet").exists():
|
|
@@ -511,11 +510,14 @@ def load(
|
|
| 511 |
|
| 512 |
if args["fine_tune"]:
|
| 513 |
assert fine_tune_checkpoint is not None, "Must provide a fine-tune checkpoint"
|
| 514 |
-
model =
|
| 515 |
-
|
|
|
|
|
|
|
|
|
|
| 516 |
|
| 517 |
-
model = VampNet() if model is None else model
|
| 518 |
|
|
|
|
| 519 |
model = accel.prepare_model(model)
|
| 520 |
|
| 521 |
# assert accel.unwrap(model).n_codebooks == codec.quantizer.n_codebooks
|
|
|
|
| 485 |
save_path: str,
|
| 486 |
resume: bool = False,
|
| 487 |
tag: str = "latest",
|
|
|
|
| 488 |
fine_tune_checkpoint: Optional[str] = None,
|
| 489 |
grad_clip_val: float = 5.0,
|
| 490 |
) -> State:
|
|
|
|
| 497 |
kwargs = {
|
| 498 |
"folder": f"{save_path}/{tag}",
|
| 499 |
"map_location": "cpu",
|
| 500 |
+
"package": False,
|
| 501 |
}
|
| 502 |
tracker.print(f"Loading checkpoint from {kwargs['folder']}")
|
| 503 |
if (Path(kwargs["folder"]) / "vampnet").exists():
|
|
|
|
| 510 |
|
| 511 |
if args["fine_tune"]:
|
| 512 |
assert fine_tune_checkpoint is not None, "Must provide a fine-tune checkpoint"
|
| 513 |
+
model = torch.compile(
|
| 514 |
+
VampNet.load(location=Path(fine_tune_checkpoint),
|
| 515 |
+
map_location="cpu",
|
| 516 |
+
)
|
| 517 |
+
)
|
| 518 |
|
|
|
|
| 519 |
|
| 520 |
+
model = torch.compile(VampNet()) if model is None else model
|
| 521 |
model = accel.prepare_model(model)
|
| 522 |
|
| 523 |
# assert accel.unwrap(model).n_codebooks == codec.quantizer.n_codebooks
|