Update README.md
Browse files
README.md
CHANGED
|
@@ -31,170 +31,50 @@ trained with CTC loss to extract text from images.
|
|
| 31 |
## Usage Example
|
| 32 |
|
| 33 |
|
| 34 |
-
import json
|
| 35 |
import torch
|
| 36 |
-
import
|
| 37 |
-
|
| 38 |
-
import
|
| 39 |
-
import cv2
|
| 40 |
-
import matplotlib.pyplot as plt
|
| 41 |
from huggingface_hub import hf_hub_download
|
| 42 |
|
| 43 |
# -----------------------------
|
| 44 |
-
# 1️⃣
|
| 45 |
-
# -----------------------------
|
| 46 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 47 |
-
|
| 48 |
-
# -----------------------------
|
| 49 |
-
# 2️⃣ Load vocab
|
| 50 |
# -----------------------------
|
| 51 |
-
vocab_path = hf_hub_download(
|
| 52 |
with open(vocab_path, "r", encoding="utf-8") as f:
|
| 53 |
vocab = json.load(f)
|
| 54 |
-
char_to_idx = vocab["char_to_idx"]
|
| 55 |
idx_to_char = {int(k): v for k, v in vocab["idx_to_char"].items()}
|
| 56 |
|
| 57 |
# -----------------------------
|
| 58 |
-
#
|
| 59 |
# -----------------------------
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
self.layer1 = nn.Sequential(nn.Conv2d(in_channels, 32, 3, 1, 1), GN(32), nn.ReLU(), nn.MaxPool2d(2, 2))
|
| 67 |
-
self.layer2 = nn.Sequential(nn.Conv2d(32, 64, 3, 1, 1), GN(64), nn.ReLU(), nn.MaxPool2d(2, 2))
|
| 68 |
-
self.layer3 = nn.Sequential(nn.Conv2d(64, 128, 3, 1, 1), GN(128), nn.ReLU(), nn.MaxPool2d(2, 2))
|
| 69 |
-
self.layer4 = nn.Sequential(nn.Conv2d(128, 256, 3, 1, 1), GN(256), nn.ReLU())
|
| 70 |
-
self.layer5 = nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), GN(256), nn.ReLU())
|
| 71 |
-
self.layer6 = nn.Sequential(nn.Conv2d(256, 128, 3, 1, 1), GN(128), nn.ReLU())
|
| 72 |
-
self.adaptive_pool = nn.AdaptiveAvgPool2d((self.adaptive_height, None))
|
| 73 |
-
def forward(self, x):
|
| 74 |
-
for i in range(1, 7):
|
| 75 |
-
x = getattr(self, f"layer{i}")(x)
|
| 76 |
-
x = self.adaptive_pool(x)
|
| 77 |
-
return x
|
| 78 |
-
|
| 79 |
-
class PositionalEncoding(nn.Module):
|
| 80 |
-
def __init__(self, d_model, max_len=2000):
|
| 81 |
-
super().__init__()
|
| 82 |
-
pe = torch.zeros(max_len, d_model)
|
| 83 |
-
position = torch.arange(0, max_len).unsqueeze(1)
|
| 84 |
-
div_term = torch.exp(torch.arange(0, d_model, 2) * (-torch.log(torch.tensor(10000.0)) / d_model))
|
| 85 |
-
pe[:, 0::2] = torch.sin(position * div_term)
|
| 86 |
-
pe[:, 1::2] = torch.cos(position * div_term)
|
| 87 |
-
self.register_buffer("pe", pe.unsqueeze(0))
|
| 88 |
-
def forward(self, x):
|
| 89 |
-
return x + self.pe[:, :x.size(1), :]
|
| 90 |
-
|
| 91 |
-
class CNN_Transformer_OCR(nn.Module):
|
| 92 |
-
def __init__(self, num_classes, d_model=1280, nhead=16, num_layers=8, dropout=0.2):
|
| 93 |
-
super().__init__()
|
| 94 |
-
self.cnn = LightResNetCNN(in_channels=1, adaptive_height=8)
|
| 95 |
-
self.proj = nn.Linear(128 * 8, d_model)
|
| 96 |
-
self.posenc = PositionalEncoding(d_model)
|
| 97 |
-
encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, batch_first=True, dropout=dropout)
|
| 98 |
-
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
| 99 |
-
self.fc = nn.Linear(d_model, num_classes)
|
| 100 |
-
def forward(self, x):
|
| 101 |
-
f = self.cnn(x)
|
| 102 |
-
B, C, H, W = f.size()
|
| 103 |
-
f = f.permute(0, 3, 1, 2).reshape(B, W, C * H)
|
| 104 |
-
f = self.posenc(self.proj(f))
|
| 105 |
-
out = self.transformer(f)
|
| 106 |
-
out = self.fc(out)
|
| 107 |
-
return out.log_softmax(2)
|
| 108 |
|
| 109 |
# -----------------------------
|
| 110 |
-
#
|
| 111 |
# -----------------------------
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
|
|
|
|
|
|
| 116 |
|
| 117 |
# -----------------------------
|
| 118 |
-
#
|
| 119 |
-
# -----------------------------
|
| 120 |
-
def greedy_decode(output, idx_to_char):
|
| 121 |
-
output = output.argmax(2)
|
| 122 |
-
texts = []
|
| 123 |
-
for seq in output:
|
| 124 |
-
prev = -1
|
| 125 |
-
chars = []
|
| 126 |
-
for idx in seq.cpu().numpy():
|
| 127 |
-
if idx != prev and idx != 0:
|
| 128 |
-
chars.append(idx_to_char.get(idx, ""))
|
| 129 |
-
prev = idx
|
| 130 |
-
texts.append("".join(chars))
|
| 131 |
-
return texts
|
| 132 |
-
|
| 133 |
-
# -----------------------------
|
| 134 |
-
# 6️⃣ Transforms
|
| 135 |
-
# -----------------------------
|
| 136 |
-
class OCRTestTransform:
|
| 137 |
-
def __init__(self, img_height=64, max_width=1600):
|
| 138 |
-
self.img_height = img_height
|
| 139 |
-
self.max_width = max_width
|
| 140 |
-
def __call__(self, img):
|
| 141 |
-
img = img.convert("L")
|
| 142 |
-
w, h = img.size
|
| 143 |
-
new_w = int(w * self.img_height / h)
|
| 144 |
-
img = img.resize((min(new_w, self.max_width), self.img_height), Image.BICUBIC)
|
| 145 |
-
new_img = Image.new("L", (self.max_width, self.img_height), 255)
|
| 146 |
-
new_img.paste(img, (0, 0))
|
| 147 |
-
img = TF.to_tensor(new_img)
|
| 148 |
-
img = TF.normalize(img, (0.5,), (0.5,))
|
| 149 |
-
return img
|
| 150 |
-
|
| 151 |
-
transform_test = OCRTestTransform()
|
| 152 |
-
|
| 153 |
-
# -----------------------------
|
| 154 |
-
# 7️⃣ Line segmentation
|
| 155 |
-
# -----------------------------
|
| 156 |
-
def segment_lines_precise(image_path, min_line_height=12, margin=6, visualize=False):
|
| 157 |
-
img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
|
| 158 |
-
_, binary = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
|
| 159 |
-
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (img.shape[1]//30, 1))
|
| 160 |
-
morphed = cv2.dilate(binary, kernel, iterations=1)
|
| 161 |
-
contours, _ = cv2.findContours(morphed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 162 |
-
contours = sorted(contours, key=lambda ctr: cv2.boundingRect(ctr)[1])
|
| 163 |
-
lines = []
|
| 164 |
-
for ctr in contours:
|
| 165 |
-
x, y, w, h = cv2.boundingRect(ctr)
|
| 166 |
-
if h < min_line_height: continue
|
| 167 |
-
y1 = max(0, y - margin)
|
| 168 |
-
y2 = min(img.shape[0], y + h + margin)
|
| 169 |
-
line_img = img[y1:y2, x:x+w]
|
| 170 |
-
lines.append(Image.fromarray(line_img))
|
| 171 |
-
if visualize:
|
| 172 |
-
for i, line_img in enumerate(lines):
|
| 173 |
-
plt.figure(figsize=(12,2))
|
| 174 |
-
plt.imshow(line_img, cmap='gray')
|
| 175 |
-
plt.axis('off')
|
| 176 |
-
plt.title(f"Line {i+1}")
|
| 177 |
-
plt.show()
|
| 178 |
-
return lines
|
| 179 |
-
|
| 180 |
-
# -----------------------------
|
| 181 |
-
# 8️⃣ OCR function
|
| 182 |
# -----------------------------
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
img_tensor = transform_test(line_img).unsqueeze(0).to(device)
|
| 188 |
-
with torch.no_grad():
|
| 189 |
-
outputs = model(img_tensor)
|
| 190 |
-
pred_text = greedy_decode(outputs, idx_to_char)[0]
|
| 191 |
-
all_texts.append(pred_text)
|
| 192 |
-
print(f"Line {idx}: {pred_text}")
|
| 193 |
-
return "\n".join(all_texts)
|
| 194 |
-
|
| 195 |
# -----------------------------
|
| 196 |
# 9️⃣ Example usage
|
| 197 |
# -----------------------------
|
| 198 |
-
img_path = "/content/
|
| 199 |
-
|
| 200 |
-
print("\n=== Final OCR Page ===\n",
|
|
|
|
| 31 |
## Usage Example
|
| 32 |
|
| 33 |
|
|
|
|
| 34 |
import torch
|
| 35 |
+
import json
|
| 36 |
+
import sys
|
| 37 |
+
import importlib.util
|
|
|
|
|
|
|
| 38 |
from huggingface_hub import hf_hub_download
|
| 39 |
|
| 40 |
# -----------------------------
|
| 41 |
+
# 1️⃣ Load vocab
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
# -----------------------------
|
| 43 |
+
vocab_path = hf_hub_download("farbodpya/Persian-OCR", "vocab.json")
|
| 44 |
with open(vocab_path, "r", encoding="utf-8") as f:
|
| 45 |
vocab = json.load(f)
|
|
|
|
| 46 |
idx_to_char = {int(k): v for k, v in vocab["idx_to_char"].items()}
|
| 47 |
|
| 48 |
# -----------------------------
|
| 49 |
+
# 2️⃣ Import model.py dynamically
|
| 50 |
# -----------------------------
|
| 51 |
+
model_file = hf_hub_download("farbodpya/Persian-OCR", "model.py")
|
| 52 |
+
spec_model = importlib.util.spec_from_file_location("model", model_file)
|
| 53 |
+
model_module = importlib.util.module_from_spec(spec_model)
|
| 54 |
+
sys.modules["model"] = model_module
|
| 55 |
+
spec_model.loader.exec_module(model_module)
|
| 56 |
+
from model import CNN_Transformer_OCR
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
# -----------------------------
|
| 59 |
+
# 3️⃣ Import utils.py dynamically
|
| 60 |
# -----------------------------
|
| 61 |
+
utils_file = hf_hub_download("farbodpya/Persian-OCR", "utils.py")
|
| 62 |
+
spec_utils = importlib.util.spec_from_file_location("utils", utils_file)
|
| 63 |
+
utils_module = importlib.util.module_from_spec(spec_utils)
|
| 64 |
+
sys.modules["utils"] = utils_module
|
| 65 |
+
spec_utils.loader.exec_module(utils_module)
|
| 66 |
+
from utils import ocr_page, load_vocab
|
| 67 |
|
| 68 |
# -----------------------------
|
| 69 |
+
# 4️⃣ Load model weights
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
# -----------------------------
|
| 71 |
+
weights_path = hf_hub_download("farbodpya/Persian-OCR", "pytorch_model.bin")
|
| 72 |
+
model = CNN_Transformer_OCR(num_classes=len(idx_to_char)+1)
|
| 73 |
+
model.load_state_dict(torch.load(weights_path, map_location="cpu"))
|
| 74 |
+
model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
# -----------------------------
|
| 76 |
# 9️⃣ Example usage
|
| 77 |
# -----------------------------
|
| 78 |
+
img_path = "/content/Screenshot 2025-09-19 145016.png"
|
| 79 |
+
text = ocr_page(img_path, model, idx_to_char, visualize=True)
|
| 80 |
+
print("\n=== Final OCR Page ===\n", text)
|