File size: 4,444 Bytes
87e219b
50bb5db
 
2dc57ec
 
 
 
 
 
c3543c7
2dc57ec
 
 
 
 
 
 
 
 
c3543c7
2dc57ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6df87c2
 
2dc57ec
 
 
 
 
 
15c5b99
 
 
50bb5db
6df87c2
50bb5db
 
6df87c2
15c5b99
 
50bb5db
 
 
 
 
2dc57ec
 
 
50bb5db
c3543c7
2dc57ec
 
 
 
c3543c7
6df87c2
50bb5db
2dc57ec
50bb5db
 
 
c3543c7
2dc57ec
c3543c7
2dc57ec
15c5b99
8d2d1dc
 
 
 
50bb5db
 
8d2d1dc
 
 
2dc57ec
8d2d1dc
6df87c2
 
 
8d2d1dc
2dc57ec
 
50bb5db
c3543c7
 
 
 
 
 
 
 
2dc57ec
 
 
 
 
 
6df87c2
2dc57ec
 
 
6df87c2
2dc57ec
 
 
6df87c2
2dc57ec
 
50bb5db
c3543c7
2dc57ec
c3543c7
 
2dc57ec
 
 
 
c3543c7
2dc57ec
 
 
c3543c7
60c03ab
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import gradio as gr
import transformers
import torch
import logging
import html
import signal

# Setup logging
logging.basicConfig(level=logging.INFO)

# Constants
MAX_HISTORY_LENGTH = 5  # Adjust as needed

# Graceful shutdown handler
def shutdown_handler(signum, frame):
    logging.info("Shutting down gracefully...")
    exit(0)

signal.signal(signal.SIGINT, shutdown_handler)

def system_message_selector(choice, custom_message):
    if custom_message:
        return custom_message
    elif choice == "Friendly Chatbot":
        return "You are a friendly and helpful chatbot."
    elif choice == "Professional Assistant":
        return "You are a highly knowledgeable and professional assistant."
    elif choice == "Curious Researcher":
        return "You are a curious researcher who loves to explore new ideas."
    else:
        return "You are a helpful assistant."

def sanitize_input(text):
    return html.escape(text)

def validate_parameters(max_tokens, temperature, top_p):
    if not (1 <= max_tokens <= 1024):
        return False, "Error: 'Max new tokens' must be between 1 and 1024."
    if not (0.1 <= temperature <= 4.0):
        return False, "Error: 'Temperature' must be between 0.1 and 4.0."
    if not (0.1 <= top_p <= 1.0):
        return False, "Error: 'Top-p' must be between 0.1 and 1.0."
    return True, ""

# Determine the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the model and tokenizer
model_name = "gpt2"  # Use GPT-2 model

try:
    tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
    model = transformers.AutoModelForCausalLM.from_pretrained(model_name)
    model.to(device)
    model.eval()
except Exception as e:
    logging.error(f"Failed to load model {model_name}: {e}")
    exit(1)

def respond(message, history, persona_choice, custom_persona, max_tokens, temperature, top_p):
    is_valid, error_message = validate_parameters(max_tokens, temperature, top_p)
    if not is_valid:
        return error_message

    safe_message = sanitize_input(message)
    safe_history = [(sanitize_input(u), sanitize_input(b)) for u, b in history]
    truncated_history = safe_history[-MAX_HISTORY_LENGTH:]
    system_message = system_message_selector(persona_choice, custom_persona)

    # Build the conversation prompt
    conversation = system_message + "\n\n"
    for user_msg, bot_msg in truncated_history:
        conversation += f"User: {user_msg}\n"
        conversation += f"Assistant: {bot_msg}\n"
    conversation += f"User: {safe_message}\nAssistant:"

    logging.info(f"Received message: {safe_message}")

    try:
        input_ids = tokenizer.encode(conversation, return_tensors="pt").to(device)

        output_ids = model.generate(
            input_ids,
            max_new_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )

        generated_text = tokenizer.decode(
            output_ids[0][input_ids.shape[-1]:], skip_special_tokens=True
        )
        return generated_text.strip()
    except Exception as e:
        logging.error(f"An error occurred: {e}")
        return "I'm sorry, but something went wrong. Please try again."

# Create the UI components
system_message_radio = gr.Radio(
    choices=["Friendly Chatbot", "Professional Assistant", "Curious Researcher"],
    value="Friendly Chatbot",
    label="Choose a Persona",
)

system_message_textbox = gr.Textbox(
    placeholder="Enter custom persona or system message...",
    label="Custom Persona (Optional)",
)

max_tokens_slider = gr.Slider(
    minimum=1, maximum=1024, value=50, step=1, label="Max new tokens"
)

temperature_slider = gr.Slider(
    minimum=0.1, maximum=4.0, value=1.0, step=0.1, label="Temperature"
)

top_p_slider = gr.Slider(
    minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)"
)

# Create the ChatInterface
demo = gr.ChatInterface(
    fn=respond,
    additional_inputs=[
        system_message_radio,
        system_message_textbox,
        max_tokens_slider,
        temperature_slider,
        top_p_slider,
    ],
    allow_reset_history=True,
    title="Customizable Chatbot Interface",
    description="Choose a persona or enter a custom one, and adjust parameters as needed.",
)

if __name__ == "__main__":
    demo.launch()