import torch from transformers import AutoModelForCausalLM, AutoTokenizer,AutoConfig def load_tokenizer(model_name: str, is_hf: bool=False): if not is_hf: tokenizer = AutoTokenizer.from_pretrained("gpt2") tokenizer.model_max_length = 2048 else: if "mamba" in model_name or "mpt" in model_name: tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") else: tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token_id = tokenizer.eos_token_id return tokenizer from fla.models import DeltaNetConfig, DeltaNetForCausalLM, DeltaNetModel print(DeltaNetConfig.model_type) AutoConfig.register("delta_net",DeltaNetConfig) AutoModelForCausalLM.register(DeltaNetConfig,DeltaNetForCausalLM) from opencompass.models.fla2.models import mask_deltanetConfig,mask_deltanetForCausalLM print(mask_deltanetConfig.model_type) AutoConfig.register("mask_deltanet",mask_deltanetConfig) AutoModelForCausalLM.register(mask_deltanetConfig,mask_deltanetForCausalLM) # model_path = "/mnt/jfzn/msj/delta_net-1.3B-100B" model_path = "/mnt/jfzn/msj/train_exp/mask_deltanet_1B_rank4" # 注意:DeltaNet 必须开 trust_remote_code!!! # tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=torch.bfloat16, # 推荐 fp16,加速且稳定 trust_remote_code=True, # <<=== 你的代码缺了它! device_map="cuda", ) tokenizer = load_tokenizer(model_path, is_hf=True) prompt = "What is the official language of China?" inputs = tokenizer(prompt, return_tensors="pt").to("cuda") outputs = model.generate( **inputs, max_new_tokens=100, do_sample=False, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, ) print(tokenizer.decode(outputs[0], skip_special_tokens=True))