MiaPredict / agent_emotion_predict_classifier.py
RustyLinux's picture
Upload folder using huggingface_hub
ad4c906 verified
"""
=== MIA · Agent Emotion Predict Classifier (Text/BETO Embedder + MLP) ===
Objetivo: predecir la emoción del AGENTE (label_agent: 0..5) a partir de:
- el TEXTO del usuario
- la EMOCIÓN del texto (label del usuario: 0..5)
Arquitectura:
Texto ──▶ Embedder (TextEmbedder ó BETOEmbedder) ─▶ h_text ∈ R^D
Label usuario (0..5) ─▶ one-hot(6) ─▶ (feature dropout opcional)
Concatenación [h_text ; onehot_label] ─▶ MLP ─▶ logits (6)
Notas:
- Si usas BETOEmbedder, se recomienda congelarlo (freeze) para esta segunda red.
- El feature dropout en la one-hot del label obliga al modelo a mirar el TEXTO en los casos ambiguos.
"""
from typing import List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from emotion_classifier_model import TextEmbedder, BETOEmbedder # reemplaza con tu import real
class FeatureDropout(nn.Module):
"""Apaga aleatoriamente (con prob p) TODA la rama de la one-hot del label en entrenamiento.
Si p=0.2, en el 20% de los batches el modelo debe decidir solo con el texto.
"""
def __init__(self, p: float = 0.0):
super().__init__()
assert 0.0 <= p < 1.0
self.p = p
def forward(self, x: torch.Tensor) -> torch.Tensor:
if not self.training or self.p <= 0.0:
return x
# Con prob p, zerea todo el vector (por muestra)
mask = (torch.rand(x.size(0), 1, device=x.device) > self.p).float()
return x * mask
class MLP(nn.Module):
def __init__(self, input_dim: int, hidden1: int = 256, hidden2: int = 64, num_classes: int = 6, dropout: float = 0.2):
super().__init__()
self.fc1 = nn.Linear(input_dim, hidden1)
self.fc2 = nn.Linear(hidden1, hidden2)
self.out = nn.Linear(hidden2, num_classes)
self.drop = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.relu(self.fc1(x))
x = self.drop(x)
x = F.relu(self.fc2(x))
x = self.drop(x)
return self.out(x)
class AgentEmotionPredictClassifier(nn.Module):
"""
Segunda red: predice la emoción del AGENTE (0..5) a partir de (texto, label_usuario).
Parámetros clave:
- pretrained_encoder: None → TextEmbedder (emb_dim)
"beto" → BETOEmbedder (768D)
- label_feature_dropout: apaga la one-hot a veces para forzar al modelo a usar el texto en casos ambiguos.
"""
def __init__(
self,
model_name: str = "dccuchile/bert-base-spanish-wwm-cased",
pretrained_encoder: Optional[str] = "beto",
emb_dim: int = 300,
max_length: int = 128,
hidden1: int = 256,
hidden2: int = 64,
num_classes: int = 6,
dropout: float = 0.2,
label_feature_dropout: float = 0.15,
device: Optional[torch.device] = None,
):
super().__init__()
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
if pretrained_encoder == "beto":
self.embedder = BETOEmbedder(model_name=model_name, max_length=max_length, device=self.device)
embed_dim = 768
else:
self.embedder = TextEmbedder(model_name=model_name, emb_dim=emb_dim, max_length=max_length, device=self.device)
embed_dim = emb_dim
self.label_dim = 6 # one-hot(6)
self.feat_drop = FeatureDropout(p=label_feature_dropout)
self.classifier = MLP(input_dim=embed_dim + self.label_dim,
hidden1=hidden1, hidden2=hidden2,
num_classes=num_classes, dropout=dropout) # num_classes = salida del AGENTE (ahora 2)
self.to(self.device)
# ---------- Utils ----------
@staticmethod
def _one_hot(labels: torch.Tensor, num_classes: int) -> torch.Tensor:
# labels: [B] int64 → one-hot [B, C]
return F.one_hot(labels.long(), num_classes=num_classes).float()
def freeze_encoder(self):
for p in self.embedder.parameters():
p.requires_grad = False
def unfreeze_encoder(self):
for p in self.embedder.parameters():
p.requires_grad = True
# ---------- Forward / Predict ----------
def forward(self, texts: List[str], user_labels: torch.Tensor) -> torch.Tensor:
"""texts: lista de strings (len=B)
user_labels: tensor [B] con labels del usuario (0..5)
"""
h_text = self.embedder.embed_batch(texts) # [B, D]
onehot = self._one_hot(user_labels.to(h_text.device), self.label_dim) # [B, 6]
onehot = self.feat_drop(onehot) # feature dropout (solo en train)
x = torch.cat([h_text, onehot], dim=-1) # [B, D+6]
logits = self.classifier(x) # [B, 6]
return logits
@torch.inference_mode()
def predict(self, texts: List[str], user_labels: torch.Tensor):
self.eval()
logits = self.forward(texts, user_labels)
probs = logits.softmax(dim=-1)
preds = probs.argmax(dim=-1)
return preds, probs