| import argparse |
| import json |
| import os |
| import random |
| from collections import OrderedDict |
| from typing import List, Sequence, Tuple |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| try: |
| from transformers import AutoTokenizer, AutoConfig, XLMRobertaModel |
| except ImportError: |
| AutoTokenizer = None |
| HFBertModel = None |
|
|
|
|
| class XLMRobertaLanguageBackbone(nn.Module): |
|
|
| def __init__( |
| self, |
| ckpt_path, |
| frozen_modules: Sequence[str] = (), |
| dropout: float = 0.0, |
| init_cfg= None, |
| ) -> None: |
|
|
| super().__init__() |
| if 'base' in ckpt_path: |
| self.head = nn.Linear(768, 768, bias=True) |
| model_name = "./xlm-roberta-base/" |
| elif 'large' in ckpt_path: |
| self.head = nn.Linear(1024, 768, bias=True) |
| model_name = "./xlm-roberta-large/" |
|
|
| self.frozen_modules = frozen_modules |
| cfg = AutoConfig.from_pretrained(model_name) |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
| self.model = XLMRobertaModel(cfg) |
| self.language_dim = cfg.hidden_size |
| |
| |
| |
| new_state_dict = OrderedDict() |
| state_dict = torch.load( |
| ckpt_path, |
| map_location="cpu", |
| weights_only=False, |
| )['state_dict'] |
| for k, v in state_dict.items(): |
| if k.startswith('backbone.text_model.'): |
| name = k.split("backbone.text_model.")[-1] |
| new_state_dict[name] = v |
| msg = self.load_state_dict(new_state_dict, strict=True) |
| print(msg) |
|
|
| print("TEXT-ENCODER xlm-roberta-base LOADING WEIGHTS !!!!") |
|
|
|
|
|
|
| def forward(self, text: List[str], max_seq_len: int = 32): |
| text = self.tokenizer(text=text, return_tensors="pt", |
| padding="max_length", max_length=max_seq_len) |
| text = text.to(device=self.model.device) |
|
|
| txt_feats = self.model(**text)["last_hidden_state"][:, 0] |
| txt_feats = self.head(txt_feats) |
|
|
| return txt_feats |
|
|
|
|
| if __name__ == '__main__': |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument('--wedetect_checkpoint', type=str, default='checkpoints/wedetect_base.pth') |
| parser.add_argument('--classname_file', type=str, default='data/texts/coco_zh_class_texts.json') |
| parser.add_argument('--max-seq-len', type=int, default=32, |
| help='Fixed token length (must match ONNX export).') |
| parser.add_argument('--num-classes-per-group', type=int, default=4, |
| help='Number of classes per group npy.') |
| parser.add_argument('--num-groups', type=int, default=64, |
| help='Number of random groups to generate.') |
| parser.add_argument('--calib-dir', type=str, default='calib_data', |
| help='Directory for text-encoder quantisation calibration data.') |
| args = parser.parse_args() |
|
|
| with open(args.classname_file) as f: |
| name_chinese = json.load(f) |
| name_chinese = [name[0] for name in name_chinese] |
|
|
| language_encoder = XLMRobertaLanguageBackbone(args.wedetect_checkpoint).cuda() |
|
|
| |
| total_classes = len(name_chinese) |
| print(f"Total classes: {total_classes} → Generating {args.num_groups} random groups") |
| |
| |
| calib_input_ids = os.path.join(args.calib_dir, "input_ids") |
| calib_attn_mask = os.path.join(args.calib_dir, "attention_mask") |
| for d in (calib_input_ids, calib_attn_mask): |
| os.makedirs(d, exist_ok=True) |
|
|
| tokenizer = language_encoder.tokenizer |
|
|
| for g in range(args.num_groups): |
| idx = random.sample(range(total_classes), args.num_classes_per_group) |
| group_texts = [name_chinese[i] for i in idx] |
| tokens = tokenizer(group_texts, padding="max_length", |
| max_length=args.max_seq_len, return_tensors="np") |
|
|
| np.save(os.path.join(calib_input_ids, f"{g:03d}.npy"), |
| tokens["input_ids"].astype(np.int64)) |
| np.save(os.path.join(calib_attn_mask, f"{g:03d}.npy"), |
| tokens["attention_mask"].astype(np.int64)) |
| print(f"calib [{g:03d}] input_ids: {tokens['input_ids'].shape} " |
| f"classes: {group_texts}") |
|
|
| |
| import tarfile |
| for sub_name in ("input_ids", "attention_mask"): |
| sub_dir = os.path.join(args.calib_dir, sub_name) |
| tar_path = os.path.join(args.calib_dir, f"{sub_name}.tar.gz") |
| with tarfile.open(tar_path, "w:gz") as tar: |
| for fname in sorted(os.listdir(sub_dir)): |
| tar.add(os.path.join(sub_dir, fname), arcname=fname) |
| print(f"Compressed: {tar_path}") |
|
|
| print(f"Saved calibration data to {args.calib_dir}/") |
|
|
| |
| |
| |
| |
| |
| embed_dir = os.path.join(args.calib_dir, "class_embedding_4cls") |
| os.makedirs(embed_dir, exist_ok=True) |
|
|
| print(f"\nGenerating {args.num_groups} random {args.num_classes_per_group}-class " |
| f"text embeddings → {embed_dir}/") |
| for g in range(args.num_groups): |
| idx = random.sample(range(total_classes), args.num_classes_per_group) |
| group_texts = [name_chinese[i] for i in idx] |
| with torch.no_grad(): |
| feats = language_encoder(group_texts, max_seq_len=args.max_seq_len) |
| feats = F.normalize(feats, dim=-1).unsqueeze(0) |
| fpath = os.path.join(embed_dir, f"{g:03d}.npy") |
| np.save(fpath, feats.cpu().numpy().astype(np.float32)) |
| if (g + 1) % 16 == 0 or g == args.num_groups - 1: |
| print(f" [{g + 1:3d}/{args.num_groups}] shape={feats.shape} " |
| f"classes: {group_texts}") |
|
|
| |
| tar_path = os.path.join(args.calib_dir, "class_embedding_4cls.tar.gz") |
| with tarfile.open(tar_path, "w:gz") as tar: |
| for fname in sorted(os.listdir(embed_dir)): |
| tar.add(os.path.join(embed_dir, fname), arcname=fname) |
| print(f"Compressed: {tar_path}") |
|
|
| print(f"Saved calibration data to {args.calib_dir}/") |
|
|