|
|
"""
|
|
|
inference.py (MULTI-LABEL)
|
|
|
|
|
|
Inferência de extração de MÚLTIPLAS ENTIDADES (SKILL, EXPERIENCE_DURATION).
|
|
|
|
|
|
Execução:
|
|
|
python scripts/inference.py --model-path models/skill_ner_multi --text "Experienced Python developer with 5+ years of experience."
|
|
|
"""
|
|
|
|
|
|
import argparse
|
|
|
import logging
|
|
|
from typing import List, Dict
|
|
|
import torch
|
|
|
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
|
|
|
|
|
|
logging.basicConfig(
|
|
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
|
|
)
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
class EntityExtractor:
|
|
|
"""Extrator de entidades (SKILL, EXPERIENCE_DURATION) usando modelo NER treinado."""
|
|
|
|
|
|
def __init__(self, model_path: str, device: str = None):
|
|
|
self.model_path = model_path
|
|
|
|
|
|
if device is None:
|
|
|
self.device = 0 if torch.cuda.is_available() else -1
|
|
|
else:
|
|
|
self.device = 0 if device == "cuda" else -1
|
|
|
|
|
|
logger.info(f"Carregando modelo de {model_path}...")
|
|
|
|
|
|
try:
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
|
|
self.model = AutoModelForTokenClassification.from_pretrained(model_path)
|
|
|
except Exception as e:
|
|
|
logger.error(
|
|
|
f"Erro ao carregar modelo de {model_path}. Verifique o caminho."
|
|
|
)
|
|
|
raise e
|
|
|
|
|
|
logger.info(f"Usando device: {'GPU' if self.device == 0 else 'CPU'}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.ner_pipeline = pipeline(
|
|
|
"ner",
|
|
|
model=self.model,
|
|
|
tokenizer=self.tokenizer,
|
|
|
aggregation_strategy="simple",
|
|
|
device=self.device,
|
|
|
)
|
|
|
|
|
|
logger.info("Modelo carregado com sucesso!")
|
|
|
|
|
|
def extract_skills(self, text: str, confidence_threshold: float = 0.5) -> List[str]:
|
|
|
"""
|
|
|
Extrai APENAS skills (para manter compatibilidade).
|
|
|
"""
|
|
|
entities = self.extract_entities_with_details(text, confidence_threshold)
|
|
|
|
|
|
|
|
|
skills = [e["entity"] for e in entities if e["label"] == "SKILL"]
|
|
|
|
|
|
|
|
|
seen = set()
|
|
|
unique_skills = []
|
|
|
for skill in skills:
|
|
|
skill_lower = skill.lower()
|
|
|
if skill_lower not in seen:
|
|
|
seen.add(skill_lower)
|
|
|
unique_skills.append(skill)
|
|
|
return unique_skills
|
|
|
|
|
|
def extract_entities_with_details(
|
|
|
self, text: str, confidence_threshold: float = 0.5
|
|
|
) -> List[Dict]:
|
|
|
"""
|
|
|
Extrai TODAS as entidades com detalhes (SKILL, EXPERIENCE_DURATION, etc.)
|
|
|
"""
|
|
|
if not text or not isinstance(text, str):
|
|
|
return []
|
|
|
|
|
|
text = " ".join(text.split())
|
|
|
if not text:
|
|
|
return []
|
|
|
|
|
|
try:
|
|
|
|
|
|
entities = self.ner_pipeline(text)
|
|
|
except Exception as e:
|
|
|
logger.error(f"Erro durante a inferência do pipeline: {e}")
|
|
|
return []
|
|
|
|
|
|
detailed_entities = []
|
|
|
for entity in entities:
|
|
|
if entity["score"] >= confidence_threshold:
|
|
|
detailed_entities.append(
|
|
|
{
|
|
|
"entity": entity["word"].strip(" .,;:"),
|
|
|
"label": entity[
|
|
|
"entity_group"
|
|
|
],
|
|
|
"start": entity["start"],
|
|
|
"end": entity["end"],
|
|
|
"confidence": round(float(entity["score"]), 3),
|
|
|
}
|
|
|
)
|
|
|
|
|
|
return detailed_entities
|
|
|
|
|
|
|
|
|
def main():
|
|
|
parser = argparse.ArgumentParser(description="Extrai entidades de textos")
|
|
|
parser.add_argument(
|
|
|
"--model-path", type=str, required=True, help="Caminho do modelo treinado"
|
|
|
)
|
|
|
parser.add_argument("--text", type=str, help="Texto para extrair entidades")
|
|
|
parser.add_argument("--file", type=str, help="Arquivo de texto para processar")
|
|
|
parser.add_argument(
|
|
|
"--confidence", type=float, default=0.5, help="Threshold de confidence (0-1)"
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
try:
|
|
|
extractor = EntityExtractor(args.model_path)
|
|
|
except Exception as e:
|
|
|
logger.error(f"Falha ao inicializar EntityExtractor: {e}")
|
|
|
return
|
|
|
|
|
|
if args.file:
|
|
|
logger.info(f"Lendo texto de {args.file}...")
|
|
|
try:
|
|
|
with open(args.file, "r", encoding="utf-8") as f:
|
|
|
text = f.read()
|
|
|
except FileNotFoundError:
|
|
|
logger.error(f"Arquivo não encontrado: {args.file}")
|
|
|
return
|
|
|
elif args.text:
|
|
|
text = args.text
|
|
|
else:
|
|
|
logger.error("Error: Especifique --text ou --file")
|
|
|
return
|
|
|
|
|
|
|
|
|
results = extractor.extract_entities_with_details(text, args.confidence)
|
|
|
|
|
|
print(f"\nExtracted {len(results)} entities:\n")
|
|
|
for result in results:
|
|
|
print(
|
|
|
f" [{result['label']:<21}] {result['entity']:<30} confidence: {result['confidence']:.3f}"
|
|
|
)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
|
|