github-actions[bot] commited on
Commit
a5b9bcb
·
1 Parent(s): c7caad9

🚀 Deploy from GitHub Actions - 2026-02-09 15:34:11

Browse files
Files changed (1) hide show
  1. app.py +48 -9
app.py CHANGED
@@ -173,6 +173,14 @@ async def startup_event():
173
  cache_dir="/tmp/models"
174
  )
175
 
 
 
 
 
 
 
 
 
176
  onnx_session = ort.InferenceSession(onnx_path)
177
  print("✅ ONNX chargé directement")
178
 
@@ -189,6 +197,10 @@ async def startup_event():
189
  filename="pytorch_model.bin",
190
  cache_dir="/tmp/models"
191
  )
 
 
 
 
192
 
193
  # -------------------------
194
  # 2. Charger PyTorch
@@ -206,7 +218,16 @@ async def startup_event():
206
  NUM_CLASSES
207
  )
208
 
209
- state_dict = torch.load(bin_path, map_location=DEVICE)
 
 
 
 
 
 
 
 
 
210
  model.load_state_dict(state_dict, strict=False)
211
  model.eval()
212
 
@@ -219,32 +240,50 @@ async def startup_event():
219
 
220
  dummy = torch.randn(1, 3, 224, 224)
221
 
 
222
  torch.onnx.export(
223
  model,
224
  dummy,
225
  tmp_onnx,
226
- export_params=True,
227
- opset_version=17,
228
- do_constant_folding=False,
229
  input_names=["input"],
230
- output_names=["output"]
 
 
 
 
 
231
  )
232
 
233
  print("✅ Conversion ONNX locale OK")
 
 
 
 
 
 
 
234
 
235
  # -------------------------
236
  # 4. ORT session
237
  # -------------------------
238
  onnx_session = ort.InferenceSession(tmp_onnx)
 
 
 
 
 
239
 
240
  except Exception as e2:
241
  print(f"❌ Fallback PyTorch échoué : {e2}")
242
  onnx_session = None
243
 
244
- if onnx_session:
245
- input_name = onnx_session.get_inputs()[0].name
246
- input_shape = onnx_session.get_inputs()[0].shape
247
- print(f" Input : {input_name} {input_shape}\n")
248
 
249
  # 2. Database
250
  if NEON_DATABASE_URL:
 
173
  cache_dir="/tmp/models"
174
  )
175
 
176
+ # ✅ Vérifier la taille avant de charger
177
+ file_size_mb = os.path.getsize(onnx_path) / 1e6
178
+ print(f" ONNX file size: {file_size_mb:.2f} MB")
179
+
180
+ if file_size_mb < 10:
181
+ print(f"⚠️ ONNX file too small ({file_size_mb:.2f} MB), using fallback")
182
+ raise ValueError("ONNX file incomplete")
183
+
184
  onnx_session = ort.InferenceSession(onnx_path)
185
  print("✅ ONNX chargé directement")
186
 
 
197
  filename="pytorch_model.bin",
198
  cache_dir="/tmp/models"
199
  )
200
+
201
+ # ✅ Vérifier la taille du .bin
202
+ bin_size_mb = os.path.getsize(bin_path) / 1e6
203
+ print(f" PyTorch .bin size: {bin_size_mb:.2f} MB")
204
 
205
  # -------------------------
206
  # 2. Charger PyTorch
 
218
  NUM_CLASSES
219
  )
220
 
221
+ # CORRECTION : Ajouter weights_only=False
222
+ state_dict = torch.load(bin_path, map_location=DEVICE, weights_only=False)
223
+
224
+ # ✅ CORRECTION : Gérer les cas où state_dict est nested
225
+ if isinstance(state_dict, dict):
226
+ if 'model' in state_dict:
227
+ state_dict = state_dict['model']
228
+ elif 'state_dict' in state_dict:
229
+ state_dict = state_dict['state_dict']
230
+
231
  model.load_state_dict(state_dict, strict=False)
232
  model.eval()
233
 
 
240
 
241
  dummy = torch.randn(1, 3, 224, 224)
242
 
243
+ # ✅ CORRECTION PRINCIPALE : do_constant_folding=True
244
  torch.onnx.export(
245
  model,
246
  dummy,
247
  tmp_onnx,
248
+ export_params=True, # ✅ OK
249
+ opset_version=17, # ✅ OK
250
+ do_constant_folding=True, # ✅ CHANGÉ : True au lieu de False !
251
  input_names=["input"],
252
+ output_names=["output"],
253
+ dynamic_axes={ # ✅ AJOUTÉ : Pour batch dynamique
254
+ 'input': {0: 'batch_size'},
255
+ 'output': {0: 'batch_size'}
256
+ },
257
+ verbose=False
258
  )
259
 
260
  print("✅ Conversion ONNX locale OK")
261
+
262
+ # ✅ AJOUTÉ : Vérifier la taille du ONNX
263
+ onnx_size_mb = os.path.getsize(tmp_onnx) / 1e6
264
+ print(f" ONNX file size: {onnx_size_mb:.2f} MB")
265
+
266
+ if onnx_size_mb < 10:
267
+ raise ValueError(f"ONNX file too small ({onnx_size_mb:.2f} MB)! Weights not exported.")
268
 
269
  # -------------------------
270
  # 4. ORT session
271
  # -------------------------
272
  onnx_session = ort.InferenceSession(tmp_onnx)
273
+
274
+ # ✅ AJOUTÉ : Test que le modèle marche
275
+ test_input = np.random.randn(1, 3, 224, 224).astype(np.float32)
276
+ test_output = onnx_session.run(['output'], {'input': test_input})
277
+ print(f" Test inference OK, output shape: {test_output[0].shape}")
278
 
279
  except Exception as e2:
280
  print(f"❌ Fallback PyTorch échoué : {e2}")
281
  onnx_session = None
282
 
283
+ if onnx_session:
284
+ input_name = onnx_session.get_inputs()[0].name
285
+ input_shape = onnx_session.get_inputs()[0].shape
286
+ print(f" Input : {input_name} {input_shape}\n")
287
 
288
  # 2. Database
289
  if NEON_DATABASE_URL: