Trace2333 commited on
Commit
692a701
·
1 Parent(s): 0032d0a

online version 1

Browse files
Files changed (3) hide show
  1. build_openprompt.py +25 -21
  2. gpt2_generation.py +122 -207
  3. sft.py +15 -24
build_openprompt.py CHANGED
@@ -1,46 +1,50 @@
1
  import csv
2
- import random
3
  import json
4
- import numpy as np
 
 
5
 
6
- from sklearn.model_selection import ShuffleSplit
7
 
8
  samples = {
9
  "x": [],
10
  "y": [],
11
  }
12
  little = False
13
- all_loaded_sample = 500000
14
- # 二十万条
 
 
15
  with open("./data/prompts.csv") as f:
16
  csv_reader = csv.DictReader(f)
17
- for row_number, row in enumerate(csv_reader):
18
- # if row_number == random.randint(0, 1000):
19
- # break
 
 
 
20
  if little:
21
- if row_number > 100:
22
  break
23
- if row_number > all_loaded_sample:
24
  break
25
 
26
  datum = row
 
27
  modifiers = json.loads(datum['raw_data'])['modifiers']
28
- n = random.randint(1, 11)
29
- if len(modifiers) < 3:
30
  continue
31
- label = ",".join(modifiers) if len(modifiers) > 1 else modifiers[0]
32
- if 0<n and n<=6:
33
- x = modifiers[0]
34
- elif n>6 and n<=9:
35
- x = ",".join(modifiers[:2])
36
- else:
37
- x = ",".join(modifiers[:3])
38
  # 小文本到大文本,因此x更小,同时x按照6:3:1的比例分配
39
 
40
  samples["x"].append(x)
41
  samples["y"].append(label)
42
 
43
 
44
- with open("./data/dataset_openprompt.json", "w") as f:
45
  json.dump(samples, f, indent=4, ensure_ascii=False)
46
- print("*"*40, "save train done.", "with little" if little else "", "*"*40)
 
1
  import csv
2
+ import pandas as pd
3
  import json
4
+ import random
5
+
6
+ from tqdm import tqdm
7
 
 
8
 
9
  samples = {
10
  "x": [],
11
  "y": [],
12
  }
13
  little = False
14
+ all_loaded_sample = 400000
15
+
16
+ s_pro = all_loaded_sample / 1e+7
17
+ # 读取概率
18
  with open("./data/prompts.csv") as f:
19
  csv_reader = csv.DictReader(f)
20
+ process_reader = tqdm(enumerate(csv_reader))
21
+ for row_number, row in process_reader:
22
+ num_samples = len(samples['x'])
23
+ process_reader.set_description(f"got data num: {num_samples}")
24
+ if random.uniform(0, 1) > s_pro:
25
+ continue
26
  if little:
27
+ if len(samples["x"]) > 100:
28
  break
29
+ if len(samples["x"]) > all_loaded_sample:
30
  break
31
 
32
  datum = row
33
+ prompt = datum['prompt']
34
  modifiers = json.loads(datum['raw_data'])['modifiers']
35
+ if len(modifiers) < 4:
 
36
  continue
37
+
38
+ # TODO: 外挂一个entity识别,过滤掉存在entity实体的数据
39
+
40
+ label = prompt
41
+ x = prompt
 
 
42
  # 小文本到大文本,因此x更小,同时x按照6:3:1的比例分配
43
 
44
  samples["x"].append(x)
45
  samples["y"].append(label)
46
 
47
 
48
+ with open(f"./data/dataset_openprompt.json", "w") as f:
49
  json.dump(samples, f, indent=4, ensure_ascii=False)
50
+ print("*"*40, f"save {num_samples} train samples done.", "with little" if little else "", "*"*40)
gpt2_generation.py CHANGED
@@ -1,32 +1,10 @@
1
  #!/usr/bin/env python
2
  # coding=utf-8
3
- # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
4
- # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
5
- #
6
- # Licensed under the Apache License, Version 2.0 (the "License");
7
- # you may not use this file except in compliance with the License.
8
- # You may obtain a copy of the License at
9
- #
10
- # http://www.apache.org/licenses/LICENSE-2.0
11
- #
12
- # Unless required by applicable law or agreed to in writing, software
13
- # distributed under the License is distributed on an "AS IS" BASIS,
14
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
- # See the License for the specific language governing permissions and
16
- # limitations under the License.
17
- """ Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/CTRL/Transformer-XL/XLNet)
18
- """
19
-
20
-
21
- import argparse
22
  import inspect
