File size: 2,638 Bytes
f9ac587
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import json
import sys
import unicodedata
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent))

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from transformers import pipeline

from config import DATASET_CSV, MODEL_OUTPUT_DIR, MAX_LENGTH

MODEL_DIR = Path(__file__).parent / MODEL_OUTPUT_DIR


def normalize(text: str) -> str:
    text = text.lower()
    text = unicodedata.normalize("NFD", text)
    return "".join(c for c in text if unicodedata.category(c) != "Mn")


def main():
    csv_path = Path(__file__).parent / DATASET_CSV
    df = pd.read_csv(csv_path)

    categories = sorted(df["category"].unique().tolist())
    label2id = {name: i for i, name in enumerate(categories)}

    df["text"] = df["text"].apply(normalize)
    df["label_id"] = df["category"].map(label2id)

    # Mismo split que en train.py: 85/15 estratificado con random_state=42
    _, val_df = train_test_split(
        df, test_size=0.15, stratify=df["label_id"], random_state=42
    )
    print(f"Evaluando sobre {len(val_df)} ejemplos de validacion...")

    classifier = pipeline(
        "text-classification",
        model=str(MODEL_DIR),
        tokenizer=str(MODEL_DIR),
        top_k=1,
        device=-1,
        truncation=True,
        max_length=MAX_LENGTH,
    )

    texts = val_df["text"].tolist()
    true_labels = val_df["category"].tolist()

    print("Corriendo inferencia...")
    results = classifier(texts, batch_size=32)
    pred_labels = [r[0]["label"] for r in results]

    print("\n" + "=" * 60)
    print("CLASSIFICATION REPORT")
    print("=" * 60)
    print(classification_report(true_labels, pred_labels, digits=3))

    print("CONFUSION MATRIX")
    print("=" * 60)
    cm = confusion_matrix(true_labels, pred_labels, labels=categories)
    header = f"{'':25s}" + "".join(f"{c[:6]:>8s}" for c in categories)
    print(header)
    for i, row_label in enumerate(categories):
        row = f"{row_label:25s}" + "".join(f"{v:>8d}" for v in cm[i])
        print(row)

    # Detectar categorías con F1 < 0.80 para alertar
    report = classification_report(
        true_labels, pred_labels, output_dict=True, zero_division=0
    )
    weak = [
        cat for cat in categories
        if report.get(cat, {}).get("f1-score", 1.0) < 0.80
    ]
    if weak:
        print(f"\nADVERTENCIA: F1 < 0.80 en: {weak}")
        print("Considera agregar mas ejemplos o diversificar templates para estas categorias.")
    else:
        print("\nTodas las categorias tienen F1 >= 0.80")


if __name__ == "__main__":
    main()