StruCTA / config.py
YOUSSEF88's picture
Upload config.py
704b1b5 verified
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class StruCTAConfig:
"""Configuration for StruCTA privacy-preserving transformer."""
# Model dimensions
hidden_dim: int = 768
num_encoder_layers: int = 12
num_decoder_layers: int = 12
num_heads: int = 12
ffn_dim: int = 3072
dropout: float = 0.1
attention_dropout: float = 0.1
# Graph structural encodings (Graphormer)
max_degree: int = 512
max_spatial_dist: int = 128
max_edge_features: int = 32
use_centrality_encoding: bool = True
use_spatial_encoding: bool = True
use_edge_encoding: bool = True
# Abstract entity types
num_abstract_types: int = 32
abstract_type_map: dict = field(default_factory=lambda: {
"PERSON": 0,
"ORG": 1,
"LOC": 2,
"GPE": 3,
"MONEY": 4,
"DATE": 5,
"PHONE": 6,
"EMAIL": 7,
"SSN": 8,
"ID": 9,
"PRODUCT": 10,
"EVENT": 11,
"MISC": 12,
})
# Vocabulary
vocab_size: int = 50000
# Privacy
use_dp_training: bool = True
dp_epsilon: float = 3.0
dp_delta: float = 1e-5
dp_clip_norm: float = 1.0
# Decoder cross-modal
max_graph_nodes: int = 256
# Training
max_seq_length: int = 512
batch_size: int = 32
learning_rate: float = 2e-4
warmup_steps: int = 60000
num_training_steps: int = 1000000
# Inference
use_privacy_verification: bool = True
privacy_threshold: float = 0.95