"""Model configuration for address NER.""" from dataclasses import dataclass @dataclass class ModelConfig: """Configuration for BERT-CRF NER model.""" # Base model - IndicBERTv2-SS recommended for Indian languages # Options: "bert-base-multilingual-cased", "ai4bharat/IndicBERTv2-SS", # "google/muril-base-cased", "xlm-roberta-base" model_name: str = "ai4bharat/IndicBERTv2-SS" use_crf: bool = True # Architecture hidden_size: int = 768 num_labels: int = 31 # O + 15 entity types * 2 (B-/I-) hidden_dropout_prob: float = 0.1 classifier_dropout: float = 0.1 # CRF settings crf_reduction: str = "mean" # 'mean' or 'sum' # Training max_length: int = 128 learning_rate: float = 5e-5 crf_learning_rate: float = 1e-3 # Higher LR for CRF weight_decay: float = 0.01 warmup_ratio: float = 0.1 num_epochs: int = 10 batch_size: int = 16 gradient_accumulation_steps: int = 1 # Label smoothing label_smoothing: float = 0.0 # Early stopping early_stopping_patience: int = 5 early_stopping_threshold: float = 0.001 # Layer-wise learning rate decay lr_decay: float = 0.95 # Paths output_dir: str = "./models" cache_dir: str | None = None # ONNX export onnx_opset_version: int = 14 @classmethod def from_pretrained_name(cls, name: str) -> ModelConfig: """Create config for known pretrained models.""" configs = { "mbert": cls( model_name="bert-base-multilingual-cased", hidden_size=768, ), "indicbert": cls( model_name="ai4bharat/IndicBERTv2-SS", hidden_size=768, ), "distilbert": cls( model_name="distilbert-base-multilingual-cased", hidden_size=768, ), "xlm-roberta": cls( model_name="xlm-roberta-base", hidden_size=768, ), "muril": cls( model_name="google/muril-base-cased", hidden_size=768, ), } return configs.get(name, cls()) # Entity label definitions (must match schemas.py) ENTITY_LABELS = [ "AREA", "SUBAREA", "HOUSE_NUMBER", "SECTOR", "GALI", "COLONY", "BLOCK", "CAMP", "POLE", "KHASRA", "FLOOR", "PLOT", "PINCODE", "CITY", "STATE", ] # Generate BIO labels BIO_LABELS = ["O"] + [f"B-{label}" for label in ENTITY_LABELS] + [f"I-{label}" for label in ENTITY_LABELS] LABEL2ID = {label: i for i, label in enumerate(BIO_LABELS)} ID2LABEL = {i: label for i, label in enumerate(BIO_LABELS)} NUM_LABELS = len(BIO_LABELS)