""" This module contains all configuration parameters for the VCF processing pipeline """ from dataclasses import dataclass, field from typing import Dict, List, Optional, Any import json import os @dataclass class ModelConfig: """Configurations""" # Embedding dimensions embed_dim: int = 32 transformer_dim: int = 128 # Transformer parameters nhead: int = 8 num_layers: int = 2 dropout: float = 0.1 # Model architecture num_classes: int = 2 hidden_dims: List[int] = field(default_factory=lambda: [256, 128]) # Training parameters learning_rate: float = 1e-4 batch_size: int = 16 max_epochs: int = 100 early_stopping_patience: int = 10 # Data processing max_mutations_per_gene: int = 100 max_genes_per_chromosome: int = 1000 max_chromosomes_per_pathway: int = 50 max_pathways_per_sample: int = 100 @dataclass class DataConfig: """Configurations""" # File paths vcf_file_path: Optional[str] = None gene_annotation_path: Optional[str] = None pathway_mapping_path: Optional[str] = None output_dir: str = "./outputs" cache_dir: str = "./cache" # VCF processing supported_impacts: List[str] = field(default_factory=lambda: [ "HIGH", "MODERATE", "LOW", "MODIFIER" ]) supported_chromosomes: List[str] = field(default_factory=lambda: [ "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", "21", "22", "X", "Y", "MT" ]) # Tokenization special_tokens: Dict[str, str] = field(default_factory=lambda: { "pad_token": "[PAD]", "unk_token": "[UNK]", "sep_token": "[SEP]", "cls_token": "[CLS]" }) # Data validation min_mutations_per_sample: int = 1 max_mutations_per_sample: int = 10000 @dataclass class HuggingFaceConfig: """Configurations""" model_name: str = "GvEM" model_version: str = "1.0.0" model_description: str = "Genomic Variant Embedding Model" # Hub configuration push_to_hub: bool = False hub_model_id: Optional[str] = None hub_token: Optional[str] = None # Model card information license: str = "apache-2.0" tags: List[str] = field(default_factory=lambda: [ "genomics", "vcf", "transformer", "hierarchical", "mutations" ]) # Repository information repository_url: Optional[str] = None paper_url: Optional[str] = None class ConfigManager: """Manage configurations""" def __init__(self, config_path: Optional[str] = None): self.config_path = config_path or "config.json" self.model_config = ModelConfig() self.data_config = DataConfig() self.hf_config = HuggingFaceConfig() def load_config(self, config_path: Optional[str] = None) -> None: path = config_path or self.config_path if os.path.exists(path): with open(path, 'r') as f: config_dict = json.load(f) # Update configurations if 'model' in config_dict: self._update_dataclass(self.model_config, config_dict['model']) if 'data' in config_dict: self._update_dataclass(self.data_config, config_dict['data']) if 'huggingface' in config_dict: self._update_dataclass(self.hf_config, config_dict['huggingface']) def save_config(self, config_path: Optional[str] = None) -> None: path = config_path or self.config_path config_dict = { 'model': self._dataclass_to_dict(self.model_config), 'data': self._dataclass_to_dict(self.data_config), 'huggingface': self._dataclass_to_dict(self.hf_config) } os.makedirs(os.path.dirname(path), exist_ok=True) with open(path, 'w') as f: json.dump(config_dict, f, indent=2) def _update_dataclass(self, dataclass_obj: Any, update_dict: Dict) -> None: """Update dataclass fields from dictionary.""" for key, value in update_dict.items(): if hasattr(dataclass_obj, key): setattr(dataclass_obj, key, value) def _dataclass_to_dict(self, dataclass_obj: Any) -> Dict: """Convert dataclass to dictionary.""" result = {} for key, value in dataclass_obj.__dict__.items(): if not key.startswith('_'): result[key] = value return result def validate_config(self) -> bool: """Validate configuration parameters.""" # Model validation assert self.model_config.embed_dim > 0, "embed_dim must be positive" assert self.model_config.nhead > 0, "nhead must be positive" assert self.model_config.num_classes > 1, "num_classes must be > 1" assert 0 <= self.model_config.dropout <= 1, "dropout must be in [0, 1]" # Data validation assert self.data_config.min_mutations_per_sample > 0, "min_mutations_per_sample must be positive" assert self.data_config.max_mutations_per_sample > self.data_config.min_mutations_per_sample, \ "max_mutations_per_sample must be > min_mutations_per_sample" return True def get_model_config_dict(self) -> Dict: return { 'architectures': ['HierarchicalVCFModel'], 'model_type': 'hierarchical-vcf', **self._dataclass_to_dict(self.model_config) } default_config = ConfigManager() EXAMPLE_CONFIG = { "model": { "embed_dim": 64, "transformer_dim": 256, "nhead": 8, "num_layers": 3, "num_classes": 5, "learning_rate": 5e-4, "batch_size": 32 }, "data": { "vcf_file_path": "/path/to/variants.vcf", "gene_annotation_path": "/path/to/gene_annotations.json", "pathway_mapping_path": "/path/to/pathway_mappings.json", "output_dir": "./results", "min_mutations_per_sample": 5, "max_mutations_per_sample": 5000 }, "huggingface": { "model_name": "my-vcf-model", "push_to_hub": True, "hub_model_id": "username/my-vcf-model", "license": "mit" } }