Spaces:
Running on Zero
Running on Zero
Upload 21 files
Browse files- pr_iqa/__init__.py +3 -0
- pr_iqa/__pycache__/__init__.cpython-310.pyc +0 -0
- pr_iqa/__pycache__/__init__.cpython-38.pyc +0 -0
- pr_iqa/dataset.py +173 -0
- pr_iqa/loss.py +108 -0
- pr_iqa/model/__init__.py +43 -0
- pr_iqa/model/__pycache__/__init__.cpython-310.pyc +0 -0
- pr_iqa/model/__pycache__/__init__.cpython-38.pyc +0 -0
- pr_iqa/model/__pycache__/layers.cpython-310.pyc +0 -0
- pr_iqa/model/__pycache__/layers.cpython-38.pyc +0 -0
- pr_iqa/model/__pycache__/priqa.cpython-310.pyc +0 -0
- pr_iqa/model/__pycache__/priqa.cpython-38.pyc +0 -0
- pr_iqa/model/layers.py +413 -0
- pr_iqa/model/priqa.py +264 -0
- pr_iqa/partial_map/__init__.py +3 -0
- pr_iqa/partial_map/__pycache__/__init__.cpython-310.pyc +0 -0
- pr_iqa/partial_map/__pycache__/__init__.cpython-38.pyc +0 -0
- pr_iqa/partial_map/__pycache__/feature_metric.cpython-310.pyc +0 -0
- pr_iqa/partial_map/__pycache__/feature_metric.cpython-38.pyc +0 -0
- pr_iqa/partial_map/feature_metric.py +285 -0
- pr_iqa/transforms.py +86 -0
pr_iqa/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .model import PRIQA, build_priqa
|
| 2 |
+
|
| 3 |
+
__all__ = ["PRIQA", "build_priqa"]
|
pr_iqa/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (214 Bytes). View file
|
|
|
pr_iqa/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (212 Bytes). View file
|
|
|
pr_iqa/dataset.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Dataset for PR-IQA training.
|
| 3 |
+
|
| 4 |
+
Expected directory structure per scene::
|
| 5 |
+
|
| 6 |
+
s000/
|
| 7 |
+
├── total/ # Original keyframe images (RGB)
|
| 8 |
+
│ ├── 0000.jpg
|
| 9 |
+
│ ├── 0001.jpg
|
| 10 |
+
│ └── ...
|
| 11 |
+
├── tgt_diffusion/ # Generated images per target
|
| 12 |
+
│ └── 0005/
|
| 13 |
+
│ ├── 0005_diff_0.jpg
|
| 14 |
+
│ └── ...
|
| 15 |
+
├── total_map/ # Full quality maps (GT, grayscale)
|
| 16 |
+
│ └── 0005/
|
| 17 |
+
│ ├── 0005_diff_0.png
|
| 18 |
+
│ └── ...
|
| 19 |
+
├── partial_map/ # Partial quality maps (from FeatureMetric)
|
| 20 |
+
│ └── 0005/
|
| 21 |
+
│ ├── 0005_diff_0_ref+10_0015.png
|
| 22 |
+
│ └── ...
|
| 23 |
+
└── partial_mask/ # Overlap masks
|
| 24 |
+
└── 0005/
|
| 25 |
+
├── 0005_diff_0_ref+10_0015.png
|
| 26 |
+
└── ...
|
| 27 |
+
|
| 28 |
+
Each sample is a tuple: (tgt, tgt_diff, full_map, partial_map, partial_mask, current_ref).
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
import random
|
| 32 |
+
from pathlib import Path
|
| 33 |
+
|
| 34 |
+
import torch
|
| 35 |
+
from PIL import Image
|
| 36 |
+
from torch.utils.data import Dataset
|
| 37 |
+
import torchvision.transforms.functional as TF
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class SceneDataset(Dataset):
|
| 41 |
+
"""Dataset that enumerates all valid (tgt, diff, ref, partial_map, mask) combinations."""
|
| 42 |
+
|
| 43 |
+
def __init__(self, root_dir, rgb_transform=None, grayscale_transform=None, training=True):
|
| 44 |
+
self.root_dir = Path(root_dir)
|
| 45 |
+
self.rgb_transform = rgb_transform
|
| 46 |
+
self.grayscale_transform = grayscale_transform
|
| 47 |
+
self.samples = []
|
| 48 |
+
self.ref_deltas = [-20, -10, 10, 20]
|
| 49 |
+
self.training = training
|
| 50 |
+
|
| 51 |
+
for scene_path in sorted(self.root_dir.glob("s*")):
|
| 52 |
+
if not scene_path.is_dir():
|
| 53 |
+
continue
|
| 54 |
+
total_dir = scene_path / "total"
|
| 55 |
+
if not total_dir.is_dir():
|
| 56 |
+
continue
|
| 57 |
+
|
| 58 |
+
total_images = sorted(total_dir.glob("*.jpg"), key=lambda p: int(p.stem))
|
| 59 |
+
num_total = len(total_images)
|
| 60 |
+
if num_total == 0:
|
| 61 |
+
continue
|
| 62 |
+
|
| 63 |
+
for i, tgt_path in enumerate(total_images):
|
| 64 |
+
tgt_stem = tgt_path.stem
|
| 65 |
+
|
| 66 |
+
# Find reference images at fixed offsets
|
| 67 |
+
ref_info_list = []
|
| 68 |
+
complete = True
|
| 69 |
+
for d in self.ref_deltas:
|
| 70 |
+
ref_idx = (i + d) % num_total
|
| 71 |
+
ref_path = total_images[ref_idx]
|
| 72 |
+
if not ref_path.exists():
|
| 73 |
+
complete = False
|
| 74 |
+
break
|
| 75 |
+
ref_info_list.append({"path": ref_path, "offset": d})
|
| 76 |
+
|
| 77 |
+
if not complete:
|
| 78 |
+
continue
|
| 79 |
+
|
| 80 |
+
tgt_diff_dir = scene_path / "tgt_diffusion" / tgt_stem
|
| 81 |
+
total_map_dir = scene_path / "total_map" / tgt_stem
|
| 82 |
+
|
| 83 |
+
for tgt_diff_path in sorted(tgt_diff_dir.glob("*_diff_*.jpg")):
|
| 84 |
+
full_map_path = total_map_dir / f"{tgt_diff_path.stem}.png"
|
| 85 |
+
if not full_map_path.exists():
|
| 86 |
+
continue
|
| 87 |
+
|
| 88 |
+
tgt_diff_stem = tgt_diff_path.stem
|
| 89 |
+
|
| 90 |
+
for ref_info in ref_info_list:
|
| 91 |
+
ref_path = ref_info["path"]
|
| 92 |
+
ref_stem = ref_path.stem
|
| 93 |
+
d = ref_info["offset"]
|
| 94 |
+
|
| 95 |
+
mask_path = (
|
| 96 |
+
scene_path / "partial_mask" / tgt_stem
|
| 97 |
+
/ f"{tgt_diff_stem}_ref{d:+d}_{ref_stem}.png"
|
| 98 |
+
)
|
| 99 |
+
map_path = (
|
| 100 |
+
scene_path / "partial_map" / tgt_stem
|
| 101 |
+
/ f"{tgt_diff_stem}_ref{d:+d}_{ref_stem}.png"
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
if mask_path.exists() and map_path.exists():
|
| 105 |
+
self.samples.append({
|
| 106 |
+
"tgt": tgt_path,
|
| 107 |
+
"tgt_diff": tgt_diff_path,
|
| 108 |
+
"full_map": full_map_path,
|
| 109 |
+
"partial_mask": mask_path,
|
| 110 |
+
"partial_map": map_path,
|
| 111 |
+
"current_ref": ref_path,
|
| 112 |
+
})
|
| 113 |
+
|
| 114 |
+
def __len__(self):
|
| 115 |
+
return len(self.samples)
|
| 116 |
+
|
| 117 |
+
def __getitem__(self, idx):
|
| 118 |
+
paths = self.samples[idx]
|
| 119 |
+
|
| 120 |
+
tgt_img = Image.open(paths["tgt"]).convert("RGB")
|
| 121 |
+
tgt_diff_img = Image.open(paths["tgt_diff"]).convert("RGB")
|
| 122 |
+
full_map_img = Image.open(paths["full_map"]).convert("L")
|
| 123 |
+
partial_mask_img = Image.open(paths["partial_mask"]).convert("L")
|
| 124 |
+
partial_map_img = Image.open(paths["partial_map"]).convert("L")
|
| 125 |
+
cur_ref_img = Image.open(paths["current_ref"]).convert("RGB")
|
| 126 |
+
|
| 127 |
+
# -- Augmentation (training only) --
|
| 128 |
+
if self.training:
|
| 129 |
+
if random.random() > 0.5:
|
| 130 |
+
tgt_img = TF.hflip(tgt_img)
|
| 131 |
+
tgt_diff_img = TF.hflip(tgt_diff_img)
|
| 132 |
+
cur_ref_img = TF.hflip(cur_ref_img)
|
| 133 |
+
full_map_img = TF.hflip(full_map_img)
|
| 134 |
+
partial_mask_img = TF.hflip(partial_mask_img)
|
| 135 |
+
partial_map_img = TF.hflip(partial_map_img)
|
| 136 |
+
|
| 137 |
+
if random.random() > 0.7:
|
| 138 |
+
tgt_img = TF.vflip(tgt_img)
|
| 139 |
+
tgt_diff_img = TF.vflip(tgt_diff_img)
|
| 140 |
+
cur_ref_img = TF.vflip(cur_ref_img)
|
| 141 |
+
full_map_img = TF.vflip(full_map_img)
|
| 142 |
+
partial_mask_img = TF.vflip(partial_mask_img)
|
| 143 |
+
partial_map_img = TF.vflip(partial_map_img)
|
| 144 |
+
|
| 145 |
+
if random.random() > 0.5:
|
| 146 |
+
brightness = random.uniform(0.9, 1.1)
|
| 147 |
+
contrast = random.uniform(0.9, 1.1)
|
| 148 |
+
saturation = random.uniform(0.9, 1.1)
|
| 149 |
+
for fn in [TF.adjust_brightness, TF.adjust_contrast, TF.adjust_saturation]:
|
| 150 |
+
val = brightness if fn == TF.adjust_brightness else (
|
| 151 |
+
contrast if fn == TF.adjust_contrast else saturation
|
| 152 |
+
)
|
| 153 |
+
tgt_img = fn(tgt_img, val)
|
| 154 |
+
tgt_diff_img = fn(tgt_diff_img, val)
|
| 155 |
+
cur_ref_img = fn(cur_ref_img, val)
|
| 156 |
+
|
| 157 |
+
if self.rgb_transform:
|
| 158 |
+
tgt_img, tgt_diff_img, cur_ref_img = map(
|
| 159 |
+
self.rgb_transform, [tgt_img, tgt_diff_img, cur_ref_img]
|
| 160 |
+
)
|
| 161 |
+
if self.grayscale_transform:
|
| 162 |
+
full_map_img, partial_mask_img, partial_map_img = map(
|
| 163 |
+
self.grayscale_transform, [full_map_img, partial_mask_img, partial_map_img]
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
return {
|
| 167 |
+
"tgt": tgt_img,
|
| 168 |
+
"tgt_diff": tgt_diff_img,
|
| 169 |
+
"partial_mask": partial_mask_img,
|
| 170 |
+
"partial_map": partial_map_img,
|
| 171 |
+
"full_map": full_map_img,
|
| 172 |
+
"current_ref": cur_ref_img,
|
| 173 |
+
}
|
pr_iqa/loss.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Loss functions for PR-IQA training.
|
| 3 |
+
|
| 4 |
+
Core losses:
|
| 5 |
+
- JSD (Jensen-Shannon Divergence): Distribution matching
|
| 6 |
+
- Masked L1: Pixel-wise L1 on partial map regions
|
| 7 |
+
- Pearson: Correlation-based structural loss
|
| 8 |
+
|
| 9 |
+
Additional losses (optional):
|
| 10 |
+
- Ranking: Pairwise ranking consistency
|
| 11 |
+
- Global mean/std: Statistics matching
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def loss_jsd(pred, target, tau=0.2, reduction="mean", eps=1e-6):
|
| 19 |
+
"""Jensen-Shannon Divergence loss.
|
| 20 |
+
|
| 21 |
+
Converts pixel maps to probability distributions via softmax over logits,
|
| 22 |
+
then computes symmetric KL divergence.
|
| 23 |
+
"""
|
| 24 |
+
with torch.autocast(device_type="cuda", enabled=False):
|
| 25 |
+
p = pred.float().clamp(min=eps, max=1 - eps)
|
| 26 |
+
y = target.float().clamp(min=eps, max=1 - eps)
|
| 27 |
+
|
| 28 |
+
p_logit = torch.logit(p, eps=eps) / tau
|
| 29 |
+
y_logit = torch.logit(y, eps=eps) / tau
|
| 30 |
+
|
| 31 |
+
q_hat = torch.softmax(p_logit.flatten(start_dim=1), dim=1)
|
| 32 |
+
q = torch.softmax(y_logit.flatten(start_dim=1), dim=1)
|
| 33 |
+
|
| 34 |
+
m = 0.5 * (q + q_hat)
|
| 35 |
+
|
| 36 |
+
def _kl(a, b):
|
| 37 |
+
return torch.sum(a * (torch.log(a + eps) - torch.log(b + eps)), dim=1)
|
| 38 |
+
|
| 39 |
+
jsd_per = 0.5 * (_kl(q, m) + _kl(q_hat, m))
|
| 40 |
+
|
| 41 |
+
if reduction == "mean":
|
| 42 |
+
return jsd_per.mean().to(pred.dtype)
|
| 43 |
+
elif reduction == "sum":
|
| 44 |
+
return jsd_per.sum().to(pred.dtype)
|
| 45 |
+
return jsd_per.to(pred.dtype)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def loss_masked_l1(pred, target, mask, reduction="mean"):
|
| 49 |
+
"""L1 loss masked to partial map regions."""
|
| 50 |
+
l = torch.abs(pred - target)
|
| 51 |
+
masked = l * mask
|
| 52 |
+
if reduction == "mean":
|
| 53 |
+
return masked.sum() / (mask.sum() + 1e-8)
|
| 54 |
+
elif reduction == "sum":
|
| 55 |
+
return masked.sum()
|
| 56 |
+
return masked
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def loss_l1(pred, target, reduction="mean"):
|
| 60 |
+
"""Standard L1 loss."""
|
| 61 |
+
l = (pred - target).abs()
|
| 62 |
+
if reduction == "mean":
|
| 63 |
+
return l.mean().to(pred.dtype)
|
| 64 |
+
elif reduction == "sum":
|
| 65 |
+
return l.sum().to(pred.dtype)
|
| 66 |
+
return l.to(pred.dtype)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def loss_pearson(pred, target, reduction="mean", eps=1e-6):
|
| 70 |
+
"""1 - Pearson correlation coefficient."""
|
| 71 |
+
x = pred.float().reshape(pred.shape[0], -1).contiguous()
|
| 72 |
+
y = target.float().reshape(target.shape[0], -1).contiguous()
|
| 73 |
+
|
| 74 |
+
mx = x.mean(dim=1)
|
| 75 |
+
my = y.mean(dim=1)
|
| 76 |
+
x = x - mx[:, None]
|
| 77 |
+
y = y - my[:, None]
|
| 78 |
+
|
| 79 |
+
xx = (x * x).sum(dim=1)
|
| 80 |
+
yy = (y * y).sum(dim=1)
|
| 81 |
+
denom = torch.sqrt(xx * yy + eps)
|
| 82 |
+
rho = ((x * y).sum(dim=1) / denom).clamp(-1.0, 1.0)
|
| 83 |
+
|
| 84 |
+
loss = 1.0 - rho
|
| 85 |
+
if reduction == "mean":
|
| 86 |
+
return loss.mean().to(pred.dtype)
|
| 87 |
+
elif reduction == "sum":
|
| 88 |
+
return loss.sum().to(pred.dtype)
|
| 89 |
+
return loss.to(pred.dtype)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def loss_ranking(pred, gt, margin=0.1):
|
| 93 |
+
"""Pairwise ranking loss for relative quality ordering."""
|
| 94 |
+
B, C, H, W = pred.shape
|
| 95 |
+
pred_flat = pred.view(B, -1)
|
| 96 |
+
gt_flat = gt.view(B, -1)
|
| 97 |
+
|
| 98 |
+
n = int(H * W * 0.5)
|
| 99 |
+
idx1 = torch.randint(0, H * W, (B, n), device=pred.device)
|
| 100 |
+
idx2 = torch.randint(0, H * W, (B, n), device=pred.device)
|
| 101 |
+
|
| 102 |
+
pred1 = pred_flat.gather(1, idx1)
|
| 103 |
+
pred2 = pred_flat.gather(1, idx2)
|
| 104 |
+
gt1 = gt_flat.gather(1, idx1)
|
| 105 |
+
gt2 = gt_flat.gather(1, idx2)
|
| 106 |
+
|
| 107 |
+
target = torch.sign(gt1 - gt2)
|
| 108 |
+
return F.margin_ranking_loss(pred1, pred2, target, margin=margin)
|
pr_iqa/model/__init__.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .priqa import PRIQA
|
| 2 |
+
from .layers import (
|
| 3 |
+
PartialConv2d,
|
| 4 |
+
GatedPartialEmb,
|
| 5 |
+
GatedEmb,
|
| 6 |
+
FeedForward,
|
| 7 |
+
ChannelGate,
|
| 8 |
+
Attention,
|
| 9 |
+
TransformerLikeBlock,
|
| 10 |
+
SandwichBlock,
|
| 11 |
+
Downsample,
|
| 12 |
+
Upsample,
|
| 13 |
+
Pos2d,
|
| 14 |
+
DropPath,
|
| 15 |
+
LayerNorm,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def build_priqa(
|
| 20 |
+
out_channels: int = 1,
|
| 21 |
+
dim: int = 48,
|
| 22 |
+
num_blocks: tuple = (2, 3, 3, 4),
|
| 23 |
+
heads: tuple = (1, 2, 4, 8),
|
| 24 |
+
ffn_expansion_factor: float = 2.66,
|
| 25 |
+
bias: bool = False,
|
| 26 |
+
layernorm_type: str = "WithBias",
|
| 27 |
+
use_partial_conv: bool = True,
|
| 28 |
+
) -> PRIQA:
|
| 29 |
+
"""Build a PR-IQA model with default or custom hyperparameters."""
|
| 30 |
+
return PRIQA(
|
| 31 |
+
inp_channels=4,
|
| 32 |
+
out_channels=out_channels,
|
| 33 |
+
dim=dim,
|
| 34 |
+
num_blocks=list(num_blocks),
|
| 35 |
+
heads=list(heads),
|
| 36 |
+
ffn_expansion_factor=ffn_expansion_factor,
|
| 37 |
+
bias=bias,
|
| 38 |
+
LayerNorm_type=layernorm_type,
|
| 39 |
+
use_partial_conv=use_partial_conv,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
__all__ = ["PRIQA", "build_priqa"]
|
pr_iqa/model/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (1.18 kB). View file
|
|
|
pr_iqa/model/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (1.13 kB). View file
|
|
|
pr_iqa/model/__pycache__/layers.cpython-310.pyc
ADDED
|
Binary file (13.9 kB). View file
|
|
|
pr_iqa/model/__pycache__/layers.cpython-38.pyc
ADDED
|
Binary file (14.2 kB). View file
|
|
|
pr_iqa/model/__pycache__/priqa.cpython-310.pyc
ADDED
|
Binary file (6.95 kB). View file
|
|
|
pr_iqa/model/__pycache__/priqa.cpython-38.pyc
ADDED
|
Binary file (6.94 kB). View file
|
|
|
pr_iqa/model/layers.py
ADDED
|
@@ -0,0 +1,413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Building blocks for the PR-IQA architecture.
|
| 3 |
+
|
| 4 |
+
Includes:
|
| 5 |
+
- PartialConv2d: Mask-aware convolution for inpainting
|
| 6 |
+
- GatedPartialEmb / GatedEmb: Gated patch embeddings
|
| 7 |
+
- FeedForward (FFN): Gated depth-wise separable FFN
|
| 8 |
+
- ChannelGate: SE/CBAM-style channel attention
|
| 9 |
+
- Attention: Spatial attention with xformers memory-efficient attention
|
| 10 |
+
- TransformerLikeBlock: Channel gate → Spatial attn → FFN with residuals
|
| 11 |
+
- SandwichBlock: FFN → Channel gate → Spatial attn → FFN
|
| 12 |
+
- Downsample / Upsample: Strided conv / PixelShuffle
|
| 13 |
+
- Pos2d: 2D sinusoidal positional encoding
|
| 14 |
+
- DropPath: Stochastic depth
|
| 15 |
+
- LayerNorm: Bias-free or with-bias layer normalization
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import numbers
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
from einops import rearrange
|
| 24 |
+
from xformers import ops
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# ---------------------------------------------------------------------------
|
| 28 |
+
# Layer Normalization
|
| 29 |
+
# ---------------------------------------------------------------------------
|
| 30 |
+
|
| 31 |
+
def to_3d(x):
|
| 32 |
+
return rearrange(x, "b c h w -> b (h w) c")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def to_4d(x, h, w):
|
| 36 |
+
return rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class BiasFree_LayerNorm(nn.Module):
|
| 40 |
+
def __init__(self, normalized_shape):
|
| 41 |
+
super().__init__()
|
| 42 |
+
if isinstance(normalized_shape, numbers.Integral):
|
| 43 |
+
normalized_shape = (normalized_shape,)
|
| 44 |
+
normalized_shape = torch.Size(normalized_shape)
|
| 45 |
+
assert len(normalized_shape) == 1
|
| 46 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
| 47 |
+
self.normalized_shape = normalized_shape
|
| 48 |
+
|
| 49 |
+
def forward(self, x):
|
| 50 |
+
sigma = x.var(-1, keepdim=True, unbiased=False)
|
| 51 |
+
return x / torch.sqrt(sigma + 1e-5) * self.weight
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class WithBias_LayerNorm(nn.Module):
|
| 55 |
+
def __init__(self, normalized_shape):
|
| 56 |
+
super().__init__()
|
| 57 |
+
if isinstance(normalized_shape, numbers.Integral):
|
| 58 |
+
normalized_shape = (normalized_shape,)
|
| 59 |
+
normalized_shape = torch.Size(normalized_shape)
|
| 60 |
+
assert len(normalized_shape) == 1
|
| 61 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
| 62 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
| 63 |
+
self.normalized_shape = normalized_shape
|
| 64 |
+
|
| 65 |
+
def forward(self, x):
|
| 66 |
+
mu = x.mean(-1, keepdim=True)
|
| 67 |
+
sigma = x.var(-1, keepdim=True, unbiased=False)
|
| 68 |
+
return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class LayerNorm(nn.Module):
|
| 72 |
+
def __init__(self, dim, LayerNorm_type="WithBias"):
|
| 73 |
+
super().__init__()
|
| 74 |
+
if LayerNorm_type == "BiasFree":
|
| 75 |
+
self.body = BiasFree_LayerNorm(dim)
|
| 76 |
+
else:
|
| 77 |
+
self.body = WithBias_LayerNorm(dim)
|
| 78 |
+
|
| 79 |
+
def forward(self, x):
|
| 80 |
+
h, w = x.shape[-2:]
|
| 81 |
+
return to_4d(self.body(to_3d(x)), h, w)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# ---------------------------------------------------------------------------
|
| 85 |
+
# Partial Convolution
|
| 86 |
+
# ---------------------------------------------------------------------------
|
| 87 |
+
|
| 88 |
+
class PartialConv2d(nn.Module):
|
| 89 |
+
"""Mask-aware convolution for inpainting.
|
| 90 |
+
|
| 91 |
+
Given input ``x`` and binary mask ``mask`` (1 = valid), the output is
|
| 92 |
+
normalized by the number of valid pixels in each receptive field.
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
def __init__(self, in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True):
|
| 96 |
+
super().__init__()
|
| 97 |
+
self.conv = nn.Conv2d(in_ch, out_ch, kernel_size, stride, padding, bias=False)
|
| 98 |
+
self.mask_conv = nn.Conv2d(1, out_ch, kernel_size, stride, padding, bias=False)
|
| 99 |
+
nn.init.constant_(self.mask_conv.weight, 1.0)
|
| 100 |
+
self.mask_conv.weight.requires_grad = False
|
| 101 |
+
self.bias = nn.Parameter(torch.zeros(out_ch)) if bias else None
|
| 102 |
+
|
| 103 |
+
def forward(self, x, mask):
|
| 104 |
+
with torch.no_grad():
|
| 105 |
+
mask_sum = self.mask_conv(mask).clamp(min=1e-8)
|
| 106 |
+
new_mask = (mask_sum > 0).float()
|
| 107 |
+
|
| 108 |
+
output = self.conv(x * mask) / mask_sum
|
| 109 |
+
if self.bias is not None:
|
| 110 |
+
output = output + self.bias.view(1, -1, 1, 1)
|
| 111 |
+
output = output * new_mask
|
| 112 |
+
return output, new_mask[:, 0:1]
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
# ---------------------------------------------------------------------------
|
| 116 |
+
# Gated Embeddings
|
| 117 |
+
# ---------------------------------------------------------------------------
|
| 118 |
+
|
| 119 |
+
class GatedPartialEmb(nn.Module):
|
| 120 |
+
"""Gated patch embedding using PartialConv2d (for masked inputs)."""
|
| 121 |
+
|
| 122 |
+
def __init__(self, in_c=4, embed_dim=48, bias=False):
|
| 123 |
+
super().__init__()
|
| 124 |
+
self.pconv = PartialConv2d(in_c, embed_dim * 2, kernel_size=3, stride=1, padding=1, bias=bias)
|
| 125 |
+
|
| 126 |
+
def forward(self, x_with_mask, mask):
|
| 127 |
+
"""
|
| 128 |
+
Args:
|
| 129 |
+
x_with_mask: (B, in_c, H, W) — e.g. RGB(3) + mask(1) concatenated.
|
| 130 |
+
mask: (B, 1, H, W) — binary mask for partial conv.
|
| 131 |
+
"""
|
| 132 |
+
x, mask_out = self.pconv(x_with_mask, mask)
|
| 133 |
+
x1, x2 = x.chunk(2, dim=1)
|
| 134 |
+
x = F.gelu(x1) * x2
|
| 135 |
+
return x, mask_out
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class GatedEmb(nn.Module):
|
| 139 |
+
"""Gated patch embedding (standard, no partial conv)."""
|
| 140 |
+
|
| 141 |
+
def __init__(self, in_c=3, embed_dim=48, bias=False):
|
| 142 |
+
super().__init__()
|
| 143 |
+
self.gproj1 = nn.Conv2d(in_c, embed_dim * 2, kernel_size=3, stride=1, padding=1, bias=bias)
|
| 144 |
+
|
| 145 |
+
def forward(self, x):
|
| 146 |
+
x = self.gproj1(x)
|
| 147 |
+
x1, x2 = x.chunk(2, dim=1)
|
| 148 |
+
return F.gelu(x1) * x2
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
# ---------------------------------------------------------------------------
|
| 152 |
+
# Feed-Forward Network
|
| 153 |
+
# ---------------------------------------------------------------------------
|
| 154 |
+
|
| 155 |
+
class FeedForward(nn.Module):
|
| 156 |
+
"""Gated depth-wise separable FFN."""
|
| 157 |
+
|
| 158 |
+
def __init__(self, dim, ffn_expansion_factor, bias):
|
| 159 |
+
super().__init__()
|
| 160 |
+
hidden_features = int(dim * ffn_expansion_factor)
|
| 161 |
+
self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias)
|
| 162 |
+
self.dwconv = nn.Conv2d(
|
| 163 |
+
hidden_features * 2, hidden_features * 2,
|
| 164 |
+
kernel_size=3, stride=1, padding=1,
|
| 165 |
+
groups=hidden_features * 2, bias=bias,
|
| 166 |
+
)
|
| 167 |
+
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
|
| 168 |
+
|
| 169 |
+
def forward(self, x):
|
| 170 |
+
x = self.project_in(x)
|
| 171 |
+
x1, x2 = self.dwconv(x).chunk(2, dim=1)
|
| 172 |
+
x = F.gelu(x1) * x2
|
| 173 |
+
return self.project_out(x)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
# ---------------------------------------------------------------------------
|
| 177 |
+
# Channel Attention
|
| 178 |
+
# ---------------------------------------------------------------------------
|
| 179 |
+
|
| 180 |
+
class ChannelGate(nn.Module):
|
| 181 |
+
"""SE/CBAM-style channel gate."""
|
| 182 |
+
|
| 183 |
+
def __init__(self, dim, reduction=16, use_max=True, bias=True):
|
| 184 |
+
super().__init__()
|
| 185 |
+
hidden = max(1, dim // reduction)
|
| 186 |
+
self.mlp = nn.Sequential(
|
| 187 |
+
nn.Conv2d(dim, hidden, 1, bias=bias),
|
| 188 |
+
nn.ReLU(inplace=True),
|
| 189 |
+
nn.Conv2d(hidden, dim, 1, bias=bias),
|
| 190 |
+
)
|
| 191 |
+
self.use_max = use_max
|
| 192 |
+
|
| 193 |
+
def _pooled(self, t):
|
| 194 |
+
avg = F.adaptive_avg_pool2d(t, 1)
|
| 195 |
+
if self.use_max:
|
| 196 |
+
mx = F.adaptive_max_pool2d(t, 1)
|
| 197 |
+
pooled = avg + mx
|
| 198 |
+
else:
|
| 199 |
+
pooled = avg
|
| 200 |
+
return self.mlp(pooled)
|
| 201 |
+
|
| 202 |
+
def forward(self, x, kv=None):
|
| 203 |
+
gate_logits = self._pooled(x) if kv is None else (self._pooled(x) + self._pooled(kv))
|
| 204 |
+
gate = torch.sigmoid(gate_logits)
|
| 205 |
+
x_gated = x * gate
|
| 206 |
+
kv_gated = kv * gate if kv is not None else None
|
| 207 |
+
return x_gated, kv_gated
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
# ---------------------------------------------------------------------------
|
| 211 |
+
# Spatial Attention (xformers)
|
| 212 |
+
# ---------------------------------------------------------------------------
|
| 213 |
+
|
| 214 |
+
class Attention(nn.Module):
|
| 215 |
+
"""Spatial attention with xformers memory-efficient attention.
|
| 216 |
+
|
| 217 |
+
Supports both self-attention (kv=None) and cross-attention (kv provided).
|
| 218 |
+
Includes a spatial gating branch.
|
| 219 |
+
"""
|
| 220 |
+
|
| 221 |
+
def __init__(self, dim, num_heads, bias):
|
| 222 |
+
super().__init__()
|
| 223 |
+
self.num_heads = num_heads
|
| 224 |
+
|
| 225 |
+
# Self-attention projections
|
| 226 |
+
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
|
| 227 |
+
self.qkv_dwconv = nn.Conv2d(
|
| 228 |
+
dim * 3, dim * 3, kernel_size=3, stride=1, padding=1,
|
| 229 |
+
groups=dim * 3, bias=bias,
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
# Cross-attention projections
|
| 233 |
+
self.q_proj_qonly = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
| 234 |
+
self.q_dw_qonly = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim, bias=bias)
|
| 235 |
+
self.kv_proj_cross = nn.Conv2d(dim, dim * 2, kernel_size=1, bias=bias)
|
| 236 |
+
self.kv_dwconv_cross = nn.Conv2d(
|
| 237 |
+
dim * 2, dim * 2, kernel_size=3, stride=1, padding=1,
|
| 238 |
+
groups=dim * 2, bias=bias,
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
| 242 |
+
|
| 243 |
+
# Spatial gating
|
| 244 |
+
self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)
|
| 245 |
+
self.upsample_to = lambda t, size: F.interpolate(t, size=size, mode="bilinear", align_corners=False)
|
| 246 |
+
self.conv = nn.Sequential(
|
| 247 |
+
nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=True),
|
| 248 |
+
LayerNorm(dim, "WithBias"),
|
| 249 |
+
nn.ReLU(inplace=True),
|
| 250 |
+
nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=True),
|
| 251 |
+
LayerNorm(dim, "WithBias"),
|
| 252 |
+
nn.ReLU(inplace=True),
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
def forward(self, x, kv=None):
|
| 256 |
+
b, c, h, w = x.shape
|
| 257 |
+
head_dim = c // self.num_heads
|
| 258 |
+
|
| 259 |
+
if kv is None:
|
| 260 |
+
qkv = self.qkv_dwconv(self.qkv(x))
|
| 261 |
+
q, k, v = qkv.chunk(3, dim=1)
|
| 262 |
+
else:
|
| 263 |
+
q = self.q_dw_qonly(self.q_proj_qonly(x))
|
| 264 |
+
kv_feat = self.kv_dwconv_cross(self.kv_proj_cross(kv))
|
| 265 |
+
k, v = kv_feat.chunk(2, dim=1)
|
| 266 |
+
|
| 267 |
+
q = q.view(b, self.num_heads, head_dim, h * w).permute(0, 3, 1, 2).contiguous()
|
| 268 |
+
k = k.view(b, self.num_heads, head_dim, -1).permute(0, 3, 1, 2).contiguous()
|
| 269 |
+
v = v.view(b, self.num_heads, head_dim, -1).permute(0, 3, 1, 2).contiguous()
|
| 270 |
+
|
| 271 |
+
out = ops.memory_efficient_attention(q, k, v)
|
| 272 |
+
out = out.permute(0, 2, 3, 1).reshape(b, c, h, w)
|
| 273 |
+
|
| 274 |
+
# Spatial gating
|
| 275 |
+
spatial_weight = self.avg_pool(x)
|
| 276 |
+
spatial_weight = self.conv(spatial_weight)
|
| 277 |
+
spatial_weight = self.upsample_to(spatial_weight, (h, w))
|
| 278 |
+
out = out * spatial_weight
|
| 279 |
+
|
| 280 |
+
return self.project_out(out)
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
# ---------------------------------------------------------------------------
|
| 284 |
+
# Drop Path (Stochastic Depth)
|
| 285 |
+
# ---------------------------------------------------------------------------
|
| 286 |
+
|
| 287 |
+
class DropPath(nn.Module):
|
| 288 |
+
def __init__(self, p: float = 0.0):
|
| 289 |
+
super().__init__()
|
| 290 |
+
self.p = float(p)
|
| 291 |
+
|
| 292 |
+
def forward(self, x):
|
| 293 |
+
if self.p == 0.0 or not self.training:
|
| 294 |
+
return x
|
| 295 |
+
keep = 1.0 - self.p
|
| 296 |
+
mask = torch.rand(x.shape[0], 1, 1, 1, device=x.device, dtype=x.dtype) < keep
|
| 297 |
+
return x * mask / keep
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
# ---------------------------------------------------------------------------
|
| 301 |
+
# Transformer-like Block
|
| 302 |
+
# ---------------------------------------------------------------------------
|
| 303 |
+
|
| 304 |
+
class TransformerLikeBlock(nn.Module):
|
| 305 |
+
"""Channel gate → Spatial attention → FFN with layer scale and residuals."""
|
| 306 |
+
|
| 307 |
+
def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type,
|
| 308 |
+
drop_path=0.0, layerscale_init=1e-2):
|
| 309 |
+
super().__init__()
|
| 310 |
+
self.norm_c = LayerNorm(dim, LayerNorm_type)
|
| 311 |
+
self.chan = ChannelGate(dim, reduction=16, use_max=True, bias=bias)
|
| 312 |
+
self.norm_s = LayerNorm(dim, LayerNorm_type)
|
| 313 |
+
self.sattn = Attention(dim, num_heads, bias)
|
| 314 |
+
self.norm_f = LayerNorm(dim, LayerNorm_type)
|
| 315 |
+
self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
|
| 316 |
+
|
| 317 |
+
self.gamma_c = nn.Parameter(torch.ones(1, dim, 1, 1) * layerscale_init)
|
| 318 |
+
self.gamma_s = nn.Parameter(torch.ones(1, dim, 1, 1) * layerscale_init)
|
| 319 |
+
self.gamma_f = nn.Parameter(torch.ones(1, dim, 1, 1) * layerscale_init)
|
| 320 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity()
|
| 321 |
+
|
| 322 |
+
def forward(self, x, kv=None):
|
| 323 |
+
xc = self.norm_c(x)
|
| 324 |
+
xc_gated, kv_gated = self.chan(xc, kv)
|
| 325 |
+
x = x + self.drop_path(self.gamma_c * xc_gated)
|
| 326 |
+
|
| 327 |
+
xs = self.norm_s(x)
|
| 328 |
+
xs = self.sattn(xs, kv_gated if kv is not None else None)
|
| 329 |
+
x = x + self.drop_path(self.gamma_s * xs)
|
| 330 |
+
|
| 331 |
+
xf = self.norm_f(x)
|
| 332 |
+
xf = self.ffn(xf)
|
| 333 |
+
x = x + self.drop_path(self.gamma_f * xf)
|
| 334 |
+
return x
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
# ---------------------------------------------------------------------------
|
| 338 |
+
# Sandwich Block
|
| 339 |
+
# ---------------------------------------------------------------------------
|
| 340 |
+
|
| 341 |
+
class SandwichBlock(nn.Module):
|
| 342 |
+
"""FFN → Channel gate → Spatial attn → FFN."""
|
| 343 |
+
|
| 344 |
+
def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type):
|
| 345 |
+
super().__init__()
|
| 346 |
+
self.norm1_1 = LayerNorm(dim, LayerNorm_type)
|
| 347 |
+
self.ffn1 = FeedForward(dim, ffn_expansion_factor, bias)
|
| 348 |
+
self.norm_c = LayerNorm(dim, LayerNorm_type)
|
| 349 |
+
self.chan = ChannelGate(dim, reduction=16, use_max=True, bias=bias)
|
| 350 |
+
self.norm1 = LayerNorm(dim, LayerNorm_type)
|
| 351 |
+
self.attn = Attention(dim, num_heads, bias)
|
| 352 |
+
self.norm2 = LayerNorm(dim, LayerNorm_type)
|
| 353 |
+
self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
|
| 354 |
+
|
| 355 |
+
def forward(self, x, kv=None):
|
| 356 |
+
x = x + self.ffn1(self.norm1_1(x))
|
| 357 |
+
xc = self.norm_c(x)
|
| 358 |
+
xc_gated, kv_gated = self.chan(xc, kv)
|
| 359 |
+
x = x + xc_gated
|
| 360 |
+
x = x + self.attn(self.norm1(x), kv_gated if kv is not None else None)
|
| 361 |
+
x = x + self.ffn(self.norm2(x))
|
| 362 |
+
return x
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
# ---------------------------------------------------------------------------
|
| 366 |
+
# Downsample / Upsample
|
| 367 |
+
# ---------------------------------------------------------------------------
|
| 368 |
+
|
| 369 |
+
class Downsample(nn.Module):
|
| 370 |
+
def __init__(self, n_feat):
|
| 371 |
+
super().__init__()
|
| 372 |
+
self.body = nn.Sequential(
|
| 373 |
+
nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=2, padding=1, bias=False),
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
def forward(self, x, mask=None):
|
| 377 |
+
return self.body(x)
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
class Upsample(nn.Module):
|
| 381 |
+
def __init__(self, n_feat):
|
| 382 |
+
super().__init__()
|
| 383 |
+
self.body = nn.Sequential(
|
| 384 |
+
nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False),
|
| 385 |
+
nn.PixelShuffle(2),
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
def forward(self, x, mask=None):
|
| 389 |
+
return self.body(x)
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
# ---------------------------------------------------------------------------
|
| 393 |
+
# Positional Encoding
|
| 394 |
+
# ---------------------------------------------------------------------------
|
| 395 |
+
|
| 396 |
+
class Pos2d(nn.Module):
|
| 397 |
+
"""2D sinusoidal positional encoding."""
|
| 398 |
+
|
| 399 |
+
def __init__(self, dim):
|
| 400 |
+
super().__init__()
|
| 401 |
+
self.proj = nn.Conv2d(4, dim, kernel_size=1)
|
| 402 |
+
|
| 403 |
+
def forward(self, x):
|
| 404 |
+
B, C, H, W = x.shape
|
| 405 |
+
device = x.device
|
| 406 |
+
yy, xx = torch.meshgrid(
|
| 407 |
+
torch.linspace(-1, 1, H, device=device),
|
| 408 |
+
torch.linspace(-1, 1, W, device=device),
|
| 409 |
+
indexing="ij",
|
| 410 |
+
)
|
| 411 |
+
pe4 = torch.stack([xx, yy, torch.sin(xx * 3.14159), torch.cos(yy * 3.14159)], dim=0)
|
| 412 |
+
pe = self.proj(pe4.unsqueeze(0)).repeat(B, 1, 1, 1)
|
| 413 |
+
return x + pe
|
pr_iqa/model/priqa.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PR-IQA: Partial-Reference Image Quality Assessment model.
|
| 3 |
+
|
| 4 |
+
3-input U-Net encoder-decoder with cross-attention:
|
| 5 |
+
- tgt_img: partial quality map (from FeatureMetric) replicated to 3ch
|
| 6 |
+
- dif_img: generated / distorted image
|
| 7 |
+
- ref_img: reference image
|
| 8 |
+
|
| 9 |
+
Each input comes with a 4-scale mask pyramid (whole, half, quarter, tiny).
|
| 10 |
+
|
| 11 |
+
Architecture:
|
| 12 |
+
Encoder: 4 levels (dim → 2*dim → 4*dim → 8*dim)
|
| 13 |
+
- img_encoder: shared for ref_img and dif_img (self-attention)
|
| 14 |
+
- map_encoder: for tgt_img (cross-attention with ref features)
|
| 15 |
+
- qfuse: fuses dif and tgt encoder outputs at each level
|
| 16 |
+
Decoder: 3 levels with skip connections from the dif encoder
|
| 17 |
+
Output: tanh-activated quality map
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
|
| 24 |
+
from .layers import (
|
| 25 |
+
GatedPartialEmb,
|
| 26 |
+
GatedEmb,
|
| 27 |
+
TransformerLikeBlock,
|
| 28 |
+
Downsample,
|
| 29 |
+
Upsample,
|
| 30 |
+
Pos2d,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class PRIQA(nn.Module):
|
| 35 |
+
"""Partial-Reference Image Quality Assessment model.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
inp_channels: Input channels per image (typically 4 = RGB + mask).
|
| 39 |
+
out_channels: Output channels (1 for quality map, 3 for RGB).
|
| 40 |
+
dim: Base feature dimension (doubles at each encoder level).
|
| 41 |
+
num_blocks: Number of TransformerLikeBlocks at each level.
|
| 42 |
+
heads: Number of attention heads at each level.
|
| 43 |
+
ffn_expansion_factor: FFN hidden dim multiplier.
|
| 44 |
+
bias: Use bias in convolutions.
|
| 45 |
+
LayerNorm_type: ``"WithBias"`` or ``"BiasFree"``.
|
| 46 |
+
use_partial_conv: Use PartialConv2d in patch embedding.
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
def __init__(
|
| 50 |
+
self,
|
| 51 |
+
inp_channels=4,
|
| 52 |
+
out_channels=3,
|
| 53 |
+
dim=48,
|
| 54 |
+
num_blocks=[4, 6, 6, 8],
|
| 55 |
+
heads=[1, 2, 4, 8],
|
| 56 |
+
ffn_expansion_factor=2.66,
|
| 57 |
+
bias=False,
|
| 58 |
+
LayerNorm_type="WithBias",
|
| 59 |
+
use_partial_conv=True,
|
| 60 |
+
):
|
| 61 |
+
super().__init__()
|
| 62 |
+
self.use_partial_conv = use_partial_conv
|
| 63 |
+
|
| 64 |
+
# -- Patch embedding --
|
| 65 |
+
if use_partial_conv:
|
| 66 |
+
self.patch_embed = GatedPartialEmb(inp_channels, dim, bias)
|
| 67 |
+
else:
|
| 68 |
+
self.patch_embed = GatedEmb(inp_channels, dim, bias)
|
| 69 |
+
|
| 70 |
+
# -- Quality fusion (dif + tgt) at each level --
|
| 71 |
+
self.qfuse_l1 = nn.Conv2d(dim * 2, dim, kernel_size=1, bias=bias)
|
| 72 |
+
self.qfuse_l2 = nn.Conv2d(int(dim * 2 ** 1) * 2, int(dim * 2 ** 1), kernel_size=1, bias=bias)
|
| 73 |
+
self.qfuse_l3 = nn.Conv2d(int(dim * 2 ** 2) * 2, int(dim * 2 ** 2), kernel_size=1, bias=bias)
|
| 74 |
+
self.qfuse_l4 = nn.Conv2d(int(dim * 2 ** 3) * 2, int(dim * 2 ** 3), kernel_size=1, bias=bias)
|
| 75 |
+
|
| 76 |
+
# -- Downsampler --
|
| 77 |
+
self.down1_2 = Downsample(dim)
|
| 78 |
+
self.down2_3 = Downsample(int(dim * 2 ** 1))
|
| 79 |
+
self.down3_4 = Downsample(int(dim * 2 ** 2))
|
| 80 |
+
|
| 81 |
+
# -- Positional Encoding --
|
| 82 |
+
self.pos_l1 = Pos2d(dim)
|
| 83 |
+
self.pos_l2 = Pos2d(int(dim * 2 ** 1))
|
| 84 |
+
self.pos_l3 = Pos2d(int(dim * 2 ** 2))
|
| 85 |
+
self.pos_l4 = Pos2d(int(dim * 2 ** 3))
|
| 86 |
+
self.pos_d3 = Pos2d(int(dim * 2 ** 2))
|
| 87 |
+
self.pos_d2 = Pos2d(int(dim * 2 ** 1))
|
| 88 |
+
self.pos_d1 = Pos2d(int(dim * 2 ** 1))
|
| 89 |
+
|
| 90 |
+
# -- Encoder (img: shared for ref & dif) --
|
| 91 |
+
def _make_encoder(level_dim, n_blocks, n_heads):
|
| 92 |
+
return nn.ModuleList([
|
| 93 |
+
TransformerLikeBlock(
|
| 94 |
+
dim=level_dim, num_heads=n_heads,
|
| 95 |
+
ffn_expansion_factor=ffn_expansion_factor,
|
| 96 |
+
bias=bias, LayerNorm_type=LayerNorm_type,
|
| 97 |
+
)
|
| 98 |
+
for _ in range(n_blocks)
|
| 99 |
+
])
|
| 100 |
+
|
| 101 |
+
self.img_encoder_level1 = _make_encoder(dim, num_blocks[0], heads[0])
|
| 102 |
+
self.img_encoder_level2 = _make_encoder(int(dim * 2 ** 1), num_blocks[1], heads[1])
|
| 103 |
+
self.img_encoder_level3 = _make_encoder(int(dim * 2 ** 2), num_blocks[2], heads[2])
|
| 104 |
+
self.img_latent = _make_encoder(int(dim * 2 ** 3), num_blocks[3], heads[3])
|
| 105 |
+
|
| 106 |
+
# -- Encoder (map: for tgt, cross-attention with ref) --
|
| 107 |
+
self.map_encoder_level1 = _make_encoder(dim, num_blocks[0], heads[0])
|
| 108 |
+
self.map_encoder_level2 = _make_encoder(int(dim * 2 ** 1), num_blocks[1], heads[1])
|
| 109 |
+
self.map_encoder_level3 = _make_encoder(int(dim * 2 ** 2), num_blocks[2], heads[2])
|
| 110 |
+
self.map_latent = _make_encoder(int(dim * 2 ** 3), num_blocks[3], heads[3])
|
| 111 |
+
|
| 112 |
+
# -- Decoder --
|
| 113 |
+
self.up4_3 = Upsample(int(dim * 2 ** 3))
|
| 114 |
+
self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias)
|
| 115 |
+
self.decoder_level3 = _make_encoder(int(dim * 2 ** 2), num_blocks[2], heads[2])
|
| 116 |
+
|
| 117 |
+
self.up3_2 = Upsample(int(dim * 2 ** 2))
|
| 118 |
+
self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias)
|
| 119 |
+
self.decoder_level2 = _make_encoder(int(dim * 2 ** 1), num_blocks[1], heads[1])
|
| 120 |
+
|
| 121 |
+
self.up2_1 = Upsample(int(dim * 2 ** 1))
|
| 122 |
+
self.decoder_level1 = _make_encoder(int(dim * 2 ** 1), num_blocks[0], heads[0])
|
| 123 |
+
|
| 124 |
+
# -- Output --
|
| 125 |
+
self.output = nn.Sequential(
|
| 126 |
+
nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias),
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
def forward(
|
| 130 |
+
self,
|
| 131 |
+
tgt_img, dif_img, ref_img,
|
| 132 |
+
tgt_mask_whole, tgt_mask_half, tgt_mask_quarter, tgt_mask_tiny,
|
| 133 |
+
dif_mask_whole, dif_mask_half, dif_mask_quarter, dif_mask_tiny,
|
| 134 |
+
ref_mask_whole, ref_mask_half, ref_mask_quarter, ref_mask_tiny,
|
| 135 |
+
):
|
| 136 |
+
"""
|
| 137 |
+
Args:
|
| 138 |
+
tgt_img: (B, 3, H, W) — partial quality map replicated to 3ch.
|
| 139 |
+
dif_img: (B, 3, H, W) — generated / distorted image.
|
| 140 |
+
ref_img: (B, 3, H, W) — reference image.
|
| 141 |
+
*_mask_*: (B, 1, H/s, W/s) — mask pyramids at 4 scales.
|
| 142 |
+
|
| 143 |
+
Returns:
|
| 144 |
+
(B, out_channels, H, W) quality map (tanh activated).
|
| 145 |
+
"""
|
| 146 |
+
# -- Patch embedding --
|
| 147 |
+
if self.use_partial_conv:
|
| 148 |
+
tgt_enc_level1, _ = self.patch_embed(
|
| 149 |
+
torch.cat((tgt_img, tgt_mask_whole), dim=1), tgt_mask_whole,
|
| 150 |
+
)
|
| 151 |
+
dif_enc_level1, _ = self.patch_embed(
|
| 152 |
+
torch.cat((dif_img, dif_mask_whole), dim=1), dif_mask_whole,
|
| 153 |
+
)
|
| 154 |
+
ref_enc_level1, _ = self.patch_embed(
|
| 155 |
+
torch.cat((ref_img, ref_mask_whole), dim=1), ref_mask_whole,
|
| 156 |
+
)
|
| 157 |
+
else:
|
| 158 |
+
tgt_enc_level1 = self.patch_embed(torch.cat((tgt_img, tgt_mask_whole), dim=1))
|
| 159 |
+
dif_enc_level1 = self.patch_embed(torch.cat((dif_img, dif_mask_whole), dim=1))
|
| 160 |
+
ref_enc_level1 = self.patch_embed(torch.cat((ref_img, ref_mask_whole), dim=1))
|
| 161 |
+
|
| 162 |
+
tgt_enc_level1 = self.pos_l1(tgt_enc_level1)
|
| 163 |
+
dif_enc_level1 = self.pos_l1(dif_enc_level1)
|
| 164 |
+
ref_enc_level1 = self.pos_l1(ref_enc_level1)
|
| 165 |
+
|
| 166 |
+
# ── ENCODER Level 1 ──
|
| 167 |
+
out_ref_enc_level1 = ref_enc_level1
|
| 168 |
+
for block in self.img_encoder_level1:
|
| 169 |
+
out_ref_enc_level1 = block(out_ref_enc_level1)
|
| 170 |
+
kv_level1 = out_ref_enc_level1
|
| 171 |
+
|
| 172 |
+
out_tgt_enc_level1 = tgt_enc_level1
|
| 173 |
+
for block in self.map_encoder_level1:
|
| 174 |
+
out_tgt_enc_level1 = block(out_tgt_enc_level1, kv_level1)
|
| 175 |
+
|
| 176 |
+
out_dif_enc_level1 = dif_enc_level1
|
| 177 |
+
for block in self.img_encoder_level1:
|
| 178 |
+
out_dif_enc_level1 = block(out_dif_enc_level1, kv_level1)
|
| 179 |
+
|
| 180 |
+
out_dif_enc_level1 = self.qfuse_l1(torch.cat([out_dif_enc_level1, out_tgt_enc_level1], dim=1))
|
| 181 |
+
|
| 182 |
+
# ── ENCODER Level 2 ──
|
| 183 |
+
inp_tgt_enc_level2 = self.pos_l2(self.down1_2(out_tgt_enc_level1, tgt_mask_whole))
|
| 184 |
+
inp_dif_enc_level2 = self.pos_l2(self.down1_2(out_dif_enc_level1, dif_mask_whole))
|
| 185 |
+
inp_ref_enc_level2 = self.pos_l2(self.down1_2(out_ref_enc_level1, ref_mask_whole))
|
| 186 |
+
|
| 187 |
+
out_ref_enc_level2 = inp_ref_enc_level2
|
| 188 |
+
for block in self.img_encoder_level2:
|
| 189 |
+
out_ref_enc_level2 = block(out_ref_enc_level2)
|
| 190 |
+
kv_level2 = out_ref_enc_level2
|
| 191 |
+
|
| 192 |
+
out_tgt_enc_level2 = inp_tgt_enc_level2
|
| 193 |
+
for block in self.map_encoder_level2:
|
| 194 |
+
out_tgt_enc_level2 = block(out_tgt_enc_level2, kv_level2)
|
| 195 |
+
|
| 196 |
+
out_dif_enc_level2 = inp_dif_enc_level2
|
| 197 |
+
for block in self.img_encoder_level2:
|
| 198 |
+
out_dif_enc_level2 = block(out_dif_enc_level2, kv_level2)
|
| 199 |
+
|
| 200 |
+
out_dif_enc_level2 = self.qfuse_l2(torch.cat([out_dif_enc_level2, out_tgt_enc_level2], dim=1))
|
| 201 |
+
|
| 202 |
+
# ── ENCODER Level 3 ──
|
| 203 |
+
inp_tgt_enc_level3 = self.pos_l3(self.down2_3(out_tgt_enc_level2, tgt_mask_half))
|
| 204 |
+
inp_dif_enc_level3 = self.pos_l3(self.down2_3(out_dif_enc_level2, dif_mask_half))
|
| 205 |
+
inp_ref_enc_level3 = self.pos_l3(self.down2_3(out_ref_enc_level2, ref_mask_half))
|
| 206 |
+
|
| 207 |
+
out_ref_enc_level3 = inp_ref_enc_level3
|
| 208 |
+
for block in self.img_encoder_level3:
|
| 209 |
+
out_ref_enc_level3 = block(out_ref_enc_level3)
|
| 210 |
+
kv_level3 = out_ref_enc_level3
|
| 211 |
+
|
| 212 |
+
out_tgt_enc_level3 = inp_tgt_enc_level3
|
| 213 |
+
for block in self.map_encoder_level3:
|
| 214 |
+
out_tgt_enc_level3 = block(out_tgt_enc_level3, kv_level3)
|
| 215 |
+
|
| 216 |
+
out_dif_enc_level3 = inp_dif_enc_level3
|
| 217 |
+
for block in self.img_encoder_level3:
|
| 218 |
+
out_dif_enc_level3 = block(out_dif_enc_level3, kv_level3)
|
| 219 |
+
|
| 220 |
+
out_dif_enc_level3 = self.qfuse_l3(torch.cat([out_dif_enc_level3, out_tgt_enc_level3], dim=1))
|
| 221 |
+
|
| 222 |
+
# ── ENCODER Level 4 (Latent) ──
|
| 223 |
+
inp_tgt_enc_level4 = self.pos_l4(self.down3_4(out_tgt_enc_level3, tgt_mask_quarter))
|
| 224 |
+
inp_dif_enc_level4 = self.pos_l4(self.down3_4(out_dif_enc_level3, dif_mask_quarter))
|
| 225 |
+
inp_ref_enc_level4 = self.pos_l4(self.down3_4(out_ref_enc_level3, ref_mask_quarter))
|
| 226 |
+
|
| 227 |
+
ref_latent_out = inp_ref_enc_level4
|
| 228 |
+
for block in self.img_latent:
|
| 229 |
+
ref_latent_out = block(ref_latent_out)
|
| 230 |
+
kv_level4 = ref_latent_out
|
| 231 |
+
|
| 232 |
+
tgt_latent_out = inp_tgt_enc_level4
|
| 233 |
+
for block in self.map_latent:
|
| 234 |
+
tgt_latent_out = block(tgt_latent_out, kv_level4)
|
| 235 |
+
|
| 236 |
+
dif_latent_out = inp_dif_enc_level4
|
| 237 |
+
for block in self.img_latent:
|
| 238 |
+
dif_latent_out = block(dif_latent_out, kv_level4)
|
| 239 |
+
|
| 240 |
+
latent_out = self.qfuse_l4(torch.cat([dif_latent_out, tgt_latent_out], dim=1))
|
| 241 |
+
|
| 242 |
+
# ── DECODER ──
|
| 243 |
+
inp_dec_level3 = self.up4_3(latent_out, dif_mask_tiny)
|
| 244 |
+
inp_dec_level3 = torch.cat([inp_dec_level3, out_dif_enc_level3], 1)
|
| 245 |
+
inp_dec_level3 = self.pos_d3(self.reduce_chan_level3(inp_dec_level3))
|
| 246 |
+
out_dec_level3 = inp_dec_level3
|
| 247 |
+
for block in self.decoder_level3:
|
| 248 |
+
out_dec_level3 = block(out_dec_level3)
|
| 249 |
+
|
| 250 |
+
inp_dec_level2 = self.up3_2(out_dec_level3, dif_mask_quarter)
|
| 251 |
+
inp_dec_level2 = torch.cat([inp_dec_level2, out_dif_enc_level2], 1)
|
| 252 |
+
inp_dec_level2 = self.pos_d2(self.reduce_chan_level2(inp_dec_level2))
|
| 253 |
+
out_dec_level2 = inp_dec_level2
|
| 254 |
+
for block in self.decoder_level2:
|
| 255 |
+
out_dec_level2 = block(out_dec_level2)
|
| 256 |
+
|
| 257 |
+
inp_dec_level1 = self.up2_1(out_dec_level2, dif_mask_half)
|
| 258 |
+
inp_dec_level1 = torch.cat([inp_dec_level1, out_dif_enc_level1], 1)
|
| 259 |
+
inp_dec_level1 = self.pos_d1(inp_dec_level1)
|
| 260 |
+
out_dec_level1 = inp_dec_level1
|
| 261 |
+
for block in self.decoder_level1:
|
| 262 |
+
out_dec_level1 = block(out_dec_level1)
|
| 263 |
+
|
| 264 |
+
return torch.tanh(self.output(out_dec_level1))
|
pr_iqa/partial_map/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .feature_metric import FeatureMetric
|
| 2 |
+
|
| 3 |
+
__all__ = ["FeatureMetric"]
|
pr_iqa/partial_map/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (214 Bytes). View file
|
|
|
pr_iqa/partial_map/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (212 Bytes). View file
|
|
|
pr_iqa/partial_map/__pycache__/feature_metric.cpython-310.pyc
ADDED
|
Binary file (8.03 kB). View file
|
|
|
pr_iqa/partial_map/__pycache__/feature_metric.cpython-38.pyc
ADDED
|
Binary file (7.91 kB). View file
|
|
|
pr_iqa/partial_map/feature_metric.py
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FeatureMetric: DINOv2 + LoftUp feature-based quality metric.
|
| 3 |
+
|
| 4 |
+
Generates partial quality maps by:
|
| 5 |
+
1. Extracting DINOv2 features (upsampled via LoftUp) from input images
|
| 6 |
+
2. Using VGGT for monocular depth and pose estimation
|
| 7 |
+
3. Constructing a colored 3D point cloud with features
|
| 8 |
+
4. Rendering the point cloud from the target viewpoint via PyTorch3D
|
| 9 |
+
5. Computing cosine similarity between rendered features and target features
|
| 10 |
+
|
| 11 |
+
Two modes:
|
| 12 |
+
- partial_generation=True: Full 3D pipeline → partial map + overlap mask
|
| 13 |
+
- partial_generation=False: Direct cosine similarity → total quality map
|
| 14 |
+
|
| 15 |
+
Dependencies (Level 1):
|
| 16 |
+
- VGGT (facebook/VGGT-1B)
|
| 17 |
+
- LoftUp (andrehuang/loftup)
|
| 18 |
+
- PyTorch3D
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import sys
|
| 22 |
+
import torch
|
| 23 |
+
from torch import Tensor
|
| 24 |
+
from torch.nn import Module
|
| 25 |
+
import numpy as np
|
| 26 |
+
from typing import Optional, Tuple, Union
|
| 27 |
+
from pathlib import Path
|
| 28 |
+
from einops import rearrange
|
| 29 |
+
|
| 30 |
+
# Auto-detect submodule paths
|
| 31 |
+
_THIS_DIR = Path(__file__).resolve().parent
|
| 32 |
+
_REPO_ROOT = _THIS_DIR.parent.parent
|
| 33 |
+
_SUBMODULES = _REPO_ROOT / "submodules"
|
| 34 |
+
|
| 35 |
+
if (_SUBMODULES / "vggt").exists():
|
| 36 |
+
sys.path.insert(0, str(_SUBMODULES / "vggt"))
|
| 37 |
+
if (_SUBMODULES / "loftup").exists():
|
| 38 |
+
sys.path.insert(0, str(_SUBMODULES / "loftup"))
|
| 39 |
+
|
| 40 |
+
# Lazy imports for heavy dependencies — loaded on first use
|
| 41 |
+
_VGGT = None
|
| 42 |
+
_LOFTUP_FEATURIZERS = None
|
| 43 |
+
_LOFTUP_UPSAMPLERS = None
|
| 44 |
+
_PYTORCH3D = None
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _import_vggt():
|
| 48 |
+
global _VGGT
|
| 49 |
+
if _VGGT is None:
|
| 50 |
+
from vggt.models.vggt import VGGT as _V
|
| 51 |
+
from vggt.utils.pose_enc import pose_encoding_to_extri_intri as _pe
|
| 52 |
+
from vggt.utils.geometry import unproject_depth_map_to_point_map as _ud
|
| 53 |
+
from vggt.utils.load_fn import load_and_preprocess_images as _lpi
|
| 54 |
+
_VGGT = {"VGGT": _V, "pose_encoding_to_extri_intri": _pe,
|
| 55 |
+
"unproject_depth_map_to_point_map": _ud,
|
| 56 |
+
"load_and_preprocess_images": _lpi}
|
| 57 |
+
return _VGGT
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _import_loftup():
|
| 61 |
+
global _LOFTUP_FEATURIZERS, _LOFTUP_UPSAMPLERS
|
| 62 |
+
if _LOFTUP_FEATURIZERS is None:
|
| 63 |
+
from featurizers import get_featurizer as _gf
|
| 64 |
+
from upsamplers import norm as _n
|
| 65 |
+
_LOFTUP_FEATURIZERS = _gf
|
| 66 |
+
_LOFTUP_UPSAMPLERS = _n
|
| 67 |
+
return _LOFTUP_FEATURIZERS, _LOFTUP_UPSAMPLERS
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _import_pytorch3d():
|
| 71 |
+
global _PYTORCH3D
|
| 72 |
+
if _PYTORCH3D is None:
|
| 73 |
+
from pytorch3d.structures import Pointclouds
|
| 74 |
+
from pytorch3d.renderer import (
|
| 75 |
+
PointsRasterizationSettings,
|
| 76 |
+
PointsRasterizer,
|
| 77 |
+
AlphaCompositor,
|
| 78 |
+
)
|
| 79 |
+
from pytorch3d.renderer.camera_conversions import _cameras_from_opencv_projection
|
| 80 |
+
_PYTORCH3D = {
|
| 81 |
+
"Pointclouds": Pointclouds,
|
| 82 |
+
"PointsRasterizationSettings": PointsRasterizationSettings,
|
| 83 |
+
"PointsRasterizer": PointsRasterizer,
|
| 84 |
+
"AlphaCompositor": AlphaCompositor,
|
| 85 |
+
"_cameras_from_opencv_projection": _cameras_from_opencv_projection,
|
| 86 |
+
}
|
| 87 |
+
return _PYTORCH3D
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class FeatureMetric(Module):
|
| 91 |
+
"""DINOv2 + LoftUp + VGGT → partial / total quality map.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
img_size: Inference image size (controls rasterizer resolution).
|
| 95 |
+
feature_backbone: Name of the feature backbone (default: ``"dinov2"``).
|
| 96 |
+
loftup_torch_hub: Torch Hub repository for LoftUp.
|
| 97 |
+
loftup_model_name: LoftUp model name.
|
| 98 |
+
vggt_weights: HuggingFace model ID for VGGT.
|
| 99 |
+
use_vggt: Load VGGT for depth/pose estimation.
|
| 100 |
+
use_loftup: Load LoftUp for feature upsampling.
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
def __init__(
|
| 104 |
+
self,
|
| 105 |
+
img_size: int = 256,
|
| 106 |
+
feature_backbone: str = "dinov2",
|
| 107 |
+
loftup_torch_hub: Union[str, Path] = "andrehuang/loftup",
|
| 108 |
+
loftup_model_name: Union[str, Path] = "loftup_dinov2s",
|
| 109 |
+
vggt_weights: Union[str, Path] = "facebook/VGGT-1B",
|
| 110 |
+
use_vggt: bool = True,
|
| 111 |
+
use_loftup: bool = False,
|
| 112 |
+
**kwargs,
|
| 113 |
+
) -> None:
|
| 114 |
+
super().__init__()
|
| 115 |
+
self.img_size = img_size
|
| 116 |
+
|
| 117 |
+
get_featurizer, _ = _import_loftup()
|
| 118 |
+
self.feature_backbone, self.patch_size, self.dim = get_featurizer(feature_backbone)
|
| 119 |
+
|
| 120 |
+
self.upsampler = (
|
| 121 |
+
torch.hub.load(loftup_torch_hub, loftup_model_name, pretrained=True)
|
| 122 |
+
if use_loftup else None
|
| 123 |
+
)
|
| 124 |
+
self.use_loftup = use_loftup
|
| 125 |
+
|
| 126 |
+
if use_vggt:
|
| 127 |
+
vggt_mod = _import_vggt()
|
| 128 |
+
self.vggt = vggt_mod["VGGT"].from_pretrained(vggt_weights)
|
| 129 |
+
|
| 130 |
+
p3d = _import_pytorch3d()
|
| 131 |
+
self.compositor = p3d["AlphaCompositor"]()
|
| 132 |
+
|
| 133 |
+
def _render(self, point_clouds, **kwargs):
|
| 134 |
+
"""Render point cloud features to images."""
|
| 135 |
+
with torch.autocast("cuda", enabled=False):
|
| 136 |
+
fragments = self.rasterizer(point_clouds, **kwargs)
|
| 137 |
+
|
| 138 |
+
r = self.rasterizer.raster_settings.radius
|
| 139 |
+
dists2 = fragments.dists.permute(0, 3, 1, 2)
|
| 140 |
+
weights = 1 - dists2 / (r * r)
|
| 141 |
+
|
| 142 |
+
images = self.compositor(
|
| 143 |
+
fragments.idx.long().permute(0, 3, 1, 2),
|
| 144 |
+
weights,
|
| 145 |
+
point_clouds.features_packed().permute(1, 0),
|
| 146 |
+
**kwargs,
|
| 147 |
+
)
|
| 148 |
+
images = images.permute(0, 2, 3, 1)
|
| 149 |
+
return images, fragments.zbuf
|
| 150 |
+
|
| 151 |
+
@torch.no_grad()
|
| 152 |
+
def forward(
|
| 153 |
+
self,
|
| 154 |
+
device: str,
|
| 155 |
+
images: Tensor, # (K, 3, H, W)
|
| 156 |
+
return_overlap_mask: bool = False,
|
| 157 |
+
return_score_map: bool = False,
|
| 158 |
+
return_projections: bool = False,
|
| 159 |
+
partial_generation: bool = False,
|
| 160 |
+
use_filtering: bool = False,
|
| 161 |
+
) -> Tuple[float, Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
|
| 162 |
+
"""Compute quality score map.
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
device: Torch device string.
|
| 166 |
+
images: (K, 3, H, W) input images. First image is the target.
|
| 167 |
+
partial_generation: If True, use full 3D pipeline for partial map.
|
| 168 |
+
|
| 169 |
+
Returns:
|
| 170 |
+
(score_scalar, overlap_mask, score_map, projections)
|
| 171 |
+
"""
|
| 172 |
+
k, c, h, w = images.shape
|
| 173 |
+
p3d = _import_pytorch3d()
|
| 174 |
+
_, norm_fn = _import_loftup()
|
| 175 |
+
|
| 176 |
+
# Setup rasterizer
|
| 177 |
+
raster_settings = p3d["PointsRasterizationSettings"](
|
| 178 |
+
image_size=(h, w), radius=0.01, points_per_pixel=10, bin_size=0,
|
| 179 |
+
)
|
| 180 |
+
self.rasterizer = p3d["PointsRasterizer"](cameras=None, raster_settings=raster_settings)
|
| 181 |
+
|
| 182 |
+
# Extract features
|
| 183 |
+
images_norm = norm_fn(images)
|
| 184 |
+
hr_feats = []
|
| 185 |
+
for i in range(k):
|
| 186 |
+
img = images_norm[i:i + 1]
|
| 187 |
+
lr_feat = self.feature_backbone(img)
|
| 188 |
+
if self.use_loftup and self.upsampler is not None:
|
| 189 |
+
hr_feat = self.upsampler(lr_feat, img)
|
| 190 |
+
else:
|
| 191 |
+
hr_feat = lr_feat
|
| 192 |
+
hr_feat = rearrange(hr_feat, "b c h w -> b (h w) c")
|
| 193 |
+
hr_feats.append(hr_feat)
|
| 194 |
+
hr_feats = torch.cat(hr_feats, dim=0)
|
| 195 |
+
|
| 196 |
+
if not partial_generation:
|
| 197 |
+
# Fast cosine similarity mode
|
| 198 |
+
dot = (hr_feats[0] * hr_feats[1]).sum(dim=1)
|
| 199 |
+
tgt_norm = torch.linalg.norm(hr_feats[0], dim=1)
|
| 200 |
+
ref_norm = torch.linalg.norm(hr_feats[1], dim=1)
|
| 201 |
+
cosine_sim = dot / (tgt_norm * ref_norm + 1e-8)
|
| 202 |
+
score_map = torch.clamp(cosine_sim, min=0.0, max=1.0)
|
| 203 |
+
|
| 204 |
+
if self.use_loftup and self.upsampler is not None:
|
| 205 |
+
H_out, W_out = h, w
|
| 206 |
+
else:
|
| 207 |
+
H_out = h // self.patch_size
|
| 208 |
+
W_out = w // self.patch_size
|
| 209 |
+
score_map = score_map.reshape(H_out, W_out).unsqueeze(0)
|
| 210 |
+
return score_map.mean().item(), None, score_map if return_score_map else None, None
|
| 211 |
+
|
| 212 |
+
# Full 3D partial map generation
|
| 213 |
+
vggt_mod = _import_vggt()
|
| 214 |
+
pose_encoding_to_extri_intri = vggt_mod["pose_encoding_to_extri_intri"]
|
| 215 |
+
unproject_depth_map_to_point_map = vggt_mod["unproject_depth_map_to_point_map"]
|
| 216 |
+
|
| 217 |
+
preds = self.vggt(images)
|
| 218 |
+
extrinsic, intrinsic = pose_encoding_to_extri_intri(preds["pose_enc"], images.shape[-2:])
|
| 219 |
+
depth, depth_conf = preds["depth"], preds["depth_conf"]
|
| 220 |
+
|
| 221 |
+
point_map = unproject_depth_map_to_point_map(
|
| 222 |
+
depth.squeeze(0), extrinsic.squeeze(0), intrinsic.squeeze(0),
|
| 223 |
+
)
|
| 224 |
+
cols = images.cpu().numpy().transpose(0, 2, 3, 1)
|
| 225 |
+
cols = cols / cols.max()
|
| 226 |
+
pts_flatten = torch.from_numpy(
|
| 227 |
+
rearrange(point_map, "k h w c -> k (h w) c")
|
| 228 |
+
).float().to(device)
|
| 229 |
+
|
| 230 |
+
if use_filtering:
|
| 231 |
+
percent = 20
|
| 232 |
+
quantile = torch.quantile(depth_conf, percent / 100.0)
|
| 233 |
+
mask_flat = rearrange((depth_conf > quantile).squeeze(0), "k h w -> k (h w)")
|
| 234 |
+
points_list, features_list = [], []
|
| 235 |
+
for i in range(k):
|
| 236 |
+
valid = mask_flat[i]
|
| 237 |
+
points_list.append(pts_flatten[i][valid])
|
| 238 |
+
features_list.append(hr_feats[i][valid])
|
| 239 |
+
point_clouds = p3d["Pointclouds"](points=points_list, features=features_list)
|
| 240 |
+
else:
|
| 241 |
+
point_clouds = p3d["Pointclouds"](points=pts_flatten, features=hr_feats)
|
| 242 |
+
|
| 243 |
+
# Render from target viewpoint
|
| 244 |
+
extrinsic, intrinsic = pose_encoding_to_extri_intri(preds["pose_enc"], images.shape[-2:])
|
| 245 |
+
E, K = extrinsic.squeeze(0), intrinsic.squeeze(0)
|
| 246 |
+
R0, T0, K0 = E[0, :3, :3], E[0, :3, 3], K[0]
|
| 247 |
+
B = pts_flatten.shape[0]
|
| 248 |
+
|
| 249 |
+
R_repeat = R0.unsqueeze(0).repeat(B, 1, 1)
|
| 250 |
+
T_repeat = T0.unsqueeze(0).repeat(B, 1)
|
| 251 |
+
K_repeat = K0.unsqueeze(0).repeat(B, 1, 1)
|
| 252 |
+
im_size = torch.tensor([[h, w]]).repeat(B, 1).to(device)
|
| 253 |
+
|
| 254 |
+
cameras_p3d = p3d["_cameras_from_opencv_projection"](R_repeat, T_repeat, K_repeat, im_size)
|
| 255 |
+
|
| 256 |
+
with torch.autocast("cuda", enabled=False):
|
| 257 |
+
bg_color = torch.tensor(
|
| 258 |
+
[-10000] * hr_feats[0].shape[-1], dtype=torch.float32, device=device,
|
| 259 |
+
)
|
| 260 |
+
rendering, zbuf = self._render(point_clouds, cameras=cameras_p3d, background_color=bg_color)
|
| 261 |
+
rendering = rearrange(rendering, "k h w c -> k c h w")
|
| 262 |
+
|
| 263 |
+
# Cosine similarity score map
|
| 264 |
+
target = rendering[0:1]
|
| 265 |
+
reference = rendering[1:]
|
| 266 |
+
dot = (reference * target).sum(dim=1)
|
| 267 |
+
tgt_norm = torch.linalg.norm(target, dim=1)
|
| 268 |
+
ref_norm = torch.linalg.norm(reference, dim=1)
|
| 269 |
+
cosine_sim = dot / (tgt_norm * ref_norm + 1e-8)
|
| 270 |
+
score_map = torch.clamp(cosine_sim, min=0.0, max=1.0)
|
| 271 |
+
|
| 272 |
+
# Mask true background
|
| 273 |
+
target_mask = zbuf[0, ..., 0] >= 0
|
| 274 |
+
reference_mask = zbuf[1:, ..., 0] >= 0
|
| 275 |
+
true_bg = ~target_mask & ~torch.any(reference_mask, dim=0)
|
| 276 |
+
score_map[:, true_bg] = 0.0
|
| 277 |
+
|
| 278 |
+
overlap_mask = zbuf[1:, ..., 0] >= 0
|
| 279 |
+
|
| 280 |
+
return (
|
| 281 |
+
score_map.mean().item(),
|
| 282 |
+
overlap_mask if return_overlap_mask else None,
|
| 283 |
+
score_map if return_score_map else None,
|
| 284 |
+
rendering if return_projections else None,
|
| 285 |
+
)
|
pr_iqa/transforms.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data transforms and batch preparation utilities for PR-IQA training.
|
| 3 |
+
|
| 4 |
+
ImageNet normalization is applied to RGB inputs.
|
| 5 |
+
Grayscale inputs (partial maps, masks) are kept in [0, 1].
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
import torchvision.transforms as T
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# ImageNet normalization constants
|
| 14 |
+
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
| 15 |
+
IMAGENET_STD = (0.229, 0.224, 0.225)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def build_rgb_transform(img_size: int = 256) -> T.Compose:
|
| 19 |
+
"""Transform for RGB images: resize → tensor → ImageNet normalize."""
|
| 20 |
+
return T.Compose([
|
| 21 |
+
T.Resize((img_size, img_size)),
|
| 22 |
+
T.ToTensor(),
|
| 23 |
+
T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
|
| 24 |
+
])
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def build_grey_transform(img_size: int = 256) -> T.Compose:
|
| 28 |
+
"""Transform for grayscale images (maps/masks): resize → tensor [0,1]."""
|
| 29 |
+
return T.Compose([
|
| 30 |
+
T.Resize((img_size, img_size)),
|
| 31 |
+
T.ToTensor(),
|
| 32 |
+
])
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def make_pyramid_masks(mask_whole: torch.Tensor):
|
| 36 |
+
"""Build 3 downscaled masks from (B, 1, H, W) → half, quarter, tiny."""
|
| 37 |
+
mask_half = F.interpolate(mask_whole, scale_factor=0.5, mode="nearest")
|
| 38 |
+
mask_quarter = F.interpolate(mask_whole, scale_factor=0.25, mode="nearest")
|
| 39 |
+
mask_tiny = F.interpolate(mask_whole, scale_factor=0.125, mode="nearest")
|
| 40 |
+
return mask_half, mask_quarter, mask_tiny
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def prepare_batch(batch: dict, device: torch.device):
|
| 44 |
+
"""Prepare a training batch for the PR-IQA model.
|
| 45 |
+
|
| 46 |
+
Takes a dataset batch dict and returns (model_args, gt) where
|
| 47 |
+
model_args is a tuple of 15 tensors matching PRIQA.forward() signature.
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
model_args: (tgt_img, dif_img, ref_img, + 12 mask tensors)
|
| 51 |
+
gt: (B, 1, H, W) ground truth quality map
|
| 52 |
+
"""
|
| 53 |
+
dtype = torch.bfloat16
|
| 54 |
+
|
| 55 |
+
dif_img = batch["tgt_diff"].to(device, dtype=dtype, non_blocking=True,
|
| 56 |
+
memory_format=torch.channels_last)
|
| 57 |
+
tgt_mask_whole = batch["partial_mask"].to(device, dtype=dtype, non_blocking=True,
|
| 58 |
+
memory_format=torch.channels_last)
|
| 59 |
+
tgt_img_1ch = batch["partial_map"].to(device, dtype=dtype, non_blocking=True,
|
| 60 |
+
memory_format=torch.channels_last)
|
| 61 |
+
tgt_img = tgt_img_1ch.repeat(1, 3, 1, 1)
|
| 62 |
+
ref_img = batch["current_ref"].to(device, dtype=dtype, non_blocking=True,
|
| 63 |
+
memory_format=torch.channels_last)
|
| 64 |
+
gt = batch["full_map"].to(device, dtype=dtype, non_blocking=True,
|
| 65 |
+
memory_format=torch.channels_last)
|
| 66 |
+
|
| 67 |
+
tgt_mask_half, tgt_mask_quarter, tgt_mask_tiny = make_pyramid_masks(tgt_mask_whole)
|
| 68 |
+
|
| 69 |
+
ones = torch.ones_like
|
| 70 |
+
dif_mask_whole = ones(tgt_mask_whole)
|
| 71 |
+
dif_mask_half = ones(tgt_mask_half)
|
| 72 |
+
dif_mask_quarter = ones(tgt_mask_quarter)
|
| 73 |
+
dif_mask_tiny = ones(tgt_mask_tiny)
|
| 74 |
+
|
| 75 |
+
ref_mask_whole = ones(tgt_mask_whole)
|
| 76 |
+
ref_mask_half = ones(tgt_mask_half)
|
| 77 |
+
ref_mask_quarter = ones(tgt_mask_quarter)
|
| 78 |
+
ref_mask_tiny = ones(tgt_mask_tiny)
|
| 79 |
+
|
| 80 |
+
model_args = (
|
| 81 |
+
tgt_img, dif_img, ref_img,
|
| 82 |
+
tgt_mask_whole, tgt_mask_half, tgt_mask_quarter, tgt_mask_tiny,
|
| 83 |
+
dif_mask_whole, dif_mask_half, dif_mask_quarter, dif_mask_tiny,
|
| 84 |
+
ref_mask_whole, ref_mask_half, ref_mask_quarter, ref_mask_tiny,
|
| 85 |
+
)
|
| 86 |
+
return model_args, gt
|