File size: 7,172 Bytes
ea2eee0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 | """
HuggingFace-compatible generator for custom LanguageModel.
Wraps the custom model in a minimal HF-compatible interface so we can use
transformers.generate() with all its bells and whistles (beam search,
contrastive search, repetition penalty, etc.) while keeping our own weights.
Usage:
from GeneratorHF import HFTextGenerator
gen = HFTextGenerator(model, tokenizer, device, context_size=1024)
text = gen.generate("Once upon a time", max_new_tokens=200)
"""
import torch
import torch.nn as nn
from transformers import GenerationConfig, GenerationMixin, PretrainedConfig
class LanguageModelHF(nn.Module, GenerationMixin):
"""Thin wrapper that makes our custom LanguageModel compatible with
HuggingFace's GenerationMixin (generate, beam_search, sample, etc.)."""
_is_stateful = False
_supports_cache_class = False
def __init__(self, model, context_size, device):
super().__init__()
self.model = model
self.config = _make_hf_config(model, context_size)
self.generation_config = GenerationConfig(
max_new_tokens=200,
do_sample=True,
temperature=0.8,
top_k=40,
top_p=0.95,
repetition_penalty=1.1,
eos_token_id=50256,
pad_token_id=50256,
)
self.main_input_name = "input_ids"
@property
def device(self):
return next(self.parameters()).device
def forward(self, input_ids, attention_mask=None, **kwargs):
# Trim to context window (HF might feed longer sequences)
input_ids = input_ids[:, -self.config.n_positions:]
logits = self.model(input_ids)
return _CausalLMOutput(logits=logits)
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **kwargs):
return {"input_ids": input_ids}
def can_generate(self):
return True
def _reorder_cache(self, past, beam_idx):
return past
class _CausalLMOutput:
"""Minimal output container matching HF's CausalLMOutput interface."""
def __init__(self, logits):
self.logits = logits
self.past_key_values = None
self.hidden_states = None
self.attentions = None
def __getitem__(self, key):
return getattr(self, key)
def __contains__(self, key):
return hasattr(self, key) and getattr(self, key) is not None
def keys(self):
return [k for k in ["logits", "past_key_values", "hidden_states", "attentions"]
if getattr(self, k, None) is not None]
class _MinimalConfig(PretrainedConfig):
"""Config that satisfies HF GenerationMixin by inheriting PretrainedConfig."""
model_type = "custom_gpt2"
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.is_encoder_decoder = False
self.n_positions = kwargs.get("n_positions", 1024)
self.vocab_size = kwargs.get("vocab_size", 50257)
self.eos_token_id = 50256
self.pad_token_id = 50256
self.bos_token_id = 50256
self.use_cache = False
self.num_hidden_layers = kwargs.get("num_hidden_layers", 12)
def _make_hf_config(model, context_size):
vocab_size = 50257
n_positions = context_size
num_hidden_layers = 12
if hasattr(model, 'config') and isinstance(model.config, dict):
vocab_size = model.config.get("vocab_size", 50257)
n_positions = model.config.get("context_length", context_size)
num_hidden_layers = model.config.get("n_layers", 12)
cfg = _MinimalConfig(vocab_size=vocab_size, n_positions=n_positions, num_hidden_layers=num_hidden_layers)
return cfg
class HFTextGenerator:
"""Drop-in replacement for TextGenerator using HF generate()."""
def __init__(self, model, tokenizer, device, context_size=1024):
self.tokenizer = tokenizer
self.device = device
self.context_size = context_size
self.hf_model = LanguageModelHF(model, context_size, device)
self.hf_model.to(device)
self.hf_model.eval()
def generate(self, prompt, max_new_tokens=200,
temperature=0.8, top_k=40, top_p=0.95,
repetition_penalty=1.1, num_beams=1,
do_sample=True, eos_id=50256):
"""Generate text from a string prompt. Returns the generated string."""
input_ids = torch.tensor(
[self.tokenizer.encode(prompt, allowed_special={"<|endoftext|>"})]
).to(self.device)
gen_config = GenerationConfig(
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
num_beams=num_beams,
eos_token_id=eos_id,
pad_token_id=eos_id,
)
attention_mask = torch.ones_like(input_ids)
with torch.no_grad():
output_ids = self.hf_model.generate(
input_ids,
attention_mask=attention_mask,
generation_config=gen_config,
)
# Decode only the newly generated tokens
new_tokens = output_ids[0, input_ids.shape[1]:]
# Filter out the EOS token if it was generated
if len(new_tokens) > 0 and new_tokens[-1].item() == eos_id:
new_tokens = new_tokens[:-1]
decoded_text = self.tokenizer.decode(new_tokens.tolist())
return decoded_text.replace("<|endoftext|>", "").strip()
def generate_ids(self, idx, max_new_tokens=200,
temperature=0.8, top_k=40, top_p=0.95,
repetition_penalty=1.1, num_beams=1,
do_sample=True, eos_id=50256):
"""Generate from token IDs tensor (same interface as original).
Returns full sequence including prompt."""
gen_config = GenerationConfig(
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
num_beams=num_beams,
eos_token_id=eos_id,
pad_token_id=eos_id,
)
idx = idx.to(self.device)
attention_mask = torch.ones_like(idx)
with torch.no_grad():
output_ids = self.hf_model.generate(
idx,
attention_mask=attention_mask,
generation_config=gen_config,
)
return output_ids
# ββ Backward-compatible helpers (match TextGenerator interface) βββββββ
def token_ids_to_text(self, token_ids):
flat = token_ids.squeeze(0)
return self.tokenizer.decode(flat.tolist())
def text_to_token_ids(self, text):
encoded = self.tokenizer.encode(text, allowed_special={"<|endoftext|>"})
return torch.tensor(encoded).unsqueeze(0).to(self.device)
|