Spaces:
Sleeping
Sleeping
File size: 4,443 Bytes
a9f99c3 ef9fa4b 93a5bf9 d98138c ef9fa4b fb491f0 a9f99c3 ef9fa4b fb491f0 ef9fa4b a0929ab be3a5c4 ef9fa4b be3a5c4 415ac2b be3a5c4 a9f99c3 8039e4b a9f99c3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 | import os
from groq import Groq
from .state import State
from .tools import Retrieval
from .state import BrainstromTopicFormatter
from src.genai.utils.models_loader import llm_gpt , captioning_model
from langchain_core.messages import SystemMessage ,HumanMessage, FunctionMessage
from .prompts import image_captioning_prompt , initial_story_prompt , refined_story_prompt , brainstroming_prompt
class ImageCaptioner:
def __init__(self):
self.captioning_model = captioning_model
self.client = Groq(api_key=os.environ.get('GROQ_API_KEY'))
def run(self, state: State) -> State:
if len(state.images)>0:
if state.images[-1]!=None:
print('Captioning image')
chat_completion = self.client.chat.completions.create(
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": image_captioning_prompt(state)},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{state.images[-1]}",
},
},
],
}
],
model=self.captioning_model,
max_completion_tokens=50,
temperature = 1
)
response=chat_completion.choices[0].message.content
state.image_captions.append(response)
return state
else:
state.images.append(None)
state.image_captions.append(None)
return state
class Retriever:
def __init__(self):
self.retrievals = []
def run(self,state: State) -> State:
query_prompt = 'Represent this sentence for searching relevant passages: '
if len(state.latest_preferred_topics)==0:
for idea in state.idea:
result = Retrieval(idea+query_prompt).influencers_data()
self.retrievals.append(result)
state.retrievals.append(self.retrievals)
if len (state.latest_preferred_topics)>0:
state.preferred_topics.append(state.latest_preferred_topics)
for idea in state.preferred_topics[-1]:
result = Retrieval(idea+query_prompt).influencers_data()
self.retrievals.append(result)
state.latest_preferred_topics=[]
state.retrievals.append(self.retrievals)
return state
class StoryGenerator:
def __init__(self):
self.llm = llm_gpt
def run(self,state:State)-> State:
if len(state.preferred_topics)==0:
template = initial_story_prompt(state)
else:
template = refined_story_prompt(state)
messages = [SystemMessage(content=template),
HumanMessage(content=f'''The idea of the video is:\n{state.idea}\n'''),
FunctionMessage(name='generate_story_function',content=f'''The business details is:\n{state.business_details}\n
The retrieved data of influencers is:\n{state.retrievals[-1]}\n
The information from the image is:\n{state.image_captions[-1]} ''')]
response = self.llm.invoke(messages)
response = response.content
state.stories.append(response)
return state
class BrainstromTopicGenerator:
def __init__(self):
self.llm = llm_gpt
def run(self,state:State)-> State:
template= brainstroming_prompt(state)
messages = [SystemMessage(content=template),
HumanMessage(content=f'''Here is the story to you for brainstorming:\n{state.stories[-1]}'''),
FunctionMessage(content=f'''The details of business is:\n{state.business_details}\n''', name="brainstorm_tool")]
print('Message for brainstorming:',messages)
response = self.llm.with_structured_output(BrainstromTopicFormatter).invoke(messages)
response = response.model_dump()
state.brainstroming_topics.append(response)
print('The brainstroming topics are:',state.brainstroming_topics)
return state
|