""" HuggingFace-compatible generator for custom LanguageModel. Wraps the custom model in a minimal HF-compatible interface so we can use transformers.generate() with all its bells and whistles (beam search, contrastive search, repetition penalty, etc.) while keeping our own weights. Usage: from GeneratorHF import HFTextGenerator gen = HFTextGenerator(model, tokenizer, device, context_size=1024) text = gen.generate("Once upon a time", max_new_tokens=200) """ import torch import torch.nn as nn from transformers import GenerationConfig, GenerationMixin, PretrainedConfig class LanguageModelHF(nn.Module, GenerationMixin): """Thin wrapper that makes our custom LanguageModel compatible with HuggingFace's GenerationMixin (generate, beam_search, sample, etc.).""" _is_stateful = False _supports_cache_class = False def __init__(self, model, context_size, device): super().__init__() self.model = model self.config = _make_hf_config(model, context_size) self.generation_config = GenerationConfig( max_new_tokens=200, do_sample=True, temperature=0.8, top_k=40, top_p=0.95, repetition_penalty=1.1, eos_token_id=50256, pad_token_id=50256, ) self.main_input_name = "input_ids" @property def device(self): return next(self.parameters()).device def forward(self, input_ids, attention_mask=None, **kwargs): # Trim to context window (HF might feed longer sequences) input_ids = input_ids[:, -self.config.n_positions:] logits = self.model(input_ids) return _CausalLMOutput(logits=logits) def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **kwargs): return {"input_ids": input_ids} def can_generate(self): return True def _reorder_cache(self, past, beam_idx): return past class _CausalLMOutput: """Minimal output container matching HF's CausalLMOutput interface.""" def __init__(self, logits): self.logits = logits self.past_key_values = None self.hidden_states = None self.attentions = None def __getitem__(self, key): return getattr(self, key) def __contains__(self, key): return hasattr(self, key) and getattr(self, key) is not None def keys(self): return [k for k in ["logits", "past_key_values", "hidden_states", "attentions"] if getattr(self, k, None) is not None] class _MinimalConfig(PretrainedConfig): """Config that satisfies HF GenerationMixin by inheriting PretrainedConfig.""" model_type = "custom_gpt2" def __init__(self, **kwargs): super().__init__(**kwargs) self.is_encoder_decoder = False self.n_positions = kwargs.get("n_positions", 1024) self.vocab_size = kwargs.get("vocab_size", 50257) self.eos_token_id = 50256 self.pad_token_id = 50256 self.bos_token_id = 50256 self.use_cache = False self.num_hidden_layers = kwargs.get("num_hidden_layers", 12) def _make_hf_config(model, context_size): vocab_size = 50257 n_positions = context_size num_hidden_layers = 12 if hasattr(model, 'config') and isinstance(model.config, dict): vocab_size = model.config.get("vocab_size", 50257) n_positions = model.config.get("context_length", context_size) num_hidden_layers = model.config.get("n_layers", 12) cfg = _MinimalConfig(vocab_size=vocab_size, n_positions=n_positions, num_hidden_layers=num_hidden_layers) return cfg class HFTextGenerator: """Drop-in replacement for TextGenerator using HF generate().""" def __init__(self, model, tokenizer, device, context_size=1024): self.tokenizer = tokenizer self.device = device self.context_size = context_size self.hf_model = LanguageModelHF(model, context_size, device) self.hf_model.to(device) self.hf_model.eval() def generate(self, prompt, max_new_tokens=200, temperature=0.8, top_k=40, top_p=0.95, repetition_penalty=1.1, num_beams=1, do_sample=True, eos_id=50256): """Generate text from a string prompt. Returns the generated string.""" input_ids = torch.tensor( [self.tokenizer.encode(prompt, allowed_special={"<|endoftext|>"})] ).to(self.device) gen_config = GenerationConfig( max_new_tokens=max_new_tokens, do_sample=do_sample, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, num_beams=num_beams, eos_token_id=eos_id, pad_token_id=eos_id, ) attention_mask = torch.ones_like(input_ids) with torch.no_grad(): output_ids = self.hf_model.generate( input_ids, attention_mask=attention_mask, generation_config=gen_config, ) # Decode only the newly generated tokens new_tokens = output_ids[0, input_ids.shape[1]:] # Filter out the EOS token if it was generated if len(new_tokens) > 0 and new_tokens[-1].item() == eos_id: new_tokens = new_tokens[:-1] decoded_text = self.tokenizer.decode(new_tokens.tolist()) return decoded_text.replace("<|endoftext|>", "").strip() def generate_ids(self, idx, max_new_tokens=200, temperature=0.8, top_k=40, top_p=0.95, repetition_penalty=1.1, num_beams=1, do_sample=True, eos_id=50256): """Generate from token IDs tensor (same interface as original). Returns full sequence including prompt.""" gen_config = GenerationConfig( max_new_tokens=max_new_tokens, do_sample=do_sample, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, num_beams=num_beams, eos_token_id=eos_id, pad_token_id=eos_id, ) idx = idx.to(self.device) attention_mask = torch.ones_like(idx) with torch.no_grad(): output_ids = self.hf_model.generate( idx, attention_mask=attention_mask, generation_config=gen_config, ) return output_ids # ── Backward-compatible helpers (match TextGenerator interface) ─────── def token_ids_to_text(self, token_ids): flat = token_ids.squeeze(0) return self.tokenizer.decode(flat.tolist()) def text_to_token_ids(self, text): encoded = self.tokenizer.encode(text, allowed_special={"<|endoftext|>"}) return torch.tensor(encoded).unsqueeze(0).to(self.device)