AxionLab-official commited on
Commit
b73274c
·
verified ·
1 Parent(s): 93ca81b

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +46 -6
README.md CHANGED
@@ -129,17 +129,57 @@ This reinforces the distinction between:
129
 
130
  ```python
131
  import torch
132
- from safetensors.torch import load_file
133
- from model import NanoThink
134
  from tokenizers import Tokenizer
 
 
 
 
 
135
 
136
- tokenizer = Tokenizer.from_file("tokenizer.json")
137
 
138
- model = NanoThink(vocab_size=1229)
139
- state_dict = load_file("model.safetensors")
140
- model.load_state_dict(state_dict)
141
 
 
 
142
  model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  ```
144
 
145
  ---
 
129
 
130
  ```python
131
  import torch
 
 
132
  from tokenizers import Tokenizer
133
+ from model import NanoThink
134
+ from safetensors.torch import load_file
135
+
136
+ MODEL_PATH = "model.safetensors"
137
+ TOKENIZER_PATH = "tokenizer.json"
138
 
 
139
 
140
+ tokenizer = Tokenizer.from_file(TOKENIZER_PATH)
 
 
141
 
142
+ model = NanoThink(vocab_size=tokenizer.get_vocab_size())
143
+ model.load_state_dict(load_file(MODEL_PATH))
144
  model.eval()
145
+
146
+ history = ""
147
+
148
+ while True:
149
+ user_input = input("You: ")
150
+
151
+ if user_input.lower() in ["get out", "exit", "quit"]:
152
+ break
153
+
154
+ prompt = history + f"\n<USER>\n{user_input}\n</USER>\n"
155
+
156
+ input_ids = torch.tensor([tokenizer.encode(prompt).ids])
157
+
158
+ output_ids = []
159
+
160
+ for _ in range(120):
161
+ logits = model(input_ids)
162
+ next_token = torch.multinomial(torch.softmax(logits[0, -1], dim=-1), 1).item()
163
+
164
+ input_ids = torch.cat([input_ids, torch.tensor([[next_token]])], dim=1)
165
+ output_ids.append(next_token)
166
+
167
+ text = tokenizer.decode(output_ids)
168
+
169
+ if "</ANSWER>" in text:
170
+ break
171
+
172
+ output = tokenizer.decode(output_ids)
173
+
174
+
175
+ if "<ANSWER>" in output:
176
+ output = output.split("<ANSWER>")[1].split("</ANSWER>")[0]
177
+
178
+ print("\n💬 Answer:")
179
+ print(output.strip())
180
+ print("\n" + "-"*50 + "\n")
181
+
182
+ history += f"\n<USER>\n{user_input}\n</USER>\n<ANSWER>\n{output.strip()}\n</ANSWER>\n"
183
  ```
184
 
185
  ---