| """ |
| Combined training loss with Human-Pattern Term: |
| |
| L_total = L_CE + λ₁ · L_style + λ₂ · L_semantic + λ₃ · L_human_pattern |
| |
| Where: |
| L_CE = cross-entropy language model loss (standard token prediction) |
| L_style = style consistency loss (cosine distance between output and target style vectors) |
| L_semantic = semantic similarity loss (cosine distance between sentence embeddings) |
| L_human_pattern = 1 - HumanPatternClassifier.score(output_text) |
| λ₁ = style loss weight (default 0.3) |
| λ₂ = semantic loss weight (default 0.5) |
| λ₃ = human pattern weight (default 0.4) |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from sentence_transformers import SentenceTransformer |
| from typing import Optional, List, Dict |
| from loguru import logger |
|
|
|
|
| class CombinedCorrectionLoss(nn.Module): |
| """V1 combined loss: L_CE + λ₁·L_style + λ₂·L_semantic.""" |
|
|
| def __init__( |
| self, |
| lambda_style: float = 0.3, |
| lambda_semantic: float = 0.5, |
| sem_model_name: str = "all-mpnet-base-v2", |
| device: str = "cpu", |
| ): |
| super().__init__() |
| self.lambda_style = lambda_style |
| self.lambda_semantic = lambda_semantic |
| self.device = device |
|
|
| |
| self.ce_loss = nn.CrossEntropyLoss(ignore_index=-100) |
|
|
| |
| logger.info(f"Loading sentence transformer for loss: {sem_model_name}") |
| self.sem_model = SentenceTransformer(sem_model_name, device=device) |
| self.sem_model.eval() |
| |
| for param in self.sem_model.parameters(): |
| param.requires_grad = False |
|
|
| def _style_loss( |
| self, |
| output_style_vec: torch.Tensor, |
| target_style_vec: torch.Tensor, |
| ) -> torch.Tensor: |
| """1 - cosine_similarity(output_style, target_style).""" |
| if output_style_vec.dim() == 1: |
| output_style_vec = output_style_vec.unsqueeze(0) |
| if target_style_vec.dim() == 1: |
| target_style_vec = target_style_vec.unsqueeze(0) |
| cos_sim = F.cosine_similarity(output_style_vec, target_style_vec, dim=-1) |
| return (1.0 - cos_sim).mean() |
|
|
| def _semantic_loss( |
| self, |
| input_texts: List[str], |
| output_texts: List[str], |
| ) -> torch.Tensor: |
| """Penalises meaning change between input and output.""" |
| with torch.no_grad(): |
| input_embeddings = self.sem_model.encode(input_texts, convert_to_tensor=True) |
| output_embeddings = self.sem_model.encode(output_texts, convert_to_tensor=True) |
|
|
| cos_sim = F.cosine_similarity(input_embeddings, output_embeddings, dim=-1) |
| |
| return (1.0 - cos_sim).mean() |
|
|
| def forward( |
| self, |
| logits: torch.Tensor, |
| labels: torch.Tensor, |
| output_style_vec: Optional[torch.Tensor] = None, |
| target_style_vec: Optional[torch.Tensor] = None, |
| input_texts: Optional[List[str]] = None, |
| output_texts: Optional[List[str]] = None, |
| ) -> Dict[str, torch.Tensor]: |
| """Compute combined loss.""" |
| losses = {} |
|
|
| |
| |
| |
| if logits.dim() == 3: |
| ce_logits = logits.view(-1, logits.size(-1)) |
| ce_labels = labels.view(-1) |
| else: |
| ce_logits = logits |
| ce_labels = labels |
| l_ce = self.ce_loss(ce_logits, ce_labels) |
| losses["ce_loss"] = l_ce |
|
|
| total = l_ce |
|
|
| |
| if output_style_vec is not None and target_style_vec is not None: |
| l_style = self._style_loss(output_style_vec, target_style_vec) |
| losses["style_loss"] = l_style |
| total = total + self.lambda_style * l_style |
|
|
| |
| if input_texts is not None and output_texts is not None: |
| l_semantic = self._semantic_loss(input_texts, output_texts) |
| losses["semantic_loss"] = l_semantic |
| total = total + self.lambda_semantic * l_semantic |
|
|
| losses["total_loss"] = total |
| return losses |
|
|
|
|
| class CombinedCorrectionLossV2(nn.Module): |
| """V2 combined loss with human-pattern term: L_CE + λ₁·L_style + λ₂·L_semantic + λ₃·L_human_pattern.""" |
|
|
| def __init__( |
| self, |
| lambda_style: float = 0.3, |
| lambda_semantic: float = 0.5, |
| lambda_human_pattern: float = 0.4, |
| classifier_path: str = "checkpoints/human_pattern_classifier.pt", |
| sem_model_name: str = "all-mpnet-base-v2", |
| device: str = "cpu", |
| ): |
| super().__init__() |
| self.lambda_style = lambda_style |
| self.lambda_semantic = lambda_semantic |
| self.lambda_human_pattern = lambda_human_pattern |
| self.device = device |
|
|
| |
| self.ce_loss = nn.CrossEntropyLoss(ignore_index=-100) |
|
|
| |
| logger.info(f"Loading sentence transformer for loss: {sem_model_name} (on CPU)") |
| self.sem_model = SentenceTransformer(sem_model_name, device="cpu") |
| self.sem_model.eval() |
|
|
| |
| from .human_pattern_extractor import HumanPatternClassifier, HumanPatternFeatureExtractor |
| self.hp_classifier = HumanPatternClassifier() |
| try: |
| state_dict = torch.load(classifier_path, map_location=device, weights_only=True) |
| self.hp_classifier.load_state_dict(state_dict) |
| logger.info(f"Loaded human pattern classifier from {classifier_path}") |
| except FileNotFoundError: |
| logger.warning(f"Human pattern classifier not found at {classifier_path}, using random weights") |
|
|
| self.hp_classifier.eval() |
| for param in self.hp_classifier.parameters(): |
| param.requires_grad = False |
|
|
| |
| self.hp_extractor = HumanPatternFeatureExtractor(device="cpu") |
|
|
| def _human_pattern_loss(self, output_texts: List[str], compute_device: torch.device = None) -> torch.Tensor: |
| """Loss = 1 - human_score. Penalise AI-like outputs.""" |
| scores = [] |
| for text in output_texts: |
| score = self.hp_classifier.score(text, self.hp_extractor) |
| scores.append(score) |
| device = compute_device if compute_device is not None else self.device |
| human_scores = torch.tensor(scores, dtype=torch.float32, device=device) |
| return (1.0 - human_scores).mean() |
|
|
| def forward( |
| self, |
| logits: torch.Tensor, |
| labels: torch.Tensor, |
| output_style_vec: Optional[torch.Tensor] = None, |
| target_style_vec: Optional[torch.Tensor] = None, |
| input_texts: Optional[List[str]] = None, |
| output_texts: Optional[List[str]] = None, |
| ) -> Dict[str, torch.Tensor]: |
| """Compute combined loss with human pattern term.""" |
| losses = {} |
|
|
| |
| if logits.dim() == 3: |
| ce_logits = logits.view(-1, logits.size(-1)) |
| ce_labels = labels.view(-1) |
| else: |
| ce_logits = logits |
| ce_labels = labels |
| l_ce = self.ce_loss(ce_logits, ce_labels) |
| losses["ce_loss"] = l_ce |
| total = l_ce |
|
|
| |
| if output_style_vec is not None and target_style_vec is not None: |
| |
| compute_device = logits.device |
| output_style_vec = output_style_vec.to(compute_device) |
| target_style_vec = target_style_vec.to(compute_device) |
| if output_style_vec.dim() == 1: |
| output_style_vec = output_style_vec.unsqueeze(0) |
| if target_style_vec.dim() == 1: |
| target_style_vec = target_style_vec.unsqueeze(0) |
| cos_sim = F.cosine_similarity(output_style_vec, target_style_vec, dim=-1) |
| l_style = (1.0 - cos_sim).mean() |
| losses["style_loss"] = l_style |
| total = total + self.lambda_style * l_style |
|
|
| |
| if input_texts is not None and output_texts is not None: |
| with torch.no_grad(): |
| input_emb = self.sem_model.encode(input_texts, convert_to_tensor=True) |
| output_emb = self.sem_model.encode(output_texts, convert_to_tensor=True) |
| |
| input_emb = input_emb.to(logits.device) |
| output_emb = output_emb.to(logits.device) |
| cos_sim = F.cosine_similarity(input_emb, output_emb, dim=-1) |
| l_semantic = (1.0 - cos_sim).mean() |
| losses["semantic_loss"] = l_semantic |
| total = total + self.lambda_semantic * l_semantic |
|
|
| |
| if output_texts is not None: |
| l_human = self._human_pattern_loss(output_texts, compute_device=logits.device) |
| losses["human_pattern_loss"] = l_human |
| total = total + self.lambda_human_pattern * l_human |
|
|
| losses["total_loss"] = total |
| return losses |
|
|