Spaces:
Runtime error
Runtime error
root commited on
Commit ·
c8c0ef5
1
Parent(s): 57d225d
push to levo2.0
Browse files- Dockerfile +5 -1
- app.py +12 -13
- codeclm/models/builders.py +1 -1
- codeclm/models/codeclm_gen.py +326 -0
- codeclm/models/levo.py +2 -2
- codeclm/models/llama/modeling_llama.py +4 -1
- codeclm/modules/conditioners.py +29 -37
- codeclm/tokenizer/Flow1dVAE/generate_1rvq.py +28 -56
- codeclm/tokenizer/Flow1dVAE/model_1rvq.py +10 -29
- codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_config.py +55 -0
- codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_rope2_time_new_correct_mask_noncasual_reflow.py +2 -2
- codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/musicfm/modules/features.py +14 -5
- codeclm/tokenizer/audio_tokenizer.py +2 -2
- generate.py +106 -500
- generate.sh +9 -64
- levo_inference.py +57 -50
- requirements.txt +1 -0
- sample/lyrics.jsonl +2 -3
- vllm_hacked/model_executor/layers/utils.py +196 -0
- vllm_hacked/model_executor/layers/utils_ori.py +195 -0
- vllm_hacked/model_executor/models/llama.py +688 -0
- vllm_hacked/model_executor/sampling_metadata.py +596 -0
- vllm_hacked/model_executor/sampling_metadata_ori.py +596 -0
- vllm_hacked/sampling_params.py +596 -0
- vllm_hacked/sampling_params_ori.py +593 -0
- ckpt/.gitkeep → vllm_hacked/v1/sample/__init__ori.py +0 -0
- vllm_hacked/v1/sample/metadata.py +45 -0
- vllm_hacked/v1/sample/metadata_ori.py +43 -0
- vllm_hacked/v1/sample/ops/penalties_ori.py +43 -0
- vllm_hacked/v1/sample/sampler.py +338 -0
- vllm_hacked/v1/sample/sampler_ori.py +285 -0
- vllm_hacked/v1/spec_decode/utils.py +18 -0
- vllm_hacked/v1/spec_decode/utils_ori.py +14 -0
- vllm_hacked/v1/utils_ori.py +396 -0
- vllm_hacked/v1/worker/gpu_input_batch.py +669 -0
- vllm_hacked/v1/worker/gpu_input_batch_ori.py +863 -0
- vllm_hacked/v1/worker/gpu_model_runner.py +0 -0
- vllm_hacked/v1/worker/gpu_model_runner_ori.py +0 -0
- vllm_hacked/v1/worker/gpu_worker.py +710 -0
- vllm_hacked/worker_base.py +279 -0
- z_script.py +0 -44
Dockerfile
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
FROM
|
| 2 |
|
| 3 |
USER root
|
| 4 |
|
|
@@ -13,6 +13,10 @@ ENV PATH="/home/user/.local/bin:$PATH"
|
|
| 13 |
|
| 14 |
WORKDIR /app
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
COPY --chown=user ./requirements.txt requirements.txt
|
| 17 |
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
| 18 |
|
|
|
|
| 1 |
+
FROM witszhang/songgeneration_vllm:v0
|
| 2 |
|
| 3 |
USER root
|
| 4 |
|
|
|
|
| 13 |
|
| 14 |
WORKDIR /app
|
| 15 |
|
| 16 |
+
COPY --chown=user ./vllm_hacked/model_executor/models/llama.py /opt/conda/lib/python3.11/site-packages/vllm/model_executor/models/llama.py
|
| 17 |
+
COPY --chown=user ./vllm_hacked/v1/sample/sampler.py /opt/conda/lib/python3.11/site-packages/vllm/v1/sample/sampler.py
|
| 18 |
+
COPY --chown=user ./vllm_hacked/v1/sample/metadata.py /opt/conda/lib/python3.11/site-packages/vllm/v1/sample/metadata.py
|
| 19 |
+
COPY --chown=user ./vllm_hacked/sampling_params.py /opt/conda/lib/python3.11/site-packages/vllm/sampling_params.py
|
| 20 |
COPY --chown=user ./requirements.txt requirements.txt
|
| 21 |
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
| 22 |
|
app.py
CHANGED
|
@@ -15,14 +15,12 @@ from download import download_model
|
|
| 15 |
|
| 16 |
# 下载模型
|
| 17 |
APP_DIR = op.dirname(op.abspath(__file__))
|
| 18 |
-
download_model(APP_DIR)
|
| 19 |
-
large_model_path = op.join(APP_DIR, "ckpt", "SongGeneration-v1.5-beta")
|
| 20 |
-
download_model(large_model_path, repo_id="waytan22/SongGeneration-v1.5-beta", revision="db10f47")
|
| 21 |
print("Successful downloaded model.")
|
| 22 |
|
| 23 |
# 模型初始化
|
| 24 |
from levo_inference import LeVoInference
|
| 25 |
-
|
| 26 |
|
| 27 |
EXAMPLE_LYRICS = """
|
| 28 |
[intro-medium]
|
|
@@ -159,7 +157,7 @@ def generate_song(lyric, description=None, prompt_audio=None, genre=None, cfg_co
|
|
| 159 |
# 创建Gradio界面
|
| 160 |
with gr.Blocks(title="SongGeneration Demo Space") as demo:
|
| 161 |
gr.Markdown("# 🎵 SongGeneration Demo Space")
|
| 162 |
-
gr.Markdown("
|
| 163 |
|
| 164 |
with gr.Row():
|
| 165 |
with gr.Column():
|
|
@@ -215,7 +213,7 @@ lyrics
|
|
| 215 |
minimum=0.1,
|
| 216 |
maximum=3.0,
|
| 217 |
step=0.1,
|
| 218 |
-
value=1.
|
| 219 |
interactive=True,
|
| 220 |
elem_id="cfg-coef",
|
| 221 |
)
|
|
@@ -239,7 +237,7 @@ lyrics
|
|
| 239 |
# )
|
| 240 |
with gr.Row():
|
| 241 |
generate_btn = gr.Button("Generate Song", variant="primary")
|
| 242 |
-
generate_bgm_btn = gr.Button("Generate Pure Music", variant="primary")
|
| 243 |
|
| 244 |
with gr.Column():
|
| 245 |
output_audio = gr.Audio(label="Generated Song", type="filepath")
|
|
@@ -267,18 +265,19 @@ lyrics
|
|
| 267 |
# 生成按钮点击事件
|
| 268 |
generate_btn.click(
|
| 269 |
fn=generate_song,
|
| 270 |
-
inputs=[lyric, description, prompt_audio, genre, cfg_coef, temperature, gr.State(
|
| 271 |
-
outputs=[output_audio, output_json]
|
| 272 |
-
)
|
| 273 |
-
generate_bgm_btn.click(
|
| 274 |
-
fn=generate_song,
|
| 275 |
-
inputs=[lyric, description, prompt_audio, genre, cfg_coef, temperature, gr.State(50), gr.State("bgm")],
|
| 276 |
outputs=[output_audio, output_json]
|
| 277 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
|
| 279 |
|
| 280 |
# 启动应用
|
| 281 |
if __name__ == "__main__":
|
| 282 |
torch.set_num_threads(1)
|
|
|
|
| 283 |
demo.launch(server_name="0.0.0.0", server_port=7860)
|
| 284 |
|
|
|
|
| 15 |
|
| 16 |
# 下载模型
|
| 17 |
APP_DIR = op.dirname(op.abspath(__file__))
|
| 18 |
+
download_model(APP_DIR, repo_id="waytan22/SongGeneration-v2.0", revision="ffd9215")
|
|
|
|
|
|
|
| 19 |
print("Successful downloaded model.")
|
| 20 |
|
| 21 |
# 模型初始化
|
| 22 |
from levo_inference import LeVoInference
|
| 23 |
+
Model = None
|
| 24 |
|
| 25 |
EXAMPLE_LYRICS = """
|
| 26 |
[intro-medium]
|
|
|
|
| 157 |
# 创建Gradio界面
|
| 158 |
with gr.Blocks(title="SongGeneration Demo Space") as demo:
|
| 159 |
gr.Markdown("# 🎵 SongGeneration Demo Space")
|
| 160 |
+
gr.Markdown("Push to Levo 2.0 — faster and more controllable. The code is in [GIT](https://github.com/tencent-ailab/SongGeneration)")
|
| 161 |
|
| 162 |
with gr.Row():
|
| 163 |
with gr.Column():
|
|
|
|
| 213 |
minimum=0.1,
|
| 214 |
maximum=3.0,
|
| 215 |
step=0.1,
|
| 216 |
+
value=1.8,
|
| 217 |
interactive=True,
|
| 218 |
elem_id="cfg-coef",
|
| 219 |
)
|
|
|
|
| 237 |
# )
|
| 238 |
with gr.Row():
|
| 239 |
generate_btn = gr.Button("Generate Song", variant="primary")
|
| 240 |
+
# generate_bgm_btn = gr.Button("Generate Pure Music", variant="primary")
|
| 241 |
|
| 242 |
with gr.Column():
|
| 243 |
output_audio = gr.Audio(label="Generated Song", type="filepath")
|
|
|
|
| 265 |
# 生成按钮点击事件
|
| 266 |
generate_btn.click(
|
| 267 |
fn=generate_song,
|
| 268 |
+
inputs=[lyric, description, prompt_audio, genre, cfg_coef, temperature, gr.State(5000)],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
outputs=[output_audio, output_json]
|
| 270 |
)
|
| 271 |
+
# generate_bgm_btn.click(
|
| 272 |
+
# fn=generate_song,
|
| 273 |
+
# inputs=[lyric, description, prompt_audio, genre, cfg_coef, temperature, gr.State(50), gr.State("bgm")],
|
| 274 |
+
# outputs=[output_audio, output_json]
|
| 275 |
+
# )
|
| 276 |
|
| 277 |
|
| 278 |
# 启动应用
|
| 279 |
if __name__ == "__main__":
|
| 280 |
torch.set_num_threads(1)
|
| 281 |
+
MODEL = LeVoInference(op.join(APP_DIR, "ckpt"))
|
| 282 |
demo.launch(server_name="0.0.0.0", server_port=7860)
|
| 283 |
|
codeclm/models/builders.py
CHANGED
|
@@ -52,7 +52,7 @@ def get_audio_tokenizer_model_cpu(checkpoint_path: str, cfg: omegaconf.DictConfi
|
|
| 52 |
return AudioTokenizer.get_pretrained(name, cfg.vae_config, cfg.vae_model, 'cpu', mode=cfg.mode, tango_device='cpu')
|
| 53 |
|
| 54 |
|
| 55 |
-
def get_lm_model(cfg: omegaconf.DictConfig, version: str = 'v1.
|
| 56 |
"""Instantiate a LM."""
|
| 57 |
lm_kwargs = dict_from_config(getattr(cfg, 'lm'))
|
| 58 |
|
|
|
|
| 52 |
return AudioTokenizer.get_pretrained(name, cfg.vae_config, cfg.vae_model, 'cpu', mode=cfg.mode, tango_device='cpu')
|
| 53 |
|
| 54 |
|
| 55 |
+
def get_lm_model(cfg: omegaconf.DictConfig, version: str = 'v1.5'): #-> LMModel:
|
| 56 |
"""Instantiate a LM."""
|
| 57 |
lm_kwargs = dict_from_config(getattr(cfg, 'lm'))
|
| 58 |
|
codeclm/models/codeclm_gen.py
ADDED
|
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Main model for using CodecLM. This will combine all the required components
|
| 3 |
+
and provide easy access to the generation API.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import typing as tp
|
| 7 |
+
import warnings
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from codeclm.tokenizer.audio_tokenizer import AudioTokenizer
|
| 12 |
+
# from .lm_llama import LMModel
|
| 13 |
+
from ..utils.autocast import TorchAutocast
|
| 14 |
+
import torch
|
| 15 |
+
from torch.nn import functional as F
|
| 16 |
+
import torchaudio
|
| 17 |
+
# from optim.ema import EMA
|
| 18 |
+
from codeclm.utils.utils import dict_from_config
|
| 19 |
+
from codeclm.modules.pattern import (
|
| 20 |
+
CodebooksPatternProvider,
|
| 21 |
+
DelayedPatternProvider,
|
| 22 |
+
)
|
| 23 |
+
from codeclm.modules.conditioners import (
|
| 24 |
+
ConditioningAttributes,
|
| 25 |
+
AudioCondition,
|
| 26 |
+
BaseConditioner,
|
| 27 |
+
QuantizedEmbeddingConditioner,
|
| 28 |
+
ConditionerProvider,
|
| 29 |
+
ConditionFuser,
|
| 30 |
+
QwTextConditioner,
|
| 31 |
+
QwTokenizerConditioner,
|
| 32 |
+
ClassifierFreeGuidanceDropoutInference,
|
| 33 |
+
)
|
| 34 |
+
import omegaconf
|
| 35 |
+
|
| 36 |
+
def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig, version: str = 'v1.0') -> ConditionerProvider:
|
| 37 |
+
"""Instantiate a conditioning model."""
|
| 38 |
+
cfg = getattr(cfg, 'conditioners')
|
| 39 |
+
dict_cfg = {} if cfg is None else dict_from_config(cfg)
|
| 40 |
+
conditioners: tp.Dict[str, BaseConditioner] = {}
|
| 41 |
+
condition_provider_args = dict_cfg.pop('args', {})
|
| 42 |
+
|
| 43 |
+
for cond, cond_cfg in dict_cfg.items():
|
| 44 |
+
model_type = cond_cfg['model']
|
| 45 |
+
model_args = cond_cfg[model_type]
|
| 46 |
+
if model_type == 'QwTokenizer':
|
| 47 |
+
conditioners[str(cond)] = QwTokenizerConditioner(
|
| 48 |
+
output_dim=output_dim,
|
| 49 |
+
**model_args
|
| 50 |
+
)
|
| 51 |
+
elif model_type == "QwTextTokenizer":
|
| 52 |
+
conditioners[str(cond)] = QwTextConditioner(
|
| 53 |
+
output_dim=output_dim,
|
| 54 |
+
version=version,
|
| 55 |
+
**model_args
|
| 56 |
+
)
|
| 57 |
+
elif model_type == "qt_embedding":
|
| 58 |
+
conditioners[str(cond)] = QuantizedEmbeddingConditioner(
|
| 59 |
+
dim=output_dim,
|
| 60 |
+
**model_args
|
| 61 |
+
)
|
| 62 |
+
else:
|
| 63 |
+
raise ValueError(f"Unrecognized conditioning model: {model_type}")
|
| 64 |
+
conditioner = ConditionerProvider(conditioners, **condition_provider_args)
|
| 65 |
+
return conditioner
|
| 66 |
+
|
| 67 |
+
def get_codebooks_pattern_provider(code_depth: int, cfg: omegaconf.DictConfig) -> CodebooksPatternProvider:
|
| 68 |
+
"""Instantiate a codebooks pattern provider object."""
|
| 69 |
+
pattern_providers = {
|
| 70 |
+
'delay': DelayedPatternProvider,
|
| 71 |
+
}
|
| 72 |
+
name = cfg.modeling
|
| 73 |
+
kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {}
|
| 74 |
+
klass = pattern_providers[name]
|
| 75 |
+
return klass(code_depth, **kwargs)
|
| 76 |
+
|
| 77 |
+
MelodyList = tp.List[tp.Optional[torch.Tensor]]
|
| 78 |
+
MelodyType = tp.Union[torch.Tensor, MelodyList]
|
| 79 |
+
|
| 80 |
+
def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser:
|
| 81 |
+
"""Instantiate a condition fuser object."""
|
| 82 |
+
fuser_cfg = getattr(cfg, 'fuser')
|
| 83 |
+
fuser_methods = ['sum', 'prepend']
|
| 84 |
+
fuse2cond = {k: fuser_cfg[k] for k in fuser_methods}
|
| 85 |
+
kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods}
|
| 86 |
+
fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs)
|
| 87 |
+
return fuser
|
| 88 |
+
|
| 89 |
+
class CodecLM_gen:
|
| 90 |
+
"""CodecLM main model with convenient generation API.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
name (str): name of the model.
|
| 94 |
+
compression_model (CompressionModel): Compression model
|
| 95 |
+
used to map audio to invertible discrete representations.
|
| 96 |
+
lm (LMModel): Language model over discrete representations.
|
| 97 |
+
max_duration (float, optional): maximum duration the model can produce,
|
| 98 |
+
otherwise, inferred from the training params.
|
| 99 |
+
"""
|
| 100 |
+
def __init__(self, cfg, name: str, audiotokenizer: AudioTokenizer,
|
| 101 |
+
max_duration: tp.Optional[float] = None):
|
| 102 |
+
self.cfg = cfg
|
| 103 |
+
self.name = name
|
| 104 |
+
self.audiotokenizer = audiotokenizer
|
| 105 |
+
self.seperate_tokenizer = None
|
| 106 |
+
if max_duration is None:
|
| 107 |
+
max_duration = self.cfg.max_dur
|
| 108 |
+
assert max_duration is not None
|
| 109 |
+
|
| 110 |
+
self.max_duration: float = max_duration
|
| 111 |
+
# self.device = next(iter(lm.parameters())).device
|
| 112 |
+
# self.device = next(iter(audiotokenizer.parameters())).device
|
| 113 |
+
self.generation_params: dict = {}
|
| 114 |
+
# self.set_generation_params(duration=15) # 15 seconds by default
|
| 115 |
+
self.set_generation_params(duration=15, extend_stride=self.max_duration // 2)
|
| 116 |
+
self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None
|
| 117 |
+
self.autocast = TorchAutocast(enabled=False)
|
| 118 |
+
self.condition_provider = get_conditioner_provider(cfg.lm.dim, self.cfg)
|
| 119 |
+
codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern')
|
| 120 |
+
self.pattern_provider = get_codebooks_pattern_provider(cfg.lm.code_depth, codebooks_pattern_cfg)
|
| 121 |
+
self.fuser = get_condition_fuser(cfg)
|
| 122 |
+
self.eos_token_id = cfg.lm.code_size
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
@property
|
| 127 |
+
def frame_rate(self) -> float:
|
| 128 |
+
"""Roughly the number of AR steps per seconds."""
|
| 129 |
+
return self.audiotokenizer.frame_rate
|
| 130 |
+
|
| 131 |
+
@property
|
| 132 |
+
def sample_rate(self) -> int:
|
| 133 |
+
"""Sample rate of the generated audio."""
|
| 134 |
+
return self.audiotokenizer.sample_rate
|
| 135 |
+
|
| 136 |
+
@property
|
| 137 |
+
def audio_channels(self) -> int:
|
| 138 |
+
"""Audio channels of the generated audio."""
|
| 139 |
+
return self.audiotokenizer.channels
|
| 140 |
+
|
| 141 |
+
def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
|
| 142 |
+
top_p: float = 0.0, temperature: float = 1.0,
|
| 143 |
+
duration: float = 30.0, cfg_coef: float = 3.0,
|
| 144 |
+
extend_stride: float = 18, record_tokens: bool = False,
|
| 145 |
+
record_window: int = 50):
|
| 146 |
+
"""Set the generation parameters for CodecLM.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True.
|
| 150 |
+
top_k (int, optional): top_k used for sampling. Defaults to 250.
|
| 151 |
+
top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.
|
| 152 |
+
temperature (float, optional): Softmax temperature parameter. Defaults to 1.0.
|
| 153 |
+
duration (float, optional): Duration of the generated waveform. Defaults to 30.0.
|
| 154 |
+
cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0.
|
| 155 |
+
two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
|
| 156 |
+
instead of batching together the two. This has some impact on how things
|
| 157 |
+
are padded but seems to have little impact in practice.
|
| 158 |
+
extend_stride: when doing extended generation (i.e. more than 30 seconds), by how much
|
| 159 |
+
should we extend the audio each time. Larger values will mean less context is
|
| 160 |
+
preserved, and shorter value will require extra computations.
|
| 161 |
+
"""
|
| 162 |
+
assert extend_stride <= self.max_duration, "Cannot stride by more than max generation duration."
|
| 163 |
+
self.extend_stride = extend_stride
|
| 164 |
+
self.duration = duration
|
| 165 |
+
self.generation_params = {
|
| 166 |
+
'use_sampling': use_sampling,
|
| 167 |
+
'temp': temperature,
|
| 168 |
+
'top_k': top_k,
|
| 169 |
+
'top_p': top_p,
|
| 170 |
+
'cfg_coef': cfg_coef,
|
| 171 |
+
'record_tokens': record_tokens,
|
| 172 |
+
'record_window': record_window,
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None):
|
| 176 |
+
"""Override the default progress callback."""
|
| 177 |
+
self._progress_callback = progress_callback
|
| 178 |
+
|
| 179 |
+
# Inference
|
| 180 |
+
def generate_condition(self, descriptions: tp.List[str],
|
| 181 |
+
melody_wavs: torch.Tensor = None,
|
| 182 |
+
return_tokens: bool = False,
|
| 183 |
+
melody_is_wav: bool = True,
|
| 184 |
+
type_info: tp.List[str] = None,
|
| 185 |
+
embeded_eosp1: torch.Tensor = None,
|
| 186 |
+
) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
|
| 187 |
+
if melody_wavs is not None:
|
| 188 |
+
if melody_wavs.dim() == 2:
|
| 189 |
+
melody_wavs = melody_wavs[None]
|
| 190 |
+
if melody_wavs.dim() != 3:
|
| 191 |
+
raise ValueError("Melody wavs should have a shape [B, C, T].")
|
| 192 |
+
melody_wavs = list(melody_wavs)
|
| 193 |
+
|
| 194 |
+
# if melody_is_wav:
|
| 195 |
+
# melody_wavs = [wav.mean(dim=-2) for wav in melody_wavs]
|
| 196 |
+
|
| 197 |
+
texts, audio_qt_embs = self._prepare_tokens_and_attributes(descriptions=descriptions,
|
| 198 |
+
melody_wavs=melody_wavs,
|
| 199 |
+
melody_is_wav=melody_is_wav)
|
| 200 |
+
fused_input = self.get_condition_tensors(texts, audio_qt_embs, type_info, embeded_eosp1)
|
| 201 |
+
|
| 202 |
+
return fused_input, audio_qt_embs
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
@torch.no_grad()
|
| 206 |
+
def _prepare_tokens_and_attributes(
|
| 207 |
+
self,
|
| 208 |
+
descriptions: tp.Sequence[tp.Optional[str]],
|
| 209 |
+
melody_wavs: tp.Optional[MelodyList] = None,
|
| 210 |
+
melody_is_wav = True
|
| 211 |
+
) -> tp.Tuple[tp.List[str], tp.List[torch.Tensor]]:
|
| 212 |
+
"""Prepare model inputs.
|
| 213 |
+
|
| 214 |
+
Args:
|
| 215 |
+
descriptions (list of str): A list of strings used as text conditioning.
|
| 216 |
+
prompt (torch.Tensor): A batch of waveforms used for continuation.
|
| 217 |
+
melody_wavs (torch.Tensor, optional): A batch of waveforms
|
| 218 |
+
used as melody conditioning. Defaults to None.
|
| 219 |
+
"""
|
| 220 |
+
texts = [description for description in descriptions]
|
| 221 |
+
audio_qt_embs = []
|
| 222 |
+
|
| 223 |
+
if melody_wavs is None:
|
| 224 |
+
audio_qt_embs = None
|
| 225 |
+
elif melody_wavs is not None:
|
| 226 |
+
if 'prompt_audio' not in self.condition_provider.conditioners:
|
| 227 |
+
raise RuntimeError("This model doesn't support melody conditioning. "
|
| 228 |
+
"Use the `melody` model.")
|
| 229 |
+
assert len(melody_wavs) == len(texts), \
|
| 230 |
+
f"number of melody wavs must match number of descriptions! " \
|
| 231 |
+
f"got melody len={len(melody_wavs)}, and descriptions len={len(texts)}"
|
| 232 |
+
if type(melody_wavs) == list:
|
| 233 |
+
melody_wavs = torch.stack(melody_wavs, dim=0)
|
| 234 |
+
# melody_wavs = melody_wavs.to(self.device)
|
| 235 |
+
print(melody_wavs.shape)
|
| 236 |
+
if melody_is_wav:
|
| 237 |
+
melody_tokens, scale = self.audiotokenizer.encode(melody_wavs)
|
| 238 |
+
else:
|
| 239 |
+
melody_tokens = melody_wavs
|
| 240 |
+
target_melody_token_len = self.cfg.prompt_len * self.audiotokenizer.frame_rate
|
| 241 |
+
print(melody_tokens.shape, target_melody_token_len)
|
| 242 |
+
print(melody_tokens)
|
| 243 |
+
if melody_tokens.shape[-1] > target_melody_token_len:
|
| 244 |
+
melody_tokens = melody_tokens[...,:target_melody_token_len]
|
| 245 |
+
for melody in melody_tokens:
|
| 246 |
+
audio_qt_embs.append(melody.long())
|
| 247 |
+
return texts, audio_qt_embs
|
| 248 |
+
|
| 249 |
+
@torch.no_grad()
|
| 250 |
+
def prepare_condition_tensors(self,
|
| 251 |
+
batch_size = 1,
|
| 252 |
+
text: tp.Optional[tp.List[str]] = None,
|
| 253 |
+
audio_qt_emb: tp.Optional[tp.List[torch.Tensor]] = None,
|
| 254 |
+
type_info: tp.Optional[tp.List[str]] = None,
|
| 255 |
+
prepare_null_condition = False,
|
| 256 |
+
):
|
| 257 |
+
conditions = []
|
| 258 |
+
for i in range(batch_size):
|
| 259 |
+
attr = ConditioningAttributes()
|
| 260 |
+
if 'description' in self.condition_provider.conditioners:
|
| 261 |
+
attr["text"]["description"] = ""
|
| 262 |
+
if text is not None:
|
| 263 |
+
attr["text"]["description"] = text[i]
|
| 264 |
+
if 'prompt_audio' in self.condition_provider.conditioners:
|
| 265 |
+
if audio_qt_emb is None: # tokenize stage will padding to max length
|
| 266 |
+
attr["audio"]['prompt_audio'] = AudioCondition(
|
| 267 |
+
wav=torch.zeros((1, self.cfg.audio_tokenizer_code_depth, 0)).long().cuda() + 16385,
|
| 268 |
+
length=torch.Tensor([0]).long(),
|
| 269 |
+
sample_rate=[self.cfg.sample_rate],)
|
| 270 |
+
else:
|
| 271 |
+
aT = audio_qt_emb[i].shape[-1]
|
| 272 |
+
pattern = self.pattern_provider.get_pattern(aT)
|
| 273 |
+
audio_qt_seq, _, _ = pattern.build_pattern_sequence(audio_qt_emb[i][None],
|
| 274 |
+
self.eos_token_id, keep_only_valid_steps=False)
|
| 275 |
+
attr["audio"]['prompt_audio'] = AudioCondition(
|
| 276 |
+
wav=audio_qt_seq.long().cuda(),
|
| 277 |
+
length=torch.Tensor([audio_qt_seq.shape[-1]]).long(),
|
| 278 |
+
sample_rate=[self.cfg.sample_rate],)
|
| 279 |
+
if 'type_info' in self.condition_provider.conditioners:
|
| 280 |
+
attr["text"]["type_info"] = ""
|
| 281 |
+
if type_info is not None:
|
| 282 |
+
attr["text"]["type_info"] = type_info[i]
|
| 283 |
+
conditions.append(attr)
|
| 284 |
+
# print("conditions", conditions)
|
| 285 |
+
if prepare_null_condition:
|
| 286 |
+
cfg_inference = ClassifierFreeGuidanceDropoutInference()
|
| 287 |
+
null_conditions = cfg_inference(conditions, condition_types=["audio", "text"],
|
| 288 |
+
customized=None)
|
| 289 |
+
conditions = conditions + null_conditions
|
| 290 |
+
tokenized_conditions = self.condition_provider.tokenize(conditions)
|
| 291 |
+
# import pdb; pdb.set_trace()
|
| 292 |
+
condition_tensors = self.condition_provider(tokenized_conditions)
|
| 293 |
+
return condition_tensors
|
| 294 |
+
|
| 295 |
+
def get_condition_tensors(self, texts, audio_qt_embs, type_info, embeded_eosp1):
|
| 296 |
+
condition_tensors = self.prepare_condition_tensors(batch_size=1, text=texts, audio_qt_emb=audio_qt_embs, type_info=type_info, prepare_null_condition=self.cfg.vllm.cfg)
|
| 297 |
+
if self.cfg.vllm.cfg:
|
| 298 |
+
input_ = torch.cat((embeded_eosp1, embeded_eosp1), dim=0)
|
| 299 |
+
else:
|
| 300 |
+
input_ = embeded_eosp1
|
| 301 |
+
fused_input = self.fuser(input_, condition_tensors)
|
| 302 |
+
return fused_input
|
| 303 |
+
|
| 304 |
+
@torch.no_grad()
|
| 305 |
+
def generate_audio(self, gen_tokens: torch.Tensor, prompt=None, vocal_prompt=None, bgm_prompt=None, chunked=False, chunk_size=128, gen_type='mixed'):
|
| 306 |
+
"""Generate Audio from tokens"""
|
| 307 |
+
assert gen_tokens.dim() == 3
|
| 308 |
+
if self.seperate_tokenizer is not None:
|
| 309 |
+
gen_tokens_song = gen_tokens[:, [0], :]
|
| 310 |
+
gen_tokens_vocal = gen_tokens[:, [1], :]
|
| 311 |
+
gen_tokens_bgm = gen_tokens[:, [2], :]
|
| 312 |
+
if gen_type == 'bgm':
|
| 313 |
+
gen_tokens_vocal = torch.full_like(gen_tokens_vocal, 3142)
|
| 314 |
+
if vocal_prompt is not None:
|
| 315 |
+
vocal_prompt = torch.zeros_like(vocal_prompt)
|
| 316 |
+
elif gen_type == 'vocal':
|
| 317 |
+
gen_tokens_bgm = torch.full_like(gen_tokens_bgm, 9670)
|
| 318 |
+
if bgm_prompt is not None:
|
| 319 |
+
bgm_prompt = torch.zeros_like(bgm_prompt)
|
| 320 |
+
else:
|
| 321 |
+
assert gen_type == 'mixed', f"gen_type {gen_type} not supported"
|
| 322 |
+
gen_audio_seperate = self.seperate_tokenizer.decode([gen_tokens_vocal, gen_tokens_bgm], vocal_prompt, bgm_prompt, chunked=chunked, chunk_size=chunk_size)
|
| 323 |
+
return gen_audio_seperate
|
| 324 |
+
else:
|
| 325 |
+
gen_audio = self.audiotokenizer.decode(gen_tokens, prompt)
|
| 326 |
+
return gen_audio
|
codeclm/models/levo.py
CHANGED
|
@@ -96,7 +96,7 @@ class LmModel(LlamaModel_base):
|
|
| 96 |
self.vocab_size = config.vocab_size
|
| 97 |
layer_cls = LlamaDecoderLayer # cross attention decoder layer can be overwritten here
|
| 98 |
|
| 99 |
-
assert version.parse(transformers.__version__) < version.parse("4.40")
|
| 100 |
|
| 101 |
self.layers = nn.ModuleList([layer_cls(config) for _ in range(config.num_hidden_layers)])
|
| 102 |
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
@@ -221,4 +221,4 @@ class LmModel(LlamaModel_base):
|
|
| 221 |
hidden_states=all_hidden_states,
|
| 222 |
attentions=all_self_attns,
|
| 223 |
)
|
| 224 |
-
|
|
|
|
| 96 |
self.vocab_size = config.vocab_size
|
| 97 |
layer_cls = LlamaDecoderLayer # cross attention decoder layer can be overwritten here
|
| 98 |
|
| 99 |
+
#assert version.parse(transformers.__version__) < version.parse("4.40")
|
| 100 |
|
| 101 |
self.layers = nn.ModuleList([layer_cls(config) for _ in range(config.num_hidden_layers)])
|
| 102 |
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
| 221 |
hidden_states=all_hidden_states,
|
| 222 |
attentions=all_self_attns,
|
| 223 |
)
|
| 224 |
+
|
codeclm/models/llama/modeling_llama.py
CHANGED
|
@@ -34,10 +34,13 @@ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
|
| 34 |
from transformers.utils import (
|
| 35 |
add_start_docstrings,
|
| 36 |
add_start_docstrings_to_model_forward,
|
| 37 |
-
is_flash_attn_available,
|
| 38 |
logging,
|
| 39 |
replace_return_docstrings,
|
| 40 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
from .configuration_llama import LlamaConfig
|
| 42 |
|
| 43 |
|
|
|
|
| 34 |
from transformers.utils import (
|
| 35 |
add_start_docstrings,
|
| 36 |
add_start_docstrings_to_model_forward,
|
|
|
|
| 37 |
logging,
|
| 38 |
replace_return_docstrings,
|
| 39 |
)
|
| 40 |
+
try:
|
| 41 |
+
from transformers.utils import is_flash_attn_available
|
| 42 |
+
except ImportError:
|
| 43 |
+
from transformers.utils import is_flash_attn_2_available as is_flash_attn_available
|
| 44 |
from .configuration_llama import LlamaConfig
|
| 45 |
|
| 46 |
|
codeclm/modules/conditioners.py
CHANGED
|
@@ -112,6 +112,7 @@ class QwTokenizerConditioner(TextConditioner):
|
|
| 112 |
token_path = "",
|
| 113 |
max_len = 300,
|
| 114 |
add_token_list=[]): #""
|
|
|
|
| 115 |
from transformers import Qwen2Tokenizer
|
| 116 |
self.text_tokenizer = Qwen2Tokenizer.from_pretrained(token_path)
|
| 117 |
if add_token_list != []:
|
|
@@ -157,9 +158,6 @@ class QwTokenizerConditioner(TextConditioner):
|
|
| 157 |
tp_cover_range[b, st: sp_list[i+1]] = tokens[b, st] - 151645
|
| 158 |
|
| 159 |
if self.max_len is not None:
|
| 160 |
-
if inputs['input_ids'].shape[-1] > self.max_len:
|
| 161 |
-
warnings.warn(f"Max len limit ({self.max_len}) Exceed! \
|
| 162 |
-
{[self.text_tokenizer.convert_ids_to_tokens(i.tolist()) for i in tokens]} will be cut!")
|
| 163 |
tokens = self.pad_2d_tensor(tokens, self.max_len, self.pad_token_idx).to(self.output_proj.weight.device)
|
| 164 |
mask = self.pad_2d_tensor(mask, self.max_len, 0).to(self.output_proj.weight.device)
|
| 165 |
tp_cover_range = self.pad_2d_tensor(tp_cover_range, self.max_len, 0).to(self.output_proj.weight.device)
|
|
@@ -168,7 +166,7 @@ class QwTokenizerConditioner(TextConditioner):
|
|
| 168 |
structure_embeds = self.structure_emb(tp_cover_range.to(device))
|
| 169 |
|
| 170 |
embeds = content_embeds + structure_embeds
|
| 171 |
-
return embeds,
|
| 172 |
|
| 173 |
def pad_2d_tensor(self, x, max_len, pad_id):
|
| 174 |
batch_size, seq_len = x.size()
|
|
@@ -192,9 +190,9 @@ class QwTextConditioner(TextConditioner):
|
|
| 192 |
version: str = 'v1.0'): #""
|
| 193 |
|
| 194 |
from transformers import Qwen2Tokenizer
|
| 195 |
-
self.text_tokenizer = Qwen2Tokenizer.from_pretrained(token_path)
|
| 196 |
-
|
| 197 |
-
|
| 198 |
voc_size = len(self.text_tokenizer.get_vocab())
|
| 199 |
# here initialize a output_proj (nn.Embedding) layer
|
| 200 |
super().__init__(voc_size, output_dim, input_token=True, padding_idx=151643)
|
|
@@ -223,7 +221,7 @@ class QwTextConditioner(TextConditioner):
|
|
| 223 |
mask = self.pad_2d_tensor(mask, self.max_len, 0).to(self.output_proj.weight.device)
|
| 224 |
|
| 225 |
embeds = self.output_proj(tokens)
|
| 226 |
-
return embeds,
|
| 227 |
|
| 228 |
def pad_2d_tensor(self, x, max_len, pad_id):
|
| 229 |
batch_size, seq_len = x.size()
|
|
@@ -255,7 +253,6 @@ class QuantizedEmbeddingConditioner(AudioConditioner):
|
|
| 255 |
self.emb = nn.ModuleList([nn.Embedding(code_size+2, dim, padding_idx=code_size+1) for _ in range(code_depth)])
|
| 256 |
# add End-Of-Text embedding
|
| 257 |
self.EOT_emb = nn.Parameter(torch.randn(1, dim), requires_grad=True)
|
| 258 |
-
self.layer2_EOT_emb = nn.Parameter(torch.randn(1, dim), requires_grad=True)
|
| 259 |
self.output_proj = None
|
| 260 |
self.max_len = max_len
|
| 261 |
self.vocab_size = code_size
|
|
@@ -274,20 +271,20 @@ class QuantizedEmbeddingConditioner(AudioConditioner):
|
|
| 274 |
wav = F.pad(wav, [0, self.max_len - 1 - wav.shape[2]], value=self.vocab_size+1)
|
| 275 |
else:
|
| 276 |
wav = wav[:, :, :self.max_len-1]
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
lengths = lengths + 1
|
| 284 |
lengths = torch.clamp(lengths, max=self.max_len)
|
| 285 |
|
| 286 |
if lengths is not None:
|
| 287 |
-
mask = length_to_mask(lengths, max_len=
|
| 288 |
else:
|
| 289 |
-
mask = torch.ones((B, self.code_depth), device=
|
| 290 |
-
return
|
| 291 |
|
| 292 |
|
| 293 |
# ================================================================
|
|
@@ -356,10 +353,10 @@ class ConditionerProvider(nn.Module):
|
|
| 356 |
output = {}
|
| 357 |
for attribute, inputs in tokenized.items():
|
| 358 |
if attribute == 'description' and structure_dur is not None:
|
| 359 |
-
|
| 360 |
else:
|
| 361 |
-
|
| 362 |
-
output[attribute] = (
|
| 363 |
return output
|
| 364 |
|
| 365 |
def _collate_text(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.List[tp.Optional[str]]]:
|
|
@@ -460,8 +457,7 @@ class ConditionFuser(StreamingModule):
|
|
| 460 |
|
| 461 |
def forward(
|
| 462 |
self,
|
| 463 |
-
|
| 464 |
-
input2: torch.Tensor,
|
| 465 |
conditions: tp.Dict[str, ConditionType]
|
| 466 |
) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
| 467 |
"""Fuse the conditions to the provided model input.
|
|
@@ -475,14 +471,14 @@ class ConditionFuser(StreamingModule):
|
|
| 475 |
used for cross-attention or None if no cross attention inputs exist.
|
| 476 |
"""
|
| 477 |
#import pdb; pdb.set_trace()
|
| 478 |
-
B, T, _ =
|
| 479 |
|
| 480 |
if 'offsets' in self._streaming_state:
|
| 481 |
first_step = False
|
| 482 |
offsets = self._streaming_state['offsets']
|
| 483 |
else:
|
| 484 |
first_step = True
|
| 485 |
-
offsets = torch.zeros(
|
| 486 |
|
| 487 |
assert set(conditions.keys()).issubset(set(self.cond2fuse.keys())), \
|
| 488 |
f"given conditions contain unknown attributes for fuser, " \
|
|
@@ -491,31 +487,28 @@ class ConditionFuser(StreamingModule):
|
|
| 491 |
# if 'prepend' mode is used,
|
| 492 |
# the concatenation order will be the SAME with the conditions in config:
|
| 493 |
# prepend: ['description', 'prompt_audio'] (then goes the input)
|
| 494 |
-
|
| 495 |
-
fused_input_2 = input2
|
| 496 |
for fuse_op in self.fuse2cond.keys():
|
| 497 |
fuse_op_conditions = self.fuse2cond[fuse_op]
|
| 498 |
if fuse_op == 'sum' and len(fuse_op_conditions) > 0:
|
| 499 |
for cond in fuse_op_conditions:
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
fused_input_2 += this_cond_2
|
| 503 |
elif fuse_op == 'prepend' and len(fuse_op_conditions) > 0:
|
| 504 |
if not first_step:
|
| 505 |
continue
|
| 506 |
reverse_list = deepcopy(fuse_op_conditions)
|
| 507 |
reverse_list.reverse()
|
| 508 |
for cond in reverse_list:
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
fused_input_2 = torch.cat((this_cond_2, fused_input_2), dim=1) # concat along T dim
|
| 512 |
elif fuse_op not in self.FUSING_METHODS:
|
| 513 |
raise ValueError(f"unknown op ({fuse_op})")
|
| 514 |
|
| 515 |
if self._is_streaming:
|
| 516 |
self._streaming_state['offsets'] = offsets + T
|
| 517 |
|
| 518 |
-
return
|
| 519 |
|
| 520 |
|
| 521 |
|
|
@@ -575,8 +568,7 @@ class ClassifierFreeGuidanceDropout(DropoutModule):
|
|
| 575 |
self.check(sample, condition_type, condition)
|
| 576 |
|
| 577 |
if condition_type == 'audio':
|
| 578 |
-
audio_cond = sample.audio[condition]
|
| 579 |
-
depth = audio_cond.wav.shape[1]
|
| 580 |
sample.audio[condition] = self.get_null_wav(audio_cond.wav, sr=audio_cond.sample_rate[0])
|
| 581 |
else:
|
| 582 |
sample.text[condition] = None
|
|
@@ -639,7 +631,7 @@ class ClassifierFreeGuidanceDropoutInference(ClassifierFreeGuidanceDropout):
|
|
| 639 |
sample.audio[condition] = self.get_null_wav(audio_cond.wav, sr=audio_cond.sample_rate[0])
|
| 640 |
else:
|
| 641 |
if customized is None:
|
| 642 |
-
if condition in ['type_info']
|
| 643 |
if "[Musicality-very-high]" in sample.text[condition]:
|
| 644 |
sample.text[condition] = "[Musicality-very-low], ."
|
| 645 |
print(f"cfg unconditioning: change sample.text[condition] to [Musicality-very-low]")
|
|
|
|
| 112 |
token_path = "",
|
| 113 |
max_len = 300,
|
| 114 |
add_token_list=[]): #""
|
| 115 |
+
add_token_list.append('.')
|
| 116 |
from transformers import Qwen2Tokenizer
|
| 117 |
self.text_tokenizer = Qwen2Tokenizer.from_pretrained(token_path)
|
| 118 |
if add_token_list != []:
|
|
|
|
| 158 |
tp_cover_range[b, st: sp_list[i+1]] = tokens[b, st] - 151645
|
| 159 |
|
| 160 |
if self.max_len is not None:
|
|
|
|
|
|
|
|
|
|
| 161 |
tokens = self.pad_2d_tensor(tokens, self.max_len, self.pad_token_idx).to(self.output_proj.weight.device)
|
| 162 |
mask = self.pad_2d_tensor(mask, self.max_len, 0).to(self.output_proj.weight.device)
|
| 163 |
tp_cover_range = self.pad_2d_tensor(tp_cover_range, self.max_len, 0).to(self.output_proj.weight.device)
|
|
|
|
| 166 |
structure_embeds = self.structure_emb(tp_cover_range.to(device))
|
| 167 |
|
| 168 |
embeds = content_embeds + structure_embeds
|
| 169 |
+
return embeds, mask
|
| 170 |
|
| 171 |
def pad_2d_tensor(self, x, max_len, pad_id):
|
| 172 |
batch_size, seq_len = x.size()
|
|
|
|
| 190 |
version: str = 'v1.0'): #""
|
| 191 |
|
| 192 |
from transformers import Qwen2Tokenizer
|
| 193 |
+
self.text_tokenizer = Qwen2Tokenizer.from_pretrained(token_path)
|
| 194 |
+
self.text_tokenizer.add_tokens(['[Musicality-very-high]', '[Musicality-high]', '[Musicality-medium]', '[Musicality-low]', '[Musicality-very-low]', '[Pure-Music]', '.'], special_tokens=True)
|
| 195 |
+
print(self.text_tokenizer)
|
| 196 |
voc_size = len(self.text_tokenizer.get_vocab())
|
| 197 |
# here initialize a output_proj (nn.Embedding) layer
|
| 198 |
super().__init__(voc_size, output_dim, input_token=True, padding_idx=151643)
|
|
|
|
| 221 |
mask = self.pad_2d_tensor(mask, self.max_len, 0).to(self.output_proj.weight.device)
|
| 222 |
|
| 223 |
embeds = self.output_proj(tokens)
|
| 224 |
+
return embeds, mask
|
| 225 |
|
| 226 |
def pad_2d_tensor(self, x, max_len, pad_id):
|
| 227 |
batch_size, seq_len = x.size()
|
|
|
|
| 253 |
self.emb = nn.ModuleList([nn.Embedding(code_size+2, dim, padding_idx=code_size+1) for _ in range(code_depth)])
|
| 254 |
# add End-Of-Text embedding
|
| 255 |
self.EOT_emb = nn.Parameter(torch.randn(1, dim), requires_grad=True)
|
|
|
|
| 256 |
self.output_proj = None
|
| 257 |
self.max_len = max_len
|
| 258 |
self.vocab_size = code_size
|
|
|
|
| 271 |
wav = F.pad(wav, [0, self.max_len - 1 - wav.shape[2]], value=self.vocab_size+1)
|
| 272 |
else:
|
| 273 |
wav = wav[:, :, :self.max_len-1]
|
| 274 |
+
# self.emb.to(wav.device) # 都放cuda
|
| 275 |
+
wav = wav.to(self.emb[0].weight.device)
|
| 276 |
+
embeds = sum([self.emb[k](wav[:, k]) for k in range(self.code_depth)]) # B,T,D
|
| 277 |
+
# self.EOT_emb.data = self.EOT_emb.data.to(embeds.device)
|
| 278 |
+
embeds = torch.cat((self.EOT_emb.unsqueeze(0).repeat(B, 1, 1),
|
| 279 |
+
embeds), dim=1)
|
| 280 |
lengths = lengths + 1
|
| 281 |
lengths = torch.clamp(lengths, max=self.max_len)
|
| 282 |
|
| 283 |
if lengths is not None:
|
| 284 |
+
mask = length_to_mask(lengths, max_len=embeds.shape[1]).int() # type: ignore
|
| 285 |
else:
|
| 286 |
+
mask = torch.ones((B, self.code_depth), device=embeds.device, dtype=torch.int)
|
| 287 |
+
return embeds, mask
|
| 288 |
|
| 289 |
|
| 290 |
# ================================================================
|
|
|
|
| 353 |
output = {}
|
| 354 |
for attribute, inputs in tokenized.items():
|
| 355 |
if attribute == 'description' and structure_dur is not None:
|
| 356 |
+
condition, mask = self.conditioners[attribute](inputs, structure_dur = structure_dur)
|
| 357 |
else:
|
| 358 |
+
condition, mask = self.conditioners[attribute](inputs)
|
| 359 |
+
output[attribute] = (condition, mask)
|
| 360 |
return output
|
| 361 |
|
| 362 |
def _collate_text(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.List[tp.Optional[str]]]:
|
|
|
|
| 457 |
|
| 458 |
def forward(
|
| 459 |
self,
|
| 460 |
+
input: torch.Tensor,
|
|
|
|
| 461 |
conditions: tp.Dict[str, ConditionType]
|
| 462 |
) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
| 463 |
"""Fuse the conditions to the provided model input.
|
|
|
|
| 471 |
used for cross-attention or None if no cross attention inputs exist.
|
| 472 |
"""
|
| 473 |
#import pdb; pdb.set_trace()
|
| 474 |
+
B, T, _ = input.shape
|
| 475 |
|
| 476 |
if 'offsets' in self._streaming_state:
|
| 477 |
first_step = False
|
| 478 |
offsets = self._streaming_state['offsets']
|
| 479 |
else:
|
| 480 |
first_step = True
|
| 481 |
+
offsets = torch.zeros(input.shape[0], dtype=torch.long, device=input.device)
|
| 482 |
|
| 483 |
assert set(conditions.keys()).issubset(set(self.cond2fuse.keys())), \
|
| 484 |
f"given conditions contain unknown attributes for fuser, " \
|
|
|
|
| 487 |
# if 'prepend' mode is used,
|
| 488 |
# the concatenation order will be the SAME with the conditions in config:
|
| 489 |
# prepend: ['description', 'prompt_audio'] (then goes the input)
|
| 490 |
+
fused_input = input
|
|
|
|
| 491 |
for fuse_op in self.fuse2cond.keys():
|
| 492 |
fuse_op_conditions = self.fuse2cond[fuse_op]
|
| 493 |
if fuse_op == 'sum' and len(fuse_op_conditions) > 0:
|
| 494 |
for cond in fuse_op_conditions:
|
| 495 |
+
this_cond, cond_mask = conditions[cond]
|
| 496 |
+
fused_input += this_cond
|
|
|
|
| 497 |
elif fuse_op == 'prepend' and len(fuse_op_conditions) > 0:
|
| 498 |
if not first_step:
|
| 499 |
continue
|
| 500 |
reverse_list = deepcopy(fuse_op_conditions)
|
| 501 |
reverse_list.reverse()
|
| 502 |
for cond in reverse_list:
|
| 503 |
+
this_cond, cond_mask = conditions[cond]
|
| 504 |
+
fused_input = torch.cat((this_cond, fused_input), dim=1) # concat along T dim
|
|
|
|
| 505 |
elif fuse_op not in self.FUSING_METHODS:
|
| 506 |
raise ValueError(f"unknown op ({fuse_op})")
|
| 507 |
|
| 508 |
if self._is_streaming:
|
| 509 |
self._streaming_state['offsets'] = offsets + T
|
| 510 |
|
| 511 |
+
return fused_input
|
| 512 |
|
| 513 |
|
| 514 |
|
|
|
|
| 568 |
self.check(sample, condition_type, condition)
|
| 569 |
|
| 570 |
if condition_type == 'audio':
|
| 571 |
+
audio_cond = sample.audio[condition]
|
|
|
|
| 572 |
sample.audio[condition] = self.get_null_wav(audio_cond.wav, sr=audio_cond.sample_rate[0])
|
| 573 |
else:
|
| 574 |
sample.text[condition] = None
|
|
|
|
| 631 |
sample.audio[condition] = self.get_null_wav(audio_cond.wav, sr=audio_cond.sample_rate[0])
|
| 632 |
else:
|
| 633 |
if customized is None:
|
| 634 |
+
if condition in ['type_info']:
|
| 635 |
if "[Musicality-very-high]" in sample.text[condition]:
|
| 636 |
sample.text[condition] = "[Musicality-very-low], ."
|
| 637 |
print(f"cfg unconditioning: change sample.text[condition] to [Musicality-very-low]")
|
codeclm/tokenizer/Flow1dVAE/generate_1rvq.py
CHANGED
|
@@ -10,6 +10,7 @@ import math
|
|
| 10 |
import numpy as np
|
| 11 |
import tools.torch_tools as torch_tools
|
| 12 |
from safetensors.torch import load_file
|
|
|
|
| 13 |
|
| 14 |
class Tango:
|
| 15 |
def __init__(self, \
|
|
@@ -23,9 +24,9 @@ class Tango:
|
|
| 23 |
scheduler_name = "configs/scheduler/stable_diffusion_2.1_largenoise_sample.json"
|
| 24 |
self.device = device
|
| 25 |
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
self.layer_num = layer_num
|
| 30 |
|
| 31 |
self.MAX_DURATION = 360
|
|
@@ -52,43 +53,34 @@ class Tango:
|
|
| 52 |
# scheduler_name, subfolder="scheduler")
|
| 53 |
# print("Successfully loaded inference scheduler from {}".format(scheduler_name))
|
| 54 |
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
# orig_samples = orig_samples.to(self.device)
|
| 63 |
-
# saved_samples = orig_samples[:,0:40*48000].clamp(-1,1)
|
| 64 |
-
# orig_samples = orig_samples[:,0:40*48000].clamp(-1,1)
|
| 65 |
-
# max_volume = orig_samples.abs().max(dim=-1)[0]
|
| 66 |
-
# orig_samples = orig_samples/max_volume.unsqueeze(-1)
|
| 67 |
-
# print("orig_samples.shape", orig_samples.shape)
|
| 68 |
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
-
|
| 72 |
|
| 73 |
-
|
| 74 |
-
# latents = self.model.inference(orig_samples.repeat(batch_size, 1), [lyric, ]*batch_size, true_latents, latent_length, additional_feats=[], guidance_scale=1.5, num_steps = steps, disable_progress=disable_progress,layer=6, scenario = scenario)
|
| 75 |
-
# print("latents.shape", latents.shape)
|
| 76 |
-
# print("latent_length", latent_length)
|
| 77 |
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
# print("audio.shape:",audio.shape)
|
| 83 |
-
# # audio = audio.reshape(audio.shape[0]//2, 2, -1)
|
| 84 |
-
# # audio = torch.from_numpy(audio)
|
| 85 |
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
|
| 93 |
@torch.no_grad()
|
| 94 |
@torch.autocast(device_type="cuda", dtype=torch.float32)
|
|
@@ -105,7 +97,6 @@ class Tango:
|
|
| 105 |
min_samples = int(40 * self.sample_rate)
|
| 106 |
# 40秒对应10个token
|
| 107 |
output_len = int(orig_length / float(self.sample_rate) * 25) + 1
|
| 108 |
-
print("output_len: ", output_len)
|
| 109 |
|
| 110 |
while(audios.shape[-1] < min_samples):
|
| 111 |
audios = torch.cat([audios, audios], -1)
|
|
@@ -117,10 +108,8 @@ class Tango:
|
|
| 117 |
audio_input = audios.reshape(2, -1, min_samples).permute(1, 0, 2).reshape(-1, 2, min_samples)
|
| 118 |
|
| 119 |
for audio_inx in range(0, audio_input.shape[0], batch_size):
|
| 120 |
-
# import pdb; pdb.set_trace()
|
| 121 |
codes, _, spk_embeds = self.model.fetch_codes_batch((audio_input[audio_inx:audio_inx+batch_size]), additional_feats=[],layer=self.layer_num)
|
| 122 |
codes_list.append(torch.cat(codes, 1))
|
| 123 |
-
# print("codes_list",codes_list[0].shape)
|
| 124 |
|
| 125 |
codes = torch.cat(codes_list, 0).permute(1,0,2).reshape(1, -1)[None] # B 3 T -> 3 B T
|
| 126 |
codes=codes[:,:,:output_len]
|
|
@@ -159,21 +148,13 @@ class Tango:
|
|
| 159 |
# else choose from 20.48s which might includes verse or chorus
|
| 160 |
prompt = prompt[:,int(20*self.sample_rate):int(30*self.sample_rate)] # limit max length to 10.24
|
| 161 |
|
| 162 |
-
true_latent = self.vae.encode_audio(prompt).permute(0,2,1)
|
| 163 |
-
# print("true_latent.shape", true_latent.shape)
|
| 164 |
-
# print("first_latent.shape", first_latent.shape)
|
| 165 |
-
#true_latent.shape torch.Size([1, 250, 64])
|
| 166 |
-
# first_latent.shape torch.Size([1, 1000, 64])
|
| 167 |
-
|
| 168 |
first_latent[:,0:true_latent.shape[1],:] = true_latent
|
| 169 |
first_latent_length = true_latent.shape[1]
|
| 170 |
first_latent_codes = self.sound2code(prompt)
|
| 171 |
first_latent_codes_length = first_latent_codes.shape[-1]
|
| 172 |
codes = torch.cat([first_latent_codes, codes], -1)
|
| 173 |
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
codes_len= codes.shape[-1]
|
| 178 |
target_len = int((codes_len - first_latent_codes_length) / 100 * 4 * self.sample_rate)
|
| 179 |
# target_len = int(codes_len / 100 * 4 * self.sample_rate)
|
|
@@ -196,17 +177,12 @@ class Tango:
|
|
| 196 |
codes_input=[]
|
| 197 |
codes_input.append(codes[:,:,sinx:sinx+min_samples])
|
| 198 |
if(sinx == 0):
|
| 199 |
-
# print("Processing {} to {}".format(sinx/self.sample_rate, (sinx + min_samples)/self.sample_rate))
|
| 200 |
incontext_length = first_latent_length
|
| 201 |
latents = self.model.inference_codes(codes_input, spk_embeds, first_latent, latent_length, incontext_length=incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg')
|
| 202 |
latent_list.append(latents)
|
| 203 |
else:
|
| 204 |
-
# print("Processing {} to {}".format(sinx/self.sample_rate, (sinx + min_samples)/self.sample_rate))
|
| 205 |
true_latent = latent_list[-1][:,:,-ovlp_frames:].permute(0,2,1)
|
| 206 |
-
print("true_latent.shape", true_latent.shape)
|
| 207 |
len_add_to_1000 = min_samples - true_latent.shape[-2]
|
| 208 |
-
# print("len_add_to_1000", len_add_to_1000)
|
| 209 |
-
# exit()
|
| 210 |
incontext_length = true_latent.shape[-2]
|
| 211 |
true_latent = torch.cat([true_latent, torch.randn(true_latent.shape[0], len_add_to_1000, true_latent.shape[-1]).to(self.device)], -2)
|
| 212 |
latents = self.model.inference_codes(codes_input, spk_embeds, true_latent, latent_length, incontext_length=incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg')
|
|
@@ -228,8 +204,6 @@ class Tango:
|
|
| 228 |
else:
|
| 229 |
ov_win = torch.from_numpy(np.linspace(0, 1, ovlp_samples)[None, :])
|
| 230 |
ov_win = torch.cat([ov_win, 1 - ov_win], -1)
|
| 231 |
-
print("output.shape", output.shape)
|
| 232 |
-
print("ov_win.shape", ov_win.shape)
|
| 233 |
output[:, -ovlp_samples:] = output[:, -ovlp_samples:] * ov_win[:, -ovlp_samples:] + cur_output[:, 0:ovlp_samples] * ov_win[:, 0:ovlp_samples]
|
| 234 |
output = torch.cat([output, cur_output[:, ovlp_samples:]], -1)
|
| 235 |
output = output[:, 0:target_len]
|
|
@@ -248,9 +222,7 @@ class Tango:
|
|
| 248 |
@torch.no_grad()
|
| 249 |
def sound2sound(self, sound, prompt=None, steps=50, disable_progress=False):
|
| 250 |
codes = self.sound2code(sound)
|
| 251 |
-
# print(codes.shape)
|
| 252 |
wave = self.code2sound(codes, prompt, guidance_scale=1.5, num_steps=steps, disable_progress=disable_progress)
|
| 253 |
-
# print(fname, wave.shape)
|
| 254 |
return wave
|
| 255 |
|
| 256 |
def to(self, device=None, dtype=None, non_blocking=False):
|
|
|
|
| 10 |
import numpy as np
|
| 11 |
import tools.torch_tools as torch_tools
|
| 12 |
from safetensors.torch import load_file
|
| 13 |
+
from tools.get_1dvae_large import get_model
|
| 14 |
|
| 15 |
class Tango:
|
| 16 |
def __init__(self, \
|
|
|
|
| 24 |
scheduler_name = "configs/scheduler/stable_diffusion_2.1_largenoise_sample.json"
|
| 25 |
self.device = device
|
| 26 |
|
| 27 |
+
self.vae = get_model(vae_config, vae_model)
|
| 28 |
+
self.vae = self.vae.to(device)
|
| 29 |
+
self.vae=self.vae.eval()
|
| 30 |
self.layer_num = layer_num
|
| 31 |
|
| 32 |
self.MAX_DURATION = 360
|
|
|
|
| 53 |
# scheduler_name, subfolder="scheduler")
|
| 54 |
# print("Successfully loaded inference scheduler from {}".format(scheduler_name))
|
| 55 |
|
| 56 |
+
def sound2sound(self, orig_samples, lyric, st_et, batch_size=1, duration=40.96, steps=200, disable_progress=False,scenario = "start_seg"):
|
| 57 |
+
""" Genrate audio without condition. """
|
| 58 |
+
with torch.no_grad():
|
| 59 |
+
if(orig_samples.shape[-1]<int(duration*48000)+480):
|
| 60 |
+
orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration*48000+480)-orig_samples.shape[-1], \
|
| 61 |
+
dtype=orig_samples.dtype, device=orig_samples.device)], -1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
+
orig_samples = orig_samples.to(self.device)
|
| 64 |
+
saved_samples = orig_samples[:,0:40*48000].clamp(-1,1)
|
| 65 |
+
orig_samples = orig_samples[:,0:40*48000].clamp(-1,1)
|
| 66 |
+
max_volume = orig_samples.abs().max(dim=-1)[0]
|
| 67 |
+
orig_samples = orig_samples/max_volume.unsqueeze(-1)
|
| 68 |
|
| 69 |
+
latent_length = int((st_et[1] - st_et[0]) * 48000) // 1920 + 1
|
| 70 |
|
| 71 |
+
true_latents = self.vae.encode_audio(orig_samples).permute(0,2,1)
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
+
latents = self.model.inference(orig_samples.repeat(batch_size, 1), [lyric, ]*batch_size, true_latents, latent_length, additional_feats=[], guidance_scale=1.5, num_steps = steps, disable_progress=disable_progress,layer=6, scenario = scenario)
|
| 74 |
+
latents = latents[:,:,:latent_length]
|
| 75 |
+
audio = self.vae.decode_audio(latents)
|
| 76 |
+
audio = torch.cat((audio, torch.zeros(audio.shape[0],audio.shape[1], 48000*40 - audio.shape[-1], dtype=audio.dtype, device=audio.device)), dim=-1)
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
+
if(saved_samples.shape[-1]<audio.shape[-1]):
|
| 79 |
+
saved_samples = torch.cat([saved_samples, torch.zeros(saved_samples.shape[0], audio.shape[-1]-saved_samples.shape[-1], dtype=saved_samples.dtype, device=saved_samples.device)],-1)
|
| 80 |
+
else:
|
| 81 |
+
saved_samples = saved_samples[:,0:audio.shape[-1]]
|
| 82 |
+
output = torch.cat([saved_samples.detach().cpu(),audio[0].detach().cpu()],0)
|
| 83 |
+
return output
|
| 84 |
|
| 85 |
@torch.no_grad()
|
| 86 |
@torch.autocast(device_type="cuda", dtype=torch.float32)
|
|
|
|
| 97 |
min_samples = int(40 * self.sample_rate)
|
| 98 |
# 40秒对应10个token
|
| 99 |
output_len = int(orig_length / float(self.sample_rate) * 25) + 1
|
|
|
|
| 100 |
|
| 101 |
while(audios.shape[-1] < min_samples):
|
| 102 |
audios = torch.cat([audios, audios], -1)
|
|
|
|
| 108 |
audio_input = audios.reshape(2, -1, min_samples).permute(1, 0, 2).reshape(-1, 2, min_samples)
|
| 109 |
|
| 110 |
for audio_inx in range(0, audio_input.shape[0], batch_size):
|
|
|
|
| 111 |
codes, _, spk_embeds = self.model.fetch_codes_batch((audio_input[audio_inx:audio_inx+batch_size]), additional_feats=[],layer=self.layer_num)
|
| 112 |
codes_list.append(torch.cat(codes, 1))
|
|
|
|
| 113 |
|
| 114 |
codes = torch.cat(codes_list, 0).permute(1,0,2).reshape(1, -1)[None] # B 3 T -> 3 B T
|
| 115 |
codes=codes[:,:,:output_len]
|
|
|
|
| 148 |
# else choose from 20.48s which might includes verse or chorus
|
| 149 |
prompt = prompt[:,int(20*self.sample_rate):int(30*self.sample_rate)] # limit max length to 10.24
|
| 150 |
|
| 151 |
+
true_latent = self.vae.encode_audio(prompt).permute(0,2,1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
first_latent[:,0:true_latent.shape[1],:] = true_latent
|
| 153 |
first_latent_length = true_latent.shape[1]
|
| 154 |
first_latent_codes = self.sound2code(prompt)
|
| 155 |
first_latent_codes_length = first_latent_codes.shape[-1]
|
| 156 |
codes = torch.cat([first_latent_codes, codes], -1)
|
| 157 |
|
|
|
|
|
|
|
|
|
|
| 158 |
codes_len= codes.shape[-1]
|
| 159 |
target_len = int((codes_len - first_latent_codes_length) / 100 * 4 * self.sample_rate)
|
| 160 |
# target_len = int(codes_len / 100 * 4 * self.sample_rate)
|
|
|
|
| 177 |
codes_input=[]
|
| 178 |
codes_input.append(codes[:,:,sinx:sinx+min_samples])
|
| 179 |
if(sinx == 0):
|
|
|
|
| 180 |
incontext_length = first_latent_length
|
| 181 |
latents = self.model.inference_codes(codes_input, spk_embeds, first_latent, latent_length, incontext_length=incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg')
|
| 182 |
latent_list.append(latents)
|
| 183 |
else:
|
|
|
|
| 184 |
true_latent = latent_list[-1][:,:,-ovlp_frames:].permute(0,2,1)
|
|
|
|
| 185 |
len_add_to_1000 = min_samples - true_latent.shape[-2]
|
|
|
|
|
|
|
| 186 |
incontext_length = true_latent.shape[-2]
|
| 187 |
true_latent = torch.cat([true_latent, torch.randn(true_latent.shape[0], len_add_to_1000, true_latent.shape[-1]).to(self.device)], -2)
|
| 188 |
latents = self.model.inference_codes(codes_input, spk_embeds, true_latent, latent_length, incontext_length=incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg')
|
|
|
|
| 204 |
else:
|
| 205 |
ov_win = torch.from_numpy(np.linspace(0, 1, ovlp_samples)[None, :])
|
| 206 |
ov_win = torch.cat([ov_win, 1 - ov_win], -1)
|
|
|
|
|
|
|
| 207 |
output[:, -ovlp_samples:] = output[:, -ovlp_samples:] * ov_win[:, -ovlp_samples:] + cur_output[:, 0:ovlp_samples] * ov_win[:, 0:ovlp_samples]
|
| 208 |
output = torch.cat([output, cur_output[:, ovlp_samples:]], -1)
|
| 209 |
output = output[:, 0:target_len]
|
|
|
|
| 222 |
@torch.no_grad()
|
| 223 |
def sound2sound(self, sound, prompt=None, steps=50, disable_progress=False):
|
| 224 |
codes = self.sound2code(sound)
|
|
|
|
| 225 |
wave = self.code2sound(codes, prompt, guidance_scale=1.5, num_steps=steps, disable_progress=disable_progress)
|
|
|
|
| 226 |
return wave
|
| 227 |
|
| 228 |
def to(self, device=None, dtype=None, non_blocking=False):
|
codeclm/tokenizer/Flow1dVAE/model_1rvq.py
CHANGED
|
@@ -301,17 +301,17 @@ class PromptCondAudioDiffusion(nn.Module):
|
|
| 301 |
# for v in self.hubert.parameters():v.requires_grad = False
|
| 302 |
self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,))
|
| 303 |
# self.xvecmodel = XVECModel()
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
self.set_from = "random"
|
| 314 |
-
|
| 315 |
self.mask_emb = torch.nn.Embedding(3, 48)
|
| 316 |
print("Transformer initialized from pretrain.")
|
| 317 |
torch.cuda.empty_cache()
|
|
@@ -602,38 +602,20 @@ class PromptCondAudioDiffusion(nn.Module):
|
|
| 602 |
dtype = self.dtype
|
| 603 |
# codes_bestrq_middle, codes_bestrq_last = codes
|
| 604 |
codes_bestrq_emb = codes[0]
|
| 605 |
-
|
| 606 |
-
|
| 607 |
batch_size = codes_bestrq_emb.shape[0]
|
| 608 |
-
|
| 609 |
-
|
| 610 |
quantized_bestrq_emb,_,_=self.rvq_bestrq_emb.from_codes(codes_bestrq_emb)
|
| 611 |
-
# quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
|
| 612 |
quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous()
|
| 613 |
-
print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
|
| 614 |
-
# quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
| 619 |
if('spk' in additional_feats):
|
| 620 |
spk_embeds = spk_embeds.repeat(1,1,quantized_bestrq_emb.shape[-2],1).detach()
|
| 621 |
-
|
| 622 |
num_frames = quantized_bestrq_emb.shape[1]
|
| 623 |
-
|
| 624 |
num_channels_latents = self.num_channels
|
| 625 |
shape = (batch_size, num_frames, 64)
|
| 626 |
latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
|
| 630 |
latent_masks = torch.zeros(latents.shape[0], latents.shape[1], dtype=torch.int64, device=latents.device)
|
| 631 |
latent_masks[:,0:latent_length] = 2
|
| 632 |
if(scenario=='other_seg'):
|
| 633 |
latent_masks[:,0:incontext_length] = 1
|
| 634 |
|
| 635 |
-
|
| 636 |
-
|
| 637 |
quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \
|
| 638 |
+ (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
|
| 639 |
true_latents = true_latents.permute(0,2,1).contiguous()
|
|
@@ -642,7 +624,6 @@ class PromptCondAudioDiffusion(nn.Module):
|
|
| 642 |
incontext_latents = true_latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float()
|
| 643 |
incontext_length = ((latent_masks > 0.5) * (latent_masks < 1.5)).sum(-1)[0]
|
| 644 |
|
| 645 |
-
|
| 646 |
attention_mask=(latent_masks > 0.5)
|
| 647 |
B, L = attention_mask.size()
|
| 648 |
attention_mask = attention_mask.view(B, 1, L)
|
|
|
|
| 301 |
# for v in self.hubert.parameters():v.requires_grad = False
|
| 302 |
self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,))
|
| 303 |
# self.xvecmodel = XVECModel()
|
| 304 |
+
config = GPT2Config(n_positions=1000,n_layer=39,n_head=30,n_embd=1200)
|
| 305 |
+
unet = GPT2Model(config)
|
| 306 |
+
mlp = nn.Sequential(
|
| 307 |
+
nn.Linear(1200, 1024),
|
| 308 |
+
nn.SiLU(),
|
| 309 |
+
nn.Linear(1024, 1024),
|
| 310 |
+
nn.SiLU(),
|
| 311 |
+
nn.Linear(1024, 768)
|
| 312 |
+
)
|
| 313 |
self.set_from = "random"
|
| 314 |
+
self.cfm_wrapper = BASECFM(unet, mlp,self.ssl_layer)
|
| 315 |
self.mask_emb = torch.nn.Embedding(3, 48)
|
| 316 |
print("Transformer initialized from pretrain.")
|
| 317 |
torch.cuda.empty_cache()
|
|
|
|
| 602 |
dtype = self.dtype
|
| 603 |
# codes_bestrq_middle, codes_bestrq_last = codes
|
| 604 |
codes_bestrq_emb = codes[0]
|
|
|
|
|
|
|
| 605 |
batch_size = codes_bestrq_emb.shape[0]
|
|
|
|
|
|
|
| 606 |
quantized_bestrq_emb,_,_=self.rvq_bestrq_emb.from_codes(codes_bestrq_emb)
|
|
|
|
| 607 |
quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 608 |
if('spk' in additional_feats):
|
| 609 |
spk_embeds = spk_embeds.repeat(1,1,quantized_bestrq_emb.shape[-2],1).detach()
|
|
|
|
| 610 |
num_frames = quantized_bestrq_emb.shape[1]
|
|
|
|
| 611 |
num_channels_latents = self.num_channels
|
| 612 |
shape = (batch_size, num_frames, 64)
|
| 613 |
latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
|
|
|
|
|
|
|
|
|
|
| 614 |
latent_masks = torch.zeros(latents.shape[0], latents.shape[1], dtype=torch.int64, device=latents.device)
|
| 615 |
latent_masks[:,0:latent_length] = 2
|
| 616 |
if(scenario=='other_seg'):
|
| 617 |
latent_masks[:,0:incontext_length] = 1
|
| 618 |
|
|
|
|
|
|
|
| 619 |
quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \
|
| 620 |
+ (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
|
| 621 |
true_latents = true_latents.permute(0,2,1).contiguous()
|
|
|
|
| 624 |
incontext_latents = true_latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float()
|
| 625 |
incontext_length = ((latent_masks > 0.5) * (latent_masks < 1.5)).sum(-1)[0]
|
| 626 |
|
|
|
|
| 627 |
attention_mask=(latent_masks > 0.5)
|
| 628 |
B, L = attention_mask.size()
|
| 629 |
attention_mask = attention_mask.view(B, 1, L)
|
codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_config.py
CHANGED
|
@@ -18,6 +18,8 @@
|
|
| 18 |
from collections import OrderedDict
|
| 19 |
from typing import Any, List, Mapping, Optional
|
| 20 |
|
|
|
|
|
|
|
| 21 |
from transformers import PreTrainedTokenizer, TensorType, is_torch_available
|
| 22 |
from transformers.configuration_utils import PretrainedConfig
|
| 23 |
from transformers.onnx import OnnxConfigWithPast, PatchingSpec
|
|
@@ -27,6 +29,59 @@ from transformers.utils import logging
|
|
| 27 |
logger = logging.get_logger(__name__)
|
| 28 |
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
class GPT2Config(PretrainedConfig):
|
| 31 |
"""
|
| 32 |
This is the configuration class to store the configuration of a [`GPT2Model`] or a [`TFGPT2Model`]. It is used to
|
|
|
|
| 18 |
from collections import OrderedDict
|
| 19 |
from typing import Any, List, Mapping, Optional
|
| 20 |
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn as nn
|
| 23 |
from transformers import PreTrainedTokenizer, TensorType, is_torch_available
|
| 24 |
from transformers.configuration_utils import PretrainedConfig
|
| 25 |
from transformers.onnx import OnnxConfigWithPast, PatchingSpec
|
|
|
|
| 29 |
logger = logging.get_logger(__name__)
|
| 30 |
|
| 31 |
|
| 32 |
+
class SequenceSummary(nn.Module):
|
| 33 |
+
"""Compute a single vector summary of a sequence hidden states."""
|
| 34 |
+
|
| 35 |
+
def __init__(self, config: PretrainedConfig):
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.summary_type = getattr(config, "summary_type", "last")
|
| 38 |
+
self.summary_use_proj = getattr(config, "summary_use_proj", True)
|
| 39 |
+
self.summary_activation = getattr(config, "summary_activation", None)
|
| 40 |
+
self.summary_last_dropout = getattr(config, "summary_last_dropout", 0.0)
|
| 41 |
+
self.summary_first_dropout = getattr(config, "summary_first_dropout", 0.0)
|
| 42 |
+
self.summary_proj_to_labels = getattr(config, "summary_proj_to_labels", True)
|
| 43 |
+
|
| 44 |
+
if self.summary_use_proj:
|
| 45 |
+
if self.summary_proj_to_labels and hasattr(config, "num_labels"):
|
| 46 |
+
num_classes = config.num_labels
|
| 47 |
+
else:
|
| 48 |
+
num_classes = config.hidden_size
|
| 49 |
+
self.summary = nn.Linear(config.hidden_size, num_classes)
|
| 50 |
+
|
| 51 |
+
self.activation = nn.Tanh() if self.summary_activation == "tanh" else None
|
| 52 |
+
self.first_dropout = nn.Dropout(self.summary_first_dropout) if self.summary_first_dropout > 0 else None
|
| 53 |
+
self.last_dropout = nn.Dropout(self.summary_last_dropout) if self.summary_last_dropout > 0 else None
|
| 54 |
+
|
| 55 |
+
def forward(self, hidden_states, cls_index=None):
|
| 56 |
+
if self.summary_type == "last":
|
| 57 |
+
output = hidden_states[:, -1]
|
| 58 |
+
elif self.summary_type == "first":
|
| 59 |
+
output = hidden_states[:, 0]
|
| 60 |
+
elif self.summary_type == "mean":
|
| 61 |
+
output = hidden_states.mean(dim=1)
|
| 62 |
+
elif self.summary_type == "cls_index":
|
| 63 |
+
if cls_index is None:
|
| 64 |
+
cls_index = torch.full_like(hidden_states[:, :1, :1], hidden_states.size(1) - 1, dtype=torch.long)
|
| 65 |
+
cls_index = cls_index[:, 0].long()
|
| 66 |
+
output = hidden_states[torch.arange(hidden_states.size(0)), cls_index]
|
| 67 |
+
else:
|
| 68 |
+
output = hidden_states[:, -1] # default to last
|
| 69 |
+
|
| 70 |
+
if self.first_dropout:
|
| 71 |
+
output = self.first_dropout(output)
|
| 72 |
+
|
| 73 |
+
if self.summary_use_proj:
|
| 74 |
+
output = self.summary(output)
|
| 75 |
+
|
| 76 |
+
if self.activation:
|
| 77 |
+
output = self.activation(output)
|
| 78 |
+
|
| 79 |
+
if self.last_dropout:
|
| 80 |
+
output = self.last_dropout(output)
|
| 81 |
+
|
| 82 |
+
return output
|
| 83 |
+
|
| 84 |
+
|
| 85 |
class GPT2Config(PretrainedConfig):
|
| 86 |
"""
|
| 87 |
This is the configuration class to store the configuration of a [`GPT2Model`] or a [`TFGPT2Model`]. It is used to
|
codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_rope2_time_new_correct_mask_noncasual_reflow.py
CHANGED
|
@@ -37,7 +37,7 @@ from transformers.modeling_outputs import (
|
|
| 37 |
SequenceClassifierOutputWithPast,
|
| 38 |
TokenClassifierOutput,
|
| 39 |
)
|
| 40 |
-
from transformers.modeling_utils import PreTrainedModel
|
| 41 |
from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
|
| 42 |
from transformers.utils import (
|
| 43 |
ModelOutput,
|
|
@@ -50,7 +50,7 @@ from transformers.utils import (
|
|
| 50 |
replace_return_docstrings,
|
| 51 |
)
|
| 52 |
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
|
| 53 |
-
from models_gpt.models.gpt2_config import GPT2Config
|
| 54 |
|
| 55 |
|
| 56 |
if is_flash_attn_2_available():
|
|
|
|
| 37 |
SequenceClassifierOutputWithPast,
|
| 38 |
TokenClassifierOutput,
|
| 39 |
)
|
| 40 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 41 |
from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
|
| 42 |
from transformers.utils import (
|
| 43 |
ModelOutput,
|
|
|
|
| 50 |
replace_return_docstrings,
|
| 51 |
)
|
| 52 |
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
|
| 53 |
+
from models_gpt.models.gpt2_config import GPT2Config, SequenceSummary
|
| 54 |
|
| 55 |
|
| 56 |
if is_flash_attn_2_available():
|
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/musicfm/modules/features.py
CHANGED
|
@@ -15,7 +15,7 @@
|
|
| 15 |
|
| 16 |
import torchaudio
|
| 17 |
from torch import nn
|
| 18 |
-
|
| 19 |
|
| 20 |
class MelSTFT(nn.Module):
|
| 21 |
def __init__(
|
|
@@ -39,7 +39,16 @@ class MelSTFT(nn.Module):
|
|
| 39 |
self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
|
| 40 |
|
| 41 |
def forward(self, waveform):
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
import torchaudio
|
| 17 |
from torch import nn
|
| 18 |
+
import torch
|
| 19 |
|
| 20 |
class MelSTFT(nn.Module):
|
| 21 |
def __init__(
|
|
|
|
| 39 |
self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
|
| 40 |
|
| 41 |
def forward(self, waveform):
|
| 42 |
+
# 将数据移至 CPU 处理 STFT,再移回 GPU
|
| 43 |
+
device = waveform.device
|
| 44 |
+
waveform_cpu = waveform.cpu()
|
| 45 |
+
# 强制在 CPU 上运行
|
| 46 |
+
with torch.cpu.amp.autocast(enabled=False):
|
| 47 |
+
if self.is_db:
|
| 48 |
+
spec = self.amplitude_to_db(self.mel_stft.to('cpu')(waveform_cpu))
|
| 49 |
+
else:
|
| 50 |
+
spec = self.mel_stft.to('cpu')(waveform_cpu)
|
| 51 |
+
# 结果移回原设备,并将 mel_stft 移回原设备供下次使用(或者克隆一个 cpu 版的)
|
| 52 |
+
spec = spec.to(device)
|
| 53 |
+
self.mel_stft.to(device)
|
| 54 |
+
return spec
|
codeclm/tokenizer/audio_tokenizer.py
CHANGED
|
@@ -136,7 +136,7 @@ class Flow1dVAE1rvq(AudioTokenizer):
|
|
| 136 |
@torch.no_grad()
|
| 137 |
def decode(self, codes: torch.Tensor, prompt = None, scale: tp.Optional[torch.Tensor] = None, ncodes=9):
|
| 138 |
wav = self.model.code2sound(codes, prompt=prompt, guidance_scale=1.5,
|
| 139 |
-
num_steps=
|
| 140 |
return wav[None]
|
| 141 |
|
| 142 |
|
|
@@ -222,7 +222,7 @@ class Flow1dVAESeparate(AudioTokenizer):
|
|
| 222 |
@torch.no_grad()
|
| 223 |
def decode(self, codes: torch.Tensor, prompt_vocal = None, prompt_bgm = None, chunked=False, chunk_size=128):
|
| 224 |
wav = self.model.code2sound(codes, prompt_vocal=prompt_vocal, prompt_bgm=prompt_bgm, guidance_scale=1.5,
|
| 225 |
-
num_steps=
|
| 226 |
return wav[None]
|
| 227 |
|
| 228 |
|
|
|
|
| 136 |
@torch.no_grad()
|
| 137 |
def decode(self, codes: torch.Tensor, prompt = None, scale: tp.Optional[torch.Tensor] = None, ncodes=9):
|
| 138 |
wav = self.model.code2sound(codes, prompt=prompt, guidance_scale=1.5,
|
| 139 |
+
num_steps=10, disable_progress=False) # [B,N,T] -> [B,T]
|
| 140 |
return wav[None]
|
| 141 |
|
| 142 |
|
|
|
|
| 222 |
@torch.no_grad()
|
| 223 |
def decode(self, codes: torch.Tensor, prompt_vocal = None, prompt_bgm = None, chunked=False, chunk_size=128):
|
| 224 |
wav = self.model.code2sound(codes, prompt_vocal=prompt_vocal, prompt_bgm=prompt_bgm, guidance_scale=1.5,
|
| 225 |
+
num_steps=10, disable_progress=False, chunked=chunked, chunk_size=chunk_size) # [B,N,T] -> [B,T]
|
| 226 |
return wav[None]
|
| 227 |
|
| 228 |
|
generate.py
CHANGED
|
@@ -1,22 +1,19 @@
|
|
| 1 |
-
|
| 2 |
-
import sys
|
| 3 |
-
import os
|
| 4 |
-
import argparse
|
| 5 |
-
|
| 6 |
import time
|
| 7 |
-
import json
|
| 8 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
import torchaudio
|
| 10 |
import numpy as np
|
| 11 |
-
|
| 12 |
-
from
|
| 13 |
-
import gc
|
| 14 |
-
from codeclm.trainer.codec_song_pl import CodecLM_PL
|
| 15 |
-
from codeclm.models import CodecLM
|
| 16 |
-
from third_party.demucs.models.pretrained import get_model_from_yaml
|
| 17 |
import re
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
-
auto_prompt_type = ['Pop', 'R&B', 'Dance', 'Jazz', 'Folk', 'Rock', 'Chinese Style', 'Chinese Tradition', 'Metal', 'Reggae', 'Chinese Opera', 'Auto']
|
| 20 |
|
| 21 |
def check_language_by_text(text):
|
| 22 |
chinese_pattern = re.compile(r'[\u4e00-\u9fff]')
|
|
@@ -32,563 +29,172 @@ def check_language_by_text(text):
|
|
| 32 |
else:
|
| 33 |
return "en"
|
| 34 |
|
| 35 |
-
class Separator:
|
| 36 |
-
def __init__(self, dm_model_path='third_party/demucs/ckpt/htdemucs.pth', dm_config_path='third_party/demucs/ckpt/htdemucs.yaml', gpu_id=0) -> None:
|
| 37 |
-
if torch.cuda.is_available() and gpu_id < torch.cuda.device_count():
|
| 38 |
-
self.device = torch.device(f"cuda:{gpu_id}")
|
| 39 |
-
else:
|
| 40 |
-
self.device = torch.device("cpu")
|
| 41 |
-
self.demucs_model = self.init_demucs_model(dm_model_path, dm_config_path)
|
| 42 |
-
|
| 43 |
-
def init_demucs_model(self, model_path, config_path):
|
| 44 |
-
model = get_model_from_yaml(config_path, model_path)
|
| 45 |
-
model.to(self.device)
|
| 46 |
-
model.eval()
|
| 47 |
-
return model
|
| 48 |
-
|
| 49 |
-
def load_audio(self, f):
|
| 50 |
-
a, fs = torchaudio.load(f)
|
| 51 |
-
if (fs != 48000):
|
| 52 |
-
a = torchaudio.functional.resample(a, fs, 48000)
|
| 53 |
-
if a.shape[-1] >= 48000*10:
|
| 54 |
-
a = a[..., :48000*10]
|
| 55 |
-
return a[:, 0:48000*10]
|
| 56 |
-
|
| 57 |
-
def run(self, audio_path, output_dir='tmp', ext=".flac"):
|
| 58 |
-
os.makedirs(output_dir, exist_ok=True)
|
| 59 |
-
name, _ = os.path.splitext(os.path.split(audio_path)[-1])
|
| 60 |
-
output_paths = []
|
| 61 |
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
for path in [drums_path, bass_path, other_path]:
|
| 71 |
-
os.remove(path)
|
| 72 |
-
full_audio = self.load_audio(audio_path)
|
| 73 |
-
vocal_audio = self.load_audio(vocal_path)
|
| 74 |
-
bgm_audio = full_audio - vocal_audio
|
| 75 |
-
return full_audio, vocal_audio, bgm_audio
|
| 76 |
|
| 77 |
|
| 78 |
def parse_args():
|
| 79 |
parser = argparse.ArgumentParser(description='Song Generation Script')
|
| 80 |
|
| 81 |
# 必需参数
|
| 82 |
-
parser.add_argument('--ckpt_path', type=str, required=True,
|
| 83 |
-
help='Path to the checkpoint directory containing config.yaml and model.pt')
|
| 84 |
parser.add_argument('--input_jsonl', type=str, required=True,
|
| 85 |
help='Path to input JSONL file containing generation tasks')
|
| 86 |
parser.add_argument('--save_dir', type=str, required=True,
|
| 87 |
help='Directory to save generated audio files and results')
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
help='Type of generation: "vocal" or "bgm" or "separate" or "mixed" (default: "mixed")')
|
| 91 |
-
parser.add_argument('--use_flash_attn', action='store_true',
|
| 92 |
-
help='Whether to use flash attention (default: False)')
|
| 93 |
-
parser.add_argument('--low_mem', action='store_true',
|
| 94 |
-
help='Whether to use low memory mode (default: False)')
|
| 95 |
return parser.parse_args()
|
| 96 |
|
| 97 |
-
|
|
|
|
| 98 |
torch.set_num_threads(1)
|
| 99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
input_jsonl = args.input_jsonl
|
| 101 |
save_dir = args.save_dir
|
| 102 |
-
cfg_path =
|
| 103 |
-
|
| 104 |
cfg = OmegaConf.load(cfg_path)
|
| 105 |
-
cfg.lm.use_flash_attn_2 = args.use_flash_attn
|
| 106 |
-
print(f"use_flash_attn: {args.use_flash_attn}")
|
| 107 |
cfg.mode = 'inference'
|
| 108 |
max_duration = cfg.max_dur
|
| 109 |
-
gen_type = args.generate_type
|
| 110 |
|
| 111 |
-
|
| 112 |
-
separator = Separator()
|
| 113 |
-
auto_prompt = torch.load('tools/new_auto_prompt.pt')
|
| 114 |
audio_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint, cfg)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
audio_tokenizer = audio_tokenizer.eval().cuda()
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
item['raw_pmt_wav'] = pmt_wav
|
| 131 |
-
item['raw_vocal_wav'] = vocal_wav
|
| 132 |
-
item['raw_bgm_wav'] = bgm_wav
|
| 133 |
-
if pmt_wav.dim() == 2:
|
| 134 |
-
pmt_wav = pmt_wav[None]
|
| 135 |
-
if pmt_wav.dim() != 3:
|
| 136 |
-
raise ValueError("Melody wavs should have a shape [B, C, T].")
|
| 137 |
-
pmt_wav = list(pmt_wav)
|
| 138 |
-
if vocal_wav.dim() == 2:
|
| 139 |
-
vocal_wav = vocal_wav[None]
|
| 140 |
-
if vocal_wav.dim() != 3:
|
| 141 |
-
raise ValueError("Vocal wavs should have a shape [B, C, T].")
|
| 142 |
-
vocal_wav = list(vocal_wav)
|
| 143 |
-
if bgm_wav.dim() == 2:
|
| 144 |
-
bgm_wav = bgm_wav[None]
|
| 145 |
-
if bgm_wav.dim() != 3:
|
| 146 |
-
raise ValueError("BGM wavs should have a shape [B, C, T].")
|
| 147 |
-
bgm_wav = list(bgm_wav)
|
| 148 |
-
if type(pmt_wav) == list:
|
| 149 |
-
pmt_wav = torch.stack(pmt_wav, dim=0)
|
| 150 |
-
if type(vocal_wav) == list:
|
| 151 |
-
vocal_wav = torch.stack(vocal_wav, dim=0)
|
| 152 |
-
if type(bgm_wav) == list:
|
| 153 |
-
bgm_wav = torch.stack(bgm_wav, dim=0)
|
| 154 |
-
pmt_wav = pmt_wav
|
| 155 |
-
vocal_wav = vocal_wav
|
| 156 |
-
bgm_wav = bgm_wav
|
| 157 |
-
with torch.no_grad():
|
| 158 |
-
pmt_wav, _ = audio_tokenizer.encode(pmt_wav.cuda())
|
| 159 |
-
melody_is_wav = False
|
| 160 |
-
elif "auto_prompt_audio_type" in item:
|
| 161 |
-
assert item["auto_prompt_audio_type"] in auto_prompt_type, f"auto_prompt_audio_type {item['auto_prompt_audio_type']} not found"
|
| 162 |
-
if item['auto_prompt_audio_type'] == 'Auto':
|
| 163 |
-
lang = check_language_by_text(item['gt_lyric'])
|
| 164 |
-
prompt_token = auto_prompt['Auto'][lang][np.random.randint(0, len(auto_prompt['Auto'][lang]))]
|
| 165 |
-
else:
|
| 166 |
-
prompt_token = auto_prompt[item["auto_prompt_audio_type"]][np.random.randint(0, len(auto_prompt[item["auto_prompt_audio_type"]]))]
|
| 167 |
-
pmt_wav = prompt_token[:,[0],:]
|
| 168 |
-
vocal_wav = prompt_token[:,[1],:]
|
| 169 |
-
bgm_wav = prompt_token[:,[2],:]
|
| 170 |
-
melody_is_wav = False
|
| 171 |
-
else:
|
| 172 |
-
pmt_wav = None
|
| 173 |
-
vocal_wav = None
|
| 174 |
-
bgm_wav = None
|
| 175 |
-
melody_is_wav = True
|
| 176 |
-
item['pmt_wav'] = pmt_wav
|
| 177 |
-
item['vocal_wav'] = vocal_wav
|
| 178 |
-
item['bgm_wav'] = bgm_wav
|
| 179 |
-
item['melody_is_wav'] = melody_is_wav
|
| 180 |
-
item["idx"] = f"{item['idx']}"
|
| 181 |
-
item["wav_path"] = target_wav_name
|
| 182 |
-
new_items.append(item)
|
| 183 |
-
|
| 184 |
-
del audio_tokenizer
|
| 185 |
-
del separator
|
| 186 |
-
|
| 187 |
-
torch.cuda.empty_cache()
|
| 188 |
-
|
| 189 |
-
if "audio_tokenizer_checkpoint_sep" in cfg.keys():
|
| 190 |
-
seperate_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint_sep, cfg)
|
| 191 |
-
else:
|
| 192 |
-
seperate_tokenizer = None
|
| 193 |
-
|
| 194 |
-
if seperate_tokenizer is not None:
|
| 195 |
-
seperate_tokenizer = seperate_tokenizer.eval().cuda()
|
| 196 |
-
|
| 197 |
-
for item in new_items:
|
| 198 |
-
if "prompt_audio_path" in item:
|
| 199 |
-
with torch.no_grad():
|
| 200 |
-
vocal_wav, bgm_wav = seperate_tokenizer.encode(item['vocal_wav'].cuda(), item['bgm_wav'].cuda())
|
| 201 |
-
item['vocal_wav'] = vocal_wav
|
| 202 |
-
item['bgm_wav'] = bgm_wav
|
| 203 |
-
|
| 204 |
-
torch.cuda.empty_cache()
|
| 205 |
-
audiolm = builders.get_lm_model(cfg, version=version)
|
| 206 |
-
checkpoint = torch.load(ckpt_path, map_location='cpu')
|
| 207 |
-
audiolm_state_dict = {k.replace('audiolm.', ''): v for k, v in checkpoint.items() if k.startswith('audiolm')}
|
| 208 |
-
audiolm.load_state_dict(audiolm_state_dict, strict=False)
|
| 209 |
-
audiolm = audiolm.eval()
|
| 210 |
-
audiolm = audiolm.cuda().to(torch.float16)
|
| 211 |
-
|
| 212 |
-
model = CodecLM(name = "tmp",
|
| 213 |
-
lm = audiolm,
|
| 214 |
-
audiotokenizer = None,
|
| 215 |
-
max_duration = max_duration,
|
| 216 |
-
seperate_tokenizer = seperate_tokenizer,
|
| 217 |
)
|
|
|
|
|
|
|
| 218 |
|
| 219 |
-
|
| 220 |
-
temp =
|
| 221 |
-
top_k =
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
record_window = 50
|
| 225 |
-
|
| 226 |
-
model.set_generation_params(duration=max_duration, extend_stride=5, temperature=temp, cfg_coef=cfg_coef,
|
| 227 |
-
top_k=top_k, top_p=top_p, record_tokens=record_tokens, record_window=record_window)
|
| 228 |
os.makedirs(save_dir, exist_ok=True)
|
| 229 |
os.makedirs(save_dir + "/audios", exist_ok=True)
|
| 230 |
os.makedirs(save_dir + "/jsonl", exist_ok=True)
|
| 231 |
-
|
| 232 |
-
for item in new_items:
|
| 233 |
-
lyric = item["gt_lyric"]
|
| 234 |
-
if version == 'v1.0':
|
| 235 |
-
descriptions = item["descriptions"] if "descriptions" in item else None
|
| 236 |
-
else:
|
| 237 |
-
descriptions = item["descriptions"] if "descriptions" in item else '.'
|
| 238 |
-
descriptions = '[Musicality-very-high]' + ', ' + descriptions
|
| 239 |
-
pmt_wav = item['pmt_wav']
|
| 240 |
-
vocal_wav = item['vocal_wav']
|
| 241 |
-
bgm_wav = item['bgm_wav']
|
| 242 |
-
melody_is_wav = item['melody_is_wav']
|
| 243 |
-
target_wav_name = f"{save_dir}/audios/{item['idx']}.flac"
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
generate_inp = {
|
| 247 |
-
'lyrics': [lyric.replace(" ", " ")],
|
| 248 |
-
'descriptions': [descriptions],
|
| 249 |
-
'melody_wavs': pmt_wav,
|
| 250 |
-
'vocal_wavs': vocal_wav,
|
| 251 |
-
'bgm_wavs': bgm_wav,
|
| 252 |
-
'melody_is_wav': melody_is_wav,
|
| 253 |
-
}
|
| 254 |
-
start_time = time.time()
|
| 255 |
-
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
| 256 |
-
with torch.no_grad():
|
| 257 |
-
tokens = model.generate(**generate_inp, return_tokens=True)
|
| 258 |
-
mid_time = time.time()
|
| 259 |
-
|
| 260 |
-
with torch.no_grad():
|
| 261 |
-
if 'raw_pmt_wav' in item:
|
| 262 |
-
if gen_type == 'separate':
|
| 263 |
-
wav_seperate = model.generate_audio(tokens, item['raw_pmt_wav'], item['raw_vocal_wav'], item['raw_bgm_wav'], chunked=True, gen_type='mixed')
|
| 264 |
-
wav_vocal = model.generate_audio(tokens, item['raw_pmt_wav'], item['raw_vocal_wav'], item['raw_bgm_wav'], chunked=True, gen_type='vocal')
|
| 265 |
-
wav_bgm = model.generate_audio(tokens, item['raw_pmt_wav'], item['raw_vocal_wav'], item['raw_bgm_wav'], chunked=True, gen_type='bgm')
|
| 266 |
-
elif gen_type == 'mixed':
|
| 267 |
-
wav_seperate = model.generate_audio(tokens, item['raw_pmt_wav'], item['raw_vocal_wav'], item['raw_bgm_wav'],chunked=True, gen_type=gen_type)
|
| 268 |
-
else:
|
| 269 |
-
wav_seperate = model.generate_audio(tokens,chunked=True, gen_type=gen_type)
|
| 270 |
-
del item['raw_pmt_wav']
|
| 271 |
-
del item['raw_vocal_wav']
|
| 272 |
-
del item['raw_bgm_wav']
|
| 273 |
-
else:
|
| 274 |
-
if gen_type == 'separate':
|
| 275 |
-
wav_vocal = model.generate_audio(tokens, chunked=True, gen_type='vocal')
|
| 276 |
-
wav_bgm = model.generate_audio(tokens, chunked=True, gen_type='bgm')
|
| 277 |
-
wav_seperate = model.generate_audio(tokens, chunked=True, gen_type='mixed')
|
| 278 |
-
else:
|
| 279 |
-
wav_seperate = model.generate_audio(tokens, chunked=True, gen_type=gen_type)
|
| 280 |
-
del item['pmt_wav']
|
| 281 |
-
del item['vocal_wav']
|
| 282 |
-
del item['bgm_wav']
|
| 283 |
-
del item['melody_is_wav']
|
| 284 |
-
end_time = time.time()
|
| 285 |
-
if gen_type == 'separate':
|
| 286 |
-
torchaudio.save(target_wav_name.replace('.flac', '_vocal.flac'), wav_vocal[0].cpu().float(), cfg.sample_rate)
|
| 287 |
-
torchaudio.save(target_wav_name.replace('.flac', '_bgm.flac'), wav_bgm[0].cpu().float(), cfg.sample_rate)
|
| 288 |
-
torchaudio.save(target_wav_name, wav_seperate[0].cpu().float(), cfg.sample_rate)
|
| 289 |
-
else:
|
| 290 |
-
torchaudio.save(target_wav_name, wav_seperate[0].cpu().float(), cfg.sample_rate)
|
| 291 |
-
|
| 292 |
-
print(f"process{item['idx']}, lm cost {mid_time - start_time}s, diffusion cost {end_time - mid_time}")
|
| 293 |
-
item["idx"] = f"{item['idx']}"
|
| 294 |
-
item["wav_path"] = target_wav_name
|
| 295 |
-
|
| 296 |
-
src_jsonl_name = os.path.split(input_jsonl)[-1]
|
| 297 |
-
with open(f"{save_dir}/jsonl/{src_jsonl_name}.jsonl", "w", encoding='utf-8') as fw:
|
| 298 |
-
for item in new_items:
|
| 299 |
-
fw.writelines(json.dumps(item, ensure_ascii=False)+"\n")
|
| 300 |
-
|
| 301 |
-
def generate_lowmem(args):
|
| 302 |
-
torch.set_num_threads(1)
|
| 303 |
-
ckpt_path = args.ckpt_path
|
| 304 |
-
input_jsonl = args.input_jsonl
|
| 305 |
-
save_dir = args.save_dir
|
| 306 |
-
cfg_path = os.path.join(ckpt_path, 'config.yaml')
|
| 307 |
-
ckpt_path = os.path.join(ckpt_path, 'model.pt')
|
| 308 |
-
cfg = OmegaConf.load(cfg_path)
|
| 309 |
-
cfg.lm.use_flash_attn_2 = args.use_flash_attn
|
| 310 |
-
print(f"use_flash_attn: {args.use_flash_attn}")
|
| 311 |
-
cfg.mode = 'inference'
|
| 312 |
-
max_duration = cfg.max_dur
|
| 313 |
-
gen_type = args.generate_type
|
| 314 |
-
chunk_size = 128
|
| 315 |
-
use_audio_tokenizer = False
|
| 316 |
with open(input_jsonl, "r") as fp:
|
| 317 |
lines = fp.readlines()
|
| 318 |
-
for line in lines:
|
| 319 |
-
item = json.loads(line)
|
| 320 |
-
if "prompt_audio_path" in item:
|
| 321 |
-
use_audio_tokenizer = True
|
| 322 |
-
break
|
| 323 |
-
if use_audio_tokenizer:
|
| 324 |
-
separator = Separator()
|
| 325 |
-
audio_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint, cfg)
|
| 326 |
-
audio_tokenizer = audio_tokenizer.eval().cuda()
|
| 327 |
-
auto_prompt = torch.load('tools/new_prompt.pt')
|
| 328 |
new_items = []
|
| 329 |
for line in lines:
|
| 330 |
item = json.loads(line)
|
|
|
|
|
|
|
|
|
|
| 331 |
target_wav_name = f"{save_dir}/audios/{item['idx']}.flac"
|
| 332 |
-
|
|
|
|
| 333 |
if "prompt_audio_path" in item:
|
| 334 |
assert os.path.exists(item['prompt_audio_path']), f"prompt_audio_path {item['prompt_audio_path']} not found"
|
| 335 |
assert 'auto_prompt_audio_type' not in item, f"auto_prompt_audio_type and prompt_audio_path cannot be used together"
|
| 336 |
with torch.no_grad():
|
| 337 |
-
pmt_wav
|
| 338 |
item['raw_pmt_wav'] = pmt_wav
|
| 339 |
-
item['raw_vocal_wav'] = vocal_wav
|
| 340 |
-
item['raw_bgm_wav'] = bgm_wav
|
| 341 |
if pmt_wav.dim() == 2:
|
| 342 |
pmt_wav = pmt_wav[None]
|
| 343 |
if pmt_wav.dim() != 3:
|
| 344 |
raise ValueError("Melody wavs should have a shape [B, C, T].")
|
| 345 |
pmt_wav = list(pmt_wav)
|
| 346 |
-
if vocal_wav.dim() == 2:
|
| 347 |
-
vocal_wav = vocal_wav[None]
|
| 348 |
-
if vocal_wav.dim() != 3:
|
| 349 |
-
raise ValueError("Vocal wavs should have a shape [B, C, T].")
|
| 350 |
-
vocal_wav = list(vocal_wav)
|
| 351 |
-
if bgm_wav.dim() == 2:
|
| 352 |
-
bgm_wav = bgm_wav[None]
|
| 353 |
-
if bgm_wav.dim() != 3:
|
| 354 |
-
raise ValueError("BGM wavs should have a shape [B, C, T].")
|
| 355 |
-
bgm_wav = list(bgm_wav)
|
| 356 |
if type(pmt_wav) == list:
|
| 357 |
pmt_wav = torch.stack(pmt_wav, dim=0)
|
| 358 |
-
if type(vocal_wav) == list:
|
| 359 |
-
vocal_wav = torch.stack(vocal_wav, dim=0)
|
| 360 |
-
if type(bgm_wav) == list:
|
| 361 |
-
bgm_wav = torch.stack(bgm_wav, dim=0)
|
| 362 |
with torch.no_grad():
|
| 363 |
pmt_wav, _ = audio_tokenizer.encode(pmt_wav.cuda())
|
|
|
|
| 364 |
melody_is_wav = False
|
| 365 |
elif "auto_prompt_audio_type" in item:
|
| 366 |
assert item["auto_prompt_audio_type"] in auto_prompt_type, f"auto_prompt_audio_type {item['auto_prompt_audio_type']} not found"
|
| 367 |
-
|
|
|
|
| 368 |
pmt_wav = prompt_token[:,[0],:]
|
| 369 |
-
vocal_wav = prompt_token[:,[1],:]
|
| 370 |
-
bgm_wav = prompt_token[:,[2],:]
|
| 371 |
melody_is_wav = False
|
| 372 |
else:
|
| 373 |
pmt_wav = None
|
| 374 |
-
vocal_wav = None
|
| 375 |
-
bgm_wav = None
|
| 376 |
melody_is_wav = True
|
| 377 |
-
item['pmt_wav'] = pmt_wav
|
| 378 |
-
item['vocal_wav'] = vocal_wav
|
| 379 |
-
item['bgm_wav'] = bgm_wav
|
| 380 |
-
item['melody_is_wav'] = melody_is_wav
|
| 381 |
item["idx"] = f"{item['idx']}"
|
| 382 |
item["wav_path"] = target_wav_name
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
if use_audio_tokenizer:
|
| 386 |
-
del audio_tokenizer
|
| 387 |
-
del separator
|
| 388 |
-
|
| 389 |
-
torch.cuda.empty_cache()
|
| 390 |
-
|
| 391 |
-
if "audio_tokenizer_checkpoint_sep" in cfg.keys() and use_audio_tokenizer:
|
| 392 |
-
seperate_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint_sep, cfg)
|
| 393 |
-
else:
|
| 394 |
-
seperate_tokenizer = None
|
| 395 |
-
|
| 396 |
-
if seperate_tokenizer is not None:
|
| 397 |
-
seperate_tokenizer = seperate_tokenizer.eval().cuda()
|
| 398 |
-
|
| 399 |
-
for item in new_items:
|
| 400 |
-
if "prompt_audio_path" in item:
|
| 401 |
-
with torch.no_grad():
|
| 402 |
-
vocal_wav, bgm_wav = seperate_tokenizer.encode(item['vocal_wav'].cuda(), item['bgm_wav'].cuda())
|
| 403 |
-
item['vocal_wav'] = vocal_wav
|
| 404 |
-
item['bgm_wav'] = bgm_wav
|
| 405 |
-
|
| 406 |
-
if use_audio_tokenizer:
|
| 407 |
-
del seperate_tokenizer
|
| 408 |
-
|
| 409 |
-
torch.cuda.empty_cache()
|
| 410 |
-
|
| 411 |
-
# Define model or load pretrained model
|
| 412 |
-
audiolm = builders.get_lm_model(cfg)
|
| 413 |
-
checkpoint = torch.load(ckpt_path, map_location='cpu')
|
| 414 |
-
audiolm_state_dict = {k.replace('audiolm.', ''): v for k, v in checkpoint.items() if k.startswith('audiolm')}
|
| 415 |
-
audiolm.load_state_dict(audiolm_state_dict, strict=False)
|
| 416 |
-
audiolm = audiolm.eval()
|
| 417 |
-
|
| 418 |
-
offload_audiolm = True if 'offload' in cfg.keys() and 'audiolm' in cfg.offload else False
|
| 419 |
-
if offload_audiolm:
|
| 420 |
-
audiolm_offload_param = OffloadParamParse.parse_config(audiolm, cfg.offload.audiolm)
|
| 421 |
-
audiolm_offload_param.show()
|
| 422 |
-
offload_profiler = OffloadProfiler(device_index=0, **(audiolm_offload_param.init_param_dict()))
|
| 423 |
-
offload_profiler.offload_layer(**(audiolm_offload_param.offload_layer_param_dict()))
|
| 424 |
-
offload_profiler.clean_cache_wrapper(**(audiolm_offload_param.clean_cache_param_dict()))
|
| 425 |
-
else:
|
| 426 |
-
audiolm = audiolm.cuda().to(torch.float16)
|
| 427 |
-
|
| 428 |
-
model = CodecLM(name = "tmp",
|
| 429 |
-
lm = audiolm,
|
| 430 |
-
audiotokenizer = None,
|
| 431 |
-
max_duration = max_duration,
|
| 432 |
-
seperate_tokenizer = None,
|
| 433 |
-
)
|
| 434 |
-
|
| 435 |
-
cfg_coef = 1.5 #25
|
| 436 |
-
temp = 0.9
|
| 437 |
-
top_k = 50
|
| 438 |
-
top_p = 0.0
|
| 439 |
-
record_tokens = True
|
| 440 |
-
record_window = 50
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
model.set_generation_params(duration=max_duration, extend_stride=5, temperature=temp, cfg_coef=cfg_coef,
|
| 444 |
-
top_k=top_k, top_p=top_p, record_tokens=record_tokens, record_window=record_window)
|
| 445 |
-
os.makedirs(save_dir, exist_ok=True)
|
| 446 |
-
os.makedirs(save_dir + "/audios", exist_ok=True)
|
| 447 |
-
os.makedirs(save_dir + "/jsonl", exist_ok=True)
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
for item in new_items:
|
| 451 |
-
lyric = item["gt_lyric"]
|
| 452 |
-
descriptions = item["descriptions"] if "descriptions" in item else None
|
| 453 |
-
pmt_wav = item['pmt_wav']
|
| 454 |
-
vocal_wav = item['vocal_wav']
|
| 455 |
-
bgm_wav = item['bgm_wav']
|
| 456 |
-
melody_is_wav = item['melody_is_wav']
|
| 457 |
-
|
| 458 |
generate_inp = {
|
| 459 |
-
'
|
| 460 |
-
'
|
| 461 |
'melody_wavs': pmt_wav,
|
| 462 |
-
'vocal_wavs': vocal_wav,
|
| 463 |
-
'bgm_wavs': bgm_wav,
|
| 464 |
'melody_is_wav': melody_is_wav,
|
|
|
|
| 465 |
}
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
offload_wav_tokenizer_diffusion = False
|
| 492 |
-
if offload_wav_tokenizer_diffusion:
|
| 493 |
-
sep_offload_param = OffloadParamParse.parse_config(seperate_tokenizer, cfg.offload.wav_tokenizer_diffusion)
|
| 494 |
-
sep_offload_param.show()
|
| 495 |
-
sep_offload_profiler = OffloadProfiler(device_index=0, **(sep_offload_param.init_param_dict()))
|
| 496 |
-
sep_offload_profiler.offload_layer(**(sep_offload_param.offload_layer_param_dict()))
|
| 497 |
-
sep_offload_profiler.clean_cache_wrapper(**(sep_offload_param.clean_cache_param_dict()))
|
| 498 |
-
else:
|
| 499 |
-
seperate_tokenizer.model.model = seperate_tokenizer.model.model.to(device)
|
| 500 |
-
|
| 501 |
-
model = CodecLM(name = "tmp",
|
| 502 |
-
lm = None,
|
| 503 |
-
audiotokenizer = None,
|
| 504 |
-
max_duration = max_duration,
|
| 505 |
-
seperate_tokenizer = seperate_tokenizer,
|
| 506 |
-
)
|
| 507 |
|
| 508 |
-
for item in new_items:
|
| 509 |
with torch.no_grad():
|
|
|
|
| 510 |
if 'raw_pmt_wav' in item:
|
| 511 |
-
|
| 512 |
-
wav_seperate = model.generate_audio(item['tokens'], item['raw_pmt_wav'], item['raw_vocal_wav'], item['raw_bgm_wav'],chunked=True, gen_type='mixed')
|
| 513 |
-
wav_vocal = model.generate_audio(item['tokens'],chunked=True, gen_type='vocal')
|
| 514 |
-
wav_bgm = model.generate_audio(item['tokens'], chunked=True, gen_type='bgm')
|
| 515 |
-
elif gen_type == 'mixed':
|
| 516 |
-
wav_seperate = model.generate_audio(item['tokens'], item['raw_pmt_wav'], item['raw_vocal_wav'], item['raw_bgm_wav'],chunked=True, gen_type=gen_type)
|
| 517 |
-
else:
|
| 518 |
-
wav_seperate = model.generate_audio(item['tokens'], chunked=True, gen_type=gen_type)
|
| 519 |
del item['raw_pmt_wav']
|
| 520 |
-
del item['raw_vocal_wav']
|
| 521 |
-
del item['raw_bgm_wav']
|
| 522 |
else:
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
torchaudio.save(item['wav_path'].replace('.flac', '_bgm.flac'), wav_bgm[0].cpu().float(), cfg.sample_rate)
|
| 532 |
-
torchaudio.save(item['wav_path'], wav_seperate[0].cpu().float(), cfg.sample_rate)
|
| 533 |
-
else:
|
| 534 |
-
torchaudio.save(item['wav_path'], wav_seperate[0].cpu().float(), cfg.sample_rate)
|
| 535 |
-
del item['tokens']
|
| 536 |
-
del item['pmt_wav']
|
| 537 |
-
del item['vocal_wav']
|
| 538 |
-
del item['bgm_wav']
|
| 539 |
-
del item['melody_is_wav']
|
| 540 |
-
if offload_wav_tokenizer_diffusion:
|
| 541 |
-
sep_offload_profiler.reset_empty_cache_mem_line()
|
| 542 |
|
| 543 |
-
if offload_wav_tokenizer_diffusion:
|
| 544 |
-
sep_offload_profiler.stop()
|
| 545 |
-
torch.cuda.empty_cache()
|
| 546 |
src_jsonl_name = os.path.split(input_jsonl)[-1]
|
| 547 |
with open(f"{save_dir}/jsonl/{src_jsonl_name}.jsonl", "w", encoding='utf-8') as fw:
|
| 548 |
for item in new_items:
|
| 549 |
fw.writelines(json.dumps(item, ensure_ascii=False)+"\n")
|
| 550 |
|
| 551 |
-
|
| 552 |
if __name__ == "__main__":
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx])
|
| 556 |
-
OmegaConf.register_new_resolver("get_fname", lambda: os.path.splitext(os.path.basename(sys.argv[1]))[0])
|
| 557 |
-
OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x)))
|
| 558 |
-
np.random.seed(int(time.time()))
|
| 559 |
-
# 解析命令行参数
|
| 560 |
-
args = parse_args()
|
| 561 |
-
if torch.cuda.is_available():
|
| 562 |
-
device = torch.cuda.current_device()
|
| 563 |
-
reserved = torch.cuda.memory_reserved(device)
|
| 564 |
-
total = torch.cuda.get_device_properties(device).total_memory
|
| 565 |
-
res_mem = (total - reserved) / 1024 / 1024 / 1024
|
| 566 |
-
print(f"reserved memory: {res_mem}GB")
|
| 567 |
-
|
| 568 |
-
model_name = args.ckpt_path.split("/")[-1].lower().replace('-', '_')
|
| 569 |
-
assert model_name in ['songgeneration_base', 'songgeneration_base_new', 'songgeneration_base_full', 'songgeneration_large', 'songgeneration_new_small', 'songgeneration_new_large', 'songgeneration_new_medium'], f'{model_name} is not supported, currently only songgeneration_base, songgeneration_base_new, songgeneration_base_full, songgeneration_large are supported. Please download correct files and rename the folder to the corresponding version name.'
|
| 570 |
-
if model_name == 'songgeneration_base' or model_name == 'songgeneration_base_new' or model_name == 'songgeneration_base_full':
|
| 571 |
-
if res_mem > 24 and not args.low_mem:
|
| 572 |
-
print("use generate")
|
| 573 |
-
generate(args)
|
| 574 |
-
else:
|
| 575 |
-
from codeclm.utils.offload_profiler import OffloadProfiler, OffloadParamParse
|
| 576 |
-
print("use generate_lowmem")
|
| 577 |
-
generate_lowmem(args)
|
| 578 |
-
elif model_name == 'songgeneration_large':
|
| 579 |
-
if res_mem > 36 and not args.low_mem:
|
| 580 |
-
print("use generate")
|
| 581 |
-
generate(args)
|
| 582 |
-
else:
|
| 583 |
-
print("use generate_lowmem")
|
| 584 |
-
from codeclm.utils.offload_profiler import OffloadProfiler, OffloadParamParse
|
| 585 |
-
generate_lowmem(args)
|
| 586 |
-
elif model_name == 'songgeneration_new_small' or model_name == 'songgeneration_new_large' or model_name == 'songgeneration_new_medium':
|
| 587 |
-
print("use generate")
|
| 588 |
-
generate(args, version = 'v1.5')
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
else:
|
| 592 |
-
print("CUDA is not available")
|
| 593 |
-
exit()
|
| 594 |
-
|
|
|
|
| 1 |
+
import glob
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import time
|
|
|
|
| 3 |
import torch
|
| 4 |
+
from codeclm.models.codeclm_gen import CodecLM_gen
|
| 5 |
+
from codeclm.models import builders
|
| 6 |
+
import sys
|
| 7 |
+
import os
|
| 8 |
import torchaudio
|
| 9 |
import numpy as np
|
| 10 |
+
import json
|
| 11 |
+
from vllm import LLM, SamplingParams
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
import re
|
| 13 |
+
import argparse
|
| 14 |
+
import librosa
|
| 15 |
+
auto_prompt_type = ['Pop', 'Latin', 'Rock', 'Electronic', 'Metal', 'Country', 'R&B/Soul', 'Ballad', 'Jazz', 'World', 'Hip-Hop', 'Funk', 'Soundtrack','Auto']
|
| 16 |
|
|
|
|
| 17 |
|
| 18 |
def check_language_by_text(text):
|
| 19 |
chinese_pattern = re.compile(r'[\u4e00-\u9fff]')
|
|
|
|
| 29 |
else:
|
| 30 |
return "en"
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
+
def load_audio(f):
|
| 34 |
+
a, fs= librosa.load(f, sr=48000)
|
| 35 |
+
a = torch.tensor(a).unsqueeze(0)
|
| 36 |
+
if (fs != 48000):
|
| 37 |
+
a = torchaudio.functional.resample(a, fs, 48000)
|
| 38 |
+
if a.shape[-1] >= 48000*10:
|
| 39 |
+
a = a[..., :48000*10]
|
| 40 |
+
return a[:, 0:48000*10]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
|
| 43 |
def parse_args():
|
| 44 |
parser = argparse.ArgumentParser(description='Song Generation Script')
|
| 45 |
|
| 46 |
# 必需参数
|
|
|
|
|
|
|
| 47 |
parser.add_argument('--input_jsonl', type=str, required=True,
|
| 48 |
help='Path to input JSONL file containing generation tasks')
|
| 49 |
parser.add_argument('--save_dir', type=str, required=True,
|
| 50 |
help='Directory to save generated audio files and results')
|
| 51 |
+
parser.add_argument('--config_path', type=str, required=True,
|
| 52 |
+
help='Path to the config file')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
return parser.parse_args()
|
| 54 |
|
| 55 |
+
|
| 56 |
+
def main():
|
| 57 |
torch.set_num_threads(1)
|
| 58 |
+
torch.backends.cudnn.enabled = False #taiji的某些傻呗node会报奇奇怪怪的错
|
| 59 |
+
from omegaconf import OmegaConf
|
| 60 |
+
OmegaConf.register_new_resolver("eval", lambda x: eval(x))
|
| 61 |
+
OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx])
|
| 62 |
+
OmegaConf.register_new_resolver("get_fname", lambda: os.path.splitext(os.path.basename(sys.argv[1]))[0])
|
| 63 |
+
OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x)))
|
| 64 |
+
args = parse_args()
|
| 65 |
input_jsonl = args.input_jsonl
|
| 66 |
save_dir = args.save_dir
|
| 67 |
+
cfg_path = args.config_path
|
| 68 |
+
|
| 69 |
cfg = OmegaConf.load(cfg_path)
|
|
|
|
|
|
|
| 70 |
cfg.mode = 'inference'
|
| 71 |
max_duration = cfg.max_dur
|
|
|
|
| 72 |
|
|
|
|
|
|
|
|
|
|
| 73 |
audio_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint, cfg)
|
| 74 |
+
if audio_tokenizer is not None:
|
| 75 |
+
for param in audio_tokenizer.parameters():
|
| 76 |
+
param.requires_grad = False
|
| 77 |
+
print("Audio tokenizer successfully loaded!")
|
| 78 |
audio_tokenizer = audio_tokenizer.eval().cuda()
|
| 79 |
+
model_condition = CodecLM_gen(cfg=cfg,name = "tmp",audiotokenizer = audio_tokenizer,max_duration = max_duration)
|
| 80 |
+
model_condition.condition_provider.conditioners.load_state_dict(torch.load(cfg.lm_checkpoint+"/conditioners_weights.pth"))
|
| 81 |
+
print('Conditioner successfully loaded!')
|
| 82 |
+
llm = LLM(
|
| 83 |
+
model=cfg.lm_checkpoint,
|
| 84 |
+
trust_remote_code=True,
|
| 85 |
+
tensor_parallel_size=cfg.vllm.device_num,
|
| 86 |
+
enforce_eager=False,
|
| 87 |
+
dtype="bfloat16",
|
| 88 |
+
gpu_memory_utilization=cfg.vllm.gpu_memory_utilization,
|
| 89 |
+
tokenizer=None,
|
| 90 |
+
skip_tokenizer_init=True,
|
| 91 |
+
enable_prompt_embeds=True,
|
| 92 |
+
enable_chunked_prefill=True,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
)
|
| 94 |
+
print("LLM 初始化成功")
|
| 95 |
+
auto_prompt = torch.load('tools/new_prompt.pt')
|
| 96 |
|
| 97 |
+
guidance_scale = cfg.vllm.guidance_scale
|
| 98 |
+
temp = cfg.vllm.temp
|
| 99 |
+
top_k = cfg.vllm.top_k
|
| 100 |
+
sum_time = 0
|
| 101 |
+
sum_wav_len = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
os.makedirs(save_dir, exist_ok=True)
|
| 103 |
os.makedirs(save_dir + "/audios", exist_ok=True)
|
| 104 |
os.makedirs(save_dir + "/jsonl", exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
with open(input_jsonl, "r") as fp:
|
| 106 |
lines = fp.readlines()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
new_items = []
|
| 108 |
for line in lines:
|
| 109 |
item = json.loads(line)
|
| 110 |
+
lyric = item["gt_lyric"]
|
| 111 |
+
descriptions = item["descriptions"].lower() if "descriptions" in item else '.'
|
| 112 |
+
descriptions = '[Musicality-very-high]' + ', ' + descriptions
|
| 113 |
target_wav_name = f"{save_dir}/audios/{item['idx']}.flac"
|
| 114 |
+
if os.path.exists(target_wav_name):
|
| 115 |
+
continue
|
| 116 |
if "prompt_audio_path" in item:
|
| 117 |
assert os.path.exists(item['prompt_audio_path']), f"prompt_audio_path {item['prompt_audio_path']} not found"
|
| 118 |
assert 'auto_prompt_audio_type' not in item, f"auto_prompt_audio_type and prompt_audio_path cannot be used together"
|
| 119 |
with torch.no_grad():
|
| 120 |
+
pmt_wav = load_audio(item['prompt_audio_path'])
|
| 121 |
item['raw_pmt_wav'] = pmt_wav
|
|
|
|
|
|
|
| 122 |
if pmt_wav.dim() == 2:
|
| 123 |
pmt_wav = pmt_wav[None]
|
| 124 |
if pmt_wav.dim() != 3:
|
| 125 |
raise ValueError("Melody wavs should have a shape [B, C, T].")
|
| 126 |
pmt_wav = list(pmt_wav)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
if type(pmt_wav) == list:
|
| 128 |
pmt_wav = torch.stack(pmt_wav, dim=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
with torch.no_grad():
|
| 130 |
pmt_wav, _ = audio_tokenizer.encode(pmt_wav.cuda())
|
| 131 |
+
print(pmt_wav.shape)
|
| 132 |
melody_is_wav = False
|
| 133 |
elif "auto_prompt_audio_type" in item:
|
| 134 |
assert item["auto_prompt_audio_type"] in auto_prompt_type, f"auto_prompt_audio_type {item['auto_prompt_audio_type']} not found"
|
| 135 |
+
lang = check_language_by_text(item['gt_lyric'])
|
| 136 |
+
prompt_token = auto_prompt[item["auto_prompt_audio_type"]][lang][np.random.randint(0, len(auto_prompt[item["auto_prompt_audio_type"]][lang]))]
|
| 137 |
pmt_wav = prompt_token[:,[0],:]
|
|
|
|
|
|
|
| 138 |
melody_is_wav = False
|
| 139 |
else:
|
| 140 |
pmt_wav = None
|
|
|
|
|
|
|
| 141 |
melody_is_wav = True
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
item["idx"] = f"{item['idx']}"
|
| 143 |
item["wav_path"] = target_wav_name
|
| 144 |
+
embeded_eosp1 = torch.load(cfg.lm_checkpoint+'/embeded_eosp1.pt')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
generate_inp = {
|
| 146 |
+
'descriptions': [lyric.replace(" ", " ")],
|
| 147 |
+
'type_info': [descriptions],
|
| 148 |
'melody_wavs': pmt_wav,
|
|
|
|
|
|
|
| 149 |
'melody_is_wav': melody_is_wav,
|
| 150 |
+
'embeded_eosp1': embeded_eosp1,
|
| 151 |
}
|
| 152 |
+
fused_input, audio_qt_embs = model_condition.generate_condition(**generate_inp, return_tokens=True)
|
| 153 |
+
prompt_token = audio_qt_embs[0][0].tolist() if audio_qt_embs else []
|
| 154 |
+
allowed_token_ids = [x for x in range(cfg.lm.code_size+1) if x not in prompt_token]
|
| 155 |
+
sampling_params = SamplingParams(
|
| 156 |
+
max_tokens=cfg.audio_tokenizer_frame_rate*cfg.max_dur,
|
| 157 |
+
temperature=temp,
|
| 158 |
+
stop_token_ids=[cfg.lm.code_size],
|
| 159 |
+
top_k=top_k,
|
| 160 |
+
frequency_penalty=0.2,
|
| 161 |
+
seed=int(time.time() * 1000000) % (2**32) if cfg.vllm.cfg else -1,
|
| 162 |
+
allowed_token_ids=allowed_token_ids,
|
| 163 |
+
guidance_scale=guidance_scale
|
| 164 |
+
)
|
| 165 |
+
# 拆成现支持的batch 3 CFG形式
|
| 166 |
+
prompts = [{"prompt_embeds": embed} for embed in fused_input]
|
| 167 |
+
promptss = []
|
| 168 |
+
for _ in range(2):
|
| 169 |
+
promptss+=prompts
|
| 170 |
+
uncondi = prompts[1]
|
| 171 |
+
promptss = promptss[::2] + [uncondi]
|
| 172 |
+
start_time = time.time()
|
| 173 |
+
outputs = llm.generate(promptss, sampling_params=sampling_params)
|
| 174 |
+
mid_time = time.time()
|
| 175 |
+
token_ids_CFG = torch.tensor(outputs[1].outputs[0].token_ids)
|
| 176 |
+
token_ids_CFG = token_ids_CFG[:-1].unsqueeze(0).unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
|
|
|
|
| 178 |
with torch.no_grad():
|
| 179 |
+
# wav_nocfg = model_condition.generate_audio(token_ids)
|
| 180 |
if 'raw_pmt_wav' in item:
|
| 181 |
+
wav_cfg = model_condition.generate_audio(token_ids_CFG, item['raw_pmt_wav'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
del item['raw_pmt_wav']
|
|
|
|
|
|
|
| 183 |
else:
|
| 184 |
+
wav_cfg = model_condition.generate_audio(token_ids_CFG)
|
| 185 |
+
end_time = time.time()
|
| 186 |
+
torchaudio.save(target_wav_name, wav_cfg[0].cpu().float(), cfg.sample_rate)
|
| 187 |
+
sum_time += end_time - start_time
|
| 188 |
+
sum_wav_len += (token_ids_CFG.shape[-1] / 25)
|
| 189 |
+
print(f"process{item['idx']}, lm cost {mid_time - start_time}s, diffusion cost {end_time - mid_time}, rtf {(end_time - start_time) / token_ids_CFG.shape[-1] * 25:.2f}")
|
| 190 |
+
new_items.append(item)
|
| 191 |
+
print(f"Total time: {sum_time:.4f} seconds, total wav length: {sum_wav_len:.4f} seconds, rtf {sum_time/sum_wav_len:.2f}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
|
|
|
|
|
|
|
|
|
|
| 193 |
src_jsonl_name = os.path.split(input_jsonl)[-1]
|
| 194 |
with open(f"{save_dir}/jsonl/{src_jsonl_name}.jsonl", "w", encoding='utf-8') as fw:
|
| 195 |
for item in new_items:
|
| 196 |
fw.writelines(json.dumps(item, ensure_ascii=False)+"\n")
|
| 197 |
|
|
|
|
| 198 |
if __name__ == "__main__":
|
| 199 |
+
main()
|
| 200 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
generate.sh
CHANGED
|
@@ -3,70 +3,15 @@ export PYTHONDONTWRITEBYTECODE=1
|
|
| 3 |
export TRANSFORMERS_CACHE="$(pwd)/third_party/hub"
|
| 4 |
export NCCL_HOME=/usr/local/tccl
|
| 5 |
export PYTHONPATH="$(pwd)/codeclm/tokenizer/":"$(pwd)":"$(pwd)/codeclm/tokenizer/Flow1dVAE/":"$(pwd)/codeclm/tokenizer/":$PYTHONPATH
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
-
|
|
|
|
| 8 |
JSONL=$2
|
| 9 |
SAVE_DIR=$3
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
if [[ $arg == "--not_use_flash_attn" ]]; then
|
| 15 |
-
USE_FLASH_ATTN="False"
|
| 16 |
-
fi
|
| 17 |
-
done
|
| 18 |
-
for arg in "$@"; do
|
| 19 |
-
if [[ $arg == "--low_mem" ]]; then
|
| 20 |
-
LOW_MEM="True"
|
| 21 |
-
fi
|
| 22 |
-
done
|
| 23 |
-
for arg in "$@"; do
|
| 24 |
-
if [[ $arg == "--separate" ]]; then
|
| 25 |
-
GENERATE_TYPE="separate"
|
| 26 |
-
fi
|
| 27 |
-
done
|
| 28 |
-
for arg in "$@"; do
|
| 29 |
-
if [[ $arg == "--bgm" ]]; then
|
| 30 |
-
GENERATE_TYPE="bgm"
|
| 31 |
-
fi
|
| 32 |
-
done
|
| 33 |
-
for arg in "$@"; do
|
| 34 |
-
if [[ $arg == "--vocal" ]]; then
|
| 35 |
-
GENERATE_TYPE="vocal"
|
| 36 |
-
fi
|
| 37 |
-
done
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
if [ "$USE_FLASH_ATTN" == "True" ] && [ "$LOW_MEM" == "True" ]; then
|
| 41 |
-
echo "Use Flash Attention + Low Memory Mode"
|
| 42 |
-
python3 generate.py \
|
| 43 |
-
--ckpt_path $CKPT_PATH \
|
| 44 |
-
--input_jsonl $JSONL \
|
| 45 |
-
--save_dir $SAVE_DIR \
|
| 46 |
-
--generate_type $GENERATE_TYPE \
|
| 47 |
-
--use_flash_attn \
|
| 48 |
-
--low_mem
|
| 49 |
-
elif [ "$USE_FLASH_ATTN" == "True" ] && [ "$LOW_MEM" == "False" ]; then
|
| 50 |
-
echo "Use Flash Attention + Auto Memory Mode"
|
| 51 |
-
python3 generate.py \
|
| 52 |
-
--ckpt_path $CKPT_PATH \
|
| 53 |
-
--input_jsonl $JSONL \
|
| 54 |
-
--save_dir $SAVE_DIR \
|
| 55 |
-
--generate_type $GENERATE_TYPE \
|
| 56 |
-
--use_flash_attn
|
| 57 |
-
elif [ "$USE_FLASH_ATTN" == "False" ] && [ "$LOW_MEM" == "False" ]; then
|
| 58 |
-
echo "Not Use Flash Attention + Auto Memory Mode"
|
| 59 |
-
python3 generate.py \
|
| 60 |
-
--ckpt_path $CKPT_PATH \
|
| 61 |
-
--input_jsonl $JSONL \
|
| 62 |
-
--generate_type $GENERATE_TYPE \
|
| 63 |
-
--save_dir $SAVE_DIR
|
| 64 |
-
elif [ "$USE_FLASH_ATTN" == "False" ] && [ "$LOW_MEM" == "True" ]; then
|
| 65 |
-
echo "Not Use Flash Attention + Low Memory Mode"
|
| 66 |
-
python3 generate.py \
|
| 67 |
-
--ckpt_path $CKPT_PATH \
|
| 68 |
-
--input_jsonl $JSONL \
|
| 69 |
-
--save_dir $SAVE_DIR \
|
| 70 |
-
--generate_type $GENERATE_TYPE \
|
| 71 |
-
--low_mem
|
| 72 |
-
fi
|
|
|
|
| 3 |
export TRANSFORMERS_CACHE="$(pwd)/third_party/hub"
|
| 4 |
export NCCL_HOME=/usr/local/tccl
|
| 5 |
export PYTHONPATH="$(pwd)/codeclm/tokenizer/":"$(pwd)":"$(pwd)/codeclm/tokenizer/Flow1dVAE/":"$(pwd)/codeclm/tokenizer/":$PYTHONPATH
|
| 6 |
+
export OMP_NUM_THREADS=1
|
| 7 |
+
export MKL_NUM_THREADS=1
|
| 8 |
+
export CUDA_LAUNCH_BLOCKING=0
|
| 9 |
|
| 10 |
+
|
| 11 |
+
CONFIG_PATH=$1
|
| 12 |
JSONL=$2
|
| 13 |
SAVE_DIR=$3
|
| 14 |
+
python3 generate.py \
|
| 15 |
+
--input_jsonl $JSONL \
|
| 16 |
+
--save_dir $SAVE_DIR \
|
| 17 |
+
--config_path $CONFIG_PATH
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
levo_inference.py
CHANGED
|
@@ -1,22 +1,19 @@
|
|
| 1 |
import os
|
| 2 |
import sys
|
| 3 |
-
|
| 4 |
|
| 5 |
sys.path.append('./codeclm/tokenizer')
|
| 6 |
sys.path.append('./codeclm/tokenizer/Flow1dVAE')
|
| 7 |
sys.path.append('.')
|
| 8 |
|
| 9 |
import torch
|
| 10 |
-
|
| 11 |
-
import json
|
| 12 |
import numpy as np
|
| 13 |
from omegaconf import OmegaConf
|
|
|
|
| 14 |
|
| 15 |
from codeclm.models import builders
|
| 16 |
-
from codeclm.models import
|
| 17 |
-
|
| 18 |
-
from separator import Separator
|
| 19 |
-
from generate import check_language_by_text
|
| 20 |
|
| 21 |
|
| 22 |
class LeVoInference(torch.nn.Module):
|
|
@@ -30,39 +27,37 @@ class LeVoInference(torch.nn.Module):
|
|
| 30 |
OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x)))
|
| 31 |
|
| 32 |
cfg_path = os.path.join(ckpt_path, 'config.yaml')
|
| 33 |
-
pt_path = os.path.join(ckpt_path, 'model.pt')
|
| 34 |
-
|
| 35 |
self.cfg = OmegaConf.load(cfg_path)
|
| 36 |
self.cfg.mode = 'inference'
|
| 37 |
self.max_duration = self.cfg.max_dur
|
| 38 |
|
| 39 |
-
# Define model or load pretrained model
|
| 40 |
-
audiolm = builders.get_lm_model(self.cfg, version='v1.5')
|
| 41 |
-
checkpoint = torch.load(pt_path, map_location='cpu')
|
| 42 |
-
audiolm_state_dict = {k.replace('audiolm.', ''): v for k, v in checkpoint.items() if k.startswith('audiolm')}
|
| 43 |
-
audiolm.load_state_dict(audiolm_state_dict, strict=False)
|
| 44 |
-
audiolm = audiolm.eval()
|
| 45 |
-
audiolm = audiolm.cuda().to(torch.float16)
|
| 46 |
-
|
| 47 |
audio_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint, self.cfg)
|
| 48 |
-
audio_tokenizer
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
self.
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
)
|
| 59 |
-
self.separator = Separator()
|
| 60 |
-
|
| 61 |
|
| 62 |
self.default_params = dict(
|
| 63 |
-
cfg_coef = 1.
|
| 64 |
-
temperature =
|
| 65 |
-
top_k =
|
| 66 |
top_p = 0.0,
|
| 67 |
record_tokens = True,
|
| 68 |
record_window = 50,
|
|
@@ -70,14 +65,11 @@ class LeVoInference(torch.nn.Module):
|
|
| 70 |
duration = self.max_duration,
|
| 71 |
)
|
| 72 |
|
| 73 |
-
self.model.set_generation_params(**self.default_params)
|
| 74 |
-
|
| 75 |
def forward(self, lyric: str, description: str = None, prompt_audio_path: os.PathLike = None, genre: str = None, auto_prompt_path: os.PathLike = None, gen_type: str = "mixed", params = dict()):
|
| 76 |
params = {**self.default_params, **params}
|
| 77 |
-
self.model.set_generation_params(**params)
|
| 78 |
|
| 79 |
if prompt_audio_path is not None and os.path.exists(prompt_audio_path):
|
| 80 |
-
pmt_wav
|
| 81 |
melody_is_wav = True
|
| 82 |
elif genre is not None and auto_prompt_path is not None:
|
| 83 |
auto_prompt = torch.load(auto_prompt_path)
|
|
@@ -87,33 +79,48 @@ class LeVoInference(torch.nn.Module):
|
|
| 87 |
else:
|
| 88 |
prompt_token = auto_prompt[genre][np.random.randint(0, len(auto_prompt[genre]))]
|
| 89 |
pmt_wav = prompt_token[:,[0],:]
|
| 90 |
-
vocal_wav = prompt_token[:,[1],:]
|
| 91 |
-
bgm_wav = prompt_token[:,[2],:]
|
| 92 |
melody_is_wav = False
|
| 93 |
else:
|
| 94 |
pmt_wav = None
|
| 95 |
-
vocal_wav = None
|
| 96 |
-
bgm_wav = None
|
| 97 |
melody_is_wav = True
|
| 98 |
|
| 99 |
description = description if description else '.'
|
| 100 |
description = '[Musicality-very-high]' + ', ' + description
|
| 101 |
generate_inp = {
|
| 102 |
-
'
|
| 103 |
-
'
|
| 104 |
'melody_wavs': pmt_wav,
|
| 105 |
-
'vocal_wavs': vocal_wav,
|
| 106 |
-
'bgm_wavs': bgm_wav,
|
| 107 |
'melody_is_wav': melody_is_wav,
|
|
|
|
| 108 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
-
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
| 111 |
-
tokens = self.model.generate(**generate_inp, return_tokens=True)
|
| 112 |
-
|
| 113 |
with torch.no_grad():
|
| 114 |
if melody_is_wav:
|
| 115 |
-
|
| 116 |
else:
|
| 117 |
-
|
| 118 |
|
| 119 |
-
return
|
|
|
|
| 1 |
import os
|
| 2 |
import sys
|
| 3 |
+
import time
|
| 4 |
|
| 5 |
sys.path.append('./codeclm/tokenizer')
|
| 6 |
sys.path.append('./codeclm/tokenizer/Flow1dVAE')
|
| 7 |
sys.path.append('.')
|
| 8 |
|
| 9 |
import torch
|
|
|
|
|
|
|
| 10 |
import numpy as np
|
| 11 |
from omegaconf import OmegaConf
|
| 12 |
+
from vllm import LLM, SamplingParams
|
| 13 |
|
| 14 |
from codeclm.models import builders
|
| 15 |
+
from codeclm.models.codeclm_gen import CodecLM_gen
|
| 16 |
+
from generate import check_language_by_text, load_audio
|
|
|
|
|
|
|
| 17 |
|
| 18 |
|
| 19 |
class LeVoInference(torch.nn.Module):
|
|
|
|
| 27 |
OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x)))
|
| 28 |
|
| 29 |
cfg_path = os.path.join(ckpt_path, 'config.yaml')
|
|
|
|
|
|
|
| 30 |
self.cfg = OmegaConf.load(cfg_path)
|
| 31 |
self.cfg.mode = 'inference'
|
| 32 |
self.max_duration = self.cfg.max_dur
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
audio_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint, self.cfg)
|
| 35 |
+
if audio_tokenizer is not None:
|
| 36 |
+
for param in audio_tokenizer.parameters():
|
| 37 |
+
param.requires_grad = False
|
| 38 |
+
print("Audio tokenizer successfully loaded!")
|
| 39 |
+
audio_tokenizer = audio_tokenizer.eval().cuda()
|
| 40 |
+
self.model_condition = CodecLM_gen(cfg=self.cfg,name = "tmp",audiotokenizer = audio_tokenizer,max_duration = self.max_duration)
|
| 41 |
+
self.model_condition.condition_provider.conditioners.load_state_dict(torch.load(self.cfg.lm_checkpoint+"/conditioners_weights.pth"))
|
| 42 |
+
self.embeded_eosp1 = torch.load(self.cfg.lm_checkpoint+'/embeded_eosp1.pt')
|
| 43 |
+
print('Conditioner successfully loaded!')
|
| 44 |
+
self.llm = LLM(
|
| 45 |
+
model=self.cfg.lm_checkpoint,
|
| 46 |
+
trust_remote_code=True,
|
| 47 |
+
tensor_parallel_size=self.cfg.vllm.device_num,
|
| 48 |
+
enforce_eager=False,
|
| 49 |
+
dtype="bfloat16",
|
| 50 |
+
gpu_memory_utilization=self.cfg.vllm.gpu_memory_utilization,
|
| 51 |
+
tokenizer=None,
|
| 52 |
+
skip_tokenizer_init=True,
|
| 53 |
+
enable_prompt_embeds=True,
|
| 54 |
+
enable_chunked_prefill=True,
|
| 55 |
)
|
|
|
|
|
|
|
| 56 |
|
| 57 |
self.default_params = dict(
|
| 58 |
+
cfg_coef = 1.8,
|
| 59 |
+
temperature = 0.8,
|
| 60 |
+
top_k = 5000,
|
| 61 |
top_p = 0.0,
|
| 62 |
record_tokens = True,
|
| 63 |
record_window = 50,
|
|
|
|
| 65 |
duration = self.max_duration,
|
| 66 |
)
|
| 67 |
|
|
|
|
|
|
|
| 68 |
def forward(self, lyric: str, description: str = None, prompt_audio_path: os.PathLike = None, genre: str = None, auto_prompt_path: os.PathLike = None, gen_type: str = "mixed", params = dict()):
|
| 69 |
params = {**self.default_params, **params}
|
|
|
|
| 70 |
|
| 71 |
if prompt_audio_path is not None and os.path.exists(prompt_audio_path):
|
| 72 |
+
pmt_wav = load_audio(prompt_audio_path)
|
| 73 |
melody_is_wav = True
|
| 74 |
elif genre is not None and auto_prompt_path is not None:
|
| 75 |
auto_prompt = torch.load(auto_prompt_path)
|
|
|
|
| 79 |
else:
|
| 80 |
prompt_token = auto_prompt[genre][np.random.randint(0, len(auto_prompt[genre]))]
|
| 81 |
pmt_wav = prompt_token[:,[0],:]
|
|
|
|
|
|
|
| 82 |
melody_is_wav = False
|
| 83 |
else:
|
| 84 |
pmt_wav = None
|
|
|
|
|
|
|
| 85 |
melody_is_wav = True
|
| 86 |
|
| 87 |
description = description if description else '.'
|
| 88 |
description = '[Musicality-very-high]' + ', ' + description
|
| 89 |
generate_inp = {
|
| 90 |
+
'descriptions': [lyric.replace(" ", " ")],
|
| 91 |
+
'type_info': [description],
|
| 92 |
'melody_wavs': pmt_wav,
|
|
|
|
|
|
|
| 93 |
'melody_is_wav': melody_is_wav,
|
| 94 |
+
'embeded_eosp1': self.embeded_eosp1,
|
| 95 |
}
|
| 96 |
+
fused_input, audio_qt_embs = self.model_condition.generate_condition(**generate_inp, return_tokens=True)
|
| 97 |
+
prompt_token = audio_qt_embs[0][0].tolist() if audio_qt_embs else []
|
| 98 |
+
allowed_token_ids = [x for x in range(self.cfg.lm.code_size+1) if x not in prompt_token]
|
| 99 |
+
sampling_params = SamplingParams(
|
| 100 |
+
max_tokens=self.cfg.audio_tokenizer_frame_rate*self.max_duration,
|
| 101 |
+
temperature=params["temperature"],
|
| 102 |
+
stop_token_ids=[self.cfg.lm.code_size],
|
| 103 |
+
top_k=params["top_k"],
|
| 104 |
+
frequency_penalty=0.2,
|
| 105 |
+
seed=int(time.time() * 1000000) % (2**32) if self.cfg.vllm.cfg else -1,
|
| 106 |
+
allowed_token_ids=allowed_token_ids,
|
| 107 |
+
guidance_scale=params["cfg_coef"]
|
| 108 |
+
)
|
| 109 |
+
# 拆成现支持的batch 3 CFG形式
|
| 110 |
+
prompts = [{"prompt_embeds": embed} for embed in fused_input]
|
| 111 |
+
promptss = []
|
| 112 |
+
for _ in range(2):
|
| 113 |
+
promptss+=prompts
|
| 114 |
+
uncondi = prompts[1]
|
| 115 |
+
promptss = promptss[::2] + [uncondi]
|
| 116 |
+
outputs = self.llm.generate(promptss, sampling_params=sampling_params)
|
| 117 |
+
token_ids_CFG = torch.tensor(outputs[1].outputs[0].token_ids)
|
| 118 |
+
token_ids_CFG = token_ids_CFG[:-1].unsqueeze(0).unsqueeze(0)
|
| 119 |
|
|
|
|
|
|
|
|
|
|
| 120 |
with torch.no_grad():
|
| 121 |
if melody_is_wav:
|
| 122 |
+
wav_cfg = self.model_condition.generate_audio(token_ids_CFG, pmt_wav)
|
| 123 |
else:
|
| 124 |
+
wav_cfg = self.model_condition.generate_audio(token_ids_CFG)
|
| 125 |
|
| 126 |
+
return wav_cfg[0]
|
requirements.txt
CHANGED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
gradio>=6.5.1
|
sample/lyrics.jsonl
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
{"idx": "sample_01_autoprompt", "gt_lyric": "[intro-short] ; [verse] 雪花舞动在无尽的天际.情缘如同雪花般轻轻逝去.希望与真挚.永不磨灭.你的忧虑.随风而逝 ; [chorus] 我怀抱着守护这片梦境.在这世界中寻找爱与虚幻.苦辣酸甜.我们一起品尝.在雪的光芒中.紧紧相拥 ; [inst-short] ; [verse] 雪花再次在风中飘扬.情愿如同雪花般消失无踪.希望与真挚.永不消失.在痛苦与喧嚣中.你找到解脱 ; [chorus] 我环绕着守护这片梦境.在这世界中感受爱与虚假.苦辣酸甜.我们一起分享.在白银的光芒中.我们同在 ; [outro-short]", "auto_prompt_audio_type": "Auto"}
|
| 2 |
{"idx": "sample_01_noprompt", "gt_lyric": "[intro-short] ; [verse] 雪花舞动在无尽的天际.情缘如同雪花般轻轻逝去.希望与真挚.永不磨灭.你的忧虑.随风而逝 ; [chorus] 我怀抱着守护这片梦境.在这世界中寻找爱与虚幻.苦辣酸甜.我们一起品尝.在雪的光芒中.紧紧相拥 ; [inst-short] ; [verse] 雪花再次在风中飘扬.情愿如同雪花般消失无踪.希望与真挚.永不消失.在痛苦与喧嚣中.你找到解脱 ; [chorus] 我环绕着守护这片梦境.在这世界中感受爱与虚假.苦辣酸甜.我们一起分享.在白银的光芒中.我们同在 ; [outro-short]"}
|
| 3 |
-
{"idx": "sample_01_textprompt", "descriptions": "female, dark, pop, sad,
|
| 4 |
-
{"idx": "sample_01_audioprompt", "gt_lyric": "[intro-short] ; [verse] 雪花舞动在无尽的天际.情缘如同雪花般轻轻逝去.希望与真挚.永不磨灭.你的忧虑.随风而逝 ; [chorus] 我怀抱着守护这片梦境.在这世界中寻找爱与虚幻.苦辣酸甜.我们一起品尝.在雪的光芒中.紧紧相拥 ; [inst-short] ; [verse] 雪花再次在风中飘扬.情愿如同雪花般消失无踪.希望与真挚.永不消失.在痛苦与喧嚣中.你找到解脱 ; [chorus] 我环绕着守护这片梦境.在这世界中感受爱与虚假.苦辣酸甜.我们一起分享.在白银的光芒中.我们同在 ; [outro-short]", "prompt_audio_path": "
|
|
|
|
|
|
|
| 1 |
{"idx": "sample_01_noprompt", "gt_lyric": "[intro-short] ; [verse] 雪花舞动在无尽的天际.情缘如同雪花般轻轻逝去.希望与真挚.永不磨灭.你的忧虑.随风而逝 ; [chorus] 我怀抱着守护这片梦境.在这世界中寻找爱与虚幻.苦辣酸甜.我们一起品尝.在雪的光芒中.紧紧相拥 ; [inst-short] ; [verse] 雪花再次在风中飘扬.情愿如同雪花般消失无踪.希望与真挚.永不消失.在痛苦与喧嚣中.你找到解脱 ; [chorus] 我环绕着守护这片梦境.在这世界中感受爱与虚假.苦辣酸甜.我们一起分享.在白银的光芒中.我们同在 ; [outro-short]"}
|
| 2 |
+
{"idx": "sample_01_textprompt", "descriptions": "female, dark, pop, sad, guitar and drums, the bpm is 125.", "gt_lyric": "[intro-short] ; [verse] 雪花舞动在无尽的天际.情缘如同雪花般轻轻逝去.希望与真挚.永不磨灭.你的忧虑.随风而逝 ; [chorus] 我怀抱着守护这片梦境.在这世界中寻找爱与虚幻.苦辣酸甜.我们一起品尝.在雪的光芒中.紧紧相拥 ; [inst-short] ; [verse] 雪花再次在风中飘扬.情愿如同雪花般消失无踪.希望与真挚.永不消失.在痛苦与喧嚣中.你找到解脱 ; [chorus] 我环绕着守护这片梦境.在这世界中感受爱与虚假.苦辣酸甜.我们一起分享.在白银的光芒中.我们同在 ; [outro-short]"}
|
| 3 |
+
{"idx": "sample_01_audioprompt", "gt_lyric": "[intro-short] ; [verse] 雪花舞动在无尽的天际.情缘如同雪花般轻轻逝去.希望与真挚.永不磨灭.你的忧虑.随风而逝 ; [chorus] 我怀抱着守护这片梦境.在这世界中寻找爱与虚幻.苦辣酸甜.我们一起品尝.在雪的光芒中.紧紧相拥 ; [inst-short] ; [verse] 雪花再次在风中飘扬.情愿如同雪花般消失无踪.希望与真挚.永不消失.在痛苦与喧嚣中.你找到解脱 ; [chorus] 我环绕着守护这片梦境.在这世界中感受爱与虚假.苦辣酸甜.我们一起分享.在白银的光芒中.我们同在 ; [outro-short]", "prompt_audio_path": "sample/sample_prompt_audio.wav"}
|
vllm_hacked/model_executor/layers/utils.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
| 3 |
+
"""Utility methods for model layers."""
|
| 4 |
+
from typing import Callable, Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from vllm import _custom_ops as ops
|
| 9 |
+
from vllm import envs
|
| 10 |
+
from vllm.platforms import CpuArchEnum, current_platform
|
| 11 |
+
from vllm.utils import direct_register_custom_op
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def shuffle_weight(w: torch.Tensor) -> torch.Tensor:
|
| 15 |
+
# Shuffle weight along the last dimension so that
|
| 16 |
+
# we folded the weights to adjance location
|
| 17 |
+
# Example:
|
| 18 |
+
# input:
|
| 19 |
+
# [[1, 2, 3, 4, 5, 6],
|
| 20 |
+
# [7, 8, 9, 10, 11, 12]]
|
| 21 |
+
# output:
|
| 22 |
+
# [[1, 4, 2, 5, 3, 6],
|
| 23 |
+
# [7, 10, 8, 11, 9, 12]]
|
| 24 |
+
# This will be used together with triton swiglu kernel
|
| 25 |
+
shape = w.shape
|
| 26 |
+
N = shape[-1]
|
| 27 |
+
first = w[..., :N // 2]
|
| 28 |
+
second = w[..., N // 2:]
|
| 29 |
+
|
| 30 |
+
stacked = torch.stack((first, second), dim=-1)
|
| 31 |
+
w_shuffled = stacked.reshape(shape)
|
| 32 |
+
return w_shuffled
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def get_token_bin_counts_and_mask(
|
| 36 |
+
tokens: torch.Tensor,
|
| 37 |
+
vocab_size: int,
|
| 38 |
+
num_seqs: int,
|
| 39 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 40 |
+
# Compute the bin counts for the tokens.
|
| 41 |
+
# vocab_size + 1 for padding.
|
| 42 |
+
bin_counts = torch.zeros((num_seqs, vocab_size + 1),
|
| 43 |
+
dtype=torch.long,
|
| 44 |
+
device=tokens.device)
|
| 45 |
+
bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
|
| 46 |
+
bin_counts = bin_counts[:, :vocab_size]
|
| 47 |
+
mask = bin_counts > 0
|
| 48 |
+
|
| 49 |
+
return bin_counts, mask
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
|
| 53 |
+
output_tokens_tensor: torch.Tensor,
|
| 54 |
+
presence_penalties: torch.Tensor,
|
| 55 |
+
frequency_penalties: torch.Tensor,
|
| 56 |
+
repetition_penalties: torch.Tensor) -> torch.Tensor:
|
| 57 |
+
"""
|
| 58 |
+
Applies penalties in place to the logits tensor
|
| 59 |
+
logits : The input logits tensor of shape [num_seqs, vocab_size]
|
| 60 |
+
prompt_tokens_tensor: A tensor containing the prompt tokens. The prompts
|
| 61 |
+
are padded to the maximum prompt length within the batch using
|
| 62 |
+
`vocab_size` as the padding value. The value `vocab_size` is used
|
| 63 |
+
for padding because it does not correspond to any valid token ID
|
| 64 |
+
in the vocabulary.
|
| 65 |
+
output_tokens_tensor: The output tokens tensor.
|
| 66 |
+
presence_penalties: The presence penalties of shape (num_seqs, )
|
| 67 |
+
frequency_penalties: The frequency penalties of shape (num_seqs, )
|
| 68 |
+
repetition_penalties: The repetition penalties of shape (num_seqs, )
|
| 69 |
+
"""
|
| 70 |
+
num_seqs, vocab_size = logits.shape
|
| 71 |
+
_, prompt_mask = get_token_bin_counts_and_mask(prompt_tokens_tensor,
|
| 72 |
+
vocab_size, num_seqs)
|
| 73 |
+
output_bin_counts, output_mask = get_token_bin_counts_and_mask(
|
| 74 |
+
output_tokens_tensor, vocab_size, num_seqs)
|
| 75 |
+
|
| 76 |
+
# Apply repetition penalties as a custom op
|
| 77 |
+
from vllm._custom_ops import apply_repetition_penalties
|
| 78 |
+
apply_repetition_penalties(logits, prompt_mask, output_mask,
|
| 79 |
+
repetition_penalties)
|
| 80 |
+
|
| 81 |
+
# We follow the definition in OpenAI API.
|
| 82 |
+
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
|
| 83 |
+
logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts
|
| 84 |
+
# logits /= (1+frequency_penalties).unsqueeze(dim=1) ** output_bin_counts # 修改频率惩罚方式,先不改,有负有正反而encourage
|
| 85 |
+
logits -= presence_penalties.unsqueeze(dim=1) * output_mask
|
| 86 |
+
return logits
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def default_unquantized_gemm(layer: torch.nn.Module,
|
| 90 |
+
x: torch.Tensor,
|
| 91 |
+
weight: torch.Tensor,
|
| 92 |
+
bias: Optional[torch.Tensor] = None):
|
| 93 |
+
return torch.nn.functional.linear(x, weight, bias)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def rocm_unquantized_gemm_impl(
|
| 97 |
+
x: torch.Tensor,
|
| 98 |
+
weight: torch.Tensor,
|
| 99 |
+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 100 |
+
from vllm.platforms.rocm import on_gfx9
|
| 101 |
+
k = weight.shape[1]
|
| 102 |
+
use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_gfx9() and \
|
| 103 |
+
x.dtype in [torch.float16, torch.bfloat16] \
|
| 104 |
+
and k % 8 == 0)
|
| 105 |
+
|
| 106 |
+
if use_skinny is not True:
|
| 107 |
+
return torch.nn.functional.linear(x, weight, bias)
|
| 108 |
+
|
| 109 |
+
x_view = x.view(-1, x.size(-1))
|
| 110 |
+
n = x_view.shape[0]
|
| 111 |
+
m = weight.shape[0]
|
| 112 |
+
cu_count = current_platform.get_cu_count()
|
| 113 |
+
|
| 114 |
+
if m > 8 and 0 < n <= 4:
|
| 115 |
+
out = ops.wvSplitK(weight, x_view, cu_count, bias)
|
| 116 |
+
return out.view(*x.shape[:-1], weight.shape[0])
|
| 117 |
+
elif m % 4 == 0 and n == 1 and k <= 8192 and bias is None:
|
| 118 |
+
out = ops.LLMM1(weight, x_view, 4)
|
| 119 |
+
return out.view(*x.shape[:-1], weight.shape[0])
|
| 120 |
+
return torch.nn.functional.linear(x, weight, bias)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def rocm_unquantized_gemm_impl_fake(
|
| 124 |
+
x: torch.Tensor,
|
| 125 |
+
weight: torch.Tensor,
|
| 126 |
+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 127 |
+
return x.new_empty((*x.shape[:-1], weight.shape[0]))
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def rocm_unquantized_gemm(layer: torch.nn.Module,
|
| 131 |
+
x: torch.Tensor,
|
| 132 |
+
weight: torch.Tensor,
|
| 133 |
+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 134 |
+
return torch.ops.vllm.rocm_unquantized_gemm_impl(x, weight, bias)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
direct_register_custom_op(
|
| 138 |
+
op_name="rocm_unquantized_gemm_impl",
|
| 139 |
+
op_func=rocm_unquantized_gemm_impl,
|
| 140 |
+
fake_impl=rocm_unquantized_gemm_impl_fake,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def check_cpu_sgl_kernel(n: int, k: int, dtype: torch.dtype) -> bool:
|
| 145 |
+
return (torch._C._cpu._is_amx_tile_supported()
|
| 146 |
+
and (dtype in (torch.bfloat16, torch.int8)) and k % 32 == 0
|
| 147 |
+
and n % 16 == 0)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def dispatch_cpu_unquantized_gemm(
|
| 151 |
+
layer: torch.nn.Module,
|
| 152 |
+
remove_weight: bool,
|
| 153 |
+
) -> None:
|
| 154 |
+
N, K = layer.weight.size()
|
| 155 |
+
dtype = layer.weight.dtype
|
| 156 |
+
if envs.VLLM_CPU_SGL_KERNEL and check_cpu_sgl_kernel(N, K, dtype):
|
| 157 |
+
packed_weight = torch.ops._C.convert_weight_packed(layer.weight)
|
| 158 |
+
if getattr(layer, "bias", None) is not None:
|
| 159 |
+
bias_f32 = layer.bias.to(torch.float32)
|
| 160 |
+
else:
|
| 161 |
+
bias_f32 = None
|
| 162 |
+
layer.cpu_linear = (
|
| 163 |
+
lambda x, weight, bias: torch.ops._C.weight_packed_linear(
|
| 164 |
+
x, packed_weight, bias_f32
|
| 165 |
+
if bias is not None else None, True))
|
| 166 |
+
if remove_weight:
|
| 167 |
+
layer.weight = torch.nn.Parameter(torch.empty(0),
|
| 168 |
+
requires_grad=False)
|
| 169 |
+
elif (ops._supports_onednn
|
| 170 |
+
and current_platform.get_cpu_architecture() == CpuArchEnum.X86):
|
| 171 |
+
origin_weight = layer.weight
|
| 172 |
+
if remove_weight:
|
| 173 |
+
layer.weight = torch.nn.Parameter(torch.empty(0),
|
| 174 |
+
requires_grad=False)
|
| 175 |
+
handler = ops.create_onednn_mm(origin_weight.t(), 32)
|
| 176 |
+
layer.cpu_linear = lambda x, weight, bias: ops.onednn_mm(
|
| 177 |
+
handler, x, bias)
|
| 178 |
+
else:
|
| 179 |
+
layer.cpu_linear = lambda x, weight, bias: torch.nn.functional.linear(
|
| 180 |
+
x, weight, bias)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def cpu_unquantized_gemm(layer: torch.nn.Module,
|
| 184 |
+
x: torch.Tensor,
|
| 185 |
+
weight: torch.Tensor,
|
| 186 |
+
bias: Optional[torch.Tensor] = None):
|
| 187 |
+
return layer.cpu_linear(x, weight, bias)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def dispatch_unquantized_gemm() -> Callable[..., torch.Tensor]:
|
| 191 |
+
if current_platform.is_rocm():
|
| 192 |
+
return rocm_unquantized_gemm
|
| 193 |
+
elif current_platform.is_cpu():
|
| 194 |
+
return cpu_unquantized_gemm
|
| 195 |
+
else:
|
| 196 |
+
return default_unquantized_gemm
|
vllm_hacked/model_executor/layers/utils_ori.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
| 3 |
+
"""Utility methods for model layers."""
|
| 4 |
+
from typing import Callable, Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from vllm import _custom_ops as ops
|
| 9 |
+
from vllm import envs
|
| 10 |
+
from vllm.platforms import CpuArchEnum, current_platform
|
| 11 |
+
from vllm.utils import direct_register_custom_op
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def shuffle_weight(w: torch.Tensor) -> torch.Tensor:
|
| 15 |
+
# Shuffle weight along the last dimension so that
|
| 16 |
+
# we folded the weights to adjance location
|
| 17 |
+
# Example:
|
| 18 |
+
# input:
|
| 19 |
+
# [[1, 2, 3, 4, 5, 6],
|
| 20 |
+
# [7, 8, 9, 10, 11, 12]]
|
| 21 |
+
# output:
|
| 22 |
+
# [[1, 4, 2, 5, 3, 6],
|
| 23 |
+
# [7, 10, 8, 11, 9, 12]]
|
| 24 |
+
# This will be used together with triton swiglu kernel
|
| 25 |
+
shape = w.shape
|
| 26 |
+
N = shape[-1]
|
| 27 |
+
first = w[..., :N // 2]
|
| 28 |
+
second = w[..., N // 2:]
|
| 29 |
+
|
| 30 |
+
stacked = torch.stack((first, second), dim=-1)
|
| 31 |
+
w_shuffled = stacked.reshape(shape)
|
| 32 |
+
return w_shuffled
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def get_token_bin_counts_and_mask(
|
| 36 |
+
tokens: torch.Tensor,
|
| 37 |
+
vocab_size: int,
|
| 38 |
+
num_seqs: int,
|
| 39 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 40 |
+
# Compute the bin counts for the tokens.
|
| 41 |
+
# vocab_size + 1 for padding.
|
| 42 |
+
bin_counts = torch.zeros((num_seqs, vocab_size + 1),
|
| 43 |
+
dtype=torch.long,
|
| 44 |
+
device=tokens.device)
|
| 45 |
+
bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
|
| 46 |
+
bin_counts = bin_counts[:, :vocab_size]
|
| 47 |
+
mask = bin_counts > 0
|
| 48 |
+
|
| 49 |
+
return bin_counts, mask
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
|
| 53 |
+
output_tokens_tensor: torch.Tensor,
|
| 54 |
+
presence_penalties: torch.Tensor,
|
| 55 |
+
frequency_penalties: torch.Tensor,
|
| 56 |
+
repetition_penalties: torch.Tensor) -> torch.Tensor:
|
| 57 |
+
"""
|
| 58 |
+
Applies penalties in place to the logits tensor
|
| 59 |
+
logits : The input logits tensor of shape [num_seqs, vocab_size]
|
| 60 |
+
prompt_tokens_tensor: A tensor containing the prompt tokens. The prompts
|
| 61 |
+
are padded to the maximum prompt length within the batch using
|
| 62 |
+
`vocab_size` as the padding value. The value `vocab_size` is used
|
| 63 |
+
for padding because it does not correspond to any valid token ID
|
| 64 |
+
in the vocabulary.
|
| 65 |
+
output_tokens_tensor: The output tokens tensor.
|
| 66 |
+
presence_penalties: The presence penalties of shape (num_seqs, )
|
| 67 |
+
frequency_penalties: The frequency penalties of shape (num_seqs, )
|
| 68 |
+
repetition_penalties: The repetition penalties of shape (num_seqs, )
|
| 69 |
+
"""
|
| 70 |
+
num_seqs, vocab_size = logits.shape
|
| 71 |
+
_, prompt_mask = get_token_bin_counts_and_mask(prompt_tokens_tensor,
|
| 72 |
+
vocab_size, num_seqs)
|
| 73 |
+
output_bin_counts, output_mask = get_token_bin_counts_and_mask(
|
| 74 |
+
output_tokens_tensor, vocab_size, num_seqs)
|
| 75 |
+
|
| 76 |
+
# Apply repetition penalties as a custom op
|
| 77 |
+
from vllm._custom_ops import apply_repetition_penalties
|
| 78 |
+
apply_repetition_penalties(logits, prompt_mask, output_mask,
|
| 79 |
+
repetition_penalties)
|
| 80 |
+
|
| 81 |
+
# We follow the definition in OpenAI API.
|
| 82 |
+
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
|
| 83 |
+
logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts
|
| 84 |
+
logits -= presence_penalties.unsqueeze(dim=1) * output_mask
|
| 85 |
+
return logits
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def default_unquantized_gemm(layer: torch.nn.Module,
|
| 89 |
+
x: torch.Tensor,
|
| 90 |
+
weight: torch.Tensor,
|
| 91 |
+
bias: Optional[torch.Tensor] = None):
|
| 92 |
+
return torch.nn.functional.linear(x, weight, bias)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def rocm_unquantized_gemm_impl(
|
| 96 |
+
x: torch.Tensor,
|
| 97 |
+
weight: torch.Tensor,
|
| 98 |
+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 99 |
+
from vllm.platforms.rocm import on_gfx9
|
| 100 |
+
k = weight.shape[1]
|
| 101 |
+
use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_gfx9() and \
|
| 102 |
+
x.dtype in [torch.float16, torch.bfloat16] \
|
| 103 |
+
and k % 8 == 0)
|
| 104 |
+
|
| 105 |
+
if use_skinny is not True:
|
| 106 |
+
return torch.nn.functional.linear(x, weight, bias)
|
| 107 |
+
|
| 108 |
+
x_view = x.view(-1, x.size(-1))
|
| 109 |
+
n = x_view.shape[0]
|
| 110 |
+
m = weight.shape[0]
|
| 111 |
+
cu_count = current_platform.get_cu_count()
|
| 112 |
+
|
| 113 |
+
if m > 8 and 0 < n <= 4:
|
| 114 |
+
out = ops.wvSplitK(weight, x_view, cu_count, bias)
|
| 115 |
+
return out.view(*x.shape[:-1], weight.shape[0])
|
| 116 |
+
elif m % 4 == 0 and n == 1 and k <= 8192 and bias is None:
|
| 117 |
+
out = ops.LLMM1(weight, x_view, 4)
|
| 118 |
+
return out.view(*x.shape[:-1], weight.shape[0])
|
| 119 |
+
return torch.nn.functional.linear(x, weight, bias)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def rocm_unquantized_gemm_impl_fake(
|
| 123 |
+
x: torch.Tensor,
|
| 124 |
+
weight: torch.Tensor,
|
| 125 |
+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 126 |
+
return x.new_empty((*x.shape[:-1], weight.shape[0]))
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def rocm_unquantized_gemm(layer: torch.nn.Module,
|
| 130 |
+
x: torch.Tensor,
|
| 131 |
+
weight: torch.Tensor,
|
| 132 |
+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 133 |
+
return torch.ops.vllm.rocm_unquantized_gemm_impl(x, weight, bias)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
direct_register_custom_op(
|
| 137 |
+
op_name="rocm_unquantized_gemm_impl",
|
| 138 |
+
op_func=rocm_unquantized_gemm_impl,
|
| 139 |
+
fake_impl=rocm_unquantized_gemm_impl_fake,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def check_cpu_sgl_kernel(n: int, k: int, dtype: torch.dtype) -> bool:
|
| 144 |
+
return (torch._C._cpu._is_amx_tile_supported()
|
| 145 |
+
and (dtype in (torch.bfloat16, torch.int8)) and k % 32 == 0
|
| 146 |
+
and n % 16 == 0)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def dispatch_cpu_unquantized_gemm(
|
| 150 |
+
layer: torch.nn.Module,
|
| 151 |
+
remove_weight: bool,
|
| 152 |
+
) -> None:
|
| 153 |
+
N, K = layer.weight.size()
|
| 154 |
+
dtype = layer.weight.dtype
|
| 155 |
+
if envs.VLLM_CPU_SGL_KERNEL and check_cpu_sgl_kernel(N, K, dtype):
|
| 156 |
+
packed_weight = torch.ops._C.convert_weight_packed(layer.weight)
|
| 157 |
+
if getattr(layer, "bias", None) is not None:
|
| 158 |
+
bias_f32 = layer.bias.to(torch.float32)
|
| 159 |
+
else:
|
| 160 |
+
bias_f32 = None
|
| 161 |
+
layer.cpu_linear = (
|
| 162 |
+
lambda x, weight, bias: torch.ops._C.weight_packed_linear(
|
| 163 |
+
x, packed_weight, bias_f32
|
| 164 |
+
if bias is not None else None, True))
|
| 165 |
+
if remove_weight:
|
| 166 |
+
layer.weight = torch.nn.Parameter(torch.empty(0),
|
| 167 |
+
requires_grad=False)
|
| 168 |
+
elif (ops._supports_onednn
|
| 169 |
+
and current_platform.get_cpu_architecture() == CpuArchEnum.X86):
|
| 170 |
+
origin_weight = layer.weight
|
| 171 |
+
if remove_weight:
|
| 172 |
+
layer.weight = torch.nn.Parameter(torch.empty(0),
|
| 173 |
+
requires_grad=False)
|
| 174 |
+
handler = ops.create_onednn_mm(origin_weight.t(), 32)
|
| 175 |
+
layer.cpu_linear = lambda x, weight, bias: ops.onednn_mm(
|
| 176 |
+
handler, x, bias)
|
| 177 |
+
else:
|
| 178 |
+
layer.cpu_linear = lambda x, weight, bias: torch.nn.functional.linear(
|
| 179 |
+
x, weight, bias)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def cpu_unquantized_gemm(layer: torch.nn.Module,
|
| 183 |
+
x: torch.Tensor,
|
| 184 |
+
weight: torch.Tensor,
|
| 185 |
+
bias: Optional[torch.Tensor] = None):
|
| 186 |
+
return layer.cpu_linear(x, weight, bias)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def dispatch_unquantized_gemm() -> Callable[..., torch.Tensor]:
|
| 190 |
+
if current_platform.is_rocm():
|
| 191 |
+
return rocm_unquantized_gemm
|
| 192 |
+
elif current_platform.is_cpu():
|
| 193 |
+
return cpu_unquantized_gemm
|
| 194 |
+
else:
|
| 195 |
+
return default_unquantized_gemm
|
vllm_hacked/model_executor/models/llama.py
ADDED
|
@@ -0,0 +1,688 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
| 3 |
+
|
| 4 |
+
# Adapted from
|
| 5 |
+
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
| 6 |
+
# Copyright 2023 The vLLM team.
|
| 7 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 8 |
+
#
|
| 9 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 10 |
+
# and OPT implementations in this library. It has been modified from its
|
| 11 |
+
# original forms to accommodate minor architectural differences compared
|
| 12 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 13 |
+
#
|
| 14 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 15 |
+
# you may not use this file except in compliance with the License.
|
| 16 |
+
# You may obtain a copy of the License at
|
| 17 |
+
#
|
| 18 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 19 |
+
#
|
| 20 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 21 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 22 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 23 |
+
# See the License for the specific language governing permissions and
|
| 24 |
+
# limitations under the License.
|
| 25 |
+
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
| 26 |
+
from collections.abc import Iterable
|
| 27 |
+
from itertools import islice
|
| 28 |
+
from typing import Any, Optional, Union
|
| 29 |
+
|
| 30 |
+
import torch
|
| 31 |
+
from torch import nn
|
| 32 |
+
from transformers import LlamaConfig
|
| 33 |
+
|
| 34 |
+
from vllm.attention import Attention, AttentionType
|
| 35 |
+
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
|
| 36 |
+
from vllm.compilation.decorators import support_torch_compile
|
| 37 |
+
from vllm.config import CacheConfig, VllmConfig
|
| 38 |
+
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
| 39 |
+
from vllm.model_executor.layers.activation import SiluAndMul
|
| 40 |
+
from vllm.model_executor.layers.layernorm import RMSNorm
|
| 41 |
+
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
| 42 |
+
QKVParallelLinear,
|
| 43 |
+
RowParallelLinear)
|
| 44 |
+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
| 45 |
+
from vllm.model_executor.layers.quantization import QuantizationConfig
|
| 46 |
+
from vllm.model_executor.layers.rotary_embedding import get_rope
|
| 47 |
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
| 48 |
+
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
| 49 |
+
from vllm.model_executor.model_loader.weight_utils import (
|
| 50 |
+
default_weight_loader, maybe_remap_kv_scale_name)
|
| 51 |
+
from vllm.sequence import IntermediateTensors
|
| 52 |
+
|
| 53 |
+
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
|
| 54 |
+
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
|
| 55 |
+
is_pp_missing_parameter,
|
| 56 |
+
make_empty_intermediate_tensors_factory, make_layers,
|
| 57 |
+
maybe_prefix)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class LlamaMLP(nn.Module):
|
| 61 |
+
|
| 62 |
+
def __init__(
|
| 63 |
+
self,
|
| 64 |
+
hidden_size: int,
|
| 65 |
+
intermediate_size: int,
|
| 66 |
+
hidden_act: str,
|
| 67 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 68 |
+
bias: bool = False,
|
| 69 |
+
prefix: str = "",
|
| 70 |
+
reduce_results: bool = True,
|
| 71 |
+
disable_tp: bool = False,
|
| 72 |
+
) -> None:
|
| 73 |
+
super().__init__()
|
| 74 |
+
self.gate_up_proj = MergedColumnParallelLinear(
|
| 75 |
+
input_size=hidden_size,
|
| 76 |
+
output_sizes=[intermediate_size] * 2,
|
| 77 |
+
bias=bias,
|
| 78 |
+
quant_config=quant_config,
|
| 79 |
+
disable_tp=disable_tp,
|
| 80 |
+
prefix=f"{prefix}.gate_up_proj",
|
| 81 |
+
)
|
| 82 |
+
self.down_proj = RowParallelLinear(
|
| 83 |
+
input_size=intermediate_size,
|
| 84 |
+
output_size=hidden_size,
|
| 85 |
+
bias=bias,
|
| 86 |
+
quant_config=quant_config,
|
| 87 |
+
reduce_results=reduce_results,
|
| 88 |
+
disable_tp=disable_tp,
|
| 89 |
+
prefix=f"{prefix}.down_proj",
|
| 90 |
+
)
|
| 91 |
+
if hidden_act != "silu":
|
| 92 |
+
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
| 93 |
+
"Only silu is supported for now.")
|
| 94 |
+
self.act_fn = SiluAndMul()
|
| 95 |
+
|
| 96 |
+
def forward(self, x):
|
| 97 |
+
x, _ = self.gate_up_proj(x)
|
| 98 |
+
x = self.act_fn(x)
|
| 99 |
+
x, _ = self.down_proj(x)
|
| 100 |
+
return x
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class LlamaAttention(nn.Module):
|
| 104 |
+
|
| 105 |
+
def __init__(
|
| 106 |
+
self,
|
| 107 |
+
config: LlamaConfig,
|
| 108 |
+
hidden_size: int,
|
| 109 |
+
num_heads: int,
|
| 110 |
+
num_kv_heads: int,
|
| 111 |
+
rope_theta: float = 10000,
|
| 112 |
+
rope_scaling: Optional[dict[str, Any]] = None,
|
| 113 |
+
max_position_embeddings: int = 8192,
|
| 114 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 115 |
+
bias: bool = False,
|
| 116 |
+
bias_o_proj: bool = False,
|
| 117 |
+
cache_config: Optional[CacheConfig] = None,
|
| 118 |
+
prefix: str = "",
|
| 119 |
+
attn_type: str = AttentionType.DECODER,
|
| 120 |
+
) -> None:
|
| 121 |
+
super().__init__()
|
| 122 |
+
layer_idx = extract_layer_index(prefix)
|
| 123 |
+
self.hidden_size = hidden_size
|
| 124 |
+
tp_size = get_tensor_model_parallel_world_size()
|
| 125 |
+
self.total_num_heads = num_heads
|
| 126 |
+
assert self.total_num_heads % tp_size == 0
|
| 127 |
+
self.num_heads = self.total_num_heads // tp_size
|
| 128 |
+
self.total_num_kv_heads = num_kv_heads
|
| 129 |
+
if self.total_num_kv_heads >= tp_size:
|
| 130 |
+
# Number of KV heads is greater than TP size, so we partition
|
| 131 |
+
# the KV heads across multiple tensor parallel GPUs.
|
| 132 |
+
assert self.total_num_kv_heads % tp_size == 0
|
| 133 |
+
else:
|
| 134 |
+
# Number of KV heads is less than TP size, so we replicate
|
| 135 |
+
# the KV heads across multiple tensor parallel GPUs.
|
| 136 |
+
assert tp_size % self.total_num_kv_heads == 0
|
| 137 |
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
| 138 |
+
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
|
| 139 |
+
head_dim = getattr(config, "head_dim", None)
|
| 140 |
+
if head_dim is None:
|
| 141 |
+
head_dim = self.hidden_size // self.total_num_heads
|
| 142 |
+
self.head_dim = head_dim
|
| 143 |
+
# Phi models introduced a partial_rotary_factor parameter in the config
|
| 144 |
+
self.partial_rotary_factor = getattr(config, "partial_rotary_factor",
|
| 145 |
+
1)
|
| 146 |
+
self.q_size = self.num_heads * self.head_dim
|
| 147 |
+
self.kv_size = self.num_kv_heads * self.head_dim
|
| 148 |
+
self.scaling = self.head_dim**-0.5
|
| 149 |
+
self.rope_theta = rope_theta
|
| 150 |
+
self.max_position_embeddings = max_position_embeddings
|
| 151 |
+
|
| 152 |
+
self.qkv_proj = QKVParallelLinear(
|
| 153 |
+
hidden_size=hidden_size,
|
| 154 |
+
head_size=self.head_dim,
|
| 155 |
+
total_num_heads=self.total_num_heads,
|
| 156 |
+
total_num_kv_heads=self.total_num_kv_heads,
|
| 157 |
+
bias=bias,
|
| 158 |
+
quant_config=quant_config,
|
| 159 |
+
prefix=f"{prefix}.qkv_proj",
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
self.o_proj = RowParallelLinear(
|
| 163 |
+
input_size=self.total_num_heads * self.head_dim,
|
| 164 |
+
output_size=hidden_size,
|
| 165 |
+
bias=bias_o_proj,
|
| 166 |
+
quant_config=quant_config,
|
| 167 |
+
prefix=f"{prefix}.o_proj",
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
self._init_rotary_emb(config,
|
| 171 |
+
rope_scaling=rope_scaling,
|
| 172 |
+
quant_config=quant_config)
|
| 173 |
+
|
| 174 |
+
sliding_window = None
|
| 175 |
+
if layer_types := getattr(config, "layer_types", None):
|
| 176 |
+
# Fix for Eagle3 compatibility:
|
| 177 |
+
# for draft models, subtract target layer count
|
| 178 |
+
# to get draft-relative layer index starting from 0
|
| 179 |
+
if hasattr(config, 'target_layer_count'):
|
| 180 |
+
# This is a draft model,
|
| 181 |
+
# adjust layer_idx to be relative to draft layers
|
| 182 |
+
effective_layer_idx = layer_idx - config.target_layer_count
|
| 183 |
+
else:
|
| 184 |
+
# This is a target model, use layer_idx directly
|
| 185 |
+
effective_layer_idx = layer_idx
|
| 186 |
+
assert effective_layer_idx < len(layer_types), \
|
| 187 |
+
f"effective_layer_idx: {effective_layer_idx} \
|
| 188 |
+
is out of bounds for layer_types: {layer_types}"
|
| 189 |
+
|
| 190 |
+
is_sliding = layer_types[
|
| 191 |
+
effective_layer_idx] == "sliding_attention"
|
| 192 |
+
if is_sliding:
|
| 193 |
+
sliding_window = config.sliding_window
|
| 194 |
+
|
| 195 |
+
attn_cls = (EncoderOnlyAttention
|
| 196 |
+
if attn_type == AttentionType.ENCODER_ONLY else Attention)
|
| 197 |
+
|
| 198 |
+
self.attn = attn_cls(
|
| 199 |
+
self.num_heads,
|
| 200 |
+
self.head_dim,
|
| 201 |
+
self.scaling,
|
| 202 |
+
num_kv_heads=self.num_kv_heads,
|
| 203 |
+
cache_config=cache_config,
|
| 204 |
+
quant_config=quant_config,
|
| 205 |
+
per_layer_sliding_window=sliding_window,
|
| 206 |
+
attn_type=attn_type,
|
| 207 |
+
prefix=f"{prefix}.attn",
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
def forward(
|
| 211 |
+
self,
|
| 212 |
+
positions: torch.Tensor,
|
| 213 |
+
hidden_states: torch.Tensor,
|
| 214 |
+
) -> torch.Tensor:
|
| 215 |
+
qkv, _ = self.qkv_proj(hidden_states)
|
| 216 |
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
| 217 |
+
q, k = self.rotary_emb(positions, q, k)
|
| 218 |
+
attn_output = self.attn(q, k, v)
|
| 219 |
+
output, _ = self.o_proj(attn_output)
|
| 220 |
+
return output
|
| 221 |
+
|
| 222 |
+
def _init_rotary_emb(self, config: LlamaConfig,
|
| 223 |
+
rope_scaling: Optional[dict[str, Any]],
|
| 224 |
+
quant_config: Optional[QuantizationConfig]) -> None:
|
| 225 |
+
is_neox_style = True
|
| 226 |
+
is_gguf = quant_config and quant_config.get_name() == "gguf"
|
| 227 |
+
if is_gguf and config.model_type == "llama":
|
| 228 |
+
is_neox_style = False
|
| 229 |
+
|
| 230 |
+
self.rotary_emb = get_rope(
|
| 231 |
+
self.head_dim,
|
| 232 |
+
rotary_dim=self.head_dim,
|
| 233 |
+
max_position=self.max_position_embeddings,
|
| 234 |
+
base=self.rope_theta,
|
| 235 |
+
rope_scaling=rope_scaling,
|
| 236 |
+
is_neox_style=is_neox_style,
|
| 237 |
+
partial_rotary_factor=self.partial_rotary_factor,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
class LlamaDecoderLayer(nn.Module):
|
| 242 |
+
|
| 243 |
+
def __init__(self,
|
| 244 |
+
vllm_config: VllmConfig,
|
| 245 |
+
prefix: str = "",
|
| 246 |
+
config: Optional[LlamaConfig] = None) -> None:
|
| 247 |
+
super().__init__()
|
| 248 |
+
|
| 249 |
+
config = config or vllm_config.model_config.hf_config
|
| 250 |
+
cache_config = vllm_config.cache_config
|
| 251 |
+
quant_config = vllm_config.quant_config
|
| 252 |
+
|
| 253 |
+
self.hidden_size = config.hidden_size
|
| 254 |
+
rope_theta = getattr(config, "rope_theta", 10000)
|
| 255 |
+
rope_scaling = getattr(config, "rope_scaling", None)
|
| 256 |
+
if rope_scaling is not None and getattr(
|
| 257 |
+
config, "original_max_position_embeddings", None):
|
| 258 |
+
rope_scaling["original_max_position_embeddings"] = (
|
| 259 |
+
config.original_max_position_embeddings)
|
| 260 |
+
max_position_embeddings = getattr(config, "max_position_embeddings",
|
| 261 |
+
8192)
|
| 262 |
+
# Support abacusai/Smaug-72B-v0.1 with attention_bias
|
| 263 |
+
# Support internlm/internlm-7b with bias
|
| 264 |
+
attention_bias = getattr(config, "attention_bias", False) or getattr(
|
| 265 |
+
config, "bias", False)
|
| 266 |
+
bias_o_proj = attention_bias
|
| 267 |
+
# support internlm/internlm3-8b with qkv_bias
|
| 268 |
+
if hasattr(config, 'qkv_bias'):
|
| 269 |
+
attention_bias = config.qkv_bias
|
| 270 |
+
|
| 271 |
+
# By default, Llama uses causal attention as it is a decoder-only model.
|
| 272 |
+
# You can override the HF config with `is_causal=False` to enable
|
| 273 |
+
# bidirectional attention, which is used in some embedding models
|
| 274 |
+
# (e.g. parasail-ai/GritLM-7B-vllm)
|
| 275 |
+
if getattr(config, "is_causal", True):
|
| 276 |
+
attn_type = AttentionType.DECODER
|
| 277 |
+
else:
|
| 278 |
+
attn_type = AttentionType.ENCODER_ONLY
|
| 279 |
+
|
| 280 |
+
self.self_attn = LlamaAttention(
|
| 281 |
+
config=config,
|
| 282 |
+
hidden_size=self.hidden_size,
|
| 283 |
+
num_heads=config.num_attention_heads,
|
| 284 |
+
num_kv_heads=getattr(config, "num_key_value_heads",
|
| 285 |
+
config.num_attention_heads),
|
| 286 |
+
rope_theta=rope_theta,
|
| 287 |
+
rope_scaling=rope_scaling,
|
| 288 |
+
max_position_embeddings=max_position_embeddings,
|
| 289 |
+
quant_config=quant_config,
|
| 290 |
+
bias=attention_bias,
|
| 291 |
+
bias_o_proj=bias_o_proj,
|
| 292 |
+
cache_config=cache_config,
|
| 293 |
+
prefix=f"{prefix}.self_attn",
|
| 294 |
+
attn_type=attn_type,
|
| 295 |
+
)
|
| 296 |
+
self.mlp = LlamaMLP(
|
| 297 |
+
hidden_size=self.hidden_size,
|
| 298 |
+
intermediate_size=config.intermediate_size,
|
| 299 |
+
hidden_act=config.hidden_act,
|
| 300 |
+
quant_config=quant_config,
|
| 301 |
+
bias=getattr(config, "mlp_bias", False),
|
| 302 |
+
prefix=f"{prefix}.mlp",
|
| 303 |
+
)
|
| 304 |
+
self.input_layernorm = RMSNorm(config.hidden_size,
|
| 305 |
+
eps=config.rms_norm_eps)
|
| 306 |
+
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
| 307 |
+
eps=config.rms_norm_eps)
|
| 308 |
+
|
| 309 |
+
def forward(
|
| 310 |
+
self,
|
| 311 |
+
positions: torch.Tensor,
|
| 312 |
+
hidden_states: torch.Tensor,
|
| 313 |
+
residual: Optional[torch.Tensor],
|
| 314 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 315 |
+
# Self Attention
|
| 316 |
+
if residual is None:
|
| 317 |
+
residual = hidden_states
|
| 318 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 319 |
+
else:
|
| 320 |
+
hidden_states, residual = self.input_layernorm(
|
| 321 |
+
hidden_states, residual)
|
| 322 |
+
hidden_states = self.self_attn(positions=positions,
|
| 323 |
+
hidden_states=hidden_states)
|
| 324 |
+
|
| 325 |
+
# Fully Connected
|
| 326 |
+
hidden_states, residual = self.post_attention_layernorm(
|
| 327 |
+
hidden_states, residual)
|
| 328 |
+
hidden_states = self.mlp(hidden_states)
|
| 329 |
+
return hidden_states, residual
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
@support_torch_compile
|
| 333 |
+
class LlamaModel(nn.Module):
|
| 334 |
+
|
| 335 |
+
def __init__(self,
|
| 336 |
+
*,
|
| 337 |
+
vllm_config: VllmConfig,
|
| 338 |
+
prefix: str = "",
|
| 339 |
+
layer_type: type[nn.Module] = LlamaDecoderLayer):
|
| 340 |
+
super().__init__()
|
| 341 |
+
|
| 342 |
+
config = vllm_config.model_config.hf_config
|
| 343 |
+
quant_config = vllm_config.quant_config
|
| 344 |
+
lora_config = vllm_config.lora_config
|
| 345 |
+
|
| 346 |
+
self.config = config
|
| 347 |
+
self.quant_config = quant_config
|
| 348 |
+
lora_vocab = (lora_config.lora_extra_vocab_size *
|
| 349 |
+
(lora_config.max_loras or 1)) if lora_config else 0
|
| 350 |
+
self.vocab_size = config.vocab_size + lora_vocab
|
| 351 |
+
self.org_vocab_size = config.vocab_size
|
| 352 |
+
if get_pp_group().is_first_rank or (config.tie_word_embeddings
|
| 353 |
+
and get_pp_group().is_last_rank):
|
| 354 |
+
self.embed_tokens = VocabParallelEmbedding(
|
| 355 |
+
self.vocab_size,
|
| 356 |
+
config.hidden_size,
|
| 357 |
+
org_num_embeddings=config.vocab_size,
|
| 358 |
+
quant_config=quant_config,
|
| 359 |
+
)
|
| 360 |
+
else:
|
| 361 |
+
self.embed_tokens = PPMissingLayer()
|
| 362 |
+
self.start_layer, self.end_layer, self.layers = make_layers(
|
| 363 |
+
config.num_hidden_layers,
|
| 364 |
+
lambda prefix: layer_type(vllm_config=vllm_config, prefix=prefix),
|
| 365 |
+
prefix=f"{prefix}.layers",
|
| 366 |
+
)
|
| 367 |
+
if get_pp_group().is_last_rank:
|
| 368 |
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 369 |
+
else:
|
| 370 |
+
self.norm = PPMissingLayer()
|
| 371 |
+
|
| 372 |
+
self.aux_hidden_state_layers = tuple[int, ...]()
|
| 373 |
+
|
| 374 |
+
self.make_empty_intermediate_tensors = (
|
| 375 |
+
make_empty_intermediate_tensors_factory(
|
| 376 |
+
["hidden_states", "residual"], config.hidden_size))
|
| 377 |
+
|
| 378 |
+
# 加入自定义的embedding层
|
| 379 |
+
self.emb = nn.ModuleList([nn.Embedding(config.vocab_size+1, config.hidden_size) #, lr=emb_lr)
|
| 380 |
+
for _ in range(self.config.code_depth)])
|
| 381 |
+
|
| 382 |
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 383 |
+
# print('===== get_input_embeddings is called =====')
|
| 384 |
+
# print ('input_ids:', input_ids)
|
| 385 |
+
# print(self.embed_tokens(input_ids).shape)
|
| 386 |
+
# print(sum([self.emb[k](input_ids) for k in range(self.config.code_depth)]).shape)
|
| 387 |
+
# import pdb; pdb.set_trace()
|
| 388 |
+
# return self.embed_tokens(input_ids)
|
| 389 |
+
return sum([self.emb[k](input_ids) for k in range(self.config.code_depth)])
|
| 390 |
+
|
| 391 |
+
def forward(
|
| 392 |
+
self,
|
| 393 |
+
input_ids: Optional[torch.Tensor],
|
| 394 |
+
positions: torch.Tensor,
|
| 395 |
+
intermediate_tensors: Optional[IntermediateTensors],
|
| 396 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 397 |
+
) -> Union[torch.Tensor, IntermediateTensors, tuple[torch.Tensor,
|
| 398 |
+
list[torch.Tensor]]]:
|
| 399 |
+
if get_pp_group().is_first_rank:
|
| 400 |
+
# import pdb; pdb.set_trace()
|
| 401 |
+
# print('input_ids', input_ids.shape, 'input_embedes_shape', inputs_embeds.shape)
|
| 402 |
+
if inputs_embeds is not None:
|
| 403 |
+
hidden_states = inputs_embeds
|
| 404 |
+
# print('use_input_embedes')
|
| 405 |
+
# print('input_ids exist:', input_ids is not None)
|
| 406 |
+
# import random
|
| 407 |
+
# count = random.random()
|
| 408 |
+
# if count>0.9:
|
| 409 |
+
# import pdb; pdb.set_trace()
|
| 410 |
+
else:
|
| 411 |
+
# hidden_states = self.get_input_embeddings(input_ids)
|
| 412 |
+
hidden_states = sum([self.emb[k](input_ids) for k in range(self.config.code_depth)]) # 修改为自己的embedding
|
| 413 |
+
print('use_input_ids:', input_ids)
|
| 414 |
+
residual = None
|
| 415 |
+
else:
|
| 416 |
+
assert intermediate_tensors is not None
|
| 417 |
+
hidden_states = intermediate_tensors["hidden_states"]
|
| 418 |
+
residual = intermediate_tensors["residual"]
|
| 419 |
+
|
| 420 |
+
aux_hidden_states = []
|
| 421 |
+
for idx, layer in enumerate(
|
| 422 |
+
islice(self.layers, self.start_layer, self.end_layer)):
|
| 423 |
+
if idx in self.aux_hidden_state_layers:
|
| 424 |
+
aux_hidden_states.append(hidden_states + residual)
|
| 425 |
+
hidden_states, residual = layer(positions, hidden_states, residual)
|
| 426 |
+
|
| 427 |
+
if not get_pp_group().is_last_rank:
|
| 428 |
+
return IntermediateTensors({
|
| 429 |
+
"hidden_states": hidden_states,
|
| 430 |
+
"residual": residual
|
| 431 |
+
})
|
| 432 |
+
|
| 433 |
+
hidden_states, _ = self.norm(hidden_states, residual)
|
| 434 |
+
|
| 435 |
+
if len(aux_hidden_states) > 0:
|
| 436 |
+
return hidden_states, aux_hidden_states
|
| 437 |
+
return hidden_states
|
| 438 |
+
|
| 439 |
+
def load_weights(self, weights: Iterable[tuple[str,
|
| 440 |
+
torch.Tensor]]) -> set[str]:
|
| 441 |
+
stacked_params_mapping = [
|
| 442 |
+
# (param_name, shard_name, shard_id)
|
| 443 |
+
(".qkv_proj", ".q_proj", "q"),
|
| 444 |
+
(".qkv_proj", ".k_proj", "k"),
|
| 445 |
+
(".qkv_proj", ".v_proj", "v"),
|
| 446 |
+
(".gate_up_proj", ".gate_proj", 0),
|
| 447 |
+
(".gate_up_proj", ".up_proj", 1),
|
| 448 |
+
]
|
| 449 |
+
params_dict = dict(self.named_parameters())
|
| 450 |
+
loaded_params: set[str] = set()
|
| 451 |
+
for name, loaded_weight in weights:
|
| 452 |
+
if "rotary_emb.inv_freq" in name:
|
| 453 |
+
continue
|
| 454 |
+
if ("rotary_emb.cos_cached" in name
|
| 455 |
+
or "rotary_emb.sin_cached" in name):
|
| 456 |
+
# Models trained using ColossalAI may include these tensors in
|
| 457 |
+
# the checkpoint. Skip them.
|
| 458 |
+
continue
|
| 459 |
+
if (self.quant_config is not None and
|
| 460 |
+
(scale_name := self.quant_config.get_cache_scale(name))):
|
| 461 |
+
# Loading kv cache quantization scales
|
| 462 |
+
param = params_dict[scale_name]
|
| 463 |
+
weight_loader = getattr(param, "weight_loader",
|
| 464 |
+
default_weight_loader)
|
| 465 |
+
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
| 466 |
+
loaded_weight[0])
|
| 467 |
+
weight_loader(param, loaded_weight)
|
| 468 |
+
loaded_params.add(scale_name)
|
| 469 |
+
continue
|
| 470 |
+
if "scale" in name:
|
| 471 |
+
# Remapping the name of FP8 kv-scale.
|
| 472 |
+
name = maybe_remap_kv_scale_name(name, params_dict)
|
| 473 |
+
if name is None:
|
| 474 |
+
continue
|
| 475 |
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
| 476 |
+
if weight_name not in name:
|
| 477 |
+
continue
|
| 478 |
+
name = name.replace(weight_name, param_name)
|
| 479 |
+
# Skip loading extra bias for GPTQ models.
|
| 480 |
+
if name.endswith(".bias") and name not in params_dict:
|
| 481 |
+
continue
|
| 482 |
+
|
| 483 |
+
if is_pp_missing_parameter(name, self):
|
| 484 |
+
continue
|
| 485 |
+
|
| 486 |
+
param = params_dict[name]
|
| 487 |
+
weight_loader = param.weight_loader
|
| 488 |
+
weight_loader(param, loaded_weight, shard_id)
|
| 489 |
+
break
|
| 490 |
+
else:
|
| 491 |
+
# Skip loading extra bias for GPTQ models.
|
| 492 |
+
if name.endswith(".bias") and name not in params_dict:
|
| 493 |
+
continue
|
| 494 |
+
|
| 495 |
+
if is_pp_missing_parameter(name, self):
|
| 496 |
+
continue
|
| 497 |
+
|
| 498 |
+
param = params_dict[name]
|
| 499 |
+
weight_loader = getattr(param, "weight_loader",
|
| 500 |
+
default_weight_loader)
|
| 501 |
+
weight_loader(param, loaded_weight)
|
| 502 |
+
loaded_params.add(name)
|
| 503 |
+
return loaded_params
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
|
| 507 |
+
packed_modules_mapping = {
|
| 508 |
+
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
| 509 |
+
"gate_up_proj": ["gate_proj", "up_proj"]
|
| 510 |
+
}
|
| 511 |
+
|
| 512 |
+
# LoRA specific attributes
|
| 513 |
+
embedding_modules = {
|
| 514 |
+
"embed_tokens": "input_embeddings",
|
| 515 |
+
"lm_head": "output_embeddings"
|
| 516 |
+
}
|
| 517 |
+
embedding_padding_modules = ["lm_head"]
|
| 518 |
+
|
| 519 |
+
# Mistral/Llama models can also be loaded with --load-format mistral
|
| 520 |
+
# from consolidated.safetensors checkpoints
|
| 521 |
+
mistral_mapping = {
|
| 522 |
+
"layers": "model.layers",
|
| 523 |
+
"attention": "self_attn",
|
| 524 |
+
"qscale_act": "input_scale",
|
| 525 |
+
"qscale_weight": "weight_scale",
|
| 526 |
+
"kv_fake_quantizer.qscale_act": "kv_scale",
|
| 527 |
+
"q_fake_quantizer.qscale_act": "attn.q_scale",
|
| 528 |
+
"k_fake_quantizer.qscale_act": "k_scale",
|
| 529 |
+
"v_fake_quantizer.qscale_act": "v_scale",
|
| 530 |
+
"wq": "q_proj",
|
| 531 |
+
"wk": "k_proj",
|
| 532 |
+
"wv": "v_proj",
|
| 533 |
+
"wo": "o_proj",
|
| 534 |
+
"attention_norm": "input_layernorm",
|
| 535 |
+
"feed_forward": "mlp",
|
| 536 |
+
"w1": "gate_proj",
|
| 537 |
+
"w2": "down_proj",
|
| 538 |
+
"w3": "up_proj",
|
| 539 |
+
"ffn_norm": "post_attention_layernorm",
|
| 540 |
+
"tok_embeddings": "model.embed_tokens",
|
| 541 |
+
"output": "lm_head",
|
| 542 |
+
"norm": "model.norm",
|
| 543 |
+
}
|
| 544 |
+
|
| 545 |
+
def __init__(self,
|
| 546 |
+
*,
|
| 547 |
+
vllm_config: VllmConfig,
|
| 548 |
+
prefix: str = "",
|
| 549 |
+
layer_type: type[nn.Module] = LlamaDecoderLayer):
|
| 550 |
+
super().__init__()
|
| 551 |
+
config = vllm_config.model_config.hf_config
|
| 552 |
+
quant_config = vllm_config.quant_config
|
| 553 |
+
lora_config = vllm_config.lora_config
|
| 554 |
+
self.config = config
|
| 555 |
+
self.lora_config = lora_config
|
| 556 |
+
self.model = self._init_model(vllm_config=vllm_config,
|
| 557 |
+
prefix=maybe_prefix(prefix, "model"),
|
| 558 |
+
layer_type=layer_type)
|
| 559 |
+
if get_pp_group().is_last_rank:
|
| 560 |
+
self.unpadded_vocab_size = config.vocab_size
|
| 561 |
+
if lora_config:
|
| 562 |
+
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
| 563 |
+
self.lm_head = ParallelLMHead(
|
| 564 |
+
self.unpadded_vocab_size,
|
| 565 |
+
config.hidden_size,
|
| 566 |
+
org_num_embeddings=config.vocab_size,
|
| 567 |
+
padding_size=(
|
| 568 |
+
DEFAULT_VOCAB_PADDING_SIZE
|
| 569 |
+
# We need bigger padding if using lora for kernel
|
| 570 |
+
# compatibility
|
| 571 |
+
if not lora_config else
|
| 572 |
+
lora_config.lora_vocab_padding_size),
|
| 573 |
+
quant_config=quant_config,
|
| 574 |
+
prefix=maybe_prefix(prefix, "lm_head"),
|
| 575 |
+
)
|
| 576 |
+
if config.tie_word_embeddings:
|
| 577 |
+
self.lm_head = self.lm_head.tie_weights(
|
| 578 |
+
self.model.embed_tokens)
|
| 579 |
+
|
| 580 |
+
logit_scale = getattr(config, "logit_scale", 1.0)
|
| 581 |
+
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
| 582 |
+
config.vocab_size,
|
| 583 |
+
logit_scale)
|
| 584 |
+
else:
|
| 585 |
+
self.lm_head = PPMissingLayer()
|
| 586 |
+
|
| 587 |
+
self.make_empty_intermediate_tensors = (
|
| 588 |
+
self.model.make_empty_intermediate_tensors)
|
| 589 |
+
|
| 590 |
+
|
| 591 |
+
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
| 592 |
+
self.model.aux_hidden_state_layers = layers
|
| 593 |
+
|
| 594 |
+
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
| 595 |
+
num_layers = len(self.model.layers)
|
| 596 |
+
return (2, num_layers // 2, num_layers - 3)
|
| 597 |
+
|
| 598 |
+
def _init_model(self,
|
| 599 |
+
vllm_config: VllmConfig,
|
| 600 |
+
prefix: str = "",
|
| 601 |
+
layer_type: type[nn.Module] = LlamaDecoderLayer):
|
| 602 |
+
return LlamaModel(vllm_config=vllm_config,
|
| 603 |
+
prefix=prefix,
|
| 604 |
+
layer_type=layer_type)
|
| 605 |
+
|
| 606 |
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 607 |
+
return self.model.get_input_embeddings(input_ids)
|
| 608 |
+
|
| 609 |
+
def forward(
|
| 610 |
+
self,
|
| 611 |
+
input_ids: torch.Tensor,
|
| 612 |
+
positions: torch.Tensor,
|
| 613 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
| 614 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 615 |
+
) -> Union[torch.Tensor, IntermediateTensors]:
|
| 616 |
+
model_output = self.model(input_ids, positions, intermediate_tensors,
|
| 617 |
+
inputs_embeds)
|
| 618 |
+
return model_output
|
| 619 |
+
|
| 620 |
+
def compute_logits(
|
| 621 |
+
self,
|
| 622 |
+
hidden_states: torch.Tensor,
|
| 623 |
+
) -> Optional[torch.Tensor]:
|
| 624 |
+
logits = self.logits_processor(self.lm_head, hidden_states)
|
| 625 |
+
return logits
|
| 626 |
+
|
| 627 |
+
def load_weights(self, weights: Iterable[tuple[str,
|
| 628 |
+
torch.Tensor]]) -> set[str]:
|
| 629 |
+
loader = AutoWeightsLoader(
|
| 630 |
+
self,
|
| 631 |
+
skip_prefixes=(["lm_head."]
|
| 632 |
+
if self.config.tie_word_embeddings else None),
|
| 633 |
+
)
|
| 634 |
+
return loader.load_weights(
|
| 635 |
+
self.maybe_remap_mistral(name, loaded_weight)
|
| 636 |
+
for name, loaded_weight in weights)
|
| 637 |
+
|
| 638 |
+
# This function is used to remap the mistral format as
|
| 639 |
+
# used by Mistral and Llama <=2
|
| 640 |
+
def maybe_remap_mistral(
|
| 641 |
+
self,
|
| 642 |
+
name: str,
|
| 643 |
+
loaded_weight: torch.Tensor,
|
| 644 |
+
) -> tuple[str, torch.Tensor]:
|
| 645 |
+
|
| 646 |
+
def permute(w: torch.Tensor, n_heads: int, attn_out: int):
|
| 647 |
+
attn_in = self.config.head_dim * n_heads
|
| 648 |
+
|
| 649 |
+
return w.view(n_heads, attn_in // n_heads // 2, 2,
|
| 650 |
+
attn_out).transpose(1, 2).reshape(attn_in, attn_out)
|
| 651 |
+
|
| 652 |
+
mapping = self.mistral_mapping
|
| 653 |
+
modules = name.split(".")
|
| 654 |
+
|
| 655 |
+
# rotary embeds should be sliced
|
| 656 |
+
# If using quantized model in mistral format,
|
| 657 |
+
# quantization scales (qscale_weight) also need to be sliced
|
| 658 |
+
if "wk" in modules and modules[-1] == "weight":
|
| 659 |
+
loaded_weight = permute(loaded_weight,
|
| 660 |
+
self.config.num_key_value_heads,
|
| 661 |
+
self.config.hidden_size)
|
| 662 |
+
elif "wk" in modules and modules[
|
| 663 |
+
-1] == "qscale_weight" and loaded_weight.numel() > 1:
|
| 664 |
+
loaded_weight = permute(loaded_weight,
|
| 665 |
+
self.config.num_key_value_heads, 1)
|
| 666 |
+
elif "wq" in modules and modules[-1] == "weight":
|
| 667 |
+
loaded_weight = permute(loaded_weight,
|
| 668 |
+
self.config.num_attention_heads,
|
| 669 |
+
self.config.hidden_size)
|
| 670 |
+
elif "wq" in modules and modules[
|
| 671 |
+
-1] == "qscale_weight" and loaded_weight.numel() > 1:
|
| 672 |
+
loaded_weight = permute(loaded_weight,
|
| 673 |
+
self.config.num_attention_heads, 1)
|
| 674 |
+
|
| 675 |
+
num_modules = len(modules)
|
| 676 |
+
for i in range(num_modules):
|
| 677 |
+
item = modules[i]
|
| 678 |
+
next_item = modules[i + 1] if i < num_modules - 1 else None
|
| 679 |
+
|
| 680 |
+
combined_item = (f"{item}.{next_item}"
|
| 681 |
+
if next_item is not None else None)
|
| 682 |
+
|
| 683 |
+
if combined_item in mapping:
|
| 684 |
+
name = name.replace(combined_item, mapping[combined_item])
|
| 685 |
+
elif item in mapping and mapping[item] not in name:
|
| 686 |
+
name = name.replace(item, mapping[item])
|
| 687 |
+
|
| 688 |
+
return name, loaded_weight
|
vllm_hacked/model_executor/sampling_metadata.py
ADDED
|
@@ -0,0 +1,596 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from array import array
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Dict, List, Optional, Tuple
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from vllm.sampling_params import SamplingParams, SamplingType
|
| 10 |
+
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData,
|
| 11 |
+
SequenceGroupMetadata)
|
| 12 |
+
from vllm.utils import (PyObjectCache, async_tensor_h2d,
|
| 13 |
+
is_pin_memory_available, make_tensor_with_pad)
|
| 14 |
+
|
| 15 |
+
_SAMPLING_EPS = 1e-5
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class SequenceGroupToSample:
|
| 20 |
+
# |---------- N-1 iteration --------|
|
| 21 |
+
# |---------------- N iteration ---------------------|
|
| 22 |
+
# |- tokenA -|......................|-- newTokens ---|
|
| 23 |
+
# |---------- context_len ----------|
|
| 24 |
+
# |-------------------- seq_len ----------------------|
|
| 25 |
+
# |-- query_len ---|
|
| 26 |
+
|
| 27 |
+
# Sequence ids for the sequence group in a previous step.
|
| 28 |
+
seq_ids: List[int]
|
| 29 |
+
sampling_params: SamplingParams
|
| 30 |
+
# seq_id -> sequence data.
|
| 31 |
+
seq_data: Dict[int, SequenceData]
|
| 32 |
+
# The length of the sequence (all tokens seen in the past + new token to
|
| 33 |
+
# compute attention) of the sequence group. None if it is in a decode
|
| 34 |
+
# stage.
|
| 35 |
+
seq_len: Optional[int]
|
| 36 |
+
# The length of new query tokens to compute in the current step. None if it
|
| 37 |
+
# is in a decode stage. The length of query_len <= seq_len if chunked
|
| 38 |
+
# prefill is enabled.
|
| 39 |
+
query_len: Optional[int]
|
| 40 |
+
# A random number generator for sampling.
|
| 41 |
+
generator: Optional[torch.Generator]
|
| 42 |
+
# True if the sequence group is in prefill stage. False if it is in a
|
| 43 |
+
# decode stage.
|
| 44 |
+
is_prompt: bool
|
| 45 |
+
# Query token indices from logits. to compute prompt logprob. Empty if
|
| 46 |
+
# prompt logprob is not required.
|
| 47 |
+
prompt_logprob_indices: List[int]
|
| 48 |
+
# Sample token indices from logits. Empty if sampling is not required.
|
| 49 |
+
sample_indices: List[int]
|
| 50 |
+
|
| 51 |
+
@property
|
| 52 |
+
def do_sample(self):
|
| 53 |
+
return len(self.sample_indices) > 0
|
| 54 |
+
|
| 55 |
+
def __post_init__(self):
|
| 56 |
+
if len(self.prompt_logprob_indices) > 0:
|
| 57 |
+
assert self.sampling_params.prompt_logprobs is not None
|
| 58 |
+
if self.is_prompt:
|
| 59 |
+
assert self.seq_len is not None
|
| 60 |
+
assert self.query_len is not None
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def gen_seq_group_to_sample_builder(num_seqs: int):
|
| 64 |
+
return lambda: SequenceGroupToSample(
|
| 65 |
+
seq_ids=[0] * num_seqs,
|
| 66 |
+
sampling_params=None,
|
| 67 |
+
seq_data=None, # type: ignore
|
| 68 |
+
seq_len=0,
|
| 69 |
+
query_len=0,
|
| 70 |
+
generator=None,
|
| 71 |
+
is_prompt=True,
|
| 72 |
+
prompt_logprob_indices=[],
|
| 73 |
+
sample_indices=[],
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class SamplingMetadataCache:
|
| 78 |
+
"""Used to cache SamplingMetadata objects between scheduler iterations"""
|
| 79 |
+
|
| 80 |
+
def __init__(self):
|
| 81 |
+
self._seq_group_to_sample_cache: Dict[int, PyObjectCache] = {}
|
| 82 |
+
|
| 83 |
+
def get_cached_seq_group_to_sample(self, num_seqs):
|
| 84 |
+
if num_seqs not in self._seq_group_to_sample_cache:
|
| 85 |
+
self._seq_group_to_sample_cache[num_seqs] = PyObjectCache(
|
| 86 |
+
gen_seq_group_to_sample_builder(num_seqs))
|
| 87 |
+
|
| 88 |
+
obj = self._seq_group_to_sample_cache[num_seqs].get_object()
|
| 89 |
+
return obj
|
| 90 |
+
|
| 91 |
+
def reset(self):
|
| 92 |
+
for cache in self._seq_group_to_sample_cache.values():
|
| 93 |
+
cache.reset()
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class SamplingMetadata:
|
| 97 |
+
"""Metadata for input sequences. Used in sampler.
|
| 98 |
+
|
| 99 |
+
The usage is as follow;
|
| 100 |
+
```
|
| 101 |
+
hidden_states = execute_model(...)
|
| 102 |
+
logits = hidden_states[sampling_metadata.selected_token_indices]
|
| 103 |
+
sample(logits)
|
| 104 |
+
|
| 105 |
+
def sample(logits):
|
| 106 |
+
# Use categorized_sample_indices for sampling....
|
| 107 |
+
```
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
seq_groups: List of batched sequence groups.
|
| 111 |
+
selected_token_indices: (num_query_tokens_to_logprob). Indices to find
|
| 112 |
+
logits from the initial model output hidden states.
|
| 113 |
+
categorized_sample_indices: SamplingType -> token indices to sample.
|
| 114 |
+
Each token indices is 2D tensor of (num_indices, num_indices) where
|
| 115 |
+
the first item means the sample index within the returned logit
|
| 116 |
+
(before pruning padding), and the second item means the sample
|
| 117 |
+
index after pruning using selected_token_indices.
|
| 118 |
+
For example, if the returned logit is [1, 2, 3], and we select
|
| 119 |
+
[1, 2] for sampling, the pruned logit will be [2, 3]. In this case,
|
| 120 |
+
The first tuple is [1, 2] (sampled index within original logit),
|
| 121 |
+
and the second tuple is [0, 1] (sampled index within pruned logit).
|
| 122 |
+
num_prompts: Number of prompt sequence groups in seq_groups.
|
| 123 |
+
skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU
|
| 124 |
+
serialization of token outputs.
|
| 125 |
+
reuse_sampling_tensors: Indicates if we want to reuse sampling
|
| 126 |
+
tensors that are part of the sampler forward pass. Currently,
|
| 127 |
+
it is mainly used for multi-step decode.
|
| 128 |
+
|
| 129 |
+
"""
|
| 130 |
+
|
| 131 |
+
def __init__(
|
| 132 |
+
self,
|
| 133 |
+
seq_groups: List[SequenceGroupToSample],
|
| 134 |
+
selected_token_indices: torch.Tensor,
|
| 135 |
+
categorized_sample_indices: Dict[SamplingType, torch.Tensor],
|
| 136 |
+
num_prompts: int,
|
| 137 |
+
skip_sampler_cpu_output: bool = False,
|
| 138 |
+
reuse_sampling_tensors: bool = False,
|
| 139 |
+
) -> None:
|
| 140 |
+
self.seq_groups = seq_groups
|
| 141 |
+
self.selected_token_indices = selected_token_indices
|
| 142 |
+
self.categorized_sample_indices = categorized_sample_indices
|
| 143 |
+
self.num_prompts = num_prompts
|
| 144 |
+
self.skip_sampler_cpu_output = skip_sampler_cpu_output
|
| 145 |
+
self.reuse_sampling_tensors = reuse_sampling_tensors
|
| 146 |
+
|
| 147 |
+
@staticmethod
|
| 148 |
+
def prepare(
|
| 149 |
+
seq_group_metadata_list: List[SequenceGroupMetadata],
|
| 150 |
+
seq_lens: List[int],
|
| 151 |
+
query_lens: List[int],
|
| 152 |
+
device: str,
|
| 153 |
+
pin_memory: bool,
|
| 154 |
+
generators: Optional[Dict[str, torch.Generator]] = None,
|
| 155 |
+
cache: Optional[SamplingMetadataCache] = None,
|
| 156 |
+
) -> "SamplingMetadata":
|
| 157 |
+
(
|
| 158 |
+
seq_groups,
|
| 159 |
+
selected_token_indices,
|
| 160 |
+
categorized_sample_indices,
|
| 161 |
+
num_prompts,
|
| 162 |
+
) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens,
|
| 163 |
+
device, generators, cache)
|
| 164 |
+
selected_token_indices = async_tensor_h2d(
|
| 165 |
+
selected_token_indices,
|
| 166 |
+
dtype=torch.long,
|
| 167 |
+
target_device=device,
|
| 168 |
+
pin_memory=pin_memory,
|
| 169 |
+
)
|
| 170 |
+
categorized_sample_indices = {
|
| 171 |
+
t:
|
| 172 |
+
async_tensor_h2d(
|
| 173 |
+
seq_ids,
|
| 174 |
+
dtype=torch.int,
|
| 175 |
+
target_device=device,
|
| 176 |
+
pin_memory=pin_memory,
|
| 177 |
+
)
|
| 178 |
+
for t, seq_ids in categorized_sample_indices.items()
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
sampling_metadata = SamplingMetadata(
|
| 182 |
+
seq_groups=seq_groups,
|
| 183 |
+
selected_token_indices=selected_token_indices,
|
| 184 |
+
categorized_sample_indices=categorized_sample_indices,
|
| 185 |
+
num_prompts=num_prompts,
|
| 186 |
+
)
|
| 187 |
+
return sampling_metadata
|
| 188 |
+
|
| 189 |
+
def __repr__(self) -> str:
|
| 190 |
+
return (
|
| 191 |
+
"SamplingMetadata("
|
| 192 |
+
f"seq_groups={self.seq_groups}, "
|
| 193 |
+
f"selected_token_indices={self.selected_token_indices}, "
|
| 194 |
+
f"categorized_sample_indices={self.categorized_sample_indices}), ")
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def _prepare_seq_groups(
|
| 198 |
+
seq_group_metadata_list: List[SequenceGroupMetadata],
|
| 199 |
+
seq_lens: List[int],
|
| 200 |
+
query_lens: List[int],
|
| 201 |
+
device: str,
|
| 202 |
+
generators: Optional[Dict[str, torch.Generator]] = None,
|
| 203 |
+
cache: Optional[SamplingMetadataCache] = None,
|
| 204 |
+
) -> Tuple[
|
| 205 |
+
List[SequenceGroupToSample],
|
| 206 |
+
List[int],
|
| 207 |
+
Dict[SamplingType, List[int]],
|
| 208 |
+
int,
|
| 209 |
+
]:
|
| 210 |
+
"""Prepare sequence groups and indices for sampling.
|
| 211 |
+
|
| 212 |
+
Args:
|
| 213 |
+
seq_group_metadata_list: A list of sequence group to batch.
|
| 214 |
+
seq_lens: A list of sequence lens per sequence group.
|
| 215 |
+
Index of prompt len should match with seq_group_metadata_list.
|
| 216 |
+
query_lens: A list of query lengths. Prompt lens include the length
|
| 217 |
+
of entire prompt tokens, and it could be shorter.
|
| 218 |
+
device: A device to use for random number generators,
|
| 219 |
+
`SequenceGroupToSample.generator`.
|
| 220 |
+
generators: A store of per-request random number generators used
|
| 221 |
+
for seeded requests.
|
| 222 |
+
|
| 223 |
+
Returns:
|
| 224 |
+
seq_groups: A list of sequence group to sample.
|
| 225 |
+
selected_token_indices: See the definition from `SamplingMetadata`.
|
| 226 |
+
categorized_sample_indices: See the definition from `SamplingMetadata`.
|
| 227 |
+
num_prompts: Total number of prompts from `seq_group_metadata_list`.
|
| 228 |
+
"""
|
| 229 |
+
# Batched sequence groups for the current model forward stsep.
|
| 230 |
+
seq_groups: List[SequenceGroupToSample] = []
|
| 231 |
+
# A list of token indices to sample/compute logprob. It is used to
|
| 232 |
+
# prune the outcome logits from the model for the performance.
|
| 233 |
+
selected_token_indices: List[int] = []
|
| 234 |
+
# Used for selected_token_indices.
|
| 235 |
+
model_output_idx = 0
|
| 236 |
+
|
| 237 |
+
# Sampling type -> (
|
| 238 |
+
# indices to sample/prompt logprob within pruned output logits,
|
| 239 |
+
# indices to sample within pruned logits)
|
| 240 |
+
categorized_sample_indices: Dict[SamplingType, List[int]] = {
|
| 241 |
+
t: []
|
| 242 |
+
for t in SamplingType
|
| 243 |
+
}
|
| 244 |
+
# Index of logits to compute logprob. Logits include both prompt logprob
|
| 245 |
+
# and sample logprob indices.
|
| 246 |
+
logit_idx = 0
|
| 247 |
+
# Total number of prompts from given sequence groups.
|
| 248 |
+
num_prompts = 0
|
| 249 |
+
|
| 250 |
+
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
| 251 |
+
seq_ids = seq_group_metadata.seq_data.keys()
|
| 252 |
+
|
| 253 |
+
if cache is not None:
|
| 254 |
+
sample_obj = cache.get_cached_seq_group_to_sample(len(seq_ids))
|
| 255 |
+
|
| 256 |
+
for j, seq_id in enumerate(seq_ids):
|
| 257 |
+
sample_obj.seq_ids[j] = seq_id
|
| 258 |
+
|
| 259 |
+
sample_obj.prompt_logprob_indices.clear()
|
| 260 |
+
sample_obj.sample_indices.clear()
|
| 261 |
+
|
| 262 |
+
sampling_params = seq_group_metadata.sampling_params
|
| 263 |
+
is_prompt = seq_group_metadata.is_prompt
|
| 264 |
+
generator: Optional[torch.Generator] = None
|
| 265 |
+
# If the current seq group is in decode stage, it is None.
|
| 266 |
+
seq_len: Optional[int] = None
|
| 267 |
+
query_len: Optional[int] = None
|
| 268 |
+
prompt_logprob_indices: List[int] = (sample_obj.prompt_logprob_indices
|
| 269 |
+
if cache is not None else [])
|
| 270 |
+
sample_indices: List[int] = (sample_obj.sample_indices
|
| 271 |
+
if cache is not None else [])
|
| 272 |
+
do_sample = seq_group_metadata.do_sample
|
| 273 |
+
|
| 274 |
+
if seq_group_metadata.is_prompt:
|
| 275 |
+
if sampling_params.seed is not None:
|
| 276 |
+
generator = torch.Generator(device=device).manual_seed(
|
| 277 |
+
sampling_params.seed)
|
| 278 |
+
if generators is not None:
|
| 279 |
+
generators[seq_group_metadata.request_id] = generator
|
| 280 |
+
|
| 281 |
+
num_prompts += 1
|
| 282 |
+
num_prefill_sample = len(seq_ids)
|
| 283 |
+
assert num_prefill_sample == 1
|
| 284 |
+
assert query_lens is not None and seq_lens is not None
|
| 285 |
+
query_len, seq_len = query_lens[i], seq_lens[i]
|
| 286 |
+
# If we need sampling, exclude num_prefill_sample tokens from
|
| 287 |
+
# prompt logprob.
|
| 288 |
+
prompt_logprob_len = (query_len - num_prefill_sample
|
| 289 |
+
if do_sample else query_len)
|
| 290 |
+
sample_len = num_prefill_sample if do_sample else 0
|
| 291 |
+
else:
|
| 292 |
+
# Decode
|
| 293 |
+
prompt_logprob_len = 0
|
| 294 |
+
query_len = query_lens[i] if query_lens is not None and len(
|
| 295 |
+
query_lens) > 0 else 1
|
| 296 |
+
sample_len = len(seq_ids) * query_len if do_sample else 0
|
| 297 |
+
|
| 298 |
+
if sampling_params.seed is not None and generators is not None:
|
| 299 |
+
generator = generators.get(seq_group_metadata.request_id)
|
| 300 |
+
|
| 301 |
+
# Update indices to select from the model output.
|
| 302 |
+
"""
|
| 303 |
+
This blocks computes selected_token_indices which is used in the
|
| 304 |
+
following way.
|
| 305 |
+
|
| 306 |
+
hidden_states = model(...)
|
| 307 |
+
logits = hidden_states[selected_token_indices]
|
| 308 |
+
"""
|
| 309 |
+
|
| 310 |
+
if sampling_params.prompt_logprobs is not None:
|
| 311 |
+
selected_token_indices.extend(
|
| 312 |
+
range(model_output_idx, model_output_idx + prompt_logprob_len))
|
| 313 |
+
model_output_idx += prompt_logprob_len
|
| 314 |
+
if do_sample:
|
| 315 |
+
selected_token_indices.extend(
|
| 316 |
+
range(model_output_idx, model_output_idx + sample_len))
|
| 317 |
+
model_output_idx += sample_len
|
| 318 |
+
|
| 319 |
+
# We now find indices for logprob computation and sampling.
|
| 320 |
+
"""
|
| 321 |
+
This block computes categorized_sample_indices which is used in the
|
| 322 |
+
following way.
|
| 323 |
+
|
| 324 |
+
hidden_states = model(...)
|
| 325 |
+
logits = hidden_states[selected_token_indices]
|
| 326 |
+
def sample(logits):
|
| 327 |
+
# Use categorized_sample_indices for sampling.
|
| 328 |
+
# prompt_logprob_indices to find prompt logprob indices.
|
| 329 |
+
# sample_indices to find sample indices.
|
| 330 |
+
"""
|
| 331 |
+
|
| 332 |
+
if sampling_params.prompt_logprobs is not None:
|
| 333 |
+
prompt_logprob_indices.extend(
|
| 334 |
+
range(logit_idx, logit_idx + prompt_logprob_len))
|
| 335 |
+
logit_idx += prompt_logprob_len
|
| 336 |
+
if do_sample:
|
| 337 |
+
sample_indices.extend(range(logit_idx, logit_idx + sample_len))
|
| 338 |
+
categorized_sample_indices[sampling_params.sampling_type].extend(
|
| 339 |
+
list(range(logit_idx, logit_idx + sample_len)))
|
| 340 |
+
logit_idx += sample_len
|
| 341 |
+
|
| 342 |
+
if cache is not None:
|
| 343 |
+
sample_obj.sampling_params = sampling_params
|
| 344 |
+
sample_obj.seq_data = seq_group_metadata.seq_data
|
| 345 |
+
sample_obj.seq_len = seq_len
|
| 346 |
+
sample_obj.query_len = query_len
|
| 347 |
+
sample_obj.generator = generator
|
| 348 |
+
sample_obj.is_prompt = is_prompt
|
| 349 |
+
else:
|
| 350 |
+
sample_obj = SequenceGroupToSample(
|
| 351 |
+
seq_ids=list(seq_ids),
|
| 352 |
+
sampling_params=sampling_params,
|
| 353 |
+
seq_data=seq_group_metadata.seq_data,
|
| 354 |
+
seq_len=seq_len,
|
| 355 |
+
query_len=query_len,
|
| 356 |
+
generator=generator,
|
| 357 |
+
is_prompt=is_prompt,
|
| 358 |
+
prompt_logprob_indices=list(prompt_logprob_indices),
|
| 359 |
+
sample_indices=list(sample_indices),
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
seq_groups.append(sample_obj)
|
| 363 |
+
|
| 364 |
+
if cache is not None:
|
| 365 |
+
cache.reset()
|
| 366 |
+
|
| 367 |
+
return (seq_groups, selected_token_indices, categorized_sample_indices,
|
| 368 |
+
num_prompts)
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
@dataclass
|
| 372 |
+
class SamplingTensors:
|
| 373 |
+
"""Tensors for sampling."""
|
| 374 |
+
|
| 375 |
+
temperatures: torch.Tensor
|
| 376 |
+
top_ps: torch.Tensor
|
| 377 |
+
top_ks: torch.Tensor
|
| 378 |
+
min_ps: torch.Tensor
|
| 379 |
+
presence_penalties: torch.Tensor
|
| 380 |
+
frequency_penalties: torch.Tensor
|
| 381 |
+
repetition_penalties: torch.Tensor
|
| 382 |
+
prompt_tokens: torch.Tensor
|
| 383 |
+
output_tokens: torch.Tensor
|
| 384 |
+
|
| 385 |
+
@classmethod
|
| 386 |
+
def from_sampling_metadata(
|
| 387 |
+
cls,
|
| 388 |
+
sampling_metadata: "SamplingMetadata",
|
| 389 |
+
vocab_size: int,
|
| 390 |
+
device: torch.device,
|
| 391 |
+
dtype: torch.dtype,
|
| 392 |
+
) -> Tuple["SamplingTensors", bool, bool, bool]:
|
| 393 |
+
prompt_tokens: List[array] = []
|
| 394 |
+
output_tokens: List[array] = []
|
| 395 |
+
top_ks: List[int] = []
|
| 396 |
+
temperatures: List[float] = []
|
| 397 |
+
top_ps: List[float] = []
|
| 398 |
+
min_ps: List[float] = []
|
| 399 |
+
presence_penalties: List[float] = []
|
| 400 |
+
frequency_penalties: List[float] = []
|
| 401 |
+
repetition_penalties: List[float] = []
|
| 402 |
+
do_penalties = False
|
| 403 |
+
do_top_p_top_k = False
|
| 404 |
+
do_min_p = False
|
| 405 |
+
|
| 406 |
+
assert sampling_metadata.seq_groups is not None
|
| 407 |
+
for seq_group in sampling_metadata.seq_groups:
|
| 408 |
+
seq_ids = seq_group.seq_ids
|
| 409 |
+
sampling_params = seq_group.sampling_params
|
| 410 |
+
temperature = sampling_params.temperature
|
| 411 |
+
p = sampling_params.presence_penalty
|
| 412 |
+
f = sampling_params.frequency_penalty
|
| 413 |
+
r = sampling_params.repetition_penalty
|
| 414 |
+
top_p = sampling_params.top_p
|
| 415 |
+
min_p = sampling_params.min_p
|
| 416 |
+
|
| 417 |
+
# k should not be greater than the vocab size.
|
| 418 |
+
top_k = min(sampling_params.top_k, vocab_size)
|
| 419 |
+
top_k = vocab_size if top_k == -1 else top_k
|
| 420 |
+
if temperature < _SAMPLING_EPS:
|
| 421 |
+
# NOTE: Zero temperature means deterministic sampling
|
| 422 |
+
# (i.e., greedy sampling or beam search).
|
| 423 |
+
# Set the temperature to 1 to avoid division by zero.
|
| 424 |
+
temperature = 1.0
|
| 425 |
+
if not do_top_p_top_k and (top_p < 1.0 - _SAMPLING_EPS
|
| 426 |
+
or top_k != vocab_size):
|
| 427 |
+
do_top_p_top_k = True
|
| 428 |
+
if not do_min_p and min_p > _SAMPLING_EPS:
|
| 429 |
+
do_min_p = True
|
| 430 |
+
if not do_penalties and (abs(p) >= _SAMPLING_EPS
|
| 431 |
+
or abs(f) >= _SAMPLING_EPS
|
| 432 |
+
or abs(r - 1.0) >= _SAMPLING_EPS):
|
| 433 |
+
do_penalties = True
|
| 434 |
+
|
| 435 |
+
is_prompt = seq_group.is_prompt
|
| 436 |
+
if is_prompt and sampling_params.prompt_logprobs is not None:
|
| 437 |
+
# For tokens in the prompt that we only need to get
|
| 438 |
+
# their logprobs
|
| 439 |
+
query_len = seq_group.query_len
|
| 440 |
+
assert query_len is not None
|
| 441 |
+
prefill_len = len(seq_group.prompt_logprob_indices)
|
| 442 |
+
temperatures += [temperature] * prefill_len
|
| 443 |
+
top_ps += [top_p] * prefill_len
|
| 444 |
+
top_ks += [top_k] * prefill_len
|
| 445 |
+
min_ps += [min_p] * prefill_len
|
| 446 |
+
presence_penalties += [0] * prefill_len
|
| 447 |
+
frequency_penalties += [0] * prefill_len
|
| 448 |
+
repetition_penalties += [1] * prefill_len
|
| 449 |
+
|
| 450 |
+
if seq_group.do_sample:
|
| 451 |
+
sample_lens = len(seq_group.sample_indices)
|
| 452 |
+
assert sample_lens >= len(seq_ids)
|
| 453 |
+
temperatures += [temperature] * sample_lens
|
| 454 |
+
top_ps += [top_p] * sample_lens
|
| 455 |
+
top_ks += [top_k] * sample_lens
|
| 456 |
+
min_ps += [min_p] * sample_lens
|
| 457 |
+
presence_penalties += [p] * sample_lens
|
| 458 |
+
frequency_penalties += [f] * sample_lens
|
| 459 |
+
repetition_penalties += [r] * sample_lens
|
| 460 |
+
|
| 461 |
+
if do_penalties:
|
| 462 |
+
for seq_group in sampling_metadata.seq_groups:
|
| 463 |
+
seq_ids = seq_group.seq_ids
|
| 464 |
+
sampling_params = seq_group.sampling_params
|
| 465 |
+
if (seq_group.is_prompt
|
| 466 |
+
and sampling_params.prompt_logprobs is not None):
|
| 467 |
+
prefill_len = len(seq_group.prompt_logprob_indices)
|
| 468 |
+
prompt_tokens.extend(
|
| 469 |
+
array(VLLM_TOKEN_ID_ARRAY_TYPE)
|
| 470 |
+
for _ in range(prefill_len))
|
| 471 |
+
output_tokens.extend(
|
| 472 |
+
array(VLLM_TOKEN_ID_ARRAY_TYPE)
|
| 473 |
+
for _ in range(prefill_len))
|
| 474 |
+
if seq_group.do_sample:
|
| 475 |
+
for seq_id in seq_ids:
|
| 476 |
+
seq_data = seq_group.seq_data[seq_id]
|
| 477 |
+
prompt_tokens.append(seq_data.prompt_token_ids_array)
|
| 478 |
+
output_tokens.append(seq_data.output_token_ids_array)
|
| 479 |
+
|
| 480 |
+
sampling_tensors = SamplingTensors.from_lists(
|
| 481 |
+
temperatures,
|
| 482 |
+
top_ps,
|
| 483 |
+
top_ks,
|
| 484 |
+
min_ps,
|
| 485 |
+
presence_penalties,
|
| 486 |
+
frequency_penalties,
|
| 487 |
+
repetition_penalties,
|
| 488 |
+
prompt_tokens,
|
| 489 |
+
output_tokens,
|
| 490 |
+
vocab_size,
|
| 491 |
+
device,
|
| 492 |
+
dtype,
|
| 493 |
+
)
|
| 494 |
+
return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p)
|
| 495 |
+
|
| 496 |
+
@classmethod
|
| 497 |
+
def from_lists(
|
| 498 |
+
cls,
|
| 499 |
+
temperatures: List[float],
|
| 500 |
+
top_ps: List[float],
|
| 501 |
+
top_ks: List[int],
|
| 502 |
+
min_ps: List[float],
|
| 503 |
+
presence_penalties: List[float],
|
| 504 |
+
frequency_penalties: List[float],
|
| 505 |
+
repetition_penalties: List[float],
|
| 506 |
+
prompt_tokens: List[array],
|
| 507 |
+
output_tokens: List[array],
|
| 508 |
+
vocab_size: int,
|
| 509 |
+
device: torch.device,
|
| 510 |
+
dtype: torch.dtype,
|
| 511 |
+
) -> "SamplingTensors":
|
| 512 |
+
# Note that the performance will be very bad without
|
| 513 |
+
# pinned memory.
|
| 514 |
+
pin_memory = is_pin_memory_available()
|
| 515 |
+
|
| 516 |
+
do_penalties = prompt_tokens or output_tokens
|
| 517 |
+
|
| 518 |
+
if do_penalties:
|
| 519 |
+
prompt_t = make_tensor_with_pad(
|
| 520 |
+
prompt_tokens,
|
| 521 |
+
vocab_size,
|
| 522 |
+
device="cpu",
|
| 523 |
+
dtype=torch.int64,
|
| 524 |
+
pin_memory=pin_memory,
|
| 525 |
+
)
|
| 526 |
+
output_t = make_tensor_with_pad(
|
| 527 |
+
output_tokens,
|
| 528 |
+
vocab_size,
|
| 529 |
+
device="cpu",
|
| 530 |
+
dtype=torch.int64,
|
| 531 |
+
pin_memory=pin_memory,
|
| 532 |
+
)
|
| 533 |
+
else:
|
| 534 |
+
empty_tensor = torch.empty(0, device=device, dtype=torch.long)
|
| 535 |
+
prompt_t = empty_tensor
|
| 536 |
+
output_t = empty_tensor
|
| 537 |
+
|
| 538 |
+
temperatures_t = torch.tensor(
|
| 539 |
+
temperatures,
|
| 540 |
+
device="cpu",
|
| 541 |
+
dtype=dtype,
|
| 542 |
+
pin_memory=pin_memory,
|
| 543 |
+
)
|
| 544 |
+
top_ps_t = torch.tensor(
|
| 545 |
+
top_ps,
|
| 546 |
+
device="cpu",
|
| 547 |
+
dtype=dtype,
|
| 548 |
+
pin_memory=pin_memory,
|
| 549 |
+
)
|
| 550 |
+
min_ps_t = torch.tensor(
|
| 551 |
+
min_ps,
|
| 552 |
+
device="cpu",
|
| 553 |
+
dtype=dtype,
|
| 554 |
+
pin_memory=pin_memory,
|
| 555 |
+
)
|
| 556 |
+
presence_penalties_t = torch.tensor(
|
| 557 |
+
presence_penalties,
|
| 558 |
+
device="cpu",
|
| 559 |
+
dtype=dtype,
|
| 560 |
+
pin_memory=pin_memory,
|
| 561 |
+
)
|
| 562 |
+
frequency_penalties_t = torch.tensor(
|
| 563 |
+
frequency_penalties,
|
| 564 |
+
device="cpu",
|
| 565 |
+
dtype=dtype,
|
| 566 |
+
pin_memory=pin_memory,
|
| 567 |
+
)
|
| 568 |
+
repetition_penalties_t = torch.tensor(
|
| 569 |
+
repetition_penalties,
|
| 570 |
+
device="cpu",
|
| 571 |
+
dtype=dtype,
|
| 572 |
+
pin_memory=pin_memory,
|
| 573 |
+
)
|
| 574 |
+
top_ks_t = torch.tensor(
|
| 575 |
+
top_ks,
|
| 576 |
+
device="cpu",
|
| 577 |
+
dtype=torch.int,
|
| 578 |
+
pin_memory=pin_memory,
|
| 579 |
+
)
|
| 580 |
+
# Because the memory is pinned, we can do non-blocking
|
| 581 |
+
# transfer to device.
|
| 582 |
+
|
| 583 |
+
return cls(
|
| 584 |
+
temperatures=temperatures_t.to(device=device, non_blocking=True),
|
| 585 |
+
top_ps=top_ps_t.to(device=device, non_blocking=True),
|
| 586 |
+
top_ks=top_ks_t.to(device=device, non_blocking=True),
|
| 587 |
+
min_ps=min_ps_t.to(device=device, non_blocking=True),
|
| 588 |
+
presence_penalties=presence_penalties_t.to(device=device,
|
| 589 |
+
non_blocking=True),
|
| 590 |
+
frequency_penalties=frequency_penalties_t.to(device=device,
|
| 591 |
+
non_blocking=True),
|
| 592 |
+
repetition_penalties=repetition_penalties_t.to(device=device,
|
| 593 |
+
non_blocking=True),
|
| 594 |
+
prompt_tokens=prompt_t.to(device=device, non_blocking=True),
|
| 595 |
+
output_tokens=output_t.to(device=device, non_blocking=True),
|
| 596 |
+
)
|
vllm_hacked/model_executor/sampling_metadata_ori.py
ADDED
|
@@ -0,0 +1,596 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from array import array
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Dict, List, Optional, Tuple
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from vllm.sampling_params import SamplingParams, SamplingType
|
| 10 |
+
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData,
|
| 11 |
+
SequenceGroupMetadata)
|
| 12 |
+
from vllm.utils import (PyObjectCache, async_tensor_h2d,
|
| 13 |
+
is_pin_memory_available, make_tensor_with_pad)
|
| 14 |
+
|
| 15 |
+
_SAMPLING_EPS = 1e-5
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class SequenceGroupToSample:
|
| 20 |
+
# |---------- N-1 iteration --------|
|
| 21 |
+
# |---------------- N iteration ---------------------|
|
| 22 |
+
# |- tokenA -|......................|-- newTokens ---|
|
| 23 |
+
# |---------- context_len ----------|
|
| 24 |
+
# |-------------------- seq_len ----------------------|
|
| 25 |
+
# |-- query_len ---|
|
| 26 |
+
|
| 27 |
+
# Sequence ids for the sequence group in a previous step.
|
| 28 |
+
seq_ids: List[int]
|
| 29 |
+
sampling_params: SamplingParams
|
| 30 |
+
# seq_id -> sequence data.
|
| 31 |
+
seq_data: Dict[int, SequenceData]
|
| 32 |
+
# The length of the sequence (all tokens seen in the past + new token to
|
| 33 |
+
# compute attention) of the sequence group. None if it is in a decode
|
| 34 |
+
# stage.
|
| 35 |
+
seq_len: Optional[int]
|
| 36 |
+
# The length of new query tokens to compute in the current step. None if it
|
| 37 |
+
# is in a decode stage. The length of query_len <= seq_len if chunked
|
| 38 |
+
# prefill is enabled.
|
| 39 |
+
query_len: Optional[int]
|
| 40 |
+
# A random number generator for sampling.
|
| 41 |
+
generator: Optional[torch.Generator]
|
| 42 |
+
# True if the sequence group is in prefill stage. False if it is in a
|
| 43 |
+
# decode stage.
|
| 44 |
+
is_prompt: bool
|
| 45 |
+
# Query token indices from logits. to compute prompt logprob. Empty if
|
| 46 |
+
# prompt logprob is not required.
|
| 47 |
+
prompt_logprob_indices: List[int]
|
| 48 |
+
# Sample token indices from logits. Empty if sampling is not required.
|
| 49 |
+
sample_indices: List[int]
|
| 50 |
+
|
| 51 |
+
@property
|
| 52 |
+
def do_sample(self):
|
| 53 |
+
return len(self.sample_indices) > 0
|
| 54 |
+
|
| 55 |
+
def __post_init__(self):
|
| 56 |
+
if len(self.prompt_logprob_indices) > 0:
|
| 57 |
+
assert self.sampling_params.prompt_logprobs is not None
|
| 58 |
+
if self.is_prompt:
|
| 59 |
+
assert self.seq_len is not None
|
| 60 |
+
assert self.query_len is not None
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def gen_seq_group_to_sample_builder(num_seqs: int):
|
| 64 |
+
return lambda: SequenceGroupToSample(
|
| 65 |
+
seq_ids=[0] * num_seqs,
|
| 66 |
+
sampling_params=None,
|
| 67 |
+
seq_data=None, # type: ignore
|
| 68 |
+
seq_len=0,
|
| 69 |
+
query_len=0,
|
| 70 |
+
generator=None,
|
| 71 |
+
is_prompt=True,
|
| 72 |
+
prompt_logprob_indices=[],
|
| 73 |
+
sample_indices=[],
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class SamplingMetadataCache:
|
| 78 |
+
"""Used to cache SamplingMetadata objects between scheduler iterations"""
|
| 79 |
+
|
| 80 |
+
def __init__(self):
|
| 81 |
+
self._seq_group_to_sample_cache: Dict[int, PyObjectCache] = {}
|
| 82 |
+
|
| 83 |
+
def get_cached_seq_group_to_sample(self, num_seqs):
|
| 84 |
+
if num_seqs not in self._seq_group_to_sample_cache:
|
| 85 |
+
self._seq_group_to_sample_cache[num_seqs] = PyObjectCache(
|
| 86 |
+
gen_seq_group_to_sample_builder(num_seqs))
|
| 87 |
+
|
| 88 |
+
obj = self._seq_group_to_sample_cache[num_seqs].get_object()
|
| 89 |
+
return obj
|
| 90 |
+
|
| 91 |
+
def reset(self):
|
| 92 |
+
for cache in self._seq_group_to_sample_cache.values():
|
| 93 |
+
cache.reset()
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class SamplingMetadata:
|
| 97 |
+
"""Metadata for input sequences. Used in sampler.
|
| 98 |
+
|
| 99 |
+
The usage is as follow;
|
| 100 |
+
```
|
| 101 |
+
hidden_states = execute_model(...)
|
| 102 |
+
logits = hidden_states[sampling_metadata.selected_token_indices]
|
| 103 |
+
sample(logits)
|
| 104 |
+
|
| 105 |
+
def sample(logits):
|
| 106 |
+
# Use categorized_sample_indices for sampling....
|
| 107 |
+
```
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
seq_groups: List of batched sequence groups.
|
| 111 |
+
selected_token_indices: (num_query_tokens_to_logprob). Indices to find
|
| 112 |
+
logits from the initial model output hidden states.
|
| 113 |
+
categorized_sample_indices: SamplingType -> token indices to sample.
|
| 114 |
+
Each token indices is 2D tensor of (num_indices, num_indices) where
|
| 115 |
+
the first item means the sample index within the returned logit
|
| 116 |
+
(before pruning padding), and the second item means the sample
|
| 117 |
+
index after pruning using selected_token_indices.
|
| 118 |
+
For example, if the returned logit is [1, 2, 3], and we select
|
| 119 |
+
[1, 2] for sampling, the pruned logit will be [2, 3]. In this case,
|
| 120 |
+
The first tuple is [1, 2] (sampled index within original logit),
|
| 121 |
+
and the second tuple is [0, 1] (sampled index within pruned logit).
|
| 122 |
+
num_prompts: Number of prompt sequence groups in seq_groups.
|
| 123 |
+
skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU
|
| 124 |
+
serialization of token outputs.
|
| 125 |
+
reuse_sampling_tensors: Indicates if we want to reuse sampling
|
| 126 |
+
tensors that are part of the sampler forward pass. Currently,
|
| 127 |
+
it is mainly used for multi-step decode.
|
| 128 |
+
|
| 129 |
+
"""
|
| 130 |
+
|
| 131 |
+
def __init__(
|
| 132 |
+
self,
|
| 133 |
+
seq_groups: List[SequenceGroupToSample],
|
| 134 |
+
selected_token_indices: torch.Tensor,
|
| 135 |
+
categorized_sample_indices: Dict[SamplingType, torch.Tensor],
|
| 136 |
+
num_prompts: int,
|
| 137 |
+
skip_sampler_cpu_output: bool = False,
|
| 138 |
+
reuse_sampling_tensors: bool = False,
|
| 139 |
+
) -> None:
|
| 140 |
+
self.seq_groups = seq_groups
|
| 141 |
+
self.selected_token_indices = selected_token_indices
|
| 142 |
+
self.categorized_sample_indices = categorized_sample_indices
|
| 143 |
+
self.num_prompts = num_prompts
|
| 144 |
+
self.skip_sampler_cpu_output = skip_sampler_cpu_output
|
| 145 |
+
self.reuse_sampling_tensors = reuse_sampling_tensors
|
| 146 |
+
|
| 147 |
+
@staticmethod
|
| 148 |
+
def prepare(
|
| 149 |
+
seq_group_metadata_list: List[SequenceGroupMetadata],
|
| 150 |
+
seq_lens: List[int],
|
| 151 |
+
query_lens: List[int],
|
| 152 |
+
device: str,
|
| 153 |
+
pin_memory: bool,
|
| 154 |
+
generators: Optional[Dict[str, torch.Generator]] = None,
|
| 155 |
+
cache: Optional[SamplingMetadataCache] = None,
|
| 156 |
+
) -> "SamplingMetadata":
|
| 157 |
+
(
|
| 158 |
+
seq_groups,
|
| 159 |
+
selected_token_indices,
|
| 160 |
+
categorized_sample_indices,
|
| 161 |
+
num_prompts,
|
| 162 |
+
) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens,
|
| 163 |
+
device, generators, cache)
|
| 164 |
+
selected_token_indices = async_tensor_h2d(
|
| 165 |
+
selected_token_indices,
|
| 166 |
+
dtype=torch.long,
|
| 167 |
+
target_device=device,
|
| 168 |
+
pin_memory=pin_memory,
|
| 169 |
+
)
|
| 170 |
+
categorized_sample_indices = {
|
| 171 |
+
t:
|
| 172 |
+
async_tensor_h2d(
|
| 173 |
+
seq_ids,
|
| 174 |
+
dtype=torch.int,
|
| 175 |
+
target_device=device,
|
| 176 |
+
pin_memory=pin_memory,
|
| 177 |
+
)
|
| 178 |
+
for t, seq_ids in categorized_sample_indices.items()
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
sampling_metadata = SamplingMetadata(
|
| 182 |
+
seq_groups=seq_groups,
|
| 183 |
+
selected_token_indices=selected_token_indices,
|
| 184 |
+
categorized_sample_indices=categorized_sample_indices,
|
| 185 |
+
num_prompts=num_prompts,
|
| 186 |
+
)
|
| 187 |
+
return sampling_metadata
|
| 188 |
+
|
| 189 |
+
def __repr__(self) -> str:
|
| 190 |
+
return (
|
| 191 |
+
"SamplingMetadata("
|
| 192 |
+
f"seq_groups={self.seq_groups}, "
|
| 193 |
+
f"selected_token_indices={self.selected_token_indices}, "
|
| 194 |
+
f"categorized_sample_indices={self.categorized_sample_indices}), ")
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def _prepare_seq_groups(
|
| 198 |
+
seq_group_metadata_list: List[SequenceGroupMetadata],
|
| 199 |
+
seq_lens: List[int],
|
| 200 |
+
query_lens: List[int],
|
| 201 |
+
device: str,
|
| 202 |
+
generators: Optional[Dict[str, torch.Generator]] = None,
|
| 203 |
+
cache: Optional[SamplingMetadataCache] = None,
|
| 204 |
+
) -> Tuple[
|
| 205 |
+
List[SequenceGroupToSample],
|
| 206 |
+
List[int],
|
| 207 |
+
Dict[SamplingType, List[int]],
|
| 208 |
+
int,
|
| 209 |
+
]:
|
| 210 |
+
"""Prepare sequence groups and indices for sampling.
|
| 211 |
+
|
| 212 |
+
Args:
|
| 213 |
+
seq_group_metadata_list: A list of sequence group to batch.
|
| 214 |
+
seq_lens: A list of sequence lens per sequence group.
|
| 215 |
+
Index of prompt len should match with seq_group_metadata_list.
|
| 216 |
+
query_lens: A list of query lengths. Prompt lens include the length
|
| 217 |
+
of entire prompt tokens, and it could be shorter.
|
| 218 |
+
device: A device to use for random number generators,
|
| 219 |
+
`SequenceGroupToSample.generator`.
|
| 220 |
+
generators: A store of per-request random number generators used
|
| 221 |
+
for seeded requests.
|
| 222 |
+
|
| 223 |
+
Returns:
|
| 224 |
+
seq_groups: A list of sequence group to sample.
|
| 225 |
+
selected_token_indices: See the definition from `SamplingMetadata`.
|
| 226 |
+
categorized_sample_indices: See the definition from `SamplingMetadata`.
|
| 227 |
+
num_prompts: Total number of prompts from `seq_group_metadata_list`.
|
| 228 |
+
"""
|
| 229 |
+
# Batched sequence groups for the current model forward stsep.
|
| 230 |
+
seq_groups: List[SequenceGroupToSample] = []
|
| 231 |
+
# A list of token indices to sample/compute logprob. It is used to
|
| 232 |
+
# prune the outcome logits from the model for the performance.
|
| 233 |
+
selected_token_indices: List[int] = []
|
| 234 |
+
# Used for selected_token_indices.
|
| 235 |
+
model_output_idx = 0
|
| 236 |
+
|
| 237 |
+
# Sampling type -> (
|
| 238 |
+
# indices to sample/prompt logprob within pruned output logits,
|
| 239 |
+
# indices to sample within pruned logits)
|
| 240 |
+
categorized_sample_indices: Dict[SamplingType, List[int]] = {
|
| 241 |
+
t: []
|
| 242 |
+
for t in SamplingType
|
| 243 |
+
}
|
| 244 |
+
# Index of logits to compute logprob. Logits include both prompt logprob
|
| 245 |
+
# and sample logprob indices.
|
| 246 |
+
logit_idx = 0
|
| 247 |
+
# Total number of prompts from given sequence groups.
|
| 248 |
+
num_prompts = 0
|
| 249 |
+
|
| 250 |
+
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
| 251 |
+
seq_ids = seq_group_metadata.seq_data.keys()
|
| 252 |
+
|
| 253 |
+
if cache is not None:
|
| 254 |
+
sample_obj = cache.get_cached_seq_group_to_sample(len(seq_ids))
|
| 255 |
+
|
| 256 |
+
for j, seq_id in enumerate(seq_ids):
|
| 257 |
+
sample_obj.seq_ids[j] = seq_id
|
| 258 |
+
|
| 259 |
+
sample_obj.prompt_logprob_indices.clear()
|
| 260 |
+
sample_obj.sample_indices.clear()
|
| 261 |
+
|
| 262 |
+
sampling_params = seq_group_metadata.sampling_params
|
| 263 |
+
is_prompt = seq_group_metadata.is_prompt
|
| 264 |
+
generator: Optional[torch.Generator] = None
|
| 265 |
+
# If the current seq group is in decode stage, it is None.
|
| 266 |
+
seq_len: Optional[int] = None
|
| 267 |
+
query_len: Optional[int] = None
|
| 268 |
+
prompt_logprob_indices: List[int] = (sample_obj.prompt_logprob_indices
|
| 269 |
+
if cache is not None else [])
|
| 270 |
+
sample_indices: List[int] = (sample_obj.sample_indices
|
| 271 |
+
if cache is not None else [])
|
| 272 |
+
do_sample = seq_group_metadata.do_sample
|
| 273 |
+
|
| 274 |
+
if seq_group_metadata.is_prompt:
|
| 275 |
+
if sampling_params.seed is not None:
|
| 276 |
+
generator = torch.Generator(device=device).manual_seed(
|
| 277 |
+
sampling_params.seed)
|
| 278 |
+
if generators is not None:
|
| 279 |
+
generators[seq_group_metadata.request_id] = generator
|
| 280 |
+
|
| 281 |
+
num_prompts += 1
|
| 282 |
+
num_prefill_sample = len(seq_ids)
|
| 283 |
+
assert num_prefill_sample == 1
|
| 284 |
+
assert query_lens is not None and seq_lens is not None
|
| 285 |
+
query_len, seq_len = query_lens[i], seq_lens[i]
|
| 286 |
+
# If we need sampling, exclude num_prefill_sample tokens from
|
| 287 |
+
# prompt logprob.
|
| 288 |
+
prompt_logprob_len = (query_len - num_prefill_sample
|
| 289 |
+
if do_sample else query_len)
|
| 290 |
+
sample_len = num_prefill_sample if do_sample else 0
|
| 291 |
+
else:
|
| 292 |
+
# Decode
|
| 293 |
+
prompt_logprob_len = 0
|
| 294 |
+
query_len = query_lens[i] if query_lens is not None and len(
|
| 295 |
+
query_lens) > 0 else 1
|
| 296 |
+
sample_len = len(seq_ids) * query_len if do_sample else 0
|
| 297 |
+
|
| 298 |
+
if sampling_params.seed is not None and generators is not None:
|
| 299 |
+
generator = generators.get(seq_group_metadata.request_id)
|
| 300 |
+
|
| 301 |
+
# Update indices to select from the model output.
|
| 302 |
+
"""
|
| 303 |
+
This blocks computes selected_token_indices which is used in the
|
| 304 |
+
following way.
|
| 305 |
+
|
| 306 |
+
hidden_states = model(...)
|
| 307 |
+
logits = hidden_states[selected_token_indices]
|
| 308 |
+
"""
|
| 309 |
+
|
| 310 |
+
if sampling_params.prompt_logprobs is not None:
|
| 311 |
+
selected_token_indices.extend(
|
| 312 |
+
range(model_output_idx, model_output_idx + prompt_logprob_len))
|
| 313 |
+
model_output_idx += prompt_logprob_len
|
| 314 |
+
if do_sample:
|
| 315 |
+
selected_token_indices.extend(
|
| 316 |
+
range(model_output_idx, model_output_idx + sample_len))
|
| 317 |
+
model_output_idx += sample_len
|
| 318 |
+
|
| 319 |
+
# We now find indices for logprob computation and sampling.
|
| 320 |
+
"""
|
| 321 |
+
This block computes categorized_sample_indices which is used in the
|
| 322 |
+
following way.
|
| 323 |
+
|
| 324 |
+
hidden_states = model(...)
|
| 325 |
+
logits = hidden_states[selected_token_indices]
|
| 326 |
+
def sample(logits):
|
| 327 |
+
# Use categorized_sample_indices for sampling.
|
| 328 |
+
# prompt_logprob_indices to find prompt logprob indices.
|
| 329 |
+
# sample_indices to find sample indices.
|
| 330 |
+
"""
|
| 331 |
+
|
| 332 |
+
if sampling_params.prompt_logprobs is not None:
|
| 333 |
+
prompt_logprob_indices.extend(
|
| 334 |
+
range(logit_idx, logit_idx + prompt_logprob_len))
|
| 335 |
+
logit_idx += prompt_logprob_len
|
| 336 |
+
if do_sample:
|
| 337 |
+
sample_indices.extend(range(logit_idx, logit_idx + sample_len))
|
| 338 |
+
categorized_sample_indices[sampling_params.sampling_type].extend(
|
| 339 |
+
list(range(logit_idx, logit_idx + sample_len)))
|
| 340 |
+
logit_idx += sample_len
|
| 341 |
+
|
| 342 |
+
if cache is not None:
|
| 343 |
+
sample_obj.sampling_params = sampling_params
|
| 344 |
+
sample_obj.seq_data = seq_group_metadata.seq_data
|
| 345 |
+
sample_obj.seq_len = seq_len
|
| 346 |
+
sample_obj.query_len = query_len
|
| 347 |
+
sample_obj.generator = generator
|
| 348 |
+
sample_obj.is_prompt = is_prompt
|
| 349 |
+
else:
|
| 350 |
+
sample_obj = SequenceGroupToSample(
|
| 351 |
+
seq_ids=list(seq_ids),
|
| 352 |
+
sampling_params=sampling_params,
|
| 353 |
+
seq_data=seq_group_metadata.seq_data,
|
| 354 |
+
seq_len=seq_len,
|
| 355 |
+
query_len=query_len,
|
| 356 |
+
generator=generator,
|
| 357 |
+
is_prompt=is_prompt,
|
| 358 |
+
prompt_logprob_indices=list(prompt_logprob_indices),
|
| 359 |
+
sample_indices=list(sample_indices),
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
seq_groups.append(sample_obj)
|
| 363 |
+
|
| 364 |
+
if cache is not None:
|
| 365 |
+
cache.reset()
|
| 366 |
+
|
| 367 |
+
return (seq_groups, selected_token_indices, categorized_sample_indices,
|
| 368 |
+
num_prompts)
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
@dataclass
|
| 372 |
+
class SamplingTensors:
|
| 373 |
+
"""Tensors for sampling."""
|
| 374 |
+
|
| 375 |
+
temperatures: torch.Tensor
|
| 376 |
+
top_ps: torch.Tensor
|
| 377 |
+
top_ks: torch.Tensor
|
| 378 |
+
min_ps: torch.Tensor
|
| 379 |
+
presence_penalties: torch.Tensor
|
| 380 |
+
frequency_penalties: torch.Tensor
|
| 381 |
+
repetition_penalties: torch.Tensor
|
| 382 |
+
prompt_tokens: torch.Tensor
|
| 383 |
+
output_tokens: torch.Tensor
|
| 384 |
+
|
| 385 |
+
@classmethod
|
| 386 |
+
def from_sampling_metadata(
|
| 387 |
+
cls,
|
| 388 |
+
sampling_metadata: "SamplingMetadata",
|
| 389 |
+
vocab_size: int,
|
| 390 |
+
device: torch.device,
|
| 391 |
+
dtype: torch.dtype,
|
| 392 |
+
) -> Tuple["SamplingTensors", bool, bool, bool]:
|
| 393 |
+
prompt_tokens: List[array] = []
|
| 394 |
+
output_tokens: List[array] = []
|
| 395 |
+
top_ks: List[int] = []
|
| 396 |
+
temperatures: List[float] = []
|
| 397 |
+
top_ps: List[float] = []
|
| 398 |
+
min_ps: List[float] = []
|
| 399 |
+
presence_penalties: List[float] = []
|
| 400 |
+
frequency_penalties: List[float] = []
|
| 401 |
+
repetition_penalties: List[float] = []
|
| 402 |
+
do_penalties = False
|
| 403 |
+
do_top_p_top_k = False
|
| 404 |
+
do_min_p = False
|
| 405 |
+
|
| 406 |
+
assert sampling_metadata.seq_groups is not None
|
| 407 |
+
for seq_group in sampling_metadata.seq_groups:
|
| 408 |
+
seq_ids = seq_group.seq_ids
|
| 409 |
+
sampling_params = seq_group.sampling_params
|
| 410 |
+
temperature = sampling_params.temperature
|
| 411 |
+
p = sampling_params.presence_penalty
|
| 412 |
+
f = sampling_params.frequency_penalty
|
| 413 |
+
r = sampling_params.repetition_penalty
|
| 414 |
+
top_p = sampling_params.top_p
|
| 415 |
+
min_p = sampling_params.min_p
|
| 416 |
+
|
| 417 |
+
# k should not be greater than the vocab size.
|
| 418 |
+
top_k = min(sampling_params.top_k, vocab_size)
|
| 419 |
+
top_k = vocab_size if top_k == -1 else top_k
|
| 420 |
+
if temperature < _SAMPLING_EPS:
|
| 421 |
+
# NOTE: Zero temperature means deterministic sampling
|
| 422 |
+
# (i.e., greedy sampling or beam search).
|
| 423 |
+
# Set the temperature to 1 to avoid division by zero.
|
| 424 |
+
temperature = 1.0
|
| 425 |
+
if not do_top_p_top_k and (top_p < 1.0 - _SAMPLING_EPS
|
| 426 |
+
or top_k != vocab_size):
|
| 427 |
+
do_top_p_top_k = True
|
| 428 |
+
if not do_min_p and min_p > _SAMPLING_EPS:
|
| 429 |
+
do_min_p = True
|
| 430 |
+
if not do_penalties and (abs(p) >= _SAMPLING_EPS
|
| 431 |
+
or abs(f) >= _SAMPLING_EPS
|
| 432 |
+
or abs(r - 1.0) >= _SAMPLING_EPS):
|
| 433 |
+
do_penalties = True
|
| 434 |
+
|
| 435 |
+
is_prompt = seq_group.is_prompt
|
| 436 |
+
if is_prompt and sampling_params.prompt_logprobs is not None:
|
| 437 |
+
# For tokens in the prompt that we only need to get
|
| 438 |
+
# their logprobs
|
| 439 |
+
query_len = seq_group.query_len
|
| 440 |
+
assert query_len is not None
|
| 441 |
+
prefill_len = len(seq_group.prompt_logprob_indices)
|
| 442 |
+
temperatures += [temperature] * prefill_len
|
| 443 |
+
top_ps += [top_p] * prefill_len
|
| 444 |
+
top_ks += [top_k] * prefill_len
|
| 445 |
+
min_ps += [min_p] * prefill_len
|
| 446 |
+
presence_penalties += [0] * prefill_len
|
| 447 |
+
frequency_penalties += [0] * prefill_len
|
| 448 |
+
repetition_penalties += [1] * prefill_len
|
| 449 |
+
|
| 450 |
+
if seq_group.do_sample:
|
| 451 |
+
sample_lens = len(seq_group.sample_indices)
|
| 452 |
+
assert sample_lens >= len(seq_ids)
|
| 453 |
+
temperatures += [temperature] * sample_lens
|
| 454 |
+
top_ps += [top_p] * sample_lens
|
| 455 |
+
top_ks += [top_k] * sample_lens
|
| 456 |
+
min_ps += [min_p] * sample_lens
|
| 457 |
+
presence_penalties += [p] * sample_lens
|
| 458 |
+
frequency_penalties += [f] * sample_lens
|
| 459 |
+
repetition_penalties += [r] * sample_lens
|
| 460 |
+
|
| 461 |
+
if do_penalties:
|
| 462 |
+
for seq_group in sampling_metadata.seq_groups:
|
| 463 |
+
seq_ids = seq_group.seq_ids
|
| 464 |
+
sampling_params = seq_group.sampling_params
|
| 465 |
+
if (seq_group.is_prompt
|
| 466 |
+
and sampling_params.prompt_logprobs is not None):
|
| 467 |
+
prefill_len = len(seq_group.prompt_logprob_indices)
|
| 468 |
+
prompt_tokens.extend(
|
| 469 |
+
array(VLLM_TOKEN_ID_ARRAY_TYPE)
|
| 470 |
+
for _ in range(prefill_len))
|
| 471 |
+
output_tokens.extend(
|
| 472 |
+
array(VLLM_TOKEN_ID_ARRAY_TYPE)
|
| 473 |
+
for _ in range(prefill_len))
|
| 474 |
+
if seq_group.do_sample:
|
| 475 |
+
for seq_id in seq_ids:
|
| 476 |
+
seq_data = seq_group.seq_data[seq_id]
|
| 477 |
+
prompt_tokens.append(seq_data.prompt_token_ids_array)
|
| 478 |
+
output_tokens.append(seq_data.output_token_ids_array)
|
| 479 |
+
|
| 480 |
+
sampling_tensors = SamplingTensors.from_lists(
|
| 481 |
+
temperatures,
|
| 482 |
+
top_ps,
|
| 483 |
+
top_ks,
|
| 484 |
+
min_ps,
|
| 485 |
+
presence_penalties,
|
| 486 |
+
frequency_penalties,
|
| 487 |
+
repetition_penalties,
|
| 488 |
+
prompt_tokens,
|
| 489 |
+
output_tokens,
|
| 490 |
+
vocab_size,
|
| 491 |
+
device,
|
| 492 |
+
dtype,
|
| 493 |
+
)
|
| 494 |
+
return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p)
|
| 495 |
+
|
| 496 |
+
@classmethod
|
| 497 |
+
def from_lists(
|
| 498 |
+
cls,
|
| 499 |
+
temperatures: List[float],
|
| 500 |
+
top_ps: List[float],
|
| 501 |
+
top_ks: List[int],
|
| 502 |
+
min_ps: List[float],
|
| 503 |
+
presence_penalties: List[float],
|
| 504 |
+
frequency_penalties: List[float],
|
| 505 |
+
repetition_penalties: List[float],
|
| 506 |
+
prompt_tokens: List[array],
|
| 507 |
+
output_tokens: List[array],
|
| 508 |
+
vocab_size: int,
|
| 509 |
+
device: torch.device,
|
| 510 |
+
dtype: torch.dtype,
|
| 511 |
+
) -> "SamplingTensors":
|
| 512 |
+
# Note that the performance will be very bad without
|
| 513 |
+
# pinned memory.
|
| 514 |
+
pin_memory = is_pin_memory_available()
|
| 515 |
+
|
| 516 |
+
do_penalties = prompt_tokens or output_tokens
|
| 517 |
+
|
| 518 |
+
if do_penalties:
|
| 519 |
+
prompt_t = make_tensor_with_pad(
|
| 520 |
+
prompt_tokens,
|
| 521 |
+
vocab_size,
|
| 522 |
+
device="cpu",
|
| 523 |
+
dtype=torch.int64,
|
| 524 |
+
pin_memory=pin_memory,
|
| 525 |
+
)
|
| 526 |
+
output_t = make_tensor_with_pad(
|
| 527 |
+
output_tokens,
|
| 528 |
+
vocab_size,
|
| 529 |
+
device="cpu",
|
| 530 |
+
dtype=torch.int64,
|
| 531 |
+
pin_memory=pin_memory,
|
| 532 |
+
)
|
| 533 |
+
else:
|
| 534 |
+
empty_tensor = torch.empty(0, device=device, dtype=torch.long)
|
| 535 |
+
prompt_t = empty_tensor
|
| 536 |
+
output_t = empty_tensor
|
| 537 |
+
|
| 538 |
+
temperatures_t = torch.tensor(
|
| 539 |
+
temperatures,
|
| 540 |
+
device="cpu",
|
| 541 |
+
dtype=dtype,
|
| 542 |
+
pin_memory=pin_memory,
|
| 543 |
+
)
|
| 544 |
+
top_ps_t = torch.tensor(
|
| 545 |
+
top_ps,
|
| 546 |
+
device="cpu",
|
| 547 |
+
dtype=dtype,
|
| 548 |
+
pin_memory=pin_memory,
|
| 549 |
+
)
|
| 550 |
+
min_ps_t = torch.tensor(
|
| 551 |
+
min_ps,
|
| 552 |
+
device="cpu",
|
| 553 |
+
dtype=dtype,
|
| 554 |
+
pin_memory=pin_memory,
|
| 555 |
+
)
|
| 556 |
+
presence_penalties_t = torch.tensor(
|
| 557 |
+
presence_penalties,
|
| 558 |
+
device="cpu",
|
| 559 |
+
dtype=dtype,
|
| 560 |
+
pin_memory=pin_memory,
|
| 561 |
+
)
|
| 562 |
+
frequency_penalties_t = torch.tensor(
|
| 563 |
+
frequency_penalties,
|
| 564 |
+
device="cpu",
|
| 565 |
+
dtype=dtype,
|
| 566 |
+
pin_memory=pin_memory,
|
| 567 |
+
)
|
| 568 |
+
repetition_penalties_t = torch.tensor(
|
| 569 |
+
repetition_penalties,
|
| 570 |
+
device="cpu",
|
| 571 |
+
dtype=dtype,
|
| 572 |
+
pin_memory=pin_memory,
|
| 573 |
+
)
|
| 574 |
+
top_ks_t = torch.tensor(
|
| 575 |
+
top_ks,
|
| 576 |
+
device="cpu",
|
| 577 |
+
dtype=torch.int,
|
| 578 |
+
pin_memory=pin_memory,
|
| 579 |
+
)
|
| 580 |
+
# Because the memory is pinned, we can do non-blocking
|
| 581 |
+
# transfer to device.
|
| 582 |
+
|
| 583 |
+
return cls(
|
| 584 |
+
temperatures=temperatures_t.to(device=device, non_blocking=True),
|
| 585 |
+
top_ps=top_ps_t.to(device=device, non_blocking=True),
|
| 586 |
+
top_ks=top_ks_t.to(device=device, non_blocking=True),
|
| 587 |
+
min_ps=min_ps_t.to(device=device, non_blocking=True),
|
| 588 |
+
presence_penalties=presence_penalties_t.to(device=device,
|
| 589 |
+
non_blocking=True),
|
| 590 |
+
frequency_penalties=frequency_penalties_t.to(device=device,
|
| 591 |
+
non_blocking=True),
|
| 592 |
+
repetition_penalties=repetition_penalties_t.to(device=device,
|
| 593 |
+
non_blocking=True),
|
| 594 |
+
prompt_tokens=prompt_t.to(device=device, non_blocking=True),
|
| 595 |
+
output_tokens=output_t.to(device=device, non_blocking=True),
|
| 596 |
+
)
|
vllm_hacked/sampling_params.py
ADDED
|
@@ -0,0 +1,596 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
| 3 |
+
"""Sampling parameters for text generation."""
|
| 4 |
+
import copy
|
| 5 |
+
import warnings
|
| 6 |
+
from dataclasses import field
|
| 7 |
+
from enum import Enum, IntEnum
|
| 8 |
+
from functools import cached_property
|
| 9 |
+
from typing import Annotated, Any, Optional, Union
|
| 10 |
+
|
| 11 |
+
import msgspec
|
| 12 |
+
from pydantic.dataclasses import dataclass
|
| 13 |
+
|
| 14 |
+
from vllm.logger import init_logger
|
| 15 |
+
from vllm.logits_process import LogitsProcessor
|
| 16 |
+
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
| 17 |
+
|
| 18 |
+
logger = init_logger(__name__)
|
| 19 |
+
|
| 20 |
+
_SAMPLING_EPS = 1e-5
|
| 21 |
+
_MAX_TEMP = 1e-2
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class SamplingType(IntEnum):
|
| 25 |
+
GREEDY = 0
|
| 26 |
+
RANDOM = 1
|
| 27 |
+
RANDOM_SEED = 2
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# maybe make msgspec?
|
| 31 |
+
@dataclass
|
| 32 |
+
class StructuredOutputsParams:
|
| 33 |
+
# One of these fields will be used to build a logit processor.
|
| 34 |
+
json: Optional[Union[str, dict]] = None
|
| 35 |
+
regex: Optional[str] = None
|
| 36 |
+
choice: Optional[list[str]] = None
|
| 37 |
+
grammar: Optional[str] = None
|
| 38 |
+
json_object: Optional[bool] = None
|
| 39 |
+
# These are other options that can be set.
|
| 40 |
+
disable_fallback: bool = False
|
| 41 |
+
disable_any_whitespace: bool = False
|
| 42 |
+
disable_additional_properties: bool = False
|
| 43 |
+
whitespace_pattern: Optional[str] = None
|
| 44 |
+
structural_tag: Optional[str] = None
|
| 45 |
+
|
| 46 |
+
_backend: Optional[str] = field(default=None, init=False)
|
| 47 |
+
"""CAUTION: Should only be set by Processor._validate_structured_output"""
|
| 48 |
+
_backend_was_auto: bool = field(default=False, init=False)
|
| 49 |
+
"""CAUTION: Should only be set by Processor._validate_structured_output"""
|
| 50 |
+
|
| 51 |
+
def __post_init__(self):
|
| 52 |
+
"""Validate that some fields are mutually exclusive."""
|
| 53 |
+
count = sum([
|
| 54 |
+
self.json is not None, self.regex is not None, self.choice
|
| 55 |
+
is not None, self.grammar is not None, self.json_object is not None
|
| 56 |
+
])
|
| 57 |
+
if count > 1:
|
| 58 |
+
raise ValueError(
|
| 59 |
+
"You can only use one kind of structured outputs constraint "
|
| 60 |
+
f"but multiple are specified: {self.__dict__}")
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@dataclass
|
| 64 |
+
class GuidedDecodingParams(StructuredOutputsParams):
|
| 65 |
+
|
| 66 |
+
def __post_init__(self):
|
| 67 |
+
warnings.warn(
|
| 68 |
+
"GuidedDecodingParams is deprecated. This will be removed in "
|
| 69 |
+
"v0.12.0 or v1.0.0, which ever is soonest. Please use "
|
| 70 |
+
"StructuredOutputsParams instead.",
|
| 71 |
+
DeprecationWarning,
|
| 72 |
+
stacklevel=2)
|
| 73 |
+
return super().__post_init__()
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class RequestOutputKind(Enum):
|
| 77 |
+
# Return entire output so far in every RequestOutput
|
| 78 |
+
CUMULATIVE = 0
|
| 79 |
+
# Return only deltas in each RequestOutput
|
| 80 |
+
DELTA = 1
|
| 81 |
+
# Do not return intermediate RequestOutput
|
| 82 |
+
FINAL_ONLY = 2
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class SamplingParams(
|
| 86 |
+
msgspec.Struct,
|
| 87 |
+
omit_defaults=True, # type: ignore[call-arg]
|
| 88 |
+
# required for @cached_property.
|
| 89 |
+
dict=True): # type: ignore[call-arg]
|
| 90 |
+
"""Sampling parameters for text generation.
|
| 91 |
+
|
| 92 |
+
Overall, we follow the sampling parameters from the OpenAI text completion
|
| 93 |
+
API (https://platform.openai.com/docs/api-reference/completions/create).
|
| 94 |
+
In addition, we support beam search, which is not supported by OpenAI.
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
n: int = 1
|
| 98 |
+
"""Number of outputs to return for the given prompt request.
|
| 99 |
+
|
| 100 |
+
NOTE:
|
| 101 |
+
`AsyncLLM` streams outputs by default. When `n > 1`, all `n` outputs
|
| 102 |
+
are generated and streamed cumulatively per request. To see all `n`
|
| 103 |
+
outputs upon completion, use `output_kind=RequestOutputKind.FINAL_ONLY`
|
| 104 |
+
in `SamplingParams`."""
|
| 105 |
+
best_of: Optional[int] = None
|
| 106 |
+
"""Number of output sequences that are generated from the prompt. From
|
| 107 |
+
these `best_of` sequences, the top `n` sequences are returned. `best_of`
|
| 108 |
+
must be greater than or equal to `n`. By default, `best_of` is set to `n`.
|
| 109 |
+
Warning, this is only supported in V0."""
|
| 110 |
+
_real_n: Optional[int] = None
|
| 111 |
+
presence_penalty: float = 0.0
|
| 112 |
+
"""Penalizes new tokens based on whether they appear in the generated text
|
| 113 |
+
so far. Values > 0 encourage the model to use new tokens, while values < 0
|
| 114 |
+
encourage the model to repeat tokens."""
|
| 115 |
+
frequency_penalty: float = 0.0
|
| 116 |
+
"""Penalizes new tokens based on their frequency in the generated text so
|
| 117 |
+
far. Values > 0 encourage the model to use new tokens, while values < 0
|
| 118 |
+
encourage the model to repeat tokens."""
|
| 119 |
+
repetition_penalty: float = 1.0
|
| 120 |
+
"""Penalizes new tokens based on whether they appear in the prompt and the
|
| 121 |
+
generated text so far. Values > 1 encourage the model to use new tokens,
|
| 122 |
+
while values < 1 encourage the model to repeat tokens."""
|
| 123 |
+
temperature: float = 1.0
|
| 124 |
+
"""Controls the randomness of the sampling. Lower values make the model
|
| 125 |
+
more deterministic, while higher values make the model more random. Zero
|
| 126 |
+
means greedy sampling."""
|
| 127 |
+
top_p: float = 1.0
|
| 128 |
+
"""Controls the cumulative probability of the top tokens to consider. Must
|
| 129 |
+
be in (0, 1]. Set to 1 to consider all tokens."""
|
| 130 |
+
top_k: int = 0
|
| 131 |
+
"""Controls the number of top tokens to consider. Set to 0 (or -1) to
|
| 132 |
+
consider all tokens."""
|
| 133 |
+
min_p: float = 0.0
|
| 134 |
+
"""Represents the minimum probability for a token to be considered,
|
| 135 |
+
relative to the probability of the most likely token. Must be in [0, 1].
|
| 136 |
+
Set to 0 to disable this."""
|
| 137 |
+
seed: Optional[int] = None
|
| 138 |
+
"""Random seed to use for the generation."""
|
| 139 |
+
stop: Optional[Union[str, list[str]]] = None
|
| 140 |
+
"""String(s) that stop the generation when they are generated. The returned
|
| 141 |
+
output will not contain the stop strings."""
|
| 142 |
+
stop_token_ids: Optional[list[int]] = None
|
| 143 |
+
"""Token IDs that stop the generation when they are generated. The returned
|
| 144 |
+
output will contain the stop tokens unless the stop tokens are special
|
| 145 |
+
tokens."""
|
| 146 |
+
ignore_eos: bool = False
|
| 147 |
+
"""Whether to ignore the EOS token and continue generating
|
| 148 |
+
tokens after the EOS token is generated."""
|
| 149 |
+
max_tokens: Optional[int] = 16
|
| 150 |
+
"""Maximum number of tokens to generate per output sequence."""
|
| 151 |
+
min_tokens: int = 0
|
| 152 |
+
"""Minimum number of tokens to generate per output sequence before EOS or
|
| 153 |
+
`stop_token_ids` can be generated"""
|
| 154 |
+
logprobs: Optional[int] = None
|
| 155 |
+
"""Number of log probabilities to return per output token. When set to
|
| 156 |
+
`None`, no probability is returned. If set to a non-`None` value, the
|
| 157 |
+
result includes the log probabilities of the specified number of most
|
| 158 |
+
likely tokens, as well as the chosen tokens. Note that the implementation
|
| 159 |
+
follows the OpenAI API: The API will always return the log probability of
|
| 160 |
+
the sampled token, so there may be up to `logprobs+1` elements in the
|
| 161 |
+
response. When set to -1, return all `vocab_size` log probabilities."""
|
| 162 |
+
prompt_logprobs: Optional[int] = None
|
| 163 |
+
"""Number of log probabilities to return per prompt token.
|
| 164 |
+
When set to -1, return all `vocab_size` log probabilities."""
|
| 165 |
+
# NOTE: This parameter is only exposed at the engine level for now.
|
| 166 |
+
# It is not exposed in the OpenAI API server, as the OpenAI API does
|
| 167 |
+
# not support returning only a list of token IDs.
|
| 168 |
+
detokenize: bool = True
|
| 169 |
+
"""Whether to detokenize the output."""
|
| 170 |
+
skip_special_tokens: bool = True
|
| 171 |
+
"""Whether to skip special tokens in the output."""
|
| 172 |
+
spaces_between_special_tokens: bool = True
|
| 173 |
+
"""Whether to add spaces between special tokens in the output."""
|
| 174 |
+
# Optional[list[LogitsProcessor]] type. We use Any here because
|
| 175 |
+
# Optional[list[LogitsProcessor]] type is not supported by msgspec.
|
| 176 |
+
logits_processors: Optional[Any] = None
|
| 177 |
+
"""Functions that modify logits based on previously generated tokens, and
|
| 178 |
+
optionally prompt tokens as a first argument."""
|
| 179 |
+
include_stop_str_in_output: bool = False
|
| 180 |
+
"""Whether to include the stop strings in output text."""
|
| 181 |
+
truncate_prompt_tokens: Optional[Annotated[int,
|
| 182 |
+
msgspec.Meta(ge=-1)]] = None
|
| 183 |
+
"""If set to -1, will use the truncation size supported by the model. If
|
| 184 |
+
set to an integer k, will use only the last k tokens from the prompt
|
| 185 |
+
(i.e., left truncation). If set to `None`, truncation is disabled."""
|
| 186 |
+
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE
|
| 187 |
+
|
| 188 |
+
# The below fields are not supposed to be used as an input.
|
| 189 |
+
# They are set in post_init.
|
| 190 |
+
output_text_buffer_length: int = 0
|
| 191 |
+
_all_stop_token_ids: set[int] = msgspec.field(default_factory=set)
|
| 192 |
+
|
| 193 |
+
# Fields used to construct logits processors
|
| 194 |
+
structured_outputs: Optional[StructuredOutputsParams] = None
|
| 195 |
+
"""Parameters for configuring structured outputs."""
|
| 196 |
+
guided_decoding: Optional[GuidedDecodingParams] = None
|
| 197 |
+
"""Deprecated alias for structured_outputs."""
|
| 198 |
+
logit_bias: Optional[dict[int, float]] = None
|
| 199 |
+
"""If provided, the engine will construct a logits processor that applies
|
| 200 |
+
these logit biases."""
|
| 201 |
+
allowed_token_ids: Optional[list[int]] = None
|
| 202 |
+
"""If provided, the engine will construct a logits processor which only
|
| 203 |
+
retains scores for the given token ids."""
|
| 204 |
+
extra_args: Optional[dict[str, Any]] = None
|
| 205 |
+
"""Arbitrary additional args, that can be used by custom sampling
|
| 206 |
+
implementations, plugins, etc. Not used by any in-tree sampling
|
| 207 |
+
implementations."""
|
| 208 |
+
guidance_scale: Optional[float] = None
|
| 209 |
+
|
| 210 |
+
# Fields used for bad words
|
| 211 |
+
bad_words: Optional[list[str]] = None
|
| 212 |
+
"""Words that are not allowed to be generated. More precisely, only the
|
| 213 |
+
last token of a corresponding token sequence is not allowed when the next
|
| 214 |
+
generated token can complete the sequence."""
|
| 215 |
+
_bad_words_token_ids: Optional[list[list[int]]] = None
|
| 216 |
+
|
| 217 |
+
@staticmethod
|
| 218 |
+
def from_optional(
|
| 219 |
+
n: Optional[int] = 1,
|
| 220 |
+
best_of: Optional[int] = None,
|
| 221 |
+
presence_penalty: Optional[float] = 0.0,
|
| 222 |
+
frequency_penalty: Optional[float] = 0.0,
|
| 223 |
+
repetition_penalty: Optional[float] = 1.0,
|
| 224 |
+
temperature: Optional[float] = 1.0,
|
| 225 |
+
top_p: Optional[float] = 1.0,
|
| 226 |
+
top_k: int = 0,
|
| 227 |
+
min_p: float = 0.0,
|
| 228 |
+
seed: Optional[int] = None,
|
| 229 |
+
stop: Optional[Union[str, list[str]]] = None,
|
| 230 |
+
stop_token_ids: Optional[list[int]] = None,
|
| 231 |
+
bad_words: Optional[list[str]] = None,
|
| 232 |
+
include_stop_str_in_output: bool = False,
|
| 233 |
+
ignore_eos: bool = False,
|
| 234 |
+
max_tokens: Optional[int] = 16,
|
| 235 |
+
min_tokens: int = 0,
|
| 236 |
+
logprobs: Optional[int] = None,
|
| 237 |
+
prompt_logprobs: Optional[int] = None,
|
| 238 |
+
detokenize: bool = True,
|
| 239 |
+
skip_special_tokens: bool = True,
|
| 240 |
+
spaces_between_special_tokens: bool = True,
|
| 241 |
+
logits_processors: Optional[list[LogitsProcessor]] = None,
|
| 242 |
+
truncate_prompt_tokens: Optional[Annotated[int,
|
| 243 |
+
msgspec.Meta(
|
| 244 |
+
ge=-1)]] = None,
|
| 245 |
+
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
|
| 246 |
+
structured_outputs: Optional[StructuredOutputsParams] = None,
|
| 247 |
+
guided_decoding: Optional[GuidedDecodingParams] = None,
|
| 248 |
+
logit_bias: Optional[Union[dict[int, float], dict[str, float]]] = None,
|
| 249 |
+
allowed_token_ids: Optional[list[int]] = None,
|
| 250 |
+
extra_args: Optional[dict[str, Any]] = None,
|
| 251 |
+
guidance_scale: Optional[float] = None,
|
| 252 |
+
) -> "SamplingParams":
|
| 253 |
+
if logit_bias is not None:
|
| 254 |
+
# Convert token_id to integer
|
| 255 |
+
# Clamp the bias between -100 and 100 per OpenAI API spec
|
| 256 |
+
logit_bias = {
|
| 257 |
+
int(token): min(100.0, max(-100.0, bias))
|
| 258 |
+
for token, bias in logit_bias.items()
|
| 259 |
+
}
|
| 260 |
+
if guided_decoding is not None:
|
| 261 |
+
warnings.warn(
|
| 262 |
+
"guided_decoding is deprecated. This will be removed in "
|
| 263 |
+
"v0.12.0 or v1.0.0, which ever is soonest. Please use "
|
| 264 |
+
"structured_outputs instead.",
|
| 265 |
+
DeprecationWarning,
|
| 266 |
+
stacklevel=2)
|
| 267 |
+
structured_outputs = guided_decoding
|
| 268 |
+
guided_decoding = None
|
| 269 |
+
|
| 270 |
+
return SamplingParams(
|
| 271 |
+
n=1 if n is None else n,
|
| 272 |
+
best_of=best_of,
|
| 273 |
+
presence_penalty=0.0
|
| 274 |
+
if presence_penalty is None else presence_penalty,
|
| 275 |
+
frequency_penalty=0.0
|
| 276 |
+
if frequency_penalty is None else frequency_penalty,
|
| 277 |
+
repetition_penalty=1.0
|
| 278 |
+
if repetition_penalty is None else repetition_penalty,
|
| 279 |
+
temperature=1.0 if temperature is None else temperature,
|
| 280 |
+
top_p=1.0 if top_p is None else top_p,
|
| 281 |
+
top_k=top_k,
|
| 282 |
+
min_p=min_p,
|
| 283 |
+
seed=seed,
|
| 284 |
+
stop=stop,
|
| 285 |
+
stop_token_ids=stop_token_ids,
|
| 286 |
+
bad_words=bad_words,
|
| 287 |
+
include_stop_str_in_output=include_stop_str_in_output,
|
| 288 |
+
ignore_eos=ignore_eos,
|
| 289 |
+
max_tokens=max_tokens,
|
| 290 |
+
min_tokens=min_tokens,
|
| 291 |
+
logprobs=logprobs,
|
| 292 |
+
prompt_logprobs=prompt_logprobs,
|
| 293 |
+
detokenize=detokenize,
|
| 294 |
+
skip_special_tokens=skip_special_tokens,
|
| 295 |
+
spaces_between_special_tokens=spaces_between_special_tokens,
|
| 296 |
+
logits_processors=logits_processors,
|
| 297 |
+
truncate_prompt_tokens=truncate_prompt_tokens,
|
| 298 |
+
output_kind=output_kind,
|
| 299 |
+
structured_outputs=structured_outputs,
|
| 300 |
+
logit_bias=logit_bias,
|
| 301 |
+
allowed_token_ids=allowed_token_ids,
|
| 302 |
+
extra_args=extra_args,
|
| 303 |
+
guidance_scale=guidance_scale,
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
def __post_init__(self) -> None:
|
| 307 |
+
# how we deal with `best_of``:
|
| 308 |
+
# if `best_of`` is not set, we default to `n`;
|
| 309 |
+
# if `best_of`` is set, we set `n`` to `best_of`,
|
| 310 |
+
# and set `_real_n`` to the original `n`.
|
| 311 |
+
# when we return the result, we will check
|
| 312 |
+
# if we need to return `n` or `_real_n` results
|
| 313 |
+
if self.best_of:
|
| 314 |
+
if self.best_of < self.n:
|
| 315 |
+
raise ValueError(
|
| 316 |
+
f"best_of must be greater than or equal to n, "
|
| 317 |
+
f"got n={self.n} and best_of={self.best_of}.")
|
| 318 |
+
if not self._real_n:
|
| 319 |
+
self._real_n = self.n
|
| 320 |
+
self.n = self.best_of
|
| 321 |
+
|
| 322 |
+
if 0 < self.temperature < _MAX_TEMP:
|
| 323 |
+
logger.warning(
|
| 324 |
+
"temperature %s is less than %s, which may cause numerical "
|
| 325 |
+
"errors nan or inf in tensors. We have maxed it out to %s.",
|
| 326 |
+
self.temperature, _MAX_TEMP, _MAX_TEMP)
|
| 327 |
+
self.temperature = max(self.temperature, _MAX_TEMP)
|
| 328 |
+
|
| 329 |
+
if self.seed == -1:
|
| 330 |
+
self.seed = None
|
| 331 |
+
|
| 332 |
+
if self.stop is None:
|
| 333 |
+
self.stop = []
|
| 334 |
+
elif isinstance(self.stop, str):
|
| 335 |
+
self.stop = [self.stop]
|
| 336 |
+
|
| 337 |
+
if self.stop_token_ids is None:
|
| 338 |
+
self.stop_token_ids = []
|
| 339 |
+
|
| 340 |
+
if self.bad_words is None:
|
| 341 |
+
self.bad_words = []
|
| 342 |
+
|
| 343 |
+
if self.logprobs is True:
|
| 344 |
+
self.logprobs = 1
|
| 345 |
+
|
| 346 |
+
if self.prompt_logprobs is True:
|
| 347 |
+
self.prompt_logprobs = 1
|
| 348 |
+
|
| 349 |
+
# Number of characters to hold back for stop string evaluation
|
| 350 |
+
# until sequence is finished.
|
| 351 |
+
if self.stop and not self.include_stop_str_in_output:
|
| 352 |
+
self.output_text_buffer_length = max(len(s) for s in self.stop) - 1
|
| 353 |
+
|
| 354 |
+
self._verify_args()
|
| 355 |
+
|
| 356 |
+
if self.temperature < _SAMPLING_EPS:
|
| 357 |
+
# Zero temperature means greedy sampling.
|
| 358 |
+
self.top_p = 1.0
|
| 359 |
+
self.top_k = 0
|
| 360 |
+
self.min_p = 0.0
|
| 361 |
+
self._verify_greedy_sampling()
|
| 362 |
+
|
| 363 |
+
# eos_token_id is added to this by the engine
|
| 364 |
+
self._all_stop_token_ids.update(self.stop_token_ids)
|
| 365 |
+
|
| 366 |
+
if self.guided_decoding is not None:
|
| 367 |
+
warnings.warn(
|
| 368 |
+
"guided_decoding is deprecated. This will be removed in "
|
| 369 |
+
"v0.12.0 or v1.0.0, which ever is soonest. Please use "
|
| 370 |
+
"structured_outputs instead.",
|
| 371 |
+
DeprecationWarning,
|
| 372 |
+
stacklevel=2)
|
| 373 |
+
self.structured_outputs = self.guided_decoding
|
| 374 |
+
self.guided_decoding = None
|
| 375 |
+
|
| 376 |
+
def _verify_args(self) -> None:
|
| 377 |
+
if not isinstance(self.n, int):
|
| 378 |
+
raise ValueError(f"n must be an int, but is of "
|
| 379 |
+
f"type {type(self.n)}")
|
| 380 |
+
if self.n < 1:
|
| 381 |
+
raise ValueError(f"n must be at least 1, got {self.n}.")
|
| 382 |
+
if self.best_of is not None:
|
| 383 |
+
if not isinstance(self.best_of, int):
|
| 384 |
+
raise ValueError(
|
| 385 |
+
f"best_of must be an integer, got {type(self.best_of)}")
|
| 386 |
+
if self.best_of < 1:
|
| 387 |
+
raise ValueError(
|
| 388 |
+
f"best_of must be at least 1, got {self.best_of}")
|
| 389 |
+
if self.best_of < self.n:
|
| 390 |
+
raise ValueError(
|
| 391 |
+
f"best_of must be greater than or equal to n, "
|
| 392 |
+
f"got n={self.n} and best_of={self.best_of}.")
|
| 393 |
+
if not -2.0 <= self.presence_penalty <= 2.0:
|
| 394 |
+
raise ValueError("presence_penalty must be in [-2, 2], got "
|
| 395 |
+
f"{self.presence_penalty}.")
|
| 396 |
+
if not -2.0 <= self.frequency_penalty <= 2.0:
|
| 397 |
+
raise ValueError("frequency_penalty must be in [-2, 2], got "
|
| 398 |
+
f"{self.frequency_penalty}.")
|
| 399 |
+
if self.repetition_penalty <= 0.0:
|
| 400 |
+
raise ValueError(
|
| 401 |
+
"repetition_penalty must be greater than zero, got "
|
| 402 |
+
f"{self.repetition_penalty}.")
|
| 403 |
+
if self.temperature < 0.0:
|
| 404 |
+
raise ValueError(
|
| 405 |
+
f"temperature must be non-negative, got {self.temperature}.")
|
| 406 |
+
if not 0.0 < self.top_p <= 1.0:
|
| 407 |
+
raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
|
| 408 |
+
# quietly accept -1 as disabled, but prefer 0
|
| 409 |
+
if self.top_k < -1:
|
| 410 |
+
raise ValueError(f"top_k must be 0 (disable), or at least 1, "
|
| 411 |
+
f"got {self.top_k}.")
|
| 412 |
+
if not isinstance(self.top_k, int):
|
| 413 |
+
raise TypeError(
|
| 414 |
+
f"top_k must be an integer, got {type(self.top_k).__name__}")
|
| 415 |
+
if not 0.0 <= self.min_p <= 1.0:
|
| 416 |
+
raise ValueError("min_p must be in [0, 1], got "
|
| 417 |
+
f"{self.min_p}.")
|
| 418 |
+
if self.max_tokens is not None and self.max_tokens < 1:
|
| 419 |
+
raise ValueError(
|
| 420 |
+
f"max_tokens must be at least 1, got {self.max_tokens}.")
|
| 421 |
+
if self.min_tokens < 0:
|
| 422 |
+
raise ValueError(f"min_tokens must be greater than or equal to 0, "
|
| 423 |
+
f"got {self.min_tokens}.")
|
| 424 |
+
if self.max_tokens is not None and self.min_tokens > self.max_tokens:
|
| 425 |
+
raise ValueError(
|
| 426 |
+
f"min_tokens must be less than or equal to "
|
| 427 |
+
f"max_tokens={self.max_tokens}, got {self.min_tokens}.")
|
| 428 |
+
if (self.logprobs is not None and self.logprobs != -1
|
| 429 |
+
and self.logprobs < 0):
|
| 430 |
+
raise ValueError(
|
| 431 |
+
f"logprobs must be non-negative or -1, got {self.logprobs}.")
|
| 432 |
+
if (self.prompt_logprobs is not None and self.prompt_logprobs != -1
|
| 433 |
+
and self.prompt_logprobs < 0):
|
| 434 |
+
raise ValueError(
|
| 435 |
+
f"prompt_logprobs must be non-negative or -1, got "
|
| 436 |
+
f"{self.prompt_logprobs}.")
|
| 437 |
+
if (self.truncate_prompt_tokens is not None
|
| 438 |
+
and (self.truncate_prompt_tokens == 0
|
| 439 |
+
or self.truncate_prompt_tokens < -1)):
|
| 440 |
+
raise ValueError(
|
| 441 |
+
f"truncate_prompt_tokens must be an integer >= 1 or -1, "
|
| 442 |
+
f"got {self.truncate_prompt_tokens}")
|
| 443 |
+
assert isinstance(self.stop_token_ids, list)
|
| 444 |
+
if not all(isinstance(st_id, int) for st_id in self.stop_token_ids):
|
| 445 |
+
raise ValueError(f"stop_token_ids must contain only integers, "
|
| 446 |
+
f"got {self.stop_token_ids}.")
|
| 447 |
+
assert isinstance(self.stop, list)
|
| 448 |
+
if any(not stop_str for stop_str in self.stop):
|
| 449 |
+
raise ValueError("stop cannot contain an empty string.")
|
| 450 |
+
if self.stop and not self.detokenize:
|
| 451 |
+
raise ValueError(
|
| 452 |
+
"stop strings are only supported when detokenize is True. "
|
| 453 |
+
"Set detokenize=True to use stop.")
|
| 454 |
+
if self.best_of != self._real_n and self.output_kind == (
|
| 455 |
+
RequestOutputKind.DELTA):
|
| 456 |
+
raise ValueError("best_of must equal n to use output_kind=DELTA")
|
| 457 |
+
|
| 458 |
+
def _verify_greedy_sampling(self) -> None:
|
| 459 |
+
if self.n > 1:
|
| 460 |
+
raise ValueError("n must be 1 when using greedy sampling, "
|
| 461 |
+
f"got {self.n}.")
|
| 462 |
+
|
| 463 |
+
def update_from_generation_config(
|
| 464 |
+
self,
|
| 465 |
+
generation_config: dict[str, Any],
|
| 466 |
+
model_eos_token_id: Optional[int] = None) -> None:
|
| 467 |
+
"""Update if there are non-default values from generation_config"""
|
| 468 |
+
|
| 469 |
+
if model_eos_token_id is not None:
|
| 470 |
+
# Add the eos token id into the sampling_params to support
|
| 471 |
+
# min_tokens processing.
|
| 472 |
+
self._all_stop_token_ids.add(model_eos_token_id)
|
| 473 |
+
|
| 474 |
+
# Update eos_token_id for generation
|
| 475 |
+
if (eos_ids := generation_config.get("eos_token_id")) is not None:
|
| 476 |
+
# it can be either int or list of int
|
| 477 |
+
eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
|
| 478 |
+
if model_eos_token_id is not None:
|
| 479 |
+
# We don't need to include the primary eos_token_id in
|
| 480 |
+
# stop_token_ids since it's handled separately for stopping
|
| 481 |
+
# purposes.
|
| 482 |
+
eos_ids.discard(model_eos_token_id)
|
| 483 |
+
if eos_ids:
|
| 484 |
+
self._all_stop_token_ids.update(eos_ids)
|
| 485 |
+
if not self.ignore_eos:
|
| 486 |
+
eos_ids.update(self.stop_token_ids)
|
| 487 |
+
self.stop_token_ids = list(eos_ids)
|
| 488 |
+
|
| 489 |
+
def update_from_tokenizer(self, tokenizer: AnyTokenizer) -> None:
|
| 490 |
+
if not self.bad_words:
|
| 491 |
+
return
|
| 492 |
+
self._bad_words_token_ids = []
|
| 493 |
+
for bad_word in self.bad_words:
|
| 494 |
+
# To prohibit words both at the beginning
|
| 495 |
+
# and in the middle of text
|
| 496 |
+
# (related to add_prefix_space tokenizer parameter)
|
| 497 |
+
for add_prefix_space in [False, True]:
|
| 498 |
+
prefix = " " if add_prefix_space else ""
|
| 499 |
+
prompt = prefix + bad_word.lstrip()
|
| 500 |
+
prompt_token_ids = tokenizer.encode(text=prompt,
|
| 501 |
+
add_special_tokens=False)
|
| 502 |
+
|
| 503 |
+
# If no space at the beginning
|
| 504 |
+
# or if prefix space produces a new word token
|
| 505 |
+
if (not add_prefix_space) or (
|
| 506 |
+
add_prefix_space and prompt_token_ids[0]
|
| 507 |
+
!= self._bad_words_token_ids[-1][0]
|
| 508 |
+
and len(prompt_token_ids) == len(
|
| 509 |
+
self._bad_words_token_ids[-1])):
|
| 510 |
+
self._bad_words_token_ids.append(prompt_token_ids)
|
| 511 |
+
|
| 512 |
+
invalid_token_ids = [
|
| 513 |
+
token_id for bad_words_token_ids in self._bad_words_token_ids
|
| 514 |
+
for token_id in bad_words_token_ids
|
| 515 |
+
if token_id < 0 or token_id > tokenizer.max_token_id
|
| 516 |
+
]
|
| 517 |
+
if len(invalid_token_ids) > 0:
|
| 518 |
+
raise ValueError(
|
| 519 |
+
f"The model vocabulary size is {tokenizer.max_token_id+1},"
|
| 520 |
+
f" but the following tokens"
|
| 521 |
+
f" were specified as bad: {invalid_token_ids}."
|
| 522 |
+
f" All token id values should be integers satisfying:"
|
| 523 |
+
f" 0 <= token_id <= {tokenizer.max_token_id}.")
|
| 524 |
+
|
| 525 |
+
@cached_property
|
| 526 |
+
def sampling_type(self) -> SamplingType:
|
| 527 |
+
if self.temperature < _SAMPLING_EPS:
|
| 528 |
+
return SamplingType.GREEDY
|
| 529 |
+
if self.seed is not None:
|
| 530 |
+
return SamplingType.RANDOM_SEED
|
| 531 |
+
return SamplingType.RANDOM
|
| 532 |
+
|
| 533 |
+
@property
|
| 534 |
+
def all_stop_token_ids(self) -> set[int]:
|
| 535 |
+
return self._all_stop_token_ids
|
| 536 |
+
|
| 537 |
+
@property
|
| 538 |
+
def bad_words_token_ids(self) -> Optional[list[list[int]]]:
|
| 539 |
+
# For internal use only. Backward compatibility not guaranteed
|
| 540 |
+
return self._bad_words_token_ids
|
| 541 |
+
|
| 542 |
+
def clone(self) -> "SamplingParams":
|
| 543 |
+
"""Deep copy, but maybe not the LogitsProcessor objects.
|
| 544 |
+
|
| 545 |
+
LogitsProcessor objects may contain an arbitrary, nontrivial amount of
|
| 546 |
+
data that is expensive to copy. However, if not copied, the processor
|
| 547 |
+
needs to support parallel decoding for multiple sequences
|
| 548 |
+
See https://github.com/vllm-project/vllm/issues/3087
|
| 549 |
+
"""
|
| 550 |
+
|
| 551 |
+
logit_processor_refs = None if self.logits_processors is None else {
|
| 552 |
+
id(lp): lp.clone() if hasattr(lp, 'clone') else lp
|
| 553 |
+
for lp in self.logits_processors
|
| 554 |
+
}
|
| 555 |
+
return copy.deepcopy(self, memo=logit_processor_refs)
|
| 556 |
+
|
| 557 |
+
def __repr__(self) -> str:
|
| 558 |
+
return (
|
| 559 |
+
f"SamplingParams(n={self.n}, "
|
| 560 |
+
f"presence_penalty={self.presence_penalty}, "
|
| 561 |
+
f"frequency_penalty={self.frequency_penalty}, "
|
| 562 |
+
f"repetition_penalty={self.repetition_penalty}, "
|
| 563 |
+
f"temperature={self.temperature}, "
|
| 564 |
+
f"top_p={self.top_p}, "
|
| 565 |
+
f"top_k={self.top_k}, "
|
| 566 |
+
f"min_p={self.min_p}, "
|
| 567 |
+
f"seed={self.seed}, "
|
| 568 |
+
f"stop={self.stop}, "
|
| 569 |
+
f"stop_token_ids={self.stop_token_ids}, "
|
| 570 |
+
f"bad_words={self.bad_words}, "
|
| 571 |
+
f"include_stop_str_in_output={self.include_stop_str_in_output}, "
|
| 572 |
+
f"ignore_eos={self.ignore_eos}, "
|
| 573 |
+
f"max_tokens={self.max_tokens}, "
|
| 574 |
+
f"min_tokens={self.min_tokens}, "
|
| 575 |
+
f"logprobs={self.logprobs}, "
|
| 576 |
+
f"prompt_logprobs={self.prompt_logprobs}, "
|
| 577 |
+
f"skip_special_tokens={self.skip_special_tokens}, "
|
| 578 |
+
"spaces_between_special_tokens="
|
| 579 |
+
f"{self.spaces_between_special_tokens}, "
|
| 580 |
+
f"truncate_prompt_tokens={self.truncate_prompt_tokens}, "
|
| 581 |
+
f"structured_outputs={self.structured_outputs}, "
|
| 582 |
+
f"extra_args={self.extra_args})")
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
class BeamSearchParams(
|
| 586 |
+
msgspec.Struct,
|
| 587 |
+
omit_defaults=True, # type: ignore[call-arg]
|
| 588 |
+
# required for @cached_property.
|
| 589 |
+
dict=True): # type: ignore[call-arg]
|
| 590 |
+
"""Beam search parameters for text generation."""
|
| 591 |
+
beam_width: int
|
| 592 |
+
max_tokens: int
|
| 593 |
+
ignore_eos: bool = False
|
| 594 |
+
temperature: float = 0.0
|
| 595 |
+
length_penalty: float = 1.0
|
| 596 |
+
include_stop_str_in_output: bool = False
|
vllm_hacked/sampling_params_ori.py
ADDED
|
@@ -0,0 +1,593 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
| 3 |
+
"""Sampling parameters for text generation."""
|
| 4 |
+
import copy
|
| 5 |
+
import warnings
|
| 6 |
+
from dataclasses import field
|
| 7 |
+
from enum import Enum, IntEnum
|
| 8 |
+
from functools import cached_property
|
| 9 |
+
from typing import Annotated, Any, Optional, Union
|
| 10 |
+
|
| 11 |
+
import msgspec
|
| 12 |
+
from pydantic.dataclasses import dataclass
|
| 13 |
+
|
| 14 |
+
from vllm.logger import init_logger
|
| 15 |
+
from vllm.logits_process import LogitsProcessor
|
| 16 |
+
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
| 17 |
+
|
| 18 |
+
logger = init_logger(__name__)
|
| 19 |
+
|
| 20 |
+
_SAMPLING_EPS = 1e-5
|
| 21 |
+
_MAX_TEMP = 1e-2
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class SamplingType(IntEnum):
|
| 25 |
+
GREEDY = 0
|
| 26 |
+
RANDOM = 1
|
| 27 |
+
RANDOM_SEED = 2
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# maybe make msgspec?
|
| 31 |
+
@dataclass
|
| 32 |
+
class StructuredOutputsParams:
|
| 33 |
+
# One of these fields will be used to build a logit processor.
|
| 34 |
+
json: Optional[Union[str, dict]] = None
|
| 35 |
+
regex: Optional[str] = None
|
| 36 |
+
choice: Optional[list[str]] = None
|
| 37 |
+
grammar: Optional[str] = None
|
| 38 |
+
json_object: Optional[bool] = None
|
| 39 |
+
# These are other options that can be set.
|
| 40 |
+
disable_fallback: bool = False
|
| 41 |
+
disable_any_whitespace: bool = False
|
| 42 |
+
disable_additional_properties: bool = False
|
| 43 |
+
whitespace_pattern: Optional[str] = None
|
| 44 |
+
structural_tag: Optional[str] = None
|
| 45 |
+
|
| 46 |
+
_backend: Optional[str] = field(default=None, init=False)
|
| 47 |
+
"""CAUTION: Should only be set by Processor._validate_structured_output"""
|
| 48 |
+
_backend_was_auto: bool = field(default=False, init=False)
|
| 49 |
+
"""CAUTION: Should only be set by Processor._validate_structured_output"""
|
| 50 |
+
|
| 51 |
+
def __post_init__(self):
|
| 52 |
+
"""Validate that some fields are mutually exclusive."""
|
| 53 |
+
count = sum([
|
| 54 |
+
self.json is not None, self.regex is not None, self.choice
|
| 55 |
+
is not None, self.grammar is not None, self.json_object is not None
|
| 56 |
+
])
|
| 57 |
+
if count > 1:
|
| 58 |
+
raise ValueError(
|
| 59 |
+
"You can only use one kind of structured outputs constraint "
|
| 60 |
+
f"but multiple are specified: {self.__dict__}")
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@dataclass
|
| 64 |
+
class GuidedDecodingParams(StructuredOutputsParams):
|
| 65 |
+
|
| 66 |
+
def __post_init__(self):
|
| 67 |
+
warnings.warn(
|
| 68 |
+
"GuidedDecodingParams is deprecated. This will be removed in "
|
| 69 |
+
"v0.12.0 or v1.0.0, which ever is soonest. Please use "
|
| 70 |
+
"StructuredOutputsParams instead.",
|
| 71 |
+
DeprecationWarning,
|
| 72 |
+
stacklevel=2)
|
| 73 |
+
return super().__post_init__()
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class RequestOutputKind(Enum):
|
| 77 |
+
# Return entire output so far in every RequestOutput
|
| 78 |
+
CUMULATIVE = 0
|
| 79 |
+
# Return only deltas in each RequestOutput
|
| 80 |
+
DELTA = 1
|
| 81 |
+
# Do not return intermediate RequestOutput
|
| 82 |
+
FINAL_ONLY = 2
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class SamplingParams(
|
| 86 |
+
msgspec.Struct,
|
| 87 |
+
omit_defaults=True, # type: ignore[call-arg]
|
| 88 |
+
# required for @cached_property.
|
| 89 |
+
dict=True): # type: ignore[call-arg]
|
| 90 |
+
"""Sampling parameters for text generation.
|
| 91 |
+
|
| 92 |
+
Overall, we follow the sampling parameters from the OpenAI text completion
|
| 93 |
+
API (https://platform.openai.com/docs/api-reference/completions/create).
|
| 94 |
+
In addition, we support beam search, which is not supported by OpenAI.
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
n: int = 1
|
| 98 |
+
"""Number of outputs to return for the given prompt request.
|
| 99 |
+
|
| 100 |
+
NOTE:
|
| 101 |
+
`AsyncLLM` streams outputs by default. When `n > 1`, all `n` outputs
|
| 102 |
+
are generated and streamed cumulatively per request. To see all `n`
|
| 103 |
+
outputs upon completion, use `output_kind=RequestOutputKind.FINAL_ONLY`
|
| 104 |
+
in `SamplingParams`."""
|
| 105 |
+
best_of: Optional[int] = None
|
| 106 |
+
"""Number of output sequences that are generated from the prompt. From
|
| 107 |
+
these `best_of` sequences, the top `n` sequences are returned. `best_of`
|
| 108 |
+
must be greater than or equal to `n`. By default, `best_of` is set to `n`.
|
| 109 |
+
Warning, this is only supported in V0."""
|
| 110 |
+
_real_n: Optional[int] = None
|
| 111 |
+
presence_penalty: float = 0.0
|
| 112 |
+
"""Penalizes new tokens based on whether they appear in the generated text
|
| 113 |
+
so far. Values > 0 encourage the model to use new tokens, while values < 0
|
| 114 |
+
encourage the model to repeat tokens."""
|
| 115 |
+
frequency_penalty: float = 0.0
|
| 116 |
+
"""Penalizes new tokens based on their frequency in the generated text so
|
| 117 |
+
far. Values > 0 encourage the model to use new tokens, while values < 0
|
| 118 |
+
encourage the model to repeat tokens."""
|
| 119 |
+
repetition_penalty: float = 1.0
|
| 120 |
+
"""Penalizes new tokens based on whether they appear in the prompt and the
|
| 121 |
+
generated text so far. Values > 1 encourage the model to use new tokens,
|
| 122 |
+
while values < 1 encourage the model to repeat tokens."""
|
| 123 |
+
temperature: float = 1.0
|
| 124 |
+
"""Controls the randomness of the sampling. Lower values make the model
|
| 125 |
+
more deterministic, while higher values make the model more random. Zero
|
| 126 |
+
means greedy sampling."""
|
| 127 |
+
top_p: float = 1.0
|
| 128 |
+
"""Controls the cumulative probability of the top tokens to consider. Must
|
| 129 |
+
be in (0, 1]. Set to 1 to consider all tokens."""
|
| 130 |
+
top_k: int = 0
|
| 131 |
+
"""Controls the number of top tokens to consider. Set to 0 (or -1) to
|
| 132 |
+
consider all tokens."""
|
| 133 |
+
min_p: float = 0.0
|
| 134 |
+
"""Represents the minimum probability for a token to be considered,
|
| 135 |
+
relative to the probability of the most likely token. Must be in [0, 1].
|
| 136 |
+
Set to 0 to disable this."""
|
| 137 |
+
seed: Optional[int] = None
|
| 138 |
+
"""Random seed to use for the generation."""
|
| 139 |
+
stop: Optional[Union[str, list[str]]] = None
|
| 140 |
+
"""String(s) that stop the generation when they are generated. The returned
|
| 141 |
+
output will not contain the stop strings."""
|
| 142 |
+
stop_token_ids: Optional[list[int]] = None
|
| 143 |
+
"""Token IDs that stop the generation when they are generated. The returned
|
| 144 |
+
output will contain the stop tokens unless the stop tokens are special
|
| 145 |
+
tokens."""
|
| 146 |
+
ignore_eos: bool = False
|
| 147 |
+
"""Whether to ignore the EOS token and continue generating
|
| 148 |
+
tokens after the EOS token is generated."""
|
| 149 |
+
max_tokens: Optional[int] = 16
|
| 150 |
+
"""Maximum number of tokens to generate per output sequence."""
|
| 151 |
+
min_tokens: int = 0
|
| 152 |
+
"""Minimum number of tokens to generate per output sequence before EOS or
|
| 153 |
+
`stop_token_ids` can be generated"""
|
| 154 |
+
logprobs: Optional[int] = None
|
| 155 |
+
"""Number of log probabilities to return per output token. When set to
|
| 156 |
+
`None`, no probability is returned. If set to a non-`None` value, the
|
| 157 |
+
result includes the log probabilities of the specified number of most
|
| 158 |
+
likely tokens, as well as the chosen tokens. Note that the implementation
|
| 159 |
+
follows the OpenAI API: The API will always return the log probability of
|
| 160 |
+
the sampled token, so there may be up to `logprobs+1` elements in the
|
| 161 |
+
response. When set to -1, return all `vocab_size` log probabilities."""
|
| 162 |
+
prompt_logprobs: Optional[int] = None
|
| 163 |
+
"""Number of log probabilities to return per prompt token.
|
| 164 |
+
When set to -1, return all `vocab_size` log probabilities."""
|
| 165 |
+
# NOTE: This parameter is only exposed at the engine level for now.
|
| 166 |
+
# It is not exposed in the OpenAI API server, as the OpenAI API does
|
| 167 |
+
# not support returning only a list of token IDs.
|
| 168 |
+
detokenize: bool = True
|
| 169 |
+
"""Whether to detokenize the output."""
|
| 170 |
+
skip_special_tokens: bool = True
|
| 171 |
+
"""Whether to skip special tokens in the output."""
|
| 172 |
+
spaces_between_special_tokens: bool = True
|
| 173 |
+
"""Whether to add spaces between special tokens in the output."""
|
| 174 |
+
# Optional[list[LogitsProcessor]] type. We use Any here because
|
| 175 |
+
# Optional[list[LogitsProcessor]] type is not supported by msgspec.
|
| 176 |
+
logits_processors: Optional[Any] = None
|
| 177 |
+
"""Functions that modify logits based on previously generated tokens, and
|
| 178 |
+
optionally prompt tokens as a first argument."""
|
| 179 |
+
include_stop_str_in_output: bool = False
|
| 180 |
+
"""Whether to include the stop strings in output text."""
|
| 181 |
+
truncate_prompt_tokens: Optional[Annotated[int,
|
| 182 |
+
msgspec.Meta(ge=-1)]] = None
|
| 183 |
+
"""If set to -1, will use the truncation size supported by the model. If
|
| 184 |
+
set to an integer k, will use only the last k tokens from the prompt
|
| 185 |
+
(i.e., left truncation). If set to `None`, truncation is disabled."""
|
| 186 |
+
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE
|
| 187 |
+
|
| 188 |
+
# The below fields are not supposed to be used as an input.
|
| 189 |
+
# They are set in post_init.
|
| 190 |
+
output_text_buffer_length: int = 0
|
| 191 |
+
_all_stop_token_ids: set[int] = msgspec.field(default_factory=set)
|
| 192 |
+
|
| 193 |
+
# Fields used to construct logits processors
|
| 194 |
+
structured_outputs: Optional[StructuredOutputsParams] = None
|
| 195 |
+
"""Parameters for configuring structured outputs."""
|
| 196 |
+
guided_decoding: Optional[GuidedDecodingParams] = None
|
| 197 |
+
"""Deprecated alias for structured_outputs."""
|
| 198 |
+
logit_bias: Optional[dict[int, float]] = None
|
| 199 |
+
"""If provided, the engine will construct a logits processor that applies
|
| 200 |
+
these logit biases."""
|
| 201 |
+
allowed_token_ids: Optional[list[int]] = None
|
| 202 |
+
"""If provided, the engine will construct a logits processor which only
|
| 203 |
+
retains scores for the given token ids."""
|
| 204 |
+
extra_args: Optional[dict[str, Any]] = None
|
| 205 |
+
"""Arbitrary additional args, that can be used by custom sampling
|
| 206 |
+
implementations, plugins, etc. Not used by any in-tree sampling
|
| 207 |
+
implementations."""
|
| 208 |
+
|
| 209 |
+
# Fields used for bad words
|
| 210 |
+
bad_words: Optional[list[str]] = None
|
| 211 |
+
"""Words that are not allowed to be generated. More precisely, only the
|
| 212 |
+
last token of a corresponding token sequence is not allowed when the next
|
| 213 |
+
generated token can complete the sequence."""
|
| 214 |
+
_bad_words_token_ids: Optional[list[list[int]]] = None
|
| 215 |
+
|
| 216 |
+
@staticmethod
|
| 217 |
+
def from_optional(
|
| 218 |
+
n: Optional[int] = 1,
|
| 219 |
+
best_of: Optional[int] = None,
|
| 220 |
+
presence_penalty: Optional[float] = 0.0,
|
| 221 |
+
frequency_penalty: Optional[float] = 0.0,
|
| 222 |
+
repetition_penalty: Optional[float] = 1.0,
|
| 223 |
+
temperature: Optional[float] = 1.0,
|
| 224 |
+
top_p: Optional[float] = 1.0,
|
| 225 |
+
top_k: int = 0,
|
| 226 |
+
min_p: float = 0.0,
|
| 227 |
+
seed: Optional[int] = None,
|
| 228 |
+
stop: Optional[Union[str, list[str]]] = None,
|
| 229 |
+
stop_token_ids: Optional[list[int]] = None,
|
| 230 |
+
bad_words: Optional[list[str]] = None,
|
| 231 |
+
include_stop_str_in_output: bool = False,
|
| 232 |
+
ignore_eos: bool = False,
|
| 233 |
+
max_tokens: Optional[int] = 16,
|
| 234 |
+
min_tokens: int = 0,
|
| 235 |
+
logprobs: Optional[int] = None,
|
| 236 |
+
prompt_logprobs: Optional[int] = None,
|
| 237 |
+
detokenize: bool = True,
|
| 238 |
+
skip_special_tokens: bool = True,
|
| 239 |
+
spaces_between_special_tokens: bool = True,
|
| 240 |
+
logits_processors: Optional[list[LogitsProcessor]] = None,
|
| 241 |
+
truncate_prompt_tokens: Optional[Annotated[int,
|
| 242 |
+
msgspec.Meta(
|
| 243 |
+
ge=-1)]] = None,
|
| 244 |
+
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
|
| 245 |
+
structured_outputs: Optional[StructuredOutputsParams] = None,
|
| 246 |
+
guided_decoding: Optional[GuidedDecodingParams] = None,
|
| 247 |
+
logit_bias: Optional[Union[dict[int, float], dict[str, float]]] = None,
|
| 248 |
+
allowed_token_ids: Optional[list[int]] = None,
|
| 249 |
+
extra_args: Optional[dict[str, Any]] = None,
|
| 250 |
+
) -> "SamplingParams":
|
| 251 |
+
if logit_bias is not None:
|
| 252 |
+
# Convert token_id to integer
|
| 253 |
+
# Clamp the bias between -100 and 100 per OpenAI API spec
|
| 254 |
+
logit_bias = {
|
| 255 |
+
int(token): min(100.0, max(-100.0, bias))
|
| 256 |
+
for token, bias in logit_bias.items()
|
| 257 |
+
}
|
| 258 |
+
if guided_decoding is not None:
|
| 259 |
+
warnings.warn(
|
| 260 |
+
"guided_decoding is deprecated. This will be removed in "
|
| 261 |
+
"v0.12.0 or v1.0.0, which ever is soonest. Please use "
|
| 262 |
+
"structured_outputs instead.",
|
| 263 |
+
DeprecationWarning,
|
| 264 |
+
stacklevel=2)
|
| 265 |
+
structured_outputs = guided_decoding
|
| 266 |
+
guided_decoding = None
|
| 267 |
+
|
| 268 |
+
return SamplingParams(
|
| 269 |
+
n=1 if n is None else n,
|
| 270 |
+
best_of=best_of,
|
| 271 |
+
presence_penalty=0.0
|
| 272 |
+
if presence_penalty is None else presence_penalty,
|
| 273 |
+
frequency_penalty=0.0
|
| 274 |
+
if frequency_penalty is None else frequency_penalty,
|
| 275 |
+
repetition_penalty=1.0
|
| 276 |
+
if repetition_penalty is None else repetition_penalty,
|
| 277 |
+
temperature=1.0 if temperature is None else temperature,
|
| 278 |
+
top_p=1.0 if top_p is None else top_p,
|
| 279 |
+
top_k=top_k,
|
| 280 |
+
min_p=min_p,
|
| 281 |
+
seed=seed,
|
| 282 |
+
stop=stop,
|
| 283 |
+
stop_token_ids=stop_token_ids,
|
| 284 |
+
bad_words=bad_words,
|
| 285 |
+
include_stop_str_in_output=include_stop_str_in_output,
|
| 286 |
+
ignore_eos=ignore_eos,
|
| 287 |
+
max_tokens=max_tokens,
|
| 288 |
+
min_tokens=min_tokens,
|
| 289 |
+
logprobs=logprobs,
|
| 290 |
+
prompt_logprobs=prompt_logprobs,
|
| 291 |
+
detokenize=detokenize,
|
| 292 |
+
skip_special_tokens=skip_special_tokens,
|
| 293 |
+
spaces_between_special_tokens=spaces_between_special_tokens,
|
| 294 |
+
logits_processors=logits_processors,
|
| 295 |
+
truncate_prompt_tokens=truncate_prompt_tokens,
|
| 296 |
+
output_kind=output_kind,
|
| 297 |
+
structured_outputs=structured_outputs,
|
| 298 |
+
logit_bias=logit_bias,
|
| 299 |
+
allowed_token_ids=allowed_token_ids,
|
| 300 |
+
extra_args=extra_args,
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
def __post_init__(self) -> None:
|
| 304 |
+
# how we deal with `best_of``:
|
| 305 |
+
# if `best_of`` is not set, we default to `n`;
|
| 306 |
+
# if `best_of`` is set, we set `n`` to `best_of`,
|
| 307 |
+
# and set `_real_n`` to the original `n`.
|
| 308 |
+
# when we return the result, we will check
|
| 309 |
+
# if we need to return `n` or `_real_n` results
|
| 310 |
+
if self.best_of:
|
| 311 |
+
if self.best_of < self.n:
|
| 312 |
+
raise ValueError(
|
| 313 |
+
f"best_of must be greater than or equal to n, "
|
| 314 |
+
f"got n={self.n} and best_of={self.best_of}.")
|
| 315 |
+
if not self._real_n:
|
| 316 |
+
self._real_n = self.n
|
| 317 |
+
self.n = self.best_of
|
| 318 |
+
|
| 319 |
+
if 0 < self.temperature < _MAX_TEMP:
|
| 320 |
+
logger.warning(
|
| 321 |
+
"temperature %s is less than %s, which may cause numerical "
|
| 322 |
+
"errors nan or inf in tensors. We have maxed it out to %s.",
|
| 323 |
+
self.temperature, _MAX_TEMP, _MAX_TEMP)
|
| 324 |
+
self.temperature = max(self.temperature, _MAX_TEMP)
|
| 325 |
+
|
| 326 |
+
if self.seed == -1:
|
| 327 |
+
self.seed = None
|
| 328 |
+
|
| 329 |
+
if self.stop is None:
|
| 330 |
+
self.stop = []
|
| 331 |
+
elif isinstance(self.stop, str):
|
| 332 |
+
self.stop = [self.stop]
|
| 333 |
+
|
| 334 |
+
if self.stop_token_ids is None:
|
| 335 |
+
self.stop_token_ids = []
|
| 336 |
+
|
| 337 |
+
if self.bad_words is None:
|
| 338 |
+
self.bad_words = []
|
| 339 |
+
|
| 340 |
+
if self.logprobs is True:
|
| 341 |
+
self.logprobs = 1
|
| 342 |
+
|
| 343 |
+
if self.prompt_logprobs is True:
|
| 344 |
+
self.prompt_logprobs = 1
|
| 345 |
+
|
| 346 |
+
# Number of characters to hold back for stop string evaluation
|
| 347 |
+
# until sequence is finished.
|
| 348 |
+
if self.stop and not self.include_stop_str_in_output:
|
| 349 |
+
self.output_text_buffer_length = max(len(s) for s in self.stop) - 1
|
| 350 |
+
|
| 351 |
+
self._verify_args()
|
| 352 |
+
|
| 353 |
+
if self.temperature < _SAMPLING_EPS:
|
| 354 |
+
# Zero temperature means greedy sampling.
|
| 355 |
+
self.top_p = 1.0
|
| 356 |
+
self.top_k = 0
|
| 357 |
+
self.min_p = 0.0
|
| 358 |
+
self._verify_greedy_sampling()
|
| 359 |
+
|
| 360 |
+
# eos_token_id is added to this by the engine
|
| 361 |
+
self._all_stop_token_ids.update(self.stop_token_ids)
|
| 362 |
+
|
| 363 |
+
if self.guided_decoding is not None:
|
| 364 |
+
warnings.warn(
|
| 365 |
+
"guided_decoding is deprecated. This will be removed in "
|
| 366 |
+
"v0.12.0 or v1.0.0, which ever is soonest. Please use "
|
| 367 |
+
"structured_outputs instead.",
|
| 368 |
+
DeprecationWarning,
|
| 369 |
+
stacklevel=2)
|
| 370 |
+
self.structured_outputs = self.guided_decoding
|
| 371 |
+
self.guided_decoding = None
|
| 372 |
+
|
| 373 |
+
def _verify_args(self) -> None:
|
| 374 |
+
if not isinstance(self.n, int):
|
| 375 |
+
raise ValueError(f"n must be an int, but is of "
|
| 376 |
+
f"type {type(self.n)}")
|
| 377 |
+
if self.n < 1:
|
| 378 |
+
raise ValueError(f"n must be at least 1, got {self.n}.")
|
| 379 |
+
if self.best_of is not None:
|
| 380 |
+
if not isinstance(self.best_of, int):
|
| 381 |
+
raise ValueError(
|
| 382 |
+
f"best_of must be an integer, got {type(self.best_of)}")
|
| 383 |
+
if self.best_of < 1:
|
| 384 |
+
raise ValueError(
|
| 385 |
+
f"best_of must be at least 1, got {self.best_of}")
|
| 386 |
+
if self.best_of < self.n:
|
| 387 |
+
raise ValueError(
|
| 388 |
+
f"best_of must be greater than or equal to n, "
|
| 389 |
+
f"got n={self.n} and best_of={self.best_of}.")
|
| 390 |
+
if not -2.0 <= self.presence_penalty <= 2.0:
|
| 391 |
+
raise ValueError("presence_penalty must be in [-2, 2], got "
|
| 392 |
+
f"{self.presence_penalty}.")
|
| 393 |
+
if not -2.0 <= self.frequency_penalty <= 2.0:
|
| 394 |
+
raise ValueError("frequency_penalty must be in [-2, 2], got "
|
| 395 |
+
f"{self.frequency_penalty}.")
|
| 396 |
+
if self.repetition_penalty <= 0.0:
|
| 397 |
+
raise ValueError(
|
| 398 |
+
"repetition_penalty must be greater than zero, got "
|
| 399 |
+
f"{self.repetition_penalty}.")
|
| 400 |
+
if self.temperature < 0.0:
|
| 401 |
+
raise ValueError(
|
| 402 |
+
f"temperature must be non-negative, got {self.temperature}.")
|
| 403 |
+
if not 0.0 < self.top_p <= 1.0:
|
| 404 |
+
raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
|
| 405 |
+
# quietly accept -1 as disabled, but prefer 0
|
| 406 |
+
if self.top_k < -1:
|
| 407 |
+
raise ValueError(f"top_k must be 0 (disable), or at least 1, "
|
| 408 |
+
f"got {self.top_k}.")
|
| 409 |
+
if not isinstance(self.top_k, int):
|
| 410 |
+
raise TypeError(
|
| 411 |
+
f"top_k must be an integer, got {type(self.top_k).__name__}")
|
| 412 |
+
if not 0.0 <= self.min_p <= 1.0:
|
| 413 |
+
raise ValueError("min_p must be in [0, 1], got "
|
| 414 |
+
f"{self.min_p}.")
|
| 415 |
+
if self.max_tokens is not None and self.max_tokens < 1:
|
| 416 |
+
raise ValueError(
|
| 417 |
+
f"max_tokens must be at least 1, got {self.max_tokens}.")
|
| 418 |
+
if self.min_tokens < 0:
|
| 419 |
+
raise ValueError(f"min_tokens must be greater than or equal to 0, "
|
| 420 |
+
f"got {self.min_tokens}.")
|
| 421 |
+
if self.max_tokens is not None and self.min_tokens > self.max_tokens:
|
| 422 |
+
raise ValueError(
|
| 423 |
+
f"min_tokens must be less than or equal to "
|
| 424 |
+
f"max_tokens={self.max_tokens}, got {self.min_tokens}.")
|
| 425 |
+
if (self.logprobs is not None and self.logprobs != -1
|
| 426 |
+
and self.logprobs < 0):
|
| 427 |
+
raise ValueError(
|
| 428 |
+
f"logprobs must be non-negative or -1, got {self.logprobs}.")
|
| 429 |
+
if (self.prompt_logprobs is not None and self.prompt_logprobs != -1
|
| 430 |
+
and self.prompt_logprobs < 0):
|
| 431 |
+
raise ValueError(
|
| 432 |
+
f"prompt_logprobs must be non-negative or -1, got "
|
| 433 |
+
f"{self.prompt_logprobs}.")
|
| 434 |
+
if (self.truncate_prompt_tokens is not None
|
| 435 |
+
and (self.truncate_prompt_tokens == 0
|
| 436 |
+
or self.truncate_prompt_tokens < -1)):
|
| 437 |
+
raise ValueError(
|
| 438 |
+
f"truncate_prompt_tokens must be an integer >= 1 or -1, "
|
| 439 |
+
f"got {self.truncate_prompt_tokens}")
|
| 440 |
+
assert isinstance(self.stop_token_ids, list)
|
| 441 |
+
if not all(isinstance(st_id, int) for st_id in self.stop_token_ids):
|
| 442 |
+
raise ValueError(f"stop_token_ids must contain only integers, "
|
| 443 |
+
f"got {self.stop_token_ids}.")
|
| 444 |
+
assert isinstance(self.stop, list)
|
| 445 |
+
if any(not stop_str for stop_str in self.stop):
|
| 446 |
+
raise ValueError("stop cannot contain an empty string.")
|
| 447 |
+
if self.stop and not self.detokenize:
|
| 448 |
+
raise ValueError(
|
| 449 |
+
"stop strings are only supported when detokenize is True. "
|
| 450 |
+
"Set detokenize=True to use stop.")
|
| 451 |
+
if self.best_of != self._real_n and self.output_kind == (
|
| 452 |
+
RequestOutputKind.DELTA):
|
| 453 |
+
raise ValueError("best_of must equal n to use output_kind=DELTA")
|
| 454 |
+
|
| 455 |
+
def _verify_greedy_sampling(self) -> None:
|
| 456 |
+
if self.n > 1:
|
| 457 |
+
raise ValueError("n must be 1 when using greedy sampling, "
|
| 458 |
+
f"got {self.n}.")
|
| 459 |
+
|
| 460 |
+
def update_from_generation_config(
|
| 461 |
+
self,
|
| 462 |
+
generation_config: dict[str, Any],
|
| 463 |
+
model_eos_token_id: Optional[int] = None) -> None:
|
| 464 |
+
"""Update if there are non-default values from generation_config"""
|
| 465 |
+
|
| 466 |
+
if model_eos_token_id is not None:
|
| 467 |
+
# Add the eos token id into the sampling_params to support
|
| 468 |
+
# min_tokens processing.
|
| 469 |
+
self._all_stop_token_ids.add(model_eos_token_id)
|
| 470 |
+
|
| 471 |
+
# Update eos_token_id for generation
|
| 472 |
+
if (eos_ids := generation_config.get("eos_token_id")) is not None:
|
| 473 |
+
# it can be either int or list of int
|
| 474 |
+
eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
|
| 475 |
+
if model_eos_token_id is not None:
|
| 476 |
+
# We don't need to include the primary eos_token_id in
|
| 477 |
+
# stop_token_ids since it's handled separately for stopping
|
| 478 |
+
# purposes.
|
| 479 |
+
eos_ids.discard(model_eos_token_id)
|
| 480 |
+
if eos_ids:
|
| 481 |
+
self._all_stop_token_ids.update(eos_ids)
|
| 482 |
+
if not self.ignore_eos:
|
| 483 |
+
eos_ids.update(self.stop_token_ids)
|
| 484 |
+
self.stop_token_ids = list(eos_ids)
|
| 485 |
+
|
| 486 |
+
def update_from_tokenizer(self, tokenizer: AnyTokenizer) -> None:
|
| 487 |
+
if not self.bad_words:
|
| 488 |
+
return
|
| 489 |
+
self._bad_words_token_ids = []
|
| 490 |
+
for bad_word in self.bad_words:
|
| 491 |
+
# To prohibit words both at the beginning
|
| 492 |
+
# and in the middle of text
|
| 493 |
+
# (related to add_prefix_space tokenizer parameter)
|
| 494 |
+
for add_prefix_space in [False, True]:
|
| 495 |
+
prefix = " " if add_prefix_space else ""
|
| 496 |
+
prompt = prefix + bad_word.lstrip()
|
| 497 |
+
prompt_token_ids = tokenizer.encode(text=prompt,
|
| 498 |
+
add_special_tokens=False)
|
| 499 |
+
|
| 500 |
+
# If no space at the beginning
|
| 501 |
+
# or if prefix space produces a new word token
|
| 502 |
+
if (not add_prefix_space) or (
|
| 503 |
+
add_prefix_space and prompt_token_ids[0]
|
| 504 |
+
!= self._bad_words_token_ids[-1][0]
|
| 505 |
+
and len(prompt_token_ids) == len(
|
| 506 |
+
self._bad_words_token_ids[-1])):
|
| 507 |
+
self._bad_words_token_ids.append(prompt_token_ids)
|
| 508 |
+
|
| 509 |
+
invalid_token_ids = [
|
| 510 |
+
token_id for bad_words_token_ids in self._bad_words_token_ids
|
| 511 |
+
for token_id in bad_words_token_ids
|
| 512 |
+
if token_id < 0 or token_id > tokenizer.max_token_id
|
| 513 |
+
]
|
| 514 |
+
if len(invalid_token_ids) > 0:
|
| 515 |
+
raise ValueError(
|
| 516 |
+
f"The model vocabulary size is {tokenizer.max_token_id+1},"
|
| 517 |
+
f" but the following tokens"
|
| 518 |
+
f" were specified as bad: {invalid_token_ids}."
|
| 519 |
+
f" All token id values should be integers satisfying:"
|
| 520 |
+
f" 0 <= token_id <= {tokenizer.max_token_id}.")
|
| 521 |
+
|
| 522 |
+
@cached_property
|
| 523 |
+
def sampling_type(self) -> SamplingType:
|
| 524 |
+
if self.temperature < _SAMPLING_EPS:
|
| 525 |
+
return SamplingType.GREEDY
|
| 526 |
+
if self.seed is not None:
|
| 527 |
+
return SamplingType.RANDOM_SEED
|
| 528 |
+
return SamplingType.RANDOM
|
| 529 |
+
|
| 530 |
+
@property
|
| 531 |
+
def all_stop_token_ids(self) -> set[int]:
|
| 532 |
+
return self._all_stop_token_ids
|
| 533 |
+
|
| 534 |
+
@property
|
| 535 |
+
def bad_words_token_ids(self) -> Optional[list[list[int]]]:
|
| 536 |
+
# For internal use only. Backward compatibility not guaranteed
|
| 537 |
+
return self._bad_words_token_ids
|
| 538 |
+
|
| 539 |
+
def clone(self) -> "SamplingParams":
|
| 540 |
+
"""Deep copy, but maybe not the LogitsProcessor objects.
|
| 541 |
+
|
| 542 |
+
LogitsProcessor objects may contain an arbitrary, nontrivial amount of
|
| 543 |
+
data that is expensive to copy. However, if not copied, the processor
|
| 544 |
+
needs to support parallel decoding for multiple sequences
|
| 545 |
+
See https://github.com/vllm-project/vllm/issues/3087
|
| 546 |
+
"""
|
| 547 |
+
|
| 548 |
+
logit_processor_refs = None if self.logits_processors is None else {
|
| 549 |
+
id(lp): lp.clone() if hasattr(lp, 'clone') else lp
|
| 550 |
+
for lp in self.logits_processors
|
| 551 |
+
}
|
| 552 |
+
return copy.deepcopy(self, memo=logit_processor_refs)
|
| 553 |
+
|
| 554 |
+
def __repr__(self) -> str:
|
| 555 |
+
return (
|
| 556 |
+
f"SamplingParams(n={self.n}, "
|
| 557 |
+
f"presence_penalty={self.presence_penalty}, "
|
| 558 |
+
f"frequency_penalty={self.frequency_penalty}, "
|
| 559 |
+
f"repetition_penalty={self.repetition_penalty}, "
|
| 560 |
+
f"temperature={self.temperature}, "
|
| 561 |
+
f"top_p={self.top_p}, "
|
| 562 |
+
f"top_k={self.top_k}, "
|
| 563 |
+
f"min_p={self.min_p}, "
|
| 564 |
+
f"seed={self.seed}, "
|
| 565 |
+
f"stop={self.stop}, "
|
| 566 |
+
f"stop_token_ids={self.stop_token_ids}, "
|
| 567 |
+
f"bad_words={self.bad_words}, "
|
| 568 |
+
f"include_stop_str_in_output={self.include_stop_str_in_output}, "
|
| 569 |
+
f"ignore_eos={self.ignore_eos}, "
|
| 570 |
+
f"max_tokens={self.max_tokens}, "
|
| 571 |
+
f"min_tokens={self.min_tokens}, "
|
| 572 |
+
f"logprobs={self.logprobs}, "
|
| 573 |
+
f"prompt_logprobs={self.prompt_logprobs}, "
|
| 574 |
+
f"skip_special_tokens={self.skip_special_tokens}, "
|
| 575 |
+
"spaces_between_special_tokens="
|
| 576 |
+
f"{self.spaces_between_special_tokens}, "
|
| 577 |
+
f"truncate_prompt_tokens={self.truncate_prompt_tokens}, "
|
| 578 |
+
f"structured_outputs={self.structured_outputs}, "
|
| 579 |
+
f"extra_args={self.extra_args})")
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
class BeamSearchParams(
|
| 583 |
+
msgspec.Struct,
|
| 584 |
+
omit_defaults=True, # type: ignore[call-arg]
|
| 585 |
+
# required for @cached_property.
|
| 586 |
+
dict=True): # type: ignore[call-arg]
|
| 587 |
+
"""Beam search parameters for text generation."""
|
| 588 |
+
beam_width: int
|
| 589 |
+
max_tokens: int
|
| 590 |
+
ignore_eos: bool = False
|
| 591 |
+
temperature: float = 0.0
|
| 592 |
+
length_penalty: float = 1.0
|
| 593 |
+
include_stop_str_in_output: bool = False
|
ckpt/.gitkeep → vllm_hacked/v1/sample/__init__ori.py
RENAMED
|
File without changes
|
vllm_hacked/v1/sample/metadata.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
| 3 |
+
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from vllm.v1.sample.logits_processor import LogitsProcessors
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class SamplingMetadata:
|
| 14 |
+
|
| 15 |
+
temperature: Optional[torch.Tensor]
|
| 16 |
+
all_greedy: bool
|
| 17 |
+
all_random: bool
|
| 18 |
+
|
| 19 |
+
top_p: Optional[torch.Tensor]
|
| 20 |
+
top_k: Optional[torch.Tensor]
|
| 21 |
+
|
| 22 |
+
generators: dict[int, torch.Generator]
|
| 23 |
+
|
| 24 |
+
# None means no logprobs, 0 means sampled token logprobs only
|
| 25 |
+
max_num_logprobs: Optional[int]
|
| 26 |
+
|
| 27 |
+
no_penalties: bool
|
| 28 |
+
prompt_token_ids: Optional[torch.Tensor]
|
| 29 |
+
frequency_penalties: torch.Tensor
|
| 30 |
+
presence_penalties: torch.Tensor
|
| 31 |
+
repetition_penalties: torch.Tensor
|
| 32 |
+
|
| 33 |
+
output_token_ids: list[list[int]]
|
| 34 |
+
|
| 35 |
+
# `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size,
|
| 36 |
+
# vocab size).
|
| 37 |
+
allowed_token_ids_mask: Optional[torch.Tensor]
|
| 38 |
+
|
| 39 |
+
# req_index -> bad_words_token_ids
|
| 40 |
+
bad_words_token_ids: dict[int, list[list[int]]]
|
| 41 |
+
|
| 42 |
+
# Loaded logits processors
|
| 43 |
+
logitsprocs: LogitsProcessors
|
| 44 |
+
|
| 45 |
+
guidance_scale: Optional[float] = 1.8
|
vllm_hacked/v1/sample/metadata_ori.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
| 3 |
+
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from vllm.v1.sample.logits_processor import LogitsProcessors
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class SamplingMetadata:
|
| 14 |
+
|
| 15 |
+
temperature: Optional[torch.Tensor]
|
| 16 |
+
all_greedy: bool
|
| 17 |
+
all_random: bool
|
| 18 |
+
|
| 19 |
+
top_p: Optional[torch.Tensor]
|
| 20 |
+
top_k: Optional[torch.Tensor]
|
| 21 |
+
|
| 22 |
+
generators: dict[int, torch.Generator]
|
| 23 |
+
|
| 24 |
+
# None means no logprobs, 0 means sampled token logprobs only
|
| 25 |
+
max_num_logprobs: Optional[int]
|
| 26 |
+
|
| 27 |
+
no_penalties: bool
|
| 28 |
+
prompt_token_ids: Optional[torch.Tensor]
|
| 29 |
+
frequency_penalties: torch.Tensor
|
| 30 |
+
presence_penalties: torch.Tensor
|
| 31 |
+
repetition_penalties: torch.Tensor
|
| 32 |
+
|
| 33 |
+
output_token_ids: list[list[int]]
|
| 34 |
+
|
| 35 |
+
# `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size,
|
| 36 |
+
# vocab size).
|
| 37 |
+
allowed_token_ids_mask: Optional[torch.Tensor]
|
| 38 |
+
|
| 39 |
+
# req_index -> bad_words_token_ids
|
| 40 |
+
bad_words_token_ids: dict[int, list[list[int]]]
|
| 41 |
+
|
| 42 |
+
# Loaded logits processors
|
| 43 |
+
logitsprocs: LogitsProcessors
|
vllm_hacked/v1/sample/ops/penalties_ori.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from vllm.model_executor.layers.utils import apply_penalties
|
| 7 |
+
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def apply_all_penalties(
|
| 11 |
+
logits: torch.Tensor,
|
| 12 |
+
prompt_token_ids: torch.Tensor,
|
| 13 |
+
presence_penalties: torch.Tensor,
|
| 14 |
+
frequency_penalties: torch.Tensor,
|
| 15 |
+
repetition_penalties: torch.Tensor,
|
| 16 |
+
output_token_ids: list[list[int]],
|
| 17 |
+
) -> torch.Tensor:
|
| 18 |
+
"""
|
| 19 |
+
Applies presence, frequency and repetition penalties to the logits.
|
| 20 |
+
"""
|
| 21 |
+
_, vocab_size = logits.shape
|
| 22 |
+
output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size,
|
| 23 |
+
logits.device)
|
| 24 |
+
return apply_penalties(logits, prompt_token_ids, output_tokens_t,
|
| 25 |
+
presence_penalties, frequency_penalties,
|
| 26 |
+
repetition_penalties)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _convert_to_tensors(output_token_ids: list[list[int]], vocab_size: int,
|
| 30 |
+
device: torch.device) -> torch.Tensor:
|
| 31 |
+
"""
|
| 32 |
+
Convert the different list data structures to tensors.
|
| 33 |
+
"""
|
| 34 |
+
output_tokens_tensor = make_tensor_with_pad(
|
| 35 |
+
output_token_ids,
|
| 36 |
+
# Use the value of vocab_size as a pad since we don't have a
|
| 37 |
+
# token_id of this value.
|
| 38 |
+
pad=vocab_size,
|
| 39 |
+
device="cpu",
|
| 40 |
+
dtype=torch.int64,
|
| 41 |
+
pin_memory=is_pin_memory_available(),
|
| 42 |
+
)
|
| 43 |
+
return output_tokens_tensor.to(device, non_blocking=True)
|
vllm_hacked/v1/sample/sampler.py
ADDED
|
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
| 3 |
+
"""A layer that samples the next tokens from the model's outputs."""
|
| 4 |
+
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
from vllm.config import LogprobsMode
|
| 11 |
+
from vllm.utils import is_pin_memory_available
|
| 12 |
+
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
| 13 |
+
from vllm.v1.sample.metadata import SamplingMetadata
|
| 14 |
+
from vllm.v1.sample.ops.bad_words import apply_bad_words
|
| 15 |
+
from vllm.v1.sample.ops.logprobs import batched_count_greater_than
|
| 16 |
+
from vllm.v1.sample.ops.penalties import apply_all_penalties
|
| 17 |
+
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
|
| 18 |
+
|
| 19 |
+
_SAMPLING_EPS = 1e-5
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class Sampler(nn.Module):
|
| 23 |
+
"""
|
| 24 |
+
A layer that samples the next tokens from the model's outputs
|
| 25 |
+
with the following steps in order:
|
| 26 |
+
|
| 27 |
+
1. If logprobs are requested:
|
| 28 |
+
a) If `logprobs_mode` is `raw_logprobs`, compute logprobs
|
| 29 |
+
as the final logprobs to return.
|
| 30 |
+
b) If `logprobs_mode` is `raw_logits`, clone the logits
|
| 31 |
+
as the final logprobs to return.
|
| 32 |
+
2. Convert logits to float32.
|
| 33 |
+
3. Apply allowed token ids whitelist.
|
| 34 |
+
4. Apply bad words exclusion.
|
| 35 |
+
5. Apply logit processors which are not argmax-invariant,
|
| 36 |
+
i.e. that can impact greedy sampling.
|
| 37 |
+
a) Min tokens processor
|
| 38 |
+
b) Logit bias processor
|
| 39 |
+
6. Apply penalties
|
| 40 |
+
a) Repetition penalty
|
| 41 |
+
b) Frequency penalty
|
| 42 |
+
c) Presence penalty
|
| 43 |
+
7. Sample the next tokens. `sample` method performs the following steps:
|
| 44 |
+
a) If not `all_random`, perform greedy sampling. If `all_greedy`,
|
| 45 |
+
return the greedily sampled tokens and final logprobs if requested.
|
| 46 |
+
b) Apply temperature.
|
| 47 |
+
c) Apply logit processors which are argmax-invariant, by default
|
| 48 |
+
the min_p processor.
|
| 49 |
+
d) Apply top_k and/or top_p.
|
| 50 |
+
e) Sample the next tokens with the probability distribution.
|
| 51 |
+
f) If `all_random` or temperature >= epsilon (1e-5), return the
|
| 52 |
+
randomly sampled tokens and final logprobs if requested. Else,
|
| 53 |
+
return the greedily sampled tokens and logprobs if requested.
|
| 54 |
+
8. Gather the logprobs of the top `max_num_logprobs` and sampled token
|
| 55 |
+
(if requested). Note that if the sampled token is within the top
|
| 56 |
+
`max_num_logprobs`, the logprob will be eventually merged in
|
| 57 |
+
`LogprobsProcessor` during output processing. Therefore, the
|
| 58 |
+
final output may contain either `max_num_logprobs + 1` or
|
| 59 |
+
`max_num_logprobs` logprobs.
|
| 60 |
+
9. Return the final `SamplerOutput`.
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs"):
|
| 64 |
+
super().__init__()
|
| 65 |
+
self.topk_topp_sampler = TopKTopPSampler(logprobs_mode)
|
| 66 |
+
self.pin_memory = is_pin_memory_available()
|
| 67 |
+
self.logprobs_mode = logprobs_mode
|
| 68 |
+
|
| 69 |
+
def forward(
|
| 70 |
+
self,
|
| 71 |
+
logits: torch.Tensor,
|
| 72 |
+
sampling_metadata: SamplingMetadata,
|
| 73 |
+
) -> SamplerOutput:
|
| 74 |
+
# NOTE(woosuk): Use the original logits (before any penalties or
|
| 75 |
+
# temperature scaling) for the top-k logprobs.
|
| 76 |
+
# This is different from the V0 sampler, which uses the logits that
|
| 77 |
+
# is used for sampling (after penalties and temperature scaling).
|
| 78 |
+
|
| 79 |
+
# Jianwei Yu CFG debug
|
| 80 |
+
# print(dir(sampling_metadata))
|
| 81 |
+
# import pdb; pdb.set_trace()
|
| 82 |
+
|
| 83 |
+
# if sampling_metadata.seq_groups[0].sampling_params.guidance_scale:
|
| 84 |
+
# if sampling_metadata.seq_groups[0].sampling_params.guidance_scale != 1.0:
|
| 85 |
+
# print("Guidance scale is not 1.0, processing logits")
|
| 86 |
+
# print("Guidance scale: {}".format(sampling_metadata.seq_groups[0].sampling_params.guidance_scale))
|
| 87 |
+
# print(logits.shape)
|
| 88 |
+
|
| 89 |
+
# if logits.shape[0] == 2 and logits.ndim == 2: # batch为1的情况
|
| 90 |
+
# logits = logits.to(torch.float32)
|
| 91 |
+
# scores = torch.nn.functional.log_softmax(logits, dim=-1)
|
| 92 |
+
# # scores_processed = (sampling_metadata.seq_groups[0].sampling_params.guidance_scale * (scores[0] - scores[1]) + scores[1])
|
| 93 |
+
# scores_processed = (1.8 * (scores[0] - scores[1]) + scores[1])
|
| 94 |
+
# # import random;
|
| 95 |
+
# # tmp = random.random()
|
| 96 |
+
# # scores_processed = (1.8 * (scores[0] - tmp) + tmp)
|
| 97 |
+
# # scores_processed = torch.stack([scores_processed.clone(), scores_processed.clone()])
|
| 98 |
+
# scores_processed = torch.stack([scores_processed.clone(), scores[0].clone()])
|
| 99 |
+
# # def logits_processor_stage1(logits):
|
| 100 |
+
# # blocked_token_ids = list(range(0, 32002))+[32016]
|
| 101 |
+
# # logits[:,blocked_token_ids] = -float("inf")
|
| 102 |
+
# # return logits
|
| 103 |
+
|
| 104 |
+
# logits = scores_processed
|
| 105 |
+
# # logits = logits_processor_stage1(logits)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
'''单条推理CFG'''
|
| 109 |
+
# if logits.shape[0] == 3:
|
| 110 |
+
if logits.shape[0] > 1 and logits.shape[0] != 1024:
|
| 111 |
+
logits = logits.to(torch.float32)
|
| 112 |
+
scores = torch.nn.functional.log_softmax(logits, dim=-1)
|
| 113 |
+
# scores_reshaped = scores.reshape(-1, 2, *scores.shape[1:])
|
| 114 |
+
scores_cond = scores[-2]
|
| 115 |
+
scores_uncond = scores[-1]
|
| 116 |
+
scores_processed = sampling_metadata.guidance_scale * (scores_cond - scores_uncond) + scores_uncond
|
| 117 |
+
# scores_processed = processed_groups.repeat_interleave(2, dim=0)
|
| 118 |
+
if logits.shape[0] == 3:
|
| 119 |
+
scores_processed = torch.stack([scores[0].clone(), scores_processed.clone(), scores_processed.clone()])
|
| 120 |
+
elif logits.shape[0] == 2:
|
| 121 |
+
scores_processed = torch.stack([scores_processed.clone(), scores_processed.clone()])
|
| 122 |
+
logits = scores_processed
|
| 123 |
+
# else:
|
| 124 |
+
# print("Warning: logits shape is not 3, the dim is {}".format(logits.shape[0]))
|
| 125 |
+
|
| 126 |
+
num_logprobs = sampling_metadata.max_num_logprobs
|
| 127 |
+
if num_logprobs is not None:
|
| 128 |
+
if self.logprobs_mode == "raw_logprobs":
|
| 129 |
+
raw_logprobs = self.compute_logprobs(logits)
|
| 130 |
+
elif self.logprobs_mode == "raw_logits":
|
| 131 |
+
raw_logprobs = logits.clone()
|
| 132 |
+
|
| 133 |
+
# Use float32 for the logits.
|
| 134 |
+
logits = logits.to(torch.float32)
|
| 135 |
+
# Apply allowed token ids.
|
| 136 |
+
logits = self.apply_allowed_token_ids(logits, sampling_metadata)
|
| 137 |
+
# Apply bad words exclusion.
|
| 138 |
+
logits = self.apply_bad_words(logits, sampling_metadata)
|
| 139 |
+
|
| 140 |
+
# Apply logits processors which can impact greedy sampling
|
| 141 |
+
for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
|
| 142 |
+
logits = processor.apply(logits)
|
| 143 |
+
|
| 144 |
+
# Apply penalties (e.g., min_tokens, freq_penalties).
|
| 145 |
+
logits = self.apply_penalties(logits, sampling_metadata)
|
| 146 |
+
|
| 147 |
+
# Sample the next token.
|
| 148 |
+
sampled, processed_logprobs = self.sample(logits, sampling_metadata)
|
| 149 |
+
if processed_logprobs is not None:
|
| 150 |
+
raw_logprobs = processed_logprobs
|
| 151 |
+
# Convert sampled token ids to int64 (long) type to ensure compatibility
|
| 152 |
+
# with subsequent operations that may use these values as indices.
|
| 153 |
+
# This conversion is necessary because FlashInfer sampling operations
|
| 154 |
+
# return int32 (while PyTorch argmax and topk return int64).
|
| 155 |
+
sampled = sampled.long()
|
| 156 |
+
|
| 157 |
+
# Gather the logprobs of the topk and sampled token (if requested).
|
| 158 |
+
# Get logprobs and rank tensors (if requested)
|
| 159 |
+
logprobs_tensors = None if num_logprobs is None else \
|
| 160 |
+
self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=sampled)
|
| 161 |
+
|
| 162 |
+
# Use int32 to reduce the tensor size.
|
| 163 |
+
sampled = sampled.to(torch.int32)
|
| 164 |
+
|
| 165 |
+
# These are GPU tensors.
|
| 166 |
+
sampler_output = SamplerOutput(
|
| 167 |
+
# The sampled tokens are expanded to 2D tensor with shape
|
| 168 |
+
# [num_requests, 1], where each row represents one generated
|
| 169 |
+
# token per request.
|
| 170 |
+
sampled_token_ids=sampled.unsqueeze(-1),
|
| 171 |
+
logprobs_tensors=logprobs_tensors,
|
| 172 |
+
)
|
| 173 |
+
# print(sampler_output)
|
| 174 |
+
# print(sampler_output.sampled_token_ids.shape)
|
| 175 |
+
# if sampler_output.sampled_token_ids.shape[0] != 1024 and sampler_output.sampled_token_ids.shape[0] != 1:
|
| 176 |
+
# import pdb; pdb.set_trace()
|
| 177 |
+
# pass
|
| 178 |
+
return sampler_output
|
| 179 |
+
|
| 180 |
+
def apply_temperature(
|
| 181 |
+
self,
|
| 182 |
+
logits: torch.Tensor,
|
| 183 |
+
temp: torch.Tensor,
|
| 184 |
+
all_random: bool,
|
| 185 |
+
) -> torch.Tensor:
|
| 186 |
+
# Use in-place division to avoid creating a new tensor.
|
| 187 |
+
# Avoid division by zero if there are greedy requests.
|
| 188 |
+
if not all_random:
|
| 189 |
+
temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
|
| 190 |
+
return logits.div_(temp.unsqueeze(dim=1))
|
| 191 |
+
|
| 192 |
+
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
|
| 193 |
+
return logits.argmax(dim=-1).view(-1)
|
| 194 |
+
|
| 195 |
+
def sample(
|
| 196 |
+
self,
|
| 197 |
+
logits: torch.Tensor,
|
| 198 |
+
sampling_metadata: SamplingMetadata,
|
| 199 |
+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 200 |
+
"""Sample logits based on sampling metadata.
|
| 201 |
+
|
| 202 |
+
The various logits processing functions called in this method
|
| 203 |
+
may update the logits tensor in-place.
|
| 204 |
+
"""
|
| 205 |
+
|
| 206 |
+
assert not (sampling_metadata.all_greedy
|
| 207 |
+
and sampling_metadata.all_random)
|
| 208 |
+
if sampling_metadata.all_random:
|
| 209 |
+
greedy_sampled = None
|
| 210 |
+
else:
|
| 211 |
+
greedy_sampled = self.greedy_sample(logits)
|
| 212 |
+
if sampling_metadata.all_greedy:
|
| 213 |
+
processed_logprobs = None
|
| 214 |
+
if sampling_metadata.max_num_logprobs is not None:
|
| 215 |
+
if self.logprobs_mode == "processed_logits":
|
| 216 |
+
processed_logprobs = logits
|
| 217 |
+
elif self.logprobs_mode == "processed_logprobs":
|
| 218 |
+
processed_logprobs = self.compute_logprobs(logits)
|
| 219 |
+
return greedy_sampled, processed_logprobs
|
| 220 |
+
|
| 221 |
+
assert sampling_metadata.temperature is not None
|
| 222 |
+
|
| 223 |
+
# Apply temperature.
|
| 224 |
+
logits = self.apply_temperature(logits, sampling_metadata.temperature,
|
| 225 |
+
sampling_metadata.all_random)
|
| 226 |
+
|
| 227 |
+
# Apply logits processors that only apply to random sampling
|
| 228 |
+
# (argmax invariant)
|
| 229 |
+
for processor in sampling_metadata.logitsprocs.argmax_invariant:
|
| 230 |
+
logits = processor.apply(logits)
|
| 231 |
+
|
| 232 |
+
# Apply top_k and/or top_p.
|
| 233 |
+
random_sampled, processed_logprobs = self.topk_topp_sampler(
|
| 234 |
+
logits,
|
| 235 |
+
sampling_metadata.generators,
|
| 236 |
+
sampling_metadata.top_k,
|
| 237 |
+
sampling_metadata.top_p,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
if greedy_sampled is None:
|
| 241 |
+
return random_sampled, processed_logprobs
|
| 242 |
+
|
| 243 |
+
sampled = torch.where(
|
| 244 |
+
sampling_metadata.temperature < _SAMPLING_EPS,
|
| 245 |
+
greedy_sampled,
|
| 246 |
+
random_sampled,
|
| 247 |
+
out=greedy_sampled, # Reuse tensor
|
| 248 |
+
)
|
| 249 |
+
return sampled, processed_logprobs
|
| 250 |
+
|
| 251 |
+
def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
|
| 252 |
+
return logits.log_softmax(dim=-1, dtype=torch.float32)
|
| 253 |
+
|
| 254 |
+
def gather_logprobs(
|
| 255 |
+
self,
|
| 256 |
+
logprobs: torch.Tensor,
|
| 257 |
+
num_logprobs: int,
|
| 258 |
+
token_ids: torch.Tensor,
|
| 259 |
+
) -> LogprobsTensors:
|
| 260 |
+
"""
|
| 261 |
+
Gather logprobs for topk and sampled/prompt token.
|
| 262 |
+
|
| 263 |
+
Args:
|
| 264 |
+
logprobs: (num tokens) x (vocab) tensor
|
| 265 |
+
num_logprobs: minimum number of logprobs to
|
| 266 |
+
retain per token
|
| 267 |
+
token_ids: prompt tokens (if prompt logprobs)
|
| 268 |
+
or sampled tokens (if sampled
|
| 269 |
+
logprobs); 1D token ID tensor
|
| 270 |
+
with (num tokens) elements
|
| 271 |
+
Must be int64.
|
| 272 |
+
|
| 273 |
+
Returns:
|
| 274 |
+
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
|
| 275 |
+
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
|
| 276 |
+
Sampled token rank tensor, (num tokens)
|
| 277 |
+
"""
|
| 278 |
+
assert token_ids.dtype == torch.int64
|
| 279 |
+
# Find the topK values.
|
| 280 |
+
topk_logprobs, topk_indices = torch.topk(logprobs,
|
| 281 |
+
num_logprobs,
|
| 282 |
+
dim=-1)
|
| 283 |
+
|
| 284 |
+
# Get with the logprob of the prompt or sampled token.
|
| 285 |
+
token_ids = token_ids.unsqueeze(-1)
|
| 286 |
+
token_logprobs = logprobs.gather(-1, token_ids)
|
| 287 |
+
|
| 288 |
+
# Compute the ranks of the actual token.
|
| 289 |
+
token_ranks = batched_count_greater_than(logprobs, token_logprobs)
|
| 290 |
+
|
| 291 |
+
# Concatenate together with the topk.
|
| 292 |
+
indices = torch.cat((token_ids, topk_indices), dim=1)
|
| 293 |
+
logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)
|
| 294 |
+
|
| 295 |
+
# Use int32 to reduce the tensor size.
|
| 296 |
+
indices = indices.to(torch.int32)
|
| 297 |
+
|
| 298 |
+
return LogprobsTensors(indices, logprobs, token_ranks)
|
| 299 |
+
|
| 300 |
+
def apply_penalties(
|
| 301 |
+
self,
|
| 302 |
+
logits: torch.Tensor,
|
| 303 |
+
sampling_metadata: SamplingMetadata,
|
| 304 |
+
) -> torch.Tensor:
|
| 305 |
+
if not sampling_metadata.no_penalties:
|
| 306 |
+
assert sampling_metadata.prompt_token_ids is not None
|
| 307 |
+
logits = apply_all_penalties(
|
| 308 |
+
logits,
|
| 309 |
+
sampling_metadata.prompt_token_ids,
|
| 310 |
+
sampling_metadata.presence_penalties,
|
| 311 |
+
sampling_metadata.frequency_penalties,
|
| 312 |
+
sampling_metadata.repetition_penalties,
|
| 313 |
+
sampling_metadata.output_token_ids,
|
| 314 |
+
)
|
| 315 |
+
return logits
|
| 316 |
+
|
| 317 |
+
def apply_allowed_token_ids(
|
| 318 |
+
self,
|
| 319 |
+
logits: torch.Tensor,
|
| 320 |
+
sampling_metadata: SamplingMetadata,
|
| 321 |
+
) -> torch.Tensor:
|
| 322 |
+
if sampling_metadata.allowed_token_ids_mask is not None:
|
| 323 |
+
logits.masked_fill_(sampling_metadata.allowed_token_ids_mask,
|
| 324 |
+
float("-inf"))
|
| 325 |
+
return logits
|
| 326 |
+
|
| 327 |
+
def apply_bad_words(
|
| 328 |
+
self,
|
| 329 |
+
logits: torch.Tensor,
|
| 330 |
+
sampling_metadata: SamplingMetadata,
|
| 331 |
+
) -> torch.Tensor:
|
| 332 |
+
if sampling_metadata.bad_words_token_ids:
|
| 333 |
+
apply_bad_words(
|
| 334 |
+
logits,
|
| 335 |
+
sampling_metadata.bad_words_token_ids,
|
| 336 |
+
sampling_metadata.output_token_ids,
|
| 337 |
+
)
|
| 338 |
+
return logits
|
vllm_hacked/v1/sample/sampler_ori.py
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
| 3 |
+
"""A layer that samples the next tokens from the model's outputs."""
|
| 4 |
+
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
from vllm.config import LogprobsMode
|
| 11 |
+
from vllm.utils import is_pin_memory_available
|
| 12 |
+
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
| 13 |
+
from vllm.v1.sample.metadata import SamplingMetadata
|
| 14 |
+
from vllm.v1.sample.ops.bad_words import apply_bad_words
|
| 15 |
+
from vllm.v1.sample.ops.logprobs import batched_count_greater_than
|
| 16 |
+
from vllm.v1.sample.ops.penalties import apply_all_penalties
|
| 17 |
+
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
|
| 18 |
+
|
| 19 |
+
_SAMPLING_EPS = 1e-5
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class Sampler(nn.Module):
|
| 23 |
+
"""
|
| 24 |
+
A layer that samples the next tokens from the model's outputs
|
| 25 |
+
with the following steps in order:
|
| 26 |
+
|
| 27 |
+
1. If logprobs are requested:
|
| 28 |
+
a) If `logprobs_mode` is `raw_logprobs`, compute logprobs
|
| 29 |
+
as the final logprobs to return.
|
| 30 |
+
b) If `logprobs_mode` is `raw_logits`, clone the logits
|
| 31 |
+
as the final logprobs to return.
|
| 32 |
+
2. Convert logits to float32.
|
| 33 |
+
3. Apply allowed token ids whitelist.
|
| 34 |
+
4. Apply bad words exclusion.
|
| 35 |
+
5. Apply logit processors which are not argmax-invariant,
|
| 36 |
+
i.e. that can impact greedy sampling.
|
| 37 |
+
a) Min tokens processor
|
| 38 |
+
b) Logit bias processor
|
| 39 |
+
6. Apply penalties
|
| 40 |
+
a) Repetition penalty
|
| 41 |
+
b) Frequency penalty
|
| 42 |
+
c) Presence penalty
|
| 43 |
+
7. Sample the next tokens. `sample` method performs the following steps:
|
| 44 |
+
a) If not `all_random`, perform greedy sampling. If `all_greedy`,
|
| 45 |
+
return the greedily sampled tokens and final logprobs if requested.
|
| 46 |
+
b) Apply temperature.
|
| 47 |
+
c) Apply logit processors which are argmax-invariant, by default
|
| 48 |
+
the min_p processor.
|
| 49 |
+
d) Apply top_k and/or top_p.
|
| 50 |
+
e) Sample the next tokens with the probability distribution.
|
| 51 |
+
f) If `all_random` or temperature >= epsilon (1e-5), return the
|
| 52 |
+
randomly sampled tokens and final logprobs if requested. Else,
|
| 53 |
+
return the greedily sampled tokens and logprobs if requested.
|
| 54 |
+
8. Gather the logprobs of the top `max_num_logprobs` and sampled token
|
| 55 |
+
(if requested). Note that if the sampled token is within the top
|
| 56 |
+
`max_num_logprobs`, the logprob will be eventually merged in
|
| 57 |
+
`LogprobsProcessor` during output processing. Therefore, the
|
| 58 |
+
final output may contain either `max_num_logprobs + 1` or
|
| 59 |
+
`max_num_logprobs` logprobs.
|
| 60 |
+
9. Return the final `SamplerOutput`.
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs"):
|
| 64 |
+
super().__init__()
|
| 65 |
+
self.topk_topp_sampler = TopKTopPSampler(logprobs_mode)
|
| 66 |
+
self.pin_memory = is_pin_memory_available()
|
| 67 |
+
self.logprobs_mode = logprobs_mode
|
| 68 |
+
|
| 69 |
+
def forward(
|
| 70 |
+
self,
|
| 71 |
+
logits: torch.Tensor,
|
| 72 |
+
sampling_metadata: SamplingMetadata,
|
| 73 |
+
) -> SamplerOutput:
|
| 74 |
+
# NOTE(woosuk): Use the original logits (before any penalties or
|
| 75 |
+
# temperature scaling) for the top-k logprobs.
|
| 76 |
+
# This is different from the V0 sampler, which uses the logits that
|
| 77 |
+
# is used for sampling (after penalties and temperature scaling).
|
| 78 |
+
num_logprobs = sampling_metadata.max_num_logprobs
|
| 79 |
+
if num_logprobs is not None:
|
| 80 |
+
if self.logprobs_mode == "raw_logprobs":
|
| 81 |
+
raw_logprobs = self.compute_logprobs(logits)
|
| 82 |
+
elif self.logprobs_mode == "raw_logits":
|
| 83 |
+
raw_logprobs = logits.clone()
|
| 84 |
+
|
| 85 |
+
# Use float32 for the logits.
|
| 86 |
+
logits = logits.to(torch.float32)
|
| 87 |
+
# Apply allowed token ids.
|
| 88 |
+
logits = self.apply_allowed_token_ids(logits, sampling_metadata)
|
| 89 |
+
# Apply bad words exclusion.
|
| 90 |
+
logits = self.apply_bad_words(logits, sampling_metadata)
|
| 91 |
+
|
| 92 |
+
# Apply logits processors which can impact greedy sampling
|
| 93 |
+
for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
|
| 94 |
+
logits = processor.apply(logits)
|
| 95 |
+
|
| 96 |
+
# Apply penalties (e.g., min_tokens, freq_penalties).
|
| 97 |
+
logits = self.apply_penalties(logits, sampling_metadata)
|
| 98 |
+
|
| 99 |
+
# Sample the next token.
|
| 100 |
+
sampled, processed_logprobs = self.sample(logits, sampling_metadata)
|
| 101 |
+
if processed_logprobs is not None:
|
| 102 |
+
raw_logprobs = processed_logprobs
|
| 103 |
+
# Convert sampled token ids to int64 (long) type to ensure compatibility
|
| 104 |
+
# with subsequent operations that may use these values as indices.
|
| 105 |
+
# This conversion is necessary because FlashInfer sampling operations
|
| 106 |
+
# return int32 (while PyTorch argmax and topk return int64).
|
| 107 |
+
sampled = sampled.long()
|
| 108 |
+
|
| 109 |
+
# Gather the logprobs of the topk and sampled token (if requested).
|
| 110 |
+
# Get logprobs and rank tensors (if requested)
|
| 111 |
+
logprobs_tensors = None if num_logprobs is None else \
|
| 112 |
+
self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=sampled)
|
| 113 |
+
|
| 114 |
+
# Use int32 to reduce the tensor size.
|
| 115 |
+
sampled = sampled.to(torch.int32)
|
| 116 |
+
|
| 117 |
+
# These are GPU tensors.
|
| 118 |
+
sampler_output = SamplerOutput(
|
| 119 |
+
# The sampled tokens are expanded to 2D tensor with shape
|
| 120 |
+
# [num_requests, 1], where each row represents one generated
|
| 121 |
+
# token per request.
|
| 122 |
+
sampled_token_ids=sampled.unsqueeze(-1),
|
| 123 |
+
logprobs_tensors=logprobs_tensors,
|
| 124 |
+
)
|
| 125 |
+
return sampler_output
|
| 126 |
+
|
| 127 |
+
def apply_temperature(
|
| 128 |
+
self,
|
| 129 |
+
logits: torch.Tensor,
|
| 130 |
+
temp: torch.Tensor,
|
| 131 |
+
all_random: bool,
|
| 132 |
+
) -> torch.Tensor:
|
| 133 |
+
# Use in-place division to avoid creating a new tensor.
|
| 134 |
+
# Avoid division by zero if there are greedy requests.
|
| 135 |
+
if not all_random:
|
| 136 |
+
temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
|
| 137 |
+
return logits.div_(temp.unsqueeze(dim=1))
|
| 138 |
+
|
| 139 |
+
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
|
| 140 |
+
return logits.argmax(dim=-1).view(-1)
|
| 141 |
+
|
| 142 |
+
def sample(
|
| 143 |
+
self,
|
| 144 |
+
logits: torch.Tensor,
|
| 145 |
+
sampling_metadata: SamplingMetadata,
|
| 146 |
+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 147 |
+
"""Sample logits based on sampling metadata.
|
| 148 |
+
|
| 149 |
+
The various logits processing functions called in this method
|
| 150 |
+
may update the logits tensor in-place.
|
| 151 |
+
"""
|
| 152 |
+
|
| 153 |
+
assert not (sampling_metadata.all_greedy
|
| 154 |
+
and sampling_metadata.all_random)
|
| 155 |
+
if sampling_metadata.all_random:
|
| 156 |
+
greedy_sampled = None
|
| 157 |
+
else:
|
| 158 |
+
greedy_sampled = self.greedy_sample(logits)
|
| 159 |
+
if sampling_metadata.all_greedy:
|
| 160 |
+
processed_logprobs = None
|
| 161 |
+
if sampling_metadata.max_num_logprobs is not None:
|
| 162 |
+
if self.logprobs_mode == "processed_logits":
|
| 163 |
+
processed_logprobs = logits
|
| 164 |
+
elif self.logprobs_mode == "processed_logprobs":
|
| 165 |
+
processed_logprobs = self.compute_logprobs(logits)
|
| 166 |
+
return greedy_sampled, processed_logprobs
|
| 167 |
+
|
| 168 |
+
assert sampling_metadata.temperature is not None
|
| 169 |
+
|
| 170 |
+
# Apply temperature.
|
| 171 |
+
logits = self.apply_temperature(logits, sampling_metadata.temperature,
|
| 172 |
+
sampling_metadata.all_random)
|
| 173 |
+
|
| 174 |
+
# Apply logits processors that only apply to random sampling
|
| 175 |
+
# (argmax invariant)
|
| 176 |
+
for processor in sampling_metadata.logitsprocs.argmax_invariant:
|
| 177 |
+
logits = processor.apply(logits)
|
| 178 |
+
|
| 179 |
+
# Apply top_k and/or top_p.
|
| 180 |
+
random_sampled, processed_logprobs = self.topk_topp_sampler(
|
| 181 |
+
logits,
|
| 182 |
+
sampling_metadata.generators,
|
| 183 |
+
sampling_metadata.top_k,
|
| 184 |
+
sampling_metadata.top_p,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
if greedy_sampled is None:
|
| 188 |
+
return random_sampled, processed_logprobs
|
| 189 |
+
|
| 190 |
+
sampled = torch.where(
|
| 191 |
+
sampling_metadata.temperature < _SAMPLING_EPS,
|
| 192 |
+
greedy_sampled,
|
| 193 |
+
random_sampled,
|
| 194 |
+
out=greedy_sampled, # Reuse tensor
|
| 195 |
+
)
|
| 196 |
+
return sampled, processed_logprobs
|
| 197 |
+
|
| 198 |
+
def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
|
| 199 |
+
return logits.log_softmax(dim=-1, dtype=torch.float32)
|
| 200 |
+
|
| 201 |
+
def gather_logprobs(
|
| 202 |
+
self,
|
| 203 |
+
logprobs: torch.Tensor,
|
| 204 |
+
num_logprobs: int,
|
| 205 |
+
token_ids: torch.Tensor,
|
| 206 |
+
) -> LogprobsTensors:
|
| 207 |
+
"""
|
| 208 |
+
Gather logprobs for topk and sampled/prompt token.
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
logprobs: (num tokens) x (vocab) tensor
|
| 212 |
+
num_logprobs: minimum number of logprobs to
|
| 213 |
+
retain per token
|
| 214 |
+
token_ids: prompt tokens (if prompt logprobs)
|
| 215 |
+
or sampled tokens (if sampled
|
| 216 |
+
logprobs); 1D token ID tensor
|
| 217 |
+
with (num tokens) elements
|
| 218 |
+
Must be int64.
|
| 219 |
+
|
| 220 |
+
Returns:
|
| 221 |
+
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
|
| 222 |
+
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
|
| 223 |
+
Sampled token rank tensor, (num tokens)
|
| 224 |
+
"""
|
| 225 |
+
assert token_ids.dtype == torch.int64
|
| 226 |
+
# Find the topK values.
|
| 227 |
+
topk_logprobs, topk_indices = torch.topk(logprobs,
|
| 228 |
+
num_logprobs,
|
| 229 |
+
dim=-1)
|
| 230 |
+
|
| 231 |
+
# Get with the logprob of the prompt or sampled token.
|
| 232 |
+
token_ids = token_ids.unsqueeze(-1)
|
| 233 |
+
token_logprobs = logprobs.gather(-1, token_ids)
|
| 234 |
+
|
| 235 |
+
# Compute the ranks of the actual token.
|
| 236 |
+
token_ranks = batched_count_greater_than(logprobs, token_logprobs)
|
| 237 |
+
|
| 238 |
+
# Concatenate together with the topk.
|
| 239 |
+
indices = torch.cat((token_ids, topk_indices), dim=1)
|
| 240 |
+
logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)
|
| 241 |
+
|
| 242 |
+
# Use int32 to reduce the tensor size.
|
| 243 |
+
indices = indices.to(torch.int32)
|
| 244 |
+
|
| 245 |
+
return LogprobsTensors(indices, logprobs, token_ranks)
|
| 246 |
+
|
| 247 |
+
def apply_penalties(
|
| 248 |
+
self,
|
| 249 |
+
logits: torch.Tensor,
|
| 250 |
+
sampling_metadata: SamplingMetadata,
|
| 251 |
+
) -> torch.Tensor:
|
| 252 |
+
if not sampling_metadata.no_penalties:
|
| 253 |
+
assert sampling_metadata.prompt_token_ids is not None
|
| 254 |
+
logits = apply_all_penalties(
|
| 255 |
+
logits,
|
| 256 |
+
sampling_metadata.prompt_token_ids,
|
| 257 |
+
sampling_metadata.presence_penalties,
|
| 258 |
+
sampling_metadata.frequency_penalties,
|
| 259 |
+
sampling_metadata.repetition_penalties,
|
| 260 |
+
sampling_metadata.output_token_ids,
|
| 261 |
+
)
|
| 262 |
+
return logits
|
| 263 |
+
|
| 264 |
+
def apply_allowed_token_ids(
|
| 265 |
+
self,
|
| 266 |
+
logits: torch.Tensor,
|
| 267 |
+
sampling_metadata: SamplingMetadata,
|
| 268 |
+
) -> torch.Tensor:
|
| 269 |
+
if sampling_metadata.allowed_token_ids_mask is not None:
|
| 270 |
+
logits.masked_fill_(sampling_metadata.allowed_token_ids_mask,
|
| 271 |
+
float("-inf"))
|
| 272 |
+
return logits
|
| 273 |
+
|
| 274 |
+
def apply_bad_words(
|
| 275 |
+
self,
|
| 276 |
+
logits: torch.Tensor,
|
| 277 |
+
sampling_metadata: SamplingMetadata,
|
| 278 |
+
) -> torch.Tensor:
|
| 279 |
+
if sampling_metadata.bad_words_token_ids:
|
| 280 |
+
apply_bad_words(
|
| 281 |
+
logits,
|
| 282 |
+
sampling_metadata.bad_words_token_ids,
|
| 283 |
+
sampling_metadata.output_token_ids,
|
| 284 |
+
)
|
| 285 |
+
return logits
|
vllm_hacked/v1/spec_decode/utils.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
from vllm.v1.worker.gpu_input_batch import InputBatch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def is_spec_decode_supported(req_id: str, input_batch: InputBatch) -> bool:
|
| 6 |
+
if req_id in input_batch.min_p_reqs:
|
| 7 |
+
# Spec decode doesn't support min_p sampling.
|
| 8 |
+
return False
|
| 9 |
+
elif (req_id in input_batch.frequency_penalties_reqs
|
| 10 |
+
or req_id in input_batch.presence_penalties_reqs
|
| 11 |
+
or req_id in input_batch.repetition_penalties_reqs):
|
| 12 |
+
# Spec decode doesn't support penalties.
|
| 13 |
+
return False
|
| 14 |
+
elif req_id in input_batch.num_logprobs:
|
| 15 |
+
# Spec decode doesn't support logprobs.
|
| 16 |
+
return False
|
| 17 |
+
|
| 18 |
+
return True
|
vllm_hacked/v1/spec_decode/utils_ori.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
| 3 |
+
from vllm.sampling_params import SamplingParams
|
| 4 |
+
|
| 5 |
+
_SAMPLING_EPS = 1e-5
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def is_spec_decode_unsupported(sampling_params: SamplingParams) -> bool:
|
| 9 |
+
"""True if request is incompatible with speculative decoding"""
|
| 10 |
+
return (sampling_params.frequency_penalty != 0.0
|
| 11 |
+
or sampling_params.presence_penalty != 0.0
|
| 12 |
+
or sampling_params.repetition_penalty != 1.0
|
| 13 |
+
or sampling_params.min_p > _SAMPLING_EPS
|
| 14 |
+
or sampling_params.logprobs is not None)
|
vllm_hacked/v1/utils_ori.py
ADDED
|
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
| 3 |
+
import argparse
|
| 4 |
+
import contextlib
|
| 5 |
+
import multiprocessing
|
| 6 |
+
import time
|
| 7 |
+
import weakref
|
| 8 |
+
from collections.abc import Sequence
|
| 9 |
+
from contextlib import AbstractContextManager
|
| 10 |
+
from multiprocessing import connection
|
| 11 |
+
from multiprocessing.process import BaseProcess
|
| 12 |
+
from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar,
|
| 13 |
+
Union, overload)
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
from torch.autograd.profiler import record_function
|
| 17 |
+
|
| 18 |
+
import vllm.envs as envs
|
| 19 |
+
from vllm.logger import init_logger
|
| 20 |
+
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
|
| 21 |
+
usage_message)
|
| 22 |
+
from vllm.utils import (get_open_port, get_open_zmq_ipc_path, get_tcp_uri,
|
| 23 |
+
kill_process_tree)
|
| 24 |
+
|
| 25 |
+
if TYPE_CHECKING:
|
| 26 |
+
import numpy as np
|
| 27 |
+
|
| 28 |
+
from vllm.v1.engine.coordinator import DPCoordinator
|
| 29 |
+
from vllm.v1.engine.utils import (CoreEngineActorManager,
|
| 30 |
+
CoreEngineProcManager)
|
| 31 |
+
|
| 32 |
+
logger = init_logger(__name__)
|
| 33 |
+
|
| 34 |
+
T = TypeVar("T")
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class ConstantList(Generic[T], Sequence):
|
| 38 |
+
|
| 39 |
+
def __init__(self, x: list[T]) -> None:
|
| 40 |
+
self._x = x
|
| 41 |
+
|
| 42 |
+
def append(self, item):
|
| 43 |
+
raise TypeError("Cannot append to a constant list")
|
| 44 |
+
|
| 45 |
+
def extend(self, item):
|
| 46 |
+
raise TypeError("Cannot extend a constant list")
|
| 47 |
+
|
| 48 |
+
def insert(self, item):
|
| 49 |
+
raise TypeError("Cannot insert into a constant list")
|
| 50 |
+
|
| 51 |
+
def pop(self, item):
|
| 52 |
+
raise TypeError("Cannot pop from a constant list")
|
| 53 |
+
|
| 54 |
+
def remove(self, item):
|
| 55 |
+
raise TypeError("Cannot remove from a constant list")
|
| 56 |
+
|
| 57 |
+
def clear(self):
|
| 58 |
+
raise TypeError("Cannot clear a constant list")
|
| 59 |
+
|
| 60 |
+
def index(self,
|
| 61 |
+
item: T,
|
| 62 |
+
start: int = 0,
|
| 63 |
+
stop: Optional[int] = None) -> int:
|
| 64 |
+
return self._x.index(item, start,
|
| 65 |
+
stop if stop is not None else len(self._x))
|
| 66 |
+
|
| 67 |
+
@overload
|
| 68 |
+
def __getitem__(self, item: int) -> T:
|
| 69 |
+
...
|
| 70 |
+
|
| 71 |
+
@overload
|
| 72 |
+
def __getitem__(self, s: slice, /) -> list[T]:
|
| 73 |
+
...
|
| 74 |
+
|
| 75 |
+
def __getitem__(self, item: Union[int, slice]) -> Union[T, list[T]]:
|
| 76 |
+
return self._x[item]
|
| 77 |
+
|
| 78 |
+
@overload
|
| 79 |
+
def __setitem__(self, item: int, value: T):
|
| 80 |
+
...
|
| 81 |
+
|
| 82 |
+
@overload
|
| 83 |
+
def __setitem__(self, s: slice, value: T, /):
|
| 84 |
+
...
|
| 85 |
+
|
| 86 |
+
def __setitem__(self, item: Union[int, slice], value: Union[T, list[T]]):
|
| 87 |
+
raise TypeError("Cannot set item in a constant list")
|
| 88 |
+
|
| 89 |
+
def __delitem__(self, item):
|
| 90 |
+
raise TypeError("Cannot delete item from a constant list")
|
| 91 |
+
|
| 92 |
+
def __iter__(self):
|
| 93 |
+
return iter(self._x)
|
| 94 |
+
|
| 95 |
+
def __contains__(self, item):
|
| 96 |
+
return item in self._x
|
| 97 |
+
|
| 98 |
+
def __len__(self):
|
| 99 |
+
return len(self._x)
|
| 100 |
+
|
| 101 |
+
def __repr__(self):
|
| 102 |
+
return f"ConstantList({self._x})"
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class CpuGpuBuffer:
|
| 106 |
+
"""Buffer to easily copy tensors between CPU and GPU."""
|
| 107 |
+
|
| 108 |
+
def __init__(
|
| 109 |
+
self,
|
| 110 |
+
*size: Union[int, torch.SymInt],
|
| 111 |
+
dtype: torch.dtype,
|
| 112 |
+
device: torch.device,
|
| 113 |
+
pin_memory: bool,
|
| 114 |
+
with_numpy: bool = True,
|
| 115 |
+
) -> None:
|
| 116 |
+
self.cpu = torch.zeros(*size,
|
| 117 |
+
dtype=dtype,
|
| 118 |
+
device="cpu",
|
| 119 |
+
pin_memory=pin_memory)
|
| 120 |
+
self.gpu = self.cpu.to(device)
|
| 121 |
+
self.np: np.ndarray
|
| 122 |
+
# To keep type hints simple (avoiding generics and subclasses), we
|
| 123 |
+
# only conditionally create the numpy array attribute. This can cause
|
| 124 |
+
# AttributeError if `self.np` is accessed when `with_numpy=False`.
|
| 125 |
+
if with_numpy:
|
| 126 |
+
if dtype == torch.bfloat16:
|
| 127 |
+
raise ValueError(
|
| 128 |
+
"Bfloat16 torch tensors cannot be directly cast to a "
|
| 129 |
+
"numpy array, so call CpuGpuBuffer with with_numpy=False")
|
| 130 |
+
self.np = self.cpu.numpy()
|
| 131 |
+
|
| 132 |
+
def copy_to_gpu(self, n: Optional[int] = None) -> torch.Tensor:
|
| 133 |
+
if n is None:
|
| 134 |
+
return self.gpu.copy_(self.cpu, non_blocking=True)
|
| 135 |
+
return self.gpu[:n].copy_(self.cpu[:n], non_blocking=True)
|
| 136 |
+
|
| 137 |
+
def copy_to_cpu(self, n: Optional[int] = None) -> torch.Tensor:
|
| 138 |
+
"""NOTE: Because this method is non-blocking, explicit synchronization
|
| 139 |
+
is needed to ensure the data is copied to CPU."""
|
| 140 |
+
if n is None:
|
| 141 |
+
return self.cpu.copy_(self.gpu, non_blocking=True)
|
| 142 |
+
return self.cpu[:n].copy_(self.gpu[:n], non_blocking=True)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def get_engine_client_zmq_addr(local_only: bool,
|
| 146 |
+
host: str,
|
| 147 |
+
port: int = 0) -> str:
|
| 148 |
+
"""Assign a new ZMQ socket address.
|
| 149 |
+
|
| 150 |
+
If local_only is True, participants are colocated and so a unique IPC
|
| 151 |
+
address will be returned.
|
| 152 |
+
|
| 153 |
+
Otherwise, the provided host and port will be used to construct a TCP
|
| 154 |
+
address (port == 0 means assign an available port)."""
|
| 155 |
+
|
| 156 |
+
return get_open_zmq_ipc_path() if local_only else (get_tcp_uri(
|
| 157 |
+
host, port or get_open_port()))
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class APIServerProcessManager:
|
| 161 |
+
"""Manages a group of API server processes.
|
| 162 |
+
|
| 163 |
+
Handles creation, monitoring, and termination of API server worker
|
| 164 |
+
processes. Also monitors extra processes to check if they are healthy.
|
| 165 |
+
"""
|
| 166 |
+
|
| 167 |
+
def __init__(
|
| 168 |
+
self,
|
| 169 |
+
target_server_fn: Callable,
|
| 170 |
+
listen_address: str,
|
| 171 |
+
sock: Any,
|
| 172 |
+
args: argparse.Namespace,
|
| 173 |
+
num_servers: int,
|
| 174 |
+
input_addresses: list[str],
|
| 175 |
+
output_addresses: list[str],
|
| 176 |
+
stats_update_address: Optional[str] = None,
|
| 177 |
+
):
|
| 178 |
+
"""Initialize and start API server worker processes.
|
| 179 |
+
|
| 180 |
+
Args:
|
| 181 |
+
target_server_fn: Function to call for each API server process
|
| 182 |
+
listen_address: Address to listen for client connections
|
| 183 |
+
sock: Socket for client connections
|
| 184 |
+
args: Command line arguments
|
| 185 |
+
num_servers: Number of API server processes to start
|
| 186 |
+
input_addresses: Input addresses for each API server
|
| 187 |
+
output_addresses: Output addresses for each API server
|
| 188 |
+
stats_update_address: Optional stats update address
|
| 189 |
+
"""
|
| 190 |
+
self.listen_address = listen_address
|
| 191 |
+
self.sock = sock
|
| 192 |
+
self.args = args
|
| 193 |
+
|
| 194 |
+
# Start API servers
|
| 195 |
+
spawn_context = multiprocessing.get_context("spawn")
|
| 196 |
+
self.processes: list[BaseProcess] = []
|
| 197 |
+
|
| 198 |
+
for i, in_addr, out_addr in zip(range(num_servers), input_addresses,
|
| 199 |
+
output_addresses):
|
| 200 |
+
client_config = {
|
| 201 |
+
"input_address": in_addr,
|
| 202 |
+
"output_address": out_addr,
|
| 203 |
+
"client_count": num_servers,
|
| 204 |
+
"client_index": i
|
| 205 |
+
}
|
| 206 |
+
if stats_update_address is not None:
|
| 207 |
+
client_config["stats_update_address"] = stats_update_address
|
| 208 |
+
|
| 209 |
+
proc = spawn_context.Process(target=target_server_fn,
|
| 210 |
+
name=f"ApiServer_{i}",
|
| 211 |
+
args=(listen_address, sock, args,
|
| 212 |
+
client_config))
|
| 213 |
+
self.processes.append(proc)
|
| 214 |
+
proc.start()
|
| 215 |
+
|
| 216 |
+
logger.info("Started %d API server processes", len(self.processes))
|
| 217 |
+
|
| 218 |
+
# Shutdown only the API server processes on garbage collection
|
| 219 |
+
# The extra processes are managed by their owners
|
| 220 |
+
self._finalizer = weakref.finalize(self, shutdown, self.processes)
|
| 221 |
+
|
| 222 |
+
def close(self) -> None:
|
| 223 |
+
self._finalizer()
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def wait_for_completion_or_failure(
|
| 227 |
+
api_server_manager: APIServerProcessManager,
|
| 228 |
+
engine_manager: Optional[Union["CoreEngineProcManager",
|
| 229 |
+
"CoreEngineActorManager"]] = None,
|
| 230 |
+
coordinator: Optional["DPCoordinator"] = None) -> None:
|
| 231 |
+
"""Wait for all processes to complete or detect if any fail.
|
| 232 |
+
|
| 233 |
+
Raises an exception if any process exits with a non-zero status.
|
| 234 |
+
|
| 235 |
+
Args:
|
| 236 |
+
api_server_manager: The manager for API servers.
|
| 237 |
+
engine_manager: The manager for engine processes.
|
| 238 |
+
If CoreEngineProcManager, it manages local engines;
|
| 239 |
+
if CoreEngineActorManager, it manages all engines.
|
| 240 |
+
coordinator: The coordinator for data parallel.
|
| 241 |
+
"""
|
| 242 |
+
|
| 243 |
+
from vllm.v1.engine.utils import (CoreEngineActorManager,
|
| 244 |
+
CoreEngineProcManager)
|
| 245 |
+
|
| 246 |
+
try:
|
| 247 |
+
logger.info("Waiting for API servers to complete ...")
|
| 248 |
+
# Create a mapping of sentinels to their corresponding processes
|
| 249 |
+
# for efficient lookup
|
| 250 |
+
sentinel_to_proc: dict[Any, BaseProcess] = {
|
| 251 |
+
proc.sentinel: proc
|
| 252 |
+
for proc in api_server_manager.processes
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
if coordinator:
|
| 256 |
+
sentinel_to_proc[coordinator.proc.sentinel] = coordinator.proc
|
| 257 |
+
|
| 258 |
+
actor_run_refs = []
|
| 259 |
+
if isinstance(engine_manager, CoreEngineProcManager):
|
| 260 |
+
for proc in engine_manager.processes:
|
| 261 |
+
sentinel_to_proc[proc.sentinel] = proc
|
| 262 |
+
elif isinstance(engine_manager, CoreEngineActorManager):
|
| 263 |
+
actor_run_refs = engine_manager.get_run_refs()
|
| 264 |
+
|
| 265 |
+
# Check if any process terminates
|
| 266 |
+
while sentinel_to_proc or actor_run_refs:
|
| 267 |
+
# Wait for any process to terminate
|
| 268 |
+
ready_sentinels: list[Any] = connection.wait(sentinel_to_proc,
|
| 269 |
+
timeout=5)
|
| 270 |
+
|
| 271 |
+
# Process any terminated processes
|
| 272 |
+
for sentinel in ready_sentinels:
|
| 273 |
+
proc = sentinel_to_proc.pop(sentinel)
|
| 274 |
+
|
| 275 |
+
# Check if process exited with error
|
| 276 |
+
if proc.exitcode != 0:
|
| 277 |
+
raise RuntimeError(
|
| 278 |
+
f"Process {proc.name} (PID: {proc.pid}) "
|
| 279 |
+
f"died with exit code {proc.exitcode}")
|
| 280 |
+
|
| 281 |
+
if actor_run_refs:
|
| 282 |
+
import ray
|
| 283 |
+
_, actor_run_refs = ray.wait(actor_run_refs, timeout=5)
|
| 284 |
+
|
| 285 |
+
except KeyboardInterrupt:
|
| 286 |
+
logger.info("Received KeyboardInterrupt, shutting down API servers...")
|
| 287 |
+
except Exception as e:
|
| 288 |
+
logger.exception("Exception occurred while running API servers: %s",
|
| 289 |
+
str(e))
|
| 290 |
+
raise
|
| 291 |
+
finally:
|
| 292 |
+
logger.info("Terminating remaining processes ...")
|
| 293 |
+
api_server_manager.close()
|
| 294 |
+
if coordinator:
|
| 295 |
+
coordinator.close()
|
| 296 |
+
if engine_manager:
|
| 297 |
+
engine_manager.close()
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
# Note(rob): shutdown function cannot be a bound method,
|
| 301 |
+
# else the gc cannot collect the object.
|
| 302 |
+
def shutdown(procs: list[BaseProcess]):
|
| 303 |
+
# Shutdown the process.
|
| 304 |
+
for proc in procs:
|
| 305 |
+
if proc.is_alive():
|
| 306 |
+
proc.terminate()
|
| 307 |
+
|
| 308 |
+
# Allow 5 seconds for remaining procs to terminate.
|
| 309 |
+
deadline = time.monotonic() + 5
|
| 310 |
+
for proc in procs:
|
| 311 |
+
remaining = deadline - time.monotonic()
|
| 312 |
+
if remaining <= 0:
|
| 313 |
+
break
|
| 314 |
+
if proc.is_alive():
|
| 315 |
+
proc.join(remaining)
|
| 316 |
+
|
| 317 |
+
for proc in procs:
|
| 318 |
+
if proc.is_alive() and (pid := proc.pid) is not None:
|
| 319 |
+
kill_process_tree(pid)
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
|
| 323 |
+
length: int) -> torch.Tensor:
|
| 324 |
+
"""
|
| 325 |
+
Copy the first length elements of a tensor into another tensor in a
|
| 326 |
+
non-blocking manner.
|
| 327 |
+
|
| 328 |
+
Used to copy pinned CPU tensor data to pre-allocated GPU tensors.
|
| 329 |
+
|
| 330 |
+
Returns the sliced target tensor.
|
| 331 |
+
"""
|
| 332 |
+
return to_tensor[:length].copy_(from_tensor[:length], non_blocking=True)
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def report_usage_stats(
|
| 336 |
+
vllm_config,
|
| 337 |
+
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT) -> None:
|
| 338 |
+
"""Report usage statistics if enabled."""
|
| 339 |
+
|
| 340 |
+
if not is_usage_stats_enabled():
|
| 341 |
+
return
|
| 342 |
+
|
| 343 |
+
from vllm.model_executor.model_loader import get_architecture_class_name
|
| 344 |
+
|
| 345 |
+
usage_message.report_usage(
|
| 346 |
+
get_architecture_class_name(vllm_config.model_config),
|
| 347 |
+
usage_context,
|
| 348 |
+
extra_kvs={
|
| 349 |
+
# Common configuration
|
| 350 |
+
"dtype":
|
| 351 |
+
str(vllm_config.model_config.dtype),
|
| 352 |
+
"tensor_parallel_size":
|
| 353 |
+
vllm_config.parallel_config.tensor_parallel_size,
|
| 354 |
+
"block_size":
|
| 355 |
+
vllm_config.cache_config.block_size,
|
| 356 |
+
"gpu_memory_utilization":
|
| 357 |
+
vllm_config.cache_config.gpu_memory_utilization,
|
| 358 |
+
"kv_cache_memory_bytes":
|
| 359 |
+
vllm_config.cache_config.kv_cache_memory_bytes,
|
| 360 |
+
# Quantization
|
| 361 |
+
"quantization":
|
| 362 |
+
vllm_config.model_config.quantization,
|
| 363 |
+
"kv_cache_dtype":
|
| 364 |
+
str(vllm_config.cache_config.cache_dtype),
|
| 365 |
+
|
| 366 |
+
# Feature flags
|
| 367 |
+
"enable_lora":
|
| 368 |
+
bool(vllm_config.lora_config),
|
| 369 |
+
"enable_prefix_caching":
|
| 370 |
+
vllm_config.cache_config.enable_prefix_caching,
|
| 371 |
+
"enforce_eager":
|
| 372 |
+
vllm_config.model_config.enforce_eager,
|
| 373 |
+
"disable_custom_all_reduce":
|
| 374 |
+
vllm_config.parallel_config.disable_custom_all_reduce,
|
| 375 |
+
})
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
_PROFILER_FUNC = None
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
def record_function_or_nullcontext(name: str) -> AbstractContextManager:
|
| 382 |
+
global _PROFILER_FUNC
|
| 383 |
+
|
| 384 |
+
# fast path assume it is set
|
| 385 |
+
if _PROFILER_FUNC is not None:
|
| 386 |
+
return _PROFILER_FUNC(name)
|
| 387 |
+
|
| 388 |
+
func = contextlib.nullcontext
|
| 389 |
+
if envs.VLLM_CUSTOM_SCOPES_FOR_PROFILING:
|
| 390 |
+
func = record_function
|
| 391 |
+
elif envs.VLLM_NVTX_SCOPES_FOR_PROFILING:
|
| 392 |
+
import nvtx
|
| 393 |
+
func = nvtx.annotate
|
| 394 |
+
|
| 395 |
+
_PROFILER_FUNC = func
|
| 396 |
+
return func(name)
|
vllm_hacked/v1/worker/gpu_input_batch.py
ADDED
|
@@ -0,0 +1,669 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
# Datastructures defining an input batch
|
| 3 |
+
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Optional, cast
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from vllm.lora.request import LoRARequest
|
| 11 |
+
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
| 12 |
+
from vllm.sampling_params import SamplingParams, SamplingType
|
| 13 |
+
from vllm.utils import swap_dict_values
|
| 14 |
+
from vllm.v1.outputs import LogprobsTensors
|
| 15 |
+
from vllm.v1.sample.metadata import SamplingMetadata
|
| 16 |
+
from vllm.v1.utils import copy_slice
|
| 17 |
+
from vllm.v1.worker.block_table import BlockTable
|
| 18 |
+
|
| 19 |
+
_SAMPLING_EPS = 1e-5
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class CachedRequestState:
|
| 24 |
+
|
| 25 |
+
req_id: str
|
| 26 |
+
prompt_token_ids: list[int]
|
| 27 |
+
prompt: Optional[str]
|
| 28 |
+
mm_inputs: list[MultiModalKwargs]
|
| 29 |
+
mm_positions: list[PlaceholderRange]
|
| 30 |
+
sampling_params: SamplingParams
|
| 31 |
+
generator: Optional[torch.Generator]
|
| 32 |
+
|
| 33 |
+
block_ids: list[int]
|
| 34 |
+
num_computed_tokens: int
|
| 35 |
+
output_token_ids: list[int]
|
| 36 |
+
|
| 37 |
+
mrope_positions: Optional[torch.Tensor] = None
|
| 38 |
+
mrope_position_delta: Optional[int] = None
|
| 39 |
+
|
| 40 |
+
lora_request: Optional[LoRARequest] = None
|
| 41 |
+
|
| 42 |
+
@property
|
| 43 |
+
def num_tokens(self) -> int:
|
| 44 |
+
return len(self.prompt_token_ids) + len(self.output_token_ids)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class InputBatch:
|
| 48 |
+
|
| 49 |
+
def __init__(
|
| 50 |
+
self,
|
| 51 |
+
max_num_reqs: int,
|
| 52 |
+
max_model_len: int,
|
| 53 |
+
max_num_blocks_per_req: int,
|
| 54 |
+
device: torch.device,
|
| 55 |
+
pin_memory: bool,
|
| 56 |
+
vocab_size: int,
|
| 57 |
+
):
|
| 58 |
+
self.max_num_reqs = max_num_reqs
|
| 59 |
+
self.max_model_len = max_model_len
|
| 60 |
+
self.max_num_blocks_per_req = max_num_blocks_per_req
|
| 61 |
+
self.device = device
|
| 62 |
+
self.pin_memory = pin_memory
|
| 63 |
+
self.vocab_size = vocab_size
|
| 64 |
+
|
| 65 |
+
self._req_ids: list[Optional[str]] = []
|
| 66 |
+
self.req_id_to_index: dict[str, int] = {}
|
| 67 |
+
|
| 68 |
+
# TODO(woosuk): This buffer could be too large if max_model_len is big.
|
| 69 |
+
# Find a way to reduce the CPU memory usage.
|
| 70 |
+
# This buffer is not directly transferred to the GPU, so it does not
|
| 71 |
+
# need to be pinned.
|
| 72 |
+
self.token_ids_cpu_tensor = torch.zeros(
|
| 73 |
+
(max_num_reqs, max_model_len),
|
| 74 |
+
device="cpu",
|
| 75 |
+
dtype=torch.int32,
|
| 76 |
+
pin_memory=False,
|
| 77 |
+
)
|
| 78 |
+
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
|
| 79 |
+
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
|
| 80 |
+
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
|
| 81 |
+
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
|
| 82 |
+
self.num_computed_tokens_cpu_tensor = torch.zeros(
|
| 83 |
+
(max_num_reqs, ),
|
| 84 |
+
device="cpu",
|
| 85 |
+
dtype=torch.int32,
|
| 86 |
+
pin_memory=pin_memory,
|
| 87 |
+
)
|
| 88 |
+
self.num_computed_tokens_cpu = \
|
| 89 |
+
self.num_computed_tokens_cpu_tensor.numpy()
|
| 90 |
+
|
| 91 |
+
# Block table.
|
| 92 |
+
self.block_table = BlockTable(
|
| 93 |
+
max_num_reqs=max_num_reqs,
|
| 94 |
+
max_num_blocks_per_req=max_num_blocks_per_req,
|
| 95 |
+
pin_memory=pin_memory,
|
| 96 |
+
device=device,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# Sampling-related.
|
| 100 |
+
self.temperature = torch.empty((max_num_reqs, ),
|
| 101 |
+
dtype=torch.float32,
|
| 102 |
+
device=device)
|
| 103 |
+
self.temperature_cpu_tensor = torch.empty((max_num_reqs, ),
|
| 104 |
+
dtype=torch.float32,
|
| 105 |
+
device="cpu",
|
| 106 |
+
pin_memory=pin_memory)
|
| 107 |
+
self.temperature_cpu = self.temperature_cpu_tensor.numpy()
|
| 108 |
+
self.greedy_reqs: set[str] = set()
|
| 109 |
+
self.random_reqs: set[str] = set()
|
| 110 |
+
|
| 111 |
+
self.top_p = torch.empty((max_num_reqs, ),
|
| 112 |
+
dtype=torch.float32,
|
| 113 |
+
device=device)
|
| 114 |
+
self.top_p_cpu_tensor = torch.empty((max_num_reqs, ),
|
| 115 |
+
dtype=torch.float32,
|
| 116 |
+
device="cpu",
|
| 117 |
+
pin_memory=pin_memory)
|
| 118 |
+
self.top_p_cpu = self.top_p_cpu_tensor.numpy()
|
| 119 |
+
self.top_p_reqs: set[str] = set()
|
| 120 |
+
|
| 121 |
+
self.top_k = torch.empty((max_num_reqs, ),
|
| 122 |
+
dtype=torch.int32,
|
| 123 |
+
device=device)
|
| 124 |
+
self.top_k_cpu_tensor = torch.empty((max_num_reqs, ),
|
| 125 |
+
dtype=torch.int32,
|
| 126 |
+
device="cpu",
|
| 127 |
+
pin_memory=pin_memory)
|
| 128 |
+
self.top_k_cpu = self.top_k_cpu_tensor.numpy()
|
| 129 |
+
self.top_k_reqs: set[str] = set()
|
| 130 |
+
|
| 131 |
+
self.min_p = torch.empty((max_num_reqs, ),
|
| 132 |
+
dtype=torch.float32,
|
| 133 |
+
device=device)
|
| 134 |
+
self.min_p_cpu_tensor = torch.empty((max_num_reqs, ),
|
| 135 |
+
dtype=torch.float32,
|
| 136 |
+
device="cpu",
|
| 137 |
+
pin_memory=pin_memory)
|
| 138 |
+
self.min_p_cpu = self.min_p_cpu_tensor.numpy()
|
| 139 |
+
self.min_p_reqs: set[str] = set()
|
| 140 |
+
|
| 141 |
+
# Frequency penalty related data structures
|
| 142 |
+
self.frequency_penalties = torch.empty((max_num_reqs, ),
|
| 143 |
+
dtype=torch.float,
|
| 144 |
+
device=device)
|
| 145 |
+
self.frequency_penalties_cpu_tensor = torch.empty(
|
| 146 |
+
(max_num_reqs, ),
|
| 147 |
+
dtype=torch.float,
|
| 148 |
+
device="cpu",
|
| 149 |
+
pin_memory=pin_memory)
|
| 150 |
+
self.frequency_penalties_cpu = \
|
| 151 |
+
self.frequency_penalties_cpu_tensor.numpy()
|
| 152 |
+
self.frequency_penalties_reqs: set[str] = set()
|
| 153 |
+
|
| 154 |
+
# Presence penalty related data structures
|
| 155 |
+
self.presence_penalties = torch.empty((max_num_reqs, ),
|
| 156 |
+
dtype=torch.float,
|
| 157 |
+
device=device)
|
| 158 |
+
self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ),
|
| 159 |
+
dtype=torch.float,
|
| 160 |
+
device="cpu",
|
| 161 |
+
pin_memory=pin_memory)
|
| 162 |
+
self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy(
|
| 163 |
+
)
|
| 164 |
+
self.presence_penalties_reqs: set[str] = set()
|
| 165 |
+
|
| 166 |
+
# Repetition penalty related data structures
|
| 167 |
+
self.repetition_penalties = torch.empty((max_num_reqs, ),
|
| 168 |
+
dtype=torch.float,
|
| 169 |
+
device=device)
|
| 170 |
+
self.repetition_penalties_cpu_tensor = torch.empty(
|
| 171 |
+
(max_num_reqs, ),
|
| 172 |
+
dtype=torch.float,
|
| 173 |
+
device="cpu",
|
| 174 |
+
pin_memory=pin_memory)
|
| 175 |
+
self.repetition_penalties_cpu = \
|
| 176 |
+
self.repetition_penalties_cpu_tensor.numpy()
|
| 177 |
+
self.repetition_penalties_reqs: set[str] = set()
|
| 178 |
+
|
| 179 |
+
# req_index -> (min_tokens, stop_token_ids)
|
| 180 |
+
self.min_tokens: dict[int, tuple[int, set[int]]] = {}
|
| 181 |
+
|
| 182 |
+
# lora related
|
| 183 |
+
self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
|
| 184 |
+
dtype=np.int32)
|
| 185 |
+
self.lora_id_to_request_ids: dict[int, set[str]] = {}
|
| 186 |
+
self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
|
| 187 |
+
|
| 188 |
+
# req_index -> generator
|
| 189 |
+
# NOTE(woosuk): The indices of the requests that do not have their own
|
| 190 |
+
# generator should not be included in the dictionary.
|
| 191 |
+
self.generators: dict[int, torch.Generator] = {}
|
| 192 |
+
|
| 193 |
+
self.num_logprobs: dict[str, int] = {}
|
| 194 |
+
# NOTE(rob): num_prompt_logprobs only includes reqs
|
| 195 |
+
# that are currently in the prefill phase.
|
| 196 |
+
self.num_prompt_logprobs: dict[str, int] = {}
|
| 197 |
+
|
| 198 |
+
# To accumulate prompt logprobs tensor chunks across prefill steps.
|
| 199 |
+
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
|
| 200 |
+
|
| 201 |
+
self.logit_bias: list[Optional[dict[int,
|
| 202 |
+
float]]] = [None] * max_num_reqs
|
| 203 |
+
self.has_allowed_token_ids: set[str] = set()
|
| 204 |
+
# NOTE(lufang): In the mask tensor, if the corresponding token allowed,
|
| 205 |
+
# the value is False. Since we use masked_fill_ to set -inf.
|
| 206 |
+
self.allowed_token_ids_mask: Optional[torch.Tensor] = None
|
| 207 |
+
self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None
|
| 208 |
+
|
| 209 |
+
# req_index -> bad_words_token_ids
|
| 210 |
+
self.bad_words_token_ids: dict[int, list[list[int]]] = {}
|
| 211 |
+
|
| 212 |
+
self.req_output_token_ids: list[Optional[list[int]]] = []
|
| 213 |
+
|
| 214 |
+
# This is updated each time the batch constituents change.
|
| 215 |
+
self.sampling_metadata = self._make_sampling_metadata()
|
| 216 |
+
|
| 217 |
+
@property
|
| 218 |
+
def req_ids(self) -> list[str]:
|
| 219 |
+
# None elements should only be present transiently
|
| 220 |
+
# while performing state updates to the batch.
|
| 221 |
+
return cast(list[str], self._req_ids)
|
| 222 |
+
|
| 223 |
+
def add_request(
|
| 224 |
+
self,
|
| 225 |
+
request: "CachedRequestState",
|
| 226 |
+
req_index: Optional[int] = None,
|
| 227 |
+
) -> None:
|
| 228 |
+
if req_index is None:
|
| 229 |
+
req_index = self.num_reqs
|
| 230 |
+
assert req_index < self.max_num_reqs
|
| 231 |
+
|
| 232 |
+
req_id = request.req_id
|
| 233 |
+
if req_index == len(self._req_ids):
|
| 234 |
+
self._req_ids.append(req_id)
|
| 235 |
+
self.req_output_token_ids.append(request.output_token_ids)
|
| 236 |
+
else:
|
| 237 |
+
self._req_ids[req_index] = req_id
|
| 238 |
+
self.req_output_token_ids[req_index] = request.output_token_ids
|
| 239 |
+
|
| 240 |
+
self.req_id_to_index[req_id] = req_index
|
| 241 |
+
|
| 242 |
+
# Copy the prompt token ids and output token ids.
|
| 243 |
+
num_prompt_tokens = len(request.prompt_token_ids)
|
| 244 |
+
self.num_prompt_tokens[req_index] = num_prompt_tokens
|
| 245 |
+
self.token_ids_cpu[
|
| 246 |
+
req_index, :num_prompt_tokens] = request.prompt_token_ids
|
| 247 |
+
start_idx = num_prompt_tokens
|
| 248 |
+
end_idx = start_idx + len(request.output_token_ids)
|
| 249 |
+
self.token_ids_cpu[req_index,
|
| 250 |
+
start_idx:end_idx] = request.output_token_ids
|
| 251 |
+
# Number of token ids in token_ids_cpu.
|
| 252 |
+
# NOTE(woosuk): This may include spec decode tokens.
|
| 253 |
+
self.num_tokens[req_index] = request.num_tokens
|
| 254 |
+
# Number of tokens without spec decode tokens.
|
| 255 |
+
self.num_tokens_no_spec[req_index] = request.num_tokens
|
| 256 |
+
|
| 257 |
+
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
|
| 258 |
+
self.block_table.add_row(request.block_ids, req_index)
|
| 259 |
+
|
| 260 |
+
sampling_params = request.sampling_params
|
| 261 |
+
if sampling_params.sampling_type == SamplingType.GREEDY:
|
| 262 |
+
# Avoid later division by zero.
|
| 263 |
+
self.temperature_cpu[req_index] = -1.0
|
| 264 |
+
self.greedy_reqs.add(req_id)
|
| 265 |
+
else:
|
| 266 |
+
self.temperature_cpu[req_index] = sampling_params.temperature
|
| 267 |
+
self.random_reqs.add(req_id)
|
| 268 |
+
|
| 269 |
+
self.top_p_cpu[req_index] = sampling_params.top_p
|
| 270 |
+
if sampling_params.top_p < 1:
|
| 271 |
+
self.top_p_reqs.add(req_id)
|
| 272 |
+
top_k = sampling_params.top_k
|
| 273 |
+
if 0 < top_k < self.vocab_size:
|
| 274 |
+
self.top_k_reqs.add(req_id)
|
| 275 |
+
else:
|
| 276 |
+
top_k = self.vocab_size
|
| 277 |
+
self.top_k_cpu[req_index] = top_k
|
| 278 |
+
self.min_p_cpu[req_index] = sampling_params.min_p
|
| 279 |
+
self.frequency_penalties_cpu[
|
| 280 |
+
req_index] = sampling_params.frequency_penalty
|
| 281 |
+
if sampling_params.min_p > _SAMPLING_EPS:
|
| 282 |
+
self.min_p_reqs.add(req_id)
|
| 283 |
+
if sampling_params.frequency_penalty != 0.0:
|
| 284 |
+
self.frequency_penalties_reqs.add(req_id)
|
| 285 |
+
self.presence_penalties_cpu[
|
| 286 |
+
req_index] = sampling_params.presence_penalty
|
| 287 |
+
if sampling_params.presence_penalty != 0.0:
|
| 288 |
+
self.presence_penalties_reqs.add(req_id)
|
| 289 |
+
self.repetition_penalties_cpu[
|
| 290 |
+
req_index] = sampling_params.repetition_penalty
|
| 291 |
+
if sampling_params.repetition_penalty != 1.0:
|
| 292 |
+
self.repetition_penalties_reqs.add(req_id)
|
| 293 |
+
if sampling_params.min_tokens:
|
| 294 |
+
self.min_tokens[req_index] = (sampling_params.min_tokens,
|
| 295 |
+
sampling_params.all_stop_token_ids)
|
| 296 |
+
|
| 297 |
+
# NOTE(woosuk): self.generators should not include the requests that
|
| 298 |
+
# do not have their own generator.
|
| 299 |
+
if request.generator is not None:
|
| 300 |
+
self.generators[req_index] = request.generator
|
| 301 |
+
|
| 302 |
+
if sampling_params.logprobs is not None:
|
| 303 |
+
self.num_logprobs[req_id] = sampling_params.logprobs
|
| 304 |
+
if sampling_params.prompt_logprobs is not None:
|
| 305 |
+
self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs
|
| 306 |
+
if sampling_params.logit_bias is not None:
|
| 307 |
+
self.logit_bias[req_index] = sampling_params.logit_bias
|
| 308 |
+
|
| 309 |
+
if sampling_params.allowed_token_ids:
|
| 310 |
+
self.has_allowed_token_ids.add(req_id)
|
| 311 |
+
if self.allowed_token_ids_mask_cpu_tensor is None:
|
| 312 |
+
# Lazy allocation for this tensor, which can be large.
|
| 313 |
+
# False means we don't fill with -inf.
|
| 314 |
+
self.allowed_token_ids_mask = torch.zeros(self.max_num_reqs,
|
| 315 |
+
self.vocab_size,
|
| 316 |
+
dtype=torch.bool,
|
| 317 |
+
device=self.device)
|
| 318 |
+
self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
|
| 319 |
+
self.max_num_reqs,
|
| 320 |
+
self.vocab_size,
|
| 321 |
+
dtype=torch.bool,
|
| 322 |
+
device="cpu")
|
| 323 |
+
self.allowed_token_ids_mask_cpu_tensor[req_index] = True
|
| 324 |
+
# False means we don't fill with -inf.
|
| 325 |
+
self.allowed_token_ids_mask_cpu_tensor[req_index][
|
| 326 |
+
sampling_params.allowed_token_ids] = False
|
| 327 |
+
|
| 328 |
+
if sampling_params.bad_words_token_ids:
|
| 329 |
+
self.bad_words_token_ids[
|
| 330 |
+
req_index] = sampling_params.bad_words_token_ids
|
| 331 |
+
|
| 332 |
+
# Add request lora ID
|
| 333 |
+
if request.lora_request:
|
| 334 |
+
lora_id = request.lora_request.lora_int_id
|
| 335 |
+
if lora_id not in self.lora_id_to_request_ids:
|
| 336 |
+
self.lora_id_to_request_ids[lora_id] = set()
|
| 337 |
+
|
| 338 |
+
self.request_lora_mapping[req_index] = lora_id
|
| 339 |
+
self.lora_id_to_request_ids[lora_id].add(request.req_id)
|
| 340 |
+
self.lora_id_to_lora_request[lora_id] = request.lora_request
|
| 341 |
+
else:
|
| 342 |
+
# No LoRA
|
| 343 |
+
self.request_lora_mapping[req_index] = 0
|
| 344 |
+
|
| 345 |
+
def remove_request(self, req_id: str) -> Optional[int]:
|
| 346 |
+
"""This method must always be followed by a call to condense()."""
|
| 347 |
+
|
| 348 |
+
req_index = self.req_id_to_index.pop(req_id, None)
|
| 349 |
+
if req_index is None:
|
| 350 |
+
return None
|
| 351 |
+
self._req_ids[req_index] = None
|
| 352 |
+
self.req_output_token_ids[req_index] = None
|
| 353 |
+
|
| 354 |
+
self.greedy_reqs.discard(req_id)
|
| 355 |
+
self.random_reqs.discard(req_id)
|
| 356 |
+
self.top_p_reqs.discard(req_id)
|
| 357 |
+
self.top_k_reqs.discard(req_id)
|
| 358 |
+
self.min_p_reqs.discard(req_id)
|
| 359 |
+
self.min_tokens.pop(req_index, None)
|
| 360 |
+
self.frequency_penalties_reqs.discard(req_id)
|
| 361 |
+
self.presence_penalties_reqs.discard(req_id)
|
| 362 |
+
self.repetition_penalties_reqs.discard(req_id)
|
| 363 |
+
self.generators.pop(req_index, None)
|
| 364 |
+
self.num_logprobs.pop(req_id, None)
|
| 365 |
+
self.num_prompt_logprobs.pop(req_id, None)
|
| 366 |
+
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
|
| 367 |
+
|
| 368 |
+
# LoRA
|
| 369 |
+
lora_id = self.request_lora_mapping[req_index]
|
| 370 |
+
if lora_id != 0:
|
| 371 |
+
self.lora_id_to_request_ids[lora_id].discard(req_id)
|
| 372 |
+
if len(self.lora_id_to_request_ids[lora_id]) == 0:
|
| 373 |
+
self.lora_id_to_request_ids.pop(lora_id)
|
| 374 |
+
self.lora_id_to_lora_request.pop(lora_id)
|
| 375 |
+
self.request_lora_mapping[req_index] = 0
|
| 376 |
+
|
| 377 |
+
self.logit_bias[req_index] = None
|
| 378 |
+
self.has_allowed_token_ids.discard(req_id)
|
| 379 |
+
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
| 380 |
+
# False means we don't fill with -inf.
|
| 381 |
+
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
|
| 382 |
+
self.bad_words_token_ids.pop(req_index, None)
|
| 383 |
+
return req_index
|
| 384 |
+
|
| 385 |
+
def swap_states(self, i1: int, i2: int) -> None:
|
| 386 |
+
old_id_i1 = self._req_ids[i1]
|
| 387 |
+
old_id_i2 = self._req_ids[i2]
|
| 388 |
+
self._req_ids[i1], self._req_ids[i2] =\
|
| 389 |
+
self._req_ids[i2], self._req_ids[i1] # noqa
|
| 390 |
+
self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\
|
| 391 |
+
self.req_output_token_ids[i2], self.req_output_token_ids[i1]
|
| 392 |
+
assert old_id_i1 is not None and old_id_i2 is not None
|
| 393 |
+
self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\
|
| 394 |
+
self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1]
|
| 395 |
+
self.num_tokens[i1], self.num_tokens[i2] =\
|
| 396 |
+
self.num_tokens[i2], self.num_tokens[i1]
|
| 397 |
+
self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\
|
| 398 |
+
self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1]
|
| 399 |
+
self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\
|
| 400 |
+
self.num_prompt_tokens[i2], self.num_prompt_tokens[i1]
|
| 401 |
+
self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\
|
| 402 |
+
self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1]
|
| 403 |
+
self.temperature_cpu[i1], self.temperature_cpu[i2] =\
|
| 404 |
+
self.temperature_cpu[i2], self.temperature_cpu[i1]
|
| 405 |
+
self.top_p_cpu[i1], self.top_p_cpu[i2] =\
|
| 406 |
+
self.top_p_cpu[i2], self.top_p_cpu[i1]
|
| 407 |
+
self.top_k_cpu[i1], self.top_k_cpu[i2] =\
|
| 408 |
+
self.top_k_cpu[i2], self.top_k_cpu[i1]
|
| 409 |
+
self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] =\
|
| 410 |
+
self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1]
|
| 411 |
+
self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] =\
|
| 412 |
+
self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
|
| 413 |
+
self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\
|
| 414 |
+
self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]
|
| 415 |
+
self.min_p_cpu[i1], self.min_p_cpu[i2] =\
|
| 416 |
+
self.min_p_cpu[i2], self.min_p_cpu[i1]
|
| 417 |
+
|
| 418 |
+
# NOTE: the following is unsafe
|
| 419 |
+
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
|
| 420 |
+
# self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
|
| 421 |
+
# instead, we need to temporiarily copy the data for one of the indices
|
| 422 |
+
# TODO(lucas): optimize this by only copying valid indices
|
| 423 |
+
tmp = self.token_ids_cpu[i1, ...].copy()
|
| 424 |
+
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
|
| 425 |
+
self.token_ids_cpu[i2, ...] = tmp
|
| 426 |
+
|
| 427 |
+
swap_dict_values(self.generators, i1, i2)
|
| 428 |
+
swap_dict_values(self.min_tokens, i1, i2)
|
| 429 |
+
swap_dict_values(self.bad_words_token_ids, i1, i2)
|
| 430 |
+
|
| 431 |
+
self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\
|
| 432 |
+
self.request_lora_mapping[i2], self.request_lora_mapping[i1]
|
| 433 |
+
self.logit_bias[i1], self.logit_bias[i2] =\
|
| 434 |
+
self.logit_bias[i2], self.logit_bias[i1]
|
| 435 |
+
|
| 436 |
+
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
| 437 |
+
self.allowed_token_ids_mask_cpu_tensor[i1], \
|
| 438 |
+
self.allowed_token_ids_mask_cpu_tensor[i2] =\
|
| 439 |
+
self.allowed_token_ids_mask_cpu_tensor[i2], \
|
| 440 |
+
self.allowed_token_ids_mask_cpu_tensor[i1]
|
| 441 |
+
self.block_table.swap_row(i1, i2)
|
| 442 |
+
|
| 443 |
+
def condense(self, empty_req_indices: list[int]) -> None:
|
| 444 |
+
num_reqs = self.num_reqs
|
| 445 |
+
if num_reqs == 0:
|
| 446 |
+
# The batched states are empty.
|
| 447 |
+
self._req_ids.clear()
|
| 448 |
+
self.req_output_token_ids.clear()
|
| 449 |
+
return
|
| 450 |
+
|
| 451 |
+
# NOTE(woosuk): This function assumes that the empty_req_indices
|
| 452 |
+
# is sorted in descending order.
|
| 453 |
+
last_req_index = num_reqs + len(empty_req_indices) - 1
|
| 454 |
+
while empty_req_indices:
|
| 455 |
+
# Find the largest non-empty index.
|
| 456 |
+
while last_req_index in empty_req_indices:
|
| 457 |
+
last_req_index -= 1
|
| 458 |
+
|
| 459 |
+
# Find the smallest empty index.
|
| 460 |
+
empty_index = empty_req_indices.pop()
|
| 461 |
+
if empty_index >= last_req_index:
|
| 462 |
+
break
|
| 463 |
+
|
| 464 |
+
# Swap the states.
|
| 465 |
+
req_id = self._req_ids[last_req_index]
|
| 466 |
+
output_token_ids = self.req_output_token_ids[last_req_index]
|
| 467 |
+
assert req_id is not None
|
| 468 |
+
self._req_ids[empty_index] = req_id
|
| 469 |
+
self._req_ids[last_req_index] = None
|
| 470 |
+
self.req_output_token_ids[empty_index] = output_token_ids
|
| 471 |
+
self.req_output_token_ids[last_req_index] = None
|
| 472 |
+
self.req_id_to_index[req_id] = empty_index
|
| 473 |
+
|
| 474 |
+
num_tokens = self.num_tokens[last_req_index]
|
| 475 |
+
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
|
| 476 |
+
last_req_index, :num_tokens]
|
| 477 |
+
self.num_tokens[empty_index] = num_tokens
|
| 478 |
+
self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
|
| 479 |
+
last_req_index]
|
| 480 |
+
self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[
|
| 481 |
+
last_req_index]
|
| 482 |
+
self.num_computed_tokens_cpu[
|
| 483 |
+
empty_index] = self.num_computed_tokens_cpu[last_req_index]
|
| 484 |
+
self.block_table.move_row(last_req_index, empty_index)
|
| 485 |
+
self.temperature_cpu[empty_index] = self.temperature_cpu[
|
| 486 |
+
last_req_index]
|
| 487 |
+
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
|
| 488 |
+
self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
|
| 489 |
+
self.frequency_penalties_cpu[
|
| 490 |
+
empty_index] = self.frequency_penalties_cpu[last_req_index]
|
| 491 |
+
self.presence_penalties_cpu[
|
| 492 |
+
empty_index] = self.presence_penalties_cpu[last_req_index]
|
| 493 |
+
self.repetition_penalties_cpu[
|
| 494 |
+
empty_index] = self.repetition_penalties_cpu[last_req_index]
|
| 495 |
+
self.min_p_cpu[empty_index] = self.min_p_cpu[last_req_index]
|
| 496 |
+
generator = self.generators.pop(last_req_index, None)
|
| 497 |
+
if generator is not None:
|
| 498 |
+
self.generators[empty_index] = generator
|
| 499 |
+
|
| 500 |
+
min_token = self.min_tokens.pop(last_req_index, None)
|
| 501 |
+
if min_token is not None:
|
| 502 |
+
self.min_tokens[empty_index] = min_token
|
| 503 |
+
|
| 504 |
+
self.request_lora_mapping[empty_index] = self.request_lora_mapping[
|
| 505 |
+
last_req_index]
|
| 506 |
+
|
| 507 |
+
self.logit_bias[empty_index] = self.logit_bias[last_req_index]
|
| 508 |
+
|
| 509 |
+
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
| 510 |
+
self.allowed_token_ids_mask_cpu_tensor[
|
| 511 |
+
empty_index] = self.allowed_token_ids_mask_cpu_tensor[
|
| 512 |
+
last_req_index]
|
| 513 |
+
|
| 514 |
+
bad_words_token_ids = self.bad_words_token_ids.pop(
|
| 515 |
+
last_req_index, None)
|
| 516 |
+
if bad_words_token_ids is not None:
|
| 517 |
+
self.bad_words_token_ids[empty_index] = bad_words_token_ids
|
| 518 |
+
# Decrement last_req_index since it is now empty.
|
| 519 |
+
last_req_index -= 1
|
| 520 |
+
|
| 521 |
+
# Trim lists to the batch size.
|
| 522 |
+
del self._req_ids[self.num_reqs:]
|
| 523 |
+
del self.req_output_token_ids[self.num_reqs:]
|
| 524 |
+
|
| 525 |
+
def refresh_sampling_metadata(self):
|
| 526 |
+
self.sampling_metadata = self._make_sampling_metadata()
|
| 527 |
+
|
| 528 |
+
def _make_sampling_metadata(self) -> SamplingMetadata:
|
| 529 |
+
num_reqs = self.num_reqs
|
| 530 |
+
if not self.all_greedy:
|
| 531 |
+
temperature = copy_slice(self.temperature_cpu_tensor,
|
| 532 |
+
self.temperature, num_reqs)
|
| 533 |
+
else:
|
| 534 |
+
temperature = None
|
| 535 |
+
if not self.no_top_p:
|
| 536 |
+
copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
|
| 537 |
+
if not self.no_top_k:
|
| 538 |
+
copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs)
|
| 539 |
+
if not self.no_min_p:
|
| 540 |
+
copy_slice(self.min_p_cpu_tensor, self.min_p, num_reqs)
|
| 541 |
+
|
| 542 |
+
if not self.no_penalties:
|
| 543 |
+
# Since syncing these tensors is expensive only copy them
|
| 544 |
+
# if necessary i.e. if there are requests which require
|
| 545 |
+
# penalties to be applied during sampling.
|
| 546 |
+
copy_slice(self.frequency_penalties_cpu_tensor,
|
| 547 |
+
self.frequency_penalties, num_reqs)
|
| 548 |
+
copy_slice(self.presence_penalties_cpu_tensor,
|
| 549 |
+
self.presence_penalties, num_reqs)
|
| 550 |
+
copy_slice(self.repetition_penalties_cpu_tensor,
|
| 551 |
+
self.repetition_penalties, num_reqs)
|
| 552 |
+
|
| 553 |
+
# The prompt tokens are used only for applying penalties during
|
| 554 |
+
# the sampling process. Hence copy these tensors only when
|
| 555 |
+
# there are requests which need penalties to be applied.
|
| 556 |
+
prompt_token_ids = self._make_prompt_token_ids_tensor()
|
| 557 |
+
else:
|
| 558 |
+
prompt_token_ids = None
|
| 559 |
+
|
| 560 |
+
allowed_token_ids_mask: Optional[torch.Tensor] = None
|
| 561 |
+
if not self.no_allowed_token_ids:
|
| 562 |
+
assert self.allowed_token_ids_mask is not None
|
| 563 |
+
copy_slice(self.allowed_token_ids_mask_cpu_tensor,
|
| 564 |
+
self.allowed_token_ids_mask, num_reqs)
|
| 565 |
+
allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]
|
| 566 |
+
|
| 567 |
+
return SamplingMetadata(
|
| 568 |
+
temperature=temperature,
|
| 569 |
+
all_greedy=self.all_greedy,
|
| 570 |
+
all_random=self.all_random,
|
| 571 |
+
top_p=None if self.no_top_p else self.top_p[:num_reqs],
|
| 572 |
+
top_k=None if self.no_top_k else self.top_k[:num_reqs],
|
| 573 |
+
min_p=None if self.no_min_p else self.min_p[:num_reqs],
|
| 574 |
+
generators=self.generators,
|
| 575 |
+
max_num_logprobs=self.max_num_logprobs,
|
| 576 |
+
prompt_token_ids=prompt_token_ids,
|
| 577 |
+
frequency_penalties=self.frequency_penalties[:num_reqs],
|
| 578 |
+
presence_penalties=self.presence_penalties[:num_reqs],
|
| 579 |
+
repetition_penalties=self.repetition_penalties[:num_reqs],
|
| 580 |
+
output_token_ids=cast(list[list[int]], self.req_output_token_ids),
|
| 581 |
+
min_tokens=self.min_tokens,
|
| 582 |
+
no_penalties=self.no_penalties,
|
| 583 |
+
logit_bias=self.logit_bias[:num_reqs],
|
| 584 |
+
allowed_token_ids_mask=allowed_token_ids_mask,
|
| 585 |
+
bad_words_token_ids=self.bad_words_token_ids,
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
+
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
|
| 589 |
+
max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
|
| 590 |
+
prompt_token_ids_cpu_tensor = torch.empty(
|
| 591 |
+
(self.num_reqs, max_prompt_len),
|
| 592 |
+
device="cpu",
|
| 593 |
+
dtype=torch.int64,
|
| 594 |
+
pin_memory=self.pin_memory,
|
| 595 |
+
)
|
| 596 |
+
prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
|
| 597 |
+
prompt_token_ids[:] = self.token_ids_cpu[:self.
|
| 598 |
+
num_reqs, :max_prompt_len]
|
| 599 |
+
# Use the value of vocab_size as a pad since we don't have a
|
| 600 |
+
# token_id of this value.
|
| 601 |
+
for i in range(self.num_reqs):
|
| 602 |
+
prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size
|
| 603 |
+
return prompt_token_ids_cpu_tensor.to(device=self.device,
|
| 604 |
+
non_blocking=True)
|
| 605 |
+
|
| 606 |
+
def make_lora_inputs(
|
| 607 |
+
self, num_scheduled_tokens: np.ndarray
|
| 608 |
+
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
|
| 609 |
+
"""
|
| 610 |
+
Given the num_scheduled_tokens for each request in the batch, return
|
| 611 |
+
datastructures used to activate the current LoRAs.
|
| 612 |
+
Returns:
|
| 613 |
+
1. prompt_lora_mapping: A tuple of size self.num_reqs where,
|
| 614 |
+
prompt_lora_mapping[i] is the LoRA id to use for the ith prompt.
|
| 615 |
+
2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
|
| 616 |
+
where, token_lora_mapping[i] is the LoRA id to use for ith token.
|
| 617 |
+
3. lora_requests: Set of relevant LoRA requests.
|
| 618 |
+
"""
|
| 619 |
+
|
| 620 |
+
req_lora_mapping = self.request_lora_mapping[:self.num_reqs]
|
| 621 |
+
prompt_lora_mapping = tuple(req_lora_mapping)
|
| 622 |
+
token_lora_mapping = tuple(
|
| 623 |
+
req_lora_mapping.repeat(num_scheduled_tokens))
|
| 624 |
+
active_lora_requests: set[LoRARequest] = set(
|
| 625 |
+
self.lora_id_to_lora_request.values())
|
| 626 |
+
|
| 627 |
+
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
|
| 628 |
+
|
| 629 |
+
@property
|
| 630 |
+
def num_reqs(self) -> int:
|
| 631 |
+
return len(self.req_id_to_index)
|
| 632 |
+
|
| 633 |
+
@property
|
| 634 |
+
def all_greedy(self) -> bool:
|
| 635 |
+
return len(self.random_reqs) == 0
|
| 636 |
+
|
| 637 |
+
@property
|
| 638 |
+
def all_random(self) -> bool:
|
| 639 |
+
return len(self.greedy_reqs) == 0
|
| 640 |
+
|
| 641 |
+
@property
|
| 642 |
+
def no_top_p(self) -> bool:
|
| 643 |
+
return len(self.top_p_reqs) == 0
|
| 644 |
+
|
| 645 |
+
@property
|
| 646 |
+
def no_top_k(self) -> bool:
|
| 647 |
+
return len(self.top_k_reqs) == 0
|
| 648 |
+
|
| 649 |
+
@property
|
| 650 |
+
def no_min_p(self) -> bool:
|
| 651 |
+
return len(self.min_p_reqs) == 0
|
| 652 |
+
|
| 653 |
+
@property
|
| 654 |
+
def no_penalties(self) -> bool:
|
| 655 |
+
return (len(self.presence_penalties_reqs) == 0
|
| 656 |
+
and len(self.frequency_penalties_reqs) == 0
|
| 657 |
+
and len(self.repetition_penalties_reqs) == 0)
|
| 658 |
+
|
| 659 |
+
@property
|
| 660 |
+
def max_num_logprobs(self) -> Optional[int]:
|
| 661 |
+
return max(self.num_logprobs.values()) if self.num_logprobs else None
|
| 662 |
+
|
| 663 |
+
@property
|
| 664 |
+
def no_prompt_logprob(self) -> bool:
|
| 665 |
+
return not self.num_prompt_logprobs
|
| 666 |
+
|
| 667 |
+
@property
|
| 668 |
+
def no_allowed_token_ids(self) -> bool:
|
| 669 |
+
return len(self.has_allowed_token_ids) == 0
|
vllm_hacked/v1/worker/gpu_input_batch_ori.py
ADDED
|
@@ -0,0 +1,863 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
| 3 |
+
# Datastructures defining a GPU input batch
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import Optional, cast
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
from typing_extensions import deprecated
|
| 11 |
+
|
| 12 |
+
from vllm.lora.request import LoRARequest
|
| 13 |
+
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItems
|
| 14 |
+
from vllm.pooling_params import PoolingParams
|
| 15 |
+
from vllm.sampling_params import SamplingParams, SamplingType
|
| 16 |
+
from vllm.utils import length_from_prompt_token_ids_or_embeds, swap_dict_values
|
| 17 |
+
from vllm.v1.outputs import LogprobsTensors
|
| 18 |
+
from vllm.v1.pool.metadata import PoolingMetadata
|
| 19 |
+
from vllm.v1.sample.logits_processor import (BatchUpdateBuilder,
|
| 20 |
+
LogitsProcessors,
|
| 21 |
+
MoveDirectionality)
|
| 22 |
+
from vllm.v1.sample.metadata import SamplingMetadata
|
| 23 |
+
from vllm.v1.spec_decode.utils import is_spec_decode_unsupported
|
| 24 |
+
from vllm.v1.utils import copy_slice
|
| 25 |
+
from vllm.v1.worker.block_table import MultiGroupBlockTable
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class CachedRequestState:
|
| 30 |
+
|
| 31 |
+
req_id: str
|
| 32 |
+
prompt_token_ids: Optional[list[int]]
|
| 33 |
+
mm_features: list[MultiModalFeatureSpec]
|
| 34 |
+
sampling_params: Optional[SamplingParams]
|
| 35 |
+
pooling_params: Optional[PoolingParams]
|
| 36 |
+
generator: Optional[torch.Generator]
|
| 37 |
+
|
| 38 |
+
block_ids: tuple[list[int], ...]
|
| 39 |
+
num_computed_tokens: int
|
| 40 |
+
output_token_ids: list[int]
|
| 41 |
+
|
| 42 |
+
mrope_positions: Optional[torch.Tensor] = None
|
| 43 |
+
mrope_position_delta: Optional[int] = None
|
| 44 |
+
|
| 45 |
+
lora_request: Optional[LoRARequest] = None
|
| 46 |
+
prompt_embeds: Optional[torch.Tensor] = None
|
| 47 |
+
|
| 48 |
+
def __post_init__(self):
|
| 49 |
+
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
|
| 50 |
+
self.prompt_token_ids, self.prompt_embeds)
|
| 51 |
+
|
| 52 |
+
@property
|
| 53 |
+
def num_tokens(self) -> int:
|
| 54 |
+
return self.num_prompt_tokens + len(self.output_token_ids)
|
| 55 |
+
|
| 56 |
+
# Temporary back-compatibility for plugins that define model runner
|
| 57 |
+
@property
|
| 58 |
+
@deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be "
|
| 59 |
+
"removed in v0.13. Please use `mm_kwargs` instead.")
|
| 60 |
+
def mm_inputs(self) -> list[MultiModalKwargsItems]:
|
| 61 |
+
return [
|
| 62 |
+
MultiModalKwargsItems.from_seq([f.data]) for f in self.mm_features
|
| 63 |
+
if f.data is not None
|
| 64 |
+
]
|
| 65 |
+
|
| 66 |
+
def get_token_id(self, idx: int) -> int:
|
| 67 |
+
if idx < self.num_prompt_tokens:
|
| 68 |
+
if self.prompt_token_ids is None:
|
| 69 |
+
raise ValueError(
|
| 70 |
+
f"Tried to access token index {idx}, but that token was "
|
| 71 |
+
"provided via prompt_embeds, and its ID is unknown.")
|
| 72 |
+
return self.prompt_token_ids[idx]
|
| 73 |
+
elif idx - self.num_prompt_tokens < len(self.output_token_ids):
|
| 74 |
+
return self.output_token_ids[idx - self.num_prompt_tokens]
|
| 75 |
+
else:
|
| 76 |
+
return -1
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class InputBatch:
|
| 80 |
+
|
| 81 |
+
def __init__(
|
| 82 |
+
self,
|
| 83 |
+
max_num_reqs: int,
|
| 84 |
+
max_model_len: int,
|
| 85 |
+
max_num_batched_tokens: int,
|
| 86 |
+
device: torch.device,
|
| 87 |
+
pin_memory: bool,
|
| 88 |
+
vocab_size: int,
|
| 89 |
+
block_sizes: list[int], # The block_size of each kv cache group
|
| 90 |
+
logitsprocs: Optional[LogitsProcessors] = None,
|
| 91 |
+
is_spec_decode: bool = False,
|
| 92 |
+
is_pooling_model: bool = False,
|
| 93 |
+
num_speculative_tokens: int = 0,
|
| 94 |
+
):
|
| 95 |
+
self.is_pooling_model = is_pooling_model
|
| 96 |
+
self.is_spec_decode = is_spec_decode
|
| 97 |
+
self.max_num_reqs = max_num_reqs
|
| 98 |
+
self.max_model_len = max_model_len
|
| 99 |
+
self.max_num_batched_tokens = max_num_batched_tokens
|
| 100 |
+
self.device = device
|
| 101 |
+
self.pin_memory = pin_memory
|
| 102 |
+
self.vocab_size = vocab_size
|
| 103 |
+
|
| 104 |
+
self._req_ids: list[Optional[str]] = []
|
| 105 |
+
self.req_id_to_index: dict[str, int] = {}
|
| 106 |
+
|
| 107 |
+
# TODO(woosuk): This buffer could be too large if max_model_len is big.
|
| 108 |
+
# Find a way to reduce the CPU memory usage.
|
| 109 |
+
# This buffer is not directly transferred to the GPU, so it does not
|
| 110 |
+
# need to be pinned.
|
| 111 |
+
self.token_ids_cpu_tensor = torch.zeros(
|
| 112 |
+
(max_num_reqs, max_model_len),
|
| 113 |
+
device="cpu",
|
| 114 |
+
dtype=torch.int32,
|
| 115 |
+
pin_memory=False,
|
| 116 |
+
)
|
| 117 |
+
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
|
| 118 |
+
self.is_token_ids = torch.zeros((max_num_reqs, max_model_len),
|
| 119 |
+
device="cpu",
|
| 120 |
+
dtype=bool,
|
| 121 |
+
pin_memory=False)
|
| 122 |
+
# Store prompt embeddings per request to avoid OOM from large upfront
|
| 123 |
+
# allocation if max_model_len is big.
|
| 124 |
+
# Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size)
|
| 125 |
+
self.req_prompt_embeds: dict[int, torch.Tensor] = {}
|
| 126 |
+
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
|
| 127 |
+
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
|
| 128 |
+
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
|
| 129 |
+
self.num_computed_tokens_cpu_tensor = torch.zeros(
|
| 130 |
+
(max_num_reqs, ),
|
| 131 |
+
device="cpu",
|
| 132 |
+
dtype=torch.int32,
|
| 133 |
+
pin_memory=pin_memory,
|
| 134 |
+
)
|
| 135 |
+
self.num_computed_tokens_cpu = \
|
| 136 |
+
self.num_computed_tokens_cpu_tensor.numpy()
|
| 137 |
+
|
| 138 |
+
# Block table.
|
| 139 |
+
self.block_table = MultiGroupBlockTable(
|
| 140 |
+
max_num_reqs=max_num_reqs,
|
| 141 |
+
max_model_len=max_model_len,
|
| 142 |
+
max_num_batched_tokens=max_num_batched_tokens,
|
| 143 |
+
pin_memory=pin_memory,
|
| 144 |
+
device=device,
|
| 145 |
+
block_sizes=block_sizes,
|
| 146 |
+
num_speculative_tokens=num_speculative_tokens,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
# Sampling-related.
|
| 150 |
+
self.temperature = torch.empty((max_num_reqs, ),
|
| 151 |
+
dtype=torch.float32,
|
| 152 |
+
device=device)
|
| 153 |
+
self.temperature_cpu_tensor = torch.empty((max_num_reqs, ),
|
| 154 |
+
dtype=torch.float32,
|
| 155 |
+
device="cpu",
|
| 156 |
+
pin_memory=pin_memory)
|
| 157 |
+
self.temperature_cpu = self.temperature_cpu_tensor.numpy()
|
| 158 |
+
self.greedy_reqs: set[str] = set()
|
| 159 |
+
self.random_reqs: set[str] = set()
|
| 160 |
+
|
| 161 |
+
self.top_p = torch.empty((max_num_reqs, ),
|
| 162 |
+
dtype=torch.float32,
|
| 163 |
+
device=device)
|
| 164 |
+
self.top_p_cpu_tensor = torch.empty((max_num_reqs, ),
|
| 165 |
+
dtype=torch.float32,
|
| 166 |
+
device="cpu",
|
| 167 |
+
pin_memory=pin_memory)
|
| 168 |
+
self.top_p_cpu = self.top_p_cpu_tensor.numpy()
|
| 169 |
+
self.top_p_reqs: set[str] = set()
|
| 170 |
+
|
| 171 |
+
self.top_k = torch.empty((max_num_reqs, ),
|
| 172 |
+
dtype=torch.int32,
|
| 173 |
+
device=device)
|
| 174 |
+
self.top_k_cpu_tensor = torch.empty((max_num_reqs, ),
|
| 175 |
+
dtype=torch.int32,
|
| 176 |
+
device="cpu",
|
| 177 |
+
pin_memory=pin_memory)
|
| 178 |
+
self.top_k_cpu = self.top_k_cpu_tensor.numpy()
|
| 179 |
+
self.top_k_reqs: set[str] = set()
|
| 180 |
+
|
| 181 |
+
# IDs of requests which do not support spec decoding
|
| 182 |
+
self.spec_decode_unsupported_reqs: set[str] = set()
|
| 183 |
+
|
| 184 |
+
# Frequency penalty related data structures
|
| 185 |
+
self.frequency_penalties = torch.empty((max_num_reqs, ),
|
| 186 |
+
dtype=torch.float,
|
| 187 |
+
device=device)
|
| 188 |
+
self.frequency_penalties_cpu_tensor = torch.empty(
|
| 189 |
+
(max_num_reqs, ),
|
| 190 |
+
dtype=torch.float,
|
| 191 |
+
device="cpu",
|
| 192 |
+
pin_memory=pin_memory)
|
| 193 |
+
self.frequency_penalties_cpu = \
|
| 194 |
+
self.frequency_penalties_cpu_tensor.numpy()
|
| 195 |
+
self.frequency_penalties_reqs: set[str] = set()
|
| 196 |
+
|
| 197 |
+
# Presence penalty related data structures
|
| 198 |
+
self.presence_penalties = torch.empty((max_num_reqs, ),
|
| 199 |
+
dtype=torch.float,
|
| 200 |
+
device=device)
|
| 201 |
+
self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ),
|
| 202 |
+
dtype=torch.float,
|
| 203 |
+
device="cpu",
|
| 204 |
+
pin_memory=pin_memory)
|
| 205 |
+
self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy(
|
| 206 |
+
)
|
| 207 |
+
self.presence_penalties_reqs: set[str] = set()
|
| 208 |
+
|
| 209 |
+
# Repetition penalty related data structures
|
| 210 |
+
self.repetition_penalties = torch.empty((max_num_reqs, ),
|
| 211 |
+
dtype=torch.float,
|
| 212 |
+
device=device)
|
| 213 |
+
self.repetition_penalties_cpu_tensor = torch.empty(
|
| 214 |
+
(max_num_reqs, ),
|
| 215 |
+
dtype=torch.float,
|
| 216 |
+
device="cpu",
|
| 217 |
+
pin_memory=pin_memory)
|
| 218 |
+
self.repetition_penalties_cpu = \
|
| 219 |
+
self.repetition_penalties_cpu_tensor.numpy()
|
| 220 |
+
self.repetition_penalties_reqs: set[str] = set()
|
| 221 |
+
|
| 222 |
+
# Speculative decoding
|
| 223 |
+
self.num_accepted_tokens_cpu_tensor = torch.ones((max_num_reqs, ),
|
| 224 |
+
dtype=torch.int64,
|
| 225 |
+
device="cpu",
|
| 226 |
+
pin_memory=pin_memory)
|
| 227 |
+
self.num_accepted_tokens_cpu = \
|
| 228 |
+
self.num_accepted_tokens_cpu_tensor.numpy()
|
| 229 |
+
|
| 230 |
+
# lora related
|
| 231 |
+
self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
|
| 232 |
+
dtype=np.int32)
|
| 233 |
+
self.lora_id_to_request_ids: dict[int, set[str]] = {}
|
| 234 |
+
self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
|
| 235 |
+
|
| 236 |
+
# req_index -> generator
|
| 237 |
+
# NOTE(woosuk): The indices of the requests that do not have their own
|
| 238 |
+
# generator should not be included in the dictionary.
|
| 239 |
+
self.generators: dict[int, torch.Generator] = {}
|
| 240 |
+
|
| 241 |
+
self.num_logprobs: dict[str, int] = {}
|
| 242 |
+
# NOTE(rob): num_prompt_logprobs only includes reqs
|
| 243 |
+
# that are currently in the prefill phase.
|
| 244 |
+
self.num_prompt_logprobs: dict[str, int] = {}
|
| 245 |
+
|
| 246 |
+
# To accumulate prompt logprobs tensor chunks across prefill steps.
|
| 247 |
+
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
|
| 248 |
+
|
| 249 |
+
# Internal representation of per-step batch state changes, used for
|
| 250 |
+
# reordering persistent batch and generating logitsprocs batch state
|
| 251 |
+
# updates. Should reset each step.
|
| 252 |
+
self.batch_update_builder = BatchUpdateBuilder()
|
| 253 |
+
|
| 254 |
+
# TODO convert this to LogitsProcessor
|
| 255 |
+
self.has_allowed_token_ids: set[str] = set()
|
| 256 |
+
# NOTE(lufang): In the mask tensor, if the corresponding token allowed,
|
| 257 |
+
# the value is False. Since we use masked_fill_ to set -inf.
|
| 258 |
+
self.allowed_token_ids_mask: Optional[torch.Tensor] = None
|
| 259 |
+
self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None
|
| 260 |
+
|
| 261 |
+
# req_index -> bad_words_token_ids
|
| 262 |
+
self.bad_words_token_ids: dict[int, list[list[int]]] = {}
|
| 263 |
+
|
| 264 |
+
self.logits_processing_needs_token_ids = np.zeros(max_num_reqs,
|
| 265 |
+
dtype=bool)
|
| 266 |
+
|
| 267 |
+
self.req_output_token_ids: list[Optional[list[int]]] = []
|
| 268 |
+
|
| 269 |
+
# Store provided logitsprocs. If none are provided, initialize empty
|
| 270 |
+
# data structure
|
| 271 |
+
self.logitsprocs = logitsprocs or LogitsProcessors()
|
| 272 |
+
|
| 273 |
+
# This is updated each time the batch constituents change.
|
| 274 |
+
self.sampling_metadata = self._make_sampling_metadata()
|
| 275 |
+
|
| 276 |
+
self.pooling_params: dict[str, PoolingParams] = {}
|
| 277 |
+
|
| 278 |
+
# Cached reference to the GPU tensor of previously sampled tokens
|
| 279 |
+
self.prev_sampled_token_ids: Optional[torch.Tensor] = None
|
| 280 |
+
self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None
|
| 281 |
+
self.prev_req_id_to_index: Optional[dict[str, int]] = None
|
| 282 |
+
|
| 283 |
+
@property
|
| 284 |
+
def req_ids(self) -> list[str]:
|
| 285 |
+
# None elements should only be present transiently
|
| 286 |
+
# while performing state updates to the batch.
|
| 287 |
+
return cast(list[str], self._req_ids)
|
| 288 |
+
|
| 289 |
+
def _register_add_request(self, request: "CachedRequestState") -> int:
|
| 290 |
+
"""Track add-request operations for logits processors.
|
| 291 |
+
Not applicable to pooling models.
|
| 292 |
+
"""
|
| 293 |
+
|
| 294 |
+
# Fill the next empty index if there is one.
|
| 295 |
+
if (new_req_index := self.batch_update_builder.pop_removed()) is None:
|
| 296 |
+
# Append to end otherwise.
|
| 297 |
+
new_req_index = self.num_reqs
|
| 298 |
+
|
| 299 |
+
assert new_req_index < self.max_num_reqs
|
| 300 |
+
self.batch_update_builder.batch_changed = True
|
| 301 |
+
if request.sampling_params:
|
| 302 |
+
# Detailed added request metadata is only required for non-pooling
|
| 303 |
+
# models, to support logitsprocs.
|
| 304 |
+
self.batch_update_builder.added.append(
|
| 305 |
+
(new_req_index, request.sampling_params,
|
| 306 |
+
request.prompt_token_ids, request.output_token_ids))
|
| 307 |
+
|
| 308 |
+
return new_req_index
|
| 309 |
+
|
| 310 |
+
def add_request(
|
| 311 |
+
self,
|
| 312 |
+
request: "CachedRequestState",
|
| 313 |
+
) -> int:
|
| 314 |
+
req_index = self._register_add_request(request)
|
| 315 |
+
|
| 316 |
+
req_id = request.req_id
|
| 317 |
+
if req_index == len(self._req_ids):
|
| 318 |
+
self._req_ids.append(req_id)
|
| 319 |
+
self.req_output_token_ids.append(request.output_token_ids)
|
| 320 |
+
else:
|
| 321 |
+
self._req_ids[req_index] = req_id
|
| 322 |
+
self.req_output_token_ids[req_index] = request.output_token_ids
|
| 323 |
+
|
| 324 |
+
self.req_id_to_index[req_id] = req_index
|
| 325 |
+
|
| 326 |
+
# Copy the prompt token ids and output token ids.
|
| 327 |
+
num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
|
| 328 |
+
request.prompt_token_ids, request.prompt_embeds)
|
| 329 |
+
self.num_prompt_tokens[req_index] = num_prompt_tokens
|
| 330 |
+
start_idx = num_prompt_tokens
|
| 331 |
+
end_idx = start_idx + len(request.output_token_ids)
|
| 332 |
+
if request.prompt_token_ids is not None:
|
| 333 |
+
self.token_ids_cpu[
|
| 334 |
+
req_index, :num_prompt_tokens] = request.prompt_token_ids
|
| 335 |
+
self.is_token_ids[req_index, :num_prompt_tokens] = True
|
| 336 |
+
else:
|
| 337 |
+
self.is_token_ids[req_index, :num_prompt_tokens] = False
|
| 338 |
+
if request.prompt_embeds is not None:
|
| 339 |
+
self.req_prompt_embeds[req_index] = request.prompt_embeds
|
| 340 |
+
self.token_ids_cpu[req_index,
|
| 341 |
+
start_idx:end_idx] = request.output_token_ids
|
| 342 |
+
self.is_token_ids[req_index, start_idx:end_idx] = True
|
| 343 |
+
# Number of token ids in prompt (token_ids_cpu or prompt_embeds).
|
| 344 |
+
# NOTE(woosuk): This may include spec decode tokens.
|
| 345 |
+
self.num_tokens[req_index] = request.num_tokens
|
| 346 |
+
# Number of tokens without spec decode tokens.
|
| 347 |
+
self.num_tokens_no_spec[req_index] = request.num_tokens
|
| 348 |
+
|
| 349 |
+
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
|
| 350 |
+
self.block_table.add_row(request.block_ids, req_index)
|
| 351 |
+
|
| 352 |
+
if sampling_params := request.sampling_params:
|
| 353 |
+
if (self.is_spec_decode
|
| 354 |
+
and is_spec_decode_unsupported(sampling_params)):
|
| 355 |
+
self.spec_decode_unsupported_reqs.add(req_id)
|
| 356 |
+
if sampling_params.sampling_type == SamplingType.GREEDY:
|
| 357 |
+
# Should avoid division by zero later when apply_temperature.
|
| 358 |
+
self.temperature_cpu[req_index] = 0.0
|
| 359 |
+
self.greedy_reqs.add(req_id)
|
| 360 |
+
else:
|
| 361 |
+
self.temperature_cpu[req_index] = sampling_params.temperature
|
| 362 |
+
self.random_reqs.add(req_id)
|
| 363 |
+
|
| 364 |
+
self.top_p_cpu[req_index] = sampling_params.top_p
|
| 365 |
+
if sampling_params.top_p < 1:
|
| 366 |
+
self.top_p_reqs.add(req_id)
|
| 367 |
+
top_k = sampling_params.top_k
|
| 368 |
+
if 0 < top_k < self.vocab_size:
|
| 369 |
+
self.top_k_reqs.add(req_id)
|
| 370 |
+
else:
|
| 371 |
+
top_k = self.vocab_size
|
| 372 |
+
self.top_k_cpu[req_index] = top_k
|
| 373 |
+
self.frequency_penalties_cpu[
|
| 374 |
+
req_index] = sampling_params.frequency_penalty
|
| 375 |
+
if sampling_params.frequency_penalty != 0.0:
|
| 376 |
+
self.frequency_penalties_reqs.add(req_id)
|
| 377 |
+
self.presence_penalties_cpu[
|
| 378 |
+
req_index] = sampling_params.presence_penalty
|
| 379 |
+
if sampling_params.presence_penalty != 0.0:
|
| 380 |
+
self.presence_penalties_reqs.add(req_id)
|
| 381 |
+
self.repetition_penalties_cpu[
|
| 382 |
+
req_index] = sampling_params.repetition_penalty
|
| 383 |
+
if sampling_params.repetition_penalty != 1.0:
|
| 384 |
+
self.repetition_penalties_reqs.add(req_id)
|
| 385 |
+
|
| 386 |
+
# NOTE(woosuk): self.generators should not include the requests that
|
| 387 |
+
# do not have their own generator.
|
| 388 |
+
if request.generator is not None:
|
| 389 |
+
self.generators[req_index] = request.generator
|
| 390 |
+
|
| 391 |
+
if sampling_params.logprobs is not None:
|
| 392 |
+
self.num_logprobs[req_id] = (self.vocab_size
|
| 393 |
+
if sampling_params.logprobs == -1
|
| 394 |
+
else sampling_params.logprobs)
|
| 395 |
+
if sampling_params.prompt_logprobs is not None:
|
| 396 |
+
self.num_prompt_logprobs[req_id] = (
|
| 397 |
+
self.vocab_size if sampling_params.prompt_logprobs == -1
|
| 398 |
+
else sampling_params.prompt_logprobs)
|
| 399 |
+
|
| 400 |
+
if sampling_params.allowed_token_ids:
|
| 401 |
+
self.has_allowed_token_ids.add(req_id)
|
| 402 |
+
if self.allowed_token_ids_mask_cpu_tensor is None:
|
| 403 |
+
# Lazy allocation for this tensor, which can be large.
|
| 404 |
+
# False means we don't fill with -inf.
|
| 405 |
+
self.allowed_token_ids_mask = torch.zeros(
|
| 406 |
+
self.max_num_reqs,
|
| 407 |
+
self.vocab_size,
|
| 408 |
+
dtype=torch.bool,
|
| 409 |
+
device=self.device)
|
| 410 |
+
self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
|
| 411 |
+
self.max_num_reqs,
|
| 412 |
+
self.vocab_size,
|
| 413 |
+
dtype=torch.bool,
|
| 414 |
+
device="cpu")
|
| 415 |
+
self.allowed_token_ids_mask_cpu_tensor[req_index] = True
|
| 416 |
+
# False means we don't fill with -inf.
|
| 417 |
+
self.allowed_token_ids_mask_cpu_tensor[req_index][
|
| 418 |
+
sampling_params.allowed_token_ids] = False
|
| 419 |
+
|
| 420 |
+
if sampling_params.bad_words_token_ids:
|
| 421 |
+
self.bad_words_token_ids[
|
| 422 |
+
req_index] = sampling_params.bad_words_token_ids
|
| 423 |
+
elif pooling_params := request.pooling_params:
|
| 424 |
+
self.pooling_params[req_id] = pooling_params
|
| 425 |
+
self.logits_processing_needs_token_ids[req_index] = (
|
| 426 |
+
pooling_params.requires_token_ids)
|
| 427 |
+
else:
|
| 428 |
+
raise NotImplementedError("Unrecognized request type")
|
| 429 |
+
|
| 430 |
+
# Speculative decoding: by default 1 token is generated.
|
| 431 |
+
self.num_accepted_tokens_cpu[req_index] = 1
|
| 432 |
+
|
| 433 |
+
# Add request lora ID
|
| 434 |
+
if request.lora_request:
|
| 435 |
+
lora_id = request.lora_request.lora_int_id
|
| 436 |
+
if lora_id not in self.lora_id_to_request_ids:
|
| 437 |
+
self.lora_id_to_request_ids[lora_id] = set()
|
| 438 |
+
|
| 439 |
+
self.request_lora_mapping[req_index] = lora_id
|
| 440 |
+
self.lora_id_to_request_ids[lora_id].add(request.req_id)
|
| 441 |
+
self.lora_id_to_lora_request[lora_id] = request.lora_request
|
| 442 |
+
else:
|
| 443 |
+
# No LoRA
|
| 444 |
+
self.request_lora_mapping[req_index] = 0
|
| 445 |
+
|
| 446 |
+
return req_index
|
| 447 |
+
|
| 448 |
+
def remove_request(self, req_id: str) -> Optional[int]:
|
| 449 |
+
"""This method must always be followed by a call to condense().
|
| 450 |
+
|
| 451 |
+
Args:
|
| 452 |
+
req_id: request to remove
|
| 453 |
+
|
| 454 |
+
Returns:
|
| 455 |
+
Removed request index, or `None` if `req_id` not recognized
|
| 456 |
+
"""
|
| 457 |
+
|
| 458 |
+
req_index = self.req_id_to_index.pop(req_id, None)
|
| 459 |
+
if req_index is None:
|
| 460 |
+
return None
|
| 461 |
+
|
| 462 |
+
self.batch_update_builder.removed_append(req_index)
|
| 463 |
+
self._req_ids[req_index] = None
|
| 464 |
+
self.req_output_token_ids[req_index] = None
|
| 465 |
+
|
| 466 |
+
# LoRA
|
| 467 |
+
lora_id = self.request_lora_mapping[req_index]
|
| 468 |
+
if lora_id != 0:
|
| 469 |
+
lora_req_ids = self.lora_id_to_request_ids[lora_id]
|
| 470 |
+
lora_req_ids.discard(req_id)
|
| 471 |
+
if not lora_req_ids:
|
| 472 |
+
del self.lora_id_to_request_ids[lora_id]
|
| 473 |
+
del self.lora_id_to_lora_request[lora_id]
|
| 474 |
+
self.request_lora_mapping[req_index] = 0
|
| 475 |
+
|
| 476 |
+
if self.is_pooling_model:
|
| 477 |
+
self.pooling_params.pop(req_id, None)
|
| 478 |
+
return req_index
|
| 479 |
+
|
| 480 |
+
self.greedy_reqs.discard(req_id)
|
| 481 |
+
self.random_reqs.discard(req_id)
|
| 482 |
+
self.top_p_reqs.discard(req_id)
|
| 483 |
+
self.top_k_reqs.discard(req_id)
|
| 484 |
+
self.spec_decode_unsupported_reqs.discard(req_id)
|
| 485 |
+
self.frequency_penalties_reqs.discard(req_id)
|
| 486 |
+
self.presence_penalties_reqs.discard(req_id)
|
| 487 |
+
self.repetition_penalties_reqs.discard(req_id)
|
| 488 |
+
self.generators.pop(req_index, None)
|
| 489 |
+
self.num_logprobs.pop(req_id, None)
|
| 490 |
+
self.num_prompt_logprobs.pop(req_id, None)
|
| 491 |
+
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
|
| 492 |
+
|
| 493 |
+
self.has_allowed_token_ids.discard(req_id)
|
| 494 |
+
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
| 495 |
+
# False means we don't fill with -inf.
|
| 496 |
+
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
|
| 497 |
+
self.bad_words_token_ids.pop(req_index, None)
|
| 498 |
+
return req_index
|
| 499 |
+
|
| 500 |
+
def swap_states(self, i1: int, i2: int) -> None:
|
| 501 |
+
old_id_i1 = self._req_ids[i1]
|
| 502 |
+
old_id_i2 = self._req_ids[i2]
|
| 503 |
+
self._req_ids[i1], self._req_ids[i2] =\
|
| 504 |
+
self._req_ids[i2], self._req_ids[i1] # noqa
|
| 505 |
+
self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\
|
| 506 |
+
self.req_output_token_ids[i2], self.req_output_token_ids[i1]
|
| 507 |
+
assert old_id_i1 is not None and old_id_i2 is not None
|
| 508 |
+
self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\
|
| 509 |
+
self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1]
|
| 510 |
+
self.num_tokens[i1], self.num_tokens[i2] =\
|
| 511 |
+
self.num_tokens[i2], self.num_tokens[i1]
|
| 512 |
+
self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\
|
| 513 |
+
self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1]
|
| 514 |
+
self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\
|
| 515 |
+
self.num_prompt_tokens[i2], self.num_prompt_tokens[i1]
|
| 516 |
+
self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\
|
| 517 |
+
self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1]
|
| 518 |
+
|
| 519 |
+
# NOTE: the following is unsafe
|
| 520 |
+
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
|
| 521 |
+
# self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
|
| 522 |
+
# instead, we need to temporiarily copy the data for one of the indices
|
| 523 |
+
# TODO(lucas): optimize this by only copying valid indices
|
| 524 |
+
tmp = self.token_ids_cpu[i1, ...].copy()
|
| 525 |
+
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
|
| 526 |
+
self.token_ids_cpu[i2, ...] = tmp
|
| 527 |
+
|
| 528 |
+
self.is_token_ids[[i1, i2], ...] = self.is_token_ids[[i2, i1], ...]
|
| 529 |
+
|
| 530 |
+
# Swap prompt embeddings if they exist
|
| 531 |
+
embeds_i1 = self.req_prompt_embeds.get(i1)
|
| 532 |
+
embeds_i2 = self.req_prompt_embeds.get(i2)
|
| 533 |
+
if embeds_i1 is not None:
|
| 534 |
+
self.req_prompt_embeds[i2] = embeds_i1
|
| 535 |
+
else:
|
| 536 |
+
self.req_prompt_embeds.pop(i2, None)
|
| 537 |
+
if embeds_i2 is not None:
|
| 538 |
+
self.req_prompt_embeds[i1] = embeds_i2
|
| 539 |
+
else:
|
| 540 |
+
self.req_prompt_embeds.pop(i1, None)
|
| 541 |
+
|
| 542 |
+
self.block_table.swap_row(i1, i2)
|
| 543 |
+
|
| 544 |
+
self.request_lora_mapping[i1], self.request_lora_mapping[i2] = \
|
| 545 |
+
self.request_lora_mapping[i2], self.request_lora_mapping[i1]
|
| 546 |
+
|
| 547 |
+
if self.is_pooling_model:
|
| 548 |
+
# Sampling and logits parameters don't apply to pooling models.
|
| 549 |
+
return
|
| 550 |
+
|
| 551 |
+
# For autoregressive models, track detailed request reordering info
|
| 552 |
+
# to support logitsprocs.
|
| 553 |
+
self.batch_update_builder.moved.append(
|
| 554 |
+
(i1, i2, MoveDirectionality.SWAP))
|
| 555 |
+
|
| 556 |
+
self.temperature_cpu[i1], self.temperature_cpu[i2] = \
|
| 557 |
+
self.temperature_cpu[i2], self.temperature_cpu[i1]
|
| 558 |
+
self.top_p_cpu[i1], self.top_p_cpu[i2] = \
|
| 559 |
+
self.top_p_cpu[i2], self.top_p_cpu[i1]
|
| 560 |
+
self.top_k_cpu[i1], self.top_k_cpu[i2] = \
|
| 561 |
+
self.top_k_cpu[i2], self.top_k_cpu[i1]
|
| 562 |
+
self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] = \
|
| 563 |
+
self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1]
|
| 564 |
+
self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] = \
|
| 565 |
+
self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
|
| 566 |
+
self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = \
|
| 567 |
+
self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]
|
| 568 |
+
self.num_accepted_tokens_cpu[i1], self.num_accepted_tokens_cpu[i2] =\
|
| 569 |
+
self.num_accepted_tokens_cpu[i2], self.num_accepted_tokens_cpu[i1]
|
| 570 |
+
|
| 571 |
+
swap_dict_values(self.generators, i1, i2)
|
| 572 |
+
swap_dict_values(self.bad_words_token_ids, i1, i2)
|
| 573 |
+
|
| 574 |
+
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
| 575 |
+
self.allowed_token_ids_mask_cpu_tensor[i1], \
|
| 576 |
+
self.allowed_token_ids_mask_cpu_tensor[i2] =\
|
| 577 |
+
self.allowed_token_ids_mask_cpu_tensor[i2], \
|
| 578 |
+
self.allowed_token_ids_mask_cpu_tensor[i1]
|
| 579 |
+
|
| 580 |
+
def condense(self) -> None:
|
| 581 |
+
"""Slide non-empty requests down into lower, empty indices.
|
| 582 |
+
|
| 583 |
+
Any consecutive empty indices at the very end of the list are not
|
| 584 |
+
filled.
|
| 585 |
+
|
| 586 |
+
Returns:
|
| 587 |
+
swaps: list of (from,to) swap tuples for moved requests
|
| 588 |
+
empty_req_indices: indices not filled by condensation
|
| 589 |
+
"""
|
| 590 |
+
num_reqs = self.num_reqs
|
| 591 |
+
|
| 592 |
+
if not (empty_req_indices := self.batch_update_builder.removed):
|
| 593 |
+
# All removed requests were replaced by added requests, or else no
|
| 594 |
+
# requests were removed at all. No condense() needed
|
| 595 |
+
return
|
| 596 |
+
if num_reqs == 0:
|
| 597 |
+
# The batched states are empty.
|
| 598 |
+
self._req_ids.clear()
|
| 599 |
+
self.req_output_token_ids.clear()
|
| 600 |
+
return
|
| 601 |
+
|
| 602 |
+
# NOTE(woosuk): This function assumes that the empty_req_indices
|
| 603 |
+
# is sorted in descending order.
|
| 604 |
+
last_req_index = num_reqs + len(empty_req_indices) - 1
|
| 605 |
+
while empty_req_indices:
|
| 606 |
+
# Find the largest non-empty index.
|
| 607 |
+
while last_req_index in empty_req_indices:
|
| 608 |
+
last_req_index -= 1
|
| 609 |
+
|
| 610 |
+
# Find the smallest empty index.
|
| 611 |
+
empty_index = self.batch_update_builder.peek_removed()
|
| 612 |
+
assert empty_index is not None
|
| 613 |
+
if empty_index >= last_req_index:
|
| 614 |
+
break
|
| 615 |
+
|
| 616 |
+
# Move active request down into empty request
|
| 617 |
+
# index.
|
| 618 |
+
self.batch_update_builder.pop_removed()
|
| 619 |
+
req_id = self._req_ids[last_req_index]
|
| 620 |
+
output_token_ids = self.req_output_token_ids[last_req_index]
|
| 621 |
+
assert req_id is not None
|
| 622 |
+
self._req_ids[empty_index] = req_id
|
| 623 |
+
self._req_ids[last_req_index] = None
|
| 624 |
+
self.req_output_token_ids[empty_index] = output_token_ids
|
| 625 |
+
self.req_output_token_ids[last_req_index] = None
|
| 626 |
+
self.req_id_to_index[req_id] = empty_index
|
| 627 |
+
|
| 628 |
+
num_tokens = self.num_tokens[last_req_index]
|
| 629 |
+
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
|
| 630 |
+
last_req_index, :num_tokens]
|
| 631 |
+
self.is_token_ids[empty_index, :num_tokens] = self.is_token_ids[
|
| 632 |
+
last_req_index, :num_tokens]
|
| 633 |
+
if last_req_index in self.req_prompt_embeds:
|
| 634 |
+
self.req_prompt_embeds[
|
| 635 |
+
empty_index] = self.req_prompt_embeds.pop(last_req_index)
|
| 636 |
+
self.num_tokens[empty_index] = num_tokens
|
| 637 |
+
self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
|
| 638 |
+
last_req_index]
|
| 639 |
+
self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[
|
| 640 |
+
last_req_index]
|
| 641 |
+
self.num_computed_tokens_cpu[
|
| 642 |
+
empty_index] = self.num_computed_tokens_cpu[last_req_index]
|
| 643 |
+
self.block_table.move_row(last_req_index, empty_index)
|
| 644 |
+
|
| 645 |
+
self.request_lora_mapping[empty_index] = self.request_lora_mapping[
|
| 646 |
+
last_req_index]
|
| 647 |
+
|
| 648 |
+
if self.is_pooling_model:
|
| 649 |
+
last_req_index -= 1
|
| 650 |
+
# Sampling state not used by pooling models.
|
| 651 |
+
continue
|
| 652 |
+
|
| 653 |
+
# Autoregressive models require detailed tracking of condense
|
| 654 |
+
# operations to support logitsprocs
|
| 655 |
+
self.batch_update_builder.moved.append(
|
| 656 |
+
(last_req_index, empty_index,
|
| 657 |
+
MoveDirectionality.UNIDIRECTIONAL))
|
| 658 |
+
|
| 659 |
+
self.temperature_cpu[empty_index] = self.temperature_cpu[
|
| 660 |
+
last_req_index]
|
| 661 |
+
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
|
| 662 |
+
self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
|
| 663 |
+
self.frequency_penalties_cpu[
|
| 664 |
+
empty_index] = self.frequency_penalties_cpu[last_req_index]
|
| 665 |
+
self.presence_penalties_cpu[
|
| 666 |
+
empty_index] = self.presence_penalties_cpu[last_req_index]
|
| 667 |
+
self.repetition_penalties_cpu[
|
| 668 |
+
empty_index] = self.repetition_penalties_cpu[last_req_index]
|
| 669 |
+
self.num_accepted_tokens_cpu[
|
| 670 |
+
empty_index] = self.num_accepted_tokens_cpu[last_req_index]
|
| 671 |
+
generator = self.generators.pop(last_req_index, None)
|
| 672 |
+
if generator is not None:
|
| 673 |
+
self.generators[empty_index] = generator
|
| 674 |
+
|
| 675 |
+
# TODO convert these to LogitsProcessors
|
| 676 |
+
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
| 677 |
+
self.allowed_token_ids_mask_cpu_tensor[
|
| 678 |
+
empty_index] = self.allowed_token_ids_mask_cpu_tensor[
|
| 679 |
+
last_req_index]
|
| 680 |
+
|
| 681 |
+
bad_words_token_ids = self.bad_words_token_ids.pop(
|
| 682 |
+
last_req_index, None)
|
| 683 |
+
if bad_words_token_ids is not None:
|
| 684 |
+
self.bad_words_token_ids[empty_index] = bad_words_token_ids
|
| 685 |
+
|
| 686 |
+
# Decrement last_req_index since it is now empty.
|
| 687 |
+
last_req_index -= 1
|
| 688 |
+
|
| 689 |
+
# Trim lists to the batch size.
|
| 690 |
+
del self._req_ids[num_reqs:]
|
| 691 |
+
del self.req_output_token_ids[num_reqs:]
|
| 692 |
+
|
| 693 |
+
def refresh_metadata(self):
|
| 694 |
+
"""Apply any batch updates to sampling metadata."""
|
| 695 |
+
|
| 696 |
+
if self.is_pooling_model:
|
| 697 |
+
batch_changed = self.batch_update_builder.reset()
|
| 698 |
+
if batch_changed:
|
| 699 |
+
self.sampling_metadata = self._make_sampling_metadata()
|
| 700 |
+
return
|
| 701 |
+
|
| 702 |
+
# For non-pooling models - generate and apply logitsprocs update;
|
| 703 |
+
# reset batch update tracking.
|
| 704 |
+
# Update sampling metadata if batch state is changed.
|
| 705 |
+
batch_update = self.batch_update_builder.get_and_reset(self.num_reqs)
|
| 706 |
+
for logit_proc in self.logitsprocs.all:
|
| 707 |
+
logit_proc.update_state(batch_update)
|
| 708 |
+
if batch_update:
|
| 709 |
+
self.sampling_metadata = self._make_sampling_metadata()
|
| 710 |
+
|
| 711 |
+
def _make_sampling_metadata(self) -> SamplingMetadata:
|
| 712 |
+
num_reqs = self.num_reqs
|
| 713 |
+
if not self.all_greedy:
|
| 714 |
+
temperature = copy_slice(self.temperature_cpu_tensor,
|
| 715 |
+
self.temperature, num_reqs)
|
| 716 |
+
else:
|
| 717 |
+
temperature = None
|
| 718 |
+
if not self.no_top_p:
|
| 719 |
+
copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
|
| 720 |
+
if not self.no_top_k:
|
| 721 |
+
copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs)
|
| 722 |
+
|
| 723 |
+
if not self.no_penalties:
|
| 724 |
+
# Since syncing these tensors is expensive only copy them
|
| 725 |
+
# if necessary i.e. if there are requests which require
|
| 726 |
+
# penalties to be applied during sampling.
|
| 727 |
+
copy_slice(self.frequency_penalties_cpu_tensor,
|
| 728 |
+
self.frequency_penalties, num_reqs)
|
| 729 |
+
copy_slice(self.presence_penalties_cpu_tensor,
|
| 730 |
+
self.presence_penalties, num_reqs)
|
| 731 |
+
copy_slice(self.repetition_penalties_cpu_tensor,
|
| 732 |
+
self.repetition_penalties, num_reqs)
|
| 733 |
+
|
| 734 |
+
needs_prompt_token_ids = (
|
| 735 |
+
not self.no_penalties
|
| 736 |
+
or self.logits_processing_needs_token_ids[:num_reqs].any())
|
| 737 |
+
if needs_prompt_token_ids:
|
| 738 |
+
# The prompt tokens are used only for applying penalties or
|
| 739 |
+
# step pooling during the sampling/pooling process.
|
| 740 |
+
# Hence copy these tensors only when there are requests which
|
| 741 |
+
# need penalties/step_pooler to be applied.
|
| 742 |
+
prompt_token_ids = self._make_prompt_token_ids_tensor()
|
| 743 |
+
else:
|
| 744 |
+
prompt_token_ids = None
|
| 745 |
+
|
| 746 |
+
allowed_token_ids_mask: Optional[torch.Tensor] = None
|
| 747 |
+
if not self.no_allowed_token_ids:
|
| 748 |
+
assert self.allowed_token_ids_mask is not None
|
| 749 |
+
copy_slice(self.allowed_token_ids_mask_cpu_tensor,
|
| 750 |
+
self.allowed_token_ids_mask, num_reqs)
|
| 751 |
+
allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]
|
| 752 |
+
|
| 753 |
+
return SamplingMetadata(
|
| 754 |
+
temperature=temperature,
|
| 755 |
+
all_greedy=self.all_greedy,
|
| 756 |
+
all_random=self.all_random,
|
| 757 |
+
top_p=None if self.no_top_p else self.top_p[:num_reqs],
|
| 758 |
+
top_k=None if self.no_top_k else self.top_k[:num_reqs],
|
| 759 |
+
generators=self.generators,
|
| 760 |
+
max_num_logprobs=self.max_num_logprobs,
|
| 761 |
+
prompt_token_ids=prompt_token_ids,
|
| 762 |
+
frequency_penalties=self.frequency_penalties[:num_reqs],
|
| 763 |
+
presence_penalties=self.presence_penalties[:num_reqs],
|
| 764 |
+
repetition_penalties=self.repetition_penalties[:num_reqs],
|
| 765 |
+
output_token_ids=cast(list[list[int]], self.req_output_token_ids),
|
| 766 |
+
no_penalties=self.no_penalties,
|
| 767 |
+
allowed_token_ids_mask=allowed_token_ids_mask,
|
| 768 |
+
bad_words_token_ids=self.bad_words_token_ids,
|
| 769 |
+
logitsprocs=self.logitsprocs,
|
| 770 |
+
)
|
| 771 |
+
|
| 772 |
+
def get_pooling_params(self) -> list[PoolingParams]:
|
| 773 |
+
assert len(self.req_ids) == len(self.pooling_params)
|
| 774 |
+
return [self.pooling_params[req_id] for req_id in self.req_ids]
|
| 775 |
+
|
| 776 |
+
def get_pooling_metadata(self) -> PoolingMetadata:
|
| 777 |
+
pooling_params = self.get_pooling_params()
|
| 778 |
+
|
| 779 |
+
return PoolingMetadata(
|
| 780 |
+
prompt_lens=torch.from_numpy(
|
| 781 |
+
self.num_prompt_tokens[:self.num_reqs]),
|
| 782 |
+
prompt_token_ids=self.sampling_metadata.prompt_token_ids,
|
| 783 |
+
pooling_params=pooling_params,
|
| 784 |
+
)
|
| 785 |
+
|
| 786 |
+
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
|
| 787 |
+
num_reqs = self.num_reqs
|
| 788 |
+
max_prompt_len = self.num_prompt_tokens[:num_reqs].max()
|
| 789 |
+
prompt_token_ids_cpu_tensor = torch.empty(
|
| 790 |
+
(self.num_reqs, max_prompt_len),
|
| 791 |
+
device="cpu",
|
| 792 |
+
dtype=torch.int64,
|
| 793 |
+
pin_memory=self.pin_memory,
|
| 794 |
+
)
|
| 795 |
+
prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
|
| 796 |
+
prompt_token_ids[:] = self.token_ids_cpu[:num_reqs, :max_prompt_len]
|
| 797 |
+
# Use the value of vocab_size as a pad since we don't have a
|
| 798 |
+
# token_id of this value.
|
| 799 |
+
for i in range(num_reqs):
|
| 800 |
+
prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size
|
| 801 |
+
return prompt_token_ids_cpu_tensor.to(device=self.device,
|
| 802 |
+
non_blocking=True)
|
| 803 |
+
|
| 804 |
+
def make_lora_inputs(
|
| 805 |
+
self, num_scheduled_tokens: np.ndarray
|
| 806 |
+
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
|
| 807 |
+
"""
|
| 808 |
+
Given the num_scheduled_tokens for each request in the batch, return
|
| 809 |
+
datastructures used to activate the current LoRAs.
|
| 810 |
+
Returns:
|
| 811 |
+
1. prompt_lora_mapping: A tuple of size self.num_reqs where,
|
| 812 |
+
prompt_lora_mapping[i] is the LoRA id to use for the ith prompt.
|
| 813 |
+
2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
|
| 814 |
+
where, token_lora_mapping[i] is the LoRA id to use for ith token.
|
| 815 |
+
3. lora_requests: Set of relevant LoRA requests.
|
| 816 |
+
"""
|
| 817 |
+
|
| 818 |
+
req_lora_mapping = self.request_lora_mapping[:self.num_reqs]
|
| 819 |
+
prompt_lora_mapping = tuple(req_lora_mapping)
|
| 820 |
+
token_lora_mapping = tuple(
|
| 821 |
+
req_lora_mapping.repeat(num_scheduled_tokens))
|
| 822 |
+
active_lora_requests: set[LoRARequest] = set(
|
| 823 |
+
self.lora_id_to_lora_request.values())
|
| 824 |
+
|
| 825 |
+
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
|
| 826 |
+
|
| 827 |
+
@property
|
| 828 |
+
def num_reqs(self) -> int:
|
| 829 |
+
return len(self.req_id_to_index)
|
| 830 |
+
|
| 831 |
+
@property
|
| 832 |
+
def all_greedy(self) -> bool:
|
| 833 |
+
return len(self.random_reqs) == 0
|
| 834 |
+
|
| 835 |
+
@property
|
| 836 |
+
def all_random(self) -> bool:
|
| 837 |
+
return len(self.greedy_reqs) == 0
|
| 838 |
+
|
| 839 |
+
@property
|
| 840 |
+
def no_top_p(self) -> bool:
|
| 841 |
+
return len(self.top_p_reqs) == 0
|
| 842 |
+
|
| 843 |
+
@property
|
| 844 |
+
def no_top_k(self) -> bool:
|
| 845 |
+
return len(self.top_k_reqs) == 0
|
| 846 |
+
|
| 847 |
+
@property
|
| 848 |
+
def no_penalties(self) -> bool:
|
| 849 |
+
return (len(self.presence_penalties_reqs) == 0
|
| 850 |
+
and len(self.frequency_penalties_reqs) == 0
|
| 851 |
+
and len(self.repetition_penalties_reqs) == 0)
|
| 852 |
+
|
| 853 |
+
@property
|
| 854 |
+
def max_num_logprobs(self) -> Optional[int]:
|
| 855 |
+
return max(self.num_logprobs.values()) if self.num_logprobs else None
|
| 856 |
+
|
| 857 |
+
@property
|
| 858 |
+
def no_prompt_logprob(self) -> bool:
|
| 859 |
+
return not self.num_prompt_logprobs
|
| 860 |
+
|
| 861 |
+
@property
|
| 862 |
+
def no_allowed_token_ids(self) -> bool:
|
| 863 |
+
return len(self.has_allowed_token_ids) == 0
|
vllm_hacked/v1/worker/gpu_model_runner.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
vllm_hacked/v1/worker/gpu_model_runner_ori.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
vllm_hacked/v1/worker/gpu_worker.py
ADDED
|
@@ -0,0 +1,710 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
| 3 |
+
"""A GPU worker class."""
|
| 4 |
+
import copy
|
| 5 |
+
import gc
|
| 6 |
+
import os
|
| 7 |
+
from contextlib import AbstractContextManager, nullcontext
|
| 8 |
+
from typing import TYPE_CHECKING, Any, Optional, Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.distributed
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
|
| 14 |
+
import vllm.envs as envs
|
| 15 |
+
from vllm.config import VllmConfig
|
| 16 |
+
from vllm.distributed import (ensure_model_parallel_initialized,
|
| 17 |
+
init_distributed_environment,
|
| 18 |
+
set_custom_all_reduce)
|
| 19 |
+
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
|
| 20 |
+
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
|
| 21 |
+
from vllm.logger import init_logger
|
| 22 |
+
from vllm.lora.request import LoRARequest
|
| 23 |
+
from vllm.model_executor import set_random_seed
|
| 24 |
+
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
|
| 25 |
+
from vllm.platforms import current_platform
|
| 26 |
+
from vllm.sequence import IntermediateTensors
|
| 27 |
+
from vllm.tasks import SupportedTask
|
| 28 |
+
from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling
|
| 29 |
+
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
| 30 |
+
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
| 31 |
+
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
| 32 |
+
DraftTokenIds, ModelRunnerOutput)
|
| 33 |
+
from vllm.v1.utils import report_usage_stats
|
| 34 |
+
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
| 35 |
+
from vllm.v1.worker.utils import is_residual_scattered_for_sp
|
| 36 |
+
from vllm.v1.worker.worker_base import WorkerBase
|
| 37 |
+
|
| 38 |
+
logger = init_logger(__name__)
|
| 39 |
+
|
| 40 |
+
if TYPE_CHECKING:
|
| 41 |
+
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
| 42 |
+
from vllm.v1.core.sched.output import SchedulerOutput
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class Worker(WorkerBase):
|
| 46 |
+
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
vllm_config: VllmConfig,
|
| 50 |
+
local_rank: int,
|
| 51 |
+
rank: int,
|
| 52 |
+
distributed_init_method: str,
|
| 53 |
+
is_driver_worker: bool = False,
|
| 54 |
+
):
|
| 55 |
+
|
| 56 |
+
super().__init__(vllm_config=vllm_config,
|
| 57 |
+
local_rank=local_rank,
|
| 58 |
+
rank=rank,
|
| 59 |
+
distributed_init_method=distributed_init_method,
|
| 60 |
+
is_driver_worker=is_driver_worker)
|
| 61 |
+
|
| 62 |
+
if self.model_config.trust_remote_code:
|
| 63 |
+
# note: lazy import to avoid importing torch before initializing
|
| 64 |
+
from vllm.utils import init_cached_hf_modules
|
| 65 |
+
init_cached_hf_modules()
|
| 66 |
+
|
| 67 |
+
# Buffers saved before sleep
|
| 68 |
+
self._sleep_saved_buffers: dict[str, torch.Tensor] = {}
|
| 69 |
+
|
| 70 |
+
# Torch profiler. Enabled and configured through env vars:
|
| 71 |
+
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
|
| 72 |
+
if envs.VLLM_TORCH_PROFILER_DIR:
|
| 73 |
+
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
|
| 74 |
+
logger.info("Profiling enabled. Traces will be saved to: %s",
|
| 75 |
+
torch_profiler_trace_dir)
|
| 76 |
+
logger.debug(
|
| 77 |
+
"Profiler config: record_shapes=%s,"
|
| 78 |
+
"profile_memory=%s,with_stack=%s,with_flops=%s",
|
| 79 |
+
envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
|
| 80 |
+
envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
|
| 81 |
+
envs.VLLM_TORCH_PROFILER_WITH_STACK,
|
| 82 |
+
envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
|
| 83 |
+
)
|
| 84 |
+
self.profiler = torch.profiler.profile(
|
| 85 |
+
activities=[
|
| 86 |
+
torch.profiler.ProfilerActivity.CPU,
|
| 87 |
+
torch.profiler.ProfilerActivity.CUDA,
|
| 88 |
+
],
|
| 89 |
+
record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
|
| 90 |
+
profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
|
| 91 |
+
with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
|
| 92 |
+
with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
|
| 93 |
+
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
| 94 |
+
torch_profiler_trace_dir, use_gzip=True))
|
| 95 |
+
else:
|
| 96 |
+
self.profiler = None
|
| 97 |
+
|
| 98 |
+
def sleep(self, level: int = 1) -> None:
|
| 99 |
+
from vllm.device_allocator.cumem import CuMemAllocator
|
| 100 |
+
|
| 101 |
+
free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
|
| 102 |
+
|
| 103 |
+
# Save the buffers before level 2 sleep
|
| 104 |
+
if level == 2:
|
| 105 |
+
model = self.model_runner.model
|
| 106 |
+
self._sleep_saved_buffers = {
|
| 107 |
+
name: buffer.cpu().clone()
|
| 108 |
+
for name, buffer in model.named_buffers()
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
allocator = CuMemAllocator.get_instance()
|
| 112 |
+
allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple())
|
| 113 |
+
free_bytes_after_sleep, total = torch.cuda.mem_get_info()
|
| 114 |
+
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
|
| 115 |
+
used_bytes = total - free_bytes_after_sleep
|
| 116 |
+
assert freed_bytes >= 0, "Memory usage increased after sleeping."
|
| 117 |
+
logger.info(
|
| 118 |
+
"Sleep mode freed %.2f GiB memory, "
|
| 119 |
+
"%.2f GiB memory is still in use.", freed_bytes / GiB_bytes,
|
| 120 |
+
used_bytes / GiB_bytes)
|
| 121 |
+
|
| 122 |
+
def wake_up(self, tags: Optional[list[str]] = None) -> None:
|
| 123 |
+
from vllm.device_allocator.cumem import CuMemAllocator
|
| 124 |
+
|
| 125 |
+
allocator = CuMemAllocator.get_instance()
|
| 126 |
+
allocator.wake_up(tags)
|
| 127 |
+
|
| 128 |
+
# Restore the buffers after level 2 sleep
|
| 129 |
+
if len(self._sleep_saved_buffers):
|
| 130 |
+
model = self.model_runner.model
|
| 131 |
+
for name, buffer in model.named_buffers():
|
| 132 |
+
if name in self._sleep_saved_buffers:
|
| 133 |
+
buffer.data.copy_(self._sleep_saved_buffers[name].data)
|
| 134 |
+
self._sleep_saved_buffers = {}
|
| 135 |
+
|
| 136 |
+
def _maybe_get_memory_pool_context(self,
|
| 137 |
+
tag: str) -> AbstractContextManager:
|
| 138 |
+
if self.vllm_config.model_config.enable_sleep_mode:
|
| 139 |
+
from vllm.device_allocator.cumem import CuMemAllocator
|
| 140 |
+
|
| 141 |
+
allocator = CuMemAllocator.get_instance()
|
| 142 |
+
if tag == "weights":
|
| 143 |
+
assert allocator.get_current_usage() == 0, (
|
| 144 |
+
"Sleep mode can only be "
|
| 145 |
+
"used for one instance per process.")
|
| 146 |
+
context = allocator.use_memory_pool(tag=tag)
|
| 147 |
+
else:
|
| 148 |
+
context = nullcontext()
|
| 149 |
+
return context
|
| 150 |
+
|
| 151 |
+
def initialize_cache(self, num_gpu_blocks: int,
|
| 152 |
+
num_cpu_blocks: int) -> None:
|
| 153 |
+
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
| 154 |
+
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
| 155 |
+
|
| 156 |
+
def init_device(self):
|
| 157 |
+
if self.device_config.device.type == "cuda":
|
| 158 |
+
# This env var set by Ray causes exceptions with graph building.
|
| 159 |
+
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
|
| 160 |
+
self.device = torch.device(f"cuda:{self.local_rank}")
|
| 161 |
+
current_platform.set_device(self.device)
|
| 162 |
+
|
| 163 |
+
current_platform.check_if_supports_dtype(self.model_config.dtype)
|
| 164 |
+
|
| 165 |
+
# Initialize the distributed environment BEFORE taking
|
| 166 |
+
# memory snapshot
|
| 167 |
+
# This ensures NCCL buffers are allocated before we measure
|
| 168 |
+
# available memory
|
| 169 |
+
init_worker_distributed_environment(self.vllm_config, self.rank,
|
| 170 |
+
self.distributed_init_method,
|
| 171 |
+
self.local_rank,
|
| 172 |
+
current_platform.dist_backend)
|
| 173 |
+
|
| 174 |
+
# Set random seed.
|
| 175 |
+
set_random_seed(self.model_config.seed)
|
| 176 |
+
|
| 177 |
+
# Now take memory snapshot after NCCL is initialized
|
| 178 |
+
gc.collect()
|
| 179 |
+
torch.cuda.empty_cache()
|
| 180 |
+
|
| 181 |
+
# take current memory snapshot
|
| 182 |
+
self.init_snapshot = MemorySnapshot()
|
| 183 |
+
self.requested_memory = (self.init_snapshot.total_memory *
|
| 184 |
+
self.cache_config.gpu_memory_utilization)
|
| 185 |
+
if self.init_snapshot.free_memory < self.requested_memory:
|
| 186 |
+
GiB = lambda b: round(b / GiB_bytes, 2)
|
| 187 |
+
raise ValueError(
|
| 188 |
+
f"Free memory on device "
|
| 189 |
+
f"({GiB(self.init_snapshot.free_memory)}/"
|
| 190 |
+
f"{GiB(self.init_snapshot.total_memory)} GiB) on startup "
|
| 191 |
+
f"is less than desired GPU memory utilization "
|
| 192 |
+
f"({self.cache_config.gpu_memory_utilization}, "
|
| 193 |
+
f"{GiB(self.requested_memory)} GiB). Decrease GPU memory "
|
| 194 |
+
f"utilization or reduce GPU memory used by other processes."
|
| 195 |
+
)
|
| 196 |
+
else:
|
| 197 |
+
raise RuntimeError(
|
| 198 |
+
f"Not support device type: {self.device_config.device}")
|
| 199 |
+
|
| 200 |
+
# Construct the model runner
|
| 201 |
+
self.model_runner: GPUModelRunner = GPUModelRunner(
|
| 202 |
+
self.vllm_config, self.device)
|
| 203 |
+
|
| 204 |
+
if self.rank == 0:
|
| 205 |
+
# If usage stat is enabled, collect relevant info.
|
| 206 |
+
report_usage_stats(self.vllm_config)
|
| 207 |
+
|
| 208 |
+
# FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
|
| 209 |
+
# to hijack tensor allocation.
|
| 210 |
+
def load_model(self) -> None:
|
| 211 |
+
eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
|
| 212 |
+
with self._maybe_get_memory_pool_context(tag="weights"):
|
| 213 |
+
self.model_runner.load_model(eep_scale_up=eep_scale_up)
|
| 214 |
+
|
| 215 |
+
def update_config(self, overrides: dict[str, Any]) -> None:
|
| 216 |
+
self.model_runner.update_config(overrides)
|
| 217 |
+
|
| 218 |
+
def reload_weights(self) -> None:
|
| 219 |
+
self.model_runner.reload_weights()
|
| 220 |
+
|
| 221 |
+
@torch.inference_mode()
|
| 222 |
+
def determine_available_memory(self) -> int:
|
| 223 |
+
"""Profiles the peak memory usage of the model to determine how much
|
| 224 |
+
memory can be used for KV cache without OOMs.
|
| 225 |
+
|
| 226 |
+
The engine will first conduct a profiling of the existing memory usage.
|
| 227 |
+
Then, it calculates the free memory that can be used for KV cache in
|
| 228 |
+
bytes.
|
| 229 |
+
|
| 230 |
+
Tip:
|
| 231 |
+
You may limit the usage of GPU memory
|
| 232 |
+
by adjusting the `gpu_memory_utilization` parameter.
|
| 233 |
+
"""
|
| 234 |
+
GiB = lambda b: b / GiB_bytes
|
| 235 |
+
if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes:
|
| 236 |
+
# still need a profile run which compiles the model for
|
| 237 |
+
# max_num_batched_tokens
|
| 238 |
+
self.model_runner.profile_run()
|
| 239 |
+
|
| 240 |
+
msg = (
|
| 241 |
+
f"Initial free memory {GiB(self.init_snapshot.free_memory)} "
|
| 242 |
+
f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f}GiB memory for "
|
| 243 |
+
"KV Cache as specified by kv_cache_memory_bytes config and "
|
| 244 |
+
"skipped memory profiling. This does does not respect the "
|
| 245 |
+
"gpu_memory_utilization config. Only use kv_cache_memory_bytes "
|
| 246 |
+
"config when you want manual control of KV cache memory "
|
| 247 |
+
"size. If OOM'ed, check the difference of initial free "
|
| 248 |
+
"memory between the current run and the previous run "
|
| 249 |
+
"where kv_cache_memory_bytes is suggested and update it "
|
| 250 |
+
"correspondingly.")
|
| 251 |
+
logger.info(msg)
|
| 252 |
+
return kv_cache_memory_bytes
|
| 253 |
+
|
| 254 |
+
torch.cuda.empty_cache()
|
| 255 |
+
torch.cuda.reset_peak_memory_stats()
|
| 256 |
+
|
| 257 |
+
# Execute a forward pass with dummy inputs to profile the memory usage
|
| 258 |
+
# of the model.
|
| 259 |
+
with memory_profiling(
|
| 260 |
+
self.init_snapshot,
|
| 261 |
+
weights_memory=int(self.model_runner.model_memory_usage),
|
| 262 |
+
) as profile_result:
|
| 263 |
+
self.model_runner.profile_run()
|
| 264 |
+
|
| 265 |
+
self.non_torch_memory = profile_result.non_torch_increase
|
| 266 |
+
self.peak_activation_memory = profile_result.torch_peak_increase
|
| 267 |
+
|
| 268 |
+
free_gpu_memory = profile_result.after_profile.free_memory
|
| 269 |
+
# NOTE(woosuk): Here we assume that the other processes using the same
|
| 270 |
+
# GPU did not change their memory usage during the profiling.
|
| 271 |
+
assert self.init_snapshot.free_memory > free_gpu_memory, (
|
| 272 |
+
"Error in memory profiling. "
|
| 273 |
+
f"Initial free memory {GiB(self.init_snapshot.free_memory)} GiB, "
|
| 274 |
+
f"current free memory {GiB(free_gpu_memory)} GiB. "
|
| 275 |
+
"This happens when other processes sharing the same container "
|
| 276 |
+
"release GPU memory while vLLM is profiling during initialization. "
|
| 277 |
+
"To fix this, ensure consistent GPU memory allocation or "
|
| 278 |
+
"isolate vLLM in its own container.")
|
| 279 |
+
self.available_kv_cache_memory_bytes = self.requested_memory \
|
| 280 |
+
- profile_result.non_kv_cache_memory
|
| 281 |
+
|
| 282 |
+
unrequested_memory = self.init_snapshot.free_memory \
|
| 283 |
+
- self.requested_memory
|
| 284 |
+
logger.debug(
|
| 285 |
+
"Initial free memory: %.2f GiB; "
|
| 286 |
+
"Requested memory: %.2f (util), %.2f GiB",
|
| 287 |
+
GiB(self.init_snapshot.free_memory),
|
| 288 |
+
self.cache_config.gpu_memory_utilization,
|
| 289 |
+
GiB(self.requested_memory),
|
| 290 |
+
)
|
| 291 |
+
logger.debug(
|
| 292 |
+
"Free memory after profiling: %.2f GiB (total), "
|
| 293 |
+
"%.2f GiB (within requested)",
|
| 294 |
+
GiB(free_gpu_memory),
|
| 295 |
+
GiB(free_gpu_memory - unrequested_memory),
|
| 296 |
+
)
|
| 297 |
+
logger.debug(profile_result)
|
| 298 |
+
logger.info("Available KV cache memory: %.2f GiB",
|
| 299 |
+
GiB(self.available_kv_cache_memory_bytes))
|
| 300 |
+
gc.collect()
|
| 301 |
+
|
| 302 |
+
return int(self.available_kv_cache_memory_bytes)
|
| 303 |
+
|
| 304 |
+
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
| 305 |
+
return self.model_runner.get_kv_cache_spec()
|
| 306 |
+
|
| 307 |
+
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
|
| 308 |
+
"""Allocate GPU KV cache with the specified kv_cache_config."""
|
| 309 |
+
|
| 310 |
+
if self.vllm_config.model_config.enable_sleep_mode:
|
| 311 |
+
from vllm.device_allocator.cumem import CuMemAllocator
|
| 312 |
+
|
| 313 |
+
allocator = CuMemAllocator.get_instance()
|
| 314 |
+
context = allocator.use_memory_pool(tag="kv_cache")
|
| 315 |
+
else:
|
| 316 |
+
context = nullcontext()
|
| 317 |
+
with context:
|
| 318 |
+
self.model_runner.initialize_kv_cache(kv_cache_config)
|
| 319 |
+
|
| 320 |
+
def compile_or_warm_up_model(self) -> None:
|
| 321 |
+
# warm up sizes that are not in cudagraph capture sizes,
|
| 322 |
+
# but users still want to compile for better performance,
|
| 323 |
+
# e.g. for the max-num-batched token size in chunked prefill.
|
| 324 |
+
warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
|
| 325 |
+
if not self.model_config.enforce_eager:
|
| 326 |
+
warmup_sizes = [
|
| 327 |
+
x for x in warmup_sizes if x not in
|
| 328 |
+
self.vllm_config.compilation_config.cudagraph_capture_sizes
|
| 329 |
+
]
|
| 330 |
+
# We skip EPLB here since we don't want to record dummy metrics
|
| 331 |
+
for size in sorted(warmup_sizes, reverse=True):
|
| 332 |
+
logger.info("Compile and warming up model for size %d", size)
|
| 333 |
+
self.model_runner._dummy_run(size,
|
| 334 |
+
skip_eplb=True,
|
| 335 |
+
remove_lora=False)
|
| 336 |
+
self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config)
|
| 337 |
+
|
| 338 |
+
# Warmup and tune the kernels used during model execution before
|
| 339 |
+
# cuda graph capture.
|
| 340 |
+
kernel_warmup(self)
|
| 341 |
+
|
| 342 |
+
cuda_graph_memory_bytes = 0
|
| 343 |
+
if not self.model_config.enforce_eager:
|
| 344 |
+
cuda_graph_memory_bytes = self.model_runner.capture_model()
|
| 345 |
+
|
| 346 |
+
if (self.cache_config.kv_cache_memory_bytes is None
|
| 347 |
+
and hasattr(self, "peak_activation_memory")):
|
| 348 |
+
# Suggests optimal kv cache memory size if we rely on
|
| 349 |
+
# memory_profiling to guess the kv cache memory size which
|
| 350 |
+
# provides peak_activation_memory and a few other memory
|
| 351 |
+
# consumption. `memory_profiling` does not consider
|
| 352 |
+
# CUDAGraph memory size and may not utilize all gpu memory.
|
| 353 |
+
# Users may want fine-grained control to specify kv cache
|
| 354 |
+
# memory size.
|
| 355 |
+
GiB = lambda b: round(b / GiB_bytes, 2)
|
| 356 |
+
|
| 357 |
+
# empirically observed that the memory profiling may
|
| 358 |
+
# slightly underestimate the memory consumption.
|
| 359 |
+
# So leave a small buffer (=150MiB) to avoid OOM.
|
| 360 |
+
redundancy_buffer_memory = 150 * (1 << 20)
|
| 361 |
+
non_kv_cache_memory = (self.model_runner.model_memory_usage +
|
| 362 |
+
self.peak_activation_memory +
|
| 363 |
+
self.non_torch_memory +
|
| 364 |
+
cuda_graph_memory_bytes)
|
| 365 |
+
kv_cache_memory_bytes_to_gpu_limit = (
|
| 366 |
+
self.init_snapshot.free_memory - non_kv_cache_memory -
|
| 367 |
+
redundancy_buffer_memory)
|
| 368 |
+
kv_cache_memory_bytes_to_requested_limit = (
|
| 369 |
+
int(self.requested_memory) - non_kv_cache_memory -
|
| 370 |
+
redundancy_buffer_memory)
|
| 371 |
+
|
| 372 |
+
msg = (
|
| 373 |
+
f"Free memory on device "
|
| 374 |
+
f"({GiB(self.init_snapshot.free_memory)}/"
|
| 375 |
+
f"{GiB(self.init_snapshot.total_memory)} GiB) on startup. "
|
| 376 |
+
f"Desired GPU memory utilization is "
|
| 377 |
+
f"({self.cache_config.gpu_memory_utilization}, "
|
| 378 |
+
f"{GiB(self.requested_memory)} GiB). "
|
| 379 |
+
f"Actual usage is {GiB(self.model_runner.model_memory_usage)} "
|
| 380 |
+
f"GiB for weight, {GiB(self.peak_activation_memory)} GiB "
|
| 381 |
+
f"for peak activation, {GiB(self.non_torch_memory)} GiB "
|
| 382 |
+
f"for non-torch memory, and {GiB(cuda_graph_memory_bytes)} "
|
| 383 |
+
f"GiB for CUDAGraph memory. Replace gpu_memory_utilization "
|
| 384 |
+
f"config with `--kv-cache-memory="
|
| 385 |
+
f"{kv_cache_memory_bytes_to_requested_limit}` "
|
| 386 |
+
f"({GiB(kv_cache_memory_bytes_to_requested_limit)} GiB) to fit "
|
| 387 |
+
f"into requested memory, or `--kv-cache-memory="
|
| 388 |
+
f"{kv_cache_memory_bytes_to_gpu_limit}` "
|
| 389 |
+
f"({GiB(kv_cache_memory_bytes_to_gpu_limit)} GiB) to fully "
|
| 390 |
+
f"utilize gpu memory. Current kv cache memory in use is "
|
| 391 |
+
f"{GiB(self.available_kv_cache_memory_bytes)} GiB.")
|
| 392 |
+
|
| 393 |
+
logger.debug(msg)
|
| 394 |
+
|
| 395 |
+
# Warm up sampler and preallocate memory buffer for logits and other
|
| 396 |
+
# sampling related tensors of max possible shape to avoid memory
|
| 397 |
+
# fragmentation issue.
|
| 398 |
+
# NOTE: This is called after `capture_model` on purpose to prevent
|
| 399 |
+
# memory buffers from being cleared by `torch.cuda.empty_cache`.
|
| 400 |
+
if get_pp_group().is_last_rank:
|
| 401 |
+
max_num_reqs = min(self.scheduler_config.max_num_seqs,
|
| 402 |
+
self.scheduler_config.max_num_batched_tokens)
|
| 403 |
+
|
| 404 |
+
# We skip EPLB here since we don't want to record dummy metrics
|
| 405 |
+
hidden_states, last_hidden_states = \
|
| 406 |
+
self.model_runner._dummy_run(
|
| 407 |
+
num_tokens=max_num_reqs,
|
| 408 |
+
skip_eplb=True,
|
| 409 |
+
)
|
| 410 |
+
if self.model_runner.is_pooling_model:
|
| 411 |
+
self.model_runner._dummy_pooler_run(hidden_states)
|
| 412 |
+
else:
|
| 413 |
+
self.model_runner._dummy_sampler_run(
|
| 414 |
+
hidden_states=last_hidden_states)
|
| 415 |
+
|
| 416 |
+
# Reset the seed to ensure that the random state is not affected by
|
| 417 |
+
# the model initialization and profiling.
|
| 418 |
+
set_random_seed(self.model_config.seed)
|
| 419 |
+
|
| 420 |
+
def get_model(self) -> nn.Module:
|
| 421 |
+
return self.model_runner.get_model()
|
| 422 |
+
|
| 423 |
+
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
| 424 |
+
return self.model_runner.get_supported_tasks()
|
| 425 |
+
|
| 426 |
+
@torch.inference_mode()
|
| 427 |
+
def execute_model(
|
| 428 |
+
self,
|
| 429 |
+
scheduler_output: "SchedulerOutput",
|
| 430 |
+
) -> Optional[Union[ModelRunnerOutput, AsyncModelRunnerOutput]]:
|
| 431 |
+
intermediate_tensors = None
|
| 432 |
+
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
|
| 433 |
+
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
| 434 |
+
num_input_tokens = self.model_runner._get_num_input_tokens(
|
| 435 |
+
num_scheduled_tokens)
|
| 436 |
+
all_gather_tensors = {
|
| 437 |
+
"residual":
|
| 438 |
+
not is_residual_scattered_for_sp(self.vllm_config,
|
| 439 |
+
num_input_tokens)
|
| 440 |
+
}
|
| 441 |
+
if forward_pass and not get_pp_group().is_first_rank:
|
| 442 |
+
intermediate_tensors = IntermediateTensors(
|
| 443 |
+
get_pp_group().recv_tensor_dict(
|
| 444 |
+
all_gather_group=get_tp_group(),
|
| 445 |
+
all_gather_tensors=all_gather_tensors))
|
| 446 |
+
|
| 447 |
+
output = self.model_runner.execute_model(scheduler_output,
|
| 448 |
+
intermediate_tensors)
|
| 449 |
+
if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)):
|
| 450 |
+
return output
|
| 451 |
+
|
| 452 |
+
assert isinstance(output, IntermediateTensors)
|
| 453 |
+
parallel_config = self.vllm_config.parallel_config
|
| 454 |
+
assert parallel_config.distributed_executor_backend != (
|
| 455 |
+
"external_launcher") and not get_pp_group().is_last_rank
|
| 456 |
+
|
| 457 |
+
get_pp_group().send_tensor_dict(output.tensors,
|
| 458 |
+
all_gather_group=get_tp_group(),
|
| 459 |
+
all_gather_tensors=all_gather_tensors)
|
| 460 |
+
|
| 461 |
+
kv_connector_output = output.kv_connector_output
|
| 462 |
+
if not kv_connector_output:
|
| 463 |
+
return None
|
| 464 |
+
|
| 465 |
+
# In case of PP with kv transfer, we need to pass through the
|
| 466 |
+
# kv_connector_output
|
| 467 |
+
if (not kv_connector_output.finished_sending
|
| 468 |
+
and not kv_connector_output.finished_recving):
|
| 469 |
+
return EMPTY_MODEL_RUNNER_OUTPUT
|
| 470 |
+
|
| 471 |
+
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
| 472 |
+
output.kv_connector_output = kv_connector_output
|
| 473 |
+
return output
|
| 474 |
+
|
| 475 |
+
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
|
| 476 |
+
return self.model_runner.take_draft_token_ids()
|
| 477 |
+
|
| 478 |
+
def profile(self, is_start: bool = True):
|
| 479 |
+
if self.profiler is None:
|
| 480 |
+
raise RuntimeError("Profiler is not enabled.")
|
| 481 |
+
if is_start:
|
| 482 |
+
self.profiler.start()
|
| 483 |
+
else:
|
| 484 |
+
self.profiler.stop()
|
| 485 |
+
# only print profiler results on rank 0
|
| 486 |
+
if self.local_rank == 0:
|
| 487 |
+
print(self.profiler.key_averages().table(
|
| 488 |
+
sort_by="self_cuda_time_total"))
|
| 489 |
+
|
| 490 |
+
def execute_dummy_batch(self) -> None:
|
| 491 |
+
self.model_runner._dummy_run(1, uniform_decode=True)
|
| 492 |
+
|
| 493 |
+
def add_lora(self, lora_request: LoRARequest) -> bool:
|
| 494 |
+
return self.model_runner.add_lora(lora_request)
|
| 495 |
+
|
| 496 |
+
def remove_lora(self, lora_id: int) -> bool:
|
| 497 |
+
return self.model_runner.remove_lora(lora_id)
|
| 498 |
+
|
| 499 |
+
def list_loras(self) -> set[int]:
|
| 500 |
+
return self.model_runner.list_loras()
|
| 501 |
+
|
| 502 |
+
def pin_lora(self, lora_id: int) -> bool:
|
| 503 |
+
return self.model_runner.pin_lora(lora_id)
|
| 504 |
+
|
| 505 |
+
def check_health(self) -> None:
|
| 506 |
+
# worker will always be healthy as long as it's running.
|
| 507 |
+
return
|
| 508 |
+
|
| 509 |
+
def _eplb_before_scale_down(self, old_ep_size: int,
|
| 510 |
+
new_ep_size: int) -> None:
|
| 511 |
+
from vllm.distributed.parallel_state import get_ep_group
|
| 512 |
+
if get_ep_group().rank == 0:
|
| 513 |
+
logger.info("[Elastic EP] Starting expert resharding "
|
| 514 |
+
"before scaling down...")
|
| 515 |
+
rank_mapping = {
|
| 516 |
+
old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
|
| 517 |
+
for old_ep_rank in range(old_ep_size)
|
| 518 |
+
}
|
| 519 |
+
assert self.model_runner.eplb_state is not None
|
| 520 |
+
self.model_runner.eplb_state.rearrange(self.model_runner.model,
|
| 521 |
+
execute_shuffle=True,
|
| 522 |
+
global_expert_load=None,
|
| 523 |
+
rank_mapping=rank_mapping)
|
| 524 |
+
torch.cuda.synchronize()
|
| 525 |
+
if get_ep_group().rank == 0:
|
| 526 |
+
logger.info("[Elastic EP] Expert resharding completed!")
|
| 527 |
+
|
| 528 |
+
def _eplb_after_scale_up(
|
| 529 |
+
self, old_ep_size: int, new_ep_size: int,
|
| 530 |
+
global_expert_load: Optional[torch.Tensor]) -> None:
|
| 531 |
+
from vllm.distributed.parallel_state import get_ep_group
|
| 532 |
+
if get_ep_group().rank == 0:
|
| 533 |
+
logger.info("[Elastic EP] Starting expert resharding "
|
| 534 |
+
"after scaling up...")
|
| 535 |
+
rank_mapping = {
|
| 536 |
+
old_ep_rank: old_ep_rank
|
| 537 |
+
for old_ep_rank in range(old_ep_size)
|
| 538 |
+
}
|
| 539 |
+
assert self.model_runner.eplb_state is not None
|
| 540 |
+
self.model_runner.eplb_state.rearrange(
|
| 541 |
+
self.model_runner.model,
|
| 542 |
+
execute_shuffle=True,
|
| 543 |
+
global_expert_load=global_expert_load,
|
| 544 |
+
rank_mapping=rank_mapping)
|
| 545 |
+
if get_ep_group().rank == 0:
|
| 546 |
+
logger.info("[Elastic EP] Expert resharding completed!")
|
| 547 |
+
|
| 548 |
+
def _reconfigure_parallel_config(
|
| 549 |
+
self, reconfig_request: ReconfigureDistributedRequest) -> None:
|
| 550 |
+
"""
|
| 551 |
+
Update parallel config with provided reconfig_request
|
| 552 |
+
"""
|
| 553 |
+
parallel_config = self.vllm_config.parallel_config
|
| 554 |
+
parallel_config.data_parallel_size = \
|
| 555 |
+
reconfig_request.new_data_parallel_size
|
| 556 |
+
if reconfig_request.new_data_parallel_rank != \
|
| 557 |
+
ReconfigureRankType.KEEP_CURRENT_RANK:
|
| 558 |
+
parallel_config.data_parallel_rank = \
|
| 559 |
+
reconfig_request.new_data_parallel_rank
|
| 560 |
+
if reconfig_request.new_data_parallel_rank_local != \
|
| 561 |
+
ReconfigureRankType.KEEP_CURRENT_RANK:
|
| 562 |
+
parallel_config.data_parallel_rank_local = \
|
| 563 |
+
reconfig_request.new_data_parallel_rank_local
|
| 564 |
+
parallel_config.data_parallel_master_ip = \
|
| 565 |
+
reconfig_request.new_data_parallel_master_ip
|
| 566 |
+
parallel_config.data_parallel_master_port = \
|
| 567 |
+
reconfig_request.new_data_parallel_master_port
|
| 568 |
+
|
| 569 |
+
def _reconfigure_moe(self, old_ep_size: int,
|
| 570 |
+
new_ep_size: int) -> Optional[torch.Tensor]:
|
| 571 |
+
"""
|
| 572 |
+
Reconfigure MoE modules with provided reconfig_request
|
| 573 |
+
|
| 574 |
+
Return the global expert load if new_ep_size > old_ep_size,
|
| 575 |
+
otherwise None
|
| 576 |
+
"""
|
| 577 |
+
from vllm.distributed.parallel_state import (
|
| 578 |
+
get_dp_group, get_ep_group, prepare_communication_buffer_for_model)
|
| 579 |
+
from vllm.model_executor.layers.fused_moe.layer import (
|
| 580 |
+
FusedMoEParallelConfig)
|
| 581 |
+
|
| 582 |
+
parallel_config = self.vllm_config.parallel_config
|
| 583 |
+
moe_modules = [
|
| 584 |
+
module for module in self.model_runner.model.modules()
|
| 585 |
+
if (module.__class__.__name__ == "FusedMoE"
|
| 586 |
+
or module.__class__.__name__ == "SharedFusedMoE")
|
| 587 |
+
]
|
| 588 |
+
num_local_experts = moe_modules[0].moe_config.num_local_experts
|
| 589 |
+
assert all(module.moe_config.num_local_experts == num_local_experts
|
| 590 |
+
for module in moe_modules), (
|
| 591 |
+
"All MoE modules must have the same number of experts")
|
| 592 |
+
for module in moe_modules:
|
| 593 |
+
module.moe_config.num_experts = num_local_experts * new_ep_size
|
| 594 |
+
module.global_num_experts = module.moe_config.num_experts
|
| 595 |
+
module.moe_parallel_config = FusedMoEParallelConfig.make(
|
| 596 |
+
tp_size_=get_tp_group().world_size,
|
| 597 |
+
dp_size_=get_dp_group().world_size,
|
| 598 |
+
vllm_parallel_config=parallel_config,
|
| 599 |
+
)
|
| 600 |
+
module.moe_config.moe_parallel_config = module.moe_parallel_config
|
| 601 |
+
if new_ep_size < old_ep_size:
|
| 602 |
+
num_local_physical_experts = num_local_experts
|
| 603 |
+
assert self.model_runner.eplb_state is not None
|
| 604 |
+
new_physical_experts = \
|
| 605 |
+
self.model_runner.eplb_state.physical_to_logical_map.shape[1]
|
| 606 |
+
parallel_config.eplb_config.num_redundant_experts = (
|
| 607 |
+
new_physical_experts -
|
| 608 |
+
self.model_runner.eplb_state.logical_replica_count.shape[1])
|
| 609 |
+
global_expert_load = None
|
| 610 |
+
else:
|
| 611 |
+
num_local_physical_experts = torch.tensor([num_local_experts],
|
| 612 |
+
dtype=torch.int32,
|
| 613 |
+
device="cpu")
|
| 614 |
+
torch.distributed.broadcast(num_local_physical_experts,
|
| 615 |
+
group=get_ep_group().cpu_group,
|
| 616 |
+
group_src=0)
|
| 617 |
+
num_local_physical_experts = num_local_physical_experts.item()
|
| 618 |
+
new_physical_experts = num_local_physical_experts * new_ep_size
|
| 619 |
+
assert self.model_runner.eplb_state is not None
|
| 620 |
+
global_expert_load = self.model_runner.eplb_state.rearrange(
|
| 621 |
+
self.model_runner.model, execute_shuffle=False)
|
| 622 |
+
parallel_config.eplb_config.num_redundant_experts = (
|
| 623 |
+
new_physical_experts - global_expert_load.shape[1])
|
| 624 |
+
prepare_communication_buffer_for_model(self.model_runner.model)
|
| 625 |
+
self.model_runner.model.update_physical_experts_metadata(
|
| 626 |
+
num_physical_experts=new_physical_experts,
|
| 627 |
+
num_local_physical_experts=num_local_physical_experts)
|
| 628 |
+
return global_expert_load
|
| 629 |
+
|
| 630 |
+
def reinitialize_distributed(
|
| 631 |
+
self, reconfig_request: ReconfigureDistributedRequest) -> None:
|
| 632 |
+
from vllm.config import set_current_vllm_config
|
| 633 |
+
from vllm.distributed.parallel_state import (
|
| 634 |
+
cleanup_dist_env_and_memory, get_ep_group)
|
| 635 |
+
|
| 636 |
+
old_ep_size = get_ep_group().world_size
|
| 637 |
+
old_ep_rank = get_ep_group().rank
|
| 638 |
+
new_ep_size = reconfig_request.new_data_parallel_size * get_tp_group(
|
| 639 |
+
).world_size * get_pp_group().world_size
|
| 640 |
+
if new_ep_size < old_ep_size:
|
| 641 |
+
self._eplb_before_scale_down(old_ep_size, new_ep_size)
|
| 642 |
+
|
| 643 |
+
cleanup_dist_env_and_memory()
|
| 644 |
+
|
| 645 |
+
if reconfig_request.new_data_parallel_rank == \
|
| 646 |
+
ReconfigureRankType.SHUTDOWN_CURRENT_RANK:
|
| 647 |
+
assert old_ep_rank >= new_ep_size
|
| 648 |
+
# shutdown
|
| 649 |
+
return
|
| 650 |
+
|
| 651 |
+
self._reconfigure_parallel_config(reconfig_request)
|
| 652 |
+
|
| 653 |
+
with set_current_vllm_config(self.vllm_config):
|
| 654 |
+
init_worker_distributed_environment(self.vllm_config, self.rank,
|
| 655 |
+
self.distributed_init_method,
|
| 656 |
+
self.local_rank)
|
| 657 |
+
|
| 658 |
+
global_expert_load = self._reconfigure_moe(old_ep_size, new_ep_size)
|
| 659 |
+
|
| 660 |
+
if new_ep_size > old_ep_size:
|
| 661 |
+
assert global_expert_load is not None
|
| 662 |
+
self._eplb_after_scale_up(old_ep_size, new_ep_size,
|
| 663 |
+
global_expert_load)
|
| 664 |
+
|
| 665 |
+
def save_sharded_state(
|
| 666 |
+
self,
|
| 667 |
+
path: str,
|
| 668 |
+
pattern: Optional[str] = None,
|
| 669 |
+
max_size: Optional[int] = None,
|
| 670 |
+
) -> None:
|
| 671 |
+
from vllm.model_executor.model_loader import ShardedStateLoader
|
| 672 |
+
ShardedStateLoader.save_model(
|
| 673 |
+
self.model_runner.model,
|
| 674 |
+
path,
|
| 675 |
+
pattern=pattern,
|
| 676 |
+
max_size=max_size,
|
| 677 |
+
)
|
| 678 |
+
|
| 679 |
+
def save_tensorized_model(
|
| 680 |
+
self,
|
| 681 |
+
tensorizer_config: "TensorizerConfig",
|
| 682 |
+
) -> None:
|
| 683 |
+
self.model_runner.save_tensorized_model(
|
| 684 |
+
tensorizer_config=tensorizer_config, )
|
| 685 |
+
|
| 686 |
+
def shutdown(self) -> None:
|
| 687 |
+
if runner := getattr(self, "model_runner", None):
|
| 688 |
+
runner.ensure_kv_transfer_shutdown()
|
| 689 |
+
|
| 690 |
+
|
| 691 |
+
def init_worker_distributed_environment(
|
| 692 |
+
vllm_config: VllmConfig,
|
| 693 |
+
rank: int,
|
| 694 |
+
distributed_init_method: Optional[str] = None,
|
| 695 |
+
local_rank: int = -1,
|
| 696 |
+
backend: str = "nccl",
|
| 697 |
+
) -> None:
|
| 698 |
+
"""Initialize the distributed environment."""
|
| 699 |
+
parallel_config = vllm_config.parallel_config
|
| 700 |
+
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
|
| 701 |
+
|
| 702 |
+
init_distributed_environment(parallel_config.world_size, rank,
|
| 703 |
+
distributed_init_method, local_rank, backend)
|
| 704 |
+
|
| 705 |
+
ensure_model_parallel_initialized(
|
| 706 |
+
parallel_config.tensor_parallel_size,
|
| 707 |
+
parallel_config.pipeline_parallel_size,
|
| 708 |
+
parallel_config.decode_context_parallel_size)
|
| 709 |
+
|
| 710 |
+
ensure_kv_transfer_initialized(vllm_config)
|
vllm_hacked/worker_base.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
from typing import (Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar,
|
| 6 |
+
Union)
|
| 7 |
+
|
| 8 |
+
import cloudpickle
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
|
| 11 |
+
from vllm.config import VllmConfig, set_current_vllm_config
|
| 12 |
+
from vllm.logger import init_logger
|
| 13 |
+
from vllm.lora.request import LoRARequest
|
| 14 |
+
from vllm.sequence import ExecuteModelRequest
|
| 15 |
+
from vllm.utils import (enable_trace_function_call_for_thread,
|
| 16 |
+
resolve_obj_by_qualname, run_method,
|
| 17 |
+
update_environment_variables,
|
| 18 |
+
warn_for_unimplemented_methods)
|
| 19 |
+
from vllm.v1.outputs import SamplerOutput
|
| 20 |
+
|
| 21 |
+
logger = init_logger(__name__)
|
| 22 |
+
|
| 23 |
+
_R = TypeVar("_R")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@warn_for_unimplemented_methods
|
| 27 |
+
class WorkerBase:
|
| 28 |
+
"""Worker interface that allows vLLM to cleanly separate implementations for
|
| 29 |
+
different hardware. Also abstracts control plane communication, e.g., to
|
| 30 |
+
communicate request metadata to other workers.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
vllm_config: VllmConfig,
|
| 36 |
+
) -> None:
|
| 37 |
+
self.vllm_config = vllm_config
|
| 38 |
+
self.model_config = vllm_config.model_config
|
| 39 |
+
self.cache_config = vllm_config.cache_config
|
| 40 |
+
self.lora_config = vllm_config.lora_config
|
| 41 |
+
self.load_config = vllm_config.load_config
|
| 42 |
+
self.parallel_config = vllm_config.parallel_config
|
| 43 |
+
self.scheduler_config = vllm_config.scheduler_config
|
| 44 |
+
self.device_config = vllm_config.device_config
|
| 45 |
+
self.speculative_config = vllm_config.speculative_config
|
| 46 |
+
self.observability_config = vllm_config.observability_config
|
| 47 |
+
self.kv_transfer_config = vllm_config.kv_transfer_config
|
| 48 |
+
self.compilation_config = vllm_config.compilation_config
|
| 49 |
+
from vllm.platforms import current_platform
|
| 50 |
+
self.current_platform = current_platform
|
| 51 |
+
|
| 52 |
+
def init_device(self) -> None:
|
| 53 |
+
"""Initialize device state, such as loading the model or other on-device
|
| 54 |
+
memory allocations.
|
| 55 |
+
"""
|
| 56 |
+
raise NotImplementedError
|
| 57 |
+
|
| 58 |
+
def initialize_cache(self, num_gpu_blocks: int,
|
| 59 |
+
num_cpu_blocks: int) -> None:
|
| 60 |
+
"""Initialize the KV cache with the given size in blocks.
|
| 61 |
+
"""
|
| 62 |
+
raise NotImplementedError
|
| 63 |
+
|
| 64 |
+
def get_model(self) -> nn.Module:
|
| 65 |
+
raise NotImplementedError
|
| 66 |
+
|
| 67 |
+
def apply_model(self, fn: Callable[[nn.Module], _R]) -> _R:
|
| 68 |
+
"""Apply a function on the model inside this worker."""
|
| 69 |
+
return fn(self.get_model())
|
| 70 |
+
|
| 71 |
+
def load_model(self) -> None:
|
| 72 |
+
"""Load model onto target device."""
|
| 73 |
+
raise NotImplementedError
|
| 74 |
+
|
| 75 |
+
def execute_model(
|
| 76 |
+
self,
|
| 77 |
+
execute_model_req: Optional[ExecuteModelRequest] = None
|
| 78 |
+
) -> Optional[List[SamplerOutput]]:
|
| 79 |
+
raise NotImplementedError
|
| 80 |
+
|
| 81 |
+
def start_worker_execution_loop(self) -> None:
|
| 82 |
+
"""Execute model loop in parallel worker.
|
| 83 |
+
|
| 84 |
+
You can stop the loop by executing a driver worker with an empty output.
|
| 85 |
+
See `stop_remote_worker_execution_loop` for more details.
|
| 86 |
+
"""
|
| 87 |
+
with self.current_platform.inference_mode():
|
| 88 |
+
while True:
|
| 89 |
+
output = self.execute_model(execute_model_req=None)
|
| 90 |
+
if output is None:
|
| 91 |
+
return None
|
| 92 |
+
|
| 93 |
+
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
| 94 |
+
"""Determine the number of available blocks for the GPU KV cache and
|
| 95 |
+
swappable CPU KV cache.
|
| 96 |
+
|
| 97 |
+
The implementation may run profiling or other heuristics to determine
|
| 98 |
+
the size of caches.
|
| 99 |
+
|
| 100 |
+
Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
|
| 101 |
+
are blocks that are "active" on the device and can be appended to.
|
| 102 |
+
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
|
| 103 |
+
appended to.
|
| 104 |
+
"""
|
| 105 |
+
raise NotImplementedError
|
| 106 |
+
|
| 107 |
+
def get_cache_block_size_bytes(self) -> int:
|
| 108 |
+
"""Return the size of a single cache block, in bytes. Used in
|
| 109 |
+
speculative decoding.
|
| 110 |
+
"""
|
| 111 |
+
raise NotImplementedError
|
| 112 |
+
|
| 113 |
+
def add_lora(self, lora_request: LoRARequest) -> bool:
|
| 114 |
+
raise NotImplementedError
|
| 115 |
+
|
| 116 |
+
def remove_lora(self, lora_id: int) -> bool:
|
| 117 |
+
raise NotImplementedError
|
| 118 |
+
|
| 119 |
+
def pin_lora(self, lora_id: int) -> bool:
|
| 120 |
+
raise NotImplementedError
|
| 121 |
+
|
| 122 |
+
def list_loras(self) -> Set[int]:
|
| 123 |
+
raise NotImplementedError
|
| 124 |
+
|
| 125 |
+
@property
|
| 126 |
+
def vocab_size(self) -> int:
|
| 127 |
+
"""Get vocabulary size from model configuration."""
|
| 128 |
+
return self.model_config.get_vocab_size()
|
| 129 |
+
|
| 130 |
+
def shutdown(self) -> None:
|
| 131 |
+
"""Clean up resources held by the worker."""
|
| 132 |
+
return
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class WorkerWrapperBase:
|
| 136 |
+
"""
|
| 137 |
+
This class represents one process in an executor/engine. It is responsible
|
| 138 |
+
for lazily initializing the worker and handling the worker's lifecycle.
|
| 139 |
+
We first instantiate the WorkerWrapper, which remembers the worker module
|
| 140 |
+
and class name. Then, when we call `update_environment_variables`, and the
|
| 141 |
+
real initialization happens in `init_worker`.
|
| 142 |
+
"""
|
| 143 |
+
|
| 144 |
+
def __init__(
|
| 145 |
+
self,
|
| 146 |
+
vllm_config: VllmConfig,
|
| 147 |
+
rpc_rank: int = 0,
|
| 148 |
+
) -> None:
|
| 149 |
+
"""
|
| 150 |
+
Initialize the worker wrapper with the given vllm_config and rpc_rank.
|
| 151 |
+
Note: rpc_rank is the rank of the worker in the executor. In most cases,
|
| 152 |
+
it is also the rank of the worker in the distributed group. However,
|
| 153 |
+
when multiple executors work together, they can be different.
|
| 154 |
+
e.g. in the case of SPMD-style offline inference with TP=2,
|
| 155 |
+
users can launch 2 engines/executors, each with only 1 worker.
|
| 156 |
+
All workers have rpc_rank=0, but they have different ranks in the TP
|
| 157 |
+
group.
|
| 158 |
+
"""
|
| 159 |
+
self.rpc_rank = rpc_rank
|
| 160 |
+
self.worker: Optional[WorkerBase] = None
|
| 161 |
+
self.vllm_config: Optional[VllmConfig] = None
|
| 162 |
+
# do not store this `vllm_config`, `init_worker` will set the final
|
| 163 |
+
# one. TODO: investigate if we can remove this field in
|
| 164 |
+
# `WorkerWrapperBase`, `init_cached_hf_modules` should be
|
| 165 |
+
# unnecessary now.
|
| 166 |
+
if vllm_config.model_config is not None:
|
| 167 |
+
# it can be None in tests
|
| 168 |
+
trust_remote_code = vllm_config.model_config.trust_remote_code
|
| 169 |
+
if trust_remote_code:
|
| 170 |
+
# note: lazy import to avoid importing torch before initializing
|
| 171 |
+
from vllm.utils import init_cached_hf_modules
|
| 172 |
+
init_cached_hf_modules()
|
| 173 |
+
|
| 174 |
+
def shutdown(self) -> None:
|
| 175 |
+
if self.worker is not None:
|
| 176 |
+
self.worker.shutdown()
|
| 177 |
+
|
| 178 |
+
def adjust_rank(self, rank_mapping: Dict[int, int]) -> None:
|
| 179 |
+
"""
|
| 180 |
+
Adjust the rpc_rank based on the given mapping.
|
| 181 |
+
It is only used during the initialization of the executor,
|
| 182 |
+
to adjust the rpc_rank of workers after we create all workers.
|
| 183 |
+
"""
|
| 184 |
+
if self.rpc_rank in rank_mapping:
|
| 185 |
+
self.rpc_rank = rank_mapping[self.rpc_rank]
|
| 186 |
+
|
| 187 |
+
def update_environment_variables(self, envs_list: List[Dict[str,
|
| 188 |
+
str]]) -> None:
|
| 189 |
+
envs = envs_list[self.rpc_rank]
|
| 190 |
+
key = 'CUDA_VISIBLE_DEVICES'
|
| 191 |
+
if key in envs and key in os.environ:
|
| 192 |
+
# overwriting CUDA_VISIBLE_DEVICES is desired behavior
|
| 193 |
+
# suppress the warning in `update_environment_variables`
|
| 194 |
+
del os.environ[key]
|
| 195 |
+
update_environment_variables(envs)
|
| 196 |
+
|
| 197 |
+
def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None:
|
| 198 |
+
"""
|
| 199 |
+
Here we inject some common logic before initializing the worker.
|
| 200 |
+
Arguments are passed to the worker class constructor.
|
| 201 |
+
"""
|
| 202 |
+
kwargs = all_kwargs[self.rpc_rank]
|
| 203 |
+
self.vllm_config = kwargs.get("vllm_config")
|
| 204 |
+
assert self.vllm_config is not None, (
|
| 205 |
+
"vllm_config is required to initialize the worker")
|
| 206 |
+
enable_trace_function_call_for_thread(self.vllm_config)
|
| 207 |
+
|
| 208 |
+
from vllm.plugins import load_general_plugins
|
| 209 |
+
load_general_plugins()
|
| 210 |
+
|
| 211 |
+
if isinstance(self.vllm_config.parallel_config.worker_cls, str):
|
| 212 |
+
worker_class = resolve_obj_by_qualname(
|
| 213 |
+
self.vllm_config.parallel_config.worker_cls)
|
| 214 |
+
else:
|
| 215 |
+
logger.warning(
|
| 216 |
+
"passing worker_cls as a class object is strongly deprecated,"
|
| 217 |
+
" as the serialization of class objects can be tricky and"
|
| 218 |
+
" error-prone. To be safe, please keep the class in a separate"
|
| 219 |
+
" module and pass the qualified name of the class as a string."
|
| 220 |
+
)
|
| 221 |
+
assert isinstance(self.vllm_config.parallel_config.worker_cls,
|
| 222 |
+
bytes)
|
| 223 |
+
worker_class = cloudpickle.loads(
|
| 224 |
+
self.vllm_config.parallel_config.worker_cls)
|
| 225 |
+
if self.vllm_config.parallel_config.worker_extension_cls:
|
| 226 |
+
worker_extension_cls = resolve_obj_by_qualname(
|
| 227 |
+
self.vllm_config.parallel_config.worker_extension_cls)
|
| 228 |
+
extended_calls = []
|
| 229 |
+
if worker_extension_cls not in worker_class.__bases__:
|
| 230 |
+
# check any conflicts between worker and worker_extension_cls
|
| 231 |
+
for attr in dir(worker_extension_cls):
|
| 232 |
+
if attr.startswith("__"):
|
| 233 |
+
continue
|
| 234 |
+
assert not hasattr(worker_class, attr), (
|
| 235 |
+
f"Worker class {worker_class} already has an attribute"
|
| 236 |
+
f" {attr}, which conflicts with the worker"
|
| 237 |
+
f" extension class {worker_extension_cls}.")
|
| 238 |
+
if callable(getattr(worker_extension_cls, attr)):
|
| 239 |
+
extended_calls.append(attr)
|
| 240 |
+
# dynamically inherit the worker extension class
|
| 241 |
+
worker_class.__bases__ = worker_class.__bases__ + (
|
| 242 |
+
worker_extension_cls, )
|
| 243 |
+
logger.info(
|
| 244 |
+
"Injected %s into %s for extended collective_rpc calls %s",
|
| 245 |
+
worker_extension_cls, worker_class, extended_calls)
|
| 246 |
+
with set_current_vllm_config(self.vllm_config):
|
| 247 |
+
# To make vLLM config available during worker initialization
|
| 248 |
+
self.worker = worker_class(**kwargs)
|
| 249 |
+
assert self.worker is not None
|
| 250 |
+
|
| 251 |
+
def initialize_from_config(self, kv_cache_configs: List[Any]) -> None:
|
| 252 |
+
kv_cache_config = kv_cache_configs[self.rpc_rank]
|
| 253 |
+
with set_current_vllm_config(self.vllm_config):
|
| 254 |
+
self.worker.initialize_from_config(kv_cache_config) # type: ignore
|
| 255 |
+
|
| 256 |
+
def init_device(self):
|
| 257 |
+
with set_current_vllm_config(self.vllm_config):
|
| 258 |
+
# To make vLLM config available during device initialization
|
| 259 |
+
self.worker.init_device() # type: ignore
|
| 260 |
+
|
| 261 |
+
def execute_method(self, method: Union[str, bytes], *args, **kwargs):
|
| 262 |
+
try:
|
| 263 |
+
# method resolution order:
|
| 264 |
+
# if a method is defined in this class, it will be called directly.
|
| 265 |
+
# otherwise, since we define `__getattr__` and redirect attribute
|
| 266 |
+
# query to `self.worker`, the method will be called on the worker.
|
| 267 |
+
return run_method(self, method, args, kwargs)
|
| 268 |
+
except Exception as e:
|
| 269 |
+
# if the driver worker also execute methods,
|
| 270 |
+
# exceptions in the rest worker may cause deadlock in rpc like ray
|
| 271 |
+
# see https://github.com/vllm-project/vllm/issues/3455
|
| 272 |
+
# print the error and inform the user to solve the error
|
| 273 |
+
msg = (f"Error executing method {method!r}. "
|
| 274 |
+
"This might cause deadlock in distributed execution.")
|
| 275 |
+
logger.exception(msg)
|
| 276 |
+
raise e
|
| 277 |
+
|
| 278 |
+
def __getattr__(self, attr):
|
| 279 |
+
return getattr(self.worker, attr)
|
z_script.py
DELETED
|
@@ -1,44 +0,0 @@
|
|
| 1 |
-
from hmac import new
|
| 2 |
-
import sys
|
| 3 |
-
import os
|
| 4 |
-
import argparse
|
| 5 |
-
from safetensors.torch import save_file
|
| 6 |
-
|
| 7 |
-
import time
|
| 8 |
-
import json
|
| 9 |
-
import torch
|
| 10 |
-
import torchaudio
|
| 11 |
-
import numpy as np
|
| 12 |
-
from omegaconf import OmegaConf
|
| 13 |
-
from codeclm.models import builders
|
| 14 |
-
import gc
|
| 15 |
-
from codeclm.trainer.codec_song_pl import CodecLM_PL
|
| 16 |
-
from codeclm.models import CodecLM
|
| 17 |
-
from third_party.demucs.models.pretrained import get_model_from_yaml
|
| 18 |
-
|
| 19 |
-
cfg_path = "/apdcephfs_cq11/share_300883980/tanwei/SongGeneration-LeVo/ckpt/songgeneration_base/config.yaml"
|
| 20 |
-
cfg = OmegaConf.load(cfg_path)
|
| 21 |
-
cfg.mode = 'inference'
|
| 22 |
-
# audio_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint, cfg)
|
| 23 |
-
# model = audio_tokenizer.model.model
|
| 24 |
-
# weights = {k: v.half() for k, v in model.state_dict().items() if isinstance(v, torch.Tensor) and v.numel() > 0}
|
| 25 |
-
# save_file(weights, '/apdcephfs_cq11/share_300883980/tanwei/SongGeneration-LeVo/ckpt/encoder_fp16.safetensors')
|
| 26 |
-
# print(weights)
|
| 27 |
-
|
| 28 |
-
# seperate_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint_sep, cfg)
|
| 29 |
-
# model = seperate_tokenizer.model.model
|
| 30 |
-
# weights = {}
|
| 31 |
-
# for k, v in model.state_dict().items():
|
| 32 |
-
# if k.startswith("rvq_bestrq_bgm_emb") or k.startswith("rvq_bestrq_emb") or k.startswith("bestrq"):
|
| 33 |
-
# weights[k] = v.half()
|
| 34 |
-
# else:
|
| 35 |
-
# weights[k] = v
|
| 36 |
-
# # weights = {k: v.half() for k, v in model.state_dict().items() if isinstance(v, torch.Tensor) and v.numel() > 0}
|
| 37 |
-
# save_file(weights, '/apdcephfs_cq11/share_300883980/tanwei/SongGeneration-LeVo/ckpt/encoder_fp16.safetensors')
|
| 38 |
-
# print(weights.keys())
|
| 39 |
-
|
| 40 |
-
ckpt_path = "/apdcephfs_cq11/share_300883980/tanwei/SongGeneration-WX/ckpt/songgeneration_new_small/model_32.pt"
|
| 41 |
-
# audiolm = builders.get_lm_model(cfg)
|
| 42 |
-
checkpoint = torch.load(ckpt_path, map_location='cpu')
|
| 43 |
-
audiolm_state_dict = {k: v.half() for k, v in checkpoint.items()}
|
| 44 |
-
torch.save(audiolm_state_dict, "/apdcephfs_cq11/share_300883980/tanwei/SongGeneration-WX/ckpt/songgeneration_new_small/model.pt")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|