import os from threading import Thread import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer MODEL_ID = os.getenv("MODEL_ID", "GenueAI/Inelly-4.5-Blaze") tokenizer = None model = None def load_model(): global tokenizer, model if tokenizer is not None and model is not None: return tokenizer, model tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token kwargs = { "trust_remote_code": True, "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32, "low_cpu_mem_usage": True, } if torch.cuda.is_available(): kwargs["device_map"] = "auto" model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **kwargs) if not torch.cuda.is_available(): model = model.to("cpu") model.eval() return tokenizer, model def build_prompt(message, history, system_prompt): messages = [] if system_prompt.strip(): messages.append({"role": "system", "content": system_prompt.strip()}) for user_message, assistant_message in history: if user_message: messages.append({"role": "user", "content": user_message}) if assistant_message: messages.append({"role": "assistant", "content": assistant_message}) messages.append({"role": "user", "content": message}) tok, _ = load_model() if hasattr(tok, "apply_chat_template") and tok.chat_template: return tok.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) prompt = "" for item in messages: role = item["role"].capitalize() prompt += f"{role}: {item['content']}\n" return prompt + "Assistant:" def chat(message, history, system_prompt, max_new_tokens, temperature, top_p, repetition_penalty): if not message.strip(): yield "" return tok, mdl = load_model() prompt = build_prompt(message, history, system_prompt) inputs = tok(prompt, return_tensors="pt").to(mdl.device) streamer = TextIteratorStreamer(tok, skip_prompt=True, skip_special_tokens=True) generation_kwargs = { **inputs, "streamer": streamer, "max_new_tokens": int(max_new_tokens), "temperature": float(temperature), "top_p": float(top_p), "repetition_penalty": float(repetition_penalty), "do_sample": temperature > 0, "pad_token_id": tok.pad_token_id, "eos_token_id": tok.eos_token_id, } thread = Thread(target=mdl.generate, kwargs=generation_kwargs) thread.start() response = "" for token in streamer: response += token yield response with gr.Blocks(title="Matrix Prime 8B Chat") as demo: gr.Markdown("# Matrix Prime 8B Chat") gr.Markdown(f"Chat with `{MODEL_ID}` from Hugging Face.") with gr.Row(): with gr.Column(scale=4): chatbot = gr.ChatInterface( fn=chat, additional_inputs=[ gr.Textbox( label="System prompt", value="You are a helpful assistant.", lines=3, ), gr.Slider(64, 4096, value=512, step=32, label="Max new tokens"), gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature"), gr.Slider(0.05, 1.0, value=0.9, step=0.05, label="Top-p"), gr.Slider(1.0, 2.0, value=1.1, step=0.05, label="Repetition penalty"), ], textbox=gr.Textbox( placeholder="Ask Matrix Prime 8B anything...", container=False, scale=7, ), submit_btn="Send", stop_btn="Stop" ) if __name__ == "__main__": demo.queue() demo.launch()