Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import json | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| # ------------------------- | |
| # Load model + tokenizer | |
| # ------------------------- | |
| model_dir = "./argument-scheme-classifier-json" | |
| tokenizer = AutoTokenizer.from_pretrained(model_dir) | |
| model = AutoModelForSequenceClassification.from_pretrained(model_dir) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| # ------------------------- | |
| # Load Label Mapping | |
| # ------------------------- | |
| with open("./label2scheme.json", "r", encoding="utf-8") as f: | |
| label2scheme = json.load(f) | |
| # ------------------------- | |
| # Prediction Function | |
| # ------------------------- | |
| def classify_argument(argument_text): | |
| if not argument_text.strip(): | |
| return "Please enter an argument." | |
| inputs = tokenizer(argument_text, return_tensors="pt", truncation=True, max_length=128).to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probs = F.softmax(outputs.logits, dim=-1).cpu().numpy()[0] | |
| # Sort from highest to lowest confidence | |
| results = [ | |
| (label2scheme[str(i)], float(probs[i])) | |
| for i in range(len(probs)) | |
| ] | |
| results.sort(key=lambda x: x[1], reverse=True) | |
| # Format nicely | |
| top_scheme, top_conf = results[0] | |
| output_text = f"🎯 Predicted Scheme: **{top_scheme}** (confidence {top_conf:.4f})\n\n" | |
| output_text += "---\n📊 Full Probability Distribution:\n" | |
| for scheme, p in results: | |
| output_text += f"- {scheme}: **{p:.4f}**\n" | |
| return output_text | |
| # ------------------------- | |
| # Gradio interface | |
| # ------------------------- | |
| title = "🧩 Argument Scheme Classifier" | |
| description = "Enter an argument and see confidence scores for all possible argument schemes." | |
| iface = gr.Interface( | |
| fn=classify_argument, | |
| inputs=gr.Textbox(lines=4, placeholder="Type your argument here..."), | |
| outputs=gr.Markdown(label="Model Prediction"), | |
| title=title, | |
| description=description, | |
| ) | |
| iface.launch() | |