| import altair as alt |
| import numpy as np |
| import pandas as pd |
| import streamlit as st |
| import torch |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
| from peft import PeftModel |
| import os |
| import logging |
| from datetime import datetime |
| from urllib.parse import urlparse, parse_qs |
| import json |
| import uuid |
| import re |
|
|
| LOG_DIR = "/data/session_logs" |
| os.makedirs(LOG_DIR, exist_ok=True) |
|
|
| BASE_MODEL = "meta-llama/Meta-Llama-3-8B-Instruct" |
| ADAPTER_PATH = "Anxo/erisk26-task1-patient-13-adapter" |
|
|
| match = re.search(r'patient-(\d+)', ADAPTER_PATH) |
| st.session_state.patient_id = int(match.group(1)) if match else 21 |
|
|
| |
| |
| BDI_ITEMS = { |
| "1. Sadness": [ |
| "0", |
| "1", |
| "2", |
| "3", |
| ], |
| "2. Pessimism": [ |
| "0", |
| "1", |
| "2", |
| "3", |
| ], |
| "3. Past Failure": [ |
| "0", |
| "1", |
| "2", |
| "3", |
| ], |
| "4. Loss of Pleasure": [ |
| "0", |
| "1", |
| "2", |
| "3", |
| ], |
| "5. Guilty Feelings": [ |
| "0", |
| "1", |
| "2", |
| "3", |
| ], |
| "6. Punishment Feelings": [ |
| "0", |
| "1", |
| "2", |
| "3", |
| ], |
| "7. Self-Dislike": [ |
| "0", |
| "1", |
| "2", |
| "3", |
| ], |
| "8. Self-Criticalness": [ |
| "0", |
| "1", |
| "2", |
| "3", |
| ], |
| "9. Suicidal Thoughts or Wishes": [ |
| "0", |
| "1", |
| "2", |
| "3", |
| ], |
| "10. Crying": [ |
| "0", |
| "1", |
| "2", |
| "3", |
| ], |
| "11. Agitation": [ |
| "0", |
| "1", |
| "2", |
| "3", |
| ], |
| "12. Loss of Interest": [ |
| "0", |
| "1", |
| "2", |
| "3", |
| ], |
| "13. Indecisiveness": [ |
| "0", |
| "1", |
| "2", |
| "3", |
| ], |
| "14. Worthlessness": [ |
| "0", |
| "1", |
| "2", |
| "3", |
| ], |
| "15. Loss of Energy": [ |
| "0", |
| "1", |
| "2", |
| "3", |
| ], |
| "16. Changes in Sleeping Pattern": [ |
| "0", |
| "1", |
| "2", |
| "3", |
| "4", |
| "5", |
| "6", |
| ], |
| "17. Irritability": [ |
| "0", |
| "1", |
| "2", |
| "3", |
| ], |
| "18. Changes in Appetite": [ |
| "0", |
| "1", |
| "2", |
| "3", |
| "4", |
| "5", |
| "6", |
| ], |
| "19. Concentration Difficulty": [ |
| "0", |
| "1", |
| "2", |
| "3", |
| ], |
| "20. Tiredness or Fatigue": [ |
| "0", |
| "1", |
| "2", |
| "3", |
| ], |
| "21. Loss of Interest in Sex": [ |
| "0", |
| "1", |
| "2", |
| "3", |
| ], |
| } |
|
|
| query_params = st.query_params |
| is_admin = query_params.get("admin") == "1234" |
|
|
| if is_admin: |
| st.sidebar.subheader("Admin: Session Logs") |
|
|
| if os.path.isdir(LOG_DIR): |
| files = sorted(os.listdir(LOG_DIR)) |
| if files: |
| selected = st.sidebar.selectbox("Select a session", files) |
| fpath = os.path.join(LOG_DIR, selected) |
|
|
| |
| with open(fpath, "r", encoding="utf-8") as f: |
| data = json.load(f) |
| st.sidebar.write(f"Session ID: {data.get('session_id')}") |
| st.sidebar.write(f"Patient ID: {data.get('patient_id')}") |
| st.sidebar.write(f"Prolific ID: {data.get('prolific_id', 'N/A')}") |
| st.sidebar.write(f"Start: {data.get('start_time')}") |
|
|
| |
| with open(fpath, "rb") as f: |
| st.sidebar.download_button( |
| label=f"β¬οΈ Download {selected}", |
| data=f, |
| file_name=selected, |
| mime="application/json", |
| key=f"dl_{selected}", |
| ) |
| else: |
| st.sidebar.info("No logs yet.") |
| else: |
| st.sidebar.warning("Log directory not found.") |
|
|
| |
| st.sidebar.markdown("---") |
|
|
| |
| annotated_symptoms: dict[str, str] = {} |
| for annotation in st.session_state.get("turn_annotations", []): |
| for sym, entry in annotation.get("scores", {}).items(): |
| |
| |
| annotated_symptoms[sym] = entry.get("label", "") |
|
|
| |
| n_annotated = len(annotated_symptoms) |
| n_total = len(BDI_ITEMS) |
| is_complete = n_annotated >= n_total |
| st.sidebar.caption(f"{n_annotated} / {n_total} symptoms annotated") |
| st.sidebar.progress(n_annotated / n_total) |
| st.sidebar.subheader("π©Ί BDI-II Symptom Checklist") |
|
|
| |
| for symptom in BDI_ITEMS: |
| if symptom in annotated_symptoms: |
| label_text = annotated_symptoms[symptom] |
| |
| score_digit = label_text.split("β")[0].strip() if "β" in label_text else "?" |
| st.sidebar.markdown( |
| f"β
**{symptom}** β *{label_text}*" |
| ) |
| else: |
| st.sidebar.markdown(f"β¬ {symptom}") |
| st.sidebar.markdown("---") |
|
|
|
|
| def save_session_snapshot(): |
| record = { |
| "session_id": st.session_state.session_id, |
| "patient_id": st.session_state.patient_id, |
| "prolific_id": st.session_state.get("prolific_id"), |
| "start_time": st.session_state.start_time, |
| "timestamp": datetime.utcnow().isoformat(), |
| "messages": st.session_state.messages, |
| "turn_annotations": st.session_state.turn_annotations, |
| } |
| fname = f"{st.session_state.patient_id}_{st.session_state.prolific_id}_{st.session_state.session_id}.json" |
| fpath = os.path.join(LOG_DIR, fname) |
| with open(fpath, "w", encoding="utf-8") as f: |
| json.dump(record, f, ensure_ascii=False, indent=2) |
|
|
| @st.cache_resource |
| def load_model(): |
| |
| tokenizer = AutoTokenizer.from_pretrained(ADAPTER_PATH) |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| |
| bnb_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_use_double_quant=True, |
| bnb_4bit_quant_type="nf4", |
| bnb_4bit_compute_dtype=torch.float16, |
| ) |
|
|
| |
| base_model = AutoModelForCausalLM.from_pretrained( |
| BASE_MODEL, |
| quantization_config=bnb_config, |
| device_map="auto", |
| dtype=torch.float16, |
| ) |
|
|
| |
| model = PeftModel.from_pretrained(base_model, ADAPTER_PATH) |
| model.eval() |
| |
| return tokenizer, model |
|
|
| tokenizer, model = load_model() |
|
|
| terminators = [ |
| tokenizer.eos_token_id, |
| tokenizer.convert_tokens_to_ids("<|eot_id|>") |
| ] |
|
|
| |
| if "submitted" not in st.session_state: |
| st.session_state.submitted = False |
| if "session_id" not in st.session_state: |
| st.session_state.session_id = str(uuid.uuid4()) |
| if "start_time" not in st.session_state: |
| st.session_state.start_time = datetime.utcnow().isoformat() |
| if "prolific_id" not in st.session_state: |
| st.session_state.prolific_id = None |
| if "messages" not in st.session_state: |
| st.session_state.messages = [ |
| {"role": "system", "content": "You are a simulated patient. Act realistically based on your internal training. Ensure contextual realism. Avoid overly detailed or formal speech. Keep natural speaking style (e.g., short answers, hesitations, casual expressions). Do not mention you are an AI."}, |
| ] |
| if "turn_annotations" not in st.session_state: |
| st.session_state.turn_annotations = [] |
|
|
| |
| st.set_page_config(page_title="BDI-II Study") |
| st.title("BDI-II Patient Chatbot π¬") |
|
|
| |
| if st.session_state.prolific_id is None: |
| st.markdown("### Before we begin") |
| st.markdown( |
| "Please enter your **Prolific ID** below. " |
| "You can find it in your Prolific dashboard." |
| "This step is important for the payment to be approved." |
| ) |
| with st.form("prolific_form"): |
| pid = st.text_input("Prolific ID", placeholder="e.g. 5e3f2a1b9c8d7e6f5a4b3c2d") |
| submitted = st.form_submit_button("Continue β") |
| if submitted: |
| pid = pid.strip() |
| if pid: |
| st.session_state.prolific_id = pid |
| save_session_snapshot() |
| st.rerun() |
| else: |
| st.error("Please enter your Prolific ID before continuing.") |
| st.stop() |
|
|
| |
| st.markdown("Please talk to the chatbot to administer the BDI-II test.") |
|
|
| |
| if st.session_state.get("submitted", False): |
| st.success("## π Thank you for your participation!") |
| st.markdown( |
| """ |
| Your session has been saved successfully. |
| |
| **Next steps:** |
| - Return to Prolific and complete the submission there using the code "CI3N0FC3". |
| - Your payment will be processed once your Prolific ID is verified. |
| |
| You may now close this window. |
| """ |
| ) |
| st.balloons() |
| st.stop() |
|
|
| |
| assistant_turn_counter = 0 |
| for msg in st.session_state.messages: |
| if msg["role"] == "system": |
| continue |
| with st.chat_message("user" if msg["role"] == "user" else "assistant"): |
| st.markdown(msg["content"]) |
|
|
| if msg["role"] == "assistant": |
| assistant_turn_counter += 1 |
| turn_idx = assistant_turn_counter |
|
|
| existing = next( |
| (a for a in st.session_state.turn_annotations if a["turn"] == turn_idx), None |
| ) |
|
|
| with st.expander( |
| f"π©Ί Annotate Turn {turn_idx}" + (" β
" if existing else ""), |
| expanded=(existing is None), |
| ): |
| if existing: |
| st.markdown( |
| f"**Saved symptoms:** {', '.join(existing['symptoms']) if existing['symptoms'] else 'β'}" |
| ) |
| for sym, entry in existing["scores"].items(): |
| st.markdown(f"- **{sym}:** {entry['label']} (score: {entry['score']})") |
| else: |
| selected_symptoms = st.multiselect( |
| "Which BDI-II symptoms do you infer from this turn?", |
| options=list(BDI_ITEMS.keys()), |
| key=f"symptoms_{turn_idx}", |
| ) |
|
|
| per_symptom_scores = {} |
| if selected_symptoms: |
| st.markdown("**Select the best-matching statement for each symptom:**") |
| for sym in selected_symptoms: |
| options = BDI_ITEMS[sym] |
| chosen = st.radio( |
| sym, |
| options=options, |
| key=f"radio_{turn_idx}_{sym}", |
| ) |
| per_symptom_scores[sym] = { |
| "label": chosen, |
| "score": chosen, |
| } |
|
|
| if st.button("Next turn ->", key=f"save_{turn_idx}"): |
| st.session_state.turn_annotations.append({ |
| "turn": turn_idx, |
| "symptoms": selected_symptoms, |
| "scores": per_symptom_scores, |
| "annotated_at": datetime.utcnow().isoformat(), |
| }) |
| save_session_snapshot() |
| st.success(f"Annotation for turn {turn_idx} saved!") |
| st.rerun() |
|
|
| |
| if is_complete: |
| st.markdown("---") |
| st.info( |
| "β
**All 21 symptoms have been annotated.** " |
| "You may continue chatting, or submit your session when ready." |
| ) |
| if st.button("π Submit & Finish Session", type="primary", key="inline_submit"): |
| st.session_state.submitted = True |
| save_session_snapshot() |
| st.rerun() |
| st.markdown("---") |
|
|
| |
| if prompt := st.chat_input("Type your next BDI-II question or remark..."): |
| |
| st.session_state.messages.append({"role": "user", "content": prompt}) |
| with st.chat_message("user"): |
| st.markdown(prompt) |
| |
| |
| |
| inputs = tokenizer.apply_chat_template( |
| st.session_state.messages, |
| add_generation_prompt=True, |
| return_tensors="pt", |
| return_dict=True |
| ).to(model.device) |
| |
| |
| |
| with torch.no_grad(): |
| outputs = model.generate( |
| input_ids=inputs.input_ids, |
| attention_mask=inputs.attention_mask, |
| max_new_tokens=256, |
| eos_token_id=terminators, |
| pad_token_id=tokenizer.eos_token_id, |
| do_sample=True, |
| temperature=0.6, |
| top_p=0.9, |
| ) |
| |
|
|
| |
| |
| response_tokens = outputs[0][inputs.input_ids.shape[-1]:] |
| assistant_text = tokenizer.decode(response_tokens, skip_special_tokens=True) |
|
|
| |
| with st.chat_message("assistant"): |
| st.markdown(assistant_text) |
|
|
| |
| |
| st.session_state.messages.append({"role": "assistant", "content": assistant_text}) |
|
|
| save_session_snapshot() |
| st.rerun() |