SoyeonHH commited on
Commit
d449979
·
verified ·
1 Parent(s): 9c197c4

Upload train_deberta_multimodal.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_deberta_multimodal.py +556 -0
train_deberta_multimodal.py ADDED
@@ -0,0 +1,556 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DeBERTa-v3-Large based Multimodal Sentiment Analysis
3
+ Uses raw text with DeBERTa encoder + audio/video features
4
+ """
5
+
6
+ import os
7
+ os.environ['USE_TF'] = '0'
8
+ os.environ['TRANSFORMERS_NO_TF'] = '1'
9
+
10
+ import argparse
11
+ import pickle
12
+ import random
13
+ import numpy as np
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from torch.utils.data import Dataset, DataLoader
18
+ from transformers import AutoTokenizer, AutoModel, get_cosine_schedule_with_warmup
19
+ from tqdm import tqdm
20
+ from sklearn.metrics import f1_score
21
+ import warnings
22
+ warnings.filterwarnings('ignore')
23
+
24
+
25
+ def set_seed(seed):
26
+ random.seed(seed)
27
+ np.random.seed(seed)
28
+ torch.manual_seed(seed)
29
+ torch.cuda.manual_seed_all(seed)
30
+ torch.backends.cudnn.deterministic = True
31
+
32
+
33
+ class MOSEIDataset(Dataset):
34
+ """Dataset with raw text for DeBERTa encoding"""
35
+
36
+ def __init__(self, data, tokenizer, max_length=128):
37
+ self.raw_text = data['raw_text']
38
+ self.audio = torch.tensor(data['audio'], dtype=torch.float32)
39
+ self.video = torch.tensor(data['vision'], dtype=torch.float32)
40
+ self.labels = torch.tensor(data['regression_labels'], dtype=torch.float32)
41
+ self.tokenizer = tokenizer
42
+ self.max_length = max_length
43
+
44
+ def __len__(self):
45
+ return len(self.labels)
46
+
47
+ def __getitem__(self, idx):
48
+ text = str(self.raw_text[idx])
49
+
50
+ # Tokenize text
51
+ encoding = self.tokenizer(
52
+ text,
53
+ max_length=self.max_length,
54
+ padding='max_length',
55
+ truncation=True,
56
+ return_tensors='pt'
57
+ )
58
+
59
+ return {
60
+ 'input_ids': encoding['input_ids'].squeeze(0),
61
+ 'attention_mask': encoding['attention_mask'].squeeze(0),
62
+ 'audio': self.audio[idx],
63
+ 'video': self.video[idx],
64
+ 'label': self.labels[idx]
65
+ }
66
+
67
+
68
+ class DeBERTaMultimodalModel(nn.Module):
69
+ """
70
+ DeBERTa-v3-Large + Audio/Video Fusion
71
+ """
72
+
73
+ def __init__(
74
+ self,
75
+ model_name='microsoft/deberta-v3-large',
76
+ audio_dim=74,
77
+ video_dim=35,
78
+ hidden_size=512,
79
+ num_heads=8,
80
+ num_classes=7,
81
+ dropout=0.2,
82
+ freeze_deberta_layers=20 # Freeze first N layers
83
+ ):
84
+ super().__init__()
85
+
86
+ # DeBERTa encoder
87
+ self.deberta = AutoModel.from_pretrained(model_name)
88
+ self.text_dim = self.deberta.config.hidden_size # 1024 for large
89
+
90
+ # Freeze some layers
91
+ if freeze_deberta_layers > 0:
92
+ for param in self.deberta.embeddings.parameters():
93
+ param.requires_grad = False
94
+ for i, layer in enumerate(self.deberta.encoder.layer):
95
+ if i < freeze_deberta_layers:
96
+ for param in layer.parameters():
97
+ param.requires_grad = False
98
+
99
+ # Audio encoder (temporal)
100
+ self.audio_encoder = nn.Sequential(
101
+ nn.Linear(audio_dim, hidden_size),
102
+ nn.LayerNorm(hidden_size),
103
+ nn.GELU(),
104
+ nn.Dropout(dropout),
105
+ )
106
+ self.audio_temporal = nn.TransformerEncoder(
107
+ nn.TransformerEncoderLayer(
108
+ d_model=hidden_size,
109
+ nhead=num_heads,
110
+ dim_feedforward=hidden_size * 4,
111
+ dropout=dropout,
112
+ activation='gelu',
113
+ batch_first=True
114
+ ),
115
+ num_layers=2
116
+ )
117
+
118
+ # Video encoder (temporal)
119
+ self.video_encoder = nn.Sequential(
120
+ nn.Linear(video_dim, hidden_size),
121
+ nn.LayerNorm(hidden_size),
122
+ nn.GELU(),
123
+ nn.Dropout(dropout),
124
+ )
125
+ self.video_temporal = nn.TransformerEncoder(
126
+ nn.TransformerEncoderLayer(
127
+ d_model=hidden_size,
128
+ nhead=num_heads,
129
+ dim_feedforward=hidden_size * 4,
130
+ dropout=dropout,
131
+ activation='gelu',
132
+ batch_first=True
133
+ ),
134
+ num_layers=2
135
+ )
136
+
137
+ # Project text to hidden_size
138
+ self.text_proj = nn.Sequential(
139
+ nn.Linear(self.text_dim, hidden_size),
140
+ nn.LayerNorm(hidden_size),
141
+ nn.GELU(),
142
+ nn.Dropout(dropout),
143
+ )
144
+
145
+ # Cross-modal attention
146
+ self.text_to_audio_attn = nn.MultiheadAttention(
147
+ hidden_size, num_heads, dropout=dropout, batch_first=True
148
+ )
149
+ self.text_to_video_attn = nn.MultiheadAttention(
150
+ hidden_size, num_heads, dropout=dropout, batch_first=True
151
+ )
152
+ self.audio_to_text_attn = nn.MultiheadAttention(
153
+ hidden_size, num_heads, dropout=dropout, batch_first=True
154
+ )
155
+ self.video_to_text_attn = nn.MultiheadAttention(
156
+ hidden_size, num_heads, dropout=dropout, batch_first=True
157
+ )
158
+
159
+ # Fusion layer
160
+ self.fusion = nn.Sequential(
161
+ nn.Linear(hidden_size * 6, hidden_size * 2), # 6 features: t, a, v, t2a, t2v, multimodal
162
+ nn.LayerNorm(hidden_size * 2),
163
+ nn.GELU(),
164
+ nn.Dropout(dropout),
165
+ nn.Linear(hidden_size * 2, hidden_size),
166
+ nn.LayerNorm(hidden_size),
167
+ nn.GELU(),
168
+ nn.Dropout(dropout),
169
+ )
170
+
171
+ # Classifiers
172
+ self.classifier = nn.Linear(hidden_size, num_classes)
173
+
174
+ # Auxiliary classifiers
175
+ self.text_classifier = nn.Linear(hidden_size, num_classes)
176
+ self.audio_classifier = nn.Linear(hidden_size, num_classes)
177
+ self.video_classifier = nn.Linear(hidden_size, num_classes)
178
+
179
+ def forward(self, input_ids, attention_mask, audio, video):
180
+ batch_size = input_ids.size(0)
181
+
182
+ # Text encoding with DeBERTa
183
+ text_output = self.deberta(input_ids=input_ids, attention_mask=attention_mask)
184
+ text_hidden = text_output.last_hidden_state # (B, seq_len, 1024)
185
+ text_cls = text_hidden[:, 0] # CLS token
186
+
187
+ # Project text
188
+ text_proj = self.text_proj(text_hidden) # (B, seq_len, hidden)
189
+ text_cls_proj = text_proj[:, 0] # (B, hidden)
190
+
191
+ # Audio encoding
192
+ audio_hidden = self.audio_encoder(audio) # (B, 500, hidden)
193
+ audio_hidden = self.audio_temporal(audio_hidden)
194
+ audio_pooled = audio_hidden.mean(dim=1) # (B, hidden)
195
+
196
+ # Video encoding
197
+ video_hidden = self.video_encoder(video) # (B, 500, hidden)
198
+ video_hidden = self.video_temporal(video_hidden)
199
+ video_pooled = video_hidden.mean(dim=1) # (B, hidden)
200
+
201
+ # Cross-modal attention
202
+ # Text attends to audio/video
203
+ text_to_audio, _ = self.text_to_audio_attn(
204
+ text_proj, audio_hidden, audio_hidden
205
+ )
206
+ text_to_video, _ = self.text_to_video_attn(
207
+ text_proj, video_hidden, video_hidden
208
+ )
209
+ text_to_audio_pooled = text_to_audio[:, 0] # (B, hidden)
210
+ text_to_video_pooled = text_to_video[:, 0] # (B, hidden)
211
+
212
+ # Audio/Video attend to text
213
+ audio_to_text, _ = self.audio_to_text_attn(
214
+ audio_hidden, text_proj, text_proj,
215
+ key_padding_mask=(attention_mask == 0)
216
+ )
217
+ video_to_text, _ = self.video_to_text_attn(
218
+ video_hidden, text_proj, text_proj,
219
+ key_padding_mask=(attention_mask == 0)
220
+ )
221
+
222
+ # Multimodal representation
223
+ multimodal = (audio_to_text.mean(dim=1) + video_to_text.mean(dim=1)) / 2
224
+
225
+ # Fusion
226
+ fused = torch.cat([
227
+ text_cls_proj,
228
+ audio_pooled,
229
+ video_pooled,
230
+ text_to_audio_pooled,
231
+ text_to_video_pooled,
232
+ multimodal
233
+ ], dim=-1)
234
+
235
+ fused = self.fusion(fused)
236
+
237
+ # Classification
238
+ logits = self.classifier(fused)
239
+ text_logits = self.text_classifier(text_cls_proj)
240
+ audio_logits = self.audio_classifier(audio_pooled)
241
+ video_logits = self.video_classifier(video_pooled)
242
+
243
+ return logits, text_logits, audio_logits, video_logits
244
+
245
+
246
+ def regression_to_class(pred, num_classes=7):
247
+ """Convert regression prediction to class (0-6)"""
248
+ pred = torch.clamp(pred, -3, 3)
249
+ # Map [-3, 3] to [0, 6]
250
+ return torch.round((pred + 3)).long().clamp(0, num_classes - 1)
251
+
252
+
253
+ def compute_metrics(preds, labels, num_classes=7):
254
+ """Compute evaluation metrics"""
255
+ # Convert to numpy
256
+ preds = preds.cpu().numpy() if torch.is_tensor(preds) else preds
257
+ labels = labels.cpu().numpy() if torch.is_tensor(labels) else labels
258
+
259
+ # Binary accuracy (positive/negative)
260
+ has0_pred = (preds >= 0).astype(int)
261
+ has0_label = (labels >= 0).astype(int)
262
+ has0_acc = (has0_pred == has0_label).mean()
263
+ has0_f1 = f1_score(has0_label, has0_pred, average='weighted')
264
+
265
+ # Non-zero binary
266
+ non0_mask = labels != 0
267
+ if non0_mask.sum() > 0:
268
+ non0_pred = (preds[non0_mask] > 0).astype(int)
269
+ non0_label = (labels[non0_mask] > 0).astype(int)
270
+ non0_acc = (non0_pred == non0_label).mean()
271
+ non0_f1 = f1_score(non0_label, non0_pred, average='weighted')
272
+ else:
273
+ non0_acc = 0.0
274
+ non0_f1 = 0.0
275
+
276
+ # Multi-class accuracy (5 classes: map to 0-4)
277
+ pred_5 = np.clip(np.round(preds + 2), 0, 4).astype(int)
278
+ label_5 = np.clip(np.round(labels + 2), 0, 4).astype(int)
279
+ mult_acc_5 = (pred_5 == label_5).mean()
280
+
281
+ # Multi-class accuracy (7 classes: map to 0-6)
282
+ pred_7 = np.clip(np.round(preds + 3), 0, 6).astype(int)
283
+ label_7 = np.clip(np.round(labels + 3), 0, 6).astype(int)
284
+ mult_acc_7 = (pred_7 == label_7).mean()
285
+
286
+ # MAE and Correlation
287
+ mae = np.abs(preds - labels).mean()
288
+ corr = np.corrcoef(preds, labels)[0, 1] if len(preds) > 1 else 0.0
289
+
290
+ return {
291
+ 'Has0_acc_2': has0_acc,
292
+ 'Has0_F1_score': has0_f1,
293
+ 'Non0_acc_2': non0_acc,
294
+ 'Non0_F1_score': non0_f1,
295
+ 'Mult_acc_5': mult_acc_5,
296
+ 'Mult_acc_7': mult_acc_7,
297
+ 'MAE': mae,
298
+ 'Corr': corr
299
+ }
300
+
301
+
302
+ def train_epoch(model, loader, optimizer, scheduler, device,
303
+ cls_weight=0.7, aux_weight=0.1, mixup_prob=0.5, mixup_alpha=0.4):
304
+ model.train()
305
+ total_loss = 0
306
+
307
+ for batch in tqdm(loader, desc="Training"):
308
+ input_ids = batch['input_ids'].to(device)
309
+ attention_mask = batch['attention_mask'].to(device)
310
+ audio = batch['audio'].to(device)
311
+ video = batch['video'].to(device)
312
+ labels = batch['label'].to(device)
313
+
314
+ # Convert to class labels
315
+ class_labels = regression_to_class(labels)
316
+
317
+ # Mixup
318
+ if random.random() < mixup_prob:
319
+ lam = np.random.beta(mixup_alpha, mixup_alpha)
320
+ idx = torch.randperm(input_ids.size(0))
321
+
322
+ # Mixup audio and video (can't mixup text easily)
323
+ audio = lam * audio + (1 - lam) * audio[idx]
324
+ video = lam * video + (1 - lam) * video[idx]
325
+
326
+ # Forward
327
+ logits, text_logits, audio_logits, video_logits = model(
328
+ input_ids, attention_mask, audio, video
329
+ )
330
+
331
+ # Mixup loss
332
+ loss_main = lam * F.cross_entropy(logits, class_labels) + \
333
+ (1 - lam) * F.cross_entropy(logits, class_labels[idx])
334
+ loss_text = F.cross_entropy(text_logits, class_labels) # Text not mixed
335
+ loss_audio = lam * F.cross_entropy(audio_logits, class_labels) + \
336
+ (1 - lam) * F.cross_entropy(audio_logits, class_labels[idx])
337
+ loss_video = lam * F.cross_entropy(video_logits, class_labels) + \
338
+ (1 - lam) * F.cross_entropy(video_logits, class_labels[idx])
339
+ else:
340
+ # Forward
341
+ logits, text_logits, audio_logits, video_logits = model(
342
+ input_ids, attention_mask, audio, video
343
+ )
344
+
345
+ loss_main = F.cross_entropy(logits, class_labels)
346
+ loss_text = F.cross_entropy(text_logits, class_labels)
347
+ loss_audio = F.cross_entropy(audio_logits, class_labels)
348
+ loss_video = F.cross_entropy(video_logits, class_labels)
349
+
350
+ # Total loss
351
+ loss = cls_weight * loss_main + \
352
+ aux_weight * (loss_text + loss_audio + loss_video)
353
+
354
+ optimizer.zero_grad()
355
+ loss.backward()
356
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
357
+ optimizer.step()
358
+ scheduler.step()
359
+
360
+ total_loss += loss.item()
361
+
362
+ return total_loss / len(loader)
363
+
364
+
365
+ @torch.no_grad()
366
+ def evaluate(model, loader, device):
367
+ model.eval()
368
+ all_preds = []
369
+ all_labels = []
370
+ total_loss = 0
371
+
372
+ for batch in tqdm(loader, desc="Evaluating"):
373
+ input_ids = batch['input_ids'].to(device)
374
+ attention_mask = batch['attention_mask'].to(device)
375
+ audio = batch['audio'].to(device)
376
+ video = batch['video'].to(device)
377
+ labels = batch['label'].to(device)
378
+
379
+ logits, _, _, _ = model(input_ids, attention_mask, audio, video)
380
+
381
+ # Convert logits to regression predictions
382
+ probs = F.softmax(logits, dim=-1)
383
+ class_preds = torch.argmax(probs, dim=-1)
384
+ reg_preds = class_preds.float() - 3 # Map [0,6] back to [-3,3]
385
+
386
+ # Loss
387
+ class_labels = regression_to_class(labels)
388
+ loss = F.cross_entropy(logits, class_labels)
389
+ total_loss += loss.item()
390
+
391
+ all_preds.append(reg_preds.cpu())
392
+ all_labels.append(labels.cpu())
393
+
394
+ preds = torch.cat(all_preds).numpy()
395
+ labels = torch.cat(all_labels).numpy()
396
+
397
+ metrics = compute_metrics(preds, labels)
398
+ metrics['loss'] = total_loss / len(loader)
399
+
400
+ return metrics
401
+
402
+
403
+ def main():
404
+ parser = argparse.ArgumentParser()
405
+ parser.add_argument('--pkl_path', type=str, required=True)
406
+ parser.add_argument('--model_name', type=str, default='microsoft/deberta-v3-large')
407
+ parser.add_argument('--hidden_size', type=int, default=512)
408
+ parser.add_argument('--num_heads', type=int, default=8)
409
+ parser.add_argument('--freeze_layers', type=int, default=20)
410
+ parser.add_argument('--lr', type=float, default=2e-5)
411
+ parser.add_argument('--deberta_lr', type=float, default=5e-6)
412
+ parser.add_argument('--batch_size', type=int, default=16)
413
+ parser.add_argument('--epochs', type=int, default=50)
414
+ parser.add_argument('--early_stop', type=int, default=15)
415
+ parser.add_argument('--max_length', type=int, default=128)
416
+ parser.add_argument('--mixup_prob', type=float, default=0.5)
417
+ parser.add_argument('--mixup_alpha', type=float, default=0.4)
418
+ parser.add_argument('--cls_weight', type=float, default=0.7)
419
+ parser.add_argument('--aux_weight', type=float, default=0.1)
420
+ parser.add_argument('--dropout', type=float, default=0.2)
421
+ parser.add_argument('--checkpoint_dir', type=str, default='./checkpoints_deberta')
422
+ parser.add_argument('--seed', type=int, default=42)
423
+ args = parser.parse_args()
424
+
425
+ set_seed(args.seed)
426
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
427
+ print(f"Using device: {device}")
428
+
429
+ # Load data
430
+ print(f"Loading data from {args.pkl_path}")
431
+ with open(args.pkl_path, 'rb') as f:
432
+ data = pickle.load(f)
433
+
434
+ # Load tokenizer
435
+ print(f"Loading tokenizer: {args.model_name}")
436
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
437
+
438
+ # Create datasets
439
+ train_dataset = MOSEIDataset(data['train'], tokenizer, args.max_length)
440
+ valid_dataset = MOSEIDataset(data['valid'], tokenizer, args.max_length)
441
+ test_dataset = MOSEIDataset(data['test'], tokenizer, args.max_length)
442
+
443
+ train_loader = DataLoader(
444
+ train_dataset, batch_size=args.batch_size, shuffle=True,
445
+ num_workers=4, pin_memory=True
446
+ )
447
+ valid_loader = DataLoader(
448
+ valid_dataset, batch_size=args.batch_size * 2, shuffle=False,
449
+ num_workers=4, pin_memory=True
450
+ )
451
+ test_loader = DataLoader(
452
+ test_dataset, batch_size=args.batch_size * 2, shuffle=False,
453
+ num_workers=4, pin_memory=True
454
+ )
455
+
456
+ print(f"Train: {len(train_dataset)}, Valid: {len(valid_dataset)}, Test: {len(test_dataset)}")
457
+
458
+ # Create model
459
+ print(f"Creating model with hidden_size={args.hidden_size}")
460
+ model = DeBERTaMultimodalModel(
461
+ model_name=args.model_name,
462
+ hidden_size=args.hidden_size,
463
+ num_heads=args.num_heads,
464
+ dropout=args.dropout,
465
+ freeze_deberta_layers=args.freeze_layers
466
+ ).to(device)
467
+
468
+ # Count parameters
469
+ total_params = sum(p.numel() for p in model.parameters())
470
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
471
+ print(f"Total parameters: {total_params:,}")
472
+ print(f"Trainable parameters: {trainable_params:,}")
473
+
474
+ # Optimizer with different learning rates
475
+ deberta_params = list(model.deberta.parameters())
476
+ other_params = [p for n, p in model.named_parameters() if 'deberta' not in n]
477
+
478
+ optimizer = torch.optim.AdamW([
479
+ {'params': [p for p in deberta_params if p.requires_grad], 'lr': args.deberta_lr},
480
+ {'params': other_params, 'lr': args.lr}
481
+ ], weight_decay=0.01)
482
+
483
+ # Scheduler
484
+ total_steps = len(train_loader) * args.epochs
485
+ warmup_steps = int(total_steps * 0.1)
486
+ scheduler = get_cosine_schedule_with_warmup(
487
+ optimizer, warmup_steps, total_steps
488
+ )
489
+
490
+ # Training
491
+ import os
492
+ os.makedirs(args.checkpoint_dir, exist_ok=True)
493
+
494
+ best_acc = 0
495
+ patience = 0
496
+
497
+ for epoch in range(args.epochs):
498
+ print(f"\nEpoch {epoch+1}/{args.epochs}")
499
+
500
+ train_loss = train_epoch(
501
+ model, train_loader, optimizer, scheduler, device,
502
+ cls_weight=args.cls_weight,
503
+ aux_weight=args.aux_weight,
504
+ mixup_prob=args.mixup_prob,
505
+ mixup_alpha=args.mixup_alpha
506
+ )
507
+ print(f"Train Loss: {train_loss:.4f}")
508
+
509
+ # Validation
510
+ valid_metrics = evaluate(model, valid_loader, device)
511
+ print(f"Valid Loss: {valid_metrics['loss']:.4f}")
512
+ print(f"Mult_acc_7: {valid_metrics['Mult_acc_7']:.4f} | "
513
+ f"Mult_acc_5: {valid_metrics['Mult_acc_5']:.4f} | "
514
+ f"Has0_acc: {valid_metrics['Has0_acc_2']:.4f}")
515
+ print(f"MAE: {valid_metrics['MAE']:.4f} | Corr: {valid_metrics['Corr']:.4f}")
516
+
517
+ # Save best model
518
+ if valid_metrics['Mult_acc_7'] > best_acc:
519
+ best_acc = valid_metrics['Mult_acc_7']
520
+ patience = 0
521
+ torch.save({
522
+ 'epoch': epoch,
523
+ 'model_state_dict': model.state_dict(),
524
+ 'optimizer_state_dict': optimizer.state_dict(),
525
+ 'best_acc': best_acc,
526
+ 'args': args
527
+ }, os.path.join(args.checkpoint_dir, 'best_model.pt'))
528
+ print(f"*** New best model saved! Mult_acc_7: {best_acc:.4f} ***")
529
+ else:
530
+ patience += 1
531
+ if patience >= args.early_stop:
532
+ print(f"\nEarly stopping at epoch {epoch+1}")
533
+ break
534
+
535
+ # Load best model and evaluate on test
536
+ print("\nLoaded best model for final evaluation")
537
+ checkpoint = torch.load(os.path.join(args.checkpoint_dir, 'best_model.pt'))
538
+ model.load_state_dict(checkpoint['model_state_dict'])
539
+
540
+ print("\n" + "=" * 60)
541
+ print("Final Test Evaluation")
542
+ print("=" * 60)
543
+
544
+ test_metrics = evaluate(model, test_loader, device)
545
+ print(f"Test Loss: {test_metrics['loss']:.4f}")
546
+ print("\nTest Metrics:")
547
+ print("-" * 40)
548
+ for k, v in test_metrics.items():
549
+ if k != 'loss':
550
+ print(f" {k}: {v:.4f}")
551
+ print("-" * 40)
552
+ print(f"\n*** Final Mult_acc_7: {test_metrics['Mult_acc_7']:.4f} ***")
553
+
554
+
555
+ if __name__ == '__main__':
556
+ main()