subashpoudel's picture
Converted code to OOP
ef9fa4b
from langchain_core.messages import SystemMessage,HumanMessage, FunctionMessage
from .tools import Retrieval
import base64
from PIL import Image
from io import BytesIO
from fastapi import UploadFile
from huggingface_hub import InferenceClient
from .prompts import story_to_prompt , final_story_prompt
import os
from langgraph.prebuilt import create_react_agent
from src.genai.utils.models_loader import llm_gpt, image_generation_model
class FinalStoryGenenrator:
def __init__(self):
self.llm = llm_gpt
self.agent = create_react_agent(model=llm_gpt,tools=[])
def generate_final_story(self,final_state):
if 'preferred_topics' in final_state:
if len(final_state['preferred_topics'])>0:
template = final_story_prompt(final_state)
messages = [SystemMessage(content=template),
HumanMessage(content=f'''The idea of the video is:\n{final_state['idea']}\n '''),
FunctionMessage(content=f'''The business details is:\n{final_state['business_details']}\nThe data of influencers is:\n{final_state['retrievals'][-1]}''',name='final_story_tool')]
for message_chunk , metadata in self.agent.stream({'messages':messages},stream_mode='messages'):
yield message_chunk.content
else:
for chunk in final_state['stories'][-1]:
yield chunk
else:
template = final_story_prompt(final_state)
influencers_data = Retrieval(str(final_state)).influencers_data()
messages = [SystemMessage(content=template),
FunctionMessage(content=f'''The business details is:\n{str(final_state)}\nThe data of influencers is:\n{influencers_data}''',name='final_story_tool')]
for message_chunk , metadata in self.agent.stream({'messages':messages},stream_mode='messages'):
yield message_chunk.content
class ImageGenerator:
def __init__(self):
self.llm = llm_gpt
self.image_generation_model = image_generation_model
def generate_prompt(self, final_story,business_details,refined_ideation):
messages = [SystemMessage(content=story_to_prompt()),
HumanMessage(content=f'''The scene-by-scene video story is {final_story}'''),
FunctionMessage(content=f'''The business details is:\n{business_details}\nThe idea is{refined_ideation}''',name='prompt_generation_id')
]
prompt = self.llm.invoke(messages)
return prompt.content
def generate_image(self,final_story, business_details, refined_ideation):
prompt = self.generate_prompt(final_story, business_details, refined_ideation)
client = InferenceClient(provider="hf-inference",api_key=os.environ.get('HUGGINGFACEHUB_ACCESS_TOKEN'))
image = client.text_to_image( prompt,model=self.image_generation_model)
buffered = BytesIO()
image.save(buffered, format="PNG")
buffered.seek(0)
img_base64 = base64.b64encode(buffered.read()).decode("utf-8")
return img_base64
def encode_image_to_base64(uploaded_file: UploadFile) -> str:
return base64.b64encode(uploaded_file.file.read()).decode("utf-8")
def process_image(base64_str: str) -> Image.Image:
image_data = base64.b64decode(base64_str)
return Image.open(BytesIO(image_data))