farbodpya commited on
Commit
970fa72
·
verified ·
1 Parent(s): 2a5dd21

Update modeling_persianocr.py

Browse files
Files changed (1) hide show
  1. modeling_persianocr.py +60 -21
modeling_persianocr.py CHANGED
@@ -1,24 +1,29 @@
1
- import torch.nn as nn
2
- import torch
3
-
4
- # -----------------------------
5
- # 3️⃣ Model definition
6
- # -----------------------------
7
  import torch
8
  import torch.nn as nn
9
- from torch.nn import functional as F
10
  from transformers import PreTrainedModel, PretrainedConfig
11
 
12
- def GN(c, groups=16):
 
 
 
13
  return nn.GroupNorm(min(groups, c), c)
14
 
 
 
 
15
  class LightResNetCNN(nn.Module):
16
  def __init__(self, in_channels=1, adaptive_height=8):
17
  super().__init__()
18
  self.adaptive_height = adaptive_height
19
- self.layer1 = nn.Sequential(nn.Conv2d(in_channels, 32, 3, 1, 1), GN(32), nn.ReLU(), nn.MaxPool2d(2, 2))
20
- self.layer2 = nn.Sequential(nn.Conv2d(32, 64, 3, 1, 1), GN(64), nn.ReLU(), nn.MaxPool2d(2, 2))
21
- self.layer3 = nn.Sequential(nn.Conv2d(64, 128, 3, 1, 1), GN(128), nn.ReLU(), nn.MaxPool2d(2, 2))
 
 
 
 
 
 
22
  self.layer4 = nn.Sequential(nn.Conv2d(128, 256, 3, 1, 1), GN(256), nn.ReLU())
23
  self.layer5 = nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), GN(256), nn.ReLU())
24
  self.layer6 = nn.Sequential(nn.Conv2d(256, 128, 3, 1, 1), GN(128), nn.ReLU())
@@ -30,23 +35,40 @@ class LightResNetCNN(nn.Module):
30
  x = self.adaptive_pool(x)
31
  return x
32
 
 
 
 
33
  class PositionalEncoding(nn.Module):
34
  def __init__(self, d_model, max_len=2000):
35
  super().__init__()
36
  pe = torch.zeros(max_len, d_model)
37
  position = torch.arange(0, max_len).unsqueeze(1)
38
- div_term = torch.exp(torch.arange(0, d_model, 2) * (-torch.log(torch.tensor(10000.0)) / d_model))
 
 
39
  pe[:, 0::2] = torch.sin(position * div_term)
40
  pe[:, 1::2] = torch.cos(position * div_term)
41
  self.register_buffer("pe", pe.unsqueeze(0))
42
 
43
  def forward(self, x):
44
- return x + self.pe[:, :x.size(1), :]
45
 
 
 
 
46
  class PersianOCRConfig(PretrainedConfig):
47
  model_type = "persianocr"
48
 
49
- def __init__(self, num_classes=100, d_model=1280, nhead=16, num_layers=8, dropout=0.2, adaptive_height=8, **kwargs):
 
 
 
 
 
 
 
 
 
50
  super().__init__(**kwargs)
51
  self.num_classes = num_classes
52
  self.d_model = d_model
@@ -55,26 +77,43 @@ class PersianOCRConfig(PretrainedConfig):
55
  self.dropout = dropout
56
  self.adaptive_height = adaptive_height
57
 
 
 
 
58
  class PersianOCRModel(PreTrainedModel):
59
  config_class = PersianOCRConfig
60
 
61
  def __init__(self, config):
62
  super().__init__(config)
63
- self.cnn = LightResNetCNN(in_channels=1, adaptive_height=config.adaptive_height)
 
 
64
  self.proj = nn.Linear(128 * config.adaptive_height, config.d_model)
65
  self.posenc = PositionalEncoding(config.d_model)
66
- encoder_layer = nn.TransformerEncoderLayer(config.d_model, config.nhead, batch_first=True, dropout=config.dropout)
67
- self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=config.num_layers)
 
 
 
 
68
  self.fc = nn.Linear(config.d_model, config.num_classes)
69
 
70
- self.post_init() # مهم: برای HuggingFace
 
71
 
72
  def forward(self, x, labels=None):
73
- f = self.cnn(x)
 
 
 
 
 
 
 
74
  B, C, H, W = f.size()
75
  f = f.permute(0, 3, 1, 2).reshape(B, W, C * H)
76
  f = self.posenc(self.proj(f))
77
  out = self.transformer(f)
78
- logits = self.fc(out)
79
- return {"logits": logits}
80
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import torch.nn as nn
 
3
  from transformers import PreTrainedModel, PretrainedConfig
4
 
