Spaces:
Sleeping
Sleeping
| from langchain_core.messages import SystemMessage | |
| from .tools import StoryFormatter | |
| from .models_loader import llm | |
| import base64 | |
| from PIL import Image | |
| from io import BytesIO | |
| from fastapi import UploadFile | |
| from huggingface_hub import InferenceClient | |
| from .prompts import story_to_prompt , final_story_prompt | |
| import os | |
| def generate_final_story(final_state): | |
| if len(final_state['preferred_topics'])>0: | |
| template = final_story_prompt(final_state) | |
| messages = [SystemMessage(content=template)] | |
| response = llm.bind_tools([StoryFormatter]).invoke(messages) | |
| print('The final response is:',response) | |
| if hasattr(response, 'tool_calls') and response.tool_calls: | |
| response = response.tool_calls[0]['args'] | |
| elif hasattr(response, 'content'): | |
| response = response.content | |
| else: | |
| response = "No response" | |
| # state.final_story.append(response) | |
| # state.stories.append(response) | |
| return response | |
| else: | |
| return final_state['stories'][-1] | |
| def encode_image_to_base64(uploaded_file: UploadFile) -> str: | |
| return base64.b64encode(uploaded_file.file.read()).decode("utf-8") | |
| # Convert base64 string to PIL image (optional for LangGraph processing) | |
| def process_image(base64_str: str) -> Image.Image: | |
| image_data = base64.b64decode(base64_str) | |
| return Image.open(BytesIO(image_data)) | |
| def generate_prompt(final_story): | |
| print('************Entering prompt generator****************') | |
| messages = [ | |
| ( | |
| "system", | |
| story_to_prompt, | |
| ), | |
| ("human", final_story), | |
| ] | |
| prompt = llm.invoke(messages) | |
| print('The prompt is:',prompt) | |
| return prompt.content | |
| def generate_image(final_story): | |
| prompt = generate_prompt(final_story) | |
| print('************Finished prompt generator****************') | |
| client = InferenceClient( | |
| provider="hf-inference", | |
| api_key=os.environ.get('HUGGINGFACEHUB_ACCESS_TOKEN'), | |
| ) | |
| print('************Finished calling generator****************') | |
| # output is a PIL.Image object | |
| image = client.text_to_image( | |
| prompt, | |
| model="black-forest-labs/FLUX.1-schnell", | |
| ) | |
| print('*****************Image Created*******************') | |
| image.save('image.png') | |
| print('*****************Image Saved*******************') | |
| return "Image Created" | |
| # try: | |
| # return image | |
| # except: | |
| # return 'Image created' | |