|
|
import streamlit as st |
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
import torch |
|
|
|
|
|
|
|
|
def load_model(): |
|
|
model = AutoModelForSeq2SeqLM.from_pretrained("majorSeaweed/T5-Conv_summarisation_small") |
|
|
tokenizer = AutoTokenizer.from_pretrained("majorSeaweed/T5-Conv_summarisation_small") |
|
|
return tokenizer, model |
|
|
|
|
|
tokenizer, model = load_model() |
|
|
|
|
|
|
|
|
st.title("๐ง Dialogue Summarisation (T5)") |
|
|
st.markdown( |
|
|
"This app generates **non-toxic summaries** of multi-turn conversations using a SFT Tuned `t5-small` model." |
|
|
) |
|
|
|
|
|
|
|
|
dialogue = st.text_area( |
|
|
"Paste a conversation below (multi-turn dialogue with speaker turns):", |
|
|
height=200, |
|
|
placeholder="User: Hey, you there?\nBot: Yeah, what's up?\nUser: I just wanted to say you're terrible at this.", |
|
|
) |
|
|
|
|
|
|
|
|
if st.button("Generate Summary"): |
|
|
if not dialogue.strip(): |
|
|
st.warning("Please enter a conversation.") |
|
|
else: |
|
|
with st.spinner("Generating summary..."): |
|
|
inputs = tokenizer(dialogue, return_tensors="pt", truncation=True, max_length=512) |
|
|
summary_ids = model.generate( |
|
|
**inputs, |
|
|
max_length=60, |
|
|
num_beams=4, |
|
|
early_stopping=True, |
|
|
no_repeat_ngram_size=2 |
|
|
) |
|
|
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) |
|
|
st.subheader("๐ Summary") |
|
|
st.success(summary) |
|
|
|