File size: 3,828 Bytes
91a7101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
---
language:
- en
- vi
tags:
- safety
- guardrail
- routing
- pytorch
- tabular-classification
metrics:
- f1
- accuracy
- precision
- recall
---

# SafeRoute Router Model (DynaGuard 1.7B / 8B)

This repository contains the weights for the **SafeRoute Router**, an optimized neural router designed to dynamically direct input prompts/responses between a lightweight safety classifier (Small Model) and a high-capacity safety classifier (Large Model). 

By routing "easy/safe" queries to the small model and reserving the large model only for "hard/unsafe" queries, the system drastically reduces inference latency and computational cost while preserving overall safety evaluation performance.

## Model Details

- **Architecture:** Multi-Layer Perceptron (MLP) with 3 hidden layers (`1024 -> 512 -> 256`), utilizing `BatchNorm1d`, `GELU` activations, and moderate `Dropout` (0.3).
- **Input Dimension:** `2048` (feature embeddings extracted from the small safety model).
- **Output Dimension:** `1` (binary classification logit indicating routing probability).
- **Loss Function:** `Focal Loss` ($\alpha=0.75, \gamma=2.0$) tailored to address severe class imbalance.
- **Optimizer & Scheduler:** `AdamW` with `CosineAnnealingWarmRestarts`.

## Evaluation Results

Evaluated on a balanced Test Benchmark at the optimal decision threshold (**0.6**):

| Metric | Score |
| :--- | :---: |
| **F1 Score** | **0.7525** |
| **Accuracy** | **0.7500** |
| **Precision** | **0.7451** |
| **Recall** | **0.7600** |
| **Overall AUPRC** | **0.7588** |

*Note: The high recall (0.76) combined with solid precision (0.74) ensures that potentially unsafe or ambiguous prompts are reliably intercepted and routed to the Large Model for thorough inspection.*

## How to Get Started with the Model

You can easily download and use this model in your PyTorch pipeline:

```python
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download

# 1. Define the Router Architecture
class RouterMLP(nn.Module):
    def __init__(self, input_dim=2048):
        super().__init__()
        self.cls = nn.Sequential(
            nn.Linear(input_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.GELU(),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.GELU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(256, 1),
        )

    def forward(self, x):
        return self.cls(x).squeeze(-1)

# 2. Download and Load the Checkpoint
repo_id = "YOUR_HF_USERNAME/safe-route-dynaguard" # <-- Replace with your repo name
model_path = hf_hub_download(repo_id=repo_id, filename="model.pt")

device = "cuda" if torch.cuda.is_available() else "cpu"
router = RouterMLP(input_dim=2048).to(device)

ckpt = torch.load(model_path, map_location=device)
router.load_state_dict(ckpt["state_dict"], strict=False)
router.eval()

# 3. Perform Routing Inference
with torch.no_grad():
    # Example feature tensor extracted from small model
    sample_features = torch.randn(4, 2048, device=device)
    
    logits = router(sample_features)
    routing_probs = torch.sigmoid(logits)
    
    # Use recommended threshold 0.6
    decisions = (routing_probs > 0.6).long()
    
    for i, decision in enumerate(decisions):
        if decision == 1:
            print(f"Sample {i}: Route to LARGE Model (Hard/Unsafe)")
        else:
            print(f"Sample {i}: Use SMALL Model (Easy/Safe)")
```

## Intended Use

- **Primary Use Case:** Guardrail optimization in LLM serving pipelines.
- **Out-of-Scope:** Standalone toxicity classification directly from raw text (this model requires intermediate hidden feature representations from a pre-trained small safety model).