Spaces:
Sleeping
Sleeping
| import os | |
| from groq import Groq | |
| from .state import State | |
| from .tools import Retrieval | |
| from .state import BrainstromTopicFormatter | |
| from src.genai.utils.models_loader import llm_gpt , captioning_model | |
| from langchain_core.messages import SystemMessage ,HumanMessage, FunctionMessage | |
| from .prompts import image_captioning_prompt , initial_story_prompt , refined_story_prompt , brainstroming_prompt | |
| class ImageCaptioner: | |
| def __init__(self): | |
| self.captioning_model = captioning_model | |
| self.client = Groq(api_key=os.environ.get('GROQ_API_KEY')) | |
| def run(self, state: State) -> State: | |
| if len(state.images)>0: | |
| if state.images[-1]!=None: | |
| print('Captioning image') | |
| chat_completion = self.client.chat.completions.create( | |
| messages=[ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": image_captioning_prompt(state)}, | |
| { | |
| "type": "image_url", | |
| "image_url": { | |
| "url": f"data:image/jpeg;base64,{state.images[-1]}", | |
| }, | |
| }, | |
| ], | |
| } | |
| ], | |
| model=self.captioning_model, | |
| max_completion_tokens=50, | |
| temperature = 1 | |
| ) | |
| response=chat_completion.choices[0].message.content | |
| state.image_captions.append(response) | |
| return state | |
| else: | |
| state.images.append(None) | |
| state.image_captions.append(None) | |
| return state | |
| class Retriever: | |
| def __init__(self): | |
| self.retrievals = [] | |
| def run(self,state: State) -> State: | |
| query_prompt = 'Represent this sentence for searching relevant passages: ' | |
| if len(state.latest_preferred_topics)==0: | |
| for idea in state.idea: | |
| result = Retrieval(idea+query_prompt).influencers_data() | |
| self.retrievals.append(result) | |
| state.retrievals.append(self.retrievals) | |
| if len (state.latest_preferred_topics)>0: | |
| state.preferred_topics.append(state.latest_preferred_topics) | |
| for idea in state.preferred_topics[-1]: | |
| result = Retrieval(idea+query_prompt).influencers_data() | |
| self.retrievals.append(result) | |
| state.latest_preferred_topics=[] | |
| state.retrievals.append(self.retrievals) | |
| return state | |
| class StoryGenerator: | |
| def __init__(self): | |
| self.llm = llm_gpt | |
| def run(self,state:State)-> State: | |
| if len(state.preferred_topics)==0: | |
| template = initial_story_prompt(state) | |
| else: | |
| template = refined_story_prompt(state) | |
| messages = [SystemMessage(content=template), | |
| HumanMessage(content=f'''The idea of the video is:\n{state.idea}\n'''), | |
| FunctionMessage(name='generate_story_function',content=f'''The business details is:\n{state.business_details}\n | |
| The retrieved data of influencers is:\n{state.retrievals[-1]}\n | |
| The information from the image is:\n{state.image_captions[-1]} ''')] | |
| response = self.llm.invoke(messages) | |
| response = response.content | |
| state.stories.append(response) | |
| return state | |
| class BrainstromTopicGenerator: | |
| def __init__(self): | |
| self.llm = llm_gpt | |
| def run(self,state:State)-> State: | |
| template= brainstroming_prompt(state) | |
| messages = [SystemMessage(content=template), | |
| HumanMessage(content=f'''Here is the story to you for brainstorming:\n{state.stories[-1]}'''), | |
| FunctionMessage(content=f'''The details of business is:\n{state.business_details}\n''', name="brainstorm_tool")] | |
| print('Message for brainstorming:',messages) | |
| response = self.llm.with_structured_output(BrainstromTopicFormatter).invoke(messages) | |
| response = response.model_dump() | |
| state.brainstroming_topics.append(response) | |
| print('The brainstroming topics are:',state.brainstroming_topics) | |
| return state | |