|
|
from transformers import pipeline |
|
|
import torch |
|
|
|
|
|
print("context_llm module loaded (Zero-Shot BART)") |
|
|
|
|
|
|
|
|
classifier = None |
|
|
|
|
|
def load_context_model(): |
|
|
""" |
|
|
Lazy load the Zero-Shot Classification pipeline. |
|
|
Uses facebook/bart-large-mnli. |
|
|
""" |
|
|
global classifier |
|
|
if classifier is not None: |
|
|
return |
|
|
|
|
|
try: |
|
|
|
|
|
device = 0 if torch.cuda.is_available() else -1 |
|
|
|
|
|
print("[LLM] Loading valhalla/distilbart-mnli-12-3 (Distilled) for context analysis...") |
|
|
classifier = pipeline( |
|
|
"zero-shot-classification", |
|
|
model="valhalla/distilbart-mnli-12-3", |
|
|
device=device |
|
|
) |
|
|
print("[LLM] Context model loaded successfully.") |
|
|
except Exception as e: |
|
|
print(f"[LLM] CRITICAL ERROR: {e}") |
|
|
|
|
|
pass |
|
|
|
|
|
def get_context_probs(text: str) -> list: |
|
|
""" |
|
|
Analyzes text against specific hypotheses to determine deep context. |
|
|
Returns probabilities for: |
|
|
[ |
|
|
0: "Political Criticism" (Anti-Govt), |
|
|
1: "National Criticism" (Anti-India), |
|
|
2: "Political Praise" (Pro-Govt), |
|
|
3: "National Praise" (Pro-India) |
|
|
] |
|
|
""" |
|
|
|
|
|
if classifier is None: |
|
|
load_context_model() |
|
|
|
|
|
if classifier is None: |
|
|
|
|
|
return [0.25, 0.25, 0.25, 0.25] |
|
|
|
|
|
labels = [ |
|
|
"criticism of the government", |
|
|
"criticism of the country", |
|
|
"praise of the government", |
|
|
"praise of the country" |
|
|
] |
|
|
|
|
|
try: |
|
|
result = classifier(text, candidate_labels=labels, multi_label=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
score_map = {label: score for label, score in zip(result['labels'], result['scores'])} |
|
|
|
|
|
ordered_scores = [ |
|
|
score_map.get(labels[0], 0.0), |
|
|
score_map.get(labels[1], 0.0), |
|
|
score_map.get(labels[2], 0.0), |
|
|
score_map.get(labels[3], 0.0) |
|
|
] |
|
|
|
|
|
return ordered_scores |
|
|
|
|
|
except Exception as e: |
|
|
print(f"[LLM] Inference failed: {e}") |
|
|
return [0.25, 0.25, 0.25, 0.25] |
|
|
|