File size: 29,628 Bytes
0dad901
 
ecd28db
e91e2b4
0dad901
e91e2b4
0dad901
e91e2b4
 
23c7579
 
48de182
 
42f0aac
ecd28db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e91e2b4
 
ecd28db
e91e2b4
ecd28db
e91e2b4
 
 
23c7579
 
 
27caa28
23c7579
 
 
 
ecd28db
 
 
 
e91e2b4
 
 
 
23c7579
ef68cd9
1dbbeef
 
 
 
23c7579
ecd28db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e91e2b4
 
 
 
23c7579
 
 
 
 
 
 
 
 
 
ecd28db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23c7579
 
ecd28db
 
 
 
 
 
23c7579
ecd28db
 
 
 
 
 
23c7579
 
d759549
f827eab
 
4a28790
f827eab
4a28790
 
 
15dc14b
 
 
 
 
 
3653027
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15dc14b
3653027
 
 
 
15dc14b
ad65183
15dc14b
 
ad65183
 
15dc14b
 
 
 
 
 
 
 
 
 
3653027
15dc14b
 
 
 
 
 
 
3653027
15dc14b
3653027
15dc14b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad65183
4a28790
ad65183
 
 
 
 
 
 
 
 
 
 
4a28790
ad65183
 
 
15dc14b
 
ad65183
 
 
 
 
 
15dc14b
 
4a28790
ad65183
 
4a28790
 
ad65183
 
 
 
 
 
 
 
4a28790
 
 
d759549
4a28790
d759549
f827eab
d759549
 
 
 
 
 
 
 
 
 
e91e2b4
422aadb
e91e2b4
 
 
 
ecd28db
 
 
 
 
 
e91e2b4
 
 
ecd28db
 
 
 
 
 
e91e2b4
 
 
 
 
 
cdd476e
e91e2b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ecd28db
 
 
 
 
 
 
 
 
e91e2b4
 
 
 
 
 
ecd28db
e91e2b4
 
4d77281
 
 
 
 
 
e91e2b4
ecd28db
e91e2b4
 
4d77281
 
 
 
 
 
e91e2b4
 
ecd28db
e91e2b4
 
ecd28db
e91e2b4
 
4d77281
 
 
 
 
 
 
 
 
 
ecd28db
e91e2b4
23c7579
e91e2b4
 
 
 
 
 
 
ecd28db
 
 
 
 
 
 
 
e91e2b4
 
ecd28db
e91e2b4
 
 
 
 
 
ecd28db
 
 
 
 
 
 
e91e2b4
 
 
 
 
 
ecd28db
 
 
 
 
 
 
e91e2b4
 
 
ecd28db
 
 
 
 
 
 
 
27caa28
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
import os
import secrets
import time
from typing import cast
from urllib.parse import urlencode

import requests
import streamlit as st

from gateway_client import delete_profile, ingest_and_rewrite
from llm import chat, set_model
from model_config import MODEL_CHOICES, MODEL_TO_PROVIDER, MODEL_DISPLAY_NAMES



def _generate_session_name(base: str = "Session") -> str:
    existing = set(st.session_state.get("session_order", []))
    idx = 1
    while True:
        candidate = f"{base} {idx}"
        if candidate not in existing:
            return candidate
        idx += 1

def ensure_session_state() -> None:
    if "sessions" not in st.session_state:
        st.session_state.sessions = {}
    if "session_order" not in st.session_state:
        st.session_state.session_order = []
    if (
        "active_session_id" not in st.session_state
        or st.session_state.active_session_id not in st.session_state.sessions
    ):
        default_name = _generate_session_name()
        st.session_state.sessions.setdefault(default_name, {"history": []})
        if default_name not in st.session_state.session_order:
            st.session_state.session_order.append(default_name)
        st.session_state.active_session_id = default_name
    if "session_select" not in st.session_state:
        st.session_state.session_select = st.session_state.active_session_id
    if st.session_state.session_select not in st.session_state.sessions:
        st.session_state.session_select = st.session_state.active_session_id
    st.session_state.setdefault(
        "rename_session_name", st.session_state.active_session_id
    )
    st.session_state.setdefault(
        "rename_session_synced_to", st.session_state.active_session_id
    )
    st.session_state.history = cast(
        list[dict],
        st.session_state.sessions[
            st.session_state.active_session_id
        ].setdefault("history", []),
    )


