Ilyakk commited on
Commit
8f7b5bc
·
verified ·
1 Parent(s): 94aa154

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +64 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ import nltk
4
+ import torch
5
+ import math
6
+
7
+ model_name = "AGIvan/t5-base-title-generation"
8
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
10
+ nltk.download("punkt")
11
+
12
+ def generate_titles(text, num_titles=3, temperature=0.7):
13
+ # tokenize text
14
+ inputs = ["summarize: " + text]
15
+ inputs = tokenizer(inputs, return_tensors="pt")
16
+
17
+ num_tokens = len(inputs["input_ids"][0])
18
+ max_input_length = 512
19
+ num_spans = math.ceil(num_tokens / max_input_length)
20
+ overlap = math.ceil((num_spans * max_input_length - num_tokens) / max(num_spans - 1, 1))
21
+
22
+ spans_boundaries = []
23
+ start = 0
24
+ for i in range(num_spans):
25
+ spans_boundaries.append([start + max_input_length * i, start + max_input_length * (i + 1)])
26
+ start -= overlap
27
+
28
+ spans_boundaries_selected = []
29
+ j = 0
30
+ for _ in range(num_titles):
31
+ spans_boundaries_selected.append(spans_boundaries[j])
32
+ j += 1
33
+ if j == len(spans_boundaries):
34
+ j = 0
35
+
36
+ tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]] for boundary in spans_boundaries_selected]
37
+ tensor_masks = [inputs["attention_mask"][0][boundary[0]:boundary[1]] for boundary in spans_boundaries_selected]
38
+
39
+ inputs = {
40
+ "input_ids": torch.stack(tensor_ids),
41
+ "attention_mask": torch.stack(tensor_masks),
42
+ }
43
+
44
+ outputs = model.generate(**inputs, do_sample=True, temperature=temperature)
45
+ decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
46
+ predicted_titles = [nltk.sent_tokenize(decoded_output.strip())[0] for decoded_output in decoded_outputs]
47
+
48
+ return predicted_titles
49
+
50
+ # Gradio interface
51
+ demo = gr.Interface(
52
+ fn=generate_titles,
53
+ inputs=[
54
+ gr.Textbox(label="Article text", lines=10, placeholder="Paste your article text here"),
55
+ gr.Slider(1, 10, value=3, step=1, label="Number of titles"),
56
+ gr.Slider(0.1, 1.5, value=0.7, step=0.05, label="Temperature"),
57
+ ],
58
+ outputs=gr.List(label="Generated titles"),
59
+ title="📰 T5 Title Generator",
60
+ description="Generate candidate titles for articles using a fine-tuned T5 model."
61
+ )
62
+
63
+ if __name__ == "__main__":
64
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ nltk
4
+ gradio
5
+ sentencepiece