| import os |
| from PIL import Image |
| import contextlib |
| import torch |
| from diffusers import DiffusionPipeline, StableDiffusionXLPipeline |
|
|
| from fp12 import Linear, Conv2d |
|
|
| pipe = None |
|
|
| PATH_TO_MODEL = "./animagineXLV3_v30.safetensors" |
| USE_FP12 = True |
| FP12_ONLY_ATTN = True |
| FP12_APPLY_LINEAR = False |
| FP12_APPLY_CONV = False |
|
|
|
|
| |
| |
| |
|
|
| def free_memory(): |
| import gc |
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| def to_fp12(module: torch.nn.Module): |
| target_modules = [] |
| |
| if FP12_APPLY_LINEAR: |
| target_modules.append((torch.nn.Linear, Linear)) |
| |
| if FP12_APPLY_CONV: |
| target_modules.append((torch.nn.Conv2d, Conv2d)) |
| |
| for name, mod in list(module.named_children()): |
| for orig_class, fp12_class in target_modules: |
| if isinstance(mod, orig_class): |
| try: |
| new_mod = fp12_class(mod) |
| except Exception as e: |
| print(f' -> failed: {name} {str(e)}') |
| continue |
| |
| delattr(module, name) |
| del mod |
| |
| setattr(module, name, new_mod) |
| break |
|
|
|
|
| def load_model_cpu(path: str): |
| pipe = StableDiffusionXLPipeline.from_single_file( |
| path, |
| torch_dtype=torch.float16, |
| safety_checker=None, |
| ) |
| return pipe |
|
|
| def replace_fp12(pipe: DiffusionPipeline): |
| for name, mod in pipe.unet.named_modules(): |
| if FP12_ONLY_ATTN and 'attn' not in name: |
| continue |
| print('[fp12] REPLACE', name) |
| to_fp12(mod) |
| return pipe |
|
|
|
|
| @contextlib.contextmanager |
| def cuda_profiler(device: str): |
| cuda_start = torch.cuda.Event(enable_timing=True) |
| cuda_end = torch.cuda.Event(enable_timing=True) |
|
|
| obj = {} |
| |
| torch.cuda.synchronize() |
| torch.cuda.reset_peak_memory_stats(device) |
| cuda_start.record() |
| |
| try: |
| yield obj |
| finally: |
| pass |
|
|
| cuda_end.record() |
| torch.cuda.synchronize() |
| obj['time'] = cuda_start.elapsed_time(cuda_end) |
| obj['memory'] = torch.cuda.max_memory_allocated(device) |
|
|
| |
| |
| |
|
|
| def generate(pipe: DiffusionPipeline, prompt: str, negative_prompt: str, seed: int, device: str, use_amp: bool = False, guidance_scale = None, steps = None): |
| import contextlib |
| import torch.amp |
| |
| context = ( |
| torch.amp.autocast_mode.autocast if use_amp |
| else contextlib.nullcontext |
| ) |
|
|
| with torch.no_grad(), context(device): |
| rng = torch.Generator(device=device) |
| if 0 <= seed: |
| rng = rng.manual_seed(seed) |
| |
| latents, *_ = pipe( |
| prompt=prompt, |
| negative_prompt=negative_prompt, |
| width=1024, |
| height=1024, |
| num_inference_steps=steps, |
| guidance_scale=guidance_scale, |
| num_images_per_prompt=1, |
| generator=rng, |
| device=device, |
| return_dict=False, |
| output_type='latent', |
| ) |
| |
| return latents |
| |
| def save_image(pipe, latents): |
| with torch.no_grad(): |
| images = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0] |
| images = pipe.image_processor.postprocess(images, output_type='pil') |
| |
| for i, image in enumerate(images): |
| |
| return image |
|
|
| def load_model(model = None, device = None): |
| global pipe |
|
|
| model = model or PATH_TO_MODEL |
| device = device or 'cuda:0' |
|
|
| pipe = load_model_cpu(model) |
| |
| if USE_FP12: |
| pipe = replace_fp12(pipe) |
| |
| free_memory() |
| with cuda_profiler(device) as prof: |
| pipe.unet = pipe.unet.to(device) |
| print('LOAD VRAM', prof['memory']) |
| print('LOAD TIME', prof['time']) |
| |
| pipe.text_encoder = pipe.text_encoder.to(device) |
| pipe.text_encoder_2 = pipe.text_encoder_2.to(device) |
| |
| if torch.cuda.is_available(): |
| torch.cuda.synchronize(device) |
|
|
| def run(prompt = None, negative_prompt = None, model = None, guidance_scale = None, steps = None, seed = None, device: str = None, use_amp: bool = False): |
| global pipe |
|
|
| if not pipe: |
| load_model(model) |
| |
| _prompt = "masterpiece, best quality, 1girl, portrait" |
| _negative_prompt = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name" |
|
|
| prompt = prompt or _prompt |
| negative_prompt = negative_prompt or _negative_prompt |
| guidance_scale = float(guidance_scale) if guidance_scale else 5.0 |
| steps = int(steps) if steps else 20 |
| seed = int(seed) if seed else -1 |
| device = device or 'cuda:0' |
|
|
| free_memory() |
| with cuda_profiler(device) as prof: |
| latents = generate(pipe, prompt, negative_prompt, seed, device, use_amp, guidance_scale, steps) |
| print('UNET VRAM', prof['memory']) |
| print('UNET TIME', prof['time']) |
| |
| |
| |
| |
| |
| free_memory() |
| pipe.vae = pipe.vae.to(device) |
| pipe.vae.enable_slicing() |
| return save_image(pipe, latents) |
|
|
| def pil_to_webp(img): |
| buffer = io.BytesIO() |
| img.save(buffer, 'webp') |
|
|
| return buffer.getvalue() |
|
|
| def bin_to_base64(bin): |
| return base64.b64encode(bin).decode('ascii') |
|
|