Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from langchain_community.chat_models import ChatOpenAI | |
| from langchain.memory import ConversationBufferMemory, SimpleMemory | |
| from langchain.agents import initialize_agent, AgentType | |
| from langchain_community.callbacks import ClearMLCallbackHandler | |
| from langchain_core.callbacks import StdOutCallbackHandler | |
| from clearml import Logger, Task | |
| from dotenv import load_dotenv | |
| from dotenv import load_dotenv, find_dotenv | |
| import os | |
| import agent.planning_agent as planning_agent | |
| import logging | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Global variables | |
| llm = None | |
| chat_memory = None | |
| query_memory = None | |
| clearml_callback = None | |
| def initialize_components(): | |
| global llm, chat_memory, query_memory, clearml_callback | |
| load_dotenv() | |
| OPENAI_API_KEY="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" | |
| #CLEARML_API_ACCESS_KEY="NYZZ07E2ZEW08V4DUGYY2PA7O6JX5F" | |
| #CLEARML_API_SECRET_KEY="MkfQrIOuKNFRWHfCz32cN-UVm_19M7_vgxAwRn8twnvHYJ1xeqD9T2GZcIX9RwnD8mw" | |
| #SERPAPI_API_KEY="619f2302253fbe56448bcf82565caf2a3263d845944682533f10b09a0d1650e6" | |
| # openai.api_key=OPENAI_API_KEY | |
| llm = ChatOpenAI( | |
| model_name="gpt-3.5-turbo", | |
| temperature=0, | |
| openai_api_key="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" | |
| ) | |
| # Initialize memories | |
| chat_memory = ConversationBufferMemory( | |
| memory_key="chat_history", | |
| return_messages=True | |
| ) | |
| query_memory = SimpleMemory() | |
| # Setup and use the ClearML Callback | |
| clearml_callback = ClearMLCallbackHandler( | |
| task_type="inference", | |
| project_name="langchain_callback_demo", | |
| task_name="llm", | |
| tags=["test"], | |
| # Change the following parameters based on the amount of detail you want tracked | |
| visualize=True, | |
| complexity_metrics=True, | |
| stream_logs=True,) | |
| callbacks = [StdOutCallbackHandler(), clearml_callback] | |
| # Initialize planning agent with both memories | |
| planning_agent.initialize_planning_agent(llm, chat_memory, query_memory, callbacks) | |
| logger.info("Components initialized successfully") | |
| def process_query(query, history): | |
| try: | |
| # Restore chat history from Gradio's history | |
| if history: | |
| for human_msg, ai_msg in history: | |
| if chat_memory and hasattr(chat_memory, 'chat_memory'): | |
| chat_memory.chat_memory.add_user_message(human_msg) | |
| chat_memory.chat_memory.add_ai_message(ai_msg) | |
| # Store original query in query memory | |
| query_memory.memories['original_query'] = query | |
| # Execute query through planning agent | |
| response = planning_agent.execute(query) | |
| #clearml_callback.flush_tracker(name="Planning agent", finish=True) | |
| #clearml_callback.flush_tracker(langchain_asset=llm, name="simple_sequential") | |
| # Add current interaction to chat memory | |
| if chat_memory and hasattr(chat_memory, 'chat_memory'): | |
| chat_memory.chat_memory.add_user_message(query) | |
| chat_memory.chat_memory.add_ai_message(response) | |
| return response | |
| except Exception as e: | |
| error_msg = f"Error processing query: {str(e)}" | |
| logger.error(f"Error details: {str(e)}") | |
| if chat_memory and hasattr(chat_memory, 'chat_memory'): | |
| chat_memory.chat_memory.add_user_message(query) | |
| chat_memory.chat_memory.add_ai_message(error_msg) | |
| return error_msg | |
| def clear_context(): | |
| planning_agent.clear_context() | |
| chat_memory.clear() | |
| query_memory.memories.clear() | |
| return [], [] | |
| def create_gradio_app(): | |
| from interface import create_interface | |
| return create_interface(process_query, clear_context) | |
| def report_table(loer, iteration=0): | |
| # type: (Logger, int) -> () | |
| """ | |
| reporting tables to the plots section | |
| :param logger: The task.logger to use for sending the plots | |
| :param iteration: The iteration number of the current reports | |
| """ | |
| # report table | |
| # Report table - CSV from path | |
| csv_path = './data/cleaned_dataset_full.csv' | |
| loer.report_table("Data Set Capstone", "remote csv", iteration=iteration, csv=csv_path) | |
| def main(): | |
| """Main application entry point""" | |
| try: | |
| initialize_components() | |
| app = create_gradio_app() | |
| app.queue() | |
| app.launch(server_name="0.0.0.0", server_port=7860, share=True) | |
| a_task = Task.get_task(project_name='langchain_callback_demo', task_name='llm') | |
| loer = a_task.get_logger() | |
| report_logs(loer) | |
| # report text as debug example | |
| report_debug_text(loer) | |
| # report graphs | |
| report_table(loer) | |
| loer.flush() | |
| except Exception as e: | |
| logger.error(f"Error in main: {str(e)}") | |
| raise | |
| if __name__ == "__main__": | |
| main() |