CPU-Based-AI-Guardrail / inference3.py
boying07's picture
Upload 2 files
8487231 verified
Raw
History Blame Contribute Delete
5.84 kB
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
class SharedEncoder(nn.Module):
def __init__(self, model_name):
super().__init__()
self.encoder = AutoModel.from_pretrained(model_name)
def mean_pool(self, hidden, mask):
mask = mask.unsqueeze(-1).expand(hidden.size()).float()
masked = hidden * mask
summed = masked.sum(1)
counts = mask.sum(1).clamp(min=1e-9)
return summed / counts
def forward(self, input_ids, attention_mask):
outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask
)
pooled = self.mean_pool(outputs.last_hidden_state, attention_mask)
pooled = F.normalize(pooled, p=2, dim=-1)
return pooled
class ClassifierHead(nn.Module):
def __init__(self, dim=768):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, 256),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(256, 1)
)
def forward(self, x):
return self.net(x).squeeze(-1)
def load_models():
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "dbmdz/bert-base-turkish-cased"
encoder_path = "HomayShield_v5/homayshield_encoder.pt"
classifier_path = "HomayShield_v5/homayshield_classifier.pt"
attack_bank_path = "HomayShield_v5/homayshield_attack_bank.npy"
tokenizer = AutoTokenizer.from_pretrained(model_name)
encoder = SharedEncoder(model_name).to(device)
encoder.load_state_dict(torch.load(encoder_path, map_location=device))
encoder.eval()
classifier = ClassifierHead().to(device)
classifier.load_state_dict(torch.load(classifier_path, map_location=device))
classifier.eval()
attack_bank = np.load(attack_bank_path)
return tokenizer, encoder, classifier, attack_bank, device
def encode_text(text, tokenizer, encoder, device):
batch = tokenizer(
text,
truncation=True,
max_length=256,
padding="max_length",
return_tensors="pt"
)
batch = {k: v.to(device) for k, v in batch.items()}
with torch.no_grad():
emb = encoder(batch["input_ids"], batch["attention_mask"])
return emb[0].cpu().numpy()
def semantic_score(emb, attack_bank):
return float(np.max(attack_bank @ emb))
def predict(text, mode, config, tokenizer, encoder, classifier, attack_bank, device):
emb = encode_text(text, tokenizer, encoder, device)
attack_score = semantic_score(emb, attack_bank)
with torch.no_grad():
x = torch.tensor(emb).float().unsqueeze(0).to(device)
logits = classifier(x)
classifier_score = torch.sigmoid(logits).item()
if mode == "or":
label = "ATTACK" if (
attack_score >= config["semantic_threshold"] or
classifier_score >= config["classifier_threshold"]
) else "NORMAL"
elif mode == "fusion":
fusion_score = (
config["semantic_weight"] * attack_score +
config["classifier_weight"] * classifier_score
)
label = "ATTACK" if fusion_score >= config["fusion_threshold"] else "NORMAL"
elif mode == "semantic_only":
label = "ATTACK" if attack_score >= config["semantic_threshold"] else "NORMAL"
elif mode == "classifier_only":
label = "ATTACK" if classifier_score >= config["classifier_threshold"] else "NORMAL"
else:
raise ValueError("Invalid mode")
return {
"label": label,
"semantic_score": attack_score,
"classifier_score": classifier_score
}
def ask_mode():
print("\nSelect Mode:")
print("1 -> OR")
print("2 -> Fusion")
print("3 -> Semantic Only")
print("4 -> Classifier Only")
choice = input("Enter choice: ").strip()
mapping = {
"1": "or",
"2": "fusion",
"3": "semantic_only",
"4": "classifier_only"
}
if choice not in mapping:
raise ValueError("Invalid choice")
return mapping[choice]
def ask_thresholds(mode):
config = {}
if mode in ["or", "semantic_only"]:
config["semantic_threshold"] = float(
input("Semantic threshold (default 0.92): ") or 0.92
)
if mode in ["or", "classifier_only"]:
config["classifier_threshold"] = float(
input("Classifier threshold (default 0.80): ") or 0.80
)
if mode == "fusion":
config["semantic_weight"] = float(
input("Semantic weight (default 0.3): ") or 0.3
)
config["classifier_weight"] = float(
input("Classifier weight (default 0.7): ") or 0.7
)
config["fusion_threshold"] = float(
input("Fusion threshold (default 0.75): ") or 0.75
)
return config
def main():
tokenizer, encoder, classifier, attack_bank, device = load_models()
while True:
try:
mode = ask_mode()
config = ask_thresholds(mode)
text = input("\nEnter text to analyze:\n")
result = predict(
text,
mode,
config,
tokenizer,
encoder,
classifier,
attack_bank,
device
)
print("\n========== RESULT ==========")
print("Label:", result["label"])
print("Semantic Score:", result["semantic_score"])
print("Classifier Score:", result["classifier_score"])
again = input("\nAnalyze another? (y/n): ").strip().lower()
if again != "y":
break
except Exception as e:
print("Error:", e)
if __name__ == "__main__":
main()