File size: 865 Bytes
0c8750c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d317fe7
 
7bc4a8b
0c8750c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import gradio as gr
from RoBERTaModule import RoBERTaModule
from transformers import RobertaTokenizerFast
from huggingface_hub import hf_hub_download


MODEL_REPO_ID = "DornierDo17/RoBERTa_17.7M"
WEIGHTS_FILE = "finishedBest10.pt"

weight_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=WEIGHTS_FILE)

model = RoBERTaModule()

model.load_checkpoint(path=weight_path)
tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")

def predict(sentece):
    try:
        result = model.inference(sentece)
        return result
    except Exception as e:
        return str(e)
        


gr.Interface(
    fn=predict,
    inputs=gr.Textbox(
        label="Enter sentence with <mask>",
        placeholder="Example: The water boils at <mask> degress Celsius"),
    outputs=gr.Textbox(label="Predicted token(s)"),
    title="RoBERTa MLM Inference"
).launch()