Khrawsynth commited on
Commit
19343bf
·
verified ·
1 Parent(s): b49cb90

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +140 -12
README.md CHANGED
@@ -1,3 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # AssameseOCR
2
 
3
  **AssameseOCR** is a vision-language model for Optical Character Recognition (OCR) of printed Assamese text. Built on Microsoft's Florence-2-large foundation model with a custom character-level decoder, it achieves 94.67% character accuracy on the Mozhi dataset.
@@ -98,11 +143,13 @@ pip install torch torchvision transformers pillow
98
 
99
  ```python
100
  import torch
 
101
  from PIL import Image
102
  from transformers import AutoModelForCausalLM, CLIPImageProcessor
 
103
  import json
104
 
105
- # Load tokenizer
106
  class CharTokenizer:
107
  def __init__(self, vocab):
108
  self.vocab = vocab
@@ -112,6 +159,18 @@ class CharTokenizer:
112
  self.bos_token_id = self.char2id["<s>"]
113
  self.eos_token_id = self.char2id["</s>"]
114
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  def decode(self, ids, skip_special_tokens=True):
116
  chars = []
117
  for i in ids:
@@ -127,9 +186,46 @@ class CharTokenizer:
127
  vocab = json.load(f)
128
  return cls(vocab)
129
 
130
- # Load model components
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  device = "cuda" if torch.cuda.is_available() else "cpu"
132
 
 
 
 
 
 
 
 
133
  # Load Florence base model
134
  florence_model = AutoModelForCausalLM.from_pretrained(
135
  "microsoft/Florence-2-large-ft",
@@ -139,17 +235,50 @@ florence_model = AutoModelForCausalLM.from_pretrained(
139
  # Load image processor
140
  image_processor = CLIPImageProcessor.from_pretrained("microsoft/Florence-2-large-ft")
141
 
142
- # Load tokenizer
143
- char_tokenizer = CharTokenizer.load("assamese_char_tokenizer.json")
 
 
 
 
 
 
144
 
145
- # Load AssameseOCR weights
146
- # (Note: You'll need to define the FlorenceCharOCR class as in training)
147
- checkpoint = torch.load("assamese_ocr_best.pt", map_location=device)
148
- # ocr_model.load_state_dict(checkpoint['model_state_dict'])
149
 
150
- # Inference
151
- image = Image.open("assamese_text.jpg")
152
- # Process and predict...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  ```
154
 
155
  ## Vocabulary
