| | import re |
| | import torch |
| | from torch import nn |
| | from torch.nn import functional as F |
| | from transformers import VisionEncoderDecoderModel, DonutProcessor, VisionEncoderDecoderConfig |
| |
|
| | import paths |
| |
|
| | |
| | |
| | |
| |
|
| | class Identity(nn.Module): |
| | def __init__(self): |
| | super(Identity, self).__init__() |
| | |
| | def forward(self, x): |
| | return x |
| |
|
| | class Swin_CTC(nn.Module): |
| |
|
| | def __init__(self, vocab_size=100): |
| | super().__init__() |
| |
|
| | |
| | HEIGHT = paths.HEIGHT |
| | WIDTH = paths.WIDTH |
| | config = VisionEncoderDecoderConfig.from_pretrained(paths.DONUT_WEIGHTS) |
| | config.encoder.image_size = [HEIGHT, WIDTH] |
| | |
| | |
| | self.processor = DonutProcessor.from_pretrained(paths.DONUT_WEIGHTS) |
| | self.processor.image_processor.size = [WIDTH, HEIGHT] |
| | self.processor.image_processor.do_align_long_axis = False |
| |
|
| | |
| | self.swin_encoder = VisionEncoderDecoderModel.from_pretrained(paths.DONUT_WEIGHTS, config=config).encoder |
| | self.swin_encoder.pooler = Identity() |
| |
|
| | |
| | self.projection_V = nn.Linear(1024, vocab_size+1) |
| |
|
| | def forward(self, x, targets=None, target_lengths=None): |
| |
|
| | x = self.swin_encoder(x).last_hidden_state |
| | x = self.projection_V(x) |
| |
|
| | if targets is not None: |
| | x = x.permute(1, 0, 2) |
| | loss = self.ctc_loss(x,targets, target_lengths) |
| | return x, loss |
| |
|
| | return x, None |
| |
|
| | @staticmethod |
| | def ctc_loss(x, targets, target_lengths): |
| | batch_size = x.size(1) |
| | |
| | log_probs = F.log_softmax(x, 2) |
| | |
| | input_lengths = torch.full( |
| | size=(batch_size,), |
| | fill_value=log_probs.size(0), |
| | dtype=torch.int32 |
| | ) |
| |
|
| | loss = nn.CTCLoss(blank=0)( |
| | log_probs, targets, input_lengths, target_lengths |
| | ) |
| |
|
| | return loss |
| | |
| | def inference_one_sample(self, x, seq_to_text): |
| |
|
| | x, _ = self(x) |
| |
|
| | x = x.permute(1, 0, 2) |
| |
|
| | x, xs = x, [x.size(0)] * x.size(1) |
| | x = x.detach() |
| | |
| | x = torch.nn.functional.log_softmax(x, 2) |
| | |
| | |
| | x = [x[: xs[i], i, :] for i in range(len(xs))] |
| | x = [x_n.max(dim=1) for x_n in x] |
| |
|
| | |
| | probs = [x_n.values.exp() for x_n in x] |
| | x = [x_n.indices for x_n in x] |
| |
|
| | |
| | |
| | counts = [torch.unique_consecutive(x_n, return_counts=True)[1] for x_n in x] |
| |
|
| | |
| | zero_tensor = torch.tensor([0], device=x.device) |
| | idxs = [torch.cat((zero_tensor, count.cumsum(0)[:-1])) for count in counts] |
| |
|
| | |
| | x = [x[i][idxs[i]] for i in range(len(x))] |
| | probs = [probs[i][idxs[i]] for i in range(len(x))] |
| |
|
| | |
| | |
| | idxs = [torch.nonzero(x_n, as_tuple=True) for x_n in x] |
| |
|
| | |
| | x = [x[i][idxs[i]] for i in range(len(x))] |
| | probs = [probs[i][idxs[i]] for i in range(len(x))] |
| |
|
| | |
| | out = {} |
| | out["hyp"] = [x_n.tolist() for x_n in x] |
| |
|
| | |
| | out["prob-htr-char"] = [prob.tolist() for prob in probs] |
| |
|
| | text = "" |
| | for i in out["hyp"][0]: |
| | text += seq_to_text[i] |
| |
|
| | return text |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class VED(nn.Module): |
| |
|
| | def __init__(self): |
| | super().__init__() |
| |
|
| | |
| | HEIGHT = paths.HEIGHT |
| | WIDTH = paths.WIDTH |
| | self.MAX_LENGTH = paths.MAX_LENGTH |
| | config = VisionEncoderDecoderConfig.from_pretrained(paths.DONUT_WEIGHTS) |
| | config.encoder.image_size = [HEIGHT, WIDTH] |
| | config.decoder.max_length = self.MAX_LENGTH |
| | |
| | |
| | self.processor = DonutProcessor.from_pretrained(paths.DONUT_WEIGHTS) |
| | self.processor.image_processor.size = [WIDTH, HEIGHT] |
| | self.processor.image_processor.do_align_long_axis = False |
| |
|
| | |
| | self.model = VisionEncoderDecoderModel.from_pretrained(paths.DONUT_WEIGHTS, config=config) |
| |
|
| | |
| | self.model.config.pad_token_id = self.processor.tokenizer.pad_token_id |
| | self.model.config.pad_token_id = self.processor.tokenizer.pad_token_id |
| | |
| | self.model.config.decoder_start_token_id = 57524 |
| |
|
| | def forward(self, x, labels): |
| |
|
| | outputs = self.model(x, labels=labels) |
| | return outputs, outputs.loss |
| | |
| | def inference(self, x): |
| |
|
| | batch_size = x.shape[0] |
| |
|
| | decoder_input_ids = torch.full( |
| | (batch_size, 1), |
| | self.model.config.decoder_start_token_id, |
| | device=x.device |
| | ) |
| |
|
| | self.model.eval() |
| | with torch.no_grad(): |
| | outputs = self.model.generate( |
| | x, |
| | decoder_input_ids=decoder_input_ids, |
| | max_length=self.MAX_LENGTH, |
| | early_stopping=True, |
| | pad_token_id=self.processor.tokenizer.pad_token_id, |
| | eos_token_id=self.processor.tokenizer.eos_token_id, |
| | use_cache=True, |
| | num_beams=1, |
| | bad_words_ids=[[self.processor.tokenizer.unk_token_id]], |
| | return_dict_in_generate=True, |
| | ) |
| |
|
| | predictions = [] |
| | for seq in self.processor.tokenizer.batch_decode(outputs.sequences): |
| | seq = seq.replace(self.processor.tokenizer.eos_token, "").replace(self.processor.tokenizer.pad_token, "") |
| | seq = re.sub(r"<.*?>", "", seq, count=1).strip() |
| | predictions.append(seq) |
| |
|
| | return predictions |