| import gc |
| import math |
| import multiprocessing |
| import os |
| import traceback |
| from datetime import datetime |
| from io import BytesIO |
| from itertools import permutations |
| from multiprocessing.pool import Pool |
| from pathlib import Path |
| from urllib.parse import quote_plus |
|
|
| import numpy as np |
| import nltk |
| import torch |
|
|
| from PIL.Image import Image |
| from diffusers import DiffusionPipeline, StableDiffusionXLInpaintPipeline |
| from diffusers.utils import load_image |
| from fastapi import FastAPI |
| from fastapi.middleware.gzip import GZipMiddleware |
| from loguru import logger |
| from starlette.middleware.cors import CORSMiddleware |
| from starlette.responses import FileResponse |
| from starlette.responses import JSONResponse |
|
|
| from env import BUCKET_PATH, BUCKET_NAME |
| |
| torch._dynamo.config.suppress_errors = True |
|
|
| import string |
| import random |
|
|
| def generate_save_path(): |
| |
| N = 7 |
|
|
| |
| |
| res = ''.join(random.choices(string.ascii_uppercase + |
| string.digits, k=N)) |
| return res |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| model_dir = os.getenv("SDXL_MODEL_DIR") |
|
|
| if model_dir: |
| |
| model_key_base = os.path.join(model_dir, "stable-diffusion-xl-base-1.0") |
| model_key_refiner = os.path.join(model_dir, "stable-diffusion-xl-refiner-1.0") |
| else: |
| model_key_base = "stabilityai/stable-diffusion-xl-base-1.0" |
| model_key_refiner = "stabilityai/stable-diffusion-xl-refiner-1.0" |
|
|
| pipe = DiffusionPipeline.from_pretrained(model_key_base, torch_dtype=torch.float16, use_safetensors=True, variant="fp16") |
|
|
| pipe.watermark = None |
|
|
| pipe.to("cuda") |
|
|
| refiner = DiffusionPipeline.from_pretrained( |
| "stabilityai/stable-diffusion-xl-refiner-1.0", |
| text_encoder_2=pipe.text_encoder_2, |
| vae=pipe.vae, |
| torch_dtype=torch.bfloat16, |
| use_safetensors=True, |
| variant="fp16", |
| ) |
| refiner.watermark = None |
| refiner.to("cuda") |
|
|
| |
| inpaintpipe = StableDiffusionXLInpaintPipeline.from_pretrained( |
| "models/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16, variant="fp16", use_safetensors=True, |
| scheduler=pipe.scheduler, |
| text_encoder=pipe.text_encoder, |
| text_encoder_2=pipe.text_encoder_2, |
| tokenizer=pipe.tokenizer, |
| tokenizer_2=pipe.tokenizer_2, |
| unet=pipe.unet, |
| vae=pipe.vae, |
| |
| ) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| inpaintpipe.to("cuda") |
| inpaintpipe.watermark = None |
| |
|
|
| inpaint_refiner = StableDiffusionXLInpaintPipeline.from_pretrained( |
| "stabilityai/stable-diffusion-xl-refiner-1.0", |
| text_encoder_2=inpaintpipe.text_encoder_2, |
| vae=inpaintpipe.vae, |
| torch_dtype=torch.bfloat16, |
| use_safetensors=True, |
| variant="fp16", |
|
|
| tokenizer_2=refiner.tokenizer_2, |
| tokenizer=refiner.tokenizer, |
| scheduler=refiner.scheduler, |
| text_encoder=refiner.text_encoder, |
| unet=refiner.unet, |
| ) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| inpaint_refiner.to("cuda") |
| inpaint_refiner.watermark = None |
| |
|
|
| n_steps = 40 |
| high_noise_frac = 0.8 |
|
|
| |
| |
|
|
|
|
| |
| |
| pipe.unet = torch.compile(pipe.unet) |
| refiner.unet = torch.compile(refiner.unet) |
| |
| inpaintpipe.unet = pipe.unet |
| inpaint_refiner.unet = refiner.unet |
| |
| |
| from pydantic import BaseModel |
|
|
| app = FastAPI( |
| openapi_url="/static/openapi.json", |
| docs_url="/swagger-docs", |
| redoc_url="/redoc", |
| title="Generate Images Netwrck API", |
| description="Character Chat API", |
| |
| version="1", |
| ) |
| app.add_middleware(GZipMiddleware, minimum_size=1000) |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| stopwords = nltk.corpus.stopwords.words("english") |
|
|
| class Img(BaseModel): |
| system_prompt: str |
| ASSISTANT: str |
|
|
| |
| img_url = "http://phlrr3105.guest.corp.microsoft.com:8000/" |
|
|
| is_gpu_busy = False |
|
|
| def get_summary(system_prompt, prompt): |
| import requests |
| import time |
| from io import BytesIO |
| import json |
| summary_sys = """I want you to act as a text summarizer to help me create a concise summary of the text I provide. The summary can be up to 60.0 words in length, expressing the key points, key scenarios, main character and concepts written in the original text without adding your interpretations.""" |
| instruction = summary_sys |
| |
| |
| |
| message = f"""My first request is to summarize this text – [{prompt}]""" |
| instruction += ' USER: ' + message + ' ASSISTANT:' |
|
|
| print("Ins: ", instruction) |
| |
| |
| json_object = {"prompt": instruction, |
| |
| "max_tokens": 90, |
| "n": 1 |
| } |
| generate_response = requests.post("http://phlrr3105.guest.corp.microsoft.com:7991/generate", json=json_object) |
| |
| res_json = json.loads(generate_response.content) |
| ASSISTANT = res_json['text'][-1].split("ASSISTANT:")[-1].strip() |
| print(ASSISTANT) |
| return ASSISTANT |
|
|
| @app.post("/image_url") |
| def image_url(img: Img): |
| system_prompt = img.system_prompt |
| prompt = img.ASSISTANT |
| prompt = get_summary(system_prompt, prompt) |
| prompt = shorten_too_long_text(prompt) |
| |
| |
| |
| |
| g = torch.Generator(device="cuda") |
| image = pipe(prompt=prompt, width=1024, height=1024, generator=g).images[0] |
|
|
| |
| save_path = generate_save_path() |
| save_path = f"images/{save_path}.png" |
| image.save(save_path) |
| |
| path = f"{img_url}/{save_path}" |
| return JSONResponse({"path": path}) |
|
|
|
|
| @app.get("/make_image") |
| |
| def make_image(prompt: str, save_path: str = ""): |
| if Path(save_path).exists(): |
| return FileResponse(save_path, media_type="image/png") |
| image = pipe(prompt=prompt).images[0] |
| if not save_path: |
| save_path = f"images/{prompt}.png" |
| image.save(save_path) |
| return FileResponse(save_path, media_type="image/png") |
|
|
|
|
| @app.get("/create_and_upload_image") |
| def create_and_upload_image(prompt: str, width: int=1024, height:int=1024, save_path: str = ""): |
| path_components = save_path.split("/")[0:-1] |
| final_name = save_path.split("/")[-1] |
| if not path_components: |
| path_components = [] |
| save_path = '/'.join(path_components) + quote_plus(final_name) |
| path = get_image_or_create_upload_to_cloud_storage(prompt, width, height, save_path) |
| return JSONResponse({"path": path}) |
|
|
| @app.get("/inpaint_and_upload_image") |
| def inpaint_and_upload_image(prompt: str, image_url:str, mask_url:str, save_path: str = ""): |
| path_components = save_path.split("/")[0:-1] |
| final_name = save_path.split("/")[-1] |
| if not path_components: |
| path_components = [] |
| save_path = '/'.join(path_components) + quote_plus(final_name) |
| path = get_image_or_inpaint_upload_to_cloud_storage(prompt, image_url, mask_url, save_path) |
| return JSONResponse({"path": path}) |
|
|
|
|
| def get_image_or_create_upload_to_cloud_storage(prompt:str,width:int, height:int, save_path:str): |
| prompt = shorten_too_long_text(prompt) |
| save_path = shorten_too_long_text(save_path) |
| |
| if check_if_blob_exists(save_path): |
| return f"https://{BUCKET_NAME}/{BUCKET_PATH}/{save_path}" |
| bio = create_image_from_prompt(prompt, width, height) |
| if bio is None: |
| return None |
| link = upload_to_bucket(save_path, bio, is_bytesio=True) |
| return link |
| def get_image_or_inpaint_upload_to_cloud_storage(prompt:str, image_url:str, mask_url:str, save_path:str): |
| prompt = shorten_too_long_text(prompt) |
| save_path = shorten_too_long_text(save_path) |
| |
| if check_if_blob_exists(save_path): |
| return f"https://{BUCKET_NAME}/{BUCKET_PATH}/{save_path}" |
| bio = inpaint_image_from_prompt(prompt, image_url, mask_url) |
| if bio is None: |
| return None |
| link = upload_to_bucket(save_path, bio, is_bytesio=True) |
| return link |
|
|
| |
| |
| |
| |
| |
|
|
| def create_image_from_prompt(prompt, width, height): |
| |
| block_width = width - (width % 64) |
| block_height = height - (height % 64) |
| prompt = shorten_too_long_text(prompt) |
| |
| try: |
| image = pipe(prompt=prompt, |
| width=block_width, |
| height=block_height, |
| |
| |
| |
| |
| num_inference_steps=50).images[0] |
| except Exception as e: |
| |
| |
| logger.info(f"trying to shorten prompt of length {len(prompt)}") |
|
|
| prompt = ' '.join((word for word in prompt if word not in stopwords)) |
| prompts = prompt.split() |
|
|
| prompt = ' '.join(prompts[:len(prompts) // 2]) |
| logger.info(f"shortened prompt to: {len(prompt)}") |
| image = None |
| if prompt: |
| try: |
| image = pipe(prompt=prompt, |
| width=block_width, |
| height=block_height, |
| |
| |
| |
| |
| num_inference_steps=50).images[0] |
| except Exception as e: |
| |
| |
| |
| |
| logger.info(f"trying to shorten prompt of length {len(prompt)}") |
|
|
| prompt = ' '.join((word for word in prompt if word not in stopwords)) |
| prompts = prompt.split() |
|
|
| prompt = ' '.join(prompts[:len(prompts) // 2]) |
| logger.info(f"shortened prompt to: {len(prompt)}") |
|
|
| try: |
| image = pipe(prompt=prompt, |
| width=block_width, |
| height=block_height, |
| |
| |
| |
| |
| num_inference_steps=50).images[0] |
| except Exception as e: |
| |
| traceback.print_exc() |
| raise e |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| if width != block_width or height != block_height: |
| |
| |
| scale_up_ratio = max(width / block_width, height / block_height) |
| image = image.resize((math.ceil(block_width * scale_up_ratio), math.ceil(height * scale_up_ratio))) |
| |
| image = image.crop((0, 0, width, height)) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| bs = BytesIO() |
|
|
| bright_count = np.sum(np.array(image) > 0) |
| if bright_count == 0: |
| |
| logger.info("restarting server to fix cuda issues (device side asserts)") |
| |
| |
| |
| os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`") |
| os.system("kill -1 `pgrep gunicorn`") |
| os.system("/usr/bin/bash kill -SIGHUP `pgrep uvicorn`") |
| os.system("kill -1 `pgrep uvicorn`") |
|
|
| return None |
| image.save(bs, quality=85, optimize=True, format="webp") |
| bio = bs.getvalue() |
| |
| with open("progress.txt", "w") as f: |
| current_time = datetime.now().strftime("%H:%M:%S") |
| f.write(f"{current_time}") |
| return bio |
|
|
| def inpaint_image_from_prompt(prompt, image_url: str, mask_url: str): |
| prompt = shorten_too_long_text(prompt) |
| |
|
|
| init_image = load_image(image_url).convert("RGB") |
| mask_image = load_image(mask_url).convert("RGB") |
| num_inference_steps = 75 |
| high_noise_frac = 0.7 |
|
|
| try: |
| image = inpaintpipe( |
| prompt=prompt, |
| image=init_image, |
| mask_image=mask_image, |
| num_inference_steps=num_inference_steps, |
| denoising_start=high_noise_frac, |
| output_type="latent", |
| ).images[0] |
| except Exception as e: |
| |
| |
| logger.info(f"trying to shorten prompt of length {len(prompt)}") |
|
|
| prompt = ' '.join((word for word in prompt if word not in stopwords)) |
| prompts = prompt.split() |
|
|
| prompt = ' '.join(prompts[:len(prompts) // 2]) |
| logger.info(f"shortened prompt to: {len(prompt)}") |
| image = None |
| if prompt: |
| try: |
| image = pipe( |
| prompt=prompt, |
| image=init_image, |
| mask_image=mask_image, |
| num_inference_steps=num_inference_steps, |
| denoising_start=high_noise_frac, |
| output_type="latent", |
| ).images[0] |
| except Exception as e: |
| |
| |
| |
| |
| logger.info(f"trying to shorten prompt of length {len(prompt)}") |
|
|
| prompt = ' '.join((word for word in prompt if word not in stopwords)) |
| prompts = prompt.split() |
|
|
| prompt = ' '.join(prompts[:len(prompts) // 2]) |
| logger.info(f"shortened prompt to: {len(prompt)}") |
|
|
| try: |
| image = inpaintpipe( |
| prompt=prompt, |
| image=init_image, |
| mask_image=mask_image, |
| num_inference_steps=num_inference_steps, |
| denoising_start=high_noise_frac, |
| output_type="latent", |
| ).images[0] |
| except Exception as e: |
| |
| traceback.print_exc() |
| raise e |
| |
| |
| |
| |
| |
| |
| if image != None: |
| image = inpaint_refiner( |
| prompt=prompt, |
| image=image, |
| mask_image=mask_image, |
| num_inference_steps=num_inference_steps, |
| denoising_start=high_noise_frac, |
|
|
| ).images[0] |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| bs = BytesIO() |
|
|
| bright_count = np.sum(np.array(image) > 0) |
| if bright_count == 0: |
| |
| logger.info("restarting server to fix cuda issues (device side asserts)") |
| |
| |
| |
| os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`") |
| os.system("kill -1 `pgrep gunicorn`") |
| os.system("/usr/bin/bash kill -SIGHUP `pgrep uvicorn`") |
| os.system("kill -1 `pgrep uvicorn`") |
|
|
| return None |
| image.save(bs, quality=85, optimize=True, format="webp") |
| bio = bs.getvalue() |
| |
| with open("progress.txt", "w") as f: |
| current_time = datetime.now().strftime("%H:%M:%S") |
| f.write(f"{current_time}") |
| return bio |
|
|
|
|
|
|
| def shorten_too_long_text(prompt): |
| if len(prompt) > 200: |
| |
| prompt = prompt.split() |
| prompt = ' '.join((word for word in prompt if word not in stopwords)) |
| if len(prompt) > 200: |
| prompt = prompt[:200] |
| return prompt |
|
|
| |
| |
| |
| |
| |
| |
|
|
|
|
|
|