2cu1001 commited on
Commit
e6c4c00
·
verified ·
1 Parent(s): ed3bf54

Upload 21 files

Browse files
pr_iqa/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .model import PRIQA, build_priqa
2
+
3
+ __all__ = ["PRIQA", "build_priqa"]
pr_iqa/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (214 Bytes). View file
 
pr_iqa/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (212 Bytes). View file
 
pr_iqa/dataset.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dataset for PR-IQA training.
3
+
4
+ Expected directory structure per scene::
5
+
6
+ s000/
7
+ ├── total/ # Original keyframe images (RGB)
8
+ │ ├── 0000.jpg
9
+ │ ├── 0001.jpg
10
+ │ └── ...
11
+ ├── tgt_diffusion/ # Generated images per target
12
+ │ └── 0005/
13
+ │ ├── 0005_diff_0.jpg
14
+ │ └── ...
15
+ ├── total_map/ # Full quality maps (GT, grayscale)
16
+ │ └── 0005/
17
+ │ ├── 0005_diff_0.png
18
+ │ └── ...
19
+ ├── partial_map/ # Partial quality maps (from FeatureMetric)
20
+ │ └── 0005/
21
+ │ ├── 0005_diff_0_ref+10_0015.png
22
+ │ └── ...
23
+ └── partial_mask/ # Overlap masks
24
+ └── 0005/
25
+ ├── 0005_diff_0_ref+10_0015.png
26
+ └── ...
27
+
28
+ Each sample is a tuple: (tgt, tgt_diff, full_map, partial_map, partial_mask, current_ref).
29
+ """
30
+
31
+ import random
32
+ from pathlib import Path
33
+
34
+ import torch
35
+ from PIL import Image
36
+ from torch.utils.data import Dataset
37
+ import torchvision.transforms.functional as TF
38
+
39
+
40
+ class SceneDataset(Dataset):
41
+ """Dataset that enumerates all valid (tgt, diff, ref, partial_map, mask) combinations."""
42
+
43
+ def __init__(self, root_dir, rgb_transform=None, grayscale_transform=None, training=True):
44
+ self.root_dir = Path(root_dir)
45
+ self.rgb_transform = rgb_transform
46
+ self.grayscale_transform = grayscale_transform
47
+ self.samples = []
48
+ self.ref_deltas = [-20, -10, 10, 20]
49
+ self.training = training
50
+
51
+ for scene_path in sorted(self.root_dir.glob("s*")):
52
+ if not scene_path.is_dir():
53
+ continue
54
+ total_dir = scene_path / "total"
55
+ if not total_dir.is_dir():
56
+ continue
57
+
58
+ total_images = sorted(total_dir.glob("*.jpg"), key=lambda p: int(p.stem))
59
+ num_total = len(total_images)
60
+ if num_total == 0:
61
+ continue
62
+
63
+ for i, tgt_path in enumerate(total_images):
64
+ tgt_stem = tgt_path.stem
65
+
66
+ # Find reference images at fixed offsets
67
+ ref_info_list = []
68
+ complete = True
69
+ for d in self.ref_deltas:
70
+ ref_idx = (i + d) % num_total
71
+ ref_path = total_images[ref_idx]
72
+ if not ref_path.exists():
73
+ complete = False
74
+ break
75
+ ref_info_list.append({"path": ref_path, "offset": d})
76
+
77
+ if not complete:
78
+ continue
79
+
80
+ tgt_diff_dir = scene_path / "tgt_diffusion" / tgt_stem
81
+ total_map_dir = scene_path / "total_map" / tgt_stem
82
+
83
+ for tgt_diff_path in sorted(tgt_diff_dir.glob("*_diff_*.jpg")):
84
+ full_map_path = total_map_dir / f"{tgt_diff_path.stem}.png"
85
+ if not full_map_path.exists():
86
+ continue
87
+
88
+ tgt_diff_stem = tgt_diff_path.stem
89
+
90
+ for ref_info in ref_info_list:
91
+ ref_path = ref_info["path"]
92
+ ref_stem = ref_path.stem
93
+ d = ref_info["offset"]
94
+
95
+ mask_path = (
96
+ scene_path / "partial_mask" / tgt_stem
97
+ / f"{tgt_diff_stem}_ref{d:+d}_{ref_stem}.png"
98
+ )
99
+ map_path = (
100
+ scene_path / "partial_map" / tgt_stem
101
+ / f"{tgt_diff_stem}_ref{d:+d}_{ref_stem}.png"
102
+ )
103
+
104
+ if mask_path.exists() and map_path.exists():
105
+ self.samples.append({
106
+ "tgt": tgt_path,
107
+ "tgt_diff": tgt_diff_path,
108
+ "full_map": full_map_path,
109
+ "partial_mask": mask_path,
110
+ "partial_map": map_path,
111
+ "current_ref": ref_path,
112
+ })
113
+
114
+ def __len__(self):
115
+ return len(self.samples)
116
+
117
+ def __getitem__(self, idx):
118
+ paths = self.samples[idx]
119
+
120
+ tgt_img = Image.open(paths["tgt"]).convert("RGB")
121
+ tgt_diff_img = Image.open(paths["tgt_diff"]).convert("RGB")
122
+ full_map_img = Image.open(paths["full_map"]).convert("L")
123
+ partial_mask_img = Image.open(paths["partial_mask"]).convert("L")
124
+ partial_map_img = Image.open(paths["partial_map"]).convert("L")
125
+ cur_ref_img = Image.open(paths["current_ref"]).convert("RGB")
126
+
127
+ # -- Augmentation (training only) --
128
+ if self.training:
129
+ if random.random() > 0.5:
130
+ tgt_img = TF.hflip(tgt_img)
131
+ tgt_diff_img = TF.hflip(tgt_diff_img)
132
+ cur_ref_img = TF.hflip(cur_ref_img)
133
+ full_map_img = TF.hflip(full_map_img)
134
+ partial_mask_img = TF.hflip(partial_mask_img)
135
+ partial_map_img = TF.hflip(partial_map_img)
136
+
137
+ if random.random() > 0.7:
138
+ tgt_img = TF.vflip(tgt_img)
139
+ tgt_diff_img = TF.vflip(tgt_diff_img)
140
+ cur_ref_img = TF.vflip(cur_ref_img)
141
+ full_map_img = TF.vflip(full_map_img)
142
+ partial_mask_img = TF.vflip(partial_mask_img)
143
+ partial_map_img = TF.vflip(partial_map_img)
144
+
145
+ if random.random() > 0.5:
146
+ brightness = random.uniform(0.9, 1.1)
147
+ contrast = random.uniform(0.9, 1.1)
148
+ saturation = random.uniform(0.9, 1.1)
149
+ for fn in [TF.adjust_brightness, TF.adjust_contrast, TF.adjust_saturation]:
150
+ val = brightness if fn == TF.adjust_brightness else (
151
+ contrast if fn == TF.adjust_contrast else saturation
152
+ )
153
+ tgt_img = fn(tgt_img, val)
154
+ tgt_diff_img = fn(tgt_diff_img, val)
155
+ cur_ref_img = fn(cur_ref_img, val)
156
+
157
+ if self.rgb_transform:
158
+ tgt_img, tgt_diff_img, cur_ref_img = map(
159
+ self.rgb_transform, [tgt_img, tgt_diff_img, cur_ref_img]
160
+ )
161
+ if self.grayscale_transform:
162
+ full_map_img, partial_mask_img, partial_map_img = map(
163
+ self.grayscale_transform, [full_map_img, partial_mask_img, partial_map_img]
164
+ )
165
+
166
+ return {
167
+ "tgt": tgt_img,
168
+ "tgt_diff": tgt_diff_img,
169
+ "partial_mask": partial_mask_img,
170
+ "partial_map": partial_map_img,
171
+ "full_map": full_map_img,
172
+ "current_ref": cur_ref_img,
173
+ }
pr_iqa/loss.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Loss functions for PR-IQA training.
3
+
4
+ Core losses:
5
+ - JSD (Jensen-Shannon Divergence): Distribution matching
6
+ - Masked L1: Pixel-wise L1 on partial map regions
7
+ - Pearson: Correlation-based structural loss
8
+
9
+ Additional losses (optional):
10
+ - Ranking: Pairwise ranking consistency
11
+ - Global mean/std: Statistics matching
12
+ """
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+
17
+
18
+ def loss_jsd(pred, target, tau=0.2, reduction="mean", eps=1e-6):
19
+ """Jensen-Shannon Divergence loss.
20
+
21
+ Converts pixel maps to probability distributions via softmax over logits,
22
+ then computes symmetric KL divergence.
23
+ """
24
+ with torch.autocast(device_type="cuda", enabled=False):
25
+ p = pred.float().clamp(min=eps, max=1 - eps)
26
+ y = target.float().clamp(min=eps, max=1 - eps)
27
+
28
+ p_logit = torch.logit(p, eps=eps) / tau
29
+ y_logit = torch.logit(y, eps=eps) / tau
30
+
31
+ q_hat = torch.softmax(p_logit.flatten(start_dim=1), dim=1)
32
+ q = torch.softmax(y_logit.flatten(start_dim=1), dim=1)
33
+
34
+ m = 0.5 * (q + q_hat)
35
+
36
+ def _kl(a, b):
37
+ return torch.sum(a * (torch.log(a + eps) - torch.log(b + eps)), dim=1)
38
+
39
+ jsd_per = 0.5 * (_kl(q, m) + _kl(q_hat, m))
40
+
41
+ if reduction == "mean":
42
+ return jsd_per.mean().to(pred.dtype)
43
+ elif reduction == "sum":
44
+ return jsd_per.sum().to(pred.dtype)
45
+ return jsd_per.to(pred.dtype)
46
+
47
+
48
+ def loss_masked_l1(pred, target, mask, reduction="mean"):
49
+ """L1 loss masked to partial map regions."""
50
+ l = torch.abs(pred - target)
51
+ masked = l * mask
52
+ if reduction == "mean":
53
+ return masked.sum() / (mask.sum() + 1e-8)
54
+ elif reduction == "sum":
55
+ return masked.sum()
56
+ return masked
57
+
58
+
59
+ def loss_l1(pred, target, reduction="mean"):
60
+ """Standard L1 loss."""
61
+ l = (pred - target).abs()
62
+ if reduction == "mean":
63
+ return l.mean().to(pred.dtype)
64
+ elif reduction == "sum":
65
+ return l.sum().to(pred.dtype)
66
+ return l.to(pred.dtype)
67
+
68
+
69
+ def loss_pearson(pred, target, reduction="mean", eps=1e-6):
70
+ """1 - Pearson correlation coefficient."""
71
+ x = pred.float().reshape(pred.shape[0], -1).contiguous()
72
+ y = target.float().reshape(target.shape[0], -1).contiguous()
73
+
74
+ mx = x.mean(dim=1)
75
+ my = y.mean(dim=1)
76
+ x = x - mx[:, None]
77
+ y = y - my[:, None]
78
+
79
+ xx = (x * x).sum(dim=1)
80
+ yy = (y * y).sum(dim=1)
81
+ denom = torch.sqrt(xx * yy + eps)
82
+ rho = ((x * y).sum(dim=1) / denom).clamp(-1.0, 1.0)
83
+
84
+ loss = 1.0 - rho
85
+ if reduction == "mean":
86
+ return loss.mean().to(pred.dtype)
87
+ elif reduction == "sum":
88
+ return loss.sum().to(pred.dtype)
89
+ return loss.to(pred.dtype)
90
+
91
+
92
+ def loss_ranking(pred, gt, margin=0.1):
93
+ """Pairwise ranking loss for relative quality ordering."""
94
+ B, C, H, W = pred.shape
95
+ pred_flat = pred.view(B, -1)
96
+ gt_flat = gt.view(B, -1)
97
+
98
+ n = int(H * W * 0.5)
99
+ idx1 = torch.randint(0, H * W, (B, n), device=pred.device)
100
+ idx2 = torch.randint(0, H * W, (B, n), device=pred.device)
101
+
102
+ pred1 = pred_flat.gather(1, idx1)
103
+ pred2 = pred_flat.gather(1, idx2)
104
+ gt1 = gt_flat.gather(1, idx1)
105
+ gt2 = gt_flat.gather(1, idx2)
106
+
107
+ target = torch.sign(gt1 - gt2)
108
+ return F.margin_ranking_loss(pred1, pred2, target, margin=margin)
pr_iqa/model/__init__.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .priqa import PRIQA
2
+ from .layers import (
3
+ PartialConv2d,
4
+ GatedPartialEmb,
5
+ GatedEmb,
6
+ FeedForward,
7
+ ChannelGate,
8
+ Attention,
9
+ TransformerLikeBlock,
10
+ SandwichBlock,
11
+ Downsample,
12
+ Upsample,
13
+ Pos2d,
14
+ DropPath,
15
+ LayerNorm,
16
+ )
17
+
18
+
19
+ def build_priqa(
20
+ out_channels: int = 1,
21
+ dim: int = 48,
22
+ num_blocks: tuple = (2, 3, 3, 4),
23
+ heads: tuple = (1, 2, 4, 8),
24
+ ffn_expansion_factor: float = 2.66,
25
+ bias: bool = False,
26
+ layernorm_type: str = "WithBias",
27
+ use_partial_conv: bool = True,
28
+ ) -> PRIQA:
29
+ """Build a PR-IQA model with default or custom hyperparameters."""
30
+ return PRIQA(
31
+ inp_channels=4,
32
+ out_channels=out_channels,
33
+ dim=dim,
34
+ num_blocks=list(num_blocks),
35
+ heads=list(heads),
36
+ ffn_expansion_factor=ffn_expansion_factor,
37
+ bias=bias,
38
+ LayerNorm_type=layernorm_type,
39
+ use_partial_conv=use_partial_conv,
40
+ )
41
+
42
+
43
+ __all__ = ["PRIQA", "build_priqa"]
pr_iqa/model/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.18 kB). View file
 
