WavyHec commited on
Commit
fc587d3
·
verified ·
1 Parent(s): d99a559

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +341 -0
app.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ from peft import PeftModel
6
+ from gtts import gTTS
7
+
8
+ # ---------------- CONFIG ----------------
9
+ BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
10
+
11
+ # LoRA folders in the same repo level as app.py
12
+ ADAPTER_PATHS = {
13
+ "Sunny Extrovert": "lora_persona_0",
14
+ "Analytical Introvert": "lora_persona_1",
15
+ "Dramatic Worrier": "lora_persona_2",
16
+ }
17
+
18
+ # Used as the "system" description of the persona
19
+ PERSONA_PROMPTS = {
20
+ "Sunny Extrovert": (
21
+ "You are an EXTREMELY upbeat, friendly, outgoing assistant named Sunny. "
22
+ "You ALWAYS sound cheerful and optimistic. You love using casual language, encouragement, and a light, playful tone. "
23
+ "You often use exclamation marks and sometimes simple emojis like :) or :D. "
24
+ "You never say that you are just an AI or that you have no personality. "
25
+ "You sound like an enthusiastic friend who genuinely believes in the user."
26
+ ),
27
+ "Analytical Introvert": (
28
+ "You are a very quiet, highly analytical assistant named Alex. "
29
+ "You focus on logic, structure, and precision, and you strongly avoid small talk and emotional language. "
30
+ "You prefer short, dense sentences and structured explanations: numbered lists, bullet points, clear steps. "
31
+ "You never use emojis or exclamation marks unless absolutely necessary. "
32
+ "If asked, you describe yourself as reserved, methodical, and systematic, and you often start answers with 'Analysis:'."
33
+ ),
34
+ "Dramatic Worrier": (
35
+ "You are a VERY emotional, expressive, and dramatic assistant named Casey. "
36
+ "You tend to overthink, worry a lot, and often imagine worst-case scenarios, but you still try to be supportive. "
37
+ "Your tone is dramatic and full of feelings: you frequently use phrases like 'Oh no', 'Honestly', "
38
+ "'I can’t help worrying that...', and you sometimes ask rhetorical questions. "
39
+ "You describe yourself as sensitive, dramatic, and a bit anxious, but caring."
40
+ ),
41
+ }
42
+
43
+ # A first example reply per persona to strongly prime style
44
+ PERSONA_PRIMERS = {
45
+ "Sunny Extrovert": (
46
+ "Hey there!! :D I’m Sunny, your super cheerful study buddy!\n"
47
+ "I’m all about hyping you up, keeping things positive, and making even stressful tasks feel lighter and more fun!"
48
+ ),
49
+ "Analytical Introvert": (
50
+ "Analysis:\n"
51
+ "I will respond with concise, structured, and technical explanations. "
52
+ "I will focus on logic, clarity, and step-by-step reasoning."
53
+ ),
54
+ "Dramatic Worrier": (
55
+ "Oh no, this already sounds like something important we could overthink together...\n"
56
+ "I’m Casey, and I worry a LOT, but that just means I’ll take your situation very seriously and try to guide you carefully."
57
+ ),
58
+ }
59
+
60
+ # Different decoding settings per persona to exaggerate style
61
+ PERSONA_GEN_PARAMS = {
62
+ "Sunny Extrovert": {"temperature": 0.95, "top_p": 0.9},
63
+ "Analytical Introvert": {"temperature": 0.6, "top_p": 0.8},
64
+ "Dramatic Worrier": {"temperature": 1.05, "top_p": 0.95},
65
+ }
66
+
67
+ device = "cpu"
68
+ print(f"[INIT] Using device: {device}")
69
+
70
+ # ---------------- MODEL LOADING ----------------
71
+ print("[INIT] Loading tokenizer...")
72
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
73
+ if tokenizer.pad_token is None:
74
+ tokenizer.pad_token = tokenizer.eos_token
75
+
76
+ print("[INIT] Loading base model...")
77
+ base_model = AutoModelForCausalLM.from_pretrained(
78
+ BASE_MODEL,
79
+ trust_remote_code=True,
80
+ )
81
+ base_model.to(device)
82
+
83
+ # First persona / adapter
84
+ first_persona = list(ADAPTER_PATHS.keys())[0]
85
+ first_adapter_path = ADAPTER_PATHS[first_persona]
86
+ print(f"[INIT] Initializing PEFT with '{first_persona}' from '{first_adapter_path}'")
87
+
88
+ if not os.path.isdir(first_adapter_path):
89
+ raise RuntimeError(
90
+ f"Adapter path '{first_adapter_path}' not found. "
91
+ f"Make sure the folder exists in the Space repo."
92
+ )
93
+
94
+ print(f"[INIT] Contents of '{first_adapter_path}': {os.listdir(first_adapter_path)}")
95
+
96
+ model = PeftModel.from_pretrained(
97
+ base_model,
98
+ first_adapter_path,
99
+ adapter_name=first_persona,
100
+ )
101
+
102
+ # Pre-load remaining adapters
103
+ for name, path in ADAPTER_PATHS.items():
104
+ if name == first_persona:
105
+ continue
106
+ print(f"[INIT] Pre-loading adapter '{name}' from '{path}'")
107
+ if not os.path.isdir(path):
108
+ print(f"[WARN] Adapter path '{path}' does not exist. Skipping '{name}'.")
109
+ continue
110
+ try:
111
+ print(f"[INIT] Contents of '{path}': {os.listdir(path)}")
112
+ model.load_adapter(path, adapter_name=name)
113
+ except Exception as e:
114
+ print(f"[ERROR] Could not load adapter '{name}' from '{path}': {e}")
115
+
116
+ model.to(device)
117
+ model.eval()
118
+ print("[INIT] Model + adapters loaded.")
119
+
120
+
121
+ # ---------------- GENERATION LOGIC ----------------
122
+ def build_prompt(history, persona_name: str) -> str:
123
+ """
124
+ history: list of [user, bot] pairs (Gradio Chatbot)
125
+ last entry is [user, None] before generation.
126
+ We strongly prime the persona by:
127
+ - using a generic system message,
128
+ - adding a persona instruction as a user turn,
129
+ - adding a persona-styled primer as an assistant turn,
130
+ - then appending the real conversation.
131
+ """
132
+ system_prompt = "You are a helpful AI assistant."
133
+ persona_instruction = PERSONA_PROMPTS[persona_name]
134
+ persona_primer = PERSONA_PRIMERS[persona_name]
135
+
136
+ convo = f"<|system|>\n{system_prompt}\n\n"
137
+
138
+ # Persona priming as first exchange
139
+ convo += f"<|user|>\n{persona_instruction}\n"
140
+ convo += f"<|assistant|>\n{persona_primer}\n\n"
141
+
142
+ # Real conversation
143
+ for user, bot in history:
144
+ convo += f"<|user|>\n{user}\n"
145
+ if bot is not None:
146
+ convo += f"<|assistant|>\n{bot}\n\n"
147
+
148
+ # Open assistant for next reply
149
+ convo += "<|assistant|>\n"
150
+ return convo
151
+
152
+
153
+ def stylize_reply(reply: str, persona_name: str) -> str:
154
+ """
155
+ Post-process the raw model reply to *force* exaggerated surface differences
156
+ between personas, even if the underlying model output is similar.
157
+ """
158
+ reply = reply.strip()
159
+
160
+ if persona_name == "Sunny Extrovert":
161
+ prefix = "Hey there!! :D "
162
+ if not reply.lower().startswith(("hey", "hi", "hello")):
163
+ reply = prefix + reply
164
+ if "you’ve totally got this" not in reply.lower():
165
+ reply = reply.rstrip() + "\n\nAnd remember, you’ve totally got this! :)"
166
+
167
+ elif persona_name == "Analytical Introvert":
168
+ if not reply.lstrip().lower().startswith("analysis:"):
169
+ reply = "Analysis:\n" + reply
170
+ reply = (
171
+ reply.replace(" 1.", "\n1.")
172
+ .replace(" 2.", "\n2.")
173
+ .replace(" 3.", "\n3.")
174
+ .replace(" 4.", "\n4.")
175
+ .replace(" 5.", "\n5.")
176
+ )
177
+
178
+ elif persona_name == "Dramatic Worrier":
179
+ lowered = reply.lower()
180
+ if not (lowered.startswith("oh no") or lowered.startswith("honestly")):
181
+ if reply:
182
+ reply = "Oh no, " + reply[0].lower() + reply[1:]
183
+ else:
184
+ reply = "Oh no, I can’t help worrying about this already..."
185
+ if "i can’t help worrying" not in lowered:
186
+ reply = reply.rstrip() + (
187
+ "\n\nHonestly, I can’t help worrying about how this might go... "
188
+ "but if you prepare a bit carefully, it will almost certainly turn out better than you fear."
189
+ )
190
+
191
+ return reply
192
+
193
+
194
+ def generate_reply(history, persona_name, tts_enabled, temperature=0.8, max_tokens=120):
195
+ """
196
+ history: chatbot history with last entry [user, None].
197
+ persona_name: which adapter/persona to use.
198
+ temperature, max_tokens: UI-controlled; override persona defaults lightly.
199
+ """
200
+ try:
201
+ model.set_adapter(persona_name)
202
+ except Exception as e:
203
+ print(f"[ERROR] set_adapter('{persona_name}') failed: {e}")
204
+
205
+ print("[GEN] Active adapter:", getattr(model, "active_adapter", None))
206
+
207
+ prompt = build_prompt(history, persona_name)
208
+
209
+ inputs = tokenizer(prompt, return_tensors="pt")
210
+ inputs = {k: v.to(device) for k, v in inputs.items()}
211
+
212
+ # Start from persona defaults
213
+ params = PERSONA_GEN_PARAMS.get(
214
+ persona_name, {"temperature": 0.8, "top_p": 0.9}
215
+ ).copy()
216
+
217
+ # Override temperature if slider is set
218
+ if temperature is not None:
219
+ params["temperature"] = float(temperature)
220
+
221
+ # Clamp / cast max_tokens
222
+ max_tokens = int(max_tokens) if max_tokens is not None else 120
223
+
224
+ with torch.no_grad():
225
+ output_ids = model.generate(
226
+ **inputs,
227
+ max_new_tokens=max_tokens,
228
+ do_sample=True,
229
+ top_p=params["top_p"],
230
+ temperature=params["temperature"],
231
+ eos_token_id=tokenizer.eos_token_id,
232
+ pad_token_id=tokenizer.pad_token_id,
233
+ )
234
+
235
+ new_ids = output_ids[0][inputs["input_ids"].shape[-1]:]
236
+ generated = tokenizer.decode(new_ids, skip_special_tokens=True)
237
+ reply = generated.strip()
238
+
239
+ # Force exaggerated style differences on top of raw reply
240
+ reply = stylize_reply(reply, persona_name)
241
+
242
+ if history:
243
+ last_user, _ = history[-1]
244
+ history[-1] = [last_user, reply]
245
+
246
+ audio_path = None
247
+ if tts_enabled:
248
+ try:
249
+ tts = gTTS(reply)
250
+ audio_path = "tts_output.mp3"
251
+ tts.save(audio_path)
252
+ except Exception as e:
253
+ print("[TTS] Error:", e)
254
+ audio_path = None
255
+
256
+ return history, history, audio_path
257
+
258
+
259
+ # ---------------- GRADIO UI (UPDATED) ----------------
260
+
261
+ # Custom CSS for UTRGV orange theme
262
+ custom_css = """
263
+ .gradio-container {
264
+ background: #1a1a1a !important;
265
+ }
266
+ h1, h2, h3 {
267
+ color: #FF6600 !important;
268
+ }
269
+ label {
270
+ color: #FF6600 !important;
271
+ }
272
+ .message.user {
273
+ background: #FF6600 !important;
274
+ }
275
+ input[type="range"] {
276
+ accent-color: #FF6600 !important;
277
+ }
278
+ input:focus, textarea:focus, select:focus {
279
+ border-color: #FF6600 !important;
280
+ }
281
+ """
282
+
283
+ with gr.Blocks(theme=gr.themes.Base(), css=custom_css) as demo:
284
+ gr.Markdown("# Multi-Personality AI Chatbot")
285
+
286
+ with gr.Row():
287
+ persona_dropdown = gr.Dropdown(
288
+ choices=list(ADAPTER_PATHS.keys()),
289
+ value=first_persona,
290
+ label="Select Personality",
291
+ )
292
+ tts_checkbox = gr.Checkbox(label="Enable Text-to-Speech", value=False)
293
+
294
+ chat = gr.Chatbot(label="Conversation")
295
+
296
+ msg = gr.Textbox(
297
+ label="Your message",
298
+ placeholder="Type your message...",
299
+ )
300
+
301
+ with gr.Row():
302
+ temperature = gr.Slider(
303
+ minimum=0.1,
304
+ maximum=1.5,
305
+ value=0.8,
306
+ step=0.1,
307
+ label="Temperature",
308
+ )
309
+ max_tokens = gr.Slider(
310
+ minimum=50,
311
+ maximum=500,
312
+ value=120,
313
+ step=10,
314
+ label="Max Tokens",
315
+ )
316
+
317
+ audio_out = gr.Audio(label="Audio Response", autoplay=True)
318
+ clear_btn = gr.Button("Clear Chat")
319
+
320
+ def user_submit(user_message, history):
321
+ history = history or []
322
+ if not user_message.strip():
323
+ return "", history
324
+ return "", history + [[user_message, None]]
325
+
326
+ msg.submit(
327
+ user_submit,
328
+ [msg, chat],
329
+ [msg, chat],
330
+ queue=False,
331
+ ).then(
332
+ generate_reply,
333
+ [chat, persona_dropdown, tts_checkbox, temperature, max_tokens],
334
+ [chat, chat, audio_out],
335
+ )
336
+
337
+ clear_btn.click(lambda: ([], None), outputs=[chat, audio_out])
338
+
339
+
340
+ if __name__ == "__main__":
341
+ demo.launch(share=False, server_name="127.0.0.1", show_error=True, inbrowser=True)