주영 commited on
Commit
2eff4f8
Β·
1 Parent(s): a68472a

Add training script for fine-tuned first KoBART

Browse files
Files changed (1) hide show
  1. first_kobart/train_stt2pron_eos.py +98 -0
first_kobart/train_stt2pron_eos.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torch.utils.data import Dataset, DataLoader
4
+ from transformers import BartForConditionalGeneration, PreTrainedTokenizerFast
5
+ from torch.optim import AdamW
6
+ from transformers import get_scheduler
7
+ from tqdm import tqdm
8
+
9
+ # βœ… μ„€μ •
10
+ MODEL_DIR = "gogamza/kobart-base-v2"
11
+ SAVE_DIR = "./kobart_stt2pron_with_eos"
12
+ DATA_PATH = "data/train_stt2pron_with_eos.pt"
13
+ BATCH_SIZE = 8
14
+ EPOCHS = 7
15
+ LR = 5e-5
16
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+
18
+ # βœ… 데이터셋 클래슀
19
+ class STT2PronDataset(Dataset):
20
+ def __init__(self, data, tokenizer, max_length=128):
21
+ self.data = data
22
+ self.tokenizer = tokenizer
23
+ self.max_length = max_length
24
+
25
+ def __len__(self):
26
+ return len(self.data)
27
+
28
+ def __getitem__(self, idx):
29
+ item = self.data[idx]
30
+ source = item["stt"]
31
+ target = item["pronunciation"]
32
+ input_enc = self.tokenizer(
33
+ source,
34
+ padding="max_length",
35
+ truncation=True,
36
+ max_length=self.max_length,
37
+ return_tensors="pt"
38
+ )
39
+ target_enc = self.tokenizer(
40
+ target,
41
+ padding="max_length",
42
+ truncation=True,
43
+ max_length=self.max_length,
44
+ return_tensors="pt"
45
+ )
46
+ labels = target_enc["input_ids"]
47
+ labels[labels == self.tokenizer.pad_token_id] = -100 # CrossEntropy loss λ¬΄μ‹œ
48
+
49
+ return {
50
+ "input_ids": input_enc["input_ids"].squeeze(),
51
+ "attention_mask": input_enc["attention_mask"].squeeze(),
52
+ "labels": labels.squeeze()
53
+ }
54
+
55
+ # βœ… λͺ¨λΈ 및 ν† ν¬λ‚˜μ΄μ € λ‘œλ“œ
56
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(MODEL_DIR)
57
+ model = BartForConditionalGeneration.from_pretrained(MODEL_DIR).to(DEVICE)
58
+
59
+ # βœ… 데이터 λ‘œλ“œ 및 λ°μ΄ν„°λ‘œλ” 생성
60
+ data = torch.load(DATA_PATH)
61
+ dataset = STT2PronDataset(data, tokenizer)
62
+ loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
63
+
64
+ # βœ… μ˜΅ν‹°λ§ˆμ΄μ € & μŠ€μΌ€μ€„λŸ¬
65
+ optimizer = AdamW(model.parameters(), lr=LR)
66
+ lr_scheduler = get_scheduler(
67
+ name="linear", optimizer=optimizer, num_warmup_steps=0,
68
+ num_training_steps=len(loader) * EPOCHS
69
+ )
70
+
71
+ # βœ… ν•™μŠ΅ 루프
72
+ model.train()
73
+ for epoch in range(EPOCHS):
74
+ print(f"\n🌟 Epoch {epoch+1}/{EPOCHS}")
75
+ loop = tqdm(loader)
76
+ total_loss = 0
77
+ for batch in loop:
78
+ for k in batch:
79
+ batch[k] = batch[k].to(DEVICE)
80
+ outputs = model(**batch)
81
+ loss = outputs.loss
82
+
83
+ optimizer.zero_grad()
84
+ loss.backward()
85
+ optimizer.step()
86
+ lr_scheduler.step()
87
+
88
+ total_loss += loss.item()
89
+ loop.set_description(f"Loss: {loss.item():.4f}")
90
+
91
+ avg_loss = total_loss / len(loader)
92
+ print(f"βœ… Epoch {epoch+1} 평균 Loss: {avg_loss:.4f}")
93
+
94
+ # βœ… λͺ¨λΈ μ €μž₯
95
+ os.makedirs(SAVE_DIR, exist_ok=True)
96
+ model.save_pretrained(SAVE_DIR)
97
+ tokenizer.save_pretrained(SAVE_DIR)
98
+ print(f"\nπŸ“¦ λͺ¨λΈ μ €μž₯ μ™„λ£Œ: {SAVE_DIR}")