rewrite / src /training /loss_functions.py
morpheuslord's picture
Add files using upload-large-folder tool
12fd5f2 verified
"""
Combined training loss with Human-Pattern Term:
L_total = L_CE + λ₁ · L_style + λ₂ · L_semantic + λ₃ · L_human_pattern
Where:
L_CE = cross-entropy language model loss (standard token prediction)
L_style = style consistency loss (cosine distance between output and target style vectors)
L_semantic = semantic similarity loss (cosine distance between sentence embeddings)
L_human_pattern = 1 - HumanPatternClassifier.score(output_text)
λ₁ = style loss weight (default 0.3)
λ₂ = semantic loss weight (default 0.5)
λ₃ = human pattern weight (default 0.4)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer
from typing import Optional, List, Dict
from loguru import logger
class CombinedCorrectionLoss(nn.Module):
"""V1 combined loss: L_CE + λ₁·L_style + λ₂·L_semantic."""
def __init__(
self,
lambda_style: float = 0.3,
lambda_semantic: float = 0.5,
sem_model_name: str = "all-mpnet-base-v2",
device: str = "cpu",
):
super().__init__()
self.lambda_style = lambda_style
self.lambda_semantic = lambda_semantic
self.device = device
# Cross-entropy loss
self.ce_loss = nn.CrossEntropyLoss(ignore_index=-100)
# Frozen sentence transformer for semantic similarity
logger.info(f"Loading sentence transformer for loss: {sem_model_name}")
self.sem_model = SentenceTransformer(sem_model_name, device=device)
self.sem_model.eval()
# Freeze sentence transformer weights
for param in self.sem_model.parameters():
param.requires_grad = False
def _style_loss(
self,
output_style_vec: torch.Tensor,
target_style_vec: torch.Tensor,
) -> torch.Tensor:
"""1 - cosine_similarity(output_style, target_style)."""
if output_style_vec.dim() == 1:
output_style_vec = output_style_vec.unsqueeze(0)
if target_style_vec.dim() == 1:
target_style_vec = target_style_vec.unsqueeze(0)
cos_sim = F.cosine_similarity(output_style_vec, target_style_vec, dim=-1)
return (1.0 - cos_sim).mean()
def _semantic_loss(
self,
input_texts: List[str],
output_texts: List[str],
) -> torch.Tensor:
"""Penalises meaning change between input and output."""
with torch.no_grad():
input_embeddings = self.sem_model.encode(input_texts, convert_to_tensor=True)
output_embeddings = self.sem_model.encode(output_texts, convert_to_tensor=True)
cos_sim = F.cosine_similarity(input_embeddings, output_embeddings, dim=-1)
# Loss = 1 - similarity (we want high similarity = low loss)
return (1.0 - cos_sim).mean()
def forward(
self,
logits: torch.Tensor,
labels: torch.Tensor,
output_style_vec: Optional[torch.Tensor] = None,
target_style_vec: Optional[torch.Tensor] = None,
input_texts: Optional[List[str]] = None,
output_texts: Optional[List[str]] = None,
) -> Dict[str, torch.Tensor]:
"""Compute combined loss."""
losses = {}
# L_CE: cross-entropy
# logits: [batch, seq_len, vocab_size]
# labels: [batch, seq_len]
if logits.dim() == 3:
ce_logits = logits.view(-1, logits.size(-1))
ce_labels = labels.view(-1)
else:
ce_logits = logits
ce_labels = labels
l_ce = self.ce_loss(ce_logits, ce_labels)
losses["ce_loss"] = l_ce
total = l_ce
# L_style
if output_style_vec is not None and target_style_vec is not None:
l_style = self._style_loss(output_style_vec, target_style_vec)
losses["style_loss"] = l_style
total = total + self.lambda_style * l_style
# L_semantic
if input_texts is not None and output_texts is not None:
l_semantic = self._semantic_loss(input_texts, output_texts)
losses["semantic_loss"] = l_semantic
total = total + self.lambda_semantic * l_semantic
losses["total_loss"] = total
return losses
class CombinedCorrectionLossV2(nn.Module):
"""V2 combined loss with human-pattern term: L_CE + λ₁·L_style + λ₂·L_semantic + λ₃·L_human_pattern."""
def __init__(
self,
lambda_style: float = 0.3,
lambda_semantic: float = 0.5,
lambda_human_pattern: float = 0.4,
classifier_path: str = "checkpoints/human_pattern_classifier.pt",
sem_model_name: str = "all-mpnet-base-v2",
device: str = "cpu",
):
super().__init__()
self.lambda_style = lambda_style
self.lambda_semantic = lambda_semantic
self.lambda_human_pattern = lambda_human_pattern
self.device = device
# V1 components
self.ce_loss = nn.CrossEntropyLoss(ignore_index=-100)
# Sentence transformer on CPU to save GPU VRAM for main model
logger.info(f"Loading sentence transformer for loss: {sem_model_name} (on CPU)")
self.sem_model = SentenceTransformer(sem_model_name, device="cpu")
self.sem_model.eval()
# Load frozen human pattern classifier
from .human_pattern_extractor import HumanPatternClassifier, HumanPatternFeatureExtractor
self.hp_classifier = HumanPatternClassifier()
try:
state_dict = torch.load(classifier_path, map_location=device, weights_only=True)
self.hp_classifier.load_state_dict(state_dict)
logger.info(f"Loaded human pattern classifier from {classifier_path}")
except FileNotFoundError:
logger.warning(f"Human pattern classifier not found at {classifier_path}, using random weights")
self.hp_classifier.eval()
for param in self.hp_classifier.parameters():
param.requires_grad = False
# Feature extractor on CPU to save GPU VRAM for main model
self.hp_extractor = HumanPatternFeatureExtractor(device="cpu")
def _human_pattern_loss(self, output_texts: List[str], compute_device: torch.device = None) -> torch.Tensor:
"""Loss = 1 - human_score. Penalise AI-like outputs."""
scores = []
for text in output_texts:
score = self.hp_classifier.score(text, self.hp_extractor)
scores.append(score)
device = compute_device if compute_device is not None else self.device
human_scores = torch.tensor(scores, dtype=torch.float32, device=device)
return (1.0 - human_scores).mean()
def forward(
self,
logits: torch.Tensor,
labels: torch.Tensor,
output_style_vec: Optional[torch.Tensor] = None,
target_style_vec: Optional[torch.Tensor] = None,
input_texts: Optional[List[str]] = None,
output_texts: Optional[List[str]] = None,
) -> Dict[str, torch.Tensor]:
"""Compute combined loss with human pattern term."""
losses = {}
# L_CE
if logits.dim() == 3:
ce_logits = logits.view(-1, logits.size(-1))
ce_labels = labels.view(-1)
else:
ce_logits = logits
ce_labels = labels
l_ce = self.ce_loss(ce_logits, ce_labels)
losses["ce_loss"] = l_ce
total = l_ce
# L_style
if output_style_vec is not None and target_style_vec is not None:
# Ensure both vectors are on the same device (style vecs may come from CPU fingerprinter)
compute_device = logits.device
output_style_vec = output_style_vec.to(compute_device)
target_style_vec = target_style_vec.to(compute_device)
if output_style_vec.dim() == 1:
output_style_vec = output_style_vec.unsqueeze(0)
if target_style_vec.dim() == 1:
target_style_vec = target_style_vec.unsqueeze(0)
cos_sim = F.cosine_similarity(output_style_vec, target_style_vec, dim=-1)
l_style = (1.0 - cos_sim).mean()
losses["style_loss"] = l_style
total = total + self.lambda_style * l_style
# L_semantic
if input_texts is not None and output_texts is not None:
with torch.no_grad():
input_emb = self.sem_model.encode(input_texts, convert_to_tensor=True)
output_emb = self.sem_model.encode(output_texts, convert_to_tensor=True)
# sem_model is on CPU, move embeddings to compute device
input_emb = input_emb.to(logits.device)
output_emb = output_emb.to(logits.device)
cos_sim = F.cosine_similarity(input_emb, output_emb, dim=-1)
l_semantic = (1.0 - cos_sim).mean()
losses["semantic_loss"] = l_semantic
total = total + self.lambda_semantic * l_semantic
# L_human_pattern
if output_texts is not None:
l_human = self._human_pattern_loss(output_texts, compute_device=logits.device)
losses["human_pattern_loss"] = l_human
total = total + self.lambda_human_pattern * l_human
losses["total_loss"] = total
return losses