File size: 972 Bytes
4336727
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
import torch.nn as nn
from torchvision.transforms import Grayscale


class P(nn.Module):
    """

        to solve min(P) = ||I-PQ||^2 + γ||P-R||^2

        this is a least square problem

        how to solve?

        P* = (gamma*R + I*Q) / (Q*Q + gamma)

    """
    def __init__(self):
        super().__init__()

    def forward(self, I, Q, R, gamma):
        return ((I * Q + gamma * R) / (gamma + Q * Q))
        
class Q(nn.Module):
    """

        to solve min(Q) = ||I-PQ||^2 + λ||Q-L||^2

        Q* = (lamda*L + I*P) / (P*P + lamda)

    """
    def __init__(self):
        super().__init__()
    
    def forward(self, I, P, L, lamda):
        
        IR = I[:, 0:1, :, :]
        IG = I[:, 1:2, :, :]
        IB = I[:, 2:3, :, :]

        PR = P[:, 0:1, :, :]
        PG = P[:, 1:2, :, :]
        PB = P[:, 2:3, :, :]

        return (IR*PR + IG*PG + IB*PB + lamda*L) / ((PR*PR + PG*PG + PB*PB) + lamda)