File size: 3,003 Bytes
dc59b01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
from datasets import load_from_disk
from transformers import (
    T5ForConditionalGeneration,
    T5Tokenizer,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments
)

# ======================================================
# DEVICE (Mac M1/M2/M3 Safe)
# ======================================================
device = "mps" if torch.backends.mps.is_available() else "cpu"
print("Using device:", device)

# ======================================================
# LOAD TOKENIZED DATASET (FIXED PATHS)
# ======================================================
print("Loading tokenized dataset...")

train_dataset = load_from_disk("data/tokenized/train")
val_dataset   = load_from_disk("data/tokenized/validation")

print("Train size:", len(train_dataset))
print("Validation size:", len(val_dataset))

# ======================================================
# LOAD MODEL
# ======================================================
print("Loading model (t5-small)...")

model = T5ForConditionalGeneration.from_pretrained("t5-small").to(device)
tokenizer = T5Tokenizer.from_pretrained("t5-small")

# Prevent Mac memory crash
model.config.use_cache = False

# Important T5 settings (prevents generation bugs)
model.config.decoder_start_token_id = tokenizer.pad_token_id
model.config.eos_token_id = tokenizer.eos_token_id
model.config.pad_token_id = tokenizer.pad_token_id

# ======================================================
# DATA COLLATOR
# ======================================================
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model
)

# ======================================================
# TRAINING ARGUMENTS (Mac Safe)
# ======================================================
print("Setting training config...")

training_args = Seq2SeqTrainingArguments(
    output_dir="outputs/model",

    evaluation_strategy="epoch",
    save_strategy="epoch",

    learning_rate=3e-4,
    num_train_epochs=5,

    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=8,

    logging_steps=50,

    fp16=False,
    bf16=False,
    dataloader_pin_memory=False,

    predict_with_generate=True,
    report_to="none"
)

# ======================================================
# TRAINER
# ======================================================
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

# ======================================================
# TRAIN
# ======================================================
print("Training started 🚀")
trainer.train()

# ======================================================
# SAVE MODEL
# ======================================================
print("Saving model...")
trainer.save_model("outputs/model")
tokenizer.save_pretrained("outputs/model")

print("\nDONE ✔ Base model trained successfully")