File size: 5,459 Bytes
f4b81bf
 
 
 
 
 
 
 
 
 
021abb6
 
f4b81bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8fec1e2
f4b81bf
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import streamlit as st
from peft import PeftModel
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch


@st.cache_resource
def load_model():
    """Load the PEFT model and tokenizer once and cache them"""
    base_model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")
    peft_model = PeftModel.from_pretrained(base_model, "Lakshan2003/finetuned-t5-xsum")
    tokenizer = AutoTokenizer.from_pretrained("Lakshan2003/finetuned-t5-xsum")
    return peft_model, tokenizer


def generate_summary(text, model, tokenizer, max_length=128, min_length=30):
    """Generate summary using the PEFT model"""
    # Move model to GPU if available
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = model.to(device)

    # Prepare the input text
    prefix = "summarize: "
    input_text = prefix + text

    # Tokenize
    inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    # Generate summary
    with torch.no_grad():
        output_ids = model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_length=max_length,
            min_length=min_length,
            num_beams=4,
            length_penalty=2.0,
            early_stopping=True,
            no_repeat_ngram_size=3
        )

    # Decode the summary
    summary = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    return summary


def main():
    st.set_page_config(
        page_title="SummarizeAI Pro",
        page_icon="✨",
        layout="wide"
    )

    # Custom CSS
    st.markdown("""
        <style>
        .main-title {
            text-align: center;
            color: #1E88E5;
            font-size: 3rem !important;
            font-weight: 700;
            margin-bottom: 1rem;
        }
        .subtitle {
            text-align: center;
            color: #424242;
            font-size: 1.2rem !important;
            margin-bottom: 2rem;
        }
        </style>
    """, unsafe_allow_html=True)

    # App title and subtitle
    st.markdown("<h1 class='main-title'>✨ SummarizeAI Pro</h1>", unsafe_allow_html=True)
    st.markdown("<p class='subtitle'>Transform lengthy text into concise, meaningful summaries with AI</p>",
                unsafe_allow_html=True)

    # Load model and tokenizer
    with st.spinner("Loading model... (this may take a few moments)"):
        model, tokenizer = load_model()

    # Input text area
    text = st.text_area(
        "πŸ“ Enter your text below:",
        height=200,
        placeholder="Paste your text here and let SummarizeAI Pro work its magic..."
    )

    # Create three columns for better layout
    col1, col2, col3 = st.columns([1, 1, 1])

    with col1:
        max_length = st.slider("Maximum summary length", 50, 250, 128)
    with col2:
        min_length = st.slider("Minimum summary length", 10, 100, 30)
    with col3:
        st.markdown("<br>", unsafe_allow_html=True)  # Spacing
        generate_button = st.button("✨ Generate Summary", use_container_width=True)

    if generate_button:
        if text:
            with st.spinner("✨ AI is crafting your summary..."):
                try:
                    summary = generate_summary(text, model, tokenizer,
                                               max_length=max_length,
                                               min_length=min_length)

                    st.markdown("### πŸ“Š Summary Results")

                    # Create columns for statistics
                    stat_col1, stat_col2 = st.columns(2)
                    with stat_col1:
                        st.info(f"πŸ“„ Original text: {len(text.split())} words")
                    with stat_col2:
                        st.info(f"βœ‚οΈ Summarized text: {len(summary.split())} words")

                    # Display summary in a nice box
                    st.markdown("### ✨ Generated Summary")
                    st.markdown(f"""
                        <div style="
                            padding: 20px;
                            border-radius: 10px;
                            background-color: #f0f2f6;
                            border-left: 5px solid #1E88E5;
                        ">
                            {summary}
                        </div>
                    """, unsafe_allow_html=True)

                except Exception as e:
                    st.error(f"🚫 An error occurred: {str(e)}")
        else:
            st.warning("⚠️ Please enter some text to summarize.")

    # Sidebar with enhanced styling
    st.sidebar.markdown("## 🎯 About SummarizeAI Pro")
    st.sidebar.markdown("""
        SummarizeAI Pro uses advanced AI technology powered by a PEFT-tuned T5 model 
        to generate accurate and concise summaries while preserving the key points 
        of your text.
    """)

    st.sidebar.markdown("## πŸ“– How to Use")
    st.sidebar.markdown("""
        1. πŸ“ Paste your text in the input box
        2. 🎚️ Adjust summary length with sliders
        3. πŸš€ Click 'Generate Summary'
        4. ✨ Get your AI-powered summary
    """)

    # Footer
    st.markdown("""
        <div style='text-align: center; color: #666; padding: 20px;'>
            <p>Made with ❀️ by Lakshan Cooray</p>
        </div>
    """, unsafe_allow_html=True)


if __name__ == "__main__":
    main()