| """
|
| HuggingFace-compatible generator for custom LanguageModel.
|
|
|
| Wraps the custom model in a minimal HF-compatible interface so we can use
|
| transformers.generate() with all its bells and whistles (beam search,
|
| contrastive search, repetition penalty, etc.) while keeping our own weights.
|
|
|
| Usage:
|
| from GeneratorHF import HFTextGenerator
|
|
|
| gen = HFTextGenerator(model, tokenizer, device, context_size=1024)
|
| text = gen.generate("Once upon a time", max_new_tokens=200)
|
| """
|
|
|
| import torch
|
| import torch.nn as nn
|
| from transformers import GenerationConfig, GenerationMixin, PretrainedConfig
|
|
|
|
|
| class LanguageModelHF(nn.Module, GenerationMixin):
|
| """Thin wrapper that makes our custom LanguageModel compatible with
|
| HuggingFace's GenerationMixin (generate, beam_search, sample, etc.)."""
|
|
|
| _is_stateful = False
|
| _supports_cache_class = False
|
|
|
| def __init__(self, model, context_size, device):
|
| super().__init__()
|
| self.model = model
|
| self.config = _make_hf_config(model, context_size)
|
| self.generation_config = GenerationConfig(
|
| max_new_tokens=200,
|
| do_sample=True,
|
| temperature=0.8,
|
| top_k=40,
|
| top_p=0.95,
|
| repetition_penalty=1.1,
|
| eos_token_id=50256,
|
| pad_token_id=50256,
|
| )
|
| self.main_input_name = "input_ids"
|
|
|
| @property
|
| def device(self):
|
| return next(self.parameters()).device
|
|
|
| def forward(self, input_ids, attention_mask=None, **kwargs):
|
|
|
| input_ids = input_ids[:, -self.config.n_positions:]
|
| logits = self.model(input_ids)
|
| return _CausalLMOutput(logits=logits)
|
|
|
| def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **kwargs):
|
| return {"input_ids": input_ids}
|
|
|
| def can_generate(self):
|
| return True
|
|
|
| def _reorder_cache(self, past, beam_idx):
|
| return past
|
|
|
|
|
| class _CausalLMOutput:
|
| """Minimal output container matching HF's CausalLMOutput interface."""
|
| def __init__(self, logits):
|
| self.logits = logits
|
| self.past_key_values = None
|
| self.hidden_states = None
|
| self.attentions = None
|
|
|
| def __getitem__(self, key):
|
| return getattr(self, key)
|
|
|
| def __contains__(self, key):
|
| return hasattr(self, key) and getattr(self, key) is not None
|
|
|
| def keys(self):
|
| return [k for k in ["logits", "past_key_values", "hidden_states", "attentions"]
|
| if getattr(self, k, None) is not None]
|
|
|
|
|
| class _MinimalConfig(PretrainedConfig):
|
| """Config that satisfies HF GenerationMixin by inheriting PretrainedConfig."""
|
| model_type = "custom_gpt2"
|
|
|
| def __init__(self, **kwargs):
|
| super().__init__(**kwargs)
|
| self.is_encoder_decoder = False
|
| self.n_positions = kwargs.get("n_positions", 1024)
|
| self.vocab_size = kwargs.get("vocab_size", 50257)
|
| self.eos_token_id = 50256
|
| self.pad_token_id = 50256
|
| self.bos_token_id = 50256
|
| self.use_cache = False
|
| self.num_hidden_layers = kwargs.get("num_hidden_layers", 12)
|
|
|
|
|
| def _make_hf_config(model, context_size):
|
| vocab_size = 50257
|
| n_positions = context_size
|
| num_hidden_layers = 12
|
| if hasattr(model, 'config') and isinstance(model.config, dict):
|
| vocab_size = model.config.get("vocab_size", 50257)
|
| n_positions = model.config.get("context_length", context_size)
|
| num_hidden_layers = model.config.get("n_layers", 12)
|
|
|
| cfg = _MinimalConfig(vocab_size=vocab_size, n_positions=n_positions, num_hidden_layers=num_hidden_layers)
|
| return cfg
|
|
|
|
|
| class HFTextGenerator:
|
| """Drop-in replacement for TextGenerator using HF generate()."""
|
|
|
| def __init__(self, model, tokenizer, device, context_size=1024):
|
| self.tokenizer = tokenizer
|
| self.device = device
|
| self.context_size = context_size
|
| self.hf_model = LanguageModelHF(model, context_size, device)
|
| self.hf_model.to(device)
|
| self.hf_model.eval()
|
|
|
| def generate(self, prompt, max_new_tokens=200,
|
| temperature=0.8, top_k=40, top_p=0.95,
|
| repetition_penalty=1.1, num_beams=1,
|
| do_sample=True, eos_id=50256):
|
| """Generate text from a string prompt. Returns the generated string."""
|
| input_ids = torch.tensor(
|
| [self.tokenizer.encode(prompt, allowed_special={"<|endoftext|>"})]
|
| ).to(self.device)
|
|
|
| gen_config = GenerationConfig(
|
| max_new_tokens=max_new_tokens,
|
| do_sample=do_sample,
|
| temperature=temperature,
|
| top_k=top_k,
|
| top_p=top_p,
|
| repetition_penalty=repetition_penalty,
|
| num_beams=num_beams,
|
| eos_token_id=eos_id,
|
| pad_token_id=eos_id,
|
| )
|
|
|
| attention_mask = torch.ones_like(input_ids)
|
|
|
| with torch.no_grad():
|
| output_ids = self.hf_model.generate(
|
| input_ids,
|
| attention_mask=attention_mask,
|
| generation_config=gen_config,
|
| )
|
|
|
|
|
| new_tokens = output_ids[0, input_ids.shape[1]:]
|
|
|
|
|
| if len(new_tokens) > 0 and new_tokens[-1].item() == eos_id:
|
| new_tokens = new_tokens[:-1]
|
|
|
| decoded_text = self.tokenizer.decode(new_tokens.tolist())
|
| return decoded_text.replace("<|endoftext|>", "").strip()
|
|
|
| def generate_ids(self, idx, max_new_tokens=200,
|
| temperature=0.8, top_k=40, top_p=0.95,
|
| repetition_penalty=1.1, num_beams=1,
|
| do_sample=True, eos_id=50256):
|
| """Generate from token IDs tensor (same interface as original).
|
| Returns full sequence including prompt."""
|
| gen_config = GenerationConfig(
|
| max_new_tokens=max_new_tokens,
|
| do_sample=do_sample,
|
| temperature=temperature,
|
| top_k=top_k,
|
| top_p=top_p,
|
| repetition_penalty=repetition_penalty,
|
| num_beams=num_beams,
|
| eos_token_id=eos_id,
|
| pad_token_id=eos_id,
|
| )
|
|
|
| idx = idx.to(self.device)
|
| attention_mask = torch.ones_like(idx)
|
|
|
| with torch.no_grad():
|
| output_ids = self.hf_model.generate(
|
| idx,
|
| attention_mask=attention_mask,
|
| generation_config=gen_config,
|
| )
|
|
|
| return output_ids
|
|
|
|
|
| def token_ids_to_text(self, token_ids):
|
| flat = token_ids.squeeze(0)
|
| return self.tokenizer.decode(flat.tolist())
|
|
|
| def text_to_token_ids(self, text):
|
| encoded = self.tokenizer.encode(text, allowed_special={"<|endoftext|>"})
|
| return torch.tensor(encoded).unsqueeze(0).to(self.device)
|
|
|