def create_session(session_name: str | None = None) -> tuple[bool, str]:
    ensure_session_state()
    candidate = (session_name or "").strip()
    if not candidate:
        candidate = _generate_session_name()
    if candidate in st.session_state.sessions:
        return False, candidate
    st.session_state.sessions[candidate] = {"history": []}
    st.session_state.session_order.append(candidate)
    st.session_state.active_session_id = candidate
    st.session_state.session_select = candidate
    st.session_state.history = cast(
        list[dict], st.session_state.sessions[candidate]["history"]
    )
    st.session_state.rename_session_name = candidate
    st.session_state.rename_session_synced_to = candidate
    return True, candidate


def rename_session(current_name: str, new_name: str) -> bool:
    ensure_session_state()
    target = new_name.strip()
    if not target or target == current_name:
        return False
    if target in st.session_state.sessions:
        return False
    st.session_state.sessions[target] = st.session_state.sessions.pop(current_name)
    order = st.session_state.session_order
    order[order.index(current_name)] = target
    if st.session_state.active_session_id == current_name:
        st.session_state.active_session_id = target
        st.session_state.session_select = target
    st.session_state.history = cast(
        list[dict],
        st.session_state.sessions[st.session_state.active_session_id]["history"],
    )
    st.session_state.rename_session_name = target
    st.session_state.rename_session_synced_to = target
    return True


def delete_session(session_name: str) -> bool:
    ensure_session_state()
    if session_name not in st.session_state.sessions:
        return False
    if len(st.session_state.session_order) <= 1:
        return False
    st.session_state.sessions.pop(session_name, None)
    st.session_state.session_order.remove(session_name)
    if st.session_state.active_session_id == session_name:
        st.session_state.active_session_id = st.session_state.session_order[-1]
        st.session_state.session_select = st.session_state.active_session_id
        st.session_state.rename_session_name = st.session_state.active_session_id
        st.session_state.rename_session_synced_to = st.session_state.active_session_id
    st.session_state.history = cast(
        list[dict],
        st.session_state.sessions[st.session_state.active_session_id]["history"],
    )
    return True


def rewrite_message(
    msg: str, persona_name: str, show_rationale: bool
) -> str:
    if persona_name.lower() == "control":
        rewritten_msg = msg
        if show_rationale:
            rewritten_msg += " At the beginning of your response, please say the following in ITALIC: 'Persona Rationale: No personalization applied.'. Begin your answer on the next line."
        return rewritten_msg
    try:
        rewritten_msg = ingest_and_rewrite(
            user_id=persona_name, query=msg
        )
        if show_rationale:
            rewritten_msg += " At the beginning of your response, please say the following in ITALIC: 'Persona Rationale: ' followed by 1 sentence about how your reasoning for how the persona traits influenced this response, also in italics. Begin your answer on the next line."
    except Exception as e:
        st.error(f"Failed to ingest_and_append message: {e}")
        raise
    print(rewritten_msg)
    return rewritten_msg

# ──────────────────────────────────────────────────────────────
# Page setup & CSS
# ──────────────────────────────────────────────────────────────
st.set_page_config(page_title="MemMachine Chatbot", layout="wide")

try:
    with open("./styles.css") as f:
        st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
except FileNotFoundError:
    pass
    
ensure_session_state()


