Spaces:
Sleeping
Sleeping
| from openai import OpenAI | |
| from pydantic import BaseModel | |
| import replicate | |
| from dotenv import load_dotenv | |
| import boto3 | |
| import os | |
| from uuid import uuid4 | |
| import requests | |
| from meta_data import meta_data_helper_function | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| from typing import List, Optional | |
| load_dotenv() | |
| gpt_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
| def generate_prompt(category, keyword, subkeyword, aspect_ratio): | |
| system_prompt = """You are professional prompt generator for the images which is to be used in the ads for the performance marketing. | |
| User will give you category, keyword and subkeyword as input. You will generate the prompt for the given category, keyword and subkeyword. | |
| Make sure the prompt which you would generate for generating the images should relevant to information provided. The images should be stocky. | |
| Strictly, their should be no text overlays mentioned in the prompt. I don't want any text on images so don't mention any text overlays in the prompt. | |
| For generating the images, google/imagen-4 model will be used so, craft the prompt for that model.""" | |
| user_prompt = f"Generate me the stocky images for the given category: {category} and keyword: {keyword} and subkeyword: {subkeyword}. The aspect ratio of the images should be: {aspect_ratio}." | |
| class GeneratedPrompt(BaseModel): | |
| prompt: str | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": system_prompt | |
| }, | |
| ] | |
| }, | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": user_prompt | |
| }, | |
| ] | |
| } | |
| ] | |
| completion = gpt_client.beta.chat.completions.parse( | |
| model="gpt-4o", | |
| messages=messages, | |
| response_format=GeneratedPrompt, | |
| ) | |
| prompt_response = completion.choices[0].message | |
| if prompt_response.parsed: | |
| return prompt_response.parsed | |
| else: | |
| return prompt_response.refusal | |
| def generate_images(prompt, aspect_ratio): | |
| replicate_client = replicate.Client(api_token=os.getenv("REPLICATE_API_TOKEN")) | |
| output = replicate_client.run( | |
| "google/imagen-4-ultra", | |
| input={ | |
| "prompt": prompt, | |
| "aspect_ratio": aspect_ratio | |
| } | |
| ) | |
| print(output) | |
| if isinstance(output, list) and output: | |
| first = output[0] | |
| url = getattr(first, "url", str(first)) | |
| urls = [url] | |
| elif isinstance(output, str): | |
| urls = [output] | |
| elif hasattr(output, "url"): | |
| urls = [getattr(output, "url")] | |
| return urls[0] | |
| def fetch_image_bytes(url): | |
| r = requests.get(url, timeout=60) | |
| r.raise_for_status() | |
| return r.content | |
| def init_s3(): | |
| return boto3.client( | |
| "s3", | |
| endpoint_url=os.getenv("R2_ENDPOINT"), | |
| aws_access_key_id=os.getenv("R2_ACCESS_KEY"), | |
| aws_secret_access_key=os.getenv("R2_SECRET_KEY"), | |
| region_name="auto", | |
| ) | |
| def upload_to_r2(image_bytes): | |
| s3_client = init_s3() | |
| filename = f"{uuid4().hex}.png" | |
| file_key = f"infinityverse/{filename}" | |
| s3_client.put_object( | |
| Bucket=os.getenv("R2_BUCKET_NAME"), | |
| Key=file_key, | |
| Body=image_bytes, | |
| ContentType="image/png", | |
| ) | |
| r2_url = f'{os.getenv("NEW_BASE").rstrip("/")}/{file_key}' | |
| return r2_url | |
| def _process_one(aspect_ratio: str, category: str, keyword: str, subkeyword: str) -> str: | |
| prompt_obj = generate_prompt(category, keyword, subkeyword, aspect_ratio) | |
| prompt_text = getattr(prompt_obj, "prompt", prompt_obj) | |
| image_urls = generate_images(prompt_text, aspect_ratio) | |
| image_bytes = fetch_image_bytes(image_urls) | |
| image_with_metadata = meta_data_helper_function(image_bytes) | |
| r2_url = upload_to_r2(image_with_metadata) | |
| return r2_url | |
| def get_images(category: str, keyword: str, subkeyword: str) -> List[str]: | |
| max_workers = 5 | |
| aspect_ratios = ["1:1", "1:1", "16:9", "16:9", "16:9"] | |
| urls: List[Optional[str]] = [None] * len(aspect_ratios) | |
| workers = max(1, min(max_workers, len(aspect_ratios))) | |
| with ThreadPoolExecutor(max_workers=workers) as pool: | |
| futures = { | |
| pool.submit(_process_one, ar, category, keyword, subkeyword): idx | |
| for idx, ar in enumerate(aspect_ratios) | |
| } | |
| for fut in as_completed(futures): | |
| idx = futures[fut] | |
| try: | |
| urls[idx] = fut.result() | |
| except Exception as e: | |
| print(e) | |
| urls[idx] = None | |
| return [u for u in urls if u is not None] |