23
- import time
24
  import logging
25
  from typing import Tuple
26
 
27
  import torch
28
- from accelerate import PartialState
29
- from accelerate.utils import set_seed
30
 
31
  from transformers import (
32
  AutoTokenizer,
@@ -49,17 +27,16 @@ from transformers import (
49
  XLMWithLMHeadModel,
50
  XLNetLMHeadModel,
51
  XLNetTokenizer,
 
52
  )
53
  from transformers.modeling_outputs import CausalLMOutputWithPast
54
-
55
 
56
  logging.basicConfig(
57
  format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
58
  datefmt="%m/%d/%Y %H:%M:%S",
59
  level=logging.INFO,
60
  )
61
- logger = logging.getLogger(__name__)
62
-
63
  MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
64
 
65
  MODEL_CLASSES = {
@@ -75,33 +52,33 @@ MODEL_CLASSES = {
75
  "opt": (OPTForCausalLM, GPT2Tokenizer),
76
  }
77
 
78
- # Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
79
- # in https://github.com/rusiaaman/XLNet-gen#methodology
80
- # and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
81
- PREFIX = """In 1991, the remains of Russian Tsar Nicholas II and his family
82
- (except for Alexei and Maria) are discovered.
83
- The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
84
- remainder of the story. 1883 Western Siberia,
85
- a young Grigori Rasputin is asked by his father and a group of men to perform magic.
86
- Rasputin has a vision and denounces one of the men as a horse thief. Although his
87
- father initially slaps him for making such an accusation, Rasputin watches as the
88
- man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
89
- the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
90
- with people, even a bishop, begging for his blessing. <eod> </s> <eos>"""
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  #
94
  # Functions to prepare models' input
95
  #
96
-
97
-
98
  def prepare_ctrl_input(args, _, tokenizer, prompt_text):
99
- if args.temperature > 0.7:
100
- logger.info("CTRL typically works better with lower temperatures (and lower top_k).")
101
 
102
  encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False)
103
  if not any(encoded_prompt[0] == x for x in tokenizer.control_codes.values()):
104
- logger.info("WARNING! You are not starting your generation from a control code so you won't get good results")
105
  return prompt_text
106
 
107
 
@@ -112,8 +89,8 @@ def prepare_xlm_input(args, model, tokenizer, prompt_text):
112
  use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb
113
  if hasattr(model.config, "lang2id") and use_lang_emb:
114
  available_languages = model.config.lang2id.keys()
115
- if args.xlm_language in available_languages:
116
- language = args.xlm_language
117
  else:
118
  language = None
119
  while language not in available_languages:
@@ -132,13 +109,13 @@ def prepare_xlm_input(args, model, tokenizer, prompt_text):
132
 
133
 
134
  def prepare_xlnet_input(args, _, tokenizer, prompt_text):
135
- prefix = args.prefix if args.prefix else args.padding_text if args.padding_text else PREFIX
136
  prompt_text = prefix + prompt_text
137
  return prompt_text
138
 
139
 
140
  def prepare_transfoxl_input(args, _, tokenizer, prompt_text):
141
- prefix = args.prefix if args.prefix else args.padding_text if args.padding_text else PREFIX
142
  prompt_text = prefix + prompt_text
143
  return prompt_text
144
 
@@ -284,170 +261,108 @@ class _ModelFallbackWrapper(GenerationMixin):
284
  return self._default._reorder_cache(past_key_values, beam_idx)
285
 
286
 
287
- def main():
288
- parser = argparse.ArgumentParser()
289
- parser.add_argument(
290
- "--model_type",
291
- default="gpt2",
292
- type=str,
293
- help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
294
- )
295
- parser.add_argument(
296
- "--model_name_or_path",
297
- default="./output/gpt2_openprpmpt/checkpoint-218500",
298
- type=str,
299
- help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
300
- )
301
-
302
- parser.add_argument("--prompt", type=str, default="")
303
- parser.add_argument("--length", type=int, default=60)
304
- parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped")
305
-
306
- parser.add_argument(
307
- "--temperature",
308
- type=float,
309
- default=1.0,
310
- help="temperature of 1.0 has no effect, lower tend toward greedy sampling",
311
- )
312
- parser.add_argument(
313
- "--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2"
314
- )
315
- parser.add_argument("--k", type=int, default=3)
316
- parser.add_argument("--p", type=float, default=0.9)
317
-
318
- parser.add_argument("--prefix", type=str, default="", help="Text added prior to input.")
319
- parser.add_argument("--padding_text", type=str, default="", help="Deprecated, the use of `--prefix` is preferred.")
320
- parser.add_argument("--xlm_language", type=str, default="", help="Optional language when used with the XLM model.")
321
-
322
- parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
323
- parser.add_argument(
324
- "--use_cpu",
325
- action="store_true",
326
- help="Whether or not to use cpu. If set to False, " "we will use gpu/npu or mps device if available",
327
- )
328
- parser.add_argument("--num_return_sequences", type=int, default=4, help="The number of samples to generate.")
329
- parser.add_argument(
330
- "--fp16",
331
- action="store_true",
332
- help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
333
- )
334
- parser.add_argument("--jit", action="store_true", help="Whether or not to use jit trace to accelerate inference")
335
- args = parser.parse_args()
336
-
337
- # Initialize the distributed state.
338
- distributed_state = PartialState(cpu=args.use_cpu)
339
-
340
- logger.warning(f"device: {distributed_state.device}, 16-bits inference: {args.fp16}")
341
-
342
- if args.seed is not None:
343
- set_seed(args.seed)
344
-
345
- # Initialize the model and tokenizer
346
- try:
347
- args.model_type = args.model_type.lower()
348
- model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
349
- except KeyError:
350
- raise KeyError("the model {} you specified is not supported. You are welcome to add it and open a PR :)")
351
-
352
- tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path, padding_side='left')
353
- if tokenizer.pad_token is None:
354
- tokenizer.pad_token = tokenizer.eos_token
355
- tokenizer.mask_token = tokenizer.eos_token
356
- model = model_class.from_pretrained(args.model_name_or_path)
357
-
358
- # Set the model to the right device
359
- model.to(distributed_state.device)
360
-
361
- if args.fp16:
362
- model.half()
363
  max_seq_length = getattr(model.config, "max_position_embeddings", 0)
