llir / network /Math_Module.py
linxin02's picture
Upload portable Low_light_rainy_new code export
4336727 verified
import torch
import torch.nn as nn
from torchvision.transforms import Grayscale
class P(nn.Module):
"""
to solve min(P) = ||I-PQ||^2 + γ||P-R||^2
this is a least square problem
how to solve?
P* = (gamma*R + I*Q) / (Q*Q + gamma)
"""
def __init__(self):
super().__init__()
def forward(self, I, Q, R, gamma):
return ((I * Q + gamma * R) / (gamma + Q * Q))
class Q(nn.Module):
"""
to solve min(Q) = ||I-PQ||^2 + λ||Q-L||^2
Q* = (lamda*L + I*P) / (P*P + lamda)
"""
def __init__(self):
super().__init__()
def forward(self, I, P, L, lamda):
IR = I[:, 0:1, :, :]
IG = I[:, 1:2, :, :]
IB = I[:, 2:3, :, :]
PR = P[:, 0:1, :, :]
PG = P[:, 1:2, :, :]
PB = P[:, 2:3, :, :]
return (IR*PR + IG*PG + IB*PB + lamda*L) / ((PR*PR + PG*PG + PB*PB) + lamda)