|
|
|
|
|
""" |
|
|
eval_with_expert_tracking.py - Evaluation script for OLMoE models with expert usage tracking |
|
|
|
|
|
This script extends the standard evaluation to track: |
|
|
1. Which experts are being used |
|
|
2. Frequency of expert usage |
|
|
3. Distribution across experts |
|
|
4. Small vs regular expert usage |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import json |
|
|
import os |
|
|
import sys |
|
|
import logging |
|
|
from typing import Dict, List, Optional, Any, Tuple |
|
|
import numpy as np |
|
|
import torch |
|
|
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
|
|
|
from lm_eval import evaluator |
|
|
from lm_eval.models.huggingface import HFLM |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.DEBUG, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class ExpertTrackingHFLM(HFLM): |
|
|
"""Wrapper around HFLM that tracks expert usage statistics.""" |
|
|
|
|
|
def __init__(self, *args, **kwargs): |
|
|
super().__init__(*args, **kwargs) |
|
|
self.expert_stats = { |
|
|
'total_tokens': 0, |
|
|
'regular_expert_usage': {}, |
|
|
'small_expert_usage': {}, |
|
|
'layer_stats': {} |
|
|
} |
|
|
self._register_hooks() |
|
|
|
|
|
def _register_hooks(self): |
|
|
"""Register forward hooks to track expert usage.""" |
|
|
if not hasattr(self.model, 'model') or not hasattr(self.model.model, 'layers'): |
|
|
logger.warning("Model doesn't have expected layer structure - expert tracking disabled") |
|
|
return |
|
|
|
|
|
for layer_idx, layer in enumerate(self.model.model.layers): |
|
|
if hasattr(layer, 'mlp') and hasattr(layer.mlp, 'experts'): |
|
|
|
|
|
layer.mlp._expert_hook_handle = layer.mlp.register_forward_hook( |
|
|
self._make_expert_hook(layer_idx) |
|
|
) |
|
|
|
|
|
def _make_expert_hook(layer_idx, model): |
|
|
def hook(module, input, output): |
|
|
|
|
|
if isinstance(output, tuple) and len(output) == 2: |
|
|
hidden_states, routing_weights = output |
|
|
else: |
|
|
hidden_states = output |
|
|
routing_weights = None |
|
|
|
|
|
|
|
|
num_small_experts = getattr(model.config, 'small_expert_count', 0) |
|
|
|
|
|
expert_stats[layer_idx] = expert_stats.get(layer_idx, {}) |
|
|
expert_stats[layer_idx]['total'] = expert_stats[layer_idx].get('total', 0) + 1 |
|
|
|
|
|
if routing_weights is not None: |
|
|
top_expert = routing_weights.argmax(dim=-1) |
|
|
for expert_id in top_expert.view(-1).tolist(): |
|
|
expert_stats[layer_idx][expert_id] = expert_stats[layer_idx].get(expert_id, 0) + 1 |
|
|
|
|
|
if expert_id < num_small_experts: |
|
|
expert_stats[layer_idx]['small'] = expert_stats[layer_idx].get('small', 0) + 1 |
|
|
|
|
|
return hook |
|
|
|
|
|
def update_expert_stats(self, layer_idx: int, topk_experts: torch.Tensor, |
|
|
topk_probs: torch.Tensor, num_regular_experts: int, |
|
|
num_small_experts: int, batch_size: int, seq_len: int): |
|
|
"""Update expert usage statistics with debug logging.""" |
|
|
|
|
|
logger.debug(f"\n{'='*40}") |
|
|
logger.debug(f"Updating stats for layer {layer_idx}") |
|
|
logger.debug(f"Input shapes - experts: {topk_experts.shape}, probs: {topk_probs.shape}") |
|
|
logger.debug(f"Num experts - regular: {num_regular_experts}, small: {num_small_experts}") |
|
|
|
|
|
|
|
|
topk_experts_flat = topk_experts.view(-1, topk_experts.size(-1)) |
|
|
topk_probs_flat = topk_probs.view(-1, topk_probs.size(-1)) |
|
|
|
|
|
|
|
|
if layer_idx not in self.expert_stats['layer_stats']: |
|
|
logger.debug(f"Initializing new layer stats with {num_regular_experts} regular and {num_small_experts} small experts") |
|
|
self.expert_stats['layer_stats'][layer_idx] = { |
|
|
'total_tokens': 0, |
|
|
'regular_expert_counts': [0] * num_regular_experts, |
|
|
'small_expert_counts': [0] * num_small_experts if num_small_experts > 0 else None, |
|
|
'regular_expert_load': [0.0] * num_regular_experts, |
|
|
'small_expert_load': [0.0] * num_small_experts if num_small_experts > 0 else None |
|
|
} |
|
|
|
|
|
layer_stats = self.expert_stats['layer_stats'][layer_idx] |
|
|
num_tokens = topk_experts_flat.size(0) |
|
|
|
|
|
|
|
|
logger.debug(f"Current layer stats structure: {layer_stats.keys()}") |
|
|
if layer_stats['small_expert_counts'] is None: |
|
|
logger.debug("Small expert counts is None - no small experts initialized") |
|
|
else: |
|
|
logger.debug(f"Small expert counts length: {len(layer_stats['small_expert_counts'])}") |
|
|
|
|
|
|
|
|
regular_expert_used = False |
|
|
for expert_idx in range(num_regular_experts): |
|
|
mask = (topk_experts_flat == expert_idx) |
|
|
count = mask.sum().item() |
|
|
if count > 0: |
|
|
regular_expert_used = True |
|
|
layer_stats['regular_expert_counts'][expert_idx] += count |
|
|
layer_stats['regular_expert_load'][expert_idx] += topk_probs_flat[mask].sum().item() |
|
|
|
|
|
if expert_idx not in self.expert_stats['regular_expert_usage']: |
|
|
self.expert_stats['regular_expert_usage'][expert_idx] = 0 |
|
|
self.expert_stats['regular_expert_usage'][expert_idx] += count |
|
|
|
|
|
|
|
|
logger.debug(f"Regular experts used this batch: {regular_expert_used}") |
|
|
|
|
|
|
|
|
if num_small_experts > 0: |
|
|
small_expert_used = False |
|
|
for expert_idx in range(num_small_experts): |
|
|
small_expert_num = expert_idx + num_regular_experts |
|
|
mask = (topk_experts_flat == small_expert_num) |
|
|
count = mask.sum().item() |
|
|
|
|
|
if count > 0: |
|
|
small_expert_used = True |
|
|
layer_stats['small_expert_counts'][expert_idx] += count |
|
|
layer_stats['small_expert_load'][expert_idx] += topk_probs_flat[mask].sum().item() |
|
|
|
|
|
if expert_idx not in self.expert_stats['small_expert_usage']: |
|
|
self.expert_stats['small_expert_usage'][expert_idx] = 0 |
|
|
self.expert_stats['small_expert_usage'][expert_idx] += count |
|
|
|
|
|
|
|
|
logger.debug(f"Small experts used this batch: {small_expert_used}") |
|
|
if not small_expert_used: |
|
|
logger.debug(f"Top-k experts sample: {topk_experts_flat[:5].tolist()}") |
|
|
logger.debug(f"Num regular experts: {num_regular_experts}, looking for experts >= this number") |
|
|
else: |
|
|
logger.debug("No small experts configured for this layer") |
|
|
|
|
|
|
|
|
self.expert_stats['total_tokens'] += num_tokens |
|
|
layer_stats['total_tokens'] += num_tokens |
|
|
logger.debug(f"Updated token counts - layer: {layer_stats['total_tokens']}, total: {self.expert_stats['total_tokens']}") |
|
|
|
|
|
def get_expert_stats(self) -> Dict[str, Any]: |
|
|
"""Return expert usage statistics in a serializable format.""" |
|
|
def convert(obj): |
|
|
"""Recursively convert objects to JSON-serializable formats.""" |
|
|
if isinstance(obj, (np.integer, np.floating)): |
|
|
return int(obj) if isinstance(obj, np.integer) else float(obj) |
|
|
elif isinstance(obj, np.ndarray): |
|
|
return obj.tolist() |
|
|
elif isinstance(obj, torch.Tensor): |
|
|
return obj.cpu().numpy().tolist() |
|
|
elif isinstance(obj, torch.dtype): |
|
|
return str(obj) |
|
|
elif isinstance(obj, (dict)): |
|
|
return {k: convert(v) for k, v in obj.items()} |
|
|
elif isinstance(obj, (list, tuple)): |
|
|
return [convert(v) for v in obj] |
|
|
else: |
|
|
return obj |
|
|
|
|
|
stats = { |
|
|
'total_tokens': convert(self.expert_stats['total_tokens']), |
|
|
'regular_expert_usage': {}, |
|
|
'small_expert_usage': {}, |
|
|
'layer_stats': {} |
|
|
} |
|
|
|
|
|
|
|
|
for expert_idx, count in self.expert_stats['regular_expert_usage'].items(): |
|
|
stats['regular_expert_usage'][expert_idx] = { |
|
|
'count': convert(count), |
|
|
'percentage': convert(count / (self.expert_stats['total_tokens'] * getattr(self.model.config, 'top_k', 1)) * 100) |
|
|
} |
|
|
|
|
|
|
|
|
if self.expert_stats['small_expert_usage']: |
|
|
for expert_idx, count in self.expert_stats['small_expert_usage'].items(): |
|
|
stats['small_expert_usage'][expert_idx] = { |
|
|
'count': convert(count), |
|
|
'percentage': convert(count / (self.expert_stats['total_tokens'] * getattr(self.model.config, 'top_k', 1)) * 100) |
|
|
} |
|
|
|
|
|
|
|
|
for layer_idx, layer_stat in self.expert_stats['layer_stats'].items(): |
|
|
stats['layer_stats'][layer_idx] = { |
|
|
'total_tokens': convert(layer_stat['total_tokens']), |
|
|
'regular_expert_counts': convert(layer_stat['regular_expert_counts']), |
|
|
'regular_expert_load': convert(layer_stat['regular_expert_load']), |
|
|
'small_expert_counts': convert(layer_stat['small_expert_counts']), |
|
|
'small_expert_load': convert(layer_stat['small_expert_load']) |
|
|
} |
|
|
|
|
|
return stats |
|
|
|
|
|
def print_expert_stats(self) -> None: |
|
|
"""Print expert usage statistics in a human-readable format.""" |
|
|
if not self.expert_stats['total_tokens']: |
|
|
print("No expert usage statistics collected.") |
|
|
return |
|
|
|
|
|
total_tokens = self.expert_stats['total_tokens'] |
|
|
top_k = getattr(self.model.config, 'top_k', 1) |
|
|
total_expert_activations = total_tokens * top_k |
|
|
|
|
|
print("\n" + "="*80) |
|
|
print("EXPERT USAGE STATISTICS") |
|
|
print("="*80) |
|
|
print(f"Total tokens processed: {total_tokens:,}") |
|
|
print(f"Total expert activations (top-{top_k}): {total_expert_activations:,}") |
|
|
print("\nOverall Expert Usage:") |
|
|
|
|
|
|
|
|
if self.expert_stats['regular_expert_usage']: |
|
|
print("\nRegular Experts:") |
|
|
for expert_idx, count in sorted(self.expert_stats['regular_expert_usage'].items()): |
|
|
percentage = count / total_expert_activations * 100 |
|
|
print(f" Expert {expert_idx}: {count:,} ({percentage:.2f}%)") |
|
|
|
|
|
|
|
|
if self.expert_stats['small_expert_usage']: |
|
|
print("\nSmall Experts:") |
|
|
for expert_idx, count in sorted(self.expert_stats['small_expert_usage'].items()): |
|
|
percentage = count / total_expert_activations * 100 |
|
|
print(f" Small Expert {expert_idx}: {count:,} ({percentage:.2f}%)") |
|
|
|
|
|
|
|
|
print("\nLayer-wise Statistics:") |
|
|
for layer_idx, layer_stat in self.expert_stats['layer_stats'].items(): |
|
|
print(f"\nLayer {layer_idx}:") |
|
|
print(f" Tokens processed: {layer_stat['total_tokens']:,}") |
|
|
|
|
|
|
|
|
print(" Regular Experts:") |
|
|
for expert_idx, (count, load) in enumerate(zip( |
|
|
layer_stat['regular_expert_counts'], |
|
|
layer_stat['regular_expert_load'] |
|
|
)): |
|
|
count_pct = count / (layer_stat['total_tokens'] * top_k) * 100 |
|
|
load_pct = load / layer_stat['total_tokens'] * 100 |
|
|
print(f" Expert {expert_idx}: Count={count:,} ({count_pct:.2f}%), Load={load:.2f} ({load_pct:.2f}%)") |
|
|
|
|
|
|
|
|
if layer_stat['small_expert_counts'] is not None: |
|
|
print(" Small Experts:") |
|
|
for expert_idx, (count, load) in enumerate(zip( |
|
|
layer_stat['small_expert_counts'], |
|
|
layer_stat['small_expert_load'] |
|
|
)): |
|
|
count_pct = count / (layer_stat['total_tokens'] * top_k) * 100 |
|
|
load_pct = load / layer_stat['total_tokens'] * 100 |
|
|
print(f" Small Expert {expert_idx}: Count={count:,} ({count_pct:.2f}%), Load={load:.2f} ({load_pct:.2f}%)") |
|
|
|
|
|
print("="*80 + "\n") |
|
|
|
|
|
def parse_args(): |
|
|
"""Parse command line arguments.""" |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Evaluate OLMoE models with expert usage tracking", |
|
|
formatter_class=argparse.RawDescriptionHelpFormatter, |
|
|
epilog=""" |
|
|
Examples: |
|
|
# Standard evaluation with expert tracking |
|
|
python eval_with_expert_tracking.py --model_type transformers --tasks mmlu arc_easy |
|
|
|
|
|
# Custom model evaluation with expert tracking |
|
|
python eval_with_expert_tracking.py --model_type custom --tasks mmlu hellaswag |
|
|
""" |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
"--model_path", |
|
|
type=str, |
|
|
default="allenai/OLMoE-1B-7B-0924", |
|
|
help="Path or name of the pretrained model" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--model_type", |
|
|
type=str, |
|
|
default="transformers", |
|
|
choices=["transformers", "custom"], |
|
|
help="Model type: 'transformers' for standard OLMoE, 'custom' for MyOLMoE" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--custom_model_path", |
|
|
type=str, |
|
|
default="./myolmoe_model", |
|
|
help="Path to custom MyOLMoE model code (when using --model_type custom)" |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
"--tasks", |
|
|
type=str, |
|
|
nargs="+", |
|
|
default=["mmlu"], |
|
|
help="Tasks to evaluate on (e.g., mmlu, hellaswag, arc_easy, gsm8k)" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--num_fewshot", |
|
|
type=int, |
|
|
default=0, |
|
|
help="Number of few-shot examples" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--batch_size", |
|
|
type=int, |
|
|
default=8, |
|
|
help="Batch size for evaluation" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--max_batch_size", |
|
|
type=int, |
|
|
default=None, |
|
|
help="Maximum batch size (auto if None)" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--device", |
|
|
type=str, |
|
|
default="auto", |
|
|
help="Device to use ('auto', 'cuda', 'cpu')" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--dtype", |
|
|
type=str, |
|
|
default="auto", |
|
|
choices=["auto", "float16", "bfloat16", "float32"], |
|
|
help="Data type for model weights" |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
"--output_dir", |
|
|
type=str, |
|
|
default="./eval_results", |
|
|
help="Directory to save evaluation results" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output_filename", |
|
|
type=str, |
|
|
default=None, |
|
|
help="Custom filename for results (auto-generated if not provided)" |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
"--limit", |
|
|
type=int, |
|
|
default=None, |
|
|
help="Limit number of examples per task (for testing)" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--write_out", |
|
|
action="store_true", |
|
|
help="Write out individual predictions to files" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--trust_remote_code", |
|
|
action="store_true", |
|
|
help="Trust remote code when loading model" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--verbosity", |
|
|
type=str, |
|
|
default="INFO", |
|
|
choices=["DEBUG", "INFO", "WARNING", "ERROR"], |
|
|
help="Logging verbosity level" |
|
|
) |
|
|
|
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def load_transformers_model(args) -> ExpertTrackingHFLM: |
|
|
""" |
|
|
Load standard Transformers OLMoE model with expert tracking. |
|
|
|
|
|
Args: |
|
|
args: Parsed command line arguments |
|
|
|
|
|
Returns: |
|
|
ExpertTrackingHFLM: Wrapped model ready for evaluation with expert tracking |
|
|
""" |
|
|
logger.info(f"Loading Transformers OLMoE model with expert tracking: {args.model_path}") |
|
|
|
|
|
|
|
|
model = ExpertTrackingHFLM( |
|
|
pretrained=args.model_path, |
|
|
device=args.device, |
|
|
batch_size=args.batch_size, |
|
|
max_batch_size=args.max_batch_size, |
|
|
dtype=args.dtype, |
|
|
trust_remote_code=args.trust_remote_code |
|
|
) |
|
|
|
|
|
logger.info("Transformers model with expert tracking loaded successfully") |
|
|
return model |
|
|
|
|
|
|
|
|
def load_custom_model(args) -> ExpertTrackingHFLM: |
|
|
""" |
|
|
Load custom MyOLMoE model with expert tracking. |
|
|
|
|
|
Args: |
|
|
args: Parsed command line arguments |
|
|
|
|
|
Returns: |
|
|
ExpertTrackingHFLM: Wrapped model ready for evaluation with expert tracking |
|
|
""" |
|
|
logger.info(f"Loading custom MyOLMoE model with expert tracking: {args.model_path}") |
|
|
|
|
|
|
|
|
if os.path.exists(args.custom_model_path): |
|
|
sys.path.insert(0, args.custom_model_path) |
|
|
logger.info(f"Added {args.custom_model_path} to Python path") |
|
|
else: |
|
|
logger.warning(f"Custom model path not found: {args.custom_model_path}") |
|
|
|
|
|
try: |
|
|
|
|
|
from modeling_myolmoe import MyOlmoeForCausalLM |
|
|
logger.info("Successfully imported MyOlmoeForCausalLM") |
|
|
except ImportError as e: |
|
|
logger.error(f"Failed to import custom model: {e}") |
|
|
logger.error("Make sure the custom model code is available in the specified path") |
|
|
raise |
|
|
|
|
|
|
|
|
config = AutoConfig.from_pretrained( |
|
|
args.model_path, |
|
|
trust_remote_code=args.trust_remote_code |
|
|
) |
|
|
|
|
|
logger.info("Model will use default top-k routing configuration") |
|
|
|
|
|
|
|
|
if args.dtype == "auto": |
|
|
torch_dtype = "auto" |
|
|
else: |
|
|
torch_dtype = { |
|
|
"float16": torch.float16, |
|
|
"bfloat16": torch.bfloat16, |
|
|
"float32": torch.float32 |
|
|
}[args.dtype] |
|
|
|
|
|
|
|
|
hf_model = MyOlmoeForCausalLM.from_pretrained( |
|
|
args.model_path, |
|
|
config=config, |
|
|
torch_dtype=torch_dtype, |
|
|
device_map="auto" if args.device == "auto" else None, |
|
|
trust_remote_code=args.trust_remote_code |
|
|
).eval() |
|
|
|
|
|
|
|
|
model = ExpertTrackingHFLM( |
|
|
pretrained=args.model_path, |
|
|
device=args.device, |
|
|
batch_size=args.batch_size, |
|
|
max_batch_size=args.max_batch_size, |
|
|
dtype=args.dtype |
|
|
) |
|
|
|
|
|
logger.info("Custom model with expert tracking loaded successfully") |
|
|
return model |
|
|
|
|
|
|
|
|
def run_evaluation(args) -> Tuple[Dict[str, Any], Dict[str, Any]]: |
|
|
""" |
|
|
Run evaluation on the specified model and return both task results and expert stats. |
|
|
|
|
|
Args: |
|
|
args: Parsed command line arguments |
|
|
|
|
|
Returns: |
|
|
Tuple of (evaluation_results, expert_stats) |
|
|
""" |
|
|
logger.info("Starting evaluation with expert tracking...") |
|
|
|
|
|
|
|
|
if args.model_type == "transformers": |
|
|
model = load_transformers_model(args) |
|
|
elif args.model_type == "custom": |
|
|
model = load_custom_model(args) |
|
|
else: |
|
|
raise ValueError(f"Unknown model type: {args.model_type}") |
|
|
|
|
|
|
|
|
logger.info(f"Running evaluation on tasks: {args.tasks}") |
|
|
logger.info(f"Few-shot examples: {args.num_fewshot}") |
|
|
logger.info(f"Batch size: {args.batch_size}") |
|
|
|
|
|
results = evaluator.simple_evaluate( |
|
|
model=model, |
|
|
tasks=args.tasks, |
|
|
num_fewshot=args.num_fewshot, |
|
|
limit=args.limit, |
|
|
write_out=args.write_out, |
|
|
) |
|
|
|
|
|
|
|
|
expert_stats = model.get_expert_stats() |
|
|
|
|
|
logger.info("Evaluation completed successfully") |
|
|
return results, expert_stats |
|
|
|
|
|
|
|
|
def save_results(results: Dict[str, Any], expert_stats: Dict[str, Any], args) -> str: |
|
|
""" |
|
|
Save evaluation results and expert statistics to file with proper serialization. |
|
|
""" |
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
if args.output_filename is None: |
|
|
model_name = os.path.basename(args.model_path.rstrip('/')) |
|
|
tasks_str = "_".join(args.tasks[:3]) |
|
|
if len(args.tasks) > 3: |
|
|
tasks_str += f"_and_{len(args.tasks)-3}_more" |
|
|
|
|
|
filename = f"{model_name}_{args.model_type}_{tasks_str}_results.json" |
|
|
else: |
|
|
filename = args.output_filename |
|
|
|
|
|
if not filename.endswith('.json'): |
|
|
filename += '.json' |
|
|
|
|
|
output_path = os.path.join(args.output_dir, filename) |
|
|
|
|
|
|
|
|
metadata = { |
|
|
"model_path": args.model_path, |
|
|
"model_type": args.model_type, |
|
|
"tasks": args.tasks, |
|
|
"num_fewshot": args.num_fewshot, |
|
|
"batch_size": args.batch_size, |
|
|
"device": args.device, |
|
|
"dtype": str(args.dtype), |
|
|
"limit": args.limit, |
|
|
} |
|
|
|
|
|
|
|
|
if args.model_type == "custom": |
|
|
metadata["routing_type"] = "top-k (default)" |
|
|
|
|
|
|
|
|
def recursive_convert(obj): |
|
|
if isinstance(obj, (np.integer, np.floating)): |
|
|
return int(obj) if isinstance(obj, np.integer) else float(obj) |
|
|
elif isinstance(obj, np.ndarray): |
|
|
return obj.tolist() |
|
|
elif isinstance(obj, torch.Tensor): |
|
|
return obj.cpu().tolist() |
|
|
elif isinstance(obj, torch.dtype): |
|
|
return str(obj) |
|
|
elif isinstance(obj, dict): |
|
|
return {k: recursive_convert(v) for k, v in obj.items()} |
|
|
elif isinstance(obj, (list, tuple)): |
|
|
return [recursive_convert(v) for v in obj] |
|
|
elif isinstance(obj, (int, float, str, bool)) or obj is None: |
|
|
return obj |
|
|
else: |
|
|
return str(obj) |
|
|
|
|
|
|
|
|
serializable_results = recursive_convert({ |
|
|
"metadata": metadata, |
|
|
"task_results": results, |
|
|
"expert_statistics": expert_stats |
|
|
}) |
|
|
|
|
|
|
|
|
|
|
|
with open(output_path, 'w') as f: |
|
|
json.dump(serializable_results, f, indent=2) |
|
|
|
|
|
logger.info(f"Results saved to {output_path}") |
|
|
return output_path |
|
|
|
|
|
|
|
|
def print_summary(results: Dict[str, Any], expert_stats: Dict[str, Any], args) -> None: |
|
|
""" |
|
|
Print a formatted summary of evaluation results and expert statistics. |
|
|
|
|
|
Args: |
|
|
results: Evaluation results |
|
|
expert_stats: Expert usage statistics |
|
|
args: Parsed command line arguments |
|
|
""" |
|
|
print(f"\n{'='*80}") |
|
|
print(f"EVALUATION SUMMARY") |
|
|
print(f"Model: {args.model_path}") |
|
|
print(f"Type: {args.model_type.upper()}") |
|
|
if args.model_type == "custom": |
|
|
print(f"Routing: TOP-K (default)") |
|
|
print(f"Tasks: {', '.join(args.tasks)}") |
|
|
print(f"{'='*80}") |
|
|
|
|
|
|
|
|
if "results" in results: |
|
|
for task, metrics in results["results"].items(): |
|
|
if isinstance(metrics, dict): |
|
|
print(f"\n📊 {task.upper()}:") |
|
|
for metric, value in metrics.items(): |
|
|
if isinstance(value, (int, float)) and not metric.endswith('_stderr'): |
|
|
stderr_key = f"{metric}_stderr" |
|
|
stderr = metrics.get(stderr_key, 0) |
|
|
print(f" {metric:.<20} {value:.4f} (±{stderr:.4f})") |
|
|
else: |
|
|
print("\n⚠️ No results found in evaluation output") |
|
|
|
|
|
|
|
|
if expert_stats: |
|
|
total_tokens = expert_stats.get('total_tokens', 0) |
|
|
if total_tokens > 0: |
|
|
top_k = getattr(args, 'top_k', 1) |
|
|
total_expert_activations = total_tokens * top_k |
|
|
|
|
|
print(f"\n🔍 EXPERT USAGE SUMMARY (Top-{top_k})") |
|
|
print(f"Total tokens processed: {total_tokens:,}") |
|
|
print(f"Total expert activations: {total_expert_activations:,}") |
|
|
|
|
|
|
|
|
if expert_stats.get('regular_expert_usage'): |
|
|
print("\nRegular Experts:") |
|
|
for expert_idx, stats in sorted(expert_stats['regular_expert_usage'].items()): |
|
|
print(f" Expert {expert_idx}: {stats['count']:,} ({stats['percentage']:.2f}%)") |
|
|
|
|
|
|
|
|
if expert_stats.get('small_expert_usage'): |
|
|
print("\nSmall Experts:") |
|
|
for expert_idx, stats in sorted(expert_stats['small_expert_usage'].items()): |
|
|
print(f" Small Expert {expert_idx}: {stats['count']:,} ({stats['percentage']:.2f}%)") |
|
|
|
|
|
|
|
|
if expert_stats.get('layer_stats'): |
|
|
print("\nLayer-wise Statistics (Top 3 most used experts per layer):") |
|
|
for layer_idx, layer_stat in expert_stats['layer_stats'].items(): |
|
|
print(f"\nLayer {layer_idx}:") |
|
|
print(f" Tokens processed: {layer_stat['total_tokens']:,}") |
|
|
|
|
|
|
|
|
if layer_stat.get('regular_expert_counts'): |
|
|
counts = layer_stat['regular_expert_counts'] |
|
|
top_indices = np.argsort(counts)[-3:][::-1] |
|
|
print(" Top Regular Experts:") |
|
|
for idx in top_indices: |
|
|
count = counts[idx] |
|
|
load = layer_stat['regular_expert_load'][idx] |
|
|
count_pct = count / (layer_stat['total_tokens'] * top_k) * 100 |
|
|
load_pct = load / layer_stat['total_tokens'] * 100 |
|
|
print(f" Expert {idx}: Count={count:,} ({count_pct:.2f}%), Load={load:.2f} ({load_pct:.2f}%)") |
|
|
|
|
|
|
|
|
if layer_stat.get('small_expert_counts'): |
|
|
counts = layer_stat['small_expert_counts'] |
|
|
top_indices = np.argsort(counts)[-3:][::-1] |
|
|
print(" Top Small Experts:") |
|
|
for idx in top_indices: |
|
|
count = counts[idx] |
|
|
load = layer_stat['small_expert_load'][idx] |
|
|
count_pct = count / (layer_stat['total_tokens'] * top_k) * 100 |
|
|
load_pct = load / layer_stat['total_tokens'] * 100 |
|
|
print(f" Small Expert {idx}: Count={count:,} ({count_pct:.2f}%), Load={load:.2f} ({load_pct:.2f}%)") |
|
|
|
|
|
print(f"\n{'='*80}") |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main evaluation function with expert tracking.""" |
|
|
args = parse_args() |
|
|
|
|
|
|
|
|
numeric_level = getattr(logging, args.verbosity.upper(), None) |
|
|
if isinstance(numeric_level, int): |
|
|
logging.getLogger().setLevel(numeric_level) |
|
|
logger.setLevel(numeric_level) |
|
|
|
|
|
try: |
|
|
logger.info("="*80) |
|
|
logger.info("Starting OLMoE Model Evaluation with Expert Tracking") |
|
|
logger.info("="*80) |
|
|
|
|
|
|
|
|
results, expert_stats = run_evaluation(args) |
|
|
|
|
|
|
|
|
output_path = save_results(results, expert_stats, args) |
|
|
|
|
|
|
|
|
print_summary(results, expert_stats, args) |
|
|
|
|
|
logger.info(f"✅ Evaluation completed successfully!") |
|
|
logger.info(f"📁 Results saved to: {output_path}") |
|
|
|
|
|
except KeyboardInterrupt: |
|
|
logger.info("Evaluation interrupted by user") |
|
|
sys.exit(1) |
|
|
except Exception as e: |
|
|
logger.error(f"❌ Evaluation failed: {e}") |
|
|
logger.debug("Full traceback:", exc_info=True) |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |