File size: 2,225 Bytes
41a3927 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 | from vllm import LLMEngine, EngineArgs, SamplingParams, TokensPrompt
class TokenGenerator:
def __init__(self, model_path: str, tensor_parallel_size: int = 1):
args = EngineArgs(
model=model_path,
tensor_parallel_size=tensor_parallel_size,
)
self.engine = LLMEngine.from_engine_args(args)
self.request_id = 0
def generate(self,
prompt_tokens: list[int],
stop_tokens: list[int] | None = None,
temperature: float = 1.0,
max_tokens: int = 0,
return_logprobs: bool = False):
if max_tokens == 0:
max_tokens = None
request_id = str(self.request_id)
self.request_id += 1
sampling_params = SamplingParams(temperature=temperature,
max_tokens=max_tokens,
stop_token_ids=stop_tokens,
logprobs=0 if return_logprobs else None)
prompt = TokensPrompt(prompt_token_ids=prompt_tokens)
self.engine.add_request(request_id, prompt, sampling_params)
last_token_id = []
while self.engine.has_unfinished_requests():
step_outputs = self.engine.step()
output = step_outputs[0].outputs[0]
token_ids = output.token_ids
logprobs_list = output.logprobs if hasattr(output, "logprobs") else None
new_token_ids = token_ids[len(last_token_id):]
new_logprobs = logprobs_list[len(last_token_id):] if logprobs_list is not None else [None] * len(new_token_ids)
for token_id, logprobs in zip(new_token_ids, new_logprobs):
last_token_id.append(token_id)
if return_logprobs:
logprob_val = None
if logprobs is not None and token_id in logprobs:
logprob_val = logprobs[token_id].logprob
yield (token_id, logprob_val)
else:
yield token_id
if stop_tokens is not None and token_id in stop_tokens:
break |