| import gradio as gr |
| import torch |
| from transformers import BertForMaskedLM, BertTokenizer |
|
|
| |
| model_name = "bert-base-uncased" |
| model = BertForMaskedLM.from_pretrained(model_name, force_download=True) |
| tokenizer = BertTokenizer.from_pretrained(model_name, force_download=True) |
|
|
| |
| def inference(input_text): |
| if "[MASK]" not in input_text: |
| return "Error: The input text must contain the [MASK] token." |
| |
| # Tokenisierung |
| inputs = tokenizer(input_text, return_tensors="pt") |
| mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1] |
|
|
| # Vorhersage |
| with torch.no_grad(): |
| outputs = model(**inputs) |
| logits = outputs.logits |
|
|
| # Wahrscheinlichsten Token für [MASK] finden |
| mask_token_logits = logits[0, mask_token_index, :] |
| top_token = torch.topk(mask_token_logits, 1, dim=1).indices[0].tolist() |
|
|
| # Vorhersage in den Text einfügen |
| predicted_token = tokenizer.decode(top_token) |
| result_text = input_text.replace("[MASK]", predicted_token, 1) |
| |
| return result_text |
|
|
| |
| iface = gr.Interface( |
| fn=inference, |
| inputs="text", |
| outputs="text", |
| examples=[ |
| ["The capital of France is [MASK]."], |
| ["The quick brown fox jumps over the [MASK] dog."] |
| ] |
| ) |
|
|
| |
| if __name__ == "__main__": |
| iface.launch(server_port=7862) |