Spaces:
Runtime error
Runtime error
File size: 10,165 Bytes
5a52f5f 91df016 9e20865 3ba4868 ca6ca0b 2b83d6a 5a52f5f 5caf7d3 5a52f5f 6757c8f 296b0c1 4033fdf 7792198 0c8265b 296b0c1 9e20865 5f278ff 5ff31ab 9266264 d094bcb f2ad27e 17b52c1 5ff31ab 9e20865 daf4b62 518ea8e daf4b62 9e20865 cd8f78d cb289b8 be5df3f 22e2b0f 9266264 5f278ff 0e0878c a9e3877 0e0878c 344843b c70db13 5ff31ab 3e259c5 5ff31ab a9e3877 0e0878c 5ff31ab 6f702cb 9266264 5ff31ab 05772a5 5ff31ab 6f702cb 5ff31ab a893ffa 5ff31ab 0e0878c 5ff31ab 9b282a7 30a2879 aab97f4 cbeac6f 6a765a3 518ea8e 5fbb266 5ff31ab 5fbb266 990c7b6 8956d82 990c7b6 5ff31ab 518ea8e 5ff31ab 5fbb266 990c7b6 8956d82 990c7b6 5ff31ab ca6ca0b 5ff31ab a35ad54 1316998 cbeac6f d894182 cbeac6f 518ea8e cbeac6f 990c7b6 cbeac6f aab97f4 fdf85a0 aab97f4 5ff31ab f6020c8 d5ed8d7 5ff31ab 8956d82 5ff31ab 8f90c8c 5ff31ab 9ca642c 5ff31ab f6020c8 5ff31ab 8956d82 5ff31ab 8f90c8c 8956d82 5ff31ab 9ca642c 8956d82 5ff31ab d94c94c 5ff31ab 8956d82 5ff31ab 8956d82 5ff31ab 8956d82 5ff31ab 8956d82 5ff31ab 2043f8a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 | from fastapi import FastAPI, Request, UploadFile, File, Query
from fastapi.exceptions import HTTPException
from PIL import Image, ImageDraw, ImageFont
import torch
from diffusers import StableDiffusionControlNetPipeline, StableDiffusionImg2ImgPipeline, DiffusionPipeline, StableDiffusionControlNetImg2ImgPipeline
from diffusers import ControlNetModel, UniPCMultistepScheduler
from transformers import pipeline
from starlette.responses import StreamingResponse
from fastapi.responses import FileResponse, JSONResponse, PlainTextResponse
import io
from fastapi.middleware.cors import CORSMiddleware
import os
from dotenv import load_dotenv
import requests
import cv2
import numpy as np
HF_TOKEN = os.getenv('HF_TOKEN')
app = FastAPI()
uploaded_image = None
uploaded_logo = None
generated_image = None
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
allow_credentials=True,
)
if torch.backends.mps.is_available():
device = "mps"
else:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("DEVICE: ",device)
pipe = None
if device == "cpu":
pipe = StableDiffusionImg2ImgPipeline.from_pretrained("nitrosocke/Ghibli-Diffusion",torch_dtype=torch.float32, use_auth_token=HF_TOKEN).to(device)
else:
pipe = StableDiffusionImg2ImgPipeline.from_pretrained("nitrosocke/Ghibli-Diffusion",torch_dtype=torch.float16, use_auth_token=HF_TOKEN).to(device)
#pipe1 = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5").to(device)
def hex_to_rgb(hex_code):
hex_code = hex_code.lstrip('#')
return tuple(int(hex_code[i:i+2], 16) for i in (0, 2, 4))
def get_depth_map(image, depth_estimator):
image = depth_estimator(image)["depth"]
image = np.array(image)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
detected_map = torch.from_numpy(image).float() / 255.0
depth_map = detected_map.permute(2, 0, 1)
return depth_map
@app.get("/")
async def root():
return {"message": "Welcome to the Creating Ad Template With Stable Diffusion API!"}
@app.post("/uploadImage", response_class=JSONResponse)
async def upload_image(image_file:UploadFile):
global uploaded_image
if not image_file.filename.endswith((".jpg", ".jpeg", ".png")):
raise HTTPException(status_code=400, detail="Invalid file format")
image_bytes = await image_file.read()
uploaded_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
return JSONResponse(content={"message": "Image uploaded successfully"})
@app.get("/get_image")
async def get_image(response_class=StreamingResponse):
if uploaded_image is not None:
# Return the uploaded image as a streaming response
image_bytes = io.BytesIO()
uploaded_image.save(image_bytes, format="PNG")
image_bytes.seek(0)
return StreamingResponse(image_bytes, media_type="image/png")
else:
raise HTTPException(status_code=400, detail="No image uploaded")
@app.post("/uploadLogo", response_class=JSONResponse)
async def upload_logo(logo_file:UploadFile):
global uploaded_logo
if not logo_file.filename.endswith((".jpg", ".jpeg", ".png")):
raise HTTPException(status_code=400, detail="Invalid file format")
logo_bytes = await logo_file.read()
uploaded_logo = Image.open(io.BytesIO(logo_bytes)).convert("RGB")
return JSONResponse(content={"message": "Logo uploaded successfully"})
@app.get("/get_logo")
async def get_logo(response_class=StreamingResponse):
if uploaded_logo is not None:
# Return the uploaded image as a streaming response
logo_bytes = io.BytesIO()
uploaded_logo.save(logo_bytes, format="PNG")
logo_bytes.seek(0)
return StreamingResponse(logo_bytes, media_type="image/png")
else:
raise HTTPException(status_code=400, detail="No logo uploaded")
@app.get("/generate_new_img",response_class=StreamingResponse)
async def generate_new_img(hex_code: str, prompt: str = Query(..., description="Text prompt for image generation")):
if uploaded_image is not None:
try:
print("Image is creating....")
# Generate the image using the text-to-image model
ad_prompt = f"""Your system prompt is this: {prompt} Consider your system prompt first.
Then from the initial image create a new image that will attract customers to put in an (ad template)
Also, use this RGB color {hex_to_rgb(hex_code)} as a tone in the image while image is still recognized as it is original."""
print(f"uploaded image type: {type(uploaded_image)}")
depth_estimator = pipeline("depth-estimation")
depth_map = None
if device=="cpu":
depth_map = get_depth_map(uploaded_image, depth_estimator).unsqueeze(0).to(device)
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11f1p_sd15_depth", torch_dtype=torch.float32)
pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float32
).to(device)
else:
depth_map = get_depth_map(uploaded_image, depth_estimator).unsqueeze(0).half().to(device)
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11f1p_sd15_depth", torch_dtype=torch.float16)
pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
).to(device)
pipe.enable_model_cpu_offload()
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
image_bytes = io.BytesIO()
uploaded_image.save(image_bytes, format="PNG")
image_bytes.seek(0)
init_image = Image.open(image_bytes).convert("RGB")
print(f"Image bytes length: {len(image_bytes.getvalue())}")
print(f"init_image type: {type(init_image)}")
image = pipe(
ad_prompt, image=init_image, control_image=depth_map
).images[0]
print(f"image type: {type(image)}")
print("Image created")
image_data = io.BytesIO()
image.save(image_data, format="PNG")
image_data.seek(0)
global generated_image
generated_image = image
return StreamingResponse(image_data, media_type="image/png")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
else:
return PlainTextResponse("You have not uploaded the image!")
@app.get("/create_ad_template",response_class=StreamingResponse)
async def create_ad_template(punchline: str, punchline_color:str, button_text:str,button_color:str):
if uploaded_logo is not None and generated_image is not None:
# Create LOGO
print("Drawing")
#logo_width, logo_height = logo_image.size
logo_width, logo_height = 100, 100 # Desired size for the logo
logo_image_resized = uploaded_logo.resize((logo_width, logo_height))
logo_image_location = ((800 - logo_width) // 2,10)
# Create Generataed Image
# Oku ve okuduktan sonra Image ile oku
generated_image_width, generated_image_height = 350, 350 # Desired size for the generated image
generated_image_resized = generated_image.resize((generated_image_width, generated_image_height))
generated_image_center_x = (800 - generated_image_width) // 2
generated_image_center_y = (600 - generated_image_height) // 2
generated_image_location = (generated_image_center_x, generated_image_center_y)
# Create a blank canvas for the ad template
ad_template = Image.new("RGB", (800, 600), "#FFFFFF")
print("Template created")
# Add logo
ad_template.paste(logo_image_resized, logo_image_location)
print("Logo added")
# Add generated image
ad_template.paste(generated_image_resized, generated_image_location)
print("Image added")
# Add the text at the bottom of the ad template
draw = ImageDraw.Draw(ad_template)
font = ImageFont.load_default()
text_width, text_height = draw.textsize(punchline, font=font)
text_position = (400 - text_width / 2, 500 - text_height / 2)
draw.text(text_position, punchline, font=font, fill=punchline_color)
print("Punchline added")
# Add the button at the bottom of the ad template
button_width, button_height = 200, 50
button_position = ((400 - button_width / 2), (550 - button_height / 2))
button_positions = [button_position[0], button_position[1], button_position[0] + button_width, button_position[1] + button_height]
draw.rectangle(button_positions, fill=button_color)
text_width, text_height = draw.textsize(button_text, font=font)
text_position = (button_position[0] + (button_width - text_width) / 2, button_position[1] + (button_height - text_height) / 2)
rect_text_color = (255, 255, 255) # Text color within the rectangle
draw.text(text_position, button_text, fill=rect_text_color, font=font)
border_width = 2 # Border width
border_positions = [0, 0, 800, 600]
draw.rectangle(border_positions, outline="#D60D0D", width=border_width)
print("Button added")
ad_template_data = io.BytesIO()
ad_template.save(ad_template_data, format="PNG")
ad_template_data.seek(0)
print("Template finished")
return StreamingResponse(ad_template_data, media_type="image/png")
else:
StreamingResponse("If you did not generate the new image or if you did not upload logo image. You should check and try again...", media_type="text/plain")
|