Spaces:
Sleeping
Sleeping
| from transformers_interpret import SequenceClassificationExplainer | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline | |
| import torch, gradio as gr | |
| import numpy as np | |
| MODEL_ID = "sosohrabian/my-fine-tuned-bert" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID) | |
| explainer = SequenceClassificationExplainer(model=model, tokenizer=tokenizer) | |
| label_names = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"} | |
| device = 0 if torch.cuda.is_available() else -1 | |
| clf = pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=None, device=device) | |
| def predict(text: str): | |
| text = (text or "").strip() | |
| if not text: | |
| return {} | |
| out = clf(text, truncation=True) | |
| if isinstance(out, list) and isinstance(out[0], list): | |
| out = out[0] | |
| results = {} | |
| for o in sorted(out, key=lambda x: -x["score"]): | |
| idx = int(o["label"].split("_")[1]) | |
| results[label_names[idx]] = float(o["score"]) | |
| return results | |
| # Build script-free HTML so it renders in Gradio pages | |
| def explain_html(text: str) -> str: | |
| text = (text or "").strip() | |
| if not text: | |
| return "<i>Enter text to see highlighted words.</i>" | |
| atts = explainer(text) # list of (token, attribution) | |
| toks = [t for t, _ in atts] | |
| scores = np.abs([s for _, s in atts]) | |
| smin, smax = float(np.min(scores)), float(np.max(scores)) | |
| scores = (scores - smin) / (smax - smin + 1e-8) | |
| spans = [ | |
| f"<span style='background: rgba(255,0,0,{0.15+0.85*s:.2f});" | |
| f"padding:2px 3px; margin:1px; border-radius:4px; display:inline-block'>{tok}</span>" | |
| for tok, s in zip(toks, scores) | |
| ] | |
| return "<div style='line-height:2'>" + " ".join(spans) + "</div>" | |
| def predict_and_explain(text: str): | |
| return predict(text), explain_html(text) | |
| demo = gr.Interface( | |
| fn=predict_and_explain, | |
| inputs=gr.Textbox(lines=3, label="Enter news headline"), | |
| outputs=[ | |
| gr.Label(num_top_classes=4, label="Predicted topic"), | |
| gr.HTML(label="Important-word highlights"), | |
| ], | |
| title="AG News Topic Classifier (BERT-base)", | |
| description="Shows predicted topic and highlights words that influenced the decision." | |
| ) | |
| demo.launch(share=True) | |