LikeGPT2small / GeneratorHF.py
Zemulax's picture
files for inference
ea2eee0 verified
"""
HuggingFace-compatible generator for custom LanguageModel.
Wraps the custom model in a minimal HF-compatible interface so we can use
transformers.generate() with all its bells and whistles (beam search,
contrastive search, repetition penalty, etc.) while keeping our own weights.
Usage:
from GeneratorHF import HFTextGenerator
gen = HFTextGenerator(model, tokenizer, device, context_size=1024)
text = gen.generate("Once upon a time", max_new_tokens=200)
"""
import torch
import torch.nn as nn
from transformers import GenerationConfig, GenerationMixin, PretrainedConfig
class LanguageModelHF(nn.Module, GenerationMixin):
"""Thin wrapper that makes our custom LanguageModel compatible with
HuggingFace's GenerationMixin (generate, beam_search, sample, etc.)."""
_is_stateful = False
_supports_cache_class = False
def __init__(self, model, context_size, device):
super().__init__()
self.model = model
self.config = _make_hf_config(model, context_size)
self.generation_config = GenerationConfig(
max_new_tokens=200,
do_sample=True,
temperature=0.8,
top_k=40,
top_p=0.95,
repetition_penalty=1.1,
eos_token_id=50256,
pad_token_id=50256,
)
self.main_input_name = "input_ids"
@property
def device(self):
return next(self.parameters()).device
def forward(self, input_ids, attention_mask=None, **kwargs):
# Trim to context window (HF might feed longer sequences)
input_ids = input_ids[:, -self.config.n_positions:]
logits = self.model(input_ids)
return _CausalLMOutput(logits=logits)
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **kwargs):
return {"input_ids": input_ids}
def can_generate(self):
return True
def _reorder_cache(self, past, beam_idx):
return past
class _CausalLMOutput:
"""Minimal output container matching HF's CausalLMOutput interface."""
def __init__(self, logits):
self.logits = logits
self.past_key_values = None
self.hidden_states = None
self.attentions = None
def __getitem__(self, key):
return getattr(self, key)
def __contains__(self, key):
return hasattr(self, key) and getattr(self, key) is not None
def keys(self):
return [k for k in ["logits", "past_key_values", "hidden_states", "attentions"]
if getattr(self, k, None) is not None]
class _MinimalConfig(PretrainedConfig):
"""Config that satisfies HF GenerationMixin by inheriting PretrainedConfig."""
model_type = "custom_gpt2"
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.is_encoder_decoder = False
self.n_positions = kwargs.get("n_positions", 1024)
self.vocab_size = kwargs.get("vocab_size", 50257)
self.eos_token_id = 50256
self.pad_token_id = 50256
self.bos_token_id = 50256
self.use_cache = False
self.num_hidden_layers = kwargs.get("num_hidden_layers", 12)
def _make_hf_config(model, context_size):
vocab_size = 50257
n_positions = context_size
num_hidden_layers = 12
if hasattr(model, 'config') and isinstance(model.config, dict):
vocab_size = model.config.get("vocab_size", 50257)
n_positions = model.config.get("context_length", context_size)
num_hidden_layers = model.config.get("n_layers", 12)
cfg = _MinimalConfig(vocab_size=vocab_size, n_positions=n_positions, num_hidden_layers=num_hidden_layers)
return cfg
class HFTextGenerator:
"""Drop-in replacement for TextGenerator using HF generate()."""
def __init__(self, model, tokenizer, device, context_size=1024):
self.tokenizer = tokenizer
self.device = device
self.context_size = context_size
self.hf_model = LanguageModelHF(model, context_size, device)
self.hf_model.to(device)
self.hf_model.eval()
def generate(self, prompt, max_new_tokens=200,
temperature=0.8, top_k=40, top_p=0.95,
repetition_penalty=1.1, num_beams=1,
do_sample=True, eos_id=50256):
"""Generate text from a string prompt. Returns the generated string."""
input_ids = torch.tensor(
[self.tokenizer.encode(prompt, allowed_special={"<|endoftext|>"})]
).to(self.device)
gen_config = GenerationConfig(
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
num_beams=num_beams,
eos_token_id=eos_id,
pad_token_id=eos_id,
)
attention_mask = torch.ones_like(input_ids)
with torch.no_grad():
output_ids = self.hf_model.generate(
input_ids,
attention_mask=attention_mask,
generation_config=gen_config,
)
# Decode only the newly generated tokens
new_tokens = output_ids[0, input_ids.shape[1]:]
# Filter out the EOS token if it was generated
if len(new_tokens) > 0 and new_tokens[-1].item() == eos_id:
new_tokens = new_tokens[:-1]
decoded_text = self.tokenizer.decode(new_tokens.tolist())
return decoded_text.replace("<|endoftext|>", "").strip()
def generate_ids(self, idx, max_new_tokens=200,
temperature=0.8, top_k=40, top_p=0.95,
repetition_penalty=1.1, num_beams=1,
do_sample=True, eos_id=50256):
"""Generate from token IDs tensor (same interface as original).
Returns full sequence including prompt."""
gen_config = GenerationConfig(
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
num_beams=num_beams,
eos_token_id=eos_id,
pad_token_id=eos_id,
)
idx = idx.to(self.device)
attention_mask = torch.ones_like(idx)
with torch.no_grad():
output_ids = self.hf_model.generate(
idx,
attention_mask=attention_mask,
generation_config=gen_config,
)
return output_ids
# ── Backward-compatible helpers (match TextGenerator interface) ───────
def token_ids_to_text(self, token_ids):
flat = token_ids.squeeze(0)
return self.tokenizer.decode(flat.tolist())
def text_to_token_ids(self, text):
encoded = self.tokenizer.encode(text, allowed_special={"<|endoftext|>"})
return torch.tensor(encoded).unsqueeze(0).to(self.device)