File size: 3,053 Bytes
27c7e24
 
cdc317a
 
 
 
 
 
 
 
 
 
 
 
27c7e24
 
 
cdc317a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27c7e24
 
 
 
 
 
 
 
cdc317a
 
 
 
 
27c7e24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os

import numpy as np
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
from torch.utils.data import DataLoader
from torchvision import models

from config import HF_BACKBONE_REPO, HF_TOKEN

_BACKBONE = None
_FEATURES_CACHE = None

# Partagé entre tous les workers Gradio (même process group)
_DISK_CACHE_PATH = "/tmp/charcoal_features.npz"


def load_backbone(device: torch.device) -> nn.Module:
    global _BACKBONE

    if _BACKBONE is not None:
        return _BACKBONE.to(device)

    pt_path = hf_hub_download(
        repo_id=HF_BACKBONE_REPO,
        filename="resnet18_charcoal_backbone.pt",
        token=HF_TOKEN,
        repo_type="model",
    )

    backbone = models.resnet18()
    backbone.fc = nn.Identity()
    backbone.load_state_dict(torch.load(pt_path, map_location="cpu"))

    for p in backbone.parameters():
        p.requires_grad = False

    _BACKBONE = backbone
    return _BACKBONE.to(device)


def extract_all_features(batch_size: int = 64):
    global _FEATURES_CACHE

    from data_utils import prepare_splits, get_class_names, HFDatasetWrapper, get_eval_transform

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    backbone = load_backbone(device)
    backbone.eval()

    splits = prepare_splits()
    class_names = get_class_names()

    cache = {}
    counts = {}

    for split_name, split_data in splits.items():
        dataset = HFDatasetWrapper(split_data, get_eval_transform())
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

        X_parts, y_parts = [], []
        with torch.no_grad():
            for images, labels in loader:
                features = backbone(images.to(device))
                X_parts.append(features.cpu().numpy())
                y_parts.append(labels.numpy())

        cache[split_name] = {
            "X": np.concatenate(X_parts, axis=0),
            "y": np.concatenate(y_parts, axis=0),
        }
        counts[split_name] = len(cache[split_name]["y"])

    # Sauvegarde sur disque pour que tous les workers Gradio y aient accès
    np.savez(
        _DISK_CACHE_PATH,
        train_X=cache["train"]["X"],       train_y=cache["train"]["y"],
        validation_X=cache["validation"]["X"], validation_y=cache["validation"]["y"],
        test_X=cache["test"]["X"],         test_y=cache["test"]["y"],
    )

    _FEATURES_CACHE = cache
    return cache, class_names, counts


def get_cached_features():
    global _FEATURES_CACHE

    if _FEATURES_CACHE is not None:
        return _FEATURES_CACHE

    # Essaye de charger depuis le disque (autre worker a peut-être déjà extrait)
    if os.path.exists(_DISK_CACHE_PATH):
        data = np.load(_DISK_CACHE_PATH)
        _FEATURES_CACHE = {
            "train":      {"X": data["train_X"],      "y": data["train_y"]},
            "validation": {"X": data["validation_X"], "y": data["validation_y"]},
            "test":       {"X": data["test_X"],       "y": data["test_y"]},
        }
        return _FEATURES_CACHE

    return None