docling / analyzers /ner_analyzer.py
thlinhares's picture
Rename analyzers/ner_analyzer_bkp.py to analyzers/ner_analyzer.py
d82b8cb verified
from transformers import AutoModelForTokenClassification, AutoTokenizer
import torch
from typing import List, Tuple
import logging
from .base_analyzer import BaseAnalyzer
logger = logging.getLogger(__name__)
class NERAnalyzer(BaseAnalyzer):
def __init__(self):
self.model_name = "dominguesm/ner-legal-bert-base-cased-ptbr"
logger.info(f"Carregando o modelo NER: {self.model_name}")
self.model = AutoModelForTokenClassification.from_pretrained(self.model_name)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
logger.info("Modelo NER e tokenizador carregados com sucesso")
def extract_entities(self, text: str) -> List[Tuple[str, str]]:
logger.debug("Iniciando extração de entidades com NER")
inputs = self.tokenizer(text, max_length=512, truncation=True, return_tensors="pt")
tokens = inputs.tokens()
with torch.no_grad():
outputs = self.model(**inputs).logits
predictions = torch.argmax(outputs, dim=2)
entities = []
for token, prediction in zip(tokens, predictions[0].numpy()):
entity_label = self.model.config.id2label[prediction]
if entity_label != "O":
entities.append((token, entity_label))
logger.info(f"tokens: {entities}")
return entities
def extract_representatives(self, entities: List[Tuple[str, str]]) -> List[str]:
representatives = []
current_person = ""
current_organization = ""
for token, label in entities:
if label in ["B-PESSOA", "I-PESSOA"]:
if token.startswith('##'):
current_person += token.replace("##", "")
else:
current_person += f" {token.replace('##', '')}"
else:
if current_person:
representatives.append(current_person)
current_person = ""
#if label in ["B-ORGANIZACAO", "I-ORGANIZACAO"]:
# current_organization += token.replace("##", "")
#else:
# if current_organization:
# representatives.append(current_organization)
# current_organization = ""
if current_person:
representatives.append(current_person)
#if current_organization:
# representatives.append(current_organization)
return representatives
def analyze(self, text: str) -> List[str]:
entities = self.extract_entities(text)
return self.extract_representatives(entities)
def format_output(self, representatives: List[str]) -> str:
output = "ANÁLISE DO CONTRATO SOCIAL (NER)\n\n"
output += "REPRESENTANTES IDENTIFICADOS:\n"
for rep in representatives:
output += f"- {rep}\n"
return output