{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"# Check if a GPU is available\n",
"print(\"GPU Available:\", torch.cuda.is_available())\n",
"\n",
"# Check the name of the GPU\n",
"if torch.cuda.is_available():\n",
" print(\"GPU Name:\", torch.cuda.get_device_name(0))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from fastai.vision.all import *\n",
"import matplotlib.pyplot as plt\n",
"\n",
"### Data Loader\n",
"path = Path('../Spectrogram_Images')\n",
"\n",
"dls = ImageDataLoaders.from_folder(\n",
" path,\n",
" train = '.',\n",
" valid_pct=0.2, # 20% of data for validation\n",
" item_tfms = Resize(224), # Resize images to 224x224 pixels for CNN\n",
" batch_tfms=aug_transforms(mult=1.0) # Apply basic augmentations\n",
")\n",
"\n",
"# Show a batch of images to verify data loading\n",
"dls.show_batch(max_n=9, figsize=(8, 8))\n",
"\n",
"### CNN\n",
"from fastai.callback.tracker import EarlyStoppingCallback\n",
"learn = vision_learner(dls, resnet34, metrics=accuracy,wd=1e-4, cbs=[EarlyStoppingCallback(monitor='valid_loss', patience=3)])\n",
"\n",
"\n",
"\n",
"lr_steep = learn.lr_find().valley\n",
"\n",
"plt.show()\n",
"\n",
"learn.model[-1].add_module('dropout', nn.Dropout(p=0.5))\n",
"learn.fine_tune(10, base_lr=lr_steep)\n",
"\n",
"# Fine-tune with early stopping\n",
"\n",
"# Plot training and validation loss curves\n",
"learn.recorder.plot_loss()\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Validation Results: [0.3122069239616394, 0.8672986030578613]\n",
"Validation Loss: 0.3122\n",
"Validation Accuracy: 0.8673\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training Loss: 0.4510\n",
"Training Accuracy: 0.8065\n",
"The model seems to be generalizing well.\n"
]
}
],
"source": [
"### CNN\n",
"from fastai.vision.all import *\n",
"from fastai.callback.tracker import EarlyStoppingCallback\n",
"path = Path('../Spectrogram_Images/s')\n",
"set_seed(31)\n",
"dls = ImageDataLoaders.from_folder(\n",
" path,\n",
" train = '.',\n",
" valid_pct=0.2, # 20% of data for validation\n",
" item_tfms = Resize(224), # Resize images to 224x224 pixels for CNN\n",
" batch_tfms=aug_transforms(mult=1.0) # Apply basic augmentations\n",
")\n",
"learn = vision_learner(dls, resnet34, metrics=accuracy, wd=1e-4 ,cbs=[EarlyStoppingCallback(monitor='valid_loss', patience=3)])\n",
"\n",
"learn.load('final_model')\n",
"learn.model.eval()\n",
"# Validate the loaded model on the validation dataset\n",
"validation_results = learn.validate(dl=dls.valid)\n",
"print(f\"Validation Results: {validation_results}\")\n",
"\n",
"# Print the validation results\n",
"validation_loss, validation_accuracy = validation_results\n",
"print(f\"Validation Loss: {validation_loss:.4f}\")\n",
"print(f\"Validation Accuracy: {validation_accuracy:.4f}\")\n",
"\n",
"# Compare with training metrics (if available)\n",
"training_results = learn.validate(dl=dls.train)\n",
"training_loss, training_accuracy = training_results\n",
"print(f\"Training Loss: {training_loss:.4f}\")\n",
"print(f\"Training Accuracy: {training_accuracy:.4f}\")\n",
"\n",
"# Check for overfitting\n",
"if training_accuracy > validation_accuracy + 0.05:\n",
" print(\"Warning: The model might be overfitting.\")\n",
"else:\n",
" print(\"The model seems to be generalizing well.\")\n",
" \n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Predictions shape: torch.Size([848, 2])\n",
"Targets shape: torch.Size([848])\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"ename": "ValueError",
"evalue": "Predictions or targets are None. Please check the DataLoader and the test data.",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[16], line 25\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[38;5;66;03m# Check if preds or targets are None\u001b[39;00m\n\u001b[1;32m 24\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m preds \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mor\u001b[39;00m targets \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m---> 25\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPredictions or targets are None. Please check the DataLoader and the test data.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 27\u001b[0m \u001b[38;5;66;03m# Print the shapes of preds and targets for debugging\u001b[39;00m\n\u001b[1;32m 28\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPredictions shape: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpreds\u001b[38;5;241m.\u001b[39mshape\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n",
"\u001b[0;31mValueError\u001b[0m: Predictions or targets are None. Please check the DataLoader and the test data."
]
}
],
"source": [
"from sklearn.metrics import classification_report\n",
"\n",
"\n",
"path = Path(\"Test\")\n",
"\n",
"learn = load_learner('Test/final_model.pkl')\n",
"\n",
"test_dl = learn.dls.test_dl(get_image_files(path))\n",
"\n",
"preds, targets =learn.get_preds(dl=dls.train)\n",
"\n",
"# Check if preds or targets are None\n",
"if preds is None or targets is None:\n",
" raise ValueError(\"Predictions or targets are None. Please check the DataLoader and the test data.\")\n",
"\n",
"# Print the shapes of preds and targets for debugging\n",
"print(f\"Predictions shape: {preds.shape}\")\n",
"print(f\"Targets shape: {targets.shape}\")\n",
"\n",
"# Get predictions and targets\n",
"preds, targets = learn.get_preds(dl=test_dl)\n",
"\n",
"# Check if preds or targets are None\n",
"if preds is None or targets is None:\n",
" raise ValueError(\"Predictions or targets are None. Please check the DataLoader and the test data.\")\n",
"\n",
"# Print the shapes of preds and targets for debugging\n",
"print(f\"Predictions shape: {preds.shape}\")\n",
"print(f\"Targets shape: {targets.shape}\")\n",
"\n",
"# Calculate accuracy\n",
"acc = accuracy(preds, targets)\n",
"print(f'Accuracy on test set: {acc.item():.4f}')\n",
"\n",
"# Plot confusion matrix\n",
"interp = ClassificationInterpretation.from_learner(learn, dl=test_dl)\n",
"interp.plot_confusion_matrix()\n",
"plt.show()\n",
"\n",
"# Print classification report\n",
"y_pred = preds.argmax(dim=1)\n",
"y_true = targets\n",
"print(classification_report(y_true, y_pred, target_names=learn.dls.vocab))\n",
"\n",
"# Print additional information\n",
"print(f\"Number of classes: {len(learn.dls.vocab)}\")\n",
"print(f\"Class names: {learn.dls.vocab}\")\n",
"print(f\"Number of test images: {len(test_dl.dataset)}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "laugh_detection",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 2
}