farbodpya commited on
Commit
e0399f1
·
verified ·
1 Parent(s): 16ec975

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +47 -29
README.md CHANGED
@@ -30,41 +30,33 @@ trained with CTC loss to extract text from images.
30
 
31
  ## Usage Example
32
 
33
- ```python
 
34
  import torch
 
35
  from PIL import Image
36
  import torchvision.transforms.functional as TF
37
  import cv2
38
  import matplotlib.pyplot as plt
39
  from huggingface_hub import hf_hub_download
40
- import torch.nn as nn
41
- import json
42
 
43
- # --- Device ---
 
 
44
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45
 
46
- # --- Load vocab ---
 
 
47
  vocab_path = hf_hub_download(repo_id="farbodpya/Persian-OCR", filename="vocab.json")
48
  with open(vocab_path, "r", encoding="utf-8") as f:
49
- idx_to_char = json.load(f)
50
- idx_to_char = {int(k): v for k, v in idx_to_char.items()}
51
- char_to_idx = {v: k for k, v in idx_to_char.items()}
52
 
53
- # --- Greedy decoder ---
54
- def greedy_decode(logits, idx_to_char):
55
- pred = logits.argmax(-1).cpu().numpy()
56
- texts = []
57
- for seq in pred:
58
- prev = -1
59
- text = ""
60
- for p in seq:
61
- if p != prev and p != len(idx_to_char): # CTC blank
62
- text += idx_to_char[p]
63
- prev = p
64
- texts.append(text)
65
- return texts
66
-
67
- # --- Define CNN + Transformer Model ---
68
  def GN(c, groups=16): return nn.GroupNorm(min(groups, c), c)
69
 
70
  class LightResNetCNN(nn.Module):
@@ -114,13 +106,33 @@ class CNN_Transformer_OCR(nn.Module):
114
  out = self.fc(out)
115
  return out.log_softmax(2)
116
 
117
- # --- Load pretrained model ---
 
 
118
  model_path = hf_hub_download(repo_id="farbodpya/Persian-OCR", filename="pytorch_model.bin")
119
  model = CNN_Transformer_OCR(num_classes=len(idx_to_char)+1).to(device)
120
  model.load_state_dict(torch.load(model_path, map_location=device))
121
  model.eval()
122
 
123
- # --- Define transforms ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  class OCRTestTransform:
125
  def __init__(self, img_height=64, max_width=1600):
126
  self.img_height = img_height
@@ -138,7 +150,9 @@ class OCRTestTransform:
138
 
139
  transform_test = OCRTestTransform()
140
 
141
- # --- Line segmentation ---
 
 
142
  def segment_lines_precise(image_path, min_line_height=12, margin=6, visualize=False):
143
  img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
144
  _, binary = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
@@ -163,7 +177,9 @@ def segment_lines_precise(image_path, min_line_height=12, margin=6, visualize=Fa
163
  plt.show()
164
  return lines
165
 
166
- # --- OCR function ---
 
 
167
  def ocr_page(image_path, visualize=False):
168
  lines = segment_lines_precise(image_path, visualize=visualize)
169
  all_texts = []
@@ -176,7 +192,9 @@ def ocr_page(image_path, visualize=False):
176
  print(f"Line {idx}: {pred_text}")
177
  return "\n".join(all_texts)
178
 
179
- # --- Example ---
180
- img_path = "example.png"
 
 
181
  final_text = ocr_page(img_path, visualize=True)
182
  print("\n=== Final OCR Page ===\n", final_text)
 
30
 
31
  ## Usage Example
32
 
33
+
34
+ import json
35
  import torch
36
+ import torch.nn as nn
37
  from PIL import Image
38
  import torchvision.transforms.functional as TF
39
  import cv2
40
  import matplotlib.pyplot as plt
41
  from huggingface_hub import hf_hub_download
 
 
42
 
43
+ # -----------------------------
44
+ # 1️⃣ Device
45
+ # -----------------------------
46
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
47
 
48
+ # -----------------------------
49
+ # 2️⃣ Load vocab
50
+ # -----------------------------
51
  vocab_path = hf_hub_download(repo_id="farbodpya/Persian-OCR", filename="vocab.json")
52
  with open(vocab_path, "r", encoding="utf-8") as f:
53
+ vocab = json.load(f)
54
+ char_to_idx = vocab["char_to_idx"]
55
+ idx_to_char = {int(k): v for k, v in vocab["idx_to_char"].items()}
56
 
57
+ # -----------------------------
58
+ # 3️⃣ Model definition
59
+ # -----------------------------
 
 
 
 
 
 
 
 
 
 
 
 
60
  def GN(c, groups=16): return nn.GroupNorm(min(groups, c), c)
61
 
62
  class LightResNetCNN(nn.Module):
 
106
  out = self.fc(out)
107
  return out.log_softmax(2)
108
 
109
+ # -----------------------------
110
+ # 4️⃣ Load model weights
111
+ # -----------------------------
112
  model_path = hf_hub_download(repo_id="farbodpya/Persian-OCR", filename="pytorch_model.bin")
113
  model = CNN_Transformer_OCR(num_classes=len(idx_to_char)+1).to(device)
114
  model.load_state_dict(torch.load(model_path, map_location=device))
115
  model.eval()
116
 
117
+ # -----------------------------
118
+ # 5️⃣ Greedy decoder
119
+ # -----------------------------
120
+ def greedy_decode(output, idx_to_char):
121
+ output = output.argmax(2)
122
+ texts = []
123
+ for seq in output:
124
+ prev = -1
125
+ chars = []
126
+ for idx in seq.cpu().numpy():
127
+ if idx != prev and idx != 0:
128
+ chars.append(idx_to_char.get(idx, ""))
129
+ prev = idx
130
+ texts.append("".join(chars))
131
+ return texts
132
+
133
+ # -----------------------------
134
+ # 6️⃣ Transforms
135
+ # -----------------------------
136
  class OCRTestTransform:
137
  def __init__(self, img_height=64, max_width=1600):
138
  self.img_height = img_height
 
150
 
151
  transform_test = OCRTestTransform()
152
 
153
+ # -----------------------------
154
+ # 7️⃣ Line segmentation
155
+ # -----------------------------
156
  def segment_lines_precise(image_path, min_line_height=12, margin=6, visualize=False):
157
  img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
158
  _, binary = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
 
177
  plt.show()
178
  return lines
179
 
180
+ # -----------------------------
181
+ # 8️⃣ OCR function
182
+ # -----------------------------
183
  def ocr_page(image_path, visualize=False):
184
  lines = segment_lines_precise(image_path, visualize=visualize)
185
  all_texts = []
 
192
  print(f"Line {idx}: {pred_text}")
193
  return "\n".join(all_texts)
194
 
195
+ # -----------------------------
196
+ # 9️⃣ Example usage
197
+ # -----------------------------
198
+ img_path = "/content/farsi_line.png" # put your own image path here
199
  final_text = ocr_page(img_path, visualize=True)
200
  print("\n=== Final OCR Page ===\n", final_text)