| |
|
| | import random |
| | import numpy as np |
| | import torch |
| | from PIL import Image |
| | Image.MAX_IMAGE_PIXELS = None |
| | from transformers import AutoModel, AutoTokenizer |
| | import opencc |
| | from ultralytics import YOLO |
| | from config.configu import * |
| | from utils.utils import * |
| | import logging |
| | import argparse |
| |
|
| | def setup_logger(log_file): |
| | logging.basicConfig(filename=log_file, level=logging.INFO, format='%(asctime)s - %(message)s') |
| | logger = logging.getLogger() |
| | return logger |
| |
|
| | def set_seed(seed): |
| | random.seed(seed) |
| | np.random.seed(seed) |
| | torch.manual_seed(seed) |
| |
|
| | if torch.cuda.is_available(): |
| | torch.cuda.manual_seed(seed) |
| | torch.cuda.manual_seed_all(seed) |
| | |
| | torch.backends.cudnn.deterministic = True |
| | torch.backends.cudnn.benchmark = False |
| |
|
| | cc = opencc.OpenCC('t2s.json') |
| | set_seed(SEED) |
| | converter_t2s = opencc.OpenCC('t2s') |
| |
|
| |
|
| | def single_rec(model,tokenizer,detect_model,generation_config,image_path,prompt,use_p,hard_vq,drop_zero,repetition_penalty,verbose): |
| | response, history = model.chat_ocr(tokenizer, detect_model,image_path, prompt, generation_config, |
| | use_p=use_p, |
| | hard_vq=hard_vq, |
| | drop_zero=drop_zero,repetition_penalty=repetition_penalty,return_history=True,verbose=verbose) |
| | print(f'User: {prompt}\nAssistant: {response}') |
| |
|
| | def folder_rec(model,tokenizer,detect_model,generation_config,folder_path,prompt,save_name,use_p,hard_vq,drop_zero,repetition_penalty,verbose): |
| | results=[] |
| |
|
| | all_images=get_image_paths(folder_path) |
| | for pic in tqdm(all_images): |
| | pic_path=os.path.join(folder_path,pic) |
| | try: |
| | response, history = model.chat_ocr(tokenizer, detect_model,pic_path, prompt, generation_config, |
| | use_p=use_p, |
| | hard_vq=hard_vq, |
| | drop_zero=drop_zero,repetition_penalty=repetition_penalty,return_history=True,verbose=verbose) |
| | except Exception as e: |
| | print(f"An error has occured:\n{e}") |
| | response="ERROR!" |
| | print(f'User: {prompt}\nAssistant: {response}') |
| | results.append({"imagePath":pic_path,'prompt':prompt,'response':response}) |
| | if not save_name.endswith('json'): |
| | save_name+='_result.json' |
| | save_json(save_name,results) |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description="args for inference task") |
| |
|
| | parser.add_argument('--tgt', type=str,help='Recognition target') |
| | parser.add_argument('--prompt', type=str,default='这幅书法作品内容是什么?',help='Prompt for recognition') |
| | parser.add_argument('--save_name',type=str,default="recognition.json",help="Storage of results if multiple images recognition mode") |
| |
|
| | parser.add_argument('--use_p', type=bool, default=True,help='Decide the usage of perceiver resampler') |
| | parser.add_argument('--hard_vq', type=bool, default=False,help='Decide the usage of closest similarity match') |
| | parser.add_argument('--drop_zero', type=bool, default=False,help='Decide the deletion of zero padding in pseudo tokens') |
| | parser.add_argument('--verbose', type=bool, default=False,help='Decide the output of extra information') |
| | parser.add_argument('--repetition_penalty', type=float, default=1.0,help='Repetition penalty for generation') |
| | |
| |
|
| | args = parser.parse_args() |
| |
|
| | if not isinstance(args.tgt,str): |
| | raise ValueError(f"The target should a string, not a instance of {type(args.tgt)}!") |
| |
|
| | |
| | model = AutoModel.from_pretrained( |
| | INTERNVL_PATH, |
| | torch_dtype=torch.bfloat16, |
| | low_cpu_mem_usage=True, |
| | trust_remote_code=True).eval().cuda() |
| | tokenizer = AutoTokenizer.from_pretrained(INTERNVL_PATH, trust_remote_code=True) |
| |
|
| | generation_config = dict( |
| | num_beams=1, |
| | max_new_tokens=1024, |
| | do_sample=False, |
| | ) |
| |
|
| | detect_model=YOLO(YOLO_CHECKPOINT) |
| | if is_image(args.tgt): |
| | print("Single image recognition mode.") |
| | single_rec( |
| | model, |
| | tokenizer, |
| | detect_model, |
| | generation_config, |
| | args.tgt, |
| | args.prompt, |
| | args.use_p, |
| | args.hard_vq, |
| | args.drop_zero, |
| | args.repetition_penalty, |
| | args.verbose) |
| | elif os.path.isdir(args.tgt): |
| | print("Multiple images recognition mode") |
| | os.makedirs('results',exist_ok=True) |
| | folder_rec( |
| | model, |
| | tokenizer, |
| | detect_model, |
| | generation_config, |
| | args.tgt, |
| | args.prompt, |
| | os.path.join('results',args.save_name), |
| | args.use_p, |
| | args.hard_vq, |
| | args.drop_zero, |
| | args.repetition_penalty, |
| | args.verbose) |
| | else: |
| | raise ValueError(f"The target should be either a image path or a folder that contain images!") |
| |
|
| | def single_image_wrapped(image,prompts): |
| | model = AutoModel.from_pretrained( |
| | INTERNVL_PATH, |
| | torch_dtype=torch.bfloat16, |
| | low_cpu_mem_usage=True, |
| | trust_remote_code=True).eval().cuda() |
| | tokenizer = AutoTokenizer.from_pretrained(INTERNVL_PATH, trust_remote_code=True) |
| |
|
| | generation_config = dict( |
| | num_beams=1, |
| | max_new_tokens=1024, |
| | do_sample=False, |
| | ) |
| | detect_model=YOLO(YOLO_CHECKPOINT) |
| |
|
| |
|
| | temp_dir = "temp_images" |
| | os.makedirs(temp_dir, exist_ok=True) |
| | |
| | |
| | temp_image_path = os.path.join(temp_dir, "uploaded_image.png") |
| | image.save(temp_image_path) |
| | single_rec( |
| | model, |
| | tokenizer, |
| | detect_model, |
| | generation_config, |
| | temp_image_path, |
| | prompts, |
| | True, |
| | False, |
| | True, |
| | 1.2, |
| | False) |
| | if __name__=='__main__': |
| |
|
| | main() |
| |
|
| |
|
| |
|