cyberai-1 commited on
Commit
b899fa4
Β·
1 Parent(s): fd654aa
Files changed (3) hide show
  1. app.py +60 -51
  2. parfait_model.keras +1 -1
  3. parfait_model.pth +2 -2
app.py CHANGED
@@ -8,43 +8,55 @@ import numpy as np
8
  from flask import Flask, jsonify, render_template, request
9
  from PIL import Image
10
 
11
- import torch
12
- import torch.nn as nn
13
- import torch.nn.functional as F
14
- from torchvision import transforms
15
-
16
  app = Flask(__name__)
17
 
18
- CLASSES = ["buildings", "forest", "glacier", "mountain", "sea", "street"]
19
  IMG_SIZE = 150
20
 
21
  _pytorch_model = None
22
- _tf_model = None
 
 
 
 
 
 
 
23
 
24
 
25
  class CNN_Torch(nn.Module):
26
  """
27
- CNN PyTorch pour images RGB (3, 150, 150)
28
- Retourne des log-probabilitΓ©s via log_softmax.
 
 
 
 
 
 
 
 
 
 
29
  """
30
  def __init__(self, num_classes: int = 6, dropout: float = 0.5):
31
  super().__init__()
32
 
33
  self.features = nn.Sequential(
34
- # Block 1: 150 -> 75
35
  nn.Conv2d(3, 32, kernel_size=3, padding=1, bias=False),
36
  nn.BatchNorm2d(32),
37
  nn.ReLU(inplace=True),
38
  nn.MaxPool2d(2),
39
 
40
- # Block 2: 75 -> 37
41
  nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=False),
42
  nn.BatchNorm2d(64),
43
  nn.ReLU(inplace=True),
44
  nn.MaxPool2d(2),
45
  nn.Dropout2d(0.1),
46
 
47
- # Block 3: 37 -> 18
48
  nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False),
49
  nn.BatchNorm2d(128),
50
  nn.ReLU(inplace=True),
@@ -52,7 +64,7 @@ class CNN_Torch(nn.Module):
52
  nn.Dropout2d(0.2),
53
  )
54
 
55
- self.gap = nn.AdaptiveAvgPool2d(1)
56
 
57
  self.classifier = nn.Sequential(
58
  nn.Flatten(),
@@ -68,29 +80,17 @@ class CNN_Torch(nn.Module):
68
  x = self.classifier(x)
69
  return F.log_softmax(x, dim=1)
70
 
71
-
72
- def load_pytorch():
73
- global _pytorch_model
74
- if _pytorch_model is not None:
75
- return _pytorch_model
76
-
77
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
78
- model = CNN_Torch(num_classes=6).to(device)
79
-
80
- state_dict = torch.load("parfait_model.pth", map_location=device)
81
- model.load_state_dict(state_dict)
82
  model.eval()
83
 
84
- tf_transform = transforms.Compose([
85
  transforms.Resize((IMG_SIZE, IMG_SIZE)),
86
  transforms.ToTensor(),
87
- transforms.Normalize(
88
- [0.485, 0.456, 0.406],
89
- [0.229, 0.224, 0.225]
90
- ),
91
  ])
92
-
93
- _pytorch_model = (model, device, tf_transform)
94
  return _pytorch_model
95
 
96
 
@@ -102,55 +102,64 @@ def load_tensorflow():
102
  return _tf_model
103
 
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  @app.route("/")
106
  def index():
107
  return render_template("index.html")
108
 
109
 
 
 
110
  @app.route("/predict", methods=["POST"])
111
  def predict():
112
- if "image" not in request.files:
113
- return jsonify({"error": "Aucune image fournie"}), 400
114
 
115
- framework = request.form.get("model", "pytorch")
116
 
117
  try:
118
- pil_img = Image.open(io.BytesIO(request.files["image"].read())).convert("RGB")
119
  except Exception:
120
  return jsonify({"error": "Fichier image invalide"}), 400
121
 
122
  try:
123
  if framework == "pytorch":
124
- model, device, tf_transform = load_pytorch()
125
- tensor = tf_transform(pil_img).unsqueeze(0).to(device)
126
-
127
  with torch.no_grad():
128
- out = model(tensor)
129
  probs = torch.exp(out).cpu().numpy()[0]
130
-
131
  else:
132
  model = load_tensorflow()
133
- arr = np.array(pil_img.resize((IMG_SIZE, IMG_SIZE)), dtype=np.float32)
134
- arr = np.expand_dims(arr, axis=0)
135
- probs = model.predict(arr, verbose=0)[0]
136
 
137
  pred_idx = int(np.argmax(probs))
138
  return jsonify({
139
- "class": CLASSES[pred_idx],
140
- "confidence": float(probs[pred_idx]),
141
- "probabilities": {
142
- c: float(p) for c, p in zip(CLASSES, probs)
143
- },
144
  })
145
 
146
  except FileNotFoundError as e:
147
- return jsonify({
148
- "error": f"Modèle introuvable : {e}. Placez les fichiers .pth et .keras à la racine."
149
- }), 500
150
  except Exception as e:
151
  return jsonify({"error": str(e)}), 500
152
 
153
 
154
  if __name__ == "__main__":
155
  port = int(os.environ.get("PORT", 5000))
156
- app.run(host="0.0.0.0", port=port, debug=False)
 
8
  from flask import Flask, jsonify, render_template, request
9
  from PIL import Image
10
 
 
 
 
 
 
11
  app = Flask(__name__)
12
 
13
+ CLASSES = ["buildings", "forest", "glacier", "mountain", "sea", "street"]
14
  IMG_SIZE = 150
15
 
16
  _pytorch_model = None
17
+ _tf_model = None
18
+
19
+
20
+ # ── Loaders ────────────────────────────────────────────────────────────────────
21
+ def load_pytorch():
22
+ global _pytorch_model
23
+ if _pytorch_model is not None:
24
+ return _pytorch_model
25
 
26
 
27
  class CNN_Torch(nn.Module):
28
  """
29
+ CNN PyTorch allΓ©gΓ© pour images RGB (3 canaux).
30
+ EntrΓ©e : (B, 3, 150, 150)
31
+ Sortie : (B, num_classes) β€” log-softmax
32
+
33
+ Architecture :
34
+ Block 1 : Conv2d(3β†’32) + BN + ReLU + MaxPool2d(2) β†’ 75Γ—75
35
+ Block 2 : Conv2d(32β†’64) + BN + ReLU + MaxPool2d(2) + Drop2d β†’ 37Γ—37
36
+ Block 3 : Conv2d(64β†’128)+ BN + ReLU + MaxPool2d(2) + Drop2d β†’ 18Γ—18
37
+ GAP : AdaptiveAvgPool2d(1) β†’ (B,128)
38
+ Head : Linear(128β†’256) + ReLU + Dropout + Linear(256β†’C)
39
+
40
+ Paramètre `dropout` contrôlé depuis l'extérieur → utilisé dans le CV.
41
  """
42
  def __init__(self, num_classes: int = 6, dropout: float = 0.5):
43
  super().__init__()
44
 
45
  self.features = nn.Sequential(
46
+ # Block 1 β€” 150 β†’ 75
47
  nn.Conv2d(3, 32, kernel_size=3, padding=1, bias=False),
48
  nn.BatchNorm2d(32),
49
  nn.ReLU(inplace=True),
50
  nn.MaxPool2d(2),
51
 
52
+ # Block 2 β€” 75 β†’ 37
53
  nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=False),
54
  nn.BatchNorm2d(64),
55
  nn.ReLU(inplace=True),
56
  nn.MaxPool2d(2),
57
  nn.Dropout2d(0.1),
58
 
59
+ # Block 3 β€” 37 β†’ 18
60
  nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False),
61
  nn.BatchNorm2d(128),
62
  nn.ReLU(inplace=True),
 
64
  nn.Dropout2d(0.2),
