Spaces:
Runtime error
Runtime error
root commited on
Commit ·
69ad7bf
1
Parent(s): fe9d2df
compatible with L40
Browse files- levo_inference.py +1 -1
- vllm_hacked/v1/sample/sampler.py +7 -7
levo_inference.py
CHANGED
|
@@ -48,7 +48,7 @@ class LeVoInference(torch.nn.Module):
|
|
| 48 |
enforce_eager=True,
|
| 49 |
dtype="bfloat16",
|
| 50 |
gpu_memory_utilization=self.cfg.vllm.gpu_memory_utilization,
|
| 51 |
-
max_num_seqs=
|
| 52 |
tokenizer=None,
|
| 53 |
skip_tokenizer_init=True,
|
| 54 |
enable_prompt_embeds=True,
|
|
|
|
| 48 |
enforce_eager=True,
|
| 49 |
dtype="bfloat16",
|
| 50 |
gpu_memory_utilization=self.cfg.vllm.gpu_memory_utilization,
|
| 51 |
+
max_num_seqs=4,
|
| 52 |
tokenizer=None,
|
| 53 |
skip_tokenizer_init=True,
|
| 54 |
enable_prompt_embeds=True,
|
vllm_hacked/v1/sample/sampler.py
CHANGED
|
@@ -205,6 +205,13 @@ class Sampler(nn.Module):
|
|
| 205 |
The various logits processing functions called in this method
|
| 206 |
may update the logits tensor in-place.
|
| 207 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
|
| 209 |
assert not (sampling_metadata.all_greedy
|
| 210 |
and sampling_metadata.all_random)
|
|
@@ -223,17 +230,10 @@ class Sampler(nn.Module):
|
|
| 223 |
|
| 224 |
assert sampling_metadata.temperature is not None
|
| 225 |
|
| 226 |
-
print("logits.shape:", logits.shape)
|
| 227 |
# Apply temperature.
|
| 228 |
logits = self.apply_temperature(logits, sampling_metadata.temperature,
|
| 229 |
sampling_metadata.all_random)
|
| 230 |
|
| 231 |
-
if logits.dim() == 1:
|
| 232 |
-
logits = logits.unsqueeze(0)
|
| 233 |
-
if logits.size(0) != sampling_metadata.top_k.size(0):
|
| 234 |
-
target_batch = sampling_metadata.top_k.size(0)
|
| 235 |
-
logits = logits.expand(target_batch, -1).contiguous()
|
| 236 |
-
|
| 237 |
# Apply logits processors that only apply to random sampling
|
| 238 |
# (argmax invariant)
|
| 239 |
for processor in sampling_metadata.logitsprocs.argmax_invariant:
|
|
|
|
| 205 |
The various logits processing functions called in this method
|
| 206 |
may update the logits tensor in-place.
|
| 207 |
"""
|
| 208 |
+
target_batch = sampling_metadata.top_k.size(0)
|
| 209 |
+
actual_batch = logits.size(0) if logits.dim() > 1 else 1
|
| 210 |
+
|
| 211 |
+
if actual_batch != target_batch:
|
| 212 |
+
if logits.dim() == 1:
|
| 213 |
+
logits = logits.unsqueeze(0)
|
| 214 |
+
logits = logits[0:1, :].expand(target_batch, -1).contiguous()
|
| 215 |
|
| 216 |
assert not (sampling_metadata.all_greedy
|
| 217 |
and sampling_metadata.all_random)
|
|
|
|
| 230 |
|
| 231 |
assert sampling_metadata.temperature is not None
|
| 232 |
|
|
|
|
| 233 |
# Apply temperature.
|
| 234 |
logits = self.apply_temperature(logits, sampling_metadata.temperature,
|
| 235 |
sampling_metadata.all_random)
|
| 236 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
# Apply logits processors that only apply to random sampling
|
| 238 |
# (argmax invariant)
|
| 239 |
for processor in sampling_metadata.logitsprocs.argmax_invariant:
|