WikiGemma330M / README.md
deepak88's picture
Upload folder using huggingface_hub
8c402fc verified
---
model-index:
- name: gemma-from-scratch
results: []
---
# My Gemma-like Model from Scratch
This model is a custom implementation of a Gemma-like architecture, trained from scratch.
## Training Details
- **Architecture**: A 18-layer decoder-only transformer with Grouped-Query Attention.
- **Data**: Trained on the Wikitext-2 dataset.
- **Training Script**: The training script is available on GitHub at [https://github.com/your_github_repo](https://github.com/your_github_repo).
- **Parameters**: Total trainable parameters: 330.64 million.
### Checkpointing
The training script includes a checkpointing mechanism. It automatically saves the model's progress every 50 steps and at the end of each epoch to a file named `checkpoint.pt`. You can resume training by simply re-running the script. The final model is saved as `pytorch_model.bin`.
### Early Stopping
To prevent overfitting, the training process includes early stopping based on the validation loss. The script will monitor the loss on a dedicated validation set and stop training if it does not improve for 2 consecutive epochs.
## Loading and Chatting with the Model
Since this model uses a custom architecture, it requires the model class definitions from the training script to be loaded.
Here's a step-by-step guide to get started:
1. **Install Required Libraries**:
```bash
pip install torch huggingface-hub tokenizers
```
2. **Copy the Model Architecture**:
Copy the `GemmaForCausalLM` and all its required sub-classes (`RMSNorm`, `RotaryPositionalEmbedding`, `MultiHeadAttention`, `MLP`, `TransformerBlock`) from this training script into your new Python file.
3. **Load the Model and Tokenizer**:
```python
import torch
from huggingface_hub import hf_hub_download
from tokenizers import Tokenizer
# Define your model's hyperparameters
config = {
"vocab_size": 30000,
"hidden_size": 1024,
"num_attention_heads": 8,
"num_key_value_heads": 1,
"num_layers": 18,
"intermediate_size": 4096,
"max_position_embeddings": 32768,
"attention_dropout": 0.0,
"hidden_dropout": 0.0,
"sliding_window": 512,
"device": "cuda" if torch.cuda.is_available() else "cpu"
}
# Instantiate the custom model and load the weights
model = GemmaForCausalLM(config)
model_path = hf_hub_download(repo_id="your_username/gemma-from-scratch", filename="pytorch_model.bin")
model.load_state_dict(torch.load(model_path, map_location=config["device"]))
model.to(config["device"]).eval()
# Load the tokenizer
tokenizer_path = hf_hub_download(repo_id="your_username/gemma-from-scratch", filename="tokenizer.json")
tokenizer = Tokenizer.from_file(tokenizer_path)
```
4. **Generate Text**:
```python
def generate_text(model, tokenizer, prompt, max_length=50):
input_ids = tokenizer.encode(prompt).ids
input_tensor = torch.tensor(input_ids).unsqueeze(0).to(config["device"])
with torch.no_grad():
for _ in range(max_length):
logits, _ = model(input_tensor)
next_token_logits = logits[:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)
input_tensor = torch.cat([input_tensor, next_token], dim=-1)
# Stop if we generate the end-of-sentence token
if next_token.item() == tokenizer.token_to_id("</s>"):
break
return tokenizer.decode(input_tensor[0].tolist(), skip_special_tokens=True)
# Example usage
prompt = "The early bird catches the worm, but the second mouse gets the "
generated_text = generate_text(model, tokenizer, prompt)
print("Generated Text:")
print(generated_text)
```
> **Note**: This model is for demonstration purposes. Its custom architecture is not directly compatible with the Hugging Face `transformers` library out-of-the-box. To use the model, you must also include the full model class definitions in your script.