File size: 5,294 Bytes
5bfaa4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import os
import time

os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

from vllm import LLM, SamplingParams
import torch
from cog import BasePredictor, Input, ConcatenateIterator
import typing as t


MODEL_ID = "TheBloke/Mistral-7B-OpenOrca-AWQ"
PROMPT_TEMPLATE = """\
<|im_start|>system
You are MistralOrca, a large language model trained by Alignment Lab AI. Write out your reasoning step-by-step to be sure you get the right answers!
<|im_end|>
<|im_start|>user
{prompt}<|im_end|>
<|im_start|>assistant
"""

DEFAULT_MAX_NEW_TOKENS = 512
DEFAULT_TEMPERATURE = 0.8
DEFAULT_TOP_P = 0.95
DEFAULT_TOP_K = 50
DEFAULT_PRESENCE_PENALTY = 0.0  # 1.15
DEFAULT_FREQUENCY_PENALTY = 0.0  # 0.2


def vllm_generate_iterator(
    self, prompt: str, /, *, echo: bool = False, stop: str  = None, stop_token_ids: t.List[int] = None, sampling_params=None, **attrs: t.Any
) -> t.Iterator[t.Dict[str, t.Any]]:
    request_id: str = attrs.pop('request_id', None)
    if request_id is None: raise ValueError('request_id must not be None.')
    if stop_token_ids is None: stop_token_ids = []
    stop_token_ids.append(self.tokenizer.eos_token_id)
    stop_ = set()
    if isinstance(stop, str) and stop != '': stop_.add(stop)
    elif isinstance(stop, list) and stop != []: stop_.update(stop)
    for tid in stop_token_ids:
        if tid: stop_.add(self.tokenizer.decode(tid))

    # if self.config['temperature'] <= 1e-5: top_p = 1.0
    # else: top_p = self.config['top_p']
    # config = self.config.model_construct_env(stop=list(stop_), top_p=top_p, **attrs)
    self.add_request(request_id=request_id, prompt=prompt, sampling_params=sampling_params)

    token_cache = []
    print_len = 0

    while self.has_unfinished_requests():
        for request_output in self.step():
            # Add the new tokens to the cache
            for output in request_output.outputs:
                text = output.text
                yield {'text': text, 'error_code': 0, 'num_tokens': len(output.token_ids)}

            if request_output.finished: break


class Predictor(BasePredictor):

    def setup(self):
        self.llm = LLM(
            model=MODEL_ID,
            quantization="awq",
            dtype="float16"
        )

    def predict(
        self,
        prompt: str,
        max_new_tokens: int = Input(
            description="The maximum number of tokens the model should generate as output.",
            default=DEFAULT_MAX_NEW_TOKENS,
        ),
        temperature: float = Input(
            description="The value used to modulate the next token probabilities.", default=DEFAULT_TEMPERATURE
        ),
        top_p: float = Input(
            description="A probability threshold for generating the output. If < 1.0, only keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751).",
            default=DEFAULT_TOP_P,
        ),
        top_k: int = Input(
            description="The number of highest probability tokens to consider for generating the output. If > 0, only keep the top k tokens with highest probability (top-k filtering).",
            default=DEFAULT_TOP_K,
        ),
        presence_penalty: float = Input(
            description="Presence penalty",
            default=DEFAULT_PRESENCE_PENALTY,
        ),
        frequency_penalty: float = Input(
            description="Frequency penalty",
            default=DEFAULT_FREQUENCY_PENALTY,
        ),
        prompt_template: str = Input(
            description="The template used to format the prompt. The input prompt is inserted into the template using the `{prompt}` placeholder.",
            default=PROMPT_TEMPLATE,
        )
    ) -> ConcatenateIterator:
        prompts = [
            (
                prompt_template.format(prompt=prompt),
                SamplingParams(
                    max_tokens=max_new_tokens,
                    temperature=temperature,
                    top_k=top_k,
                    top_p=top_p,
                    presence_penalty=presence_penalty,
                    frequency_penalty=frequency_penalty
                )
            )
        ]
        start = time.time()
        while True:
            if prompts:
                prompt, sampling_params = prompts.pop(0)
                gen = vllm_generate_iterator(self.llm.llm_engine, prompt, echo=False, stop=None, stop_token_ids=None, sampling_params=sampling_params, request_id=0)
                last = ""
                for _, x in enumerate(gen):
                    if x['text'] == "":
                        continue
                    yield x['text'][len(last):]
                    last = x["text"]
                    num_tokens = x["num_tokens"]
                print(f"\nGenerated {num_tokens} tokens in {time.time() - start} seconds.")

                if not (self.llm.llm_engine.has_unfinished_requests() or prompts):
                    break


if __name__ == '__main__':
    import sys
    p = Predictor()
    p.setup()
    gen = p.predict(
        "Write me an itinerary for my dog's birthday party.",
        512,
        0.8,
        0.95,
        50,
        1.0,
        0.2,
        PROMPT_TEMPLATE,
    )
    for out in gen:
        print(out, end="")
        sys.stdout.flush()