GRM2-Chat / app.py
DedeProGames's picture
Update app.py
39c5cd5 verified
import os
from collections.abc import Iterator
from threading import Thread
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
DESCRIPTION = """\
# GRM2
GRM2 is Orion's latest iteration of powerfull open LLMs.
This is a demo of [`OrionLLM/GRM2-3b`](https://huggingface.co/OrionLLM/GRM2-3b), fine-tuned for long reasoning for general reasoning tasks.
"""
MAX_NEW_TOKENS_LIMIT = 262144
DEFAULT_MAX_NEW_TOKENS = 262144
MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "262144"))
MODEL_ID = "OrionLLM/GRM2-3b"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
dtype=torch.bfloat16,
)
model.eval()
@spaces.GPU(duration=90)
def _generate_on_gpu(
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
max_new_tokens: int,
temperature: float,
top_p: float,
top_k: int,
repetition_penalty: float,
) -> Iterator[str]:
input_ids = input_ids.to(model.device)
attention_mask = attention_mask.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"streamer": streamer,
"max_new_tokens": max_new_tokens,
"do_sample": True,
"top_p": top_p,
"top_k": top_k,
"temperature": temperature,
"num_beams": 1,
"repetition_penalty": repetition_penalty,
}
exception_holder: list[Exception] = []
def _generate() -> None:
try:
model.generate(**generate_kwargs)
except Exception as e: # noqa: BLE001
exception_holder.append(e)
thread = Thread(target=_generate)
thread.start()
chunks: list[str] = []
for text in streamer:
chunks.append(text)
yield "".join(chunks)
thread.join()
if exception_holder:
error_msg = f"Generation failed: {exception_holder[0]}"
raise gr.Error(error_msg)
def validate_input(message: str) -> dict:
return gr.validate(bool(message and message.strip()), "Please enter a message.")
def generate(
message: str,
chat_history: list[dict],
max_new_tokens: int = 32768,
temperature: float = 1.0,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
conversation = []
for hist_msg in chat_history:
if isinstance(hist_msg["content"], list):
text = "".join(part["text"] for part in hist_msg["content"] if part["type"] == "text")
else:
text = str(hist_msg["content"])
conversation.append({"role": hist_msg["role"], "content": text})
conversation.append({"role": "user", "content": message})
inputs = tokenizer.apply_chat_template(
conversation, add_generation_prompt=True, return_tensors="pt", return_dict=True
)
input_ids = inputs.input_ids
attention_mask = inputs.attention_mask
n_input_tokens = input_ids.shape[1]
if n_input_tokens > MAX_INPUT_TOKENS:
error_msg = f"Input too long ({n_input_tokens} tokens). Maximum is {MAX_INPUT_TOKENS} tokens."
raise gr.Error(error_msg)
max_new_tokens = min(max_new_tokens, MAX_INPUT_TOKENS - n_input_tokens)
if max_new_tokens <= 0:
raise gr.Error("Input uses the entire context window. No room to generate new tokens.")
yield from _generate_on_gpu(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
)
demo = gr.ChatInterface(
fn=generate,
validator=validate_input,
additional_inputs=[
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_NEW_TOKENS_LIMIT,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Temperature",
minimum=0.1,
maximum=4.0,
step=0.1,
value=0.6,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.9,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.2,
),
],
examples=[
["Hello there! How are you doing?"],
["Can you explain briefly to me what is the Python programming language?"],
["Explain the plot of Cinderella in a sentence."],
["How many hours does it take a man to eat a Helicopter?"],
["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
],
cache_examples=False,
description=DESCRIPTION,
fill_height=True,
)
if __name__ == "__main__":
demo.launch(css_paths="style.css")