persona-ui / utils /chat.py
Jac-Zac
add session-scoped NDIF execution and improve cold-load UX
ae347c6
from __future__ import annotations
import logging
from collections.abc import Callable
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal
from persona_data.prompts import format_messages, format_prompt, normalize_messages
if TYPE_CHECKING:
import torch
from nnterp import StandardizedTransformer
from persona_data.synth_persona import PersonaData
logger = logging.getLogger(__name__)
SystemPromptMode = Literal["empty", "templated", "biography", "custom"]
@dataclass
class ChatReply:
text: str
generated_ids: Any | None = None
def build_chat_messages(
system_prompt: str | None,
messages: list[dict[str, str]],
) -> list[dict[str, str]]:
"""Prepend the active system prompt to a chat history when present."""
return (
[{"role": "system", "content": system_prompt}] if system_prompt else []
) + messages
def resolve_system_prompt(
persona: PersonaData | None,
mode: SystemPromptMode,
) -> str:
"""Resolve the active system prompt for chat.
Args:
persona: Selected persona, if any.
mode: Prompt mode selected in the UI.
Returns:
The rendered system prompt string.
"""
if persona is None or mode == "empty":
return ""
if mode == "custom":
return format_prompt(persona, "templated", mode="conversational")
if mode in ("templated", "biography"):
return format_prompt(persona, mode, mode="conversational")
raise ValueError(f"Unsupported system prompt mode: {mode}")
def _format_plain_messages(
messages: list[dict[str, str]],
*,
add_generation_prompt: bool,
) -> str:
"""Format messages as plain text when no tokenizer chat template is usable."""
lines: list[str] = []
for message in messages:
role = message["role"]
content = message["content"]
if role == "system":
if content:
lines.append(f"System: {content}")
elif role == "user":
lines.append(f"User: {content}")
elif role == "assistant":
lines.append(f"Assistant: {content}")
else:
lines.append(f"{role.title()}: {content}")
if add_generation_prompt and (not lines or not lines[-1].startswith("Assistant:")):
lines.append("Assistant:")
return "\n\n".join(lines)
def format_generation_prompt(
messages: list[dict[str, str]],
tokenizer: object,
*,
add_generation_prompt: bool = True,
) -> tuple[str, int]:
"""Render chat messages and count prompt tokens.
``persona-data`` owns the standard chat-template path. The fallback below is
only for tokenizers with broken or missing chat templates.
"""
try:
prompt, prompt_token_count = format_messages(
messages,
tokenizer,
add_generation_prompt=add_generation_prompt,
)
return prompt, prompt_token_count
except Exception:
logger.debug("persona-data format_messages failed", exc_info=True)
try:
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=add_generation_prompt,
)
except Exception:
logger.debug("Chat template failed on raw messages", exc_info=True)
normalized = normalize_messages(messages)
try:
prompt = tokenizer.apply_chat_template(
normalized,
tokenize=False,
add_generation_prompt=add_generation_prompt,
)
except Exception:
logger.debug("Chat template fallback failed", exc_info=True)
prompt = _format_plain_messages(
normalized,
add_generation_prompt=add_generation_prompt,
)
prompt_token_count = tokenizer(prompt, return_tensors="pt").input_ids.shape[1]
return prompt, prompt_token_count
def resolve_saved_tensor(value: object) -> torch.Tensor:
"""Resolve an nnsight ``.save()`` proxy (or raw tensor) to a CPU tensor."""
import torch
resolved = value.value if getattr(value, "value", None) is not None else value
if not isinstance(resolved, torch.Tensor):
raise TypeError(f"Trace result did not resolve to a tensor: {type(resolved)!r}")
return resolved.detach().cpu()
def decode_token(tokenizer: object, token_id: int) -> str:
"""Decode a single token id, falling back when ``clean_up_tokenization_spaces`` is unsupported."""
try:
return tokenizer.decode(
[token_id],
skip_special_tokens=False,
clean_up_tokenization_spaces=False,
)
except TypeError:
return tokenizer.decode([token_id], skip_special_tokens=False)
@contextmanager
def _seeded_rng(seed: int | None):
"""Context manager that forks the RNG state and sets a deterministic seed."""
if seed is None:
yield
return
import torch
cuda_ctx = torch.random.fork_rng(devices=range(torch.cuda.device_count()))
mps_ctx = (
torch.random.fork_rng(devices=range(1), device_type="mps")
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
else nullcontext()
)
with cuda_ctx, mps_ctx:
torch.manual_seed(seed)
yield
def generate_chat_reply(
model: StandardizedTransformer,
messages: list[dict[str, str]],
remote: bool,
max_new_tokens: int = 256,
do_sample: bool = False,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
repetition_penalty: float = 1.0,
seed: int | None = None,
on_status: Callable[[str, str, str], None] | None = None,
ndif_api_key: str | None = None,
) -> ChatReply:
"""Generate one assistant reply from a full chat history.
The helper uses ``model.generate`` so it works with both local and NDIF-backed
nnsight models. The full conversation is re-rendered each turn.
Args:
model: Loaded standardized nnterp model.
messages: Full chat history, including any system prompt as the first message.
remote: Whether to execute the generation on NDIF.
max_new_tokens: Maximum number of assistant tokens to generate.
do_sample: Whether to sample from the model distribution.
temperature: Sampling temperature, used only when sampling is enabled.
top_p: Nucleus sampling threshold, used only when sampling is enabled.
top_k: Top-k cutoff, used only when sampling is enabled.
repetition_penalty: Repetition penalty applied during decoding.
seed: Optional local RNG seed for sampled generation.
Returns:
ChatReply with generated text and token ids.
"""
import torch
tokenizer = model.tokenizer
prompt, prompt_token_count = format_generation_prompt(messages, tokenizer)
generation_kwargs: dict[str, object] = {
"max_new_tokens": max_new_tokens,
"use_cache": True,
}
if not remote:
# No need for this in remote which also slows down download drastically
generation_kwargs["return_dict_in_generate"] = True
if do_sample:
generation_kwargs["do_sample"] = True
generation_kwargs["temperature"] = temperature
generation_kwargs["top_p"] = top_p
generation_kwargs["top_k"] = top_k
if repetition_penalty != 1.0:
generation_kwargs["repetition_penalty"] = repetition_penalty
# `remote` is captured by nnsight's RemoteableMixin.trace() and is NOT
# forwarded to the underlying model's generate
if remote:
from utils.runtime import remote_backend
backend = remote_backend(model, ndif_api_key, on_status=on_status)
else:
backend = None
with (
_seeded_rng(seed if do_sample and not remote else None),
model.generate(
prompt,
remote=remote,
backend=backend,
**generation_kwargs,
) as tracer,
):
generated = tracer.result.save()
if getattr(generated, "value", None) is not None:
generated = generated.value
sequences = generated.sequences if hasattr(generated, "sequences") else generated
if not isinstance(sequences, torch.Tensor):
raise TypeError("Generated sequences must be a tensor")
generated_ids = sequences[0, prompt_token_count:]
text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
return ChatReply(
text=text,
generated_ids=generated_ids.detach().cpu(),
)