Sbhatti33's picture
Update app.py
8dde0de verified
import os
import gradio as gr
import torch
import numpy as np
import joblib
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torch.nn.functional import softmax
from huggingface_hub import hf_hub_download
device = "cuda" if torch.cuda.is_available() else "cpu"
# ========== Load Models ==========
# Load SBERT encoder (subfolder inside the repo)
sbert = SentenceTransformer("Sbhatti33/sbert_model")
# Load SBERT classifier (from root of the repo)
clf_path = hf_hub_download(repo_id="Sbhatti33/sbert_model", filename="sbert_mlp.pkl")
clf = joblib.load(clf_path)
# ========== Load DeBERTa Model + Tokenizer ==========
# DeBERTa tokenizer and model from your repo
deberta_tokenizer = AutoTokenizer.from_pretrained(
"Sbhatti33/deberta_model",
use_fast=False # critical for crash
)
deberta_model = AutoModelForSequenceClassification.from_pretrained(
"Sbhatti33/deberta_model"
).to(device)
deberta_model.eval()
# ========== Class Names ==========
label_map = {0: "believer", 1: "neutral", 2: "denier"}
class_names = ["believer", "neutral", "denier"]
# ========== Prediction Function ==========
def analyze_stance(text):
# --- SBERT Embedding + MLP Prediction ---
sbert_emb = sbert.encode([text])
probs_sbert = clf.predict_proba(sbert_emb)[0]
# --- DeBERTa Prediction ---
tokens = deberta_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128).to(device)
with torch.no_grad():
outputs = deberta_model(**tokens)
probs_deberta = softmax(outputs.logits, dim=1).cpu().numpy()[0]
# --- Ensemble ---
final_probs = 0.6 * probs_sbert + 0.4 * probs_deberta
pred_idx = int(np.argmax(final_probs))
pred_label = label_map[pred_idx]
# Format response
probs_dict = {
"believer": round(float(final_probs[0]), 4),
"neutral": round(float(final_probs[1]), 4),
"denier": round(float(final_probs[2]), 4)
}
return {
"Predicted Stance": pred_label,
"Believer": f"{probs_dict['believer']:.2%}",
"Neutral": f"{probs_dict['neutral']:.2%}",
"Denier": f"{probs_dict['denier']:.2%}"
}
# ========== Gradio Interface ==========
demo = gr.Interface(
fn=analyze_stance,
inputs=gr.Textbox(label="Enter climate-related sentence or message"),
outputs=gr.Label(label="Predicted Narrative Stance", num_top_classes=3),
title="Narrative Stance Analyzer",
description="Classifies a message as believer, neutral, or denier using an ensemble of SBERT and DeBERTa models."
)
if __name__ == "__main__":
demo.launch()