ashishkblink commited on
Commit
8999da3
·
verified ·
1 Parent(s): df4c01e

Upload f5_tts/scripts/count_params_gflops.py with huggingface_hub

Browse files
f5_tts/scripts/count_params_gflops.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+
4
+ sys.path.append(os.getcwd())
5
+
6
+ from f5_tts.model import CFM, DiT
7
+
8
+ import torch
9
+ import thop
10
+
11
+
12
+ """ ~155M """
13
+ # transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4)
14
+ # transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4, text_dim = 512, conv_layers = 4)
15
+ # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2)
16
+ # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4)
17
+ # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4, long_skip_connection = True)
18
+ # transformer = MMDiT(dim = 512, depth = 16, heads = 16, ff_mult = 2)
19
+
20
+ """ ~335M """
21
+ # FLOPs: 622.1 G, Params: 333.2 M
22
+ # transformer = UNetT(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
23
+ # FLOPs: 363.4 G, Params: 335.8 M
24
+ transformer = DiT(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
25
+
26
+
27
+ model = CFM(transformer=transformer)
28
+ target_sample_rate = 24000
29
+ n_mel_channels = 100
30
+ hop_length = 256
31
+ duration = 20
32
+ frame_length = int(duration * target_sample_rate / hop_length)
33
+ text_length = 150
34
+
35
+ flops, params = thop.profile(
36
+ model, inputs=(torch.randn(1, frame_length, n_mel_channels), torch.zeros(1, text_length, dtype=torch.long))
37
+ )
38
+ print(f"FLOPs: {flops / 1e9} G")
39
+ print(f"Params: {params / 1e6} M")