Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from typing import Optional | |
| import os | |
| from PIL import Image | |
| import torch | |
| from infer_runtime.infer_config import InferConfig, load_infer_config_class_from_pyfile | |
| from infer_runtime.prompt_rewrite import rewrite_prompt | |
| from infer_runtime.settings import InferSettings | |
| from modules.models import load_dit, load_pipeline | |
| from modules.utils import _dynamic_resize_from_bucket, seed_everything | |
| class InferenceParams: | |
| prompt: str | |
| image: Optional[Image.Image] | |
| height: int | |
| width: int | |
| steps: int | |
| guidance_scale: float | |
| seed: int | |
| neg_prompt: str | |
| basesize: int | |
| class EditModel: | |
| def __init__( | |
| self, | |
| settings: InferSettings, | |
| device: torch.device, | |
| hsdp_shard_dim_override: int | None = None, | |
| ): | |
| self.settings = settings | |
| self.device = device | |
| self._rewrite_cache: dict[str, str] = {} | |
| config_class = load_infer_config_class_from_pyfile(settings.config_path) | |
| self.cfg: InferConfig = config_class() | |
| self.cfg.dit_ckpt = settings.ckpt_path | |
| self.cfg.training_mode = False | |
| if hsdp_shard_dim_override is not None: | |
| self.cfg.hsdp_shard_dim = hsdp_shard_dim_override | |
| if int(os.environ.get('WORLD_SIZE', '1')) > 1 and self.cfg.hsdp_shard_dim > 1: | |
| self.cfg.use_fsdp_inference = True | |
| self.dit = load_dit(self.cfg, device=self.device) | |
| self.dit.requires_grad_(False) | |
| self.dit.eval() | |
| self.pipeline = load_pipeline(self.cfg, self.dit, self.device) | |
| def current_device(self) -> torch.device: | |
| return self.device | |
| def move_to_device(self, device: torch.device) -> torch.device: | |
| target = torch.device(device) | |
| if self.device == target: | |
| return self.device | |
| self.dit = self.dit.to(device=target) | |
| self.pipeline = self.pipeline.to(target) | |
| self.device = target | |
| return self.device | |
| def move_to_cpu(self) -> torch.device: | |
| return self.move_to_device(torch.device('cpu')) | |
| def move_to_gpu(self, device: torch.device | None = None) -> torch.device: | |
| target = torch.device(device) if device is not None else torch.device('cuda') | |
| return self.move_to_device(target) | |
| def maybe_rewrite_prompt(self, prompt: str, image: Optional[Image.Image], enabled: bool) -> str: | |
| if not enabled: | |
| return str(prompt or '') | |
| cache_key = f"prompt={prompt.strip()}" | |
| if image is not None: | |
| cache_key += f"|image={image.size[0]}x{image.size[1]}" | |
| if cache_key not in self._rewrite_cache: | |
| self._rewrite_cache[cache_key] = rewrite_prompt( | |
| prompt, | |
| image, | |
| model=self.settings.rewrite_model, | |
| api_key=self.settings.openai_api_key, | |
| base_url=self.settings.openai_base_url, | |
| ) | |
| return self._rewrite_cache[cache_key] | |
| def infer(self, params: InferenceParams) -> Image.Image: | |
| if params.image is None: | |
| prompts = [params.prompt] | |
| negative_prompt = [params.neg_prompt] | |
| images = None | |
| height = params.height | |
| width = params.width | |
| else: | |
| processed = _dynamic_resize_from_bucket(params.image, basesize=params.basesize) | |
| width, height = processed.size | |
| image_tokens = '<image>\n' | |
| prompts = [f"<|im_start|>user\n{image_tokens}{params.prompt}<|im_end|>\n"] | |
| negative_prompt = [f"<|im_start|>user\n{image_tokens}{params.neg_prompt}<|im_end|>\n"] | |
| images = [processed] | |
| generator_device = 'cuda' if self.device.type == 'cuda' else 'cpu' | |
| generator = torch.Generator(device=generator_device).manual_seed(int(params.seed)) | |
| output = self.pipeline( | |
| prompt=prompts, | |
| negative_prompt=negative_prompt, | |
| images=images, | |
| height=height, | |
| width=width, | |
| num_frames=1, | |
| num_inference_steps=params.steps, | |
| guidance_scale=params.guidance_scale, | |
| generator=generator, | |
| num_videos_per_prompt=1, | |
| output_type='pt', | |
| return_dict=False, | |
| ) | |
| image_tensor = (output[0, -1, 0] * 255).to(torch.uint8).cpu() | |
| return Image.fromarray(image_tensor.permute(1, 2, 0).numpy()) | |
| def build_model( | |
| settings: InferSettings, | |
| device: torch.device | None = None, | |
| hsdp_shard_dim_override: int | None = None, | |
| ) -> EditModel: | |
| seed_everything(settings.default_seed) | |
| if device is None: | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| return EditModel( | |
| settings=settings, | |
| device=device, | |
| hsdp_shard_dim_override=hsdp_shard_dim_override, | |
| ) | |