Youhorng
add gradio spaces
19c2163
# Import libraries
import json
from pathlib import Path
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from peft import PeftModel
# Define path
MODEL_DIR = Path("model")
META_PATH = Path("preprocess_meta.json")
THRESH_PATH = MODEL_DIR / "threshold_global.json"
# Load metadata
with META_PATH.open() as f:
meta = json.load(f)
label2id = meta["label2id"]
id2label = {int(i): label for label, i in label2id.items()}
labels = [id2label[i] for i in range(len(id2label))]
max_length = meta["max_length"]
with THRESH_PATH.open() as f:
GLOBAL_THRESHOLD = json.load(f)["global_threshold"]
# Load model + LoRA adapters (handles both plain + adapter cases)
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
base_model = AutoModelForSequenceClassification.from_pretrained(
MODEL_DIR, trust_remote_code=True
)
try:
model = PeftModel.from_pretrained(base_model, MODEL_DIR, is_trainable=False)
except ValueError:
# already merged
model = base_model
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
def preprocess(subject: str, body: str) -> str:
sep = tokenizer.sep_token if tokenizer.sep_token else " "
text = f"{subject.strip()} {sep} {body.strip()}".strip()
return text
def predict_intents(subject: str, body: str):
if not subject and not body:
return {"Predicted labels": [], "Scores": {}}
text = preprocess(subject, body)
inputs = tokenizer(
text,
truncation=True,
max_length=max_length,
return_tensors="pt"
)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
logits = model(**inputs).logits.squeeze().float()
probs = torch.sigmoid(logits).cpu().numpy()
scores = {label: float(prob) for label, prob in zip(labels, probs)}
predicted = [label for label, prob in scores.items() if prob >= GLOBAL_THRESHOLD]
return {
"Predicted labels": predicted or ["No label ≥ threshold"],
"Scores": scores,
}
demo = gr.Interface(
fn=predict_intents,
inputs=[
gr.Textbox(label="Subject", placeholder="Email subject"),
gr.Textbox(label="Body", lines=12, placeholder="Email body"),
],
outputs=gr.JSON(label="Result (labels & probabilities)"),
title="Multi-Label Email Intent Classifier",
description=(
"DistilBERT + LoRA fine-tuned on synthetic email intents. "
"Predictions use a global sigmoid threshold of {:.2f}.".format(GLOBAL_THRESHOLD)
),
examples=[
["Meeting Reminder: Project Sync", "Dear team, this is a reminder for tomorrow's sync at 10 AM."],
["Travel Booking Confirmation", "Your flight to London on 12 June has been confirmed. See attached itinerary."],
],
)
if __name__ == "__main__":
demo.launch()