pineapple-lover commited on
Commit
7d0662d
·
1 Parent(s): 132d83f

Release HuPER Corrector weights and inference code

Browse files
.gitattributes CHANGED
@@ -1,35 +1 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  *.safetensors filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
edit_seq_speech/__init__.py ADDED
File without changes
edit_seq_speech/config/vocab.json ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "op_to_id": {
3
+ "KEEP": 0,
4
+ "DEL": 1,
5
+ "SUB:<PAD>": 2,
6
+ "SUB:AA": 3,
7
+ "SUB:AE": 4,
8
+ "SUB:AH": 5,
9
+ "SUB:AO": 6,
10
+ "SUB:AW": 7,
11
+ "SUB:AY": 8,
12
+ "SUB:B": 9,
13
+ "SUB:CH": 10,
14
+ "SUB:D": 11,
15
+ "SUB:DH": 12,
16
+ "SUB:DX": 13,
17
+ "SUB:EH": 14,
18
+ "SUB:ER": 15,
19
+ "SUB:EY": 16,
20
+ "SUB:F": 17,
21
+ "SUB:G": 18,
22
+ "SUB:HH": 19,
23
+ "SUB:IH": 20,
24
+ "SUB:IY": 21,
25
+ "SUB:JH": 22,
26
+ "SUB:K": 23,
27
+ "SUB:L": 24,
28
+ "SUB:M": 25,
29
+ "SUB:N": 26,
30
+ "SUB:NG": 27,
31
+ "SUB:OW": 28,
32
+ "SUB:OY": 29,
33
+ "SUB:P": 30,
34
+ "SUB:R": 31,
35
+ "SUB:S": 32,
36
+ "SUB:SH": 33,
37
+ "SUB:T": 34,
38
+ "SUB:TH": 35,
39
+ "SUB:UH": 36,
40
+ "SUB:UW": 37,
41
+ "SUB:V": 38,
42
+ "SUB:W": 39,
43
+ "SUB:Y": 40,
44
+ "SUB:Z": 41,
45
+ "SUB:ZH": 42
46
+ },
47
+ "insert_to_id": {
48
+ "<NONE>": 0,
49
+ "<PAD>": 1,
50
+ "AA": 2,
51
+ "AE": 3,
52
+ "AH": 4,
53
+ "AO": 5,
54
+ "AW": 6,
55
+ "AY": 7,
56
+ "B": 8,
57
+ "CH": 9,
58
+ "D": 10,
59
+ "DH": 11,
60
+ "DX": 12,
61
+ "EH": 13,
62
+ "ER": 14,
63
+ "EY": 15,
64
+ "F": 16,
65
+ "G": 17,
66
+ "HH": 18,
67
+ "IH": 19,
68
+ "IY": 20,
69
+ "JH": 21,
70
+ "K": 22,
71
+ "L": 23,
72
+ "M": 24,
73
+ "N": 25,
74
+ "NG": 26,
75
+ "OW": 27,
76
+ "OY": 28,
77
+ "P": 29,
78
+ "R": 30,
79
+ "S": 31,
80
+ "SH": 32,
81
+ "T": 33,
82
+ "TH": 34,
83
+ "UH": 35,
84
+ "UW": 36,
85
+ "V": 37,
86
+ "W": 38,
87
+ "Y": 39,
88
+ "Z": 40,
89
+ "ZH": 41
90
+ },
91
+ "stats": {
92
+ "num_ops": 43,
93
+ "num_inserts": 42
94
+ }
95
+ }
edit_seq_speech/inference.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchaudio
4
+ import json
5
+ import re
6
+ import os
7
+ from g2p_en import G2p
8
+ import pytorch_lightning as pl
9
+
10
+ from .model import PhonemeCorrector
11
+ from transformers import Wav2Vec2Processor, HubertModel
12
+
13
+ class PhonemeCorrectionInference:
14
+ def __init__(self, checkpoint_path, vocab_path, audio_model_name="facebook/hubert-large-ls960-ft", device=None):
15
+ self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+
17
+ # 1. Load Vocab / Config
18
+ print(f"Loading config from {vocab_path}...")
19
+ with open(vocab_path, 'r') as f:
20
+ self.config = json.load(f)
21
+
22
+ self.op_map = self.config['op_to_id']
23
+ self.ins_map = self.config['insert_to_id']
24
+
25
+ # Create Reverse Maps (ID -> String)
26
+ self.id2op = {v: k for k, v in self.op_map.items()}
27
+ self.id2ins = {v: k for k, v in self.ins_map.items()}
28
+
29
+ # 2. Load G2P
30
+ self.g2p = G2p()
31
+
32
+ # 3. Load Model
33
+ print(f"Loading model from {checkpoint_path}...")
34
+ if os.path.exists(checkpoint_path):
35
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
36
+ hparams = checkpoint.get('hyper_parameters', {})
37
+
38
+ vocab_size = max(self.ins_map.values()) + 1
39
+ audio_vocab_size = hparams.get('audio_vocab_size', 2048)
40
+
41
+ self.model = PhonemeCorrector.load_from_checkpoint(
42
+ checkpoint_path,
43
+ map_location=self.device,
44
+ vocab_size=vocab_size,
45
+ audio_vocab_size=audio_vocab_size
46
+ )
47
+ else:
48
+ raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}")
49
+
50
+ self.model.to(self.device)
51
+ self.model.eval()
52
+
53
+ # 4. Load Audio Tokenizer
54
+ print(f"Loading Audio Tokenizer: {audio_model_name}")
55
+ self.audio_processor = Wav2Vec2Processor.from_pretrained(audio_model_name)
56
+ self.audio_model = HubertModel.from_pretrained(audio_model_name).eval().to(self.device)
57
+
58
+ def _clean_phn(self, phn_list):
59
+ """Standard cleaning to match training."""
60
+ IGNORED = {"SIL", "'", "SPN", " "}
61
+ return [p.rstrip('012') for p in phn_list if p.rstrip('012') not in IGNORED]
62
+
63
+ def _get_audio_tokens(self, wav_path):
64
+ """
65
+ Runs the audio tokenizer.
66
+ IMPORTANT: This must match your training data generation logic.
67
+ """
68
+ waveform, sr = torchaudio.load(wav_path)
69
+ if sr != 16000:
70
+ resampler = torchaudio.transforms.Resample(sr, 16000)
71
+ waveform = resampler(waveform)
72
+
73
+ inputs = self.audio_processor(waveform.squeeze().numpy(), return_tensors="pt", sampling_rate=16000)
74
+ input_values = inputs.input_values.to(self.device)
75
+
76
+ with torch.no_grad():
77
+ outputs = self.audio_model(input_values)
78
+
79
+ # Placeholder Quantization (Argmax) - Replace if using K-Means
80
+ features = outputs.last_hidden_state
81
+ tokens = torch.argmax(features, dim=-1).squeeze()
82
+
83
+ # Downsample to 25Hz (Assuming model is 50Hz)
84
+ tokens = tokens[::2]
85
+ return tokens.unsqueeze(0) # (1, T)
86
+
87
+ def predict(self, wav_path, text):
88
+ # A. Prepare Inputs
89
+ # 1. Text -> Phonemes -> IDs
90
+ # raw_phns = self.g2p(text)
91
+ raw_phns = text.split() # Assuming input text is already phonemized for inference
92
+ src_phns = self._clean_phn(raw_phns)
93
+
94
+ # Create text vocab from insert_to_id (same as dataset)
95
+ text_vocab = {k: v for k, v in self.ins_map.items() if k not in ['<NONE>', '<PAD>']}
96
+ text_ids = [text_vocab.get(p, text_vocab.get("AA", 2)) for p in src_phns]
97
+ text_tensor = torch.tensor([text_ids], dtype=torch.long).to(self.device)
98
+
99
+ # 2. Audio -> Tokens
100
+ audio_tensor = self._get_audio_tokens(wav_path)
101
+
102
+ # B. Run Model
103
+ with torch.no_grad():
104
+ # Create masks
105
+ txt_mask = torch.ones_like(text_tensor)
106
+ aud_mask = torch.ones_like(audio_tensor)
107
+
108
+ logits_op, logits_ins = self.model(
109
+ text_tensor, audio_tensor, txt_mask, aud_mask
110
+ )
111
+
112
+ # C. Decode
113
+ pred_ops = torch.argmax(logits_op, dim=-1).squeeze().tolist()
114
+ pred_ins = torch.argmax(logits_ins, dim=-1).squeeze().tolist()
115
+
116
+ # Ensure lists
117
+ if not isinstance(pred_ops, list): pred_ops = [pred_ops]
118
+ if not isinstance(pred_ins, list): pred_ins = [pred_ins]
119
+
120
+ # D. Reconstruct Sequence
121
+ final_phonemes = []
122
+ log = []
123
+
124
+ for i, (orig, op_id, ins_id) in enumerate(zip(src_phns, pred_ops, pred_ins)):
125
+
126
+ # 1. Apply Operation
127
+ op_str = self.id2op.get(op_id, "KEEP")
128
+ curr_log = {"src": orig, "op": op_str, "ins": "NONE"}
129
+
130
+ if op_str == "KEEP":
131
+ final_phonemes.append(orig)
132
+ elif op_str == "DEL":
133
+ pass # Do not append
134
+ elif op_str.startswith("SUB:"):
135
+ # Extract phoneme: "SUB:AA" -> "AA"
136
+ new_phn = op_str.split(":")[1]
137
+ final_phonemes.append(new_phn)
138
+
139
+ # 2. Apply Insertion
140
+ ins_str = self.id2ins.get(ins_id, "<NONE>")
141
+ if ins_str != "<NONE>":
142
+ final_phonemes.append(ins_str)
143
+ curr_log["ins"] = ins_str
144
+
145
+ log.append(curr_log)
146
+
147
+ return final_phonemes, log
148
+
149
+ if __name__ == "__main__":
150
+ ckpt_path = "/data/chenxu/checkpoints/edit_seq_speech/phoneme-corrector/last.ckpt"
151
+ vocab_path = "edit_seq_speech/config/vocab.json"
152
+ wav_file = "test.wav"
153
+ text_input = "Last Sunday"
154
+
155
+ if os.path.exists(ckpt_path) and os.path.exists(wav_file):
156
+ infer = PhonemeCorrectionInference(ckpt_path, vocab_path)
157
+ result, details = infer.predict(wav_file, text_input)
158
+
159
+ print(f"Input Text: {text_input}")
160
+ print(f"Result Phn: {result}")
161
+ print("-" * 20)
162
+ for step in details:
163
+ print(f"{step['src']} -> {step['op']} + Insert({step['ins']})")
164
+ else:
165
+ print("Please set valid paths for checkpoint and wav file.")
edit_seq_speech/model.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+ import pytorch_lightning as pl
5
+
6
+ class PhonemeCorrector(pl.LightningModule):
7
+ def __init__(self, vocab_size, audio_vocab_size, d_model=256, nhead=4, num_layers=4, dropout=0.1, lr=1e-4,
8
+ weight_decay=0.01, scheduler_config=None, optimizer_config=None):
9
+ super().__init__()
10
+ self.save_hyperparameters()
11
+ self.scheduler_config = scheduler_config or {}
12
+ self.optimizer_config = optimizer_config or {}
13
+
14
+ # 1. Embeddings
15
+ self.text_embedding = nn.Embedding(vocab_size, d_model)
16
+ self.audio_embedding = nn.Embedding(audio_vocab_size, d_model)
17
+
18
+ # Positional Encoding (Standard Sinusoidal)
19
+ self.pos_encoder = PositionalEncoding(d_model, dropout)
20
+
21
+ # 2. The Core Transformer (Text querying Audio)
22
+ decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
23
+ self.transformer = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
24
+
25
+ # 3. Prediction Heads - 2-head architecture
26
+ # Head 1: Operation (KEEP, DEL, SUB:AA, SUB:AE, ...)
27
+ # num_ops = vocab_size + 2 (KEEP=0, DEL=1, SUB:phonemes=2+)
28
+ # This matches the precomputed op_ids format
29
+ num_ops = vocab_size + 2
30
+ self.head_op = nn.Linear(d_model, num_ops)
31
+
32
+ # Head 2: Insertion (NONE=0, AA, AE, ...)
33
+ # num_inserts = vocab_size (NONE=0, then phonemes)
34
+ num_inserts = vocab_size
35
+ self.head_ins = nn.Linear(d_model, num_inserts)
36
+
37
+ def forward(self, text_ids, audio_ids, text_mask=None, audio_mask=None):
38
+ """
39
+ text_ids: (Batch, Text_Len)
40
+ audio_ids: (Batch, Audio_Len)
41
+ masks: (Batch, Len) - 1 for valid, 0 for pad.
42
+ """
43
+ text_emb = self.pos_encoder(self.text_embedding(text_ids))
44
+ audio_emb = self.pos_encoder(self.audio_embedding(audio_ids))
45
+
46
+ txt_pad_mask = (text_mask == 0) if text_mask is not None else None
47
+ aud_pad_mask = (audio_mask == 0) if audio_mask is not None else None
48
+
49
+ encoded_features = self.transformer(
50
+ tgt=text_emb,
51
+ memory=audio_emb,
52
+ tgt_key_padding_mask=txt_pad_mask,
53
+ memory_key_padding_mask=aud_pad_mask
54
+ )
55
+
56
+ logits_op = self.head_op(encoded_features)
57
+ logits_ins = self.head_ins(encoded_features)
58
+
59
+ return logits_op, logits_ins
60
+
61
+ def training_step(self, batch, batch_idx):
62
+ input_ids = batch['input_ids']
63
+ audio_tokens = batch['audio_tokens']
64
+ lbl_op = batch['labels']['op']
65
+ lbl_ins = batch['labels']['ins']
66
+ txt_mask = batch['masks']['text']
67
+ audio_mask = batch['masks']['audio']
68
+
69
+ logits_op, logits_ins = self(input_ids, audio_tokens, txt_mask, audio_mask)
70
+
71
+ # Active loss mask (only compute loss on valid text tokens)
72
+ active_loss = txt_mask.view(-1) == 1
73
+
74
+ # OP LOSS (includes KEEP, DEL, and all SUB:phoneme operations)
75
+ num_ops = self.hparams.vocab_size + 2
76
+ loss_op = nn.functional.cross_entropy(
77
+ logits_op.view(-1, num_ops)[active_loss],
78
+ lbl_op.view(-1)[active_loss]
79
+ )
80
+
81
+ # INS LOSS
82
+ loss_ins = nn.functional.cross_entropy(
83
+ logits_ins.view(-1, self.hparams.vocab_size)[active_loss],
84
+ lbl_ins.view(-1)[active_loss]
85
+ )
86
+
87
+ loss = loss_op + loss_ins
88
+ self.log('train_loss', loss, prog_bar=True)
89
+ self.log('train_loss_op', loss_op)
90
+ self.log('train_loss_ins', loss_ins)
91
+ return loss
92
+
93
+ def validation_step(self, batch, batch_idx):
94
+ input_ids = batch['input_ids']
95
+ audio_tokens = batch['audio_tokens']
96
+ lbl_op = batch['labels']['op']
97
+ lbl_ins = batch['labels']['ins']
98
+ txt_mask = batch['masks']['text']
99
+ audio_mask = batch['masks']['audio']
100
+
101
+ logits_op, logits_ins = self(input_ids, audio_tokens, txt_mask, audio_mask)
102
+
103
+ # Compute losses
104
+ active_loss = txt_mask.view(-1) == 1
105
+ num_ops = self.hparams.vocab_size + 2
106
+
107
+ loss_op = nn.functional.cross_entropy(
108
+ logits_op.view(-1, num_ops)[active_loss],
109
+ lbl_op.view(-1)[active_loss]
110
+ )
111
+
112
+ loss_ins = nn.functional.cross_entropy(
113
+ logits_ins.view(-1, self.hparams.vocab_size)[active_loss],
114
+ lbl_ins.view(-1)[active_loss]
115
+ )
116
+
117
+ loss = loss_op + loss_ins
118
+
119
+ # Compute accuracy
120
+ pred_op = torch.argmax(logits_op, dim=-1)
121
+ pred_ins = torch.argmax(logits_ins, dim=-1)
122
+
123
+ # OP accuracy
124
+ op_correct = (pred_op == lbl_op) & txt_mask
125
+ op_acc = op_correct.sum().float() / txt_mask.sum().float()
126
+
127
+ # INS accuracy
128
+ ins_correct = (pred_ins == lbl_ins) & txt_mask
129
+ ins_acc = ins_correct.sum().float() / txt_mask.sum().float()
130
+
131
+ # Overall accuracy: correct OP prediction
132
+ overall_acc = op_acc
133
+
134
+ # Per-operation accuracy (KEEP=0, DEL=1, SUB>=2)
135
+ keep_mask = (lbl_op == 0) & txt_mask
136
+ del_mask = (lbl_op == 1) & txt_mask
137
+ sub_op_mask = (lbl_op >= 2) & txt_mask
138
+
139
+ keep_acc = torch.tensor(0.0, device=loss.device)
140
+ del_acc = torch.tensor(0.0, device=loss.device)
141
+ sub_op_acc = torch.tensor(0.0, device=loss.device)
142
+
143
+ if keep_mask.sum() > 0:
144
+ keep_correct = (pred_op == lbl_op) & keep_mask
145
+ keep_acc = keep_correct.sum().float() / keep_mask.sum().float()
146
+
147
+ if del_mask.sum() > 0:
148
+ del_correct = (pred_op == lbl_op) & del_mask
149
+ del_acc = del_correct.sum().float() / del_mask.sum().float()
150
+
151
+ if sub_op_mask.sum() > 0:
152
+ sub_op_correct = (pred_op == lbl_op) & sub_op_mask
153
+ sub_op_acc = sub_op_correct.sum().float() / sub_op_mask.sum().float()
154
+
155
+ # Log metrics
156
+ self.log('val_loss', loss, prog_bar=True, sync_dist=True)
157
+ self.log('val_loss_op', loss_op, sync_dist=True)
158
+ self.log('val_loss_ins', loss_ins, sync_dist=True)
159
+ self.log('val_acc', overall_acc, prog_bar=True, sync_dist=True)
160
+ self.log('val_acc_op', op_acc, sync_dist=True)
161
+ self.log('val_acc_ins', ins_acc, sync_dist=True)
162
+ self.log('val_acc_keep', keep_acc, sync_dist=True)
163
+ self.log('val_acc_del', del_acc, sync_dist=True)
164
+ self.log('val_acc_sub_op', sub_op_acc, sync_dist=True)
165
+
166
+ return {
167
+ 'val_loss': loss,
168
+ 'val_acc': overall_acc,
169
+ 'val_acc_op': op_acc,
170
+ 'val_acc_ins': ins_acc
171
+ }
172
+
173
+ def configure_optimizers(self):
174
+ # Get optimizer configuration
175
+ optimizer_name = self.optimizer_config.get("name", "adamw").lower()
176
+ lr = self.hparams.lr
177
+ weight_decay = getattr(self.hparams, 'weight_decay', 0.01)
178
+
179
+ if optimizer_name == "adamw":
180
+ optimizer = torch.optim.AdamW(
181
+ self.parameters(),
182
+ lr=lr,
183
+ weight_decay=weight_decay,
184
+ betas=self.optimizer_config.get("betas", [0.9, 0.999]),
185
+ eps=self.optimizer_config.get("eps", 1.0e-8)
186
+ )
187
+ elif optimizer_name == "adam":
188
+ optimizer = torch.optim.Adam(
189
+ self.parameters(),
190
+ lr=lr,
191
+ weight_decay=weight_decay,
192
+ betas=self.optimizer_config.get("betas", [0.9, 0.999]),
193
+ eps=self.optimizer_config.get("eps", 1.0e-8)
194
+ )
195
+ else:
196
+ raise ValueError(f"Unknown optimizer: {optimizer_name}")
197
+
198
+ # Configure scheduler
199
+ scheduler_type = self.scheduler_config.get("type", "cosine").lower()
200
+
201
+ # Calculate total training steps
202
+ max_epochs = getattr(self.trainer, 'max_epochs', 50)
203
+ if self.trainer and hasattr(self.trainer, 'estimated_stepping_batches'):
204
+ total_steps = self.trainer.estimated_stepping_batches
205
+ else:
206
+ # Fallback: estimate steps per epoch
207
+ estimated_steps_per_epoch = 1000 # Conservative estimate
208
+ total_steps = max_epochs * estimated_steps_per_epoch
209
+
210
+ warmup_ratio = self.scheduler_config.get("warmup_ratio", 0.1)
211
+ warmup_steps = max(1, int(total_steps * warmup_ratio))
212
+
213
+ if scheduler_type == "cosine":
214
+ # Use transformers' cosine scheduler with warmup
215
+ try:
216
+ from transformers import get_cosine_schedule_with_warmup
217
+ eta_min = self.scheduler_config.get("eta_min", 1.0e-6)
218
+ scheduler = get_cosine_schedule_with_warmup(
219
+ optimizer,
220
+ num_warmup_steps=warmup_steps,
221
+ num_training_steps=total_steps,
222
+ num_cycles=0.5, # Default cosine cycles
223
+ last_epoch=-1
224
+ )
225
+ except ImportError:
226
+ # Fallback to PyTorch implementation
227
+ from torch.optim.lr_scheduler import LambdaLR
228
+ import math
229
+ eta_min = self.scheduler_config.get("eta_min", 1.0e-6)
230
+ def lr_lambda(step):
231
+ if step < warmup_steps:
232
+ return step / warmup_steps
233
+ else:
234
+ # Cosine annealing after warmup
235
+ progress = (step - warmup_steps) / (total_steps - warmup_steps)
236
+ cosine_value = 0.5 * (1 + math.cos(math.pi * progress))
237
+ return eta_min / lr + (1 - eta_min / lr) * cosine_value
238
+ scheduler = LambdaLR(optimizer, lr_lambda)
239
+
240
+ elif scheduler_type == "linear":
241
+ # Use transformers' linear scheduler with warmup
242
+ try:
243
+ from transformers import get_linear_schedule_with_warmup
244
+ scheduler = get_linear_schedule_with_warmup(
245
+ optimizer,
246
+ num_warmup_steps=warmup_steps,
247
+ num_training_steps=total_steps
248
+ )
249
+ except ImportError:
250
+ # Fallback to PyTorch implementation
251
+ from torch.optim.lr_scheduler import LambdaLR
252
+ def lr_lambda(step):
253
+ if step < warmup_steps:
254
+ return step / warmup_steps
255
+ else:
256
+ progress = (step - warmup_steps) / (total_steps - warmup_steps)
257
+ return max(0.0, 1.0 - progress)
258
+ scheduler = LambdaLR(optimizer, lr_lambda)
259
+
260
+ elif scheduler_type == "polynomial":
261
+ # Use transformers' polynomial scheduler with warmup
262
+ try:
263
+ from transformers import get_polynomial_decay_schedule_with_warmup
264
+ power = self.scheduler_config.get("power", 1.0)
265
+ scheduler = get_polynomial_decay_schedule_with_warmup(
266
+ optimizer,
267
+ num_warmup_steps=warmup_steps,
268
+ num_training_steps=total_steps,
269
+ power=power
270
+ )
271
+ except ImportError:
272
+ # Fallback: use linear scheduler
273
+ from torch.optim.lr_scheduler import LambdaLR
274
+ def lr_lambda(step):
275
+ if step < warmup_steps:
276
+ return step / warmup_steps
277
+ else:
278
+ progress = (step - warmup_steps) / (total_steps - warmup_steps)
279
+ return max(0.0, (1.0 - progress) ** power)
280
+ scheduler = LambdaLR(optimizer, lr_lambda)
281
+
282
+ elif scheduler_type == "reduce_on_plateau":
283
+ from torch.optim.lr_scheduler import ReduceLROnPlateau
284
+ scheduler = ReduceLROnPlateau(
285
+ optimizer,
286
+ mode='min',
287
+ factor=self.scheduler_config.get("factor", 0.5),
288
+ patience=self.scheduler_config.get("patience", 3),
289
+ min_lr=self.scheduler_config.get("min_lr", 1.0e-6),
290
+ verbose=True
291
+ )
292
+ return {
293
+ "optimizer": optimizer,
294
+ "lr_scheduler": {
295
+ "scheduler": scheduler,
296
+ "monitor": "train_loss",
297
+ "interval": "epoch",
298
+ "frequency": 1,
299
+ }
300
+ }
301
+ else:
302
+ # No scheduler
303
+ return optimizer
304
+
305
+ return {
306
+ "optimizer": optimizer,
307
+ "lr_scheduler": {
308
+ "scheduler": scheduler,
309
+ "interval": "step",
310
+ "frequency": 1,
311
+ }
312
+ }
313
+
314
+ # Helper for Positional Encoding
315
+ class PositionalEncoding(nn.Module):
316
+ def __init__(self, d_model, dropout=0.1, max_len=5000):
317
+ super().__init__()
318
+ self.dropout = nn.Dropout(p=dropout)
319
+
320
+ pe = torch.zeros(max_len, d_model)
321
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
322
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
323
+ pe[:, 0::2] = torch.sin(position * div_term)
324
+ pe[:, 1::2] = torch.cos(position * div_term)
325
+ self.register_buffer('pe', pe.unsqueeze(0))
326
+
327
+ def forward(self, x):
328
+ # x: (Batch, Seq, Dim)
329
+ x = x + self.pe[:, :x.size(1)]
330
+ return self.dropout(x)