create_ai_ad / app.py
farukbera's picture
Update app.py
8f90c8c
from fastapi import FastAPI, Request, UploadFile, File, Query
from fastapi.exceptions import HTTPException
from PIL import Image, ImageDraw, ImageFont
import torch
from diffusers import StableDiffusionControlNetPipeline, StableDiffusionImg2ImgPipeline, DiffusionPipeline, StableDiffusionControlNetImg2ImgPipeline
from diffusers import ControlNetModel, UniPCMultistepScheduler
from transformers import pipeline
from starlette.responses import StreamingResponse
from fastapi.responses import FileResponse, JSONResponse, PlainTextResponse
import io
from fastapi.middleware.cors import CORSMiddleware
import os
from dotenv import load_dotenv
import requests
import cv2
import numpy as np
HF_TOKEN = os.getenv('HF_TOKEN')
app = FastAPI()
uploaded_image = None
uploaded_logo = None
generated_image = None
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
allow_credentials=True,
)
if torch.backends.mps.is_available():
device = "mps"
else:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("DEVICE: ",device)
pipe = None
if device == "cpu":
pipe = StableDiffusionImg2ImgPipeline.from_pretrained("nitrosocke/Ghibli-Diffusion",torch_dtype=torch.float32, use_auth_token=HF_TOKEN).to(device)
else:
pipe = StableDiffusionImg2ImgPipeline.from_pretrained("nitrosocke/Ghibli-Diffusion",torch_dtype=torch.float16, use_auth_token=HF_TOKEN).to(device)
#pipe1 = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5").to(device)
def hex_to_rgb(hex_code):
hex_code = hex_code.lstrip('#')
return tuple(int(hex_code[i:i+2], 16) for i in (0, 2, 4))
def get_depth_map(image, depth_estimator):
image = depth_estimator(image)["depth"]
image = np.array(image)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
detected_map = torch.from_numpy(image).float() / 255.0
depth_map = detected_map.permute(2, 0, 1)
return depth_map
@app.get("/")
async def root():
return {"message": "Welcome to the Creating Ad Template With Stable Diffusion API!"}
@app.post("/uploadImage", response_class=JSONResponse)
async def upload_image(image_file:UploadFile):
global uploaded_image
if not image_file.filename.endswith((".jpg", ".jpeg", ".png")):
raise HTTPException(status_code=400, detail="Invalid file format")
image_bytes = await image_file.read()
uploaded_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
return JSONResponse(content={"message": "Image uploaded successfully"})
@app.get("/get_image")
async def get_image(response_class=StreamingResponse):
if uploaded_image is not None:
# Return the uploaded image as a streaming response
image_bytes = io.BytesIO()
uploaded_image.save(image_bytes, format="PNG")
image_bytes.seek(0)
return StreamingResponse(image_bytes, media_type="image/png")
else:
raise HTTPException(status_code=400, detail="No image uploaded")
@app.post("/uploadLogo", response_class=JSONResponse)
async def upload_logo(logo_file:UploadFile):
global uploaded_logo
if not logo_file.filename.endswith((".jpg", ".jpeg", ".png")):
raise HTTPException(status_code=400, detail="Invalid file format")
logo_bytes = await logo_file.read()
uploaded_logo = Image.open(io.BytesIO(logo_bytes)).convert("RGB")
return JSONResponse(content={"message": "Logo uploaded successfully"})
@app.get("/get_logo")
async def get_logo(response_class=StreamingResponse):
if uploaded_logo is not None:
# Return the uploaded image as a streaming response
logo_bytes = io.BytesIO()
uploaded_logo.save(logo_bytes, format="PNG")
logo_bytes.seek(0)
return StreamingResponse(logo_bytes, media_type="image/png")
else:
raise HTTPException(status_code=400, detail="No logo uploaded")
@app.get("/generate_new_img",response_class=StreamingResponse)
async def generate_new_img(hex_code: str, prompt: str = Query(..., description="Text prompt for image generation")):
if uploaded_image is not None:
try:
print("Image is creating....")
# Generate the image using the text-to-image model
ad_prompt = f"""Your system prompt is this: {prompt} Consider your system prompt first.
Then from the initial image create a new image that will attract customers to put in an (ad template)
Also, use this RGB color {hex_to_rgb(hex_code)} as a tone in the image while image is still recognized as it is original."""
print(f"uploaded image type: {type(uploaded_image)}")
depth_estimator = pipeline("depth-estimation")
depth_map = None
if device=="cpu":
depth_map = get_depth_map(uploaded_image, depth_estimator).unsqueeze(0).to(device)
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11f1p_sd15_depth", torch_dtype=torch.float32)
pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float32
).to(device)
else:
depth_map = get_depth_map(uploaded_image, depth_estimator).unsqueeze(0).half().to(device)
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11f1p_sd15_depth", torch_dtype=torch.float16)
pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
).to(device)
pipe.enable_model_cpu_offload()
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
image_bytes = io.BytesIO()
uploaded_image.save(image_bytes, format="PNG")
image_bytes.seek(0)
init_image = Image.open(image_bytes).convert("RGB")
print(f"Image bytes length: {len(image_bytes.getvalue())}")
print(f"init_image type: {type(init_image)}")
image = pipe(
ad_prompt, image=init_image, control_image=depth_map
).images[0]
print(f"image type: {type(image)}")
print("Image created")
image_data = io.BytesIO()
image.save(image_data, format="PNG")
image_data.seek(0)
global generated_image
generated_image = image
return StreamingResponse(image_data, media_type="image/png")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
else:
return PlainTextResponse("You have not uploaded the image!")
@app.get("/create_ad_template",response_class=StreamingResponse)
async def create_ad_template(punchline: str, punchline_color:str, button_text:str,button_color:str):
if uploaded_logo is not None and generated_image is not None:
# Create LOGO
print("Drawing")
#logo_width, logo_height = logo_image.size
logo_width, logo_height = 100, 100 # Desired size for the logo
logo_image_resized = uploaded_logo.resize((logo_width, logo_height))
logo_image_location = ((800 - logo_width) // 2,10)
# Create Generataed Image
# Oku ve okuduktan sonra Image ile oku
generated_image_width, generated_image_height = 350, 350 # Desired size for the generated image
generated_image_resized = generated_image.resize((generated_image_width, generated_image_height))
generated_image_center_x = (800 - generated_image_width) // 2
generated_image_center_y = (600 - generated_image_height) // 2
generated_image_location = (generated_image_center_x, generated_image_center_y)
# Create a blank canvas for the ad template
ad_template = Image.new("RGB", (800, 600), "#FFFFFF")
print("Template created")
# Add logo
ad_template.paste(logo_image_resized, logo_image_location)
print("Logo added")
# Add generated image
ad_template.paste(generated_image_resized, generated_image_location)
print("Image added")
# Add the text at the bottom of the ad template
draw = ImageDraw.Draw(ad_template)
font = ImageFont.load_default()
text_width, text_height = draw.textsize(punchline, font=font)
text_position = (400 - text_width / 2, 500 - text_height / 2)
draw.text(text_position, punchline, font=font, fill=punchline_color)
print("Punchline added")
# Add the button at the bottom of the ad template
button_width, button_height = 200, 50
button_position = ((400 - button_width / 2), (550 - button_height / 2))
button_positions = [button_position[0], button_position[1], button_position[0] + button_width, button_position[1] + button_height]
draw.rectangle(button_positions, fill=button_color)
text_width, text_height = draw.textsize(button_text, font=font)
text_position = (button_position[0] + (button_width - text_width) / 2, button_position[1] + (button_height - text_height) / 2)
rect_text_color = (255, 255, 255) # Text color within the rectangle
draw.text(text_position, button_text, fill=rect_text_color, font=font)
border_width = 2 # Border width
border_positions = [0, 0, 800, 600]
draw.rectangle(border_positions, outline="#D60D0D", width=border_width)
print("Button added")
ad_template_data = io.BytesIO()
ad_template.save(ad_template_data, format="PNG")
ad_template_data.seek(0)
print("Template finished")
return StreamingResponse(ad_template_data, media_type="image/png")
else:
StreamingResponse("If you did not generate the new image or if you did not upload logo image. You should check and try again...", media_type="text/plain")