Spaces:
Sleeping
Sleeping
Commit ·
a44bd4e
0
Parent(s):
Initial Commit
Browse files- .gitignore +74 -0
- README.md +37 -0
- chatgpt.py +318 -0
- requirements.txt +67 -0
- src/agents/__init__.py +0 -0
- src/agents/chat_agent.py +39 -0
- src/agents/chat_llm.py +79 -0
- src/agents/context.py +41 -0
- src/configs/Mixtral-8x22B.yaml +9 -0
- src/configs/Qwen2-72B-Instruct.yaml +9 -0
- src/configs/anthony_trollope.yaml +12 -0
- src/configs/catherine_helen_spence.yaml +12 -0
- src/configs/counselor_config.yaml +9 -0
- src/configs/eval_config.yaml +9 -0
- src/configs/gpt-4o_counselor_config.yaml +9 -0
- src/configs/jane_eyre_config.yaml +12 -0
- src/configs/llama3-70b_counselor_config.yaml +9 -0
- src/configs/llama3-8b_counselor_config.yaml +9 -0
- src/configs/memory_graph_config.yaml +5 -0
- src/configs/metric_config.yaml +9 -0
- src/configs/obama_config.yaml +12 -0
- src/configs/patient_config.yaml +11 -0
- src/configs/patient_config_autobio_summary.yaml +29 -0
- src/conversation.py +101 -0
- src/server.py +297 -0
- src/utils/__init__.py +0 -0
- src/utils/utils.py +44 -0
.gitignore
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
**/__pycache__/
|
| 4 |
+
*.py[cod]
|
| 5 |
+
*$py.class
|
| 6 |
+
|
| 7 |
+
# Distribution / packaging
|
| 8 |
+
.eggs/
|
| 9 |
+
dist/
|
| 10 |
+
build/
|
| 11 |
+
sdist/
|
| 12 |
+
*.egg-info/
|
| 13 |
+
pip-log.txt
|
| 14 |
+
pip-wheel-metadata/
|
| 15 |
+
*.whl
|
| 16 |
+
|
| 17 |
+
# Environments
|
| 18 |
+
.venv/
|
| 19 |
+
venv/
|
| 20 |
+
env/
|
| 21 |
+
ENV/
|
| 22 |
+
*.env
|
| 23 |
+
|
| 24 |
+
# PyInstaller
|
| 25 |
+
# Usually these files are written by a python script; ignoring them is safe
|
| 26 |
+
# https://docs.pyinstaller.org/en/stable/when-things-go-wrong.html
|
| 27 |
+
_MEIPASS*
|
| 28 |
+
*.spec
|
| 29 |
+
|
| 30 |
+
# Unit test / coverage reports
|
| 31 |
+
htmlcov/
|
| 32 |
+
.tox/
|
| 33 |
+
.nox/
|
| 34 |
+
.coverage
|
| 35 |
+
.coverage.*
|
| 36 |
+
.cache
|
| 37 |
+
nosetests.xml
|
| 38 |
+
coverage.xml
|
| 39 |
+
*.log
|
| 40 |
+
*.rpt
|
| 41 |
+
|
| 42 |
+
# Translations
|
| 43 |
+
*.mo
|
| 44 |
+
*.pot
|
| 45 |
+
|
| 46 |
+
# Django stuff:
|
| 47 |
+
*.db
|
| 48 |
+
*.sqlite3
|
| 49 |
+
*/migrations
|
| 50 |
+
|
| 51 |
+
# Flask stuff:
|
| 52 |
+
instance/
|
| 53 |
+
*.ini
|
| 54 |
+
|
| 55 |
+
# Scrapy stuff:
|
| 56 |
+
.scrapy
|
| 57 |
+
|
| 58 |
+
# Sphinx stuff:
|
| 59 |
+
docs/_build
|
| 60 |
+
|
| 61 |
+
# Jupyter Notebook
|
| 62 |
+
.ipynb_checkpoints
|
| 63 |
+
|
| 64 |
+
# VS Code
|
| 65 |
+
.vscode/
|
| 66 |
+
|
| 67 |
+
# PyCharm
|
| 68 |
+
.idea/
|
| 69 |
+
|
| 70 |
+
# OS or Editor-specific files
|
| 71 |
+
.DS_Store
|
| 72 |
+
Thumbs.db
|
| 73 |
+
|
| 74 |
+
.idea/
|
README.md
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: llm-autobiography
|
| 3 |
+
app_file: chatgpt.py
|
| 4 |
+
sdk: gradio
|
| 5 |
+
sdk_version: 3.50.2
|
| 6 |
+
---
|
| 7 |
+
# Chatbot Frontend
|
| 8 |
+
|
| 9 |
+
## Prerequisites
|
| 10 |
+
- Python 3.x
|
| 11 |
+
- pip
|
| 12 |
+
|
| 13 |
+
## Installation
|
| 14 |
+
1. Clone this repository
|
| 15 |
+
|
| 16 |
+
2. Navigate to the project directory:
|
| 17 |
+
```
|
| 18 |
+
cd chatbot-mimic-notes
|
| 19 |
+
```
|
| 20 |
+
3. Install the required packages:
|
| 21 |
+
```
|
| 22 |
+
pip install -r requirements.txt
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
## Running the Application
|
| 26 |
+
Start the chatbot by running:
|
| 27 |
+
```
|
| 28 |
+
python chatgpt.py
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
## Usage
|
| 32 |
+
- Open Chrome and visit `http://localhost:7860` to interact with the chatbot.
|
| 33 |
+
|
| 34 |
+
## Run Server
|
| 35 |
+
```
|
| 36 |
+
python -m src.server
|
| 37 |
+
```
|
chatgpt.py
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import whisper
|
| 3 |
+
import asyncio
|
| 4 |
+
import httpx
|
| 5 |
+
import tempfile
|
| 6 |
+
import os
|
| 7 |
+
import requests
|
| 8 |
+
import time
|
| 9 |
+
import threading
|
| 10 |
+
from datetime import datetime, timedelta
|
| 11 |
+
|
| 12 |
+
session = requests.Session()
|
| 13 |
+
|
| 14 |
+
from interview_protocol import protocols as interview_protocols
|
| 15 |
+
|
| 16 |
+
model = whisper.load_model("base")
|
| 17 |
+
|
| 18 |
+
base_url = "http://localhost:8080"
|
| 19 |
+
|
| 20 |
+
timeout = 60
|
| 21 |
+
concurrency_count=10
|
| 22 |
+
|
| 23 |
+
async def initialization(api_key, username):
|
| 24 |
+
url = f"{base_url}/api/initialization"
|
| 25 |
+
headers = {'Content-Type': 'application/json'}
|
| 26 |
+
data = {
|
| 27 |
+
'api_key': api_key,
|
| 28 |
+
'username': username,
|
| 29 |
+
}
|
| 30 |
+
async with httpx.AsyncClient(timeout=timeout) as client:
|
| 31 |
+
try:
|
| 32 |
+
response = await client.post(url, json=data, headers=headers)
|
| 33 |
+
if response.status_code == 200:
|
| 34 |
+
return "Initialization successful."
|
| 35 |
+
else:
|
| 36 |
+
return f"Initialization failed: {response.text}"
|
| 37 |
+
except asyncio.TimeoutError:
|
| 38 |
+
print("The request timed out")
|
| 39 |
+
return "Request timed out during initialization."
|
| 40 |
+
except Exception as e:
|
| 41 |
+
return f"Error in initialization: {str(e)}"
|
| 42 |
+
|
| 43 |
+
# def fetch_default_prompts(chatbot_type):
|
| 44 |
+
# url = f"{base_url}?chatbot_type={chatbot_type}"
|
| 45 |
+
# try:
|
| 46 |
+
# response = httpx.get(url, timeout=timeout)
|
| 47 |
+
# if response.status_code == 200:
|
| 48 |
+
# prompts = response.json()
|
| 49 |
+
# print(prompts)
|
| 50 |
+
# return prompts
|
| 51 |
+
# else:
|
| 52 |
+
# print(f"Failed to fetch prompts: {response.status_code} - {response.text}")
|
| 53 |
+
# return {}
|
| 54 |
+
# except Exception as e:
|
| 55 |
+
# print(f"Error fetching prompts: {str(e)}")
|
| 56 |
+
# return {}
|
| 57 |
+
|
| 58 |
+
async def get_backend_response(api_key, patient_prompt, username, chatbot_type):
|
| 59 |
+
url = f"{base_url}/responses/doctor"
|
| 60 |
+
headers = {'Content-Type': 'application/json'}
|
| 61 |
+
data = {
|
| 62 |
+
'username': username,
|
| 63 |
+
'patient_prompt': patient_prompt,
|
| 64 |
+
'chatbot_type': chatbot_type
|
| 65 |
+
}
|
| 66 |
+
async with httpx.AsyncClient(timeout=timeout) as client:
|
| 67 |
+
try:
|
| 68 |
+
response = await client.post(url, json=data, headers=headers)
|
| 69 |
+
if response.status_code == 200:
|
| 70 |
+
response_data = response.json()
|
| 71 |
+
return response_data
|
| 72 |
+
else:
|
| 73 |
+
return f"Failed to fetch response from backend: {response.text}"
|
| 74 |
+
except Exception as e:
|
| 75 |
+
return f"Error contacting backend service: {str(e)}"
|
| 76 |
+
|
| 77 |
+
async def save_conversation_and_memory(username, chatbot_type):
|
| 78 |
+
url = f"{base_url}/save/end_and_save"
|
| 79 |
+
headers = {'Content-Type': 'application/json'}
|
| 80 |
+
data = {
|
| 81 |
+
'username': username,
|
| 82 |
+
'chatbot_type': chatbot_type
|
| 83 |
+
}
|
| 84 |
+
async with httpx.AsyncClient(timeout=timeout) as client:
|
| 85 |
+
try:
|
| 86 |
+
response = await client.post(url, json=data, headers=headers)
|
| 87 |
+
if response.status_code == 200:
|
| 88 |
+
response_data = response.json()
|
| 89 |
+
return response_data.get('message', 'Saving Error!')
|
| 90 |
+
else:
|
| 91 |
+
return f"Failed to save conversations and memory graph: {response.text}"
|
| 92 |
+
except Exception as e:
|
| 93 |
+
return f"Error contacting backend service: {str(e)}"
|
| 94 |
+
|
| 95 |
+
async def get_conversation_histories(username, chatbot_type):
|
| 96 |
+
url = f"{base_url}/save/download_conversations"
|
| 97 |
+
headers = {'Content-Type': 'application/json'}
|
| 98 |
+
data = {
|
| 99 |
+
'username': username,
|
| 100 |
+
'chatbot_type': chatbot_type
|
| 101 |
+
}
|
| 102 |
+
async with httpx.AsyncClient(timeout=timeout) as client:
|
| 103 |
+
try:
|
| 104 |
+
response = await client.post(url, json=data, headers=headers)
|
| 105 |
+
if response.status_code == 200:
|
| 106 |
+
conversation_data = response.json()
|
| 107 |
+
return conversation_data
|
| 108 |
+
else:
|
| 109 |
+
return []
|
| 110 |
+
except Exception as e:
|
| 111 |
+
return []
|
| 112 |
+
|
| 113 |
+
def download_conversations(username, chatbot_type):
|
| 114 |
+
conversation_histories = asyncio.run(get_conversation_histories(username, chatbot_type))
|
| 115 |
+
files = []
|
| 116 |
+
temp_dir = tempfile.mkdtemp()
|
| 117 |
+
for conversation_entry in conversation_histories:
|
| 118 |
+
file_name = conversation_entry.get('file_name', f"Conversation_{len(files)+1}.txt")
|
| 119 |
+
conversation = conversation_entry.get('conversation', [])
|
| 120 |
+
conversation_text = ""
|
| 121 |
+
for message_pair in conversation:
|
| 122 |
+
if isinstance(message_pair, list) and len(message_pair) == 2:
|
| 123 |
+
speaker, message = message_pair
|
| 124 |
+
conversation_text += f"{speaker.capitalize()}: {message}\n\n"
|
| 125 |
+
else:
|
| 126 |
+
conversation_text += f"Unknown format: {message_pair}\n\n"
|
| 127 |
+
temp_file_path = os.path.join(temp_dir, file_name)
|
| 128 |
+
with open(temp_file_path, 'w') as temp_file:
|
| 129 |
+
temp_file.write(conversation_text)
|
| 130 |
+
files.append(temp_file_path)
|
| 131 |
+
return files
|
| 132 |
+
|
| 133 |
+
# async def get_biography(username, chatbot_type):
|
| 134 |
+
# url = f"{base_url}/save/generate_autobiography"
|
| 135 |
+
# headers = {'Content-Type': 'application/json'}
|
| 136 |
+
# data = {
|
| 137 |
+
# 'username': username,
|
| 138 |
+
# 'chatbot_type': chatbot_type
|
| 139 |
+
# }
|
| 140 |
+
# async with httpx.AsyncClient(timeout=timeout) as client:
|
| 141 |
+
# try:
|
| 142 |
+
# response = await client.post(url, json=data, headers=headers)
|
| 143 |
+
# if response.status_code == 200:
|
| 144 |
+
# biography_data = response.json()
|
| 145 |
+
# biography_text = biography_data.get('biography', '')
|
| 146 |
+
# return biography_text
|
| 147 |
+
# else:
|
| 148 |
+
# return "Failed to generate biography."
|
| 149 |
+
# except Exception as e:
|
| 150 |
+
# return f"Error contacting backend service: {str(e)}"
|
| 151 |
+
|
| 152 |
+
# def download_biography(username, chatbot_type):
|
| 153 |
+
# biography_text = asyncio.run(get_biography(username, chatbot_type))
|
| 154 |
+
# if not biography_text or "Failed" in biography_text or "Error" in biography_text:
|
| 155 |
+
# return gr.update(value=None, visible=False), gr.update(value=biography_text, visible=True)
|
| 156 |
+
# temp_dir = tempfile.mkdtemp()
|
| 157 |
+
# temp_file_path = os.path.join(temp_dir, "biography.txt")
|
| 158 |
+
# with open(temp_file_path, 'w') as temp_file:
|
| 159 |
+
# temp_file.write(biography_text)
|
| 160 |
+
# return temp_file_path, gr.update(value=biography_text, visible=True)
|
| 161 |
+
|
| 162 |
+
def transcribe_audio(audio_file):
|
| 163 |
+
transcription = model.transcribe(audio_file)["text"]
|
| 164 |
+
return transcription
|
| 165 |
+
|
| 166 |
+
def submit_text_and_respond(edited_text, api_key, username, history, chatbot_type):
|
| 167 |
+
response = asyncio.run(get_backend_response(api_key, edited_text, username, chatbot_type))
|
| 168 |
+
print('------')
|
| 169 |
+
print(response)
|
| 170 |
+
if isinstance(response, str):
|
| 171 |
+
history.append((edited_text, response))
|
| 172 |
+
return history, "", []
|
| 173 |
+
doctor_response = response['doctor_response']
|
| 174 |
+
history.append((edited_text, doctor_response))
|
| 175 |
+
return history, "" # Return memory_graph as output
|
| 176 |
+
|
| 177 |
+
def set_initialize_button(api_key_input, username_input):
|
| 178 |
+
message = asyncio.run(initialization(api_key_input, username_input))
|
| 179 |
+
print(message)
|
| 180 |
+
return message, api_key_input
|
| 181 |
+
|
| 182 |
+
def save_conversation(username, chatbot_type):
|
| 183 |
+
response = asyncio.run(save_conversation_and_memory(username, chatbot_type))
|
| 184 |
+
return response
|
| 185 |
+
|
| 186 |
+
def start_recording(audio_file):
|
| 187 |
+
if not audio_file:
|
| 188 |
+
return ""
|
| 189 |
+
try:
|
| 190 |
+
transcription = transcribe_audio(audio_file)
|
| 191 |
+
return transcription
|
| 192 |
+
except Exception as e:
|
| 193 |
+
return f"Failed to transcribe: {str(e)}"
|
| 194 |
+
|
| 195 |
+
def update_methods(chapter):
|
| 196 |
+
return gr.update(choices=interview_protocols[chapter], value=interview_protocols[chapter][0])
|
| 197 |
+
|
| 198 |
+
# def update_memory_graph(memory_data):
|
| 199 |
+
# table_data = []
|
| 200 |
+
# for node in memory_data:
|
| 201 |
+
# table_data.append([
|
| 202 |
+
# node.get('date', ''),
|
| 203 |
+
# node.get('topic', ''),
|
| 204 |
+
# node.get('event_description', ''),
|
| 205 |
+
# node.get('people_involved', '')
|
| 206 |
+
# ])
|
| 207 |
+
# return table_data
|
| 208 |
+
|
| 209 |
+
# def update_prompts(chatbot_display_name):
|
| 210 |
+
# chatbot_type = display_to_value.get(chatbot_display_name, 'enhanced')
|
| 211 |
+
# prompts = fetch_default_prompts(chatbot_type)
|
| 212 |
+
# return (
|
| 213 |
+
# gr.update(value=prompts.get('system_prompt', '')),
|
| 214 |
+
# gr.update(value=prompts.get('conv_instruction_prompt', '')),
|
| 215 |
+
# gr.update(value=prompts.get('therapy_prompt', '')),
|
| 216 |
+
# gr.update(value=prompts.get('autobio_generation_prompt', '')),
|
| 217 |
+
# )
|
| 218 |
+
|
| 219 |
+
# def update_chatbot_type(chatbot_display_name):
|
| 220 |
+
# chatbot_type = display_to_value.get(chatbot_display_name, 'enhanced')
|
| 221 |
+
# return chatbot_type
|
| 222 |
+
|
| 223 |
+
# CSS to keep the buttons small
|
| 224 |
+
css = """
|
| 225 |
+
#start_button, #reset_button {
|
| 226 |
+
padding: 4px 10px !important;
|
| 227 |
+
font-size: 12px !important;
|
| 228 |
+
width: auto !important;
|
| 229 |
+
}
|
| 230 |
+
"""
|
| 231 |
+
|
| 232 |
+
with gr.Blocks(css=css) as app:
|
| 233 |
+
chatbot_type_state = gr.State('enhanced')
|
| 234 |
+
api_key_state = gr.State()
|
| 235 |
+
prompt_visibility_state = gr.State(False)
|
| 236 |
+
|
| 237 |
+
is_running = gr.State()
|
| 238 |
+
target_timestamp = gr.State()
|
| 239 |
+
|
| 240 |
+
with gr.Row():
|
| 241 |
+
with gr.Column(scale=1, min_width=250):
|
| 242 |
+
gr.Markdown("## Settings")
|
| 243 |
+
|
| 244 |
+
with gr.Box():
|
| 245 |
+
gr.Markdown("### User Information")
|
| 246 |
+
username_input = gr.Textbox(
|
| 247 |
+
label="Username", placeholder="Enter your username"
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
api_key_input = gr.Textbox(
|
| 251 |
+
label="OpenAI API Key",
|
| 252 |
+
placeholder="Enter your openai api key",
|
| 253 |
+
type="password"
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
initialize_button = gr.Button("Initialize", variant="primary", size="large")
|
| 257 |
+
initialization_status = gr.Textbox(
|
| 258 |
+
label="Status", interactive=False, placeholder="Initialization status will appear here."
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
initialize_button.click(
|
| 262 |
+
fn=set_initialize_button,
|
| 263 |
+
inputs=[api_key_input, username_input], # chatbot_type_dropdown replaced with None
|
| 264 |
+
outputs=[initialization_status, api_key_state],
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
# define the function to toggle prompts visibility
|
| 268 |
+
# def toggle_prompts(visibility):
|
| 269 |
+
# new_visibility = not visibility
|
| 270 |
+
# button_text = "Hide Prompts" if new_visibility else "Show Prompts"
|
| 271 |
+
# return gr.update(value=button_text), gr.update(visible=new_visibility), new_visibility
|
| 272 |
+
|
| 273 |
+
with gr.Column(scale=3):
|
| 274 |
+
|
| 275 |
+
chatbot = gr.Chatbot(label="Provide Your Pre-Operation Notes Here", height=500)
|
| 276 |
+
|
| 277 |
+
with gr.Row():
|
| 278 |
+
transcription_box = gr.Textbox(
|
| 279 |
+
label="Transcription (You can edit this)", lines=3
|
| 280 |
+
)
|
| 281 |
+
audio_input = gr.Audio(
|
| 282 |
+
source="microphone", type="filepath", label="🎤 Record Audio"
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
with gr.Row():
|
| 286 |
+
submit_button = gr.Button("Submit", variant="primary", size="large")
|
| 287 |
+
save_conversation_button = gr.Button("End and Save Conversation", variant="secondary")
|
| 288 |
+
download_button = gr.Button("Download Conversations", variant="secondary")
|
| 289 |
+
|
| 290 |
+
audio_input.change(
|
| 291 |
+
fn=start_recording,
|
| 292 |
+
inputs=[audio_input],
|
| 293 |
+
outputs=[transcription_box]
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
state = gr.State([])
|
| 297 |
+
|
| 298 |
+
submit_button.click(
|
| 299 |
+
submit_text_and_respond,
|
| 300 |
+
inputs=[transcription_box, api_key_state, username_input, state, chatbot_type_state],
|
| 301 |
+
outputs=[chatbot, transcription_box]
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
download_button.click(
|
| 305 |
+
fn=download_conversations,
|
| 306 |
+
inputs=[username_input, chatbot_type_state],
|
| 307 |
+
outputs=gr.Files()
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
save_conversation_button.click(
|
| 311 |
+
fn=save_conversation,
|
| 312 |
+
inputs=[username_input, chatbot_type_state],
|
| 313 |
+
outputs=None
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
app.queue()
|
| 318 |
+
app.launch(share=True, max_threads=10)
|
requirements.txt
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
aiofiles==23.2.1
|
| 2 |
+
altair==5.4.1
|
| 3 |
+
annotated-types==0.7.0
|
| 4 |
+
anyio==4.5.0
|
| 5 |
+
attrs==24.2.0
|
| 6 |
+
certifi==2024.8.30
|
| 7 |
+
charset-normalizer==3.3.2
|
| 8 |
+
click==8.1.7
|
| 9 |
+
contourpy==1.3.0
|
| 10 |
+
cycler==0.12.1
|
| 11 |
+
exceptiongroup==1.2.2
|
| 12 |
+
fastapi==0.115.0
|
| 13 |
+
ffmpy==0.4.0
|
| 14 |
+
filelock==3.16.1
|
| 15 |
+
fonttools==4.53.1
|
| 16 |
+
fsspec==2024.9.0
|
| 17 |
+
gradio==3.50.2
|
| 18 |
+
gradio_client==0.6.1
|
| 19 |
+
h11==0.14.0
|
| 20 |
+
httpcore==1.0.5
|
| 21 |
+
httpx==0.27.2
|
| 22 |
+
huggingface-hub==0.25.0
|
| 23 |
+
idna==3.10
|
| 24 |
+
importlib_resources==6.4.5
|
| 25 |
+
Jinja2==3.1.4
|
| 26 |
+
jsonschema==4.23.0
|
| 27 |
+
jsonschema-specifications==2023.12.1
|
| 28 |
+
kiwisolver==1.4.7
|
| 29 |
+
llvmlite==0.43.0
|
| 30 |
+
MarkupSafe==2.1.5
|
| 31 |
+
matplotlib==3.9.2
|
| 32 |
+
more-itertools==10.5.0
|
| 33 |
+
mpmath==1.3.0
|
| 34 |
+
narwhals==1.8.1
|
| 35 |
+
networkx==3.3
|
| 36 |
+
numba==0.60.0
|
| 37 |
+
numpy==1.26.4
|
| 38 |
+
openai-whisper==20231117
|
| 39 |
+
orjson==3.10.7
|
| 40 |
+
packaging==24.1
|
| 41 |
+
pandas==2.2.2
|
| 42 |
+
pillow==10.4.0
|
| 43 |
+
pydantic==2.9.2
|
| 44 |
+
pydantic_core==2.23.4
|
| 45 |
+
pydub==0.25.1
|
| 46 |
+
pyparsing==3.1.4
|
| 47 |
+
python-dateutil==2.9.0.post0
|
| 48 |
+
python-multipart==0.0.9
|
| 49 |
+
pytz==2024.2
|
| 50 |
+
PyYAML==6.0.2
|
| 51 |
+
referencing==0.35.1
|
| 52 |
+
regex==2024.9.11
|
| 53 |
+
requests==2.32.3
|
| 54 |
+
rpds-py==0.20.0
|
| 55 |
+
semantic-version==2.10.0
|
| 56 |
+
six==1.16.0
|
| 57 |
+
sniffio==1.3.1
|
| 58 |
+
starlette==0.38.5
|
| 59 |
+
sympy==1.13.3
|
| 60 |
+
tiktoken==0.7.0
|
| 61 |
+
torch
|
| 62 |
+
tqdm==4.66.5
|
| 63 |
+
typing_extensions==4.12.2
|
| 64 |
+
tzdata==2024.1
|
| 65 |
+
urllib3==2.2.3
|
| 66 |
+
uvicorn==0.30.6
|
| 67 |
+
websockets==11.0.3
|
src/agents/__init__.py
ADDED
|
File without changes
|
src/agents/chat_agent.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
from src.agents.context import Context
|
| 5 |
+
from src.utils.utils import load_config
|
| 6 |
+
from src.agents.chat_llm import chat_llm
|
| 7 |
+
|
| 8 |
+
class BaseChatAgent:
|
| 9 |
+
def __init__(self, config_path) -> None:
|
| 10 |
+
self.agent_config = load_config(config_path)
|
| 11 |
+
# conversation context
|
| 12 |
+
self.context = Context()
|
| 13 |
+
self.history_conversations = []
|
| 14 |
+
# set default interview protocal index
|
| 15 |
+
|
| 16 |
+
self.hist_conv_summarization = None
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def talk_to_user(
|
| 20 |
+
self,
|
| 21 |
+
user_response=None,
|
| 22 |
+
):
|
| 23 |
+
|
| 24 |
+
self.context.add_user_prompt(user_response['patient_prompt'])
|
| 25 |
+
|
| 26 |
+
response = chat_llm(
|
| 27 |
+
messages=self.context.msg,
|
| 28 |
+
model=self.agent_config.llm_model_path,
|
| 29 |
+
temperature=self.agent_config.temperature,
|
| 30 |
+
max_tokens=self.agent_config.max_tokens,
|
| 31 |
+
n=1,
|
| 32 |
+
timeout=self.agent_config.timeout,
|
| 33 |
+
stop=None,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
self.context.add_assistant_prompt(response)
|
| 37 |
+
|
| 38 |
+
return response["generations"][0]
|
| 39 |
+
|
src/agents/chat_llm.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import yaml
|
| 2 |
+
from box import Box
|
| 3 |
+
from langchain.chat_models import ChatOpenAI
|
| 4 |
+
from langchain_community.chat_models import ChatOpenAI
|
| 5 |
+
from langchain_community.llms import DeepInfra
|
| 6 |
+
from langchain.schema import HumanMessage, SystemMessage, AIMessage
|
| 7 |
+
import logging
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def chat_llm(messages, model, temperature, max_tokens, n, timeout=600, stop=None, return_tokens=False):
|
| 12 |
+
if model.__contains__("gpt"):
|
| 13 |
+
iterated_query = False
|
| 14 |
+
try:
|
| 15 |
+
chat = ChatOpenAI(model_name=model,
|
| 16 |
+
temperature=temperature,
|
| 17 |
+
max_tokens=max_tokens,
|
| 18 |
+
n=n,
|
| 19 |
+
request_timeout=timeout,
|
| 20 |
+
model_kwargs={"seed": 0,
|
| 21 |
+
"top_p": 0
|
| 22 |
+
})
|
| 23 |
+
except Exception as e:
|
| 24 |
+
print(f"Error in loading model: {e}")
|
| 25 |
+
return None
|
| 26 |
+
else:
|
| 27 |
+
# deepinfra
|
| 28 |
+
iterated_query = True
|
| 29 |
+
chat = ChatOpenAI(model_name=model,
|
| 30 |
+
openai_api_key=None,
|
| 31 |
+
temperature=temperature,
|
| 32 |
+
max_tokens=max_tokens,
|
| 33 |
+
n=1,
|
| 34 |
+
request_timeout=timeout,
|
| 35 |
+
openai_api_base="https://api.deepinfra.com/v1/openai")
|
| 36 |
+
|
| 37 |
+
longchain_msgs = []
|
| 38 |
+
for msg in messages:
|
| 39 |
+
if msg['role'] == 'system':
|
| 40 |
+
longchain_msgs.append(SystemMessage(content=msg['content']))
|
| 41 |
+
elif msg['role'] == 'user':
|
| 42 |
+
print('human message', msg)
|
| 43 |
+
longchain_msgs.append(HumanMessage(content=msg['content']))
|
| 44 |
+
elif msg['role'] == 'assistant':
|
| 45 |
+
longchain_msgs.append(AIMessage(content=msg['content']))
|
| 46 |
+
else:
|
| 47 |
+
raise NotImplementedError
|
| 48 |
+
# add an empty user message to avoid no user message error
|
| 49 |
+
longchain_msgs.append(HumanMessage(content=""))
|
| 50 |
+
if n > 1 and iterated_query:
|
| 51 |
+
response_list = []
|
| 52 |
+
total_completion_tokens = 0
|
| 53 |
+
total_prompt_tokens = 0
|
| 54 |
+
for n in range(n):
|
| 55 |
+
generations = chat.generate([longchain_msgs], stop=[
|
| 56 |
+
stop] if stop is not None else None)
|
| 57 |
+
responses = [
|
| 58 |
+
chat_gen.message.content for chat_gen in generations.generations[0]]
|
| 59 |
+
response_list.append(responses[0])
|
| 60 |
+
completion_tokens = generations.llm_output['token_usage']['completion_tokens']
|
| 61 |
+
prompt_tokens = generations.llm_output['token_usage']['prompt_tokens']
|
| 62 |
+
total_completion_tokens += completion_tokens
|
| 63 |
+
total_prompt_tokens += prompt_tokens
|
| 64 |
+
responses = response_list
|
| 65 |
+
completion_tokens = total_completion_tokens
|
| 66 |
+
prompt_tokens = total_prompt_tokens
|
| 67 |
+
else:
|
| 68 |
+
generations = chat.generate([longchain_msgs], stop=[
|
| 69 |
+
stop] if stop is not None else None)
|
| 70 |
+
responses = [
|
| 71 |
+
chat_gen.message.content for chat_gen in generations.generations[0]]
|
| 72 |
+
completion_tokens = generations.llm_output['token_usage']['completion_tokens']
|
| 73 |
+
prompt_tokens = generations.llm_output['token_usage']['prompt_tokens']
|
| 74 |
+
|
| 75 |
+
return {
|
| 76 |
+
'generations': responses,
|
| 77 |
+
'completion_tokens': completion_tokens,
|
| 78 |
+
'prompt_tokens': prompt_tokens
|
| 79 |
+
}
|
src/agents/context.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
class ContextAttributeError(BaseException):
|
| 3 |
+
def __init__(self, message):
|
| 4 |
+
self.message = message
|
| 5 |
+
super().__init__(self.message)
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Context:
|
| 9 |
+
def __init__(self) -> None:
|
| 10 |
+
self.msg = []
|
| 11 |
+
pass
|
| 12 |
+
|
| 13 |
+
def add_system_prompt(self, prompt: str):
|
| 14 |
+
pt = {
|
| 15 |
+
'role': 'system',
|
| 16 |
+
'content': prompt
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
# replace the last system prompt if it exists or insert to 0 if it doesn't
|
| 20 |
+
if len(self.msg) > 0 and self.msg[0]['role'] == 'system':
|
| 21 |
+
self.msg[0] = pt
|
| 22 |
+
else:
|
| 23 |
+
self.msg.insert(0, pt)
|
| 24 |
+
|
| 25 |
+
def add_user_prompt(self, prompt: str):
|
| 26 |
+
pt = {
|
| 27 |
+
'role': 'user',
|
| 28 |
+
'content': prompt
|
| 29 |
+
}
|
| 30 |
+
self.msg.append(pt)
|
| 31 |
+
|
| 32 |
+
def remove_last_prompt(self):
|
| 33 |
+
self.msg.pop(-1)
|
| 34 |
+
|
| 35 |
+
def add_assistant_prompt(self, prompt: str):
|
| 36 |
+
pt = {
|
| 37 |
+
'role': 'assistant',
|
| 38 |
+
'content': prompt
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
self.msg.append(pt)
|
src/configs/Mixtral-8x22B.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
agent_name: Doctor
|
| 2 |
+
llm_model_path: mistralai/Mixtral-8x22B-Instruct-v0.1
|
| 3 |
+
max_tokens: 1024
|
| 4 |
+
timeout: 600
|
| 5 |
+
num_generations: 1
|
| 6 |
+
temperature: 0
|
| 7 |
+
use_emotion_module: False
|
| 8 |
+
session_max_tokens: 4096
|
| 9 |
+
patient_conv_history_dir: ./assets/conversation_history
|
src/configs/Qwen2-72B-Instruct.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
agent_name: Doctor
|
| 2 |
+
llm_model_path: Qwen/Qwen2-72B-Instruct
|
| 3 |
+
max_tokens: 1024
|
| 4 |
+
timeout: 600
|
| 5 |
+
num_generations: 1
|
| 6 |
+
temperature: 0
|
| 7 |
+
use_emotion_module: False
|
| 8 |
+
session_max_tokens: 4096
|
| 9 |
+
patient_conv_history_dir: ./assets/conversation_history
|
src/configs/anthony_trollope.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
agent_name: Patient
|
| 2 |
+
#llm_model_path: gpt-4o-2024-05-13
|
| 3 |
+
llm_model_path: gpt-4o
|
| 4 |
+
max_tokens: 1024
|
| 5 |
+
timeout: 600
|
| 6 |
+
num_generations: 1
|
| 7 |
+
temperature: 0
|
| 8 |
+
patient_id: 'Anthony_Trollope'
|
| 9 |
+
autobiography_path_for_summary: ./assets/autobiography/anthony_trollope/autobiography_chapters_Anthony_Trollope.json
|
| 10 |
+
autobiography_path_for_rag: ./assets/autobiography/anthony_trollope/An_Autobiography_of_Anthony_Trollope.txt
|
| 11 |
+
autobiography_pre_summary_path: ./assets/autobiography/anthony_trollope/autobiography_pre_summary_Anthony_Trollope.json
|
| 12 |
+
autobiography_pre_event_path: ./assets/autobiography/anthony_trollope/autobiography_pre_event_parsed_Anthony_Trollope.json
|
src/configs/catherine_helen_spence.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
agent_name: Patient
|
| 2 |
+
#llm_model_path: gpt-4o-2024-05-13
|
| 3 |
+
llm_model_path: gpt-4o
|
| 4 |
+
max_tokens: 1024
|
| 5 |
+
timeout: 600
|
| 6 |
+
num_generations: 1
|
| 7 |
+
temperature: 0
|
| 8 |
+
patient_id: 'Catherine_Helen_Spence'
|
| 9 |
+
autobiography_path_for_summary: ./assets/autobiography/catherine_helen_spence/autobiography_chapters_Catherine_Helen_Spence.json
|
| 10 |
+
autobiography_path_for_rag: ./assets/autobiography/catherine_helen_spence/An_Autobiography_by_Catherine_Helen_Spence.txt
|
| 11 |
+
autobiography_pre_summary_path: ./assets/autobiography/catherine_helen_spence/autobiography_pre_summary_Catherine_Helen_Spence.json
|
| 12 |
+
autobiography_pre_event_path: ./assets/autobiography/catherine_helen_spence/autobiography_pre_event_parsed_Catherine_Helen_Spence.json
|
src/configs/counselor_config.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
agent_name: Doctor
|
| 2 |
+
llm_model_path: gpt-3.5-turbo
|
| 3 |
+
max_tokens: 1024
|
| 4 |
+
timeout: 600
|
| 5 |
+
num_generations: 1
|
| 6 |
+
temperature: 0
|
| 7 |
+
use_emotion_module: False
|
| 8 |
+
session_max_tokens: 4096
|
| 9 |
+
patient_conv_history_dir: ./assets/conversation_history
|
src/configs/eval_config.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
llm_model_path: gpt-4-turbo-preview
|
| 2 |
+
max_tokens: 1024
|
| 3 |
+
timeout: 600
|
| 4 |
+
num_generations: 1
|
| 5 |
+
temperature: 0
|
| 6 |
+
autobiography_path_for_summary: ./assets/autobiography_chapters.json
|
| 7 |
+
autobiography_path_for_rag: ./assets/autobiography/13991012000013_Test.txt
|
| 8 |
+
autobiography_pre_summary_path: ./assets/autobiography_pre_summary.json
|
| 9 |
+
autobiography_pre_event_path: ./assets/autobiography_pre_event.json
|
src/configs/gpt-4o_counselor_config.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
agent_name: Doctor
|
| 2 |
+
llm_model_path: gpt-4o-2024-05-13
|
| 3 |
+
max_tokens: 1024
|
| 4 |
+
timeout: 600
|
| 5 |
+
num_generations: 1
|
| 6 |
+
temperature: 0
|
| 7 |
+
use_emotion_module: False
|
| 8 |
+
session_max_tokens: 4096
|
| 9 |
+
patient_conv_history_dir: ./assets/conversation_history
|
src/configs/jane_eyre_config.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
agent_name: Patient
|
| 2 |
+
#llm_model_path: gpt-4o-2024-05-13
|
| 3 |
+
llm_model_path: gpt-4o
|
| 4 |
+
max_tokens: 1024
|
| 5 |
+
timeout: 600
|
| 6 |
+
num_generations: 1
|
| 7 |
+
temperature: 0
|
| 8 |
+
patient_id: 'Jane_Eyre'
|
| 9 |
+
autobiography_path_for_summary: ./assets/autobiography/jane_eyre/autobiography_chapters_Jane_Eyre.json
|
| 10 |
+
autobiography_path_for_rag: ./assets/autobiography/jane_eyre/Jane_Eyre_An_Autobiography.txt
|
| 11 |
+
autobiography_pre_summary_path: ./assets/autobiography/jane_eyre/autobiography_pre_summary_Jane_Eyre.json
|
| 12 |
+
autobiography_pre_event_path: ./assets/autobiography/jane_eyre/autobiography_pre_event_parsed_Jane_Eyre.json
|
src/configs/llama3-70b_counselor_config.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
agent_name: Doctor
|
| 2 |
+
llm_model_path: meta-llama/Meta-Llama-3-70B-Instruct
|
| 3 |
+
max_tokens: 1024
|
| 4 |
+
timeout: 600
|
| 5 |
+
num_generations: 1
|
| 6 |
+
temperature: 0
|
| 7 |
+
use_emotion_module: False
|
| 8 |
+
session_max_tokens: 4096
|
| 9 |
+
patient_conv_history_dir: ./assets/conversation_history
|
src/configs/llama3-8b_counselor_config.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
agent_name: Doctor
|
| 2 |
+
llm_model_path: meta-llama/Meta-Llama-3-8B-Instruct
|
| 3 |
+
max_tokens: 1024
|
| 4 |
+
timeout: 600
|
| 5 |
+
num_generations: 1
|
| 6 |
+
temperature: 0
|
| 7 |
+
use_emotion_module: False
|
| 8 |
+
session_max_tokens: 4096
|
| 9 |
+
patient_conv_history_dir: ./assets/conversation_history
|
src/configs/memory_graph_config.yaml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
llm_model_path: gpt-4o
|
| 2 |
+
max_tokens: 1024
|
| 3 |
+
timeout: 600
|
| 4 |
+
num_generations: 1
|
| 5 |
+
temperature: 0
|
src/configs/metric_config.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
llm_model_path: gpt-4-turbo-preview
|
| 2 |
+
max_tokens: 1024
|
| 3 |
+
timeout: 600
|
| 4 |
+
num_generations: 1
|
| 5 |
+
temperature: 0
|
| 6 |
+
autobiography_pre_memory_path: ./assets/demo/conv_history_w_interview_protocol_w_memorygraph/Obama/memory_graph
|
| 7 |
+
autobiography_pre_event_path: ./assets/autobiography_pre_event_parsed.json
|
| 8 |
+
autobiography_pre_event_emb_path: ./assets/autobiography_pre_event_emb.json
|
| 9 |
+
metrics_output_path: ./assets/eval_output/
|
src/configs/obama_config.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
agent_name: Patient
|
| 2 |
+
# llm_model_path: gpt-4o-2024-05-13
|
| 3 |
+
llm_model_path: gpt-4-turbo
|
| 4 |
+
max_tokens: 1024
|
| 5 |
+
timeout: 600
|
| 6 |
+
num_generations: 1
|
| 7 |
+
temperature: 0
|
| 8 |
+
patient_id: 'Obama'
|
| 9 |
+
autobiography_path_for_summary: ./assets/autobiography/obama/autobiography_chapters.json
|
| 10 |
+
autobiography_path_for_rag: ./assets/autobiography/obama/13991012000013_Test.txt
|
| 11 |
+
autobiography_pre_summary_path: ./assets/autobiography/obama/autobiography_pre_summary.json
|
| 12 |
+
autobiography_pre_event_path: ./assets/autobiography/obama/autobiography_pre_event_parsed.json
|
src/configs/patient_config.yaml
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
agent_name: Patient
|
| 2 |
+
#llm_model_path: gpt-4o-2024-05-13
|
| 3 |
+
llm_model_path: gpt-4-turbo
|
| 4 |
+
max_tokens: 1024
|
| 5 |
+
timeout: 600
|
| 6 |
+
num_generations: 1
|
| 7 |
+
temperature: 0
|
| 8 |
+
patient_id: 'Obama'
|
| 9 |
+
autobiography_path_for_summary: ./assets/autobiography/obama/autobiography_chapters.json
|
| 10 |
+
autobiography_path_for_rag: ./assets/autobiography/obama/13991012000013_Test.txt
|
| 11 |
+
autobiography_pre_summary_path: ./assets/autobiography/obama/autobiography_pre_summary.json
|
src/configs/patient_config_autobio_summary.yaml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
agent_name: Patient
|
| 2 |
+
#llm_model_path: gpt-4o-2024-05-13
|
| 3 |
+
llm_model_path: gpt-4-turbo
|
| 4 |
+
max_tokens: 1024
|
| 5 |
+
timeout: 600
|
| 6 |
+
num_generations: 1
|
| 7 |
+
temperature: 0
|
| 8 |
+
# patient_id: 'Obama'
|
| 9 |
+
# autobiography_path_for_summary: ./assets/autobiography_chapters_demo.json
|
| 10 |
+
# autobiography_path_for_rag: ./assets/autobiography/13991012000013_Test.txt
|
| 11 |
+
# autobiography_pre_summary_path: ./assets/autobiography_pre_summary.json
|
| 12 |
+
|
| 13 |
+
# patient_id: 'Trollope'
|
| 14 |
+
# autobiography_pre_summary_path: ./assets/autobiography_pre_summary_Anthony_Trollope.json
|
| 15 |
+
# autobiography_path_for_summary: ./assets/autobiography_chapters_Anthony_Trollope.json
|
| 16 |
+
# autobiography_pre_event_path: ./assets/autobiography_pre_event_Anthony_Trollope.json
|
| 17 |
+
# autobiography_event_parsed_path: ./assets/autobiography_pre_event_parsed_Anthony_Trollope.json
|
| 18 |
+
|
| 19 |
+
# patient_id: 'Spence'
|
| 20 |
+
# autobiography_pre_summary_path: ./assets/autobiography_pre_summary_Catherine_Helen_Spence.json
|
| 21 |
+
# autobiography_path_for_summary: ./assets/autobiography_chapters_Catherine_Helen_Spence.json
|
| 22 |
+
# autobiography_pre_event_path: ./assets/autobiography_pre_event_Catherine_Helen_Spence.json
|
| 23 |
+
# autobiography_event_parsed_path: ./assets/autobiography_pre_event_parsed_Catherine_Helen_Spence.json
|
| 24 |
+
|
| 25 |
+
patient_id: 'Jane_Eyre'
|
| 26 |
+
autobiography_pre_summary_path: ./assets/autobiography_pre_summary_Jane_Eyre.json
|
| 27 |
+
autobiography_path_for_summary: ./assets/autobiography_chapters_Jane_Eyre.json
|
| 28 |
+
autobiography_pre_event_path: ./assets/autobiography_pre_event_Jane_Eyre.json
|
| 29 |
+
autobiography_event_parsed_path: ./assets/autobiography_pre_event_parsed_Jane_Eyre.json
|
src/conversation.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import os.path
|
| 3 |
+
import copy
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
import pytz
|
| 6 |
+
|
| 7 |
+
class Conversation:
|
| 8 |
+
def __init__(self) -> None:
|
| 9 |
+
self.patient_info = ""
|
| 10 |
+
|
| 11 |
+
self.conversation = []
|
| 12 |
+
|
| 13 |
+
self.patient_out = None
|
| 14 |
+
self.doctor_out = None
|
| 15 |
+
|
| 16 |
+
self.patient = None
|
| 17 |
+
self.doctor = None
|
| 18 |
+
|
| 19 |
+
self.time_stamp = datetime.now(pytz.timezone('US/Eastern')).strftime("%m/%d/%Y, %H:%M:%S, %f")
|
| 20 |
+
|
| 21 |
+
def set_patient(self, patient):
|
| 22 |
+
self.patient = patient
|
| 23 |
+
self.patient_info = {
|
| 24 |
+
"patient_model_config": patient.agent_config,
|
| 25 |
+
# "patient_story": patient.agent_config
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
def set_doctor(self, doctor):
|
| 29 |
+
self.doctor = doctor
|
| 30 |
+
|
| 31 |
+
def generate_doctor_response(self):
|
| 32 |
+
'''
|
| 33 |
+
Must be called after setting the patient and doctor
|
| 34 |
+
'''
|
| 35 |
+
self.doctor_out = self.doctor.talk_to_patient(self.patient_out,
|
| 36 |
+
conversations=self.conversation)[0]
|
| 37 |
+
return self.doctor_out
|
| 38 |
+
|
| 39 |
+
def generate_patient_response(self):
|
| 40 |
+
'''
|
| 41 |
+
Must be called after setting the patient and doctor
|
| 42 |
+
'''
|
| 43 |
+
self.patient_out = self.patient.talk_to_doctor(self.doctor_out)[0]
|
| 44 |
+
return self.patient_out
|
| 45 |
+
|
| 46 |
+
def submit_doctor_response(self, response):
|
| 47 |
+
self.conversation.append(("doctor", response))
|
| 48 |
+
self.doctor.context.add_assistant_prompt(response)
|
| 49 |
+
|
| 50 |
+
def submit_patient_response(self, response):
|
| 51 |
+
self.conversation.append(("patient", response))
|
| 52 |
+
self.patient.context.add_assistant_prompt(response)
|
| 53 |
+
|
| 54 |
+
def get_virtual_doctor_token(self):
|
| 55 |
+
return self.doctor.current_tokens
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def start_session(self, num_conv_round):
|
| 59 |
+
self.conversation_round = 1
|
| 60 |
+
while True:
|
| 61 |
+
|
| 62 |
+
doctor_out = self.generate_doctor_response()
|
| 63 |
+
self.submit_doctor_response(doctor_out)
|
| 64 |
+
print(f'Round {self.conversation_round} Doctor: {doctor_out}')
|
| 65 |
+
|
| 66 |
+
patient_out = self.generate_patient_response()
|
| 67 |
+
self.submit_patient_response(patient_out)
|
| 68 |
+
print(f'Round {self.conversation_round} Patient: {patient_out}')
|
| 69 |
+
|
| 70 |
+
self.conversation_round += 1
|
| 71 |
+
# Condition when we jump out of the loop
|
| 72 |
+
if self.conversation_round >= num_conv_round:
|
| 73 |
+
break
|
| 74 |
+
|
| 75 |
+
def set_condition(self, _type, value=None):
|
| 76 |
+
# TODO not implemented
|
| 77 |
+
pass
|
| 78 |
+
|
| 79 |
+
def to_dict(self):
|
| 80 |
+
return {
|
| 81 |
+
'time_stamp': self.time_stamp,
|
| 82 |
+
'patient': {
|
| 83 |
+
'patient_user_id': self.patient.patient_id,
|
| 84 |
+
'patient_info': self.patient_info,
|
| 85 |
+
'patient_context': self.patient.context.msg
|
| 86 |
+
},
|
| 87 |
+
'doctor': {
|
| 88 |
+
'doctor_model_config': self.doctor.agent_config,
|
| 89 |
+
'doctor_context': self.doctor.context.msg
|
| 90 |
+
},
|
| 91 |
+
"conversation": self.conversation,
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
def __json__(self):
|
| 95 |
+
return self.to_dict()
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
if __name__ == '__main__':
|
| 99 |
+
os.environ["OPENAI_API_KEY"] = "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"
|
| 100 |
+
c = Conversation()
|
| 101 |
+
c.start_session()
|
src/server.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from agents.chat_agent import BaseChatAgent
|
| 2 |
+
from src.utils.utils import save_as_json
|
| 3 |
+
|
| 4 |
+
from flask import Flask, request, jsonify
|
| 5 |
+
from flask_cors import CORS
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
import time
|
| 9 |
+
import traceback
|
| 10 |
+
import boto3
|
| 11 |
+
import argparse
|
| 12 |
+
import pytz
|
| 13 |
+
import json
|
| 14 |
+
from datetime import datetime
|
| 15 |
+
|
| 16 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 17 |
+
|
| 18 |
+
class Server:
|
| 19 |
+
def __init__(self) -> None:
|
| 20 |
+
self.patient_info = ""
|
| 21 |
+
|
| 22 |
+
self.conversation = []
|
| 23 |
+
|
| 24 |
+
self.patient_out = None
|
| 25 |
+
self.doctor_out = None
|
| 26 |
+
self.patient = None
|
| 27 |
+
self.doctor = None
|
| 28 |
+
|
| 29 |
+
self.conversation_round = 0
|
| 30 |
+
self.interview_protocol_index = None
|
| 31 |
+
|
| 32 |
+
def set_timestamp(self):
|
| 33 |
+
self.timestamp = datetime.now(pytz.timezone('US/Eastern')).strftime("%m/%d/%Y-%H:%M:%S")
|
| 34 |
+
|
| 35 |
+
def set_patient(self, patient):
|
| 36 |
+
self.patient = patient
|
| 37 |
+
self.patient_info = {
|
| 38 |
+
"patient_model_config": patient.agent_config,
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
def set_doctor(self, doctor):
|
| 42 |
+
self.doctor = doctor
|
| 43 |
+
|
| 44 |
+
def set_interview_protocol_index(self, interview_protocol_index):
|
| 45 |
+
self.interview_protocol_index = interview_protocol_index
|
| 46 |
+
|
| 47 |
+
def generate_doctor_response(self):
|
| 48 |
+
'''
|
| 49 |
+
Must be called after setting the patient and doctor
|
| 50 |
+
'''
|
| 51 |
+
self.doctor_out = self.doctor.talk_to_user(
|
| 52 |
+
self.patient_out, conversations=self.conversation)[0]
|
| 53 |
+
return self.doctor_out
|
| 54 |
+
|
| 55 |
+
def submit_doctor_response(self, response):
|
| 56 |
+
self.conversation.append(("doctor", response))
|
| 57 |
+
self.doctor.context.add_assistant_prompt(response)
|
| 58 |
+
|
| 59 |
+
def submit_patient_response(self, response):
|
| 60 |
+
self.conversation.append(("patient", response))
|
| 61 |
+
self.patient.context.add_assistant_prompt(response)
|
| 62 |
+
|
| 63 |
+
def get_response(self, patient_prompt):
|
| 64 |
+
self.patient_out = patient_prompt
|
| 65 |
+
self.submit_patient_response(patient_prompt)
|
| 66 |
+
print(f'Round {self.conversation_round} Patient: {patient_prompt}')
|
| 67 |
+
|
| 68 |
+
if patient_prompt is not None:
|
| 69 |
+
self.conversation_round += 1
|
| 70 |
+
|
| 71 |
+
doctor_out = self.generate_doctor_response()
|
| 72 |
+
self.submit_doctor_response(doctor_out)
|
| 73 |
+
print(f'Round {self.conversation_round} Doctor: {doctor_out}')
|
| 74 |
+
|
| 75 |
+
return {"response": doctor_out}
|
| 76 |
+
|
| 77 |
+
def to_dict(self):
|
| 78 |
+
return {
|
| 79 |
+
'time_stamp': self.timestamp,
|
| 80 |
+
'patient': {
|
| 81 |
+
'patient_user_id': self.patient.patient_id,
|
| 82 |
+
'patient_info': self.patient_info,
|
| 83 |
+
'patient_context': self.patient.context.msg
|
| 84 |
+
},
|
| 85 |
+
'doctor': {
|
| 86 |
+
'doctor_model_config': self.doctor.agent_config,
|
| 87 |
+
'doctor_context': self.doctor.context.msg
|
| 88 |
+
},
|
| 89 |
+
"conversation": self.conversation,
|
| 90 |
+
"interview_protocol_index": self.interview_protocol_index
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
def __json__(self):
|
| 94 |
+
return self.to_dict()
|
| 95 |
+
|
| 96 |
+
def reset(self):
|
| 97 |
+
self.conversation = []
|
| 98 |
+
self.conversation_round = 0
|
| 99 |
+
if hasattr(self.doctor, 'reset') and callable(getattr(self.doctor, 'reset')):
|
| 100 |
+
self.doctor.reset()
|
| 101 |
+
if hasattr(self.patient, 'reset') and callable(getattr(self.patient, 'reset')):
|
| 102 |
+
self.patient.reset()
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def create_app():
|
| 106 |
+
app = Flask(__name__)
|
| 107 |
+
CORS(app)
|
| 108 |
+
app.user_servers = {}
|
| 109 |
+
return app
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def configure_routes(app, args):
|
| 113 |
+
|
| 114 |
+
@app.route('/', methods=['GET'])
|
| 115 |
+
def home():
|
| 116 |
+
'''
|
| 117 |
+
This api will return the default prompts used in the backend, including system prompt, autobiography generation prompt, therapy prompt, and conversation instruction prompt
|
| 118 |
+
Return:
|
| 119 |
+
{
|
| 120 |
+
system_prompt: String,
|
| 121 |
+
autobio_generation_prompt: String,
|
| 122 |
+
therapy_prompt: String,
|
| 123 |
+
conv_instruction_prompt: String
|
| 124 |
+
}
|
| 125 |
+
'''
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
return jsonify({
|
| 129 |
+
}), 200
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
@app.route('/api/initialization', methods=['POST'])
|
| 133 |
+
def initialization():
|
| 134 |
+
'''
|
| 135 |
+
This API processes user configurations to initialize conversation states. It specifically accepts the following parameters:
|
| 136 |
+
api_key, username, chapter_name, topic_name, and prompts. The API will then:
|
| 137 |
+
1. Initialize a Server() instance for managing conversations and sessions.
|
| 138 |
+
2. Configure the user-defined prompts.
|
| 139 |
+
3. Set up the chapter and topic for the conversation.
|
| 140 |
+
4. Configure the save path for both local storage and Amazon S3.
|
| 141 |
+
'''
|
| 142 |
+
data = request.get_json()
|
| 143 |
+
username = data.get('username')
|
| 144 |
+
|
| 145 |
+
os.environ["OPENAI_API_KEY"] = "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"
|
| 146 |
+
|
| 147 |
+
# initialize
|
| 148 |
+
# server.patient.patient_id = username
|
| 149 |
+
counselor = BaseChatAgent(config_path=args.counselor_config_path)
|
| 150 |
+
print(counselor)
|
| 151 |
+
server = Server()
|
| 152 |
+
# server.set_doctor = counselor
|
| 153 |
+
server.doctor = counselor
|
| 154 |
+
app.user_servers[username] = server
|
| 155 |
+
|
| 156 |
+
return jsonify({"message": "API key set successfully"}), 200
|
| 157 |
+
|
| 158 |
+
@app.route('/save/download_conversations', methods=['POST'])
|
| 159 |
+
def download_conversations():
|
| 160 |
+
"""
|
| 161 |
+
This API retrieves the user's conversation history based on their username and returns the conversation data to the frontend.
|
| 162 |
+
Return:
|
| 163 |
+
conversations: List[String]
|
| 164 |
+
"""
|
| 165 |
+
data = request.get_json()
|
| 166 |
+
username = data.get('username')
|
| 167 |
+
chatbot_type = data.get('chatbot_type')
|
| 168 |
+
if not username:
|
| 169 |
+
return jsonify({'error': 'Username not provided'}), 400
|
| 170 |
+
if not chatbot_type:
|
| 171 |
+
return jsonify({'error': 'Chatbot type not provided'}), 400
|
| 172 |
+
|
| 173 |
+
conversation_dir = os.path.join('user_data', chatbot_type, username, 'conversation')
|
| 174 |
+
if not os.path.exists(conversation_dir):
|
| 175 |
+
return jsonify({'error': 'User not found or no conversations available'}), 404
|
| 176 |
+
|
| 177 |
+
# Llist all files in the conversation directory
|
| 178 |
+
files = os.listdir(conversation_dir)
|
| 179 |
+
conversations = []
|
| 180 |
+
|
| 181 |
+
# read each conversation file and append the conversation data to the list
|
| 182 |
+
for file_name in files:
|
| 183 |
+
file_path = os.path.join(conversation_dir, file_name)
|
| 184 |
+
try:
|
| 185 |
+
with open(file_path, 'r') as f:
|
| 186 |
+
conversation_data = json.load(f)
|
| 187 |
+
# extract the 'conversation' from the JSON
|
| 188 |
+
conversation_content = conversation_data.get('conversation', [])
|
| 189 |
+
conversations.append({
|
| 190 |
+
'file_name': file_name,
|
| 191 |
+
'conversation': conversation_content
|
| 192 |
+
})
|
| 193 |
+
except Exception as e:
|
| 194 |
+
print(f"Error reading {file_name}: {e}")
|
| 195 |
+
continue
|
| 196 |
+
|
| 197 |
+
return jsonify(conversations), 200
|
| 198 |
+
|
| 199 |
+
@app.route('/save/end_and_save', methods=['POST'])
|
| 200 |
+
def save_conversation_memory():
|
| 201 |
+
"""
|
| 202 |
+
This API saves the current conversation history and memory events to the backend, then synchronizes the data with the Amazon S3 server.
|
| 203 |
+
"""
|
| 204 |
+
data = request.get_json()
|
| 205 |
+
username = data.get('username')
|
| 206 |
+
chatbot_type = data.get('chatbot_type')
|
| 207 |
+
|
| 208 |
+
if not username:
|
| 209 |
+
return jsonify({"error": "Username not provided"}), 400
|
| 210 |
+
if not chatbot_type:
|
| 211 |
+
return jsonify({"error": "Chatbot type not provided"}), 400
|
| 212 |
+
server = app.user_servers.get(username)
|
| 213 |
+
if not server:
|
| 214 |
+
return jsonify({"error": "User session not found"}), 400
|
| 215 |
+
|
| 216 |
+
# save conversation history
|
| 217 |
+
server.set_timestamp()
|
| 218 |
+
save_name = f'{server.chapter_name}-{server.topic_name}-{server.timestamp}.json'
|
| 219 |
+
save_name = save_name.replace(' ', '-').replace('/', '-')
|
| 220 |
+
print(save_name)
|
| 221 |
+
|
| 222 |
+
# save to local file
|
| 223 |
+
local_conv_file_path = os.path.join(server.patient.conv_history_path, save_name)
|
| 224 |
+
save_as_json(local_conv_file_path, server.to_dict())
|
| 225 |
+
|
| 226 |
+
local_memory_graph_file = os.path.join(server.patient.memory_graph_path, save_name)
|
| 227 |
+
# if the chatbot type is 'baseline', create a dummy memory graph file
|
| 228 |
+
if chatbot_type == 'baseline':
|
| 229 |
+
save_as_json(local_memory_graph_file, {'time_indexed_memory_chain': []})
|
| 230 |
+
else:
|
| 231 |
+
# save memory graph
|
| 232 |
+
server.doctor.memory_graph.save(local_memory_graph_file)
|
| 233 |
+
|
| 234 |
+
return jsonify({"message": "Current conversation and memory graph are saved!"}), 200
|
| 235 |
+
|
| 236 |
+
@app.route('/responses/doctor', methods=['POST'])
|
| 237 |
+
def get_response():
|
| 238 |
+
"""
|
| 239 |
+
This API retrieves the chatbot's response and returns both the response and updated memory events to the frontend.
|
| 240 |
+
Return:
|
| 241 |
+
{
|
| 242 |
+
doctor_response: String,
|
| 243 |
+
memory_events: List[dict]
|
| 244 |
+
}
|
| 245 |
+
"""
|
| 246 |
+
data = request.get_json()
|
| 247 |
+
username = data.get('username')
|
| 248 |
+
|
| 249 |
+
# patient_prompt = data.get('patient_prompt')
|
| 250 |
+
# chatbot_type = data.get('chatbot_type')
|
| 251 |
+
# if not username or not patient_prompt:
|
| 252 |
+
# return jsonify({"error": "Username or patient prompt not provided"}), 400
|
| 253 |
+
# if not chatbot_type:
|
| 254 |
+
# return jsonify({"error": "Chatbot type not provided"}), 400
|
| 255 |
+
# if not
|
| 256 |
+
# server = app.user_servers.get(username)
|
| 257 |
+
# if not server:
|
| 258 |
+
# return jsonify({"error": "User session not found"}), 400
|
| 259 |
+
|
| 260 |
+
# print(server.patient.patient_id, server.chapter_name, server.topic_name)
|
| 261 |
+
# doctor_response = server.get_response(patient_prompt=patient_prompt)
|
| 262 |
+
|
| 263 |
+
# if chatbot_type == 'baseline':
|
| 264 |
+
# memory_events = []
|
| 265 |
+
# else:
|
| 266 |
+
# memory_events = server.doctor.memory_graph.to_list()
|
| 267 |
+
print('username', username)
|
| 268 |
+
server = app.user_servers.get(username)
|
| 269 |
+
llm_chatbot = server.doctor
|
| 270 |
+
response = llm_chatbot.talk_to_user(data)
|
| 271 |
+
|
| 272 |
+
return jsonify({'doctor_response': response})
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def main():
|
| 276 |
+
parser = argparse.ArgumentParser()
|
| 277 |
+
# parser.add_argument('--patient-config-path', type=str,
|
| 278 |
+
# default='./src/configs/patient_config.yaml')
|
| 279 |
+
parser.add_argument('--counselor-config-path', type=str,
|
| 280 |
+
default='./src/configs/counselor_config.yaml')
|
| 281 |
+
# parser.add_argument('--retriever-config-path', type=str,
|
| 282 |
+
# default='./src/configs/retrievers/faiss_retriever.yaml')
|
| 283 |
+
parser.add_argument('--store-dir',
|
| 284 |
+
type=str, default='./user_data')
|
| 285 |
+
# parser.add_argument('--memory-graph-config', default='./src/configs/memory_graph_config.yaml')
|
| 286 |
+
# parser.add_argument('--num-conversation-round', type=int, default=30)
|
| 287 |
+
args = parser.parse_args()
|
| 288 |
+
|
| 289 |
+
app = create_app()
|
| 290 |
+
configure_routes(app, args)
|
| 291 |
+
|
| 292 |
+
port = int(os.environ.get('PORT', 8080))
|
| 293 |
+
app.run(port=port, host='0.0.0.0', debug=False)
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
if __name__ == '__main__':
|
| 297 |
+
main()
|
src/utils/__init__.py
ADDED
|
File without changes
|
src/utils/utils.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import yaml
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
from box import Box
|
| 6 |
+
import re
|
| 7 |
+
|
| 8 |
+
MAX_INPUT_TOKEN_NUM = 128000
|
| 9 |
+
|
| 10 |
+
def load_config(config_path):
|
| 11 |
+
config = Box.from_yaml(
|
| 12 |
+
filename=config_path, Loader=yaml.FullLoader)
|
| 13 |
+
return config
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def save_as_json(path, data):
|
| 18 |
+
'''
|
| 19 |
+
outout a json file with indent equals to 2
|
| 20 |
+
'''
|
| 21 |
+
json_data = json.dumps(data, indent=2)
|
| 22 |
+
# Save JSON to a file
|
| 23 |
+
try:
|
| 24 |
+
# Create the directory structure if it doesn't exist
|
| 25 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
| 26 |
+
|
| 27 |
+
# Write data to the JSON file
|
| 28 |
+
with open(path, 'w') as json_file:
|
| 29 |
+
json_file.write(json_data)
|
| 30 |
+
|
| 31 |
+
print(f"Experiment results written to '{path}' successfully.")
|
| 32 |
+
|
| 33 |
+
class ContextAttributeError(BaseException):
|
| 34 |
+
def __init__(self, message):
|
| 35 |
+
self.message = message
|
| 36 |
+
super().__init__(self.message)
|
| 37 |
+
|
| 38 |
+
return True
|
| 39 |
+
except Exception as e:
|
| 40 |
+
print(f"Failed to write experiment results to '{path}': {e}")
|
| 41 |
+
return False
|
| 42 |
+
|
| 43 |
+
pass
|
| 44 |
+
|