File size: 7,047 Bytes
1b5252f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
#!/usr/bin/env python3
\"\"\"
Entrenamiento de modelo clasificador de emails empresariales (espa帽ol)
Plan 3: Dataset Marketplace con Modelos Especializados
\"\"\"

import json
import logging
from datasets import Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback
)
import numpy as np

# ========================
# CONFIG
# ========================
MODEL_NAME = \"bert-base-multilingual-cased\"
OUTPUT_DIR = \"/tmp/email-classifier\"
HUB_MODEL_ID = \"CagliostroML/email-classifier-es\"

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# ========================
# DATASET DE ENTRENAMIENTO
# ========================
TRAINING_DATA = [
    {\"text\": \"Necesito el informe financiero del Q3 antes del viernes.\", \"label\": 0},
    {\"text\": \"El servidor est谩 ca铆do, no podemos acceder a los datos.\", \"label\": 1},
    {\"text\": \"Confirmo asistencia a la reuni贸n del lunes a las 10h.\", \"label\": 2},
    {\"text\": \"El pedido #12345 ha sido enviado, tracking: TRK998877.\", \"label\": 3},
    {\"text\": \"Por favor actualizar la direcci贸n de facturaci贸n.\", \"label\": 4},
    {\"text\": \"El pago de la factura est谩 pendiente desde hace 15 d铆as.\", \"label\": 0},
    {\"text\": \"No funciona el login en el portal, error 500.\", \"label\": 1},
    {\"text\": \"Solicito vacaciones del 15 al 20 de diciembre.\", \"label\": 5},
    {\"text\": \"El cliente XYZ ha rechazado la propuesta comercial.\", \"label\": 6},
    {\"text\": \"Necesitamos m谩s stock del producto ABC.\", \"label\": 3},
    {\"text\": \"El contrato con el proveedor est谩 listo para firma.\", \"label\": 7},
    {\"text\": \"El proyecto muestra un retraso de 3 d铆as.\", \"label\": 8},
    {\"text\": \"Solicito acceso al m贸dulo de reporting.\", \"label\": 1},
    {\"text\": \"Los n煤meros de ventas del mes muestran incremento del 12%.\", \"label\": 6},
    {\"text\": \"El evento de networking ser谩 el 22 de abril.\", \"label\": 9},
    {\"text\": \"Necesito autorizaci贸n para la compra de software.\", \"label\": 4},
    {\"text\": \"El cliente reported problemas with the shipment.\", \"label\": 3},
    {\"text\": \"La auditor铆a interna est谩 programada para la pr贸xima semana.\", \"label\": 10},
    {\"text\": \"El nuevo empleado necesita formaci贸n en el CRM.\", \"label\": 5},
    {\"text\": \"La plataforma presenta lentitud significativa desde ayer.\", \"label\": 1},
    {\"text\": \"Pueden ustedes confirmar el pago de la factura pendiente.\", \"label\": 0},
    {\"text\": \"Error en el sistema de facturaci贸n, no genera PDF.\", \"label\": 1},
    {\"text\": \"Reuni贸n de equipo a las 3pm en sala de conferencias.\", \"label\": 2},
    {\"text\": \"El env铆o lleg贸 en mal estado, necesito reembolso.\", \"label\": 3},
    {\"text\": \"Actualizar datos de contacto del proveedor.\", \"label\": 4},
    {\"text\": \"Solicito aumento de presupuesto para marketing.\", \"label\": 0},
    {\"text\": \"El website no carga correctamente en m贸vil.\", \"label\": 1},
    {\"text\": \"Solicito permiso para trabajar desde casa ma帽ana.\", \"label\": 5},
    {\"text\": \"Nuevo cliente potencial en el sector healthcare.\", \"label\": 6},
    {\"text\": \"Reponer inventario del warehouse central.\", \"label\": 3},
    {\"text\": \"El informe de gastos del mes est谩 listo para revisi贸n.\", \"label\": 0},
    {\"text\": \"El software de CRM muestra errores constantemente.\", \"label\": 1},
    {\"text\": \"La reuni贸n con proveedores fue muy productiva.\", \"label\": 2},
    {\"text\": \"Paquete recibido en almac茅n, listo para distribuci贸n.\", \"label\": 3},
    {\"text\": \"Actualizar la lista de precios del cat谩logo 2024.\", \"label\": 4},
    {\"text\": \"La inversi贸n en publicidad digital rindi贸 muy bien.\", \"label\": 0},
    {\"text\": \"El sistema de backups fall贸 esta noche.\", \"label\": 1},
    {\"text\": \"Solicito formaci贸n en herramientas de data analytics.\", \"label\": 5},
    {\"text\": \"El lead de Barcelona est谩 listo para cerrar negocio.\", \"label\": 6},
    {\"text\": \"La mercanc铆a del contenedor #4421 lleg贸 da帽ada.\", \"label\": 3},
]

LABEL_NAMES = [
    \"finance\", \"it_support\", \"meeting\", \"logistics\", \"admin\",
    \"hr\", \"sales\", \"legal\", \"project\", \"events\", \"compliance\"
]

# ========================
# METRICAS
# ========================
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    accuracy = np.mean(labels == predictions)
    return {\"accuracy\": float(accuracy)}

# ========================
# MAIN
# ========================
def main():
    logger.info(\"=== Training Email Classifier (Spanish) ===\")
    
    # Crear dataset
    ds = Dataset.from_list(TRAINING_DATA)
    ds = ds.train_test_split(test_size=0.2, seed=42)
    
    logger.info(f\"Train samples: {len(ds['train'])}\")
    logger.info(f\"Test samples: {len(ds['test'])}\")
    
    # Tokenizer
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    
    def tokenize(batch):
        return tokenizer(batch[\"text\"], padding=True, truncation=True, max_length=128)
    
    ds = ds.map(tokenize, batched=True)
    
    # Modelo
    num_labels = len(LABEL_NAMES)
    model = AutoModelForSequenceClassification.from_pretrained(
        MODEL_NAME,
        num_labels=num_labels
    )
    
    # Training args
    training_args = TrainingArguments(
        output_dir=OUTPUT_DIR,
        num_train_epochs=10,
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        warmup_steps=5,
        logging_dir=\"/tmp/logs\",
        logging_steps=5,
        eval_strategy=\"epoch\",
        save_strategy=\"epoch\",
        load_best_model_at_end=True,
        push_to_hub=True,
        hub_model_id=HUB_MODEL_ID,
        report_to=\"none\"
    )
    
    # Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=ds[\"train\"],
        eval_dataset=ds[\"test\"],
        compute_metrics=compute_metrics,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
    )
    
    logger.info(\"Starting training...\")
    trainer.train()
    
    logger.info(\"Evaluating...\")
    results = trainer.evaluate()
    logger.info(f\"Results: {results}\")
    
    logger.info(\"Pushing to Hub...\")
    trainer.push_to_hub()
    
    # Guardar config
    config = {
        \"model_type\": \"text-classification\",
        \"language\": \"es\",
        \"labels\": LABEL_NAMES,
        \"num_labels\": num_labels,
        \"accuracy\": results.get(\"eval_accuracy\", 0),
        \"f1\": results.get(\"eval_f1\", 0)
    }
    
    with open(\"/tmp/model_config.json\", \"w\") as f:
        json.dump(config, f, indent=2)
    
    logger.info(\"=== Training Complete ===\")
    logger.info(f\"Model pushed to: https://huggingface.co/{HUB_MODEL_ID}\")
    
    return results

if __name__ == \"__main__\":
    main()