File size: 4,443 Bytes
a9f99c3
ef9fa4b
 
 
93a5bf9
d98138c
ef9fa4b
 
fb491f0
a9f99c3
ef9fa4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb491f0
ef9fa4b
 
 
 
a0929ab
be3a5c4
ef9fa4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be3a5c4
 
 
415ac2b
be3a5c4
 
 
a9f99c3
8039e4b
a9f99c3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
from groq import Groq
from .state import State
from .tools import Retrieval
from .state import BrainstromTopicFormatter
from src.genai.utils.models_loader import llm_gpt , captioning_model
from langchain_core.messages import SystemMessage ,HumanMessage, FunctionMessage
from .prompts import image_captioning_prompt , initial_story_prompt , refined_story_prompt , brainstroming_prompt 


class ImageCaptioner:
    def __init__(self):
         self.captioning_model = captioning_model
         self.client =  Groq(api_key=os.environ.get('GROQ_API_KEY'))

    def run(self, state: State) -> State:
        if len(state.images)>0:
            if state.images[-1]!=None:
                print('Captioning image')

                chat_completion = self.client.chat.completions.create(
                    messages=[
                        {
                            "role": "user",
                            "content": [
                                {"type": "text", "text": image_captioning_prompt(state)},
                                {
                                    "type": "image_url",
                                    "image_url": {
                                        "url": f"data:image/jpeg;base64,{state.images[-1]}",
                                    },
                                },
                            ],
                        }
                    ],
                    model=self.captioning_model,
                    max_completion_tokens=50,
                    temperature = 1
                )
                response=chat_completion.choices[0].message.content
                state.image_captions.append(response)
                return state
        
        else:
            state.images.append(None)
            state.image_captions.append(None)
            return state
    

class Retriever:
    def __init__(self):
      self.retrievals = []
    
    def run(self,state: State) -> State:
        query_prompt = 'Represent this sentence for searching relevant passages: '
        if len(state.latest_preferred_topics)==0:
            for idea in state.idea:  
                result = Retrieval(idea+query_prompt).influencers_data()
                self.retrievals.append(result)
            state.retrievals.append(self.retrievals)

        if len (state.latest_preferred_topics)>0:
            state.preferred_topics.append(state.latest_preferred_topics)
            for idea in state.preferred_topics[-1]:  
                result = Retrieval(idea+query_prompt).influencers_data()
                self.retrievals.append(result)
            state.latest_preferred_topics=[]
            state.retrievals.append(self.retrievals)
        return state

class StoryGenerator:
    def __init__(self):
        self.llm = llm_gpt

    def run(self,state:State)-> State:
        if len(state.preferred_topics)==0:
            template = initial_story_prompt(state)
        else:
            template = refined_story_prompt(state)


        messages = [SystemMessage(content=template),
                    HumanMessage(content=f'''The idea of the video is:\n{state.idea}\n'''),
                    FunctionMessage(name='generate_story_function',content=f'''The business details is:\n{state.business_details}\n
                                The retrieved data of influencers is:\n{state.retrievals[-1]}\n
                                The information from the image is:\n{state.image_captions[-1]} ''')]
        
        response = self.llm.invoke(messages)
        response = response.content
        state.stories.append(response)
        return state


class BrainstromTopicGenerator:
    def __init__(self):
        self.llm = llm_gpt

    def run(self,state:State)-> State:
        template= brainstroming_prompt(state)

        messages = [SystemMessage(content=template),
                    HumanMessage(content=f'''Here is the story to you for brainstorming:\n{state.stories[-1]}'''),
                    FunctionMessage(content=f'''The details of business is:\n{state.business_details}\n''', name="brainstorm_tool")]
        print('Message for brainstorming:',messages)
        response = self.llm.with_structured_output(BrainstromTopicFormatter).invoke(messages)
        response = response.model_dump()
        state.brainstroming_topics.append(response)
        print('The brainstroming topics are:',state.brainstroming_topics)
        return state