OrthoReg / src /heads.py
gezi2333's picture
Upload folder using huggingface_hub
3589275 verified
import os
import open_clip
import torch
from tqdm import tqdm
from src.datasets.registry import get_dataset
from src.datasets.templates import get_templates
from src.modeling import ClassificationHead, ImageEncoder
def build_classification_head(model, dataset_name, template, data_location, device):
template = get_templates(dataset_name)
logit_scale = model.logit_scale
dataset = get_dataset(dataset_name, None, location=data_location)
model.eval()
model.to(device)
print("Building classification head.")
with torch.no_grad():
zeroshot_weights = []
for classname in tqdm(dataset.classnames):
texts = []
for t in template:
texts.append(t(classname))
texts = open_clip.tokenize(texts).to(device) # tokenize
embeddings = model.encode_text(texts) # embed with text encoder
embeddings /= embeddings.norm(dim=-1, keepdim=True)
embeddings = embeddings.mean(dim=0, keepdim=True)
embeddings /= embeddings.norm()
zeroshot_weights.append(embeddings)
zeroshot_weights = torch.stack(zeroshot_weights, dim=0).to(device)
zeroshot_weights = torch.transpose(zeroshot_weights, 0, 2)
zeroshot_weights *= logit_scale.exp()
zeroshot_weights = zeroshot_weights.squeeze().float()
zeroshot_weights = torch.transpose(zeroshot_weights, 0, 1)
classification_head = ClassificationHead(normalize=True, weights=zeroshot_weights)
return classification_head
def get_classification_head(args, dataset):
if not dataset.endswith("Val"):
# We want to load the head for the validation set always to be consistent with the one generated at training time.
dataset += "Val"
filename = os.path.join(args.save, f"head_{dataset}.pt")
if os.path.exists(filename):
print(f"Classification head for {args.model} on {dataset} exists at {filename}")
return ClassificationHead.load(filename)
print(
f"Did not find classification head for {args.model} on {dataset} at {filename}, building one from scratch." # noqa: E501
)
model = ImageEncoder(args, keep_lang=True).model
template = get_templates(dataset)
classification_head = build_classification_head(
model, dataset, template, args.data_location, args.device
)
os.makedirs(args.save, exist_ok=True)
classification_head.save(filename)
return classification_head