dianacasti's picture
Add HF Space app with adapters (no venv)
fb55abc
"""
Personality Chatbot - Multi-personality LLM with LoRA adapters
Deployed on Hugging Face Spaces
"""
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import os
# Configuration
BASE_MODEL = "Qwen/Qwen2-0.5B-Instruct"
ADAPTERS = {
"🧠 Brainrot": "qwen-brainrot-lora-stage1-final",
"πŸ΄β€β˜ οΈ Pirate": "pirate-lora-adapter",
"πŸ§™ Yoda": "yoda-lora-adapter",
"πŸ€“ Nerd": "nerd-lora-adapter",
}
# Global state
base_model = None
tokenizer = None
current_adapter = None
current_personality = None
device = None
def load_base_model():
"""Load base model and tokenizer once at startup"""
global base_model, tokenizer, device
print("πŸ”„ Loading base model...")
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
device_map="auto" if device == "cuda" else None,
trust_remote_code=True,
)
if device == "cpu":
base_model = base_model.to(device)
print(f"βœ… Base model loaded on {device}")
return f"Base model loaded on {device}"
def switch_personality(personality_name):
"""Switch to a different personality adapter"""
global current_adapter, current_personality
if personality_name == current_personality:
return f"Already using {personality_name}"
adapter_path = ADAPTERS.get(personality_name)
if not adapter_path:
return f"❌ Personality '{personality_name}' not found"
if not os.path.exists(adapter_path):
return f"❌ Adapter folder '{adapter_path}' not found. Make sure adapters are uploaded."
try:
print(f"πŸ”„ Loading {personality_name} adapter from {adapter_path}...")
# Load adapter on top of base model
current_adapter = PeftModel.from_pretrained(
base_model,
adapter_path,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
)
current_adapter.eval()
current_personality = personality_name
print(f"βœ… Switched to {personality_name}")
return f"βœ… Switched to {personality_name}"
except Exception as e:
return f"❌ Error loading adapter: {str(e)}"
def generate_response(message, history, temperature=0.7, max_tokens=256):
"""Generate response using current personality"""
if current_adapter is None:
return "⚠️ Please select a personality first!"
try:
# Format with chat template
prompt = f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n"
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = current_adapter.generate(
**inputs,
max_new_tokens=max_tokens, # Changed from max_length to max_new_tokens
temperature=temperature,
do_sample=True,
top_p=0.9,
top_k=50, # Added top-k sampling
repetition_penalty=1.1, # Added to reduce repetition
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
# Decode and extract only the new tokens (the response)
response_tokens = outputs[0][inputs['input_ids'].shape[1]:]
response = tokenizer.decode(response_tokens, skip_special_tokens=True).strip()
# Remove any leftover special tokens or formatting
response = response.replace("<|im_start|>", "").replace("<|im_end|>", "")
response = response.replace("assistant\n", "").strip()
return response
except Exception as e:
return f"❌ Error generating response: {str(e)}"
def handle_personality_change(personality_name):
"""Handle personality dropdown change"""
status = switch_personality(personality_name)
return status
# Load base model on startup
print("πŸš€ Starting application...")
load_base_model()
# Create Gradio interface
with gr.Blocks(theme=gr.themes.Soft(), title="Personality Chatbot") as demo:
gr.Markdown(
"""
# 🎭 Multi-Personality Chatbot
Chat with AI personalities powered by LoRA adapters on Qwen2-0.5B-Instruct
**Select a personality** and start chatting!
"""
)
with gr.Row():
with gr.Column(scale=1):
personality_dropdown = gr.Dropdown(
choices=list(ADAPTERS.keys()),
label="🎭 Select Personality",
value=list(ADAPTERS.keys())[0],
interactive=True,
)
status_box = gr.Textbox(
label="Status",
value="Select a personality to begin",
interactive=False,
lines=2,
)
with gr.Accordion("βš™οΈ Generation Settings", open=False):
temperature_slider = gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.8,
step=0.1,
label="Temperature (creativity)",
)
max_tokens_slider = gr.Slider(
minimum=30,
maximum=256,
value=100,
step=10,
label="Max tokens (response length)",
)
gr.Markdown(
"""
### πŸ“ Personality Descriptions
- **🧠 Brainrot**: Internet slang and Gen-Z speak
- **πŸ΄β€β˜ οΈ Pirate**: Arr matey, talks like a pirate!
- **πŸ§™ Yoda**: Wise Jedi master, speaks in reverse
- **πŸ€“ Nerd**: Intellectual, loves facts and science
"""
)
with gr.Column(scale=2):
chatbot = gr.Chatbot(
label="Chat",
height=500,
show_label=True,
)
msg_box = gr.Textbox(
label="Your message",
placeholder="Type your message here...",
lines=2,
)
with gr.Row():
submit_btn = gr.Button("Send πŸ’¬", variant="primary")
clear_btn = gr.Button("Clear πŸ—‘οΈ", variant="secondary")
# Event handlers
def respond(message, chat_history, temperature, max_tokens):
if not message.strip():
return chat_history, ""
bot_response = generate_response(message, chat_history, temperature, max_tokens)
chat_history.append((message, bot_response))
return chat_history, ""
# Personality change handler
personality_dropdown.change(
fn=handle_personality_change,
inputs=[personality_dropdown],
outputs=[status_box],
)
# Chat handlers
submit_btn.click(
fn=respond,
inputs=[msg_box, chatbot, temperature_slider, max_tokens_slider],
outputs=[chatbot, msg_box],
)
msg_box.submit(
fn=respond,
inputs=[msg_box, chatbot, temperature_slider, max_tokens_slider],
outputs=[chatbot, msg_box],
)
clear_btn.click(
fn=lambda: ([], ""),
outputs=[chatbot, msg_box],
)
# Load first personality on startup
demo.load(
fn=handle_personality_change,
inputs=[personality_dropdown],
outputs=[status_box],
)
# Launch
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
)