import altair as alt import numpy as np import pandas as pd import streamlit as st import torch from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig from peft import PeftModel import os import logging from datetime import datetime from urllib.parse import urlparse, parse_qs import json import uuid import re LOG_DIR = "/data/session_logs" os.makedirs(LOG_DIR, exist_ok=True) BASE_MODEL = "meta-llama/Meta-Llama-3-8B-Instruct" ADAPTER_PATH = "Anxo/erisk26-task1-patient-13-adapter" match = re.search(r'patient-(\d+)', ADAPTER_PATH) st.session_state.patient_id = int(match.group(1)) if match else 21 # ─── BDI-II Items: each entry is (symptom_name, [option_0, option_1, option_2, option_3]) # Exact wording from the BDI-II (Beck, 1996) BDI_ITEMS = { "1. Sadness": [ "0", "1", "2", "3", ], "2. Pessimism": [ "0", "1", "2", "3", ], "3. Past Failure": [ "0", "1", "2", "3", ], "4. Loss of Pleasure": [ "0", "1", "2", "3", ], "5. Guilty Feelings": [ "0", "1", "2", "3", ], "6. Punishment Feelings": [ "0", "1", "2", "3", ], "7. Self-Dislike": [ "0", "1", "2", "3", ], "8. Self-Criticalness": [ "0", "1", "2", "3", ], "9. Suicidal Thoughts or Wishes": [ "0", "1", "2", "3", ], "10. Crying": [ "0", "1", "2", "3", ], "11. Agitation": [ "0", "1", "2", "3", ], "12. Loss of Interest": [ "0", "1", "2", "3", ], "13. Indecisiveness": [ "0", "1", "2", "3", ], "14. Worthlessness": [ "0", "1", "2", "3", ], "15. Loss of Energy": [ "0", "1", "2", "3", ], "16. Changes in Sleeping Pattern": [ "0", "1", "2", "3", "4", "5", "6", ], "17. Irritability": [ "0", "1", "2", "3", ], "18. Changes in Appetite": [ "0", "1", "2", "3", "4", "5", "6", ], "19. Concentration Difficulty": [ "0", "1", "2", "3", ], "20. Tiredness or Fatigue": [ "0", "1", "2", "3", ], "21. Loss of Interest in Sex": [ "0", "1", "2", "3", ], } query_params = st.query_params is_admin = query_params.get("admin") == "1234" if is_admin: st.sidebar.subheader("Admin: Session Logs") if os.path.isdir(LOG_DIR): files = sorted(os.listdir(LOG_DIR)) if files: selected = st.sidebar.selectbox("Select a session", files) fpath = os.path.join(LOG_DIR, selected) # Optional: show brief metadata with open(fpath, "r", encoding="utf-8") as f: data = json.load(f) st.sidebar.write(f"Session ID: {data.get('session_id')}") st.sidebar.write(f"Patient ID: {data.get('patient_id')}") st.sidebar.write(f"Prolific ID: {data.get('prolific_id', 'N/A')}") st.sidebar.write(f"Start: {data.get('start_time')}") # Download button with open(fpath, "rb") as f: st.sidebar.download_button( label=f"⬇️ Download {selected}", data=f, file_name=selected, mime="application/json", key=f"dl_{selected}", ) else: st.sidebar.info("No logs yet.") else: st.sidebar.warning("Log directory not found.") # ─── SYMPTOM CHECKLIST SIDEBAR ─────────────────────────────────────────────── st.sidebar.markdown("---") # Build a lookup: symptom → best score label seen across all annotations annotated_symptoms: dict[str, str] = {} for annotation in st.session_state.get("turn_annotations", []): for sym, entry in annotation.get("scores", {}).items(): # Keep the annotation; if the same symptom was rated in multiple turns, # show the most recent entry (later turns overwrite earlier ones here). annotated_symptoms[sym] = entry.get("label", "") # ─── Progress bar ───────────────────────────────────────────────────────────── n_annotated = len(annotated_symptoms) n_total = len(BDI_ITEMS) is_complete = n_annotated >= n_total st.sidebar.caption(f"{n_annotated} / {n_total} symptoms annotated") st.sidebar.progress(n_annotated / n_total) st.sidebar.subheader("🩺 BDI-II Symptom Checklist") # Render one row per BDI-II item for symptom in BDI_ITEMS: if symptom in annotated_symptoms: label_text = annotated_symptoms[symptom] # Extract the leading score digit for a compact display score_digit = label_text.split("–")[0].strip() if "–" in label_text else "?" st.sidebar.markdown( f"✅ **{symptom}** — *{label_text}*" ) else: st.sidebar.markdown(f"⬜ {symptom}") st.sidebar.markdown("---") def save_session_snapshot(): record = { "session_id": st.session_state.session_id, "patient_id": st.session_state.patient_id, "prolific_id": st.session_state.get("prolific_id"), "start_time": st.session_state.start_time, "timestamp": datetime.utcnow().isoformat(), "messages": st.session_state.messages, "turn_annotations": st.session_state.turn_annotations, } fname = f"{st.session_state.patient_id}_{st.session_state.prolific_id}_{st.session_state.session_id}.json" fpath = os.path.join(LOG_DIR, fname) with open(fpath, "w", encoding="utf-8") as f: json.dump(record, f, ensure_ascii=False, indent=2) @st.cache_resource def load_model(): # Load tokenizer from adapter repo (includes special tokens) tokenizer = AutoTokenizer.from_pretrained(ADAPTER_PATH) tokenizer.pad_token = tokenizer.eos_token # Define 4-bit quantization config explicitly bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, ) # Load base model using the quantization_config kwarg base_model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, quantization_config=bnb_config, device_map="auto", dtype=torch.float16, # replaces deprecated torch_dtype ) # Attach the LoRA adapter model = PeftModel.from_pretrained(base_model, ADAPTER_PATH) model.eval() # print("Model + adapter loaded successfully!") return tokenizer, model tokenizer, model = load_model() terminators = [ tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>") ] # --- init session state --- if "submitted" not in st.session_state: st.session_state.submitted = False if "session_id" not in st.session_state: st.session_state.session_id = str(uuid.uuid4()) if "start_time" not in st.session_state: st.session_state.start_time = datetime.utcnow().isoformat() if "prolific_id" not in st.session_state: # ← ADD THIS st.session_state.prolific_id = None if "messages" not in st.session_state: st.session_state.messages = [ {"role": "system", "content": "You are a simulated patient. Act realistically based on your internal training. Ensure contextual realism. Avoid overly detailed or formal speech. Keep natural speaking style (e.g., short answers, hesitations, casual expressions). Do not mention you are an AI."}, ] if "turn_annotations" not in st.session_state: st.session_state.turn_annotations = [] # Set browser tab title st.set_page_config(page_title="BDI-II Study") st.title("BDI-II Patient Chatbot 💬") # ── Prolific ID gate ────────────────────────────────────────────────────────── if st.session_state.prolific_id is None: st.markdown("### Before we begin") st.markdown( "Please enter your **Prolific ID** below. " "You can find it in your Prolific dashboard." "This step is important for the payment to be approved." ) with st.form("prolific_form"): pid = st.text_input("Prolific ID", placeholder="e.g. 5e3f2a1b9c8d7e6f5a4b3c2d") submitted = st.form_submit_button("Continue →") if submitted: pid = pid.strip() if pid: st.session_state.prolific_id = pid save_session_snapshot() # persist ID immediately st.rerun() else: st.error("Please enter your Prolific ID before continuing.") st.stop() # ← blocks everything below until the form is filled # Streamlit UI st.markdown("Please talk to the chatbot to administer the BDI-II test.") # ─── Thank-you screen (shown after submission) ───────────────────────────────── if st.session_state.get("submitted", False): st.success("## 🎉 Thank you for your participation!") st.markdown( """ Your session has been saved successfully. **Next steps:** - Return to Prolific and complete the submission there using the code "CI3N0FC3". - Your payment will be processed once your Prolific ID is verified. You may now close this window. """ ) st.balloons() st.stop() # --- display history (skip system) --- assistant_turn_counter = 0 for msg in st.session_state.messages: if msg["role"] == "system": continue with st.chat_message("user" if msg["role"] == "user" else "assistant"): st.markdown(msg["content"]) if msg["role"] == "assistant": assistant_turn_counter += 1 turn_idx = assistant_turn_counter existing = next( (a for a in st.session_state.turn_annotations if a["turn"] == turn_idx), None ) with st.expander( f"🩺 Annotate Turn {turn_idx}" + (" ✅" if existing else ""), expanded=(existing is None), ): if existing: st.markdown( f"**Saved symptoms:** {', '.join(existing['symptoms']) if existing['symptoms'] else '—'}" ) for sym, entry in existing["scores"].items(): st.markdown(f"- **{sym}:** {entry['label']} (score: {entry['score']})") else: selected_symptoms = st.multiselect( "Which BDI-II symptoms do you infer from this turn?", options=list(BDI_ITEMS.keys()), key=f"symptoms_{turn_idx}", ) per_symptom_scores = {} if selected_symptoms: st.markdown("**Select the best-matching statement for each symptom:**") for sym in selected_symptoms: options = BDI_ITEMS[sym] chosen = st.radio( sym, options=options, key=f"radio_{turn_idx}_{sym}", ) per_symptom_scores[sym] = { "label": chosen, "score": chosen, } if st.button("Next turn ->", key=f"save_{turn_idx}"): st.session_state.turn_annotations.append({ "turn": turn_idx, "symptoms": selected_symptoms, "scores": per_symptom_scores, "annotated_at": datetime.utcnow().isoformat(), }) save_session_snapshot() st.success(f"Annotation for turn {turn_idx} saved!") st.rerun() # ─── Inline Submit button (only shown at 21/21, above chat input) ────────────── if is_complete: st.markdown("---") st.info( "✅ **All 21 symptoms have been annotated.** " "You may continue chatting, or submit your session when ready." ) if st.button("🚀 Submit & Finish Session", type="primary", key="inline_submit"): st.session_state.submitted = True save_session_snapshot() st.rerun() st.markdown("---") # --- chat input --- if prompt := st.chat_input("Type your next BDI-II question or remark..."): # add user turn st.session_state.messages.append({"role": "user", "content": prompt}) with st.chat_message("user"): st.markdown(prompt) # 2. Format history & Create Attention Mask # return_dict=True gives us the 'attention_mask' automatically inputs = tokenizer.apply_chat_template( st.session_state.messages, add_generation_prompt=True, return_tensors="pt", return_dict=True ).to(model.device) # print("Here 2") # 3. Generate response # explicitly passing attention_mask prevents the warning you saw earlier with torch.no_grad(): outputs = model.generate( input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, max_new_tokens=256, eos_token_id=terminators, pad_token_id=tokenizer.eos_token_id, do_sample=True, temperature=0.6, top_p=0.9, ) # print("Here 3") # 4. Decode response # We slice [input_len:] to ensure we don't print the prompt back to the user response_tokens = outputs[0][inputs.input_ids.shape[-1]:] assistant_text = tokenizer.decode(response_tokens, skip_special_tokens=True) # print(f"Patient: {assistant_text}") with st.chat_message("assistant"): st.markdown(assistant_text) # 5. Append assistant response to history # messages.append({"role": "assistant", "content": assistant_text}) st.session_state.messages.append({"role": "assistant", "content": assistant_text}) save_session_snapshot() st.rerun()