File size: 5,099 Bytes
8ce97f0
 
583f6dd
8ce97f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f57d05
8ce97f0
 
3c1150c
38cf703
 
 
 
 
 
8ce97f0
38cf703
 
 
6f57d05
8ce97f0
 
 
 
 
6f57d05
8ce97f0
 
 
 
6f57d05
8ce97f0
6f57d05
8ce97f0
6f57d05
8ce97f0
 
 
6f57d05
8ce97f0
 
 
 
 
 
38cf703
 
 
 
 
 
 
 
 
 
 
 
 
8ce97f0
38cf703
 
6f57d05
 
8ce97f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f57d05
8ce97f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38cf703
6874dac
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
from .prompts import tool_return_prompt , extract_user_reference_prompt , query_response_prompt,captioning_prompt
from langchain_core.messages import SystemMessage, HumanMessage, FunctionMessage
from src.genai.utils.models_loader import llm_gpt
from src.genai.utils.base_endpoint import base_url
from .state import State
from .tools import InfluencerRetrievalTool
from .schemas import ToolResponseFormatter , UserReferenceResponseFormatter
import os
import requests
from groq import Groq

retriever=InfluencerRetrievalTool()

class ImageCaptionNode:
    def __init__(self, api_key=os.environ.get('GROQ_API_KEY')):
        self.client = Groq(api_key=api_key)

    def run(self,state:State):
        if len(state['image_base64'])>0:
            print('Captioning image')
            chat_completion = self.client.chat.completions.create(
                messages=[
                    {
                        "role": "user",
                        "content": [
                            {"type": "text", "text": captioning_prompt(state['messages'])},
                            {
                                "type": "image_url",
                                "image_url": {
                                    "url": f"data:image/jpg;base64,{state['image_base64[-1]']}",
                                },
                            },
                        ],
                    }
                ],
                model="meta-llama/llama-4-scout-17b-16e-instruct",
                max_completion_tokens=50,
                temperature = 1
            )
            response=chat_completion.choices[0].message.content
            return {'image_caption': response}
        else:
            print('No image provided')
            return {'image_caption':None}


class ToolReturnNode:
    """Node for determining which tools to use based on user messages."""

    def __init__(self, llm=llm_gpt):
        self.llm = llm

    def run(self, state:State):
        if len(state["messages"]) > 23:
            state["messages"] = state["messages"][-18:]
        template = [SystemMessage(content=tool_return_prompt)] + state["messages"]
        response = self.llm.with_structured_output(ToolResponseFormatter, method='function_calling').invoke(template)
        print('The response is:', response)
        return {"messages": [{'role': 'assistant', 'content': f"Tool invoked: {response.tools}"}],
                "tools":response.tools}
                

class QueryResponseNode:
    def __init__(self):
        self.llm = llm_gpt
    
    def run(self,state:State):
        print('Entered to query response')
        if len(state['tools'])<1:
            print('Going for retrieval')
            retrieved_data=retriever.retrieve_for_orchestration(state['messages'])
            print('The data is retrieved.')
            template = [SystemMessage(content=query_response_prompt),
                        FunctionMessage(name='inf-data-retrieval',content=retrieved_data)] + state["messages"]
            response = self.llm.invoke(template)
            print('Query Response:', response)
            return {"messages": [{'role': 'assistant', 'content': response.content}],
                    "query_response":response.content}
        else:
            return{
                "query_response": f'''Okay i will perform {" ".join(state['tools'])} for you.'''
            }


class ExtractUserReferenceNode:
    """Node for extracting video idea and story from user's messages."""

    def __init__(self, llm=llm_gpt):
        self.llm = llm

    def run(self, state):
        latest_human_message = next(
            (msg for msg in reversed(state['messages']) if isinstance(msg, HumanMessage)),
            None
        )
        print('Latest human message:', latest_human_message)
        template = [SystemMessage(content=extract_user_reference_prompt),
                    HumanMessage(content=latest_human_message.content)]
        response = self.llm.with_structured_output(UserReferenceResponseFormatter, method='function_calling').invoke(template)
        print('The extracted reference:', response)
        return{
            'video_idea': response.video_idea,
            'video_story': response.video_story
        }


class InvokeToolNode:
    def __init__(self):
        self.base_url = base_url
        self.headers = {
            "Authorization": "Bearer YOUR_API_KEY",  # replace with your API key if needed
            "Content-Type": "application/json"
            }

    def run(self,state:State):
        latest_human_message = next(
            (msg for msg in reversed(state['messages']) if isinstance(msg, HumanMessage)),
            None
        )
        data_to_return=[]
        for tool in state['tools']:
            if 'analytics' in tool:
                url = f'''{self.base_url}{tool}'''
                response = requests.get(url, params=latest_human_message.content,headers=self.headers)
                return {
                    'analytics_response':response.json()
                }