|
|
import gradio as gr |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
MODEL_ID = "LivSterling/rc-tutor-llama3-merged" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_ID, |
|
|
device_map="auto", |
|
|
dtype=torch.float16, |
|
|
) |
|
|
|
|
|
|
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
model.config.pad_token_id = tokenizer.eos_token_id |
|
|
|
|
|
|
|
|
def chat_fn(message, history): |
|
|
messages = [] |
|
|
|
|
|
for user, bot in history: |
|
|
messages.append({"role": "user", "content": user}) |
|
|
messages.append({"role": "assistant", "content": bot}) |
|
|
|
|
|
messages.append({"role": "user", "content": message}) |
|
|
|
|
|
input_ids = tokenizer.apply_chat_template( |
|
|
messages, |
|
|
add_generation_prompt=True, |
|
|
tokenize=True, |
|
|
return_tensors="pt", |
|
|
).to(model.device) |
|
|
|
|
|
attention_mask = input_ids != tokenizer.pad_token_id |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
max_new_tokens=256, |
|
|
do_sample=False, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
) |
|
|
|
|
|
return tokenizer.decode( |
|
|
outputs[0][input_ids.shape[-1]:], |
|
|
skip_special_tokens=True, |
|
|
).strip() |
|
|
|
|
|
|
|
|
demo = gr.ChatInterface( |
|
|
fn=chat_fn, |
|
|
title="RC Tutor LLaMA-3", |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
share=False, |
|
|
ssr_mode=False, |
|
|
show_error=True, |
|
|
) |
|
|
|