pr_iqa/model/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (1.13 kB). View file
 
pr_iqa/model/__pycache__/layers.cpython-310.pyc ADDED
Binary file (13.9 kB). View file
 
pr_iqa/model/__pycache__/layers.cpython-38.pyc ADDED
Binary file (14.2 kB). View file
 
pr_iqa/model/__pycache__/priqa.cpython-310.pyc ADDED
Binary file (6.95 kB). View file
 
pr_iqa/model/__pycache__/priqa.cpython-38.pyc ADDED
Binary file (6.94 kB). View file
 
pr_iqa/model/layers.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Building blocks for the PR-IQA architecture.
3
+
4
+ Includes:
5
+ - PartialConv2d: Mask-aware convolution for inpainting
6
+ - GatedPartialEmb / GatedEmb: Gated patch embeddings
7
+ - FeedForward (FFN): Gated depth-wise separable FFN
8
+ - ChannelGate: SE/CBAM-style channel attention
9
+ - Attention: Spatial attention with xformers memory-efficient attention
10
+ - TransformerLikeBlock: Channel gate → Spatial attn → FFN with residuals
11
+ - SandwichBlock: FFN → Channel gate → Spatial attn → FFN
12
+ - Downsample / Upsample: Strided conv / PixelShuffle
13
+ - Pos2d: 2D sinusoidal positional encoding
14
+ - DropPath: Stochastic depth
15
+ - LayerNorm: Bias-free or with-bias layer normalization
16
+ """
17
+
18
+ import numbers
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from einops import rearrange
24
+ from xformers import ops
25
+
26
+
27
+ # ---------------------------------------------------------------------------
28
+ # Layer Normalization
29
+ # ---------------------------------------------------------------------------
30
+
31
+ def to_3d(x):
32
+ return rearrange(x, "b c h w -> b (h w) c")
33
+
34
+
35
+ def to_4d(x, h, w):
36
+ return rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
37
+
38
+
39
+ class BiasFree_LayerNorm(nn.Module):
40
+ def __init__(self, normalized_shape):
41
+ super().__init__()
42
+ if isinstance(normalized_shape, numbers.Integral):
43
+ normalized_shape = (normalized_shape,)
44
+ normalized_shape = torch.Size(normalized_shape)
45
+ assert len(normalized_shape) == 1
46
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
47
+ self.normalized_shape = normalized_shape
48
+
49
+ def forward(self, x):
50
+ sigma = x.var(-1, keepdim=True, unbiased=False)
51
+ return x / torch.sqrt(sigma + 1e-5) * self.weight
52
+
53
+
54
+ class WithBias_LayerNorm(nn.Module):
55
+ def __init__(self, normalized_shape):
56
+ super().__init__()
57
+ if isinstance(normalized_shape, numbers.Integral):
58
+ normalized_shape = (normalized_shape,)
59
+ normalized_shape = torch.Size(normalized_shape)
60
+ assert len(normalized_shape) == 1
61
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
62
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
63
+ self.normalized_shape = normalized_shape
64
+
65
+ def forward(self, x):
66
+ mu = x.mean(-1, keepdim=True)
67
+ sigma = x.var(-1, keepdim=True, unbiased=False)
68
+ return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias
69
+
70
+
71
+ class LayerNorm(nn.Module):
72
+ def __init__(self, dim, LayerNorm_type="WithBias"):
73
+ super().__init__()
74
+ if LayerNorm_type == "BiasFree":
75
+ self.body = BiasFree_LayerNorm(dim)
76
+ else:
77
+ self.body = WithBias_LayerNorm(dim)
78
+
79
+ def forward(self, x):
80
+ h, w = x.shape[-2:]
81
+ return to_4d(self.body(to_3d(x)), h, w)
82
+
83
+
84
+ # ---------------------------------------------------------------------------
85
+ # Partial Convolution
86
+ # ---------------------------------------------------------------------------
87
+
88
+ class PartialConv2d(nn.Module):
89
+ """Mask-aware convolution for inpainting.
90
+
91
+ Given input ``x`` and binary mask ``mask`` (1 = valid), the output is
92
+ normalized by the number of valid pixels in each receptive field.
93
+ """
94
+
95
+ def __init__(self, in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True):
96
+ super().__init__()
97
+ self.conv = nn.Conv2d(in_ch, out_ch, kernel_size, stride, padding, bias=False)
98
+ self.mask_conv = nn.Conv2d(1, out_ch, kernel_size, stride, padding, bias=False)
99
+ nn.init.constant_(self.mask_conv.weight, 1.0)
100
+ self.mask_conv.weight.requires_grad = False
101
+ self.bias = nn.Parameter(torch.zeros(out_ch)) if bias else None
102
+
103
+ def forward(self, x, mask):
104
+ with torch.no_grad():
105
+ mask_sum = self.mask_conv(mask).clamp(min=1e-8)
106
+ new_mask = (mask_sum > 0).float()
107
+
108
+ output = self.conv(x * mask) / mask_sum
109
+ if self.bias is not None:
110
+ output = output + self.bias.view(1, -1, 1, 1)
111
+ output = output * new_mask
112
+ return output, new_mask[:, 0:1]
113
+
114
+
115
+ # ---------------------------------------------------------------------------
116
+ # Gated Embeddings
117
+ # ---------------------------------------------------------------------------
118
+
119
+ class GatedPartialEmb(nn.Module):
120
+ """Gated patch embedding using PartialConv2d (for masked inputs)."""
121
+
122
+ def __init__(self, in_c=4, embed_dim=48, bias=False):
123
+ super().__init__()
124
+ self.pconv = PartialConv2d(in_c, embed_dim * 2, kernel_size=3, stride=1, padding=1, bias=bias)
125
+
126
+ def forward(self, x_with_mask, mask):
127
+ """
128
+ Args:
129
+ x_with_mask: (B, in_c, H, W) — e.g. RGB(3) + mask(1) concatenated.
130
+ mask: (B, 1, H, W) — binary mask for partial conv.
131
+ """
132
+ x, mask_out = self.pconv(x_with_mask, mask)
133
+ x1, x2 = x.chunk(2, dim=1)
134
+ x = F.gelu(x1) * x2
135
+ return x, mask_out
136
+
137
+
138
+ class GatedEmb(nn.Module):
139
+ """Gated patch embedding (standard, no partial conv)."""
140
+
141
+ def __init__(self, in_c=3, embed_dim=48, bias=False):
142
+ super().__init__()
143
+ self.gproj1 = nn.Conv2d(in_c, embed_dim * 2, kernel_size=3, stride=1, padding=1, bias=bias)
144
+
145
+ def forward(self, x):
146
+ x = self.gproj1(x)
147
+ x1, x2 = x.chunk(2, dim=1)
148
+ return F.gelu(x1) * x2
149
+
150
+
151
+ # ---------------------------------------------------------------------------
152
+ # Feed-Forward Network
153
+ # ---------------------------------------------------------------------------
154
+
155
+ class FeedForward(nn.Module):
156
+ """Gated depth-wise separable FFN."""
157
+
158
+ def __init__(self, dim, ffn_expansion_factor, bias):
159
+ super().__init__()
160
+ hidden_features = int(dim * ffn_expansion_factor)
161
+ self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias)
162
+ self.dwconv = nn.Conv2d(
163
+ hidden_features * 2, hidden_features * 2,
164
+ kernel_size=3, stride=1, padding=1,
165
+ groups=hidden_features * 2, bias=bias,
166
+ )
167
+ self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
168
+
169
+ def forward(self, x):
170
+ x = self.project_in(x)
171
+ x1, x2 = self.dwconv(x).chunk(2, dim=1)
172
+ x = F.gelu(x1) * x2
173
+ return self.project_out(x)
174
+
175
+
176
+ # ---------------------------------------------------------------------------
177
+ # Channel Attention
178
+ # ---------------------------------------------------------------------------
179
+
180
+ class ChannelGate(nn.Module):
181
+ """SE/CBAM-style channel gate."""
182
+
183
+ def __init__(self, dim, reduction=16, use_max=True, bias=True):
184
+ super().__init__()
185
+ hidden = max(1, dim // reduction)
186
+ self.mlp = nn.Sequential(
187
+ nn.Conv2d(dim, hidden, 1, bias=bias),
188
+ nn.ReLU(inplace=True),
189
+ nn.Conv2d(hidden, dim, 1, bias=bias),
190
+ )
191
+ self.use_max = use_max
192
+
193
+ def _pooled(self, t):
194
+ avg = F.adaptive_avg_pool2d(t, 1)
195
+ if self.use_max:
196
+ mx = F.adaptive_max_pool2d(t, 1)
197
+ pooled = avg + mx
198
+ else:
199
+ pooled = avg
200
+ return self.mlp(pooled)
201
+
202
+ def forward(self, x, kv=None):
203
+ gate_logits = self._pooled(x) if kv is None else (self._pooled(x) + self._pooled(kv))
204
+ gate = torch.sigmoid(gate_logits)
205
+ x_gated = x * gate
206
+ kv_gated = kv * gate if kv is not None else None
207
+ return x_gated, kv_gated
208
+
209
+
210
+ # ---------------------------------------------------------------------------
211
+ # Spatial Attention (xformers)
212
+ # ---------------------------------------------------------------------------
213
+
214
+ class Attention(nn.Module):
215
+ """Spatial attention with xformers memory-efficient attention.
216
+
217
+ Supports both self-attention (kv=None) and cross-attention (kv provided).
218
+ Includes a spatial gating branch.
219
+ """
220
+
221
+ def __init__(self, dim, num_heads, bias):
222
+ super().__init__()
223
+ self.num_heads = num_heads
224
+
225
+ # Self-attention projections
226
+ self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
227
+ self.qkv_dwconv = nn.Conv2d(
228
+ dim * 3, dim * 3, kernel_size=3, stride=1, padding=1,
229
+ groups=dim * 3, bias=bias,
230
+ )
231
+
232
+ # Cross-attention projections
233
+ self.q_proj_qonly = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
234
+ self.q_dw_qonly = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim, bias=bias)
235
+ self.kv_proj_cross = nn.Conv2d(dim, dim * 2, kernel_size=1, bias=bias)
236
+ self.kv_dwconv_cross = nn.Conv2d(
237
+ dim * 2, dim * 2, kernel_size=3, stride=1, padding=1,
238
+ groups=dim * 2, bias=bias,
239
+ )
240
+
241
+ self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
242
+
243
+ # Spatial gating
244
+ self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)
245
+ self.upsample_to = lambda t, size: F.interpolate(t, size=size, mode="bilinear", align_corners=False)
246
+ self.conv = nn.Sequential(
247
+ nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=True),
248
+ LayerNorm(dim, "WithBias"),
249
+ nn.ReLU(inplace=True),
250
+ nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=True),
251
+ LayerNorm(dim, "WithBias"),
252
+ nn.ReLU(inplace=True),
253
+ )
254
+
255
+ def forward(self, x, kv=None):
256
+ b, c, h, w = x.shape
257
+ head_dim = c // self.num_heads
258
+
259
+ if kv is None:
260
+ qkv = self.qkv_dwconv(self.qkv(x))
261
+ q, k, v = qkv.chunk(3, dim=1)
262
+ else:
263
+ q = self.q_dw_qonly(self.q_proj_qonly(x))
264
+ kv_feat = self.kv_dwconv_cross(self.kv_proj_cross(kv))
265
+ k, v = kv_feat.chunk(2, dim=1)
266
+
267
+ q = q.view(b, self.num_heads, head_dim, h * w).permute(0, 3, 1, 2).contiguous()
268
+ k = k.view(b, self.num_heads, head_dim, -1).permute(0, 3, 1, 2).contiguous()
269
+ v = v.view(b, self.num_heads, head_dim, -1).permute(0, 3, 1, 2).contiguous()
270
+
271
+ out = ops.memory_efficient_attention(q, k, v)
272
+ out = out.permute(0, 2, 3, 1).reshape(b, c, h, w)
273
+
274
+ # Spatial gating
275
+ spatial_weight = self.avg_pool(x)
276
+ spatial_weight = self.conv(spatial_weight)
277
+ spatial_weight = self.upsample_to(spatial_weight, (h, w))
278
+ out = out * spatial_weight
279
+
280
+ return self.project_out(out)
281
+
282
+
283
+ # ---------------------------------------------------------------------------
284
+ # Drop Path (Stochastic Depth)
285
+ # ---------------------------------------------------------------------------
286
+
287
+ class DropPath(nn.Module):
288
+ def __init__(self, p: float = 0.0):
289
+ super().__init__()
290
+ self.p = float(p)
291
+
292
+ def forward(self, x):
293
+ if self.p == 0.0 or not self.training:
294
+ return x
295
+ keep = 1.0 - self.p
296
+ mask = torch.rand(x.shape[0], 1, 1, 1, device=x.device, dtype=x.dtype) < keep
297
+ return x * mask / keep
298
+
299
+
300
+ # ---------------------------------------------------------------------------
301
+ # Transformer-like Block
302
+ # ---------------------------------------------------------------------------
303
+
304
+ class TransformerLikeBlock(nn.Module):
305
+ """Channel gate → Spatial attention → FFN with layer scale and residuals."""
306
+
307
+ def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type,
308
+ drop_path=0.0, layerscale_init=1e-2):
309
+ super().__init__()
310
+ self.norm_c = LayerNorm(dim, LayerNorm_type)
311
+ self.chan = ChannelGate(dim, reduction=16, use_max=True, bias=bias)
312
+ self.norm_s = LayerNorm(dim, LayerNorm_type)
313
+ self.sattn = Attention(dim, num_heads, bias)
314
+ self.norm_f = LayerNorm(dim, LayerNorm_type)
315
+ self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
316
+
317
+ self.gamma_c = nn.Parameter(torch.ones(1, dim, 1, 1) * layerscale_init)
318
+ self.gamma_s = nn.Parameter(torch.ones(1, dim, 1, 1) * layerscale_init)
319
+ self.gamma_f = nn.Parameter(torch.ones(1, dim, 1, 1) * layerscale_init)
320
+ self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity()
321
+
322
+ def forward(self, x, kv=None):
323
+ xc = self.norm_c(x)
324
+ xc_gated, kv_gated = self.chan(xc, kv)
325
+ x = x + self.drop_path(self.gamma_c * xc_gated)
326
+
327
+ xs = self.norm_s(x)
328
+ xs = self.sattn(xs, kv_gated if kv is not None else None)
329
+ x = x + self.drop_path(self.gamma_s * xs)
330
+
331
+ xf = self.norm_f(x)
332
+ xf = self.ffn(xf)
333
+ x = x + self.drop_path(self.gamma_f * xf)
334
+ return x
335
+
336
+
337
+ # ---------------------------------------------------------------------------
338
+ # Sandwich Block
339
+ # ---------------------------------------------------------------------------
340
+
341
+ class SandwichBlock(nn.Module):
342
+ """FFN → Channel gate → Spatial attn → FFN."""
343
+
344
+ def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type):
345
+ super().__init__()
346
+ self.norm1_1 = LayerNorm(dim, LayerNorm_type)
347
+ self.ffn1 = FeedForward(dim, ffn_expansion_factor, bias)
348
+ self.norm_c = LayerNorm(dim, LayerNorm_type)
349
+ self.chan = ChannelGate(dim, reduction=16, use_max=True, bias=bias)
350
+ self.norm1 = LayerNorm(dim, LayerNorm_type)
351
+ self.attn = Attention(dim, num_heads, bias)
352
+ self.norm2 = LayerNorm(dim, LayerNorm_type)
353
+ self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
354
+
355
+ def forward(self, x, kv=None):
356
+ x = x + self.ffn1(self.norm1_1(x))
357
+ xc = self.norm_c(x)
358
+ xc_gated, kv_gated = self.chan(xc, kv)
359
+ x = x + xc_gated
360
+ x = x + self.attn(self.norm1(x), kv_gated if kv is not None else None)
361
+ x = x + self.ffn(self.norm2(x))
362
+ return x
363
+
364
+
365
+ # ---------------------------------------------------------------------------
366
+ # Downsample / Upsample
367
+ # ---------------------------------------------------------------------------
368
+
369
+ class Downsample(nn.Module):
370
+ def __init__(self, n_feat):
371
+ super().__init__()
372
+ self.body = nn.Sequential(
373
+ nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=2, padding=1, bias=False),
374
+ )
375
+
376
+ def forward(self, x, mask=None):
377
+ return self.body(x)
378
+
379
+
380
+ class Upsample(nn.Module):
381
+ def __init__(self, n_feat):
382
+ super().__init__()
383
+ self.body = nn.Sequential(
384
+ nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False),
385
+ nn.PixelShuffle(2),
386
+ )
387
+
388
+ def forward(self, x, mask=None):
389
+ return self.body(x)
390
+
391
+
392
+ # ---------------------------------------------------------------------------
393
+ # Positional Encoding
394
+ # ---------------------------------------------------------------------------
395
+
396
+ class Pos2d(nn.Module):
397
+ """2D sinusoidal positional encoding."""
398
+
399
+ def __init__(self, dim):
400
+ super().__init__()
401
+ self.proj = nn.Conv2d(4, dim, kernel_size=1)
402
+
403
+ def forward(self, x):
404
+ B, C, H, W = x.shape
405
+ device = x.device
406
+ yy, xx = torch.meshgrid(
407
+ torch.linspace(-1, 1, H, device=device),
408
+ torch.linspace(-1, 1, W, device=device),
409
+ indexing="ij",
410
+ )
411
+ pe4 = torch.stack([xx, yy, torch.sin(xx * 3.14159), torch.cos(yy * 3.14159)], dim=0)
412
+ pe = self.proj(pe4.unsqueeze(0)).repeat(B, 1, 1, 1)
413
+ return x + pe
pr_iqa/model/priqa.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PR-IQA: Partial-Reference Image Quality Assessment model.
3
+
4
+ 3-input U-Net encoder-decoder with cross-attention:
5
+ - tgt_img: partial quality map (from FeatureMetric) replicated to 3ch
6
+ - dif_img: generated / distorted image
7
+ - ref_img: reference image
8
+
9
+ Each input comes with a 4-scale mask pyramid (whole, half, quarter, tiny).
10
+
11
+ Architecture:
12
+ Encoder: 4 levels (dim → 2*dim → 4*dim → 8*dim)
13
+ - img_encoder: shared for ref_img and dif_img (self-attention)
14
+ - map_encoder: for tgt_img (cross-attention with ref features)
15
+ - qfuse: fuses dif and tgt encoder outputs at each level
16
+ Decoder: 3 levels with skip connections from the dif encoder
17
+ Output: tanh-activated quality map
18
+ """
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+
24
+ from .layers import (
25
+ GatedPartialEmb,
26
+ GatedEmb,
27
+ TransformerLikeBlock,
28
+ Downsample,
29
+ Upsample,
30
+ Pos2d,
31
+ )
32
+
33
+
34
+ class PRIQA(nn.Module):
35
+ """Partial-Reference Image Quality Assessment model.
36
+
37
+ Args:
38
+ inp_channels: Input channels per image (typically 4 = RGB + mask).
39
+ out_channels: Output channels (1 for quality map, 3 for RGB).
40
+ dim: Base feature dimension (doubles at each encoder level).
41
+ num_blocks: Number of TransformerLikeBlocks at each level.
42
+ heads: Number of attention heads at each level.
43
+ ffn_expansion_factor: FFN hidden dim multiplier.
44
+ bias: Use bias in convolutions.
45
+ LayerNorm_type: ``"WithBias"`` or ``"BiasFree"``.
46
+ use_partial_conv: Use PartialConv2d in patch embedding.
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ inp_channels=4,
52
+ out_channels=3,
53
+ dim=48,
54
+ num_blocks=[4, 6, 6, 8],
55
+ heads=[1, 2, 4, 8],
56
+ ffn_expansion_factor=2.66,
57
+ bias=False,
58
+ LayerNorm_type="WithBias",
59
+ use_partial_conv=True,
60
+ ):
61
+ super().__init__()
62
+ self.use_partial_conv = use_partial_conv
63
+
64
+ # -- Patch embedding --
65
+ if use_partial_conv:
66
+ self.patch_embed = GatedPartialEmb(inp_channels, dim, bias)
67
+ else:
68
+ self.patch_embed = GatedEmb(inp_channels, dim, bias)
69
+
70
+ # -- Quality fusion (dif + tgt) at each level --
71
+ self.qfuse_l1 = nn.Conv2d(dim * 2, dim, kernel_size=1, bias=bias)
72
+ self.qfuse_l2 = nn.Conv2d(int(dim * 2 ** 1) * 2, int(dim * 2 ** 1), kernel_size=1, bias=bias)
73
+ self.qfuse_l3 = nn.Conv2d(int(dim * 2 ** 2) * 2, int(dim * 2 ** 2), kernel_size=1, bias=bias)
74
+ self.qfuse_l4 = nn.Conv2d(int(dim * 2 ** 3) * 2, int(dim * 2 ** 3), kernel_size=1, bias=bias)
75
+
76
+ # -- Downsampler --
77
+ self.down1_2 = Downsample(dim)
78
+ self.down2_3 = Downsample(int(dim * 2 ** 1))
79
+ self.down3_4 = Downsample(int(dim * 2 ** 2))
80
+
81
+ # -- Positional Encoding --
82
+ self.pos_l1 = Pos2d(dim)
83
+ self.pos_l2 = Pos2d(int(dim * 2 ** 1))
84
+ self.pos_l3 = Pos2d(int(dim * 2 ** 2))
85
+ self.pos_l4 = Pos2d(int(dim * 2 ** 3))
86
+ self.pos_d3 = Pos2d(int(dim * 2 ** 2))
87
+ self.pos_d2 = Pos2d(int(dim * 2 ** 1))
88
+ self.pos_d1 = Pos2d(int(dim * 2 ** 1))
89
+
90
+ # -- Encoder (img: shared for ref & dif) --
91
+ def _make_encoder(level_dim, n_blocks, n_heads):
92
+ return nn.ModuleList([
93
+ TransformerLikeBlock(
94
+ dim=level_dim, num_heads=n_heads,
95
+ ffn_expansion_factor=ffn_expansion_factor,
96
+ bias=bias, LayerNorm_type=LayerNorm_type,
97
+ )
98
+ for _ in range(n_blocks)
99
+ ])
100
+
101
+ self.img_encoder_level1 = _make_encoder(dim, num_blocks[0], heads[0])
102
+ self.img_encoder_level2 = _make_encoder(int(dim * 2 ** 1), num_blocks[1], heads[1])
103
+ self.img_encoder_level3 = _make_encoder(int(dim * 2 ** 2), num_blocks[2], heads[2])
104
+ self.img_latent = _make_encoder(int(dim * 2 ** 3), num_blocks[3], heads[3])
105
+
106
+ # -- Encoder (map: for tgt, cross-attention with ref) --
107
+ self.map_encoder_level1 = _make_encoder(dim, num_blocks[0], heads[0])
108
+ self.map_encoder_level2 = _make_encoder(int(dim * 2 ** 1), num_blocks[1], heads[1])
109
+ self.map_encoder_level3 = _make_encoder(int(dim * 2 ** 2), num_blocks[2], heads[2])
110
+ self.map_latent = _make_encoder(int(dim * 2 ** 3), num_blocks[3], heads[3])
111
+
112
+ # -- Decoder --
113
+ self.up4_3 = Upsample(int(dim * 2 ** 3))
114
+ self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias)
115
+ self.decoder_level3 = _make_encoder(int(dim * 2 ** 2), num_blocks[2], heads[2])
116
+
117
+ self.up3_2 = Upsample(int(dim * 2 ** 2))
118
+ self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias)
119
+ self.decoder_level2 = _make_encoder(int(dim * 2 ** 1), num_blocks[1], heads[1])
120
+
121
+ self.up2_1 = Upsample(int(dim * 2 ** 1))
122
+ self.decoder_level1 = _make_encoder(int(dim * 2 ** 1), num_blocks[0], heads[0])
123
+
124
+ # -- Output --
125
+ self.output = nn.Sequential(
126
+ nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias),
127
+ )
128
+
129
+ def forward(
130
+ self,
131
+ tgt_img, dif_img, ref_img,
132
+ tgt_mask_whole, tgt_mask_half, tgt_mask_quarter, tgt_mask_tiny,
133
+ dif_mask_whole, dif_mask_half, dif_mask_quarter, dif_mask_tiny,
134
+ ref_mask_whole, ref_mask_half, ref_mask_quarter, ref_mask_tiny,
135
+ ):
136
+ """
137
+ Args:
138
+ tgt_img: (B, 3, H, W) — partial quality map replicated to 3ch.
139
+ dif_img: (B, 3, H, W) — generated / distorted image.
140
+ ref_img: (B, 3, H, W) — reference image.
141
+ *_mask_*: (B, 1, H/s, W/s) — mask pyramids at 4 scales.
142
+
143
+ Returns:
144
+ (B, out_channels, H, W) quality map (tanh activated).
145
+ """
146
+ # -- Patch embedding --
147
+ if self.use_partial_conv:
148
+ tgt_enc_level1, _ = self.patch_embed(
149
+ torch.cat((tgt_img, tgt_mask_whole), dim=1), tgt_mask_whole,
150
+ )
151
+ dif_enc_level1, _ = self.patch_embed(
152
+ torch.cat((dif_img, dif_mask_whole), dim=1), dif_mask_whole,
153
+ )
154
+ ref_enc_level1, _ = self.patch_embed(
155
+ torch.cat((ref_img, ref_mask_whole), dim=1), ref_mask_whole,
156
+ )
157
+ else:
158
+ tgt_enc_level1 = self.patch_embed(torch.cat((tgt_img, tgt_mask_whole), dim=1))
159
+ dif_enc_level1 = self.patch_embed(torch.cat((dif_img, dif_mask_whole), dim=1))
160
+ ref_enc_level1 = self.patch_embed(torch.cat((ref_img, ref_mask_whole), dim=1))
161
+
162
+ tgt_enc_level1 = self.pos_l1(tgt_enc_level1)
163
+ dif_enc_level1 = self.pos_l1(dif_enc_level1)
164
+ ref_enc_level1 = self.pos_l1(ref_enc_level1)
165
+
166
+ # ── ENCODER Level 1 ──
167
+ out_ref_enc_level1 = ref_enc_level1
168
+ for block in self.img_encoder_level1:
169
+ out_ref_enc_level1 = block(out_ref_enc_level1)
170
+ kv_level1 = out_ref_enc_level1
171
+
172
+ out_tgt_enc_level1 = tgt_enc_level1
173
+ for block in self.map_encoder_level1:
174
+ out_tgt_enc_level1 = block(out_tgt_enc_level1, kv_level1)
175
+
176
+ out_dif_enc_level1 = dif_enc_level1
177
+ for block in self.img_encoder_level1:
178
+ out_dif_enc_level1 = block(out_dif_enc_level1, kv_level1)
179
+
180
+ out_dif_enc_level1 = self.qfuse_l1(torch.cat([out_dif_enc_level1, out_tgt_enc_level1], dim=1))
181
+
182
+ # ── ENCODER Level 2 ──
183
+ inp_tgt_enc_level2 = self.pos_l2(self.down1_2(out_tgt_enc_level1, tgt_mask_whole))
184
+ inp_dif_enc_level2 = self.pos_l2(self.down1_2(out_dif_enc_level1, dif_mask_whole))
185
+ inp_ref_enc_level2 = self.pos_l2(self.down1_2(out_ref_enc_level1, ref_mask_whole))
186
+
187
+ out_ref_enc_level2 = inp_ref_enc_level2
188
+ for block in self.img_encoder_level2:
189
+ out_ref_enc_level2 = block(out_ref_enc_level2)
190
+ kv_level2 = out_ref_enc_level2
191
+
192
+ out_tgt_enc_level2 = inp_tgt_enc_level2
193
+ for block in self.map_encoder_level2:
194
+ out_tgt_enc_level2 = block(out_tgt_enc_level2, kv_level2)
195
+
196
+ out_dif_enc_level2 = inp_dif_enc_level2
197
+ for block in self.img_encoder_level2:
198
+ out_dif_enc_level2 = block(out_dif_enc_level2, kv_level2)
199
+
200
+ out_dif_enc_level2 = self.qfuse_l2(torch.cat([out_dif_enc_level2, out_tgt_enc_level2], dim=1))
201
+
202
+ # ── ENCODER Level 3 ──
203
+ inp_tgt_enc_level3 = self.pos_l3(self.down2_3(out_tgt_enc_level2, tgt_mask_half))
204
+ inp_dif_enc_level3 = self.pos_l3(self.down2_3(out_dif_enc_level2, dif_mask_half))
205
+ inp_ref_enc_level3 = self.pos_l3(self.down2_3(out_ref_enc_level2, ref_mask_half))
206
+
207
+ out_ref_enc_level3 = inp_ref_enc_level3
208
+ for block in self.img_encoder_level3:
209
+ out_ref_enc_level3 = block(out_ref_enc_level3)
210
+ kv_level3 = out_ref_enc_level3
211
+
212
+ out_tgt_enc_level3 = inp_tgt_enc_level3
213
+ for block in self.map_encoder_level3:
214
+ out_tgt_enc_level3 = block(out_tgt_enc_level3, kv_level3)
215
+
216
+ out_dif_enc_level3 = inp_dif_enc_level3
217
+ for block in self.img_encoder_level3:
218
+ out_dif_enc_level3 = block(out_dif_enc_level3, kv_level3)
219
+
220
+ out_dif_enc_level3 = self.qfuse_l3(torch.cat([out_dif_enc_level3, out_tgt_enc_level3], dim=1))
221
+
222
+ # ── ENCODER Level 4 (Latent) ──
223
+ inp_tgt_enc_level4 = self.pos_l4(self.down3_4(out_tgt_enc_level3, tgt_mask_quarter))
224
+ inp_dif_enc_level4 = self.pos_l4(self.down3_4(out_dif_enc_level3, dif_mask_quarter))
225
+ inp_ref_enc_level4 = self.pos_l4(self.down3_4(out_ref_enc_level3, ref_mask_quarter))
226
+
227
+ ref_latent_out = inp_ref_enc_level4
228
+ for block in self.img_latent:
229
+ ref_latent_out = block(ref_latent_out)
230
+ kv_level4 = ref_latent_out
231
+
232
+ tgt_latent_out = inp_tgt_enc_level4
233
+ for block in self.map_latent:
234
+ tgt_latent_out = block(tgt_latent_out, kv_level4)
235
+
236
+ dif_latent_out = inp_dif_enc_level4
237
+ for block in self.img_latent:
238
+ dif_latent_out = block(dif_latent_out, kv_level4)
239
+
240
+ latent_out = self.qfuse_l4(torch.cat([dif_latent_out, tgt_latent_out], dim=1))
241
+
242
+ # ── DECODER ──
243
+ inp_dec_level3 = self.up4_3(latent_out, dif_mask_tiny)
244
+ inp_dec_level3 = torch.cat([inp_dec_level3, out_dif_enc_level3], 1)
245
+ inp_dec_level3 = self.pos_d3(self.reduce_chan_level3(inp_dec_level3))
246
+ out_dec_level3 = inp_dec_level3
247
+ for block in self.decoder_level3:
248
+ out_dec_level3 = block(out_dec_level3)
249
+
250
+ inp_dec_level2 = self.up3_2(out_dec_level3, dif_mask_quarter)
251
+ inp_dec_level2 = torch.cat([inp_dec_level2, out_dif_enc_level2], 1)
252
+ inp_dec_level2 = self.pos_d2(self.reduce_chan_level2(inp_dec_level2))
253
+ out_dec_level2 = inp_dec_level2
254
+ for block in self.decoder_level2:
255
+ out_dec_level2 = block(out_dec_level2)
256
+
257
+ inp_dec_level1 = self.up2_1(out_dec_level2, dif_mask_half)
258
+ inp_dec_level1 = torch.cat([inp_dec_level1, out_dif_enc_level1], 1)
259
+ inp_dec_level1 = self.pos_d1(inp_dec_level1)
260
+ out_dec_level1 = inp_dec_level1
261
+ for block in self.decoder_level1:
262
+ out_dec_level1 = block(out_dec_level1)
263
+
264
+ return torch.tanh(self.output(out_dec_level1))
pr_iqa/partial_map/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .feature_metric import FeatureMetric
2
+
3
+ __all__ = ["FeatureMetric"]
pr_iqa/partial_map/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (214 Bytes). View file
 
pr_iqa/partial_map/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (212 Bytes). View file
 
pr_iqa/partial_map/__pycache__/feature_metric.cpython-310.pyc ADDED
Binary file (8.03 kB). View file
 
pr_iqa/partial_map/__pycache__/feature_metric.cpython-38.pyc ADDED
Binary file (7.91 kB). View file
 
pr_iqa/partial_map/feature_metric.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FeatureMetric: DINOv2 + LoftUp feature-based quality metric.
3
+
4
+ Generates partial quality maps by:
5
+ 1. Extracting DINOv2 features (upsampled via LoftUp) from input images
6
+ 2. Using VGGT for monocular depth and pose estimation
7
+ 3. Constructing a colored 3D point cloud with features
8
+ 4. Rendering the point cloud from the target viewpoint via PyTorch3D
9
+ 5. Computing cosine similarity between rendered features and target features
10
+
11
+ Two modes:
12
+ - partial_generation=True: Full 3D pipeline → partial map + overlap mask
13
+ - partial_generation=False: Direct cosine similarity → total quality map
14
+
15
+ Dependencies (Level 1):
16
+ - VGGT (facebook/VGGT-1B)
17
+ - LoftUp (andrehuang/loftup)
18
+ - PyTorch3D
19
+ """
20
+
21
+ import sys
22
+ import torch
23
+ from torch import Tensor
24
+ from torch.nn import Module
25
+ import numpy as np
26
+ from typing import Optional, Tuple, Union
27
+ from pathlib import Path
28
+ from einops import rearrange
29
+
30
+ # Auto-detect submodule paths
31
+ _THIS_DIR = Path(__file__).resolve().parent
32
+ _REPO_ROOT = _THIS_DIR.parent.parent
33
+ _SUBMODULES = _REPO_ROOT / "submodules"
34
+
35
+ if (_SUBMODULES / "vggt").exists():
36
+ sys.path.insert(0, str(_SUBMODULES / "vggt"))
37
+ if (_SUBMODULES / "loftup").exists():
38
+ sys.path.insert(0, str(_SUBMODULES / "loftup"))
39
+
40
+ # Lazy imports for heavy dependencies — loaded on first use
41
+ _VGGT = None
42
+ _LOFTUP_FEATURIZERS = None
43
+ _LOFTUP_UPSAMPLERS = None
44
+ _PYTORCH3D = None
45
+
46
+
47
+ def _import_vggt():
48
+ global _VGGT
49
+ if _VGGT is None:
50
+ from vggt.models.vggt import VGGT as _V
51
+ from vggt.utils.pose_enc import pose_encoding_to_extri_intri as _pe
52
+ from vggt.utils.geometry import unproject_depth_map_to_point_map as _ud
53
+ from vggt.utils.load_fn import load_and_preprocess_images as _lpi
54
+ _VGGT = {"VGGT": _V, "pose_encoding_to_extri_intri": _pe,
55
+ "unproject_depth_map_to_point_map": _ud,
56
+ "load_and_preprocess_images": _lpi}
57
+ return _VGGT
58
+
59
+
60
+ def _import_loftup():
61
+ global _LOFTUP_FEATURIZERS, _LOFTUP_UPSAMPLERS
62
+ if _LOFTUP_FEATURIZERS is None:
63
+ from featurizers import get_featurizer as _gf
64
+ from upsamplers import norm as _n
65
+ _LOFTUP_FEATURIZERS = _gf
66
+ _LOFTUP_UPSAMPLERS = _n
67
+ return _LOFTUP_FEATURIZERS, _LOFTUP_UPSAMPLERS
68
+
69
+
70
+ def _import_pytorch3d():
71
+ global _PYTORCH3D
72
+ if _PYTORCH3D is None:
73
+ from pytorch3d.structures import Pointclouds
74
+ from pytorch3d.renderer import (
75
+ PointsRasterizationSettings,
76
+ PointsRasterizer,
77
+ AlphaCompositor,
78
+ )
79
+ from pytorch3d.renderer.camera_conversions import _cameras_from_opencv_projection
80
+ _PYTORCH3D = {
81
+ "Pointclouds": Pointclouds,
82
+ "PointsRasterizationSettings": PointsRasterizationSettings,
83
+ "PointsRasterizer": PointsRasterizer,
84
+ "AlphaCompositor": AlphaCompositor,
85
+ "_cameras_from_opencv_projection": _cameras_from_opencv_projection,
86
+ }
87
+ return _PYTORCH3D
88
+
89
+
90
+ class FeatureMetric(Module):
91
+ """DINOv2 + LoftUp + VGGT → partial / total quality map.
92
+
93
+ Args:
94
+ img_size: Inference image size (controls rasterizer resolution).
95
+ feature_backbone: Name of the feature backbone (default: ``"dinov2"``).
96
+ loftup_torch_hub: Torch Hub repository for LoftUp.
97
+ loftup_model_name: LoftUp model name.
98
+ vggt_weights: HuggingFace model ID for VGGT.
99
+ use_vggt: Load VGGT for depth/pose estimation.
100
+ use_loftup: Load LoftUp for feature upsampling.
101
+ """
102
+
103
+ def __init__(
104
+ self,
105
+ img_size: int = 256,
106
+ feature_backbone: str = "dinov2",
107
+ loftup_torch_hub: Union[str, Path] = "andrehuang/loftup",
108
+ loftup_model_name: Union[str, Path] = "loftup_dinov2s",
109
+ vggt_weights: Union[str, Path] = "facebook/VGGT-1B",
110
+ use_vggt: bool = True,
111
+ use_loftup: bool = False,
112
+ **kwargs,
113
+ ) -> None:
114
+ super().__init__()
115
+ self.img_size = img_size
116
+
117
+ get_featurizer, _ = _import_loftup()
118
+ self.feature_backbone, self.patch_size, self.dim = get_featurizer(feature_backbone)
119
+
120
+ self.upsampler = (
121
+ torch.hub.load(loftup_torch_hub, loftup_model_name, pretrained=True)
122
+ if use_loftup else None
123
+ )
124
+ self.use_loftup = use_loftup
125
+
126
+ if use_vggt:
127
+ vggt_mod = _import_vggt()
128
+ self.vggt = vggt_mod["VGGT"].from_pretrained(vggt_weights)
129
+
130
+ p3d = _import_pytorch3d()
131
+ self.compositor = p3d["AlphaCompositor"]()
132
+
133
+ def _render(self, point_clouds, **kwargs):
134
+ """Render point cloud features to images."""
135
+ with torch.autocast("cuda", enabled=False):
136
+ fragments = self.rasterizer(point_clouds, **kwargs)
137
+
138
+ r = self.rasterizer.raster_settings.radius
139
+ dists2 = fragments.dists.permute(0, 3, 1, 2)
140
+ weights = 1 - dists2 / (r * r)
141
+
142
+ images = self.compositor(
143
+ fragments.idx.long().permute(0, 3, 1, 2),
144
+ weights,
145
+ point_clouds.features_packed().permute(1, 0),
146
+ **kwargs,
147
+ )
148
+ images = images.permute(0, 2, 3, 1)
149
+ return images, fragments.zbuf
150
+
151
+ @torch.no_grad()
152
+ def forward(
153
+ self,
154
+ device: str,
155
+ images: Tensor, # (K, 3, H, W)
156
+ return_overlap_mask: bool = False,
157
+ return_score_map: bool = False,
158
+ return_projections: bool = False,
159
+ partial_generation: bool = False,
160
+ use_filtering: bool = False,
161
+ ) -> Tuple[float, Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
162
+ """Compute quality score map.
163
+
164
+ Args:
165
+ device: Torch device string.
166
+ images: (K, 3, H, W) input images. First image is the target.
167
+ partial_generation: If True, use full 3D pipeline for partial map.
168
+
169
+ Returns:
170
+ (score_scalar, overlap_mask, score_map, projections)
171
+ """
172
+ k, c, h, w = images.shape
173
+ p3d = _import_pytorch3d()
174
+ _, norm_fn = _import_loftup()
175
+
176
+ # Setup rasterizer
177
+ raster_settings = p3d["PointsRasterizationSettings"](
178
+ image_size=(h, w), radius=0.01, points_per_pixel=10, bin_size=0,
179
+ )
180
+ self.rasterizer = p3d["PointsRasterizer"](cameras=None, raster_settings=raster_settings)
181
+
182
+ # Extract features
183
+ images_norm = norm_fn(images)
184
+ hr_feats = []
185
+ for i in range(k):
186
+ img = images_norm[i:i + 1]
187
+ lr_feat = self.feature_backbone(img)
188
+ if self.use_loftup and self.upsampler is not None:
189
+ hr_feat = self.upsampler(lr_feat, img)
190
+ else:
191
+ hr_feat = lr_feat
192
+ hr_feat = rearrange(hr_feat, "b c h w -> b (h w) c")
193
+ hr_feats.append(hr_feat)
194
+ hr_feats = torch.cat(hr_feats, dim=0)
195
+
196
+ if not partial_generation:
197
+ # Fast cosine similarity mode
198
+ dot = (hr_feats[0] * hr_feats[1]).sum(dim=1)
199
+ tgt_norm = torch.linalg.norm(hr_feats[0], dim=1)
200
+ ref_norm = torch.linalg.norm(hr_feats[1], dim=1)
201
+ cosine_sim = dot / (tgt_norm * ref_norm + 1e-8)
202
+ score_map = torch.clamp(cosine_sim, min=0.0, max=1.0)
203
+
204
+ if self.use_loftup and self.upsampler is not None:
205
+ H_out, W_out = h, w
206
+ else:
207
+ H_out = h // self.patch_size
208
+ W_out = w // self.patch_size
209
+ score_map = score_map.reshape(H_out, W_out).unsqueeze(0)
210
+ return score_map.mean().item(), None, score_map if return_score_map else None, None
211
+
212
+ # Full 3D partial map generation
213
+ vggt_mod = _import_vggt()
214
+ pose_encoding_to_extri_intri = vggt_mod["pose_encoding_to_extri_intri"]
215
+ unproject_depth_map_to_point_map = vggt_mod["unproject_depth_map_to_point_map"]
216
+
217
+ preds = self.vggt(images)
218
+ extrinsic, intrinsic = pose_encoding_to_extri_intri(preds["pose_enc"], images.shape[-2:])
219
+ depth, depth_conf = preds["depth"], preds["depth_conf"]
220
+
221
+ point_map = unproject_depth_map_to_point_map(
222
+ depth.squeeze(0), extrinsic.squeeze(0), intrinsic.squeeze(0),
223
+ )
224
+ cols = images.cpu().numpy().transpose(0, 2, 3, 1)
225
+ cols = cols / cols.max()
226
+ pts_flatten = torch.from_numpy(
227
+ rearrange(point_map, "k h w c -> k (h w) c")
228
+ ).float().to(device)
229
+
230
+ if use_filtering:
231
+ percent = 20
232
+ quantile = torch.quantile(depth_conf, percent / 100.0)
233
+ mask_flat = rearrange((depth_conf > quantile).squeeze(0), "k h w -> k (h w)")
234
+ points_list, features_list = [], []
235
+ for i in range(k):
236
+ valid = mask_flat[i]
237
+ points_list.append(pts_flatten[i][valid])
238
+ features_list.append(hr_feats[i][valid])
239
+ point_clouds = p3d["Pointclouds"](points=points_list, features=features_list)
240
+ else:
241
+ point_clouds = p3d["Pointclouds"](points=pts_flatten, features=hr_feats)
242
+
243
+ # Render from target viewpoint
244
+ extrinsic, intrinsic = pose_encoding_to_extri_intri(preds["pose_enc"], images.shape[-2:])
245
+ E, K = extrinsic.squeeze(0), intrinsic.squeeze(0)
246
+ R0, T0, K0 = E[0, :3, :3], E[0, :3, 3], K[0]
247
+ B = pts_flatten.shape[0]
248
+
249
+ R_repeat = R0.unsqueeze(0).repeat(B, 1, 1)
250
+ T_repeat = T0.unsqueeze(0).repeat(B, 1)
251
+ K_repeat = K0.unsqueeze(0).repeat(B, 1, 1)
252
+ im_size = torch.tensor([[h, w]]).repeat(B, 1).to(device)
253
+
254
+ cameras_p3d = p3d["_cameras_from_opencv_projection"](R_repeat, T_repeat, K_repeat, im_size)
255
+
256
+ with torch.autocast("cuda", enabled=False):
257
+ bg_color = torch.tensor(
258
+ [-10000] * hr_feats[0].shape[-1], dtype=torch.float32, device=device,
259
+ )
260
+ rendering, zbuf = self._render(point_clouds, cameras=cameras_p3d, background_color=bg_color)
261
+ rendering = rearrange(rendering, "k h w c -> k c h w")
262
+
263
+ # Cosine similarity score map
264
+ target = rendering[0:1]
265
+ reference = rendering[1:]
266
+ dot = (reference * target).sum(dim=1)
267
+ tgt_norm = torch.linalg.norm(target, dim=1)
268
+ ref_norm = torch.linalg.norm(reference, dim=1)
269
+ cosine_sim = dot / (tgt_norm * ref_norm + 1e-8)
270
+ score_map = torch.clamp(cosine_sim, min=0.0, max=1.0)
271
+
272
+ # Mask true background
273
+ target_mask = zbuf[0, ..., 0] >= 0
274
+ reference_mask = zbuf[1:, ..., 0] >= 0
275
+ true_bg = ~target_mask & ~torch.any(reference_mask, dim=0)
276
+ score_map[:, true_bg] = 0.0
277
+
278
+ overlap_mask = zbuf[1:, ..., 0] >= 0
279
+
280
+ return (
281
+ score_map.mean().item(),
282
+ overlap_mask if return_overlap_mask else None,
283
+ score_map if return_score_map else None,
284
+ rendering if return_projections else None,
285
+ )
pr_iqa/transforms.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data transforms and batch preparation utilities for PR-IQA training.
3
+
4
+ ImageNet normalization is applied to RGB inputs.
5
+ Grayscale inputs (partial maps, masks) are kept in [0, 1].
6
+ """
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import torchvision.transforms as T
11
+
12
+
13
+ # ImageNet normalization constants
14
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
15
+ IMAGENET_STD = (0.229, 0.224, 0.225)
16
+
17
+
18
+ def build_rgb_transform(img_size: int = 256) -> T.Compose:
19
+ """Transform for RGB images: resize → tensor → ImageNet normalize."""
20
+ return T.Compose([
21
+ T.Resize((img_size, img_size)),
22
+ T.ToTensor(),
23
+ T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
24
+ ])
25
+
26
+
27
+ def build_grey_transform(img_size: int = 256) -> T.Compose:
28
+ """Transform for grayscale images (maps/masks): resize → tensor [0,1]."""
29
+ return T.Compose([
30
+ T.Resize((img_size, img_size)),
31
+ T.ToTensor(),
32
+ ])
33
+
34
+
35
+ def make_pyramid_masks(mask_whole: torch.Tensor):
36
+ """Build 3 downscaled masks from (B, 1, H, W) → half, quarter, tiny."""
37
+ mask_half = F.interpolate(mask_whole, scale_factor=0.5, mode="nearest")
38
+ mask_quarter = F.interpolate(mask_whole, scale_factor=0.25, mode="nearest")
39
+ mask_tiny = F.interpolate(mask_whole, scale_factor=0.125, mode="nearest")
40
+ return mask_half, mask_quarter, mask_tiny
41
+
42
+
43
+ def prepare_batch(batch: dict, device: torch.device):
44
+ """Prepare a training batch for the PR-IQA model.
45
+
46
+ Takes a dataset batch dict and returns (model_args, gt) where
47
+ model_args is a tuple of 15 tensors matching PRIQA.forward() signature.
48
+
49
+ Returns:
50
+ model_args: (tgt_img, dif_img, ref_img, + 12 mask tensors)
51
+ gt: (B, 1, H, W) ground truth quality map
52
+ """
53
+ dtype = torch.bfloat16
54
+
55
+ dif_img = batch["tgt_diff"].to(device, dtype=dtype, non_blocking=True,
56
+ memory_format=torch.channels_last)
57
+ tgt_mask_whole = batch["partial_mask"].to(device, dtype=dtype, non_blocking=True,
58
+ memory_format=torch.channels_last)
59
+ tgt_img_1ch = batch["partial_map"].to(device, dtype=dtype, non_blocking=True,
60
+ memory_format=torch.channels_last)
61
+ tgt_img = tgt_img_1ch.repeat(1, 3, 1, 1)
62
+ ref_img = batch["current_ref"].to(device, dtype=dtype, non_blocking=True,
63
+ memory_format=torch.channels_last)
64
+ gt = batch["full_map"].to(device, dtype=dtype, non_blocking=True,
65
+ memory_format=torch.channels_last)
66
+
67
+ tgt_mask_half, tgt_mask_quarter, tgt_mask_tiny = make_pyramid_masks(tgt_mask_whole)
68
+
69
+ ones = torch.ones_like
70
+ dif_mask_whole = ones(tgt_mask_whole)
71
+ dif_mask_half = ones(tgt_mask_half)
72
+ dif_mask_quarter = ones(tgt_mask_quarter)
73
+ dif_mask_tiny = ones(tgt_mask_tiny)
74
+
75
+ ref_mask_whole = ones(tgt_mask_whole)
76
+ ref_mask_half = ones(tgt_mask_half)
77
+ ref_mask_quarter = ones(tgt_mask_quarter)
78
+ ref_mask_tiny = ones(tgt_mask_tiny)
79
+
80
+ model_args = (
81
+ tgt_img, dif_img, ref_img,
82
+ tgt_mask_whole, tgt_mask_half, tgt_mask_quarter, tgt_mask_tiny,
83
+ dif_mask_whole, dif_mask_half, dif_mask_quarter, dif_mask_tiny,
84
+ ref_mask_whole, ref_mask_half, ref_mask_quarter, ref_mask_tiny,
85
+ )
86
+ return model_args, gt