subashpoudel's picture
Updated the tools
1ce8b88
raw
history blame
2.52 kB
from langchain_groq import ChatGroq
from pydantic import BaseModel, Field
from dotenv import load_dotenv
load_dotenv()
import os
import numpy as np
from langchain_core.tools import tool
from .data_loader import load_influencer_data
from .models_loader import ST , llm
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
from langchain_core.messages import SystemMessage
import re
os.environ['GROQ_API_KEY']=os.getenv('GROQ_API_KEY')
class StoryFormatter(BaseModel):
"""Always use this tool to structure your response to the user."""
story: str=Field(description="How to introduce the scene and set the tone. What is happening in the scene? Describe key visuals and actions")
narration:str=Field(description="Suggestions for narration or voiceover that complements the visuals." )
text_in_the_Video:str=Field(description="Propose important text overlays for key moments.")
transitions:str=Field(description="Smooth transitions between scenes to maintain flow.")
emotional_tone:str=Field(description="The mood and energy of the scenes (e.g., excitement, calm, tension, joy")
key_visuals:str=Field(description="Important props, locations, sound effects, or background music to enhance the video.")
class BrainstromTopicFormatter(BaseModel):
topic1:str=Field(description="First brainstorming topic of the story")
topic2:str=Field(description="Second brainstorming topic of the story")
topic3:str=Field(description="Third brainstorming topic of the story")
topic4:str=Field(description="Fourth brainstorming topic of the story")
class QueryFormatter(BaseModel):
messages:str = Field(description="The user query")
business_details: str = Field(description="The details of the business of that user.")
@tool("influencer's data-retrieval-tool", args_schema=QueryFormatter, return_direct=False,description="Retrieve influencer-related data for a given query.")
def retrieve_tool(messages, business_details):
'''Always invoke this tool once.'''
print('The query for retrieval is:',messages)
embedded_query = ST.encode(str(messages)+str(business_details)) # Embed each topic
data = load_influencer_data()
scores, retrieved_examples = data.get_nearest_examples("embeddings", embedded_query, k=2)
# Construct a list of dictionaries for this topic
result = [{user: story} for user, story in zip(retrieved_examples['username'], retrieved_examples['agentic_story'])]
print('Tool response:',result)
return result