Seth0330 commited on
Commit
f54b486
·
verified ·
1 Parent(s): 96b3399

Create pdrt/models.py

Browse files
Files changed (1) hide show
  1. pdrt/models.py +193 -0
pdrt/models.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+ from transformers import VisionEncoderDecoderModel, DonutProcessor, VisionEncoderDecoderConfig
6
+
7
+ import paths
8
+
9
+ ######################################################
10
+ # Swin + CTC
11
+ ######################################################
12
+
13
+ class Identity(nn.Module):
14
+ def __init__(self):
15
+ super(Identity, self).__init__()
16
+
17
+ def forward(self, x):
18
+ return x
19
+
20
+ class Swin_CTC(nn.Module):
21
+
22
+ def __init__(self, vocab_size=100):
23
+ super().__init__()
24
+
25
+ # Swin Config
26
+ HEIGHT = paths.HEIGHT
27
+ WIDTH = paths.WIDTH
28
+ config = VisionEncoderDecoderConfig.from_pretrained(paths.DONUT_WEIGHTS)
29
+ config.encoder.image_size = [HEIGHT, WIDTH]
30
+
31
+ # Image Processor
32
+ self.processor = DonutProcessor.from_pretrained(paths.DONUT_WEIGHTS)
33
+ self.processor.image_processor.size = [WIDTH, HEIGHT]
34
+ self.processor.image_processor.do_align_long_axis = False
35
+
36
+ # Swin Encoder
37
+ self.swin_encoder = VisionEncoderDecoderModel.from_pretrained(paths.DONUT_WEIGHTS, config=config).encoder
38
+ self.swin_encoder.pooler = Identity()
39
+
40
+ # Fully-connected Layer to Vocab
41
+ self.projection_V = nn.Linear(1024, vocab_size+1) # classes + blank token
42
+
43
+ def forward(self, x, targets=None, target_lengths=None):
44
+
45
+ x = self.swin_encoder(x).last_hidden_state # (b, 4800, 1024)
46
+ x = self.projection_V(x) # (b, 4800,1024) to (b, 4800, V)
47
+
48
+ if targets is not None:
49
+ x = x.permute(1, 0, 2)
50
+ loss = self.ctc_loss(x,targets, target_lengths)
51
+ return x, loss
52
+
53
+ return x, None
54
+
55
+ @staticmethod
56
+ def ctc_loss(x, targets, target_lengths):
57
+ batch_size = x.size(1)
58
+
59
+ log_probs = F.log_softmax(x, 2)
60
+
61
+ input_lengths = torch.full(
62
+ size=(batch_size,),
63
+ fill_value=log_probs.size(0),
64
+ dtype=torch.int32
65
+ )
66
+
67
+ loss = nn.CTCLoss(blank=0)(
68
+ log_probs, targets, input_lengths, target_lengths
69
+ )
70
+
71
+ return loss
72
+
73
+ def inference_one_sample(self, x, seq_to_text):
74
+
75
+ x, _ = self(x) # forward of Swin+CTC model
76
+
77
+ x = x.permute(1, 0, 2)
78
+
79
+ x, xs = x, [x.size(0)] * x.size(1)
80
+ x = x.detach()
81
+
82
+ x = torch.nn.functional.log_softmax(x, 2)
83
+
84
+ # Transform to list of size = batch_size
85
+ x = [x[: xs[i], i, :] for i in range(len(xs))]
86
+ x = [x_n.max(dim=1) for x_n in x]
87
+
88
+ # Get symbols and probabilities
89
+ probs = [x_n.values.exp() for x_n in x]
90
+ x = [x_n.indices for x_n in x]
91
+
92
+ # Remove consecutive symbols
93
+ # Keep track of counts of consecutive symbols. Example: [0, 0, 0, 1, 2, 2] => [3, 1, 2]
94
+ counts = [torch.unique_consecutive(x_n, return_counts=True)[1] for x_n in x]
95
+
96
+ # Select indexes to keep. Example: [0, 3, 4] (always keep the first index, then use cumulative sum of counts tensor)
97
+ zero_tensor = torch.tensor([0], device=x.device)
98
+ idxs = [torch.cat((zero_tensor, count.cumsum(0)[:-1])) for count in counts]
99
+
100
+ # Keep only non consecutive symbols and their associated probabilities
101
+ x = [x[i][idxs[i]] for i in range(len(x))]
102
+ probs = [probs[i][idxs[i]] for i in range(len(x))]
103
+
104
+ # Remove blank symbols
105
+ # Get index for non blank symbols
106
+ idxs = [torch.nonzero(x_n, as_tuple=True) for x_n in x]
107
+
108
+ # Keep only non blank symbols and their associated probabilities
109
+ x = [x[i][idxs[i]] for i in range(len(x))]
110
+ probs = [probs[i][idxs[i]] for i in range(len(x))]
111
+
112
+ # Save results
113
+ out = {}
114
+ out["hyp"] = [x_n.tolist() for x_n in x]
115
+
116
+ # Return char-based probability
117
+ out["prob-htr-char"] = [prob.tolist() for prob in probs]
118
+
119
+ text = ""
120
+ for i in out["hyp"][0]:
121
+ text += seq_to_text[i]
122
+
123
+ return text
124
+
125
+
126
+ ######################################################
127
+ # Vision Encoder-Decoder (VED)
128
+ ######################################################
129
+
130
+ class VED(nn.Module):
131
+
132
+ def __init__(self):
133
+ super().__init__()
134
+
135
+ # VED Config
136
+ HEIGHT = paths.HEIGHT
137
+ WIDTH = paths.WIDTH
138
+ self.MAX_LENGTH = paths.MAX_LENGTH
139
+ config = VisionEncoderDecoderConfig.from_pretrained(paths.DONUT_WEIGHTS)
140
+ config.encoder.image_size = [HEIGHT, WIDTH]
141
+ config.decoder.max_length = self.MAX_LENGTH
142
+
143
+ # Image Processor
144
+ self.processor = DonutProcessor.from_pretrained(paths.DONUT_WEIGHTS)
145
+ self.processor.image_processor.size = [WIDTH, HEIGHT]
146
+ self.processor.image_processor.do_align_long_axis = False
147
+
148
+ # VED Model
149
+ self.model = VisionEncoderDecoderModel.from_pretrained(paths.DONUT_WEIGHTS, config=config)
150
+
151
+ # Params for Transformer Decoder
152
+ self.model.config.pad_token_id = self.processor.tokenizer.pad_token_id
153
+ self.model.config.pad_token_id = self.processor.tokenizer.pad_token_id
154
+ # set <s_synthdog> token=57524
155
+ self.model.config.decoder_start_token_id = 57524
156
+
157
+ def forward(self, x, labels):
158
+
159
+ outputs = self.model(x, labels=labels)
160
+ return outputs, outputs.loss
161
+
162
+ def inference(self, x):
163
+
164
+ batch_size = x.shape[0]
165
+
166
+ decoder_input_ids = torch.full(
167
+ (batch_size, 1),
168
+ self.model.config.decoder_start_token_id,
169
+ device=x.device
170
+ )
171
+
172
+ self.model.eval()
173
+ with torch.no_grad():
174
+ outputs = self.model.generate(
175
+ x,
176
+ decoder_input_ids=decoder_input_ids,
177
+ max_length=self.MAX_LENGTH,
178
+ early_stopping=True,
179
+ pad_token_id=self.processor.tokenizer.pad_token_id,
180
+ eos_token_id=self.processor.tokenizer.eos_token_id,
181
+ use_cache=True,
182
+ num_beams=1,
183
+ bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
184
+ return_dict_in_generate=True,
185
+ )
186
+
187
+ predictions = []
188
+ for seq in self.processor.tokenizer.batch_decode(outputs.sequences):
189
+ seq = seq.replace(self.processor.tokenizer.eos_token, "").replace(self.processor.tokenizer.pad_token, "")
190
+ seq = re.sub(r"<.*?>", "", seq, count=1).strip() # remove first task start token
191
+ predictions.append(seq)
192
+
193
+ return predictions