File size: 8,598 Bytes
99c28ab
 
77c2d62
b279884
a9950fb
a89a7f1
c607869
a89a7f1
77c2d62
a89a7f1
99c28ab
c607869
99c28ab
c607869
99c28ab
77c2d62
a89a7f1
 
 
 
 
 
c607869
a89a7f1
 
77c2d62
 
 
 
 
 
 
 
 
 
 
a89a7f1
 
 
 
 
 
 
 
 
 
 
 
 
 
220e208
a89a7f1
 
d8ae160
220e208
88f2164
eb41f91
a89a7f1
 
77c2d62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c30bbc5
77c2d62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c30bbc5
 
c607869
 
c30bbc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a89a7f1
 
 
 
 
 
 
c607869
 
a89a7f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b279884
ae347c6
a89a7f1
 
 
 
12cdb17
a89a7f1
 
 
 
 
 
 
 
 
 
 
 
 
 
12cdb17
a89a7f1
 
c607869
 
a89a7f1
77c2d62
a89a7f1
 
 
 
 
e2cecb1
 
 
a89a7f1
 
 
 
 
 
 
e2cecb1
 
ae347c6
 
 
 
 
 
b279884
db3d901
 
b279884
 
 
 
 
 
db3d901
 
a89a7f1
220e208
a89a7f1
 
e2cecb1
a89a7f1
 
 
 
 
 
 
a9950fb
a89a7f1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
from __future__ import annotations

import logging
from collections.abc import Callable
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal

from persona_data.prompts import format_messages, format_prompt, normalize_messages

if TYPE_CHECKING:
    import torch
    from nnterp import StandardizedTransformer
    from persona_data.synth_persona import PersonaData

logger = logging.getLogger(__name__)
SystemPromptMode = Literal["empty", "templated", "biography", "custom"]


@dataclass
class ChatReply:
    text: str
    generated_ids: Any | None = None


def build_chat_messages(
    system_prompt: str | None,
    messages: list[dict[str, str]],
) -> list[dict[str, str]]:
    """Prepend the active system prompt to a chat history when present."""

    return (
        [{"role": "system", "content": system_prompt}] if system_prompt else []
    ) + messages


def resolve_system_prompt(
    persona: PersonaData | None,
    mode: SystemPromptMode,
) -> str:
    """Resolve the active system prompt for chat.

    Args:
        persona: Selected persona, if any.
        mode: Prompt mode selected in the UI.

    Returns:
        The rendered system prompt string.
    """

    if persona is None or mode == "empty":
        return ""
    if mode == "custom":
        return format_prompt(persona, "templated", mode="conversational")
    if mode in ("templated", "biography"):
        return format_prompt(persona, mode, mode="conversational")
    raise ValueError(f"Unsupported system prompt mode: {mode}")


def _format_plain_messages(
    messages: list[dict[str, str]],
    *,
    add_generation_prompt: bool,
) -> str:
    """Format messages as plain text when no tokenizer chat template is usable."""

    lines: list[str] = []
    for message in messages:
        role = message["role"]
        content = message["content"]
        if role == "system":
            if content:
                lines.append(f"System: {content}")
        elif role == "user":
            lines.append(f"User: {content}")
        elif role == "assistant":
            lines.append(f"Assistant: {content}")
        else:
            lines.append(f"{role.title()}: {content}")

    if add_generation_prompt and (not lines or not lines[-1].startswith("Assistant:")):
        lines.append("Assistant:")

    return "\n\n".join(lines)


def format_generation_prompt(
    messages: list[dict[str, str]],
    tokenizer: object,
    *,
    add_generation_prompt: bool = True,
) -> tuple[str, int]:
    """Render chat messages and count prompt tokens.

    ``persona-data`` owns the standard chat-template path. The fallback below is
    only for tokenizers with broken or missing chat templates.
    """

    try:
        prompt, prompt_token_count = format_messages(
            messages,
            tokenizer,
            add_generation_prompt=add_generation_prompt,
        )
        return prompt, prompt_token_count
    except Exception:
        logger.debug("persona-data format_messages failed", exc_info=True)

    try:
        prompt = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=add_generation_prompt,
        )
    except Exception:
        logger.debug("Chat template failed on raw messages", exc_info=True)
        normalized = normalize_messages(messages)
        try:
            prompt = tokenizer.apply_chat_template(
                normalized,
                tokenize=False,
                add_generation_prompt=add_generation_prompt,
            )
        except Exception:
            logger.debug("Chat template fallback failed", exc_info=True)
            prompt = _format_plain_messages(
                normalized,
                add_generation_prompt=add_generation_prompt,
            )

    prompt_token_count = tokenizer(prompt, return_tensors="pt").input_ids.shape[1]
    return prompt, prompt_token_count


