Spaces:
Runtime error
Runtime error
| import os | |
| from dotenv import load_dotenv | |
| from langchain_core.prompts import PromptTemplate | |
| from openai import OpenAI | |
| from log_util import logger | |
| from time_it import time_it | |
| from util import load_prompt | |
| load_dotenv() | |
| IMAGE_GEN_API_BASE_URL = os.getenv('IMAGE_GEN_API_BASE_URL') | |
| IMAGE_GEN_API_KEY = os.getenv('IMAGE_GEN_API_KEY') | |
| IMAGE_GEN_MODEL = os.getenv('IMAGE_GEN_MODEL') | |
| IMAGE_GEN_MAX_PROMPT_LEN = int(os.getenv('IMAGE_GEN_MAX_PROMPT_LEN')) | |
| IMAGE_GEN_OPTIONS = { | |
| 'response_extension': 'png', | |
| 'width': 1024, | |
| 'height': 1024, | |
| 'num_inference_steps': int(os.getenv('NUM_INFERENCE_STEPS', '16')), | |
| 'negative_prompt': '', | |
| 'seed': -1 | |
| } | |
| def generate_image(prompt_file: str, input: dict) -> str: | |
| prompt = load_prompt(prompt_file) | |
| if len(prompt) > IMAGE_GEN_MAX_PROMPT_LEN: | |
| logger.info(f'Prompt length {len(prompt)} exceeds {IMAGE_GEN_MAX_PROMPT_LEN} characters, will be truncated.') | |
| prompt = prompt[:IMAGE_GEN_MAX_PROMPT_LEN] | |
| prompt_template = PromptTemplate.from_template(prompt) | |
| prompt = prompt_template.invoke(input).to_string() | |
| images_client = OpenAI(base_url=IMAGE_GEN_API_BASE_URL, api_key=IMAGE_GEN_API_KEY).images | |
| response = images_client.generate(model=IMAGE_GEN_MODEL, prompt=prompt, response_format='url', extra_body=IMAGE_GEN_OPTIONS) | |
| image_url = response.data[0].url | |
| logger.info(f'{image_url=}') | |
| return image_url | |