farbodpya commited on
Commit
540d190
·
verified ·
1 Parent(s): db96333

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +102 -7
README.md CHANGED
@@ -24,26 +24,114 @@ trained with CTC loss to extract text from images.
24
  ## Installation
25
 
26
  ```bash
27
- pip install torch torchvision huggingface_hub
28
 
29
 
30
 
31
  ## Usage Example
32
 
 
33
  import torch
 
34
  from PIL import Image
35
  import torchvision.transforms.functional as TF
36
  import cv2
37
  import matplotlib.pyplot as plt
38
  from huggingface_hub import hf_hub_download
39
 
40
- # --- Load model from HF ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  model_path = hf_hub_download(repo_id="farbodpya/Persian-OCR", filename="pytorch_model.bin")
42
  model = CNN_Transformer_OCR(num_classes=len(idx_to_char)+1).to(device)
43
  model.load_state_dict(torch.load(model_path, map_location=device))
44
  model.eval()
45
 
46
- # --- Define transforms ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  class OCRTestTransform:
48
  def __init__(self, img_height=64, max_width=1600):
49
  self.img_height = img_height
@@ -61,7 +149,9 @@ class OCRTestTransform:
61
 
62
  transform_test = OCRTestTransform()
63
 
64
- # --- Line segmentation ---
 
 
65
  def segment_lines_precise(image_path, min_line_height=12, margin=6, visualize=False):
66
  img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
67
  _, binary = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
@@ -86,7 +176,9 @@ def segment_lines_precise(image_path, min_line_height=12, margin=6, visualize=Fa
86
  plt.show()
87
  return lines
88
 
89
- # --- OCR function ---
 
 
90
  def ocr_page(image_path, visualize=False):
91
  lines = segment_lines_precise(image_path, visualize=visualize)
92
  all_texts = []
@@ -99,7 +191,10 @@ def ocr_page(image_path, visualize=False):
99
  print(f"Line {idx}: {pred_text}")
100
  return "\n".join(all_texts)
101
 
102
- # --- Example ---
103
- img_path = "example.png"
 
 
104
  final_text = ocr_page(img_path, visualize=True)
105
  print("\n=== Final OCR Page ===\n", final_text)
 
 
24
  ## Installation
25
 
26
  ```bash
27
+ !pip install torch torchvision huggingface_hub
28
 
29
 
30
 
31
  ## Usage Example
32
 
33
+ import json
34
  import torch
35
+ import torch.nn as nn
36
  from PIL import Image
37
  import torchvision.transforms.functional as TF
38
  import cv2
39
  import matplotlib.pyplot as plt
40
  from huggingface_hub import hf_hub_download
41
 
