File size: 11,154 Bytes
a0d95b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
import hydra
from hydra.core.config_store import ConfigStore
from omegaconf import OmegaConf, DictConfig
import torch
import yaml
from dataclasses import asdict

from datasets import load_dataset

import os
import transformers
from transformers import (AutoModelForCausalLM, AutoTokenizer, 
                          LlamaTokenizer, AutoModel, AutoConfig, 
                          TrainingArguments)
import inspect

import random
import numpy as np

# from XS_llama import IbaXs_LlamaModel, IbaXs_LlamaForCausalLM
# from utils import count_parameters
# from .configIBA import MainConfig
from iba import IbaXs_LlamaModel, IbaXs_LlamaForCausalLM, count_parameters, MainConfig

from transformers.models.llama.modeling_llama import (
    LlamaMLP,
    LlamaAttention,
    LlamaDecoderLayer,
    LlamaModel,
    LlamaForCausalLM
)

# Create the ConfigStore instance
cs = ConfigStore.instance()

# Register 'TrainConfig' as the schema for the config named 'config'
# (tên 'config' này khớp với 'config_name="config"' trong decorator)
cs.store(name="main_schema", node=MainConfig)

DEVICE = 'cuda'

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    transformers.set_seed(seed)

def test_generate(config, main_cfg):
    ###
    base_model_name = main_cfg.model.base_model_name
    if config.model_type == 'llama':
        # Due to the name of transformers' LlamaTokenizer, we have to do this
        # need to handle llama 3 separately
        if "lama-3" in base_model_name:
            print("load llama-3 tokenizer")
            tokenizer = AutoTokenizer.from_pretrained(base_model_name)
        else:
            tokenizer = LlamaTokenizer.from_pretrained(base_model_name, legacy=True)
    else:
        tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
    
    model = IbaXs_LlamaForCausalLM(config=config).to(DEVICE)
    model.eval()
    prompts = [
    "The capital of France is",
    #"Here is a simple Python function to add two numbers:"
    ]
    for i, prompt in enumerate(prompts):
        print(f"\n--- Prompt {i+1} ---")
        print(f"Input: {prompt}")

        # 4.1. Tokenize the Input
        # Convert the prompt string to PyTorch tensors
        inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)

        # 4.2. Generate Text
        # Use torch.no_grad() for inference
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=4,  # Generate up to 50 new tokens
                do_sample=True,
                temperature=0.7,
                top_k=50
                # Note: We don't need 'add_generation_prompt' here
            )

        # 4.3. Decode the Output
        # The output includes the prompt, so we slice it
        output_tokens = outputs[0][inputs["input_ids"].shape[1]:]
        generated_text = tokenizer.decode(output_tokens, skip_special_tokens=True)

        print(f"Output: {generated_text}")

