Percy3822 commited on
Commit
de3a096
·
verified ·
1 Parent(s): 23b4c59

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +7 -4
train.py CHANGED
@@ -12,14 +12,17 @@ model = AutoModelForCausalLM.from_pretrained("distilgpt2")
12
  tokenizer.pad_token = tokenizer.eos_token
13
  model.config.pad_token_id = tokenizer.pad_token_id
14
 
15
- # Tokenize data
16
  def tokenize_function(example):
17
  full_text = example["prompt"] + example["completion"]
18
- return tokenizer(full_text, truncation=True, padding="max_length", max_length=512)
 
 
19
 
 
20
  tokenized_dataset = dataset["train"].map(tokenize_function)
21
 
22
- # Training config
23
  training_args = TrainingArguments(
24
  output_dir="./results",
25
  per_device_train_batch_size=2,
@@ -39,6 +42,6 @@ trainer = Trainer(
39
  # Train
40
  trainer.train()
41
 
42
- # Save model & tokenizer
43
  trainer.save_model("trained_model")
44
  tokenizer.save_pretrained("trained_model")
 
12
  tokenizer.pad_token = tokenizer.eos_token
13
  model.config.pad_token_id = tokenizer.pad_token_id
14
 
15
+ # Tokenize function: provide input_ids + labels (needed for loss)
16
  def tokenize_function(example):
17
  full_text = example["prompt"] + example["completion"]
18
+ tokens = tokenizer(full_text, truncation=True, padding="max_length", max_length=512)
19
+ tokens["labels"] = tokens["input_ids"].copy() # 👈 labels = input_ids for language modeling
20
+ return tokens
21
 
22
+ # Tokenize
23
  tokenized_dataset = dataset["train"].map(tokenize_function)
24
 
25
+ # Training configuration
26
  training_args = TrainingArguments(
27
  output_dir="./results",
28
  per_device_train_batch_size=2,
 
42
  # Train
43
  trainer.train()
44
 
45
+ # Save model and tokenizer
46
  trainer.save_model("trained_model")
47
  tokenizer.save_pretrained("trained_model")