364
- args.length = adjust_length_to_model(args.length, max_sequence_length=max_seq_length)
365
- logger.info(args)
366
-
367
- prompt_text = args.prompt if args.prompt else input("Model prompt >>> ")
368
-
369
- # Different models need different input formatting and/or extra arguments
370
- requires_preprocessing = args.model_type in PREPROCESSING_FUNCTIONS.keys()
371
- if requires_preprocessing:
372
- prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type)
373
- preprocessed_prompt_text = prepare_input(args, model, tokenizer, prompt_text)
374
-
375
- if model.__class__.__name__ in ["TransfoXLLMHeadModel"]:
376
- tokenizer_kwargs = {"add_space_before_punct_symbol": True}
 
 
 
 
 
 
377
  else:
378
- tokenizer_kwargs = {}
379
-
380
- encoded_prompt = tokenizer.encode(
381
- preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", **tokenizer_kwargs
382
- )
383
- else:
384
- prefix = args.prefix if args.prefix else args.padding_text
385
- encoded_prompt = tokenizer.encode(prefix + prompt_text, add_special_tokens=False, return_tensors="pt")
386
- encoded_prompt = encoded_prompt.to(distributed_state.device)
387
 
388
- if encoded_prompt.size()[-1] == 0:
389
- input_ids = None
390
- else:
391
- input_ids = encoded_prompt
392
-
393
- if args.jit:
394
- jit_input_texts = ["enable jit"]
395
- jit_inputs = prepare_jit_inputs(jit_input_texts, model, tokenizer)
396
- torch._C._jit_set_texpr_fuser_enabled(False)
397
- model.config.return_dict = False
398
- if hasattr(model, "forward"):
399
- sig = inspect.signature(model.forward)
400
  else:
