File size: 11,149 Bytes
9ec3d1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
# ABOUTME: Fine-tune Qwen2.5-3B with LoRA on diary classification dataset
# ABOUTME: Outputs a lightweight adapter that can be merged with base model

import json
from pathlib import Path

import torch
from datasets import Dataset
from peft import LoraConfig, get_peft_model
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    Trainer,
    TrainingArguments,
)


def resolve_dataset_paths(paths: list[str]) -> list[Path]:
    """
    Resolve a mix of files and directories into a list of JSONL files.
    Directories are expanded to all *.jsonl files within them.
    """
    resolved = []
    for p in paths:
        path = Path(p)
        if path.is_dir():
            jsonl_files = sorted(path.glob("*.jsonl"))
            if not jsonl_files:
                print(f"  Warning: No .jsonl files found in {path}")
            resolved.extend(jsonl_files)
        elif path.is_file():
            resolved.append(path)
        else:
            raise FileNotFoundError(f"Dataset path not found: {path}")
    return resolved


def load_dataset_from_jsonl(paths: list[str]) -> Dataset:
    """
    Load one or more JSONL datasets with the format:
    {"messages": [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}

    Args:
        paths: List of file paths or directory paths. Directories are expanded
               to all *.jsonl files within them.

    Multiple files are concatenated into a single dataset.
    """
    file_paths = resolve_dataset_paths(paths)

    if not file_paths:
        raise ValueError("No dataset files found")

    examples = []
    for file_path in file_paths:
        print(f"  Loading: {file_path}")
        with open(file_path, "r", encoding="utf-8") as f:
            count = 0
            for line in f:
                if line.strip():
                    examples.append(json.loads(line))
                    count += 1
            print(f"    -> {count} examples")

    return Dataset.from_list(examples)


def format_chat_example(example: dict, tokenizer) -> dict:
    """
    Apply the chat template to convert messages into a single string.
    Returns the formatted text ready for tokenization.
    """
    messages = example["messages"]

    # Apply chat template - this formats it as Qwen2.5 expects
    # add_generation_prompt=False because we include the assistant response
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=False,
    )

    return {"text": text}


def tokenize_example(example: dict, tokenizer, max_length: int = 512) -> dict:
    """
    Tokenize the formatted text.
    """
    result = tokenizer(
        example["text"],
        truncation=True,
        max_length=max_length,
        padding=False,
    )

    # For causal LM, labels are the same as input_ids
    result["labels"] = result["input_ids"].copy()

    return result


def create_model_and_tokenizer(model_name: str = "Qwen/Qwen2.5-3B-Instruct"):
    """
    Load model and tokenizer, apply LoRA configuration.
    """
    print(f"Loading model: {model_name}")

    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # Ensure pad token is set
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Determine device and dtype
    if torch.backends.mps.is_available():
        print("Using Apple MPS (Metal) backend")
        device_map = {"": "mps"}
        model_dtype = torch.float16
    elif torch.cuda.is_available():
        print("Using CUDA backend")
        device_map = "auto"
        model_dtype = torch.bfloat16
    else:
        print("Using CPU backend (this will be slow)")
        device_map = {"": "cpu"}
        model_dtype = torch.float32

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        dtype=model_dtype,
        device_map=device_map,
        trust_remote_code=True,
    )

    # Apply LoRA
    print("Applying LoRA configuration...")
    lora_config = LoraConfig(
        r=16,
        lora_alpha=32,
        target_modules=[
            "q_proj",
            "k_proj",
            "v_proj",
            "o_proj",
            "gate_proj",
            "up_proj",
            "down_proj",
        ],
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
    )

    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()

    return model, tokenizer


