import os import numpy as np import torch import torch.nn as nn from huggingface_hub import hf_hub_download from torch.utils.data import DataLoader from torchvision import models from config import HF_BACKBONE_REPO, HF_TOKEN _BACKBONE = None _FEATURES_CACHE = None # Partagé entre tous les workers Gradio (même process group) _DISK_CACHE_PATH = "/tmp/charcoal_features.npz" def load_backbone(device: torch.device) -> nn.Module: global _BACKBONE if _BACKBONE is not None: return _BACKBONE.to(device) pt_path = hf_hub_download( repo_id=HF_BACKBONE_REPO, filename="resnet18_charcoal_backbone.pt", token=HF_TOKEN, repo_type="model", ) backbone = models.resnet18() backbone.fc = nn.Identity() backbone.load_state_dict(torch.load(pt_path, map_location="cpu")) for p in backbone.parameters(): p.requires_grad = False _BACKBONE = backbone return _BACKBONE.to(device) def extract_all_features(batch_size: int = 64): global _FEATURES_CACHE from data_utils import prepare_splits, get_class_names, HFDatasetWrapper, get_eval_transform device = torch.device("cuda" if torch.cuda.is_available() else "cpu") backbone = load_backbone(device) backbone.eval() splits = prepare_splits() class_names = get_class_names() cache = {} counts = {} for split_name, split_data in splits.items(): dataset = HFDatasetWrapper(split_data, get_eval_transform()) loader = DataLoader(dataset, batch_size=batch_size, shuffle=False) X_parts, y_parts = [], [] with torch.no_grad(): for images, labels in loader: features = backbone(images.to(device)) X_parts.append(features.cpu().numpy()) y_parts.append(labels.numpy()) cache[split_name] = { "X": np.concatenate(X_parts, axis=0), "y": np.concatenate(y_parts, axis=0), } counts[split_name] = len(cache[split_name]["y"]) # Sauvegarde sur disque pour que tous les workers Gradio y aient accès np.savez( _DISK_CACHE_PATH, train_X=cache["train"]["X"], train_y=cache["train"]["y"], validation_X=cache["validation"]["X"], validation_y=cache["validation"]["y"], test_X=cache["test"]["X"], test_y=cache["test"]["y"], ) _FEATURES_CACHE = cache return cache, class_names, counts def get_cached_features(): global _FEATURES_CACHE if _FEATURES_CACHE is not None: return _FEATURES_CACHE # Essaye de charger depuis le disque (autre worker a peut-être déjà extrait) if os.path.exists(_DISK_CACHE_PATH): data = np.load(_DISK_CACHE_PATH) _FEATURES_CACHE = { "train": {"X": data["train_X"], "y": data["train_y"]}, "validation": {"X": data["validation_X"], "y": data["validation_y"]}, "test": {"X": data["test_X"], "y": data["test_y"]}, } return _FEATURES_CACHE return None