x2aqq's picture
Upload folder using huggingface_hub
5363153 verified
"""Rule-based post-processing for entity refinement."""
import re
from address_parser.postprocessing.gazetteer import DelhiGazetteer
from address_parser.schemas import AddressEntity
class RuleBasedRefiner:
"""
Post-processing rules for refining NER predictions.
Handles:
- Pattern-based entity detection (pincodes, khasra numbers)
- Entity boundary correction using gazetteer
- Entity merging for fragmented predictions
- Confidence adjustment
- Validation and filtering
"""
# Regex patterns for deterministic entities
PATTERNS = {
"PINCODE": re.compile(r'\b[1-9]\d{5}\b'),
"KHASRA": re.compile(
r'\b(?:KH\.?\s*(?:NO\.?)?\s*|KHASRA\s*(?:NO\.?)?\s*)[\d/]+(?:[/-]\d+)*\b',
re.IGNORECASE
),
"HOUSE_NUMBER": re.compile(
r'\b(?:H\.?\s*(?:NO\.?)?\s*|HOUSE\s*(?:NO\.?)?\s*|PLOT\s*(?:NO\.?)?\s*)?[A-Z]?\d+[A-Z]?(?:[-/]\d+)*\b',
re.IGNORECASE
),
"FLOOR": re.compile(
r'\b(?:GROUND|FIRST|SECOND|THIRD|FOURTH|FIFTH|1ST|2ND|3RD|4TH|5TH|GF|FF|SF|TF)?\s*(?:FLOOR|FLR)\b',
re.IGNORECASE
),
"BLOCK": re.compile(
r'\b(?:BLOCK|BLK|BL)\s*[A-Z]?[-]?[A-Z0-9]+\b',
re.IGNORECASE
),
"SECTOR": re.compile(
r'\b(?:SECTOR|SEC)\s*\d+[A-Z]?\b',
re.IGNORECASE
),
"GALI": re.compile(
r'\b(?:GALI|GALLI|LANE)\s*(?:NO\.?)?\s*\d+[A-Z]?\b',
re.IGNORECASE
),
}
# Area patterns - directional areas
AREA_PATTERNS = [
(re.compile(r'\bSOUTH\s+DELHI\b', re.IGNORECASE), "SOUTH DELHI"),
(re.compile(r'\bNORTH\s+DELHI\b', re.IGNORECASE), "NORTH DELHI"),
(re.compile(r'\bEAST\s+DELHI\b', re.IGNORECASE), "EAST DELHI"),
(re.compile(r'\bWEST\s+DELHI\b', re.IGNORECASE), "WEST DELHI"),
(re.compile(r'\bCENTRAL\s+DELHI\b', re.IGNORECASE), "CENTRAL DELHI"),
(re.compile(r'\bSOUTH\s+WEST\s+DELHI\b', re.IGNORECASE), "SOUTH WEST DELHI"),
(re.compile(r'\bNORTH\s+WEST\s+DELHI\b', re.IGNORECASE), "NORTH WEST DELHI"),
(re.compile(r'\bNORTH\s+EAST\s+DELHI\b', re.IGNORECASE), "NORTH EAST DELHI"),
(re.compile(r'\bSOUTH\s+EAST\s+DELHI\b', re.IGNORECASE), "SOUTH EAST DELHI"),
(re.compile(r'\bOUTER\s+DELHI\b', re.IGNORECASE), "OUTER DELHI"),
]
# City patterns
CITY_PATTERNS = [
(re.compile(r'\bNEW\s+DELHI\b', re.IGNORECASE), "NEW DELHI"),
(re.compile(r'\bDELHI\b', re.IGNORECASE), "DELHI"),
(re.compile(r'\bNOIDA\b', re.IGNORECASE), "NOIDA"),
(re.compile(r'\bGURUGRAM\b', re.IGNORECASE), "GURUGRAM"),
(re.compile(r'\bGURGAON\b', re.IGNORECASE), "GURGAON"),
(re.compile(r'\bFARIDABAD\b', re.IGNORECASE), "FARIDABAD"),
(re.compile(r'\bGHAZIABAD\b', re.IGNORECASE), "GHAZIABAD"),
]
# State patterns
STATE_PATTERNS = [
(re.compile(r'\bDELHI\b', re.IGNORECASE), "DELHI"),
(re.compile(r'\bHARYANA\b', re.IGNORECASE), "HARYANA"),
(re.compile(r'\bUTTAR\s+PRADESH\b', re.IGNORECASE), "UTTAR PRADESH"),
(re.compile(r'\bU\.?\s*P\.?\b'), "UTTAR PRADESH"),
]
# Colony/Nagar indicators
COLONY_SUFFIXES = [
"NAGAR", "VIHAR", "COLONY", "ENCLAVE", "PARK", "GARDEN",
"PURI", "BAGH", "KUNJ", "EXTENSION", "EXTN", "PHASE",
]
# Known multi-word localities that get fragmented
KNOWN_LOCALITIES = [
"LAJPAT NAGAR", "MALVIYA NAGAR", "KAROL BAGH", "HAUZ KHAS",
"GREEN PARK", "GREATER KAILASH", "DEFENCE COLONY", "SOUTH EXTENSION",
"CHITTARANJAN PARK", "NEHRU PLACE", "SARITA VIHAR", "VASANT KUNJ",
"CIVIL LINES", "MODEL TOWN", "MUKHERJEE NAGAR", "KAMLA NAGAR",
"ASHOK VIHAR", "SHALIMAR BAGH", "PREET VIHAR", "MAYUR VIHAR",
"LAKSHMI NAGAR", "GANDHI NAGAR", "DILSHAD GARDEN", "ANAND VIHAR",
"UTTAM NAGAR", "TILAK NAGAR", "RAJOURI GARDEN", "PUNJABI BAGH",
"PASCHIM VIHAR", "CONNAUGHT PLACE", "RAJENDER NAGAR", "PATEL NAGAR",
"KIRTI NAGAR", "LODHI ROAD", "GOLF LINKS", "SANGAM VIHAR",
"GOVINDPURI", "AMBEDKAR NAGAR", "LADO SARAI", "KAUNWAR SINGH NAGAR",
"BABA HARI DAS COLONY", "SWARN PARK", "CHANCHAL PARK", "DURGA PARK",
"RAJ NAGAR", "SADH NAGAR", "VIJAY ENCLAVE", "PALAM COLONY",
]
def __init__(self, use_gazetteer: bool = True):
"""
Initialize refiner.
Args:
use_gazetteer: Use gazetteer for validation/correction
"""
self.gazetteer = DelhiGazetteer() if use_gazetteer else None
def refine(
self,
text: str,
entities: list[AddressEntity]
) -> list[AddressEntity]:
"""
Refine entity predictions.
Args:
text: Original address text
entities: Predicted entities from NER model
Returns:
Refined list of entities
"""
refined = list(entities)
# First: detect and fix known localities from gazetteer
refined = self._fix_known_localities(text, refined)
# Add rule-based entities that may have been missed
refined = self._add_pattern_entities(text, refined)
# Detect area patterns (SOUTH DELHI, etc.)
refined = self._add_area_patterns(text, refined)
# Correct entity boundaries
refined = self._correct_boundaries(text, refined)
# Merge fragmented entities
refined = self._merge_fragmented_entities(text, refined)
# Adjust confidence scores
refined = self._adjust_confidence(text, refined)
# Remove duplicates and overlapping entities
refined = self._remove_overlaps(refined)
# Validate entities
refined = self._validate_entities(refined)
return refined
def _fix_known_localities(
self,
text: str,
entities: list[AddressEntity]
) -> list[AddressEntity]:
"""Fix fragmented known localities using gazetteer lookup."""
text_upper = text.upper()
result = []
used_ranges: list[tuple[int, int]] = []
# First pass: find all known localities in text
locality_entities = []
for locality in self.KNOWN_LOCALITIES:
idx = 0
while True:
pos = text_upper.find(locality, idx)
if pos == -1:
break
end = pos + len(locality)
locality_entities.append(AddressEntity(
label="SUBAREA",
value=text[pos:end],
start=pos,
end=end,
confidence=0.95
))
used_ranges.append((pos, end))
idx = end
# Also check area patterns
for pattern, area_name in self.AREA_PATTERNS:
match = pattern.search(text)
if match:
start, end = match.start(), match.end()
# Check for overlap with existing ranges
overlaps = any(
not (end <= s or start >= e)
for s, e in used_ranges
)
if not overlaps:
locality_entities.append(AddressEntity(
label="AREA",
value=area_name,
start=start,
end=end,
confidence=0.95
))
used_ranges.append((start, end))
# Filter out original entities that overlap with found localities
for entity in entities:
# Check if entity overlaps with any locality range
overlaps_locality = any(
not (entity.end <= start or entity.start >= end)
for start, end in used_ranges
)
if overlaps_locality and entity.label in ("AREA", "SUBAREA", "COLONY", "CITY"):
# Skip this fragmented entity
continue
result.append(entity)
# Add the locality entities
result.extend(locality_entities)
return result
def _add_area_patterns(
self,
text: str,
entities: list[AddressEntity]
) -> list[AddressEntity]:
"""Add area patterns like SOUTH DELHI, NORTH DELHI (already handled in _fix_known_localities)."""
# This is now handled in _fix_known_localities to avoid duplicates
return entities
def _merge_fragmented_entities(
self,
text: str,
entities: list[AddressEntity]
) -> list[AddressEntity]:
"""Merge adjacent entities of same type that should be together."""
if len(entities) < 2:
return entities
# Sort by position
sorted_entities = sorted(entities, key=lambda e: e.start)
result = []
i = 0
while i < len(sorted_entities):
current = sorted_entities[i]
# Look for adjacent entities to merge
if current.label in ("AREA", "SUBAREA", "COLONY", "CITY"):
merged_end = current.end
merged_confidence = current.confidence
j = i + 1
# Check subsequent entities
while j < len(sorted_entities):
next_ent = sorted_entities[j]
# Check if adjacent (within 2 chars - allows for space)
gap = next_ent.start - merged_end
if gap <= 2 and next_ent.label in ("AREA", "SUBAREA", "COLONY", "CITY"):
# Check if the merged text forms a known locality
merged_text = text[current.start:next_ent.end].strip()
if self._is_valid_merge(merged_text):
merged_end = next_ent.end
merged_confidence = max(merged_confidence, next_ent.confidence)
j += 1
else:
break
else:
break
# Create merged entity if we merged anything
if j > i + 1:
merged_value = text[current.start:merged_end].strip()
result.append(AddressEntity(
label=current.label,
value=merged_value,
start=current.start,
end=merged_end,
confidence=merged_confidence
))
i = j
continue
result.append(current)
i += 1
return result
def _is_valid_merge(self, text: str) -> bool:
"""Check if merged text forms a valid locality name."""
text_upper = text.upper().strip()
# Check against known localities
if text_upper in self.KNOWN_LOCALITIES:
return True
# Check gazetteer
if self.gazetteer and self.gazetteer.is_known_locality(text_upper, threshold=80):
return True
# Check if ends with common suffix
for suffix in self.COLONY_SUFFIXES:
if text_upper.endswith(suffix):
return True
return False
def _add_pattern_entities(
self,
text: str,
entities: list[AddressEntity]
) -> list[AddressEntity]:
"""Add entities detected by regex patterns."""
result = list(entities)
existing_spans = {(e.start, e.end) for e in entities}
# Check for pincode
if not any(e.label == "PINCODE" for e in entities):
match = self.PATTERNS["PINCODE"].search(text)
if match and (match.start(), match.end()) not in existing_spans:
result.append(AddressEntity(
label="PINCODE",
value=match.group(0),
start=match.start(),
end=match.end(),
confidence=1.0 # Rule-based, high confidence
))
# Check for city - DELHI addresses always have DELHI as city
has_city = any(e.label == "CITY" for e in result)
if not has_city:
# If text contains DELHI anywhere, set city to DELHI
if "DELHI" in text.upper():
# Find the last occurrence of DELHI (usually the city mention)
delhi_positions = [m.start() for m in re.finditer(r'\bDELHI\b', text.upper())]
if delhi_positions:
pos = delhi_positions[-1] # Use last occurrence
result.append(AddressEntity(
label="CITY",
value="DELHI",
start=pos,
end=pos + 5,
confidence=0.90
))
else:
# Check other city patterns
for pattern, city_name in self.CITY_PATTERNS:
if city_name == "DELHI":
continue # Already handled above
match = pattern.search(text)
if match and (match.start(), match.end()) not in existing_spans:
result.append(AddressEntity(
label="CITY",
value=city_name,
start=match.start(),
end=match.end(),
confidence=0.95
))
break
# Check for state
if not any(e.label == "STATE" for e in entities):
for pattern, state_name in self.STATE_PATTERNS:
match = pattern.search(text)
if match and (match.start(), match.end()) not in existing_spans:
# Avoid tagging "DELHI" as state if it's already a city
if state_name == "DELHI" and any(e.label == "CITY" and "DELHI" in e.value.upper() for e in result):
continue
result.append(AddressEntity(
label="STATE",
value=state_name,
start=match.start(),
end=match.end(),
confidence=0.90
))
break
return result
def _correct_boundaries(
self,
text: str,
entities: list[AddressEntity]
) -> list[AddressEntity]:
"""Correct entity boundaries based on patterns."""
result = []
for entity in entities:
updates: dict[str, object] = {}
# Expand KHASRA to include full pattern
if entity.label == "KHASRA":
match = self.PATTERNS["KHASRA"].search(text)
if match:
updates = {"value": match.group(0), "start": match.start(), "end": match.end()}
# Expand BLOCK to include identifier
elif entity.label == "BLOCK":
match = self.PATTERNS["BLOCK"].search(text)
if match:
updates = {"value": match.group(0), "start": match.start(), "end": match.end()}
# Expand FLOOR to include floor number
elif entity.label == "FLOOR":
match = self.PATTERNS["FLOOR"].search(text)
if match:
updates = {"value": match.group(0), "start": match.start(), "end": match.end()}
# Clean up leading/trailing whitespace from value
final_value = (updates.get("value") or entity.value).strip()
if final_value != entity.value or updates:
updates["value"] = final_value
result.append(entity.model_copy(update=updates) if updates else entity)
return result
def _adjust_confidence(
self,
text: str,
entities: list[AddressEntity]
) -> list[AddressEntity]:
"""Adjust confidence scores based on patterns and gazetteer."""
result = []
for entity in entities:
new_confidence = entity.confidence
# Boost confidence for pattern matches
if entity.label in self.PATTERNS:
pattern = self.PATTERNS[entity.label]
if pattern.fullmatch(entity.value):
new_confidence = min(1.0, new_confidence + 0.1)
# Boost confidence for gazetteer matches
if self.gazetteer and entity.label in ("AREA", "SUBAREA", "COLONY"):
if self.gazetteer.is_known_locality(entity.value):
new_confidence = min(1.0, new_confidence + 0.15)
# Reduce confidence for very short entities
if len(entity.value) < 3:
new_confidence = max(0.0, new_confidence - 0.2)
if new_confidence != entity.confidence:
result.append(entity.model_copy(update={"confidence": new_confidence}))
else:
result.append(entity)
return result
def _remove_overlaps(
self,
entities: list[AddressEntity]
) -> list[AddressEntity]:
"""Remove overlapping entities, keeping higher confidence ones."""
if not entities:
return entities
# Separate CITY and PINCODE entities - these should always be kept
# as they represent different semantic levels than AREA/SUBAREA
preserved_labels = {"CITY", "PINCODE", "STATE"}
preserved_entities = [e for e in entities if e.label in preserved_labels]
other_entities = [e for e in entities if e.label not in preserved_labels]
# Sort non-preserved by confidence (descending) then by start position
sorted_entities = sorted(other_entities, key=lambda e: (-e.confidence, e.start))
result: list[AddressEntity] = []
used_ranges: list[tuple[int, int]] = []
for entity in sorted_entities:
# Check for overlap with existing entities
overlaps = False
for start, end in used_ranges:
if not (entity.end <= start or entity.start >= end):
overlaps = True
break
if not overlaps:
result.append(entity)
used_ranges.append((entity.start, entity.end))
# Add back preserved entities (CITY, PINCODE, STATE)
result.extend(preserved_entities)
# Sort by position for output
return sorted(result, key=lambda e: e.start)
def _validate_entities(
self,
entities: list[AddressEntity]
) -> list[AddressEntity]:
"""Validate and filter entities."""
result = []
for entity in entities:
# Skip empty values
if not entity.value.strip():
continue
# Skip very low confidence
if entity.confidence < 0.3:
continue
# Validate pincode format
if entity.label == "PINCODE":
if not re.fullmatch(r'[1-9]\d{5}', entity.value):
continue
if self.gazetteer and not self.gazetteer.validate_pincode(entity.value):
# Pincode outside Delhi range - reduce confidence but keep
entity = entity.model_copy(update={"confidence": entity.confidence * 0.7})
result.append(entity)
return result
def extract_all_patterns(self, text: str) -> dict[str, list[str]]:
"""
Extract all pattern-based entities from text.
Returns dict of label -> list of matched values.
"""
results = {}
for label, pattern in self.PATTERNS.items():
matches = pattern.findall(text)
if matches:
results[label] = matches
return results