@@ -211,5 +340,4 @@ If you use AssameseOCR in your research, please cite:
211
  - [KhasiBERT](https://huggingface.co/MWirelabs/KhasiBERT-110M) - Khasi language model
212
  - [NE-BERT](https://huggingface.co/MWirelabs/NE-BERT) - 9 Northeast languages
213
  - [Kren-M](https://huggingface.co/MWirelabs/Kren-M) - Khasi-English conversational AI
214
-
215
  - **AssameseOCR** - Assamese text recognition
 
1
+ ---
2
+ language:
3
+ - asm # Assamese ISO 639-1 code
4
+ license: apache-2.0
5
+ base_model: microsoft/Florence-2-large-ft
6
+ tags:
7
+ - vision
8
+ - ocr
9
+ - assamese
10
+ - northeast-india
11
+ - indic-languages
12
+ - character-recognition
13
+ - florence-2
14
+ - vision-language
15
+ datasets:
16
+ - darknight054/indic-mozhi-ocr
17
+ metrics:
18
+ - accuracy
19
+ - character_error_rate
20
+ library_name: transformers
21
+ pipeline_tag: image-to-text
22
+
23
+ model-index:
24
+ - name: AssameseOCR
25
+ results:
26
+ - task:
27
+ type: image-to-text
28
+ name: Optical Character Recognition
29
+ dataset:
30
+ name: Mozhi Indic OCR (Assamese)
31
+ type: darknight054/indic-mozhi-ocr
32
+ config: assamese
33
+ split: test
34
+ metrics:
35
+ - type: accuracy
36
+ value: 94.67
37
+ name: Character Accuracy
38
+ verified: false
39
+ - type: character_error_rate
40
+ value: 5.33
41
+ name: Character Error Rate (CER)
42
+ verified: false
43
+
44
+ ---
45
+
46
  # AssameseOCR
47
 
48
  **AssameseOCR** is a vision-language model for Optical Character Recognition (OCR) of printed Assamese text. Built on Microsoft's Florence-2-large foundation model with a custom character-level decoder, it achieves 94.67% character accuracy on the Mozhi dataset.
 
143
 
144
  ```python
145
  import torch
146
+ import torch.nn as nn
147
  from PIL import Image
148
  from transformers import AutoModelForCausalLM, CLIPImageProcessor
149
+ from huggingface_hub import hf_hub_download
150
  import json
151
 
152
+ # CharTokenizer class
153
  class CharTokenizer:
154
  def __init__(self, vocab):
155
  self.vocab = vocab
 
159
  self.bos_token_id = self.char2id["<s>"]
160
  self.eos_token_id = self.char2id["</s>"]
161
 
162
+ def encode(self, text, max_length=None, add_special_tokens=True):
163
+ ids = [self.bos_token_id] if add_special_tokens else []
164
+ for ch in text:
165
+ ids.append(self.char2id.get(ch, self.char2id["<unk>"]))
166
+ if add_special_tokens:
167
+ ids.append(self.eos_token_id)
168
+ if max_length:
169
+ ids = ids[:max_length]
170
+ if len(ids) < max_length:
171
+ ids += [self.pad_token_id] * (max_length - len(ids))
172
+ return ids
173
+
174
  def decode(self, ids, skip_special_tokens=True):
175
  chars = []
176
  for i in ids:
 
186
  vocab = json.load(f)
187
  return cls(vocab)
188
 
189
+ # FlorenceCharOCR model class
190
+ class FlorenceCharOCR(nn.Module):
191
+ def __init__(self, florence_model, vocab_size, vision_hidden_dim, decoder_hidden_dim=512, num_layers=4):
192
+ super().__init__()
193
+ self.florence_model = florence_model
194
+
195
+ for param in self.florence_model.parameters():
196
+ param.requires_grad = False
197
+
198
+ self.vision_proj = nn.Linear(vision_hidden_dim, decoder_hidden_dim)
199
+ self.embedding = nn.Embedding(vocab_size, decoder_hidden_dim)
200
+ decoder_layer = nn.TransformerDecoderLayer(
201
+ d_model=decoder_hidden_dim,
202
+ nhead=8,
203
+ batch_first=True
204
+ )
205
+ self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
206
+ self.fc_out = nn.Linear(decoder_hidden_dim, vocab_size)
207
+
208
+ def forward(self, pixel_values, tgt_ids, tgt_mask=None):
209
+ with torch.no_grad():
210
+ vision_feats = self.florence_model._encode_image(pixel_values)
211
+
212
+ vision_feats = self.vision_proj(vision_feats)
213
+ tgt_emb = self.embedding(tgt_ids)
214
+ decoder_out = self.decoder(tgt_emb, vision_feats, tgt_mask=tgt_mask)
215
+ logits = self.fc_out(decoder_out)
216
+
217
+ return logits
218
+
219
+ # Load components
220
  device = "cuda" if torch.cuda.is_available() else "cpu"
221
 
222
+ # Download files from HuggingFace
223
+ tokenizer_path = hf_hub_download(repo_id="MWirelabs/assamese-ocr", filename="assamese_char_tokenizer.json")
224
+ model_path = hf_hub_download(repo_id="MWirelabs/assamese-ocr", filename="assamese_ocr_best.pt")
225
+
226
+ # Load tokenizer
227
+ char_tokenizer = CharTokenizer.load(tokenizer_path)
228
+
229
  # Load Florence base model
230
  florence_model = AutoModelForCausalLM.from_pretrained(
231
  "microsoft/Florence-2-large-ft",
 
235
  # Load image processor
236
  image_processor = CLIPImageProcessor.from_pretrained("microsoft/Florence-2-large-ft")
237
 
238
+ # Initialize OCR model
239
+ ocr_model = FlorenceCharOCR(
240
+ florence_model=florence_model,
241
+ vocab_size=len(char_tokenizer.vocab),
242
+ vision_hidden_dim=1024,
243
+ decoder_hidden_dim=512,
244
+ num_layers=4
245
+ ).to(device)
246
 
247
+ # Load trained weights
248
+ checkpoint = torch.load(model_path, map_location=device)
249
+ ocr_model.load_state_dict(checkpoint['model_state_dict'])
250
+ ocr_model.eval()
251
 
252
+ # Inference function
253
+ def recognize_text(image_path):
254
+ # Load and process image
255
+ image = Image.open(image_path).convert("RGB")
256
+ pixel_values = image_processor(images=[image], return_tensors="pt")['pixel_values'].to(device)
257
+
258
+ # Generate prediction
259
+ with torch.no_grad():
260
+ # Start with BOS token
261
+ generated_ids = [char_tokenizer.bos_token_id]
262
+
263
+ for _ in range(128): # max length
264
+ tgt_tensor = torch.tensor([generated_ids], device=device)
265
+ logits = ocr_model(pixel_values, tgt_tensor)
266
+
267
+ # Get next token
268
+ next_token = logits[0, -1].argmax().item()
269
+ generated_ids.append(next_token)
270
+
271
+ # Stop if EOS
272
+ if next_token == char_tokenizer.eos_token_id:
273
+ break
274
+
275
+ # Decode
276
+ text = char_tokenizer.decode(generated_ids, skip_special_tokens=True)
277
+ return text
278
+
279
+ # Example usage
280
+ result = recognize_text("assamese_text.jpg")
281
+ print(f"Recognized text: {result}")
282
  ```
283
 
284
  ## Vocabulary
 
340
  - [KhasiBERT](https://huggingface.co/MWirelabs/KhasiBERT-110M) - Khasi language model
341
  - [NE-BERT](https://huggingface.co/MWirelabs/NE-BERT) - 9 Northeast languages
342
  - [Kren-M](https://huggingface.co/MWirelabs/Kren-M) - Khasi-English conversational AI
 
343
  - **AssameseOCR** - Assamese text recognition