| import os | |
| import random | |
| import torch | |
| import copy | |
| import colbert.utils.distributed as distributed | |
| from colbert.utils.parser import Arguments | |
| from colbert.utils.runs import Run | |
| from colbert.training.training import train | |
| def main(): | |
| parser = Arguments(description='Training ColBERT with <query, positive passage, negative passage> triples.') | |
| parser.add_model_parameters() | |
| parser.add_model_training_parameters() | |
| parser.add_training_input() | |
| args = parser.parse() | |
| assert args.bsize % args.accumsteps == 0, ((args.bsize, args.accumsteps), | |
| "The batch size must be divisible by the number of gradient accumulation steps.") | |
| assert args.query_maxlen <= 512 | |
| assert args.doc_maxlen <= 512 | |
| args.lazy = args.collection is not None | |
| with Run.context(consider_failed_if_interrupted=False): | |
| train(args) | |
| if __name__ == "__main__": | |
| main() | |