| import gradio as gr |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
| import nltk |
| import torch |
|
|
| model_name = "Timur1984/t5-base-title-generation" |
|
|
| print("Loading model...") |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
| nltk.download("punkt") |
| nltk.download('punkt_tab') |
|
|
| def generate_titles(text, num_titles, temperature): |
|
|
| inputs = tokenizer( |
| ["summarize: " + text], |
| return_tensors="pt", |
| truncation=True, |
| max_length=512 |
| ) |
|
|
| outputs = model.generate( |
| **inputs, |
| do_sample=True, |
| temperature=temperature, |
| num_return_sequences=num_titles |
| ) |
|
|
| decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True) |
| titles = [nltk.sent_tokenize(t.strip())[0] for t in decoded] |
|
|
| return "\n".join(titles) |
|
|
|
|
| interface = gr.Interface( |
| fn=generate_titles, |
| inputs=[ |
| gr.Textbox(lines=15, label="Article text"), |
| gr.Slider(1,10,value=5,step=1,label="Number of titles"), |
| gr.Slider(0.1,1.5,value=0.7,step=0.05,label="Temperature") |
| ], |
| outputs="text", |
| title="Article Title Generator" |
| ) |
|
|
| interface.launch() |