Spaces:
Sleeping
Sleeping
File size: 5,026 Bytes
543acf2 2fd8b66 30a788f 543acf2 035958b 543acf2 a665c25 543acf2 3c94672 543acf2 d66eb96 e54b757 2fd8b66 27465b5 543acf2 1985088 543acf2 3c94672 543acf2 3c94672 543acf2 e687cd5 37e0497 543acf2 d84f689 2af87db 65e30ff d84f689 543acf2 30a788f d84f689 543acf2 3195421 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 | import gradio as gr
from langchain_community.chat_models import ChatOpenAI
from langchain.memory import ConversationBufferMemory, SimpleMemory
from langchain.agents import initialize_agent, AgentType
from langchain_community.callbacks import ClearMLCallbackHandler
from langchain_core.callbacks import StdOutCallbackHandler
from clearml import Logger, Task
from dotenv import load_dotenv
from dotenv import load_dotenv, find_dotenv
import os
import agent.planning_agent as planning_agent
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global variables
llm = None
chat_memory = None
query_memory = None
clearml_callback = None
def initialize_components():
global llm, chat_memory, query_memory, clearml_callback
load_dotenv()
OPENAI_API_KEY="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"
#CLEARML_API_ACCESS_KEY="NYZZ07E2ZEW08V4DUGYY2PA7O6JX5F"
#CLEARML_API_SECRET_KEY="MkfQrIOuKNFRWHfCz32cN-UVm_19M7_vgxAwRn8twnvHYJ1xeqD9T2GZcIX9RwnD8mw"
#SERPAPI_API_KEY="619f2302253fbe56448bcf82565caf2a3263d845944682533f10b09a0d1650e6"
# openai.api_key=OPENAI_API_KEY
llm = ChatOpenAI(
model_name="gpt-3.5-turbo",
temperature=0,
openai_api_key="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"
)
# Initialize memories
chat_memory = ConversationBufferMemory(
memory_key="chat_history",
return_messages=True
)
query_memory = SimpleMemory()
# Setup and use the ClearML Callback
clearml_callback = ClearMLCallbackHandler(
task_type="inference",
project_name="langchain_callback_demo",
task_name="llm",
tags=["test"],
# Change the following parameters based on the amount of detail you want tracked
visualize=True,
complexity_metrics=True,
stream_logs=True,)
callbacks = [StdOutCallbackHandler(), clearml_callback]
# Initialize planning agent with both memories
planning_agent.initialize_planning_agent(llm, chat_memory, query_memory, callbacks)
logger.info("Components initialized successfully")
def process_query(query, history):
try:
# Restore chat history from Gradio's history
if history:
for human_msg, ai_msg in history:
if chat_memory and hasattr(chat_memory, 'chat_memory'):
chat_memory.chat_memory.add_user_message(human_msg)
chat_memory.chat_memory.add_ai_message(ai_msg)
# Store original query in query memory
query_memory.memories['original_query'] = query
# Execute query through planning agent
response = planning_agent.execute(query)
#clearml_callback.flush_tracker(name="Planning agent", finish=True)
#clearml_callback.flush_tracker(langchain_asset=llm, name="simple_sequential")
# Add current interaction to chat memory
if chat_memory and hasattr(chat_memory, 'chat_memory'):
chat_memory.chat_memory.add_user_message(query)
chat_memory.chat_memory.add_ai_message(response)
return response
except Exception as e:
error_msg = f"Error processing query: {str(e)}"
logger.error(f"Error details: {str(e)}")
if chat_memory and hasattr(chat_memory, 'chat_memory'):
chat_memory.chat_memory.add_user_message(query)
chat_memory.chat_memory.add_ai_message(error_msg)
return error_msg
def clear_context():
planning_agent.clear_context()
chat_memory.clear()
query_memory.memories.clear()
return [], []
def create_gradio_app():
from interface import create_interface
return create_interface(process_query, clear_context)
def report_table(loer, iteration=0):
# type: (Logger, int) -> ()
"""
reporting tables to the plots section
:param logger: The task.logger to use for sending the plots
:param iteration: The iteration number of the current reports
"""
# report table
# Report table - CSV from path
csv_path = './data/cleaned_dataset_full.csv'
loer.report_table("Data Set Capstone", "remote csv", iteration=iteration, csv=csv_path)
def main():
"""Main application entry point"""
try:
initialize_components()
app = create_gradio_app()
app.queue()
app.launch(server_name="0.0.0.0", server_port=7860, share=True)
a_task = Task.get_task(project_name='langchain_callback_demo', task_name='llm')
loer = a_task.get_logger()
report_logs(loer)
# report text as debug example
report_debug_text(loer)
# report graphs
report_table(loer)
loer.flush()
except Exception as e:
logger.error(f"Error in main: {str(e)}")
raise
if __name__ == "__main__":
main() |