import random import numpy as np import torch from transformers import ( AutoTokenizer, AutoModelForSequenceClassification, pipeline ) from transformers_interpret import SequenceClassificationExplainer import gradio as gr # --------------------------- # Configuration # --------------------------- SEED = 42 MODEL_ID = "sosohrabian/my-fine-tuned-bert" # مدل شما در Model Hub # --------------------------- # Setup Reproducibility # --------------------------- random.seed(SEED) np.random.seed(SEED) torch.manual_seed(SEED) if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED) USE_MPS = torch.backends.mps.is_available() device = torch.device("mps" if USE_MPS else "cpu") print("Using device:", device) # --------------------------- # Load model and tokenizer from Model Hub # --------------------------- print("Loading model from:", MODEL_ID) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID) # --------------------------- # Prepare pipeline and explainer # --------------------------- label_names = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"} device_index = 0 if torch.cuda.is_available() else -1 clf = pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=None, device=device_index) explainer = SequenceClassificationExplainer(model=model, tokenizer=tokenizer) # --------------------------- # Prediction and explanation functions # --------------------------- def predict(text: str): """Predicts the class probabilities for a given input text.""" text = (text or "").strip() if not text: return {} out = clf(text, truncation=True) if isinstance(out, list) and isinstance(out[0], list): out = out[0] results = {} for o in sorted(out, key=lambda x: -x["score"]): idx = int(o["label"].split("_")[1]) results[label_names[idx]] = float(o["score"]) return results def explain_html(text: str) -> str: """Generates HTML visualization of important words.""" text = (text or "").strip() if not text: return "Enter text to see highlighted words." atts = explainer(text) toks = [t for t, _ in atts] scores = np.abs([s for _, s in atts]) smin, smax = float(np.min(scores)), float(np.max(scores)) scores = (scores - smin) / (smax - smin + 1e-8) spans = [ f"{tok}" for tok, s in zip(toks, scores) ] return "
" + " ".join(spans) + "
" def predict_and_explain(text: str): """Runs both prediction and explanation.""" return predict(text), explain_html(text) # --------------------------- # Gradio App # --------------------------- demo = gr.Interface( fn=predict_and_explain, inputs=gr.Textbox(lines=3, label="Enter news headline"), outputs=[ gr.Label(num_top_classes=4, label="Predicted topic"), gr.HTML(label="Important-word highlights"), ], title="AG News Topic Classifier (Fine-tuned BERT)", description="Classifies news headlines and highlights words that influenced the prediction.", theme="default", examples=[ ["Apple unveils new iPhone during annual event"], ["The stock market saw major gains today"], ["Scientists discover new exoplanet"], ["The local team wins the championship"], ], ) if __name__ == "__main__": demo.launch()