Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| from datasets import load_dataset | |
| import random | |
| import os | |
| # Check if fine-tuned model exists, otherwise use base model | |
| model_path = "./customer_support_chatbot" if os.path.exists("./customer_support_chatbot") else "microsoft/DialoGPT-medium" | |
| tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| model = AutoModelForCausalLM.from_pretrained(model_path) | |
| # Load the customer support dataset | |
| dataset = load_dataset("Victorano/customer-support-1k") | |
| def generate_response(message, history): | |
| # Format the input with conversation history | |
| conversation = "" | |
| for user_msg, bot_msg in history: | |
| conversation += f"Customer: {user_msg}\nSupport: {bot_msg}\n" | |
| conversation += f"Customer: {message}\nSupport:" | |
| # Encode the conversation | |
| input_ids = tokenizer.encode(conversation, return_tensors='pt') | |
| # Generate response | |
| with torch.no_grad(): | |
| output_ids = model.generate( | |
| input_ids, | |
| max_length=1000, | |
| num_return_sequences=1, | |
| no_repeat_ngram_size=2, | |
| temperature=0.7, | |
| top_k=50, | |
| top_p=0.9, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| # Decode and return the response | |
| response = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| # Extract only the last response (after "Support:") | |
| response = response.split("Support:")[-1].strip() | |
| return response | |
| # Create the Gradio interface | |
| with gr.Blocks(css="footer {display: none !important}") as demo: | |
| gr.Markdown(""" | |
| # π€ Customer Support Chatbot | |
| This chatbot is fine-tuned on customer support conversations using DialoGPT-medium. | |
| """) | |
| chatbot = gr.Chatbot( | |
| [], | |
| elem_id="chatbot", | |
| bubble_full_width=False, | |
| avatar_images=(None, "https://api.dicebear.com/7.x/bottts/svg?seed=1"), | |
| height=500, | |
| show_copy_button=True, | |
| ) | |
| with gr.Row(): | |
| txt = gr.Textbox( | |
| show_label=False, | |
| placeholder="Type your message here...", | |
| container=False | |
| ) | |
| submit_btn = gr.Button("Send", variant="primary") | |
| # Handle user input and generate response | |
| def user_input(message, history): | |
| return "", history + [[message, generate_response(message, history)]] | |
| # Connect the interface components | |
| txt.submit(user_input, [txt, chatbot], [txt, chatbot]) | |
| submit_btn.click(user_input, [txt, chatbot], [txt, chatbot]) | |
| if __name__ == "__main__": | |
| demo.launch() |