File size: 3,556 Bytes
946d35b
ef9fa4b
 
946d35b
 
 
 
708437f
efe9a51
72a7f4f
96b0973
ef9fa4b
a6a0614
8039e4b
ef9fa4b
 
 
 
a6a0614
 
ef9fa4b
 
 
 
 
 
 
6874dac
 
ef9fa4b
 
8039e4b
ef9fa4b
 
 
6874dac
ef9fa4b
 
 
 
 
 
 
946d35b
ef9fa4b
 
 
 
 
 
 
 
 
 
 
 
 
a6a0614
ef9fa4b
 
 
 
 
 
 
 
 
946d35b
 
 
 
 
 
708437f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75

from langchain_core.messages import SystemMessage,HumanMessage, FunctionMessage
from .tools import Retrieval
import base64
from PIL import Image
from io import BytesIO
from fastapi import UploadFile
from huggingface_hub import InferenceClient
from .prompts import story_to_prompt , final_story_prompt
import os
from langgraph.prebuilt import create_react_agent
from src.genai.utils.models_loader import llm_gpt, image_generation_model
from langsmith import traceable

class FinalStoryGenenrator:
    def __init__(self):
        self.llm = llm_gpt
        self.agent = create_react_agent(model=llm_gpt,tools=[])
    
    @traceable(name="final-story")
    def generate_final_story(self,final_state):
        if 'preferred_topics' in final_state:
            if len(final_state['preferred_topics'])>0:
                template = final_story_prompt(final_state)
                messages = [SystemMessage(content=template),
                            HumanMessage(content=f'''The idea of the video is:\n{final_state['idea']}\n '''),
                            FunctionMessage(content=f'''The business details is:\n{final_state['business_details']}\nThe data of influencers is:\n{final_state['retrievals'][-1]}''',name='final_story_tool')]


                for message_chunk , metadata in self.agent.stream({'messages':messages},stream_mode='messages'):
                    yield message_chunk.content

            else:
                for chunk in final_state['stories'][-1]:
                    yield chunk
        else:
            template = final_story_prompt(final_state)
            influencers_data = Retrieval(str(final_state)).influencers_data()
            messages = [SystemMessage(content=template),
                            FunctionMessage(content=f'''The business details is:\n{str(final_state)}\nThe data of influencers is:\n{influencers_data}''',name='final_story_tool')]
            
            for message_chunk , metadata in self.agent.stream({'messages':messages},stream_mode='messages'):
                yield message_chunk.content

class ImageGenerator:
    def __init__(self):
        self.llm = llm_gpt
        self.image_generation_model = image_generation_model

    def generate_prompt(self, final_story,business_details,refined_ideation):
        messages = [SystemMessage(content=story_to_prompt()),
                    HumanMessage(content=f'''The scene-by-scene video story is {final_story}'''),
                    FunctionMessage(content=f'''The business details is:\n{business_details}\nThe idea is{refined_ideation}''',name='prompt_generation_id')
                    ]
        prompt = self.llm.invoke(messages)
        return prompt.content
    
    @traceable(name="image-generation")
    def generate_image(self,final_story, business_details, refined_ideation):
        prompt = self.generate_prompt(final_story, business_details, refined_ideation)
        client = InferenceClient(provider="hf-inference",api_key=os.environ.get('HUGGINGFACEHUB_ACCESS_TOKEN'))
        image = client.text_to_image( prompt,model=self.image_generation_model)
        buffered = BytesIO()
        image.save(buffered, format="PNG")  
        buffered.seek(0)
        img_base64 = base64.b64encode(buffered.read()).decode("utf-8")
        return img_base64

def encode_image_to_base64(uploaded_file: UploadFile) -> str:
    return base64.b64encode(uploaded_file.file.read()).decode("utf-8")

def process_image(base64_str: str) -> Image.Image:
    image_data = base64.b64decode(base64_str)
    return Image.open(BytesIO(image_data))