cyberai-1 commited on
Commit
66d25ea
·
1 Parent(s): 4d2ac45
Files changed (1) hide show
  1. app.py +5 -13
app.py CHANGED
@@ -24,30 +24,22 @@ _tf_model = None
24
 
25
 
26
  class CNN_Torch(nn.Module):
27
- """
28
- CNN PyTorch léger pour images RGB.
29
- Entrée : (B, 3, 150, 150)
30
- Sortie : logits transformés en log_softmax
31
- """
32
- def __init__(self, num_classes: int = 6, dropout: float = 0.5):
33
  super().__init__()
34
 
35
  self.features = nn.Sequential(
36
- # Block 1
37
- nn.Conv2d(3, 32, kernel_size=3, padding=1, bias=False),
38
  nn.BatchNorm2d(32),
39
  nn.ReLU(inplace=True),
40
  nn.MaxPool2d(2),
41
 
42
- # Block 2
43
- nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=False),
44
  nn.BatchNorm2d(64),
45
  nn.ReLU(inplace=True),
46
  nn.MaxPool2d(2),
47
  nn.Dropout2d(0.1),
48
 
49
- # Block 3
50
- nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False),
51
  nn.BatchNorm2d(128),
52
  nn.ReLU(inplace=True),
53
  nn.MaxPool2d(2),
@@ -60,7 +52,7 @@ class CNN_Torch(nn.Module):
60
  nn.Flatten(),
61
  nn.Linear(128, 256),
62
  nn.ReLU(inplace=True),
63
- nn.Dropout(dropout),
64
  nn.Linear(256, num_classes),
65
  )
66
 
 
24
 
25
 
26
  class CNN_Torch(nn.Module):
27
+ def __init__(self, num_classes=6):
 
 
 
 
 
28
  super().__init__()
29
 
30
  self.features = nn.Sequential(
31
+ nn.Conv2d(3, 32, kernel_size=3, padding=1),
 
32
  nn.BatchNorm2d(32),
33
  nn.ReLU(inplace=True),
34
  nn.MaxPool2d(2),
35
 
36
+ nn.Conv2d(32, 64, kernel_size=3, padding=1),
 
37
  nn.BatchNorm2d(64),
38
  nn.ReLU(inplace=True),
39
  nn.MaxPool2d(2),
40
  nn.Dropout2d(0.1),
41
 
42
+ nn.Conv2d(64, 128, kernel_size=3, padding=1),
 
43
  nn.BatchNorm2d(128),
44
  nn.ReLU(inplace=True),
45
  nn.MaxPool2d(2),
 
52
  nn.Flatten(),
53
  nn.Linear(128, 256),
54
  nn.ReLU(inplace=True),
55
+ nn.Dropout(0.5),
56
  nn.Linear(256, num_classes),
57
  )
58