""" Combined training loss with Human-Pattern Term: L_total = L_CE + λ₁ · L_style + λ₂ · L_semantic + λ₃ · L_human_pattern Where: L_CE = cross-entropy language model loss (standard token prediction) L_style = style consistency loss (cosine distance between output and target style vectors) L_semantic = semantic similarity loss (cosine distance between sentence embeddings) L_human_pattern = 1 - HumanPatternClassifier.score(output_text) λ₁ = style loss weight (default 0.3) λ₂ = semantic loss weight (default 0.5) λ₃ = human pattern weight (default 0.4) """ import torch import torch.nn as nn import torch.nn.functional as F from sentence_transformers import SentenceTransformer from typing import Optional, List, Dict from loguru import logger class CombinedCorrectionLoss(nn.Module): """V1 combined loss: L_CE + λ₁·L_style + λ₂·L_semantic.""" def __init__( self, lambda_style: float = 0.3, lambda_semantic: float = 0.5, sem_model_name: str = "all-mpnet-base-v2", device: str = "cpu", ): super().__init__() self.lambda_style = lambda_style self.lambda_semantic = lambda_semantic self.device = device # Cross-entropy loss self.ce_loss = nn.CrossEntropyLoss(ignore_index=-100) # Frozen sentence transformer for semantic similarity logger.info(f"Loading sentence transformer for loss: {sem_model_name}") self.sem_model = SentenceTransformer(sem_model_name, device=device) self.sem_model.eval() # Freeze sentence transformer weights for param in self.sem_model.parameters(): param.requires_grad = False def _style_loss( self, output_style_vec: torch.Tensor, target_style_vec: torch.Tensor, ) -> torch.Tensor: """1 - cosine_similarity(output_style, target_style).""" if output_style_vec.dim() == 1: output_style_vec = output_style_vec.unsqueeze(0) if target_style_vec.dim() == 1: target_style_vec = target_style_vec.unsqueeze(0) cos_sim = F.cosine_similarity(output_style_vec, target_style_vec, dim=-1) return (1.0 - cos_sim).mean() def _semantic_loss( self, input_texts: List[str], output_texts: List[str], ) -> torch.Tensor: """Penalises meaning change between input and output.""" with torch.no_grad(): input_embeddings = self.sem_model.encode(input_texts, convert_to_tensor=True) output_embeddings = self.sem_model.encode(output_texts, convert_to_tensor=True) cos_sim = F.cosine_similarity(input_embeddings, output_embeddings, dim=-1) # Loss = 1 - similarity (we want high similarity = low loss) return (1.0 - cos_sim).mean() def forward( self, logits: torch.Tensor, labels: torch.Tensor, output_style_vec: Optional[torch.Tensor] = None, target_style_vec: Optional[torch.Tensor] = None, input_texts: Optional[List[str]] = None, output_texts: Optional[List[str]] = None, ) -> Dict[str, torch.Tensor]: """Compute combined loss.""" losses = {} # L_CE: cross-entropy # logits: [batch, seq_len, vocab_size] # labels: [batch, seq_len] if logits.dim() == 3: ce_logits = logits.view(-1, logits.size(-1)) ce_labels = labels.view(-1) else: ce_logits = logits ce_labels = labels l_ce = self.ce_loss(ce_logits, ce_labels) losses["ce_loss"] = l_ce total = l_ce # L_style if output_style_vec is not None and target_style_vec is not None: l_style = self._style_loss(output_style_vec, target_style_vec) losses["style_loss"] = l_style total = total + self.lambda_style * l_style # L_semantic if input_texts is not None and output_texts is not None: l_semantic = self._semantic_loss(input_texts, output_texts) losses["semantic_loss"] = l_semantic total = total + self.lambda_semantic * l_semantic losses["total_loss"] = total return losses class CombinedCorrectionLossV2(nn.Module): """V2 combined loss with human-pattern term: L_CE + λ₁·L_style + λ₂·L_semantic + λ₃·L_human_pattern.""" def __init__( self, lambda_style: float = 0.3, lambda_semantic: float = 0.5, lambda_human_pattern: float = 0.4, classifier_path: str = "checkpoints/human_pattern_classifier.pt", sem_model_name: str = "all-mpnet-base-v2", device: str = "cpu", ): super().__init__() self.lambda_style = lambda_style self.lambda_semantic = lambda_semantic self.lambda_human_pattern = lambda_human_pattern self.device = device # V1 components self.ce_loss = nn.CrossEntropyLoss(ignore_index=-100) # Sentence transformer on CPU to save GPU VRAM for main model logger.info(f"Loading sentence transformer for loss: {sem_model_name} (on CPU)") self.sem_model = SentenceTransformer(sem_model_name, device="cpu") self.sem_model.eval() # Load frozen human pattern classifier from .human_pattern_extractor import HumanPatternClassifier, HumanPatternFeatureExtractor self.hp_classifier = HumanPatternClassifier() try: state_dict = torch.load(classifier_path, map_location=device, weights_only=True) self.hp_classifier.load_state_dict(state_dict) logger.info(f"Loaded human pattern classifier from {classifier_path}") except FileNotFoundError: logger.warning(f"Human pattern classifier not found at {classifier_path}, using random weights") self.hp_classifier.eval() for param in self.hp_classifier.parameters(): param.requires_grad = False # Feature extractor on CPU to save GPU VRAM for main model self.hp_extractor = HumanPatternFeatureExtractor(device="cpu") def _human_pattern_loss(self, output_texts: List[str], compute_device: torch.device = None) -> torch.Tensor: """Loss = 1 - human_score. Penalise AI-like outputs.""" scores = [] for text in output_texts: score = self.hp_classifier.score(text, self.hp_extractor) scores.append(score) device = compute_device if compute_device is not None else self.device human_scores = torch.tensor(scores, dtype=torch.float32, device=device) return (1.0 - human_scores).mean() def forward( self, logits: torch.Tensor, labels: torch.Tensor, output_style_vec: Optional[torch.Tensor] = None, target_style_vec: Optional[torch.Tensor] = None, input_texts: Optional[List[str]] = None, output_texts: Optional[List[str]] = None, ) -> Dict[str, torch.Tensor]: """Compute combined loss with human pattern term.""" losses = {} # L_CE if logits.dim() == 3: ce_logits = logits.view(-1, logits.size(-1)) ce_labels = labels.view(-1) else: ce_logits = logits ce_labels = labels l_ce = self.ce_loss(ce_logits, ce_labels) losses["ce_loss"] = l_ce total = l_ce # L_style if output_style_vec is not None and target_style_vec is not None: # Ensure both vectors are on the same device (style vecs may come from CPU fingerprinter) compute_device = logits.device output_style_vec = output_style_vec.to(compute_device) target_style_vec = target_style_vec.to(compute_device) if output_style_vec.dim() == 1: output_style_vec = output_style_vec.unsqueeze(0) if target_style_vec.dim() == 1: target_style_vec = target_style_vec.unsqueeze(0) cos_sim = F.cosine_similarity(output_style_vec, target_style_vec, dim=-1) l_style = (1.0 - cos_sim).mean() losses["style_loss"] = l_style total = total + self.lambda_style * l_style # L_semantic if input_texts is not None and output_texts is not None: with torch.no_grad(): input_emb = self.sem_model.encode(input_texts, convert_to_tensor=True) output_emb = self.sem_model.encode(output_texts, convert_to_tensor=True) # sem_model is on CPU, move embeddings to compute device input_emb = input_emb.to(logits.device) output_emb = output_emb.to(logits.device) cos_sim = F.cosine_similarity(input_emb, output_emb, dim=-1) l_semantic = (1.0 - cos_sim).mean() losses["semantic_loss"] = l_semantic total = total + self.lambda_semantic * l_semantic # L_human_pattern if output_texts is not None: l_human = self._human_pattern_loss(output_texts, compute_device=logits.device) losses["human_pattern_loss"] = l_human total = total + self.lambda_human_pattern * l_human losses["total_loss"] = total return losses