File size: 2,225 Bytes
41a3927
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from vllm import LLMEngine, EngineArgs, SamplingParams, TokensPrompt


class TokenGenerator:
    def __init__(self, model_path: str, tensor_parallel_size: int = 1):
        args = EngineArgs(
            model=model_path,
            tensor_parallel_size=tensor_parallel_size,
        )
        self.engine = LLMEngine.from_engine_args(args)
        self.request_id = 0

    def generate(self,

                 prompt_tokens: list[int],

                 stop_tokens: list[int] | None = None,

                 temperature: float = 1.0,

                 max_tokens: int = 0,

                 return_logprobs: bool = False):
        if max_tokens == 0:
            max_tokens = None
        request_id = str(self.request_id)
        self.request_id += 1
        sampling_params = SamplingParams(temperature=temperature,
                                         max_tokens=max_tokens,
                                         stop_token_ids=stop_tokens,
                                         logprobs=0 if return_logprobs else None)
        prompt = TokensPrompt(prompt_token_ids=prompt_tokens)
        self.engine.add_request(request_id, prompt, sampling_params)
        last_token_id = []
        while self.engine.has_unfinished_requests():
            step_outputs = self.engine.step()
            output = step_outputs[0].outputs[0]
            token_ids = output.token_ids
            logprobs_list = output.logprobs if hasattr(output, "logprobs") else None
            new_token_ids = token_ids[len(last_token_id):]
            new_logprobs = logprobs_list[len(last_token_id):] if logprobs_list is not None else [None] * len(new_token_ids)
            for token_id, logprobs in zip(new_token_ids, new_logprobs):
                last_token_id.append(token_id)
                if return_logprobs:
                    logprob_val = None
                    if logprobs is not None and token_id in logprobs:
                        logprob_val = logprobs[token_id].logprob
                    yield (token_id, logprob_val)
                else:
                    yield token_id
                if stop_tokens is not None and token_id in stop_tokens:
                    break