Andreas Varvarigos commited on
Commit
91e5195
·
verified ·
1 Parent(s): 3feb29a

Delete train.py

Browse files
Files changed (1) hide show
  1. train.py +0 -385
train.py DELETED
@@ -1,385 +0,0 @@
1
- import json
2
- import torch
3
- import random
4
- import transformers
5
- import networkx as nx
6
- from tqdm import tqdm
7
- from peft import (LoraConfig, get_peft_model,
8
- prepare_model_for_kbit_training)
9
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
10
-
11
-
12
-
13
- class QloraTrainer_CS:
14
- def __init__(self, config: dict, use_predefined_graph=False):
15
- self.config = config
16
- self.use_predefined_graph = use_predefined_graph
17
- self.tokenizer = None
18
- self.base_model = None
19
- self.adapter_model = None
20
- self.merged_model = None
21
- self.transformer_trainer = None
22
- self.test_data = None
23
-
24
- template_file_path = 'configs/alpaca.json'
25
- with open(template_file_path) as fp:
26
- self.template = json.load(fp)
27
-
28
-
29
- def load_base_model(self):
30
- model_id = self.config['inference']["base_model"]
31
- print(model_id)
32
-
33
- bnb_config = BitsAndBytesConfig(
34
- load_in_8bit=True,
35
- bnb_8bit_use_double_quant=True,
36
- bnb_8bit_quant_type="nf8",
37
- bnb_8bit_compute_dtype=torch.bfloat16
38
- )
39
- tokenizer = AutoTokenizer.from_pretrained(model_id)
40
- tokenizer.model_max_length = self.config['training']['tokenizer']["max_length"]
41
- if not tokenizer.pad_token:
42
- tokenizer.pad_token = tokenizer.eos_token
43
- model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, torch_dtype=torch.bfloat16)
44
- if model.device.type != 'cuda':
45
- model.to('cuda')
46
-
47
- model.gradient_checkpointing_enable()
48
- model = prepare_model_for_kbit_training(model)
49
-
50
- self.tokenizer = tokenizer
51
- self.base_model = model
52
-
53
-
54
- def train(self):
55
- # Set up lora config or load pre-trained adapter
56
- lora_config = LoraConfig(
57
- r=self.config['training']['qlora']['rank'],
58
- lora_alpha=self.config['training']['qlora']['lora_alpha'],
59
- target_modules=self.config['training']['qlora']['target_modules'],
60
- lora_dropout=self.config['training']['qlora']['lora_dropout'],
61
- bias="none",
62
- task_type="CAUSAL_LM",
63
- )
64
- model = get_peft_model(self.base_model, lora_config)
65
- self._print_trainable_parameters(model)
66
-
67
- print("Start data preprocessing")
68
- train_data = self._process_data_instruction()
69
-
70
- print('Length of dataset: ', len(train_data))
71
-
72
- print("Start training")
73
- self.transformer_trainer = transformers.Trainer(
74
- model=model,
75
- train_dataset=train_data,
76
- args=transformers.TrainingArguments(
77
- per_device_train_batch_size=self.config["training"]['trainer_args']["per_device_train_batch_size"],
78
- gradient_accumulation_steps=self.config['model_saving']['index'],
79
- warmup_steps=self.config["training"]['trainer_args']["warmup_steps"],
80
- num_train_epochs=self.config["training"]['trainer_args']["num_train_epochs"],
81
- learning_rate=self.config["training"]['trainer_args']["learning_rate"],
82
- lr_scheduler_type=self.config["training"]['trainer_args']["lr_scheduler_type"],
83
- fp16=self.config["training"]['trainer_args']["fp16"],
84
- logging_steps=self.config["training"]['trainer_args']["logging_steps"],
85
- output_dir=self.config["training"]['trainer_args']["trainer_output_dir"],
86
- report_to="wandb",
87
- save_steps=self.config["training"]['trainer_args']["save_steps"],
88
- ),
89
- data_collator=transformers.DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
90
- )
91
-
92
- model.config.use_cache = False
93
-
94
- self.transformer_trainer.train()
95
-
96
- model_save_path = f"{self.config['model_saving']['model_output_dir']}/{self.config['model_saving']['model_name']}_{self.config['model_saving']['index']}_adapter_test_graph"
97
- self.transformer_trainer.save_model(model_save_path)
98
-
99
- self.adapter_model = model
100
- print(f"Training complete, adapter model saved in {model_save_path}")
101
-
102
-
103
- def _print_trainable_parameters(self, model):
104
- """
105
- Prints the number of trainable parameters in the model.
106
- """
107
- trainable_params = 0
108
- all_param = 0
109
- for _, param in model.named_parameters():
110
- all_param += param.numel()
111
- if param.requires_grad:
112
- trainable_params += param.numel()
113
- print(
114
- f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
115
- )
116
-
117
-
118
- def _process_data_instruction(self):
119
- context_window = self.tokenizer.model_max_length
120
- if self.use_predefined_graph:
121
- graph_data = nx.read_gexf('datasets/' + self.config["training"]["predefined_graph_path"], node_type=None, relabel=False, version='1.2draft')
122
- else:
123
- graph_path = self.config['data_downloading']['download_directory'] + 'description/' + self.config['data_downloading']['gexf_file']
124
- graph_data = nx.read_gexf(graph_path, node_type=None, relabel=False, version='1.2draft')
125
- raw_graph = graph_data
126
-
127
- test_set_size = len(graph_data.nodes()) // 10
128
-
129
- all_test_nodes = set(list(graph_data.nodes())[:test_set_size])
130
- all_train_nodes = set(list(graph_data.nodes())[test_set_size:])
131
-
132
- raw_id_2_title_abs = dict()
133
- for paper_id in list(graph_data.nodes())[test_set_size:]:
134
- title = graph_data.nodes()[paper_id]['title']
135
- abstract = graph_data.nodes()[paper_id]['abstract']
136
- raw_id_2_title_abs[paper_id] = [title, abstract]
137
-
138
- raw_id_2_intro = dict()
139
- for paper_id in list(graph_data.nodes())[test_set_size:]:
140
- if graph_data.nodes[paper_id]['introduction'] != '':
141
- intro = graph_data.nodes[paper_id]['introduction']
142
- raw_id_2_intro[paper_id] = intro
143
-
144
- raw_id_pair_2_sentence = dict()
145
- for edge in list(graph_data.edges()):
146
- sentence = graph_data.edges()[edge]['sentence']
147
- raw_id_pair_2_sentence[edge] = sentence
148
-
149
-
150
- test_data = []
151
- edge_list = []
152
- for edge in list(raw_graph.edges()):
153
- src, tar = edge
154
- if src not in all_test_nodes and tar not in all_test_nodes:
155
- edge_list.append(edge)
156
- else:
157
- test_data.append(edge)
158
- train_num = int(len(edge_list))
159
-
160
- data_LP = []
161
- data_abstract_2_title = []
162
- data_paper_retrieval = []
163
- data_citation_sentence = []
164
- data_abs_completion = []
165
- data_title_2_abs = []
166
- data_intro_2_abs = []
167
-
168
-
169
- for sample in tqdm(random.sample(edge_list, train_num)):
170
- source, target = sample[0], sample[1]
171
- source_title, source_abs = raw_id_2_title_abs[source]
172
- target_title, target_abs = raw_id_2_title_abs[target]
173
- # LP prompt
174
- rand_ind = random.choice(list(raw_id_2_title_abs.keys()))
175
- neg_title, neg_abs = raw_id_2_title_abs[rand_ind]
176
- data_LP.append({'s_title':source_title, 's_abs':source_abs, 't_title':target_title, 't_abs':target_abs, 'label':'yes'})
177
- data_LP.append({'s_title':source_title, 's_abs':source_abs, 't_title':neg_title, 't_abs':neg_abs, 'label':'no'})
178
-
179
- for sample in tqdm(random.sample(edge_list, train_num)):
180
- source, target = sample[0], sample[1]
181
- source_title, source_abs = raw_id_2_title_abs[source]
182
- target_title, target_abs = raw_id_2_title_abs[target]
183
- # abs_2_title prompt
184
- data_abstract_2_title.append({'title':source_title, 'abs':source_abs})
185
- data_abstract_2_title.append({'title':target_title, 'abs':target_abs})
186
-
187
- for sample in tqdm(random.sample(edge_list, train_num)):
188
- source, target = sample[0], sample[1]
189
- source_title, source_abs = raw_id_2_title_abs[source]
190
- target_title, target_abs = raw_id_2_title_abs[target]
191
- # paper_retrieval prompt
192
- neighbors = list(nx.all_neighbors(raw_graph, source))
193
- sample_node_list = list(all_train_nodes - set(neighbors) - set([source]) - set([target]))
194
- sampled_neg_nodes = random.sample(sample_node_list, 5) + [target]
195
- random.shuffle(sampled_neg_nodes)
196
- data_paper_retrieval.append({'title':source_title, 'abs':source_abs, 'sample_title': [raw_id_2_title_abs[node][0] for node in sampled_neg_nodes], 'right_title':target_title})
197
-
198
- for sample in tqdm(random.sample(edge_list, train_num)):
199
- source, target = sample[0], sample[1]
200
- source_title, source_abs = raw_id_2_title_abs[source]
201
- target_title, target_abs = raw_id_2_title_abs[target]
202
- # citation_sentence prompt
203
- citation_sentence = raw_id_pair_2_sentence[(source, target)] if (source, target) in raw_id_pair_2_sentence.keys() else raw_id_pair_2_sentence[(target, source)]
204
- data_citation_sentence.append({'s_title':source_title, 's_abs':source_abs, 't_title':target_title, 't_abs':target_abs, 'sentence': citation_sentence})
205
-
206
- for sample in tqdm(random.sample(edge_list, train_num)):
207
- source, target = sample[0], sample[1]
208
- source_title, source_abs = raw_id_2_title_abs[source]
209
- target_title, target_abs = raw_id_2_title_abs[target]
210
- # abs_complete prompt
211
- data_abs_completion.append({'title':source_title, 'abs':source_abs})
212
- data_abs_completion.append({'title':target_title, 'abs':target_abs})
213
-
214
- for sample in tqdm(random.sample(edge_list, train_num)):
215
- source, target = sample[0], sample[1]
216
- source_title, source_abs = raw_id_2_title_abs[source]
217
- target_title, target_abs = raw_id_2_title_abs[target]
218
- # title_2_abs prompt
219
- data_title_2_abs.append({'title':source_title, 'right_abs':source_abs})
220
- data_title_2_abs.append({'title':target_title, 'right_abs':target_abs})
221
-
222
- for sample in tqdm(random.sample(edge_list, train_num)):
223
- source, target = sample[0], sample[1]
224
- if source in raw_id_2_intro:
225
- source_intro = raw_id_2_intro[source]
226
- _, source_abs = raw_id_2_title_abs[source]
227
- data_intro_2_abs.append({'intro':source_intro, 'abs':source_abs})
228
- if target in raw_id_2_intro:
229
- target_intro = raw_id_2_intro[target]
230
- _, target_abs = raw_id_2_title_abs[target]
231
- data_intro_2_abs.append({'intro':target_intro, 'abs':target_abs})
232
-
233
- data_prompt = []
234
- data_prompt += [self._generate_paper_retrieval_prompt(data_point) for data_point in data_paper_retrieval]
235
- data_prompt += [self._generate_LP_prompt(data_point) for data_point in data_LP]
236
- data_prompt += [self._generate_abstract_2_title_prompt(data_point) for data_point in data_abstract_2_title]
237
- data_prompt += [self._generate_citation_sentence_prompt(data_point) for data_point in data_citation_sentence]
238
- data_prompt += [self._generate_abstract_completion_prompt(data_point) for data_point in data_abs_completion]
239
- data_prompt += [self._generate_title_2_abstract_prompt(data_point) for data_point in data_title_2_abs]
240
- data_prompt += [self._generate_intro_2_abstract_prompt(data_point, context_window) for data_point in data_intro_2_abs]
241
-
242
- print("Total prompts:", len(data_prompt))
243
- random.shuffle(data_prompt)
244
- if self.tokenizer.chat_template is None:
245
- data_tokenized = [self.tokenizer(sample, max_length=context_window, truncation=True) for sample in tqdm(data_prompt)]
246
- else:
247
- data_tokenized = [self.tokenizer.apply_chat_template(sample, max_length=context_window, truncation=True, tokenize=False) for sample in tqdm(data_prompt)]
248
-
249
- return data_tokenized
250
-
251
-
252
- def _generate_LP_prompt(self, data_point: dict):
253
- instruction = "Determine if paper A will cite paper B."
254
-
255
- prompt_input = ""
256
- prompt_input = prompt_input + "Title of Paper A: " + (data_point['s_title'] if data_point['s_title'] != None else 'Unknown') + "\n"
257
- prompt_input = prompt_input + "Abstract of Paper A: " + (data_point['s_abs'] if data_point['s_abs'] != None else 'Unknown') + "\n"
258
- prompt_input = prompt_input + "Title of Paper B: " + (data_point['t_title'] if data_point['t_title'] != None else 'Unknown') + "\n"
259
- prompt_input = prompt_input + "Abstract of Paper B: " + (data_point['t_abs'] if data_point['t_abs'] != None else 'Unknown') + "\n"
260
-
261
- if self.tokenizer.chat_template is None:
262
- res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input)
263
- res = f"{res}{data_point['label']}"
264
- else:
265
- res = [
266
- {"role": "user", "content": self.template["prompt_input"].format(instruction=instruction, input=prompt_input)},
267
- {"role": "assistant", "content": data_point['label']}
268
- ]
269
-
270
- return res
271
-
272
- def _generate_abstract_2_title_prompt(self, data_point: dict):
273
- instruction = "Please generate the title of paper based on its abstract."
274
-
275
- prompt_input = ""
276
- prompt_input = prompt_input + "Abstract: " + data_point['abs'] + "\n"
277
-
278
- if self.tokenizer.chat_template is None:
279
- res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input)
280
- res = f"{res}{data_point['title']}"
281
- else:
282
- res = [
283
- {"role": "user", "content": self.template["prompt_input"].format(instruction=instruction, input=prompt_input)},
284
- {"role": "assistant", "content": data_point['title']}
285
- ]
286
-
287
- return res
288
-
289
- def _generate_paper_retrieval_prompt(self, data_point: dict):
290
- instruction = "Please select the paper that is more likely to be cited by paper A from candidate papers."
291
-
292
- prompt_input = ""
293
- prompt_input = prompt_input + "Title of the Paper A: " + data_point['title'] + "\n"
294
- prompt_input = prompt_input + "Abstract of the Paper A: " + data_point['abs'] + "\n"
295
- prompt_input = prompt_input + "candidate papers: " + "\n"
296
- for i in range(len(data_point['sample_title'])):
297
- prompt_input = prompt_input + str(i) + '. ' + data_point['sample_title'][i] + "\n"
298
-
299
- if self.tokenizer.chat_template is None:
300
- res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input)
301
- res = f"{res}{data_point['right_title']}"
302
- else:
303
- res = [
304
- {"role": "user", "content": self.template["prompt_input"].format(instruction=instruction, input=prompt_input)},
305
- {"role": "assistant", "content": data_point['right_title']}
306
- ]
307
-
308
- return res
309
-
310
- def _generate_citation_sentence_prompt(self, data_point: dict):
311
- instruction = "Please generate the citation sentence of how Paper A cites paper B in its related work section."
312
-
313
- prompt_input = ""
314
- prompt_input = prompt_input + "Title of Paper A: " + (data_point['s_title'] if data_point['s_title'] != None else 'Unknown') + "\n"
315
- prompt_input = prompt_input + "Abstract of Paper A: " + (data_point['s_abs'] if data_point['s_abs'] != None else 'Unknown') + "\n"
316
- prompt_input = prompt_input + "Title of Paper B: " + (data_point['t_title'] if data_point['t_title'] != None else 'Unknown') + "\n"
317
- prompt_input = prompt_input + "Abstract of Paper B: " + (data_point['t_abs'] if data_point['t_abs'] != None else 'Unknown') + "\n"
318
-
319
- if self.tokenizer.chat_template is None:
320
- res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input)
321
- res = f"{res}{data_point['sentence']}"
322
- else:
323
- res = [
324
- {"role": "user", "content": self.template["prompt_input"].format(instruction=instruction, input=prompt_input)},
325
- {"role": "assistant", "content": data_point['sentence']}
326
- ]
327
-
328
- return res
329
-
330
- def _generate_abstract_completion_prompt(self, data_point: dict):
331
- instruction = "Please complete the abstract of a paper."
332
-
333
- prompt_input = ""
334
- prompt_input = prompt_input + "Title: " + data_point['title'] if data_point['title'] != None else 'Unknown' + "\n"
335
-
336
- split_abs = data_point['abs'][: int(0.3*len(data_point['abs']))]
337
- prompt_input = prompt_input + "Part of abstract: " + split_abs + "\n"
338
-
339
- if self.tokenizer.chat_template is None:
340
- res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input)
341
- res = f"{res}{data_point['abs']}"
342
- else:
343
- res = [
344
- {"role": "user", "content": self.template["prompt_input"].format(instruction=instruction, input=prompt_input)},
345
- {"role": "assistant", "content": data_point['abs']}
346
- ]
347
-
348
- return res
349
-
350
- def _generate_title_2_abstract_prompt(self, data_point: dict):
351
- instruction = "Please generate the abstract of paper based on its title."
352
-
353
- prompt_input = ""
354
- prompt_input = prompt_input + "Title: " + data_point['title'] + "\n"
355
-
356
- if self.tokenizer.chat_template is None:
357
- res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input)
358
- res = f"{res}{data_point['right_abs']}"
359
- else:
360
- res = [
361
- {"role": "user", "content": self.template["prompt_input"].format(instruction=instruction, input=prompt_input)},
362
- {"role": "assistant", "content": data_point['right_abs']}
363
- ]
364
-
365
- return res
366
-
367
- def _generate_intro_2_abstract_prompt(self, data_point: dict, context_window):
368
- instruction = "Please generate the abstract of paper based on its introduction section."
369
-
370
- prompt_input = ""
371
- prompt_input = prompt_input + "Introduction: " + data_point['intro'] + "\n"
372
-
373
- # Reduce it to make it fit
374
- prompt_input = prompt_input[:int(context_window*2)]
375
-
376
- if self.tokenizer.chat_template is None:
377
- res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input)
378
- res = f"{res}{data_point['abs']}"
379
- else:
380
- res = [
381
- {"role": "user", "content": self.template["prompt_input"].format(instruction=instruction, input=prompt_input)},
382
- {"role": "assistant", "content": data_point['abs']}
383
- ]
384
-
385
- return res