Ilyakk's picture
Upload 2 files
24911b1 verified
import math
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import nltk
MODEL_ID = "Ilyakk/t5-summarization"
MAX_INPUT_LEN = 512
GEN_MAX_LEN = 64
def ensure_nltk():
try:
nltk.data.find("tokenizers/punkt")
except LookupError:
nltk.download("punkt")
try:
nltk.data.find("tokenizers/punkt_tab")
except LookupError:
nltk.download("punkt_tab")
ensure_nltk()
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
def _first_sentence(text: str) -> str:
text = (text or "").strip()
if not text:
return ""
try:
sents = nltk.sent_tokenize(text)
return sents[0].strip() if sents else text
except Exception:
for sep in [".", "!", "?"]:
if sep in text:
return text.split(sep)[0].strip()
return text
def generate_titles(text: str, num_titles: int = 3, temperature: float = 0.7):
text = (text or "").strip()
if not text:
return ["Введите текст статьи выше."]
enc = tokenizer(["summarize: " + text], return_tensors="pt", truncation=False)
ids = enc["input_ids"][0]
mask = enc["attention_mask"][0]
num_tokens = len(ids)
num_spans = max(1, math.ceil(num_tokens / MAX_INPUT_LEN))
overlap = math.ceil((num_spans * MAX_INPUT_LEN - num_tokens) / max(num_spans - 1, 1)) if num_spans > 1 else 0
spans = []
start = 0
for i in range(num_spans):
b0 = start + MAX_INPUT_LEN * i
b1 = start + MAX_INPUT_LEN * (i + 1)
spans.append([max(0, b0), min(num_tokens, b1)])
start -= overlap
chosen = [spans[i % len(spans)] for i in range(num_titles)]
batch_ids = [ids[b0:b1] for (b0, b1) in chosen]
batch_mask = [mask[b0:b1] for (b0, b1) in chosen]
batch = {
"input_ids": torch.stack(batch_ids).to(device),
"attention_mask": torch.stack(batch_mask).to(device),
}
with torch.no_grad():
outputs = model.generate(
**batch,
do_sample=True,
temperature=float(temperature),
max_length=GEN_MAX_LEN,
num_beams=1
)
decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
titles = [_first_sentence(d) for d in decoded]
return titles
demo = gr.Interface(
fn=generate_titles,
inputs=[
gr.Textbox(label="Article text", lines=10, placeholder="Paste your article text here"),
gr.Slider(1, 10, value=3, step=1, label="Number of titles"),
gr.Slider(0.1, 1.5, value=0.7, step=0.05, label="Temperature"),
],
outputs=gr.List(label="Generated titles"),
title="T5 Title Generator",
description="Generate candidate titles for articles using your fine-tuned T5 model."
)
if __name__ == "__main__":
demo.launch()