abd-ur commited on
Commit
cc67f2f
·
verified ·
1 Parent(s): 47f723d

Create config.py

Browse files

config file for the pipeline

Files changed (1) hide show
  1. config.py +198 -0
config.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module contains all configuration parameters for the VCF processing pipeline
3
+ """
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import Dict, List, Optional, Any
7
+ import json
8
+ import os
9
+
10
+
11
+ @dataclass
12
+ class ModelConfig:
13
+ """Configurations"""
14
+
15
+ # Embedding dimensions
16
+ embed_dim: int = 32
17
+ transformer_dim: int = 128
18
+
19
+ # Transformer parameters
20
+ nhead: int = 8
21
+ num_layers: int = 2
22
+ dropout: float = 0.1
23
+
24
+ # Model architecture
25
+ num_classes: int = 2
26
+ hidden_dims: List[int] = field(default_factory=lambda: [256, 128])
27
+
28
+ # Training parameters
29
+ learning_rate: float = 1e-4
30
+ batch_size: int = 16
31
+ max_epochs: int = 100
32
+ early_stopping_patience: int = 10
33
+
34
+ # Data processing
35
+ max_mutations_per_gene: int = 100
36
+ max_genes_per_chromosome: int = 1000
37
+ max_chromosomes_per_pathway: int = 50
38
+ max_pathways_per_sample: int = 100
39
+
40
+
41
+ @dataclass
42
+ class DataConfig:
43
+ """Configurations"""
44
+
45
+ # File paths
46
+ vcf_file_path: Optional[str] = None
47
+ gene_annotation_path: Optional[str] = None
48
+ pathway_mapping_path: Optional[str] = None
49
+ output_dir: str = "./outputs"
50
+ cache_dir: str = "./cache"
51
+
52
+ # VCF processing
53
+ supported_impacts: List[str] = field(default_factory=lambda: [
54
+ "HIGH", "MODERATE", "LOW", "MODIFIER"
55
+ ])
56
+ supported_chromosomes: List[str] = field(default_factory=lambda: [
57
+ "1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
58
+ "11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
59
+ "21", "22", "X", "Y", "MT"
60
+ ])
61
+
62
+ # Tokenization
63
+ special_tokens: Dict[str, str] = field(default_factory=lambda: {
64
+ "pad_token": "[PAD]",
65
+ "unk_token": "[UNK]",
66
+ "sep_token": "[SEP]",
67
+ "cls_token": "[CLS]"
68
+ })
69
+
70
+ # Data validation
71
+ min_mutations_per_sample: int = 1
72
+ max_mutations_per_sample: int = 10000
73
+
74
+
75
+ @dataclass
76
+ class HuggingFaceConfig:
77
+ """Configurations"""
78
+
79
+ model_name: str = "GvEM"
80
+ model_version: str = "1.0.0"
81
+ model_description: str = "Genomic Variant Embedding Model"
82
+
83
+ # Hub configuration
84
+ push_to_hub: bool = False
85
+ hub_model_id: Optional[str] = None
86
+ hub_token: Optional[str] = None
87
+
88
+ # Model card information
89
+ license: str = "apache-2.0"
90
+ tags: List[str] = field(default_factory=lambda: [
91
+ "genomics", "vcf", "transformer", "hierarchical", "mutations"
92
+ ])
93
+
94
+ # Repository information
95
+ repository_url: Optional[str] = None
96
+ paper_url: Optional[str] = None
97
+
98
+
99
+ class ConfigManager:
100
+ """Manage configurations"""
101
+
102
+ def __init__(self, config_path: Optional[str] = None):
103
+ self.config_path = config_path or "config.json"
104
+ self.model_config = ModelConfig()
105
+ self.data_config = DataConfig()
106
+ self.hf_config = HuggingFaceConfig()
107
+
108
+ def load_config(self, config_path: Optional[str] = None) -> None:
109
+ path = config_path or self.config_path
110
+
111
+ if os.path.exists(path):
112
+ with open(path, 'r') as f:
113
+ config_dict = json.load(f)
114
+
115
+ # Update configurations
116
+ if 'model' in config_dict:
117
+ self._update_dataclass(self.model_config, config_dict['model'])
118
+ if 'data' in config_dict:
119
+ self._update_dataclass(self.data_config, config_dict['data'])
120
+ if 'huggingface' in config_dict:
121
+ self._update_dataclass(self.hf_config, config_dict['huggingface'])
122
+
123
+ def save_config(self, config_path: Optional[str] = None) -> None:
124
+ path = config_path or self.config_path
125
+
126
+ config_dict = {
127
+ 'model': self._dataclass_to_dict(self.model_config),
128
+ 'data': self._dataclass_to_dict(self.data_config),
129
+ 'huggingface': self._dataclass_to_dict(self.hf_config)
130
+ }
131
+
132
+ os.makedirs(os.path.dirname(path), exist_ok=True)
133
+ with open(path, 'w') as f:
134
+ json.dump(config_dict, f, indent=2)
135
+
136
+ def _update_dataclass(self, dataclass_obj: Any, update_dict: Dict) -> None:
137
+ """Update dataclass fields from dictionary."""
138
+ for key, value in update_dict.items():
139
+ if hasattr(dataclass_obj, key):
140
+ setattr(dataclass_obj, key, value)
141
+
142
+ def _dataclass_to_dict(self, dataclass_obj: Any) -> Dict:
143
+ """Convert dataclass to dictionary."""
144
+ result = {}
145
+ for key, value in dataclass_obj.__dict__.items():
146
+ if not key.startswith('_'):
147
+ result[key] = value
148
+ return result
149
+
150
+ def validate_config(self) -> bool:
151
+ """Validate configuration parameters."""
152
+ # Model validation
153
+ assert self.model_config.embed_dim > 0, "embed_dim must be positive"
154
+ assert self.model_config.nhead > 0, "nhead must be positive"
155
+ assert self.model_config.num_classes > 1, "num_classes must be > 1"
156
+ assert 0 <= self.model_config.dropout <= 1, "dropout must be in [0, 1]"
157
+
158
+ # Data validation
159
+ assert self.data_config.min_mutations_per_sample > 0, "min_mutations_per_sample must be positive"
160
+ assert self.data_config.max_mutations_per_sample > self.data_config.min_mutations_per_sample, \
161
+ "max_mutations_per_sample must be > min_mutations_per_sample"
162
+
163
+ return True
164
+
165
+ def get_model_config_dict(self) -> Dict:
166
+ return {
167
+ 'architectures': ['HierarchicalVCFModel'],
168
+ 'model_type': 'hierarchical-vcf',
169
+ **self._dataclass_to_dict(self.model_config)
170
+ }
171
+
172
+ default_config = ConfigManager()
173
+
174
+ EXAMPLE_CONFIG = {
175
+ "model": {
176
+ "embed_dim": 64,
177
+ "transformer_dim": 256,
178
+ "nhead": 8,
179
+ "num_layers": 3,
180
+ "num_classes": 5,
181
+ "learning_rate": 5e-4,
182
+ "batch_size": 32
183
+ },
184
+ "data": {
185
+ "vcf_file_path": "/path/to/variants.vcf",
186
+ "gene_annotation_path": "/path/to/gene_annotations.json",
187
+ "pathway_mapping_path": "/path/to/pathway_mappings.json",
188
+ "output_dir": "./results",
189
+ "min_mutations_per_sample": 5,
190
+ "max_mutations_per_sample": 5000
191
+ },
192
+ "huggingface": {
193
+ "model_name": "my-vcf-model",
194
+ "push_to_hub": True,
195
+ "hub_model_id": "username/my-vcf-model",
196
+ "license": "mit"
197
+ }
198
+ }