Ali Mohsin
feat: Add virtual try-on system components including DensePose, SMPL, and pix2pixHD models, rendering, and utilities.
5db43ff
| import torch | |
| import torch.nn as nn | |
| class ConvLSTMCell(nn.Module): | |
| def __init__(self, input_channels, hidden_channels, kernel_size): | |
| super(ConvLSTMCell, self).__init__() | |
| padding = kernel_size // 2 | |
| self.input_channels = input_channels | |
| self.hidden_channels = hidden_channels | |
| self.conv = nn.Conv2d( | |
| in_channels=input_channels + hidden_channels, | |
| out_channels=4 * hidden_channels, | |
| kernel_size=kernel_size, | |
| padding=padding | |
| ) | |
| def forward(self, x, h_prev, c_prev): | |
| combined = torch.cat([x, h_prev], dim=1) # concatenate along channel axis | |
| conv_output = self.conv(combined) | |
| cc_i, cc_f, cc_o, cc_g = torch.chunk(conv_output, 4, dim=1) | |
| i = torch.sigmoid(cc_i) | |
| f = torch.sigmoid(cc_f) | |
| o = torch.sigmoid(cc_o) | |
| g = torch.tanh(cc_g) | |
| c = f * c_prev + i * g | |
| h = o * torch.tanh(c) | |
| return h, c | |
| class ConvLSTM(nn.Module): | |
| def __init__(self, input_channels, output__channels, kernel_size=5, num_layers=2): | |
| super(ConvLSTM, self).__init__() | |
| self.num_layers = num_layers | |
| self.hidden_channels = output__channels | |
| self.cells = nn.ModuleList() | |
| for i in range(num_layers): | |
| in_channels = input_channels if i == 0 else output__channels | |
| self.cells.append(ConvLSTMCell(in_channels, output__channels, kernel_size)) | |
| self.c = None | |
| self.h = None | |
| def forward(self, input_seq): | |
| # input_seq: (batch, seq_len, channels, height, width) | |
| batch_size, seq_len, _, height, width = input_seq.size() | |
| if self.c is None: | |
| self.h, self.c = self.init_hidden(batch_size, height, width) | |
| outputs = [] | |
| for t in range(seq_len): | |
| x = input_seq[:, t] | |
| for i, cell in enumerate(self.cells): | |
| self.h[i], self.c[i] = cell(x, self.h[i], self.c[i]) | |
| x = self.h[i] | |
| outputs.append(self.h[-1]) | |
| return torch.stack(outputs, dim=1) | |
| def inference_forward(self, input_frame): | |
| # input_seq: (batch, seq_len, channels, height, width) | |
| batch_size, _, height, width = input_frame.size() | |
| if self.c is None: | |
| self.h, self.c = self.init_hidden(batch_size, height, width) | |
| x = input_frame | |
| for i, cell in enumerate(self.cells): | |
| self.h[i], self.c[i] = cell(x, self.h[i], self.c[i]) | |
| x = self.h[i] | |
| return x[-1].unsqueeze(0) | |
| def reset(self): | |
| self.c = None | |
| self.h = None | |
| def init_hidden(self, batch_size, height, width): | |
| h = [torch.zeros(batch_size, self.hidden_channels, height, width).to(next(self.parameters()).device) | |
| for _ in range(self.num_layers)] | |
| c = [torch.zeros(batch_size, self.hidden_channels, height, width).to(next(self.parameters()).device) | |
| for _ in range(self.num_layers)] | |
| return h, c | |
| if __name__ == '__main__': | |
| # Dummy input: batch of 5 sequences, each with 10 frames of 1-channel 64x64 images | |
| input_tensor = torch.randn(5, 10, 1, 64, 64) | |
| input_tensor2 = torch.randn(5, 1, 64, 64) | |
| model = ConvLSTM(input_channels=1, hidden_channels=16, kernel_size=3, num_layers=1) | |
| output = model(input_tensor) | |
| output2 = model.inference_forward(input_tensor2) | |
| print(output2.shape) # Expected: (5, 10, 16, 64, 64) | |