drixo commited on
Commit
7ae3549
·
verified ·
1 Parent(s): b09f0a2

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +31 -17
train.py CHANGED
@@ -9,36 +9,47 @@ from transformers import (
9
 
10
  from config import MODEL_NAME, MAX_LENGTH, DATASET_EN_ES
11
 
12
- # Load model
13
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
14
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
15
 
16
  # Load dataset
17
  dataset = load_dataset(DATASET_EN_ES)
18
 
19
- # Preprocess function
20
- def preprocess(batch):
21
- inputs = tokenizer(
22
- batch["term"]["en"],
23
- truncation=True,
24
- max_length=MAX_LENGTH
 
 
 
 
 
25
  )
26
 
27
- targets = tokenizer(
28
- batch["term"]["es"],
29
- truncation=True,
30
- max_length=MAX_LENGTH
 
31
  )
32
 
33
- inputs["labels"] = targets["input_ids"]
34
- return inputs
35
 
36
- dataset = dataset.map(preprocess)
 
37
 
 
38
  # Data collator
 
39
  data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
40
 
41
- # Training settings
 
 
42
  training_args = Seq2SeqTrainingArguments(
43
  output_dir="./my-translation-model",
44
  learning_rate=2e-5,
@@ -46,14 +57,17 @@ training_args = Seq2SeqTrainingArguments(
46
  num_train_epochs=3,
47
  save_strategy="epoch",
48
  logging_steps=50,
49
- evaluation_strategy="no"
 
50
  )
51
 
 
52
  # Trainer
 
53
  trainer = Seq2SeqTrainer(
54
  model=model,
55
  args=training_args,
56
- train_dataset=dataset["train"],
57
  tokenizer=tokenizer,
58
  data_collator=data_collator
59
  )
 
9
 
10
  from config import MODEL_NAME, MAX_LENGTH, DATASET_EN_ES
11
 
12
+ # Load tokenizer + model
13
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
14
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
15
 
16
  # Load dataset
17
  dataset = load_dataset(DATASET_EN_ES)
18
 
19
+ # -----------------------------
20
+ # FIX: proper preprocessing
21
+ # -----------------------------
22
+ def preprocess(example):
23
+ source = example["term"]["en"]
24
+ target = example["term"]["es"]
25
+
26
+ model_inputs = tokenizer(
27
+ source,
28
+ max_length=MAX_LENGTH,
29
+ truncation=True
30
  )
31
 
32
+ # IMPORTANT FIX: use text_target (correct way for seq2seq)
33
+ labels = tokenizer(
34
+ text_target=target,
35
+ max_length=MAX_LENGTH,
36
+ truncation=True
37
  )
38
 
39
+ model_inputs["labels"] = labels["input_ids"]
40
+ return model_inputs
41
 
42
+ # Apply preprocessing
43
+ tokenized_dataset = dataset.map(preprocess, remove_columns=dataset["train"].column_names)
44
 
45
+ # -----------------------------
46
  # Data collator
47
+ # -----------------------------
48
  data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
49
 
50
+ # -----------------------------
51
+ # Training arguments
52
+ # -----------------------------
53
  training_args = Seq2SeqTrainingArguments(
54
  output_dir="./my-translation-model",
55
  learning_rate=2e-5,
 
57
  num_train_epochs=3,
58
  save_strategy="epoch",
59
  logging_steps=50,
60
+ evaluation_strategy="no",
61
+ fp16=True # faster if GPU supports it
62
  )
63
 
64
+ # -----------------------------
65
  # Trainer
66
+ # -----------------------------
67
  trainer = Seq2SeqTrainer(
68
  model=model,
69
  args=training_args,
70
+ train_dataset=tokenized_dataset["train"],
71
  tokenizer=tokenizer,
72
  data_collator=data_collator
73
  )