Update modeling_tinylm.py
Browse files- modeling_tinylm.py +2 -2
modeling_tinylm.py
CHANGED
|
@@ -89,7 +89,7 @@ def load_tinylm(model_dir, device="cpu"):
|
|
| 89 |
return model, tokenizer, config
|
| 90 |
|
| 91 |
|
| 92 |
-
def generate(model, tokenizer, prompt, max_new_tokens=100, temperature=0.
|
| 93 |
MAX_SEQ_LEN = model.pos_emb.num_embeddings
|
| 94 |
model.eval()
|
| 95 |
ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
|
@@ -114,4 +114,4 @@ def generate(model, tokenizer, prompt, max_new_tokens=100, temperature=0.8, top_
|
|
| 114 |
if __name__ == "__main__":
|
| 115 |
model, tokenizer, config = load_tinylm("./tinylm")
|
| 116 |
print("Model loaded!")
|
| 117 |
-
print("Use 'module.generate(model, tokenizer, \"Once upon a time
|
|
|
|
| 89 |
return model, tokenizer, config
|
| 90 |
|
| 91 |
|
| 92 |
+
def generate(model, tokenizer, prompt, max_new_tokens=100, temperature=0.1, top_k=25, device="cpu"):
|
| 93 |
MAX_SEQ_LEN = model.pos_emb.num_embeddings
|
| 94 |
model.eval()
|
| 95 |
ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
|
|
|
| 114 |
if __name__ == "__main__":
|
| 115 |
model, tokenizer, config = load_tinylm("./tinylm")
|
| 116 |
print("Model loaded!")
|
| 117 |
+
print("Use 'module.generate(model, tokenizer, \"Once upon a time\")' to generate.")
|