File size: 5,504 Bytes
de99208
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
"""

inference.py (MULTI-LABEL)



Inferência de extração de MÚLTIPLAS ENTIDADES (SKILL, EXPERIENCE_DURATION).



Execução:

    python scripts/inference.py --model-path models/skill_ner_multi --text "Experienced Python developer with 5+ years of experience."

"""

import argparse
import logging
from typing import List, Dict
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline

logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)


class EntityExtractor:
    """Extrator de entidades (SKILL, EXPERIENCE_DURATION) usando modelo NER treinado."""

    def __init__(self, model_path: str, device: str = None):
        self.model_path = model_path

        if device is None:
            self.device = 0 if torch.cuda.is_available() else -1
        else:
            self.device = 0 if device == "cuda" else -1

        logger.info(f"Carregando modelo de {model_path}...")

        try:
            self.tokenizer = AutoTokenizer.from_pretrained(model_path)
            self.model = AutoModelForTokenClassification.from_pretrained(model_path)
        except Exception as e:
            logger.error(
                f"Erro ao carregar modelo de {model_path}. Verifique o caminho."
            )
            raise e

        logger.info(f"Usando device: {'GPU' if self.device == 0 else 'CPU'}")

        # Cria pipeline de NER
        # A estratégia 'simple' agrupará B-SKILL/I-SKILL em SKILL
        # e B-EXPERIENCE_DURATION/I-EXPERIENCE_DURATION em EXPERIENCE_DURATION
        self.ner_pipeline = pipeline(
            "ner",
            model=self.model,
            tokenizer=self.tokenizer,
            aggregation_strategy="simple",
            device=self.device,
        )

        logger.info("Modelo carregado com sucesso!")

    def extract_skills(self, text: str, confidence_threshold: float = 0.5) -> List[str]:
        """

        Extrai APENAS skills (para manter compatibilidade).

        """
        entities = self.extract_entities_with_details(text, confidence_threshold)

        # Filtra apenas por SKILL
        skills = [e["entity"] for e in entities if e["label"] == "SKILL"]

        # Remove duplicatas
        seen = set()
        unique_skills = []
        for skill in skills:
            skill_lower = skill.lower()
            if skill_lower not in seen:
                seen.add(skill_lower)
                unique_skills.append(skill)
        return unique_skills

    def extract_entities_with_details(

        self, text: str, confidence_threshold: float = 0.5

    ) -> List[Dict]:
        """

        Extrai TODAS as entidades com detalhes (SKILL, EXPERIENCE_DURATION, etc.)

        """
        if not text or not isinstance(text, str):
            return []

        text = " ".join(text.split())
        if not text:
            return []

        try:
            # O pipeline retorna todas as entidades agrupadas
            entities = self.ner_pipeline(text)
        except Exception as e:
            logger.error(f"Erro durante a inferência do pipeline: {e}")
            return []

        detailed_entities = []
        for entity in entities:
            if entity["score"] >= confidence_threshold:
                detailed_entities.append(
                    {
                        "entity": entity["word"].strip(" .,;:"),
                        "label": entity[
                            "entity_group"
                        ],  # Ex: 'SKILL' ou 'EXPERIENCE_DURATION'
                        "start": entity["start"],
                        "end": entity["end"],
                        "confidence": round(float(entity["score"]), 3),
                    }
                )

        return detailed_entities


def main():
    parser = argparse.ArgumentParser(description="Extrai entidades de textos")
    parser.add_argument(
        "--model-path", type=str, required=True, help="Caminho do modelo treinado"
    )
    parser.add_argument("--text", type=str, help="Texto para extrair entidades")
    parser.add_argument("--file", type=str, help="Arquivo de texto para processar")
    parser.add_argument(
        "--confidence", type=float, default=0.5, help="Threshold de confidence (0-1)"
    )

    # Removido --detailed, pois a saída padrão agora é detalhada

    args = parser.parse_args()

    try:
        extractor = EntityExtractor(args.model_path)
    except Exception as e:
        logger.error(f"Falha ao inicializar EntityExtractor: {e}")
        return

    if args.file:
        logger.info(f"Lendo texto de {args.file}...")
        try:
            with open(args.file, "r", encoding="utf-8") as f:
                text = f.read()
        except FileNotFoundError:
            logger.error(f"Arquivo não encontrado: {args.file}")
            return
    elif args.text:
        text = args.text
    else:
        logger.error("Error: Especifique --text ou --file")
        return

    # Extrai todas as entidades
    results = extractor.extract_entities_with_details(text, args.confidence)

    print(f"\nExtracted {len(results)} entities:\n")
    for result in results:
        print(
            f"  [{result['label']:<21}] {result['entity']:<30} confidence: {result['confidence']:.3f}"
        )


if __name__ == "__main__":
    main()