| from dataclasses import dataclass, field |
| from typing import Optional |
|
|
|
|
| @dataclass |
| class StruCTAConfig: |
| """Configuration for StruCTA privacy-preserving transformer.""" |
|
|
| |
| hidden_dim: int = 768 |
| num_encoder_layers: int = 12 |
| num_decoder_layers: int = 12 |
| num_heads: int = 12 |
| ffn_dim: int = 3072 |
| dropout: float = 0.1 |
| attention_dropout: float = 0.1 |
|
|
| |
| max_degree: int = 512 |
| max_spatial_dist: int = 128 |
| max_edge_features: int = 32 |
| use_centrality_encoding: bool = True |
| use_spatial_encoding: bool = True |
| use_edge_encoding: bool = True |
|
|
| |
| num_abstract_types: int = 32 |
| abstract_type_map: dict = field(default_factory=lambda: { |
| "PERSON": 0, |
| "ORG": 1, |
| "LOC": 2, |
| "GPE": 3, |
| "MONEY": 4, |
| "DATE": 5, |
| "PHONE": 6, |
| "EMAIL": 7, |
| "SSN": 8, |
| "ID": 9, |
| "PRODUCT": 10, |
| "EVENT": 11, |
| "MISC": 12, |
| }) |
|
|
| |
| vocab_size: int = 50000 |
|
|
| |
| use_dp_training: bool = True |
| dp_epsilon: float = 3.0 |
| dp_delta: float = 1e-5 |
| dp_clip_norm: float = 1.0 |
|
|
| |
| max_graph_nodes: int = 256 |
|
|
| |
| max_seq_length: int = 512 |
| batch_size: int = 32 |
| learning_rate: float = 2e-4 |
| warmup_steps: int = 60000 |
| num_training_steps: int = 1000000 |
|
|
| |
| use_privacy_verification: bool = True |
| privacy_threshold: float = 0.95 |
|
|