| import torch | |
| from llava.model import * | |
| from transformers import AutoConfig, StoppingCriteria | |
| def auto_upgrade(config): | |
| cfg = AutoConfig.from_pretrained(config) | |
| if 'llava' in config and 'llava' not in cfg.model_type: | |
| assert cfg.model_type == 'llama' | |
| print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.") | |
| print("You must upgrade the checkpoint to the new code base (this can be done automatically).") | |
| confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") | |
| if confirm.lower() in ["y", "yes"]: | |
| print("Upgrading checkpoint...") | |
| assert len(cfg.architectures) == 1 | |
| setattr(cfg.__class__, "model_type", "llava") | |
| cfg.architectures[0] = 'LlavaLlamaForCausalLM' | |
| cfg.save_pretrained(config) | |
| print("Checkpoint upgraded.") | |
| else: | |
| print("Checkpoint upgrade aborted.") | |
| exit(1) | |
| class KeywordsStoppingCriteria(StoppingCriteria): | |
| def __init__(self, keywords, tokenizer, input_ids): | |
| self.keywords = keywords | |
| self.keyword_ids = [tokenizer(keyword).input_ids for keyword in keywords] | |
| self.keyword_ids = [keyword_id[0] for keyword_id in self.keyword_ids if type(keyword_id) is list and len(keyword_id) == 1] | |
| self.tokenizer = tokenizer | |
| self.start_len = None | |
| self.input_ids = input_ids | |
| def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
| if self.start_len is None: | |
| self.start_len = self.input_ids.shape[1] | |
| else: | |
| for keyword_id in self.keyword_ids: | |
| if output_ids[0, -1] == keyword_id: | |
| return True | |
| outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0] | |
| for keyword in self.keywords: | |
| if keyword in outputs: | |
| return True | |
| return False | |