ClearMLIISC / app.py
sahanacp's picture
Update app.py
e54b757 verified
import gradio as gr
from langchain_community.chat_models import ChatOpenAI
from langchain.memory import ConversationBufferMemory, SimpleMemory
from langchain.agents import initialize_agent, AgentType
from langchain_community.callbacks import ClearMLCallbackHandler
from langchain_core.callbacks import StdOutCallbackHandler
from clearml import Logger, Task
from dotenv import load_dotenv
from dotenv import load_dotenv, find_dotenv
import os
import agent.planning_agent as planning_agent
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global variables
llm = None
chat_memory = None
query_memory = None
clearml_callback = None
def initialize_components():
global llm, chat_memory, query_memory, clearml_callback
load_dotenv()
OPENAI_API_KEY="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"
#CLEARML_API_ACCESS_KEY="NYZZ07E2ZEW08V4DUGYY2PA7O6JX5F"
#CLEARML_API_SECRET_KEY="MkfQrIOuKNFRWHfCz32cN-UVm_19M7_vgxAwRn8twnvHYJ1xeqD9T2GZcIX9RwnD8mw"
#SERPAPI_API_KEY="619f2302253fbe56448bcf82565caf2a3263d845944682533f10b09a0d1650e6"
# openai.api_key=OPENAI_API_KEY
llm = ChatOpenAI(
model_name="gpt-3.5-turbo",
temperature=0,
openai_api_key="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"
)
# Initialize memories
chat_memory = ConversationBufferMemory(
memory_key="chat_history",
return_messages=True
)
query_memory = SimpleMemory()
# Setup and use the ClearML Callback
clearml_callback = ClearMLCallbackHandler(
task_type="inference",
project_name="langchain_callback_demo",
task_name="llm",
tags=["test"],
# Change the following parameters based on the amount of detail you want tracked
visualize=True,
complexity_metrics=True,
stream_logs=True,)
callbacks = [StdOutCallbackHandler(), clearml_callback]
# Initialize planning agent with both memories
planning_agent.initialize_planning_agent(llm, chat_memory, query_memory, callbacks)
logger.info("Components initialized successfully")
def process_query(query, history):
try:
# Restore chat history from Gradio's history
if history:
for human_msg, ai_msg in history:
if chat_memory and hasattr(chat_memory, 'chat_memory'):
chat_memory.chat_memory.add_user_message(human_msg)
chat_memory.chat_memory.add_ai_message(ai_msg)
# Store original query in query memory
query_memory.memories['original_query'] = query
# Execute query through planning agent
response = planning_agent.execute(query)
#clearml_callback.flush_tracker(name="Planning agent", finish=True)
#clearml_callback.flush_tracker(langchain_asset=llm, name="simple_sequential")
# Add current interaction to chat memory
if chat_memory and hasattr(chat_memory, 'chat_memory'):
chat_memory.chat_memory.add_user_message(query)
chat_memory.chat_memory.add_ai_message(response)
return response
except Exception as e:
error_msg = f"Error processing query: {str(e)}"
logger.error(f"Error details: {str(e)}")
if chat_memory and hasattr(chat_memory, 'chat_memory'):
chat_memory.chat_memory.add_user_message(query)
chat_memory.chat_memory.add_ai_message(error_msg)
return error_msg
def clear_context():
planning_agent.clear_context()
chat_memory.clear()
query_memory.memories.clear()
return [], []
def create_gradio_app():
from interface import create_interface
return create_interface(process_query, clear_context)
def report_table(loer, iteration=0):
# type: (Logger, int) -> ()
"""
reporting tables to the plots section
:param logger: The task.logger to use for sending the plots
:param iteration: The iteration number of the current reports
"""
# report table
# Report table - CSV from path
csv_path = './data/cleaned_dataset_full.csv'
loer.report_table("Data Set Capstone", "remote csv", iteration=iteration, csv=csv_path)
def main():
"""Main application entry point"""
try:
initialize_components()
app = create_gradio_app()
app.queue()
app.launch(server_name="0.0.0.0", server_port=7860, share=True)
a_task = Task.get_task(project_name='langchain_callback_demo', task_name='llm')
loer = a_task.get_logger()
report_logs(loer)
# report text as debug example
report_debug_text(loer)
# report graphs
report_table(loer)
loer.flush()
except Exception as e:
logger.error(f"Error in main: {str(e)}")
raise
if __name__ == "__main__":
main()