subashpoudel's picture
Debugged the unmatched f-string
32131c3
raw
history blame
2.46 kB
from langchain_core.messages import SystemMessage
from .tools import StoryFormatter
from .models_loader import llm
import base64
from PIL import Image
from io import BytesIO
from fastapi import UploadFile
from huggingface_hub import InferenceClient
from .prompts import story_to_prompt , final_story_prompt
import os
def generate_final_story(final_state):
if len(final_state['preferred_topics'])>0:
template = final_story_prompt(final_state)
messages = [SystemMessage(content=template)]
response = llm.bind_tools([StoryFormatter]).invoke(messages)
print('The final response is:',response)
if hasattr(response, 'tool_calls') and response.tool_calls:
response = response.tool_calls[0]['args']
elif hasattr(response, 'content'):
response = response.content
else:
response = "No response"
# state.final_story.append(response)
# state.stories.append(response)
return response
else:
return final_state['stories'][-1]
def encode_image_to_base64(uploaded_file: UploadFile) -> str:
return base64.b64encode(uploaded_file.file.read()).decode("utf-8")
# Convert base64 string to PIL image (optional for LangGraph processing)
def process_image(base64_str: str) -> Image.Image:
image_data = base64.b64decode(base64_str)
return Image.open(BytesIO(image_data))
def generate_prompt(final_story):
print('************Entering prompt generator****************')
messages = [
(
"system",
story_to_prompt,
),
("human", final_story),
]
prompt = llm.invoke(messages)
print('The prompt is:',prompt)
return prompt.content
def generate_image(final_story):
prompt = generate_prompt(final_story)
print('************Finished prompt generator****************')
client = InferenceClient(
provider="hf-inference",
api_key=os.environ.get('HUGGINGFACEHUB_ACCESS_TOKEN'),
)
print('************Finished calling generator****************')
# output is a PIL.Image object
image = client.text_to_image(
prompt,
model="black-forest-labs/FLUX.1-schnell",
)
print('*****************Image Created*******************')
image.save('image.png')
print('*****************Image Saved*******************')
return "Image Created"
# try:
# return image
# except:
# return 'Image created'