File size: 1,021 Bytes
97ac334
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import vllm.model_executor.models.seed_oss as seed_oss_module
from vllm.model_executor.layers.linear import RowParallelLinear

_original_init = seed_oss_module.SeedOSSForCausalLM.__init__


def _patched_init(self, *args, **kwargs):
    _original_init(self, *args, **kwargs)

    # Get configuration parameters
    hidden_size = self.config.hidden_size
    total_num_heads = self.config.num_attention_heads
    head_dim = getattr(self.config, 'head_dim', hidden_size // total_num_heads)
    quant_config = kwargs.get('quant_config', None)

    attention_out_bias = getattr(self.config, 'attention_out_bias', False)

    for layer_idx, layer in enumerate(self.model.layers):
        layer.self_attn.o_proj = RowParallelLinear(
            total_num_heads * head_dim,
            hidden_size,
            bias=attention_out_bias,  # Read from config
            quant_config=quant_config,
            prefix=f"model.layers.{layer_idx}.self_attn.o_proj",
        )

seed_oss_module.SeedOSSForCausalLM.__init__ = _patched_init