import os from dataclasses import dataclass from typing import Optional from huggingface_hub import hf_hub_download import lm_eval as evaluator import torch import torch.nn as nn import torch.nn.functional as F from safetensors.torch import load_file from torchtune.modules import RotaryPositionalEmbeddings from transformers import ( AutoConfig, AutoModel, AutoModelForCausalLM, PreTrainedModel, PretrainedConfig, ) from transformers.modeling_outputs import CausalLMOutput from flash_attn import flash_attn_func os.environ["HF_ALLOW_CODE_EVAL"] = "1" os.environ["TOKENIZERS_PARALLELISM"] = "false" loss_fn = nn.CrossEntropyLoss() class Attention(nn.Module): def __init__(self, config): super().__init__() self.wq = nn.Linear(config.dim, config.dim) self.wk = nn.Linear(config.dim, config.dim) self.wv = nn.Linear(config.dim, config.dim) self.wo = nn.Linear(config.dim, config.dim) self.wo.SCALE_INIT = 1 self.dim = config.dim self.head_dim = config.head_dim self.num_heads = config.num_heads self.num_local_heads = config.num_local_heads self.rotary_emb = RotaryPositionalEmbeddings( dim=self.head_dim, max_seq_len=config.seq_len, base=config.rope_theta, ) def forward(self, x): bsz, seq_len, dim = x.shape q, k, v = self.wq(x), self.wk(x), self.wv(x) q = q.view(bsz, seq_len, self.num_heads, self.head_dim) k = k.view(bsz, seq_len, self.num_local_heads, self.head_dim) v = v.view(bsz, seq_len, self.num_local_heads, self.head_dim) q, k = self.rotary_emb(q), self.rotary_emb(k) y = flash_attn_func( q=q, k=k, v=v, causal=True, ) out = y.reshape(bsz, seq_len, -1) out = self.wo(out) return out def find_multiple(n: int, k: int) -> int: if n % k == 0: return n return n + k - (n % k) class BaseConfigForCausalLM(PretrainedConfig): """Base PretrainedConfig class to be decorated with dataclass""" model_type = "base_model" @dataclass class TransformerConfig(BaseConfigForCausalLM): model_type = "Transformer" # Define fields with defaults (as before) bsz: int = 1 dim: int = 768 num_heads: int = 12 num_local_heads: int = -1 num_layers: int = 12 seq_len: int = 4096 vocab_size: int = 200064 inter_dim: Optional[int] = None mlp_scale: float = 12.0 weight_tying: bool = True bias: bool = False rope_theta: float = 10000.0 torch_dtype: str = "torch.bfloat16" device: Optional[str] = None head_dim: Optional[int] = None def __init__( self, bsz: int = 1, dim: int = 768, num_heads: int = 12, num_local_heads: int = -1, num_layers: int = 12, seq_len: int = 4096, vocab_size: int = 200064, inter_dim: Optional[int] = None, mlp_scale: float = 12.0, weight_tying: bool = True, bias: bool = False, rope_theta: float = 10000.0, torch_dtype: str = "torch.bfloat16", device: Optional[str] = None, head_dim: Optional[int] = None, **kwargs, ): super().__init__(**kwargs) self.bsz = bsz self.dim = dim self.num_heads = num_heads self.num_local_heads = num_local_heads self.num_layers = num_layers self.seq_len = seq_len self.vocab_size = vocab_size self.inter_dim = inter_dim self.mlp_scale = mlp_scale self.weight_tying = weight_tying self.bias = bias self.rope_theta = rope_theta self.torch_dtype = torch_dtype self.device = device self.head_dim = head_dim self._post_init_logic() def _post_init_logic(self): if self.num_local_heads == -1: self.num_local_heads = self.num_heads if self.inter_dim is None: hidden_dim = self.mlp_scale * self.dim num_hidden = int(2 * hidden_dim / 3) multiple = 256 self.inter_dim = find_multiple(num_hidden, multiple) if num_hidden > 0 else multiple if self.num_heads > 0: self.head_dim = self.dim // self.num_heads else: raise ValueError("num_heads must be positive") if isinstance(self.torch_dtype, str): dtype_str = self.torch_dtype.replace("torch.", "") try: self.torch_dtype = getattr(torch, dtype_str) except AttributeError as err: raise ValueError(f"Invalid torch_dtype string: {self.torch_dtype}") from err elif not isinstance(self.torch_dtype, torch.dtype): raise ValueError(f"torch_dtype must be a string or torch.dtype, got {type(self.torch_dtype)}") if isinstance(self.device, str): self.device = torch.device(self.device) @classmethod def from_name(cls, name: str): print("Not yet implemented") pass class MLP(nn.Module): def __init__(self, config: TransformerConfig) -> None: super().__init__() self.w1 = nn.Linear(config.dim, config.inter_dim) self.w2 = nn.Linear(config.inter_dim, config.dim) self.w2.SCALE_INIT = 1 def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w2(F.gelu(self.w1(x), approximate="tanh")) class TransformerLayer(nn.Module): def __init__(self, config): super().__init__() self.attn_norm = nn.LayerNorm(config.dim, dtype=config.torch_dtype) self.attn = Attention(config) self.mlp_norm = nn.LayerNorm(config.dim, dtype=config.torch_dtype) self.mlp = MLP(config) def forward(self, x): x = x + self.attn(self.attn_norm(x)) x = x + self.mlp(self.mlp_norm(x)) return x class Transformer(nn.Module): def __init__(self, config): super().__init__() self.config = config self.tok_emb = nn.Embedding(config.vocab_size, config.dim) self.layers = nn.ModuleList() for _ in range(config.num_layers): self.layers.append(TransformerLayer(config)) self.norm_f = nn.LayerNorm(config.dim, dtype=config.torch_dtype) self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=False) if self.config.weight_tying: self.tok_emb.weight = self.lm_head.weight self.std = self.config.dim**-0.5 def init_weights(self, module): std = self.std if isinstance(module, nn.Linear): if hasattr(module, "SCALE_INIT"): std *= (2 * self.config.num_layers) ** -0.5 torch.nn.init.normal_(module.weight, mean=0.0, std=std) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=std) def forward(self, input_ids: torch.Tensor, labels: torch.Tensor = None, **kwargs) -> CausalLMOutput: x = self.tok_emb(input_ids) for layer in self.layers: x = layer(x) x = self.norm_f(x) logits = self.lm_head(x) loss = None if labels is not None: loss = loss_fn(logits.flatten(0, 1), labels.flatten(0, 1)) return CausalLMOutput( loss=loss, logits=logits, ) def get_num_params(self): """ Return the number of parameters in the model. For non-embedding count (default), the position embeddings get subtracted. """ n_params = sum(p.numel() for p in self.parameters()) return n_params def create_base_model_components(model_name_or_path=None, **kwargs): """Just load the config.""" if model_name_or_path is not None: config = TransformerConfig.from_pretrained(model_name_or_path, **kwargs) else: config = TransformerConfig(**kwargs) return config class TransformerForCausalLM(PreTrainedModel): """Thin wrapper to comply with HuggingFace's expected interface""" config_class = TransformerConfig base_model_prefix = "transformer" def __init__(self, config): super().__init__(config) self.transformer = Transformer(config) self.transformer.apply(self.transformer.init_weights) def forward( self, input_ids: torch.Tensor, labels: torch.Tensor = None, attention_mask: torch.Tensor = None, **kwargs ) -> CausalLMOutput: outputs = self.transformer(input_ids, labels=labels, **kwargs) return outputs def generate( self, input_ids: torch.Tensor, max_length: int = 32, num_return_sequences: int = 4, temperature: float = 0.8, top_k: int = 50, top_p: float = 0.95, repetition_penalty: float = 1.2, seed: int = 42, ) -> torch.Tensor: """Generate text using top-k and nucleus sampling with temperature and repetition penalty. Args: input_ids: Input token ids of shape (batch_size, seq_len) max_length: Maximum length of generated sequence num_return_sequences: Number of sequences to generate per input temperature: Sampling temperature. Higher = more random, lower = more focused top_k: Number of highest probability tokens to keep for top-k sampling top_p: Cumulative probability cutoff for nucleus sampling repetition_penalty: Penalty factor for repeating tokens. 1.0 = no penalty seed: Random seed for reproducibility Returns: Generated token ids of shape (num_return_sequences, max_length) """ self.eval() # Set to eval mode device = input_ids.device # Expand input for multiple sequences input_ids = input_ids.repeat(num_return_sequences, 1) generated = input_ids # Set up generator for reproducible sampling sample_rng = torch.Generator(device=device) sample_rng.manual_seed(seed) # Generate tokens until we reach max_length with torch.no_grad(): while generated.size(1) < max_length: # Get logits for next token outputs = self.transformer(generated) next_token_logits = outputs.logits[:, -1, :] # Apply repetition penalty if repetition_penalty != 1.0: for i in range(generated.shape[0]): for token in generated[i]: if token in next_token_logits[i]: next_token_logits[i, token] /= repetition_penalty # Apply temperature if temperature != 1.0: next_token_logits = next_token_logits / temperature # Get probabilities probs = torch.nn.functional.softmax(next_token_logits, dim=-1) # Top-k sampling if top_k > 0: indices_to_remove = probs < torch.topk(probs, top_k)[0][..., -1, None] probs[indices_to_remove] = 0 # Nucleus (top-p) sampling if top_p < 1.0: sorted_probs, sorted_indices = torch.sort(probs, descending=True) cumulative_probs = torch.cumsum(sorted_probs, dim=-1) # Remove tokens with cumulative probability above the threshold sorted_indices_to_remove = cumulative_probs > top_p # Shift the indices to the right to keep also the first token above the threshold sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) probs[indices_to_remove] = 0 # Renormalize probabilities probs = probs / probs.sum(dim=-1, keepdim=True).clamp(min=1e-8) # Sample next token next_token = torch.multinomial(probs, num_samples=1, generator=sample_rng) # Append to generated sequence generated = torch.cat([generated, next_token], dim=1) return generated def get_num_params(self): return self.transformer.get_num_params() @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): # Get config and create model config = create_base_model_components(pretrained_model_name_or_path, **kwargs) model = cls(config) # Download safetensors file from hub weights_path = hf_hub_download( repo_id=pretrained_model_name_or_path, filename="model.safetensors", cache_dir=kwargs.get("cache_dir"), force_download=kwargs.get("force_download", False), proxies=kwargs.get("proxies", None), local_files_only=kwargs.get("local_files_only", False), use_auth_token=kwargs.get("use_auth_token", None), revision=kwargs.get("revision", None), subfolder=kwargs.get("subfolder", ""), ) # Load the state dict and metadata from safetensors state_dict = load_file(weights_path) # Reconstruct weight tying for tok_emb and lm_head specifically tok_emb_key = "tok_emb.weight" lm_head_key = "lm_head.weight" tok_emb_present = tok_emb_key in state_dict lm_head_present = lm_head_key in state_dict if tok_emb_present and not lm_head_present: print(f"Reconstructing weight tying: Linking missing '{lm_head_key}' to existing '{tok_emb_key}'") state_dict[lm_head_key] = state_dict[tok_emb_key] elif lm_head_present and not tok_emb_present: print(f"Reconstructing weight tying: Linking missing '{tok_emb_key}' to existing '{lm_head_key}'") state_dict[tok_emb_key] = state_dict[lm_head_key] elif not tok_emb_present and not lm_head_present: # This case should ideally not happen if the file is valid print( f"Warning: Neither '{tok_emb_key}' nor '{lm_head_key}' found in state_dict. Weight tying cannot be reconstructed." ) # If both are present, assume they are loaded correctly (or were never tied) # Prepend prefix to all keys to match wrapper's state dict final_state_dict = {f"{cls.base_model_prefix}.{k}": v for k, v in state_dict.items()} model.load_state_dict(final_state_dict) # Move to GPU if available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device=device, dtype=torch.bfloat16) model.eval() # Print parameter count as a sanity check num_params = model.get_num_params() print(f"\nModel loaded: {pretrained_model_name_or_path}") print(f"Parameter count: {num_params / 1e6:.2f}M") return model # Create initial config using the correct class config = TransformerConfig() # Register models with correct names AutoConfig.register("Transformer", TransformerConfig) AutoModel.register(TransformerConfig, Transformer) AutoModelForCausalLM.register(TransformerConfig, TransformerForCausalLM) print("Registered Transformer model and configuration.") def run_model_diagnostics(model, tokenizer, device): """Run detailed diagnostics to analyze model behavior.""" print("\nRunning model diagnostics...") # Test cases of varying difficulty and length test_cases = [ # Simple completion "2 + 2 =", # Medium difficulty "The capital of France is Paris. The capital of Germany is", # Complex reasoning "If a train travels 120 kilometers in 2 hours, its average speed is", # Pattern completion "1, 2, 3, 4,", # Long context "The following is a detailed explanation of photosynthesis: Plants use sunlight to", ] with torch.no_grad(): for prompt in test_cases: print(f"\nAnalyzing prompt: {prompt}") # Tokenize tokens = tokenizer(prompt, return_tensors="pt") input_ids = tokens["input_ids"].to(device) # Get model outputs with attention patterns outputs = model.transformer(input_ids, labels=input_ids) # Analyze loss at different positions labels = input_ids.clone() shift_logits = outputs.logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = nn.CrossEntropyLoss(reduction="none") token_losses = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).view( shift_labels.size() ) # Print token-by-token analysis input_tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) print("\nToken-by-token loss:") for _, (token, loss) in enumerate(zip(input_tokens[1:], token_losses[0])): print(f"{token}: {loss.item():.3f}") print(f"Average loss: {token_losses.mean().item():.3f}") # Generate with different temperatures temps = [0.5, 0.7, 1.0] print("\nGeneration temperature comparison:") for temp in temps: gen_ids = model.generate( input_ids, max_length=25, num_return_sequences=1, temperature=temp, top_p=0.9, repetition_penalty=1.5, seed=42, ) gen_text = tokenizer.decode(gen_ids[0], skip_special_tokens=True) print(f"\nTemp {temp}: {gen_text}") def validate_model_generation(): print("\nRunning generation validation test...") try: from transformers import AutoTokenizer # Load model and tokenizer model_id = "Hazan-Lab/Transformer-340M-0428" model = TransformerForCausalLM.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) # Move to GPU if available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device=device, dtype=torch.bfloat16) model.eval() # Print parameter count as a sanity check num_params = model.get_num_params() print(f"\nModel loaded: {model_id}") print(f"Parameter count: {num_params / 1e6:.2f}M") # Run additional diagnostics run_model_diagnostics(model, tokenizer, device) except Exception as e: print(f"\nError during validation: {str(e)}") raise # Run evaluation tasks tasks = [ "hellaswag", # "mmlu", # "piqa", # "siqa", # "boolq", # "winogrande", # "commonsense_qa", # "openbookqa", # "arc", # "arc_easy", # "arc_challenge", # "triviaqa", # "nq_open", # "humaneval", # "mbpp", # "gms8k", # "hendrycks_math", # "mathqa", # "minerva_math", # "score", # "asdiv", # "agieval", # "bigbench", ] tasks_fewshot = { "hellaswag": 0, # "mmlu": 5, # "piqa": 0, # "siqa": 0, # "boolq": 0, # "winogrande": -1, # "commonsense_qa": 7, # "openbookqa": -1, # "arc": -1, # "arc_easy": -1, # "arc_challenge": -1, # "triviaqa": 5, # "nq_open": 5, # "humaneval": -1, # "mbpp": 3, # "gms8k": -1, # "hendrycks_math": 4, # "mathqa": -1, # "minerva_math": -1, # "score": -1, # "asdiv": -1, # "agieval": -1, # "bigbench": -1, } all_results = {} # First validate generation works validate_model_generation() model_id = "Hazan-Lab/Transformer-340M-0428" print("\nStarting evaluation tasks...") for task in tasks: print(f"\nEvaluating task: {task}") eval_kwargs = dict( model="hf", model_args=( f"pretrained={model_id}," "trust_remote_code=True," "dtype=bfloat16," "cache_dir=/scratch/gpfs/mn4560/hazan-lab/tensorized_filters/tensorized_filters/eval/cache" ), tasks=[task], batch_size="auto", device="cuda:0", ) few_shot_value = tasks_fewshot.get(task, -1) if few_shot_value != -1: eval_kwargs["num_fewshot"] = few_shot_value results = evaluator.simple_evaluate(**eval_kwargs) task_result = results["results"].get(task, {}) all_results[task] = task_result print(f"Results for {task}:") print(task_result) print("\n" + "=" * 50 + "\n") print("All Evaluation Results:") for task, result in all_results.items(): print(f"{task}: {result}")