subashpoudel's picture
Updated the tool in final story generation
96b0973
raw
history blame
2.32 kB
from langchain_core.messages import SystemMessage
from .tools import StoryFormatter , retrieve_tool
from .models_loader import llm
import base64
from PIL import Image
from io import BytesIO
from fastapi import UploadFile
from huggingface_hub import InferenceClient
from .prompts import story_to_prompt , final_story_prompt
import os
from langgraph.prebuilt import create_react_agent
def generate_final_story(final_state):
if len(final_state['preferred_topics'])>0:
template = final_story_prompt(final_state)
messages = [SystemMessage(content=template)]
tools = [retrieve_tool]
react_agent=create_react_agent(
model=llm.bind_tools(tools),
tools=tools)
response = react_agent.invoke({'messages':messages})
response = response['messages'][-1].content
return response
else:
return final_state['stories'][-1]
def encode_image_to_base64(uploaded_file: UploadFile) -> str:
return base64.b64encode(uploaded_file.file.read()).decode("utf-8")
# Convert base64 string to PIL image (optional for LangGraph processing)
def process_image(base64_str: str) -> Image.Image:
image_data = base64.b64decode(base64_str)
return Image.open(BytesIO(image_data))
def generate_prompt(final_story):
print('************Entering prompt generator****************')
messages = [
(
"system",
story_to_prompt,
),
("human", final_story),
]
prompt = llm.invoke(messages)
print('The prompt is:',prompt)
return prompt.content
def generate_image(final_story):
prompt = generate_prompt(final_story)
print('************Finished prompt generator****************')
client = InferenceClient(
provider="hf-inference",
api_key=os.environ.get('HUGGINGFACEHUB_ACCESS_TOKEN'),
)
print('************Finished calling generator****************')
# output is a PIL.Image object
image = client.text_to_image(
prompt,
model="black-forest-labs/FLUX.1-schnell",
)
print('*****************Image Created*******************')
image.save('image.png')
print('*****************Image Saved*******************')
return "Image Created"
# try:
# return image
# except:
# return 'Image created'