farbodpya commited on
Commit
28efaf9
·
verified ·
1 Parent(s): 242f8cb

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +11 -59
model.py CHANGED
@@ -2,28 +2,16 @@ import torch
2
  import torch.nn as nn
3
  from transformers import PreTrainedModel, PretrainedConfig
4
 
5
- # -----------------------------
6
- # GroupNorm Helper
7
- # -----------------------------
8
  def GN(c, groups=16):
9
  return nn.GroupNorm(min(groups, c), c)
10
 
11
- # -----------------------------
12
- # CNN Backbone
13
- # -----------------------------
14
  class LightResNetCNN(nn.Module):
15
  def __init__(self, in_channels=1, adaptive_height=8):
16
  super().__init__()
17
  self.adaptive_height = adaptive_height
18
- self.layer1 = nn.Sequential(
19
- nn.Conv2d(in_channels, 32, 3, 1, 1), GN(32), nn.ReLU(), nn.MaxPool2d(2, 2)
20
- )
21
- self.layer2 = nn.Sequential(
22
- nn.Conv2d(32, 64, 3, 1, 1), GN(64), nn.ReLU(), nn.MaxPool2d(2, 2)
23
- )
24
- self.layer3 = nn.Sequential(
25
- nn.Conv2d(64, 128, 3, 1, 1), GN(128), nn.ReLU(), nn.MaxPool2d(2, 2)
26
- )
27
  self.layer4 = nn.Sequential(nn.Conv2d(128, 256, 3, 1, 1), GN(256), nn.ReLU())
28
  self.layer5 = nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), GN(256), nn.ReLU())
29
  self.layer6 = nn.Sequential(nn.Conv2d(256, 128, 3, 1, 1), GN(128), nn.ReLU())
@@ -35,40 +23,23 @@ class LightResNetCNN(nn.Module):
35
  x = self.adaptive_pool(x)
36
  return x
37
 
38
- # -----------------------------
39
- # Positional Encoding
40
- # -----------------------------
41
  class PositionalEncoding(nn.Module):
42
  def __init__(self, d_model, max_len=2000):
43
  super().__init__()
44
  pe = torch.zeros(max_len, d_model)
45
  position = torch.arange(0, max_len).unsqueeze(1)
46
- div_term = torch.exp(
47
- torch.arange(0, d_model, 2) * (-torch.log(torch.tensor(10000.0)) / d_model)
48
- )
49
  pe[:, 0::2] = torch.sin(position * div_term)
50
  pe[:, 1::2] = torch.cos(position * div_term)
51
  self.register_buffer("pe", pe.unsqueeze(0))
52
 
53
  def forward(self, x):
54
- return x + self.pe[:, : x.size(1), :]
55
 
56
- # -----------------------------
57
- # Config
58
- # -----------------------------
59
  class PersianOCRConfig(PretrainedConfig):
60
  model_type = "persianocr"
61
 
62
- def __init__(
63
- self,
64
- num_classes=100,
65
- d_model=1280,
66
- nhead=16,
67
- num_layers=8,
68
- dropout=0.2,
69
- adaptive_height=8,
70
- **kwargs
71
- ):
72
  super().__init__(**kwargs)
73
  self.num_classes = num_classes
74
  self.d_model = d_model
@@ -77,43 +48,24 @@ class PersianOCRConfig(PretrainedConfig):
77
  self.dropout = dropout
78
  self.adaptive_height = adaptive_height
79
 
80
- # -----------------------------
81
- # Model
82
- # -----------------------------
83
  class PersianOCRModel(PreTrainedModel):
84
  config_class = PersianOCRConfig
85
 
86
  def __init__(self, config):
87
  super().__init__(config)
88
- self.cnn = LightResNetCNN(
89
- in_channels=1, adaptive_height=config.adaptive_height
90
- )
91
  self.proj = nn.Linear(128 * config.adaptive_height, config.d_model)
92
  self.posenc = PositionalEncoding(config.d_model)
93
- encoder_layer = nn.TransformerEncoderLayer(
94
- config.d_model, config.nhead, batch_first=True, dropout=config.dropout
95
- )
96
- self.transformer = nn.TransformerEncoder(
97
- encoder_layer, num_layers=config.num_layers
98
- )
99
  self.fc = nn.Linear(config.d_model, config.num_classes)
100
-
101
- # این خط خیلی مهمه برای HuggingFace
102
  self.post_init()
103
 
