farbodpya commited on
Commit
ba0b33a
·
verified ·
1 Parent(s): 096ba65

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +79 -1
README.md CHANGED
@@ -19,4 +19,82 @@ trained with CTC loss to extract text from images.
19
  ## Installation
20
 
21
  ```bash
22
- pip install torch torchvision huggingface_hub
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  ## Installation
20
 
21
  ```bash
22
+ pip install torch torchvision huggingface_hub
23
+
24
+
25
+
26
+ ## Usage Example
27
+
28
+ import torch
29
+ from PIL import Image
30
+ import torchvision.transforms.functional as TF
31
+ import cv2
32
+ import matplotlib.pyplot as plt
33
+ from huggingface_hub import hf_hub_download
34
+
35
+ # --- Load model from HF ---
36
+ model_path = hf_hub_download(repo_id="farbodpya/Persian-OCR", filename="pytorch_model.bin")
37
+ model = CNN_Transformer_OCR(num_classes=len(idx_to_char)+1).to(device)
38
+ model.load_state_dict(torch.load(model_path, map_location=device))
39
+ model.eval()
40
+
41
+ # --- Define transforms ---
42
+ class OCRTestTransform:
43
+ def __init__(self, img_height=64, max_width=1600):
44
+ self.img_height = img_height
45
+ self.max_width = max_width
46
+ def __call__(self, img):
47
+ img = img.convert("L")
48
+ w, h = img.size
49
+ new_w = int(w * self.img_height / h)
50
+ img = img.resize((min(new_w, self.max_width), self.img_height), Image.BICUBIC)
51
+ new_img = Image.new("L", (self.max_width, self.img_height), 255)
52
+ new_img.paste(img, (0, 0))
53
+ img = TF.to_tensor(new_img)
54
+ img = TF.normalize(img, (0.5,), (0.5,))
55
+ return img
56
+
57
+ transform_test = OCRTestTransform()
58
+
59
+ # --- Line segmentation ---
60
+ def segment_lines_precise(image_path, min_line_height=12, margin=6, visualize=False):
61
+ img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
62
+ _, binary = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
63
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (img.shape[1]//30, 1))
64
+ morphed = cv2.dilate(binary, kernel, iterations=1)
65
+ contours, _ = cv2.findContours(morphed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
66
+ contours = sorted(contours, key=lambda ctr: cv2.boundingRect(ctr)[1])
67
+ lines = []
68
+ for ctr in contours:
69
+ x, y, w, h = cv2.boundingRect(ctr)
70
+ if h < min_line_height: continue
71
+ y1 = max(0, y - margin)
72
+ y2 = min(img.shape[0], y + h + margin)
73
+ line_img = img[y1:y2, x:x+w]
74
+ lines.append(Image.fromarray(line_img))
75
+ if visualize:
76
+ for i, line_img in enumerate(lines):
77
+ plt.figure(figsize=(12,2))
78
+ plt.imshow(line_img, cmap='gray')
79
+ plt.axis('off')
80
+ plt.title(f"Line {i+1}")
81
+ plt.show()
82
+ return lines
83
+
84
+ # --- OCR function ---
85
+ def ocr_page(image_path, visualize=False):
86
+ lines = segment_lines_precise(image_path, visualize=visualize)
87
+ all_texts = []
88
+ for idx, line_img in enumerate(lines, 1):
89
+ img_tensor = transform_test(line_img).unsqueeze(0).to(device)
90
+ with torch.no_grad():
91
+ outputs = model(img_tensor)
92
+ pred_text = greedy_decode(outputs, idx_to_char)[0]
93
+ all_texts.append(pred_text)
94
+ print(f"Line {idx}: {pred_text}")
95
+ return "\n".join(all_texts)
96
+
97
+ # --- Example ---
98
+ img_path = "example.png"
99
+ final_text = ocr_page(img_path, visualize=True)
100
+ print("\n=== Final OCR Page ===\n", final_text)