YAML Metadata
Warning:
empty or missing yaml metadata in repo card
(https://huggingface.co/docs/hub/model-cards#model-card-metadata)
example usage for this model
from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2Config
import torch
from huggingface_hub import hf_hub_download
def load_model(checkpoint_name="best_val_loss_dense_step_9000.bin", model_id="idhant297/dense-5l-test"):
"""
Load a model from HuggingFace Hub with a specific checkpoint.
Args:
checkpoint_step (int): The training step checkpoint to load
model_id (str): The HuggingFace model repository ID
Returns:
tuple: (model, tokenizer) loaded from the checkpoint
"""
print(f"Loading model from {model_id} at step {checkpoint_step}...")
tokenizer = AutoTokenizer.from_pretrained(model_id)
config = GPT2Config.from_pretrained(model_id)
model = AutoModelForCausalLM.from_config(config)
checkpoint_path = hf_hub_download(
repo_id=model_id,
filename=checkpoint_filename
)
state_dict = torch.load(checkpoint_path, map_location="cpu")
model.load_state_dict(state_dict)
model.eval()
print(f"Model loaded successfully from checkpoint step {checkpoint_name}")
return model, tokenizer
def generate_text(model, tokenizer, prompt, max_length=100, temperature=0.8, top_p=0.95, num_return_sequences=1):
"""
Generate text using the loaded model.
Args:
model: The loaded model
tokenizer: The loaded tokenizer
prompt (str): Input text prompt
max_length (int): Maximum length of generated text
temperature (float): Sampling temperature
top_p (float): Top-p sampling parameter
num_return_sequences (int): Number of sequences to generate
Returns:
list: Generated text sequences
"""
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(
inputs["input_ids"],
max_length=max_length,
temperature=temperature,
top_p=top_p,
do_sample=True,
num_return_sequences=num_return_sequences,
pad_token_id=tokenizer.eos_token_id
)
generated_texts = []
for output in outputs:
text = tokenizer.decode(output, skip_special_tokens=True)
generated_texts.append(text)
return generated_texts
# example usage
checkpoint_name = "best_val_loss_dense_step_9000.bin"
model, tokenizer = load_model(checkpoint_name)
prompt = "The quick brown fox"
generated = generate_text(model, tokenizer, prompt, max_length=20)
print(generated)
- Downloads last month
- 1
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
🙋
Ask for provider support