Spaces:
Running on L4
Running on L4
System Prompt Explorer: dual model, multi-turn chat, configurable presets
Browse files- Dual model architecture: base (Llama-3.2-3B) + chat (Llama-3.2-3B-Instruct)
- Multi-turn chat with gr.State for clean history (Chatbot display-only)
- 11 configurable presets via admin panel or SYSTEM_PROMPT_PRESETS env var
- All config values overridable via env vars (Secrets vs Variables documented)
- No auto-reset on prompt changes; green terminal collapsed by default
- Educational note: no hidden system prompt, helpfulness from RLHF
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- app.py +316 -117
- config.json +15 -1
- models.py +204 -89
app.py
CHANGED
|
@@ -17,6 +17,8 @@ from datetime import datetime, timezone, timedelta
|
|
| 17 |
|
| 18 |
import gradio as gr
|
| 19 |
|
|
|
|
|
|
|
| 20 |
from models import AVAILABLE_MODELS, manager, demo_tokenizer
|
| 21 |
|
| 22 |
# ---------------------------------------------------------------------------
|
|
@@ -663,16 +665,12 @@ def tokenize_text(text):
|
|
| 663 |
# Tab 3: System Prompt Explorer
|
| 664 |
# ---------------------------------------------------------------------------
|
| 665 |
|
| 666 |
-
|
| 667 |
-
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
"
|
| 671 |
-
|
| 672 |
-
"Hostile / Rude": "You are rude and dismissive. You answer questions but with obvious annoyance and sarcasm.",
|
| 673 |
-
"Haiku Only": "You must respond only in haiku (5-7-5 syllable format). Never break this rule.",
|
| 674 |
-
"Spanish Tutor": "You are a Spanish language tutor. Respond in Spanish, then provide the English translation in parentheses.",
|
| 675 |
-
}
|
| 676 |
|
| 677 |
|
| 678 |
def _esc_terminal(text: str) -> str:
|
|
@@ -680,82 +678,139 @@ def _esc_terminal(text: str) -> str:
|
|
| 680 |
return text.replace("&", "&").replace("<", "<").replace(">", ">")
|
| 681 |
|
| 682 |
|
| 683 |
-
def
|
| 684 |
-
"""
|
| 685 |
-
if not manager.is_ready():
|
| 686 |
-
return (
|
| 687 |
-
f"<div class='green-terminal'><span class='sp-special'>Error: {manager.status_message()}</span></div>",
|
| 688 |
-
"<div class='response-card' style='color:red;'>No model loaded. Load an instruct model from the Admin tab.</div>",
|
| 689 |
-
)
|
| 690 |
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
|
| 694 |
-
|
| 695 |
-
)
|
| 696 |
|
| 697 |
-
|
| 698 |
-
|
| 699 |
-
|
| 700 |
-
|
| 701 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 702 |
|
| 703 |
result = manager.generate_chat(
|
| 704 |
-
|
| 705 |
-
user_message=user_message,
|
| 706 |
max_new_tokens=int(max_tokens),
|
| 707 |
temperature=temperature,
|
| 708 |
seed=int(seed),
|
| 709 |
)
|
| 710 |
|
| 711 |
if "error" in result:
|
| 712 |
-
|
| 713 |
-
|
| 714 |
-
|
| 715 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 716 |
|
| 717 |
-
# Render the formatted prompt in the green terminal
|
| 718 |
-
formatted = result["formatted_prompt"]
|
| 719 |
-
# Color-code the special tokens and roles
|
| 720 |
-
terminal_html = _esc_terminal(formatted)
|
| 721 |
-
# Highlight Qwen-style chat template tokens
|
| 722 |
-
for tag in ["<|im_start|>", "<|im_end|>"]:
|
| 723 |
-
terminal_html = terminal_html.replace(
|
| 724 |
-
_esc_terminal(tag),
|
| 725 |
-
f"<span class='sp-special'>{_esc_terminal(tag)}</span>",
|
| 726 |
-
)
|
| 727 |
-
terminal_html = terminal_html.replace(
|
| 728 |
-
"system\n", "<span class='sp-label'>system</span>\n"
|
| 729 |
-
)
|
| 730 |
-
terminal_html = terminal_html.replace(
|
| 731 |
-
"user\n", "<span class='sp-label'>user</span>\n"
|
| 732 |
-
)
|
| 733 |
-
terminal_html = terminal_html.replace(
|
| 734 |
-
"assistant\n", "<span class='sp-label'>assistant</span>\n"
|
| 735 |
-
)
|
| 736 |
-
# Highlight the system prompt content
|
| 737 |
-
if system_prompt.strip():
|
| 738 |
-
terminal_html = terminal_html.replace(
|
| 739 |
-
_esc_terminal(system_prompt),
|
| 740 |
-
f"<span class='sp-system'>{_esc_terminal(system_prompt)}</span>",
|
| 741 |
-
)
|
| 742 |
-
# Highlight user message
|
| 743 |
-
terminal_html = terminal_html.replace(
|
| 744 |
-
_esc_terminal(user_message),
|
| 745 |
-
f"<span class='sp-user'>{_esc_terminal(user_message)}</span>",
|
| 746 |
-
)
|
| 747 |
|
| 748 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 749 |
|
| 750 |
-
# Render the response in a clean card
|
| 751 |
-
response_html = f"<div class='response-card'>{_esc(result['response'])}</div>"
|
| 752 |
|
| 753 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 754 |
|
| 755 |
|
| 756 |
def on_preset_change(preset_name):
|
| 757 |
-
"""Update system prompt textbox when a preset is selected."""
|
| 758 |
-
return
|
| 759 |
|
| 760 |
|
| 761 |
# ---------------------------------------------------------------------------
|
|
@@ -770,13 +825,21 @@ def admin_login(password):
|
|
| 770 |
|
| 771 |
|
| 772 |
def admin_load_model(model_name):
|
| 773 |
-
"""Load a new model from admin panel."""
|
| 774 |
status = manager.load_model(model_name)
|
| 775 |
cfg = manager.get_config()
|
| 776 |
header_status = f"**{manager.status_message()}**"
|
| 777 |
return status, json.dumps(cfg, indent=2), header_status
|
| 778 |
|
| 779 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 780 |
def admin_save_defaults(prompt, tokenizer_text, temperature, top_k, steps, seed):
|
| 781 |
"""Save default settings and return updated values for all outputs."""
|
| 782 |
manager.update_config(
|
|
@@ -800,6 +863,41 @@ def admin_save_defaults(prompt, tokenizer_text, temperature, top_k, steps, seed)
|
|
| 800 |
)
|
| 801 |
|
| 802 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 803 |
# ---------------------------------------------------------------------------
|
| 804 |
# Build the Gradio app
|
| 805 |
# ---------------------------------------------------------------------------
|
|
@@ -978,44 +1076,36 @@ def create_app():
|
|
| 978 |
gr.Markdown("### System Prompt Explorer")
|
| 979 |
gr.Markdown(
|
| 980 |
"See how **system prompts** change an LLM's behavior. "
|
|
|
|
| 981 |
"The green terminal shows exactly what the model receives — "
|
| 982 |
-
"special
|
| 983 |
-
"Try different presets or write your own."
|
| 984 |
)
|
| 985 |
|
| 986 |
-
|
| 987 |
-
|
| 988 |
-
|
| 989 |
-
"Load one from the Admin tab.*",
|
| 990 |
-
elem_classes=["param-help"],
|
| 991 |
-
)
|
| 992 |
|
| 993 |
sp_preset = gr.Dropdown(
|
| 994 |
label="Preset",
|
| 995 |
-
choices=
|
| 996 |
-
value=
|
| 997 |
interactive=True,
|
| 998 |
)
|
| 999 |
sp_system = gr.Textbox(
|
| 1000 |
label="System Prompt",
|
| 1001 |
-
value=
|
| 1002 |
lines=3,
|
| 1003 |
placeholder="Enter a system prompt, or select a preset above...",
|
| 1004 |
)
|
| 1005 |
-
sp_user = gr.Textbox(
|
| 1006 |
-
label="User Message",
|
| 1007 |
-
value="What is Huston-Tillotson University?",
|
| 1008 |
-
lines=2,
|
| 1009 |
-
)
|
| 1010 |
|
| 1011 |
with gr.Accordion("Settings", open=False):
|
| 1012 |
sp_max_tokens = gr.Slider(
|
| 1013 |
label="Max tokens",
|
| 1014 |
-
minimum=32, maximum=
|
| 1015 |
-
value=
|
| 1016 |
)
|
| 1017 |
gr.Markdown(
|
| 1018 |
-
"Maximum number of tokens
|
| 1019 |
elem_classes=["param-help"],
|
| 1020 |
)
|
| 1021 |
sp_temperature = gr.Slider(
|
|
@@ -1029,33 +1119,69 @@ def create_app():
|
|
| 1029 |
precision=0,
|
| 1030 |
)
|
| 1031 |
|
| 1032 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1033 |
|
| 1034 |
-
gr.Markdown("####
|
| 1035 |
gr.Markdown(
|
| 1036 |
-
"
|
| 1037 |
-
"
|
| 1038 |
-
"
|
|
|
|
|
|
|
| 1039 |
elem_classes=["param-help"],
|
| 1040 |
)
|
| 1041 |
-
|
| 1042 |
-
|
| 1043 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1044 |
|
| 1045 |
-
|
| 1046 |
-
sp_response = gr.HTML(value="")
|
| 1047 |
|
| 1048 |
-
#
|
| 1049 |
sp_preset.change(
|
| 1050 |
fn=on_preset_change,
|
| 1051 |
inputs=[sp_preset],
|
| 1052 |
outputs=[sp_system],
|
| 1053 |
)
|
| 1054 |
|
| 1055 |
-
|
| 1056 |
-
|
| 1057 |
-
|
| 1058 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1059 |
)
|
| 1060 |
|
| 1061 |
# ==================================================================
|
|
@@ -1076,16 +1202,26 @@ def create_app():
|
|
| 1076 |
|
| 1077 |
# Admin controls (hidden until login)
|
| 1078 |
with gr.Group(visible=False) as admin_controls:
|
| 1079 |
-
gr.Markdown("#### Model")
|
| 1080 |
with gr.Row():
|
| 1081 |
admin_model_dropdown = gr.Dropdown(
|
| 1082 |
choices=list(AVAILABLE_MODELS.keys()),
|
| 1083 |
-
value=manager.current_model_name or cfg.get("model", "
|
| 1084 |
label="Select model",
|
| 1085 |
)
|
| 1086 |
-
admin_load_btn = gr.Button("Load
|
| 1087 |
admin_model_status = gr.Markdown("")
|
| 1088 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1089 |
gr.Markdown("---")
|
| 1090 |
gr.Markdown("#### Default Settings")
|
| 1091 |
admin_prompt = gr.Textbox(
|
|
@@ -1120,13 +1256,48 @@ def create_app():
|
|
| 1120 |
admin_save_msg = gr.Markdown("")
|
| 1121 |
|
| 1122 |
gr.Markdown("---")
|
| 1123 |
-
gr.Markdown("####
|
| 1124 |
gr.Markdown(
|
| 1125 |
-
"
|
|
|
|
| 1126 |
elem_classes=["param-help"],
|
| 1127 |
)
|
| 1128 |
-
|
| 1129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1130 |
|
| 1131 |
gr.Markdown("---")
|
| 1132 |
gr.Markdown("#### Current Config")
|
|
@@ -1136,6 +1307,15 @@ def create_app():
|
|
| 1136 |
interactive=False,
|
| 1137 |
)
|
| 1138 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1139 |
# Login wiring
|
| 1140 |
admin_login_btn.click(
|
| 1141 |
fn=admin_login,
|
|
@@ -1143,13 +1323,20 @@ def create_app():
|
|
| 1143 |
outputs=[admin_controls, admin_login_group, admin_login_msg],
|
| 1144 |
)
|
| 1145 |
|
| 1146 |
-
# Model loading
|
| 1147 |
admin_load_btn.click(
|
| 1148 |
fn=admin_load_model,
|
| 1149 |
inputs=[admin_model_dropdown],
|
| 1150 |
outputs=[admin_model_status, admin_config_display, status_display],
|
| 1151 |
)
|
| 1152 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1153 |
# Save defaults — updates config display + Probability Explorer + Tokenizer controls
|
| 1154 |
admin_save_btn.click(
|
| 1155 |
fn=admin_save_defaults,
|
|
@@ -1166,6 +1353,13 @@ def create_app():
|
|
| 1166 |
],
|
| 1167 |
)
|
| 1168 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1169 |
# Export slides — uses current Probability Explorer settings
|
| 1170 |
admin_export_btn.click(
|
| 1171 |
fn=generate_slideshow,
|
|
@@ -1327,12 +1521,17 @@ def create_app():
|
|
| 1327 |
# ---------------------------------------------------------------------------
|
| 1328 |
|
| 1329 |
if __name__ == "__main__":
|
| 1330 |
-
# Load default model on startup
|
| 1331 |
cfg = manager.get_config()
|
| 1332 |
-
|
| 1333 |
-
|
| 1334 |
-
|
| 1335 |
-
print(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1336 |
|
| 1337 |
app = create_app()
|
| 1338 |
app.launch(
|
|
|
|
| 17 |
|
| 18 |
import gradio as gr
|
| 19 |
|
| 20 |
+
import re
|
| 21 |
+
|
| 22 |
from models import AVAILABLE_MODELS, manager, demo_tokenizer
|
| 23 |
|
| 24 |
# ---------------------------------------------------------------------------
|
|
|
|
| 665 |
# Tab 3: System Prompt Explorer
|
| 666 |
# ---------------------------------------------------------------------------
|
| 667 |
|
| 668 |
+
MAX_CHAT_TURNS = 10 # max user messages before forcing reset
|
| 669 |
+
|
| 670 |
+
|
| 671 |
+
def _get_presets() -> dict:
|
| 672 |
+
"""Get current system prompt presets from config."""
|
| 673 |
+
return manager.get_config().get("system_prompt_presets", {})
|
|
|
|
|
|
|
|
|
|
|
|
|
| 674 |
|
| 675 |
|
| 676 |
def _esc_terminal(text: str) -> str:
|
|
|
|
| 678 |
return text.replace("&", "&").replace("<", "<").replace(">", ">")
|
| 679 |
|
| 680 |
|
| 681 |
+
def _format_terminal(raw_text: str) -> str:
|
| 682 |
+
"""Parse a chat template string and produce color-coded HTML for the green terminal.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 683 |
|
| 684 |
+
Works with both Llama (<|start_header_id|>) and Qwen (<|im_start|>) templates.
|
| 685 |
+
"""
|
| 686 |
+
# Split on special tokens, keeping them
|
| 687 |
+
parts = re.split(r'(<\|[^|]*\|>)', raw_text)
|
|
|
|
| 688 |
|
| 689 |
+
html_parts = []
|
| 690 |
+
current_role = None
|
| 691 |
+
expect_role = False # true right after a token that precedes a role label
|
| 692 |
+
|
| 693 |
+
role_css = {
|
| 694 |
+
"system": "sp-system",
|
| 695 |
+
"user": "sp-user",
|
| 696 |
+
"assistant": "sp-assistant",
|
| 697 |
+
}
|
| 698 |
+
|
| 699 |
+
for part in parts:
|
| 700 |
+
if re.match(r'<\|[^|]*\|>', part):
|
| 701 |
+
# Special token — render in gray
|
| 702 |
+
html_parts.append(f"<span class='sp-special'>{_esc_terminal(part)}</span>")
|
| 703 |
+
# After im_start or start_header_id, next text chunk is a role label
|
| 704 |
+
expect_role = ("im_start" in part or "start_header_id" in part)
|
| 705 |
+
elif expect_role and part.strip() in role_css:
|
| 706 |
+
# Role label (system / user / assistant)
|
| 707 |
+
role = part.strip()
|
| 708 |
+
current_role = role
|
| 709 |
+
before = _esc_terminal(part[: len(part) - len(part.lstrip())])
|
| 710 |
+
after = _esc_terminal(part[len(part.rstrip()) :])
|
| 711 |
+
html_parts.append(f"{before}<span class='sp-label'>{role}</span>{after}")
|
| 712 |
+
expect_role = False
|
| 713 |
+
else:
|
| 714 |
+
expect_role = False
|
| 715 |
+
css = role_css.get(current_role, "")
|
| 716 |
+
if css and part.strip():
|
| 717 |
+
html_parts.append(f"<span class='{css}'>{_esc_terminal(part)}</span>")
|
| 718 |
+
else:
|
| 719 |
+
html_parts.append(_esc_terminal(part))
|
| 720 |
+
|
| 721 |
+
return "<div class='green-terminal'>" + "".join(html_parts) + "</div>"
|
| 722 |
+
|
| 723 |
+
|
| 724 |
+
def _initial_terminal() -> str:
|
| 725 |
+
return "<div class='green-terminal'><span class='sp-special'>Send a message to see what the model receives.</span></div>"
|
| 726 |
+
|
| 727 |
+
|
| 728 |
+
def send_chat_message(user_message, history, system_prompt, max_tokens, temperature, seed):
|
| 729 |
+
"""Handle a user message: generate response, update state + chatbot + terminal.
|
| 730 |
+
|
| 731 |
+
`history` is the gr.State list of clean {"role": ..., "content": ...} dicts.
|
| 732 |
+
The Chatbot is derived from this — never read back from Chatbot (Gradio
|
| 733 |
+
mangles the dicts on round-trip).
|
| 734 |
+
"""
|
| 735 |
+
if not user_message or not user_message.strip():
|
| 736 |
+
chatbot = [{"role": m["role"], "content": m["content"]} for m in history]
|
| 737 |
+
return "", history, chatbot, _format_terminal_from_history(history, system_prompt)
|
| 738 |
+
|
| 739 |
+
if not manager.chat_ready():
|
| 740 |
+
history = history + [
|
| 741 |
+
{"role": "user", "content": user_message},
|
| 742 |
+
{"role": "assistant", "content": "No chat model loaded. Load one from the Admin tab."},
|
| 743 |
+
]
|
| 744 |
+
chatbot = [{"role": m["role"], "content": m["content"]} for m in history]
|
| 745 |
+
return "", history, chatbot, _initial_terminal()
|
| 746 |
+
|
| 747 |
+
# Check turn limit
|
| 748 |
+
user_turns = sum(1 for m in history if m["role"] == "user")
|
| 749 |
+
if user_turns >= MAX_CHAT_TURNS:
|
| 750 |
+
history = history + [
|
| 751 |
+
{"role": "user", "content": user_message},
|
| 752 |
+
{"role": "assistant", "content": f"Conversation limit reached ({MAX_CHAT_TURNS} exchanges). Click Reset to start a new conversation."},
|
| 753 |
+
]
|
| 754 |
+
chatbot = [{"role": m["role"], "content": m["content"]} for m in history]
|
| 755 |
+
return "", history, chatbot, _format_terminal_from_history(history, system_prompt)
|
| 756 |
+
|
| 757 |
+
# Build full messages for the model
|
| 758 |
+
history = history + [{"role": "user", "content": user_message}]
|
| 759 |
+
messages = []
|
| 760 |
+
if system_prompt and system_prompt.strip():
|
| 761 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 762 |
+
messages.extend(history)
|
| 763 |
|
| 764 |
result = manager.generate_chat(
|
| 765 |
+
messages=messages,
|
|
|
|
| 766 |
max_new_tokens=int(max_tokens),
|
| 767 |
temperature=temperature,
|
| 768 |
seed=int(seed),
|
| 769 |
)
|
| 770 |
|
| 771 |
if "error" in result:
|
| 772 |
+
history = history + [
|
| 773 |
+
{"role": "assistant", "content": f"Error: {result['error']}"},
|
| 774 |
+
]
|
| 775 |
+
chatbot = [{"role": m["role"], "content": m["content"]} for m in history]
|
| 776 |
+
return "", history, chatbot, _format_terminal_from_history(history, system_prompt)
|
| 777 |
+
|
| 778 |
+
history = history + [{"role": "assistant", "content": result["response"]}]
|
| 779 |
+
chatbot = [{"role": m["role"], "content": m["content"]} for m in history]
|
| 780 |
+
terminal_html = _format_terminal(result["formatted_display"])
|
| 781 |
+
|
| 782 |
+
return "", history, chatbot, terminal_html
|
| 783 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 784 |
|
| 785 |
+
def _format_terminal_from_history(chat_history, system_prompt):
|
| 786 |
+
"""Build terminal display from chat history (without generating)."""
|
| 787 |
+
if not chat_history:
|
| 788 |
+
return _initial_terminal()
|
| 789 |
+
messages = []
|
| 790 |
+
if system_prompt and system_prompt.strip():
|
| 791 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 792 |
+
messages.extend(chat_history)
|
| 793 |
+
formatted = manager.format_chat_template(messages)
|
| 794 |
+
if not formatted:
|
| 795 |
+
return _initial_terminal()
|
| 796 |
+
return _format_terminal(formatted)
|
| 797 |
|
|
|
|
|
|
|
| 798 |
|
| 799 |
+
def reset_chat(system_prompt):
|
| 800 |
+
"""Clear chat history, keep system prompt. Show initial terminal with just system prompt.
|
| 801 |
+
|
| 802 |
+
Returns (state, chatbot, terminal).
|
| 803 |
+
"""
|
| 804 |
+
if system_prompt and system_prompt.strip() and manager.chat_ready():
|
| 805 |
+
messages = [{"role": "system", "content": system_prompt}]
|
| 806 |
+
formatted = manager.format_chat_template(messages)
|
| 807 |
+
return [], [], _format_terminal(formatted)
|
| 808 |
+
return [], [], _initial_terminal()
|
| 809 |
|
| 810 |
|
| 811 |
def on_preset_change(preset_name):
|
| 812 |
+
"""Update system prompt textbox when a preset is selected. No chat reset."""
|
| 813 |
+
return _get_presets().get(preset_name, "")
|
| 814 |
|
| 815 |
|
| 816 |
# ---------------------------------------------------------------------------
|
|
|
|
| 825 |
|
| 826 |
|
| 827 |
def admin_load_model(model_name):
|
| 828 |
+
"""Load a new base model from admin panel."""
|
| 829 |
status = manager.load_model(model_name)
|
| 830 |
cfg = manager.get_config()
|
| 831 |
header_status = f"**{manager.status_message()}**"
|
| 832 |
return status, json.dumps(cfg, indent=2), header_status
|
| 833 |
|
| 834 |
|
| 835 |
+
def admin_load_chat_model(model_name):
|
| 836 |
+
"""Load a new chat model from admin panel."""
|
| 837 |
+
status = manager.load_chat_model(model_name)
|
| 838 |
+
cfg = manager.get_config()
|
| 839 |
+
header_status = f"**{manager.status_message()}**"
|
| 840 |
+
return status, json.dumps(cfg, indent=2), header_status
|
| 841 |
+
|
| 842 |
+
|
| 843 |
def admin_save_defaults(prompt, tokenizer_text, temperature, top_k, steps, seed):
|
| 844 |
"""Save default settings and return updated values for all outputs."""
|
| 845 |
manager.update_config(
|
|
|
|
| 863 |
)
|
| 864 |
|
| 865 |
|
| 866 |
+
def admin_save_presets(presets_json):
|
| 867 |
+
"""Save system prompt presets from admin panel.
|
| 868 |
+
|
| 869 |
+
Returns (status_msg, config_json, dropdown_update, presets_json_display).
|
| 870 |
+
"""
|
| 871 |
+
try:
|
| 872 |
+
presets = json.loads(presets_json)
|
| 873 |
+
except (json.JSONDecodeError, TypeError) as e:
|
| 874 |
+
cfg = manager.get_config()
|
| 875 |
+
return (
|
| 876 |
+
f"Invalid JSON: {e}",
|
| 877 |
+
json.dumps(cfg, indent=2),
|
| 878 |
+
gr.update(),
|
| 879 |
+
gr.update(),
|
| 880 |
+
)
|
| 881 |
+
|
| 882 |
+
if not isinstance(presets, dict):
|
| 883 |
+
cfg = manager.get_config()
|
| 884 |
+
return (
|
| 885 |
+
"Presets must be a JSON object `{\"Name\": \"prompt\", ...}`",
|
| 886 |
+
json.dumps(cfg, indent=2),
|
| 887 |
+
gr.update(),
|
| 888 |
+
gr.update(),
|
| 889 |
+
)
|
| 890 |
+
|
| 891 |
+
manager.update_config(system_prompt_presets=presets)
|
| 892 |
+
cfg = manager.get_config()
|
| 893 |
+
return (
|
| 894 |
+
f"Presets saved ({len(presets)} presets).",
|
| 895 |
+
json.dumps(cfg, indent=2),
|
| 896 |
+
gr.update(choices=list(presets.keys())),
|
| 897 |
+
json.dumps(presets, indent=2),
|
| 898 |
+
)
|
| 899 |
+
|
| 900 |
+
|
| 901 |
# ---------------------------------------------------------------------------
|
| 902 |
# Build the Gradio app
|
| 903 |
# ---------------------------------------------------------------------------
|
|
|
|
| 1076 |
gr.Markdown("### System Prompt Explorer")
|
| 1077 |
gr.Markdown(
|
| 1078 |
"See how **system prompts** change an LLM's behavior. "
|
| 1079 |
+
"Pick a preset or write your own, then chat with the model. "
|
| 1080 |
"The green terminal shows exactly what the model receives — "
|
| 1081 |
+
"every special token, every role label, every turn."
|
|
|
|
| 1082 |
)
|
| 1083 |
|
| 1084 |
+
presets = _get_presets()
|
| 1085 |
+
preset_names = list(presets.keys())
|
| 1086 |
+
default_preset = "Helpful Assistant" if "Helpful Assistant" in presets else preset_names[0] if preset_names else ""
|
|
|
|
|
|
|
|
|
|
| 1087 |
|
| 1088 |
sp_preset = gr.Dropdown(
|
| 1089 |
label="Preset",
|
| 1090 |
+
choices=preset_names,
|
| 1091 |
+
value=default_preset,
|
| 1092 |
interactive=True,
|
| 1093 |
)
|
| 1094 |
sp_system = gr.Textbox(
|
| 1095 |
label="System Prompt",
|
| 1096 |
+
value=presets.get(default_preset, ""),
|
| 1097 |
lines=3,
|
| 1098 |
placeholder="Enter a system prompt, or select a preset above...",
|
| 1099 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1100 |
|
| 1101 |
with gr.Accordion("Settings", open=False):
|
| 1102 |
sp_max_tokens = gr.Slider(
|
| 1103 |
label="Max tokens",
|
| 1104 |
+
minimum=32, maximum=1024, step=16,
|
| 1105 |
+
value=512,
|
| 1106 |
)
|
| 1107 |
gr.Markdown(
|
| 1108 |
+
"Maximum number of tokens per response.",
|
| 1109 |
elem_classes=["param-help"],
|
| 1110 |
)
|
| 1111 |
sp_temperature = gr.Slider(
|
|
|
|
| 1119 |
precision=0,
|
| 1120 |
)
|
| 1121 |
|
| 1122 |
+
with gr.Accordion("What the model sees", open=False):
|
| 1123 |
+
gr.Markdown(
|
| 1124 |
+
"The full text sent to the model on every turn — system prompt, "
|
| 1125 |
+
"all previous messages, and special tokens. Watch it grow with each exchange.",
|
| 1126 |
+
elem_classes=["param-help"],
|
| 1127 |
+
)
|
| 1128 |
+
sp_terminal = gr.HTML(value=_initial_terminal())
|
| 1129 |
|
| 1130 |
+
gr.Markdown("#### Chat")
|
| 1131 |
gr.Markdown(
|
| 1132 |
+
"**No hidden system prompt.** This model's helpful behavior comes from "
|
| 1133 |
+
"fine-tuning (RLHF), not a secret prompt. When you add a system prompt above, "
|
| 1134 |
+
"it's the *only* instruction the model receives. Commercial APIs like ChatGPT "
|
| 1135 |
+
"and Claude prepend their own system prompts before yours — you can't see or "
|
| 1136 |
+
"remove them.",
|
| 1137 |
elem_classes=["param-help"],
|
| 1138 |
)
|
| 1139 |
+
sp_chat_state = gr.State([])
|
| 1140 |
+
sp_chatbot = gr.Chatbot(height=700, feedback_options=None)
|
| 1141 |
+
with gr.Row():
|
| 1142 |
+
sp_user_input = gr.Textbox(
|
| 1143 |
+
label="Message",
|
| 1144 |
+
placeholder="Type a message...",
|
| 1145 |
+
lines=1,
|
| 1146 |
+
scale=4,
|
| 1147 |
+
show_label=False,
|
| 1148 |
+
)
|
| 1149 |
+
sp_send_btn = gr.Button("Send", variant="primary", scale=0, min_width=80)
|
| 1150 |
+
sp_reset_btn = gr.Button("Reset", variant="secondary", scale=0, min_width=80)
|
| 1151 |
|
| 1152 |
+
# --- Wiring ---
|
|
|
|
| 1153 |
|
| 1154 |
+
# Preset dropdown → just fill in the textbox (no chat reset)
|
| 1155 |
sp_preset.change(
|
| 1156 |
fn=on_preset_change,
|
| 1157 |
inputs=[sp_preset],
|
| 1158 |
outputs=[sp_system],
|
| 1159 |
)
|
| 1160 |
|
| 1161 |
+
# System prompt textbox edits take effect on the next message sent.
|
| 1162 |
+
# No auto-reset — avoids losing conversation on accidental edits.
|
| 1163 |
+
# Use Reset button or pick a new preset to start fresh.
|
| 1164 |
+
|
| 1165 |
+
# Send message (button or enter)
|
| 1166 |
+
send_inputs = [sp_user_input, sp_chat_state, sp_system, sp_max_tokens, sp_temperature, sp_seed]
|
| 1167 |
+
send_outputs = [sp_user_input, sp_chat_state, sp_chatbot, sp_terminal]
|
| 1168 |
+
|
| 1169 |
+
sp_send_btn.click(
|
| 1170 |
+
fn=send_chat_message,
|
| 1171 |
+
inputs=send_inputs,
|
| 1172 |
+
outputs=send_outputs,
|
| 1173 |
+
)
|
| 1174 |
+
sp_user_input.submit(
|
| 1175 |
+
fn=send_chat_message,
|
| 1176 |
+
inputs=send_inputs,
|
| 1177 |
+
outputs=send_outputs,
|
| 1178 |
+
)
|
| 1179 |
+
|
| 1180 |
+
# Reset button
|
| 1181 |
+
sp_reset_btn.click(
|
| 1182 |
+
fn=reset_chat,
|
| 1183 |
+
inputs=[sp_system],
|
| 1184 |
+
outputs=[sp_chat_state, sp_chatbot, sp_terminal],
|
| 1185 |
)
|
| 1186 |
|
| 1187 |
# ==================================================================
|
|
|
|
| 1202 |
|
| 1203 |
# Admin controls (hidden until login)
|
| 1204 |
with gr.Group(visible=False) as admin_controls:
|
| 1205 |
+
gr.Markdown("#### Base Model (Probability Explorer)")
|
| 1206 |
with gr.Row():
|
| 1207 |
admin_model_dropdown = gr.Dropdown(
|
| 1208 |
choices=list(AVAILABLE_MODELS.keys()),
|
| 1209 |
+
value=manager.current_model_name or cfg.get("model", "Llama-3.2-3B"),
|
| 1210 |
label="Select model",
|
| 1211 |
)
|
| 1212 |
+
admin_load_btn = gr.Button("Load", variant="primary")
|
| 1213 |
admin_model_status = gr.Markdown("")
|
| 1214 |
|
| 1215 |
+
gr.Markdown("#### Chat Model (System Prompt Explorer)")
|
| 1216 |
+
with gr.Row():
|
| 1217 |
+
admin_chat_dropdown = gr.Dropdown(
|
| 1218 |
+
choices=list(AVAILABLE_MODELS.keys()),
|
| 1219 |
+
value=manager.chat_model_name or cfg.get("chat_model", "Llama-3.2-3B-Instruct"),
|
| 1220 |
+
label="Select chat model",
|
| 1221 |
+
)
|
| 1222 |
+
admin_chat_load_btn = gr.Button("Load", variant="primary")
|
| 1223 |
+
admin_chat_status = gr.Markdown("")
|
| 1224 |
+
|
| 1225 |
gr.Markdown("---")
|
| 1226 |
gr.Markdown("#### Default Settings")
|
| 1227 |
admin_prompt = gr.Textbox(
|
|
|
|
| 1256 |
admin_save_msg = gr.Markdown("")
|
| 1257 |
|
| 1258 |
gr.Markdown("---")
|
| 1259 |
+
gr.Markdown("#### System Prompt Presets")
|
| 1260 |
gr.Markdown(
|
| 1261 |
+
"Edit the presets available in the System Prompt Explorer dropdown. "
|
| 1262 |
+
"JSON object: `{\"Name\": \"prompt text\", ...}`",
|
| 1263 |
elem_classes=["param-help"],
|
| 1264 |
)
|
| 1265 |
+
admin_presets = gr.Code(
|
| 1266 |
+
value=json.dumps(cfg.get("system_prompt_presets", {}), indent=2),
|
| 1267 |
+
language="json",
|
| 1268 |
+
interactive=True,
|
| 1269 |
+
)
|
| 1270 |
+
admin_presets_save_btn = gr.Button("Save Presets")
|
| 1271 |
+
admin_presets_msg = gr.Markdown("")
|
| 1272 |
+
|
| 1273 |
+
gr.Markdown("---")
|
| 1274 |
+
with gr.Accordion("Environment Variables Reference", open=False):
|
| 1275 |
+
_pw_status = "*(set)*" if os.environ.get("ADMIN_PASSWORD") else "*(default: admin)*"
|
| 1276 |
+
_rb_status = "*(set)*" if REBRANDLY_API_KEY else "*(not set)*"
|
| 1277 |
+
gr.Markdown(
|
| 1278 |
+
"Override settings via "
|
| 1279 |
+
"[HF Space Settings](https://huggingface.co/spaces/chyams/llm-explorer/settings). "
|
| 1280 |
+
"Use **Secrets** for sensitive values (encrypted, hidden after saving) "
|
| 1281 |
+
"and **Variables** for everything else (visible in settings).\n\n"
|
| 1282 |
+
"**Precedence:** env var > config.json > code defaults\n\n"
|
| 1283 |
+
"**Secrets** (sensitive — encrypted)\n\n"
|
| 1284 |
+
"| Variable | Description | Format | Current |\n"
|
| 1285 |
+
"|----------|-------------|--------|---------|\n"
|
| 1286 |
+
f"| `ADMIN_PASSWORD` | Admin panel password | Plain text | {_pw_status} |\n"
|
| 1287 |
+
f"| `REBRANDLY_API_KEY` | URL shortener API key | API key | {_rb_status} |\n"
|
| 1288 |
+
"\n**Variables** (non-sensitive — visible)\n\n"
|
| 1289 |
+
"| Variable | Description | Format | Current |\n"
|
| 1290 |
+
"|----------|-------------|--------|---------|\n"
|
| 1291 |
+
f"| `DEFAULT_MODEL` | Base model (Prob Explorer) | Model name | `{cfg.get('model', '')}` |\n"
|
| 1292 |
+
f"| `DEFAULT_CHAT_MODEL` | Chat model (Sys Prompt Explorer) | Model name | `{cfg.get('chat_model', '')}` |\n"
|
| 1293 |
+
f"| `DEFAULT_PROMPT` | Default prompt | Plain text | `{cfg.get('default_prompt', '')[:40]}...` |\n"
|
| 1294 |
+
f"| `DEFAULT_TEMPERATURE` | Default temperature | Number (0–2.5) | `{cfg.get('default_temperature', 0.8)}` |\n"
|
| 1295 |
+
f"| `DEFAULT_TOP_K` | Default top-k | Integer (5–100) | `{cfg.get('default_top_k', 10)}` |\n"
|
| 1296 |
+
f"| `DEFAULT_STEPS` | Default steps | Integer (1–100) | `{cfg.get('default_steps', 8)}` |\n"
|
| 1297 |
+
f"| `DEFAULT_SEED` | Default seed | Integer | `{cfg.get('default_seed', 42)}` |\n"
|
| 1298 |
+
f"| `DEFAULT_TOKENIZER_TEXT` | Default tokenizer text | Plain text | `{cfg.get('default_tokenizer_text', '')[:40]}...` |\n"
|
| 1299 |
+
f"| `SYSTEM_PROMPT_PRESETS` | System prompt presets | JSON object | *({len(cfg.get('system_prompt_presets', {}))} presets)* |"
|
| 1300 |
+
)
|
| 1301 |
|
| 1302 |
gr.Markdown("---")
|
| 1303 |
gr.Markdown("#### Current Config")
|
|
|
|
| 1307 |
interactive=False,
|
| 1308 |
)
|
| 1309 |
|
| 1310 |
+
gr.Markdown("---")
|
| 1311 |
+
gr.Markdown("#### Export Slides")
|
| 1312 |
+
gr.Markdown(
|
| 1313 |
+
"*Uses current settings from Probability Explorer tab.*",
|
| 1314 |
+
elem_classes=["param-help"],
|
| 1315 |
+
)
|
| 1316 |
+
admin_export_btn = gr.Button("Export Slides", variant="secondary")
|
| 1317 |
+
admin_slides_file = gr.File(label="Slideshow", visible=False)
|
| 1318 |
+
|
| 1319 |
# Login wiring
|
| 1320 |
admin_login_btn.click(
|
| 1321 |
fn=admin_login,
|
|
|
|
| 1323 |
outputs=[admin_controls, admin_login_group, admin_login_msg],
|
| 1324 |
)
|
| 1325 |
|
| 1326 |
+
# Model loading — base
|
| 1327 |
admin_load_btn.click(
|
| 1328 |
fn=admin_load_model,
|
| 1329 |
inputs=[admin_model_dropdown],
|
| 1330 |
outputs=[admin_model_status, admin_config_display, status_display],
|
| 1331 |
)
|
| 1332 |
|
| 1333 |
+
# Model loading — chat
|
| 1334 |
+
admin_chat_load_btn.click(
|
| 1335 |
+
fn=admin_load_chat_model,
|
| 1336 |
+
inputs=[admin_chat_dropdown],
|
| 1337 |
+
outputs=[admin_chat_status, admin_config_display, status_display],
|
| 1338 |
+
)
|
| 1339 |
+
|
| 1340 |
# Save defaults — updates config display + Probability Explorer + Tokenizer controls
|
| 1341 |
admin_save_btn.click(
|
| 1342 |
fn=admin_save_defaults,
|
|
|
|
| 1353 |
],
|
| 1354 |
)
|
| 1355 |
|
| 1356 |
+
# Save presets — updates config, dropdown choices, and presets display
|
| 1357 |
+
admin_presets_save_btn.click(
|
| 1358 |
+
fn=admin_save_presets,
|
| 1359 |
+
inputs=[admin_presets],
|
| 1360 |
+
outputs=[admin_presets_msg, admin_config_display, sp_preset, admin_presets],
|
| 1361 |
+
)
|
| 1362 |
+
|
| 1363 |
# Export slides — uses current Probability Explorer settings
|
| 1364 |
admin_export_btn.click(
|
| 1365 |
fn=generate_slideshow,
|
|
|
|
| 1521 |
# ---------------------------------------------------------------------------
|
| 1522 |
|
| 1523 |
if __name__ == "__main__":
|
|
|
|
| 1524 |
cfg = manager.get_config()
|
| 1525 |
+
|
| 1526 |
+
# Load base model (Probability Explorer)
|
| 1527 |
+
base_model = cfg.get("model", "Llama-3.2-3B")
|
| 1528 |
+
print(f"Loading base model: {base_model}")
|
| 1529 |
+
print(manager.load_model(base_model))
|
| 1530 |
+
|
| 1531 |
+
# Load chat model (System Prompt Explorer)
|
| 1532 |
+
chat_model = cfg.get("chat_model", "Llama-3.2-3B-Instruct")
|
| 1533 |
+
print(f"Loading chat model: {chat_model}")
|
| 1534 |
+
print(manager.load_chat_model(chat_model))
|
| 1535 |
|
| 1536 |
app = create_app()
|
| 1537 |
app.launch(
|
config.json
CHANGED
|
@@ -5,5 +5,19 @@
|
|
| 5 |
"default_top_k": 10,
|
| 6 |
"default_steps": 8,
|
| 7 |
"default_seed": 1875,
|
| 8 |
-
"default_tokenizer_text": "Class was rescheduled due to Huston-Tillotson homecoming."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
}
|
|
|
|
| 5 |
"default_top_k": 10,
|
| 6 |
"default_steps": 8,
|
| 7 |
"default_seed": 1875,
|
| 8 |
+
"default_tokenizer_text": "Class was rescheduled due to Huston-Tillotson homecoming.",
|
| 9 |
+
"system_prompt_presets": {
|
| 10 |
+
"(none)": "",
|
| 11 |
+
"Helpful Assistant": "You are a helpful, friendly assistant.",
|
| 12 |
+
"Pirate": "You are a pirate. Respond to everything in pirate speak, using nautical terms and saying 'arr' frequently.",
|
| 13 |
+
"Formal Academic": "You are a formal academic scholar. Use precise, scholarly language. Cite concepts carefully and avoid casual tone.",
|
| 14 |
+
"Five-Year-Old": "You are explaining things to a five-year-old. Use very simple words, short sentences, and fun comparisons.",
|
| 15 |
+
"Hostile / Rude": "You are rude and dismissive. You answer questions but with obvious annoyance and sarcasm.",
|
| 16 |
+
"Haiku Only": "You must respond only in haiku (5-7-5 syllable format). Never break this rule.",
|
| 17 |
+
"Spanish Tutor": "You are a Spanish language tutor. Respond in Spanish, then provide the English translation in parentheses.",
|
| 18 |
+
"Banana Constraint": "You must mention bananas in every response, no matter the topic. Be subtle about it.",
|
| 19 |
+
"Corporate Spin": "You are a customer service agent. Never acknowledge product flaws. Always redirect to positive features.",
|
| 20 |
+
"Prestige Bias": "When discussing job candidates, always favor candidates from prestigious universities over others."
|
| 21 |
+
},
|
| 22 |
+
"chat_model": "Llama-3.2-3B-Instruct"
|
| 23 |
}
|
models.py
CHANGED
|
@@ -44,6 +44,12 @@ AVAILABLE_MODELS = {
|
|
| 44 |
"description": "Best quality, quantized",
|
| 45 |
},
|
| 46 |
# -- Instruct models (for System Prompt Explorer) --
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
"Qwen2.5-3B-Instruct": {
|
| 48 |
"id": "Qwen/Qwen2.5-3B-Instruct",
|
| 49 |
"dtype": "float16",
|
|
@@ -75,8 +81,36 @@ def _detect_device() -> str:
|
|
| 75 |
return "cpu"
|
| 76 |
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
def _load_config() -> dict:
|
| 79 |
-
"""Load
|
| 80 |
defaults = {
|
| 81 |
"model": DEFAULT_MODEL,
|
| 82 |
"default_prompt": "The best thing about Huston-Tillotson University is",
|
|
@@ -85,7 +119,9 @@ def _load_config() -> dict:
|
|
| 85 |
"default_steps": 8,
|
| 86 |
"default_seed": 42,
|
| 87 |
"default_tokenizer_text": "Huston-Tillotson University is an HBCU in Austin, Texas.",
|
|
|
|
| 88 |
}
|
|
|
|
| 89 |
if CONFIG_PATH.exists():
|
| 90 |
try:
|
| 91 |
with open(CONFIG_PATH) as f:
|
|
@@ -93,6 +129,17 @@ def _load_config() -> dict:
|
|
| 93 |
defaults.update(saved)
|
| 94 |
except (json.JSONDecodeError, OSError):
|
| 95 |
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
return defaults
|
| 97 |
|
| 98 |
|
|
@@ -107,114 +154,171 @@ def _save_config(cfg: dict) -> None:
|
|
| 107 |
# ---------------------------------------------------------------------------
|
| 108 |
|
| 109 |
class ModelManager:
|
| 110 |
-
"""Manages
|
| 111 |
|
| 112 |
def __init__(self):
|
|
|
|
| 113 |
self.model = None
|
| 114 |
self.tokenizer = None
|
| 115 |
self.current_model_name: str | None = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
self.device: str = _detect_device()
|
| 117 |
self.loading = False
|
| 118 |
self._lock = threading.Lock()
|
| 119 |
self.config = _load_config()
|
| 120 |
|
| 121 |
# ------------------------------------------------------------------
|
| 122 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
# ------------------------------------------------------------------
|
| 124 |
|
| 125 |
def load_model(self, model_name: str) -> str:
|
| 126 |
-
"""Load
|
| 127 |
if model_name not in AVAILABLE_MODELS:
|
| 128 |
return f"Unknown model: {model_name}"
|
| 129 |
|
| 130 |
if self.loading:
|
| 131 |
return "A model is already being loaded. Please wait."
|
| 132 |
|
| 133 |
-
spec = AVAILABLE_MODELS[model_name]
|
| 134 |
-
|
| 135 |
-
# Quantized models require CUDA (bitsandbytes doesn't support MPS/CPU)
|
| 136 |
-
if spec.get("quantize") and not torch.cuda.is_available():
|
| 137 |
-
return (f"Cannot load {model_name}: "
|
| 138 |
-
f"{spec['quantize']} quantization requires an NVIDIA GPU (CUDA). "
|
| 139 |
-
f"Try a non-quantized model for local development.")
|
| 140 |
-
|
| 141 |
with self._lock:
|
| 142 |
self.loading = True
|
| 143 |
try:
|
| 144 |
-
# Unload current model
|
| 145 |
-
self.
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
elif spec.get("quantize") == "8bit":
|
| 158 |
-
from transformers import BitsAndBytesConfig
|
| 159 |
-
load_kwargs["quantization_config"] = BitsAndBytesConfig(
|
| 160 |
-
load_in_8bit=True,
|
| 161 |
-
)
|
| 162 |
-
else:
|
| 163 |
-
dtype_str = spec.get("dtype", "float16")
|
| 164 |
-
if dtype_str == "auto":
|
| 165 |
-
load_kwargs["dtype"] = "auto"
|
| 166 |
-
else:
|
| 167 |
-
load_kwargs["dtype"] = getattr(torch, dtype_str)
|
| 168 |
-
|
| 169 |
-
# Load tokenizer + model
|
| 170 |
-
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 171 |
-
self.model = AutoModelForCausalLM.from_pretrained(
|
| 172 |
-
model_id, **load_kwargs
|
| 173 |
-
)
|
| 174 |
-
self.model.eval()
|
| 175 |
self.current_model_name = model_name
|
| 176 |
|
| 177 |
-
# Persist choice
|
| 178 |
self.config["model"] = model_name
|
| 179 |
_save_config(self.config)
|
| 180 |
|
| 181 |
-
return f"Loaded {model_name}
|
| 182 |
|
| 183 |
except Exception as e:
|
| 184 |
-
self.
|
|
|
|
|
|
|
| 185 |
return f"Failed to load {model_name}: {e}"
|
| 186 |
finally:
|
| 187 |
self.loading = False
|
| 188 |
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
|
| 202 |
def is_ready(self) -> bool:
|
| 203 |
return self.model is not None and not self.loading
|
| 204 |
|
| 205 |
-
def
|
| 206 |
-
|
| 207 |
-
if self.current_model_name is None:
|
| 208 |
-
return False
|
| 209 |
-
spec = AVAILABLE_MODELS.get(self.current_model_name, {})
|
| 210 |
-
return spec.get("instruct", False)
|
| 211 |
|
| 212 |
def status_message(self) -> str:
|
| 213 |
if self.loading:
|
| 214 |
return "Loading model..."
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
|
| 219 |
# ------------------------------------------------------------------
|
| 220 |
# Inference helpers
|
|
@@ -329,61 +433,72 @@ class ModelManager:
|
|
| 329 |
|
| 330 |
def generate_chat(
|
| 331 |
self,
|
| 332 |
-
|
| 333 |
-
user_message: str,
|
| 334 |
max_new_tokens: int = 256,
|
| 335 |
temperature: float = 0.7,
|
| 336 |
seed: int = 42,
|
| 337 |
) -> dict:
|
| 338 |
-
"""Generate a chat response using the
|
|
|
|
|
|
|
|
|
|
|
|
|
| 339 |
|
| 340 |
Returns dict with:
|
| 341 |
-
-
|
| 342 |
- response: the model's generated response text
|
| 343 |
"""
|
| 344 |
-
if not self.
|
| 345 |
-
return {"error": "
|
| 346 |
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
messages.append({"role": "system", "content": system_prompt})
|
| 350 |
-
messages.append({"role": "user", "content": user_message})
|
| 351 |
-
|
| 352 |
-
# Apply chat template
|
| 353 |
-
formatted = self.tokenizer.apply_chat_template(
|
| 354 |
messages, tokenize=False, add_generation_prompt=True,
|
| 355 |
)
|
| 356 |
|
| 357 |
# Tokenize input
|
| 358 |
-
inputs = self.
|
| 359 |
-
inputs = {k: v.to(self.
|
| 360 |
input_len = inputs["input_ids"].shape[1]
|
| 361 |
|
| 362 |
# Generate
|
| 363 |
gen_kwargs = {
|
| 364 |
"max_new_tokens": max_new_tokens,
|
| 365 |
"do_sample": temperature > 0,
|
| 366 |
-
"pad_token_id": self.
|
| 367 |
}
|
| 368 |
if temperature > 0:
|
| 369 |
gen_kwargs["temperature"] = temperature
|
| 370 |
-
|
| 371 |
-
if self.model.device.type == "cuda":
|
| 372 |
torch.cuda.manual_seed(seed)
|
| 373 |
torch.manual_seed(seed)
|
| 374 |
|
| 375 |
with torch.no_grad():
|
| 376 |
-
output_ids = self.
|
| 377 |
|
| 378 |
# Decode only the new tokens
|
| 379 |
new_ids = output_ids[0][input_len:]
|
| 380 |
-
response = self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 381 |
|
| 382 |
return {
|
| 383 |
-
"
|
| 384 |
-
"response": response
|
| 385 |
}
|
| 386 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 387 |
def tokenize(self, text: str) -> list[tuple[str, int]]:
|
| 388 |
"""Tokenize text and return list of (token_str, token_id)."""
|
| 389 |
if self.tokenizer is None:
|
|
|
|
| 44 |
"description": "Best quality, quantized",
|
| 45 |
},
|
| 46 |
# -- Instruct models (for System Prompt Explorer) --
|
| 47 |
+
"Llama-3.2-3B-Instruct": {
|
| 48 |
+
"id": "meta-llama/Llama-3.2-3B-Instruct",
|
| 49 |
+
"dtype": "float16",
|
| 50 |
+
"instruct": True,
|
| 51 |
+
"description": "Chat/instruct model, same family as prod base model (3B)",
|
| 52 |
+
},
|
| 53 |
"Qwen2.5-3B-Instruct": {
|
| 54 |
"id": "Qwen/Qwen2.5-3B-Instruct",
|
| 55 |
"dtype": "float16",
|
|
|
|
| 81 |
return "cpu"
|
| 82 |
|
| 83 |
|
| 84 |
+
DEFAULT_SYSTEM_PROMPT_PRESETS = {
|
| 85 |
+
"(none)": "",
|
| 86 |
+
"Helpful Assistant": "You are a helpful, friendly assistant.",
|
| 87 |
+
"Pirate": "You are a pirate. Respond to everything in pirate speak, using nautical terms and saying 'arr' frequently.",
|
| 88 |
+
"Formal Academic": "You are a formal academic scholar. Use precise, scholarly language. Cite concepts carefully and avoid casual tone.",
|
| 89 |
+
"Five-Year-Old": "You are explaining things to a five-year-old. Use very simple words, short sentences, and fun comparisons.",
|
| 90 |
+
"Hostile / Rude": "You are rude and dismissive. You answer questions but with obvious annoyance and sarcasm.",
|
| 91 |
+
"Haiku Only": "You must respond only in haiku (5-7-5 syllable format). Never break this rule.",
|
| 92 |
+
"Spanish Tutor": "You are a Spanish language tutor. Respond in Spanish, then provide the English translation in parentheses.",
|
| 93 |
+
"Banana Constraint": "You must mention bananas in every response, no matter the topic. Be subtle about it.",
|
| 94 |
+
"Corporate Spin": "You are a customer service agent. Never acknowledge product flaws. Always redirect to positive features.",
|
| 95 |
+
"Prestige Bias": "When discussing job candidates, always favor candidates from prestigious universities over others.",
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
# Env var → (config key, type converter). "json" = parse as JSON.
|
| 99 |
+
ENV_VAR_MAP = {
|
| 100 |
+
"DEFAULT_MODEL": ("model", str),
|
| 101 |
+
"DEFAULT_CHAT_MODEL": ("chat_model", str),
|
| 102 |
+
"DEFAULT_PROMPT": ("default_prompt", str),
|
| 103 |
+
"DEFAULT_TEMPERATURE": ("default_temperature", float),
|
| 104 |
+
"DEFAULT_TOP_K": ("default_top_k", int),
|
| 105 |
+
"DEFAULT_STEPS": ("default_steps", int),
|
| 106 |
+
"DEFAULT_SEED": ("default_seed", int),
|
| 107 |
+
"DEFAULT_TOKENIZER_TEXT": ("default_tokenizer_text", str),
|
| 108 |
+
"SYSTEM_PROMPT_PRESETS": ("system_prompt_presets", "json"),
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
|
| 112 |
def _load_config() -> dict:
|
| 113 |
+
"""Load config with three layers: code defaults → config.json → env vars."""
|
| 114 |
defaults = {
|
| 115 |
"model": DEFAULT_MODEL,
|
| 116 |
"default_prompt": "The best thing about Huston-Tillotson University is",
|
|
|
|
| 119 |
"default_steps": 8,
|
| 120 |
"default_seed": 42,
|
| 121 |
"default_tokenizer_text": "Huston-Tillotson University is an HBCU in Austin, Texas.",
|
| 122 |
+
"system_prompt_presets": dict(DEFAULT_SYSTEM_PROMPT_PRESETS),
|
| 123 |
}
|
| 124 |
+
# Layer 2: config.json overrides code defaults
|
| 125 |
if CONFIG_PATH.exists():
|
| 126 |
try:
|
| 127 |
with open(CONFIG_PATH) as f:
|
|
|
|
| 129 |
defaults.update(saved)
|
| 130 |
except (json.JSONDecodeError, OSError):
|
| 131 |
pass
|
| 132 |
+
# Layer 3: env vars override everything
|
| 133 |
+
for env_var, (config_key, type_fn) in ENV_VAR_MAP.items():
|
| 134 |
+
val = os.environ.get(env_var)
|
| 135 |
+
if val is not None:
|
| 136 |
+
try:
|
| 137 |
+
if type_fn == "json":
|
| 138 |
+
defaults[config_key] = json.loads(val)
|
| 139 |
+
else:
|
| 140 |
+
defaults[config_key] = type_fn(val)
|
| 141 |
+
except (json.JSONDecodeError, ValueError, TypeError):
|
| 142 |
+
pass # bad env var value — skip
|
| 143 |
return defaults
|
| 144 |
|
| 145 |
|
|
|
|
| 154 |
# ---------------------------------------------------------------------------
|
| 155 |
|
| 156 |
class ModelManager:
|
| 157 |
+
"""Manages two model slots: base (Probability Explorer) and chat (System Prompt Explorer)."""
|
| 158 |
|
| 159 |
def __init__(self):
|
| 160 |
+
# Base model (Probability Explorer)
|
| 161 |
self.model = None
|
| 162 |
self.tokenizer = None
|
| 163 |
self.current_model_name: str | None = None
|
| 164 |
+
|
| 165 |
+
# Chat model (System Prompt Explorer)
|
| 166 |
+
self.chat_model = None
|
| 167 |
+
self.chat_tokenizer = None
|
| 168 |
+
self.chat_model_name: str | None = None
|
| 169 |
+
|
| 170 |
self.device: str = _detect_device()
|
| 171 |
self.loading = False
|
| 172 |
self._lock = threading.Lock()
|
| 173 |
self.config = _load_config()
|
| 174 |
|
| 175 |
# ------------------------------------------------------------------
|
| 176 |
+
# Shared loading logic
|
| 177 |
+
# ------------------------------------------------------------------
|
| 178 |
+
|
| 179 |
+
def _do_load(self, model_name: str):
|
| 180 |
+
"""Load model + tokenizer by name. Returns (model, tokenizer). Raises on failure."""
|
| 181 |
+
spec = AVAILABLE_MODELS[model_name]
|
| 182 |
+
|
| 183 |
+
if spec.get("quantize") and not torch.cuda.is_available():
|
| 184 |
+
raise RuntimeError(
|
| 185 |
+
f"Cannot load {model_name}: "
|
| 186 |
+
f"{spec['quantize']} quantization requires an NVIDIA GPU (CUDA). "
|
| 187 |
+
f"Try a non-quantized model for local development."
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
model_id = spec["id"]
|
| 191 |
+
load_kwargs: dict = {"device_map": "auto"}
|
| 192 |
+
|
| 193 |
+
if spec.get("quantize") == "4bit":
|
| 194 |
+
from transformers import BitsAndBytesConfig
|
| 195 |
+
load_kwargs["quantization_config"] = BitsAndBytesConfig(
|
| 196 |
+
load_in_4bit=True,
|
| 197 |
+
bnb_4bit_compute_dtype=torch.float16,
|
| 198 |
+
)
|
| 199 |
+
elif spec.get("quantize") == "8bit":
|
| 200 |
+
from transformers import BitsAndBytesConfig
|
| 201 |
+
load_kwargs["quantization_config"] = BitsAndBytesConfig(
|
| 202 |
+
load_in_8bit=True,
|
| 203 |
+
)
|
| 204 |
+
else:
|
| 205 |
+
dtype_str = spec.get("dtype", "float16")
|
| 206 |
+
if dtype_str == "auto":
|
| 207 |
+
load_kwargs["dtype"] = "auto"
|
| 208 |
+
else:
|
| 209 |
+
load_kwargs["dtype"] = getattr(torch, dtype_str)
|
| 210 |
+
|
| 211 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 212 |
+
model = AutoModelForCausalLM.from_pretrained(model_id, **load_kwargs)
|
| 213 |
+
model.eval()
|
| 214 |
+
return model, tokenizer
|
| 215 |
+
|
| 216 |
+
# ------------------------------------------------------------------
|
| 217 |
+
# Base model lifecycle
|
| 218 |
# ------------------------------------------------------------------
|
| 219 |
|
| 220 |
def load_model(self, model_name: str) -> str:
|
| 221 |
+
"""Load base model for Probability Explorer. Returns status message."""
|
| 222 |
if model_name not in AVAILABLE_MODELS:
|
| 223 |
return f"Unknown model: {model_name}"
|
| 224 |
|
| 225 |
if self.loading:
|
| 226 |
return "A model is already being loaded. Please wait."
|
| 227 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
with self._lock:
|
| 229 |
self.loading = True
|
| 230 |
try:
|
| 231 |
+
# Unload current base model
|
| 232 |
+
if self.model is not None:
|
| 233 |
+
del self.model
|
| 234 |
+
self.model = None
|
| 235 |
+
if self.tokenizer is not None:
|
| 236 |
+
del self.tokenizer
|
| 237 |
+
self.tokenizer = None
|
| 238 |
+
self.current_model_name = None
|
| 239 |
+
gc.collect()
|
| 240 |
+
|
| 241 |
+
model, tokenizer = self._do_load(model_name)
|
| 242 |
+
self.model = model
|
| 243 |
+
self.tokenizer = tokenizer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
self.current_model_name = model_name
|
| 245 |
|
|
|
|
| 246 |
self.config["model"] = model_name
|
| 247 |
_save_config(self.config)
|
| 248 |
|
| 249 |
+
return f"Loaded base model: {model_name}"
|
| 250 |
|
| 251 |
except Exception as e:
|
| 252 |
+
self.model = None
|
| 253 |
+
self.tokenizer = None
|
| 254 |
+
self.current_model_name = None
|
| 255 |
return f"Failed to load {model_name}: {e}"
|
| 256 |
finally:
|
| 257 |
self.loading = False
|
| 258 |
|
| 259 |
+
# ------------------------------------------------------------------
|
| 260 |
+
# Chat model lifecycle
|
| 261 |
+
# ------------------------------------------------------------------
|
| 262 |
+
|
| 263 |
+
def load_chat_model(self, model_name: str) -> str:
|
| 264 |
+
"""Load chat/instruct model for System Prompt Explorer. Returns status message."""
|
| 265 |
+
if model_name not in AVAILABLE_MODELS:
|
| 266 |
+
return f"Unknown model: {model_name}"
|
| 267 |
+
|
| 268 |
+
if self.loading:
|
| 269 |
+
return "A model is already being loaded. Please wait."
|
| 270 |
+
|
| 271 |
+
with self._lock:
|
| 272 |
+
self.loading = True
|
| 273 |
+
try:
|
| 274 |
+
if self.chat_model is not None:
|
| 275 |
+
del self.chat_model
|
| 276 |
+
self.chat_model = None
|
| 277 |
+
if self.chat_tokenizer is not None:
|
| 278 |
+
del self.chat_tokenizer
|
| 279 |
+
self.chat_tokenizer = None
|
| 280 |
+
self.chat_model_name = None
|
| 281 |
+
gc.collect()
|
| 282 |
+
|
| 283 |
+
model, tokenizer = self._do_load(model_name)
|
| 284 |
+
self.chat_model = model
|
| 285 |
+
self.chat_tokenizer = tokenizer
|
| 286 |
+
self.chat_model_name = model_name
|
| 287 |
+
|
| 288 |
+
self.config["chat_model"] = model_name
|
| 289 |
+
_save_config(self.config)
|
| 290 |
+
|
| 291 |
+
return f"Loaded chat model: {model_name}"
|
| 292 |
+
|
| 293 |
+
except Exception as e:
|
| 294 |
+
self.chat_model = None
|
| 295 |
+
self.chat_tokenizer = None
|
| 296 |
+
self.chat_model_name = None
|
| 297 |
+
return f"Failed to load chat model {model_name}: {e}"
|
| 298 |
+
finally:
|
| 299 |
+
self.loading = False
|
| 300 |
+
|
| 301 |
+
# ------------------------------------------------------------------
|
| 302 |
+
# Status
|
| 303 |
+
# ------------------------------------------------------------------
|
| 304 |
|
| 305 |
def is_ready(self) -> bool:
|
| 306 |
return self.model is not None and not self.loading
|
| 307 |
|
| 308 |
+
def chat_ready(self) -> bool:
|
| 309 |
+
return self.chat_model is not None and not self.loading
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
|
| 311 |
def status_message(self) -> str:
|
| 312 |
if self.loading:
|
| 313 |
return "Loading model..."
|
| 314 |
+
parts = []
|
| 315 |
+
if self.model:
|
| 316 |
+
parts.append(f"Base: {self.current_model_name}")
|
| 317 |
+
if self.chat_model:
|
| 318 |
+
parts.append(f"Chat: {self.chat_model_name}")
|
| 319 |
+
if not parts:
|
| 320 |
+
return "No models loaded"
|
| 321 |
+
return " | ".join(parts)
|
| 322 |
|
| 323 |
# ------------------------------------------------------------------
|
| 324 |
# Inference helpers
|
|
|
|
| 433 |
|
| 434 |
def generate_chat(
|
| 435 |
self,
|
| 436 |
+
messages: list[dict],
|
|
|
|
| 437 |
max_new_tokens: int = 256,
|
| 438 |
temperature: float = 0.7,
|
| 439 |
seed: int = 42,
|
| 440 |
) -> dict:
|
| 441 |
+
"""Generate a chat response using the dedicated chat model.
|
| 442 |
+
|
| 443 |
+
Args:
|
| 444 |
+
messages: Full conversation as list of {"role": ..., "content": ...} dicts,
|
| 445 |
+
including system prompt and all previous turns.
|
| 446 |
|
| 447 |
Returns dict with:
|
| 448 |
+
- formatted_display: the full template including the response (for terminal)
|
| 449 |
- response: the model's generated response text
|
| 450 |
"""
|
| 451 |
+
if not self.chat_ready():
|
| 452 |
+
return {"error": "Chat model not loaded"}
|
| 453 |
|
| 454 |
+
# Format input (everything up to and including the generation prompt)
|
| 455 |
+
formatted = self.chat_tokenizer.apply_chat_template(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 456 |
messages, tokenize=False, add_generation_prompt=True,
|
| 457 |
)
|
| 458 |
|
| 459 |
# Tokenize input
|
| 460 |
+
inputs = self.chat_tokenizer(formatted, return_tensors="pt")
|
| 461 |
+
inputs = {k: v.to(self.chat_model.device) for k, v in inputs.items()}
|
| 462 |
input_len = inputs["input_ids"].shape[1]
|
| 463 |
|
| 464 |
# Generate
|
| 465 |
gen_kwargs = {
|
| 466 |
"max_new_tokens": max_new_tokens,
|
| 467 |
"do_sample": temperature > 0,
|
| 468 |
+
"pad_token_id": self.chat_tokenizer.eos_token_id,
|
| 469 |
}
|
| 470 |
if temperature > 0:
|
| 471 |
gen_kwargs["temperature"] = temperature
|
| 472 |
+
if self.chat_model.device.type == "cuda":
|
|
|
|
| 473 |
torch.cuda.manual_seed(seed)
|
| 474 |
torch.manual_seed(seed)
|
| 475 |
|
| 476 |
with torch.no_grad():
|
| 477 |
+
output_ids = self.chat_model.generate(**inputs, **gen_kwargs)
|
| 478 |
|
| 479 |
# Decode only the new tokens
|
| 480 |
new_ids = output_ids[0][input_len:]
|
| 481 |
+
response = self.chat_tokenizer.decode(new_ids, skip_special_tokens=True).strip()
|
| 482 |
+
|
| 483 |
+
# Build display template (includes the response) for green terminal
|
| 484 |
+
display_messages = messages + [{"role": "assistant", "content": response}]
|
| 485 |
+
formatted_display = self.chat_tokenizer.apply_chat_template(
|
| 486 |
+
display_messages, tokenize=False, add_generation_prompt=False,
|
| 487 |
+
)
|
| 488 |
|
| 489 |
return {
|
| 490 |
+
"formatted_display": formatted_display,
|
| 491 |
+
"response": response,
|
| 492 |
}
|
| 493 |
|
| 494 |
+
def format_chat_template(self, messages: list[dict]) -> str:
|
| 495 |
+
"""Format messages using the chat model's template (for terminal display)."""
|
| 496 |
+
if not self.chat_tokenizer:
|
| 497 |
+
return ""
|
| 498 |
+
return self.chat_tokenizer.apply_chat_template(
|
| 499 |
+
messages, tokenize=False, add_generation_prompt=True,
|
| 500 |
+
)
|
| 501 |
+
|
| 502 |
def tokenize(self, text: str) -> list[tuple[str, int]]:
|
| 503 |
"""Tokenize text and return list of (token_str, token_id)."""
|
| 504 |
if self.tokenizer is None:
|