PleIAs/SYNTH
Viewer • Updated • 68M • 6.31k • 269
A 110M parameter GPT-style language model pretrained on the PleIAs/SYNTH synthetic reasoning dataset.
Note: This is a base pretrained model only. It has not been instruction-tuned or aligned.
| Parameter | Value |
|---|---|
| Layers | 12 |
| Heads | 12 |
| Embedding Dim | 768 |
| Head Dim | 64 |
import json
import torch
from huggingface_hub import hf_hub_download
from safetensors.torch import load_model
from model import GPT, GPTConfig
from tokenizer import t5
# Load tokenizer
tok, bos_id, eos_id = t5()
# Load model
device = "cuda" if torch.cuda.is_available() else "cpu"
config_path = hf_hub_download("ethanthoma/synth-gpt-110m", "config.json")
model_path = hf_hub_download("ethanthoma/synth-gpt-110m", "model.safetensors")
with open(config_path) as f:
config = GPTConfig(**json.load(f))
model = GPT(config)
load_model(model, model_path, device=device)
model.eval()
# Generate
def generate(prompt, max_tokens=100, temperature=0.8):
tokens = tok.encode(prompt).ids
idx = torch.tensor([tokens], device=device)
with torch.no_grad():
for _ in range(max_tokens):
logits, _ = model(idx[:, -1024:])
logits = logits[:, -1, :] / temperature
probs = torch.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
idx = torch.cat([idx, next_token], dim=1)
if next_token.item() == eos_id:
break
return tok.decode(idx[0].tolist())
print(generate("What is 2 + 2?"))
MIT