|
|
""" |
|
|
=== MIA · Agent Emotion Predict Classifier (Text/BETO Embedder + MLP) === |
|
|
Objetivo: predecir la emoción del AGENTE (label_agent: 0..5) a partir de: |
|
|
- el TEXTO del usuario |
|
|
- la EMOCIÓN del texto (label del usuario: 0..5) |
|
|
|
|
|
Arquitectura: |
|
|
Texto ──▶ Embedder (TextEmbedder ó BETOEmbedder) ─▶ h_text ∈ R^D |
|
|
Label usuario (0..5) ─▶ one-hot(6) ─▶ (feature dropout opcional) |
|
|
Concatenación [h_text ; onehot_label] ─▶ MLP ─▶ logits (6) |
|
|
|
|
|
Notas: |
|
|
- Si usas BETOEmbedder, se recomienda congelarlo (freeze) para esta segunda red. |
|
|
- El feature dropout en la one-hot del label obliga al modelo a mirar el TEXTO en los casos ambiguos. |
|
|
""" |
|
|
|
|
|
from typing import List, Optional |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from emotion_classifier_model import TextEmbedder, BETOEmbedder |
|
|
|
|
|
|
|
|
class FeatureDropout(nn.Module): |
|
|
"""Apaga aleatoriamente (con prob p) TODA la rama de la one-hot del label en entrenamiento. |
|
|
Si p=0.2, en el 20% de los batches el modelo debe decidir solo con el texto. |
|
|
""" |
|
|
def __init__(self, p: float = 0.0): |
|
|
super().__init__() |
|
|
assert 0.0 <= p < 1.0 |
|
|
self.p = p |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
if not self.training or self.p <= 0.0: |
|
|
return x |
|
|
|
|
|
mask = (torch.rand(x.size(0), 1, device=x.device) > self.p).float() |
|
|
return x * mask |
|
|
|
|
|
|
|
|
class MLP(nn.Module): |
|
|
def __init__(self, input_dim: int, hidden1: int = 256, hidden2: int = 64, num_classes: int = 6, dropout: float = 0.2): |
|
|
super().__init__() |
|
|
self.fc1 = nn.Linear(input_dim, hidden1) |
|
|
self.fc2 = nn.Linear(hidden1, hidden2) |
|
|
self.out = nn.Linear(hidden2, num_classes) |
|
|
self.drop = nn.Dropout(dropout) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
x = F.relu(self.fc1(x)) |
|
|
x = self.drop(x) |
|
|
x = F.relu(self.fc2(x)) |
|
|
x = self.drop(x) |
|
|
return self.out(x) |
|
|
|
|
|
|
|
|
class AgentEmotionPredictClassifier(nn.Module): |
|
|
""" |
|
|
Segunda red: predice la emoción del AGENTE (0..5) a partir de (texto, label_usuario). |
|
|
|
|
|
Parámetros clave: |
|
|
- pretrained_encoder: None → TextEmbedder (emb_dim) |
|
|
"beto" → BETOEmbedder (768D) |
|
|
- label_feature_dropout: apaga la one-hot a veces para forzar al modelo a usar el texto en casos ambiguos. |
|
|
""" |
|
|
def __init__( |
|
|
self, |
|
|
model_name: str = "dccuchile/bert-base-spanish-wwm-cased", |
|
|
pretrained_encoder: Optional[str] = "beto", |
|
|
emb_dim: int = 300, |
|
|
max_length: int = 128, |
|
|
hidden1: int = 256, |
|
|
hidden2: int = 64, |
|
|
num_classes: int = 6, |
|
|
dropout: float = 0.2, |
|
|
label_feature_dropout: float = 0.15, |
|
|
device: Optional[torch.device] = None, |
|
|
): |
|
|
super().__init__() |
|
|
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
if pretrained_encoder == "beto": |
|
|
self.embedder = BETOEmbedder(model_name=model_name, max_length=max_length, device=self.device) |
|
|
embed_dim = 768 |
|
|
else: |
|
|
self.embedder = TextEmbedder(model_name=model_name, emb_dim=emb_dim, max_length=max_length, device=self.device) |
|
|
embed_dim = emb_dim |
|
|
|
|
|
self.label_dim = 6 |
|
|
self.feat_drop = FeatureDropout(p=label_feature_dropout) |
|
|
self.classifier = MLP(input_dim=embed_dim + self.label_dim, |
|
|
hidden1=hidden1, hidden2=hidden2, |
|
|
num_classes=num_classes, dropout=dropout) |
|
|
self.to(self.device) |
|
|
|
|
|
|
|
|
@staticmethod |
|
|
def _one_hot(labels: torch.Tensor, num_classes: int) -> torch.Tensor: |
|
|
|
|
|
return F.one_hot(labels.long(), num_classes=num_classes).float() |
|
|
|
|
|
def freeze_encoder(self): |
|
|
for p in self.embedder.parameters(): |
|
|
p.requires_grad = False |
|
|
|
|
|
def unfreeze_encoder(self): |
|
|
for p in self.embedder.parameters(): |
|
|
p.requires_grad = True |
|
|
|
|
|
|
|
|
def forward(self, texts: List[str], user_labels: torch.Tensor) -> torch.Tensor: |
|
|
"""texts: lista de strings (len=B) |
|
|
user_labels: tensor [B] con labels del usuario (0..5) |
|
|
""" |
|
|
h_text = self.embedder.embed_batch(texts) |
|
|
onehot = self._one_hot(user_labels.to(h_text.device), self.label_dim) |
|
|
onehot = self.feat_drop(onehot) |
|
|
x = torch.cat([h_text, onehot], dim=-1) |
|
|
logits = self.classifier(x) |
|
|
return logits |
|
|
|
|
|
@torch.inference_mode() |
|
|
def predict(self, texts: List[str], user_labels: torch.Tensor): |
|
|
self.eval() |
|
|
logits = self.forward(texts, user_labels) |
|
|
probs = logits.softmax(dim=-1) |
|
|
preds = probs.argmax(dim=-1) |
|
|
return preds, probs |
|
|
|
|
|
|
|
|
|
|
|
|