File size: 1,862 Bytes
77c2d62
a89a7f1
c30bbc5
 
db3d901
 
77c2d62
 
 
 
 
 
 
 
 
 
 
 
 
 
a89a7f1
 
 
b279884
a89a7f1
b279884
 
 
 
 
 
 
 
 
 
 
a89a7f1
 
77c2d62
a89a7f1
 
 
 
 
 
 
a9950fb
77c2d62
a9950fb
 
 
 
 
 
 
 
 
 
 
 
 
db3d901
a89a7f1
 
 
db3d901
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from typing import Literal, NotRequired, TypedDict

import streamlit as st

from utils.helpers import session_key

PendingChatAction = Literal["new_user_prompt", "regenerate_after_edit"]


class ChatMessage(TypedDict):
    role: str
    content: str
    _contrast: NotRequired[object]
    _needs_contrast: NotRequired[bool]


class ChatState(TypedDict):
    messages: list[ChatMessage]
    persona_id: str | None
    prompt_mode: str


def chat_session_key(model_name: str, dataset_source: str) -> str:
    """Build the session-state key for a chat conversation.

    A model/backend switch changes *how* the next turn is generated, not which
    conversation the user is looking at. Keeping the model out of the key means
    toggling local/remote execution (or selecting another model) no longer makes
    an existing thread appear to vanish behind a fresh empty state.

    ``model_name`` stays in the signature for call-site compatibility and to
    make the intent explicit where chat state is requested.
    """

    _ = model_name
    return session_key("chat_state", dataset_source)


def default_chat_state() -> ChatState:
    return {
        "messages": [],
        "persona_id": None,
        "prompt_mode": "templated",
    }


def reset_chat_context_state(
    state: ChatState,
    persona_id: str,
    prompt_mode: str,
    *ui_keys: str,
) -> None:
    """Reset one chat context and clear any related widget state."""

    state["messages"] = []
    state["persona_id"] = persona_id
    state["prompt_mode"] = prompt_mode
    for key in ui_keys:
        st.session_state.pop(key, None)


def get_chat_state(model_name: str, dataset_source: str) -> ChatState:
    """Return the mutable chat state for the active context."""

    key = chat_session_key(model_name, dataset_source)
    return st.session_state.setdefault(key, default_chat_state())