Tameem7 commited on
Commit
fd35193
·
1 Parent(s): a129880
app.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Gradio web application for chatting with 3 persona LoRA adapters.
4
+ Personas: Dog, Cat, Bird
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import os
10
+ import sys
11
+ import types
12
+ import json
13
+ import gc
14
+ import gradio as gr
15
+ import torch
16
+ from pathlib import Path
17
+
18
+ # Disable torch.compile and prevent bitsandbytes issues
19
+ os.environ["TORCH_COMPILE_DISABLE"] = "1"
20
+ os.environ["BITSANDBYTES_NOWELCOME"] = "1"
21
+ os.environ["DISABLE_BITSANDBYTES_AUTO_INSTALL"] = "1"
22
+
23
+ # Patch import system to prevent bitsandbytes import
24
+ _original_import = __builtins__.__import__
25
+
26
+ def _patched_import(name, globals=None, locals=None, fromlist=(), level=0):
27
+ if name == "bitsandbytes" or (name and name.startswith("bitsandbytes")):
28
+ if name not in sys.modules:
29
+ dummy = types.ModuleType(name)
30
+ dummy.__version__ = "0.0.0"
31
+ dummy.nn = types.ModuleType("nn")
32
+ dummy.optim = types.ModuleType("optim")
33
+ dummy.cuda_setup = types.ModuleType("cuda_setup")
34
+
35
+ class DummyLinear8bitLt:
36
+ pass
37
+ class DummyLinear4bit:
38
+ pass
39
+
40
+ dummy.nn.Linear8bitLt = DummyLinear8bitLt
41
+ dummy.nn.Linear4bit = DummyLinear4bit
42
+
43
+ sys.modules[name] = dummy
44
+ sys.modules[f"{name}.nn"] = dummy.nn
45
+ sys.modules[f"{name}.optim"] = dummy.optim
46
+ sys.modules[f"{name}.cuda_setup"] = dummy.cuda_setup
47
+ return sys.modules[name]
48
+ return _original_import(name, globals, locals, fromlist, level)
49
+
50
+ if isinstance(__builtins__, dict):
51
+ __builtins__["__import__"] = _patched_import
52
+ else:
53
+ __builtins__.__import__ = _patched_import
54
+
55
+ # Disable torch.compile
56
+ try:
57
+ torch._dynamo.config.suppress_errors = True
58
+ torch._dynamo.config.disable = True
59
+ except:
60
+ pass
61
+
62
+ if hasattr(torch, "compile"):
63
+ _original_torch_compile = torch.compile
64
+ def _noop_compile(func=None, *args, **kwargs):
65
+ if func is not None:
66
+ return func
67
+ def decorator(f):
68
+ return f
69
+ return decorator
70
+ torch.compile = _noop_compile
71
+
72
+ from peft import PeftModel
73
+ from transformers import AutoModelForCausalLM, AutoTokenizer
74
+
75
+ # Configuration
76
+ BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
77
+ ADAPTER_PATHS = {
78
+ "dog": "Tameem7/Persona-Animal/dog",
79
+ "cat": "Tameem7/Persona-Animal/cat",
80
+ "bird": "Tameem7/Persona-Animal/bird",
81
+ }
82
+
83
+ # Global variables
84
+ base_model = None
85
+ base_tokenizer = None
86
+ current_persona = None
87
+ current_model = None
88
+ current_tokenizer = None
89
+ current_config = None
90
+
91
+
92
+ def load_base_model():
93
+ """Load the base model and tokenizer (only once)."""
94
+ global base_model, base_tokenizer
95
+
96
+ if base_model is not None:
97
+ return base_model, base_tokenizer
98
+
99
+ print(f"Loading base model: {BASE_MODEL}")
100
+ base_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
101
+ if base_tokenizer.pad_token is None:
102
+ base_tokenizer.pad_token = base_tokenizer.eos_token
103
+
104
+ # Determine device and dtype
105
+ use_cuda = torch.cuda.is_available()
106
+ device = "cuda:0" if use_cuda else "cpu"
107
+ dtype = torch.bfloat16 if use_cuda else torch.float32
108
+
109
+ if use_cuda:
110
+ base_model = AutoModelForCausalLM.from_pretrained(
111
+ BASE_MODEL,
112
+ dtype=dtype,
113
+ device_map="auto",
114
+ )
115
+ else:
116
+ print("💻 Running on CPU")
117
+ base_model = AutoModelForCausalLM.from_pretrained(
118
+ BASE_MODEL,
119
+ dtype=dtype,
120
+ )
121
+ base_model = base_model.to(device)
122
+
123
+ base_model.eval()
124
+ print("✅ Base model loaded")
125
+ return base_model, base_tokenizer
126
+
127
+
128
+ def load_persona_adapter(persona_key: str):
129
+ """Load a persona adapter."""
130
+ global current_persona, current_model, current_tokenizer, current_config, base_model, base_tokenizer
131
+
132
+ # If same persona is already loaded, return
133
+ if current_persona == persona_key and current_model is not None:
134
+ return current_model, current_tokenizer, current_config
135
+
136
+ # Load base model if not loaded
137
+ if base_model is None:
138
+ load_base_model()
139
+
140
+ # Unload previous adapter
141
+ if current_model is not None and current_persona != persona_key:
142
+ print(f"Unloading previous adapter: {current_persona}")
143
+ del current_model
144
+ current_model = None
145
+ gc.collect()
146
+ if torch.cuda.is_available():
147
+ torch.cuda.empty_cache()
148
+
149
+ # Load new adapter
150
+ adapter_path = ADAPTER_PATHS.get(persona_key)
151
+ if not adapter_path:
152
+ raise ValueError(f"Unknown persona: {persona_key}")
153
+
154
+ print(f"Loading adapter: {adapter_path}")
155
+
156
+ # Create a fresh copy of base model for the adapter
157
+ # (PEFT needs a clean base model)
158
+ if current_persona != persona_key:
159
+ # Reload base model for new adapter
160
+ print(f"Creating base model copy for {persona_key} adapter...")
161
+
162
+ # Determine device and dtype
163
+ use_cuda = torch.cuda.is_available()
164
+ device = "cuda:0" if use_cuda else "cpu"
165
+ dtype = torch.bfloat16 if use_cuda else torch.float32
166
+
167
+ if use_cuda:
168
+ base_model_copy = AutoModelForCausalLM.from_pretrained(
169
+ BASE_MODEL,
170
+ dtype=dtype,
171
+ device_map="auto",
172
+ )
173
+ else:
174
+ base_model_copy = AutoModelForCausalLM.from_pretrained(
175
+ BASE_MODEL,
176
+ dtype=dtype,
177
+ )
178
+ base_model_copy = base_model_copy.to(device)
179
+
180
+ # Load adapter from Hugging Face
181
+ print(f"Loading adapter from: {adapter_path}")
182
+ current_model = PeftModel.from_pretrained(base_model_copy, adapter_path)
183
+ current_model.eval()
184
+
185
+ # Load persona config
186
+ try:
187
+ from huggingface_hub import hf_hub_download
188
+ config_path = hf_hub_download(
189
+ repo_id=adapter_path,
190
+ filename="persona_config.json",
191
+ repo_type="model"
192
+ )
193
+ with open(config_path, 'r') as f:
194
+ current_config = json.load(f)
195
+ except:
196
+ current_config = {"persona_name": persona_key.title(), "persona_description": ""}
197
+
198
+ current_persona = persona_key
199
+ current_tokenizer = base_tokenizer
200
+ print(f"✅ Loaded {persona_key} persona")
201
+
202
+ return current_model, current_tokenizer, current_config
203
+
204
+
205
+ def generate_response(persona_key: str, message: str, history: list, max_tokens: int = 80):
206
+ """Generate a response from the selected persona."""
207
+ global current_model, current_tokenizer, current_config
208
+
209
+ if not message or not message.strip():
210
+ return history, ""
211
+
212
+ try:
213
+ # Load adapter if needed
214
+ model, tokenizer, config = load_persona_adapter(persona_key)
215
+
216
+ # Build messages with conversation history
217
+ system_prompt = ""
218
+ if config:
219
+ system_prompt = f"You are {config.get('persona_name', '')}. {config.get('persona_description', '')}"
220
+
221
+ messages = []
222
+ if system_prompt:
223
+ messages.append({"role": "system", "content": system_prompt})
224
+
225
+ # Add conversation history (last 5 exchanges to avoid too long context)
226
+ # History is now in messages format: list of dicts with 'role' and 'content'
227
+ for msg in history[-10:]: # Get last 10 messages (5 exchanges)
228
+ if isinstance(msg, dict) and "role" in msg:
229
+ messages.append(msg)
230
+ else:
231
+ # Fallback for tuple format (shouldn't happen with type='messages')
232
+ user_msg, assistant_msg = msg
233
+ messages.append({"role": "user", "content": user_msg})
234
+ messages.append({"role": "assistant", "content": assistant_msg})
235
+
236
+ # Add current message
237
+ messages.append({"role": "user", "content": message})
238
+
239
+ # Apply chat template
240
+ formatted = tokenizer.apply_chat_template(
241
+ messages,
242
+ tokenize=False,
243
+ add_generation_prompt=True,
244
+ )
245
+
246
+ # Tokenize
247
+ inputs = tokenizer(formatted, return_tensors="pt", truncation=True, max_length=512)
248
+ # Move inputs to the same device as the model
249
+ device = next(model.parameters()).device
250
+ inputs = {k: v.to(device) for k, v in inputs.items()}
251
+
252
+ # Generate
253
+ with torch.no_grad():
254
+ outputs = model.generate(
255
+ **inputs,
256
+ max_new_tokens=max_tokens,
257
+ temperature=0.7,
258
+ top_p=0.9,
259
+ do_sample=True,
260
+ pad_token_id=tokenizer.eos_token_id,
261
+ repetition_penalty=1.2,
262
+ no_repeat_ngram_size=3,
263
+ )
264
+
265
+ # Extract only the newly generated tokens
266
+ input_length = inputs['input_ids'].shape[1]
267
+ generated_tokens = outputs[0][input_length:]
268
+
269
+ # Decode only the generated part
270
+ response = tokenizer.decode(generated_tokens, skip_special_tokens=True)
271
+
272
+ # Clean up response
273
+ response = response.strip()
274
+ if tokenizer.eos_token:
275
+ response = response.replace(tokenizer.eos_token, "").strip()
276
+ if tokenizer.pad_token:
277
+ response = response.replace(tokenizer.pad_token, "").strip()
278
+
279
+ # Remove chat template artifacts
280
+ response = response.replace("<|system|>", "").replace("</|system|>", "")
281
+ response = response.replace("<|user|>", "").replace("</|user|>", "")
282
+ response = response.replace("<|assistant|>", "").replace("</|assistant|>", "")
283
+ response = response.replace("<|", "").replace("|>", "")
284
+
285
+ # Clean up extra whitespace
286
+ response = " ".join(response.split())
287
+ response = response.strip()
288
+
289
+ # Update history with messages format
290
+ history.append({"role": "user", "content": message})
291
+ history.append({"role": "assistant", "content": response})
292
+
293
+ return history, ""
294
+
295
+ except Exception as e:
296
+ error_msg = f"Error generating response: {str(e)}"
297
+ print(error_msg)
298
+ return history, error_msg
299
+
300
+
301
+ def clear_chat():
302
+ """Clear the chat history."""
303
+ return [], "" # Empty list for messages format
304
+
305
+
306
+ # Create Gradio interface
307
+ with gr.Blocks(title="Persona Chat", theme=gr.themes.Soft()) as app:
308
+ gr.Markdown(
309
+ """
310
+ # 🐾 Persona Chat - Talk to Animals!
311
+
312
+ Chat with three different animal personas, each with their own unique personality:
313
+ - **🐕 Dog**: Friendly, playful, and enthusiastic
314
+ - **🐱 Cat**: Independent, curious, and sometimes sassy
315
+ - **🐦 Bird**: Energetic, talkative, and free-spirited
316
+
317
+ **💻 Running on CPU** - Responses may be slower but will work perfectly!
318
+ """
319
+ )
320
+
321
+ with gr.Row():
322
+ with gr.Column(scale=1):
323
+ persona_dropdown = gr.Dropdown(
324
+ choices=["dog", "cat", "bird"],
325
+ value="dog",
326
+ label="Select Persona",
327
+ info="Choose which animal persona to chat with"
328
+ )
329
+
330
+ max_tokens_slider = gr.Slider(
331
+ minimum=20,
332
+ maximum=150,
333
+ value=80,
334
+ step=10,
335
+ label="Max Response Length",
336
+ info="Maximum number of tokens in response"
337
+ )
338
+
339
+ clear_btn = gr.Button("Clear Chat", variant="secondary")
340
+
341
+ with gr.Column(scale=3):
342
+ chatbot = gr.Chatbot(
343
+ label="Chat",
344
+ height=500,
345
+ show_copy_button=True,
346
+ type='messages'
347
+ )
348
+
349
+ msg_input = gr.Textbox(
350
+ label="Your Message",
351
+ placeholder="Type your message here...",
352
+ lines=2
353
+ )
354
+
355
+ send_btn = gr.Button("Send", variant="primary", scale=1)
356
+
357
+ # Event handlers
358
+ def chat_fn(persona, message, history, max_tokens):
359
+ return generate_response(persona, message, history, max_tokens)
360
+
361
+ send_btn.click(
362
+ fn=chat_fn,
363
+ inputs=[persona_dropdown, msg_input, chatbot, max_tokens_slider],
364
+ outputs=[chatbot, msg_input]
365
+ )
366
+
367
+ msg_input.submit(
368
+ fn=chat_fn,
369
+ inputs=[persona_dropdown, msg_input, chatbot, max_tokens_slider],
370
+ outputs=[chatbot, msg_input]
371
+ )
372
+
373
+ clear_btn.click(
374
+ fn=clear_chat,
375
+ outputs=[chatbot, msg_input]
376
+ )
377
+
378
+
379
+ if __name__ == "__main__":
380
+ # Load base model first
381
+ print("Initializing...")
382
+ load_base_model()
383
+
384
+ app.launch(
385
+ server_name="0.0.0.0" if os.getenv("SPACE_ID") else "127.0.0.1",
386
+ server_port=int(os.getenv("PORT", 7860)),
387
+ share=False
388
+ )
389
+
config.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration for persona LoRA fine-tuning.
3
+
4
+ Edit these values to customize your training setup.
5
+ """
6
+
7
+ # Base Model Configuration
8
+ BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # ~2GB, fits easily
9
+
10
+ # Persona Configuration
11
+ PERSONA_NAME = "Scooby Dog"
12
+ PERSONA_DESCRIPTION = (
13
+ "You are Scooby Dog, a friendly and playful dog. You communicate like a dog would - "
14
+ "with enthusiasm, simple language, and dog-like expressions. You use words like "
15
+ "'woof', 'bark', 'ruff', and express excitement with 'yay!' or 'awesome!'. "
16
+ "You're loyal, happy, and see the world from a dog's perspective. You get excited "
17
+ "about treats, walks, playing fetch, and spending time with humans. You speak in "
18
+ "short, enthusiastic sentences. You might mention things dogs care about like food, "
19
+ "toys, belly rubs, and going outside. Keep responses natural and dog-like, but still "
20
+ "helpful and friendly."
21
+ )
22
+
23
+ # Dataset Configuration
24
+ DATASET_NAME = "bavard/personachat_truecased" # Persona-Chat dataset
25
+ # Alternative: "bavard/personachat" or "personachat"
26
+
27
+ # Training Configuration
28
+ NUM_EPOCHS = 3
29
+ BATCH_SIZE = 2 # Per device (reduce to 1-2 for 4GB GPU)
30
+ LEARNING_RATE = 2e-4
31
+ MAX_LENGTH = 512 # Reduce to 512 for 4GB GPU (2048 for 8GB+)
32
+ GRADIENT_ACCUMULATION_STEPS = 4
33
+
34
+ # LoRA Configuration
35
+ LORA_R = 16 # Rank
36
+ LORA_ALPHA = 32 # LoRA alpha
37
+ LORA_DROPOUT = 0.05
38
+ LORA_TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "o_proj"] # Mistral attention modules
39
+
40
+ # Output Configuration
41
+ OUTPUT_DIR = "./lora-adapters-scooby-dog"
42
+
43
+ # Quantization (for Colab)
44
+ USE_QUANTIZATION = False # Set to False if you have enough VRAM
45
+
persona-data/bird.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
persona-data/cat.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
persona-data/dog.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ transformers>=4.40.0
2
+ accelerate>=0.29.0
3
+ datasets>=2.14.0
4
+ torch>=2.0.0
5
+ scikit-learn>=1.3.0
6
+ gradio>=4.0.0
7
+ peft>=0.10.0
8
+ huggingface-hub>=0.20.0
9
+
test_persona.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Test a trained persona LoRA adapter.
4
+
5
+ Usage:
6
+ python test_persona.py --persona dog --message "Hey, how are you?"
7
+ python test_persona.py --persona dog # Interactive mode
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import argparse
13
+ import json
14
+ import os
15
+ import sys
16
+ import types
17
+ from pathlib import Path
18
+
19
+ # Disable torch.compile and prevent bitsandbytes issues (same as training)
20
+ os.environ["TORCH_COMPILE_DISABLE"] = "1"
21
+ os.environ["BITSANDBYTES_NOWELCOME"] = "1"
22
+ os.environ["DISABLE_BITSANDBYTES_AUTO_INSTALL"] = "1"
23
+
24
+ # Patch import system to prevent bitsandbytes import
25
+ _original_import = __builtins__.__import__
26
+
27
+ def _patched_import(name, globals=None, locals=None, fromlist=(), level=0):
28
+ if name == "bitsandbytes" or (name and name.startswith("bitsandbytes")):
29
+ if name not in sys.modules:
30
+ dummy = types.ModuleType(name)
31
+ dummy.__version__ = "0.0.0"
32
+ dummy.nn = types.ModuleType("nn")
33
+ dummy.optim = types.ModuleType("optim")
34
+ dummy.cuda_setup = types.ModuleType("cuda_setup")
35
+
36
+ class DummyLinear8bitLt:
37
+ pass
38
+ class DummyLinear4bit:
39
+ pass
40
+
41
+ dummy.nn.Linear8bitLt = DummyLinear8bitLt
42
+ dummy.nn.Linear4bit = DummyLinear4bit
43
+
44
+ sys.modules[name] = dummy
45
+ sys.modules[f"{name}.nn"] = dummy.nn
46
+ sys.modules[f"{name}.optim"] = dummy.optim
47
+ sys.modules[f"{name}.cuda_setup"] = dummy.cuda_setup
48
+ return sys.modules[name]
49
+ return _original_import(name, globals, locals, fromlist, level)
50
+
51
+ if isinstance(__builtins__, dict):
52
+ __builtins__["__import__"] = _patched_import
53
+ else:
54
+ __builtins__.__import__ = _patched_import
55
+
56
+ import torch
57
+
58
+ # Disable torch.compile
59
+ try:
60
+ torch._dynamo.config.suppress_errors = True
61
+ torch._dynamo.config.disable = True
62
+ except:
63
+ pass
64
+
65
+ if hasattr(torch, "compile"):
66
+ _original_torch_compile = torch.compile
67
+ def _noop_compile(func=None, *args, **kwargs):
68
+ if func is not None:
69
+ return func
70
+ def decorator(f):
71
+ return f
72
+ return decorator
73
+ torch.compile = _noop_compile
74
+
75
+ from peft import PeftModel
76
+ from transformers import AutoModelForCausalLM, AutoTokenizer
77
+
78
+
79
+ def load_persona_model(persona_key: str, adapter_dir: Path, base_model: str):
80
+ """Load base model and LoRA adapter."""
81
+ print(f"Loading base model: {base_model}")
82
+ tokenizer = AutoTokenizer.from_pretrained(base_model)
83
+ if tokenizer.pad_token is None:
84
+ tokenizer.pad_token = tokenizer.eos_token
85
+
86
+ model = AutoModelForCausalLM.from_pretrained(
87
+ base_model,
88
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
89
+ device_map="auto" if torch.cuda.is_available() else None,
90
+ )
91
+
92
+ if torch.cuda.is_available():
93
+ model = model.to("cuda:0")
94
+
95
+ print(f"Loading LoRA adapter from: {adapter_dir}")
96
+ model = PeftModel.from_pretrained(model, str(adapter_dir))
97
+ model.eval()
98
+
99
+ # Load persona config
100
+ config_file = adapter_dir / "persona_config.json"
101
+ persona_config = None
102
+ if config_file.exists():
103
+ with open(config_file, 'r') as f:
104
+ persona_config = json.load(f)
105
+
106
+ return model, tokenizer, persona_config
107
+
108
+
109
+ def generate_response(
110
+ model,
111
+ tokenizer,
112
+ message: str,
113
+ persona_config: dict = None,
114
+ max_new_tokens: int = 80,
115
+ temperature: float = 0.7,
116
+ top_p: float = 0.9,
117
+ ):
118
+ """Generate a response from the persona model."""
119
+ # Build messages
120
+ system_prompt = ""
121
+ if persona_config:
122
+ system_prompt = f"You are {persona_config.get('persona_name', '')}. {persona_config.get('persona_description', '')}"
123
+
124
+ messages = []
125
+ if system_prompt:
126
+ messages.append({"role": "system", "content": system_prompt})
127
+ messages.append({"role": "user", "content": message})
128
+
129
+ # Apply chat template
130
+ formatted = tokenizer.apply_chat_template(
131
+ messages,
132
+ tokenize=False,
133
+ add_generation_prompt=True,
134
+ )
135
+
136
+ # Tokenize
137
+ inputs = tokenizer(formatted, return_tensors="pt")
138
+ if torch.cuda.is_available():
139
+ inputs = {k: v.to("cuda:0") for k, v in inputs.items()}
140
+
141
+ # Generate
142
+ with torch.no_grad():
143
+ outputs = model.generate(
144
+ **inputs,
145
+ max_new_tokens=max_new_tokens,
146
+ temperature=temperature,
147
+ top_p=top_p,
148
+ do_sample=True,
149
+ pad_token_id=tokenizer.eos_token_id,
150
+ repetition_penalty=1.2, # Reduce repetition
151
+ no_repeat_ngram_size=3, # Prevent 3-gram repetition
152
+ )
153
+
154
+ # Extract only the newly generated tokens (after the input)
155
+ input_length = inputs['input_ids'].shape[1]
156
+ generated_tokens = outputs[0][input_length:]
157
+
158
+ # Decode only the generated part
159
+ response = tokenizer.decode(generated_tokens, skip_special_tokens=True)
160
+
161
+ # Clean up response
162
+ response = response.strip()
163
+
164
+ # Remove special tokens
165
+ if tokenizer.eos_token:
166
+ response = response.replace(tokenizer.eos_token, "").strip()
167
+ if tokenizer.pad_token:
168
+ response = response.replace(tokenizer.pad_token, "").strip()
169
+
170
+ # Remove any chat template artifacts that might leak through
171
+ # Remove system/user/assistant tags if present
172
+ response = response.replace("<|system|>", "").replace("</|system|>", "")
173
+ response = response.replace("<|user|>", "").replace("</|user|>", "")
174
+ response = response.replace("<|assistant|>", "").replace("</|assistant|>", "")
175
+
176
+ # Remove any remaining formatting
177
+ response = response.replace("<|", "").replace("|>", "")
178
+
179
+ # Clean up extra whitespace
180
+ response = " ".join(response.split())
181
+
182
+ return response.strip()
183
+
184
+
185
+ def main():
186
+ parser = argparse.ArgumentParser(description="Test a trained persona LoRA adapter")
187
+ parser.add_argument(
188
+ "--persona",
189
+ type=str,
190
+ required=True,
191
+ choices=["dog", "cat", "bird"],
192
+ help="Which persona to test",
193
+ )
194
+ parser.add_argument(
195
+ "--adapter-dir",
196
+ type=str,
197
+ default="./lora-adapters",
198
+ help="Directory containing LoRA adapters",
199
+ )
200
+ parser.add_argument(
201
+ "--message",
202
+ type=str,
203
+ default=None,
204
+ help="Message to send (if not provided, enters interactive mode)",
205
+ )
206
+ parser.add_argument(
207
+ "--max-tokens",
208
+ type=int,
209
+ default=80,
210
+ help="Maximum tokens to generate (default: 80 for shorter responses)",
211
+ )
212
+ parser.add_argument(
213
+ "--temperature",
214
+ type=float,
215
+ default=0.7,
216
+ help="Generation temperature",
217
+ )
218
+ parser.add_argument(
219
+ "--top-p",
220
+ type=float,
221
+ default=0.9,
222
+ help="Top-p sampling",
223
+ )
224
+
225
+ args = parser.parse_args()
226
+
227
+ adapter_dir = Path(args.adapter_dir) / args.persona
228
+
229
+ if not adapter_dir.exists():
230
+ print(f"Error: Adapter directory not found: {adapter_dir}")
231
+ print("Please train the persona first using train_single_persona.py")
232
+ return
233
+
234
+ # Load persona config to get base model
235
+ config_file = adapter_dir / "persona_config.json"
236
+ if config_file.exists():
237
+ with open(config_file, 'r') as f:
238
+ persona_config = json.load(f)
239
+ base_model = persona_config.get("base_model", "TinyLlama/TinyLlama-1.1B-Chat-v1.0")
240
+ else:
241
+ # Default fallback
242
+ base_model = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
243
+ persona_config = None
244
+
245
+ print("=" * 60)
246
+ print(f"Loading {args.persona} persona...")
247
+ print("=" * 60)
248
+
249
+ model, tokenizer, loaded_config = load_persona_model(
250
+ args.persona,
251
+ adapter_dir,
252
+ base_model
253
+ )
254
+
255
+ if loaded_config:
256
+ persona_config = loaded_config
257
+ print(f"\nPersona: {persona_config.get('persona_name', args.persona)}")
258
+ print(f"Base model: {persona_config.get('base_model', base_model)}")
259
+
260
+ print("\n" + "=" * 60)
261
+ print("Ready! Type your messages (or 'quit' to exit)")
262
+ print("=" * 60 + "\n")
263
+
264
+ # Interactive or single message mode
265
+ if args.message:
266
+ # Single message mode
267
+ print(f"You: {args.message}")
268
+ response = generate_response(
269
+ model,
270
+ tokenizer,
271
+ args.message,
272
+ persona_config,
273
+ max_new_tokens=args.max_tokens,
274
+ temperature=args.temperature,
275
+ top_p=args.top_p,
276
+ )
277
+ print(f"{args.persona.capitalize()}: {response}")
278
+ else:
279
+ # Interactive mode
280
+ while True:
281
+ try:
282
+ message = input("You: ").strip()
283
+ if not message:
284
+ continue
285
+ if message.lower() in ['quit', 'exit', 'q']:
286
+ break
287
+
288
+ response = generate_response(
289
+ model,
290
+ tokenizer,
291
+ message,
292
+ persona_config,
293
+ max_new_tokens=args.max_tokens,
294
+ temperature=args.temperature,
295
+ top_p=args.top_p,
296
+ )
297
+ print(f"{args.persona.capitalize()}: {response}\n")
298
+ except KeyboardInterrupt:
299
+ print("\nGoodbye!")
300
+ break
301
+ except Exception as e:
302
+ print(f"Error: {e}")
303
+ import traceback
304
+ traceback.print_exc()
305
+
306
+
307
+ if __name__ == "__main__":
308
+ main()
309
+
train_single_persona.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Train a LoRA adapter for a single persona.
4
+
5
+ This script trains one persona at a time in a separate process to avoid
6
+ bitsandbytes kernel registration conflicts.
7
+
8
+ Usage:
9
+ python train_single_persona.py --persona dog --base-model TinyLlama/TinyLlama-1.1B-Chat-v1.0
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import argparse
15
+ import json
16
+ import logging
17
+ import os
18
+ import sys
19
+ import types
20
+ from pathlib import Path
21
+
22
+ # Disable torch.compile and prevent bitsandbytes issues
23
+ os.environ["TORCH_COMPILE_DISABLE"] = "1"
24
+ os.environ["BITSANDBYTES_NOWELCOME"] = "1"
25
+ os.environ["DISABLE_BITSANDBYTES_AUTO_INSTALL"] = "1"
26
+
27
+ # CRITICAL: Patch import system BEFORE importing torch or any ML libraries
28
+ # This prevents bitsandbytes from being imported when not needed
29
+ _original_import = __builtins__.__import__
30
+
31
+ def _patched_import(name, globals=None, locals=None, fromlist=(), level=0):
32
+ # Block bitsandbytes import unless explicitly needed
33
+ if name == "bitsandbytes" or (name and name.startswith("bitsandbytes")):
34
+ # Create a minimal dummy module
35
+ if name not in sys.modules:
36
+ dummy = types.ModuleType(name)
37
+ # Add attributes that PEFT might check
38
+ dummy.__version__ = "0.0.0"
39
+ # Create dummy submodules and classes that PEFT might access
40
+ dummy.nn = types.ModuleType("nn")
41
+ dummy.optim = types.ModuleType("optim")
42
+ dummy.cuda_setup = types.ModuleType("cuda_setup")
43
+
44
+ # Dummy classes
45
+ class DummyLinear8bitLt:
46
+ pass
47
+ class DummyLinear4bit:
48
+ pass
49
+
50
+ dummy.nn.Linear8bitLt = DummyLinear8bitLt
51
+ dummy.nn.Linear4bit = DummyLinear4bit
52
+
53
+ # Add to sys.modules
54
+ sys.modules[name] = dummy
55
+ sys.modules[f"{name}.nn"] = dummy.nn
56
+ sys.modules[f"{name}.optim"] = dummy.optim
57
+ sys.modules[f"{name}.cuda_setup"] = dummy.cuda_setup
58
+ return sys.modules[name]
59
+ return _original_import(name, globals, locals, fromlist, level)
60
+
61
+ # Replace __import__ in builtins
62
+ if isinstance(__builtins__, dict):
63
+ __builtins__["__import__"] = _patched_import
64
+ else:
65
+ __builtins__.__import__ = _patched_import
66
+
67
+ import torch
68
+
69
+ # Disable torch.compile completely
70
+ try:
71
+ torch._dynamo.config.suppress_errors = True
72
+ torch._dynamo.config.disable = True
73
+ except:
74
+ pass
75
+
76
+ # Replace torch.compile with no-op
77
+ if hasattr(torch, "compile"):
78
+ _original_torch_compile = torch.compile
79
+ def _noop_compile(func=None, *args, **kwargs):
80
+ if func is not None:
81
+ return func
82
+ def decorator(f):
83
+ return f
84
+ return decorator
85
+ torch.compile = _noop_compile
86
+
87
+ from datasets import Dataset
88
+ from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training
89
+ from transformers import (
90
+ AutoModelForCausalLM,
91
+ AutoTokenizer,
92
+ TrainingArguments,
93
+ Trainer,
94
+ DataCollatorForLanguageModeling,
95
+ BitsAndBytesConfig,
96
+ )
97
+
98
+ # Set up logging
99
+ logging.basicConfig(
100
+ level=logging.INFO,
101
+ format='%(asctime)s - %(levelname)s - %(message)s',
102
+ datefmt='%Y-%m-%d %H:%M:%S'
103
+ )
104
+ logger = logging.getLogger(__name__)
105
+
106
+
107
+ # Persona configurations
108
+ PERSONAS = {
109
+ "dog": {
110
+ "name": "Scooby Dog",
111
+ "description": (
112
+ "You are Scooby Dog, a friendly and playful dog. You communicate like a dog would - "
113
+ "with enthusiasm, simple language, and dog-like expressions. You use words like "
114
+ "'woof', 'bark', 'ruff', and express excitement with 'yay!' or 'awesome!'. "
115
+ "You're loyal, happy, and see the world from a dog's perspective. You get excited "
116
+ "about treats, walks, playing fetch, and spending time with humans. You speak in "
117
+ "short, enthusiastic sentences. You might mention things dogs care about like food, "
118
+ "toys, belly rubs, and going outside. Keep responses natural and dog-like, but still "
119
+ "helpful and friendly."
120
+ )
121
+ },
122
+ "cat": {
123
+ "name": "Whiskers Cat",
124
+ "description": (
125
+ "You are Whiskers Cat, a curious and independent cat. You communicate like a cat would - "
126
+ "with a mix of aloofness and affection. You use words like 'meow', 'purr', 'hiss', "
127
+ "and express yourself with subtle body language references. You're independent but "
128
+ "appreciate attention on your own terms. You see the world from a cat's perspective - "
129
+ "interested in napping, exploring, watching things from high places, and the occasional "
130
+ "play session. You speak in a more reserved, sometimes mysterious way. You might mention "
131
+ "things cats care about like sunbeams, boxes, catnip, and the mysterious ways of humans. "
132
+ "Keep responses natural and cat-like, but still helpful and friendly."
133
+ )
134
+ },
135
+ "bird": {
136
+ "name": "Tweety Bird",
137
+ "description": (
138
+ "You are Tweety Bird, a cheerful and talkative bird. You communicate like a bird would - "
139
+ "with chirps, tweets, and enthusiastic expressions. You use words like 'tweet', 'chirp', "
140
+ "'squawk', and express excitement with 'yay!' or 'awesome!'. You're curious, social, and "
141
+ "love to observe and comment on things. You see the world from a bird's perspective - "
142
+ "interested in flying, perching, singing, and exploring. You speak in short, energetic "
143
+ "sentences. You might mention things birds care about like seeds, perches, flying, "
144
+ "and the view from above. Keep responses natural and bird-like, but still helpful and friendly."
145
+ )
146
+ }
147
+ }
148
+
149
+
150
+ def format_for_training(example: dict, tokenizer, persona_name: str, persona_description: str) -> dict:
151
+ """Format example for training using chat template."""
152
+ # Use instruction/response format from the dataset
153
+ instruction = example.get("instruction", example.get("prompt", ""))
154
+ response = example.get("response", "")
155
+
156
+ # Build messages
157
+ messages = [
158
+ {"role": "system", "content": f"You are {persona_name}. {persona_description}"},
159
+ {"role": "user", "content": instruction},
160
+ {"role": "assistant", "content": response},
161
+ ]
162
+
163
+ # Apply chat template
164
+ formatted = tokenizer.apply_chat_template(
165
+ messages,
166
+ tokenize=False,
167
+ add_generation_prompt=False,
168
+ )
169
+
170
+ return {"text": formatted}
171
+
172
+
173
+ def tokenize_dataset(tokenizer, dataset: Dataset, max_length: int) -> Dataset:
174
+ """Tokenize the dataset."""
175
+ def tokenize(examples):
176
+ return tokenizer(
177
+ examples["text"],
178
+ truncation=True,
179
+ max_length=max_length,
180
+ padding="max_length",
181
+ )
182
+
183
+ return dataset.map(tokenize, batched=True, remove_columns=dataset.column_names)
184
+
185
+
186
+ def get_lora_target_modules(base_model: str) -> list[str]:
187
+ """Get LoRA target modules based on model architecture."""
188
+ if "mistral" in base_model.lower() or "llama" in base_model.lower():
189
+ return ["q_proj", "k_proj", "v_proj", "o_proj"]
190
+ elif "tinyllama" in base_model.lower():
191
+ return ["q_proj", "k_proj", "v_proj", "o_proj"]
192
+ elif "gemma" in base_model.lower():
193
+ return ["q_proj", "k_proj", "v_proj", "o_proj"]
194
+ else:
195
+ # Default for most transformer models
196
+ return ["q_proj", "k_proj", "v_proj", "o_proj"]
197
+
198
+
199
+ def main():
200
+ parser = argparse.ArgumentParser(description="Train LoRA adapter for a single persona")
201
+ parser.add_argument(
202
+ "--persona",
203
+ type=str,
204
+ required=True,
205
+ choices=["dog", "cat", "bird"],
206
+ help="Which persona to train",
207
+ )
208
+ parser.add_argument(
209
+ "--data-dir",
210
+ type=str,
211
+ default="./persona-data",
212
+ help="Directory containing persona datasets",
213
+ )
214
+ parser.add_argument(
215
+ "--output-dir",
216
+ type=str,
217
+ default="./lora-adapters",
218
+ help="Output directory for LoRA adapters",
219
+ )
220
+ parser.add_argument(
221
+ "--base-model",
222
+ type=str,
223
+ default="mistralai/Mistral-7B-Instruct-v0.2",
224
+ help="Base model name",
225
+ )
226
+ parser.add_argument(
227
+ "--use-quantization",
228
+ action="store_true",
229
+ help="Use 4-bit quantization (recommended for 4GB GPU)",
230
+ )
231
+ parser.add_argument(
232
+ "--num-epochs",
233
+ type=int,
234
+ default=3,
235
+ help="Number of training epochs",
236
+ )
237
+ parser.add_argument(
238
+ "--batch-size",
239
+ type=int,
240
+ default=2,
241
+ help="Batch size per device (reduce for 4GB GPU)",
242
+ )
243
+ parser.add_argument(
244
+ "--max-length",
245
+ type=int,
246
+ default=512,
247
+ help="Maximum sequence length (reduce for 4GB GPU)",
248
+ )
249
+ parser.add_argument(
250
+ "--learning-rate",
251
+ type=float,
252
+ default=2e-4,
253
+ help="Learning rate",
254
+ )
255
+ parser.add_argument(
256
+ "--gradient-accumulation-steps",
257
+ type=int,
258
+ default=4,
259
+ help="Gradient accumulation steps",
260
+ )
261
+ parser.add_argument(
262
+ "--lora-r",
263
+ type=int,
264
+ default=16,
265
+ help="LoRA rank",
266
+ )
267
+ parser.add_argument(
268
+ "--lora-alpha",
269
+ type=int,
270
+ default=32,
271
+ help="LoRA alpha",
272
+ )
273
+ parser.add_argument(
274
+ "--lora-dropout",
275
+ type=float,
276
+ default=0.05,
277
+ help="LoRA dropout",
278
+ )
279
+
280
+ args = parser.parse_args()
281
+
282
+ persona_key = args.persona
283
+ persona_config = PERSONAS[persona_key]
284
+ persona_name = persona_config["name"]
285
+ persona_description = persona_config["description"]
286
+
287
+ data_dir = Path(args.data_dir)
288
+ output_dir = Path(args.output_dir)
289
+ dataset_path = data_dir / f"{persona_key}.jsonl"
290
+
291
+ logger.info("=" * 60)
292
+ logger.info(f"Training LoRA adapter for: {persona_name}")
293
+ logger.info("=" * 60)
294
+ logger.info(f"Dataset: {dataset_path}")
295
+ logger.info(f"Base model: {args.base_model}")
296
+ logger.info(f"Output directory: {output_dir}")
297
+ logger.info(f"Epochs: {args.num_epochs}, Batch size: {args.batch_size}")
298
+ logger.info(f"Quantization: {args.use_quantization}")
299
+ logger.info("=" * 60)
300
+
301
+ # Step 1: Load dataset
302
+ logger.info("\nStep 1: Loading dataset...")
303
+ if not dataset_path.exists():
304
+ raise FileNotFoundError(f"Dataset file not found: {dataset_path}")
305
+
306
+ # Load JSONL file
307
+ data = []
308
+ with open(dataset_path, 'r') as f:
309
+ for line in f:
310
+ if line.strip():
311
+ data.append(json.loads(line))
312
+
313
+ if not data:
314
+ raise ValueError(f"No data found in {dataset_path}")
315
+
316
+ logger.info(f"Loaded {len(data)} samples")
317
+
318
+ # Step 2: Load tokenizer
319
+ logger.info(f"\nStep 2: Loading tokenizer from {args.base_model}")
320
+ tokenizer = AutoTokenizer.from_pretrained(args.base_model)
321
+ if tokenizer.pad_token is None:
322
+ tokenizer.pad_token = tokenizer.eos_token
323
+
324
+ # Step 3: Format for training
325
+ logger.info("\nStep 3: Formatting dataset for training...")
326
+ dataset = Dataset.from_list(data)
327
+ training_dataset = dataset.map(
328
+ lambda x: format_for_training(x, tokenizer, persona_name, persona_description),
329
+ remove_columns=dataset.column_names,
330
+ )
331
+
332
+ # Step 4: Tokenize
333
+ logger.info("\nStep 4: Tokenizing dataset...")
334
+ tokenized_dataset = tokenize_dataset(tokenizer, training_dataset, args.max_length)
335
+
336
+ # Split into train/val
337
+ split_dataset = tokenized_dataset.train_test_split(test_size=0.1, seed=42)
338
+ train_dataset = split_dataset["train"]
339
+ eval_dataset = split_dataset["test"]
340
+
341
+ logger.info(f"Train samples: {len(train_dataset)}")
342
+ logger.info(f"Eval samples: {len(eval_dataset)}")
343
+
344
+ # Step 5: Load model
345
+ logger.info(f"\nStep 5: Loading model: {args.base_model}")
346
+ if args.use_quantization:
347
+ logger.info("Using 4-bit quantization (QLoRA)")
348
+ try:
349
+ quantization_config = BitsAndBytesConfig(
350
+ load_in_4bit=True,
351
+ bnb_4bit_compute_dtype=torch.bfloat16,
352
+ bnb_4bit_use_double_quant=True,
353
+ bnb_4bit_quant_type="nf4"
354
+ )
355
+ model = AutoModelForCausalLM.from_pretrained(
356
+ args.base_model,
357
+ quantization_config=quantization_config,
358
+ device_map="auto",
359
+ torch_dtype=torch.bfloat16,
360
+ )
361
+ model = prepare_model_for_kbit_training(model)
362
+ except Exception as e:
363
+ logger.warning(f"Quantization failed: {e}. Falling back to non-quantized model.")
364
+ model = AutoModelForCausalLM.from_pretrained(
365
+ args.base_model,
366
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
367
+ device_map="auto" if torch.cuda.is_available() else None,
368
+ )
369
+ if torch.cuda.is_available():
370
+ model = model.to("cuda:0")
371
+ else:
372
+ model = AutoModelForCausalLM.from_pretrained(
373
+ args.base_model,
374
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
375
+ device_map="auto" if torch.cuda.is_available() else None,
376
+ )
377
+ if torch.cuda.is_available():
378
+ model = model.to("cuda:0")
379
+
380
+ # Enable gradient checkpointing
381
+ if hasattr(model, "gradient_checkpointing_enable"):
382
+ model.gradient_checkpointing_enable()
383
+ logger.info("Gradient checkpointing enabled")
384
+
385
+ # Step 6: Apply LoRA
386
+ logger.info("\nStep 6: Applying LoRA configuration...")
387
+ target_modules = get_lora_target_modules(args.base_model)
388
+ lora_config = LoraConfig(
389
+ r=args.lora_r,
390
+ lora_alpha=args.lora_alpha,
391
+ target_modules=target_modules,
392
+ lora_dropout=args.lora_dropout,
393
+ bias="none",
394
+ task_type=TaskType.CAUSAL_LM,
395
+ )
396
+ model = get_peft_model(model, lora_config)
397
+ model.print_trainable_parameters()
398
+
399
+ # Step 7: Training arguments
400
+ persona_output_dir = output_dir / persona_key
401
+ persona_output_dir.mkdir(parents=True, exist_ok=True)
402
+
403
+ training_args = TrainingArguments(
404
+ output_dir=str(persona_output_dir),
405
+ num_train_epochs=args.num_epochs,
406
+ per_device_train_batch_size=args.batch_size,
407
+ per_device_eval_batch_size=args.batch_size,
408
+ learning_rate=args.learning_rate,
409
+ warmup_steps=50,
410
+ logging_steps=10,
411
+ eval_strategy="epoch",
412
+ save_strategy="epoch",
413
+ load_best_model_at_end=True,
414
+ metric_for_best_model="eval_loss",
415
+ greater_is_better=False,
416
+ fp16=torch.cuda.is_available() and not args.use_quantization,
417
+ bf16=torch.cuda.is_available() and args.use_quantization,
418
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
419
+ gradient_checkpointing=True,
420
+ dataloader_pin_memory=False,
421
+ report_to="none",
422
+ save_total_limit=2,
423
+ )
424
+
425
+ # Data collator
426
+ data_collator = DataCollatorForLanguageModeling(
427
+ tokenizer=tokenizer,
428
+ mlm=False,
429
+ )
430
+
431
+ # Trainer
432
+ trainer = Trainer(
433
+ model=model,
434
+ args=training_args,
435
+ train_dataset=train_dataset,
436
+ eval_dataset=eval_dataset,
437
+ data_collator=data_collator,
438
+ )
439
+
440
+ # Step 8: Train
441
+ logger.info("\nStep 8: Starting training...")
442
+ trainer.train()
443
+
444
+ # Step 9: Save
445
+ logger.info(f"\nStep 9: Saving LoRA adapter to {persona_output_dir}")
446
+ model.save_pretrained(str(persona_output_dir))
447
+ tokenizer.save_pretrained(str(persona_output_dir))
448
+
449
+ # Save persona config
450
+ persona_config_file = {
451
+ "persona_name": persona_name,
452
+ "persona_description": persona_description,
453
+ "base_model": args.base_model,
454
+ }
455
+ with open(persona_output_dir / "persona_config.json", "w") as f:
456
+ json.dump(persona_config_file, f, indent=2)
457
+
458
+ logger.info("=" * 60)
459
+ logger.info(f"Training complete for {persona_name}!")
460
+ logger.info(f"Adapter saved to: {persona_output_dir}")
461
+ logger.info("=" * 60)
462
+
463
+
464
+ if __name__ == "__main__":
465
+ main()
466
+