File size: 10,194 Bytes
782bbd9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
import gradio as gr
import uuid
import logging
from datetime import datetime
import os

from src.graphs.finalAgentGraph import sparrowAgent
from langchain_core.messages import HumanMessage, AIMessage

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def ensure_langchain_message(message):
    """Ensure a message is a proper LangChain message object"""
    if isinstance(message, (HumanMessage, AIMessage)):
        return message
    elif isinstance(message, dict):
        content = message.get('content', str(message))
        message_type = message.get('type', 'ai')
        if message_type == 'human':
            return HumanMessage(content=content)
        else:
            return AIMessage(content=content)
    elif isinstance(message, str):
        return AIMessage(content=message)
    else:
        return AIMessage(content=str(message))


def clean_messages_list(messages):
    """Clean and ensure all messages in list are proper LangChain message objects"""
    cleaned_messages = []
    for msg in messages:
        cleaned_msg = ensure_langchain_message(msg)
        cleaned_messages.append(cleaned_msg)
    return cleaned_messages


def initialize_conversation():
    """Initialize a new conversation state"""
    return {
        'thread_id': str(uuid.uuid4()),
        'messages': [],
        'notes': [],
        'query_brief': '',
        'final_message': '',
        'created_at': datetime.now(),
        'last_updated': datetime.now()
    }


def process_message(user_message, history, conversation_state):
    """
    Process user message and return response
    
    Args:
        user_message: The user's input message
        history: Gradio chat history (list of [user_msg, bot_msg] pairs)
        conversation_state: Dictionary containing conversation context
    
    Returns:
        Tuple of (empty string, updated history, updated conversation state, status message)
    """
    try:
        if not user_message or not user_message.strip():
            return "", history, conversation_state, "Please enter a message"
        
        # Initialize conversation state if None
        if conversation_state is None:
            conversation_state = initialize_conversation()
        
        thread_id = conversation_state['thread_id']
        
        # Add user message to conversation
        human_message = HumanMessage(content=user_message)
        conversation_state['messages'].append(human_message)
        conversation_state['last_updated'] = datetime.now()
        
        # Clean messages
        cleaned_messages = clean_messages_list(conversation_state['messages'])
        
        # Prepare input for sparrow agent
        sparrow_input = {
            'messages': cleaned_messages,
            'notes': conversation_state.get('notes', []),
            'query_brief': conversation_state.get('query_brief', ''),
            'final_message': conversation_state.get('final_message', '')
        }
        
        logger.info(f"[{thread_id}] Processing message: {user_message[:100]}")
        logger.info(f"[{thread_id}] Input messages count: {len(cleaned_messages)}")
        
        # Invoke the sparrow agent
        result = sparrowAgent.invoke(sparrow_input)
        
        # Extract response message
        response_message = ""
        ai_message = None
        
        if result.get('final_message'):
            response_message = result['final_message']
            ai_message = AIMessage(content=response_message)
        else:
            result_messages = clean_messages_list(result.get('messages', []))
            
            # Find last user message index
            last_user_index = -1
            for i, msg in enumerate(result_messages):
                if isinstance(msg, HumanMessage):
                    last_user_index = i
            
            # Get first AI message after last user message
            for i in range(last_user_index + 1, len(result_messages)):
                msg = result_messages[i]
                if isinstance(msg, AIMessage) and msg.content and msg.content.strip():
                    response_message = msg.content
                    ai_message = msg
                    break
        
        if not response_message:
            response_message = "I'm processing your request. Could you provide more details?"
            ai_message = AIMessage(content=response_message)
        
        # Update conversation state
        if result.get('messages'):
            conversation_state['messages'] = clean_messages_list(result['messages'])
        else:
            conversation_state['messages'].append(ai_message)
        
        # Remove consecutive duplicates
        cleaned_conversation_messages = []
        prev_content = None
        prev_type = None
        
        for msg in conversation_state['messages']:
            current_content = msg.content if hasattr(msg, 'content') else str(msg)
            current_type = type(msg).__name__
            
            if current_content != prev_content or current_type != prev_type:
                cleaned_conversation_messages.append(msg)
                prev_content = current_content
                prev_type = current_type
        
        conversation_state['messages'] = cleaned_conversation_messages
        conversation_state['notes'] = result.get('notes', conversation_state.get('notes', []))
        conversation_state['query_brief'] = result.get('query_brief', conversation_state.get('query_brief', ''))
        conversation_state['final_message'] = result.get('final_message', conversation_state.get('final_message', ''))
        conversation_state['last_updated'] = datetime.now()
        
        # Update Gradio chat history
        history.append([user_message, response_message])
        
        # Create status message
        status_info = f"Thread: {thread_id[:8]}... | Messages: {len(conversation_state['messages'])}"
        if result.get('execution_jobs'):
            status_info += f" | Executed: {', '.join(result['execution_jobs'])}"
        elif result.get('notes') and isinstance(result['notes'], list) and result['notes']:
            status_info += f" | Note: {str(result['notes'][-1])[:50]}"
        
        logger.info(f"[{thread_id}] Response generated: {response_message[:100]}")
        logger.info(f"[{thread_id}] Final messages count: {len(conversation_state['messages'])}")
        
        return "", history, conversation_state, status_info
    
    except Exception as e:
        logger.error(f"Error processing message: {str(e)}", exc_info=True)
        error_msg = f"An error occurred: {str(e)}"
        history.append([user_message, error_msg])
        return "", history, conversation_state, f"Error: {str(e)}"


