import gradio as gr from langchain_community.chat_models import ChatOpenAI from langchain.memory import ConversationBufferMemory, SimpleMemory from langchain.agents import initialize_agent, AgentType from langchain_community.callbacks import ClearMLCallbackHandler from langchain_core.callbacks import StdOutCallbackHandler from clearml import Logger, Task from dotenv import load_dotenv from dotenv import load_dotenv, find_dotenv import os import agent.planning_agent as planning_agent import logging # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Global variables llm = None chat_memory = None query_memory = None clearml_callback = None def initialize_components(): global llm, chat_memory, query_memory, clearml_callback load_dotenv() OPENAI_API_KEY="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" #CLEARML_API_ACCESS_KEY="NYZZ07E2ZEW08V4DUGYY2PA7O6JX5F" #CLEARML_API_SECRET_KEY="MkfQrIOuKNFRWHfCz32cN-UVm_19M7_vgxAwRn8twnvHYJ1xeqD9T2GZcIX9RwnD8mw" #SERPAPI_API_KEY="619f2302253fbe56448bcf82565caf2a3263d845944682533f10b09a0d1650e6" # openai.api_key=OPENAI_API_KEY llm = ChatOpenAI( model_name="gpt-3.5-turbo", temperature=0, openai_api_key="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" ) # Initialize memories chat_memory = ConversationBufferMemory( memory_key="chat_history", return_messages=True ) query_memory = SimpleMemory() # Setup and use the ClearML Callback clearml_callback = ClearMLCallbackHandler( task_type="inference", project_name="langchain_callback_demo", task_name="llm", tags=["test"], # Change the following parameters based on the amount of detail you want tracked visualize=True, complexity_metrics=True, stream_logs=True,) callbacks = [StdOutCallbackHandler(), clearml_callback] # Initialize planning agent with both memories planning_agent.initialize_planning_agent(llm, chat_memory, query_memory, callbacks) logger.info("Components initialized successfully") def process_query(query, history): try: # Restore chat history from Gradio's history if history: for human_msg, ai_msg in history: if chat_memory and hasattr(chat_memory, 'chat_memory'): chat_memory.chat_memory.add_user_message(human_msg) chat_memory.chat_memory.add_ai_message(ai_msg) # Store original query in query memory query_memory.memories['original_query'] = query # Execute query through planning agent response = planning_agent.execute(query) #clearml_callback.flush_tracker(name="Planning agent", finish=True) #clearml_callback.flush_tracker(langchain_asset=llm, name="simple_sequential") # Add current interaction to chat memory if chat_memory and hasattr(chat_memory, 'chat_memory'): chat_memory.chat_memory.add_user_message(query) chat_memory.chat_memory.add_ai_message(response) return response except Exception as e: error_msg = f"Error processing query: {str(e)}" logger.error(f"Error details: {str(e)}") if chat_memory and hasattr(chat_memory, 'chat_memory'): chat_memory.chat_memory.add_user_message(query) chat_memory.chat_memory.add_ai_message(error_msg) return error_msg def clear_context(): planning_agent.clear_context() chat_memory.clear() query_memory.memories.clear() return [], [] def create_gradio_app(): from interface import create_interface return create_interface(process_query, clear_context) def report_table(loer, iteration=0): # type: (Logger, int) -> () """ reporting tables to the plots section :param logger: The task.logger to use for sending the plots :param iteration: The iteration number of the current reports """ # report table # Report table - CSV from path csv_path = './data/cleaned_dataset_full.csv' loer.report_table("Data Set Capstone", "remote csv", iteration=iteration, csv=csv_path) def main(): """Main application entry point""" try: initialize_components() app = create_gradio_app() app.queue() app.launch(server_name="0.0.0.0", server_port=7860, share=True) a_task = Task.get_task(project_name='langchain_callback_demo', task_name='llm') loer = a_task.get_logger() report_logs(loer) # report text as debug example report_debug_text(loer) # report graphs report_table(loer) loer.flush() except Exception as e: logger.error(f"Error in main: {str(e)}") raise if __name__ == "__main__": main()