subashpoudel's picture
Updated commit
b623e6c
raw
history blame
6.04 kB
from .state import State , ValidationFormatter , ImproverResponseFormatter
from .tools import retrieve_tool
from langgraph.prebuilt import create_react_agent
from utils.models_loader import ideator_llm, critic_llm , improver_llm , validator_llm
from langchain_core.messages import SystemMessage , HumanMessage
from .prompts import ideator_prompt , critic_prompt , improver_prompt , validator_prompt
ideator_agent = create_react_agent(
model=ideator_llm,
tools=[retrieve_tool]
)
critic_agent = create_react_agent(
model=critic_llm,
tools=[retrieve_tool]
)
improver_agent = create_react_agent(
model=improver_llm,
tools=[]
)
def ideator(state:State):
template = ideator_prompt(state)
messages = [SystemMessage(content=template),
HumanMessage(content=f'''The topic of the video is:\n{state.topic[-1]}\n''')]
try:
response = ideator_agent.invoke({'messages':messages})
response = response['messages'][-1].content
print('Ideator Response:',response)
state.ideator_response.append(response)
print('Ideator Generated the story')
return state
except:
response = ideator_llm.invoke(messages)
print('Ideator backup Response:',response.content)
state.ideator_response.append(response.content)
return state
def critic(state:State):
template = critic_prompt(state)
messages = [SystemMessage(content=template),
HumanMessage(content=f'''The topic of the video is:\n{state.topic[-1]}\n. The business_details is\n{state.business_details[-1]}\n''')]
try:
response = critic_agent.invoke({'messages':messages})
response = response['messages'][-1].content
print('Critic Response:',response)
state.critic_response.append(response)
print('Critic Evaluated the story')
return state
except:
response = critic_llm.invoke(messages)
print('Critic backup Response:',response.content)
state.critic_response.append(response.content)
return state
def improver(state:State):
response_list = []
template = improver_prompt(state)
messages = [SystemMessage(content=template),
HumanMessage(content=f'''The topic of the video is:\n{state.topic[-1]}\n The business_details is:\n{state.business_details[-1]}''')]
print('Improver Prompt:',messages)
response = improver_llm.with_structured_output(ImproverResponseFormatter).invoke(messages)
response_list.append(response.improved_idea1)
response_list.append(response.improved_idea2)
response_list.append(response.improved_idea3)
response_list.append(response.improved_idea4)
state.improver_response.append(str(response_list))
state.critic_fault.append(response.faults)
print('Improver response:',response_list)
return state
def validator1(state:State):
template = validator_prompt(state)
messages = [SystemMessage(content=template),
HumanMessage(content=f'''The topic of the video is:\n{state.topic[-1]}\n The business_details is:\n{state.business_details[-1]}''')]
response = validator_llm.with_structured_output(ValidationFormatter).invoke(messages)
print(f'Validator 1 response: {response}')
state.validator1_response.append(response.result)
print('The state check:',state.validator1_response[-1])
if 'not validated' in response.result:
state.disagreement_reason.append(response.reason)
return state
def validator2(state:State):
template = validator_prompt(state)
messages = [SystemMessage(content=template),
HumanMessage(content=f'''The topic of the video is:\n{state.topic[-1]}\n The business_details is:\n{state.business_details[-1]}''')]
response = ideator_llm.with_structured_output(ValidationFormatter).invoke(messages)
print(f'Validator 2 response: {response}')
state.validator2_response.append(response.result)
print('The state check:',state.validator2_response[-1])
if 'not validated' in response.result:
state.disagreement_reason.append(response.reason)
return state
def validator3(state:State):
template = validator_prompt(state)
messages = [SystemMessage(content=template),
HumanMessage(content=f'''The topic of the video is:\n{state.topic[-1]}\n The business_details is:\n{state.business_details[-1]}''')]
response = critic_llm.with_structured_output(ValidationFormatter).invoke(messages)
print(f'Validator 3 response: {response}')
state.validator3_response.append(response.result)
print('The state check:',state.validator1_response[-1])
if 'not validated' in response.result:
state.disagreement_reason.append(response.reason)
return state
def validator4(state:State):
template = validator_prompt(state)
messages = [SystemMessage(content=template),
HumanMessage(content=f'''The topic of the video is:\n{state.topic[-1]}\n The business_details is:\n{state.business_details[-1]}''')]
response = improver_llm.with_structured_output(ValidationFormatter).invoke(messages)
print(f'Validator 4 response: {response}')
state.validator4_response.append(response.result)
print('The state check:',state.validator1_response[-1])
if 'not validated' in response.result:
state.disagreement_reason.append(response.reason)
return state
def route1_after_validation(state:State):
if 'not validated' in state.validator1_response[-1]:
return False
else:
return True
def route2_after_validation(state:State):
if 'not validated' in state.validator2_response[-1]:
return False
else:
return True
def route3_after_validation(state:State):
if 'not validated' in state.validator3_response[-1]:
return False
else:
return True
def route4_after_validation(state:State):
if 'not validated' in state.validator4_response[-1]:
return False
else:
return True