subashpoudel commited on
Commit
96b0973
·
1 Parent(s): ca75c57

Updated the tool in final story generation

Browse files
Files changed (1) hide show
  1. my_agent/utils/utils.py +13 -11
my_agent/utils/utils.py CHANGED
@@ -1,6 +1,6 @@
1
 
2
  from langchain_core.messages import SystemMessage
3
- from .tools import StoryFormatter
4
  from .models_loader import llm
5
  import base64
6
  from PIL import Image
@@ -9,21 +9,23 @@ from fastapi import UploadFile
9
  from huggingface_hub import InferenceClient
10
  from .prompts import story_to_prompt , final_story_prompt
11
  import os
 
 
 
 
12
 
13
  def generate_final_story(final_state):
14
  if len(final_state['preferred_topics'])>0:
15
  template = final_story_prompt(final_state)
16
  messages = [SystemMessage(content=template)]
17
- response = llm.bind_tools([StoryFormatter]).invoke(messages)
18
- print('The final response is:',response)
19
- if hasattr(response, 'tool_calls') and response.tool_calls:
20
- response = response.tool_calls[0]['args']
21
- elif hasattr(response, 'content'):
22
- response = response.content
23
- else:
24
- response = "No response"
25
- # state.final_story.append(response)
26
- # state.stories.append(response)
27
  return response
28
 
29
  else:
 
1
 
2
  from langchain_core.messages import SystemMessage
3
+ from .tools import StoryFormatter , retrieve_tool
4
  from .models_loader import llm
5
  import base64
6
  from PIL import Image
 
9
  from huggingface_hub import InferenceClient
10
  from .prompts import story_to_prompt , final_story_prompt
11
  import os
12
+ from langgraph.prebuilt import create_react_agent
13
+
14
+
15
+
16
 
17
  def generate_final_story(final_state):
18
  if len(final_state['preferred_topics'])>0:
19
  template = final_story_prompt(final_state)
20
  messages = [SystemMessage(content=template)]
21
+
22
+ tools = [retrieve_tool]
23
+ react_agent=create_react_agent(
24
+ model=llm.bind_tools(tools),
25
+ tools=tools)
26
+
27
+ response = react_agent.invoke({'messages':messages})
28
+ response = response['messages'][-1].content
 
 
29
  return response
30
 
31
  else: