File size: 5,250 Bytes
2997d61 784595b 2997d61 784595b 2997d61 784595b 2997d61 784595b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import pkgutil
import re
from transformers import AutoConfig, AutoModelForCausalLM
import yaml
from stripedhyena.utils import dotdict
from stripedhyena.model import StripedHyena
from stripedhyena.tokenizer import CharLevelTokenizer
MODEL_NAMES = [
'evo-1.5-8k-base',
'evo-1-8k-base',
'evo-1-131k-base',
'evo-1-8k-crispr',
'evo-1-8k-transposon',
]
class Evo:
def __init__(self, model_name: str = MODEL_NAMES[1], device: str = None):
"""
Loads an Evo model checkpoint given a model name.
If the checkpoint does not exist, we automatically download it from HuggingFace.
"""
self.device = device
# Check model name.
if model_name not in MODEL_NAMES:
raise ValueError(
f'Invalid model name {model_name}. Should be one of: '
f'{", ".join(MODEL_NAMES)}.'
)
# Assign config path.
if model_name == 'evo-1-8k-base' or \
model_name == 'evo-1-8k-crispr' or \
model_name == 'evo-1-8k-transposon' or \
model_name == 'evo-1.5-8k-base':
config_path = 'configs/evo-1-8k-base_inference.yml'
elif model_name == 'evo-1-131k-base':
config_path = 'configs/evo-1-131k-base_inference.yml'
else:
raise ValueError(
f'Invalid model name {model_name}. Should be one of: '
f'{", ".join(MODEL_NAMES)}.'
)
# Load model.
self.model = load_checkpoint(
model_name=model_name,
config_path=config_path,
device=self.device
)
# Load tokenizer.
self.tokenizer = CharLevelTokenizer(512)
HF_MODEL_NAME_MAP = {
'evo-1.5-8k-base': 'evo-design/evo-1.5-8k-base',
'evo-1-8k-base': 'togethercomputer/evo-1-8k-base',
'evo-1-131k-base': 'togethercomputer/evo-1-131k-base',
'evo-1-8k-crispr': 'LongSafari/evo-1-8k-crispr',
'evo-1-8k-transposon': 'LongSafari/evo-1-8k-transposon',
}
def load_checkpoint(
model_name: str = MODEL_NAMES[1],
config_path: str = 'evo/configs/evo-1-131k-base_inference.yml',
device: str = None,
*args, **kwargs
):
"""
Load checkpoint from HuggingFace and place it into SH model.
"""
# Map model name to HuggingFace model name.
hf_model_name = HF_MODEL_NAME_MAP[model_name]
# Load SH config first (local)
config = yaml.safe_load(pkgutil.get_data(__name__, config_path))
global_config = dotdict(config, Loader=yaml.FullLoader)
try:
# Try to load from HuggingFace Hub
model_config = AutoConfig.from_pretrained(
hf_model_name,
trust_remote_code=True,
revision='1.1_fix' if re.match(r'evo-1-.*-base', model_name) else 'main',
)
model_config.use_cache = True
# Load pretrained model from HuggingFace
hf_model = AutoModelForCausalLM.from_pretrained(
hf_model_name,
config=model_config,
trust_remote_code=True,
revision='1.1_fix' if re.match(r'evo-1-.*-base', model_name) else 'main',
)
# Extract state dict from HuggingFace model
state_dict = hf_model.backbone.state_dict()
del hf_model
del model_config
# Load into StripedHyena model with our config
model = StripedHyena(global_config)
model.load_state_dict(state_dict, strict=True)
# Fix the tokenizer import issue by copying files to HF cache
_fix_hf_tokenizer_cache(hf_model_name)
except Exception as e:
# If HuggingFace download fails, initialize from scratch
print(f"Warning: Could not load pretrained weights from HuggingFace: {e}")
print("Initializing model with random weights...")
model = StripedHyena(global_config)
model.to_bfloat16_except_poles_residues()
if device is not None:
model = model.to(device)
return model
def _fix_hf_tokenizer_cache(hf_model_name):
"""Copy tokenizer files to HuggingFace cache after download."""
import shutil
from pathlib import Path
try:
hf_cache = Path.home() / ".cache" / "huggingface" / "modules" / "transformers_modules"
# Get our local files
import stripedhyena
stripedhyena_path = Path(stripedhyena.__file__).parent
local_tokenizer = stripedhyena_path / "tokenizer.py"
local_utils = stripedhyena_path / "utils.py"
if not local_tokenizer.exists():
return
# Find the model cache directory
model_short_name = hf_model_name.split("/")[-1] # e.g., "evo-1-8k-base"
model_cache = hf_cache / hf_model_name
if model_cache.exists():
# Copy to all version subdirectories
for version_dir in model_cache.iterdir():
if version_dir.is_dir():
shutil.copy2(local_tokenizer, version_dir / "tokenizer.py")
shutil.copy2(local_utils, version_dir / "utils.py")
print(f"✓ Fixed tokenizer cache for {model_short_name}")
except Exception as e:
print(f"Warning: Could not fix HF cache: {e}")
|