attempts to fix more
Browse files- myolmoe/modeling_myolmoe.py +49 -14
- scripts/eval.py +37 -23
myolmoe/modeling_myolmoe.py
CHANGED
|
@@ -22,12 +22,16 @@ from dataclasses import dataclass, field
|
|
| 22 |
from typing import Optional, List, Any
|
| 23 |
from transformers import PretrainedConfig
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
@dataclass
|
| 26 |
class MyOlmoeConfig(PretrainedConfig):
|
| 27 |
"""
|
| 28 |
Configuration class for MyOlmoe model.
|
| 29 |
"""
|
| 30 |
-
model_type: str = "
|
| 31 |
|
| 32 |
# Core model parameters
|
| 33 |
vocab_size: int = 50304
|
|
@@ -72,31 +76,62 @@ class MyOlmoeConfig(PretrainedConfig):
|
|
| 72 |
rope_theta: float = 10000.0
|
| 73 |
rope_scaling: Optional[dict] = None
|
| 74 |
|
| 75 |
-
# Token IDs
|
| 76 |
pad_token_id: int = 1
|
| 77 |
eos_token_id: int = 50279
|
|
|
|
| 78 |
|
| 79 |
# Model architecture
|
| 80 |
architectures: List[str] = field(default_factory=lambda: ["MyOlmoeForCausalLM"])
|
| 81 |
|
| 82 |
def __init__(self, **kwargs):
|
| 83 |
-
#
|
| 84 |
-
model_loading_params = ['torch_dtype', 'device_map', 'low_cpu_mem_usage'
|
|
|
|
| 85 |
for param in model_loading_params:
|
| 86 |
kwargs.pop(param, None)
|
| 87 |
|
| 88 |
-
#
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
super().__init__(**kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
logger = logging.get_logger(__name__)
|
| 102 |
|
|
|
|
| 22 |
from typing import Optional, List, Any
|
| 23 |
from transformers import PretrainedConfig
|
| 24 |
|
| 25 |
+
from dataclasses import dataclass, field
|
| 26 |
+
from typing import Optional, List, Dict, Any
|
| 27 |
+
from transformers import PretrainedConfig
|
| 28 |
+
|
| 29 |
@dataclass
|
| 30 |
class MyOlmoeConfig(PretrainedConfig):
|
| 31 |
"""
|
| 32 |
Configuration class for MyOlmoe model.
|
| 33 |
"""
|
| 34 |
+
model_type: str = "olmoe" # Keep as "olmoe" to match your trained model
|
| 35 |
|
| 36 |
# Core model parameters
|
| 37 |
vocab_size: int = 50304
|
|
|
|
| 76 |
rope_theta: float = 10000.0
|
| 77 |
rope_scaling: Optional[dict] = None
|
| 78 |
|
| 79 |
+
# Token IDs - Set proper defaults
|
| 80 |
pad_token_id: int = 1
|
| 81 |
eos_token_id: int = 50279
|
| 82 |
+
bos_token_id: int = 1
|
| 83 |
|
| 84 |
# Model architecture
|
| 85 |
architectures: List[str] = field(default_factory=lambda: ["MyOlmoeForCausalLM"])
|
| 86 |
|
| 87 |
def __init__(self, **kwargs):
|
| 88 |
+
# Handle model loading parameters that shouldn't go to config
|
| 89 |
+
model_loading_params = ['torch_dtype', 'device_map', 'low_cpu_mem_usage',
|
| 90 |
+
'load_in_8bit', 'load_in_4bit', 'quantization_config']
|
| 91 |
for param in model_loading_params:
|
| 92 |
kwargs.pop(param, None)
|
| 93 |
|
| 94 |
+
# Set defaults for any missing required fields
|
| 95 |
+
if 'pad_token_id' not in kwargs:
|
| 96 |
+
kwargs['pad_token_id'] = self.pad_token_id
|
| 97 |
+
if 'eos_token_id' not in kwargs:
|
| 98 |
+
kwargs['eos_token_id'] = self.eos_token_id
|
| 99 |
+
if 'bos_token_id' not in kwargs:
|
| 100 |
+
kwargs['bos_token_id'] = self.bos_token_id
|
| 101 |
+
if 'architectures' not in kwargs:
|
| 102 |
+
kwargs['architectures'] = ["MyOlmoeForCausalLM"]
|
| 103 |
+
|
| 104 |
+
# Initialize the parent class first
|
| 105 |
super().__init__(**kwargs)
|
| 106 |
+
|
| 107 |
+
# Then set dataclass fields from remaining kwargs or defaults
|
| 108 |
+
for field_name, field_def in self.__dataclass_fields__.items():
|
| 109 |
+
if hasattr(self, field_name):
|
| 110 |
+
continue # Already set by parent
|
| 111 |
+
if field_name in kwargs:
|
| 112 |
+
setattr(self, field_name, kwargs[field_name])
|
| 113 |
+
else:
|
| 114 |
+
# Use default value from dataclass field
|
| 115 |
+
if field_def.default != field_def.default_factory:
|
| 116 |
+
setattr(self, field_name, field_def.default)
|
| 117 |
+
elif field_def.default_factory != field_def.default_factory: # type: ignore
|
| 118 |
+
setattr(self, field_name, field_def.default_factory())
|
| 119 |
|
| 120 |
+
@classmethod
|
| 121 |
+
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
|
| 122 |
+
"""Override from_pretrained to handle the model type properly."""
|
| 123 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
| 124 |
+
|
| 125 |
+
# Keep the original model_type from the saved config
|
| 126 |
+
# This allows loading models trained with "olmoe" type
|
| 127 |
+
if 'model_type' in config_dict:
|
| 128 |
+
original_model_type = config_dict['model_type']
|
| 129 |
+
# But register with the class model_type for compatibility
|
| 130 |
+
if original_model_type == "olmoe":
|
| 131 |
+
config_dict['model_type'] = "olmoe" # Keep as olmoe
|
| 132 |
+
|
| 133 |
+
return cls.from_dict(config_dict, **kwargs)
|
| 134 |
+
|
| 135 |
|
| 136 |
logger = logging.get_logger(__name__)
|
| 137 |
|
scripts/eval.py
CHANGED
|
@@ -183,12 +183,6 @@ def load_transformers_model(args) -> HFLM:
|
|
| 183 |
def load_custom_model(args) -> HFLM:
|
| 184 |
"""
|
| 185 |
Load custom MyOLMoE model (uses top-k routing by default).
|
| 186 |
-
|
| 187 |
-
Args:
|
| 188 |
-
args: Parsed command line arguments
|
| 189 |
-
|
| 190 |
-
Returns:
|
| 191 |
-
HFLM: Wrapped model ready for evaluation
|
| 192 |
"""
|
| 193 |
logger.info(f"Loading custom MyOLMoE model: {args.model_path}")
|
| 194 |
logger.info("Using top-k routing (default)")
|
|
@@ -205,15 +199,11 @@ def load_custom_model(args) -> HFLM:
|
|
| 205 |
from modeling_myolmoe import MyOlmoeForCausalLM, MyOlmoeConfig
|
| 206 |
logger.info("Successfully imported MyOlmoeForCausalLM and MyOlmoeConfig")
|
| 207 |
|
| 208 |
-
#
|
| 209 |
-
if not hasattr(MyOlmoeConfig, '__dataclass_fields__'):
|
| 210 |
-
logger.warning("MyOlmoeConfig is not a dataclass, this may cause issues")
|
| 211 |
-
|
| 212 |
-
# Register the custom model class with the correct config
|
| 213 |
from transformers import AutoConfig, AutoModelForCausalLM
|
| 214 |
-
AutoConfig.register("
|
| 215 |
-
AutoModelForCausalLM.register(MyOlmoeConfig, MyOlmoeForCausalLM)
|
| 216 |
-
logger.info("Registered MyOlmoeForCausalLM with MyOlmoeConfig")
|
| 217 |
|
| 218 |
except ImportError as e:
|
| 219 |
logger.error(f"Failed to import custom model: {e}")
|
|
@@ -224,31 +214,51 @@ def load_custom_model(args) -> HFLM:
|
|
| 224 |
logger.info("Loading model manually to avoid wrapper issues...")
|
| 225 |
|
| 226 |
try:
|
| 227 |
-
# Load tokenizer
|
| 228 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 229 |
args.model_path,
|
| 230 |
trust_remote_code=args.trust_remote_code
|
| 231 |
)
|
| 232 |
|
| 233 |
-
# Load config
|
| 234 |
-
model_config =
|
| 235 |
args.model_path,
|
| 236 |
trust_remote_code=args.trust_remote_code
|
| 237 |
)
|
| 238 |
|
| 239 |
-
# Debug information
|
| 240 |
logger.info(f"Loaded config type: {type(model_config)}")
|
| 241 |
logger.info(f"Config model_type: {model_config.model_type}")
|
| 242 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
# Load model instance
|
| 244 |
-
model_instance =
|
| 245 |
args.model_path,
|
| 246 |
-
|
| 247 |
-
trust_remote_code=args.trust_remote_code,
|
| 248 |
-
torch_dtype=torch.bfloat16 if args.dtype == "bfloat16" else "auto"
|
| 249 |
)
|
| 250 |
|
| 251 |
-
|
|
|
|
|
|
|
| 252 |
model = HFLM(
|
| 253 |
pretrained=model_instance,
|
| 254 |
tokenizer=tokenizer,
|
|
@@ -259,11 +269,15 @@ def load_custom_model(args) -> HFLM:
|
|
| 259 |
|
| 260 |
except Exception as e:
|
| 261 |
logger.error(f"Failed to load custom model: {e}")
|
|
|
|
|
|
|
|
|
|
| 262 |
raise
|
| 263 |
|
| 264 |
logger.info("Custom model loaded successfully")
|
| 265 |
return model
|
| 266 |
|
|
|
|
| 267 |
def validate_model_config(model_path: str, trust_remote_code: bool = False) -> Dict[str, Any]:
|
| 268 |
"""
|
| 269 |
Validate model configuration and return key information.
|
|
|
|
| 183 |
def load_custom_model(args) -> HFLM:
|
| 184 |
"""
|
| 185 |
Load custom MyOLMoE model (uses top-k routing by default).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
"""
|
| 187 |
logger.info(f"Loading custom MyOLMoE model: {args.model_path}")
|
| 188 |
logger.info("Using top-k routing (default)")
|
|
|
|
| 199 |
from modeling_myolmoe import MyOlmoeForCausalLM, MyOlmoeConfig
|
| 200 |
logger.info("Successfully imported MyOlmoeForCausalLM and MyOlmoeConfig")
|
| 201 |
|
| 202 |
+
# IMPORTANT: Register with "olmoe" since that's what your model was trained with
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
from transformers import AutoConfig, AutoModelForCausalLM
|
| 204 |
+
AutoConfig.register("olmoe", MyOlmoeConfig, exist_ok=True) # Use exist_ok=True
|
| 205 |
+
AutoModelForCausalLM.register(MyOlmoeConfig, MyOlmoeForCausalLM, exist_ok=True)
|
| 206 |
+
logger.info("Registered MyOlmoeForCausalLM with MyOlmoeConfig for 'olmoe' type")
|
| 207 |
|
| 208 |
except ImportError as e:
|
| 209 |
logger.error(f"Failed to import custom model: {e}")
|
|
|
|
| 214 |
logger.info("Loading model manually to avoid wrapper issues...")
|
| 215 |
|
| 216 |
try:
|
| 217 |
+
# Load tokenizer first
|
| 218 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 219 |
args.model_path,
|
| 220 |
trust_remote_code=args.trust_remote_code
|
| 221 |
)
|
| 222 |
|
| 223 |
+
# Load config - this should now work with the olmoe type
|
| 224 |
+
model_config = AutoConfig.from_pretrained(
|
| 225 |
args.model_path,
|
| 226 |
trust_remote_code=args.trust_remote_code
|
| 227 |
)
|
| 228 |
|
|
|
|
| 229 |
logger.info(f"Loaded config type: {type(model_config)}")
|
| 230 |
logger.info(f"Config model_type: {model_config.model_type}")
|
| 231 |
|
| 232 |
+
# Verify the config is properly initialized
|
| 233 |
+
if not hasattr(model_config, '__dataclass_fields__'):
|
| 234 |
+
logger.warning("Config is not recognized as a dataclass, attempting to recreate...")
|
| 235 |
+
# Recreate config as proper dataclass instance
|
| 236 |
+
config_dict = model_config.to_dict()
|
| 237 |
+
model_config = MyOlmoeConfig(**config_dict)
|
| 238 |
+
|
| 239 |
+
# Prepare model loading kwargs
|
| 240 |
+
model_kwargs = {
|
| 241 |
+
'config': model_config,
|
| 242 |
+
'trust_remote_code': args.trust_remote_code,
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
# Add torch_dtype if specified
|
| 246 |
+
if args.dtype == "bfloat16":
|
| 247 |
+
model_kwargs['torch_dtype'] = torch.bfloat16
|
| 248 |
+
elif args.dtype == "float16":
|
| 249 |
+
model_kwargs['torch_dtype'] = torch.float16
|
| 250 |
+
elif args.dtype == "float32":
|
| 251 |
+
model_kwargs['torch_dtype'] = torch.float32
|
| 252 |
+
|
| 253 |
# Load model instance
|
| 254 |
+
model_instance = AutoModelForCausalLM.from_pretrained(
|
| 255 |
args.model_path,
|
| 256 |
+
**model_kwargs
|
|
|
|
|
|
|
| 257 |
)
|
| 258 |
|
| 259 |
+
logger.info(f"Loaded model type: {type(model_instance)}")
|
| 260 |
+
|
| 261 |
+
# Create HFLM wrapper
|
| 262 |
model = HFLM(
|
| 263 |
pretrained=model_instance,
|
| 264 |
tokenizer=tokenizer,
|
|
|
|
| 269 |
|
| 270 |
except Exception as e:
|
| 271 |
logger.error(f"Failed to load custom model: {e}")
|
| 272 |
+
logger.error(f"Error type: {type(e)}")
|
| 273 |
+
import traceback
|
| 274 |
+
logger.error(f"Traceback: {traceback.format_exc()}")
|
| 275 |
raise
|
| 276 |
|
| 277 |
logger.info("Custom model loaded successfully")
|
| 278 |
return model
|
| 279 |
|
| 280 |
+
|
| 281 |
def validate_model_config(model_path: str, trust_remote_code: bool = False) -> Dict[str, Any]:
|
| 282 |
"""
|
| 283 |
Validate model configuration and return key information.
|