|
|
import sys |
|
|
import os |
|
|
|
|
|
|
|
|
ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) |
|
|
sys.path.insert(0, ROOT_DIR) |
|
|
|
|
|
from src.language_detection import detect_language |
|
|
from src.preprocessing import clean_text |
|
|
from src.predict import predict |
|
|
from src.feature_builder import build_features |
|
|
from src.anchor_similarity import compute_similarity |
|
|
from src.embeddings import embedder |
|
|
from src.sarcasm import sarcasm_score |
|
|
from src.sentiment import sentiment_scores |
|
|
from src.translation import translate_to_english |
|
|
from src.context_llm import get_context_probs |
|
|
|
|
|
|
|
|
SUPPORTED_LANGS = {"en", "hi", "ta", "ur", "bn", "te", "ml", "gu", "kn", "mr"} |
|
|
|
|
|
LABELS = [ |
|
|
"Pro-India", |
|
|
"Anti-India", |
|
|
"Pro-Government", |
|
|
"Anti-Government", |
|
|
"Neutral" |
|
|
] |
|
|
|
|
|
def init_anchors(): |
|
|
""" |
|
|
Load anchor text from data/anchors/, encode them, and inject into anchor_similarity module. |
|
|
""" |
|
|
print("[INIT] Loading anchor embeddings...") |
|
|
anchor_dir = os.path.join(ROOT_DIR, "data", "anchors") |
|
|
|
|
|
|
|
|
keys = ["pro_india", "anti_india", "pro_government", "anti_government", "neutral"] |
|
|
loaded_anchors = {} |
|
|
|
|
|
for key in keys: |
|
|
file_path = os.path.join(anchor_dir, f"{key}.txt") |
|
|
if not os.path.exists(file_path): |
|
|
print(f"[WARNING] Anchor file missing: {file_path}") |
|
|
continue |
|
|
|
|
|
with open(file_path, "r", encoding="utf-8") as f: |
|
|
lines = [line.strip() for line in f if line.strip()] |
|
|
|
|
|
if not lines: |
|
|
print(f"[WARNING] Anchor file empty: {key}") |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
embeddings_matrix = embedder.encode(lines) |
|
|
loaded_anchors[key] = embeddings_matrix |
|
|
print(f" - Loaded {key}: {len(lines)} examples") |
|
|
|
|
|
|
|
|
from src.anchor_similarity import load_anchor_embeddings |
|
|
load_anchor_embeddings(loaded_anchors) |
|
|
print("[INIT] Anchor embeddings initialized.\n") |
|
|
|
|
|
def classify(text: str): |
|
|
|
|
|
text = clean_text(text) |
|
|
|
|
|
if len(text.strip()) == 0: |
|
|
return {"error": "Empty input text"} |
|
|
|
|
|
|
|
|
lang, prob = detect_language(text) |
|
|
|
|
|
|
|
|
print(f"[DEBUG] Detected language: {lang}, confidence: {round(prob, 3)}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
processing_text = text |
|
|
if lang != 'en': |
|
|
print(f"[INFO] Translating {lang} to en...") |
|
|
translated = translate_to_english(text, source=lang) |
|
|
print(f" -> {translated}") |
|
|
processing_text = translated |
|
|
|
|
|
|
|
|
text_embedding = embedder.encode(processing_text, normalize_embeddings=True) |
|
|
|
|
|
|
|
|
similarity_scores = compute_similarity( |
|
|
text_embedding=text_embedding, |
|
|
anchor_embeddings=None |
|
|
) |
|
|
|
|
|
|
|
|
sentiment = sentiment_scores(processing_text) |
|
|
sarcasm = sarcasm_score(processing_text) |
|
|
|
|
|
|
|
|
context_probs = get_context_probs(processing_text) |
|
|
|
|
|
|
|
|
features = build_features( |
|
|
similarity=similarity_scores, |
|
|
sentiment=sentiment, |
|
|
sarcasm=sarcasm, |
|
|
context_probs=context_probs |
|
|
) |
|
|
|
|
|
|
|
|
label_idx, confidence = predict(features) |
|
|
|
|
|
return { |
|
|
"text": text, |
|
|
"label": LABELS[label_idx], |
|
|
"confidence": round(confidence, 3), |
|
|
"language": lang, |
|
|
"sarcasm_score": round(sarcasm, 3), |
|
|
"sentiment": { |
|
|
"negative": round(sentiment[0], 3), |
|
|
"neutral": round(sentiment[1], 3), |
|
|
"positive": round(sentiment[2], 3), |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
init_anchors() |
|
|
|
|
|
|
|
|
if os.path.exists("test.txt"): |
|
|
print("Processing test.txt...") |
|
|
with open("test.txt","r") as f: |
|
|
for line in f: |
|
|
if line.strip(): |
|
|
result= classify(line) |
|
|
print(result) |
|
|
print("-" * 50) |
|
|
|
|
|
print("\n🔍 Reddit Political Stance Classifier") |
|
|
print("Type 'exit' to quit\n") |
|
|
|
|
|
while True: |
|
|
text = input("Enter Reddit post: ").strip() |
|
|
|
|
|
if text.lower() == "exit": |
|
|
break |
|
|
|
|
|
result = classify(text) |
|
|
print("\nResult:") |
|
|
print(result) |
|
|
print("-" * 50) |
|
|
|