Spaces:
Sleeping
Sleeping
File size: 3,053 Bytes
27c7e24 cdc317a 27c7e24 cdc317a 27c7e24 cdc317a 27c7e24 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 | import os
import numpy as np
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
from torch.utils.data import DataLoader
from torchvision import models
from config import HF_BACKBONE_REPO, HF_TOKEN
_BACKBONE = None
_FEATURES_CACHE = None
# Partagé entre tous les workers Gradio (même process group)
_DISK_CACHE_PATH = "/tmp/charcoal_features.npz"
def load_backbone(device: torch.device) -> nn.Module:
global _BACKBONE
if _BACKBONE is not None:
return _BACKBONE.to(device)
pt_path = hf_hub_download(
repo_id=HF_BACKBONE_REPO,
filename="resnet18_charcoal_backbone.pt",
token=HF_TOKEN,
repo_type="model",
)
backbone = models.resnet18()
backbone.fc = nn.Identity()
backbone.load_state_dict(torch.load(pt_path, map_location="cpu"))
for p in backbone.parameters():
p.requires_grad = False
_BACKBONE = backbone
return _BACKBONE.to(device)
def extract_all_features(batch_size: int = 64):
global _FEATURES_CACHE
from data_utils import prepare_splits, get_class_names, HFDatasetWrapper, get_eval_transform
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
backbone = load_backbone(device)
backbone.eval()
splits = prepare_splits()
class_names = get_class_names()
cache = {}
counts = {}
for split_name, split_data in splits.items():
dataset = HFDatasetWrapper(split_data, get_eval_transform())
loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
X_parts, y_parts = [], []
with torch.no_grad():
for images, labels in loader:
features = backbone(images.to(device))
X_parts.append(features.cpu().numpy())
y_parts.append(labels.numpy())
cache[split_name] = {
"X": np.concatenate(X_parts, axis=0),
"y": np.concatenate(y_parts, axis=0),
}
counts[split_name] = len(cache[split_name]["y"])
# Sauvegarde sur disque pour que tous les workers Gradio y aient accès
np.savez(
_DISK_CACHE_PATH,
train_X=cache["train"]["X"], train_y=cache["train"]["y"],
validation_X=cache["validation"]["X"], validation_y=cache["validation"]["y"],
test_X=cache["test"]["X"], test_y=cache["test"]["y"],
)
_FEATURES_CACHE = cache
return cache, class_names, counts
def get_cached_features():
global _FEATURES_CACHE
if _FEATURES_CACHE is not None:
return _FEATURES_CACHE
# Essaye de charger depuis le disque (autre worker a peut-être déjà extrait)
if os.path.exists(_DISK_CACHE_PATH):
data = np.load(_DISK_CACHE_PATH)
_FEATURES_CACHE = {
"train": {"X": data["train_X"], "y": data["train_y"]},
"validation": {"X": data["validation_X"], "y": data["validation_y"]},
"test": {"X": data["test_X"], "y": data["test_y"]},
}
return _FEATURES_CACHE
return None
|