teststrealitTlh / app.py
TLH01's picture
Rename t5_small.py to app.py
ec70cef verified
import streamlit as st
import torch
import time
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# Streamlit page configuration
st.set_page_config(page_title="Review Keypoint Extractor", page_icon="πŸ”‘")
# Define the model
model_name = "t5-small"
# Cache the model and tokenizer to avoid reloading
@st.cache_resource
def load_model_and_tokenizer():
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
return tokenizer, model, device
# Keypoint generation function
def generate_keypoint(review, max_new_tokens=64):
tokenizer, model, device = load_model_and_tokenizer()
start_time = time.time()
# T5-specific prompt
prompt = f"summarize: {review}"
# Inference
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True).to(device)
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
keypoint = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
# Post-process: Normalize "no key point" outputs
if keypoint.lower() in ["none", "no keypoint", "no key point", "n/a", "na", "", "nothing"]:
keypoint = "No key point"
elapsed = time.time() - start_time
return keypoint, elapsed
# Streamlit UI
st.title("πŸ”‘ Review Keypoint Extractor")
st.write("Enter a product review below to extract its key points using the T5-Small model.")
# Input field for review
review = st.text_area("Product Review", placeholder="e.g., The Jackery power station is lightweight and charges quickly, but the battery life could be longer.")
# Button to generate keypoint
if st.button("Extract Keypoint"):
if review.strip():
with st.spinner("Generating keypoint..."):
keypoint, elapsed = generate_keypoint(review)
st.success(f"βœ… Keypoint generated in {elapsed:.2f} seconds!")
st.subheader("Results")
st.write(f"**Review:** {review}")
st.write(f"**Keypoint:** {keypoint}")
else:
st.error("⚠️ Please enter a valid review.")
# Footer
st.markdown("---")
st.markdown("Powered by [Hugging Face Transformers](https://huggingface.co/) and [Streamlit](https://streamlit.io/)")