dilip025 commited on
Commit
ebf1daa
·
verified ·
1 Parent(s): 6228521

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +84 -8
README.md CHANGED
@@ -14,7 +14,7 @@ tags:
14
  ---
15
  # Mini GPT1 Clone
16
 
17
- This is a decoder-only transformer model (GPT1-style) trained from scratch using PyTorch.
18
 
19
  ## Model Details
20
 
@@ -32,14 +32,90 @@ Trained using `ByteLevelBPETokenizer` from the `tokenizers` library.
32
 
33
  ## Inference Example
34
 
 
 
35
  ```python
36
- from transformers import PreTrainedTokenizerFast, AutoModelForCausalLM
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  import torch
38
 
39
- tokenizer = PreTrainedTokenizerFast(tokenizer_file="tokenizer/tokenizer.json")
40
- model = AutoModelForCausalLM.from_pretrained("dilip025/mini-gpt1")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- prompt = "Once upon a time,"
43
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids
44
- outputs = model.generate(input_ids, max_length=50)
45
- print(tokenizer.decode(outputs[0], skip_special_tokens=True))
 
14
  ---
15
  # Mini GPT1 Clone
16
 
17
+ This is a custom decoder-only transformer model (GPT1-style) trained from scratch using PyTorch.
18
 
19
  ## Model Details
20
 
 
32
 
33
  ## Inference Example
34
 
35
+ Run it in google colab. Go to ==> https://colab.research.google.com
36
+
37
  ```python
38
+ # Clone only if not already cloned
39
+ import os
40
+ if not os.path.exists("mini-gpt1"):
41
+ !git clone https://huggingface.co/dilip025/mini-gpt1
42
+
43
+ # Install dependencies, Uncomment it if you haven't installed
44
+ # !pip install torch tokenizers
45
+
46
+ # Add repo path to Python
47
+ import sys
48
+ sys.path.append("mini-gpt1")
49
+
50
+ # Imports
51
+ from model_code.decoder_only_transformer import DecoderOnlyTransformer
52
+ from tokenizers import ByteLevelBPETokenizer
53
  import torch
54
 
55
+ # Load tokenizer
56
+ tokenizer = ByteLevelBPETokenizer(
57
+ "mini-gpt1/vocab.json",
58
+ "mini-gpt1/merges.txt",
59
+ )
60
+
61
+ # Model config
62
+ vocab_size = 35000
63
+ max_len = 128
64
+ embed_dim = 512
65
+ num_heads = 8
66
+ depth = 6
67
+ ff_dim = 2048
68
+
69
+ # Device
70
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
71
+
72
+ # Load model and weights
73
+ model = DecoderOnlyTransformer(
74
+ vocab_size=vocab_size,
75
+ max_len=max_len,
76
+ embed_dim=embed_dim,
77
+ num_heads=num_heads,
78
+ depth=depth,
79
+ ff_dim=ff_dim,
80
+ ).to(device)
81
+
82
+ state_dict = torch.load("mini-gpt1/pytorch_model.bin", map_location=device)
83
+ model.load_state_dict(state_dict)
84
+ model.eval()
85
+
86
+ # 💡 Your generation function with temperature & top-k
87
+ def generate(model, tokenizer, prompt, max_length=50, temperature=1.0, top_k=50):
88
+ model.eval()
89
+ device = next(model.parameters()).device
90
+
91
+ encoding = tokenizer.encode(prompt)
92
+ input_ids = torch.tensor([encoding.ids], dtype=torch.long).to(device)
93
+ generated = input_ids.clone()
94
+
95
+ for _ in range(max_length):
96
+ logits = model(generated) # [1, T, vocab_size]
97
+ next_token_logits = logits[:, -1, :] / temperature
98
+
99
+ if top_k is not None:
100
+ values, indices = torch.topk(next_token_logits, top_k)
101
+ mask = torch.full_like(next_token_logits, float('-inf'))
102
+ mask.scatter_(1, indices, values)
103
+ next_token_logits = mask
104
+
105
+ probs = torch.softmax(next_token_logits, dim=-1)
106
+ next_token = torch.multinomial(probs, num_samples=1)
107
+
108
+ generated = torch.cat((generated, next_token), dim=1)
109
+
110
+ # Optional: stop on [EOS] token
111
+ if hasattr(tokenizer, 'token_to_id') and tokenizer.token_to_id('[EOS]') is not None:
112
+ if next_token.item() == tokenizer.token_to_id('[EOS]'):
113
+ break
114
+
115
+ return tokenizer.decode(generated[0].tolist())
116
+
117
 
118
+ # 🔥 Example inference -- Run this in second cell too see gibberish ;)
119
+ prompt = "He told me a story"
120
+ output = generate(model, tokenizer, prompt, max_length=100, temperature=1.2, top_k=40)
121
+ print("Generated Output:\n", output)