nivakaran commited on
Commit
d71fbb3
·
verified ·
1 Parent(s): c65255d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +192 -183
app.py CHANGED
@@ -1,31 +1,44 @@
1
- import gradio as gr
2
  import uuid
3
  import logging
4
  from datetime import datetime
5
  import os
 
 
6
 
7
  from src.graphs.finalAgentGraph import sparrowAgent
8
  from langchain_core.messages import HumanMessage, AIMessage
9
 
10
- # Setup logging
 
 
 
 
11
  logging.basicConfig(level=logging.INFO)
12
  logger = logging.getLogger(__name__)
13
 
14
 
 
 
 
 
15
  def ensure_langchain_message(message):
16
  """Ensure a message is a proper LangChain message object"""
17
  if isinstance(message, (HumanMessage, AIMessage)):
18
  return message
19
  elif isinstance(message, dict):
 
20
  content = message.get('content', str(message))
21
- message_type = message.get('type', 'ai')
22
  if message_type == 'human':
23
  return HumanMessage(content=content)
24
  else:
25
  return AIMessage(content=content)
26
  elif isinstance(message, str):
 
27
  return AIMessage(content=message)
28
  else:
 
29
  return AIMessage(content=str(message))
30
 
31
 
@@ -38,239 +51,235 @@ def clean_messages_list(messages):
38
  return cleaned_messages
39
 
40
 
41
- def initialize_conversation():
42
- """Initialize a new conversation state"""
43
- return {
44
- 'thread_id': str(uuid.uuid4()),
45
- 'messages': [],
46
- 'notes': [],
47
- 'query_brief': '',
48
- 'final_message': '',
49
- 'created_at': datetime.now(),
50
- 'last_updated': datetime.now()
51
- }
52
 
53
 
54
- def process_message(user_message, history, conversation_state):
55
- """
56
- Process user message and return response
57
-
58
- Args:
59
- user_message: The user's input message
60
- history: Gradio chat history (list of [user_msg, bot_msg] pairs)
61
- conversation_state: Dictionary containing conversation context
62
-
63
- Returns:
64
- Tuple of (empty string, updated history, updated conversation state, status message)
65
- """
66
  try:
67
- if not user_message or not user_message.strip():
68
- return "", history, conversation_state, "Please enter a message"
 
 
 
 
69
 
70
- # Initialize conversation state if None
71
- if conversation_state is None:
72
- conversation_state = initialize_conversation()
 
 
73
 
74
- thread_id = conversation_state['thread_id']
 
 
 
 
 
 
 
 
 
 
 
75
 
76
- # Add user message to conversation
77
  human_message = HumanMessage(content=user_message)
78
- conversation_state['messages'].append(human_message)
79
- conversation_state['last_updated'] = datetime.now()
 
80
 
81
- # Clean messages
82
- cleaned_messages = clean_messages_list(conversation_state['messages'])
83
 
84
- # Prepare input for sparrow agent
85
  sparrow_input = {
86
  'messages': cleaned_messages,
87
- 'notes': conversation_state.get('notes', []),
88
- 'query_brief': conversation_state.get('query_brief', ''),
89
- 'final_message': conversation_state.get('final_message', '')
90
  }
91
-
92
  logger.info(f"[{thread_id}] Processing message: {user_message[:100]}")
93
  logger.info(f"[{thread_id}] Input messages count: {len(cleaned_messages)}")
 
94
 
95
- # Invoke the sparrow agent
96
  result = sparrowAgent.invoke(sparrow_input)
 
97
 
98
- # Extract response message
99
  response_message = ""
100
  ai_message = None
101
-
102
  if result.get('final_message'):
103
  response_message = result['final_message']
104
  ai_message = AIMessage(content=response_message)
105
  else:
 
106
  result_messages = clean_messages_list(result.get('messages', []))
107
 
108
- # Find last user message index
109
  last_user_index = -1
110
  for i, msg in enumerate(result_messages):
111
  if isinstance(msg, HumanMessage):
112
  last_user_index = i
113
 
114
- # Get first AI message after last user message
115
  for i in range(last_user_index + 1, len(result_messages)):
116
  msg = result_messages[i]
