File size: 7,222 Bytes
cb2428f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
# Copyright (c) Alibaba, Inc. and its affiliates.
import datetime as dt
import os
from dataclasses import dataclass
from typing import Literal, Optional, Union

import torch.distributed as dist

from swift.utils import get_logger, init_process_group, is_dist
from .base_args import BaseArguments, to_abspath
from .base_args.model_args import ModelArguments
from .merge_args import MergeArguments

logger = get_logger()


@dataclass
class LmdeployArguments:
    """
    LmdeployArguments is a dataclass that holds the configuration for lmdeploy.

    Args:
        tp (int): Tensor parallelism size. Default is 1.
        session_len(Optional[int]): The session length, default None.
        cache_max_entry_count (float): Maximum entry count for cache. Default is 0.8.
        quant_policy (int): Quantization policy, e.g., 4, 8. Default is 0.
        vision_batch_size (int): Maximum batch size in VisionConfig. Default is 1.
    """

    # lmdeploy
    tp: int = 1
    session_len: Optional[int] = None
    cache_max_entry_count: float = 0.8
    quant_policy: int = 0  # e.g. 4, 8
    vision_batch_size: int = 1  # max_batch_size in VisionConfig

    def get_lmdeploy_engine_kwargs(self):
        kwargs = {
            'tp': self.tp,
            'session_len': self.session_len,
            'cache_max_entry_count': self.cache_max_entry_count,
            'quant_policy': self.quant_policy,
            'vision_batch_size': self.vision_batch_size
        }
        if dist.is_initialized():
            kwargs.update({'devices': [dist.get_rank()]})
        return kwargs


@dataclass
class VllmArguments:
    """
    VllmArguments is a dataclass that holds the configuration for vllm.

    Args:
        gpu_memory_utilization (float): GPU memory utilization. Default is 0.9.
        tensor_parallel_size (int): Tensor parallelism size. Default is 1.
        pipeline_parallel_size(int): Pipeline parallelism size. Default is 1.
        max_num_seqs (int): Maximum number of sequences. Default is 256.
        max_model_len (Optional[int]): Maximum model length. Default is None.
        disable_custom_all_reduce (bool): Flag to disable custom all-reduce. Default is False.
        enforce_eager (bool): Flag to enforce eager execution. Default is False.
        limit_mm_per_prompt (Optional[str]): Limit multimedia per prompt. Default is None.
        vllm_max_lora_rank (int): Maximum LoRA rank. Default is 16.
        enable_prefix_caching (bool): Flag to enable automatic prefix caching. Default is False.
    """
    # vllm
    gpu_memory_utilization: float = 0.9
    tensor_parallel_size: int = 1
    pipeline_parallel_size: int = 1
    max_num_seqs: int = 256
    max_model_len: Optional[int] = None
    disable_custom_all_reduce: bool = False
    enforce_eager: bool = False
    limit_mm_per_prompt: Optional[Union[dict, str]] = None  # '{"image": 5, "video": 2}'
    vllm_max_lora_rank: int = 16
    enable_prefix_caching: bool = False
    use_async_engine: bool = True
    data_parallel_size: int = 1
    log_level: Literal['critical', 'error', 'warning', 'info', 'debug', 'trace'] = 'info'
    vllm_quantization: Optional[str] = None

    def __post_init__(self):
        self.limit_mm_per_prompt = ModelArguments.parse_to_dict(self.limit_mm_per_prompt)

    def get_vllm_engine_kwargs(self):
        adapters = self.adapters
        if hasattr(self, 'adapter_mapping'):
            adapters = adapters + list(self.adapter_mapping.values())
        kwargs = {
            'gpu_memory_utilization': self.gpu_memory_utilization,
            'tensor_parallel_size': self.tensor_parallel_size,
            'pipeline_parallel_size': self.pipeline_parallel_size,
            'max_num_seqs': self.max_num_seqs,
            'max_model_len': self.max_model_len,
            'disable_custom_all_reduce': self.disable_custom_all_reduce,
            'enforce_eager': self.enforce_eager,
            'limit_mm_per_prompt': self.limit_mm_per_prompt,
            'max_lora_rank': self.vllm_max_lora_rank,
            'enable_lora': len(adapters) > 0,
            'max_loras': max(len(adapters), 1),
            'enable_prefix_caching': self.enable_prefix_caching,
            'quantization': self.vllm_quantization,
        }
        if dist.is_initialized():
            kwargs.update({'device': dist.get_rank()})
        return kwargs


@dataclass
class InferArguments(MergeArguments, VllmArguments, LmdeployArguments, BaseArguments):
    """
    InferArguments is a dataclass that extends BaseArguments, MergeArguments, VllmArguments, and LmdeployArguments.
    It is used to define the arguments required for model inference.

    Args:
        ckpt_dir (Optional[str]): Directory to the checkpoint. Default is None.
        infer_backend (Literal): Backend to use for inference. Default is 'pt'.
            Allowed values are 'vllm', 'pt', 'lmdeploy'.
        result_path (Optional[str]): Directory to store inference results. Default is None.
        max_batch_size (int): Maximum batch size for the pt engine. Default is 1.
        val_dataset_sample (Optional[int]): Sample size for validation dataset. Default is None.
    """
    infer_backend: Literal['vllm', 'pt', 'lmdeploy'] = 'pt'

    result_path: Optional[str] = None
    metric: Literal['acc', 'rouge'] = None
    # for pt engine
    max_batch_size: int = 1
    ddp_backend: Optional[str] = None

    # only for inference
    val_dataset_sample: Optional[int] = None

    def _get_result_path(self, folder_name: str) -> str:
        result_dir = self.ckpt_dir or f'result/{self.model_suffix}'
        os.makedirs(result_dir, exist_ok=True)
        result_dir = to_abspath(os.path.join(result_dir, folder_name))
        os.makedirs(result_dir, exist_ok=True)
        time = dt.datetime.now().strftime('%Y%m%d-%H%M%S')
        return os.path.join(result_dir, f'{time}.jsonl')

    def _init_result_path(self, folder_name: str) -> None:
        if self.result_path is not None:
            self.result_path = to_abspath(self.result_path)
            return
        self.result_path = self._get_result_path(folder_name)
        logger.info(f'args.result_path: {self.result_path}')

    def _init_stream(self):
        self.eval_human = not (self.dataset and self.split_dataset_ratio > 0 or self.val_dataset)

        if self.stream and self.num_beams != 1:
            self.stream = False
            logger.info('Setting args.stream: False')

    def _init_ddp(self):
        if not is_dist():
            return
        assert not self.eval_human and not self.stream, (
            f'args.eval_human: {self.eval_human}, args.stream: {self.stream}')
        self._init_device()
        init_process_group(self.ddp_backend)

    def __post_init__(self) -> None:
        BaseArguments.__post_init__(self)
        VllmArguments.__post_init__(self)
        self._init_result_path('infer_result')
        self._init_eval_human()
        self._init_stream()
        self._init_ddp()

    def _init_eval_human(self):
        if len(self.dataset) == 0 and len(self.val_dataset) == 0:
            eval_human = True
        else:
            eval_human = False
        self.eval_human = eval_human
        logger.info(f'Setting args.eval_human: {self.eval_human}')