Update README.md
Browse files
README.md
CHANGED
|
@@ -129,17 +129,57 @@ This reinforces the distinction between:
|
|
| 129 |
|
| 130 |
```python
|
| 131 |
import torch
|
| 132 |
-
from safetensors.torch import load_file
|
| 133 |
-
from model import NanoThink
|
| 134 |
from tokenizers import Tokenizer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
-
tokenizer = Tokenizer.from_file("tokenizer.json")
|
| 137 |
|
| 138 |
-
|
| 139 |
-
state_dict = load_file("model.safetensors")
|
| 140 |
-
model.load_state_dict(state_dict)
|
| 141 |
|
|
|
|
|
|
|
| 142 |
model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
```
|
| 144 |
|
| 145 |
---
|
|
|
|
| 129 |
|
| 130 |
```python
|
| 131 |
import torch
|
|
|
|
|
|
|
| 132 |
from tokenizers import Tokenizer
|
| 133 |
+
from model import NanoThink
|
| 134 |
+
from safetensors.torch import load_file
|
| 135 |
+
|
| 136 |
+
MODEL_PATH = "model.safetensors"
|
| 137 |
+
TOKENIZER_PATH = "tokenizer.json"
|
| 138 |
|
|
|
|
| 139 |
|
| 140 |
+
tokenizer = Tokenizer.from_file(TOKENIZER_PATH)
|
|
|
|
|
|
|
| 141 |
|
| 142 |
+
model = NanoThink(vocab_size=tokenizer.get_vocab_size())
|
| 143 |
+
model.load_state_dict(load_file(MODEL_PATH))
|
| 144 |
model.eval()
|
| 145 |
+
|
| 146 |
+
history = ""
|
| 147 |
+
|
| 148 |
+
while True:
|
| 149 |
+
user_input = input("You: ")
|
| 150 |
+
|
| 151 |
+
if user_input.lower() in ["get out", "exit", "quit"]:
|
| 152 |
+
break
|
| 153 |
+
|
| 154 |
+
prompt = history + f"\n<USER>\n{user_input}\n</USER>\n"
|
| 155 |
+
|
| 156 |
+
input_ids = torch.tensor([tokenizer.encode(prompt).ids])
|
| 157 |
+
|
| 158 |
+
output_ids = []
|
| 159 |
+
|
| 160 |
+
for _ in range(120):
|
| 161 |
+
logits = model(input_ids)
|
| 162 |
+
next_token = torch.multinomial(torch.softmax(logits[0, -1], dim=-1), 1).item()
|
| 163 |
+
|
| 164 |
+
input_ids = torch.cat([input_ids, torch.tensor([[next_token]])], dim=1)
|
| 165 |
+
output_ids.append(next_token)
|
| 166 |
+
|
| 167 |
+
text = tokenizer.decode(output_ids)
|
| 168 |
+
|
| 169 |
+
if "</ANSWER>" in text:
|
| 170 |
+
break
|
| 171 |
+
|
| 172 |
+
output = tokenizer.decode(output_ids)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
if "<ANSWER>" in output:
|
| 176 |
+
output = output.split("<ANSWER>")[1].split("</ANSWER>")[0]
|
| 177 |
+
|
| 178 |
+
print("\n💬 Answer:")
|
| 179 |
+
print(output.strip())
|
| 180 |
+
print("\n" + "-"*50 + "\n")
|
| 181 |
+
|
| 182 |
+
history += f"\n<USER>\n{user_input}\n</USER>\n<ANSWER>\n{output.strip()}\n</ANSWER>\n"
|
| 183 |
```
|
| 184 |
|
| 185 |
---
|