117
  if isinstance(msg, AIMessage) and msg.content and msg.content.strip():
118
  response_message = msg.content
119
  ai_message = msg
120
  break
121
-
122
  if not response_message:
123
  response_message = "I'm processing your request. Could you provide more details?"
124
  ai_message = AIMessage(content=response_message)
 
125
 
126
- # Update conversation state
127
- if result.get('messages'):
128
- conversation_state['messages'] = clean_messages_list(result['messages'])
129
- else:
130
- conversation_state['messages'].append(ai_message)
131
-
132
- # Remove consecutive duplicates
133
- cleaned_conversation_messages = []
134
- prev_content = None
135
- prev_type = None
136
-
137
- for msg in conversation_state['messages']:
138
- current_content = msg.content if hasattr(msg, 'content') else str(msg)
139
- current_type = type(msg).__name__
140
-
141
- if current_content != prev_content or current_type != prev_type:
142
- cleaned_conversation_messages.append(msg)
143
- prev_content = current_content
144
- prev_type = current_type
145
-
146
- conversation_state['messages'] = cleaned_conversation_messages
147
- conversation_state['notes'] = result.get('notes', conversation_state.get('notes', []))
148
- conversation_state['query_brief'] = result.get('query_brief', conversation_state.get('query_brief', ''))
149
- conversation_state['final_message'] = result.get('final_message', conversation_state.get('final_message', ''))
150
- conversation_state['last_updated'] = datetime.now()
151
-
152
- # Update Gradio chat history
153
- history.append([user_message, response_message])
154
-
155
- # Create status message
156
- status_info = f"Thread: {thread_id[:8]}... | Messages: {len(conversation_state['messages'])}"
157
  if result.get('execution_jobs'):
158
- status_info += f" | Executed: {', '.join(result['execution_jobs'])}"
159
  elif result.get('notes') and isinstance(result['notes'], list) and result['notes']:
160
- status_info += f" | Note: {str(result['notes'][-1])[:50]}"
 
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  logger.info(f"[{thread_id}] Response generated: {response_message[:100]}")
163
- logger.info(f"[{thread_id}] Final messages count: {len(conversation_state['messages'])}")
 
164
 
165
- return "", history, conversation_state, status_info
166
-
167
- except Exception as e:
168
- logger.error(f"Error processing message: {str(e)}", exc_info=True)
169
- error_msg = f"An error occurred: {str(e)}"
170
- history.append([user_message, error_msg])
171
- return "", history, conversation_state, f"Error: {str(e)}"
172
 
 
 
 
 
 
 
173
 
174
- def clear_conversation():
175
- """Clear conversation and start fresh"""
176
- new_state = initialize_conversation()
177
- logger.info(f"[{new_state['thread_id']}] New conversation started")
178
- return [], new_state, f"New conversation started (ID: {new_state['thread_id'][:8]}...)"
179
 
180
 
181
- def get_conversation_info(conversation_state):
182
- """Get current conversation information"""
183
- if conversation_state is None:
184
- return "No active conversation"
 
185
 
186
- info_lines = [
187
- f"**Thread ID:** {conversation_state['thread_id']}",
188
- f"**Messages:** {len(conversation_state.get('messages', []))}",
189
- f"**Notes:** {len(conversation_state.get('notes', []))}",
190
- f"**Has Query Brief:** {bool(conversation_state.get('query_brief'))}",
191
- f"**Has Final Message:** {bool(conversation_state.get('final_message'))}",
192
- f"**Created:** {conversation_state.get('created_at', 'N/A')}",
193
- f"**Last Updated:** {conversation_state.get('last_updated', 'N/A')}"
194
- ]
195
 
196
- return "\n\n".join(info_lines)
197
 
198
 
199
- # Create Gradio interface
200
- with gr.Blocks(title="Sparrow Agent Chat", theme=gr.themes.Soft()) as demo:
201
- gr.Markdown("# 🦜 Sparrow Agent Chat")
202
- gr.Markdown("Interact with the Sparrow AI Agent. Ask questions and get intelligent responses!")
203
-
204
- # State to store conversation context
205
- conversation_state = gr.State(initialize_conversation())
206
 
