Spaces:
Runtime error
Runtime error
| import torch | |
| from PIL import Image | |
| from .dpm_solver_pytorch import (NoiseScheduleVP, | |
| model_wrapper, | |
| DPM_Solver) | |
| class FontDiffuserDPMPipeline(): | |
| """FontDiffuser pipeline with DPM_Solver scheduler. | |
| """ | |
| def __init__( | |
| self, | |
| model, | |
| ddpm_train_scheduler, | |
| version="V3", | |
| model_type="noise", | |
| guidance_type="classifier-free", | |
| guidance_scale=7.5 | |
| ): | |
| super().__init__() | |
| self.model = model | |
| self.train_scheduler_betas = ddpm_train_scheduler.betas | |
| # Define the noise schedule | |
| self.noise_schedule = NoiseScheduleVP(schedule='discrete', betas=self.train_scheduler_betas) | |
| self.version = version | |
| self.model_type = model_type | |
| self.guidance_type = guidance_type | |
| self.guidance_scale = guidance_scale | |
| def numpy_to_pil(self, images): | |
| """Convert a numpy image or a batch of images to a PIL image. | |
| """ | |
| if images.ndim == 3: | |
| images = images[None, ...] | |
| images = (images * 255).round().astype("uint8") | |
| pil_images = [Image.fromarray(image) for image in images] | |
| return pil_images | |
| def generate( | |
| self, | |
| content_images, | |
| style_images, | |
| batch_size, | |
| order, | |
| num_inference_step, | |
| content_encoder_downsample_size, | |
| t_start=None, | |
| t_end=None, | |
| dm_size=(96, 96), | |
| algorithm_type="dpmsolver++", | |
| skip_type="time_uniform", | |
| method="multistep", | |
| correcting_x0_fn=None, | |
| generator=None, | |
| ): | |
| model_kwargs = {} | |
| model_kwargs["version"] = self.version | |
| model_kwargs["content_encoder_downsample_size"] = content_encoder_downsample_size | |
| cond = [] | |
| cond.append(content_images) | |
| cond.append(style_images) | |
| uncond = [] | |
| uncond_content_images = torch.ones_like(content_images).to(self.model.device) | |
| uncond_style_images = torch.ones_like(style_images).to(self.model.device) | |
| uncond.append(uncond_content_images) | |
| uncond.append(uncond_style_images) | |
| # 2.Convert the discrete-time model to the continuous-time | |
| model_fn = model_wrapper( | |
| model=self.model, | |
| noise_schedule=self.noise_schedule, | |
| model_type=self.model_type, | |
| model_kwargs=model_kwargs, | |
| guidance_type=self.guidance_type, | |
| condition=cond, | |
| unconditional_condition=uncond, | |
| guidance_scale=self.guidance_scale | |
| ) | |
| # 3. Define dpm-solver and sample by multistep DPM-Solver. | |
| # (We recommend multistep DPM-Solver for conditional sampling) | |
| # You can adjust the `steps` to balance the computation costs and the sample quality. | |
| dpm_solver = DPM_Solver( | |
| model_fn=model_fn, | |
| noise_schedule=self.noise_schedule, | |
| algorithm_type=algorithm_type, | |
| correcting_x0_fn=correcting_x0_fn | |
| ) | |
| # If the DPM is defined on pixel-space images, you can further set `correcting_x0_fn="dynamic_thresholding" | |
| # 4. Generate | |
| # Sample gaussian noise to begin loop => [batch, 3, height, width] | |
| x_T = torch.randn( | |
| (batch_size, 3, dm_size[0], dm_size[1]), | |
| generator=generator, | |
| ) | |
| x_T = x_T.to(self.model.device) | |
| x_sample = dpm_solver.sample( | |
| x=x_T, | |
| steps=num_inference_step, | |
| order=order, | |
| skip_type=skip_type, | |
| method=method, | |
| ) | |
| x_sample = (x_sample / 2 + 0.5).clamp(0, 1) | |
| x_sample = x_sample.cpu().permute(0, 2, 3, 1).numpy() | |
| x_images = self.numpy_to_pil(x_sample) | |
| return x_images | |