Update README.md
Browse files
README.md
CHANGED
|
@@ -13,10 +13,13 @@ metrics:
|
|
| 13 |
|
| 14 |
Welcome to SmaLLMPro 350M, our latest Instruct-Model based on FineWeb-Edu.
|
| 15 |
|
| 16 |
-
# 1.
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
# 2. Trainingcode
|
| 22 |
```python
|
|
@@ -707,11 +710,12 @@ import torch
|
|
| 707 |
import tiktoken
|
| 708 |
from model import GPTConfig, GPT
|
| 709 |
|
| 710 |
-
|
|
|
|
| 711 |
device = 'cuda'
|
| 712 |
enc = tiktoken.get_encoding("gpt2")
|
| 713 |
|
| 714 |
-
print("Loading SmaLLMPro
|
| 715 |
checkpoint = torch.load(ckpt_path, map_location=device)
|
| 716 |
gptconf = GPTConfig(**checkpoint['model_args'])
|
| 717 |
model = GPT(gptconf)
|
|
@@ -728,7 +732,7 @@ model.to(device)
|
|
| 728 |
print(f"Model {ckpt_path} ready!\n")
|
| 729 |
|
| 730 |
def run_chat():
|
| 731 |
-
print("--- SmaLLMPro Chatbot (Type 'exit' to
|
| 732 |
|
| 733 |
while True:
|
| 734 |
user_input = input("You: ")
|
|
@@ -742,10 +746,16 @@ def run_chat():
|
|
| 742 |
print("SmaLLMPro: ", end="", flush=True)
|
| 743 |
with torch.no_grad():
|
| 744 |
with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
|
| 745 |
-
y = model.generate(x,
|
|
|
|
| 746 |
full_text = enc.decode(y[0].tolist())
|
| 747 |
|
| 748 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 749 |
print(response + "\n")
|
| 750 |
|
| 751 |
if __name__ == "__main__":
|
|
|
|
| 13 |
|
| 14 |
Welcome to SmaLLMPro 350M, our latest Instruct-Model based on FineWeb-Edu.
|
| 15 |
|
| 16 |
+
# 1. Model Details
|
| 17 |
+
- **Parameters:** 353.55M
|
| 18 |
+
- **Layers:** 24
|
| 19 |
+
- **Heads:** 16
|
| 20 |
+
- **Embedding Dim:** 1024
|
| 21 |
+
- **Context Length:** 1024
|
| 22 |
+
- **Format:** ONNX (Opset 18)
|
| 23 |
|
| 24 |
# 2. Trainingcode
|
| 25 |
```python
|
|
|
|
| 710 |
import tiktoken
|
| 711 |
from model import GPTConfig, GPT
|
| 712 |
|
| 713 |
+
# --- Config ---
|
| 714 |
+
ckpt_path = '/media/leo/Data/checkpoints/350m_SmaLLMPro_Final/SmaLLMPro_iter_1500.pt'
|
| 715 |
device = 'cuda'
|
| 716 |
enc = tiktoken.get_encoding("gpt2")
|
| 717 |
|
| 718 |
+
print("Loading SmaLLMPro...")
|
| 719 |
checkpoint = torch.load(ckpt_path, map_location=device)
|
| 720 |
gptconf = GPTConfig(**checkpoint['model_args'])
|
| 721 |
model = GPT(gptconf)
|
|
|
|
| 732 |
print(f"Model {ckpt_path} ready!\n")
|
| 733 |
|
| 734 |
def run_chat():
|
| 735 |
+
print("--- SmaLLMPro Chatbot (Type 'exit' to quit) ---")
|
| 736 |
|
| 737 |
while True:
|
| 738 |
user_input = input("You: ")
|
|
|
|
| 746 |
print("SmaLLMPro: ", end="", flush=True)
|
| 747 |
with torch.no_grad():
|
| 748 |
with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
|
| 749 |
+
y = model.generate(x, max_new_tokens=500, temperature=0.65, top_k=25)
|
| 750 |
+
|
| 751 |
full_text = enc.decode(y[0].tolist())
|
| 752 |
|
| 753 |
+
if "Response:\n" in full_text:
|
| 754 |
+
response = full_text.split("Response:\n")[-1]
|
| 755 |
+
else:
|
| 756 |
+
response = full_text
|
| 757 |
+
|
| 758 |
+
response = response.split("<|endoftext|>")[0].split("Instruction:")[0].strip()
|
| 759 |
print(response + "\n")
|
| 760 |
|
| 761 |
if __name__ == "__main__":
|