| | import gc |
| | import math |
| | import multiprocessing |
| | import os |
| | import traceback |
| | from datetime import datetime |
| | from io import BytesIO |
| | from itertools import permutations |
| | from multiprocessing.pool import Pool |
| | from pathlib import Path |
| | from urllib.parse import quote_plus |
| |
|
| | import numpy as np |
| | import nltk |
| | import torch |
| |
|
| | from PIL.Image import Image |
| | from diffusers import DiffusionPipeline, StableDiffusionXLInpaintPipeline |
| | from diffusers.utils import load_image |
| | from fastapi import FastAPI |
| | from fastapi.middleware.gzip import GZipMiddleware |
| | from loguru import logger |
| | from starlette.middleware.cors import CORSMiddleware |
| | from starlette.responses import FileResponse |
| | from starlette.responses import JSONResponse |
| |
|
| | from env import BUCKET_PATH, BUCKET_NAME |
| | |
| | torch._dynamo.config.suppress_errors = True |
| |
|
| | import string |
| | import random |
| |
|
| | def generate_save_path(): |
| | |
| | N = 7 |
| |
|
| | |
| | |
| | res = ''.join(random.choices(string.ascii_uppercase + |
| | string.digits, k=N)) |
| | return res |
| |
|
| | pipe = DiffusionPipeline.from_pretrained( |
| | "models/stable-diffusion-xl-base-1.0", |
| | torch_dtype=torch.bfloat16, |
| | use_safetensors=True, |
| | variant="fp16", |
| | |
| | ) |
| | pipe.watermark = None |
| |
|
| | pipe.to("cuda") |
| |
|
| | refiner = DiffusionPipeline.from_pretrained( |
| | "stabilityai/stable-diffusion-xl-refiner-1.0", |
| | text_encoder_2=pipe.text_encoder_2, |
| | vae=pipe.vae, |
| | torch_dtype=torch.bfloat16, |
| | use_safetensors=True, |
| | variant="fp16", |
| | ) |
| | refiner.watermark = None |
| | refiner.to("cuda") |
| |
|
| | |
| | inpaintpipe = StableDiffusionXLInpaintPipeline.from_pretrained( |
| | "models/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16, variant="fp16", use_safetensors=True, |
| | scheduler=pipe.scheduler, |
| | text_encoder=pipe.text_encoder, |
| | text_encoder_2=pipe.text_encoder_2, |
| | tokenizer=pipe.tokenizer, |
| | tokenizer_2=pipe.tokenizer_2, |
| | unet=pipe.unet, |
| | vae=pipe.vae, |
| | |
| | ) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | inpaintpipe.to("cuda") |
| | inpaintpipe.watermark = None |
| | |
| |
|
| | inpaint_refiner = StableDiffusionXLInpaintPipeline.from_pretrained( |
| | "stabilityai/stable-diffusion-xl-refiner-1.0", |
| | text_encoder_2=inpaintpipe.text_encoder_2, |
| | vae=inpaintpipe.vae, |
| | torch_dtype=torch.bfloat16, |
| | use_safetensors=True, |
| | variant="fp16", |
| |
|
| | tokenizer_2=refiner.tokenizer_2, |
| | tokenizer=refiner.tokenizer, |
| | scheduler=refiner.scheduler, |
| | text_encoder=refiner.text_encoder, |
| | unet=refiner.unet, |
| | ) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | inpaint_refiner.to("cuda") |
| | inpaint_refiner.watermark = None |
| | |
| |
|
| | n_steps = 40 |
| | high_noise_frac = 0.8 |
| |
|
| | |
| | |
| |
|
| |
|
| | |
| | |
| | pipe.unet = torch.compile(pipe.unet) |
| | refiner.unet = torch.compile(refiner.unet) |
| | |
| | inpaintpipe.unet = pipe.unet |
| | inpaint_refiner.unet = refiner.unet |
| | |
| | |
| | from pydantic import BaseModel |
| |
|
| | app = FastAPI( |
| | openapi_url="/static/openapi.json", |
| | docs_url="/swagger-docs", |
| | redoc_url="/redoc", |
| | title="Generate Images Netwrck API", |
| | description="Character Chat API", |
| | |
| | version="1", |
| | ) |
| | app.add_middleware(GZipMiddleware, minimum_size=1000) |
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_credentials=True, |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | ) |
| |
|
| | stopwords = nltk.corpus.stopwords.words("english") |
| |
|
| | class Img(BaseModel): |
| | system_prompt: str |
| | ASSISTANT: str |
| |
|
| | |
| | img_url = "http://phlrr3058.guest.corp.microsoft.com:8000/" |
| |
|
| | @app.post("/image_url") |
| | def image_url(img: Img): |
| | system_prompt = img.system_prompt |
| | prompt = img.ASSISTANT |
| | |
| | |
| | |
| | image = pipe(prompt=prompt).images[0] |
| | |
| | save_path = generate_save_path() |
| | save_path = f"images/{save_path}.png" |
| | image.save(save_path) |
| | |
| | path = f"{img_url}/{save_path}" |
| | return JSONResponse({"path": path}) |
| |
|
| |
|
| | @app.get("/make_image") |
| | |
| | def make_image(prompt: str, save_path: str = ""): |
| | if Path(save_path).exists(): |
| | return FileResponse(save_path, media_type="image/png") |
| | image = pipe(prompt=prompt).images[0] |
| | if not save_path: |
| | save_path = f"images/{prompt}.png" |
| | image.save(save_path) |
| | return FileResponse(save_path, media_type="image/png") |
| |
|
| |
|
| | @app.get("/create_and_upload_image") |
| | def create_and_upload_image(prompt: str, width: int=1024, height:int=1024, save_path: str = ""): |
| | path_components = save_path.split("/")[0:-1] |
| | final_name = save_path.split("/")[-1] |
| | if not path_components: |
| | path_components = [] |
| | save_path = '/'.join(path_components) + quote_plus(final_name) |
| | path = get_image_or_create_upload_to_cloud_storage(prompt, width, height, save_path) |
| | return JSONResponse({"path": path}) |
| |
|
| | @app.get("/inpaint_and_upload_image") |
| | def inpaint_and_upload_image(prompt: str, image_url:str, mask_url:str, save_path: str = ""): |
| | path_components = save_path.split("/")[0:-1] |
| | final_name = save_path.split("/")[-1] |
| | if not path_components: |
| | path_components = [] |
| | save_path = '/'.join(path_components) + quote_plus(final_name) |
| | path = get_image_or_inpaint_upload_to_cloud_storage(prompt, image_url, mask_url, save_path) |
| | return JSONResponse({"path": path}) |
| |
|
| |
|
| | def get_image_or_create_upload_to_cloud_storage(prompt:str,width:int, height:int, save_path:str): |
| | prompt = shorten_too_long_text(prompt) |
| | save_path = shorten_too_long_text(save_path) |
| | |
| | if check_if_blob_exists(save_path): |
| | return f"https://{BUCKET_NAME}/{BUCKET_PATH}/{save_path}" |
| | bio = create_image_from_prompt(prompt, width, height) |
| | if bio is None: |
| | return None |
| | link = upload_to_bucket(save_path, bio, is_bytesio=True) |
| | return link |
| | def get_image_or_inpaint_upload_to_cloud_storage(prompt:str, image_url:str, mask_url:str, save_path:str): |
| | prompt = shorten_too_long_text(prompt) |
| | save_path = shorten_too_long_text(save_path) |
| | |
| | if check_if_blob_exists(save_path): |
| | return f"https://{BUCKET_NAME}/{BUCKET_PATH}/{save_path}" |
| | bio = inpaint_image_from_prompt(prompt, image_url, mask_url) |
| | if bio is None: |
| | return None |
| | link = upload_to_bucket(save_path, bio, is_bytesio=True) |
| | return link |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | def create_image_from_prompt(prompt, width, height): |
| | |
| | block_width = width - (width % 64) |
| | block_height = height - (height % 64) |
| | prompt = shorten_too_long_text(prompt) |
| | |
| | try: |
| | image = pipe(prompt=prompt, |
| | width=block_width, |
| | height=block_height, |
| | |
| | |
| | |
| | |
| | num_inference_steps=50).images[0] |
| | except Exception as e: |
| | |
| | |
| | logger.info(f"trying to shorten prompt of length {len(prompt)}") |
| |
|
| | prompt = ' '.join((word for word in prompt if word not in stopwords)) |
| | prompts = prompt.split() |
| |
|
| | prompt = ' '.join(prompts[:len(prompts) // 2]) |
| | logger.info(f"shortened prompt to: {len(prompt)}") |
| | image = None |
| | if prompt: |
| | try: |
| | image = pipe(prompt=prompt, |
| | width=block_width, |
| | height=block_height, |
| | |
| | |
| | |
| | |
| | num_inference_steps=50).images[0] |
| | except Exception as e: |
| | |
| | |
| | |
| | |
| | logger.info(f"trying to shorten prompt of length {len(prompt)}") |
| |
|
| | prompt = ' '.join((word for word in prompt if word not in stopwords)) |
| | prompts = prompt.split() |
| |
|
| | prompt = ' '.join(prompts[:len(prompts) // 2]) |
| | logger.info(f"shortened prompt to: {len(prompt)}") |
| |
|
| | try: |
| | image = pipe(prompt=prompt, |
| | width=block_width, |
| | height=block_height, |
| | |
| | |
| | |
| | |
| | num_inference_steps=50).images[0] |
| | except Exception as e: |
| | |
| | traceback.print_exc() |
| | raise e |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | if width != block_width or height != block_height: |
| | |
| | |
| | scale_up_ratio = max(width / block_width, height / block_height) |
| | image = image.resize((math.ceil(block_width * scale_up_ratio), math.ceil(height * scale_up_ratio))) |
| | |
| | image = image.crop((0, 0, width, height)) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | bs = BytesIO() |
| |
|
| | bright_count = np.sum(np.array(image) > 0) |
| | if bright_count == 0: |
| | |
| | logger.info("restarting server to fix cuda issues (device side asserts)") |
| | |
| | |
| | |
| | os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`") |
| | os.system("kill -1 `pgrep gunicorn`") |
| | os.system("/usr/bin/bash kill -SIGHUP `pgrep uvicorn`") |
| | os.system("kill -1 `pgrep uvicorn`") |
| |
|
| | return None |
| | image.save(bs, quality=85, optimize=True, format="webp") |
| | bio = bs.getvalue() |
| | |
| | with open("progress.txt", "w") as f: |
| | current_time = datetime.now().strftime("%H:%M:%S") |
| | f.write(f"{current_time}") |
| | return bio |
| |
|
| | def inpaint_image_from_prompt(prompt, image_url: str, mask_url: str): |
| | prompt = shorten_too_long_text(prompt) |
| | |
| |
|
| | init_image = load_image(image_url).convert("RGB") |
| | mask_image = load_image(mask_url).convert("RGB") |
| | num_inference_steps = 75 |
| | high_noise_frac = 0.7 |
| |
|
| | try: |
| | image = inpaintpipe( |
| | prompt=prompt, |
| | image=init_image, |
| | mask_image=mask_image, |
| | num_inference_steps=num_inference_steps, |
| | denoising_start=high_noise_frac, |
| | output_type="latent", |
| | ).images[0] |
| | except Exception as e: |
| | |
| | |
| | logger.info(f"trying to shorten prompt of length {len(prompt)}") |
| |
|
| | prompt = ' '.join((word for word in prompt if word not in stopwords)) |
| | prompts = prompt.split() |
| |
|
| | prompt = ' '.join(prompts[:len(prompts) // 2]) |
| | logger.info(f"shortened prompt to: {len(prompt)}") |
| | image = None |
| | if prompt: |
| | try: |
| | image = pipe( |
| | prompt=prompt, |
| | image=init_image, |
| | mask_image=mask_image, |
| | num_inference_steps=num_inference_steps, |
| | denoising_start=high_noise_frac, |
| | output_type="latent", |
| | ).images[0] |
| | except Exception as e: |
| | |
| | |
| | |
| | |
| | logger.info(f"trying to shorten prompt of length {len(prompt)}") |
| |
|
| | prompt = ' '.join((word for word in prompt if word not in stopwords)) |
| | prompts = prompt.split() |
| |
|
| | prompt = ' '.join(prompts[:len(prompts) // 2]) |
| | logger.info(f"shortened prompt to: {len(prompt)}") |
| |
|
| | try: |
| | image = inpaintpipe( |
| | prompt=prompt, |
| | image=init_image, |
| | mask_image=mask_image, |
| | num_inference_steps=num_inference_steps, |
| | denoising_start=high_noise_frac, |
| | output_type="latent", |
| | ).images[0] |
| | except Exception as e: |
| | |
| | traceback.print_exc() |
| | raise e |
| | |
| | |
| | |
| | |
| | |
| | |
| | if image != None: |
| | image = inpaint_refiner( |
| | prompt=prompt, |
| | image=image, |
| | mask_image=mask_image, |
| | num_inference_steps=num_inference_steps, |
| | denoising_start=high_noise_frac, |
| |
|
| | ).images[0] |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | bs = BytesIO() |
| |
|
| | bright_count = np.sum(np.array(image) > 0) |
| | if bright_count == 0: |
| | |
| | logger.info("restarting server to fix cuda issues (device side asserts)") |
| | |
| | |
| | |
| | os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`") |
| | os.system("kill -1 `pgrep gunicorn`") |
| | os.system("/usr/bin/bash kill -SIGHUP `pgrep uvicorn`") |
| | os.system("kill -1 `pgrep uvicorn`") |
| |
|
| | return None |
| | image.save(bs, quality=85, optimize=True, format="webp") |
| | bio = bs.getvalue() |
| | |
| | with open("progress.txt", "w") as f: |
| | current_time = datetime.now().strftime("%H:%M:%S") |
| | f.write(f"{current_time}") |
| | return bio |
| |
|
| |
|
| |
|
| | def shorten_too_long_text(prompt): |
| | if len(prompt) > 200: |
| | |
| | prompt = prompt.split() |
| | prompt = ' '.join((word for word in prompt if word not in stopwords)) |
| | if len(prompt) > 200: |
| | prompt = prompt[:200] |
| | return prompt |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|