207
- with gr.Row():
208
- with gr.Column(scale=4):
209
- chatbot = gr.Chatbot(
210
- label="Conversation",
211
- height=500,
212
- show_copy_button=True
213
- )
214
-
215
- with gr.Row():
216
- msg = gr.Textbox(
217
- label="Your Message",
218
- placeholder="Type your message here...",
219
- lines=2,
220
- scale=4
221
- )
222
- submit_btn = gr.Button("Send", variant="primary", scale=1)
223
-
224
- with gr.Row():
225
- clear_btn = gr.Button("New Conversation", variant="secondary")
226
- status_box = gr.Textbox(
227
- label="Status",
228
- interactive=False,
229
- lines=1
230
- )
231
 
232
- with gr.Column(scale=1):
233
- gr.Markdown("### Debug Info")
234
- info_btn = gr.Button("Show Conversation Info")
235
- info_display = gr.Markdown("Click button to show info")
236
-
237
- # Event handlers
238
- submit_btn.click(
239
- fn=process_message,
240
- inputs=[msg, chatbot, conversation_state],
241
- outputs=[msg, chatbot, conversation_state, status_box]
242
- )
243
-
244
- msg.submit(
245
- fn=process_message,
246
- inputs=[msg, chatbot, conversation_state],
247
- outputs=[msg, chatbot, conversation_state, status_box]
248
- )
249
-
250
- clear_btn.click(
251
- fn=clear_conversation,
252
- inputs=[],
253
- outputs=[chatbot, conversation_state, status_box]
254
- )
255
-
256
- info_btn.click(
257
- fn=get_conversation_info,
258
- inputs=[conversation_state],
259
- outputs=[info_display]
260
- )
261
 