401
- sig = inspect.signature(model.__call__)
402
- jit_inputs = tuple(jit_inputs[key] for key in sig.parameters if jit_inputs.get(key, None) is not None)
403
- traced_model = torch.jit.trace(model, jit_inputs, strict=False)
404
- traced_model = torch.jit.freeze(traced_model.eval())
405
- traced_model(*jit_inputs)
406
- traced_model(*jit_inputs)
407
-
408
- model = _ModelFallbackWrapper(traced_model, model)
409
- t1 = time.time()
410
- output_sequences = model.generate(
411
- input_ids=input_ids,
412
- max_length=args.length + len(encoded_prompt[0]),
413
- temperature=args.temperature,
414
- top_k=args.k,
415
- top_p=args.p,
416
- repetition_penalty=args.repetition_penalty,
417
- do_sample=True,
418
- num_return_sequences=args.num_return_sequences,
419
- )
420
-
421
- # Remove the batch dimension when returning multiple sequences
422
- if len(output_sequences.shape) > 2:
423
- output_sequences.squeeze_()
424
-
425
- generated_sequences = []
426
-
427
- for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
428
- print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===")
429
- generated_sequence = generated_sequence.tolist()
430
-
431
- # Decode text
432
- text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
433
-
434
- # Remove all text after the stop token
435
- text = text[: text.find(args.stop_token) if args.stop_token else None]
436
-
437
- # Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing
438
- total_sequence = (
439
- prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
440
- )
441
-
442
- generated_sequences.append(total_sequence)
443
- print(total_sequence)
444
-
445
- t2 = time.time()
446
- print("*"*60)
447
- print(f"Time cost: {t2-t1}")
448
-
449
- return generated_sequences
 
 
 
 
 
 
 
 
 
 
 
 
 
 
450
 
451
 
452
  if __name__ == "__main__":
453
- main()
 
1
  #!/usr/bin/env python
2
  # coding=utf-8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import inspect
 
4
  import logging
5
  from typing import Tuple
6
 
7
  import torch
 
 
8
 
9
  from transformers import (
10
  AutoTokenizer,
 
27
  XLMWithLMHeadModel,
28
  XLNetLMHeadModel,
29
  XLNetTokenizer,
30
+ AutoModelForSeq2SeqLM,
31
  )
32
  from transformers.modeling_outputs import CausalLMOutputWithPast
33
+ from forbidden import FORBIDDEN_NOUN
34
 
35
  logging.basicConfig(
36
  format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
37
  datefmt="%m/%d/%Y %H:%M:%S",
38
  level=logging.INFO,
39
  )
 
 
40
  MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
41
 
42
  MODEL_CLASSES = {
 
52
  "opt": (OPTForCausalLM, GPT2Tokenizer),
53
  }
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
+ FORBIDDEN_NOUN = set(FORBIDDEN_NOUN)
57
+
58
+ class Translator:
59
+ def __init__(self, model_name):
60
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
61
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
62
+
63
+ def translate(self, text):
64
+ inputs = self.tokenizer(text, return_tensors="pt", padding=True)
65
+ outputs = self.model.generate(**inputs)
66
+ translated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
67
+ return translated_text
68
+
69
+ def __call__(self, text):
70
+ return self.translate(text)
71
 
72
  #
73
  # Functions to prepare models' input
74
  #
 
 
75
  def prepare_ctrl_input(args, _, tokenizer, prompt_text):
76
+ if args["temperature"] > 0.7:
77
+ pass
78
 
79
  encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False)
80
  if not any(encoded_prompt[0] == x for x in tokenizer.control_codes.values()):
81
+ pass
82
  return prompt_text
83
 
84
 
 
89
  use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb
90
  if hasattr(model.config, "lang2id") and use_lang_emb:
91
  available_languages = model.config.lang2id.keys()
92
+ if args["xlm_language"] in available_languages:
93
+ language = args["xlm_language"]
94
  else:
95
  language = None
96
  while language not in available_languages:
 
109
 
110
 
111
  def prepare_xlnet_input(args, _, tokenizer, prompt_text):
112
+ prefix = args["prefix"] if args["prefix"] else args["padding_text"] if args["padding_text"] else ""
113
  prompt_text = prefix + prompt_text
114
  return prompt_text
115
 
116
 
117
  def prepare_transfoxl_input(args, _, tokenizer, prompt_text):
118
+ prefix = args["prefix"] if args["prefix"] else args["padding_text"] if args["padding_text"] else ""
119
  prompt_text = prefix + prompt_text
120
  return prompt_text
121
 
 
261
  return self._default._reorder_cache(past_key_values, beam_idx)
262
 
263
 
