functionNormally commited on
Commit
27c7e24
·
1 Parent(s): 8e71d97

Persister les features sur disque /tmp pour partage entre workers Gradio

Browse files
Files changed (1) hide show
  1. backbone_utils.py +29 -7
backbone_utils.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import numpy as np
2
  import torch
3
  import torch.nn as nn
@@ -10,6 +12,9 @@ from config import HF_BACKBONE_REPO, HF_TOKEN
10
  _BACKBONE = None
11
  _FEATURES_CACHE = None
12
 
 
 
 
13
 
14
  def load_backbone(device: torch.device) -> nn.Module:
15
  global _BACKBONE
@@ -17,12 +22,6 @@ def load_backbone(device: torch.device) -> nn.Module:
17
  if _BACKBONE is not None:
18
  return _BACKBONE.to(device)
19
 
20
- if not HF_BACKBONE_REPO:
21
- raise RuntimeError(
22
- "HF_BACKBONE_REPO n'est pas configuré. "
23
- "Ajoutez-le dans les Secrets du Space Hugging Face."
24
- )
25
-
26
  pt_path = hf_hub_download(
27
  repo_id=HF_BACKBONE_REPO,
28
  filename="resnet18_charcoal_backbone.pt",
@@ -73,9 +72,32 @@ def extract_all_features(batch_size: int = 64):
73
  }
74
  counts[split_name] = len(cache[split_name]["y"])
75
 
 
 
 
 
 
 
 
 
76
  _FEATURES_CACHE = cache
77
  return cache, class_names, counts
78
 
79
 
80
  def get_cached_features():
81
- return _FEATURES_CACHE
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
  import numpy as np
4
  import torch
5
  import torch.nn as nn
 
12
  _BACKBONE = None
13
  _FEATURES_CACHE = None
14
 
15
+ # Partagé entre tous les workers Gradio (même process group)
16
+ _DISK_CACHE_PATH = "/tmp/charcoal_features.npz"
17
+
18
 
19
  def load_backbone(device: torch.device) -> nn.Module:
20
  global _BACKBONE
 
22
  if _BACKBONE is not None:
23
  return _BACKBONE.to(device)
24
 
 
 
 
 
 
 
25
  pt_path = hf_hub_download(
26
  repo_id=HF_BACKBONE_REPO,
27
  filename="resnet18_charcoal_backbone.pt",
 
72
  }
73
  counts[split_name] = len(cache[split_name]["y"])
74
 
75
+ # Sauvegarde sur disque pour que tous les workers Gradio y aient accès
76
+ np.savez(
77
+ _DISK_CACHE_PATH,
78
+ train_X=cache["train"]["X"], train_y=cache["train"]["y"],
79
+ validation_X=cache["validation"]["X"], validation_y=cache["validation"]["y"],
80
+ test_X=cache["test"]["X"], test_y=cache["test"]["y"],
81
+ )
82
+
83
  _FEATURES_CACHE = cache
84
  return cache, class_names, counts
85
 
86
 
87
  def get_cached_features():
88
+ global _FEATURES_CACHE
89
+
90
+ if _FEATURES_CACHE is not None:
91
+ return _FEATURES_CACHE
92
+
93
+ # Essaye de charger depuis le disque (autre worker a peut-être déjà extrait)
94
+ if os.path.exists(_DISK_CACHE_PATH):
95
+ data = np.load(_DISK_CACHE_PATH)
96
+ _FEATURES_CACHE = {
97
+ "train": {"X": data["train_X"], "y": data["train_y"]},
98
+ "validation": {"X": data["validation_X"], "y": data["validation_y"]},
99
+ "test": {"X": data["test_X"], "y": data["test_y"]},
100
+ }
101
+ return _FEATURES_CACHE
102
+
103
+ return None