Spaces:
Runtime error
Runtime error
| from typing import NamedTuple | |
| import yaml | |
| from tqdm import tqdm | |
| import numpy as np | |
| import torch as T | |
| from torch import nn | |
| from torch import functional as F | |
| from diac_utils import flat_2_3head | |
| from model_dd import DiacritizerD2 | |
| class Readout(nn.Module): | |
| def __init__( | |
| self, | |
| in_size: int, | |
| out_size: int, | |
| ): | |
| super().__init__() | |
| self.W1 = nn.Linear(in_size, in_size) | |
| self.W2 = nn.Linear(in_size, out_size) | |
| def forward(self, x: T.Tensor): | |
| z = self.W1(x) | |
| z = T.tanh(z) | |
| z = self.W2(x) | |
| return z | |
| class WordDD_LSTM(nn.Module): | |
| def __init__( | |
| self, | |
| feature_size: int, | |
| num_classes: int = 13, | |
| return_logits: bool = True, | |
| ): | |
| super().__init__() | |
| self.feature_size = feature_size | |
| self.num_classes = num_classes | |
| self.return_logits = return_logits | |
| self.cell = nn.LSTM(feature_size) | |
| self.head = Readout(feature_size, num_classes) | |
| def forward(self, x: T.Tensor): | |
| #^ x: [b tc dc] | |
| z = self.cell(x) | |
| #^ z: [b tc @dc] | |
| y = self.head(z) | |
| #^ y: [b tc Classes] | |
| yhat = y | |
| if not self.return_logits: | |
| yhat = F.softmax(yhat, dim=1) | |
| #^ yhat: [b tc @Classes] | |
| return yhat | |
| class PartialDiacOutput(NamedTuple): | |
| preds_hard: T.Tensor | |
| preds_ctxt_logit: T.Tensor | |
| preds_base_logit: T.Tensor | |
| class PartialDD(nn.Module): | |
| def __init__( | |
| self, | |
| config: dict, | |
| # feature_size: int, | |
| # confidence_threshold: float, | |
| d2=False | |
| ): | |
| super().__init__() | |
| self._built = False | |
| self.no_diac_id = 0 | |
| self._dummy = nn.Parameter(T.ones(1, 1)) | |
| self.config = config | |
| self.sentence_diac = DiacritizerD2(self.config) | |
| self.eval() | |
| def device(self): | |
| return self._dummy.device | |
| def tokenizer(self): | |
| return self.sentence_diac.tokenizer | |
| def load_state_dict( | |
| self, | |
| state_dict: dict | |
| ): | |
| self.sentence_diac.load_state_dict(state_dict) | |
| def _slim_batch( | |
| self, | |
| toke_ids: T.Tensor, | |
| char_ids: T.Tensor, | |
| diac_ids: T.Tensor, | |
| subword_lengths: T.Tensor, | |
| ): | |
| #^ toke_ids: [b tt] | |
| #^ char_ids: [b tw tc] | |
| #^ diac_ids: [b tw tc "13"] | |
| #^ subword_lengths: [b tw] | |
| token_nonpad_mask = toke_ids.ne(self.tokenizer.pad_token_id) | |
| Ttoken = token_nonpad_mask.sum(1).max() | |
| toke_ids = toke_ids[:, :Ttoken] | |
| char_nonpad_mask = char_ids.ne(0) | |
| Tword = char_nonpad_mask.any(2).sum(1).max() | |
| Tchar = char_nonpad_mask.sum(2).max() | |
| char_ids = char_ids[:, :Tword, :Tchar] | |
| diac_ids = diac_ids[:, :Tword, :Tchar] | |
| subword_lengths = subword_lengths[:, :Tword] | |
| return toke_ids, char_ids, diac_ids, subword_lengths | |
| def word_diac( | |
| self, | |
| toke_ids: T.Tensor, | |
| char_ids: T.Tensor, | |
| diac_ids: T.Tensor, | |
| subword_lengths: T.Tensor, | |
| *, | |
| shape: tuple = None, | |
| ): | |
| if shape is None: | |
| toke_ids, char_ids, diac_ids, subword_lengths = self._slim_batch( | |
| toke_ids, char_ids, diac_ids, subword_lengths | |
| ) | |
| else: | |
| Nb, Tw, Tc = shape | |
| toke_ids = toke_ids[:, :] | |
| char_ids = char_ids[:, :Tw, :Tc] | |
| diac_ids = diac_ids[:, :Tw, :Tc, :] | |
| subword_lengths = subword_lengths[:, :Tw] | |
| Nb, Tw, Tc = char_ids.shape | |
| # Tw = min(Tw, word_ids.shape[1]) | |
| #^ word_ids: [b tt] | |
| #^ char_ids: [b tw tc] | |
| # wids_flat = word_ids[:, Tw].reshape(Nb * Tw, 1) | |
| # cids_flat = char_ids[:, Tw].reshape(Nb * Tw, 1, Tc) | |
| # z = self.sentence_diac(wids_flat, cids_flat) | |
| sent_word_strides = subword_lengths.cumsum(1) | |
| assert tuple(subword_lengths.shape) == (Nb, Tw), f"{subword_lengths.shape} != {(Nb, Tw)=}" | |
| max_tokens_per_word: int = subword_lengths.max().int().item() | |
| word_x = T.zeros(Nb, Tw, max_tokens_per_word).to(toke_ids) | |
| for i_b in range(toke_ids.shape[0]): | |
| sent_i = toke_ids[i_b] | |
| start_iw = 0 | |
| for i_word, end_iw in enumerate(sent_word_strides[i_b]): | |
| if end_iw == start_iw: break | |
| word = sent_i[start_iw:end_iw] | |
| word_x[i_b, i_word, 0 : end_iw - start_iw] = word | |
| start_iw = end_iw | |
| #^ word_x: [b tw tt] | |
| word_x = word_x.reshape(Nb * Tw, max_tokens_per_word) | |
| cids_flat = char_ids.reshape(Nb * Tw, 1, Tc) | |
| word_lengths = subword_lengths.reshape(Nb * Tw, 1) | |
| z = self.sentence_diac( | |
| word_x, | |
| cids_flat, | |
| diac_ids.reshape(Nb*Tw, Tc, -1), | |
| subword_lengths=word_lengths, | |
| ) | |
| # Nc = z.shape[-1] | |
| #^ z: [b*tw, 1, tc, "13"] | |
| z = z.reshape(Nb, Tw, Tc, -1) | |
| return z | |
| def forward( | |
| self, | |
| word_ids: T.Tensor, | |
| char_ids: T.Tensor, | |
| _labels: T.Tensor, | |
| # ground_truth: T.Tensor, | |
| # padding_mask: T.BoolTensor, | |
| *, | |
| eval_only: str = None, | |
| subword_lengths: T.Tensor = None, | |
| return_extra: bool = False | |
| ): | |
| # assert self._built and not self.training | |
| assert not self.training | |
| #^ word_ids: [b tw] | |
| #^ char_ids: [b tw tc] | |
| #^ ground_truth: [b tw tc] | |
| padding_mask = char_ids.eq(0) | |
| #^ padding_mask: [b tw tc] | |
| if True or eval_only != 'base': | |
| y_ctxt = self.sentence_diac( | |
| word_ids, | |
| char_ids, | |
| _labels, | |
| ) | |
| out_shape = y_ctxt.shape[:-1] | |
| else: | |
| out_shape = self.sentence_diac._slim_batch_size( | |
| word_ids, | |
| char_ids, | |
| _labels, | |
| subword_lengths, | |
| )[1].shape | |
| #^ y_ctxt: [b tw tc "13"] | |
| if eval_only == 'ctxt': | |
| return y_ctxt.argmax(-1) | |
| y_base = self.word_diac( | |
| word_ids, | |
| char_ids, | |
| _labels, | |
| subword_lengths, | |
| shape=out_shape | |
| ) | |
| #^ y_base: [b tw tc "13"] | |
| if eval_only == 'base': | |
| return y_base.argmax(-1) | |
| ypred_ctxt = y_ctxt.argmax(-1) | |
| ypred_base = y_base.argmax(-1) | |
| #^ ypred: [b tw tc _] | |
| # Maybe for eval | |
| # ypred_ctxt[~((ypred_base == ground_truth) & (~padding_mask))] = self.no_diac_id | |
| # return ypred_ctxt | |
| ypred_ctxt[(padding_mask) | (ypred_base == ypred_ctxt)] = self.no_diac_id | |
| if not return_extra: | |
| return ypred_ctxt | |
| else: | |
| return PartialDiacOutput(ypred_ctxt, y_ctxt, y_base) | |
| def step(self, xt, yt, mask=None): | |
| raise NotImplementedError | |
| xt[1] = xt[1].to(self.device) | |
| xt[2] = xt[2].to(self.device) | |
| yt = yt.to(self.device) | |
| #^ yt: [b ts tw] | |
| diac, _ = self(*xt) # xt: (word_ids, char_ids, _labels) | |
| loss = self.closs(diac.view(-1, self.num_classes), yt.view(-1)) | |
| return loss | |
| def predict_partial( | |
| self, | |
| dataloader, | |
| return_extra=False, | |
| eval_only: str = None, | |
| ): | |
| training = self.training | |
| self.eval() | |
| preds = { | |
| 'haraka': [], | |
| 'shadda': [], | |
| 'tanween': [], | |
| 'diacs': [], | |
| 'y_ctxt': [], | |
| 'y_base': [], | |
| } | |
| print("> Predicting...") | |
| # breakpoint() | |
| for i_batch, (inputs, _, subword_lengths) in enumerate(tqdm(dataloader)): | |
| # if i_batch > 10: | |
| # break | |
| #^ inputs: [toke_ids, char_ids, diac_ids] | |
| inputs[0] = inputs[0].to(self.device) #< toke_ids | |
| inputs[1] = inputs[1].to(self.device) #< char_ids | |
| # inputs[2] = inputs[2].to(self.device) #< diac_ids | |
| if self._use_d2: | |
| subword_lengths = T.ones_like(inputs[0]) | |
| subword_lengths[inputs[0] == 0] = 0 | |
| with T.no_grad(): | |
| output = self( | |
| *inputs, | |
| subword_lengths=subword_lengths, | |
| return_extra=return_extra, | |
| eval_only=eval_only, | |
| ) | |
| # output = np.argmax(T.softmax(output.detach(), dim=-1).cpu().numpy(), axis=-1) | |
| if return_extra: | |
| assert isinstance(output, PartialDiacOutput) | |
| marks = output.preds_hard | |
| preds['diacs'].extend(list(marks.detach().cpu().numpy())) | |
| preds['y_ctxt'].extend(list(output.preds_ctxt_logit.detach().cpu().numpy())) | |
| preds['y_base'].extend(list(output.preds_base_logit.detach().cpu().numpy())) | |
| else: | |
| assert isinstance(output, T.Tensor) | |
| marks = output | |
| preds['diacs'].extend(list(marks.detach().cpu().numpy())) | |
| #^ [b ts tw] | |
| haraka, tanween, shadda = flat_2_3head(marks) | |
| preds['haraka'].extend(haraka) | |
| preds['tanween'].extend(tanween) | |
| preds['shadda'].extend(shadda) | |
| self.train(training) | |
| return { | |
| 'diacritics': ( | |
| #! FIXME! Due to batch slimming, output diacritics may need padding. | |
| np.array(preds['haraka']), | |
| np.array(preds["tanween"]), | |
| np.array(preds["shadda"]), | |
| ), | |
| 'other': ( # Would be empty when !return_extra | |
| preds['y_ctxt'], | |
| preds['y_base'], | |
| preds['diacs'], | |
| ) | |
| } | |
| def predict(self, dataloader): | |
| training = self.training | |
| self.eval() | |
| preds = {'haraka': [], 'shadda': [], 'tanween': []} | |
| print("> Predicting...") | |
| for inputs, _ in tqdm(dataloader, total=len(dataloader)): | |
| inputs[0] = inputs[0].to(self.device) | |
| inputs[1] = inputs[1].to(self.device) | |
| output = self(*inputs, eval_only='ctxt') | |
| # output = np.argmax(T.softmax(output.detach(), dim=-1).cpu().numpy(), axis=-1) | |
| marks = output | |
| #^ [b ts tw] | |
| haraka, tanween, shadda = flat_2_3head(marks) | |
| preds['haraka'].extend(haraka) | |
| preds['tanween'].extend(tanween) | |
| preds['shadda'].extend(shadda) | |
| self.train(training) | |
| return ( | |
| np.array(preds['haraka']), | |
| np.array(preds["tanween"]), | |
| np.array(preds["shadda"]), | |
| ) | |