File size: 2,714 Bytes
c5d9ebe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
"""Shakespeare Constrained Decoder — deterministic style enforcement.

Apply this LogitsProcessor during generation to guarantee Shakespeare vocabulary.
This is the unjailbreakable layer — it operates on raw logits, not learned behavior.
"""
from transformers import LogitsProcessor
import torch

class ShakespeareLogitProcessor(LogitsProcessor):
    """Boosts archaic tokens and suppresses modern tokens at every generation step."""

    def __init__(self, tokenizer, boost=3.0, suppress=-8.0):
        super().__init__()
        self.boost_ids = set()
        self.suppress_ids = set()

        boost_words = [
            "thee", "thou", "thy", "thine", "hast", "dost", "doth", "ye",
            "hath", "art", "wilt", "shalt", "wouldst", "shouldst", "canst",
            "didst", "prithee", "forsooth", "hark", "wherefore", "methinks",
            "verily", "perchance", "mayhap", "alas", "alack", "anon",
            "betwixt", "hence", "thence", "whence", "ere", "oft", "nay",
            "aye", "yonder", "yon", "fie", "lo", "'tis", "'twas", "'twere",
            "o'er", "e'er", "ne'er", "morn", "eve", "morrow", "quill",
            "hearken", "beseech", "tarry", "naught", "nought", "dew",
            "mortal", "immortal", "beauteous", "wondrous", "valiant",
            "whilst", "unto", "thereof", "herein", "wherein", "hither",
            "thither", "whither",
        ]

        suppress_words = [
            "AI", "chatbot", "assistant", "algorithm", "neural",
            "GPT", "LLM", "okay", "OK", "sure", "yeah", "awesome",
            "cool", "basically", "literally", "actually", "honestly",
            "definitely", "absolutely", "totally", "internet", "wifi",
            "app", "website", "download", "upload", "database", "server",
            "API", "URL", "no problem", "happy to help", "let me know",
        ]

        for w in boost_words:
            for v in [w, w.capitalize(), w.upper(), f" {w}", f" {w.capitalize()}"]:
                self.boost_ids.update(tokenizer.encode(v, add_special_tokens=False))
        for w in suppress_words:
            for v in [w, w.lower(), w.upper(), f" {w}", f" {w.lower()}"]:
                self.suppress_ids.update(tokenizer.encode(v, add_special_tokens=False))
        self.suppress_ids -= self.boost_ids
        self.boost = boost
        self.suppress = suppress

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        for tid in self.boost_ids:
            if tid < scores.shape[-1]:
                scores[:, tid] += self.boost
        for tid in self.suppress_ids:
            if tid < scores.shape[-1]:
                scores[:, tid] += self.suppress
        return scores