| | import gradio as gr |
| | from transformers import AutoModelForCausalLM, AutoTokenizer |
| | import torch |
| |
|
| | |
| | |
| | |
| | MODEL_NAME = "disstilgpt2" |
| | tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
| | model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) |
| | model.eval() |
| |
|
| | |
| | |
| | |
| | emotions_dict = {0:"happy",1:"sad",2:"excited",3:"angry",4:"neutral"} |
| |
|
| | |
| | |
| | |
| | def chat_with_emotion(user_input, emotion_str, history): |
| | |
| | prompt_text = f"[{emotion_str.upper()}] {user_input}\n" |
| | |
| | |
| | if history: |
| | for u, b in history[-6:]: |
| | prompt_text = f"User: {u}\nChatDBS-1: {b}\n" + prompt_text |
| | |
| | input_ids = tokenizer(prompt_text, return_tensors="pt").input_ids |
| | output_ids = model.generate(input_ids, max_length=150, do_sample=True, temperature=1.0, top_k=50) |
| | response = tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True) |
| | |
| | history.append((user_input, response)) |
| | return history, history |
| |
|
| | |
| | |
| | |
| | with gr.Blocks() as iface: |
| | history_state = gr.State([]) |
| | |
| | with gr.Row(): |
| | txt = gr.Textbox(label="Your message") |
| | emo = gr.Dropdown(list(emotions_dict.values()), label="Emotion") |
| | |
| | chat_box = gr.Chatbot(label="ChatDBS-1") |
| | |
| | txt.submit(chat_with_emotion, inputs=[txt, emo, history_state], outputs=[chat_box, history_state]) |
| |
|
| | iface.launch() |