prisma-chatbot / app.py
RolandM's picture
Let impressions panel match chat width on mobile
3e5167a
"""Gradio app entry point for PRISMA.
Wires prompt construction, inference, evaluation parsing, and an
always-visible impressions panel (bar-style colored cells plus a trajectory
plot) into a Gradio Blocks interface with a custom dark theme.
State held in ``gr.State``:
{
"history": list[dict], # OpenAI-format messages (system + chat)
"evaluations": list[dict], # one per assistant turn
"turn_count": int, # completed user turns
}
"""
from __future__ import annotations
import os
from pathlib import Path
from typing import Any
import gradio as gr
import matplotlib
matplotlib.use("Agg") # non-interactive backend, required for server-side use
from matplotlib.figure import Figure # noqa: E402
from dotenv import load_dotenv # noqa: E402
from src.config import ( # noqa: E402
ATTRIBUTE_COLORS,
DEFAULT_ATTRIBUTES,
MAX_SCORE,
SESSION_TURN_CAP,
)
from src.evaluation import ( # noqa: E402
INTENSIFIER_SCALE,
EvaluationParseError,
)
from src.inference import ( # noqa: E402
InferenceError,
PrismaInferenceClient,
)
from src.prompt import build_system_prompt # noqa: E402
# ---------------------------------------------------------------------------
# One-time setup
# ---------------------------------------------------------------------------
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
raise RuntimeError(
"HF_TOKEN not found. Set it in .env at the repo root "
"(see .env.example)."
)
CLIENT = PrismaInferenceClient(token=HF_TOKEN)
SYSTEM_PROMPT = build_system_prompt()
# Load the (small) footer figure inline if available; otherwise show a
# discreet placeholder rectangle. Drop your finalized small figure at
# assets/prisma-figure-footer.svg to replace the placeholder.
FOOTER_FIGURE_PATH = Path(__file__).parent / "assets" / "prisma-figure-footer.svg"
FOOTER_FIGURE_PLACEHOLDER = """
<svg width="130" height="90" viewBox="0 0 130 90"
xmlns="http://www.w3.org/2000/svg" role="img"
aria-label="PRISMA figure placeholder">
<rect width="130" height="90" rx="6"
fill="#1f1f33" stroke="#3d3d68" stroke-width="1"/>
<text x="65" y="42" text-anchor="middle"
fill="#9ca3af" font-family="serif"
font-size="11" font-style="italic">figure</text>
<text x="65" y="58" text-anchor="middle"
fill="#9ca3af" font-family="serif"
font-size="11" font-style="italic">placeholder</text>
</svg>
"""
FOOTER_FIGURE_SVG = (
FOOTER_FIGURE_PATH.read_text()
if FOOTER_FIGURE_PATH.exists()
else FOOTER_FIGURE_PLACEHOLDER
)
# ---------------------------------------------------------------------------
# Theme & CSS
# ---------------------------------------------------------------------------
THEME = gr.themes.Base(
primary_hue="violet",
neutral_hue="slate",
).set(
body_background_fill="#0f0f1a",
body_background_fill_dark="#0f0f1a",
block_background_fill="#1a1a2e",
block_background_fill_dark="#1a1a2e",
body_text_color="#e5e7eb",
body_text_color_dark="#e5e7eb",
border_color_primary="#2a2a44",
border_color_primary_dark="#2a2a44",
input_background_fill="#1a1a2e",
input_background_fill_dark="#1a1a2e",
)
CUSTOM_CSS = """
#prisma-header {
padding: 0.5rem 0 1rem 0;
text-align: left;
}
#prisma-header h1 {
font-size: 2.5rem;
margin: 0.25rem 0 0.25rem 0;
letter-spacing: 0.05em;
}
#prisma-header .tagline {
font-size: 1.2rem;
font-style: italic;
color: #9ca3af;
margin: 0 0 0.5rem 0;
}
#prisma-header .description {
font-size: 1rem;
color: #cbd5e1;
line-height: 1.45;
margin: 0;
}
#prisma-header .disclaimer {
font-size: 0.85rem;
color: #9ca3af;
font-style: italic;
line-height: 1.4;
margin: 0.5rem 0 0 0;
}
/* Dark backgrounds for text inputs (overrides theme defaults) */
textarea,
input[type="text"],
input[type="search"] {
background-color: #1a1a2e !important;
color: #e5e7eb !important;
border-color: #2a2a44 !important;
}
/* Dropdown trigger */
.gr-dropdown,
.gr-dropdown > div,
.gr-dropdown input {
background-color: #1a1a2e !important;
color: #e5e7eb !important;
}
/* Dropdown options when open */
ul[role="listbox"],
ul.options {
background-color: #1a1a2e !important;
color: #e5e7eb !important;
border: 1px solid #2a2a44 !important;
}
ul[role="listbox"] li,
ul.options li {
background-color: #1a1a2e !important;
color: #e5e7eb !important;
}
ul[role="listbox"] li:hover,
ul.options li:hover,
ul[role="listbox"] li.selected,
ul.options li.selected {
background-color: #2a2a44 !important;
}
#impressions-panel {
flex: 0 0 360px !important;
max-width: 360px !important;
min-width: 360px !important;
}
#impressions-panel h3 {
font-size: 1.4rem;
margin: 0 0 0.75rem 0;
}
.impressions-header {
font-size: 1.05rem;
font-weight: 600;
margin: 0.5rem 0 0.75rem 0;
color: #e5e7eb;
}
.impression-row {
padding: 0.55rem 0.85rem;
margin: 0.35rem 0;
border-radius: 6px;
color: #ffffff;
font-weight: 500;
font-size: 0.95rem;
white-space: nowrap;
text-shadow: 0 1px 2px rgba(0, 0, 0, 0.55);
letter-spacing: 0.01em;
}
.impressions-empty {
font-style: italic;
color: #9ca3af;
padding: 0.5rem 0;
}
/* Chat message bubbles — override default light backgrounds */
.message,
.bubble,
.bubble-wrap,
.message-wrap .message,
.message-row .message,
[data-testid="user"] .message,
[data-testid="bot"] .message,
.user .bubble,
.bot .bubble,
.assistant .bubble {
background-color: #2a2a44 !important;
color: #e5e7eb !important;
}
/* User messages (right side) — slightly different shade for contrast */
.message-row.user-row .message,
[data-testid="user"] .message,
.user .bubble {
background-color: #3d3d68 !important;
}
/* Highlight for the user message corresponding to the selected turn */
.selected-turn {
position: relative;
}
.selected-turn::after {
content: "";
position: absolute;
top: -3px; left: -3px; right: -3px; bottom: -3px;
border-radius: 10px;
border: 2px solid #fcd34d;
box-shadow: 0 0 14px rgba(252, 211, 77, 0.45);
pointer-events: none;
}
/* Warning/info/error toast notifications */
.toast,
.toast-body,
.toast-text,
.gr-toast,
[class~="toast"] {
background-color: #2a2a44 !important;
color: #e5e7eb !important;
border: 1px solid #ef4444 !important;
}
.toast .icon,
.toast svg,
.gr-toast svg,
[class~="toast"] svg {
color: #ef4444 !important;
fill: #ef4444 !important;
}
/* Footer */
#prisma-footer {
padding: 1.5rem 1rem 0.75rem 1rem;
margin-top: 1.5rem;
border-top: 1px solid #2a2a44;
}
#prisma-footer .footer-row {
display: flex;
align-items: center;
justify-content: space-between;
gap: 2rem;
}
#prisma-footer .footer-left {
flex: 0 0 auto;
}
#prisma-footer .footer-left svg {
width: 260px;
height: auto;
display: block;
}
#prisma-footer .footer-center {
flex: 1;
text-align: center;
}
#prisma-footer .footer-right {
flex: 0 0 auto;
text-align: right;
font-size: 0.95rem;
}
#prisma-footer .prisma-fullname {
font-size: 1.1rem;
font-style: italic;
color: #9ca3af;
letter-spacing: 0.03em;
margin: 0 0 0.4rem 0;
}
#prisma-footer .footer-contact {
font-size: 0.9rem;
color: #9ca3af;
margin: 0;
}
#prisma-footer .footer-right a {
color: #93c5fd;
text-decoration: none;
margin-left: 0.6rem;
}
#prisma-footer .footer-right a:hover {
color: #fcd34d;
text-decoration: underline;
}
/* Mobile: stack footer columns vertically and center them. */
@media (max-width: 768px) {
#prisma-footer .footer-row {
flex-direction: column;
text-align: center;
gap: 1rem;
}
#prisma-footer .footer-left svg {
width: 200px;
}
#prisma-footer .footer-center,
#prisma-footer .footer-right {
text-align: center;
}
#prisma-footer .footer-right a {
margin: 0 0.4rem;
}
/* Let the impressions panel match the chat-column width below it. */
#impressions-panel {
flex: 1 1 auto !important;
min-width: 0 !important;
max-width: 100% !important;
width: 100% !important;
}
}
"""
# JS that highlights the user message at the currently-selected turn index.
# Since errored attempts are no longer added to the chat, the dropdown's
# turn index maps directly to the Nth user message in the DOM.
HIGHLIGHT_TURN_JS = """
(turn_index) => {
document.querySelectorAll('.selected-turn').forEach(el => {
el.classList.remove('selected-turn');
});
if (turn_index === null || turn_index === undefined) {
return turn_index;
}
const candidates = [
'.message-row.user-row',
'[data-testid="user"]',
'.message.user',
'.user'
];
for (const selector of candidates) {
const messages = document.querySelectorAll(selector);
if (messages.length > 0) {
if (messages[turn_index]) {
messages[turn_index].classList.add('selected-turn');
}
break;
}
}
return turn_index;
}
"""
# ---------------------------------------------------------------------------
# State helpers
# ---------------------------------------------------------------------------
def initial_state() -> dict[str, Any]:
"""Return a fresh conversation state for a new session."""
return {
"history": [{"role": "system", "content": SYSTEM_PROMPT}],
"evaluations": [],
"turn_count": 0,
}
# ---------------------------------------------------------------------------
# Chat handler
# ---------------------------------------------------------------------------
def chat_step(
user_message: str,
chat_display: list[dict[str, str]],
state: dict[str, Any],
):
"""Process one user turn: call the model, update state and UI.
On success, the user message and assistant response are added to
``chat_display`` and a new evaluation is recorded. On failure, the
chat is NOT modified — the error is surfaced via gr.Warning, and the
user's text is kept in the input box so they can edit and retry.
Returns updates for (chatbot, state, msg_in, turn_dropdown).
"""
user_message = (user_message or "").strip()
if not user_message:
return chat_display, state, "", gr.Dropdown()
# Session cap reached — refuse further requests.
if state["turn_count"] >= SESSION_TURN_CAP:
notice = (
f"Session complete — Prisma has chatted with you for "
f"{SESSION_TURN_CAP} turns. Refresh the page to start over."
)
chat_display = chat_display + [
{"role": "user", "content": user_message},
{"role": "assistant", "content": notice},
]
return chat_display, state, "", gr.Dropdown()
state["history"].append({"role": "user", "content": user_message})
try:
parsed = CLIENT.generate(state["history"])
state["history"].append(
{"role": "assistant", "content": parsed.response}
)
state["evaluations"].append(parsed.evaluation)
state["turn_count"] += 1
chat_display = chat_display + [
{"role": "user", "content": user_message},
{"role": "assistant", "content": parsed.response},
]
msg_in_value = "" # clear input on success
except (InferenceError, EvaluationParseError) as exc:
# Roll back the unanswered user message so retries send clean history.
state["history"].pop()
# Log technical details to the container log for debugging.
print(f"[error] {type(exc).__name__}: {exc}")
# Surface a friendly notification to the user without polluting the
# chat history. The error attempt does not appear as a bubble.
gr.Warning(
"I wasn't able to respond properly to that. "
"Try rephrasing or asking something else."
)
# Keep the user's text in the input box so they can edit and retry.
msg_in_value = user_message
n_evals = len(state["evaluations"])
if n_evals > 0:
choices = [(f"Turn {i + 1}", i) for i in range(n_evals)]
dropdown_update = gr.Dropdown(choices=choices, value=n_evals - 1)
else:
dropdown_update = gr.Dropdown(choices=[], value=None)
return chat_display, state, msg_in_value, dropdown_update
# ---------------------------------------------------------------------------
# Impressions rendering
# ---------------------------------------------------------------------------
def render_impression(state: dict[str, Any], turn_index: int | None) -> str:
"""Build HTML for the impressions panel: header + colored bar cells.
Each row uses a linear-gradient background that fills up to (score/MAX)
of the row's width with the attribute's saturated color, then continues
with the same color at low alpha for the remainder. This doubles the
text label as a per-attribute bar plot.
"""
evaluations = state.get("evaluations", [])
if not evaluations:
return (
'<div class="impressions-empty">'
"No impressions yet — say something to Prisma."
"</div>"
)
if turn_index is None or turn_index < 0 or turn_index >= len(evaluations):
turn_index = len(evaluations) - 1
evaluation = evaluations[turn_index]
header = (
f'<div class="impressions-header">After turn {turn_index + 1}:</div>'
)
rows: list[str] = []
for attr in DEFAULT_ATTRIBUTES:
score = evaluation[attr]
color = ATTRIBUTE_COLORS[attr]
intensifier = INTENSIFIER_SCALE[score]
pct = (score / MAX_SCORE) * 100
# Two-stop linear gradient: saturated up to `pct`, then ~20% alpha.
# `{color}33` appends 0x33 (~20%) alpha to the hex color.
gradient = (
f"linear-gradient(to right, "
f"{color} 0%, {color} {pct:.1f}%, "
f"{color}33 {pct:.1f}%, {color}33 100%)"
)
rows.append(
f'<div class="impression-row" style="background: {gradient};">'
f"{intensifier} {attr} ({score}/{MAX_SCORE})"
f"</div>"
)
return header + "\n" + "\n".join(rows)
def render_trajectory(state: dict[str, Any]):
"""Render a line plot of scores per attribute across turns.
Colors match the bar cells so the rating list above acts as the legend.
A small fixed y-offset per attribute spreads overlapping points so every
attribute's marker remains visible when several share the same score on
the same turn.
"""
evaluations = state.get("evaluations", [])
fig = Figure(figsize=(5, 3), facecolor="#1a1a2e")
ax = fig.add_subplot(111)
ax.set_facecolor("#1a1a2e")
if not evaluations:
ax.text(
0.5,
0.5,
"No data yet",
ha="center",
va="center",
color="#9ca3af",
fontsize=12,
fontstyle="italic",
transform=ax.transAxes,
)
ax.set_xticks([])
ax.set_yticks([])
for spine in ax.spines.values():
spine.set_visible(False)
fig.tight_layout()
return fig
# Small fixed y-offset per attribute so overlapping points stay visible.
# Total spread is ±0.15 score units around the true score.
n = len(DEFAULT_ATTRIBUTES)
jitter_step = 0.06
jitter = {
attr: (i - (n - 1) / 2) * jitter_step
for i, attr in enumerate(DEFAULT_ATTRIBUTES)
}
turns = list(range(1, len(evaluations) + 1))
for attr in DEFAULT_ATTRIBUTES:
scores = [e[attr] + jitter[attr] for e in evaluations]
ax.plot(
turns,
scores,
color=ATTRIBUTE_COLORS[attr],
marker="o",
linewidth=2,
markersize=5,
)
ax.set_xlabel("Turn", color="#e5e7eb")
ax.set_ylabel("Score", color="#e5e7eb")
ax.set_ylim(0.5, 7.5)
ax.set_yticks(range(1, MAX_SCORE + 1))
ax.set_xticks(turns)
ax.tick_params(colors="#e5e7eb")
ax.grid(True, alpha=0.15, color="#9ca3af")
for spine_name in ("top", "right"):
ax.spines[spine_name].set_visible(False)
for spine_name in ("bottom", "left"):
ax.spines[spine_name].set_color("#9ca3af")
fig.tight_layout()
return fig
# ---------------------------------------------------------------------------
# UI
# ---------------------------------------------------------------------------
with gr.Blocks(theme=THEME, css=CUSTOM_CSS, title="PRISMA") as demo:
gr.HTML(
"""
<div id="prisma-header">
<h1>
<span style="color: #e11d48;">P</span><span style="color: #f97316;">R</span><span style="color: #eab308;">I</span><span style="color: #22c55e;">S</span><span style="color: #3b82f6;">M</span><span style="color: #a855f7;">A</span>
</h1>
<p class="tagline">Have you ever wondered what your chatbot thinks about you?</p>
<p class="description">
Chat with Prisma. She'll respond — and form impressions of you based on how you write.
</p>
<p class="disclaimer">
Research demo. Evaluations are a language model's judgments, not a validated assessment.
</p>
</div>
"""
)
state = gr.State(initial_state())
with gr.Row():
with gr.Column(scale=1):
chatbot = gr.Chatbot(
label="Chat with Prisma",
height=600,
)
with gr.Row():
msg_in = gr.Textbox(
placeholder="Say something to Prisma...",
show_label=False,
scale=4,
)
send_btn = gr.Button("Send", variant="primary", scale=1)
with gr.Column(scale=0, min_width=360, elem_id="impressions-panel"):
gr.Markdown("### Prisma's impressions of you")
turn_dropdown = gr.Dropdown(
choices=[],
label="Show impression after turn:",
interactive=True,
)
impressions_html = gr.HTML(
value=(
'<div class="impressions-empty">'
"No impressions yet — say something to Prisma."
"</div>"
),
)
trajectory_plot = gr.Plot(
value=render_trajectory(initial_state()), label=None
)
# Footer: small figure, colored acronym expansion, contact + links.
# The figure SVG file goes at assets/prisma-figure-footer.svg;
# a placeholder rectangle is shown if the file is missing.
gr.HTML(
f"""
<div id="prisma-footer">
<div class="footer-row">
<div class="footer-left">
{FOOTER_FIGURE_SVG}
</div>
<div class="footer-center">
<p class="prisma-fullname">
<span style="color: #e11d48;">P</span>ragmatic
<span style="color: #f97316;">R</span>eal-time
<span style="color: #eab308;">I</span>nference of
<span style="color: #22c55e;">S</span>ocial
<span style="color: #3b82f6;">M</span>eaning in
<span style="color: #a855f7;">A</span>gents
</p>
<p class="footer-contact">
Roland Mühlenbernd · Leibniz-Centre General Linguistics, Berlin
</p>
</div>
<div class="footer-right">
<a href="https://muehlenbernd.net/" target="_blank" rel="noopener">Website</a>
<a href="https://github.com/muehlenbernd/prisma-chatbot" target="_blank" rel="noopener">GitHub</a>
<a href="https://www.linkedin.com/in/rolandmuehlenbernd/" target="_blank" rel="noopener">LinkedIn</a>
</div>
</div>
</div>
"""
)
# Same submit handler for Enter-key and Send button.
for trigger in (send_btn.click, msg_in.submit):
trigger(
chat_step,
inputs=[msg_in, chatbot, state],
outputs=[chatbot, state, msg_in, turn_dropdown],
).then(
render_impression,
inputs=[state, turn_dropdown],
outputs=impressions_html,
).then(
render_trajectory,
inputs=state,
outputs=trajectory_plot,
).then(
fn=None,
inputs=turn_dropdown,
outputs=None,
js=HIGHLIGHT_TURN_JS,
)
turn_dropdown.change(
render_impression,
inputs=[state, turn_dropdown],
outputs=impressions_html,
).then(
fn=None,
inputs=turn_dropdown,
outputs=None,
js=HIGHLIGHT_TURN_JS,
)
if __name__ == "__main__":
demo.launch(ssr_mode=False)