Spaces:
Sleeping
Sleeping
| from transformers import AutoModelForTokenClassification, AutoTokenizer | |
| import torch | |
| from typing import List, Tuple | |
| import logging | |
| from .base_analyzer import BaseAnalyzer | |
| logger = logging.getLogger(__name__) | |
| class NERAnalyzer(BaseAnalyzer): | |
| def __init__(self): | |
| self.model_name = "dominguesm/ner-legal-bert-base-cased-ptbr" | |
| logger.info(f"Carregando o modelo NER: {self.model_name}") | |
| self.model = AutoModelForTokenClassification.from_pretrained(self.model_name) | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
| logger.info("Modelo NER e tokenizador carregados com sucesso") | |
| def extract_entities(self, text: str) -> List[Tuple[str, str]]: | |
| logger.debug("Iniciando extração de entidades com NER") | |
| inputs = self.tokenizer(text, max_length=512, truncation=True, return_tensors="pt") | |
| tokens = inputs.tokens() | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs).logits | |
| predictions = torch.argmax(outputs, dim=2) | |
| entities = [] | |
| for token, prediction in zip(tokens, predictions[0].numpy()): | |
| entity_label = self.model.config.id2label[prediction] | |
| if entity_label != "O": | |
| entities.append((token, entity_label)) | |
| logger.info(f"tokens: {entities}") | |
| return entities | |
| def extract_representatives(self, entities: List[Tuple[str, str]]) -> List[str]: | |
| representatives = [] | |
| current_person = "" | |
| current_organization = "" | |
| for token, label in entities: | |
| if label in ["B-PESSOA", "I-PESSOA"]: | |
| if token.startswith('##'): | |
| current_person += token.replace("##", "") | |
| else: | |
| current_person += f" {token.replace('##', '')}" | |
| else: | |
| if current_person: | |
| representatives.append(current_person) | |
| current_person = "" | |
| #if label in ["B-ORGANIZACAO", "I-ORGANIZACAO"]: | |
| # current_organization += token.replace("##", "") | |
| #else: | |
| # if current_organization: | |
| # representatives.append(current_organization) | |
| # current_organization = "" | |
| if current_person: | |
| representatives.append(current_person) | |
| #if current_organization: | |
| # representatives.append(current_organization) | |
| return representatives | |
| def analyze(self, text: str) -> List[str]: | |
| entities = self.extract_entities(text) | |
| return self.extract_representatives(entities) | |
| def format_output(self, representatives: List[str]) -> str: | |
| output = "ANÁLISE DO CONTRATO SOCIAL (NER)\n\n" | |
| output += "REPRESENTANTES IDENTIFICADOS:\n" | |
| for rep in representatives: | |
| output += f"- {rep}\n" | |
| return output |