stock-image / utils.py
userIdc2024's picture
Update utils.py
cf6d25f verified
from openai import OpenAI
from pydantic import BaseModel
import replicate
from dotenv import load_dotenv
import boto3
import os
from uuid import uuid4
import requests
from meta_data import meta_data_helper_function
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, Optional
load_dotenv()
gpt_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
def generate_prompt(category, keyword, subkeyword, aspect_ratio):
system_prompt = """You are professional prompt generator for the images which is to be used in the ads for the performance marketing.
User will give you category, keyword and subkeyword as input. You will generate the prompt for the given category, keyword and subkeyword.
Make sure the prompt which you would generate for generating the images should relevant to information provided. The images should be stocky.
Strictly, their should be no text overlays mentioned in the prompt. I don't want any text on images so don't mention any text overlays in the prompt.
For generating the images, google/imagen-4 model will be used so, craft the prompt for that model."""
user_prompt = f"Generate me the stocky images for the given category: {category} and keyword: {keyword} and subkeyword: {subkeyword}. The aspect ratio of the images should be: {aspect_ratio}."
class GeneratedPrompt(BaseModel):
prompt: str
messages = [
{
"role": "system",
"content": [
{
"type": "text",
"text": system_prompt
},
]
},
{
"role": "user",
"content": [
{
"type": "text",
"text": user_prompt
},
]
}
]
completion = gpt_client.beta.chat.completions.parse(
model="gpt-4o",
messages=messages,
response_format=GeneratedPrompt,
)
prompt_response = completion.choices[0].message
if prompt_response.parsed:
return prompt_response.parsed
else:
return prompt_response.refusal
def generate_images(prompt, aspect_ratio):
replicate_client = replicate.Client(api_token=os.getenv("REPLICATE_API_TOKEN"))
output = replicate_client.run(
"google/imagen-4-ultra",
input={
"prompt": prompt,
"aspect_ratio": aspect_ratio
}
)
print(output)
if isinstance(output, list) and output:
first = output[0]
url = getattr(first, "url", str(first))
urls = [url]
elif isinstance(output, str):
urls = [output]
elif hasattr(output, "url"):
urls = [getattr(output, "url")]
return urls[0]
def fetch_image_bytes(url):
r = requests.get(url, timeout=60)
r.raise_for_status()
return r.content
def init_s3():
return boto3.client(
"s3",
endpoint_url=os.getenv("R2_ENDPOINT"),
aws_access_key_id=os.getenv("R2_ACCESS_KEY"),
aws_secret_access_key=os.getenv("R2_SECRET_KEY"),
region_name="auto",
)
def upload_to_r2(image_bytes):
s3_client = init_s3()
filename = f"{uuid4().hex}.png"
file_key = f"infinityverse/{filename}"
s3_client.put_object(
Bucket=os.getenv("R2_BUCKET_NAME"),
Key=file_key,
Body=image_bytes,
ContentType="image/png",
)
r2_url = f'{os.getenv("NEW_BASE").rstrip("/")}/{file_key}'
return r2_url
def _process_one(aspect_ratio: str, category: str, keyword: str, subkeyword: str) -> str:
prompt_obj = generate_prompt(category, keyword, subkeyword, aspect_ratio)
prompt_text = getattr(prompt_obj, "prompt", prompt_obj)
image_urls = generate_images(prompt_text, aspect_ratio)
image_bytes = fetch_image_bytes(image_urls)
image_with_metadata = meta_data_helper_function(image_bytes)
r2_url = upload_to_r2(image_with_metadata)
return r2_url
def get_images(category: str, keyword: str, subkeyword: str) -> List[str]:
max_workers = 5
aspect_ratios = ["1:1", "1:1", "16:9", "16:9", "16:9"]
urls: List[Optional[str]] = [None] * len(aspect_ratios)
workers = max(1, min(max_workers, len(aspect_ratios)))
with ThreadPoolExecutor(max_workers=workers) as pool:
futures = {
pool.submit(_process_one, ar, category, keyword, subkeyword): idx
for idx, ar in enumerate(aspect_ratios)
}
for fut in as_completed(futures):
idx = futures[fut]
try:
urls[idx] = fut.result()
except Exception as e:
print(e)
urls[idx] = None
return [u for u in urls if u is not None]