| def calculate_embedding_flops(seqlen, hidden_size): |
| return 2 * seqlen * hidden_size |
|
|
|
|
| def calculate_lm_head_flops(seqlen, hidden_size, vocab_size): |
| return 2 * seqlen * hidden_size * vocab_size |
|
|
|
|
| def calculate_qkv_projection_flops(args, seqlen, hidden_size, num_attention_heads, num_query_groups): |
| if args.q_lora_rank is None: |
| q_flops = 2 * seqlen * hidden_size * num_attention_heads * args.kv_channels |
| else: |
| q_flops = ( |
| 2 |
| * seqlen |
| * args.q_lora_rank |
| * (args.hidden_size + args.num_attention_heads * (args.qk_head_dim + args.qk_pos_emb_head_dim)) |
| ) |
| if args.kv_lora_rank is None: |
| kv_flops = 2 * 2 * seqlen * hidden_size * num_query_groups * args.kv_channels |
| else: |
| kv_flops = ( |
| 2 |
| * seqlen |
| * ( |
| args.kv_lora_rank |
| * (args.hidden_size + args.num_attention_heads * (args.qk_head_dim + args.v_head_dim)) |
| + args.hidden_size * args.qk_pos_emb_head_dim |
| ) |
| ) |
|
|
| return q_flops + kv_flops |
|
|
|
|
| def calculate_attention_flops(args, seqlen, num_attention_heads): |
| |
| if args.qk_pos_emb_head_dim: |
| flops = 2 * num_attention_heads * seqlen * seqlen * (args.qk_head_dim + args.qk_pos_emb_head_dim) / 2 |
| else: |
| flops = 2 * num_attention_heads * seqlen * seqlen * args.kv_channels / 2 |
| |
| if args.v_head_dim: |
| flops += num_attention_heads * seqlen * seqlen * args.v_head_dim |
| else: |
| flops += num_attention_heads * seqlen * seqlen * args.kv_channels |
| return flops |
|
|
|
|
| def calculate_output_flops(seqlen, hidden_size): |
| return 2 * seqlen * hidden_size * hidden_size |
|
|
|
|
| def calculate_mlp_flops(seqlen, hidden_size, ffn_hidden_size): |
| return 2 * seqlen * hidden_size * ffn_hidden_size * 3 |
|
|
|
|
| def calculate_layer_flops(args, seqlen, hidden_size, num_attention_heads, num_query_groups, ffn_hidden_size): |
| return ( |
| calculate_qkv_projection_flops(args, seqlen, hidden_size, num_attention_heads, num_query_groups) |
| + calculate_attention_flops(args, seqlen, num_attention_heads) |
| + calculate_output_flops(seqlen, hidden_size) |
| + calculate_mlp_flops(seqlen, hidden_size, ffn_hidden_size) |
| ) |
|
|
|
|
| def calculate_fwd_flops( |
| seqlens, |
| args, |
| ): |
| hidden_size = args.hidden_size |
| num_attention_heads = args.num_attention_heads |
| num_query_groups = args.num_query_groups |
| vocab_size = args.vocab_size |
|
|
| total_flops = 0 |
|
|
| dense_ffn = args.ffn_hidden_size |
| if args.num_experts is None: |
| num_dense_layers = args.num_layers |
| num_moe_layers = 0 |
| else: |
| shared_expert_ffn = getattr(args, "moe_shared_expert_intermediate_size", None) |
| if shared_expert_ffn is None: |
| shared_expert_ffn = 0 |
|
|
| moe_ffn = args.moe_ffn_hidden_size * args.moe_router_topk + shared_expert_ffn |
| if hasattr(args, "moe_layer_freq"): |
| if isinstance(args.moe_layer_freq, list): |
| num_dense_layers = sum(1 for freq in args.moe_layer_freq if freq == 0) |
| num_moe_layers = sum(1 for freq in args.moe_layer_freq if freq > 0) |
| else: |
| num_dense_layers = sum(1 for i in range(args.num_layers) if i % args.moe_layer_freq != 0) |
| num_moe_layers = sum(1 for i in range(args.num_layers) if i % args.moe_layer_freq == 0) |
| else: |
| num_dense_layers = 0 |
| num_moe_layers = args.num_layers |
|
|
| for seqlen in seqlens: |
| if num_dense_layers > 0: |
| total_flops += ( |
| calculate_layer_flops( |
| args, |
| seqlen, |
| hidden_size, |
| num_attention_heads, |
| num_query_groups, |
| dense_ffn, |
| ) |
| * num_dense_layers |
| ) |
|
|
| if num_moe_layers > 0: |
| total_flops += ( |
| calculate_layer_flops( |
| args, |
| seqlen, |
| hidden_size, |
| num_attention_heads, |
| num_query_groups, |
| moe_ffn, |
| ) |
| * num_moe_layers |
| ) |
|
|
| total_flops += calculate_lm_head_flops(seqlen, hidden_size, vocab_size) |
|
|
| return total_flops |
|
|