| | from argparse import ArgumentParser |
| | from pathlib import Path |
| | from typing import Dict, List, Optional, TextIO, Tuple |
| |
|
| | import torch |
| | from PIL import Image, UnidentifiedImageError |
| | from torch import Tensor |
| | from torch.nn import Module, Parameter |
| | from torch.nn.functional import relu, sigmoid |
| | from torch.utils.data import DataLoader, Dataset |
| | from tqdm import tqdm |
| | import torch.nn.functional as F |
| | import os |
| | import json |
| |
|
| | from ram import get_transform |
| | from ram.models import ram_plus, ram, tag2text |
| | from ram.utils import build_openset_llm_label_embedding, build_openset_label_embedding, get_mAP, get_PR |
| |
|
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
| |
|
| | class _Dataset(Dataset): |
| | def __init__(self, imglist, input_size): |
| | self.imglist = imglist |
| | self.transform = get_transform(input_size) |
| |
|
| | def __len__(self): |
| | return len(self.imglist) |
| |
|
| | def __getitem__(self, index): |
| | try: |
| | img = Image.open(self.imglist[index]+".jpg") |
| | except (OSError, FileNotFoundError, UnidentifiedImageError): |
| | img = Image.new('RGB', (10, 10), 0) |
| | print("Error loading image:", self.imglist[index]) |
| | return self.transform(img) |
| |
|
| |
|
| | def parse_args(): |
| | parser = ArgumentParser() |
| | |
| | parser.add_argument("--model-type", |
| | type=str, |
| | choices=("ram_plus", "ram", "tag2text"), |
| | required=True) |
| | parser.add_argument("--checkpoint", |
| | type=str, |
| | required=True) |
| | parser.add_argument("--backbone", |
| | type=str, |
| | choices=("swin_l", "swin_b"), |
| | default=None, |
| | help="If `None`, will judge from `--model-type`") |
| | parser.add_argument("--open-set", |
| | action="store_true", |
| | help=( |
| | "Treat all categories in the taglist file as " |
| | "unseen and perform open-set classification. Only " |
| | "works with RAM." |
| | )) |
| | |
| | parser.add_argument("--dataset", |
| | type=str, |
| | choices=( |
| | "openimages_common_214", |
| | "openimages_rare_200" |
| | ), |
| | required=True) |
| | parser.add_argument("--input-size", |
| | type=int, |
| | default=384) |
| | |
| | group = parser.add_mutually_exclusive_group() |
| | group.add_argument("--threshold", |
| | type=float, |
| | default=None, |
| | help=( |
| | "Use custom threshold for all classes. Mutually " |
| | "exclusive with `--threshold-file`. If both " |
| | "`--threshold` and `--threshold-file` is `None`, " |
| | "will use a default threshold setting." |
| | )) |
| | group.add_argument("--threshold-file", |
| | type=str, |
| | default=None, |
| | help=( |
| | "Use custom class-wise thresholds by providing a " |
| | "text file. Each line is a float-type threshold, " |
| | "following the order of the tags in taglist file. " |
| | "See `ram/data/ram_tag_list_threshold.txt` as an " |
| | "example. Mutually exclusive with `--threshold`. " |
| | "If both `--threshold` and `--threshold-file` is " |
| | "`None`, will use default threshold setting." |
| | )) |
| | |
| | parser.add_argument("--output-dir", type=str, default="./outputs") |
| | parser.add_argument("--batch-size", type=int, default=128) |
| | parser.add_argument("--num-workers", type=int, default=4) |
| |
|
| | args = parser.parse_args() |
| |
|
| | |
| | args.model_type = args.model_type.lower() |
| |
|
| | assert not (args.model_type == "tag2text" and args.open_set) |
| |
|
| | if args.backbone is None: |
| | args.backbone = "swin_l" if args.model_type == "ram_plus" or args.model_type == "ram" else "swin_b" |
| |
|
| | return args |
| |
|
| |
|
| | def load_dataset( |
| | dataset: str, |
| | model_type: str, |
| | input_size: int, |
| | batch_size: int, |
| | num_workers: int |
| | ) -> Tuple[DataLoader, Dict]: |
| | dataset_root = str(Path(__file__).resolve().parent / "datasets" / dataset) |
| | img_root = dataset_root + "/imgs" |
| | |
| | |
| | |
| | if model_type == "ram_plus" or model_type == "ram": |
| | tag_file = dataset_root + f"/{dataset}_ram_taglist.txt" |
| | annot_file = dataset_root + f"/{dataset}_ram_annots.txt" |
| | else: |
| | tag_file = dataset_root + f"/{dataset}_tag2text_tagidlist.txt" |
| | annot_file = dataset_root + f"/{dataset}_{model_type}_idannots.txt" |
| |
|
| | with open(tag_file, "r", encoding="utf-8") as f: |
| | taglist = [line.strip() for line in f] |
| |
|
| | with open(annot_file, "r", encoding="utf-8") as f: |
| | imglist = [img_root + "/" + line.strip().split(",")[0] for line in f] |
| |
|
| | loader = DataLoader( |
| | dataset=_Dataset(imglist,input_size), |
| | shuffle=False, |
| | drop_last=False, |
| | pin_memory=True, |
| | batch_size=batch_size, |
| | num_workers=num_workers |
| | ) |
| | |
| | open_tag_des = dataset_root + f"/{dataset}_llm_tag_descriptions.json" |
| | if os.path.exists(open_tag_des): |
| | with open(open_tag_des, 'rb') as fo: |
| | tag_des = json.load(fo) |
| |
|
| | else: |
| | tag_des = None |
| | info = { |
| | "taglist": taglist, |
| | "imglist": imglist, |
| | "annot_file": annot_file, |
| | "img_root": img_root, |
| | "tag_des": tag_des |
| | } |
| |
|
| | return loader, info |
| |
|
| |
|
| | def get_class_idxs( |
| | model_type: str, |
| | open_set: bool, |
| | taglist: List[str] |
| | ) -> Optional[List[int]]: |
| | """Get indices of required categories in the label system.""" |
| | if model_type == "ram_plus" or model_type == "ram": |
| | if not open_set: |
| | model_taglist_file = "ram/data/ram_tag_list.txt" |
| | with open(model_taglist_file, "r", encoding="utf-8") as f: |
| | model_taglist = [line.strip() for line in f] |
| | return [model_taglist.index(tag) for tag in taglist] |
| | else: |
| | return None |
| | else: |
| | |
| | return [int(tag) for tag in taglist] |
| |
|
| |
|
| | def load_thresholds( |
| | threshold: Optional[float], |
| | threshold_file: Optional[str], |
| | model_type: str, |
| | open_set: bool, |
| | class_idxs: List[int], |
| | num_classes: int, |
| | ) -> List[float]: |
| | """Decide what threshold(s) to use.""" |
| | if not threshold_file and not threshold: |
| | if model_type == "ram_plus" or model_type == "ram": |
| | if not open_set: |
| | ram_threshold_file = "ram/data/ram_tag_list_threshold.txt" |
| | with open(ram_threshold_file, "r", encoding="utf-8") as f: |
| | idx2thre = { |
| | idx: float(line.strip()) for idx, line in enumerate(f) |
| | } |
| | return [idx2thre[idx] for idx in class_idxs] |
| | else: |
| | return [0.5] * num_classes |
| | else: |
| | return [0.68] * num_classes |
| | elif threshold_file: |
| | with open(threshold_file, "r", encoding="utf-8") as f: |
| | thresholds = [float(line.strip()) for line in f] |
| | assert len(thresholds) == num_classes |
| | return thresholds |
| | else: |
| | return [threshold] * num_classes |
| |
|
| |
|
| | def gen_pred_file( |
| | imglist: List[str], |
| | tags: List[List[str]], |
| | img_root: str, |
| | pred_file: str |
| | ) -> None: |
| | """Generate text file of tag prediction results.""" |
| | with open(pred_file, "w", encoding="utf-8") as f: |
| | for image, tag in zip(imglist, tags): |
| | |
| | s = str(Path(image).relative_to(img_root)) |
| | if tag: |
| | s = s + "," + ",".join(tag) |
| | f.write(s + "\n") |
| |
|
| | def load_ram_plus( |
| | backbone: str, |
| | checkpoint: str, |
| | input_size: int, |
| | taglist: List[str], |
| | tag_des: List[str], |
| | open_set: bool, |
| | class_idxs: List[int], |
| | ) -> Module: |
| | model = ram_plus(pretrained=checkpoint, image_size=input_size, vit=backbone) |
| | |
| | if open_set: |
| | print("Building tag embeddings ...") |
| | label_embed, _ = build_openset_llm_label_embedding(tag_des) |
| | model.label_embed = Parameter(label_embed.float()) |
| | model.num_class = len(tag_des) |
| | else: |
| | model.label_embed = Parameter(model.label_embed.data.reshape(model.num_class,51,512)[class_idxs, :, :].reshape(len(class_idxs)*51, 512)) |
| | model.num_class = len(class_idxs) |
| | return model.to(device).eval() |
| |
|
| |
|
| | def load_ram( |
| | backbone: str, |
| | checkpoint: str, |
| | input_size: int, |
| | taglist: List[str], |
| | open_set: bool, |
| | class_idxs: List[int], |
| | ) -> Module: |
| | model = ram(pretrained=checkpoint, image_size=input_size, vit=backbone) |
| | |
| | if open_set: |
| | print("Building tag embeddings ...") |
| | label_embed, _ = build_openset_label_embedding(taglist) |
| | model.label_embed = Parameter(label_embed.float()) |
| | else: |
| | model.label_embed = Parameter(model.label_embed[class_idxs, :]) |
| | return model.to(device).eval() |
| |
|
| |
|
| | def load_tag2text( |
| | backbone: str, |
| | checkpoint: str, |
| | input_size: int |
| | ) -> Module: |
| | model = tag2text( |
| | pretrained=checkpoint, |
| | image_size=input_size, |
| | vit=backbone |
| | ) |
| | return model.to(device).eval() |
| |
|
| | @torch.no_grad() |
| | def forward_ram_plus(model: Module, imgs: Tensor) -> Tensor: |
| | image_embeds = model.image_proj(model.visual_encoder(imgs.to(device))) |
| | image_atts = torch.ones( |
| | image_embeds.size()[:-1], dtype=torch.long).to(device) |
| |
|
| | image_cls_embeds = image_embeds[:, 0, :] |
| | image_spatial_embeds = image_embeds[:, 1:, :] |
| |
|
| | bs = image_spatial_embeds.shape[0] |
| |
|
| | des_per_class = int(model.label_embed.shape[0] / model.num_class) |
| |
|
| | image_cls_embeds = image_cls_embeds / image_cls_embeds.norm(dim=-1, keepdim=True) |
| | reweight_scale = model.reweight_scale.exp() |
| | logits_per_image = (reweight_scale * image_cls_embeds @ model.label_embed.t()) |
| | logits_per_image = logits_per_image.view(bs, -1,des_per_class) |
| |
|
| | weight_normalized = F.softmax(logits_per_image, dim=2) |
| | label_embed_reweight = torch.empty(bs, model.num_class, 512).cuda() |
| | weight_normalized = F.softmax(logits_per_image, dim=2) |
| | label_embed_reweight = torch.empty(bs, model.num_class, 512).cuda() |
| | for i in range(bs): |
| | reshaped_value = model.label_embed.view(-1, des_per_class, 512) |
| | product = weight_normalized[i].unsqueeze(-1) * reshaped_value |
| | label_embed_reweight[i] = product.sum(dim=1) |
| |
|
| | label_embed = relu(model.wordvec_proj(label_embed_reweight)) |
| |
|
| | tagging_embed, _ = model.tagging_head( |
| | encoder_embeds=label_embed, |
| | encoder_hidden_states=image_embeds, |
| | encoder_attention_mask=image_atts, |
| | return_dict=False, |
| | mode='tagging', |
| | ) |
| | return sigmoid(model.fc(tagging_embed).squeeze(-1)) |
| |
|
| | @torch.no_grad() |
| | def forward_ram(model: Module, imgs: Tensor) -> Tensor: |
| | image_embeds = model.image_proj(model.visual_encoder(imgs.to(device))) |
| | image_atts = torch.ones( |
| | image_embeds.size()[:-1], dtype=torch.long).to(device) |
| | label_embed = relu(model.wordvec_proj(model.label_embed)).unsqueeze(0)\ |
| | .repeat(imgs.shape[0], 1, 1) |
| | tagging_embed, _ = model.tagging_head( |
| | encoder_embeds=label_embed, |
| | encoder_hidden_states=image_embeds, |
| | encoder_attention_mask=image_atts, |
| | return_dict=False, |
| | mode='tagging', |
| | ) |
| | return sigmoid(model.fc(tagging_embed).squeeze(-1)) |
| |
|
| |
|
| | @torch.no_grad() |
| | def forward_tag2text( |
| | model: Module, |
| | class_idxs: List[int], |
| | imgs: Tensor |
| | ) -> Tensor: |
| | image_embeds = model.visual_encoder(imgs.to(device)) |
| | image_atts = torch.ones( |
| | image_embeds.size()[:-1], dtype=torch.long).to(device) |
| | label_embed = model.label_embed.weight.unsqueeze(0)\ |
| | .repeat(imgs.shape[0], 1, 1) |
| | tagging_embed, _ = model.tagging_head( |
| | encoder_embeds=label_embed, |
| | encoder_hidden_states=image_embeds, |
| | encoder_attention_mask=image_atts, |
| | return_dict=False, |
| | mode='tagging', |
| | ) |
| | return sigmoid(model.fc(tagging_embed))[:, class_idxs] |
| |
|
| |
|
| | def print_write(f: TextIO, s: str): |
| | print(s) |
| | f.write(s + "\n") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | args = parse_args() |
| |
|
| | |
| | output_dir = args.output_dir |
| | Path(output_dir).mkdir(parents=True, exist_ok=True) |
| | pred_file, pr_file, ap_file, summary_file, logit_file = [ |
| | output_dir + "/" + name for name in |
| | ("pred.txt", "pr.txt", "ap.txt", "summary.txt", "logits.pth") |
| | ] |
| | with open(summary_file, "w", encoding="utf-8") as f: |
| | print_write(f, "****************") |
| | for key in ( |
| | "model_type", "backbone", "checkpoint", "open_set", |
| | "dataset", "input_size", |
| | "threshold", "threshold_file", |
| | "output_dir", "batch_size", "num_workers" |
| | ): |
| | print_write(f, f"{key}: {getattr(args, key)}") |
| | print_write(f, "****************") |
| |
|
| | |
| | loader, info = load_dataset( |
| | dataset=args.dataset, |
| | model_type=args.model_type, |
| | input_size=args.input_size, |
| | batch_size=args.batch_size, |
| | num_workers=args.num_workers |
| | ) |
| | taglist, imglist, annot_file, img_root, tag_des = \ |
| | info["taglist"], info["imglist"], info["annot_file"], info["img_root"], info["tag_des"] |
| |
|
| | |
| | class_idxs = get_class_idxs( |
| | model_type=args.model_type, |
| | open_set=args.open_set, |
| | taglist=taglist |
| | ) |
| |
|
| | |
| | thresholds = load_thresholds( |
| | threshold=args.threshold, |
| | threshold_file=args.threshold_file, |
| | model_type=args.model_type, |
| | open_set=args.open_set, |
| | class_idxs=class_idxs, |
| | num_classes=len(taglist) |
| | ) |
| |
|
| | |
| | if Path(logit_file).is_file(): |
| |
|
| | logits = torch.load(logit_file) |
| |
|
| | else: |
| | |
| | if args.model_type == "ram_plus": |
| | model = load_ram_plus( |
| | backbone=args.backbone, |
| | checkpoint=args.checkpoint, |
| | input_size=args.input_size, |
| | taglist=taglist, |
| | tag_des = tag_des, |
| | open_set=args.open_set, |
| | class_idxs=class_idxs |
| | ) |
| | elif args.model_type == "ram": |
| | model = load_ram( |
| | backbone=args.backbone, |
| | checkpoint=args.checkpoint, |
| | input_size=args.input_size, |
| | taglist=taglist, |
| | open_set=args.open_set, |
| | class_idxs=class_idxs |
| | ) |
| | elif args.model_type == "tag2text": |
| | model = load_tag2text( |
| | backbone=args.backbone, |
| | checkpoint=args.checkpoint, |
| | input_size=args.input_size |
| | ) |
| |
|
| | |
| | logits = torch.empty(len(imglist), len(taglist)) |
| | pos = 0 |
| | for imgs in tqdm(loader, desc="inference"): |
| | if args.model_type == "ram_plus": |
| | out = forward_ram_plus(model, imgs) |
| | elif args.model_type == "ram": |
| | out = forward_ram(model, imgs) |
| | else: |
| | out = forward_tag2text(model, class_idxs, imgs) |
| | bs = imgs.shape[0] |
| | logits[pos:pos+bs, :] = out.cpu() |
| | pos += bs |
| |
|
| | |
| | torch.save(logits, logit_file) |
| |
|
| | |
| | pred_tags = [] |
| | for scores in logits.tolist(): |
| | pred_tags.append([ |
| | taglist[i] for i, s in enumerate(scores) if s >= thresholds[i] |
| | ]) |
| |
|
| | |
| | gen_pred_file(imglist, pred_tags, img_root, pred_file) |
| |
|
| | |
| | mAP, APs = get_mAP(logits.numpy(), annot_file, taglist) |
| | CP, CR, Ps, Rs = get_PR(pred_file, annot_file, taglist) |
| |
|
| | with open(ap_file, "w", encoding="utf-8") as f: |
| | f.write("Tag,AP\n") |
| | for tag, AP in zip(taglist, APs): |
| | f.write(f"{tag},{AP*100.0:.2f}\n") |
| |
|
| | with open(pr_file, "w", encoding="utf-8") as f: |
| | f.write("Tag,Precision,Recall\n") |
| | for tag, P, R in zip(taglist, Ps, Rs): |
| | f.write(f"{tag},{P*100.0:.2f},{R*100.0:.2f}\n") |
| |
|
| | with open(summary_file, "w", encoding="utf-8") as f: |
| | print_write(f, f"mAP: {mAP*100.0}") |
| | print_write(f, f"CP: {CP*100.0}") |
| | print_write(f, f"CR: {CR*100.0}") |
| |
|