Recover
Browse files- modeling_cocom.py +5 -42
modeling_cocom.py
CHANGED
|
@@ -3,45 +3,11 @@ import torch
|
|
| 3 |
import math
|
| 4 |
from peft import get_peft_model, LoraConfig, TaskType
|
| 5 |
import os
|
| 6 |
-
from flash_attn.flash_attn_interface import flash_attn_func
|
| 7 |
-
import torch.nn as nn
|
| 8 |
-
import torch
|
| 9 |
|
| 10 |
def freeze_model(model):
|
| 11 |
for param in model.parameters():
|
| 12 |
param.requires_grad = False
|
| 13 |
|
| 14 |
-
class CustomFlashAttention(nn.Module):
|
| 15 |
-
def __init__(self, embed_dim, num_heads, dropout=0.0):
|
| 16 |
-
super().__init__()
|
| 17 |
-
self.embed_dim = embed_dim
|
| 18 |
-
self.num_heads = num_heads
|
| 19 |
-
self.dropout = dropout
|
| 20 |
-
self.head_dim = embed_dim // num_heads
|
| 21 |
-
assert self.head_dim * num_heads == embed_dim, "Embedding size must be divisible by the number of heads."
|
| 22 |
-
|
| 23 |
-
# Define projection layers
|
| 24 |
-
self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim)
|
| 25 |
-
self.out_proj = nn.Linear(embed_dim, embed_dim)
|
| 26 |
-
|
| 27 |
-
def forward(self, hidden_states):
|
| 28 |
-
batch_size, seq_len, embed_dim = hidden_states.size()
|
| 29 |
-
qkv = self.qkv_proj(hidden_states) # Project to Q, K, V
|
| 30 |
-
qkv = qkv.view(batch_size, seq_len, 3, self.num_heads, self.head_dim)
|
| 31 |
-
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, batch_size, num_heads, seq_len, head_dim)
|
| 32 |
-
query, key, value = qkv[0], qkv[1], qkv[2]
|
| 33 |
-
|
| 34 |
-
# FlashAttention expects contiguous inputs
|
| 35 |
-
query = query.contiguous()
|
| 36 |
-
key = key.contiguous()
|
| 37 |
-
value = value.contiguous()
|
| 38 |
-
|
| 39 |
-
# Apply FlashAttention
|
| 40 |
-
attn_output, _ = flash_attn_func(query, key, value, dropout_p=self.dropout, causal=False)
|
| 41 |
-
|
| 42 |
-
# Reshape and project back to the original dimension
|
| 43 |
-
attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, embed_dim)
|
| 44 |
-
return self.out_proj(attn_output)
|
| 45 |
|
| 46 |
class BERT_Compressor(torch.nn.Module):
|
| 47 |
def __init__(self, compr_model_name, compr_rate, compr_linear_type, decoder_hidden_size):
|
|
@@ -109,7 +75,7 @@ class COCOMConfig(PretrainedConfig):
|
|
| 109 |
device_map = "cuda",
|
| 110 |
**kwargs):
|
| 111 |
super().__init__(**kwargs)
|
| 112 |
-
|
| 113 |
self.decoder_model_name = decoder_model_name # model name of decoder
|
| 114 |
self.quantization = quantization # quantization, could be no, int4, int8
|
| 115 |
self.generation_top_k = generation_top_k # top k for each query, for pretraining, set to 1
|
|
@@ -226,12 +192,6 @@ class COCOM(PreTrainedModel):
|
|
| 226 |
self.sep = cfg.sep
|
| 227 |
self.compr_rate = cfg.compr_rate
|
| 228 |
self.local_rank = os.getenv('LOCAL_RANK', '0')
|
| 229 |
-
for layer in self.decoder.encoder.layer:
|
| 230 |
-
layer.attention.self = CustomFlashAttention(
|
| 231 |
-
embed_dim=cfg.hidden_size,
|
| 232 |
-
num_heads=cfg.num_attention_heads,
|
| 233 |
-
dropout=cfg.attention_probs_dropout_prob,
|
| 234 |
-
)
|
| 235 |
|
| 236 |
def compress_and_replace_emb(self, enc_input_ids, enc_attention_mask, dec_input_ids):
|
| 237 |
indices = range(0, enc_input_ids.size(0) + 1, self.generation_top_k)
|
|
@@ -348,4 +308,7 @@ class COCOM(PreTrainedModel):
|
|
| 348 |
'dec_attention_mask': inp_dec['attention_mask'].to(self.decoder.device)
|
| 349 |
}
|
| 350 |
|
| 351 |
-
return self.generate(model_input, max_new_tokens)
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import math
|
| 4 |
from peft import get_peft_model, LoraConfig, TaskType
|
| 5 |
import os
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
def freeze_model(model):
|
| 8 |
for param in model.parameters():
|
| 9 |
param.requires_grad = False
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
class BERT_Compressor(torch.nn.Module):
|
| 13 |
def __init__(self, compr_model_name, compr_rate, compr_linear_type, decoder_hidden_size):
|
|
|
|
| 75 |
device_map = "cuda",
|
| 76 |
**kwargs):
|
| 77 |
super().__init__(**kwargs)
|
| 78 |
+
|
| 79 |
self.decoder_model_name = decoder_model_name # model name of decoder
|
| 80 |
self.quantization = quantization # quantization, could be no, int4, int8
|
| 81 |
self.generation_top_k = generation_top_k # top k for each query, for pretraining, set to 1
|
|
|
|
| 192 |
self.sep = cfg.sep
|
| 193 |
self.compr_rate = cfg.compr_rate
|
| 194 |
self.local_rank = os.getenv('LOCAL_RANK', '0')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
|
| 196 |
def compress_and_replace_emb(self, enc_input_ids, enc_attention_mask, dec_input_ids):
|
| 197 |
indices = range(0, enc_input_ids.size(0) + 1, self.generation_top_k)
|
|
|
|
| 308 |
'dec_attention_mask': inp_dec['attention_mask'].to(self.decoder.device)
|
| 309 |
}
|
| 310 |
|
| 311 |
+
return self.generate(model_input, max_new_tokens)
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
|