|
|
--- |
|
|
library_name: transformers |
|
|
model_name: Asterisk |
|
|
base_model: HuggingFaceTB/SmolLM2-135M-Instruct |
|
|
tags: |
|
|
- aspp |
|
|
- hybrid-architecture |
|
|
- graph-reasoning |
|
|
- sft |
|
|
- trl |
|
|
license: apache-2.0 |
|
|
language: |
|
|
- en |
|
|
--- |
|
|
|
|
|
# Asterisk: Hybrid ASPP-Attention Architecture |
|
|
|
|
|
**Asterisk** is a research implementation that combines the **ASPP (Adjacency-Structured Parallel Propagation)** operator with standard attention mechanisms to enhance the SmolLM2-135M model. The model implements a hybrid architecture that fuses graph-based local reasoning (ASPP) with global attention for improved expressiveness on structured reasoning tasks. |
|
|
|
|
|
## Model Description |
|
|
|
|
|
- **Base Model**: [SmolLM2-135M-Instruct](https://huggingface.co/HuggingFaceTB/SmolLM2-135M-Instruct) |
|
|
- **Architecture**: Hybrid ASPP-Attention (30 hybrid layers) |
|
|
- **Parameters**: 171.2M (35M additional ASPP parameters) |
|
|
- **Training**: Supervised Fine-Tuning on Capybara dataset |
|
|
- **Framework**: Transformers 4.57.6, TRL 0.27.0 |
|
|
|
|
|
|
|
|
## Evaluation Results |
|
|
|
|
|
Evaluated on LM-Evaluation-Harness: |
|
|
|
|
|
| Task | Metric | Score | Stderr | |
|
|
|------|--------|-------|--------| |
|
|
| **HellaSwag** | acc_norm | **0.4430** | ±0.0157 | |
|
|
| **ARC-Easy** | acc_norm | **0.5450** | ±0.0158 | |
|
|
| **ARC-Challenge** | acc_norm | **0.2884** | ±0.0132 | |
|
|
| **PIQA** | acc_norm | **0.6770** | ±0.0148 | |
|
|
| **WinoGrande** | acc | **0.5210** | ±0.0158 | |
|
|
|
|
|
### Key Innovation: The Asterisk Operator (★-operator) |
|
|
|
|
|
The **Asterisk Operator** performs local parallel state evolution through point-wise transformations: |
|
|
|
|
|
``` |
|
|
h_i^(t+1) = φ(h_i^(t)) [K-step iterative evolution] |
|
|
``` |
|
|
|
|
|
This is then gated and fused with standard Llama attention outputs: |
|
|
|
|
|
``` |
|
|
output = gate * ASPP(x) + (1-gate) * Attention(x) |
|
|
``` |
|
|
|
|
|
## Architecture |
|
|
|
|
|
### 1. ASPPOperator (Point-wise Parallel Propagation) |
|
|
|
|
|
```python |
|
|
class ASPPOperator: |
|
|
""" |
|
|
|
|
|
Forward pass: |
|
|
1. Optional dimensionality reduction: h_t = down_proj(hidden_states) |
|
|
2. K-step evolution: h_t = h_t + α * φ(h_t) [K times] |
|
|
3. Layer normalization after each step |
|
|
4. Optional projection back: output = up_proj(h_t) |
|
|
|
|
|
Parameters: |
|
|
- hidden_size: 576 (model dimension) |
|
|
- aspp_hidden_dim: 256 (internal ASPP dimension) |
|
|
- aspp_num_steps: 8 (evolution iterations) |
|
|
- aspp_dropout: 0.2 |
|
|
""" |
|
|
``` |
|
|
|
|
|
**Pseudocode:** |
|
|
``` |
|
|
function ASPP(hidden_states): |
|
|
# Optional dimensionality reduction |
|
|
if use_projection: |
|
|
h_t ← down_proj(hidden_states) |
|
|
h_t ← dropout(h_t) |
|
|
else: |
|
|
h_t ← hidden_states |
|
|
|
|
|
# Learnable number of steps |
|
|
k_steps ← max(1, int(sigmoid(k_logit) * num_steps)) |
|
|
|
|
|
# K-step point-wise evolution |
|
|
for t = 1 to k_steps: |
|
|
# Point-wise update: φ(h_t) = MLP(h_t) |
|
|
h_t_next ← update_net(h_t) |
|
|
|
|
|
# Scaled residual connection |
|
|
h_t ← h_t + residual_scale * h_t_next |
|
|
h_t ← layer_norm(h_t) |
|
|
|
|
|
# Project back to original dimension |
|
|
if use_projection: |
|
|
h_t ← up_proj(h_t) |
|
|
h_t ← dropout(h_t) |
|
|
|
|
|
return h_t |
|
|
``` |
|
|
|
|
|
### 2. HybridASPPAttentionLayer |
|
|
|
|
|
```python |
|
|
class HybridASPPAttentionLayer(LlamaDecoderLayer): |
|
|
""" |
|
|
Extends LlamaDecoderLayer with parallel ASPP branch |
|
|
|
|
|
Architecture: |
|
|
1. Input LayerNorm |
|
|
2. Parallel branches: |
|
|
- ASPP operator for local structured reasoning |
|
|
- Standard LlamaAttention for global context |
|
|
3. Gated fusion: gate * ASPP + (1-gate) * Attention |
|
|
4. Residual connection |
|
|
5. Feed-forward MLP |
|
|
""" |
|
|
``` |
|
|
|
|
|
**Pseudocode:** |
|
|
``` |
|
|
function HybridLayer(hidden_states, attention_mask, ...): |
|
|
residual ← hidden_states |
|
|
hidden_states ← input_layernorm(hidden_states) |
|
|
|
|
|
# Parallel branches |
|
|
aspp_output ← aspp_operator(hidden_states) |
|
|
attn_output ← self_attention(hidden_states, attention_mask, ...) |
|
|
|
|
|
# Gated fusion |
|
|
fusion_input ← concat([aspp_output, attn_output]) |
|
|
gate ← sigmoid(linear(dropout(fusion_input))) |
|
|
fused_output ← gate * aspp_output + (1 - gate) * attn_output |
|
|
|
|
|
# Residual connection |
|
|
hidden_states ← residual + fused_output |
|
|
|
|
|
# MLP block |
|
|
residual ← hidden_states |
|
|
hidden_states ← post_attention_layernorm(hidden_states) |
|
|
hidden_states ← mlp(hidden_states) |
|
|
hidden_states ← residual + hidden_states |
|
|
|
|
|
return hidden_states |
|
|
``` |
|
|
|
|
|
### 3. AsteriskForCausalLM |
|
|
|
|
|
```python |
|
|
class AsteriskForCausalLM(LlamaForCausalLM): |
|
|
""" |
|
|
Main model class with custom model_type "asterisk" |
|
|
|
|
|
Configuration: |
|
|
- hybrid_layer_indices: None (all 30 layers are hybrid) |
|
|
- aspp_hidden_dim: 256 (reduces overfitting) |
|
|
- aspp_num_steps: 8 (learnable, actual steps ≈ 6) |
|
|
- aspp_dropout: 0.2 |
|
|
""" |
|
|
``` |
|
|
|
|
|
**Note**: These are preliminary results with sample limits. Full evaluation pending. |
|
|
|
|
|
## Quick Start |
|
|
|
|
|
```python |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
import torch |
|
|
|
|
|
# Load model and tokenizer |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
"path/to/Asterisk", |
|
|
trust_remote_code=True, |
|
|
torch_dtype=torch.bfloat16, |
|
|
device_map="auto" |
|
|
) |
|
|
tokenizer = AutoTokenizer.from_pretrained("path/to/Asterisk") |
|
|
|
|
|
# Generate text |
|
|
messages = [{"role": "user", "content": "Explain quantum computing in simple terms."}] |
|
|
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(model.device) |
|
|
|
|
|
outputs = model.generate( |
|
|
inputs, |
|
|
max_new_tokens=256, |
|
|
temperature=0.7, |
|
|
do_sample=True, |
|
|
) |
|
|
print(tokenizer.decode(outputs[0], skip_special_tokens=True)) |
|
|
``` |
|
|
|
|
|
## Training Details |
|
|
|
|
|
### Training Configuration |
|
|
- **Dataset**: Capybara (conversational instruction-following) |
|
|
- **Optimizer**: AdamW (lr=2e-5, weight_decay=0.01) |
|
|
- **Batch Size**: 4 per device, gradient accumulation=4 (effective batch=16) |
|
|
- **Epochs**: 2 |
|
|
- **Scheduler**: Cosine with warmup (100 steps) |
|
|
- **Mixed Precision**: bfloat16 |
|
|
- **Gradient Checkpointing**: Enabled |
|
|
|
|
|
### ASPP Configuration |
|
|
```python |
|
|
aspp_hidden_dim = 256 # Internal dimension (vs 576 model hidden_size) |
|
|
aspp_num_steps = 8 # Max evolution steps (learnable) |
|
|
aspp_dropout = 0.2 # Regularization |
|
|
hybrid_layer_indices = None # All 30 layers |
|
|
``` |
|
|
|
|
|
|
|
|
## Model Creation from Base |
|
|
|
|
|
```python |
|
|
from AsteriskForCausalLM import AsteriskForCausalLM |
|
|
|
|
|
# Create Asterisk model from SmolLM2 base |
|
|
model, base_model = AsteriskForCausalLM.from_pretrained_base( |
|
|
"HuggingFaceTB/SmolLM2-135M-Instruct", |
|
|
hybrid_layer_indices=None, # None = all layers |
|
|
aspp_hidden_dim=256, # Internal ASPP dimension |
|
|
aspp_num_steps=8, # K-step evolution |
|
|
aspp_dropout=0.2, # Dropout rate |
|
|
torch_dtype=torch.bfloat16, |
|
|
device_map="auto", |
|
|
) |
|
|
|
|
|
# Base model parameters are transferred, ASPP parameters initialized randomly |
|
|
model.load_state_dict(base_model.state_dict(), strict=False) |
|
|
``` |
|
|
|
|
|
## Theoretical Background |
|
|
|
|
|
### Universality (Theorem 2.1) |
|
|
ASPP can simulate any Message-Passing Neural Network (MPNN) function on finite graphs in D steps, where D is the graph diameter. |
|
|
|
|
|
### Convergence (Theorem 2.2) |
|
|
Exponential convergence to fixed points with rate c=0.76 under Lipschitz continuity. |
|
|
|
|
|
### Turing Completeness |
|
|
Proven via cyclic tag system simulation - ASPP can compute any Turing-computable function given sufficient depth. |
|
|
|
|
|
**Implementation Note**: This implementation simplifies theoretical ASPP to point-wise evolution to reduce overfitting while maintaining iterative refinement benefits. |
|
|
|
|
|
## Files in Checkpoint |
|
|
|
|
|
``` |
|
|
Asterisk/ |
|
|
├── AsteriskForCausalLM.py # Model implementation (required for trust_remote_code) |
|
|
├── config.json # Model configuration with auto_map |
|
|
├── model.safetensors # Model weights |
|
|
├── tokenizer.json # Tokenizer |
|
|
├── generation_config.json # Generation settings |
|
|
└── README.md # This file |
|
|
``` |
|
|
|
|
|
## Dependencies |
|
|
|
|
|
```bash |
|
|
pip install torch>=2.0.0 |
|
|
pip install transformers>=4.40.0 |
|
|
pip install trl>=0.8.0 |
|
|
pip install datasets>=2.14.0 |
|
|
pip install accelerate>=0.25.0 |
|
|
pip install bitsandbytes |
|
|
``` |
|
|
|
|
|
## Citations |
|
|
|
|
|
If you use this model, please cite: |
|
|
|
|
|
```bibtex |
|
|
@misc{asterisk2026, |
|
|
title={Asterisk: Hybrid ASPP-Attention Architecture for Enhanced Language Modeling}, |
|
|
author={NoesisLab}, |
|
|
year={2026}, |
|
|
publisher={Huggingface}, |
|
|
url={https://huggingface.co/NoesisLab/Asterisk} |
|
|
} |
|
|
``` |
|
|
|
|
|
```bibtex |
|
|
@misc{vonwerra2022trl, |
|
|
title={{TRL: Transformer Reinforcement Learning}}, |
|
|
author={Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallouédec}, |
|
|
year={2020}, |
|
|
journal={GitHub repository}, |
|
|
publisher={GitHub}, |
|
|
howpublished={\url{https://github.com/huggingface/trl}} |
|
|
} |
|
|
``` |
|
|
|
|
|
```bibtex |
|
|
@article{allal2024SmolLM2, |
|
|
title={SmolLM2 - with great data, comes great performance}, |
|
|
author={Allal, Loubna Ben and Lozhkov, Anton and Penedo, Guilherme and Wolf, Thomas and von Werra, Leandro}, |
|
|
year={2024} |
|
|
} |
|
|
``` |
|
|
|
|
|
## License |
|
|
|
|
|
This model inherits the Apache 2.0 license from SmolLM2-135M-Instruct. |
|
|
|
|
|
## Framework Versions |
|
|
|
|
|
- **TRL**: 0.27.0 |
|
|
- **Transformers**: 4.57.6 |
|
|
- **PyTorch**: 2.8.0+cu128 |
|
|
- **Datasets**: 4.5.0 |
|
|
- **Tokenizers**: 0.22.2 |
|
|
|
|
|
## Acknowledgments |
|
|
|
|
|
Built on top of [SmolLM2-135M-Instruct](https://huggingface.co/HuggingFaceTB/SmolLM2-135M-Instruct) by HuggingFace. Training framework powered by [TRL](https://github.com/huggingface/trl). |
|
|
|