samithcs's picture
Update src/components/model_nlp_ner.py
1339523 verified
import tensorflow as tf
from transformers import DistilBertTokenizerFast, TFDistilBertForTokenClassification, pipeline
from huggingface_hub import snapshot_download
import numpy as np
import joblib
import os
print("Downloading NER model from Hugging Face...")
repo_path = snapshot_download(
repo_id="samithcs/nlp_ner",
repo_type="model"
)
print(f"NER model downloaded to: {repo_path}")
NER_MODEL_PATH = os.path.join(repo_path, "nlp_ner", "ner_model")
NER_TOKENIZER_PATH = os.path.join(repo_path, "nlp_ner", "ner_tokenizer")
LABEL2ID_PATH = os.path.join(repo_path, "nlp_ner", "label2id.joblib")
ner_model = TFDistilBertForTokenClassification.from_pretrained(NER_MODEL_PATH)
ner_tokenizer = DistilBertTokenizerFast.from_pretrained(NER_TOKENIZER_PATH)
label2id = joblib.load(LABEL2ID_PATH)
id2label = {i: t for t, i in label2id.items()}
print("Loading Hugging Face NER pipeline...")
hf_ner = pipeline(
"ner",
grouped_entities=True,
model="dbmdz/bert-large-cased-finetuned-conll03-english"
)
print("NER models loaded successfully!")
def extract_entities_pipeline(text: str) -> dict:
tokens = text.split()
encoding = ner_tokenizer(
[tokens],
is_split_into_words=True,
return_tensors='tf',
padding='max_length',
truncation=True,
max_length=32
)
outputs = ner_model({k: v for k, v in encoding.items() if k != "labels"})
pred_ids = np.argmax(outputs.logits.numpy()[0], axis=-1)
entities = {"location": [], "event": []}
current_loc, current_evt = [], []
for w, id in zip(tokens, pred_ids[:len(tokens)]):
label = id2label[id]
if label == "B-LOC":
if current_loc:
entities["location"].append(" ".join(current_loc))
current_loc = [w]
elif label == "I-LOC" and current_loc:
current_loc.append(w)
else:
if current_loc:
entities["location"].append(" ".join(current_loc))
current_loc = []
if label == "B-EVENT":
if current_evt:
entities["event"].append(" ".join(current_evt))
current_evt = [w]
elif label == "I-EVENT" and current_evt:
current_evt.append(w)
else:
if current_evt:
entities["event"].append(" ".join(current_evt))
current_evt = []
if current_loc:
entities["location"].append(" ".join(current_loc))
if current_evt:
entities["event"].append(" ".join(current_evt))
hf_results = hf_ner(text)
hf_locations = [ent['word'] for ent in hf_results if ent['entity_group'] == "LOC"]
entities["location"] = list(set(entities["location"]) | set(hf_locations))
return entities