fxxkingusername commited on
Commit
56037cf
·
verified ·
1 Parent(s): 222f269

Upload src/training\trainer.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/training//trainer.py +414 -0
src/training//trainer.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Enhanced trainer for architectural style classification.
3
+ Includes advanced optimization techniques for better accuracy.
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.optim as optim
9
+ from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, OneCycleLR, ReduceLROnPlateau
10
+ import pytorch_lightning as pl
11
+ from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
12
+ from pytorch_lightning.loggers import TensorBoardLogger
13
+ import numpy as np
14
+ from typing import Dict, List, Optional, Tuple, Any
15
+ import os
16
+ import json
17
+ from datetime import datetime
18
+
19
+ from .losses import HierarchicalLoss, ContrastiveLoss, StyleRelationshipLoss, FocalLoss, LabelSmoothingLoss
20
+ from .metrics import ArchitecturalMetrics
21
+ from .data_loader import EnhancedArchitecturalDataLoader
22
+
23
+
24
+ class EnhancedArchitecturalTrainer(pl.LightningModule):
25
+ """Enhanced trainer for architectural style classification with advanced optimization."""
26
+
27
+ def __init__(self, model: nn.Module, config: Dict[str, Any]):
28
+ super().__init__()
29
+ self.model = model
30
+ self.config = config
31
+ self.save_hyperparameters(ignore=['model'])
32
+
33
+ # Enhanced configuration
34
+ self.learning_rate = config.get('learning_rate', 1e-4)
35
+ self.weight_decay = config.get('weight_decay', 1e-4)
36
+ self.batch_size = config.get('batch_size', 8)
37
+ self.num_classes = config.get('num_classes', 25)
38
+ self.use_mixed_precision = config.get('use_mixed_precision', True)
39
+ self.use_early_stopping = config.get('use_early_stopping', True)
40
+ self.patience = config.get('patience', 15)
41
+ self.gradient_clip_val = config.get('gradient_clip_val', 1.0)
42
+ self.accumulate_grad_batches = config.get('accumulate_grad_batches', 2)
43
+
44
+ # Enhanced loss functions
45
+ self.use_focal_loss = config.get('use_focal_loss', True)
46
+ self.use_label_smoothing = config.get('use_label_smoothing', True)
47
+ self.use_contrastive_loss = config.get('use_contrastive_loss', True)
48
+
49
+ # Initialize loss functions
50
+ self._init_loss_functions()
51
+
52
+ # Initialize metrics
53
+ self.metrics = ArchitecturalMetrics(num_classes=self.num_classes)
54
+
55
+ # Curriculum learning
56
+ self.curriculum_stage = 0
57
+ self.curriculum_classes_count = self.num_classes
58
+
59
+ # Learning rate scheduling
60
+ self.scheduler_step_size = config.get('scheduler_step_size', 10)
61
+ self.scheduler_gamma = config.get('scheduler_gamma', 0.5)
62
+ self.warmup_epochs = config.get('warmup_epochs', 5)
63
+
64
+ # TensorBoard logger
65
+ self.tensorboard_logger = TensorBoardLogger(
66
+ save_dir='logs',
67
+ name=f'architectural_training_{datetime.now().strftime("%Y%m%d_%H%M%S")}',
68
+ version=None
69
+ )
70
+
71
+ def _init_loss_functions(self):
72
+ """Initialize enhanced loss functions."""
73
+ # Main classification loss
74
+ if self.use_focal_loss:
75
+ self.classification_loss = FocalLoss(
76
+ alpha=1.0,
77
+ gamma=2.0,
78
+ num_classes=self.num_classes
79
+ )
80
+ elif self.use_label_smoothing:
81
+ self.classification_loss = LabelSmoothingLoss(
82
+ smoothing=0.1,
83
+ num_classes=self.num_classes
84
+ )
85
+ else:
86
+ self.classification_loss = nn.CrossEntropyLoss()
87
+
88
+ # Additional loss functions
89
+ if self.use_contrastive_loss:
90
+ self.contrastive_loss = ContrastiveLoss(temperature=0.07)
91
+
92
+ # Hierarchical loss for multi-scale features
93
+ self.hierarchical_loss = HierarchicalLoss(
94
+ num_classes=self.num_classes,
95
+ hierarchy_weights=[1.0, 0.5, 0.25]
96
+ )
97
+
98
+ # Style relationship loss
99
+ self.style_relationship_loss = StyleRelationshipLoss(
100
+ num_classes=self.num_classes,
101
+ temperature=0.1
102
+ )
103
+
104
+ def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
105
+ """Forward pass through the model."""
106
+ return self.model(x)
107
+
108
+ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> Dict[str, torch.Tensor]:
109
+ """Enhanced training step with multiple loss components."""
110
+ images, labels = batch
111
+
112
+ # Forward pass
113
+ outputs = self(images)
114
+
115
+ # Extract logits
116
+ if isinstance(outputs, dict):
117
+ logits = outputs.get('fine_logits', outputs.get('logits'))
118
+ features = outputs.get('features', None)
119
+ hierarchical_outputs = outputs.get('hierarchical_outputs', None)
120
+ else:
121
+ logits = outputs
122
+ features = None
123
+ hierarchical_outputs = None
124
+
125
+ # Calculate main classification loss
126
+ if self.use_focal_loss or self.use_label_smoothing:
127
+ main_loss = self.classification_loss(logits, labels)
128
+ else:
129
+ main_loss = self.classification_loss(logits, labels)
130
+
131
+ # Calculate additional losses
132
+ total_loss = main_loss
133
+ loss_dict = {'main_loss': main_loss}
134
+
135
+ # Hierarchical loss
136
+ if hierarchical_outputs is not None:
137
+ hierarchical_loss = self.hierarchical_loss(hierarchical_outputs, labels)
138
+ total_loss += 0.3 * hierarchical_loss
139
+ loss_dict['hierarchical_loss'] = hierarchical_loss
140
+
141
+ # Contrastive loss
142
+ if self.use_contrastive_loss and features is not None:
143
+ contrastive_loss = self.contrastive_loss(features, labels)
144
+ total_loss += 0.1 * contrastive_loss
145
+ loss_dict['contrastive_loss'] = contrastive_loss
146
+
147
+ # Style relationship loss
148
+ style_loss = self.style_relationship_loss(logits, labels)
149
+ total_loss += 0.05 * style_loss
150
+ loss_dict['style_loss'] = style_loss
151
+
152
+ # Calculate metrics
153
+ with torch.no_grad():
154
+ metrics = self.metrics.compute(logits, labels)
155
+ for key, value in metrics.items():
156
+ if isinstance(value, (int, float)):
157
+ self.log(f'train_{key}', float(value), prog_bar=True)
158
+
159
+ # Log losses
160
+ loss_dict['loss'] = total_loss
161
+ for key, value in loss_dict.items():
162
+ self.log(f'train_{key}', value, prog_bar=True)
163
+
164
+ return loss_dict
165
+
166
+ def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> Dict[str, torch.Tensor]:
167
+ """Enhanced validation step."""
168
+ images, labels = batch
169
+
170
+ # Forward pass
171
+ outputs = self(images)
172
+
173
+ # Extract logits
174
+ if isinstance(outputs, dict):
175
+ logits = outputs.get('fine_logits', outputs.get('logits'))
176
+ else:
177
+ logits = outputs
178
+
179
+ # Calculate loss
180
+ val_loss = self.classification_loss(logits, labels)
181
+
182
+ # Calculate metrics
183
+ metrics = self.metrics.compute(logits, labels)
184
+
185
+ # Log validation metrics
186
+ self.log('val_loss', val_loss, prog_bar=True)
187
+ for key, value in metrics.items():
188
+ if isinstance(value, (int, float)):
189
+ self.log(f'val_{key}', float(value), prog_bar=True)
190
+
191
+ return {'val_loss': val_loss, 'logits': logits, 'labels': labels}
192
+
193
+ def on_validation_epoch_end(self) -> None:
194
+ """Enhanced validation epoch end with detailed logging."""
195
+ # Log curriculum learning progress
196
+ self.log('curriculum_stage', float(self.curriculum_stage), prog_bar=True)
197
+ self.log('curriculum_classes_count', float(self.curriculum_classes_count), prog_bar=True)
198
+
199
+ # Log learning rate
200
+ current_lr = self.optimizers().param_groups[0]['lr']
201
+ self.log('learning_rate', current_lr, prog_bar=True)
202
+
203
+ def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> Dict[str, torch.Tensor]:
204
+ """Enhanced test step."""
205
+ images, labels = batch
206
+
207
+ # Forward pass
208
+ outputs = self(images)
209
+
210
+ # Extract logits
211
+ if isinstance(outputs, dict):
212
+ logits = outputs.get('fine_logits', outputs.get('logits'))
213
+ else:
214
+ logits = outputs
215
+
216
+ # Calculate metrics
217
+ metrics = self.metrics.compute(logits, labels)
218
+
219
+ # Log test metrics
220
+ for key, value in metrics.items():
221
+ if isinstance(value, (int, float)):
222
+ self.log(f'test_{key}', float(value), prog_bar=True)
223
+
224
+ return {'logits': logits, 'labels': labels}
225
+
226
+ def on_test_epoch_end(self) -> None:
227
+ """Save test results."""
228
+ # Save confusion matrix
229
+ confusion_matrix = self.metrics.confusion_matrix
230
+ if confusion_matrix is not None:
231
+ np.save('results/confusion_matrix.npy', confusion_matrix.cpu().numpy())
232
+
233
+ # Save detailed results
234
+ results = {
235
+ 'model_name': self.model.__class__.__name__,
236
+ 'config': self.config,
237
+ 'test_metrics': {
238
+ 'accuracy': float(self.metrics.accuracy),
239
+ 'precision_macro': float(self.metrics.precision_macro),
240
+ 'recall_macro': float(self.metrics.recall_macro),
241
+ 'f1_macro': float(self.metrics.f1_macro),
242
+ 'precision_weighted': float(self.metrics.precision_weighted),
243
+ 'recall_weighted': float(self.metrics.recall_weighted),
244
+ 'f1_weighted': float(self.metrics.f1_weighted),
245
+ }
246
+ }
247
+
248
+ # Save results
249
+ os.makedirs('results', exist_ok=True)
250
+ with open(f'results/{self.config.get("experiment_name", "test")}_results.json', 'w') as f:
251
+ json.dump(results, f, indent=2)
252
+
253
+ def configure_optimizers(self):
254
+ """Configure enhanced optimizers and schedulers."""
255
+ # Enhanced optimizer with better parameters
256
+ optimizer = optim.AdamW(
257
+ self.parameters(),
258
+ lr=self.learning_rate,
259
+ weight_decay=self.weight_decay,
260
+ betas=(0.9, 0.999),
261
+ eps=1e-8
262
+ )
263
+
264
+ # Enhanced learning rate scheduler
265
+ scheduler = CosineAnnealingWarmRestarts(
266
+ optimizer,
267
+ T_0=10, # Restart every 10 epochs
268
+ T_mult=2, # Double the restart interval each time
269
+ eta_min=1e-7 # Minimum learning rate
270
+ )
271
+
272
+ return {
273
+ 'optimizer': optimizer,
274
+ 'lr_scheduler': {
275
+ 'scheduler': scheduler,
276
+ 'monitor': 'val_loss',
277
+ 'interval': 'epoch',
278
+ 'frequency': 1
279
+ }
280
+ }
281
+
282
+ def create_callbacks(self) -> List[pl.Callback]:
283
+ """Create enhanced callbacks."""
284
+ callbacks = []
285
+
286
+ # Model checkpointing
287
+ checkpoint_callback = ModelCheckpoint(
288
+ dirpath='models/checkpoints',
289
+ filename=f'{self.config.get("experiment_name", "model")}-{{epoch:02d}}-{{val_loss:.4f}}',
290
+ monitor='val_loss',
291
+ mode='min',
292
+ save_top_k=3,
293
+ save_last=True
294
+ )
295
+ callbacks.append(checkpoint_callback)
296
+
297
+ # Learning rate monitoring
298
+ lr_monitor = LearningRateMonitor(logging_interval='epoch')
299
+ callbacks.append(lr_monitor)
300
+
301
+ # Early stopping (optional)
302
+ if self.use_early_stopping:
303
+ early_stopping = EarlyStopping(
304
+ monitor='val_loss',
305
+ mode='min',
306
+ patience=self.patience,
307
+ verbose=True
308
+ )
309
+ callbacks.append(early_stopping)
310
+
311
+ return callbacks
312
+
313
+ def create_data_loaders(self, data_path: str) -> Tuple[Any, Any, Any]:
314
+ """Create enhanced data loaders."""
315
+ # Enhanced data loader with better augmentation
316
+ data_loader = EnhancedArchitecturalDataLoader(
317
+ data_dir=data_path,
318
+ batch_size=self.batch_size,
319
+ num_workers=4,
320
+ use_albumentations=True # Use advanced augmentation
321
+ )
322
+
323
+ # Calculate sample sizes based on available data
324
+ total_samples = len(data_loader.get_train_loader().dataset)
325
+ train_samples = int(0.7 * total_samples)
326
+ val_samples = max(1, int(0.15 * total_samples))
327
+ test_samples = max(1, int(0.15 * total_samples))
328
+
329
+ print(f"Data split: Train={train_samples}, Val={val_samples}, Test={test_samples}")
330
+
331
+ train_loader = data_loader.get_train_loader(train_samples)
332
+ val_loader = data_loader.get_val_loader(val_samples)
333
+ test_loader = data_loader.get_test_loader(test_samples)
334
+
335
+ return train_loader, val_loader, test_loader
336
+
337
+ def update_curriculum(self, epoch: int):
338
+ """Update curriculum learning stage."""
339
+ # Progressive curriculum: start with fewer classes, gradually increase
340
+ if epoch < 10:
341
+ self.curriculum_stage = 0
342
+ self.curriculum_classes_count = min(10, self.num_classes)
343
+ elif epoch < 30:
344
+ self.curriculum_stage = 1
345
+ self.curriculum_classes_count = min(20, self.num_classes)
346
+ else:
347
+ self.curriculum_stage = 2
348
+ self.curriculum_classes_count = self.num_classes
349
+
350
+ # Update model for current curriculum stage
351
+ self.update_model_for_stage()
352
+
353
+ def update_model_for_stage(self):
354
+ """Update model for current curriculum stage."""
355
+ # This can be implemented to modify model behavior based on curriculum stage
356
+ pass
357
+
358
+
359
+ class EnhancedExperimentRunner:
360
+ """Enhanced experiment runner with advanced optimization."""
361
+
362
+ def __init__(self, config: Dict[str, Any]):
363
+ self.config = config
364
+ self.experiment_name = config.get('experiment_name', 'enhanced_experiment')
365
+
366
+ def run_experiment(self, model: nn.Module, data_path: str):
367
+ """Run enhanced experiment."""
368
+ print(f"Starting enhanced experiment: {self.experiment_name}")
369
+
370
+ # Create enhanced trainer
371
+ trainer = EnhancedArchitecturalTrainer(model, self.config)
372
+
373
+ # Create data loaders
374
+ train_loader, val_loader, test_loader = trainer.create_data_loaders(data_path)
375
+
376
+ # Create callbacks
377
+ callbacks = trainer.create_callbacks()
378
+
379
+ # Create Lightning trainer
380
+ lightning_trainer = pl.Trainer(
381
+ max_epochs=self.config.get('epochs', 100),
382
+ accelerator='auto',
383
+ devices='auto',
384
+ precision='16-mixed' if self.config.get('use_mixed_precision', True) else '32',
385
+ gradient_clip_val=self.config.get('gradient_clip_val', 1.0),
386
+ accumulate_grad_batches=self.config.get('accumulate_grad_batches', 2),
387
+ callbacks=callbacks,
388
+ logger=trainer.tensorboard_logger,
389
+ log_every_n_steps=10,
390
+ val_check_interval=0.5, # Validate twice per epoch
391
+ enable_progress_bar=True,
392
+ enable_model_summary=True,
393
+ enable_checkpointing=True,
394
+ )
395
+
396
+ # Train the model
397
+ lightning_trainer.fit(trainer, train_loader, val_loader)
398
+
399
+ # Test the model
400
+ lightning_trainer.test(trainer, test_loader)
401
+
402
+ print(f"Enhanced experiment {self.experiment_name} completed successfully!")
403
+
404
+ return trainer
405
+
406
+
407
+ # Keep backward compatibility
408
+ class ArchitecturalTrainer(EnhancedArchitecturalTrainer):
409
+ """Backward compatibility wrapper."""
410
+ pass
411
+
412
+ class ExperimentRunner(EnhancedExperimentRunner):
413
+ """Backward compatibility wrapper."""
414
+ pass