Model Card for Model ID
Classification model finetuned for prompt-model routing based on code prompt difficulty.
Model Details
Base model: answerdotai/ModernBERT-base
Model Description
- Developed by: [Christian @ Prime Intellect]
- Finetuned from model: [answerdotai/ModernBERT-base]
How to Get Started with the Model
from transformers import pipeline
import torch
# Load the model
classifier = pipeline(
"text-classification",
model="cdreetz/modern-bert-router",
device=0 if torch.cuda.is_available() else -1
)
# Test easy problem
easy_problem = """
Write a function that returns the sum of two numbers.
Example:
Input: add(2, 3)
Output: 5
"""
# Test hard problem
hard_problem = """
Given a binary tree, find the maximum path sum. The path may start and end at any node in the tree.
A path is defined as any sequence of nodes from some starting node to any node in the tree along the parent-child connections.
The path must contain at least one node and does not need to go through the root.
Example:
Input: root = [1,2,3]
Output: 6 (2 -> 1 -> 3)
"""
# Run predictions
result_easy = classifier(easy_problem)[0]
result_hard = classifier(hard_problem)[0]
print("Easy Problem:")
print(f" Difficulty: {result_easy['label']}")
print(f" Confidence: {result_easy['score']:.2%}\n")
print("Hard Problem:")
print(f" Difficulty: {result_hard['label']}")
print(f" Confidence: {result_hard['score']:.2%}")
Training Details
# /// script
# dependencies = [
# "chatan",
# "transformers",
# "datasets",
# "torch",
# "accelerate",
# "scikit-learn",
# "triton",
# "huggingface_hub"
# ]
# ///
import os
import asyncio
import chatan as ch
from datasets import Dataset as hf_ds
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
import numpy as np
from sklearn.metrics import accuracy_score, f1_score
import torch
import triton
from huggingface_hub import login
#torch._dynamo.config.suppress_errors = True
login()
async def create_dataset():
gen = ch.async_generator("openai", os.getenv("OPENAI_API_KEY"))
ds = ch.async_dataset({
"difficulty": ch.sample.choice(["easy", "hard"]),
"text": gen("write a coding problem of difficulty {difficulty}")
})
df = await ds.generate(n=1000, max_concurrent_rows=500)
df['labels'] = df['difficulty'].map({"easy": 0, "hard": 1})
dataset = hf_ds.from_pandas(df[['text', 'labels']])
return dataset.train_test_split(test_size=0.2, seed=42)
def train(dataset):
tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")
model = AutoModelForSequenceClassification.from_pretrained(
"answerdotai/ModernBERT-base",
num_labels=2,
label2id={"easy": 0, "hard": 1},
id2label={0: "easy", 1: "hard"},
problem_type="single_label_classification"
)
def tokenize(examples):
return tokenizer(
examples['text'],
padding='max_length',
truncation=True,
max_length=512
)
tokenized = dataset.map(tokenize, batched=True, remove_columns=['text'])
tokenized.set_format('torch')
training_args = TrainingArguments(
output_dir="./modernbert-router",
eval_strategy="epoch",
num_train_epochs=3,
per_device_train_batch_size=16,
learning_rate=2e-5,
save_strategy="epoch",
load_best_model_at_end=True
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized["train"],
eval_dataset=tokenized["test"],
processing_class=tokenizer,
compute_metrics=lambda p: {
"accuracy": accuracy_score(p.label_ids, np.argmax(p.predictions, axis=1)),
"f1": f1_score(p.label_ids, np.argmax(p.predictions, axis=1))
}
)
print("starting training")
trainer.train()
model.save_pretrained("./modernbert-router")
tokenizer.save_pretrained("./modernbert-router")
print("donezo")
#model = AutoModelForSequenceClassification.from_pretrained("./modernbert-router")
#tokenizer = AutoTokenizer.from_pretrained("./modernbert-router")
#model.push_to_hub("cdreetz/modern-bert-router")
#tokenizer.push_to_hub("cdreetz/modern-bert-router")
if __name__ == "__main__":
dataset = asyncio.run(create_dataset())
train(dataset)
Citation [optional]
BibTeX:
@software{modern-bert-router,
author = {Reetz, Christian},
title = {ModernBERTRouter},
url = {https://huggingface.co/cdreetz/modern-bert-router/},
year = {2025}
}
- Downloads last month
- 6