42
+ # -----------------------------
43
+ # 1️⃣ Device
44
+ # -----------------------------
45
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46
+
47
+ # -----------------------------
48
+ # 2️⃣ Load vocab
49
+ # -----------------------------
50
+ vocab_path = hf_hub_download(repo_id="farbodpya/Persian-OCR", filename="vocab.json")
51
+ with open(vocab_path, "r", encoding="utf-8") as f:
52
+ vocab = json.load(f)
53
+ char_to_idx = vocab["char_to_idx"]
54
+ idx_to_char = {int(k): v for k, v in vocab["idx_to_char"].items()}
55
+
56
+ # -----------------------------
57
+ # 3️⃣ Model definition
58
+ # -----------------------------
59
+ def GN(c, groups=16): return nn.GroupNorm(min(groups, c), c)
60
+
61
+ class LightResNetCNN(nn.Module):
62
+ def __init__(self, in_channels=1, adaptive_height=8):
63
+ super().__init__()
64
+ self.adaptive_height = adaptive_height
65
+ self.layer1 = nn.Sequential(nn.Conv2d(in_channels, 32, 3, 1, 1), GN(32), nn.ReLU(), nn.MaxPool2d(2, 2))
66
+ self.layer2 = nn.Sequential(nn.Conv2d(32, 64, 3, 1, 1), GN(64), nn.ReLU(), nn.MaxPool2d(2, 2))
67
+ self.layer3 = nn.Sequential(nn.Conv2d(64, 128, 3, 1, 1), GN(128), nn.ReLU(), nn.MaxPool2d(2, 2))
68
+ self.layer4 = nn.Sequential(nn.Conv2d(128, 256, 3, 1, 1), GN(256), nn.ReLU())
69
+ self.layer5 = nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), GN(256), nn.ReLU())
70
+ self.layer6 = nn.Sequential(nn.Conv2d(256, 128, 3, 1, 1), GN(128), nn.ReLU())
71
+ self.adaptive_pool = nn.AdaptiveAvgPool2d((self.adaptive_height, None))
72
+ def forward(self, x):
73
+ for i in range(1, 7):
74
+ x = getattr(self, f"layer{i}")(x)
75
+ x = self.adaptive_pool(x)
76
+ return x
77
+
78
+ class PositionalEncoding(nn.Module):
79
+ def __init__(self, d_model, max_len=2000):
80
+ super().__init__()
81
+ pe = torch.zeros(max_len, d_model)
82
+ position = torch.arange(0, max_len).unsqueeze(1)
83
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-torch.log(torch.tensor(10000.0)) / d_model))
84
+ pe[:, 0::2] = torch.sin(position * div_term)
85
+ pe[:, 1::2] = torch.cos(position * div_term)
86
+ self.register_buffer("pe", pe.unsqueeze(0))
87
+ def forward(self, x):
88
+ return x + self.pe[:, :x.size(1), :]
89
+
90
+ class CNN_Transformer_OCR(nn.Module):
91
+ def __init__(self, num_classes, d_model=1280, nhead=16, num_layers=8, dropout=0.2):
92
+ super().__init__()
93
+ self.cnn = LightResNetCNN(in_channels=1, adaptive_height=8)
94
+ self.proj = nn.Linear(128 * 8, d_model)
95
+ self.posenc = PositionalEncoding(d_model)
96
+ encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, batch_first=True, dropout=dropout)
97
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
98
+ self.fc = nn.Linear(d_model, num_classes)
99
+ def forward(self, x):
100
+ f = self.cnn(x)
101
+ B, C, H, W = f.size()
102
+ f = f.permute(0, 3, 1, 2).reshape(B, W, C * H)
103
+ f = self.posenc(self.proj(f))
104
+ out = self.transformer(f)
105
+ out = self.fc(out)
106
+ return out.log_softmax(2)
107
+
108
+ # -----------------------------
109
+ # 4️⃣ Load model weights
110
+ # -----------------------------
111
  model_path = hf_hub_download(repo_id="farbodpya/Persian-OCR", filename="pytorch_model.bin")
112
  model = CNN_Transformer_OCR(num_classes=len(idx_to_char)+1).to(device)
113
  model.load_state_dict(torch.load(model_path, map_location=device))
114
  model.eval()
115
 
116
+ # -----------------------------
117
+ # 5️⃣ Greedy decoder
118
+ # -----------------------------
119
+ def greedy_decode(output, idx_to_char):
120
+ output = output.argmax(2)
121
+ texts = []
122
+ for seq in output:
123
+ prev = -1
124
+ chars = []
125
+ for idx in seq.cpu().numpy():
126
+ if idx != prev and idx != 0:
127
+ chars.append(idx_to_char.get(idx, ""))
128
+ prev = idx
129
+ texts.append("".join(chars))
130
+ return texts
131
+
132
+ # -----------------------------
133
+ # 6️⃣ Transforms
134
+ # -----------------------------
135
  class OCRTestTransform:
136
  def __init__(self, img_height=64, max_width=1600):
137
  self.img_height = img_height
 
149
 
150
  transform_test = OCRTestTransform()
151
 
152
+ # -----------------------------
153
+ # 7️⃣ Line segmentation
154
+ # -----------------------------
155
  def segment_lines_precise(image_path, min_line_height=12, margin=6, visualize=False):
156
  img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
157
  _, binary = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
 
176
  plt.show()
177
  return lines
178
 
179
+ # -----------------------------
180
+ # 8️⃣ OCR function
181
+ # -----------------------------
182
  def ocr_page(image_path, visualize=False):
183
  lines = segment_lines_precise(image_path, visualize=visualize)
184
  all_texts = []
 
191
  print(f"Line {idx}: {pred_text}")
192
  return "\n".join(all_texts)
193
 
194
+ # -----------------------------
195
+ # 9️⃣ Example usage
196
+ # -----------------------------
197
+ img_path = "example.png" # put your own image path here
198
  final_text = ocr_page(img_path, visualize=True)
199
  print("\n=== Final OCR Page ===\n", final_text)
200
+