Spaces:
Runtime error
Runtime error
| # Import libraries | |
| import json | |
| from pathlib import Path | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| from peft import PeftModel | |
| # Define path | |
| MODEL_DIR = Path("model") | |
| META_PATH = Path("preprocess_meta.json") | |
| THRESH_PATH = MODEL_DIR / "threshold_global.json" | |
| # Load metadata | |
| with META_PATH.open() as f: | |
| meta = json.load(f) | |
| label2id = meta["label2id"] | |
| id2label = {int(i): label for label, i in label2id.items()} | |
| labels = [id2label[i] for i in range(len(id2label))] | |
| max_length = meta["max_length"] | |
| with THRESH_PATH.open() as f: | |
| GLOBAL_THRESHOLD = json.load(f)["global_threshold"] | |
| # Load model + LoRA adapters (handles both plain + adapter cases) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR) | |
| base_model = AutoModelForSequenceClassification.from_pretrained( | |
| MODEL_DIR, trust_remote_code=True | |
| ) | |
| try: | |
| model = PeftModel.from_pretrained(base_model, MODEL_DIR, is_trainable=False) | |
| except ValueError: | |
| # already merged | |
| model = base_model | |
| model.eval() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| def preprocess(subject: str, body: str) -> str: | |
| sep = tokenizer.sep_token if tokenizer.sep_token else " " | |
| text = f"{subject.strip()} {sep} {body.strip()}".strip() | |
| return text | |
| def predict_intents(subject: str, body: str): | |
| if not subject and not body: | |
| return {"Predicted labels": [], "Scores": {}} | |
| text = preprocess(subject, body) | |
| inputs = tokenizer( | |
| text, | |
| truncation=True, | |
| max_length=max_length, | |
| return_tensors="pt" | |
| ) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits.squeeze().float() | |
| probs = torch.sigmoid(logits).cpu().numpy() | |
| scores = {label: float(prob) for label, prob in zip(labels, probs)} | |
| predicted = [label for label, prob in scores.items() if prob >= GLOBAL_THRESHOLD] | |
| return { | |
| "Predicted labels": predicted or ["No label ≥ threshold"], | |
| "Scores": scores, | |
| } | |
| demo = gr.Interface( | |
| fn=predict_intents, | |
| inputs=[ | |
| gr.Textbox(label="Subject", placeholder="Email subject"), | |
| gr.Textbox(label="Body", lines=12, placeholder="Email body"), | |
| ], | |
| outputs=gr.JSON(label="Result (labels & probabilities)"), | |
| title="Multi-Label Email Intent Classifier", | |
| description=( | |
| "DistilBERT + LoRA fine-tuned on synthetic email intents. " | |
| "Predictions use a global sigmoid threshold of {:.2f}.".format(GLOBAL_THRESHOLD) | |
| ), | |
| examples=[ | |
| ["Meeting Reminder: Project Sync", "Dear team, this is a reminder for tomorrow's sync at 10 AM."], | |
| ["Travel Booking Confirmation", "Your flight to London on 12 June has been confirmed. See attached itinerary."], | |
| ], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |