File size: 3,916 Bytes
09d1245
00ae6eb
09d1245
00ae6eb
 
 
 
 
 
 
 
 
 
 
 
 
 
09d1245
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00ae6eb
 
 
 
 
 
 
 
 
 
 
 
 
09d1245
 
 
 
 
00ae6eb
 
 
 
 
 
 
 
 
09d1245
 
 
 
 
 
 
 
00ae6eb
09d1245
00ae6eb
 
 
 
09d1245
 
 
00ae6eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
09d1245
00ae6eb
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# src/train.py
import os
from typing import Dict, List

import torch
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainingArguments,
)

from peft import LoraConfig, TaskType, get_peft_model


def _format_as_chat(tokenizer, ex: Dict) -> str:
    system = (ex.get("system") or "").strip()
    user = (ex.get("user") or "").strip()
    assistant = (ex.get("assistant") or "").strip()

    # Preferred: model-native chat template (Llama/Qwen/Mistral Instruct, etc.)
    if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template:
        messages: List[Dict[str, str]] = []
        if system:
            messages.append({"role": "system", "content": system})
        messages.append({"role": "user", "content": user})
        messages.append({"role": "assistant", "content": assistant})
        return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)

    # Fallback: simple transcript
    parts = []
    if system:
        parts.append(f"### System:\n{system}")
    parts.append(f"### User:\n{user}")
    parts.append(f"### Assistant:\n{assistant}")
    return "\n\n".join(parts)


def finetune_lora(
    base_model: str,
    dataset_id: str,
    output_dir: str,
    max_train_samples: int = 2000,
    max_steps: int = 100,
    learning_rate: float = 2e-4,
    batch_size: int = 2,
    lora_r: int = 8,
    lora_alpha: int = 16,
    lora_dropout: float = 0.05,
) -> str:
    ds = load_dataset(dataset_id, split="train")

    needed = {"system", "user", "assistant"}
    missing = needed.difference(set(ds.column_names))
    if missing:
        return f"ERROR: dataset missing columns {sorted(missing)}. Found: {ds.column_names}"

    if max_train_samples and max_train_samples > 0:
        ds = ds.select(range(min(len(ds), int(max_train_samples))))

    tokenizer = AutoTokenizer.from_pretrained(base_model, use_fast=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    def tok(batch):
        texts = [_format_as_chat(tokenizer, ex) for ex in batch]
        return tokenizer(texts, truncation=True, max_length=1024)

    # map with batched=True expects a dict-of-lists; easiest is to build list of dicts per batch
    def batched_map(batch):
        # Convert dict-of-lists to list-of-dicts
        exs = [dict(zip(batch.keys(), vals)) for vals in zip(*batch.values())]
        return tok(exs)

    tokenized = ds.map(batched_map, batched=True, remove_columns=ds.column_names)

    model = AutoModelForCausalLM.from_pretrained(base_model)
    model.config.pad_token_id = tokenizer.pad_token_id

    # NOTE: target_modules depends on model architecture.
    # GPT-2 uses c_attn/c_proj; Llama uses q_proj/k_proj/v_proj/o_proj; Qwen varies.
    # Keep GPT-2 defaults here and change if you swap base_model.
    lora_cfg = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        r=int(lora_r),
        lora_alpha=int(lora_alpha),
        lora_dropout=float(lora_dropout),
        bias="none",
        target_modules=["c_attn", "c_proj"],
    )
    model = get_peft_model(model, lora_cfg)

    collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

    fp16 = torch.cuda.is_available()
    args = TrainingArguments(
        output_dir=output_dir,
        per_device_train_batch_size=int(batch_size),
        learning_rate=float(learning_rate),
        max_steps=int(max_steps),
        logging_steps=10,
        save_steps=0,
        report_to=[],
        fp16=fp16,
    )

    trainer = Trainer(model=model, args=args, train_dataset=tokenized, data_collator=collator)
    trainer.train()

    adapter_dir = os.path.join(output_dir, "adapter")
    model.save_pretrained(adapter_dir)
    tokenizer.save_pretrained(adapter_dir)

    return f"Saved LoRA adapter + tokenizer to {adapter_dir}"