File size: 2,854 Bytes
47bc13b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""Model configuration for address NER."""

from dataclasses import dataclass


@dataclass
class ModelConfig:
    """Configuration for BERT-CRF NER model."""

    # Base model - IndicBERTv2-SS recommended for Indian languages
    # Options: "bert-base-multilingual-cased", "ai4bharat/IndicBERTv2-SS",
    #          "google/muril-base-cased", "xlm-roberta-base"
    model_name: str = "ai4bharat/IndicBERTv2-SS"
    use_crf: bool = True

    # Architecture
    hidden_size: int = 768
    num_labels: int = 31  # O + 15 entity types * 2 (B-/I-)
    hidden_dropout_prob: float = 0.1
    classifier_dropout: float = 0.1

    # CRF settings
    crf_reduction: str = "mean"  # 'mean' or 'sum'

    # Training
    max_length: int = 128
    learning_rate: float = 5e-5
    crf_learning_rate: float = 1e-3  # Higher LR for CRF
    weight_decay: float = 0.01
    warmup_ratio: float = 0.1
    num_epochs: int = 10
    batch_size: int = 16
    gradient_accumulation_steps: int = 1

    # Label smoothing
    label_smoothing: float = 0.0

    # Early stopping
    early_stopping_patience: int = 5
    early_stopping_threshold: float = 0.001

    # Layer-wise learning rate decay
    lr_decay: float = 0.95

    # Paths
    output_dir: str = "./models"
    cache_dir: str | None = None

    # ONNX export
    onnx_opset_version: int = 14

    @classmethod
    def from_pretrained_name(cls, name: str) -> ModelConfig:
        """Create config for known pretrained models."""
        configs = {
            "mbert": cls(
                model_name="bert-base-multilingual-cased",
                hidden_size=768,
            ),
            "indicbert": cls(
                model_name="ai4bharat/IndicBERTv2-SS",
                hidden_size=768,
            ),
            "distilbert": cls(
                model_name="distilbert-base-multilingual-cased",
                hidden_size=768,
            ),
            "xlm-roberta": cls(
                model_name="xlm-roberta-base",
                hidden_size=768,
            ),
            "muril": cls(
                model_name="google/muril-base-cased",
                hidden_size=768,
            ),
        }
        return configs.get(name, cls())


# Entity label definitions (must match schemas.py)
ENTITY_LABELS = [
    "AREA",
    "SUBAREA",
    "HOUSE_NUMBER",
    "SECTOR",
    "GALI",
    "COLONY",
    "BLOCK",
    "CAMP",
    "POLE",
    "KHASRA",
    "FLOOR",
    "PLOT",
    "PINCODE",
    "CITY",
    "STATE",
]

# Generate BIO labels
BIO_LABELS = ["O"] + [f"B-{label}" for label in ENTITY_LABELS] + [f"I-{label}" for label in ENTITY_LABELS]
LABEL2ID = {label: i for i, label in enumerate(BIO_LABELS)}
ID2LABEL = {i: label for i, label in enumerate(BIO_LABELS)}
NUM_LABELS = len(BIO_LABELS)