SocioCausaNet / config.py
rasoultilburg's picture
Uploading joint causal model files into the Hugging Face Hub
d2ca2ac verified
"""
Configuration settings for the joint causal learning model.
"""
import torch
# Device configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Seed for reproducibility
SEED = 8642
# Model configuration
MODEL_CONFIG = {
"encoder_name": "bert-base-uncased", # Default encoder
"num_cls_labels": 2, # Binary classification for causal/non-causal
"num_bio_labels": 7, # BIO labels for span detection
"num_rel_labels": 2, # Relation labels (updated from 3 to 2)
"dropout": 0.2, # Dropout rate
}
# Training configuration
TRAINING_CONFIG = {
"batch_size": 16,
"num_epochs": 20,
"learning_rate": 1e-5,
"weight_decay": 0.1,
"gradient_clip_val": 1.0,
"patience_epochs": 10,
"model_save_path": "best_joint_causal_model.pt",
}
# Dataset configuration
DATASET_CONFIG = {
"max_length": 512, # Maximum sequence length for tokenization
"negative_relation_rate": 2.0, # Rate of negative relation samples to generate
"max_random_span_len": 5, # Maximum length for random negative spans
"ignore_id": -100, # ID to ignore in loss computation
}
# Label mappings
# BIO labels for span detection
id2label_bio = {
0: "B-C", # Beginning of Cause
1: "I-C", # Inside of Cause
2: "B-E", # Beginning of Effect
3: "I-E", # Inside of Effect
4: "B-CE", # Beginning of Cause-Effect
5: "I-CE", # Inside of Cause-Effect
6: "O" # Outside
}
label2id_bio = {v: k for k, v in id2label_bio.items()}
# Entity label to BIO prefix mapping
entity_label_to_bio_prefix = {
"cause": "C",
"effect": "E",
"internal_CE": "CE",
"non-causal": "O"
}
# Relation labels
id2label_rel = {
0: "Rel_None",
1: "Rel_CE"
}
label2id_rel = {
"Rel_None": 0,
"Rel_CE": 1
}
# Classification labels
id2label_cls = {
0: "non-causal",
1: "causal"
}
label2id_cls = {v: k for k, v in id2label_cls.items()}
# Relation type mappings
POSITIVE_RELATION_TYPE_TO_ID = {
"Rel_CE": 1
}
NEGATIVE_SAMPLE_REL_ID = label2id_rel["Rel_None"]
# Inference configuration
INFERENCE_CONFIG = {
"cls_threshold": 0.5, # Threshold for causal/non-causal classification
}