""" High-level inference pipeline for UniBioTransfer. Designed for easy use in Hugging Face Spaces and other applications. ZeroGPU Compatible: - Supports CPU initialization (device="cpu") - Dynamically switches to CUDA during inference when called from @spaces.GPU """ from pathlib import Path import torch import numpy as np from PIL import Image import cv2 import global_ from hf_model import UniBioTransferModel, TASK_NAME2ID, TASK_ID2NAME from ldm.models.diffusion.ddim import DDIMSampler from pytorch_lightning import seed_everything DDIM_STEPS_DEFAULT = 50 SCALE_DEFAULT = 3.0 H, W, C, F = 512, 512, 4, 8 class UniBioTransferPipeline: """ High-level pipeline for UniBioTransfer inference. """ def __init__(self, model, task="face", device="cpu"): """ Initialize pipeline with a loaded model. """ self.model = model self.task = task self.task_id = TASK_NAME2ID.get(task, task) if isinstance(task, str) else task self._init_device = device global_.task = self.task_id self.model.task = self.task_id self.sampler = DDIMSampler(model) @classmethod def from_pretrained( cls, repo_id="scy639/UniBioTransfer", task="face", device="cpu", cache_dir=None, **kwargs, ): """ Load pipeline from Hugging Face Hub. """ model = UniBioTransferModel.from_pretrained( pretrained_model_name_or_path=repo_id, task=task, device=device, cache_dir=cache_dir, **kwargs, ) return cls(model, task=task, device=device) def set_task(self, task): """Switch to a different task.""" self.task = task self.task_id = TASK_NAME2ID.get(task, task) if isinstance(task, str) else task global_.task = self.task_id self.model.task = self.task_id def __call__( self, tgt_image, ref_image, ddim_steps=DDIM_STEPS_DEFAULT, scale=SCALE_DEFAULT, seed=42, num_images=1, ): """ Run inference on a pair of images. """ seed_everything(seed) tgt_img = self._load_image(tgt_image) ref_img = self._load_image(ref_image) tgt_img = self._resize_image(tgt_img, (H, W)) ref_img = self._resize_image(ref_img, (H, W)) result_tensors = self._run_inference(tgt_img, ref_img, ddim_steps, scale, num_images) result_imgs = [self._postprocess(result_tensors[i]) for i in range(result_tensors.shape[0])] return result_imgs def _load_image(self, img): """Load image from various formats.""" if isinstance(img, Image.Image): return img.convert("RGB") elif isinstance(img, np.ndarray): return Image.fromarray(img).convert("RGB") elif isinstance(img, (str, Path)): return Image.open(img).convert("RGB") else: raise ValueError(f"Unsupported image type: {type(img)}") def _resize_image(self, img, size): """Resize image to target size.""" if img.size != size: img = img.resize(size, Image.LANCZOS) return img def _run_inference(self, tgt_img, ref_img, ddim_steps, scale, num_images): """ Run diffusion sampling. 完全复用 infer.py 的逻辑,使用 dataloader。 """ from Dataset_custom import Dataset_custom from gen_lmk_and_mask import gen_lmk_and_mask import tempfile with tempfile.TemporaryDirectory() as tmpdir: tgt_path = Path(tmpdir) / "tgt.png" ref_path = Path(tmpdir) / "ref.png" tgt_img.save(tgt_path) ref_img.save(ref_path) gen_lmk_and_mask([str(tgt_path), str(ref_path)], write_cache=True) dataset = Dataset_custom( "test", task=self.task_id, paths_tgt=[str(tgt_path)], paths_ref=[str(ref_path)], ) dataloader = torch.utils.data.DataLoader( dataset, batch_size=1, num_workers=1, pin_memory=True, shuffle=False, drop_last=False, ) run_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = self.model.to(run_device) with torch.no_grad(): for test_batch, prior, test_model_kwargs, out_stem_batch in dataloader: test_batch = test_batch.to(run_device) if test_batch.shape[0] == 1: test_batch = test_batch.repeat(num_images, 1, 1, 1) if isinstance(prior, torch.Tensor): prior = prior.to(run_device) if prior.shape[0] == 1: prior = prior.repeat(num_images, 1, 1, 1) for k, v in test_model_kwargs.items(): if isinstance(v, torch.Tensor): v = v.to(run_device) if v.shape[0] == 1: repeats = [num_images] + [1] * (v.ndim - 1) v = v.repeat(*repeats) test_model_kwargs[k] = v elif isinstance(v, dict): new_v = {} for kk, vv in v.items(): if isinstance(vv, torch.Tensor): vv = vv.to(run_device) if vv.shape[0] == 1: repeats = [num_images] + [1] * (vv.ndim - 1) vv = vv.repeat(*repeats) new_v[kk] = vv else: new_v[kk] = vv test_model_kwargs[k] = new_v elif isinstance(v, list): test_model_kwargs[k] = v * num_images self.model.set_task(test_model_kwargs) bs = num_images batch_ = { **test_model_kwargs, "GT": torch.zeros(num_images, *test_model_kwargs["inpaint_image"].shape[1:], device=run_device), } batch_, c = self.model.get_input_and_conditioning(batch_, device=run_device) z_inpaint = batch_["z4_inpaint"] z_inpaint_mask = batch_["tgt_mask_64"] z_ref = batch_["z_ref"] z9 = batch_["z9"] uc = None if scale != 1.0: uc = self.model.learnable_vector[self.task_id].repeat(bs, 1, 1) shape = [C, H // F, W // F] start_code = None samples_ddim, _ = self.sampler.sample( S=ddim_steps, conditioning=c, batch_size=bs, shape=shape, verbose=False, unconditional_guidance_scale=scale, unconditional_conditioning=uc, eta=0.0, x_T=start_code, log_every_t=100, z_inpaint=z_inpaint, z_inpaint_mask=z_inpaint_mask, z_ref=z_ref, z9=z9, ) x_samples_ddim = self.model.decode_first_stage(samples_ddim) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) self.model.unset_task() return x_samples_ddim def _postprocess(self, tensor): """Convert model output tensor to PIL Image.""" img_array = tensor.cpu().permute(1, 2, 0).numpy() img_array = (img_array * 255).astype(np.uint8) return Image.fromarray(img_array) def infer_single( tgt_path, ref_path, task="face", output_path=None, ddim_steps=DDIM_STEPS_DEFAULT, scale=SCALE_DEFAULT, device="cuda", ): """ Convenience function for single inference. """ pipeline = UniBioTransferPipeline.from_pretrained(task=task, device=device) result = pipeline(tgt_path, ref_path, ddim_steps=ddim_steps, scale=scale) if output_path is not None: result.save(output_path) print(f"Saved result to {output_path}") return result if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="UniBioTransfer inference") parser.add_argument("--task", type=str, default="face", choices=["face", "hair", "motion", "head"]) parser.add_argument("--tgt", type=str, required=True, help="Path to target image") parser.add_argument("--ref", type=str, required=True, help="Path to reference image") parser.add_argument("--out", type=str, default="result.png", help="Output path") parser.add_argument("--ddim-steps", type=int, default=50) parser.add_argument("--scale", type=float, default=3.0) parser.add_argument("--device", type=str, default="cuda") args = parser.parse_args() result = infer_single( args.tgt, args.ref, task=args.task, output_path=args.out, ddim_steps=args.ddim_steps, scale=args.scale, device=args.device, ) print(f"Inference complete. Result shape: {result.size}")