Spaces:
Running
Running
| from transformers import DetrConfig, BeitConfig, DetrImageProcessor, VisionEncoderDecoderConfig, AutoModelForCausalLM, \ | |
| AutoModel | |
| from surya.model.ordering.config import MBartOrderConfig, VariableDonutSwinConfig | |
| from surya.model.ordering.decoder import MBartOrder | |
| from surya.model.ordering.encoder import VariableDonutSwinModel | |
| from surya.model.ordering.encoderdecoder import OrderVisionEncoderDecoderModel | |
| from surya.model.ordering.processor import OrderImageProcessor | |
| from surya.settings import settings | |
| def load_model(checkpoint=settings.ORDER_MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE): | |
| config = VisionEncoderDecoderConfig.from_pretrained(checkpoint) | |
| decoder_config = vars(config.decoder) | |
| decoder = MBartOrderConfig(**decoder_config) | |
| config.decoder = decoder | |
| encoder_config = vars(config.encoder) | |
| encoder = VariableDonutSwinConfig(**encoder_config) | |
| config.encoder = encoder | |
| # Get transformers to load custom model | |
| AutoModel.register(MBartOrderConfig, MBartOrder) | |
| AutoModelForCausalLM.register(MBartOrderConfig, MBartOrder) | |
| AutoModel.register(VariableDonutSwinConfig, VariableDonutSwinModel) | |
| model = OrderVisionEncoderDecoderModel.from_pretrained(checkpoint, config=config, torch_dtype=dtype) | |
| assert isinstance(model.decoder, MBartOrder) | |
| assert isinstance(model.encoder, VariableDonutSwinModel) | |
| model = model.to(device) | |
| model = model.eval() | |
| print(f"Loaded reading order model {checkpoint} on device {device} with dtype {dtype}") | |
| return model |