toilachuoituyet commited on
Commit
43532c9
·
verified ·
1 Parent(s): 1b3984b

Upload project files

Browse files
Files changed (2) hide show
  1. MASC_finetune.py +19 -2
  2. model.py +50 -5
MASC_finetune.py CHANGED
@@ -1,6 +1,7 @@
1
  import torch
2
  import argparse
3
  import os
 
4
  os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
5
  import random
6
  from transformers import BertTokenizer
@@ -14,10 +15,22 @@ import logging
14
  from accelerate import Accelerator
15
  from tqdm import tqdm
16
  from torch.optim import AdamW
17
- from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
18
  from eval_tools import *
19
 
20
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  def set_seed(seed):
22
  torch.manual_seed(seed)
23
  torch.cuda.manual_seed_all(seed)
@@ -249,7 +262,11 @@ if __name__ == "__main__":
249
 
250
  optimizer = AdamW(params=filter(lambda p: p.requires_grad, model.parameters()),
251
  lr=args.lr, betas=(0.9, 0.98), weight_decay=0.05)
252
- scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=2*1000, eta_min=5e-5)
 
 
 
 
253
  print("start training")
254
  finetune(
255
  model=model,
 
1
  import torch
2
  import argparse
3
  import os
4
+ import math
5
  os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
6
  import random
7
  from transformers import BertTokenizer
 
15
  from accelerate import Accelerator
16
  from tqdm import tqdm
17
  from torch.optim import AdamW
18
+ from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
19
  from eval_tools import *
20
 
21
 
22
+ def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, min_lr_ratio=0.1):
23
+ """Cosine LR scheduler with linear warmup for stable training."""
24
+ def lr_lambda(current_step):
25
+ if current_step < num_warmup_steps:
26
+ # Linear warmup
27
+ return float(current_step) / float(max(1, num_warmup_steps))
28
+ # Cosine annealing
29
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
30
+ return max(min_lr_ratio, 0.5 * (1.0 + math.cos(math.pi * progress)))
31
+ return LambdaLR(optimizer, lr_lambda)
32
+
33
+
34
  def set_seed(seed):
35
  torch.manual_seed(seed)
36
  torch.cuda.manual_seed_all(seed)
 
262
 
263
  optimizer = AdamW(params=filter(lambda p: p.requires_grad, model.parameters()),
264
  lr=args.lr, betas=(0.9, 0.98), weight_decay=0.05)
265
+ # Estimate total steps: ~500 steps per epoch, 3 epochs
266
+ num_training_steps = 500 * args.epoch * 6 # 6 dataset chunks
267
+ num_warmup_steps = min(500, num_training_steps // 10) # 10% warmup
268
+ scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, min_lr_ratio=0.1)
269
+ print(f"Using cosine LR with {num_warmup_steps} warmup steps, {num_training_steps} total steps")
270
  print("start training")
271
  finetune(
272
  model=model,
model.py CHANGED
@@ -44,6 +44,46 @@ class LayerNorm(nn.Module):
44
  std = x.std(-1, keepdim=True)
45
  return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  def attention(query, key, mask=None, dropout=None):
48
  d_k = query.size(-1)
49
  scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
@@ -126,11 +166,16 @@ class DASCO(nn.Module):
126
  self.classifier = nn.Linear(self.hidden_size*2, 2)
127
  self.criterion = nn.CrossEntropyLoss()
128
  elif self.task == 'MASC':
129
- self.classifier = nn.Linear(self.hidden_size*2, 3)
130
- # Class weights for imbalanced data: POS=80.8%, NEU=5.8%, NEG=13.3%
131
- # Aggressive weights + label_smoothing for better minority class handling
132
- self.register_buffer('class_weights', torch.tensor([1.0, 18.0, 8.0]))
133
- self.criterion = nn.CrossEntropyLoss(weight=self.class_weights, label_smoothing=0.1)
 
 
 
 
 
134
 
135
  def projection(self, z: torch.Tensor) -> torch.Tensor:
136
  z = F.elu(self.fc1(z))
 
44
  std = x.std(-1, keepdim=True)
45
  return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
46
 
47
+
48
+ class FocalLoss(nn.Module):
49
+ """Focal Loss for handling extreme class imbalance.
50
+
51
+ FL(p_t) = -alpha_t * (1 - p_t)^gamma * log(p_t)
52
+
53
+ - Reduces loss for well-classified samples (high p_t)
54
+ - Focuses training on hard misclassified samples
55
+ - alpha: class weights (inversely proportional to frequency)
56
+ - gamma: focusing parameter (higher = more focus on hard samples)
57
+ """
58
+ def __init__(self, alpha=None, gamma=2.0, reduction='mean', label_smoothing=0.0):
59
+ super(FocalLoss, self).__init__()
60
+ self.alpha = alpha # class weights tensor
61
+ self.gamma = gamma # focusing parameter
62
+ self.reduction = reduction
63
+ self.label_smoothing = label_smoothing
64
+
65
+ def forward(self, inputs, targets):
66
+ # inputs: [N, C], targets: [N]
67
+ ce_loss = F.cross_entropy(inputs, targets, reduction='none',
68
+ label_smoothing=self.label_smoothing)
69
+ pt = torch.exp(-ce_loss) # probability of correct class
70
+
71
+ # Apply focal weight: (1 - pt)^gamma
72
+ focal_weight = (1 - pt) ** self.gamma
73
+
74
+ # Apply class weights if provided
75
+ if self.alpha is not None:
76
+ alpha_t = self.alpha[targets]
77
+ focal_loss = alpha_t * focal_weight * ce_loss
78
+ else:
79
+ focal_loss = focal_weight * ce_loss
80
+
81
+ if self.reduction == 'mean':
82
+ return focal_loss.mean()
83
+ elif self.reduction == 'sum':
84
+ return focal_loss.sum()
85
+ return focal_loss
86
+
87
  def attention(query, key, mask=None, dropout=None):
88
  d_k = query.size(-1)
89
  scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
 
166
  self.classifier = nn.Linear(self.hidden_size*2, 2)
167
  self.criterion = nn.CrossEntropyLoss()
168
  elif self.task == 'MASC':
169
+ # Enhanced classifier with hidden layer and dropout for better generalization
170
+ self.classifier = nn.Sequential(
171
+ nn.Linear(self.hidden_size*2, 256),
172
+ nn.GELU(),
173
+ nn.Dropout(0.3),
174
+ nn.Linear(256, 3)
175
+ )
176
+ # Focal Loss for extreme class imbalance: POS=80.8%, NEU=5.8%, NEG=13.3%
177
+ self.register_buffer('class_weights', torch.tensor([1.0, 12.0, 5.0]))
178
+ self.criterion = FocalLoss(alpha=self.class_weights, gamma=2.0, label_smoothing=0.05)
179
 
180
  def projection(self, z: torch.Tensor) -> torch.Tensor:
181
  z = F.elu(self.fc1(z))