File size: 1,683 Bytes
e020674
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import numpy as np
import subprocess
import torch
import logging
import colorlog
from dataflow.logger import get_logger
from dataflow.core import get_operator

def pipeline_step(yaml_path, step_name):
    import yaml
    logger = get_logger()
    logger.info(f"Loading yaml {yaml_path} ......")
    with open(yaml_path, "r") as f:
        config = yaml.safe_load(f)
    config = merge_yaml(config)
    logger.info(f"Load yaml success, config: {config}")
    algorithm = get_operator(step_name, config)
    logger.info("Start running ...")
    algorithm.run()

def merge_yaml(config):
    if not config.get("vllm_used"):
        return config
    else:
        vllm_args_list = config.get("vllm_args", [])
        if isinstance(vllm_args_list, list) and len(vllm_args_list) > 0 and isinstance(vllm_args_list[0], dict):
            vllm_args = vllm_args_list[0]
            config.update(vllm_args)  # 合并进顶层
        return config
    

def init_model(generator_type:str =None):
    if generator_type is None:
        raise ValueError("generator_type is not found in config")
    if generator_type == "local":
        from dataflow.utils.LocalModelGenerator import LocalModelGenerator
        return LocalModelGenerator(config)
    elif generator_type == "aisuite":
        from dataflow.utils.APIGenerator_aisuite import APIGenerator_aisuite
        return APIGenerator_aisuite(config)
    elif generator_type == "request":
        from dataflow.utils.APIGenerator_request import APIGenerator_request
        return APIGenerator_request(config)
    else:
        raise ValueError(f"Invalid generator type: {config['generator_type']}, must be one of: local, aisuite, request")