epic / README.md
baonn's picture
Update README.md
566a4e6 verified
---
{}
---
# EPIC Router Family
This repository hosts the public checkpoints for the EPIC router models. Each
checkpoint learns to pick the best reasoning configuration (method, aggregator,
sample count, etc.) given a natural-language math question.
## Available Versions
| Subdirectory | File | Notes |
|--------------|-------------------|-------------------------------------|
| `0.25/` | `router_model.pt` | Router trained for the cost-accuracy trade-off = 0.25. |
| `0.5/` | `router_model.pt` | Router trained for the cost-accuracy trade-off = 0.5. |
| `0.75/` | `router_model.pt` | Router trained for the cost-accuracy trade-off = 0.75 |
| `1.0/` | `router_model.pt` | Router trained for the cost-accuracy trade-off = 1.0 |
Each checkpoint contains:
- `state_dict`: PyTorch weights for `RouterScoringModel`
- `model_name`: base encoder identifier (defaults to `sentence-transformers/all-MiniLM-L6-v2`)
- `projection_dim`: dimension of the projection head
- `methods`: serialized reasoning configurations; each row corresponds to one column in the router head
## Quickstart (Python)
Install the package locally:
```bash
git clone https://github.com/nguyenngocbaocmt02/epic.git
cd EPIC
pip install -e .
```
Load a checkpoint and route a question:
```python
from huggingface_hub import hf_hub_download
import torch
from epic.router.models import RouterScoringModel, MiniLMQuestionEncoder, QuestionProjector
from epic.data_schemas.reasoning import ReasoningConfig
REPO_ID = "baonn/epic"
VERSION = "1.0" # or 0.5 / 0.75 / 0.25
checkpoint_path = hf_hub_download(
repo_id=REPO_ID,
filename="router_model.pt",
subfolder=VERSION,
)
checkpoint = torch.load(checkpoint_path, map_location="cpu")
encoder = MiniLMQuestionEncoder(
model_name=checkpoint.get("model_name", "sentence-transformers/all-MiniLM-L6-v2"),
trainable=False,
)
projector = QuestionProjector(
input_dim=encoder.transformer.config.hidden_size,
projection_dim=int(checkpoint["projection_dim"]),
)
model = RouterScoringModel(
question_encoder=encoder,
projector=projector,
num_methods=len(checkpoint["methods"]),
)
model.load_state_dict(checkpoint["state_dict"])
model.eval()
reasoning_configs = [
ReasoningConfig.deserialize(payload) for payload in checkpoint["methods"]
]
questions = ["x + 20 = 30 then x = ?", "How many positive divisors does 3600 have?"]
with torch.no_grad():
logits = model(questions)
method_indices = torch.argmax(logits, dim=1).tolist()
print("Recommended config for question 1:", reasoning_configs[method_indices[0]].serialize(include_samples=True))
print("Recommended config for question 2:", reasoning_configs[method_indices[1]].serialize(include_samples=True))
```