HEADER_STYLE = """
<style>
.memmachine-header-wrapper {
    display: flex;
    justify-content: flex-end;
    margin-bottom: 1.2rem;
}
.memmachine-header-links {
    display: inline-flex;
    gap: 14px;
    align-items: center;
    background: transparent;
    padding: 0;
    border-radius: 0;
}
.memmachine-header-links .powered-by {
    color: #0a6cff;
    font-weight: 700;
    font-size: 16px;
    margin-right: 6px;
    white-space: nowrap;
}
.memmachine-header-links a {
    text-decoration: none;
    color: inherit;
    display: flex;
    align-items: center;
    justify-content: center;
    padding: 0;
    border-radius: 0;
    transition: opacity 0.2s ease;
}
.memmachine-header-links a:hover {
    opacity: 0.7;
}
.memmachine-header-links img,
.memmachine-header-links svg {
    width: 22px;
    height: 22px;
}
@media (max-width: 768px) {
    .memmachine-header-wrapper {
        justify-content: center;
        margin-bottom: 0.8rem;
    }
    .memmachine-header-links {
        flex-wrap: wrap;
        row-gap: 8px;
        justify-content: center;
    }
}
</style>
"""

HEADER_HTML = """
<div class="memmachine-header-wrapper">
  <div class="memmachine-header-links">
    <span class="powered-by">Powered by MemMachine</span>
    <a href="https://memmachine.ai/" target="_blank" title="MemMachine">
      <img src="https://avatars.githubusercontent.com/u/226739620?s=48&v=4" alt="MemMachine logo"/>
    </a>
    <a href="https://github.com/MemMachine/MemMachine" target="_blank" title="GitHub Repository">
      <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor">
        <path d="M12 0c-6.626 0-12 5.373-12 12 0 5.302 3.438 9.8 8.207 11.387.599.111.793-.261.793-.577v-2.234c-3.338.726-4.033-1.416-4.033-1.416-.546-1.387-1.333-1.756-1.333-1.756-1.089-.745.083-.729.083-.729 1.205.084 1.839 1.237 1.839 1.237 1.07 1.834 2.807 1.304 3.492.997.107-.775.418-1.305.762-1.604-2.665-.305-5.467-1.334-5.467-5.931 0-1.311.469-2.381 1.236-3.221-.124-.303-.535-1.524.117-3.176 0 0 1.008-.322 3.301 1.23.957-.266 1.983-.399 3.003-.404 1.02.005 2.047.138 3.006.404 2.291-1.552 3.297-1.23 3.297-1.23.653 1.653.242 2.874.118 3.176.77.84 1.235 1.911 1.235 3.221 0 4.609-2.807 5.624-5.479 5.921.43.372.823 1.102.823 2.222v3.293c0 .319.192.694.801.576 4.765-1.589 8.199-6.086 8.199-11.386 0-6.627-5.373-12-12-12z"/>
      </svg>
    </a>
    <a href="https://discord.gg/usydANvKqD" target="_blank" title="Discord Community">
      <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor">
        <path d="M20.317 4.37a19.791 19.791 0 0 0-4.885-1.515.074.074 0 0 0-.079.037c-.21.375-.444.864-.608 1.25a18.27 18.27 0 0 0-5.487 0 12.64 12.64 0 0 0-.617-1.25.077.077 0 0 0-.079-.037A19.736 19.736 0 0 0 3.677 4.37a.07.07 0 0 0-.032.027C.533 9.046-.32 13.58.099 18.057a.082.082 0 0 0 .031.057 19.9 19.9 0 0 0 5.993 3.03.078.078 0 0 0 .084-.028c.462-.63.874-1.295 1.226-1.994a.076.076 0 0 0-.041-.106 13.107 13.107 0 0 1-1.872-.892.077.077 0 0 1-.008-.128 10.2 10.2 0 0 0 .372-.292.074.074 0 0 1 .077-.01c3.928 1.793 8.18 1.793 12.062 0a.074.074 0 0 1 .078.01c.12.098.246.198.373.292a.077.077 0 0 1-.006.127 12.299 12.299 0 0 1-1.873.892.077.077 0 0 0-.041.107c.36.698.772 1.362 1.225 1.993a.076.076 0 0 0 .084.028 19.839 19.839 0 0 0 6.002-3.03.077.077 0 0 0 .032-.054c.5-5.177-.838-9.674-3.549-13.66a.061.061 0 0 0-.031-.03zM8.02 15.33c-1.183 0-2.157-1.085-2.157-2.419 0-1.333.956-2.419 2.157-2.419 1.21 0 2.176 1.096 2.157 2.42 0 1.333-.956 2.418-2.157 2.418zm7.975 0c-1.183 0-2.157-1.085-2.157-2.419 0-1.333.955-2.419 2.157-2.419 1.21 0 2.176 1.096 2.157 2.42 0 1.333-.946 2.418-2.157 2.418z"/>
      </svg>
    </a>
  </div>
</div>
"""

