Upload src/utils\config.py with huggingface_hub
Browse files- src/utils//config.py +85 -0
src/utils//config.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration utilities for architectural style classification experiments.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
from typing import Dict, Any, Optional
|
| 8 |
+
import yaml
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def load_config(config_path: str) -> Dict[str, Any]:
|
| 12 |
+
"""Load configuration from JSON or YAML file."""
|
| 13 |
+
if not os.path.exists(config_path):
|
| 14 |
+
raise FileNotFoundError(f"Config file not found: {config_path}")
|
| 15 |
+
|
| 16 |
+
with open(config_path, 'r') as f:
|
| 17 |
+
if config_path.endswith('.json'):
|
| 18 |
+
config = json.load(f)
|
| 19 |
+
elif config_path.endswith('.yaml') or config_path.endswith('.yml'):
|
| 20 |
+
config = yaml.safe_load(f)
|
| 21 |
+
else:
|
| 22 |
+
raise ValueError(f"Unsupported config file format: {config_path}")
|
| 23 |
+
|
| 24 |
+
return config
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def save_config(config: Dict[str, Any], config_path: str):
|
| 28 |
+
"""Save configuration to JSON or YAML file."""
|
| 29 |
+
os.makedirs(os.path.dirname(config_path), exist_ok=True)
|
| 30 |
+
|
| 31 |
+
with open(config_path, 'w') as f:
|
| 32 |
+
if config_path.endswith('.json'):
|
| 33 |
+
json.dump(config, f, indent=2, default=str)
|
| 34 |
+
elif config_path.endswith('.yaml') or config_path.endswith('.yml'):
|
| 35 |
+
yaml.dump(config, f, default_flow_style=False)
|
| 36 |
+
else:
|
| 37 |
+
raise ValueError(f"Unsupported config file format: {config_path}")
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def validate_config(config: Dict[str, Any]) -> bool:
|
| 41 |
+
"""Validate configuration parameters."""
|
| 42 |
+
required_fields = ['experiment_name', 'model_type', 'num_classes']
|
| 43 |
+
|
| 44 |
+
for field in required_fields:
|
| 45 |
+
if field not in config:
|
| 46 |
+
raise ValueError(f"Missing required field: {field}")
|
| 47 |
+
|
| 48 |
+
# Validate model type
|
| 49 |
+
valid_model_types = ['hierarchical', 'resnet', 'efficientnet', 'vit']
|
| 50 |
+
if config['model_type'] not in valid_model_types:
|
| 51 |
+
raise ValueError(f"Invalid model_type: {config['model_type']}. Must be one of {valid_model_types}")
|
| 52 |
+
|
| 53 |
+
# Validate numeric fields
|
| 54 |
+
numeric_fields = ['num_classes', 'learning_rate', 'max_epochs', 'batch_size']
|
| 55 |
+
for field in numeric_fields:
|
| 56 |
+
if field in config and not isinstance(config[field], (int, float)):
|
| 57 |
+
raise ValueError(f"Field {field} must be numeric")
|
| 58 |
+
|
| 59 |
+
return True
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def create_default_config() -> Dict[str, Any]:
|
| 63 |
+
"""Create a default configuration."""
|
| 64 |
+
return {
|
| 65 |
+
'experiment_name': 'architectural_classification',
|
| 66 |
+
'model_type': 'hierarchical',
|
| 67 |
+
'num_classes': 25,
|
| 68 |
+
'num_broad_classes': 5,
|
| 69 |
+
'num_fine_classes': 25,
|
| 70 |
+
'learning_rate': 1e-4,
|
| 71 |
+
'weight_decay': 1e-5,
|
| 72 |
+
'max_epochs': 100,
|
| 73 |
+
'batch_size': 16,
|
| 74 |
+
'use_hierarchical_loss': True,
|
| 75 |
+
'use_contrastive_loss': False,
|
| 76 |
+
'use_style_relationship_loss': True,
|
| 77 |
+
'use_mixed_precision': False,
|
| 78 |
+
'gradient_clip_val': 1.0,
|
| 79 |
+
'accumulate_grad_batches': 1,
|
| 80 |
+
'use_wandb': False,
|
| 81 |
+
'curriculum_stages': [
|
| 82 |
+
{'epochs': 20, 'classes': ['ancient', 'medieval', 'modern']},
|
| 83 |
+
{'epochs': 80, 'classes': list(range(25))}
|
| 84 |
+
]
|
| 85 |
+
}
|