load custom model fix
Browse files- scripts/eval.py +43 -32
scripts/eval.py
CHANGED
|
@@ -205,8 +205,15 @@ def load_custom_model(args) -> HFLM:
|
|
| 205 |
from modeling_myolmoe import MyOlmoeForCausalLM, MyOlmoeConfig
|
| 206 |
logger.info("Successfully imported MyOlmoeForCausalLM and MyOlmoeConfig")
|
| 207 |
|
| 208 |
-
# CRITICAL FIX:
|
| 209 |
from transformers import AutoConfig, AutoModelForCausalLM
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
AutoConfig.register("myolmoe", MyOlmoeConfig)
|
| 211 |
AutoModelForCausalLM.register(MyOlmoeConfig, MyOlmoeForCausalLM)
|
| 212 |
logger.info("Registered MyOlmoeForCausalLM with MyOlmoeConfig")
|
|
@@ -216,52 +223,36 @@ def load_custom_model(args) -> HFLM:
|
|
| 216 |
logger.error("Make sure the custom model code is available in the specified path")
|
| 217 |
raise
|
| 218 |
|
| 219 |
-
# Load model
|
| 220 |
-
config = MyOlmoeConfig.from_pretrained(
|
| 221 |
-
args.model_path,
|
| 222 |
-
trust_remote_code=args.trust_remote_code
|
| 223 |
-
)
|
| 224 |
-
|
| 225 |
-
logger.info("Model will use default top-k routing configuration")
|
| 226 |
-
|
| 227 |
-
# Create HFLM with explicit model class specification
|
| 228 |
try:
|
| 229 |
-
model
|
| 230 |
-
pretrained=args.model_path,
|
| 231 |
-
device=args.device,
|
| 232 |
-
batch_size=args.batch_size,
|
| 233 |
-
max_batch_size=args.max_batch_size,
|
| 234 |
-
dtype=args.dtype,
|
| 235 |
-
trust_remote_code=args.trust_remote_code,
|
| 236 |
-
# Pass the custom model class explicitly
|
| 237 |
-
backend="causal",
|
| 238 |
-
model_kwargs={"torch_dtype": torch.bfloat16 if args.dtype == "bfloat16" else "auto"}
|
| 239 |
-
)
|
| 240 |
-
except Exception as e:
|
| 241 |
-
logger.error(f"Failed to create HFLM wrapper: {e}")
|
| 242 |
-
# Alternative approach: load model manually then wrap
|
| 243 |
-
logger.info("Trying alternative loading approach...")
|
| 244 |
|
| 245 |
-
# Load tokenizer
|
| 246 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 247 |
args.model_path,
|
| 248 |
trust_remote_code=args.trust_remote_code
|
| 249 |
)
|
| 250 |
|
| 251 |
-
#
|
| 252 |
config = MyOlmoeConfig.from_pretrained(
|
| 253 |
args.model_path,
|
| 254 |
trust_remote_code=args.trust_remote_code
|
| 255 |
)
|
| 256 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
model_instance = MyOlmoeForCausalLM.from_pretrained(
|
| 258 |
args.model_path,
|
| 259 |
config=config,
|
| 260 |
trust_remote_code=args.trust_remote_code,
|
| 261 |
-
torch_dtype=torch.bfloat16 if args.dtype == "bfloat16" else "auto"
|
|
|
|
| 262 |
)
|
| 263 |
|
| 264 |
-
# Create HFLM with pre-loaded model
|
| 265 |
model = HFLM(
|
| 266 |
pretrained=model_instance,
|
| 267 |
tokenizer=tokenizer,
|
|
@@ -269,9 +260,29 @@ def load_custom_model(args) -> HFLM:
|
|
| 269 |
batch_size=args.batch_size,
|
| 270 |
max_batch_size=args.max_batch_size
|
| 271 |
)
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
|
| 276 |
|
| 277 |
def validate_model_config(model_path: str, trust_remote_code: bool = False) -> Dict[str, Any]:
|
|
|
|
| 205 |
from modeling_myolmoe import MyOlmoeForCausalLM, MyOlmoeConfig
|
| 206 |
logger.info("Successfully imported MyOlmoeForCausalLM and MyOlmoeConfig")
|
| 207 |
|
| 208 |
+
# CRITICAL FIX: Ensure the config is properly registered as a dataclass
|
| 209 |
from transformers import AutoConfig, AutoModelForCausalLM
|
| 210 |
+
from dataclasses import dataclass
|
| 211 |
+
|
| 212 |
+
# Make sure the config is a proper dataclass
|
| 213 |
+
if not hasattr(MyOlmoeConfig, '__dataclass_fields__'):
|
| 214 |
+
logger.warning("MyOlmoeConfig is not a dataclass, this may cause issues")
|
| 215 |
+
|
| 216 |
+
# Register with correct model type
|
| 217 |
AutoConfig.register("myolmoe", MyOlmoeConfig)
|
| 218 |
AutoModelForCausalLM.register(MyOlmoeConfig, MyOlmoeForCausalLM)
|
| 219 |
logger.info("Registered MyOlmoeForCausalLM with MyOlmoeConfig")
|
|
|
|
| 223 |
logger.error("Make sure the custom model code is available in the specified path")
|
| 224 |
raise
|
| 225 |
|
| 226 |
+
# Load model manually to avoid HFLM wrapper issues
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
try:
|
| 228 |
+
logger.info("Loading model manually to avoid wrapper issues...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
|
| 230 |
+
# Load tokenizer first
|
| 231 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 232 |
args.model_path,
|
| 233 |
trust_remote_code=args.trust_remote_code
|
| 234 |
)
|
| 235 |
|
| 236 |
+
# Load config with explicit class
|
| 237 |
config = MyOlmoeConfig.from_pretrained(
|
| 238 |
args.model_path,
|
| 239 |
trust_remote_code=args.trust_remote_code
|
| 240 |
)
|
| 241 |
|
| 242 |
+
# Verify config is valid
|
| 243 |
+
logger.info(f"Loaded config type: {type(config)}")
|
| 244 |
+
logger.info(f"Config model_type: {getattr(config, 'model_type', 'unknown')}")
|
| 245 |
+
|
| 246 |
+
# Load model instance
|
| 247 |
model_instance = MyOlmoeForCausalLM.from_pretrained(
|
| 248 |
args.model_path,
|
| 249 |
config=config,
|
| 250 |
trust_remote_code=args.trust_remote_code,
|
| 251 |
+
torch_dtype=torch.bfloat16 if args.dtype == "bfloat16" else "auto",
|
| 252 |
+
low_cpu_mem_usage=True
|
| 253 |
)
|
| 254 |
|
| 255 |
+
# Create HFLM wrapper with pre-loaded model
|
| 256 |
model = HFLM(
|
| 257 |
pretrained=model_instance,
|
| 258 |
tokenizer=tokenizer,
|
|
|
|
| 260 |
batch_size=args.batch_size,
|
| 261 |
max_batch_size=args.max_batch_size
|
| 262 |
)
|
| 263 |
+
|
| 264 |
+
logger.info("Custom model loaded successfully")
|
| 265 |
+
return model
|
| 266 |
+
|
| 267 |
+
except Exception as e:
|
| 268 |
+
logger.error(f"Failed to load custom model: {e}")
|
| 269 |
+
logger.error("Trying fallback approach...")
|
| 270 |
+
|
| 271 |
+
# Fallback: Try loading as standard transformers model
|
| 272 |
+
try:
|
| 273 |
+
model = HFLM(
|
| 274 |
+
pretrained=args.model_path,
|
| 275 |
+
device=args.device,
|
| 276 |
+
batch_size=args.batch_size,
|
| 277 |
+
max_batch_size=args.max_batch_size,
|
| 278 |
+
dtype=args.dtype,
|
| 279 |
+
trust_remote_code=True # Force trust remote code
|
| 280 |
+
)
|
| 281 |
+
logger.info("Fallback loading successful")
|
| 282 |
+
return model
|
| 283 |
+
except Exception as fallback_e:
|
| 284 |
+
logger.error(f"Fallback also failed: {fallback_e}")
|
| 285 |
+
raise e
|
| 286 |
|
| 287 |
|
| 288 |
def validate_model_config(model_path: str, trust_remote_code: bool = False) -> Dict[str, Any]:
|