p7inc3's picture
Update handler.py
d31e158 verified
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import joblib
from transformers import AutoTokenizer, AutoModel
from typing import Dict, List, Any
# =========================================================
# 1. Multi-Task Architecture
# =========================================================
class MultiTaskModel(nn.Module):
def __init__(self, model_name, num_fine, num_family):
super().__init__()
self.encoder = AutoModel.from_pretrained(model_name)
hidden = self.encoder.config.hidden_size
self.dropout = nn.Dropout(0.2)
self.binary_head = nn.Linear(hidden, 1)
self.multi_head = nn.Linear(hidden, num_fine)
self.family_head = nn.Linear(hidden, num_family)
def mean_pooling(self, hidden, attention_mask):
mask = attention_mask.unsqueeze(-1).float()
return (hidden * mask).sum(1) / mask.sum(1).clamp(min=1e-9)
def forward(self, input_ids, attention_mask):
outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask
)
pooled = self.mean_pooling(
outputs.last_hidden_state,
attention_mask
)
x = self.dropout(pooled)
return (
self.binary_head(x),
self.multi_head(x),
self.family_head(x)
)
# =========================================================
# 2. Hugging Face Endpoint Handler
# =========================================================
class EndpointHandler:
def __init__(self, path=""):
self.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
print(f"[INFO] Using device: {self.device}")
# Load encoders
self.fine_le = joblib.load(os.path.join(path, "fine_encoder.pkl"))
self.family_le = joblib.load(os.path.join(path, "family_encoder.pkl"))
# Tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(path)
# Model
self.model = MultiTaskModel(
model_name="microsoft/deberta-v3-small",
num_fine=len(self.fine_le.classes_),
num_family=len(self.family_le.classes_)
).to(self.device)
checkpoint = torch.load(
os.path.join(path, "multitask_model_FINAL.pt"),
map_location=self.device
)
state_dict = checkpoint.get("model_state", checkpoint)
self.model.load_state_dict(state_dict)
self.model.eval()
print("[INFO] Model loaded successfully")
# ============================================
# Threshold Config — TIGHTENED
# ============================================
self.config = {
"HIGH_ATTACK": 0.80, # ← RAISED from 0.75
"MEDIUM_ATTACK": 0.55,
"HIGH_CONF": 0.85,
"LOW_CONF": 0.30,
"EXTREME_CONF": 0.95
}
# =====================================================
# Decision Logic
# =====================================================
def decide(self, danger_prob, fine_score, family_score):
"""
Binary head (danger_prob) is the authority.
Attack type scores only matter if binary head is already suspicious.
"""
cfg = self.config
# 1. Definite safe zone
if danger_prob <= cfg["LOW_CONF"]:
return False
# 2. Danger zone — binary head must be confident it's dangerous
if danger_prob >= cfg["HIGH_CONF"]:
# Strong attack type confidence → dangerous
if fine_score >= cfg["HIGH_ATTACK"] or family_score >= cfg["HIGH_ATTACK"]:
return True
# Medium attack type confidence → dangerous only if binary is EXTREMELY confident
if fine_score >= cfg["MEDIUM_ATTACK"] or family_score >= cfg["MEDIUM_ATTACK"]:
if danger_prob >= cfg["EXTREME_CONF"]:
return True
# 3. Gray zone — always safe
return False
# =====================================================
# Predict Single
# =====================================================
def predict_single(self, text: str):
tokenized = self.tokenizer(
text,
return_tensors="pt",
truncation=True,
padding=True,
max_length=512
)
tokenized = {k: v.to(self.device) for k, v in tokenized.items()}
with torch.no_grad():
binary_logits, multi_logits, family_logits = self.model(
tokenized["input_ids"],
tokenized["attention_mask"]
)
# ================================
# Probabilities
# ================================
danger_prob = torch.sigmoid(binary_logits).item()
multi_probs = F.softmax(multi_logits, dim=1)
family_probs = F.softmax(family_logits, dim=1)
fine_idx = torch.argmax(multi_probs, dim=1).item()
family_idx = torch.argmax(family_probs, dim=1).item()
fine_score = multi_probs[0][fine_idx].item()
family_score = family_probs[0][family_idx].item()
# ================================
# Decision
# ================================
is_dangerous = self.decide(
danger_prob,
fine_score,
family_score
)
# ================================
# Labels — only predict attack type if dangerous
# ================================
if is_dangerous:
attack_type = self.fine_le.inverse_transform([fine_idx])[0]
attack_family = self.family_le.inverse_transform([family_idx])[0]
else:
attack_type = "none"
attack_family = "none"
fine_score = 0.0
family_score = 0.0
# ================================
# Explainability
# ================================
suspicious_keywords = [
"ignore", "override", "reveal", "system prompt",
"developer mode", "bypass", "disable",
"forget instructions", "pretend", "simulate", "jailbreak"
]
found_keywords = [
kw for kw in suspicious_keywords
if kw in text.lower()
]
# ================================
# Response
# ================================
return {
"status": "DANGEROUS" if is_dangerous else "SAFE",
"binary_confidence": round(danger_prob, 4),
"confidence": round(
danger_prob if is_dangerous else (1 - danger_prob),
4
),
"attack_type": {
"label": attack_type,
"score": round(fine_score, 4)
},
"attack_family": {
"label": attack_family,
"score": round(family_score, 4)
},
"trigger_words": found_keywords
}
# =====================================================
# Main Entry
# =====================================================
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
inputs = data["inputs"] if isinstance(data, dict) else data
if isinstance(inputs, str):
inputs = [inputs]
return [self.predict_single(text) for text in inputs]