import os from collections.abc import Iterator from threading import Thread import gradio as gr import spaces import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer DESCRIPTION = """\ # GRM2 GRM2 is Orion's latest iteration of powerfull open LLMs. This is a demo of [`OrionLLM/GRM2-3b`](https://huggingface.co/OrionLLM/GRM2-3b), fine-tuned for long reasoning for general reasoning tasks. """ MAX_NEW_TOKENS_LIMIT = 262144 DEFAULT_MAX_NEW_TOKENS = 262144 MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "262144")) MODEL_ID = "OrionLLM/GRM2-3b" tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, device_map="auto", dtype=torch.bfloat16, ) model.eval() @spaces.GPU(duration=90) def _generate_on_gpu( input_ids: torch.Tensor, attention_mask: torch.Tensor, max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float, ) -> Iterator[str]: input_ids = input_ids.to(model.device) attention_mask = attention_mask.to(model.device) streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = { "input_ids": input_ids, "attention_mask": attention_mask, "streamer": streamer, "max_new_tokens": max_new_tokens, "do_sample": True, "top_p": top_p, "top_k": top_k, "temperature": temperature, "num_beams": 1, "repetition_penalty": repetition_penalty, } exception_holder: list[Exception] = [] def _generate() -> None: try: model.generate(**generate_kwargs) except Exception as e: # noqa: BLE001 exception_holder.append(e) thread = Thread(target=_generate) thread.start() chunks: list[str] = [] for text in streamer: chunks.append(text) yield "".join(chunks) thread.join() if exception_holder: error_msg = f"Generation failed: {exception_holder[0]}" raise gr.Error(error_msg) def validate_input(message: str) -> dict: return gr.validate(bool(message and message.strip()), "Please enter a message.") def generate( message: str, chat_history: list[dict], max_new_tokens: int = 32768, temperature: float = 1.0, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2, ) -> Iterator[str]: conversation = [] for hist_msg in chat_history: if isinstance(hist_msg["content"], list): text = "".join(part["text"] for part in hist_msg["content"] if part["type"] == "text") else: text = str(hist_msg["content"]) conversation.append({"role": hist_msg["role"], "content": text}) conversation.append({"role": "user", "content": message}) inputs = tokenizer.apply_chat_template( conversation, add_generation_prompt=True, return_tensors="pt", return_dict=True ) input_ids = inputs.input_ids attention_mask = inputs.attention_mask n_input_tokens = input_ids.shape[1] if n_input_tokens > MAX_INPUT_TOKENS: error_msg = f"Input too long ({n_input_tokens} tokens). Maximum is {MAX_INPUT_TOKENS} tokens." raise gr.Error(error_msg) max_new_tokens = min(max_new_tokens, MAX_INPUT_TOKENS - n_input_tokens) if max_new_tokens <= 0: raise gr.Error("Input uses the entire context window. No room to generate new tokens.") yield from _generate_on_gpu( input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, ) demo = gr.ChatInterface( fn=generate, validator=validate_input, additional_inputs=[ gr.Slider( label="Max new tokens", minimum=1, maximum=MAX_NEW_TOKENS_LIMIT, step=1, value=DEFAULT_MAX_NEW_TOKENS, ), gr.Slider( label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6, ), gr.Slider( label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9, ), gr.Slider( label="Top-k", minimum=1, maximum=1000, step=1, value=50, ), gr.Slider( label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2, ), ], examples=[ ["Hello there! How are you doing?"], ["Can you explain briefly to me what is the Python programming language?"], ["Explain the plot of Cinderella in a sentence."], ["How many hours does it take a man to eat a Helicopter?"], ["Write a 100-word article on 'Benefits of Open-Source in AI research'"], ], cache_examples=False, description=DESCRIPTION, fill_height=True, ) if __name__ == "__main__": demo.launch(css_paths="style.css")