subashpoudel commited on
Commit
415ac2b
·
1 Parent(s): 0a786d2

Handled the captioning node

Browse files
Files changed (1) hide show
  1. my_agent/utils/nodes.py +26 -41
my_agent/utils/nodes.py CHANGED
@@ -11,32 +11,34 @@ from .prompts import image_captioning_prompt , initial_story_prompt , refined_st
11
 
12
 
13
  def caption_image(state: State) -> State:
14
- if state.images[-1]!=None:
15
- print('Captioning image')
16
- client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
17
-
18
- chat_completion = client.chat.completions.create(
19
- messages=[
20
- {
21
- "role": "user",
22
- "content": [
23
- {"type": "text", "text": image_captioning_prompt},
24
- {
25
- "type": "image_url",
26
- "image_url": {
27
- "url": f"data:image/jpeg;base64,{state.images[-1]}",
 
 
28
  },
29
- },
30
- ],
31
- }
32
- ],
33
- model="meta-llama/llama-4-scout-17b-16e-instruct",
34
- )
35
- response=chat_completion.choices[0].message.content
36
- state.image_captions.append(response)
37
- return state
38
 
39
  else:
 
40
  state.image_captions.append(None)
41
  return state
42
 
@@ -173,24 +175,7 @@ def route_after_selection(state:State):
173
  elif len(state.latest_preferred_topics)>0:
174
  return True
175
 
176
- def generate_final_story(final_state):
177
- if len(final_state['preferred_topics'])>0:
178
- template = final_story_prompt(final_state)
179
- messages = [SystemMessage(content=template)]
180
- response = llm.bind_tools([StoryFormatter]).invoke(messages)
181
- print('The final response is:',response)
182
- if hasattr(response, 'tool_calls') and response.tool_calls:
183
- response = response.tool_calls[0]['args']
184
- elif hasattr(response, 'content'):
185
- response = response.content
186
- else:
187
- response = "No response"
188
- # state.final_story.append(response)
189
- # state.stories.append(response)
190
- return response
191
-
192
- else:
193
- return final_state['stories'][-1]
194
 
195
 
196
 
 
11
 
12
 
13
  def caption_image(state: State) -> State:
14
+ if len(state.images)>0:
15
+ if state.images[-1]!=None:
16
+ print('Captioning image')
17
+ client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
18
+
19
+ chat_completion = client.chat.completions.create(
20
+ messages=[
21
+ {
22
+ "role": "user",
23
+ "content": [
24
+ {"type": "text", "text": image_captioning_prompt},
25
+ {
26
+ "type": "image_url",
27
+ "image_url": {
28
+ "url": f"data:image/jpeg;base64,{state.images[-1]}",
29
+ },
30
  },
31
+ ],
32
+ }
33
+ ],
34
+ model="meta-llama/llama-4-scout-17b-16e-instruct",
35
+ )
36
+ response=chat_completion.choices[0].message.content
37
+ state.image_captions.append(response)
38
+ return state
 
39
 
40
  else:
41
+ state.images.append(None)
42
  state.image_captions.append(None)
43
  return state
44
 
 
175
  elif len(state.latest_preferred_topics)>0:
176
  return True
177
 
178
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
 
181