st.markdown(HEADER_STYLE, unsafe_allow_html=True)
st.markdown(HEADER_HTML, unsafe_allow_html=True)



# ──────────────────────────────────────────────────────────────
# Sidebar
# ──────────────────────────────────────────────────────────────
default_model = MODEL_CHOICES[0] if MODEL_CHOICES else "gpt-4.1-mini"
model_id = default_model
provider = MODEL_TO_PROVIDER.get(default_model, "openai")
selected_persona = "Charlie"
persona_name = "Charlie"
skip_rewrite = False
compare_personas = False
show_rationale = False

with st.sidebar:
    st.markdown("#### Sessions")
    session_options = st.session_state.session_order
    active_session = st.session_state.active_session_id
    if st.session_state.rename_session_synced_to != active_session:
        st.session_state.rename_session_name = active_session
        st.session_state.rename_session_synced_to = active_session

    for idx, session_name in enumerate(session_options, start=1):
        is_active = session_name == active_session
        button_label = f"{session_name}"
        row = st.container()
        with row:
            button_col, menu_col = st.columns([0.8, 0.2])
            with button_col:
                if st.button(
                    button_label,
                    key=f"session_button_{session_name}",
                    use_container_width=True,
                    type="primary" if is_active else "secondary",
                ):
                    if not is_active:
                        st.session_state.active_session_id = session_name
                        st.session_state.session_select = session_name
                        st.session_state.history = cast(
                            list[dict],
                            st.session_state.sessions[session_name]["history"],
                        )
                        st.session_state.rename_session_name = session_name
                        st.session_state.rename_session_synced_to = session_name
                        st.rerun()
            with menu_col:
                if hasattr(st, "popover"):
                    menu_container = st.popover("β‹―", use_container_width=True)
                else:
                    menu_container = st.expander(
                        "β‹―", expanded=False, key=f"session_actions_{session_name}"
                    )
                with menu_container:
                    st.markdown(f"**Actions for {session_name}**")
                    rename_value = st.text_input(
                        "Rename session",
                        value=session_name,
                        key=f"rename_session_input_{session_name}",
                    )
                    if st.button(
                        "Rename",
                        use_container_width=True,
                        key=f"rename_session_button_{session_name}",
                    ):
                        rename_target = rename_value.strip()
                        if not rename_target:
                            st.warning("Enter a session name to rename.")
                        elif rename_target == session_name:
                            st.info("Session name unchanged.")
                        elif rename_target in st.session_state.sessions:
                            st.warning(f"Session '{rename_target}' already exists.")
                        elif rename_session(session_name, rename_target):
                            st.success(f"Session renamed to '{rename_target}'.")
                            st.rerun()
                        else:
                            st.error("Unable to rename session. Please try again.")

                    st.divider()
                    if st.button(
                        "Delete session",
                        use_container_width=True,
                        type="secondary",
                        key=f"delete_session_button_{session_name}",
                    ):
                        if delete_session(session_name):
                            new_active = st.session_state.active_session_id
                            st.session_state.session_select = new_active
                            st.session_state.rename_session_name = new_active
                            st.session_state.rename_session_synced_to = new_active
                            st.success(f"Session '{session_name}' deleted.")
                            st.rerun()
                        else:
                            st.warning("Cannot delete the last remaining session.")

    with st.form("create_session_form", clear_on_submit=True):
        new_session_name = st.text_input(
            "New session name",
            key="create_session_name",
            placeholder="Leave blank for automatic name",
        )
        if st.form_submit_button("Create session", use_container_width=True):
            success, created_name = create_session(new_session_name)
            if success:
                st.success(f"Session '{created_name}' created.")
                st.rerun()
            else:
                st.warning(f"Session '{created_name}' already exists.")

    st.divider()

    st.markdown("#### Choose Model")

    # Create display options with categories
    display_options = [MODEL_DISPLAY_NAMES[model] for model in MODEL_CHOICES]
    
    selected_display = st.selectbox(
        "Choose Model", display_options, index=0, label_visibility="collapsed"
    )
    
    # Get the actual model ID from the display name
    model_id = next(model for model, display in MODEL_DISPLAY_NAMES.items() 
                   if display == selected_display)
    
    provider = MODEL_TO_PROVIDER[model_id]
    set_model(model_id)

    st.markdown("#### User Identity")
    
    # Get Hugging Face user ID if available (in HF Spaces)
    hf_user_id = os.getenv("SPACE_USER") or os.getenv("HF_USERNAME") or os.getenv("HF_USER")
    
    # Check if we're on Hugging Face Spaces (not local)
    is_hf_space = os.getenv("SPACE_ID") is not None or os.getenv("HF_ENDPOINT") is not None
    
    def validate_hf_token(token: str) -> tuple[bool, str, str]:
        """Validate HF token and return (is_valid, username, error_message)."""
        token = token.strip()
        if not token:
            return False, "", "Token cannot be empty"
        
        # Remove any whitespace or newlines that might have been copied
        token = "".join(token.split())
        
        # Try using huggingface_hub library if available, otherwise fall back to API
        try:
            from huggingface_hub import whoami
            try:
                user_info = whoami(token=token)
                username = user_info.get("name") or user_info.get("username") or ""
                if username:
                    return True, username, ""
                else:
                    return False, "", "Token validated but username not found in response."
            except Exception as e:
                error_msg = str(e)
                if "401" in error_msg or "Unauthorized" in error_msg or "Invalid" in error_msg:
                    return False, "", f"Invalid token. Please verify your token is correct and has Read permissions. Error: {error_msg[:100]}"
                return False, "", f"Validation error: {error_msg[:150]}"
        except ImportError:
            # Fall back to direct API call if huggingface_hub not available
            pass
        
        # Fallback: Use the HF whoami endpoint directly
        endpoint = "https://huggingface.co/api/whoami"
        headers = {
            "Authorization": f"Bearer {token}",
            "User-Agent": "MemMachine-Playground/1.0"
        }
        
        try:
            resp = requests.get(endpoint, headers=headers, timeout=10)
            
            if resp.status_code == 200:
                user_data = resp.json()
                # Try different possible username fields
                username = (
                    user_data.get("name") or 
                    user_data.get("username") or 
                    user_data.get("user") or
                    ""
                )
                if username:
                    return True, username, ""
                else:
                    return False, "", f"Token validated but username not found. Response: {str(user_data)[:100]}"
            elif resp.status_code == 401:
                error_detail = ""
                try:
                    error_data = resp.json()
                    error_detail = error_data.get("error", "")
                except:
                    pass
                return False, "", f"Invalid token (401). The token may be expired, revoked, or incorrect. {error_detail} Please create a new Read token at https://huggingface.co/settings/tokens"
            elif resp.status_code == 403:
                return False, "", f"Token access denied (403). Please ensure your token has Read permissions."
            else:
                error_text = ""
                try:
                    error_data = resp.json()
                    error_text = error_data.get("error", resp.text[:100])
                except:
                    error_text = resp.text[:100] if hasattr(resp, 'text') else f"Status {resp.status_code}"
                return False, "", f"Authentication failed (Status {resp.status_code}): {error_text}"
                
        except requests.exceptions.Timeout:
            return False, "", "Request timed out. Please check your internet connection and try again."
        except requests.exceptions.RequestException as e:
            return False, "", f"Network error: {str(e)}. Please try again."
        except Exception as e:
            return False, "", f"Validation error: {str(e)}. Please try again."
    
    if is_hf_space:
        # On HF Spaces - require token authentication for security
        if "hf_authenticated_user" not in st.session_state:
            st.warning("πŸ” **Authentication Required**")
            st.caption("To protect your memories, please authenticate with your Hugging Face account.")
            
            token_input = st.text_input(
                "Enter your Hugging Face Access Token",
                key="hf_token_input",
                type="password",
                placeholder="hf_xxxxxxxxxxxxxxxxxxxxx",
                help="Get your token from: https://huggingface.co/settings/tokens"
            )
            
            if st.button("Authenticate", use_container_width=True, type="primary"):
                if token_input.strip():
                    with st.spinner("Validating token..."):
                        is_valid, username, error_msg = validate_hf_token(token_input.strip())
                    if is_valid and username:
                        st.session_state.hf_authenticated_user = username
                        st.session_state.hf_token = token_input.strip()  # Store for future use
                        st.success(f"βœ… Authenticated as **{username}**")
                        st.rerun()
                    else:
                        error_display = error_msg if error_msg else "Invalid token. Please check your Hugging Face access token."
                        st.error(f"❌ {error_display}")
                else:
                    st.error("Please enter your access token")
            st.info("πŸ’‘ **Privacy Note:** Your token is stored only in this session and never shared.")
            st.stop()
        else:
            # User is authenticated - lock to their username
            persona_name = st.session_state.hf_authenticated_user
            st.success(f"πŸ” Authenticated as: **{persona_name}**")
            st.caption("Your memories are secured to your account only.")
            if st.button("πŸ”“ Sign Out", use_container_width=True):
                del st.session_state.hf_authenticated_user
                if "hf_token" in st.session_state:
                    del st.session_state.hf_token
                st.rerun()
    elif hf_user_id:
        # HF user ID detected automatically
        persona_name = hf_user_id
        st.info(f"πŸ‘€ Signed in as: **{hf_user_id}**")
        st.caption("Your memories are personalized to your account.")
    else:
        # Local/testing mode - allow persona selection
        selected_persona = st.selectbox(
            "Choose user persona",
            ["Charlie", "Jing", "Charles", "Control"],
            label_visibility="collapsed",
        )
        custom_persona = st.text_input("Or enter your name", "")
        persona_name = (
            custom_persona.strip() if custom_persona.strip() else selected_persona
        )

    compare_personas = st.checkbox("Compare without MemMachine")
    show_rationale = st.checkbox("Show Persona Rationale")

    st.divider()
    if st.button("Clear chat", use_container_width=True):
        active = st.session_state.active_session_id
        st.session_state.sessions[active]["history"].clear()
        st.session_state.history = cast(
            list[dict],
            st.session_state.sessions[active]["history"],
        )
        st.rerun()
    if st.button("Delete Profile", use_container_width=True):
        success = delete_profile(persona_name)
        active = st.session_state.active_session_id
        st.session_state.sessions[active]["history"].clear()
        st.session_state.history = cast(
            list[dict],
            st.session_state.sessions[active]["history"],
        )
        if success:
            st.success(f"Profile for '{persona_name}' deleted.")
        else:
            st.error(f"Failed to delete profile for '{persona_name}'.")
    st.divider()