def train(
    dataset_paths: list[str],
    output_dir: str = "outputs/lora-adapter",
    model_name: str = "Qwen/Qwen2.5-3B-Instruct",
    num_epochs: int = 3,
    batch_size: int = 2,
    gradient_accumulation_steps: int = 4,
    learning_rate: float = 2e-4,
    max_length: int = 512,
    val_split: float = 0.1,
):
    """
    Main training function.

    Args:
        dataset_paths: List of paths to JSONL training data files
        output_dir: Where to save the LoRA adapter
        model_name: HuggingFace model ID
        num_epochs: Number of training epochs
        batch_size: Per-device batch size
        gradient_accumulation_steps: Accumulate gradients over N steps
        learning_rate: Learning rate for AdamW optimizer
        max_length: Maximum sequence length
        val_split: Fraction of data to use for validation
    """
    print("=" * 60)
    print("LoRA Fine-Tuning")
    print("=" * 60)

    # Load model and tokenizer
    model, tokenizer = create_model_and_tokenizer(model_name)

    # Load and process dataset
    print(f"\nLoading dataset(s):")
    dataset = load_dataset_from_jsonl(dataset_paths)
    print(f"  Total examples: {len(dataset)}")

    # Format with chat template
    print("Applying chat template...")
    dataset = dataset.map(
        lambda x: format_chat_example(x, tokenizer),
        desc="Formatting",
    )

    # Tokenize
    print("Tokenizing...")
    dataset = dataset.map(
        lambda x: tokenize_example(x, tokenizer, max_length),
        remove_columns=dataset.column_names,
        desc="Tokenizing",
    )

    # Split into train/validation
    if val_split > 0:
        split = dataset.train_test_split(test_size=val_split, seed=42)
        train_dataset = split["train"]
        eval_dataset = split["test"]
        print(f"  Train examples: {len(train_dataset)}")
        print(f"  Validation examples: {len(eval_dataset)}")
    else:
        train_dataset = dataset
        eval_dataset = None
        print(f"  Train examples: {len(train_dataset)}")

    # Data collator for padding
    data_collator = DataCollatorForSeq2Seq(
        tokenizer=tokenizer,
        padding=True,
        return_tensors="pt",
    )

    # Determine if we're on MPS
    use_mps = torch.backends.mps.is_available()

    # Training arguments
    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=num_epochs,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        learning_rate=learning_rate,
        weight_decay=0.01,
        warmup_ratio=0.1,
        logging_steps=10,
        save_strategy="epoch",
        eval_strategy="epoch" if eval_dataset else "no",
        load_best_model_at_end=True if eval_dataset else False,
        metric_for_best_model="eval_loss" if eval_dataset else None,
        greater_is_better=False,
        fp16=use_mps,  # Use fp16 on MPS
        bf16=not use_mps and torch.cuda.is_available(),  # Use bf16 on CUDA
        dataloader_pin_memory=not use_mps,  # Disable on MPS
        report_to="none",  # Disable wandb/tensorboard
        remove_unused_columns=False,
    )

    # Create trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=data_collator,
        processing_class=tokenizer,
    )

    # Train!
    print("\n" + "=" * 60)
    print("Starting training...")
    print("=" * 60)
    trainer.train()

    # Save the LoRA adapter
    print(f"\nSaving adapter to: {output_dir}")
    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)

    print("\n" + "=" * 60)
    print("Training complete!")
    print("=" * 60)
    print(f"\nAdapter saved to: {output_dir}")
    print(
        f"Adapter size: {sum(f.stat().st_size for f in Path(output_dir).glob('*') if f.is_file()) / 1024 / 1024:.1f} MB"
    )

    return model, tokenizer


def test_model(model, tokenizer, test_diary: str):
    """
    Test the fine-tuned model on a sample diary entry.
    """
    messages = [
        {
            "role": "user",
            "content": f"Diary: {test_diary}\n\nWhat is the disease activity score for today?",
        }
    ]

    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
    )

    inputs = tokenizer(text, return_tensors="pt").to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=5,
            do_sample=False,
            pad_token_id=tokenizer.pad_token_id,
        )

    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # Extract just the score (last character that's a digit)
    score = None
    for char in reversed(response):
        if char.isdigit():
            score = char
            break

    return score, response


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Fine-tune Qwen2.5 with LoRA")
    parser.add_argument(
        "--dataset",
        type=str,
        nargs="+",
        required=True,
        help="Path(s) to training dataset(s) (JSONL). Multiple files are concatenated.",
    )
    parser.add_argument(
        "--output",
        type=str,
        default="outputs/lora-adapter",
        help="Output directory for the adapter",
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=3,
        help="Number of training epochs",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=2,
        help="Per-device batch size",
    )
    parser.add_argument(
        "--lr",
        type=float,
        default=2e-4,
        help="Learning rate",
    )

    args = parser.parse_args()

    # Train
    model, tokenizer = train(
        dataset_paths=args.dataset,
        output_dir=args.output,
        num_epochs=args.epochs,
        batch_size=args.batch_size,
        learning_rate=args.lr,
    )

    # Quick test
    print("\n" + "=" * 60)
    print("Testing the fine-tuned model...")
    print("=" * 60)

    test_diaries = [
        "I felt fine today, no pain at all. Went for a walk and felt great.",
        "Severe pain in my joints all day. Had to stay in bed. Medication didn't help much.",
        "Some stiffness this morning but it went away. Managed to work from home.",
    ]

    for diary in test_diaries:
        score, _ = test_model(model, tokenizer, diary)
        print(f"\nDiary: {diary[:60]}...")
        print(f"Predicted score: {score}")