| import torch | |
| import torch.nn as nn | |
| from transformers import AutoConfig | |
| def auto_upgrade(config): | |
| cfg = AutoConfig.from_pretrained(config) | |
| if 'llava' in config and 'llava' not in cfg.model_type: | |
| assert cfg.model_type == 'llama' | |
| print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.") | |
| print("You must upgrade the checkpoint to the new code base (this can be done automatically).") | |
| confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") | |
| if confirm.lower() in ["y", "yes"]: | |
| print("Upgrading checkpoint...") | |
| assert len(cfg.architectures) == 1 | |
| setattr(cfg.__class__, "model_type", "llava") | |
| cfg.architectures[0] = 'LlavaLlamaForCausalLM' | |
| cfg.save_pretrained(config) | |
| print("Checkpoint upgraded.") | |
| else: | |
| print("Checkpoint upgrade aborted.") | |
| exit(1) | |
| class LayerNorm(nn.LayerNorm): | |
| """Subclass torch's LayerNorm to handle fp16.""" | |
| def forward(self, x: torch.Tensor): | |
| with torch.cuda.amp.autocast(dtype=torch.float32): | |
| orig_type = x.dtype | |
| ret = super().forward(x.type(torch.float32)) | |
| return ret.type(orig_type) |