subashpoudel's picture
Fixed importing errors
d98138c
import os
from groq import Groq
from .state import State
from .tools import Retrieval
from .state import BrainstromTopicFormatter
from src.genai.utils.models_loader import llm_gpt , captioning_model
from langchain_core.messages import SystemMessage ,HumanMessage, FunctionMessage
from .prompts import image_captioning_prompt , initial_story_prompt , refined_story_prompt , brainstroming_prompt
class ImageCaptioner:
def __init__(self):
self.captioning_model = captioning_model
self.client = Groq(api_key=os.environ.get('GROQ_API_KEY'))
def run(self, state: State) -> State:
if len(state.images)>0:
if state.images[-1]!=None:
print('Captioning image')
chat_completion = self.client.chat.completions.create(
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": image_captioning_prompt(state)},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{state.images[-1]}",
},
},
],
}
],
model=self.captioning_model,
max_completion_tokens=50,
temperature = 1
)
response=chat_completion.choices[0].message.content
state.image_captions.append(response)
return state
else:
state.images.append(None)
state.image_captions.append(None)
return state
class Retriever:
def __init__(self):
self.retrievals = []
def run(self,state: State) -> State:
query_prompt = 'Represent this sentence for searching relevant passages: '
if len(state.latest_preferred_topics)==0:
for idea in state.idea:
result = Retrieval(idea+query_prompt).influencers_data()
self.retrievals.append(result)
state.retrievals.append(self.retrievals)
if len (state.latest_preferred_topics)>0:
state.preferred_topics.append(state.latest_preferred_topics)
for idea in state.preferred_topics[-1]:
result = Retrieval(idea+query_prompt).influencers_data()
self.retrievals.append(result)
state.latest_preferred_topics=[]
state.retrievals.append(self.retrievals)
return state
class StoryGenerator:
def __init__(self):
self.llm = llm_gpt
def run(self,state:State)-> State:
if len(state.preferred_topics)==0:
template = initial_story_prompt(state)
else:
template = refined_story_prompt(state)
messages = [SystemMessage(content=template),
HumanMessage(content=f'''The idea of the video is:\n{state.idea}\n'''),
FunctionMessage(name='generate_story_function',content=f'''The business details is:\n{state.business_details}\n
The retrieved data of influencers is:\n{state.retrievals[-1]}\n
The information from the image is:\n{state.image_captions[-1]} ''')]
response = self.llm.invoke(messages)
response = response.content
state.stories.append(response)
return state
class BrainstromTopicGenerator:
def __init__(self):
self.llm = llm_gpt
def run(self,state:State)-> State:
template= brainstroming_prompt(state)
messages = [SystemMessage(content=template),
HumanMessage(content=f'''Here is the story to you for brainstorming:\n{state.stories[-1]}'''),
FunctionMessage(content=f'''The details of business is:\n{state.business_details}\n''', name="brainstorm_tool")]
print('Message for brainstorming:',messages)
response = self.llm.with_structured_output(BrainstromTopicFormatter).invoke(messages)
response = response.model_dump()
state.brainstroming_topics.append(response)
print('The brainstroming topics are:',state.brainstroming_topics)
return state