File size: 5,136 Bytes
ad4c906 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
"""
=== MIA · Agent Emotion Predict Classifier (Text/BETO Embedder + MLP) ===
Objetivo: predecir la emoción del AGENTE (label_agent: 0..5) a partir de:
- el TEXTO del usuario
- la EMOCIÓN del texto (label del usuario: 0..5)
Arquitectura:
Texto ──▶ Embedder (TextEmbedder ó BETOEmbedder) ─▶ h_text ∈ R^D
Label usuario (0..5) ─▶ one-hot(6) ─▶ (feature dropout opcional)
Concatenación [h_text ; onehot_label] ─▶ MLP ─▶ logits (6)
Notas:
- Si usas BETOEmbedder, se recomienda congelarlo (freeze) para esta segunda red.
- El feature dropout en la one-hot del label obliga al modelo a mirar el TEXTO en los casos ambiguos.
"""
from typing import List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from emotion_classifier_model import TextEmbedder, BETOEmbedder # reemplaza con tu import real
class FeatureDropout(nn.Module):
"""Apaga aleatoriamente (con prob p) TODA la rama de la one-hot del label en entrenamiento.
Si p=0.2, en el 20% de los batches el modelo debe decidir solo con el texto.
"""
def __init__(self, p: float = 0.0):
super().__init__()
assert 0.0 <= p < 1.0
self.p = p
def forward(self, x: torch.Tensor) -> torch.Tensor:
if not self.training or self.p <= 0.0:
return x
# Con prob p, zerea todo el vector (por muestra)
mask = (torch.rand(x.size(0), 1, device=x.device) > self.p).float()
return x * mask
class MLP(nn.Module):
def __init__(self, input_dim: int, hidden1: int = 256, hidden2: int = 64, num_classes: int = 6, dropout: float = 0.2):
super().__init__()
self.fc1 = nn.Linear(input_dim, hidden1)
self.fc2 = nn.Linear(hidden1, hidden2)
self.out = nn.Linear(hidden2, num_classes)
self.drop = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.relu(self.fc1(x))
x = self.drop(x)
x = F.relu(self.fc2(x))
x = self.drop(x)
return self.out(x)
class AgentEmotionPredictClassifier(nn.Module):
"""
Segunda red: predice la emoción del AGENTE (0..5) a partir de (texto, label_usuario).
Parámetros clave:
- pretrained_encoder: None → TextEmbedder (emb_dim)
"beto" → BETOEmbedder (768D)
- label_feature_dropout: apaga la one-hot a veces para forzar al modelo a usar el texto en casos ambiguos.
"""
def __init__(
self,
model_name: str = "dccuchile/bert-base-spanish-wwm-cased",
pretrained_encoder: Optional[str] = "beto",
emb_dim: int = 300,
max_length: int = 128,
hidden1: int = 256,
hidden2: int = 64,
num_classes: int = 6,
dropout: float = 0.2,
label_feature_dropout: float = 0.15,
device: Optional[torch.device] = None,
):
super().__init__()
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
if pretrained_encoder == "beto":
self.embedder = BETOEmbedder(model_name=model_name, max_length=max_length, device=self.device)
embed_dim = 768
else:
self.embedder = TextEmbedder(model_name=model_name, emb_dim=emb_dim, max_length=max_length, device=self.device)
embed_dim = emb_dim
self.label_dim = 6 # one-hot(6)
self.feat_drop = FeatureDropout(p=label_feature_dropout)
self.classifier = MLP(input_dim=embed_dim + self.label_dim,
hidden1=hidden1, hidden2=hidden2,
num_classes=num_classes, dropout=dropout) # num_classes = salida del AGENTE (ahora 2)
self.to(self.device)
# ---------- Utils ----------
@staticmethod
def _one_hot(labels: torch.Tensor, num_classes: int) -> torch.Tensor:
# labels: [B] int64 → one-hot [B, C]
return F.one_hot(labels.long(), num_classes=num_classes).float()
def freeze_encoder(self):
for p in self.embedder.parameters():
p.requires_grad = False
def unfreeze_encoder(self):
for p in self.embedder.parameters():
p.requires_grad = True
# ---------- Forward / Predict ----------
def forward(self, texts: List[str], user_labels: torch.Tensor) -> torch.Tensor:
"""texts: lista de strings (len=B)
user_labels: tensor [B] con labels del usuario (0..5)
"""
h_text = self.embedder.embed_batch(texts) # [B, D]
onehot = self._one_hot(user_labels.to(h_text.device), self.label_dim) # [B, 6]
onehot = self.feat_drop(onehot) # feature dropout (solo en train)
x = torch.cat([h_text, onehot], dim=-1) # [B, D+6]
logits = self.classifier(x) # [B, 6]
return logits
@torch.inference_mode()
def predict(self, texts: List[str], user_labels: torch.Tensor):
self.eval()
logits = self.forward(texts, user_labels)
probs = logits.softmax(dim=-1)
preds = probs.argmax(dim=-1)
return preds, probs
|