File size: 9,301 Bytes
12fd5f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
"""
Combined training loss with Human-Pattern Term:

L_total = L_CE + λ₁ · L_style + λ₂ · L_semantic + λ₃ · L_human_pattern

Where:
  L_CE       = cross-entropy language model loss (standard token prediction)
  L_style    = style consistency loss (cosine distance between output and target style vectors)
  L_semantic = semantic similarity loss (cosine distance between sentence embeddings)
  L_human_pattern = 1 - HumanPatternClassifier.score(output_text)
  λ₁         = style loss weight (default 0.3)
  λ₂         = semantic loss weight (default 0.5)
  λ₃         = human pattern weight (default 0.4)
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer
from typing import Optional, List, Dict
from loguru import logger


class CombinedCorrectionLoss(nn.Module):
    """V1 combined loss: L_CE + λ₁·L_style + λ₂·L_semantic."""

    def __init__(
        self,
        lambda_style: float = 0.3,
        lambda_semantic: float = 0.5,
        sem_model_name: str = "all-mpnet-base-v2",
        device: str = "cpu",
    ):
        super().__init__()
        self.lambda_style = lambda_style
        self.lambda_semantic = lambda_semantic
        self.device = device

        # Cross-entropy loss
        self.ce_loss = nn.CrossEntropyLoss(ignore_index=-100)

        # Frozen sentence transformer for semantic similarity
        logger.info(f"Loading sentence transformer for loss: {sem_model_name}")
        self.sem_model = SentenceTransformer(sem_model_name, device=device)
        self.sem_model.eval()
        # Freeze sentence transformer weights
        for param in self.sem_model.parameters():
            param.requires_grad = False

    def _style_loss(
        self,
        output_style_vec: torch.Tensor,
        target_style_vec: torch.Tensor,
    ) -> torch.Tensor:
        """1 - cosine_similarity(output_style, target_style)."""
        if output_style_vec.dim() == 1:
            output_style_vec = output_style_vec.unsqueeze(0)
        if target_style_vec.dim() == 1:
            target_style_vec = target_style_vec.unsqueeze(0)
        cos_sim = F.cosine_similarity(output_style_vec, target_style_vec, dim=-1)
        return (1.0 - cos_sim).mean()

    def _semantic_loss(
        self,
        input_texts: List[str],
        output_texts: List[str],
    ) -> torch.Tensor:
        """Penalises meaning change between input and output."""
        with torch.no_grad():
            input_embeddings = self.sem_model.encode(input_texts, convert_to_tensor=True)
            output_embeddings = self.sem_model.encode(output_texts, convert_to_tensor=True)

        cos_sim = F.cosine_similarity(input_embeddings, output_embeddings, dim=-1)
        # Loss = 1 - similarity (we want high similarity = low loss)
        return (1.0 - cos_sim).mean()

    def forward(
        self,
        logits: torch.Tensor,
        labels: torch.Tensor,
        output_style_vec: Optional[torch.Tensor] = None,
        target_style_vec: Optional[torch.Tensor] = None,
        input_texts: Optional[List[str]] = None,
        output_texts: Optional[List[str]] = None,
    ) -> Dict[str, torch.Tensor]:
        """Compute combined loss."""
        losses = {}

        # L_CE: cross-entropy
        # logits: [batch, seq_len, vocab_size]
        # labels: [batch, seq_len]
        if logits.dim() == 3:
            ce_logits = logits.view(-1, logits.size(-1))
            ce_labels = labels.view(-1)
        else:
            ce_logits = logits
            ce_labels = labels
        l_ce = self.ce_loss(ce_logits, ce_labels)
        losses["ce_loss"] = l_ce

        total = l_ce

        # L_style
        if output_style_vec is not None and target_style_vec is not None:
            l_style = self._style_loss(output_style_vec, target_style_vec)
            losses["style_loss"] = l_style
            total = total + self.lambda_style * l_style

        # L_semantic
        if input_texts is not None and output_texts is not None:
            l_semantic = self._semantic_loss(input_texts, output_texts)
            losses["semantic_loss"] = l_semantic
            total = total + self.lambda_semantic * l_semantic

        losses["total_loss"] = total
        return losses


class CombinedCorrectionLossV2(nn.Module):
    """V2 combined loss with human-pattern term: L_CE + λ₁·L_style + λ₂·L_semantic + λ₃·L_human_pattern."""

    def __init__(
        self,
        lambda_style: float = 0.3,
        lambda_semantic: float = 0.5,
        lambda_human_pattern: float = 0.4,
        classifier_path: str = "checkpoints/human_pattern_classifier.pt",
        sem_model_name: str = "all-mpnet-base-v2",
        device: str = "cpu",
    ):
        super().__init__()
        self.lambda_style = lambda_style
        self.lambda_semantic = lambda_semantic
        self.lambda_human_pattern = lambda_human_pattern
        self.device = device

        # V1 components
        self.ce_loss = nn.CrossEntropyLoss(ignore_index=-100)

        # Sentence transformer on CPU to save GPU VRAM for main model
        logger.info(f"Loading sentence transformer for loss: {sem_model_name} (on CPU)")
        self.sem_model = SentenceTransformer(sem_model_name, device="cpu")
        self.sem_model.eval()

        # Load frozen human pattern classifier
        from .human_pattern_extractor import HumanPatternClassifier, HumanPatternFeatureExtractor
        self.hp_classifier = HumanPatternClassifier()
        try:
            state_dict = torch.load(classifier_path, map_location=device, weights_only=True)
            self.hp_classifier.load_state_dict(state_dict)
            logger.info(f"Loaded human pattern classifier from {classifier_path}")
        except FileNotFoundError:
            logger.warning(f"Human pattern classifier not found at {classifier_path}, using random weights")

        self.hp_classifier.eval()
        for param in self.hp_classifier.parameters():
            param.requires_grad = False

        # Feature extractor on CPU to save GPU VRAM for main model
        self.hp_extractor = HumanPatternFeatureExtractor(device="cpu")

    def _human_pattern_loss(self, output_texts: List[str], compute_device: torch.device = None) -> torch.Tensor:
        """Loss = 1 - human_score. Penalise AI-like outputs."""
        scores = []
        for text in output_texts:
            score = self.hp_classifier.score(text, self.hp_extractor)
            scores.append(score)
        device = compute_device if compute_device is not None else self.device
        human_scores = torch.tensor(scores, dtype=torch.float32, device=device)
        return (1.0 - human_scores).mean()

    def forward(
        self,
        logits: torch.Tensor,
        labels: torch.Tensor,
        output_style_vec: Optional[torch.Tensor] = None,
        target_style_vec: Optional[torch.Tensor] = None,
        input_texts: Optional[List[str]] = None,
        output_texts: Optional[List[str]] = None,
    ) -> Dict[str, torch.Tensor]:
        """Compute combined loss with human pattern term."""
        losses = {}

        # L_CE
        if logits.dim() == 3:
            ce_logits = logits.view(-1, logits.size(-1))
            ce_labels = labels.view(-1)
        else:
            ce_logits = logits
            ce_labels = labels
        l_ce = self.ce_loss(ce_logits, ce_labels)
        losses["ce_loss"] = l_ce
        total = l_ce

        # L_style
        if output_style_vec is not None and target_style_vec is not None:
            # Ensure both vectors are on the same device (style vecs may come from CPU fingerprinter)
            compute_device = logits.device
            output_style_vec = output_style_vec.to(compute_device)
            target_style_vec = target_style_vec.to(compute_device)
            if output_style_vec.dim() == 1:
                output_style_vec = output_style_vec.unsqueeze(0)
            if target_style_vec.dim() == 1:
                target_style_vec = target_style_vec.unsqueeze(0)
            cos_sim = F.cosine_similarity(output_style_vec, target_style_vec, dim=-1)
            l_style = (1.0 - cos_sim).mean()
            losses["style_loss"] = l_style
            total = total + self.lambda_style * l_style

        # L_semantic
        if input_texts is not None and output_texts is not None:
            with torch.no_grad():
                input_emb = self.sem_model.encode(input_texts, convert_to_tensor=True)
                output_emb = self.sem_model.encode(output_texts, convert_to_tensor=True)
            # sem_model is on CPU, move embeddings to compute device
            input_emb = input_emb.to(logits.device)
            output_emb = output_emb.to(logits.device)
            cos_sim = F.cosine_similarity(input_emb, output_emb, dim=-1)
            l_semantic = (1.0 - cos_sim).mean()
            losses["semantic_loss"] = l_semantic
            total = total + self.lambda_semantic * l_semantic

        # L_human_pattern
        if output_texts is not None:
            l_human = self._human_pattern_loss(output_texts, compute_device=logits.device)
            losses["human_pattern_loss"] = l_human
            total = total + self.lambda_human_pattern * l_human

        losses["total_loss"] = total
        return losses