def resolve_saved_tensor(value: object) -> torch.Tensor:
    """Resolve an nnsight ``.save()`` proxy (or raw tensor) to a CPU tensor."""
    import torch

    resolved = value.value if getattr(value, "value", None) is not None else value
    if not isinstance(resolved, torch.Tensor):
        raise TypeError(f"Trace result did not resolve to a tensor: {type(resolved)!r}")
    return resolved.detach().cpu()


def decode_token(tokenizer: object, token_id: int) -> str:
    """Decode a single token id, falling back when ``clean_up_tokenization_spaces`` is unsupported."""
    try:
        return tokenizer.decode(
            [token_id],
            skip_special_tokens=False,
            clean_up_tokenization_spaces=False,
        )
    except TypeError:
        return tokenizer.decode([token_id], skip_special_tokens=False)


@contextmanager
def _seeded_rng(seed: int | None):
    """Context manager that forks the RNG state and sets a deterministic seed."""
    if seed is None:
        yield
        return

    import torch

    cuda_ctx = torch.random.fork_rng(devices=range(torch.cuda.device_count()))
    mps_ctx = (
        torch.random.fork_rng(devices=range(1), device_type="mps")
        if hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
        else nullcontext()
    )

    with cuda_ctx, mps_ctx:
        torch.manual_seed(seed)
        yield


def generate_chat_reply(
    model: StandardizedTransformer,
    messages: list[dict[str, str]],
    remote: bool,
    max_new_tokens: int = 256,
    do_sample: bool = False,
    temperature: float = 1.0,
    top_p: float = 1.0,
    top_k: int = 50,
    repetition_penalty: float = 1.0,
    seed: int | None = None,
    on_status: Callable[[str, str, str], None] | None = None,
    ndif_api_key: str | None = None,
) -> ChatReply:
    """Generate one assistant reply from a full chat history.

    The helper uses ``model.generate`` so it works with both local and NDIF-backed
    nnsight models. The full conversation is re-rendered each turn.

    Args:
        model: Loaded standardized nnterp model.
        messages: Full chat history, including any system prompt as the first message.
        remote: Whether to execute the generation on NDIF.
        max_new_tokens: Maximum number of assistant tokens to generate.
        do_sample: Whether to sample from the model distribution.
        temperature: Sampling temperature, used only when sampling is enabled.
        top_p: Nucleus sampling threshold, used only when sampling is enabled.
        top_k: Top-k cutoff, used only when sampling is enabled.
        repetition_penalty: Repetition penalty applied during decoding.
        seed: Optional local RNG seed for sampled generation.

    Returns:
        ChatReply with generated text and token ids.
    """

    import torch

    tokenizer = model.tokenizer
    prompt, prompt_token_count = format_generation_prompt(messages, tokenizer)

    generation_kwargs: dict[str, object] = {
        "max_new_tokens": max_new_tokens,
        "use_cache": True,
    }
    if not remote:
        # No need for this in remote which also slows down download drastically
        generation_kwargs["return_dict_in_generate"] = True
    if do_sample:
        generation_kwargs["do_sample"] = True
        generation_kwargs["temperature"] = temperature
        generation_kwargs["top_p"] = top_p
        generation_kwargs["top_k"] = top_k
    if repetition_penalty != 1.0:
        generation_kwargs["repetition_penalty"] = repetition_penalty
    # `remote` is captured by nnsight's RemoteableMixin.trace() and is NOT
    # forwarded to the underlying model's generate
    if remote:
        from utils.runtime import remote_backend

        backend = remote_backend(model, ndif_api_key, on_status=on_status)
    else:
        backend = None

    with (
        _seeded_rng(seed if do_sample and not remote else None),
        model.generate(
            prompt,
            remote=remote,
            backend=backend,
            **generation_kwargs,
        ) as tracer,
    ):
        generated = tracer.result.save()

    if getattr(generated, "value", None) is not None:
        generated = generated.value

    sequences = generated.sequences if hasattr(generated, "sequences") else generated
    if not isinstance(sequences, torch.Tensor):
        raise TypeError("Generated sequences must be a tensor")

    generated_ids = sequences[0, prompt_token_count:]
    text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
    return ChatReply(
        text=text,
        generated_ids=generated_ids.detach().cpu(),
    )