65
  )
66
 
67
+ self.gap = nn.AdaptiveAvgPool2d(1) # (B, 128, 18, 18) β†’ (B, 128, 1, 1)
68
 
69
  self.classifier = nn.Sequential(
70
  nn.Flatten(),
 
80
  x = self.classifier(x)
81
  return F.log_softmax(x, dim=1)
82
 
 
 
 
 
 
 
83
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
84
+ model = CNN_Torch(6).to(device)
85
+ model.load_state_dict(torch.load("parfait_model.pth", map_location=device))
 
 
86
  model.eval()
87
 
88
+ tf = transforms.Compose([
89
  transforms.Resize((IMG_SIZE, IMG_SIZE)),
90
  transforms.ToTensor(),
91
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
 
 
 
92
  ])
93
+ _pytorch_model = (model, device, tf)
 
94
  return _pytorch_model
95
 
96
 
 
102
  return _tf_model
103
 
104
 
105
+
106
+ def read_input_image():
107
+ if "image" in request.files and request.files["image"].filename:
108
+ return Image.open(io.BytesIO(request.files["image"].read())).convert("RGB")
109
+
110
+ image_url = request.form.get("image_url", "").strip()
111
+ if image_url:
112
+ with urllib.request.urlopen(image_url) as resp:
113
+ return Image.open(io.BytesIO(resp.read())).convert("RGB")
114
+
115
+ raise ValueError("No image provided")
116
+
117
+ # ── Routes ─────────────────────────────────────────────────────────────────────
118
  @app.route("/")
119
  def index():
120
  return render_template("index.html")
121
 
122
 
123
+
124
+
125
  @app.route("/predict", methods=["POST"])
126
  def predict():
127
+ # if "image" not in request.files:
128
+ # return jsonify({"error": "Aucune image fournie"}), 400
129
 
130
+ # framework = request.form.get("model", "pytorch")
131
 
132
  try:
133
+ pil_img = read_input_image()
134
  except Exception:
135
  return jsonify({"error": "Fichier image invalide"}), 400
136
 
137
  try:
138
  if framework == "pytorch":
139
+ import torch
140
+ model, device, tf = load_pytorch()
141
+ tensor = tf(pil_img).unsqueeze(0).to(device)
142
  with torch.no_grad():
143
+ out = model(tensor)
144
  probs = torch.exp(out).cpu().numpy()[0]
 
145
  else:
146
  model = load_tensorflow()
147
+ arr = np.array(pil_img.resize((IMG_SIZE, IMG_SIZE)), dtype=np.float32)
148
+ probs = model.predict(np.expand_dims(arr, 0), verbose=0)[0]
 
149
 
150
  pred_idx = int(np.argmax(probs))
151
  return jsonify({
152
+ "class": CLASSES[pred_idx],
153
+ "confidence": float(probs[pred_idx]),
154
+ "probabilities": {c: float(p) for c, p in zip(CLASSES, probs)},
 
 
155
  })
156
 
157
  except FileNotFoundError as e:
158
+ return jsonify({"error": f"Modèle introuvable : {e}. Placez les fichiers .pth et .keras à la racine."}), 500
 
 
159
  except Exception as e:
160
  return jsonify({"error": str(e)}), 500
161
 
162
 
163
  if __name__ == "__main__":
164
  port = int(os.environ.get("PORT", 5000))
165
+ app.run(host="0.0.0.0", port=port, debug=False)
parfait_model.keras CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a477cdf46a8abe2d6c9aede0dee169e4d3b0f5706f57e8e40a8db3ed1ac0c133
3
  size 1619355
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:70968e23930d4f12cd05b9beaa9ffd0fe969d8fb2bb312153c599adf1d8a0ade
3
  size 1619355
parfait_model.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1ad964d670faae671f200a74dc90713a00af83e18f2114bcddc521b496a29f96
3
- size 521545
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1fe727faa7a44e19589bdfe86434ab17416ff240f150d84df7f08e4d4862be0f
3
+ size 523282