Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, Request, UploadFile, File, Query | |
| from fastapi.exceptions import HTTPException | |
| from PIL import Image, ImageDraw, ImageFont | |
| import torch | |
| from diffusers import StableDiffusionControlNetPipeline, StableDiffusionImg2ImgPipeline, DiffusionPipeline, StableDiffusionControlNetImg2ImgPipeline | |
| from diffusers import ControlNetModel, UniPCMultistepScheduler | |
| from transformers import pipeline | |
| from starlette.responses import StreamingResponse | |
| from fastapi.responses import FileResponse, JSONResponse, PlainTextResponse | |
| import io | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import os | |
| from dotenv import load_dotenv | |
| import requests | |
| import cv2 | |
| import numpy as np | |
| HF_TOKEN = os.getenv('HF_TOKEN') | |
| app = FastAPI() | |
| uploaded_image = None | |
| uploaded_logo = None | |
| generated_image = None | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| allow_credentials=True, | |
| ) | |
| if torch.backends.mps.is_available(): | |
| device = "mps" | |
| else: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print("DEVICE: ",device) | |
| pipe = None | |
| if device == "cpu": | |
| pipe = StableDiffusionImg2ImgPipeline.from_pretrained("nitrosocke/Ghibli-Diffusion",torch_dtype=torch.float32, use_auth_token=HF_TOKEN).to(device) | |
| else: | |
| pipe = StableDiffusionImg2ImgPipeline.from_pretrained("nitrosocke/Ghibli-Diffusion",torch_dtype=torch.float16, use_auth_token=HF_TOKEN).to(device) | |
| #pipe1 = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5").to(device) | |
| def hex_to_rgb(hex_code): | |
| hex_code = hex_code.lstrip('#') | |
| return tuple(int(hex_code[i:i+2], 16) for i in (0, 2, 4)) | |
| def get_depth_map(image, depth_estimator): | |
| image = depth_estimator(image)["depth"] | |
| image = np.array(image) | |
| image = image[:, :, None] | |
| image = np.concatenate([image, image, image], axis=2) | |
| detected_map = torch.from_numpy(image).float() / 255.0 | |
| depth_map = detected_map.permute(2, 0, 1) | |
| return depth_map | |
| async def root(): | |
| return {"message": "Welcome to the Creating Ad Template With Stable Diffusion API!"} | |
| async def upload_image(image_file:UploadFile): | |
| global uploaded_image | |
| if not image_file.filename.endswith((".jpg", ".jpeg", ".png")): | |
| raise HTTPException(status_code=400, detail="Invalid file format") | |
| image_bytes = await image_file.read() | |
| uploaded_image = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
| return JSONResponse(content={"message": "Image uploaded successfully"}) | |
| async def get_image(response_class=StreamingResponse): | |
| if uploaded_image is not None: | |
| # Return the uploaded image as a streaming response | |
| image_bytes = io.BytesIO() | |
| uploaded_image.save(image_bytes, format="PNG") | |
| image_bytes.seek(0) | |
| return StreamingResponse(image_bytes, media_type="image/png") | |
| else: | |
| raise HTTPException(status_code=400, detail="No image uploaded") | |
| async def upload_logo(logo_file:UploadFile): | |
| global uploaded_logo | |
| if not logo_file.filename.endswith((".jpg", ".jpeg", ".png")): | |
| raise HTTPException(status_code=400, detail="Invalid file format") | |
| logo_bytes = await logo_file.read() | |
| uploaded_logo = Image.open(io.BytesIO(logo_bytes)).convert("RGB") | |
| return JSONResponse(content={"message": "Logo uploaded successfully"}) | |
| async def get_logo(response_class=StreamingResponse): | |
| if uploaded_logo is not None: | |
| # Return the uploaded image as a streaming response | |
| logo_bytes = io.BytesIO() | |
| uploaded_logo.save(logo_bytes, format="PNG") | |
| logo_bytes.seek(0) | |
| return StreamingResponse(logo_bytes, media_type="image/png") | |
| else: | |
| raise HTTPException(status_code=400, detail="No logo uploaded") | |
| async def generate_new_img(hex_code: str, prompt: str = Query(..., description="Text prompt for image generation")): | |
| if uploaded_image is not None: | |
| try: | |
| print("Image is creating....") | |
| # Generate the image using the text-to-image model | |
| ad_prompt = f"""Your system prompt is this: {prompt} Consider your system prompt first. | |
| Then from the initial image create a new image that will attract customers to put in an (ad template) | |
| Also, use this RGB color {hex_to_rgb(hex_code)} as a tone in the image while image is still recognized as it is original.""" | |
| print(f"uploaded image type: {type(uploaded_image)}") | |
| depth_estimator = pipeline("depth-estimation") | |
| depth_map = None | |
| if device=="cpu": | |
| depth_map = get_depth_map(uploaded_image, depth_estimator).unsqueeze(0).to(device) | |
| controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11f1p_sd15_depth", torch_dtype=torch.float32) | |
| pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( | |
| "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float32 | |
| ).to(device) | |
| else: | |
| depth_map = get_depth_map(uploaded_image, depth_estimator).unsqueeze(0).half().to(device) | |
| controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11f1p_sd15_depth", torch_dtype=torch.float16) | |
| pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( | |
| "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 | |
| ).to(device) | |
| pipe.enable_model_cpu_offload() | |
| pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) | |
| image_bytes = io.BytesIO() | |
| uploaded_image.save(image_bytes, format="PNG") | |
| image_bytes.seek(0) | |
| init_image = Image.open(image_bytes).convert("RGB") | |
| print(f"Image bytes length: {len(image_bytes.getvalue())}") | |
| print(f"init_image type: {type(init_image)}") | |
| image = pipe( | |
| ad_prompt, image=init_image, control_image=depth_map | |
| ).images[0] | |
| print(f"image type: {type(image)}") | |
| print("Image created") | |
| image_data = io.BytesIO() | |
| image.save(image_data, format="PNG") | |
| image_data.seek(0) | |
| global generated_image | |
| generated_image = image | |
| return StreamingResponse(image_data, media_type="image/png") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| else: | |
| return PlainTextResponse("You have not uploaded the image!") | |
| async def create_ad_template(punchline: str, punchline_color:str, button_text:str,button_color:str): | |
| if uploaded_logo is not None and generated_image is not None: | |
| # Create LOGO | |
| print("Drawing") | |
| #logo_width, logo_height = logo_image.size | |
| logo_width, logo_height = 100, 100 # Desired size for the logo | |
| logo_image_resized = uploaded_logo.resize((logo_width, logo_height)) | |
| logo_image_location = ((800 - logo_width) // 2,10) | |
| # Create Generataed Image | |
| # Oku ve okuduktan sonra Image ile oku | |
| generated_image_width, generated_image_height = 350, 350 # Desired size for the generated image | |
| generated_image_resized = generated_image.resize((generated_image_width, generated_image_height)) | |
| generated_image_center_x = (800 - generated_image_width) // 2 | |
| generated_image_center_y = (600 - generated_image_height) // 2 | |
| generated_image_location = (generated_image_center_x, generated_image_center_y) | |
| # Create a blank canvas for the ad template | |
| ad_template = Image.new("RGB", (800, 600), "#FFFFFF") | |
| print("Template created") | |
| # Add logo | |
| ad_template.paste(logo_image_resized, logo_image_location) | |
| print("Logo added") | |
| # Add generated image | |
| ad_template.paste(generated_image_resized, generated_image_location) | |
| print("Image added") | |
| # Add the text at the bottom of the ad template | |
| draw = ImageDraw.Draw(ad_template) | |
| font = ImageFont.load_default() | |
| text_width, text_height = draw.textsize(punchline, font=font) | |
| text_position = (400 - text_width / 2, 500 - text_height / 2) | |
| draw.text(text_position, punchline, font=font, fill=punchline_color) | |
| print("Punchline added") | |
| # Add the button at the bottom of the ad template | |
| button_width, button_height = 200, 50 | |
| button_position = ((400 - button_width / 2), (550 - button_height / 2)) | |
| button_positions = [button_position[0], button_position[1], button_position[0] + button_width, button_position[1] + button_height] | |
| draw.rectangle(button_positions, fill=button_color) | |
| text_width, text_height = draw.textsize(button_text, font=font) | |
| text_position = (button_position[0] + (button_width - text_width) / 2, button_position[1] + (button_height - text_height) / 2) | |
| rect_text_color = (255, 255, 255) # Text color within the rectangle | |
| draw.text(text_position, button_text, fill=rect_text_color, font=font) | |
| border_width = 2 # Border width | |
| border_positions = [0, 0, 800, 600] | |
| draw.rectangle(border_positions, outline="#D60D0D", width=border_width) | |
| print("Button added") | |
| ad_template_data = io.BytesIO() | |
| ad_template.save(ad_template_data, format="PNG") | |
| ad_template_data.seek(0) | |
| print("Template finished") | |
| return StreamingResponse(ad_template_data, media_type="image/png") | |
| else: | |
| StreamingResponse("If you did not generate the new image or if you did not upload logo image. You should check and try again...", media_type="text/plain") | |