subashpoudel's picture
Debugged the unmatched f-string
32131c3
raw
history blame
6.34 kB
import pandas as pd
import ast
from .state import State
from .tools import StoryFormatter, BrainstromTopicFormatter
from langchain_core.messages import SystemMessage
from .models_loader import llm , ST
from .data_loader import load_influencer_data
from groq import Groq
import os
from .prompts import image_captioning_prompt , initial_story_prompt , refined_story_prompt , brainstroming_prompt , final_story_prompt
def caption_image(state: State) -> State:
if len(state.images)>0:
if state.images[-1]!=None:
print('Captioning image')
client = Groq(api_key=os.environ.get('GROQ_API_KEY'))
chat_completion = client.chat.completions.create(
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": image_captioning_prompt},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{state.images[-1]}",
},
},
],
}
],
model="meta-llama/llama-4-scout-17b-16e-instruct",
)
response=chat_completion.choices[0].message.content
state.image_captions.append(response)
return state
else:
state.images.append(None)
state.image_captions.append(None)
return state
def retrieve(state: State) -> State:
print('Moving to retrieval process')
retrievals=[]
if len(state.latest_preferred_topics)==0:
for topic in state.topic: # Loop through each topic
embedded_query = ST.encode(topic) # Embed each topic
data = load_influencer_data()
scores, retrieved_examples = data.get_nearest_examples("embeddings", embedded_query, k=1)
# Construct a list of dictionaries for this topic
result = [{user: story} for user, story in zip(retrieved_examples['username'], retrieved_examples['agentic_story'])]
retrievals.append(result)
print('Retrieval process completed......')
state.retrievals.append(retrievals)
if len (state.latest_preferred_topics)>0:
print('The preferred_topics are:',state.latest_preferred_topics)
state.preferred_topics.append(state.latest_preferred_topics)
for topic in state.preferred_topics[-1]: # Loop through each topic
embedded_query = ST.encode(topic) # Embed each topic
data = load_influencer_data()
scores, retrieved_examples = data.get_nearest_examples("embeddings", embedded_query, k=1)
# Construct a list of dictionaries for this topic
result = [{user: story} for user, story in zip(retrieved_examples['username'], retrieved_examples['agentic_story'])]
retrievals.append(result)
print('Retrieval process completed for preferred_topics......')
state.latest_preferred_topics=[]
state.retrievals.append(retrievals)
# print('The retrieval is:\n',state.retrievals )
# return State(messages="Retrieved",topic=state.topic,retrievals=state.retrievals)
return state
def generate_story(state:State)-> State:
print('The state retrieval is:',state.retrievals)
retrieval_list= state.retrievals[-1]
agentic_stories = []
for item in retrieval_list:
print('item:', item[-1].values())
agentic_stories.extend(item[-1].values()) # Add all stories to the list
retrieval = " ".join(agentic_stories)
if len(state.preferred_topics)==0:
template = initial_story_prompt(retrieval , state)
else:
template = refined_story_prompt(retrieval , state)
# and {state.image_captions[-1]}
messages = [SystemMessage(content=template)]
response = llm.bind_tools([StoryFormatter]).invoke(messages)
print('The response is:',response)
if hasattr(response, 'tool_calls') and response.tool_calls:
response = response.tool_calls[0]['args']
elif hasattr(response, 'content'):
response = response.content
else:
response = "No response"
state.stories.append(response)
# return State(messages="Story generated", topic=state.topic,stories=state.stories)
return state
def generate_brainstroming(state:State)-> State:
template= brainstroming_prompt(state)
messages = [SystemMessage(content=template)]
response = llm.bind_tools([BrainstromTopicFormatter]).invoke(messages)
print('The response is:',response)
if hasattr(response, 'tool_calls') and response.tool_calls:
response = response.tool_calls[0]['args']
elif hasattr(response, 'content'):
response = response.content
else:
response = "No response"
state.brainstroming_topics.append(response)
print('The brainstroming topics are:',state.brainstroming_topics)
# return State(messages="Story generated",topic=state.topic,brainstroming_topics=state.brainstroming_topics)
return state
def select_preferred_topics(state: State)-> State:
print("---human_feedback---")
topic_values = list(state.brainstroming_topics[-1].values())
print("Available topics:")
for idx, topic in enumerate(topic_values, 1):
print(f"{idx}. {topic}")
raw_input_str = input("Enter the numbers of your preferred topics (comma-separated), or press Enter to skip: ").strip()
if not raw_input_str:
state.carry_on=False
print("No topics selected. Ending process.")
return state
try:
preferred_indices = [int(i.strip()) for i in raw_input_str.split(",")]
preferred_topics = [topic_values[i - 1] for i in preferred_indices if 0 < i <= len(topic_values)]
# preferred_topics = user_input
state.preferred_topics.append(preferred_topics)
except Exception:
state.carry_on=False
print("Invalid input. Please try again.")
return state
if not preferred_topics:
state.carry_on=False
print("No valid topics selected. Ending process.")
return state
print("You selected:")
print(preferred_topics)
state.carry_on=True
return state
def route_after_selection(state:State):
if len(state.latest_preferred_topics)==0:
return False
elif len(state.latest_preferred_topics)>0:
return True