| import os |
| import torch |
| from transformers import AutoTokenizer, AutoModel, AutoConfig |
|
|
| model_path = "../base-model/chatglm-6b" |
| tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
|
|
| config = AutoConfig.from_pretrained(model_path, trust_remote_code=True, pre_seq_len=128) |
| config.prefix_n_experts = 2 |
| config.prefix_cur_expert = -1 |
| config.expert_weights = [0.4, 0.6] |
|
|
| model = AutoModel.from_pretrained(model_path, config=config, trust_remote_code=True) |
|
|
| |
| prefix_state_dict = torch.load( |
| os.path.join("./checkpoints/checkpoint-3000", "pytorch_model.bin") |
| ) |
| new_prefix_state_dict = {} |
| for k, v in prefix_state_dict.items(): |
| new_prefix_state_dict[k[len("transformer.prefix_encoder.") :]] = v |
| model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) |
|
|
| |
| |
| model = model.quantize(4) |
| model = model.half().cuda() |
| model.transformer.prefix_encoder.float() |
| model = model.eval() |
|
|
| |
| |
|
|
| |
| |
|
|
| response = "" |
| history = [] |
|
|
| while True: |
| user_prompt = input(">>>>>>>> USER: ") |
| response, history = model.chat(tokenizer, user_prompt, history=history) |
| print(f">>>>>>>> Assistant: {response}") |
| print("\n") |
|
|