frawdllm-100m / hf_wrapper.py
tsingla98's picture
Upload FrawdLLMForCausalLM
47bd780 verified
"""
HuggingFace wrapper for FrawdLLM.
This allows the model to be loaded with:
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("tsingla1998/frawdllm-100m", trust_remote_code=True)
"""
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig, PreTrainedModel, GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast
from .config import ModelConfig
from .gpt import FrawdLLM
class FrawdLLMConfig(PretrainedConfig):
"""HuggingFace-compatible configuration for FrawdLLM."""
model_type = "frawdllm"
def __init__(
self,
vocab_size: int = 32000,
n_embd: int = 768,
n_layer: int = 12,
n_head: int = 12,
context_length: int = 1024,
dropout: float = 0.1,
use_rope: bool = True,
use_rmsnorm: bool = False,
use_swiglu: bool = False,
pad_token_id: int = 0,
bos_token_id: int = 2,
eos_token_id: int = 3,
**kwargs,
):
self.vocab_size = vocab_size
self.n_embd = n_embd
self.n_layer = n_layer
self.n_head = n_head
self.context_length = context_length
self.dropout = dropout
self.use_rope = use_rope
self.use_rmsnorm = use_rmsnorm
self.use_swiglu = use_swiglu
# Aliases for HuggingFace compatibility
self.num_hidden_layers = n_layer
self.hidden_size = n_embd
self.num_attention_heads = n_head
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
**kwargs,
)
def to_model_config(self) -> ModelConfig:
"""Convert to internal ModelConfig for the model."""
return ModelConfig(
vocab_size=self.vocab_size,
n_embd=self.n_embd,
n_layer=self.n_layer,
n_head=self.n_head,
context_length=self.context_length,
dropout=self.dropout,
use_rope=self.use_rope,
use_rmsnorm=self.use_rmsnorm,
use_swiglu=self.use_swiglu,
pad_token_id=self.pad_token_id,
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
)
@classmethod
def from_model_config(cls, config: ModelConfig) -> "FrawdLLMConfig":
"""Create from internal ModelConfig."""
return cls(
vocab_size=config.vocab_size,
n_embd=config.n_embd,
n_layer=config.n_layer,
n_head=config.n_head,
context_length=config.context_length,
dropout=config.dropout,
use_rope=config.use_rope,
use_rmsnorm=config.use_rmsnorm,
use_swiglu=config.use_swiglu,
pad_token_id=config.pad_token_id,
bos_token_id=config.bos_token_id,
eos_token_id=config.eos_token_id,
)
class FrawdLLMForCausalLM(PreTrainedModel, GenerationMixin):
"""HuggingFace-compatible wrapper for FrawdLLM."""
config_class = FrawdLLMConfig
base_model_prefix = "model"
supports_gradient_checkpointing = False
_no_split_modules = ["TransformerBlock"]
_tied_weights_keys = ["model.lm_head.weight"]
def __init__(self, config: FrawdLLMConfig):
super().__init__(config)
# Convert HF config to internal config
model_config = config.to_model_config()
# Create the actual model
self.model = FrawdLLM(model_config)
# For generation
self.main_input_name = "input_ids"
def get_input_embeddings(self):
return self.model.embeddings.token_emb
def set_input_embeddings(self, value):
self.model.embeddings.token_emb = value
def get_output_embeddings(self):
return self.model.lm_head
def set_output_embeddings(self, new_embeddings):
self.model.lm_head = new_embeddings
def tie_weights(self):
"""Tie input and output embeddings."""
self.model.lm_head.weight = self.model.embeddings.token_emb.weight
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
"""
Forward pass compatible with HuggingFace API.
Note: attention_mask, past_key_values, use_cache are accepted but
not fully implemented (our model doesn't use KV caching yet).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Get logits from our model
logits, _ = self.model(input_ids, None)
# Compute loss if labels provided
loss = None
if labels is not None:
# Shift for causal LM loss
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100,
)
if not return_dict:
output = (logits,)
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=None,
hidden_states=None,
attentions=None,
)
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[Tuple] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
):
"""Prepare inputs for generation (called by HF generate())."""
# Our model doesn't use KV cache yet, so just return input_ids
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
@classmethod
def from_frawdllm_checkpoint(
cls,
checkpoint_path: str,
device: str = "cpu",
) -> "FrawdLLMForCausalLM":
"""
Load from a FrawdLLM .pt checkpoint.
Args:
checkpoint_path: Path to the .pt checkpoint file
device: Device to load the model on
Returns:
FrawdLLMForCausalLM instance
"""
# Load the checkpoint
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
# Get the internal config
model_config = checkpoint["config"]
# Create HF config
hf_config = FrawdLLMConfig.from_model_config(model_config)
# Create the wrapper model
model = cls(hf_config)
# Load the weights
model.model.load_state_dict(checkpoint["model_state_dict"])
return model
def save_pretrained_simple(self, save_directory: str):
"""
Save in HuggingFace format.
This saves:
- config.json
- model.safetensors (or pytorch_model.bin)
"""
import os
from safetensors.torch import save_file
os.makedirs(save_directory, exist_ok=True)
# Save config
self.config.save_pretrained(save_directory)
# Save model weights
# Note: We have weight tying (token_emb.weight == lm_head.weight)
# Remove the duplicate to avoid safetensors error
state_dict = self.state_dict()
if "model.lm_head.weight" in state_dict:
del state_dict["model.lm_head.weight"]
save_file(state_dict, os.path.join(save_directory, "model.safetensors"))
print(f"Saved model to {save_directory}")
# Register for AutoClass - this adds auto_map to config when saving
FrawdLLMConfig.register_for_auto_class()
FrawdLLMForCausalLM.register_for_auto_class("AutoModelForCausalLM")