kn29's picture
Update simple/ner.py
e611cc9 verified
import os
import spacy
from huggingface_hub import snapshot_download
from typing import List, Dict, Any
import logging
HF_MODEL_ID = "kn29/my-ner-model"
logger = logging.getLogger(__name__)
# Global variable to store the loaded model
_nlp_model = None
def _initialize_model(model_id: str = None):
"""Initialize the NER model"""
global _nlp_model
if _nlp_model is not None:
return _nlp_model
if model_id is None:
model_id = HF_MODEL_ID
try:
logger.info(f"Loading NER model from Hugging Face: {model_id}")
token = os.getenv("HUGGINGFACE_TOKEN") or os.getenv("HF_TOKEN")
local_dir = snapshot_download(
repo_id=model_id,
token=token if token else None
)
_nlp_model = spacy.load(local_dir)
logger.info(
f"Successfully loaded NER model from {model_id} (token={'yes' if token else 'no'})"
)
except Exception as e:
logger.error(f"Failed to load NER model from {model_id}: {str(e)}")
# Fallback to standard English model
try:
logger.info("Falling back to standard English model")
_nlp_model = spacy.load("en_core_web_sm")
except Exception as fallback_error:
logger.error(f"Fallback model also failed: {str(fallback_error)}")
raise Exception(f"No spaCy model available: {str(e)}")
return _nlp_model
def process_text(text: str, model_id: str = None) -> Dict[str, Any]:
"""Process text with NER model"""
try:
nlp = _initialize_model(model_id)
if len(text) > 4000000:
logger.info(f"Text too large ({len(text)} chars), processing in chunks")
return _process_large_text(text, nlp)
doc = nlp(text)
entities = []
entity_counts = {}
for ent in doc.ents:
processed_entities = _process_entity(ent)
for entity_text, entity_label in processed_entities:
entity_info = {
"text": entity_text,
"label": entity_label,
"start": ent.start_char,
"end": ent.end_char
}
entities.append(entity_info)
if entity_label not in entity_counts:
entity_counts[entity_label] = []
entity_counts[entity_label].append(entity_text)
for label in entity_counts:
unique_entities = list(set(entity_counts[label]))
entity_counts[label] = {
"entities": unique_entities,
"count": len(unique_entities)
}
return {
"entities": entities,
"entity_counts": entity_counts,
"total_entities": len(entities),
"unique_labels": list(entity_counts.keys())
}
except Exception as e:
logger.error(f"Error processing text with NER: {str(e)}")
return {
"error": str(e),
"entities": [],
"entity_counts": {},
"total_entities": 0
}
def _process_large_text(text: str, nlp, chunk_size: int = 3000000) -> Dict[str, Any]:
"""Process large text in chunks"""
chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
all_entities = []
all_entity_counts = {}
for i, chunk in enumerate(chunks):
logger.info(f"Processing chunk {i+1}/{len(chunks)}")
try:
doc = nlp(chunk)
for ent in doc.ents:
processed_entities = _process_entity(ent)
for entity_text, entity_label in processed_entities:
entity_info = {
"text": entity_text,
"label": entity_label,
"start": ent.start_char + (i * chunk_size),
"end": ent.end_char + (i * chunk_size)
}
all_entities.append(entity_info)
if entity_label not in all_entity_counts:
all_entity_counts[entity_label] = []
all_entity_counts[entity_label].append(entity_text)
except Exception as e:
logger.error(f"Error processing chunk {i+1}: {str(e)}")
continue
for label in all_entity_counts:
unique_entities = list(set(all_entity_counts[label]))
all_entity_counts[label] = {
"entities": unique_entities,
"count": len(unique_entities)
}
return {
"entities": all_entities,
"entity_counts": all_entity_counts,
"total_entities": len(all_entities),
"unique_labels": list(all_entity_counts.keys()),
"processed_in_chunks": True,
"num_chunks": len(chunks)
}
def _process_entity(ent) -> List[tuple]:
"""Process individual entity, handling special cases"""
if ent.label_ in ["PRECEDENT", "ORG"] and " and " in ent.text:
parts = ent.text.split(" and ")
return [(p.strip(), "ORG") for p in parts]
return [(ent.text, ent.label_)]