timagonch's picture
Deploy algospeak classifier β€” model loaded from HF Hub at runtime
1ce3127
"""
model.py
Dual BERTweet architecture for algospeak content moderation.
Two independent BERTweet encoders trained jointly with supervised InfoNCE loss:
- supervised encoder: receives "[CLASS_LABEL]: text" β€” class-aware during training
- unsupervised encoder: receives raw text only β€” the inference model
At inference, only the unsupervised encoder is used. Its embeddings are compared
to class prototypes (built from training data) via cosine similarity.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel
class BERTweetEncoder(nn.Module):
"""
Wraps vinai/bertweet-base and returns an L2-normalized CLS token embedding.
"""
def __init__(self, model_name: str):
super().__init__()
self.bert = AutoModel.from_pretrained(model_name)
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
cls_emb = outputs.last_hidden_state[:, 0, :] # [B, 768]
return F.normalize(cls_emb, dim=-1) # L2 normalize -> cosine-ready
class DualEncoderModel(nn.Module):
"""
Two independent BERTweet encoders trained with supervised InfoNCE loss.
supervised encoder:
Input: "[CLASS_LABEL]: <text>" (e.g. "Offensive Language: I hate you")
Produces class-aware embeddings during training.
Discarded after training.
unsupervised encoder:
Input: raw text
Trained (via InfoNCE) to match the supervised encoder's embedding space.
Used exclusively at inference.
"""
def __init__(self, model_name: str, temperature: float):
super().__init__()
self.supervised = BERTweetEncoder(model_name)
self.unsupervised = BERTweetEncoder(model_name)
self.temperature = temperature
def forward(
self,
sup_ids: torch.Tensor,
sup_mask: torch.Tensor,
unsup_ids: torch.Tensor,
unsup_mask: torch.Tensor,
labels: torch.Tensor,
):
e_s = self.supervised(sup_ids, sup_mask) # [B, D]
e_u = self.unsupervised(unsup_ids, unsup_mask) # [B, D]
loss = supervised_infonce_loss(e_s, e_u, labels, self.temperature)
return loss, e_s, e_u
def supervised_infonce_loss(
e_s: torch.Tensor,
e_u: torch.Tensor,
labels: torch.Tensor,
temperature: float,
) -> torch.Tensor:
"""
Cross-encoder supervised InfoNCE loss.
For each unsupervised embedding e_u_i:
Positives: all supervised embeddings e_s_j where label_j == label_i
Negatives: all supervised embeddings e_s_j where label_j != label_i
Loss = mean_i [ -log( sum_{j: pos} exp(sim_ij/Ο„) / sum_j exp(sim_ij/Ο„) ) ]
Both e_s and e_u are L2-normalized so sim = dot product = cosine similarity.
Args:
e_s: [B, D] supervised encoder embeddings
e_u: [B, D] unsupervised encoder embeddings
labels: [B] integer class labels
temperature: scalar Ο„ (typically 0.07)
Returns:
Scalar loss.
"""
# Similarity matrix: unsupervised queries supervised keys β€” [B, B]
sim = torch.mm(e_u, e_s.T) / temperature
# Positive mask: True where label_j == label_i β€” [B, B]
pos_mask = (labels.unsqueeze(1) == labels.unsqueeze(0)).float()
# Numerical stability: subtract row max before exp
sim_max, _ = sim.max(dim=1, keepdim=True)
sim = sim - sim_max.detach()
exp_sim = torch.exp(sim)
pos_sum = (exp_sim * pos_mask).sum(dim=1) # [B]
all_sum = exp_sim.sum(dim=1) # [B]
# Skip samples with no positives in this batch (shouldn't happen at batch_size >= num_classes)
valid = pos_sum > 0
if not valid.any():
return torch.tensor(0.0, requires_grad=True, device=e_s.device)
loss = -torch.log(pos_sum[valid] / all_sum[valid])
return loss.mean()