264
+ def generate_prompt(
265
+ prompt_text,
266
+ args,
267
+ zh_en_translator,
268
+ nlp,
269
+ model,
270
+ tokenizer,
271
+ distributed_state,
272
+ ):
273
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  max_seq_length = getattr(model.config, "max_position_embeddings", 0)
275
+ args["length"] = adjust_length_to_model(args["length"], max_sequence_length=max_seq_length)
276
+ while(1):
277
+ prompt_text = zh_en_translator(prompt_text)
278
+ # only support single input.
279
+
280
+ # Different models need different input formatting and/or extra arguments
281
+ requires_preprocessing = args["model_type"] in PREPROCESSING_FUNCTIONS.keys()
282
+ if requires_preprocessing:
283
+ prepare_input = PREPROCESSING_FUNCTIONS.get(args["model_type"])
284
+ preprocessed_prompt_text = prepare_input(args, model, tokenizer, prompt_text)
285
+
286
+ if model.__class__.__name__ in ["TransfoXLLMHeadModel"]:
287
+ tokenizer_kwargs = {"add_space_before_punct_symbol": True}
288
+ else:
289
+ tokenizer_kwargs = {}
290
+
291
+ encoded_prompt = tokenizer.encode(
292
+ preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", **tokenizer_kwargs
293
+ )
294
  else:
295
+ prefix = args["prefix"] if args["prefix"] else args["padding_text"]
296
+ encoded_prompt = tokenizer.encode(prefix + prompt_text, add_special_tokens=False, return_tensors="pt")
297
+ encoded_prompt = encoded_prompt.to(distributed_state.device)
 
 
 
 
 
 
298
 
299
+ if encoded_prompt.size()[-1] == 0:
300
+ input_ids = None
 
 
 
 
 
 
 
 
 
 
301
  else:
302
+ input_ids = encoded_prompt
303
+
304
+ if args["jit"]:
305
+ jit_input_texts = ["enable jit"]
306
+ jit_inputs = prepare_jit_inputs(jit_input_texts, model, tokenizer)
307
+ torch._C._jit_set_texpr_fuser_enabled(False)
308
+ model.config.return_dict = False
309
+ if hasattr(model, "forward"):
310
+ sig = inspect.signature(model.forward)
311
+ else:
312
+ sig = inspect.signature(model.__call__)
313
+ jit_inputs = tuple(jit_inputs[key] for key in sig.parameters if jit_inputs.get(key, None) is not None)
314
+ traced_model = torch.jit.trace(model, jit_inputs, strict=False)
315
+ traced_model = torch.jit.freeze(traced_model.eval())
316
+ traced_model(*jit_inputs)
317
+ traced_model(*jit_inputs)
318
+
319
+ model = _ModelFallbackWrapper(traced_model, model)
320
+
321
+ generated_sequences = []
322
+
323
+ for generated_sequence_idx in range(args["num_return_sequences"]):
324
+ repeat_gen_time = 0
325
+ while(1):
326
+ repeat_gen_time = repeat_gen_time + 1
327
+ generated_sequence = model.generate(
328
+ input_ids=input_ids,
329
+ max_length=args["length"] + len(encoded_prompt[0]),
330
+ temperature=args["temperature"],
331
+ top_k=args["k"],
332
+ top_p=args["p"],
333
+ repetition_penalty=args["repetition_penalty"],
334
+ do_sample=True,
335
+ num_return_sequences=1,
336
+ pad_token_id=tokenizer.pad_token_id
337
+ )
338
+ # Remove the n_sequence dimension when returning single sequence
339
+ if len(generated_sequence.shape) >1:
340
+ generated_sequence.squeeze_()
341
+
342
+ generated_sequence = generated_sequence.tolist()
343
+
344
+ # Decode text
345
+ text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
346
+
347
+ # Remove all text after the stop token
348
+ text = text[: text.find(args["stop_token"]) if args["stop_token"] else None]
349
+
350
+ # Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing
351
+ total_sequence = (
352
+ prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
353
+ )
354
+ # no checking for prompt_text.
355
+ docs = nlp(text)
356
+ nouns = [token.text for token in docs if token.pos_ == 'NOUN']
357
+ nouns = set(nouns)
358
+ if nouns.intersection(FORBIDDEN_NOUN) and repeat_gen_time < 10:
359
+ continue
360
+ else:
361
+ break
362
+ generated_sequences.append(total_sequence)
363
+
364
+ return generated_sequences
365
 
366
 
367
  if __name__ == "__main__":
