# 1. Imports import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification # 2. Constants MODEL_ID = "phirni/iab-url-classifier" # 3. Load tokenizer + model (ONCE, at startup) tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID) model.eval() # 4. Inference function (THIS is where classify() goes) def classify(text, threshold, top_n): inputs = tokenizer( text, truncation=True, padding=True, return_tensors="pt" ) with torch.no_grad(): logits = model(**inputs).logits probs = torch.sigmoid(logits)[0] label_scores = { model.config.id2label[i]: float(probs[i]) for i in range(len(probs)) if probs[i] >= threshold } sorted_labels = sorted( label_scores.items(), key=lambda x: x[1], reverse=True ) return dict(sorted_labels[:top_n]) # 5. Gradio UI (fn points to classify) demo = gr.Interface( fn=classify, inputs=[ gr.Textbox(label="URL or Page Text"), gr.Slider(0, 1, value=0.6, step=0.05, label="Confidence Threshold"), gr.Slider(1, 10, value=5, step=1, label="Top N Categories"), ], outputs=gr.Label(num_top_classes=10), title="IAB URL Classifier" ) # 6. Launch demo.launch()