| """ | |
| Configuration settings for Legal-BERT training and risk discovery | |
| """ | |
| from dataclasses import dataclass | |
| from typing import Dict, Any | |
| import torch | |
| class LegalBertConfig: | |
| """Configuration for Legal-BERT model and training""" | |
| # Model parameters | |
| bert_model_name: str = "bert-base-uncased" | |
| num_risk_categories: int = 7 # Will be dynamically determined by risk discovery | |
| max_sequence_length: int = 512 | |
| dropout_rate: float = 0.1 | |
| # Hierarchical model parameters (ALWAYS USED) | |
| hierarchical_hidden_dim: int = 512 | |
| hierarchical_num_lstm_layers: int = 2 | |
| # Training parameters - OPTIMIZED FOR BEST RESULTS | |
| batch_size: int = 16 | |
| num_epochs: int = 20 # Increased to 20 for better convergence | |
| learning_rate: float = 2e-5 # Increased for OneCycleLR scheduler | |
| weight_decay: float = 0.01 | |
| warmup_steps: int = 1000 | |
| gradient_clip_norm: float = 1.0 # Prevent gradient explosion with high classification weight | |
| early_stopping_patience: int = 3 # Stop if val loss doesn't improve for 3 epochs | |
| # Multi-task loss weights - REBALANCED (Phase 1 improvements) | |
| # Changed from 10:1:1 to 20:0.5:0.5 to prioritize classification | |
| task_weights: Dict[str, float] = None | |
| # Focal Loss parameters for hard example mining | |
| use_focal_loss: bool = True # Use Focal Loss instead of CrossEntropyLoss | |
| focal_loss_gamma: float = 2.5 # Focus heavily on hard-to-classify examples | |
| minority_class_boost: float = 1.8 # Boost weight for Classes 0 and 5 by 80% | |
| # Learning rate scheduling | |
| use_lr_scheduler: bool = True # Use OneCycleLR for better convergence | |
| scheduler_pct_start: float = 0.1 # 10% of training for warmup | |
| # Device configuration | |
| device: str = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Paths | |
| data_path: str = "dataset/CUAD_v1/CUAD_v1.json" | |
| model_save_path: str = "models/legal_bert" | |
| checkpoint_dir: str = "checkpoints" | |
| # Risk discovery parameters - OPTIMIZED FOR BETTER PATTERN DISCOVERY | |
| risk_discovery_method: str = "lda" # Options: 'lda', 'kmeans', 'hierarchical', 'nmf', 'gmm', etc. | |
| risk_discovery_clusters: int = 7 # Number of risk patterns/topics to discover | |
| tfidf_max_features: int = 15000 # Increased from 10000 for better vocabulary coverage | |
| tfidf_ngram_range: tuple = (1, 3) | |
| # LDA-specific parameters (used when risk_discovery_method='lda') - OPTIMIZED | |
| lda_doc_topic_prior: float = 0.1 # Alpha - controls document-topic density (lower = more focused) | |
| lda_topic_word_prior: float = 0.01 # Beta - controls topic-word density (lower = more focused) | |
| lda_max_iter: int = 50 # Increased from 20 to 50 for better convergence | |
| lda_max_features: int = 8000 # Increased from 5000 for richer topic modeling | |
| lda_learning_method: str = 'batch' # 'batch' or 'online' | |
| def __post_init__(self): | |
| if self.task_weights is None: | |
| # PHASE 1 IMPROVEMENT: Rebalanced from 10:1:1 to 20:0.5:0.5 | |
| # This prioritizes classification learning over regression | |
| self.task_weights = { | |
| 'classification': 20.0, # Increased from 1.0 to 20.0 | |
| 'severity': 0.5, # Decreased from 0.5 to 0.5 | |
| 'importance': 0.5 # Decreased from 0.5 to 0.5 | |
| } | |
| # Global configuration instance | |
| config = LegalBertConfig() |