|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer,AutoConfig |
|
|
|
|
|
def load_tokenizer(model_name: str, is_hf: bool=False): |
|
|
if not is_hf: |
|
|
tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
|
tokenizer.model_max_length = 2048 |
|
|
else: |
|
|
if "mamba" in model_name or "mpt" in model_name: |
|
|
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") |
|
|
else: |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
tokenizer.pad_token_id = tokenizer.eos_token_id |
|
|
return tokenizer |
|
|
|
|
|
from fla.models import DeltaNetConfig, DeltaNetForCausalLM, DeltaNetModel |
|
|
print(DeltaNetConfig.model_type) |
|
|
AutoConfig.register("delta_net",DeltaNetConfig) |
|
|
AutoModelForCausalLM.register(DeltaNetConfig,DeltaNetForCausalLM) |
|
|
|
|
|
from opencompass.models.fla2.models import mask_deltanetConfig,mask_deltanetForCausalLM |
|
|
print(mask_deltanetConfig.model_type) |
|
|
AutoConfig.register("mask_deltanet",mask_deltanetConfig) |
|
|
AutoModelForCausalLM.register(mask_deltanetConfig,mask_deltanetForCausalLM) |
|
|
|
|
|
model_path = "/mnt/jfzn/msj/train_exp/mask_deltanet_1B_rank4" |
|
|
|
|
|
|
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_path, |
|
|
torch_dtype=torch.bfloat16, |
|
|
trust_remote_code=True, |
|
|
device_map="cuda", |
|
|
) |
|
|
tokenizer = load_tokenizer(model_path, is_hf=True) |
|
|
prompt = "What is the official language of China?" |
|
|
inputs = tokenizer(prompt, return_tensors="pt").to("cuda") |
|
|
|
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=100, |
|
|
do_sample=False, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
) |
|
|
|
|
|
print(tokenizer.decode(outputs[0], skip_special_tokens=True)) |