hf_tutorial / modeling_avey.py
agadelmoula-avey's picture
Update modeling_avey.py
a0be5c6 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.modeling_outputs import (
BaseModelOutput,
MaskedLMOutput,
SequenceClassifierOutput,
TokenClassifierOutput
)
from .configuration_avey import AveyConfig
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.utils.checkpoint import checkpoint
import torch
# torch._dynamo.config.allow_unspec_int_on_nn_module = True
class Contextualizer(nn.Module):
def __init__(self, config: AveyConfig, static: bool):
super().__init__()
self.eps = config.eps
self.static = static
if self.static:
self.spatial_proj = nn.Parameter(torch.empty(config.chunk_size, config.chunk_size))
nn.init.xavier_normal_(self.spatial_proj)
def cosim(self, embeddings: torch.Tensor) -> torch.Tensor:
norm = torch.sqrt(torch.sum(embeddings ** 2, dim=-1, keepdim=True) + self.eps)
normalized = embeddings / norm
cosim = torch.matmul(normalized, normalized.transpose(-1, -2))
return cosim
def forward(self, x: torch.Tensor) -> torch.Tensor:
_, T, _ = x.shape
x0, x1 = x.chunk(2, dim=-1)
if self.static:
x0 = self.spatial_proj[:T, :T] @ x0
else:
sim_scores = self.cosim(x0)
row_sums = sim_scores.sum(dim=-1, keepdim=True)
sim_scores = sim_scores / (row_sums + self.eps)
x0 = sim_scores @ x0
output = x0 * x1
return output
class ContextualizerLayer(nn.Module):
def __init__(self, config: AveyConfig, static: bool):
super().__init__()
expanded_dim = config.d_embed * config.expansion_factor
self.split_factor = [
int(expanded_dim * config.context_proportion),
int(expanded_dim * (1-config.context_proportion))
]
diff = expanded_dim - (self.split_factor[0] + self.split_factor[1])
self.split_factor[1] += diff
if self.split_factor[0] % 2 != 0:
self.split_factor[0] += 1
self.split_factor[1] -= 1
self.enricher = nn.Linear(config.d_embed, expanded_dim)
self.contextualizer = Contextualizer(config, static)
proj_in_features = int(self.split_factor[0] / 2 + self.split_factor[1])
self.fuser = nn.Linear(proj_in_features, config.d_embed)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_proj = F.relu(self.enricher(x)).square()
x0, x1 = x_proj.split(self.split_factor, dim=-1)
x0 = self.contextualizer(x0)
out = self.fuser(torch.cat([x0, x1], dim=-1))
return out
class AveyLayer(nn.Module):
def __init__(self, config: AveyConfig, static: bool):
super().__init__()
self.rms_norm = nn.RMSNorm(config.d_embed, eps=config.eps)
self.ctxt = ContextualizerLayer(config, static)
@torch.compile
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.ctxt(self.rms_norm(x))
class Ranker(nn.Module):
def __init__(self, config):
super().__init__()
self.chunk_size = config.chunk_size
self.k = config.k + 1
self.extended_len = self.k * config.chunk_size
self.eps = config.eps
self.down_proj = nn.Parameter(torch.empty(self.chunk_size, self.extended_len))
nn.init.xavier_normal_(self.down_proj)
def preprocess(self, x):
B, T, E = x.shape
cs, L = self.chunk_size, self.extended_len
padded = False
orig_T = T
if T % cs != 0:
pad_len = cs - (T % cs)
pad = torch.zeros(B, pad_len, E, device=x.device, dtype=x.dtype)
x = torch.cat([x, pad], dim=1)
T += pad_len
padded = True
N = T // cs
x_chunks = x.view(B, N, cs, E)
extended = []
for i in range(0, N):
cur = x_chunks[:, i]
others = x_chunks[:, :i]
cat = self._extend(others, cur) # (B, ≤k⋅cs+cs, E)
# pad or truncate to length L
cur_len = cat.size(1)
if cur_len < L:
pad2 = torch.zeros(B, L - cur_len, E, device=x.device, dtype=x.dtype)
cat = torch.cat([pad2, cat], dim=1)
else:
cat = cat[:, -L:]
extended.append(cat)
ext = torch.stack(extended, dim=1) # (B, N, L, E)
ext = (self.down_proj @ ext) + x_chunks
h = ext.view(B * N, cs, E)
state = {
"B": B,
"N": N,
"orig_T": orig_T,
"padded": padded,
}
return h, state
def contract(self, h, st):
B, cs = st["B"], self.chunk_size
N = st["N"]
padded = st["padded"]
orig_T = st["orig_T"]
E = h.size(-1)
final_chunks = h.view(B, N, cs, E)
out = final_chunks.reshape(B, N * cs, E)
if padded:
out = out[:, :orig_T, :]
return out
def _extend(self, other_chunks, cur_chunk):
B, cs, E = cur_chunk.shape
if other_chunks is None or other_chunks.size(1) == 0:
return cur_chunk
i = other_chunks.size(1)
num_sel = min(i, self.k - 1)
if num_sel <= 0:
return cur_chunk
# l2 normalize
cn = other_chunks / (other_chunks.norm(dim=-1, keepdim=True) + self.eps)
cm = cur_chunk / (cur_chunk.norm(dim=-1, keepdim=True) + self.eps)
# cosine sim
cm_e = cm.unsqueeze(1) # (B, 1, cs, E)
ct = cn.transpose(-1, -2) # (B, i, E, cs)
sims = torch.matmul(cm_e, ct) # (B, i, cs, cs)
mx, _ = sims.max(dim=-1) # (B, i, cs)
scores = mx.sum(dim=-1) # (B, i)
# topk
topk_vals, topk_idx = scores.topk(num_sel, dim=1)
# normalize weights
v_min = topk_vals.min(dim=-1, keepdim=True)[0] # (B, 1)
w = topk_vals / (v_min + self.eps) # (B, num_sel)
w = w.unsqueeze(-1).unsqueeze(-1) # (B, num_sel, 1, 1)
# gather
idx_e = topk_idx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, cs, E)
sel = other_chunks.gather(1, idx_e) # (B, num_sel, cs, E)
# weight & flatten
wt = (sel * w).reshape(B, num_sel * cs, E)
return torch.cat([wt, cur_chunk], dim=1) # (B, ≤k⋅cs+cs, E)
class AveyPreTrainedModel(PreTrainedModel):
config_class = AveyConfig
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.xavier_normal_(module.weight)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
nn.init.xavier_normal_(module.weight)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class AveyModel(AveyPreTrainedModel):
def __init__(self, config: AveyConfig):
super().__init__(config)
self.config = config
self.embeddings = nn.Embedding(config.vocab_size, config.d_embed)
self.layers = nn.ModuleList([AveyLayer(config, i%2 == 0) for i in range(config.n_layers)])
self.ranker = Ranker(config)
self.post_init()
def forward(self, input_ids: torch.Tensor, attention_mask=None, **kwargs):
h = self.embeddings(input_ids)
if attention_mask is not None:
h = h * attention_mask.unsqueeze(-1)
B, T, E = h.shape
padded = False
orig_T = T
if T % self.config.chunk_size != 0:
pad_len = self.config.chunk_size - (T % self.config.chunk_size)
pad_tensor = torch.zeros(
B, pad_len, E, device=h.device, dtype=h.dtype)
h = torch.cat([h, pad_tensor], dim=1)
T = h.shape[1]
padded = True
h, state = self.ranker.preprocess(h)
for (i, layer) in enumerate(self.layers):
# if i < self.config.n_layers - 2:
# h = checkpoint(layer,h,use_reentrant=False)
# else:
# h = layer(h)
h = layer(h)
h = self.ranker.contract(h, state)
if padded:
h = h[:, :orig_T, :]
out = BaseModelOutput(last_hidden_state=h)
return out
class AveyForMaskedLM(AveyPreTrainedModel):
def __init__(self, config: AveyConfig):
super().__init__(config)
self.config = config
self.base_avey_model = AveyModel(config)
self.ln_f = nn.RMSNorm(config.d_embed, eps=config.eps)
self.post_init()
def forward(self, input_ids: torch.Tensor, labels: torch.Tensor = None, **kwargs):
h = self.base_avey_model(input_ids, **kwargs).last_hidden_state
logits = F.linear(self.ln_f(h), self.base_avey_model.embeddings.weight)
if labels is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-100)
return MaskedLMOutput(logits=logits, loss=loss)
return MaskedLMOutput(logits=logits)
class AveyForSequenceClassification(AveyPreTrainedModel):
def __init__(self, config: AveyConfig, avey_model: AveyForMaskedLM = None):
super().__init__(config)
self.config = config
self.num_labels = config.num_labels
if avey_model is None:
self.avey_model = AveyForMaskedLM(config)
else:
self.avey_model = avey_model
self.classifier = nn.Linear(config.d_embed, config.num_labels)
self.dense = nn.Sequential(
nn.Linear(self.config.d_embed, self.config.d_embed*2),
nn.GELU(),
nn.Linear(self.config.d_embed*2, self.config.d_embed*2),
nn.GELU(),
nn.Linear(self.config.d_embed*2, self.config.d_embed)
)
self.post_init()
def forward(self, input_ids: torch.Tensor, labels: torch.Tensor = None, **kwargs):
h = self.avey_model.base_avey_model(input_ids, **kwargs).last_hidden_state
h = h.mean(dim=1)
h = self.avey_model.ln_f(h)
h = self.dense(h)
h = F.gelu(h)
logits = self.classifier(h)
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
return SequenceClassifierOutput(logits=logits, loss=loss)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, *model_args, **kwargs):
config = kwargs.pop("config", None)
if config is None:
config = AveyConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
archs = getattr(config, "architectures", [])
is_mlm = any("MaskedLM" in a for a in archs)
if is_mlm:
mlm_model = AveyForMaskedLM.from_pretrained(pretrained_model_name_or_path, **kwargs)
return cls(config, avey_model=mlm_model)
else:
return super().from_pretrained(
pretrained_model_name_or_path,
*model_args,
config=config,
**kwargs
)
class AveyForTokenClassification(AveyPreTrainedModel):
def __init__(self, config: AveyConfig, avey_model: AveyForMaskedLM = None):
super().__init__(config)
self.config = config
self.num_labels = config.num_labels
if avey_model is None:
self.avey_model = AveyForMaskedLM(config)
else:
self.avey_model = avey_model
self.classifier = nn.Linear(config.d_embed, config.num_labels)
self.dense = nn.Sequential(
nn.Linear(config.d_embed, config.d_embed),
nn.Tanh()
)
self.post_init()
def forward(self, input_ids: torch.Tensor, labels: torch.Tensor = None, **kwargs):
outputs = self.avey_model.base_avey_model(input_ids, **kwargs)
h = outputs.last_hidden_state
h = self.avey_model.ln_f(h)
h = self.dense(h)
logits = self.classifier(h)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return TokenClassifierOutput(loss=loss, logits=logits)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, *model_args, **kwargs):
config = kwargs.pop("config", None)
if config is None:
config = AveyConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
archs = getattr(config, "architectures", [])
is_mlm = any("MaskedLM" in a for a in archs)
if is_mlm:
mlm_model = AveyForMaskedLM.from_pretrained(pretrained_model_name_or_path, **kwargs)
return cls(config, avey_model=mlm_model)
else:
return super().from_pretrained(
pretrained_model_name_or_path,
*model_args,
config=config,
**kwargs
)