File size: 7,440 Bytes
83dd916 32b8af1 83dd916 32b8af1 83dd916 32b8af1 83dd916 32b8af1 83dd916 32b8af1 83dd916 32b8af1 139362b 32b8af1 139362b 32b8af1 139362b 32b8af1 139362b 32b8af1 139362b 32b8af1 139362b 32b8af1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
# Copyright 2025 Radical Numerics Inc.
#
# This source code is licensed under the Apache License, Version 2.0, found in the
# LICENSE file in the root directory of this source tree.
"""
RND1 Generation Utilities.
This module provides generation utilities and mixins for RND1 models,
including the main GenerationMixin class that integrates with HuggingFace.
"""
import torch
from typing import Optional, Union, Dict, Any
from transformers import GenerationMixin as HFGenerationMixin
from transformers.generation import GenerationConfig
from .generation_config import RND1GenerationConfig
from .sampling import diffusion_sample
class RND1GenerationMixin(HFGenerationMixin):
"""
Generation mixin for RND1 models.
This mixin provides generation methods compatible with HuggingFace's
generation API while using RND1's diffusion-based sampling internally.
"""
def generate(
self,
inputs: Optional[torch.LongTensor] = None,
generation_config: Optional[GenerationConfig] = None,
# RND1-specific parameters
prefix_ids: Optional[torch.LongTensor] = None,
suffix_ids: Optional[torch.LongTensor] = None,
infill_length: Optional[int] = None,
return_dict_in_generate: Optional[bool] = None,
**kwargs, # Accept all kwargs to be compatible with pipelines
) -> Union[torch.LongTensor, Dict[str, Any]]:
"""
Generate text using RND1's diffusion-based sampling.
Follows HuggingFace's standard generate API, using diffusion sampling
internally. Supports both standard generation and infilling.
Args:
inputs: Input token IDs to use as prefix (standard HF parameter)
generation_config: Generation configuration object. Default is RND1GenerationConfig.
prefix_ids: Alternative to inputs for infilling tasks
suffix_ids: Optional suffix for infilling tasks
infill_length: Length of infill region (for infilling)
return_dict_in_generate: Whether to return GenerateDecoderOnlyOutput
**kwargs: Additional arguments (accepted for compatibility). These will be passed to the config constructor.
Returns:
Generated token IDs or GenerateDecoderOnlyOutput
"""
if generation_config is not None:
gen_config = generation_config
model_kwargs = kwargs.copy()
else:
# Only prepare config from kwargs if no config was provided
gen_config, model_kwargs = self._prepare_generation_config(RND1GenerationConfig(), **kwargs)
device = next(self.parameters()).device
if inputs is not None:
prefix_ids = inputs.to(device)
elif prefix_ids is not None:
prefix_ids = prefix_ids.to(device)
else:
prefix_ids = None
if suffix_ids is not None:
suffix_ids = suffix_ids.to(device)
eos_token_id = gen_config.eos_token_id or getattr(self.config, "eos_token_id", 151645)
pad_token_id = gen_config.pad_token_id or getattr(self.config, "pad_token_id", 151643)
bos_token_id = gen_config.bos_token_id or getattr(self.config, "bos_token_id", None)
mask_token_id = getattr(gen_config, "mask_token_id", getattr(self.config, "mask_token_id", 151669))
if infill_length is not None and prefix_ids is not None:
# Infilling mode: use specified infill_length
prefix_len = prefix_ids.shape[1] if prefix_ids is not None else 0
suffix_len = suffix_ids.shape[1] if suffix_ids is not None else 0
seq_len = prefix_len + infill_length + suffix_len
else:
# Standard generation mode
if prefix_ids is not None:
prefix_len = prefix_ids.shape[1]
if gen_config.max_new_tokens is not None:
seq_len = prefix_len + gen_config.max_new_tokens
else:
seq_len = gen_config.max_length or self.config.max_position_embeddings
else:
seq_len = gen_config.max_length or self.config.max_position_embeddings
num_diffusion_steps = getattr(gen_config, "num_diffusion_steps",
getattr(self.config, "num_diffusion_steps", 256))
temperature = float(getattr(gen_config, "temperature", 1.0))
top_k = getattr(gen_config, "top_k", None)
top_p = getattr(gen_config, "top_p", None)
greedy = getattr(gen_config, "greedy",
not bool(gen_config.do_sample) if hasattr(gen_config, "do_sample") else True)
with torch.inference_mode():
sequences = diffusion_sample(
model=self,
seq_len=seq_len,
num_steps=num_diffusion_steps,
mask_token_id=mask_token_id,
temperature=temperature,
top_k=top_k,
top_p=top_p,
greedy=greedy,
prefix_ids=prefix_ids,
suffix_ids=suffix_ids,
infill_length=infill_length,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
device=device,
visualizer=model_kwargs.get("visualizer", None), # Optional visualizer from kwargs
)
if return_dict_in_generate or getattr(gen_config, "return_dict_in_generate", False):
from transformers.generation.utils import GenerateDecoderOnlyOutput
return GenerateDecoderOnlyOutput(sequences=sequences)
return sequences
def generate_with_visualization(
self,
tokenizer,
inputs: Optional[torch.LongTensor] = None,
generation_config: Optional[GenerationConfig] = None,
suffix_ids: Optional[torch.LongTensor] = None,
infill_length: Optional[int] = None,
**kwargs,
) -> torch.LongTensor:
"""
Generate with live visualization (for demos).
This method requires a tokenizer to display the generation process.
For production use, prefer `generate()`.
Args:
tokenizer: Tokenizer for decoding tokens to text
inputs: Input token IDs to use as prefix
generation_config: Generation configuration object
suffix_ids: Optional suffix token IDs
infill_length: Length of infill region
**kwargs: Additional arguments for backward compatibility
Returns:
Generated token IDs as LongTensor
"""
from .terminal_visualizer import TerminalVisualizer
visualizer = TerminalVisualizer(tokenizer, show_visualization=True)
return self.generate(
inputs=inputs,
generation_config=generation_config,
suffix_ids=suffix_ids,
infill_length=infill_length,
visualizer=visualizer,
return_dict_in_generate=False,
**kwargs,
)
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
**kwargs,
) -> Dict[str, Any]:
"""
Prepare inputs for generation (required by HuggingFace).
For RND1, we don't use the standard autoregressive generation,
so this just returns the input_ids.
"""
return {"input_ids": input_ids} |