MSLars's picture
Add application file
348749a
# app.py
import gradio as gr
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import pandas as pd
# --- 1. Modelle laden ---
# Wir laden die Modelle einmal beim Start der Anwendung, um die Ladezeit bei jeder Anfrage zu vermeiden.
# Modell 1: German GPT-2 für Text-Vervollständigung
# Wir laden hier Tokenizer und Modell manuell, da wir die Logits für das nächste Wort brauchen,
# was die Standard-Pipeline nicht direkt so ausgibt.
gpt2_model_name = "dbmdz/german-gpt2"
gpt2_tokenizer = AutoTokenizer.from_pretrained(gpt2_model_name)
# AutoModelWithLMHead ist veraltet, wir benutzen stattdessen AutoModelForCausalLM
gpt2_model = AutoModelForCausalLM.from_pretrained(gpt2_model_name)
# Modell 2: BERT für Masken-Füllung
# Hier können wir die Pipeline direkt nutzen, da sie genau das tut, was wir wollen.
unmasker = pipeline('fill-mask', model='bert-base-multilingual-uncased', top_k=5)
# --- 2. Logik-Funktionen für die Modelle ---
def predict_next_word(text):
"""
Diese Funktion nimmt einen Text entgegen, tokenisiert ihn, gibt ihn an das GPT-2 Modell
und berechnet die Wahrscheinlichkeitsverteilung für das *unmittelbar* folgende Wort.
"""
if not text or not text.strip():
return {}
# Text in Token-IDs umwandeln
inputs = gpt2_tokenizer(text, return_tensors='pt')
# Modell-Vorhersage ohne Gradientenberechnung (schneller)
with torch.no_grad():
logits = gpt2_model(**inputs).logits
# Wir interessieren uns nur für die Logits des letzten Tokens im Input
# Shape: [batch_size, sequence_length, vocab_size] -> wir nehmen das letzte Token
last_token_logits = logits[0, -1, :]
# Wende Softmax an, um Wahrscheinlichkeiten zu erhalten
probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1)
# Finde die Top 5 wahrscheinlichsten Token-IDs und ihre Wahrscheinlichkeiten
top_k_probs, top_k_indices = torch.topk(probabilities, 5)
# Konvertiere die Token-IDs zurück in lesbare Tokens (Wörter)
top_k_tokens = gpt2_tokenizer.convert_ids_to_tokens(top_k_indices)
# Bereinige die Tokens (z.B. 'Ġ' entfernen, das für ein Leerzeichen steht)
cleaned_tokens = [token.replace('Ġ', '') for token in top_k_tokens]
# Erstelle ein Dictionary für die Ausgabe in Gradio
result_dict = {token: prob.item() for token, prob in zip(cleaned_tokens, top_k_probs)}
return result_dict
def fill_the_mask(text):
"""
Diese Funktion nutzt die fill-mask Pipeline, um die wahrscheinlichsten Wörter für
das [MASK]-Token im Text zu finden.
"""
# Überprüfen, ob das Masken-Token vorhanden ist
mask_token = unmasker.tokenizer.mask_token
if mask_token not in text:
# Einen Fehler in der Gradio-UI anzeigen
raise gr.Error(f"Die Eingabe muss das Masken-Token '{mask_token}' enthalten.")
# Pipeline aufrufen
predictions = unmasker(text)
# Die Ausgabe der Pipeline in ein Dictionary für Gradio umwandeln
# Die Pipeline liefert eine Liste von Dictionaries, wir formatieren sie um.
result_dict = {pred['token_str']: pred['score'] for pred in predictions}
return result_dict
# --- 3. Gradio Interface erstellen ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# Funktionalität von Language Models demonstrieren
Diese Anwendung zeigt die Kernfähigkeiten von zwei verschiedenen Arten von Transformer-Modellen:
1. **Causal Language Models (z.B. GPT-2):** Diese Modelle sind darauf trainiert, das nächste Wort in einer Sequenz vorherzusagen. Sie sind ideal für die Textgenerierung.
2. **Masked Language Models (z.B. BERT):** Diese Modelle lernen, fehlende Wörter in einem Satz zu ergänzen. Sie verstehen den Kontext in beide Richtungen (bidirektional) und sind stark in Klassifikations- oder Extraktionsaufgaben.
"""
)
with gr.Tabs():
# --- Tab 1: Causal LM (GPT-2) ---
with gr.TabItem("1. Causal LM (Text-Vervollständigung)"):
gr.Markdown(
"""
## German GPT-2 (`dbmdz/german-gpt2`)
Dieses Modell sagt das wahrscheinlichste *nächste* Wort voraus.
### Anleitung
Gib einen deutschen Satzanfang in das Textfeld ein und klicke auf "Nächstes Wort vorhersagen".
"""
)
with gr.Row():
text_input_gpt = gr.Textbox(
label="Dein Satzanfang",
placeholder="Tippe hier...",
value="Der Sinn des Lebens ist",
lines=3
)
predict_button_gpt = gr.Button("Nächstes Wort vorhersagen")
gr.Markdown("### Wahrscheinlichste Folgewörter")
output_label_gpt = gr.Label(num_top_classes=5)
# --- Tab 2: Masked LM (BERT) ---
with gr.TabItem("2. Masked LM (Masken-Füllung)"):
gr.Markdown(
f"""
## Multilingual BERT (`bert-base-multilingual-uncased`)
Dieses Modell füllt die Lücke, die durch ein sogenanntes "Masken-Token" (`{unmasker.tokenizer.mask_token}`) markiert ist.
### Anleitung
Gib einen Satz mit dem Masken-Token `{unmasker.tokenizer.mask_token}` ein und klicke auf "Maske füllen".
"""
)
with gr.Row():
text_input_bert = gr.Textbox(
label="Satz mit Maske",
placeholder="Tippe hier...",
value=f"Berlin ist die Hauptstadt von {unmasker.tokenizer.mask_token}.",
lines=3
)
fill_mask_button_bert = gr.Button("Maske füllen")
gr.Markdown("### Wahrscheinlichste Füllwörter")
output_label_bert = gr.Label(num_top_classes=5)
# --- 4. Events verknüpfen ---
predict_button_gpt.click(
fn=predict_next_word,
inputs=text_input_gpt,
outputs=output_label_gpt,
api_name="predict_next_word"
)
fill_mask_button_bert.click(
fn=fill_the_mask,
inputs=text_input_bert,
outputs=output_label_bert,
api_name="fill_mask"
)
# Anwendung starten
demo.launch()