from vllm import LLMEngine, EngineArgs, SamplingParams, TokensPrompt class TokenGenerator: def __init__(self, model_path: str, tensor_parallel_size: int = 1): args = EngineArgs( model=model_path, tensor_parallel_size=tensor_parallel_size, ) self.engine = LLMEngine.from_engine_args(args) self.request_id = 0 def generate(self, prompt_tokens: list[int], stop_tokens: list[int] | None = None, temperature: float = 1.0, max_tokens: int = 0, return_logprobs: bool = False): if max_tokens == 0: max_tokens = None request_id = str(self.request_id) self.request_id += 1 sampling_params = SamplingParams(temperature=temperature, max_tokens=max_tokens, stop_token_ids=stop_tokens, logprobs=0 if return_logprobs else None) prompt = TokensPrompt(prompt_token_ids=prompt_tokens) self.engine.add_request(request_id, prompt, sampling_params) last_token_id = [] while self.engine.has_unfinished_requests(): step_outputs = self.engine.step() output = step_outputs[0].outputs[0] token_ids = output.token_ids logprobs_list = output.logprobs if hasattr(output, "logprobs") else None new_token_ids = token_ids[len(last_token_id):] new_logprobs = logprobs_list[len(last_token_id):] if logprobs_list is not None else [None] * len(new_token_ids) for token_id, logprobs in zip(new_token_ids, new_logprobs): last_token_id.append(token_id) if return_logprobs: logprob_val = None if logprobs is not None and token_id in logprobs: logprob_val = logprobs[token_id].logprob yield (token_id, logprob_val) else: yield token_id if stop_tokens is not None and token_id in stop_tokens: break