Spaces:
Sleeping
Sleeping
Commit
·
415ac2b
1
Parent(s):
0a786d2
Handled the captioning node
Browse files- my_agent/utils/nodes.py +26 -41
my_agent/utils/nodes.py
CHANGED
|
@@ -11,32 +11,34 @@ from .prompts import image_captioning_prompt , initial_story_prompt , refined_st
|
|
| 11 |
|
| 12 |
|
| 13 |
def caption_image(state: State) -> State:
|
| 14 |
-
if state.images
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
"
|
|
|
|
|
|
|
| 28 |
},
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
return state
|
| 38 |
|
| 39 |
else:
|
|
|
|
| 40 |
state.image_captions.append(None)
|
| 41 |
return state
|
| 42 |
|
|
@@ -173,24 +175,7 @@ def route_after_selection(state:State):
|
|
| 173 |
elif len(state.latest_preferred_topics)>0:
|
| 174 |
return True
|
| 175 |
|
| 176 |
-
|
| 177 |
-
if len(final_state['preferred_topics'])>0:
|
| 178 |
-
template = final_story_prompt(final_state)
|
| 179 |
-
messages = [SystemMessage(content=template)]
|
| 180 |
-
response = llm.bind_tools([StoryFormatter]).invoke(messages)
|
| 181 |
-
print('The final response is:',response)
|
| 182 |
-
if hasattr(response, 'tool_calls') and response.tool_calls:
|
| 183 |
-
response = response.tool_calls[0]['args']
|
| 184 |
-
elif hasattr(response, 'content'):
|
| 185 |
-
response = response.content
|
| 186 |
-
else:
|
| 187 |
-
response = "No response"
|
| 188 |
-
# state.final_story.append(response)
|
| 189 |
-
# state.stories.append(response)
|
| 190 |
-
return response
|
| 191 |
-
|
| 192 |
-
else:
|
| 193 |
-
return final_state['stories'][-1]
|
| 194 |
|
| 195 |
|
| 196 |
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
def caption_image(state: State) -> State:
|
| 14 |
+
if len(state.images)>0:
|
| 15 |
+
if state.images[-1]!=None:
|
| 16 |
+
print('Captioning image')
|
| 17 |
+
client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
|
| 18 |
+
|
| 19 |
+
chat_completion = client.chat.completions.create(
|
| 20 |
+
messages=[
|
| 21 |
+
{
|
| 22 |
+
"role": "user",
|
| 23 |
+
"content": [
|
| 24 |
+
{"type": "text", "text": image_captioning_prompt},
|
| 25 |
+
{
|
| 26 |
+
"type": "image_url",
|
| 27 |
+
"image_url": {
|
| 28 |
+
"url": f"data:image/jpeg;base64,{state.images[-1]}",
|
| 29 |
+
},
|
| 30 |
},
|
| 31 |
+
],
|
| 32 |
+
}
|
| 33 |
+
],
|
| 34 |
+
model="meta-llama/llama-4-scout-17b-16e-instruct",
|
| 35 |
+
)
|
| 36 |
+
response=chat_completion.choices[0].message.content
|
| 37 |
+
state.image_captions.append(response)
|
| 38 |
+
return state
|
|
|
|
| 39 |
|
| 40 |
else:
|
| 41 |
+
state.images.append(None)
|
| 42 |
state.image_captions.append(None)
|
| 43 |
return state
|
| 44 |
|
|
|
|
| 175 |
elif len(state.latest_preferred_topics)>0:
|
| 176 |
return True
|
| 177 |
|
| 178 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
|
| 180 |
|
| 181 |
|