def trainIBA(config, main_cfg):
    training_cfg = main_cfg.training
    data_cfg = main_cfg.data

    valid_hf_arg_names = set(inspect.signature(TrainingArguments).parameters.keys())
    training_config_dict = OmegaConf.to_container(
        training_cfg, resolve=True
    )
    filtered_training_args_dict = {
        key: value for key, value in training_config_dict.items()
        if key in valid_hf_arg_names
    }
    trainer_args = TrainingArguments(**filtered_training_args_dict)

    gradient_accumulation_steps = training_cfg.batch_size // training_cfg.per_device_train_batch_size

    device_map = "auto"
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    ddp = world_size != 1
    if ddp:
        device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
        gradient_accumulation_steps = gradient_accumulation_steps // world_size

    base_model_name = main_cfg.model.base_model_name
    # model = AutoModelForCausalLM.from_pretrained(
    #             base_model_name,
    #             load_in_8bit=False,
    #             torch_dtype=torch.float,
    #             device_map={"": int(os.environ.get("LOCAL_RANK") or 0)},
    #             trust_remote_code=True,
    #         )
    model = IbaXs_LlamaForCausalLM(config=config).to(DEVICE) # test
    # model = LlamaForCausalLM(config=config).to('mps')
    # model = IbaXs_LlamaModel.from_pretrained(main_cfg.model.base_model_name, config=config,
    #                         torch_dtype=torch.float, device_map={"": int(os.environ.get("LOCAL_RANK") or 0)},
    #                         trust_remote_code=True)
    
    

    base_model_name = main_cfg.model.base_model_name
    if config.model_type == 'llama':
        # Due to the name of transformers' LlamaTokenizer, we have to do this
        # need to handle llama 3 separately
        if "lama-3" in base_model_name:
            print("load llama-3 tokenizer")
            tokenizer = AutoTokenizer.from_pretrained(base_model_name)
        else:
            tokenizer = LlamaTokenizer.from_pretrained(base_model_name, legacy=True)
    else:
        tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)

    tokenizer.pad_token_id = (
        0  # unk. we want this to be different from the eos token
    )
    tokenizer.padding_side = "left"  # Allow batched inference

    def tokenize(prompt, max_length=training_cfg.cutoff_len, add_eos_token=True):
        result = tokenizer(
            prompt,
            truncation=True,
            max_length=training_cfg.cutoff_len,
            padding=False,
            return_tensors=None,
        )
        if (
                result["input_ids"][-1] != tokenizer.eos_token_id
                and len(result["input_ids"]) < max_length
                and add_eos_token
        ):
            result["input_ids"].append(tokenizer.eos_token_id)
            if "chatglm" not in base_model_name:
                result["attention_mask"].append(1)

        result["labels"] = result["input_ids"].copy()

        if "chatglm" in base_model_name:
            return {"input_ids": result["input_ids"], "labels": result["labels"]}
        else:
            return result

    def generate_and_tokenize_prompt(data_point):
        full_prompt = generate_prompt(data_point)
        tokenized_full_prompt = tokenize(full_prompt)
        
        if not training_cfg.train_on_inputs:
            user_prompt = generate_prompt({**data_point, "output": ""})
            tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False)
            user_prompt_len = len(tokenized_user_prompt["input_ids"])

            tokenized_full_prompt["labels"] = [
                                                  -100
                                              ] * user_prompt_len + tokenized_full_prompt["labels"][
                                                                    user_prompt_len:
                                                                    ]  # could be sped up, probably
        return tokenized_full_prompt
    # print('model', model)

    if data_cfg.data_path.endswith(".json"):
        data = load_dataset("json", data_files=data_cfg.data_path)
    else:
        data = load_dataset(data_cfg.data_path)

    ### Check later
    if training_cfg.resume_from_checkpoint:
        # Check the available weights and load them
        checkpoint_name = os.path.join(
            resume_from_checkpoint, "pytorch_model.bin"
        )  # Full checkpoint
        if not os.path.exists(checkpoint_name):
            checkpoint_name = os.path.join(
                resume_from_checkpoint, "adapter_model.bin"
            )  # only LoRA model - LoRA config above has to fit
            resume_from_checkpoint = (
                False  # So the trainer won't try loading its state
            )
        # The two files above have a different name depending on how they were saved, but are actually the same.
        if os.path.exists(checkpoint_name):
            print(f"Restarting from {checkpoint_name}")
            model = IbaXs_LlamaModel.from_pretrained("./my-saved-model")
        else:
            print(f"Checkpoint {checkpoint_name} not found")

    # model.print_trainable_parameters() 

    if training_cfg.val_set_size > 0:
        train_val = data["train"].train_test_split(
            test_size=training_cfg.val_set_size, shuffle=True, seed=42
        )
        train_data = (
            train_val["train"].map(generate_and_tokenize_prompt, num_proc=8)
        )
        val_data = (
            train_val["test"].map(generate_and_tokenize_prompt)
        )
    else:
        train_data = data["train"].shuffle().map(generate_and_tokenize_prompt, num_proc=8)
        val_data = None

    # print('val data', type(val_data), val_data)
    # for k,v in val_data[0].items():
    #     print('kv', k, ': ', v)
    # count_parameters(model)

    trainer = transformers.Trainer(
        model=model,
        train_dataset=train_data,
        eval_dataset=val_data,
        args=trainer_args,
        data_collator=transformers.DataCollatorForSeq2Seq(
            tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
        ),
    )
    model.config.use_cache = False

    trainer.train(resume_from_checkpoint=training_cfg.resume_from_checkpoint)



def generate_prompt(data_point):
    # sorry about the formatting disaster gotta move fast
    if data_point["input"]:
        return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. 

                ### Instruction:
                {data_point["instruction"]}
                
                ### Input:
                {data_point["input"]}
                
                ### Response:
                {data_point["output"]}""" # noqa: E501
    else:
        return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.  

                ### Instruction:
                {data_point["instruction"]}
                
                ### Response:
                {data_point["output"]}""" # noqa: E501   
    
@hydra.main(config_path="../conf_hydra", config_name="config", version_base='1.3')
def main(main_cfg: MainConfig):
    #
    main_cfg = OmegaConf.merge(OmegaConf.structured(MainConfig), main_cfg)
    # print('Hello,', main_cfg)
    # print(OmegaConf.to_yaml(main_cfg))
    # cfg_dict = asdict(main_cfg)
    main_cfg_dict = OmegaConf.to_container(main_cfg, resolve=True)
    # print(yaml.dump(cfg_dict, indent=2, default_flow_style=False))
    config = AutoConfig.from_pretrained(
        main_cfg.model.base_model_name
    )
    # print(config)
    config.hidden_size=128
    config.intermediate_size=290
    config.num_hidden_layers=3
    # config._attn_implementation = "eager"
    config.head_dim = config.hidden_size // config.num_attention_heads

    # main_cfg_dict = asdict(main_cfg)
    config.main_cfg = main_cfg_dict
    set_seed(main_cfg.seed)
    trainIBA(config, main_cfg)

    

if __name__ == "__main__":
    main()