| import os | |
| from findfile import find_file | |
| from anonymous_demo.core.tad.prediction.tad_classifier import TADTextClassifier | |
| from anonymous_demo.utils.demo_utils import retry | |
| class CheckpointManager: | |
| pass | |
| class TADCheckpointManager(CheckpointManager): | |
| def get_tad_text_classifier(checkpoint: str = None, | |
| eval_batch_size=128, | |
| **kwargs): | |
| tad_text_classifier = TADTextClassifier(checkpoint, eval_batch_size=eval_batch_size, **kwargs) | |
| return tad_text_classifier | |