""" inference.py (MULTI-LABEL) Inferência de extração de MÚLTIPLAS ENTIDADES (SKILL, EXPERIENCE_DURATION). Execução: python scripts/inference.py --model-path models/skill_ner_multi --text "Experienced Python developer with 5+ years of experience." """ import argparse import logging from typing import List, Dict import torch from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) class EntityExtractor: """Extrator de entidades (SKILL, EXPERIENCE_DURATION) usando modelo NER treinado.""" def __init__(self, model_path: str, device: str = None): self.model_path = model_path if device is None: self.device = 0 if torch.cuda.is_available() else -1 else: self.device = 0 if device == "cuda" else -1 logger.info(f"Carregando modelo de {model_path}...") try: self.tokenizer = AutoTokenizer.from_pretrained(model_path) self.model = AutoModelForTokenClassification.from_pretrained(model_path) except Exception as e: logger.error( f"Erro ao carregar modelo de {model_path}. Verifique o caminho." ) raise e logger.info(f"Usando device: {'GPU' if self.device == 0 else 'CPU'}") # Cria pipeline de NER # A estratégia 'simple' agrupará B-SKILL/I-SKILL em SKILL # e B-EXPERIENCE_DURATION/I-EXPERIENCE_DURATION em EXPERIENCE_DURATION self.ner_pipeline = pipeline( "ner", model=self.model, tokenizer=self.tokenizer, aggregation_strategy="simple", device=self.device, ) logger.info("Modelo carregado com sucesso!") def extract_skills(self, text: str, confidence_threshold: float = 0.5) -> List[str]: """ Extrai APENAS skills (para manter compatibilidade). """ entities = self.extract_entities_with_details(text, confidence_threshold) # Filtra apenas por SKILL skills = [e["entity"] for e in entities if e["label"] == "SKILL"] # Remove duplicatas seen = set() unique_skills = [] for skill in skills: skill_lower = skill.lower() if skill_lower not in seen: seen.add(skill_lower) unique_skills.append(skill) return unique_skills def extract_entities_with_details( self, text: str, confidence_threshold: float = 0.5 ) -> List[Dict]: """ Extrai TODAS as entidades com detalhes (SKILL, EXPERIENCE_DURATION, etc.) """ if not text or not isinstance(text, str): return [] text = " ".join(text.split()) if not text: return [] try: # O pipeline retorna todas as entidades agrupadas entities = self.ner_pipeline(text) except Exception as e: logger.error(f"Erro durante a inferência do pipeline: {e}") return [] detailed_entities = [] for entity in entities: if entity["score"] >= confidence_threshold: detailed_entities.append( { "entity": entity["word"].strip(" .,;:"), "label": entity[ "entity_group" ], # Ex: 'SKILL' ou 'EXPERIENCE_DURATION' "start": entity["start"], "end": entity["end"], "confidence": round(float(entity["score"]), 3), } ) return detailed_entities def main(): parser = argparse.ArgumentParser(description="Extrai entidades de textos") parser.add_argument( "--model-path", type=str, required=True, help="Caminho do modelo treinado" ) parser.add_argument("--text", type=str, help="Texto para extrair entidades") parser.add_argument("--file", type=str, help="Arquivo de texto para processar") parser.add_argument( "--confidence", type=float, default=0.5, help="Threshold de confidence (0-1)" ) # Removido --detailed, pois a saída padrão agora é detalhada args = parser.parse_args() try: extractor = EntityExtractor(args.model_path) except Exception as e: logger.error(f"Falha ao inicializar EntityExtractor: {e}") return if args.file: logger.info(f"Lendo texto de {args.file}...") try: with open(args.file, "r", encoding="utf-8") as f: text = f.read() except FileNotFoundError: logger.error(f"Arquivo não encontrado: {args.file}") return elif args.text: text = args.text else: logger.error("Error: Especifique --text ou --file") return # Extrai todas as entidades results = extractor.extract_entities_with_details(text, args.confidence) print(f"\nExtracted {len(results)} entities:\n") for result in results: print( f" [{result['label']:<21}] {result['entity']:<30} confidence: {result['confidence']:.3f}" ) if __name__ == "__main__": main()