hirly-ner / scripts /inference.py
feliponi's picture
Upload 3 files
de99208 verified
"""
inference.py (MULTI-LABEL)
Inferência de extração de MÚLTIPLAS ENTIDADES (SKILL, EXPERIENCE_DURATION).
Execução:
python scripts/inference.py --model-path models/skill_ner_multi --text "Experienced Python developer with 5+ years of experience."
"""
import argparse
import logging
from typing import List, Dict
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
class EntityExtractor:
"""Extrator de entidades (SKILL, EXPERIENCE_DURATION) usando modelo NER treinado."""
def __init__(self, model_path: str, device: str = None):
self.model_path = model_path
if device is None:
self.device = 0 if torch.cuda.is_available() else -1
else:
self.device = 0 if device == "cuda" else -1
logger.info(f"Carregando modelo de {model_path}...")
try:
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = AutoModelForTokenClassification.from_pretrained(model_path)
except Exception as e:
logger.error(
f"Erro ao carregar modelo de {model_path}. Verifique o caminho."
)
raise e
logger.info(f"Usando device: {'GPU' if self.device == 0 else 'CPU'}")
# Cria pipeline de NER
# A estratégia 'simple' agrupará B-SKILL/I-SKILL em SKILL
# e B-EXPERIENCE_DURATION/I-EXPERIENCE_DURATION em EXPERIENCE_DURATION
self.ner_pipeline = pipeline(
"ner",
model=self.model,
tokenizer=self.tokenizer,
aggregation_strategy="simple",
device=self.device,
)
logger.info("Modelo carregado com sucesso!")
def extract_skills(self, text: str, confidence_threshold: float = 0.5) -> List[str]:
"""
Extrai APENAS skills (para manter compatibilidade).
"""
entities = self.extract_entities_with_details(text, confidence_threshold)
# Filtra apenas por SKILL
skills = [e["entity"] for e in entities if e["label"] == "SKILL"]
# Remove duplicatas
seen = set()
unique_skills = []
for skill in skills:
skill_lower = skill.lower()
if skill_lower not in seen:
seen.add(skill_lower)
unique_skills.append(skill)
return unique_skills
def extract_entities_with_details(
self, text: str, confidence_threshold: float = 0.5
) -> List[Dict]:
"""
Extrai TODAS as entidades com detalhes (SKILL, EXPERIENCE_DURATION, etc.)
"""
if not text or not isinstance(text, str):
return []
text = " ".join(text.split())
if not text:
return []
try:
# O pipeline retorna todas as entidades agrupadas
entities = self.ner_pipeline(text)
except Exception as e:
logger.error(f"Erro durante a inferência do pipeline: {e}")
return []
detailed_entities = []
for entity in entities:
if entity["score"] >= confidence_threshold:
detailed_entities.append(
{
"entity": entity["word"].strip(" .,;:"),
"label": entity[
"entity_group"
], # Ex: 'SKILL' ou 'EXPERIENCE_DURATION'
"start": entity["start"],
"end": entity["end"],
"confidence": round(float(entity["score"]), 3),
}
)
return detailed_entities
def main():
parser = argparse.ArgumentParser(description="Extrai entidades de textos")
parser.add_argument(
"--model-path", type=str, required=True, help="Caminho do modelo treinado"
)
parser.add_argument("--text", type=str, help="Texto para extrair entidades")
parser.add_argument("--file", type=str, help="Arquivo de texto para processar")
parser.add_argument(
"--confidence", type=float, default=0.5, help="Threshold de confidence (0-1)"
)
# Removido --detailed, pois a saída padrão agora é detalhada
args = parser.parse_args()
try:
extractor = EntityExtractor(args.model_path)
except Exception as e:
logger.error(f"Falha ao inicializar EntityExtractor: {e}")
return
if args.file:
logger.info(f"Lendo texto de {args.file}...")
try:
with open(args.file, "r", encoding="utf-8") as f:
text = f.read()
except FileNotFoundError:
logger.error(f"Arquivo não encontrado: {args.file}")
return
elif args.text:
text = args.text
else:
logger.error("Error: Especifique --text ou --file")
return
# Extrai todas as entidades
results = extractor.extract_entities_with_details(text, args.confidence)
print(f"\nExtracted {len(results)} entities:\n")
for result in results:
print(
f" [{result['label']:<21}] {result['entity']:<30} confidence: {result['confidence']:.3f}"
)
if __name__ == "__main__":
main()