Spaces:
Sleeping
Sleeping
File size: 2,316 Bytes
946d35b 96b0973 946d35b 708437f efe9a51 72a7f4f 96b0973 708437f efe9a51 946d35b 96b0973 946d35b efe9a51 946d35b 708437f 32131c3 708437f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
from langchain_core.messages import SystemMessage
from .tools import StoryFormatter , retrieve_tool
from .models_loader import llm
import base64
from PIL import Image
from io import BytesIO
from fastapi import UploadFile
from huggingface_hub import InferenceClient
from .prompts import story_to_prompt , final_story_prompt
import os
from langgraph.prebuilt import create_react_agent
def generate_final_story(final_state):
if len(final_state['preferred_topics'])>0:
template = final_story_prompt(final_state)
messages = [SystemMessage(content=template)]
tools = [retrieve_tool]
react_agent=create_react_agent(
model=llm.bind_tools(tools),
tools=tools)
response = react_agent.invoke({'messages':messages})
response = response['messages'][-1].content
return response
else:
return final_state['stories'][-1]
def encode_image_to_base64(uploaded_file: UploadFile) -> str:
return base64.b64encode(uploaded_file.file.read()).decode("utf-8")
# Convert base64 string to PIL image (optional for LangGraph processing)
def process_image(base64_str: str) -> Image.Image:
image_data = base64.b64decode(base64_str)
return Image.open(BytesIO(image_data))
def generate_prompt(final_story):
print('************Entering prompt generator****************')
messages = [
(
"system",
story_to_prompt,
),
("human", final_story),
]
prompt = llm.invoke(messages)
print('The prompt is:',prompt)
return prompt.content
def generate_image(final_story):
prompt = generate_prompt(final_story)
print('************Finished prompt generator****************')
client = InferenceClient(
provider="hf-inference",
api_key=os.environ.get('HUGGINGFACEHUB_ACCESS_TOKEN'),
)
print('************Finished calling generator****************')
# output is a PIL.Image object
image = client.text_to_image(
prompt,
model="black-forest-labs/FLUX.1-schnell",
)
print('*****************Image Created*******************')
image.save('image.png')
print('*****************Image Saved*******************')
return "Image Created"
# try:
# return image
# except:
# return 'Image created'
|