{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "# Check if a GPU is available\n", "print(\"GPU Available:\", torch.cuda.is_available())\n", "\n", "# Check the name of the GPU\n", "if torch.cuda.is_available():\n", " print(\"GPU Name:\", torch.cuda.get_device_name(0))\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai.vision.all import *\n", "import matplotlib.pyplot as plt\n", "\n", "### Data Loader\n", "path = Path('../Spectrogram_Images')\n", "\n", "dls = ImageDataLoaders.from_folder(\n", " path,\n", " train = '.',\n", " valid_pct=0.2, # 20% of data for validation\n", " item_tfms = Resize(224), # Resize images to 224x224 pixels for CNN\n", " batch_tfms=aug_transforms(mult=1.0) # Apply basic augmentations\n", ")\n", "\n", "# Show a batch of images to verify data loading\n", "dls.show_batch(max_n=9, figsize=(8, 8))\n", "\n", "### CNN\n", "from fastai.callback.tracker import EarlyStoppingCallback\n", "learn = vision_learner(dls, resnet34, metrics=accuracy,wd=1e-4, cbs=[EarlyStoppingCallback(monitor='valid_loss', patience=3)])\n", "\n", "\n", "\n", "lr_steep = learn.lr_find().valley\n", "\n", "plt.show()\n", "\n", "learn.model[-1].add_module('dropout', nn.Dropout(p=0.5))\n", "learn.fine_tune(10, base_lr=lr_steep)\n", "\n", "# Fine-tune with early stopping\n", "\n", "# Plot training and validation loss curves\n", "learn.recorder.plot_loss()\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Validation Results: [0.3122069239616394, 0.8672986030578613]\n", "Validation Loss: 0.3122\n", "Validation Accuracy: 0.8673\n" ] }, { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Training Loss: 0.4510\n", "Training Accuracy: 0.8065\n", "The model seems to be generalizing well.\n" ] } ], "source": [ "### CNN\n", "from fastai.vision.all import *\n", "from fastai.callback.tracker import EarlyStoppingCallback\n", "path = Path('../Spectrogram_Images/s')\n", "set_seed(31)\n", "dls = ImageDataLoaders.from_folder(\n", " path,\n", " train = '.',\n", " valid_pct=0.2, # 20% of data for validation\n", " item_tfms = Resize(224), # Resize images to 224x224 pixels for CNN\n", " batch_tfms=aug_transforms(mult=1.0) # Apply basic augmentations\n", ")\n", "learn = vision_learner(dls, resnet34, metrics=accuracy, wd=1e-4 ,cbs=[EarlyStoppingCallback(monitor='valid_loss', patience=3)])\n", "\n", "learn.load('final_model')\n", "learn.model.eval()\n", "# Validate the loaded model on the validation dataset\n", "validation_results = learn.validate(dl=dls.valid)\n", "print(f\"Validation Results: {validation_results}\")\n", "\n", "# Print the validation results\n", "validation_loss, validation_accuracy = validation_results\n", "print(f\"Validation Loss: {validation_loss:.4f}\")\n", "print(f\"Validation Accuracy: {validation_accuracy:.4f}\")\n", "\n", "# Compare with training metrics (if available)\n", "training_results = learn.validate(dl=dls.train)\n", "training_loss, training_accuracy = training_results\n", "print(f\"Training Loss: {training_loss:.4f}\")\n", "print(f\"Training Accuracy: {training_accuracy:.4f}\")\n", "\n", "# Check for overfitting\n", "if training_accuracy > validation_accuracy + 0.05:\n", " print(\"Warning: The model might be overfitting.\")\n", "else:\n", " print(\"The model seems to be generalizing well.\")\n", " \n" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Predictions shape: torch.Size([848, 2])\n", "Targets shape: torch.Size([848])\n" ] }, { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "ename": "ValueError", "evalue": "Predictions or targets are None. Please check the DataLoader and the test data.", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[16], line 25\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[38;5;66;03m# Check if preds or targets are None\u001b[39;00m\n\u001b[1;32m 24\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m preds \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mor\u001b[39;00m targets \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m---> 25\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPredictions or targets are None. Please check the DataLoader and the test data.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 27\u001b[0m \u001b[38;5;66;03m# Print the shapes of preds and targets for debugging\u001b[39;00m\n\u001b[1;32m 28\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPredictions shape: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpreds\u001b[38;5;241m.\u001b[39mshape\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n", "\u001b[0;31mValueError\u001b[0m: Predictions or targets are None. Please check the DataLoader and the test data." ] } ], "source": [ "from sklearn.metrics import classification_report\n", "\n", "\n", "path = Path(\"Test\")\n", "\n", "learn = load_learner('Test/final_model.pkl')\n", "\n", "test_dl = learn.dls.test_dl(get_image_files(path))\n", "\n", "preds, targets =learn.get_preds(dl=dls.train)\n", "\n", "# Check if preds or targets are None\n", "if preds is None or targets is None:\n", " raise ValueError(\"Predictions or targets are None. Please check the DataLoader and the test data.\")\n", "\n", "# Print the shapes of preds and targets for debugging\n", "print(f\"Predictions shape: {preds.shape}\")\n", "print(f\"Targets shape: {targets.shape}\")\n", "\n", "# Get predictions and targets\n", "preds, targets = learn.get_preds(dl=test_dl)\n", "\n", "# Check if preds or targets are None\n", "if preds is None or targets is None:\n", " raise ValueError(\"Predictions or targets are None. Please check the DataLoader and the test data.\")\n", "\n", "# Print the shapes of preds and targets for debugging\n", "print(f\"Predictions shape: {preds.shape}\")\n", "print(f\"Targets shape: {targets.shape}\")\n", "\n", "# Calculate accuracy\n", "acc = accuracy(preds, targets)\n", "print(f'Accuracy on test set: {acc.item():.4f}')\n", "\n", "# Plot confusion matrix\n", "interp = ClassificationInterpretation.from_learner(learn, dl=test_dl)\n", "interp.plot_confusion_matrix()\n", "plt.show()\n", "\n", "# Print classification report\n", "y_pred = preds.argmax(dim=1)\n", "y_true = targets\n", "print(classification_report(y_true, y_pred, target_names=learn.dls.vocab))\n", "\n", "# Print additional information\n", "print(f\"Number of classes: {len(learn.dls.vocab)}\")\n", "print(f\"Class names: {learn.dls.vocab}\")\n", "print(f\"Number of test images: {len(test_dl.dataset)}\")" ] } ], "metadata": { "kernelspec": { "display_name": "laugh_detection", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 2 }