cyberai-1 commited on
Commit
001319d
·
1 Parent(s): a0faa4d
Files changed (1) hide show
  1. app.py +50 -6
app.py CHANGED
@@ -24,43 +24,79 @@ _tf_model = None
24
 
25
 
26
  class CNN_Torch(nn.Module):
 
 
 
 
 
 
27
  def __init__(self, num_classes=6):
28
  super().__init__()
29
 
30
  self.features = nn.Sequential(
 
31
  nn.Conv2d(3, 32, kernel_size=3, padding=1),
32
  nn.BatchNorm2d(32),
33
  nn.ReLU(inplace=True),
 
 
 
 
 
34
  nn.MaxPool2d(2),
35
 
 
36
  nn.Conv2d(32, 64, kernel_size=3, padding=1),
37
  nn.BatchNorm2d(64),
38
  nn.ReLU(inplace=True),
 
 
 
 
 
39
  nn.MaxPool2d(2),
40
- nn.Dropout2d(0.1),
41
 
 
42
  nn.Conv2d(64, 128, kernel_size=3, padding=1),
43
  nn.BatchNorm2d(128),
44
  nn.ReLU(inplace=True),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  nn.MaxPool2d(2),
46
- nn.Dropout2d(0.2),
47
  )
48
 
49
  self.gap = nn.AdaptiveAvgPool2d(1)
50
 
51
  self.classifier = nn.Sequential(
52
  nn.Flatten(),
53
- nn.Linear(128, 256),
54
  nn.ReLU(inplace=True),
55
- nn.Dropout(0.5),
56
- nn.Linear(256, num_classes),
57
  )
58
 
59
  def forward(self, x):
60
  x = self.features(x)
61
  x = self.gap(x)
62
  x = self.classifier(x)
63
- return F.log_softmax(x, dim=1)
64
 
65
 
66
  def load_pytorch():
@@ -88,6 +124,14 @@ def load_pytorch():
88
  return _pytorch_model
89
 
90
 
 
 
 
 
 
 
 
 
91
  def read_input_image():
92
  if "image" in request.files and request.files["image"].filename:
93
  return Image.open(io.BytesIO(request.files["image"].read())).convert("RGB")
 
24
 
25
 
26
  class CNN_Torch(nn.Module):
27
+ """
28
+ CNN amélioré 4 blocs
29
+ Entrée : (B, 3, 150, 150)
30
+ Sortie : (B, num_classes)
31
+ """
32
+
33
  def __init__(self, num_classes=6):
34
  super().__init__()
35
 
36
  self.features = nn.Sequential(
37
+ # Block 1: 150 -> 75
38
  nn.Conv2d(3, 32, kernel_size=3, padding=1),
39
  nn.BatchNorm2d(32),
40
  nn.ReLU(inplace=True),
41
+
42
+ nn.Conv2d(32, 32, kernel_size=3, padding=1),
43
+ nn.BatchNorm2d(32),
44
+ nn.ReLU(inplace=True),
45
+
46
  nn.MaxPool2d(2),
47
 
48
+ # Block 2: 75 -> 37
49
  nn.Conv2d(32, 64, kernel_size=3, padding=1),
50
  nn.BatchNorm2d(64),
51
  nn.ReLU(inplace=True),
52
+
53
+ nn.Conv2d(64, 64, kernel_size=3, padding=1),
54
+ nn.BatchNorm2d(64),
55
+ nn.ReLU(inplace=True),
56
+
57
  nn.MaxPool2d(2),
58
+ nn.Dropout2d(0.10),
59
 
60
+ # Block 3: 37 -> 18
61
  nn.Conv2d(64, 128, kernel_size=3, padding=1),
62
  nn.BatchNorm2d(128),
63
  nn.ReLU(inplace=True),
64
+
65
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
66
+ nn.BatchNorm2d(128),
67
+ nn.ReLU(inplace=True),
68
+
69
+ nn.MaxPool2d(2),
70
+ nn.Dropout2d(0.15),
71
+
72
+ # Block 4: 18 -> 9
73
+ nn.Conv2d(128, 256, kernel_size=3, padding=1),
74
+ nn.BatchNorm2d(256),
75
+ nn.ReLU(inplace=True),
76
+
77
+ nn.Conv2d(256, 256, kernel_size=3, padding=1),
78
+ nn.BatchNorm2d(256),
79
+ nn.ReLU(inplace=True),
80
+
81
  nn.MaxPool2d(2),
82
+ nn.Dropout2d(0.20),
83
  )
84
 
85
  self.gap = nn.AdaptiveAvgPool2d(1)
86
 
87
  self.classifier = nn.Sequential(
88
  nn.Flatten(),
89
+ nn.Linear(256, 256),
90
  nn.ReLU(inplace=True),
91
+ nn.Dropout(0.3),
92
+ nn.Linear(256, num_classes)
93
  )
94
 
95
  def forward(self, x):
96
  x = self.features(x)
97
  x = self.gap(x)
98
  x = self.classifier(x)
99
+ return x
100
 
101
 
102
  def load_pytorch():
 
124
  return _pytorch_model
125
 
126
 
127
+ def load_tensorflow():
128
+ global _tf_model
129
+ if _tf_model is None:
130
+ import tensorflow as tf
131
+ _tf_model = tf.keras.models.load_model("parfait_model.keras")
132
+ return _tf_model
133
+
134
+
135
  def read_input_image():
136
  if "image" in request.files and request.files["image"].filename:
137
  return Image.open(io.BytesIO(request.files["image"].read())).convert("RGB")