Spaces:
Running
Running
File size: 3,631 Bytes
53605cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# Load model and tokenizer
MODEL_ID = "LLM-Semantic-Router/halugate-sentinel"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
model.eval()
# Label mapping
LABELS = {
0: ("NO_FACT_CHECK_NEEDED", "🟢"),
1: ("FACT_CHECK_NEEDED", "🔴"),
}
def classify_text(text: str) -> tuple[str, dict]:
"""Classify whether a prompt needs fact-checking."""
if not text.strip():
return "Please enter some text to classify.", {}
# Tokenize and predict
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probs = torch.softmax(logits, dim=-1)[0]
# Get prediction
pred_class = torch.argmax(probs).item()
label_name, emoji = LABELS[pred_class]
confidence = probs[pred_class].item()
# Format result
result = f"{emoji} **{label_name}**\n\nConfidence: {confidence:.1%}"
# Confidence scores for both classes
scores = {
f"{LABELS[0][1]} {LABELS[0][0]}": float(probs[0]),
f"{LABELS[1][1]} {LABELS[1][0]}": float(probs[1]),
}
return result, scores
# Example prompts
EXAMPLES = [
["When was the Eiffel Tower built?"],
["What is the population of Tokyo?"],
["Who invented the telephone?"],
["Write a poem about the ocean"],
["Can you help me debug this Python code?"],
["What do you think about modern art?"],
["What year did World War II end?"],
["Calculate 15 * 7 + 3"],
["Translate 'hello' to Spanish"],
["What is the current population of China?"],
]
# Create Gradio interface
with gr.Blocks(title="HaluGate Sentinel - Fact Check Classifier") as demo:
gr.Markdown(
"""
# 🛡️ HaluGate Sentinel
**Fact-Check Classifier** - Determines whether a prompt requires external factual verification.
This model helps identify prompts that contain factual claims or questions that should be
verified against authoritative sources to prevent hallucinations in LLM responses.
- 🔴 **FACT_CHECK_NEEDED**: The prompt contains factual claims/questions that should be verified
- 🟢 **NO_FACT_CHECK_NEEDED**: The prompt is creative, computational, or opinion-based
"""
)
with gr.Row():
with gr.Column(scale=2):
input_text = gr.Textbox(
label="Input Prompt",
placeholder="Enter a prompt to classify...",
lines=4,
)
submit_btn = gr.Button("Classify", variant="primary")
with gr.Column(scale=1):
output_label = gr.Markdown(label="Classification Result")
output_scores = gr.Label(label="Confidence Scores", num_top_classes=2)
gr.Examples(
examples=EXAMPLES,
inputs=input_text,
outputs=[output_label, output_scores],
fn=classify_text,
cache_examples=True,
)
submit_btn.click(
fn=classify_text,
inputs=input_text,
outputs=[output_label, output_scores],
)
input_text.submit(
fn=classify_text,
inputs=input_text,
outputs=[output_label, output_scores],
)
gr.Markdown(
"""
---
**Model**: [LLM-Semantic-Router/halugate-sentinel](https://huggingface.co/LLM-Semantic-Router/halugate-sentinel)
| **Architecture**: ModernBERT for Sequence Classification
"""
)
if __name__ == "__main__":
demo.launch()
|