| from __future__ import annotations |
|
|
| import logging |
| from collections.abc import Callable |
| from contextlib import contextmanager, nullcontext |
| from dataclasses import dataclass |
| from typing import TYPE_CHECKING, Any, Literal |
|
|
| from persona_data.prompts import format_messages, format_prompt, normalize_messages |
|
|
| if TYPE_CHECKING: |
| import torch |
| from nnterp import StandardizedTransformer |
| from persona_data.synth_persona import PersonaData |
|
|
| logger = logging.getLogger(__name__) |
| SystemPromptMode = Literal["empty", "templated", "biography", "custom"] |
|
|
|
|
| @dataclass |
| class ChatReply: |
| text: str |
| generated_ids: Any | None = None |
|
|
|
|
| def build_chat_messages( |
| system_prompt: str | None, |
| messages: list[dict[str, str]], |
| ) -> list[dict[str, str]]: |
| """Prepend the active system prompt to a chat history when present.""" |
|
|
| return ( |
| [{"role": "system", "content": system_prompt}] if system_prompt else [] |
| ) + messages |
|
|
|
|
| def resolve_system_prompt( |
| persona: PersonaData | None, |
| mode: SystemPromptMode, |
| ) -> str: |
| """Resolve the active system prompt for chat. |
| |
| Args: |
| persona: Selected persona, if any. |
| mode: Prompt mode selected in the UI. |
| |
| Returns: |
| The rendered system prompt string. |
| """ |
|
|
| if persona is None or mode == "empty": |
| return "" |
| if mode == "custom": |
| return format_prompt(persona, "templated", mode="conversational") |
| if mode in ("templated", "biography"): |
| return format_prompt(persona, mode, mode="conversational") |
| raise ValueError(f"Unsupported system prompt mode: {mode}") |
|
|
|
|
| def _format_plain_messages( |
| messages: list[dict[str, str]], |
| *, |
| add_generation_prompt: bool, |
| ) -> str: |
| """Format messages as plain text when no tokenizer chat template is usable.""" |
|
|
| lines: list[str] = [] |
| for message in messages: |
| role = message["role"] |
| content = message["content"] |
| if role == "system": |
| if content: |
| lines.append(f"System: {content}") |
| elif role == "user": |
| lines.append(f"User: {content}") |
| elif role == "assistant": |
| lines.append(f"Assistant: {content}") |
| else: |
| lines.append(f"{role.title()}: {content}") |
|
|
| if add_generation_prompt and (not lines or not lines[-1].startswith("Assistant:")): |
| lines.append("Assistant:") |
|
|
| return "\n\n".join(lines) |
|
|
|
|
| def format_generation_prompt( |
| messages: list[dict[str, str]], |
| tokenizer: object, |
| *, |
| add_generation_prompt: bool = True, |
| ) -> tuple[str, int]: |
| """Render chat messages and count prompt tokens. |
| |
| ``persona-data`` owns the standard chat-template path. The fallback below is |
| only for tokenizers with broken or missing chat templates. |
| """ |
|
|
| try: |
| prompt, prompt_token_count = format_messages( |
| messages, |
| tokenizer, |
| add_generation_prompt=add_generation_prompt, |
| ) |
| return prompt, prompt_token_count |
| except Exception: |
| logger.debug("persona-data format_messages failed", exc_info=True) |
|
|
| try: |
| prompt = tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=add_generation_prompt, |
| ) |
| except Exception: |
| logger.debug("Chat template failed on raw messages", exc_info=True) |
| normalized = normalize_messages(messages) |
| try: |
| prompt = tokenizer.apply_chat_template( |
| normalized, |
| tokenize=False, |
| add_generation_prompt=add_generation_prompt, |
| ) |
| except Exception: |
| logger.debug("Chat template fallback failed", exc_info=True) |
| prompt = _format_plain_messages( |
| normalized, |
| add_generation_prompt=add_generation_prompt, |
| ) |
|
|
| prompt_token_count = tokenizer(prompt, return_tensors="pt").input_ids.shape[1] |
| return prompt, prompt_token_count |
|
|
|
|
| def resolve_saved_tensor(value: object) -> torch.Tensor: |
| """Resolve an nnsight ``.save()`` proxy (or raw tensor) to a CPU tensor.""" |
| import torch |
|
|
| resolved = value.value if getattr(value, "value", None) is not None else value |
| if not isinstance(resolved, torch.Tensor): |
| raise TypeError(f"Trace result did not resolve to a tensor: {type(resolved)!r}") |
| return resolved.detach().cpu() |
|
|
|
|
| def decode_token(tokenizer: object, token_id: int) -> str: |
| """Decode a single token id, falling back when ``clean_up_tokenization_spaces`` is unsupported.""" |
| try: |
| return tokenizer.decode( |
| [token_id], |
| skip_special_tokens=False, |
| clean_up_tokenization_spaces=False, |
| ) |
| except TypeError: |
| return tokenizer.decode([token_id], skip_special_tokens=False) |
|
|
|
|
| @contextmanager |
| def _seeded_rng(seed: int | None): |
| """Context manager that forks the RNG state and sets a deterministic seed.""" |
| if seed is None: |
| yield |
| return |
|
|
| import torch |
|
|
| cuda_ctx = torch.random.fork_rng(devices=range(torch.cuda.device_count())) |
| mps_ctx = ( |
| torch.random.fork_rng(devices=range(1), device_type="mps") |
| if hasattr(torch.backends, "mps") and torch.backends.mps.is_available() |
| else nullcontext() |
| ) |
|
|
| with cuda_ctx, mps_ctx: |
| torch.manual_seed(seed) |
| yield |
|
|
|
|
| def generate_chat_reply( |
| model: StandardizedTransformer, |
| messages: list[dict[str, str]], |
| remote: bool, |
| max_new_tokens: int = 256, |
| do_sample: bool = False, |
| temperature: float = 1.0, |
| top_p: float = 1.0, |
| top_k: int = 50, |
| repetition_penalty: float = 1.0, |
| seed: int | None = None, |
| on_status: Callable[[str, str, str], None] | None = None, |
| ndif_api_key: str | None = None, |
| ) -> ChatReply: |
| """Generate one assistant reply from a full chat history. |
| |
| The helper uses ``model.generate`` so it works with both local and NDIF-backed |
| nnsight models. The full conversation is re-rendered each turn. |
| |
| Args: |
| model: Loaded standardized nnterp model. |
| messages: Full chat history, including any system prompt as the first message. |
| remote: Whether to execute the generation on NDIF. |
| max_new_tokens: Maximum number of assistant tokens to generate. |
| do_sample: Whether to sample from the model distribution. |
| temperature: Sampling temperature, used only when sampling is enabled. |
| top_p: Nucleus sampling threshold, used only when sampling is enabled. |
| top_k: Top-k cutoff, used only when sampling is enabled. |
| repetition_penalty: Repetition penalty applied during decoding. |
| seed: Optional local RNG seed for sampled generation. |
| |
| Returns: |
| ChatReply with generated text and token ids. |
| """ |
|
|
| import torch |
|
|
| tokenizer = model.tokenizer |
| prompt, prompt_token_count = format_generation_prompt(messages, tokenizer) |
|
|
| generation_kwargs: dict[str, object] = { |
| "max_new_tokens": max_new_tokens, |
| "use_cache": True, |
| } |
| if not remote: |
| |
| generation_kwargs["return_dict_in_generate"] = True |
| if do_sample: |
| generation_kwargs["do_sample"] = True |
| generation_kwargs["temperature"] = temperature |
| generation_kwargs["top_p"] = top_p |
| generation_kwargs["top_k"] = top_k |
| if repetition_penalty != 1.0: |
| generation_kwargs["repetition_penalty"] = repetition_penalty |
| |
| |
| if remote: |
| from utils.runtime import remote_backend |
|
|
| backend = remote_backend(model, ndif_api_key, on_status=on_status) |
| else: |
| backend = None |
|
|
| with ( |
| _seeded_rng(seed if do_sample and not remote else None), |
| model.generate( |
| prompt, |
| remote=remote, |
| backend=backend, |
| **generation_kwargs, |
| ) as tracer, |
| ): |
| generated = tracer.result.save() |
|
|
| if getattr(generated, "value", None) is not None: |
| generated = generated.value |
|
|
| sequences = generated.sequences if hasattr(generated, "sequences") else generated |
| if not isinstance(sequences, torch.Tensor): |
| raise TypeError("Generated sequences must be a tensor") |
|
|
| generated_ids = sequences[0, prompt_token_count:] |
| text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip() |
| return ChatReply( |
| text=text, |
| generated_ids=generated_ids.detach().cpu(), |
| ) |
|
|