File size: 1,533 Bytes
d9187c8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 | from dataclasses import dataclass, field
from typing import Optional
@dataclass
class StruCTAConfig:
"""Configuration for StruCTA privacy-preserving transformer."""
# Model dimensions
hidden_dim: int = 768
num_encoder_layers: int = 12
num_decoder_layers: int = 12
num_heads: int = 12
ffn_dim: int = 3072
dropout: float = 0.1
attention_dropout: float = 0.1
# Graph structural encodings (Graphormer)
max_degree: int = 512
max_spatial_dist: int = 128
max_edge_features: int = 32
use_centrality_encoding: bool = True
use_spatial_encoding: bool = True
use_edge_encoding: bool = True
# Abstract entity types
num_abstract_types: int = 32
abstract_type_map: dict = field(default_factory=lambda: {
"PERSON": 0,
"ORG": 1,
"LOC": 2,
"GPE": 3,
"MONEY": 4,
"DATE": 5,
"PHONE": 6,
"EMAIL": 7,
"SSN": 8,
"ID": 9,
"PRODUCT": 10,
"EVENT": 11,
"MISC": 12,
})
# Vocabulary
vocab_size: int = 50000
# Privacy
use_dp_training: bool = True
dp_epsilon: float = 3.0
dp_delta: float = 1e-5
dp_clip_norm: float = 1.0
# Decoder cross-modal
max_graph_nodes: int = 256
# Training
max_seq_length: int = 512
batch_size: int = 32
learning_rate: float = 2e-4
warmup_steps: int = 60000
num_training_steps: int = 1000000
# Inference
use_privacy_verification: bool = True
privacy_threshold: float = 0.95
|