File size: 3,413 Bytes
9b1c753
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21613a7
 
9b1c753
 
21613a7
 
9b1c753
21613a7
 
9b1c753
 
21613a7
 
 
 
 
 
 
 
 
9b1c753
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21613a7
 
9b1c753
21613a7
 
 
9b1c753
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
"""
Configuration settings for Legal-BERT training and risk discovery
"""
from dataclasses import dataclass
from typing import Dict, Any
import torch

@dataclass
class LegalBertConfig:
    """Configuration for Legal-BERT model and training"""
    
    # Model parameters
    bert_model_name: str = "bert-base-uncased"
    num_risk_categories: int = 7  # Will be dynamically determined by risk discovery
    max_sequence_length: int = 512
    dropout_rate: float = 0.1
    
    # Hierarchical model parameters (ALWAYS USED)
    hierarchical_hidden_dim: int = 512
    hierarchical_num_lstm_layers: int = 2
    
    # Training parameters - OPTIMIZED FOR BEST RESULTS
    batch_size: int = 16
    num_epochs: int = 20  # Increased to 20 for better convergence
    learning_rate: float = 2e-5  # Increased for OneCycleLR scheduler
    weight_decay: float = 0.01
    warmup_steps: int = 1000
    gradient_clip_norm: float = 1.0  # Prevent gradient explosion with high classification weight
    early_stopping_patience: int = 3  # Stop if val loss doesn't improve for 3 epochs
    
    # Multi-task loss weights - REBALANCED (Phase 1 improvements)
    # Changed from 10:1:1 to 20:0.5:0.5 to prioritize classification
    task_weights: Dict[str, float] = None
    
    # Focal Loss parameters for hard example mining
    use_focal_loss: bool = True  # Use Focal Loss instead of CrossEntropyLoss
    focal_loss_gamma: float = 2.5  # Focus heavily on hard-to-classify examples
    minority_class_boost: float = 1.8  # Boost weight for Classes 0 and 5 by 80%
    
    # Learning rate scheduling
    use_lr_scheduler: bool = True  # Use OneCycleLR for better convergence
    scheduler_pct_start: float = 0.1  # 10% of training for warmup
    
    # Device configuration
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Paths
    data_path: str = "dataset/CUAD_v1/CUAD_v1.json"
    model_save_path: str = "models/legal_bert"
    checkpoint_dir: str = "checkpoints"
    
    # Risk discovery parameters - OPTIMIZED FOR BETTER PATTERN DISCOVERY
    risk_discovery_method: str = "lda"  # Options: 'lda', 'kmeans', 'hierarchical', 'nmf', 'gmm', etc.
    risk_discovery_clusters: int = 7  # Number of risk patterns/topics to discover
    tfidf_max_features: int = 15000  # Increased from 10000 for better vocabulary coverage
    tfidf_ngram_range: tuple = (1, 3)
    
    # LDA-specific parameters (used when risk_discovery_method='lda') - OPTIMIZED
    lda_doc_topic_prior: float = 0.1  # Alpha - controls document-topic density (lower = more focused)
    lda_topic_word_prior: float = 0.01  # Beta - controls topic-word density (lower = more focused)
    lda_max_iter: int = 50  # Increased from 20 to 50 for better convergence
    lda_max_features: int = 8000  # Increased from 5000 for richer topic modeling
    lda_learning_method: str = 'batch'  # 'batch' or 'online'
    
    def __post_init__(self):
        if self.task_weights is None:
            # PHASE 1 IMPROVEMENT: Rebalanced from 10:1:1 to 20:0.5:0.5
            # This prioritizes classification learning over regression
            self.task_weights = {
                'classification': 20.0,  # Increased from 1.0 to 20.0
                'severity': 0.5,         # Decreased from 0.5 to 0.5
                'importance': 0.5        # Decreased from 0.5 to 0.5
            }

# Global configuration instance
config = LegalBertConfig()