File size: 3,631 Bytes
53605cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# Load model and tokenizer
MODEL_ID = "LLM-Semantic-Router/halugate-sentinel"

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
model.eval()

# Label mapping
LABELS = {
    0: ("NO_FACT_CHECK_NEEDED", "🟢"),
    1: ("FACT_CHECK_NEEDED", "🔴"),
}


def classify_text(text: str) -> tuple[str, dict]:
    """Classify whether a prompt needs fact-checking."""
    if not text.strip():
        return "Please enter some text to classify.", {}

    # Tokenize and predict
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)

    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        probs = torch.softmax(logits, dim=-1)[0]

    # Get prediction
    pred_class = torch.argmax(probs).item()
    label_name, emoji = LABELS[pred_class]
    confidence = probs[pred_class].item()

    # Format result
    result = f"{emoji} **{label_name}**\n\nConfidence: {confidence:.1%}"

    # Confidence scores for both classes
    scores = {
        f"{LABELS[0][1]} {LABELS[0][0]}": float(probs[0]),
        f"{LABELS[1][1]} {LABELS[1][0]}": float(probs[1]),
    }

    return result, scores


# Example prompts
EXAMPLES = [
    ["When was the Eiffel Tower built?"],
    ["What is the population of Tokyo?"],
    ["Who invented the telephone?"],
    ["Write a poem about the ocean"],
    ["Can you help me debug this Python code?"],
    ["What do you think about modern art?"],
    ["What year did World War II end?"],
    ["Calculate 15 * 7 + 3"],
    ["Translate 'hello' to Spanish"],
    ["What is the current population of China?"],
]

# Create Gradio interface
with gr.Blocks(title="HaluGate Sentinel - Fact Check Classifier") as demo:
    gr.Markdown(
        """
    # 🛡️ HaluGate Sentinel

    **Fact-Check Classifier** - Determines whether a prompt requires external factual verification.

    This model helps identify prompts that contain factual claims or questions that should be
    verified against authoritative sources to prevent hallucinations in LLM responses.

    - 🔴 **FACT_CHECK_NEEDED**: The prompt contains factual claims/questions that should be verified
    - 🟢 **NO_FACT_CHECK_NEEDED**: The prompt is creative, computational, or opinion-based
    """
    )

    with gr.Row():
        with gr.Column(scale=2):
            input_text = gr.Textbox(
                label="Input Prompt",
                placeholder="Enter a prompt to classify...",
                lines=4,
            )
            submit_btn = gr.Button("Classify", variant="primary")

        with gr.Column(scale=1):
            output_label = gr.Markdown(label="Classification Result")
            output_scores = gr.Label(label="Confidence Scores", num_top_classes=2)

    gr.Examples(
        examples=EXAMPLES,
        inputs=input_text,
        outputs=[output_label, output_scores],
        fn=classify_text,
        cache_examples=True,
    )

    submit_btn.click(
        fn=classify_text,
        inputs=input_text,
        outputs=[output_label, output_scores],
    )

    input_text.submit(
        fn=classify_text,
        inputs=input_text,
        outputs=[output_label, output_scores],
    )

    gr.Markdown(
        """
    ---
    **Model**: [LLM-Semantic-Router/halugate-sentinel](https://huggingface.co/LLM-Semantic-Router/halugate-sentinel)
    | **Architecture**: ModernBERT for Sequence Classification
    """
    )

if __name__ == "__main__":
    demo.launch()