subashpoudel commited on
Commit
a9f99c3
·
1 Parent(s): b55b8d4

Updated memory management and api routing

Browse files
__pycache__/main.cpython-312.pyc CHANGED
Binary files a/__pycache__/main.cpython-312.pyc and b/__pycache__/main.cpython-312.pyc differ
 
main.py CHANGED
@@ -1,25 +1,25 @@
1
- from fastapi import FastAPI
 
2
  from pydantic import BaseModel
3
  from my_agent.agent import build_graph
4
  import pandas as pd
5
- from typing import Optional
6
  from my_agent.utils.initial_interaction import BusinessInteractionChatbot
 
 
 
 
7
 
8
  app = FastAPI()
9
  interaction_chatbot = BusinessInteractionChatbot()
10
  graph = build_graph()
11
 
12
 
13
- class RequestInput(BaseModel):
14
- query: list
15
- preferred_topics: Optional[list] = []
16
-
17
 
18
 
19
 
20
  class UserMessage(BaseModel):
21
  message: str
22
-
23
  details_for_brainstrom = {}
24
  @app.post("/business-interaction")
25
  def business_chat(msg: UserMessage):
@@ -32,14 +32,62 @@ def business_chat(msg: UserMessage):
32
  return {"response": response, "complete": False}
33
 
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  @app.post("/brainstrom")
36
- def run_graph(input_data: RequestInput):
37
- # business_details = details_for_brainstrom
38
- result = graph.invoke({'topic' : input_data.query , 'business_details': details_for_brainstrom})
39
- # RequestInput.preferred_topics=result['preferred_topics']
40
- return {'final_story': result['final_story'],
41
- 'business_details':result['business_details'],
42
- }
43
-
 
 
 
44
 
 
 
 
 
 
 
 
45
 
 
 
 
 
1
+ from fastapi import FastAPI , UploadFile , File , Form
2
+ from fastapi.responses import JSONResponse
3
  from pydantic import BaseModel
4
  from my_agent.agent import build_graph
5
  import pandas as pd
6
+ from typing import Optional , List
7
  from my_agent.utils.initial_interaction import BusinessInteractionChatbot
8
+ import base64
9
+ from PIL import Image
10
+ from io import BytesIO
11
+ import json
12
 
13
  app = FastAPI()
14
  interaction_chatbot = BusinessInteractionChatbot()
15
  graph = build_graph()
16
 
17
 
 
 
 
 
18
 
19
 
20
 
21
  class UserMessage(BaseModel):
22
  message: str
 
23
  details_for_brainstrom = {}
24
  @app.post("/business-interaction")
25
  def business_chat(msg: UserMessage):
 
32
  return {"response": response, "complete": False}
33
 
34
 
35
+
36
+ # class RequestInput(BaseModel):
37
+ # query: list
38
+ # preferred_topics: Optional[list] = []
39
+ # images: Optional[list[str]] = [] # base64-encoded image strings
40
+
41
+ # @app.post("/brainstrom")
42
+ # def run_graph(input_data: RequestInput):
43
+ # image_objects = []
44
+ # for img_b64 in input_data.images:
45
+ # image_objects.append(process_image(img_b64)) # decode and load images
46
+
47
+ # result = graph.invoke({
48
+ # 'topic': input_data.query,
49
+ # 'images': image_objects,
50
+ # 'business_details': details_for_brainstrom
51
+ # })
52
+
53
+ # return {
54
+ # 'final_story': result['final_story'],
55
+ # 'business_details': result['business_details'],
56
+ # }
57
+
58
+
59
+
60
+
61
+ # Convert uploaded image to base64 string
62
+ def encode_image_to_base64(uploaded_file: UploadFile) -> str:
63
+ return base64.b64encode(uploaded_file.file.read()).decode("utf-8")
64
+
65
+ # Convert base64 string to PIL image (optional for LangGraph processing)
66
+ def process_image(base64_str: str) -> Image.Image:
67
+ image_data = base64.b64decode(base64_str)
68
+ return Image.open(BytesIO(image_data))
69
+
70
  @app.post("/brainstrom")
71
+ async def run_graph(
72
+ query: List[str], # sent as JSON body
73
+ preferred_topics: Optional[list] = [],
74
+ images: Optional[List[UploadFile]] = [], # ✅ Optional UploadFile list
75
+ thread_id: Optional[str] = "default-session"
76
+ ):
77
+ # Convert uploaded images to base64
78
+ image_base64_list = [encode_image_to_base64(img) for img in images]
79
+
80
+ # Convert base64 to image objects (if LangGraph expects PIL.Image)
81
+ image_objects = [process_image(img_b64) for img_b64 in image_base64_list]
82
 
83
+ # Invoke LangGraph
84
+ result = graph.invoke({
85
+ 'topic': query,
86
+ 'images': image_base64_list,
87
+ 'latest_preferred_topics':preferred_topics
88
+ },
89
+ config={"configurable": {"thread_id": thread_id}})
90
 
91
+ return {
92
+ 'response': result,
93
+ }
my_agent/__pycache__/agent.cpython-312.pyc CHANGED
Binary files a/my_agent/__pycache__/agent.cpython-312.pyc and b/my_agent/__pycache__/agent.cpython-312.pyc differ
 
my_agent/agent.py CHANGED
@@ -1,9 +1,13 @@
1
  from langgraph.graph import StateGraph, START, END
2
  from .utils.state import State
3
- from .utils.nodes import retrieve, generate_story, generate_brainstroming , generate_final_story, route_after_selection, select_preferred_topics
 
4
 
5
- def build_graph():
 
 
6
  builder = StateGraph(State)
 
7
  builder.add_node(retrieve)
8
  builder.add_node(generate_story)
9
  builder.add_node(generate_brainstroming)
@@ -12,8 +16,13 @@ def build_graph():
12
 
13
 
14
  # Normal edges
 
 
15
  builder.add_edge(START, "retrieve")
16
  builder.add_edge("retrieve", "generate_story")
 
 
 
17
  builder.add_edge("generate_story", "generate_brainstroming")
18
  builder.add_edge("generate_brainstroming", "select_preferred_topics")
19
 
@@ -21,4 +30,33 @@ def build_graph():
21
  builder.add_conditional_edges("select_preferred_topics", route_after_selection,{True:'retrieve',False:'generate_final_story'})
22
  builder.add_edge("generate_final_story",END)
23
 
24
- return builder.compile()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from langgraph.graph import StateGraph, START, END
2
  from .utils.state import State
3
+ from .utils.nodes import retrieve, generate_story, generate_brainstroming , generate_final_story, route_after_selection, select_preferred_topics,caption_image
4
+ from langgraph.checkpoint.memory import MemorySaver
5
 
6
+ memory = MemorySaver()
7
+
8
+ def build_graph_old():
9
  builder = StateGraph(State)
10
+ # builder.add_node(caption_image)
11
  builder.add_node(retrieve)
12
  builder.add_node(generate_story)
13
  builder.add_node(generate_brainstroming)
 
16
 
17
 
18
  # Normal edges
19
+ # builder.add_edge(START, "caption_image")
20
+
21
  builder.add_edge(START, "retrieve")
22
  builder.add_edge("retrieve", "generate_story")
23
+
24
+ # builder.add_edge("caption_image", "retrieve")
25
+ # builder.add_edge("retrieve", "generate_story")
26
  builder.add_edge("generate_story", "generate_brainstroming")
27
  builder.add_edge("generate_brainstroming", "select_preferred_topics")
28
 
 
30
  builder.add_conditional_edges("select_preferred_topics", route_after_selection,{True:'retrieve',False:'generate_final_story'})
31
  builder.add_edge("generate_final_story",END)
32
 
33
+ return builder.compile(checkpointer=memory)
34
+
35
+
36
+
37
+
38
+ def build_graph():
39
+ builder = StateGraph(State)
40
+ # builder.add_node(caption_image)
41
+ builder.add_node(retrieve)
42
+ builder.add_node(generate_story)
43
+ builder.add_node(generate_brainstroming)
44
+ builder.add_node(select_preferred_topics)
45
+ builder.add_node(generate_final_story)
46
+
47
+
48
+ # Normal edges
49
+ # builder.add_edge(START, "caption_image")
50
+
51
+ builder.add_edge(START, "retrieve")
52
+ builder.add_edge("retrieve", "generate_story")
53
+
54
+ # builder.add_edge("caption_image", "retrieve")
55
+ # builder.add_edge("retrieve", "generate_story")
56
+ builder.add_edge("generate_story", "generate_brainstroming")
57
+
58
+ # Conditional edge
59
+ builder.add_edge("generate_brainstroming", END)
60
+ # builder.add_edge("generate_final_story",END)
61
+
62
+ return builder.compile(checkpointer=memory)
my_agent/utils/__pycache__/nodes.cpython-312.pyc CHANGED
Binary files a/my_agent/utils/__pycache__/nodes.cpython-312.pyc and b/my_agent/utils/__pycache__/nodes.cpython-312.pyc differ
 
my_agent/utils/__pycache__/state.cpython-312.pyc CHANGED
Binary files a/my_agent/utils/__pycache__/state.cpython-312.pyc and b/my_agent/utils/__pycache__/state.cpython-312.pyc differ
 
my_agent/utils/nodes.py CHANGED
@@ -5,12 +5,50 @@ from .tools import StoryFormatter, BrainstromTopicFormatter
5
  from langchain_core.messages import SystemMessage
6
  from .models_loader import llm , ST
7
  from .data_loader import load_influencer_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
 
10
  def retrieve(state: State) -> State:
11
  print('Moving to retrieval process')
12
  retrievals=[]
13
- if len(state.preferred_topics)==0:
14
  for topic in state.topic: # Loop through each topic
15
  embedded_query = ST.encode(topic) # Embed each topic
16
  data = load_influencer_data()
@@ -22,7 +60,9 @@ def retrieve(state: State) -> State:
22
  print('Retrieval process completed......')
23
  state.retrievals.append(retrievals)
24
 
25
- if len (state.preferred_topics)>0:
 
 
26
  for topic in state.preferred_topics[-1]: # Loop through each topic
27
  embedded_query = ST.encode(topic) # Embed each topic
28
  data = load_influencer_data()
@@ -31,10 +71,11 @@ def retrieve(state: State) -> State:
31
  # Construct a list of dictionaries for this topic
32
  result = [{user: story} for user, story in zip(retrieved_examples['username'], retrieved_examples['agentic_story'])]
33
  retrievals.append(result)
34
- print('Retrieval process completed......')
 
35
  state.retrievals.append(retrievals)
36
 
37
- print('The retrieval is:\n',state.retrievals )
38
  # return State(messages="Retrieved",topic=state.topic,retrievals=state.retrievals)
39
  return state
40
 
@@ -63,6 +104,7 @@ def generate_story(state:State)-> State:
63
 
