{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-01-27 16:42:47.893272: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", "E0000 00:00:1737974567.952170 6163 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "E0000 00:00:1737974567.970030 6163 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", "2025-01-27 16:42:48.097247: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n" ] } ], "source": [ "import os\n", "import torch\n", "import torch.nn as nn\n", "from torch.utils.data import DataLoader\n", "from torchvision import transforms\n", "from torchvision.datasets import ImageFolder\n", "from transformers import ViTForImageClassification, ViTFeatureExtractor" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using device: cuda\n" ] } ], "source": [ "# Define device\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(f\"Using device: {device}\")" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# Paths\n", "dataset_path = \"/home/shanin/Desktop/SHANIN/MAIN/ALL_CODE/Face_Recognition/FACE_CROP\" # Replace with your dataset path\n", "model_save_path = \"/home/shanin/Desktop/SHANIN/MAIN/ALL_CODE/Face_Recognition/v1.pth\"" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:\n", "- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([1364]) in the model instantiated\n", "- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([1364, 768]) in the model instantiated\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] } ], "source": [ "# Data Transformations with 160x160 resize\n", "transform = transforms.Compose([\n", " transforms.Resize((224, 224)), # Resize to 160x160 as per your dataset\n", " transforms.RandomHorizontalFlip(),\n", " transforms.RandomRotation(15),\n", " transforms.ToTensor(),\n", " transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),\n", "])\n", "\n", "# Load Dataset\n", "dataset = ImageFolder(root=dataset_path, transform=transform)\n", "\n", "# DataLoader (using all images)\n", "dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=4)\n", "\n", "# Define ViT Model\n", "model = ViTForImageClassification.from_pretrained(\n", " \"google/vit-base-patch16-224\", num_labels=len(dataset.classes), \n", " ignore_mismatched_sizes=True # This will ignore size mismatch warnings\n", ")\n", "\n", "# Modify the classifier head to match the number of classes in your dataset\n", "model.classifier = nn.Linear(model.config.hidden_size, len(dataset.classes))\n", "\n", "# Move model to device\n", "model = model.to(device)\n", "\n", "# Define Optimizer, Loss Function, and Scheduler\n", "optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5, weight_decay=1e-4)\n", "criterion = nn.CrossEntropyLoss()\n", "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# Training Function\n", "def train_model(model, dataloader, epochs=100):\n", " model.train()\n", "\n", " for epoch in range(epochs):\n", " print(f\"Epoch {epoch + 1}/{epochs}\")\n", " epoch_loss = 0.0\n", "\n", " for images, labels in dataloader:\n", " images, labels = images.to(device), labels.to(device)\n", "\n", " # Forward pass\n", " optimizer.zero_grad()\n", " outputs = model(images).logits\n", " loss = criterion(outputs, labels)\n", "\n", " # Backward pass\n", " loss.backward()\n", " optimizer.step()\n", "\n", " epoch_loss += loss.item()\n", "\n", " # Step the scheduler\n", " scheduler.step()\n", "\n", " print(f\"Epoch Loss: {epoch_loss / len(dataloader):.4f}\")\n", "\n", " # Save the trained model\n", " torch.save(model.state_dict(), model_save_path)\n", " print(\"Training complete! Model saved.\")" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/100\n", "Epoch Loss: 6.4940\n", "Epoch 2/100\n", "Epoch Loss: 4.4488\n", "Epoch 3/100\n", "Epoch Loss: 2.7864\n", "Epoch 4/100\n", "Epoch Loss: 1.6240\n", "Epoch 5/100\n", "Epoch Loss: 0.9264\n", "Epoch 6/100\n", "Epoch Loss: 0.5430\n", "Epoch 7/100\n", "Epoch Loss: 0.4927\n", "Epoch 8/100\n", "Epoch Loss: 0.4542\n", "Epoch 9/100\n", "Epoch Loss: 0.4172\n", "Epoch 10/100\n", "Epoch Loss: 0.3850\n", "Epoch 11/100\n", "Epoch Loss: 0.3539\n", "Epoch 12/100\n", "Epoch Loss: 0.3482\n", "Epoch 13/100\n", "Epoch Loss: 0.3455\n", "Epoch 14/100\n", "Epoch Loss: 0.3437\n", "Epoch 15/100\n", "Epoch Loss: 0.3408\n", "Epoch 16/100\n", "Epoch Loss: 0.3362\n", "Epoch 17/100\n", "Epoch Loss: 0.3362\n", "Epoch 18/100\n", "Epoch Loss: 0.3366\n", "Epoch 19/100\n", "Epoch Loss: 0.3359\n", "Epoch 20/100\n", "Epoch Loss: 0.3364\n", "Epoch 21/100\n", "Epoch Loss: 0.3349\n", "Epoch 22/100\n", "Epoch Loss: 0.3352\n", "Epoch 23/100\n", "Epoch Loss: 0.3356\n", "Epoch 24/100\n", "Epoch Loss: 0.3347\n", "Epoch 25/100\n", "Epoch Loss: 0.3343\n", "Epoch 26/100\n", "Epoch Loss: 0.3355\n", "Epoch 27/100\n", "Epoch Loss: 0.3347\n", "Epoch 28/100\n", "Epoch Loss: 0.3350\n", "Epoch 29/100\n", "Epoch Loss: 0.3354\n", "Epoch 30/100\n", "Epoch Loss: 0.3350\n", "Epoch 31/100\n", "Epoch Loss: 0.3356\n", "Epoch 32/100\n", "Epoch Loss: 0.3358\n", "Epoch 33/100\n", "Epoch Loss: 0.3349\n", "Epoch 34/100\n", "Epoch Loss: 0.3354\n", "Epoch 35/100\n", "Epoch Loss: 0.3347\n", "Epoch 36/100\n", "Epoch Loss: 0.3352\n", "Epoch 37/100\n", "Epoch Loss: 0.3351\n", "Epoch 38/100\n", "Epoch Loss: 0.3349\n", "Epoch 39/100\n", "Epoch Loss: 0.3343\n", "Epoch 40/100\n", "Epoch Loss: 0.3356\n", "Epoch 41/100\n", "Epoch Loss: 0.3349\n", "Epoch 42/100\n", "Epoch Loss: 0.3348\n", "Epoch 43/100\n", "Epoch Loss: 0.3348\n", "Epoch 44/100\n", "Epoch Loss: 0.3360\n", "Epoch 45/100\n", "Epoch Loss: 0.3352\n", "Epoch 46/100\n", "Epoch Loss: 0.3344\n", "Epoch 47/100\n", "Epoch Loss: 0.3351\n", "Epoch 48/100\n", "Epoch Loss: 0.3360\n", "Epoch 49/100\n", "Epoch Loss: 0.3351\n", "Epoch 50/100\n", "Epoch Loss: 0.3348\n", "Epoch 51/100\n", "Epoch Loss: 0.3344\n", "Epoch 52/100\n", "Epoch Loss: 0.3347\n", "Epoch 53/100\n", "Epoch Loss: 0.3349\n", "Epoch 54/100\n", "Epoch Loss: 0.3359\n", "Epoch 55/100\n", "Epoch Loss: 0.3353\n", "Epoch 56/100\n", "Epoch Loss: 0.3347\n", "Epoch 57/100\n", "Epoch Loss: 0.3355\n", "Epoch 58/100\n", "Epoch Loss: 0.3356\n", "Epoch 59/100\n", "Epoch Loss: 0.3352\n", "Epoch 60/100\n", "Epoch Loss: 0.3357\n", "Epoch 61/100\n", "Epoch Loss: 0.3359\n", "Epoch 62/100\n", "Epoch Loss: 0.3357\n", "Epoch 63/100\n", "Epoch Loss: 0.3349\n", "Epoch 64/100\n", "Epoch Loss: 0.3358\n", "Epoch 65/100\n", "Epoch Loss: 0.3349\n", "Epoch 66/100\n", "Epoch Loss: 0.3347\n", "Epoch 67/100\n", "Epoch Loss: 0.3359\n", "Epoch 68/100\n", "Epoch Loss: 0.3349\n", "Epoch 69/100\n", "Epoch Loss: 0.3338\n", "Epoch 70/100\n", "Epoch Loss: 0.3351\n", "Epoch 71/100\n", "Epoch Loss: 0.3358\n", "Epoch 72/100\n", "Epoch Loss: 0.3347\n", "Epoch 73/100\n", "Epoch Loss: 0.3353\n", "Epoch 74/100\n", "Epoch Loss: 0.3347\n", "Epoch 75/100\n", "Epoch Loss: 0.3344\n", "Epoch 76/100\n", "Epoch Loss: 0.3341\n", "Epoch 77/100\n", "Epoch Loss: 0.3352\n", "Epoch 78/100\n", "Epoch Loss: 0.3349\n", "Epoch 79/100\n", "Epoch Loss: 0.3344\n", "Epoch 80/100\n", "Epoch Loss: 0.3350\n", "Epoch 81/100\n", "Epoch Loss: 0.3351\n", "Epoch 82/100\n", "Epoch Loss: 0.3347\n", "Epoch 83/100\n", "Epoch Loss: 0.3358\n", "Epoch 84/100\n", "Epoch Loss: 0.3346\n", "Epoch 85/100\n", "Epoch Loss: 0.3351\n", "Epoch 86/100\n", "Epoch Loss: 0.3347\n", "Epoch 87/100\n", "Epoch Loss: 0.3364\n", "Epoch 88/100\n", "Epoch Loss: 0.3356\n", "Epoch 89/100\n", "Epoch Loss: 0.3349\n", "Epoch 90/100\n", "Epoch Loss: 0.3347\n", "Epoch 91/100\n", "Epoch Loss: 0.3346\n", "Epoch 92/100\n", "Epoch Loss: 0.3354\n", "Epoch 93/100\n", "Epoch Loss: 0.3362\n", "Epoch 94/100\n", "Epoch Loss: 0.3344\n", "Epoch 95/100\n", "Epoch Loss: 0.3351\n", "Epoch 96/100\n", "Epoch Loss: 0.3346\n", "Epoch 97/100\n", "Epoch Loss: 0.3352\n", "Epoch 98/100\n", "Epoch Loss: 0.3343\n", "Epoch 99/100\n", "Epoch Loss: 0.3352\n", "Epoch 100/100\n", "Epoch Loss: 0.3346\n", "Training complete! Model saved.\n" ] } ], "source": [ "# Train the Model\n", "train_model(model, dataloader, epochs=100)" ] } ], "metadata": { "kernelspec": { "display_name": "minibat", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.8" } }, "nbformat": 4, "nbformat_minor": 2 }