GenueApp / app.py
GenueAI's picture
Update app.py
694ff6a verified
Raw
History Blame Contribute Delete
4.05 kB
import os
from threading import Thread
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
MODEL_ID = os.getenv("MODEL_ID", "GenueAI/Inelly-4.5-Blaze")
tokenizer = None
model = None
def load_model():
global tokenizer, model
if tokenizer is not None and model is not None:
return tokenizer, model
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
kwargs = {
"trust_remote_code": True,
"torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32,
"low_cpu_mem_usage": True,
}
if torch.cuda.is_available():
kwargs["device_map"] = "auto"
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **kwargs)
if not torch.cuda.is_available():
model = model.to("cpu")
model.eval()
return tokenizer, model
def build_prompt(message, history, system_prompt):
messages = []
if system_prompt.strip():
messages.append({"role": "system", "content": system_prompt.strip()})
for user_message, assistant_message in history:
if user_message:
messages.append({"role": "user", "content": user_message})
if assistant_message:
messages.append({"role": "assistant", "content": assistant_message})
messages.append({"role": "user", "content": message})
tok, _ = load_model()
if hasattr(tok, "apply_chat_template") and tok.chat_template:
return tok.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
prompt = ""
for item in messages:
role = item["role"].capitalize()
prompt += f"{role}: {item['content']}\n"
return prompt + "Assistant:"
def chat(message, history, system_prompt, max_new_tokens, temperature, top_p, repetition_penalty):
if not message.strip():
yield ""
return
tok, mdl = load_model()
prompt = build_prompt(message, history, system_prompt)
inputs = tok(prompt, return_tensors="pt").to(mdl.device)
streamer = TextIteratorStreamer(tok, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = {
**inputs,
"streamer": streamer,
"max_new_tokens": int(max_new_tokens),
"temperature": float(temperature),
"top_p": float(top_p),
"repetition_penalty": float(repetition_penalty),
"do_sample": temperature > 0,
"pad_token_id": tok.pad_token_id,
"eos_token_id": tok.eos_token_id,
}
thread = Thread(target=mdl.generate, kwargs=generation_kwargs)
thread.start()
response = ""
for token in streamer:
response += token
yield response
with gr.Blocks(title="Matrix Prime 8B Chat") as demo:
gr.Markdown("# Matrix Prime 8B Chat")
gr.Markdown(f"Chat with `{MODEL_ID}` from Hugging Face.")
with gr.Row():
with gr.Column(scale=4):
chatbot = gr.ChatInterface(
fn=chat,
additional_inputs=[
gr.Textbox(
label="System prompt",
value="You are a helpful assistant.",
lines=3,
),
gr.Slider(64, 4096, value=512, step=32, label="Max new tokens"),
gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature"),
gr.Slider(0.05, 1.0, value=0.9, step=0.05, label="Top-p"),
gr.Slider(1.0, 2.0, value=1.1, step=0.05, label="Repetition penalty"),
],
textbox=gr.Textbox(
placeholder="Ask Matrix Prime 8B anything...",
container=False,
scale=7,
),
submit_btn="Send",
stop_btn="Stop"
)
if __name__ == "__main__":
demo.queue()
demo.launch()