Spaces:
Running
Running
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| import yaml | |
| import clip | |
| from datasets.imagenet import imagenet_templates | |
| from utils import get_clip_logits | |
| from dota import DOTA | |
| DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| CLASSNAMES_PATH = 'data/imagenet-sketch/classnames.txt' | |
| TOP_K = 5 | |
| def load_cfg(path="configs/vit/imagenet.yaml"): | |
| with open(path, 'r', encoding='utf-8-sig') as f: | |
| return yaml.safe_load(f) | |
| cfg = load_cfg() | |
| clip_model, preprocess = clip.load("ViT-B/16", device=DEVICE) | |
| clip_model.eval() | |
| def load_imagenet_classnames(path: str): | |
| classnames = [] | |
| with open(path, 'r', encoding='utf-8') as f: | |
| for line in f: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| parts = line.split(maxsplit=1) | |
| if len(parts) == 2: | |
| classnames.append(parts[1]) | |
| return classnames | |
| IMAGENET_CLASSNAMES = load_imagenet_classnames(CLASSNAMES_PATH) | |
| def build_clip_weights(classnames, templates): | |
| with torch.no_grad(): | |
| weights = [] | |
| for classname in classnames: | |
| texts = [template.format(classname) for template in templates] | |
| tokens = clip.tokenize(texts).to(DEVICE) | |
| text_features = clip_model.encode_text(tokens) | |
| text_features = text_features / text_features.norm(dim=-1, keepdim=True) | |
| class_feature = text_features.mean(dim=0) | |
| class_feature = class_feature / class_feature.norm() | |
| weights.append(class_feature) | |
| return torch.stack(weights, dim=1).to(DEVICE) | |
| CLIP_WEIGHTS = build_clip_weights(IMAGENET_CLASSNAMES, imagenet_templates) | |
| dota_model = DOTA(cfg, input_shape=CLIP_WEIGHTS.shape[0], num_classes=CLIP_WEIGHTS.shape[1], clip_weights=CLIP_WEIGHTS.clone()) | |
| dota_model.eval() | |
| def predict(image: Image.Image): | |
| if image is None: | |
| return [] | |
| img = preprocess(image).unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| image_features, clip_logits, loss, prob_map, pred = get_clip_logits(img, clip_model, CLIP_WEIGHTS) | |
| # DOTA logits computation | |
| dota_logits = dota_model.predict(image_features.mean(0).unsqueeze(0)) | |
| dota_weight = min(cfg.get('eta', 0.2), cfg.get('rho', 0.005) * float(dota_model.c.mean()) / max(1, image_features.size(0))) | |
| final_logits = clip_logits + dota_weight * dota_logits | |
| # In testing environments, DOTA usually updates from streaming data | |
| # We perform an adaptation step here! | |
| dota_model.fit(image_features, prob_map) | |
| dota_model.update() | |
| probs = torch.softmax(final_logits, dim=-1).squeeze(0) | |
| top_probs, top_indices = probs.topk(TOP_K) | |
| return [ | |
| {"class": IMAGENET_CLASSNAMES[idx], "probability": float(prob)} | |
| for prob, idx in zip(top_probs.cpu(), top_indices.cpu()) | |
| ] | |
| description = "上传图片后,基于 ImageNet 预定义类别列表输出 Top-5 概率结果。" | |
| iface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil", label="上传图片"), | |
| outputs=gr.JSON(label="Top-5 预测结果"), | |
| description=description, | |
| flagging_mode="never", | |
| ) | |
| if __name__ == '__main__': | |
| iface.launch(server_name='0.0.0.0', share=False) | |