Adapters
Afar
chemistry
AI_Text_to_Image / maindata
dghdgkl's picture
Create maindata
6a3f4a1 verified
Raw
History Blame Contribute Delete
1.5 kB
pip install torch diffusers transformers datasets wandb
import torch
import torch.nn as nn
from torch.nn import functional as F
# Define a basic U-Net style model (you can scale this up for an XL model)
class UNetModel(nn.Module):
def __init__(self, in_channels=3, out_channels=3, base_channels=64):
super(UNetModel, self).__init__()
# Downsample
self.enc1 = self.conv_block(in_channels, base_channels)
self.enc2 = self.conv_block(base_channels, base_channels * 2)
self.enc3 = self.conv_block(base_channels * 2, base_channels * 4)
# Middle
self.middle = self.conv_block(base_channels * 4, base_channels * 8)
# Upsample
self.dec3 = self.conv_block(base_channels * 8, base_channels * 4)
self.dec2 = self.conv_block(base_channels * 4, base_channels * 2)
self.dec1 = self.conv_block(base_channels * 2, out_channels)
def conv_block(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2)
)
def forward(self, x):
# Encode (Downsample)
x1 = self.enc1(x)
x2 = self.enc2(x1)
x3 = self.enc3(x2)
# Middle block
x_middle = self.middle(x3)
# Decode (Upsample)
x3_dec = self.dec3(x_middle)
x2_dec = self.dec2(x3_dec + x3)
x1_dec = self.dec1(x2_dec + x2)
return x1_dec