| """ |
| High-level inference pipeline for UniBioTransfer. |
| Designed for easy use in Hugging Face Spaces and other applications. |
| |
| ZeroGPU Compatible: |
| - Supports CPU initialization (device="cpu") |
| - Dynamically switches to CUDA during inference when called from @spaces.GPU |
| """ |
| from pathlib import Path |
| import torch |
| import numpy as np |
| from PIL import Image |
| import cv2 |
|
|
| import global_ |
| from hf_model import UniBioTransferModel, TASK_NAME2ID, TASK_ID2NAME |
| from ldm.models.diffusion.ddim import DDIMSampler |
| from pytorch_lightning import seed_everything |
|
|
| DDIM_STEPS_DEFAULT = 50 |
| SCALE_DEFAULT = 3.0 |
|
|
|
|
| H, W, C, F = 512, 512, 4, 8 |
| class UniBioTransferPipeline: |
| """ |
| High-level pipeline for UniBioTransfer inference. |
| """ |
|
|
| def __init__(self, model, task="face", device="cpu"): |
| """ |
| Initialize pipeline with a loaded model. |
| """ |
| self.model = model |
| self.task = task |
| self.task_id = TASK_NAME2ID.get(task, task) if isinstance(task, str) else task |
| self._init_device = device |
|
|
| global_.task = self.task_id |
| self.model.task = self.task_id |
|
|
| self.sampler = DDIMSampler(model) |
|
|
| @classmethod |
| def from_pretrained( |
| cls, |
| repo_id="scy639/UniBioTransfer", |
| task="face", |
| device="cpu", |
| cache_dir=None, |
| **kwargs, |
| ): |
| """ |
| Load pipeline from Hugging Face Hub. |
| """ |
| model = UniBioTransferModel.from_pretrained( |
| pretrained_model_name_or_path=repo_id, |
| task=task, |
| device=device, |
| cache_dir=cache_dir, |
| **kwargs, |
| ) |
| return cls(model, task=task, device=device) |
|
|
| def set_task(self, task): |
| """Switch to a different task.""" |
| self.task = task |
| self.task_id = TASK_NAME2ID.get(task, task) if isinstance(task, str) else task |
| global_.task = self.task_id |
| self.model.task = self.task_id |
|
|
| def __call__( |
| self, |
| tgt_image, |
| ref_image, |
| ddim_steps=DDIM_STEPS_DEFAULT, |
| scale=SCALE_DEFAULT, |
| seed=42, |
| num_images=1, |
| ): |
| """ |
| Run inference on a pair of images. |
| """ |
| seed_everything(seed) |
|
|
| tgt_img = self._load_image(tgt_image) |
| ref_img = self._load_image(ref_image) |
|
|
| tgt_img = self._resize_image(tgt_img, (H, W)) |
| ref_img = self._resize_image(ref_img, (H, W)) |
|
|
| result_tensors = self._run_inference(tgt_img, ref_img, ddim_steps, scale, num_images) |
|
|
| result_imgs = [self._postprocess(result_tensors[i]) for i in range(result_tensors.shape[0])] |
| return result_imgs |
|
|
| def _load_image(self, img): |
| """Load image from various formats.""" |
| if isinstance(img, Image.Image): |
| return img.convert("RGB") |
| elif isinstance(img, np.ndarray): |
| return Image.fromarray(img).convert("RGB") |
| elif isinstance(img, (str, Path)): |
| return Image.open(img).convert("RGB") |
| else: |
| raise ValueError(f"Unsupported image type: {type(img)}") |
|
|
| def _resize_image(self, img, size): |
| """Resize image to target size.""" |
| if img.size != size: |
| img = img.resize(size, Image.LANCZOS) |
| return img |
|
|
| def _run_inference(self, tgt_img, ref_img, ddim_steps, scale, num_images): |
| """ |
| Run diffusion sampling. |
| 完全复用 infer.py 的逻辑,使用 dataloader。 |
| """ |
| from Dataset_custom import Dataset_custom |
| from gen_lmk_and_mask import gen_lmk_and_mask |
| import tempfile |
|
|
| with tempfile.TemporaryDirectory() as tmpdir: |
| tgt_path = Path(tmpdir) / "tgt.png" |
| ref_path = Path(tmpdir) / "ref.png" |
| tgt_img.save(tgt_path) |
| ref_img.save(ref_path) |
|
|
| gen_lmk_and_mask([str(tgt_path), str(ref_path)], write_cache=True) |
|
|
| dataset = Dataset_custom( |
| "test", |
| task=self.task_id, |
| paths_tgt=[str(tgt_path)], |
| paths_ref=[str(ref_path)], |
| ) |
|
|
| dataloader = torch.utils.data.DataLoader( |
| dataset, |
| batch_size=1, |
| num_workers=1, |
| pin_memory=True, |
| shuffle=False, |
| drop_last=False, |
| ) |
|
|
| run_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| self.model = self.model.to(run_device) |
|
|
| with torch.no_grad(): |
| for test_batch, prior, test_model_kwargs, out_stem_batch in dataloader: |
| test_batch = test_batch.to(run_device) |
| if test_batch.shape[0] == 1: |
| test_batch = test_batch.repeat(num_images, 1, 1, 1) |
| if isinstance(prior, torch.Tensor): |
| prior = prior.to(run_device) |
| if prior.shape[0] == 1: |
| prior = prior.repeat(num_images, 1, 1, 1) |
| for k, v in test_model_kwargs.items(): |
| if isinstance(v, torch.Tensor): |
| v = v.to(run_device) |
| if v.shape[0] == 1: |
| repeats = [num_images] + [1] * (v.ndim - 1) |
| v = v.repeat(*repeats) |
| test_model_kwargs[k] = v |
| elif isinstance(v, dict): |
| new_v = {} |
| for kk, vv in v.items(): |
| if isinstance(vv, torch.Tensor): |
| vv = vv.to(run_device) |
| if vv.shape[0] == 1: |
| repeats = [num_images] + [1] * (vv.ndim - 1) |
| vv = vv.repeat(*repeats) |
| new_v[kk] = vv |
| else: |
| new_v[kk] = vv |
| test_model_kwargs[k] = new_v |
| elif isinstance(v, list): |
| test_model_kwargs[k] = v * num_images |
|
|
| self.model.set_task(test_model_kwargs) |
| bs = num_images |
|
|
| batch_ = { |
| **test_model_kwargs, |
| "GT": torch.zeros(num_images, *test_model_kwargs["inpaint_image"].shape[1:], device=run_device), |
| } |
| batch_, c = self.model.get_input_and_conditioning(batch_, device=run_device) |
|
|
| z_inpaint = batch_["z4_inpaint"] |
| z_inpaint_mask = batch_["tgt_mask_64"] |
| z_ref = batch_["z_ref"] |
| z9 = batch_["z9"] |
|
|
| uc = None |
| if scale != 1.0: |
| uc = self.model.learnable_vector[self.task_id].repeat(bs, 1, 1) |
|
|
| shape = [C, H // F, W // F] |
| start_code = None |
|
|
| samples_ddim, _ = self.sampler.sample( |
| S=ddim_steps, |
| conditioning=c, |
| batch_size=bs, |
| shape=shape, |
| verbose=False, |
| unconditional_guidance_scale=scale, |
| unconditional_conditioning=uc, |
| eta=0.0, |
| x_T=start_code, |
| log_every_t=100, |
| z_inpaint=z_inpaint, |
| z_inpaint_mask=z_inpaint_mask, |
| z_ref=z_ref, |
| z9=z9, |
| ) |
|
|
| x_samples_ddim = self.model.decode_first_stage(samples_ddim) |
| x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) |
|
|
| self.model.unset_task() |
|
|
| return x_samples_ddim |
|
|
| def _postprocess(self, tensor): |
| """Convert model output tensor to PIL Image.""" |
| img_array = tensor.cpu().permute(1, 2, 0).numpy() |
| img_array = (img_array * 255).astype(np.uint8) |
| return Image.fromarray(img_array) |
|
|
|
|
| def infer_single( |
| tgt_path, |
| ref_path, |
| task="face", |
| output_path=None, |
| ddim_steps=DDIM_STEPS_DEFAULT, |
| scale=SCALE_DEFAULT, |
| device="cuda", |
| ): |
| """ |
| Convenience function for single inference. |
| """ |
| pipeline = UniBioTransferPipeline.from_pretrained(task=task, device=device) |
| result = pipeline(tgt_path, ref_path, ddim_steps=ddim_steps, scale=scale) |
|
|
| if output_path is not None: |
| result.save(output_path) |
| print(f"Saved result to {output_path}") |
|
|
| return result |
|
|
|
|
| if __name__ == "__main__": |
| import argparse |
|
|
| parser = argparse.ArgumentParser(description="UniBioTransfer inference") |
| parser.add_argument("--task", type=str, default="face", choices=["face", "hair", "motion", "head"]) |
| parser.add_argument("--tgt", type=str, required=True, help="Path to target image") |
| parser.add_argument("--ref", type=str, required=True, help="Path to reference image") |
| parser.add_argument("--out", type=str, default="result.png", help="Output path") |
| parser.add_argument("--ddim-steps", type=int, default=50) |
| parser.add_argument("--scale", type=float, default=3.0) |
| parser.add_argument("--device", type=str, default="cuda") |
|
|
| args = parser.parse_args() |
|
|
| result = infer_single( |
| args.tgt, |
| args.ref, |
| task=args.task, |
| output_path=args.out, |
| ddim_steps=args.ddim_steps, |
| scale=args.scale, |
| device=args.device, |
| ) |
|
|
| print(f"Inference complete. Result shape: {result.size}") |
|
|