sunweiwei commited on
Commit
97ac334
·
verified ·
1 Parent(s): 551ddcb

Create patch_seed_oss.py

Browse files
Files changed (1) hide show
  1. patch_seed_oss.py +27 -0
patch_seed_oss.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import vllm.model_executor.models.seed_oss as seed_oss_module
2
+ from vllm.model_executor.layers.linear import RowParallelLinear
3
+
4
+ _original_init = seed_oss_module.SeedOSSForCausalLM.__init__
5
+
6
+
7
+ def _patched_init(self, *args, **kwargs):
8
+ _original_init(self, *args, **kwargs)
9
+
10
+ # Get configuration parameters
11
+ hidden_size = self.config.hidden_size
12
+ total_num_heads = self.config.num_attention_heads
13
+ head_dim = getattr(self.config, 'head_dim', hidden_size // total_num_heads)
14
+ quant_config = kwargs.get('quant_config', None)
15
+
16
+ attention_out_bias = getattr(self.config, 'attention_out_bias', False)
17
+
18
+ for layer_idx, layer in enumerate(self.model.layers):
19
+ layer.self_attn.o_proj = RowParallelLinear(
20
+ total_num_heads * head_dim,
21
+ hidden_size,
22
+ bias=attention_out_bias, # Read from config
23
+ quant_config=quant_config,
24
+ prefix=f"model.layers.{layer_idx}.self_attn.o_proj",
25
+ )
26
+
27
+ seed_oss_module.SeedOSSForCausalLM.__init__ = _patched_init