File size: 3,032 Bytes
24911b1
 
8f7b5bc
 
24911b1
8f7b5bc
24911b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f7b5bc
6dbe5b7
296eba9
 
24911b1
 
8f7b5bc
24911b1
 
 
 
 
 
 
 
 
 
 
 
296eba9
 
 
 
 
 
24911b1
 
296eba9
 
24911b1
296eba9
 
 
 
8f7b5bc
 
296eba9
 
 
8f7b5bc
 
296eba9
8f7b5bc
24911b1
296eba9
24911b1
 
 
 
8f7b5bc
296eba9
 
 
 
 
24911b1
 
296eba9
8f7b5bc
296eba9
24911b1
296eba9
8f7b5bc
 
 
 
 
 
 
 
 
296eba9
 
8f7b5bc
 
 
24911b1
8f7b5bc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import math
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import nltk


MODEL_ID = "Ilyakk/t5-summarization"   
MAX_INPUT_LEN = 512                    
GEN_MAX_LEN = 64                    

def ensure_nltk():
    try:
        nltk.data.find("tokenizers/punkt")
    except LookupError:
        nltk.download("punkt")
    try:
        nltk.data.find("tokenizers/punkt_tab")
    except LookupError:
        nltk.download("punkt_tab")

ensure_nltk()


tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

def _first_sentence(text: str) -> str:
    text = (text or "").strip()
    if not text:
        return ""
    try:
        sents = nltk.sent_tokenize(text)
        return sents[0].strip() if sents else text
    except Exception:
        for sep in [".", "!", "?"]:
            if sep in text:
                return text.split(sep)[0].strip()
        return text

def generate_titles(text: str, num_titles: int = 3, temperature: float = 0.7):
    text = (text or "").strip()
    if not text:
        return ["Введите текст статьи выше."]

    enc = tokenizer(["summarize: " + text], return_tensors="pt", truncation=False)
    ids = enc["input_ids"][0]
    mask = enc["attention_mask"][0]
    num_tokens = len(ids)

    num_spans = max(1, math.ceil(num_tokens / MAX_INPUT_LEN))
    overlap = math.ceil((num_spans * MAX_INPUT_LEN - num_tokens) / max(num_spans - 1, 1)) if num_spans > 1 else 0

    spans = []
    start = 0
    for i in range(num_spans):
        b0 = start + MAX_INPUT_LEN * i
        b1 = start + MAX_INPUT_LEN * (i + 1)
        spans.append([max(0, b0), min(num_tokens, b1)])
        start -= overlap

    chosen = [spans[i % len(spans)] for i in range(num_titles)]

    batch_ids = [ids[b0:b1] for (b0, b1) in chosen]
    batch_mask = [mask[b0:b1] for (b0, b1) in chosen]
    batch = {
        "input_ids": torch.stack(batch_ids).to(device),
        "attention_mask": torch.stack(batch_mask).to(device),
    }

    with torch.no_grad():
        outputs = model.generate(
            **batch,
            do_sample=True,
            temperature=float(temperature),
            max_length=GEN_MAX_LEN,
            num_beams=1
        )

    decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    titles = [_first_sentence(d) for d in decoded]
    return titles

demo = gr.Interface(
    fn=generate_titles,
    inputs=[
        gr.Textbox(label="Article text", lines=10, placeholder="Paste your article text here"),
        gr.Slider(1, 10, value=3, step=1, label="Number of titles"),
        gr.Slider(0.1, 1.5, value=0.7, step=0.05, label="Temperature"),
    ],
    outputs=gr.List(label="Generated titles"),
    title="T5 Title Generator",
    description="Generate candidate titles for articles using your fine-tuned T5 model."
)

if __name__ == "__main__":
    
    demo.launch()