|
|
import gradio as gr |
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
import torch |
|
|
|
|
|
|
|
|
model_name = "iro-malta07/distilbert-base-german-lang-level-class" |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
model = AutoModelForSequenceClassification.from_pretrained(model_name) |
|
|
|
|
|
|
|
|
id2label = model.config.id2label |
|
|
|
|
|
|
|
|
def classify_text(text): |
|
|
if not text.endswith("."): |
|
|
text += "." |
|
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)[0] |
|
|
|
|
|
confidences = {id2label[i]: float(probs[i]) for i in range(len(probs))} |
|
|
return confidences |
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=classify_text, |
|
|
inputs=gr.Textbox(lines=4, placeholder="Schreibe etwas auf Deutsch..."), |
|
|
outputs=gr.Label(num_top_classes=4), |
|
|
title="German Language Level Classifier", |
|
|
description="Enter German text and get the predicted CEFR level (A1 to C2). 🚧 Work in progress. 🚧" |
|
|
) |
|
|
|
|
|
|
|
|
demo.launch() |
|
|
|