fxxkingusername commited on
Commit
a4126d0
·
verified ·
1 Parent(s): 4a9f2e7

Upload src/utils\config.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/utils//config.py +85 -0
src/utils//config.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration utilities for architectural style classification experiments.
3
+ """
4
+
5
+ import json
6
+ import os
7
+ from typing import Dict, Any, Optional
8
+ import yaml
9
+
10
+
11
+ def load_config(config_path: str) -> Dict[str, Any]:
12
+ """Load configuration from JSON or YAML file."""
13
+ if not os.path.exists(config_path):
14
+ raise FileNotFoundError(f"Config file not found: {config_path}")
15
+
16
+ with open(config_path, 'r') as f:
17
+ if config_path.endswith('.json'):
18
+ config = json.load(f)
19
+ elif config_path.endswith('.yaml') or config_path.endswith('.yml'):
20
+ config = yaml.safe_load(f)
21
+ else:
22
+ raise ValueError(f"Unsupported config file format: {config_path}")
23
+
24
+ return config
25
+
26
+
27
+ def save_config(config: Dict[str, Any], config_path: str):
28
+ """Save configuration to JSON or YAML file."""
29
+ os.makedirs(os.path.dirname(config_path), exist_ok=True)
30
+
31
+ with open(config_path, 'w') as f:
32
+ if config_path.endswith('.json'):
33
+ json.dump(config, f, indent=2, default=str)
34
+ elif config_path.endswith('.yaml') or config_path.endswith('.yml'):
35
+ yaml.dump(config, f, default_flow_style=False)
36
+ else:
37
+ raise ValueError(f"Unsupported config file format: {config_path}")
38
+
39
+
40
+ def validate_config(config: Dict[str, Any]) -> bool:
41
+ """Validate configuration parameters."""
42
+ required_fields = ['experiment_name', 'model_type', 'num_classes']
43
+
44
+ for field in required_fields:
45
+ if field not in config:
46
+ raise ValueError(f"Missing required field: {field}")
47
+
48
+ # Validate model type
49
+ valid_model_types = ['hierarchical', 'resnet', 'efficientnet', 'vit']
50
+ if config['model_type'] not in valid_model_types:
51
+ raise ValueError(f"Invalid model_type: {config['model_type']}. Must be one of {valid_model_types}")
52
+
53
+ # Validate numeric fields
54
+ numeric_fields = ['num_classes', 'learning_rate', 'max_epochs', 'batch_size']
55
+ for field in numeric_fields:
56
+ if field in config and not isinstance(config[field], (int, float)):
57
+ raise ValueError(f"Field {field} must be numeric")
58
+
59
+ return True
60
+
61
+
62
+ def create_default_config() -> Dict[str, Any]:
63
+ """Create a default configuration."""
64
+ return {
65
+ 'experiment_name': 'architectural_classification',
66
+ 'model_type': 'hierarchical',
67
+ 'num_classes': 25,
68
+ 'num_broad_classes': 5,
69
+ 'num_fine_classes': 25,
70
+ 'learning_rate': 1e-4,
71
+ 'weight_decay': 1e-5,
72
+ 'max_epochs': 100,
73
+ 'batch_size': 16,
74
+ 'use_hierarchical_loss': True,
75
+ 'use_contrastive_loss': False,
76
+ 'use_style_relationship_loss': True,
77
+ 'use_mixed_precision': False,
78
+ 'gradient_clip_val': 1.0,
79
+ 'accumulate_grad_batches': 1,
80
+ 'use_wandb': False,
81
+ 'curriculum_stages': [
82
+ {'epochs': 20, 'classes': ['ancient', 'medieval', 'modern']},
83
+ {'epochs': 80, 'classes': list(range(25))}
84
+ ]
85
+ }