Spaces:
Running
Running
| import warnings | |
| import torch | |
| warnings.filterwarnings("ignore", message="torch.utils._pytree._register_pytree_node is deprecated") | |
| import logging | |
| logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR) | |
| from typing import List, Optional, Tuple | |
| from surya.model.recognition.encoderdecoder import OCREncoderDecoderModel | |
| from surya.model.recognition.config import DonutSwinConfig, SuryaOCRConfig, SuryaOCRDecoderConfig, SuryaOCRTextEncoderConfig | |
| from surya.model.recognition.encoder import DonutSwinModel | |
| from surya.model.recognition.decoder import SuryaOCRDecoder, SuryaOCRTextEncoder | |
| from surya.settings import settings | |
| if not settings.ENABLE_EFFICIENT_ATTENTION: | |
| print("Efficient attention is disabled. This will use significantly more VRAM.") | |
| torch.backends.cuda.enable_mem_efficient_sdp(False) | |
| torch.backends.cuda.enable_flash_sdp(True) | |
| torch.backends.cuda.enable_math_sdp(True) | |
| def load_model(checkpoint=settings.RECOGNITION_MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE): | |
| config = SuryaOCRConfig.from_pretrained(checkpoint) | |
| decoder_config = config.decoder | |
| decoder = SuryaOCRDecoderConfig(**decoder_config) | |
| config.decoder = decoder | |
| encoder_config = config.encoder | |
| encoder = DonutSwinConfig(**encoder_config) | |
| config.encoder = encoder | |
| text_encoder_config = config.text_encoder | |
| text_encoder = SuryaOCRTextEncoderConfig(**text_encoder_config) | |
| config.text_encoder = text_encoder | |
| model = OCREncoderDecoderModel.from_pretrained(checkpoint, config=config, torch_dtype=dtype) | |
| assert isinstance(model.decoder, SuryaOCRDecoder) | |
| assert isinstance(model.encoder, DonutSwinModel) | |
| assert isinstance(model.text_encoder, SuryaOCRTextEncoder) | |
| model = model.to(device) | |
| model = model.eval() | |
| print(f"Loaded recognition model {checkpoint} on device {device} with dtype {dtype}") | |
| return model |