| |
| """ |
| Simple DI²FIX inference script. |
| |
| Expected files in the same folder: |
| - inference.py |
| - config.json |
| - model.py |
| - mv_unet.py |
| - pipeline_difix.py |
| - di2fix_utils.py |
| |
| Expected file on the Hugging Face model repo: |
| - model_80001.pkl |
| |
| Example: |
| python inference.py \ |
| --source examples/source.png \ |
| --reference examples/reference.png \ |
| --output output.png \ |
| --repo_id ChengYou305/DI2FIX_HF |
| """ |
|
|
| import argparse |
| import json |
| from pathlib import Path |
| from typing import Dict, Any |
|
|
| import torch |
| from PIL import Image |
| import torchvision.transforms.functional as TF |
| from huggingface_hub import hf_hub_download |
|
|
| from model import Di2fix |
| from pipeline_difix import DifixPipeline |
| from di2fix_utils import ( |
| to_uint8, |
| load_pipe_weights_into_model, |
| load_finetuned_ckpt_model_only, |
| ) |
|
|
|
|
| DEFAULT_CONFIG: Dict[str, Any] = { |
| "model_repo_id": "ChengYou305/DI2FIX_HF", |
| "checkpoint_filename": "model_80001.pkl", |
| "base_difix_repo_id": "nvidia/difix_ref", |
| "lora_rank_vae": 4, |
| "timestep": 199, |
| "mv_unet": True, |
| "prompt": "remove degradation", |
| "max_side": 536, |
| "weight_dtype": "float32", |
| } |
|
|
|
|
| def load_config(config_path: str) -> Dict[str, Any]: |
| """Load config.json if it exists, otherwise fall back to DEFAULT_CONFIG.""" |
| config = dict(DEFAULT_CONFIG) |
| path = Path(config_path) |
|
|
| if path.is_file(): |
| with path.open("r", encoding="utf-8") as f: |
| user_config = json.load(f) |
| config.update(user_config) |
| else: |
| print(f"[WARN] Config file not found: {path}. Using default config.") |
|
|
| return config |
|
|
|
|
| def get_dtype(dtype_name: str) -> torch.dtype: |
| if dtype_name == "float16": |
| return torch.float16 |
| if dtype_name == "bfloat16": |
| return torch.bfloat16 |
| if dtype_name == "float32": |
| return torch.float32 |
| raise ValueError(f"Unsupported weight_dtype: {dtype_name}") |
|
|
|
|
| def resize_keep_aspect(img: Image.Image, max_side: int) -> Image.Image: |
| """Resize image so the longest side is <= max_side, preserving aspect ratio.""" |
| img = img.convert("RGB") |
| w, h = img.size |
|
|
| if max(w, h) <= max_side: |
| return img |
|
|
| scale = max_side / max(w, h) |
| new_w = int(round(w * scale)) |
| new_h = int(round(h * scale)) |
| return img.resize((new_w, new_h), Image.BICUBIC) |
|
|
|
|
| def pad_01(img_t: torch.Tensor, target_h: int, target_w: int): |
| """Pad CHW tensor with zeros to target_h x target_w.""" |
| c, h, w = img_t.shape |
| if h > target_h or w > target_w: |
| raise ValueError( |
| f"Input tensor size {(h, w)} is larger than target size {(target_h, target_w)}." |
| ) |
|
|
| out = torch.zeros((c, target_h, target_w), dtype=img_t.dtype) |
| out[:, :h, :w] = img_t |
| return out, h, w |
|
|
|
|
| def preprocess_pair( |
| source_img: Image.Image, |
| reference_img: Image.Image, |
| max_side: int, |
| ): |
| """ |
| Prepare source/reference pair for DI²FIX. |
| |
| Output shape: |
| x_src: (1, 2, 3, max_side, max_side) |
| """ |
| source_img = resize_keep_aspect(source_img, max_side) |
| reference_img = resize_keep_aspect(reference_img, max_side) |
|
|
| src_t = TF.to_tensor(source_img) |
| ref_t = TF.to_tensor(reference_img) |
|
|
| src_t, org_h, org_w = pad_01(src_t, max_side, max_side) |
| ref_t, _, _ = pad_01(ref_t, max_side, max_side) |
|
|
| src_t = TF.normalize(src_t, mean=[0.5], std=[0.5]) |
| ref_t = TF.normalize(ref_t, mean=[0.5], std=[0.5]) |
|
|
| x_src = torch.stack([src_t, ref_t], dim=0).unsqueeze(0) |
| return x_src, org_h, org_w |
|
|
|
|
| def postprocess_source(x_pred: torch.Tensor, org_h: int, org_w: int) -> Image.Image: |
| """Convert DI²FIX output tensor to PIL image and crop away padding.""" |
| x_pred = x_pred[:, :, :, :org_h, :org_w] |
| fixed = x_pred[:, 0].detach().float().cpu()[0] |
|
|
| fixed_hwc = fixed.permute(1, 2, 0) |
| fixed_u8 = to_uint8(fixed_hwc).numpy() |
|
|
| return Image.fromarray(fixed_u8) |
|
|
|
|
| def load_model(config: Dict[str, Any], repo_id: str, report: bool = True): |
| """ |
| Load DI²FIX. |
| |
| This follows the same loading logic as the Gradio demo: |
| 1. Build Di2fix. |
| 2. Load base DIFIX weights from base_difix_repo_id. |
| 3. Download model_80001.pkl from the HF model repo with hf_hub_download. |
| 4. Load the finetuned DI²FIX checkpoint. |
| """ |
| if not torch.cuda.is_available(): |
| raise RuntimeError( |
| "CUDA GPU is required because the current model.py uses hard-coded .cuda() calls." |
| ) |
|
|
| device = torch.device("cuda") |
| weight_dtype = get_dtype(config.get("weight_dtype", "float32")) |
|
|
| print("[DI2FIX] Building model...") |
| model = Di2fix( |
| lora_rank_vae=int(config["lora_rank_vae"]), |
| timestep=int(config["timestep"]), |
| mv_unet=bool(config["mv_unet"]), |
| ) |
|
|
| base_repo_id = config["base_difix_repo_id"] |
| print(f"[DI2FIX] Loading base DIFIX weights from: {base_repo_id}") |
| pipe = DifixPipeline.from_pretrained( |
| base_repo_id, |
| trust_remote_code=True, |
| ) |
| pipe.to(device) |
|
|
| load_pipe_weights_into_model(pipe, model, report=report) |
| del pipe |
| torch.cuda.empty_cache() |
|
|
| checkpoint_filename = config["checkpoint_filename"] |
| print(f"[DI2FIX] Downloading checkpoint from {repo_id}: {checkpoint_filename}") |
| ckpt_path = hf_hub_download( |
| repo_id=repo_id, |
| filename=checkpoint_filename, |
| repo_type="model", |
| ) |
|
|
| print(f"[DI2FIX] Loading checkpoint: {ckpt_path}") |
| load_finetuned_ckpt_model_only(model, ckpt_path) |
|
|
| model.to(device=device, dtype=weight_dtype) |
| model.eval() |
| torch.set_grad_enabled(False) |
|
|
| print("[DI2FIX] Model ready.") |
| return model, device, weight_dtype |
|
|
|
|
| @torch.no_grad() |
| def run_inference( |
| model, |
| source_path: str, |
| reference_path: str, |
| output_path: str, |
| config: Dict[str, Any], |
| device: torch.device, |
| weight_dtype: torch.dtype, |
| ): |
| source_img = Image.open(source_path).convert("RGB") |
| reference_img = Image.open(reference_path).convert("RGB") |
|
|
| x_src, org_h, org_w = preprocess_pair( |
| source_img=source_img, |
| reference_img=reference_img, |
| max_side=int(config["max_side"]), |
| ) |
| x_src = x_src.to(device=device, dtype=weight_dtype) |
|
|
| prompt = config.get("prompt", "remove degradation") |
| input_ids = model.tokenizer( |
| prompt, |
| max_length=model.tokenizer.model_max_length, |
| padding="max_length", |
| truncation=True, |
| return_tensors="pt", |
| ).input_ids.to(device) |
|
|
| x_pred = model(x_src, prompt_tokens=input_ids) |
|
|
| output_img = postprocess_source(x_pred, org_h, org_w) |
| output_img.save(output_path) |
|
|
| print(f"[DI2FIX] Saved output to: {output_path}") |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description="Run DI²FIX inference.") |
| parser.add_argument("--source", type=str, required=True, help="Path to the source image.") |
| parser.add_argument("--reference", type=str, required=True, help="Path to the reference image.") |
| parser.add_argument("--output", type=str, default="di2fix_output.png", help="Output image path.") |
| parser.add_argument("--config", type=str, default="config.json", help="Path to config.json.") |
| parser.add_argument( |
| "--repo_id", |
| type=str, |
| default=None, |
| help="HF model repo ID, e.g. ChengYou305/DI2FIX or DF3DV/DI2FIX. Overrides config.json.", |
| ) |
| parser.add_argument( |
| "--checkpoint_filename", |
| type=str, |
| default=None, |
| help="Checkpoint filename in the HF model repo. Overrides config.json.", |
| ) |
| parser.add_argument( |
| "--base_difix_repo_id", |
| type=str, |
| default=None, |
| help="Base DIFIX repo ID. Overrides config.json.", |
| ) |
| parser.add_argument("--quiet", action="store_true", help="Disable detailed weight-loading report.") |
| return parser.parse_args() |
|
|
|
|
| def main(): |
| args = parse_args() |
| config = load_config(args.config) |
|
|
| if args.checkpoint_filename is not None: |
| config["checkpoint_filename"] = args.checkpoint_filename |
| if args.base_difix_repo_id is not None: |
| config["base_difix_repo_id"] = args.base_difix_repo_id |
|
|
| repo_id = args.repo_id or config["model_repo_id"] |
|
|
| model, device, weight_dtype = load_model( |
| config=config, |
| repo_id=repo_id, |
| report=not args.quiet, |
| ) |
|
|
| run_inference( |
| model=model, |
| source_path=args.source, |
| reference_path=args.reference, |
| output_path=args.output, |
| config=config, |
| device=device, |
| weight_dtype=weight_dtype, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|