64
  **Final Reminder** You have to strongly focus on these topics while creating the storyline: {state.preferred_topics[-1]}'''
65
 
 
66
 
67
  messages = [SystemMessage(content=template)]
68
  response = llm.bind_tools([StoryFormatter]).invoke(messages)
@@ -145,9 +187,11 @@ def select_preferred_topics(state: State)-> State:
145
  state.carry_on=True
146
  return state
147
 
148
-
149
-
150
-
 
 
151
 
152
  def generate_final_story(state:State)-> State:
153
  if len(state.preferred_topics)>0:
@@ -166,15 +210,15 @@ def generate_final_story(state:State)-> State:
166
  response = response.content
167
  else:
168
  response = "No response"
169
- state.final_story=response
170
  state.stories.append(response)
171
  return state
172
 
173
- state.final_story=state.stories[-1]
 
174
  return state
175
 
176
 
177
 
178
- def route_after_selection(state:State):
179
- print('The output is:',state.carry_on)
180
- return state.carry_on
 
5
  from langchain_core.messages import SystemMessage
6
  from .models_loader import llm , ST
7
  from .data_loader import load_influencer_data
8
+ from groq import Groq
9
+ import os
10
+
11
+
12
+ def caption_image(state: State) -> State:
13
+ if state.images[-1]!=None:
14
+ print('Captioning image')
15
+ client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
16
+
17
+ chat_completion = client.chat.completions.create(
18
+ messages=[
19
+ {
20
+ "role": "user",
21
+ "content": [
22
+ {"type": "text", "text": "What's in this image?"},
23
+ {
24
+ "type": "image_url",
25
+ "image_url": {
26
+ "url": f"data:image/jpeg;base64,{state.images[-1]}",
27
+ },
28
+ },
29
+ ],
30
+ }
31
+ ],
32
+ model="meta-llama/llama-4-scout-17b-16e-instruct",
33
+ )
34
+ response=chat_completion.choices[0].message.content
35
+ state.image_captions.append(response)
36
+ return state
37
+
38
+ else:
39
+ state.image_captions.append(None)
40
+ return state
41
+
42
+ # elif state.images[-1]==None:
43
+ # state.image_captions.append(None)
44
+
45
+
46
 
47
 
48
  def retrieve(state: State) -> State:
49
  print('Moving to retrieval process')
50
  retrievals=[]
51
+ if len(state.latest_preferred_topics)==0:
52
  for topic in state.topic: # Loop through each topic
53
  embedded_query = ST.encode(topic) # Embed each topic
54
  data = load_influencer_data()
 
60
  print('Retrieval process completed......')
61
  state.retrievals.append(retrievals)
62
 
63
+ if len (state.latest_preferred_topics)>0:
64
+ print('The preferred_topics are:',state.latest_preferred_topics)
65
+ state.preferred_topics.append(state.latest_preferred_topics)
66
  for topic in state.preferred_topics[-1]: # Loop through each topic
67
  embedded_query = ST.encode(topic) # Embed each topic
68
  data = load_influencer_data()
 
71
  # Construct a list of dictionaries for this topic
72
  result = [{user: story} for user, story in zip(retrieved_examples['username'], retrieved_examples['agentic_story'])]
73
  retrievals.append(result)
74
+ print('Retrieval process completed for preferred_topics......')
75
+ state.latest_preferred_topics=[]
76
  state.retrievals.append(retrievals)
77
 
78
+ # print('The retrieval is:\n',state.retrievals )
79
  # return State(messages="Retrieved",topic=state.topic,retrievals=state.retrievals)
80
  return state
81
 
 
104
 
105
  **Final Reminder** You have to strongly focus on these topics while creating the storyline: {state.preferred_topics[-1]}'''
106
 
107
+ # and {state.image_captions[-1]}
108
 
109
  messages = [SystemMessage(content=template)]
110
  response = llm.bind_tools([StoryFormatter]).invoke(messages)
 
187
  state.carry_on=True
188
  return state
189
 
190
+ def route_after_selection(state:State):
191
+ if len(state.latest_preferred_topics)==0:
192
+ return False
193
+ elif len(state.latest_preferred_topics)>0:
194
+ return True
195
 
196
  def generate_final_story(state:State)-> State:
197
  if len(state.preferred_topics)>0:
 
210
  response = response.content
211
  else:
212
  response = "No response"
213
+ state.final_story.append(response)
214
  state.stories.append(response)
215
  return state
216
 
217
+ state.final_story.append(state.stories[-1])
218
+ state.latest_preferred_topics=[]
219
  return state
220
 
221
 
222
 
223
+
224
+
 
my_agent/utils/state.py CHANGED
@@ -9,8 +9,10 @@ class State(BaseModel):
9
  brainstroming_topics: Optional[list] = []
10
  preferred_topics: Optional[list] = []
11
  stories : Optional[list]=[]
12
- final_story: Optional[str]=None
13
  retrievals : Optional[list]=[]
14
  business_details : Optional[dict]={}
15
  latest_preferred_topics: Optional[list] = []
 
 
16
  model_config = ConfigDict(arbitrary_types_allowed=True)
 
9
  brainstroming_topics: Optional[list] = []
10
  preferred_topics: Optional[list] = []
11
  stories : Optional[list]=[]
12
+ final_story: Optional[list]=[]
13
  retrievals : Optional[list]=[]
14
  business_details : Optional[dict]={}
15
  latest_preferred_topics: Optional[list] = []
16
+ images: Optional[list[str]] = [] # Base64-encoded strings of images
17
+ image_captions: Optional[list] = []
18
  model_config = ConfigDict(arbitrary_types_allowed=True)