| from tror_yong_ocr import TrorYongOCR, TrorYongConfig | |
| from tror_yong_ocr import get_tokenizer | |
| def load_model(): | |
| tokenizer = get_tokenizer() | |
| config = TrorYongConfig( | |
| img_size=(32, 128), | |
| patch_size=(4, 8), | |
| n_channel=3, | |
| vocab_size=len(tokenizer), # exclude pad and unk tokens | |
| block_size=192, | |
| n_layer=4, | |
| n_head=6, | |
| n_embed=384, | |
| dropout=0.1, | |
| bias=True, | |
| ) | |
| model = TrorYongOCR(config, tokenizer) | |
| state_dict = torch.hub.load_state_dict_from_url('https://huggingface.co/KrorngAI/PARSeqForKhmer/resolve/main/best_model-80epoch.pt', map_location=torch.device('cpu')) | |
| model.load_state_dict(state_dict) | |
| return model |