File size: 4,176 Bytes
5e94db5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70e1e32
5e94db5
 
 
70e1e32
5e94db5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from __future__ import annotations
from typing import List, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from timm.data import resolve_model_data_config, create_transform
from contextlib import nullcontext

from .utils import load_tag_names

class EVAHeadPreserving:
    """

    Head-preserving inference for EVA-02 backbones (Animetimm / WD-EVA02).

    Interface: encode / logits / prob / tags_prob / top_tags

    """
    def __init__(self,

                 repo_id: str,

                 head_path: str,

                 categories: List[str],

                 tag_csv: str = "selected_tags.csv"):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.torch_dtype = torch.float16 if self.device == "cuda" else torch.float32

        self.categories = list(categories)
        self.tag_csv = tag_csv

        self.backbone = timm.create_model(f"hf-hub:{repo_id}", pretrained=True)
        self.backbone = self.backbone.to(self.device).eval().requires_grad_(False)

        cfg = resolve_model_data_config(self.backbone)
        self.preprocess = create_transform(**cfg)

        with torch.no_grad():
            in_size = cfg.get("input_size", (3, 448, 448))
            h, w = int(in_size[-2]), int(in_size[-1])
            dummy = torch.zeros(1, 3, h, w, device=self.device)
            fx = self.backbone.forward_features(dummy)
            pre = self.backbone.forward_head(fx, pre_logits=True)
            tags_log = self.backbone.forward_head(fx, pre_logits=False)
            D, T = int(pre.shape[-1]), int(tags_log.shape[-1])

        self.custom_head = nn.Linear(D, len(self.categories)).to(self.device).eval().requires_grad_(False)

        ckpt = torch.load(head_path, map_location=self.device, weights_only=True)
        state = ckpt.get("state_dict", ckpt)

        w = state["head.weight"].to(self.device).float()
        b = state["head.bias"].to(self.device).float()
        if w.shape != self.custom_head.weight.shape and w.t().shape == self.custom_head.weight.shape:
            w = w.t()

        with torch.no_grad():
            self.custom_head.weight.copy_(w)
            self.custom_head.bias.copy_(b)
            self.use_amp = True

        self.tag_names = load_tag_names(T, self.tag_csv)

        self.use_amp = False
        if self.device == "cuda":
            torch.backends.cuda.matmul.allow_tf32 = True
            torch.backends.cudnn.allow_tf32 = True
            torch.backends.cudnn.benchmark = True

    @torch.inference_mode()
    def encode(self, pil_list: List) -> Tuple[torch.Tensor, torch.Tensor]:
        x = torch.stack([self.preprocess(im.convert("RGB")) for im in pil_list], 0)
        x = x.to(self.device, non_blocking=True, memory_format=torch.channels_last)
        ctx = torch.amp.autocast("cuda", dtype=self.torch_dtype) if self.use_amp else nullcontext()
        with ctx:
            fx = self.backbone.forward_features(x)
            pre = self.backbone.forward_head(fx, pre_logits=True)
            feat = F.normalize(pre, dim=1)
            tags_log = self.backbone.forward_head(fx, pre_logits=False)
        return feat.float(), tags_log.float()

    @torch.inference_mode()
    def logits(self, pil_list: List) -> torch.Tensor:
        feat_norm, _ = self.encode(pil_list)
        return self.custom_head(feat_norm)

    @torch.inference_mode()
    def prob(self, pil_list: List) -> torch.Tensor:
        z = torch.clamp(self.logits(pil_list), -20, 20)
        return torch.sigmoid(z)

    @torch.inference_mode()
    def tags_prob(self, pil_list: List) -> torch.Tensor:
        _, tags_log = self.encode(pil_list)
        z = torch.clamp(tags_log, -20, 20)
        return torch.sigmoid(z)

    @torch.inference_mode()
    def top_tags(self, pil_image, top_k: int = 50):
        p = self.tags_prob([pil_image])[0].tolist()
        k = max(0, min(top_k, len(p)))
        idx = sorted(range(len(p)), key=lambda i: -p[i])[:k]
        names = self.tag_names
        return [(names[i] if i < len(names) else f"tag_{i:04d}", float(p[i])) for i in idx]