File size: 5,896 Bytes
be3a5c4
 
 
322e9b1
be3a5c4
93a5bf9
 
a9f99c3
 
0c51449
fb491f0
 
 
93a5bf9
fb491f0
a9f99c3
 
 
415ac2b
 
 
32131c3
415ac2b
 
 
 
 
 
 
 
 
 
 
 
a9f99c3
415ac2b
 
 
 
 
 
 
 
a9f99c3
 
415ac2b
a9f99c3
 
 
be3a5c4
 
 
 
 
a9f99c3
b55b8d4
 
 
 
 
 
 
 
 
 
 
a9f99c3
 
 
b55b8d4
 
 
 
 
 
 
 
a9f99c3
 
b55b8d4
be3a5c4
a9f99c3
be3a5c4
 
 
 
fb491f0
 
 
 
 
 
 
 
 
 
 
be3a5c4
a9f99c3
be3a5c4
fb491f0
 
 
7fb95cb
fb491f0
 
 
be3a5c4
 
 
 
0c51449
 
be3a5c4
 
93a5bf9
 
be3a5c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b55b8d4
be3a5c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a9f99c3
 
 
 
 
85a68fb
415ac2b
be3a5c4
 
 
a9f99c3
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import pandas as pd
import ast
from .state import State
from .tools import retrieve_tool
from langchain_core.messages import SystemMessage
from utils.models_loader import llm , ST
from utils.data_loader import load_influencer_data
from groq import Groq
import os
from .prompts import image_captioning_prompt , initial_story_prompt , refined_story_prompt , brainstroming_prompt , final_story_prompt
from langgraph.prebuilt import create_react_agent
from pydantic import BaseModel , Field
from langchain_core.tools import tool
from .state import BrainstromTopicFormatter



def caption_image(state: State) -> State:
    if len(state.images)>0:
      if state.images[-1]!=None:
          print('Captioning image')
          client = Groq(api_key=os.environ.get('GROQ_API_KEY'))

          chat_completion = client.chat.completions.create(
              messages=[
                  {
                      "role": "user",
                      "content": [
                          {"type": "text", "text": image_captioning_prompt},
                          {
                              "type": "image_url",
                              "image_url": {
                                  "url": f"data:image/jpeg;base64,{state.images[-1]}",
                              },
                          },
                      ],
                  }
              ],
              model="meta-llama/llama-4-scout-17b-16e-instruct",
          )
          response=chat_completion.choices[0].message.content
          state.image_captions.append(response)
          return state
    
    else:
       state.images.append(None)
       state.image_captions.append(None)
       return state
    


def retrieve(state: State) -> State:
  print('Moving to retrieval process')
  retrievals=[]
  if len(state.latest_preferred_topics)==0:
      for topic in state.topic:  # Loop through each topic
          embedded_query = ST.encode(topic)  # Embed each topic
          data = load_influencer_data()
          scores, retrieved_examples = data.get_nearest_examples("embeddings", embedded_query, k=1)

          # Construct a list of dictionaries for this topic
          result = [{user: story} for user, story in zip(retrieved_examples['username'], retrieved_examples['agentic_story'])]
          retrievals.append(result)
          print('Retrieval process completed......')
      state.retrievals.append(retrievals)

  if len (state.latest_preferred_topics)>0:
      print('The preferred_topics are:',state.latest_preferred_topics)
      state.preferred_topics.append(state.latest_preferred_topics)
      for topic in state.preferred_topics[-1]:  # Loop through each topic
          embedded_query = ST.encode(topic)  # Embed each topic
          data = load_influencer_data()
          scores, retrieved_examples = data.get_nearest_examples("embeddings", embedded_query, k=1)

          # Construct a list of dictionaries for this topic
          result = [{user: story} for user, story in zip(retrieved_examples['username'], retrieved_examples['agentic_story'])]
          retrievals.append(result)
          print('Retrieval process completed for preferred_topics......')
      state.latest_preferred_topics=[]
      state.retrievals.append(retrievals)

#   print('The retrieval is:\n',state.retrievals )
  # return State(messages="Retrieved",topic=state.topic,retrievals=state.retrievals)
  return state

def generate_story(state:State)-> State:
    tools=[retrieve_tool]

    react_agent=create_react_agent(
        model=llm.bind_tools(tools),
        tools=tools
        
    )
    if len(state.preferred_topics)==0:
        template = initial_story_prompt(state)
    else:
        template = refined_story_prompt(state)

# and {state.image_captions[-1]}

    messages = [SystemMessage(content=template)]

    response = react_agent.invoke({'messages':messages})
    response = response['messages'][-1].content
    state.stories.append(response)
    # return State(messages="Story generated", topic=state.topic,stories=state.stories)
    return state



def generate_brainstroming(state:State)-> State:

  template= brainstroming_prompt(state)

  messages = [SystemMessage(content=template)]
  response = llm.with_structured_output(BrainstromTopicFormatter).invoke(messages)
  response = response.model_dump()
  state.brainstroming_topics.append(response)
  print('The brainstroming topics are:',state.brainstroming_topics)
  # return State(messages="Story generated",topic=state.topic,brainstroming_topics=state.brainstroming_topics)
  return state



def select_preferred_topics(state: State)-> State:
    print("---human_feedback---")

    topic_values = list(state.brainstroming_topics[-1].values())

    print("Available topics:")
    for idx, topic in enumerate(topic_values, 1):
        print(f"{idx}. {topic}")

    raw_input_str = input("Enter the numbers of your preferred topics (comma-separated), or press Enter to skip: ").strip()

    if not raw_input_str:
        state.carry_on=False
        print("No topics selected. Ending process.")
        return state

    try:
        preferred_indices = [int(i.strip()) for i in raw_input_str.split(",")]
        preferred_topics = [topic_values[i - 1] for i in preferred_indices if 0 < i <= len(topic_values)]
        # preferred_topics = user_input
        state.preferred_topics.append(preferred_topics)
    except Exception:
        state.carry_on=False
        print("Invalid input. Please try again.")
        return state

    if not preferred_topics:
        state.carry_on=False
        print("No valid topics selected. Ending process.")
        return state

    print("You selected:")
    print(preferred_topics)
    state.carry_on=True
    return state

def route_after_selection(state:State):
  if len(state.latest_preferred_topics)==0:
    return False
  elif len(state.latest_preferred_topics)>0:
    return True