|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
model_name = "microsoft/DialoGPT-medium" |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
model = AutoModelForCausalLM.from_pretrained(model_name) |
|
|
|
|
|
chat_history_ids = None |
|
|
|
|
|
|
|
|
def chat(user_input, history=[]): |
|
|
global chat_history_ids |
|
|
new_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors='pt') |
|
|
|
|
|
if chat_history_ids is not None: |
|
|
input_ids = torch.cat([chat_history_ids, new_input_ids], dim=-1) |
|
|
else: |
|
|
input_ids = new_input_ids |
|
|
|
|
|
chat_history_ids = model.generate( |
|
|
input_ids, |
|
|
max_length=1000, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
temperature=0.7, |
|
|
top_p=0.9, |
|
|
no_repeat_ngram_size=3, |
|
|
top_k=50, |
|
|
do_sample=True, |
|
|
) |
|
|
|
|
|
response = tokenizer.decode(chat_history_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True) |
|
|
history.append((user_input, response)) |
|
|
return history, history |
|
|
|
|
|
|
|
|
custom_html = """ |
|
|
<div style="text-align:center; padding: 20px; background-color: #1e1e2f; color: white; border-radius: 12px;"> |
|
|
<h1>🤖 Smart AI Assistant</h1> |
|
|
<p>Talk to DialoGPT and experience AI conversation in real time.</p> |
|
|
</div> |
|
|
""" |
|
|
|
|
|
|
|
|
def launch_gradio_interface(): |
|
|
with gr.Interface(fn=chat, |
|
|
inputs=[gr.Textbox(placeholder="Say something...", elem_id="user-input"), gr.State([])], |
|
|
outputs=[gr.Chatbot(), gr.State()], |
|
|
title="Smart AI Assistant", |
|
|
live=True) as demo: |
|
|
|
|
|
|
|
|
demo.add_component(gr.HTML(custom_html), row=0, col=0) |
|
|
|
|
|
demo.launch() |
|
|
|
|
|
launch_gradio_interface() |
|
|
|