def clear_conversation():
    """Clear conversation and start fresh"""
    new_state = initialize_conversation()
    logger.info(f"[{new_state['thread_id']}] New conversation started")
    return [], new_state, f"New conversation started (ID: {new_state['thread_id'][:8]}...)"


def get_conversation_info(conversation_state):
    """Get current conversation information"""
    if conversation_state is None:
        return "No active conversation"
    
    info_lines = [
        f"**Thread ID:** {conversation_state['thread_id']}",
        f"**Messages:** {len(conversation_state.get('messages', []))}",
        f"**Notes:** {len(conversation_state.get('notes', []))}",
        f"**Has Query Brief:** {bool(conversation_state.get('query_brief'))}",
        f"**Has Final Message:** {bool(conversation_state.get('final_message'))}",
        f"**Created:** {conversation_state.get('created_at', 'N/A')}",
        f"**Last Updated:** {conversation_state.get('last_updated', 'N/A')}"
    ]
    
    return "\n\n".join(info_lines)


# Create Gradio interface
with gr.Blocks(title="Sparrow Agent Chat", theme=gr.themes.Soft()) as demo:
    gr.Markdown("# 🦜 Sparrow Agent Chat")
    gr.Markdown("Interact with the Sparrow AI Agent. Ask questions and get intelligent responses!")
    
    # State to store conversation context
    conversation_state = gr.State(initialize_conversation())
    
    with gr.Row():
        with gr.Column(scale=4):
            chatbot = gr.Chatbot(
                label="Conversation",
                height=500,
                show_copy_button=True
            )
            
            with gr.Row():
                msg = gr.Textbox(
                    label="Your Message",
                    placeholder="Type your message here...",
                    lines=2,
                    scale=4
                )
                submit_btn = gr.Button("Send", variant="primary", scale=1)
            
            with gr.Row():
                clear_btn = gr.Button("New Conversation", variant="secondary")
                status_box = gr.Textbox(
                    label="Status",
                    interactive=False,
                    lines=1
                )
        
        with gr.Column(scale=1):
            gr.Markdown("### Debug Info")
            info_btn = gr.Button("Show Conversation Info")
            info_display = gr.Markdown("Click button to show info")
    
    # Event handlers
    submit_btn.click(
        fn=process_message,
        inputs=[msg, chatbot, conversation_state],
        outputs=[msg, chatbot, conversation_state, status_box]
    )
    
    msg.submit(
        fn=process_message,
        inputs=[msg, chatbot, conversation_state],
        outputs=[msg, chatbot, conversation_state, status_box]
    )
    
    clear_btn.click(
        fn=clear_conversation,
        inputs=[],
        outputs=[chatbot, conversation_state, status_box]
    )
    
    info_btn.click(
        fn=get_conversation_info,
        inputs=[conversation_state],
        outputs=[info_display]
    )
    
    # Initialize status on load
    demo.load(
        fn=lambda state: f"Ready | Thread: {state['thread_id'][:8]}...",
        inputs=[conversation_state],
        outputs=[status_box]
    )


# Launch the app
if __name__ == "__main__":
    demo.launch(
        server_name="0.0.0.0",
        server_port=int(os.environ.get('PORT', 7860)),
        share=False
    )