#!/usr/bin/env python3 """ Simple DI²FIX inference script. Expected files in the same folder: - inference.py - config.json - model.py - mv_unet.py - pipeline_difix.py - di2fix_utils.py Expected file on the Hugging Face model repo: - model_80001.pkl Example: python inference.py \ --source examples/source.png \ --reference examples/reference.png \ --output output.png \ --repo_id ChengYou305/DI2FIX_HF """ import argparse import json from pathlib import Path from typing import Dict, Any import torch from PIL import Image import torchvision.transforms.functional as TF from huggingface_hub import hf_hub_download from model import Di2fix from pipeline_difix import DifixPipeline from di2fix_utils import ( to_uint8, load_pipe_weights_into_model, load_finetuned_ckpt_model_only, ) DEFAULT_CONFIG: Dict[str, Any] = { "model_repo_id": "ChengYou305/DI2FIX_HF", "checkpoint_filename": "model_80001.pkl", "base_difix_repo_id": "nvidia/difix_ref", "lora_rank_vae": 4, "timestep": 199, "mv_unet": True, "prompt": "remove degradation", "max_side": 536, "weight_dtype": "float32", } def load_config(config_path: str) -> Dict[str, Any]: """Load config.json if it exists, otherwise fall back to DEFAULT_CONFIG.""" config = dict(DEFAULT_CONFIG) path = Path(config_path) if path.is_file(): with path.open("r", encoding="utf-8") as f: user_config = json.load(f) config.update(user_config) else: print(f"[WARN] Config file not found: {path}. Using default config.") return config def get_dtype(dtype_name: str) -> torch.dtype: if dtype_name == "float16": return torch.float16 if dtype_name == "bfloat16": return torch.bfloat16 if dtype_name == "float32": return torch.float32 raise ValueError(f"Unsupported weight_dtype: {dtype_name}") def resize_keep_aspect(img: Image.Image, max_side: int) -> Image.Image: """Resize image so the longest side is <= max_side, preserving aspect ratio.""" img = img.convert("RGB") w, h = img.size if max(w, h) <= max_side: return img scale = max_side / max(w, h) new_w = int(round(w * scale)) new_h = int(round(h * scale)) return img.resize((new_w, new_h), Image.BICUBIC) def pad_01(img_t: torch.Tensor, target_h: int, target_w: int): """Pad CHW tensor with zeros to target_h x target_w.""" c, h, w = img_t.shape if h > target_h or w > target_w: raise ValueError( f"Input tensor size {(h, w)} is larger than target size {(target_h, target_w)}." ) out = torch.zeros((c, target_h, target_w), dtype=img_t.dtype) out[:, :h, :w] = img_t return out, h, w def preprocess_pair( source_img: Image.Image, reference_img: Image.Image, max_side: int, ): """ Prepare source/reference pair for DI²FIX. Output shape: x_src: (1, 2, 3, max_side, max_side) """ source_img = resize_keep_aspect(source_img, max_side) reference_img = resize_keep_aspect(reference_img, max_side) src_t = TF.to_tensor(source_img) ref_t = TF.to_tensor(reference_img) src_t, org_h, org_w = pad_01(src_t, max_side, max_side) ref_t, _, _ = pad_01(ref_t, max_side, max_side) src_t = TF.normalize(src_t, mean=[0.5], std=[0.5]) ref_t = TF.normalize(ref_t, mean=[0.5], std=[0.5]) x_src = torch.stack([src_t, ref_t], dim=0).unsqueeze(0) return x_src, org_h, org_w def postprocess_source(x_pred: torch.Tensor, org_h: int, org_w: int) -> Image.Image: """Convert DI²FIX output tensor to PIL image and crop away padding.""" x_pred = x_pred[:, :, :, :org_h, :org_w] fixed = x_pred[:, 0].detach().float().cpu()[0] fixed_hwc = fixed.permute(1, 2, 0) fixed_u8 = to_uint8(fixed_hwc).numpy() return Image.fromarray(fixed_u8) def load_model(config: Dict[str, Any], repo_id: str, report: bool = True): """ Load DI²FIX. This follows the same loading logic as the Gradio demo: 1. Build Di2fix. 2. Load base DIFIX weights from base_difix_repo_id. 3. Download model_80001.pkl from the HF model repo with hf_hub_download. 4. Load the finetuned DI²FIX checkpoint. """ if not torch.cuda.is_available(): raise RuntimeError( "CUDA GPU is required because the current model.py uses hard-coded .cuda() calls." ) device = torch.device("cuda") weight_dtype = get_dtype(config.get("weight_dtype", "float32")) print("[DI2FIX] Building model...") model = Di2fix( lora_rank_vae=int(config["lora_rank_vae"]), timestep=int(config["timestep"]), mv_unet=bool(config["mv_unet"]), ) base_repo_id = config["base_difix_repo_id"] print(f"[DI2FIX] Loading base DIFIX weights from: {base_repo_id}") pipe = DifixPipeline.from_pretrained( base_repo_id, trust_remote_code=True, ) pipe.to(device) load_pipe_weights_into_model(pipe, model, report=report) del pipe torch.cuda.empty_cache() checkpoint_filename = config["checkpoint_filename"] print(f"[DI2FIX] Downloading checkpoint from {repo_id}: {checkpoint_filename}") ckpt_path = hf_hub_download( repo_id=repo_id, filename=checkpoint_filename, repo_type="model", ) print(f"[DI2FIX] Loading checkpoint: {ckpt_path}") load_finetuned_ckpt_model_only(model, ckpt_path) model.to(device=device, dtype=weight_dtype) model.eval() torch.set_grad_enabled(False) print("[DI2FIX] Model ready.") return model, device, weight_dtype @torch.no_grad() def run_inference( model, source_path: str, reference_path: str, output_path: str, config: Dict[str, Any], device: torch.device, weight_dtype: torch.dtype, ): source_img = Image.open(source_path).convert("RGB") reference_img = Image.open(reference_path).convert("RGB") x_src, org_h, org_w = preprocess_pair( source_img=source_img, reference_img=reference_img, max_side=int(config["max_side"]), ) x_src = x_src.to(device=device, dtype=weight_dtype) prompt = config.get("prompt", "remove degradation") input_ids = model.tokenizer( prompt, max_length=model.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt", ).input_ids.to(device) x_pred = model(x_src, prompt_tokens=input_ids) output_img = postprocess_source(x_pred, org_h, org_w) output_img.save(output_path) print(f"[DI2FIX] Saved output to: {output_path}") def parse_args(): parser = argparse.ArgumentParser(description="Run DI²FIX inference.") parser.add_argument("--source", type=str, required=True, help="Path to the source image.") parser.add_argument("--reference", type=str, required=True, help="Path to the reference image.") parser.add_argument("--output", type=str, default="di2fix_output.png", help="Output image path.") parser.add_argument("--config", type=str, default="config.json", help="Path to config.json.") parser.add_argument( "--repo_id", type=str, default=None, help="HF model repo ID, e.g. ChengYou305/DI2FIX or DF3DV/DI2FIX. Overrides config.json.", ) parser.add_argument( "--checkpoint_filename", type=str, default=None, help="Checkpoint filename in the HF model repo. Overrides config.json.", ) parser.add_argument( "--base_difix_repo_id", type=str, default=None, help="Base DIFIX repo ID. Overrides config.json.", ) parser.add_argument("--quiet", action="store_true", help="Disable detailed weight-loading report.") return parser.parse_args() def main(): args = parse_args() config = load_config(args.config) if args.checkpoint_filename is not None: config["checkpoint_filename"] = args.checkpoint_filename if args.base_difix_repo_id is not None: config["base_difix_repo_id"] = args.base_difix_repo_id repo_id = args.repo_id or config["model_repo_id"] model, device, weight_dtype = load_model( config=config, repo_id=repo_id, report=not args.quiet, ) run_inference( model=model, source_path=args.source, reference_path=args.reference, output_path=args.output, config=config, device=device, weight_dtype=weight_dtype, ) if __name__ == "__main__": main()