farbodpya commited on
Commit
c56b18e
·
verified ·
1 Parent(s): 0c3af99

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +137 -74
inference.py CHANGED
@@ -1,74 +1,137 @@
1
- import torch
2
- from PIL import Image
3
- import torchvision.transforms.functional as TF
4
- import cv2
5
- import matplotlib.pyplot as plt
6
- from huggingface_hub import hf_hub_download
7
-
8
-
9
- # --- Load model from HF ---
10
- model_path = hf_hub_download(repo_id="farbodpya/Persian-OCR", filename="pytorch_model.bin")
11
- model = CNN_Transformer_OCR(num_classes=len(idx_to_char)+1).to(device)
12
- model.load_state_dict(torch.load(model_path, map_location=device))
13
- model.eval()
14
-
15
- # --- Define transforms ---
16
- class OCRTestTransform:
17
- def __init__(self, img_height=64, max_width=1600):
18
- self.img_height = img_height
19
- self.max_width = max_width
20
- def __call__(self, img):
21
- img = img.convert("L")
22
- w, h = img.size
23
- new_w = int(w * self.img_height / h)
24
- img = img.resize((min(new_w, self.max_width), self.img_height), Image.BICUBIC)
25
- new_img = Image.new("L", (self.max_width, self.img_height), 255)
26
- new_img.paste(img, (0, 0))
27
- img = TF.to_tensor(new_img)
28
- img = TF.normalize(img, (0.5,), (0.5,))
29
- return img
30
-
31
- transform_test = OCRTestTransform()
32
-
33
- # --- Line segmentation ---
34
- def segment_lines_precise(image_path, min_line_height=12, margin=6, visualize=False):
35
- img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
36
- _, binary = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
37
- kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (img.shape[1]//30, 1))
38
- morphed = cv2.dilate(binary, kernel, iterations=1)
39
- contours, _ = cv2.findContours(morphed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
40
- contours = sorted(contours, key=lambda ctr: cv2.boundingRect(ctr)[1])
41
- lines = []
42
- for ctr in contours:
43
- x, y, w, h = cv2.boundingRect(ctr)
44
- if h < min_line_height: continue
45
- y1 = max(0, y - margin)
46
- y2 = min(img.shape[0], y + h + margin)
47
- line_img = img[y1:y2, x:x+w]
48
- lines.append(Image.fromarray(line_img))
49
- if visualize:
50
- for i, line_img in enumerate(lines):
51
- plt.figure(figsize=(12,2))
52
- plt.imshow(line_img, cmap='gray')
53
- plt.axis('off')
54
- plt.title(f"Line {i+1}")
55
- plt.show()
56
- return lines
57
-
58
- # --- OCR function ---
59
- def ocr_page(image_path, visualize=False):
60
- lines = segment_lines_precise(image_path, visualize=visualize)
61
- all_texts = []
62
- for idx, line_img in enumerate(lines, 1):
63
- img_tensor = transform_test(line_img).unsqueeze(0).to(device)
64
- with torch.no_grad():
65
- outputs = model(img_tensor)
66
- pred_text = greedy_decode(outputs, idx_to_char)[0]
67
- all_texts.append(pred_text)
68
- print(f"Line {idx}: {pred_text}")
69
- return "\n".join(all_texts)
70
-
71
- # --- Example ---
72
- img_path = "example.png"
73
- final_text = ocr_page(img_path, visualize=True)
74
- print("\n=== Final OCR Page ===\n", final_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ import torchvision.transforms.functional as TF
4
+ import cv2
5
+ import matplotlib.pyplot as plt
6
+ from huggingface_hub import hf_hub_download
7
+ import torch.nn as nn
8
+
9
+ # -----------------------------
10
+ # Device
11
+ # -----------------------------
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+
14
+ # -----------------------------
15
+ # Dummy idx_to_char
16
+ # -----------------------------
17
+
18
+ import json
19
+ vocab_path = hf_hub_download(repo_id="farbodpya/Persian-OCR", filename="vocab.json")
20
+ with open(vocab_path, "r", encoding="utf-8") as f:
21
+ vocab = json.load(f)
22
+ char_to_idx = vocab["char_to_idx"]
23
+ idx_to_char = {int(k): v for k, v in vocab["idx_to_char"].items()}
24
+
25
+ # -----------------------------
26
+ # Model
27
+ # -----------------------------
28
+ class CNN_Transformer_OCR(nn.Module):
29
+ # این کلاس همان مدل خودت است
30
+ def __init__(self, num_classes, d_model=1280, nhead=16, num_layers=8, dropout=0.2):
31
+ super().__init__()
32
+ # برای کوتاهی، همان CNN+Transformer خودت
33
+ self.cnn = nn.Sequential(
34
+ nn.Conv2d(1, 32, 3, 1, 1), nn.ReLU(), nn.MaxPool2d(2,2),
35
+ nn.Conv2d(32, 64, 3, 1, 1), nn.ReLU(), nn.MaxPool2d(2,2)
36
+ )
37
+ self.proj = nn.Linear(64*16, d_model)
38
+ encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, batch_first=True, dropout=dropout)
39
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
40
+ self.fc = nn.Linear(d_model, num_classes)
41
+ def forward(self, x):
42
+ f = self.cnn(x)
43
+ B, C, H, W = f.size()
44
+ f = f.permute(0,3,1,2).reshape(B,W,C*H)
45
+ out = self.transformer(self.proj(f))
46
+ return self.fc(out).log_softmax(2)
47
+
48
+ # Load model weights
49
+ model_path = hf_hub_download(repo_id="farbodpya/Persian-OCR", filename="pytorch_model.bin")
50
+ model = CNN_Transformer_OCR(num_classes=len(idx_to_char)+1).to(device)
51
+ model.load_state_dict(torch.load(model_path, map_location=device))
52
+ model.eval()
53
+
54
+ # -----------------------------
55
+ # Decode function
56
+ # -----------------------------
57
+ def greedy_decode(output, idx_to_char):
58
+ output = output.argmax(2)
59
+ texts = []
60
+ for seq in output:
61
+ prev = -1
62
+ chars = []
63
+ for idx in seq.cpu().numpy():
64
+ if idx != prev and idx != 0:
65
+ chars.append(idx_to_char.get(idx,""))
66
+ prev = idx
67
+ texts.append("".join(chars))
68
+ return texts
69
+
70
+ # -----------------------------
71
+ # Transform
72
+ # -----------------------------
73
+ class OCRTestTransform:
74
+ def __init__(self, img_height=64, max_width=1600):
75
+ self.img_height = img_height
76
+ self.max_width = max_width
77
+ def __call__(self, img):
78
+ img = img.convert("L")
79
+ w, h = img.size
80
+ new_w = int(w * self.img_height / h)
81
+ img = img.resize((min(new_w,self.max_width), self.img_height), Image.BICUBIC)
82
+ new_img = Image.new("L", (self.max_width, self.img_height), 255)
83
+ new_img.paste(img, (0,0))
84
+ img = TF.to_tensor(new_img)
85
+ img = TF.normalize(img, (0.5,), (0.5,))
86
+ return img
87
+
88
+ transform_test = OCRTestTransform()
89
+
90
+ # -----------------------------
91
+ # Line segmentation
92
+ # -----------------------------
93
+ def segment_lines_precise(image_path, min_line_height=12, margin=6, visualize=False):
94
+ img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
95
+ _, binary = cv2.threshold(img,0,255,cv2.THRESH_BINARY_INV+cv2.THRESH_OTSU)
96
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(img.shape[1]//30,1))
97
+ morphed = cv2.dilate(binary,kernel,iterations=1)
98
+ contours, _ = cv2.findContours(morphed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
99
+ contours = sorted(contours, key=lambda ctr: cv2.boundingRect(ctr)[1])
100
+ lines = []
101
+ for ctr in contours:
102
+ x, y, w, h = cv2.boundingRect(ctr)
103
+ if h < min_line_height: continue
104
+ y1 = max(0, y-margin)
105
+ y2 = min(img.shape[0], y+h+margin)
106
+ line_img = img[y1:y2, x:x+w]
107
+ lines.append(Image.fromarray(line_img))
108
+ if visualize:
109
+ for i, line_img in enumerate(lines):
110
+ plt.figure(figsize=(12,2))
111
+ plt.imshow(line_img, cmap='gray')
112
+ plt.axis('off')
113
+ plt.title(f"Line {i+1}")
114
+ plt.show()
115
+ return lines
116
+
117
+ # -----------------------------
118
+ # OCR page function
119
+ # -----------------------------
120
+ def ocr_page(image_path, visualize=False):
121
+ lines = segment_lines_precise(image_path, visualize=visualize)
122
+ all_texts = []
123
+ for idx, line_img in enumerate(lines,1):
124
+ img_tensor = transform_test(line_img).unsqueeze(0).to(device)
125
+ with torch.no_grad():
126
+ outputs = model(img_tensor)
127
+ pred_text = greedy_decode(outputs, idx_to_char)[0]
128
+ all_texts.append(pred_text)
129
+ print(f"Line {idx}: {pred_text}")
130
+ return "\n".join(all_texts)
131
+
132
+ # -----------------------------
133
+ # Example usage
134
+ # -----------------------------
135
+ img_path = "example.png"
136
+ final_text = ocr_page(img_path, visualize=True)
137
+ print("\n=== Final OCR Page ===\n", final_text)