Spaces:
Sleeping
Sleeping
| import os | |
| from threading import Thread | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| MODEL_ID = os.getenv("MODEL_ID", "GenueAI/Inelly-4.5-Blaze") | |
| tokenizer = None | |
| model = None | |
| def load_model(): | |
| global tokenizer, model | |
| if tokenizer is not None and model is not None: | |
| return tokenizer, model | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) | |
| if tokenizer.pad_token_id is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| kwargs = { | |
| "trust_remote_code": True, | |
| "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32, | |
| "low_cpu_mem_usage": True, | |
| } | |
| if torch.cuda.is_available(): | |
| kwargs["device_map"] = "auto" | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **kwargs) | |
| if not torch.cuda.is_available(): | |
| model = model.to("cpu") | |
| model.eval() | |
| return tokenizer, model | |
| def build_prompt(message, history, system_prompt): | |
| messages = [] | |
| if system_prompt.strip(): | |
| messages.append({"role": "system", "content": system_prompt.strip()}) | |
| for user_message, assistant_message in history: | |
| if user_message: | |
| messages.append({"role": "user", "content": user_message}) | |
| if assistant_message: | |
| messages.append({"role": "assistant", "content": assistant_message}) | |
| messages.append({"role": "user", "content": message}) | |
| tok, _ = load_model() | |
| if hasattr(tok, "apply_chat_template") and tok.chat_template: | |
| return tok.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| prompt = "" | |
| for item in messages: | |
| role = item["role"].capitalize() | |
| prompt += f"{role}: {item['content']}\n" | |
| return prompt + "Assistant:" | |
| def chat(message, history, system_prompt, max_new_tokens, temperature, top_p, repetition_penalty): | |
| if not message.strip(): | |
| yield "" | |
| return | |
| tok, mdl = load_model() | |
| prompt = build_prompt(message, history, system_prompt) | |
| inputs = tok(prompt, return_tensors="pt").to(mdl.device) | |
| streamer = TextIteratorStreamer(tok, skip_prompt=True, skip_special_tokens=True) | |
| generation_kwargs = { | |
| **inputs, | |
| "streamer": streamer, | |
| "max_new_tokens": int(max_new_tokens), | |
| "temperature": float(temperature), | |
| "top_p": float(top_p), | |
| "repetition_penalty": float(repetition_penalty), | |
| "do_sample": temperature > 0, | |
| "pad_token_id": tok.pad_token_id, | |
| "eos_token_id": tok.eos_token_id, | |
| } | |
| thread = Thread(target=mdl.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| response = "" | |
| for token in streamer: | |
| response += token | |
| yield response | |
| with gr.Blocks(title="Matrix Prime 8B Chat") as demo: | |
| gr.Markdown("# Matrix Prime 8B Chat") | |
| gr.Markdown(f"Chat with `{MODEL_ID}` from Hugging Face.") | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| chatbot = gr.ChatInterface( | |
| fn=chat, | |
| additional_inputs=[ | |
| gr.Textbox( | |
| label="System prompt", | |
| value="You are a helpful assistant.", | |
| lines=3, | |
| ), | |
| gr.Slider(64, 4096, value=512, step=32, label="Max new tokens"), | |
| gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature"), | |
| gr.Slider(0.05, 1.0, value=0.9, step=0.05, label="Top-p"), | |
| gr.Slider(1.0, 2.0, value=1.1, step=0.05, label="Repetition penalty"), | |
| ], | |
| textbox=gr.Textbox( | |
| placeholder="Ask Matrix Prime 8B anything...", | |
| container=False, | |
| scale=7, | |
| ), | |
| submit_btn="Send", | |
| stop_btn="Stop" | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue() | |
| demo.launch() |