classification / app.py
hobbylxx's picture
Update app.py
e241c21 verified
import torch
import gradio as gr
from PIL import Image
import yaml
import clip
from datasets.imagenet import imagenet_templates
from utils import get_clip_logits
from dota import DOTA
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
CLASSNAMES_PATH = 'data/imagenet-sketch/classnames.txt'
TOP_K = 5
def load_cfg(path="configs/vit/imagenet.yaml"):
with open(path, 'r', encoding='utf-8-sig') as f:
return yaml.safe_load(f)
cfg = load_cfg()
clip_model, preprocess = clip.load("ViT-B/16", device=DEVICE)
clip_model.eval()
def load_imagenet_classnames(path: str):
classnames = []
with open(path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if not line:
continue
parts = line.split(maxsplit=1)
if len(parts) == 2:
classnames.append(parts[1])
return classnames
IMAGENET_CLASSNAMES = load_imagenet_classnames(CLASSNAMES_PATH)
def build_clip_weights(classnames, templates):
with torch.no_grad():
weights = []
for classname in classnames:
texts = [template.format(classname) for template in templates]
tokens = clip.tokenize(texts).to(DEVICE)
text_features = clip_model.encode_text(tokens)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
class_feature = text_features.mean(dim=0)
class_feature = class_feature / class_feature.norm()
weights.append(class_feature)
return torch.stack(weights, dim=1).to(DEVICE)
CLIP_WEIGHTS = build_clip_weights(IMAGENET_CLASSNAMES, imagenet_templates)
dota_model = DOTA(cfg, input_shape=CLIP_WEIGHTS.shape[0], num_classes=CLIP_WEIGHTS.shape[1], clip_weights=CLIP_WEIGHTS.clone())
dota_model.eval()
def predict(image: Image.Image):
if image is None:
return []
img = preprocess(image).unsqueeze(0).to(DEVICE)
with torch.no_grad():
image_features, clip_logits, loss, prob_map, pred = get_clip_logits(img, clip_model, CLIP_WEIGHTS)
# DOTA logits computation
dota_logits = dota_model.predict(image_features.mean(0).unsqueeze(0))
dota_weight = min(cfg.get('eta', 0.2), cfg.get('rho', 0.005) * float(dota_model.c.mean()) / max(1, image_features.size(0)))
final_logits = clip_logits + dota_weight * dota_logits
# In testing environments, DOTA usually updates from streaming data
# We perform an adaptation step here!
dota_model.fit(image_features, prob_map)
dota_model.update()
probs = torch.softmax(final_logits, dim=-1).squeeze(0)
top_probs, top_indices = probs.topk(TOP_K)
return [
{"class": IMAGENET_CLASSNAMES[idx], "probability": float(prob)}
for prob, idx in zip(top_probs.cpu(), top_indices.cpu())
]
description = "上传图片后,基于 ImageNet 预定义类别列表输出 Top-5 概率结果。"
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="上传图片"),
outputs=gr.JSON(label="Top-5 预测结果"),
description=description,
flagging_mode="never",
)
if __name__ == '__main__':
iface.launch(server_name='0.0.0.0', share=False)