root commited on
Commit
c8c0ef5
·
1 Parent(s): 57d225d

push to levo2.0

Browse files
Files changed (41) hide show
  1. Dockerfile +5 -1
  2. app.py +12 -13
  3. codeclm/models/builders.py +1 -1
  4. codeclm/models/codeclm_gen.py +326 -0
  5. codeclm/models/levo.py +2 -2
  6. codeclm/models/llama/modeling_llama.py +4 -1
  7. codeclm/modules/conditioners.py +29 -37
  8. codeclm/tokenizer/Flow1dVAE/generate_1rvq.py +28 -56
  9. codeclm/tokenizer/Flow1dVAE/model_1rvq.py +10 -29
  10. codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_config.py +55 -0
  11. codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_rope2_time_new_correct_mask_noncasual_reflow.py +2 -2
  12. codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/musicfm/modules/features.py +14 -5
  13. codeclm/tokenizer/audio_tokenizer.py +2 -2
  14. generate.py +106 -500
  15. generate.sh +9 -64
  16. levo_inference.py +57 -50
  17. requirements.txt +1 -0
  18. sample/lyrics.jsonl +2 -3
  19. vllm_hacked/model_executor/layers/utils.py +196 -0
  20. vllm_hacked/model_executor/layers/utils_ori.py +195 -0
  21. vllm_hacked/model_executor/models/llama.py +688 -0
  22. vllm_hacked/model_executor/sampling_metadata.py +596 -0
  23. vllm_hacked/model_executor/sampling_metadata_ori.py +596 -0
  24. vllm_hacked/sampling_params.py +596 -0
  25. vllm_hacked/sampling_params_ori.py +593 -0
  26. ckpt/.gitkeep → vllm_hacked/v1/sample/__init__ori.py +0 -0
  27. vllm_hacked/v1/sample/metadata.py +45 -0
  28. vllm_hacked/v1/sample/metadata_ori.py +43 -0
  29. vllm_hacked/v1/sample/ops/penalties_ori.py +43 -0
  30. vllm_hacked/v1/sample/sampler.py +338 -0
  31. vllm_hacked/v1/sample/sampler_ori.py +285 -0
  32. vllm_hacked/v1/spec_decode/utils.py +18 -0
  33. vllm_hacked/v1/spec_decode/utils_ori.py +14 -0
  34. vllm_hacked/v1/utils_ori.py +396 -0
  35. vllm_hacked/v1/worker/gpu_input_batch.py +669 -0
  36. vllm_hacked/v1/worker/gpu_input_batch_ori.py +863 -0
  37. vllm_hacked/v1/worker/gpu_model_runner.py +0 -0
  38. vllm_hacked/v1/worker/gpu_model_runner_ori.py +0 -0
  39. vllm_hacked/v1/worker/gpu_worker.py +710 -0
  40. vllm_hacked/worker_base.py +279 -0
  41. z_script.py +0 -44
Dockerfile CHANGED
@@ -1,4 +1,4 @@
1
- FROM juhayna/song-generation-levo:hf0613
2
 
3
  USER root
4
 
@@ -13,6 +13,10 @@ ENV PATH="/home/user/.local/bin:$PATH"
13
 
14
  WORKDIR /app
15
 
 
 
 
 
16
  COPY --chown=user ./requirements.txt requirements.txt
17
  RUN pip install --no-cache-dir --upgrade -r requirements.txt
18
 
 
1
+ FROM witszhang/songgeneration_vllm:v0
2
 
3
  USER root
4
 
 
13
 
14
  WORKDIR /app
15
 
16
+ COPY --chown=user ./vllm_hacked/model_executor/models/llama.py /opt/conda/lib/python3.11/site-packages/vllm/model_executor/models/llama.py
17
+ COPY --chown=user ./vllm_hacked/v1/sample/sampler.py /opt/conda/lib/python3.11/site-packages/vllm/v1/sample/sampler.py
18
+ COPY --chown=user ./vllm_hacked/v1/sample/metadata.py /opt/conda/lib/python3.11/site-packages/vllm/v1/sample/metadata.py
19
+ COPY --chown=user ./vllm_hacked/sampling_params.py /opt/conda/lib/python3.11/site-packages/vllm/sampling_params.py
20
  COPY --chown=user ./requirements.txt requirements.txt
21
  RUN pip install --no-cache-dir --upgrade -r requirements.txt
22
 
app.py CHANGED
@@ -15,14 +15,12 @@ from download import download_model
15
 
16
  # 下载模型
17
  APP_DIR = op.dirname(op.abspath(__file__))
18
- download_model(APP_DIR)
19
- large_model_path = op.join(APP_DIR, "ckpt", "SongGeneration-v1.5-beta")
20
- download_model(large_model_path, repo_id="waytan22/SongGeneration-v1.5-beta", revision="db10f47")
21
  print("Successful downloaded model.")
22
 
23
  # 模型初始化
24
  from levo_inference import LeVoInference
25
- MODEL = LeVoInference(large_model_path)
26
 
27
  EXAMPLE_LYRICS = """
28
  [intro-medium]
@@ -159,7 +157,7 @@ def generate_song(lyric, description=None, prompt_audio=None, genre=None, cfg_co
159
  # 创建Gradio界面
160
  with gr.Blocks(title="SongGeneration Demo Space") as demo:
161
  gr.Markdown("# 🎵 SongGeneration Demo Space")
162
- gr.Markdown("Demo interface for the song generation model. Provide a lyrics, and optionally an audio or text prompt, to generate a custom song. The code is in [GIT](https://github.com/tencent-ailab/SongGeneration)")
163
 
164
  with gr.Row():
165
  with gr.Column():
@@ -215,7 +213,7 @@ lyrics
215
  minimum=0.1,
216
  maximum=3.0,
217
  step=0.1,
218
- value=1.5,
219
  interactive=True,
220
  elem_id="cfg-coef",
221
  )
@@ -239,7 +237,7 @@ lyrics
239
  # )
240
  with gr.Row():
241
  generate_btn = gr.Button("Generate Song", variant="primary")
242
- generate_bgm_btn = gr.Button("Generate Pure Music", variant="primary")
243
 
244
  with gr.Column():
245
  output_audio = gr.Audio(label="Generated Song", type="filepath")
@@ -267,18 +265,19 @@ lyrics
267
  # 生成按钮点击事件
268
  generate_btn.click(
269
  fn=generate_song,
270
- inputs=[lyric, description, prompt_audio, genre, cfg_coef, temperature, gr.State(50)],
271
- outputs=[output_audio, output_json]
272
- )
273
- generate_bgm_btn.click(
274
- fn=generate_song,
275
- inputs=[lyric, description, prompt_audio, genre, cfg_coef, temperature, gr.State(50), gr.State("bgm")],
276
  outputs=[output_audio, output_json]
277
  )
 
 
 
 
 
278
 
279
 
280
  # 启动应用
281
  if __name__ == "__main__":
282
  torch.set_num_threads(1)
 
283
  demo.launch(server_name="0.0.0.0", server_port=7860)
284
 
 
15
 
16
  # 下载模型
17
  APP_DIR = op.dirname(op.abspath(__file__))
18
+ download_model(APP_DIR, repo_id="waytan22/SongGeneration-v2.0", revision="ffd9215")
 
 
19
  print("Successful downloaded model.")
20
 
21
  # 模型初始化
22
  from levo_inference import LeVoInference
23
+ Model = None
24
 
25
  EXAMPLE_LYRICS = """
26
  [intro-medium]
 
157
  # 创建Gradio界面
158
  with gr.Blocks(title="SongGeneration Demo Space") as demo:
159
  gr.Markdown("# 🎵 SongGeneration Demo Space")
160
+ gr.Markdown("Push to Levo 2.0 faster and more controllable. The code is in [GIT](https://github.com/tencent-ailab/SongGeneration)")
161
 
162
  with gr.Row():
163
  with gr.Column():
 
213
  minimum=0.1,
214
  maximum=3.0,
215
  step=0.1,
216
+ value=1.8,
217
  interactive=True,
218
  elem_id="cfg-coef",
219
  )
 
237
  # )
238
  with gr.Row():
239
  generate_btn = gr.Button("Generate Song", variant="primary")
240
+ # generate_bgm_btn = gr.Button("Generate Pure Music", variant="primary")
241
 
242
  with gr.Column():
243
  output_audio = gr.Audio(label="Generated Song", type="filepath")
 
265
  # 生成按钮点击事件
266
  generate_btn.click(
267
  fn=generate_song,
268
+ inputs=[lyric, description, prompt_audio, genre, cfg_coef, temperature, gr.State(5000)],
 
 
 
 
 
269
  outputs=[output_audio, output_json]
270
  )
271
+ # generate_bgm_btn.click(
272
+ # fn=generate_song,
273
+ # inputs=[lyric, description, prompt_audio, genre, cfg_coef, temperature, gr.State(50), gr.State("bgm")],
274
+ # outputs=[output_audio, output_json]
275
+ # )
276
 
277
 
278
  # 启动应用
279
  if __name__ == "__main__":
280
  torch.set_num_threads(1)
281
+ MODEL = LeVoInference(op.join(APP_DIR, "ckpt"))
282
  demo.launch(server_name="0.0.0.0", server_port=7860)
283
 
codeclm/models/builders.py CHANGED
@@ -52,7 +52,7 @@ def get_audio_tokenizer_model_cpu(checkpoint_path: str, cfg: omegaconf.DictConfi
52
  return AudioTokenizer.get_pretrained(name, cfg.vae_config, cfg.vae_model, 'cpu', mode=cfg.mode, tango_device='cpu')
53
 
54
 
55
- def get_lm_model(cfg: omegaconf.DictConfig, version: str = 'v1.0'): #-> LMModel:
56
  """Instantiate a LM."""
57
  lm_kwargs = dict_from_config(getattr(cfg, 'lm'))
58
 
 
52
  return AudioTokenizer.get_pretrained(name, cfg.vae_config, cfg.vae_model, 'cpu', mode=cfg.mode, tango_device='cpu')
53
 
54
 
55
+ def get_lm_model(cfg: omegaconf.DictConfig, version: str = 'v1.5'): #-> LMModel:
56
  """Instantiate a LM."""
57
  lm_kwargs = dict_from_config(getattr(cfg, 'lm'))
58
 
codeclm/models/codeclm_gen.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Main model for using CodecLM. This will combine all the required components
3
+ and provide easy access to the generation API.
4
+ """
5
+
6
+ import typing as tp
7
+ import warnings
8
+
9
+ import torch
10
+
11
+ from codeclm.tokenizer.audio_tokenizer import AudioTokenizer
12
+ # from .lm_llama import LMModel
13
+ from ..utils.autocast import TorchAutocast
14
+ import torch
15
+ from torch.nn import functional as F
16
+ import torchaudio
17
+ # from optim.ema import EMA
18
+ from codeclm.utils.utils import dict_from_config
19
+ from codeclm.modules.pattern import (
20
+ CodebooksPatternProvider,
21
+ DelayedPatternProvider,
22
+ )
23
+ from codeclm.modules.conditioners import (
24
+ ConditioningAttributes,
25
+ AudioCondition,
26
+ BaseConditioner,
27
+ QuantizedEmbeddingConditioner,
28
+ ConditionerProvider,
29
+ ConditionFuser,
30
+ QwTextConditioner,
31
+ QwTokenizerConditioner,
32
+ ClassifierFreeGuidanceDropoutInference,
33
+ )
34
+ import omegaconf
35
+
36
+ def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig, version: str = 'v1.0') -> ConditionerProvider:
37
+ """Instantiate a conditioning model."""
38
+ cfg = getattr(cfg, 'conditioners')
39
+ dict_cfg = {} if cfg is None else dict_from_config(cfg)
40
+ conditioners: tp.Dict[str, BaseConditioner] = {}
41
+ condition_provider_args = dict_cfg.pop('args', {})
42
+
43
+ for cond, cond_cfg in dict_cfg.items():
44
+ model_type = cond_cfg['model']
45
+ model_args = cond_cfg[model_type]
46
+ if model_type == 'QwTokenizer':
47
+ conditioners[str(cond)] = QwTokenizerConditioner(
48
+ output_dim=output_dim,
49
+ **model_args
50
+ )
51
+ elif model_type == "QwTextTokenizer":
52
+ conditioners[str(cond)] = QwTextConditioner(
53
+ output_dim=output_dim,
54
+ version=version,
55
+ **model_args
56
+ )
57
+ elif model_type == "qt_embedding":
58
+ conditioners[str(cond)] = QuantizedEmbeddingConditioner(
59
+ dim=output_dim,
60
+ **model_args
61
+ )
62
+ else:
63
+ raise ValueError(f"Unrecognized conditioning model: {model_type}")
64
+ conditioner = ConditionerProvider(conditioners, **condition_provider_args)
65
+ return conditioner
66
+
67
+ def get_codebooks_pattern_provider(code_depth: int, cfg: omegaconf.DictConfig) -> CodebooksPatternProvider:
68
+ """Instantiate a codebooks pattern provider object."""
69
+ pattern_providers = {
70
+ 'delay': DelayedPatternProvider,
71
+ }
72
+ name = cfg.modeling
73
+ kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {}
74
+ klass = pattern_providers[name]
75
+ return klass(code_depth, **kwargs)
76
+
77
+ MelodyList = tp.List[tp.Optional[torch.Tensor]]
78
+ MelodyType = tp.Union[torch.Tensor, MelodyList]
79
+
80
+ def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser:
81
+ """Instantiate a condition fuser object."""
82
+ fuser_cfg = getattr(cfg, 'fuser')
83
+ fuser_methods = ['sum', 'prepend']
84
+ fuse2cond = {k: fuser_cfg[k] for k in fuser_methods}
85
+ kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods}
86
+ fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs)
87
+ return fuser
88
+
89
+ class CodecLM_gen:
90
+ """CodecLM main model with convenient generation API.
91
+
92
+ Args:
93
+ name (str): name of the model.
94
+ compression_model (CompressionModel): Compression model
95
+ used to map audio to invertible discrete representations.
96
+ lm (LMModel): Language model over discrete representations.
97
+ max_duration (float, optional): maximum duration the model can produce,
98
+ otherwise, inferred from the training params.
99
+ """
100
+ def __init__(self, cfg, name: str, audiotokenizer: AudioTokenizer,
101
+ max_duration: tp.Optional[float] = None):
102
+ self.cfg = cfg
103
+ self.name = name
104
+ self.audiotokenizer = audiotokenizer
105
+ self.seperate_tokenizer = None
106
+ if max_duration is None:
107
+ max_duration = self.cfg.max_dur
108
+ assert max_duration is not None
109
+
110
+ self.max_duration: float = max_duration
111
+ # self.device = next(iter(lm.parameters())).device
112
+ # self.device = next(iter(audiotokenizer.parameters())).device
113
+ self.generation_params: dict = {}
114
+ # self.set_generation_params(duration=15) # 15 seconds by default
115
+ self.set_generation_params(duration=15, extend_stride=self.max_duration // 2)
116
+ self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None
117
+ self.autocast = TorchAutocast(enabled=False)
118
+ self.condition_provider = get_conditioner_provider(cfg.lm.dim, self.cfg)
119
+ codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern')
120
+ self.pattern_provider = get_codebooks_pattern_provider(cfg.lm.code_depth, codebooks_pattern_cfg)
121
+ self.fuser = get_condition_fuser(cfg)
122
+ self.eos_token_id = cfg.lm.code_size
123
+
124
+
125
+
126
+ @property
127
+ def frame_rate(self) -> float:
128
+ """Roughly the number of AR steps per seconds."""
129
+ return self.audiotokenizer.frame_rate
130
+
131
+ @property
132
+ def sample_rate(self) -> int:
133
+ """Sample rate of the generated audio."""
134
+ return self.audiotokenizer.sample_rate
135
+
136
+ @property
137
+ def audio_channels(self) -> int:
138
+ """Audio channels of the generated audio."""
139
+ return self.audiotokenizer.channels
140
+
141
+ def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
142
+ top_p: float = 0.0, temperature: float = 1.0,
143
+ duration: float = 30.0, cfg_coef: float = 3.0,
144
+ extend_stride: float = 18, record_tokens: bool = False,
145
+ record_window: int = 50):
146
+ """Set the generation parameters for CodecLM.
147
+
148
+ Args:
149
+ use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True.
150
+ top_k (int, optional): top_k used for sampling. Defaults to 250.
151
+ top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.
152
+ temperature (float, optional): Softmax temperature parameter. Defaults to 1.0.
153
+ duration (float, optional): Duration of the generated waveform. Defaults to 30.0.
154
+ cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0.
155
+ two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
156
+ instead of batching together the two. This has some impact on how things
157
+ are padded but seems to have little impact in practice.
158
+ extend_stride: when doing extended generation (i.e. more than 30 seconds), by how much
159
+ should we extend the audio each time. Larger values will mean less context is
160
+ preserved, and shorter value will require extra computations.
161
+ """
162
+ assert extend_stride <= self.max_duration, "Cannot stride by more than max generation duration."
163
+ self.extend_stride = extend_stride
164
+ self.duration = duration
165
+ self.generation_params = {
166
+ 'use_sampling': use_sampling,
167
+ 'temp': temperature,
168
+ 'top_k': top_k,
169
+ 'top_p': top_p,
170
+ 'cfg_coef': cfg_coef,
171
+ 'record_tokens': record_tokens,
172
+ 'record_window': record_window,
173
+ }
174
+
175
+ def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None):
176
+ """Override the default progress callback."""
177
+ self._progress_callback = progress_callback
178
+
179
+ # Inference
180
+ def generate_condition(self, descriptions: tp.List[str],
181
+ melody_wavs: torch.Tensor = None,
182
+ return_tokens: bool = False,
183
+ melody_is_wav: bool = True,
184
+ type_info: tp.List[str] = None,
185
+ embeded_eosp1: torch.Tensor = None,
186
+ ) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
187
+ if melody_wavs is not None:
188
+ if melody_wavs.dim() == 2:
189
+ melody_wavs = melody_wavs[None]
190
+ if melody_wavs.dim() != 3:
191
+ raise ValueError("Melody wavs should have a shape [B, C, T].")
192
+ melody_wavs = list(melody_wavs)
193
+
194
+ # if melody_is_wav:
195
+ # melody_wavs = [wav.mean(dim=-2) for wav in melody_wavs]
196
+
197
+ texts, audio_qt_embs = self._prepare_tokens_and_attributes(descriptions=descriptions,
198
+ melody_wavs=melody_wavs,
199
+ melody_is_wav=melody_is_wav)
200
+ fused_input = self.get_condition_tensors(texts, audio_qt_embs, type_info, embeded_eosp1)
201
+
202
+ return fused_input, audio_qt_embs
203
+
204
+
205
+ @torch.no_grad()
206
+ def _prepare_tokens_and_attributes(
207
+ self,
208
+ descriptions: tp.Sequence[tp.Optional[str]],
209
+ melody_wavs: tp.Optional[MelodyList] = None,
210
+ melody_is_wav = True
211
+ ) -> tp.Tuple[tp.List[str], tp.List[torch.Tensor]]:
212
+ """Prepare model inputs.
213
+
214
+ Args:
215
+ descriptions (list of str): A list of strings used as text conditioning.
216
+ prompt (torch.Tensor): A batch of waveforms used for continuation.
217
+ melody_wavs (torch.Tensor, optional): A batch of waveforms
218
+ used as melody conditioning. Defaults to None.
219
+ """
220
+ texts = [description for description in descriptions]
221
+ audio_qt_embs = []
222
+
223
+ if melody_wavs is None:
224
+ audio_qt_embs = None
225
+ elif melody_wavs is not None:
226
+ if 'prompt_audio' not in self.condition_provider.conditioners:
227
+ raise RuntimeError("This model doesn't support melody conditioning. "
228
+ "Use the `melody` model.")
229
+ assert len(melody_wavs) == len(texts), \
230
+ f"number of melody wavs must match number of descriptions! " \
231
+ f"got melody len={len(melody_wavs)}, and descriptions len={len(texts)}"
232
+ if type(melody_wavs) == list:
233
+ melody_wavs = torch.stack(melody_wavs, dim=0)
234
+ # melody_wavs = melody_wavs.to(self.device)
235
+ print(melody_wavs.shape)
236
+ if melody_is_wav:
237
+ melody_tokens, scale = self.audiotokenizer.encode(melody_wavs)
238
+ else:
239
+ melody_tokens = melody_wavs
240
+ target_melody_token_len = self.cfg.prompt_len * self.audiotokenizer.frame_rate
241
+ print(melody_tokens.shape, target_melody_token_len)
242
+ print(melody_tokens)
243
+ if melody_tokens.shape[-1] > target_melody_token_len:
244
+ melody_tokens = melody_tokens[...,:target_melody_token_len]
245
+ for melody in melody_tokens:
246
+ audio_qt_embs.append(melody.long())
247
+ return texts, audio_qt_embs
248
+
249
+ @torch.no_grad()
250
+ def prepare_condition_tensors(self,
251
+ batch_size = 1,
252
+ text: tp.Optional[tp.List[str]] = None,
253
+ audio_qt_emb: tp.Optional[tp.List[torch.Tensor]] = None,
254
+ type_info: tp.Optional[tp.List[str]] = None,
255
+ prepare_null_condition = False,
256
+ ):
257
+ conditions = []
258
+ for i in range(batch_size):
259
+ attr = ConditioningAttributes()
260
+ if 'description' in self.condition_provider.conditioners:
261
+ attr["text"]["description"] = ""
262
+ if text is not None:
263
+ attr["text"]["description"] = text[i]
264
+ if 'prompt_audio' in self.condition_provider.conditioners:
265
+ if audio_qt_emb is None: # tokenize stage will padding to max length
266
+ attr["audio"]['prompt_audio'] = AudioCondition(
267
+ wav=torch.zeros((1, self.cfg.audio_tokenizer_code_depth, 0)).long().cuda() + 16385,
268
+ length=torch.Tensor([0]).long(),
269
+ sample_rate=[self.cfg.sample_rate],)
270
+ else:
271
+ aT = audio_qt_emb[i].shape[-1]
272
+ pattern = self.pattern_provider.get_pattern(aT)
273
+ audio_qt_seq, _, _ = pattern.build_pattern_sequence(audio_qt_emb[i][None],
274
+ self.eos_token_id, keep_only_valid_steps=False)
275
+ attr["audio"]['prompt_audio'] = AudioCondition(
276
+ wav=audio_qt_seq.long().cuda(),
277
+ length=torch.Tensor([audio_qt_seq.shape[-1]]).long(),
278
+ sample_rate=[self.cfg.sample_rate],)
279
+ if 'type_info' in self.condition_provider.conditioners:
280
+ attr["text"]["type_info"] = ""
281
+ if type_info is not None:
282
+ attr["text"]["type_info"] = type_info[i]
283
+ conditions.append(attr)
284
+ # print("conditions", conditions)
285
+ if prepare_null_condition:
286
+ cfg_inference = ClassifierFreeGuidanceDropoutInference()
287
+ null_conditions = cfg_inference(conditions, condition_types=["audio", "text"],
288
+ customized=None)
289
+ conditions = conditions + null_conditions
290
+ tokenized_conditions = self.condition_provider.tokenize(conditions)
291
+ # import pdb; pdb.set_trace()
292
+ condition_tensors = self.condition_provider(tokenized_conditions)
293
+ return condition_tensors
294
+
295
+ def get_condition_tensors(self, texts, audio_qt_embs, type_info, embeded_eosp1):
296
+ condition_tensors = self.prepare_condition_tensors(batch_size=1, text=texts, audio_qt_emb=audio_qt_embs, type_info=type_info, prepare_null_condition=self.cfg.vllm.cfg)
297
+ if self.cfg.vllm.cfg:
298
+ input_ = torch.cat((embeded_eosp1, embeded_eosp1), dim=0)
299
+ else:
300
+ input_ = embeded_eosp1
301
+ fused_input = self.fuser(input_, condition_tensors)
302
+ return fused_input
303
+
304
+ @torch.no_grad()
305
+ def generate_audio(self, gen_tokens: torch.Tensor, prompt=None, vocal_prompt=None, bgm_prompt=None, chunked=False, chunk_size=128, gen_type='mixed'):
306
+ """Generate Audio from tokens"""
307
+ assert gen_tokens.dim() == 3
308
+ if self.seperate_tokenizer is not None:
309
+ gen_tokens_song = gen_tokens[:, [0], :]
310
+ gen_tokens_vocal = gen_tokens[:, [1], :]
311
+ gen_tokens_bgm = gen_tokens[:, [2], :]
312
+ if gen_type == 'bgm':
313
+ gen_tokens_vocal = torch.full_like(gen_tokens_vocal, 3142)
314
+ if vocal_prompt is not None:
315
+ vocal_prompt = torch.zeros_like(vocal_prompt)
316
+ elif gen_type == 'vocal':
317
+ gen_tokens_bgm = torch.full_like(gen_tokens_bgm, 9670)
318
+ if bgm_prompt is not None:
319
+ bgm_prompt = torch.zeros_like(bgm_prompt)
320
+ else:
321
+ assert gen_type == 'mixed', f"gen_type {gen_type} not supported"
322
+ gen_audio_seperate = self.seperate_tokenizer.decode([gen_tokens_vocal, gen_tokens_bgm], vocal_prompt, bgm_prompt, chunked=chunked, chunk_size=chunk_size)
323
+ return gen_audio_seperate
324
+ else:
325
+ gen_audio = self.audiotokenizer.decode(gen_tokens, prompt)
326
+ return gen_audio
codeclm/models/levo.py CHANGED
@@ -96,7 +96,7 @@ class LmModel(LlamaModel_base):
96
  self.vocab_size = config.vocab_size
97
  layer_cls = LlamaDecoderLayer # cross attention decoder layer can be overwritten here
98
 
99
- assert version.parse(transformers.__version__) < version.parse("4.40")
100
 
101
  self.layers = nn.ModuleList([layer_cls(config) for _ in range(config.num_hidden_layers)])
102
  self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -221,4 +221,4 @@ class LmModel(LlamaModel_base):
221
  hidden_states=all_hidden_states,
222
  attentions=all_self_attns,
223
  )
224
-
 
96
  self.vocab_size = config.vocab_size
97
  layer_cls = LlamaDecoderLayer # cross attention decoder layer can be overwritten here
98
 
99
+ #assert version.parse(transformers.__version__) < version.parse("4.40")
100
 
101
  self.layers = nn.ModuleList([layer_cls(config) for _ in range(config.num_hidden_layers)])
102
  self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
221
  hidden_states=all_hidden_states,
222
  attentions=all_self_attns,
223
  )
224
+
codeclm/models/llama/modeling_llama.py CHANGED
@@ -34,10 +34,13 @@ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
34
  from transformers.utils import (
35
  add_start_docstrings,
36
  add_start_docstrings_to_model_forward,
37
- is_flash_attn_available,
38
  logging,
39
  replace_return_docstrings,
40
  )
 
 
 
 
41
  from .configuration_llama import LlamaConfig
42
 
43
 
 
34
  from transformers.utils import (
35
  add_start_docstrings,
36
  add_start_docstrings_to_model_forward,
 
37
  logging,
38
  replace_return_docstrings,
39
  )
40
+ try:
41
+ from transformers.utils import is_flash_attn_available
42
+ except ImportError:
43
+ from transformers.utils import is_flash_attn_2_available as is_flash_attn_available
44
  from .configuration_llama import LlamaConfig
45
 
46
 
codeclm/modules/conditioners.py CHANGED
@@ -112,6 +112,7 @@ class QwTokenizerConditioner(TextConditioner):
112
  token_path = "",
113
  max_len = 300,
114
  add_token_list=[]): #""
 
115
  from transformers import Qwen2Tokenizer
116
  self.text_tokenizer = Qwen2Tokenizer.from_pretrained(token_path)
117
  if add_token_list != []:
@@ -157,9 +158,6 @@ class QwTokenizerConditioner(TextConditioner):
157
  tp_cover_range[b, st: sp_list[i+1]] = tokens[b, st] - 151645
158
 
159
  if self.max_len is not None:
160
- if inputs['input_ids'].shape[-1] > self.max_len:
161
- warnings.warn(f"Max len limit ({self.max_len}) Exceed! \
162
- {[self.text_tokenizer.convert_ids_to_tokens(i.tolist()) for i in tokens]} will be cut!")
163
  tokens = self.pad_2d_tensor(tokens, self.max_len, self.pad_token_idx).to(self.output_proj.weight.device)
164
  mask = self.pad_2d_tensor(mask, self.max_len, 0).to(self.output_proj.weight.device)
165
  tp_cover_range = self.pad_2d_tensor(tp_cover_range, self.max_len, 0).to(self.output_proj.weight.device)
@@ -168,7 +166,7 @@ class QwTokenizerConditioner(TextConditioner):
168
  structure_embeds = self.structure_emb(tp_cover_range.to(device))
169
 
170
  embeds = content_embeds + structure_embeds
171
- return embeds, embeds, mask
172
 
173
  def pad_2d_tensor(self, x, max_len, pad_id):
174
  batch_size, seq_len = x.size()
@@ -192,9 +190,9 @@ class QwTextConditioner(TextConditioner):
192
  version: str = 'v1.0'): #""
193
 
194
  from transformers import Qwen2Tokenizer
195
- self.text_tokenizer = Qwen2Tokenizer.from_pretrained(token_path)
196
- if version == 'v1.5':
197
- self.text_tokenizer.add_tokens(['[Musicality-very-high]', '[Musicality-high]', '[Musicality-medium]', '[Musicality-low]', '[Musicality-very-low]'], special_tokens=True)
198
  voc_size = len(self.text_tokenizer.get_vocab())
199
  # here initialize a output_proj (nn.Embedding) layer
200
  super().__init__(voc_size, output_dim, input_token=True, padding_idx=151643)
@@ -223,7 +221,7 @@ class QwTextConditioner(TextConditioner):
223
  mask = self.pad_2d_tensor(mask, self.max_len, 0).to(self.output_proj.weight.device)
224
 
225
  embeds = self.output_proj(tokens)
226
- return embeds, embeds, mask
227
 
228
  def pad_2d_tensor(self, x, max_len, pad_id):
229
  batch_size, seq_len = x.size()
@@ -255,7 +253,6 @@ class QuantizedEmbeddingConditioner(AudioConditioner):
255
  self.emb = nn.ModuleList([nn.Embedding(code_size+2, dim, padding_idx=code_size+1) for _ in range(code_depth)])
256
  # add End-Of-Text embedding
257
  self.EOT_emb = nn.Parameter(torch.randn(1, dim), requires_grad=True)
258
- self.layer2_EOT_emb = nn.Parameter(torch.randn(1, dim), requires_grad=True)
259
  self.output_proj = None
260
  self.max_len = max_len
261
  self.vocab_size = code_size
@@ -274,20 +271,20 @@ class QuantizedEmbeddingConditioner(AudioConditioner):
274
  wav = F.pad(wav, [0, self.max_len - 1 - wav.shape[2]], value=self.vocab_size+1)
275
  else:
276
  wav = wav[:, :, :self.max_len-1]
277
- embeds1 = self.emb[0](wav[:, 0])
278
- embeds1 = torch.cat((self.EOT_emb.unsqueeze(0).repeat(B, 1, 1),
279
- embeds1), dim=1)
280
- embeds2 = sum([self.emb[k](wav[:, k]) for k in range(1, self.code_depth)]) # B,T,D
281
- embeds2 = torch.cat((self.layer2_EOT_emb.unsqueeze(0).repeat(B, 1, 1),
282
- embeds2), dim=1)
283
  lengths = lengths + 1
284
  lengths = torch.clamp(lengths, max=self.max_len)
285
 
286
  if lengths is not None:
287
- mask = length_to_mask(lengths, max_len=embeds1.shape[1]).int() # type: ignore
288
  else:
289
- mask = torch.ones((B, self.code_depth), device=embeds1.device, dtype=torch.int)
290
- return embeds1, embeds2, mask
291
 
292
 
293
  # ================================================================
@@ -356,10 +353,10 @@ class ConditionerProvider(nn.Module):
356
  output = {}
357
  for attribute, inputs in tokenized.items():
358
  if attribute == 'description' and structure_dur is not None:
359
- condition1, condition2, mask = self.conditioners[attribute](inputs, structure_dur = structure_dur)
360
  else:
361
- condition1, condition2, mask = self.conditioners[attribute](inputs)
362
- output[attribute] = (condition1, condition2, mask)
363
  return output
364
 
365
  def _collate_text(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.List[tp.Optional[str]]]:
@@ -460,8 +457,7 @@ class ConditionFuser(StreamingModule):
460
 
461
  def forward(
462
  self,
463
- input1: torch.Tensor,
464
- input2: torch.Tensor,
465
  conditions: tp.Dict[str, ConditionType]
466
  ) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
467
  """Fuse the conditions to the provided model input.
@@ -475,14 +471,14 @@ class ConditionFuser(StreamingModule):
475
  used for cross-attention or None if no cross attention inputs exist.
476
  """
477
  #import pdb; pdb.set_trace()
478
- B, T, _ = input1.shape
479
 
480
  if 'offsets' in self._streaming_state:
481
  first_step = False
482
  offsets = self._streaming_state['offsets']
483
  else:
484
  first_step = True
485
- offsets = torch.zeros(input1.shape[0], dtype=torch.long, device=input1.device)
486
 
487
  assert set(conditions.keys()).issubset(set(self.cond2fuse.keys())), \
488
  f"given conditions contain unknown attributes for fuser, " \
@@ -491,31 +487,28 @@ class ConditionFuser(StreamingModule):
491
  # if 'prepend' mode is used,
492
  # the concatenation order will be the SAME with the conditions in config:
493
  # prepend: ['description', 'prompt_audio'] (then goes the input)
494
- fused_input_1 = input1
495
- fused_input_2 = input2
496
  for fuse_op in self.fuse2cond.keys():
497
  fuse_op_conditions = self.fuse2cond[fuse_op]
498
  if fuse_op == 'sum' and len(fuse_op_conditions) > 0:
499
  for cond in fuse_op_conditions:
500
- this_cond_1, this_cond_2, cond_mask = conditions[cond]
501
- fused_input_1 += this_cond_1
502
- fused_input_2 += this_cond_2
503
  elif fuse_op == 'prepend' and len(fuse_op_conditions) > 0:
504
  if not first_step:
505
  continue
506
  reverse_list = deepcopy(fuse_op_conditions)
507
  reverse_list.reverse()
508
  for cond in reverse_list:
509
- this_cond_1, this_cond_2, cond_mask = conditions[cond]
510
- fused_input_1 = torch.cat((this_cond_1, fused_input_1), dim=1) # concat along T dim
511
- fused_input_2 = torch.cat((this_cond_2, fused_input_2), dim=1) # concat along T dim
512
  elif fuse_op not in self.FUSING_METHODS:
513
  raise ValueError(f"unknown op ({fuse_op})")
514
 
515
  if self._is_streaming:
516
  self._streaming_state['offsets'] = offsets + T
517
 
518
- return fused_input_1, fused_input_2
519
 
520
 
521
 
@@ -575,8 +568,7 @@ class ClassifierFreeGuidanceDropout(DropoutModule):
575
  self.check(sample, condition_type, condition)
576
 
577
  if condition_type == 'audio':
578
- audio_cond = sample.audio[condition]
579
- depth = audio_cond.wav.shape[1]
580
  sample.audio[condition] = self.get_null_wav(audio_cond.wav, sr=audio_cond.sample_rate[0])
581
  else:
582
  sample.text[condition] = None
@@ -639,7 +631,7 @@ class ClassifierFreeGuidanceDropoutInference(ClassifierFreeGuidanceDropout):
639
  sample.audio[condition] = self.get_null_wav(audio_cond.wav, sr=audio_cond.sample_rate[0])
640
  else:
641
  if customized is None:
642
- if condition in ['type_info'] and sample.text[condition] is not None:
643
  if "[Musicality-very-high]" in sample.text[condition]:
644
  sample.text[condition] = "[Musicality-very-low], ."
645
  print(f"cfg unconditioning: change sample.text[condition] to [Musicality-very-low]")
 
112
  token_path = "",
113
  max_len = 300,
114
  add_token_list=[]): #""
115
+ add_token_list.append('.')
116
  from transformers import Qwen2Tokenizer
117
  self.text_tokenizer = Qwen2Tokenizer.from_pretrained(token_path)
118
  if add_token_list != []:
 
158
  tp_cover_range[b, st: sp_list[i+1]] = tokens[b, st] - 151645
159
 
160
  if self.max_len is not None:
 
 
 
161
  tokens = self.pad_2d_tensor(tokens, self.max_len, self.pad_token_idx).to(self.output_proj.weight.device)
162
  mask = self.pad_2d_tensor(mask, self.max_len, 0).to(self.output_proj.weight.device)
163
  tp_cover_range = self.pad_2d_tensor(tp_cover_range, self.max_len, 0).to(self.output_proj.weight.device)
 
166
  structure_embeds = self.structure_emb(tp_cover_range.to(device))
167
 
168
  embeds = content_embeds + structure_embeds
169
+ return embeds, mask
170
 
171
  def pad_2d_tensor(self, x, max_len, pad_id):
172
  batch_size, seq_len = x.size()
 
190
  version: str = 'v1.0'): #""
191
 
192
  from transformers import Qwen2Tokenizer
193
+ self.text_tokenizer = Qwen2Tokenizer.from_pretrained(token_path)
194
+ self.text_tokenizer.add_tokens(['[Musicality-very-high]', '[Musicality-high]', '[Musicality-medium]', '[Musicality-low]', '[Musicality-very-low]', '[Pure-Music]', '.'], special_tokens=True)
195
+ print(self.text_tokenizer)
196
  voc_size = len(self.text_tokenizer.get_vocab())
197
  # here initialize a output_proj (nn.Embedding) layer
198
  super().__init__(voc_size, output_dim, input_token=True, padding_idx=151643)
 
221
  mask = self.pad_2d_tensor(mask, self.max_len, 0).to(self.output_proj.weight.device)
222
 
223
  embeds = self.output_proj(tokens)
224
+ return embeds, mask
225
 
226
  def pad_2d_tensor(self, x, max_len, pad_id):
227
  batch_size, seq_len = x.size()
 
253
  self.emb = nn.ModuleList([nn.Embedding(code_size+2, dim, padding_idx=code_size+1) for _ in range(code_depth)])
254
  # add End-Of-Text embedding
255
  self.EOT_emb = nn.Parameter(torch.randn(1, dim), requires_grad=True)
 
256
  self.output_proj = None
257
  self.max_len = max_len
258
  self.vocab_size = code_size
 
271
  wav = F.pad(wav, [0, self.max_len - 1 - wav.shape[2]], value=self.vocab_size+1)
272
  else:
273
  wav = wav[:, :, :self.max_len-1]
274
+ # self.emb.to(wav.device) # 都放cuda
275
+ wav = wav.to(self.emb[0].weight.device)
276
+ embeds = sum([self.emb[k](wav[:, k]) for k in range(self.code_depth)]) # B,T,D
277
+ # self.EOT_emb.data = self.EOT_emb.data.to(embeds.device)
278
+ embeds = torch.cat((self.EOT_emb.unsqueeze(0).repeat(B, 1, 1),
279
+ embeds), dim=1)
280
  lengths = lengths + 1
281
  lengths = torch.clamp(lengths, max=self.max_len)
282
 
283
  if lengths is not None:
284
+ mask = length_to_mask(lengths, max_len=embeds.shape[1]).int() # type: ignore
285
  else:
286
+ mask = torch.ones((B, self.code_depth), device=embeds.device, dtype=torch.int)
287
+ return embeds, mask
288
 
289
 
290
  # ================================================================
 
353
  output = {}
354
  for attribute, inputs in tokenized.items():
355
  if attribute == 'description' and structure_dur is not None:
356
+ condition, mask = self.conditioners[attribute](inputs, structure_dur = structure_dur)
357
  else:
358
+ condition, mask = self.conditioners[attribute](inputs)
359
+ output[attribute] = (condition, mask)
360
  return output
361
 
362
  def _collate_text(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.List[tp.Optional[str]]]:
 
457
 
458
  def forward(
459
  self,
460
+ input: torch.Tensor,
 
461
  conditions: tp.Dict[str, ConditionType]
462
  ) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
463
  """Fuse the conditions to the provided model input.
 
471
  used for cross-attention or None if no cross attention inputs exist.
472
  """
473
  #import pdb; pdb.set_trace()
474
+ B, T, _ = input.shape
475
 
476
  if 'offsets' in self._streaming_state:
477
  first_step = False
478
  offsets = self._streaming_state['offsets']
479
  else:
480
  first_step = True
481
+ offsets = torch.zeros(input.shape[0], dtype=torch.long, device=input.device)
482
 
483
  assert set(conditions.keys()).issubset(set(self.cond2fuse.keys())), \
484
  f"given conditions contain unknown attributes for fuser, " \
 
487
  # if 'prepend' mode is used,
488
  # the concatenation order will be the SAME with the conditions in config:
489
  # prepend: ['description', 'prompt_audio'] (then goes the input)
490
+ fused_input = input
 
491
  for fuse_op in self.fuse2cond.keys():
492
  fuse_op_conditions = self.fuse2cond[fuse_op]
493
  if fuse_op == 'sum' and len(fuse_op_conditions) > 0:
494
  for cond in fuse_op_conditions:
495
+ this_cond, cond_mask = conditions[cond]
496
+ fused_input += this_cond
 
497
  elif fuse_op == 'prepend' and len(fuse_op_conditions) > 0:
498
  if not first_step:
499
  continue
500
  reverse_list = deepcopy(fuse_op_conditions)
501
  reverse_list.reverse()
502
  for cond in reverse_list:
503
+ this_cond, cond_mask = conditions[cond]
504
+ fused_input = torch.cat((this_cond, fused_input), dim=1) # concat along T dim
 
505
  elif fuse_op not in self.FUSING_METHODS:
506
  raise ValueError(f"unknown op ({fuse_op})")
507
 
508
  if self._is_streaming:
509
  self._streaming_state['offsets'] = offsets + T
510
 
511
+ return fused_input
512
 
513
 
514
 
 
568
  self.check(sample, condition_type, condition)
569
 
570
  if condition_type == 'audio':
571
+ audio_cond = sample.audio[condition]
 
572
  sample.audio[condition] = self.get_null_wav(audio_cond.wav, sr=audio_cond.sample_rate[0])
573
  else:
574
  sample.text[condition] = None
 
631
  sample.audio[condition] = self.get_null_wav(audio_cond.wav, sr=audio_cond.sample_rate[0])
632
  else:
633
  if customized is None:
634
+ if condition in ['type_info']:
635
  if "[Musicality-very-high]" in sample.text[condition]:
636
  sample.text[condition] = "[Musicality-very-low], ."
637
  print(f"cfg unconditioning: change sample.text[condition] to [Musicality-very-low]")
codeclm/tokenizer/Flow1dVAE/generate_1rvq.py CHANGED
@@ -10,6 +10,7 @@ import math
10
  import numpy as np
11
  import tools.torch_tools as torch_tools
12
  from safetensors.torch import load_file
 
13
 
14
  class Tango:
15
  def __init__(self, \
@@ -23,9 +24,9 @@ class Tango:
23
  scheduler_name = "configs/scheduler/stable_diffusion_2.1_largenoise_sample.json"
24
  self.device = device
25
 
26
- # self.vae = get_model(vae_config, vae_model)
27
- # self.vae = self.vae.to(device)
28
- # self.vae=self.vae.eval()
29
  self.layer_num = layer_num
30
 
31
  self.MAX_DURATION = 360
@@ -52,43 +53,34 @@ class Tango:
52
  # scheduler_name, subfolder="scheduler")
53
  # print("Successfully loaded inference scheduler from {}".format(scheduler_name))
54
 
55
- # def sound2sound(self, orig_samples, lyric, st_et, batch_size=1, duration=40.96, steps=200, disable_progress=False,scenario = "start_seg"):
56
- # """ Genrate audio without condition. """
57
- # with torch.no_grad():
58
- # if(orig_samples.shape[-1]<int(duration*48000)+480):
59
- # orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration*48000+480)-orig_samples.shape[-1], \
60
- # dtype=orig_samples.dtype, device=orig_samples.device)], -1)
61
-
62
- # orig_samples = orig_samples.to(self.device)
63
- # saved_samples = orig_samples[:,0:40*48000].clamp(-1,1)
64
- # orig_samples = orig_samples[:,0:40*48000].clamp(-1,1)
65
- # max_volume = orig_samples.abs().max(dim=-1)[0]
66
- # orig_samples = orig_samples/max_volume.unsqueeze(-1)
67
- # print("orig_samples.shape", orig_samples.shape)
68
 
69
- # latent_length = int((st_et[1] - st_et[0]) * 48000) // 1920 + 1
 
 
 
 
70
 
71
- # true_latents = self.vae.encode_audio(orig_samples).permute(0,2,1)
72
 
73
- # print("true_latents.shape", true_latents.shape)
74
- # latents = self.model.inference(orig_samples.repeat(batch_size, 1), [lyric, ]*batch_size, true_latents, latent_length, additional_feats=[], guidance_scale=1.5, num_steps = steps, disable_progress=disable_progress,layer=6, scenario = scenario)
75
- # print("latents.shape", latents.shape)
76
- # print("latent_length", latent_length)
77
 
78
- # latents = latents[:,:,:latent_length]
79
- # audio = self.vae.decode_audio(latents)
80
- # print("audio.shape:",audio.shape)
81
- # audio = torch.cat((audio, torch.zeros(audio.shape[0],audio.shape[1], 48000*40 - audio.shape[-1], dtype=audio.dtype, device=audio.device)), dim=-1)
82
- # print("audio.shape:",audio.shape)
83
- # # audio = audio.reshape(audio.shape[0]//2, 2, -1)
84
- # # audio = torch.from_numpy(audio)
85
 
86
- # if(saved_samples.shape[-1]<audio.shape[-1]):
87
- # saved_samples = torch.cat([saved_samples, torch.zeros(saved_samples.shape[0], audio.shape[-1]-saved_samples.shape[-1], dtype=saved_samples.dtype, device=saved_samples.device)],-1)
88
- # else:
89
- # saved_samples = saved_samples[:,0:audio.shape[-1]]
90
- # output = torch.cat([saved_samples.detach().cpu(),audio[0].detach().cpu()],0)
91
- # return output
92
 
93
  @torch.no_grad()
94
  @torch.autocast(device_type="cuda", dtype=torch.float32)
@@ -105,7 +97,6 @@ class Tango:
105
  min_samples = int(40 * self.sample_rate)
106
  # 40秒对应10个token
107
  output_len = int(orig_length / float(self.sample_rate) * 25) + 1
108
- print("output_len: ", output_len)
109
 
110
  while(audios.shape[-1] < min_samples):
111
  audios = torch.cat([audios, audios], -1)
@@ -117,10 +108,8 @@ class Tango:
117
  audio_input = audios.reshape(2, -1, min_samples).permute(1, 0, 2).reshape(-1, 2, min_samples)
118
 
119
  for audio_inx in range(0, audio_input.shape[0], batch_size):
120
- # import pdb; pdb.set_trace()
121
  codes, _, spk_embeds = self.model.fetch_codes_batch((audio_input[audio_inx:audio_inx+batch_size]), additional_feats=[],layer=self.layer_num)
122
  codes_list.append(torch.cat(codes, 1))
123
- # print("codes_list",codes_list[0].shape)
124
 
125
  codes = torch.cat(codes_list, 0).permute(1,0,2).reshape(1, -1)[None] # B 3 T -> 3 B T
126
  codes=codes[:,:,:output_len]
@@ -159,21 +148,13 @@ class Tango:
159
  # else choose from 20.48s which might includes verse or chorus
160
  prompt = prompt[:,int(20*self.sample_rate):int(30*self.sample_rate)] # limit max length to 10.24
161
 
162
- true_latent = self.vae.encode_audio(prompt).permute(0,2,1)
163
- # print("true_latent.shape", true_latent.shape)
164
- # print("first_latent.shape", first_latent.shape)
165
- #true_latent.shape torch.Size([1, 250, 64])
166
- # first_latent.shape torch.Size([1, 1000, 64])
167
-
168
  first_latent[:,0:true_latent.shape[1],:] = true_latent
169
  first_latent_length = true_latent.shape[1]
170
  first_latent_codes = self.sound2code(prompt)
171
  first_latent_codes_length = first_latent_codes.shape[-1]
172
  codes = torch.cat([first_latent_codes, codes], -1)
173
 
174
-
175
-
176
-
177
  codes_len= codes.shape[-1]
178
  target_len = int((codes_len - first_latent_codes_length) / 100 * 4 * self.sample_rate)
179
  # target_len = int(codes_len / 100 * 4 * self.sample_rate)
@@ -196,17 +177,12 @@ class Tango:
196
  codes_input=[]
197
  codes_input.append(codes[:,:,sinx:sinx+min_samples])
198
  if(sinx == 0):
199
- # print("Processing {} to {}".format(sinx/self.sample_rate, (sinx + min_samples)/self.sample_rate))
200
  incontext_length = first_latent_length
201
  latents = self.model.inference_codes(codes_input, spk_embeds, first_latent, latent_length, incontext_length=incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg')
202
  latent_list.append(latents)
203
  else:
204
- # print("Processing {} to {}".format(sinx/self.sample_rate, (sinx + min_samples)/self.sample_rate))
205
  true_latent = latent_list[-1][:,:,-ovlp_frames:].permute(0,2,1)
206
- print("true_latent.shape", true_latent.shape)
207
  len_add_to_1000 = min_samples - true_latent.shape[-2]
208
- # print("len_add_to_1000", len_add_to_1000)
209
- # exit()
210
  incontext_length = true_latent.shape[-2]
211
  true_latent = torch.cat([true_latent, torch.randn(true_latent.shape[0], len_add_to_1000, true_latent.shape[-1]).to(self.device)], -2)
212
  latents = self.model.inference_codes(codes_input, spk_embeds, true_latent, latent_length, incontext_length=incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg')
@@ -228,8 +204,6 @@ class Tango:
228
  else:
229
  ov_win = torch.from_numpy(np.linspace(0, 1, ovlp_samples)[None, :])
230
  ov_win = torch.cat([ov_win, 1 - ov_win], -1)
231
- print("output.shape", output.shape)
232
- print("ov_win.shape", ov_win.shape)
233
  output[:, -ovlp_samples:] = output[:, -ovlp_samples:] * ov_win[:, -ovlp_samples:] + cur_output[:, 0:ovlp_samples] * ov_win[:, 0:ovlp_samples]
234
  output = torch.cat([output, cur_output[:, ovlp_samples:]], -1)
235
  output = output[:, 0:target_len]
@@ -248,9 +222,7 @@ class Tango:
248
  @torch.no_grad()
249
  def sound2sound(self, sound, prompt=None, steps=50, disable_progress=False):
250
  codes = self.sound2code(sound)
251
- # print(codes.shape)
252
  wave = self.code2sound(codes, prompt, guidance_scale=1.5, num_steps=steps, disable_progress=disable_progress)
253
- # print(fname, wave.shape)
254
  return wave
255
 
256
  def to(self, device=None, dtype=None, non_blocking=False):
 
10
  import numpy as np
11
  import tools.torch_tools as torch_tools
12
  from safetensors.torch import load_file
13
+ from tools.get_1dvae_large import get_model
14
 
15
  class Tango:
16
  def __init__(self, \
 
24
  scheduler_name = "configs/scheduler/stable_diffusion_2.1_largenoise_sample.json"
25
  self.device = device
26
 
27
+ self.vae = get_model(vae_config, vae_model)
28
+ self.vae = self.vae.to(device)
29
+ self.vae=self.vae.eval()
30
  self.layer_num = layer_num
31
 
32
  self.MAX_DURATION = 360
 
53
  # scheduler_name, subfolder="scheduler")
54
  # print("Successfully loaded inference scheduler from {}".format(scheduler_name))
55
 
56
+ def sound2sound(self, orig_samples, lyric, st_et, batch_size=1, duration=40.96, steps=200, disable_progress=False,scenario = "start_seg"):
57
+ """ Genrate audio without condition. """
58
+ with torch.no_grad():
59
+ if(orig_samples.shape[-1]<int(duration*48000)+480):
60
+ orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration*48000+480)-orig_samples.shape[-1], \
61
+ dtype=orig_samples.dtype, device=orig_samples.device)], -1)
 
 
 
 
 
 
 
62
 
63
+ orig_samples = orig_samples.to(self.device)
64
+ saved_samples = orig_samples[:,0:40*48000].clamp(-1,1)
65
+ orig_samples = orig_samples[:,0:40*48000].clamp(-1,1)
66
+ max_volume = orig_samples.abs().max(dim=-1)[0]
67
+ orig_samples = orig_samples/max_volume.unsqueeze(-1)
68
 
69
+ latent_length = int((st_et[1] - st_et[0]) * 48000) // 1920 + 1
70
 
71
+ true_latents = self.vae.encode_audio(orig_samples).permute(0,2,1)
 
 
 
72
 
73
+ latents = self.model.inference(orig_samples.repeat(batch_size, 1), [lyric, ]*batch_size, true_latents, latent_length, additional_feats=[], guidance_scale=1.5, num_steps = steps, disable_progress=disable_progress,layer=6, scenario = scenario)
74
+ latents = latents[:,:,:latent_length]
75
+ audio = self.vae.decode_audio(latents)
76
+ audio = torch.cat((audio, torch.zeros(audio.shape[0],audio.shape[1], 48000*40 - audio.shape[-1], dtype=audio.dtype, device=audio.device)), dim=-1)
 
 
 
77
 
78
+ if(saved_samples.shape[-1]<audio.shape[-1]):
79
+ saved_samples = torch.cat([saved_samples, torch.zeros(saved_samples.shape[0], audio.shape[-1]-saved_samples.shape[-1], dtype=saved_samples.dtype, device=saved_samples.device)],-1)
80
+ else:
81
+ saved_samples = saved_samples[:,0:audio.shape[-1]]
82
+ output = torch.cat([saved_samples.detach().cpu(),audio[0].detach().cpu()],0)
83
+ return output
84
 
85
  @torch.no_grad()
86
  @torch.autocast(device_type="cuda", dtype=torch.float32)
 
97
  min_samples = int(40 * self.sample_rate)
98
  # 40秒对应10个token
99
  output_len = int(orig_length / float(self.sample_rate) * 25) + 1
 
100
 
101
  while(audios.shape[-1] < min_samples):
102
  audios = torch.cat([audios, audios], -1)
 
108
  audio_input = audios.reshape(2, -1, min_samples).permute(1, 0, 2).reshape(-1, 2, min_samples)
109
 
110
  for audio_inx in range(0, audio_input.shape[0], batch_size):
 
111
  codes, _, spk_embeds = self.model.fetch_codes_batch((audio_input[audio_inx:audio_inx+batch_size]), additional_feats=[],layer=self.layer_num)
112
  codes_list.append(torch.cat(codes, 1))
 
113
 
114
  codes = torch.cat(codes_list, 0).permute(1,0,2).reshape(1, -1)[None] # B 3 T -> 3 B T
115
  codes=codes[:,:,:output_len]
 
148
  # else choose from 20.48s which might includes verse or chorus
149
  prompt = prompt[:,int(20*self.sample_rate):int(30*self.sample_rate)] # limit max length to 10.24
150
 
151
+ true_latent = self.vae.encode_audio(prompt).permute(0,2,1)
 
 
 
 
 
152
  first_latent[:,0:true_latent.shape[1],:] = true_latent
153
  first_latent_length = true_latent.shape[1]
154
  first_latent_codes = self.sound2code(prompt)
155
  first_latent_codes_length = first_latent_codes.shape[-1]
156
  codes = torch.cat([first_latent_codes, codes], -1)
157
 
 
 
 
158
  codes_len= codes.shape[-1]
159
  target_len = int((codes_len - first_latent_codes_length) / 100 * 4 * self.sample_rate)
160
  # target_len = int(codes_len / 100 * 4 * self.sample_rate)
 
177
  codes_input=[]
178
  codes_input.append(codes[:,:,sinx:sinx+min_samples])
179
  if(sinx == 0):
 
180
  incontext_length = first_latent_length
181
  latents = self.model.inference_codes(codes_input, spk_embeds, first_latent, latent_length, incontext_length=incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg')
182
  latent_list.append(latents)
183
  else:
 
184
  true_latent = latent_list[-1][:,:,-ovlp_frames:].permute(0,2,1)
 
185
  len_add_to_1000 = min_samples - true_latent.shape[-2]
 
 
186
  incontext_length = true_latent.shape[-2]
187
  true_latent = torch.cat([true_latent, torch.randn(true_latent.shape[0], len_add_to_1000, true_latent.shape[-1]).to(self.device)], -2)
188
  latents = self.model.inference_codes(codes_input, spk_embeds, true_latent, latent_length, incontext_length=incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg')
 
204
  else:
205
  ov_win = torch.from_numpy(np.linspace(0, 1, ovlp_samples)[None, :])
206
  ov_win = torch.cat([ov_win, 1 - ov_win], -1)
 
 
207
  output[:, -ovlp_samples:] = output[:, -ovlp_samples:] * ov_win[:, -ovlp_samples:] + cur_output[:, 0:ovlp_samples] * ov_win[:, 0:ovlp_samples]
208
  output = torch.cat([output, cur_output[:, ovlp_samples:]], -1)
209
  output = output[:, 0:target_len]
 
222
  @torch.no_grad()
223
  def sound2sound(self, sound, prompt=None, steps=50, disable_progress=False):
224
  codes = self.sound2code(sound)
 
225
  wave = self.code2sound(codes, prompt, guidance_scale=1.5, num_steps=steps, disable_progress=disable_progress)
 
226
  return wave
227
 
228
  def to(self, device=None, dtype=None, non_blocking=False):
codeclm/tokenizer/Flow1dVAE/model_1rvq.py CHANGED
@@ -301,17 +301,17 @@ class PromptCondAudioDiffusion(nn.Module):
301
  # for v in self.hubert.parameters():v.requires_grad = False
302
  self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,))
303
  # self.xvecmodel = XVECModel()
304
- # config = GPT2Config(n_positions=1000,n_layer=39,n_head=30,n_embd=1200)
305
- # unet = GPT2Model(config)
306
- # mlp = nn.Sequential(
307
- # nn.Linear(1200, 1024),
308
- # nn.SiLU(),
309
- # nn.Linear(1024, 1024),
310
- # nn.SiLU(),
311
- # nn.Linear(1024, 768)
312
- # )
313
  self.set_from = "random"
314
- # self.cfm_wrapper = BASECFM(unet, mlp,self.ssl_layer)
315
  self.mask_emb = torch.nn.Embedding(3, 48)
316
  print("Transformer initialized from pretrain.")
317
  torch.cuda.empty_cache()
@@ -602,38 +602,20 @@ class PromptCondAudioDiffusion(nn.Module):
602
  dtype = self.dtype
603
  # codes_bestrq_middle, codes_bestrq_last = codes
604
  codes_bestrq_emb = codes[0]
605
-
606
-
607
  batch_size = codes_bestrq_emb.shape[0]
608
-
609
-
610
  quantized_bestrq_emb,_,_=self.rvq_bestrq_emb.from_codes(codes_bestrq_emb)
611
- # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
612
  quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous()
613
- print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
614
- # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
615
-
616
-
617
-
618
-
619
  if('spk' in additional_feats):
620
  spk_embeds = spk_embeds.repeat(1,1,quantized_bestrq_emb.shape[-2],1).detach()
621
-
622
  num_frames = quantized_bestrq_emb.shape[1]
623
-
624
  num_channels_latents = self.num_channels
625
  shape = (batch_size, num_frames, 64)
626
  latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
627
-
628
-
629
-
630
  latent_masks = torch.zeros(latents.shape[0], latents.shape[1], dtype=torch.int64, device=latents.device)
631
  latent_masks[:,0:latent_length] = 2
632
  if(scenario=='other_seg'):
633
  latent_masks[:,0:incontext_length] = 1
634
 
635
-
636
-
637
  quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \
638
  + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
639
  true_latents = true_latents.permute(0,2,1).contiguous()
@@ -642,7 +624,6 @@ class PromptCondAudioDiffusion(nn.Module):
642
  incontext_latents = true_latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float()
643
  incontext_length = ((latent_masks > 0.5) * (latent_masks < 1.5)).sum(-1)[0]
644
 
645
-
646
  attention_mask=(latent_masks > 0.5)
647
  B, L = attention_mask.size()
648
  attention_mask = attention_mask.view(B, 1, L)
 
301
  # for v in self.hubert.parameters():v.requires_grad = False
302
  self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,))
303
  # self.xvecmodel = XVECModel()
304
+ config = GPT2Config(n_positions=1000,n_layer=39,n_head=30,n_embd=1200)
305
+ unet = GPT2Model(config)
306
+ mlp = nn.Sequential(
307
+ nn.Linear(1200, 1024),
308
+ nn.SiLU(),
309
+ nn.Linear(1024, 1024),
310
+ nn.SiLU(),
311
+ nn.Linear(1024, 768)
312
+ )
313
  self.set_from = "random"
314
+ self.cfm_wrapper = BASECFM(unet, mlp,self.ssl_layer)
315
  self.mask_emb = torch.nn.Embedding(3, 48)
316
  print("Transformer initialized from pretrain.")
317
  torch.cuda.empty_cache()
 
602
  dtype = self.dtype
603
  # codes_bestrq_middle, codes_bestrq_last = codes
604
  codes_bestrq_emb = codes[0]
 
 
605
  batch_size = codes_bestrq_emb.shape[0]
 
 
606
  quantized_bestrq_emb,_,_=self.rvq_bestrq_emb.from_codes(codes_bestrq_emb)
 
607
  quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous()
 
 
 
 
 
 
608
  if('spk' in additional_feats):
609
  spk_embeds = spk_embeds.repeat(1,1,quantized_bestrq_emb.shape[-2],1).detach()
 
610
  num_frames = quantized_bestrq_emb.shape[1]
 
611
  num_channels_latents = self.num_channels
612
  shape = (batch_size, num_frames, 64)
613
  latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
 
 
 
614
  latent_masks = torch.zeros(latents.shape[0], latents.shape[1], dtype=torch.int64, device=latents.device)
615
  latent_masks[:,0:latent_length] = 2
616
  if(scenario=='other_seg'):
617
  latent_masks[:,0:incontext_length] = 1
618
 
 
 
619
  quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \
620
  + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
621
  true_latents = true_latents.permute(0,2,1).contiguous()
 
624
  incontext_latents = true_latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float()
625
  incontext_length = ((latent_masks > 0.5) * (latent_masks < 1.5)).sum(-1)[0]
626
 
 
627
  attention_mask=(latent_masks > 0.5)
628
  B, L = attention_mask.size()
629
  attention_mask = attention_mask.view(B, 1, L)
codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_config.py CHANGED
@@ -18,6 +18,8 @@
18
  from collections import OrderedDict
19
  from typing import Any, List, Mapping, Optional
20
 
 
 
21
  from transformers import PreTrainedTokenizer, TensorType, is_torch_available
22
  from transformers.configuration_utils import PretrainedConfig
23
  from transformers.onnx import OnnxConfigWithPast, PatchingSpec
@@ -27,6 +29,59 @@ from transformers.utils import logging
27
  logger = logging.get_logger(__name__)
28
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  class GPT2Config(PretrainedConfig):
31
  """
32
  This is the configuration class to store the configuration of a [`GPT2Model`] or a [`TFGPT2Model`]. It is used to
 
18
  from collections import OrderedDict
19
  from typing import Any, List, Mapping, Optional
20
 
21
+ import torch
22
+ import torch.nn as nn
23
  from transformers import PreTrainedTokenizer, TensorType, is_torch_available
24
  from transformers.configuration_utils import PretrainedConfig
25
  from transformers.onnx import OnnxConfigWithPast, PatchingSpec
 
29
  logger = logging.get_logger(__name__)
30
 
31
 
32
+ class SequenceSummary(nn.Module):
33
+ """Compute a single vector summary of a sequence hidden states."""
34
+
35
+ def __init__(self, config: PretrainedConfig):
36
+ super().__init__()
37
+ self.summary_type = getattr(config, "summary_type", "last")
38
+ self.summary_use_proj = getattr(config, "summary_use_proj", True)
39
+ self.summary_activation = getattr(config, "summary_activation", None)
40
+ self.summary_last_dropout = getattr(config, "summary_last_dropout", 0.0)
41
+ self.summary_first_dropout = getattr(config, "summary_first_dropout", 0.0)
42
+ self.summary_proj_to_labels = getattr(config, "summary_proj_to_labels", True)
43
+
44
+ if self.summary_use_proj:
45
+ if self.summary_proj_to_labels and hasattr(config, "num_labels"):
46
+ num_classes = config.num_labels
47
+ else:
48
+ num_classes = config.hidden_size
49
+ self.summary = nn.Linear(config.hidden_size, num_classes)
50
+
51
+ self.activation = nn.Tanh() if self.summary_activation == "tanh" else None
52
+ self.first_dropout = nn.Dropout(self.summary_first_dropout) if self.summary_first_dropout > 0 else None
53
+ self.last_dropout = nn.Dropout(self.summary_last_dropout) if self.summary_last_dropout > 0 else None
54
+
55
+ def forward(self, hidden_states, cls_index=None):
56
+ if self.summary_type == "last":
57
+ output = hidden_states[:, -1]
58
+ elif self.summary_type == "first":
59
+ output = hidden_states[:, 0]
60
+ elif self.summary_type == "mean":
61
+ output = hidden_states.mean(dim=1)
62
+ elif self.summary_type == "cls_index":
63
+ if cls_index is None:
64
+ cls_index = torch.full_like(hidden_states[:, :1, :1], hidden_states.size(1) - 1, dtype=torch.long)
65
+ cls_index = cls_index[:, 0].long()
66
+ output = hidden_states[torch.arange(hidden_states.size(0)), cls_index]
67
+ else:
68
+ output = hidden_states[:, -1] # default to last
69
+
70
+ if self.first_dropout:
71
+ output = self.first_dropout(output)
72
+
73
+ if self.summary_use_proj:
74
+ output = self.summary(output)
75
+
76
+ if self.activation:
77
+ output = self.activation(output)
78
+
79
+ if self.last_dropout:
80
+ output = self.last_dropout(output)
81
+
82
+ return output
83
+
84
+
85
  class GPT2Config(PretrainedConfig):
86
  """
87
  This is the configuration class to store the configuration of a [`GPT2Model`] or a [`TFGPT2Model`]. It is used to
codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_rope2_time_new_correct_mask_noncasual_reflow.py CHANGED
@@ -37,7 +37,7 @@ from transformers.modeling_outputs import (
37
  SequenceClassifierOutputWithPast,
38
  TokenClassifierOutput,
39
  )
40
- from transformers.modeling_utils import PreTrainedModel, SequenceSummary
41
  from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
42
  from transformers.utils import (
43
  ModelOutput,
@@ -50,7 +50,7 @@ from transformers.utils import (
50
  replace_return_docstrings,
51
  )
52
  from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
53
- from models_gpt.models.gpt2_config import GPT2Config
54
 
55
 
56
  if is_flash_attn_2_available():
 
37
  SequenceClassifierOutputWithPast,
38
  TokenClassifierOutput,
39
  )
40
+ from transformers.modeling_utils import PreTrainedModel
41
  from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
42
  from transformers.utils import (
43
  ModelOutput,
 
50
  replace_return_docstrings,
51
  )
52
  from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
53
+ from models_gpt.models.gpt2_config import GPT2Config, SequenceSummary
54
 
55
 
56
  if is_flash_attn_2_available():
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/musicfm/modules/features.py CHANGED
@@ -15,7 +15,7 @@
15
 
16
  import torchaudio
17
  from torch import nn
18
-
19
 
20
  class MelSTFT(nn.Module):
21
  def __init__(
@@ -39,7 +39,16 @@ class MelSTFT(nn.Module):
39
  self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
40
 
41
  def forward(self, waveform):
42
- if self.is_db:
43
- return self.amplitude_to_db(self.mel_stft(waveform))
44
- else:
45
- return self.mel_stft(waveform)
 
 
 
 
 
 
 
 
 
 
15
 
16
  import torchaudio
17
  from torch import nn
18
+ import torch
19
 
20
  class MelSTFT(nn.Module):
21
  def __init__(
 
39
  self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
40
 
41
  def forward(self, waveform):
42
+ # 将数据移至 CPU 处理 STFT,再移回 GPU
43
+ device = waveform.device
44
+ waveform_cpu = waveform.cpu()
45
+ # 强制在 CPU 上运行
46
+ with torch.cpu.amp.autocast(enabled=False):
47
+ if self.is_db:
48
+ spec = self.amplitude_to_db(self.mel_stft.to('cpu')(waveform_cpu))
49
+ else:
50
+ spec = self.mel_stft.to('cpu')(waveform_cpu)
51
+ # 结果移回原设备,并将 mel_stft 移回原设备供下次使用(或者克隆一个 cpu 版的)
52
+ spec = spec.to(device)
53
+ self.mel_stft.to(device)
54
+ return spec
codeclm/tokenizer/audio_tokenizer.py CHANGED
@@ -136,7 +136,7 @@ class Flow1dVAE1rvq(AudioTokenizer):
136
  @torch.no_grad()
137
  def decode(self, codes: torch.Tensor, prompt = None, scale: tp.Optional[torch.Tensor] = None, ncodes=9):
138
  wav = self.model.code2sound(codes, prompt=prompt, guidance_scale=1.5,
139
- num_steps=50, disable_progress=False) # [B,N,T] -> [B,T]
140
  return wav[None]
141
 
142
 
@@ -222,7 +222,7 @@ class Flow1dVAESeparate(AudioTokenizer):
222
  @torch.no_grad()
223
  def decode(self, codes: torch.Tensor, prompt_vocal = None, prompt_bgm = None, chunked=False, chunk_size=128):
224
  wav = self.model.code2sound(codes, prompt_vocal=prompt_vocal, prompt_bgm=prompt_bgm, guidance_scale=1.5,
225
- num_steps=50, disable_progress=False, chunked=chunked, chunk_size=chunk_size) # [B,N,T] -> [B,T]
226
  return wav[None]
227
 
228
 
 
136
  @torch.no_grad()
137
  def decode(self, codes: torch.Tensor, prompt = None, scale: tp.Optional[torch.Tensor] = None, ncodes=9):
138
  wav = self.model.code2sound(codes, prompt=prompt, guidance_scale=1.5,
139
+ num_steps=10, disable_progress=False) # [B,N,T] -> [B,T]
140
  return wav[None]
141
 
142
 
 
222
  @torch.no_grad()
223
  def decode(self, codes: torch.Tensor, prompt_vocal = None, prompt_bgm = None, chunked=False, chunk_size=128):
224
  wav = self.model.code2sound(codes, prompt_vocal=prompt_vocal, prompt_bgm=prompt_bgm, guidance_scale=1.5,
225
+ num_steps=10, disable_progress=False, chunked=chunked, chunk_size=chunk_size) # [B,N,T] -> [B,T]
226
  return wav[None]
227
 
228
 
generate.py CHANGED
@@ -1,22 +1,19 @@
1
- from hmac import new
2
- import sys
3
- import os
4
- import argparse
5
-
6
  import time
7
- import json
8
  import torch
 
 
 
 
9
  import torchaudio
10
  import numpy as np
11
- from omegaconf import OmegaConf
12
- from codeclm.models import builders
13
- import gc
14
- from codeclm.trainer.codec_song_pl import CodecLM_PL
15
- from codeclm.models import CodecLM
16
- from third_party.demucs.models.pretrained import get_model_from_yaml
17
  import re
 
 
 
18
 
19
- auto_prompt_type = ['Pop', 'R&B', 'Dance', 'Jazz', 'Folk', 'Rock', 'Chinese Style', 'Chinese Tradition', 'Metal', 'Reggae', 'Chinese Opera', 'Auto']
20
 
21
  def check_language_by_text(text):
22
  chinese_pattern = re.compile(r'[\u4e00-\u9fff]')
@@ -32,563 +29,172 @@ def check_language_by_text(text):
32
  else:
33
  return "en"
34
 
35
- class Separator:
36
- def __init__(self, dm_model_path='third_party/demucs/ckpt/htdemucs.pth', dm_config_path='third_party/demucs/ckpt/htdemucs.yaml', gpu_id=0) -> None:
37
- if torch.cuda.is_available() and gpu_id < torch.cuda.device_count():
38
- self.device = torch.device(f"cuda:{gpu_id}")
39
- else:
40
- self.device = torch.device("cpu")
41
- self.demucs_model = self.init_demucs_model(dm_model_path, dm_config_path)
42
-
43
- def init_demucs_model(self, model_path, config_path):
44
- model = get_model_from_yaml(config_path, model_path)
45
- model.to(self.device)
46
- model.eval()
47
- return model
48
-
49
- def load_audio(self, f):
50
- a, fs = torchaudio.load(f)
51
- if (fs != 48000):
52
- a = torchaudio.functional.resample(a, fs, 48000)
53
- if a.shape[-1] >= 48000*10:
54
- a = a[..., :48000*10]
55
- return a[:, 0:48000*10]
56
-
57
- def run(self, audio_path, output_dir='tmp', ext=".flac"):
58
- os.makedirs(output_dir, exist_ok=True)
59
- name, _ = os.path.splitext(os.path.split(audio_path)[-1])
60
- output_paths = []
61
 
62
- for stem in self.demucs_model.sources:
63
- output_path = os.path.join(output_dir, f"{name}_{stem}{ext}")
64
- if os.path.exists(output_path):
65
- output_paths.append(output_path)
66
- if len(output_paths) == 1: # 4
67
- vocal_path = output_paths[0]
68
- else:
69
- drums_path, bass_path, other_path, vocal_path = self.demucs_model.separate(audio_path, output_dir, device=self.device)
70
- for path in [drums_path, bass_path, other_path]:
71
- os.remove(path)
72
- full_audio = self.load_audio(audio_path)
73
- vocal_audio = self.load_audio(vocal_path)
74
- bgm_audio = full_audio - vocal_audio
75
- return full_audio, vocal_audio, bgm_audio
76
 
77
 
78
  def parse_args():
79
  parser = argparse.ArgumentParser(description='Song Generation Script')
80
 
81
  # 必需参数
82
- parser.add_argument('--ckpt_path', type=str, required=True,
83
- help='Path to the checkpoint directory containing config.yaml and model.pt')
84
  parser.add_argument('--input_jsonl', type=str, required=True,
85
  help='Path to input JSONL file containing generation tasks')
86
  parser.add_argument('--save_dir', type=str, required=True,
87
  help='Directory to save generated audio files and results')
88
- # 可选参数
89
- parser.add_argument('--generate_type', type=str, default='mixed',
90
- help='Type of generation: "vocal" or "bgm" or "separate" or "mixed" (default: "mixed")')
91
- parser.add_argument('--use_flash_attn', action='store_true',
92
- help='Whether to use flash attention (default: False)')
93
- parser.add_argument('--low_mem', action='store_true',
94
- help='Whether to use low memory mode (default: False)')
95
  return parser.parse_args()
96
 
97
- def generate(args, version = 'v1.0'):
 
98
  torch.set_num_threads(1)
99
- ckpt_path = args.ckpt_path
 
 
 
 
 
 
100
  input_jsonl = args.input_jsonl
101
  save_dir = args.save_dir
102
- cfg_path = os.path.join(ckpt_path, 'config.yaml')
103
- ckpt_path = os.path.join(ckpt_path, 'model.pt')
104
  cfg = OmegaConf.load(cfg_path)
105
- cfg.lm.use_flash_attn_2 = args.use_flash_attn
106
- print(f"use_flash_attn: {args.use_flash_attn}")
107
  cfg.mode = 'inference'
108
  max_duration = cfg.max_dur
109
- gen_type = args.generate_type
110
 
111
-
112
- separator = Separator()
113
- auto_prompt = torch.load('tools/new_auto_prompt.pt')
114
  audio_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint, cfg)
 
 
 
 
115
  audio_tokenizer = audio_tokenizer.eval().cuda()
116
- with open(input_jsonl, "r") as fp:
117
- lines = fp.readlines()
118
-
119
-
120
- new_items = []
121
- for line in lines:
122
- item = json.loads(line)
123
- target_wav_name = f"{save_dir}/audios/{item['idx']}.flac"
124
- # get prompt audio
125
- if "prompt_audio_path" in item:
126
- assert os.path.exists(item['prompt_audio_path']), f"prompt_audio_path {item['prompt_audio_path']} not found"
127
- assert 'auto_prompt_audio_type' not in item, f"auto_prompt_audio_type and prompt_audio_path cannot be used together"
128
- with torch.no_grad():
129
- pmt_wav, vocal_wav, bgm_wav = separator.run(item['prompt_audio_path'])
130
- item['raw_pmt_wav'] = pmt_wav
131
- item['raw_vocal_wav'] = vocal_wav
132
- item['raw_bgm_wav'] = bgm_wav
133
- if pmt_wav.dim() == 2:
134
- pmt_wav = pmt_wav[None]
135
- if pmt_wav.dim() != 3:
136
- raise ValueError("Melody wavs should have a shape [B, C, T].")
137
- pmt_wav = list(pmt_wav)
138
- if vocal_wav.dim() == 2:
139
- vocal_wav = vocal_wav[None]
140
- if vocal_wav.dim() != 3:
141
- raise ValueError("Vocal wavs should have a shape [B, C, T].")
142
- vocal_wav = list(vocal_wav)
143
- if bgm_wav.dim() == 2:
144
- bgm_wav = bgm_wav[None]
145
- if bgm_wav.dim() != 3:
146
- raise ValueError("BGM wavs should have a shape [B, C, T].")
147
- bgm_wav = list(bgm_wav)
148
- if type(pmt_wav) == list:
149
- pmt_wav = torch.stack(pmt_wav, dim=0)
150
- if type(vocal_wav) == list:
151
- vocal_wav = torch.stack(vocal_wav, dim=0)
152
- if type(bgm_wav) == list:
153
- bgm_wav = torch.stack(bgm_wav, dim=0)
154
- pmt_wav = pmt_wav
155
- vocal_wav = vocal_wav
156
- bgm_wav = bgm_wav
157
- with torch.no_grad():
158
- pmt_wav, _ = audio_tokenizer.encode(pmt_wav.cuda())
159
- melody_is_wav = False
160
- elif "auto_prompt_audio_type" in item:
161
- assert item["auto_prompt_audio_type"] in auto_prompt_type, f"auto_prompt_audio_type {item['auto_prompt_audio_type']} not found"
162
- if item['auto_prompt_audio_type'] == 'Auto':
163
- lang = check_language_by_text(item['gt_lyric'])
164
- prompt_token = auto_prompt['Auto'][lang][np.random.randint(0, len(auto_prompt['Auto'][lang]))]
165
- else:
166
- prompt_token = auto_prompt[item["auto_prompt_audio_type"]][np.random.randint(0, len(auto_prompt[item["auto_prompt_audio_type"]]))]
167
- pmt_wav = prompt_token[:,[0],:]
168
- vocal_wav = prompt_token[:,[1],:]
169
- bgm_wav = prompt_token[:,[2],:]
170
- melody_is_wav = False
171
- else:
172
- pmt_wav = None
173
- vocal_wav = None
174
- bgm_wav = None
175
- melody_is_wav = True
176
- item['pmt_wav'] = pmt_wav
177
- item['vocal_wav'] = vocal_wav
178
- item['bgm_wav'] = bgm_wav
179
- item['melody_is_wav'] = melody_is_wav
180
- item["idx"] = f"{item['idx']}"
181
- item["wav_path"] = target_wav_name
182
- new_items.append(item)
183
-
184
- del audio_tokenizer
185
- del separator
186
-
187
- torch.cuda.empty_cache()
188
-
189
- if "audio_tokenizer_checkpoint_sep" in cfg.keys():
190
- seperate_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint_sep, cfg)
191
- else:
192
- seperate_tokenizer = None
193
-
194
- if seperate_tokenizer is not None:
195
- seperate_tokenizer = seperate_tokenizer.eval().cuda()
196
-
197
- for item in new_items:
198
- if "prompt_audio_path" in item:
199
- with torch.no_grad():
200
- vocal_wav, bgm_wav = seperate_tokenizer.encode(item['vocal_wav'].cuda(), item['bgm_wav'].cuda())
201
- item['vocal_wav'] = vocal_wav
202
- item['bgm_wav'] = bgm_wav
203
-
204
- torch.cuda.empty_cache()
205
- audiolm = builders.get_lm_model(cfg, version=version)
206
- checkpoint = torch.load(ckpt_path, map_location='cpu')
207
- audiolm_state_dict = {k.replace('audiolm.', ''): v for k, v in checkpoint.items() if k.startswith('audiolm')}
208
- audiolm.load_state_dict(audiolm_state_dict, strict=False)
209
- audiolm = audiolm.eval()
210
- audiolm = audiolm.cuda().to(torch.float16)
211
-
212
- model = CodecLM(name = "tmp",
213
- lm = audiolm,
214
- audiotokenizer = None,
215
- max_duration = max_duration,
216
- seperate_tokenizer = seperate_tokenizer,
217
  )
 
 
218
 
219
- cfg_coef = 1.5 #25
220
- temp = 0.9
221
- top_k = 50
222
- top_p = 0.0
223
- record_tokens = True
224
- record_window = 50
225
-
226
- model.set_generation_params(duration=max_duration, extend_stride=5, temperature=temp, cfg_coef=cfg_coef,
227
- top_k=top_k, top_p=top_p, record_tokens=record_tokens, record_window=record_window)
228
  os.makedirs(save_dir, exist_ok=True)
229
  os.makedirs(save_dir + "/audios", exist_ok=True)
230
  os.makedirs(save_dir + "/jsonl", exist_ok=True)
231
-
232
- for item in new_items:
233
- lyric = item["gt_lyric"]
234
- if version == 'v1.0':
235
- descriptions = item["descriptions"] if "descriptions" in item else None
236
- else:
237
- descriptions = item["descriptions"] if "descriptions" in item else '.'
238
- descriptions = '[Musicality-very-high]' + ', ' + descriptions
239
- pmt_wav = item['pmt_wav']
240
- vocal_wav = item['vocal_wav']
241
- bgm_wav = item['bgm_wav']
242
- melody_is_wav = item['melody_is_wav']
243
- target_wav_name = f"{save_dir}/audios/{item['idx']}.flac"
244
-
245
-
246
- generate_inp = {
247
- 'lyrics': [lyric.replace(" ", " ")],
248
- 'descriptions': [descriptions],
249
- 'melody_wavs': pmt_wav,
250
- 'vocal_wavs': vocal_wav,
251
- 'bgm_wavs': bgm_wav,
252
- 'melody_is_wav': melody_is_wav,
253
- }
254
- start_time = time.time()
255
- with torch.autocast(device_type="cuda", dtype=torch.float16):
256
- with torch.no_grad():
257
- tokens = model.generate(**generate_inp, return_tokens=True)
258
- mid_time = time.time()
259
-
260
- with torch.no_grad():
261
- if 'raw_pmt_wav' in item:
262
- if gen_type == 'separate':
263
- wav_seperate = model.generate_audio(tokens, item['raw_pmt_wav'], item['raw_vocal_wav'], item['raw_bgm_wav'], chunked=True, gen_type='mixed')
264
- wav_vocal = model.generate_audio(tokens, item['raw_pmt_wav'], item['raw_vocal_wav'], item['raw_bgm_wav'], chunked=True, gen_type='vocal')
265
- wav_bgm = model.generate_audio(tokens, item['raw_pmt_wav'], item['raw_vocal_wav'], item['raw_bgm_wav'], chunked=True, gen_type='bgm')
266
- elif gen_type == 'mixed':
267
- wav_seperate = model.generate_audio(tokens, item['raw_pmt_wav'], item['raw_vocal_wav'], item['raw_bgm_wav'],chunked=True, gen_type=gen_type)
268
- else:
269
- wav_seperate = model.generate_audio(tokens,chunked=True, gen_type=gen_type)
270
- del item['raw_pmt_wav']
271
- del item['raw_vocal_wav']
272
- del item['raw_bgm_wav']
273
- else:
274
- if gen_type == 'separate':
275
- wav_vocal = model.generate_audio(tokens, chunked=True, gen_type='vocal')
276
- wav_bgm = model.generate_audio(tokens, chunked=True, gen_type='bgm')
277
- wav_seperate = model.generate_audio(tokens, chunked=True, gen_type='mixed')
278
- else:
279
- wav_seperate = model.generate_audio(tokens, chunked=True, gen_type=gen_type)
280
- del item['pmt_wav']
281
- del item['vocal_wav']
282
- del item['bgm_wav']
283
- del item['melody_is_wav']
284
- end_time = time.time()
285
- if gen_type == 'separate':
286
- torchaudio.save(target_wav_name.replace('.flac', '_vocal.flac'), wav_vocal[0].cpu().float(), cfg.sample_rate)
287
- torchaudio.save(target_wav_name.replace('.flac', '_bgm.flac'), wav_bgm[0].cpu().float(), cfg.sample_rate)
288
- torchaudio.save(target_wav_name, wav_seperate[0].cpu().float(), cfg.sample_rate)
289
- else:
290
- torchaudio.save(target_wav_name, wav_seperate[0].cpu().float(), cfg.sample_rate)
291
-
292
- print(f"process{item['idx']}, lm cost {mid_time - start_time}s, diffusion cost {end_time - mid_time}")
293
- item["idx"] = f"{item['idx']}"
294
- item["wav_path"] = target_wav_name
295
-
296
- src_jsonl_name = os.path.split(input_jsonl)[-1]
297
- with open(f"{save_dir}/jsonl/{src_jsonl_name}.jsonl", "w", encoding='utf-8') as fw:
298
- for item in new_items:
299
- fw.writelines(json.dumps(item, ensure_ascii=False)+"\n")
300
-
301
- def generate_lowmem(args):
302
- torch.set_num_threads(1)
303
- ckpt_path = args.ckpt_path
304
- input_jsonl = args.input_jsonl
305
- save_dir = args.save_dir
306
- cfg_path = os.path.join(ckpt_path, 'config.yaml')
307
- ckpt_path = os.path.join(ckpt_path, 'model.pt')
308
- cfg = OmegaConf.load(cfg_path)
309
- cfg.lm.use_flash_attn_2 = args.use_flash_attn
310
- print(f"use_flash_attn: {args.use_flash_attn}")
311
- cfg.mode = 'inference'
312
- max_duration = cfg.max_dur
313
- gen_type = args.generate_type
314
- chunk_size = 128
315
- use_audio_tokenizer = False
316
  with open(input_jsonl, "r") as fp:
317
  lines = fp.readlines()
318
- for line in lines:
319
- item = json.loads(line)
320
- if "prompt_audio_path" in item:
321
- use_audio_tokenizer = True
322
- break
323
- if use_audio_tokenizer:
324
- separator = Separator()
325
- audio_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint, cfg)
326
- audio_tokenizer = audio_tokenizer.eval().cuda()
327
- auto_prompt = torch.load('tools/new_prompt.pt')
328
  new_items = []
329
  for line in lines:
330
  item = json.loads(line)
 
 
 
331
  target_wav_name = f"{save_dir}/audios/{item['idx']}.flac"
332
- # get prompt audio
 
333
  if "prompt_audio_path" in item:
334
  assert os.path.exists(item['prompt_audio_path']), f"prompt_audio_path {item['prompt_audio_path']} not found"
335
  assert 'auto_prompt_audio_type' not in item, f"auto_prompt_audio_type and prompt_audio_path cannot be used together"
336
  with torch.no_grad():
337
- pmt_wav, vocal_wav, bgm_wav = separator.run(item['prompt_audio_path'])
338
  item['raw_pmt_wav'] = pmt_wav
339
- item['raw_vocal_wav'] = vocal_wav
340
- item['raw_bgm_wav'] = bgm_wav
341
  if pmt_wav.dim() == 2:
342
  pmt_wav = pmt_wav[None]
343
  if pmt_wav.dim() != 3:
344
  raise ValueError("Melody wavs should have a shape [B, C, T].")
345
  pmt_wav = list(pmt_wav)
346
- if vocal_wav.dim() == 2:
347
- vocal_wav = vocal_wav[None]
348
- if vocal_wav.dim() != 3:
349
- raise ValueError("Vocal wavs should have a shape [B, C, T].")
350
- vocal_wav = list(vocal_wav)
351
- if bgm_wav.dim() == 2:
352
- bgm_wav = bgm_wav[None]
353
- if bgm_wav.dim() != 3:
354
- raise ValueError("BGM wavs should have a shape [B, C, T].")
355
- bgm_wav = list(bgm_wav)
356
  if type(pmt_wav) == list:
357
  pmt_wav = torch.stack(pmt_wav, dim=0)
358
- if type(vocal_wav) == list:
359
- vocal_wav = torch.stack(vocal_wav, dim=0)
360
- if type(bgm_wav) == list:
361
- bgm_wav = torch.stack(bgm_wav, dim=0)
362
  with torch.no_grad():
363
  pmt_wav, _ = audio_tokenizer.encode(pmt_wav.cuda())
 
364
  melody_is_wav = False
365
  elif "auto_prompt_audio_type" in item:
366
  assert item["auto_prompt_audio_type"] in auto_prompt_type, f"auto_prompt_audio_type {item['auto_prompt_audio_type']} not found"
367
- prompt_token = auto_prompt[item["auto_prompt_audio_type"]][np.random.randint(0, len(auto_prompt[item["auto_prompt_audio_type"]]))]
 
368
  pmt_wav = prompt_token[:,[0],:]
369
- vocal_wav = prompt_token[:,[1],:]
370
- bgm_wav = prompt_token[:,[2],:]
371
  melody_is_wav = False
372
  else:
373
  pmt_wav = None
374
- vocal_wav = None
375
- bgm_wav = None
376
  melody_is_wav = True
377
- item['pmt_wav'] = pmt_wav
378
- item['vocal_wav'] = vocal_wav
379
- item['bgm_wav'] = bgm_wav
380
- item['melody_is_wav'] = melody_is_wav
381
  item["idx"] = f"{item['idx']}"
382
  item["wav_path"] = target_wav_name
383
- new_items.append(item)
384
-
385
- if use_audio_tokenizer:
386
- del audio_tokenizer
387
- del separator
388
-
389
- torch.cuda.empty_cache()
390
-
391
- if "audio_tokenizer_checkpoint_sep" in cfg.keys() and use_audio_tokenizer:
392
- seperate_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint_sep, cfg)
393
- else:
394
- seperate_tokenizer = None
395
-
396
- if seperate_tokenizer is not None:
397
- seperate_tokenizer = seperate_tokenizer.eval().cuda()
398
-
399
- for item in new_items:
400
- if "prompt_audio_path" in item:
401
- with torch.no_grad():
402
- vocal_wav, bgm_wav = seperate_tokenizer.encode(item['vocal_wav'].cuda(), item['bgm_wav'].cuda())
403
- item['vocal_wav'] = vocal_wav
404
- item['bgm_wav'] = bgm_wav
405
-
406
- if use_audio_tokenizer:
407
- del seperate_tokenizer
408
-
409
- torch.cuda.empty_cache()
410
-
411
- # Define model or load pretrained model
412
- audiolm = builders.get_lm_model(cfg)
413
- checkpoint = torch.load(ckpt_path, map_location='cpu')
414
- audiolm_state_dict = {k.replace('audiolm.', ''): v for k, v in checkpoint.items() if k.startswith('audiolm')}
415
- audiolm.load_state_dict(audiolm_state_dict, strict=False)
416
- audiolm = audiolm.eval()
417
-
418
- offload_audiolm = True if 'offload' in cfg.keys() and 'audiolm' in cfg.offload else False
419
- if offload_audiolm:
420
- audiolm_offload_param = OffloadParamParse.parse_config(audiolm, cfg.offload.audiolm)
421
- audiolm_offload_param.show()
422
- offload_profiler = OffloadProfiler(device_index=0, **(audiolm_offload_param.init_param_dict()))
423
- offload_profiler.offload_layer(**(audiolm_offload_param.offload_layer_param_dict()))
424
- offload_profiler.clean_cache_wrapper(**(audiolm_offload_param.clean_cache_param_dict()))
425
- else:
426
- audiolm = audiolm.cuda().to(torch.float16)
427
-
428
- model = CodecLM(name = "tmp",
429
- lm = audiolm,
430
- audiotokenizer = None,
431
- max_duration = max_duration,
432
- seperate_tokenizer = None,
433
- )
434
-
435
- cfg_coef = 1.5 #25
436
- temp = 0.9
437
- top_k = 50
438
- top_p = 0.0
439
- record_tokens = True
440
- record_window = 50
441
-
442
-
443
- model.set_generation_params(duration=max_duration, extend_stride=5, temperature=temp, cfg_coef=cfg_coef,
444
- top_k=top_k, top_p=top_p, record_tokens=record_tokens, record_window=record_window)
445
- os.makedirs(save_dir, exist_ok=True)
446
- os.makedirs(save_dir + "/audios", exist_ok=True)
447
- os.makedirs(save_dir + "/jsonl", exist_ok=True)
448
-
449
-
450
- for item in new_items:
451
- lyric = item["gt_lyric"]
452
- descriptions = item["descriptions"] if "descriptions" in item else None
453
- pmt_wav = item['pmt_wav']
454
- vocal_wav = item['vocal_wav']
455
- bgm_wav = item['bgm_wav']
456
- melody_is_wav = item['melody_is_wav']
457
-
458
  generate_inp = {
459
- 'lyrics': [lyric.replace(" ", " ")],
460
- 'descriptions': [descriptions],
461
  'melody_wavs': pmt_wav,
462
- 'vocal_wavs': vocal_wav,
463
- 'bgm_wavs': bgm_wav,
464
  'melody_is_wav': melody_is_wav,
 
465
  }
466
- with torch.autocast(device_type="cuda", dtype=torch.float16):
467
- with torch.no_grad():
468
- tokens = model.generate(**generate_inp, return_tokens=True)
469
- if offload_audiolm:
470
- offload_profiler.reset_empty_cache_mem_line()
471
- item['tokens'] = tokens
472
- if offload_audiolm:
473
- offload_profiler.stop()
474
- del offload_profiler
475
- del audiolm_offload_param
476
- del model
477
- audiolm = audiolm.cpu()
478
- del audiolm
479
- del checkpoint
480
- gc.collect()
481
- torch.cuda.empty_cache()
482
-
483
- seperate_tokenizer = builders.get_audio_tokenizer_model_cpu(cfg.audio_tokenizer_checkpoint_sep, cfg)
484
- device = "cuda:0"
485
- seperate_tokenizer.model.device = device
486
- seperate_tokenizer.model.vae = seperate_tokenizer.model.vae.to(device)
487
- seperate_tokenizer.model.model.device = torch.device(device)
488
- seperate_tokenizer = seperate_tokenizer.eval()
489
-
490
- # offload_wav_tokenizer_diffusion = True if 'offload' in cfg.keys() and 'wav_tokenizer_diffusion' in cfg.offload else False
491
- offload_wav_tokenizer_diffusion = False
492
- if offload_wav_tokenizer_diffusion:
493
- sep_offload_param = OffloadParamParse.parse_config(seperate_tokenizer, cfg.offload.wav_tokenizer_diffusion)
494
- sep_offload_param.show()
495
- sep_offload_profiler = OffloadProfiler(device_index=0, **(sep_offload_param.init_param_dict()))
496
- sep_offload_profiler.offload_layer(**(sep_offload_param.offload_layer_param_dict()))
497
- sep_offload_profiler.clean_cache_wrapper(**(sep_offload_param.clean_cache_param_dict()))
498
- else:
499
- seperate_tokenizer.model.model = seperate_tokenizer.model.model.to(device)
500
-
501
- model = CodecLM(name = "tmp",
502
- lm = None,
503
- audiotokenizer = None,
504
- max_duration = max_duration,
505
- seperate_tokenizer = seperate_tokenizer,
506
- )
507
 
508
- for item in new_items:
509
  with torch.no_grad():
 
510
  if 'raw_pmt_wav' in item:
511
- if gen_type == 'separate':
512
- wav_seperate = model.generate_audio(item['tokens'], item['raw_pmt_wav'], item['raw_vocal_wav'], item['raw_bgm_wav'],chunked=True, gen_type='mixed')
513
- wav_vocal = model.generate_audio(item['tokens'],chunked=True, gen_type='vocal')
514
- wav_bgm = model.generate_audio(item['tokens'], chunked=True, gen_type='bgm')
515
- elif gen_type == 'mixed':
516
- wav_seperate = model.generate_audio(item['tokens'], item['raw_pmt_wav'], item['raw_vocal_wav'], item['raw_bgm_wav'],chunked=True, gen_type=gen_type)
517
- else:
518
- wav_seperate = model.generate_audio(item['tokens'], chunked=True, gen_type=gen_type)
519
  del item['raw_pmt_wav']
520
- del item['raw_vocal_wav']
521
- del item['raw_bgm_wav']
522
  else:
523
- if gen_type == 'separate':
524
- wav_vocal = model.generate_audio(item['tokens'], chunked=True, gen_type='vocal')
525
- wav_bgm = model.generate_audio(item['tokens'], chunked=True, gen_type='bgm')
526
- wav_seperate = model.generate_audio(item['tokens'], chunked=True, gen_type='mixed')
527
- else:
528
- wav_seperate = model.generate_audio(item['tokens'], chunked=True, gen_type=gen_type)
529
- if gen_type == 'separate':
530
- torchaudio.save(item['wav_path'].replace('.flac', '_vocal.flac'), wav_vocal[0].cpu().float(), cfg.sample_rate)
531
- torchaudio.save(item['wav_path'].replace('.flac', '_bgm.flac'), wav_bgm[0].cpu().float(), cfg.sample_rate)
532
- torchaudio.save(item['wav_path'], wav_seperate[0].cpu().float(), cfg.sample_rate)
533
- else:
534
- torchaudio.save(item['wav_path'], wav_seperate[0].cpu().float(), cfg.sample_rate)
535
- del item['tokens']
536
- del item['pmt_wav']
537
- del item['vocal_wav']
538
- del item['bgm_wav']
539
- del item['melody_is_wav']
540
- if offload_wav_tokenizer_diffusion:
541
- sep_offload_profiler.reset_empty_cache_mem_line()
542
 
543
- if offload_wav_tokenizer_diffusion:
544
- sep_offload_profiler.stop()
545
- torch.cuda.empty_cache()
546
  src_jsonl_name = os.path.split(input_jsonl)[-1]
547
  with open(f"{save_dir}/jsonl/{src_jsonl_name}.jsonl", "w", encoding='utf-8') as fw:
548
  for item in new_items:
549
  fw.writelines(json.dumps(item, ensure_ascii=False)+"\n")
550
 
551
-
552
  if __name__ == "__main__":
553
- torch.backends.cudnn.enabled = False
554
- OmegaConf.register_new_resolver("eval", lambda x: eval(x))
555
- OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx])
556
- OmegaConf.register_new_resolver("get_fname", lambda: os.path.splitext(os.path.basename(sys.argv[1]))[0])
557
- OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x)))
558
- np.random.seed(int(time.time()))
559
- # 解析命令行参数
560
- args = parse_args()
561
- if torch.cuda.is_available():
562
- device = torch.cuda.current_device()
563
- reserved = torch.cuda.memory_reserved(device)
564
- total = torch.cuda.get_device_properties(device).total_memory
565
- res_mem = (total - reserved) / 1024 / 1024 / 1024
566
- print(f"reserved memory: {res_mem}GB")
567
-
568
- model_name = args.ckpt_path.split("/")[-1].lower().replace('-', '_')
569
- assert model_name in ['songgeneration_base', 'songgeneration_base_new', 'songgeneration_base_full', 'songgeneration_large', 'songgeneration_new_small', 'songgeneration_new_large', 'songgeneration_new_medium'], f'{model_name} is not supported, currently only songgeneration_base, songgeneration_base_new, songgeneration_base_full, songgeneration_large are supported. Please download correct files and rename the folder to the corresponding version name.'
570
- if model_name == 'songgeneration_base' or model_name == 'songgeneration_base_new' or model_name == 'songgeneration_base_full':
571
- if res_mem > 24 and not args.low_mem:
572
- print("use generate")
573
- generate(args)
574
- else:
575
- from codeclm.utils.offload_profiler import OffloadProfiler, OffloadParamParse
576
- print("use generate_lowmem")
577
- generate_lowmem(args)
578
- elif model_name == 'songgeneration_large':
579
- if res_mem > 36 and not args.low_mem:
580
- print("use generate")
581
- generate(args)
582
- else:
583
- print("use generate_lowmem")
584
- from codeclm.utils.offload_profiler import OffloadProfiler, OffloadParamParse
585
- generate_lowmem(args)
586
- elif model_name == 'songgeneration_new_small' or model_name == 'songgeneration_new_large' or model_name == 'songgeneration_new_medium':
587
- print("use generate")
588
- generate(args, version = 'v1.5')
589
-
590
-
591
- else:
592
- print("CUDA is not available")
593
- exit()
594
-
 
1
+ import glob
 
 
 
 
2
  import time
 
3
  import torch
4
+ from codeclm.models.codeclm_gen import CodecLM_gen
5
+ from codeclm.models import builders
6
+ import sys
7
+ import os
8
  import torchaudio
9
  import numpy as np
10
+ import json
11
+ from vllm import LLM, SamplingParams
 
 
 
 
12
  import re
13
+ import argparse
14
+ import librosa
15
+ auto_prompt_type = ['Pop', 'Latin', 'Rock', 'Electronic', 'Metal', 'Country', 'R&B/Soul', 'Ballad', 'Jazz', 'World', 'Hip-Hop', 'Funk', 'Soundtrack','Auto']
16
 
 
17
 
18
  def check_language_by_text(text):
19
  chinese_pattern = re.compile(r'[\u4e00-\u9fff]')
 
29
  else:
30
  return "en"
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ def load_audio(f):
34
+ a, fs= librosa.load(f, sr=48000)
35
+ a = torch.tensor(a).unsqueeze(0)
36
+ if (fs != 48000):
37
+ a = torchaudio.functional.resample(a, fs, 48000)
38
+ if a.shape[-1] >= 48000*10:
39
+ a = a[..., :48000*10]
40
+ return a[:, 0:48000*10]
 
 
 
 
 
 
41
 
42
 
43
  def parse_args():
44
  parser = argparse.ArgumentParser(description='Song Generation Script')
45
 
46
  # 必需参数
 
 
47
  parser.add_argument('--input_jsonl', type=str, required=True,
48
  help='Path to input JSONL file containing generation tasks')
49
  parser.add_argument('--save_dir', type=str, required=True,
50
  help='Directory to save generated audio files and results')
51
+ parser.add_argument('--config_path', type=str, required=True,
52
+ help='Path to the config file')
 
 
 
 
 
53
  return parser.parse_args()
54
 
55
+
56
+ def main():
57
  torch.set_num_threads(1)
58
+ torch.backends.cudnn.enabled = False #taiji的某些傻呗node会报奇奇怪怪的错
59
+ from omegaconf import OmegaConf
60
+ OmegaConf.register_new_resolver("eval", lambda x: eval(x))
61
+ OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx])
62
+ OmegaConf.register_new_resolver("get_fname", lambda: os.path.splitext(os.path.basename(sys.argv[1]))[0])
63
+ OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x)))
64
+ args = parse_args()
65
  input_jsonl = args.input_jsonl
66
  save_dir = args.save_dir
67
+ cfg_path = args.config_path
68
+
69
  cfg = OmegaConf.load(cfg_path)
 
 
70
  cfg.mode = 'inference'
71
  max_duration = cfg.max_dur
 
72
 
 
 
 
73
  audio_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint, cfg)
74
+ if audio_tokenizer is not None:
75
+ for param in audio_tokenizer.parameters():
76
+ param.requires_grad = False
77
+ print("Audio tokenizer successfully loaded!")
78
  audio_tokenizer = audio_tokenizer.eval().cuda()
79
+ model_condition = CodecLM_gen(cfg=cfg,name = "tmp",audiotokenizer = audio_tokenizer,max_duration = max_duration)
80
+ model_condition.condition_provider.conditioners.load_state_dict(torch.load(cfg.lm_checkpoint+"/conditioners_weights.pth"))
81
+ print('Conditioner successfully loaded!')
82
+ llm = LLM(
83
+ model=cfg.lm_checkpoint,
84
+ trust_remote_code=True,
85
+ tensor_parallel_size=cfg.vllm.device_num,
86
+ enforce_eager=False,
87
+ dtype="bfloat16",
88
+ gpu_memory_utilization=cfg.vllm.gpu_memory_utilization,
89
+ tokenizer=None,
90
+ skip_tokenizer_init=True,
91
+ enable_prompt_embeds=True,
92
+ enable_chunked_prefill=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  )
94
+ print("LLM 初始化成功")
95
+ auto_prompt = torch.load('tools/new_prompt.pt')
96
 
97
+ guidance_scale = cfg.vllm.guidance_scale
98
+ temp = cfg.vllm.temp
99
+ top_k = cfg.vllm.top_k
100
+ sum_time = 0
101
+ sum_wav_len = 0
 
 
 
 
102
  os.makedirs(save_dir, exist_ok=True)
103
  os.makedirs(save_dir + "/audios", exist_ok=True)
104
  os.makedirs(save_dir + "/jsonl", exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  with open(input_jsonl, "r") as fp:
106
  lines = fp.readlines()
 
 
 
 
 
 
 
 
 
 
107
  new_items = []
108
  for line in lines:
109
  item = json.loads(line)
110
+ lyric = item["gt_lyric"]
111
+ descriptions = item["descriptions"].lower() if "descriptions" in item else '.'
112
+ descriptions = '[Musicality-very-high]' + ', ' + descriptions
113
  target_wav_name = f"{save_dir}/audios/{item['idx']}.flac"
114
+ if os.path.exists(target_wav_name):
115
+ continue
116
  if "prompt_audio_path" in item:
117
  assert os.path.exists(item['prompt_audio_path']), f"prompt_audio_path {item['prompt_audio_path']} not found"
118
  assert 'auto_prompt_audio_type' not in item, f"auto_prompt_audio_type and prompt_audio_path cannot be used together"
119
  with torch.no_grad():
120
+ pmt_wav = load_audio(item['prompt_audio_path'])
121
  item['raw_pmt_wav'] = pmt_wav
 
 
122
  if pmt_wav.dim() == 2:
123
  pmt_wav = pmt_wav[None]
124
  if pmt_wav.dim() != 3:
125
  raise ValueError("Melody wavs should have a shape [B, C, T].")
126
  pmt_wav = list(pmt_wav)
 
 
 
 
 
 
 
 
 
 
127
  if type(pmt_wav) == list:
128
  pmt_wav = torch.stack(pmt_wav, dim=0)
 
 
 
 
129
  with torch.no_grad():
130
  pmt_wav, _ = audio_tokenizer.encode(pmt_wav.cuda())
131
+ print(pmt_wav.shape)
132
  melody_is_wav = False
133
  elif "auto_prompt_audio_type" in item:
134
  assert item["auto_prompt_audio_type"] in auto_prompt_type, f"auto_prompt_audio_type {item['auto_prompt_audio_type']} not found"
135
+ lang = check_language_by_text(item['gt_lyric'])
136
+ prompt_token = auto_prompt[item["auto_prompt_audio_type"]][lang][np.random.randint(0, len(auto_prompt[item["auto_prompt_audio_type"]][lang]))]
137
  pmt_wav = prompt_token[:,[0],:]
 
 
138
  melody_is_wav = False
139
  else:
140
  pmt_wav = None
 
 
141
  melody_is_wav = True
 
 
 
 
142
  item["idx"] = f"{item['idx']}"
143
  item["wav_path"] = target_wav_name
144
+ embeded_eosp1 = torch.load(cfg.lm_checkpoint+'/embeded_eosp1.pt')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  generate_inp = {
146
+ 'descriptions': [lyric.replace(" ", " ")],
147
+ 'type_info': [descriptions],
148
  'melody_wavs': pmt_wav,
 
 
149
  'melody_is_wav': melody_is_wav,
150
+ 'embeded_eosp1': embeded_eosp1,
151
  }
152
+ fused_input, audio_qt_embs = model_condition.generate_condition(**generate_inp, return_tokens=True)
153
+ prompt_token = audio_qt_embs[0][0].tolist() if audio_qt_embs else []
154
+ allowed_token_ids = [x for x in range(cfg.lm.code_size+1) if x not in prompt_token]
155
+ sampling_params = SamplingParams(
156
+ max_tokens=cfg.audio_tokenizer_frame_rate*cfg.max_dur,
157
+ temperature=temp,
158
+ stop_token_ids=[cfg.lm.code_size],
159
+ top_k=top_k,
160
+ frequency_penalty=0.2,
161
+ seed=int(time.time() * 1000000) % (2**32) if cfg.vllm.cfg else -1,
162
+ allowed_token_ids=allowed_token_ids,
163
+ guidance_scale=guidance_scale
164
+ )
165
+ # 拆成现支持的batch 3 CFG形式
166
+ prompts = [{"prompt_embeds": embed} for embed in fused_input]
167
+ promptss = []
168
+ for _ in range(2):
169
+ promptss+=prompts
170
+ uncondi = prompts[1]
171
+ promptss = promptss[::2] + [uncondi]
172
+ start_time = time.time()
173
+ outputs = llm.generate(promptss, sampling_params=sampling_params)
174
+ mid_time = time.time()
175
+ token_ids_CFG = torch.tensor(outputs[1].outputs[0].token_ids)
176
+ token_ids_CFG = token_ids_CFG[:-1].unsqueeze(0).unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
 
178
  with torch.no_grad():
179
+ # wav_nocfg = model_condition.generate_audio(token_ids)
180
  if 'raw_pmt_wav' in item:
181
+ wav_cfg = model_condition.generate_audio(token_ids_CFG, item['raw_pmt_wav'])
 
 
 
 
 
 
 
182
  del item['raw_pmt_wav']
 
 
183
  else:
184
+ wav_cfg = model_condition.generate_audio(token_ids_CFG)
185
+ end_time = time.time()
186
+ torchaudio.save(target_wav_name, wav_cfg[0].cpu().float(), cfg.sample_rate)
187
+ sum_time += end_time - start_time
188
+ sum_wav_len += (token_ids_CFG.shape[-1] / 25)
189
+ print(f"process{item['idx']}, lm cost {mid_time - start_time}s, diffusion cost {end_time - mid_time}, rtf {(end_time - start_time) / token_ids_CFG.shape[-1] * 25:.2f}")
190
+ new_items.append(item)
191
+ print(f"Total time: {sum_time:.4f} seconds, total wav length: {sum_wav_len:.4f} seconds, rtf {sum_time/sum_wav_len:.2f}")
 
 
 
 
 
 
 
 
 
 
 
192
 
 
 
 
193
  src_jsonl_name = os.path.split(input_jsonl)[-1]
194
  with open(f"{save_dir}/jsonl/{src_jsonl_name}.jsonl", "w", encoding='utf-8') as fw:
195
  for item in new_items:
196
  fw.writelines(json.dumps(item, ensure_ascii=False)+"\n")
197
 
 
198
  if __name__ == "__main__":
199
+ main()
200
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
generate.sh CHANGED
@@ -3,70 +3,15 @@ export PYTHONDONTWRITEBYTECODE=1
3
  export TRANSFORMERS_CACHE="$(pwd)/third_party/hub"
4
  export NCCL_HOME=/usr/local/tccl
5
  export PYTHONPATH="$(pwd)/codeclm/tokenizer/":"$(pwd)":"$(pwd)/codeclm/tokenizer/Flow1dVAE/":"$(pwd)/codeclm/tokenizer/":$PYTHONPATH
 
 
 
6
 
7
- CKPT_PATH=$1
 
8
  JSONL=$2
9
  SAVE_DIR=$3
10
- USE_FLASH_ATTN="True"
11
- LOW_MEM="False"
12
- GENERATE_TYPE="mixed"
13
- for arg in "$@"; do
14
- if [[ $arg == "--not_use_flash_attn" ]]; then
15
- USE_FLASH_ATTN="False"
16
- fi
17
- done
18
- for arg in "$@"; do
19
- if [[ $arg == "--low_mem" ]]; then
20
- LOW_MEM="True"
21
- fi
22
- done
23
- for arg in "$@"; do
24
- if [[ $arg == "--separate" ]]; then
25
- GENERATE_TYPE="separate"
26
- fi
27
- done
28
- for arg in "$@"; do
29
- if [[ $arg == "--bgm" ]]; then
30
- GENERATE_TYPE="bgm"
31
- fi
32
- done
33
- for arg in "$@"; do
34
- if [[ $arg == "--vocal" ]]; then
35
- GENERATE_TYPE="vocal"
36
- fi
37
- done
38
-
39
-
40
- if [ "$USE_FLASH_ATTN" == "True" ] && [ "$LOW_MEM" == "True" ]; then
41
- echo "Use Flash Attention + Low Memory Mode"
42
- python3 generate.py \
43
- --ckpt_path $CKPT_PATH \
44
- --input_jsonl $JSONL \
45
- --save_dir $SAVE_DIR \
46
- --generate_type $GENERATE_TYPE \
47
- --use_flash_attn \
48
- --low_mem
49
- elif [ "$USE_FLASH_ATTN" == "True" ] && [ "$LOW_MEM" == "False" ]; then
50
- echo "Use Flash Attention + Auto Memory Mode"
51
- python3 generate.py \
52
- --ckpt_path $CKPT_PATH \
53
- --input_jsonl $JSONL \
54
- --save_dir $SAVE_DIR \
55
- --generate_type $GENERATE_TYPE \
56
- --use_flash_attn
57
- elif [ "$USE_FLASH_ATTN" == "False" ] && [ "$LOW_MEM" == "False" ]; then
58
- echo "Not Use Flash Attention + Auto Memory Mode"
59
- python3 generate.py \
60
- --ckpt_path $CKPT_PATH \
61
- --input_jsonl $JSONL \
62
- --generate_type $GENERATE_TYPE \
63
- --save_dir $SAVE_DIR
64
- elif [ "$USE_FLASH_ATTN" == "False" ] && [ "$LOW_MEM" == "True" ]; then
65
- echo "Not Use Flash Attention + Low Memory Mode"
66
- python3 generate.py \
67
- --ckpt_path $CKPT_PATH \
68
- --input_jsonl $JSONL \
69
- --save_dir $SAVE_DIR \
70
- --generate_type $GENERATE_TYPE \
71
- --low_mem
72
- fi
 
3
  export TRANSFORMERS_CACHE="$(pwd)/third_party/hub"
4
  export NCCL_HOME=/usr/local/tccl
5
  export PYTHONPATH="$(pwd)/codeclm/tokenizer/":"$(pwd)":"$(pwd)/codeclm/tokenizer/Flow1dVAE/":"$(pwd)/codeclm/tokenizer/":$PYTHONPATH
6
+ export OMP_NUM_THREADS=1
7
+ export MKL_NUM_THREADS=1
8
+ export CUDA_LAUNCH_BLOCKING=0
9
 
10
+
11
+ CONFIG_PATH=$1
12
  JSONL=$2
13
  SAVE_DIR=$3
14
+ python3 generate.py \
15
+ --input_jsonl $JSONL \
16
+ --save_dir $SAVE_DIR \
17
+ --config_path $CONFIG_PATH
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
levo_inference.py CHANGED
@@ -1,22 +1,19 @@
1
  import os
2
  import sys
3
-
4
 
5
  sys.path.append('./codeclm/tokenizer')
6
  sys.path.append('./codeclm/tokenizer/Flow1dVAE')
7
  sys.path.append('.')
8
 
9
  import torch
10
-
11
- import json
12
  import numpy as np
13
  from omegaconf import OmegaConf
 
14
 
15
  from codeclm.models import builders
16
- from codeclm.models import CodecLM
17
-
18
- from separator import Separator
19
- from generate import check_language_by_text
20
 
21
 
22
  class LeVoInference(torch.nn.Module):
@@ -30,39 +27,37 @@ class LeVoInference(torch.nn.Module):
30
  OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x)))
31
 
32
  cfg_path = os.path.join(ckpt_path, 'config.yaml')
33
- pt_path = os.path.join(ckpt_path, 'model.pt')
34
-
35
  self.cfg = OmegaConf.load(cfg_path)
36
  self.cfg.mode = 'inference'
37
  self.max_duration = self.cfg.max_dur
38
 
39
- # Define model or load pretrained model
40
- audiolm = builders.get_lm_model(self.cfg, version='v1.5')
41
- checkpoint = torch.load(pt_path, map_location='cpu')
42
- audiolm_state_dict = {k.replace('audiolm.', ''): v for k, v in checkpoint.items() if k.startswith('audiolm')}
43
- audiolm.load_state_dict(audiolm_state_dict, strict=False)
44
- audiolm = audiolm.eval()
45
- audiolm = audiolm.cuda().to(torch.float16)
46
-
47
  audio_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint, self.cfg)
48
- audio_tokenizer = audio_tokenizer.eval()
49
-
50
- seperate_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint_sep, self.cfg)
51
- seperate_tokenizer = seperate_tokenizer.eval()
52
-
53
- self.model = CodecLM(name = "tmp",
54
- lm = audiolm,
55
- audiotokenizer = audio_tokenizer,
56
- max_duration = self.max_duration,
57
- seperate_tokenizer = seperate_tokenizer,
 
 
 
 
 
 
 
 
 
 
58
  )
59
- self.separator = Separator()
60
-
61
 
62
  self.default_params = dict(
63
- cfg_coef = 1.5,
64
- temperature = 1.0,
65
- top_k = 50,
66
  top_p = 0.0,
67
  record_tokens = True,
68
  record_window = 50,
@@ -70,14 +65,11 @@ class LeVoInference(torch.nn.Module):
70
  duration = self.max_duration,
71
  )
72
 
73
- self.model.set_generation_params(**self.default_params)
74
-
75
  def forward(self, lyric: str, description: str = None, prompt_audio_path: os.PathLike = None, genre: str = None, auto_prompt_path: os.PathLike = None, gen_type: str = "mixed", params = dict()):
76
  params = {**self.default_params, **params}
77
- self.model.set_generation_params(**params)
78
 
79
  if prompt_audio_path is not None and os.path.exists(prompt_audio_path):
80
- pmt_wav, vocal_wav, bgm_wav = self.separator.run(prompt_audio_path)
81
  melody_is_wav = True
82
  elif genre is not None and auto_prompt_path is not None:
83
  auto_prompt = torch.load(auto_prompt_path)
@@ -87,33 +79,48 @@ class LeVoInference(torch.nn.Module):
87
  else:
88
  prompt_token = auto_prompt[genre][np.random.randint(0, len(auto_prompt[genre]))]
89
  pmt_wav = prompt_token[:,[0],:]
90
- vocal_wav = prompt_token[:,[1],:]
91
- bgm_wav = prompt_token[:,[2],:]
92
  melody_is_wav = False
93
  else:
94
  pmt_wav = None
95
- vocal_wav = None
96
- bgm_wav = None
97
  melody_is_wav = True
98
 
99
  description = description if description else '.'
100
  description = '[Musicality-very-high]' + ', ' + description
101
  generate_inp = {
102
- 'lyrics': [lyric.replace(" ", " ")],
103
- 'descriptions': [description],
104
  'melody_wavs': pmt_wav,
105
- 'vocal_wavs': vocal_wav,
106
- 'bgm_wavs': bgm_wav,
107
  'melody_is_wav': melody_is_wav,
 
108
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
- with torch.autocast(device_type="cuda", dtype=torch.float16):
111
- tokens = self.model.generate(**generate_inp, return_tokens=True)
112
-
113
  with torch.no_grad():
114
  if melody_is_wav:
115
- wav_seperate = self.model.generate_audio(tokens, pmt_wav, vocal_wav, bgm_wav, gen_type=gen_type)
116
  else:
117
- wav_seperate = self.model.generate_audio(tokens, gen_type=gen_type)
118
 
119
- return wav_seperate[0]
 
1
  import os
2
  import sys
3
+ import time
4
 
5
  sys.path.append('./codeclm/tokenizer')
6
  sys.path.append('./codeclm/tokenizer/Flow1dVAE')
7
  sys.path.append('.')
8
 
9
  import torch
 
 
10
  import numpy as np
11
  from omegaconf import OmegaConf
12
+ from vllm import LLM, SamplingParams
13
 
14
  from codeclm.models import builders
15
+ from codeclm.models.codeclm_gen import CodecLM_gen
16
+ from generate import check_language_by_text, load_audio
 
 
17
 
18
 
19
  class LeVoInference(torch.nn.Module):
 
27
  OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x)))
28
 
29
  cfg_path = os.path.join(ckpt_path, 'config.yaml')
 
 
30
  self.cfg = OmegaConf.load(cfg_path)
31
  self.cfg.mode = 'inference'
32
  self.max_duration = self.cfg.max_dur
33
 
 
 
 
 
 
 
 
 
34
  audio_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint, self.cfg)
35
+ if audio_tokenizer is not None:
36
+ for param in audio_tokenizer.parameters():
37
+ param.requires_grad = False
38
+ print("Audio tokenizer successfully loaded!")
39
+ audio_tokenizer = audio_tokenizer.eval().cuda()
40
+ self.model_condition = CodecLM_gen(cfg=self.cfg,name = "tmp",audiotokenizer = audio_tokenizer,max_duration = self.max_duration)
41
+ self.model_condition.condition_provider.conditioners.load_state_dict(torch.load(self.cfg.lm_checkpoint+"/conditioners_weights.pth"))
42
+ self.embeded_eosp1 = torch.load(self.cfg.lm_checkpoint+'/embeded_eosp1.pt')
43
+ print('Conditioner successfully loaded!')
44
+ self.llm = LLM(
45
+ model=self.cfg.lm_checkpoint,
46
+ trust_remote_code=True,
47
+ tensor_parallel_size=self.cfg.vllm.device_num,
48
+ enforce_eager=False,
49
+ dtype="bfloat16",
50
+ gpu_memory_utilization=self.cfg.vllm.gpu_memory_utilization,
51
+ tokenizer=None,
52
+ skip_tokenizer_init=True,
53
+ enable_prompt_embeds=True,
54
+ enable_chunked_prefill=True,
55
  )
 
 
56
 
57
  self.default_params = dict(
58
+ cfg_coef = 1.8,
59
+ temperature = 0.8,
60
+ top_k = 5000,
61
  top_p = 0.0,
62
  record_tokens = True,
63
  record_window = 50,
 
65
  duration = self.max_duration,
66
  )
67
 
 
 
68
  def forward(self, lyric: str, description: str = None, prompt_audio_path: os.PathLike = None, genre: str = None, auto_prompt_path: os.PathLike = None, gen_type: str = "mixed", params = dict()):
69
  params = {**self.default_params, **params}
 
70
 
71
  if prompt_audio_path is not None and os.path.exists(prompt_audio_path):
72
+ pmt_wav = load_audio(prompt_audio_path)
73
  melody_is_wav = True
74
  elif genre is not None and auto_prompt_path is not None:
75
  auto_prompt = torch.load(auto_prompt_path)
 
79
  else:
80
  prompt_token = auto_prompt[genre][np.random.randint(0, len(auto_prompt[genre]))]
81
  pmt_wav = prompt_token[:,[0],:]
 
 
82
  melody_is_wav = False
83
  else:
84
  pmt_wav = None
 
 
85
  melody_is_wav = True
86
 
87
  description = description if description else '.'
88
  description = '[Musicality-very-high]' + ', ' + description
89
  generate_inp = {
90
+ 'descriptions': [lyric.replace(" ", " ")],
91
+ 'type_info': [description],
92
  'melody_wavs': pmt_wav,
 
 
93
  'melody_is_wav': melody_is_wav,
94
+ 'embeded_eosp1': self.embeded_eosp1,
95
  }
96
+ fused_input, audio_qt_embs = self.model_condition.generate_condition(**generate_inp, return_tokens=True)
97
+ prompt_token = audio_qt_embs[0][0].tolist() if audio_qt_embs else []
98
+ allowed_token_ids = [x for x in range(self.cfg.lm.code_size+1) if x not in prompt_token]
99
+ sampling_params = SamplingParams(
100
+ max_tokens=self.cfg.audio_tokenizer_frame_rate*self.max_duration,
101
+ temperature=params["temperature"],
102
+ stop_token_ids=[self.cfg.lm.code_size],
103
+ top_k=params["top_k"],
104
+ frequency_penalty=0.2,
105
+ seed=int(time.time() * 1000000) % (2**32) if self.cfg.vllm.cfg else -1,
106
+ allowed_token_ids=allowed_token_ids,
107
+ guidance_scale=params["cfg_coef"]
108
+ )
109
+ # 拆成现支持的batch 3 CFG形式
110
+ prompts = [{"prompt_embeds": embed} for embed in fused_input]
111
+ promptss = []
112
+ for _ in range(2):
113
+ promptss+=prompts
114
+ uncondi = prompts[1]
115
+ promptss = promptss[::2] + [uncondi]
116
+ outputs = self.llm.generate(promptss, sampling_params=sampling_params)
117
+ token_ids_CFG = torch.tensor(outputs[1].outputs[0].token_ids)
118
+ token_ids_CFG = token_ids_CFG[:-1].unsqueeze(0).unsqueeze(0)
119
 
 
 
 
120
  with torch.no_grad():
121
  if melody_is_wav:
122
+ wav_cfg = self.model_condition.generate_audio(token_ids_CFG, pmt_wav)
123
  else:
124
+ wav_cfg = self.model_condition.generate_audio(token_ids_CFG)
125
 
126
+ return wav_cfg[0]
requirements.txt CHANGED
@@ -0,0 +1 @@
 
 
1
+ gradio>=6.5.1
sample/lyrics.jsonl CHANGED
@@ -1,4 +1,3 @@
1
- {"idx": "sample_01_autoprompt", "gt_lyric": "[intro-short] ; [verse] 雪花舞动在无尽的天际.情缘如同雪花般轻轻逝去.希望与真挚.永不磨灭.你的忧虑.随风而逝 ; [chorus] 我怀抱着守护这片梦境.在这世界中寻找爱与虚幻.苦辣酸甜.我们一起品尝.在雪的光芒中.紧紧相拥 ; [inst-short] ; [verse] 雪花再次在风中飘扬.情愿如同雪花般消失无踪.希望与真挚.永不消失.在痛苦与喧嚣中.你找到解脱 ; [chorus] 我环绕着守护这片梦境.在这世界中感受爱与虚假.苦辣酸甜.我们一起分享.在白银的光芒中.我们同在 ; [outro-short]", "auto_prompt_audio_type": "Auto"}
2
  {"idx": "sample_01_noprompt", "gt_lyric": "[intro-short] ; [verse] 雪花舞动在无尽的天际.情缘如同雪花般轻轻逝去.希望与真挚.永不磨灭.你的忧虑.随风而逝 ; [chorus] 我怀抱着守护这片梦境.在这世界中寻找爱与虚幻.苦辣酸甜.我们一起品尝.在雪的光芒中.紧紧相拥 ; [inst-short] ; [verse] 雪花再次在风中飘扬.情愿如同雪花般消失无踪.希望与真挚.永不消失.在痛苦与喧嚣中.你找到解脱 ; [chorus] 我环绕着守护这片梦境.在这世界中感受爱与虚假.苦辣酸甜.我们一起分享.在白银的光芒中.我们同在 ; [outro-short]"}
3
- {"idx": "sample_01_textprompt", "descriptions": "female, dark, pop, sad, piano and drums, the bpm is 125.", "gt_lyric": "[intro-short] ; [verse] 雪花舞动在无尽的天际.情缘如同雪花般轻轻逝去.希望与真挚.永不磨灭.你的忧虑.随风而逝 ; [chorus] 我怀抱着守护这片梦境.在这世界中寻找爱与虚幻.苦辣酸甜.我们一起品尝.在雪的光芒中.紧紧相拥 ; [inst-short] ; [verse] 雪花再次在风中飘扬.情愿如同雪花般消失无踪.希望与真挚.永不消失.在痛苦与喧嚣中.你找到解脱 ; [chorus] 我环绕着守护这片梦境.在这世界中感受爱与虚假.苦辣酸甜.我们一起分享.在白银的光芒中.我们同在 ; [outro-short]"}
4
- {"idx": "sample_01_audioprompt", "gt_lyric": "[intro-short] ; [verse] 雪花舞动在无尽的天际.情缘如同雪花般轻轻逝去.希望与真挚.永不磨灭.你的忧虑.随风而逝 ; [chorus] 我怀抱着守护这片梦境.在这世界中寻找爱与虚幻.苦辣酸甜.我们一起品尝.在雪的光芒中.紧紧相拥 ; [inst-short] ; [verse] 雪花再次在风中飘扬.情愿如同雪花般消失无踪.希望与真挚.永不消失.在痛苦与喧嚣中.你找到解脱 ; [chorus] 我环绕着守护这片梦境.在这世界中感受爱与虚假.苦辣酸甜.我们一起分享.在白银的光芒中.我们同在 ; [outro-short]", "prompt_audio_path": "input/sample_prompt_audio.wav"}
 
 
1
  {"idx": "sample_01_noprompt", "gt_lyric": "[intro-short] ; [verse] 雪花舞动在无尽的天际.情缘如同雪花般轻轻逝去.希望与真挚.永不磨灭.你的忧虑.随风而逝 ; [chorus] 我怀抱着守护这片梦境.在这世界中寻找爱与虚幻.苦辣酸甜.我们一起品尝.在雪的光芒中.紧紧相拥 ; [inst-short] ; [verse] 雪花再次在风中飘扬.情愿如同雪花般消失无踪.希望与真挚.永不消失.在痛苦与喧嚣中.你找到解脱 ; [chorus] 我环绕着守护这片梦境.在这世界中感受爱与虚假.苦辣酸甜.我们一起分享.在白银的光芒中.我们同在 ; [outro-short]"}
2
+ {"idx": "sample_01_textprompt", "descriptions": "female, dark, pop, sad, guitar and drums, the bpm is 125.", "gt_lyric": "[intro-short] ; [verse] 雪花舞动在无尽的天际.情缘如同雪花般轻轻逝去.希望与真挚.永不磨灭.你的忧虑.随风而逝 ; [chorus] 我怀抱着守护这片梦境.在这世界中寻找爱与虚幻.苦辣酸甜.我们一起品尝.在雪的光芒中.紧紧相拥 ; [inst-short] ; [verse] 雪花再次在风中飘扬.情愿如同雪花般消失无踪.希望与真挚.永不消失.在痛苦与喧嚣中.你找到解脱 ; [chorus] 我环绕着守护这片梦境.在这世界中感受爱与虚假.苦辣酸甜.我们一起分享.在白银的光芒中.我们同在 ; [outro-short]"}
3
+ {"idx": "sample_01_audioprompt", "gt_lyric": "[intro-short] ; [verse] 雪花舞动在无尽的天际.情缘如同雪花般轻轻逝去.希望与真挚.永不磨灭.你的忧虑.随风而逝 ; [chorus] 我怀抱着守护这片梦境.在这世界中寻找爱与虚幻.苦辣酸甜.我们一起品尝.在雪的光芒中.紧紧相拥 ; [inst-short] ; [verse] 雪花再次在风中飘扬.情愿如同雪花般消失无踪.希望与真挚.永不消失.在痛苦与喧嚣中.你找到解脱 ; [chorus] 我环绕着守护这片梦境.在这世界中感受爱与虚假.苦辣酸甜.我们一起分享.在白银的光芒中.我们同在 ; [outro-short]", "prompt_audio_path": "sample/sample_prompt_audio.wav"}
vllm_hacked/model_executor/layers/utils.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ """Utility methods for model layers."""
4
+ from typing import Callable, Optional
5
+
6
+ import torch
7
+
8
+ from vllm import _custom_ops as ops
9
+ from vllm import envs
10
+ from vllm.platforms import CpuArchEnum, current_platform
11
+ from vllm.utils import direct_register_custom_op
12
+
13
+
14
+ def shuffle_weight(w: torch.Tensor) -> torch.Tensor:
15
+ # Shuffle weight along the last dimension so that
16
+ # we folded the weights to adjance location
17
+ # Example:
18
+ # input:
19
+ # [[1, 2, 3, 4, 5, 6],
20
+ # [7, 8, 9, 10, 11, 12]]
21
+ # output:
22
+ # [[1, 4, 2, 5, 3, 6],
23
+ # [7, 10, 8, 11, 9, 12]]
24
+ # This will be used together with triton swiglu kernel
25
+ shape = w.shape
26
+ N = shape[-1]
27
+ first = w[..., :N // 2]
28
+ second = w[..., N // 2:]
29
+
30
+ stacked = torch.stack((first, second), dim=-1)
31
+ w_shuffled = stacked.reshape(shape)
32
+ return w_shuffled
33
+
34
+
35
+ def get_token_bin_counts_and_mask(
36
+ tokens: torch.Tensor,
37
+ vocab_size: int,
38
+ num_seqs: int,
39
+ ) -> tuple[torch.Tensor, torch.Tensor]:
40
+ # Compute the bin counts for the tokens.
41
+ # vocab_size + 1 for padding.
42
+ bin_counts = torch.zeros((num_seqs, vocab_size + 1),
43
+ dtype=torch.long,
44
+ device=tokens.device)
45
+ bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
46
+ bin_counts = bin_counts[:, :vocab_size]
47
+ mask = bin_counts > 0
48
+
49
+ return bin_counts, mask
50
+
51
+
52
+ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
53
+ output_tokens_tensor: torch.Tensor,
54
+ presence_penalties: torch.Tensor,
55
+ frequency_penalties: torch.Tensor,
56
+ repetition_penalties: torch.Tensor) -> torch.Tensor:
57
+ """
58
+ Applies penalties in place to the logits tensor
59
+ logits : The input logits tensor of shape [num_seqs, vocab_size]
60
+ prompt_tokens_tensor: A tensor containing the prompt tokens. The prompts
61
+ are padded to the maximum prompt length within the batch using
62
+ `vocab_size` as the padding value. The value `vocab_size` is used
63
+ for padding because it does not correspond to any valid token ID
64
+ in the vocabulary.
65
+ output_tokens_tensor: The output tokens tensor.
66
+ presence_penalties: The presence penalties of shape (num_seqs, )
67
+ frequency_penalties: The frequency penalties of shape (num_seqs, )
68
+ repetition_penalties: The repetition penalties of shape (num_seqs, )
69
+ """
70
+ num_seqs, vocab_size = logits.shape
71
+ _, prompt_mask = get_token_bin_counts_and_mask(prompt_tokens_tensor,
72
+ vocab_size, num_seqs)
73
+ output_bin_counts, output_mask = get_token_bin_counts_and_mask(
74
+ output_tokens_tensor, vocab_size, num_seqs)
75
+
76
+ # Apply repetition penalties as a custom op
77
+ from vllm._custom_ops import apply_repetition_penalties
78
+ apply_repetition_penalties(logits, prompt_mask, output_mask,
79
+ repetition_penalties)
80
+
81
+ # We follow the definition in OpenAI API.
82
+ # Refer to https://platform.openai.com/docs/api-reference/parameter-details
83
+ logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts
84
+ # logits /= (1+frequency_penalties).unsqueeze(dim=1) ** output_bin_counts # 修改频率惩罚方式,先不改,有负有正反而encourage
85
+ logits -= presence_penalties.unsqueeze(dim=1) * output_mask
86
+ return logits
87
+
88
+
89
+ def default_unquantized_gemm(layer: torch.nn.Module,
90
+ x: torch.Tensor,
91
+ weight: torch.Tensor,
92
+ bias: Optional[torch.Tensor] = None):
93
+ return torch.nn.functional.linear(x, weight, bias)
94
+
95
+
96
+ def rocm_unquantized_gemm_impl(
97
+ x: torch.Tensor,
98
+ weight: torch.Tensor,
99
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
100
+ from vllm.platforms.rocm import on_gfx9
101
+ k = weight.shape[1]
102
+ use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_gfx9() and \
103
+ x.dtype in [torch.float16, torch.bfloat16] \
104
+ and k % 8 == 0)
105
+
106
+ if use_skinny is not True:
107
+ return torch.nn.functional.linear(x, weight, bias)
108
+
109
+ x_view = x.view(-1, x.size(-1))
110
+ n = x_view.shape[0]
111
+ m = weight.shape[0]
112
+ cu_count = current_platform.get_cu_count()
113
+
114
+ if m > 8 and 0 < n <= 4:
115
+ out = ops.wvSplitK(weight, x_view, cu_count, bias)
116
+ return out.view(*x.shape[:-1], weight.shape[0])
117
+ elif m % 4 == 0 and n == 1 and k <= 8192 and bias is None:
118
+ out = ops.LLMM1(weight, x_view, 4)
119
+ return out.view(*x.shape[:-1], weight.shape[0])
120
+ return torch.nn.functional.linear(x, weight, bias)
121
+
122
+
123
+ def rocm_unquantized_gemm_impl_fake(
124
+ x: torch.Tensor,
125
+ weight: torch.Tensor,
126
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
127
+ return x.new_empty((*x.shape[:-1], weight.shape[0]))
128
+
129
+
130
+ def rocm_unquantized_gemm(layer: torch.nn.Module,
131
+ x: torch.Tensor,
132
+ weight: torch.Tensor,
133
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
134
+ return torch.ops.vllm.rocm_unquantized_gemm_impl(x, weight, bias)
135
+
136
+
137
+ direct_register_custom_op(
138
+ op_name="rocm_unquantized_gemm_impl",
139
+ op_func=rocm_unquantized_gemm_impl,
140
+ fake_impl=rocm_unquantized_gemm_impl_fake,
141
+ )
142
+
143
+
144
+ def check_cpu_sgl_kernel(n: int, k: int, dtype: torch.dtype) -> bool:
145
+ return (torch._C._cpu._is_amx_tile_supported()
146
+ and (dtype in (torch.bfloat16, torch.int8)) and k % 32 == 0
147
+ and n % 16 == 0)
148
+
149
+
150
+ def dispatch_cpu_unquantized_gemm(
151
+ layer: torch.nn.Module,
152
+ remove_weight: bool,
153
+ ) -> None:
154
+ N, K = layer.weight.size()
155
+ dtype = layer.weight.dtype
156
+ if envs.VLLM_CPU_SGL_KERNEL and check_cpu_sgl_kernel(N, K, dtype):
157
+ packed_weight = torch.ops._C.convert_weight_packed(layer.weight)
158
+ if getattr(layer, "bias", None) is not None:
159
+ bias_f32 = layer.bias.to(torch.float32)
160
+ else:
161
+ bias_f32 = None
162
+ layer.cpu_linear = (
163
+ lambda x, weight, bias: torch.ops._C.weight_packed_linear(
164
+ x, packed_weight, bias_f32
165
+ if bias is not None else None, True))
166
+ if remove_weight:
167
+ layer.weight = torch.nn.Parameter(torch.empty(0),
168
+ requires_grad=False)
169
+ elif (ops._supports_onednn
170
+ and current_platform.get_cpu_architecture() == CpuArchEnum.X86):
171
+ origin_weight = layer.weight
172
+ if remove_weight:
173
+ layer.weight = torch.nn.Parameter(torch.empty(0),
174
+ requires_grad=False)
175
+ handler = ops.create_onednn_mm(origin_weight.t(), 32)
176
+ layer.cpu_linear = lambda x, weight, bias: ops.onednn_mm(
177
+ handler, x, bias)
178
+ else:
179
+ layer.cpu_linear = lambda x, weight, bias: torch.nn.functional.linear(
180
+ x, weight, bias)
181
+
182
+
183
+ def cpu_unquantized_gemm(layer: torch.nn.Module,
184
+ x: torch.Tensor,
185
+ weight: torch.Tensor,
186
+ bias: Optional[torch.Tensor] = None):
187
+ return layer.cpu_linear(x, weight, bias)
188
+
189
+
190
+ def dispatch_unquantized_gemm() -> Callable[..., torch.Tensor]:
191
+ if current_platform.is_rocm():
192
+ return rocm_unquantized_gemm
193
+ elif current_platform.is_cpu():
194
+ return cpu_unquantized_gemm
195
+ else:
196
+ return default_unquantized_gemm
vllm_hacked/model_executor/layers/utils_ori.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ """Utility methods for model layers."""
4
+ from typing import Callable, Optional
5
+
6
+ import torch
7
+
8
+ from vllm import _custom_ops as ops
9
+ from vllm import envs
10
+ from vllm.platforms import CpuArchEnum, current_platform
11
+ from vllm.utils import direct_register_custom_op
12
+
13
+
14
+ def shuffle_weight(w: torch.Tensor) -> torch.Tensor:
15
+ # Shuffle weight along the last dimension so that
16
+ # we folded the weights to adjance location
17
+ # Example:
18
+ # input:
19
+ # [[1, 2, 3, 4, 5, 6],
20
+ # [7, 8, 9, 10, 11, 12]]
21
+ # output:
22
+ # [[1, 4, 2, 5, 3, 6],
23
+ # [7, 10, 8, 11, 9, 12]]
24
+ # This will be used together with triton swiglu kernel
25
+ shape = w.shape
26
+ N = shape[-1]
27
+ first = w[..., :N // 2]
28
+ second = w[..., N // 2:]
29
+
30
+ stacked = torch.stack((first, second), dim=-1)
31
+ w_shuffled = stacked.reshape(shape)
32
+ return w_shuffled
33
+
34
+
35
+ def get_token_bin_counts_and_mask(
36
+ tokens: torch.Tensor,
37
+ vocab_size: int,
38
+ num_seqs: int,
39
+ ) -> tuple[torch.Tensor, torch.Tensor]:
40
+ # Compute the bin counts for the tokens.
41
+ # vocab_size + 1 for padding.
42
+ bin_counts = torch.zeros((num_seqs, vocab_size + 1),
43
+ dtype=torch.long,
44
+ device=tokens.device)
45
+ bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
46
+ bin_counts = bin_counts[:, :vocab_size]
47
+ mask = bin_counts > 0
48
+
49
+ return bin_counts, mask
50
+
51
+
52
+ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
53
+ output_tokens_tensor: torch.Tensor,
54
+ presence_penalties: torch.Tensor,
55
+ frequency_penalties: torch.Tensor,
56
+ repetition_penalties: torch.Tensor) -> torch.Tensor:
57
+ """
58
+ Applies penalties in place to the logits tensor
59
+ logits : The input logits tensor of shape [num_seqs, vocab_size]
60
+ prompt_tokens_tensor: A tensor containing the prompt tokens. The prompts
61
+ are padded to the maximum prompt length within the batch using
62
+ `vocab_size` as the padding value. The value `vocab_size` is used
63
+ for padding because it does not correspond to any valid token ID
64
+ in the vocabulary.
65
+ output_tokens_tensor: The output tokens tensor.
66
+ presence_penalties: The presence penalties of shape (num_seqs, )
67
+ frequency_penalties: The frequency penalties of shape (num_seqs, )
68
+ repetition_penalties: The repetition penalties of shape (num_seqs, )
69
+ """
70
+ num_seqs, vocab_size = logits.shape
71
+ _, prompt_mask = get_token_bin_counts_and_mask(prompt_tokens_tensor,
72
+ vocab_size, num_seqs)
73
+ output_bin_counts, output_mask = get_token_bin_counts_and_mask(
74
+ output_tokens_tensor, vocab_size, num_seqs)
75
+
76
+ # Apply repetition penalties as a custom op
77
+ from vllm._custom_ops import apply_repetition_penalties
78
+ apply_repetition_penalties(logits, prompt_mask, output_mask,
79
+ repetition_penalties)
80
+
81
+ # We follow the definition in OpenAI API.
82
+ # Refer to https://platform.openai.com/docs/api-reference/parameter-details
83
+ logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts
84
+ logits -= presence_penalties.unsqueeze(dim=1) * output_mask
85
+ return logits
86
+
87
+
88
+ def default_unquantized_gemm(layer: torch.nn.Module,
89
+ x: torch.Tensor,
90
+ weight: torch.Tensor,
91
+ bias: Optional[torch.Tensor] = None):
92
+ return torch.nn.functional.linear(x, weight, bias)
93
+
94
+
95
+ def rocm_unquantized_gemm_impl(
96
+ x: torch.Tensor,
97
+ weight: torch.Tensor,
98
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
99
+ from vllm.platforms.rocm import on_gfx9
100
+ k = weight.shape[1]
101
+ use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_gfx9() and \
102
+ x.dtype in [torch.float16, torch.bfloat16] \
103
+ and k % 8 == 0)
104
+
105
+ if use_skinny is not True:
106
+ return torch.nn.functional.linear(x, weight, bias)
107
+
108
+ x_view = x.view(-1, x.size(-1))
109
+ n = x_view.shape[0]
110
+ m = weight.shape[0]
111
+ cu_count = current_platform.get_cu_count()
112
+
113
+ if m > 8 and 0 < n <= 4:
114
+ out = ops.wvSplitK(weight, x_view, cu_count, bias)
115
+ return out.view(*x.shape[:-1], weight.shape[0])
116
+ elif m % 4 == 0 and n == 1 and k <= 8192 and bias is None:
117
+ out = ops.LLMM1(weight, x_view, 4)
118
+ return out.view(*x.shape[:-1], weight.shape[0])
119
+ return torch.nn.functional.linear(x, weight, bias)
120
+
121
+
122
+ def rocm_unquantized_gemm_impl_fake(
123
+ x: torch.Tensor,
124
+ weight: torch.Tensor,
125
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
126
+ return x.new_empty((*x.shape[:-1], weight.shape[0]))
127
+
128
+
129
+ def rocm_unquantized_gemm(layer: torch.nn.Module,
130
+ x: torch.Tensor,
131
+ weight: torch.Tensor,
132
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
133
+ return torch.ops.vllm.rocm_unquantized_gemm_impl(x, weight, bias)
134
+
135
+
136
+ direct_register_custom_op(
137
+ op_name="rocm_unquantized_gemm_impl",
138
+ op_func=rocm_unquantized_gemm_impl,
139
+ fake_impl=rocm_unquantized_gemm_impl_fake,
140
+ )
141
+
142
+
143
+ def check_cpu_sgl_kernel(n: int, k: int, dtype: torch.dtype) -> bool:
144
+ return (torch._C._cpu._is_amx_tile_supported()
145
+ and (dtype in (torch.bfloat16, torch.int8)) and k % 32 == 0
146
+ and n % 16 == 0)
147
+
148
+
149
+ def dispatch_cpu_unquantized_gemm(
150
+ layer: torch.nn.Module,
151
+ remove_weight: bool,
152
+ ) -> None:
153
+ N, K = layer.weight.size()
154
+ dtype = layer.weight.dtype
155
+ if envs.VLLM_CPU_SGL_KERNEL and check_cpu_sgl_kernel(N, K, dtype):
156
+ packed_weight = torch.ops._C.convert_weight_packed(layer.weight)
157
+ if getattr(layer, "bias", None) is not None:
158
+ bias_f32 = layer.bias.to(torch.float32)
159
+ else:
160
+ bias_f32 = None
161
+ layer.cpu_linear = (
162
+ lambda x, weight, bias: torch.ops._C.weight_packed_linear(
163
+ x, packed_weight, bias_f32
164
+ if bias is not None else None, True))
165
+ if remove_weight:
166
+ layer.weight = torch.nn.Parameter(torch.empty(0),
167
+ requires_grad=False)
168
+ elif (ops._supports_onednn
169
+ and current_platform.get_cpu_architecture() == CpuArchEnum.X86):
170
+ origin_weight = layer.weight
171
+ if remove_weight:
172
+ layer.weight = torch.nn.Parameter(torch.empty(0),
173
+ requires_grad=False)
174
+ handler = ops.create_onednn_mm(origin_weight.t(), 32)
175
+ layer.cpu_linear = lambda x, weight, bias: ops.onednn_mm(
176
+ handler, x, bias)
177
+ else:
178
+ layer.cpu_linear = lambda x, weight, bias: torch.nn.functional.linear(
179
+ x, weight, bias)
180
+
181
+
182
+ def cpu_unquantized_gemm(layer: torch.nn.Module,
183
+ x: torch.Tensor,
184
+ weight: torch.Tensor,
185
+ bias: Optional[torch.Tensor] = None):
186
+ return layer.cpu_linear(x, weight, bias)
187
+
188
+
189
+ def dispatch_unquantized_gemm() -> Callable[..., torch.Tensor]:
190
+ if current_platform.is_rocm():
191
+ return rocm_unquantized_gemm
192
+ elif current_platform.is_cpu():
193
+ return cpu_unquantized_gemm
194
+ else:
195
+ return default_unquantized_gemm
vllm_hacked/model_executor/models/llama.py ADDED
@@ -0,0 +1,688 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
+ # Adapted from
5
+ # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
6
+ # Copyright 2023 The vLLM team.
7
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
8
+ #
9
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
10
+ # and OPT implementations in this library. It has been modified from its
11
+ # original forms to accommodate minor architectural differences compared
12
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
13
+ #
14
+ # Licensed under the Apache License, Version 2.0 (the "License");
15
+ # you may not use this file except in compliance with the License.
16
+ # You may obtain a copy of the License at
17
+ #
18
+ # http://www.apache.org/licenses/LICENSE-2.0
19
+ #
20
+ # Unless required by applicable law or agreed to in writing, software
21
+ # distributed under the License is distributed on an "AS IS" BASIS,
22
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23
+ # See the License for the specific language governing permissions and
24
+ # limitations under the License.
25
+ """Inference-only LLaMA model compatible with HuggingFace weights."""
26
+ from collections.abc import Iterable
27
+ from itertools import islice
28
+ from typing import Any, Optional, Union
29
+
30
+ import torch
31
+ from torch import nn
32
+ from transformers import LlamaConfig
33
+
34
+ from vllm.attention import Attention, AttentionType
35
+ from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
36
+ from vllm.compilation.decorators import support_torch_compile
37
+ from vllm.config import CacheConfig, VllmConfig
38
+ from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
39
+ from vllm.model_executor.layers.activation import SiluAndMul
40
+ from vllm.model_executor.layers.layernorm import RMSNorm
41
+ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
42
+ QKVParallelLinear,
43
+ RowParallelLinear)
44
+ from vllm.model_executor.layers.logits_processor import LogitsProcessor
45
+ from vllm.model_executor.layers.quantization import QuantizationConfig
46
+ from vllm.model_executor.layers.rotary_embedding import get_rope
47
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
48
+ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
49
+ from vllm.model_executor.model_loader.weight_utils import (
50
+ default_weight_loader, maybe_remap_kv_scale_name)
51
+ from vllm.sequence import IntermediateTensors
52
+
53
+ from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
54
+ from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
55
+ is_pp_missing_parameter,
56
+ make_empty_intermediate_tensors_factory, make_layers,
57
+ maybe_prefix)
58
+
59
+
60
+ class LlamaMLP(nn.Module):
61
+
62
+ def __init__(
63
+ self,
64
+ hidden_size: int,
65
+ intermediate_size: int,
66
+ hidden_act: str,
67
+ quant_config: Optional[QuantizationConfig] = None,
68
+ bias: bool = False,
69
+ prefix: str = "",
70
+ reduce_results: bool = True,
71
+ disable_tp: bool = False,
72
+ ) -> None:
73
+ super().__init__()
74
+ self.gate_up_proj = MergedColumnParallelLinear(
75
+ input_size=hidden_size,
76
+ output_sizes=[intermediate_size] * 2,
77
+ bias=bias,
78
+ quant_config=quant_config,
79
+ disable_tp=disable_tp,
80
+ prefix=f"{prefix}.gate_up_proj",
81
+ )
82
+ self.down_proj = RowParallelLinear(
83
+ input_size=intermediate_size,
84
+ output_size=hidden_size,
85
+ bias=bias,
86
+ quant_config=quant_config,
87
+ reduce_results=reduce_results,
88
+ disable_tp=disable_tp,
89
+ prefix=f"{prefix}.down_proj",
90
+ )
91
+ if hidden_act != "silu":
92
+ raise ValueError(f"Unsupported activation: {hidden_act}. "
93
+ "Only silu is supported for now.")
94
+ self.act_fn = SiluAndMul()
95
+
96
+ def forward(self, x):
97
+ x, _ = self.gate_up_proj(x)
98
+ x = self.act_fn(x)
99
+ x, _ = self.down_proj(x)
100
+ return x
101
+
102
+
103
+ class LlamaAttention(nn.Module):
104
+
105
+ def __init__(
106
+ self,
107
+ config: LlamaConfig,
108
+ hidden_size: int,
109
+ num_heads: int,
110
+ num_kv_heads: int,
111
+ rope_theta: float = 10000,
112
+ rope_scaling: Optional[dict[str, Any]] = None,
113
+ max_position_embeddings: int = 8192,
114
+ quant_config: Optional[QuantizationConfig] = None,
115
+ bias: bool = False,
116
+ bias_o_proj: bool = False,
117
+ cache_config: Optional[CacheConfig] = None,
118
+ prefix: str = "",
119
+ attn_type: str = AttentionType.DECODER,
120
+ ) -> None:
121
+ super().__init__()
122
+ layer_idx = extract_layer_index(prefix)
123
+ self.hidden_size = hidden_size
124
+ tp_size = get_tensor_model_parallel_world_size()
125
+ self.total_num_heads = num_heads
126
+ assert self.total_num_heads % tp_size == 0
127
+ self.num_heads = self.total_num_heads // tp_size
128
+ self.total_num_kv_heads = num_kv_heads
129
+ if self.total_num_kv_heads >= tp_size:
130
+ # Number of KV heads is greater than TP size, so we partition
131
+ # the KV heads across multiple tensor parallel GPUs.
132
+ assert self.total_num_kv_heads % tp_size == 0
133
+ else:
134
+ # Number of KV heads is less than TP size, so we replicate
135
+ # the KV heads across multiple tensor parallel GPUs.
136
+ assert tp_size % self.total_num_kv_heads == 0
137
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
138
+ # MistralConfig has an optional head_dim introduced by Mistral-Nemo
139
+ head_dim = getattr(config, "head_dim", None)
140
+ if head_dim is None:
141
+ head_dim = self.hidden_size // self.total_num_heads
142
+ self.head_dim = head_dim
143
+ # Phi models introduced a partial_rotary_factor parameter in the config
144
+ self.partial_rotary_factor = getattr(config, "partial_rotary_factor",
145
+ 1)
146
+ self.q_size = self.num_heads * self.head_dim
147
+ self.kv_size = self.num_kv_heads * self.head_dim
148
+ self.scaling = self.head_dim**-0.5
149
+ self.rope_theta = rope_theta
150
+ self.max_position_embeddings = max_position_embeddings
151
+
152
+ self.qkv_proj = QKVParallelLinear(
153
+ hidden_size=hidden_size,
154
+ head_size=self.head_dim,
155
+ total_num_heads=self.total_num_heads,
156
+ total_num_kv_heads=self.total_num_kv_heads,
157
+ bias=bias,
158
+ quant_config=quant_config,
159
+ prefix=f"{prefix}.qkv_proj",
160
+ )
161
+
162
+ self.o_proj = RowParallelLinear(
163
+ input_size=self.total_num_heads * self.head_dim,
164
+ output_size=hidden_size,
165
+ bias=bias_o_proj,
166
+ quant_config=quant_config,
167
+ prefix=f"{prefix}.o_proj",
168
+ )
169
+
170
+ self._init_rotary_emb(config,
171
+ rope_scaling=rope_scaling,
172
+ quant_config=quant_config)
173
+
174
+ sliding_window = None
175
+ if layer_types := getattr(config, "layer_types", None):
176
+ # Fix for Eagle3 compatibility:
177
+ # for draft models, subtract target layer count
178
+ # to get draft-relative layer index starting from 0
179
+ if hasattr(config, 'target_layer_count'):
180
+ # This is a draft model,
181
+ # adjust layer_idx to be relative to draft layers
182
+ effective_layer_idx = layer_idx - config.target_layer_count
183
+ else:
184
+ # This is a target model, use layer_idx directly
185
+ effective_layer_idx = layer_idx
186
+ assert effective_layer_idx < len(layer_types), \
187
+ f"effective_layer_idx: {effective_layer_idx} \
188
+ is out of bounds for layer_types: {layer_types}"
189
+
190
+ is_sliding = layer_types[
191
+ effective_layer_idx] == "sliding_attention"
192
+ if is_sliding:
193
+ sliding_window = config.sliding_window
194
+
195
+ attn_cls = (EncoderOnlyAttention
196
+ if attn_type == AttentionType.ENCODER_ONLY else Attention)
197
+
198
+ self.attn = attn_cls(
199
+ self.num_heads,
200
+ self.head_dim,
201
+ self.scaling,
202
+ num_kv_heads=self.num_kv_heads,
203
+ cache_config=cache_config,
204
+ quant_config=quant_config,
205
+ per_layer_sliding_window=sliding_window,
206
+ attn_type=attn_type,
207
+ prefix=f"{prefix}.attn",
208
+ )
209
+
210
+ def forward(
211
+ self,
212
+ positions: torch.Tensor,
213
+ hidden_states: torch.Tensor,
214
+ ) -> torch.Tensor:
215
+ qkv, _ = self.qkv_proj(hidden_states)
216
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
217
+ q, k = self.rotary_emb(positions, q, k)
218
+ attn_output = self.attn(q, k, v)
219
+ output, _ = self.o_proj(attn_output)
220
+ return output
221
+
222
+ def _init_rotary_emb(self, config: LlamaConfig,
223
+ rope_scaling: Optional[dict[str, Any]],
224
+ quant_config: Optional[QuantizationConfig]) -> None:
225
+ is_neox_style = True
226
+ is_gguf = quant_config and quant_config.get_name() == "gguf"
227
+ if is_gguf and config.model_type == "llama":
228
+ is_neox_style = False
229
+
230
+ self.rotary_emb = get_rope(
231
+ self.head_dim,
232
+ rotary_dim=self.head_dim,
233
+ max_position=self.max_position_embeddings,
234
+ base=self.rope_theta,
235
+ rope_scaling=rope_scaling,
236
+ is_neox_style=is_neox_style,
237
+ partial_rotary_factor=self.partial_rotary_factor,
238
+ )
239
+
240
+
241
+ class LlamaDecoderLayer(nn.Module):
242
+
243
+ def __init__(self,
244
+ vllm_config: VllmConfig,
245
+ prefix: str = "",
246
+ config: Optional[LlamaConfig] = None) -> None:
247
+ super().__init__()
248
+
249
+ config = config or vllm_config.model_config.hf_config
250
+ cache_config = vllm_config.cache_config
251
+ quant_config = vllm_config.quant_config
252
+
253
+ self.hidden_size = config.hidden_size
254
+ rope_theta = getattr(config, "rope_theta", 10000)
255
+ rope_scaling = getattr(config, "rope_scaling", None)
256
+ if rope_scaling is not None and getattr(
257
+ config, "original_max_position_embeddings", None):
258
+ rope_scaling["original_max_position_embeddings"] = (
259
+ config.original_max_position_embeddings)
260
+ max_position_embeddings = getattr(config, "max_position_embeddings",
261
+ 8192)
262
+ # Support abacusai/Smaug-72B-v0.1 with attention_bias
263
+ # Support internlm/internlm-7b with bias
264
+ attention_bias = getattr(config, "attention_bias", False) or getattr(
265
+ config, "bias", False)
266
+ bias_o_proj = attention_bias
267
+ # support internlm/internlm3-8b with qkv_bias
268
+ if hasattr(config, 'qkv_bias'):
269
+ attention_bias = config.qkv_bias
270
+
271
+ # By default, Llama uses causal attention as it is a decoder-only model.
272
+ # You can override the HF config with `is_causal=False` to enable
273
+ # bidirectional attention, which is used in some embedding models
274
+ # (e.g. parasail-ai/GritLM-7B-vllm)
275
+ if getattr(config, "is_causal", True):
276
+ attn_type = AttentionType.DECODER
277
+ else:
278
+ attn_type = AttentionType.ENCODER_ONLY
279
+
280
+ self.self_attn = LlamaAttention(
281
+ config=config,
282
+ hidden_size=self.hidden_size,
283
+ num_heads=config.num_attention_heads,
284
+ num_kv_heads=getattr(config, "num_key_value_heads",
285
+ config.num_attention_heads),
286
+ rope_theta=rope_theta,
287
+ rope_scaling=rope_scaling,
288
+ max_position_embeddings=max_position_embeddings,
289
+ quant_config=quant_config,
290
+ bias=attention_bias,
291
+ bias_o_proj=bias_o_proj,
292
+ cache_config=cache_config,
293
+ prefix=f"{prefix}.self_attn",
294
+ attn_type=attn_type,
295
+ )
296
+ self.mlp = LlamaMLP(
297
+ hidden_size=self.hidden_size,
298
+ intermediate_size=config.intermediate_size,
299
+ hidden_act=config.hidden_act,
300
+ quant_config=quant_config,
301
+ bias=getattr(config, "mlp_bias", False),
302
+ prefix=f"{prefix}.mlp",
303
+ )
304
+ self.input_layernorm = RMSNorm(config.hidden_size,
305
+ eps=config.rms_norm_eps)
306
+ self.post_attention_layernorm = RMSNorm(config.hidden_size,
307
+ eps=config.rms_norm_eps)
308
+
309
+ def forward(
310
+ self,
311
+ positions: torch.Tensor,
312
+ hidden_states: torch.Tensor,
313
+ residual: Optional[torch.Tensor],
314
+ ) -> tuple[torch.Tensor, torch.Tensor]:
315
+ # Self Attention
316
+ if residual is None:
317
+ residual = hidden_states
318
+ hidden_states = self.input_layernorm(hidden_states)
319
+ else:
320
+ hidden_states, residual = self.input_layernorm(
321
+ hidden_states, residual)
322
+ hidden_states = self.self_attn(positions=positions,
323
+ hidden_states=hidden_states)
324
+
325
+ # Fully Connected
326
+ hidden_states, residual = self.post_attention_layernorm(
327
+ hidden_states, residual)
328
+ hidden_states = self.mlp(hidden_states)
329
+ return hidden_states, residual
330
+
331
+
332
+ @support_torch_compile
333
+ class LlamaModel(nn.Module):
334
+
335
+ def __init__(self,
336
+ *,
337
+ vllm_config: VllmConfig,
338
+ prefix: str = "",
339
+ layer_type: type[nn.Module] = LlamaDecoderLayer):
340
+ super().__init__()
341
+
342
+ config = vllm_config.model_config.hf_config
343
+ quant_config = vllm_config.quant_config
344
+ lora_config = vllm_config.lora_config
345
+
346
+ self.config = config
347
+ self.quant_config = quant_config
348
+ lora_vocab = (lora_config.lora_extra_vocab_size *
349
+ (lora_config.max_loras or 1)) if lora_config else 0
350
+ self.vocab_size = config.vocab_size + lora_vocab
351
+ self.org_vocab_size = config.vocab_size
352
+ if get_pp_group().is_first_rank or (config.tie_word_embeddings
353
+ and get_pp_group().is_last_rank):
354
+ self.embed_tokens = VocabParallelEmbedding(
355
+ self.vocab_size,
356
+ config.hidden_size,
357
+ org_num_embeddings=config.vocab_size,
358
+ quant_config=quant_config,
359
+ )
360
+ else:
361
+ self.embed_tokens = PPMissingLayer()
362
+ self.start_layer, self.end_layer, self.layers = make_layers(
363
+ config.num_hidden_layers,
364
+ lambda prefix: layer_type(vllm_config=vllm_config, prefix=prefix),
365
+ prefix=f"{prefix}.layers",
366
+ )
367
+ if get_pp_group().is_last_rank:
368
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
369
+ else:
370
+ self.norm = PPMissingLayer()
371
+
372
+ self.aux_hidden_state_layers = tuple[int, ...]()
373
+
374
+ self.make_empty_intermediate_tensors = (
375
+ make_empty_intermediate_tensors_factory(
376
+ ["hidden_states", "residual"], config.hidden_size))
377
+
378
+ # 加入自定义的embedding层
379
+ self.emb = nn.ModuleList([nn.Embedding(config.vocab_size+1, config.hidden_size) #, lr=emb_lr)
380
+ for _ in range(self.config.code_depth)])
381
+
382
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
383
+ # print('===== get_input_embeddings is called =====')
384
+ # print ('input_ids:', input_ids)
385
+ # print(self.embed_tokens(input_ids).shape)
386
+ # print(sum([self.emb[k](input_ids) for k in range(self.config.code_depth)]).shape)
387
+ # import pdb; pdb.set_trace()
388
+ # return self.embed_tokens(input_ids)
389
+ return sum([self.emb[k](input_ids) for k in range(self.config.code_depth)])
390
+
391
+ def forward(
392
+ self,
393
+ input_ids: Optional[torch.Tensor],
394
+ positions: torch.Tensor,
395
+ intermediate_tensors: Optional[IntermediateTensors],
396
+ inputs_embeds: Optional[torch.Tensor] = None,
397
+ ) -> Union[torch.Tensor, IntermediateTensors, tuple[torch.Tensor,
398
+ list[torch.Tensor]]]:
399
+ if get_pp_group().is_first_rank:
400
+ # import pdb; pdb.set_trace()
401
+ # print('input_ids', input_ids.shape, 'input_embedes_shape', inputs_embeds.shape)
402
+ if inputs_embeds is not None:
403
+ hidden_states = inputs_embeds
404
+ # print('use_input_embedes')
405
+ # print('input_ids exist:', input_ids is not None)
406
+ # import random
407
+ # count = random.random()
408
+ # if count>0.9:
409
+ # import pdb; pdb.set_trace()
410
+ else:
411
+ # hidden_states = self.get_input_embeddings(input_ids)
412
+ hidden_states = sum([self.emb[k](input_ids) for k in range(self.config.code_depth)]) # 修改为自己的embedding
413
+ print('use_input_ids:', input_ids)
414
+ residual = None
415
+ else:
416
+ assert intermediate_tensors is not None
417
+ hidden_states = intermediate_tensors["hidden_states"]
418
+ residual = intermediate_tensors["residual"]
419
+
420
+ aux_hidden_states = []
421
+ for idx, layer in enumerate(
422
+ islice(self.layers, self.start_layer, self.end_layer)):
423
+ if idx in self.aux_hidden_state_layers:
424
+ aux_hidden_states.append(hidden_states + residual)
425
+ hidden_states, residual = layer(positions, hidden_states, residual)
426
+
427
+ if not get_pp_group().is_last_rank:
428
+ return IntermediateTensors({
429
+ "hidden_states": hidden_states,
430
+ "residual": residual
431
+ })
432
+
433
+ hidden_states, _ = self.norm(hidden_states, residual)
434
+
435
+ if len(aux_hidden_states) > 0:
436
+ return hidden_states, aux_hidden_states
437
+ return hidden_states
438
+
439
+ def load_weights(self, weights: Iterable[tuple[str,
440
+ torch.Tensor]]) -> set[str]:
441
+ stacked_params_mapping = [
442
+ # (param_name, shard_name, shard_id)
443
+ (".qkv_proj", ".q_proj", "q"),
444
+ (".qkv_proj", ".k_proj", "k"),
445
+ (".qkv_proj", ".v_proj", "v"),
446
+ (".gate_up_proj", ".gate_proj", 0),
447
+ (".gate_up_proj", ".up_proj", 1),
448
+ ]
449
+ params_dict = dict(self.named_parameters())
450
+ loaded_params: set[str] = set()
451
+ for name, loaded_weight in weights:
452
+ if "rotary_emb.inv_freq" in name:
453
+ continue
454
+ if ("rotary_emb.cos_cached" in name
455
+ or "rotary_emb.sin_cached" in name):
456
+ # Models trained using ColossalAI may include these tensors in
457
+ # the checkpoint. Skip them.
458
+ continue
459
+ if (self.quant_config is not None and
460
+ (scale_name := self.quant_config.get_cache_scale(name))):
461
+ # Loading kv cache quantization scales
462
+ param = params_dict[scale_name]
463
+ weight_loader = getattr(param, "weight_loader",
464
+ default_weight_loader)
465
+ loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
466
+ loaded_weight[0])
467
+ weight_loader(param, loaded_weight)
468
+ loaded_params.add(scale_name)
469
+ continue
470
+ if "scale" in name:
471
+ # Remapping the name of FP8 kv-scale.
472
+ name = maybe_remap_kv_scale_name(name, params_dict)
473
+ if name is None:
474
+ continue
475
+ for param_name, weight_name, shard_id in stacked_params_mapping:
476
+ if weight_name not in name:
477
+ continue
478
+ name = name.replace(weight_name, param_name)
479
+ # Skip loading extra bias for GPTQ models.
480
+ if name.endswith(".bias") and name not in params_dict:
481
+ continue
482
+
483
+ if is_pp_missing_parameter(name, self):
484
+ continue
485
+
486
+ param = params_dict[name]
487
+ weight_loader = param.weight_loader
488
+ weight_loader(param, loaded_weight, shard_id)
489
+ break
490
+ else:
491
+ # Skip loading extra bias for GPTQ models.
492
+ if name.endswith(".bias") and name not in params_dict:
493
+ continue
494
+
495
+ if is_pp_missing_parameter(name, self):
496
+ continue
497
+
498
+ param = params_dict[name]
499
+ weight_loader = getattr(param, "weight_loader",
500
+ default_weight_loader)
501
+ weight_loader(param, loaded_weight)
502
+ loaded_params.add(name)
503
+ return loaded_params
504
+
505
+
506
+ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
507
+ packed_modules_mapping = {
508
+ "qkv_proj": ["q_proj", "k_proj", "v_proj"],
509
+ "gate_up_proj": ["gate_proj", "up_proj"]
510
+ }
511
+
512
+ # LoRA specific attributes
513
+ embedding_modules = {
514
+ "embed_tokens": "input_embeddings",
515
+ "lm_head": "output_embeddings"
516
+ }
517
+ embedding_padding_modules = ["lm_head"]
518
+
519
+ # Mistral/Llama models can also be loaded with --load-format mistral
520
+ # from consolidated.safetensors checkpoints
521
+ mistral_mapping = {
522
+ "layers": "model.layers",
523
+ "attention": "self_attn",
524
+ "qscale_act": "input_scale",
525
+ "qscale_weight": "weight_scale",
526
+ "kv_fake_quantizer.qscale_act": "kv_scale",
527
+ "q_fake_quantizer.qscale_act": "attn.q_scale",
528
+ "k_fake_quantizer.qscale_act": "k_scale",
529
+ "v_fake_quantizer.qscale_act": "v_scale",
530
+ "wq": "q_proj",
531
+ "wk": "k_proj",
532
+ "wv": "v_proj",
533
+ "wo": "o_proj",
534
+ "attention_norm": "input_layernorm",
535
+ "feed_forward": "mlp",
536
+ "w1": "gate_proj",
537
+ "w2": "down_proj",
538
+ "w3": "up_proj",
539
+ "ffn_norm": "post_attention_layernorm",
540
+ "tok_embeddings": "model.embed_tokens",
541
+ "output": "lm_head",
542
+ "norm": "model.norm",
543
+ }
544
+
545
+ def __init__(self,
546
+ *,
547
+ vllm_config: VllmConfig,
548
+ prefix: str = "",
549
+ layer_type: type[nn.Module] = LlamaDecoderLayer):
550
+ super().__init__()
551
+ config = vllm_config.model_config.hf_config
552
+ quant_config = vllm_config.quant_config
553
+ lora_config = vllm_config.lora_config
554
+ self.config = config
555
+ self.lora_config = lora_config
556
+ self.model = self._init_model(vllm_config=vllm_config,
557
+ prefix=maybe_prefix(prefix, "model"),
558
+ layer_type=layer_type)
559
+ if get_pp_group().is_last_rank:
560
+ self.unpadded_vocab_size = config.vocab_size
561
+ if lora_config:
562
+ self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
563
+ self.lm_head = ParallelLMHead(
564
+ self.unpadded_vocab_size,
565
+ config.hidden_size,
566
+ org_num_embeddings=config.vocab_size,
567
+ padding_size=(
568
+ DEFAULT_VOCAB_PADDING_SIZE
569
+ # We need bigger padding if using lora for kernel
570
+ # compatibility
571
+ if not lora_config else
572
+ lora_config.lora_vocab_padding_size),
573
+ quant_config=quant_config,
574
+ prefix=maybe_prefix(prefix, "lm_head"),
575
+ )
576
+ if config.tie_word_embeddings:
577
+ self.lm_head = self.lm_head.tie_weights(
578
+ self.model.embed_tokens)
579
+
580
+ logit_scale = getattr(config, "logit_scale", 1.0)
581
+ self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
582
+ config.vocab_size,
583
+ logit_scale)
584
+ else:
585
+ self.lm_head = PPMissingLayer()
586
+
587
+ self.make_empty_intermediate_tensors = (
588
+ self.model.make_empty_intermediate_tensors)
589
+
590
+
591
+ def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
592
+ self.model.aux_hidden_state_layers = layers
593
+
594
+ def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
595
+ num_layers = len(self.model.layers)
596
+ return (2, num_layers // 2, num_layers - 3)
597
+
598
+ def _init_model(self,
599
+ vllm_config: VllmConfig,
600
+ prefix: str = "",
601
+ layer_type: type[nn.Module] = LlamaDecoderLayer):
602
+ return LlamaModel(vllm_config=vllm_config,
603
+ prefix=prefix,
604
+ layer_type=layer_type)
605
+
606
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
607
+ return self.model.get_input_embeddings(input_ids)
608
+
609
+ def forward(
610
+ self,
611
+ input_ids: torch.Tensor,
612
+ positions: torch.Tensor,
613
+ intermediate_tensors: Optional[IntermediateTensors] = None,
614
+ inputs_embeds: Optional[torch.Tensor] = None,
615
+ ) -> Union[torch.Tensor, IntermediateTensors]:
616
+ model_output = self.model(input_ids, positions, intermediate_tensors,
617
+ inputs_embeds)
618
+ return model_output
619
+
620
+ def compute_logits(
621
+ self,
622
+ hidden_states: torch.Tensor,
623
+ ) -> Optional[torch.Tensor]:
624
+ logits = self.logits_processor(self.lm_head, hidden_states)
625
+ return logits
626
+
627
+ def load_weights(self, weights: Iterable[tuple[str,
628
+ torch.Tensor]]) -> set[str]:
629
+ loader = AutoWeightsLoader(
630
+ self,
631
+ skip_prefixes=(["lm_head."]
632
+ if self.config.tie_word_embeddings else None),
633
+ )
634
+ return loader.load_weights(
635
+ self.maybe_remap_mistral(name, loaded_weight)
636
+ for name, loaded_weight in weights)
637
+
638
+ # This function is used to remap the mistral format as
639
+ # used by Mistral and Llama <=2
640
+ def maybe_remap_mistral(
641
+ self,
642
+ name: str,
643
+ loaded_weight: torch.Tensor,
644
+ ) -> tuple[str, torch.Tensor]:
645
+
646
+ def permute(w: torch.Tensor, n_heads: int, attn_out: int):
647
+ attn_in = self.config.head_dim * n_heads
648
+
649
+ return w.view(n_heads, attn_in // n_heads // 2, 2,
650
+ attn_out).transpose(1, 2).reshape(attn_in, attn_out)
651
+
652
+ mapping = self.mistral_mapping
653
+ modules = name.split(".")
654
+
655
+ # rotary embeds should be sliced
656
+ # If using quantized model in mistral format,
657
+ # quantization scales (qscale_weight) also need to be sliced
658
+ if "wk" in modules and modules[-1] == "weight":
659
+ loaded_weight = permute(loaded_weight,
660
+ self.config.num_key_value_heads,
661
+ self.config.hidden_size)
662
+ elif "wk" in modules and modules[
663
+ -1] == "qscale_weight" and loaded_weight.numel() > 1:
664
+ loaded_weight = permute(loaded_weight,
665
+ self.config.num_key_value_heads, 1)
666
+ elif "wq" in modules and modules[-1] == "weight":
667
+ loaded_weight = permute(loaded_weight,
668
+ self.config.num_attention_heads,
669
+ self.config.hidden_size)
670
+ elif "wq" in modules and modules[
671
+ -1] == "qscale_weight" and loaded_weight.numel() > 1:
672
+ loaded_weight = permute(loaded_weight,
673
+ self.config.num_attention_heads, 1)
674
+
675
+ num_modules = len(modules)
676
+ for i in range(num_modules):
677
+ item = modules[i]
678
+ next_item = modules[i + 1] if i < num_modules - 1 else None
679
+
680
+ combined_item = (f"{item}.{next_item}"
681
+ if next_item is not None else None)
682
+
683
+ if combined_item in mapping:
684
+ name = name.replace(combined_item, mapping[combined_item])
685
+ elif item in mapping and mapping[item] not in name:
686
+ name = name.replace(item, mapping[item])
687
+
688
+ return name, loaded_weight
vllm_hacked/model_executor/sampling_metadata.py ADDED
@@ -0,0 +1,596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from array import array
4
+ from dataclasses import dataclass
5
+ from typing import Dict, List, Optional, Tuple
6
+
7
+ import torch
8
+
9
+ from vllm.sampling_params import SamplingParams, SamplingType
10
+ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData,
11
+ SequenceGroupMetadata)
12
+ from vllm.utils import (PyObjectCache, async_tensor_h2d,
13
+ is_pin_memory_available, make_tensor_with_pad)
14
+
15
+ _SAMPLING_EPS = 1e-5
16
+
17
+
18
+ @dataclass
19
+ class SequenceGroupToSample:
20
+ # |---------- N-1 iteration --------|
21
+ # |---------------- N iteration ---------------------|
22
+ # |- tokenA -|......................|-- newTokens ---|
23
+ # |---------- context_len ----------|
24
+ # |-------------------- seq_len ----------------------|
25
+ # |-- query_len ---|
26
+
27
+ # Sequence ids for the sequence group in a previous step.
28
+ seq_ids: List[int]
29
+ sampling_params: SamplingParams
30
+ # seq_id -> sequence data.
31
+ seq_data: Dict[int, SequenceData]
32
+ # The length of the sequence (all tokens seen in the past + new token to
33
+ # compute attention) of the sequence group. None if it is in a decode
34
+ # stage.
35
+ seq_len: Optional[int]
36
+ # The length of new query tokens to compute in the current step. None if it
37
+ # is in a decode stage. The length of query_len <= seq_len if chunked
38
+ # prefill is enabled.
39
+ query_len: Optional[int]
40
+ # A random number generator for sampling.
41
+ generator: Optional[torch.Generator]
42
+ # True if the sequence group is in prefill stage. False if it is in a
43
+ # decode stage.
44
+ is_prompt: bool
45
+ # Query token indices from logits. to compute prompt logprob. Empty if
46
+ # prompt logprob is not required.
47
+ prompt_logprob_indices: List[int]
48
+ # Sample token indices from logits. Empty if sampling is not required.
49
+ sample_indices: List[int]
50
+
51
+ @property
52
+ def do_sample(self):
53
+ return len(self.sample_indices) > 0
54
+
55
+ def __post_init__(self):
56
+ if len(self.prompt_logprob_indices) > 0:
57
+ assert self.sampling_params.prompt_logprobs is not None
58
+ if self.is_prompt:
59
+ assert self.seq_len is not None
60
+ assert self.query_len is not None
61
+
62
+
63
+ def gen_seq_group_to_sample_builder(num_seqs: int):
64
+ return lambda: SequenceGroupToSample(
65
+ seq_ids=[0] * num_seqs,
66
+ sampling_params=None,
67
+ seq_data=None, # type: ignore
68
+ seq_len=0,
69
+ query_len=0,
70
+ generator=None,
71
+ is_prompt=True,
72
+ prompt_logprob_indices=[],
73
+ sample_indices=[],
74
+ )
75
+
76
+
77
+ class SamplingMetadataCache:
78
+ """Used to cache SamplingMetadata objects between scheduler iterations"""
79
+
80
+ def __init__(self):
81
+ self._seq_group_to_sample_cache: Dict[int, PyObjectCache] = {}
82
+
83
+ def get_cached_seq_group_to_sample(self, num_seqs):
84
+ if num_seqs not in self._seq_group_to_sample_cache:
85
+ self._seq_group_to_sample_cache[num_seqs] = PyObjectCache(
86
+ gen_seq_group_to_sample_builder(num_seqs))
87
+
88
+ obj = self._seq_group_to_sample_cache[num_seqs].get_object()
89
+ return obj
90
+
91
+ def reset(self):
92
+ for cache in self._seq_group_to_sample_cache.values():
93
+ cache.reset()
94
+
95
+
96
+ class SamplingMetadata:
97
+ """Metadata for input sequences. Used in sampler.
98
+
99
+ The usage is as follow;
100
+ ```
101
+ hidden_states = execute_model(...)
102
+ logits = hidden_states[sampling_metadata.selected_token_indices]
103
+ sample(logits)
104
+
105
+ def sample(logits):
106
+ # Use categorized_sample_indices for sampling....
107
+ ```
108
+
109
+ Args:
110
+ seq_groups: List of batched sequence groups.
111
+ selected_token_indices: (num_query_tokens_to_logprob). Indices to find
112
+ logits from the initial model output hidden states.
113
+ categorized_sample_indices: SamplingType -> token indices to sample.
114
+ Each token indices is 2D tensor of (num_indices, num_indices) where
115
+ the first item means the sample index within the returned logit
116
+ (before pruning padding), and the second item means the sample
117
+ index after pruning using selected_token_indices.
118
+ For example, if the returned logit is [1, 2, 3], and we select
119
+ [1, 2] for sampling, the pruned logit will be [2, 3]. In this case,
120
+ The first tuple is [1, 2] (sampled index within original logit),
121
+ and the second tuple is [0, 1] (sampled index within pruned logit).
122
+ num_prompts: Number of prompt sequence groups in seq_groups.
123
+ skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU
124
+ serialization of token outputs.
125
+ reuse_sampling_tensors: Indicates if we want to reuse sampling
126
+ tensors that are part of the sampler forward pass. Currently,
127
+ it is mainly used for multi-step decode.
128
+
129
+ """
130
+
131
+ def __init__(
132
+ self,
133
+ seq_groups: List[SequenceGroupToSample],
134
+ selected_token_indices: torch.Tensor,
135
+ categorized_sample_indices: Dict[SamplingType, torch.Tensor],
136
+ num_prompts: int,
137
+ skip_sampler_cpu_output: bool = False,
138
+ reuse_sampling_tensors: bool = False,
139
+ ) -> None:
140
+ self.seq_groups = seq_groups
141
+ self.selected_token_indices = selected_token_indices
142
+ self.categorized_sample_indices = categorized_sample_indices
143
+ self.num_prompts = num_prompts
144
+ self.skip_sampler_cpu_output = skip_sampler_cpu_output
145
+ self.reuse_sampling_tensors = reuse_sampling_tensors
146
+
147
+ @staticmethod
148
+ def prepare(
149
+ seq_group_metadata_list: List[SequenceGroupMetadata],
150
+ seq_lens: List[int],
151
+ query_lens: List[int],
152
+ device: str,
153
+ pin_memory: bool,
154
+ generators: Optional[Dict[str, torch.Generator]] = None,
155
+ cache: Optional[SamplingMetadataCache] = None,
156
+ ) -> "SamplingMetadata":
157
+ (
158
+ seq_groups,
159
+ selected_token_indices,
160
+ categorized_sample_indices,
161
+ num_prompts,
162
+ ) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens,
163
+ device, generators, cache)
164
+ selected_token_indices = async_tensor_h2d(
165
+ selected_token_indices,
166
+ dtype=torch.long,
167
+ target_device=device,
168
+ pin_memory=pin_memory,
169
+ )
170
+ categorized_sample_indices = {
171
+ t:
172
+ async_tensor_h2d(
173
+ seq_ids,
174
+ dtype=torch.int,
175
+ target_device=device,
176
+ pin_memory=pin_memory,
177
+ )
178
+ for t, seq_ids in categorized_sample_indices.items()
179
+ }
180
+
181
+ sampling_metadata = SamplingMetadata(
182
+ seq_groups=seq_groups,
183
+ selected_token_indices=selected_token_indices,
184
+ categorized_sample_indices=categorized_sample_indices,
185
+ num_prompts=num_prompts,
186
+ )
187
+ return sampling_metadata
188
+
189
+ def __repr__(self) -> str:
190
+ return (
191
+ "SamplingMetadata("
192
+ f"seq_groups={self.seq_groups}, "
193
+ f"selected_token_indices={self.selected_token_indices}, "
194
+ f"categorized_sample_indices={self.categorized_sample_indices}), ")
195
+
196
+
197
+ def _prepare_seq_groups(
198
+ seq_group_metadata_list: List[SequenceGroupMetadata],
199
+ seq_lens: List[int],
200
+ query_lens: List[int],
201
+ device: str,
202
+ generators: Optional[Dict[str, torch.Generator]] = None,
203
+ cache: Optional[SamplingMetadataCache] = None,
204
+ ) -> Tuple[
205
+ List[SequenceGroupToSample],
206
+ List[int],
207
+ Dict[SamplingType, List[int]],
208
+ int,
209
+ ]:
210
+ """Prepare sequence groups and indices for sampling.
211
+
212
+ Args:
213
+ seq_group_metadata_list: A list of sequence group to batch.
214
+ seq_lens: A list of sequence lens per sequence group.
215
+ Index of prompt len should match with seq_group_metadata_list.
216
+ query_lens: A list of query lengths. Prompt lens include the length
217
+ of entire prompt tokens, and it could be shorter.
218
+ device: A device to use for random number generators,
219
+ `SequenceGroupToSample.generator`.
220
+ generators: A store of per-request random number generators used
221
+ for seeded requests.
222
+
223
+ Returns:
224
+ seq_groups: A list of sequence group to sample.
225
+ selected_token_indices: See the definition from `SamplingMetadata`.
226
+ categorized_sample_indices: See the definition from `SamplingMetadata`.
227
+ num_prompts: Total number of prompts from `seq_group_metadata_list`.
228
+ """
229
+ # Batched sequence groups for the current model forward stsep.
230
+ seq_groups: List[SequenceGroupToSample] = []
231
+ # A list of token indices to sample/compute logprob. It is used to
232
+ # prune the outcome logits from the model for the performance.
233
+ selected_token_indices: List[int] = []
234
+ # Used for selected_token_indices.
235
+ model_output_idx = 0
236
+
237
+ # Sampling type -> (
238
+ # indices to sample/prompt logprob within pruned output logits,
239
+ # indices to sample within pruned logits)
240
+ categorized_sample_indices: Dict[SamplingType, List[int]] = {
241
+ t: []
242
+ for t in SamplingType
243
+ }
244
+ # Index of logits to compute logprob. Logits include both prompt logprob
245
+ # and sample logprob indices.
246
+ logit_idx = 0
247
+ # Total number of prompts from given sequence groups.
248
+ num_prompts = 0
249
+
250
+ for i, seq_group_metadata in enumerate(seq_group_metadata_list):
251
+ seq_ids = seq_group_metadata.seq_data.keys()
252
+
253
+ if cache is not None:
254
+ sample_obj = cache.get_cached_seq_group_to_sample(len(seq_ids))
255
+
256
+ for j, seq_id in enumerate(seq_ids):
257
+ sample_obj.seq_ids[j] = seq_id
258
+
259
+ sample_obj.prompt_logprob_indices.clear()
260
+ sample_obj.sample_indices.clear()
261
+
262
+ sampling_params = seq_group_metadata.sampling_params
263
+ is_prompt = seq_group_metadata.is_prompt
264
+ generator: Optional[torch.Generator] = None
265
+ # If the current seq group is in decode stage, it is None.
266
+ seq_len: Optional[int] = None
267
+ query_len: Optional[int] = None
268
+ prompt_logprob_indices: List[int] = (sample_obj.prompt_logprob_indices
269
+ if cache is not None else [])
270
+ sample_indices: List[int] = (sample_obj.sample_indices
271
+ if cache is not None else [])
272
+ do_sample = seq_group_metadata.do_sample
273
+
274
+ if seq_group_metadata.is_prompt:
275
+ if sampling_params.seed is not None:
276
+ generator = torch.Generator(device=device).manual_seed(
277
+ sampling_params.seed)
278
+ if generators is not None:
279
+ generators[seq_group_metadata.request_id] = generator
280
+
281
+ num_prompts += 1
282
+ num_prefill_sample = len(seq_ids)
283
+ assert num_prefill_sample == 1
284
+ assert query_lens is not None and seq_lens is not None
285
+ query_len, seq_len = query_lens[i], seq_lens[i]
286
+ # If we need sampling, exclude num_prefill_sample tokens from
287
+ # prompt logprob.
288
+ prompt_logprob_len = (query_len - num_prefill_sample
289
+ if do_sample else query_len)
290
+ sample_len = num_prefill_sample if do_sample else 0
291
+ else:
292
+ # Decode
293
+ prompt_logprob_len = 0
294
+ query_len = query_lens[i] if query_lens is not None and len(
295
+ query_lens) > 0 else 1
296
+ sample_len = len(seq_ids) * query_len if do_sample else 0
297
+
298
+ if sampling_params.seed is not None and generators is not None:
299
+ generator = generators.get(seq_group_metadata.request_id)
300
+
301
+ # Update indices to select from the model output.
302
+ """
303
+ This blocks computes selected_token_indices which is used in the
304
+ following way.
305
+
306
+ hidden_states = model(...)
307
+ logits = hidden_states[selected_token_indices]
308
+ """
309
+
310
+ if sampling_params.prompt_logprobs is not None:
311
+ selected_token_indices.extend(
312
+ range(model_output_idx, model_output_idx + prompt_logprob_len))
313
+ model_output_idx += prompt_logprob_len
314
+ if do_sample:
315
+ selected_token_indices.extend(
316
+ range(model_output_idx, model_output_idx + sample_len))
317
+ model_output_idx += sample_len
318
+
319
+ # We now find indices for logprob computation and sampling.
320
+ """
321
+ This block computes categorized_sample_indices which is used in the
322
+ following way.
323
+
324
+ hidden_states = model(...)
325
+ logits = hidden_states[selected_token_indices]
326
+ def sample(logits):
327
+ # Use categorized_sample_indices for sampling.
328
+ # prompt_logprob_indices to find prompt logprob indices.
329
+ # sample_indices to find sample indices.
330
+ """
331
+
332
+ if sampling_params.prompt_logprobs is not None:
333
+ prompt_logprob_indices.extend(
334
+ range(logit_idx, logit_idx + prompt_logprob_len))
335
+ logit_idx += prompt_logprob_len
336
+ if do_sample:
337
+ sample_indices.extend(range(logit_idx, logit_idx + sample_len))
338
+ categorized_sample_indices[sampling_params.sampling_type].extend(
339
+ list(range(logit_idx, logit_idx + sample_len)))
340
+ logit_idx += sample_len
341
+
342
+ if cache is not None:
343
+ sample_obj.sampling_params = sampling_params
344
+ sample_obj.seq_data = seq_group_metadata.seq_data
345
+ sample_obj.seq_len = seq_len
346
+ sample_obj.query_len = query_len
347
+ sample_obj.generator = generator
348
+ sample_obj.is_prompt = is_prompt
349
+ else:
350
+ sample_obj = SequenceGroupToSample(
351
+ seq_ids=list(seq_ids),
352
+ sampling_params=sampling_params,
353
+ seq_data=seq_group_metadata.seq_data,
354
+ seq_len=seq_len,
355
+ query_len=query_len,
356
+ generator=generator,
357
+ is_prompt=is_prompt,
358
+ prompt_logprob_indices=list(prompt_logprob_indices),
359
+ sample_indices=list(sample_indices),
360
+ )
361
+
362
+ seq_groups.append(sample_obj)
363
+
364
+ if cache is not None:
365
+ cache.reset()
366
+
367
+ return (seq_groups, selected_token_indices, categorized_sample_indices,
368
+ num_prompts)
369
+
370
+
371
+ @dataclass
372
+ class SamplingTensors:
373
+ """Tensors for sampling."""
374
+
375
+ temperatures: torch.Tensor
376
+ top_ps: torch.Tensor
377
+ top_ks: torch.Tensor
378
+ min_ps: torch.Tensor
379
+ presence_penalties: torch.Tensor
380
+ frequency_penalties: torch.Tensor
381
+ repetition_penalties: torch.Tensor
382
+ prompt_tokens: torch.Tensor
383
+ output_tokens: torch.Tensor
384
+
385
+ @classmethod
386
+ def from_sampling_metadata(
387
+ cls,
388
+ sampling_metadata: "SamplingMetadata",
389
+ vocab_size: int,
390
+ device: torch.device,
391
+ dtype: torch.dtype,
392
+ ) -> Tuple["SamplingTensors", bool, bool, bool]:
393
+ prompt_tokens: List[array] = []
394
+ output_tokens: List[array] = []
395
+ top_ks: List[int] = []
396
+ temperatures: List[float] = []
397
+ top_ps: List[float] = []
398
+ min_ps: List[float] = []
399
+ presence_penalties: List[float] = []
400
+ frequency_penalties: List[float] = []
401
+ repetition_penalties: List[float] = []
402
+ do_penalties = False
403
+ do_top_p_top_k = False
404
+ do_min_p = False
405
+
406
+ assert sampling_metadata.seq_groups is not None
407
+ for seq_group in sampling_metadata.seq_groups:
408
+ seq_ids = seq_group.seq_ids
409
+ sampling_params = seq_group.sampling_params
410
+ temperature = sampling_params.temperature
411
+ p = sampling_params.presence_penalty
412
+ f = sampling_params.frequency_penalty
413
+ r = sampling_params.repetition_penalty
414
+ top_p = sampling_params.top_p
415
+ min_p = sampling_params.min_p
416
+
417
+ # k should not be greater than the vocab size.
418
+ top_k = min(sampling_params.top_k, vocab_size)
419
+ top_k = vocab_size if top_k == -1 else top_k
420
+ if temperature < _SAMPLING_EPS:
421
+ # NOTE: Zero temperature means deterministic sampling
422
+ # (i.e., greedy sampling or beam search).
423
+ # Set the temperature to 1 to avoid division by zero.
424
+ temperature = 1.0
425
+ if not do_top_p_top_k and (top_p < 1.0 - _SAMPLING_EPS
426
+ or top_k != vocab_size):
427
+ do_top_p_top_k = True
428
+ if not do_min_p and min_p > _SAMPLING_EPS:
429
+ do_min_p = True
430
+ if not do_penalties and (abs(p) >= _SAMPLING_EPS
431
+ or abs(f) >= _SAMPLING_EPS
432
+ or abs(r - 1.0) >= _SAMPLING_EPS):
433
+ do_penalties = True
434
+
435
+ is_prompt = seq_group.is_prompt
436
+ if is_prompt and sampling_params.prompt_logprobs is not None:
437
+ # For tokens in the prompt that we only need to get
438
+ # their logprobs
439
+ query_len = seq_group.query_len
440
+ assert query_len is not None
441
+ prefill_len = len(seq_group.prompt_logprob_indices)
442
+ temperatures += [temperature] * prefill_len
443
+ top_ps += [top_p] * prefill_len
444
+ top_ks += [top_k] * prefill_len
445
+ min_ps += [min_p] * prefill_len
446
+ presence_penalties += [0] * prefill_len
447
+ frequency_penalties += [0] * prefill_len
448
+ repetition_penalties += [1] * prefill_len
449
+
450
+ if seq_group.do_sample:
451
+ sample_lens = len(seq_group.sample_indices)
452
+ assert sample_lens >= len(seq_ids)
453
+ temperatures += [temperature] * sample_lens
454
+ top_ps += [top_p] * sample_lens
455
+ top_ks += [top_k] * sample_lens
456
+ min_ps += [min_p] * sample_lens
457
+ presence_penalties += [p] * sample_lens
458
+ frequency_penalties += [f] * sample_lens
459
+ repetition_penalties += [r] * sample_lens
460
+
461
+ if do_penalties:
462
+ for seq_group in sampling_metadata.seq_groups:
463
+ seq_ids = seq_group.seq_ids
464
+ sampling_params = seq_group.sampling_params
465
+ if (seq_group.is_prompt
466
+ and sampling_params.prompt_logprobs is not None):
467
+ prefill_len = len(seq_group.prompt_logprob_indices)
468
+ prompt_tokens.extend(
469
+ array(VLLM_TOKEN_ID_ARRAY_TYPE)
470
+ for _ in range(prefill_len))
471
+ output_tokens.extend(
472
+ array(VLLM_TOKEN_ID_ARRAY_TYPE)
473
+ for _ in range(prefill_len))
474
+ if seq_group.do_sample:
475
+ for seq_id in seq_ids:
476
+ seq_data = seq_group.seq_data[seq_id]
477
+ prompt_tokens.append(seq_data.prompt_token_ids_array)
478
+ output_tokens.append(seq_data.output_token_ids_array)
479
+
480
+ sampling_tensors = SamplingTensors.from_lists(
481
+ temperatures,
482
+ top_ps,
483
+ top_ks,
484
+ min_ps,
485
+ presence_penalties,
486
+ frequency_penalties,
487
+ repetition_penalties,
488
+ prompt_tokens,
489
+ output_tokens,
490
+ vocab_size,
491
+ device,
492
+ dtype,
493
+ )
494
+ return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p)
495
+
496
+ @classmethod
497
+ def from_lists(
498
+ cls,
499
+ temperatures: List[float],
500
+ top_ps: List[float],
501
+ top_ks: List[int],
502
+ min_ps: List[float],
503
+ presence_penalties: List[float],
504
+ frequency_penalties: List[float],
505
+ repetition_penalties: List[float],
506
+ prompt_tokens: List[array],
507
+ output_tokens: List[array],
508
+ vocab_size: int,
509
+ device: torch.device,
510
+ dtype: torch.dtype,
511
+ ) -> "SamplingTensors":
512
+ # Note that the performance will be very bad without
513
+ # pinned memory.
514
+ pin_memory = is_pin_memory_available()
515
+
516
+ do_penalties = prompt_tokens or output_tokens
517
+
518
+ if do_penalties:
519
+ prompt_t = make_tensor_with_pad(
520
+ prompt_tokens,
521
+ vocab_size,
522
+ device="cpu",
523
+ dtype=torch.int64,
524
+ pin_memory=pin_memory,
525
+ )
526
+ output_t = make_tensor_with_pad(
527
+ output_tokens,
528
+ vocab_size,
529
+ device="cpu",
530
+ dtype=torch.int64,
531
+ pin_memory=pin_memory,
532
+ )
533
+ else:
534
+ empty_tensor = torch.empty(0, device=device, dtype=torch.long)
535
+ prompt_t = empty_tensor
536
+ output_t = empty_tensor
537
+
538
+ temperatures_t = torch.tensor(
539
+ temperatures,
540
+ device="cpu",
541
+ dtype=dtype,
542
+ pin_memory=pin_memory,
543
+ )
544
+ top_ps_t = torch.tensor(
545
+ top_ps,
546
+ device="cpu",
547
+ dtype=dtype,
548
+ pin_memory=pin_memory,
549
+ )
550
+ min_ps_t = torch.tensor(
551
+ min_ps,
552
+ device="cpu",
553
+ dtype=dtype,
554
+ pin_memory=pin_memory,
555
+ )
556
+ presence_penalties_t = torch.tensor(
557
+ presence_penalties,
558
+ device="cpu",
559
+ dtype=dtype,
560
+ pin_memory=pin_memory,
561
+ )
562
+ frequency_penalties_t = torch.tensor(
563
+ frequency_penalties,
564
+ device="cpu",
565
+ dtype=dtype,
566
+ pin_memory=pin_memory,
567
+ )
568
+ repetition_penalties_t = torch.tensor(
569
+ repetition_penalties,
570
+ device="cpu",
571
+ dtype=dtype,
572
+ pin_memory=pin_memory,
573
+ )
574
+ top_ks_t = torch.tensor(
575
+ top_ks,
576
+ device="cpu",
577
+ dtype=torch.int,
578
+ pin_memory=pin_memory,
579
+ )
580
+ # Because the memory is pinned, we can do non-blocking
581
+ # transfer to device.
582
+
583
+ return cls(
584
+ temperatures=temperatures_t.to(device=device, non_blocking=True),
585
+ top_ps=top_ps_t.to(device=device, non_blocking=True),
586
+ top_ks=top_ks_t.to(device=device, non_blocking=True),
587
+ min_ps=min_ps_t.to(device=device, non_blocking=True),
588
+ presence_penalties=presence_penalties_t.to(device=device,
589
+ non_blocking=True),
590
+ frequency_penalties=frequency_penalties_t.to(device=device,
591
+ non_blocking=True),
592
+ repetition_penalties=repetition_penalties_t.to(device=device,
593
+ non_blocking=True),
594
+ prompt_tokens=prompt_t.to(device=device, non_blocking=True),
595
+ output_tokens=output_t.to(device=device, non_blocking=True),
596
+ )
vllm_hacked/model_executor/sampling_metadata_ori.py ADDED
@@ -0,0 +1,596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from array import array
4
+ from dataclasses import dataclass
5
+ from typing import Dict, List, Optional, Tuple
6
+
7
+ import torch
8
+
9
+ from vllm.sampling_params import SamplingParams, SamplingType
10
+ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData,
11
+ SequenceGroupMetadata)
12
+ from vllm.utils import (PyObjectCache, async_tensor_h2d,
13
+ is_pin_memory_available, make_tensor_with_pad)
14
+
15
+ _SAMPLING_EPS = 1e-5
16
+
17
+
18
+ @dataclass
19
+ class SequenceGroupToSample:
20
+ # |---------- N-1 iteration --------|
21
+ # |---------------- N iteration ---------------------|
22
+ # |- tokenA -|......................|-- newTokens ---|
23
+ # |---------- context_len ----------|
24
+ # |-------------------- seq_len ----------------------|
25
+ # |-- query_len ---|
26
+
27
+ # Sequence ids for the sequence group in a previous step.
28
+ seq_ids: List[int]
29
+ sampling_params: SamplingParams
30
+ # seq_id -> sequence data.
31
+ seq_data: Dict[int, SequenceData]
32
+ # The length of the sequence (all tokens seen in the past + new token to
33
+ # compute attention) of the sequence group. None if it is in a decode
34
+ # stage.
35
+ seq_len: Optional[int]
36
+ # The length of new query tokens to compute in the current step. None if it
37
+ # is in a decode stage. The length of query_len <= seq_len if chunked
38
+ # prefill is enabled.
39
+ query_len: Optional[int]
40
+ # A random number generator for sampling.
41
+ generator: Optional[torch.Generator]
42
+ # True if the sequence group is in prefill stage. False if it is in a
43
+ # decode stage.
44
+ is_prompt: bool
45
+ # Query token indices from logits. to compute prompt logprob. Empty if
46
+ # prompt logprob is not required.
47
+ prompt_logprob_indices: List[int]
48
+ # Sample token indices from logits. Empty if sampling is not required.
49
+ sample_indices: List[int]
50
+
51
+ @property
52
+ def do_sample(self):
53
+ return len(self.sample_indices) > 0
54
+
55
+ def __post_init__(self):
56
+ if len(self.prompt_logprob_indices) > 0:
57
+ assert self.sampling_params.prompt_logprobs is not None
58
+ if self.is_prompt:
59
+ assert self.seq_len is not None
60
+ assert self.query_len is not None
61
+
62
+
63
+ def gen_seq_group_to_sample_builder(num_seqs: int):
64
+ return lambda: SequenceGroupToSample(
65
+ seq_ids=[0] * num_seqs,
66
+ sampling_params=None,
67
+ seq_data=None, # type: ignore
68
+ seq_len=0,
69
+ query_len=0,
70
+ generator=None,
71
+ is_prompt=True,
72
+ prompt_logprob_indices=[],
73
+ sample_indices=[],
74
+ )
75
+
76
+
77
+ class SamplingMetadataCache:
78
+ """Used to cache SamplingMetadata objects between scheduler iterations"""
79
+
80
+ def __init__(self):
81
+ self._seq_group_to_sample_cache: Dict[int, PyObjectCache] = {}
82
+
83
+ def get_cached_seq_group_to_sample(self, num_seqs):
84
+ if num_seqs not in self._seq_group_to_sample_cache:
85
+ self._seq_group_to_sample_cache[num_seqs] = PyObjectCache(
86
+ gen_seq_group_to_sample_builder(num_seqs))
87
+
88
+ obj = self._seq_group_to_sample_cache[num_seqs].get_object()
89
+ return obj
90
+
91
+ def reset(self):
92
+ for cache in self._seq_group_to_sample_cache.values():
93
+ cache.reset()
94
+
95
+
96
+ class SamplingMetadata:
97
+ """Metadata for input sequences. Used in sampler.
98
+
99
+ The usage is as follow;
100
+ ```
101
+ hidden_states = execute_model(...)
102
+ logits = hidden_states[sampling_metadata.selected_token_indices]
103
+ sample(logits)
104
+
105
+ def sample(logits):
106
+ # Use categorized_sample_indices for sampling....
107
+ ```
108
+
109
+ Args:
110
+ seq_groups: List of batched sequence groups.
111
+ selected_token_indices: (num_query_tokens_to_logprob). Indices to find
112
+ logits from the initial model output hidden states.
113
+ categorized_sample_indices: SamplingType -> token indices to sample.
114
+ Each token indices is 2D tensor of (num_indices, num_indices) where
115
+ the first item means the sample index within the returned logit
116
+ (before pruning padding), and the second item means the sample
117
+ index after pruning using selected_token_indices.
118
+ For example, if the returned logit is [1, 2, 3], and we select
119
+ [1, 2] for sampling, the pruned logit will be [2, 3]. In this case,
120
+ The first tuple is [1, 2] (sampled index within original logit),
121
+ and the second tuple is [0, 1] (sampled index within pruned logit).
122
+ num_prompts: Number of prompt sequence groups in seq_groups.
123
+ skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU
124
+ serialization of token outputs.
125
+ reuse_sampling_tensors: Indicates if we want to reuse sampling
126
+ tensors that are part of the sampler forward pass. Currently,
127
+ it is mainly used for multi-step decode.
128
+
129
+ """
130
+
131
+ def __init__(
132
+ self,
133
+ seq_groups: List[SequenceGroupToSample],
134
+ selected_token_indices: torch.Tensor,
135
+ categorized_sample_indices: Dict[SamplingType, torch.Tensor],
136
+ num_prompts: int,
137
+ skip_sampler_cpu_output: bool = False,
138
+ reuse_sampling_tensors: bool = False,
139
+ ) -> None:
140
+ self.seq_groups = seq_groups
141
+ self.selected_token_indices = selected_token_indices
142
+ self.categorized_sample_indices = categorized_sample_indices
143
+ self.num_prompts = num_prompts
144
+ self.skip_sampler_cpu_output = skip_sampler_cpu_output
145
+ self.reuse_sampling_tensors = reuse_sampling_tensors
146
+
147
+ @staticmethod
148
+ def prepare(
149
+ seq_group_metadata_list: List[SequenceGroupMetadata],
150
+ seq_lens: List[int],
151
+ query_lens: List[int],
152
+ device: str,
153
+ pin_memory: bool,
154
+ generators: Optional[Dict[str, torch.Generator]] = None,
155
+ cache: Optional[SamplingMetadataCache] = None,
156
+ ) -> "SamplingMetadata":
157
+ (
158
+ seq_groups,
159
+ selected_token_indices,
160
+ categorized_sample_indices,
161
+ num_prompts,
162
+ ) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens,
163
+ device, generators, cache)
164
+ selected_token_indices = async_tensor_h2d(
165
+ selected_token_indices,
166
+ dtype=torch.long,
167
+ target_device=device,
168
+ pin_memory=pin_memory,
169
+ )
170
+ categorized_sample_indices = {
171
+ t:
172
+ async_tensor_h2d(
173
+ seq_ids,
174
+ dtype=torch.int,
175
+ target_device=device,
176
+ pin_memory=pin_memory,
177
+ )
178
+ for t, seq_ids in categorized_sample_indices.items()
179
+ }
180
+
181
+ sampling_metadata = SamplingMetadata(
182
+ seq_groups=seq_groups,
183
+ selected_token_indices=selected_token_indices,
184
+ categorized_sample_indices=categorized_sample_indices,
185
+ num_prompts=num_prompts,
186
+ )
187
+ return sampling_metadata
188
+
189
+ def __repr__(self) -> str:
190
+ return (
191
+ "SamplingMetadata("
192
+ f"seq_groups={self.seq_groups}, "
193
+ f"selected_token_indices={self.selected_token_indices}, "
194
+ f"categorized_sample_indices={self.categorized_sample_indices}), ")
195
+
196
+
197
+ def _prepare_seq_groups(
198
+ seq_group_metadata_list: List[SequenceGroupMetadata],
199
+ seq_lens: List[int],
200
+ query_lens: List[int],
201
+ device: str,
202
+ generators: Optional[Dict[str, torch.Generator]] = None,
203
+ cache: Optional[SamplingMetadataCache] = None,
204
+ ) -> Tuple[
205
+ List[SequenceGroupToSample],
206
+ List[int],
207
+ Dict[SamplingType, List[int]],
208
+ int,
209
+ ]:
210
+ """Prepare sequence groups and indices for sampling.
211
+
212
+ Args:
213
+ seq_group_metadata_list: A list of sequence group to batch.
214
+ seq_lens: A list of sequence lens per sequence group.
215
+ Index of prompt len should match with seq_group_metadata_list.
216
+ query_lens: A list of query lengths. Prompt lens include the length
217
+ of entire prompt tokens, and it could be shorter.
218
+ device: A device to use for random number generators,
219
+ `SequenceGroupToSample.generator`.
220
+ generators: A store of per-request random number generators used
221
+ for seeded requests.
222
+
223
+ Returns:
224
+ seq_groups: A list of sequence group to sample.
225
+ selected_token_indices: See the definition from `SamplingMetadata`.
226
+ categorized_sample_indices: See the definition from `SamplingMetadata`.
227
+ num_prompts: Total number of prompts from `seq_group_metadata_list`.
228
+ """
229
+ # Batched sequence groups for the current model forward stsep.
230
+ seq_groups: List[SequenceGroupToSample] = []
231
+ # A list of token indices to sample/compute logprob. It is used to
232
+ # prune the outcome logits from the model for the performance.
233
+ selected_token_indices: List[int] = []
234
+ # Used for selected_token_indices.
235
+ model_output_idx = 0
236
+
237
+ # Sampling type -> (
238
+ # indices to sample/prompt logprob within pruned output logits,
239
+ # indices to sample within pruned logits)
240
+ categorized_sample_indices: Dict[SamplingType, List[int]] = {
241
+ t: []
242
+ for t in SamplingType
243
+ }
244
+ # Index of logits to compute logprob. Logits include both prompt logprob
245
+ # and sample logprob indices.
246
+ logit_idx = 0
247
+ # Total number of prompts from given sequence groups.
248
+ num_prompts = 0
249
+
250
+ for i, seq_group_metadata in enumerate(seq_group_metadata_list):
251
+ seq_ids = seq_group_metadata.seq_data.keys()
252
+
253
+ if cache is not None:
254
+ sample_obj = cache.get_cached_seq_group_to_sample(len(seq_ids))
255
+
256
+ for j, seq_id in enumerate(seq_ids):
257
+ sample_obj.seq_ids[j] = seq_id
258
+
259
+ sample_obj.prompt_logprob_indices.clear()
260
+ sample_obj.sample_indices.clear()
261
+
262
+ sampling_params = seq_group_metadata.sampling_params
263
+ is_prompt = seq_group_metadata.is_prompt
264
+ generator: Optional[torch.Generator] = None
265
+ # If the current seq group is in decode stage, it is None.
266
+ seq_len: Optional[int] = None
267
+ query_len: Optional[int] = None
268
+ prompt_logprob_indices: List[int] = (sample_obj.prompt_logprob_indices
269
+ if cache is not None else [])
270
+ sample_indices: List[int] = (sample_obj.sample_indices
271
+ if cache is not None else [])
272
+ do_sample = seq_group_metadata.do_sample
273
+
274
+ if seq_group_metadata.is_prompt:
275
+ if sampling_params.seed is not None:
276
+ generator = torch.Generator(device=device).manual_seed(
277
+ sampling_params.seed)
278
+ if generators is not None:
279
+ generators[seq_group_metadata.request_id] = generator
280
+
281
+ num_prompts += 1
282
+ num_prefill_sample = len(seq_ids)
283
+ assert num_prefill_sample == 1
284
+ assert query_lens is not None and seq_lens is not None
285
+ query_len, seq_len = query_lens[i], seq_lens[i]
286
+ # If we need sampling, exclude num_prefill_sample tokens from
287
+ # prompt logprob.
288
+ prompt_logprob_len = (query_len - num_prefill_sample
289
+ if do_sample else query_len)
290
+ sample_len = num_prefill_sample if do_sample else 0
291
+ else:
292
+ # Decode
293
+ prompt_logprob_len = 0
294
+ query_len = query_lens[i] if query_lens is not None and len(
295
+ query_lens) > 0 else 1
296
+ sample_len = len(seq_ids) * query_len if do_sample else 0
297
+
298
+ if sampling_params.seed is not None and generators is not None:
299
+ generator = generators.get(seq_group_metadata.request_id)
300
+
301
+ # Update indices to select from the model output.
302
+ """
303
+ This blocks computes selected_token_indices which is used in the
304
+ following way.
305
+
306
+ hidden_states = model(...)
307
+ logits = hidden_states[selected_token_indices]
308
+ """
309
+
310
+ if sampling_params.prompt_logprobs is not None:
311
+ selected_token_indices.extend(
312
+ range(model_output_idx, model_output_idx + prompt_logprob_len))
313
+ model_output_idx += prompt_logprob_len
314
+ if do_sample:
315
+ selected_token_indices.extend(
316
+ range(model_output_idx, model_output_idx + sample_len))
317
+ model_output_idx += sample_len
318
+
319
+ # We now find indices for logprob computation and sampling.
320
+ """
321
+ This block computes categorized_sample_indices which is used in the
322
+ following way.
323
+
324
+ hidden_states = model(...)
325
+ logits = hidden_states[selected_token_indices]
326
+ def sample(logits):
327
+ # Use categorized_sample_indices for sampling.
328
+ # prompt_logprob_indices to find prompt logprob indices.
329
+ # sample_indices to find sample indices.
330
+ """
331
+
332
+ if sampling_params.prompt_logprobs is not None:
333
+ prompt_logprob_indices.extend(
334
+ range(logit_idx, logit_idx + prompt_logprob_len))
335
+ logit_idx += prompt_logprob_len
336
+ if do_sample:
337
+ sample_indices.extend(range(logit_idx, logit_idx + sample_len))
338
+ categorized_sample_indices[sampling_params.sampling_type].extend(
339
+ list(range(logit_idx, logit_idx + sample_len)))
340
+ logit_idx += sample_len
341
+
342
+ if cache is not None:
343
+ sample_obj.sampling_params = sampling_params
344
+ sample_obj.seq_data = seq_group_metadata.seq_data
345
+ sample_obj.seq_len = seq_len
346
+ sample_obj.query_len = query_len
347
+ sample_obj.generator = generator
348
+ sample_obj.is_prompt = is_prompt
349
+ else:
350
+ sample_obj = SequenceGroupToSample(
351
+ seq_ids=list(seq_ids),
352
+ sampling_params=sampling_params,
353
+ seq_data=seq_group_metadata.seq_data,
354
+ seq_len=seq_len,
355
+ query_len=query_len,
356
+ generator=generator,
357
+ is_prompt=is_prompt,
358
+ prompt_logprob_indices=list(prompt_logprob_indices),
359
+ sample_indices=list(sample_indices),
360
+ )
361
+
362
+ seq_groups.append(sample_obj)
363
+
364
+ if cache is not None:
365
+ cache.reset()
366
+
367
+ return (seq_groups, selected_token_indices, categorized_sample_indices,
368
+ num_prompts)
369
+
370
+
371
+ @dataclass
372
+ class SamplingTensors:
373
+ """Tensors for sampling."""
374
+
375
+ temperatures: torch.Tensor
376
+ top_ps: torch.Tensor
377
+ top_ks: torch.Tensor
378
+ min_ps: torch.Tensor
379
+ presence_penalties: torch.Tensor
380
+ frequency_penalties: torch.Tensor
381
+ repetition_penalties: torch.Tensor
382
+ prompt_tokens: torch.Tensor
383
+ output_tokens: torch.Tensor
384
+
385
+ @classmethod
386
+ def from_sampling_metadata(
387
+ cls,
388
+ sampling_metadata: "SamplingMetadata",
389
+ vocab_size: int,
390
+ device: torch.device,
391
+ dtype: torch.dtype,
392
+ ) -> Tuple["SamplingTensors", bool, bool, bool]:
393
+ prompt_tokens: List[array] = []
394
+ output_tokens: List[array] = []
395
+ top_ks: List[int] = []
396
+ temperatures: List[float] = []
397
+ top_ps: List[float] = []
398
+ min_ps: List[float] = []
399
+ presence_penalties: List[float] = []
400
+ frequency_penalties: List[float] = []
401
+ repetition_penalties: List[float] = []
402
+ do_penalties = False
403
+ do_top_p_top_k = False
404
+ do_min_p = False
405
+
406
+ assert sampling_metadata.seq_groups is not None
407
+ for seq_group in sampling_metadata.seq_groups:
408
+ seq_ids = seq_group.seq_ids
409
+ sampling_params = seq_group.sampling_params
410
+ temperature = sampling_params.temperature
411
+ p = sampling_params.presence_penalty
412
+ f = sampling_params.frequency_penalty
413
+ r = sampling_params.repetition_penalty
414
+ top_p = sampling_params.top_p
415
+ min_p = sampling_params.min_p
416
+
417
+ # k should not be greater than the vocab size.
418
+ top_k = min(sampling_params.top_k, vocab_size)
419
+ top_k = vocab_size if top_k == -1 else top_k
420
+ if temperature < _SAMPLING_EPS:
421
+ # NOTE: Zero temperature means deterministic sampling
422
+ # (i.e., greedy sampling or beam search).
423
+ # Set the temperature to 1 to avoid division by zero.
424
+ temperature = 1.0
425
+ if not do_top_p_top_k and (top_p < 1.0 - _SAMPLING_EPS
426
+ or top_k != vocab_size):
427
+ do_top_p_top_k = True
428
+ if not do_min_p and min_p > _SAMPLING_EPS:
429
+ do_min_p = True
430
+ if not do_penalties and (abs(p) >= _SAMPLING_EPS
431
+ or abs(f) >= _SAMPLING_EPS
432
+ or abs(r - 1.0) >= _SAMPLING_EPS):
433
+ do_penalties = True
434
+
435
+ is_prompt = seq_group.is_prompt
436
+ if is_prompt and sampling_params.prompt_logprobs is not None:
437
+ # For tokens in the prompt that we only need to get
438
+ # their logprobs
439
+ query_len = seq_group.query_len
440
+ assert query_len is not None
441
+ prefill_len = len(seq_group.prompt_logprob_indices)
442
+ temperatures += [temperature] * prefill_len
443
+ top_ps += [top_p] * prefill_len
444
+ top_ks += [top_k] * prefill_len
445
+ min_ps += [min_p] * prefill_len
446
+ presence_penalties += [0] * prefill_len
447
+ frequency_penalties += [0] * prefill_len
448
+ repetition_penalties += [1] * prefill_len
449
+
450
+ if seq_group.do_sample:
451
+ sample_lens = len(seq_group.sample_indices)
452
+ assert sample_lens >= len(seq_ids)
453
+ temperatures += [temperature] * sample_lens
454
+ top_ps += [top_p] * sample_lens
455
+ top_ks += [top_k] * sample_lens
456
+ min_ps += [min_p] * sample_lens
457
+ presence_penalties += [p] * sample_lens
458
+ frequency_penalties += [f] * sample_lens
459
+ repetition_penalties += [r] * sample_lens
460
+
461
+ if do_penalties:
462
+ for seq_group in sampling_metadata.seq_groups:
463
+ seq_ids = seq_group.seq_ids
464
+ sampling_params = seq_group.sampling_params
465
+ if (seq_group.is_prompt
466
+ and sampling_params.prompt_logprobs is not None):
467
+ prefill_len = len(seq_group.prompt_logprob_indices)
468
+ prompt_tokens.extend(
469
+ array(VLLM_TOKEN_ID_ARRAY_TYPE)
470
+ for _ in range(prefill_len))
471
+ output_tokens.extend(
472
+ array(VLLM_TOKEN_ID_ARRAY_TYPE)
473
+ for _ in range(prefill_len))
474
+ if seq_group.do_sample:
475
+ for seq_id in seq_ids:
476
+ seq_data = seq_group.seq_data[seq_id]
477
+ prompt_tokens.append(seq_data.prompt_token_ids_array)
478
+ output_tokens.append(seq_data.output_token_ids_array)
479
+
480
+ sampling_tensors = SamplingTensors.from_lists(
481
+ temperatures,
482
+ top_ps,
483
+ top_ks,
484
+ min_ps,
485
+ presence_penalties,
486
+ frequency_penalties,
487
+ repetition_penalties,
488
+ prompt_tokens,
489
+ output_tokens,
490
+ vocab_size,
491
+ device,
492
+ dtype,
493
+ )
494
+ return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p)
495
+
496
+ @classmethod
497
+ def from_lists(
498
+ cls,
499
+ temperatures: List[float],
500
+ top_ps: List[float],
501
+ top_ks: List[int],
502
+ min_ps: List[float],
503
+ presence_penalties: List[float],
504
+ frequency_penalties: List[float],
505
+ repetition_penalties: List[float],
506
+ prompt_tokens: List[array],
507
+ output_tokens: List[array],
508
+ vocab_size: int,
509
+ device: torch.device,
510
+ dtype: torch.dtype,
511
+ ) -> "SamplingTensors":
512
+ # Note that the performance will be very bad without
513
+ # pinned memory.
514
+ pin_memory = is_pin_memory_available()
515
+
516
+ do_penalties = prompt_tokens or output_tokens
517
+
518
+ if do_penalties:
519
+ prompt_t = make_tensor_with_pad(
520
+ prompt_tokens,
521
+ vocab_size,
522
+ device="cpu",
523
+ dtype=torch.int64,
524
+ pin_memory=pin_memory,
525
+ )
526
+ output_t = make_tensor_with_pad(
527
+ output_tokens,
528
+ vocab_size,
529
+ device="cpu",
530
+ dtype=torch.int64,
531
+ pin_memory=pin_memory,
532
+ )
533
+ else:
534
+ empty_tensor = torch.empty(0, device=device, dtype=torch.long)
535
+ prompt_t = empty_tensor
536
+ output_t = empty_tensor
537
+
538
+ temperatures_t = torch.tensor(
539
+ temperatures,
540
+ device="cpu",
541
+ dtype=dtype,
542
+ pin_memory=pin_memory,
543
+ )
544
+ top_ps_t = torch.tensor(
545
+ top_ps,
546
+ device="cpu",
547
+ dtype=dtype,
548
+ pin_memory=pin_memory,
549
+ )
550
+ min_ps_t = torch.tensor(
551
+ min_ps,
552
+ device="cpu",
553
+ dtype=dtype,
554
+ pin_memory=pin_memory,
555
+ )
556
+ presence_penalties_t = torch.tensor(
557
+ presence_penalties,
558
+ device="cpu",
559
+ dtype=dtype,
560
+ pin_memory=pin_memory,
561
+ )
562
+ frequency_penalties_t = torch.tensor(
563
+ frequency_penalties,
564
+ device="cpu",
565
+ dtype=dtype,
566
+ pin_memory=pin_memory,
567
+ )
568
+ repetition_penalties_t = torch.tensor(
569
+ repetition_penalties,
570
+ device="cpu",
571
+ dtype=dtype,
572
+ pin_memory=pin_memory,
573
+ )
574
+ top_ks_t = torch.tensor(
575
+ top_ks,
576
+ device="cpu",
577
+ dtype=torch.int,
578
+ pin_memory=pin_memory,
579
+ )
580
+ # Because the memory is pinned, we can do non-blocking
581
+ # transfer to device.
582
+
583
+ return cls(
584
+ temperatures=temperatures_t.to(device=device, non_blocking=True),
585
+ top_ps=top_ps_t.to(device=device, non_blocking=True),
586
+ top_ks=top_ks_t.to(device=device, non_blocking=True),
587
+ min_ps=min_ps_t.to(device=device, non_blocking=True),
588
+ presence_penalties=presence_penalties_t.to(device=device,
589
+ non_blocking=True),
590
+ frequency_penalties=frequency_penalties_t.to(device=device,
591
+ non_blocking=True),
592
+ repetition_penalties=repetition_penalties_t.to(device=device,
593
+ non_blocking=True),
594
+ prompt_tokens=prompt_t.to(device=device, non_blocking=True),
595
+ output_tokens=output_t.to(device=device, non_blocking=True),
596
+ )
vllm_hacked/sampling_params.py ADDED
@@ -0,0 +1,596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ """Sampling parameters for text generation."""
4
+ import copy
5
+ import warnings
6
+ from dataclasses import field
7
+ from enum import Enum, IntEnum
8
+ from functools import cached_property
9
+ from typing import Annotated, Any, Optional, Union
10
+
11
+ import msgspec
12
+ from pydantic.dataclasses import dataclass
13
+
14
+ from vllm.logger import init_logger
15
+ from vllm.logits_process import LogitsProcessor
16
+ from vllm.transformers_utils.tokenizer import AnyTokenizer
17
+
18
+ logger = init_logger(__name__)
19
+
20
+ _SAMPLING_EPS = 1e-5
21
+ _MAX_TEMP = 1e-2
22
+
23
+
24
+ class SamplingType(IntEnum):
25
+ GREEDY = 0
26
+ RANDOM = 1
27
+ RANDOM_SEED = 2
28
+
29
+
30
+ # maybe make msgspec?
31
+ @dataclass
32
+ class StructuredOutputsParams:
33
+ # One of these fields will be used to build a logit processor.
34
+ json: Optional[Union[str, dict]] = None
35
+ regex: Optional[str] = None
36
+ choice: Optional[list[str]] = None
37
+ grammar: Optional[str] = None
38
+ json_object: Optional[bool] = None
39
+ # These are other options that can be set.
40
+ disable_fallback: bool = False
41
+ disable_any_whitespace: bool = False
42
+ disable_additional_properties: bool = False
43
+ whitespace_pattern: Optional[str] = None
44
+ structural_tag: Optional[str] = None
45
+
46
+ _backend: Optional[str] = field(default=None, init=False)
47
+ """CAUTION: Should only be set by Processor._validate_structured_output"""
48
+ _backend_was_auto: bool = field(default=False, init=False)
49
+ """CAUTION: Should only be set by Processor._validate_structured_output"""
50
+
51
+ def __post_init__(self):
52
+ """Validate that some fields are mutually exclusive."""
53
+ count = sum([
54
+ self.json is not None, self.regex is not None, self.choice
55
+ is not None, self.grammar is not None, self.json_object is not None
56
+ ])
57
+ if count > 1:
58
+ raise ValueError(
59
+ "You can only use one kind of structured outputs constraint "
60
+ f"but multiple are specified: {self.__dict__}")
61
+
62
+
63
+ @dataclass
64
+ class GuidedDecodingParams(StructuredOutputsParams):
65
+
66
+ def __post_init__(self):
67
+ warnings.warn(
68
+ "GuidedDecodingParams is deprecated. This will be removed in "
69
+ "v0.12.0 or v1.0.0, which ever is soonest. Please use "
70
+ "StructuredOutputsParams instead.",
71
+ DeprecationWarning,
72
+ stacklevel=2)
73
+ return super().__post_init__()
74
+
75
+
76
+ class RequestOutputKind(Enum):
77
+ # Return entire output so far in every RequestOutput
78
+ CUMULATIVE = 0
79
+ # Return only deltas in each RequestOutput
80
+ DELTA = 1
81
+ # Do not return intermediate RequestOutput
82
+ FINAL_ONLY = 2
83
+
84
+
85
+ class SamplingParams(
86
+ msgspec.Struct,
87
+ omit_defaults=True, # type: ignore[call-arg]
88
+ # required for @cached_property.
89
+ dict=True): # type: ignore[call-arg]
90
+ """Sampling parameters for text generation.
91
+
92
+ Overall, we follow the sampling parameters from the OpenAI text completion
93
+ API (https://platform.openai.com/docs/api-reference/completions/create).
94
+ In addition, we support beam search, which is not supported by OpenAI.
95
+ """
96
+
97
+ n: int = 1
98
+ """Number of outputs to return for the given prompt request.
99
+
100
+ NOTE:
101
+ `AsyncLLM` streams outputs by default. When `n > 1`, all `n` outputs
102
+ are generated and streamed cumulatively per request. To see all `n`
103
+ outputs upon completion, use `output_kind=RequestOutputKind.FINAL_ONLY`
104
+ in `SamplingParams`."""
105
+ best_of: Optional[int] = None
106
+ """Number of output sequences that are generated from the prompt. From
107
+ these `best_of` sequences, the top `n` sequences are returned. `best_of`
108
+ must be greater than or equal to `n`. By default, `best_of` is set to `n`.
109
+ Warning, this is only supported in V0."""
110
+ _real_n: Optional[int] = None
111
+ presence_penalty: float = 0.0
112
+ """Penalizes new tokens based on whether they appear in the generated text
113
+ so far. Values > 0 encourage the model to use new tokens, while values < 0
114
+ encourage the model to repeat tokens."""
115
+ frequency_penalty: float = 0.0
116
+ """Penalizes new tokens based on their frequency in the generated text so
117
+ far. Values > 0 encourage the model to use new tokens, while values < 0
118
+ encourage the model to repeat tokens."""
119
+ repetition_penalty: float = 1.0
120
+ """Penalizes new tokens based on whether they appear in the prompt and the
121
+ generated text so far. Values > 1 encourage the model to use new tokens,
122
+ while values < 1 encourage the model to repeat tokens."""
123
+ temperature: float = 1.0
124
+ """Controls the randomness of the sampling. Lower values make the model
125
+ more deterministic, while higher values make the model more random. Zero
126
+ means greedy sampling."""
127
+ top_p: float = 1.0
128
+ """Controls the cumulative probability of the top tokens to consider. Must
129
+ be in (0, 1]. Set to 1 to consider all tokens."""
130
+ top_k: int = 0
131
+ """Controls the number of top tokens to consider. Set to 0 (or -1) to
132
+ consider all tokens."""
133
+ min_p: float = 0.0
134
+ """Represents the minimum probability for a token to be considered,
135
+ relative to the probability of the most likely token. Must be in [0, 1].
136
+ Set to 0 to disable this."""
137
+ seed: Optional[int] = None
138
+ """Random seed to use for the generation."""
139
+ stop: Optional[Union[str, list[str]]] = None
140
+ """String(s) that stop the generation when they are generated. The returned
141
+ output will not contain the stop strings."""
142
+ stop_token_ids: Optional[list[int]] = None
143
+ """Token IDs that stop the generation when they are generated. The returned
144
+ output will contain the stop tokens unless the stop tokens are special
145
+ tokens."""
146
+ ignore_eos: bool = False
147
+ """Whether to ignore the EOS token and continue generating
148
+ tokens after the EOS token is generated."""
149
+ max_tokens: Optional[int] = 16
150
+ """Maximum number of tokens to generate per output sequence."""
151
+ min_tokens: int = 0
152
+ """Minimum number of tokens to generate per output sequence before EOS or
153
+ `stop_token_ids` can be generated"""
154
+ logprobs: Optional[int] = None
155
+ """Number of log probabilities to return per output token. When set to
156
+ `None`, no probability is returned. If set to a non-`None` value, the
157
+ result includes the log probabilities of the specified number of most
158
+ likely tokens, as well as the chosen tokens. Note that the implementation
159
+ follows the OpenAI API: The API will always return the log probability of
160
+ the sampled token, so there may be up to `logprobs+1` elements in the
161
+ response. When set to -1, return all `vocab_size` log probabilities."""
162
+ prompt_logprobs: Optional[int] = None
163
+ """Number of log probabilities to return per prompt token.
164
+ When set to -1, return all `vocab_size` log probabilities."""
165
+ # NOTE: This parameter is only exposed at the engine level for now.
166
+ # It is not exposed in the OpenAI API server, as the OpenAI API does
167
+ # not support returning only a list of token IDs.
168
+ detokenize: bool = True
169
+ """Whether to detokenize the output."""
170
+ skip_special_tokens: bool = True
171
+ """Whether to skip special tokens in the output."""
172
+ spaces_between_special_tokens: bool = True
173
+ """Whether to add spaces between special tokens in the output."""
174
+ # Optional[list[LogitsProcessor]] type. We use Any here because
175
+ # Optional[list[LogitsProcessor]] type is not supported by msgspec.
176
+ logits_processors: Optional[Any] = None
177
+ """Functions that modify logits based on previously generated tokens, and
178
+ optionally prompt tokens as a first argument."""
179
+ include_stop_str_in_output: bool = False
180
+ """Whether to include the stop strings in output text."""
181
+ truncate_prompt_tokens: Optional[Annotated[int,
182
+ msgspec.Meta(ge=-1)]] = None
183
+ """If set to -1, will use the truncation size supported by the model. If
184
+ set to an integer k, will use only the last k tokens from the prompt
185
+ (i.e., left truncation). If set to `None`, truncation is disabled."""
186
+ output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE
187
+
188
+ # The below fields are not supposed to be used as an input.
189
+ # They are set in post_init.
190
+ output_text_buffer_length: int = 0
191
+ _all_stop_token_ids: set[int] = msgspec.field(default_factory=set)
192
+
193
+ # Fields used to construct logits processors
194
+ structured_outputs: Optional[StructuredOutputsParams] = None
195
+ """Parameters for configuring structured outputs."""
196
+ guided_decoding: Optional[GuidedDecodingParams] = None
197
+ """Deprecated alias for structured_outputs."""
198
+ logit_bias: Optional[dict[int, float]] = None
199
+ """If provided, the engine will construct a logits processor that applies
200
+ these logit biases."""
201
+ allowed_token_ids: Optional[list[int]] = None
202
+ """If provided, the engine will construct a logits processor which only
203
+ retains scores for the given token ids."""
204
+ extra_args: Optional[dict[str, Any]] = None
205
+ """Arbitrary additional args, that can be used by custom sampling
206
+ implementations, plugins, etc. Not used by any in-tree sampling
207
+ implementations."""
208
+ guidance_scale: Optional[float] = None
209
+
210
+ # Fields used for bad words
211
+ bad_words: Optional[list[str]] = None
212
+ """Words that are not allowed to be generated. More precisely, only the
213
+ last token of a corresponding token sequence is not allowed when the next
214
+ generated token can complete the sequence."""
215
+ _bad_words_token_ids: Optional[list[list[int]]] = None
216
+
217
+ @staticmethod
218
+ def from_optional(
219
+ n: Optional[int] = 1,
220
+ best_of: Optional[int] = None,
221
+ presence_penalty: Optional[float] = 0.0,
222
+ frequency_penalty: Optional[float] = 0.0,
223
+ repetition_penalty: Optional[float] = 1.0,
224
+ temperature: Optional[float] = 1.0,
225
+ top_p: Optional[float] = 1.0,
226
+ top_k: int = 0,
227
+ min_p: float = 0.0,
228
+ seed: Optional[int] = None,
229
+ stop: Optional[Union[str, list[str]]] = None,
230
+ stop_token_ids: Optional[list[int]] = None,
231
+ bad_words: Optional[list[str]] = None,
232
+ include_stop_str_in_output: bool = False,
233
+ ignore_eos: bool = False,
234
+ max_tokens: Optional[int] = 16,
235
+ min_tokens: int = 0,
236
+ logprobs: Optional[int] = None,
237
+ prompt_logprobs: Optional[int] = None,
238
+ detokenize: bool = True,
239
+ skip_special_tokens: bool = True,
240
+ spaces_between_special_tokens: bool = True,
241
+ logits_processors: Optional[list[LogitsProcessor]] = None,
242
+ truncate_prompt_tokens: Optional[Annotated[int,
243
+ msgspec.Meta(
244
+ ge=-1)]] = None,
245
+ output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
246
+ structured_outputs: Optional[StructuredOutputsParams] = None,
247
+ guided_decoding: Optional[GuidedDecodingParams] = None,
248
+ logit_bias: Optional[Union[dict[int, float], dict[str, float]]] = None,
249
+ allowed_token_ids: Optional[list[int]] = None,
250
+ extra_args: Optional[dict[str, Any]] = None,
251
+ guidance_scale: Optional[float] = None,
252
+ ) -> "SamplingParams":
253
+ if logit_bias is not None:
254
+ # Convert token_id to integer
255
+ # Clamp the bias between -100 and 100 per OpenAI API spec
256
+ logit_bias = {
257
+ int(token): min(100.0, max(-100.0, bias))
258
+ for token, bias in logit_bias.items()
259
+ }
260
+ if guided_decoding is not None:
261
+ warnings.warn(
262
+ "guided_decoding is deprecated. This will be removed in "
263
+ "v0.12.0 or v1.0.0, which ever is soonest. Please use "
264
+ "structured_outputs instead.",
265
+ DeprecationWarning,
266
+ stacklevel=2)
267
+ structured_outputs = guided_decoding
268
+ guided_decoding = None
269
+
270
+ return SamplingParams(
271
+ n=1 if n is None else n,
272
+ best_of=best_of,
273
+ presence_penalty=0.0
274
+ if presence_penalty is None else presence_penalty,
275
+ frequency_penalty=0.0
276
+ if frequency_penalty is None else frequency_penalty,
277
+ repetition_penalty=1.0
278
+ if repetition_penalty is None else repetition_penalty,
279
+ temperature=1.0 if temperature is None else temperature,
280
+ top_p=1.0 if top_p is None else top_p,
281
+ top_k=top_k,
282
+ min_p=min_p,
283
+ seed=seed,
284
+ stop=stop,
285
+ stop_token_ids=stop_token_ids,
286
+ bad_words=bad_words,
287
+ include_stop_str_in_output=include_stop_str_in_output,
288
+ ignore_eos=ignore_eos,
289
+ max_tokens=max_tokens,
290
+ min_tokens=min_tokens,
291
+ logprobs=logprobs,
292
+ prompt_logprobs=prompt_logprobs,
293
+ detokenize=detokenize,
294
+ skip_special_tokens=skip_special_tokens,
295
+ spaces_between_special_tokens=spaces_between_special_tokens,
296
+ logits_processors=logits_processors,
297
+ truncate_prompt_tokens=truncate_prompt_tokens,
298
+ output_kind=output_kind,
299
+ structured_outputs=structured_outputs,
300
+ logit_bias=logit_bias,
301
+ allowed_token_ids=allowed_token_ids,
302
+ extra_args=extra_args,
303
+ guidance_scale=guidance_scale,
304
+ )
305
+
306
+ def __post_init__(self) -> None:
307
+ # how we deal with `best_of``:
308
+ # if `best_of`` is not set, we default to `n`;
309
+ # if `best_of`` is set, we set `n`` to `best_of`,
310
+ # and set `_real_n`` to the original `n`.
311
+ # when we return the result, we will check
312
+ # if we need to return `n` or `_real_n` results
313
+ if self.best_of:
314
+ if self.best_of < self.n:
315
+ raise ValueError(
316
+ f"best_of must be greater than or equal to n, "
317
+ f"got n={self.n} and best_of={self.best_of}.")
318
+ if not self._real_n:
319
+ self._real_n = self.n
320
+ self.n = self.best_of
321
+
322
+ if 0 < self.temperature < _MAX_TEMP:
323
+ logger.warning(
324
+ "temperature %s is less than %s, which may cause numerical "
325
+ "errors nan or inf in tensors. We have maxed it out to %s.",
326
+ self.temperature, _MAX_TEMP, _MAX_TEMP)
327
+ self.temperature = max(self.temperature, _MAX_TEMP)
328
+
329
+ if self.seed == -1:
330
+ self.seed = None
331
+
332
+ if self.stop is None:
333
+ self.stop = []
334
+ elif isinstance(self.stop, str):
335
+ self.stop = [self.stop]
336
+
337
+ if self.stop_token_ids is None:
338
+ self.stop_token_ids = []
339
+
340
+ if self.bad_words is None:
341
+ self.bad_words = []
342
+
343
+ if self.logprobs is True:
344
+ self.logprobs = 1
345
+
346
+ if self.prompt_logprobs is True:
347
+ self.prompt_logprobs = 1
348
+
349
+ # Number of characters to hold back for stop string evaluation
350
+ # until sequence is finished.
351
+ if self.stop and not self.include_stop_str_in_output:
352
+ self.output_text_buffer_length = max(len(s) for s in self.stop) - 1
353
+
354
+ self._verify_args()
355
+
356
+ if self.temperature < _SAMPLING_EPS:
357
+ # Zero temperature means greedy sampling.
358
+ self.top_p = 1.0
359
+ self.top_k = 0
360
+ self.min_p = 0.0
361
+ self._verify_greedy_sampling()
362
+
363
+ # eos_token_id is added to this by the engine
364
+ self._all_stop_token_ids.update(self.stop_token_ids)
365
+
366
+ if self.guided_decoding is not None:
367
+ warnings.warn(
368
+ "guided_decoding is deprecated. This will be removed in "
369
+ "v0.12.0 or v1.0.0, which ever is soonest. Please use "
370
+ "structured_outputs instead.",
371
+ DeprecationWarning,
372
+ stacklevel=2)
373
+ self.structured_outputs = self.guided_decoding
374
+ self.guided_decoding = None
375
+
376
+ def _verify_args(self) -> None:
377
+ if not isinstance(self.n, int):
378
+ raise ValueError(f"n must be an int, but is of "
379
+ f"type {type(self.n)}")
380
+ if self.n < 1:
381
+ raise ValueError(f"n must be at least 1, got {self.n}.")
382
+ if self.best_of is not None:
383
+ if not isinstance(self.best_of, int):
384
+ raise ValueError(
385
+ f"best_of must be an integer, got {type(self.best_of)}")
386
+ if self.best_of < 1:
387
+ raise ValueError(
388
+ f"best_of must be at least 1, got {self.best_of}")
389
+ if self.best_of < self.n:
390
+ raise ValueError(
391
+ f"best_of must be greater than or equal to n, "
392
+ f"got n={self.n} and best_of={self.best_of}.")
393
+ if not -2.0 <= self.presence_penalty <= 2.0:
394
+ raise ValueError("presence_penalty must be in [-2, 2], got "
395
+ f"{self.presence_penalty}.")
396
+ if not -2.0 <= self.frequency_penalty <= 2.0:
397
+ raise ValueError("frequency_penalty must be in [-2, 2], got "
398
+ f"{self.frequency_penalty}.")
399
+ if self.repetition_penalty <= 0.0:
400
+ raise ValueError(
401
+ "repetition_penalty must be greater than zero, got "
402
+ f"{self.repetition_penalty}.")
403
+ if self.temperature < 0.0:
404
+ raise ValueError(
405
+ f"temperature must be non-negative, got {self.temperature}.")
406
+ if not 0.0 < self.top_p <= 1.0:
407
+ raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
408
+ # quietly accept -1 as disabled, but prefer 0
409
+ if self.top_k < -1:
410
+ raise ValueError(f"top_k must be 0 (disable), or at least 1, "
411
+ f"got {self.top_k}.")
412
+ if not isinstance(self.top_k, int):
413
+ raise TypeError(
414
+ f"top_k must be an integer, got {type(self.top_k).__name__}")
415
+ if not 0.0 <= self.min_p <= 1.0:
416
+ raise ValueError("min_p must be in [0, 1], got "
417
+ f"{self.min_p}.")
418
+ if self.max_tokens is not None and self.max_tokens < 1:
419
+ raise ValueError(
420
+ f"max_tokens must be at least 1, got {self.max_tokens}.")
421
+ if self.min_tokens < 0:
422
+ raise ValueError(f"min_tokens must be greater than or equal to 0, "
423
+ f"got {self.min_tokens}.")
424
+ if self.max_tokens is not None and self.min_tokens > self.max_tokens:
425
+ raise ValueError(
426
+ f"min_tokens must be less than or equal to "
427
+ f"max_tokens={self.max_tokens}, got {self.min_tokens}.")
428
+ if (self.logprobs is not None and self.logprobs != -1
429
+ and self.logprobs < 0):
430
+ raise ValueError(
431
+ f"logprobs must be non-negative or -1, got {self.logprobs}.")
432
+ if (self.prompt_logprobs is not None and self.prompt_logprobs != -1
433
+ and self.prompt_logprobs < 0):
434
+ raise ValueError(
435
+ f"prompt_logprobs must be non-negative or -1, got "
436
+ f"{self.prompt_logprobs}.")
437
+ if (self.truncate_prompt_tokens is not None
438
+ and (self.truncate_prompt_tokens == 0
439
+ or self.truncate_prompt_tokens < -1)):
440
+ raise ValueError(
441
+ f"truncate_prompt_tokens must be an integer >= 1 or -1, "
442
+ f"got {self.truncate_prompt_tokens}")
443
+ assert isinstance(self.stop_token_ids, list)
444
+ if not all(isinstance(st_id, int) for st_id in self.stop_token_ids):
445
+ raise ValueError(f"stop_token_ids must contain only integers, "
446
+ f"got {self.stop_token_ids}.")
447
+ assert isinstance(self.stop, list)
448
+ if any(not stop_str for stop_str in self.stop):
449
+ raise ValueError("stop cannot contain an empty string.")
450
+ if self.stop and not self.detokenize:
451
+ raise ValueError(
452
+ "stop strings are only supported when detokenize is True. "
453
+ "Set detokenize=True to use stop.")
454
+ if self.best_of != self._real_n and self.output_kind == (
455
+ RequestOutputKind.DELTA):
456
+ raise ValueError("best_of must equal n to use output_kind=DELTA")
457
+
458
+ def _verify_greedy_sampling(self) -> None:
459
+ if self.n > 1:
460
+ raise ValueError("n must be 1 when using greedy sampling, "
461
+ f"got {self.n}.")
462
+
463
+ def update_from_generation_config(
464
+ self,
465
+ generation_config: dict[str, Any],
466
+ model_eos_token_id: Optional[int] = None) -> None:
467
+ """Update if there are non-default values from generation_config"""
468
+
469
+ if model_eos_token_id is not None:
470
+ # Add the eos token id into the sampling_params to support
471
+ # min_tokens processing.
472
+ self._all_stop_token_ids.add(model_eos_token_id)
473
+
474
+ # Update eos_token_id for generation
475
+ if (eos_ids := generation_config.get("eos_token_id")) is not None:
476
+ # it can be either int or list of int
477
+ eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
478
+ if model_eos_token_id is not None:
479
+ # We don't need to include the primary eos_token_id in
480
+ # stop_token_ids since it's handled separately for stopping
481
+ # purposes.
482
+ eos_ids.discard(model_eos_token_id)
483
+ if eos_ids:
484
+ self._all_stop_token_ids.update(eos_ids)
485
+ if not self.ignore_eos:
486
+ eos_ids.update(self.stop_token_ids)
487
+ self.stop_token_ids = list(eos_ids)
488
+
489
+ def update_from_tokenizer(self, tokenizer: AnyTokenizer) -> None:
490
+ if not self.bad_words:
491
+ return
492
+ self._bad_words_token_ids = []
493
+ for bad_word in self.bad_words:
494
+ # To prohibit words both at the beginning
495
+ # and in the middle of text
496
+ # (related to add_prefix_space tokenizer parameter)
497
+ for add_prefix_space in [False, True]:
498
+ prefix = " " if add_prefix_space else ""
499
+ prompt = prefix + bad_word.lstrip()
500
+ prompt_token_ids = tokenizer.encode(text=prompt,
501
+ add_special_tokens=False)
502
+
503
+ # If no space at the beginning
504
+ # or if prefix space produces a new word token
505
+ if (not add_prefix_space) or (
506
+ add_prefix_space and prompt_token_ids[0]
507
+ != self._bad_words_token_ids[-1][0]
508
+ and len(prompt_token_ids) == len(
509
+ self._bad_words_token_ids[-1])):
510
+ self._bad_words_token_ids.append(prompt_token_ids)
511
+
512
+ invalid_token_ids = [
513
+ token_id for bad_words_token_ids in self._bad_words_token_ids
514
+ for token_id in bad_words_token_ids
515
+ if token_id < 0 or token_id > tokenizer.max_token_id
516
+ ]
517
+ if len(invalid_token_ids) > 0:
518
+ raise ValueError(
519
+ f"The model vocabulary size is {tokenizer.max_token_id+1},"
520
+ f" but the following tokens"
521
+ f" were specified as bad: {invalid_token_ids}."
522
+ f" All token id values should be integers satisfying:"
523
+ f" 0 <= token_id <= {tokenizer.max_token_id}.")
524
+
525
+ @cached_property
526
+ def sampling_type(self) -> SamplingType:
527
+ if self.temperature < _SAMPLING_EPS:
528
+ return SamplingType.GREEDY
529
+ if self.seed is not None:
530
+ return SamplingType.RANDOM_SEED
531
+ return SamplingType.RANDOM
532
+
533
+ @property
534
+ def all_stop_token_ids(self) -> set[int]:
535
+ return self._all_stop_token_ids
536
+
537
+ @property
538
+ def bad_words_token_ids(self) -> Optional[list[list[int]]]:
539
+ # For internal use only. Backward compatibility not guaranteed
540
+ return self._bad_words_token_ids
541
+
542
+ def clone(self) -> "SamplingParams":
543
+ """Deep copy, but maybe not the LogitsProcessor objects.
544
+
545
+ LogitsProcessor objects may contain an arbitrary, nontrivial amount of
546
+ data that is expensive to copy. However, if not copied, the processor
547
+ needs to support parallel decoding for multiple sequences
548
+ See https://github.com/vllm-project/vllm/issues/3087
549
+ """
550
+
551
+ logit_processor_refs = None if self.logits_processors is None else {
552
+ id(lp): lp.clone() if hasattr(lp, 'clone') else lp
553
+ for lp in self.logits_processors
554
+ }
555
+ return copy.deepcopy(self, memo=logit_processor_refs)
556
+
557
+ def __repr__(self) -> str:
558
+ return (
559
+ f"SamplingParams(n={self.n}, "
560
+ f"presence_penalty={self.presence_penalty}, "
561
+ f"frequency_penalty={self.frequency_penalty}, "
562
+ f"repetition_penalty={self.repetition_penalty}, "
563
+ f"temperature={self.temperature}, "
564
+ f"top_p={self.top_p}, "
565
+ f"top_k={self.top_k}, "
566
+ f"min_p={self.min_p}, "
567
+ f"seed={self.seed}, "
568
+ f"stop={self.stop}, "
569
+ f"stop_token_ids={self.stop_token_ids}, "
570
+ f"bad_words={self.bad_words}, "
571
+ f"include_stop_str_in_output={self.include_stop_str_in_output}, "
572
+ f"ignore_eos={self.ignore_eos}, "
573
+ f"max_tokens={self.max_tokens}, "
574
+ f"min_tokens={self.min_tokens}, "
575
+ f"logprobs={self.logprobs}, "
576
+ f"prompt_logprobs={self.prompt_logprobs}, "
577
+ f"skip_special_tokens={self.skip_special_tokens}, "
578
+ "spaces_between_special_tokens="
579
+ f"{self.spaces_between_special_tokens}, "
580
+ f"truncate_prompt_tokens={self.truncate_prompt_tokens}, "
581
+ f"structured_outputs={self.structured_outputs}, "
582
+ f"extra_args={self.extra_args})")
583
+
584
+
585
+ class BeamSearchParams(
586
+ msgspec.Struct,
587
+ omit_defaults=True, # type: ignore[call-arg]
588
+ # required for @cached_property.
589
+ dict=True): # type: ignore[call-arg]
590
+ """Beam search parameters for text generation."""
591
+ beam_width: int
592
+ max_tokens: int
593
+ ignore_eos: bool = False
594
+ temperature: float = 0.0
595
+ length_penalty: float = 1.0
596
+ include_stop_str_in_output: bool = False
vllm_hacked/sampling_params_ori.py ADDED
@@ -0,0 +1,593 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ """Sampling parameters for text generation."""
4
+ import copy
5
+ import warnings
6
+ from dataclasses import field
7
+ from enum import Enum, IntEnum
8
+ from functools import cached_property
9
+ from typing import Annotated, Any, Optional, Union
10
+
11
+ import msgspec
12
+ from pydantic.dataclasses import dataclass
13
+
14
+ from vllm.logger import init_logger
15
+ from vllm.logits_process import LogitsProcessor
16
+ from vllm.transformers_utils.tokenizer import AnyTokenizer
17
+
18
+ logger = init_logger(__name__)
19
+
20
+ _SAMPLING_EPS = 1e-5
21
+ _MAX_TEMP = 1e-2
22
+
23
+
24
+ class SamplingType(IntEnum):
25
+ GREEDY = 0
26
+ RANDOM = 1
27
+ RANDOM_SEED = 2
28
+
29
+
30
+ # maybe make msgspec?
31
+ @dataclass
32
+ class StructuredOutputsParams:
33
+ # One of these fields will be used to build a logit processor.
34
+ json: Optional[Union[str, dict]] = None
35
+ regex: Optional[str] = None
36
+ choice: Optional[list[str]] = None
37
+ grammar: Optional[str] = None
38
+ json_object: Optional[bool] = None
39
+ # These are other options that can be set.
40
+ disable_fallback: bool = False
41
+ disable_any_whitespace: bool = False
42
+ disable_additional_properties: bool = False
43
+ whitespace_pattern: Optional[str] = None
44
+ structural_tag: Optional[str] = None
45
+
46
+ _backend: Optional[str] = field(default=None, init=False)
47
+ """CAUTION: Should only be set by Processor._validate_structured_output"""
48
+ _backend_was_auto: bool = field(default=False, init=False)
49
+ """CAUTION: Should only be set by Processor._validate_structured_output"""
50
+
51
+ def __post_init__(self):
52
+ """Validate that some fields are mutually exclusive."""
53
+ count = sum([
54
+ self.json is not None, self.regex is not None, self.choice
55
+ is not None, self.grammar is not None, self.json_object is not None
56
+ ])
57
+ if count > 1:
58
+ raise ValueError(
59
+ "You can only use one kind of structured outputs constraint "
60
+ f"but multiple are specified: {self.__dict__}")
61
+
62
+
63
+ @dataclass
64
+ class GuidedDecodingParams(StructuredOutputsParams):
65
+
66
+ def __post_init__(self):
67
+ warnings.warn(
68
+ "GuidedDecodingParams is deprecated. This will be removed in "
69
+ "v0.12.0 or v1.0.0, which ever is soonest. Please use "
70
+ "StructuredOutputsParams instead.",
71
+ DeprecationWarning,
72
+ stacklevel=2)
73
+ return super().__post_init__()
74
+
75
+
76
+ class RequestOutputKind(Enum):
77
+ # Return entire output so far in every RequestOutput
78
+ CUMULATIVE = 0
79
+ # Return only deltas in each RequestOutput
80
+ DELTA = 1
81
+ # Do not return intermediate RequestOutput
82
+ FINAL_ONLY = 2
83
+
84
+
85
+ class SamplingParams(
86
+ msgspec.Struct,
87
+ omit_defaults=True, # type: ignore[call-arg]
88
+ # required for @cached_property.
89
+ dict=True): # type: ignore[call-arg]
90
+ """Sampling parameters for text generation.
91
+
92
+ Overall, we follow the sampling parameters from the OpenAI text completion
93
+ API (https://platform.openai.com/docs/api-reference/completions/create).
94
+ In addition, we support beam search, which is not supported by OpenAI.
95
+ """
96
+
97
+ n: int = 1
98
+ """Number of outputs to return for the given prompt request.
99
+
100
+ NOTE:
101
+ `AsyncLLM` streams outputs by default. When `n > 1`, all `n` outputs
102
+ are generated and streamed cumulatively per request. To see all `n`
103
+ outputs upon completion, use `output_kind=RequestOutputKind.FINAL_ONLY`
104
+ in `SamplingParams`."""
105
+ best_of: Optional[int] = None
106
+ """Number of output sequences that are generated from the prompt. From
107
+ these `best_of` sequences, the top `n` sequences are returned. `best_of`
108
+ must be greater than or equal to `n`. By default, `best_of` is set to `n`.
109
+ Warning, this is only supported in V0."""
110
+ _real_n: Optional[int] = None
111
+ presence_penalty: float = 0.0
112
+ """Penalizes new tokens based on whether they appear in the generated text
113
+ so far. Values > 0 encourage the model to use new tokens, while values < 0
114
+ encourage the model to repeat tokens."""
115
+ frequency_penalty: float = 0.0
116
+ """Penalizes new tokens based on their frequency in the generated text so
117
+ far. Values > 0 encourage the model to use new tokens, while values < 0
118
+ encourage the model to repeat tokens."""
119
+ repetition_penalty: float = 1.0
120
+ """Penalizes new tokens based on whether they appear in the prompt and the
121
+ generated text so far. Values > 1 encourage the model to use new tokens,
122
+ while values < 1 encourage the model to repeat tokens."""
123
+ temperature: float = 1.0
124
+ """Controls the randomness of the sampling. Lower values make the model
125
+ more deterministic, while higher values make the model more random. Zero
126
+ means greedy sampling."""
127
+ top_p: float = 1.0
128
+ """Controls the cumulative probability of the top tokens to consider. Must
129
+ be in (0, 1]. Set to 1 to consider all tokens."""
130
+ top_k: int = 0
131
+ """Controls the number of top tokens to consider. Set to 0 (or -1) to
132
+ consider all tokens."""
133
+ min_p: float = 0.0
134
+ """Represents the minimum probability for a token to be considered,
135
+ relative to the probability of the most likely token. Must be in [0, 1].
136
+ Set to 0 to disable this."""
137
+ seed: Optional[int] = None
138
+ """Random seed to use for the generation."""
139
+ stop: Optional[Union[str, list[str]]] = None
140
+ """String(s) that stop the generation when they are generated. The returned
141
+ output will not contain the stop strings."""
142
+ stop_token_ids: Optional[list[int]] = None
143
+ """Token IDs that stop the generation when they are generated. The returned
144
+ output will contain the stop tokens unless the stop tokens are special
145
+ tokens."""
146
+ ignore_eos: bool = False
147
+ """Whether to ignore the EOS token and continue generating
148
+ tokens after the EOS token is generated."""
149
+ max_tokens: Optional[int] = 16
150
+ """Maximum number of tokens to generate per output sequence."""
151
+ min_tokens: int = 0
152
+ """Minimum number of tokens to generate per output sequence before EOS or
153
+ `stop_token_ids` can be generated"""
154
+ logprobs: Optional[int] = None
155
+ """Number of log probabilities to return per output token. When set to
156
+ `None`, no probability is returned. If set to a non-`None` value, the
157
+ result includes the log probabilities of the specified number of most
158
+ likely tokens, as well as the chosen tokens. Note that the implementation
159
+ follows the OpenAI API: The API will always return the log probability of
160
+ the sampled token, so there may be up to `logprobs+1` elements in the
161
+ response. When set to -1, return all `vocab_size` log probabilities."""
162
+ prompt_logprobs: Optional[int] = None
163
+ """Number of log probabilities to return per prompt token.
164
+ When set to -1, return all `vocab_size` log probabilities."""
165
+ # NOTE: This parameter is only exposed at the engine level for now.
166
+ # It is not exposed in the OpenAI API server, as the OpenAI API does
167
+ # not support returning only a list of token IDs.
168
+ detokenize: bool = True
169
+ """Whether to detokenize the output."""
170
+ skip_special_tokens: bool = True
171
+ """Whether to skip special tokens in the output."""
172
+ spaces_between_special_tokens: bool = True
173
+ """Whether to add spaces between special tokens in the output."""
174
+ # Optional[list[LogitsProcessor]] type. We use Any here because
175
+ # Optional[list[LogitsProcessor]] type is not supported by msgspec.
176
+ logits_processors: Optional[Any] = None
177
+ """Functions that modify logits based on previously generated tokens, and
178
+ optionally prompt tokens as a first argument."""
179
+ include_stop_str_in_output: bool = False
180
+ """Whether to include the stop strings in output text."""
181
+ truncate_prompt_tokens: Optional[Annotated[int,
182
+ msgspec.Meta(ge=-1)]] = None
183
+ """If set to -1, will use the truncation size supported by the model. If
184
+ set to an integer k, will use only the last k tokens from the prompt
185
+ (i.e., left truncation). If set to `None`, truncation is disabled."""
186
+ output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE
187
+
188
+ # The below fields are not supposed to be used as an input.
189
+ # They are set in post_init.
190
+ output_text_buffer_length: int = 0
191
+ _all_stop_token_ids: set[int] = msgspec.field(default_factory=set)
192
+
193
+ # Fields used to construct logits processors
194
+ structured_outputs: Optional[StructuredOutputsParams] = None
195
+ """Parameters for configuring structured outputs."""
196
+ guided_decoding: Optional[GuidedDecodingParams] = None
197
+ """Deprecated alias for structured_outputs."""
198
+ logit_bias: Optional[dict[int, float]] = None
199
+ """If provided, the engine will construct a logits processor that applies
200
+ these logit biases."""
201
+ allowed_token_ids: Optional[list[int]] = None
202
+ """If provided, the engine will construct a logits processor which only
203
+ retains scores for the given token ids."""
204
+ extra_args: Optional[dict[str, Any]] = None
205
+ """Arbitrary additional args, that can be used by custom sampling
206
+ implementations, plugins, etc. Not used by any in-tree sampling
207
+ implementations."""
208
+
209
+ # Fields used for bad words
210
+ bad_words: Optional[list[str]] = None
211
+ """Words that are not allowed to be generated. More precisely, only the
212
+ last token of a corresponding token sequence is not allowed when the next
213
+ generated token can complete the sequence."""
214
+ _bad_words_token_ids: Optional[list[list[int]]] = None
215
+
216
+ @staticmethod
217
+ def from_optional(
218
+ n: Optional[int] = 1,
219
+ best_of: Optional[int] = None,
220
+ presence_penalty: Optional[float] = 0.0,
221
+ frequency_penalty: Optional[float] = 0.0,
222
+ repetition_penalty: Optional[float] = 1.0,
223
+ temperature: Optional[float] = 1.0,
224
+ top_p: Optional[float] = 1.0,
225
+ top_k: int = 0,
226
+ min_p: float = 0.0,
227
+ seed: Optional[int] = None,
228
+ stop: Optional[Union[str, list[str]]] = None,
229
+ stop_token_ids: Optional[list[int]] = None,
230
+ bad_words: Optional[list[str]] = None,
231
+ include_stop_str_in_output: bool = False,
232
+ ignore_eos: bool = False,
233
+ max_tokens: Optional[int] = 16,
234
+ min_tokens: int = 0,
235
+ logprobs: Optional[int] = None,
236
+ prompt_logprobs: Optional[int] = None,
237
+ detokenize: bool = True,
238
+ skip_special_tokens: bool = True,
239
+ spaces_between_special_tokens: bool = True,
240
+ logits_processors: Optional[list[LogitsProcessor]] = None,
241
+ truncate_prompt_tokens: Optional[Annotated[int,
242
+ msgspec.Meta(
243
+ ge=-1)]] = None,
244
+ output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
245
+ structured_outputs: Optional[StructuredOutputsParams] = None,
246
+ guided_decoding: Optional[GuidedDecodingParams] = None,
247
+ logit_bias: Optional[Union[dict[int, float], dict[str, float]]] = None,
248
+ allowed_token_ids: Optional[list[int]] = None,
249
+ extra_args: Optional[dict[str, Any]] = None,
250
+ ) -> "SamplingParams":
251
+ if logit_bias is not None:
252
+ # Convert token_id to integer
253
+ # Clamp the bias between -100 and 100 per OpenAI API spec
254
+ logit_bias = {
255
+ int(token): min(100.0, max(-100.0, bias))
256
+ for token, bias in logit_bias.items()
257
+ }
258
+ if guided_decoding is not None:
259
+ warnings.warn(
260
+ "guided_decoding is deprecated. This will be removed in "
261
+ "v0.12.0 or v1.0.0, which ever is soonest. Please use "
262
+ "structured_outputs instead.",
263
+ DeprecationWarning,
264
+ stacklevel=2)
265
+ structured_outputs = guided_decoding
266
+ guided_decoding = None
267
+
268
+ return SamplingParams(
269
+ n=1 if n is None else n,
270
+ best_of=best_of,
271
+ presence_penalty=0.0
272
+ if presence_penalty is None else presence_penalty,
273
+ frequency_penalty=0.0
274
+ if frequency_penalty is None else frequency_penalty,
275
+ repetition_penalty=1.0
276
+ if repetition_penalty is None else repetition_penalty,
277
+ temperature=1.0 if temperature is None else temperature,
278
+ top_p=1.0 if top_p is None else top_p,
279
+ top_k=top_k,
280
+ min_p=min_p,
281
+ seed=seed,
282
+ stop=stop,
283
+ stop_token_ids=stop_token_ids,
284
+ bad_words=bad_words,
285
+ include_stop_str_in_output=include_stop_str_in_output,
286
+ ignore_eos=ignore_eos,
287
+ max_tokens=max_tokens,
288
+ min_tokens=min_tokens,
289
+ logprobs=logprobs,
290
+ prompt_logprobs=prompt_logprobs,
291
+ detokenize=detokenize,
292
+ skip_special_tokens=skip_special_tokens,
293
+ spaces_between_special_tokens=spaces_between_special_tokens,
294
+ logits_processors=logits_processors,
295
+ truncate_prompt_tokens=truncate_prompt_tokens,
296
+ output_kind=output_kind,
297
+ structured_outputs=structured_outputs,
298
+ logit_bias=logit_bias,
299
+ allowed_token_ids=allowed_token_ids,
300
+ extra_args=extra_args,
301
+ )
302
+
303
+ def __post_init__(self) -> None:
304
+ # how we deal with `best_of``:
305
+ # if `best_of`` is not set, we default to `n`;
306
+ # if `best_of`` is set, we set `n`` to `best_of`,
307
+ # and set `_real_n`` to the original `n`.
308
+ # when we return the result, we will check
309
+ # if we need to return `n` or `_real_n` results
310
+ if self.best_of:
311
+ if self.best_of < self.n:
312
+ raise ValueError(
313
+ f"best_of must be greater than or equal to n, "
314
+ f"got n={self.n} and best_of={self.best_of}.")
315
+ if not self._real_n:
316
+ self._real_n = self.n
317
+ self.n = self.best_of
318
+
319
+ if 0 < self.temperature < _MAX_TEMP:
320
+ logger.warning(
321
+ "temperature %s is less than %s, which may cause numerical "
322
+ "errors nan or inf in tensors. We have maxed it out to %s.",
323
+ self.temperature, _MAX_TEMP, _MAX_TEMP)
324
+ self.temperature = max(self.temperature, _MAX_TEMP)
325
+
326
+ if self.seed == -1:
327
+ self.seed = None
328
+
329
+ if self.stop is None:
330
+ self.stop = []
331
+ elif isinstance(self.stop, str):
332
+ self.stop = [self.stop]
333
+
334
+ if self.stop_token_ids is None:
335
+ self.stop_token_ids = []
336
+
337
+ if self.bad_words is None:
338
+ self.bad_words = []
339
+
340
+ if self.logprobs is True:
341
+ self.logprobs = 1
342
+
343
+ if self.prompt_logprobs is True:
344
+ self.prompt_logprobs = 1
345
+
346
+ # Number of characters to hold back for stop string evaluation
347
+ # until sequence is finished.
348
+ if self.stop and not self.include_stop_str_in_output:
349
+ self.output_text_buffer_length = max(len(s) for s in self.stop) - 1
350
+
351
+ self._verify_args()
352
+
353
+ if self.temperature < _SAMPLING_EPS:
354
+ # Zero temperature means greedy sampling.
355
+ self.top_p = 1.0
356
+ self.top_k = 0
357
+ self.min_p = 0.0
358
+ self._verify_greedy_sampling()
359
+
360
+ # eos_token_id is added to this by the engine
361
+ self._all_stop_token_ids.update(self.stop_token_ids)
362
+
363
+ if self.guided_decoding is not None:
364
+ warnings.warn(
365
+ "guided_decoding is deprecated. This will be removed in "
366
+ "v0.12.0 or v1.0.0, which ever is soonest. Please use "
367
+ "structured_outputs instead.",
368
+ DeprecationWarning,
369
+ stacklevel=2)
370
+ self.structured_outputs = self.guided_decoding
371
+ self.guided_decoding = None
372
+
373
+ def _verify_args(self) -> None:
374
+ if not isinstance(self.n, int):
375
+ raise ValueError(f"n must be an int, but is of "
376
+ f"type {type(self.n)}")
377
+ if self.n < 1:
378
+ raise ValueError(f"n must be at least 1, got {self.n}.")
379
+ if self.best_of is not None:
380
+ if not isinstance(self.best_of, int):
381
+ raise ValueError(
382
+ f"best_of must be an integer, got {type(self.best_of)}")
383
+ if self.best_of < 1:
384
+ raise ValueError(
385
+ f"best_of must be at least 1, got {self.best_of}")
386
+ if self.best_of < self.n:
387
+ raise ValueError(
388
+ f"best_of must be greater than or equal to n, "
389
+ f"got n={self.n} and best_of={self.best_of}.")
390
+ if not -2.0 <= self.presence_penalty <= 2.0:
391
+ raise ValueError("presence_penalty must be in [-2, 2], got "
392
+ f"{self.presence_penalty}.")
393
+ if not -2.0 <= self.frequency_penalty <= 2.0:
394
+ raise ValueError("frequency_penalty must be in [-2, 2], got "
395
+ f"{self.frequency_penalty}.")
396
+ if self.repetition_penalty <= 0.0:
397
+ raise ValueError(
398
+ "repetition_penalty must be greater than zero, got "
399
+ f"{self.repetition_penalty}.")
400
+ if self.temperature < 0.0:
401
+ raise ValueError(
402
+ f"temperature must be non-negative, got {self.temperature}.")
403
+ if not 0.0 < self.top_p <= 1.0:
404
+ raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
405
+ # quietly accept -1 as disabled, but prefer 0
406
+ if self.top_k < -1:
407
+ raise ValueError(f"top_k must be 0 (disable), or at least 1, "
408
+ f"got {self.top_k}.")
409
+ if not isinstance(self.top_k, int):
410
+ raise TypeError(
411
+ f"top_k must be an integer, got {type(self.top_k).__name__}")
412
+ if not 0.0 <= self.min_p <= 1.0:
413
+ raise ValueError("min_p must be in [0, 1], got "
414
+ f"{self.min_p}.")
415
+ if self.max_tokens is not None and self.max_tokens < 1:
416
+ raise ValueError(
417
+ f"max_tokens must be at least 1, got {self.max_tokens}.")
418
+ if self.min_tokens < 0:
419
+ raise ValueError(f"min_tokens must be greater than or equal to 0, "
420
+ f"got {self.min_tokens}.")
421
+ if self.max_tokens is not None and self.min_tokens > self.max_tokens:
422
+ raise ValueError(
423
+ f"min_tokens must be less than or equal to "
424
+ f"max_tokens={self.max_tokens}, got {self.min_tokens}.")
425
+ if (self.logprobs is not None and self.logprobs != -1
426
+ and self.logprobs < 0):
427
+ raise ValueError(
428
+ f"logprobs must be non-negative or -1, got {self.logprobs}.")
429
+ if (self.prompt_logprobs is not None and self.prompt_logprobs != -1
430
+ and self.prompt_logprobs < 0):
431
+ raise ValueError(
432
+ f"prompt_logprobs must be non-negative or -1, got "
433
+ f"{self.prompt_logprobs}.")
434
+ if (self.truncate_prompt_tokens is not None
435
+ and (self.truncate_prompt_tokens == 0
436
+ or self.truncate_prompt_tokens < -1)):
437
+ raise ValueError(
438
+ f"truncate_prompt_tokens must be an integer >= 1 or -1, "
439
+ f"got {self.truncate_prompt_tokens}")
440
+ assert isinstance(self.stop_token_ids, list)
441
+ if not all(isinstance(st_id, int) for st_id in self.stop_token_ids):
442
+ raise ValueError(f"stop_token_ids must contain only integers, "
443
+ f"got {self.stop_token_ids}.")
444
+ assert isinstance(self.stop, list)
445
+ if any(not stop_str for stop_str in self.stop):
446
+ raise ValueError("stop cannot contain an empty string.")
447
+ if self.stop and not self.detokenize:
448
+ raise ValueError(
449
+ "stop strings are only supported when detokenize is True. "
450
+ "Set detokenize=True to use stop.")
451
+ if self.best_of != self._real_n and self.output_kind == (
452
+ RequestOutputKind.DELTA):
453
+ raise ValueError("best_of must equal n to use output_kind=DELTA")
454
+
455
+ def _verify_greedy_sampling(self) -> None:
456
+ if self.n > 1:
457
+ raise ValueError("n must be 1 when using greedy sampling, "
458
+ f"got {self.n}.")
459
+
460
+ def update_from_generation_config(
461
+ self,
462
+ generation_config: dict[str, Any],
463
+ model_eos_token_id: Optional[int] = None) -> None:
464
+ """Update if there are non-default values from generation_config"""
465
+
466
+ if model_eos_token_id is not None:
467
+ # Add the eos token id into the sampling_params to support
468
+ # min_tokens processing.
469
+ self._all_stop_token_ids.add(model_eos_token_id)
470
+
471
+ # Update eos_token_id for generation
472
+ if (eos_ids := generation_config.get("eos_token_id")) is not None:
473
+ # it can be either int or list of int
474
+ eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
475
+ if model_eos_token_id is not None:
476
+ # We don't need to include the primary eos_token_id in
477
+ # stop_token_ids since it's handled separately for stopping
478
+ # purposes.
479
+ eos_ids.discard(model_eos_token_id)
480
+ if eos_ids:
481
+ self._all_stop_token_ids.update(eos_ids)
482
+ if not self.ignore_eos:
483
+ eos_ids.update(self.stop_token_ids)
484
+ self.stop_token_ids = list(eos_ids)
485
+
486
+ def update_from_tokenizer(self, tokenizer: AnyTokenizer) -> None:
487
+ if not self.bad_words:
488
+ return
489
+ self._bad_words_token_ids = []
490
+ for bad_word in self.bad_words:
491
+ # To prohibit words both at the beginning
492
+ # and in the middle of text
493
+ # (related to add_prefix_space tokenizer parameter)
494
+ for add_prefix_space in [False, True]:
495
+ prefix = " " if add_prefix_space else ""
496
+ prompt = prefix + bad_word.lstrip()
497
+ prompt_token_ids = tokenizer.encode(text=prompt,
498
+ add_special_tokens=False)
499
+
500
+ # If no space at the beginning
501
+ # or if prefix space produces a new word token
502
+ if (not add_prefix_space) or (
503
+ add_prefix_space and prompt_token_ids[0]
504
+ != self._bad_words_token_ids[-1][0]
505
+ and len(prompt_token_ids) == len(
506
+ self._bad_words_token_ids[-1])):
507
+ self._bad_words_token_ids.append(prompt_token_ids)
508
+
509
+ invalid_token_ids = [
510
+ token_id for bad_words_token_ids in self._bad_words_token_ids
511
+ for token_id in bad_words_token_ids
512
+ if token_id < 0 or token_id > tokenizer.max_token_id
513
+ ]
514
+ if len(invalid_token_ids) > 0:
515
+ raise ValueError(
516
+ f"The model vocabulary size is {tokenizer.max_token_id+1},"
517
+ f" but the following tokens"
518
+ f" were specified as bad: {invalid_token_ids}."
519
+ f" All token id values should be integers satisfying:"
520
+ f" 0 <= token_id <= {tokenizer.max_token_id}.")
521
+
522
+ @cached_property
523
+ def sampling_type(self) -> SamplingType:
524
+ if self.temperature < _SAMPLING_EPS:
525
+ return SamplingType.GREEDY
526
+ if self.seed is not None:
527
+ return SamplingType.RANDOM_SEED
528
+ return SamplingType.RANDOM
529
+
530
+ @property
531
+ def all_stop_token_ids(self) -> set[int]:
532
+ return self._all_stop_token_ids
533
+
534
+ @property
535
+ def bad_words_token_ids(self) -> Optional[list[list[int]]]:
536
+ # For internal use only. Backward compatibility not guaranteed
537
+ return self._bad_words_token_ids
538
+
539
+ def clone(self) -> "SamplingParams":
540
+ """Deep copy, but maybe not the LogitsProcessor objects.
541
+
542
+ LogitsProcessor objects may contain an arbitrary, nontrivial amount of
543
+ data that is expensive to copy. However, if not copied, the processor
544
+ needs to support parallel decoding for multiple sequences
545
+ See https://github.com/vllm-project/vllm/issues/3087
546
+ """
547
+
548
+ logit_processor_refs = None if self.logits_processors is None else {
549
+ id(lp): lp.clone() if hasattr(lp, 'clone') else lp
550
+ for lp in self.logits_processors
551
+ }
552
+ return copy.deepcopy(self, memo=logit_processor_refs)
553
+
554
+ def __repr__(self) -> str:
555
+ return (
556
+ f"SamplingParams(n={self.n}, "
557
+ f"presence_penalty={self.presence_penalty}, "
558
+ f"frequency_penalty={self.frequency_penalty}, "
559
+ f"repetition_penalty={self.repetition_penalty}, "
560
+ f"temperature={self.temperature}, "
561
+ f"top_p={self.top_p}, "
562
+ f"top_k={self.top_k}, "
563
+ f"min_p={self.min_p}, "
564
+ f"seed={self.seed}, "
565
+ f"stop={self.stop}, "
566
+ f"stop_token_ids={self.stop_token_ids}, "
567
+ f"bad_words={self.bad_words}, "
568
+ f"include_stop_str_in_output={self.include_stop_str_in_output}, "
569
+ f"ignore_eos={self.ignore_eos}, "
570
+ f"max_tokens={self.max_tokens}, "
571
+ f"min_tokens={self.min_tokens}, "
572
+ f"logprobs={self.logprobs}, "
573
+ f"prompt_logprobs={self.prompt_logprobs}, "
574
+ f"skip_special_tokens={self.skip_special_tokens}, "
575
+ "spaces_between_special_tokens="
576
+ f"{self.spaces_between_special_tokens}, "
577
+ f"truncate_prompt_tokens={self.truncate_prompt_tokens}, "
578
+ f"structured_outputs={self.structured_outputs}, "
579
+ f"extra_args={self.extra_args})")
580
+
581
+
582
+ class BeamSearchParams(
583
+ msgspec.Struct,
584
+ omit_defaults=True, # type: ignore[call-arg]
585
+ # required for @cached_property.
586
+ dict=True): # type: ignore[call-arg]
587
+ """Beam search parameters for text generation."""
588
+ beam_width: int
589
+ max_tokens: int
590
+ ignore_eos: bool = False
591
+ temperature: float = 0.0
592
+ length_penalty: float = 1.0
593
+ include_stop_str_in_output: bool = False
ckpt/.gitkeep → vllm_hacked/v1/sample/__init__ori.py RENAMED
File without changes
vllm_hacked/v1/sample/metadata.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
+ from dataclasses import dataclass
5
+ from typing import Optional
6
+
7
+ import torch
8
+
9
+ from vllm.v1.sample.logits_processor import LogitsProcessors
10
+
11
+
12
+ @dataclass
13
+ class SamplingMetadata:
14
+
15
+ temperature: Optional[torch.Tensor]
16
+ all_greedy: bool
17
+ all_random: bool
18
+
19
+ top_p: Optional[torch.Tensor]
20
+ top_k: Optional[torch.Tensor]
21
+
22
+ generators: dict[int, torch.Generator]
23
+
24
+ # None means no logprobs, 0 means sampled token logprobs only
25
+ max_num_logprobs: Optional[int]
26
+
27
+ no_penalties: bool
28
+ prompt_token_ids: Optional[torch.Tensor]
29
+ frequency_penalties: torch.Tensor
30
+ presence_penalties: torch.Tensor
31
+ repetition_penalties: torch.Tensor
32
+
33
+ output_token_ids: list[list[int]]
34
+
35
+ # `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size,
36
+ # vocab size).
37
+ allowed_token_ids_mask: Optional[torch.Tensor]
38
+
39
+ # req_index -> bad_words_token_ids
40
+ bad_words_token_ids: dict[int, list[list[int]]]
41
+
42
+ # Loaded logits processors
43
+ logitsprocs: LogitsProcessors
44
+
45
+ guidance_scale: Optional[float] = 1.8
vllm_hacked/v1/sample/metadata_ori.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
+ from dataclasses import dataclass
5
+ from typing import Optional
6
+
7
+ import torch
8
+
9
+ from vllm.v1.sample.logits_processor import LogitsProcessors
10
+
11
+
12
+ @dataclass
13
+ class SamplingMetadata:
14
+
15
+ temperature: Optional[torch.Tensor]
16
+ all_greedy: bool
17
+ all_random: bool
18
+
19
+ top_p: Optional[torch.Tensor]
20
+ top_k: Optional[torch.Tensor]
21
+
22
+ generators: dict[int, torch.Generator]
23
+
24
+ # None means no logprobs, 0 means sampled token logprobs only
25
+ max_num_logprobs: Optional[int]
26
+
27
+ no_penalties: bool
28
+ prompt_token_ids: Optional[torch.Tensor]
29
+ frequency_penalties: torch.Tensor
30
+ presence_penalties: torch.Tensor
31
+ repetition_penalties: torch.Tensor
32
+
33
+ output_token_ids: list[list[int]]
34
+
35
+ # `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size,
36
+ # vocab size).
37
+ allowed_token_ids_mask: Optional[torch.Tensor]
38
+
39
+ # req_index -> bad_words_token_ids
40
+ bad_words_token_ids: dict[int, list[list[int]]]
41
+
42
+ # Loaded logits processors
43
+ logitsprocs: LogitsProcessors
vllm_hacked/v1/sample/ops/penalties_ori.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
+ import torch
5
+
6
+ from vllm.model_executor.layers.utils import apply_penalties
7
+ from vllm.utils import is_pin_memory_available, make_tensor_with_pad
8
+
9
+
10
+ def apply_all_penalties(
11
+ logits: torch.Tensor,
12
+ prompt_token_ids: torch.Tensor,
13
+ presence_penalties: torch.Tensor,
14
+ frequency_penalties: torch.Tensor,
15
+ repetition_penalties: torch.Tensor,
16
+ output_token_ids: list[list[int]],
17
+ ) -> torch.Tensor:
18
+ """
19
+ Applies presence, frequency and repetition penalties to the logits.
20
+ """
21
+ _, vocab_size = logits.shape
22
+ output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size,
23
+ logits.device)
24
+ return apply_penalties(logits, prompt_token_ids, output_tokens_t,
25
+ presence_penalties, frequency_penalties,
26
+ repetition_penalties)
27
+
28
+
29
+ def _convert_to_tensors(output_token_ids: list[list[int]], vocab_size: int,
30
+ device: torch.device) -> torch.Tensor:
31
+ """
32
+ Convert the different list data structures to tensors.
33
+ """
34
+ output_tokens_tensor = make_tensor_with_pad(
35
+ output_token_ids,
36
+ # Use the value of vocab_size as a pad since we don't have a
37
+ # token_id of this value.
38
+ pad=vocab_size,
39
+ device="cpu",
40
+ dtype=torch.int64,
41
+ pin_memory=is_pin_memory_available(),
42
+ )
43
+ return output_tokens_tensor.to(device, non_blocking=True)
vllm_hacked/v1/sample/sampler.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ """A layer that samples the next tokens from the model's outputs."""
4
+
5
+ from typing import Optional
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from vllm.config import LogprobsMode
11
+ from vllm.utils import is_pin_memory_available
12
+ from vllm.v1.outputs import LogprobsTensors, SamplerOutput
13
+ from vllm.v1.sample.metadata import SamplingMetadata
14
+ from vllm.v1.sample.ops.bad_words import apply_bad_words
15
+ from vllm.v1.sample.ops.logprobs import batched_count_greater_than
16
+ from vllm.v1.sample.ops.penalties import apply_all_penalties
17
+ from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
18
+
19
+ _SAMPLING_EPS = 1e-5
20
+
21
+
22
+ class Sampler(nn.Module):
23
+ """
24
+ A layer that samples the next tokens from the model's outputs
25
+ with the following steps in order:
26
+
27
+ 1. If logprobs are requested:
28
+ a) If `logprobs_mode` is `raw_logprobs`, compute logprobs
29
+ as the final logprobs to return.
30
+ b) If `logprobs_mode` is `raw_logits`, clone the logits
31
+ as the final logprobs to return.
32
+ 2. Convert logits to float32.
33
+ 3. Apply allowed token ids whitelist.
34
+ 4. Apply bad words exclusion.
35
+ 5. Apply logit processors which are not argmax-invariant,
36
+ i.e. that can impact greedy sampling.
37
+ a) Min tokens processor
38
+ b) Logit bias processor
39
+ 6. Apply penalties
40
+ a) Repetition penalty
41
+ b) Frequency penalty
42
+ c) Presence penalty
43
+ 7. Sample the next tokens. `sample` method performs the following steps:
44
+ a) If not `all_random`, perform greedy sampling. If `all_greedy`,
45
+ return the greedily sampled tokens and final logprobs if requested.
46
+ b) Apply temperature.
47
+ c) Apply logit processors which are argmax-invariant, by default
48
+ the min_p processor.
49
+ d) Apply top_k and/or top_p.
50
+ e) Sample the next tokens with the probability distribution.
51
+ f) If `all_random` or temperature >= epsilon (1e-5), return the
52
+ randomly sampled tokens and final logprobs if requested. Else,
53
+ return the greedily sampled tokens and logprobs if requested.
54
+ 8. Gather the logprobs of the top `max_num_logprobs` and sampled token
55
+ (if requested). Note that if the sampled token is within the top
56
+ `max_num_logprobs`, the logprob will be eventually merged in
57
+ `LogprobsProcessor` during output processing. Therefore, the
58
+ final output may contain either `max_num_logprobs + 1` or
59
+ `max_num_logprobs` logprobs.
60
+ 9. Return the final `SamplerOutput`.
61
+ """
62
+
63
+ def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs"):
64
+ super().__init__()
65
+ self.topk_topp_sampler = TopKTopPSampler(logprobs_mode)
66
+ self.pin_memory = is_pin_memory_available()
67
+ self.logprobs_mode = logprobs_mode
68
+
69
+ def forward(
70
+ self,
71
+ logits: torch.Tensor,
72
+ sampling_metadata: SamplingMetadata,
73
+ ) -> SamplerOutput:
74
+ # NOTE(woosuk): Use the original logits (before any penalties or
75
+ # temperature scaling) for the top-k logprobs.
76
+ # This is different from the V0 sampler, which uses the logits that
77
+ # is used for sampling (after penalties and temperature scaling).
78
+
79
+ # Jianwei Yu CFG debug
80
+ # print(dir(sampling_metadata))
81
+ # import pdb; pdb.set_trace()
82
+
83
+ # if sampling_metadata.seq_groups[0].sampling_params.guidance_scale:
84
+ # if sampling_metadata.seq_groups[0].sampling_params.guidance_scale != 1.0:
85
+ # print("Guidance scale is not 1.0, processing logits")
86
+ # print("Guidance scale: {}".format(sampling_metadata.seq_groups[0].sampling_params.guidance_scale))
87
+ # print(logits.shape)
88
+
89
+ # if logits.shape[0] == 2 and logits.ndim == 2: # batch为1的情况
90
+ # logits = logits.to(torch.float32)
91
+ # scores = torch.nn.functional.log_softmax(logits, dim=-1)
92
+ # # scores_processed = (sampling_metadata.seq_groups[0].sampling_params.guidance_scale * (scores[0] - scores[1]) + scores[1])
93
+ # scores_processed = (1.8 * (scores[0] - scores[1]) + scores[1])
94
+ # # import random;
95
+ # # tmp = random.random()
96
+ # # scores_processed = (1.8 * (scores[0] - tmp) + tmp)
97
+ # # scores_processed = torch.stack([scores_processed.clone(), scores_processed.clone()])
98
+ # scores_processed = torch.stack([scores_processed.clone(), scores[0].clone()])
99
+ # # def logits_processor_stage1(logits):
100
+ # # blocked_token_ids = list(range(0, 32002))+[32016]
101
+ # # logits[:,blocked_token_ids] = -float("inf")
102
+ # # return logits
103
+
104
+ # logits = scores_processed
105
+ # # logits = logits_processor_stage1(logits)
106
+
107
+
108
+ '''单条推理CFG'''
109
+ # if logits.shape[0] == 3:
110
+ if logits.shape[0] > 1 and logits.shape[0] != 1024:
111
+ logits = logits.to(torch.float32)
112
+ scores = torch.nn.functional.log_softmax(logits, dim=-1)
113
+ # scores_reshaped = scores.reshape(-1, 2, *scores.shape[1:])
114
+ scores_cond = scores[-2]
115
+ scores_uncond = scores[-1]
116
+ scores_processed = sampling_metadata.guidance_scale * (scores_cond - scores_uncond) + scores_uncond
117
+ # scores_processed = processed_groups.repeat_interleave(2, dim=0)
118
+ if logits.shape[0] == 3:
119
+ scores_processed = torch.stack([scores[0].clone(), scores_processed.clone(), scores_processed.clone()])
120
+ elif logits.shape[0] == 2:
121
+ scores_processed = torch.stack([scores_processed.clone(), scores_processed.clone()])
122
+ logits = scores_processed
123
+ # else:
124
+ # print("Warning: logits shape is not 3, the dim is {}".format(logits.shape[0]))
125
+
126
+ num_logprobs = sampling_metadata.max_num_logprobs
127
+ if num_logprobs is not None:
128
+ if self.logprobs_mode == "raw_logprobs":
129
+ raw_logprobs = self.compute_logprobs(logits)
130
+ elif self.logprobs_mode == "raw_logits":
131
+ raw_logprobs = logits.clone()
132
+
133
+ # Use float32 for the logits.
134
+ logits = logits.to(torch.float32)
135
+ # Apply allowed token ids.
136
+ logits = self.apply_allowed_token_ids(logits, sampling_metadata)
137
+ # Apply bad words exclusion.
138
+ logits = self.apply_bad_words(logits, sampling_metadata)
139
+
140
+ # Apply logits processors which can impact greedy sampling
141
+ for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
142
+ logits = processor.apply(logits)
143
+
144
+ # Apply penalties (e.g., min_tokens, freq_penalties).
145
+ logits = self.apply_penalties(logits, sampling_metadata)
146
+
147
+ # Sample the next token.
148
+ sampled, processed_logprobs = self.sample(logits, sampling_metadata)
149
+ if processed_logprobs is not None:
150
+ raw_logprobs = processed_logprobs
151
+ # Convert sampled token ids to int64 (long) type to ensure compatibility
152
+ # with subsequent operations that may use these values as indices.
153
+ # This conversion is necessary because FlashInfer sampling operations
154
+ # return int32 (while PyTorch argmax and topk return int64).
155
+ sampled = sampled.long()
156
+
157
+ # Gather the logprobs of the topk and sampled token (if requested).
158
+ # Get logprobs and rank tensors (if requested)
159
+ logprobs_tensors = None if num_logprobs is None else \
160
+ self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=sampled)
161
+
162
+ # Use int32 to reduce the tensor size.
163
+ sampled = sampled.to(torch.int32)
164
+
165
+ # These are GPU tensors.
166
+ sampler_output = SamplerOutput(
167
+ # The sampled tokens are expanded to 2D tensor with shape
168
+ # [num_requests, 1], where each row represents one generated
169
+ # token per request.
170
+ sampled_token_ids=sampled.unsqueeze(-1),
171
+ logprobs_tensors=logprobs_tensors,
172
+ )
173
+ # print(sampler_output)
174
+ # print(sampler_output.sampled_token_ids.shape)
175
+ # if sampler_output.sampled_token_ids.shape[0] != 1024 and sampler_output.sampled_token_ids.shape[0] != 1:
176
+ # import pdb; pdb.set_trace()
177
+ # pass
178
+ return sampler_output
179
+
180
+ def apply_temperature(
181
+ self,
182
+ logits: torch.Tensor,
183
+ temp: torch.Tensor,
184
+ all_random: bool,
185
+ ) -> torch.Tensor:
186
+ # Use in-place division to avoid creating a new tensor.
187
+ # Avoid division by zero if there are greedy requests.
188
+ if not all_random:
189
+ temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
190
+ return logits.div_(temp.unsqueeze(dim=1))
191
+
192
+ def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
193
+ return logits.argmax(dim=-1).view(-1)
194
+
195
+ def sample(
196
+ self,
197
+ logits: torch.Tensor,
198
+ sampling_metadata: SamplingMetadata,
199
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
200
+ """Sample logits based on sampling metadata.
201
+
202
+ The various logits processing functions called in this method
203
+ may update the logits tensor in-place.
204
+ """
205
+
206
+ assert not (sampling_metadata.all_greedy
207
+ and sampling_metadata.all_random)
208
+ if sampling_metadata.all_random:
209
+ greedy_sampled = None
210
+ else:
211
+ greedy_sampled = self.greedy_sample(logits)
212
+ if sampling_metadata.all_greedy:
213
+ processed_logprobs = None
214
+ if sampling_metadata.max_num_logprobs is not None:
215
+ if self.logprobs_mode == "processed_logits":
216
+ processed_logprobs = logits
217
+ elif self.logprobs_mode == "processed_logprobs":
218
+ processed_logprobs = self.compute_logprobs(logits)
219
+ return greedy_sampled, processed_logprobs
220
+
221
+ assert sampling_metadata.temperature is not None
222
+
223
+ # Apply temperature.
224
+ logits = self.apply_temperature(logits, sampling_metadata.temperature,
225
+ sampling_metadata.all_random)
226
+
227
+ # Apply logits processors that only apply to random sampling
228
+ # (argmax invariant)
229
+ for processor in sampling_metadata.logitsprocs.argmax_invariant:
230
+ logits = processor.apply(logits)
231
+
232
+ # Apply top_k and/or top_p.
233
+ random_sampled, processed_logprobs = self.topk_topp_sampler(
234
+ logits,
235
+ sampling_metadata.generators,
236
+ sampling_metadata.top_k,
237
+ sampling_metadata.top_p,
238
+ )
239
+
240
+ if greedy_sampled is None:
241
+ return random_sampled, processed_logprobs
242
+
243
+ sampled = torch.where(
244
+ sampling_metadata.temperature < _SAMPLING_EPS,
245
+ greedy_sampled,
246
+ random_sampled,
247
+ out=greedy_sampled, # Reuse tensor
248
+ )
249
+ return sampled, processed_logprobs
250
+
251
+ def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
252
+ return logits.log_softmax(dim=-1, dtype=torch.float32)
253
+
254
+ def gather_logprobs(
255
+ self,
256
+ logprobs: torch.Tensor,
257
+ num_logprobs: int,
258
+ token_ids: torch.Tensor,
259
+ ) -> LogprobsTensors:
260
+ """
261
+ Gather logprobs for topk and sampled/prompt token.
262
+
263
+ Args:
264
+ logprobs: (num tokens) x (vocab) tensor
265
+ num_logprobs: minimum number of logprobs to
266
+ retain per token
267
+ token_ids: prompt tokens (if prompt logprobs)
268
+ or sampled tokens (if sampled
269
+ logprobs); 1D token ID tensor
270
+ with (num tokens) elements
271
+ Must be int64.
272
+
273
+ Returns:
274
+ Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
275
+ Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
276
+ Sampled token rank tensor, (num tokens)
277
+ """
278
+ assert token_ids.dtype == torch.int64
279
+ # Find the topK values.
280
+ topk_logprobs, topk_indices = torch.topk(logprobs,
281
+ num_logprobs,
282
+ dim=-1)
283
+
284
+ # Get with the logprob of the prompt or sampled token.
285
+ token_ids = token_ids.unsqueeze(-1)
286
+ token_logprobs = logprobs.gather(-1, token_ids)
287
+
288
+ # Compute the ranks of the actual token.
289
+ token_ranks = batched_count_greater_than(logprobs, token_logprobs)
290
+
291
+ # Concatenate together with the topk.
292
+ indices = torch.cat((token_ids, topk_indices), dim=1)
293
+ logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)
294
+
295
+ # Use int32 to reduce the tensor size.
296
+ indices = indices.to(torch.int32)
297
+
298
+ return LogprobsTensors(indices, logprobs, token_ranks)
299
+
300
+ def apply_penalties(
301
+ self,
302
+ logits: torch.Tensor,
303
+ sampling_metadata: SamplingMetadata,
304
+ ) -> torch.Tensor:
305
+ if not sampling_metadata.no_penalties:
306
+ assert sampling_metadata.prompt_token_ids is not None
307
+ logits = apply_all_penalties(
308
+ logits,
309
+ sampling_metadata.prompt_token_ids,
310
+ sampling_metadata.presence_penalties,
311
+ sampling_metadata.frequency_penalties,
312
+ sampling_metadata.repetition_penalties,
313
+ sampling_metadata.output_token_ids,
314
+ )
315
+ return logits
316
+
317
+ def apply_allowed_token_ids(
318
+ self,
319
+ logits: torch.Tensor,
320
+ sampling_metadata: SamplingMetadata,
321
+ ) -> torch.Tensor:
322
+ if sampling_metadata.allowed_token_ids_mask is not None:
323
+ logits.masked_fill_(sampling_metadata.allowed_token_ids_mask,
324
+ float("-inf"))
325
+ return logits
326
+
327
+ def apply_bad_words(
328
+ self,
329
+ logits: torch.Tensor,
330
+ sampling_metadata: SamplingMetadata,
331
+ ) -> torch.Tensor:
332
+ if sampling_metadata.bad_words_token_ids:
333
+ apply_bad_words(
334
+ logits,
335
+ sampling_metadata.bad_words_token_ids,
336
+ sampling_metadata.output_token_ids,
337
+ )
338
+ return logits
vllm_hacked/v1/sample/sampler_ori.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ """A layer that samples the next tokens from the model's outputs."""
4
+
5
+ from typing import Optional
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from vllm.config import LogprobsMode
11
+ from vllm.utils import is_pin_memory_available
12
+ from vllm.v1.outputs import LogprobsTensors, SamplerOutput
13
+ from vllm.v1.sample.metadata import SamplingMetadata
14
+ from vllm.v1.sample.ops.bad_words import apply_bad_words
15
+ from vllm.v1.sample.ops.logprobs import batched_count_greater_than
16
+ from vllm.v1.sample.ops.penalties import apply_all_penalties
17
+ from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
18
+
19
+ _SAMPLING_EPS = 1e-5
20
+
21
+
22
+ class Sampler(nn.Module):
23
+ """
24
+ A layer that samples the next tokens from the model's outputs
25
+ with the following steps in order:
26
+
27
+ 1. If logprobs are requested:
28
+ a) If `logprobs_mode` is `raw_logprobs`, compute logprobs
29
+ as the final logprobs to return.
30
+ b) If `logprobs_mode` is `raw_logits`, clone the logits
31
+ as the final logprobs to return.
32
+ 2. Convert logits to float32.
33
+ 3. Apply allowed token ids whitelist.
34
+ 4. Apply bad words exclusion.
35
+ 5. Apply logit processors which are not argmax-invariant,
36
+ i.e. that can impact greedy sampling.
37
+ a) Min tokens processor
38
+ b) Logit bias processor
39
+ 6. Apply penalties
40
+ a) Repetition penalty
41
+ b) Frequency penalty
42
+ c) Presence penalty
43
+ 7. Sample the next tokens. `sample` method performs the following steps:
44
+ a) If not `all_random`, perform greedy sampling. If `all_greedy`,
45
+ return the greedily sampled tokens and final logprobs if requested.
46
+ b) Apply temperature.
47
+ c) Apply logit processors which are argmax-invariant, by default
48
+ the min_p processor.
49
+ d) Apply top_k and/or top_p.
50
+ e) Sample the next tokens with the probability distribution.
51
+ f) If `all_random` or temperature >= epsilon (1e-5), return the
52
+ randomly sampled tokens and final logprobs if requested. Else,
53
+ return the greedily sampled tokens and logprobs if requested.
54
+ 8. Gather the logprobs of the top `max_num_logprobs` and sampled token
55
+ (if requested). Note that if the sampled token is within the top
56
+ `max_num_logprobs`, the logprob will be eventually merged in
57
+ `LogprobsProcessor` during output processing. Therefore, the
58
+ final output may contain either `max_num_logprobs + 1` or
59
+ `max_num_logprobs` logprobs.
60
+ 9. Return the final `SamplerOutput`.
61
+ """
62
+
63
+ def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs"):
64
+ super().__init__()
65
+ self.topk_topp_sampler = TopKTopPSampler(logprobs_mode)
66
+ self.pin_memory = is_pin_memory_available()
67
+ self.logprobs_mode = logprobs_mode
68
+
69
+ def forward(
70
+ self,
71
+ logits: torch.Tensor,
72
+ sampling_metadata: SamplingMetadata,
73
+ ) -> SamplerOutput:
74
+ # NOTE(woosuk): Use the original logits (before any penalties or
75
+ # temperature scaling) for the top-k logprobs.
76
+ # This is different from the V0 sampler, which uses the logits that
77
+ # is used for sampling (after penalties and temperature scaling).
78
+ num_logprobs = sampling_metadata.max_num_logprobs
79
+ if num_logprobs is not None:
80
+ if self.logprobs_mode == "raw_logprobs":
81
+ raw_logprobs = self.compute_logprobs(logits)
82
+ elif self.logprobs_mode == "raw_logits":
83
+ raw_logprobs = logits.clone()
84
+
85
+ # Use float32 for the logits.
86
+ logits = logits.to(torch.float32)
87
+ # Apply allowed token ids.
88
+ logits = self.apply_allowed_token_ids(logits, sampling_metadata)
89
+ # Apply bad words exclusion.
90
+ logits = self.apply_bad_words(logits, sampling_metadata)
91
+
92
+ # Apply logits processors which can impact greedy sampling
93
+ for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
94
+ logits = processor.apply(logits)
95
+
96
+ # Apply penalties (e.g., min_tokens, freq_penalties).
97
+ logits = self.apply_penalties(logits, sampling_metadata)
98
+
99
+ # Sample the next token.
100
+ sampled, processed_logprobs = self.sample(logits, sampling_metadata)
101
+ if processed_logprobs is not None:
102
+ raw_logprobs = processed_logprobs
103
+ # Convert sampled token ids to int64 (long) type to ensure compatibility
104
+ # with subsequent operations that may use these values as indices.
105
+ # This conversion is necessary because FlashInfer sampling operations
106
+ # return int32 (while PyTorch argmax and topk return int64).
107
+ sampled = sampled.long()
108
+
109
+ # Gather the logprobs of the topk and sampled token (if requested).
110
+ # Get logprobs and rank tensors (if requested)
111
+ logprobs_tensors = None if num_logprobs is None else \
112
+ self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=sampled)
113
+
114
+ # Use int32 to reduce the tensor size.
115
+ sampled = sampled.to(torch.int32)
116
+
117
+ # These are GPU tensors.
118
+ sampler_output = SamplerOutput(
119
+ # The sampled tokens are expanded to 2D tensor with shape
120
+ # [num_requests, 1], where each row represents one generated
121
+ # token per request.
122
+ sampled_token_ids=sampled.unsqueeze(-1),
123
+ logprobs_tensors=logprobs_tensors,
124
+ )
125
+ return sampler_output
126
+
127
+ def apply_temperature(
128
+ self,
129
+ logits: torch.Tensor,
130
+ temp: torch.Tensor,
131
+ all_random: bool,
132
+ ) -> torch.Tensor:
133
+ # Use in-place division to avoid creating a new tensor.
134
+ # Avoid division by zero if there are greedy requests.
135
+ if not all_random:
136
+ temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
137
+ return logits.div_(temp.unsqueeze(dim=1))
138
+
139
+ def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
140
+ return logits.argmax(dim=-1).view(-1)
141
+
142
+ def sample(
143
+ self,
144
+ logits: torch.Tensor,
145
+ sampling_metadata: SamplingMetadata,
146
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
147
+ """Sample logits based on sampling metadata.
148
+
149
+ The various logits processing functions called in this method
150
+ may update the logits tensor in-place.
151
+ """
152
+
153
+ assert not (sampling_metadata.all_greedy
154
+ and sampling_metadata.all_random)
155
+ if sampling_metadata.all_random:
156
+ greedy_sampled = None
157
+ else:
158
+ greedy_sampled = self.greedy_sample(logits)
159
+ if sampling_metadata.all_greedy:
160
+ processed_logprobs = None
161
+ if sampling_metadata.max_num_logprobs is not None:
162
+ if self.logprobs_mode == "processed_logits":
163
+ processed_logprobs = logits
164
+ elif self.logprobs_mode == "processed_logprobs":
165
+ processed_logprobs = self.compute_logprobs(logits)
166
+ return greedy_sampled, processed_logprobs
167
+
168
+ assert sampling_metadata.temperature is not None
169
+
170
+ # Apply temperature.
171
+ logits = self.apply_temperature(logits, sampling_metadata.temperature,
172
+ sampling_metadata.all_random)
173
+
174
+ # Apply logits processors that only apply to random sampling
175
+ # (argmax invariant)
176
+ for processor in sampling_metadata.logitsprocs.argmax_invariant:
177
+ logits = processor.apply(logits)
178
+
179
+ # Apply top_k and/or top_p.
180
+ random_sampled, processed_logprobs = self.topk_topp_sampler(
181
+ logits,
182
+ sampling_metadata.generators,
183
+ sampling_metadata.top_k,
184
+ sampling_metadata.top_p,
185
+ )
186
+
187
+ if greedy_sampled is None:
188
+ return random_sampled, processed_logprobs
189
+
190
+ sampled = torch.where(
191
+ sampling_metadata.temperature < _SAMPLING_EPS,
192
+ greedy_sampled,
193
+ random_sampled,
194
+ out=greedy_sampled, # Reuse tensor
195
+ )
196
+ return sampled, processed_logprobs
197
+
198
+ def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
199
+ return logits.log_softmax(dim=-1, dtype=torch.float32)
200
+
201
+ def gather_logprobs(
202
+ self,
203
+ logprobs: torch.Tensor,
204
+ num_logprobs: int,
205
+ token_ids: torch.Tensor,
206
+ ) -> LogprobsTensors:
207
+ """
208
+ Gather logprobs for topk and sampled/prompt token.
209
+
210
+ Args:
211
+ logprobs: (num tokens) x (vocab) tensor
212
+ num_logprobs: minimum number of logprobs to
213
+ retain per token
214
+ token_ids: prompt tokens (if prompt logprobs)
215
+ or sampled tokens (if sampled
216
+ logprobs); 1D token ID tensor
217
+ with (num tokens) elements
218
+ Must be int64.
219
+
220
+ Returns:
221
+ Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
222
+ Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
223
+ Sampled token rank tensor, (num tokens)
224
+ """
225
+ assert token_ids.dtype == torch.int64
226
+ # Find the topK values.
227
+ topk_logprobs, topk_indices = torch.topk(logprobs,
228
+ num_logprobs,
229
+ dim=-1)
230
+
231
+ # Get with the logprob of the prompt or sampled token.
232
+ token_ids = token_ids.unsqueeze(-1)
233
+ token_logprobs = logprobs.gather(-1, token_ids)
234
+
235
+ # Compute the ranks of the actual token.
236
+ token_ranks = batched_count_greater_than(logprobs, token_logprobs)
237
+
238
+ # Concatenate together with the topk.
239
+ indices = torch.cat((token_ids, topk_indices), dim=1)
240
+ logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)
241
+
242
+ # Use int32 to reduce the tensor size.
243
+ indices = indices.to(torch.int32)
244
+
245
+ return LogprobsTensors(indices, logprobs, token_ranks)
246
+
247
+ def apply_penalties(
248
+ self,
249
+ logits: torch.Tensor,
250
+ sampling_metadata: SamplingMetadata,
251
+ ) -> torch.Tensor:
252
+ if not sampling_metadata.no_penalties:
253
+ assert sampling_metadata.prompt_token_ids is not None
254
+ logits = apply_all_penalties(
255
+ logits,
256
+ sampling_metadata.prompt_token_ids,
257
+ sampling_metadata.presence_penalties,
258
+ sampling_metadata.frequency_penalties,
259
+ sampling_metadata.repetition_penalties,
260
+ sampling_metadata.output_token_ids,
261
+ )
262
+ return logits
263
+
264
+ def apply_allowed_token_ids(
265
+ self,
266
+ logits: torch.Tensor,
267
+ sampling_metadata: SamplingMetadata,
268
+ ) -> torch.Tensor:
269
+ if sampling_metadata.allowed_token_ids_mask is not None:
270
+ logits.masked_fill_(sampling_metadata.allowed_token_ids_mask,
271
+ float("-inf"))
272
+ return logits
273
+
274
+ def apply_bad_words(
275
+ self,
276
+ logits: torch.Tensor,
277
+ sampling_metadata: SamplingMetadata,
278
+ ) -> torch.Tensor:
279
+ if sampling_metadata.bad_words_token_ids:
280
+ apply_bad_words(
281
+ logits,
282
+ sampling_metadata.bad_words_token_ids,
283
+ sampling_metadata.output_token_ids,
284
+ )
285
+ return logits
vllm_hacked/v1/spec_decode/utils.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ from vllm.v1.worker.gpu_input_batch import InputBatch
3
+
4
+
5
+ def is_spec_decode_supported(req_id: str, input_batch: InputBatch) -> bool:
6
+ if req_id in input_batch.min_p_reqs:
7
+ # Spec decode doesn't support min_p sampling.
8
+ return False
9
+ elif (req_id in input_batch.frequency_penalties_reqs
10
+ or req_id in input_batch.presence_penalties_reqs
11
+ or req_id in input_batch.repetition_penalties_reqs):
12
+ # Spec decode doesn't support penalties.
13
+ return False
14
+ elif req_id in input_batch.num_logprobs:
15
+ # Spec decode doesn't support logprobs.
16
+ return False
17
+
18
+ return True
vllm_hacked/v1/spec_decode/utils_ori.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ from vllm.sampling_params import SamplingParams
4
+
5
+ _SAMPLING_EPS = 1e-5
6
+
7
+
8
+ def is_spec_decode_unsupported(sampling_params: SamplingParams) -> bool:
9
+ """True if request is incompatible with speculative decoding"""
10
+ return (sampling_params.frequency_penalty != 0.0
11
+ or sampling_params.presence_penalty != 0.0
12
+ or sampling_params.repetition_penalty != 1.0
13
+ or sampling_params.min_p > _SAMPLING_EPS
14
+ or sampling_params.logprobs is not None)
vllm_hacked/v1/utils_ori.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ import argparse
4
+ import contextlib
5
+ import multiprocessing
6
+ import time
7
+ import weakref
8
+ from collections.abc import Sequence
9
+ from contextlib import AbstractContextManager
10
+ from multiprocessing import connection
11
+ from multiprocessing.process import BaseProcess
12
+ from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar,
13
+ Union, overload)
14
+
15
+ import torch
16
+ from torch.autograd.profiler import record_function
17
+
18
+ import vllm.envs as envs
19
+ from vllm.logger import init_logger
20
+ from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
21
+ usage_message)
22
+ from vllm.utils import (get_open_port, get_open_zmq_ipc_path, get_tcp_uri,
23
+ kill_process_tree)
24
+
25
+ if TYPE_CHECKING:
26
+ import numpy as np
27
+
28
+ from vllm.v1.engine.coordinator import DPCoordinator
29
+ from vllm.v1.engine.utils import (CoreEngineActorManager,
30
+ CoreEngineProcManager)
31
+
32
+ logger = init_logger(__name__)
33
+
34
+ T = TypeVar("T")
35
+
36
+
37
+ class ConstantList(Generic[T], Sequence):
38
+
39
+ def __init__(self, x: list[T]) -> None:
40
+ self._x = x
41
+
42
+ def append(self, item):
43
+ raise TypeError("Cannot append to a constant list")
44
+
45
+ def extend(self, item):
46
+ raise TypeError("Cannot extend a constant list")
47
+
48
+ def insert(self, item):
49
+ raise TypeError("Cannot insert into a constant list")
50
+
51
+ def pop(self, item):
52
+ raise TypeError("Cannot pop from a constant list")
53
+
54
+ def remove(self, item):
55
+ raise TypeError("Cannot remove from a constant list")
56
+
57
+ def clear(self):
58
+ raise TypeError("Cannot clear a constant list")
59
+
60
+ def index(self,
61
+ item: T,
62
+ start: int = 0,
63
+ stop: Optional[int] = None) -> int:
64
+ return self._x.index(item, start,
65
+ stop if stop is not None else len(self._x))
66
+
67
+ @overload
68
+ def __getitem__(self, item: int) -> T:
69
+ ...
70
+
71
+ @overload
72
+ def __getitem__(self, s: slice, /) -> list[T]:
73
+ ...
74
+
75
+ def __getitem__(self, item: Union[int, slice]) -> Union[T, list[T]]:
76
+ return self._x[item]
77
+
78
+ @overload
79
+ def __setitem__(self, item: int, value: T):
80
+ ...
81
+
82
+ @overload
83
+ def __setitem__(self, s: slice, value: T, /):
84
+ ...
85
+
86
+ def __setitem__(self, item: Union[int, slice], value: Union[T, list[T]]):
87
+ raise TypeError("Cannot set item in a constant list")
88
+
89
+ def __delitem__(self, item):
90
+ raise TypeError("Cannot delete item from a constant list")
91
+
92
+ def __iter__(self):
93
+ return iter(self._x)
94
+
95
+ def __contains__(self, item):
96
+ return item in self._x
97
+
98
+ def __len__(self):
99
+ return len(self._x)
100
+
101
+ def __repr__(self):
102
+ return f"ConstantList({self._x})"
103
+
104
+
105
+ class CpuGpuBuffer:
106
+ """Buffer to easily copy tensors between CPU and GPU."""
107
+
108
+ def __init__(
109
+ self,
110
+ *size: Union[int, torch.SymInt],
111
+ dtype: torch.dtype,
112
+ device: torch.device,
113
+ pin_memory: bool,
114
+ with_numpy: bool = True,
115
+ ) -> None:
116
+ self.cpu = torch.zeros(*size,
117
+ dtype=dtype,
118
+ device="cpu",
119
+ pin_memory=pin_memory)
120
+ self.gpu = self.cpu.to(device)
121
+ self.np: np.ndarray
122
+ # To keep type hints simple (avoiding generics and subclasses), we
123
+ # only conditionally create the numpy array attribute. This can cause
124
+ # AttributeError if `self.np` is accessed when `with_numpy=False`.
125
+ if with_numpy:
126
+ if dtype == torch.bfloat16:
127
+ raise ValueError(
128
+ "Bfloat16 torch tensors cannot be directly cast to a "
129
+ "numpy array, so call CpuGpuBuffer with with_numpy=False")
130
+ self.np = self.cpu.numpy()
131
+
132
+ def copy_to_gpu(self, n: Optional[int] = None) -> torch.Tensor:
133
+ if n is None:
134
+ return self.gpu.copy_(self.cpu, non_blocking=True)
135
+ return self.gpu[:n].copy_(self.cpu[:n], non_blocking=True)
136
+
137
+ def copy_to_cpu(self, n: Optional[int] = None) -> torch.Tensor:
138
+ """NOTE: Because this method is non-blocking, explicit synchronization
139
+ is needed to ensure the data is copied to CPU."""
140
+ if n is None:
141
+ return self.cpu.copy_(self.gpu, non_blocking=True)
142
+ return self.cpu[:n].copy_(self.gpu[:n], non_blocking=True)
143
+
144
+
145
+ def get_engine_client_zmq_addr(local_only: bool,
146
+ host: str,
147
+ port: int = 0) -> str:
148
+ """Assign a new ZMQ socket address.
149
+
150
+ If local_only is True, participants are colocated and so a unique IPC
151
+ address will be returned.
152
+
153
+ Otherwise, the provided host and port will be used to construct a TCP
154
+ address (port == 0 means assign an available port)."""
155
+
156
+ return get_open_zmq_ipc_path() if local_only else (get_tcp_uri(
157
+ host, port or get_open_port()))
158
+
159
+
160
+ class APIServerProcessManager:
161
+ """Manages a group of API server processes.
162
+
163
+ Handles creation, monitoring, and termination of API server worker
164
+ processes. Also monitors extra processes to check if they are healthy.
165
+ """
166
+
167
+ def __init__(
168
+ self,
169
+ target_server_fn: Callable,
170
+ listen_address: str,
171
+ sock: Any,
172
+ args: argparse.Namespace,
173
+ num_servers: int,
174
+ input_addresses: list[str],
175
+ output_addresses: list[str],
176
+ stats_update_address: Optional[str] = None,
177
+ ):
178
+ """Initialize and start API server worker processes.
179
+
180
+ Args:
181
+ target_server_fn: Function to call for each API server process
182
+ listen_address: Address to listen for client connections
183
+ sock: Socket for client connections
184
+ args: Command line arguments
185
+ num_servers: Number of API server processes to start
186
+ input_addresses: Input addresses for each API server
187
+ output_addresses: Output addresses for each API server
188
+ stats_update_address: Optional stats update address
189
+ """
190
+ self.listen_address = listen_address
191
+ self.sock = sock
192
+ self.args = args
193
+
194
+ # Start API servers
195
+ spawn_context = multiprocessing.get_context("spawn")
196
+ self.processes: list[BaseProcess] = []
197
+
198
+ for i, in_addr, out_addr in zip(range(num_servers), input_addresses,
199
+ output_addresses):
200
+ client_config = {
201
+ "input_address": in_addr,
202
+ "output_address": out_addr,
203
+ "client_count": num_servers,
204
+ "client_index": i
205
+ }
206
+ if stats_update_address is not None:
207
+ client_config["stats_update_address"] = stats_update_address
208
+
209
+ proc = spawn_context.Process(target=target_server_fn,
210
+ name=f"ApiServer_{i}",
211
+ args=(listen_address, sock, args,
212
+ client_config))
213
+ self.processes.append(proc)
214
+ proc.start()
215
+
216
+ logger.info("Started %d API server processes", len(self.processes))
217
+
218
+ # Shutdown only the API server processes on garbage collection
219
+ # The extra processes are managed by their owners
220
+ self._finalizer = weakref.finalize(self, shutdown, self.processes)
221
+
222
+ def close(self) -> None:
223
+ self._finalizer()
224
+
225
+
226
+ def wait_for_completion_or_failure(
227
+ api_server_manager: APIServerProcessManager,
228
+ engine_manager: Optional[Union["CoreEngineProcManager",
229
+ "CoreEngineActorManager"]] = None,
230
+ coordinator: Optional["DPCoordinator"] = None) -> None:
231
+ """Wait for all processes to complete or detect if any fail.
232
+
233
+ Raises an exception if any process exits with a non-zero status.
234
+
235
+ Args:
236
+ api_server_manager: The manager for API servers.
237
+ engine_manager: The manager for engine processes.
238
+ If CoreEngineProcManager, it manages local engines;
239
+ if CoreEngineActorManager, it manages all engines.
240
+ coordinator: The coordinator for data parallel.
241
+ """
242
+
243
+ from vllm.v1.engine.utils import (CoreEngineActorManager,
244
+ CoreEngineProcManager)
245
+
246
+ try:
247
+ logger.info("Waiting for API servers to complete ...")
248
+ # Create a mapping of sentinels to their corresponding processes
249
+ # for efficient lookup
250
+ sentinel_to_proc: dict[Any, BaseProcess] = {
251
+ proc.sentinel: proc
252
+ for proc in api_server_manager.processes
253
+ }
254
+
255
+ if coordinator:
256
+ sentinel_to_proc[coordinator.proc.sentinel] = coordinator.proc
257
+
258
+ actor_run_refs = []
259
+ if isinstance(engine_manager, CoreEngineProcManager):
260
+ for proc in engine_manager.processes:
261
+ sentinel_to_proc[proc.sentinel] = proc
262
+ elif isinstance(engine_manager, CoreEngineActorManager):
263
+ actor_run_refs = engine_manager.get_run_refs()
264
+
265
+ # Check if any process terminates
266
+ while sentinel_to_proc or actor_run_refs:
267
+ # Wait for any process to terminate
268
+ ready_sentinels: list[Any] = connection.wait(sentinel_to_proc,
269
+ timeout=5)
270
+
271
+ # Process any terminated processes
272
+ for sentinel in ready_sentinels:
273
+ proc = sentinel_to_proc.pop(sentinel)
274
+
275
+ # Check if process exited with error
276
+ if proc.exitcode != 0:
277
+ raise RuntimeError(
278
+ f"Process {proc.name} (PID: {proc.pid}) "
279
+ f"died with exit code {proc.exitcode}")
280
+
281
+ if actor_run_refs:
282
+ import ray
283
+ _, actor_run_refs = ray.wait(actor_run_refs, timeout=5)
284
+
285
+ except KeyboardInterrupt:
286
+ logger.info("Received KeyboardInterrupt, shutting down API servers...")
287
+ except Exception as e:
288
+ logger.exception("Exception occurred while running API servers: %s",
289
+ str(e))
290
+ raise
291
+ finally:
292
+ logger.info("Terminating remaining processes ...")
293
+ api_server_manager.close()
294
+ if coordinator:
295
+ coordinator.close()
296
+ if engine_manager:
297
+ engine_manager.close()
298
+
299
+
300
+ # Note(rob): shutdown function cannot be a bound method,
301
+ # else the gc cannot collect the object.
302
+ def shutdown(procs: list[BaseProcess]):
303
+ # Shutdown the process.
304
+ for proc in procs:
305
+ if proc.is_alive():
306
+ proc.terminate()
307
+
308
+ # Allow 5 seconds for remaining procs to terminate.
309
+ deadline = time.monotonic() + 5
310
+ for proc in procs:
311
+ remaining = deadline - time.monotonic()
312
+ if remaining <= 0:
313
+ break
314
+ if proc.is_alive():
315
+ proc.join(remaining)
316
+
317
+ for proc in procs:
318
+ if proc.is_alive() and (pid := proc.pid) is not None:
319
+ kill_process_tree(pid)
320
+
321
+
322
+ def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
323
+ length: int) -> torch.Tensor:
324
+ """
325
+ Copy the first length elements of a tensor into another tensor in a
326
+ non-blocking manner.
327
+
328
+ Used to copy pinned CPU tensor data to pre-allocated GPU tensors.
329
+
330
+ Returns the sliced target tensor.
331
+ """
332
+ return to_tensor[:length].copy_(from_tensor[:length], non_blocking=True)
333
+
334
+
335
+ def report_usage_stats(
336
+ vllm_config,
337
+ usage_context: UsageContext = UsageContext.ENGINE_CONTEXT) -> None:
338
+ """Report usage statistics if enabled."""
339
+
340
+ if not is_usage_stats_enabled():
341
+ return
342
+
343
+ from vllm.model_executor.model_loader import get_architecture_class_name
344
+
345
+ usage_message.report_usage(
346
+ get_architecture_class_name(vllm_config.model_config),
347
+ usage_context,
348
+ extra_kvs={
349
+ # Common configuration
350
+ "dtype":
351
+ str(vllm_config.model_config.dtype),
352
+ "tensor_parallel_size":
353
+ vllm_config.parallel_config.tensor_parallel_size,
354
+ "block_size":
355
+ vllm_config.cache_config.block_size,
356
+ "gpu_memory_utilization":
357
+ vllm_config.cache_config.gpu_memory_utilization,
358
+ "kv_cache_memory_bytes":
359
+ vllm_config.cache_config.kv_cache_memory_bytes,
360
+ # Quantization
361
+ "quantization":
362
+ vllm_config.model_config.quantization,
363
+ "kv_cache_dtype":
364
+ str(vllm_config.cache_config.cache_dtype),
365
+
366
+ # Feature flags
367
+ "enable_lora":
368
+ bool(vllm_config.lora_config),
369
+ "enable_prefix_caching":
370
+ vllm_config.cache_config.enable_prefix_caching,
371
+ "enforce_eager":
372
+ vllm_config.model_config.enforce_eager,
373
+ "disable_custom_all_reduce":
374
+ vllm_config.parallel_config.disable_custom_all_reduce,
375
+ })
376
+
377
+
378
+ _PROFILER_FUNC = None
379
+
380
+
381
+ def record_function_or_nullcontext(name: str) -> AbstractContextManager:
382
+ global _PROFILER_FUNC
383
+
384
+ # fast path assume it is set
385
+ if _PROFILER_FUNC is not None:
386
+ return _PROFILER_FUNC(name)
387
+
388
+ func = contextlib.nullcontext
389
+ if envs.VLLM_CUSTOM_SCOPES_FOR_PROFILING:
390
+ func = record_function
391
+ elif envs.VLLM_NVTX_SCOPES_FOR_PROFILING:
392
+ import nvtx
393
+ func = nvtx.annotate
394
+
395
+ _PROFILER_FUNC = func
396
+ return func(name)
vllm_hacked/v1/worker/gpu_input_batch.py ADDED
@@ -0,0 +1,669 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # Datastructures defining an input batch
3
+
4
+ from dataclasses import dataclass
5
+ from typing import Optional, cast
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ from vllm.lora.request import LoRARequest
11
+ from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
12
+ from vllm.sampling_params import SamplingParams, SamplingType
13
+ from vllm.utils import swap_dict_values
14
+ from vllm.v1.outputs import LogprobsTensors
15
+ from vllm.v1.sample.metadata import SamplingMetadata
16
+ from vllm.v1.utils import copy_slice
17
+ from vllm.v1.worker.block_table import BlockTable
18
+
19
+ _SAMPLING_EPS = 1e-5
20
+
21
+
22
+ @dataclass
23
+ class CachedRequestState:
24
+
25
+ req_id: str
26
+ prompt_token_ids: list[int]
27
+ prompt: Optional[str]
28
+ mm_inputs: list[MultiModalKwargs]
29
+ mm_positions: list[PlaceholderRange]
30
+ sampling_params: SamplingParams
31
+ generator: Optional[torch.Generator]
32
+
33
+ block_ids: list[int]
34
+ num_computed_tokens: int
35
+ output_token_ids: list[int]
36
+
37
+ mrope_positions: Optional[torch.Tensor] = None
38
+ mrope_position_delta: Optional[int] = None
39
+
40
+ lora_request: Optional[LoRARequest] = None
41
+
42
+ @property
43
+ def num_tokens(self) -> int:
44
+ return len(self.prompt_token_ids) + len(self.output_token_ids)
45
+
46
+
47
+ class InputBatch:
48
+
49
+ def __init__(
50
+ self,
51
+ max_num_reqs: int,
52
+ max_model_len: int,
53
+ max_num_blocks_per_req: int,
54
+ device: torch.device,
55
+ pin_memory: bool,
56
+ vocab_size: int,
57
+ ):
58
+ self.max_num_reqs = max_num_reqs
59
+ self.max_model_len = max_model_len
60
+ self.max_num_blocks_per_req = max_num_blocks_per_req
61
+ self.device = device
62
+ self.pin_memory = pin_memory
63
+ self.vocab_size = vocab_size
64
+
65
+ self._req_ids: list[Optional[str]] = []
66
+ self.req_id_to_index: dict[str, int] = {}
67
+
68
+ # TODO(woosuk): This buffer could be too large if max_model_len is big.
69
+ # Find a way to reduce the CPU memory usage.
70
+ # This buffer is not directly transferred to the GPU, so it does not
71
+ # need to be pinned.
72
+ self.token_ids_cpu_tensor = torch.zeros(
73
+ (max_num_reqs, max_model_len),
74
+ device="cpu",
75
+ dtype=torch.int32,
76
+ pin_memory=False,
77
+ )
78
+ self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
79
+ self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
80
+ self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
81
+ self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
82
+ self.num_computed_tokens_cpu_tensor = torch.zeros(
83
+ (max_num_reqs, ),
84
+ device="cpu",
85
+ dtype=torch.int32,
86
+ pin_memory=pin_memory,
87
+ )
88
+ self.num_computed_tokens_cpu = \
89
+ self.num_computed_tokens_cpu_tensor.numpy()
90
+
91
+ # Block table.
92
+ self.block_table = BlockTable(
93
+ max_num_reqs=max_num_reqs,
94
+ max_num_blocks_per_req=max_num_blocks_per_req,
95
+ pin_memory=pin_memory,
96
+ device=device,
97
+ )
98
+
99
+ # Sampling-related.
100
+ self.temperature = torch.empty((max_num_reqs, ),
101
+ dtype=torch.float32,
102
+ device=device)
103
+ self.temperature_cpu_tensor = torch.empty((max_num_reqs, ),
104
+ dtype=torch.float32,
105
+ device="cpu",
106
+ pin_memory=pin_memory)
107
+ self.temperature_cpu = self.temperature_cpu_tensor.numpy()
108
+ self.greedy_reqs: set[str] = set()
109
+ self.random_reqs: set[str] = set()
110
+
111
+ self.top_p = torch.empty((max_num_reqs, ),
112
+ dtype=torch.float32,
113
+ device=device)
114
+ self.top_p_cpu_tensor = torch.empty((max_num_reqs, ),
115
+ dtype=torch.float32,
116
+ device="cpu",
117
+ pin_memory=pin_memory)
118
+ self.top_p_cpu = self.top_p_cpu_tensor.numpy()
119
+ self.top_p_reqs: set[str] = set()
120
+
121
+ self.top_k = torch.empty((max_num_reqs, ),
122
+ dtype=torch.int32,
123
+ device=device)
124
+ self.top_k_cpu_tensor = torch.empty((max_num_reqs, ),
125
+ dtype=torch.int32,
126
+ device="cpu",
127
+ pin_memory=pin_memory)
128
+ self.top_k_cpu = self.top_k_cpu_tensor.numpy()
129
+ self.top_k_reqs: set[str] = set()
130
+
131
+ self.min_p = torch.empty((max_num_reqs, ),
132
+ dtype=torch.float32,
133
+ device=device)
134
+ self.min_p_cpu_tensor = torch.empty((max_num_reqs, ),
135
+ dtype=torch.float32,
136
+ device="cpu",
137
+ pin_memory=pin_memory)
138
+ self.min_p_cpu = self.min_p_cpu_tensor.numpy()
139
+ self.min_p_reqs: set[str] = set()
140
+
141
+ # Frequency penalty related data structures
142
+ self.frequency_penalties = torch.empty((max_num_reqs, ),
143
+ dtype=torch.float,
144
+ device=device)
145
+ self.frequency_penalties_cpu_tensor = torch.empty(
146
+ (max_num_reqs, ),
147
+ dtype=torch.float,
148
+ device="cpu",
149
+ pin_memory=pin_memory)
150
+ self.frequency_penalties_cpu = \
151
+ self.frequency_penalties_cpu_tensor.numpy()
152
+ self.frequency_penalties_reqs: set[str] = set()
153
+
154
+ # Presence penalty related data structures
155
+ self.presence_penalties = torch.empty((max_num_reqs, ),
156
+ dtype=torch.float,
157
+ device=device)
158
+ self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ),
159
+ dtype=torch.float,
160
+ device="cpu",
161
+ pin_memory=pin_memory)
162
+ self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy(
163
+ )
164
+ self.presence_penalties_reqs: set[str] = set()
165
+
166
+ # Repetition penalty related data structures
167
+ self.repetition_penalties = torch.empty((max_num_reqs, ),
168
+ dtype=torch.float,
169
+ device=device)
170
+ self.repetition_penalties_cpu_tensor = torch.empty(
171
+ (max_num_reqs, ),
172
+ dtype=torch.float,
173
+ device="cpu",
174
+ pin_memory=pin_memory)
175
+ self.repetition_penalties_cpu = \
176
+ self.repetition_penalties_cpu_tensor.numpy()
177
+ self.repetition_penalties_reqs: set[str] = set()
178
+
179
+ # req_index -> (min_tokens, stop_token_ids)
180
+ self.min_tokens: dict[int, tuple[int, set[int]]] = {}
181
+
182
+ # lora related
183
+ self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
184
+ dtype=np.int32)
185
+ self.lora_id_to_request_ids: dict[int, set[str]] = {}
186
+ self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
187
+
188
+ # req_index -> generator
189
+ # NOTE(woosuk): The indices of the requests that do not have their own
190
+ # generator should not be included in the dictionary.
191
+ self.generators: dict[int, torch.Generator] = {}
192
+
193
+ self.num_logprobs: dict[str, int] = {}
194
+ # NOTE(rob): num_prompt_logprobs only includes reqs
195
+ # that are currently in the prefill phase.
196
+ self.num_prompt_logprobs: dict[str, int] = {}
197
+
198
+ # To accumulate prompt logprobs tensor chunks across prefill steps.
199
+ self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
200
+
201
+ self.logit_bias: list[Optional[dict[int,
202
+ float]]] = [None] * max_num_reqs
203
+ self.has_allowed_token_ids: set[str] = set()
204
+ # NOTE(lufang): In the mask tensor, if the corresponding token allowed,
205
+ # the value is False. Since we use masked_fill_ to set -inf.
206
+ self.allowed_token_ids_mask: Optional[torch.Tensor] = None
207
+ self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None
208
+
209
+ # req_index -> bad_words_token_ids
210
+ self.bad_words_token_ids: dict[int, list[list[int]]] = {}
211
+
212
+ self.req_output_token_ids: list[Optional[list[int]]] = []
213
+
214
+ # This is updated each time the batch constituents change.
215
+ self.sampling_metadata = self._make_sampling_metadata()
216
+
217
+ @property
218
+ def req_ids(self) -> list[str]:
219
+ # None elements should only be present transiently
220
+ # while performing state updates to the batch.
221
+ return cast(list[str], self._req_ids)
222
+
223
+ def add_request(
224
+ self,
225
+ request: "CachedRequestState",
226
+ req_index: Optional[int] = None,
227
+ ) -> None:
228
+ if req_index is None:
229
+ req_index = self.num_reqs
230
+ assert req_index < self.max_num_reqs
231
+
232
+ req_id = request.req_id
233
+ if req_index == len(self._req_ids):
234
+ self._req_ids.append(req_id)
235
+ self.req_output_token_ids.append(request.output_token_ids)
236
+ else:
237
+ self._req_ids[req_index] = req_id
238
+ self.req_output_token_ids[req_index] = request.output_token_ids
239
+
240
+ self.req_id_to_index[req_id] = req_index
241
+
242
+ # Copy the prompt token ids and output token ids.
243
+ num_prompt_tokens = len(request.prompt_token_ids)
244
+ self.num_prompt_tokens[req_index] = num_prompt_tokens
245
+ self.token_ids_cpu[
246
+ req_index, :num_prompt_tokens] = request.prompt_token_ids
247
+ start_idx = num_prompt_tokens
248
+ end_idx = start_idx + len(request.output_token_ids)
249
+ self.token_ids_cpu[req_index,
250
+ start_idx:end_idx] = request.output_token_ids
251
+ # Number of token ids in token_ids_cpu.
252
+ # NOTE(woosuk): This may include spec decode tokens.
253
+ self.num_tokens[req_index] = request.num_tokens
254
+ # Number of tokens without spec decode tokens.
255
+ self.num_tokens_no_spec[req_index] = request.num_tokens
256
+
257
+ self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
258
+ self.block_table.add_row(request.block_ids, req_index)
259
+
260
+ sampling_params = request.sampling_params
261
+ if sampling_params.sampling_type == SamplingType.GREEDY:
262
+ # Avoid later division by zero.
263
+ self.temperature_cpu[req_index] = -1.0
264
+ self.greedy_reqs.add(req_id)
265
+ else:
266
+ self.temperature_cpu[req_index] = sampling_params.temperature
267
+ self.random_reqs.add(req_id)
268
+
269
+ self.top_p_cpu[req_index] = sampling_params.top_p
270
+ if sampling_params.top_p < 1:
271
+ self.top_p_reqs.add(req_id)
272
+ top_k = sampling_params.top_k
273
+ if 0 < top_k < self.vocab_size:
274
+ self.top_k_reqs.add(req_id)
275
+ else:
276
+ top_k = self.vocab_size
277
+ self.top_k_cpu[req_index] = top_k
278
+ self.min_p_cpu[req_index] = sampling_params.min_p
279
+ self.frequency_penalties_cpu[
280
+ req_index] = sampling_params.frequency_penalty
281
+ if sampling_params.min_p > _SAMPLING_EPS:
282
+ self.min_p_reqs.add(req_id)
283
+ if sampling_params.frequency_penalty != 0.0:
284
+ self.frequency_penalties_reqs.add(req_id)
285
+ self.presence_penalties_cpu[
286
+ req_index] = sampling_params.presence_penalty
287
+ if sampling_params.presence_penalty != 0.0:
288
+ self.presence_penalties_reqs.add(req_id)
289
+ self.repetition_penalties_cpu[
290
+ req_index] = sampling_params.repetition_penalty
291
+ if sampling_params.repetition_penalty != 1.0:
292
+ self.repetition_penalties_reqs.add(req_id)
293
+ if sampling_params.min_tokens:
294
+ self.min_tokens[req_index] = (sampling_params.min_tokens,
295
+ sampling_params.all_stop_token_ids)
296
+
297
+ # NOTE(woosuk): self.generators should not include the requests that
298
+ # do not have their own generator.
299
+ if request.generator is not None:
300
+ self.generators[req_index] = request.generator
301
+
302
+ if sampling_params.logprobs is not None:
303
+ self.num_logprobs[req_id] = sampling_params.logprobs
304
+ if sampling_params.prompt_logprobs is not None:
305
+ self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs
306
+ if sampling_params.logit_bias is not None:
307
+ self.logit_bias[req_index] = sampling_params.logit_bias
308
+
309
+ if sampling_params.allowed_token_ids:
310
+ self.has_allowed_token_ids.add(req_id)
311
+ if self.allowed_token_ids_mask_cpu_tensor is None:
312
+ # Lazy allocation for this tensor, which can be large.
313
+ # False means we don't fill with -inf.
314
+ self.allowed_token_ids_mask = torch.zeros(self.max_num_reqs,
315
+ self.vocab_size,
316
+ dtype=torch.bool,
317
+ device=self.device)
318
+ self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
319
+ self.max_num_reqs,
320
+ self.vocab_size,
321
+ dtype=torch.bool,
322
+ device="cpu")
323
+ self.allowed_token_ids_mask_cpu_tensor[req_index] = True
324
+ # False means we don't fill with -inf.
325
+ self.allowed_token_ids_mask_cpu_tensor[req_index][
326
+ sampling_params.allowed_token_ids] = False
327
+
328
+ if sampling_params.bad_words_token_ids:
329
+ self.bad_words_token_ids[
330
+ req_index] = sampling_params.bad_words_token_ids
331
+
332
+ # Add request lora ID
333
+ if request.lora_request:
334
+ lora_id = request.lora_request.lora_int_id
335
+ if lora_id not in self.lora_id_to_request_ids:
336
+ self.lora_id_to_request_ids[lora_id] = set()
337
+
338
+ self.request_lora_mapping[req_index] = lora_id
339
+ self.lora_id_to_request_ids[lora_id].add(request.req_id)
340
+ self.lora_id_to_lora_request[lora_id] = request.lora_request
341
+ else:
342
+ # No LoRA
343
+ self.request_lora_mapping[req_index] = 0
344
+
345
+ def remove_request(self, req_id: str) -> Optional[int]:
346
+ """This method must always be followed by a call to condense()."""
347
+
348
+ req_index = self.req_id_to_index.pop(req_id, None)
349
+ if req_index is None:
350
+ return None
351
+ self._req_ids[req_index] = None
352
+ self.req_output_token_ids[req_index] = None
353
+
354
+ self.greedy_reqs.discard(req_id)
355
+ self.random_reqs.discard(req_id)
356
+ self.top_p_reqs.discard(req_id)
357
+ self.top_k_reqs.discard(req_id)
358
+ self.min_p_reqs.discard(req_id)
359
+ self.min_tokens.pop(req_index, None)
360
+ self.frequency_penalties_reqs.discard(req_id)
361
+ self.presence_penalties_reqs.discard(req_id)
362
+ self.repetition_penalties_reqs.discard(req_id)
363
+ self.generators.pop(req_index, None)
364
+ self.num_logprobs.pop(req_id, None)
365
+ self.num_prompt_logprobs.pop(req_id, None)
366
+ self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
367
+
368
+ # LoRA
369
+ lora_id = self.request_lora_mapping[req_index]
370
+ if lora_id != 0:
371
+ self.lora_id_to_request_ids[lora_id].discard(req_id)
372
+ if len(self.lora_id_to_request_ids[lora_id]) == 0:
373
+ self.lora_id_to_request_ids.pop(lora_id)
374
+ self.lora_id_to_lora_request.pop(lora_id)
375
+ self.request_lora_mapping[req_index] = 0
376
+
377
+ self.logit_bias[req_index] = None
378
+ self.has_allowed_token_ids.discard(req_id)
379
+ if self.allowed_token_ids_mask_cpu_tensor is not None:
380
+ # False means we don't fill with -inf.
381
+ self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
382
+ self.bad_words_token_ids.pop(req_index, None)
383
+ return req_index
384
+
385
+ def swap_states(self, i1: int, i2: int) -> None:
386
+ old_id_i1 = self._req_ids[i1]
387
+ old_id_i2 = self._req_ids[i2]
388
+ self._req_ids[i1], self._req_ids[i2] =\
389
+ self._req_ids[i2], self._req_ids[i1] # noqa
390
+ self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\
391
+ self.req_output_token_ids[i2], self.req_output_token_ids[i1]
392
+ assert old_id_i1 is not None and old_id_i2 is not None
393
+ self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\
394
+ self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1]
395
+ self.num_tokens[i1], self.num_tokens[i2] =\
396
+ self.num_tokens[i2], self.num_tokens[i1]
397
+ self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\
398
+ self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1]
399
+ self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\
400
+ self.num_prompt_tokens[i2], self.num_prompt_tokens[i1]
401
+ self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\
402
+ self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1]
403
+ self.temperature_cpu[i1], self.temperature_cpu[i2] =\
404
+ self.temperature_cpu[i2], self.temperature_cpu[i1]
405
+ self.top_p_cpu[i1], self.top_p_cpu[i2] =\
406
+ self.top_p_cpu[i2], self.top_p_cpu[i1]
407
+ self.top_k_cpu[i1], self.top_k_cpu[i2] =\
408
+ self.top_k_cpu[i2], self.top_k_cpu[i1]
409
+ self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] =\
410
+ self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1]
411
+ self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] =\
412
+ self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
413
+ self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\
414
+ self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]
415
+ self.min_p_cpu[i1], self.min_p_cpu[i2] =\
416
+ self.min_p_cpu[i2], self.min_p_cpu[i1]
417
+
418
+ # NOTE: the following is unsafe
419
+ # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
420
+ # self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
421
+ # instead, we need to temporiarily copy the data for one of the indices
422
+ # TODO(lucas): optimize this by only copying valid indices
423
+ tmp = self.token_ids_cpu[i1, ...].copy()
424
+ self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
425
+ self.token_ids_cpu[i2, ...] = tmp
426
+
427
+ swap_dict_values(self.generators, i1, i2)
428
+ swap_dict_values(self.min_tokens, i1, i2)
429
+ swap_dict_values(self.bad_words_token_ids, i1, i2)
430
+
431
+ self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\
432
+ self.request_lora_mapping[i2], self.request_lora_mapping[i1]
433
+ self.logit_bias[i1], self.logit_bias[i2] =\
434
+ self.logit_bias[i2], self.logit_bias[i1]
435
+
436
+ if self.allowed_token_ids_mask_cpu_tensor is not None:
437
+ self.allowed_token_ids_mask_cpu_tensor[i1], \
438
+ self.allowed_token_ids_mask_cpu_tensor[i2] =\
439
+ self.allowed_token_ids_mask_cpu_tensor[i2], \
440
+ self.allowed_token_ids_mask_cpu_tensor[i1]
441
+ self.block_table.swap_row(i1, i2)
442
+
443
+ def condense(self, empty_req_indices: list[int]) -> None:
444
+ num_reqs = self.num_reqs
445
+ if num_reqs == 0:
446
+ # The batched states are empty.
447
+ self._req_ids.clear()
448
+ self.req_output_token_ids.clear()
449
+ return
450
+
451
+ # NOTE(woosuk): This function assumes that the empty_req_indices
452
+ # is sorted in descending order.
453
+ last_req_index = num_reqs + len(empty_req_indices) - 1
454
+ while empty_req_indices:
455
+ # Find the largest non-empty index.
456
+ while last_req_index in empty_req_indices:
457
+ last_req_index -= 1
458
+
459
+ # Find the smallest empty index.
460
+ empty_index = empty_req_indices.pop()
461
+ if empty_index >= last_req_index:
462
+ break
463
+
464
+ # Swap the states.
465
+ req_id = self._req_ids[last_req_index]
466
+ output_token_ids = self.req_output_token_ids[last_req_index]
467
+ assert req_id is not None
468
+ self._req_ids[empty_index] = req_id
469
+ self._req_ids[last_req_index] = None
470
+ self.req_output_token_ids[empty_index] = output_token_ids
471
+ self.req_output_token_ids[last_req_index] = None
472
+ self.req_id_to_index[req_id] = empty_index
473
+
474
+ num_tokens = self.num_tokens[last_req_index]
475
+ self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
476
+ last_req_index, :num_tokens]
477
+ self.num_tokens[empty_index] = num_tokens
478
+ self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
479
+ last_req_index]
480
+ self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[
481
+ last_req_index]
482
+ self.num_computed_tokens_cpu[
483
+ empty_index] = self.num_computed_tokens_cpu[last_req_index]
484
+ self.block_table.move_row(last_req_index, empty_index)
485
+ self.temperature_cpu[empty_index] = self.temperature_cpu[
486
+ last_req_index]
487
+ self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
488
+ self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
489
+ self.frequency_penalties_cpu[
490
+ empty_index] = self.frequency_penalties_cpu[last_req_index]
491
+ self.presence_penalties_cpu[
492
+ empty_index] = self.presence_penalties_cpu[last_req_index]
493
+ self.repetition_penalties_cpu[
494
+ empty_index] = self.repetition_penalties_cpu[last_req_index]
495
+ self.min_p_cpu[empty_index] = self.min_p_cpu[last_req_index]
496
+ generator = self.generators.pop(last_req_index, None)
497
+ if generator is not None:
498
+ self.generators[empty_index] = generator
499
+
500
+ min_token = self.min_tokens.pop(last_req_index, None)
501
+ if min_token is not None:
502
+ self.min_tokens[empty_index] = min_token
503
+
504
+ self.request_lora_mapping[empty_index] = self.request_lora_mapping[
505
+ last_req_index]
506
+
507
+ self.logit_bias[empty_index] = self.logit_bias[last_req_index]
508
+
509
+ if self.allowed_token_ids_mask_cpu_tensor is not None:
510
+ self.allowed_token_ids_mask_cpu_tensor[
511
+ empty_index] = self.allowed_token_ids_mask_cpu_tensor[
512
+ last_req_index]
513
+
514
+ bad_words_token_ids = self.bad_words_token_ids.pop(
515
+ last_req_index, None)
516
+ if bad_words_token_ids is not None:
517
+ self.bad_words_token_ids[empty_index] = bad_words_token_ids
518
+ # Decrement last_req_index since it is now empty.
519
+ last_req_index -= 1
520
+
521
+ # Trim lists to the batch size.
522
+ del self._req_ids[self.num_reqs:]
523
+ del self.req_output_token_ids[self.num_reqs:]
524
+
525
+ def refresh_sampling_metadata(self):
526
+ self.sampling_metadata = self._make_sampling_metadata()
527
+
528
+ def _make_sampling_metadata(self) -> SamplingMetadata:
529
+ num_reqs = self.num_reqs
530
+ if not self.all_greedy:
531
+ temperature = copy_slice(self.temperature_cpu_tensor,
532
+ self.temperature, num_reqs)
533
+ else:
534
+ temperature = None
535
+ if not self.no_top_p:
536
+ copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
537
+ if not self.no_top_k:
538
+ copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs)
539
+ if not self.no_min_p:
540
+ copy_slice(self.min_p_cpu_tensor, self.min_p, num_reqs)
541
+
542
+ if not self.no_penalties:
543
+ # Since syncing these tensors is expensive only copy them
544
+ # if necessary i.e. if there are requests which require
545
+ # penalties to be applied during sampling.
546
+ copy_slice(self.frequency_penalties_cpu_tensor,
547
+ self.frequency_penalties, num_reqs)
548
+ copy_slice(self.presence_penalties_cpu_tensor,
549
+ self.presence_penalties, num_reqs)
550
+ copy_slice(self.repetition_penalties_cpu_tensor,
551
+ self.repetition_penalties, num_reqs)
552
+
553
+ # The prompt tokens are used only for applying penalties during
554
+ # the sampling process. Hence copy these tensors only when
555
+ # there are requests which need penalties to be applied.
556
+ prompt_token_ids = self._make_prompt_token_ids_tensor()
557
+ else:
558
+ prompt_token_ids = None
559
+
560
+ allowed_token_ids_mask: Optional[torch.Tensor] = None
561
+ if not self.no_allowed_token_ids:
562
+ assert self.allowed_token_ids_mask is not None
563
+ copy_slice(self.allowed_token_ids_mask_cpu_tensor,
564
+ self.allowed_token_ids_mask, num_reqs)
565
+ allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]
566
+
567
+ return SamplingMetadata(
568
+ temperature=temperature,
569
+ all_greedy=self.all_greedy,
570
+ all_random=self.all_random,
571
+ top_p=None if self.no_top_p else self.top_p[:num_reqs],
572
+ top_k=None if self.no_top_k else self.top_k[:num_reqs],
573
+ min_p=None if self.no_min_p else self.min_p[:num_reqs],
574
+ generators=self.generators,
575
+ max_num_logprobs=self.max_num_logprobs,
576
+ prompt_token_ids=prompt_token_ids,
577
+ frequency_penalties=self.frequency_penalties[:num_reqs],
578
+ presence_penalties=self.presence_penalties[:num_reqs],
579
+ repetition_penalties=self.repetition_penalties[:num_reqs],
580
+ output_token_ids=cast(list[list[int]], self.req_output_token_ids),
581
+ min_tokens=self.min_tokens,
582
+ no_penalties=self.no_penalties,
583
+ logit_bias=self.logit_bias[:num_reqs],
584
+ allowed_token_ids_mask=allowed_token_ids_mask,
585
+ bad_words_token_ids=self.bad_words_token_ids,
586
+ )
587
+
588
+ def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
589
+ max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
590
+ prompt_token_ids_cpu_tensor = torch.empty(
591
+ (self.num_reqs, max_prompt_len),
592
+ device="cpu",
593
+ dtype=torch.int64,
594
+ pin_memory=self.pin_memory,
595
+ )
596
+ prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
597
+ prompt_token_ids[:] = self.token_ids_cpu[:self.
598
+ num_reqs, :max_prompt_len]
599
+ # Use the value of vocab_size as a pad since we don't have a
600
+ # token_id of this value.
601
+ for i in range(self.num_reqs):
602
+ prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size
603
+ return prompt_token_ids_cpu_tensor.to(device=self.device,
604
+ non_blocking=True)
605
+
606
+ def make_lora_inputs(
607
+ self, num_scheduled_tokens: np.ndarray
608
+ ) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
609
+ """
610
+ Given the num_scheduled_tokens for each request in the batch, return
611
+ datastructures used to activate the current LoRAs.
612
+ Returns:
613
+ 1. prompt_lora_mapping: A tuple of size self.num_reqs where,
614
+ prompt_lora_mapping[i] is the LoRA id to use for the ith prompt.
615
+ 2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
616
+ where, token_lora_mapping[i] is the LoRA id to use for ith token.
617
+ 3. lora_requests: Set of relevant LoRA requests.
618
+ """
619
+
620
+ req_lora_mapping = self.request_lora_mapping[:self.num_reqs]
621
+ prompt_lora_mapping = tuple(req_lora_mapping)
622
+ token_lora_mapping = tuple(
623
+ req_lora_mapping.repeat(num_scheduled_tokens))
624
+ active_lora_requests: set[LoRARequest] = set(
625
+ self.lora_id_to_lora_request.values())
626
+
627
+ return prompt_lora_mapping, token_lora_mapping, active_lora_requests
628
+
629
+ @property
630
+ def num_reqs(self) -> int:
631
+ return len(self.req_id_to_index)
632
+
633
+ @property
634
+ def all_greedy(self) -> bool:
635
+ return len(self.random_reqs) == 0
636
+
637
+ @property
638
+ def all_random(self) -> bool:
639
+ return len(self.greedy_reqs) == 0
640
+
641
+ @property
642
+ def no_top_p(self) -> bool:
643
+ return len(self.top_p_reqs) == 0
644
+
645
+ @property
646
+ def no_top_k(self) -> bool:
647
+ return len(self.top_k_reqs) == 0
648
+
649
+ @property
650
+ def no_min_p(self) -> bool:
651
+ return len(self.min_p_reqs) == 0
652
+
653
+ @property
654
+ def no_penalties(self) -> bool:
655
+ return (len(self.presence_penalties_reqs) == 0
656
+ and len(self.frequency_penalties_reqs) == 0
657
+ and len(self.repetition_penalties_reqs) == 0)
658
+
659
+ @property
660
+ def max_num_logprobs(self) -> Optional[int]:
661
+ return max(self.num_logprobs.values()) if self.num_logprobs else None
662
+
663
+ @property
664
+ def no_prompt_logprob(self) -> bool:
665
+ return not self.num_prompt_logprobs
666
+
667
+ @property
668
+ def no_allowed_token_ids(self) -> bool:
669
+ return len(self.has_allowed_token_ids) == 0
vllm_hacked/v1/worker/gpu_input_batch_ori.py ADDED
@@ -0,0 +1,863 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ # Datastructures defining a GPU input batch
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Optional, cast
7
+
8
+ import numpy as np
9
+ import torch
10
+ from typing_extensions import deprecated
11
+
12
+ from vllm.lora.request import LoRARequest
13
+ from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItems
14
+ from vllm.pooling_params import PoolingParams
15
+ from vllm.sampling_params import SamplingParams, SamplingType
16
+ from vllm.utils import length_from_prompt_token_ids_or_embeds, swap_dict_values
17
+ from vllm.v1.outputs import LogprobsTensors
18
+ from vllm.v1.pool.metadata import PoolingMetadata
19
+ from vllm.v1.sample.logits_processor import (BatchUpdateBuilder,
20
+ LogitsProcessors,
21
+ MoveDirectionality)
22
+ from vllm.v1.sample.metadata import SamplingMetadata
23
+ from vllm.v1.spec_decode.utils import is_spec_decode_unsupported
24
+ from vllm.v1.utils import copy_slice
25
+ from vllm.v1.worker.block_table import MultiGroupBlockTable
26
+
27
+
28
+ @dataclass
29
+ class CachedRequestState:
30
+
31
+ req_id: str
32
+ prompt_token_ids: Optional[list[int]]
33
+ mm_features: list[MultiModalFeatureSpec]
34
+ sampling_params: Optional[SamplingParams]
35
+ pooling_params: Optional[PoolingParams]
36
+ generator: Optional[torch.Generator]
37
+
38
+ block_ids: tuple[list[int], ...]
39
+ num_computed_tokens: int
40
+ output_token_ids: list[int]
41
+
42
+ mrope_positions: Optional[torch.Tensor] = None
43
+ mrope_position_delta: Optional[int] = None
44
+
45
+ lora_request: Optional[LoRARequest] = None
46
+ prompt_embeds: Optional[torch.Tensor] = None
47
+
48
+ def __post_init__(self):
49
+ self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
50
+ self.prompt_token_ids, self.prompt_embeds)
51
+
52
+ @property
53
+ def num_tokens(self) -> int:
54
+ return self.num_prompt_tokens + len(self.output_token_ids)
55
+
56
+ # Temporary back-compatibility for plugins that define model runner
57
+ @property
58
+ @deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be "
59
+ "removed in v0.13. Please use `mm_kwargs` instead.")
60
+ def mm_inputs(self) -> list[MultiModalKwargsItems]:
61
+ return [
62
+ MultiModalKwargsItems.from_seq([f.data]) for f in self.mm_features
63
+ if f.data is not None
64
+ ]
65
+
66
+ def get_token_id(self, idx: int) -> int:
67
+ if idx < self.num_prompt_tokens:
68
+ if self.prompt_token_ids is None:
69
+ raise ValueError(
70
+ f"Tried to access token index {idx}, but that token was "
71
+ "provided via prompt_embeds, and its ID is unknown.")
72
+ return self.prompt_token_ids[idx]
73
+ elif idx - self.num_prompt_tokens < len(self.output_token_ids):
74
+ return self.output_token_ids[idx - self.num_prompt_tokens]
75
+ else:
76
+ return -1
77
+
78
+
79
+ class InputBatch:
80
+
81
+ def __init__(
82
+ self,
83
+ max_num_reqs: int,
84
+ max_model_len: int,
85
+ max_num_batched_tokens: int,
86
+ device: torch.device,
87
+ pin_memory: bool,
88
+ vocab_size: int,
89
+ block_sizes: list[int], # The block_size of each kv cache group
90
+ logitsprocs: Optional[LogitsProcessors] = None,
91
+ is_spec_decode: bool = False,
92
+ is_pooling_model: bool = False,
93
+ num_speculative_tokens: int = 0,
94
+ ):
95
+ self.is_pooling_model = is_pooling_model
96
+ self.is_spec_decode = is_spec_decode
97
+ self.max_num_reqs = max_num_reqs
98
+ self.max_model_len = max_model_len
99
+ self.max_num_batched_tokens = max_num_batched_tokens
100
+ self.device = device
101
+ self.pin_memory = pin_memory
102
+ self.vocab_size = vocab_size
103
+
104
+ self._req_ids: list[Optional[str]] = []
105
+ self.req_id_to_index: dict[str, int] = {}
106
+
107
+ # TODO(woosuk): This buffer could be too large if max_model_len is big.
108
+ # Find a way to reduce the CPU memory usage.
109
+ # This buffer is not directly transferred to the GPU, so it does not
110
+ # need to be pinned.
111
+ self.token_ids_cpu_tensor = torch.zeros(
112
+ (max_num_reqs, max_model_len),
113
+ device="cpu",
114
+ dtype=torch.int32,
115
+ pin_memory=False,
116
+ )
117
+ self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
118
+ self.is_token_ids = torch.zeros((max_num_reqs, max_model_len),
119
+ device="cpu",
120
+ dtype=bool,
121
+ pin_memory=False)
122
+ # Store prompt embeddings per request to avoid OOM from large upfront
123
+ # allocation if max_model_len is big.
124
+ # Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size)
125
+ self.req_prompt_embeds: dict[int, torch.Tensor] = {}
126
+ self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
127
+ self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
128
+ self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
129
+ self.num_computed_tokens_cpu_tensor = torch.zeros(
130
+ (max_num_reqs, ),
131
+ device="cpu",
132
+ dtype=torch.int32,
133
+ pin_memory=pin_memory,
134
+ )
135
+ self.num_computed_tokens_cpu = \
136
+ self.num_computed_tokens_cpu_tensor.numpy()
137
+
138
+ # Block table.
139
+ self.block_table = MultiGroupBlockTable(
140
+ max_num_reqs=max_num_reqs,
141
+ max_model_len=max_model_len,
142
+ max_num_batched_tokens=max_num_batched_tokens,
143
+ pin_memory=pin_memory,
144
+ device=device,
145
+ block_sizes=block_sizes,
146
+ num_speculative_tokens=num_speculative_tokens,
147
+ )
148
+
149
+ # Sampling-related.
150
+ self.temperature = torch.empty((max_num_reqs, ),
151
+ dtype=torch.float32,
152
+ device=device)
153
+ self.temperature_cpu_tensor = torch.empty((max_num_reqs, ),
154
+ dtype=torch.float32,
155
+ device="cpu",
156
+ pin_memory=pin_memory)
157
+ self.temperature_cpu = self.temperature_cpu_tensor.numpy()
158
+ self.greedy_reqs: set[str] = set()
159
+ self.random_reqs: set[str] = set()
160
+
161
+ self.top_p = torch.empty((max_num_reqs, ),
162
+ dtype=torch.float32,
163
+ device=device)
164
+ self.top_p_cpu_tensor = torch.empty((max_num_reqs, ),
165
+ dtype=torch.float32,
166
+ device="cpu",
167
+ pin_memory=pin_memory)
168
+ self.top_p_cpu = self.top_p_cpu_tensor.numpy()
169
+ self.top_p_reqs: set[str] = set()
170
+
171
+ self.top_k = torch.empty((max_num_reqs, ),
172
+ dtype=torch.int32,
173
+ device=device)
174
+ self.top_k_cpu_tensor = torch.empty((max_num_reqs, ),
175
+ dtype=torch.int32,
176
+ device="cpu",
177
+ pin_memory=pin_memory)
178
+ self.top_k_cpu = self.top_k_cpu_tensor.numpy()
179
+ self.top_k_reqs: set[str] = set()
180
+
181
+ # IDs of requests which do not support spec decoding
182
+ self.spec_decode_unsupported_reqs: set[str] = set()
183
+
184
+ # Frequency penalty related data structures
185
+ self.frequency_penalties = torch.empty((max_num_reqs, ),
186
+ dtype=torch.float,
187
+ device=device)
188
+ self.frequency_penalties_cpu_tensor = torch.empty(
189
+ (max_num_reqs, ),
190
+ dtype=torch.float,
191
+ device="cpu",
192
+ pin_memory=pin_memory)
193
+ self.frequency_penalties_cpu = \
194
+ self.frequency_penalties_cpu_tensor.numpy()
195
+ self.frequency_penalties_reqs: set[str] = set()
196
+
197
+ # Presence penalty related data structures
198
+ self.presence_penalties = torch.empty((max_num_reqs, ),
199
+ dtype=torch.float,
200
+ device=device)
201
+ self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ),
202
+ dtype=torch.float,
203
+ device="cpu",
204
+ pin_memory=pin_memory)
205
+ self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy(
206
+ )
207
+ self.presence_penalties_reqs: set[str] = set()
208
+
209
+ # Repetition penalty related data structures
210
+ self.repetition_penalties = torch.empty((max_num_reqs, ),
211
+ dtype=torch.float,
212
+ device=device)
213
+ self.repetition_penalties_cpu_tensor = torch.empty(
214
+ (max_num_reqs, ),
215
+ dtype=torch.float,
216
+ device="cpu",
217
+ pin_memory=pin_memory)
218
+ self.repetition_penalties_cpu = \
219
+ self.repetition_penalties_cpu_tensor.numpy()
220
+ self.repetition_penalties_reqs: set[str] = set()
221
+
222
+ # Speculative decoding
223
+ self.num_accepted_tokens_cpu_tensor = torch.ones((max_num_reqs, ),
224
+ dtype=torch.int64,
225
+ device="cpu",
226
+ pin_memory=pin_memory)
227
+ self.num_accepted_tokens_cpu = \
228
+ self.num_accepted_tokens_cpu_tensor.numpy()
229
+
230
+ # lora related
231
+ self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
232
+ dtype=np.int32)
233
+ self.lora_id_to_request_ids: dict[int, set[str]] = {}
234
+ self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
235
+
236
+ # req_index -> generator
237
+ # NOTE(woosuk): The indices of the requests that do not have their own
238
+ # generator should not be included in the dictionary.
239
+ self.generators: dict[int, torch.Generator] = {}
240
+
241
+ self.num_logprobs: dict[str, int] = {}
242
+ # NOTE(rob): num_prompt_logprobs only includes reqs
243
+ # that are currently in the prefill phase.
244
+ self.num_prompt_logprobs: dict[str, int] = {}
245
+
246
+ # To accumulate prompt logprobs tensor chunks across prefill steps.
247
+ self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
248
+
249
+ # Internal representation of per-step batch state changes, used for
250
+ # reordering persistent batch and generating logitsprocs batch state
251
+ # updates. Should reset each step.
252
+ self.batch_update_builder = BatchUpdateBuilder()
253
+
254
+ # TODO convert this to LogitsProcessor
255
+ self.has_allowed_token_ids: set[str] = set()
256
+ # NOTE(lufang): In the mask tensor, if the corresponding token allowed,
257
+ # the value is False. Since we use masked_fill_ to set -inf.
258
+ self.allowed_token_ids_mask: Optional[torch.Tensor] = None
259
+ self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None
260
+
261
+ # req_index -> bad_words_token_ids
262
+ self.bad_words_token_ids: dict[int, list[list[int]]] = {}
263
+
264
+ self.logits_processing_needs_token_ids = np.zeros(max_num_reqs,
265
+ dtype=bool)
266
+
267
+ self.req_output_token_ids: list[Optional[list[int]]] = []
268
+
269
+ # Store provided logitsprocs. If none are provided, initialize empty
270
+ # data structure
271
+ self.logitsprocs = logitsprocs or LogitsProcessors()
272
+
273
+ # This is updated each time the batch constituents change.
274
+ self.sampling_metadata = self._make_sampling_metadata()
275
+
276
+ self.pooling_params: dict[str, PoolingParams] = {}
277
+
278
+ # Cached reference to the GPU tensor of previously sampled tokens
279
+ self.prev_sampled_token_ids: Optional[torch.Tensor] = None
280
+ self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None
281
+ self.prev_req_id_to_index: Optional[dict[str, int]] = None
282
+
283
+ @property
284
+ def req_ids(self) -> list[str]:
285
+ # None elements should only be present transiently
286
+ # while performing state updates to the batch.
287
+ return cast(list[str], self._req_ids)
288
+
289
+ def _register_add_request(self, request: "CachedRequestState") -> int:
290
+ """Track add-request operations for logits processors.
291
+ Not applicable to pooling models.
292
+ """
293
+
294
+ # Fill the next empty index if there is one.
295
+ if (new_req_index := self.batch_update_builder.pop_removed()) is None:
296
+ # Append to end otherwise.
297
+ new_req_index = self.num_reqs
298
+
299
+ assert new_req_index < self.max_num_reqs
300
+ self.batch_update_builder.batch_changed = True
301
+ if request.sampling_params:
302
+ # Detailed added request metadata is only required for non-pooling
303
+ # models, to support logitsprocs.
304
+ self.batch_update_builder.added.append(
305
+ (new_req_index, request.sampling_params,
306
+ request.prompt_token_ids, request.output_token_ids))
307
+
308
+ return new_req_index
309
+
310
+ def add_request(
311
+ self,
312
+ request: "CachedRequestState",
313
+ ) -> int:
314
+ req_index = self._register_add_request(request)
315
+
316
+ req_id = request.req_id
317
+ if req_index == len(self._req_ids):
318
+ self._req_ids.append(req_id)
319
+ self.req_output_token_ids.append(request.output_token_ids)
320
+ else:
321
+ self._req_ids[req_index] = req_id
322
+ self.req_output_token_ids[req_index] = request.output_token_ids
323
+
324
+ self.req_id_to_index[req_id] = req_index
325
+
326
+ # Copy the prompt token ids and output token ids.
327
+ num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
328
+ request.prompt_token_ids, request.prompt_embeds)
329
+ self.num_prompt_tokens[req_index] = num_prompt_tokens
330
+ start_idx = num_prompt_tokens
331
+ end_idx = start_idx + len(request.output_token_ids)
332
+ if request.prompt_token_ids is not None:
333
+ self.token_ids_cpu[
334
+ req_index, :num_prompt_tokens] = request.prompt_token_ids
335
+ self.is_token_ids[req_index, :num_prompt_tokens] = True
336
+ else:
337
+ self.is_token_ids[req_index, :num_prompt_tokens] = False
338
+ if request.prompt_embeds is not None:
339
+ self.req_prompt_embeds[req_index] = request.prompt_embeds
340
+ self.token_ids_cpu[req_index,
341
+ start_idx:end_idx] = request.output_token_ids
342
+ self.is_token_ids[req_index, start_idx:end_idx] = True
343
+ # Number of token ids in prompt (token_ids_cpu or prompt_embeds).
344
+ # NOTE(woosuk): This may include spec decode tokens.
345
+ self.num_tokens[req_index] = request.num_tokens
346
+ # Number of tokens without spec decode tokens.
347
+ self.num_tokens_no_spec[req_index] = request.num_tokens
348
+
349
+ self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
350
+ self.block_table.add_row(request.block_ids, req_index)
351
+
352
+ if sampling_params := request.sampling_params:
353
+ if (self.is_spec_decode
354
+ and is_spec_decode_unsupported(sampling_params)):
355
+ self.spec_decode_unsupported_reqs.add(req_id)
356
+ if sampling_params.sampling_type == SamplingType.GREEDY:
357
+ # Should avoid division by zero later when apply_temperature.
358
+ self.temperature_cpu[req_index] = 0.0
359
+ self.greedy_reqs.add(req_id)
360
+ else:
361
+ self.temperature_cpu[req_index] = sampling_params.temperature
362
+ self.random_reqs.add(req_id)
363
+
364
+ self.top_p_cpu[req_index] = sampling_params.top_p
365
+ if sampling_params.top_p < 1:
366
+ self.top_p_reqs.add(req_id)
367
+ top_k = sampling_params.top_k
368
+ if 0 < top_k < self.vocab_size:
369
+ self.top_k_reqs.add(req_id)
370
+ else:
371
+ top_k = self.vocab_size
372
+ self.top_k_cpu[req_index] = top_k
373
+ self.frequency_penalties_cpu[
374
+ req_index] = sampling_params.frequency_penalty
375
+ if sampling_params.frequency_penalty != 0.0:
376
+ self.frequency_penalties_reqs.add(req_id)
377
+ self.presence_penalties_cpu[
378
+ req_index] = sampling_params.presence_penalty
379
+ if sampling_params.presence_penalty != 0.0:
380
+ self.presence_penalties_reqs.add(req_id)
381
+ self.repetition_penalties_cpu[
382
+ req_index] = sampling_params.repetition_penalty
383
+ if sampling_params.repetition_penalty != 1.0:
384
+ self.repetition_penalties_reqs.add(req_id)
385
+
386
+ # NOTE(woosuk): self.generators should not include the requests that
387
+ # do not have their own generator.
388
+ if request.generator is not None:
389
+ self.generators[req_index] = request.generator
390
+
391
+ if sampling_params.logprobs is not None:
392
+ self.num_logprobs[req_id] = (self.vocab_size
393
+ if sampling_params.logprobs == -1
394
+ else sampling_params.logprobs)
395
+ if sampling_params.prompt_logprobs is not None:
396
+ self.num_prompt_logprobs[req_id] = (
397
+ self.vocab_size if sampling_params.prompt_logprobs == -1
398
+ else sampling_params.prompt_logprobs)
399
+
400
+ if sampling_params.allowed_token_ids:
401
+ self.has_allowed_token_ids.add(req_id)
402
+ if self.allowed_token_ids_mask_cpu_tensor is None:
403
+ # Lazy allocation for this tensor, which can be large.
404
+ # False means we don't fill with -inf.
405
+ self.allowed_token_ids_mask = torch.zeros(
406
+ self.max_num_reqs,
407
+ self.vocab_size,
408
+ dtype=torch.bool,
409
+ device=self.device)
410
+ self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
411
+ self.max_num_reqs,
412
+ self.vocab_size,
413
+ dtype=torch.bool,
414
+ device="cpu")
415
+ self.allowed_token_ids_mask_cpu_tensor[req_index] = True
416
+ # False means we don't fill with -inf.
417
+ self.allowed_token_ids_mask_cpu_tensor[req_index][
418
+ sampling_params.allowed_token_ids] = False
419
+
420
+ if sampling_params.bad_words_token_ids:
421
+ self.bad_words_token_ids[
422
+ req_index] = sampling_params.bad_words_token_ids
423
+ elif pooling_params := request.pooling_params:
424
+ self.pooling_params[req_id] = pooling_params
425
+ self.logits_processing_needs_token_ids[req_index] = (
426
+ pooling_params.requires_token_ids)
427
+ else:
428
+ raise NotImplementedError("Unrecognized request type")
429
+
430
+ # Speculative decoding: by default 1 token is generated.
431
+ self.num_accepted_tokens_cpu[req_index] = 1
432
+
433
+ # Add request lora ID
434
+ if request.lora_request:
435
+ lora_id = request.lora_request.lora_int_id
436
+ if lora_id not in self.lora_id_to_request_ids:
437
+ self.lora_id_to_request_ids[lora_id] = set()
438
+
439
+ self.request_lora_mapping[req_index] = lora_id
440
+ self.lora_id_to_request_ids[lora_id].add(request.req_id)
441
+ self.lora_id_to_lora_request[lora_id] = request.lora_request
442
+ else:
443
+ # No LoRA
444
+ self.request_lora_mapping[req_index] = 0
445
+
446
+ return req_index
447
+
448
+ def remove_request(self, req_id: str) -> Optional[int]:
449
+ """This method must always be followed by a call to condense().
450
+
451
+ Args:
452
+ req_id: request to remove
453
+
454
+ Returns:
455
+ Removed request index, or `None` if `req_id` not recognized
456
+ """
457
+
458
+ req_index = self.req_id_to_index.pop(req_id, None)
459
+ if req_index is None:
460
+ return None
461
+
462
+ self.batch_update_builder.removed_append(req_index)
463
+ self._req_ids[req_index] = None
464
+ self.req_output_token_ids[req_index] = None
465
+
466
+ # LoRA
467
+ lora_id = self.request_lora_mapping[req_index]
468
+ if lora_id != 0:
469
+ lora_req_ids = self.lora_id_to_request_ids[lora_id]
470
+ lora_req_ids.discard(req_id)
471
+ if not lora_req_ids:
472
+ del self.lora_id_to_request_ids[lora_id]
473
+ del self.lora_id_to_lora_request[lora_id]
474
+ self.request_lora_mapping[req_index] = 0
475
+
476
+ if self.is_pooling_model:
477
+ self.pooling_params.pop(req_id, None)
478
+ return req_index
479
+
480
+ self.greedy_reqs.discard(req_id)
481
+ self.random_reqs.discard(req_id)
482
+ self.top_p_reqs.discard(req_id)
483
+ self.top_k_reqs.discard(req_id)
484
+ self.spec_decode_unsupported_reqs.discard(req_id)
485
+ self.frequency_penalties_reqs.discard(req_id)
486
+ self.presence_penalties_reqs.discard(req_id)
487
+ self.repetition_penalties_reqs.discard(req_id)
488
+ self.generators.pop(req_index, None)
489
+ self.num_logprobs.pop(req_id, None)
490
+ self.num_prompt_logprobs.pop(req_id, None)
491
+ self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
492
+
493
+ self.has_allowed_token_ids.discard(req_id)
494
+ if self.allowed_token_ids_mask_cpu_tensor is not None:
495
+ # False means we don't fill with -inf.
496
+ self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
497
+ self.bad_words_token_ids.pop(req_index, None)
498
+ return req_index
499
+
500
+ def swap_states(self, i1: int, i2: int) -> None:
501
+ old_id_i1 = self._req_ids[i1]
502
+ old_id_i2 = self._req_ids[i2]
503
+ self._req_ids[i1], self._req_ids[i2] =\
504
+ self._req_ids[i2], self._req_ids[i1] # noqa
505
+ self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\
506
+ self.req_output_token_ids[i2], self.req_output_token_ids[i1]
507
+ assert old_id_i1 is not None and old_id_i2 is not None
508
+ self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\
509
+ self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1]
510
+ self.num_tokens[i1], self.num_tokens[i2] =\
511
+ self.num_tokens[i2], self.num_tokens[i1]
512
+ self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\
513
+ self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1]
514
+ self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\
515
+ self.num_prompt_tokens[i2], self.num_prompt_tokens[i1]
516
+ self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\
517
+ self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1]
518
+
519
+ # NOTE: the following is unsafe
520
+ # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
521
+ # self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
522
+ # instead, we need to temporiarily copy the data for one of the indices
523
+ # TODO(lucas): optimize this by only copying valid indices
524
+ tmp = self.token_ids_cpu[i1, ...].copy()
525
+ self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
526
+ self.token_ids_cpu[i2, ...] = tmp
527
+
528
+ self.is_token_ids[[i1, i2], ...] = self.is_token_ids[[i2, i1], ...]
529
+
530
+ # Swap prompt embeddings if they exist
531
+ embeds_i1 = self.req_prompt_embeds.get(i1)
532
+ embeds_i2 = self.req_prompt_embeds.get(i2)
533
+ if embeds_i1 is not None:
534
+ self.req_prompt_embeds[i2] = embeds_i1
535
+ else:
536
+ self.req_prompt_embeds.pop(i2, None)
537
+ if embeds_i2 is not None:
538
+ self.req_prompt_embeds[i1] = embeds_i2
539
+ else:
540
+ self.req_prompt_embeds.pop(i1, None)
541
+
542
+ self.block_table.swap_row(i1, i2)
543
+
544
+ self.request_lora_mapping[i1], self.request_lora_mapping[i2] = \
545
+ self.request_lora_mapping[i2], self.request_lora_mapping[i1]
546
+
547
+ if self.is_pooling_model:
548
+ # Sampling and logits parameters don't apply to pooling models.
549
+ return
550
+
551
+ # For autoregressive models, track detailed request reordering info
552
+ # to support logitsprocs.
553
+ self.batch_update_builder.moved.append(
554
+ (i1, i2, MoveDirectionality.SWAP))
555
+
556
+ self.temperature_cpu[i1], self.temperature_cpu[i2] = \
557
+ self.temperature_cpu[i2], self.temperature_cpu[i1]
558
+ self.top_p_cpu[i1], self.top_p_cpu[i2] = \
559
+ self.top_p_cpu[i2], self.top_p_cpu[i1]
560
+ self.top_k_cpu[i1], self.top_k_cpu[i2] = \
561
+ self.top_k_cpu[i2], self.top_k_cpu[i1]
562
+ self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] = \
563
+ self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1]
564
+ self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] = \
565
+ self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
566
+ self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = \
567
+ self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]
568
+ self.num_accepted_tokens_cpu[i1], self.num_accepted_tokens_cpu[i2] =\
569
+ self.num_accepted_tokens_cpu[i2], self.num_accepted_tokens_cpu[i1]
570
+
571
+ swap_dict_values(self.generators, i1, i2)
572
+ swap_dict_values(self.bad_words_token_ids, i1, i2)
573
+
574
+ if self.allowed_token_ids_mask_cpu_tensor is not None:
575
+ self.allowed_token_ids_mask_cpu_tensor[i1], \
576
+ self.allowed_token_ids_mask_cpu_tensor[i2] =\
577
+ self.allowed_token_ids_mask_cpu_tensor[i2], \
578
+ self.allowed_token_ids_mask_cpu_tensor[i1]
579
+
580
+ def condense(self) -> None:
581
+ """Slide non-empty requests down into lower, empty indices.
582
+
583
+ Any consecutive empty indices at the very end of the list are not
584
+ filled.
585
+
586
+ Returns:
587
+ swaps: list of (from,to) swap tuples for moved requests
588
+ empty_req_indices: indices not filled by condensation
589
+ """
590
+ num_reqs = self.num_reqs
591
+
592
+ if not (empty_req_indices := self.batch_update_builder.removed):
593
+ # All removed requests were replaced by added requests, or else no
594
+ # requests were removed at all. No condense() needed
595
+ return
596
+ if num_reqs == 0:
597
+ # The batched states are empty.
598
+ self._req_ids.clear()
599
+ self.req_output_token_ids.clear()
600
+ return
601
+
602
+ # NOTE(woosuk): This function assumes that the empty_req_indices
603
+ # is sorted in descending order.
604
+ last_req_index = num_reqs + len(empty_req_indices) - 1
605
+ while empty_req_indices:
606
+ # Find the largest non-empty index.
607
+ while last_req_index in empty_req_indices:
608
+ last_req_index -= 1
609
+
610
+ # Find the smallest empty index.
611
+ empty_index = self.batch_update_builder.peek_removed()
612
+ assert empty_index is not None
613
+ if empty_index >= last_req_index:
614
+ break
615
+
616
+ # Move active request down into empty request
617
+ # index.
618
+ self.batch_update_builder.pop_removed()
619
+ req_id = self._req_ids[last_req_index]
620
+ output_token_ids = self.req_output_token_ids[last_req_index]
621
+ assert req_id is not None
622
+ self._req_ids[empty_index] = req_id
623
+ self._req_ids[last_req_index] = None
624
+ self.req_output_token_ids[empty_index] = output_token_ids
625
+ self.req_output_token_ids[last_req_index] = None
626
+ self.req_id_to_index[req_id] = empty_index
627
+
628
+ num_tokens = self.num_tokens[last_req_index]
629
+ self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
630
+ last_req_index, :num_tokens]
631
+ self.is_token_ids[empty_index, :num_tokens] = self.is_token_ids[
632
+ last_req_index, :num_tokens]
633
+ if last_req_index in self.req_prompt_embeds:
634
+ self.req_prompt_embeds[
635
+ empty_index] = self.req_prompt_embeds.pop(last_req_index)
636
+ self.num_tokens[empty_index] = num_tokens
637
+ self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
638
+ last_req_index]
639
+ self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[
640
+ last_req_index]
641
+ self.num_computed_tokens_cpu[
642
+ empty_index] = self.num_computed_tokens_cpu[last_req_index]
643
+ self.block_table.move_row(last_req_index, empty_index)
644
+
645
+ self.request_lora_mapping[empty_index] = self.request_lora_mapping[
646
+ last_req_index]
647
+
648
+ if self.is_pooling_model:
649
+ last_req_index -= 1
650
+ # Sampling state not used by pooling models.
651
+ continue
652
+
653
+ # Autoregressive models require detailed tracking of condense
654
+ # operations to support logitsprocs
655
+ self.batch_update_builder.moved.append(
656
+ (last_req_index, empty_index,
657
+ MoveDirectionality.UNIDIRECTIONAL))
658
+
659
+ self.temperature_cpu[empty_index] = self.temperature_cpu[
660
+ last_req_index]
661
+ self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
662
+ self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
663
+ self.frequency_penalties_cpu[
664
+ empty_index] = self.frequency_penalties_cpu[last_req_index]
665
+ self.presence_penalties_cpu[
666
+ empty_index] = self.presence_penalties_cpu[last_req_index]
667
+ self.repetition_penalties_cpu[
668
+ empty_index] = self.repetition_penalties_cpu[last_req_index]
669
+ self.num_accepted_tokens_cpu[
670
+ empty_index] = self.num_accepted_tokens_cpu[last_req_index]
671
+ generator = self.generators.pop(last_req_index, None)
672
+ if generator is not None:
673
+ self.generators[empty_index] = generator
674
+
675
+ # TODO convert these to LogitsProcessors
676
+ if self.allowed_token_ids_mask_cpu_tensor is not None:
677
+ self.allowed_token_ids_mask_cpu_tensor[
678
+ empty_index] = self.allowed_token_ids_mask_cpu_tensor[
679
+ last_req_index]
680
+
681
+ bad_words_token_ids = self.bad_words_token_ids.pop(
682
+ last_req_index, None)
683
+ if bad_words_token_ids is not None:
684
+ self.bad_words_token_ids[empty_index] = bad_words_token_ids
685
+
686
+ # Decrement last_req_index since it is now empty.
687
+ last_req_index -= 1
688
+
689
+ # Trim lists to the batch size.
690
+ del self._req_ids[num_reqs:]
691
+ del self.req_output_token_ids[num_reqs:]
692
+
693
+ def refresh_metadata(self):
694
+ """Apply any batch updates to sampling metadata."""
695
+
696
+ if self.is_pooling_model:
697
+ batch_changed = self.batch_update_builder.reset()
698
+ if batch_changed:
699
+ self.sampling_metadata = self._make_sampling_metadata()
700
+ return
701
+
702
+ # For non-pooling models - generate and apply logitsprocs update;
703
+ # reset batch update tracking.
704
+ # Update sampling metadata if batch state is changed.
705
+ batch_update = self.batch_update_builder.get_and_reset(self.num_reqs)
706
+ for logit_proc in self.logitsprocs.all:
707
+ logit_proc.update_state(batch_update)
708
+ if batch_update:
709
+ self.sampling_metadata = self._make_sampling_metadata()
710
+
711
+ def _make_sampling_metadata(self) -> SamplingMetadata:
712
+ num_reqs = self.num_reqs
713
+ if not self.all_greedy:
714
+ temperature = copy_slice(self.temperature_cpu_tensor,
715
+ self.temperature, num_reqs)
716
+ else:
717
+ temperature = None
718
+ if not self.no_top_p:
719
+ copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
720
+ if not self.no_top_k:
721
+ copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs)
722
+
723
+ if not self.no_penalties:
724
+ # Since syncing these tensors is expensive only copy them
725
+ # if necessary i.e. if there are requests which require
726
+ # penalties to be applied during sampling.
727
+ copy_slice(self.frequency_penalties_cpu_tensor,
728
+ self.frequency_penalties, num_reqs)
729
+ copy_slice(self.presence_penalties_cpu_tensor,
730
+ self.presence_penalties, num_reqs)
731
+ copy_slice(self.repetition_penalties_cpu_tensor,
732
+ self.repetition_penalties, num_reqs)
733
+
734
+ needs_prompt_token_ids = (
735
+ not self.no_penalties
736
+ or self.logits_processing_needs_token_ids[:num_reqs].any())
737
+ if needs_prompt_token_ids:
738
+ # The prompt tokens are used only for applying penalties or
739
+ # step pooling during the sampling/pooling process.
740
+ # Hence copy these tensors only when there are requests which
741
+ # need penalties/step_pooler to be applied.
742
+ prompt_token_ids = self._make_prompt_token_ids_tensor()
743
+ else:
744
+ prompt_token_ids = None
745
+
746
+ allowed_token_ids_mask: Optional[torch.Tensor] = None
747
+ if not self.no_allowed_token_ids:
748
+ assert self.allowed_token_ids_mask is not None
749
+ copy_slice(self.allowed_token_ids_mask_cpu_tensor,
750
+ self.allowed_token_ids_mask, num_reqs)
751
+ allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]
752
+
753
+ return SamplingMetadata(
754
+ temperature=temperature,
755
+ all_greedy=self.all_greedy,
756
+ all_random=self.all_random,
757
+ top_p=None if self.no_top_p else self.top_p[:num_reqs],
758
+ top_k=None if self.no_top_k else self.top_k[:num_reqs],
759
+ generators=self.generators,
760
+ max_num_logprobs=self.max_num_logprobs,
761
+ prompt_token_ids=prompt_token_ids,
762
+ frequency_penalties=self.frequency_penalties[:num_reqs],
763
+ presence_penalties=self.presence_penalties[:num_reqs],
764
+ repetition_penalties=self.repetition_penalties[:num_reqs],
765
+ output_token_ids=cast(list[list[int]], self.req_output_token_ids),
766
+ no_penalties=self.no_penalties,
767
+ allowed_token_ids_mask=allowed_token_ids_mask,
768
+ bad_words_token_ids=self.bad_words_token_ids,
769
+ logitsprocs=self.logitsprocs,
770
+ )
771
+
772
+ def get_pooling_params(self) -> list[PoolingParams]:
773
+ assert len(self.req_ids) == len(self.pooling_params)
774
+ return [self.pooling_params[req_id] for req_id in self.req_ids]
775
+
776
+ def get_pooling_metadata(self) -> PoolingMetadata:
777
+ pooling_params = self.get_pooling_params()
778
+
779
+ return PoolingMetadata(
780
+ prompt_lens=torch.from_numpy(
781
+ self.num_prompt_tokens[:self.num_reqs]),
782
+ prompt_token_ids=self.sampling_metadata.prompt_token_ids,
783
+ pooling_params=pooling_params,
784
+ )
785
+
786
+ def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
787
+ num_reqs = self.num_reqs
788
+ max_prompt_len = self.num_prompt_tokens[:num_reqs].max()
789
+ prompt_token_ids_cpu_tensor = torch.empty(
790
+ (self.num_reqs, max_prompt_len),
791
+ device="cpu",
792
+ dtype=torch.int64,
793
+ pin_memory=self.pin_memory,
794
+ )
795
+ prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
796
+ prompt_token_ids[:] = self.token_ids_cpu[:num_reqs, :max_prompt_len]
797
+ # Use the value of vocab_size as a pad since we don't have a
798
+ # token_id of this value.
799
+ for i in range(num_reqs):
800
+ prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size
801
+ return prompt_token_ids_cpu_tensor.to(device=self.device,
802
+ non_blocking=True)
803
+
804
+ def make_lora_inputs(
805
+ self, num_scheduled_tokens: np.ndarray
806
+ ) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
807
+ """
808
+ Given the num_scheduled_tokens for each request in the batch, return
809
+ datastructures used to activate the current LoRAs.
810
+ Returns:
811
+ 1. prompt_lora_mapping: A tuple of size self.num_reqs where,
812
+ prompt_lora_mapping[i] is the LoRA id to use for the ith prompt.
813
+ 2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
814
+ where, token_lora_mapping[i] is the LoRA id to use for ith token.
815
+ 3. lora_requests: Set of relevant LoRA requests.
816
+ """
817
+
818
+ req_lora_mapping = self.request_lora_mapping[:self.num_reqs]
819
+ prompt_lora_mapping = tuple(req_lora_mapping)
820
+ token_lora_mapping = tuple(
821
+ req_lora_mapping.repeat(num_scheduled_tokens))
822
+ active_lora_requests: set[LoRARequest] = set(
823
+ self.lora_id_to_lora_request.values())
824
+
825
+ return prompt_lora_mapping, token_lora_mapping, active_lora_requests
826
+
827
+ @property
828
+ def num_reqs(self) -> int:
829
+ return len(self.req_id_to_index)
830
+
831
+ @property
832
+ def all_greedy(self) -> bool:
833
+ return len(self.random_reqs) == 0
834
+
835
+ @property
836
+ def all_random(self) -> bool:
837
+ return len(self.greedy_reqs) == 0
838
+
839
+ @property
840
+ def no_top_p(self) -> bool:
841
+ return len(self.top_p_reqs) == 0
842
+
843
+ @property
844
+ def no_top_k(self) -> bool:
845
+ return len(self.top_k_reqs) == 0
846
+
847
+ @property
848
+ def no_penalties(self) -> bool:
849
+ return (len(self.presence_penalties_reqs) == 0
850
+ and len(self.frequency_penalties_reqs) == 0
851
+ and len(self.repetition_penalties_reqs) == 0)
852
+
853
+ @property
854
+ def max_num_logprobs(self) -> Optional[int]:
855
+ return max(self.num_logprobs.values()) if self.num_logprobs else None
856
+
857
+ @property
858
+ def no_prompt_logprob(self) -> bool:
859
+ return not self.num_prompt_logprobs
860
+
861
+ @property
862
+ def no_allowed_token_ids(self) -> bool:
863
+ return len(self.has_allowed_token_ids) == 0
vllm_hacked/v1/worker/gpu_model_runner.py ADDED
The diff for this file is too large to render. See raw diff
 
vllm_hacked/v1/worker/gpu_model_runner_ori.py ADDED
The diff for this file is too large to render. See raw diff
 
vllm_hacked/v1/worker/gpu_worker.py ADDED
@@ -0,0 +1,710 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ """A GPU worker class."""
4
+ import copy
5
+ import gc
6
+ import os
7
+ from contextlib import AbstractContextManager, nullcontext
8
+ from typing import TYPE_CHECKING, Any, Optional, Union
9
+
10
+ import torch
11
+ import torch.distributed
12
+ import torch.nn as nn
13
+
14
+ import vllm.envs as envs
15
+ from vllm.config import VllmConfig
16
+ from vllm.distributed import (ensure_model_parallel_initialized,
17
+ init_distributed_environment,
18
+ set_custom_all_reduce)
19
+ from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
20
+ from vllm.distributed.parallel_state import get_pp_group, get_tp_group
21
+ from vllm.logger import init_logger
22
+ from vllm.lora.request import LoRARequest
23
+ from vllm.model_executor import set_random_seed
24
+ from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
25
+ from vllm.platforms import current_platform
26
+ from vllm.sequence import IntermediateTensors
27
+ from vllm.tasks import SupportedTask
28
+ from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling
29
+ from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
30
+ from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
31
+ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
32
+ DraftTokenIds, ModelRunnerOutput)
33
+ from vllm.v1.utils import report_usage_stats
34
+ from vllm.v1.worker.gpu_model_runner import GPUModelRunner
35
+ from vllm.v1.worker.utils import is_residual_scattered_for_sp
36
+ from vllm.v1.worker.worker_base import WorkerBase
37
+
38
+ logger = init_logger(__name__)
39
+
40
+ if TYPE_CHECKING:
41
+ from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
42
+ from vllm.v1.core.sched.output import SchedulerOutput
43
+
44
+
45
+ class Worker(WorkerBase):
46
+
47
+ def __init__(
48
+ self,
49
+ vllm_config: VllmConfig,
50
+ local_rank: int,
51
+ rank: int,
52
+ distributed_init_method: str,
53
+ is_driver_worker: bool = False,
54
+ ):
55
+
56
+ super().__init__(vllm_config=vllm_config,
57
+ local_rank=local_rank,
58
+ rank=rank,
59
+ distributed_init_method=distributed_init_method,
60
+ is_driver_worker=is_driver_worker)
61
+
62
+ if self.model_config.trust_remote_code:
63
+ # note: lazy import to avoid importing torch before initializing
64
+ from vllm.utils import init_cached_hf_modules
65
+ init_cached_hf_modules()
66
+
67
+ # Buffers saved before sleep
68
+ self._sleep_saved_buffers: dict[str, torch.Tensor] = {}
69
+
70
+ # Torch profiler. Enabled and configured through env vars:
71
+ # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
72
+ if envs.VLLM_TORCH_PROFILER_DIR:
73
+ torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
74
+ logger.info("Profiling enabled. Traces will be saved to: %s",
75
+ torch_profiler_trace_dir)
76
+ logger.debug(
77
+ "Profiler config: record_shapes=%s,"
78
+ "profile_memory=%s,with_stack=%s,with_flops=%s",
79
+ envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
80
+ envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
81
+ envs.VLLM_TORCH_PROFILER_WITH_STACK,
82
+ envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
83
+ )
84
+ self.profiler = torch.profiler.profile(
85
+ activities=[
86
+ torch.profiler.ProfilerActivity.CPU,
87
+ torch.profiler.ProfilerActivity.CUDA,
88
+ ],
89
+ record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
90
+ profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
91
+ with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
92
+ with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
93
+ on_trace_ready=torch.profiler.tensorboard_trace_handler(
94
+ torch_profiler_trace_dir, use_gzip=True))
95
+ else:
96
+ self.profiler = None
97
+
98
+ def sleep(self, level: int = 1) -> None:
99
+ from vllm.device_allocator.cumem import CuMemAllocator
100
+
101
+ free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
102
+
103
+ # Save the buffers before level 2 sleep
104
+ if level == 2:
105
+ model = self.model_runner.model
106
+ self._sleep_saved_buffers = {
107
+ name: buffer.cpu().clone()
108
+ for name, buffer in model.named_buffers()
109
+ }
110
+
111
+ allocator = CuMemAllocator.get_instance()
112
+ allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple())
113
+ free_bytes_after_sleep, total = torch.cuda.mem_get_info()
114
+ freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
115
+ used_bytes = total - free_bytes_after_sleep
116
+ assert freed_bytes >= 0, "Memory usage increased after sleeping."
117
+ logger.info(
118
+ "Sleep mode freed %.2f GiB memory, "
119
+ "%.2f GiB memory is still in use.", freed_bytes / GiB_bytes,
120
+ used_bytes / GiB_bytes)
121
+
122
+ def wake_up(self, tags: Optional[list[str]] = None) -> None:
123
+ from vllm.device_allocator.cumem import CuMemAllocator
124
+
125
+ allocator = CuMemAllocator.get_instance()
126
+ allocator.wake_up(tags)
127
+
128
+ # Restore the buffers after level 2 sleep
129
+ if len(self._sleep_saved_buffers):
130
+ model = self.model_runner.model
131
+ for name, buffer in model.named_buffers():
132
+ if name in self._sleep_saved_buffers:
133
+ buffer.data.copy_(self._sleep_saved_buffers[name].data)
134
+ self._sleep_saved_buffers = {}
135
+
136
+ def _maybe_get_memory_pool_context(self,
137
+ tag: str) -> AbstractContextManager:
138
+ if self.vllm_config.model_config.enable_sleep_mode:
139
+ from vllm.device_allocator.cumem import CuMemAllocator
140
+
141
+ allocator = CuMemAllocator.get_instance()
142
+ if tag == "weights":
143
+ assert allocator.get_current_usage() == 0, (
144
+ "Sleep mode can only be "
145
+ "used for one instance per process.")
146
+ context = allocator.use_memory_pool(tag=tag)
147
+ else:
148
+ context = nullcontext()
149
+ return context
150
+
151
+ def initialize_cache(self, num_gpu_blocks: int,
152
+ num_cpu_blocks: int) -> None:
153
+ self.cache_config.num_gpu_blocks = num_gpu_blocks
154
+ self.cache_config.num_cpu_blocks = num_cpu_blocks
155
+
156
+ def init_device(self):
157
+ if self.device_config.device.type == "cuda":
158
+ # This env var set by Ray causes exceptions with graph building.
159
+ os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
160
+ self.device = torch.device(f"cuda:{self.local_rank}")
161
+ current_platform.set_device(self.device)
162
+
163
+ current_platform.check_if_supports_dtype(self.model_config.dtype)
164
+
165
+ # Initialize the distributed environment BEFORE taking
166
+ # memory snapshot
167
+ # This ensures NCCL buffers are allocated before we measure
168
+ # available memory
169
+ init_worker_distributed_environment(self.vllm_config, self.rank,
170
+ self.distributed_init_method,
171
+ self.local_rank,
172
+ current_platform.dist_backend)
173
+
174
+ # Set random seed.
175
+ set_random_seed(self.model_config.seed)
176
+
177
+ # Now take memory snapshot after NCCL is initialized
178
+ gc.collect()
179
+ torch.cuda.empty_cache()
180
+
181
+ # take current memory snapshot
182
+ self.init_snapshot = MemorySnapshot()
183
+ self.requested_memory = (self.init_snapshot.total_memory *
184
+ self.cache_config.gpu_memory_utilization)
185
+ if self.init_snapshot.free_memory < self.requested_memory:
186
+ GiB = lambda b: round(b / GiB_bytes, 2)
187
+ raise ValueError(
188
+ f"Free memory on device "
189
+ f"({GiB(self.init_snapshot.free_memory)}/"
190
+ f"{GiB(self.init_snapshot.total_memory)} GiB) on startup "
191
+ f"is less than desired GPU memory utilization "
192
+ f"({self.cache_config.gpu_memory_utilization}, "
193
+ f"{GiB(self.requested_memory)} GiB). Decrease GPU memory "
194
+ f"utilization or reduce GPU memory used by other processes."
195
+ )
196
+ else:
197
+ raise RuntimeError(
198
+ f"Not support device type: {self.device_config.device}")
199
+
200
+ # Construct the model runner
201
+ self.model_runner: GPUModelRunner = GPUModelRunner(
202
+ self.vllm_config, self.device)
203
+
204
+ if self.rank == 0:
205
+ # If usage stat is enabled, collect relevant info.
206
+ report_usage_stats(self.vllm_config)
207
+
208
+ # FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
209
+ # to hijack tensor allocation.
210
+ def load_model(self) -> None:
211
+ eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
212
+ with self._maybe_get_memory_pool_context(tag="weights"):
213
+ self.model_runner.load_model(eep_scale_up=eep_scale_up)
214
+
215
+ def update_config(self, overrides: dict[str, Any]) -> None:
216
+ self.model_runner.update_config(overrides)
217
+
218
+ def reload_weights(self) -> None:
219
+ self.model_runner.reload_weights()
220
+
221
+ @torch.inference_mode()
222
+ def determine_available_memory(self) -> int:
223
+ """Profiles the peak memory usage of the model to determine how much
224
+ memory can be used for KV cache without OOMs.
225
+
226
+ The engine will first conduct a profiling of the existing memory usage.
227
+ Then, it calculates the free memory that can be used for KV cache in
228
+ bytes.
229
+
230
+ Tip:
231
+ You may limit the usage of GPU memory
232
+ by adjusting the `gpu_memory_utilization` parameter.
233
+ """
234
+ GiB = lambda b: b / GiB_bytes
235
+ if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes:
236
+ # still need a profile run which compiles the model for
237
+ # max_num_batched_tokens
238
+ self.model_runner.profile_run()
239
+
240
+ msg = (
241
+ f"Initial free memory {GiB(self.init_snapshot.free_memory)} "
242
+ f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f}GiB memory for "
243
+ "KV Cache as specified by kv_cache_memory_bytes config and "
244
+ "skipped memory profiling. This does does not respect the "
245
+ "gpu_memory_utilization config. Only use kv_cache_memory_bytes "
246
+ "config when you want manual control of KV cache memory "
247
+ "size. If OOM'ed, check the difference of initial free "
248
+ "memory between the current run and the previous run "
249
+ "where kv_cache_memory_bytes is suggested and update it "
250
+ "correspondingly.")
251
+ logger.info(msg)
252
+ return kv_cache_memory_bytes
253
+
254
+ torch.cuda.empty_cache()
255
+ torch.cuda.reset_peak_memory_stats()
256
+
257
+ # Execute a forward pass with dummy inputs to profile the memory usage
258
+ # of the model.
259
+ with memory_profiling(
260
+ self.init_snapshot,
261
+ weights_memory=int(self.model_runner.model_memory_usage),
262
+ ) as profile_result:
263
+ self.model_runner.profile_run()
264
+
265
+ self.non_torch_memory = profile_result.non_torch_increase
266
+ self.peak_activation_memory = profile_result.torch_peak_increase
267
+
268
+ free_gpu_memory = profile_result.after_profile.free_memory
269
+ # NOTE(woosuk): Here we assume that the other processes using the same
270
+ # GPU did not change their memory usage during the profiling.
271
+ assert self.init_snapshot.free_memory > free_gpu_memory, (
272
+ "Error in memory profiling. "
273
+ f"Initial free memory {GiB(self.init_snapshot.free_memory)} GiB, "
274
+ f"current free memory {GiB(free_gpu_memory)} GiB. "
275
+ "This happens when other processes sharing the same container "
276
+ "release GPU memory while vLLM is profiling during initialization. "
277
+ "To fix this, ensure consistent GPU memory allocation or "
278
+ "isolate vLLM in its own container.")
279
+ self.available_kv_cache_memory_bytes = self.requested_memory \
280
+ - profile_result.non_kv_cache_memory
281
+
282
+ unrequested_memory = self.init_snapshot.free_memory \
283
+ - self.requested_memory
284
+ logger.debug(
285
+ "Initial free memory: %.2f GiB; "
286
+ "Requested memory: %.2f (util), %.2f GiB",
287
+ GiB(self.init_snapshot.free_memory),
288
+ self.cache_config.gpu_memory_utilization,
289
+ GiB(self.requested_memory),
290
+ )
291
+ logger.debug(
292
+ "Free memory after profiling: %.2f GiB (total), "
293
+ "%.2f GiB (within requested)",
294
+ GiB(free_gpu_memory),
295
+ GiB(free_gpu_memory - unrequested_memory),
296
+ )
297
+ logger.debug(profile_result)
298
+ logger.info("Available KV cache memory: %.2f GiB",
299
+ GiB(self.available_kv_cache_memory_bytes))
300
+ gc.collect()
301
+
302
+ return int(self.available_kv_cache_memory_bytes)
303
+
304
+ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
305
+ return self.model_runner.get_kv_cache_spec()
306
+
307
+ def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
308
+ """Allocate GPU KV cache with the specified kv_cache_config."""
309
+
310
+ if self.vllm_config.model_config.enable_sleep_mode:
311
+ from vllm.device_allocator.cumem import CuMemAllocator
312
+
313
+ allocator = CuMemAllocator.get_instance()
314
+ context = allocator.use_memory_pool(tag="kv_cache")
315
+ else:
316
+ context = nullcontext()
317
+ with context:
318
+ self.model_runner.initialize_kv_cache(kv_cache_config)
319
+
320
+ def compile_or_warm_up_model(self) -> None:
321
+ # warm up sizes that are not in cudagraph capture sizes,
322
+ # but users still want to compile for better performance,
323
+ # e.g. for the max-num-batched token size in chunked prefill.
324
+ warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
325
+ if not self.model_config.enforce_eager:
326
+ warmup_sizes = [
327
+ x for x in warmup_sizes if x not in
328
+ self.vllm_config.compilation_config.cudagraph_capture_sizes
329
+ ]
330
+ # We skip EPLB here since we don't want to record dummy metrics
331
+ for size in sorted(warmup_sizes, reverse=True):
332
+ logger.info("Compile and warming up model for size %d", size)
333
+ self.model_runner._dummy_run(size,
334
+ skip_eplb=True,
335
+ remove_lora=False)
336
+ self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config)
337
+
338
+ # Warmup and tune the kernels used during model execution before
339
+ # cuda graph capture.
340
+ kernel_warmup(self)
341
+
342
+ cuda_graph_memory_bytes = 0
343
+ if not self.model_config.enforce_eager:
344
+ cuda_graph_memory_bytes = self.model_runner.capture_model()
345
+
346
+ if (self.cache_config.kv_cache_memory_bytes is None
347
+ and hasattr(self, "peak_activation_memory")):
348
+ # Suggests optimal kv cache memory size if we rely on
349
+ # memory_profiling to guess the kv cache memory size which
350
+ # provides peak_activation_memory and a few other memory
351
+ # consumption. `memory_profiling` does not consider
352
+ # CUDAGraph memory size and may not utilize all gpu memory.
353
+ # Users may want fine-grained control to specify kv cache
354
+ # memory size.
355
+ GiB = lambda b: round(b / GiB_bytes, 2)
356
+
357
+ # empirically observed that the memory profiling may
358
+ # slightly underestimate the memory consumption.
359
+ # So leave a small buffer (=150MiB) to avoid OOM.
360
+ redundancy_buffer_memory = 150 * (1 << 20)
361
+ non_kv_cache_memory = (self.model_runner.model_memory_usage +
362
+ self.peak_activation_memory +
363
+ self.non_torch_memory +
364
+ cuda_graph_memory_bytes)
365
+ kv_cache_memory_bytes_to_gpu_limit = (
366
+ self.init_snapshot.free_memory - non_kv_cache_memory -
367
+ redundancy_buffer_memory)
368
+ kv_cache_memory_bytes_to_requested_limit = (
369
+ int(self.requested_memory) - non_kv_cache_memory -
370
+ redundancy_buffer_memory)
371
+
372
+ msg = (
373
+ f"Free memory on device "
374
+ f"({GiB(self.init_snapshot.free_memory)}/"
375
+ f"{GiB(self.init_snapshot.total_memory)} GiB) on startup. "
376
+ f"Desired GPU memory utilization is "
377
+ f"({self.cache_config.gpu_memory_utilization}, "
378
+ f"{GiB(self.requested_memory)} GiB). "
379
+ f"Actual usage is {GiB(self.model_runner.model_memory_usage)} "
380
+ f"GiB for weight, {GiB(self.peak_activation_memory)} GiB "
381
+ f"for peak activation, {GiB(self.non_torch_memory)} GiB "
382
+ f"for non-torch memory, and {GiB(cuda_graph_memory_bytes)} "
383
+ f"GiB for CUDAGraph memory. Replace gpu_memory_utilization "
384
+ f"config with `--kv-cache-memory="
385
+ f"{kv_cache_memory_bytes_to_requested_limit}` "
386
+ f"({GiB(kv_cache_memory_bytes_to_requested_limit)} GiB) to fit "
387
+ f"into requested memory, or `--kv-cache-memory="
388
+ f"{kv_cache_memory_bytes_to_gpu_limit}` "
389
+ f"({GiB(kv_cache_memory_bytes_to_gpu_limit)} GiB) to fully "
390
+ f"utilize gpu memory. Current kv cache memory in use is "
391
+ f"{GiB(self.available_kv_cache_memory_bytes)} GiB.")
392
+
393
+ logger.debug(msg)
394
+
395
+ # Warm up sampler and preallocate memory buffer for logits and other
396
+ # sampling related tensors of max possible shape to avoid memory
397
+ # fragmentation issue.
398
+ # NOTE: This is called after `capture_model` on purpose to prevent
399
+ # memory buffers from being cleared by `torch.cuda.empty_cache`.
400
+ if get_pp_group().is_last_rank:
401
+ max_num_reqs = min(self.scheduler_config.max_num_seqs,
402
+ self.scheduler_config.max_num_batched_tokens)
403
+
404
+ # We skip EPLB here since we don't want to record dummy metrics
405
+ hidden_states, last_hidden_states = \
406
+ self.model_runner._dummy_run(
407
+ num_tokens=max_num_reqs,
408
+ skip_eplb=True,
409
+ )
410
+ if self.model_runner.is_pooling_model:
411
+ self.model_runner._dummy_pooler_run(hidden_states)
412
+ else:
413
+ self.model_runner._dummy_sampler_run(
414
+ hidden_states=last_hidden_states)
415
+
416
+ # Reset the seed to ensure that the random state is not affected by
417
+ # the model initialization and profiling.
418
+ set_random_seed(self.model_config.seed)
419
+
420
+ def get_model(self) -> nn.Module:
421
+ return self.model_runner.get_model()
422
+
423
+ def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
424
+ return self.model_runner.get_supported_tasks()
425
+
426
+ @torch.inference_mode()
427
+ def execute_model(
428
+ self,
429
+ scheduler_output: "SchedulerOutput",
430
+ ) -> Optional[Union[ModelRunnerOutput, AsyncModelRunnerOutput]]:
431
+ intermediate_tensors = None
432
+ forward_pass = scheduler_output.total_num_scheduled_tokens > 0
433
+ num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
434
+ num_input_tokens = self.model_runner._get_num_input_tokens(
435
+ num_scheduled_tokens)
436
+ all_gather_tensors = {
437
+ "residual":
438
+ not is_residual_scattered_for_sp(self.vllm_config,
439
+ num_input_tokens)
440
+ }
441
+ if forward_pass and not get_pp_group().is_first_rank:
442
+ intermediate_tensors = IntermediateTensors(
443
+ get_pp_group().recv_tensor_dict(
444
+ all_gather_group=get_tp_group(),
445
+ all_gather_tensors=all_gather_tensors))
446
+
447
+ output = self.model_runner.execute_model(scheduler_output,
448
+ intermediate_tensors)
449
+ if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)):
450
+ return output
451
+
452
+ assert isinstance(output, IntermediateTensors)
453
+ parallel_config = self.vllm_config.parallel_config
454
+ assert parallel_config.distributed_executor_backend != (
455
+ "external_launcher") and not get_pp_group().is_last_rank
456
+
457
+ get_pp_group().send_tensor_dict(output.tensors,
458
+ all_gather_group=get_tp_group(),
459
+ all_gather_tensors=all_gather_tensors)
460
+
461
+ kv_connector_output = output.kv_connector_output
462
+ if not kv_connector_output:
463
+ return None
464
+
465
+ # In case of PP with kv transfer, we need to pass through the
466
+ # kv_connector_output
467
+ if (not kv_connector_output.finished_sending
468
+ and not kv_connector_output.finished_recving):
469
+ return EMPTY_MODEL_RUNNER_OUTPUT
470
+
471
+ output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
472
+ output.kv_connector_output = kv_connector_output
473
+ return output
474
+
475
+ def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
476
+ return self.model_runner.take_draft_token_ids()
477
+
478
+ def profile(self, is_start: bool = True):
479
+ if self.profiler is None:
480
+ raise RuntimeError("Profiler is not enabled.")
481
+ if is_start:
482
+ self.profiler.start()
483
+ else:
484
+ self.profiler.stop()
485
+ # only print profiler results on rank 0
486
+ if self.local_rank == 0:
487
+ print(self.profiler.key_averages().table(
488
+ sort_by="self_cuda_time_total"))
489
+
490
+ def execute_dummy_batch(self) -> None:
491
+ self.model_runner._dummy_run(1, uniform_decode=True)
492
+
493
+ def add_lora(self, lora_request: LoRARequest) -> bool:
494
+ return self.model_runner.add_lora(lora_request)
495
+
496
+ def remove_lora(self, lora_id: int) -> bool:
497
+ return self.model_runner.remove_lora(lora_id)
498
+
499
+ def list_loras(self) -> set[int]:
500
+ return self.model_runner.list_loras()
501
+
502
+ def pin_lora(self, lora_id: int) -> bool:
503
+ return self.model_runner.pin_lora(lora_id)
504
+
505
+ def check_health(self) -> None:
506
+ # worker will always be healthy as long as it's running.
507
+ return
508
+
509
+ def _eplb_before_scale_down(self, old_ep_size: int,
510
+ new_ep_size: int) -> None:
511
+ from vllm.distributed.parallel_state import get_ep_group
512
+ if get_ep_group().rank == 0:
513
+ logger.info("[Elastic EP] Starting expert resharding "
514
+ "before scaling down...")
515
+ rank_mapping = {
516
+ old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
517
+ for old_ep_rank in range(old_ep_size)
518
+ }
519
+ assert self.model_runner.eplb_state is not None
520
+ self.model_runner.eplb_state.rearrange(self.model_runner.model,
521
+ execute_shuffle=True,
522
+ global_expert_load=None,
523
+ rank_mapping=rank_mapping)
524
+ torch.cuda.synchronize()
525
+ if get_ep_group().rank == 0:
526
+ logger.info("[Elastic EP] Expert resharding completed!")
527
+
528
+ def _eplb_after_scale_up(
529
+ self, old_ep_size: int, new_ep_size: int,
530
+ global_expert_load: Optional[torch.Tensor]) -> None:
531
+ from vllm.distributed.parallel_state import get_ep_group
532
+ if get_ep_group().rank == 0:
533
+ logger.info("[Elastic EP] Starting expert resharding "
534
+ "after scaling up...")
535
+ rank_mapping = {
536
+ old_ep_rank: old_ep_rank
537
+ for old_ep_rank in range(old_ep_size)
538
+ }
539
+ assert self.model_runner.eplb_state is not None
540
+ self.model_runner.eplb_state.rearrange(
541
+ self.model_runner.model,
542
+ execute_shuffle=True,
543
+ global_expert_load=global_expert_load,
544
+ rank_mapping=rank_mapping)
545
+ if get_ep_group().rank == 0:
546
+ logger.info("[Elastic EP] Expert resharding completed!")
547
+
548
+ def _reconfigure_parallel_config(
549
+ self, reconfig_request: ReconfigureDistributedRequest) -> None:
550
+ """
551
+ Update parallel config with provided reconfig_request
552
+ """
553
+ parallel_config = self.vllm_config.parallel_config
554
+ parallel_config.data_parallel_size = \
555
+ reconfig_request.new_data_parallel_size
556
+ if reconfig_request.new_data_parallel_rank != \
557
+ ReconfigureRankType.KEEP_CURRENT_RANK:
558
+ parallel_config.data_parallel_rank = \
559
+ reconfig_request.new_data_parallel_rank
560
+ if reconfig_request.new_data_parallel_rank_local != \
561
+ ReconfigureRankType.KEEP_CURRENT_RANK:
562
+ parallel_config.data_parallel_rank_local = \
563
+ reconfig_request.new_data_parallel_rank_local
564
+ parallel_config.data_parallel_master_ip = \
565
+ reconfig_request.new_data_parallel_master_ip
566
+ parallel_config.data_parallel_master_port = \
567
+ reconfig_request.new_data_parallel_master_port
568
+
569
+ def _reconfigure_moe(self, old_ep_size: int,
570
+ new_ep_size: int) -> Optional[torch.Tensor]:
571
+ """
572
+ Reconfigure MoE modules with provided reconfig_request
573
+
574
+ Return the global expert load if new_ep_size > old_ep_size,
575
+ otherwise None
576
+ """
577
+ from vllm.distributed.parallel_state import (
578
+ get_dp_group, get_ep_group, prepare_communication_buffer_for_model)
579
+ from vllm.model_executor.layers.fused_moe.layer import (
580
+ FusedMoEParallelConfig)
581
+
582
+ parallel_config = self.vllm_config.parallel_config
583
+ moe_modules = [
584
+ module for module in self.model_runner.model.modules()
585
+ if (module.__class__.__name__ == "FusedMoE"
586
+ or module.__class__.__name__ == "SharedFusedMoE")
587
+ ]
588
+ num_local_experts = moe_modules[0].moe_config.num_local_experts
589
+ assert all(module.moe_config.num_local_experts == num_local_experts
590
+ for module in moe_modules), (
591
+ "All MoE modules must have the same number of experts")
592
+ for module in moe_modules:
593
+ module.moe_config.num_experts = num_local_experts * new_ep_size
594
+ module.global_num_experts = module.moe_config.num_experts
595
+ module.moe_parallel_config = FusedMoEParallelConfig.make(
596
+ tp_size_=get_tp_group().world_size,
597
+ dp_size_=get_dp_group().world_size,
598
+ vllm_parallel_config=parallel_config,
599
+ )
600
+ module.moe_config.moe_parallel_config = module.moe_parallel_config
601
+ if new_ep_size < old_ep_size:
602
+ num_local_physical_experts = num_local_experts
603
+ assert self.model_runner.eplb_state is not None
604
+ new_physical_experts = \
605
+ self.model_runner.eplb_state.physical_to_logical_map.shape[1]
606
+ parallel_config.eplb_config.num_redundant_experts = (
607
+ new_physical_experts -
608
+ self.model_runner.eplb_state.logical_replica_count.shape[1])
609
+ global_expert_load = None
610
+ else:
611
+ num_local_physical_experts = torch.tensor([num_local_experts],
612
+ dtype=torch.int32,
613
+ device="cpu")
614
+ torch.distributed.broadcast(num_local_physical_experts,
615
+ group=get_ep_group().cpu_group,
616
+ group_src=0)
617
+ num_local_physical_experts = num_local_physical_experts.item()
618
+ new_physical_experts = num_local_physical_experts * new_ep_size
619
+ assert self.model_runner.eplb_state is not None
620
+ global_expert_load = self.model_runner.eplb_state.rearrange(
621
+ self.model_runner.model, execute_shuffle=False)
622
+ parallel_config.eplb_config.num_redundant_experts = (
623
+ new_physical_experts - global_expert_load.shape[1])
624
+ prepare_communication_buffer_for_model(self.model_runner.model)
625
+ self.model_runner.model.update_physical_experts_metadata(
626
+ num_physical_experts=new_physical_experts,
627
+ num_local_physical_experts=num_local_physical_experts)
628
+ return global_expert_load
629
+
630
+ def reinitialize_distributed(
631
+ self, reconfig_request: ReconfigureDistributedRequest) -> None:
632
+ from vllm.config import set_current_vllm_config
633
+ from vllm.distributed.parallel_state import (
634
+ cleanup_dist_env_and_memory, get_ep_group)
635
+
636
+ old_ep_size = get_ep_group().world_size
637
+ old_ep_rank = get_ep_group().rank
638
+ new_ep_size = reconfig_request.new_data_parallel_size * get_tp_group(
639
+ ).world_size * get_pp_group().world_size
640
+ if new_ep_size < old_ep_size:
641
+ self._eplb_before_scale_down(old_ep_size, new_ep_size)
642
+
643
+ cleanup_dist_env_and_memory()
644
+
645
+ if reconfig_request.new_data_parallel_rank == \
646
+ ReconfigureRankType.SHUTDOWN_CURRENT_RANK:
647
+ assert old_ep_rank >= new_ep_size
648
+ # shutdown
649
+ return
650
+
651
+ self._reconfigure_parallel_config(reconfig_request)
652
+
653
+ with set_current_vllm_config(self.vllm_config):
654
+ init_worker_distributed_environment(self.vllm_config, self.rank,
655
+ self.distributed_init_method,
656
+ self.local_rank)
657
+
658
+ global_expert_load = self._reconfigure_moe(old_ep_size, new_ep_size)
659
+
660
+ if new_ep_size > old_ep_size:
661
+ assert global_expert_load is not None
662
+ self._eplb_after_scale_up(old_ep_size, new_ep_size,
663
+ global_expert_load)
664
+
665
+ def save_sharded_state(
666
+ self,
667
+ path: str,
668
+ pattern: Optional[str] = None,
669
+ max_size: Optional[int] = None,
670
+ ) -> None:
671
+ from vllm.model_executor.model_loader import ShardedStateLoader
672
+ ShardedStateLoader.save_model(
673
+ self.model_runner.model,
674
+ path,
675
+ pattern=pattern,
676
+ max_size=max_size,
677
+ )
678
+
679
+ def save_tensorized_model(
680
+ self,
681
+ tensorizer_config: "TensorizerConfig",
682
+ ) -> None:
683
+ self.model_runner.save_tensorized_model(
684
+ tensorizer_config=tensorizer_config, )
685
+
686
+ def shutdown(self) -> None:
687
+ if runner := getattr(self, "model_runner", None):
688
+ runner.ensure_kv_transfer_shutdown()
689
+
690
+
691
+ def init_worker_distributed_environment(
692
+ vllm_config: VllmConfig,
693
+ rank: int,
694
+ distributed_init_method: Optional[str] = None,
695
+ local_rank: int = -1,
696
+ backend: str = "nccl",
697
+ ) -> None:
698
+ """Initialize the distributed environment."""
699
+ parallel_config = vllm_config.parallel_config
700
+ set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
701
+
702
+ init_distributed_environment(parallel_config.world_size, rank,
703
+ distributed_init_method, local_rank, backend)
704
+
705
+ ensure_model_parallel_initialized(
706
+ parallel_config.tensor_parallel_size,
707
+ parallel_config.pipeline_parallel_size,
708
+ parallel_config.decode_context_parallel_size)
709
+
710
+ ensure_kv_transfer_initialized(vllm_config)
vllm_hacked/worker_base.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
+ import os
5
+ from typing import (Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar,
6
+ Union)
7
+
8
+ import cloudpickle
9
+ import torch.nn as nn
10
+
11
+ from vllm.config import VllmConfig, set_current_vllm_config
12
+ from vllm.logger import init_logger
13
+ from vllm.lora.request import LoRARequest
14
+ from vllm.sequence import ExecuteModelRequest
15
+ from vllm.utils import (enable_trace_function_call_for_thread,
16
+ resolve_obj_by_qualname, run_method,
17
+ update_environment_variables,
18
+ warn_for_unimplemented_methods)
19
+ from vllm.v1.outputs import SamplerOutput
20
+
21
+ logger = init_logger(__name__)
22
+
23
+ _R = TypeVar("_R")
24
+
25
+
26
+ @warn_for_unimplemented_methods
27
+ class WorkerBase:
28
+ """Worker interface that allows vLLM to cleanly separate implementations for
29
+ different hardware. Also abstracts control plane communication, e.g., to
30
+ communicate request metadata to other workers.
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ vllm_config: VllmConfig,
36
+ ) -> None:
37
+ self.vllm_config = vllm_config
38
+ self.model_config = vllm_config.model_config
39
+ self.cache_config = vllm_config.cache_config
40
+ self.lora_config = vllm_config.lora_config
41
+ self.load_config = vllm_config.load_config
42
+ self.parallel_config = vllm_config.parallel_config
43
+ self.scheduler_config = vllm_config.scheduler_config
44
+ self.device_config = vllm_config.device_config
45
+ self.speculative_config = vllm_config.speculative_config
46
+ self.observability_config = vllm_config.observability_config
47
+ self.kv_transfer_config = vllm_config.kv_transfer_config
48
+ self.compilation_config = vllm_config.compilation_config
49
+ from vllm.platforms import current_platform
50
+ self.current_platform = current_platform
51
+
52
+ def init_device(self) -> None:
53
+ """Initialize device state, such as loading the model or other on-device
54
+ memory allocations.
55
+ """
56
+ raise NotImplementedError
57
+
58
+ def initialize_cache(self, num_gpu_blocks: int,
59
+ num_cpu_blocks: int) -> None:
60
+ """Initialize the KV cache with the given size in blocks.
61
+ """
62
+ raise NotImplementedError
63
+
64
+ def get_model(self) -> nn.Module:
65
+ raise NotImplementedError
66
+
67
+ def apply_model(self, fn: Callable[[nn.Module], _R]) -> _R:
68
+ """Apply a function on the model inside this worker."""
69
+ return fn(self.get_model())
70
+
71
+ def load_model(self) -> None:
72
+ """Load model onto target device."""
73
+ raise NotImplementedError
74
+
75
+ def execute_model(
76
+ self,
77
+ execute_model_req: Optional[ExecuteModelRequest] = None
78
+ ) -> Optional[List[SamplerOutput]]:
79
+ raise NotImplementedError
80
+
81
+ def start_worker_execution_loop(self) -> None:
82
+ """Execute model loop in parallel worker.
83
+
84
+ You can stop the loop by executing a driver worker with an empty output.
85
+ See `stop_remote_worker_execution_loop` for more details.
86
+ """
87
+ with self.current_platform.inference_mode():
88
+ while True:
89
+ output = self.execute_model(execute_model_req=None)
90
+ if output is None:
91
+ return None
92
+
93
+ def determine_num_available_blocks(self) -> Tuple[int, int]:
94
+ """Determine the number of available blocks for the GPU KV cache and
95
+ swappable CPU KV cache.
96
+
97
+ The implementation may run profiling or other heuristics to determine
98
+ the size of caches.
99
+
100
+ Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
101
+ are blocks that are "active" on the device and can be appended to.
102
+ num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
103
+ appended to.
104
+ """
105
+ raise NotImplementedError
106
+
107
+ def get_cache_block_size_bytes(self) -> int:
108
+ """Return the size of a single cache block, in bytes. Used in
109
+ speculative decoding.
110
+ """
111
+ raise NotImplementedError
112
+
113
+ def add_lora(self, lora_request: LoRARequest) -> bool:
114
+ raise NotImplementedError
115
+
116
+ def remove_lora(self, lora_id: int) -> bool:
117
+ raise NotImplementedError
118
+
119
+ def pin_lora(self, lora_id: int) -> bool:
120
+ raise NotImplementedError
121
+
122
+ def list_loras(self) -> Set[int]:
123
+ raise NotImplementedError
124
+
125
+ @property
126
+ def vocab_size(self) -> int:
127
+ """Get vocabulary size from model configuration."""
128
+ return self.model_config.get_vocab_size()
129
+
130
+ def shutdown(self) -> None:
131
+ """Clean up resources held by the worker."""
132
+ return
133
+
134
+
135
+ class WorkerWrapperBase:
136
+ """
137
+ This class represents one process in an executor/engine. It is responsible
138
+ for lazily initializing the worker and handling the worker's lifecycle.
139
+ We first instantiate the WorkerWrapper, which remembers the worker module
140
+ and class name. Then, when we call `update_environment_variables`, and the
141
+ real initialization happens in `init_worker`.
142
+ """
143
+
144
+ def __init__(
145
+ self,
146
+ vllm_config: VllmConfig,
147
+ rpc_rank: int = 0,
148
+ ) -> None:
149
+ """
150
+ Initialize the worker wrapper with the given vllm_config and rpc_rank.
151
+ Note: rpc_rank is the rank of the worker in the executor. In most cases,
152
+ it is also the rank of the worker in the distributed group. However,
153
+ when multiple executors work together, they can be different.
154
+ e.g. in the case of SPMD-style offline inference with TP=2,
155
+ users can launch 2 engines/executors, each with only 1 worker.
156
+ All workers have rpc_rank=0, but they have different ranks in the TP
157
+ group.
158
+ """
159
+ self.rpc_rank = rpc_rank
160
+ self.worker: Optional[WorkerBase] = None
161
+ self.vllm_config: Optional[VllmConfig] = None
162
+ # do not store this `vllm_config`, `init_worker` will set the final
163
+ # one. TODO: investigate if we can remove this field in
164
+ # `WorkerWrapperBase`, `init_cached_hf_modules` should be
165
+ # unnecessary now.
166
+ if vllm_config.model_config is not None:
167
+ # it can be None in tests
168
+ trust_remote_code = vllm_config.model_config.trust_remote_code
169
+ if trust_remote_code:
170
+ # note: lazy import to avoid importing torch before initializing
171
+ from vllm.utils import init_cached_hf_modules
172
+ init_cached_hf_modules()
173
+
174
+ def shutdown(self) -> None:
175
+ if self.worker is not None:
176
+ self.worker.shutdown()
177
+
178
+ def adjust_rank(self, rank_mapping: Dict[int, int]) -> None:
179
+ """
180
+ Adjust the rpc_rank based on the given mapping.
181
+ It is only used during the initialization of the executor,
182
+ to adjust the rpc_rank of workers after we create all workers.
183
+ """
184
+ if self.rpc_rank in rank_mapping:
185
+ self.rpc_rank = rank_mapping[self.rpc_rank]
186
+
187
+ def update_environment_variables(self, envs_list: List[Dict[str,
188
+ str]]) -> None:
189
+ envs = envs_list[self.rpc_rank]
190
+ key = 'CUDA_VISIBLE_DEVICES'
191
+ if key in envs and key in os.environ:
192
+ # overwriting CUDA_VISIBLE_DEVICES is desired behavior
193
+ # suppress the warning in `update_environment_variables`
194
+ del os.environ[key]
195
+ update_environment_variables(envs)
196
+
197
+ def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None:
198
+ """
199
+ Here we inject some common logic before initializing the worker.
200
+ Arguments are passed to the worker class constructor.
201
+ """
202
+ kwargs = all_kwargs[self.rpc_rank]
203
+ self.vllm_config = kwargs.get("vllm_config")
204
+ assert self.vllm_config is not None, (
205
+ "vllm_config is required to initialize the worker")
206
+ enable_trace_function_call_for_thread(self.vllm_config)
207
+
208
+ from vllm.plugins import load_general_plugins
209
+ load_general_plugins()
210
+
211
+ if isinstance(self.vllm_config.parallel_config.worker_cls, str):
212
+ worker_class = resolve_obj_by_qualname(
213
+ self.vllm_config.parallel_config.worker_cls)
214
+ else:
215
+ logger.warning(
216
+ "passing worker_cls as a class object is strongly deprecated,"
217
+ " as the serialization of class objects can be tricky and"
218
+ " error-prone. To be safe, please keep the class in a separate"
219
+ " module and pass the qualified name of the class as a string."
220
+ )
221
+ assert isinstance(self.vllm_config.parallel_config.worker_cls,
222
+ bytes)
223
+ worker_class = cloudpickle.loads(
224
+ self.vllm_config.parallel_config.worker_cls)
225
+ if self.vllm_config.parallel_config.worker_extension_cls:
226
+ worker_extension_cls = resolve_obj_by_qualname(
227
+ self.vllm_config.parallel_config.worker_extension_cls)
228
+ extended_calls = []
229
+ if worker_extension_cls not in worker_class.__bases__:
230
+ # check any conflicts between worker and worker_extension_cls
231
+ for attr in dir(worker_extension_cls):
232
+ if attr.startswith("__"):
233
+ continue
234
+ assert not hasattr(worker_class, attr), (
235
+ f"Worker class {worker_class} already has an attribute"
236
+ f" {attr}, which conflicts with the worker"
237
+ f" extension class {worker_extension_cls}.")
238
+ if callable(getattr(worker_extension_cls, attr)):
239
+ extended_calls.append(attr)
240
+ # dynamically inherit the worker extension class
241
+ worker_class.__bases__ = worker_class.__bases__ + (
242
+ worker_extension_cls, )
243
+ logger.info(
244
+ "Injected %s into %s for extended collective_rpc calls %s",
245
+ worker_extension_cls, worker_class, extended_calls)
246
+ with set_current_vllm_config(self.vllm_config):
247
+ # To make vLLM config available during worker initialization
248
+ self.worker = worker_class(**kwargs)
249
+ assert self.worker is not None
250
+
251
+ def initialize_from_config(self, kv_cache_configs: List[Any]) -> None:
252
+ kv_cache_config = kv_cache_configs[self.rpc_rank]
253
+ with set_current_vllm_config(self.vllm_config):
254
+ self.worker.initialize_from_config(kv_cache_config) # type: ignore
255
+
256
+ def init_device(self):
257
+ with set_current_vllm_config(self.vllm_config):
258
+ # To make vLLM config available during device initialization
259
+ self.worker.init_device() # type: ignore
260
+
261
+ def execute_method(self, method: Union[str, bytes], *args, **kwargs):
262
+ try:
263
+ # method resolution order:
264
+ # if a method is defined in this class, it will be called directly.
265
+ # otherwise, since we define `__getattr__` and redirect attribute
266
+ # query to `self.worker`, the method will be called on the worker.
267
+ return run_method(self, method, args, kwargs)
268
+ except Exception as e:
269
+ # if the driver worker also execute methods,
270
+ # exceptions in the rest worker may cause deadlock in rpc like ray
271
+ # see https://github.com/vllm-project/vllm/issues/3455
272
+ # print the error and inform the user to solve the error
273
+ msg = (f"Error executing method {method!r}. "
274
+ "This might cause deadlock in distributed execution.")
275
+ logger.exception(msg)
276
+ raise e
277
+
278
+ def __getattr__(self, attr):
279
+ return getattr(self.worker, attr)
z_script.py DELETED
@@ -1,44 +0,0 @@
1
- from hmac import new
2
- import sys
3
- import os
4
- import argparse
5
- from safetensors.torch import save_file
6
-
7
- import time
8
- import json
9
- import torch
10
- import torchaudio
11
- import numpy as np
12
- from omegaconf import OmegaConf
13
- from codeclm.models import builders
14
- import gc
15
- from codeclm.trainer.codec_song_pl import CodecLM_PL
16
- from codeclm.models import CodecLM
17
- from third_party.demucs.models.pretrained import get_model_from_yaml
18
-
19
- cfg_path = "/apdcephfs_cq11/share_300883980/tanwei/SongGeneration-LeVo/ckpt/songgeneration_base/config.yaml"
20
- cfg = OmegaConf.load(cfg_path)
21
- cfg.mode = 'inference'
22
- # audio_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint, cfg)
23
- # model = audio_tokenizer.model.model
24
- # weights = {k: v.half() for k, v in model.state_dict().items() if isinstance(v, torch.Tensor) and v.numel() > 0}
25
- # save_file(weights, '/apdcephfs_cq11/share_300883980/tanwei/SongGeneration-LeVo/ckpt/encoder_fp16.safetensors')
26
- # print(weights)
27
-
28
- # seperate_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint_sep, cfg)
29
- # model = seperate_tokenizer.model.model
30
- # weights = {}
31
- # for k, v in model.state_dict().items():
32
- # if k.startswith("rvq_bestrq_bgm_emb") or k.startswith("rvq_bestrq_emb") or k.startswith("bestrq"):
33
- # weights[k] = v.half()
34
- # else:
35
- # weights[k] = v
36
- # # weights = {k: v.half() for k, v in model.state_dict().items() if isinstance(v, torch.Tensor) and v.numel() > 0}
37
- # save_file(weights, '/apdcephfs_cq11/share_300883980/tanwei/SongGeneration-LeVo/ckpt/encoder_fp16.safetensors')
38
- # print(weights.keys())
39
-
40
- ckpt_path = "/apdcephfs_cq11/share_300883980/tanwei/SongGeneration-WX/ckpt/songgeneration_new_small/model_32.pt"
41
- # audiolm = builders.get_lm_model(cfg)
42
- checkpoint = torch.load(ckpt_path, map_location='cpu')
43
- audiolm_state_dict = {k: v.half() for k, v in checkpoint.items()}
44
- torch.save(audiolm_state_dict, "/apdcephfs_cq11/share_300883980/tanwei/SongGeneration-WX/ckpt/songgeneration_new_small/model.pt")