LewisBabong commited on
Commit
724d182
·
verified ·
1 Parent(s): ef0f71c

Upload 2 files

Browse files
Files changed (2) hide show
  1. Ntonga_brain_tumor_model.pth +3 -0
  2. transfer.ipynb +212 -0
Ntonga_brain_tumor_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2bc00ad0ca0b37db31f9cfe49229d49aa880d31c763c7a6aa1b00423c31e559f
3
+ size 44789468
transfer.ipynb ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "# Notebook: Entraînement du modèle de classification d'IRM cérébrales\n",
10
+ "import torch\n",
11
+ "import torch.nn as nn\n",
12
+ "import torch.optim as optim\n",
13
+ "from torchvision import datasets, transforms, models\n",
14
+ "from torch.utils.data import DataLoader, random_split\n",
15
+ "import os"
16
+ ]
17
+ },
18
+ {
19
+ "cell_type": "code",
20
+ "execution_count": 2,
21
+ "metadata": {},
22
+ "outputs": [],
23
+ "source": [
24
+ "# Répertoire contenant les images\n",
25
+ "data_dir = \"Brain Tumor MRI Images\""
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "execution_count": 3,
31
+ "metadata": {},
32
+ "outputs": [],
33
+ "source": [
34
+ "# Définir les transformations pour l'entraînement et la validation\n",
35
+ "transform = transforms.Compose([\n",
36
+ " transforms.Resize((224, 224)),\n",
37
+ " transforms.RandomHorizontalFlip(),\n",
38
+ " transforms.ToTensor(),\n",
39
+ " transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
40
+ " std=[0.229, 0.224, 0.225])\n",
41
+ "])"
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "execution_count": 4,
47
+ "metadata": {},
48
+ "outputs": [],
49
+ "source": [
50
+ "# Charger l'ensemble des données avec ImageFolder (les dossiers \"Healthy\" et \"Tumor\")\n",
51
+ "full_dataset = datasets.ImageFolder(root=data_dir, transform=transform)\n",
52
+ "dataset_size = len(full_dataset)\n",
53
+ "train_size = int(0.8 * dataset_size)\n",
54
+ "val_size = dataset_size - train_size\n",
55
+ "\n",
56
+ "train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])\n",
57
+ "\n",
58
+ "batch_size = 32\n",
59
+ "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)\n",
60
+ "val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)"
61
+ ]
62
+ },
63
+ {
64
+ "cell_type": "code",
65
+ "execution_count": 5,
66
+ "metadata": {},
67
+ "outputs": [
68
+ {
69
+ "name": "stderr",
70
+ "output_type": "stream",
71
+ "text": [
72
+ "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
73
+ " warnings.warn(\n",
74
+ "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.\n",
75
+ " warnings.warn(msg)\n",
76
+ "Downloading: \"https://download.pytorch.org/models/resnet18-f37072fd.pth\" to /Users/mac/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth\n",
77
+ "45.8%\n"
78
+ ]
79
+ },
80
+ {
81
+ "ename": "KeyboardInterrupt",
82
+ "evalue": "",
83
+ "output_type": "error",
84
+ "traceback": [
85
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
86
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
87
+ "Cell \u001b[0;32mIn[5], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Charger le modèle pré-entraîné ResNet18\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m model \u001b[38;5;241m=\u001b[39m \u001b[43mmodels\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mresnet18\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpretrained\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m 3\u001b[0m num_ftrs \u001b[38;5;241m=\u001b[39m model\u001b[38;5;241m.\u001b[39mfc\u001b[38;5;241m.\u001b[39min_features\n\u001b[1;32m 4\u001b[0m \u001b[38;5;66;03m# Remplacer la dernière couche pour la classification binaire\u001b[39;00m\n",
88
+ "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torchvision/models/_utils.py:142\u001b[0m, in \u001b[0;36mkwonly_to_pos_or_kw.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 135\u001b[0m warnings\u001b[38;5;241m.\u001b[39mwarn(\n\u001b[1;32m 136\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUsing \u001b[39m\u001b[38;5;132;01m{\u001b[39;00msequence_to_str(\u001b[38;5;28mtuple\u001b[39m(keyword_only_kwargs\u001b[38;5;241m.\u001b[39mkeys()),\u001b[38;5;250m \u001b[39mseparate_last\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mand \u001b[39m\u001b[38;5;124m'\u001b[39m)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m as positional \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 137\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mparameter(s) is deprecated since 0.13 and may be removed in the future. Please use keyword parameter(s) \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 138\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124minstead.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 139\u001b[0m )\n\u001b[1;32m 140\u001b[0m kwargs\u001b[38;5;241m.\u001b[39mupdate(keyword_only_kwargs)\n\u001b[0;32m--> 142\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
89
+ "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torchvision/models/_utils.py:228\u001b[0m, in \u001b[0;36mhandle_legacy_interface.<locals>.outer_wrapper.<locals>.inner_wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 225\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m kwargs[pretrained_param]\n\u001b[1;32m 226\u001b[0m kwargs[weights_param] \u001b[38;5;241m=\u001b[39m default_weights_arg\n\u001b[0;32m--> 228\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mbuilder\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
90
+ "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torchvision/models/resnet.py:705\u001b[0m, in \u001b[0;36mresnet18\u001b[0;34m(weights, progress, **kwargs)\u001b[0m\n\u001b[1;32m 685\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"ResNet-18 from `Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>`__.\u001b[39;00m\n\u001b[1;32m 686\u001b[0m \n\u001b[1;32m 687\u001b[0m \u001b[38;5;124;03mArgs:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 701\u001b[0m \u001b[38;5;124;03m :members:\u001b[39;00m\n\u001b[1;32m 702\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 703\u001b[0m weights \u001b[38;5;241m=\u001b[39m ResNet18_Weights\u001b[38;5;241m.\u001b[39mverify(weights)\n\u001b[0;32m--> 705\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_resnet\u001b[49m\u001b[43m(\u001b[49m\u001b[43mBasicBlock\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mweights\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprogress\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
91
+ "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torchvision/models/resnet.py:301\u001b[0m, in \u001b[0;36m_resnet\u001b[0;34m(block, layers, weights, progress, **kwargs)\u001b[0m\n\u001b[1;32m 298\u001b[0m model \u001b[38;5;241m=\u001b[39m ResNet(block, layers, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 300\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m weights \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 301\u001b[0m model\u001b[38;5;241m.\u001b[39mload_state_dict(\u001b[43mweights\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_state_dict\u001b[49m\u001b[43m(\u001b[49m\u001b[43mprogress\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mprogress\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcheck_hash\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m)\n\u001b[1;32m 303\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m model\n",
92
+ "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torchvision/models/_api.py:90\u001b[0m, in \u001b[0;36mWeightsEnum.get_state_dict\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 89\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget_state_dict\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Mapping[\u001b[38;5;28mstr\u001b[39m, Any]:\n\u001b[0;32m---> 90\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mload_state_dict_from_url\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
93
+ "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/hub.py:766\u001b[0m, in \u001b[0;36mload_state_dict_from_url\u001b[0;34m(url, model_dir, map_location, progress, check_hash, file_name, weights_only)\u001b[0m\n\u001b[1;32m 764\u001b[0m r \u001b[38;5;241m=\u001b[39m HASH_REGEX\u001b[38;5;241m.\u001b[39msearch(filename) \u001b[38;5;66;03m# r is Optional[Match[str]]\u001b[39;00m\n\u001b[1;32m 765\u001b[0m hash_prefix \u001b[38;5;241m=\u001b[39m r\u001b[38;5;241m.\u001b[39mgroup(\u001b[38;5;241m1\u001b[39m) \u001b[38;5;28;01mif\u001b[39;00m r \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 766\u001b[0m \u001b[43mdownload_url_to_file\u001b[49m\u001b[43m(\u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcached_file\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhash_prefix\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprogress\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mprogress\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 768\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _is_legacy_zip_format(cached_file):\n\u001b[1;32m 769\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _legacy_zip_load(cached_file, model_dir, map_location, weights_only)\n",
94
+ "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/hub.py:651\u001b[0m, in \u001b[0;36mdownload_url_to_file\u001b[0;34m(url, dst, hash_prefix, progress)\u001b[0m\n\u001b[1;32m 648\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m tqdm(total\u001b[38;5;241m=\u001b[39mfile_size, disable\u001b[38;5;241m=\u001b[39m\u001b[38;5;129;01mnot\u001b[39;00m progress,\n\u001b[1;32m 649\u001b[0m unit\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mB\u001b[39m\u001b[38;5;124m'\u001b[39m, unit_scale\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, unit_divisor\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1024\u001b[39m) \u001b[38;5;28;01mas\u001b[39;00m pbar:\n\u001b[1;32m 650\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[0;32m--> 651\u001b[0m buffer \u001b[38;5;241m=\u001b[39m \u001b[43mu\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mread\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m8192\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 652\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(buffer) \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m 653\u001b[0m \u001b[38;5;28;01mbreak\u001b[39;00m\n",
95
+ "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/http/client.py:472\u001b[0m, in \u001b[0;36mHTTPResponse.read\u001b[0;34m(self, amt)\u001b[0m\n\u001b[1;32m 469\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlength \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m amt \u001b[38;5;241m>\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlength:\n\u001b[1;32m 470\u001b[0m \u001b[38;5;66;03m# clip the read to the \"end of response\"\u001b[39;00m\n\u001b[1;32m 471\u001b[0m amt \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlength\n\u001b[0;32m--> 472\u001b[0m s \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mread\u001b[49m\u001b[43m(\u001b[49m\u001b[43mamt\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 473\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m s \u001b[38;5;129;01mand\u001b[39;00m amt:\n\u001b[1;32m 474\u001b[0m \u001b[38;5;66;03m# Ideally, we would raise IncompleteRead if the content-length\u001b[39;00m\n\u001b[1;32m 475\u001b[0m \u001b[38;5;66;03m# wasn't satisfied, but it might break compatibility.\u001b[39;00m\n\u001b[1;32m 476\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_close_conn()\n",
96
+ "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/socket.py:707\u001b[0m, in \u001b[0;36mSocketIO.readinto\u001b[0;34m(self, b)\u001b[0m\n\u001b[1;32m 705\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[1;32m 706\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 707\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_sock\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrecv_into\u001b[49m\u001b[43m(\u001b[49m\u001b[43mb\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 708\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m timeout:\n\u001b[1;32m 709\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_timeout_occurred \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n",
97
+ "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/ssl.py:1249\u001b[0m, in \u001b[0;36mSSLSocket.recv_into\u001b[0;34m(self, buffer, nbytes, flags)\u001b[0m\n\u001b[1;32m 1245\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m flags \u001b[38;5;241m!=\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m 1246\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 1247\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnon-zero flags not allowed in calls to recv_into() on \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m%\u001b[39m\n\u001b[1;32m 1248\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m)\n\u001b[0;32m-> 1249\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mread\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnbytes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbuffer\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1250\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1251\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28msuper\u001b[39m()\u001b[38;5;241m.\u001b[39mrecv_into(buffer, nbytes, flags)\n",
98
+ "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/ssl.py:1105\u001b[0m, in \u001b[0;36mSSLSocket.read\u001b[0;34m(self, len, buffer)\u001b[0m\n\u001b[1;32m 1103\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1104\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m buffer \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m-> 1105\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_sslobj\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mread\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mlen\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbuffer\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1106\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1107\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sslobj\u001b[38;5;241m.\u001b[39mread(\u001b[38;5;28mlen\u001b[39m)\n",
99
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
100
+ ]
101
+ }
102
+ ],
103
+ "source": [
104
+ "# Charger le modèle pré-entraîné ResNet18\n",
105
+ "model = models.resnet18(pretrained=True)\n",
106
+ "num_ftrs = model.fc.in_features\n",
107
+ "# Remplacer la dernière couche pour la classification binaire\n",
108
+ "model.fc = nn.Linear(num_ftrs, 2)"
109
+ ]
110
+ },
111
+ {
112
+ "cell_type": "code",
113
+ "execution_count": null,
114
+ "metadata": {},
115
+ "outputs": [],
116
+ "source": [
117
+ "# Définir l'appareil (GPU si disponible)\n",
118
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
119
+ "model = model.to(device)"
120
+ ]
121
+ },
122
+ {
123
+ "cell_type": "code",
124
+ "execution_count": null,
125
+ "metadata": {},
126
+ "outputs": [],
127
+ "source": [
128
+ "# Critère et optimiseur\n",
129
+ "criterion = nn.CrossEntropyLoss()\n",
130
+ "optimizer = optim.Adam(model.parameters(), lr=0.001)"
131
+ ]
132
+ },
133
+ {
134
+ "cell_type": "code",
135
+ "execution_count": null,
136
+ "metadata": {},
137
+ "outputs": [],
138
+ "source": [
139
+ "# Boucle d'entraînement\n",
140
+ "num_epochs = 10\n",
141
+ "best_acc = 0.0\n",
142
+ "\n",
143
+ "for epoch in range(num_epochs):\n",
144
+ " model.train()\n",
145
+ " running_loss = 0.0\n",
146
+ " running_corrects = 0\n",
147
+ "\n",
148
+ " for inputs, labels in train_loader:\n",
149
+ " inputs = inputs.to(device)\n",
150
+ " labels = labels.to(device)\n",
151
+ " \n",
152
+ " optimizer.zero_grad()\n",
153
+ " outputs = model(inputs)\n",
154
+ " loss = criterion(outputs, labels)\n",
155
+ " loss.backward()\n",
156
+ " optimizer.step()\n",
157
+ " \n",
158
+ " running_loss += loss.item() * inputs.size(0)\n",
159
+ " _, preds = torch.max(outputs, 1)\n",
160
+ " running_corrects += torch.sum(preds == labels.data)\n",
161
+ " \n",
162
+ " epoch_loss = running_loss / train_size\n",
163
+ " epoch_acc = running_corrects.double() / train_size\n",
164
+ " print(f\"Epoch {epoch+1}/{num_epochs} - Loss: {epoch_loss:.4f} - Accuracy: {epoch_acc:.4f}\")\n",
165
+ "\n",
166
+ "\n",
167
+ " # Évaluation sur l'ensemble de validation\n",
168
+ " model.eval()\n",
169
+ " val_running_corrects = 0\n",
170
+ " with torch.no_grad():\n",
171
+ " for inputs, labels in val_loader:\n",
172
+ " inputs = inputs.to(device)\n",
173
+ " labels = labels.to(device)\n",
174
+ " outputs = model(inputs)\n",
175
+ " _, preds = torch.max(outputs, 1)\n",
176
+ " val_running_corrects += torch.sum(preds == labels.data)\n",
177
+ " val_acc = val_running_corrects.double() / val_size\n",
178
+ " print(f\"Validation Accuracy: {val_acc:.4f}\")\n",
179
+ "\n",
180
+ " # Sauvegarde du meilleur modèle\n",
181
+ " if val_acc > best_acc:\n",
182
+ " best_acc = val_acc\n",
183
+ " \n",
184
+ "torch.save(model.state_dict(), \"Ntonga_brain_tumor_model.pth\")\n",
185
+ "print(\"Modèle sauvegardé.\")\n",
186
+ "\n",
187
+ "print(\"Entraînement terminé.\")"
188
+ ]
189
+ }
190
+ ],
191
+ "metadata": {
192
+ "kernelspec": {
193
+ "display_name": "Python 3",
194
+ "language": "python",
195
+ "name": "python3"
196
+ },
197
+ "language_info": {
198
+ "codemirror_mode": {
199
+ "name": "ipython",
200
+ "version": 3
201
+ },
202
+ "file_extension": ".py",
203
+ "mimetype": "text/x-python",
204
+ "name": "python",
205
+ "nbconvert_exporter": "python",
206
+ "pygments_lexer": "ipython3",
207
+ "version": "3.12.0"
208
+ }
209
+ },
210
+ "nbformat": 4,
211
+ "nbformat_minor": 2
212
+ }