tiny-lm-chat / README.md
sixf0ur's picture
Update README.md
fa57eaf verified
|
Raw
History Blame Contribute Delete
3.07 kB
---
license: mit
datasets:
- sixf0ur/nano_wiki
- sixf0ur/nano_chat
- sixf0ur/babylm_eng_distilled_1024
- sixf0ur/single_turn_chat
language:
- en
tags:
- gemma
---
# tiny-lm-chat
tiny-lm-chat is an ultra-lightweight, 14-million-parameter Small Language Model based on the Gemma architecture.
It has been trained from scratch for conversational efficiency within a compact 192-token context window.
It was designed to do singleturn chats in very simple language (input / output).
## Evaluation Metrics
* **Final Training Loss:** 2.601
* **Validation Loss (Eval):** 2.687
* **Perplexity (PPL):** ~14.68
## Chat Examples (Zero-Shot)
πŸ’¬ **User:** What is the meaning of life?
πŸ€– **Bot:** To find friends and help others.
πŸ’¬ **User:** I cant find my keys.
πŸ€– **Bot:** Check your pockets or the door.
πŸ’¬ **User:** Are you a human?
πŸ€– **Bot:** Nope, just a language.
### Recommend generation config
```
Context Length 192
Temperature 0.75
Top P 0.90
Repeat Penalty 1.10
```
## Usage for chat in terminal
```python
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
def main():
REPO_ID = "sixf0ur/tiny-lm-chat"
MAX_CONTEXT = 192
tokenizer = AutoTokenizer.from_pretrained(REPO_ID)
model = AutoModelForCausalLM.from_pretrained(REPO_ID, torch_dtype=torch.float32)
model.eval()
# custom tokens used for interaction
BOS_ID = tokenizer.convert_tokens_to_ids("<bos>")
USER_ID = tokenizer.convert_tokens_to_ids("<user>")
BOT_ID = tokenizer.convert_tokens_to_ids("<bot>")
print("="*50 + "\n")
while True:
user_input = input("User: ").strip()
if user_input.lower() in ["exit", "quit", "q"]:
break
if not user_input:
continue
user_token_ids = tokenizer.encode(user_input, add_special_tokens=False)
prompt_ids = [BOS_ID, USER_ID] + user_token_ids + [BOT_ID]
prompt_len = len(prompt_ids)
max_new_tokens = MAX_CONTEXT - prompt_len
if max_new_tokens <= 5:
print(f"(Input too long!\n")
continue
input_ids = torch.tensor([prompt_ids])
with torch.no_grad():
outputs = model.generate(
input_ids,
max_new_tokens=max_new_tokens,
temperature=0.75,
top_p=0.9,
repetition_penalty=1.1,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
generated_ids = outputs[0][prompt_len:]
response = tokenizer.decode(generated_ids, skip_special_tokens=True)
print(f"\nBot: {response.strip()}")
print("\n" + "="*50)
if __name__ == "__main__":
main()
```
## lm-studio template (Jinja)
```
{% for message in messages %}
{% if message['role'] == 'user' %}
{{ '<bos><user>' + message['content'].strip() + '<bot>' }}
{% elif message['role'] == 'assistant' %}
{{ message['content'].strip() + '<eos>' }}
{% endif %}
{% endfor %}
```