Kimang18 commited on
Commit
f2c188e
·
1 Parent(s): 8569912

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +158 -0
model.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Sequence
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from timm.models.vision_transformer import PatchEmbed, VisionTransformer
6
+ from dataclasses import dataclass
7
+ from torch import Tensor
8
+ import math
9
+
10
+
11
+ class ImageEncoder(VisionTransformer):
12
+ def __init__(self, config):
13
+ super().__init__(
14
+ img_size=config.img_size,
15
+ patch_size=config.patch_size,
16
+ in_chans=config.n_channel,
17
+ embed_dim=config.n_embed,
18
+ depth=config.n_layer,
19
+ num_heads=config.n_head,
20
+ mlp_ratio=4,
21
+ qkv_bias=True,
22
+ drop_rate=0.0,
23
+ attn_drop_rate=0.0,
24
+ drop_path_rate=0.0,
25
+ embed_layer=PatchEmbed,
26
+ num_classes=0, # These
27
+ global_pool='', # disable the
28
+ class_token=False, # classifier head.
29
+ )
30
+
31
+ def forward(self, x):
32
+ return self.forward_features(x)
33
+
34
+
35
+ class RMSNorm(nn.RMSNorm):
36
+ def forward(self, x):
37
+ return super().forward(x.float()).type(x.dtype)
38
+
39
+
40
+ class Linear(nn.Linear):
41
+ def forward(self, x: Tensor) -> Tensor:
42
+ return F.linear(x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype))
43
+
44
+
45
+ class TextDecoder(nn.Module):
46
+ def __init__(self, config, ) -> None:
47
+ super().__init__()
48
+ self.config = config
49
+ self.n_head = 2 * config.n_head
50
+ self.tok_embed = nn.Embedding(config.vocab_size, config.n_embed)
51
+ self.pos_embed = nn.Parameter(torch.Tensor(
52
+ 1, config.block_size, config.n_embed))
53
+ self.dropout = nn.Dropout(config.dropout)
54
+
55
+ self.sa_ln = RMSNorm(config.n_embed)
56
+ self.sa_attn = nn.MultiheadAttention(config.n_embed, self.n_head, dropout=config.dropout, batch_first=True)
57
+
58
+ self.cross_ln = RMSNorm(config.n_embed)
59
+ self.cross_attn = nn.MultiheadAttention(config.n_embed, self.n_head, dropout=config.dropout, batch_first=True)
60
+
61
+ self.ffn_ln = RMSNorm(config.n_embed)
62
+ dim_feedforward = 4*config.n_embed
63
+ self.ffn = nn.Sequential(
64
+ Linear(config.n_embed, dim_feedforward, bias=config.bias),
65
+ nn.GELU(),
66
+ Linear(dim_feedforward, config.n_embed, bias=config.bias),
67
+ nn.Dropout(config.dropout)
68
+ )
69
+ self.lm_head = Linear(config.n_embed, config.vocab_size)
70
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
71
+
72
+ def forward(self, x: Tensor, xi: Tensor):
73
+ """
74
+ x: input token ids
75
+ xi: image features (already normalized by ImageEncoder)
76
+ """
77
+ b, t = x.size()
78
+ tok_embed = self.tok_embed(x) * math.sqrt(self.config.n_embed)
79
+
80
+ ctx = torch.cat(
81
+ [tok_embed[:, :1], self.pos_embed[:, :t-1] + tok_embed[:, 1:]], dim=1)
82
+ ctx = self.dropout(ctx)
83
+ ctx = self.sa_ln(ctx)
84
+ res = self.dropout(self.pos_embed[:, :t].expand(b, -1, -1)) # (b, t, n_embed)
85
+
86
+ mask = torch.triu(torch.ones((t, t), dtype=torch.bool, device=x.device), 1)
87
+ query, sa_weights = self.sa_attn(self.sa_ln(res), ctx, ctx, attn_mask=mask)
88
+ res = res + query
89
+ query, ca_weights = self.cross_attn(self.cross_ln(res), xi, xi)
90
+ res = res + query
91
+ res = res + self.ffn(self.ffn_ln(res))
92
+ return self.lm_head(res[:, [-1], :]).float()
93
+
94
+
95
+ class OCRModel(nn.Module):
96
+ def __init__(self, config, tokenizer) -> None:
97
+ super().__init__()
98
+ self.encoder = ImageEncoder(config)
99
+ self.decoder = TextDecoder(config)
100
+ self.tokenizer = tokenizer
101
+
102
+ def forward(self, img_tensor: Tensor, input_tokens: Tensor):
103
+ xi = self.encoder(img_tensor)
104
+ logits, loss = self.decoder(input_tokens, xi)
105
+ return logits, loss
106
+
107
+ @torch.inference_mode()
108
+ def generate(self, img_tensor: Tensor, max_new_tokens: int, temperature=1.0, top_k=None):
109
+ xi = self.encoder(img_tensor.unsqueeze(0))
110
+ idx = torch.full((xi.size(0),1), fill_value=self.tokenizer.bos_id, dtype=torch.long, device=img_tensor.device)
111
+ for i in range(max_new_tokens):
112
+ logits = self.decoder(idx, xi)
113
+ logits = logits[:, -1, :] / temperature
114
+ if top_k is not None:
115
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
116
+ logits[logits < v[:, [-1]]] = -float('inf')
117
+ probs = F.softmax(logits, dim=-1)
118
+ idx_next = torch.multinomial(probs, num_samples=1)
119
+ idx = torch.cat((idx, idx_next), dim=1)
120
+ if idx_next.item() == self.tokenizer.eos_id:
121
+ break
122
+ return self.tokenizer.decode(idx[0].tolist(), ignore_special_tokens=True)
123
+
124
+
125
+ @dataclass
126
+ class ModelConfig:
127
+ img_size: Sequence[int]
128
+ patch_size: Sequence[int]
129
+ n_channel: int
130
+ vocab_size: int
131
+ block_size: int
132
+ n_layer: int
133
+ n_head: int
134
+ n_embed: int
135
+ dropout: float = 0.0
136
+ bias: bool = True
137
+
138
+
139
+ def load_model():
140
+ import pickle
141
+ with open('tokenizer.pkl', 'rb') as inp:
142
+ tokenizer = pickle.load(inp)
143
+ config = ModelConfig(
144
+ img_size=(32, 128),
145
+ patch_size=(4, 8),
146
+ n_channel=3,
147
+ vocab_size=len(tokenizer),
148
+ block_size=192,
149
+ n_layer=12,
150
+ n_head=3,
151
+ n_embed=192,
152
+ dropout=0.1,
153
+ bias=True,
154
+ )
155
+ model = OCRModel(config, tokenizer)
156
+ state_dict = torch.hub.load_state_dict_from_url('https://huggingface.co/KrorngAI/PARSeqForKhmer/resolve/main/parseq_kh.pt', map_location=torch.device('cpu'))
157
+ model.load_state_dict(state_dict)
158
+ return model