DI2FIX_HF / inference.py
ChengYou305's picture
Upload 9 files
99baf56 verified
Raw
History Blame Contribute Delete
8.57 kB
#!/usr/bin/env python3
"""
Simple DI²FIX inference script.
Expected files in the same folder:
- inference.py
- config.json
- model.py
- mv_unet.py
- pipeline_difix.py
- di2fix_utils.py
Expected file on the Hugging Face model repo:
- model_80001.pkl
Example:
python inference.py \
--source examples/source.png \
--reference examples/reference.png \
--output output.png \
--repo_id ChengYou305/DI2FIX_HF
"""
import argparse
import json
from pathlib import Path
from typing import Dict, Any
import torch
from PIL import Image
import torchvision.transforms.functional as TF
from huggingface_hub import hf_hub_download
from model import Di2fix
from pipeline_difix import DifixPipeline
from di2fix_utils import (
to_uint8,
load_pipe_weights_into_model,
load_finetuned_ckpt_model_only,
)
DEFAULT_CONFIG: Dict[str, Any] = {
"model_repo_id": "ChengYou305/DI2FIX_HF",
"checkpoint_filename": "model_80001.pkl",
"base_difix_repo_id": "nvidia/difix_ref",
"lora_rank_vae": 4,
"timestep": 199,
"mv_unet": True,
"prompt": "remove degradation",
"max_side": 536,
"weight_dtype": "float32",
}
def load_config(config_path: str) -> Dict[str, Any]:
"""Load config.json if it exists, otherwise fall back to DEFAULT_CONFIG."""
config = dict(DEFAULT_CONFIG)
path = Path(config_path)
if path.is_file():
with path.open("r", encoding="utf-8") as f:
user_config = json.load(f)
config.update(user_config)
else:
print(f"[WARN] Config file not found: {path}. Using default config.")
return config
def get_dtype(dtype_name: str) -> torch.dtype:
if dtype_name == "float16":
return torch.float16
if dtype_name == "bfloat16":
return torch.bfloat16
if dtype_name == "float32":
return torch.float32
raise ValueError(f"Unsupported weight_dtype: {dtype_name}")
def resize_keep_aspect(img: Image.Image, max_side: int) -> Image.Image:
"""Resize image so the longest side is <= max_side, preserving aspect ratio."""
img = img.convert("RGB")
w, h = img.size
if max(w, h) <= max_side:
return img
scale = max_side / max(w, h)
new_w = int(round(w * scale))
new_h = int(round(h * scale))
return img.resize((new_w, new_h), Image.BICUBIC)
def pad_01(img_t: torch.Tensor, target_h: int, target_w: int):
"""Pad CHW tensor with zeros to target_h x target_w."""
c, h, w = img_t.shape
if h > target_h or w > target_w:
raise ValueError(
f"Input tensor size {(h, w)} is larger than target size {(target_h, target_w)}."
)
out = torch.zeros((c, target_h, target_w), dtype=img_t.dtype)
out[:, :h, :w] = img_t
return out, h, w
def preprocess_pair(
source_img: Image.Image,
reference_img: Image.Image,
max_side: int,
):
"""
Prepare source/reference pair for DI²FIX.
Output shape:
x_src: (1, 2, 3, max_side, max_side)
"""
source_img = resize_keep_aspect(source_img, max_side)
reference_img = resize_keep_aspect(reference_img, max_side)
src_t = TF.to_tensor(source_img)
ref_t = TF.to_tensor(reference_img)
src_t, org_h, org_w = pad_01(src_t, max_side, max_side)
ref_t, _, _ = pad_01(ref_t, max_side, max_side)
src_t = TF.normalize(src_t, mean=[0.5], std=[0.5])
ref_t = TF.normalize(ref_t, mean=[0.5], std=[0.5])
x_src = torch.stack([src_t, ref_t], dim=0).unsqueeze(0)
return x_src, org_h, org_w
def postprocess_source(x_pred: torch.Tensor, org_h: int, org_w: int) -> Image.Image:
"""Convert DI²FIX output tensor to PIL image and crop away padding."""
x_pred = x_pred[:, :, :, :org_h, :org_w]
fixed = x_pred[:, 0].detach().float().cpu()[0]
fixed_hwc = fixed.permute(1, 2, 0)
fixed_u8 = to_uint8(fixed_hwc).numpy()
return Image.fromarray(fixed_u8)
def load_model(config: Dict[str, Any], repo_id: str, report: bool = True):
"""
Load DI²FIX.
This follows the same loading logic as the Gradio demo:
1. Build Di2fix.
2. Load base DIFIX weights from base_difix_repo_id.
3. Download model_80001.pkl from the HF model repo with hf_hub_download.
4. Load the finetuned DI²FIX checkpoint.
"""
if not torch.cuda.is_available():
raise RuntimeError(
"CUDA GPU is required because the current model.py uses hard-coded .cuda() calls."
)
device = torch.device("cuda")
weight_dtype = get_dtype(config.get("weight_dtype", "float32"))
print("[DI2FIX] Building model...")
model = Di2fix(
lora_rank_vae=int(config["lora_rank_vae"]),
timestep=int(config["timestep"]),
mv_unet=bool(config["mv_unet"]),
)
base_repo_id = config["base_difix_repo_id"]
print(f"[DI2FIX] Loading base DIFIX weights from: {base_repo_id}")
pipe = DifixPipeline.from_pretrained(
base_repo_id,
trust_remote_code=True,
)
pipe.to(device)
load_pipe_weights_into_model(pipe, model, report=report)
del pipe
torch.cuda.empty_cache()
checkpoint_filename = config["checkpoint_filename"]
print(f"[DI2FIX] Downloading checkpoint from {repo_id}: {checkpoint_filename}")
ckpt_path = hf_hub_download(
repo_id=repo_id,
filename=checkpoint_filename,
repo_type="model",
)
print(f"[DI2FIX] Loading checkpoint: {ckpt_path}")
load_finetuned_ckpt_model_only(model, ckpt_path)
model.to(device=device, dtype=weight_dtype)
model.eval()
torch.set_grad_enabled(False)
print("[DI2FIX] Model ready.")
return model, device, weight_dtype
@torch.no_grad()
def run_inference(
model,
source_path: str,
reference_path: str,
output_path: str,
config: Dict[str, Any],
device: torch.device,
weight_dtype: torch.dtype,
):
source_img = Image.open(source_path).convert("RGB")
reference_img = Image.open(reference_path).convert("RGB")
x_src, org_h, org_w = preprocess_pair(
source_img=source_img,
reference_img=reference_img,
max_side=int(config["max_side"]),
)
x_src = x_src.to(device=device, dtype=weight_dtype)
prompt = config.get("prompt", "remove degradation")
input_ids = model.tokenizer(
prompt,
max_length=model.tokenizer.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
).input_ids.to(device)
x_pred = model(x_src, prompt_tokens=input_ids)
output_img = postprocess_source(x_pred, org_h, org_w)
output_img.save(output_path)
print(f"[DI2FIX] Saved output to: {output_path}")
def parse_args():
parser = argparse.ArgumentParser(description="Run DI²FIX inference.")
parser.add_argument("--source", type=str, required=True, help="Path to the source image.")
parser.add_argument("--reference", type=str, required=True, help="Path to the reference image.")
parser.add_argument("--output", type=str, default="di2fix_output.png", help="Output image path.")
parser.add_argument("--config", type=str, default="config.json", help="Path to config.json.")
parser.add_argument(
"--repo_id",
type=str,
default=None,
help="HF model repo ID, e.g. ChengYou305/DI2FIX or DF3DV/DI2FIX. Overrides config.json.",
)
parser.add_argument(
"--checkpoint_filename",
type=str,
default=None,
help="Checkpoint filename in the HF model repo. Overrides config.json.",
)
parser.add_argument(
"--base_difix_repo_id",
type=str,
default=None,
help="Base DIFIX repo ID. Overrides config.json.",
)
parser.add_argument("--quiet", action="store_true", help="Disable detailed weight-loading report.")
return parser.parse_args()
def main():
args = parse_args()
config = load_config(args.config)
if args.checkpoint_filename is not None:
config["checkpoint_filename"] = args.checkpoint_filename
if args.base_difix_repo_id is not None:
config["base_difix_repo_id"] = args.base_difix_repo_id
repo_id = args.repo_id or config["model_repo_id"]
model, device, weight_dtype = load_model(
config=config,
repo_id=repo_id,
report=not args.quiet,
)
run_inference(
model=model,
source_path=args.source,
reference_path=args.reference,
output_path=args.output,
config=config,
device=device,
weight_dtype=weight_dtype,
)
if __name__ == "__main__":
main()