| """Model configuration for address NER."""
|
|
|
| from dataclasses import dataclass
|
|
|
|
|
| @dataclass
|
| class ModelConfig:
|
| """Configuration for BERT-CRF NER model."""
|
|
|
|
|
|
|
|
|
| model_name: str = "ai4bharat/IndicBERTv2-SS"
|
| use_crf: bool = True
|
|
|
|
|
| hidden_size: int = 768
|
| num_labels: int = 31
|
| hidden_dropout_prob: float = 0.1
|
| classifier_dropout: float = 0.1
|
|
|
|
|
| crf_reduction: str = "mean"
|
|
|
|
|
| max_length: int = 128
|
| learning_rate: float = 5e-5
|
| crf_learning_rate: float = 1e-3
|
| weight_decay: float = 0.01
|
| warmup_ratio: float = 0.1
|
| num_epochs: int = 10
|
| batch_size: int = 16
|
| gradient_accumulation_steps: int = 1
|
|
|
|
|
| label_smoothing: float = 0.0
|
|
|
|
|
| early_stopping_patience: int = 5
|
| early_stopping_threshold: float = 0.001
|
|
|
|
|
| lr_decay: float = 0.95
|
|
|
|
|
| output_dir: str = "./models"
|
| cache_dir: str | None = None
|
|
|
|
|
| onnx_opset_version: int = 14
|
|
|
| @classmethod
|
| def from_pretrained_name(cls, name: str) -> ModelConfig:
|
| """Create config for known pretrained models."""
|
| configs = {
|
| "mbert": cls(
|
| model_name="bert-base-multilingual-cased",
|
| hidden_size=768,
|
| ),
|
| "indicbert": cls(
|
| model_name="ai4bharat/IndicBERTv2-SS",
|
| hidden_size=768,
|
| ),
|
| "distilbert": cls(
|
| model_name="distilbert-base-multilingual-cased",
|
| hidden_size=768,
|
| ),
|
| "xlm-roberta": cls(
|
| model_name="xlm-roberta-base",
|
| hidden_size=768,
|
| ),
|
| "muril": cls(
|
| model_name="google/muril-base-cased",
|
| hidden_size=768,
|
| ),
|
| }
|
| return configs.get(name, cls())
|
|
|
|
|
|
|
| ENTITY_LABELS = [
|
| "AREA",
|
| "SUBAREA",
|
| "HOUSE_NUMBER",
|
| "SECTOR",
|
| "GALI",
|
| "COLONY",
|
| "BLOCK",
|
| "CAMP",
|
| "POLE",
|
| "KHASRA",
|
| "FLOOR",
|
| "PLOT",
|
| "PINCODE",
|
| "CITY",
|
| "STATE",
|
| ]
|
|
|
|
|
| BIO_LABELS = ["O"] + [f"B-{label}" for label in ENTITY_LABELS] + [f"I-{label}" for label in ENTITY_LABELS]
|
| LABEL2ID = {label: i for i, label in enumerate(BIO_LABELS)}
|
| ID2LABEL = {i: label for i, label in enumerate(BIO_LABELS)}
|
| NUM_LABELS = len(BIO_LABELS)
|
|
|