# ──────────────────────────────────────────────────────────────
# Enforce alternating roles
# ──────────────────────────────────────────────────────────────
def clean_history(history: list[dict], persona: str) -> list[dict]:
    out = []
    for turn in history:
        if turn.get("role") == "user":
            out.append({"role": "user", "content": turn["content"]})
        elif turn.get("role") == "assistant" and turn.get("persona") == persona:
            out.append({"role": "assistant", "content": turn["content"]})
    cleaned = []
    last_role = None
    for msg in out:
        if msg["role"] != last_role:
            cleaned.append(msg)
            last_role = msg["role"]
    return cleaned


def append_user_turn(msgs: list[dict], new_user_msg: str) -> list[dict]:
    if msgs and msgs[-1]["role"] == "user":
        msgs[-1] = {"role": "user", "content": new_user_msg}
    else:
        msgs.append({"role": "user", "content": new_user_msg})
    return msgs


def typewriter_effect(text: str, speed: float = 0.02):
    """Generator that yields text word by word to create a typing effect."""
    words = text.split(" ")
    for i, word in enumerate(words):
        if i == 0:
            yield word
        else:
            yield " " + word
        time.sleep(speed)

msg = st.chat_input("Type your message…")
if msg:
    st.session_state.history.append({"role": "user", "content": msg})
    if compare_personas:
        all_answers = {}
        rewritten_msg = rewrite_message(msg, persona_name, show_rationale)
        msgs = clean_history(st.session_state.history, persona_name)
        msgs = append_user_turn(msgs, rewritten_msg)
        try:
            txt, lat, tok, tps = chat(msgs, persona_name)
            all_answers[persona_name] = txt
        except ValueError as e:
            st.error(f"❌ {str(e)}")
            st.stop()

        rewritten_msg_control = rewrite_message(msg, "Control", show_rationale)
        msgs_control = clean_history(st.session_state.history, "Control")
        msgs_control = append_user_turn(msgs_control, rewritten_msg_control)
        try:
            txt_control, lat, tok, tps = chat(msgs_control, "Arnold")
            all_answers["Control"] = txt_control
        except ValueError as e:
            st.error(f"❌ {str(e)}")
            st.stop()

        st.session_state.history.append(
            {"role": "assistant_all", "axis": "role", "content": all_answers, "is_new": True}
        )
    else:
        rewritten_msg = rewrite_message(msg, persona_name, show_rationale)
        msgs = clean_history(st.session_state.history, persona_name)
        msgs = append_user_turn(msgs, rewritten_msg)
        try:
            txt, lat, tok, tps = chat(
                msgs, "Arnold" if persona_name == "Control" else persona_name
            )
            st.session_state.history.append(
                {"role": "assistant", "persona": persona_name, "content": txt, "is_new": True}
            )
        except ValueError as e:
            st.error(f"❌ {str(e)}")
            st.stop()
    st.rerun()


