Transformer-340M-0428 / evaluate.py
windsornguyen's picture
add: eval script
dc91c82 verified
import os
from dataclasses import dataclass
from typing import Optional
from huggingface_hub import hf_hub_download
import lm_eval as evaluator
import torch
import torch.nn as nn
import torch.nn.functional as F
from safetensors.torch import load_file
from torchtune.modules import RotaryPositionalEmbeddings
from transformers import (
AutoConfig,
AutoModel,
AutoModelForCausalLM,
PreTrainedModel,
PretrainedConfig,
)
from transformers.modeling_outputs import CausalLMOutput
from flash_attn import flash_attn_func
os.environ["HF_ALLOW_CODE_EVAL"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
loss_fn = nn.CrossEntropyLoss()
class Attention(nn.Module):
def __init__(self, config):
super().__init__()
self.wq = nn.Linear(config.dim, config.dim)
self.wk = nn.Linear(config.dim, config.dim)
self.wv = nn.Linear(config.dim, config.dim)
self.wo = nn.Linear(config.dim, config.dim)
self.wo.SCALE_INIT = 1
self.dim = config.dim
self.head_dim = config.head_dim
self.num_heads = config.num_heads
self.num_local_heads = config.num_local_heads
self.rotary_emb = RotaryPositionalEmbeddings(
dim=self.head_dim,
max_seq_len=config.seq_len,
base=config.rope_theta,
)
def forward(self, x):
bsz, seq_len, dim = x.shape
q, k, v = self.wq(x), self.wk(x), self.wv(x)
q = q.view(bsz, seq_len, self.num_heads, self.head_dim)
k = k.view(bsz, seq_len, self.num_local_heads, self.head_dim)
v = v.view(bsz, seq_len, self.num_local_heads, self.head_dim)
q, k = self.rotary_emb(q), self.rotary_emb(k)
y = flash_attn_func(
q=q,
k=k,
v=v,
causal=True,
)
out = y.reshape(bsz, seq_len, -1)
out = self.wo(out)
return out
def find_multiple(n: int, k: int) -> int:
if n % k == 0:
return n
return n + k - (n % k)
class BaseConfigForCausalLM(PretrainedConfig):
"""Base PretrainedConfig class to be decorated with dataclass"""
model_type = "base_model"
@dataclass
class TransformerConfig(BaseConfigForCausalLM):
model_type = "Transformer"
# Define fields with defaults (as before)
bsz: int = 1
dim: int = 768
num_heads: int = 12
num_local_heads: int = -1
num_layers: int = 12
seq_len: int = 4096
vocab_size: int = 200064
inter_dim: Optional[int] = None
mlp_scale: float = 12.0
weight_tying: bool = True
bias: bool = False
rope_theta: float = 10000.0
torch_dtype: str = "torch.bfloat16"
device: Optional[str] = None
head_dim: Optional[int] = None
def __init__(
self,
bsz: int = 1,
dim: int = 768,
num_heads: int = 12,
num_local_heads: int = -1,
num_layers: int = 12,
seq_len: int = 4096,
vocab_size: int = 200064,
inter_dim: Optional[int] = None,
mlp_scale: float = 12.0,
weight_tying: bool = True,
bias: bool = False,
rope_theta: float = 10000.0,
torch_dtype: str = "torch.bfloat16",
device: Optional[str] = None,
head_dim: Optional[int] = None,
**kwargs,
):
super().__init__(**kwargs)
self.bsz = bsz
self.dim = dim
self.num_heads = num_heads
self.num_local_heads = num_local_heads
self.num_layers = num_layers
self.seq_len = seq_len
self.vocab_size = vocab_size
self.inter_dim = inter_dim
self.mlp_scale = mlp_scale
self.weight_tying = weight_tying
self.bias = bias
self.rope_theta = rope_theta
self.torch_dtype = torch_dtype
self.device = device
self.head_dim = head_dim
self._post_init_logic()
def _post_init_logic(self):
if self.num_local_heads == -1:
self.num_local_heads = self.num_heads
if self.inter_dim is None:
hidden_dim = self.mlp_scale * self.dim
num_hidden = int(2 * hidden_dim / 3)
multiple = 256
self.inter_dim = find_multiple(num_hidden, multiple) if num_hidden > 0 else multiple
if self.num_heads > 0:
self.head_dim = self.dim // self.num_heads
else:
raise ValueError("num_heads must be positive")
if isinstance(self.torch_dtype, str):
dtype_str = self.torch_dtype.replace("torch.", "")
try:
self.torch_dtype = getattr(torch, dtype_str)
except AttributeError as err:
raise ValueError(f"Invalid torch_dtype string: {self.torch_dtype}") from err
elif not isinstance(self.torch_dtype, torch.dtype):
raise ValueError(f"torch_dtype must be a string or torch.dtype, got {type(self.torch_dtype)}")
if isinstance(self.device, str):
self.device = torch.device(self.device)
@classmethod
def from_name(cls, name: str):
print("Not yet implemented")
pass
class MLP(nn.Module):
def __init__(self, config: TransformerConfig) -> None:
super().__init__()
self.w1 = nn.Linear(config.dim, config.inter_dim)
self.w2 = nn.Linear(config.inter_dim, config.dim)
self.w2.SCALE_INIT = 1
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w2(F.gelu(self.w1(x), approximate="tanh"))
class TransformerLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.attn_norm = nn.LayerNorm(config.dim, dtype=config.torch_dtype)
self.attn = Attention(config)
self.mlp_norm = nn.LayerNorm(config.dim, dtype=config.torch_dtype)
self.mlp = MLP(config)
def forward(self, x):
x = x + self.attn(self.attn_norm(x))
x = x + self.mlp(self.mlp_norm(x))
return x
class Transformer(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.tok_emb = nn.Embedding(config.vocab_size, config.dim)
self.layers = nn.ModuleList()
for _ in range(config.num_layers):
self.layers.append(TransformerLayer(config))
self.norm_f = nn.LayerNorm(config.dim, dtype=config.torch_dtype)
self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=False)
if self.config.weight_tying:
self.tok_emb.weight = self.lm_head.weight
self.std = self.config.dim**-0.5
def init_weights(self, module):
std = self.std
if isinstance(module, nn.Linear):
if hasattr(module, "SCALE_INIT"):
std *= (2 * self.config.num_layers) ** -0.5
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
def forward(self, input_ids: torch.Tensor, labels: torch.Tensor = None, **kwargs) -> CausalLMOutput:
x = self.tok_emb(input_ids)
for layer in self.layers:
x = layer(x)
x = self.norm_f(x)
logits = self.lm_head(x)
loss = None
if labels is not None:
loss = loss_fn(logits.flatten(0, 1), labels.flatten(0, 1))
return CausalLMOutput(
loss=loss,
logits=logits,
)
def get_num_params(self):
"""
Return the number of parameters in the model.
For non-embedding count (default), the position embeddings get subtracted.
"""
n_params = sum(p.numel() for p in self.parameters())
return n_params
def create_base_model_components(model_name_or_path=None, **kwargs):
"""Just load the config."""
if model_name_or_path is not None:
config = TransformerConfig.from_pretrained(model_name_or_path, **kwargs)
else:
config = TransformerConfig(**kwargs)
return config
class TransformerForCausalLM(PreTrainedModel):
"""Thin wrapper to comply with HuggingFace's expected interface"""
config_class = TransformerConfig
base_model_prefix = "transformer"
def __init__(self, config):
super().__init__(config)
self.transformer = Transformer(config)
self.transformer.apply(self.transformer.init_weights)
def forward(
self, input_ids: torch.Tensor, labels: torch.Tensor = None, attention_mask: torch.Tensor = None, **kwargs
) -> CausalLMOutput:
outputs = self.transformer(input_ids, labels=labels, **kwargs)
return outputs
def generate(
self,
input_ids: torch.Tensor,
max_length: int = 32,
num_return_sequences: int = 4,
temperature: float = 0.8,
top_k: int = 50,
top_p: float = 0.95,
repetition_penalty: float = 1.2,
seed: int = 42,
) -> torch.Tensor:
"""Generate text using top-k and nucleus sampling with temperature and repetition penalty.
Args:
input_ids: Input token ids of shape (batch_size, seq_len)
max_length: Maximum length of generated sequence
num_return_sequences: Number of sequences to generate per input
temperature: Sampling temperature. Higher = more random, lower = more focused
top_k: Number of highest probability tokens to keep for top-k sampling
top_p: Cumulative probability cutoff for nucleus sampling
repetition_penalty: Penalty factor for repeating tokens. 1.0 = no penalty
seed: Random seed for reproducibility
Returns:
Generated token ids of shape (num_return_sequences, max_length)
"""
self.eval() # Set to eval mode
device = input_ids.device
# Expand input for multiple sequences
input_ids = input_ids.repeat(num_return_sequences, 1)
generated = input_ids
# Set up generator for reproducible sampling
sample_rng = torch.Generator(device=device)
sample_rng.manual_seed(seed)
# Generate tokens until we reach max_length
with torch.no_grad():
while generated.size(1) < max_length:
# Get logits for next token
outputs = self.transformer(generated)
next_token_logits = outputs.logits[:, -1, :]
# Apply repetition penalty
if repetition_penalty != 1.0:
for i in range(generated.shape[0]):
for token in generated[i]:
if token in next_token_logits[i]:
next_token_logits[i, token] /= repetition_penalty
# Apply temperature
if temperature != 1.0:
next_token_logits = next_token_logits / temperature
# Get probabilities
probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
# Top-k sampling
if top_k > 0:
indices_to_remove = probs < torch.topk(probs, top_k)[0][..., -1, None]
probs[indices_to_remove] = 0
# Nucleus (top-p) sampling
if top_p < 1.0:
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
probs[indices_to_remove] = 0
# Renormalize probabilities
probs = probs / probs.sum(dim=-1, keepdim=True).clamp(min=1e-8)
# Sample next token
next_token = torch.multinomial(probs, num_samples=1, generator=sample_rng)
# Append to generated sequence
generated = torch.cat([generated, next_token], dim=1)
return generated
def get_num_params(self):
return self.transformer.get_num_params()
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# Get config and create model
config = create_base_model_components(pretrained_model_name_or_path, **kwargs)
model = cls(config)
# Download safetensors file from hub
weights_path = hf_hub_download(
repo_id=pretrained_model_name_or_path,
filename="model.safetensors",
cache_dir=kwargs.get("cache_dir"),
force_download=kwargs.get("force_download", False),
proxies=kwargs.get("proxies", None),
local_files_only=kwargs.get("local_files_only", False),
use_auth_token=kwargs.get("use_auth_token", None),
revision=kwargs.get("revision", None),
subfolder=kwargs.get("subfolder", ""),
)
# Load the state dict and metadata from safetensors
state_dict = load_file(weights_path)
# Reconstruct weight tying for tok_emb and lm_head specifically
tok_emb_key = "tok_emb.weight"
lm_head_key = "lm_head.weight"
tok_emb_present = tok_emb_key in state_dict
lm_head_present = lm_head_key in state_dict
if tok_emb_present and not lm_head_present:
print(f"Reconstructing weight tying: Linking missing '{lm_head_key}' to existing '{tok_emb_key}'")
state_dict[lm_head_key] = state_dict[tok_emb_key]
elif lm_head_present and not tok_emb_present:
print(f"Reconstructing weight tying: Linking missing '{tok_emb_key}' to existing '{lm_head_key}'")
state_dict[tok_emb_key] = state_dict[lm_head_key]
elif not tok_emb_present and not lm_head_present:
# This case should ideally not happen if the file is valid
print(
f"Warning: Neither '{tok_emb_key}' nor '{lm_head_key}' found in state_dict. Weight tying cannot be reconstructed."
)
# If both are present, assume they are loaded correctly (or were never tied)
# Prepend prefix to all keys to match wrapper's state dict
final_state_dict = {f"{cls.base_model_prefix}.{k}": v for k, v in state_dict.items()}
model.load_state_dict(final_state_dict)
# Move to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device=device, dtype=torch.bfloat16)
model.eval()
# Print parameter count as a sanity check
num_params = model.get_num_params()
print(f"\nModel loaded: {pretrained_model_name_or_path}")
print(f"Parameter count: {num_params / 1e6:.2f}M")
return model
# Create initial config using the correct class
config = TransformerConfig()
# Register models with correct names
AutoConfig.register("Transformer", TransformerConfig)
AutoModel.register(TransformerConfig, Transformer)
AutoModelForCausalLM.register(TransformerConfig, TransformerForCausalLM)
print("Registered Transformer model and configuration.")
def run_model_diagnostics(model, tokenizer, device):
"""Run detailed diagnostics to analyze model behavior."""
print("\nRunning model diagnostics...")
# Test cases of varying difficulty and length
test_cases = [
# Simple completion
"2 + 2 =",
# Medium difficulty
"The capital of France is Paris. The capital of Germany is",
# Complex reasoning
"If a train travels 120 kilometers in 2 hours, its average speed is",
# Pattern completion
"1, 2, 3, 4,",
# Long context
"The following is a detailed explanation of photosynthesis: Plants use sunlight to",
]
with torch.no_grad():
for prompt in test_cases:
print(f"\nAnalyzing prompt: {prompt}")
# Tokenize
tokens = tokenizer(prompt, return_tensors="pt")
input_ids = tokens["input_ids"].to(device)
# Get model outputs with attention patterns
outputs = model.transformer(input_ids, labels=input_ids)
# Analyze loss at different positions
labels = input_ids.clone()
shift_logits = outputs.logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss(reduction="none")
token_losses = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).view(
shift_labels.size()
)
# Print token-by-token analysis
input_tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
print("\nToken-by-token loss:")
for _, (token, loss) in enumerate(zip(input_tokens[1:], token_losses[0])):
print(f"{token}: {loss.item():.3f}")
print(f"Average loss: {token_losses.mean().item():.3f}")
# Generate with different temperatures
temps = [0.5, 0.7, 1.0]
print("\nGeneration temperature comparison:")
for temp in temps:
gen_ids = model.generate(
input_ids,
max_length=25,
num_return_sequences=1,
temperature=temp,
top_p=0.9,
repetition_penalty=1.5,
seed=42,
)
gen_text = tokenizer.decode(gen_ids[0], skip_special_tokens=True)
print(f"\nTemp {temp}: {gen_text}")
def validate_model_generation():
print("\nRunning generation validation test...")
try:
from transformers import AutoTokenizer
# Load model and tokenizer
model_id = "Hazan-Lab/Transformer-340M-0428"
model = TransformerForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
# Move to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device=device, dtype=torch.bfloat16)
model.eval()
# Print parameter count as a sanity check
num_params = model.get_num_params()
print(f"\nModel loaded: {model_id}")
print(f"Parameter count: {num_params / 1e6:.2f}M")
# Run additional diagnostics
run_model_diagnostics(model, tokenizer, device)
except Exception as e:
print(f"\nError during validation: {str(e)}")
raise
# Run evaluation tasks
tasks = [
"hellaswag",
# "mmlu",
# "piqa",
# "siqa",
# "boolq",
# "winogrande",
# "commonsense_qa",
# "openbookqa",
# "arc",
# "arc_easy",
# "arc_challenge",
# "triviaqa",
# "nq_open",
# "humaneval",
# "mbpp",
# "gms8k",
# "hendrycks_math",
# "mathqa",
# "minerva_math",
# "score",
# "asdiv",
# "agieval",
# "bigbench",
]
tasks_fewshot = {
"hellaswag": 0,
# "mmlu": 5,
# "piqa": 0,
# "siqa": 0,
# "boolq": 0,
# "winogrande": -1,
# "commonsense_qa": 7,
# "openbookqa": -1,
# "arc": -1,
# "arc_easy": -1,
# "arc_challenge": -1,
# "triviaqa": 5,
# "nq_open": 5,
# "humaneval": -1,
# "mbpp": 3,
# "gms8k": -1,
# "hendrycks_math": 4,
# "mathqa": -1,
# "minerva_math": -1,
# "score": -1,
# "asdiv": -1,
# "agieval": -1,
# "bigbench": -1,
}
all_results = {}
# First validate generation works
validate_model_generation()
model_id = "Hazan-Lab/Transformer-340M-0428"
print("\nStarting evaluation tasks...")
for task in tasks:
print(f"\nEvaluating task: {task}")
eval_kwargs = dict(
model="hf",
model_args=(
f"pretrained={model_id},"
"trust_remote_code=True,"
"dtype=bfloat16,"
"cache_dir=/scratch/gpfs/mn4560/hazan-lab/tensorized_filters/tensorized_filters/eval/cache"
),
tasks=[task],
batch_size="auto",
device="cuda:0",
)
few_shot_value = tasks_fewshot.get(task, -1)
if few_shot_value != -1:
eval_kwargs["num_fewshot"] = few_shot_value
results = evaluator.simple_evaluate(**eval_kwargs)
task_result = results["results"].get(task, {})
all_results[task] = task_result
print(f"Results for {task}:")
print(task_result)
print("\n" + "=" * 50 + "\n")
print("All Evaluation Results:")
for task, result in all_results.items():
print(f"{task}: {result}")