sms-classifier-api / training /eval_report.py
cmeneses99's picture
Initial deploy: SMS Classifier API
f9ac587
import json
import sys
import unicodedata
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from transformers import pipeline
from config import DATASET_CSV, MODEL_OUTPUT_DIR, MAX_LENGTH
MODEL_DIR = Path(__file__).parent / MODEL_OUTPUT_DIR
def normalize(text: str) -> str:
text = text.lower()
text = unicodedata.normalize("NFD", text)
return "".join(c for c in text if unicodedata.category(c) != "Mn")
def main():
csv_path = Path(__file__).parent / DATASET_CSV
df = pd.read_csv(csv_path)
categories = sorted(df["category"].unique().tolist())
label2id = {name: i for i, name in enumerate(categories)}
df["text"] = df["text"].apply(normalize)
df["label_id"] = df["category"].map(label2id)
# Mismo split que en train.py: 85/15 estratificado con random_state=42
_, val_df = train_test_split(
df, test_size=0.15, stratify=df["label_id"], random_state=42
)
print(f"Evaluando sobre {len(val_df)} ejemplos de validacion...")
classifier = pipeline(
"text-classification",
model=str(MODEL_DIR),
tokenizer=str(MODEL_DIR),
top_k=1,
device=-1,
truncation=True,
max_length=MAX_LENGTH,
)
texts = val_df["text"].tolist()
true_labels = val_df["category"].tolist()
print("Corriendo inferencia...")
results = classifier(texts, batch_size=32)
pred_labels = [r[0]["label"] for r in results]
print("\n" + "=" * 60)
print("CLASSIFICATION REPORT")
print("=" * 60)
print(classification_report(true_labels, pred_labels, digits=3))
print("CONFUSION MATRIX")
print("=" * 60)
cm = confusion_matrix(true_labels, pred_labels, labels=categories)
header = f"{'':25s}" + "".join(f"{c[:6]:>8s}" for c in categories)
print(header)
for i, row_label in enumerate(categories):
row = f"{row_label:25s}" + "".join(f"{v:>8d}" for v in cm[i])
print(row)
# Detectar categorías con F1 < 0.80 para alertar
report = classification_report(
true_labels, pred_labels, output_dict=True, zero_division=0
)
weak = [
cat for cat in categories
if report.get(cat, {}).get("f1-score", 1.0) < 0.80
]
if weak:
print(f"\nADVERTENCIA: F1 < 0.80 en: {weak}")
print("Considera agregar mas ejemplos o diversificar templates para estas categorias.")
else:
print("\nTodas las categorias tienen F1 >= 0.80")
if __name__ == "__main__":
main()