# ──────────────────────────────────────────────────────────────
# Chat history display
# ──────────────────────────────────────────────────────────────
for turn in st.session_state.history:
    if turn.get("role") == "user":
        st.chat_message("user").write(turn["content"])
    elif turn.get("role") == "assistant":
        with st.chat_message("assistant"):
            # Use typing effect for new messages, normal display for old ones
            if turn.get("is_new", False):
                st.write_stream(typewriter_effect(turn["content"]))
                # Mark as no longer new so it displays normally on rerun
                turn["is_new"] = False
            else:
                st.write(turn["content"])
    elif turn.get("role") == "assistant_all":
        content_items = list(turn["content"].items())
        is_new = turn.get("is_new", False)
        if len(content_items) >= 2:
            cols = st.columns([1, 0.03, 1])
            persona_label, persona_response = content_items[0]
            control_label, control_response = content_items[1]
            with cols[0]:
                st.markdown(f"**{persona_label}**")
                if is_new:
                    st.write_stream(typewriter_effect(persona_response))
                else:
                    st.markdown(
                        f'<div class="answer">{persona_response}</div>',
                        unsafe_allow_html=True,
                    )
            with cols[1]:
                st.markdown(
                    '<div class="vertical-divider"></div>', unsafe_allow_html=True
                )
            with cols[2]:
                st.markdown(f"**{control_label}**")
                if is_new:
                    st.write_stream(typewriter_effect(control_response))
                else:
                    st.markdown(
                        f'<div class="answer">{control_response}</div>',
                        unsafe_allow_html=True,
                    )
        else:
            for label, response in content_items:
                st.markdown(f"**{label}**")
                if is_new:
                    st.write_stream(typewriter_effect(response))
                else:
                    st.markdown(
                        f'<div class="answer">{response}</div>', unsafe_allow_html=True
                    )
        # Mark as no longer new
        if is_new:
            turn["is_new"] = False