|
|
import torch, warnings, json, pathlib |
|
|
from transformers.models.auto.tokenization_auto import AutoTokenizer |
|
|
from transformers.models.auto.modeling_auto import AutoModelForCausalLM |
|
|
from configuration_evo2 import Evo2Config |
|
|
|
|
|
|
|
|
root = pathlib.Path(".") |
|
|
|
|
|
|
|
|
INFERENCE_MAX_SEQLEN = 8192 |
|
|
|
|
|
print("Loading tokenizer...") |
|
|
tok = AutoTokenizer.from_pretrained(root, trust_remote_code=True) |
|
|
|
|
|
print(f"Loading configuration and overriding max_seqlen to {INFERENCE_MAX_SEQLEN}...") |
|
|
|
|
|
config = Evo2Config.from_pretrained(root, trust_remote_code=True) |
|
|
|
|
|
config.max_seqlen = INFERENCE_MAX_SEQLEN |
|
|
|
|
|
|
|
|
|
|
|
print("Loading model with modified config... (this takes ~30 s on first run)") |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
root, |
|
|
config=config, |
|
|
torch_dtype=torch.bfloat16, |
|
|
device_map="cuda:0", |
|
|
trust_remote_code=True |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
prompt = "ATGGCGAAAACGTGGCTCGTCCGGTAGGGATCTGGAAACAATTGTAGACAGTTCCGAGTTGTCAAGGGCCA" |
|
|
tokens = tok(prompt, return_tensors="pt").to(model.device) |
|
|
with torch.no_grad(): |
|
|
out = model.generate( |
|
|
input_ids=tokens['input_ids'], |
|
|
max_new_tokens=64, |
|
|
temperature=0.8, |
|
|
do_sample=True) |
|
|
print("\n--- Generated sequence ---\n", |
|
|
tok.decode(out[0], skip_special_tokens=True)) |
|
|
|