shakespeare-lora-gemma4 / shakespeare_constrained_decoder.py
cabdru's picture
Shakespeare LoRA: rank 256, SFT+DPO, constrained decoder, 15/15 FSV
c5d9ebe verified
"""Shakespeare Constrained Decoder — deterministic style enforcement.
Apply this LogitsProcessor during generation to guarantee Shakespeare vocabulary.
This is the unjailbreakable layer — it operates on raw logits, not learned behavior.
"""
from transformers import LogitsProcessor
import torch
class ShakespeareLogitProcessor(LogitsProcessor):
"""Boosts archaic tokens and suppresses modern tokens at every generation step."""
def __init__(self, tokenizer, boost=3.0, suppress=-8.0):
super().__init__()
self.boost_ids = set()
self.suppress_ids = set()
boost_words = [
"thee", "thou", "thy", "thine", "hast", "dost", "doth", "ye",
"hath", "art", "wilt", "shalt", "wouldst", "shouldst", "canst",
"didst", "prithee", "forsooth", "hark", "wherefore", "methinks",
"verily", "perchance", "mayhap", "alas", "alack", "anon",
"betwixt", "hence", "thence", "whence", "ere", "oft", "nay",
"aye", "yonder", "yon", "fie", "lo", "'tis", "'twas", "'twere",
"o'er", "e'er", "ne'er", "morn", "eve", "morrow", "quill",
"hearken", "beseech", "tarry", "naught", "nought", "dew",
"mortal", "immortal", "beauteous", "wondrous", "valiant",
"whilst", "unto", "thereof", "herein", "wherein", "hither",
"thither", "whither",
]
suppress_words = [
"AI", "chatbot", "assistant", "algorithm", "neural",
"GPT", "LLM", "okay", "OK", "sure", "yeah", "awesome",
"cool", "basically", "literally", "actually", "honestly",
"definitely", "absolutely", "totally", "internet", "wifi",
"app", "website", "download", "upload", "database", "server",
"API", "URL", "no problem", "happy to help", "let me know",
]
for w in boost_words:
for v in [w, w.capitalize(), w.upper(), f" {w}", f" {w.capitalize()}"]:
self.boost_ids.update(tokenizer.encode(v, add_special_tokens=False))
for w in suppress_words:
for v in [w, w.lower(), w.upper(), f" {w}", f" {w.lower()}"]:
self.suppress_ids.update(tokenizer.encode(v, add_special_tokens=False))
self.suppress_ids -= self.boost_ids
self.boost = boost
self.suppress = suppress
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
for tid in self.boost_ids:
if tid < scores.shape[-1]:
scores[:, tid] += self.boost
for tid in self.suppress_ids:
if tid < scores.shape[-1]:
scores[:, tid] += self.suppress
return scores