import torch import gradio as gr from PIL import Image import yaml import clip from datasets.imagenet import imagenet_templates from utils import get_clip_logits from dota import DOTA DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' CLASSNAMES_PATH = 'data/imagenet-sketch/classnames.txt' TOP_K = 5 def load_cfg(path="configs/vit/imagenet.yaml"): with open(path, 'r', encoding='utf-8-sig') as f: return yaml.safe_load(f) cfg = load_cfg() clip_model, preprocess = clip.load("ViT-B/16", device=DEVICE) clip_model.eval() def load_imagenet_classnames(path: str): classnames = [] with open(path, 'r', encoding='utf-8') as f: for line in f: line = line.strip() if not line: continue parts = line.split(maxsplit=1) if len(parts) == 2: classnames.append(parts[1]) return classnames IMAGENET_CLASSNAMES = load_imagenet_classnames(CLASSNAMES_PATH) def build_clip_weights(classnames, templates): with torch.no_grad(): weights = [] for classname in classnames: texts = [template.format(classname) for template in templates] tokens = clip.tokenize(texts).to(DEVICE) text_features = clip_model.encode_text(tokens) text_features = text_features / text_features.norm(dim=-1, keepdim=True) class_feature = text_features.mean(dim=0) class_feature = class_feature / class_feature.norm() weights.append(class_feature) return torch.stack(weights, dim=1).to(DEVICE) CLIP_WEIGHTS = build_clip_weights(IMAGENET_CLASSNAMES, imagenet_templates) dota_model = DOTA(cfg, input_shape=CLIP_WEIGHTS.shape[0], num_classes=CLIP_WEIGHTS.shape[1], clip_weights=CLIP_WEIGHTS.clone()) dota_model.eval() def predict(image: Image.Image): if image is None: return [] img = preprocess(image).unsqueeze(0).to(DEVICE) with torch.no_grad(): image_features, clip_logits, loss, prob_map, pred = get_clip_logits(img, clip_model, CLIP_WEIGHTS) # DOTA logits computation dota_logits = dota_model.predict(image_features.mean(0).unsqueeze(0)) dota_weight = min(cfg.get('eta', 0.2), cfg.get('rho', 0.005) * float(dota_model.c.mean()) / max(1, image_features.size(0))) final_logits = clip_logits + dota_weight * dota_logits # In testing environments, DOTA usually updates from streaming data # We perform an adaptation step here! dota_model.fit(image_features, prob_map) dota_model.update() probs = torch.softmax(final_logits, dim=-1).squeeze(0) top_probs, top_indices = probs.topk(TOP_K) return [ {"class": IMAGENET_CLASSNAMES[idx], "probability": float(prob)} for prob, idx in zip(top_probs.cpu(), top_indices.cpu()) ] description = "上传图片后,基于 ImageNet 预定义类别列表输出 Top-5 概率结果。" iface = gr.Interface( fn=predict, inputs=gr.Image(type="pil", label="上传图片"), outputs=gr.JSON(label="Top-5 预测结果"), description=description, flagging_mode="never", ) if __name__ == '__main__': iface.launch(server_name='0.0.0.0', share=False)