File size: 11,725 Bytes
706667c
a44bd4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c26f5b3
 
 
a44bd4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a725a4
 
 
 
 
 
 
 
 
 
 
 
 
 
a44bd4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
from src.agents.chat_agent import BaseChatAgent
from src.utils.utils import save_as_json

from flask import Flask, request, jsonify
from flask_cors import CORS
import os
import sys
import time
import traceback
import boto3
import argparse
import pytz
import json
from datetime import datetime

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

class Server:
    def __init__(self) -> None:
        self.patient_info = ""

        self.conversation = []

        self.patient_out = None
        self.doctor_out = None
        self.patient = None
        self.doctor = None

        self.conversation_round = 0
        self.interview_protocol_index = None

    def set_timestamp(self):
        self.timestamp = datetime.now(pytz.timezone('US/Eastern')).strftime("%m/%d/%Y-%H:%M:%S")

    def set_patient(self, patient):
        self.patient = patient
        self.patient_info = {
            "patient_model_config": patient.agent_config,
        }

    def set_doctor(self, doctor):
        self.doctor = doctor

    def set_interview_protocol_index(self, interview_protocol_index):
        self.interview_protocol_index = interview_protocol_index

    def generate_doctor_response(self):
        '''
        Must be called after setting the patient and doctor
        '''
        self.doctor_out = self.doctor.talk_to_user(
            self.patient_out, conversations=self.conversation)[0]
        return self.doctor_out

    def submit_doctor_response(self, response):
        self.conversation.append(("doctor", response))
        self.doctor.context.add_assistant_prompt(response)

    def submit_patient_response(self, response):
        self.conversation.append(("patient", response))
        self.patient.context.add_assistant_prompt(response)

    def get_response(self, patient_prompt):
        self.patient_out = patient_prompt
        self.submit_patient_response(patient_prompt)
        print(f'Round {self.conversation_round} Patient: {patient_prompt}')
    
        if patient_prompt is not None:
            self.conversation_round += 1

        doctor_out = self.generate_doctor_response()
        self.submit_doctor_response(doctor_out)
        print(f'Round {self.conversation_round} Doctor: {doctor_out}')

        return {"response": doctor_out}

    def to_dict(self):
        return {
            'time_stamp': self.timestamp,
            'patient': {
                'patient_user_id': self.patient.patient_id,
                'patient_info': self.patient_info,
                'patient_context': self.patient.context.msg
            },
            'doctor': {
                'doctor_model_config': self.doctor.agent_config,
                'doctor_context': self.doctor.context.msg
            },
            "conversation": self.conversation,
            "interview_protocol_index": self.interview_protocol_index
        }

    def __json__(self):
        return self.to_dict()

    def reset(self):
        self.conversation = []
        self.conversation_round = 0
        if hasattr(self.doctor, 'reset') and callable(getattr(self.doctor, 'reset')):
            self.doctor.reset()
        if hasattr(self.patient, 'reset') and callable(getattr(self.patient, 'reset')):
            self.patient.reset()


def create_app():
    app = Flask(__name__)
    CORS(app)
    app.user_servers = {}
    return app


