Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| # coding=utf-8 | |
| """Return a pipeline automatically based on its name. | |
| """ | |
| from lmflow.pipeline.evaluator import Evaluator | |
| from lmflow.pipeline.finetuner import Finetuner | |
| from lmflow.pipeline.inferencer import Inferencer | |
| from lmflow.pipeline.raft_aligner import RaftAligner | |
| PIPELINE_MAPPING = { | |
| "evaluator": Evaluator, | |
| "finetuner": Finetuner, | |
| "inferencer": Inferencer, | |
| "raft_aligner": RaftAligner, | |
| } | |
| class AutoPipeline: | |
| """ | |
| The class designed to return a pipeline automatically based on its name. | |
| """ | |
| def get_pipeline(self, | |
| pipeline_name, | |
| model_args, | |
| data_args, | |
| pipeline_args, | |
| *args, | |
| **kwargs | |
| ): | |
| if pipeline_name not in PIPELINE_MAPPING: | |
| raise NotImplementedError( | |
| f'Pipeline "{pipeline_name}" is not supported' | |
| ) | |
| pipeline = PIPELINE_MAPPING[pipeline_name]( | |
| model_args, | |
| data_args, | |
| pipeline_args, | |
| *args, | |
| **kwargs | |
| ) | |
| return pipeline | |