fxxkingusername commited on
Commit
d12ab12
·
verified ·
1 Parent(s): 3406633

Upload src/train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/train.py +314 -0
src/train.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Main training script for Architectural Style Classification
4
+ Advanced Deep Learning Approach with Hierarchical Multi-Modal Architecture
5
+ """
6
+
7
+ import os
8
+ import sys
9
+ import json
10
+ import argparse
11
+ from typing import Dict, Any
12
+ import torch
13
+ import pytorch_lightning as pl
14
+
15
+ # Add src to path
16
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
17
+
18
+ from src.models import HierarchicalArchitecturalClassifier, BaselineModels
19
+ from src.training.trainer import ArchitecturalTrainer, ExperimentRunner
20
+ from src.training.losses import CombinedLoss
21
+ from src.utils.config import load_config, save_config
22
+
23
+
24
+ def create_experiment_configs() -> Dict[str, Dict[str, Any]]:
25
+ """Create different experiment configurations."""
26
+
27
+ configs = {
28
+ # Baseline experiments
29
+ 'baseline_resnet': {
30
+ 'experiment_name': 'baseline_resnet',
31
+ 'model_type': 'resnet',
32
+ 'num_classes': 25,
33
+ 'learning_rate': 1e-4,
34
+ 'max_epochs': 50,
35
+ 'batch_size': 32,
36
+ 'use_hierarchical_loss': False,
37
+ 'use_contrastive_loss': False,
38
+ 'use_style_relationship_loss': False,
39
+ 'use_wandb': False
40
+ },
41
+
42
+ 'baseline_efficientnet': {
43
+ 'experiment_name': 'baseline_efficientnet',
44
+ 'model_type': 'efficientnet',
45
+ 'num_classes': 25,
46
+ 'learning_rate': 1e-4,
47
+ 'max_epochs': 50,
48
+ 'batch_size': 32,
49
+ 'use_hierarchical_loss': False,
50
+ 'use_contrastive_loss': False,
51
+ 'use_style_relationship_loss': False,
52
+ 'use_wandb': False
53
+ },
54
+
55
+ 'baseline_vit': {
56
+ 'experiment_name': 'baseline_vit',
57
+ 'model_type': 'vit',
58
+ 'num_classes': 25,
59
+ 'learning_rate': 1e-4,
60
+ 'max_epochs': 50,
61
+ 'batch_size': 16, # Smaller batch size for ViT
62
+ 'use_hierarchical_loss': False,
63
+ 'use_contrastive_loss': False,
64
+ 'use_style_relationship_loss': False,
65
+ 'use_wandb': False
66
+ },
67
+
68
+ # Hierarchical model experiments
69
+ 'hierarchical_basic': {
70
+ 'experiment_name': 'hierarchical_basic',
71
+ 'model_type': 'hierarchical',
72
+ 'num_classes': 25,
73
+ 'num_broad_classes': 5,
74
+ 'num_fine_classes': 25,
75
+ 'learning_rate': 1e-4,
76
+ 'max_epochs': 100,
77
+ 'batch_size': 16,
78
+ 'use_hierarchical_loss': True,
79
+ 'use_contrastive_loss': False,
80
+ 'use_style_relationship_loss': True,
81
+ 'curriculum_stages': [
82
+ {'epochs': 20, 'classes': ['ancient', 'medieval', 'modern']},
83
+ {'epochs': 80, 'classes': list(range(25))}
84
+ ],
85
+ 'use_wandb': False
86
+ },
87
+
88
+ 'hierarchical_contrastive': {
89
+ 'experiment_name': 'hierarchical_contrastive',
90
+ 'model_type': 'hierarchical',
91
+ 'num_classes': 25,
92
+ 'num_broad_classes': 5,
93
+ 'num_fine_classes': 25,
94
+ 'learning_rate': 1e-4,
95
+ 'max_epochs': 100,
96
+ 'batch_size': 16,
97
+ 'use_hierarchical_loss': True,
98
+ 'use_contrastive_loss': True,
99
+ 'use_style_relationship_loss': True,
100
+ 'curriculum_stages': [
101
+ {'epochs': 20, 'classes': ['ancient', 'medieval', 'modern']},
102
+ {'epochs': 80, 'classes': list(range(25))}
103
+ ],
104
+ 'use_wandb': False
105
+ },
106
+
107
+ # Advanced experiments
108
+ 'hierarchical_advanced': {
109
+ 'experiment_name': 'hierarchical_advanced',
110
+ 'model_type': 'hierarchical',
111
+ 'num_classes': 25,
112
+ 'num_broad_classes': 5,
113
+ 'num_fine_classes': 25,
114
+ 'learning_rate': 5e-5,
115
+ 'max_epochs': 150,
116
+ 'batch_size': 16,
117
+ 'use_hierarchical_loss': True,
118
+ 'use_contrastive_loss': True,
119
+ 'use_style_relationship_loss': True,
120
+ 'use_mixed_precision': True,
121
+ 'gradient_clip_val': 1.0,
122
+ 'accumulate_grad_batches': 2,
123
+ 'curriculum_stages': [
124
+ {'epochs': 30, 'classes': ['ancient', 'medieval', 'modern']},
125
+ {'epochs': 60, 'classes': list(range(25))},
126
+ {'epochs': 60, 'classes': list(range(25))}
127
+ ],
128
+ 'use_wandb': True
129
+ }
130
+ }
131
+
132
+ return configs
133
+
134
+
135
+ def run_single_experiment(config: Dict[str, Any], data_path: str = None):
136
+ """Run a single experiment."""
137
+ print(f"Starting experiment: {config['experiment_name']}")
138
+ print(f"Model type: {config['model_type']}")
139
+ print(f"Configuration: {json.dumps(config, indent=2)}")
140
+
141
+ # Initialize experiment runner
142
+ runner = ExperimentRunner(config)
143
+
144
+ # Run experiment
145
+ try:
146
+ trainer, pl_trainer = runner.run_experiment()
147
+ print(f"Experiment {config['experiment_name']} completed successfully!")
148
+ return trainer, pl_trainer
149
+ except Exception as e:
150
+ print(f"Experiment {config['experiment_name']} failed: {str(e)}")
151
+ raise
152
+
153
+
154
+ def run_experiment_suite(experiment_names: list = None, data_path: str = None):
155
+ """Run a suite of experiments."""
156
+ configs = create_experiment_configs()
157
+
158
+ if experiment_names is None:
159
+ experiment_names = list(configs.keys())
160
+
161
+ results = {}
162
+
163
+ for exp_name in experiment_names:
164
+ if exp_name not in configs:
165
+ print(f"Warning: Experiment {exp_name} not found in configurations")
166
+ continue
167
+
168
+ print(f"\n{'='*50}")
169
+ print(f"Running experiment: {exp_name}")
170
+ print(f"{'='*50}")
171
+
172
+ try:
173
+ trainer, pl_trainer = run_single_experiment(configs[exp_name], data_path)
174
+ results[exp_name] = {
175
+ 'status': 'success',
176
+ 'trainer': trainer,
177
+ 'pl_trainer': pl_trainer
178
+ }
179
+ except Exception as e:
180
+ print(f"Experiment {exp_name} failed: {str(e)}")
181
+ results[exp_name] = {
182
+ 'status': 'failed',
183
+ 'error': str(e)
184
+ }
185
+
186
+ # Save results summary
187
+ save_experiment_results(results)
188
+
189
+ return results
190
+
191
+
192
+ def save_experiment_results(results: Dict[str, Any]):
193
+ """Save experiment results summary."""
194
+ summary = {}
195
+
196
+ for exp_name, result in results.items():
197
+ if result['status'] == 'success':
198
+ summary[exp_name] = {
199
+ 'status': 'success',
200
+ 'model_type': result['trainer'].model.__class__.__name__,
201
+ 'hyperparameters': result['trainer'].hparams
202
+ }
203
+ else:
204
+ summary[exp_name] = {
205
+ 'status': 'failed',
206
+ 'error': result.get('error', 'Unknown error')
207
+ }
208
+
209
+ # Save to file
210
+ os.makedirs('results', exist_ok=True)
211
+ with open('results/experiment_summary.json', 'w') as f:
212
+ json.dump(summary, f, indent=2, default=str)
213
+
214
+ print(f"\nExperiment summary saved to results/experiment_summary.json")
215
+
216
+
217
+ def test_model_creation():
218
+ """Test model creation to ensure everything works."""
219
+ print("Testing model creation...")
220
+
221
+ try:
222
+ # Test hierarchical model
223
+ hierarchical_model = HierarchicalArchitecturalClassifier()
224
+ print(f"✓ Hierarchical model created successfully")
225
+ print(f" Parameters: {sum(p.numel() for p in hierarchical_model.parameters()):,}")
226
+
227
+ # Test baseline models
228
+ resnet_model = BaselineModels.resnet50()
229
+ print(f"✓ ResNet-50 model created successfully")
230
+ print(f" Parameters: {sum(p.numel() for p in resnet_model.parameters()):,}")
231
+
232
+ efficientnet_model = BaselineModels.efficientnet_b4()
233
+ print(f"✓ EfficientNet-B4 model created successfully")
234
+ print(f" Parameters: {sum(p.numel() for p in efficientnet_model.parameters()):,}")
235
+
236
+ vit_model = BaselineModels.vit_base()
237
+ print(f"✓ ViT-Base model created successfully")
238
+ print(f" Parameters: {sum(p.numel() for p in vit_model.parameters()):,}")
239
+
240
+ # Test loss functions
241
+ combined_loss = CombinedLoss()
242
+ print(f"✓ Combined loss function created successfully")
243
+
244
+ print("\nAll model tests passed! ✓")
245
+ return True
246
+
247
+ except Exception as e:
248
+ print(f"Model test failed: {str(e)}")
249
+ return False
250
+
251
+
252
+ def main():
253
+ """Main function."""
254
+ parser = argparse.ArgumentParser(description='Architectural Style Classification Training')
255
+ parser.add_argument('--experiment', type=str, default=None,
256
+ help='Specific experiment to run')
257
+ parser.add_argument('--suite', action='store_true',
258
+ help='Run the full experiment suite')
259
+ parser.add_argument('--test', action='store_true',
260
+ help='Test model creation and setup')
261
+ parser.add_argument('--data_path', type=str, default=None,
262
+ help='Path to dataset')
263
+ parser.add_argument('--config', type=str, default=None,
264
+ help='Path to custom config file')
265
+
266
+ args = parser.parse_args()
267
+
268
+ # Set random seeds for reproducibility
269
+ torch.manual_seed(42)
270
+ pl.seed_everything(42)
271
+
272
+ print("Architectural Style Classification Training")
273
+ print("=" * 50)
274
+
275
+ # Test mode
276
+ if args.test:
277
+ if test_model_creation():
278
+ print("Setup test completed successfully!")
279
+ else:
280
+ print("Setup test failed!")
281
+ return 1
282
+
283
+ # Load custom config if provided
284
+ if args.config:
285
+ config = load_config(args.config)
286
+ run_single_experiment(config, args.data_path)
287
+ return 0
288
+
289
+ # Run specific experiment
290
+ if args.experiment:
291
+ configs = create_experiment_configs()
292
+ if args.experiment not in configs:
293
+ print(f"Experiment '{args.experiment}' not found!")
294
+ print(f"Available experiments: {list(configs.keys())}")
295
+ return 1
296
+
297
+ run_single_experiment(configs[args.experiment], args.data_path)
298
+ return 0
299
+
300
+ # Run experiment suite
301
+ if args.suite:
302
+ run_experiment_suite(data_path=args.data_path)
303
+ return 0
304
+
305
+ # Default: run basic hierarchical experiment
306
+ print("No specific experiment specified. Running basic hierarchical experiment...")
307
+ configs = create_experiment_configs()
308
+ run_single_experiment(configs['hierarchical_basic'], args.data_path)
309
+
310
+ return 0
311
+
312
+
313
+ if __name__ == "__main__":
314
+ exit(main())