DornierDo17's picture
Update app.py
7bc4a8b verified
raw
history blame contribute delete
865 Bytes
import gradio as gr
from RoBERTaModule import RoBERTaModule
from transformers import RobertaTokenizerFast
from huggingface_hub import hf_hub_download
MODEL_REPO_ID = "DornierDo17/RoBERTa_17.7M"
WEIGHTS_FILE = "finishedBest10.pt"
weight_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=WEIGHTS_FILE)
model = RoBERTaModule()
model.load_checkpoint(path=weight_path)
tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
def predict(sentece):
try:
result = model.inference(sentece)
return result
except Exception as e:
return str(e)
gr.Interface(
fn=predict,
inputs=gr.Textbox(
label="Enter sentence with <mask>",
placeholder="Example: The water boils at <mask> degress Celsius"),
outputs=gr.Textbox(label="Predicted token(s)"),
title="RoBERTa MLM Inference"
).launch()