diff --git "a/Diffusion_Models.ipynb" "b/Diffusion_Models.ipynb" new file mode 100644--- /dev/null +++ "b/Diffusion_Models.ipynb" @@ -0,0 +1,3357 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 383, + "metadata": {}, + "outputs": [], + "source": [ + "# Loading all the important libraries\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn.functional as F\n", + "from matplotlib import pyplot as plt\n", + "from PIL import Image\n", + "import torchvision\n", + "from datasets import load_dataset\n", + "from torchvision import transforms\n", + "from diffusers import StableDiffusionPipeline\n", + "from peft import LoraConfig, get_peft_model\n", + "from torch.utils.data import DataLoader\n", + "from tqdm.auto import tqdm\n", + "import torch, os\n", + "from torch.amp import autocast, GradScaler\n", + "import warnings" + ] + }, + { + "cell_type": "code", + "execution_count": 384, + "metadata": {}, + "outputs": [], + "source": [ + "warnings.filterwarnings('ignore')" + ] + }, + { + "cell_type": "code", + "execution_count": 385, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Repo card metadata block was not found. Setting CardData to empty.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Combined dataset size: 24554\n" + ] + } + ], + "source": [ + "# Load animal dataset from Rapidata-10 and Smithsonian Butterfly dataset.\n", + "butterflies = load_dataset(\"huggan/smithsonian_butterflies_subset\", split=\"train\")\n", + "animals = load_dataset(\"Rapidata/Animals-10\", split=\"train\")\n", + "\n", + "# Combine datasets\n", + "dataset = concatenate_datasets([animals, butterflies])\n", + "print(\"Combined dataset size:\", len(dataset))" + ] + }, + { + "cell_type": "code", + "execution_count": 386, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 138, + "referenced_widgets": [ + "a8b35c686e9042d5ba14cd29462d2baf", + "499b384d89104156915fb644cae54692", + "ab80f5f0927d40d98892cfd57d503266", + "84446392113b4a7abeaeea80529d5d99", + "57ed90c0acb54ec9884233b4ae4ed81c", + "84448187f23f4d418010fdc4a5aa9678", + "dfa3ce719ac040c0a9fd0b161a06d20a", + "84f312d7b02c4353b99b73ca99d5662b", + "3413c74883d740909dea596481092d50", + "aee8b5c034bf42e88776b5311a395b1a", + "56dc56a5bc3a49d6a7dec255d7c1fc68" + ] + }, + "id": "Fp3D-fRmnlDn", + "outputId": "454c42d4-b396-4b6c-e0fe-1e8ce307786b" + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "68db9e1e82b1492292b73783dfc5f06d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading pipeline components...: 0%| | 0/6 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Visualize training loss\n", + "plt.figure(figsize=(8,5))\n", + "plt.plot(train_losses, marker='o', color='blue')\n", + "plt.title(\"Stable Diffusion LoRA Fine-tuning Loss Curve\")\n", + "plt.xlabel(\"Batch\")\n", + "plt.ylabel(\"Mean MSE Loss\")\n", + "plt.grid(True)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 397, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Pipeline saved successfully at models/stable_diffusion_training\n" + ] + } + ], + "source": [ + "# Attach trained components back to pipeline\n", + "unet = unet.merge_and_unload() # Converts back to standard UNet\n", + "\n", + "# Re-attach to pipeline\n", + "pipe.unet = unet\n", + "pipe.vae = vae\n", + "pipe.text_encoder = text_encoder\n", + "pipe.tokenizer = tokenizer\n", + "\n", + "\n", + "# Save the full fine-tuned pipeline\n", + "pipe.save_pretrained(out_dir)\n", + "pipe.tokenizer.save_pretrained(out_dir)\n", + "\n", + "print(f\"Pipeline saved successfully at {out_dir}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 402, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "616198a201fa4310a3d232a32e5a8364", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/30 [00:00" + ] + }, + "execution_count": 402, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Generating an initial image from a promt\n", + "init_prompt = \"horse on a beach with waves that reaching its feet and the water splashes after striking its feet\"\n", + "generator = torch.Generator(device=device).manual_seed(40)\n", + "init_image = pipe(init_prompt, negative_promt=\"blurry, satureated, deformed\", guidance_scale=9, num_inference_steps=30, generator=generator).images[0]\n", + "init_image" + ] + }, + { + "cell_type": "code", + "execution_count": 403, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "be7f70f111ab4da58a97366017397492", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading pipeline components...: 0%| | 0/6 [00:00" + ] + }, + "execution_count": 404, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Progressive image generation using the prompts\n", + "prompts = [\n", + " \"cowboy sitting straight on the horse on a beach with medium waves that reaches its feet and the water splashes after striking its feet\",\n", + " \"cowboy sitting straight on the horse on a beach with medium waves that reaches its feet and the water splashes after striking its feet, mountains in the background\"\n", + "]\n", + "generator = torch.Generator(device=device).manual_seed(42)\n", + "\n", + "image = init_image\n", + "for i, prompt in enumerate(prompts):\n", + " image = pipe_img2img(prompt=prompt, negative_promt=\"blurry, deformed, distorted,saturated\", image=image, strength=0.9, guidance_scale=10.0 ,num_inference_steps=100, generator=generator).images[0]\n", + "\n", + "image" + ] + }, + { + "cell_type": "code", + "execution_count": 405, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e693bd7c397e4bed86beb6148e5b7848", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(HTML(value='