368
+ generate_prompt()
sft.py CHANGED
@@ -1,10 +1,12 @@
1
  import time
2
  import evaluate
3
  import numpy as np
4
-
5
  from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq
6
  from transformers import TrainingArguments, Trainer
7
 
 
 
8
  from utils import (
9
  get_dataset,
10
  get_tok_and_model,
@@ -17,34 +19,27 @@ tokenizer, model = get_tok_and_model(f"./models/{base_model}")
17
  tokenizer.pad_token = tokenizer.eos_token
18
  rouge = evaluate.load("rouge")
19
 
20
- # train_data, test_data = get_open_prompt_data("./data")
21
- # train_dataset, test_dataset = get_dataset(train_data, test_data)
22
  dict_data = get_dict_dataset("./data")
23
  dataset = get_advance_dataset(dict_data)
24
- dataset = dataset.train_test_split(test_size=0.2)
25
 
26
  def preprocess_function(examples):
27
  x_inputs = [x for x in examples["x"]]
28
  y_inputs = examples["y"]
29
- model_inputs = tokenizer(x_inputs, max_length=128, truncation=True)
30
 
31
- labels = tokenizer(text_target=y_inputs, max_length=128, truncation=True)
32
 
33
- model_inputs["labels"] = model_inputs["input_ids"]
34
  return model_inputs
35
 
36
- def compute_metrics(eval_pred):
37
- predictions, labels = eval_pred
38
- decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
39
- labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
40
- decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
41
-
42
- result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
43
 
44
- prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
45
- result["gen_len"] = np.mean(prediction_lens)
46
-
47
- return {k: round(v, 4) for k, v in result.items()}
48
 
49
  # data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
50
  data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
@@ -69,7 +64,7 @@ training_args = TrainingArguments(
69
  adam_beta1=0.9,
70
  adam_beta2=0.98,
71
  save_total_limit=1,
72
- num_train_epochs=100,
73
  fp16=True,
74
  push_to_hub=False,
75
  )
@@ -81,12 +76,8 @@ trainer = Trainer(
81
  eval_dataset=tokenized_dataset["test"],
82
  tokenizer=tokenizer,
83
  data_collator=data_collator,
 
84
  )
85
 
86
  trainer.train()
87
 
88
- import math
89
-
90
- eval_results = trainer.evaluate()
91
- print(f"Perplexity: {math.exp(eval_results['eval_loss']):.2f}")
92
-
 
1
  import time
2
  import evaluate
3
  import numpy as np
4
+ import math
5
  from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq
6
  from transformers import TrainingArguments, Trainer
7
 
8
+ from transformers.trainer_callback import TrainerCallback
9
+
10
  from utils import (
11
  get_dataset,
12
  get_tok_and_model,
 
19
  tokenizer.pad_token = tokenizer.eos_token
20
  rouge = evaluate.load("rouge")
21
 
 
 
22
  dict_data = get_dict_dataset("./data")
23
  dataset = get_advance_dataset(dict_data)
24
+ dataset = dataset.train_test_split(test_size=0.05)
25
 
26
  def preprocess_function(examples):
27
  x_inputs = [x for x in examples["x"]]
28
  y_inputs = examples["y"]
29
+ model_inputs = tokenizer(x_inputs, max_length=256, truncation=True)
30
 
31
+ labels = tokenizer(y_inputs, max_length=256, truncation=True)
32
 
33
+ model_inputs["labels"] = labels["input_ids"]
34
  return model_inputs
35
 
36
+ class CustomCallback(TrainerCallback):
37
+ def on_epoch_end(self, args, state, control, **kwargs):
38
+ control.should_evaluate=True,
 
 
 
 
39
 
40
+ def on_evaluate(self, args, state, control, **kwargs):
41
+ eval_results = kwargs["metrics"]
42
+ print(f"Perplexity: {math.exp(eval_results['eval_loss']):.2f}\n")
 
43
 
44
  # data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
45
  data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
 
64
  adam_beta1=0.9,
65
  adam_beta2=0.98,
66
  save_total_limit=1,
67
+ num_train_epochs=80,
68
  fp16=True,
69
  push_to_hub=False,
70
  )
 
76
  eval_dataset=tokenized_dataset["test"],
77
  tokenizer=tokenizer,
78
  data_collator=data_collator,
79
+ callbacks=[CustomCallback]
80
  )
81
 
82
  trainer.train()
83