262
- # Initialize status on load
263
- demo.load(
264
- fn=lambda state: f"Ready | Thread: {state['thread_id'][:8]}...",
265
- inputs=[conversation_state],
266
- outputs=[status_box]
267
- )
268
-
269
-
270
- # Launch the app
271
- if __name__ == "__main__":
272
- demo.launch(
273
- server_name="0.0.0.0",
274
- server_port=int(os.environ.get('PORT', 7860)),
275
- share=False
276
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify, render_template, session
2
  import uuid
3
  import logging
4
  from datetime import datetime
5
  import os
6
+ import threading
7
+ import time
8
 
9
  from src.graphs.finalAgentGraph import sparrowAgent
10
  from langchain_core.messages import HumanMessage, AIMessage
11
 
12
+
13
+ app = Flask(__name__)
14
+ app.secret_key = os.environ.get('FLASK_SECRET_KEY', 'your-secret-key-here')
15
+
16
+
17
  logging.basicConfig(level=logging.INFO)
18
  logger = logging.getLogger(__name__)
19
 
20
 
21
+ conversations = {}
22
+ conversations_lock = threading.Lock()
23
+
24
+
25
  def ensure_langchain_message(message):
26
  """Ensure a message is a proper LangChain message object"""
27
  if isinstance(message, (HumanMessage, AIMessage)):
28
  return message
29
  elif isinstance(message, dict):
30
+
31
  content = message.get('content', str(message))
32
+ message_type = message.get('type', 'ai')
33
  if message_type == 'human':
34
  return HumanMessage(content=content)
35
  else:
36
  return AIMessage(content=content)
37
  elif isinstance(message, str):
38
+
39
  return AIMessage(content=message)
40
  else:
41
+
42
  return AIMessage(content=str(message))
43
 
44
 
 
51
  return cleaned_messages
52
 
53
 
54
+ @app.route('/')
55
+ def index():
56
+ """Serve the main chat interface"""
57
+ return render_template('index.html')
 
 
 
 
 
 
 
58
 
59
 
60
+ @app.route('/chat', methods=['POST'])
61
+ def chat():
62
+ """Handle chat messages"""
 
 
 
 
 
 
 
 
 
63
  try:
64
+ data = request.get_json()
65
+ user_message = data.get('message', '').strip()
66
+
67
+ if not user_message:
68
+ return jsonify({'success': False, 'error': 'Empty message'})
69
+
70
 
71
+ thread_id = session.get('thread_id')
72
+ if not thread_id:
73
+ thread_id = str(uuid.uuid4())
74
+ session['thread_id'] = thread_id
75
+
76
 
77
+ with conversations_lock:
78
+ if thread_id not in conversations:
79
+ conversations[thread_id] = {
80
+ 'messages': [],
81
+ 'notes': [],
82
+ 'query_brief': '',
83
+ 'final_message': '',
84
+ 'created_at': datetime.now(),
85
+ 'last_updated': datetime.now()
86
+ }
87
+ conversation = conversations[thread_id]
88
+
89
 
 
90
  human_message = HumanMessage(content=user_message)
91
+ conversation['messages'].append(human_message)
92
+ conversation['last_updated'] = datetime.now()
93
+
94
 
95
+ cleaned_messages = clean_messages_list(conversation['messages'])
96
+
97
 
 
98
  sparrow_input = {
99
  'messages': cleaned_messages,
100
+ 'notes': conversation.get('notes', []),
101
+ 'query_brief': conversation.get('query_brief', ''),
102
+ 'final_message': conversation.get('final_message', '')
103
  }
104
+
105
  logger.info(f"[{thread_id}] Processing message: {user_message[:100]}")
106
  logger.info(f"[{thread_id}] Input messages count: {len(cleaned_messages)}")
107
+
108
 
 
109
  result = sparrowAgent.invoke(sparrow_input)
110
+
111
 
 
112
  response_message = ""
113
  ai_message = None
114
+
115
  if result.get('final_message'):
116
  response_message = result['final_message']
117
  ai_message = AIMessage(content=response_message)
118
  else:
119
+
120
  result_messages = clean_messages_list(result.get('messages', []))
121
 
122
+
123
  last_user_index = -1
124
  for i, msg in enumerate(result_messages):
125
  if isinstance(msg, HumanMessage):
126
  last_user_index = i
127
 
128
+
129
  for i in range(last_user_index + 1, len(result_messages)):
130
  msg = result_messages[i]
131
  if isinstance(msg, AIMessage) and msg.content and msg.content.strip():
132
  response_message = msg.content
133
  ai_message = msg
134
  break
135
+
136
  if not response_message:
137
  response_message = "I'm processing your request. Could you provide more details?"
138
  ai_message = AIMessage(content=response_message)
139
+
140
 
141
+ status_info = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  if result.get('execution_jobs'):
143
+ status_info = f"Executed: {', '.join(result['execution_jobs'])}"
144
  elif result.get('notes') and isinstance(result['notes'], list) and result['notes']:
145
+ status_info = str(result['notes'][-1])
146
+
147
 
148
+ with conversations_lock:
149
+
150
+ if result.get('messages'):
151
+ conversation['messages'] = clean_messages_list(result['messages'])
152
+ else:
153
+
154
+ conversation['messages'].append(ai_message)
155
+
156
+
157
+ cleaned_conversation_messages = []
158
+ prev_content = None
159
+ prev_type = None
160
+
161
+ for msg in conversation['messages']:
162
+ current_content = msg.content if hasattr(msg, 'content') else str(msg)
163
+ current_type = type(msg).__name__
164
+
165
+
166
+ if current_content != prev_content or current_type != prev_type:
167
+ cleaned_conversation_messages.append(msg)
168
+ prev_content = current_content
169
+ prev_type = current_type
170
+
171
+ conversation['messages'] = cleaned_conversation_messages
172
+
173
+
174
+ conversation['notes'] = result.get('notes', conversation.get('notes', []))
175
+ conversation['query_brief'] = result.get('query_brief', conversation.get('query_brief', ''))
176
+ conversation['final_message'] = result.get('final_message', conversation.get('final_message', ''))
177
+ conversation['last_updated'] = datetime.now()
178
+ conversations[thread_id] = conversation
179
+
180
  logger.info(f"[{thread_id}] Response generated: {response_message[:100]}")
181
+ logger.info(f"[{thread_id}] Final messages count: {len(conversation['messages'])}")
182
+
183
 
184
+ message_types = [f"{type(msg).__name__}: {msg.content[:50] if hasattr(msg, 'content') else str(msg)[:50]}"
185
+ for msg in conversation['messages'][-3:]] # Show last 3 messages
186
+ logger.info(f"[{thread_id}] Recent message types: {message_types}")
 
 
 
 
187
 
188
+ return jsonify({
189
+ 'success': True,
190
+ 'response': response_message,
191
+ 'status': status_info,
192
+ 'thread_id': thread_id
193
+ })
194
 
195
+ except Exception as e:
196
+ logger.error(f"Error in chat endpoint: {str(e)}", exc_info=True)
197
+ return jsonify({'success': False, 'error': f"An error occurred: {str(e)}"})
 
 
198
 
199
 
200
+ @app.route('/new_conversation', methods=['POST'])
201
+ def new_conversation():
202
+ """Start a new conversation thread"""
203
+ old_thread_id = session.get('thread_id')
204
+ session.pop('thread_id', None)
205
 
206
+ if old_thread_id:
207
+ logger.info(f"[{old_thread_id}] Starting new conversation")
 
 
 
 
 
 
 
208
 
209
+ return jsonify({'success': True, 'message': 'New conversation started'})
210
 
211
 
212
+ @app.route('/conversation_info', methods=['GET'])
213
+ def conversation_info():
214
+ """Get current conversation information (for debugging)"""
215
+ thread_id = session.get('thread_id')
216
+ if not thread_id:
217
+ return jsonify({'error': 'No active conversation'})
 
218
 
219
+ with conversations_lock:
220
+ conversation = conversations.get(thread_id, {})
221
+ if not conversation:
222
+ return jsonify({'error': 'Conversation not found'})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
 
224
+
225
+ info = {
226
+ 'thread_id': thread_id,
227
+ 'message_count': len(conversation.get('messages', [])),
228
+ 'message_types': [type(msg).__name__ for msg in conversation.get('messages', [])],
229
+ 'notes_count': len(conversation.get('notes', [])),
230
+ 'has_query_brief': bool(conversation.get('query_brief')),
231
+ 'has_final_message': bool(conversation.get('final_message')),
232
+ 'created_at': conversation.get('created_at', '').isoformat() if conversation.get('created_at') else '',
233
+ 'last_updated': conversation.get('last_updated', '').isoformat() if conversation.get('last_updated') else ''
234
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
+ return jsonify(info)
237
+
238
+
239
+ @app.route('/health')
240
+ def health():
241
+ """Health check endpoint"""
242
+ with conversations_lock:
243
+ active_conversations = len(conversations)
244
+ return jsonify({
245
+ 'status': 'healthy',
246
+ 'timestamp': datetime.now().isoformat(),
247
+ 'active_conversations': active_conversations
248
+ })
249
+
250
+
251
+ @app.errorhandler(404)
252
+ def not_found(error):
253
+ return jsonify({'error': 'Endpoint not found'}), 404
254
+
255
+
256
+ @app.errorhandler(500)
257
+ def internal_error(error):
258
+ logger.error(f"Internal server error: {error}")
259
+ return jsonify({'error': 'Internal server error'}), 500
260
+
261
+
262
+ def cleanup_conversations():
263
+ """Remove old conversations older than 24 hours"""
264
+ while True:
265
+ time.sleep(3600) # every hour
266
+ cutoff = datetime.now().timestamp() - 24 * 3600
267
+ removed_count = 0
268
+ with conversations_lock:
269
+ threads_to_remove = [tid for tid, conv in conversations.items()
270
+ if conv['last_updated'].timestamp() < cutoff]
271
+ for tid in threads_to_remove:
272
+ del conversations[tid]
273
+ removed_count += 1
274
+ if removed_count > 0:
275
+ logger.info(f"Cleaned up {removed_count} old conversations")
276
+
277
+
278
+ if __name__ == '__main__':
279
+ cleanup_thread = threading.Thread(target=cleanup_conversations, daemon=True)
280
+ cleanup_thread.start()
281
+
282
+ port = int(os.environ.get('PORT', 5000))
283
+ debug = os.environ.get('FLASK_DEBUG', 'False').lower() == 'true'
284
+ logger.info(f"Starting Sparrow Agent Flask app on port {port}")
285
+ app.run(host='0.0.0.0', port=port, debug=debug)