Stack-2-9-finetuned / src /enhancements /nlp /entity_recognition.py
walidsobhie-code
feat: Add comprehensive enhancement modules for Stack 2.9
8f05ad1
"""
Entity Recognition Module
Provides Named Entity Recognition (NER) for extracting entities from text.
"""
from typing import List, Dict, Optional, Any
import re
class EntityRecognizer:
"""Extract named entities from text using pattern matching and NER."""
def __init__(
self,
use_transformers: bool = True,
model_name: str = "dslim/bert-base-NER",
):
"""
Initialize the entity recognizer.
Args:
use_transformers: Whether to use transformer-based NER
model_name: Name of the NER model (if using transformers)
"""
self.use_transformers = use_transformers and self._check_transformers()
self.model_name = model_name
self._model = None
self._tokenizer = None
# Define entity patterns for rule-based fallback
self._patterns = {
"EMAIL": r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b',
"URL": r'https?://[^\s]+',
"PHONE": r'\b\d{3}[-.]?\d{3}[-.]?\d{4}\b',
"IP_ADDRESS": r'\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b',
"DATE": r'\b\d{1,2}[/-]\d{1,2}[/-]\d{2,4}\b',
"TIME": r'\b\d{1,2}:\d{2}(?:\s*[AaPp][Mm])?\b',
"FILE_PATH": r'(?:/[a-zA-Z0-9_.-]+)+',
"CODE": r'`[^`]+`',
"QUOTED_STRING": r'"[^"]*"|\'[^\']*\'',
}
def _check_transformers(self) -> bool:
"""Check if transformers is available."""
try:
import transformers
return True
except ImportError:
return False
def _load_transformer_model(self):
"""Lazy load the transformer NER model."""
if self._model is None:
try:
from transformers import AutoTokenizer, AutoModelForTokenClassification
self._tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self._model = AutoModelForTokenClassification.from_pretrained(self.model_name)
self._model.eval()
except Exception as e:
print(f"Warning: Could not load transformer NER model: {e}")
self.use_transformers = False
def recognize_entities(self, text: str) -> List[Dict[str, Any]]:
"""
Recognize entities in text.
Args:
text: Input text
Returns:
List of entity dictionaries with 'text', 'type', 'start', 'end'
"""
entities = []
# Use transformer-based NER if available
if self.use_transformers:
try:
entities.extend(self._recognize_transformers(text))
except Exception:
pass
# Add rule-based entities
entities.extend(self._recognize_patterns(text))
# Sort by position and remove overlaps
entities = self._resolve_overlaps(entities)
return entities
def _recognize_transformers(self, text: str) -> List[Dict[str, Any]]:
"""Use transformer model for NER."""
self._load_transformer_model()
from transformers import pipeline
from typing import List, Dict, Any
# Create pipeline if not exists
if not hasattr(self, "_ner_pipeline"):
self._ner_pipeline = pipeline(
"ner",
model=self._model,
tokenizer=self._tokenizer,
aggregation_strategy="simple",
)
results = self._ner_pipeline(text)
entities = []
for result in results:
# Map NER tags to simpler types
entity_type = self._map_ner_tag(result.get("entity_group", ""))
if entity_type:
entities.append({
"text": result["word"],
"type": entity_type,
"start": result.get("start", 0),
"end": result.get("end", 0),
"score": result.get("score", 1.0),
})
return entities
def _map_ner_tag(self, tag: str) -> Optional[str]:
"""Map NER tags to standard entity types."""
tag_mapping = {
"PER": "PERSON",
"ORG": "ORGANIZATION",
"LOC": "LOCATION",
"MISC": "MISC",
}
return tag_mapping.get(tag)
def _recognize_patterns(self, text: str) -> List[Dict[str, Any]]:
"""Use pattern matching for entity recognition."""
entities = []
for entity_type, pattern in self._patterns.items():
for match in re.finditer(pattern, text):
entities.append({
"text": match.group(),
"type": entity_type,
"start": match.start(),
"end": match.end(),
"score": 1.0,
})
return entities
def _resolve_overlaps(self, entities: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Remove overlapping entities, keeping the higher confidence one."""
if not entities:
return []
# Sort by score (descending), then by length (descending)
entities = sorted(entities, key=lambda x: (-x.get("score", 1.0), -(x["end"] - x["start"])))
result = []
for entity in entities:
overlaps = False
for existing in result:
if self._overlaps(entity, existing):
overlaps = True
break
if not overlaps:
result.append(entity)
# Sort by position
result = sorted(result, key=lambda x: x["start"])
return result
def _overlaps(self, e1: Dict[str, Any], e2: Dict[str, Any]) -> bool:
"""Check if two entities overlap."""
return not (e1["end"] <= e2["start"] or e2["end"] <= e1["start"])
def extract_entities_by_type(
self,
text: str,
entity_type: str,
) -> List[str]:
"""
Extract all entities of a specific type.
Args:
text: Input text
entity_type: Type of entity to extract
Returns:
List of entity texts
"""
entities = self.recognize_entities(text)
return [e["text"] for e in entities if e["type"] == entity_type]
def get_entity_summary(self, text: str) -> Dict[str, int]:
"""
Get a summary of entity counts by type.
Args:
text: Input text
Returns:
Dictionary mapping entity type to count
"""
entities = self.recognize_entities(text)
summary: Dict[str, int] = {}
for entity in entities:
entity_type = entity["type"]
summary[entity_type] = summary.get(entity_type, 0) + 1
return summary
def __repr__(self) -> str:
return f"EntityRecognizer(use_transformers={self.use_transformers}, model='{self.model_name}')"