rishabhsabnavis commited on
Commit
7c3484a
·
verified ·
1 Parent(s): 589b573

Upload train.py

Browse files
Files changed (1) hide show
  1. train.py +77 -0
train.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ASLP-LAB
2
+ # 2025 Ziqian Ning (ningziqian@mail.nwpu.edu.cn)
3
+ # 2025 Huakang Chen (huakang@mail.nwpu.edu.cn)
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from importlib.resources import files
18
+
19
+ from model import CFM, DiT, Trainer
20
+
21
+ from prefigure.prefigure import get_all_args
22
+ import json
23
+ import os
24
+
25
+ os.environ['OMP_NUM_THREADS']="1"
26
+ os.environ['MKL_NUM_THREADS']="1"
27
+
28
+ def main():
29
+ args = get_all_args("config/default.ini")
30
+
31
+ with open(args.model_config) as f:
32
+ model_config = json.load(f)
33
+
34
+ if model_config["model_type"] == "diffrhythm":
35
+ wandb_resume_id = None
36
+ model_cls = DiT
37
+
38
+ model = CFM(
39
+ transformer=model_cls(**model_config["model"], max_frames=args.max_frames),
40
+ num_channels=model_config["model"]['mel_dim'],
41
+ audio_drop_prob=args.audio_drop_prob,
42
+ cond_drop_prob=args.cond_drop_prob,
43
+ style_drop_prob=args.style_drop_prob,
44
+ lrc_drop_prob=args.lrc_drop_prob,
45
+ max_frames=args.max_frames
46
+ )
47
+
48
+ total_params = sum(p.numel() for p in model.parameters())
49
+ print(f"Total parameters: {total_params}")
50
+
51
+ trainer = Trainer(
52
+ model,
53
+ args,
54
+ args.epochs,
55
+ args.learning_rate,
56
+ num_warmup_updates=args.num_warmup_updates,
57
+ save_per_updates=args.save_per_updates,
58
+ checkpoint_path=f"ckpts/{args.exp_name}",
59
+ grad_accumulation_steps=args.grad_accumulation_steps,
60
+ max_grad_norm=args.max_grad_norm,
61
+ wandb_project="diffrhythm-test",
62
+ wandb_run_name=args.exp_name,
63
+ wandb_resume_id=wandb_resume_id,
64
+ last_per_steps=args.last_per_steps,
65
+ bnb_optimizer=False,
66
+ reset_lr=args.reset_lr,
67
+ batch_size=args.batch_size,
68
+ grad_ckpt=args.grad_ckpt
69
+ )
70
+
71
+ trainer.train(
72
+ resumable_with_seed=args.resumable_with_seed, # seed for shuffling dataset
73
+ )
74
+
75
+
76
+ if __name__ == "__main__":
77
+ main()