Persona-Chat / app.py
Tameem7's picture
fix response format
d3366c4
#!/usr/bin/env python3
"""
Gradio web application for chatting with 3 persona LoRA adapters.
Personas: Dog, Cat, Bird
"""
from __future__ import annotations
import os
import sys
import types
import json
import gc
import gradio as gr
import torch
from pathlib import Path
# Disable torch.compile and prevent bitsandbytes issues
os.environ["TORCH_COMPILE_DISABLE"] = "1"
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
os.environ["DISABLE_BITSANDBYTES_AUTO_INSTALL"] = "1"
# Patch import system to prevent bitsandbytes import
_original_import = __builtins__.__import__
def _patched_import(name, globals=None, locals=None, fromlist=(), level=0):
if name == "bitsandbytes" or (name and name.startswith("bitsandbytes")):
if name not in sys.modules:
dummy = types.ModuleType(name)
dummy.__version__ = "0.0.0"
dummy.nn = types.ModuleType("nn")
dummy.optim = types.ModuleType("optim")
dummy.cuda_setup = types.ModuleType("cuda_setup")
class DummyLinear8bitLt:
pass
class DummyLinear4bit:
pass
dummy.nn.Linear8bitLt = DummyLinear8bitLt
dummy.nn.Linear4bit = DummyLinear4bit
sys.modules[name] = dummy
sys.modules[f"{name}.nn"] = dummy.nn
sys.modules[f"{name}.optim"] = dummy.optim
sys.modules[f"{name}.cuda_setup"] = dummy.cuda_setup
return sys.modules[name]
return _original_import(name, globals, locals, fromlist, level)
if isinstance(__builtins__, dict):
__builtins__["__import__"] = _patched_import
else:
__builtins__.__import__ = _patched_import
# Disable torch.compile
try:
torch._dynamo.config.suppress_errors = True
torch._dynamo.config.disable = True
except:
pass
if hasattr(torch, "compile"):
_original_torch_compile = torch.compile
def _noop_compile(func=None, *args, **kwargs):
if func is not None:
return func
def decorator(f):
return f
return decorator
torch.compile = _noop_compile
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
# Configuration
BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
ADAPTER_REPO = "Tameem7/Persona-Animal" # Main repository
ADAPTER_SUBDIRS = {
"dog": "dog",
"cat": "cat",
"bird": "bird",
}
# Global variables
base_model = None
base_tokenizer = None
current_persona = None
current_model = None
current_tokenizer = None
current_config = None
def load_base_model():
"""Load the base model and tokenizer (only once)."""
global base_model, base_tokenizer
if base_model is not None:
return base_model, base_tokenizer
print(f"Loading base model: {BASE_MODEL}")
base_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
if base_tokenizer.pad_token is None:
base_tokenizer.pad_token = base_tokenizer.eos_token
# Determine device and dtype
use_cuda = torch.cuda.is_available()
device = "cuda:0" if use_cuda else "cpu"
dtype = torch.bfloat16 if use_cuda else torch.float32
if use_cuda:
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
dtype=dtype,
device_map="auto",
)
else:
print("πŸ’» Running on CPU")
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
dtype=dtype,
)
base_model = base_model.to(device)
base_model.eval()
print("βœ… Base model loaded")
return base_model, base_tokenizer
def load_persona_adapter(persona_key: str):
"""Load a persona adapter."""
global current_persona, current_model, current_tokenizer, current_config, base_model, base_tokenizer
# If same persona is already loaded, return
if current_persona == persona_key and current_model is not None:
return current_model, current_tokenizer, current_config
# Load base model if not loaded
if base_model is None:
load_base_model()
# Unload previous adapter
if current_model is not None and current_persona != persona_key:
print(f"Unloading previous adapter: {current_persona}")
del current_model
current_model = None
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Get adapter subdirectory
adapter_subdir = ADAPTER_SUBDIRS.get(persona_key)
if not adapter_subdir:
raise ValueError(f"Unknown persona: {persona_key}")
print(f"Loading adapter: {ADAPTER_REPO}/{adapter_subdir}")
# Create a fresh copy of base model for the adapter
# (PEFT needs a clean base model)
if current_persona != persona_key:
# Reload base model for new adapter
print(f"Creating base model copy for {persona_key} adapter...")
# Determine device and dtype
use_cuda = torch.cuda.is_available()
device = "cuda:0" if use_cuda else "cpu"
dtype = torch.bfloat16 if use_cuda else torch.float32
if use_cuda:
base_model_copy = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
dtype=dtype,
device_map="auto",
)
else:
base_model_copy = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
dtype=dtype,
)
base_model_copy = base_model_copy.to(device)
# Download adapter files from subdirectory and load
try:
from huggingface_hub import snapshot_download
# Download the adapter files from the subdirectory
print(f"Downloading adapter files from {ADAPTER_REPO}/{adapter_subdir}...")
# Download the entire repo (or use cache if available)
repo_cache_dir = snapshot_download(
repo_id=ADAPTER_REPO,
repo_type="model",
cache_dir=None, # Use default cache
)
# The adapter files should be in repo_cache_dir/adapter_subdir/
adapter_local_path = os.path.join(repo_cache_dir, adapter_subdir)
if not os.path.exists(adapter_local_path):
raise FileNotFoundError(f"Adapter subdirectory not found: {adapter_local_path}")
# Check if adapter_config.json exists
adapter_config_path = os.path.join(adapter_local_path, "adapter_config.json")
if not os.path.exists(adapter_config_path):
raise FileNotFoundError(f"adapter_config.json not found in {adapter_local_path}")
# Load adapter from local path
print(f"Loading adapter from: {adapter_local_path}")
current_model = PeftModel.from_pretrained(base_model_copy, adapter_local_path)
current_model.eval()
except Exception as e:
error_msg = f"Failed to load adapter from {ADAPTER_REPO}/{adapter_subdir}: {str(e)}"
print(f"❌ {error_msg}")
print("Make sure the adapter files are uploaded to the correct subdirectory on Hugging Face.")
raise RuntimeError(error_msg) from e
# Load persona config
try:
from huggingface_hub import hf_hub_download
config_path = hf_hub_download(
repo_id=ADAPTER_REPO,
filename=f"{adapter_subdir}/persona_config.json",
repo_type="model"
)
with open(config_path, 'r') as f:
current_config = json.load(f)
except Exception as e:
print(f"⚠️ Could not load persona_config.json: {e}")
print("Using default persona config.")
current_config = {"persona_name": persona_key.title(), "persona_description": ""}
current_persona = persona_key
current_tokenizer = base_tokenizer
print(f"βœ… Loaded {persona_key} persona")
return current_model, current_tokenizer, current_config
def generate_response(persona_key: str, message: str, history: list, max_tokens: int = 80):
"""Generate a response from the selected persona."""
global current_model, current_tokenizer, current_config
if not message or not message.strip():
return history, ""
try:
# Load adapter if needed
model, tokenizer, config = load_persona_adapter(persona_key)
# Build messages with conversation history
system_prompt = ""
if config:
system_prompt = f"You are {config.get('persona_name', '')}. {config.get('persona_description', '')}"
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
# Add conversation history (last 5 exchanges to avoid too long context)
# Handle both tuple format [user_msg, assistant_msg] and messages format
for item in history[-5:]:
if isinstance(item, dict) and "role" in item:
# Messages format
messages.append(item)
else:
# Tuple format [user_msg, assistant_msg]
user_msg, assistant_msg = item
messages.append({"role": "user", "content": user_msg})
messages.append({"role": "assistant", "content": assistant_msg})
# Add current message
messages.append({"role": "user", "content": message})
# Apply chat template
formatted = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
# Tokenize
inputs = tokenizer(formatted, return_tensors="pt", truncation=True, max_length=512)
# Move inputs to the same device as the model
device = next(model.parameters()).device
inputs = {k: v.to(device) for k, v in inputs.items()}
# Generate
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
repetition_penalty=1.2,
no_repeat_ngram_size=3,
)
# Extract only the newly generated tokens
input_length = inputs['input_ids'].shape[1]
generated_tokens = outputs[0][input_length:]
# Decode only the generated part
response = tokenizer.decode(generated_tokens, skip_special_tokens=True)
# Clean up response
response = response.strip()
if tokenizer.eos_token:
response = response.replace(tokenizer.eos_token, "").strip()
if tokenizer.pad_token:
response = response.replace(tokenizer.pad_token, "").strip()
# Remove chat template artifacts
response = response.replace("<|system|>", "").replace("</|system|>", "")
response = response.replace("<|user|>", "").replace("</|user|>", "")
response = response.replace("<|assistant|>", "").replace("</|assistant|>", "")
response = response.replace("<|", "").replace("|>", "")
# Clean up extra whitespace
response = " ".join(response.split())
response = response.strip()
# Update history with messages format (Gradio expects this format)
history.append({"role": "user", "content": message})
history.append({"role": "assistant", "content": response})
return history, ""
except Exception as e:
error_msg = f"Error generating response: {str(e)}"
print(error_msg)
return history, error_msg
def clear_chat():
"""Clear the chat history."""
return [], "" # Empty list for messages format
# Create Gradio interface
with gr.Blocks(title="Persona Chat") as app:
gr.Markdown(
"""
# 🐾 Persona Chat - Talk to Animals!
Chat with three different animal personas, each with their own unique personality:
- **πŸ• Dog**: Friendly, playful, and enthusiastic
- **🐱 Cat**: Independent, curious, and sometimes sassy
- **🐦 Bird**: Energetic, talkative, and free-spirited
**πŸ’» Running on CPU** - Responses may be slower but will work perfectly!
"""
)
with gr.Row():
with gr.Column(scale=1):
persona_dropdown = gr.Dropdown(
choices=["dog", "cat", "bird"],
value="dog",
label="Select Persona"
)
max_tokens_slider = gr.Slider(
minimum=20,
maximum=150,
value=80,
step=10,
label="Max Response Length"
)
clear_btn = gr.Button("Clear Chat")
with gr.Column(scale=3):
chatbot = gr.Chatbot(
label="Chat",
height=400
)
msg_input = gr.Textbox(
label="Your Message",
placeholder="Type your message here...",
lines=2
)
send_btn = gr.Button("Send")
# Event handlers
def chat_fn(persona, message, history, max_tokens):
return generate_response(persona, message, history, max_tokens)
send_btn.click(
fn=chat_fn,
inputs=[persona_dropdown, msg_input, chatbot, max_tokens_slider],
outputs=[chatbot, msg_input]
)
msg_input.submit(
fn=chat_fn,
inputs=[persona_dropdown, msg_input, chatbot, max_tokens_slider],
outputs=[chatbot, msg_input]
)
clear_btn.click(
fn=clear_chat,
outputs=[chatbot, msg_input]
)
if __name__ == "__main__":
# Load base model first
print("Initializing...")
load_base_model()
app.launch(
server_name="0.0.0.0" if os.getenv("SPACE_ID") else "127.0.0.1",
server_port=int(os.getenv("PORT", 7860)),
share=False
)