QCQC / src /inference.py
Johnny050407's picture
Upload folder using huggingface_hub
9ed01de verified
import os
import sys
import argparse
import numpy as np
import torch
from utils import *
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed, GPT2Tokenizer, GPT2LMHeadModel
from transformers.generation.stopping_criteria import StoppingCriteria, StoppingCriteriaList
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
if BASE_DIR not in sys.path:
sys.path.insert(0, BASE_DIR)
sys.path.insert(0, os.path.join(BASE_DIR, ".."))
from search_preparation import index_configs
def resolve_path(*parts):
return os.path.abspath(os.path.join(BASE_DIR, *parts))
def load_index_f(index_dir):
print(os.getcwd())
index_path = os.path.join(index_dir, 'faiss_IVPQ_PCA.index')
index = faiss.read_index(index_path)
norm1 = faiss.read_VectorTransform(os.path.join(index_dir, "norm1.bin"))
do_pca = os.path.exists(os.path.join(index_dir, "pca.bin"))
if do_pca:
pca = faiss.read_VectorTransform(os.path.join(index_dir, "pca.bin"))
norm2 = faiss.read_VectorTransform(os.path.join(index_dir, "norm2.bin"))
def feat_transform(x):
x = norm1.apply_py(x)
if do_pca:
x = pca.apply_py(x)
x = norm2.apply_py(x)
return x
img_ids = torch.load(os.path.join(index_dir, 'img_ids.pt'), weights_only=False)
return index, feat_transform, img_ids
class StoppingCriteriaSub(StoppingCriteria):
def __init__(self, tokenizer, stops = [], encounters=1):
super().__init__()
self.tokenizer = tokenizer
self.stops = [stop.to("cuda") for stop in stops]
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
current_token = input_ids[0]
for stop in self.stops:
if self.tokenizer.decode(stop) in self.tokenizer.decode(current_token):
return True
return False
def query_completion(model_name, ori_query, args, text_dir):
if model_name == "gpt2" or model_name == "Qwen":
model_save_path = model_name
elif model_name == 'no_completion':
print(f"No query completion, copy original query {args.cmpl_k} times to keep the shape ! ")
return [[item] * args.cmpl_k for item in ori_query] # copy 10 times to keep the same shape with the ones using k auto-completions
else:
# Path: outputs/gpt2coco{suffix}/model_{suffix}/checkpoint, e.g. model_name="model_ADS" -> gpt2cocoADS/model_ADS/checkpoint
suffix = model_name.replace("model_", "", 1) if model_name.startswith("model_") else model_name
if "Qwen" in model_name:
tmp_path = os.path.join(text_dir, f"qwencoco{suffix}", model_name)
else:
tmp_path = os.path.join(text_dir, f"gpt2coco{suffix}", model_name)
model_save_path = os.path.join(tmp_path, "checkpoint")
print(f"loading model from {model_save_path}...")
if model_name == "gpt2":
tokenizer = GPT2Tokenizer.from_pretrained(model_save_path)
model = GPT2LMHeadModel.from_pretrained(model_save_path)
tokenizer.pad_token = tokenizer.eos_token
elif model_name == "Qwen":
model = AutoModelForCausalLM.from_pretrained(model_save_path+"/Qwen2.5-0.5B")
tokenizer = AutoTokenizer.from_pretrained(model_save_path+"/Qwen2.5-0.5B")
tokenizer.pad_token = tokenizer.eos_token
else:
if "Qwen" in model_name:
model = AutoModelForCausalLM.from_pretrained(model_save_path)
tokenizer = AutoTokenizer.from_pretrained(model_save_path)
else:
tokenizer = GPT2Tokenizer.from_pretrained(model_save_path)
model = GPT2LMHeadModel.from_pretrained(model_save_path)
#tokenizer.pad_token = tokenizer.eos_token # This explicitly sets the pad token to the same value as the EOS token, ensuring that padding is handled.
#tokenizer.padding_side ='left'
#model.config.pad_token_id = tokenizer.eos_token_id
#model.config.eos_token_id = tokenizer.eos_token_id
model.eval()
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters: {total_params}")
stop_words = [".", "!", "?"]
stop_words_ids = [tokenizer(stop_word, return_tensors='pt', add_special_tokens=False)['input_ids'].squeeze() for stop_word in stop_words]
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(tokenizer, stops=stop_words_ids)])
if args.no_condition:
with torch.no_grad():
autocompleted_queries = []
for _ in range(80):
generated_ids = model.generate(
do_sample = True,
min_length = args.min_len,
#max_length = args.max_len,
max_new_tokens = args.max_len,
temperature = args.tmpr,
top_k = args.top_k,
top_p = args.top_p,
repetition_penalty = args.rept_pnal,
no_repeat_ngram_size = args.no_rept_ngram,
eos_token_id = tokenizer.eos_token_id , # Use the end-of-sentence token as a stop condition
pad_token_id = tokenizer.pad_token_id, # tokenizer.pad_token_id,
stopping_criteria = stopping_criteria,
num_return_sequences = 1,
)
autocompleted_queries.append(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
autocompleted_queries = [[item] * args.cmpl_k for item in autocompleted_queries]
else:
with torch.no_grad():
tokenized_inputs = [tokenizer(text, padding=True, truncation=True, return_tensors="pt") for text in ori_query] # len(tokenized_inputs)=80
autocompleted_queries = []
for input_data in tokenized_inputs:
queries_for_each_class = []
for i in range(args.cmpl_k):
generated_ids = model.generate(
input_data['input_ids'],
attention_mask = input_data["attention_mask"],
do_sample = True,
min_length = args.min_len,
#max_length = args.max_len,
max_new_tokens = args.max_len,
temperature = args.tmpr, # Control randomness (lower temperature makes the output more deterministic)
top_k = args.top_k,
top_p = args.top_p,
repetition_penalty = args.rept_pnal,
no_repeat_ngram_size = args.no_rept_ngram,
eos_token_id = tokenizer.eos_token_id , # Use the end-of-sentence token as a stop condition
pad_token_id = tokenizer.pad_token_id, # tokenizer.pad_token_id,
stopping_criteria = stopping_criteria,
return_dict_in_generate=False, # if true: generated_ids[0] should be --> generated_ids["sequences"][0]
num_return_sequences = 1, # only generate one sentence at a time, stopping_criteria didn't work well for multi-outputs
)
# tokenizer.decode(model.generate(input_data['input_ids'], attention_mask = input_data["attention_mask"], max_new_tokens=200, num_return_sequences = 1, eos_token_id = tokenizer.eos_token_id , pad_token_id = tokenizer.pad_token_id, )[0])
queries_for_each_class.append(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
autocompleted_queries.append(queries_for_each_class)
return autocompleted_queries
def image_retrieve_coco(sear_k, index, q_feats, loaded_data, img_ids):
img_list = []
dis_list = []
D, I = index.search(q_feats, sear_k)
aesthetics, faiss_smi, iqas, img_hash_list = get_scores_coco(I, D, loaded_data, img_ids)
print_scores_iqa(aesthetics, faiss_smi, iqas)
return aesthetics, faiss_smi, iqas, img_hash_list
def get_scores_coco(I, D, loaded_data, img_ids):
aesthetics_score = torch.tensor(loaded_data["aesthetics"])
IQAs_score = torch.tensor(loaded_data["IQAs"])
strImagehash = loaded_data["image_ids"]
img_hash = [img_ids[idx] for idx in I]
aesthetics = []
iqas = []
for each_class in img_hash:
indices = [strImagehash.index(int(s.split(".")[0])) if int(s.split(".")[0]) in strImagehash else None for s in each_class]
aes_score = [aesthetics_score[iii] if iii is not None else aesthetics_score.mean() for iii in indices]
iqa_score = [IQAs_score[iii] if iii is not None else IQAs_score.mean() for iii in indices]
aes_score = torch.stack(aes_score)
aesthetics.append(aes_score)
iqa_score = torch.stack(iqa_score)
iqas.append(iqa_score)
aesthetics = torch.stack(aesthetics)
iqas = torch.stack(iqas)
faiss_smi = torch.tensor(D)
return aesthetics, faiss_smi, iqas, img_hash
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Dataset-condition query completion")
parser.add_argument('--cmpl_k', type=int, default=10, help='perform k times query completion')
parser.add_argument('--sear_k', type=int, default=50, help='search k images for each completed query')
parser.add_argument('--min_len', type=int, default=10, help='minimal length in query completion')
parser.add_argument('--max_len', type=int, default=20, help='maximal length in query completion')
parser.add_argument('--tmpr', type=float, default=0.7, help='temperature in query completion')
parser.add_argument('--top_k', type=int, default=50, help='select tok_k tokens in query completion')
parser.add_argument('--top_p', type=float, default=0.9, help='top_p in query completion')
parser.add_argument('--rept_pnal', type=float, default=1.2, help='repetition_penalty in query completion')
parser.add_argument('--no_rept_ngram', type=int, default=2, help='no_repeat_ngram_size in query completion')
parser.add_argument('--seed', type=int, default=42, help='seed')
parser.add_argument('--search_imgs', action='store_true', help='load pre-downloaded data for image retrieval')
parser.add_argument('--an_img_showing', action='store_true', help='query: an image showing x')
parser.add_argument('--prt_cmpl_qry', action='store_true', help='print some completed queries for illustration')
parser.add_argument('--no_condition', action='store_true', help='generate texts without conditions')
parser.add_argument('--model_names', nargs='+', default=['no_completion', 'gpt2'], help='Choices of completion models')
parser.add_argument('--aes_level', type=str, default='high', help='low or high or median')
parser.add_argument('--sim_level', type=str, default='high', help='low or high or median')
parser.add_argument('--iqa_level', type=str, default='high', help='low or high or median')
args = parser.parse_args()
set_seed(args.seed)
work_space = resolve_path("..")
text_dir = resolve_path("..", "outputs")
coco_class_path = os.path.join(work_space, "MC-COCO-Class.txt")
if not os.path.exists(coco_class_path):
raise FileNotFoundError(f"Class list not found: {coco_class_path}")
a_cls_list, an_image_showing_list = load_ori_query(coco_class_path)
index_dir = resolve_path("..", "coco_faiss_indexes", "CLIP-Huge-Flickr-Flat")
if not os.path.isdir(index_dir):
raise FileNotFoundError(f"Index dir not found: {index_dir}. Run search_preparation.py first.")
index, feat_transform, img_ids = load_index(index_dir)
print("loaded index, transform, img_ids.")
model_id = 'CLIP-Huge-Flickr-Flat'
ai_config = index_configs[model_id]['a1_config']
weight_path = index_configs[model_id]['weight_path']
for model_name in args.model_names:
print(args)
data_dict = {'text': np.array(a_cls_list)}
dataset = Dataset.from_dict(data_dict)
# Prompt templates: must match training order (A=Aesthetic, D=DeQA, S=Similarity)
PROMPT_TEMPLATES = {
# "gpt2": "Similarity: {sim}, Aesthetic: {aes}, Query: {cond}",
# "Qwen": "Similarity: {sim}, Aesthetic: {aes}, Query: {cond}",
"model_ADS": "<|startoftext|>Aesthetic: {aes}, DeQA-Score: {iqa}, Similarity: {sim}, Query: {cond}",
"model_DSA": "<|startoftext|>DeQA-Score: {iqa}, Similarity: {sim}, Aesthetic: {aes}, Query: {cond}",
"model_SAD": "<|startoftext|>Similarity: {sim}, Aesthetic: {aes}, DeQA-Score: {iqa}, Query: {cond}",
"model_SA": "<|startoftext|>Similarity: {sim}, Aesthetic: {aes}, Query: {cond}",
}
if model_name not in PROMPT_TEMPLATES:
raise ValueError(f"Invalid model name: {model_name}. Valid: {list(PROMPT_TEMPLATES)}")
prompt = PROMPT_TEMPLATES[model_name]
def apply_prompt_template(sample):
sim = args.sim_level
aes = args.aes_level
iqa = args.iqa_level
con_str = sample["text"]
return {"prompt": prompt.format(aes=aes, iqa=iqa, sim=sim, cond=con_str)}
dataset = dataset.map(apply_prompt_template, remove_columns=["text"])
queries = query_completion(model_name, dataset['prompt'], args, text_dir)
autocompleted_queries = []
for item in queries:
query_text = item[0].split('Query: ')[1]
autocompleted_queries.append([query_text])
## print some query completion results for illustration
if args.prt_cmpl_qry:
print(f"------------------------- {model_name} --------------------------------")
for ii in range(len(autocompleted_queries)):
for jj, query in enumerate(autocompleted_queries[ii]):
print(query)
if args.search_imgs:
data_path = resolve_path("..", "processed_data", "coco", "data.pt")
if not os.path.exists(data_path):
raise FileNotFoundError(f"Processed data not found: {data_path}")
loaded_data = torch.load(data_path, weights_only=False)
## textual squences to vectors
q_feats = get_text_list_feature(autocompleted_queries, ai_config, weight_path)
q_feats = torch.tensor(np.array(q_feats)).squeeze()
q_feats /= q_feats.norm(dim=-1, keepdim=True)
q_feats = feat_transform(q_feats.numpy())
aesthetics, faiss_smi, iqas, img_hash_list = image_retrieve_coco(args.sear_k, index, q_feats, loaded_data, img_ids)