Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import joblib | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| from torch.nn.functional import softmax | |
| from huggingface_hub import hf_hub_download | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # ========== Load Models ========== | |
| # Load SBERT encoder (subfolder inside the repo) | |
| sbert = SentenceTransformer("Sbhatti33/sbert_model") | |
| # Load SBERT classifier (from root of the repo) | |
| clf_path = hf_hub_download(repo_id="Sbhatti33/sbert_model", filename="sbert_mlp.pkl") | |
| clf = joblib.load(clf_path) | |
| # ========== Load DeBERTa Model + Tokenizer ========== | |
| # DeBERTa tokenizer and model from your repo | |
| deberta_tokenizer = AutoTokenizer.from_pretrained( | |
| "Sbhatti33/deberta_model", | |
| use_fast=False # critical for crash | |
| ) | |
| deberta_model = AutoModelForSequenceClassification.from_pretrained( | |
| "Sbhatti33/deberta_model" | |
| ).to(device) | |
| deberta_model.eval() | |
| # ========== Class Names ========== | |
| label_map = {0: "believer", 1: "neutral", 2: "denier"} | |
| class_names = ["believer", "neutral", "denier"] | |
| # ========== Prediction Function ========== | |
| def analyze_stance(text): | |
| # --- SBERT Embedding + MLP Prediction --- | |
| sbert_emb = sbert.encode([text]) | |
| probs_sbert = clf.predict_proba(sbert_emb)[0] | |
| # --- DeBERTa Prediction --- | |
| tokens = deberta_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128).to(device) | |
| with torch.no_grad(): | |
| outputs = deberta_model(**tokens) | |
| probs_deberta = softmax(outputs.logits, dim=1).cpu().numpy()[0] | |
| # --- Ensemble --- | |
| final_probs = 0.6 * probs_sbert + 0.4 * probs_deberta | |
| pred_idx = int(np.argmax(final_probs)) | |
| pred_label = label_map[pred_idx] | |
| # Format response | |
| probs_dict = { | |
| "believer": round(float(final_probs[0]), 4), | |
| "neutral": round(float(final_probs[1]), 4), | |
| "denier": round(float(final_probs[2]), 4) | |
| } | |
| return { | |
| "Predicted Stance": pred_label, | |
| "Believer": f"{probs_dict['believer']:.2%}", | |
| "Neutral": f"{probs_dict['neutral']:.2%}", | |
| "Denier": f"{probs_dict['denier']:.2%}" | |
| } | |
| # ========== Gradio Interface ========== | |
| demo = gr.Interface( | |
| fn=analyze_stance, | |
| inputs=gr.Textbox(label="Enter climate-related sentence or message"), | |
| outputs=gr.Label(label="Predicted Narrative Stance", num_top_classes=3), | |
| title="Narrative Stance Analyzer", | |
| description="Classifies a message as believer, neutral, or denier using an ensemble of SBERT and DeBERTa models." | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |