phirni's picture
Update app.py
0f3f7a4 verified
# 1. Imports
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# 2. Constants
MODEL_ID = "phirni/iab-url-classifier"
# 3. Load tokenizer + model (ONCE, at startup)
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
model.eval()
# 4. Inference function (THIS is where classify() goes)
def classify(text, threshold, top_n):
inputs = tokenizer(
text,
truncation=True,
padding=True,
return_tensors="pt"
)
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.sigmoid(logits)[0]
label_scores = {
model.config.id2label[i]: float(probs[i])
for i in range(len(probs))
if probs[i] >= threshold
}
sorted_labels = sorted(
label_scores.items(),
key=lambda x: x[1],
reverse=True
)
return dict(sorted_labels[:top_n])
# 5. Gradio UI (fn points to classify)
demo = gr.Interface(
fn=classify,
inputs=[
gr.Textbox(label="URL or Page Text"),
gr.Slider(0, 1, value=0.6, step=0.05, label="Confidence Threshold"),
gr.Slider(1, 10, value=5, step=1, label="Top N Categories"),
],
outputs=gr.Label(num_top_classes=10),
title="IAB URL Classifier"
)
# 6. Launch
demo.launch()