subashpoudel's picture
Refined chatbot
6f57d05
from .prompts import tool_return_prompt , extract_user_reference_prompt , query_response_prompt,captioning_prompt
from langchain_core.messages import SystemMessage, HumanMessage, FunctionMessage
from src.genai.utils.models_loader import llm_gpt
from src.genai.utils.base_endpoint import base_url
from .state import State
from .tools import InfluencerRetrievalTool
from .schemas import ToolResponseFormatter , UserReferenceResponseFormatter
import os
import requests
from groq import Groq
retriever=InfluencerRetrievalTool()
class ImageCaptionNode:
def __init__(self, api_key=os.environ.get('GROQ_API_KEY')):
self.client = Groq(api_key=api_key)
def run(self,state:State):
if len(state['image_base64'])>0:
print('Captioning image')
chat_completion = self.client.chat.completions.create(
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": captioning_prompt(state['messages'])},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpg;base64,{state['image_base64[-1]']}",
},
},
],
}
],
model="meta-llama/llama-4-scout-17b-16e-instruct",
max_completion_tokens=50,
temperature = 1
)
response=chat_completion.choices[0].message.content
return {'image_caption': response}
else:
print('No image provided')
return {'image_caption':None}
class ToolReturnNode:
"""Node for determining which tools to use based on user messages."""
def __init__(self, llm=llm_gpt):
self.llm = llm
def run(self, state:State):
if len(state["messages"]) > 23:
state["messages"] = state["messages"][-18:]
template = [SystemMessage(content=tool_return_prompt)] + state["messages"]
response = self.llm.with_structured_output(ToolResponseFormatter, method='function_calling').invoke(template)
print('The response is:', response)
return {"messages": [{'role': 'assistant', 'content': f"Tool invoked: {response.tools}"}],
"tools":response.tools}
class QueryResponseNode:
def __init__(self):
self.llm = llm_gpt
def run(self,state:State):
print('Entered to query response')
if len(state['tools'])<1:
print('Going for retrieval')
retrieved_data=retriever.retrieve_for_orchestration(state['messages'])
print('The data is retrieved.')
template = [SystemMessage(content=query_response_prompt),
FunctionMessage(name='inf-data-retrieval',content=retrieved_data)] + state["messages"]
response = self.llm.invoke(template)
print('Query Response:', response)
return {"messages": [{'role': 'assistant', 'content': response.content}],
"query_response":response.content}
else:
return{
"query_response": f'''Okay i will perform {" ".join(state['tools'])} for you.'''
}
class ExtractUserReferenceNode:
"""Node for extracting video idea and story from user's messages."""
def __init__(self, llm=llm_gpt):
self.llm = llm
def run(self, state):
latest_human_message = next(
(msg for msg in reversed(state['messages']) if isinstance(msg, HumanMessage)),
None
)
print('Latest human message:', latest_human_message)
template = [SystemMessage(content=extract_user_reference_prompt),
HumanMessage(content=latest_human_message.content)]
response = self.llm.with_structured_output(UserReferenceResponseFormatter, method='function_calling').invoke(template)
print('The extracted reference:', response)
return{
'video_idea': response.video_idea,
'video_story': response.video_story
}
class InvokeToolNode:
def __init__(self):
self.base_url = base_url
self.headers = {
"Authorization": "Bearer YOUR_API_KEY", # replace with your API key if needed
"Content-Type": "application/json"
}
def run(self,state:State):
latest_human_message = next(
(msg for msg in reversed(state['messages']) if isinstance(msg, HumanMessage)),
None
)
data_to_return=[]
for tool in state['tools']:
if 'analytics' in tool:
url = f'''{self.base_url}{tool}'''
response = requests.get(url, params=latest_human_message.content,headers=self.headers)
return {
'analytics_response':response.json()
}