Upload project files
Browse files- MASC_finetune.py +19 -2
- model.py +50 -5
MASC_finetune.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import torch
|
| 2 |
import argparse
|
| 3 |
import os
|
|
|
|
| 4 |
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
|
| 5 |
import random
|
| 6 |
from transformers import BertTokenizer
|
|
@@ -14,10 +15,22 @@ import logging
|
|
| 14 |
from accelerate import Accelerator
|
| 15 |
from tqdm import tqdm
|
| 16 |
from torch.optim import AdamW
|
| 17 |
-
from torch.optim.lr_scheduler import
|
| 18 |
from eval_tools import *
|
| 19 |
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
def set_seed(seed):
|
| 22 |
torch.manual_seed(seed)
|
| 23 |
torch.cuda.manual_seed_all(seed)
|
|
@@ -249,7 +262,11 @@ if __name__ == "__main__":
|
|
| 249 |
|
| 250 |
optimizer = AdamW(params=filter(lambda p: p.requires_grad, model.parameters()),
|
| 251 |
lr=args.lr, betas=(0.9, 0.98), weight_decay=0.05)
|
| 252 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
print("start training")
|
| 254 |
finetune(
|
| 255 |
model=model,
|
|
|
|
| 1 |
import torch
|
| 2 |
import argparse
|
| 3 |
import os
|
| 4 |
+
import math
|
| 5 |
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
|
| 6 |
import random
|
| 7 |
from transformers import BertTokenizer
|
|
|
|
| 15 |
from accelerate import Accelerator
|
| 16 |
from tqdm import tqdm
|
| 17 |
from torch.optim import AdamW
|
| 18 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
|
| 19 |
from eval_tools import *
|
| 20 |
|
| 21 |
|
| 22 |
+
def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, min_lr_ratio=0.1):
|
| 23 |
+
"""Cosine LR scheduler with linear warmup for stable training."""
|
| 24 |
+
def lr_lambda(current_step):
|
| 25 |
+
if current_step < num_warmup_steps:
|
| 26 |
+
# Linear warmup
|
| 27 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
| 28 |
+
# Cosine annealing
|
| 29 |
+
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
|
| 30 |
+
return max(min_lr_ratio, 0.5 * (1.0 + math.cos(math.pi * progress)))
|
| 31 |
+
return LambdaLR(optimizer, lr_lambda)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
def set_seed(seed):
|
| 35 |
torch.manual_seed(seed)
|
| 36 |
torch.cuda.manual_seed_all(seed)
|
|
|
|
| 262 |
|
| 263 |
optimizer = AdamW(params=filter(lambda p: p.requires_grad, model.parameters()),
|
| 264 |
lr=args.lr, betas=(0.9, 0.98), weight_decay=0.05)
|
| 265 |
+
# Estimate total steps: ~500 steps per epoch, 3 epochs
|
| 266 |
+
num_training_steps = 500 * args.epoch * 6 # 6 dataset chunks
|
| 267 |
+
num_warmup_steps = min(500, num_training_steps // 10) # 10% warmup
|
| 268 |
+
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, min_lr_ratio=0.1)
|
| 269 |
+
print(f"Using cosine LR with {num_warmup_steps} warmup steps, {num_training_steps} total steps")
|
| 270 |
print("start training")
|
| 271 |
finetune(
|
| 272 |
model=model,
|
model.py
CHANGED
|
@@ -44,6 +44,46 @@ class LayerNorm(nn.Module):
|
|
| 44 |
std = x.std(-1, keepdim=True)
|
| 45 |
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
|
| 46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
def attention(query, key, mask=None, dropout=None):
|
| 48 |
d_k = query.size(-1)
|
| 49 |
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
|
|
@@ -126,11 +166,16 @@ class DASCO(nn.Module):
|
|
| 126 |
self.classifier = nn.Linear(self.hidden_size*2, 2)
|
| 127 |
self.criterion = nn.CrossEntropyLoss()
|
| 128 |
elif self.task == 'MASC':
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
def projection(self, z: torch.Tensor) -> torch.Tensor:
|
| 136 |
z = F.elu(self.fc1(z))
|
|
|
|
| 44 |
std = x.std(-1, keepdim=True)
|
| 45 |
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
|
| 46 |
|
| 47 |
+
|
| 48 |
+
class FocalLoss(nn.Module):
|
| 49 |
+
"""Focal Loss for handling extreme class imbalance.
|
| 50 |
+
|
| 51 |
+
FL(p_t) = -alpha_t * (1 - p_t)^gamma * log(p_t)
|
| 52 |
+
|
| 53 |
+
- Reduces loss for well-classified samples (high p_t)
|
| 54 |
+
- Focuses training on hard misclassified samples
|
| 55 |
+
- alpha: class weights (inversely proportional to frequency)
|
| 56 |
+
- gamma: focusing parameter (higher = more focus on hard samples)
|
| 57 |
+
"""
|
| 58 |
+
def __init__(self, alpha=None, gamma=2.0, reduction='mean', label_smoothing=0.0):
|
| 59 |
+
super(FocalLoss, self).__init__()
|
| 60 |
+
self.alpha = alpha # class weights tensor
|
| 61 |
+
self.gamma = gamma # focusing parameter
|
| 62 |
+
self.reduction = reduction
|
| 63 |
+
self.label_smoothing = label_smoothing
|
| 64 |
+
|
| 65 |
+
def forward(self, inputs, targets):
|
| 66 |
+
# inputs: [N, C], targets: [N]
|
| 67 |
+
ce_loss = F.cross_entropy(inputs, targets, reduction='none',
|
| 68 |
+
label_smoothing=self.label_smoothing)
|
| 69 |
+
pt = torch.exp(-ce_loss) # probability of correct class
|
| 70 |
+
|
| 71 |
+
# Apply focal weight: (1 - pt)^gamma
|
| 72 |
+
focal_weight = (1 - pt) ** self.gamma
|
| 73 |
+
|
| 74 |
+
# Apply class weights if provided
|
| 75 |
+
if self.alpha is not None:
|
| 76 |
+
alpha_t = self.alpha[targets]
|
| 77 |
+
focal_loss = alpha_t * focal_weight * ce_loss
|
| 78 |
+
else:
|
| 79 |
+
focal_loss = focal_weight * ce_loss
|
| 80 |
+
|
| 81 |
+
if self.reduction == 'mean':
|
| 82 |
+
return focal_loss.mean()
|
| 83 |
+
elif self.reduction == 'sum':
|
| 84 |
+
return focal_loss.sum()
|
| 85 |
+
return focal_loss
|
| 86 |
+
|
| 87 |
def attention(query, key, mask=None, dropout=None):
|
| 88 |
d_k = query.size(-1)
|
| 89 |
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
|
|
|
|
| 166 |
self.classifier = nn.Linear(self.hidden_size*2, 2)
|
| 167 |
self.criterion = nn.CrossEntropyLoss()
|
| 168 |
elif self.task == 'MASC':
|
| 169 |
+
# Enhanced classifier with hidden layer and dropout for better generalization
|
| 170 |
+
self.classifier = nn.Sequential(
|
| 171 |
+
nn.Linear(self.hidden_size*2, 256),
|
| 172 |
+
nn.GELU(),
|
| 173 |
+
nn.Dropout(0.3),
|
| 174 |
+
nn.Linear(256, 3)
|
| 175 |
+
)
|
| 176 |
+
# Focal Loss for extreme class imbalance: POS=80.8%, NEU=5.8%, NEG=13.3%
|
| 177 |
+
self.register_buffer('class_weights', torch.tensor([1.0, 12.0, 5.0]))
|
| 178 |
+
self.criterion = FocalLoss(alpha=self.class_weights, gamma=2.0, label_smoothing=0.05)
|
| 179 |
|
| 180 |
def projection(self, z: torch.Tensor) -> torch.Tensor:
|
| 181 |
z = F.elu(self.fc1(z))
|