Spaces:
Sleeping
Sleeping
File size: 3,032 Bytes
24911b1 8f7b5bc 24911b1 8f7b5bc 24911b1 8f7b5bc 6dbe5b7 296eba9 24911b1 8f7b5bc 24911b1 296eba9 24911b1 296eba9 24911b1 296eba9 8f7b5bc 296eba9 8f7b5bc 296eba9 8f7b5bc 24911b1 296eba9 24911b1 8f7b5bc 296eba9 24911b1 296eba9 8f7b5bc 296eba9 24911b1 296eba9 8f7b5bc 296eba9 8f7b5bc 24911b1 8f7b5bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import math
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import nltk
MODEL_ID = "Ilyakk/t5-summarization"
MAX_INPUT_LEN = 512
GEN_MAX_LEN = 64
def ensure_nltk():
try:
nltk.data.find("tokenizers/punkt")
except LookupError:
nltk.download("punkt")
try:
nltk.data.find("tokenizers/punkt_tab")
except LookupError:
nltk.download("punkt_tab")
ensure_nltk()
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
def _first_sentence(text: str) -> str:
text = (text or "").strip()
if not text:
return ""
try:
sents = nltk.sent_tokenize(text)
return sents[0].strip() if sents else text
except Exception:
for sep in [".", "!", "?"]:
if sep in text:
return text.split(sep)[0].strip()
return text
def generate_titles(text: str, num_titles: int = 3, temperature: float = 0.7):
text = (text or "").strip()
if not text:
return ["Введите текст статьи выше."]
enc = tokenizer(["summarize: " + text], return_tensors="pt", truncation=False)
ids = enc["input_ids"][0]
mask = enc["attention_mask"][0]
num_tokens = len(ids)
num_spans = max(1, math.ceil(num_tokens / MAX_INPUT_LEN))
overlap = math.ceil((num_spans * MAX_INPUT_LEN - num_tokens) / max(num_spans - 1, 1)) if num_spans > 1 else 0
spans = []
start = 0
for i in range(num_spans):
b0 = start + MAX_INPUT_LEN * i
b1 = start + MAX_INPUT_LEN * (i + 1)
spans.append([max(0, b0), min(num_tokens, b1)])
start -= overlap
chosen = [spans[i % len(spans)] for i in range(num_titles)]
batch_ids = [ids[b0:b1] for (b0, b1) in chosen]
batch_mask = [mask[b0:b1] for (b0, b1) in chosen]
batch = {
"input_ids": torch.stack(batch_ids).to(device),
"attention_mask": torch.stack(batch_mask).to(device),
}
with torch.no_grad():
outputs = model.generate(
**batch,
do_sample=True,
temperature=float(temperature),
max_length=GEN_MAX_LEN,
num_beams=1
)
decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
titles = [_first_sentence(d) for d in decoded]
return titles
demo = gr.Interface(
fn=generate_titles,
inputs=[
gr.Textbox(label="Article text", lines=10, placeholder="Paste your article text here"),
gr.Slider(1, 10, value=3, step=1, label="Number of titles"),
gr.Slider(0.1, 1.5, value=0.7, step=0.05, label="Temperature"),
],
outputs=gr.List(label="Generated titles"),
title="T5 Title Generator",
description="Generate candidate titles for articles using your fine-tuned T5 model."
)
if __name__ == "__main__":
demo.launch()
|