Andreas Varvarigos commited on
Commit
34dc5af
·
verified ·
1 Parent(s): 908351f

Delete src/no_UI

Browse files
src/no_UI/README_eval.md DELETED
@@ -1,54 +0,0 @@
1
- # **Evaluation for Literature-Based Tasks (Without User Interface)**
2
-
3
- ## **Overview**
4
- `eval_noUI.py` is a script designed to **evaluate the performance of LLMs** on various literature-related tasks, including **citation sentence generation, link prediction, abstract completion, title generation, paper retrieval, and introduction-to-abstract generation**. The script provides a **batch evaluation pipeline** to assess models trained with **LitBench** datasets or other domain-specific literature datasets.
5
-
6
- It loads a **citation graph dataset** and constructs evaluation prompts for the defined tasks. The script then uses the specified LLM to generate predictions and compares them against ground-truth outputs using **BERTScore and accuracy metrics**.
7
-
8
- ---
9
-
10
- ## **Usage**
11
- To run `eval_noUI.py`, execute the following command:
12
-
13
- ```bash
14
- CUDA_VISIBLE_DEVICES=0 python3.10 src/no_UI/eval_noUI.py \
15
- -config_path=configs/config_noUI.yaml \
16
- -model=lora \
17
- -lorapath=models/llama_1b_qlora_uncensored_1_adapter_test_graph
18
- ```
19
-
20
- ## **Command-Line Arguments**
21
-
22
- - `config_path`: Path to the configuration file for evaluation.
23
- - `model`: Model type (e.g., lora).
24
- - `lorapath`: Path to the LLM model checkpoint.
25
- - `index`: Index of the checkpoint to use for evaluation.
26
-
27
- ---
28
-
29
- ## **Supported Evaluation Tasks**
30
- The script evaluates model performance across six key literature-based tasks:
31
- 1. Citation Sentence Generation (test_sentence)
32
- * Generates a citation sentence describing how Paper A cites Paper B in the related work section.
33
- * Evaluates output coherence using BERTScore.
34
- 2. Citation Link Prediction (test_LP)
35
- * Determines if Paper A is likely to cite Paper B based on their titles and abstracts.
36
- * Evaluates binary classification accuracy.
37
- 3. Abstract Completion (test_abs_completion)
38
- * Completes a partially given abstract using the model’s understanding.
39
- * Evaluates precision, recall, and F1-score using BERTScore.
40
- 4. Title Generation (test_title_generate)
41
- * Predicts a paper’s title based on its abstract.
42
- * Evaluates BERTScore similarity with ground-truth titles.
43
- 5. Citation Recommendation (test_retrival_e)
44
- * Given a paper and a set of candidate papers, selects the one most likely to be cited.
45
- * Evaluates retrieval accuracy.
46
- 6. Introduction to Abstract (test_intro_2_abs)
47
- * Predicts a paper’s abstract based on its introduction section.
48
- * Evaluates BERTScore similarity.
49
-
50
- ---
51
-
52
- ## Dependencies
53
-
54
- Ensure you have the required Python libraries installed, following the instructions in [README.md](../../README.md)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/no_UI/README_finetune.md DELETED
@@ -1,40 +0,0 @@
1
- # **Fine-Tuning for Literature-Based LLMs (without User Interface)**
2
-
3
- ## **Overview**
4
- `finetune_noUI.py` is a script designed to **fine-tune large language models (LLMs) on literature-based tasks** using **QLoRA**. The script supports **training domain-specific models** for citation reasoning, abstract generation, retrieval, and more. It leverages **LoRA adapters** to enable efficient fine-tuning on consumer-grade GPUs.
5
-
6
- The script reads **a citation graph dataset**, constructs training prompts for multiple tasks, and fine-tunes an LLM using the QLoRA framework. The resulting **LoRA-adapted model** can then be used for inference or further training.
7
-
8
- ---
9
-
10
- ## **Usage**
11
- To run `finetune_noUI.py`, execute the following command:
12
-
13
- ```bash
14
- python3.10 src/no_UI/finetune_noUI.py configs/config_noUI.yaml --index 1
15
- ```
16
-
17
- ## **Command-Line Arguments**
18
- - config_path: Path to the YAML configuration
19
- - index: Index specifying GPU/task number (default: 1).
20
-
21
- ## **Supported Fine-Tuning Tasks**
22
- The script fine-tunes LLMs on seven key literature-based tasks, generating instruction-tuned training data:
23
-
24
- 1. **Citation Sentence Generation:** Trains the model to generate citation sentences describing how Paper A cites Paper B in the related work section.
25
-
26
- 2. **Citation Link Prediction:** Trains the model to predict whether Paper A is likely to cite Paper B based on their titles and abstracts.
27
-
28
- 3. **Abstract Completion:** Trains the model to complete an abstract given a partial abstract and a paper title.
29
-
30
- 4. **Title Generation:** Trains the model to generate a paper’s title based on its abstract.
31
-
32
- 5. **Citation Recommendation:** Trains the model to select the most relevant paper from a set of candidates that Paper A is likely to cite.
33
-
34
- 6. **Introduction to Abstract:** Trains the model to generate an abstract based on a paper’s introduction.
35
-
36
- ---
37
-
38
- ## Dependencies
39
-
40
- Ensure you have the required Python libraries installed, following the instructions in [README.md](../../README.md)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/no_UI/eval_noUI.py DELETED
@@ -1,564 +0,0 @@
1
- import argparse
2
- import torch
3
- import json
4
- import random
5
- import networkx as nx
6
- import numpy as np
7
- from tqdm import tqdm
8
- from peft import PeftModel
9
- from transformers import AutoTokenizer, pipeline, AutoModelForCausalLM
10
- from bert_score import score
11
- from tqdm import tqdm
12
- import os
13
- from utils.utils import read_yaml_file
14
-
15
- """
16
- Ad-hoc sanity check to see if model outputs something coherent
17
- Not a robust inference platform!
18
- """
19
-
20
- def get_bert_score(candidate, reference):
21
- P, R, F1 = score([candidate], [reference],lang="en")
22
- return P, R, F1
23
-
24
- def _generate_LP_prompt(data_point: dict):
25
- instruction = "Determine if paper A will cite paper B."
26
-
27
- prompt_input = ""
28
- prompt_input = prompt_input + "Title of Paper A: " + (data_point['s_title'] if data_point['s_title'] != None else 'Unknown') + "\n"
29
- prompt_input = prompt_input + "Abstract of Paper A: " + (data_point['s_abs'] if data_point['s_abs'] != None else 'Unknown') + "\n"
30
-
31
- prompt_input = prompt_input + "Title of Paper B: " + (data_point['t_title'] if data_point['t_title'] != None else 'Unknown') + "\n"
32
- prompt_input = prompt_input + "Abstract of Paper B: " + (data_point['t_abs'] if data_point['t_abs'] != None else 'Unknown') + "\n"
33
-
34
- prompt_input = prompt_input + " Give me a direct answer of yes or no."
35
-
36
- res = template["prompt_input"].format(instruction=instruction, input=prompt_input)
37
- return res
38
-
39
- def _generate_retrival_prompt(data_point: dict):
40
- instruction = "Please select the paper that is more likely to be cited by paper A from candidate papers."
41
-
42
- prompt_input = ""
43
- prompt_input = prompt_input + "Title of the Paper A: " + data_point['s_title'] + "\n"
44
- prompt_input = prompt_input + "Abstract of the Paper A: " + data_point['s_abs'] + "\n"
45
- prompt_input = prompt_input + "candidate papers: " + "\n"
46
- for i in range(len(data_point['nei_titles'])):
47
- prompt_input = prompt_input + str(i) + '. ' + data_point['nei_titles'][i] + "\n"
48
-
49
- prompt_input = prompt_input + "Give me the title of your selected paper."
50
-
51
- res = template["prompt_input"].format(instruction=instruction, input=prompt_input)
52
-
53
-
54
- return res, str(data_point['t_title'])
55
-
56
- def _generate_abstrat_2_title_prompt(data_point: dict):
57
- instruction = "Please generate the title of paper based on its abstract"
58
-
59
- prompt_input = ""
60
- prompt_input = prompt_input + "Abstract: " + data_point['abs'] + "\n"
61
- res = template["prompt_input"].format(instruction=instruction, input=prompt_input)
62
-
63
- return res
64
-
65
- def _generate_sentence_prompt(data_point):
66
- instruction = "Please generate the citation sentence of how Paper A cites paper B in its related work section. \n"
67
-
68
- prompt_input = ""
69
- prompt_input = prompt_input + "Title of Paper A: " + data_point['s_title'] + '\n' if data_point['s_title'] != None else 'Unknown' + "\n"
70
- prompt_input = prompt_input + "Abstract of Paper A: " + data_point['s_abs'] + '\n' if data_point['s_abs'] != None else 'Unknown' + "\n"
71
- prompt_input = prompt_input + "Title of Paper B: " + data_point['t_title'] + '\n' if data_point['t_title'] != None else 'Unknown' + "\n"
72
- prompt_input = prompt_input + "Abstract of Paper B: " + data_point['t_abs'] + '\n' if data_point['t_abs'] != None else 'Unknown' + "\n"
73
-
74
- res = template["prompt_input"].format(instruction=instruction, input=prompt_input)
75
- return res
76
-
77
- def _generate_abstrat_completion_prompt(data_point: dict):
78
- instruction = "Please complete the abstract of a paper."
79
-
80
- split_abs = data_point['abs'][: int(0.1*len(data_point['abs']))]
81
-
82
- prompt_input = ""
83
- prompt_input = prompt_input + "Title: " + data_point['title'] + "\n"
84
- prompt_input = prompt_input + "Part of abstract: " + split_abs
85
- res = template["prompt_input"].format(instruction=instruction, input=prompt_input)
86
-
87
- return res
88
-
89
- def get_llm_response(prompt, task):
90
- if task == 'sentence':
91
- return pipe_sentence(prompt)
92
- if task == 'LP':
93
- return pipe_LP(prompt)
94
- if task == 'abstract':
95
- return pipe_abstract(prompt)
96
- if task == 'title':
97
- return pipe_title(prompt)
98
- if task == 'retrieval':
99
- return pipe_retrieval(prompt)
100
- if task == 'intro':
101
- return pipe_intro(prompt)
102
-
103
-
104
- def test_sentence():
105
- Bert_p_list = []
106
- Bert_r_list = []
107
- Bert_f_list = []
108
-
109
- result_dict = {}
110
- # pos test
111
- for i in tqdm(range(len(test_data))):
112
- source, target = test_data[i][0], test_data[i][1]
113
- source_title, source_abs = raw_id_2_tile_abs[source]
114
- target_title, target_abs = raw_id_2_tile_abs[target]
115
-
116
- s_nei = list(nx.all_neighbors(raw_graph, source))
117
- s_nei_list = list(set(s_nei) - set([source]) - set([target]))[:10]
118
- s_nei_titles = [raw_id_2_tile_abs[i][0] for i in s_nei_list]
119
-
120
- t_nei = list(nx.all_neighbors(raw_graph, target))
121
- t_nei_list = list(set(t_nei) - set([source]) - set([target]))[:10]
122
- t_nei_titles = [raw_id_2_tile_abs[i][0] for i in t_nei_list]
123
-
124
- t_nei_sentence = []
125
- for i in range(len(t_nei_list)):
126
- tmp_sentence = raw_id_pair_2_sentence[(t_nei_list[i], target)] if (t_nei_list[i], target) in raw_id_pair_2_sentence.keys() else ''
127
- if len(tmp_sentence) != 0:
128
- t_nei_sentence.append(tmp_sentence)
129
-
130
- citation_sentence = raw_id_pair_2_sentence[(source, target)] if (source, target) in raw_id_pair_2_sentence.keys() else raw_id_pair_2_sentence[(target, source)]
131
-
132
- datapoint = {'s_title':source_title, 's_abs':source_abs, 't_title':target_title, 't_abs':target_abs, 's_nei':s_nei_titles, 't_nei':t_nei_titles, 't_nei_sentence':t_nei_sentence, 'sentence': citation_sentence}
133
-
134
- prompt = _generate_sentence_prompt(datapoint)
135
- ans = get_llm_response(prompt, 'sentence')[0]['generated_text']
136
- res = ans.strip().split(human_instruction[1])[-1]
137
-
138
- result_dict[(source, target)] = [source_title, source_abs, target_title, target_abs, citation_sentence, res]
139
- Bert_p, Bert_r, Bert_f = get_bert_score(res, citation_sentence)
140
-
141
- print("Answer is:", ans)
142
- print("Stripped result is:", res)
143
- print("Citation sentence:", citation_sentence)
144
-
145
- Bert_p_list.append(Bert_p.item())
146
- Bert_r_list.append(Bert_r.item())
147
- Bert_f_list.append(Bert_f.item())
148
- print([len(Bert_p_list), np.mean(Bert_p_list), np.mean(Bert_r_list), np.mean(Bert_f_list)])
149
-
150
- return np.mean(Bert_p_list), np.mean(Bert_r_list), np.mean(Bert_f_list)
151
-
152
-
153
- def test_LP():
154
- result_list = []
155
- # pos test
156
- for i in tqdm(range(len(test_data))):
157
- source, target = test_data[i][0], test_data[i][1]
158
- source_title, source_abs = raw_id_2_tile_abs[source]
159
- target_title, target_abs = raw_id_2_tile_abs[target]
160
-
161
- s_nei = list(nx.all_neighbors(raw_graph, source))
162
- s_nei_list = list(set(s_nei) - set([source]) - set([target]))[:5]
163
- s_nei_titles = [raw_id_2_tile_abs[i][0] for i in s_nei_list]
164
-
165
- t_nei = list(nx.all_neighbors(raw_graph, target))
166
- t_nei_list = list(set(t_nei) - set([source]) - set([target]))[:5]
167
- t_nei_titles = [raw_id_2_tile_abs[i][0] for i in t_nei_list]
168
-
169
- datapoint = {'s_title':source_title, 's_abs':source_abs, 't_title':target_title, 't_abs':target_abs, 's_nei':s_nei_titles, 't_nei':t_nei_titles, 'label':'yes'}
170
-
171
- prompt = _generate_LP_prompt(datapoint)
172
- ans = get_llm_response(prompt, 'LP')[0]['generated_text']
173
-
174
- res = ans.strip().split(human_instruction[1])[-1]
175
- print("Answer is:", res)
176
- if 'yes' in res[:4] or 'Yes' in res[:4]:
177
- result_list.append(1)
178
- else:
179
- result_list.append(0)
180
- print("Current value:", np.mean(result_list))
181
-
182
-
183
- # neg test
184
- for i in tqdm(range(len(test_data))):
185
- source, target = test_data[i][0], random.sample(list(graph_data.nodes()), 1)[0]
186
- source_title, source_abs = raw_id_2_tile_abs[source]
187
- target_title, target_abs = raw_id_2_tile_abs[target]
188
-
189
- s_nei = list(nx.all_neighbors(raw_graph, source))
190
- s_nei_list = list(set(s_nei) - set([source]) - set([target]))[:5]
191
- s_nei_titles = [raw_id_2_tile_abs[i][0] for i in s_nei_list]
192
-
193
- try:
194
- t_nei = list(nx.all_neighbors(raw_graph, target))
195
- except:
196
- t_nei = []
197
- t_nei_list = list(set(t_nei) - set([source]) - set([target]))[:5]
198
- t_nei_titles = [raw_id_2_tile_abs[i][0] for i in t_nei_list]
199
-
200
- datapoint = {'s_title':source_title, 's_abs':source_abs, 't_title':target_title, 't_abs':target_abs, 's_nei':s_nei_titles, 't_nei':t_nei_titles, 'label':'no'}
201
-
202
- prompt = _generate_LP_prompt(datapoint)
203
- ans = get_llm_response(prompt, 'LP')[0]['generated_text']
204
-
205
- res = ans.strip().split(human_instruction[1])[-1]
206
-
207
- print("Answer is:", res)
208
-
209
- if 'No' in res[:4] or 'no' in res[:4]:
210
- result_list.append(1)
211
- else:
212
- result_list.append(0)
213
- print("Current value:", np.mean(result_list))
214
-
215
- return np.mean(result_list)
216
-
217
-
218
- def test_title_generate():
219
- result_dict = {}
220
- Bert_p_list = []
221
- Bert_r_list = []
222
- Bert_f_list = []
223
- # pos test
224
- for i in tqdm(range(len(test_data))):
225
- source, target = test_data[i][0], test_data[i][1]
226
- title, abstract = raw_id_2_tile_abs[source]
227
- if title == None or abstract == None:
228
- continue
229
-
230
- retrieval_nei = list(nx.all_neighbors(raw_graph, source))
231
- retrieval_nei_list = list(set(retrieval_nei) - set([source]) - set([target]))[:5]
232
- retrieval_nei_titles = [raw_id_2_tile_abs[i][0] for i in retrieval_nei_list]
233
-
234
- datapoint = {'title':title, 'abs':abstract, 'retrieval_nei_titles':retrieval_nei_titles}
235
-
236
- prompt = _generate_abstrat_2_title_prompt(datapoint)
237
- ans = get_llm_response(prompt, 'title')[0]['generated_text']
238
-
239
- res = ans.strip().split(human_instruction[1])[-1]
240
-
241
- result_dict[source] = [title, abstract, res]
242
-
243
- print(ans)
244
- print(res)
245
- print(title)
246
-
247
- Bert_p, Bert_r, Bert_f = get_bert_score(res, title)
248
- Bert_p_list.append(Bert_p.item())
249
- Bert_r_list.append(Bert_r.item())
250
- Bert_f_list.append(Bert_f.item())
251
- print([len(Bert_p_list), np.mean(Bert_p_list), np.mean(Bert_r_list), np.mean(Bert_f_list)])
252
-
253
- return np.mean(Bert_p_list), np.mean(Bert_r_list), np.mean(Bert_f_list)
254
-
255
-
256
- def test_abs_completion():
257
- result_dict = {}
258
- Bert_p_list = []
259
- Bert_r_list = []
260
- Bert_f_list = []
261
- # pos test
262
- for i in tqdm(range(len(test_data))):
263
- source, target = test_data[i][0], test_data[i][1]
264
- title, abstract = raw_id_2_tile_abs[source]
265
- if title == None or abstract == None:
266
- continue
267
-
268
- retrieval_nei = list(nx.all_neighbors(raw_graph, source)) #node_id_2_retrieval_papers[source]
269
- retrieval_nei_list = list(set(retrieval_nei) - set([source]) - set([target]))[:5]
270
- retrieval_nei_abs = [raw_id_2_tile_abs[i][1] for i in retrieval_nei_list]
271
-
272
- datapoint = {'title':title, 'abs':abstract, 'nei_abs':retrieval_nei_abs}
273
-
274
- prompt = _generate_abstrat_completion_prompt(datapoint)
275
- ans = get_llm_response(prompt, 'abstract')[0]['generated_text']
276
-
277
- res = ans.strip().split(human_instruction[1])[-1]
278
-
279
- result_dict[source] = [title, abstract, res]
280
-
281
- print(ans)
282
- print(res)
283
- print(abstract)
284
-
285
- Bert_p, Bert_r, Bert_f = get_bert_score(res, abstract)
286
-
287
-
288
- Bert_p_list.append(Bert_p.item())
289
- Bert_r_list.append(Bert_r.item())
290
- Bert_f_list.append(Bert_f.item())
291
- print([len(Bert_p_list), np.mean(Bert_p_list), np.mean(Bert_r_list), np.mean(Bert_f_list)])
292
-
293
- return np.mean(Bert_p_list), np.mean(Bert_r_list), np.mean(Bert_f_list)
294
-
295
-
296
- def test_retrival_e():
297
- result_list = []
298
- # pos test
299
- for i in tqdm(range(len(test_data))):
300
- source, target = test_data[i][0], test_data[i][1]
301
- source_title, source_abs = raw_id_2_tile_abs[source]
302
- target_title, _ = raw_id_2_tile_abs[target]
303
-
304
- neighbors = list(nx.all_neighbors(raw_graph, source))
305
- sample_node_list = list(set(raw_graph.nodes()) - set(neighbors) - set([source]) - set([target]))
306
- sampled_neg_nodes = random.sample(sample_node_list, 5) + [target]
307
- random.shuffle(sampled_neg_nodes)
308
-
309
- retrieval_nei = list(nx.all_neighbors(raw_graph, source)) #node_id_2_retrieval_papers[source] # neighbors
310
- retrieval_nei_list = list(set(retrieval_nei) - set([source]) - set([target]))[:3]
311
- retrieval_nei_titles = [raw_id_2_tile_abs[i][0] for i in retrieval_nei_list]
312
-
313
- datapoint = {'s_title':source_title, 's_abs':source_abs, 't_title':target_title, 'nei_titles':[raw_id_2_tile_abs[node][0] for node in sampled_neg_nodes], 'retrieval_nei_title':retrieval_nei_titles}
314
- prompt, _ = _generate_retrival_prompt(datapoint)
315
- ans = get_llm_response(prompt, 'retrieval')[0]['generated_text']
316
-
317
-
318
- res = ans.strip().split(human_instruction[1])[-1].lower()
319
- target_title = target_title.lower()
320
-
321
- print(ans)
322
- print("###GT: " + target_title)
323
- print(res)
324
- if target_title in res or res in target_title:
325
- result_list.append(1)
326
- else:
327
- result_list.append(0)
328
- print([sum(result_list), len(result_list)])
329
-
330
- print([sum(result_list), len(result_list)])
331
- return np.mean(result_list)
332
-
333
-
334
-
335
- def _generate_intro_2_abstract_prompt(data_point: dict, context_window):
336
- instruction = "Please generate the abstract of paper based on its introduction section."
337
-
338
- prompt_input = ""
339
- prompt_input = prompt_input + "Introduction: " + data_point['intro'] + "\n"
340
-
341
- # Reduce it to make it fit
342
- prompt_input = prompt_input[:int(context_window*2)]
343
-
344
- res = template["prompt_input"].format(instruction=instruction, input=prompt_input)
345
-
346
- return res
347
-
348
-
349
- def test_intro_2_abs():
350
- result_dict = {}
351
- Bert_p_list = []
352
- Bert_r_list = []
353
- Bert_f_list = []
354
- # pos test
355
- for i in tqdm(range(len(test_data))):
356
- source, target = test_data[i][0], test_data[i][1]
357
-
358
- if source not in raw_id_2_intro:
359
- source = target
360
-
361
- if source not in raw_id_2_intro:
362
- continue
363
-
364
- title, abstract = raw_id_2_tile_abs[source]
365
- intro = raw_id_2_intro[source]
366
-
367
- datapoint = {'abs':abstract, 'intro':intro}
368
- prompt = _generate_intro_2_abstract_prompt(datapoint, tokenizer.model_max_length)
369
- ans = get_llm_response(prompt, 'intro')[0]['generated_text']
370
-
371
- res = ans.strip().split(human_instruction[1]+'\n')[-1]
372
-
373
- result_dict[source] = [title, abstract, res]
374
-
375
- print(ans)
376
- print(res)
377
- print(abstract)
378
-
379
- Bert_p, Bert_r, Bert_f = get_bert_score(res, abstract)
380
-
381
- Bert_p_list.append(Bert_p.item())
382
- Bert_r_list.append(Bert_r.item())
383
- Bert_f_list.append(Bert_f.item())
384
- print([len(Bert_p_list), np.mean(Bert_p_list), np.mean(Bert_r_list), np.mean(Bert_f_list)])
385
-
386
- return np.mean(Bert_p_list), np.mean(Bert_r_list), np.mean(Bert_f_list)
387
-
388
-
389
- if __name__ == "__main__":
390
- parser = argparse.ArgumentParser()
391
- parser.add_argument("-config_path", help="Path to the config YAML file")
392
- parser.add_argument("-model", help="Path to the config YAML file")
393
- parser.add_argument("-lorapath", help="Path to the config YAML file")
394
- parser.add_argument("-prompt_num", help="Path to the config YAML file", default = 1)
395
- args = parser.parse_args()
396
-
397
- config = read_yaml_file(args.config_path)
398
- random.seed(42)
399
- print("Load model")
400
- model_path = config["eval"]["base_model"]
401
-
402
- tokenizer = AutoTokenizer.from_pretrained(model_path)
403
- base_model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, load_in_8bit=True)
404
- if base_model.device.type != 'cuda':
405
- base_model.to('cuda')
406
- tokenizer.model_max_length = 2048
407
- tokenizer.pad_token = tokenizer.eos_token
408
-
409
- adapter_save_path = args.lorapath
410
- model = PeftModel.from_pretrained(base_model, adapter_save_path)
411
- model = model.merge_and_unload()
412
-
413
- pipe_LP = pipeline(
414
- "text-generation",
415
- model=model,
416
- tokenizer=tokenizer,
417
- max_new_tokens=2,
418
- temperature=0.7,
419
- top_p=0.95,
420
- repetition_penalty=1.15
421
- )
422
-
423
- pipe_title = pipeline(
424
- "text-generation",
425
- model=model,
426
- tokenizer=tokenizer,
427
- max_new_tokens=100,
428
- temperature=0.7,
429
- top_p=0.95,
430
- repetition_penalty=1.15
431
- )
432
-
433
- pipe_abstract = pipeline(
434
- "text-generation",
435
- model=model,
436
- tokenizer=tokenizer,
437
- max_new_tokens=256,
438
- temperature=0.7,
439
- top_p=0.95,
440
- repetition_penalty=1.15
441
- )
442
-
443
- pipe_intro = pipeline(
444
- "text-generation",
445
- model=model,
446
- tokenizer=tokenizer,
447
- max_new_tokens=256,
448
- temperature=0.7,
449
- top_p=0.95,
450
- repetition_penalty=1.15
451
- )
452
-
453
-
454
- pipe_sentence = pipeline(
455
- "text-generation",
456
- model=model,
457
- tokenizer=tokenizer,
458
- max_new_tokens=64,
459
- temperature=0.7,
460
- top_p=0.95,
461
- repetition_penalty=1.15
462
- )
463
-
464
- pipe_retrieval = pipeline(
465
- "text-generation",
466
- model=model,
467
- tokenizer=tokenizer,
468
- max_new_tokens=64,
469
- temperature=0.7,
470
- top_p=0.95,
471
- repetition_penalty=1.15
472
- )
473
-
474
- graph_path = config["eval"]["graph_path"]
475
- graph_data = nx.read_gexf(graph_path, node_type=None, relabel=False, version='1.2draft')
476
-
477
- raw_graph = graph_data
478
-
479
- test_set_size = 50
480
- all_test_nodes = set(list(graph_data.nodes())[:test_set_size])
481
- all_train_nodes = set(list(graph_data.nodes())[test_set_size:])
482
-
483
- raw_id_2_tile_abs = dict()
484
- for paper_id in list(graph_data.nodes()):
485
- title = graph_data.nodes()[paper_id]['title']
486
- abstract = graph_data.nodes()[paper_id]['abstract']
487
- raw_id_2_tile_abs[paper_id] = [title, abstract]
488
-
489
- raw_id_pair_2_sentence = dict()
490
- for edge in list(graph_data.edges()):
491
- sentence = graph_data.edges()[edge].get('sentence', '')
492
- raw_id_pair_2_sentence[edge] = sentence
493
-
494
- raw_id_2_intro = dict()
495
- for paper_id in list(graph_data.nodes())[test_set_size:]:
496
- if graph_data.nodes[paper_id]['introduction'] != '':
497
- intro = graph_data.nodes[paper_id]['introduction']
498
- raw_id_2_intro[paper_id] = intro
499
-
500
- test_data = []
501
- edge_list = []
502
- for edge in list(raw_graph.edges()):
503
- src, tar = edge
504
- if src not in all_test_nodes and tar not in all_test_nodes:
505
- edge_list.append(edge)
506
- else:
507
- test_data.append(edge)
508
-
509
-
510
- with open('configs/alpaca.json') as fp:
511
- template = json.load(fp)
512
- human_instruction = ['### Input:', '### Response:']
513
-
514
-
515
- LP_score = test_LP()
516
- retrieval_score = test_retrival_e()
517
- title_p, title_r, title_f = test_title_generate()
518
- sentence_p, sentence_r, sentence_f = test_sentence()
519
- abstract_p, abstract_r, abstract_f = test_abs_completion()
520
- intro_p, intro_r, intro_f = test_intro_2_abs()
521
-
522
- print("Retrieval Score:", retrieval_score)
523
- print("LP Score:", LP_score)
524
- print("Title:", [title_p, title_r, title_f])
525
- print("Sentence:", [sentence_p, sentence_r, sentence_f])
526
- print("Abstract:", [abstract_p, abstract_r, abstract_f])
527
- print("Intro:", [intro_p, intro_r, intro_f])
528
-
529
- results = {
530
- "LP_score": LP_score,
531
- "retrieval_score": retrieval_score,
532
- "title": {
533
- "precision": title_p,
534
- "recall": title_r,
535
- "f1": title_f
536
- },
537
- "sentence": {
538
- "precision": sentence_p,
539
- "recall": sentence_r,
540
- "f1": sentence_f
541
- },
542
- "abstract": {
543
- "precision": abstract_p,
544
- "recall": abstract_r,
545
- "f1": abstract_f
546
- },
547
- "intro": {
548
- "precision": intro_p,
549
- "recall": intro_r,
550
- "f1": intro_f
551
- },
552
- }
553
-
554
- graph_name = graph_path.split('/')[-1].split('.')[0]
555
-
556
- name_save = config["eval"]["model_name"]
557
-
558
- try:
559
- os.mkdir("eval")
560
- except:
561
- pass
562
-
563
- with open(f"eval/{name_save}_{graph_name}_results.json", "w") as f:
564
- json.dump(results, f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/no_UI/finetune_noUI.py DELETED
@@ -1,382 +0,0 @@
1
- import json
2
- import torch
3
- import random
4
- import transformers
5
- import networkx as nx
6
- from tqdm import tqdm
7
- from peft import (LoraConfig, get_peft_model,
8
- prepare_model_for_kbit_training)
9
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
10
- import argparse
11
- import wandb
12
- from utils.utils import read_yaml_file
13
-
14
-
15
-
16
- class QloraTrainer_CS:
17
- def __init__(self, config: dict, index, use_predefined_graph=False):
18
- self.config = config
19
- self.tokenizer = None
20
- self.base_model = None
21
- self.adapter_model = None
22
- self.merged_model = None
23
- self.index = index
24
- self.transformer_trainer = None
25
- self.test_data = None
26
- self.use_predefined_graph = use_predefined_graph
27
-
28
- template_file_path = 'configs/alpaca.json'
29
- with open(template_file_path) as fp:
30
- self.template = json.load(fp)
31
-
32
-
33
- def load_base_model(self):
34
- model_id = self.config['training']['base_model']
35
- print(model_id)
36
-
37
- bnb_config = BitsAndBytesConfig(
38
- load_in_8bit=True,
39
- bnb_8bit_use_double_quant=True,
40
- bnb_8bit_quant_type="nf8",
41
- bnb_8bit_compute_dtype=torch.bfloat16
42
- )
43
- print('load llama 3')
44
- tokenizer = AutoTokenizer.from_pretrained(model_id)
45
- tokenizer.model_max_length = self.config['training']['tokenizer']["max_length"]
46
- if not tokenizer.pad_token:
47
- tokenizer.pad_token = tokenizer.eos_token
48
- model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, torch_dtype=torch.bfloat16)
49
- if model.device.type != 'cuda':
50
- model.to('cuda')
51
-
52
- model.gradient_checkpointing_enable()
53
- model = prepare_model_for_kbit_training(model)
54
-
55
- self.tokenizer = tokenizer
56
- self.base_model = model
57
-
58
-
59
- def train(self):
60
- # Set up lora config or load pre-trained adapter
61
- config = LoraConfig(
62
- r=self.config['training']['qlora']['rank'],
63
- lora_alpha=self.config['training']['qlora']['lora_alpha'],
64
- target_modules=self.config['training']['qlora']['target_modules'],
65
- lora_dropout=self.config['training']['qlora']['lora_dropout'],
66
- bias="none",
67
- task_type="CAUSAL_LM",
68
- )
69
- model = get_peft_model(self.base_model, config)
70
- self._print_trainable_parameters(model)
71
-
72
- print("Start data preprocessing")
73
- train_data = self._process_data_instruction()
74
-
75
- print('Length of dataset: ', len(train_data))
76
-
77
- print("Start training")
78
- self.transformer_trainer = transformers.Trainer(
79
- model=model,
80
- train_dataset=train_data,
81
- args=transformers.TrainingArguments(
82
- per_device_train_batch_size=self.config['training']['trainer_args']["per_device_train_batch_size"],
83
- gradient_accumulation_steps=int(self.index),
84
- warmup_steps=self.config['training']['trainer_args']["warmup_steps"],
85
- num_train_epochs=self.config['training']['trainer_args']["num_train_epochs"],
86
- learning_rate=self.config['training']['trainer_args']["learning_rate"],
87
- lr_scheduler_type=self.config['training']['trainer_args']["lr_scheduler_type"],
88
- fp16=self.config['training']['trainer_args']["fp16"],
89
- logging_steps=self.config['training']['trainer_args']["logging_steps"],
90
- output_dir=self.config['training']['trainer_args']["trainer_output_dir"],
91
- report_to="wandb",
92
- save_steps=self.config['training']['trainer_args']["save_steps"],
93
- ),
94
- data_collator=transformers.DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
95
- )
96
-
97
- model.config.use_cache = False
98
-
99
- self.transformer_trainer.train()
100
-
101
- model_save_path = f"{self.config['training']['model_saving']['model_output_dir']}/{self.config['training']['model_saving']['model_name']}_{str(self.index)}_adapter_test_graph"
102
- self.transformer_trainer.save_model(model_save_path)
103
-
104
- self.adapter_model = model
105
- print(f"Training complete, adapter model saved in {model_save_path}")
106
-
107
-
108
- def _print_trainable_parameters(self, model):
109
- """
110
- Prints the number of trainable parameters in the model.
111
- """
112
- trainable_params = 0
113
- all_param = 0
114
- for _, param in model.named_parameters():
115
- all_param += param.numel()
116
- if param.requires_grad:
117
- trainable_params += param.numel()
118
- print(
119
- f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
120
- )
121
-
122
-
123
- def _process_data_instruction(self):
124
- context_window = self.tokenizer.model_max_length
125
- graph_data = nx.read_gexf(self.config["training"]["graph_path"], node_type=None, relabel=False, version='1.2draft')
126
- raw_graph = graph_data
127
-
128
- test_set_size = len(graph_data.nodes()) // 10
129
-
130
- all_test_nodes = set(list(graph_data.nodes())[:test_set_size])
131
- all_train_nodes = set(list(graph_data.nodes())[test_set_size:])
132
-
133
- raw_id_2_title_abs = dict()
134
- for paper_id in list(graph_data.nodes())[test_set_size:]:
135
- title = graph_data.nodes()[paper_id]['title']
136
- abstract = graph_data.nodes()[paper_id]['abstract']
137
- raw_id_2_title_abs[paper_id] = [title, abstract]
138
-
139
- raw_id_2_title_abs_test = dict()
140
- for paper_id in list(graph_data.nodes()):
141
- title = graph_data.nodes()[paper_id]['title']
142
- abstract = graph_data.nodes()[paper_id]['abstract']
143
- raw_id_2_title_abs_test[paper_id] = [title, abstract]
144
-
145
- raw_id_2_intro = dict()
146
- for paper_id in list(graph_data.nodes())[test_set_size:]:
147
- if graph_data.nodes[paper_id]['introduction'] != '':
148
- intro = graph_data.nodes[paper_id]['introduction']
149
- raw_id_2_intro[paper_id] = intro
150
-
151
- raw_id_pair_2_sentence = dict()
152
- for edge in list(graph_data.edges()):
153
- sentence = graph_data.edges()[edge]['sentence']
154
- raw_id_pair_2_sentence[edge] = sentence
155
-
156
-
157
- test_data = []
158
- edge_list = []
159
- for edge in list(raw_graph.edges()):
160
- src, tar = edge
161
- if src not in all_test_nodes and tar not in all_test_nodes:
162
- edge_list.append(edge)
163
- else:
164
- test_data.append(edge)
165
- train_num = int(len(edge_list))
166
-
167
- data_LP = []
168
- data_abstract_2_title = []
169
- data_paper_retrieval = []
170
- data_citation_sentence = []
171
- data_abs_completion = []
172
- data_intro_2_abs = []
173
-
174
-
175
- for sample in tqdm(random.sample(edge_list, train_num)):
176
- source, target = sample[0], sample[1]
177
- source_title, source_abs = raw_id_2_title_abs[source]
178
- target_title, target_abs = raw_id_2_title_abs[target]
179
- # LP prompt
180
- rand_ind = random.choice(list(raw_id_2_title_abs.keys()))
181
- neg_title, neg_abs = raw_id_2_title_abs[rand_ind]
182
- data_LP.append({'s_title':source_title, 's_abs':source_abs, 't_title':target_title, 't_abs':target_abs, 'label':'yes'})
183
- data_LP.append({'s_title':source_title, 's_abs':source_abs, 't_title':neg_title, 't_abs':neg_abs, 'label':'no'})
184
-
185
- for sample in tqdm(random.sample(edge_list, train_num)):
186
- source, target = sample[0], sample[1]
187
- source_title, source_abs = raw_id_2_title_abs[source]
188
- target_title, target_abs = raw_id_2_title_abs[target]
189
- # abs_2_title prompt
190
- data_abstract_2_title.append({'title':source_title, 'abs':source_abs})
191
- data_abstract_2_title.append({'title':target_title, 'abs':target_abs})
192
-
193
- for sample in tqdm(random.sample(edge_list, train_num)):
194
- source, target = sample[0], sample[1]
195
- source_title, source_abs = raw_id_2_title_abs[source]
196
- target_title, target_abs = raw_id_2_title_abs[target]
197
- # paper_retrieval prompt
198
- neighbors = list(nx.all_neighbors(raw_graph, source))
199
- sample_node_list = list(all_train_nodes - set(neighbors) - set([source]) - set([target]))
200
- sampled_neg_nodes = random.sample(sample_node_list, 5) + [target]
201
- random.shuffle(sampled_neg_nodes)
202
- data_paper_retrieval.append({'title':source_title, 'abs':source_abs, 'sample_title': [raw_id_2_title_abs[node][0] for node in sampled_neg_nodes], 'right_title':target_title})
203
-
204
- for sample in tqdm(random.sample(edge_list, train_num)):
205
- source, target = sample[0], sample[1]
206
- source_title, source_abs = raw_id_2_title_abs[source]
207
- target_title, target_abs = raw_id_2_title_abs[target]
208
- # citation_sentence prompt
209
- citation_sentence = raw_id_pair_2_sentence[(source, target)] if (source, target) in raw_id_pair_2_sentence.keys() else raw_id_pair_2_sentence[(target, source)]
210
- data_citation_sentence.append({'s_title':source_title, 's_abs':source_abs, 't_title':target_title, 't_abs':target_abs, 'sentence': citation_sentence})
211
-
212
- for sample in tqdm(random.sample(edge_list, train_num)):
213
- source, target = sample[0], sample[1]
214
- source_title, source_abs = raw_id_2_title_abs[source]
215
- target_title, target_abs = raw_id_2_title_abs[target]
216
- # abs_complete prompt
217
- data_abs_completion.append({'title':source_title, 'abs':source_abs})
218
- data_abs_completion.append({'title':target_title, 'abs':target_abs})
219
-
220
- for sample in tqdm(random.sample(edge_list, train_num)):
221
- source, target = sample[0], sample[1]
222
- if source in raw_id_2_intro:
223
- source_intro = raw_id_2_intro[source]
224
- _, source_abs = raw_id_2_title_abs[source]
225
- data_intro_2_abs.append({'intro':source_intro, 'abs':source_abs})
226
- if target in raw_id_2_intro:
227
- target_intro = raw_id_2_intro[target]
228
- _, target_abs = raw_id_2_title_abs[target]
229
- data_intro_2_abs.append({'intro':target_intro, 'abs':target_abs})
230
-
231
- data_prompt = []
232
- data_prompt += [self._generate_paper_retrieval_prompt(data_point) for data_point in data_paper_retrieval]
233
- data_prompt += [self._generate_LP_prompt(data_point) for data_point in data_LP]
234
- data_prompt += [self._generate_abstract_2_title_prompt(data_point) for data_point in data_abstract_2_title]
235
- data_prompt += [self._generate_citation_sentence_prompt(data_point) for data_point in data_citation_sentence]
236
- data_prompt += [self._generate_abstract_completion_prompt(data_point) for data_point in data_abs_completion]
237
- data_prompt += [self._generate_intro_2_abstract_prompt(data_point, context_window) for data_point in data_intro_2_abs]
238
-
239
- print("Total prompts:", len(data_prompt))
240
- random.shuffle(data_prompt)
241
- if self.tokenizer.chat_template is None:
242
- data_tokenized = [self.tokenizer(sample, max_length=context_window, truncation=True) for sample in tqdm(data_prompt)]
243
- else:
244
- data_tokenized = [self.tokenizer.apply_chat_template(sample, max_length=context_window, truncation=True, tokenize=False) for sample in tqdm(data_prompt)]
245
-
246
- return data_tokenized
247
-
248
-
249
- def _generate_LP_prompt(self, data_point: dict):
250
- instruction = "Determine if paper A will cite paper B."
251
-
252
- prompt_input = ""
253
- prompt_input = prompt_input + "Title of Paper A: " + (data_point['s_title'] if data_point['s_title'] != None else 'Unknown') + "\n"
254
- prompt_input = prompt_input + "Abstract of Paper A: " + (data_point['s_abs'] if data_point['s_abs'] != None else 'Unknown') + "\n"
255
- prompt_input = prompt_input + "Title of Paper B: " + (data_point['t_title'] if data_point['t_title'] != None else 'Unknown') + "\n"
256
- prompt_input = prompt_input + "Abstract of Paper B: " + (data_point['t_abs'] if data_point['t_abs'] != None else 'Unknown') + "\n"
257
-
258
- if self.tokenizer.chat_template is None:
259
- res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input)
260
- res = f"{res}{data_point['label']}"
261
- else:
262
- res = [
263
- {"role": "user", "content": self.template["prompt_input"].format(instruction=instruction, input=prompt_input)},
264
- {"role": "assistant", "content": data_point['label']}
265
- ]
266
-
267
- return res
268
-
269
- def _generate_abstract_2_title_prompt(self, data_point: dict):
270
- instruction = "Please generate the title of paper based on its abstract."
271
-
272
- prompt_input = ""
273
- prompt_input = prompt_input + "Abstract: " + data_point['abs'] + "\n"
274
-
275
- if self.tokenizer.chat_template is None:
276
- res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input)
277
- res = f"{res}{data_point['title']}"
278
- else:
279
- res = [
280
- {"role": "user", "content": self.template["prompt_input"].format(instruction=instruction, input=prompt_input)},
281
- {"role": "assistant", "content": data_point['title']}
282
- ]
283
-
284
- return res
285
-
286
- def _generate_paper_retrieval_prompt(self, data_point: dict):
287
- instruction = "Please select the paper that is more likely to be cited by paper A from candidate papers."
288
-
289
- prompt_input = ""
290
- prompt_input = prompt_input + "Title of the Paper A: " + data_point['title'] + "\n"
291
- prompt_input = prompt_input + "Abstract of the Paper A: " + data_point['abs'] + "\n"
292
- prompt_input = prompt_input + "candidate papers: " + "\n"
293
- for i in range(len(data_point['sample_title'])):
294
- prompt_input = prompt_input + str(i) + '. ' + data_point['sample_title'][i] + "\n"
295
-
296
- if self.tokenizer.chat_template is None:
297
- res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input)
298
- res = f"{res}{data_point['right_title']}"
299
- else:
300
- res = [
301
- {"role": "user", "content": self.template["prompt_input"].format(instruction=instruction, input=prompt_input)},
302
- {"role": "assistant", "content": data_point['right_title']}
303
- ]
304
-
305
- return res
306
-
307
- def _generate_citation_sentence_prompt(self, data_point: dict):
308
- instruction = "Please generate the citation sentence of how Paper A cites paper B in its related work section."
309
-
310
- prompt_input = ""
311
- prompt_input = prompt_input + "Title of Paper A: " + (data_point['s_title'] if data_point['s_title'] != None else 'Unknown') + "\n"
312
- prompt_input = prompt_input + "Abstract of Paper A: " + (data_point['s_abs'] if data_point['s_abs'] != None else 'Unknown') + "\n"
313
- prompt_input = prompt_input + "Title of Paper B: " + (data_point['t_title'] if data_point['t_title'] != None else 'Unknown') + "\n"
314
- prompt_input = prompt_input + "Abstract of Paper B: " + (data_point['t_abs'] if data_point['t_abs'] != None else 'Unknown') + "\n"
315
-
316
- if self.tokenizer.chat_template is None:
317
- res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input)
318
- res = f"{res}{data_point['sentence']}"
319
- else:
320
- res = [
321
- {"role": "user", "content": self.template["prompt_input"].format(instruction=instruction, input=prompt_input)},
322
- {"role": "assistant", "content": data_point['sentence']}
323
- ]
324
-
325
- return res
326
-
327
- def _generate_abstract_completion_prompt(self, data_point: dict):
328
- instruction = "Please complete the abstract of a paper."
329
-
330
- prompt_input = ""
331
- prompt_input = prompt_input + "Title: " + data_point['title'] if data_point['title'] != None else 'Unknown' + "\n"
332
-
333
- split_abs = data_point['abs'][: int(0.3*len(data_point['abs']))]
334
- prompt_input = prompt_input + "Part of abstract: " + split_abs + "\n"
335
-
336
- if self.tokenizer.chat_template is None:
337
- res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input)
338
- res = f"{res}{data_point['abs']}"
339
- else:
340
- res = [
341
- {"role": "user", "content": self.template["prompt_input"].format(instruction=instruction, input=prompt_input)},
342
- {"role": "assistant", "content": data_point['abs']}
343
- ]
344
-
345
- return res
346
-
347
- def _generate_intro_2_abstract_prompt(self, data_point: dict, context_window):
348
- instruction = "Please generate the abstract of paper based on its introduction section."
349
-
350
- prompt_input = ""
351
- prompt_input = prompt_input + "Introduction: " + data_point['intro'] + "\n"
352
-
353
- # Reduce it to make it fit
354
- prompt_input = prompt_input[:int(context_window*2)]
355
-
356
- if self.tokenizer.chat_template is None:
357
- res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input)
358
- res = f"{res}{data_point['abs']}"
359
- else:
360
- res = [
361
- {"role": "user", "content": self.template["prompt_input"].format(instruction=instruction, input=prompt_input)},
362
- {"role": "assistant", "content": data_point['abs']}
363
- ]
364
-
365
- return res
366
-
367
-
368
- if __name__ == "__main__":
369
- wandb.init(project='qlora_train')
370
- parser = argparse.ArgumentParser()
371
- parser.add_argument("config_path", help="Path to the config YAML file")
372
- parser.add_argument("--index", type=int, default=1, help="Index to specify the GPU or task number")
373
- args = parser.parse_args()
374
-
375
- config = read_yaml_file(args.config_path)
376
- trainer = QloraTrainer_CS(config, args.index, True)
377
-
378
- print("Load base model")
379
- trainer.load_base_model()
380
-
381
- print("Start training")
382
- trainer.train()