| import torch | |
| import torch.nn as nn | |
| from torchvision.transforms import Grayscale | |
| class P(nn.Module): | |
| """ | |
| to solve min(P) = ||I-PQ||^2 + γ||P-R||^2 | |
| this is a least square problem | |
| how to solve? | |
| P* = (gamma*R + I*Q) / (Q*Q + gamma) | |
| """ | |
| def __init__(self): | |
| super().__init__() | |
| def forward(self, I, Q, R, gamma): | |
| return ((I * Q + gamma * R) / (gamma + Q * Q)) | |
| class Q(nn.Module): | |
| """ | |
| to solve min(Q) = ||I-PQ||^2 + λ||Q-L||^2 | |
| Q* = (lamda*L + I*P) / (P*P + lamda) | |
| """ | |
| def __init__(self): | |
| super().__init__() | |
| def forward(self, I, P, L, lamda): | |
| IR = I[:, 0:1, :, :] | |
| IG = I[:, 1:2, :, :] | |
| IB = I[:, 2:3, :, :] | |
| PR = P[:, 0:1, :, :] | |
| PG = P[:, 1:2, :, :] | |
| PB = P[:, 2:3, :, :] | |
| return (IR*PR + IG*PG + IB*PB + lamda*L) / ((PR*PR + PG*PG + PB*PB) + lamda) | |