rudygpt-instruct
An instruction-tuned version of rudyon/rudygpt, a 124M parameter causal language model. Fine-tuned on the tatsu-lab/alpaca dataset, using full fine-tuning with mixed precision training on a Kaggle T4 GPU.
usage
import torch
from transformers import AutoTokenizer
from huggingface_hub import hf_hub_download
import sys, os
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# download model code and weights
model_py = hf_hub_download(repo_id="rudyon/rudygpt-instruct", filename="model.py")
weights = hf_hub_download(repo_id="rudyon/rudygpt-instruct", filename="pytorch_model.bin")
sys.path.insert(0, os.path.dirname(model_py))
from model import GPT, GPTConfig
tokenizer = AutoTokenizer.from_pretrained("rudyon/rudygpt-instruct")
model = GPT(GPTConfig(depth=12, vocab_size=50304))
state_dict = torch.load(weights, map_location=device)
model.load_state_dict(state_dict)
model.eval()
def chat(instruction, max_new_tokens=200, temperature=0.7, top_p=0.9):
prompt = f"### Instruction:\n{instruction}\n\n### Response:\n"
tokens = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
generated = []
with torch.no_grad():
for _ in range(max_new_tokens):
logits, _ = model(tokens)
next_token_logits = logits[:, -1, :].float() / temperature
probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_indices_to_remove = cumulative_probs - sorted_probs > top_p
sorted_probs[sorted_indices_to_remove] = 0
sorted_probs = sorted_probs / sorted_probs.sum()
next_token = sorted_indices[0][torch.multinomial(sorted_probs[0], 1)]
if next_token.item() == tokenizer.eos_token_id:
break
tokens = torch.cat([tokens, next_token.reshape(1, 1)], dim=1)
generated.append(next_token.item())
current_text = tokenizer.decode(generated, skip_special_tokens=True)
if "### Instruction:" in current_text:
return current_text.split("### Instruction:")[0].strip()
return tokenizer.decode(generated, skip_special_tokens=True).strip()
print(chat("What is gravity?"))
prompt format
### Instruction:
Your instruction here
### Response:
Model tree for rudyon/rudygpt-instruct
Base model
rudyon/rudygpt