| import gradio as gr |
| import re |
|
|
| from transformers import ( |
| AutoTokenizer, |
| AutoModelForSeq2SeqLM, |
| ) |
|
|
| def clean_text(text): |
| text = text.encode("ascii", errors="ignore").decode( |
| "ascii" |
| ) |
| text = re.sub(r"http\S+", "", text) |
| text = re.sub(r"\n", " ", text) |
| text = re.sub(r"\n\n", " ", text) |
| text = re.sub(r"\t", " ", text) |
| text = re.sub(r"ADVERTISEMENT", " ", text) |
| text = text.strip(" ") |
| text = re.sub( |
| " +", " ", text |
| ).strip() |
| return text |
|
|
|
|
| model_name = "chinhon/pegasus-newsroom-headline_writer_57k" |
|
|
| def headline_writer(text): |
| input_text = clean_text(text) |
|
|
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
| model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
|
| with tokenizer.as_target_tokenizer(): |
| batch = tokenizer( |
| input_text, |
| truncation=True, |
| padding="longest", |
| return_tensors="pt", |
| ) |
|
|
| raw_write = model.generate(**batch) |
|
|
| headline = tokenizer.batch_decode( |
| raw_write, skip_special_tokens=True, min_length=200, length_penalty=50.5 |
| ) |
|
|
| return headline[0] |
|
|
|
|
| gradio_ui = gr.Interface( |
| fn=headline_writer, |
| title="Generate News Headlines with AI", |
| description="Too busy or tired to write a headline? Try this instead.", |
| inputs=gr.inputs.Textbox( |
| lines=20, label="Paste the first few paras of your news story here" |
| ), |
| outputs=gr.outputs.Textbox(label="Suggested Headline"), |
| theme="darkdefault" |
| ) |
|
|
| gradio_ui.launch(enable_queue=True) |
|
|