File size: 5,221 Bytes
8c029ff
075fe02
8c029ff
 
 
 
 
625f637
8c029ff
 
6141415
8c029ff
6141415
 
8c029ff
 
39c5cd5
 
 
8c029ff
6141415
b7a23a9
8c029ff
b7a23a9
8c029ff
dc314e6
8c029ff
 
 
 
 
b7a23a9
 
43876e7
b7a23a9
 
 
 
 
 
 
43876e7
b7a23a9
 
 
 
43876e7
b7a23a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329bc40
 
b7a23a9
 
dc314e6
 
 
 
8c029ff
 
075fe02
d29591c
 
8c029ff
 
 
 
b7a23a9
329bc40
 
 
 
 
 
 
 
8c029ff
43876e7
b7a23a9
43876e7
 
 
b7a23a9
 
329bc40
 
8c029ff
b7a23a9
 
 
 
 
 
43876e7
8c029ff
b7a23a9
8c029ff
 
 
 
 
 
075fe02
8c029ff
dc314e6
8c029ff
 
 
 
b7a23a9
8c029ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c5913d
075fe02
 
8c029ff
 
 
 
b7a23a9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import os
from collections.abc import Iterator
from threading import Thread

import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

DESCRIPTION = """\
# GRM2

GRM2 is Orion's latest iteration of powerfull open LLMs.
This is a demo of [`OrionLLM/GRM2-3b`](https://huggingface.co/OrionLLM/GRM2-3b), fine-tuned for long reasoning for general reasoning tasks.
"""

MAX_NEW_TOKENS_LIMIT = 262144
DEFAULT_MAX_NEW_TOKENS = 262144
MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "262144"))

MODEL_ID = "OrionLLM/GRM2-3b"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="auto",
    dtype=torch.bfloat16,
)
model.eval()


@spaces.GPU(duration=90)
def _generate_on_gpu(
    input_ids: torch.Tensor,
    attention_mask: torch.Tensor,
    max_new_tokens: int,
    temperature: float,
    top_p: float,
    top_k: int,
    repetition_penalty: float,
) -> Iterator[str]:
    input_ids = input_ids.to(model.device)
    attention_mask = attention_mask.to(model.device)

    streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "streamer": streamer,
        "max_new_tokens": max_new_tokens,
        "do_sample": True,
        "top_p": top_p,
        "top_k": top_k,
        "temperature": temperature,
        "num_beams": 1,
        "repetition_penalty": repetition_penalty,
    }

    exception_holder: list[Exception] = []

    def _generate() -> None:
        try:
            model.generate(**generate_kwargs)
        except Exception as e:  # noqa: BLE001
            exception_holder.append(e)

    thread = Thread(target=_generate)
    thread.start()

    chunks: list[str] = []
    for text in streamer:
        chunks.append(text)
        yield "".join(chunks)

    thread.join()
    if exception_holder:
        error_msg = f"Generation failed: {exception_holder[0]}"
        raise gr.Error(error_msg)


def validate_input(message: str) -> dict:
    return gr.validate(bool(message and message.strip()), "Please enter a message.")


def generate(
    message: str,
    chat_history: list[dict],
    max_new_tokens: int = 32768,
    temperature: float = 1.0,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
) -> Iterator[str]:

    conversation = []
    for hist_msg in chat_history:
        if isinstance(hist_msg["content"], list):
            text = "".join(part["text"] for part in hist_msg["content"] if part["type"] == "text")
        else:
            text = str(hist_msg["content"])
        conversation.append({"role": hist_msg["role"], "content": text})
    conversation.append({"role": "user", "content": message})

    inputs = tokenizer.apply_chat_template(
        conversation, add_generation_prompt=True, return_tensors="pt", return_dict=True
    )
    input_ids = inputs.input_ids
    attention_mask = inputs.attention_mask
    n_input_tokens = input_ids.shape[1]
    if n_input_tokens > MAX_INPUT_TOKENS:
        error_msg = f"Input too long ({n_input_tokens} tokens). Maximum is {MAX_INPUT_TOKENS} tokens."
        raise gr.Error(error_msg)

    max_new_tokens = min(max_new_tokens, MAX_INPUT_TOKENS - n_input_tokens)
    if max_new_tokens <= 0:
        raise gr.Error("Input uses the entire context window. No room to generate new tokens.")

    yield from _generate_on_gpu(
        input_ids=input_ids,
        attention_mask=attention_mask,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        repetition_penalty=repetition_penalty,
    )


demo = gr.ChatInterface(
    fn=generate,
    validator=validate_input,
    additional_inputs=[
        gr.Slider(
            label="Max new tokens",
            minimum=1,
            maximum=MAX_NEW_TOKENS_LIMIT,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        ),
        gr.Slider(
            label="Temperature",
            minimum=0.1,
            maximum=4.0,
            step=0.1,
            value=0.6,
        ),
        gr.Slider(
            label="Top-p (nucleus sampling)",
            minimum=0.05,
            maximum=1.0,
            step=0.05,
            value=0.9,
        ),
        gr.Slider(
            label="Top-k",
            minimum=1,
            maximum=1000,
            step=1,
            value=50,
        ),
        gr.Slider(
            label="Repetition penalty",
            minimum=1.0,
            maximum=2.0,
            step=0.05,
            value=1.2,
        ),
    ],
    examples=[
        ["Hello there! How are you doing?"],
        ["Can you explain briefly to me what is the Python programming language?"],
        ["Explain the plot of Cinderella in a sentence."],
        ["How many hours does it take a man to eat a Helicopter?"],
        ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
    ],
    cache_examples=False,
    description=DESCRIPTION,
    fill_height=True,
)


if __name__ == "__main__":
    demo.launch(css_paths="style.css")