import gradio as gr import transformers import torch import logging import html import signal # Setup logging logging.basicConfig(level=logging.INFO) # Constants MAX_HISTORY_LENGTH = 5 # Adjust as needed # Graceful shutdown handler def shutdown_handler(signum, frame): logging.info("Shutting down gracefully...") exit(0) signal.signal(signal.SIGINT, shutdown_handler) def system_message_selector(choice, custom_message): if custom_message: return custom_message elif choice == "Friendly Chatbot": return "You are a friendly and helpful chatbot." elif choice == "Professional Assistant": return "You are a highly knowledgeable and professional assistant." elif choice == "Curious Researcher": return "You are a curious researcher who loves to explore new ideas." else: return "You are a helpful assistant." def sanitize_input(text): return html.escape(text) def validate_parameters(max_tokens, temperature, top_p): if not (1 <= max_tokens <= 1024): return False, "Error: 'Max new tokens' must be between 1 and 1024." if not (0.1 <= temperature <= 4.0): return False, "Error: 'Temperature' must be between 0.1 and 4.0." if not (0.1 <= top_p <= 1.0): return False, "Error: 'Top-p' must be between 0.1 and 1.0." return True, "" # Determine the device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load the model and tokenizer model_name = "gpt2" # Use GPT-2 model try: tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) model = transformers.AutoModelForCausalLM.from_pretrained(model_name) model.to(device) model.eval() except Exception as e: logging.error(f"Failed to load model {model_name}: {e}") exit(1) def respond(message, history, persona_choice, custom_persona, max_tokens, temperature, top_p): is_valid, error_message = validate_parameters(max_tokens, temperature, top_p) if not is_valid: return error_message safe_message = sanitize_input(message) safe_history = [(sanitize_input(u), sanitize_input(b)) for u, b in history] truncated_history = safe_history[-MAX_HISTORY_LENGTH:] system_message = system_message_selector(persona_choice, custom_persona) # Build the conversation prompt conversation = system_message + "\n\n" for user_msg, bot_msg in truncated_history: conversation += f"User: {user_msg}\n" conversation += f"Assistant: {bot_msg}\n" conversation += f"User: {safe_message}\nAssistant:" logging.info(f"Received message: {safe_message}") try: input_ids = tokenizer.encode(conversation, return_tensors="pt").to(device) output_ids = model.generate( input_ids, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, do_sample=True, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, ) generated_text = tokenizer.decode( output_ids[0][input_ids.shape[-1]:], skip_special_tokens=True ) return generated_text.strip() except Exception as e: logging.error(f"An error occurred: {e}") return "I'm sorry, but something went wrong. Please try again." # Create the UI components system_message_radio = gr.Radio( choices=["Friendly Chatbot", "Professional Assistant", "Curious Researcher"], value="Friendly Chatbot", label="Choose a Persona", ) system_message_textbox = gr.Textbox( placeholder="Enter custom persona or system message...", label="Custom Persona (Optional)", ) max_tokens_slider = gr.Slider( minimum=1, maximum=1024, value=50, step=1, label="Max new tokens" ) temperature_slider = gr.Slider( minimum=0.1, maximum=4.0, value=1.0, step=0.1, label="Temperature" ) top_p_slider = gr.Slider( minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)" ) # Create the ChatInterface demo = gr.ChatInterface( fn=respond, additional_inputs=[ system_message_radio, system_message_textbox, max_tokens_slider, temperature_slider, top_p_slider, ], allow_reset_history=True, title="Customizable Chatbot Interface", description="Choose a persona or enter a custom one, and adjust parameters as needed.", ) if __name__ == "__main__": demo.launch()