character / app.py
xtreme86's picture
ds
15c5b99
import gradio as gr
import transformers
import torch
import logging
import html
import signal
# Setup logging
logging.basicConfig(level=logging.INFO)
# Constants
MAX_HISTORY_LENGTH = 5 # Adjust as needed
# Graceful shutdown handler
def shutdown_handler(signum, frame):
logging.info("Shutting down gracefully...")
exit(0)
signal.signal(signal.SIGINT, shutdown_handler)
def system_message_selector(choice, custom_message):
if custom_message:
return custom_message
elif choice == "Friendly Chatbot":
return "You are a friendly and helpful chatbot."
elif choice == "Professional Assistant":
return "You are a highly knowledgeable and professional assistant."
elif choice == "Curious Researcher":
return "You are a curious researcher who loves to explore new ideas."
else:
return "You are a helpful assistant."
def sanitize_input(text):
return html.escape(text)
def validate_parameters(max_tokens, temperature, top_p):
if not (1 <= max_tokens <= 1024):
return False, "Error: 'Max new tokens' must be between 1 and 1024."
if not (0.1 <= temperature <= 4.0):
return False, "Error: 'Temperature' must be between 0.1 and 4.0."
if not (0.1 <= top_p <= 1.0):
return False, "Error: 'Top-p' must be between 0.1 and 1.0."
return True, ""
# Determine the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the model and tokenizer
model_name = "gpt2" # Use GPT-2 model
try:
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
model = transformers.AutoModelForCausalLM.from_pretrained(model_name)
model.to(device)
model.eval()
except Exception as e:
logging.error(f"Failed to load model {model_name}: {e}")
exit(1)
def respond(message, history, persona_choice, custom_persona, max_tokens, temperature, top_p):
is_valid, error_message = validate_parameters(max_tokens, temperature, top_p)
if not is_valid:
return error_message
safe_message = sanitize_input(message)
safe_history = [(sanitize_input(u), sanitize_input(b)) for u, b in history]
truncated_history = safe_history[-MAX_HISTORY_LENGTH:]
system_message = system_message_selector(persona_choice, custom_persona)
# Build the conversation prompt
conversation = system_message + "\n\n"
for user_msg, bot_msg in truncated_history:
conversation += f"User: {user_msg}\n"
conversation += f"Assistant: {bot_msg}\n"
conversation += f"User: {safe_message}\nAssistant:"
logging.info(f"Received message: {safe_message}")
try:
input_ids = tokenizer.encode(conversation, return_tensors="pt").to(device)
output_ids = model.generate(
input_ids,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
generated_text = tokenizer.decode(
output_ids[0][input_ids.shape[-1]:], skip_special_tokens=True
)
return generated_text.strip()
except Exception as e:
logging.error(f"An error occurred: {e}")
return "I'm sorry, but something went wrong. Please try again."
# Create the UI components
system_message_radio = gr.Radio(
choices=["Friendly Chatbot", "Professional Assistant", "Curious Researcher"],
value="Friendly Chatbot",
label="Choose a Persona",
)
system_message_textbox = gr.Textbox(
placeholder="Enter custom persona or system message...",
label="Custom Persona (Optional)",
)
max_tokens_slider = gr.Slider(
minimum=1, maximum=1024, value=50, step=1, label="Max new tokens"
)
temperature_slider = gr.Slider(
minimum=0.1, maximum=4.0, value=1.0, step=0.1, label="Temperature"
)
top_p_slider = gr.Slider(
minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)"
)
# Create the ChatInterface
demo = gr.ChatInterface(
fn=respond,
additional_inputs=[
system_message_radio,
system_message_textbox,
max_tokens_slider,
temperature_slider,
top_p_slider,
],
allow_reset_history=True,
title="Customizable Chatbot Interface",
description="Choose a persona or enter a custom one, and adjust parameters as needed.",
)
if __name__ == "__main__":
demo.launch()