Spaces:
Runtime error
Runtime error
| from transformers import AutoTokenizer | |
| from transformers import AutoModelForSequenceClassification | |
| from transformers import DataCollatorWithPadding | |
| from transformers import Trainer | |
| import gradio as gr | |
| tokenizer = AutoTokenizer.from_pretrained("smallbenchnlp/roberta-small") | |
| model = AutoModelForSequenceClassification.from_pretrained("frostymelonade/roberta-small-pun-detector-v2", num_labels=2) | |
| data_collator = DataCollatorWithPadding(tokenizer=tokenizer) | |
| trainer = Trainer( | |
| model=model, | |
| data_collator=data_collator, | |
| tokenizer=tokenizer, | |
| ) | |
| def classify_pun(text): | |
| inputs = [tokenizer(text, truncation=True)] | |
| predictions = trainer.predict(inputs) | |
| label = "Pun" if predictions[0][0][0] < predictions[0][0][1] else "Not a pun" | |
| return label, str(predictions[0][0]) | |
| #gr.Interface(fn=classify_pun, inputs=["text"], outputs=["text", "text"]).launch() | |
| with gr.Blocks() as demo: | |
| text = gr.Textbox(label="Text") | |
| output = gr.Textbox(label="Classification") | |
| output2 = gr.Textbox(label="Raw Results") | |
| greet_btn = gr.Button("Submit") | |
| greet_btn.click(fn=classify_pun, inputs=text, outputs=[output, output2], api_name="classify_pun") | |
| demo.launch() |