Spaces:
Sleeping
Sleeping
| from .prompts import tool_return_prompt , extract_user_reference_prompt , query_response_prompt,captioning_prompt | |
| from langchain_core.messages import SystemMessage, HumanMessage, FunctionMessage | |
| from src.genai.utils.models_loader import llm_gpt | |
| from src.genai.utils.base_endpoint import base_url | |
| from .state import State | |
| from .tools import InfluencerRetrievalTool | |
| from .schemas import ToolResponseFormatter , UserReferenceResponseFormatter | |
| import os | |
| import requests | |
| from groq import Groq | |
| retriever=InfluencerRetrievalTool() | |
| class ImageCaptionNode: | |
| def __init__(self, api_key=os.environ.get('GROQ_API_KEY')): | |
| self.client = Groq(api_key=api_key) | |
| def run(self,state:State): | |
| if len(state['image_base64'])>0: | |
| print('Captioning image') | |
| chat_completion = self.client.chat.completions.create( | |
| messages=[ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": captioning_prompt(state['messages'])}, | |
| { | |
| "type": "image_url", | |
| "image_url": { | |
| "url": f"data:image/jpg;base64,{state['image_base64[-1]']}", | |
| }, | |
| }, | |
| ], | |
| } | |
| ], | |
| model="meta-llama/llama-4-scout-17b-16e-instruct", | |
| max_completion_tokens=50, | |
| temperature = 1 | |
| ) | |
| response=chat_completion.choices[0].message.content | |
| return {'image_caption': response} | |
| else: | |
| print('No image provided') | |
| return {'image_caption':None} | |
| class ToolReturnNode: | |
| """Node for determining which tools to use based on user messages.""" | |
| def __init__(self, llm=llm_gpt): | |
| self.llm = llm | |
| def run(self, state:State): | |
| if len(state["messages"]) > 23: | |
| state["messages"] = state["messages"][-18:] | |
| template = [SystemMessage(content=tool_return_prompt)] + state["messages"] | |
| response = self.llm.with_structured_output(ToolResponseFormatter, method='function_calling').invoke(template) | |
| print('The response is:', response) | |
| return {"messages": [{'role': 'assistant', 'content': f"Tool invoked: {response.tools}"}], | |
| "tools":response.tools} | |
| class QueryResponseNode: | |
| def __init__(self): | |
| self.llm = llm_gpt | |
| def run(self,state:State): | |
| print('Entered to query response') | |
| if len(state['tools'])<1: | |
| print('Going for retrieval') | |
| retrieved_data=retriever.retrieve_for_orchestration(state['messages']) | |
| print('The data is retrieved.') | |
| template = [SystemMessage(content=query_response_prompt), | |
| FunctionMessage(name='inf-data-retrieval',content=retrieved_data)] + state["messages"] | |
| response = self.llm.invoke(template) | |
| print('Query Response:', response) | |
| return {"messages": [{'role': 'assistant', 'content': response.content}], | |
| "query_response":response.content} | |
| else: | |
| return{ | |
| "query_response": f'''Okay i will perform {" ".join(state['tools'])} for you.''' | |
| } | |
| class ExtractUserReferenceNode: | |
| """Node for extracting video idea and story from user's messages.""" | |
| def __init__(self, llm=llm_gpt): | |
| self.llm = llm | |
| def run(self, state): | |
| latest_human_message = next( | |
| (msg for msg in reversed(state['messages']) if isinstance(msg, HumanMessage)), | |
| None | |
| ) | |
| print('Latest human message:', latest_human_message) | |
| template = [SystemMessage(content=extract_user_reference_prompt), | |
| HumanMessage(content=latest_human_message.content)] | |
| response = self.llm.with_structured_output(UserReferenceResponseFormatter, method='function_calling').invoke(template) | |
| print('The extracted reference:', response) | |
| return{ | |
| 'video_idea': response.video_idea, | |
| 'video_story': response.video_story | |
| } | |
| class InvokeToolNode: | |
| def __init__(self): | |
| self.base_url = base_url | |
| self.headers = { | |
| "Authorization": "Bearer YOUR_API_KEY", # replace with your API key if needed | |
| "Content-Type": "application/json" | |
| } | |
| def run(self,state:State): | |
| latest_human_message = next( | |
| (msg for msg in reversed(state['messages']) if isinstance(msg, HumanMessage)), | |
| None | |
| ) | |
| data_to_return=[] | |
| for tool in state['tools']: | |
| if 'analytics' in tool: | |
| url = f'''{self.base_url}{tool}''' | |
| response = requests.get(url, params=latest_human_message.content,headers=self.headers) | |
| return { | |
| 'analytics_response':response.json() | |
| } | |