phonsobon commited on
Commit
7ec4bd8
ยท
verified ยท
1 Parent(s): a06ee27

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +0 -96
README.md CHANGED
@@ -48,102 +48,6 @@ It uses a CTC head so it can handle variable-length text without needing segment
48
  ```bash
49
  pip install torch torchvision pillow
50
  ```
51
-
52
- ### Option A โ€” using `model.pt` (state_dict)
53
-
54
- ```python
55
- import torch
56
- import torch.nn as nn
57
- import numpy as np
58
- from PIL import Image
59
- from huggingface_hub import hf_hub_download
60
-
61
- # โ”€โ”€ 1. Model definition โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
62
- class KhmerOCR_DTWG(nn.Module):
63
- def __init__(self, num_chars, hidden_size=256):
64
- super().__init__()
65
- self.cnn = nn.Sequential(
66
- self._conv(1, 32), nn.MaxPool2d(2, 2),
67
- self._conv(32, 64), nn.MaxPool2d(2, 2),
68
- self._conv(64, 128),
69
- self._conv(128, 128),
70
- nn.MaxPool2d((2, 1), (2, 1)),
71
- self._conv(128, 256),
72
- self._conv(256, 256),
73
- nn.MaxPool2d((4, 1), (4, 1)),
74
- )
75
- self.lstm1 = nn.LSTM(256, hidden_size, bidirectional=True, batch_first=True)
76
- self.fc1 = nn.Linear(hidden_size * 2, hidden_size)
77
- self.lstm2 = nn.LSTM(hidden_size, hidden_size, bidirectional=True, batch_first=True)
78
- self.fc = nn.Linear(hidden_size * 2, num_chars + 1)
79
-
80
- def _conv(self, i, o):
81
- return nn.Sequential(
82
- nn.Conv2d(i, o, 3, 1, 1, bias=False),
83
- nn.BatchNorm2d(o),
84
- nn.ReLU(inplace=True),
85
- )
86
-
87
- def forward(self, x):
88
- x = self.cnn(x)
89
- x = x.squeeze(2).permute(0, 2, 1)
90
- x, _ = self.lstm1(x)
91
- x = torch.relu(self.fc1(x))
92
- x, _ = self.lstm2(x)
93
- x = self.fc(x)
94
- return x.permute(1, 0, 2)
95
-
96
- # โ”€โ”€ 2. Vocabulary โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
97
- TOKENS = (
98
- "abcdefghijklmnopqrstuvwxyz"
99
- "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
100
- "0123456789"
101
- "แž€แžแž‚แžƒแž„แž…แž†แž‡แžˆแž‰แžŠแž‹แžŒแžแžŽแžแžแž‘แž’แž“แž”แž•แž–แž—แž˜แž™แžšแž›แžœแžแžžแžŸแž แžกแžขแžฃแžคแžฅแžฆแžงแžฉแžชแžซแžฌแžญแžฎแžฏแžฐแžฑแžฒแžณ"
102
- "แžถแžทแžธแžนแžบแžปแžผแžฝแžพแžฟแŸ€แŸแŸ‚แŸƒแŸ„แŸ…แŸ†แŸ‡แŸˆแŸ‰แŸŠแŸ‹แŸŒแŸแŸŽแŸแŸแŸ‘แŸ’แŸ”แŸ•แŸ–แŸ—แŸ˜แŸ›แŸ"
103
- "แŸ แŸกแŸขแŸฃแŸคแŸฅแŸฆแŸงแŸจแŸฉแŸณ"
104
- "!@#$%^&*()-_=+[]{};:'\",.<>?/|\\ "
105
- )
106
- NUM_CHARS = len(TOKENS)
107
- idx2char = {i + 1: c for i, c in enumerate(TOKENS)}
108
-
109
- # โ”€โ”€ 3. Load model โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
110
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
111
-
112
- weights_path = hf_hub_download(repo_id="phonsobon/mini-ocr", filename="model.pt")
113
- model = KhmerOCR_DTWG(NUM_CHARS).to(device)
114
- model.load_state_dict(torch.load(weights_path, map_location=device))
115
- model.eval()
116
-
117
- # โ”€โ”€ 4. Helpers โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
118
- def load_image(path):
119
- img = Image.open(path).convert("L")
120
- w, h = img.size
121
- new_w = int(w / h * 32)
122
- img = img.resize((new_w, 32))
123
- img = np.array(img, dtype=np.float32) / 255.0
124
- return torch.tensor(img).unsqueeze(0).unsqueeze(0) # (1, 1, 32, W)
125
-
126
- def ctc_decode(logits):
127
- preds = torch.argmax(logits, dim=2)[:, 0].cpu().numpy()
128
- prev, text = -1, []
129
- for p in preds:
130
- if p != prev and p != 0:
131
- text.append(idx2char.get(p, ""))
132
- prev = p
133
- return "".join(text)
134
-
135
- # โ”€โ”€ 5. Inference โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
136
- img = load_image("your_image.png").to(device)
137
-
138
- with torch.no_grad():
139
- logits = model(img)
140
- result = ctc_decode(logits)
141
-
142
- print("OCR result:", result)
143
- ```
144
-
145
- ### Option B โ€” TorchScript (no class needed)
146
-
147
  ```python
148
  import torch
149
  import numpy as np
 
48
  ```bash
49
  pip install torch torchvision pillow
50
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  ```python
52
  import torch
53
  import numpy as np