|
|
--- |
|
|
model-index: |
|
|
- name: gemma-from-scratch |
|
|
results: [] |
|
|
--- |
|
|
|
|
|
# My Gemma-like Model from Scratch |
|
|
|
|
|
This model is a custom implementation of a Gemma-like architecture, trained from scratch. |
|
|
|
|
|
## Training Details |
|
|
- **Architecture**: A 18-layer decoder-only transformer with Grouped-Query Attention. |
|
|
- **Data**: Trained on the Wikitext-2 dataset. |
|
|
- **Training Script**: The training script is available on GitHub at [https://github.com/your_github_repo](https://github.com/your_github_repo). |
|
|
- **Parameters**: Total trainable parameters: 330.64 million. |
|
|
|
|
|
### Checkpointing |
|
|
The training script includes a checkpointing mechanism. It automatically saves the model's progress every 50 steps and at the end of each epoch to a file named `checkpoint.pt`. You can resume training by simply re-running the script. The final model is saved as `pytorch_model.bin`. |
|
|
|
|
|
### Early Stopping |
|
|
To prevent overfitting, the training process includes early stopping based on the validation loss. The script will monitor the loss on a dedicated validation set and stop training if it does not improve for 2 consecutive epochs. |
|
|
|
|
|
## Loading and Chatting with the Model |
|
|
|
|
|
Since this model uses a custom architecture, it requires the model class definitions from the training script to be loaded. |
|
|
|
|
|
Here's a step-by-step guide to get started: |
|
|
|
|
|
1. **Install Required Libraries**: |
|
|
```bash |
|
|
pip install torch huggingface-hub tokenizers |
|
|
``` |
|
|
|
|
|
2. **Copy the Model Architecture**: |
|
|
Copy the `GemmaForCausalLM` and all its required sub-classes (`RMSNorm`, `RotaryPositionalEmbedding`, `MultiHeadAttention`, `MLP`, `TransformerBlock`) from this training script into your new Python file. |
|
|
|
|
|
3. **Load the Model and Tokenizer**: |
|
|
```python |
|
|
import torch |
|
|
from huggingface_hub import hf_hub_download |
|
|
from tokenizers import Tokenizer |
|
|
|
|
|
# Define your model's hyperparameters |
|
|
config = { |
|
|
"vocab_size": 30000, |
|
|
"hidden_size": 1024, |
|
|
"num_attention_heads": 8, |
|
|
"num_key_value_heads": 1, |
|
|
"num_layers": 18, |
|
|
"intermediate_size": 4096, |
|
|
"max_position_embeddings": 32768, |
|
|
"attention_dropout": 0.0, |
|
|
"hidden_dropout": 0.0, |
|
|
"sliding_window": 512, |
|
|
"device": "cuda" if torch.cuda.is_available() else "cpu" |
|
|
} |
|
|
|
|
|
# Instantiate the custom model and load the weights |
|
|
model = GemmaForCausalLM(config) |
|
|
model_path = hf_hub_download(repo_id="your_username/gemma-from-scratch", filename="pytorch_model.bin") |
|
|
model.load_state_dict(torch.load(model_path, map_location=config["device"])) |
|
|
model.to(config["device"]).eval() |
|
|
|
|
|
# Load the tokenizer |
|
|
tokenizer_path = hf_hub_download(repo_id="your_username/gemma-from-scratch", filename="tokenizer.json") |
|
|
tokenizer = Tokenizer.from_file(tokenizer_path) |
|
|
``` |
|
|
|
|
|
4. **Generate Text**: |
|
|
```python |
|
|
def generate_text(model, tokenizer, prompt, max_length=50): |
|
|
input_ids = tokenizer.encode(prompt).ids |
|
|
input_tensor = torch.tensor(input_ids).unsqueeze(0).to(config["device"]) |
|
|
|
|
|
with torch.no_grad(): |
|
|
for _ in range(max_length): |
|
|
logits, _ = model(input_tensor) |
|
|
next_token_logits = logits[:, -1, :] |
|
|
next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(0) |
|
|
input_tensor = torch.cat([input_tensor, next_token], dim=-1) |
|
|
|
|
|
# Stop if we generate the end-of-sentence token |
|
|
if next_token.item() == tokenizer.token_to_id("</s>"): |
|
|
break |
|
|
|
|
|
return tokenizer.decode(input_tensor[0].tolist(), skip_special_tokens=True) |
|
|
|
|
|
# Example usage |
|
|
prompt = "The early bird catches the worm, but the second mouse gets the " |
|
|
generated_text = generate_text(model, tokenizer, prompt) |
|
|
print("Generated Text:") |
|
|
print(generated_text) |
|
|
``` |
|
|
|
|
|
> **Note**: This model is for demonstration purposes. Its custom architecture is not directly compatible with the Hugging Face `transformers` library out-of-the-box. To use the model, you must also include the full model class definitions in your script. |
|
|
|
|
|
|