File size: 9,301 Bytes
12fd5f2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 | """
Combined training loss with Human-Pattern Term:
L_total = L_CE + λ₁ · L_style + λ₂ · L_semantic + λ₃ · L_human_pattern
Where:
L_CE = cross-entropy language model loss (standard token prediction)
L_style = style consistency loss (cosine distance between output and target style vectors)
L_semantic = semantic similarity loss (cosine distance between sentence embeddings)
L_human_pattern = 1 - HumanPatternClassifier.score(output_text)
λ₁ = style loss weight (default 0.3)
λ₂ = semantic loss weight (default 0.5)
λ₃ = human pattern weight (default 0.4)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer
from typing import Optional, List, Dict
from loguru import logger
class CombinedCorrectionLoss(nn.Module):
"""V1 combined loss: L_CE + λ₁·L_style + λ₂·L_semantic."""
def __init__(
self,
lambda_style: float = 0.3,
lambda_semantic: float = 0.5,
sem_model_name: str = "all-mpnet-base-v2",
device: str = "cpu",
):
super().__init__()
self.lambda_style = lambda_style
self.lambda_semantic = lambda_semantic
self.device = device
# Cross-entropy loss
self.ce_loss = nn.CrossEntropyLoss(ignore_index=-100)
# Frozen sentence transformer for semantic similarity
logger.info(f"Loading sentence transformer for loss: {sem_model_name}")
self.sem_model = SentenceTransformer(sem_model_name, device=device)
self.sem_model.eval()
# Freeze sentence transformer weights
for param in self.sem_model.parameters():
param.requires_grad = False
def _style_loss(
self,
output_style_vec: torch.Tensor,
target_style_vec: torch.Tensor,
) -> torch.Tensor:
"""1 - cosine_similarity(output_style, target_style)."""
if output_style_vec.dim() == 1:
output_style_vec = output_style_vec.unsqueeze(0)
if target_style_vec.dim() == 1:
target_style_vec = target_style_vec.unsqueeze(0)
cos_sim = F.cosine_similarity(output_style_vec, target_style_vec, dim=-1)
return (1.0 - cos_sim).mean()
def _semantic_loss(
self,
input_texts: List[str],
output_texts: List[str],
) -> torch.Tensor:
"""Penalises meaning change between input and output."""
with torch.no_grad():
input_embeddings = self.sem_model.encode(input_texts, convert_to_tensor=True)
output_embeddings = self.sem_model.encode(output_texts, convert_to_tensor=True)
cos_sim = F.cosine_similarity(input_embeddings, output_embeddings, dim=-1)
# Loss = 1 - similarity (we want high similarity = low loss)
return (1.0 - cos_sim).mean()
def forward(
self,
logits: torch.Tensor,
labels: torch.Tensor,
output_style_vec: Optional[torch.Tensor] = None,
target_style_vec: Optional[torch.Tensor] = None,
input_texts: Optional[List[str]] = None,
output_texts: Optional[List[str]] = None,
) -> Dict[str, torch.Tensor]:
"""Compute combined loss."""
losses = {}
# L_CE: cross-entropy
# logits: [batch, seq_len, vocab_size]
# labels: [batch, seq_len]
if logits.dim() == 3:
ce_logits = logits.view(-1, logits.size(-1))
ce_labels = labels.view(-1)
else:
ce_logits = logits
ce_labels = labels
l_ce = self.ce_loss(ce_logits, ce_labels)
losses["ce_loss"] = l_ce
total = l_ce
# L_style
if output_style_vec is not None and target_style_vec is not None:
l_style = self._style_loss(output_style_vec, target_style_vec)
losses["style_loss"] = l_style
total = total + self.lambda_style * l_style
# L_semantic
if input_texts is not None and output_texts is not None:
l_semantic = self._semantic_loss(input_texts, output_texts)
losses["semantic_loss"] = l_semantic
total = total + self.lambda_semantic * l_semantic
losses["total_loss"] = total
return losses
class CombinedCorrectionLossV2(nn.Module):
"""V2 combined loss with human-pattern term: L_CE + λ₁·L_style + λ₂·L_semantic + λ₃·L_human_pattern."""
def __init__(
self,
lambda_style: float = 0.3,
lambda_semantic: float = 0.5,
lambda_human_pattern: float = 0.4,
classifier_path: str = "checkpoints/human_pattern_classifier.pt",
sem_model_name: str = "all-mpnet-base-v2",
device: str = "cpu",
):
super().__init__()
self.lambda_style = lambda_style
self.lambda_semantic = lambda_semantic
self.lambda_human_pattern = lambda_human_pattern
self.device = device
# V1 components
self.ce_loss = nn.CrossEntropyLoss(ignore_index=-100)
# Sentence transformer on CPU to save GPU VRAM for main model
logger.info(f"Loading sentence transformer for loss: {sem_model_name} (on CPU)")
self.sem_model = SentenceTransformer(sem_model_name, device="cpu")
self.sem_model.eval()
# Load frozen human pattern classifier
from .human_pattern_extractor import HumanPatternClassifier, HumanPatternFeatureExtractor
self.hp_classifier = HumanPatternClassifier()
try:
state_dict = torch.load(classifier_path, map_location=device, weights_only=True)
self.hp_classifier.load_state_dict(state_dict)
logger.info(f"Loaded human pattern classifier from {classifier_path}")
except FileNotFoundError:
logger.warning(f"Human pattern classifier not found at {classifier_path}, using random weights")
self.hp_classifier.eval()
for param in self.hp_classifier.parameters():
param.requires_grad = False
# Feature extractor on CPU to save GPU VRAM for main model
self.hp_extractor = HumanPatternFeatureExtractor(device="cpu")
def _human_pattern_loss(self, output_texts: List[str], compute_device: torch.device = None) -> torch.Tensor:
"""Loss = 1 - human_score. Penalise AI-like outputs."""
scores = []
for text in output_texts:
score = self.hp_classifier.score(text, self.hp_extractor)
scores.append(score)
device = compute_device if compute_device is not None else self.device
human_scores = torch.tensor(scores, dtype=torch.float32, device=device)
return (1.0 - human_scores).mean()
def forward(
self,
logits: torch.Tensor,
labels: torch.Tensor,
output_style_vec: Optional[torch.Tensor] = None,
target_style_vec: Optional[torch.Tensor] = None,
input_texts: Optional[List[str]] = None,
output_texts: Optional[List[str]] = None,
) -> Dict[str, torch.Tensor]:
"""Compute combined loss with human pattern term."""
losses = {}
# L_CE
if logits.dim() == 3:
ce_logits = logits.view(-1, logits.size(-1))
ce_labels = labels.view(-1)
else:
ce_logits = logits
ce_labels = labels
l_ce = self.ce_loss(ce_logits, ce_labels)
losses["ce_loss"] = l_ce
total = l_ce
# L_style
if output_style_vec is not None and target_style_vec is not None:
# Ensure both vectors are on the same device (style vecs may come from CPU fingerprinter)
compute_device = logits.device
output_style_vec = output_style_vec.to(compute_device)
target_style_vec = target_style_vec.to(compute_device)
if output_style_vec.dim() == 1:
output_style_vec = output_style_vec.unsqueeze(0)
if target_style_vec.dim() == 1:
target_style_vec = target_style_vec.unsqueeze(0)
cos_sim = F.cosine_similarity(output_style_vec, target_style_vec, dim=-1)
l_style = (1.0 - cos_sim).mean()
losses["style_loss"] = l_style
total = total + self.lambda_style * l_style
# L_semantic
if input_texts is not None and output_texts is not None:
with torch.no_grad():
input_emb = self.sem_model.encode(input_texts, convert_to_tensor=True)
output_emb = self.sem_model.encode(output_texts, convert_to_tensor=True)
# sem_model is on CPU, move embeddings to compute device
input_emb = input_emb.to(logits.device)
output_emb = output_emb.to(logits.device)
cos_sim = F.cosine_similarity(input_emb, output_emb, dim=-1)
l_semantic = (1.0 - cos_sim).mean()
losses["semantic_loss"] = l_semantic
total = total + self.lambda_semantic * l_semantic
# L_human_pattern
if output_texts is not None:
l_human = self._human_pattern_loss(output_texts, compute_device=logits.device)
losses["human_pattern_loss"] = l_human
total = total + self.lambda_human_pattern * l_human
losses["total_loss"] = total
return losses
|