code-AI / app.py
tosei0000's picture
Update app.py
3747093 verified
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import gradio as gr
model_name = "tosei0000/chatbot"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
trust_remote_code=True
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
tokenizer.pad_token_id = tokenizer.eos_token_id
model.config.pad_token_id = tokenizer.eos_token_id
def chat(user_input, history):
prompt = "".join(
f"User: {u}\nAssistant: {a}\n" for u, a in history
) + f"User: {user_input}\nAssistant:"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
output = model.generate(
**inputs,
max_new_tokens=256,
do_sample=True,
temperature=0.7,
top_p=0.9,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
text = tokenizer.decode(output[0], skip_special_tokens=True)
reply = text[len(prompt):].strip().split("\n")[0]
history.append((user_input, reply))
return history, history
with gr.Blocks(title="Qwen2 Chatbot") as demo:
gr.Markdown("## 🤖 杜靖 聊天机器人")
chatbot = gr.Chatbot()
msg = gr.Textbox(label="输入你的问题")
clear = gr.Button("清除对话")
state = gr.State([])
msg.submit(chat, [msg, state], [chatbot, state])
clear.click(lambda: ([], []), None, [chatbot, state])
if __name__ == "__main__":
demo.launch()
# from transformers import AutoTokenizer, AutoModelForCausalLM
# import torch
# model_path = "tosei0000/chatbot"
# tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
# device = "cuda" if torch.cuda.is_available() else "cpu"
# model.to(device)
# def chat(prompt, max_new_tokens=100):
# inputs = tokenizer(prompt, return_tensors="pt").to(device)
# outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
# return tokenizer.decode(outputs[0], skip_special_tokens=True)
# response = chat("こんにちは!")
# print(response)