def configure_routes(app, args):

    @app.route('/', methods=['GET'])
    def home():
        '''
        This api will return the default prompts used in the backend, including system prompt, autobiography generation prompt, therapy prompt, and conversation instruction prompt
        Return:
            {
            system_prompt: String,
            autobio_generation_prompt: String,
            therapy_prompt: String,
            conv_instruction_prompt: String
            }
        '''


        return jsonify({
        }), 200


    @app.route('/api/initialization', methods=['POST'])
    def initialization():
        '''
        This API processes user configurations to initialize conversation states. It specifically accepts the following parameters:
        api_key, username, chapter_name, topic_name, and prompts. The API will then:
        1. Initialize a Server() instance for managing conversations and sessions.
        2. Configure the user-defined prompts.
        3. Set up the chapter and topic for the conversation.
        4. Configure the save path for both local storage and Amazon S3.
        '''
        data = request.get_json()
        username = data.get('username')
        api_key = data.get('api_key')
        if api_key and isinstance(api_key, str) and api_key.strip():
            os.environ["OPENAI_API_KEY"] = api_key.strip()

        # initialize
        # server.patient.patient_id = username
        counselor = BaseChatAgent(config_path=args.counselor_config_path)
        print(counselor)
        server = Server()
        # server.set_doctor = counselor
        server.doctor = counselor
        app.user_servers[username] = server

        return jsonify({"message": "API key set successfully"}), 200

    @app.route('/save/download_conversations', methods=['POST'])
    def download_conversations():
        """
        This API retrieves the user's conversation history based on their username and returns the conversation data to the frontend.
        Return:
            conversations: List[String]
        """
        data = request.get_json()
        username = data.get('username')
        chatbot_type = data.get('chatbot_type')
        if not username:
            return jsonify({'error': 'Username not provided'}), 400
        if not chatbot_type:
            return jsonify({'error': 'Chatbot type not provided'}), 400

        conversation_dir = os.path.join('user_data', chatbot_type, username, 'conversation')
        if not os.path.exists(conversation_dir):
            return jsonify({'error': 'User not found or no conversations available'}), 404

        # Llist all files in the conversation directory
        files = os.listdir(conversation_dir)
        conversations = []

        # read each conversation file and append the conversation data to the list
        for file_name in files:
            file_path = os.path.join(conversation_dir, file_name)
            try:
                with open(file_path, 'r') as f:
                    conversation_data = json.load(f)
                    # extract the 'conversation' from the JSON
                    conversation_content = conversation_data.get('conversation', [])
                    conversations.append({
                        'file_name': file_name,
                        'conversation': conversation_content
                    })
            except Exception as e:
                print(f"Error reading {file_name}: {e}")
                continue

        return jsonify(conversations), 200

    @app.route('/save/end_and_save', methods=['POST'])
    def save_conversation_memory():
        """
        This API saves the current conversation history and memory events to the backend, then synchronizes the data with the Amazon S3 server.
        """
        data = request.get_json()
        username = data.get('username')
        chatbot_type = data.get('chatbot_type')

        if not username:
            return jsonify({"error": "Username not provided"}), 400
        if not chatbot_type:
            return jsonify({"error": "Chatbot type not provided"}), 400
        server = app.user_servers.get(username)
        if not server:
            return jsonify({"error": "User session not found"}), 400

        # save conversation history
        server.set_timestamp()
        save_name = f'{server.chapter_name}-{server.topic_name}-{server.timestamp}.json'
        save_name = save_name.replace(' ', '-').replace('/', '-')
        print(save_name)

        # save to local file
        local_conv_file_path = os.path.join(server.patient.conv_history_path, save_name)
        save_as_json(local_conv_file_path, server.to_dict())

        local_memory_graph_file = os.path.join(server.patient.memory_graph_path, save_name)
        # if the chatbot type is 'baseline', create a dummy memory graph file
        if chatbot_type == 'baseline':
            save_as_json(local_memory_graph_file, {'time_indexed_memory_chain': []})
        else:
            # save memory graph
            server.doctor.memory_graph.save(local_memory_graph_file)
        
        # Auto-upload to Google Drive if authenticated
        try:
            import sys
            sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
            from google_drive_sync import auto_upload_to_drive
            
            # Upload conversation file
            auto_upload_to_drive(local_conv_file_path, user_id=username, folder_name="Chatbot_Conversations")
            # Upload memory graph file
            auto_upload_to_drive(local_memory_graph_file, user_id=username, folder_name="Chatbot_Conversations")
        except Exception as e:
            # Fail silently if Google Drive upload fails
            print(f"Google Drive auto-upload failed (non-critical): {str(e)}")
        
        return jsonify({"message": "Current conversation and memory graph are saved!"}), 200

    @app.route('/responses/doctor', methods=['POST'])
    def get_response():
        """
        This API retrieves the chatbot's response and returns both the response and updated memory events to the frontend.
        Return:
            {
            doctor_response: String,
            memory_events: List[dict]
            }
        """
        data = request.get_json()
        username = data.get('username')

        # patient_prompt = data.get('patient_prompt')
        # chatbot_type = data.get('chatbot_type')
        # if not username or not patient_prompt:
        #     return jsonify({"error": "Username or patient prompt not provided"}), 400
        # if not chatbot_type:
        #     return jsonify({"error": "Chatbot type not provided"}), 400
        # if not
        # server = app.user_servers.get(username)
        # if not server:
        #     return jsonify({"error": "User session not found"}), 400

        # print(server.patient.patient_id, server.chapter_name, server.topic_name)
        # doctor_response = server.get_response(patient_prompt=patient_prompt)
        
        # if chatbot_type == 'baseline':
        #     memory_events = []
        # else:
        #     memory_events = server.doctor.memory_graph.to_list()
        print('username', username)
        server = app.user_servers.get(username)
        llm_chatbot = server.doctor
        response = llm_chatbot.talk_to_user(data)

        return jsonify({'doctor_response': response})


def main():
    parser = argparse.ArgumentParser()
    # parser.add_argument('--patient-config-path', type=str,
    #                     default='./src/configs/patient_config.yaml')
    parser.add_argument('--counselor-config-path', type=str,
                        default='./src/configs/counselor_config.yaml')
    # parser.add_argument('--retriever-config-path', type=str,
    #                     default='./src/configs/retrievers/faiss_retriever.yaml')
    parser.add_argument('--store-dir',
                        type=str, default='./user_data')
    # parser.add_argument('--memory-graph-config', default='./src/configs/memory_graph_config.yaml')
    # parser.add_argument('--num-conversation-round', type=int, default=30)
    args = parser.parse_args()

    app = create_app()
    configure_routes(app, args)

    port = int(os.environ.get('PORT', 8080))
    app.run(port=port, host='0.0.0.0', debug=False)


if __name__ == '__main__':
    main()