File size: 1,947 Bytes
dc367ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer,AutoConfig

def load_tokenizer(model_name: str, is_hf: bool=False):
    if not is_hf:
        tokenizer = AutoTokenizer.from_pretrained("gpt2")
        tokenizer.model_max_length = 2048
    else:
        if "mamba" in model_name or "mpt" in model_name:
            tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
        else:
            tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id
    return tokenizer

from fla.models import DeltaNetConfig, DeltaNetForCausalLM, DeltaNetModel
print(DeltaNetConfig.model_type)
AutoConfig.register("delta_net",DeltaNetConfig)
AutoModelForCausalLM.register(DeltaNetConfig,DeltaNetForCausalLM)

from opencompass.models.fla2.models import mask_deltanetConfig,mask_deltanetForCausalLM
print(mask_deltanetConfig.model_type)
AutoConfig.register("mask_deltanet",mask_deltanetConfig)
AutoModelForCausalLM.register(mask_deltanetConfig,mask_deltanetForCausalLM)
# model_path = "/mnt/jfzn/msj/delta_net-1.3B-100B"
model_path =  "/mnt/jfzn/msj/train_exp/mask_deltanet_1B_rank4"
# 注意:DeltaNet 必须开 trust_remote_code!!!
# tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,   # 推荐 fp16,加速且稳定
    trust_remote_code=True,      # <<=== 你的代码缺了它!
    device_map="cuda",
)
tokenizer = load_tokenizer(model_path, is_hf=True)
prompt = "What is the official language of China?"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

outputs = model.generate(
    **inputs,
    max_new_tokens=100,
    do_sample=False,
    pad_token_id=tokenizer.eos_token_id,
    eos_token_id=tokenizer.eos_token_id,
)

print(tokenizer.decode(outputs[0], skip_special_tokens=True))