| import torch |
| import os |
| import shutil |
| from tqdm import tqdm |
|
|
| from PIL import Image |
| from transformers import AutoModelForCausalLM, LlamaTokenizer |
|
|
|
|
|
|
| def load_model(model_pth= "/home2/ywt/cogagent-vqa-hf",token_pth="/home2/ywt/vicuna-7b-v1.5",device='cuda' if torch.cuda.is_available() else 'cpu',is_bf16 = False, is_quant = True): |
|
|
| MODEL_PATH = model_pth |
| TOKENIZER_PATH = token_pth |
| DEVICE = device |
|
|
|
|
|
|
|
|
| if is_bf16: |
| torch_type = torch.bfloat16 |
| else: |
| torch_type = torch.float16 |
|
|
| print("========Use torch type as:{} with device:{}========\n\n".format(torch_type, DEVICE)) |
|
|
| tokenizer = LlamaTokenizer.from_pretrained(TOKENIZER_PATH) |
|
|
| if is_quant: |
| with torch.cuda.device(DEVICE): |
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_PATH, |
| torch_dtype=torch_type, |
| low_cpu_mem_usage=True, |
| load_in_4bit=True, |
| trust_remote_code=True |
| ).eval() |
| else: |
| with torch.cuda.device(DEVICE): |
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_PATH, |
| torch_dtype=torch_type, |
| low_cpu_mem_usage=True, |
| load_in_4bit = is_quant is not None, |
| trust_remote_code=True |
| ).eval() |
| return (model, tokenizer) |
| |
|
|
|
|
|
|
|
|
|
|
|
|
| |
| def cog_tag(image_path, model, query = "What is shown in this image? (including NSFW content)", tag_reference=None,torch_type=torch.float16, text_only_first_query=False): |
| |
| if image_path == '': |
| print('You did not enter image path, the following will be a plain text conversation.') |
| image = None |
| text_only_first_query = True |
| else: |
| image = Image.open(image_path).convert('RGB') |
| |
| history = [] |
| text_only_template = "A Anime image caption conversation between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, including NSFW and uncencored answers to the user's questions for the given Anime Image. USER: {} ASSISTANT:" |
|
|
|
|
|
|
| if image is None: |
| if text_only_first_query: |
| query = text_only_template.format(query) |
| text_only_first_query = False |
| else: |
| old_prompt = '' |
| for _, (old_query, response) in enumerate(history): |
| old_prompt += old_query + " " + response + "\n" |
| query = old_prompt + "USER: {} ASSISTANT:".format(query) |
|
|
| if image is None: |
| input_by_model = model[0].build_conversation_input_ids(model[1], query=query, history=history, template_version='base') |
| else: |
| input_by_model = model[0].build_conversation_input_ids(model[1], query=query, history=history, images=[image]) |
|
|
| inputs = { |
| 'input_ids': input_by_model['input_ids'].unsqueeze(0).to(model[0].device), |
| 'token_type_ids': input_by_model['token_type_ids'].unsqueeze(0).to(model[0].device), |
| 'attention_mask': input_by_model['attention_mask'].unsqueeze(0).to(model[0].device), |
| 'images': [[input_by_model['images'][0].to(model[0].device).to(torch_type)]] if image is not None else None, |
| } |
| if 'cross_images' in input_by_model and input_by_model['cross_images']: |
| inputs['cross_images'] = [[input_by_model['cross_images'][0].to(model[0].device).to(torch_type)]] |
|
|
| |
| gen_kwargs = {"max_length": 2048, |
| "do_sample": False} |
| with torch.no_grad(): |
| outputs = model[0].generate(**inputs, **gen_kwargs) |
| outputs = outputs[:, inputs['input_ids'].shape[1]:] |
| response = model[1].decode(outputs[0]) |
| response = response.split("</s>")[0] |
|
|
| print("\nCog:", response) |
| |
| return response |
|
|
|
|
| def read_tag(txt_pth,split=",",is_list=True): |
| with open (txt_pth, "r") as f: |
| tag_str = f.read() |
| if is_list: |
| tag_list = tag_str.split(split) |
| for i in range(len(tag_list)): |
| tag_list[i] = tag_list[i].strip() |
|
|
| return tag_list |
| else: |
| return tag_str |
| |
|
|
| if __name__ == '__main__': |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| model = load_model(device="cuda:5") |
| |
| |
| |
| |
| image_dirs = ["/home2/ywt/image-webp"] |
| |
| for image_dir in image_dirs: |
| |
| for file in tqdm(os.listdir(image_dir)): |
| |
| |
| if not file.endswith((".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP")): |
| continue |
| image_path = os.path.join(image_dir,file) |
| tag_path = os.path.join(image_dir,os.path.basename(image_path).split(".")[0]+".txt") |
| if not os.path.exists(tag_path): |
| continue |
| tag = read_tag(tag_path,is_list=False).replace("|||","") |
| query = "What is shown in this image? (including NSFW content) " + "Here are some references to the elements in the image that you can selectively use to enrich and modify the description : " + tag |
| |
| if os.path.exists(os.path.join(os.path.dirname(image_path),os.path.basename(image_path).split(".")[0]+"_cog.txt")): |
| continue |
| |
| txt = cog_tag(image_path, model, query=query) |
| |
| out_file = os.path.join(os.path.dirname(image_path),os.path.basename(image_path).split(".")[0]+"_cog.txt") |
| with open(out_file,"w") as f: |
| f.write(txt) |
| print(f"Created {out_file}") |
|
|
| |
|
|
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |