| import os |
| import re |
| import torch |
| import gradio as gr |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
| |
| |
| MODEL_ID = os.environ.get("MODEL_ID", "Spyspook/my-t5-medium-summarizer") |
|
|
| MAX_INPUT_LENGTH = 512 |
| MAX_TARGET_LENGTH = 64 |
| DEFAULT_NUM_TITLES = 3 |
| DEFAULT_TEMPERATURE = 0.7 |
| DEFAULT_BEAMS = 4 |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
| model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID).to(device) |
| model.eval() |
|
|
| |
| _SENT_END_RE = re.compile(r"([.!?])\s+") |
|
|
| def first_sentence(text: str) -> str: |
| text = text.strip() |
| if not text: |
| return text |
| parts = _SENT_END_RE.split(text, maxsplit=1) |
| |
| if len(parts) >= 2: |
| return (parts[0] + parts[1]).strip() |
| return text |
|
|
| def generate_titles(article_text, num_titles, temperature, beams, do_sample): |
| if not article_text or not article_text.strip(): |
| return [] |
|
|
| |
| prefixed = "summarize: " + article_text.strip() |
|
|
| |
| inputs = tokenizer( |
| prefixed, |
| return_tensors="pt", |
| truncation=True, |
| max_length=MAX_INPUT_LENGTH, |
| ) |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
| gen_kwargs = dict( |
| max_length=MAX_TARGET_LENGTH, |
| num_return_sequences=int(num_titles), |
| early_stopping=True, |
| ) |
|
|
| |
| |
| |
| if do_sample: |
| gen_kwargs.update( |
| dict( |
| do_sample=True, |
| temperature=float(temperature), |
| top_p=0.95, |
| num_beams=1, |
| ) |
| ) |
| else: |
| gen_kwargs.update( |
| dict( |
| do_sample=False, |
| num_beams=int(beams), |
| length_penalty=1.0, |
| ) |
| ) |
|
|
| with torch.no_grad(): |
| outputs = model.generate(**inputs, **gen_kwargs) |
|
|
| decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True) |
|
|
| |
| seen = set() |
| titles = [] |
| for t in decoded: |
| t1 = first_sentence(t) |
| if t1 and t1 not in seen: |
| seen.add(t1) |
| titles.append(t1) |
|
|
| |
| return [[t] for t in titles] |
|
|
| |
| with gr.Blocks() as demo: |
| gr.Markdown("## T5 Article Title Generator") |
|
|
| with gr.Row(): |
| text_in = gr.Textbox( |
| label="Article text", |
| placeholder="Paste article text here…", |
| lines=14, |
| ) |
|
|
| with gr.Row(): |
| num_titles = gr.Slider(1, 10, value=DEFAULT_NUM_TITLES, step=1, label="Number of titles") |
| temperature = gr.Slider(0.1, 1.5, value=DEFAULT_TEMPERATURE, step=0.05, label="Temperature (sampling)") |
| with gr.Row(): |
| beams = gr.Slider(1, 8, value=DEFAULT_BEAMS, step=1, label="Beams (if sampling is OFF)") |
| do_sample = gr.Checkbox(value=True, label="Use sampling (ON) / Beam search (OFF)") |
|
|
| generate_btn = gr.Button("Generate") |
| out_table = gr.Dataframe(headers=["Title"], row_count=(0, "dynamic"), wrap=True) |
|
|
| generate_btn.click( |
| fn=generate_titles, |
| inputs=[text_in, num_titles, temperature, beams, do_sample], |
| outputs=out_table, |
| api_name="generate", |
| ) |
|
|
| |
| app = demo |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|