farbodpya commited on
Commit
16ec975
·
verified ·
1 Parent(s): 540d190

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +29 -47
README.md CHANGED
@@ -30,32 +30,41 @@ trained with CTC loss to extract text from images.
30
 
31
  ## Usage Example
32
 
33
- import json
34
  import torch
35
- import torch.nn as nn
36
  from PIL import Image
37
  import torchvision.transforms.functional as TF
38
  import cv2
39
  import matplotlib.pyplot as plt
40
  from huggingface_hub import hf_hub_download
 
 
41
 
42
- # -----------------------------
43
- # 1️⃣ Device
44
- # -----------------------------
45
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46
 
47
- # -----------------------------
48
- # 2️⃣ Load vocab
49
- # -----------------------------
50
  vocab_path = hf_hub_download(repo_id="farbodpya/Persian-OCR", filename="vocab.json")
51
  with open(vocab_path, "r", encoding="utf-8") as f:
52
- vocab = json.load(f)
53
- char_to_idx = vocab["char_to_idx"]
54
- idx_to_char = {int(k): v for k, v in vocab["idx_to_char"].items()}
55
 
56
- # -----------------------------
57
- # 3️⃣ Model definition
58
- # -----------------------------
 
 
 
 
 
 
 
 
 
 
 
 
59
  def GN(c, groups=16): return nn.GroupNorm(min(groups, c), c)
60
 
61
  class LightResNetCNN(nn.Module):
@@ -105,33 +114,13 @@ class CNN_Transformer_OCR(nn.Module):
105
  out = self.fc(out)
106
  return out.log_softmax(2)
107
 
108
- # -----------------------------
109
- # 4️⃣ Load model weights
110
- # -----------------------------
111
  model_path = hf_hub_download(repo_id="farbodpya/Persian-OCR", filename="pytorch_model.bin")
112
  model = CNN_Transformer_OCR(num_classes=len(idx_to_char)+1).to(device)
113
  model.load_state_dict(torch.load(model_path, map_location=device))
114
  model.eval()
115
 
116
- # -----------------------------
117
- # 5️⃣ Greedy decoder
118
- # -----------------------------
119
- def greedy_decode(output, idx_to_char):
120
- output = output.argmax(2)
121
- texts = []
122
- for seq in output:
123
- prev = -1
124
- chars = []
125
- for idx in seq.cpu().numpy():
126
- if idx != prev and idx != 0:
127
- chars.append(idx_to_char.get(idx, ""))
128
- prev = idx
129
- texts.append("".join(chars))
130
- return texts
131
-
132
- # -----------------------------
133
- # 6️⃣ Transforms
134
- # -----------------------------
135
  class OCRTestTransform:
136
  def __init__(self, img_height=64, max_width=1600):
137
  self.img_height = img_height
@@ -149,9 +138,7 @@ class OCRTestTransform:
149
 
150
  transform_test = OCRTestTransform()
151
 
152
- # -----------------------------
153
- # 7️⃣ Line segmentation
154
- # -----------------------------
155
  def segment_lines_precise(image_path, min_line_height=12, margin=6, visualize=False):
156
  img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
157
  _, binary = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
@@ -176,9 +163,7 @@ def segment_lines_precise(image_path, min_line_height=12, margin=6, visualize=Fa
176
  plt.show()
177
  return lines
178
 
179
- # -----------------------------
180
- # 8️⃣ OCR function
181
- # -----------------------------
182
  def ocr_page(image_path, visualize=False):
183
  lines = segment_lines_precise(image_path, visualize=visualize)
184
  all_texts = []
@@ -191,10 +176,7 @@ def ocr_page(image_path, visualize=False):
191
  print(f"Line {idx}: {pred_text}")
192
  return "\n".join(all_texts)
193
 
194
- # -----------------------------
195
- # 9️⃣ Example usage
196
- # -----------------------------
197
- img_path = "example.png" # put your own image path here
198
  final_text = ocr_page(img_path, visualize=True)
199
  print("\n=== Final OCR Page ===\n", final_text)
200
-
 
30
 
31
  ## Usage Example
32
 
33
+ ```python
34
  import torch
 
35
  from PIL import Image
36
  import torchvision.transforms.functional as TF
37
  import cv2
38
  import matplotlib.pyplot as plt
39
  from huggingface_hub import hf_hub_download
40
+ import torch.nn as nn
41
+ import json
42
 
43
+ # --- Device ---
 
 
44
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45
 
46
+ # --- Load vocab ---
 
 
47
  vocab_path = hf_hub_download(repo_id="farbodpya/Persian-OCR", filename="vocab.json")
48
  with open(vocab_path, "r", encoding="utf-8") as f:
49
+ idx_to_char = json.load(f)
50
+ idx_to_char = {int(k): v for k, v in idx_to_char.items()}
51
+ char_to_idx = {v: k for k, v in idx_to_char.items()}
52
 
53
+ # --- Greedy decoder ---
54
+ def greedy_decode(logits, idx_to_char):
55
+ pred = logits.argmax(-1).cpu().numpy()
56
+ texts = []
57
+ for seq in pred:
58
+ prev = -1
59
+ text = ""
60
+ for p in seq:
61
+ if p != prev and p != len(idx_to_char): # CTC blank
62
+ text += idx_to_char[p]
63
+ prev = p
64
+ texts.append(text)
65
+ return texts
66
+
67
+ # --- Define CNN + Transformer Model ---
68
  def GN(c, groups=16): return nn.GroupNorm(min(groups, c), c)
69
 
70
  class LightResNetCNN(nn.Module):
 
114
  out = self.fc(out)
115
  return out.log_softmax(2)
116
 
117
+ # --- Load pretrained model ---
 
 
118
  model_path = hf_hub_download(repo_id="farbodpya/Persian-OCR", filename="pytorch_model.bin")
119
  model = CNN_Transformer_OCR(num_classes=len(idx_to_char)+1).to(device)
120
  model.load_state_dict(torch.load(model_path, map_location=device))
121
  model.eval()
122
 
123
+ # --- Define transforms ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  class OCRTestTransform:
125
  def __init__(self, img_height=64, max_width=1600):
126
  self.img_height = img_height
 
138
 
139
  transform_test = OCRTestTransform()
140
 
141
+ # --- Line segmentation ---
 
 
142
  def segment_lines_precise(image_path, min_line_height=12, margin=6, visualize=False):
143
  img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
144
  _, binary = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
 
163
  plt.show()
164
  return lines
165
 
166
+ # --- OCR function ---
 
 
167
  def ocr_page(image_path, visualize=False):
168
  lines = segment_lines_precise(image_path, visualize=visualize)
169
  all_texts = []
 
176
  print(f"Line {idx}: {pred_text}")
177
  return "\n".join(all_texts)
178
 
179
+ # --- Example ---
180
+ img_path = "example.png"
 
 
181
  final_text = ocr_page(img_path, visualize=True)
182
  print("\n=== Final OCR Page ===\n", final_text)