|
|
import os |
|
|
from dataclasses import dataclass |
|
|
from typing import Optional |
|
|
|
|
|
from huggingface_hub import hf_hub_download |
|
|
import lm_eval as evaluator |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from safetensors.torch import load_file |
|
|
from torchtune.modules import RotaryPositionalEmbeddings |
|
|
from transformers import ( |
|
|
AutoConfig, |
|
|
AutoModel, |
|
|
AutoModelForCausalLM, |
|
|
PreTrainedModel, |
|
|
PretrainedConfig, |
|
|
) |
|
|
from transformers.modeling_outputs import CausalLMOutput |
|
|
from flash_attn import flash_attn_func |
|
|
|
|
|
os.environ["HF_ALLOW_CODE_EVAL"] = "1" |
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
loss_fn = nn.CrossEntropyLoss() |
|
|
|
|
|
|
|
|
class Attention(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.wq = nn.Linear(config.dim, config.dim) |
|
|
self.wk = nn.Linear(config.dim, config.dim) |
|
|
self.wv = nn.Linear(config.dim, config.dim) |
|
|
self.wo = nn.Linear(config.dim, config.dim) |
|
|
self.wo.SCALE_INIT = 1 |
|
|
|
|
|
self.dim = config.dim |
|
|
self.head_dim = config.head_dim |
|
|
self.num_heads = config.num_heads |
|
|
self.num_local_heads = config.num_local_heads |
|
|
|
|
|
self.rotary_emb = RotaryPositionalEmbeddings( |
|
|
dim=self.head_dim, |
|
|
max_seq_len=config.seq_len, |
|
|
base=config.rope_theta, |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
bsz, seq_len, dim = x.shape |
|
|
|
|
|
q, k, v = self.wq(x), self.wk(x), self.wv(x) |
|
|
q = q.view(bsz, seq_len, self.num_heads, self.head_dim) |
|
|
k = k.view(bsz, seq_len, self.num_local_heads, self.head_dim) |
|
|
v = v.view(bsz, seq_len, self.num_local_heads, self.head_dim) |
|
|
q, k = self.rotary_emb(q), self.rotary_emb(k) |
|
|
|
|
|
y = flash_attn_func( |
|
|
q=q, |
|
|
k=k, |
|
|
v=v, |
|
|
causal=True, |
|
|
) |
|
|
|
|
|
out = y.reshape(bsz, seq_len, -1) |
|
|
out = self.wo(out) |
|
|
|
|
|
return out |
|
|
|
|
|
|
|
|
def find_multiple(n: int, k: int) -> int: |
|
|
if n % k == 0: |
|
|
return n |
|
|
return n + k - (n % k) |
|
|
|
|
|
|
|
|
class BaseConfigForCausalLM(PretrainedConfig): |
|
|
"""Base PretrainedConfig class to be decorated with dataclass""" |
|
|
|
|
|
model_type = "base_model" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TransformerConfig(BaseConfigForCausalLM): |
|
|
model_type = "Transformer" |
|
|
|
|
|
|
|
|
bsz: int = 1 |
|
|
dim: int = 768 |
|
|
num_heads: int = 12 |
|
|
num_local_heads: int = -1 |
|
|
num_layers: int = 12 |
|
|
seq_len: int = 4096 |
|
|
vocab_size: int = 200064 |
|
|
inter_dim: Optional[int] = None |
|
|
mlp_scale: float = 12.0 |
|
|
weight_tying: bool = True |
|
|
bias: bool = False |
|
|
rope_theta: float = 10000.0 |
|
|
torch_dtype: str = "torch.bfloat16" |
|
|
device: Optional[str] = None |
|
|
head_dim: Optional[int] = None |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
bsz: int = 1, |
|
|
dim: int = 768, |
|
|
num_heads: int = 12, |
|
|
num_local_heads: int = -1, |
|
|
num_layers: int = 12, |
|
|
seq_len: int = 4096, |
|
|
vocab_size: int = 200064, |
|
|
inter_dim: Optional[int] = None, |
|
|
mlp_scale: float = 12.0, |
|
|
weight_tying: bool = True, |
|
|
bias: bool = False, |
|
|
rope_theta: float = 10000.0, |
|
|
torch_dtype: str = "torch.bfloat16", |
|
|
device: Optional[str] = None, |
|
|
head_dim: Optional[int] = None, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
self.bsz = bsz |
|
|
self.dim = dim |
|
|
self.num_heads = num_heads |
|
|
self.num_local_heads = num_local_heads |
|
|
self.num_layers = num_layers |
|
|
self.seq_len = seq_len |
|
|
self.vocab_size = vocab_size |
|
|
self.inter_dim = inter_dim |
|
|
self.mlp_scale = mlp_scale |
|
|
self.weight_tying = weight_tying |
|
|
self.bias = bias |
|
|
self.rope_theta = rope_theta |
|
|
self.torch_dtype = torch_dtype |
|
|
self.device = device |
|
|
self.head_dim = head_dim |
|
|
|
|
|
self._post_init_logic() |
|
|
|
|
|
def _post_init_logic(self): |
|
|
if self.num_local_heads == -1: |
|
|
self.num_local_heads = self.num_heads |
|
|
if self.inter_dim is None: |
|
|
hidden_dim = self.mlp_scale * self.dim |
|
|
num_hidden = int(2 * hidden_dim / 3) |
|
|
multiple = 256 |
|
|
self.inter_dim = find_multiple(num_hidden, multiple) if num_hidden > 0 else multiple |
|
|
|
|
|
if self.num_heads > 0: |
|
|
self.head_dim = self.dim // self.num_heads |
|
|
else: |
|
|
raise ValueError("num_heads must be positive") |
|
|
|
|
|
if isinstance(self.torch_dtype, str): |
|
|
dtype_str = self.torch_dtype.replace("torch.", "") |
|
|
try: |
|
|
self.torch_dtype = getattr(torch, dtype_str) |
|
|
except AttributeError as err: |
|
|
raise ValueError(f"Invalid torch_dtype string: {self.torch_dtype}") from err |
|
|
elif not isinstance(self.torch_dtype, torch.dtype): |
|
|
raise ValueError(f"torch_dtype must be a string or torch.dtype, got {type(self.torch_dtype)}") |
|
|
|
|
|
if isinstance(self.device, str): |
|
|
self.device = torch.device(self.device) |
|
|
|
|
|
@classmethod |
|
|
def from_name(cls, name: str): |
|
|
print("Not yet implemented") |
|
|
pass |
|
|
|
|
|
|
|
|
class MLP(nn.Module): |
|
|
def __init__(self, config: TransformerConfig) -> None: |
|
|
super().__init__() |
|
|
self.w1 = nn.Linear(config.dim, config.inter_dim) |
|
|
self.w2 = nn.Linear(config.inter_dim, config.dim) |
|
|
self.w2.SCALE_INIT = 1 |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
return self.w2(F.gelu(self.w1(x), approximate="tanh")) |
|
|
|
|
|
|
|
|
class TransformerLayer(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.attn_norm = nn.LayerNorm(config.dim, dtype=config.torch_dtype) |
|
|
self.attn = Attention(config) |
|
|
self.mlp_norm = nn.LayerNorm(config.dim, dtype=config.torch_dtype) |
|
|
self.mlp = MLP(config) |
|
|
|
|
|
def forward(self, x): |
|
|
x = x + self.attn(self.attn_norm(x)) |
|
|
x = x + self.mlp(self.mlp_norm(x)) |
|
|
return x |
|
|
|
|
|
|
|
|
class Transformer(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.tok_emb = nn.Embedding(config.vocab_size, config.dim) |
|
|
self.layers = nn.ModuleList() |
|
|
for _ in range(config.num_layers): |
|
|
self.layers.append(TransformerLayer(config)) |
|
|
self.norm_f = nn.LayerNorm(config.dim, dtype=config.torch_dtype) |
|
|
self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=False) |
|
|
|
|
|
if self.config.weight_tying: |
|
|
self.tok_emb.weight = self.lm_head.weight |
|
|
|
|
|
self.std = self.config.dim**-0.5 |
|
|
|
|
|
def init_weights(self, module): |
|
|
std = self.std |
|
|
if isinstance(module, nn.Linear): |
|
|
if hasattr(module, "SCALE_INIT"): |
|
|
std *= (2 * self.config.num_layers) ** -0.5 |
|
|
torch.nn.init.normal_(module.weight, mean=0.0, std=std) |
|
|
if module.bias is not None: |
|
|
torch.nn.init.zeros_(module.bias) |
|
|
elif isinstance(module, nn.Embedding): |
|
|
torch.nn.init.normal_(module.weight, mean=0.0, std=std) |
|
|
|
|
|
def forward(self, input_ids: torch.Tensor, labels: torch.Tensor = None, **kwargs) -> CausalLMOutput: |
|
|
x = self.tok_emb(input_ids) |
|
|
|
|
|
for layer in self.layers: |
|
|
x = layer(x) |
|
|
|
|
|
x = self.norm_f(x) |
|
|
logits = self.lm_head(x) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
loss = loss_fn(logits.flatten(0, 1), labels.flatten(0, 1)) |
|
|
|
|
|
return CausalLMOutput( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
) |
|
|
|
|
|
def get_num_params(self): |
|
|
""" |
|
|
Return the number of parameters in the model. |
|
|
For non-embedding count (default), the position embeddings get subtracted. |
|
|
""" |
|
|
n_params = sum(p.numel() for p in self.parameters()) |
|
|
return n_params |
|
|
|
|
|
|
|
|
def create_base_model_components(model_name_or_path=None, **kwargs): |
|
|
"""Just load the config.""" |
|
|
if model_name_or_path is not None: |
|
|
config = TransformerConfig.from_pretrained(model_name_or_path, **kwargs) |
|
|
else: |
|
|
config = TransformerConfig(**kwargs) |
|
|
return config |
|
|
|
|
|
|
|
|
class TransformerForCausalLM(PreTrainedModel): |
|
|
"""Thin wrapper to comply with HuggingFace's expected interface""" |
|
|
|
|
|
config_class = TransformerConfig |
|
|
base_model_prefix = "transformer" |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
|
|
|
self.transformer = Transformer(config) |
|
|
self.transformer.apply(self.transformer.init_weights) |
|
|
|
|
|
def forward( |
|
|
self, input_ids: torch.Tensor, labels: torch.Tensor = None, attention_mask: torch.Tensor = None, **kwargs |
|
|
) -> CausalLMOutput: |
|
|
outputs = self.transformer(input_ids, labels=labels, **kwargs) |
|
|
return outputs |
|
|
|
|
|
def generate( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
max_length: int = 32, |
|
|
num_return_sequences: int = 4, |
|
|
temperature: float = 0.8, |
|
|
top_k: int = 50, |
|
|
top_p: float = 0.95, |
|
|
repetition_penalty: float = 1.2, |
|
|
seed: int = 42, |
|
|
) -> torch.Tensor: |
|
|
"""Generate text using top-k and nucleus sampling with temperature and repetition penalty. |
|
|
|
|
|
Args: |
|
|
input_ids: Input token ids of shape (batch_size, seq_len) |
|
|
max_length: Maximum length of generated sequence |
|
|
num_return_sequences: Number of sequences to generate per input |
|
|
temperature: Sampling temperature. Higher = more random, lower = more focused |
|
|
top_k: Number of highest probability tokens to keep for top-k sampling |
|
|
top_p: Cumulative probability cutoff for nucleus sampling |
|
|
repetition_penalty: Penalty factor for repeating tokens. 1.0 = no penalty |
|
|
seed: Random seed for reproducibility |
|
|
|
|
|
Returns: |
|
|
Generated token ids of shape (num_return_sequences, max_length) |
|
|
""" |
|
|
self.eval() |
|
|
device = input_ids.device |
|
|
|
|
|
|
|
|
input_ids = input_ids.repeat(num_return_sequences, 1) |
|
|
generated = input_ids |
|
|
|
|
|
|
|
|
sample_rng = torch.Generator(device=device) |
|
|
sample_rng.manual_seed(seed) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
while generated.size(1) < max_length: |
|
|
|
|
|
outputs = self.transformer(generated) |
|
|
next_token_logits = outputs.logits[:, -1, :] |
|
|
|
|
|
|
|
|
if repetition_penalty != 1.0: |
|
|
for i in range(generated.shape[0]): |
|
|
for token in generated[i]: |
|
|
if token in next_token_logits[i]: |
|
|
next_token_logits[i, token] /= repetition_penalty |
|
|
|
|
|
|
|
|
if temperature != 1.0: |
|
|
next_token_logits = next_token_logits / temperature |
|
|
|
|
|
|
|
|
probs = torch.nn.functional.softmax(next_token_logits, dim=-1) |
|
|
|
|
|
|
|
|
if top_k > 0: |
|
|
indices_to_remove = probs < torch.topk(probs, top_k)[0][..., -1, None] |
|
|
probs[indices_to_remove] = 0 |
|
|
|
|
|
|
|
|
if top_p < 1.0: |
|
|
sorted_probs, sorted_indices = torch.sort(probs, descending=True) |
|
|
cumulative_probs = torch.cumsum(sorted_probs, dim=-1) |
|
|
|
|
|
|
|
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
|
|
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
|
sorted_indices_to_remove[..., 0] = 0 |
|
|
|
|
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) |
|
|
probs[indices_to_remove] = 0 |
|
|
|
|
|
|
|
|
probs = probs / probs.sum(dim=-1, keepdim=True).clamp(min=1e-8) |
|
|
|
|
|
|
|
|
next_token = torch.multinomial(probs, num_samples=1, generator=sample_rng) |
|
|
|
|
|
|
|
|
generated = torch.cat([generated, next_token], dim=1) |
|
|
|
|
|
return generated |
|
|
|
|
|
def get_num_params(self): |
|
|
return self.transformer.get_num_params() |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
|
|
|
|
|
config = create_base_model_components(pretrained_model_name_or_path, **kwargs) |
|
|
model = cls(config) |
|
|
|
|
|
|
|
|
weights_path = hf_hub_download( |
|
|
repo_id=pretrained_model_name_or_path, |
|
|
filename="model.safetensors", |
|
|
cache_dir=kwargs.get("cache_dir"), |
|
|
force_download=kwargs.get("force_download", False), |
|
|
proxies=kwargs.get("proxies", None), |
|
|
local_files_only=kwargs.get("local_files_only", False), |
|
|
use_auth_token=kwargs.get("use_auth_token", None), |
|
|
revision=kwargs.get("revision", None), |
|
|
subfolder=kwargs.get("subfolder", ""), |
|
|
) |
|
|
|
|
|
|
|
|
state_dict = load_file(weights_path) |
|
|
|
|
|
|
|
|
tok_emb_key = "tok_emb.weight" |
|
|
lm_head_key = "lm_head.weight" |
|
|
|
|
|
tok_emb_present = tok_emb_key in state_dict |
|
|
lm_head_present = lm_head_key in state_dict |
|
|
|
|
|
if tok_emb_present and not lm_head_present: |
|
|
print(f"Reconstructing weight tying: Linking missing '{lm_head_key}' to existing '{tok_emb_key}'") |
|
|
state_dict[lm_head_key] = state_dict[tok_emb_key] |
|
|
elif lm_head_present and not tok_emb_present: |
|
|
print(f"Reconstructing weight tying: Linking missing '{tok_emb_key}' to existing '{lm_head_key}'") |
|
|
state_dict[tok_emb_key] = state_dict[lm_head_key] |
|
|
elif not tok_emb_present and not lm_head_present: |
|
|
|
|
|
print( |
|
|
f"Warning: Neither '{tok_emb_key}' nor '{lm_head_key}' found in state_dict. Weight tying cannot be reconstructed." |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
final_state_dict = {f"{cls.base_model_prefix}.{k}": v for k, v in state_dict.items()} |
|
|
model.load_state_dict(final_state_dict) |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model = model.to(device=device, dtype=torch.bfloat16) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
num_params = model.get_num_params() |
|
|
print(f"\nModel loaded: {pretrained_model_name_or_path}") |
|
|
print(f"Parameter count: {num_params / 1e6:.2f}M") |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
config = TransformerConfig() |
|
|
|
|
|
|
|
|
AutoConfig.register("Transformer", TransformerConfig) |
|
|
AutoModel.register(TransformerConfig, Transformer) |
|
|
AutoModelForCausalLM.register(TransformerConfig, TransformerForCausalLM) |
|
|
|
|
|
print("Registered Transformer model and configuration.") |
|
|
|
|
|
|
|
|
def run_model_diagnostics(model, tokenizer, device): |
|
|
"""Run detailed diagnostics to analyze model behavior.""" |
|
|
print("\nRunning model diagnostics...") |
|
|
|
|
|
|
|
|
test_cases = [ |
|
|
|
|
|
"2 + 2 =", |
|
|
|
|
|
"The capital of France is Paris. The capital of Germany is", |
|
|
|
|
|
"If a train travels 120 kilometers in 2 hours, its average speed is", |
|
|
|
|
|
"1, 2, 3, 4,", |
|
|
|
|
|
"The following is a detailed explanation of photosynthesis: Plants use sunlight to", |
|
|
] |
|
|
|
|
|
with torch.no_grad(): |
|
|
for prompt in test_cases: |
|
|
print(f"\nAnalyzing prompt: {prompt}") |
|
|
|
|
|
|
|
|
tokens = tokenizer(prompt, return_tensors="pt") |
|
|
input_ids = tokens["input_ids"].to(device) |
|
|
|
|
|
|
|
|
outputs = model.transformer(input_ids, labels=input_ids) |
|
|
|
|
|
|
|
|
labels = input_ids.clone() |
|
|
shift_logits = outputs.logits[..., :-1, :].contiguous() |
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
|
|
|
loss_fct = nn.CrossEntropyLoss(reduction="none") |
|
|
token_losses = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).view( |
|
|
shift_labels.size() |
|
|
) |
|
|
|
|
|
|
|
|
input_tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) |
|
|
print("\nToken-by-token loss:") |
|
|
for _, (token, loss) in enumerate(zip(input_tokens[1:], token_losses[0])): |
|
|
print(f"{token}: {loss.item():.3f}") |
|
|
|
|
|
print(f"Average loss: {token_losses.mean().item():.3f}") |
|
|
|
|
|
|
|
|
temps = [0.5, 0.7, 1.0] |
|
|
print("\nGeneration temperature comparison:") |
|
|
for temp in temps: |
|
|
gen_ids = model.generate( |
|
|
input_ids, |
|
|
max_length=25, |
|
|
num_return_sequences=1, |
|
|
temperature=temp, |
|
|
top_p=0.9, |
|
|
repetition_penalty=1.5, |
|
|
seed=42, |
|
|
) |
|
|
gen_text = tokenizer.decode(gen_ids[0], skip_special_tokens=True) |
|
|
print(f"\nTemp {temp}: {gen_text}") |
|
|
|
|
|
|
|
|
def validate_model_generation(): |
|
|
print("\nRunning generation validation test...") |
|
|
|
|
|
try: |
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
|
|
|
model_id = "Hazan-Lab/Transformer-340M-0428" |
|
|
model = TransformerForCausalLM.from_pretrained(model_id) |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model = model.to(device=device, dtype=torch.bfloat16) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
num_params = model.get_num_params() |
|
|
print(f"\nModel loaded: {model_id}") |
|
|
print(f"Parameter count: {num_params / 1e6:.2f}M") |
|
|
|
|
|
|
|
|
run_model_diagnostics(model, tokenizer, device) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"\nError during validation: {str(e)}") |
|
|
raise |
|
|
|
|
|
|
|
|
|
|
|
tasks = [ |
|
|
"hellaswag", |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
] |
|
|
|
|
|
tasks_fewshot = { |
|
|
"hellaswag": 0, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
all_results = {} |
|
|
|
|
|
|
|
|
validate_model_generation() |
|
|
model_id = "Hazan-Lab/Transformer-340M-0428" |
|
|
|
|
|
print("\nStarting evaluation tasks...") |
|
|
for task in tasks: |
|
|
print(f"\nEvaluating task: {task}") |
|
|
eval_kwargs = dict( |
|
|
model="hf", |
|
|
model_args=( |
|
|
f"pretrained={model_id}," |
|
|
"trust_remote_code=True," |
|
|
"dtype=bfloat16," |
|
|
"cache_dir=/scratch/gpfs/mn4560/hazan-lab/tensorized_filters/tensorized_filters/eval/cache" |
|
|
), |
|
|
tasks=[task], |
|
|
batch_size="auto", |
|
|
device="cuda:0", |
|
|
) |
|
|
few_shot_value = tasks_fewshot.get(task, -1) |
|
|
if few_shot_value != -1: |
|
|
eval_kwargs["num_fewshot"] = few_shot_value |
|
|
results = evaluator.simple_evaluate(**eval_kwargs) |
|
|
task_result = results["results"].get(task, {}) |
|
|
all_results[task] = task_result |
|
|
print(f"Results for {task}:") |
|
|
print(task_result) |
|
|
print("\n" + "=" * 50 + "\n") |
|
|
|
|
|
print("All Evaluation Results:") |
|
|
for task, result in all_results.items(): |
|
|
print(f"{task}: {result}") |
|
|
|