| | from typing import List, Iterator, cast |
| |
|
| | import copy |
| | import numpy as np |
| |
|
| | import torch as T |
| | from torch import nn |
| | from torch.nn import functional as F |
| | from transformers import BertConfig, BertModel |
| | from transformers import AutoTokenizer, AutoModel, AutoConfig |
| | from transformers import PreTrainedModel |
| | from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions |
| |
|
| | class Diacritizer(nn.Module): |
| | def __init__( |
| | self, |
| | config, |
| | device=None, |
| | load_pretrained=True |
| | ) -> None: |
| | super().__init__() |
| | self._dummy = nn.Parameter(T.ones(1)) |
| |
|
| | if 'modeling' in config: |
| | config = config['modeling'] |
| | self.config = config |
| |
|
| | model_name = config.get('base_model', "CAMeL-Lab/bert-base-arabic-camelbert-mix-ner") |
| | self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
| |
|
| | if load_pretrained: |
| | self.token_model: BertModel = AutoModel.from_pretrained(model_name) |
| | else: |
| | marbert_config = AutoConfig.from_pretrained(model_name) |
| | self.token_model = AutoModel.from_config(marbert_config) |
| |
|
| | self.num_classes = 15 |
| | self.diac_model_config = BertConfig(**config['diac_model_config']) |
| | self.token_model_config: BertConfig = self.token_model.config |
| |
|
| | self.char_embs = nn.Embedding(config["num-chars"], embedding_dim=config["char-embed-dim"]) |
| | self.diac_emb_model = self.build_diac_model(self.token_model) |
| |
|
| | self.down_project_token_embeds_deep = None |
| | self.down_project_token_embeds = None |
| | if 'token_hidden_size' in config: |
| | if config['token_hidden_size'] == 'auto': |
| | down_proj_size = self.diac_emb_model.config.hidden_size |
| | else: |
| | down_proj_size = config['token_hidden_size'] |
| | if config.get('deep-down-proj', False): |
| | self.down_project_token_embeds_deep = nn.Sequential( |
| | nn.Linear( |
| | self.token_model_config.hidden_size + config["char-embed-dim"], |
| | down_proj_size * 4, |
| | bias=False, |
| | ), |
| | nn.Tanh(), |
| | nn.Linear( |
| | down_proj_size * 4, |
| | down_proj_size, |
| | bias=False, |
| | ) |
| | ) |
| | |
| | self.down_project_token_embeds = nn.Linear( |
| | self.token_model_config.hidden_size + config["char-embed-dim"], |
| | down_proj_size, |
| | bias=False, |
| | ) |
| |
|
| | |
| | classifier_feature_size = self.diac_model_config.hidden_size |
| | if config.get('deep-cls', False): |
| | |
| | self.final_feature_transform = nn.Linear( |
| | self.diac_model_config.hidden_size |
| | + self.token_model_config.hidden_size, |
| | |
| | out_features=classifier_feature_size, |
| | bias=False |
| | ) |
| | else: |
| | self.final_feature_transform = None |
| |
|
| | self.feature_layer_norm = nn.LayerNorm(classifier_feature_size) |
| | self.classifier = nn.Linear(classifier_feature_size, self.num_classes, bias=True) |
| |
|
| | self.trim_model_(config) |
| |
|
| | self.dropout = nn.Dropout(config['dropout']) |
| | self.sent_dropout_p = config['sentence_dropout'] |
| | self.closs = F.cross_entropy |
| |
|
| | def build_diac_model(self, token_model=None): |
| | if self.config.get('pre-init-diac-model', False): |
| | model = copy.deepcopy(self.token_model) |
| | model.pooler = None |
| | model.embeddings.word_embeddings = None |
| |
|
| | num_layers = self.config.get('keep-token-model-layers', None) |
| | model.encoder.layer = nn.ModuleList( |
| | list(model.encoder.layer[num_layers:num_layers*2]) |
| | ) |
| |
|
| | model.encoder.config.num_hidden_layers = num_layers |
| | else: |
| | model = BertModel(self.diac_model_config) |
| | return model |
| |
|
| | def trim_model_(self, config): |
| | self.token_model.pooler = None |
| | self.diac_emb_model.pooler = None |
| | |
| | self.diac_emb_model.embeddings.word_embeddings = None |
| |
|
| | num_token_model_kept_layers = config.get('keep-token-model-layers', None) |
| | if num_token_model_kept_layers is not None: |
| | self.token_model.encoder.layer = nn.ModuleList( |
| | list(self.token_model.encoder.layer[:num_token_model_kept_layers]) |
| | ) |
| | self.token_model.encoder.config.num_hidden_layers = num_token_model_kept_layers |
| |
|
| | if not config.get('full-finetune', False): |
| | for param in self.token_model.parameters(): |
| | param.requires_grad = False |
| | finetune_last_layers = config.get('num-finetune-last-layers', 4) |
| | if finetune_last_layers > 0: |
| | unfrozen_layers = self.token_model.encoder.layer[-finetune_last_layers:] |
| | for layer in unfrozen_layers: |
| | for param in layer.parameters(): |
| | param.requires_grad = True |
| |
|
| | def get_grouped_params(self): |
| | downstream_params: Iterator[nn.Parameter] = cast( |
| | Iterator, |
| | (param |
| | for module in (self.diac_emb_model, self.classifier, self.char_embs) |
| | for param in module.parameters()) |
| | ) |
| | pg = { |
| | 'pretrained': self.token_model.parameters(), |
| | 'downstream': downstream_params, |
| | } |
| | return pg |
| |
|
| | @property |
| | def device(self): |
| | return self._dummy.device |
| |
|
| | def step(self, xt, yt, mask=None, subword_lengths: T.Tensor=None): |
| | |
| | |
| | |
| | |
| | |
| | xt[0], xt[1], yt, subword_lengths = self._slim_batch_size(xt[0], xt[1], yt, subword_lengths) |
| | xt[0] = xt[0].to(self.device) |
| | xt[1] = xt[1].to(self.device) |
| | |
| |
|
| | yt = yt.to(self.device) |
| | |
| |
|
| | Nb, Tword, Tchar = xt[1].shape |
| | if Tword * Tchar < 500: |
| | diac = self(*xt, subword_lengths) |
| | loss = self.closs(diac.view(-1, self.num_classes), yt.view(-1), reduction='sum') |
| | else: |
| | num_chunks = Tword * Tchar / 300 |
| | loss = 0 |
| | for i in range(round(num_chunks+0.5)): |
| | _slice = slice(i*300, (i+1)*300) |
| | chunk = self._slice_batch(xt, _slice) |
| | diac = self(*chunk, subword_lengths[_slice]) |
| | chunk_loss = self.closs(diac.view(-1, self.num_classes), yt.view(-1), reduction='sum') |
| | loss = loss + chunk_loss |
| |
|
| | return loss |
| |
|
| | def _slice_batch(self, xt: List[T.Tensor], _slice): |
| | return [xt[0][_slice], xt[1][_slice], xt[2][_slice]] |
| |
|
| | def _slim_batch_size( |
| | self, |
| | tx: T.Tensor, |
| | cx: T.Tensor, |
| | yt: T.Tensor, |
| | subword_lengths: T.Tensor |
| | ): |
| | |
| | |
| | |
| | token_nonpad_mask = tx.ne(self.tokenizer.pad_token_id) |
| | Ttoken = token_nonpad_mask.sum(1).max() |
| | tx = tx[:, :Ttoken] |
| |
|
| | char_nonpad_mask = cx.ne(0) |
| | Tword = char_nonpad_mask.any(2).sum(1).max() |
| | Tchar = char_nonpad_mask.sum(2).max() |
| | cx = cx[:, :Tword, :Tchar] |
| | yt = yt[:, :Tword, :Tchar] |
| | subword_lengths = subword_lengths[:, :Tword] |
| |
|
| | return tx, cx, yt, subword_lengths |
| |
|
| | def token_dropout(self, toke_x): |
| | |
| | if self.training: |
| | q = 1.0 - self.sent_dropout_p |
| | sdo = T.bernoulli(T.full(toke_x.shape, q)) |
| | toke_x[sdo == 0] = self.tokenizer.pad_token_id |
| | return toke_x |
| |
|
| | def sentence_dropout(self, word_embs: T.Tensor): |
| | |
| | if self.training: |
| | q = 1.0 - self.sent_dropout_p |
| | sdo = T.bernoulli(T.full(word_embs.shape[:2], q)) |
| | sdo = sdo.detach().unsqueeze(-1).to(word_embs) |
| | word_embs = word_embs * sdo |
| | |
| | return word_embs |
| |
|
| | def embed_tokens(self, input_ids: T.Tensor, attention_mask: T.Tensor): |
| | y: BaseModelOutputWithPoolingAndCrossAttentions |
| | y = self.token_model(input_ids, attention_mask=attention_mask) |
| | z = y.last_hidden_state |
| | return z |
| |
|
| | def forward( |
| | self, |
| | toke_x : T.Tensor, |
| | char_x : T.Tensor, |
| | diac_x : T.Tensor, |
| | subword_lengths : T.Tensor, |
| | ): |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | token_nonpad_mask = toke_x.ne(self.tokenizer.pad_token_id) |
| | char_nonpad_mask = char_x.ne(0) |
| |
|
| | Nb, Tw, Tc = char_x.shape |
| | |
| |
|
| | |
| | token_embs = self.embed_tokens(toke_x, attention_mask=token_nonpad_mask) |
| | |
| | |
| | token_embs = token_embs[:, 1:-1, ...] |
| |
|
| | sent_word_strides = subword_lengths.cumsum(1) |
| | sent_enc: T.Tensor = T.zeros(Nb, Tw, token_embs.shape[-1]).to(token_embs) |
| | for i_b in range(Nb): |
| | token_embs_ib = token_embs[i_b] |
| | start_iw = 0 |
| | for i_word, end_iw in enumerate(sent_word_strides[i_b]): |
| | if end_iw == start_iw: break |
| | word_emb = token_embs_ib[start_iw : end_iw].sum(0) / (end_iw - start_iw) |
| | sent_enc[i_b, i_word] = word_emb |
| | start_iw = end_iw |
| | |
| |
|
| | char_x_flat = char_x.reshape(Nb*Tw, Tc) |
| | char_nonpad_mask = char_x_flat.gt(0) |
| | |
| |
|
| | char_x_flat = char_x_flat * char_nonpad_mask |
| |
|
| | cembs = self.char_embs(char_x_flat) |
| | |
| | |
| | wembs = sent_enc.unsqueeze(-2).expand(Nb, Tw, Tc, -1).view(Nb*Tw, Tc, -1) |
| | |
| | cw_embs = T.cat([cembs, wembs], dim=-1) |
| | |
| | cw_embs = self.dropout(cw_embs) |
| |
|
| | cw_embs_ = cw_embs |
| | if self.down_project_token_embeds is not None: |
| | cw_embs_ = self.down_project_token_embeds(cw_embs) |
| | if self.down_project_token_embeds_deep is not None: |
| | cw_embs_ = cw_embs_ + self.down_project_token_embeds_deep(cw_embs) |
| | cw_embs = cw_embs_ |
| |
|
| | diac_enc: BaseModelOutputWithPoolingAndCrossAttentions |
| | diac_enc = self.diac_emb_model(inputs_embeds=cw_embs, attention_mask=char_nonpad_mask) |
| | diac_emb = diac_enc.last_hidden_state |
| | diac_emb = self.dropout(diac_emb) |
| | |
| | diac_emb = diac_emb.view(Nb, Tw, Tc, -1) |
| |
|
| | sent_residual = sent_enc.unsqueeze(2).expand(-1, -1, Tc, -1) |
| | final_feature = T.cat([sent_residual, diac_emb], dim=-1) |
| | if self.final_feature_transform is not None: |
| | final_feature = self.final_feature_transform(final_feature) |
| | final_feature = F.tanh(final_feature) |
| | final_feature = self.dropout(final_feature) |
| | else: |
| | final_feature = diac_emb |
| |
|
| | |
| | diac_out = self.classifier(final_feature) |
| | |
| | |
| | return diac_out |
| |
|
| | def predict(self, dataloader): |
| | from tqdm import tqdm |
| | import diac_utils as du |
| | training = self.training |
| | self.eval() |
| |
|
| | preds = {'haraka': [], 'shadda': [], 'tanween': []} |
| | print("> Predicting...") |
| | for inputs, _, subword_lengths in tqdm(dataloader, total=len(dataloader)): |
| | inputs[0] = inputs[0].to(self.device) |
| | inputs[1] = inputs[1].to(self.device) |
| | output = self(*inputs, subword_lengths).detach() |
| |
|
| | marks = np.argmax(T.softmax(output, dim=-1).cpu().numpy(), axis=-1) |
| | |
| |
|
| | haraka, tanween, shadda = du.flat_2_3head(marks) |
| |
|
| | preds['haraka'].extend(haraka) |
| | preds['tanween'].extend(tanween) |
| | preds['shadda'].extend(shadda) |
| |
|
| | self.train(training) |
| | return ( |
| | np.array(preds['haraka']), |
| | np.array(preds["tanween"]), |
| | np.array(preds["shadda"]), |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | model = Diacritizer({ |
| | "num-chars": 36, |
| | "hidden_size": 768, |
| | "char-embed-dim": 32, |
| | "dropout": 0.25, |
| | "sentence_dropout": 0.2, |
| | "diac_model_config": { |
| | "num_layers": 4, |
| | "hidden_size": 768 + 32, |
| | "intermediate_size": (768 + 32) * 4, |
| | }, |
| | }, load_pretrained=False) |
| |
|
| | total_params = sum(p.numel() for p in model.parameters()) |
| | trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| |
|
| | print(model) |
| | print(f"{trainable_params:,}/{total_params:,} Trainable Parameters") |