File size: 4,591 Bytes
84ff315
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import warnings
import shutil

from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig, AutoProcessor
import torch
from ola.model import *
from ola.model.speech_encoder.builder import build_speech_encoder

# 过滤掉 PyTorch 的 meta parameter 警告
warnings.filterwarnings("ignore", message=".*copying from a non-meta parameter in the checkpoint to a meta parameter.*")

def load_pretrained_model(model_path, model_type, model_base, is_lora=False, s2s=False, load_8bit=False, load_4bit=False, device="cuda", use_flash_attn=False, **kwargs):
    device = "cuda"
    if load_8bit:
        kwargs['load_in_8bit'] = True
    elif load_4bit:
        kwargs['load_in_4bit'] = True
        kwargs['quantization_config'] = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type='nf4'
        )
    else:
        kwargs['torch_dtype'] = torch.bfloat16

    if use_flash_attn:
        kwargs['attn_implementation'] = 'flash_attention_2'

    if model_type == 'ola_internvl':
        model_cls = OlaQwen3ForCausalLM
        print('Loading OlaQwen3ForCausalLM model...')
    else:
        model_cls = OlaQwenForCausalLM

    # Load Ola model
    if is_lora:
        assert model_base is not None, "model_base is required for LoRA models."
        from ola.model.language_model.ola_qwen import OlaConfigQwen
        lora_cfg_pretrained = OlaConfigQwen.from_pretrained(model_path)
        tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
        print('Loading Ola from base model...')
        model = model_cls.from_pretrained(model_base, low_cpu_mem_usage=False, config=lora_cfg_pretrained, **kwargs)
        print('Loading additional Ola weights...')
        if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
            non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
        non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
        if any(k.startswith('model.model.') for k in non_lora_trainables):
            non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
        model.load_state_dict(non_lora_trainables, strict=False, assign=True)

        from peft import PeftModel
        print('Loading LoRA weights...')
        model = PeftModel.from_pretrained(model, model_path)
        print('Merging LoRA weights...')
        model = model.merge_and_unload()
        print('Model is loaded...')
    elif model_base is not None:
        print('Loading Ola from base model...')
        tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
        cfg_pretrained = AutoConfig.from_pretrained(model_path)
        model = model_cls.from_pretrained(model_base, low_cpu_mem_usage=False, config=cfg_pretrained, **kwargs)
        
        speech_projector_weights = torch.load(os.path.join(model_path, 'speech_projector.bin'), map_location='cpu')
        speech_projector_weights = {k: v.to(torch.float16) for k, v in speech_projector_weights.items()}
        model.load_state_dict(speech_projector_weights, strict=False, assign=True)
        model = model.to(device=device)
    else:
        # model_path = "/data1/cxy/plm-v/modeling/plm_internvl3_5_ola"
        model_path = "/data1/cxy/plm-v/modeling/ckpt/ola_audio_8_8gpu/checkpoint-120"
        tokernizer_path = "/data1/cxy/plm-v/modeling/internvl3_5-2B"
        tokenizer = AutoTokenizer.from_pretrained(tokernizer_path, use_fast=False, trust_remote_code=True)
        cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
        with torch.device("cuda"):
            model = model_cls.from_pretrained(
                model_path,
                trust_remote_code=True,
                config=cfg,
                # device_map="auto",
                **kwargs,
            )
        model = model.to(device=device)
    # breakpoint()
    image_processor = None
    model.resize_token_embeddings(len(tokenizer))
    # breakpoint()
    print("Loading vision tower...")
    print("Loading vision tower succeeded.")
    
    if hasattr(model.config, "max_sequence_length"):
        context_len = model.config.max_sequence_length
    else:
        context_len = 16384
    image_processor = AutoProcessor.from_pretrained("/data1/cxy/plm-v/modeling/internvl3_5-2B-HF")

    return tokenizer, model, image_processor, context_len