# Import libraries import json from pathlib import Path import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification from peft import PeftModel # Define path MODEL_DIR = Path("model") META_PATH = Path("preprocess_meta.json") THRESH_PATH = MODEL_DIR / "threshold_global.json" # Load metadata with META_PATH.open() as f: meta = json.load(f) label2id = meta["label2id"] id2label = {int(i): label for label, i in label2id.items()} labels = [id2label[i] for i in range(len(id2label))] max_length = meta["max_length"] with THRESH_PATH.open() as f: GLOBAL_THRESHOLD = json.load(f)["global_threshold"] # Load model + LoRA adapters (handles both plain + adapter cases) tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR) base_model = AutoModelForSequenceClassification.from_pretrained( MODEL_DIR, trust_remote_code=True ) try: model = PeftModel.from_pretrained(base_model, MODEL_DIR, is_trainable=False) except ValueError: # already merged model = base_model model.eval() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) def preprocess(subject: str, body: str) -> str: sep = tokenizer.sep_token if tokenizer.sep_token else " " text = f"{subject.strip()} {sep} {body.strip()}".strip() return text def predict_intents(subject: str, body: str): if not subject and not body: return {"Predicted labels": [], "Scores": {}} text = preprocess(subject, body) inputs = tokenizer( text, truncation=True, max_length=max_length, return_tensors="pt" ) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): logits = model(**inputs).logits.squeeze().float() probs = torch.sigmoid(logits).cpu().numpy() scores = {label: float(prob) for label, prob in zip(labels, probs)} predicted = [label for label, prob in scores.items() if prob >= GLOBAL_THRESHOLD] return { "Predicted labels": predicted or ["No label ≥ threshold"], "Scores": scores, } demo = gr.Interface( fn=predict_intents, inputs=[ gr.Textbox(label="Subject", placeholder="Email subject"), gr.Textbox(label="Body", lines=12, placeholder="Email body"), ], outputs=gr.JSON(label="Result (labels & probabilities)"), title="Multi-Label Email Intent Classifier", description=( "DistilBERT + LoRA fine-tuned on synthetic email intents. " "Predictions use a global sigmoid threshold of {:.2f}.".format(GLOBAL_THRESHOLD) ), examples=[ ["Meeting Reminder: Project Sync", "Dear team, this is a reminder for tomorrow's sync at 10 AM."], ["Travel Booking Confirmation", "Your flight to London on 12 June has been confirmed. See attached itinerary."], ], ) if __name__ == "__main__": demo.launch()