104
  def forward(self, x, labels=None):
105
- """
106
- Args:
107
- x: Tensor [batch, 1, H, W] - grayscale input
108
- labels: optional, برای CTC loss
109
- Returns:
110
- dict با logits
111
- """
112
- f = self.cnn(x) # [B, C, H, W]
113
  B, C, H, W = f.size()
114
  f = f.permute(0, 3, 1, 2).reshape(B, W, C * H)
115
  f = self.posenc(self.proj(f))
116
  out = self.transformer(f)
117
- logits = self.fc(out) # [B, W, num_classes]
118
-
119
  return {"logits": logits}
 
2
  import torch.nn as nn
3
  from transformers import PreTrainedModel, PretrainedConfig
4
 
 
 
 
5
  def GN(c, groups=16):
6
  return nn.GroupNorm(min(groups, c), c)
7
 
 
 
 
8
  class LightResNetCNN(nn.Module):
9
  def __init__(self, in_channels=1, adaptive_height=8):
10
  super().__init__()
11
  self.adaptive_height = adaptive_height
12
+ self.layer1 = nn.Sequential(nn.Conv2d(in_channels, 32, 3, 1, 1), GN(32), nn.ReLU(), nn.MaxPool2d(2, 2))
13
+ self.layer2 = nn.Sequential(nn.Conv2d(32, 64, 3, 1, 1), GN(64), nn.ReLU(), nn.MaxPool2d(2, 2))
14
+ self.layer3 = nn.Sequential(nn.Conv2d(64, 128, 3, 1, 1), GN(128), nn.ReLU(), nn.MaxPool2d(2, 2))
 
 
 
 
 
 
15
  self.layer4 = nn.Sequential(nn.Conv2d(128, 256, 3, 1, 1), GN(256), nn.ReLU())
16
  self.layer5 = nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), GN(256), nn.ReLU())
17
  self.layer6 = nn.Sequential(nn.Conv2d(256, 128, 3, 1, 1), GN(128), nn.ReLU())
 
23
  x = self.adaptive_pool(x)
24
  return x
25
 
 
 
 
26
  class PositionalEncoding(nn.Module):
27
  def __init__(self, d_model, max_len=2000):
28
  super().__init__()
29
  pe = torch.zeros(max_len, d_model)
30
  position = torch.arange(0, max_len).unsqueeze(1)
31
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-torch.log(torch.tensor(10000.0)) / d_model))
 
 
32
  pe[:, 0::2] = torch.sin(position * div_term)
33
  pe[:, 1::2] = torch.cos(position * div_term)
34
  self.register_buffer("pe", pe.unsqueeze(0))
35
 
36
  def forward(self, x):
37
+ return x + self.pe[:, :x.size(1), :]
38
 
 
 
 
39
  class PersianOCRConfig(PretrainedConfig):
40
  model_type = "persianocr"
41
 
42
+ def __init__(self, num_classes=100, d_model=1280, nhead=16, num_layers=8, dropout=0.2, adaptive_height=8, **kwargs):
 
 
 
 
 
 
 
 
 
43
  super().__init__(**kwargs)
44
  self.num_classes = num_classes
45
  self.d_model = d_model
 
48
  self.dropout = dropout
49
  self.adaptive_height = adaptive_height
50
 
 
 
 
51
  class PersianOCRModel(PreTrainedModel):
52
  config_class = PersianOCRConfig
53
 
54
  def __init__(self, config):
55
  super().__init__(config)
56
+ self.cnn = LightResNetCNN(in_channels=1, adaptive_height=config.adaptive_height)
 
 
57
  self.proj = nn.Linear(128 * config.adaptive_height, config.d_model)
58
  self.posenc = PositionalEncoding(config.d_model)
59
+ encoder_layer = nn.TransformerEncoderLayer(config.d_model, config.nhead, batch_first=True, dropout=config.dropout)
60
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=config.num_layers)
 
 
 
 
61
  self.fc = nn.Linear(config.d_model, config.num_classes)
 
 
62
  self.post_init()
63
 
64
  def forward(self, x, labels=None):
65
+ f = self.cnn(x)
 
 
 
 
 
 
 
66
  B, C, H, W = f.size()
67
  f = f.permute(0, 3, 1, 2).reshape(B, W, C * H)
68
  f = self.posenc(self.proj(f))
69
  out = self.transformer(f)
70
+ logits = self.fc(out)
 
71
  return {"logits": logits}