File size: 2,212 Bytes
abc86a6
 
 
 
 
 
 
 
 
 
5c41baf
abc86a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
import joblib
import numpy as np
from sentence_transformers import SentenceTransformer

# ── Lazy-load models on first use (faster Spaces startup) ──────────────────
_embedding_model = None
_classifier = None

MODEL_PATH = os.path.join(os.path.dirname(__file__), "models", "log_classifier.joblib")
CONFIDENCE_THRESHOLD = 0.3


def _load_models():
    global _embedding_model, _classifier
    if _embedding_model is None:
        _embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
    if _classifier is None:
        if not os.path.exists(MODEL_PATH):
            raise FileNotFoundError(
                f"Model not found at {MODEL_PATH}. "
                "Run the Colab training notebook first and upload log_classifier.joblib."
            )
        _classifier = joblib.load(MODEL_PATH)


def classify_with_bert(log_message: str) -> tuple[str, float]:
    """
    Tier 2: BERT embedding + Logistic Regression classifier.
    Returns (label, confidence). Returns ('Unclassified', max_prob) if
    no class exceeds CONFIDENCE_THRESHOLD.
    Latency: ~20-80ms on CPU.
    """
    _load_models()

    embedding = _embedding_model.encode([log_message])
    probabilities = _classifier.predict_proba(embedding)[0]
    max_prob = float(np.max(probabilities))

    if max_prob < CONFIDENCE_THRESHOLD:
        return "Unclassified", max_prob

    predicted_label = _classifier.predict(embedding)[0]
    return predicted_label, max_prob


def get_classes() -> list[str]:
    """Return list of classes the BERT classifier knows."""
    _load_models()
    return list(_classifier.classes_)


if __name__ == "__main__":
    test_logs = [
        "GET /v2/servers/detail HTTP/1.1 status: 404 len: 1583 time: 0.19",
        "System crashed due to driver errors when restarting the server",
        "Multiple login failures occurred on user 6454 account",
        "Admin access escalation detected for user 9429",
        "CPU usage at 98% for the last 10 minutes on node-7",
        "Hey bro chill ya!",   # should be Unclassified
    ]
    for log in test_logs:
        label, conf = classify_with_bert(log)
        print(f"[{conf:.0%}] {label:25s} | {log[:70]}")