File size: 4,127 Bytes
9b020f1
 
580b956
 
 
 
 
 
0b2440d
 
3bc7bf2
9b020f1
3bc7bf2
9b020f1
 
 
3bc7bf2
9b020f1
 
 
 
580b956
 
 
 
 
 
 
 
 
 
 
 
 
fe2ae52
580b956
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import sys
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import nltk
import math
import torch

os.environ['HOME'] = '/tmp'

os.environ['TRANSFORMERS_CACHE'] = '/tmp/tr_cache'
os.environ['HF_HOME'] = '/tmp/hf_cache'
os.environ['HF_HUB_CACHE'] = '/tmp/hub_cache'

os.environ['NLTK_DATA'] = '/tmp/nltk_data'

cache_paths = ['/tmp/hf_cache', '/tmp/nltk_data', '/tmp/tr_cache', '/tmp/hub_cache']
for path in cache_paths:
    os.makedirs(path, exist_ok=True)


model_name = "dnj0/t5_base_article_sum"
max_input_length = 512

st.header("Generate candidate titles for articles")

st_model_load = st.text('Loading title generator model...')

@st.cache_data()
def load_model():
    print("Loading model...")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    nltk.download('punkt')
    nltk.download('punkt_tab')
    print("Model loaded!")
    return tokenizer, model

tokenizer, model = load_model()
st.success('Model loaded!')
st_model_load.text("")

with st.sidebar:
    st.header("Model parameters")
    if 'num_titles' not in st.session_state:
        st.session_state.num_titles = 5
    def on_change_num_titles():
        st.session_state.num_titles = num_titles
    num_titles = st.slider("Number of titles to generate", min_value=1, max_value=10, value=1, step=1, on_change=on_change_num_titles)
    if 'temperature' not in st.session_state:
        st.session_state.temperature = 0.7
    def on_change_temperatures():
        st.session_state.temperature = temperature
    temperature = st.slider("Temperature", min_value=0.1, max_value=1.5, value=0.6, step=0.05, on_change=on_change_temperatures)
    st.markdown("_High temperature means that results are more random_")

if 'text' not in st.session_state:
    st.session_state.text = ""
st_text_area = st.text_area('Text to generate the title for', value=st.session_state.text, height=500)

def generate_title():
    st.session_state.text = st_text_area

    # tokenize text
    inputs = ["summarize: " + st_text_area]
    inputs = tokenizer(inputs, return_tensors="pt")

    # compute span boundaries
    num_tokens = len(inputs["input_ids"][0])
    print(f"Input has {num_tokens} tokens")
    max_input_length = 512
    num_spans = math.ceil(num_tokens / max_input_length)
    print(f"Input has {num_spans} spans")
    overlap = math.ceil((num_spans * max_input_length - num_tokens) / max(num_spans - 1, 1))
    spans_boundaries = []
    start = 0
    for i in range(num_spans):
        spans_boundaries.append([start + max_input_length * i, start + max_input_length * (i + 1)])
        start -= overlap
    print(f"Span boundaries are {spans_boundaries}")
    spans_boundaries_selected = []
    j = 0
    for _ in range(num_titles):
        spans_boundaries_selected.append(spans_boundaries[j])
        j += 1
        if j == len(spans_boundaries):
            j = 0
    print(f"Selected span boundaries are {spans_boundaries_selected}")

    # transform input with spans
    tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]] for boundary in spans_boundaries_selected]
    tensor_masks = [inputs["attention_mask"][0][boundary[0]:boundary[1]] for boundary in spans_boundaries_selected]

    inputs = {
        "input_ids": torch.stack(tensor_ids),
        "attention_mask": torch.stack(tensor_masks)
    }

    # compute predictions
    outputs = model.generate(**inputs, do_sample=True, temperature=temperature)
    decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    predicted_titles = [nltk.sent_tokenize(decoded_output.strip())[0] for decoded_output in decoded_outputs]

    st.session_state.titles = predicted_titles

# generate title button
st_generate_button = st.button('Generate title', on_click=generate_title)

# title generation labels
if 'titles' not in st.session_state:
    st.session_state.titles = []

if len(st.session_state.titles) > 0:
    with st.container():
        st.subheader("Generated titles")
        for title in st.session_state.titles:
            st.markdown("__" + title + "__")