"""Shakespeare Constrained Decoder — deterministic style enforcement. Apply this LogitsProcessor during generation to guarantee Shakespeare vocabulary. This is the unjailbreakable layer — it operates on raw logits, not learned behavior. """ from transformers import LogitsProcessor import torch class ShakespeareLogitProcessor(LogitsProcessor): """Boosts archaic tokens and suppresses modern tokens at every generation step.""" def __init__(self, tokenizer, boost=3.0, suppress=-8.0): super().__init__() self.boost_ids = set() self.suppress_ids = set() boost_words = [ "thee", "thou", "thy", "thine", "hast", "dost", "doth", "ye", "hath", "art", "wilt", "shalt", "wouldst", "shouldst", "canst", "didst", "prithee", "forsooth", "hark", "wherefore", "methinks", "verily", "perchance", "mayhap", "alas", "alack", "anon", "betwixt", "hence", "thence", "whence", "ere", "oft", "nay", "aye", "yonder", "yon", "fie", "lo", "'tis", "'twas", "'twere", "o'er", "e'er", "ne'er", "morn", "eve", "morrow", "quill", "hearken", "beseech", "tarry", "naught", "nought", "dew", "mortal", "immortal", "beauteous", "wondrous", "valiant", "whilst", "unto", "thereof", "herein", "wherein", "hither", "thither", "whither", ] suppress_words = [ "AI", "chatbot", "assistant", "algorithm", "neural", "GPT", "LLM", "okay", "OK", "sure", "yeah", "awesome", "cool", "basically", "literally", "actually", "honestly", "definitely", "absolutely", "totally", "internet", "wifi", "app", "website", "download", "upload", "database", "server", "API", "URL", "no problem", "happy to help", "let me know", ] for w in boost_words: for v in [w, w.capitalize(), w.upper(), f" {w}", f" {w.capitalize()}"]: self.boost_ids.update(tokenizer.encode(v, add_special_tokens=False)) for w in suppress_words: for v in [w, w.lower(), w.upper(), f" {w}", f" {w.lower()}"]: self.suppress_ids.update(tokenizer.encode(v, add_special_tokens=False)) self.suppress_ids -= self.boost_ids self.boost = boost self.suppress = suppress def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: for tid in self.boost_ids: if tid < scores.shape[-1]: scores[:, tid] += self.boost for tid in self.suppress_ids: if tid < scores.shape[-1]: scores[:, tid] += self.suppress return scores