epic / README.md
baonn's picture
Update README.md
566a4e6 verified

EPIC Router Family

This repository hosts the public checkpoints for the EPIC router models. Each checkpoint learns to pick the best reasoning configuration (method, aggregator, sample count, etc.) given a natural-language math question.

Available Versions

Subdirectory File Notes
0.25/ router_model.pt Router trained for the cost-accuracy trade-off = 0.25.
0.5/ router_model.pt Router trained for the cost-accuracy trade-off = 0.5.
0.75/ router_model.pt Router trained for the cost-accuracy trade-off = 0.75
1.0/ router_model.pt Router trained for the cost-accuracy trade-off = 1.0

Each checkpoint contains:

  • state_dict: PyTorch weights for RouterScoringModel
  • model_name: base encoder identifier (defaults to sentence-transformers/all-MiniLM-L6-v2)
  • projection_dim: dimension of the projection head
  • methods: serialized reasoning configurations; each row corresponds to one column in the router head

Quickstart (Python)

Install the package locally:

git clone https://github.com/nguyenngocbaocmt02/epic.git
cd EPIC
pip install -e .

Load a checkpoint and route a question:

from huggingface_hub import hf_hub_download
import torch
from epic.router.models import RouterScoringModel, MiniLMQuestionEncoder, QuestionProjector
from epic.data_schemas.reasoning import ReasoningConfig

REPO_ID = "baonn/epic"
VERSION = "1.0"  # or 0.5 / 0.75 / 0.25

checkpoint_path = hf_hub_download(
    repo_id=REPO_ID,
    filename="router_model.pt",
    subfolder=VERSION,
)
checkpoint = torch.load(checkpoint_path, map_location="cpu")

encoder = MiniLMQuestionEncoder(
    model_name=checkpoint.get("model_name", "sentence-transformers/all-MiniLM-L6-v2"),
    trainable=False,
)
projector = QuestionProjector(
    input_dim=encoder.transformer.config.hidden_size,
    projection_dim=int(checkpoint["projection_dim"]),
)
model = RouterScoringModel(
    question_encoder=encoder,
    projector=projector,
    num_methods=len(checkpoint["methods"]),
)
model.load_state_dict(checkpoint["state_dict"])
model.eval()

reasoning_configs = [
    ReasoningConfig.deserialize(payload) for payload in checkpoint["methods"]
]

questions = ["x + 20 = 30 then x = ?", "How many positive divisors does 3600 have?"]
with torch.no_grad():
    logits = model(questions)
    method_indices = torch.argmax(logits, dim=1).tolist()

print("Recommended config for question 1:", reasoning_configs[method_indices[0]].serialize(include_samples=True))
print("Recommended config for question 2:", reasoning_configs[method_indices[1]].serialize(include_samples=True))