| import gradio as gr |
| from transformers import PreTrainedTokenizerFast, BartForConditionalGeneration |
|
|
| model_name = 'ainize/kobart-news' |
|
|
| tokenizer = PreTrainedTokenizerFast.from_pretrained(model_name) |
| model = BartForConditionalGeneration.from_pretrained(model_name) |
|
|
| def summ(txt): |
| input_ids = tokenizer.encode(txt, return_tensors="pt") |
| summary_text_ids = model.generate( |
| input_ids=input_ids, |
| bos_token_id=model.config.bos_token_id, |
| eos_token_id=model.config.eos_token_id, |
| length_penalty=2.0, |
| max_length=142, |
| min_length=56, |
| num_beams=4, |
| ) |
| return tokenizer.decode(summary_text_ids[0], skip_special_tokens=True) |
|
|
| interface = gr.Interface(summ, [gr.Textbox(label = 'original text')], |
| [gr.Textbox(label = 'summary')]) |
|
|
| interface.launch(share = True) |