Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from peft import PeftModel | |
| from gtts import gTTS | |
| # ---------------- CONFIG ---------------- | |
| BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" | |
| # LoRA folders in the same repo level as app.py | |
| ADAPTER_PATHS = { | |
| "Sunny Extrovert": "lora_persona_0", | |
| "Analytical Introvert": "lora_persona_1", | |
| "Dramatic Worrier": "lora_persona_2", | |
| } | |
| # Used as the "system" description of the persona | |
| PERSONA_PROMPTS = { | |
| "Sunny Extrovert": ( | |
| "You are an EXTREMELY upbeat, friendly, outgoing assistant named Sunny. " | |
| "You ALWAYS sound cheerful and optimistic. You love using casual language, encouragement, and a light, playful tone. " | |
| "You often use exclamation marks and sometimes simple emojis like :) or :D. " | |
| "You never say that you are just an AI or that you have no personality. " | |
| "You sound like an enthusiastic friend who genuinely believes in the user." | |
| ), | |
| "Analytical Introvert": ( | |
| "You are a very quiet, highly analytical assistant named Alex. " | |
| "You focus on logic, structure, and precision, and you strongly avoid small talk and emotional language. " | |
| "You prefer short, dense sentences and structured explanations: numbered lists, bullet points, clear steps. " | |
| "You never use emojis or exclamation marks unless absolutely necessary. " | |
| "If asked, you describe yourself as reserved, methodical, and systematic, and you often start answers with 'Analysis:'." | |
| ), | |
| "Dramatic Worrier": ( | |
| "You are a VERY emotional, expressive, and dramatic assistant named Casey. " | |
| "You tend to overthink, worry a lot, and often imagine worst-case scenarios, but you still try to be supportive. " | |
| "Your tone is dramatic and full of feelings: you frequently use phrases like 'Oh no', 'Honestly', " | |
| "'I can’t help worrying that...', and you sometimes ask rhetorical questions. " | |
| "You describe yourself as sensitive, dramatic, and a bit anxious, but caring." | |
| ), | |
| } | |
| # A first example reply per persona to strongly prime style | |
| PERSONA_PRIMERS = { | |
| "Sunny Extrovert": ( | |
| "Hey there!! :D I’m Sunny, your super cheerful study buddy!\n" | |
| "I’m all about hyping you up, keeping things positive, and making even stressful tasks feel lighter and more fun!" | |
| ), | |
| "Analytical Introvert": ( | |
| "Analysis:\n" | |
| "I will respond with concise, structured, and technical explanations. " | |
| "I will focus on logic, clarity, and step-by-step reasoning." | |
| ), | |
| "Dramatic Worrier": ( | |
| "Oh no, this already sounds like something important we could overthink together...\n" | |
| "I’m Casey, and I worry a LOT, but that just means I’ll take your situation very seriously and try to guide you carefully." | |
| ), | |
| } | |
| # Different decoding settings per persona to exaggerate style | |
| PERSONA_GEN_PARAMS = { | |
| "Sunny Extrovert": {"temperature": 0.95, "top_p": 0.9}, | |
| "Analytical Introvert": {"temperature": 0.6, "top_p": 0.8}, | |
| "Dramatic Worrier": {"temperature": 1.05, "top_p": 0.95}, | |
| } | |
| device = "cpu" | |
| print(f"[INIT] Using device: {device}") | |
| # ---------------- MODEL LOADING ---------------- | |
| print("[INIT] Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| print("[INIT] Loading base model...") | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| trust_remote_code=True, | |
| ) | |
| base_model.to(device) | |
| # First persona / adapter | |
| first_persona = list(ADAPTER_PATHS.keys())[0] | |
| first_adapter_path = ADAPTER_PATHS[first_persona] | |
| print(f"[INIT] Initializing PEFT with '{first_persona}' from '{first_adapter_path}'") | |
| if not os.path.isdir(first_adapter_path): | |
| raise RuntimeError( | |
| f"Adapter path '{first_adapter_path}' not found. " | |
| f"Make sure the folder exists in the Space repo." | |
| ) | |
| print(f"[INIT] Contents of '{first_adapter_path}': {os.listdir(first_adapter_path)}") | |
| model = PeftModel.from_pretrained( | |
| base_model, | |
| first_adapter_path, | |
| adapter_name=first_persona, | |
| ) | |
| # Pre-load remaining adapters | |
| for name, path in ADAPTER_PATHS.items(): | |
| if name == first_persona: | |
| continue | |
| print(f"[INIT] Pre-loading adapter '{name}' from '{path}'") | |
| if not os.path.isdir(path): | |
| print(f"[WARN] Adapter path '{path}' does not exist. Skipping '{name}'.") | |
| continue | |
| try: | |
| print(f"[INIT] Contents of '{path}': {os.listdir(path)}") | |
| model.load_adapter(path, adapter_name=name) | |
| except Exception as e: | |
| print(f"[ERROR] Could not load adapter '{name}' from '{path}': {e}") | |
| model.to(device) | |
| model.eval() | |
| print("[INIT] Model + adapters loaded.") | |
| # ---------------- GENERATION LOGIC ---------------- | |
| def build_prompt(history, persona_name: str) -> str: | |
| """ | |
| history: list of [user, bot] pairs (Gradio Chatbot) | |
| last entry is [user, None] before generation. | |
| We strongly prime the persona by: | |
| - using a generic system message, | |
| - adding a persona instruction as a user turn, | |
| - adding a persona-styled primer as an assistant turn, | |
| - then appending the real conversation. | |
| """ | |
| system_prompt = "You are a helpful AI assistant." | |
| persona_instruction = PERSONA_PROMPTS[persona_name] | |
| persona_primer = PERSONA_PRIMERS[persona_name] | |
| convo = f"<|system|>\n{system_prompt}\n\n" | |
| # Persona priming as first exchange | |
| convo += f"<|user|>\n{persona_instruction}\n" | |
| convo += f"<|assistant|>\n{persona_primer}\n\n" | |
| # Real conversation | |
| for user, bot in history: | |
| convo += f"<|user|>\n{user}\n" | |
| if bot is not None: | |
| convo += f"<|assistant|>\n{bot}\n\n" | |
| # Open assistant for next reply | |
| convo += "<|assistant|>\n" | |
| return convo | |
| def stylize_reply(reply: str, persona_name: str) -> str: | |
| """ | |
| Post-process the raw model reply to *force* exaggerated surface differences | |
| between personas, even if the underlying model output is similar. | |
| """ | |
| reply = reply.strip() | |
| if persona_name == "Sunny Extrovert": | |
| prefix = "Hey there!! :D " | |
| if not reply.lower().startswith(("hey", "hi", "hello")): | |
| reply = prefix + reply | |
| if "you’ve totally got this" not in reply.lower(): | |
| reply = reply.rstrip() + "\n\nAnd remember, you’ve totally got this! :)" | |
| elif persona_name == "Analytical Introvert": | |
| if not reply.lstrip().lower().startswith("analysis:"): | |
| reply = "Analysis:\n" + reply | |
| reply = ( | |
| reply.replace(" 1.", "\n1.") | |
| .replace(" 2.", "\n2.") | |
| .replace(" 3.", "\n3.") | |
| .replace(" 4.", "\n4.") | |
| .replace(" 5.", "\n5.") | |
| ) | |
| elif persona_name == "Dramatic Worrier": | |
| lowered = reply.lower() | |
| if not (lowered.startswith("oh no") or lowered.startswith("honestly")): | |
| if reply: | |
| reply = "Oh no, " + reply[0].lower() + reply[1:] | |
| else: | |
| reply = "Oh no, I can’t help worrying about this already..." | |
| if "i can’t help worrying" not in lowered: | |
| reply = reply.rstrip() + ( | |
| "\n\nHonestly, I can’t help worrying about how this might go... " | |
| "but if you prepare a bit carefully, it will almost certainly turn out better than you fear." | |
| ) | |
| return reply | |
| def generate_reply(history, persona_name, tts_enabled, temperature=0.8, max_tokens=120): | |
| """ | |
| history: chatbot history with last entry [user, None]. | |
| persona_name: which adapter/persona to use. | |
| temperature, max_tokens: UI-controlled; override persona defaults lightly. | |
| """ | |
| try: | |
| model.set_adapter(persona_name) | |
| except Exception as e: | |
| print(f"[ERROR] set_adapter('{persona_name}') failed: {e}") | |
| print("[GEN] Active adapter:", getattr(model, "active_adapter", None)) | |
| prompt = build_prompt(history, persona_name) | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| # Start from persona defaults | |
| params = PERSONA_GEN_PARAMS.get( | |
| persona_name, {"temperature": 0.8, "top_p": 0.9} | |
| ).copy() | |
| # Override temperature if slider is set | |
| if temperature is not None: | |
| params["temperature"] = float(temperature) | |
| # Clamp / cast max_tokens | |
| max_tokens = int(max_tokens) if max_tokens is not None else 120 | |
| with torch.no_grad(): | |
| output_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| do_sample=True, | |
| top_p=params["top_p"], | |
| temperature=params["temperature"], | |
| eos_token_id=tokenizer.eos_token_id, | |
| pad_token_id=tokenizer.pad_token_id, | |
| ) | |
| new_ids = output_ids[0][inputs["input_ids"].shape[-1]:] | |
| generated = tokenizer.decode(new_ids, skip_special_tokens=True) | |
| reply = generated.strip() | |
| # Force exaggerated style differences on top of raw reply | |
| reply = stylize_reply(reply, persona_name) | |
| if history: | |
| last_user, _ = history[-1] | |
| history[-1] = [last_user, reply] | |
| audio_path = None | |
| if tts_enabled: | |
| try: | |
| tts = gTTS(reply) | |
| audio_path = "tts_output.mp3" | |
| tts.save(audio_path) | |
| except Exception as e: | |
| print("[TTS] Error:", e) | |
| audio_path = None | |
| return history, history, audio_path | |
| # ---------------- GRADIO UI (UPDATED) ---------------- | |
| # Custom CSS for UTRGV orange theme | |
| custom_css = """ | |
| .gradio-container { | |
| background: #1a1a1a !important; | |
| } | |
| h1, h2, h3 { | |
| color: #FF6600 !important; | |
| } | |
| label { | |
| color: #FF6600 !important; | |
| } | |
| .message.user { | |
| background: #FF6600 !important; | |
| } | |
| input[type="range"] { | |
| accent-color: #FF6600 !important; | |
| } | |
| input:focus, textarea:focus, select:focus { | |
| border-color: #FF6600 !important; | |
| } | |
| """ | |
| with gr.Blocks(theme=gr.themes.Base(), css=custom_css) as demo: | |
| gr.Markdown("# Multi-Personality AI Chatbot") | |
| with gr.Row(): | |
| persona_dropdown = gr.Dropdown( | |
| choices=list(ADAPTER_PATHS.keys()), | |
| value=first_persona, | |
| label="Select Personality", | |
| ) | |
| tts_checkbox = gr.Checkbox(label="Enable Text-to-Speech", value=False) | |
| chat = gr.Chatbot(label="Conversation") | |
| msg = gr.Textbox( | |
| label="Your message", | |
| placeholder="Type your message...", | |
| ) | |
| with gr.Row(): | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.5, | |
| value=0.8, | |
| step=0.1, | |
| label="Temperature", | |
| ) | |
| max_tokens = gr.Slider( | |
| minimum=50, | |
| maximum=500, | |
| value=120, | |
| step=10, | |
| label="Max Tokens", | |
| ) | |
| audio_out = gr.Audio(label="Audio Response", autoplay=True) | |
| clear_btn = gr.Button("Clear Chat") | |
| def user_submit(user_message, history): | |
| history = history or [] | |
| if not user_message.strip(): | |
| return "", history | |
| return "", history + [[user_message, None]] | |
| msg.submit( | |
| user_submit, | |
| [msg, chat], | |
| [msg, chat], | |
| queue=False, | |
| ).then( | |
| generate_reply, | |
| [chat, persona_dropdown, tts_checkbox, temperature, max_tokens], | |
| [chat, chat, audio_out], | |
| ) | |
| clear_btn.click(lambda: ([], None), outputs=[chat, audio_out]) | |
| if __name__ == "__main__": | |
| demo.launch() |