bdip13 / src /streamlit_app.py
isinec's picture
Update src/streamlit_app.py
cd16a9a verified
Raw
History Blame Contribute Delete
14.7 kB
import altair as alt
import numpy as np
import pandas as pd
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
import os
import logging
from datetime import datetime
from urllib.parse import urlparse, parse_qs
import json
import uuid
import re
LOG_DIR = "/data/session_logs"
os.makedirs(LOG_DIR, exist_ok=True)
BASE_MODEL = "meta-llama/Meta-Llama-3-8B-Instruct"
ADAPTER_PATH = "Anxo/erisk26-task1-patient-13-adapter"
match = re.search(r'patient-(\d+)', ADAPTER_PATH)
st.session_state.patient_id = int(match.group(1)) if match else 21
# ─── BDI-II Items: each entry is (symptom_name, [option_0, option_1, option_2, option_3])
# Exact wording from the BDI-II (Beck, 1996)
BDI_ITEMS = {
"1. Sadness": [
"0",
"1",
"2",
"3",
],
"2. Pessimism": [
"0",
"1",
"2",
"3",
],
"3. Past Failure": [
"0",
"1",
"2",
"3",
],
"4. Loss of Pleasure": [
"0",
"1",
"2",
"3",
],
"5. Guilty Feelings": [
"0",
"1",
"2",
"3",
],
"6. Punishment Feelings": [
"0",
"1",
"2",
"3",
],
"7. Self-Dislike": [
"0",
"1",
"2",
"3",
],
"8. Self-Criticalness": [
"0",
"1",
"2",
"3",
],
"9. Suicidal Thoughts or Wishes": [
"0",
"1",
"2",
"3",
],
"10. Crying": [
"0",
"1",
"2",
"3",
],
"11. Agitation": [
"0",
"1",
"2",
"3",
],
"12. Loss of Interest": [
"0",
"1",
"2",
"3",
],
"13. Indecisiveness": [
"0",
"1",
"2",
"3",
],
"14. Worthlessness": [
"0",
"1",
"2",
"3",
],
"15. Loss of Energy": [
"0",
"1",
"2",
"3",
],
"16. Changes in Sleeping Pattern": [
"0",
"1",
"2",
"3",
"4",
"5",
"6",
],
"17. Irritability": [
"0",
"1",
"2",
"3",
],
"18. Changes in Appetite": [
"0",
"1",
"2",
"3",
"4",
"5",
"6",
],
"19. Concentration Difficulty": [
"0",
"1",
"2",
"3",
],
"20. Tiredness or Fatigue": [
"0",
"1",
"2",
"3",
],
"21. Loss of Interest in Sex": [
"0",
"1",
"2",
"3",
],
}
query_params = st.query_params
is_admin = query_params.get("admin") == "1234"
if is_admin:
st.sidebar.subheader("Admin: Session Logs")
if os.path.isdir(LOG_DIR):
files = sorted(os.listdir(LOG_DIR))
if files:
selected = st.sidebar.selectbox("Select a session", files)
fpath = os.path.join(LOG_DIR, selected)
# Optional: show brief metadata
with open(fpath, "r", encoding="utf-8") as f:
data = json.load(f)
st.sidebar.write(f"Session ID: {data.get('session_id')}")
st.sidebar.write(f"Patient ID: {data.get('patient_id')}")
st.sidebar.write(f"Prolific ID: {data.get('prolific_id', 'N/A')}")
st.sidebar.write(f"Start: {data.get('start_time')}")
# Download button
with open(fpath, "rb") as f:
st.sidebar.download_button(
label=f"⬇️ Download {selected}",
data=f,
file_name=selected,
mime="application/json",
key=f"dl_{selected}",
)
else:
st.sidebar.info("No logs yet.")
else:
st.sidebar.warning("Log directory not found.")
# ─── SYMPTOM CHECKLIST SIDEBAR ───────────────────────────────────────────────
st.sidebar.markdown("---")
# Build a lookup: symptom β†’ best score label seen across all annotations
annotated_symptoms: dict[str, str] = {}
for annotation in st.session_state.get("turn_annotations", []):
for sym, entry in annotation.get("scores", {}).items():
# Keep the annotation; if the same symptom was rated in multiple turns,
# show the most recent entry (later turns overwrite earlier ones here).
annotated_symptoms[sym] = entry.get("label", "")
# ─── Progress bar ─────────────────────────────────────────────────────────────
n_annotated = len(annotated_symptoms)
n_total = len(BDI_ITEMS)
is_complete = n_annotated >= n_total
st.sidebar.caption(f"{n_annotated} / {n_total} symptoms annotated")
st.sidebar.progress(n_annotated / n_total)
st.sidebar.subheader("🩺 BDI-II Symptom Checklist")
# Render one row per BDI-II item
for symptom in BDI_ITEMS:
if symptom in annotated_symptoms:
label_text = annotated_symptoms[symptom]
# Extract the leading score digit for a compact display
score_digit = label_text.split("–")[0].strip() if "–" in label_text else "?"
st.sidebar.markdown(
f"βœ… **{symptom}** β€” *{label_text}*"
)
else:
st.sidebar.markdown(f"⬜ {symptom}")
st.sidebar.markdown("---")
def save_session_snapshot():
record = {
"session_id": st.session_state.session_id,
"patient_id": st.session_state.patient_id,
"prolific_id": st.session_state.get("prolific_id"),
"start_time": st.session_state.start_time,
"timestamp": datetime.utcnow().isoformat(),
"messages": st.session_state.messages,
"turn_annotations": st.session_state.turn_annotations,
}
fname = f"{st.session_state.patient_id}_{st.session_state.prolific_id}_{st.session_state.session_id}.json"
fpath = os.path.join(LOG_DIR, fname)
with open(fpath, "w", encoding="utf-8") as f:
json.dump(record, f, ensure_ascii=False, indent=2)
@st.cache_resource
def load_model():
# Load tokenizer from adapter repo (includes special tokens)
tokenizer = AutoTokenizer.from_pretrained(ADAPTER_PATH)
tokenizer.pad_token = tokenizer.eos_token
# Define 4-bit quantization config explicitly
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
# Load base model using the quantization_config kwarg
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
quantization_config=bnb_config,
device_map="auto",
dtype=torch.float16, # replaces deprecated torch_dtype
)
# Attach the LoRA adapter
model = PeftModel.from_pretrained(base_model, ADAPTER_PATH)
model.eval()
# print("Model + adapter loaded successfully!")
return tokenizer, model
tokenizer, model = load_model()
terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
# --- init session state ---
if "submitted" not in st.session_state:
st.session_state.submitted = False
if "session_id" not in st.session_state:
st.session_state.session_id = str(uuid.uuid4())
if "start_time" not in st.session_state:
st.session_state.start_time = datetime.utcnow().isoformat()
if "prolific_id" not in st.session_state: # ← ADD THIS
st.session_state.prolific_id = None
if "messages" not in st.session_state:
st.session_state.messages = [
{"role": "system", "content": "You are a simulated patient. Act realistically based on your internal training. Ensure contextual realism. Avoid overly detailed or formal speech. Keep natural speaking style (e.g., short answers, hesitations, casual expressions). Do not mention you are an AI."},
]
if "turn_annotations" not in st.session_state:
st.session_state.turn_annotations = []
# Set browser tab title
st.set_page_config(page_title="BDI-II Study")
st.title("BDI-II Patient Chatbot πŸ’¬")
# ── Prolific ID gate ──────────────────────────────────────────────────────────
if st.session_state.prolific_id is None:
st.markdown("### Before we begin")
st.markdown(
"Please enter your **Prolific ID** below. "
"You can find it in your Prolific dashboard."
"This step is important for the payment to be approved."
)
with st.form("prolific_form"):
pid = st.text_input("Prolific ID", placeholder="e.g. 5e3f2a1b9c8d7e6f5a4b3c2d")
submitted = st.form_submit_button("Continue β†’")
if submitted:
pid = pid.strip()
if pid:
st.session_state.prolific_id = pid
save_session_snapshot() # persist ID immediately
st.rerun()
else:
st.error("Please enter your Prolific ID before continuing.")
st.stop() # ← blocks everything below until the form is filled
# Streamlit UI
st.markdown("Please talk to the chatbot to administer the BDI-II test.")
# ─── Thank-you screen (shown after submission) ─────────────────────────────────
if st.session_state.get("submitted", False):
st.success("## πŸŽ‰ Thank you for your participation!")
st.markdown(
"""
Your session has been saved successfully.
**Next steps:**
- Return to Prolific and complete the submission there using the code "CI3N0FC3".
- Your payment will be processed once your Prolific ID is verified.
You may now close this window.
"""
)
st.balloons()
st.stop()
# --- display history (skip system) ---
assistant_turn_counter = 0
for msg in st.session_state.messages:
if msg["role"] == "system":
continue
with st.chat_message("user" if msg["role"] == "user" else "assistant"):
st.markdown(msg["content"])
if msg["role"] == "assistant":
assistant_turn_counter += 1
turn_idx = assistant_turn_counter
existing = next(
(a for a in st.session_state.turn_annotations if a["turn"] == turn_idx), None
)
with st.expander(
f"🩺 Annotate Turn {turn_idx}" + (" βœ…" if existing else ""),
expanded=(existing is None),
):
if existing:
st.markdown(
f"**Saved symptoms:** {', '.join(existing['symptoms']) if existing['symptoms'] else 'β€”'}"
)
for sym, entry in existing["scores"].items():
st.markdown(f"- **{sym}:** {entry['label']} (score: {entry['score']})")
else:
selected_symptoms = st.multiselect(
"Which BDI-II symptoms do you infer from this turn?",
options=list(BDI_ITEMS.keys()),
key=f"symptoms_{turn_idx}",
)
per_symptom_scores = {}
if selected_symptoms:
st.markdown("**Select the best-matching statement for each symptom:**")
for sym in selected_symptoms:
options = BDI_ITEMS[sym]
chosen = st.radio(
sym,
options=options,
key=f"radio_{turn_idx}_{sym}",
)
per_symptom_scores[sym] = {
"label": chosen,
"score": chosen,
}
if st.button("Next turn ->", key=f"save_{turn_idx}"):
st.session_state.turn_annotations.append({
"turn": turn_idx,
"symptoms": selected_symptoms,
"scores": per_symptom_scores,
"annotated_at": datetime.utcnow().isoformat(),
})
save_session_snapshot()
st.success(f"Annotation for turn {turn_idx} saved!")
st.rerun()
# ─── Inline Submit button (only shown at 21/21, above chat input) ──────────────
if is_complete:
st.markdown("---")
st.info(
"βœ… **All 21 symptoms have been annotated.** "
"You may continue chatting, or submit your session when ready."
)
if st.button("πŸš€ Submit & Finish Session", type="primary", key="inline_submit"):
st.session_state.submitted = True
save_session_snapshot()
st.rerun()
st.markdown("---")
# --- chat input ---
if prompt := st.chat_input("Type your next BDI-II question or remark..."):
# add user turn
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
# 2. Format history & Create Attention Mask
# return_dict=True gives us the 'attention_mask' automatically
inputs = tokenizer.apply_chat_template(
st.session_state.messages,
add_generation_prompt=True,
return_tensors="pt",
return_dict=True
).to(model.device)
# print("Here 2")
# 3. Generate response
# explicitly passing attention_mask prevents the warning you saw earlier
with torch.no_grad():
outputs = model.generate(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=256,
eos_token_id=terminators,
pad_token_id=tokenizer.eos_token_id,
do_sample=True,
temperature=0.6,
top_p=0.9,
)
# print("Here 3")
# 4. Decode response
# We slice [input_len:] to ensure we don't print the prompt back to the user
response_tokens = outputs[0][inputs.input_ids.shape[-1]:]
assistant_text = tokenizer.decode(response_tokens, skip_special_tokens=True)
# print(f"Patient: {assistant_text}")
with st.chat_message("assistant"):
st.markdown(assistant_text)
# 5. Append assistant response to history
# messages.append({"role": "assistant", "content": assistant_text})
st.session_state.messages.append({"role": "assistant", "content": assistant_text})
save_session_snapshot()
st.rerun()