File size: 5,504 Bytes
de99208 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
"""
inference.py (MULTI-LABEL)
Inferência de extração de MÚLTIPLAS ENTIDADES (SKILL, EXPERIENCE_DURATION).
Execução:
python scripts/inference.py --model-path models/skill_ner_multi --text "Experienced Python developer with 5+ years of experience."
"""
import argparse
import logging
from typing import List, Dict
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
class EntityExtractor:
"""Extrator de entidades (SKILL, EXPERIENCE_DURATION) usando modelo NER treinado."""
def __init__(self, model_path: str, device: str = None):
self.model_path = model_path
if device is None:
self.device = 0 if torch.cuda.is_available() else -1
else:
self.device = 0 if device == "cuda" else -1
logger.info(f"Carregando modelo de {model_path}...")
try:
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = AutoModelForTokenClassification.from_pretrained(model_path)
except Exception as e:
logger.error(
f"Erro ao carregar modelo de {model_path}. Verifique o caminho."
)
raise e
logger.info(f"Usando device: {'GPU' if self.device == 0 else 'CPU'}")
# Cria pipeline de NER
# A estratégia 'simple' agrupará B-SKILL/I-SKILL em SKILL
# e B-EXPERIENCE_DURATION/I-EXPERIENCE_DURATION em EXPERIENCE_DURATION
self.ner_pipeline = pipeline(
"ner",
model=self.model,
tokenizer=self.tokenizer,
aggregation_strategy="simple",
device=self.device,
)
logger.info("Modelo carregado com sucesso!")
def extract_skills(self, text: str, confidence_threshold: float = 0.5) -> List[str]:
"""
Extrai APENAS skills (para manter compatibilidade).
"""
entities = self.extract_entities_with_details(text, confidence_threshold)
# Filtra apenas por SKILL
skills = [e["entity"] for e in entities if e["label"] == "SKILL"]
# Remove duplicatas
seen = set()
unique_skills = []
for skill in skills:
skill_lower = skill.lower()
if skill_lower not in seen:
seen.add(skill_lower)
unique_skills.append(skill)
return unique_skills
def extract_entities_with_details(
self, text: str, confidence_threshold: float = 0.5
) -> List[Dict]:
"""
Extrai TODAS as entidades com detalhes (SKILL, EXPERIENCE_DURATION, etc.)
"""
if not text or not isinstance(text, str):
return []
text = " ".join(text.split())
if not text:
return []
try:
# O pipeline retorna todas as entidades agrupadas
entities = self.ner_pipeline(text)
except Exception as e:
logger.error(f"Erro durante a inferência do pipeline: {e}")
return []
detailed_entities = []
for entity in entities:
if entity["score"] >= confidence_threshold:
detailed_entities.append(
{
"entity": entity["word"].strip(" .,;:"),
"label": entity[
"entity_group"
], # Ex: 'SKILL' ou 'EXPERIENCE_DURATION'
"start": entity["start"],
"end": entity["end"],
"confidence": round(float(entity["score"]), 3),
}
)
return detailed_entities
def main():
parser = argparse.ArgumentParser(description="Extrai entidades de textos")
parser.add_argument(
"--model-path", type=str, required=True, help="Caminho do modelo treinado"
)
parser.add_argument("--text", type=str, help="Texto para extrair entidades")
parser.add_argument("--file", type=str, help="Arquivo de texto para processar")
parser.add_argument(
"--confidence", type=float, default=0.5, help="Threshold de confidence (0-1)"
)
# Removido --detailed, pois a saída padrão agora é detalhada
args = parser.parse_args()
try:
extractor = EntityExtractor(args.model_path)
except Exception as e:
logger.error(f"Falha ao inicializar EntityExtractor: {e}")
return
if args.file:
logger.info(f"Lendo texto de {args.file}...")
try:
with open(args.file, "r", encoding="utf-8") as f:
text = f.read()
except FileNotFoundError:
logger.error(f"Arquivo não encontrado: {args.file}")
return
elif args.text:
text = args.text
else:
logger.error("Error: Especifique --text ou --file")
return
# Extrai todas as entidades
results = extractor.extract_entities_with_details(text, args.confidence)
print(f"\nExtracted {len(results)} entities:\n")
for result in results:
print(
f" [{result['label']:<21}] {result['entity']:<30} confidence: {result['confidence']:.3f}"
)
if __name__ == "__main__":
main()
|