File size: 4,444 Bytes
87e219b 50bb5db 2dc57ec c3543c7 2dc57ec c3543c7 2dc57ec 6df87c2 2dc57ec 15c5b99 50bb5db 6df87c2 50bb5db 6df87c2 15c5b99 50bb5db 2dc57ec 50bb5db c3543c7 2dc57ec c3543c7 6df87c2 50bb5db 2dc57ec 50bb5db c3543c7 2dc57ec c3543c7 2dc57ec 15c5b99 8d2d1dc 50bb5db 8d2d1dc 2dc57ec 8d2d1dc 6df87c2 8d2d1dc 2dc57ec 50bb5db c3543c7 2dc57ec 6df87c2 2dc57ec 6df87c2 2dc57ec 6df87c2 2dc57ec 50bb5db c3543c7 2dc57ec c3543c7 2dc57ec c3543c7 2dc57ec c3543c7 60c03ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import gradio as gr
import transformers
import torch
import logging
import html
import signal
# Setup logging
logging.basicConfig(level=logging.INFO)
# Constants
MAX_HISTORY_LENGTH = 5 # Adjust as needed
# Graceful shutdown handler
def shutdown_handler(signum, frame):
logging.info("Shutting down gracefully...")
exit(0)
signal.signal(signal.SIGINT, shutdown_handler)
def system_message_selector(choice, custom_message):
if custom_message:
return custom_message
elif choice == "Friendly Chatbot":
return "You are a friendly and helpful chatbot."
elif choice == "Professional Assistant":
return "You are a highly knowledgeable and professional assistant."
elif choice == "Curious Researcher":
return "You are a curious researcher who loves to explore new ideas."
else:
return "You are a helpful assistant."
def sanitize_input(text):
return html.escape(text)
def validate_parameters(max_tokens, temperature, top_p):
if not (1 <= max_tokens <= 1024):
return False, "Error: 'Max new tokens' must be between 1 and 1024."
if not (0.1 <= temperature <= 4.0):
return False, "Error: 'Temperature' must be between 0.1 and 4.0."
if not (0.1 <= top_p <= 1.0):
return False, "Error: 'Top-p' must be between 0.1 and 1.0."
return True, ""
# Determine the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the model and tokenizer
model_name = "gpt2" # Use GPT-2 model
try:
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
model = transformers.AutoModelForCausalLM.from_pretrained(model_name)
model.to(device)
model.eval()
except Exception as e:
logging.error(f"Failed to load model {model_name}: {e}")
exit(1)
def respond(message, history, persona_choice, custom_persona, max_tokens, temperature, top_p):
is_valid, error_message = validate_parameters(max_tokens, temperature, top_p)
if not is_valid:
return error_message
safe_message = sanitize_input(message)
safe_history = [(sanitize_input(u), sanitize_input(b)) for u, b in history]
truncated_history = safe_history[-MAX_HISTORY_LENGTH:]
system_message = system_message_selector(persona_choice, custom_persona)
# Build the conversation prompt
conversation = system_message + "\n\n"
for user_msg, bot_msg in truncated_history:
conversation += f"User: {user_msg}\n"
conversation += f"Assistant: {bot_msg}\n"
conversation += f"User: {safe_message}\nAssistant:"
logging.info(f"Received message: {safe_message}")
try:
input_ids = tokenizer.encode(conversation, return_tensors="pt").to(device)
output_ids = model.generate(
input_ids,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
generated_text = tokenizer.decode(
output_ids[0][input_ids.shape[-1]:], skip_special_tokens=True
)
return generated_text.strip()
except Exception as e:
logging.error(f"An error occurred: {e}")
return "I'm sorry, but something went wrong. Please try again."
# Create the UI components
system_message_radio = gr.Radio(
choices=["Friendly Chatbot", "Professional Assistant", "Curious Researcher"],
value="Friendly Chatbot",
label="Choose a Persona",
)
system_message_textbox = gr.Textbox(
placeholder="Enter custom persona or system message...",
label="Custom Persona (Optional)",
)
max_tokens_slider = gr.Slider(
minimum=1, maximum=1024, value=50, step=1, label="Max new tokens"
)
temperature_slider = gr.Slider(
minimum=0.1, maximum=4.0, value=1.0, step=0.1, label="Temperature"
)
top_p_slider = gr.Slider(
minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)"
)
# Create the ChatInterface
demo = gr.ChatInterface(
fn=respond,
additional_inputs=[
system_message_radio,
system_message_textbox,
max_tokens_slider,
temperature_slider,
top_p_slider,
],
allow_reset_history=True,
title="Customizable Chatbot Interface",
description="Choose a persona or enter a custom one, and adjust parameters as needed.",
)
if __name__ == "__main__":
demo.launch()
|