Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| # import pandas as pd | |
| # import matplotlib.pyplot as plt | |
| from PIL import Image | |
| import io | |
| import numpy as np | |
| import logging | |
| import os | |
| from datetime import datetime | |
| from torch.utils.data import DataLoader, Dataset | |
| from torchvision import transforms | |
| from feed_forward_nn import feedforward | |
| from positional_encoding import Positional_Encoding | |
| from multihead_attention import MultiHeadAttention | |
| d_model = 768 # main model dimension | |
| num_heads = 8 # number of heads | |
| d_ff = 2048 # feedforward hidden dimension | |
| seq_len = 128 # max input length | |
| vocab_size = 30000 | |
| class FoodDataset(Dataset): | |
| def __init__(self, dataframe, transform=None): | |
| self.df = dataframe.reset_index(drop=True) | |
| self.transform = transform | |
| def __len__(self): | |
| return len(self.df) | |
| def __getitem__(self, idx): | |
| row = self.df.iloc[idx] | |
| # ----- Image Bytes → PIL ----- | |
| image_bytes = row["image"]["bytes"] | |
| image = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
| # ----- Transform ----- | |
| if self.transform: | |
| image = self.transform(image) | |
| return image | |
| # class MultiParquetFoodDataset(Dataset): | |
| # def __init__(self, parquet_files, transform=None): | |
| # self.parquet_files = parquet_files | |
| # self.transform = transform | |
| # self.file_dfs = [pd.read_parquet(f) for f in parquet_files] | |
| # # total indexing mapping | |
| # self.index_map = [] | |
| # for file_idx, df in enumerate(self.file_dfs): | |
| # for row_idx in range(len(df)): | |
| # self.index_map.append((file_idx, row_idx)) | |
| # def __len__(self): | |
| # return len(self.index_map) | |
| # def __getitem__(self, idx): | |
| # file_idx, row_idx = self.index_map[idx] | |
| # row = self.file_dfs[file_idx].iloc[row_idx] | |
| # image_bytes = row['image']['bytes'] | |
| # image = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
| # if self.transform: | |
| # image = self.transform(image) | |
| # return image | |
| class PatchEmbedding(nn.Module): | |
| def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768): | |
| super().__init__() | |
| self.patch_size = patch_size | |
| self.num_patches = (img_size // patch_size) ** 2 | |
| self.proj = nn.Linear( | |
| patch_size * patch_size * in_channels, | |
| embed_dim | |
| ) | |
| def forward(self, x): | |
| B, C, H, W = x.shape | |
| # ---- Image → patches ---- | |
| x = x.unfold(2, self.patch_size, self.patch_size) | |
| x = x.unfold(3, self.patch_size, self.patch_size) | |
| # shape: B, C, num_patch_h, num_patch_w, P, P | |
| x = x.contiguous().view(B, C, -1, self.patch_size, self.patch_size) | |
| x = x.permute(0, 2, 1, 3, 4) | |
| # shape: B, num_patches, C, P, P | |
| x = x.flatten(2) | |
| # shape: B, num_patches, patch_dim | |
| x = self.proj(x) | |
| return x | |
| def random_masking(x, mask_ratio=0.75): | |
| """ | |
| x: [B, N, D] | |
| """ | |
| B, N, D = x.shape | |
| len_keep = int(N * (1 - mask_ratio)) | |
| # ---- random noise generate ---- | |
| noise = torch.rand(B, N, device=x.device) | |
| # ---- shuffle indices ---- | |
| ids_shuffle = torch.argsort(noise, dim=1) | |
| ids_restore = torch.argsort(ids_shuffle, dim=1) | |
| # ---- keep first tokens ---- | |
| ids_keep = ids_shuffle[:, :len_keep] | |
| x_masked = torch.gather( | |
| x, | |
| dim=1, | |
| index=ids_keep.unsqueeze(-1).repeat(1, 1, D) | |
| ) | |
| # ---- create mask ---- | |
| mask = torch.ones(B, N, device=x.device) | |
| mask[:, :len_keep] = 0 | |
| # unshuffle mask | |
| mask = torch.gather(mask, dim=1, index=ids_restore) | |
| return x_masked, mask, ids_restore | |
| class Encoder_block(nn.Module): | |
| def __init__(self, d_model, d_ff, num_heads, dropout=0.1): | |
| super().__init__() | |
| self.ffn = feedforward(d_model, d_ff) | |
| self.multi_att = MultiHeadAttention(d_model, num_heads) #d_model >> embed_dim | |
| self.norm_layer1 = nn.LayerNorm(d_model) | |
| self.norm_layer2 = nn.LayerNorm(d_model) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x, mask=None): | |
| # Multi-Head Self Attention | |
| mha_out, attn = self.multi_att(x, mask) | |
| # first Add & Norm (Residual connection) | |
| residual_1 = x + self.dropout(mha_out) | |
| norm_layer1_out = self.norm_layer1(residual_1) | |
| # Feed Forward Network | |
| ffn_out = self.ffn(norm_layer1_out) | |
| # second Add & Norm (Residual connection) | |
| residual_2 = norm_layer1_out + self.dropout(ffn_out) | |
| norm_layer2_out = self.norm_layer2(residual_2) | |
| return norm_layer2_out, attn | |
| class TransformerEncoder(nn.Module): | |
| def __init__(self, num_layers, d_model, d_ff, num_heads): | |
| super().__init__() | |
| self.layers = nn.ModuleList([ | |
| Encoder_block(d_model, d_ff, num_heads) | |
| for _ in range(num_layers) | |
| ]) | |
| def forward(self, x): | |
| for layer in self.layers: | |
| x, _ = layer(x) | |
| return x | |
| def insert_mask_tokens(x, ids_restore, mask_token): | |
| B, N_visible, D = x.shape | |
| N_full = ids_restore.shape[1] | |
| # ---- create mask tokens ---- | |
| mask_tokens = mask_token.repeat(B, N_full - N_visible, 1) | |
| # ---- concatenate visible + mask ---- | |
| x_ = torch.cat([x, mask_tokens], dim=1) | |
| # ---- restore original order ---- | |
| x_full = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1,1,D)) | |
| return x_full | |
| class MAE_Decoder_Block(nn.Module): | |
| def __init__(self, d_model, num_heads, d_ff, dropout=0.1): | |
| super().__init__() | |
| self.attn = MultiHeadAttention(d_model, num_heads) | |
| self.ffn = feedforward(d_model, d_ff) | |
| self.norm1 = nn.LayerNorm(d_model) | |
| self.norm2 = nn.LayerNorm(d_model) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x): | |
| # ---- Self Attention ---- | |
| attn_out, _ = self.attn(x) | |
| x = x + self.dropout(attn_out) | |
| x = self.norm1(x) | |
| # ---- Feed Forward ---- | |
| ffn_out = self.ffn(x) | |
| x = x + self.dropout(ffn_out) | |
| x = self.norm2(x) | |
| return x | |
| class MAEDecoder(nn.Module): | |
| def __init__(self, embed_dim=768, depth=4, num_heads=8, d_ff=2048, patch_dim=768): | |
| super().__init__() | |
| # self.pos_embed = nn.Parameter(torch.zeros(1, 196, embed_dim)) | |
| self.pos_embed = Positional_Encoding(196, d_model) | |
| self.blocks = nn.ModuleList([ | |
| MAE_Decoder_Block(embed_dim, num_heads, d_ff) | |
| for _ in range(depth) | |
| ]) | |
| self.norm = nn.LayerNorm(embed_dim) | |
| self.head = nn.Linear(embed_dim, patch_dim) | |
| def forward(self, x): | |
| x = self.pos_embed(x) | |
| for block in self.blocks: | |
| x = block(x) | |
| x = self.norm(x) | |
| x = self.head(x) | |
| return x | |
| def patchify(images, patch_size=16): | |
| B, C, H, W = images.shape | |
| num_patches = H // patch_size | |
| patches = images.unfold(2, patch_size, patch_size)\ | |
| .unfold(3, patch_size, patch_size) | |
| patches = patches.contiguous().view( | |
| B, C, -1, patch_size, patch_size | |
| ) | |
| patches = patches.permute(0, 2, 1, 3, 4) | |
| patches = patches.flatten(2) | |
| return patches | |
| def unpatchify(patches, patch_size=16, img_size=224): | |
| """ | |
| patches: [B, N, C*ps*ps] | |
| return: [B, C, H, W] | |
| """ | |
| B, N, D = patches.shape | |
| C = 3 | |
| h = w = img_size // patch_size # 14 | |
| patches = patches.view(B, N, C, patch_size, patch_size) | |
| # [B, N, C, ps, ps] | |
| patches = patches.view(B, h, w, C, patch_size, patch_size) | |
| # [B, 14, 14, C, ps, ps] | |
| patches = patches.permute(0, 3, 1, 4, 2, 5) | |
| # [B, C, 14, ps, 14, ps] | |
| images = patches.reshape(B, C, img_size, img_size) | |
| return images | |
| def setup_logger(log_dir="logs", name="mae"): | |
| os.makedirs(log_dir, exist_ok=True) | |
| log_file = os.path.join( | |
| log_dir, | |
| f"{name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log" | |
| ) | |
| logger = logging.getLogger(name) | |
| logger.setLevel(logging.INFO) | |
| logger.propagate = False | |
| formatter = logging.Formatter( | |
| "%(asctime)s | %(levelname)s | %(message)s", | |
| datefmt="%Y-%m-%d %H:%M:%S" | |
| ) | |
| # ---- file handler ---- | |
| fh = logging.FileHandler(log_file) | |
| fh.setFormatter(formatter) | |
| # ---- console handler ---- | |
| ch = logging.StreamHandler() | |
| ch.setFormatter(formatter) | |
| logger.addHandler(fh) | |
| logger.addHandler(ch) | |
| return logger |