elliot-evno's picture
init
d0c5b08
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
import numpy as np
MODEL = "elliot-evno/kb-bert-swedish-dep"
tokenizer = AutoTokenizer.from_pretrained(MODEL)
model = AutoModelForTokenClassification.from_pretrained(MODEL)
DEP_LABELS = ["_", "acl", "acl:cleft", "acl:relcl", "advcl", "advmod", "amod", "appos", "aux", "aux:pass", "case", "cc", "ccomp", "compound:prt", "conj", "cop", "csubj", "csubj:pass", "det", "discourse", "dislocated", "expl", "fixed", "flat:name", "iobj", "mark", "nmod", "nmod:poss", "nsubj", "nsubj:outer", "nsubj:pass", "nummod", "obj", "obl", "obl:agent", "orphan", "parataxis", "punct", "root", "vocative", "xcomp"]
def predict_dependencies(text):
"""Predict dependency relations for input text"""
if not text.strip():
return "Please enter some Swedish text!"
tokens = text.split()
inputs = tokenizer(tokens, is_split_into_words=True, return_tensors="pt",
truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
predicted_label_ids = predictions.argmax(-1)
word_ids = inputs.word_ids()
predicted_labels = []
for i, token in enumerate(tokens):
# Find the first subtoken for this word
word_predictions = []
for j, word_id in enumerate(word_ids):
if word_id == i:
word_predictions.append(predicted_label_ids[0][j].item())
if word_predictions:
# Use the prediction from the first subtoken
label_id = word_predictions[0]
if label_id < len(DEP_LABELS):
predicted_labels.append(DEP_LABELS[label_id])
else:
predicted_labels.append("UNK")
else:
predicted_labels.append("UNK")
# Format output
result = []
for token, label in zip(tokens, predicted_labels):
result.append(f"{token}{label}")
return "\n".join(result)
# Example Swedish sentences
examples = [
"Jag heter Elliot.",
"När barnen kom hem från skolan åt de pizza med sina föräldrar.",
"Den svenska flickan som jag träffade igår läser en bok.",
"Stockholm är Sveriges huvudstad och en vacker stad."
]
# Create Gradio interface
demo = gr.Interface(
fn=predict_dependencies,
inputs=gr.Textbox(
label="Swedish Text",
placeholder="Enter Swedish text here...",
lines=3
),
outputs=gr.Textbox(
label="Dependency Relations",
lines=10
),
title="🌲 Swedish Dependency Parser",
description="Enter Swedish text to get dependency relations using a fine-tuned BERT model. Shows grammatical relationships between words using Universal Dependencies format.",
examples=examples,
theme=gr.themes.Soft()
)
if __name__ == "__main__":
demo.launch()