File size: 3,265 Bytes
a919b01
 
 
e241c21
a919b01
 
 
e241c21
 
a919b01
 
 
 
 
e241c21
 
 
 
 
 
a919b01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e241c21
 
a919b01
 
 
 
 
 
 
 
e241c21
 
 
 
 
 
 
 
 
 
 
 
 
 
a919b01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7cb4ff7
a919b01
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import torch
import gradio as gr
from PIL import Image
import yaml

import clip
from datasets.imagenet import imagenet_templates
from utils import get_clip_logits
from dota import DOTA


DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
CLASSNAMES_PATH = 'data/imagenet-sketch/classnames.txt'
TOP_K = 5

def load_cfg(path="configs/vit/imagenet.yaml"):
    with open(path, 'r', encoding='utf-8-sig') as f:
        return yaml.safe_load(f)

cfg = load_cfg()
clip_model, preprocess = clip.load("ViT-B/16", device=DEVICE)
clip_model.eval()


def load_imagenet_classnames(path: str):
    classnames = []
    with open(path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            parts = line.split(maxsplit=1)
            if len(parts) == 2:
                classnames.append(parts[1])
    return classnames


IMAGENET_CLASSNAMES = load_imagenet_classnames(CLASSNAMES_PATH)


def build_clip_weights(classnames, templates):
    with torch.no_grad():
        weights = []
        for classname in classnames:
            texts = [template.format(classname) for template in templates]
            tokens = clip.tokenize(texts).to(DEVICE)
            text_features = clip_model.encode_text(tokens)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)
            class_feature = text_features.mean(dim=0)
            class_feature = class_feature / class_feature.norm()
            weights.append(class_feature)
        return torch.stack(weights, dim=1).to(DEVICE)


CLIP_WEIGHTS = build_clip_weights(IMAGENET_CLASSNAMES, imagenet_templates)
dota_model = DOTA(cfg, input_shape=CLIP_WEIGHTS.shape[0], num_classes=CLIP_WEIGHTS.shape[1], clip_weights=CLIP_WEIGHTS.clone())
dota_model.eval()

def predict(image: Image.Image):
    if image is None:
        return []

    img = preprocess(image).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        image_features, clip_logits, loss, prob_map, pred = get_clip_logits(img, clip_model, CLIP_WEIGHTS)
        
        # DOTA logits computation
        dota_logits = dota_model.predict(image_features.mean(0).unsqueeze(0))
        dota_weight = min(cfg.get('eta', 0.2), cfg.get('rho', 0.005) * float(dota_model.c.mean()) / max(1, image_features.size(0)))
        
        final_logits = clip_logits + dota_weight * dota_logits
        
        # In testing environments, DOTA usually updates from streaming data
        # We perform an adaptation step here!
        dota_model.fit(image_features, prob_map)
        dota_model.update()
        
        probs = torch.softmax(final_logits, dim=-1).squeeze(0)

    top_probs, top_indices = probs.topk(TOP_K)

    return [
        {"class": IMAGENET_CLASSNAMES[idx], "probability": float(prob)}
        for prob, idx in zip(top_probs.cpu(), top_indices.cpu())
    ]


description = "上传图片后,基于 ImageNet 预定义类别列表输出 Top-5 概率结果。"

iface = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil", label="上传图片"),
    outputs=gr.JSON(label="Top-5 预测结果"),
    description=description,
    flagging_mode="never",
)


if __name__ == '__main__':
    iface.launch(server_name='0.0.0.0', share=False)