Spaces:
Runtime error
Runtime error
File size: 4,887 Bytes
fcfea15 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 | from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
import os
from PIL import Image
import torch
from infer_runtime.infer_config import InferConfig, load_infer_config_class_from_pyfile
from infer_runtime.prompt_rewrite import rewrite_prompt
from infer_runtime.settings import InferSettings
from modules.models import load_dit, load_pipeline
from modules.utils import _dynamic_resize_from_bucket, seed_everything
@dataclass
class InferenceParams:
prompt: str
image: Optional[Image.Image]
height: int
width: int
steps: int
guidance_scale: float
seed: int
neg_prompt: str
basesize: int
class EditModel:
def __init__(
self,
settings: InferSettings,
device: torch.device,
hsdp_shard_dim_override: int | None = None,
):
self.settings = settings
self.device = device
self._rewrite_cache: dict[str, str] = {}
config_class = load_infer_config_class_from_pyfile(settings.config_path)
self.cfg: InferConfig = config_class()
self.cfg.dit_ckpt = settings.ckpt_path
self.cfg.training_mode = False
if hsdp_shard_dim_override is not None:
self.cfg.hsdp_shard_dim = hsdp_shard_dim_override
if int(os.environ.get('WORLD_SIZE', '1')) > 1 and self.cfg.hsdp_shard_dim > 1:
self.cfg.use_fsdp_inference = True
self.dit = load_dit(self.cfg, device=self.device)
self.dit.requires_grad_(False)
self.dit.eval()
self.pipeline = load_pipeline(self.cfg, self.dit, self.device)
def current_device(self) -> torch.device:
return self.device
def move_to_device(self, device: torch.device) -> torch.device:
target = torch.device(device)
if self.device == target:
return self.device
self.dit = self.dit.to(device=target)
self.pipeline = self.pipeline.to(target)
self.device = target
return self.device
def move_to_cpu(self) -> torch.device:
return self.move_to_device(torch.device('cpu'))
def move_to_gpu(self, device: torch.device | None = None) -> torch.device:
target = torch.device(device) if device is not None else torch.device('cuda')
return self.move_to_device(target)
def maybe_rewrite_prompt(self, prompt: str, image: Optional[Image.Image], enabled: bool) -> str:
if not enabled:
return str(prompt or '')
cache_key = f"prompt={prompt.strip()}"
if image is not None:
cache_key += f"|image={image.size[0]}x{image.size[1]}"
if cache_key not in self._rewrite_cache:
self._rewrite_cache[cache_key] = rewrite_prompt(
prompt,
image,
model=self.settings.rewrite_model,
api_key=self.settings.openai_api_key,
base_url=self.settings.openai_base_url,
)
return self._rewrite_cache[cache_key]
@torch.no_grad()
def infer(self, params: InferenceParams) -> Image.Image:
if params.image is None:
prompts = [params.prompt]
negative_prompt = [params.neg_prompt]
images = None
height = params.height
width = params.width
else:
processed = _dynamic_resize_from_bucket(params.image, basesize=params.basesize)
width, height = processed.size
image_tokens = '<image>\n'
prompts = [f"<|im_start|>user\n{image_tokens}{params.prompt}<|im_end|>\n"]
negative_prompt = [f"<|im_start|>user\n{image_tokens}{params.neg_prompt}<|im_end|>\n"]
images = [processed]
generator_device = 'cuda' if self.device.type == 'cuda' else 'cpu'
generator = torch.Generator(device=generator_device).manual_seed(int(params.seed))
output = self.pipeline(
prompt=prompts,
negative_prompt=negative_prompt,
images=images,
height=height,
width=width,
num_frames=1,
num_inference_steps=params.steps,
guidance_scale=params.guidance_scale,
generator=generator,
num_videos_per_prompt=1,
output_type='pt',
return_dict=False,
)
image_tensor = (output[0, -1, 0] * 255).to(torch.uint8).cpu()
return Image.fromarray(image_tensor.permute(1, 2, 0).numpy())
def build_model(
settings: InferSettings,
device: torch.device | None = None,
hsdp_shard_dim_override: int | None = None,
) -> EditModel:
seed_everything(settings.default_seed)
if device is None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
return EditModel(
settings=settings,
device=device,
hsdp_shard_dim_override=hsdp_shard_dim_override,
)
|