YAML Metadata Warning: empty or missing yaml metadata in repo card
Check out the documentation for more information.
MentorCollab MLP Models
Branch prediction MLP models for collaborative decoding between different-sized language models.
π― Overview
This repository contains trained MLP (Multi-Layer Perceptron) models that predict which branch to take during collaborative decoding. Each model is trained for a specific base LLM and task type.
π¦ Available Models
| Base Model | HuggingFace Model | Hidden Size | Tasks |
|---|---|---|---|
| Qwen3_1.7B | Qwen/Qwen3-1.7B | 1536 | Math, General |
| qwen3_8B_base | Qwen/Qwen3-8B-Base | 4096 | Math, General |
| llama3_1_8B | meta-llama/Llama-3.1-8B | 4096 | Math, General |
| llama3_2_3B | meta-llama/Llama-3.2-3B | 3072 | Math, General |
| gemma3_4b_it | google/gemma-3-4b-it | 3072 | Math, General |
| gemma3_4b_pt | google/gemma-3-4b-pt | 3072 | Math, General |
π Quick Start
Installation
pip install torch huggingface_hub
Download Model Loader
from huggingface_hub import hf_hub_download
# Download the model loader
loader_file = hf_hub_download(
repo_id="haojinw0027/MentorCollab-MLP",
filename="model_loader.py"
)
Load a Model
from model_loader import load_branch_mlp
# Load MLP for Qwen3-8B on Math tasks
mlp_model = load_branch_mlp(
base_model="qwen3_8B_base",
task="Math"
)
# Use for inference
import torch
with torch.no_grad():
hidden_state = torch.randn(1, 4096) # Example hidden state
score = mlp_model(hidden_state)
branch = 'A' if score > 0.5 else 'B'
print(f"Predicted branch: {branch}")
List All Available Models
from model_loader import list_available_models
list_available_models()
π Model Details
Architecture
Each MLP model has the following architecture:
- Input: Hidden state from the base LLM (dimension varies by model)
- Hidden Layer: 4096 units with ReLU activation
- Output: Single sigmoid output (probability of choosing branch A)
Training
Models are trained to predict the optimal branch choice at decision points during collaborative decoding between a smaller "student" model and a larger "mentor" model.
π‘ Usage Example
Complete Inference Pipeline
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from model_loader import load_branch_mlp
# Load base model
base_model_name = "Qwen/Qwen3-8B-Base"
model = AutoModelForCausalLM.from_pretrained(base_model_name)
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
# Load branch prediction MLP
mlp = load_branch_mlp(
base_model="qwen3_8B_base",
task="Math"
)
# Extract hidden state at branch point
prompt = "Solve: 2 + 2 = ?"
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True)
hidden_state = outputs.hidden_states[-1][0, -1, :] # Last token, last layer
# Predict branch
score = mlp(hidden_state.unsqueeze(0).float())
branch = 'A' if score.item() > 0.5 else 'B'
print(f"Branch prediction: {branch} (score: {score.item():.4f})")
π Repository Structure
MentorCollab-MLP/
βββ Qwen3_1.7B/
β βββ Math/
β β βββ branch_mlp.pth
β β βββ config.json
β βββ General/
β βββ branch_mlp.pth
β βββ config.json
βββ qwen3_8B_base/
β βββ ... (similar structure)
βββ llama3_1_8B/
βββ llama3_2_3B/
βββ gemma3_4b_it/
βββ gemma3_4b_pt/
βββ model_loader.py
βββ README.md
π§ Advanced Usage
Load from Local Path
mlp = load_branch_mlp(
base_model="qwen3_8B_base",
task="Math",
local_path="/path/to/local/models"
)
Specify Device
mlp = load_branch_mlp(
base_model="qwen3_8B_base",
task="Math",
device="cuda:0" # or "cpu"
)
π Model Sizes
Each MLP model file (branch_mlp.pth) is approximately 289 MB.
π License
Please refer to the main repository for license information.
Inference Providers NEW
This model isn't deployed by any Inference Provider. π Ask for provider support