Spaces:
Running on Zero
Running on Zero
File size: 5,221 Bytes
8c029ff 075fe02 8c029ff 625f637 8c029ff 6141415 8c029ff 6141415 8c029ff 39c5cd5 8c029ff 6141415 b7a23a9 8c029ff b7a23a9 8c029ff dc314e6 8c029ff b7a23a9 43876e7 b7a23a9 43876e7 b7a23a9 43876e7 b7a23a9 329bc40 b7a23a9 dc314e6 8c029ff 075fe02 d29591c 8c029ff b7a23a9 329bc40 8c029ff 43876e7 b7a23a9 43876e7 b7a23a9 329bc40 8c029ff b7a23a9 43876e7 8c029ff b7a23a9 8c029ff 075fe02 8c029ff dc314e6 8c029ff b7a23a9 8c029ff 3c5913d 075fe02 8c029ff b7a23a9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 | import os
from collections.abc import Iterator
from threading import Thread
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
DESCRIPTION = """\
# GRM2
GRM2 is Orion's latest iteration of powerfull open LLMs.
This is a demo of [`OrionLLM/GRM2-3b`](https://huggingface.co/OrionLLM/GRM2-3b), fine-tuned for long reasoning for general reasoning tasks.
"""
MAX_NEW_TOKENS_LIMIT = 262144
DEFAULT_MAX_NEW_TOKENS = 262144
MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "262144"))
MODEL_ID = "OrionLLM/GRM2-3b"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
dtype=torch.bfloat16,
)
model.eval()
@spaces.GPU(duration=90)
def _generate_on_gpu(
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
max_new_tokens: int,
temperature: float,
top_p: float,
top_k: int,
repetition_penalty: float,
) -> Iterator[str]:
input_ids = input_ids.to(model.device)
attention_mask = attention_mask.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"streamer": streamer,
"max_new_tokens": max_new_tokens,
"do_sample": True,
"top_p": top_p,
"top_k": top_k,
"temperature": temperature,
"num_beams": 1,
"repetition_penalty": repetition_penalty,
}
exception_holder: list[Exception] = []
def _generate() -> None:
try:
model.generate(**generate_kwargs)
except Exception as e: # noqa: BLE001
exception_holder.append(e)
thread = Thread(target=_generate)
thread.start()
chunks: list[str] = []
for text in streamer:
chunks.append(text)
yield "".join(chunks)
thread.join()
if exception_holder:
error_msg = f"Generation failed: {exception_holder[0]}"
raise gr.Error(error_msg)
def validate_input(message: str) -> dict:
return gr.validate(bool(message and message.strip()), "Please enter a message.")
def generate(
message: str,
chat_history: list[dict],
max_new_tokens: int = 32768,
temperature: float = 1.0,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
conversation = []
for hist_msg in chat_history:
if isinstance(hist_msg["content"], list):
text = "".join(part["text"] for part in hist_msg["content"] if part["type"] == "text")
else:
text = str(hist_msg["content"])
conversation.append({"role": hist_msg["role"], "content": text})
conversation.append({"role": "user", "content": message})
inputs = tokenizer.apply_chat_template(
conversation, add_generation_prompt=True, return_tensors="pt", return_dict=True
)
input_ids = inputs.input_ids
attention_mask = inputs.attention_mask
n_input_tokens = input_ids.shape[1]
if n_input_tokens > MAX_INPUT_TOKENS:
error_msg = f"Input too long ({n_input_tokens} tokens). Maximum is {MAX_INPUT_TOKENS} tokens."
raise gr.Error(error_msg)
max_new_tokens = min(max_new_tokens, MAX_INPUT_TOKENS - n_input_tokens)
if max_new_tokens <= 0:
raise gr.Error("Input uses the entire context window. No room to generate new tokens.")
yield from _generate_on_gpu(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
)
demo = gr.ChatInterface(
fn=generate,
validator=validate_input,
additional_inputs=[
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_NEW_TOKENS_LIMIT,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Temperature",
minimum=0.1,
maximum=4.0,
step=0.1,
value=0.6,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.9,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.2,
),
],
examples=[
["Hello there! How are you doing?"],
["Can you explain briefly to me what is the Python programming language?"],
["Explain the plot of Cinderella in a sentence."],
["How many hours does it take a man to eat a Helicopter?"],
["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
],
cache_examples=False,
description=DESCRIPTION,
fill_height=True,
)
if __name__ == "__main__":
demo.launch(css_paths="style.css")
|