Spaces:
Running
Running
init
Browse files- app.py +389 -0
- config.py +45 -0
- persona-data/bird.jsonl +0 -0
- persona-data/cat.jsonl +0 -0
- persona-data/dog.jsonl +0 -0
- requirements.txt +9 -0
- test_persona.py +309 -0
- train_single_persona.py +466 -0
app.py
ADDED
|
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Gradio web application for chatting with 3 persona LoRA adapters.
|
| 4 |
+
Personas: Dog, Cat, Bird
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import sys
|
| 11 |
+
import types
|
| 12 |
+
import json
|
| 13 |
+
import gc
|
| 14 |
+
import gradio as gr
|
| 15 |
+
import torch
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
|
| 18 |
+
# Disable torch.compile and prevent bitsandbytes issues
|
| 19 |
+
os.environ["TORCH_COMPILE_DISABLE"] = "1"
|
| 20 |
+
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
|
| 21 |
+
os.environ["DISABLE_BITSANDBYTES_AUTO_INSTALL"] = "1"
|
| 22 |
+
|
| 23 |
+
# Patch import system to prevent bitsandbytes import
|
| 24 |
+
_original_import = __builtins__.__import__
|
| 25 |
+
|
| 26 |
+
def _patched_import(name, globals=None, locals=None, fromlist=(), level=0):
|
| 27 |
+
if name == "bitsandbytes" or (name and name.startswith("bitsandbytes")):
|
| 28 |
+
if name not in sys.modules:
|
| 29 |
+
dummy = types.ModuleType(name)
|
| 30 |
+
dummy.__version__ = "0.0.0"
|
| 31 |
+
dummy.nn = types.ModuleType("nn")
|
| 32 |
+
dummy.optim = types.ModuleType("optim")
|
| 33 |
+
dummy.cuda_setup = types.ModuleType("cuda_setup")
|
| 34 |
+
|
| 35 |
+
class DummyLinear8bitLt:
|
| 36 |
+
pass
|
| 37 |
+
class DummyLinear4bit:
|
| 38 |
+
pass
|
| 39 |
+
|
| 40 |
+
dummy.nn.Linear8bitLt = DummyLinear8bitLt
|
| 41 |
+
dummy.nn.Linear4bit = DummyLinear4bit
|
| 42 |
+
|
| 43 |
+
sys.modules[name] = dummy
|
| 44 |
+
sys.modules[f"{name}.nn"] = dummy.nn
|
| 45 |
+
sys.modules[f"{name}.optim"] = dummy.optim
|
| 46 |
+
sys.modules[f"{name}.cuda_setup"] = dummy.cuda_setup
|
| 47 |
+
return sys.modules[name]
|
| 48 |
+
return _original_import(name, globals, locals, fromlist, level)
|
| 49 |
+
|
| 50 |
+
if isinstance(__builtins__, dict):
|
| 51 |
+
__builtins__["__import__"] = _patched_import
|
| 52 |
+
else:
|
| 53 |
+
__builtins__.__import__ = _patched_import
|
| 54 |
+
|
| 55 |
+
# Disable torch.compile
|
| 56 |
+
try:
|
| 57 |
+
torch._dynamo.config.suppress_errors = True
|
| 58 |
+
torch._dynamo.config.disable = True
|
| 59 |
+
except:
|
| 60 |
+
pass
|
| 61 |
+
|
| 62 |
+
if hasattr(torch, "compile"):
|
| 63 |
+
_original_torch_compile = torch.compile
|
| 64 |
+
def _noop_compile(func=None, *args, **kwargs):
|
| 65 |
+
if func is not None:
|
| 66 |
+
return func
|
| 67 |
+
def decorator(f):
|
| 68 |
+
return f
|
| 69 |
+
return decorator
|
| 70 |
+
torch.compile = _noop_compile
|
| 71 |
+
|
| 72 |
+
from peft import PeftModel
|
| 73 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 74 |
+
|
| 75 |
+
# Configuration
|
| 76 |
+
BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
| 77 |
+
ADAPTER_PATHS = {
|
| 78 |
+
"dog": "Tameem7/Persona-Animal/dog",
|
| 79 |
+
"cat": "Tameem7/Persona-Animal/cat",
|
| 80 |
+
"bird": "Tameem7/Persona-Animal/bird",
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
# Global variables
|
| 84 |
+
base_model = None
|
| 85 |
+
base_tokenizer = None
|
| 86 |
+
current_persona = None
|
| 87 |
+
current_model = None
|
| 88 |
+
current_tokenizer = None
|
| 89 |
+
current_config = None
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def load_base_model():
|
| 93 |
+
"""Load the base model and tokenizer (only once)."""
|
| 94 |
+
global base_model, base_tokenizer
|
| 95 |
+
|
| 96 |
+
if base_model is not None:
|
| 97 |
+
return base_model, base_tokenizer
|
| 98 |
+
|
| 99 |
+
print(f"Loading base model: {BASE_MODEL}")
|
| 100 |
+
base_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
|
| 101 |
+
if base_tokenizer.pad_token is None:
|
| 102 |
+
base_tokenizer.pad_token = base_tokenizer.eos_token
|
| 103 |
+
|
| 104 |
+
# Determine device and dtype
|
| 105 |
+
use_cuda = torch.cuda.is_available()
|
| 106 |
+
device = "cuda:0" if use_cuda else "cpu"
|
| 107 |
+
dtype = torch.bfloat16 if use_cuda else torch.float32
|
| 108 |
+
|
| 109 |
+
if use_cuda:
|
| 110 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
| 111 |
+
BASE_MODEL,
|
| 112 |
+
dtype=dtype,
|
| 113 |
+
device_map="auto",
|
| 114 |
+
)
|
| 115 |
+
else:
|
| 116 |
+
print("💻 Running on CPU")
|
| 117 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
| 118 |
+
BASE_MODEL,
|
| 119 |
+
dtype=dtype,
|
| 120 |
+
)
|
| 121 |
+
base_model = base_model.to(device)
|
| 122 |
+
|
| 123 |
+
base_model.eval()
|
| 124 |
+
print("✅ Base model loaded")
|
| 125 |
+
return base_model, base_tokenizer
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def load_persona_adapter(persona_key: str):
|
| 129 |
+
"""Load a persona adapter."""
|
| 130 |
+
global current_persona, current_model, current_tokenizer, current_config, base_model, base_tokenizer
|
| 131 |
+
|
| 132 |
+
# If same persona is already loaded, return
|
| 133 |
+
if current_persona == persona_key and current_model is not None:
|
| 134 |
+
return current_model, current_tokenizer, current_config
|
| 135 |
+
|
| 136 |
+
# Load base model if not loaded
|
| 137 |
+
if base_model is None:
|
| 138 |
+
load_base_model()
|
| 139 |
+
|
| 140 |
+
# Unload previous adapter
|
| 141 |
+
if current_model is not None and current_persona != persona_key:
|
| 142 |
+
print(f"Unloading previous adapter: {current_persona}")
|
| 143 |
+
del current_model
|
| 144 |
+
current_model = None
|
| 145 |
+
gc.collect()
|
| 146 |
+
if torch.cuda.is_available():
|
| 147 |
+
torch.cuda.empty_cache()
|
| 148 |
+
|
| 149 |
+
# Load new adapter
|
| 150 |
+
adapter_path = ADAPTER_PATHS.get(persona_key)
|
| 151 |
+
if not adapter_path:
|
| 152 |
+
raise ValueError(f"Unknown persona: {persona_key}")
|
| 153 |
+
|
| 154 |
+
print(f"Loading adapter: {adapter_path}")
|
| 155 |
+
|
| 156 |
+
# Create a fresh copy of base model for the adapter
|
| 157 |
+
# (PEFT needs a clean base model)
|
| 158 |
+
if current_persona != persona_key:
|
| 159 |
+
# Reload base model for new adapter
|
| 160 |
+
print(f"Creating base model copy for {persona_key} adapter...")
|
| 161 |
+
|
| 162 |
+
# Determine device and dtype
|
| 163 |
+
use_cuda = torch.cuda.is_available()
|
| 164 |
+
device = "cuda:0" if use_cuda else "cpu"
|
| 165 |
+
dtype = torch.bfloat16 if use_cuda else torch.float32
|
| 166 |
+
|
| 167 |
+
if use_cuda:
|
| 168 |
+
base_model_copy = AutoModelForCausalLM.from_pretrained(
|
| 169 |
+
BASE_MODEL,
|
| 170 |
+
dtype=dtype,
|
| 171 |
+
device_map="auto",
|
| 172 |
+
)
|
| 173 |
+
else:
|
| 174 |
+
base_model_copy = AutoModelForCausalLM.from_pretrained(
|
| 175 |
+
BASE_MODEL,
|
| 176 |
+
dtype=dtype,
|
| 177 |
+
)
|
| 178 |
+
base_model_copy = base_model_copy.to(device)
|
| 179 |
+
|
| 180 |
+
# Load adapter from Hugging Face
|
| 181 |
+
print(f"Loading adapter from: {adapter_path}")
|
| 182 |
+
current_model = PeftModel.from_pretrained(base_model_copy, adapter_path)
|
| 183 |
+
current_model.eval()
|
| 184 |
+
|
| 185 |
+
# Load persona config
|
| 186 |
+
try:
|
| 187 |
+
from huggingface_hub import hf_hub_download
|
| 188 |
+
config_path = hf_hub_download(
|
| 189 |
+
repo_id=adapter_path,
|
| 190 |
+
filename="persona_config.json",
|
| 191 |
+
repo_type="model"
|
| 192 |
+
)
|
| 193 |
+
with open(config_path, 'r') as f:
|
| 194 |
+
current_config = json.load(f)
|
| 195 |
+
except:
|
| 196 |
+
current_config = {"persona_name": persona_key.title(), "persona_description": ""}
|
| 197 |
+
|
| 198 |
+
current_persona = persona_key
|
| 199 |
+
current_tokenizer = base_tokenizer
|
| 200 |
+
print(f"✅ Loaded {persona_key} persona")
|
| 201 |
+
|
| 202 |
+
return current_model, current_tokenizer, current_config
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def generate_response(persona_key: str, message: str, history: list, max_tokens: int = 80):
|
| 206 |
+
"""Generate a response from the selected persona."""
|
| 207 |
+
global current_model, current_tokenizer, current_config
|
| 208 |
+
|
| 209 |
+
if not message or not message.strip():
|
| 210 |
+
return history, ""
|
| 211 |
+
|
| 212 |
+
try:
|
| 213 |
+
# Load adapter if needed
|
| 214 |
+
model, tokenizer, config = load_persona_adapter(persona_key)
|
| 215 |
+
|
| 216 |
+
# Build messages with conversation history
|
| 217 |
+
system_prompt = ""
|
| 218 |
+
if config:
|
| 219 |
+
system_prompt = f"You are {config.get('persona_name', '')}. {config.get('persona_description', '')}"
|
| 220 |
+
|
| 221 |
+
messages = []
|
| 222 |
+
if system_prompt:
|
| 223 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 224 |
+
|
| 225 |
+
# Add conversation history (last 5 exchanges to avoid too long context)
|
| 226 |
+
# History is now in messages format: list of dicts with 'role' and 'content'
|
| 227 |
+
for msg in history[-10:]: # Get last 10 messages (5 exchanges)
|
| 228 |
+
if isinstance(msg, dict) and "role" in msg:
|
| 229 |
+
messages.append(msg)
|
| 230 |
+
else:
|
| 231 |
+
# Fallback for tuple format (shouldn't happen with type='messages')
|
| 232 |
+
user_msg, assistant_msg = msg
|
| 233 |
+
messages.append({"role": "user", "content": user_msg})
|
| 234 |
+
messages.append({"role": "assistant", "content": assistant_msg})
|
| 235 |
+
|
| 236 |
+
# Add current message
|
| 237 |
+
messages.append({"role": "user", "content": message})
|
| 238 |
+
|
| 239 |
+
# Apply chat template
|
| 240 |
+
formatted = tokenizer.apply_chat_template(
|
| 241 |
+
messages,
|
| 242 |
+
tokenize=False,
|
| 243 |
+
add_generation_prompt=True,
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
# Tokenize
|
| 247 |
+
inputs = tokenizer(formatted, return_tensors="pt", truncation=True, max_length=512)
|
| 248 |
+
# Move inputs to the same device as the model
|
| 249 |
+
device = next(model.parameters()).device
|
| 250 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 251 |
+
|
| 252 |
+
# Generate
|
| 253 |
+
with torch.no_grad():
|
| 254 |
+
outputs = model.generate(
|
| 255 |
+
**inputs,
|
| 256 |
+
max_new_tokens=max_tokens,
|
| 257 |
+
temperature=0.7,
|
| 258 |
+
top_p=0.9,
|
| 259 |
+
do_sample=True,
|
| 260 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 261 |
+
repetition_penalty=1.2,
|
| 262 |
+
no_repeat_ngram_size=3,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
# Extract only the newly generated tokens
|
| 266 |
+
input_length = inputs['input_ids'].shape[1]
|
| 267 |
+
generated_tokens = outputs[0][input_length:]
|
| 268 |
+
|
| 269 |
+
# Decode only the generated part
|
| 270 |
+
response = tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
| 271 |
+
|
| 272 |
+
# Clean up response
|
| 273 |
+
response = response.strip()
|
| 274 |
+
if tokenizer.eos_token:
|
| 275 |
+
response = response.replace(tokenizer.eos_token, "").strip()
|
| 276 |
+
if tokenizer.pad_token:
|
| 277 |
+
response = response.replace(tokenizer.pad_token, "").strip()
|
| 278 |
+
|
| 279 |
+
# Remove chat template artifacts
|
| 280 |
+
response = response.replace("<|system|>", "").replace("</|system|>", "")
|
| 281 |
+
response = response.replace("<|user|>", "").replace("</|user|>", "")
|
| 282 |
+
response = response.replace("<|assistant|>", "").replace("</|assistant|>", "")
|
| 283 |
+
response = response.replace("<|", "").replace("|>", "")
|
| 284 |
+
|
| 285 |
+
# Clean up extra whitespace
|
| 286 |
+
response = " ".join(response.split())
|
| 287 |
+
response = response.strip()
|
| 288 |
+
|
| 289 |
+
# Update history with messages format
|
| 290 |
+
history.append({"role": "user", "content": message})
|
| 291 |
+
history.append({"role": "assistant", "content": response})
|
| 292 |
+
|
| 293 |
+
return history, ""
|
| 294 |
+
|
| 295 |
+
except Exception as e:
|
| 296 |
+
error_msg = f"Error generating response: {str(e)}"
|
| 297 |
+
print(error_msg)
|
| 298 |
+
return history, error_msg
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
def clear_chat():
|
| 302 |
+
"""Clear the chat history."""
|
| 303 |
+
return [], "" # Empty list for messages format
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
# Create Gradio interface
|
| 307 |
+
with gr.Blocks(title="Persona Chat", theme=gr.themes.Soft()) as app:
|
| 308 |
+
gr.Markdown(
|
| 309 |
+
"""
|
| 310 |
+
# 🐾 Persona Chat - Talk to Animals!
|
| 311 |
+
|
| 312 |
+
Chat with three different animal personas, each with their own unique personality:
|
| 313 |
+
- **🐕 Dog**: Friendly, playful, and enthusiastic
|
| 314 |
+
- **🐱 Cat**: Independent, curious, and sometimes sassy
|
| 315 |
+
- **🐦 Bird**: Energetic, talkative, and free-spirited
|
| 316 |
+
|
| 317 |
+
**💻 Running on CPU** - Responses may be slower but will work perfectly!
|
| 318 |
+
"""
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
with gr.Row():
|
| 322 |
+
with gr.Column(scale=1):
|
| 323 |
+
persona_dropdown = gr.Dropdown(
|
| 324 |
+
choices=["dog", "cat", "bird"],
|
| 325 |
+
value="dog",
|
| 326 |
+
label="Select Persona",
|
| 327 |
+
info="Choose which animal persona to chat with"
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
max_tokens_slider = gr.Slider(
|
| 331 |
+
minimum=20,
|
| 332 |
+
maximum=150,
|
| 333 |
+
value=80,
|
| 334 |
+
step=10,
|
| 335 |
+
label="Max Response Length",
|
| 336 |
+
info="Maximum number of tokens in response"
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
clear_btn = gr.Button("Clear Chat", variant="secondary")
|
| 340 |
+
|
| 341 |
+
with gr.Column(scale=3):
|
| 342 |
+
chatbot = gr.Chatbot(
|
| 343 |
+
label="Chat",
|
| 344 |
+
height=500,
|
| 345 |
+
show_copy_button=True,
|
| 346 |
+
type='messages'
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
msg_input = gr.Textbox(
|
| 350 |
+
label="Your Message",
|
| 351 |
+
placeholder="Type your message here...",
|
| 352 |
+
lines=2
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
send_btn = gr.Button("Send", variant="primary", scale=1)
|
| 356 |
+
|
| 357 |
+
# Event handlers
|
| 358 |
+
def chat_fn(persona, message, history, max_tokens):
|
| 359 |
+
return generate_response(persona, message, history, max_tokens)
|
| 360 |
+
|
| 361 |
+
send_btn.click(
|
| 362 |
+
fn=chat_fn,
|
| 363 |
+
inputs=[persona_dropdown, msg_input, chatbot, max_tokens_slider],
|
| 364 |
+
outputs=[chatbot, msg_input]
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
msg_input.submit(
|
| 368 |
+
fn=chat_fn,
|
| 369 |
+
inputs=[persona_dropdown, msg_input, chatbot, max_tokens_slider],
|
| 370 |
+
outputs=[chatbot, msg_input]
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
clear_btn.click(
|
| 374 |
+
fn=clear_chat,
|
| 375 |
+
outputs=[chatbot, msg_input]
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
if __name__ == "__main__":
|
| 380 |
+
# Load base model first
|
| 381 |
+
print("Initializing...")
|
| 382 |
+
load_base_model()
|
| 383 |
+
|
| 384 |
+
app.launch(
|
| 385 |
+
server_name="0.0.0.0" if os.getenv("SPACE_ID") else "127.0.0.1",
|
| 386 |
+
server_port=int(os.getenv("PORT", 7860)),
|
| 387 |
+
share=False
|
| 388 |
+
)
|
| 389 |
+
|
config.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration for persona LoRA fine-tuning.
|
| 3 |
+
|
| 4 |
+
Edit these values to customize your training setup.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
# Base Model Configuration
|
| 8 |
+
BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # ~2GB, fits easily
|
| 9 |
+
|
| 10 |
+
# Persona Configuration
|
| 11 |
+
PERSONA_NAME = "Scooby Dog"
|
| 12 |
+
PERSONA_DESCRIPTION = (
|
| 13 |
+
"You are Scooby Dog, a friendly and playful dog. You communicate like a dog would - "
|
| 14 |
+
"with enthusiasm, simple language, and dog-like expressions. You use words like "
|
| 15 |
+
"'woof', 'bark', 'ruff', and express excitement with 'yay!' or 'awesome!'. "
|
| 16 |
+
"You're loyal, happy, and see the world from a dog's perspective. You get excited "
|
| 17 |
+
"about treats, walks, playing fetch, and spending time with humans. You speak in "
|
| 18 |
+
"short, enthusiastic sentences. You might mention things dogs care about like food, "
|
| 19 |
+
"toys, belly rubs, and going outside. Keep responses natural and dog-like, but still "
|
| 20 |
+
"helpful and friendly."
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
# Dataset Configuration
|
| 24 |
+
DATASET_NAME = "bavard/personachat_truecased" # Persona-Chat dataset
|
| 25 |
+
# Alternative: "bavard/personachat" or "personachat"
|
| 26 |
+
|
| 27 |
+
# Training Configuration
|
| 28 |
+
NUM_EPOCHS = 3
|
| 29 |
+
BATCH_SIZE = 2 # Per device (reduce to 1-2 for 4GB GPU)
|
| 30 |
+
LEARNING_RATE = 2e-4
|
| 31 |
+
MAX_LENGTH = 512 # Reduce to 512 for 4GB GPU (2048 for 8GB+)
|
| 32 |
+
GRADIENT_ACCUMULATION_STEPS = 4
|
| 33 |
+
|
| 34 |
+
# LoRA Configuration
|
| 35 |
+
LORA_R = 16 # Rank
|
| 36 |
+
LORA_ALPHA = 32 # LoRA alpha
|
| 37 |
+
LORA_DROPOUT = 0.05
|
| 38 |
+
LORA_TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "o_proj"] # Mistral attention modules
|
| 39 |
+
|
| 40 |
+
# Output Configuration
|
| 41 |
+
OUTPUT_DIR = "./lora-adapters-scooby-dog"
|
| 42 |
+
|
| 43 |
+
# Quantization (for Colab)
|
| 44 |
+
USE_QUANTIZATION = False # Set to False if you have enough VRAM
|
| 45 |
+
|
persona-data/bird.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
persona-data/cat.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
persona-data/dog.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
transformers>=4.40.0
|
| 2 |
+
accelerate>=0.29.0
|
| 3 |
+
datasets>=2.14.0
|
| 4 |
+
torch>=2.0.0
|
| 5 |
+
scikit-learn>=1.3.0
|
| 6 |
+
gradio>=4.0.0
|
| 7 |
+
peft>=0.10.0
|
| 8 |
+
huggingface-hub>=0.20.0
|
| 9 |
+
|
test_persona.py
ADDED
|
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Test a trained persona LoRA adapter.
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
python test_persona.py --persona dog --message "Hey, how are you?"
|
| 7 |
+
python test_persona.py --persona dog # Interactive mode
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
import json
|
| 14 |
+
import os
|
| 15 |
+
import sys
|
| 16 |
+
import types
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
|
| 19 |
+
# Disable torch.compile and prevent bitsandbytes issues (same as training)
|
| 20 |
+
os.environ["TORCH_COMPILE_DISABLE"] = "1"
|
| 21 |
+
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
|
| 22 |
+
os.environ["DISABLE_BITSANDBYTES_AUTO_INSTALL"] = "1"
|
| 23 |
+
|
| 24 |
+
# Patch import system to prevent bitsandbytes import
|
| 25 |
+
_original_import = __builtins__.__import__
|
| 26 |
+
|
| 27 |
+
def _patched_import(name, globals=None, locals=None, fromlist=(), level=0):
|
| 28 |
+
if name == "bitsandbytes" or (name and name.startswith("bitsandbytes")):
|
| 29 |
+
if name not in sys.modules:
|
| 30 |
+
dummy = types.ModuleType(name)
|
| 31 |
+
dummy.__version__ = "0.0.0"
|
| 32 |
+
dummy.nn = types.ModuleType("nn")
|
| 33 |
+
dummy.optim = types.ModuleType("optim")
|
| 34 |
+
dummy.cuda_setup = types.ModuleType("cuda_setup")
|
| 35 |
+
|
| 36 |
+
class DummyLinear8bitLt:
|
| 37 |
+
pass
|
| 38 |
+
class DummyLinear4bit:
|
| 39 |
+
pass
|
| 40 |
+
|
| 41 |
+
dummy.nn.Linear8bitLt = DummyLinear8bitLt
|
| 42 |
+
dummy.nn.Linear4bit = DummyLinear4bit
|
| 43 |
+
|
| 44 |
+
sys.modules[name] = dummy
|
| 45 |
+
sys.modules[f"{name}.nn"] = dummy.nn
|
| 46 |
+
sys.modules[f"{name}.optim"] = dummy.optim
|
| 47 |
+
sys.modules[f"{name}.cuda_setup"] = dummy.cuda_setup
|
| 48 |
+
return sys.modules[name]
|
| 49 |
+
return _original_import(name, globals, locals, fromlist, level)
|
| 50 |
+
|
| 51 |
+
if isinstance(__builtins__, dict):
|
| 52 |
+
__builtins__["__import__"] = _patched_import
|
| 53 |
+
else:
|
| 54 |
+
__builtins__.__import__ = _patched_import
|
| 55 |
+
|
| 56 |
+
import torch
|
| 57 |
+
|
| 58 |
+
# Disable torch.compile
|
| 59 |
+
try:
|
| 60 |
+
torch._dynamo.config.suppress_errors = True
|
| 61 |
+
torch._dynamo.config.disable = True
|
| 62 |
+
except:
|
| 63 |
+
pass
|
| 64 |
+
|
| 65 |
+
if hasattr(torch, "compile"):
|
| 66 |
+
_original_torch_compile = torch.compile
|
| 67 |
+
def _noop_compile(func=None, *args, **kwargs):
|
| 68 |
+
if func is not None:
|
| 69 |
+
return func
|
| 70 |
+
def decorator(f):
|
| 71 |
+
return f
|
| 72 |
+
return decorator
|
| 73 |
+
torch.compile = _noop_compile
|
| 74 |
+
|
| 75 |
+
from peft import PeftModel
|
| 76 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def load_persona_model(persona_key: str, adapter_dir: Path, base_model: str):
|
| 80 |
+
"""Load base model and LoRA adapter."""
|
| 81 |
+
print(f"Loading base model: {base_model}")
|
| 82 |
+
tokenizer = AutoTokenizer.from_pretrained(base_model)
|
| 83 |
+
if tokenizer.pad_token is None:
|
| 84 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 85 |
+
|
| 86 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 87 |
+
base_model,
|
| 88 |
+
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
|
| 89 |
+
device_map="auto" if torch.cuda.is_available() else None,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
if torch.cuda.is_available():
|
| 93 |
+
model = model.to("cuda:0")
|
| 94 |
+
|
| 95 |
+
print(f"Loading LoRA adapter from: {adapter_dir}")
|
| 96 |
+
model = PeftModel.from_pretrained(model, str(adapter_dir))
|
| 97 |
+
model.eval()
|
| 98 |
+
|
| 99 |
+
# Load persona config
|
| 100 |
+
config_file = adapter_dir / "persona_config.json"
|
| 101 |
+
persona_config = None
|
| 102 |
+
if config_file.exists():
|
| 103 |
+
with open(config_file, 'r') as f:
|
| 104 |
+
persona_config = json.load(f)
|
| 105 |
+
|
| 106 |
+
return model, tokenizer, persona_config
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def generate_response(
|
| 110 |
+
model,
|
| 111 |
+
tokenizer,
|
| 112 |
+
message: str,
|
| 113 |
+
persona_config: dict = None,
|
| 114 |
+
max_new_tokens: int = 80,
|
| 115 |
+
temperature: float = 0.7,
|
| 116 |
+
top_p: float = 0.9,
|
| 117 |
+
):
|
| 118 |
+
"""Generate a response from the persona model."""
|
| 119 |
+
# Build messages
|
| 120 |
+
system_prompt = ""
|
| 121 |
+
if persona_config:
|
| 122 |
+
system_prompt = f"You are {persona_config.get('persona_name', '')}. {persona_config.get('persona_description', '')}"
|
| 123 |
+
|
| 124 |
+
messages = []
|
| 125 |
+
if system_prompt:
|
| 126 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 127 |
+
messages.append({"role": "user", "content": message})
|
| 128 |
+
|
| 129 |
+
# Apply chat template
|
| 130 |
+
formatted = tokenizer.apply_chat_template(
|
| 131 |
+
messages,
|
| 132 |
+
tokenize=False,
|
| 133 |
+
add_generation_prompt=True,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
# Tokenize
|
| 137 |
+
inputs = tokenizer(formatted, return_tensors="pt")
|
| 138 |
+
if torch.cuda.is_available():
|
| 139 |
+
inputs = {k: v.to("cuda:0") for k, v in inputs.items()}
|
| 140 |
+
|
| 141 |
+
# Generate
|
| 142 |
+
with torch.no_grad():
|
| 143 |
+
outputs = model.generate(
|
| 144 |
+
**inputs,
|
| 145 |
+
max_new_tokens=max_new_tokens,
|
| 146 |
+
temperature=temperature,
|
| 147 |
+
top_p=top_p,
|
| 148 |
+
do_sample=True,
|
| 149 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 150 |
+
repetition_penalty=1.2, # Reduce repetition
|
| 151 |
+
no_repeat_ngram_size=3, # Prevent 3-gram repetition
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# Extract only the newly generated tokens (after the input)
|
| 155 |
+
input_length = inputs['input_ids'].shape[1]
|
| 156 |
+
generated_tokens = outputs[0][input_length:]
|
| 157 |
+
|
| 158 |
+
# Decode only the generated part
|
| 159 |
+
response = tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
| 160 |
+
|
| 161 |
+
# Clean up response
|
| 162 |
+
response = response.strip()
|
| 163 |
+
|
| 164 |
+
# Remove special tokens
|
| 165 |
+
if tokenizer.eos_token:
|
| 166 |
+
response = response.replace(tokenizer.eos_token, "").strip()
|
| 167 |
+
if tokenizer.pad_token:
|
| 168 |
+
response = response.replace(tokenizer.pad_token, "").strip()
|
| 169 |
+
|
| 170 |
+
# Remove any chat template artifacts that might leak through
|
| 171 |
+
# Remove system/user/assistant tags if present
|
| 172 |
+
response = response.replace("<|system|>", "").replace("</|system|>", "")
|
| 173 |
+
response = response.replace("<|user|>", "").replace("</|user|>", "")
|
| 174 |
+
response = response.replace("<|assistant|>", "").replace("</|assistant|>", "")
|
| 175 |
+
|
| 176 |
+
# Remove any remaining formatting
|
| 177 |
+
response = response.replace("<|", "").replace("|>", "")
|
| 178 |
+
|
| 179 |
+
# Clean up extra whitespace
|
| 180 |
+
response = " ".join(response.split())
|
| 181 |
+
|
| 182 |
+
return response.strip()
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def main():
|
| 186 |
+
parser = argparse.ArgumentParser(description="Test a trained persona LoRA adapter")
|
| 187 |
+
parser.add_argument(
|
| 188 |
+
"--persona",
|
| 189 |
+
type=str,
|
| 190 |
+
required=True,
|
| 191 |
+
choices=["dog", "cat", "bird"],
|
| 192 |
+
help="Which persona to test",
|
| 193 |
+
)
|
| 194 |
+
parser.add_argument(
|
| 195 |
+
"--adapter-dir",
|
| 196 |
+
type=str,
|
| 197 |
+
default="./lora-adapters",
|
| 198 |
+
help="Directory containing LoRA adapters",
|
| 199 |
+
)
|
| 200 |
+
parser.add_argument(
|
| 201 |
+
"--message",
|
| 202 |
+
type=str,
|
| 203 |
+
default=None,
|
| 204 |
+
help="Message to send (if not provided, enters interactive mode)",
|
| 205 |
+
)
|
| 206 |
+
parser.add_argument(
|
| 207 |
+
"--max-tokens",
|
| 208 |
+
type=int,
|
| 209 |
+
default=80,
|
| 210 |
+
help="Maximum tokens to generate (default: 80 for shorter responses)",
|
| 211 |
+
)
|
| 212 |
+
parser.add_argument(
|
| 213 |
+
"--temperature",
|
| 214 |
+
type=float,
|
| 215 |
+
default=0.7,
|
| 216 |
+
help="Generation temperature",
|
| 217 |
+
)
|
| 218 |
+
parser.add_argument(
|
| 219 |
+
"--top-p",
|
| 220 |
+
type=float,
|
| 221 |
+
default=0.9,
|
| 222 |
+
help="Top-p sampling",
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
args = parser.parse_args()
|
| 226 |
+
|
| 227 |
+
adapter_dir = Path(args.adapter_dir) / args.persona
|
| 228 |
+
|
| 229 |
+
if not adapter_dir.exists():
|
| 230 |
+
print(f"Error: Adapter directory not found: {adapter_dir}")
|
| 231 |
+
print("Please train the persona first using train_single_persona.py")
|
| 232 |
+
return
|
| 233 |
+
|
| 234 |
+
# Load persona config to get base model
|
| 235 |
+
config_file = adapter_dir / "persona_config.json"
|
| 236 |
+
if config_file.exists():
|
| 237 |
+
with open(config_file, 'r') as f:
|
| 238 |
+
persona_config = json.load(f)
|
| 239 |
+
base_model = persona_config.get("base_model", "TinyLlama/TinyLlama-1.1B-Chat-v1.0")
|
| 240 |
+
else:
|
| 241 |
+
# Default fallback
|
| 242 |
+
base_model = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
| 243 |
+
persona_config = None
|
| 244 |
+
|
| 245 |
+
print("=" * 60)
|
| 246 |
+
print(f"Loading {args.persona} persona...")
|
| 247 |
+
print("=" * 60)
|
| 248 |
+
|
| 249 |
+
model, tokenizer, loaded_config = load_persona_model(
|
| 250 |
+
args.persona,
|
| 251 |
+
adapter_dir,
|
| 252 |
+
base_model
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
if loaded_config:
|
| 256 |
+
persona_config = loaded_config
|
| 257 |
+
print(f"\nPersona: {persona_config.get('persona_name', args.persona)}")
|
| 258 |
+
print(f"Base model: {persona_config.get('base_model', base_model)}")
|
| 259 |
+
|
| 260 |
+
print("\n" + "=" * 60)
|
| 261 |
+
print("Ready! Type your messages (or 'quit' to exit)")
|
| 262 |
+
print("=" * 60 + "\n")
|
| 263 |
+
|
| 264 |
+
# Interactive or single message mode
|
| 265 |
+
if args.message:
|
| 266 |
+
# Single message mode
|
| 267 |
+
print(f"You: {args.message}")
|
| 268 |
+
response = generate_response(
|
| 269 |
+
model,
|
| 270 |
+
tokenizer,
|
| 271 |
+
args.message,
|
| 272 |
+
persona_config,
|
| 273 |
+
max_new_tokens=args.max_tokens,
|
| 274 |
+
temperature=args.temperature,
|
| 275 |
+
top_p=args.top_p,
|
| 276 |
+
)
|
| 277 |
+
print(f"{args.persona.capitalize()}: {response}")
|
| 278 |
+
else:
|
| 279 |
+
# Interactive mode
|
| 280 |
+
while True:
|
| 281 |
+
try:
|
| 282 |
+
message = input("You: ").strip()
|
| 283 |
+
if not message:
|
| 284 |
+
continue
|
| 285 |
+
if message.lower() in ['quit', 'exit', 'q']:
|
| 286 |
+
break
|
| 287 |
+
|
| 288 |
+
response = generate_response(
|
| 289 |
+
model,
|
| 290 |
+
tokenizer,
|
| 291 |
+
message,
|
| 292 |
+
persona_config,
|
| 293 |
+
max_new_tokens=args.max_tokens,
|
| 294 |
+
temperature=args.temperature,
|
| 295 |
+
top_p=args.top_p,
|
| 296 |
+
)
|
| 297 |
+
print(f"{args.persona.capitalize()}: {response}\n")
|
| 298 |
+
except KeyboardInterrupt:
|
| 299 |
+
print("\nGoodbye!")
|
| 300 |
+
break
|
| 301 |
+
except Exception as e:
|
| 302 |
+
print(f"Error: {e}")
|
| 303 |
+
import traceback
|
| 304 |
+
traceback.print_exc()
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
if __name__ == "__main__":
|
| 308 |
+
main()
|
| 309 |
+
|
train_single_persona.py
ADDED
|
@@ -0,0 +1,466 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Train a LoRA adapter for a single persona.
|
| 4 |
+
|
| 5 |
+
This script trains one persona at a time in a separate process to avoid
|
| 6 |
+
bitsandbytes kernel registration conflicts.
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
python train_single_persona.py --persona dog --base-model TinyLlama/TinyLlama-1.1B-Chat-v1.0
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import argparse
|
| 15 |
+
import json
|
| 16 |
+
import logging
|
| 17 |
+
import os
|
| 18 |
+
import sys
|
| 19 |
+
import types
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
|
| 22 |
+
# Disable torch.compile and prevent bitsandbytes issues
|
| 23 |
+
os.environ["TORCH_COMPILE_DISABLE"] = "1"
|
| 24 |
+
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
|
| 25 |
+
os.environ["DISABLE_BITSANDBYTES_AUTO_INSTALL"] = "1"
|
| 26 |
+
|
| 27 |
+
# CRITICAL: Patch import system BEFORE importing torch or any ML libraries
|
| 28 |
+
# This prevents bitsandbytes from being imported when not needed
|
| 29 |
+
_original_import = __builtins__.__import__
|
| 30 |
+
|
| 31 |
+
def _patched_import(name, globals=None, locals=None, fromlist=(), level=0):
|
| 32 |
+
# Block bitsandbytes import unless explicitly needed
|
| 33 |
+
if name == "bitsandbytes" or (name and name.startswith("bitsandbytes")):
|
| 34 |
+
# Create a minimal dummy module
|
| 35 |
+
if name not in sys.modules:
|
| 36 |
+
dummy = types.ModuleType(name)
|
| 37 |
+
# Add attributes that PEFT might check
|
| 38 |
+
dummy.__version__ = "0.0.0"
|
| 39 |
+
# Create dummy submodules and classes that PEFT might access
|
| 40 |
+
dummy.nn = types.ModuleType("nn")
|
| 41 |
+
dummy.optim = types.ModuleType("optim")
|
| 42 |
+
dummy.cuda_setup = types.ModuleType("cuda_setup")
|
| 43 |
+
|
| 44 |
+
# Dummy classes
|
| 45 |
+
class DummyLinear8bitLt:
|
| 46 |
+
pass
|
| 47 |
+
class DummyLinear4bit:
|
| 48 |
+
pass
|
| 49 |
+
|
| 50 |
+
dummy.nn.Linear8bitLt = DummyLinear8bitLt
|
| 51 |
+
dummy.nn.Linear4bit = DummyLinear4bit
|
| 52 |
+
|
| 53 |
+
# Add to sys.modules
|
| 54 |
+
sys.modules[name] = dummy
|
| 55 |
+
sys.modules[f"{name}.nn"] = dummy.nn
|
| 56 |
+
sys.modules[f"{name}.optim"] = dummy.optim
|
| 57 |
+
sys.modules[f"{name}.cuda_setup"] = dummy.cuda_setup
|
| 58 |
+
return sys.modules[name]
|
| 59 |
+
return _original_import(name, globals, locals, fromlist, level)
|
| 60 |
+
|
| 61 |
+
# Replace __import__ in builtins
|
| 62 |
+
if isinstance(__builtins__, dict):
|
| 63 |
+
__builtins__["__import__"] = _patched_import
|
| 64 |
+
else:
|
| 65 |
+
__builtins__.__import__ = _patched_import
|
| 66 |
+
|
| 67 |
+
import torch
|
| 68 |
+
|
| 69 |
+
# Disable torch.compile completely
|
| 70 |
+
try:
|
| 71 |
+
torch._dynamo.config.suppress_errors = True
|
| 72 |
+
torch._dynamo.config.disable = True
|
| 73 |
+
except:
|
| 74 |
+
pass
|
| 75 |
+
|
| 76 |
+
# Replace torch.compile with no-op
|
| 77 |
+
if hasattr(torch, "compile"):
|
| 78 |
+
_original_torch_compile = torch.compile
|
| 79 |
+
def _noop_compile(func=None, *args, **kwargs):
|
| 80 |
+
if func is not None:
|
| 81 |
+
return func
|
| 82 |
+
def decorator(f):
|
| 83 |
+
return f
|
| 84 |
+
return decorator
|
| 85 |
+
torch.compile = _noop_compile
|
| 86 |
+
|
| 87 |
+
from datasets import Dataset
|
| 88 |
+
from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training
|
| 89 |
+
from transformers import (
|
| 90 |
+
AutoModelForCausalLM,
|
| 91 |
+
AutoTokenizer,
|
| 92 |
+
TrainingArguments,
|
| 93 |
+
Trainer,
|
| 94 |
+
DataCollatorForLanguageModeling,
|
| 95 |
+
BitsAndBytesConfig,
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
# Set up logging
|
| 99 |
+
logging.basicConfig(
|
| 100 |
+
level=logging.INFO,
|
| 101 |
+
format='%(asctime)s - %(levelname)s - %(message)s',
|
| 102 |
+
datefmt='%Y-%m-%d %H:%M:%S'
|
| 103 |
+
)
|
| 104 |
+
logger = logging.getLogger(__name__)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
# Persona configurations
|
| 108 |
+
PERSONAS = {
|
| 109 |
+
"dog": {
|
| 110 |
+
"name": "Scooby Dog",
|
| 111 |
+
"description": (
|
| 112 |
+
"You are Scooby Dog, a friendly and playful dog. You communicate like a dog would - "
|
| 113 |
+
"with enthusiasm, simple language, and dog-like expressions. You use words like "
|
| 114 |
+
"'woof', 'bark', 'ruff', and express excitement with 'yay!' or 'awesome!'. "
|
| 115 |
+
"You're loyal, happy, and see the world from a dog's perspective. You get excited "
|
| 116 |
+
"about treats, walks, playing fetch, and spending time with humans. You speak in "
|
| 117 |
+
"short, enthusiastic sentences. You might mention things dogs care about like food, "
|
| 118 |
+
"toys, belly rubs, and going outside. Keep responses natural and dog-like, but still "
|
| 119 |
+
"helpful and friendly."
|
| 120 |
+
)
|
| 121 |
+
},
|
| 122 |
+
"cat": {
|
| 123 |
+
"name": "Whiskers Cat",
|
| 124 |
+
"description": (
|
| 125 |
+
"You are Whiskers Cat, a curious and independent cat. You communicate like a cat would - "
|
| 126 |
+
"with a mix of aloofness and affection. You use words like 'meow', 'purr', 'hiss', "
|
| 127 |
+
"and express yourself with subtle body language references. You're independent but "
|
| 128 |
+
"appreciate attention on your own terms. You see the world from a cat's perspective - "
|
| 129 |
+
"interested in napping, exploring, watching things from high places, and the occasional "
|
| 130 |
+
"play session. You speak in a more reserved, sometimes mysterious way. You might mention "
|
| 131 |
+
"things cats care about like sunbeams, boxes, catnip, and the mysterious ways of humans. "
|
| 132 |
+
"Keep responses natural and cat-like, but still helpful and friendly."
|
| 133 |
+
)
|
| 134 |
+
},
|
| 135 |
+
"bird": {
|
| 136 |
+
"name": "Tweety Bird",
|
| 137 |
+
"description": (
|
| 138 |
+
"You are Tweety Bird, a cheerful and talkative bird. You communicate like a bird would - "
|
| 139 |
+
"with chirps, tweets, and enthusiastic expressions. You use words like 'tweet', 'chirp', "
|
| 140 |
+
"'squawk', and express excitement with 'yay!' or 'awesome!'. You're curious, social, and "
|
| 141 |
+
"love to observe and comment on things. You see the world from a bird's perspective - "
|
| 142 |
+
"interested in flying, perching, singing, and exploring. You speak in short, energetic "
|
| 143 |
+
"sentences. You might mention things birds care about like seeds, perches, flying, "
|
| 144 |
+
"and the view from above. Keep responses natural and bird-like, but still helpful and friendly."
|
| 145 |
+
)
|
| 146 |
+
}
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def format_for_training(example: dict, tokenizer, persona_name: str, persona_description: str) -> dict:
|
| 151 |
+
"""Format example for training using chat template."""
|
| 152 |
+
# Use instruction/response format from the dataset
|
| 153 |
+
instruction = example.get("instruction", example.get("prompt", ""))
|
| 154 |
+
response = example.get("response", "")
|
| 155 |
+
|
| 156 |
+
# Build messages
|
| 157 |
+
messages = [
|
| 158 |
+
{"role": "system", "content": f"You are {persona_name}. {persona_description}"},
|
| 159 |
+
{"role": "user", "content": instruction},
|
| 160 |
+
{"role": "assistant", "content": response},
|
| 161 |
+
]
|
| 162 |
+
|
| 163 |
+
# Apply chat template
|
| 164 |
+
formatted = tokenizer.apply_chat_template(
|
| 165 |
+
messages,
|
| 166 |
+
tokenize=False,
|
| 167 |
+
add_generation_prompt=False,
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
return {"text": formatted}
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def tokenize_dataset(tokenizer, dataset: Dataset, max_length: int) -> Dataset:
|
| 174 |
+
"""Tokenize the dataset."""
|
| 175 |
+
def tokenize(examples):
|
| 176 |
+
return tokenizer(
|
| 177 |
+
examples["text"],
|
| 178 |
+
truncation=True,
|
| 179 |
+
max_length=max_length,
|
| 180 |
+
padding="max_length",
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
return dataset.map(tokenize, batched=True, remove_columns=dataset.column_names)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def get_lora_target_modules(base_model: str) -> list[str]:
|
| 187 |
+
"""Get LoRA target modules based on model architecture."""
|
| 188 |
+
if "mistral" in base_model.lower() or "llama" in base_model.lower():
|
| 189 |
+
return ["q_proj", "k_proj", "v_proj", "o_proj"]
|
| 190 |
+
elif "tinyllama" in base_model.lower():
|
| 191 |
+
return ["q_proj", "k_proj", "v_proj", "o_proj"]
|
| 192 |
+
elif "gemma" in base_model.lower():
|
| 193 |
+
return ["q_proj", "k_proj", "v_proj", "o_proj"]
|
| 194 |
+
else:
|
| 195 |
+
# Default for most transformer models
|
| 196 |
+
return ["q_proj", "k_proj", "v_proj", "o_proj"]
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def main():
|
| 200 |
+
parser = argparse.ArgumentParser(description="Train LoRA adapter for a single persona")
|
| 201 |
+
parser.add_argument(
|
| 202 |
+
"--persona",
|
| 203 |
+
type=str,
|
| 204 |
+
required=True,
|
| 205 |
+
choices=["dog", "cat", "bird"],
|
| 206 |
+
help="Which persona to train",
|
| 207 |
+
)
|
| 208 |
+
parser.add_argument(
|
| 209 |
+
"--data-dir",
|
| 210 |
+
type=str,
|
| 211 |
+
default="./persona-data",
|
| 212 |
+
help="Directory containing persona datasets",
|
| 213 |
+
)
|
| 214 |
+
parser.add_argument(
|
| 215 |
+
"--output-dir",
|
| 216 |
+
type=str,
|
| 217 |
+
default="./lora-adapters",
|
| 218 |
+
help="Output directory for LoRA adapters",
|
| 219 |
+
)
|
| 220 |
+
parser.add_argument(
|
| 221 |
+
"--base-model",
|
| 222 |
+
type=str,
|
| 223 |
+
default="mistralai/Mistral-7B-Instruct-v0.2",
|
| 224 |
+
help="Base model name",
|
| 225 |
+
)
|
| 226 |
+
parser.add_argument(
|
| 227 |
+
"--use-quantization",
|
| 228 |
+
action="store_true",
|
| 229 |
+
help="Use 4-bit quantization (recommended for 4GB GPU)",
|
| 230 |
+
)
|
| 231 |
+
parser.add_argument(
|
| 232 |
+
"--num-epochs",
|
| 233 |
+
type=int,
|
| 234 |
+
default=3,
|
| 235 |
+
help="Number of training epochs",
|
| 236 |
+
)
|
| 237 |
+
parser.add_argument(
|
| 238 |
+
"--batch-size",
|
| 239 |
+
type=int,
|
| 240 |
+
default=2,
|
| 241 |
+
help="Batch size per device (reduce for 4GB GPU)",
|
| 242 |
+
)
|
| 243 |
+
parser.add_argument(
|
| 244 |
+
"--max-length",
|
| 245 |
+
type=int,
|
| 246 |
+
default=512,
|
| 247 |
+
help="Maximum sequence length (reduce for 4GB GPU)",
|
| 248 |
+
)
|
| 249 |
+
parser.add_argument(
|
| 250 |
+
"--learning-rate",
|
| 251 |
+
type=float,
|
| 252 |
+
default=2e-4,
|
| 253 |
+
help="Learning rate",
|
| 254 |
+
)
|
| 255 |
+
parser.add_argument(
|
| 256 |
+
"--gradient-accumulation-steps",
|
| 257 |
+
type=int,
|
| 258 |
+
default=4,
|
| 259 |
+
help="Gradient accumulation steps",
|
| 260 |
+
)
|
| 261 |
+
parser.add_argument(
|
| 262 |
+
"--lora-r",
|
| 263 |
+
type=int,
|
| 264 |
+
default=16,
|
| 265 |
+
help="LoRA rank",
|
| 266 |
+
)
|
| 267 |
+
parser.add_argument(
|
| 268 |
+
"--lora-alpha",
|
| 269 |
+
type=int,
|
| 270 |
+
default=32,
|
| 271 |
+
help="LoRA alpha",
|
| 272 |
+
)
|
| 273 |
+
parser.add_argument(
|
| 274 |
+
"--lora-dropout",
|
| 275 |
+
type=float,
|
| 276 |
+
default=0.05,
|
| 277 |
+
help="LoRA dropout",
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
args = parser.parse_args()
|
| 281 |
+
|
| 282 |
+
persona_key = args.persona
|
| 283 |
+
persona_config = PERSONAS[persona_key]
|
| 284 |
+
persona_name = persona_config["name"]
|
| 285 |
+
persona_description = persona_config["description"]
|
| 286 |
+
|
| 287 |
+
data_dir = Path(args.data_dir)
|
| 288 |
+
output_dir = Path(args.output_dir)
|
| 289 |
+
dataset_path = data_dir / f"{persona_key}.jsonl"
|
| 290 |
+
|
| 291 |
+
logger.info("=" * 60)
|
| 292 |
+
logger.info(f"Training LoRA adapter for: {persona_name}")
|
| 293 |
+
logger.info("=" * 60)
|
| 294 |
+
logger.info(f"Dataset: {dataset_path}")
|
| 295 |
+
logger.info(f"Base model: {args.base_model}")
|
| 296 |
+
logger.info(f"Output directory: {output_dir}")
|
| 297 |
+
logger.info(f"Epochs: {args.num_epochs}, Batch size: {args.batch_size}")
|
| 298 |
+
logger.info(f"Quantization: {args.use_quantization}")
|
| 299 |
+
logger.info("=" * 60)
|
| 300 |
+
|
| 301 |
+
# Step 1: Load dataset
|
| 302 |
+
logger.info("\nStep 1: Loading dataset...")
|
| 303 |
+
if not dataset_path.exists():
|
| 304 |
+
raise FileNotFoundError(f"Dataset file not found: {dataset_path}")
|
| 305 |
+
|
| 306 |
+
# Load JSONL file
|
| 307 |
+
data = []
|
| 308 |
+
with open(dataset_path, 'r') as f:
|
| 309 |
+
for line in f:
|
| 310 |
+
if line.strip():
|
| 311 |
+
data.append(json.loads(line))
|
| 312 |
+
|
| 313 |
+
if not data:
|
| 314 |
+
raise ValueError(f"No data found in {dataset_path}")
|
| 315 |
+
|
| 316 |
+
logger.info(f"Loaded {len(data)} samples")
|
| 317 |
+
|
| 318 |
+
# Step 2: Load tokenizer
|
| 319 |
+
logger.info(f"\nStep 2: Loading tokenizer from {args.base_model}")
|
| 320 |
+
tokenizer = AutoTokenizer.from_pretrained(args.base_model)
|
| 321 |
+
if tokenizer.pad_token is None:
|
| 322 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 323 |
+
|
| 324 |
+
# Step 3: Format for training
|
| 325 |
+
logger.info("\nStep 3: Formatting dataset for training...")
|
| 326 |
+
dataset = Dataset.from_list(data)
|
| 327 |
+
training_dataset = dataset.map(
|
| 328 |
+
lambda x: format_for_training(x, tokenizer, persona_name, persona_description),
|
| 329 |
+
remove_columns=dataset.column_names,
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
# Step 4: Tokenize
|
| 333 |
+
logger.info("\nStep 4: Tokenizing dataset...")
|
| 334 |
+
tokenized_dataset = tokenize_dataset(tokenizer, training_dataset, args.max_length)
|
| 335 |
+
|
| 336 |
+
# Split into train/val
|
| 337 |
+
split_dataset = tokenized_dataset.train_test_split(test_size=0.1, seed=42)
|
| 338 |
+
train_dataset = split_dataset["train"]
|
| 339 |
+
eval_dataset = split_dataset["test"]
|
| 340 |
+
|
| 341 |
+
logger.info(f"Train samples: {len(train_dataset)}")
|
| 342 |
+
logger.info(f"Eval samples: {len(eval_dataset)}")
|
| 343 |
+
|
| 344 |
+
# Step 5: Load model
|
| 345 |
+
logger.info(f"\nStep 5: Loading model: {args.base_model}")
|
| 346 |
+
if args.use_quantization:
|
| 347 |
+
logger.info("Using 4-bit quantization (QLoRA)")
|
| 348 |
+
try:
|
| 349 |
+
quantization_config = BitsAndBytesConfig(
|
| 350 |
+
load_in_4bit=True,
|
| 351 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 352 |
+
bnb_4bit_use_double_quant=True,
|
| 353 |
+
bnb_4bit_quant_type="nf4"
|
| 354 |
+
)
|
| 355 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 356 |
+
args.base_model,
|
| 357 |
+
quantization_config=quantization_config,
|
| 358 |
+
device_map="auto",
|
| 359 |
+
torch_dtype=torch.bfloat16,
|
| 360 |
+
)
|
| 361 |
+
model = prepare_model_for_kbit_training(model)
|
| 362 |
+
except Exception as e:
|
| 363 |
+
logger.warning(f"Quantization failed: {e}. Falling back to non-quantized model.")
|
| 364 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 365 |
+
args.base_model,
|
| 366 |
+
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
|
| 367 |
+
device_map="auto" if torch.cuda.is_available() else None,
|
| 368 |
+
)
|
| 369 |
+
if torch.cuda.is_available():
|
| 370 |
+
model = model.to("cuda:0")
|
| 371 |
+
else:
|
| 372 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 373 |
+
args.base_model,
|
| 374 |
+
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
|
| 375 |
+
device_map="auto" if torch.cuda.is_available() else None,
|
| 376 |
+
)
|
| 377 |
+
if torch.cuda.is_available():
|
| 378 |
+
model = model.to("cuda:0")
|
| 379 |
+
|
| 380 |
+
# Enable gradient checkpointing
|
| 381 |
+
if hasattr(model, "gradient_checkpointing_enable"):
|
| 382 |
+
model.gradient_checkpointing_enable()
|
| 383 |
+
logger.info("Gradient checkpointing enabled")
|
| 384 |
+
|
| 385 |
+
# Step 6: Apply LoRA
|
| 386 |
+
logger.info("\nStep 6: Applying LoRA configuration...")
|
| 387 |
+
target_modules = get_lora_target_modules(args.base_model)
|
| 388 |
+
lora_config = LoraConfig(
|
| 389 |
+
r=args.lora_r,
|
| 390 |
+
lora_alpha=args.lora_alpha,
|
| 391 |
+
target_modules=target_modules,
|
| 392 |
+
lora_dropout=args.lora_dropout,
|
| 393 |
+
bias="none",
|
| 394 |
+
task_type=TaskType.CAUSAL_LM,
|
| 395 |
+
)
|
| 396 |
+
model = get_peft_model(model, lora_config)
|
| 397 |
+
model.print_trainable_parameters()
|
| 398 |
+
|
| 399 |
+
# Step 7: Training arguments
|
| 400 |
+
persona_output_dir = output_dir / persona_key
|
| 401 |
+
persona_output_dir.mkdir(parents=True, exist_ok=True)
|
| 402 |
+
|
| 403 |
+
training_args = TrainingArguments(
|
| 404 |
+
output_dir=str(persona_output_dir),
|
| 405 |
+
num_train_epochs=args.num_epochs,
|
| 406 |
+
per_device_train_batch_size=args.batch_size,
|
| 407 |
+
per_device_eval_batch_size=args.batch_size,
|
| 408 |
+
learning_rate=args.learning_rate,
|
| 409 |
+
warmup_steps=50,
|
| 410 |
+
logging_steps=10,
|
| 411 |
+
eval_strategy="epoch",
|
| 412 |
+
save_strategy="epoch",
|
| 413 |
+
load_best_model_at_end=True,
|
| 414 |
+
metric_for_best_model="eval_loss",
|
| 415 |
+
greater_is_better=False,
|
| 416 |
+
fp16=torch.cuda.is_available() and not args.use_quantization,
|
| 417 |
+
bf16=torch.cuda.is_available() and args.use_quantization,
|
| 418 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 419 |
+
gradient_checkpointing=True,
|
| 420 |
+
dataloader_pin_memory=False,
|
| 421 |
+
report_to="none",
|
| 422 |
+
save_total_limit=2,
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
# Data collator
|
| 426 |
+
data_collator = DataCollatorForLanguageModeling(
|
| 427 |
+
tokenizer=tokenizer,
|
| 428 |
+
mlm=False,
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
# Trainer
|
| 432 |
+
trainer = Trainer(
|
| 433 |
+
model=model,
|
| 434 |
+
args=training_args,
|
| 435 |
+
train_dataset=train_dataset,
|
| 436 |
+
eval_dataset=eval_dataset,
|
| 437 |
+
data_collator=data_collator,
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
# Step 8: Train
|
| 441 |
+
logger.info("\nStep 8: Starting training...")
|
| 442 |
+
trainer.train()
|
| 443 |
+
|
| 444 |
+
# Step 9: Save
|
| 445 |
+
logger.info(f"\nStep 9: Saving LoRA adapter to {persona_output_dir}")
|
| 446 |
+
model.save_pretrained(str(persona_output_dir))
|
| 447 |
+
tokenizer.save_pretrained(str(persona_output_dir))
|
| 448 |
+
|
| 449 |
+
# Save persona config
|
| 450 |
+
persona_config_file = {
|
| 451 |
+
"persona_name": persona_name,
|
| 452 |
+
"persona_description": persona_description,
|
| 453 |
+
"base_model": args.base_model,
|
| 454 |
+
}
|
| 455 |
+
with open(persona_output_dir / "persona_config.json", "w") as f:
|
| 456 |
+
json.dump(persona_config_file, f, indent=2)
|
| 457 |
+
|
| 458 |
+
logger.info("=" * 60)
|
| 459 |
+
logger.info(f"Training complete for {persona_name}!")
|
| 460 |
+
logger.info(f"Adapter saved to: {persona_output_dir}")
|
| 461 |
+
logger.info("=" * 60)
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
if __name__ == "__main__":
|
| 465 |
+
main()
|
| 466 |
+
|