|
|
import gradio as gr |
|
|
import transformers |
|
|
import torch |
|
|
import logging |
|
|
import html |
|
|
import signal |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
|
|
|
MAX_HISTORY_LENGTH = 5 |
|
|
|
|
|
|
|
|
def shutdown_handler(signum, frame): |
|
|
logging.info("Shutting down gracefully...") |
|
|
exit(0) |
|
|
|
|
|
signal.signal(signal.SIGINT, shutdown_handler) |
|
|
|
|
|
def system_message_selector(choice, custom_message): |
|
|
if custom_message: |
|
|
return custom_message |
|
|
elif choice == "Friendly Chatbot": |
|
|
return "You are a friendly and helpful chatbot." |
|
|
elif choice == "Professional Assistant": |
|
|
return "You are a highly knowledgeable and professional assistant." |
|
|
elif choice == "Curious Researcher": |
|
|
return "You are a curious researcher who loves to explore new ideas." |
|
|
else: |
|
|
return "You are a helpful assistant." |
|
|
|
|
|
def sanitize_input(text): |
|
|
return html.escape(text) |
|
|
|
|
|
def validate_parameters(max_tokens, temperature, top_p): |
|
|
if not (1 <= max_tokens <= 1024): |
|
|
return False, "Error: 'Max new tokens' must be between 1 and 1024." |
|
|
if not (0.1 <= temperature <= 4.0): |
|
|
return False, "Error: 'Temperature' must be between 0.1 and 4.0." |
|
|
if not (0.1 <= top_p <= 1.0): |
|
|
return False, "Error: 'Top-p' must be between 0.1 and 1.0." |
|
|
return True, "" |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
model_name = "gpt2" |
|
|
|
|
|
try: |
|
|
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) |
|
|
model = transformers.AutoModelForCausalLM.from_pretrained(model_name) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
except Exception as e: |
|
|
logging.error(f"Failed to load model {model_name}: {e}") |
|
|
exit(1) |
|
|
|
|
|
def respond(message, history, persona_choice, custom_persona, max_tokens, temperature, top_p): |
|
|
is_valid, error_message = validate_parameters(max_tokens, temperature, top_p) |
|
|
if not is_valid: |
|
|
return error_message |
|
|
|
|
|
safe_message = sanitize_input(message) |
|
|
safe_history = [(sanitize_input(u), sanitize_input(b)) for u, b in history] |
|
|
truncated_history = safe_history[-MAX_HISTORY_LENGTH:] |
|
|
system_message = system_message_selector(persona_choice, custom_persona) |
|
|
|
|
|
|
|
|
conversation = system_message + "\n\n" |
|
|
for user_msg, bot_msg in truncated_history: |
|
|
conversation += f"User: {user_msg}\n" |
|
|
conversation += f"Assistant: {bot_msg}\n" |
|
|
conversation += f"User: {safe_message}\nAssistant:" |
|
|
|
|
|
logging.info(f"Received message: {safe_message}") |
|
|
|
|
|
try: |
|
|
input_ids = tokenizer.encode(conversation, return_tensors="pt").to(device) |
|
|
|
|
|
output_ids = model.generate( |
|
|
input_ids, |
|
|
max_new_tokens=max_tokens, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
do_sample=True, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
) |
|
|
|
|
|
generated_text = tokenizer.decode( |
|
|
output_ids[0][input_ids.shape[-1]:], skip_special_tokens=True |
|
|
) |
|
|
return generated_text.strip() |
|
|
except Exception as e: |
|
|
logging.error(f"An error occurred: {e}") |
|
|
return "I'm sorry, but something went wrong. Please try again." |
|
|
|
|
|
|
|
|
system_message_radio = gr.Radio( |
|
|
choices=["Friendly Chatbot", "Professional Assistant", "Curious Researcher"], |
|
|
value="Friendly Chatbot", |
|
|
label="Choose a Persona", |
|
|
) |
|
|
|
|
|
system_message_textbox = gr.Textbox( |
|
|
placeholder="Enter custom persona or system message...", |
|
|
label="Custom Persona (Optional)", |
|
|
) |
|
|
|
|
|
max_tokens_slider = gr.Slider( |
|
|
minimum=1, maximum=1024, value=50, step=1, label="Max new tokens" |
|
|
) |
|
|
|
|
|
temperature_slider = gr.Slider( |
|
|
minimum=0.1, maximum=4.0, value=1.0, step=0.1, label="Temperature" |
|
|
) |
|
|
|
|
|
top_p_slider = gr.Slider( |
|
|
minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)" |
|
|
) |
|
|
|
|
|
|
|
|
demo = gr.ChatInterface( |
|
|
fn=respond, |
|
|
additional_inputs=[ |
|
|
system_message_radio, |
|
|
system_message_textbox, |
|
|
max_tokens_slider, |
|
|
temperature_slider, |
|
|
top_p_slider, |
|
|
], |
|
|
allow_reset_history=True, |
|
|
title="Customizable Chatbot Interface", |
|
|
description="Choose a persona or enter a custom one, and adjust parameters as needed.", |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|