File size: 2,929 Bytes
53d6231
 
 
 
 
 
 
85b01ed
 
07e2f18
 
936fde4
 
53d6231
 
936fde4
53d6231
936fde4
53d6231
 
85b01ed
07e2f18
53d6231
85b01ed
22c6032
53d6231
 
 
 
07e2f18
53d6231
 
 
 
 
 
 
 
 
07e2f18
53d6231
85b01ed
53d6231
 
 
85b01ed
 
 
 
 
 
53d6231
85b01ed
 
 
 
53d6231
85b01ed
 
53d6231
85b01ed
07e2f18
85b01ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53d6231
07e2f18
53d6231
85b01ed
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import torch
import torch.nn as nn
from datasets import load_dataset
import transformers
from transformers import AutoTokenizer, AutoConfig, LLaMAForCausalLM, LLaMATokenizer
from peft import prepare_model_for_int8_training, LoraConfig, get_peft_model
from accelerate import Accelerator
from torch.utils.data import DataLoader

def train():
    MICRO_BATCH_SIZE = 1
    BATCH_SIZE = 16
    GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE
    EPOCHS = 2
    LEARNING_RATE = 2e-10
    LORA_R = 4
    LORA_ALPHA = 8
    LORA_DROPOUT = 0.05

    accelerator = Accelerator()

    model = LLaMAForCausalLM.from_pretrained(
        "decapoda-research/llama-7b-hf"
    )
    tokenizer = LLaMATokenizer.from_pretrained(
        "decapoda-research/llama-7b-hf", add_eos_token=True
    )

    model = prepare_model_for_int8_training(model)

    config = LoraConfig(
        r=LORA_R,
        lora_alpha=LORA_ALPHA,
        target_modules=["q_proj", "v_proj"],
        lora_dropout=LORA_DROPOUT,
        bias="none",
        task_type="CAUSAL_LM",
    )
    model = get_peft_model(model, config)
    tokenizer.pad_token_id = 0
    data = load_dataset("json", data_files="samples.json")

    def generate_prompt(data_point):
        if data_point["input"]:
            prompt = f"""### Instruction:
        {data_point["instruction"]}
        ### Input:
        {data_point["input"]}
        ### Response:
        {data_point["output"]}"""
        else:
            prompt = f"""### Instruction:
        {data_point["instruction"]}
        ### Response:
        {data_point["output"]}"""

        input_tokens = tokenizer(prompt, truncation=False, padding='longest', return_tensors='pt')
        output_tokens = tokenizer(data_point["output"], truncation=False, padding='longest', return_tensors='pt')

        return input_tokens, output_tokens["input_ids"].squeeze()

    data = data.shuffle().map(generate_prompt)

    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
    model, optimizer = accelerator.prepare(model, optimizer)

    train_dataloader = DataLoader(data["train"], batch_size=MICRO_BATCH_SIZE, shuffle=True)
    train_dataloader = accelerator.prepare(train_dataloader)

    for epoch in range(EPOCHS):
        for step, batch in enumerate(train_dataloader):
            inputs, labels = batch
            inputs_tensor = torch.tensor(inputs["input_ids"], dtype=torch.long).unsqueeze(0).to(accelerator.device)
            outputs = model(inputs_tensor)
            labels_tensor = torch.tensor(labels, dtype=torch.long).to(accelerator.device)
            loss = nn.CrossEntropyLoss()(outputs.logits.view(-1, outputs.logits.size(-1)), labels_tensor.view(-1))

            accelerator.backward(loss)
            optimizer.step()
            optimizer.zero_grad()

    model.save_pretrained(f"lora-smartscraper-{accelerator.process_index}")

    if __name__ == "__main__":
        train()