Ali Mohsin
feat: Add virtual try-on system components including DensePose, SMPL, and pix2pixHD models, rendering, and utilities.
5db43ff
import torch
import torch.nn as nn
class ConvLSTMCell(nn.Module):
def __init__(self, input_channels, hidden_channels, kernel_size):
super(ConvLSTMCell, self).__init__()
padding = kernel_size // 2
self.input_channels = input_channels
self.hidden_channels = hidden_channels
self.conv = nn.Conv2d(
in_channels=input_channels + hidden_channels,
out_channels=4 * hidden_channels,
kernel_size=kernel_size,
padding=padding
)
def forward(self, x, h_prev, c_prev):
combined = torch.cat([x, h_prev], dim=1) # concatenate along channel axis
conv_output = self.conv(combined)
cc_i, cc_f, cc_o, cc_g = torch.chunk(conv_output, 4, dim=1)
i = torch.sigmoid(cc_i)
f = torch.sigmoid(cc_f)
o = torch.sigmoid(cc_o)
g = torch.tanh(cc_g)
c = f * c_prev + i * g
h = o * torch.tanh(c)
return h, c
class ConvLSTM(nn.Module):
def __init__(self, input_channels, output__channels, kernel_size=5, num_layers=2):
super(ConvLSTM, self).__init__()
self.num_layers = num_layers
self.hidden_channels = output__channels
self.cells = nn.ModuleList()
for i in range(num_layers):
in_channels = input_channels if i == 0 else output__channels
self.cells.append(ConvLSTMCell(in_channels, output__channels, kernel_size))
self.c = None
self.h = None
def forward(self, input_seq):
# input_seq: (batch, seq_len, channels, height, width)
batch_size, seq_len, _, height, width = input_seq.size()
if self.c is None:
self.h, self.c = self.init_hidden(batch_size, height, width)
outputs = []
for t in range(seq_len):
x = input_seq[:, t]
for i, cell in enumerate(self.cells):
self.h[i], self.c[i] = cell(x, self.h[i], self.c[i])
x = self.h[i]
outputs.append(self.h[-1])
return torch.stack(outputs, dim=1)
def inference_forward(self, input_frame):
# input_seq: (batch, seq_len, channels, height, width)
batch_size, _, height, width = input_frame.size()
if self.c is None:
self.h, self.c = self.init_hidden(batch_size, height, width)
x = input_frame
for i, cell in enumerate(self.cells):
self.h[i], self.c[i] = cell(x, self.h[i], self.c[i])
x = self.h[i]
return x[-1].unsqueeze(0)
def reset(self):
self.c = None
self.h = None
def init_hidden(self, batch_size, height, width):
h = [torch.zeros(batch_size, self.hidden_channels, height, width).to(next(self.parameters()).device)
for _ in range(self.num_layers)]
c = [torch.zeros(batch_size, self.hidden_channels, height, width).to(next(self.parameters()).device)
for _ in range(self.num_layers)]
return h, c
if __name__ == '__main__':
# Dummy input: batch of 5 sequences, each with 10 frames of 1-channel 64x64 images
input_tensor = torch.randn(5, 10, 1, 64, 64)
input_tensor2 = torch.randn(5, 1, 64, 64)
model = ConvLSTM(input_channels=1, hidden_channels=16, kernel_size=3, num_layers=1)
output = model(input_tensor)
output2 = model.inference_forward(input_tensor2)
print(output2.shape) # Expected: (5, 10, 16, 64, 64)