| | from PIL import Image |
| | import torch |
| | from transformers import NougatProcessor, VisionEncoderDecoderModel |
| |
|
| | |
| | processor = NougatProcessor.from_pretrained("MohamedRashad/arabic-small-nougat") |
| | model = VisionEncoderDecoderModel.from_pretrained("MohamedRashad/arabic-small-nougat") |
| |
|
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | model.to(device) |
| |
|
| | context_length = 2048 |
| |
|
| | def predict(img_path): |
| | |
| | image = Image.open(img_path).convert("RGB") |
| |
|
| | |
| | pixel_values = processor(images=image, return_tensors="pt").pixel_values |
| |
|
| | |
| | outputs = model.generate( |
| | pixel_values.to(device), |
| | min_length=1, |
| | max_new_tokens=context_length, |
| | bad_words_ids=[[processor.tokenizer.unk_token_id]], |
| | ) |
| |
|
| | |
| | page_sequence = processor.batch_decode(outputs, skip_special_tokens=True)[0] |
| | page_sequence = processor.post_process_generation(page_sequence, fix_markdown=False) |
| |
|
| | return page_sequence |
| |
|
| | |
| | print(predict("1.png")) |
| |
|