subashpoudel's picture
Created Ideation agent
3a3fe92
raw
history blame
5.18 kB
from .state import State , ValidationFormatter
from .tools import retrieve_tool
from langgraph.prebuilt import create_react_agent
from utils.models_loader import ideator_llm, critic_llm , improver_llm , validator_llm
from langchain_core.messages import SystemMessage , HumanMessage
from .prompts import ideator_prompt , critic_prompt , improver_prompt , validator_prompt
def ideator(state:State):
tools=[retrieve_tool]
react_agent=create_react_agent(
model=ideator_llm,
tools=tools
)
template = ideator_prompt(state)
messages = [SystemMessage(content=template),
HumanMessage(content=f'''The topic of the video is:\n{state.topic}\n''')]
response = react_agent.invoke({'messages':messages})
response = response['messages'][-1].content
print('Ideator Response:',response)
state.ideator_response = response
print('Ideator Generated the story')
return state
def critic(state:State):
tools=[retrieve_tool]
react_agent=create_react_agent(
model=critic_llm,
tools=tools
)
template = critic_prompt(state)
messages = [SystemMessage(content=template),
HumanMessage(content=f'''The topic of the video is:\n{state.topic}\n. The business_details is\n{state.business_details}\n''')]
response = react_agent.invoke({'messages':messages})
response = response['messages'][-1].content
print('Critic Response:',response)
state.critic_response = response
print('Critic Evaluated the story')
return state
def improver(state:State):
react_agent=create_react_agent(
model=improver_llm,
tools=[]
)
template = improver_prompt(state)
messages = [SystemMessage(content=template),
HumanMessage(content=f'''The topic of the video is:\n{state.topic}\n The business_details is:\n{state.business_details}''')]
response = react_agent.invoke({'messages':messages})
response = response['messages'][-1].content
print('Improver Response:',response)
state.improver_response = response
print('Improver Improved the story')
return state
def validator1(state:State):
template = validator_prompt(state)
messages = [SystemMessage(content=template),
HumanMessage(content=f'''The topic of the video is:\n{state.topic}\n The business_details is:\n{state.business_details}''')]
response = validator_llm.with_structured_output(ValidationFormatter).invoke(messages)
print(f'Validator 1 response: {response}')
state.validator1_response = response.result
print('The state check:',state.validator1_response)
if 'not validated' in response.result:
state.disagreement_reason = response.reason
return state
def validator2(state:State):
template = validator_prompt(state)
messages = [SystemMessage(content=template),
HumanMessage(content=f'''The topic of the video is:\n{state.topic}\n The business_details is:\n{state.business_details}''')]
response = ideator_llm.with_structured_output(ValidationFormatter).invoke(messages)
print(f'Validator 2 response: {response}')
state.validator2_response = response.result
print('The state check:',state.validator1_response)
if 'not validated' in response.result:
state.disagreement_reason = response.reason
return state
def validator3(state:State):
template = validator_prompt(state)
messages = [SystemMessage(content=template),
HumanMessage(content=f'''The topic of the video is:\n{state.topic}\n The business_details is:\n{state.business_details}''')]
response = critic_llm.with_structured_output(ValidationFormatter).invoke(messages)
print(f'Validator 3 response: {response}')
state.validator3_response = response.result
print('The state check:',state.validator1_response)
if 'not validated' in response.result:
state.disagreement_reason = response.reason
return state
def validator4(state:State):
template = validator_prompt(state)
messages = [SystemMessage(content=template),
HumanMessage(content=f'''The topic of the video is:\n{state.topic}\n The business_details is:\n{state.business_details}''')]
response = improver_llm.with_structured_output(ValidationFormatter).invoke(messages)
print(f'Validator 4 response: {response}')
state.validator4_response = response.result
print('The state check:',state.validator1_response)
if 'not validated' in response.result:
state.topic=None
state.disagreement_reason = response.reason
return state
def route1_after_validation(state:State):
if 'not validated' in state.validator1_response:
return False
else:
return True
def route2_after_validation(state:State):
if 'not validated' in state.validator2_response:
return False
else:
return True
def route3_after_validation(state:State):
if 'not validated' in state.validator3_response:
return False
else:
return True
def route4_after_validation(state:State):
if 'not validated' in state.validator4_response:
return False
else:
return True