You need to agree to share your contact information to access this model

This repository is publicly accessible, but you have to accept the conditions to access its files and content.

Log in or Sign Up to review the conditions and access this model content.

YAML Metadata Warning: empty or missing yaml metadata in repo card

Check out the documentation for more information.

MentorCollab MLP Models

Branch prediction MLP models for collaborative decoding between different-sized language models.

🎯 Overview

This repository contains trained MLP (Multi-Layer Perceptron) models that predict which branch to take during collaborative decoding. Each model is trained for a specific base LLM and task type.

πŸ“¦ Available Models

Base Model HuggingFace Model Hidden Size Tasks
Qwen3_1.7B Qwen/Qwen3-1.7B 1536 Math, General
qwen3_8B_base Qwen/Qwen3-8B-Base 4096 Math, General
llama3_1_8B meta-llama/Llama-3.1-8B 4096 Math, General
llama3_2_3B meta-llama/Llama-3.2-3B 3072 Math, General
gemma3_4b_it google/gemma-3-4b-it 3072 Math, General
gemma3_4b_pt google/gemma-3-4b-pt 3072 Math, General

πŸš€ Quick Start

Installation

pip install torch huggingface_hub

Download Model Loader

from huggingface_hub import hf_hub_download

# Download the model loader
loader_file = hf_hub_download(
    repo_id="haojinw0027/MentorCollab-MLP",
    filename="model_loader.py"
)

Load a Model

from model_loader import load_branch_mlp

# Load MLP for Qwen3-8B on Math tasks
mlp_model = load_branch_mlp(
    base_model="qwen3_8B_base",
    task="Math"
)

# Use for inference
import torch
with torch.no_grad():
    hidden_state = torch.randn(1, 4096)  # Example hidden state
    score = mlp_model(hidden_state)
    branch = 'A' if score > 0.5 else 'B'
    print(f"Predicted branch: {branch}")

List All Available Models

from model_loader import list_available_models

list_available_models()

πŸ“– Model Details

Architecture

Each MLP model has the following architecture:

  • Input: Hidden state from the base LLM (dimension varies by model)
  • Hidden Layer: 4096 units with ReLU activation
  • Output: Single sigmoid output (probability of choosing branch A)

Training

Models are trained to predict the optimal branch choice at decision points during collaborative decoding between a smaller "student" model and a larger "mentor" model.

πŸ’‘ Usage Example

Complete Inference Pipeline

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from model_loader import load_branch_mlp

# Load base model
base_model_name = "Qwen/Qwen3-8B-Base"
model = AutoModelForCausalLM.from_pretrained(base_model_name)
tokenizer = AutoTokenizer.from_pretrained(base_model_name)

# Load branch prediction MLP
mlp = load_branch_mlp(
    base_model="qwen3_8B_base",
    task="Math"
)

# Extract hidden state at branch point
prompt = "Solve: 2 + 2 = ?"
inputs = tokenizer(prompt, return_tensors="pt")

with torch.no_grad():
    outputs = model(**inputs, output_hidden_states=True)
    hidden_state = outputs.hidden_states[-1][0, -1, :]  # Last token, last layer

    # Predict branch
    score = mlp(hidden_state.unsqueeze(0).float())
    branch = 'A' if score.item() > 0.5 else 'B'
    print(f"Branch prediction: {branch} (score: {score.item():.4f})")

πŸ“‚ Repository Structure

MentorCollab-MLP/
β”œβ”€β”€ Qwen3_1.7B/
β”‚   β”œβ”€β”€ Math/
β”‚   β”‚   β”œβ”€β”€ branch_mlp.pth
β”‚   β”‚   └── config.json
β”‚   └── General/
β”‚       β”œβ”€β”€ branch_mlp.pth
β”‚       └── config.json
β”œβ”€β”€ qwen3_8B_base/
β”‚   └── ... (similar structure)
β”œβ”€β”€ llama3_1_8B/
β”œβ”€β”€ llama3_2_3B/
β”œβ”€β”€ gemma3_4b_it/
β”œβ”€β”€ gemma3_4b_pt/
β”œβ”€β”€ model_loader.py
└── README.md

πŸ”§ Advanced Usage

Load from Local Path

mlp = load_branch_mlp(
    base_model="qwen3_8B_base",
    task="Math",
    local_path="/path/to/local/models"
)

Specify Device

mlp = load_branch_mlp(
    base_model="qwen3_8B_base",
    task="Math",
    device="cuda:0"  # or "cpu"
)

πŸ“Š Model Sizes

Each MLP model file (branch_mlp.pth) is approximately 289 MB.

πŸ“„ License

Please refer to the main repository for license information.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support