File size: 5,026 Bytes
543acf2
 
 
 
2fd8b66
 
30a788f
543acf2
035958b
543acf2
 
 
 
 
 
 
 
 
 
 
 
a665c25
543acf2
 
3c94672
543acf2
d66eb96
 
e54b757
 
 
2fd8b66
27465b5
543acf2
 
 
 
1985088
543acf2
 
 
 
 
 
 
 
 
3c94672
 
 
 
 
 
 
 
 
 
 
 
 
543acf2
3c94672
543acf2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e687cd5
37e0497
543acf2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d84f689
 
 
 
 
 
 
 
 
2af87db
65e30ff
d84f689
543acf2
 
 
 
 
 
 
30a788f
 
 
 
 
d84f689
 
 
543acf2
 
 
 
 
3195421
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import gradio as gr
from langchain_community.chat_models import ChatOpenAI
from langchain.memory import ConversationBufferMemory, SimpleMemory
from langchain.agents import initialize_agent, AgentType
from langchain_community.callbacks import ClearMLCallbackHandler
from langchain_core.callbacks import StdOutCallbackHandler
from clearml import Logger, Task
from dotenv import load_dotenv
from dotenv import load_dotenv, find_dotenv
import os
import agent.planning_agent as planning_agent
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Global variables
llm = None
chat_memory = None
query_memory = None
clearml_callback = None

def initialize_components():
    global llm, chat_memory, query_memory, clearml_callback
    load_dotenv()
    
    OPENAI_API_KEY="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"
    #CLEARML_API_ACCESS_KEY="NYZZ07E2ZEW08V4DUGYY2PA7O6JX5F"
    #CLEARML_API_SECRET_KEY="MkfQrIOuKNFRWHfCz32cN-UVm_19M7_vgxAwRn8twnvHYJ1xeqD9T2GZcIX9RwnD8mw"
    #SERPAPI_API_KEY="619f2302253fbe56448bcf82565caf2a3263d845944682533f10b09a0d1650e6"
    
    # openai.api_key=OPENAI_API_KEY
    
    llm = ChatOpenAI(
        model_name="gpt-3.5-turbo",
        temperature=0,
        openai_api_key="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"
    )

    # Initialize memories
    chat_memory = ConversationBufferMemory(
        memory_key="chat_history",
        return_messages=True
    )
    query_memory = SimpleMemory()

    # Setup and use the ClearML Callback
    clearml_callback = ClearMLCallbackHandler(
    task_type="inference",
    project_name="langchain_callback_demo",
    task_name="llm",
    tags=["test"],
    # Change the following parameters based on the amount of detail you want tracked
    visualize=True,
    complexity_metrics=True,
    stream_logs=True,)
    
    callbacks = [StdOutCallbackHandler(), clearml_callback]
    
    # Initialize planning agent with both memories
    planning_agent.initialize_planning_agent(llm, chat_memory, query_memory, callbacks)

    logger.info("Components initialized successfully")

def process_query(query, history):
    try:
        # Restore chat history from Gradio's history
        if history:
            for human_msg, ai_msg in history:
                if chat_memory and hasattr(chat_memory, 'chat_memory'):
                    chat_memory.chat_memory.add_user_message(human_msg)
                    chat_memory.chat_memory.add_ai_message(ai_msg)
        
        # Store original query in query memory
        query_memory.memories['original_query'] = query
        
        # Execute query through planning agent
        response = planning_agent.execute(query)
        #clearml_callback.flush_tracker(name="Planning agent", finish=True)
        #clearml_callback.flush_tracker(langchain_asset=llm, name="simple_sequential")
        
        # Add current interaction to chat memory
        if chat_memory and hasattr(chat_memory, 'chat_memory'):
            chat_memory.chat_memory.add_user_message(query)
            chat_memory.chat_memory.add_ai_message(response)
        
        return response

    except Exception as e:
        error_msg = f"Error processing query: {str(e)}"
        logger.error(f"Error details: {str(e)}")

        if chat_memory and hasattr(chat_memory, 'chat_memory'):
            chat_memory.chat_memory.add_user_message(query)
            chat_memory.chat_memory.add_ai_message(error_msg)

        return error_msg

def clear_context():
    planning_agent.clear_context()
    chat_memory.clear()
    query_memory.memories.clear()
    return [], []

def create_gradio_app():
    from interface import create_interface
    return create_interface(process_query, clear_context)

def report_table(loer, iteration=0):
    # type: (Logger, int) -> ()
    """
    reporting tables to the plots section
    :param logger: The task.logger to use for sending the plots
    :param iteration: The iteration number of the current reports
    """
    # report table
    # Report table - CSV from path
    csv_path = './data/cleaned_dataset_full.csv'
    loer.report_table("Data Set Capstone", "remote csv", iteration=iteration, csv=csv_path)

def main():
    """Main application entry point"""
    try:
        initialize_components()
        app = create_gradio_app()
        app.queue()
        app.launch(server_name="0.0.0.0", server_port=7860, share=True)
        a_task = Task.get_task(project_name='langchain_callback_demo', task_name='llm')
        loer = a_task.get_logger()
        report_logs(loer)
        # report text as debug example
        report_debug_text(loer)
        # report graphs
        report_table(loer)
        loer.flush()
    except Exception as e:
        logger.error(f"Error in main: {str(e)}")
        raise

if __name__ == "__main__":
    main()