File size: 2,380 Bytes
bbd259b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from transformers import pipeline
import torch

print("context_llm module loaded (Zero-Shot BART)")

# Global pipeline variable
classifier = None

def load_context_model():
    """
    Lazy load the Zero-Shot Classification pipeline.
    Uses facebook/bart-large-mnli.
    """
    global classifier
    if classifier is not None:
        return

    try:
        # Use CPU by default to be safe on Windows, or cuda if available
        device = 0 if torch.cuda.is_available() else -1
        
        print("[LLM] Loading valhalla/distilbart-mnli-12-3 (Distilled) for context analysis...")
        classifier = pipeline(
            "zero-shot-classification",
            model="valhalla/distilbart-mnli-12-3",
            device=device
        )
        print("[LLM] Context model loaded successfully.")
    except Exception as e:
        print(f"[LLM] CRITICAL ERROR: {e}")
        # non-fatal, will just return neutral scores
        pass

def get_context_probs(text: str) -> list:
    """
    Analyzes text against specific hypotheses to determine deep context.
    Returns probabilities for:
    [
      0: "Political Criticism" (Anti-Govt),
      1: "National Criticism" (Anti-India), 
      2: "Political Praise" (Pro-Govt),
      3: "National Praise" (Pro-India)
    ]
    """
    # Lazy load
    if classifier is None:
        load_context_model()

    if classifier is None:
        # Fallback if model failed to load
        return [0.25, 0.25, 0.25, 0.25]

    labels = [
        "criticism of the government",   # 0
        "criticism of the country",      # 1
        "praise of the government",      # 2
        "praise of the country"          # 3
    ]

    try:
        result = classifier(text, candidate_labels=labels, multi_label=False)
        
        # Result has 'labels' and 'scores' sorted by score descending.
        # We need to map them back to our fixed order [0, 1, 2, 3]
        
        score_map = {label: score for label, score in zip(result['labels'], result['scores'])}
        
        ordered_scores = [
            score_map.get(labels[0], 0.0),
            score_map.get(labels[1], 0.0),
            score_map.get(labels[2], 0.0),
            score_map.get(labels[3], 0.0)
        ]
        
        return ordered_scores

    except Exception as e:
        print(f"[LLM] Inference failed: {e}")
        return [0.25, 0.25, 0.25, 0.25]