File size: 5,136 Bytes
ad4c906
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
"""
=== MIA · Agent Emotion Predict Classifier (Text/BETO Embedder + MLP) ===
Objetivo: predecir la emoción del AGENTE (label_agent: 0..5) a partir de:
  - el TEXTO del usuario
  - la EMOCIÓN del texto (label del usuario: 0..5)

Arquitectura:
  Texto ──▶ Embedder (TextEmbedder ó BETOEmbedder) ─▶ h_text ∈ R^D
  Label usuario (0..5) ─▶ one-hot(6) ─▶ (feature dropout opcional)
  Concatenación [h_text ; onehot_label] ─▶ MLP ─▶ logits (6)

Notas:
- Si usas BETOEmbedder, se recomienda congelarlo (freeze) para esta segunda red.
- El feature dropout en la one-hot del label obliga al modelo a mirar el TEXTO en los casos ambiguos.
"""

from typing import List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from emotion_classifier_model import TextEmbedder, BETOEmbedder  # reemplaza con tu import real


class FeatureDropout(nn.Module):
    """Apaga aleatoriamente (con prob p) TODA la rama de la one-hot del label en entrenamiento.
    Si p=0.2, en el 20% de los batches el modelo debe decidir solo con el texto.
    """
    def __init__(self, p: float = 0.0):
        super().__init__()
        assert 0.0 <= p < 1.0
        self.p = p

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if not self.training or self.p <= 0.0:
            return x
        # Con prob p, zerea todo el vector (por muestra)
        mask = (torch.rand(x.size(0), 1, device=x.device) > self.p).float()
        return x * mask


class MLP(nn.Module):
    def __init__(self, input_dim: int, hidden1: int = 256, hidden2: int = 64, num_classes: int = 6, dropout: float = 0.2):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden1)
        self.fc2 = nn.Linear(hidden1, hidden2)
        self.out = nn.Linear(hidden2, num_classes)
        self.drop = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.relu(self.fc1(x))
        x = self.drop(x)
        x = F.relu(self.fc2(x))
        x = self.drop(x)
        return self.out(x)


class AgentEmotionPredictClassifier(nn.Module):
    """
    Segunda red: predice la emoción del AGENTE (0..5) a partir de (texto, label_usuario).

    Parámetros clave:
      - pretrained_encoder: None → TextEmbedder (emb_dim)
                            "beto" → BETOEmbedder (768D)
      - label_feature_dropout: apaga la one-hot a veces para forzar al modelo a usar el texto en casos ambiguos.
    """
    def __init__(
        self,
        model_name: str = "dccuchile/bert-base-spanish-wwm-cased",
        pretrained_encoder: Optional[str] = "beto",
        emb_dim: int = 300,
        max_length: int = 128,
        hidden1: int = 256,
        hidden2: int = 64,
        num_classes: int = 6,
        dropout: float = 0.2,
        label_feature_dropout: float = 0.15,
        device: Optional[torch.device] = None,
    ):
        super().__init__()
        self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")

        if pretrained_encoder == "beto":
            self.embedder = BETOEmbedder(model_name=model_name, max_length=max_length, device=self.device)
            embed_dim = 768
        else:
            self.embedder = TextEmbedder(model_name=model_name, emb_dim=emb_dim, max_length=max_length, device=self.device)
            embed_dim = emb_dim

        self.label_dim = 6  # one-hot(6)
        self.feat_drop = FeatureDropout(p=label_feature_dropout)
        self.classifier = MLP(input_dim=embed_dim + self.label_dim,
                      hidden1=hidden1, hidden2=hidden2,
                      num_classes=num_classes, dropout=dropout)  # num_classes = salida del AGENTE (ahora 2)
        self.to(self.device)

    # ---------- Utils ----------
    @staticmethod
    def _one_hot(labels: torch.Tensor, num_classes: int) -> torch.Tensor:
        # labels: [B] int64 → one-hot [B, C]
        return F.one_hot(labels.long(), num_classes=num_classes).float()

    def freeze_encoder(self):
        for p in self.embedder.parameters():
            p.requires_grad = False

    def unfreeze_encoder(self):
        for p in self.embedder.parameters():
            p.requires_grad = True

    # ---------- Forward / Predict ----------
    def forward(self, texts: List[str], user_labels: torch.Tensor) -> torch.Tensor:
        """texts: lista de strings (len=B)
           user_labels: tensor [B] con labels del usuario (0..5)
        """
        h_text = self.embedder.embed_batch(texts)              # [B, D]
        onehot = self._one_hot(user_labels.to(h_text.device), self.label_dim)  # [B, 6]
        onehot = self.feat_drop(onehot)                        # feature dropout (solo en train)
        x = torch.cat([h_text, onehot], dim=-1)                # [B, D+6]
        logits = self.classifier(x)                            # [B, 6]
        return logits

    @torch.inference_mode()
    def predict(self, texts: List[str], user_labels: torch.Tensor):
        self.eval()
        logits = self.forward(texts, user_labels)
        probs = logits.softmax(dim=-1)
        preds = probs.argmax(dim=-1)
        return preds, probs