|
|
""" |
|
|
This module contains all configuration parameters for the VCF processing pipeline |
|
|
""" |
|
|
|
|
|
from dataclasses import dataclass, field |
|
|
from typing import Dict, List, Optional, Any |
|
|
import json |
|
|
import os |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ModelConfig: |
|
|
"""Configurations""" |
|
|
|
|
|
|
|
|
embed_dim: int = 32 |
|
|
transformer_dim: int = 128 |
|
|
|
|
|
|
|
|
nhead: int = 8 |
|
|
num_layers: int = 2 |
|
|
dropout: float = 0.1 |
|
|
|
|
|
|
|
|
num_classes: int = 2 |
|
|
hidden_dims: List[int] = field(default_factory=lambda: [256, 128]) |
|
|
|
|
|
|
|
|
learning_rate: float = 1e-4 |
|
|
batch_size: int = 16 |
|
|
max_epochs: int = 100 |
|
|
early_stopping_patience: int = 10 |
|
|
|
|
|
|
|
|
max_mutations_per_gene: int = 100 |
|
|
max_genes_per_chromosome: int = 1000 |
|
|
max_chromosomes_per_pathway: int = 50 |
|
|
max_pathways_per_sample: int = 100 |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DataConfig: |
|
|
"""Configurations""" |
|
|
|
|
|
|
|
|
vcf_file_path: Optional[str] = None |
|
|
gene_annotation_path: Optional[str] = None |
|
|
pathway_mapping_path: Optional[str] = None |
|
|
output_dir: str = "./outputs" |
|
|
cache_dir: str = "./cache" |
|
|
|
|
|
|
|
|
supported_impacts: List[str] = field(default_factory=lambda: [ |
|
|
"HIGH", "MODERATE", "LOW", "MODIFIER" |
|
|
]) |
|
|
supported_chromosomes: List[str] = field(default_factory=lambda: [ |
|
|
"1", "2", "3", "4", "5", "6", "7", "8", "9", "10", |
|
|
"11", "12", "13", "14", "15", "16", "17", "18", "19", "20", |
|
|
"21", "22", "X", "Y", "MT" |
|
|
]) |
|
|
|
|
|
|
|
|
special_tokens: Dict[str, str] = field(default_factory=lambda: { |
|
|
"pad_token": "[PAD]", |
|
|
"unk_token": "[UNK]", |
|
|
"sep_token": "[SEP]", |
|
|
"cls_token": "[CLS]" |
|
|
}) |
|
|
|
|
|
|
|
|
min_mutations_per_sample: int = 1 |
|
|
max_mutations_per_sample: int = 10000 |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class HuggingFaceConfig: |
|
|
"""Configurations""" |
|
|
|
|
|
model_name: str = "GvEM" |
|
|
model_version: str = "1.0.0" |
|
|
model_description: str = "Genomic Variant Embedding Model" |
|
|
|
|
|
|
|
|
push_to_hub: bool = False |
|
|
hub_model_id: Optional[str] = None |
|
|
hub_token: Optional[str] = None |
|
|
|
|
|
|
|
|
license: str = "apache-2.0" |
|
|
tags: List[str] = field(default_factory=lambda: [ |
|
|
"genomics", "vcf", "transformer", "hierarchical", "mutations" |
|
|
]) |
|
|
|
|
|
|
|
|
repository_url: Optional[str] = None |
|
|
paper_url: Optional[str] = None |
|
|
|
|
|
|
|
|
class ConfigManager: |
|
|
"""Manage configurations""" |
|
|
|
|
|
def __init__(self, config_path: Optional[str] = None): |
|
|
self.config_path = config_path or "config.json" |
|
|
self.model_config = ModelConfig() |
|
|
self.data_config = DataConfig() |
|
|
self.hf_config = HuggingFaceConfig() |
|
|
|
|
|
def load_config(self, config_path: Optional[str] = None) -> None: |
|
|
path = config_path or self.config_path |
|
|
|
|
|
if os.path.exists(path): |
|
|
with open(path, 'r') as f: |
|
|
config_dict = json.load(f) |
|
|
|
|
|
|
|
|
if 'model' in config_dict: |
|
|
self._update_dataclass(self.model_config, config_dict['model']) |
|
|
if 'data' in config_dict: |
|
|
self._update_dataclass(self.data_config, config_dict['data']) |
|
|
if 'huggingface' in config_dict: |
|
|
self._update_dataclass(self.hf_config, config_dict['huggingface']) |
|
|
|
|
|
def save_config(self, config_path: Optional[str] = None) -> None: |
|
|
path = config_path or self.config_path |
|
|
|
|
|
config_dict = { |
|
|
'model': self._dataclass_to_dict(self.model_config), |
|
|
'data': self._dataclass_to_dict(self.data_config), |
|
|
'huggingface': self._dataclass_to_dict(self.hf_config) |
|
|
} |
|
|
|
|
|
os.makedirs(os.path.dirname(path), exist_ok=True) |
|
|
with open(path, 'w') as f: |
|
|
json.dump(config_dict, f, indent=2) |
|
|
|
|
|
def _update_dataclass(self, dataclass_obj: Any, update_dict: Dict) -> None: |
|
|
"""Update dataclass fields from dictionary.""" |
|
|
for key, value in update_dict.items(): |
|
|
if hasattr(dataclass_obj, key): |
|
|
setattr(dataclass_obj, key, value) |
|
|
|
|
|
def _dataclass_to_dict(self, dataclass_obj: Any) -> Dict: |
|
|
"""Convert dataclass to dictionary.""" |
|
|
result = {} |
|
|
for key, value in dataclass_obj.__dict__.items(): |
|
|
if not key.startswith('_'): |
|
|
result[key] = value |
|
|
return result |
|
|
|
|
|
def validate_config(self) -> bool: |
|
|
"""Validate configuration parameters.""" |
|
|
|
|
|
assert self.model_config.embed_dim > 0, "embed_dim must be positive" |
|
|
assert self.model_config.nhead > 0, "nhead must be positive" |
|
|
assert self.model_config.num_classes > 1, "num_classes must be > 1" |
|
|
assert 0 <= self.model_config.dropout <= 1, "dropout must be in [0, 1]" |
|
|
|
|
|
|
|
|
assert self.data_config.min_mutations_per_sample > 0, "min_mutations_per_sample must be positive" |
|
|
assert self.data_config.max_mutations_per_sample > self.data_config.min_mutations_per_sample, \ |
|
|
"max_mutations_per_sample must be > min_mutations_per_sample" |
|
|
|
|
|
return True |
|
|
|
|
|
def get_model_config_dict(self) -> Dict: |
|
|
return { |
|
|
'architectures': ['HierarchicalVCFModel'], |
|
|
'model_type': 'hierarchical-vcf', |
|
|
**self._dataclass_to_dict(self.model_config) |
|
|
} |
|
|
|
|
|
default_config = ConfigManager() |
|
|
|
|
|
EXAMPLE_CONFIG = { |
|
|
"model": { |
|
|
"embed_dim": 64, |
|
|
"transformer_dim": 256, |
|
|
"nhead": 8, |
|
|
"num_layers": 3, |
|
|
"num_classes": 5, |
|
|
"learning_rate": 5e-4, |
|
|
"batch_size": 32 |
|
|
}, |
|
|
"data": { |
|
|
"vcf_file_path": "/path/to/variants.vcf", |
|
|
"gene_annotation_path": "/path/to/gene_annotations.json", |
|
|
"pathway_mapping_path": "/path/to/pathway_mappings.json", |
|
|
"output_dir": "./results", |
|
|
"min_mutations_per_sample": 5, |
|
|
"max_mutations_per_sample": 5000 |
|
|
}, |
|
|
"huggingface": { |
|
|
"model_name": "my-vcf-model", |
|
|
"push_to_hub": True, |
|
|
"hub_model_id": "username/my-vcf-model", |
|
|
"license": "mit" |
|
|
} |
|
|
} |