LH-Tech-AI commited on
Commit
a13cf9c
·
verified ·
1 Parent(s): bfdf9fc

Upload catgen-v2.ipynb

Browse files
Files changed (1) hide show
  1. catgen-v2.ipynb +1 -0
catgen-v2.ipynb ADDED
@@ -0,0 +1 @@
 
 
1
+ {"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.12.12","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"},"kaggle":{"accelerator":"nvidiaTeslaT4","dataSources":[],"isInternetEnabled":true,"language":"python","sourceType":"notebook","isGpuEnabled":true}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"code","source":"!pip install -q datasets torchvision matplotlib","metadata":{"_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","trusted":true,"execution":{"iopub.status.busy":"2026-04-01T16:35:07.405121Z","iopub.execute_input":"2026-04-01T16:35:07.405808Z","iopub.status.idle":"2026-04-01T16:35:11.696955Z","shell.execute_reply.started":"2026-04-01T16:35:07.405771Z","shell.execute_reply":"2026-04-01T16:35:11.696140Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"import torch\nimport torch.nn as nn\nimport torch.optim as optim\nimport torchvision.transforms as transforms\nfrom torch.utils.data import DataLoader, Dataset\nfrom datasets import load_dataset\nimport matplotlib.pyplot as plt\nimport numpy as np\n\n# --- CatGen v2 Config ---\nBATCH_SIZE = 128\nIMAGE_SIZE = 128\nCHANNELS = 3\nZ_DIM = 128\nFEATURES_G = 256 \nFEATURES_D = 128\nEPOCHS = 165\nLR = 0.0002\nBETA1 = 0.5 \n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\nprint(f\"Training running on: {device}\")","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2026-04-01T16:39:22.093315Z","iopub.execute_input":"2026-04-01T16:39:22.093860Z","iopub.status.idle":"2026-04-01T16:39:22.099627Z","shell.execute_reply.started":"2026-04-01T16:39:22.093822Z","shell.execute_reply":"2026-04-01T16:39:22.098820Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"print(\"Loading dataset...\")\nhf_dataset = load_dataset(\"huggan/cats\", split=\"train\")\n\ntransform = transforms.Compose([\n transforms.Resize(IMAGE_SIZE),\n transforms.CenterCrop(IMAGE_SIZE),\n transforms.RandomHorizontalFlip(),\n transforms.ToTensor(),\n transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n])\n\nclass CatDataset(Dataset):\n def __init__(self, hf_ds, transform):\n self.hf_ds = hf_ds\n self.transform = transform\n def __len__(self):\n return len(self.hf_ds)\n def __getitem__(self, idx):\n img = self.hf_ds[idx]['image'].convert(\"RGB\")\n return self.transform(img)\n\ndataset = CatDataset(hf_dataset, transform)\n\ndataloader = DataLoader(\n dataset, \n batch_size=BATCH_SIZE, \n shuffle=True, \n drop_last=True,\n num_workers=4, \n pin_memory=True\n)\nprint(f\"Dataset ready with {len(dataset)} images.\")","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2026-04-01T16:38:54.344262Z","iopub.execute_input":"2026-04-01T16:38:54.344902Z","iopub.status.idle":"2026-04-01T16:38:54.992746Z","shell.execute_reply.started":"2026-04-01T16:38:54.344870Z","shell.execute_reply":"2026-04-01T16:38:54.992102Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"class Generator(nn.Module):\n def __init__(self, z_dim, channels, features_g):\n super(Generator, self).__init__()\n self.net = nn.Sequential(\n # Input: Z_DIM x 1 x 1\n nn.ConvTranspose2d(z_dim, features_g * 16, 4, 1, 0, bias=False),\n nn.BatchNorm2d(features_g * 16),\n nn.ReLU(True),\n # 4x4 -> 8x8\n nn.ConvTranspose2d(features_g * 16, features_g * 8, 4, 2, 1, bias=False),\n nn.BatchNorm2d(features_g * 8),\n nn.ReLU(True),\n # 8x8 -> 16x16\n nn.ConvTranspose2d(features_g * 8, features_g * 4, 4, 2, 1, bias=False),\n nn.BatchNorm2d(features_g * 4),\n nn.ReLU(True),\n # 16x16 -> 32x32\n nn.ConvTranspose2d(features_g * 4, features_g * 2, 4, 2, 1, bias=False),\n nn.BatchNorm2d(features_g * 2),\n nn.ReLU(True),\n # 32x32 -> 64x64\n nn.ConvTranspose2d(features_g * 2, features_g, 4, 2, 1, bias=False),\n nn.BatchNorm2d(features_g),\n nn.ReLU(True),\n # 64x64 -> 128x128\n nn.ConvTranspose2d(features_g, channels, 4, 2, 1, bias=False),\n nn.Tanh()\n )\n\n def forward(self, x):\n return self.net(x)\n\nnetG = Generator(Z_DIM, CHANNELS, FEATURES_G).to(device)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2026-04-01T16:38:57.438781Z","iopub.execute_input":"2026-04-01T16:38:57.439563Z","iopub.status.idle":"2026-04-01T16:38:58.950564Z","shell.execute_reply.started":"2026-04-01T16:38:57.439530Z","shell.execute_reply":"2026-04-01T16:38:58.949989Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"class Discriminator(nn.Module):\n def __init__(self, channels, features_d):\n super(Discriminator, self).__init__()\n self.net = nn.Sequential(\n # 128x128 -> 64x64\n nn.Conv2d(channels, features_d, 4, 2, 1, bias=False),\n nn.LeakyReLU(0.2, inplace=True),\n # 64x64 -> 32x32\n nn.Conv2d(features_d, features_d * 2, 4, 2, 1, bias=False),\n nn.BatchNorm2d(features_d * 2),\n nn.LeakyReLU(0.2, inplace=True),\n # 32x32 -> 16x16\n nn.Conv2d(features_d * 2, features_d * 4, 4, 2, 1, bias=False),\n nn.BatchNorm2d(features_d * 4),\n nn.LeakyReLU(0.2, inplace=True),\n # 16x16 -> 8x8\n nn.Conv2d(features_d * 4, features_d * 8, 4, 2, 1, bias=False),\n nn.BatchNorm2d(features_d * 8),\n nn.LeakyReLU(0.2, inplace=True),\n # 8x8 -> 4x4\n nn.Conv2d(features_d * 8, features_d * 16, 4, 2, 1, bias=False),\n nn.BatchNorm2d(features_d * 16),\n nn.LeakyReLU(0.2, inplace=True),\n # 4x4 -> 1x1\n nn.Conv2d(features_d * 16, 1, 4, 1, 0, bias=False),\n )\n\n def forward(self, x):\n return self.net(x)\n\nnetD = Discriminator(CHANNELS, FEATURES_D).to(device)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2026-04-01T16:39:01.726444Z","iopub.execute_input":"2026-04-01T16:39:01.726972Z","iopub.status.idle":"2026-04-01T16:39:02.092572Z","shell.execute_reply.started":"2026-04-01T16:39:01.726941Z","shell.execute_reply":"2026-04-01T16:39:02.092014Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"# Init weights\ndef weights_init(m):\n classname = m.__class__.__name__\n if classname.find('Conv') != -1:\n nn.init.normal_(m.weight.data, 0.0, 0.02)\n elif classname.find('BatchNorm') != -1:\n nn.init.normal_(m.weight.data, 1.0, 0.02)\n nn.init.constant_(m.bias.data, 0)\n\nnetG.apply(weights_init)\nnetD.apply(weights_init)\n\ncriterion = nn.BCEWithLogitsLoss()\n\n# Separated optimizers for both networks\noptG = optim.Adam(netG.parameters(), lr=LR, betas=(BETA1, 0.999))\noptD = optim.Adam(netD.parameters(), lr=LR, betas=(BETA1, 0.999))\n\nfixed_noise = torch.randn(64, Z_DIM, 1, 1, device=device)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2026-04-01T16:39:04.634364Z","iopub.execute_input":"2026-04-01T16:39:04.635106Z","iopub.status.idle":"2026-04-01T16:39:04.642551Z","shell.execute_reply.started":"2026-04-01T16:39:04.635066Z","shell.execute_reply":"2026-04-01T16:39:04.642006Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"from torch.cuda.amp import autocast, GradScaler\nimport torchvision.utils as vutils\nfrom IPython.display import display\n\nscaler = torch.amp.GradScaler('cuda') \n\nprint(f\"Model size G: {sum(p.numel() for p in netG.parameters())/1e6:.2f}M parameters\")\nprint(f\"Model size D: {sum(p.numel() for p in netD.parameters())/1e6:.2f}M parameters\")\n\nreal_label_val = 0.9\nfake_label_val = 0.1\n\nfor epoch in range(EPOCHS):\n for i, real_images in enumerate(dataloader):\n real_images = real_images.to(device)\n b_size = real_images.size(0)\n \n # --- Discriminator Update ---\n optD.zero_grad()\n with torch.amp.autocast('cuda'):\n output_real = netD(real_images).view(-1)\n lossD_real = criterion(output_real, torch.full((b_size,), real_label_val, device=device))\n \n noise = torch.randn(b_size, Z_DIM, 1, 1, device=device)\n fake_images = netG(noise)\n output_fake = netD(fake_images.detach()).view(-1)\n lossD_fake = criterion(output_fake, torch.full((b_size,), fake_label_val, device=device))\n lossD = lossD_real + lossD_fake\n\n scaler.scale(lossD).backward()\n scaler.step(optD)\n \n # --- Generator Update ---\n optG.zero_grad()\n with torch.amp.autocast('cuda'):\n output_fake_G = netD(fake_images).view(-1)\n lossG = criterion(output_fake_G, torch.full((b_size,), real_label_val, device=device))\n\n scaler.scale(lossG).backward()\n scaler.step(optG)\n scaler.update()\n\n if i % 10 == 0:\n print(f\"E[{epoch}] I[{i}/{len(dataloader)}] Loss_D: {lossD.item():.4f} Loss_G: {lossG.item():.4f}\")\n\n if (epoch + 1) % 10 == 0 or epoch == 0:\n netG.eval()\n with torch.no_grad():\n with torch.amp.autocast('cuda'):\n sample = netG(fixed_noise[0:1]).detach().cpu().float()\n \n sample_rescaled = (sample.squeeze().permute(1, 2, 0).numpy() + 1) / 2\n \n plt.figure(figsize=(5,5))\n plt.imshow(sample_rescaled)\n plt.title(f\"Sample Epoch {epoch}\")\n plt.axis('off')\n plt.show()\n \n vutils.save_image(sample, f\"sample_epoch_{epoch}.png\", normalize=True)\n \n netG.train()\n\ncheckpoint = {\n 'epoch': EPOCHS,\n 'netG_state_dict': netG.state_dict(),\n 'netD_state_dict': netD.state_dict(),\n 'optimizerG_state_dict': optG.state_dict(),\n 'optimizerD_state_dict': optD.state_dict(),\n 'scaler_state_dict': scaler.state_dict(),\n}\n\ntorch.save(checkpoint, f'catgen_ultra_FULL_RESUME_E{EPOCHS}.ckpt')\n\ntorch.save(netG.state_dict(), 'catgen_ultra_G_ONLY_HF.pth')\n\nprint(f\"Checkpoints saved successfully.\")","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2026-04-01T12:10:15.449906Z","iopub.execute_input":"2026-04-01T12:10:15.450636Z","iopub.status.idle":"2026-04-01T15:42:24.140809Z","shell.execute_reply.started":"2026-04-01T12:10:15.450605Z","shell.execute_reply":"2026-04-01T15:42:24.139936Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"import torch\ntorch.save({\n 'epoch': epoch,\n 'model_state_dict': netG.state_dict(),\n 'optimizer_state_dict': optG.state_dict(),\n}, 'catgen_v1_full_checkpoint.pth')\n\ntorch.save(netG.state_dict(), 'catgen_v2_generator_only.pth')\n\nprint(\"Files saved!\")","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2026-04-01T16:32:15.464511Z","iopub.execute_input":"2026-04-01T16:32:15.464820Z","iopub.status.idle":"2026-04-01T16:32:19.150682Z","shell.execute_reply.started":"2026-04-01T16:32:15.464794Z","shell.execute_reply":"2026-04-01T16:32:19.149802Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"import torchvision.utils as vutils\nimport numpy as np\nimport matplotlib.pyplot as plt\n\n# --- CatGen v2: Professional Gallery Export (Fix) ---\nnetG.eval()\n\nwith torch.no_grad():\n with torch.amp.autocast('cuda'):\n fake_cats = netG(fixed_noise).detach().cpu().float() \n\ngrid = vutils.make_grid(fake_cats, padding=4, normalize=True)\ngrid_np = grid.numpy().transpose((1, 2, 0))\n\nplt.figure(figsize=(12, 12), facecolor='#111111') \nplt.imshow(grid_np, interpolation='bilinear')\nplt.axis(\"off\")\n\nplt.title(f\"CatGen v2 | Training Complete | {FEATURES_G}x{FEATURES_D} Filters\", \n color='white', fontsize=16, fontweight='bold', pad=20)\n\nplt.tight_layout()\nplt.show()\n\nplt.savefig(\"catgen_v2_results.png\", facecolor='#111111', bbox_inches='tight')","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"def generate_single_cat(model, epoch_num):\n model.eval()\n noise = torch.randn(1, Z_DIM, 1, 1, device=device)\n with torch.no_grad():\n fake = model(noise).detach().cpu().float()\n \n img = (fake.squeeze().permute(1, 2, 0).numpy() + 1) / 2\n plt.figure(figsize=(8, 8))\n plt.imshow(img, interpolation='lanczos')\n plt.axis('off')\n plt.title(f\"CatGen v2 ULTRA - Single Specimen (Epoch {epoch_num})\")\n plt.show()","metadata":{"trusted":true},"outputs":[],"execution_count":null}]}