|
|
--- |
|
|
license: mit |
|
|
tags: |
|
|
- split-learning |
|
|
- gpt2 |
|
|
- federated-learning |
|
|
- lora |
|
|
--- |
|
|
|
|
|
# DisLLM GPT-2 Split Learning - Client Model |
|
|
|
|
|
This repository contains the **client-side model** (first 4 layers) for a split learning implementation of GPT-2 Small with LoRA fine-tuning. |
|
|
|
|
|
## Model Details |
|
|
|
|
|
- **Architecture**: GPT-2 Small (124M parameters) |
|
|
- **Split Configuration**: First 4 layers out of 12 transformer blocks |
|
|
- **LoRA Parameters**: 1,735,304 trainable parameters |
|
|
- **Training Method**: Federated Split Learning |
|
|
|
|
|
## Performance |
|
|
|
|
|
- **Training PPL**: 30.64 |
|
|
- **Validation PPL**: 27.03 |
|
|
- **Test PPL**: 29.75 |
|
|
- **Training Epochs**: 5 |
|
|
|
|
|
## Usage |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
# Download model |
|
|
model_path = hf_hub_download(repo_id="Chandij123/websockets", filename="client_model.pth") |
|
|
|
|
|
# Load checkpoint |
|
|
checkpoint = torch.load(model_path) |
|
|
model_state = checkpoint['model_state_dict'] |
|
|
config = checkpoint['model_config'] |
|
|
|
|
|
# Initialize and load your FirstPartModel here |
|
|
# first_part_model.load_state_dict(model_state) |
|
|
``` |
|
|
|
|
|
## Split Learning Architecture |
|
|
|
|
|
This model works in conjunction with a server-side model that contains the remaining layers. |
|
|
|
|
|
- **Client**: Processes input through first 4 layers |
|
|
- **Server**: Continues processing through remaining 8 layers |
|
|
|
|
|
## Training Details |
|
|
|
|
|
Dataset: WikiText-2 |
|
|
- Training: 2,359 examples |
|
|
- Validation: 243 examples |
|
|
- Test: 279 examples |
|
|
|
|
|
Training configuration: |
|
|
- Batch size: 2 |
|
|
- Context length: 1024 |
|
|
- Learning rate: 1e-6 |
|
|
- Optimizer: AdamW |
|
|
|
|
|
## License |
|
|
|
|
|
MIT License |
|
|
|