subashpoudel's picture
Included CI CD
583f6dd
raw
history blame
2.31 kB
from langgraph.graph import StateGraph, MessagesState, START, END
from langgraph.checkpoint.memory import MemorySaver
from .utils.nodes import tool_return_node, extract_user_reference_node
from src.genai.utils.models_loader import llm_gpt
from .utils.state import ValidationFormatter
from .utils.utils import caption_image , extract_latest_response_block
from .utils.tools import retrieve_data_for_orchestration
import re
from langchain_core.messages import SystemMessage
memory = MemorySaver()
def orchestration_graph():
workflow = StateGraph(MessagesState)
workflow.add_node("chatbot1", tool_return_node)
workflow.add_node("chatbot2", extract_user_reference_node)
workflow.add_edge(START, "chatbot1")
workflow.add_edge('chatbot1', "chatbot2")
workflow.add_edge('chatbot2', END)
return workflow.compile(checkpointer=memory)
user_input_history = []
def orchestration_chat(user_input: str, image_base64=[]):
global user_input_history
user_input_history.append({'role': 'human', 'content': user_input})
if len(image_base64)>0:
caption_response = caption_image(image_base64, user_input)
user_input_history.append({'role': 'image_caption', 'content': caption_response})
print('Caption Response:', caption_response)
else:
caption_response =''
if len(user_input_history)>4:
user_input_history=user_input_history[-2:]
print('Length of history', len(user_input_history))
query_for_retrieval = ' '.join(
[msg['content'] for msg in user_input_history if msg['role'] in ('human', 'image_caption')]
)
influencers_data = retrieve_data_for_orchestration(query_for_retrieval)
agent = orchestration_graph()
config = {"configurable": {"thread_id": "orchestration-thread"}}
response = agent.invoke({"messages": [{'role':'human','content':user_input},
{'role': 'function', 'name': 'data_of_influencers', 'content': influencers_data},
{'role':'function','name':'information_of_image','content':caption_response}]}, config)['messages']
print('Orchestrator Response', response)
response=llm_gpt.with_structured_output(ValidationFormatter).invoke(extract_latest_response_block(response))
return response