5
+ # -----------------------------
6
+ # GroupNorm Helper
7
+ # -----------------------------
8
+ def GN(c, groups=16):
9
  return nn.GroupNorm(min(groups, c), c)
10
 
11
+ # -----------------------------
12
+ # CNN Backbone
13
+ # -----------------------------
14
  class LightResNetCNN(nn.Module):
15
  def __init__(self, in_channels=1, adaptive_height=8):
16
  super().__init__()
17
  self.adaptive_height = adaptive_height
18
+ self.layer1 = nn.Sequential(
19
+ nn.Conv2d(in_channels, 32, 3, 1, 1), GN(32), nn.ReLU(), nn.MaxPool2d(2, 2)
20
+ )
21
+ self.layer2 = nn.Sequential(
22
+ nn.Conv2d(32, 64, 3, 1, 1), GN(64), nn.ReLU(), nn.MaxPool2d(2, 2)
23
+ )
24
+ self.layer3 = nn.Sequential(
25
+ nn.Conv2d(64, 128, 3, 1, 1), GN(128), nn.ReLU(), nn.MaxPool2d(2, 2)
26
+ )
27
  self.layer4 = nn.Sequential(nn.Conv2d(128, 256, 3, 1, 1), GN(256), nn.ReLU())
28
  self.layer5 = nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), GN(256), nn.ReLU())
29
  self.layer6 = nn.Sequential(nn.Conv2d(256, 128, 3, 1, 1), GN(128), nn.ReLU())
 
35
  x = self.adaptive_pool(x)
36
  return x
37
 
38
+ # -----------------------------
39
+ # Positional Encoding
40
+ # -----------------------------
41
  class PositionalEncoding(nn.Module):
42
  def __init__(self, d_model, max_len=2000):
43
  super().__init__()
44
  pe = torch.zeros(max_len, d_model)
45
  position = torch.arange(0, max_len).unsqueeze(1)
46
+ div_term = torch.exp(
47
+ torch.arange(0, d_model, 2) * (-torch.log(torch.tensor(10000.0)) / d_model)
48
+ )
49
  pe[:, 0::2] = torch.sin(position * div_term)
50
  pe[:, 1::2] = torch.cos(position * div_term)
51
  self.register_buffer("pe", pe.unsqueeze(0))
52
 
53
  def forward(self, x):
54
+ return x + self.pe[:, : x.size(1), :]
55
 
56
+ # -----------------------------
57
+ # Config
58
+ # -----------------------------
59
  class PersianOCRConfig(PretrainedConfig):
60
  model_type = "persianocr"
61
 
62
+ def __init__(
63
+ self,
64
+ num_classes=100,
65
+ d_model=1280,
66
+ nhead=16,
67
+ num_layers=8,
68
+ dropout=0.2,
69
+ adaptive_height=8,
70
+ **kwargs
71
+ ):
72
  super().__init__(**kwargs)
73
  self.num_classes = num_classes
74
  self.d_model = d_model
 
77
  self.dropout = dropout
78
  self.adaptive_height = adaptive_height
79
 
80
+ # -----------------------------
81
+ # Model
82
+ # -----------------------------
83
  class PersianOCRModel(PreTrainedModel):
84
  config_class = PersianOCRConfig
85
 
86
  def __init__(self, config):
87
  super().__init__(config)
88
+ self.cnn = LightResNetCNN(
89
+ in_channels=1, adaptive_height=config.adaptive_height
90
+ )
91
  self.proj = nn.Linear(128 * config.adaptive_height, config.d_model)
92
  self.posenc = PositionalEncoding(config.d_model)
93
+ encoder_layer = nn.TransformerEncoderLayer(
94
+ config.d_model, config.nhead, batch_first=True, dropout=config.dropout
95
+ )
96
+ self.transformer = nn.TransformerEncoder(
97
+ encoder_layer, num_layers=config.num_layers
98
+ )
99
  self.fc = nn.Linear(config.d_model, config.num_classes)
100
 
101
+ # این خط خیلی مهمه برای HuggingFace
102
+ self.post_init()
103
 
104
  def forward(self, x, labels=None):
105
+ """
106
+ Args:
107
+ x: Tensor [batch, 1, H, W] - grayscale input
108
+ labels: optional, برای CTC loss
109
+ Returns:
110
+ dict با logits
111
+ """
112
+ f = self.cnn(x) # [B, C, H, W]
113
  B, C, H, W = f.size()
114
  f = f.permute(0, 3, 1, 2).reshape(B, W, C * H)
115
  f = self.posenc(self.proj(f))
116
  out = self.transformer(f)
117
+ logits = self.fc(out) # [B, W, num_classes]
 
118
 
119
+ return {"logits": logits}