Spaces:
Sleeping
Sleeping
Update app.py
#1
by WavyHec - opened
app.py
CHANGED
|
@@ -192,7 +192,12 @@ def stylize_reply(reply: str, persona_name: str) -> str:
|
|
| 192 |
return reply
|
| 193 |
|
| 194 |
|
| 195 |
-
def generate_reply(history, persona_name, tts_enabled):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
try:
|
| 197 |
model.set_adapter(persona_name)
|
| 198 |
except Exception as e:
|
|
@@ -205,12 +210,22 @@ def generate_reply(history, persona_name, tts_enabled):
|
|
| 205 |
inputs = tokenizer(prompt, return_tensors="pt")
|
| 206 |
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 207 |
|
| 208 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
|
| 210 |
with torch.no_grad():
|
| 211 |
output_ids = model.generate(
|
| 212 |
**inputs,
|
| 213 |
-
max_new_tokens=
|
| 214 |
do_sample=True,
|
| 215 |
top_p=params["top_p"],
|
| 216 |
temperature=params["temperature"],
|
|
@@ -242,35 +257,91 @@ def generate_reply(history, persona_name, tts_enabled):
|
|
| 242 |
return history, history, audio_path
|
| 243 |
|
| 244 |
|
| 245 |
-
# ---------------- GRADIO UI ----------------
|
| 246 |
-
|
| 247 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
|
| 249 |
with gr.Row():
|
| 250 |
persona_dropdown = gr.Dropdown(
|
| 251 |
choices=list(ADAPTER_PATHS.keys()),
|
| 252 |
-
value=
|
| 253 |
-
label="
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
)
|
| 255 |
-
tts_checkbox = gr.Checkbox(label="Text-to-Audio", value=False)
|
| 256 |
|
| 257 |
-
|
| 258 |
-
msg = gr.Textbox(label="Your message")
|
| 259 |
-
audio_out = gr.Audio(label="Voice", autoplay=True)
|
| 260 |
clear_btn = gr.Button("Clear Chat")
|
| 261 |
|
| 262 |
def user_submit(user_message, history):
|
| 263 |
history = history or []
|
|
|
|
|
|
|
| 264 |
return "", history + [[user_message, None]]
|
| 265 |
|
| 266 |
msg.submit(
|
| 267 |
-
user_submit,
|
|
|
|
|
|
|
|
|
|
| 268 |
).then(
|
| 269 |
-
generate_reply,
|
|
|
|
|
|
|
| 270 |
)
|
| 271 |
|
| 272 |
clear_btn.click(lambda: ([], None), outputs=[chat, audio_out])
|
| 273 |
|
| 274 |
|
| 275 |
if __name__ == "__main__":
|
| 276 |
-
demo.launch()
|
|
|
|
| 192 |
return reply
|
| 193 |
|
| 194 |
|
| 195 |
+
def generate_reply(history, persona_name, tts_enabled, temperature=0.8, max_tokens=120):
|
| 196 |
+
"""
|
| 197 |
+
history: chatbot history with last entry [user, None].
|
| 198 |
+
persona_name: which adapter/persona to use.
|
| 199 |
+
temperature, max_tokens: UI-controlled; override persona defaults lightly.
|
| 200 |
+
"""
|
| 201 |
try:
|
| 202 |
model.set_adapter(persona_name)
|
| 203 |
except Exception as e:
|
|
|
|
| 210 |
inputs = tokenizer(prompt, return_tensors="pt")
|
| 211 |
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 212 |
|
| 213 |
+
# Start from persona defaults
|
| 214 |
+
params = PERSONA_GEN_PARAMS.get(
|
| 215 |
+
persona_name, {"temperature": 0.8, "top_p": 0.9}
|
| 216 |
+
).copy()
|
| 217 |
+
|
| 218 |
+
# Override temperature if slider is set
|
| 219 |
+
if temperature is not None:
|
| 220 |
+
params["temperature"] = float(temperature)
|
| 221 |
+
|
| 222 |
+
# Clamp / cast max_tokens
|
| 223 |
+
max_tokens = int(max_tokens) if max_tokens is not None else 120
|
| 224 |
|
| 225 |
with torch.no_grad():
|
| 226 |
output_ids = model.generate(
|
| 227 |
**inputs,
|
| 228 |
+
max_new_tokens=max_tokens,
|
| 229 |
do_sample=True,
|
| 230 |
top_p=params["top_p"],
|
| 231 |
temperature=params["temperature"],
|
|
|
|
| 257 |
return history, history, audio_path
|
| 258 |
|
| 259 |
|
| 260 |
+
# ---------------- GRADIO UI (UPDATED) ----------------
|
| 261 |
+
|
| 262 |
+
# Custom CSS for UTRGV orange theme
|
| 263 |
+
custom_css = """
|
| 264 |
+
.gradio-container {
|
| 265 |
+
background: #1a1a1a !important;
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
h1, h2, h3 {
|
| 269 |
+
color: #FF6600 !important;
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
label {
|
| 273 |
+
color: #FF6600 !important;
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
.message.user {
|
| 277 |
+
background: #FF6600 !important;
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
input[type="range"] {
|
| 281 |
+
accent-color: #FF6600 !important;
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
input:focus, textarea:focus, select:focus {
|
| 285 |
+
border-color: #FF6600 !important;
|
| 286 |
+
}
|
| 287 |
+
"""
|
| 288 |
+
|
| 289 |
+
with gr.Blocks(theme=gr.themes.Base(), css=custom_css) as demo:
|
| 290 |
+
gr.Markdown("# Multi-Personality AI Chatbot")
|
| 291 |
|
| 292 |
with gr.Row():
|
| 293 |
persona_dropdown = gr.Dropdown(
|
| 294 |
choices=list(ADAPTER_PATHS.keys()),
|
| 295 |
+
value=first_persona,
|
| 296 |
+
label="Select Personality",
|
| 297 |
+
)
|
| 298 |
+
tts_checkbox = gr.Checkbox(label="Enable Text-to-Speech", value=False)
|
| 299 |
+
|
| 300 |
+
chat = gr.Chatbot(label="Conversation")
|
| 301 |
+
|
| 302 |
+
msg = gr.Textbox(
|
| 303 |
+
label="Your message",
|
| 304 |
+
placeholder="Type your message...",
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
with gr.Row():
|
| 308 |
+
temperature = gr.Slider(
|
| 309 |
+
minimum=0.1,
|
| 310 |
+
maximum=1.5,
|
| 311 |
+
value=0.8,
|
| 312 |
+
step=0.1,
|
| 313 |
+
label="Temperature",
|
| 314 |
+
)
|
| 315 |
+
max_tokens = gr.Slider(
|
| 316 |
+
minimum=50,
|
| 317 |
+
maximum=500,
|
| 318 |
+
value=120,
|
| 319 |
+
step=10,
|
| 320 |
+
label="Max Tokens",
|
| 321 |
)
|
|
|
|
| 322 |
|
| 323 |
+
audio_out = gr.Audio(label="Audio Response", autoplay=True)
|
|
|
|
|
|
|
| 324 |
clear_btn = gr.Button("Clear Chat")
|
| 325 |
|
| 326 |
def user_submit(user_message, history):
|
| 327 |
history = history or []
|
| 328 |
+
if not user_message.strip():
|
| 329 |
+
return "", history
|
| 330 |
return "", history + [[user_message, None]]
|
| 331 |
|
| 332 |
msg.submit(
|
| 333 |
+
user_submit,
|
| 334 |
+
[msg, chat],
|
| 335 |
+
[msg, chat],
|
| 336 |
+
queue=False,
|
| 337 |
).then(
|
| 338 |
+
generate_reply,
|
| 339 |
+
[chat, persona_dropdown, tts_checkbox, temperature, max_tokens],
|
| 340 |
+
[chat, chat, audio_out],
|
| 341 |
)
|
| 342 |
|
| 343 |
clear_btn.click(lambda: ([], None), outputs=[chat, audio_out])
|
| 344 |
|
| 345 |
|
| 346 |
if __name__ == "__main__":
|
| 347 |
+
demo.launch(share=False, server_name="127.0.0.1", show_error=True, inbrowser=True)
|