rkonan commited on
Commit
9b3e2a4
·
1 Parent(s): 9ee11ef

ajout version cache heatmap

Browse files
Files changed (4) hide show
  1. .gitignore +3 -0
  2. app/main.py +0 -1
  3. app/model.py +72 -2
  4. app/utils.py +1 -1
.gitignore CHANGED
@@ -17,3 +17,6 @@ app/__pycache__/
17
 
18
  # Fichiers de logs
19
  *.log
 
 
 
 
17
 
18
  # Fichiers de logs
19
  *.log
20
+
21
+ #cache des heatmaps
22
+ cache_heatmaps/
app/main.py CHANGED
@@ -14,7 +14,6 @@ from app.log import logger
14
  from app.config import MODEL_NAME, ENV,MODEL_TYPE
15
 
16
 
17
-
18
  logger.info(f"ENV :{ENV}")
19
 
20
  app = FastAPI()
 
14
  from app.config import MODEL_NAME, ENV,MODEL_TYPE
15
 
16
 
 
17
  logger.info(f"ENV :{ENV}")
18
 
19
  app = FastAPI()
app/model.py CHANGED
@@ -9,7 +9,7 @@ from keras.applications.efficientnet_v2 import preprocess_input as effnet_prepro
9
  import io
10
  from tf_keras_vis.gradcam import Gradcam,GradcamPlusPlus
11
  from tf_keras_vis.utils import normalize
12
-
13
  import numpy as np
14
  import tensorflow as tf
15
  from tf_keras_vis.saliency import Saliency
@@ -20,8 +20,15 @@ from tf_keras_vis.saliency import Saliency
20
  from tf_keras_vis.utils import normalize
21
  import logging
22
  import time
 
23
 
24
  from typing import TypedDict, Callable, Any
 
 
 
 
 
 
25
  logging.basicConfig(
26
  level=logging.INFO, # ou logging.DEBUG
27
  format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
@@ -322,7 +329,70 @@ def compute_entropy_safe(probas):
322
  return entropy
323
 
324
 
325
- def get_heatmap(config, image_bytes: bytes,predicted_class_index):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
  result={}
327
  try:
328
  _,raw_input = preprocess_image(image_bytes,config["target_size"],config["preprocess_input"])
 
9
  import io
10
  from tf_keras_vis.gradcam import Gradcam,GradcamPlusPlus
11
  from tf_keras_vis.utils import normalize
12
+ import hashlib
13
  import numpy as np
14
  import tensorflow as tf
15
  from tf_keras_vis.saliency import Saliency
 
20
  from tf_keras_vis.utils import normalize
21
  import logging
22
  import time
23
+ import os
24
 
25
  from typing import TypedDict, Callable, Any
26
+
27
+ HEATMAP_CACHE = {}
28
+ CACHE_DIR = "./cache_heatmaps"
29
+ os.makedirs(CACHE_DIR, exist_ok=True)
30
+
31
+
32
  logging.basicConfig(
33
  level=logging.INFO, # ou logging.DEBUG
34
  format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
 
329
  return entropy
330
 
331
 
332
+ def hash_image_bytes(image_bytes):
333
+ return hashlib.md5(image_bytes).hexdigest()
334
+
335
+
336
+ def get_heatmap(config, image_bytes: bytes, predicted_class_index):
337
+ result = {}
338
+ try:
339
+ image_hash = hash_image_bytes(image_bytes)
340
+ cache_key = f"{image_hash}_{predicted_class_index}"
341
+
342
+ # Vérification cache mémoire d'abord
343
+ if cache_key in HEATMAP_CACHE:
344
+ logger.info(f"✅ Heatmap trouvée dans cache mémoire pour {cache_key}")
345
+ result["heatmap"] = HEATMAP_CACHE[cache_key]
346
+ return result
347
+
348
+ # Vérification cache disque ensuite
349
+ cache_file_path = os.path.join(CACHE_DIR, f"{cache_key}.pkl")
350
+ if os.path.exists(cache_file_path):
351
+ logger.info(f"✅ Heatmap trouvée sur disque pour {cache_key}")
352
+ with open(cache_file_path, "rb") as f:
353
+ cached_heatmap = pickle.load(f)
354
+ result["heatmap"] = cached_heatmap
355
+ # On remet aussi en mémoire pour accélérer prochaines requêtes
356
+ HEATMAP_CACHE[cache_key] = cached_heatmap
357
+ return result
358
+
359
+ # Calcul si non trouvé dans le cache
360
+ _, raw_input = preprocess_image(
361
+ image_bytes, config["target_size"], config["preprocess_input"]
362
+ )
363
+ logger.info("✅ Début de la génération de la heatmap")
364
+ start_time = time.time()
365
+
366
+ heatmap = compute_gradcam(
367
+ config["gradcam_model"],
368
+ raw_input,
369
+ class_index=predicted_class_index,
370
+ layer_name=config["last_conv_layer"],
371
+ gradcam_type=config["gradcam_type"],
372
+ )
373
+
374
+ elapsed_time = time.time() - start_time
375
+ logger.info(f"✅ Heatmap générée en {elapsed_time:.2f} secondes")
376
+
377
+ # Conversion en liste pour le JSON
378
+ heatmap_list = heatmap.tolist()
379
+ result["heatmap"] = heatmap_list
380
+
381
+ # Sauvegarde dans cache mémoire
382
+ HEATMAP_CACHE[cache_key] = heatmap_list
383
+
384
+ # Sauvegarde sur disque
385
+ with open(cache_file_path, "wb") as f:
386
+ pickle.dump(heatmap_list, f)
387
+
388
+ except Exception as e:
389
+ logger.error(f"❌ Erreur lors de la génération de la heatmap: {e}")
390
+ result["heatmap"] = []
391
+
392
+ return result
393
+
394
+
395
+ def get_heatmap_old(config, image_bytes: bytes,predicted_class_index):
396
  result={}
397
  try:
398
  _,raw_input = preprocess_image(image_bytes,config["target_size"],config["preprocess_input"])
app/utils.py CHANGED
@@ -30,7 +30,7 @@ def register_with_orchestrator():
30
  logger.info(f"📡 Tentative d'enregistrement de {MODEL_NAME} à l'orchestrateur...")
31
  response = requests.post(
32
  f"{ORCHESTRATOR_URL}/register_model",
33
- json={"model_name": MODEL_NAME, "model_type": MODEL_TYPE,"url": f"{OWN_URL}/predict"}
34
  )
35
  if response.status_code == 200:
36
  logger.info("✅ Modèle enregistré avec succès")
 
30
  logger.info(f"📡 Tentative d'enregistrement de {MODEL_NAME} à l'orchestrateur...")
31
  response = requests.post(
32
  f"{ORCHESTRATOR_URL}/register_model",
33
+ json={"model_name": MODEL_NAME, "model_type": MODEL_TYPE,"url": f"{OWN_URL}"}
34
  )
35
  if response.status_code == 200:
36
  logger.info("✅ Modèle enregistré avec succès")