File size: 7,440 Bytes
83dd916
 
 
 
 
32b8af1
 
 
 
 
 
 
 
 
 
 
 
83dd916
 
32b8af1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83dd916
32b8af1
 
 
 
83dd916
32b8af1
 
 
 
 
 
 
 
 
83dd916
32b8af1
 
 
 
 
 
 
 
 
 
 
 
 
 
83dd916
32b8af1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139362b
 
32b8af1
 
139362b
32b8af1
 
 
 
 
 
 
 
 
139362b
 
32b8af1
 
139362b
32b8af1
 
 
 
 
 
 
 
139362b
 
32b8af1
 
 
 
139362b
32b8af1
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
# Copyright 2025 Radical Numerics Inc.
#
# This source code is licensed under the Apache License, Version 2.0, found in the
# LICENSE file in the root directory of this source tree.

"""
RND1 Generation Utilities.

This module provides generation utilities and mixins for RND1 models,
including the main GenerationMixin class that integrates with HuggingFace.
"""

import torch
from typing import Optional, Union, Dict, Any
from transformers import GenerationMixin as HFGenerationMixin
from transformers.generation import GenerationConfig

from .generation_config import RND1GenerationConfig
from .sampling import diffusion_sample


class RND1GenerationMixin(HFGenerationMixin):
    """
    Generation mixin for RND1 models.

    This mixin provides generation methods compatible with HuggingFace's
    generation API while using RND1's diffusion-based sampling internally.
    """

    def generate(
        self,
        inputs: Optional[torch.LongTensor] = None,
        generation_config: Optional[GenerationConfig] = None,
        # RND1-specific parameters
        prefix_ids: Optional[torch.LongTensor] = None,
        suffix_ids: Optional[torch.LongTensor] = None,
        infill_length: Optional[int] = None,
        return_dict_in_generate: Optional[bool] = None,
        **kwargs,  # Accept all kwargs to be compatible with pipelines
    ) -> Union[torch.LongTensor, Dict[str, Any]]:
        """
        Generate text using RND1's diffusion-based sampling.

        Follows HuggingFace's standard generate API, using diffusion sampling
        internally. Supports both standard generation and infilling.

        Args:
            inputs: Input token IDs to use as prefix (standard HF parameter)
            generation_config: Generation configuration object. Default is RND1GenerationConfig.
            prefix_ids: Alternative to inputs for infilling tasks
            suffix_ids: Optional suffix for infilling tasks
            infill_length: Length of infill region (for infilling)
            return_dict_in_generate: Whether to return GenerateDecoderOnlyOutput
            **kwargs: Additional arguments (accepted for compatibility). These will be passed to the config constructor.

        Returns:
            Generated token IDs or GenerateDecoderOnlyOutput
        """
        if generation_config is not None:
            gen_config = generation_config
            model_kwargs = kwargs.copy()
        else:
            # Only prepare config from kwargs if no config was provided
            gen_config, model_kwargs = self._prepare_generation_config(RND1GenerationConfig(), **kwargs)

        device = next(self.parameters()).device

        if inputs is not None:
            prefix_ids = inputs.to(device)
        elif prefix_ids is not None:
            prefix_ids = prefix_ids.to(device)
        else:
            prefix_ids = None

        if suffix_ids is not None:
            suffix_ids = suffix_ids.to(device)

        eos_token_id = gen_config.eos_token_id or getattr(self.config, "eos_token_id", 151645)
        pad_token_id = gen_config.pad_token_id or getattr(self.config, "pad_token_id", 151643)
        bos_token_id = gen_config.bos_token_id or getattr(self.config, "bos_token_id", None)
        mask_token_id = getattr(gen_config, "mask_token_id", getattr(self.config, "mask_token_id", 151669))

        if infill_length is not None and prefix_ids is not None:
            # Infilling mode: use specified infill_length
            prefix_len = prefix_ids.shape[1] if prefix_ids is not None else 0
            suffix_len = suffix_ids.shape[1] if suffix_ids is not None else 0
            seq_len = prefix_len + infill_length + suffix_len
        else:
            # Standard generation mode
            if prefix_ids is not None:
                prefix_len = prefix_ids.shape[1]
                if gen_config.max_new_tokens is not None:
                    seq_len = prefix_len + gen_config.max_new_tokens
                else:
                    seq_len = gen_config.max_length or self.config.max_position_embeddings
            else:
                seq_len = gen_config.max_length or self.config.max_position_embeddings

        num_diffusion_steps = getattr(gen_config, "num_diffusion_steps",
                                     getattr(self.config, "num_diffusion_steps", 256))

        temperature = float(getattr(gen_config, "temperature", 1.0))
        top_k = getattr(gen_config, "top_k", None)
        top_p = getattr(gen_config, "top_p", None)

        greedy = getattr(gen_config, "greedy",
                        not bool(gen_config.do_sample) if hasattr(gen_config, "do_sample") else True)


        with torch.inference_mode():
            sequences = diffusion_sample(
                model=self,
                seq_len=seq_len,
                num_steps=num_diffusion_steps,
                mask_token_id=mask_token_id,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
                greedy=greedy,
                prefix_ids=prefix_ids,
                suffix_ids=suffix_ids,
                infill_length=infill_length,
                eos_token_id=eos_token_id,
                pad_token_id=pad_token_id,
                bos_token_id=bos_token_id,
                device=device,
                visualizer=model_kwargs.get("visualizer", None),  # Optional visualizer from kwargs
            )

        if return_dict_in_generate or getattr(gen_config, "return_dict_in_generate", False):
            from transformers.generation.utils import GenerateDecoderOnlyOutput
            return GenerateDecoderOnlyOutput(sequences=sequences)

        return sequences

    def generate_with_visualization(
        self,
        tokenizer,
        inputs: Optional[torch.LongTensor] = None,
        generation_config: Optional[GenerationConfig] = None,
        suffix_ids: Optional[torch.LongTensor] = None,
        infill_length: Optional[int] = None,
        **kwargs,
    ) -> torch.LongTensor:
        """
        Generate with live visualization (for demos).

        This method requires a tokenizer to display the generation process.
        For production use, prefer `generate()`.

        Args:
            tokenizer: Tokenizer for decoding tokens to text
            inputs: Input token IDs to use as prefix
            generation_config: Generation configuration object
            suffix_ids: Optional suffix token IDs
            infill_length: Length of infill region
            **kwargs: Additional arguments for backward compatibility

        Returns:
            Generated token IDs as LongTensor
        """
        from .terminal_visualizer import TerminalVisualizer
        visualizer = TerminalVisualizer(tokenizer, show_visualization=True)

        return self.generate(
            inputs=inputs,
            generation_config=generation_config,
            suffix_ids=suffix_ids,
            infill_length=infill_length,
            visualizer=visualizer,
            return_dict_in_generate=False,
            **kwargs,
        )

    def prepare_inputs_for_generation(
        self,
        input_ids: torch.LongTensor,
        **kwargs,
    ) -> Dict[str, Any]:
        """
        Prepare inputs for generation (required by HuggingFace).

        For RND1, we don't use the standard autoregressive generation,
        so this just returns the input_ids.
        """
        return {"input_ids": input_ids}