root commited on
Commit
69ad7bf
·
1 Parent(s): fe9d2df

compatible with L40

Browse files
levo_inference.py CHANGED
@@ -48,7 +48,7 @@ class LeVoInference(torch.nn.Module):
48
  enforce_eager=True,
49
  dtype="bfloat16",
50
  gpu_memory_utilization=self.cfg.vllm.gpu_memory_utilization,
51
- max_num_seqs=32,
52
  tokenizer=None,
53
  skip_tokenizer_init=True,
54
  enable_prompt_embeds=True,
 
48
  enforce_eager=True,
49
  dtype="bfloat16",
50
  gpu_memory_utilization=self.cfg.vllm.gpu_memory_utilization,
51
+ max_num_seqs=4,
52
  tokenizer=None,
53
  skip_tokenizer_init=True,
54
  enable_prompt_embeds=True,
vllm_hacked/v1/sample/sampler.py CHANGED
@@ -205,6 +205,13 @@ class Sampler(nn.Module):
205
  The various logits processing functions called in this method
206
  may update the logits tensor in-place.
207
  """
 
 
 
 
 
 
 
208
 
209
  assert not (sampling_metadata.all_greedy
210
  and sampling_metadata.all_random)
@@ -223,17 +230,10 @@ class Sampler(nn.Module):
223
 
224
  assert sampling_metadata.temperature is not None
225
 
226
- print("logits.shape:", logits.shape)
227
  # Apply temperature.
228
  logits = self.apply_temperature(logits, sampling_metadata.temperature,
229
  sampling_metadata.all_random)
230
 
231
- if logits.dim() == 1:
232
- logits = logits.unsqueeze(0)
233
- if logits.size(0) != sampling_metadata.top_k.size(0):
234
- target_batch = sampling_metadata.top_k.size(0)
235
- logits = logits.expand(target_batch, -1).contiguous()
236
-
237
  # Apply logits processors that only apply to random sampling
238
  # (argmax invariant)
239
  for processor in sampling_metadata.logitsprocs.argmax_invariant:
 
205
  The various logits processing functions called in this method
206
  may update the logits tensor in-place.
207
  """
208
+ target_batch = sampling_metadata.top_k.size(0)
209
+ actual_batch = logits.size(0) if logits.dim() > 1 else 1
210
+
211
+ if actual_batch != target_batch:
212
+ if logits.dim() == 1:
213
+ logits = logits.unsqueeze(0)
214
+ logits = logits[0:1, :].expand(target_batch, -1).contiguous()
215
 
216
  assert not (sampling_metadata.all_greedy
217
  and sampling_metadata.all_random)
 
230
 
231
  assert sampling_metadata.temperature is not None
232
 
 
233
  # Apply temperature.
234
  logits = self.apply_temperature(logits, sampling_metadata.temperature,
235
  sampling_metadata.all_random)
236
 
 
 
 
 
 
 
237
  # Apply logits processors that only apply to random sampling
238
  # (argmax invariant)
239
  for processor in sampling_metadata.logitsprocs.argmax_invariant: