chyams Claude Opus 4.6 commited on
Commit
beb8b02
·
1 Parent(s): f2b5e3f

System Prompt Explorer: dual model, multi-turn chat, configurable presets

Browse files

- Dual model architecture: base (Llama-3.2-3B) + chat (Llama-3.2-3B-Instruct)
- Multi-turn chat with gr.State for clean history (Chatbot display-only)
- 11 configurable presets via admin panel or SYSTEM_PROMPT_PRESETS env var
- All config values overridable via env vars (Secrets vs Variables documented)
- No auto-reset on prompt changes; green terminal collapsed by default
- Educational note: no hidden system prompt, helpfulness from RLHF

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (3) hide show
  1. app.py +316 -117
  2. config.json +15 -1
  3. models.py +204 -89
app.py CHANGED
@@ -17,6 +17,8 @@ from datetime import datetime, timezone, timedelta
17
 
18
  import gradio as gr
19
 
 
 
20
  from models import AVAILABLE_MODELS, manager, demo_tokenizer
21
 
22
  # ---------------------------------------------------------------------------
@@ -663,16 +665,12 @@ def tokenize_text(text):
663
  # Tab 3: System Prompt Explorer
664
  # ---------------------------------------------------------------------------
665
 
666
- SYSTEM_PROMPT_PRESETS = {
667
- "(none)": "",
668
- "Helpful Assistant": "You are a helpful, friendly assistant.",
669
- "Pirate": "You are a pirate. Respond to everything in pirate speak, using nautical terms and saying 'arr' frequently.",
670
- "Formal Academic": "You are a formal academic scholar. Use precise, scholarly language. Cite concepts carefully and avoid casual tone.",
671
- "Five-Year-Old": "You are explaining things to a five-year-old. Use very simple words, short sentences, and fun comparisons.",
672
- "Hostile / Rude": "You are rude and dismissive. You answer questions but with obvious annoyance and sarcasm.",
673
- "Haiku Only": "You must respond only in haiku (5-7-5 syllable format). Never break this rule.",
674
- "Spanish Tutor": "You are a Spanish language tutor. Respond in Spanish, then provide the English translation in parentheses.",
675
- }
676
 
677
 
678
  def _esc_terminal(text: str) -> str:
@@ -680,82 +678,139 @@ def _esc_terminal(text: str) -> str:
680
  return text.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
681
 
682
 
683
- def run_system_prompt_explorer(system_prompt, user_message, max_tokens, temperature, seed):
684
- """Generate a chat response and return formatted terminal + response HTML."""
685
- if not manager.is_ready():
686
- return (
687
- f"<div class='green-terminal'><span class='sp-special'>Error: {manager.status_message()}</span></div>",
688
- "<div class='response-card' style='color:red;'>No model loaded. Load an instruct model from the Admin tab.</div>",
689
- )
690
 
691
- if not manager.is_instruct():
692
- return (
693
- "<div class='green-terminal'><span class='sp-special'>⚠ Current model is not an instruct model.\n\nLoad an instruct model (e.g. Qwen2.5-3B-Instruct) from the Admin tab.</span></div>",
694
- "<div class='response-card'>The System Prompt Explorer requires an instruct/chat model. Base models don't understand system prompts.</div>",
695
- )
696
 
697
- if not user_message.strip():
698
- return (
699
- "<div class='green-terminal'><span class='sp-special'>Enter a message below and click Generate.</span></div>",
700
- "",
701
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
702
 
703
  result = manager.generate_chat(
704
- system_prompt=system_prompt,
705
- user_message=user_message,
706
  max_new_tokens=int(max_tokens),
707
  temperature=temperature,
708
  seed=int(seed),
709
  )
710
 
711
  if "error" in result:
712
- return (
713
- f"<div class='green-terminal'><span class='sp-special'>Error: {_esc_terminal(result['error'])}</span></div>",
714
- "",
715
- )
 
 
 
 
 
 
 
716
 
717
- # Render the formatted prompt in the green terminal
718
- formatted = result["formatted_prompt"]
719
- # Color-code the special tokens and roles
720
- terminal_html = _esc_terminal(formatted)
721
- # Highlight Qwen-style chat template tokens
722
- for tag in ["<|im_start|>", "<|im_end|>"]:
723
- terminal_html = terminal_html.replace(
724
- _esc_terminal(tag),
725
- f"<span class='sp-special'>{_esc_terminal(tag)}</span>",
726
- )
727
- terminal_html = terminal_html.replace(
728
- "system\n", "<span class='sp-label'>system</span>\n"
729
- )
730
- terminal_html = terminal_html.replace(
731
- "user\n", "<span class='sp-label'>user</span>\n"
732
- )
733
- terminal_html = terminal_html.replace(
734
- "assistant\n", "<span class='sp-label'>assistant</span>\n"
735
- )
736
- # Highlight the system prompt content
737
- if system_prompt.strip():
738
- terminal_html = terminal_html.replace(
739
- _esc_terminal(system_prompt),
740
- f"<span class='sp-system'>{_esc_terminal(system_prompt)}</span>",
741
- )
742
- # Highlight user message
743
- terminal_html = terminal_html.replace(
744
- _esc_terminal(user_message),
745
- f"<span class='sp-user'>{_esc_terminal(user_message)}</span>",
746
- )
747
 
748
- terminal_out = f"<div class='green-terminal'>{terminal_html}</div>"
 
 
 
 
 
 
 
 
 
 
 
749
 
750
- # Render the response in a clean card
751
- response_html = f"<div class='response-card'>{_esc(result['response'])}</div>"
752
 
753
- return terminal_out, response_html
 
 
 
 
 
 
 
 
 
754
 
755
 
756
  def on_preset_change(preset_name):
757
- """Update system prompt textbox when a preset is selected."""
758
- return SYSTEM_PROMPT_PRESETS.get(preset_name, "")
759
 
760
 
761
  # ---------------------------------------------------------------------------
@@ -770,13 +825,21 @@ def admin_login(password):
770
 
771
 
772
  def admin_load_model(model_name):
773
- """Load a new model from admin panel."""
774
  status = manager.load_model(model_name)
775
  cfg = manager.get_config()
776
  header_status = f"**{manager.status_message()}**"
777
  return status, json.dumps(cfg, indent=2), header_status
778
 
779
 
 
 
 
 
 
 
 
 
780
  def admin_save_defaults(prompt, tokenizer_text, temperature, top_k, steps, seed):
781
  """Save default settings and return updated values for all outputs."""
782
  manager.update_config(
@@ -800,6 +863,41 @@ def admin_save_defaults(prompt, tokenizer_text, temperature, top_k, steps, seed)
800
  )
801
 
802
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
803
  # ---------------------------------------------------------------------------
804
  # Build the Gradio app
805
  # ---------------------------------------------------------------------------
@@ -978,44 +1076,36 @@ def create_app():
978
  gr.Markdown("### System Prompt Explorer")
979
  gr.Markdown(
980
  "See how **system prompts** change an LLM's behavior. "
 
981
  "The green terminal shows exactly what the model receives — "
982
- "special tokens, role labels, and all. "
983
- "Try different presets or write your own."
984
  )
985
 
986
- # Instruct model notice
987
- sp_model_notice = gr.Markdown(
988
- "*Requires an instruct model (e.g. Qwen2.5-3B-Instruct). "
989
- "Load one from the Admin tab.*",
990
- elem_classes=["param-help"],
991
- )
992
 
993
  sp_preset = gr.Dropdown(
994
  label="Preset",
995
- choices=list(SYSTEM_PROMPT_PRESETS.keys()),
996
- value="Helpful Assistant",
997
  interactive=True,
998
  )
999
  sp_system = gr.Textbox(
1000
  label="System Prompt",
1001
- value=SYSTEM_PROMPT_PRESETS["Helpful Assistant"],
1002
  lines=3,
1003
  placeholder="Enter a system prompt, or select a preset above...",
1004
  )
1005
- sp_user = gr.Textbox(
1006
- label="User Message",
1007
- value="What is Huston-Tillotson University?",
1008
- lines=2,
1009
- )
1010
 
1011
  with gr.Accordion("Settings", open=False):
1012
  sp_max_tokens = gr.Slider(
1013
  label="Max tokens",
1014
- minimum=32, maximum=512, step=16,
1015
- value=256,
1016
  )
1017
  gr.Markdown(
1018
- "Maximum number of tokens in the response.",
1019
  elem_classes=["param-help"],
1020
  )
1021
  sp_temperature = gr.Slider(
@@ -1029,33 +1119,69 @@ def create_app():
1029
  precision=0,
1030
  )
1031
 
1032
- sp_generate_btn = gr.Button("Generate", variant="primary")
 
 
 
 
 
 
1033
 
1034
- gr.Markdown("#### What the model sees")
1035
  gr.Markdown(
1036
- "This is the actual text sent to the model, including special tokens "
1037
- "that mark where system instructions, user messages, and assistant "
1038
- "responses begin and end.",
 
 
1039
  elem_classes=["param-help"],
1040
  )
1041
- sp_terminal = gr.HTML(
1042
- value="<div class='green-terminal'><span class='sp-special'>Select a preset and enter a message, then click Generate.</span></div>",
1043
- )
 
 
 
 
 
 
 
 
 
1044
 
1045
- gr.Markdown("#### Model response")
1046
- sp_response = gr.HTML(value="")
1047
 
1048
- # Wiring
1049
  sp_preset.change(
1050
  fn=on_preset_change,
1051
  inputs=[sp_preset],
1052
  outputs=[sp_system],
1053
  )
1054
 
1055
- sp_generate_btn.click(
1056
- fn=run_system_prompt_explorer,
1057
- inputs=[sp_system, sp_user, sp_max_tokens, sp_temperature, sp_seed],
1058
- outputs=[sp_terminal, sp_response],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1059
  )
1060
 
1061
  # ==================================================================
@@ -1076,16 +1202,26 @@ def create_app():
1076
 
1077
  # Admin controls (hidden until login)
1078
  with gr.Group(visible=False) as admin_controls:
1079
- gr.Markdown("#### Model")
1080
  with gr.Row():
1081
  admin_model_dropdown = gr.Dropdown(
1082
  choices=list(AVAILABLE_MODELS.keys()),
1083
- value=manager.current_model_name or cfg.get("model", "Qwen2.5-3B"),
1084
  label="Select model",
1085
  )
1086
- admin_load_btn = gr.Button("Load Model", variant="primary")
1087
  admin_model_status = gr.Markdown("")
1088
 
 
 
 
 
 
 
 
 
 
 
1089
  gr.Markdown("---")
1090
  gr.Markdown("#### Default Settings")
1091
  admin_prompt = gr.Textbox(
@@ -1120,13 +1256,48 @@ def create_app():
1120
  admin_save_msg = gr.Markdown("")
1121
 
1122
  gr.Markdown("---")
1123
- gr.Markdown("#### Export Slides")
1124
  gr.Markdown(
1125
- "*Uses current settings from Probability Explorer tab.*",
 
1126
  elem_classes=["param-help"],
1127
  )
1128
- admin_export_btn = gr.Button("Export Slides", variant="secondary")
1129
- admin_slides_file = gr.File(label="Slideshow", visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1130
 
1131
  gr.Markdown("---")
1132
  gr.Markdown("#### Current Config")
@@ -1136,6 +1307,15 @@ def create_app():
1136
  interactive=False,
1137
  )
1138
 
 
 
 
 
 
 
 
 
 
1139
  # Login wiring
1140
  admin_login_btn.click(
1141
  fn=admin_login,
@@ -1143,13 +1323,20 @@ def create_app():
1143
  outputs=[admin_controls, admin_login_group, admin_login_msg],
1144
  )
1145
 
1146
- # Model loading
1147
  admin_load_btn.click(
1148
  fn=admin_load_model,
1149
  inputs=[admin_model_dropdown],
1150
  outputs=[admin_model_status, admin_config_display, status_display],
1151
  )
1152
 
 
 
 
 
 
 
 
1153
  # Save defaults — updates config display + Probability Explorer + Tokenizer controls
1154
  admin_save_btn.click(
1155
  fn=admin_save_defaults,
@@ -1166,6 +1353,13 @@ def create_app():
1166
  ],
1167
  )
1168
 
 
 
 
 
 
 
 
1169
  # Export slides — uses current Probability Explorer settings
1170
  admin_export_btn.click(
1171
  fn=generate_slideshow,
@@ -1327,12 +1521,17 @@ def create_app():
1327
  # ---------------------------------------------------------------------------
1328
 
1329
  if __name__ == "__main__":
1330
- # Load default model on startup
1331
  cfg = manager.get_config()
1332
- model_to_load = cfg.get("model", "Qwen2.5-3B")
1333
- print(f"Loading default model: {model_to_load}")
1334
- status = manager.load_model(model_to_load)
1335
- print(status)
 
 
 
 
 
 
1336
 
1337
  app = create_app()
1338
  app.launch(
 
17
 
18
  import gradio as gr
19
 
20
+ import re
21
+
22
  from models import AVAILABLE_MODELS, manager, demo_tokenizer
23
 
24
  # ---------------------------------------------------------------------------
 
665
  # Tab 3: System Prompt Explorer
666
  # ---------------------------------------------------------------------------
667
 
668
+ MAX_CHAT_TURNS = 10 # max user messages before forcing reset
669
+
670
+
671
+ def _get_presets() -> dict:
672
+ """Get current system prompt presets from config."""
673
+ return manager.get_config().get("system_prompt_presets", {})
 
 
 
 
674
 
675
 
676
  def _esc_terminal(text: str) -> str:
 
678
  return text.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
679
 
680
 
681
+ def _format_terminal(raw_text: str) -> str:
682
+ """Parse a chat template string and produce color-coded HTML for the green terminal.
 
 
 
 
 
683
 
684
+ Works with both Llama (<|start_header_id|>) and Qwen (<|im_start|>) templates.
685
+ """
686
+ # Split on special tokens, keeping them
687
+ parts = re.split(r'(<\|[^|]*\|>)', raw_text)
 
688
 
689
+ html_parts = []
690
+ current_role = None
691
+ expect_role = False # true right after a token that precedes a role label
692
+
693
+ role_css = {
694
+ "system": "sp-system",
695
+ "user": "sp-user",
696
+ "assistant": "sp-assistant",
697
+ }
698
+
699
+ for part in parts:
700
+ if re.match(r'<\|[^|]*\|>', part):
701
+ # Special token — render in gray
702
+ html_parts.append(f"<span class='sp-special'>{_esc_terminal(part)}</span>")
703
+ # After im_start or start_header_id, next text chunk is a role label
704
+ expect_role = ("im_start" in part or "start_header_id" in part)
705
+ elif expect_role and part.strip() in role_css:
706
+ # Role label (system / user / assistant)
707
+ role = part.strip()
708
+ current_role = role
709
+ before = _esc_terminal(part[: len(part) - len(part.lstrip())])
710
+ after = _esc_terminal(part[len(part.rstrip()) :])
711
+ html_parts.append(f"{before}<span class='sp-label'>{role}</span>{after}")
712
+ expect_role = False
713
+ else:
714
+ expect_role = False
715
+ css = role_css.get(current_role, "")
716
+ if css and part.strip():
717
+ html_parts.append(f"<span class='{css}'>{_esc_terminal(part)}</span>")
718
+ else:
719
+ html_parts.append(_esc_terminal(part))
720
+
721
+ return "<div class='green-terminal'>" + "".join(html_parts) + "</div>"
722
+
723
+
724
+ def _initial_terminal() -> str:
725
+ return "<div class='green-terminal'><span class='sp-special'>Send a message to see what the model receives.</span></div>"
726
+
727
+
728
+ def send_chat_message(user_message, history, system_prompt, max_tokens, temperature, seed):
729
+ """Handle a user message: generate response, update state + chatbot + terminal.
730
+
731
+ `history` is the gr.State list of clean {"role": ..., "content": ...} dicts.
732
+ The Chatbot is derived from this — never read back from Chatbot (Gradio
733
+ mangles the dicts on round-trip).
734
+ """
735
+ if not user_message or not user_message.strip():
736
+ chatbot = [{"role": m["role"], "content": m["content"]} for m in history]
737
+ return "", history, chatbot, _format_terminal_from_history(history, system_prompt)
738
+
739
+ if not manager.chat_ready():
740
+ history = history + [
741
+ {"role": "user", "content": user_message},
742
+ {"role": "assistant", "content": "No chat model loaded. Load one from the Admin tab."},
743
+ ]
744
+ chatbot = [{"role": m["role"], "content": m["content"]} for m in history]
745
+ return "", history, chatbot, _initial_terminal()
746
+
747
+ # Check turn limit
748
+ user_turns = sum(1 for m in history if m["role"] == "user")
749
+ if user_turns >= MAX_CHAT_TURNS:
750
+ history = history + [
751
+ {"role": "user", "content": user_message},
752
+ {"role": "assistant", "content": f"Conversation limit reached ({MAX_CHAT_TURNS} exchanges). Click Reset to start a new conversation."},
753
+ ]
754
+ chatbot = [{"role": m["role"], "content": m["content"]} for m in history]
755
+ return "", history, chatbot, _format_terminal_from_history(history, system_prompt)
756
+
757
+ # Build full messages for the model
758
+ history = history + [{"role": "user", "content": user_message}]
759
+ messages = []
760
+ if system_prompt and system_prompt.strip():
761
+ messages.append({"role": "system", "content": system_prompt})
762
+ messages.extend(history)
763
 
764
  result = manager.generate_chat(
765
+ messages=messages,
 
766
  max_new_tokens=int(max_tokens),
767
  temperature=temperature,
768
  seed=int(seed),
769
  )
770
 
771
  if "error" in result:
772
+ history = history + [
773
+ {"role": "assistant", "content": f"Error: {result['error']}"},
774
+ ]
775
+ chatbot = [{"role": m["role"], "content": m["content"]} for m in history]
776
+ return "", history, chatbot, _format_terminal_from_history(history, system_prompt)
777
+
778
+ history = history + [{"role": "assistant", "content": result["response"]}]
779
+ chatbot = [{"role": m["role"], "content": m["content"]} for m in history]
780
+ terminal_html = _format_terminal(result["formatted_display"])
781
+
782
+ return "", history, chatbot, terminal_html
783
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
784
 
785
+ def _format_terminal_from_history(chat_history, system_prompt):
786
+ """Build terminal display from chat history (without generating)."""
787
+ if not chat_history:
788
+ return _initial_terminal()
789
+ messages = []
790
+ if system_prompt and system_prompt.strip():
791
+ messages.append({"role": "system", "content": system_prompt})
792
+ messages.extend(chat_history)
793
+ formatted = manager.format_chat_template(messages)
794
+ if not formatted:
795
+ return _initial_terminal()
796
+ return _format_terminal(formatted)
797
 
 
 
798
 
799
+ def reset_chat(system_prompt):
800
+ """Clear chat history, keep system prompt. Show initial terminal with just system prompt.
801
+
802
+ Returns (state, chatbot, terminal).
803
+ """
804
+ if system_prompt and system_prompt.strip() and manager.chat_ready():
805
+ messages = [{"role": "system", "content": system_prompt}]
806
+ formatted = manager.format_chat_template(messages)
807
+ return [], [], _format_terminal(formatted)
808
+ return [], [], _initial_terminal()
809
 
810
 
811
  def on_preset_change(preset_name):
812
+ """Update system prompt textbox when a preset is selected. No chat reset."""
813
+ return _get_presets().get(preset_name, "")
814
 
815
 
816
  # ---------------------------------------------------------------------------
 
825
 
826
 
827
  def admin_load_model(model_name):
828
+ """Load a new base model from admin panel."""
829
  status = manager.load_model(model_name)
830
  cfg = manager.get_config()
831
  header_status = f"**{manager.status_message()}**"
832
  return status, json.dumps(cfg, indent=2), header_status
833
 
834
 
835
+ def admin_load_chat_model(model_name):
836
+ """Load a new chat model from admin panel."""
837
+ status = manager.load_chat_model(model_name)
838
+ cfg = manager.get_config()
839
+ header_status = f"**{manager.status_message()}**"
840
+ return status, json.dumps(cfg, indent=2), header_status
841
+
842
+
843
  def admin_save_defaults(prompt, tokenizer_text, temperature, top_k, steps, seed):
844
  """Save default settings and return updated values for all outputs."""
845
  manager.update_config(
 
863
  )
864
 
865
 
866
+ def admin_save_presets(presets_json):
867
+ """Save system prompt presets from admin panel.
868
+
869
+ Returns (status_msg, config_json, dropdown_update, presets_json_display).
870
+ """
871
+ try:
872
+ presets = json.loads(presets_json)
873
+ except (json.JSONDecodeError, TypeError) as e:
874
+ cfg = manager.get_config()
875
+ return (
876
+ f"Invalid JSON: {e}",
877
+ json.dumps(cfg, indent=2),
878
+ gr.update(),
879
+ gr.update(),
880
+ )
881
+
882
+ if not isinstance(presets, dict):
883
+ cfg = manager.get_config()
884
+ return (
885
+ "Presets must be a JSON object `{\"Name\": \"prompt\", ...}`",
886
+ json.dumps(cfg, indent=2),
887
+ gr.update(),
888
+ gr.update(),
889
+ )
890
+
891
+ manager.update_config(system_prompt_presets=presets)
892
+ cfg = manager.get_config()
893
+ return (
894
+ f"Presets saved ({len(presets)} presets).",
895
+ json.dumps(cfg, indent=2),
896
+ gr.update(choices=list(presets.keys())),
897
+ json.dumps(presets, indent=2),
898
+ )
899
+
900
+
901
  # ---------------------------------------------------------------------------
902
  # Build the Gradio app
903
  # ---------------------------------------------------------------------------
 
1076
  gr.Markdown("### System Prompt Explorer")
1077
  gr.Markdown(
1078
  "See how **system prompts** change an LLM's behavior. "
1079
+ "Pick a preset or write your own, then chat with the model. "
1080
  "The green terminal shows exactly what the model receives — "
1081
+ "every special token, every role label, every turn."
 
1082
  )
1083
 
1084
+ presets = _get_presets()
1085
+ preset_names = list(presets.keys())
1086
+ default_preset = "Helpful Assistant" if "Helpful Assistant" in presets else preset_names[0] if preset_names else ""
 
 
 
1087
 
1088
  sp_preset = gr.Dropdown(
1089
  label="Preset",
1090
+ choices=preset_names,
1091
+ value=default_preset,
1092
  interactive=True,
1093
  )
1094
  sp_system = gr.Textbox(
1095
  label="System Prompt",
1096
+ value=presets.get(default_preset, ""),
1097
  lines=3,
1098
  placeholder="Enter a system prompt, or select a preset above...",
1099
  )
 
 
 
 
 
1100
 
1101
  with gr.Accordion("Settings", open=False):
1102
  sp_max_tokens = gr.Slider(
1103
  label="Max tokens",
1104
+ minimum=32, maximum=1024, step=16,
1105
+ value=512,
1106
  )
1107
  gr.Markdown(
1108
+ "Maximum number of tokens per response.",
1109
  elem_classes=["param-help"],
1110
  )
1111
  sp_temperature = gr.Slider(
 
1119
  precision=0,
1120
  )
1121
 
1122
+ with gr.Accordion("What the model sees", open=False):
1123
+ gr.Markdown(
1124
+ "The full text sent to the model on every turn — system prompt, "
1125
+ "all previous messages, and special tokens. Watch it grow with each exchange.",
1126
+ elem_classes=["param-help"],
1127
+ )
1128
+ sp_terminal = gr.HTML(value=_initial_terminal())
1129
 
1130
+ gr.Markdown("#### Chat")
1131
  gr.Markdown(
1132
+ "**No hidden system prompt.** This model's helpful behavior comes from "
1133
+ "fine-tuning (RLHF), not a secret prompt. When you add a system prompt above, "
1134
+ "it's the *only* instruction the model receives. Commercial APIs like ChatGPT "
1135
+ "and Claude prepend their own system prompts before yours — you can't see or "
1136
+ "remove them.",
1137
  elem_classes=["param-help"],
1138
  )
1139
+ sp_chat_state = gr.State([])
1140
+ sp_chatbot = gr.Chatbot(height=700, feedback_options=None)
1141
+ with gr.Row():
1142
+ sp_user_input = gr.Textbox(
1143
+ label="Message",
1144
+ placeholder="Type a message...",
1145
+ lines=1,
1146
+ scale=4,
1147
+ show_label=False,
1148
+ )
1149
+ sp_send_btn = gr.Button("Send", variant="primary", scale=0, min_width=80)
1150
+ sp_reset_btn = gr.Button("Reset", variant="secondary", scale=0, min_width=80)
1151
 
1152
+ # --- Wiring ---
 
1153
 
1154
+ # Preset dropdown → just fill in the textbox (no chat reset)
1155
  sp_preset.change(
1156
  fn=on_preset_change,
1157
  inputs=[sp_preset],
1158
  outputs=[sp_system],
1159
  )
1160
 
1161
+ # System prompt textbox edits take effect on the next message sent.
1162
+ # No auto-reset — avoids losing conversation on accidental edits.
1163
+ # Use Reset button or pick a new preset to start fresh.
1164
+
1165
+ # Send message (button or enter)
1166
+ send_inputs = [sp_user_input, sp_chat_state, sp_system, sp_max_tokens, sp_temperature, sp_seed]
1167
+ send_outputs = [sp_user_input, sp_chat_state, sp_chatbot, sp_terminal]
1168
+
1169
+ sp_send_btn.click(
1170
+ fn=send_chat_message,
1171
+ inputs=send_inputs,
1172
+ outputs=send_outputs,
1173
+ )
1174
+ sp_user_input.submit(
1175
+ fn=send_chat_message,
1176
+ inputs=send_inputs,
1177
+ outputs=send_outputs,
1178
+ )
1179
+
1180
+ # Reset button
1181
+ sp_reset_btn.click(
1182
+ fn=reset_chat,
1183
+ inputs=[sp_system],
1184
+ outputs=[sp_chat_state, sp_chatbot, sp_terminal],
1185
  )
1186
 
1187
  # ==================================================================
 
1202
 
1203
  # Admin controls (hidden until login)
1204
  with gr.Group(visible=False) as admin_controls:
1205
+ gr.Markdown("#### Base Model (Probability Explorer)")
1206
  with gr.Row():
1207
  admin_model_dropdown = gr.Dropdown(
1208
  choices=list(AVAILABLE_MODELS.keys()),
1209
+ value=manager.current_model_name or cfg.get("model", "Llama-3.2-3B"),
1210
  label="Select model",
1211
  )
1212
+ admin_load_btn = gr.Button("Load", variant="primary")
1213
  admin_model_status = gr.Markdown("")
1214
 
1215
+ gr.Markdown("#### Chat Model (System Prompt Explorer)")
1216
+ with gr.Row():
1217
+ admin_chat_dropdown = gr.Dropdown(
1218
+ choices=list(AVAILABLE_MODELS.keys()),
1219
+ value=manager.chat_model_name or cfg.get("chat_model", "Llama-3.2-3B-Instruct"),
1220
+ label="Select chat model",
1221
+ )
1222
+ admin_chat_load_btn = gr.Button("Load", variant="primary")
1223
+ admin_chat_status = gr.Markdown("")
1224
+
1225
  gr.Markdown("---")
1226
  gr.Markdown("#### Default Settings")
1227
  admin_prompt = gr.Textbox(
 
1256
  admin_save_msg = gr.Markdown("")
1257
 
1258
  gr.Markdown("---")
1259
+ gr.Markdown("#### System Prompt Presets")
1260
  gr.Markdown(
1261
+ "Edit the presets available in the System Prompt Explorer dropdown. "
1262
+ "JSON object: `{\"Name\": \"prompt text\", ...}`",
1263
  elem_classes=["param-help"],
1264
  )
1265
+ admin_presets = gr.Code(
1266
+ value=json.dumps(cfg.get("system_prompt_presets", {}), indent=2),
1267
+ language="json",
1268
+ interactive=True,
1269
+ )
1270
+ admin_presets_save_btn = gr.Button("Save Presets")
1271
+ admin_presets_msg = gr.Markdown("")
1272
+
1273
+ gr.Markdown("---")
1274
+ with gr.Accordion("Environment Variables Reference", open=False):
1275
+ _pw_status = "*(set)*" if os.environ.get("ADMIN_PASSWORD") else "*(default: admin)*"
1276
+ _rb_status = "*(set)*" if REBRANDLY_API_KEY else "*(not set)*"
1277
+ gr.Markdown(
1278
+ "Override settings via "
1279
+ "[HF Space Settings](https://huggingface.co/spaces/chyams/llm-explorer/settings). "
1280
+ "Use **Secrets** for sensitive values (encrypted, hidden after saving) "
1281
+ "and **Variables** for everything else (visible in settings).\n\n"
1282
+ "**Precedence:** env var > config.json > code defaults\n\n"
1283
+ "**Secrets** (sensitive — encrypted)\n\n"
1284
+ "| Variable | Description | Format | Current |\n"
1285
+ "|----------|-------------|--------|---------|\n"
1286
+ f"| `ADMIN_PASSWORD` | Admin panel password | Plain text | {_pw_status} |\n"
1287
+ f"| `REBRANDLY_API_KEY` | URL shortener API key | API key | {_rb_status} |\n"
1288
+ "\n**Variables** (non-sensitive — visible)\n\n"
1289
+ "| Variable | Description | Format | Current |\n"
1290
+ "|----------|-------------|--------|---------|\n"
1291
+ f"| `DEFAULT_MODEL` | Base model (Prob Explorer) | Model name | `{cfg.get('model', '')}` |\n"
1292
+ f"| `DEFAULT_CHAT_MODEL` | Chat model (Sys Prompt Explorer) | Model name | `{cfg.get('chat_model', '')}` |\n"
1293
+ f"| `DEFAULT_PROMPT` | Default prompt | Plain text | `{cfg.get('default_prompt', '')[:40]}...` |\n"
1294
+ f"| `DEFAULT_TEMPERATURE` | Default temperature | Number (0–2.5) | `{cfg.get('default_temperature', 0.8)}` |\n"
1295
+ f"| `DEFAULT_TOP_K` | Default top-k | Integer (5–100) | `{cfg.get('default_top_k', 10)}` |\n"
1296
+ f"| `DEFAULT_STEPS` | Default steps | Integer (1–100) | `{cfg.get('default_steps', 8)}` |\n"
1297
+ f"| `DEFAULT_SEED` | Default seed | Integer | `{cfg.get('default_seed', 42)}` |\n"
1298
+ f"| `DEFAULT_TOKENIZER_TEXT` | Default tokenizer text | Plain text | `{cfg.get('default_tokenizer_text', '')[:40]}...` |\n"
1299
+ f"| `SYSTEM_PROMPT_PRESETS` | System prompt presets | JSON object | *({len(cfg.get('system_prompt_presets', {}))} presets)* |"
1300
+ )
1301
 
1302
  gr.Markdown("---")
1303
  gr.Markdown("#### Current Config")
 
1307
  interactive=False,
1308
  )
1309
 
1310
+ gr.Markdown("---")
1311
+ gr.Markdown("#### Export Slides")
1312
+ gr.Markdown(
1313
+ "*Uses current settings from Probability Explorer tab.*",
1314
+ elem_classes=["param-help"],
1315
+ )
1316
+ admin_export_btn = gr.Button("Export Slides", variant="secondary")
1317
+ admin_slides_file = gr.File(label="Slideshow", visible=False)
1318
+
1319
  # Login wiring
1320
  admin_login_btn.click(
1321
  fn=admin_login,
 
1323
  outputs=[admin_controls, admin_login_group, admin_login_msg],
1324
  )
1325
 
1326
+ # Model loading — base
1327
  admin_load_btn.click(
1328
  fn=admin_load_model,
1329
  inputs=[admin_model_dropdown],
1330
  outputs=[admin_model_status, admin_config_display, status_display],
1331
  )
1332
 
1333
+ # Model loading — chat
1334
+ admin_chat_load_btn.click(
1335
+ fn=admin_load_chat_model,
1336
+ inputs=[admin_chat_dropdown],
1337
+ outputs=[admin_chat_status, admin_config_display, status_display],
1338
+ )
1339
+
1340
  # Save defaults — updates config display + Probability Explorer + Tokenizer controls
1341
  admin_save_btn.click(
1342
  fn=admin_save_defaults,
 
1353
  ],
1354
  )
1355
 
1356
+ # Save presets — updates config, dropdown choices, and presets display
1357
+ admin_presets_save_btn.click(
1358
+ fn=admin_save_presets,
1359
+ inputs=[admin_presets],
1360
+ outputs=[admin_presets_msg, admin_config_display, sp_preset, admin_presets],
1361
+ )
1362
+
1363
  # Export slides — uses current Probability Explorer settings
1364
  admin_export_btn.click(
1365
  fn=generate_slideshow,
 
1521
  # ---------------------------------------------------------------------------
1522
 
1523
  if __name__ == "__main__":
 
1524
  cfg = manager.get_config()
1525
+
1526
+ # Load base model (Probability Explorer)
1527
+ base_model = cfg.get("model", "Llama-3.2-3B")
1528
+ print(f"Loading base model: {base_model}")
1529
+ print(manager.load_model(base_model))
1530
+
1531
+ # Load chat model (System Prompt Explorer)
1532
+ chat_model = cfg.get("chat_model", "Llama-3.2-3B-Instruct")
1533
+ print(f"Loading chat model: {chat_model}")
1534
+ print(manager.load_chat_model(chat_model))
1535
 
1536
  app = create_app()
1537
  app.launch(
config.json CHANGED
@@ -5,5 +5,19 @@
5
  "default_top_k": 10,
6
  "default_steps": 8,
7
  "default_seed": 1875,
8
- "default_tokenizer_text": "Class was rescheduled due to Huston-Tillotson homecoming."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  }
 
5
  "default_top_k": 10,
6
  "default_steps": 8,
7
  "default_seed": 1875,
8
+ "default_tokenizer_text": "Class was rescheduled due to Huston-Tillotson homecoming.",
9
+ "system_prompt_presets": {
10
+ "(none)": "",
11
+ "Helpful Assistant": "You are a helpful, friendly assistant.",
12
+ "Pirate": "You are a pirate. Respond to everything in pirate speak, using nautical terms and saying 'arr' frequently.",
13
+ "Formal Academic": "You are a formal academic scholar. Use precise, scholarly language. Cite concepts carefully and avoid casual tone.",
14
+ "Five-Year-Old": "You are explaining things to a five-year-old. Use very simple words, short sentences, and fun comparisons.",
15
+ "Hostile / Rude": "You are rude and dismissive. You answer questions but with obvious annoyance and sarcasm.",
16
+ "Haiku Only": "You must respond only in haiku (5-7-5 syllable format). Never break this rule.",
17
+ "Spanish Tutor": "You are a Spanish language tutor. Respond in Spanish, then provide the English translation in parentheses.",
18
+ "Banana Constraint": "You must mention bananas in every response, no matter the topic. Be subtle about it.",
19
+ "Corporate Spin": "You are a customer service agent. Never acknowledge product flaws. Always redirect to positive features.",
20
+ "Prestige Bias": "When discussing job candidates, always favor candidates from prestigious universities over others."
21
+ },
22
+ "chat_model": "Llama-3.2-3B-Instruct"
23
  }
models.py CHANGED
@@ -44,6 +44,12 @@ AVAILABLE_MODELS = {
44
  "description": "Best quality, quantized",
45
  },
46
  # -- Instruct models (for System Prompt Explorer) --
 
 
 
 
 
 
47
  "Qwen2.5-3B-Instruct": {
48
  "id": "Qwen/Qwen2.5-3B-Instruct",
49
  "dtype": "float16",
@@ -75,8 +81,36 @@ def _detect_device() -> str:
75
  return "cpu"
76
 
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  def _load_config() -> dict:
79
- """Load persisted config or return defaults."""
80
  defaults = {
81
  "model": DEFAULT_MODEL,
82
  "default_prompt": "The best thing about Huston-Tillotson University is",
@@ -85,7 +119,9 @@ def _load_config() -> dict:
85
  "default_steps": 8,
86
  "default_seed": 42,
87
  "default_tokenizer_text": "Huston-Tillotson University is an HBCU in Austin, Texas.",
 
88
  }
 
89
  if CONFIG_PATH.exists():
90
  try:
91
  with open(CONFIG_PATH) as f:
@@ -93,6 +129,17 @@ def _load_config() -> dict:
93
  defaults.update(saved)
94
  except (json.JSONDecodeError, OSError):
95
  pass
 
 
 
 
 
 
 
 
 
 
 
96
  return defaults
97
 
98
 
@@ -107,114 +154,171 @@ def _save_config(cfg: dict) -> None:
107
  # ---------------------------------------------------------------------------
108
 
109
  class ModelManager:
110
- """Manages a single active model with hot-swap capability."""
111
 
112
  def __init__(self):
 
113
  self.model = None
114
  self.tokenizer = None
115
  self.current_model_name: str | None = None
 
 
 
 
 
 
116
  self.device: str = _detect_device()
117
  self.loading = False
118
  self._lock = threading.Lock()
119
  self.config = _load_config()
120
 
121
  # ------------------------------------------------------------------
122
- # Model lifecycle
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  # ------------------------------------------------------------------
124
 
125
  def load_model(self, model_name: str) -> str:
126
- """Load a model by its display name. Returns status message."""
127
  if model_name not in AVAILABLE_MODELS:
128
  return f"Unknown model: {model_name}"
129
 
130
  if self.loading:
131
  return "A model is already being loaded. Please wait."
132
 
133
- spec = AVAILABLE_MODELS[model_name]
134
-
135
- # Quantized models require CUDA (bitsandbytes doesn't support MPS/CPU)
136
- if spec.get("quantize") and not torch.cuda.is_available():
137
- return (f"Cannot load {model_name}: "
138
- f"{spec['quantize']} quantization requires an NVIDIA GPU (CUDA). "
139
- f"Try a non-quantized model for local development.")
140
-
141
  with self._lock:
142
  self.loading = True
143
  try:
144
- # Unload current model
145
- self._unload()
146
-
147
- # Determine load kwargs
148
- model_id = spec["id"]
149
- load_kwargs: dict = {"device_map": "auto"}
150
-
151
- if spec.get("quantize") == "4bit":
152
- from transformers import BitsAndBytesConfig
153
- load_kwargs["quantization_config"] = BitsAndBytesConfig(
154
- load_in_4bit=True,
155
- bnb_4bit_compute_dtype=torch.float16,
156
- )
157
- elif spec.get("quantize") == "8bit":
158
- from transformers import BitsAndBytesConfig
159
- load_kwargs["quantization_config"] = BitsAndBytesConfig(
160
- load_in_8bit=True,
161
- )
162
- else:
163
- dtype_str = spec.get("dtype", "float16")
164
- if dtype_str == "auto":
165
- load_kwargs["dtype"] = "auto"
166
- else:
167
- load_kwargs["dtype"] = getattr(torch, dtype_str)
168
-
169
- # Load tokenizer + model
170
- self.tokenizer = AutoTokenizer.from_pretrained(model_id)
171
- self.model = AutoModelForCausalLM.from_pretrained(
172
- model_id, **load_kwargs
173
- )
174
- self.model.eval()
175
  self.current_model_name = model_name
176
 
177
- # Persist choice
178
  self.config["model"] = model_name
179
  _save_config(self.config)
180
 
181
- return f"Loaded {model_name} ({model_id})"
182
 
183
  except Exception as e:
184
- self._unload()
 
 
185
  return f"Failed to load {model_name}: {e}"
186
  finally:
187
  self.loading = False
188
 
189
- def _unload(self) -> None:
190
- """Release current model and free memory."""
191
- if self.model is not None:
192
- del self.model
193
- self.model = None
194
- if self.tokenizer is not None:
195
- del self.tokenizer
196
- self.tokenizer = None
197
- self.current_model_name = None
198
- gc.collect()
199
- if torch.cuda.is_available():
200
- torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
  def is_ready(self) -> bool:
203
  return self.model is not None and not self.loading
204
 
205
- def is_instruct(self) -> bool:
206
- """Check if the current model is an instruct/chat model."""
207
- if self.current_model_name is None:
208
- return False
209
- spec = AVAILABLE_MODELS.get(self.current_model_name, {})
210
- return spec.get("instruct", False)
211
 
212
  def status_message(self) -> str:
213
  if self.loading:
214
  return "Loading model..."
215
- if self.model is None:
216
- return "No model loaded"
217
- return f"Model: {self.current_model_name}"
 
 
 
 
 
218
 
219
  # ------------------------------------------------------------------
220
  # Inference helpers
@@ -329,61 +433,72 @@ class ModelManager:
329
 
330
  def generate_chat(
331
  self,
332
- system_prompt: str,
333
- user_message: str,
334
  max_new_tokens: int = 256,
335
  temperature: float = 0.7,
336
  seed: int = 42,
337
  ) -> dict:
338
- """Generate a chat response using the instruct model's chat template.
 
 
 
 
339
 
340
  Returns dict with:
341
- - formatted_prompt: the full tokenized prompt with special tokens
342
  - response: the model's generated response text
343
  """
344
- if not self.is_ready():
345
- return {"error": "Model not loaded"}
346
 
347
- messages = []
348
- if system_prompt.strip():
349
- messages.append({"role": "system", "content": system_prompt})
350
- messages.append({"role": "user", "content": user_message})
351
-
352
- # Apply chat template
353
- formatted = self.tokenizer.apply_chat_template(
354
  messages, tokenize=False, add_generation_prompt=True,
355
  )
356
 
357
  # Tokenize input
358
- inputs = self.tokenizer(formatted, return_tensors="pt")
359
- inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
360
  input_len = inputs["input_ids"].shape[1]
361
 
362
  # Generate
363
  gen_kwargs = {
364
  "max_new_tokens": max_new_tokens,
365
  "do_sample": temperature > 0,
366
- "pad_token_id": self.tokenizer.eos_token_id,
367
  }
368
  if temperature > 0:
369
  gen_kwargs["temperature"] = temperature
370
- # Set seed for reproducibility
371
- if self.model.device.type == "cuda":
372
  torch.cuda.manual_seed(seed)
373
  torch.manual_seed(seed)
374
 
375
  with torch.no_grad():
376
- output_ids = self.model.generate(**inputs, **gen_kwargs)
377
 
378
  # Decode only the new tokens
379
  new_ids = output_ids[0][input_len:]
380
- response = self.tokenizer.decode(new_ids, skip_special_tokens=True)
 
 
 
 
 
 
381
 
382
  return {
383
- "formatted_prompt": formatted,
384
- "response": response.strip(),
385
  }
386
 
 
 
 
 
 
 
 
 
387
  def tokenize(self, text: str) -> list[tuple[str, int]]:
388
  """Tokenize text and return list of (token_str, token_id)."""
389
  if self.tokenizer is None:
 
44
  "description": "Best quality, quantized",
45
  },
46
  # -- Instruct models (for System Prompt Explorer) --
47
+ "Llama-3.2-3B-Instruct": {
48
+ "id": "meta-llama/Llama-3.2-3B-Instruct",
49
+ "dtype": "float16",
50
+ "instruct": True,
51
+ "description": "Chat/instruct model, same family as prod base model (3B)",
52
+ },
53
  "Qwen2.5-3B-Instruct": {
54
  "id": "Qwen/Qwen2.5-3B-Instruct",
55
  "dtype": "float16",
 
81
  return "cpu"
82
 
83
 
84
+ DEFAULT_SYSTEM_PROMPT_PRESETS = {
85
+ "(none)": "",
86
+ "Helpful Assistant": "You are a helpful, friendly assistant.",
87
+ "Pirate": "You are a pirate. Respond to everything in pirate speak, using nautical terms and saying 'arr' frequently.",
88
+ "Formal Academic": "You are a formal academic scholar. Use precise, scholarly language. Cite concepts carefully and avoid casual tone.",
89
+ "Five-Year-Old": "You are explaining things to a five-year-old. Use very simple words, short sentences, and fun comparisons.",
90
+ "Hostile / Rude": "You are rude and dismissive. You answer questions but with obvious annoyance and sarcasm.",
91
+ "Haiku Only": "You must respond only in haiku (5-7-5 syllable format). Never break this rule.",
92
+ "Spanish Tutor": "You are a Spanish language tutor. Respond in Spanish, then provide the English translation in parentheses.",
93
+ "Banana Constraint": "You must mention bananas in every response, no matter the topic. Be subtle about it.",
94
+ "Corporate Spin": "You are a customer service agent. Never acknowledge product flaws. Always redirect to positive features.",
95
+ "Prestige Bias": "When discussing job candidates, always favor candidates from prestigious universities over others.",
96
+ }
97
+
98
+ # Env var → (config key, type converter). "json" = parse as JSON.
99
+ ENV_VAR_MAP = {
100
+ "DEFAULT_MODEL": ("model", str),
101
+ "DEFAULT_CHAT_MODEL": ("chat_model", str),
102
+ "DEFAULT_PROMPT": ("default_prompt", str),
103
+ "DEFAULT_TEMPERATURE": ("default_temperature", float),
104
+ "DEFAULT_TOP_K": ("default_top_k", int),
105
+ "DEFAULT_STEPS": ("default_steps", int),
106
+ "DEFAULT_SEED": ("default_seed", int),
107
+ "DEFAULT_TOKENIZER_TEXT": ("default_tokenizer_text", str),
108
+ "SYSTEM_PROMPT_PRESETS": ("system_prompt_presets", "json"),
109
+ }
110
+
111
+
112
  def _load_config() -> dict:
113
+ """Load config with three layers: code defaults → config.json → env vars."""
114
  defaults = {
115
  "model": DEFAULT_MODEL,
116
  "default_prompt": "The best thing about Huston-Tillotson University is",
 
119
  "default_steps": 8,
120
  "default_seed": 42,
121
  "default_tokenizer_text": "Huston-Tillotson University is an HBCU in Austin, Texas.",
122
+ "system_prompt_presets": dict(DEFAULT_SYSTEM_PROMPT_PRESETS),
123
  }
124
+ # Layer 2: config.json overrides code defaults
125
  if CONFIG_PATH.exists():
126
  try:
127
  with open(CONFIG_PATH) as f:
 
129
  defaults.update(saved)
130
  except (json.JSONDecodeError, OSError):
131
  pass
132
+ # Layer 3: env vars override everything
133
+ for env_var, (config_key, type_fn) in ENV_VAR_MAP.items():
134
+ val = os.environ.get(env_var)
135
+ if val is not None:
136
+ try:
137
+ if type_fn == "json":
138
+ defaults[config_key] = json.loads(val)
139
+ else:
140
+ defaults[config_key] = type_fn(val)
141
+ except (json.JSONDecodeError, ValueError, TypeError):
142
+ pass # bad env var value — skip
143
  return defaults
144
 
145
 
 
154
  # ---------------------------------------------------------------------------
155
 
156
  class ModelManager:
157
+ """Manages two model slots: base (Probability Explorer) and chat (System Prompt Explorer)."""
158
 
159
  def __init__(self):
160
+ # Base model (Probability Explorer)
161
  self.model = None
162
  self.tokenizer = None
163
  self.current_model_name: str | None = None
164
+
165
+ # Chat model (System Prompt Explorer)
166
+ self.chat_model = None
167
+ self.chat_tokenizer = None
168
+ self.chat_model_name: str | None = None
169
+
170
  self.device: str = _detect_device()
171
  self.loading = False
172
  self._lock = threading.Lock()
173
  self.config = _load_config()
174
 
175
  # ------------------------------------------------------------------
176
+ # Shared loading logic
177
+ # ------------------------------------------------------------------
178
+
179
+ def _do_load(self, model_name: str):
180
+ """Load model + tokenizer by name. Returns (model, tokenizer). Raises on failure."""
181
+ spec = AVAILABLE_MODELS[model_name]
182
+
183
+ if spec.get("quantize") and not torch.cuda.is_available():
184
+ raise RuntimeError(
185
+ f"Cannot load {model_name}: "
186
+ f"{spec['quantize']} quantization requires an NVIDIA GPU (CUDA). "
187
+ f"Try a non-quantized model for local development."
188
+ )
189
+
190
+ model_id = spec["id"]
191
+ load_kwargs: dict = {"device_map": "auto"}
192
+
193
+ if spec.get("quantize") == "4bit":
194
+ from transformers import BitsAndBytesConfig
195
+ load_kwargs["quantization_config"] = BitsAndBytesConfig(
196
+ load_in_4bit=True,
197
+ bnb_4bit_compute_dtype=torch.float16,
198
+ )
199
+ elif spec.get("quantize") == "8bit":
200
+ from transformers import BitsAndBytesConfig
201
+ load_kwargs["quantization_config"] = BitsAndBytesConfig(
202
+ load_in_8bit=True,
203
+ )
204
+ else:
205
+ dtype_str = spec.get("dtype", "float16")
206
+ if dtype_str == "auto":
207
+ load_kwargs["dtype"] = "auto"
208
+ else:
209
+ load_kwargs["dtype"] = getattr(torch, dtype_str)
210
+
211
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
212
+ model = AutoModelForCausalLM.from_pretrained(model_id, **load_kwargs)
213
+ model.eval()
214
+ return model, tokenizer
215
+
216
+ # ------------------------------------------------------------------
217
+ # Base model lifecycle
218
  # ------------------------------------------------------------------
219
 
220
  def load_model(self, model_name: str) -> str:
221
+ """Load base model for Probability Explorer. Returns status message."""
222
  if model_name not in AVAILABLE_MODELS:
223
  return f"Unknown model: {model_name}"
224
 
225
  if self.loading:
226
  return "A model is already being loaded. Please wait."
227
 
 
 
 
 
 
 
 
 
228
  with self._lock:
229
  self.loading = True
230
  try:
231
+ # Unload current base model
232
+ if self.model is not None:
233
+ del self.model
234
+ self.model = None
235
+ if self.tokenizer is not None:
236
+ del self.tokenizer
237
+ self.tokenizer = None
238
+ self.current_model_name = None
239
+ gc.collect()
240
+
241
+ model, tokenizer = self._do_load(model_name)
242
+ self.model = model
243
+ self.tokenizer = tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  self.current_model_name = model_name
245
 
 
246
  self.config["model"] = model_name
247
  _save_config(self.config)
248
 
249
+ return f"Loaded base model: {model_name}"
250
 
251
  except Exception as e:
252
+ self.model = None
253
+ self.tokenizer = None
254
+ self.current_model_name = None
255
  return f"Failed to load {model_name}: {e}"
256
  finally:
257
  self.loading = False
258
 
259
+ # ------------------------------------------------------------------
260
+ # Chat model lifecycle
261
+ # ------------------------------------------------------------------
262
+
263
+ def load_chat_model(self, model_name: str) -> str:
264
+ """Load chat/instruct model for System Prompt Explorer. Returns status message."""
265
+ if model_name not in AVAILABLE_MODELS:
266
+ return f"Unknown model: {model_name}"
267
+
268
+ if self.loading:
269
+ return "A model is already being loaded. Please wait."
270
+
271
+ with self._lock:
272
+ self.loading = True
273
+ try:
274
+ if self.chat_model is not None:
275
+ del self.chat_model
276
+ self.chat_model = None
277
+ if self.chat_tokenizer is not None:
278
+ del self.chat_tokenizer
279
+ self.chat_tokenizer = None
280
+ self.chat_model_name = None
281
+ gc.collect()
282
+
283
+ model, tokenizer = self._do_load(model_name)
284
+ self.chat_model = model
285
+ self.chat_tokenizer = tokenizer
286
+ self.chat_model_name = model_name
287
+
288
+ self.config["chat_model"] = model_name
289
+ _save_config(self.config)
290
+
291
+ return f"Loaded chat model: {model_name}"
292
+
293
+ except Exception as e:
294
+ self.chat_model = None
295
+ self.chat_tokenizer = None
296
+ self.chat_model_name = None
297
+ return f"Failed to load chat model {model_name}: {e}"
298
+ finally:
299
+ self.loading = False
300
+
301
+ # ------------------------------------------------------------------
302
+ # Status
303
+ # ------------------------------------------------------------------
304
 
305
  def is_ready(self) -> bool:
306
  return self.model is not None and not self.loading
307
 
308
+ def chat_ready(self) -> bool:
309
+ return self.chat_model is not None and not self.loading
 
 
 
 
310
 
311
  def status_message(self) -> str:
312
  if self.loading:
313
  return "Loading model..."
314
+ parts = []
315
+ if self.model:
316
+ parts.append(f"Base: {self.current_model_name}")
317
+ if self.chat_model:
318
+ parts.append(f"Chat: {self.chat_model_name}")
319
+ if not parts:
320
+ return "No models loaded"
321
+ return " | ".join(parts)
322
 
323
  # ------------------------------------------------------------------
324
  # Inference helpers
 
433
 
434
  def generate_chat(
435
  self,
436
+ messages: list[dict],
 
437
  max_new_tokens: int = 256,
438
  temperature: float = 0.7,
439
  seed: int = 42,
440
  ) -> dict:
441
+ """Generate a chat response using the dedicated chat model.
442
+
443
+ Args:
444
+ messages: Full conversation as list of {"role": ..., "content": ...} dicts,
445
+ including system prompt and all previous turns.
446
 
447
  Returns dict with:
448
+ - formatted_display: the full template including the response (for terminal)
449
  - response: the model's generated response text
450
  """
451
+ if not self.chat_ready():
452
+ return {"error": "Chat model not loaded"}
453
 
454
+ # Format input (everything up to and including the generation prompt)
455
+ formatted = self.chat_tokenizer.apply_chat_template(
 
 
 
 
 
456
  messages, tokenize=False, add_generation_prompt=True,
457
  )
458
 
459
  # Tokenize input
460
+ inputs = self.chat_tokenizer(formatted, return_tensors="pt")
461
+ inputs = {k: v.to(self.chat_model.device) for k, v in inputs.items()}
462
  input_len = inputs["input_ids"].shape[1]
463
 
464
  # Generate
465
  gen_kwargs = {
466
  "max_new_tokens": max_new_tokens,
467
  "do_sample": temperature > 0,
468
+ "pad_token_id": self.chat_tokenizer.eos_token_id,
469
  }
470
  if temperature > 0:
471
  gen_kwargs["temperature"] = temperature
472
+ if self.chat_model.device.type == "cuda":
 
473
  torch.cuda.manual_seed(seed)
474
  torch.manual_seed(seed)
475
 
476
  with torch.no_grad():
477
+ output_ids = self.chat_model.generate(**inputs, **gen_kwargs)
478
 
479
  # Decode only the new tokens
480
  new_ids = output_ids[0][input_len:]
481
+ response = self.chat_tokenizer.decode(new_ids, skip_special_tokens=True).strip()
482
+
483
+ # Build display template (includes the response) for green terminal
484
+ display_messages = messages + [{"role": "assistant", "content": response}]
485
+ formatted_display = self.chat_tokenizer.apply_chat_template(
486
+ display_messages, tokenize=False, add_generation_prompt=False,
487
+ )
488
 
489
  return {
490
+ "formatted_display": formatted_display,
491
+ "response": response,
492
  }
493
 
494
+ def format_chat_template(self, messages: list[dict]) -> str:
495
+ """Format messages using the chat model's template (for terminal display)."""
496
+ if not self.chat_tokenizer:
497
+ return ""
498
+ return self.chat_tokenizer.apply_chat_template(
499
+ messages, tokenize=False, add_generation_prompt=True,
500
+ )
501
+
502
  def tokenize(self, text: str) -> list[tuple[str, int]]:
503
  """Tokenize text and return list of (token_str, token_id)."""
504
  if self.tokenizer is None: